Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 68 additions & 31 deletions ipsae.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,49 +431,86 @@ def classify_chains(chains, residue_types):
sys.exit()

if boltz1:
# Boltz1 filenames:
# AURKA_TPX2_model_0.cif
# confidence_AURKA_TPX2_model_0.json
# pae_AURKA_TPX2_model_0.npz
# plddt_AURKA_TPX2_model_0.npz

# Boltz1 / Boltz-2 filenames:
# <stem>.cif
# confidence_<stem>.json
# pae_<stem>.npz
# plddt_<stem>.npz

plddt_file_path = pae_file_path.replace("pae", "plddt")
numres = len(residues) # number of residues we parsed from the structure

plddt_file_path=pae_file_path.replace("pae","plddt")
# --- plDDT handling ---
if os.path.exists(plddt_file_path):
data_plddt=np.load(plddt_file_path)
plddt_boltz1=np.array(100.0*data_plddt['plddt'])
plddt = plddt_boltz1[np.ix_(token_array.astype(bool))]
cb_plddt = plddt_boltz1[np.ix_(token_array.astype(bool))]
data_plddt = np.load(plddt_file_path)
# Boltz plDDT is 0–1; convert to 0–100 like AF
plddt_boltz1 = np.array(100.0 * data_plddt["plddt"])
M = len(plddt_boltz1)
n_common = min(M, numres)

# Always make per-residue arrays of length numres
plddt = np.zeros(numres, dtype=float)
cb_plddt = np.zeros(numres, dtype=float)
plddt[:n_common] = plddt_boltz1[:n_common]
cb_plddt[:n_common] = plddt_boltz1[:n_common]

if M != numres:
print(
f"WARNING (Boltz1): plDDT length ({M}) != number of residues ({numres}); "
f"using first {n_common} positions and zero-padding the rest."
)
else:
plddt = np.zeros(ntokens)
cb_plddt = np.zeros(ntokens)

print("Boltz1 pLDDT file does not exist: ", plddt_file_path)
# fall back to zeros
plddt = np.zeros(numres, dtype=float)
cb_plddt = np.zeros(numres, dtype=float)

# --- PAE handling ---
if os.path.exists(pae_file_path):
data_pae = np.load(pae_file_path)
pae_matrix_boltz1=np.array(data_pae['pae'])
pae_matrix = pae_matrix_boltz1[np.ix_(token_array.astype(bool), token_array.astype(bool))]

pae_matrix_boltz1 = np.array(data_pae["pae"])
M_pae = pae_matrix_boltz1.shape[0]
if pae_matrix_boltz1.shape[0] != pae_matrix_boltz1.shape[1]:
print("ERROR: Boltz1 PAE matrix is not square; quitting.")
sys.exit(1)

n_common_pae = min(M_pae, numres)

# Big default PAE so out-of-range residues are effectively ignored
pae_matrix = np.full((numres, numres), 99.0, dtype=float)
pae_matrix[:n_common_pae, :n_common_pae] = pae_matrix_boltz1[:n_common_pae, :n_common_pae]

if M_pae != numres:
print(
f"WARNING (Boltz1): PAE matrix size ({M_pae}) != number of residues ({numres}); "
f"using top-left {n_common_pae}×{n_common_pae} block and filling the rest with 99 Å."
)
else:
print("Boltz1 PAE file does not exist: ", pae_file_path)
sys.exit()

summary_file_path=pae_file_path.replace("pae","confidence")
summary_file_path=summary_file_path.replace(".npz",".json")
iptm_boltz1= {chain1: {chain2: 0 for chain2 in unique_chains if chain1 != chain2} for chain1 in unique_chains}
sys.exit(1)

# --- ipTM per chain pair from Boltz confidence json ---
summary_file_path = pae_file_path.replace("pae", "confidence").replace(".npz", ".json")
iptm_boltz1 = {chain1: {chain2: 0.0 for chain2 in unique_chains if chain1 != chain2}
for chain1 in unique_chains}

if os.path.exists(summary_file_path):
with open(summary_file_path, 'r') as file:
with open(summary_file_path, "r") as file:
data_summary = json.load(file)

boltz1_chain_pair_iptm_data=data_summary['pair_chains_iptm']
for chain1 in unique_chains:
nchain1= ord(chain1) - ord('A') # map A,B,C... to 0,1,2...
for chain2 in unique_chains:
if chain1 == chain2: continue
nchain2=ord(chain2) - ord('A')
iptm_boltz1[chain1][chain2]=boltz1_chain_pair_iptm_data[str(nchain1)][str(nchain2)]
boltz1_chain_pair_iptm_data = data_summary["pair_chains_iptm"]
for chain1 in unique_chains:
nchain1 = ord(chain1) - ord("A") # map A,B,C... to 0,1,2...
for chain2 in unique_chains:
if chain1 == chain2:
continue
nchain2 = ord(chain2) - ord("A")
iptm_boltz1[chain1][chain2] = boltz1_chain_pair_iptm_data[str(nchain1)][str(nchain2)]
else:
print("Boltz1 summary file does not exist: ", summary_file_path)



if af3:
# Example Alphafold3 server filenames
# fold_aurka_0_tpx2_0_full_data_0.json
Expand Down Expand Up @@ -967,4 +1004,4 @@ def classify_chains(chains, residue_types):
chain1_residues = f'chain {chain1} and resi {contiguous_ranges(unique_residues_chain1[chain1][chain2])}'
chain2_residues = f'chain {chain2} and resi {contiguous_ranges(unique_residues_chain2[chain1][chain2])}'
PML.write(f'alias {chain_pair}, color gray80, all; color {color1}, {chain1_residues}; color {color2}, {chain2_residues}\n\n')
OUT.write("\n")
OUT.write("\n")