Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 44 additions & 48 deletions ipsae.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
else:
print("Wrong PDB or PAE file type ", pdb_path)
sys.exit()

file_path = path_stem + ".txt"
file2_path = path_stem + "_byres.txt"
pml_path = path_stem + ".pml"
Expand All @@ -96,7 +96,7 @@

# Define the ptm and d0 functions
def ptm_func(x,d0):
return 1.0/(1+(x/d0)**2.0)
return 1.0/(1+(x/d0)**2.0)
ptm_func_vec=np.vectorize(ptm_func) # vector version

# Define the d0 functions for numbers and arrays; minimum value = 1.0; from Yang and Skolnick, PROTEINS: Structure, Function, and Bioinformatics 57:702–710 (2004)
Expand Down Expand Up @@ -166,12 +166,12 @@ def parse_cif_atom_line(line,fielddict):
#HETATM 1307 C C . TPO A 1 160 ? -2.115 -11.668 12.263 1.00 96.19 160 A 1
#HETATM 1308 O O . TPO A 1 160 ? -1.790 -11.556 11.113 1.00 95.75 160 A 1
# ...
#HETATM 2608 P PG . ATP C 3 . ? -6.858 4.182 10.275 1.00 84.94 1 C 1
#HETATM 2609 O O1G . ATP C 3 . ? -6.178 5.238 11.074 1.00 75.56 1 C 1
#HETATM 2610 O O2G . ATP C 3 . ? -5.889 3.166 9.748 1.00 75.15 1 C 1
#HETATM 2608 P PG . ATP C 3 . ? -6.858 4.182 10.275 1.00 84.94 1 C 1
#HETATM 2609 O O1G . ATP C 3 . ? -6.178 5.238 11.074 1.00 75.56 1 C 1
#HETATM 2610 O O2G . ATP C 3 . ? -5.889 3.166 9.748 1.00 75.15 1 C 1
# ...
#HETATM 2639 MG MG . MG D 4 . ? -7.262 2.709 4.825 1.00 91.47 1 D 1
#HETATM 2640 MG MG . MG E 5 . ? -4.994 2.251 8.755 1.00 85.96 1 E 1
#HETATM 2639 MG MG . MG D 4 . ? -7.262 2.709 4.825 1.00 91.47 1 D 1
#HETATM 2640 MG MG . MG E 5 . ? -4.994 2.251 8.755 1.00 85.96 1 E 1


# Boltz1 mmcif files (in non-standard order))
Expand Down Expand Up @@ -211,7 +211,7 @@ def parse_cif_atom_line(line,fielddict):
#HETATM 2665 O O1B . ATP . 1 ? C -7.04640 8.36577 -7.14326 1 3 C ATP 1 1
#HETATM 2666 O O2B . ATP . 1 ? C -5.79036 7.13926 -5.33995 1 3 C ATP 1 1


linelist = line.split()
atom_num = linelist[ fielddict['id'] ]
atom_name = linelist[ fielddict['label_atom_id'] ]
Expand Down Expand Up @@ -248,7 +248,7 @@ def parse_cif_atom_line(line,fielddict):
def contiguous_ranges(numbers):
if not numbers: # Check if the set is empty
return

sorted_numbers = sorted(numbers) # Sort the numbers
start = sorted_numbers[0]
end = start
Expand All @@ -266,7 +266,7 @@ def format_range(start, end):
else:
ranges.append(format_range(start, end))
start = end = number

# Append the last range after the loop
ranges.append(format_range(start, end))

Expand All @@ -290,7 +290,7 @@ def init_chainpairdict_set(chainlist):
def classify_chains(chains, residue_types):
nuc_residue_set = {"DA", "DC", "DT", "DG", "A", "C", "U", "G"}
chain_types = {}

# Get unique chains and iterate over them
unique_chains = np.unique(chains)
for chain in unique_chains:
Expand All @@ -300,10 +300,10 @@ def classify_chains(chains, residue_types):
chain_residues = residue_types[indices]
# Count nucleic acid residues
nuc_count = sum(residue in nuc_residue_set for residue in chain_residues)

# Determine if the chain is a nucleic acid or protein
chain_types[chain] = 'nucleic_acid' if nuc_count > 0 else 'protein'

return chain_types


Expand All @@ -318,7 +318,7 @@ def classify_chains(chains, residue_types):

# For af3 and boltz1: need mask to identify CA atom tokens in plddt vector and pae matrix;
# Skip ligand atom tokens and non-CA-atom tokens in PTMs (those not in residue_set)
token_mask=list()
token_mask=list()
residue_set= {"ALA", "ARG", "ASN", "ASP", "CYS",
"GLN", "GLU", "GLY", "HIS", "ILE",
"LEU", "LYS", "MET", "PHE", "PRO",
Expand All @@ -335,7 +335,7 @@ def classify_chains(chains, residue_types):
(atomsite,fieldname)=line.split(".")
atomsitefield_dict[fieldname]=atomsitefield_num
atomsitefield_num += 1

if line.startswith("ATOM") or line.startswith("HETATM"):
if cif:
atom=parse_cif_atom_line(line, atomsitefield_dict)
Expand Down Expand Up @@ -394,49 +394,49 @@ def classify_chains(chains, residue_types):
chain_pair_type[chain1][chain2]='nucleic_acid'
else:
chain_pair_type[chain1][chain2]='protein'

# Calculate distance matrix using NumPy broadcasting
distances = np.sqrt(((coordinates[:, np.newaxis, :] - coordinates[np.newaxis, :, :])**2).sum(axis=2))

# Load AF2, AF3, or BOLTZ1 data and extract plddt and pae_matrix (and ptm_matrix if available)
if af2:


if os.path.exists(pae_file_path):
if pae_file_path.endswith('.pkl'):
data = np.load(pae_file_path, allow_pickle=True)
else:
with open(pae_file_path, 'r') as file:
data = json.load(file)

if 'iptm' in data: iptm_af2 = float(data['iptm'])
else: iptm_af2=-1.0
if 'ptm' in data: ptm_af2 = float(data['ptm'])
else: ptm_af2=-1.0

if 'plddt' in data:
plddt = np.array(data['plddt'])
cb_plddt = np.array(data['plddt']) # for pDockQ
else:
plddt = np.zeros(numres)
cb_plddt = np.zeros(numres)

if 'pae' in data:
pae_matrix = np.array(data['pae'])
elif 'predicted_aligned_error' in data:
pae_matrix=np.array(data['predicted_aligned_error'])

else:
print("AF2 PAE file does not exist: ", pae_file_path)
sys.exit()

if boltz1:
# Boltz1 filenames:
# AURKA_TPX2_model_0.cif
# confidence_AURKA_TPX2_model_0.json
# pae_AURKA_TPX2_model_0.npz
# plddt_AURKA_TPX2_model_0.npz


plddt_file_path=pae_file_path.replace("pae","plddt")
if os.path.exists(plddt_file_path):
Expand All @@ -447,7 +447,7 @@ def classify_chains(chains, residue_types):
else:
plddt = np.zeros(ntokens)
cb_plddt = np.zeros(ntokens)

if os.path.exists(pae_file_path):
data_pae = np.load(pae_file_path)
pae_matrix_boltz1=np.array(data_pae['pae'])
Expand All @@ -456,7 +456,7 @@ def classify_chains(chains, residue_types):
else:
print("Boltz1 PAE file does not exist: ", pae_file_path)
sys.exit()

summary_file_path=pae_file_path.replace("pae","confidence")
summary_file_path=summary_file_path.replace(".npz",".json")
iptm_boltz1= {chain1: {chain2: 0 for chain2 in unique_chains if chain1 != chain2} for chain1 in unique_chains}
Expand All @@ -465,11 +465,9 @@ def classify_chains(chains, residue_types):
data_summary = json.load(file)

boltz1_chain_pair_iptm_data=data_summary['pair_chains_iptm']
for chain1 in unique_chains:
nchain1= ord(chain1) - ord('A') # map A,B,C... to 0,1,2...
for chain2 in unique_chains:
for nchain1, chain1 in enumerate(unique_chains):
for nchain2, chain2 in enumerate(unique_chains):
if chain1 == chain2: continue
nchain2=ord(chain2) - ord('A')
iptm_boltz1[chain1][chain2]=boltz1_chain_pair_iptm_data[str(nchain1)][str(nchain2)]
else:
print("Boltz1 summary file does not exist: ", summary_file_path)
Expand All @@ -493,7 +491,7 @@ def classify_chains(chains, residue_types):
atom_plddts=np.array(data['atom_plddts'])
plddt=atom_plddts[CA_atom_num] # pull out residue plddts from Calpha atoms
cb_plddt=atom_plddts[CB_atom_num] # pull out residue plddts from Cbeta atoms for pDockQ

# Get pairwise residue PAE matrix by identifying one token per protein residue.
# Modified residues have separate tokens for each atom, so need to pull out Calpha atom as token
# Skip ligands
Expand All @@ -502,7 +500,7 @@ def classify_chains(chains, residue_types):
else:
print("no PAE data in AF3 json file; quitting")
sys.exit()

# Set pae_matrix for AF3 from subset of full PAE matrix from json file
token_array=np.array(token_mask)
pae_matrix = pae_matrix_af3[np.ix_(token_array.astype(bool), token_array.astype(bool))]
Expand All @@ -519,11 +517,9 @@ def classify_chains(chains, residue_types):
with open(summary_file_path,'r') as file:
data_summary=json.load(file)
af3_chain_pair_iptm_data=data_summary['chain_pair_iptm']
for chain1 in unique_chains:
nchain1= ord(chain1) - ord('A') # map A,B,C... to 0,1,2...
for chain2 in unique_chains:
for nchain1, chain1 in enumerate(unique_chains):
for nchain2, chain2 in enumerate(unique_chains):
if chain1 == chain2: continue
nchain2=ord(chain2) - ord('A')
iptm_af3[chain1][chain2]=af3_chain_pair_iptm_data[nchain1][nchain2]
else:
print("AF3 summary file does not exist: ", summary_file_path)
Expand All @@ -536,7 +532,7 @@ def classify_chains(chains, residue_types):
# ipsae_d0chn = calculate ipsae from PAEs with PAE cutoff; d0 = numres in chain pair = len(chain1) + len(chain2)
# ipsae_d0dom = calculate ipsae from PAEs with PAE cutoff; d0 from number of residues in chain1 and chain2 that have interchain PAE<cutoff
# ipsae_d0res = calculate ipsae from PAEs with PAE cutoff; d0 from number of residues in chain2 that have interchain PAE<cutoff given residue in chain1
#
#
# for each chain_pair iptm/ipsae, there is (for example)
# ipsae_d0res_byres = by-residue array;
# ipsae_d0res_asym = asymmetric pair value (A->B is different from B->A)
Expand Down Expand Up @@ -617,7 +613,7 @@ def classify_chains(chains, residue_types):

for residue in chain2residues:
pDockQ_unique_residues[chain1][chain2].add(residue)

if npairs>0:
nres=len(list(pDockQ_unique_residues[chain1][chain2]))
mean_plddt= cb_plddt[ list(pDockQ_unique_residues[chain1][chain2])].mean()
Expand All @@ -628,7 +624,7 @@ def classify_chains(chains, residue_types):
x=0.0
pDockQ[chain1][chain2]=0.0
nres=0

# pDockQ2

for chain1 in unique_chains:
Expand All @@ -646,7 +642,7 @@ def classify_chains(chains, residue_types):
pae_list=pae_matrix[i][valid_pairs]
pae_list_ptm=ptm_func_vec(pae_list,10.0)
sum += pae_list_ptm.sum()

if npairs>0:
nres=len(list(pDockQ_unique_residues[chain1][chain2]))
mean_plddt= cb_plddt[ list(pDockQ_unique_residues[chain1][chain2])].mean()
Expand All @@ -658,16 +654,16 @@ def classify_chains(chains, residue_types):
x=0.0
nres=0
pDockQ2[chain1][chain2]=0.0

# LIS

for chain1 in unique_chains:
for chain2 in unique_chains:
if chain1==chain2: continue

mask = (chains[:, None] == chain1) & (chains[None, :] == chain2) # Select residues for (chain1, chain2)
selected_pae = pae_matrix[mask] # Get PAE values for this pair

if selected_pae.size > 0: # Ensure we have values
valid_pae = selected_pae[selected_pae <= 12] # Apply the threshold
if valid_pae.size > 0:
Expand Down Expand Up @@ -714,7 +710,7 @@ def classify_chains(chains, residue_types):
for j in np.where(valid_pairs_ipsae)[0]:
jresnum=residues[j]['resnum']
unique_residues_chain2[chain1][chain2].add(jresnum)

# Track unique residues contributing to iptm in interface
valid_pairs = (chains == chain2) & (pae_matrix[i] < pae_cutoff) & (distances[i] < dist_cutoff)
dist_valid_pair_counts[chain1][chain2] += np.sum(valid_pairs)
Expand Down Expand Up @@ -748,7 +744,7 @@ def classify_chains(chains, residue_types):

n0res_byres[chain1][chain2] = n0res_byres_all
d0res_byres[chain1][chain2] = d0res_byres_all

for i in range(numres):
if chains[i] != chain1:
continue
Expand Down Expand Up @@ -777,7 +773,7 @@ def classify_chains(chains, residue_types):
f'{ipsae_d0res_byres[chain1][chain2][i]:8.4f}\n'
)
OUT2.write(outstring)

