diff --git a/rfdiffusion/inference/symmetry.py b/rfdiffusion/inference/symmetry.py index 864a5abe..fc8548b2 100644 --- a/rfdiffusion/inference/symmetry.py +++ b/rfdiffusion/inference/symmetry.py @@ -72,6 +72,12 @@ def __init__(self, global_sym, recenter, radius, model_only_neighbors=False): self._init_octahedral() self.apply_symmetry = self._apply_octahedral + elif global_sym.lower().startswith('translational'): + # Translational symmetry + self._log.info('Initializing translational symmetry.') + self._init_translational(global_sym) + self.apply_symmetry = self._apply_translational + elif global_sym.lower() in saved_symmetries: # Using a saved symmetry self._log.info('Initializing %s symmetry order.'%global_sym) @@ -218,6 +224,38 @@ def _init_from_symrots_file(self, name): assert len(self.sym_rots) == self.order assert np.isclose(((self.sym_rots[0]-np.eye(3))**2).sum(), 0) + def _init_translational(self, global_sym): + # Example: 'translational3_0_0_10' for 3 subunits, translation vector (0,0,10) + # Format: translational{n}_{dx}_{dy}_{dz} + import re + m = re.match(r'translational(\d+)_([\d\.-]+)_([\d\.-]+)_([\d\.-]+)', global_sym.lower()) + if not m: + raise ValueError(f'Invalid translational symmetry string: {global_sym}') + order = int(m.group(1)) + dx = float(m.group(2)) + dy = float(m.group(3)) + dz = float(m.group(4)) + self.trans_vec = torch.tensor([dx, dy, dz], dtype=torch.float32) + self.order = order + + def _apply_translational(self, coords_in, seq_in): + coords_out = torch.clone(coords_in) + seq_out = torch.clone(seq_in) + + if seq_out.shape[0] % self.order != 0: + raise ValueError(f'Sequence length must be divisible by {self.order}') + + subunit_len = seq_out.shape[0] // self.order + + for i in range(self.order): + start_i = subunit_len * i + end_i = subunit_len * (i + 1) + translation = i * self.trans_vec[None, None, :] # shape: [1, 1, 3] + coords_out[start_i:end_i] = coords_in[:subunit_len] + translation + seq_out[start_i:end_i] = seq_in[:subunit_len] + + return coords_out, seq_out + def close_neighbors(self): """close_neighbors finds the rotations within self.sym_rots that correspond to close neighbors.