diff --git a/rf_diffusion/frame_diffusion/data/all_atom.py b/rf_diffusion/frame_diffusion/data/all_atom.py index 8f472ba..48de0f1 100644 --- a/rf_diffusion/frame_diffusion/data/all_atom.py +++ b/rf_diffusion/frame_diffusion/data/all_atom.py @@ -9,12 +9,12 @@ Rotation = ru.Rotation # Residue Constants from OpenFold/AlphaFold2. -IDEALIZED_POS37 = torch.tensor(residue_constants.restype_atom37_rigid_group_positions) +IDEALIZED_POS37 = torch.from_numpy(residue_constants.restype_atom37_rigid_group_positions) IDEALIZED_POS37_MASK = torch.any(IDEALIZED_POS37, axis=-1) -IDEALIZED_POS = torch.tensor(residue_constants.restype_atom14_rigid_group_positions) -DEFAULT_FRAMES = torch.tensor(residue_constants.restype_rigid_group_default_frame) -ATOM_MASK = torch.tensor(residue_constants.restype_atom14_mask) -GROUP_IDX = torch.tensor(residue_constants.restype_atom14_to_rigid_group) +IDEALIZED_POS = torch.from_numpy(residue_constants.restype_atom14_rigid_group_positions) +DEFAULT_FRAMES = torch.from_numpy(residue_constants.restype_rigid_group_default_frame) +ATOM_MASK = torch.from_numpy(residue_constants.restype_atom14_mask) +GROUP_IDX = torch.from_numpy(residue_constants.restype_atom14_to_rigid_group) # def torsion_angles_to_frames( diff --git a/se3_flow_matching/data/all_atom.py b/se3_flow_matching/data/all_atom.py index a1c1ac2..9ab5655 100644 --- a/se3_flow_matching/data/all_atom.py +++ b/se3_flow_matching/data/all_atom.py @@ -30,10 +30,10 @@ # Residue Constants from OpenFold/AlphaFold2. -IDEALIZED_POS = torch.tensor(residue_constants.restype_atom14_rigid_group_positions) -DEFAULT_FRAMES = torch.tensor(residue_constants.restype_rigid_group_default_frame) -ATOM_MASK = torch.tensor(residue_constants.restype_atom14_mask) -GROUP_IDX = torch.tensor(residue_constants.restype_atom14_to_rigid_group) +IDEALIZED_POS = torch.from_numpy(residue_constants.restype_atom14_rigid_group_positions) +DEFAULT_FRAMES = torch.from_numpy(residue_constants.restype_rigid_group_default_frame) +ATOM_MASK = torch.from_numpy(residue_constants.restype_atom14_mask) +GROUP_IDX = torch.from_numpy(residue_constants.restype_atom14_to_rigid_group) def to_atom37(trans, rots): fake_psi = torch.zeros(trans.shape[:-1] + (2,), device=trans.device)