Skip to content

Commit

Permalink
Add HSTU ragged attention operator (#2453)
Browse files Browse the repository at this point in the history
Summary:
As the title says.

On H100:
```
$ python run_benchmark.py triton --op ragged_attention

            x_val    hstu_triton_ragged_attention-latency    hstu_triton_ragged_attention_persistent-latency
-----------------  --------------------------------------  -------------------------------------------------
(8, 4, 512, 2048)                               0.0141706                                          0.0128713
(8, 4, 512, 2048)                               0.0187315                                          0.0171204
(8, 4, 512, 2048)                               0.0156807                                          0.0155399
(8, 4, 512, 2048)                               0.0165724                                          0.0154679
(8, 4, 512, 2048)                               0.0163886                                          0.0157738
(8, 4, 512, 2048)                               0.0173378                                          0.0155991
(8, 4, 512, 2048)                               0.0164874                                          0.0153128
(8, 4, 512, 2048)                               0.0203275                                          0.0172193
(8, 4, 512, 2048)                               0.0214526                                          0.0185414
(8, 4, 512, 2048)                               0.0172307                                          0.0169625
```


Differential Revision: D62513596

Pulled By: xuzhao9
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Sep 12, 2024
1 parent 7acad50 commit 0ae15fa
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 0 deletions.
1 change: 1 addition & 0 deletions torchbenchmark/operators/ragged_attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .operator import Operator
181 changes: 181 additions & 0 deletions torchbenchmark/operators/ragged_attention/hstu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import torch
import triton
from torchbenchmark import add_path, SUBMODULE_PATH

try:
# Internal Import
from hammer.generative_recommenders.ops.triton.triton_ragged_hstu_attention import (
_ragged_hstu_attn_fwd,
_ragged_hstu_attn_fwd_persistent,
)
except ModuleNotFoundError:
# OSS Import
import importlib

with add_path(str(SUBMODULE_PATH)):
triton_ragged_hstu_attention = importlib.import_module(
"generative-recommenders.ops.triton.triton_ragged_hstu_attention"
)
_ragged_hstu_attn_fwd = triton_ragged_hstu_attention._ragged_hstu_attn_fwd
_ragged_hstu_attn_fwd_persistent = (
triton_ragged_hstu_attention._ragged_hstu_attn_fwd_persistent
)

from typing import Tuple


class RaggedHSTUAttn(torch.nn.Module):
def __init__(
self,
batch_size,
num_heads,
max_seq_len,
num_buckets,
persistent_kernel: bool = False,
) -> None:
self.batch_size = batch_size
self.num_heads = num_heads
self.max_seq_len = max_seq_len
self.num_buckets = num_buckets
super().__init__()
self.all_ts_weights = torch.nn.Parameter(
torch.randn(
(self.num_buckets + 1,),
dtype=torch.bfloat16,
).cuda()
)
self.all_pos_weights = torch.nn.Parameter(
torch.randn(
(2 * self.max_seq_len - 1,),
dtype=torch.bfloat16,
).cuda()
)
self.persistent_kernel = persistent_kernel

def forward(
self, qkv: torch.Tensor, seq_offsets: torch.Tensor, timestamps: torch.Tensor
) -> torch.Tensor:
torch._check(timestamps.size(0) + 1 == seq_offsets.size(0))

q = qkv[:, :, :128]
k = qkv[:, :, 128:256]
v = qkv[:, :, 256:384]
out = torch.zeros_like(v)

Z = timestamps.size(0)
N = timestamps.size(1) - 1
_, H, DimQ = q.shape
_, _, DimV = v.shape

kwargs = {
"Q": q,
"K": k,
"V": v,
"seq_offsets": seq_offsets,
"delta_x_offsets": None,
"TS": timestamps,
"TW": self.all_ts_weights,
"PW": self.all_pos_weights,
"Bias": None,
"seq2_offsets": None,
"num_targets": None,
"Scale": None,
"Out": out,
"stride_qm": q.stride(0),
"stride_qh": q.stride(1),
"stride_kn": k.stride(0),
"stride_kh": k.stride(1),
"stride_vn": v.stride(0),
"stride_vh": v.stride(1),
"stride_sz": None,
"stride_sm": None,
"stride_ts": timestamps.stride(0),
"stride_om": out.stride(0),
"stride_oh": out.stride(1),
"alpha": 0.08838834764831843,
"Z": Z,
"H": H,
"MAX_SEQ_LEN": N,
"DimQ": DimQ,
"DimV": DimV,
"DeltaSize": None,
"num_buckets": self.num_buckets,
"max_pos_ind": None,
"time_bucket_incr": 60.0,
"time_bucket_div": 1.0,
"time_delta": 0.0,
"INVALID_MASK_TYPE": "lower_triangular",
"CAUSAL": True,
"BUCKET_FN": "sqrt",
"ATTN_BIAS_TYPE": "fused",
"USE_TIME_BIAS": False,
"USE_POS_BIAS": False,
"HAS_MAX_POS_IND": False,
"HAS_MULTIPLE_TARGETS": False,
"HAS_ATTN_SCALE": False,
"IS_DELTA_Q": False,
"ALLOW_TF32": True,
"BLOCK_D_Q": DimQ,
"BLOCK_D_V": DimV,
"max_attn_len": 0,
"HAS_MAX_ATTN_LEN": False,
}
if self.persistent_kernel:
grid = (1216,)
# pyre-fixme[16]: Module `triton_ragged_hstu_attention` has no attribute
_ragged_hstu_attn_fwd_persistent[grid](**kwargs)
else:
grid = lambda meta: ( # noqa E731
triton.cdiv(N, meta["BLOCK_M"]),
Z * H,
)
# pyre-fixme[16]: Module `triton_ragged_hstu_attention` has no attribute
_ragged_hstu_attn_fwd[grid](**kwargs)

return out


def get_test_inputs(
batch_size, num_heads, max_seq_len
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
timestamp_deltas: torch.Tensor = (
torch.randint(
86400,
size=(batch_size, max_seq_len + 1),
)
.requires_grad_(False)
.cuda()
)
timestamps = timestamp_deltas.cumsum(dim=1)

lengths = (
torch.randint(
max_seq_len + 1,
size=(batch_size,),
)
.requires_grad_(False)
.cuda()
)
seq_offsets = (
torch.zeros(
(batch_size + 1,),
dtype=torch.int64,
)
.requires_grad_(False)
.cuda()
)
seq_offsets[1:] = torch.cumsum(
lengths,
dim=0,
)
L = int(seq_offsets[-1].item())

qkv = (
torch.randn(
(L, num_heads, 512),
dtype=torch.bfloat16,
)
.requires_grad_(False)
.cuda()
)
return qkv, seq_offsets, timestamps
62 changes: 62 additions & 0 deletions torchbenchmark/operators/ragged_attention/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import argparse

from typing import List, Optional

from torchbenchmark.util.triton_op import BenchmarkOperator, register_benchmark

from .hstu import get_test_inputs, RaggedHSTUAttn


def parse_op_args(args: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=int, default=8, help="Batch size")
parser.add_argument("--heads", type=int, default=4, help="Number of heads")
parser.add_argument("--max-seq-len-log2", type=int, default=9)
parser.add_argument("--num-buckets", type=int, default=2048)
return parser.parse_args(args)


class Operator(BenchmarkOperator):
DEFAULT_PRECISION = "bf16"

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args=extra_args)
args = parse_op_args(self.extra_args)
self.batch_size = args.batch_size
self.num_heads = args.heads
self.max_seq_len = 2**args.max_seq_len_log2
self.num_buckets = args.num_buckets
# set a default number of inputs
self._num_inputs = 10

@register_benchmark()
def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps):
attn = RaggedHSTUAttn(
self.batch_size,
self.num_heads,
self.max_seq_len,
self.num_buckets,
persistent_kernel=False,
)
return lambda: attn(qkv, seq_offsets, timestamps)

@register_benchmark()
def hstu_triton_ragged_attention_persistent(self, qkv, seq_offsets, timestamps):
attn = RaggedHSTUAttn(
self.batch_size,
self.num_heads,
self.max_seq_len,
self.num_buckets,
persistent_kernel=True,
)
return lambda: attn(qkv, seq_offsets, timestamps)

def get_x_val(self, example_inputs):
return (self.batch_size, self.num_heads, self.max_seq_len, self.num_buckets)

def get_input_iter(self):
for _input_id in range(self._num_inputs):
inputs = get_test_inputs(self.batch_size, self.num_heads, self.max_seq_len)
yield inputs

0 comments on commit 0ae15fa

Please sign in to comment.