Skip to content

Commit

Permalink
test pipelines and reconstructions
Browse files Browse the repository at this point in the history
  • Loading branch information
quintenroets committed Apr 11, 2024
1 parent 6095e0a commit 85a7a15
Show file tree
Hide file tree
Showing 55 changed files with 242 additions and 1,686 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Revnets
![Python version](https://img.shields.io/badge/python-3.10+-brightgreen)
![Operating system](https://img.shields.io/badge/os-linux%20%7c%20macOS-brightgreen)
![Coverage](https://img.shields.io/badge/coverage-64%25-brightgreen)
![Coverage](https://img.shields.io/badge/coverage-83%25-brightgreen)

Reverse engineer internal parameters of black box neural networks

Expand Down
12 changes: 2 additions & 10 deletions src/revnets/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,2 @@
from . import (
arbitrary_correlated_features,
base,
correlated_features,
mnist,
mnist1d,
output_supervision,
random,
)
from .base import Dataset
from . import mnist, mnist1d
from .base import DataModule
19 changes: 0 additions & 19 deletions src/revnets/data/arbitrary_correlated_features.py

This file was deleted.

161 changes: 33 additions & 128 deletions src/revnets/data/base.py
Original file line number Diff line number Diff line change
@@ -1,146 +1,51 @@
from dataclasses import dataclass, field
from typing import cast
from typing import Any, TypeVar

import torch
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning import LightningDataModule
from torch.utils import data
from torch.utils.data import ConcatDataset, DataLoader, Subset

from revnets import training
from revnets.training import Network
from torch.utils.data import DataLoader, random_split

from ..context import context
from .split import Split
from .utils import split_train_val

T = TypeVar("T")

@dataclass
class Dataset(LightningDataModule):
"""
For data size experiments, we want to now how many samples we need in order to have
fair comparisons, we keep the number of effective samples the same by scaling the
number of repetitions in the training set.
"""

eval_batch_size: int = field(init=False)
train_val_dataset: data.Dataset[tuple[torch.Tensor, ...]] | None = None
train_dataset: data.Dataset[tuple[torch.Tensor, ...]] | None = None
val_dataset: data.Dataset[tuple[torch.Tensor, ...]] | None = None
test_dataset: data.Dataset[tuple[torch.Tensor, ...]] | None = None
repetition_factor: float | None = None
@dataclass
class DataModule(LightningDataModule):
batch_size: int = context.config.target_network_training.batch_size
evaluation_batch_size: int = 1000
validation_ratio = context.config.validation_ratio
train: data.Dataset[Any] = field(init=False)
validation: data.Dataset[Any] = field(init=False)
train_validation: data.Dataset[Any] = field(init=False)
test: data.Dataset[Any] = field(init=False)

def __post_init__(self) -> None:
super().__init__()

def prepare(self) -> None:
self.prepare_data()
self.setup("train")

def prepare_data(self) -> None:
pass

def setup(self, stage: str | None = None) -> None:
# allow overriding train_dataset
if self.train_dataset is None:
assert self.train_val_dataset is not None
self.train_dataset, self.val_dataset = split_train_val(
self.train_val_dataset, val_fraction=self.validation_ratio
)

def train_dataloader(
self, shuffle: bool = True, batch_size: int | None = None
) -> DataLoader[tuple[torch.Tensor, ...]]:
if batch_size is None:
batch_size = self.batch_size
return self.get_dataloader(
Split.train, batch_size, shuffle=shuffle, use_repeat=True
def split_train_validation(self) -> None:
split_sizes = self.calculate_split_sizes()
seed = context.config.experiment.target_network_seed
random_generator = torch.Generator().manual_seed(seed)
split = random_split(self.train_validation, split_sizes, random_generator)
self.train, self.validation = split

def calculate_split_sizes(self) -> tuple[int, int]:
total_size = len(self.train_validation) # type: ignore[arg-type]
validation_size = int(self.validation_ratio * total_size)
train_size = total_size - validation_size
return train_size, validation_size

def train_dataloader(self, shuffle: bool = True) -> DataLoader[Any]:
return DataLoader(self.train, batch_size=self.batch_size, shuffle=True)

def val_dataloader(self, shuffle: bool = False) -> DataLoader[Any]:
return DataLoader(
self.validation, batch_size=self.evaluation_batch_size, shuffle=False
)

def val_dataloader(
self, shuffle: bool = False
) -> DataLoader[tuple[torch.Tensor, ...]]:
assert self.eval_batch_size is not None
return self.get_dataloader(Split.valid, self.eval_batch_size, shuffle=shuffle)

def test_dataloader(
self, shuffle: bool = False
) -> DataLoader[tuple[torch.Tensor, ...]]:
assert self.eval_batch_size is not None
return self.get_dataloader(Split.test, self.eval_batch_size, shuffle=shuffle)

def get_all_inputs(self, split: Split) -> torch.Tensor:
assert self.eval_batch_size is not None
dataloader = self.get_dataloader(
split, batch_size=self.eval_batch_size, shuffle=False
def test_dataloader(self) -> DataLoader[Any]:
return DataLoader(
self.test, batch_size=self.evaluation_batch_size, shuffle=False
)
batched_inputs = tuple(batch[0] for batch in dataloader)
return torch.vstack(batched_inputs)

def get_all_targets(self, split: Split) -> torch.Tensor:
assert self.eval_batch_size is not None
dataloader = self.get_dataloader(
split, batch_size=self.eval_batch_size, shuffle=False
)
batched_targets = tuple(batch[1] for batch in dataloader)
return torch.vstack(batched_targets)

def get_dataloader(
self,
split: Split,
batch_size: int,
shuffle: bool = False,
use_repeat: bool = False,
) -> DataLoader[tuple[torch.Tensor, ...]]:
dataset = (
self.create_debug_dataset(split)
if context.config.debug
else self.create_dataset(split, use_repeat)
)
if batch_size == -1:
batch_size = len(dataset) # type: ignore[arg-type]
if context.config.debug:
shuffle = False
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

def create_debug_dataset(
self, split: Split
) -> data.Dataset[tuple[torch.Tensor, ...]]:
# memorize 1 training batch during debugging
dataset = self.get_dataset(Split.train)
dataset = Subset(dataset, list(range(self.batch_size)))
if split.is_train:
dataset = ConcatDataset([dataset for _ in range(100)])
return dataset

def create_dataset(
self, split: Split, use_repeat: bool
) -> data.Dataset[tuple[torch.Tensor, ...]]:
dataset = self.get_dataset(split)
if split.is_train and self.repetition_factor is not None and use_repeat:
repetition_factor_int = int(self.repetition_factor)
repetition_fraction = self.repetition_factor - repetition_factor_int
datasets = [dataset] * repetition_factor_int
if repetition_fraction:
last_length = int(len(dataset) * repetition_fraction) # type: ignore[arg-type]
last_dataset = Subset(dataset, list(range(last_length)))
datasets.append(last_dataset)

dataset = ConcatDataset(datasets)
return dataset

def get_dataset(self, datatype: Split) -> data.Dataset[tuple[torch.Tensor, ...]]:
match datatype:
case Split.train:
dataset = self.train_dataset
case Split.valid:
dataset = self.val_dataset
case Split.test:
dataset = self.test_dataset
case Split.train_val:
dataset = self.train_val_dataset
return cast(data.Dataset[tuple[torch.Tensor, ...]], dataset)

def calibrate(self, network: Network | LightningModule) -> None:
self.eval_batch_size = training.calculate_max_batch_size(network, self)
41 changes: 0 additions & 41 deletions src/revnets/data/correlated_features.py

This file was deleted.

17 changes: 9 additions & 8 deletions src/revnets/data/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@


@dataclass
class Dataset(base.Dataset):
class DataModule(base.DataModule):
path: str = str(Path.data / "mnist")
transform: transforms.Compose = transforms.Compose(
transformation: transforms.Compose = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

Expand All @@ -18,10 +18,11 @@ def prepare_data(self) -> None:
datasets.MNIST(self.path, train=train, download=True)

def setup(self, stage: str | None = None) -> None:
self.train_val_dataset = datasets.MNIST(
self.path, train=True, download=True, transform=self.transform
)
self.test_dataset = datasets.MNIST(
self.path, train=False, download=True, transform=self.transform
self.train_validation = self.load_dataset(train=True)
self.test = self.load_dataset(train=False)
self.split_train_validation()

def load_dataset(self, train: bool) -> datasets.MNIST:
return datasets.MNIST(
self.path, train=train, download=True, transform=self.transformation
)
super().setup()
68 changes: 45 additions & 23 deletions src/revnets/data/mnist1d.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import pickle
from dataclasses import dataclass

Expand All @@ -20,34 +22,54 @@ class RawData(SerializationMixin):
x_test: NDArray[np.float32]
y_test: NDArray[np.float32]

@classmethod
def from_path(cls, path: Path) -> RawData:
with path.open("rb") as fp:
data = pickle.load(fp)
return cls(data["x"], data["y"], data["x_test"], data["y_test"])

@dataclass
class Dataset(base.Dataset):
path: Path = Path.data / "mnist_1D.pkl"

def prepare_data(self) -> None:
data = self.load_data()

def scale(self) -> None:
scaler = StandardScaler()
data.x = scaler.fit_transform(data.x)
data.x_test = scaler.transform(data.x_test)
self.x = scaler.fit_transform(self.x)
self.x_test = scaler.transform(self.x_test)

x = torch.Tensor(data.x)
x_test = torch.Tensor(data.x_test)
def extract_train_validation(self) -> TensorDataset:
x = torch.Tensor(self.x)
y = torch.LongTensor(self.y)
return TensorDataset(x, y)

y = torch.LongTensor(data.y)
y_test = torch.LongTensor(data.y_test)
def extract_test(self) -> TensorDataset:
x = torch.Tensor(self.x_test)
y = torch.LongTensor(self.y_test)
return TensorDataset(x, y)

self.train_val_dataset = TensorDataset(x, y)
self.test_dataset = TensorDataset(x_test, y_test)

def load_data(self) -> RawData:
self.check_download()
with self.path.open("rb") as fp:
data = pickle.load(fp)
return RawData(data["x"], data["y"], data["x_test"], data["y_test"])
@dataclass
class DataModule(base.DataModule):
path: Path = Path.data / "mnist_1D"
raw_path: Path = Path.data / "mnist_1D.pkl"
download_url: str = (
"https://github.com/greydanus/mnist1d/raw/master/mnist1d_data.pkl"
)

def check_download(self) -> None:
def prepare_data(self) -> None:
if not self.path.exists():
url = "https://github.com/greydanus/mnist1d/raw/master/mnist1d_data.pkl"
self.path.byte_content = requests.get(url, allow_redirects=True).content
if not self.raw_path.exists():
self.download()
self.process()

def download(self) -> None:
response = requests.get(self.download_url, allow_redirects=True)
self.raw_path.byte_content = response.content

def process(self) -> None:
raw_data = RawData.from_path(self.raw_path)
raw_data.scale()
data = raw_data.extract_train_validation(), raw_data.extract_test()
path = str(self.path)
torch.save(data, path)

def setup(self, stage: str) -> None:
path = str(self.path)
self.train_validation, self.test = torch.load(path)
self.split_train_validation()
Loading

0 comments on commit 85a7a15

Please sign in to comment.