Skip to content

Commit

Permalink
account for both arrow and torch datasets when splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
Alessandro Sordoni committed Sep 17, 2024
1 parent 4c2b2ae commit 078ec89
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions mttl/datamodule/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,9 @@ def task_to_id(self):
return self._task_to_id

def create_train_valid_split(
self, dataset: ArrowDataset, validation_portion: float = 0.05
self,
dataset: Union[ArrowDataset, torch.utils.data.Dataset],
validation_portion: float = 0.05,
):
# always use the same split for the dataset
validation_portion = validation_portion or self.config.validation_portion
Expand All @@ -647,10 +649,25 @@ def create_train_valid_split(
)
return dataset, None

split_dataset = dataset.train_test_split(
test_size=validation_portion, seed=self.rng.seed()
)
return split_dataset["train"], split_dataset["test"]
if isinstance(dataset, ArrowDataset):
split_dataset = dataset.train_test_split(
test_size=validation_portion, generator=self.rng
)
return split_dataset["train"], split_dataset["test"]
elif isinstance(dataset, torch.utils.data.Dataset):
split_dataset = torch.utils.data.random_split(
dataset,
[
int(len(dataset) * (1 - validation_portion)),
int(len(dataset) * validation_portion),
],
seed=self.rng,
)
return split_dataset[0], split_dataset[1]
else:
raise ValueError(
"Only ArrowDataset and torch.utils.data.Dataset are supported for train/valid split."
)

def subsample_dataset(self, dataset, n_samples, per_task=False):
"""
Expand Down

0 comments on commit 078ec89

Please sign in to comment.