diff --git a/embodichain/lab/gym/envs/managers/dataset_manager.py b/embodichain/lab/gym/envs/managers/dataset_manager.py index 0a8e9d4..a0ca168 100644 --- a/embodichain/lab/gym/envs/managers/dataset_manager.py +++ b/embodichain/lab/gym/envs/managers/dataset_manager.py @@ -246,9 +246,45 @@ def finalize(self) -> Optional[str]: return None - """ - Operations - Functor settings. - """ + def get_cached_data(self) -> list[Dict[str, Any]]: + """Get cached data from all dataset functors (for online training). + + Iterates through all functors and collects cached data from those + that support online training mode (have get_cached_data method). + + Returns: + List of cached data dictionaries from all functors. + """ + all_cached_data = [] + + # Iterate through all modes and functors + for mode_cfgs in self._mode_functor_cfgs.values(): + for functor_cfg in mode_cfgs: + if hasattr(functor_cfg.func, "get_cached_data"): + cached_data = functor_cfg.func.get_cached_data() + all_cached_data.extend(cached_data) + + return all_cached_data + + def clear_cache(self) -> int: + """Clear cached data from all dataset functors (for online training). + + Iterates through all functors and clears their cache if they + support online training mode (have clear_cache method). + + Returns: + Total number of cached items cleared across all functors. + """ + total_cleared = 0 + + # Iterate through all modes and functors + for mode_cfgs in self._mode_functor_cfgs.values(): + for functor_cfg in mode_cfgs: + if hasattr(functor_cfg.func, "clear_cache"): + cleared = functor_cfg.func.clear_cache() + total_cleared += cleared + + return total_cleared def get_functor_cfg(self, functor_name: str) -> DatasetFunctorCfg: """Gets the configuration for the specified functor. @@ -267,10 +303,6 @@ def get_functor_cfg(self, functor_name: str) -> DatasetFunctorCfg: return self._mode_functor_cfgs[mode][functors.index(functor_name)] logger.log_error(f"Dataset functor '{functor_name}' not found.") - """ - Helper functions. - """ - def _prepare_functors(self): """Prepare dataset functors from configuration. diff --git a/embodichain/lab/gym/envs/managers/datasets.py b/embodichain/lab/gym/envs/managers/datasets.py index 09c0c1e..625ce8e 100644 --- a/embodichain/lab/gym/envs/managers/datasets.py +++ b/embodichain/lab/gym/envs/managers/datasets.py @@ -40,11 +40,9 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset, HF_LEROBOT_HOME LEROBOT_AVAILABLE = True - __all__ = ["LeRobotRecorder"] except ImportError: LEROBOT_AVAILABLE = False - __all__ = [] @@ -72,6 +70,11 @@ def __init__(self, cfg: DatasetFunctorCfg, env: EmbodiedEnv): - export_success_only: Whether to export only successful episodes env: The environment instance """ + if not LEROBOT_AVAILABLE: + logger.log_error( + "LeRobot is not installed. Please install it with: pip install lerobot" + ) + super().__init__(cfg, env) # Extract parameters from cfg.params @@ -104,8 +107,6 @@ def __init__(self, cfg: DatasetFunctorCfg, env: EmbodiedEnv): # Initialize dataset self._initialize_dataset() - logger.log_info(f"LeRobotRecorder initialized at: {self.dataset_path}") - @property def dataset_path(self) -> str: """Path to the dataset directory.""" @@ -215,7 +216,6 @@ def _save_episodes( elif terminateds is not None: is_success = terminateds[env_id].item() - logger.log_info(f"Episode {env_id} success: {is_success}") if self.export_success_only and not is_success: logger.log_info(f"Skipping failed episode for env {env_id}") continue @@ -295,11 +295,6 @@ def _initialize_dataset(self) -> None: fps = self.robot_meta.get("control_freq", 30) features = self._build_features() - logger.log_info("------------------------------------------") - logger.log_info(f"Building dataset: {dataset_name}") - logger.log_info(f"Parent directory: {lerobot_data_root}") - logger.log_info(f"Full path: {self.dataset_full_path}") - self.dataset = LeRobotDataset.create( repo_id=dataset_name, fps=fps, diff --git a/embodichain/lab/gym/envs/tasks/tableware/scoop_ice.py b/embodichain/lab/gym/envs/tasks/tableware/scoop_ice.py index 6a3c4c5..8eb37ed 100644 --- a/embodichain/lab/gym/envs/tasks/tableware/scoop_ice.py +++ b/embodichain/lab/gym/envs/tasks/tableware/scoop_ice.py @@ -99,7 +99,6 @@ def add_xpos_offset(self, arm_qpos: np.ndarray, offset: np.ndarray, is_left: boo return arm_qpos_offset_batch[0].to("cpu").numpy() def pack_qpos(self): - self.num_envs = self.sim.num_envs left_arm_qpos = self.trajectory["left_arm"] # [waypoint_num, dof] logger.log_info("Adding x and z offset to left arm trajectory...") left_arm_qpos = self.add_xpos_offset( diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index ebe06d6..6ebb6d6 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -453,7 +453,8 @@ class ComponentCfg: env_cfg.sim_steps_per_control = config["env"].get("sim_steps_per_control", 4) - # load dataset config + # TODO: support more env events, eg, grasp pose generation, mesh preprocessing, etc. + env_cfg.dataset = ComponentCfg() if "dataset" in config["env"]: # Define modules to search for dataset functions @@ -481,8 +482,6 @@ class ComponentCfg: setattr(env_cfg.dataset, dataset_name, dataset) - # TODO: support more env events, eg, grasp pose generation, mesh preprocessing, etc. - env_cfg.events = ComponentCfg() if "events" in config["env"]: # Define modules to search for event functions diff --git a/embodichain/lab/sim/common.py b/embodichain/lab/sim/common.py index 4e5b9b3..8c11ab3 100644 --- a/embodichain/lab/sim/common.py +++ b/embodichain/lab/sim/common.py @@ -24,6 +24,7 @@ from embodichain.lab.sim.cfg import ObjectBaseCfg from embodichain.utils import logger +from copy import deepcopy T = TypeVar("T") @@ -57,7 +58,7 @@ def __init__( if entities is None or len(entities) == 0: logger.log_error("Invalid entities list: must not be empty.") - self.cfg = cfg.copy() + self.cfg = deepcopy(cfg) self.uid = self.cfg.uid if self.uid is None: logger.log_error("UID must be set in the configuration.")