Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster & memory-efficient logprobs calculation #583

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

li-plus
Copy link
Contributor

@li-plus li-plus commented Dec 2, 2023

The current logprobs_of_labels computes logprobs using a log_softmax followed by a gather. When the input logits is not contiguous, the log_softmax will make a copy of the logits, which is very large (batch_size * seq_len * vocab_size can be 32 * 2048 * 64000 * 2B = 8GB for typical settings).

This PR directly feeds the contiguous logits into log_softmax so as to reduce the peak cuda memory and remove redundant copy.

Test script:

import torch
from torch.utils.benchmark import Timer
from trlx.utils.modeling import logprobs_of_labels

def perf():
    batch_size, seq_len, vocab_size = 32, 2048, 64000
    logits = torch.randn((batch_size, seq_len, vocab_size), dtype=torch.half, device='cuda')
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long, device='cuda')

    # correctness
    assert torch.allclose(logprobs_of_labels(logits[:, :-1, :], input_ids[:, 1:]), logprobs_of_labels(logits, input_ids[:, 1:]))

    # peak memory test
    torch.cuda.empty_cache()
    logprobs_of_labels(logits[:, :-1, :], input_ids[:, 1:])
    print(f'original allocated: {torch.cuda.memory_allocated() / 1e9:.3f} GB, reserved: {torch.cuda.memory_reserved() / 1e9:.3f} GB')

    torch.cuda.empty_cache()
    logprobs_of_labels(logits, input_ids[:, 1:])
    print(f'optimized allocated: {torch.cuda.memory_allocated() / 1e9:.3f} GB, reserved: {torch.cuda.memory_reserved() / 1e9:.3f} GB')

    # speed test
    timer = Timer(stmt="logprobs_of_labels(logits[:, :-1, :], input_ids[:, 1:])", globals={**globals(), **locals()})
    elapsed_org = timer.timeit(100).mean
    print(f'original costs: {elapsed_org:.4f} s')

    timer = Timer(stmt="logprobs_of_labels(logits, input_ids[:, 1:])", globals={**globals(), **locals()})
    elapsed_opt = timer.timeit(100).mean
    print(f'optimized costs: {elapsed_opt:.4f} s')

perf()

Tested on a Tesla V100, method in this PR is both faster (1.6x speedup) and memory-efficient.

original allocated: 8.389 GB, reserved: 25.164 GB
optimized allocated: 8.389 GB, reserved: 16.779 GB
original costs: 0.0700 s
optimized costs: 0.0435 s

@codecov-commenter
Copy link

codecov-commenter commented Dec 2, 2023

Codecov Report

Attention: 6 lines in your changes are missing coverage. Please review.

Comparison is base (91a0f43) 43.58% compared to head (730d900) 43.58%.
Report is 1 commits behind head on main.

❗ Current head 730d900 differs from pull request most recent head aa1031a. Consider uploading reports for the commit aa1031a to get more accurate results

Files Patch % Lines
trlx/models/modeling_nemo_ppo.py 0.00% 3 Missing ⚠️
trlx/trainer/accelerate_ppo_trainer.py 57.14% 3 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #583   +/-   ##
=======================================
  Coverage   43.58%   43.58%           
=======================================
  Files          33       33           
  Lines        4974     4974           
=======================================
  Hits         2168     2168           
  Misses       2806     2806           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants