Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve inference speed of multi-query attention model #3

Open
harm-devries opened this issue Sep 23, 2022 · 2 comments
Open

Improve inference speed of multi-query attention model #3

harm-devries opened this issue Sep 23, 2022 · 2 comments
Labels
help wanted Extra attention is needed

Comments

@harm-devries
Copy link
Contributor

harm-devries commented Sep 23, 2022

The multi-query attention paper reports up to 10x speed-ups compared to incremental decoding with multi-head attention model. We've implemented multi-query attention but only observed up to 25% speed-ups when it's fully integrated in the Transformers model. We did observe up to 2x speed-ups for a simplified version of the attention layer (without softmax and layer normalization). See more details here.

Further inference gains are likely possible but do require further investigation. For example, we would like to benchmark the difference in a more optimized inference environment like Deepspeed-inference. We are also happy to discuss other solutions and directions in the #wg-inference channel.

@harm-devries harm-devries changed the title Improve inference speed of multi-query attention Improve inference speed of multi-query attention model Sep 23, 2022
@pacman100
Copy link

Hello @harm-devries,

One central missing point in the experiments done above is that batch_size is pretty small at 8. In the paper, one important point/assumption to note is that batch_size needs to be large else the ratio of memory operations to arithmetic computation will be large thereby becoming the bottleneck. To quote paper Theoretically, given large batch size b, this should dramatically improve performance of incremental generation. As such, this helps in large batch inference settings. For example, in the paper, they had batch_size of 1024, (1 / batch_size) term would be ~10^-3 but in the above experiments, it is magnitudes larger at ~10^-1.

Would be worthwhile to try experiments using larger batche sizes.

@pacman100
Copy link

pacman100 commented Oct 21, 2022

To confirm my hypothesis, I quickly did below experiments on A100 (80GB) GPU.

In the below experiment, you can see with batch_size=1024, seq_length=128, MultiQuery1 being 6.3X faster (38/6) compared to Multihead.

python profile_hf_generate.py
/home/sourab/bigcode/transformers/src/transformers/__init__.py
NVIDIA A100-SXM4-80GB
-------------------- attention_type == AttentionType.MULTI_QUERY---------------------
{'get_test_batch': 2.193450927734375e-05, 'generate_text_batch': 10.881535291671753, 'input_batch_size': 1024, 'input_batch_length': 16, 'max_gen_length': 128, 'num_beams': 1, 'do_sample': False, 'pad_token_id': 50256, 'dtype': torch.int64, 'device': device(type='cuda'), 'cuda_device_name': 'NVIDIA A100-SXM4-80GB'}
-------------------- attention_type == AttentionType.MULTI_QUERY---------------------
{'get_test_batch': 2.193450927734375e-05, 'generate_text_batch': 10.306073904037476, 'input_batch_size': 1024, 'input_batch_length': 16, 'max_gen_length': 128, 'num_beams': 1, 'do_sample': False, 'pad_token_id': 50256, 'dtype': torch.int64, 'device': device(type='cuda'), 'cuda_device_name': 'NVIDIA A100-SXM4-80GB'}
-------------------- attention_type == AttentionType.MULTI_QUERY_1---------------------
{'get_test_batch': 2.09808349609375e-05, 'generate_text_batch': 6.453148603439331, 'input_batch_size': 1024, 'input_batch_length': 16, 'max_gen_length': 128, 'num_beams': 1, 'do_sample': False, 'pad_token_id': 50256, 'dtype': torch.int64, 'device': device(type='cuda'), 'cuda_device_name': 'NVIDIA A100-SXM4-80GB'}
-------------------- attention_type == AttentionType.MULTI_HEAD---------------------
{'get_test_batch': 2.288818359375e-05, 'generate_text_batch': 38.42392134666443, 'input_batch_size': 1024, 'input_batch_length': 16, 'max_gen_length': 128, 'num_beams': 1, 'do_sample': False, 'pad_token_id': 50256, 'dtype': torch.int64, 'device': device(type='cuda'), 'cuda_device_name': 'NVIDIA A100-SXM4-80GB'}

In the below experiment, you can see with batch_size=1024, seq_length=1024, Multi-head OOMs, but MuliQuery run with MultiQuery1 being 4.32X faster (340/80) when compared to MultiQuery and both can work for large batches and sequences.

