Skip to content

Commit

Permalink
Drop torchvision dependency for unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tvogels committed Nov 22, 2022
1 parent 8887fa7 commit 32f690a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pip.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
python-version: ${{ matrix.python-version }}

- name: Install PyTorch
run: pip install torch torchvision pillow==6.1 --extra-index-url https://download.pytorch.org/whl/cpu
run: pip install torch --extra-index-url https://download.pytorch.org/whl/cpu

- name: Build and install
run: pip install --verbose .[test]
Expand Down
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,4 @@ where = src

[options.extras_require]
test =
pytest
torchvision
pytest
13 changes: 10 additions & 3 deletions tests/powersgd_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import torch
import torchvision

from powersgd import PowerSGD, Config

def build_model():
return torch.nn.Sequential(
torch.nn.Conv2d(3, 100, 3),
torch.nn.ReLU(),
torch.nn.Conv2d(100, 50, 5),
torch.nn.Linear(50, 1)
)


def test_no_compression_in_the_beginning():
model = torchvision.models.resnet50()
model = build_model()
params = list(model.parameters())
config = Config(
rank=1,
Expand All @@ -29,7 +36,7 @@ def test_no_compression_in_the_beginning():

def test_error_feedback_mechanism():
torch.set_default_dtype(torch.float64)
model = torchvision.models.resnet50()
model = build_model()
params = list(model.parameters())
config = Config(
rank=2,
Expand Down

0 comments on commit 32f690a

Please sign in to comment.