diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 14a8fad5b..5a7411b61 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -650,11 +650,12 @@ def __call__( start = self.rng.integers(0, n - fragment_length, endpoint=True) end = start + fragment_length terminal = (end == n) and traj.terminal + # Copy the slices to enable garbage collection of full trajectory. fragment = TrajectoryWithRew( - obs=traj.obs[start : end + 1], - acts=traj.acts[start:end], - infos=traj.infos[start:end] if traj.infos is not None else None, - rews=traj.rews[start:end], + obs=traj.obs[start : end + 1].copy(), + acts=traj.acts[start:end].copy(), + infos=traj.infos[start:end].copy() if traj.infos is not None else None, + rews=traj.rews[start:end].copy(), terminal=terminal, ) fragments.append(fragment)