Skip to content

Commit

Permalink
use ruff best practices
Browse files Browse the repository at this point in the history
  • Loading branch information
quintenroets committed Aug 8, 2024
1 parent 72fd579 commit b58b339
Show file tree
Hide file tree
Showing 68 changed files with 370 additions and 242 deletions.
17 changes: 10 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,27 @@ pythonpath = [
]
addopts = "-p no:warnings"


[tool.ruff]
fix = true

[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"UP", # pyupgrade
select = ["ALL"]
ignore = [
"ANN101", # annotate self
"ANN102", # annotate cls
"ANN401", # annotated with Any
"D", # docstrings
]

[tool.ruff.lint.isort]
known-first-party = ["src"]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]
"tests/*" = [
"S101", # assert used
"PLR2004" # Magic value used in comparison
]

[tool.setuptools.package-data]
revnets = ["py.typed"]
3 changes: 1 addition & 2 deletions src/revnets/cli/entry_point.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from package_utils.context.entry_point import create_entry_point

from revnets.context import context
from revnets.main.main import Experiment

from ..context import context

experiment = Experiment()

entry_point = create_entry_point(experiment.run, context)
3 changes: 1 addition & 2 deletions src/revnets/cli/launch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from package_utils.context.entry_point import create_entry_point

from revnets.context import context
from revnets.launching import main

from ..context import context

entry_point = create_entry_point(main, context)
5 changes: 3 additions & 2 deletions src/revnets/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from package_utils.context import Context as Context_

from ..models import Config, Options, Path
from revnets.models import Config, Options, Path


class Context(Context_[Options, Config, None]):
Expand Down Expand Up @@ -49,7 +49,8 @@ def dtype(self) -> torch.dtype:
case 64:
dtype = torch.float64
case _:
raise ValueError(f"Unsupported precision {self.config.precision}")
msg = f"Unsupported precision {self.config.precision}"
raise ValueError(msg)
return dtype


Expand Down
20 changes: 12 additions & 8 deletions src/revnets/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
from torch.utils import data
from torch.utils.data import DataLoader, Subset, random_split

from ..context import context
from revnets.context import context

T = TypeVar("T")


@dataclass
class DataModule(LightningDataModule):
batch_size: int = field(
default_factory=lambda: context.config.target_network_training.batch_size
default_factory=lambda: context.config.target_network_training.batch_size,
)
evaluation_batch_size: int = 1000
validation_ratio: float = field(
default_factory=lambda: context.config.validation_ratio
default_factory=lambda: context.config.validation_ratio,
)
train: data.Dataset[Any] = field(init=False)
validation: data.Dataset[Any] = field(init=False)
Expand All @@ -37,17 +37,21 @@ def split(self, dataset: data.Dataset[Any]) -> list[Subset[Any]]:
random_generator = torch.Generator().manual_seed(seed)
return random_split(dataset, split_sizes, random_generator)

def train_dataloader(self, shuffle: bool = True) -> DataLoader[Any]:
return DataLoader(self.train, batch_size=self.batch_size, shuffle=True)
def train_dataloader(self, *, shuffle: bool = True) -> DataLoader[Any]:
return DataLoader(self.train, batch_size=self.batch_size, shuffle=shuffle)

def val_dataloader(self, shuffle: bool = False) -> DataLoader[Any]:
def val_dataloader(self, *, shuffle: bool = False) -> DataLoader[Any]:
return DataLoader(
self.validation, batch_size=self.evaluation_batch_size, shuffle=False
self.validation,
batch_size=self.evaluation_batch_size,
shuffle=shuffle,
)

def test_dataloader(self) -> DataLoader[Any]:
return DataLoader(
self.test, batch_size=self.evaluation_batch_size, shuffle=False
self.test,
batch_size=self.evaluation_batch_size,
shuffle=False,
)

@property
Expand Down
20 changes: 13 additions & 7 deletions src/revnets/data/mnist.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@
from dataclasses import dataclass
from dataclasses import dataclass, field

from torchvision import datasets, transforms

from ..models import Path
from revnets.models import Path

from . import base


@dataclass
class DataModule(base.DataModule):
path: str = str(Path.data / "mnist")
transformation: transforms.Compose = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
transformation: transforms.Compose = field(
default_factory=lambda: transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))],
),
)

def prepare_data(self) -> None:
for train in (False, True):
datasets.MNIST(self.path, train=train, download=True)

def setup(self, stage: str | None = None) -> None:
def setup(self, stage: str | None = None) -> None: # noqa: ARG002
self.train_validation = self.load_dataset(train=True)
self.test = self.load_dataset(train=False)
self.split_train_validation()

