diff --git a/ipsae.py b/ipsae.py index 45f64de..cbfae06 100644 --- a/ipsae.py +++ b/ipsae.py @@ -84,7 +84,7 @@ else: print("Wrong PDB or PAE file type ", pdb_path) sys.exit() - + file_path = path_stem + ".txt" file2_path = path_stem + "_byres.txt" pml_path = path_stem + ".pml" @@ -96,7 +96,7 @@ # Define the ptm and d0 functions def ptm_func(x,d0): - return 1.0/(1+(x/d0)**2.0) + return 1.0/(1+(x/d0)**2.0) ptm_func_vec=np.vectorize(ptm_func) # vector version # Define the d0 functions for numbers and arrays; minimum value = 1.0; from Yang and Skolnick, PROTEINS: Structure, Function, and Bioinformatics 57:702–710 (2004) @@ -166,12 +166,12 @@ def parse_cif_atom_line(line,fielddict): #HETATM 1307 C C . TPO A 1 160 ? -2.115 -11.668 12.263 1.00 96.19 160 A 1 #HETATM 1308 O O . TPO A 1 160 ? -1.790 -11.556 11.113 1.00 95.75 160 A 1 # ... - #HETATM 2608 P PG . ATP C 3 . ? -6.858 4.182 10.275 1.00 84.94 1 C 1 - #HETATM 2609 O O1G . ATP C 3 . ? -6.178 5.238 11.074 1.00 75.56 1 C 1 - #HETATM 2610 O O2G . ATP C 3 . ? -5.889 3.166 9.748 1.00 75.15 1 C 1 + #HETATM 2608 P PG . ATP C 3 . ? -6.858 4.182 10.275 1.00 84.94 1 C 1 + #HETATM 2609 O O1G . ATP C 3 . ? -6.178 5.238 11.074 1.00 75.56 1 C 1 + #HETATM 2610 O O2G . ATP C 3 . ? -5.889 3.166 9.748 1.00 75.15 1 C 1 # ... - #HETATM 2639 MG MG . MG D 4 . ? -7.262 2.709 4.825 1.00 91.47 1 D 1 - #HETATM 2640 MG MG . MG E 5 . ? -4.994 2.251 8.755 1.00 85.96 1 E 1 + #HETATM 2639 MG MG . MG D 4 . ? -7.262 2.709 4.825 1.00 91.47 1 D 1 + #HETATM 2640 MG MG . MG E 5 . ? -4.994 2.251 8.755 1.00 85.96 1 E 1 # Boltz1 mmcif files (in non-standard order)) @@ -211,7 +211,7 @@ def parse_cif_atom_line(line,fielddict): #HETATM 2665 O O1B . ATP . 1 ? C -7.04640 8.36577 -7.14326 1 3 C ATP 1 1 #HETATM 2666 O O2B . ATP . 1 ? C -5.79036 7.13926 -5.33995 1 3 C ATP 1 1 - + linelist = line.split() atom_num = linelist[ fielddict['id'] ] atom_name = linelist[ fielddict['label_atom_id'] ] @@ -248,7 +248,7 @@ def parse_cif_atom_line(line,fielddict): def contiguous_ranges(numbers): if not numbers: # Check if the set is empty return - + sorted_numbers = sorted(numbers) # Sort the numbers start = sorted_numbers[0] end = start @@ -266,7 +266,7 @@ def format_range(start, end): else: ranges.append(format_range(start, end)) start = end = number - + # Append the last range after the loop ranges.append(format_range(start, end)) @@ -290,7 +290,7 @@ def init_chainpairdict_set(chainlist): def classify_chains(chains, residue_types): nuc_residue_set = {"DA", "DC", "DT", "DG", "A", "C", "U", "G"} chain_types = {} - + # Get unique chains and iterate over them unique_chains = np.unique(chains) for chain in unique_chains: @@ -300,10 +300,10 @@ def classify_chains(chains, residue_types): chain_residues = residue_types[indices] # Count nucleic acid residues nuc_count = sum(residue in nuc_residue_set for residue in chain_residues) - + # Determine if the chain is a nucleic acid or protein chain_types[chain] = 'nucleic_acid' if nuc_count > 0 else 'protein' - + return chain_types @@ -318,7 +318,7 @@ def classify_chains(chains, residue_types): # For af3 and boltz1: need mask to identify CA atom tokens in plddt vector and pae matrix; # Skip ligand atom tokens and non-CA-atom tokens in PTMs (those not in residue_set) -token_mask=list() +token_mask=list() residue_set= {"ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE", "LEU", "LYS", "MET", "PHE", "PRO", @@ -335,7 +335,7 @@ def classify_chains(chains, residue_types): (atomsite,fieldname)=line.split(".") atomsitefield_dict[fieldname]=atomsitefield_num atomsitefield_num += 1 - + if line.startswith("ATOM") or line.startswith("HETATM"): if cif: atom=parse_cif_atom_line(line, atomsitefield_dict) @@ -394,13 +394,13 @@ def classify_chains(chains, residue_types): chain_pair_type[chain1][chain2]='nucleic_acid' else: chain_pair_type[chain1][chain2]='protein' - + # Calculate distance matrix using NumPy broadcasting distances = np.sqrt(((coordinates[:, np.newaxis, :] - coordinates[np.newaxis, :, :])**2).sum(axis=2)) # Load AF2, AF3, or BOLTZ1 data and extract plddt and pae_matrix (and ptm_matrix if available) if af2: - + if os.path.exists(pae_file_path): if pae_file_path.endswith('.pkl'): @@ -408,35 +408,35 @@ def classify_chains(chains, residue_types): else: with open(pae_file_path, 'r') as file: data = json.load(file) - + if 'iptm' in data: iptm_af2 = float(data['iptm']) else: iptm_af2=-1.0 if 'ptm' in data: ptm_af2 = float(data['ptm']) else: ptm_af2=-1.0 - + if 'plddt' in data: plddt = np.array(data['plddt']) cb_plddt = np.array(data['plddt']) # for pDockQ else: plddt = np.zeros(numres) cb_plddt = np.zeros(numres) - + if 'pae' in data: pae_matrix = np.array(data['pae']) elif 'predicted_aligned_error' in data: pae_matrix=np.array(data['predicted_aligned_error']) - + else: print("AF2 PAE file does not exist: ", pae_file_path) sys.exit() - + if boltz1: # Boltz1 filenames: # AURKA_TPX2_model_0.cif # confidence_AURKA_TPX2_model_0.json # pae_AURKA_TPX2_model_0.npz # plddt_AURKA_TPX2_model_0.npz - + plddt_file_path=pae_file_path.replace("pae","plddt") if os.path.exists(plddt_file_path): @@ -447,7 +447,7 @@ def classify_chains(chains, residue_types): else: plddt = np.zeros(ntokens) cb_plddt = np.zeros(ntokens) - + if os.path.exists(pae_file_path): data_pae = np.load(pae_file_path) pae_matrix_boltz1=np.array(data_pae['pae']) @@ -456,7 +456,7 @@ def classify_chains(chains, residue_types): else: print("Boltz1 PAE file does not exist: ", pae_file_path) sys.exit() - + summary_file_path=pae_file_path.replace("pae","confidence") summary_file_path=summary_file_path.replace(".npz",".json") iptm_boltz1= {chain1: {chain2: 0 for chain2 in unique_chains if chain1 != chain2} for chain1 in unique_chains} @@ -465,11 +465,9 @@ def classify_chains(chains, residue_types): data_summary = json.load(file) boltz1_chain_pair_iptm_data=data_summary['pair_chains_iptm'] - for chain1 in unique_chains: - nchain1= ord(chain1) - ord('A') # map A,B,C... to 0,1,2... - for chain2 in unique_chains: + for nchain1, chain1 in enumerate(unique_chains): + for nchain2, chain2 in enumerate(unique_chains): if chain1 == chain2: continue - nchain2=ord(chain2) - ord('A') iptm_boltz1[chain1][chain2]=boltz1_chain_pair_iptm_data[str(nchain1)][str(nchain2)] else: print("Boltz1 summary file does not exist: ", summary_file_path) @@ -493,7 +491,7 @@ def classify_chains(chains, residue_types): atom_plddts=np.array(data['atom_plddts']) plddt=atom_plddts[CA_atom_num] # pull out residue plddts from Calpha atoms cb_plddt=atom_plddts[CB_atom_num] # pull out residue plddts from Cbeta atoms for pDockQ - + # Get pairwise residue PAE matrix by identifying one token per protein residue. # Modified residues have separate tokens for each atom, so need to pull out Calpha atom as token # Skip ligands @@ -502,7 +500,7 @@ def classify_chains(chains, residue_types): else: print("no PAE data in AF3 json file; quitting") sys.exit() - + # Set pae_matrix for AF3 from subset of full PAE matrix from json file token_array=np.array(token_mask) pae_matrix = pae_matrix_af3[np.ix_(token_array.astype(bool), token_array.astype(bool))] @@ -519,11 +517,9 @@ def classify_chains(chains, residue_types): with open(summary_file_path,'r') as file: data_summary=json.load(file) af3_chain_pair_iptm_data=data_summary['chain_pair_iptm'] - for chain1 in unique_chains: - nchain1= ord(chain1) - ord('A') # map A,B,C... to 0,1,2... - for chain2 in unique_chains: + for nchain1, chain1 in enumerate(unique_chains): + for nchain2, chain2 in enumerate(unique_chains): if chain1 == chain2: continue - nchain2=ord(chain2) - ord('A') iptm_af3[chain1][chain2]=af3_chain_pair_iptm_data[nchain1][nchain2] else: print("AF3 summary file does not exist: ", summary_file_path) @@ -536,7 +532,7 @@ def classify_chains(chains, residue_types): # ipsae_d0chn = calculate ipsae from PAEs with PAE cutoff; d0 = numres in chain pair = len(chain1) + len(chain2) # ipsae_d0dom = calculate ipsae from PAEs with PAE cutoff; d0 from number of residues in chain1 and chain2 that have interchain PAEB is different from B->A) @@ -617,7 +613,7 @@ def classify_chains(chains, residue_types): for residue in chain2residues: pDockQ_unique_residues[chain1][chain2].add(residue) - + if npairs>0: nres=len(list(pDockQ_unique_residues[chain1][chain2])) mean_plddt= cb_plddt[ list(pDockQ_unique_residues[chain1][chain2])].mean() @@ -628,7 +624,7 @@ def classify_chains(chains, residue_types): x=0.0 pDockQ[chain1][chain2]=0.0 nres=0 - + # pDockQ2 for chain1 in unique_chains: @@ -646,7 +642,7 @@ def classify_chains(chains, residue_types): pae_list=pae_matrix[i][valid_pairs] pae_list_ptm=ptm_func_vec(pae_list,10.0) sum += pae_list_ptm.sum() - + if npairs>0: nres=len(list(pDockQ_unique_residues[chain1][chain2])) mean_plddt= cb_plddt[ list(pDockQ_unique_residues[chain1][chain2])].mean() @@ -658,16 +654,16 @@ def classify_chains(chains, residue_types): x=0.0 nres=0 pDockQ2[chain1][chain2]=0.0 - + # LIS for chain1 in unique_chains: for chain2 in unique_chains: if chain1==chain2: continue - + mask = (chains[:, None] == chain1) & (chains[None, :] == chain2) # Select residues for (chain1, chain2) selected_pae = pae_matrix[mask] # Get PAE values for this pair - + if selected_pae.size > 0: # Ensure we have values valid_pae = selected_pae[selected_pae <= 12] # Apply the threshold if valid_pae.size > 0: @@ -714,7 +710,7 @@ def classify_chains(chains, residue_types): for j in np.where(valid_pairs_ipsae)[0]: jresnum=residues[j]['resnum'] unique_residues_chain2[chain1][chain2].add(jresnum) - + # Track unique residues contributing to iptm in interface valid_pairs = (chains == chain2) & (pae_matrix[i] < pae_cutoff) & (distances[i] < dist_cutoff) dist_valid_pair_counts[chain1][chain2] += np.sum(valid_pairs) @@ -748,7 +744,7 @@ def classify_chains(chains, residue_types): n0res_byres[chain1][chain2] = n0res_byres_all d0res_byres[chain1][chain2] = d0res_byres_all - + for i in range(numres): if chains[i] != chain1: continue @@ -777,7 +773,7 @@ def classify_chains(chains, residue_types): f'{ipsae_d0res_byres[chain1][chain2][i]:8.4f}\n' ) OUT2.write(outstring) - + # Compute interchain ipTM and ipSAE for each chain pair for chain1 in unique_chains: for chain2 in unique_chains: @@ -860,7 +856,7 @@ def classify_chains(chains, residue_types): d0res_max[chain1][chain2]=maxd0 d0res_max[chain2][chain1]=maxd0 - + chaincolor={'A':'magenta', 'B':'marine', 'C':'lime', 'D':'orange', 'E':'yellow', 'F':'cyan', 'G':'lightorange', 'H':'pink', 'I':'deepteal', 'J':'forest', 'K':'lightblue', 'L':'slate', @@ -904,7 +900,7 @@ def classify_chains(chains, residue_types): if af2: iptm_af = iptm_af2 # same for all chain pairs in entry if af3: iptm_af = iptm_af3[chain1][chain2] # symmetric value for each chain pair if boltz1: iptm_af=iptm_boltz1[chain1][chain2] - + outstring=f'{chain1} {chain2} {pae_string:3} {dist_string:3} {"asym":5} ' + ( f'{ipsae_d0res_asym[chain1][chain2]:8.6f} ' f'{ipsae_d0chn_asym[chain1][chain2]:8.6f} ' @@ -962,7 +958,7 @@ def classify_chains(chains, residue_types): f'{pdb_stem}\n') OUT.write(outstring) PML.write("# " + outstring) - + chain_pair= f'color_{chain1}_{chain2}' chain1_residues = f'chain {chain1} and resi {contiguous_ranges(unique_residues_chain1[chain1][chain2])}' chain2_residues = f'chain {chain2} and resi {contiguous_ranges(unique_residues_chain2[chain1][chain2])}'