Skip to content

Commit

Permalink
snip accumulation test
Browse files Browse the repository at this point in the history
  • Loading branch information
oleksost committed Sep 16, 2024
1 parent 369c404 commit da81782
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 15 deletions.
22 changes: 7 additions & 15 deletions mttl/models/modifiers/sparse_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,16 +311,18 @@ def get_weights_for_mask_learning(self):
return (
self.base_weight,
self.base_bias,
csr_matrix(self.sparse_weights.data.cpu().float(), shape=self.sparse_weights.shape),
csr_matrix(
self.sparse_weights.data.cpu().float(), shape=self.sparse_weights.shape
),
self.sparse_bias,
)

def reset_sparse_weights(self, mask: csr_matrix):
self.binary_mask = torch.tensor(
mask.toarray(), device=self.base_weight.device, dtype=self.base_weight.dtype
)
r,c = get_2d_indices_from_csr_matrix(mask)
self.sparse_weights.data[r,c] = torch.tensor(
r, c = get_2d_indices_from_csr_matrix(mask)
self.sparse_weights.data[r, c] = torch.tensor(
mask.data, dtype=self.base_weight.dtype, device=self.base_weight.device
)

Expand Down Expand Up @@ -568,17 +570,6 @@ def __init__(self, sparse_layer: SparseLinear, config: SparseMaskConfig):
sparse_layer.base_weight, device="cpu"
)

# make sure accumulated_sparse_weights are on CPU
def cuda(self, *args, **kwargs):
super().cuda(*args, **kwargs)
self.accumulated_sparse_weights = self.accumulated_sparse_weights.cpu()
return self

def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.accumulated_sparse_weights = self.accumulated_sparse_weights.cpu()
return self

def switch_to_mask_update_modus(self):
self.updating_the_mask = True
self._selected_indices = None
Expand Down Expand Up @@ -632,8 +623,9 @@ def switch_to_weights_update_modus(self):
self.sparse_layer_weights, self.sparse_layer_biases = None, None
# update the mask of the sparse layer
# SNIP sets the new weights to zeros but weights that have been learned in the past are kept
new_weights = torch_coo_to_scipy_csr(self.selected_params) * 0.0
new_weights = torch_coo_to_scipy_csr(self.selected_params)
r, c = get_2d_indices_from_csr_matrix(new_weights)
new_weights *= 0.0
new_weights[r, c] = self.accumulated_sparse_weights[r, c].float()
self.sparse_layer.reset_sparse_weights(new_weights)
self._selected_indices = None
Expand Down
52 changes: 52 additions & 0 deletions tests/test_sparse_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os

import pytest
import torch
import torch.nn as nn
from pytorch_lightning import seed_everything

from mttl.models.modifiers import modify_transformer
Expand All @@ -10,6 +12,7 @@
ScatteredSparseLinearModule,
SNIPMaskUpdateWrapper,
SparseLinearModule,
SparseMaskAdapter,
SparseMaskConfig,
)

Expand Down Expand Up @@ -220,5 +223,54 @@ def test_snip_updater(dummy_batch):
assert parent_module._mask_update_steps == 1


@pytest.mark.parametrize(
"sps_type", ["sp_add+sp_mm", "scattered", "masked_linear", "triton_block_sparse"]
)
def test_snip_weight_accumulation(sps_type):
os.environ["CONFIG_PATH"] = "./"

seed_everything(0)
from transformers.models.llama.configuration_llama import LlamaConfig

adapter_config = SparseMaskConfig(
sps_impl=sps_type,
sps_type="block_sparse",
keep_ratio=0.02,
block_size=10,
mask_updater="snip",
)
snip_module = SparseMaskAdapter(adapter_config, nn.Linear(100, 100)).sparse_layer

assert snip_module.accumulated_sparse_weights.sum() == 0.0
sparse_weights = snip_module.sparse_layer.sparse_weights
sparse_weights.requires_grad = False
idxs_perm = torch.randperm(sparse_weights.flatten().shape[0])
idxs1 = idxs_perm[:100]
sparse_weights.flatten()[idxs1] += 1.0
assert sparse_weights.sum() == 100.0

assert snip_module.accumulated_sparse_weights.sum() == 0.0
snip_module.switch_to_mask_update_modus()
assert snip_module.accumulated_sparse_weights.sum() == 100.0

idxs2 = idxs_perm[100:200]
sparse_weights.flatten()[idxs2] += 1.0
assert sparse_weights.sum() == 200.0
snip_module.switch_to_mask_update_modus()
assert snip_module.accumulated_sparse_weights.sum() == 200.0

selected_indices = torch.zeros_like(snip_module.accumulated_sparse_weights)
# half already existing and half new
_, idxs = torch.topk(
snip_module.accumulated_sparse_weights.flatten(), 300, sorted=True
)
selected_indices.flatten()[idxs[100:]] = 1.0
assert selected_indices.sum() == 200.0
snip_module._selected_indices = selected_indices.float().to_sparse_coo()
snip_module.sparse_layer.sparse_weights *= 0.0
snip_module.switch_to_weights_update_modus()
assert snip_module.sparse_layer.sparse_weights.sum() == 100.0


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit da81782

Please sign in to comment.