Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions embodichain/lab/gym/envs/managers/dataset_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,45 @@ def finalize(self) -> Optional[str]:

return None

"""
Operations - Functor settings.
"""
def get_cached_data(self) -> list[Dict[str, Any]]:
"""Get cached data from all dataset functors (for online training).

Iterates through all functors and collects cached data from those
that support online training mode (have get_cached_data method).

Returns:
List of cached data dictionaries from all functors.
"""
all_cached_data = []

# Iterate through all modes and functors
for mode_cfgs in self._mode_functor_cfgs.values():
for functor_cfg in mode_cfgs:
if hasattr(functor_cfg.func, "get_cached_data"):
cached_data = functor_cfg.func.get_cached_data()
all_cached_data.extend(cached_data)

return all_cached_data

def clear_cache(self) -> int:
"""Clear cached data from all dataset functors (for online training).

Iterates through all functors and clears their cache if they
support online training mode (have clear_cache method).

Returns:
Total number of cached items cleared across all functors.
"""
total_cleared = 0

# Iterate through all modes and functors
for mode_cfgs in self._mode_functor_cfgs.values():
for functor_cfg in mode_cfgs:
if hasattr(functor_cfg.func, "clear_cache"):
cleared = functor_cfg.func.clear_cache()
total_cleared += cleared

return total_cleared

def get_functor_cfg(self, functor_name: str) -> DatasetFunctorCfg:
"""Gets the configuration for the specified functor.
Expand All @@ -267,10 +303,6 @@ def get_functor_cfg(self, functor_name: str) -> DatasetFunctorCfg:
return self._mode_functor_cfgs[mode][functors.index(functor_name)]
logger.log_error(f"Dataset functor '{functor_name}' not found.")

"""
Helper functions.
"""

def _prepare_functors(self):
"""Prepare dataset functors from configuration.

Expand Down
15 changes: 5 additions & 10 deletions embodichain/lab/gym/envs/managers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,9 @@
from lerobot.datasets.lerobot_dataset import LeRobotDataset, HF_LEROBOT_HOME

LEROBOT_AVAILABLE = True

__all__ = ["LeRobotRecorder"]
except ImportError:
LEROBOT_AVAILABLE = False

__all__ = []


Expand Down Expand Up @@ -72,6 +70,11 @@ def __init__(self, cfg: DatasetFunctorCfg, env: EmbodiedEnv):
- export_success_only: Whether to export only successful episodes
env: The environment instance
"""
if not LEROBOT_AVAILABLE:
logger.log_error(
"LeRobot is not installed. Please install it with: pip install lerobot"
)

super().__init__(cfg, env)

# Extract parameters from cfg.params
Expand Down Expand Up @@ -104,8 +107,6 @@ def __init__(self, cfg: DatasetFunctorCfg, env: EmbodiedEnv):
# Initialize dataset
self._initialize_dataset()

logger.log_info(f"LeRobotRecorder initialized at: {self.dataset_path}")

@property
def dataset_path(self) -> str:
"""Path to the dataset directory."""
Expand Down Expand Up @@ -215,7 +216,6 @@ def _save_episodes(
elif terminateds is not None:
is_success = terminateds[env_id].item()

logger.log_info(f"Episode {env_id} success: {is_success}")
if self.export_success_only and not is_success:
logger.log_info(f"Skipping failed episode for env {env_id}")
continue
Expand Down Expand Up @@ -295,11 +295,6 @@ def _initialize_dataset(self) -> None:
fps = self.robot_meta.get("control_freq", 30)
features = self._build_features()

logger.log_info("------------------------------------------")
logger.log_info(f"Building dataset: {dataset_name}")
logger.log_info(f"Parent directory: {lerobot_data_root}")
logger.log_info(f"Full path: {self.dataset_full_path}")

self.dataset = LeRobotDataset.create(
repo_id=dataset_name,
fps=fps,
Expand Down
1 change: 0 additions & 1 deletion embodichain/lab/gym/envs/tasks/tableware/scoop_ice.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def add_xpos_offset(self, arm_qpos: np.ndarray, offset: np.ndarray, is_left: boo
return arm_qpos_offset_batch[0].to("cpu").numpy()

def pack_qpos(self):
self.num_envs = self.sim.num_envs
left_arm_qpos = self.trajectory["left_arm"] # [waypoint_num, dof]
logger.log_info("Adding x and z offset to left arm trajectory...")
left_arm_qpos = self.add_xpos_offset(
Expand Down
5 changes: 2 additions & 3 deletions embodichain/lab/gym/utils/gym_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,8 @@ class ComponentCfg:

env_cfg.sim_steps_per_control = config["env"].get("sim_steps_per_control", 4)

# load dataset config
# TODO: support more env events, eg, grasp pose generation, mesh preprocessing, etc.

env_cfg.dataset = ComponentCfg()
if "dataset" in config["env"]:
# Define modules to search for dataset functions
Expand Down Expand Up @@ -481,8 +482,6 @@ class ComponentCfg:

setattr(env_cfg.dataset, dataset_name, dataset)

# TODO: support more env events, eg, grasp pose generation, mesh preprocessing, etc.

env_cfg.events = ComponentCfg()
if "events" in config["env"]:
# Define modules to search for event functions
Expand Down
3 changes: 2 additions & 1 deletion embodichain/lab/sim/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from embodichain.lab.sim.cfg import ObjectBaseCfg
from embodichain.utils import logger
from copy import deepcopy

T = TypeVar("T")

Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(
if entities is None or len(entities) == 0:
logger.log_error("Invalid entities list: must not be empty.")

self.cfg = cfg.copy()
self.cfg = deepcopy(cfg)
self.uid = self.cfg.uid
if self.uid is None:
logger.log_error("UID must be set in the configuration.")
Expand Down
Loading