def load_dataset(self, train: bool) -> datasets.MNIST:
def load_dataset(self, *, train: bool) -> datasets.MNIST:
return datasets.MNIST(
self.path, train=train, download=True, transform=self.transformation
self.path,
train=train,
download=True,
transform=self.transformation,
)
17 changes: 10 additions & 7 deletions src/revnets/data/mnist1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@

import pickle
from dataclasses import dataclass
from typing import cast
from typing import TYPE_CHECKING, cast

import numpy as np
import requests
import torch
from numpy.typing import NDArray
from package_utils.dataclasses import SerializationMixin
from simple_classproperty import classproperty
from sklearn.preprocessing import StandardScaler
from torch.utils.data import TensorDataset

from ..models import Path
from revnets.models import Path

from . import base

if TYPE_CHECKING: # pragma: nocover
import numpy as np
from numpy.typing import NDArray


@dataclass
class RawData(SerializationMixin):
Expand All @@ -27,7 +30,7 @@ class RawData(SerializationMixin):
@classmethod
def from_path(cls, path: Path) -> RawData:
with path.open("rb") as fp:
data = pickle.load(fp)
data = pickle.load(fp) # noqa: S301
return cls(data["x"], data["y"], data["x_test"], data["y_test"])

def scale(self) -> None:
Expand Down Expand Up @@ -71,7 +74,7 @@ def prepare_data(self) -> None:
self.process()

def download(self) -> None:
response = requests.get(self.download_url, allow_redirects=True)
response = requests.get(self.download_url, allow_redirects=True, timeout=10)
self.raw_path.byte_content = response.content

def process(self) -> None:
Expand All @@ -81,7 +84,7 @@ def process(self) -> None:
path = str(self.path)
torch.save(data, path)

def setup(self, stage: str) -> None:
def setup(self, stage: str) -> None: # noqa: ARG002
path = str(self.path)
self.train_validation, self.test = torch.load(path)
self.split_train_validation()
7 changes: 3 additions & 4 deletions src/revnets/evaluations/analysis/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Module
from torch.nn.functional import relu

from revnets.evaluations import base
from revnets.utils.data import compute_targets

from .. import base


@dataclass
class Evaluator(base.Evaluator):
Expand Down Expand Up @@ -46,7 +45,7 @@ def visualize_model_outputs(self, model: Module, name: str) -> None:
inputs = self.create_queries()
outputs = compute_targets(inputs, model)
if self.activation:
outputs = F.relu(outputs) # pragma: nocover
outputs = relu(outputs) # pragma: nocover
ActivationsVisualizer(outputs, name).run()

def create_queries(self) -> torch.Tensor:
Expand Down
3 changes: 1 addition & 2 deletions src/revnets/evaluations/analysis/trained_target.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from dataclasses import dataclass

from revnets.evaluations import base
from revnets.training import Trainer
from revnets.training.targets import Network

from .. import base


@dataclass
class Evaluator(base.Evaluator):
Expand Down
16 changes: 8 additions & 8 deletions src/revnets/evaluations/analysis/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from torch.nn import Module

from revnets.context import context

from ...utils.colors import get_colors
from ..weights import layers_mae
from revnets.evaluations.weights import layers_mae
from revnets.utils.colors import get_colors

cpu = torch.device("cpu")

Expand All @@ -32,23 +31,24 @@ def visualize_network_weights(self, network: Module, name: str) -> None:

@classmethod
def visualize_layer_weights(
cls, weights: torch.Tensor, title: str, n_show: int | None = 10
cls,
weights: torch.Tensor,
title: str,
n_show: int | None = 10,
) -> None:
weights = weights[:n_show].cpu()

# weights = torch.transpose(weights, 0, 1)

n_neurons = len(weights)
colors = get_colors(number_of_colors=n_neurons)
ax = cls.create_figure()

for i, (neuron, color) in enumerate(zip(weights, colors)):
for i, (neuron, color) in enumerate(zip(weights, colors, strict=False)):
label = f"Neuron {i + 1}"
ax.plot(neuron, color=color, label=label)

n_features = len(weights[0])
interval = n_features // 4
x_ticks = list(range(0, n_features, interval)) + [n_features - 1]
x_ticks = [*list(range(0, n_features, interval)), n_features - 1]
if n_features - 2 not in x_ticks and False:
x_ticks.insert(-1, n_features - 2)
x_tick_labels = [str(xtick) for xtick in x_ticks[:-1]] + ["Bias weight"]
Expand Down
2 changes: 1 addition & 1 deletion src/revnets/evaluations/attack/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import torch

from revnets.evaluations import base
from revnets.training import Trainer
from revnets.training.targets import Metrics

from .. import base
from .network import AttackNetwork


Expand Down
Loading

0 comments on commit b58b339

Please sign in to comment.