-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6095e0a
commit 85a7a15
Showing
55 changed files
with
242 additions
and
1,686 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,2 @@ | ||
from . import ( | ||
arbitrary_correlated_features, | ||
base, | ||
correlated_features, | ||
mnist, | ||
mnist1d, | ||
output_supervision, | ||
random, | ||
) | ||
from .base import Dataset | ||
from . import mnist, mnist1d | ||
from .base import DataModule |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,146 +1,51 @@ | ||
from dataclasses import dataclass, field | ||
from typing import cast | ||
from typing import Any, TypeVar | ||
|
||
import torch | ||
from pytorch_lightning import LightningDataModule, LightningModule | ||
from pytorch_lightning import LightningDataModule | ||
from torch.utils import data | ||
from torch.utils.data import ConcatDataset, DataLoader, Subset | ||
|
||
from revnets import training | ||
from revnets.training import Network | ||
from torch.utils.data import DataLoader, random_split | ||
|
||
from ..context import context | ||
from .split import Split | ||
from .utils import split_train_val | ||
|
||
T = TypeVar("T") | ||
|
||
@dataclass | ||
class Dataset(LightningDataModule): | ||
""" | ||
For data size experiments, we want to now how many samples we need in order to have | ||
fair comparisons, we keep the number of effective samples the same by scaling the | ||
number of repetitions in the training set. | ||
""" | ||
|
||
eval_batch_size: int = field(init=False) | ||
train_val_dataset: data.Dataset[tuple[torch.Tensor, ...]] | None = None | ||
train_dataset: data.Dataset[tuple[torch.Tensor, ...]] | None = None | ||
val_dataset: data.Dataset[tuple[torch.Tensor, ...]] | None = None | ||
test_dataset: data.Dataset[tuple[torch.Tensor, ...]] | None = None | ||
repetition_factor: float | None = None | ||
@dataclass | ||
class DataModule(LightningDataModule): | ||
batch_size: int = context.config.target_network_training.batch_size | ||
evaluation_batch_size: int = 1000 | ||
validation_ratio = context.config.validation_ratio | ||
train: data.Dataset[Any] = field(init=False) | ||
validation: data.Dataset[Any] = field(init=False) | ||
train_validation: data.Dataset[Any] = field(init=False) | ||
test: data.Dataset[Any] = field(init=False) | ||
|
||
def __post_init__(self) -> None: | ||
super().__init__() | ||
|
||
def prepare(self) -> None: | ||
self.prepare_data() | ||
self.setup("train") | ||
|
||
def prepare_data(self) -> None: | ||
pass | ||
|
||
def setup(self, stage: str | None = None) -> None: | ||
# allow overriding train_dataset | ||
if self.train_dataset is None: | ||
assert self.train_val_dataset is not None | ||
self.train_dataset, self.val_dataset = split_train_val( | ||
self.train_val_dataset, val_fraction=self.validation_ratio | ||
) | ||
|
||
def train_dataloader( | ||
self, shuffle: bool = True, batch_size: int | None = None | ||
) -> DataLoader[tuple[torch.Tensor, ...]]: | ||
if batch_size is None: | ||
batch_size = self.batch_size | ||
return self.get_dataloader( | ||
Split.train, batch_size, shuffle=shuffle, use_repeat=True | ||
def split_train_validation(self) -> None: | ||
split_sizes = self.calculate_split_sizes() | ||
seed = context.config.experiment.target_network_seed | ||
random_generator = torch.Generator().manual_seed(seed) | ||
split = random_split(self.train_validation, split_sizes, random_generator) | ||
self.train, self.validation = split | ||
|
||
def calculate_split_sizes(self) -> tuple[int, int]: | ||
total_size = len(self.train_validation) # type: ignore[arg-type] | ||
validation_size = int(self.validation_ratio * total_size) | ||
train_size = total_size - validation_size | ||
return train_size, validation_size | ||
|
||
def train_dataloader(self, shuffle: bool = True) -> DataLoader[Any]: | ||
return DataLoader(self.train, batch_size=self.batch_size, shuffle=True) | ||
|
||
def val_dataloader(self, shuffle: bool = False) -> DataLoader[Any]: | ||
return DataLoader( | ||
self.validation, batch_size=self.evaluation_batch_size, shuffle=False | ||
) | ||
|
||
def val_dataloader( | ||
self, shuffle: bool = False | ||
) -> DataLoader[tuple[torch.Tensor, ...]]: | ||
assert self.eval_batch_size is not None | ||
return self.get_dataloader(Split.valid, self.eval_batch_size, shuffle=shuffle) | ||
|
||
def test_dataloader( | ||
self, shuffle: bool = False | ||
) -> DataLoader[tuple[torch.Tensor, ...]]: | ||
assert self.eval_batch_size is not None | ||
return self.get_dataloader(Split.test, self.eval_batch_size, shuffle=shuffle) | ||
|
||
def get_all_inputs(self, split: Split) -> torch.Tensor: | ||
assert self.eval_batch_size is not None | ||
dataloader = self.get_dataloader( | ||
split, batch_size=self.eval_batch_size, shuffle=False | ||
def test_dataloader(self) -> DataLoader[Any]: | ||
return DataLoader( | ||
self.test, batch_size=self.evaluation_batch_size, shuffle=False | ||
) | ||
batched_inputs = tuple(batch[0] for batch in dataloader) | ||
return torch.vstack(batched_inputs) | ||
|
||
def get_all_targets(self, split: Split) -> torch.Tensor: | ||
assert self.eval_batch_size is not None | ||
dataloader = self.get_dataloader( | ||
split, batch_size=self.eval_batch_size, shuffle=False | ||
) | ||
batched_targets = tuple(batch[1] for batch in dataloader) | ||
return torch.vstack(batched_targets) | ||
|
||
def get_dataloader( | ||
self, | ||
split: Split, | ||
batch_size: int, | ||
shuffle: bool = False, | ||
use_repeat: bool = False, | ||
) -> DataLoader[tuple[torch.Tensor, ...]]: | ||
dataset = ( | ||
self.create_debug_dataset(split) | ||
if context.config.debug | ||
else self.create_dataset(split, use_repeat) | ||
) | ||
if batch_size == -1: | ||
batch_size = len(dataset) # type: ignore[arg-type] | ||
if context.config.debug: | ||
shuffle = False | ||
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) | ||
|
||
def create_debug_dataset( | ||
self, split: Split | ||
) -> data.Dataset[tuple[torch.Tensor, ...]]: | ||
# memorize 1 training batch during debugging | ||
dataset = self.get_dataset(Split.train) | ||
dataset = Subset(dataset, list(range(self.batch_size))) | ||
if split.is_train: | ||
dataset = ConcatDataset([dataset for _ in range(100)]) | ||
return dataset | ||
|
||
def create_dataset( | ||
self, split: Split, use_repeat: bool | ||
) -> data.Dataset[tuple[torch.Tensor, ...]]: | ||
dataset = self.get_dataset(split) | ||
if split.is_train and self.repetition_factor is not None and use_repeat: | ||
repetition_factor_int = int(self.repetition_factor) | ||
repetition_fraction = self.repetition_factor - repetition_factor_int | ||
datasets = [dataset] * repetition_factor_int | ||
if repetition_fraction: | ||
last_length = int(len(dataset) * repetition_fraction) # type: ignore[arg-type] | ||
last_dataset = Subset(dataset, list(range(last_length))) | ||
datasets.append(last_dataset) | ||
|
||
dataset = ConcatDataset(datasets) | ||
return dataset | ||
|
||
def get_dataset(self, datatype: Split) -> data.Dataset[tuple[torch.Tensor, ...]]: | ||
match datatype: | ||
case Split.train: | ||
dataset = self.train_dataset | ||
case Split.valid: | ||
dataset = self.val_dataset | ||
case Split.test: | ||
dataset = self.test_dataset | ||
case Split.train_val: | ||
dataset = self.train_val_dataset | ||
return cast(data.Dataset[tuple[torch.Tensor, ...]], dataset) | ||
|
||
def calibrate(self, network: Network | LightningModule) -> None: | ||
self.eval_batch_size = training.calculate_max_batch_size(network, self) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.