From 7c30fee0abf502e30b3e32c4b7a6e47c20ad74d8 Mon Sep 17 00:00:00 2001 From: Orange Date: Mon, 11 Aug 2025 09:24:52 +0000 Subject: [PATCH 1/3] Fixed issues with designing in scaffoldguided mode, for example: design_ppi_scaffold.sh. The solutions to issues 272 and 273 did not fully address the issue. --- config/inference/base.yaml | 2 +- examples/design_ppi_scaffolded.sh | 2 +- rfdiffusion/inference/model_runners.py | 33 ++++++++++++++++++++------ rfdiffusion/inference/utils.py | 4 ++-- 4 files changed, 30 insertions(+), 11 deletions(-) diff --git a/config/inference/base.yaml b/config/inference/base.yaml index 3bb0a5c1..509bd4e9 100644 --- a/config/inference/base.yaml +++ b/config/inference/base.yaml @@ -126,7 +126,7 @@ logging: inputs: False scaffoldguided: - scaffoldguided: False + scaffoldguided_enable: False target_pdb: False target_path: null scaffold_list: null diff --git a/examples/design_ppi_scaffolded.sh b/examples/design_ppi_scaffolded.sh index 63ca165e..5bcf5af5 100755 --- a/examples/design_ppi_scaffolded.sh +++ b/examples/design_ppi_scaffolded.sh @@ -7,4 +7,4 @@ # We then provide a path to a directory of different scaffolds (we've provided some for you to use, from Cao et al., 2022) # We generate 10 designs, and reduce the noise added during inference to 0 (which improves the quality of designs) -../scripts/run_inference.py scaffoldguided.target_path=input_pdbs/insulin_target.pdb inference.output_prefix=example_outputs/design_ppi_scaffolded scaffoldguided.scaffoldguided=True 'ppi.hotspot_res=[A59,A83,A91]' scaffoldguided.target_pdb=True scaffoldguided.target_ss=target_folds/insulin_target_ss.pt scaffoldguided.target_adj=target_folds/insulin_target_adj.pt scaffoldguided.scaffold_dir=./ppi_scaffolds/ inference.num_designs=10 denoiser.noise_scale_ca=0 denoiser.noise_scale_frame=0 +../scripts/run_inference.py scaffoldguided.target_path=input_pdbs/insulin_target.pdb inference.output_prefix=example_outputs/design_ppi_scaffolded scaffoldguided.scaffoldguided_enable=True 'ppi.hotspot_res=[A59,A83,A91]' scaffoldguided.target_pdb=True scaffoldguided.target_ss=target_folds/insulin_target_ss.pt scaffoldguided.target_adj=target_folds/insulin_target_adj.pt scaffoldguided.scaffold_dir=./ppi_scaffolds/ inference.num_designs=10 denoiser.noise_scale_ca=0 denoiser.noise_scale_frame=0 diff --git a/rfdiffusion/inference/model_runners.py b/rfdiffusion/inference/model_runners.py index 1ad3c7f1..edb637d0 100644 --- a/rfdiffusion/inference/model_runners.py +++ b/rfdiffusion/inference/model_runners.py @@ -77,14 +77,14 @@ def initialize(self, conf: DictConfig) -> None: if conf.contigmap.provide_seq is not None: # this is only used for partial diffusion assert conf.diffuser.partial_T is not None, "The provide_seq input is specifically for partial diffusion" - if conf.scaffoldguided.scaffoldguided: + if conf.scaffoldguided.scaffoldguided_enable: self.ckpt_path = f'{model_directory}/InpaintSeq_Fold_ckpt.pt' else: self.ckpt_path = f'{model_directory}/InpaintSeq_ckpt.pt' - elif conf.ppi.hotspot_res is not None and conf.scaffoldguided.scaffoldguided is False: + elif conf.ppi.hotspot_res is not None and conf.scaffoldguided.scaffoldguided_enable is False: # use complex trained model self.ckpt_path = f'{model_directory}/Complex_base_ckpt.pt' - elif conf.scaffoldguided.scaffoldguided is True: + elif conf.scaffoldguided.scaffoldguided_enable is True: # use complex and secondary structure-guided model self.ckpt_path = f'{model_directory}/Complex_Fold_base_ckpt.pt' else: @@ -279,7 +279,6 @@ def sample_init(self, return_forward_trajectory=False): self.mask_seq = torch.from_numpy(self.contig_map.inpaint_seq)[None,:] self.mask_str = torch.from_numpy(self.contig_map.inpaint_str)[None,:] self.binderlen = len(self.contig_map.inpaint) - ####################################### ### Resolve cyclic peptide indicies ### ####################################### @@ -301,7 +300,7 @@ def sample_init(self, return_forward_trajectory=False): self.cyclic_reses = is_cyclized else: self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() - + #################### ### Get Hotspots ### #################### @@ -681,7 +680,6 @@ def sample_step(self, *, t, x_t, seq_init, final_step): #################### ### Forward Pass ### #################### - with torch.no_grad(): msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model(msa_masked, msa_full, @@ -771,7 +769,7 @@ def __init__(self, conf: DictConfig): else: # initialize BlockAdjacency sampling class assert all(x is None for x in (conf.contigmap.inpaint_str_helix, conf.contigmap.inpaint_str_strand, conf.contigmap.inpaint_str_loop)), "can't provide scaffold_dir if you're also specifying per-residue ss" - self.blockadjacency = iu.BlockAdjacency(conf.scaffoldguided, conf.inference.num_designs) + self.blockadjacency = iu.BlockAdjacency(conf, conf.inference.num_designs) ################################################# @@ -945,6 +943,27 @@ def sample_init(self): xT = torch.clone(fa_stack[-1].squeeze()[:,:14,:]) + + ################################ + ### Add to Cyclic_reses init ### + ################################ + + if self._conf.inference.cyclic: + if self._conf.inference.cyc_chains is None: + self.cyclic_reses = ~self.mask_str.to(self.device).squeeze() + else: + assert isinstance(self._conf.inference.cyc_chains, str), 'cyc_chains arg must be string' + cyc_chains = self._conf.inference.cyc_chains + cyc_chains = [i.upper() for i in cyc_chains] + hal_idx = self.contig_map.hal + is_cyclized = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() + for ch in cyc_chains: + ch_mask = torch.tensor([idx[0] == ch for idx in hal_idx]).bool() + is_cyclized[ch_mask] = True + self.cyclic_reses = is_cyclized + else: + self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() + return xT, seq_T def _preprocess(self, seq, xyz_t, t): diff --git a/rfdiffusion/inference/utils.py b/rfdiffusion/inference/utils.py index 2ed6105b..f325e4ce 100644 --- a/rfdiffusion/inference/utils.py +++ b/rfdiffusion/inference/utils.py @@ -502,7 +502,7 @@ def get_next_pose( def sampler_selector(conf: DictConfig): - if conf.scaffoldguided.scaffoldguided: + if conf.scaffoldguided.scaffoldguided_enable: sampler = model_runners.ScaffoldedSampler(conf) else: if conf.inference.model_runner == "default": @@ -1012,4 +1012,4 @@ def ss_from_contig(ss_masks: dict): for idx, mask in enumerate([ss_masks['helix'],ss_masks['strand'], ss_masks['loop']]): ss[mask,idx] = 1 ss[mask, 3] = 0 # remove the mask token - return ss \ No newline at end of file + return ss From 723a66408c9a5f722812e7bbe13a116fa2fd4f42 Mon Sep 17 00:00:00 2001 From: woodsh17 Date: Tue, 18 Nov 2025 12:26:26 -0600 Subject: [PATCH 2/3] Reverting changes to flag name, so you still use scaffoldguided.scaffoldguided=True instead of scaffoldguided_enabled --- config/inference/base.yaml | 2 +- examples/design_ppi_scaffolded.sh | 2 +- rfdiffusion/inference/model_runners.py | 6 +++--- rfdiffusion/inference/utils.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/config/inference/base.yaml b/config/inference/base.yaml index 509bd4e9..3bb0a5c1 100644 --- a/config/inference/base.yaml +++ b/config/inference/base.yaml @@ -126,7 +126,7 @@ logging: inputs: False scaffoldguided: - scaffoldguided_enable: False + scaffoldguided: False target_pdb: False target_path: null scaffold_list: null diff --git a/examples/design_ppi_scaffolded.sh b/examples/design_ppi_scaffolded.sh index 5bcf5af5..63ca165e 100755 --- a/examples/design_ppi_scaffolded.sh +++ b/examples/design_ppi_scaffolded.sh @@ -7,4 +7,4 @@ # We then provide a path to a directory of different scaffolds (we've provided some for you to use, from Cao et al., 2022) # We generate 10 designs, and reduce the noise added during inference to 0 (which improves the quality of designs) -../scripts/run_inference.py scaffoldguided.target_path=input_pdbs/insulin_target.pdb inference.output_prefix=example_outputs/design_ppi_scaffolded scaffoldguided.scaffoldguided_enable=True 'ppi.hotspot_res=[A59,A83,A91]' scaffoldguided.target_pdb=True scaffoldguided.target_ss=target_folds/insulin_target_ss.pt scaffoldguided.target_adj=target_folds/insulin_target_adj.pt scaffoldguided.scaffold_dir=./ppi_scaffolds/ inference.num_designs=10 denoiser.noise_scale_ca=0 denoiser.noise_scale_frame=0 +../scripts/run_inference.py scaffoldguided.target_path=input_pdbs/insulin_target.pdb inference.output_prefix=example_outputs/design_ppi_scaffolded scaffoldguided.scaffoldguided=True 'ppi.hotspot_res=[A59,A83,A91]' scaffoldguided.target_pdb=True scaffoldguided.target_ss=target_folds/insulin_target_ss.pt scaffoldguided.target_adj=target_folds/insulin_target_adj.pt scaffoldguided.scaffold_dir=./ppi_scaffolds/ inference.num_designs=10 denoiser.noise_scale_ca=0 denoiser.noise_scale_frame=0 diff --git a/rfdiffusion/inference/model_runners.py b/rfdiffusion/inference/model_runners.py index 32b89bbd..48c2d7c8 100644 --- a/rfdiffusion/inference/model_runners.py +++ b/rfdiffusion/inference/model_runners.py @@ -78,14 +78,14 @@ def initialize(self, conf: DictConfig) -> None: if conf.contigmap.provide_seq is not None: # this is only used for partial diffusion assert conf.diffuser.partial_T is not None, "The provide_seq input is specifically for partial diffusion" - if conf.scaffoldguided.scaffoldguided_enable: + if conf.scaffoldguided.scaffoldguided: self.ckpt_path = f'{model_directory}/InpaintSeq_Fold_ckpt.pt' else: self.ckpt_path = f'{model_directory}/InpaintSeq_ckpt.pt' - elif conf.ppi.hotspot_res is not None and conf.scaffoldguided.scaffoldguided_enable is False: + elif conf.ppi.hotspot_res is not None and conf.scaffoldguided.scaffoldguided is False: # use complex trained model self.ckpt_path = f'{model_directory}/Complex_base_ckpt.pt' - elif conf.scaffoldguided.scaffoldguided_enable is True: + elif conf.scaffoldguided.scaffoldguided is True: # use complex and secondary structure-guided model self.ckpt_path = f'{model_directory}/Complex_Fold_base_ckpt.pt' else: diff --git a/rfdiffusion/inference/utils.py b/rfdiffusion/inference/utils.py index f325e4ce..3fb14112 100644 --- a/rfdiffusion/inference/utils.py +++ b/rfdiffusion/inference/utils.py @@ -502,7 +502,7 @@ def get_next_pose( def sampler_selector(conf: DictConfig): - if conf.scaffoldguided.scaffoldguided_enable: + if conf.scaffoldguided.scaffoldguided: sampler = model_runners.ScaffoldedSampler(conf) else: if conf.inference.model_runner == "default": From ecf161b4e2579fdbb9f4a668ac39656bfbc2180a Mon Sep 17 00:00:00 2001 From: woodsh17 Date: Tue, 18 Nov 2025 15:43:09 -0600 Subject: [PATCH 3/3] Move cyclic_reses initialization to a helper function and call it for Sampler and ScaffoldedSampler --- rfdiffusion/inference/model_runners.py | 842 +++++++++++++++---------- 1 file changed, 502 insertions(+), 340 deletions(-) diff --git a/rfdiffusion/inference/model_runners.py b/rfdiffusion/inference/model_runners.py index 48c2d7c8..2bd0d530 100644 --- a/rfdiffusion/inference/model_runners.py +++ b/rfdiffusion/inference/model_runners.py @@ -19,11 +19,11 @@ from rfdiffusion.model_input_logger import pickle_function_call import sys -SCRIPT_DIR=os.path.dirname(os.path.realpath(__file__)) +SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) -TOR_INDICES = util.torsion_indices +TOR_INDICES = util.torsion_indices TOR_CAN_FLIP = util.torsion_can_flip -REF_ANGLES = util.reference_angles +REF_ANGLES = util.reference_angles class Sampler: @@ -36,23 +36,27 @@ def __init__(self, conf: DictConfig): """ self.initialized = False self.initialize(conf) - + def initialize(self, conf: DictConfig) -> None: """ Initialize sampler. Args: conf: Configuration - + - Selects appropriate model from input - Assembles Config from model checkpoint and command line overrides """ self._log = logging.getLogger(__name__) if torch.cuda.is_available(): - self.device = torch.device('cuda') + self.device = torch.device("cuda") else: - self.device = torch.device('cpu') - needs_model_reload = not self.initialized or conf.inference.ckpt_override_path != self._conf.inference.ckpt_override_path + self.device = torch.device("cpu") + needs_model_reload = ( + not self.initialized + or conf.inference.ckpt_override_path + != self._conf.inference.ckpt_override_path + ) # Assign config to Sampler self._conf = conf @@ -71,29 +75,42 @@ def initialize(self, conf: DictConfig) -> None: # Initialize inference only helper objects to Sampler if conf.inference.ckpt_override_path is not None: self.ckpt_path = conf.inference.ckpt_override_path - print("WARNING: You're overriding the checkpoint path from the defaults. Check that the model you're providing can run with the inputs you're providing.") + print( + "WARNING: You're overriding the checkpoint path from the defaults. Check that the model you're providing can run with the inputs you're providing." + ) else: - if conf.contigmap.inpaint_seq is not None or conf.contigmap.provide_seq is not None or conf.contigmap.inpaint_str: + if ( + conf.contigmap.inpaint_seq is not None + or conf.contigmap.provide_seq is not None + or conf.contigmap.inpaint_str + ): # use model trained for inpaint_seq if conf.contigmap.provide_seq is not None: # this is only used for partial diffusion - assert conf.diffuser.partial_T is not None, "The provide_seq input is specifically for partial diffusion" + assert ( + conf.diffuser.partial_T is not None + ), "The provide_seq input is specifically for partial diffusion" if conf.scaffoldguided.scaffoldguided: - self.ckpt_path = f'{model_directory}/InpaintSeq_Fold_ckpt.pt' + self.ckpt_path = f"{model_directory}/InpaintSeq_Fold_ckpt.pt" else: - self.ckpt_path = f'{model_directory}/InpaintSeq_ckpt.pt' - elif conf.ppi.hotspot_res is not None and conf.scaffoldguided.scaffoldguided is False: + self.ckpt_path = f"{model_directory}/InpaintSeq_ckpt.pt" + elif ( + conf.ppi.hotspot_res is not None + and conf.scaffoldguided.scaffoldguided is False + ): # use complex trained model - self.ckpt_path = f'{model_directory}/Complex_base_ckpt.pt' + self.ckpt_path = f"{model_directory}/Complex_base_ckpt.pt" elif conf.scaffoldguided.scaffoldguided is True: # use complex and secondary structure-guided model - self.ckpt_path = f'{model_directory}/Complex_Fold_base_ckpt.pt' + self.ckpt_path = f"{model_directory}/Complex_Fold_base_ckpt.pt" else: # use default model - self.ckpt_path = f'{model_directory}/Base_ckpt.pt' + self.ckpt_path = f"{model_directory}/Base_ckpt.pt" # for saving in trb file: - assert self._conf.inference.trb_save_ckpt_path is None, "trb_save_ckpt_path is not the place to specify an input model. Specify in inference.ckpt_override_path" - self._conf['inference']['trb_save_ckpt_path']=self.ckpt_path + assert ( + self._conf.inference.trb_save_ckpt_path is None + ), "trb_save_ckpt_path is not the place to specify an input model. Specify in inference.ckpt_override_path" + self._conf["inference"]["trb_save_ckpt_path"] = self.ckpt_path ####################### ### Assemble Config ### @@ -109,7 +126,7 @@ def initialize(self, conf: DictConfig) -> None: self.assemble_config_from_chk() # self.initialize_sampler(conf) - self.initialized=True + self.initialized = True # Initialize helper objects self.inf_conf = self._conf.inference @@ -148,9 +165,13 @@ def initialize(self, conf: DictConfig) -> None: if self.inf_conf.input_pdb is None: # set default pdb - script_dir=os.path.dirname(os.path.realpath(__file__)) - self.inf_conf.input_pdb=os.path.join(script_dir, '../../examples/input_pdbs/1qys.pdb') - self.target_feats = iu.process_target(self.inf_conf.input_pdb, parse_hetatom=True, center=False) + script_dir = os.path.dirname(os.path.realpath(__file__)) + self.inf_conf.input_pdb = os.path.join( + script_dir, "../../examples/input_pdbs/1qys.pdb" + ) + self.target_feats = iu.process_target( + self.inf_conf.input_pdb, parse_hetatom=True, center=False + ) self.chain_idx = None self.idx_pdb = None @@ -163,25 +184,24 @@ def initialize(self, conf: DictConfig) -> None: self.t_step_input = int(self.diffuser_conf.partial_T) else: self.t_step_input = int(self.diffuser_conf.T) - + @property def T(self): - ''' - Return the maximum number of timesteps - that this design protocol will perform. + """ + Return the maximum number of timesteps + that this design protocol will perform. - Output: - T (int): The maximum number of timesteps to perform - ''' + Output: + T (int): The maximum number of timesteps to perform + """ return self.diffuser_conf.T def load_checkpoint(self) -> None: """Loads RF checkpoint, from which config can be generated.""" - self._log.info(f'Reading checkpoint from {self.ckpt_path}') - print('This is inf_conf.ckpt_path') + self._log.info(f"Reading checkpoint from {self.ckpt_path}") + print("This is inf_conf.ckpt_path") print(self.ckpt_path) - self.ckpt = torch.load( - self.ckpt_path, map_location=self.device) + self.ckpt = torch.load(self.ckpt_path, map_location=self.device) def assemble_config_from_chk(self) -> None: """ @@ -193,7 +213,7 @@ def assemble_config_from_chk(self) -> None: Actions: - Replaces all -model and -diffuser items - Throws a warning if there are items in -model and -diffuser that aren't in the checkpoint - + This throws an error if there is a flag in the checkpoint 'config_dict' that isn't in the inference config. This should ensure that whenever a feature is added in the training setup, it is accounted for in the inference script. @@ -204,71 +224,90 @@ def assemble_config_from_chk(self) -> None: overrides = HydraConfig.get().overrides.task print("Assembling -model, -diffuser and -preprocess configs from checkpoint") - for cat in ['model','diffuser','preprocess']: + for cat in ["model", "diffuser", "preprocess"]: for key in self._conf[cat]: try: - print(f"USING MODEL CONFIG: self._conf[{cat}][{key}] = {self.ckpt['config_dict'][cat][key]}") - self._conf[cat][key] = self.ckpt['config_dict'][cat][key] + print( + f"USING MODEL CONFIG: self._conf[{cat}][{key}] = {self.ckpt['config_dict'][cat][key]}" + ) + self._conf[cat][key] = self.ckpt["config_dict"][cat][key] except: pass - + # add overrides back in again for override in overrides: - if override.split(".")[0] in ['model','diffuser','preprocess']: - print(f'WARNING: You are changing {override.split("=")[0]} from the value this model was trained with. Are you sure you know what you are doing?') - mytype = type(self._conf[override.split(".")[0]][override.split(".")[1].split("=")[0]]) - self._conf[override.split(".")[0]][override.split(".")[1].split("=")[0]] = mytype(override.split("=")[1]) + if override.split(".")[0] in ["model", "diffuser", "preprocess"]: + print( + f'WARNING: You are changing {override.split("=")[0]} from the value this model was trained with. Are you sure you know what you are doing?' + ) + mytype = type( + self._conf[override.split(".")[0]][ + override.split(".")[1].split("=")[0] + ] + ) + self._conf[override.split(".")[0]][ + override.split(".")[1].split("=")[0] + ] = mytype(override.split("=")[1]) def load_model(self): """Create RosettaFold model from preloaded checkpoint.""" - + # Read input dimensions from checkpoint. - self.d_t1d=self._conf.preprocess.d_t1d - self.d_t2d=self._conf.preprocess.d_t2d - model = RoseTTAFoldModule(**self._conf.model, d_t1d=self.d_t1d, d_t2d=self.d_t2d, T=self._conf.diffuser.T).to(self.device) + self.d_t1d = self._conf.preprocess.d_t1d + self.d_t2d = self._conf.preprocess.d_t2d + model = RoseTTAFoldModule( + **self._conf.model, + d_t1d=self.d_t1d, + d_t2d=self.d_t2d, + T=self._conf.diffuser.T, + ).to(self.device) if self._conf.logging.inputs: - pickle_dir = pickle_function_call(model, 'forward', 'inference') - print(f'pickle_dir: {pickle_dir}') + pickle_dir = pickle_function_call(model, "forward", "inference") + print(f"pickle_dir: {pickle_dir}") model = model.eval() - self._log.info(f'Loading checkpoint.') - model.load_state_dict(self.ckpt['model_state_dict'], strict=True) + self._log.info(f"Loading checkpoint.") + model.load_state_dict(self.ckpt["model_state_dict"], strict=True) return model def construct_contig(self, target_feats): """ Construct contig class describing the protein to be generated """ - self._log.info(f'Using contig: {self.contig_conf.contigs}') + self._log.info(f"Using contig: {self.contig_conf.contigs}") return ContigMap(target_feats, **self.contig_conf) def construct_denoiser(self, L, visible): """Make length-specific denoiser.""" denoise_kwargs = OmegaConf.to_container(self.diffuser_conf) denoise_kwargs.update(OmegaConf.to_container(self.denoiser_conf)) - denoise_kwargs.update({ - 'L': L, - 'diffuser': self.diffuser, - 'potential_manager': self.potential_manager, - }) + denoise_kwargs.update( + { + "L": L, + "diffuser": self.diffuser, + "potential_manager": self.potential_manager, + } + ) return iu.Denoise(**denoise_kwargs) def sample_init(self, return_forward_trajectory=False): """ Initial features to start the sampling process. - + Modify signature and function body for different initialization based on the config. - + Returns: xt: Starting positions with a portion of them randomly sampled. seq_t: Starting sequence with a portion of them set to unknown. """ - + ####################### ### Parse input pdb ### ####################### - self.target_feats = iu.process_target(self.inf_conf.input_pdb, parse_hetatom=True, center=False) + self.target_feats = iu.process_target( + self.inf_conf.input_pdb, parse_hetatom=True, center=False + ) ################################ ### Generate specific contig ### @@ -278,58 +317,44 @@ def sample_init(self, return_forward_trajectory=False): self.contig_map = self.construct_contig(self.target_feats) self.mappings = self.contig_map.get_mappings() - self.mask_seq = torch.from_numpy(self.contig_map.inpaint_seq)[None,:] - self.mask_str = torch.from_numpy(self.contig_map.inpaint_str)[None,:] - self.binderlen = len(self.contig_map.inpaint) + self.mask_seq = torch.from_numpy(self.contig_map.inpaint_seq)[None, :] + self.mask_str = torch.from_numpy(self.contig_map.inpaint_str)[None, :] + self.binderlen = len(self.contig_map.inpaint) ####################################### ### Resolve cyclic peptide indicies ### ####################################### - if self._conf.inference.cyclic: - if self._conf.inference.cyc_chains is None: - # default to all residues being cyclized - self.cyclic_reses = ~self.mask_str.to(self.device).squeeze() - else: - # use cyc_chains arg to determine cyclic_reses mask - assert type(self._conf.inference.cyc_chains) is str, 'cyc_chains arg must be string' - cyc_chains = self._conf.inference.cyc_chains - cyc_chains = [i.upper() for i in cyc_chains] - hal_idx = self.contig_map.hal # the pdb indices of output, knowledge of different chains - is_cyclized = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() # initially empty + self._init_cyclic_reses(self.mask_str, self.contig_map) - for ch in cyc_chains: - ch_mask = torch.tensor([idx[0] == ch for idx in hal_idx]).bool() - is_cyclized[ch_mask] = True # set this whole chain to be cyclic - self.cyclic_reses = is_cyclized - else: - self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() - #################### ### Get Hotspots ### #################### - self.hotspot_0idx=iu.get_idx0_hotspots(self.mappings, self.ppi_conf, self.binderlen) - + self.hotspot_0idx = iu.get_idx0_hotspots( + self.mappings, self.ppi_conf, self.binderlen + ) ##################################### ### Initialise Potentials Manager ### ##################################### - self.potential_manager = PotentialManager(self.potential_conf, - self.ppi_conf, - self.diffuser_conf, - self.inf_conf, - self.hotspot_0idx, - self.binderlen) + self.potential_manager = PotentialManager( + self.potential_conf, + self.ppi_conf, + self.diffuser_conf, + self.inf_conf, + self.hotspot_0idx, + self.binderlen, + ) ################################### ### Initialize other attributes ### ################################### - xyz_27 = self.target_feats['xyz_27'] - mask_27 = self.target_feats['mask_27'] - seq_orig = self.target_feats['seq'] + xyz_27 = self.target_feats["xyz_27"] + mask_27 = self.target_feats["mask_27"] + seq_orig = self.target_feats["seq"] L_mapped = len(self.contig_map.ref) - contig_map=self.contig_map + contig_map = self.contig_map self.diffusion_mask = self.mask_str length_bound = self.contig_map.sampled_mask_length_bound.copy() @@ -342,7 +367,9 @@ def sample_init(self, return_forward_trajectory=False): # Iterate over each chain for last_res in length_bound: - chain_ids = {contig_ref[0] for contig_ref in self.contig_map.ref[first_res: last_res]} + chain_ids = { + contig_ref[0] for contig_ref in self.contig_map.ref[first_res:last_res] + } # If we are designing this chain, it will have a '-' in the contig map # Renumber this chain from 1 if "_" in chain_ids: @@ -351,20 +378,29 @@ def sample_init(self, return_forward_trajectory=False): # If there are no fixed residues that have a chain id, pick the first available letter if not chain_ids: if not available_chains: - raise ValueError(f"No available chains! You are trying to design a new chain, and you have " - f"already used all upper- and lower-case chain ids (up to 52 chains): " - f"{','.join(all_chains)}.") + raise ValueError( + f"No available chains! You are trying to design a new chain, and you have " + f"already used all upper- and lower-case chain ids (up to 52 chains): " + f"{','.join(all_chains)}." + ) chain_id = available_chains[0] available_chains.remove(chain_id) # Otherwise, use the chain of the fixed (motif) residues else: - assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}" + assert ( + len(chain_ids) == 1 + ), f"Error: Multiple chain IDs in chain: {chain_ids}" chain_id = list(chain_ids)[0] self.chain_idx += [chain_id] * (last_res - first_res) # If this is a fixed chain, maintain the chain and residue numbering else: - self.idx_pdb += [contig_ref[1] for contig_ref in self.contig_map.ref[first_res: last_res]] - assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}" + self.idx_pdb += [ + contig_ref[1] + for contig_ref in self.contig_map.ref[first_res:last_res] + ] + assert ( + len(chain_ids) == 1 + ), f"Error: Multiple chain IDs in chain: {chain_ids}" self.chain_idx += [list(chain_ids)[0]] * (last_res - first_res) first_res = last_res @@ -373,61 +409,74 @@ def sample_init(self, return_forward_trajectory=False): #################################### if self.diffuser_conf.partial_T: - assert xyz_27.shape[0] == L_mapped, f"there must be a coordinate in the input PDB for \ + assert ( + xyz_27.shape[0] == L_mapped + ), f"there must be a coordinate in the input PDB for \ each residue implied by the contig string for partial diffusion. length of \ input PDB != length of contig string: {xyz_27.shape[0]} != {L_mapped}" - assert contig_map.hal_idx0 == contig_map.ref_idx0, f'for partial diffusion there can \ + assert ( + contig_map.hal_idx0 == contig_map.ref_idx0 + ), f"for partial diffusion there can \ be no offset between the index of a residue in the input and the index of the \ - residue in the output, {contig_map.hal_idx0} != {contig_map.ref_idx0}' + residue in the output, {contig_map.hal_idx0} != {contig_map.ref_idx0}" # Partially diffusing from a known structure - xyz_mapped=xyz_27 + xyz_mapped = xyz_27 atom_mask_mapped = mask_27 else: # Fully diffusing from points initialised at the origin # adjust size of input xt according to residue map - xyz_mapped = torch.full((1,1,L_mapped,27,3), np.nan) - xyz_mapped[:, :, contig_map.hal_idx0, ...] = xyz_27[contig_map.ref_idx0,...] + xyz_mapped = torch.full((1, 1, L_mapped, 27, 3), np.nan) + xyz_mapped[:, :, contig_map.hal_idx0, ...] = xyz_27[ + contig_map.ref_idx0, ... + ] xyz_motif_prealign = xyz_mapped.clone() - motif_prealign_com = xyz_motif_prealign[0,0,:,1].mean(dim=0) - self.motif_com = xyz_27[contig_map.ref_idx0,1].mean(dim=0) + motif_prealign_com = xyz_motif_prealign[0, 0, :, 1].mean(dim=0) + self.motif_com = xyz_27[contig_map.ref_idx0, 1].mean(dim=0) xyz_mapped = get_init_xyz(xyz_mapped).squeeze() # adjust the size of the input atom map atom_mask_mapped = torch.full((L_mapped, 27), False) atom_mask_mapped[contig_map.hal_idx0] = mask_27[contig_map.ref_idx0] - # Diffuse the contig-mapped coordinates + # Diffuse the contig-mapped coordinates if self.diffuser_conf.partial_T: - assert self.diffuser_conf.partial_T <= self.diffuser_conf.T, "Partial_T must be less than T" + assert ( + self.diffuser_conf.partial_T <= self.diffuser_conf.T + ), "Partial_T must be less than T" self.t_step_input = int(self.diffuser_conf.partial_T) else: self.t_step_input = int(self.diffuser_conf.T) - t_list = np.arange(1, self.t_step_input+1) + t_list = np.arange(1, self.t_step_input + 1) ################################# ### Generate initial sequence ### ################################# - seq_t = torch.full((1,L_mapped), 21).squeeze() # 21 is the mask token + seq_t = torch.full((1, L_mapped), 21).squeeze() # 21 is the mask token seq_t[contig_map.hal_idx0] = seq_orig[contig_map.ref_idx0] - + # Unmask sequence if desired if self._conf.contigmap.provide_seq is not None: - seq_t[self.mask_seq.squeeze()] = seq_orig[self.mask_seq.squeeze()] + seq_t[self.mask_seq.squeeze()] = seq_orig[self.mask_seq.squeeze()] seq_t[~self.mask_seq.squeeze()] = 21 - seq_t = torch.nn.functional.one_hot(seq_t, num_classes=22).float() # [L,22] - seq_orig = torch.nn.functional.one_hot(seq_orig, num_classes=22).float() # [L,22] + seq_t = torch.nn.functional.one_hot(seq_t, num_classes=22).float() # [L,22] + seq_orig = torch.nn.functional.one_hot( + seq_orig, num_classes=22 + ).float() # [L,22] fa_stack, xyz_true = self.diffuser.diffuse_pose( xyz_mapped, torch.clone(seq_t), atom_mask_mapped.squeeze(), diffusion_mask=self.diffusion_mask.squeeze(), - t_list=t_list) - xT = fa_stack[-1].squeeze()[:,:14,:] + t_list=t_list, + ) + xT = fa_stack[-1].squeeze()[:, :14, :] xt = torch.clone(xT) - self.denoiser = self.construct_denoiser(len(self.contig_map.ref), visible=self.mask_seq.squeeze()) + self.denoiser = self.construct_denoiser( + len(self.contig_map.ref), visible=self.mask_seq.squeeze() + ) ###################### ### Apply Symmetry ### @@ -435,8 +484,8 @@ def sample_init(self, return_forward_trajectory=False): if self.symmetry is not None: xt, seq_t = self.symmetry.apply_symmetry(xt, seq_t) - self._log.info(f'Sequence init: {seq2chars(torch.argmax(seq_t, dim=-1))}') - + self._log.info(f"Sequence init: {seq2chars(torch.argmax(seq_t, dim=-1))}") + self.msa_prev = None self.pair_prev = None self.state_prev = None @@ -446,15 +495,32 @@ def sample_init(self, return_forward_trajectory=False): ######################################### if self.potential_conf.guiding_potentials is not None: - if any(list(filter(lambda x: "substrate_contacts" in x, self.potential_conf.guiding_potentials))): - assert len(self.target_feats['xyz_het']) > 0, "If you're using the Substrate Contact potential, \ + if any( + list( + filter( + lambda x: "substrate_contacts" in x, + self.potential_conf.guiding_potentials, + ) + ) + ): + assert ( + len(self.target_feats["xyz_het"]) > 0 + ), "If you're using the Substrate Contact potential, \ you need to make sure there's a ligand in the input_pdb file!" - het_names = np.array([i['name'].strip() for i in self.target_feats['info_het']]) - xyz_het = self.target_feats['xyz_het'][het_names == self._conf.potentials.substrate] + het_names = np.array( + [i["name"].strip() for i in self.target_feats["info_het"]] + ) + xyz_het = self.target_feats["xyz_het"][ + het_names == self._conf.potentials.substrate + ] xyz_het = torch.from_numpy(xyz_het) - assert xyz_het.shape[0] > 0, f'expected >0 heteroatoms from ligand with name {self._conf.potentials.substrate}' - xyz_motif_prealign = xyz_motif_prealign[0,0][self.diffusion_mask.squeeze()] - motif_prealign_com = xyz_motif_prealign[:,1].mean(dim=0) + assert ( + xyz_het.shape[0] > 0 + ), f"expected >0 heteroatoms from ligand with name {self._conf.potentials.substrate}" + xyz_motif_prealign = xyz_motif_prealign[0, 0][ + self.diffusion_mask.squeeze() + ] + motif_prealign_com = xyz_motif_prealign[:, 1].mean(dim=0) xyz_het_com = xyz_het.mean(dim=0) for pot in self.potential_manager.potentials_to_apply: pot.motif_substrate_atoms = xyz_het @@ -463,18 +529,51 @@ def sample_init(self, return_forward_trajectory=False): pot.diffuser = self.diffuser return xt, seq_t + def _init_cyclic_reses(self, mask_str, contig_map): + """ + Centralized logic for initializing self.cyclic_reses. + + mask_str: tensor-like mask (can be 1D or have a leading batch dim), where True indicates motif/fixed. + contig_map: object with attribute 'hal' (iterable of per-residue chain ids). + """ + mask = mask_str.squeeze() + # ensure on correct device and dtype + mask = mask.to(self.device) + + if self._conf.inference.cyclic: + if self._conf.inference.cyc_chains is None: + self.cyclic_reses = ~mask + else: + assert isinstance( + self._conf.inference.cyc_chains, str + ), "cyc_chains arg must be string" + cyc_chains = [i.upper() for i in self._conf.inference.cyc_chains] + hal_idx = contig_map.hal + is_cyclized = torch.zeros_like(mask).bool().to(self.device).squeeze() + # build boolean mask per residue based on chain id in hal_idx + ch_mask_list = [idx[0] for idx in hal_idx] + for ch in cyc_chains: + ch_mask = torch.tensor( + [c == ch for c in ch_mask_list], + dtype=torch.bool, + device=self.device, + ) + is_cyclized[ch_mask] = True + self.cyclic_reses = is_cyclized + else: + self.cyclic_reses = torch.zeros_like(mask).bool().to(self.device).squeeze() + def _preprocess(self, seq, xyz_t, t, repack=False): - """ Function to prepare inputs to diffusion model - - seq (L,22) one-hot sequence + + seq (L,22) one-hot sequence msa_masked (1,1,L,48) msa_full (1,1,L,25) - - xyz_t (L,14,3) template crds (diffused) + + xyz_t (L,14,3) template crds (diffused) t1d (1,L,28) this is the t1d before tacking on the chi angles: - seq + unknown/mask (21) @@ -484,10 +583,10 @@ def _preprocess(self, seq, xyz_t, t, repack=False): - contacting residues: for ppi. Target residues in contact with binder (1) - empty feature (legacy) (1) - ss (H, E, L, MASK) (4) - + t2d (1, L, L, 45) - last plane is block adjacency - """ + """ L = seq.shape[0] T = self.T @@ -497,61 +596,60 @@ def _preprocess(self, seq, xyz_t, t, repack=False): ################## ### msa_masked ### ################## - msa_masked = torch.zeros((1,1,L,48)) - msa_masked[:,:,:,:22] = seq[None, None] - msa_masked[:,:,:,22:44] = seq[None, None] - msa_masked[:,:,0,46] = 1.0 - msa_masked[:,:,-1,47] = 1.0 + msa_masked = torch.zeros((1, 1, L, 48)) + msa_masked[:, :, :, :22] = seq[None, None] + msa_masked[:, :, :, 22:44] = seq[None, None] + msa_masked[:, :, 0, 46] = 1.0 + msa_masked[:, :, -1, 47] = 1.0 ################ ### msa_full ### ################ - msa_full = torch.zeros((1,1,L,25)) - msa_full[:,:,:,:22] = seq[None, None] - msa_full[:,:,0,23] = 1.0 - msa_full[:,:,-1,24] = 1.0 + msa_full = torch.zeros((1, 1, L, 25)) + msa_full[:, :, :, :22] = seq[None, None] + msa_full[:, :, 0, 23] = 1.0 + msa_full[:, :, -1, 24] = 1.0 ########### ### t1d ### - ########### + ########### # Here we need to go from one hot with 22 classes to one hot with 21 classes (last plane is missing token) - t1d = torch.zeros((1,1,L,21)) + t1d = torch.zeros((1, 1, L, 21)) seqt1d = torch.clone(seq) for idx in range(L): - if seqt1d[idx,21] == 1: - seqt1d[idx,20] = 1 - seqt1d[idx,21] = 0 - - t1d[:,:,:,:21] = seqt1d[None,None,:,:21] - + if seqt1d[idx, 21] == 1: + seqt1d[idx, 20] = 1 + seqt1d[idx, 21] = 0 + + t1d[:, :, :, :21] = seqt1d[None, None, :, :21] # Set timestep feature to 1 where diffusion mask is True, else 1-t/T timefeature = torch.zeros((L)).float() timefeature[self.mask_str.squeeze()] = 1 - timefeature[~self.mask_str.squeeze()] = 1 - t/self.T - timefeature = timefeature[None,None,...,None] + timefeature[~self.mask_str.squeeze()] = 1 - t / self.T + timefeature = timefeature[None, None, ..., None] t1d = torch.cat((t1d, timefeature), dim=-1).float() - + ############# ### xyz_t ### ############# if self.preprocess_conf.sidechain_input: - xyz_t[torch.where(seq == 21, True, False),3:,:] = float('nan') + xyz_t[torch.where(seq == 21, True, False), 3:, :] = float("nan") else: - xyz_t[~self.mask_str.squeeze(),3:,:] = float('nan') + xyz_t[~self.mask_str.squeeze(), 3:, :] = float("nan") - xyz_t=xyz_t[None, None] - xyz_t = torch.cat((xyz_t, torch.full((1,1,L,13,3), float('nan'))), dim=3) + xyz_t = xyz_t[None, None] + xyz_t = torch.cat((xyz_t, torch.full((1, 1, L, 13, 3), float("nan"))), dim=3) ########### ### t2d ### ########### t2d = xyz_to_t2d(xyz_t) - - ########### + + ########### ### idx ### ########### idx = torch.tensor(self.contig_map.rf)[None] @@ -559,15 +657,17 @@ def _preprocess(self, seq, xyz_t, t, repack=False): ############### ### alpha_t ### ############### - seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L) - alpha, _, alpha_mask, _ = util.get_torsions(xyz_t.reshape(-1, L, 27, 3), seq_tmp, TOR_INDICES, TOR_CAN_FLIP, REF_ANGLES) - alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0])) + seq_tmp = t1d[..., :-1].argmax(dim=-1).reshape(-1, L) + alpha, _, alpha_mask, _ = util.get_torsions( + xyz_t.reshape(-1, L, 27, 3), seq_tmp, TOR_INDICES, TOR_CAN_FLIP, REF_ANGLES + ) + alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[..., 0])) alpha[torch.isnan(alpha)] = 0.0 - alpha = alpha.reshape(1,-1,L,10,2) - alpha_mask = alpha_mask.reshape(1,-1,L,10,1) + alpha = alpha.reshape(1, -1, L, 10, 2) + alpha_mask = alpha_mask.reshape(1, -1, L, 10, 1) alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(1, -1, L, 30) - #put tensors on device + # put tensors on device msa_masked = msa_masked.to(self.device) msa_full = msa_full.to(self.device) seq = seq.to(self.device) @@ -576,49 +676,69 @@ def _preprocess(self, seq, xyz_t, t, repack=False): t1d = t1d.to(self.device) t2d = t2d.to(self.device) alpha_t = alpha_t.to(self.device) - + ###################### ### added_features ### ###################### - if self.preprocess_conf.d_t1d >= 24: # add hotspot residues + if self.preprocess_conf.d_t1d >= 24: # add hotspot residues hotspot_tens = torch.zeros(L).float() if self.ppi_conf.hotspot_res is None: - print("WARNING: you're using a model trained on complexes and hotspot residues, without specifying hotspots.\ - If you're doing monomer diffusion this is fine") - hotspot_idx=[] + print( + "WARNING: you're using a model trained on complexes and hotspot residues, without specifying hotspots.\ + If you're doing monomer diffusion this is fine" + ) + hotspot_idx = [] else: - hotspots = [(i[0],int(i[1:])) for i in self.ppi_conf.hotspot_res] - hotspot_idx=[] - for i,res in enumerate(self.contig_map.con_ref_pdb_idx): + hotspots = [(i[0], int(i[1:])) for i in self.ppi_conf.hotspot_res] + hotspot_idx = [] + for i, res in enumerate(self.contig_map.con_ref_pdb_idx): if res in hotspots: hotspot_idx.append(self.contig_map.hal_idx0[i]) hotspot_tens[hotspot_idx] = 1.0 # Add blank (legacy) feature and hotspot tensor - t1d=torch.cat((t1d, torch.zeros_like(t1d[...,:1]), hotspot_tens[None,None,...,None].to(self.device)), dim=-1) + t1d = torch.cat( + ( + t1d, + torch.zeros_like(t1d[..., :1]), + hotspot_tens[None, None, ..., None].to(self.device), + ), + dim=-1, + ) + + return ( + msa_masked, + msa_full, + seq[None], + torch.squeeze(xyz_t, dim=0), + idx, + t1d, + t2d, + xyz_t, + alpha_t, + ) - return msa_masked, msa_full, seq[None], torch.squeeze(xyz_t, dim=0), idx, t1d, t2d, xyz_t, alpha_t - def sample_step(self, *, t, x_t, seq_init, final_step): - '''Generate the next pose that the model should be supplied at timestep t-1. + """Generate the next pose that the model should be supplied at timestep t-1. Args: t (int): The timestep that has just been predicted seq_t (torch.tensor): (L,22) The sequence at the beginning of this timestep x_t (torch.tensor): (L,14,3) The residue positions at the beginning of this timestep seq_init (torch.tensor): (L,22) The initialized sequence used in updating the sequence. - + Returns: px0: (L,14,3) The model's prediction of x0. x_t_1: (L,14,3) The updated positions of the next step. seq_t_1: (L,22) The updated sequence of the next step. tors_t_1: (L, ?) The updated torsion angles of the next step. plddt: (L, 1) Predicted lDDT of x0. - ''' - msa_masked, msa_full, seq_in, xt_in, idx_pdb, t1d, t2d, xyz_t, alpha_t = self._preprocess( - seq_init, x_t, t) + """ + msa_masked, msa_full, seq_in, xt_in, idx_pdb, t1d, t2d, xyz_t, alpha_t = ( + self._preprocess(seq_init, x_t, t) + ) - N,L = msa_masked.shape[:2] + N, L = msa_masked.shape[:2] if self.symmetry is not None: idx_pdb, self.chain_idx = self.symmetry.res_idx_procesing(res_idx=idx_pdb) @@ -628,38 +748,40 @@ def sample_step(self, *, t, x_t, seq_init, final_step): state_prev = None with torch.no_grad(): - msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model(msa_masked, - msa_full, - seq_in, - xt_in, - idx_pdb, - t1d=t1d, - t2d=t2d, - xyz_t=xyz_t, - alpha_t=alpha_t, - msa_prev = msa_prev, - pair_prev = pair_prev, - state_prev = state_prev, - t=torch.tensor(t), - return_infer=True, - motif_mask=self.diffusion_mask.squeeze().to(self.device)) - - # prediction of X0 - _, px0 = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) - px0 = px0.squeeze()[:,:14] - + msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model( + msa_masked, + msa_full, + seq_in, + xt_in, + idx_pdb, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev=msa_prev, + pair_prev=pair_prev, + state_prev=state_prev, + t=torch.tensor(t), + return_infer=True, + motif_mask=self.diffusion_mask.squeeze().to(self.device), + ) + + # prediction of X0 + _, px0 = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) + px0 = px0.squeeze()[:, :14] + ##################### ### Get next pose ### ##################### - + if t > final_step: - seq_t_1 = nn.one_hot(seq_init,num_classes=22).to(self.device) + seq_t_1 = nn.one_hot(seq_init, num_classes=22).to(self.device) x_t_1, px0 = self.denoiser.get_next_pose( xt=x_t, px0=px0, t=t, diffusion_mask=self.mask_str.squeeze(), - align_motif=self.inf_conf.align_motif + align_motif=self.inf_conf.align_motif, ) else: x_t_1 = torch.clone(px0).to(x_t.device) @@ -679,7 +801,7 @@ class SelfConditioning(Sampler): """ def sample_step(self, *, t, x_t, seq_init, final_step): - ''' + """ Generate the next pose that the model should be supplied at timestep t-1. Args: t (int): The timestep that has just been predicted @@ -691,24 +813,27 @@ def sample_step(self, *, t, x_t, seq_init, final_step): x_t_1: (L,14,3) The updated positions of the next step. seq_t_1: (L) The sequence to the next step (== seq_init) plddt: (L, 1) Predicted lDDT of x0. - ''' + """ - msa_masked, msa_full, seq_in, xt_in, idx_pdb, t1d, t2d, xyz_t, alpha_t = self._preprocess( - seq_init, x_t, t) - B,N,L = xyz_t.shape[:3] + msa_masked, msa_full, seq_in, xt_in, idx_pdb, t1d, t2d, xyz_t, alpha_t = ( + self._preprocess(seq_init, x_t, t) + ) + B, N, L = xyz_t.shape[:3] ################################## ######## Str Self Cond ########### ################################## - if (t < self.diffuser.T) and (t != self.diffuser_conf.partial_T): - zeros = torch.zeros(B,1,L,24,3).float().to(xyz_t.device) - xyz_t = torch.cat((self.prev_pred.unsqueeze(1),zeros), dim=-2) # [B,T,L,27,3] - t2d_44 = xyz_to_t2d(xyz_t) # [B,T,L,L,44] + if (t < self.diffuser.T) and (t != self.diffuser_conf.partial_T): + zeros = torch.zeros(B, 1, L, 24, 3).float().to(xyz_t.device) + xyz_t = torch.cat( + (self.prev_pred.unsqueeze(1), zeros), dim=-2 + ) # [B,T,L,27,3] + t2d_44 = xyz_to_t2d(xyz_t) # [B,T,L,L,44] else: xyz_t = torch.zeros_like(xyz_t) - t2d_44 = torch.zeros_like(t2d[...,:44]) + t2d_44 = torch.zeros_like(t2d[..., :44]) # No effect if t2d is only dim 44 - t2d[...,:44] = t2d_44 + t2d[..., :44] = t2d_44 if self.symmetry is not None: idx_pdb, self.chain_idx = self.symmetry.res_idx_procesing(res_idx=idx_pdb) @@ -717,32 +842,36 @@ def sample_step(self, *, t, x_t, seq_init, final_step): ### Forward Pass ### #################### with torch.no_grad(): - msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model(msa_masked, - msa_full, - seq_in, - xt_in, - idx_pdb, - t1d=t1d, - t2d=t2d, - xyz_t=xyz_t, - alpha_t=alpha_t, - msa_prev = None, - pair_prev = None, - state_prev = None, - t=torch.tensor(t), - return_infer=True, - motif_mask=self.diffusion_mask.squeeze().to(self.device), - cyclic_reses=self.cyclic_reses) + msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model( + msa_masked, + msa_full, + seq_in, + xt_in, + idx_pdb, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev=None, + pair_prev=None, + state_prev=None, + t=torch.tensor(t), + return_infer=True, + motif_mask=self.diffusion_mask.squeeze().to(self.device), + cyclic_reses=self.cyclic_reses, + ) if self.symmetry is not None and self.inf_conf.symmetric_self_cond: - px0 = self.symmetrise_prev_pred(px0=px0,seq_in=seq_in, alpha=alpha)[:,:,:3] + px0 = self.symmetrise_prev_pred(px0=px0, seq_in=seq_in, alpha=alpha)[ + :, :, :3 + ] self.prev_pred = torch.clone(px0) # prediction of X0 - _, px0 = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) - px0 = px0.squeeze()[:,:14] - + _, px0 = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) + px0 = px0.squeeze()[:, :14] + ########################### ### Generate Next Input ### ########################### @@ -755,10 +884,11 @@ def sample_step(self, *, t, x_t, seq_init, final_step): t=t, diffusion_mask=self.mask_str.squeeze(), align_motif=self.inf_conf.align_motif, - include_motif_sidechains=self.preprocess_conf.motif_sidechain_input + include_motif_sidechains=self.preprocess_conf.motif_sidechain_input, ) self._log.info( - f'Timestep {t}, input to next step: { seq2chars(torch.argmax(seq_t_1, dim=-1).tolist())}') + f"Timestep {t}, input to next step: { seq2chars(torch.argmax(seq_t_1, dim=-1).tolist())}" + ) else: x_t_1 = torch.clone(px0).to(x_t.device) px0 = px0.to(x_t.device) @@ -776,15 +906,20 @@ def symmetrise_prev_pred(self, px0, seq_in, alpha): """ Method for symmetrising px0 output for self-conditioning """ - _,px0_aa = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) - px0_sym,_ = self.symmetry.apply_symmetry(px0_aa.to('cpu').squeeze()[:,:14], torch.argmax(seq_in, dim=-1).squeeze().to('cpu')) + _, px0_aa = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) + px0_sym, _ = self.symmetry.apply_symmetry( + px0_aa.to("cpu").squeeze()[:, :14], + torch.argmax(seq_in, dim=-1).squeeze().to("cpu"), + ) px0_sym = px0_sym[None].to(self.device) return px0_sym + class ScaffoldedSampler(SelfConditioning): - """ + """ Model Runner for Scaffold-Constrained diffusion """ + def __init__(self, conf: DictConfig): """ Initialize scaffolded sampler. @@ -799,15 +934,30 @@ def __init__(self, conf: DictConfig): super().__init__(conf) # initialize BlockAdjacency sampling class if conf.scaffoldguided.scaffold_dir is None: - assert any(x is not None for x in (conf.contigmap.inpaint_str_helix, conf.contigmap.inpaint_str_strand, conf.contigmap.inpaint_str_loop)) + assert any( + x is not None + for x in ( + conf.contigmap.inpaint_str_helix, + conf.contigmap.inpaint_str_strand, + conf.contigmap.inpaint_str_loop, + ) + ) if conf.contigmap.inpaint_str_loop is not None: - assert conf.scaffoldguided.mask_loops == False, "You shouldn't be masking loops if you're specifying loop secondary structure" + assert ( + conf.scaffoldguided.mask_loops == False + ), "You shouldn't be masking loops if you're specifying loop secondary structure" else: # initialize BlockAdjacency sampling class - assert all(x is None for x in (conf.contigmap.inpaint_str_helix, conf.contigmap.inpaint_str_strand, conf.contigmap.inpaint_str_loop)), "can't provide scaffold_dir if you're also specifying per-residue ss" + assert all( + x is None + for x in ( + conf.contigmap.inpaint_str_helix, + conf.contigmap.inpaint_str_strand, + conf.contigmap.inpaint_str_loop, + ) + ), "can't provide scaffold_dir if you're also specifying per-residue ss" self.blockadjacency = iu.BlockAdjacency(conf, conf.inference.num_designs) - ################################################# ### Initialize target, if doing binder design ### ################################################# @@ -817,18 +967,22 @@ def __init__(self, conf: DictConfig): self.target_pdb = self.target.get_target() if conf.scaffoldguided.target_ss is not None: self.target_ss = torch.load(conf.scaffoldguided.target_ss).long() - self.target_ss = torch.nn.functional.one_hot(self.target_ss, num_classes=4) + self.target_ss = torch.nn.functional.one_hot( + self.target_ss, num_classes=4 + ) if self._conf.scaffoldguided.contig_crop is not None: - self.target_ss=self.target_ss[self.target_pdb['crop_mask']] + self.target_ss = self.target_ss[self.target_pdb["crop_mask"]] if conf.scaffoldguided.target_adj is not None: self.target_adj = torch.load(conf.scaffoldguided.target_adj).long() - self.target_adj=torch.nn.functional.one_hot(self.target_adj, num_classes=3) + self.target_adj = torch.nn.functional.one_hot( + self.target_adj, num_classes=3 + ) if self._conf.scaffoldguided.contig_crop is not None: - self.target_adj=self.target_adj[self.target_pdb['crop_mask']] - self.target_adj=self.target_adj[:,self.target_pdb['crop_mask']] + self.target_adj = self.target_adj[self.target_pdb["crop_mask"]] + self.target_adj = self.target_adj[:, self.target_pdb["crop_mask"]] else: self.target = None - self.target_pdb=False + self.target_pdb = False def sample_init(self): """ @@ -838,58 +992,67 @@ def sample_init(self): ########################## ### Process Fold Input ### ########################## - if hasattr(self, 'blockadjacency'): + if hasattr(self, "blockadjacency"): self.L, self.ss, self.adj = self.blockadjacency.get_scaffold() self.adj = nn.one_hot(self.adj.long(), num_classes=3) else: - self.L=100 # shim. Get's overwritten + self.L = 100 # shim. Get's overwritten ############################## ### Auto-contig generation ### - ############################## + ############################## - if self.contig_conf.contigs is None: + if self.contig_conf.contigs is None: # process target - xT = torch.full((self.L, 27,3), np.nan) - xT = get_init_xyz(xT[None,None]).squeeze() - seq_T = torch.full((self.L,),21) - self.diffusion_mask = torch.full((self.L,),False) - atom_mask = torch.full((self.L,27), False) - self.binderlen=self.L + xT = torch.full((self.L, 27, 3), np.nan) + xT = get_init_xyz(xT[None, None]).squeeze() + seq_T = torch.full((self.L,), 21) + self.diffusion_mask = torch.full((self.L,), False) + atom_mask = torch.full((self.L, 27), False) + self.binderlen = self.L if self.target: - target_L = np.shape(self.target_pdb['xyz'])[0] + target_L = np.shape(self.target_pdb["xyz"])[0] # xyz target_xyz = torch.full((target_L, 27, 3), np.nan) - target_xyz[:,:14,:] = torch.from_numpy(self.target_pdb['xyz']) + target_xyz[:, :14, :] = torch.from_numpy(self.target_pdb["xyz"]) xT = torch.cat((xT, target_xyz), dim=0) # seq - seq_T = torch.cat((seq_T, torch.from_numpy(self.target_pdb['seq'])), dim=0) + seq_T = torch.cat( + (seq_T, torch.from_numpy(self.target_pdb["seq"])), dim=0 + ) # diffusion mask - self.diffusion_mask = torch.cat((self.diffusion_mask, torch.full((target_L,), True)),dim=0) + self.diffusion_mask = torch.cat( + (self.diffusion_mask, torch.full((target_L,), True)), dim=0 + ) # atom mask mask_27 = torch.full((target_L, 27), False) - mask_27[:,:14] = torch.from_numpy(self.target_pdb['mask']) + mask_27[:, :14] = torch.from_numpy(self.target_pdb["mask"]) atom_mask = torch.cat((atom_mask, mask_27), dim=0) self.L += target_L # generate contigmap object contig = [] - for idx,i in enumerate(self.target_pdb['pdb_idx'][:-1]): - if idx==0: - start=i[1] - if i[1] + 1 != self.target_pdb['pdb_idx'][idx+1][1] or i[0] != self.target_pdb['pdb_idx'][idx+1][0]: - contig.append(f'{i[0]}{start}-{i[1]}/0 ') - start = self.target_pdb['pdb_idx'][idx+1][1] - contig.append(f"{self.target_pdb['pdb_idx'][-1][0]}{start}-{self.target_pdb['pdb_idx'][-1][1]}/0 ") + for idx, i in enumerate(self.target_pdb["pdb_idx"][:-1]): + if idx == 0: + start = i[1] + if ( + i[1] + 1 != self.target_pdb["pdb_idx"][idx + 1][1] + or i[0] != self.target_pdb["pdb_idx"][idx + 1][0] + ): + contig.append(f"{i[0]}{start}-{i[1]}/0 ") + start = self.target_pdb["pdb_idx"][idx + 1][1] + contig.append( + f"{self.target_pdb['pdb_idx'][-1][0]}{start}-{self.target_pdb['pdb_idx'][-1][1]}/0 " + ) contig.append(f"{self.binderlen}-{self.binderlen}") contig = ["".join(contig)] else: contig = [f"{self.binderlen}-{self.binderlen}"] - self.contig_map=ContigMap(self.target_pdb, contig) + self.contig_map = ContigMap(self.target_pdb, contig) self.mappings = self.contig_map.get_mappings() self.mask_seq = self.diffusion_mask self.mask_str = self.diffusion_mask - L_mapped=len(self.contig_map.ref) + L_mapped = len(self.contig_map.ref) ############################ ### Specific Contig mode ### @@ -897,59 +1060,65 @@ def sample_init(self): else: # get contigmap from command line - assert self.target is None, "Giving a target is the wrong way of handling this is you're doing contigs and secondary structure" + assert ( + self.target is None + ), "Giving a target is the wrong way of handling this is you're doing contigs and secondary structure" # process target and reinitialise potential_manager. This is here because the 'target' is always set up to be the second chain in out inputs. self.target_feats = iu.process_target(self.inf_conf.input_pdb) self.contig_map = self.construct_contig(self.target_feats) self.mappings = self.contig_map.get_mappings() - self.mask_seq = torch.from_numpy(self.contig_map.inpaint_seq)[None,:] - self.mask_str = torch.from_numpy(self.contig_map.inpaint_str)[None,:] - self.binderlen = len(self.contig_map.inpaint) + self.mask_seq = torch.from_numpy(self.contig_map.inpaint_seq)[None, :] + self.mask_str = torch.from_numpy(self.contig_map.inpaint_str)[None, :] + self.binderlen = len(self.contig_map.inpaint) self.L = len(self.contig_map.inpaint_seq) target_feats = self.target_feats contig_map = self.contig_map - xyz_27 = target_feats['xyz_27'] - mask_27 = target_feats['mask_27'] - seq_orig = target_feats['seq'] + xyz_27 = target_feats["xyz_27"] + mask_27 = target_feats["mask_27"] + seq_orig = target_feats["seq"] L_mapped = len(self.contig_map.ref) - seq_T=torch.full((L_mapped,),21) + seq_T = torch.full((L_mapped,), 21) seq_T[contig_map.hal_idx0] = seq_orig[contig_map.ref_idx0] seq_T[~self.mask_seq.squeeze()] = 21 diffusion_mask = self.mask_str self.diffusion_mask = diffusion_mask - - xT = torch.full((1,1,L_mapped,27,3), np.nan) - xT[:, :, contig_map.hal_idx0, ...] = xyz_27[contig_map.ref_idx0,...] + + xT = torch.full((1, 1, L_mapped, 27, 3), np.nan) + xT[:, :, contig_map.hal_idx0, ...] = xyz_27[contig_map.ref_idx0, ...] xT = get_init_xyz(xT).squeeze() atom_mask = torch.full((L_mapped, 27), False) atom_mask[contig_map.hal_idx0] = mask_27[contig_map.ref_idx0] - if hasattr(self.contig_map, 'ss_spec'): - self.adj=torch.full((L_mapped, L_mapped),2) # masked - self.adj=nn.one_hot(self.adj.long(), num_classes=3) - self.ss=iu.ss_from_contig(self.contig_map.ss_spec) - assert L_mapped==self.adj.shape[0] - + if hasattr(self.contig_map, "ss_spec"): + self.adj = torch.full((L_mapped, L_mapped), 2) # masked + self.adj = nn.one_hot(self.adj.long(), num_classes=3) + self.ss = iu.ss_from_contig(self.contig_map.ss_spec) + assert L_mapped == self.adj.shape[0] + #################### ### Get hotspots ### #################### - self.hotspot_0idx=iu.get_idx0_hotspots(self.mappings, self.ppi_conf, self.binderlen) - + self.hotspot_0idx = iu.get_idx0_hotspots( + self.mappings, self.ppi_conf, self.binderlen + ) + ######################### ### Set up potentials ### ######################### - self.potential_manager = PotentialManager(self.potential_conf, - self.ppi_conf, - self.diffuser_conf, - self.inf_conf, - self.hotspot_0idx, - self.binderlen) + self.potential_manager = PotentialManager( + self.potential_conf, + self.ppi_conf, + self.diffuser_conf, + self.inf_conf, + self.hotspot_0idx, + self.binderlen, + ) - self.chain_idx=['A' if i < self.binderlen else 'B' for i in range(self.L)] + self.chain_idx = ["A" if i < self.binderlen else "B" for i in range(self.L)] ######################## ### Handle Partial T ### @@ -960,8 +1129,8 @@ def sample_init(self): self.t_step_input = int(self.diffuser_conf.partial_T) else: self.t_step_input = int(self.diffuser_conf.T) - t_list = np.arange(1, self.t_step_input+1) - seq_T=torch.nn.functional.one_hot(seq_T, num_classes=22).float() + t_list = np.arange(1, self.t_step_input + 1) + seq_T = torch.nn.functional.one_hot(seq_T, num_classes=22).float() fa_stack, xyz_true = self.diffuser.diffuse_pose( xT, @@ -969,7 +1138,8 @@ def sample_init(self): atom_mask.squeeze(), diffusion_mask=self.diffusion_mask.squeeze(), t_list=t_list, - include_motif_sidechains=self.preprocess_conf.motif_sidechain_input) + include_motif_sidechains=self.preprocess_conf.motif_sidechain_input, + ) ####################### ### Set up Denoiser ### @@ -977,56 +1147,48 @@ def sample_init(self): self.denoiser = self.construct_denoiser(self.L, visible=self.mask_seq.squeeze()) - - xT = torch.clone(fa_stack[-1].squeeze()[:,:14,:]) + xT = torch.clone(fa_stack[-1].squeeze()[:, :14, :]) ################################ ### Add to Cyclic_reses init ### ################################ - - if self._conf.inference.cyclic: - if self._conf.inference.cyc_chains is None: - self.cyclic_reses = ~self.mask_str.to(self.device).squeeze() - else: - assert isinstance(self._conf.inference.cyc_chains, str), 'cyc_chains arg must be string' - cyc_chains = self._conf.inference.cyc_chains - cyc_chains = [i.upper() for i in cyc_chains] - hal_idx = self.contig_map.hal - is_cyclized = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() - for ch in cyc_chains: - ch_mask = torch.tensor([idx[0] == ch for idx in hal_idx]).bool() - is_cyclized[ch_mask] = True - self.cyclic_reses = is_cyclized - else: - self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() + self._init_cyclic_reses(self.mask_str, self.contig_map) return xT, seq_T - + def _preprocess(self, seq, xyz_t, t): - msa_masked, msa_full, seq, xyz_prev, idx_pdb, t1d, t2d, xyz_t, alpha_t = super()._preprocess(seq, xyz_t, t, repack=False) - + msa_masked, msa_full, seq, xyz_prev, idx_pdb, t1d, t2d, xyz_t, alpha_t = ( + super()._preprocess(seq, xyz_t, t, repack=False) + ) + ################################### ### Add Adj/Secondary Structure ### ################################### - assert self.preprocess_conf.d_t1d == 28, "The checkpoint you're using hasn't been trained with sec-struc/block adjacency features" - assert self.preprocess_conf.d_t2d == 47, "The checkpoint you're using hasn't been trained with sec-struc/block adjacency features" - + assert ( + self.preprocess_conf.d_t1d == 28 + ), "The checkpoint you're using hasn't been trained with sec-struc/block adjacency features" + assert ( + self.preprocess_conf.d_t2d == 47 + ), "The checkpoint you're using hasn't been trained with sec-struc/block adjacency features" + ##################### ### Handle Target ### ##################### if self.target: - blank_ss = torch.nn.functional.one_hot(torch.full((self.L-self.binderlen,), 3), num_classes=4) + blank_ss = torch.nn.functional.one_hot( + torch.full((self.L - self.binderlen,), 3), num_classes=4 + ) full_ss = torch.cat((self.ss, blank_ss), dim=0) if self._conf.scaffoldguided.target_ss is not None: - full_ss[self.binderlen:] = self.target_ss + full_ss[self.binderlen :] = self.target_ss else: full_ss = self.ss - t1d=torch.cat((t1d, full_ss[None,None].to(self.device)), dim=-1) + t1d = torch.cat((t1d, full_ss[None, None].to(self.device)), dim=-1) t1d = t1d.float() - + ########### ### t2d ### ########### @@ -1034,19 +1196,19 @@ def _preprocess(self, seq, xyz_t, t): if self.d_t2d == 47: if self.target: full_adj = torch.zeros((self.L, self.L, 3)) - full_adj[:,:,-1] = 1. #set to mask - full_adj[:self.binderlen, :self.binderlen] = self.adj + full_adj[:, :, -1] = 1.0 # set to mask + full_adj[: self.binderlen, : self.binderlen] = self.adj if self._conf.scaffoldguided.target_adj is not None: - full_adj[self.binderlen:,self.binderlen:] = self.target_adj + full_adj[self.binderlen :, self.binderlen :] = self.target_adj else: full_adj = self.adj - t2d=torch.cat((t2d, full_adj[None,None].to(self.device)),dim=-1) + t2d = torch.cat((t2d, full_adj[None, None].to(self.device)), dim=-1) ########### ### idx ### ########### if self.target: - idx_pdb[:,self.binderlen:] += 200 + idx_pdb[:, self.binderlen :] += 200 return msa_masked, msa_full, seq, xyz_prev, idx_pdb, t1d, t2d, xyz_t, alpha_t