Skip to content

Commit

Permalink
Support empty batches (#530)
Browse files Browse the repository at this point in the history
Summary:
## Background

Poisson sampling can sometimes result in an empty input batch, especially if a sampling rate (i.e. expected batch size) is small. This is not out of the ordinary and should be handled accordingly - gradients (signal) should be set to 0 and noise should still be added.

We've made an [attempt](https://github.com/pytorch/opacus/blob/main/opacus/data_loader.py#L31) to support this behaviour, but it wasn't fully covered with tests and got broken over time. As a result, at the moment we have a DataLoader that is capable of producing zero-sized batches, GradSampleModule that only partially supports them and DPOptimizer that doesn't support them at all

This PR addresses Issue #522 (thanks xichens for reporting)

## Improvements

This diff fixes the following

* DPOptimizer can now handle empty batches
* BatchMemoryManager can now handle empty batches
* Adds a PrivacyEngine test with empty batches
* Adds BatchMemoryManager test with empty batches
* DataLoader now respects dtype of the inputs (i.e. empty batches only used to work with float input tensors)
* ExpandedWeights still can's process empty batches, which we call out in our readme (FYI samdow )

Pull Request resolved: #530

Reviewed By: alexandresablayrolles

Differential Revision: D40676213

Pulled By: ffuuugor

fbshipit-source-id: dc637fd91a3c20d481d22c5de97d22d42e423a71
  • Loading branch information
Igor Shilov authored and facebook-github-bot committed Nov 7, 2022
1 parent 64680d1 commit c5562a7
Show file tree
Hide file tree
Showing 16 changed files with 191 additions and 49 deletions.
36 changes: 30 additions & 6 deletions opacus/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Any, Optional, Sequence
from typing import Any, Optional, Sequence, Tuple, Type, Union

import torch
from opacus.utils.uniform_sampler import (
Expand All @@ -29,7 +29,10 @@


def wrap_collate_with_empty(
collate_fn: Optional[_collate_fn_t], sample_empty_shapes: Sequence
*,
collate_fn: Optional[_collate_fn_t],
sample_empty_shapes: Sequence[Tuple],
dtypes: Sequence[Union[torch.dtype, Type]],
):
"""
Wraps given collate function to handle empty batches.
Expand All @@ -49,12 +52,15 @@ def collate(batch):
if len(batch) > 0:
return collate_fn(batch)
else:
return [torch.zeros(x) for x in sample_empty_shapes]
return [
torch.zeros(shape, dtype=dtype)
for shape, dtype in zip(sample_empty_shapes, dtypes)
]

return collate


def shape_safe(x: Any):
def shape_safe(x: Any) -> Tuple:
"""
Exception-safe getter for ``shape`` attribute
Expand All @@ -67,6 +73,19 @@ def shape_safe(x: Any):
return x.shape if hasattr(x, "shape") else ()


def dtype_safe(x: Any) -> Union[torch.dtype, Type]:
"""
Exception-safe getter for ``dtype`` attribute
Args:
x: any object
Returns:
``x.dtype`` if attribute exists, type of x otherwise
"""
return x.dtype if hasattr(x, "dtype") else type(x)


class DPDataLoader(DataLoader):
"""
DataLoader subclass that always does Poisson sampling and supports empty batches
Expand Down Expand Up @@ -143,7 +162,8 @@ def __init__(
sample_rate=sample_rate,
generator=generator,
)
sample_empty_shapes = [[0, *shape_safe(x)] for x in dataset[0]]
sample_empty_shapes = [(0, *shape_safe(x)) for x in dataset[0]]
dtypes = [dtype_safe(x) for x in dataset[0]]
if collate_fn is None:
collate_fn = default_collate

Expand All @@ -156,7 +176,11 @@ def __init__(
dataset=dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=wrap_collate_with_empty(collate_fn, sample_empty_shapes),
collate_fn=wrap_collate_with_empty(
collate_fn=collate_fn,
sample_empty_shapes=sample_empty_shapes,
dtypes=dtypes,
),
pin_memory=pin_memory,
timeout=timeout,
worker_init_fn=worker_init_fn,
Expand Down
1 change: 1 addition & 0 deletions opacus/grad_sample/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Please note that these are known limitations and we plan to improve Expanded Wei
| `batch_first=False` | ✅ Supported | Not supported | ✅ Supported |
| Recurrent networks | ✅ Supported | Not supported | ✅ Supported |
| Padding `same` in Conv | ✅ Supported | Not supported | ✅ Supported |
| Empty poisson batches | ✅ Supported | Not supported | Not supported |

† Note, that performance differences are unstable and can vary a lot depending on the exact model and batch size.
Numbers above are averaged over benchmarks with small models consisting of convolutional and linear layers.
Expand Down
8 changes: 8 additions & 0 deletions opacus/grad_sample/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ def compute_conv_grad_sample(
"""
activations = activations[0]
n = activations.shape[0]
if n == 0:
# Empty batch
ret = {}
ret[layer.weight] = torch.zeros_like(layer.weight).unsqueeze(0)
if layer.bias is not None and layer.bias.requires_grad:
ret[layer.bias] = torch.zeros_like(layer.bias).unsqueeze(0)
return ret

# get activations and backprops in shape depending on the Conv layer
if type(layer) == nn.Conv2d:
activations = unfold2d(
Expand Down
4 changes: 4 additions & 0 deletions opacus/grad_sample/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def compute_embedding_grad_sample(
torch.backends.cudnn.deterministic = True

batch_size = activations.shape[0]
if batch_size == 0:
ret[layer.weight] = torch.zeros_like(layer.weight).unsqueeze(0)
return ret

index = (
activations.unsqueeze(-1)
.expand(*activations.shape, layer.embedding_dim)
Expand Down
18 changes: 11 additions & 7 deletions opacus/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,13 +394,17 @@ def clip_and_accumulate(self):
Stores clipped and aggregated gradients into `p.summed_grad```
"""

per_param_norms = [
g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
]
per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
per_sample_clip_factor = (self.max_grad_norm / (per_sample_norms + 1e-6)).clamp(
max=1.0
)
if len(self.grad_samples[0]) == 0:
# Empty batch
per_sample_clip_factor = torch.zeros((0,))
else:
per_param_norms = [
g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
]
per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
per_sample_clip_factor = (
self.max_grad_norm / (per_sample_norms + 1e-6)
).clamp(max=1.0)

for p in self.params:
_check_processed_flag(p.grad_sample)
Expand Down
88 changes: 78 additions & 10 deletions opacus/tests/batch_memory_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,37 +37,42 @@ class BatchMemoryManagerTest(unittest.TestCase):
GSM_MODE = "hooks"

def setUp(self) -> None:
self.data_size = 100
self.batch_size = 10
self.data_size = 256
self.inps = torch.randn(self.data_size, 5)
self.tgts = torch.randn(
self.data_size,
)

self.dataset = TensorDataset(self.inps, self.tgts)

def _init_training(self, **data_loader_kwargs):
def _init_training(self, batch_size=10, **data_loader_kwargs):
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data_loader = DataLoader(
self.dataset, batch_size=self.batch_size, **data_loader_kwargs
self.dataset, batch_size=batch_size, **data_loader_kwargs
)

return model, optimizer, data_loader

@given(
num_workers=st.integers(0, 4),
pin_memory=st.booleans(),
batch_size=st.sampled_from([8, 16, 64]),
max_physical_batch_size=st.sampled_from([4, 8]),
)
@settings(deadline=10000)
def test_basic(
self,
num_workers: int,
pin_memory: bool,
batch_size: int,
max_physical_batch_size: int,
):
batches_per_step = max(1, batch_size // max_physical_batch_size)
model, optimizer, data_loader = self._init_training(
num_workers=num_workers,
pin_memory=pin_memory,
batch_size=batch_size,
)

privacy_engine = PrivacyEngine()
Expand All @@ -80,22 +85,19 @@ def test_basic(
poisson_sampling=False,
grad_sample_mode=self.GSM_MODE,
)
max_physical_batch_size = 3
with BatchMemoryManager(
data_loader=data_loader,
max_physical_batch_size=max_physical_batch_size,
optimizer=optimizer,
) as new_data_loader:
self.assertEqual(
len(data_loader), len(data_loader.dataset) // self.batch_size
)
self.assertEqual(len(data_loader), len(data_loader.dataset) // batch_size)
self.assertEqual(
len(new_data_loader),
len(data_loader.dataset) // max_physical_batch_size,
)
weights_before = torch.clone(model._module.fc.weight)
for i, (x, y) in enumerate(new_data_loader):
self.assertTrue(x.shape[0] <= 3)
self.assertTrue(x.shape[0] <= max_physical_batch_size)

out = model(x)
loss = (y - out).mean()
Expand All @@ -104,7 +106,63 @@ def test_basic(
optimizer.step()
optimizer.zero_grad()

if i % 4 < 3:
if (i + 1) % batches_per_step > 0:
self.assertTrue(
torch.allclose(model._module.fc.weight, weights_before)
)
else:
self.assertFalse(
torch.allclose(model._module.fc.weight, weights_before)
)
weights_before = torch.clone(model._module.fc.weight)

@given(
num_workers=st.integers(0, 4),
pin_memory=st.booleans(),
)
@settings(deadline=10000)
def test_empty_batch(
self,
num_workers: int,
pin_memory: bool,
):
batch_size = 2
max_physical_batch_size = 10
torch.manual_seed(30)

model, optimizer, data_loader = self._init_training(
num_workers=num_workers,
pin_memory=pin_memory,
batch_size=batch_size,
)

privacy_engine = PrivacyEngine()
model, optimizer, data_loader = privacy_engine.make_private(
module=model,
optimizer=optimizer,
data_loader=data_loader,
noise_multiplier=0.0,
max_grad_norm=1e5,
poisson_sampling=True,
grad_sample_mode=self.GSM_MODE,
)
with BatchMemoryManager(
data_loader=data_loader,
max_physical_batch_size=max_physical_batch_size,
optimizer=optimizer,
) as new_data_loader:
weights_before = torch.clone(model._module.fc.weight)
for i, (x, y) in enumerate(new_data_loader):
self.assertTrue(x.shape[0] <= max_physical_batch_size)

out = model(x)
loss = (y - out).mean()

loss.backward()
optimizer.step()
optimizer.zero_grad()

if len(x) == 0:
self.assertTrue(
torch.allclose(model._module.fc.weight, weights_before)
)
Expand Down Expand Up @@ -174,3 +232,13 @@ def test_equivalent_to_one_batch(self):
)
class BatchMemoryManagerTestWithExpandedWeights(BatchMemoryManagerTest):
GSM_MODE = "ew"

def test_empty_batch(self):
pass


@unittest.skipIf(
torch.__version__ >= API_CUTOFF_VERSION, "not supported in this torch version"
)
class BatchMemoryManagerTestWithFunctorch(BatchMemoryManagerTest):
GSM_MODE = "functorch"
36 changes: 24 additions & 12 deletions opacus/tests/grad_samples/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import io
import unittest
from typing import Dict, List, Union
from typing import Dict, Iterable, List, Tuple, Union

import numpy as np
import torch
Expand All @@ -36,6 +36,13 @@ def shrinker(x, factor: int = 2):
return max(1, x // factor) # if avoid returning 0 for x == 1


def is_batch_empty(batch: Union[torch.Tensor, Iterable[torch.Tensor]]):
if type(batch) is torch.Tensor:
return batch.numel() == 0
else:
return batch[0].numel() == 0


class ModelWithLoss(nn.Module):
"""
To test the gradients of a module, we need to have a loss.
Expand Down Expand Up @@ -221,7 +228,7 @@ def compute_opacus_grad_sample(

def run_test(
self,
x: Union[torch.Tensor, PackedSequence],
x: Union[torch.Tensor, PackedSequence, Tuple],
module: nn.Module,
batch_first=True,
atol=10e-6,
Expand All @@ -235,7 +242,9 @@ def run_test(
except ImportError:
grad_sample_modes = ["hooks"]

if type(module) is nn.EmbeddingBag:
if type(module) is nn.EmbeddingBag or (
type(x) is not PackedSequence and is_batch_empty(x)
):
grad_sample_modes = ["hooks"]

for grad_sample_mode in grad_sample_modes:
Expand Down Expand Up @@ -277,6 +286,14 @@ def run_test_with_reduction(
grad_sample_mode="hooks",
chunk_method=iter,
):
opacus_grad_samples = self.compute_opacus_grad_sample(
x,
module,
batch_first=batch_first,
loss_reduction=loss_reduction,
grad_sample_mode=grad_sample_mode,
)

if type(x) is PackedSequence:
x_unpacked = _unpack_packedsequences(x)
microbatch_grad_samples = self.compute_microbatch_grad_sample(
Expand All @@ -285,22 +302,17 @@ def run_test_with_reduction(
batch_first=batch_first,
loss_reduction=loss_reduction,
)
else:
elif not is_batch_empty(x):
microbatch_grad_samples = self.compute_microbatch_grad_sample(
x,
module,
batch_first=batch_first,
loss_reduction=loss_reduction,
chunk_method=chunk_method,
)

opacus_grad_samples = self.compute_opacus_grad_sample(
x,
module,
batch_first=batch_first,
loss_reduction=loss_reduction,
grad_sample_mode=grad_sample_mode,
)
else:
# We've checked opacus can handle 0-sized batch. Microbatch doesn't make sense
return

if microbatch_grad_samples.keys() != opacus_grad_samples.keys():
raise ValueError(
Expand Down
6 changes: 4 additions & 2 deletions opacus/tests/grad_samples/conv1d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

class Conv1d_test(GradSampleHooks_test):
@given(
N=st.integers(1, 4),
N=st.integers(0, 4),
C=st.sampled_from([1, 3, 32]),
W=st.integers(6, 10),
out_channels_mapper=st.sampled_from([expander, shrinker]),
Expand Down Expand Up @@ -67,4 +67,6 @@ def test_conv1d(
dilation=dilation,
groups=groups,
)
self.run_test(x, conv, batch_first=True, atol=10e-5, rtol=10e-4)
self.run_test(
x, conv, batch_first=True, atol=10e-5, rtol=10e-4, ew_compatible=N > 0
)
Loading

0 comments on commit c5562a7

Please sign in to comment.