Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse masks #108

Merged
merged 64 commits into from
Sep 19, 2024
Merged

Sparse masks #108

merged 64 commits into from
Sep 19, 2024

Conversation

oleksost
Copy link
Collaborator

@oleksost oleksost commented Sep 5, 2024

Implements sparse masks in 3 different ways:

  • masked linear: compute and memory inefficient as it has to update sparse weights that are kept in dense format
  • scattered sparse: more memory and compute efficient, keeps sparse weights only and uses torch.scatter_add to only update the sparse weights
  • sparse linear: uses spops kernels to make things even faster. This also supports structured operations, so block spacity should be fast out of the box (not 100% sure, will double check), This does not work on some GPUs (spops compiled for sm_80 architectures like A100)

Also implements mask updates. Currently, only SNIP updater is implemented and SPieL is in the pipeline.

TODOs:

  • Tests are not implemented yet.
  • When updating the mask periodically with SNIP, shall we accumulate all weight updates for all used masks so far on CPU? (like in masked linear case by default)
  • Do some profiling
  • Make sure block structure is leveraged
  • SPieL mask updater

Currently, manual profiler gives me this (for GPT-neo 125M with 0.5% sparcity):

  • SparseLinearModule (spops) with regular sparsity - Runtime: 0.066590s, Allocated Memory: 4552.14MB, Reserved Memory: 4645.19MB
  • SparseLinearModule (spops) with blcok sparsity - Runtime: 0.067642s, Allocated Memory: 4553.58MB, Reserved Memory: 4645.19MB
  • ScatteredSparseLinearModule with block sparsity - Runtime: 0.052826s, Allocated Memory: 4734.14MB, Reserved Memory: 4817.16MB
  • ScatteredSparseLinearModule with regular sparsity - Runtime: 0.052953s, Allocated Memory: 4734.66MB, Reserved Memory: 4817.16MB
  • MaskedLinear with regular sparsity - Runtime: 0.056629s, Allocated Memory: 4892.71MB, Reserved Memory: 4970.25MB
  • MaskedLinear with block sparsity - Runtime: 0.055440s, Allocated Memory: 4889.36MB, Reserved Memory: 4978.64MB

So ScatteredSparseLinearModule is the fastest now but spops SparseLinearModule uses the least memory.


Profilled block sparse mult. with profile_block_sparcity.py: stk and triton block sparse outperform naive torch.matmul (see profile_block_sparcity.py):
image

@oleksost oleksost changed the title WIP, DO NOT MERGE: Sparse masks Sparse masks Sep 5, 2024
mttl/config.py Outdated Show resolved Hide resolved
mttl/models/modifiers/sparse_mask.py Outdated Show resolved Hide resolved
mttl/models/modifiers/sparse_mask.py Outdated Show resolved Hide resolved
Copy link
Contributor

@pclucas14 pclucas14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice man, very happy with this PR

@@ -0,0 +1,200 @@
# several options to compare for block sparce operations:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small typo in the file name (sparsity instead of sparsity)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed

@@ -19,6 +20,26 @@
from mttl.utils import generate_random_string, rank_zero_only_and_wait, remote_login


def setup_profiler(args: ExpertConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we put this in utils ? @matheper I know you want single use code to not be in utils but I think this could be useful somewhere else in the future

nltk
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this used ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its used by rouge evaluators (not automatically installed dependency)

@oleksost oleksost merged commit 0163d2f into main Sep 19, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants