Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse masks #108

Merged
merged 64 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
3e1fc4b
wip
oleksost Aug 26, 2024
d583249
snip gpu: works on v100 for me
oleksost Aug 27, 2024
3e2f77d
batched projection
oleksost Aug 28, 2024
7d27a86
Merge branch 'main' into snip_clustering
oleksost Sep 3, 2024
5b814ba
sparse masks with spops
oleksost Sep 5, 2024
6d0c927
clean up
oleksost Sep 5, 2024
5f0b196
added scattered implementation
oleksost Sep 5, 2024
9e09b4a
removed callback
oleksost Sep 5, 2024
f21ae88
spops
oleksost Sep 5, 2024
d75827e
renamed config to aguments
oleksost Sep 6, 2024
4ba52e6
Merge remote-tracking branch 'origin/main' into sparse_masks
oleksost Sep 6, 2024
44967a2
reorganized
oleksost Sep 9, 2024
3618279
removed unneccessary arguments
oleksost Sep 9, 2024
b053dd9
black
oleksost Sep 9, 2024
94022d8
requirements
oleksost Sep 9, 2024
2a1a9b1
requirements
oleksost Sep 9, 2024
6ecda06
requirements
oleksost Sep 9, 2024
9de2ba8
requirements
oleksost Sep 9, 2024
ff126f8
try import spops
oleksost Sep 9, 2024
7c7b090
tests
oleksost Sep 9, 2024
74f72ba
black
oleksost Sep 9, 2024
8518ac1
black
oleksost Sep 9, 2024
e5ae2bf
added profiler
oleksost Sep 9, 2024
97a4c66
black
oleksost Sep 9, 2024
8c92375
as expert
oleksost Sep 9, 2024
69cf475
black
oleksost Sep 9, 2024
b97fb7e
nvm
oleksost Sep 9, 2024
c04b075
black
oleksost Sep 9, 2024
b6544e8
nvm
oleksost Sep 9, 2024
5321396
profiler
oleksost Sep 10, 2024
0f80994
manual profiler
oleksost Sep 11, 2024
da821a7
black
oleksost Sep 11, 2024
28bc13f
black
oleksost Sep 11, 2024
130f8e7
nvm
oleksost Sep 11, 2024
03627ef
reorganized things
oleksost Sep 13, 2024
ee18fe0
nvm
oleksost Sep 13, 2024
48b0737
black
oleksost Sep 13, 2024
28e0e87
profile block sparsity
oleksost Sep 13, 2024
4b2be32
black
oleksost Sep 13, 2024
b044e18
block sparsity
oleksost Sep 15, 2024
4c0e045
reorg
oleksost Sep 15, 2024
3fed1ea
black
oleksost Sep 15, 2024
46b78a8
black
oleksost Sep 16, 2024
1e4e972
benchmarking
oleksost Sep 16, 2024
7ffb4f9
black
oleksost Sep 16, 2024
24dbe09
profiling sparse amsks
oleksost Sep 16, 2024
e1b8e2e
config for benchmarking
oleksost Sep 16, 2024
7762e1f
black
oleksost Sep 16, 2024
369c404
accumulate snips weights
oleksost Sep 16, 2024
da81782
snip accumulation test
oleksost Sep 16, 2024
ac8cad1
nvm
oleksost Sep 16, 2024
d43c572
to file
oleksost Sep 16, 2024
1e2be9b
rename file
oleksost Sep 18, 2024
7150cb3
remove nltk
oleksost Sep 18, 2024
bd0ab93
added some sparse implementations
oleksost Sep 18, 2024
effdc04
black
oleksost Sep 18, 2024
ff31684
removed not implemented
oleksost Sep 18, 2024
b599551
add_experts_from_library to module
oleksost Sep 18, 2024
8fdc608
linear_sd import
oleksost Sep 18, 2024
ec9d485
nvm
oleksost Sep 18, 2024
660cb46
readded mask_updater param
oleksost Sep 18, 2024
9cc2573
readded mask updated + black
oleksost Sep 18, 2024
481d0b2
hf hub version
oleksost Sep 19, 2024
0bc1662
eaded nltk
oleksost Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions mttl/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,50 @@
DEBUG = False


class UpdateSparseMask(pl.Callback):
def __init__(self, update_interval=5, dm=None, save_mask_dir=None):
super().__init__()
self.update_interval = update_interval
self.update_counter = 0

#
self.dm = dm
self.save_mask_dir = save_mask_dir

def update_mask(self, pl_module, batch):

from mttl.models.modifiers.sparse_mask import (
make_sparse_model_during_training,
save_mask,
)

# make_sparse_model(pl_module, self.dm, keep_ratio=self.keep_ratio)
make_sparse_model_during_training(pl_module, batch)
# save mask
# f_name = f"{self.save_mask_dir}/mask"
# save_mask(pl_module, f_name)

def on_train_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: torch.Tensor,
batch_idx: int,
) -> None:
"""
only updates the mask on epoch=0
"""
if trainer.current_epoch == 0:
self.update_counter += 1
if (
self.update_counter % self.update_interval == 0
or self.update_counter == 1 # to set mask at the beginning
):
# Update mask
self.update_mask(pl_module, batch)
self.update_counter = 0 # Reset counter for next interval


class LiveCheckpointCallback(pl.Callback):
"""A better model checkpoint callback, that works in synchrony with LiveLogMixin."""

Expand Down
6 changes: 6 additions & 0 deletions mttl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,12 @@ class TrainingArgs(DataArgs):
pipeline_eval_tasks: str = None
save_if_loaded_from_ckpt: bool = True
dataset_type: str = None


keep_ratio: float = 0.05
oleksost marked this conversation as resolved.
Show resolved Hide resolved
BLOCK_SIZE: int = 16 # used for block-sparsity, decides the size of the block
sps_type: str = "block_sparse" # block_sparse,regular_sparse
sps_impl: str = "sp_add+sp_mm" #sp_add+sp_mm, scatter+filte

@property
def dataset_config(self):
Expand Down
1 change: 1 addition & 0 deletions mttl/models/modifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
import mttl.models.modifiers.lora # noqa: F401
import mttl.models.modifiers.mlp # noqa: F401
import mttl.models.modifiers.prompt_tuning # noqa: F401
import mttl.models.modifiers.sparse_mask # noqa: F401
Loading