From 7e4418a362419693d7cf4142545c2986e6a418af Mon Sep 17 00:00:00 2001 From: JohnToro-CZAF Date: Thu, 28 Aug 2025 03:57:17 +0700 Subject: [PATCH] HOTFIX: newer torch doesn't supporting init tensor from numpy anymore --- rf_diffusion/frame_diffusion/data/all_atom.py | 10 +++++----- se3_flow_matching/data/all_atom.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/rf_diffusion/frame_diffusion/data/all_atom.py b/rf_diffusion/frame_diffusion/data/all_atom.py index 8f472ba..48de0f1 100644 --- a/rf_diffusion/frame_diffusion/data/all_atom.py +++ b/rf_diffusion/frame_diffusion/data/all_atom.py @@ -9,12 +9,12 @@ Rotation = ru.Rotation # Residue Constants from OpenFold/AlphaFold2. -IDEALIZED_POS37 = torch.tensor(residue_constants.restype_atom37_rigid_group_positions) +IDEALIZED_POS37 = torch.from_numpy(residue_constants.restype_atom37_rigid_group_positions) IDEALIZED_POS37_MASK = torch.any(IDEALIZED_POS37, axis=-1) -IDEALIZED_POS = torch.tensor(residue_constants.restype_atom14_rigid_group_positions) -DEFAULT_FRAMES = torch.tensor(residue_constants.restype_rigid_group_default_frame) -ATOM_MASK = torch.tensor(residue_constants.restype_atom14_mask) -GROUP_IDX = torch.tensor(residue_constants.restype_atom14_to_rigid_group) +IDEALIZED_POS = torch.from_numpy(residue_constants.restype_atom14_rigid_group_positions) +DEFAULT_FRAMES = torch.from_numpy(residue_constants.restype_rigid_group_default_frame) +ATOM_MASK = torch.from_numpy(residue_constants.restype_atom14_mask) +GROUP_IDX = torch.from_numpy(residue_constants.restype_atom14_to_rigid_group) # def torsion_angles_to_frames( diff --git a/se3_flow_matching/data/all_atom.py b/se3_flow_matching/data/all_atom.py index a1c1ac2..9ab5655 100644 --- a/se3_flow_matching/data/all_atom.py +++ b/se3_flow_matching/data/all_atom.py @@ -30,10 +30,10 @@ # Residue Constants from OpenFold/AlphaFold2. -IDEALIZED_POS = torch.tensor(residue_constants.restype_atom14_rigid_group_positions) -DEFAULT_FRAMES = torch.tensor(residue_constants.restype_rigid_group_default_frame) -ATOM_MASK = torch.tensor(residue_constants.restype_atom14_mask) -GROUP_IDX = torch.tensor(residue_constants.restype_atom14_to_rigid_group) +IDEALIZED_POS = torch.from_numpy(residue_constants.restype_atom14_rigid_group_positions) +DEFAULT_FRAMES = torch.from_numpy(residue_constants.restype_rigid_group_default_frame) +ATOM_MASK = torch.from_numpy(residue_constants.restype_atom14_mask) +GROUP_IDX = torch.from_numpy(residue_constants.restype_atom14_to_rigid_group) def to_atom37(trans, rots): fake_psi = torch.zeros(trans.shape[:-1] + (2,), device=trans.device)