diff --git a/diffusion_policy/dataset/blockpush_lowdim_dataset.py b/diffusion_policy/dataset/blockpush_lowdim_dataset.py index 86242c22..a9c043f4 100644 --- a/diffusion_policy/dataset/blockpush_lowdim_dataset.py +++ b/diffusion_policy/dataset/blockpush_lowdim_dataset.py @@ -29,6 +29,7 @@ def __init__(self, n_episodes=self.replay_buffer.n_episodes, val_ratio=val_ratio, seed=seed) + self.val_mask = val_mask train_mask = ~val_mask self.sampler = SequenceSampler( replay_buffer=self.replay_buffer, @@ -52,9 +53,9 @@ def get_validation_dataset(self): sequence_length=self.horizon, pad_before=self.pad_before, pad_after=self.pad_after, - episode_mask=~self.train_mask + episode_mask=self.val_mask ) - val_set.train_mask = ~self.train_mask + val_set.train_mask = self.val_mask return val_set def get_normalizer(self, mode='limits', **kwargs): diff --git a/diffusion_policy/dataset/kitchen_lowdim_dataset.py b/diffusion_policy/dataset/kitchen_lowdim_dataset.py index 601e21cb..f4fa69fd 100644 --- a/diffusion_policy/dataset/kitchen_lowdim_dataset.py +++ b/diffusion_policy/dataset/kitchen_lowdim_dataset.py @@ -40,6 +40,7 @@ def __init__(self, n_episodes=self.replay_buffer.n_episodes, val_ratio=val_ratio, seed=seed) + self.val_mask = val_mask train_mask = ~val_mask self.sampler = SequenceSampler( replay_buffer=self.replay_buffer, @@ -60,9 +61,9 @@ def get_validation_dataset(self): sequence_length=self.horizon, pad_before=self.pad_before, pad_after=self.pad_after, - episode_mask=~self.train_mask + episode_mask=self.val_mask ) - val_set.train_mask = ~self.train_mask + val_set.train_mask = self.val_mask return val_set def get_normalizer(self, mode='limits', **kwargs): diff --git a/diffusion_policy/dataset/kitchen_mjl_lowdim_dataset.py b/diffusion_policy/dataset/kitchen_mjl_lowdim_dataset.py index e3173818..36f60d1f 100644 --- a/diffusion_policy/dataset/kitchen_mjl_lowdim_dataset.py +++ b/diffusion_policy/dataset/kitchen_mjl_lowdim_dataset.py @@ -61,6 +61,7 @@ def __init__(self, n_episodes=self.replay_buffer.n_episodes, val_ratio=val_ratio, seed=seed) + self.val_mask = val_mask train_mask = ~val_mask self.sampler = SequenceSampler( replay_buffer=self.replay_buffer, @@ -81,9 +82,9 @@ def get_validation_dataset(self): sequence_length=self.horizon, pad_before=self.pad_before, pad_after=self.pad_after, - episode_mask=~self.train_mask + episode_mask=self.val_mask ) - val_set.train_mask = ~self.train_mask + val_set.train_mask = self.val_mask return val_set def get_normalizer(self, mode='limits', **kwargs): diff --git a/diffusion_policy/dataset/pusht_dataset.py b/diffusion_policy/dataset/pusht_dataset.py index dc3ec1c8..cca45654 100644 --- a/diffusion_policy/dataset/pusht_dataset.py +++ b/diffusion_policy/dataset/pusht_dataset.py @@ -30,6 +30,7 @@ def __init__(self, n_episodes=self.replay_buffer.n_episodes, val_ratio=val_ratio, seed=seed) + self.val_mask = val_mask train_mask = ~val_mask train_mask = downsample_mask( mask=train_mask, @@ -58,9 +59,9 @@ def get_validation_dataset(self): sequence_length=self.horizon, pad_before=self.pad_before, pad_after=self.pad_after, - episode_mask=~self.train_mask + episode_mask=self.val_mask ) - val_set.train_mask = ~self.train_mask + val_set.train_mask = self.val_mask return val_set def get_normalizer(self, mode='limits', **kwargs): diff --git a/diffusion_policy/dataset/pusht_image_dataset.py b/diffusion_policy/dataset/pusht_image_dataset.py index f096a8f0..e6489620 100644 --- a/diffusion_policy/dataset/pusht_image_dataset.py +++ b/diffusion_policy/dataset/pusht_image_dataset.py @@ -28,6 +28,7 @@ def __init__(self, n_episodes=self.replay_buffer.n_episodes, val_ratio=val_ratio, seed=seed) + self.val_mask = val_mask; train_mask = ~val_mask train_mask = downsample_mask( mask=train_mask, @@ -52,9 +53,9 @@ def get_validation_dataset(self): sequence_length=self.horizon, pad_before=self.pad_before, pad_after=self.pad_after, - episode_mask=~self.train_mask + episode_mask=self.val_mask ) - val_set.train_mask = ~self.train_mask + val_set.train_mask = self.val_mask return val_set def get_normalizer(self, mode='limits', **kwargs):