diff --git a/ipsae.py b/ipsae.py index 45f64de..9669c0b 100644 --- a/ipsae.py +++ b/ipsae.py @@ -431,49 +431,86 @@ def classify_chains(chains, residue_types): sys.exit() if boltz1: - # Boltz1 filenames: - # AURKA_TPX2_model_0.cif - # confidence_AURKA_TPX2_model_0.json - # pae_AURKA_TPX2_model_0.npz - # plddt_AURKA_TPX2_model_0.npz - + # Boltz1 / Boltz-2 filenames: + # .cif + # confidence_.json + # pae_.npz + # plddt_.npz + + plddt_file_path = pae_file_path.replace("pae", "plddt") + numres = len(residues) # number of residues we parsed from the structure - plddt_file_path=pae_file_path.replace("pae","plddt") + # --- plDDT handling --- if os.path.exists(plddt_file_path): - data_plddt=np.load(plddt_file_path) - plddt_boltz1=np.array(100.0*data_plddt['plddt']) - plddt = plddt_boltz1[np.ix_(token_array.astype(bool))] - cb_plddt = plddt_boltz1[np.ix_(token_array.astype(bool))] + data_plddt = np.load(plddt_file_path) + # Boltz plDDT is 0–1; convert to 0–100 like AF + plddt_boltz1 = np.array(100.0 * data_plddt["plddt"]) + M = len(plddt_boltz1) + n_common = min(M, numres) + + # Always make per-residue arrays of length numres + plddt = np.zeros(numres, dtype=float) + cb_plddt = np.zeros(numres, dtype=float) + plddt[:n_common] = plddt_boltz1[:n_common] + cb_plddt[:n_common] = plddt_boltz1[:n_common] + + if M != numres: + print( + f"WARNING (Boltz1): plDDT length ({M}) != number of residues ({numres}); " + f"using first {n_common} positions and zero-padding the rest." + ) else: - plddt = np.zeros(ntokens) - cb_plddt = np.zeros(ntokens) - + print("Boltz1 pLDDT file does not exist: ", plddt_file_path) + # fall back to zeros + plddt = np.zeros(numres, dtype=float) + cb_plddt = np.zeros(numres, dtype=float) + + # --- PAE handling --- if os.path.exists(pae_file_path): data_pae = np.load(pae_file_path) - pae_matrix_boltz1=np.array(data_pae['pae']) - pae_matrix = pae_matrix_boltz1[np.ix_(token_array.astype(bool), token_array.astype(bool))] - + pae_matrix_boltz1 = np.array(data_pae["pae"]) + M_pae = pae_matrix_boltz1.shape[0] + if pae_matrix_boltz1.shape[0] != pae_matrix_boltz1.shape[1]: + print("ERROR: Boltz1 PAE matrix is not square; quitting.") + sys.exit(1) + + n_common_pae = min(M_pae, numres) + + # Big default PAE so out-of-range residues are effectively ignored + pae_matrix = np.full((numres, numres), 99.0, dtype=float) + pae_matrix[:n_common_pae, :n_common_pae] = pae_matrix_boltz1[:n_common_pae, :n_common_pae] + + if M_pae != numres: + print( + f"WARNING (Boltz1): PAE matrix size ({M_pae}) != number of residues ({numres}); " + f"using top-left {n_common_pae}×{n_common_pae} block and filling the rest with 99 Å." + ) else: print("Boltz1 PAE file does not exist: ", pae_file_path) - sys.exit() - - summary_file_path=pae_file_path.replace("pae","confidence") - summary_file_path=summary_file_path.replace(".npz",".json") - iptm_boltz1= {chain1: {chain2: 0 for chain2 in unique_chains if chain1 != chain2} for chain1 in unique_chains} + sys.exit(1) + + # --- ipTM per chain pair from Boltz confidence json --- + summary_file_path = pae_file_path.replace("pae", "confidence").replace(".npz", ".json") + iptm_boltz1 = {chain1: {chain2: 0.0 for chain2 in unique_chains if chain1 != chain2} + for chain1 in unique_chains} + if os.path.exists(summary_file_path): - with open(summary_file_path, 'r') as file: + with open(summary_file_path, "r") as file: data_summary = json.load(file) - boltz1_chain_pair_iptm_data=data_summary['pair_chains_iptm'] - for chain1 in unique_chains: - nchain1= ord(chain1) - ord('A') # map A,B,C... to 0,1,2... - for chain2 in unique_chains: - if chain1 == chain2: continue - nchain2=ord(chain2) - ord('A') - iptm_boltz1[chain1][chain2]=boltz1_chain_pair_iptm_data[str(nchain1)][str(nchain2)] + boltz1_chain_pair_iptm_data = data_summary["pair_chains_iptm"] + for chain1 in unique_chains: + nchain1 = ord(chain1) - ord("A") # map A,B,C... to 0,1,2... + for chain2 in unique_chains: + if chain1 == chain2: + continue + nchain2 = ord(chain2) - ord("A") + iptm_boltz1[chain1][chain2] = boltz1_chain_pair_iptm_data[str(nchain1)][str(nchain2)] else: print("Boltz1 summary file does not exist: ", summary_file_path) + + if af3: # Example Alphafold3 server filenames # fold_aurka_0_tpx2_0_full_data_0.json @@ -967,4 +1004,4 @@ def classify_chains(chains, residue_types): chain1_residues = f'chain {chain1} and resi {contiguous_ranges(unique_residues_chain1[chain1][chain2])}' chain2_residues = f'chain {chain2} and resi {contiguous_ranges(unique_residues_chain2[chain1][chain2])}' PML.write(f'alias {chain_pair}, color gray80, all; color {color1}, {chain1_residues}; color {color2}, {chain2_residues}\n\n') - OUT.write("\n") + OUT.write("\n") \ No newline at end of file