python profile_hf_generate.py
/home/sourab/bigcode/transformers/src/transformers/__init__.py
NVIDIA A100-SXM4-80GB
Downloading vocab.json: 100%|██████████████████████████████████| 0.99M/0.99M [00:00<00:00, 1.73MB/s]
Downloading merges.txt: 100%|████████████████████████████████████| 446k/446k [00:00<00:00, 1.01MB/s]
Downloading config.json: 100%|██████████████████████████████████████| 665/665 [00:00<00:00, 587kB/s]
-------------------- attention_type == AttentionType.MULTI_QUERY---------------------
{'get_test_batch': 2.288818359375e-05, 'generate_text_batch': 363.0819971561432, 'input_batch_size': 1024, 'input_batch_length': 16, 'max_gen_length': 1024, 'num_beams': 1, 'do_sample': False, 'pad_token_id': 50256, 'dtype': torch.int64, 'device': device(type='cuda'), 'cuda_device_name': 'NVIDIA A100-SXM4-80GB'}
-------------------- attention_type == AttentionType.MULTI_QUERY---------------------
{'get_test_batch': 2.3603439331054688e-05, 'generate_text_batch': 346.3282334804535, 'input_batch_size': 1024, 'input_batch_length': 16, 'max_gen_length': 1024, 'num_beams': 1, 'do_sample': False, 'pad_token_id': 50256, 'dtype': torch.int64, 'device': device(type='cuda'), 'cuda_device_name': 'NVIDIA A100-SXM4-80GB'}
-------------------- attention_type == AttentionType.MULTI_QUERY_1---------------------
{'get_test_batch': 2.2172927856445312e-05, 'generate_text_batch': 80.81027579307556, 'input_batch_size': 1024, 'input_batch_length': 16, 'max_gen_length': 1024, 'num_beams': 1, 'do_sample': False, 'pad_token_id': 50256, 'dtype': torch.int64, 'device': device(type='cuda'), 'cuda_device_name': 'NVIDIA A100-SXM4-80GB'}
-------------------- attention_type == AttentionType.MULTI_HEAD---------------------
Traceback (most recent call last):
  File "/home/sourab/bigcode/bigcode-analysis/multi_query_experiments/profile_hf_generate.py", line 90, in <module>
    profile(AttentionType.MULTI_HEAD)
  File "/home/sourab/bigcode/bigcode-analysis/multi_query_experiments/profile_hf_generate.py", line 82, in profile
    inputs, outputs, stats = time_generate(tokenizer.vocab_size, model, 1024, 16, 1024, device=torch.device('cuda'))
  File "/home/sourab/bigcode/bigcode-analysis/multi_query_experiments/profile_hf_generate.py", line 44, in time_generate
    outputs = generate_text_batch(
  File "/home/sourab/bigcode/bigcode-analysis/multi_query_experiments/profile_hf_generate.py", line 22, in generate_text_batch
    return model.generate(
  File "/home/sourab/miniconda3/envs/temp/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/sourab/bigcode/transformers/src/transformers/generation_utils.py", line 1294, in generate
    return self.greedy_search(
  File "/home/sourab/bigcode/transformers/src/transformers/generation_utils.py", line 1689, in greedy_search
    outputs = self(
  File "/home/sourab/miniconda3/envs/temp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/sourab/bigcode/transformers/src/transformers/models/gpt2/modeling_gpt2.py", line 1219, in forward
    transformer_outputs = self.transformer(
  File "/home/sourab/miniconda3/envs/temp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/sourab/bigcode/transformers/src/transformers/models/gpt2/modeling_gpt2.py", line 1058, in forward
    outputs = block(
  File "/home/sourab/miniconda3/envs/temp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/sourab/bigcode/transformers/src/transformers/models/gpt2/modeling_gpt2.py", line 507, in forward
    attn_outputs = self.attn(
  File "/home/sourab/miniconda3/envs/temp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/sourab/bigcode/transformers/src/transformers/models/gpt2/modeling_gpt2.py", line 430, in forward
    value = torch.cat((past_value, value), dim=-2)
RuntimeError: CUDA out of memory.

So, Multi-Query attention doesn't help to reduce the latency in adaptive inference wherein we are trying to generate for a single prompt(batch_size=1, e.g., model hub inference API) or smaller batch_sizes. I hope this helps.

Appendix: Replicating the main experiments results with batch_size=8 and seq_length=1024 for reference in order to rule out GPU causing above behaviour.

python profile_hf_generate.py
/home/sourab/bigcode/transformers/src/transformers/__init__.py
NVIDIA A100-SXM4-80GB
-------------------- attention_type == AttentionType.MULTI_QUERY---------------------
{'get_test_batch': 2.1696090698242188e-05, 'generate_text_batch': 18.797884225845337, 'input_batch_size': 8, 'input_batch_length': 16, 'max_gen_length': 1024, 'num_beams': 1, 'do_sample': False, 'pad_token_id': 50256, 'dtype': torch.int64, 'device': device(type='cuda'), 'cuda_device_name': 'NVIDIA A100-SXM4-80GB'}
-------------------- attention_type == AttentionType.MULTI_QUERY---------------------
{'get_test_batch': 2.193450927734375e-05, 'generate_text_batch': 18.270429134368896, 'input_batch_size': 8, 'input_batch_length': 16, 'max_gen_length': 1024, 'num_beams': 1, 'do_sample': False, 'pad_token_id': 50256, 'dtype': torch.int64, 'device': device(type='cuda'), 'cuda_device_name': 'NVIDIA A100-SXM4-80GB'}
-------------------- attention_type == AttentionType.MULTI_QUERY_1---------------------
{'get_test_batch': 2.288818359375e-05, 'generate_text_batch': 16.58125400543213, 'input_batch_size': 8, 'input_batch_length': 16, 'max_gen_length': 1024, 'num_beams': 1, 'do_sample': False, 'pad_token_id': 50256, 'dtype': torch.int64, 'device': device(type='cuda'), 'cuda_device_name': 'NVIDIA A100-SXM4-80GB'}
-------------------- attention_type == AttentionType.MULTI_HEAD---------------------
{'get_test_batch': 2.2172927856445312e-05, 'generate_text_batch': 19.13312315940857, 'input_batch_size': 8, 'input_batch_length': 16, 'max_gen_length': 1024, 'num_beams': 1, 'do_sample': False, 'pad_token_id': 50256, 'dtype': torch.int64, 'device': device(type='cuda'), 'cuda_device_name': 'NVIDIA A100-SXM4-80GB'}

@harm-devries harm-devries transferred this issue from bigcode-project/admin Oct 31, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
Status: Todo
Development

No branches or pull requests

2 participants