# Compute interchain ipTM and ipSAE for each chain pair
for chain1 in unique_chains:
for chain2 in unique_chains:
Expand Down Expand Up @@ -860,7 +856,7 @@ def classify_chains(chains, residue_types):
d0res_max[chain1][chain2]=maxd0
d0res_max[chain2][chain1]=maxd0


chaincolor={'A':'magenta', 'B':'marine', 'C':'lime', 'D':'orange',
'E':'yellow', 'F':'cyan', 'G':'lightorange', 'H':'pink',
'I':'deepteal', 'J':'forest', 'K':'lightblue', 'L':'slate',
Expand Down Expand Up @@ -904,7 +900,7 @@ def classify_chains(chains, residue_types):
if af2: iptm_af = iptm_af2 # same for all chain pairs in entry
if af3: iptm_af = iptm_af3[chain1][chain2] # symmetric value for each chain pair
if boltz1: iptm_af=iptm_boltz1[chain1][chain2]

outstring=f'{chain1} {chain2} {pae_string:3} {dist_string:3} {"asym":5} ' + (
f'{ipsae_d0res_asym[chain1][chain2]:8.6f} '
f'{ipsae_d0chn_asym[chain1][chain2]:8.6f} '
Expand Down Expand Up @@ -962,7 +958,7 @@ def classify_chains(chains, residue_types):
f'{pdb_stem}\n')
OUT.write(outstring)
PML.write("# " + outstring)

chain_pair= f'color_{chain1}_{chain2}'
chain1_residues = f'chain {chain1} and resi {contiguous_ranges(unique_residues_chain1[chain1][chain2])}'
chain2_residues = f'chain {chain2} and resi {contiguous_ranges(unique_residues_chain2[chain1][chain2])}'
Expand Down