Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions ipsae.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,17 @@
af3 = True
boltz1 = False
cif = True
elif ".cif" in pdb_path and pae_file_path.endswith(".npz"):
pdb_stem=pdb_path.replace(".cif","")
path_stem = f'{pdb_path.replace(".cif","")}_{pae_string}_{dist_string}'
elif (".cif" in pdb_path or ".pdb" in pdb_path) and pae_file_path.endswith(".npz"):
pdb_stem=pdb_path.replace(".cif","").replace(".pdb","")
# Determine extension for replacement
ext = ".cif" if ".cif" in pdb_path else ".pdb"
path_stem = f'{pdb_path.replace(ext,"")}_{pae_string}_{dist_string}'

af2 = False
af3 = False
boltz1 = True
cif = True
cif = False # Set to False if using PDB so it uses the PDB parser
if ext == ".cif": cif = True
else:
print("Wrong PDB or PAE file type ", pdb_path)
sys.exit()
Expand Down Expand Up @@ -216,7 +220,10 @@ def parse_cif_atom_line(line,fielddict):
atom_num = linelist[ fielddict['id'] ]
atom_name = linelist[ fielddict['label_atom_id'] ]
residue_name = linelist[ fielddict['label_comp_id'] ]
chain_id = linelist[ fielddict['label_asym_id'] ]
if 'auth_asym_id' in fielddict:
chain_id = linelist[ fielddict['auth_asym_id'] ]
else:
chain_id = linelist[ fielddict['label_asym_id'] ]
residue_seq_num = linelist[ fielddict['label_seq_id'] ]
x = linelist[ fielddict['Cartn_x'] ]
y = linelist[ fielddict['Cartn_y'] ]
Expand Down Expand Up @@ -441,7 +448,12 @@ def classify_chains(chains, residue_types):
plddt_file_path=pae_file_path.replace("pae","plddt")
if os.path.exists(plddt_file_path):
data_plddt=np.load(plddt_file_path)
plddt_boltz1=np.array(100.0*data_plddt['plddt'])
raw_plddt = data_plddt['plddt']
# Only multiply by 100 if the max value is <= 1.0 (meaning it's normalized)
if np.max(raw_plddt) <= 1.0:
plddt_boltz1 = np.array(100.0 * raw_plddt)
else:
plddt_boltz1 = np.array(raw_plddt)
plddt = plddt_boltz1[np.ix_(token_array.astype(bool))]
cb_plddt = plddt_boltz1[np.ix_(token_array.astype(bool))]
else:
Expand All @@ -463,8 +475,12 @@ def classify_chains(chains, residue_types):
if os.path.exists(summary_file_path):
with open(summary_file_path, 'r') as file:
data_summary = json.load(file)

boltz1_chain_pair_iptm_data=data_summary['pair_chains_iptm']
if 'pair_chains_iptm' in data_summary:
boltz1_chain_pair_iptm_data=data_summary['pair_chains_iptm']
else:
# Boltz2 specific or missing key fallback
print(f"Warning: 'pair_chains_iptm' key not found in {summary_file_path}. ipTM scores will be 0.")
boltz1_chain_pair_iptm_data = {}
for chain1 in unique_chains:
nchain1= ord(chain1) - ord('A') # map A,B,C... to 0,1,2...
for chain2 in unique_chains:
Expand Down Expand Up @@ -967,4 +983,4 @@ def classify_chains(chains, residue_types):
chain1_residues = f'chain {chain1} and resi {contiguous_ranges(unique_residues_chain1[chain1][chain2])}'
chain2_residues = f'chain {chain2} and resi {contiguous_ranges(unique_residues_chain2[chain1][chain2])}'
PML.write(f'alias {chain_pair}, color gray80, all; color {color1}, {chain1_residues}; color {color2}, {chain2_residues}\n\n')
OUT.write("\n")
OUT.write("\n")