Skip to content

Commit

Permalink
Fix python 3.8 and 3.9 releases (#129)
Browse files Browse the repository at this point in the history
Fixes #128.
  • Loading branch information
alihassanijr authored May 19, 2024
1 parent 8422e23 commit 15db931
Show file tree
Hide file tree
Showing 12 changed files with 167 additions and 146 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## [Main branch]

## [0.17.1] - 2024-05-19
* Fixed interface for python 3.8 and 3.9

## [0.17.0] - 2024-05-02
* [Fused neighborhood attention](https://github.com/SHI-Labs/NATTEN/tree/main/docs/fna) (FNA) kernels
* 1D, 2D and 3D Neighborhood Attention are supported,
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ train models based on neighborhood attention faster and more efficiently.
FNA can be seen as a generalization of methods such as [Flash Attention](https://github.com/Dao-AILab/flash-attention/) and
[FMHA](https://github.com/facebookresearch/xformers/) from back-to-back matrix multiplication to
back-to-back tensor-tensor contraction, and comes with neighborhood attention masking built in.
This accelerates accelerates neighborhood attention, a multi-dimensional sliding window attention pattern,
This accelerates neighborhood attention, a multi-dimensional sliding window attention pattern,
by never storing the attention tensor to global memory, which aside from reducing global memory footprint also reduces
the memory bandwidth bottleneck.

Expand Down
2 changes: 1 addition & 1 deletion src/natten/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,4 @@
"disable_tiled_na",
]

__version__ = "0.17.0"
__version__ = "0.17.1"
120 changes: 60 additions & 60 deletions src/natten/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@
qk_cross_forward,
)
from .types import (
CausalArg1DType,
CausalArg2DType,
CausalArg3DType,
Dimension1DType,
Dimension2DType,
Dimension3DType,
CausalArg1DTypeOrDed,
CausalArg2DTypeOrDed,
CausalArg3DTypeOrDed,
Dimension1DTypeOrDed,
Dimension2DTypeOrDed,
Dimension3DTypeOrDed,
FnaBackwardConfigType,
FnaForwardConfigType,
NoneType,
Expand Down Expand Up @@ -86,9 +86,9 @@ def forward(
key: Tensor,
bias: Optional[Tensor],
additional_key: Optional[Tensor],
kernel_size_: int | Dimension1DType,
dilation_: int | Dimension1DType,
is_causal_: bool | CausalArg1DType,
kernel_size_: Dimension1DTypeOrDed,
dilation_: Dimension1DTypeOrDed,
is_causal_: CausalArg1DTypeOrDed,
) -> Tensor:
kernel_size, dilation, is_causal = check_all_args(
1, kernel_size_, dilation_, is_causal_
Expand Down Expand Up @@ -263,9 +263,9 @@ def forward(
attn: Tensor,
value: Tensor,
additional_value: Optional[Tensor],
kernel_size_: int | Dimension1DType,
dilation_: int | Dimension1DType,
is_causal_: bool | CausalArg1DType,
kernel_size_: Dimension1DTypeOrDed,
dilation_: Dimension1DTypeOrDed,
is_causal_: CausalArg1DTypeOrDed,
):
kernel_size, dilation, is_causal = check_all_args(
1, kernel_size_, dilation_, is_causal_
Expand Down Expand Up @@ -429,9 +429,9 @@ def forward(
key: Tensor,
bias: Optional[Tensor],
additional_key: Optional[Tensor],
kernel_size_: int | Dimension2DType,
dilation_: int | Dimension2DType,
is_causal_: bool | CausalArg2DType,
kernel_size_: Dimension2DTypeOrDed,
dilation_: Dimension2DTypeOrDed,
is_causal_: CausalArg2DTypeOrDed,
):
kernel_size, dilation, is_causal = check_all_args(
2, kernel_size_, dilation_, is_causal_
Expand Down Expand Up @@ -605,9 +605,9 @@ def forward(
attn: Tensor,
value: Tensor,
additional_value: Optional[Tensor],
kernel_size_: int | Dimension2DType,
dilation_: int | Dimension2DType,
is_causal_: bool | CausalArg2DType,
kernel_size_: Dimension2DTypeOrDed,
dilation_: Dimension2DTypeOrDed,
is_causal_: CausalArg2DTypeOrDed,
) -> Tensor:
kernel_size, dilation, is_causal = check_all_args(
2, kernel_size_, dilation_, is_causal_
Expand Down Expand Up @@ -771,9 +771,9 @@ def forward(
key: Tensor,
bias: Optional[Tensor],
additional_key: Optional[Tensor],
kernel_size_: int | Dimension3DType,
dilation_: int | Dimension3DType,
is_causal_: bool | CausalArg3DType,
kernel_size_: Dimension3DTypeOrDed,
dilation_: Dimension3DTypeOrDed,
is_causal_: CausalArg3DTypeOrDed,
) -> Tensor:
kernel_size, dilation, is_causal = check_all_args(
3, kernel_size_, dilation_, is_causal_
Expand Down Expand Up @@ -957,9 +957,9 @@ def forward(
attn: Tensor,
value: Tensor,
additional_value: Optional[Tensor],
kernel_size_: int | Dimension3DType,
dilation_: int | Dimension3DType,
is_causal_: bool | CausalArg3DType,
kernel_size_: Dimension3DTypeOrDed,
dilation_: Dimension3DTypeOrDed,
is_causal_: CausalArg3DTypeOrDed,
) -> Tensor:
kernel_size, dilation, is_causal = check_all_args(
3, kernel_size_, dilation_, is_causal_
Expand Down Expand Up @@ -1131,9 +1131,9 @@ def forward(
key: Tensor,
value: Tensor,
bias: Optional[Tensor],
kernel_size_: int | Dimension1DType,
dilation_: int | Dimension1DType,
is_causal_: bool | CausalArg1DType,
kernel_size_: Dimension1DTypeOrDed,
dilation_: Dimension1DTypeOrDed,
is_causal_: CausalArg1DTypeOrDed,
scale: float,
tiling_config_: FnaForwardConfigType,
tiling_config_backward_: FnaBackwardConfigType,
Expand Down Expand Up @@ -1268,9 +1268,9 @@ def forward(
key: Tensor,
value: Tensor,
bias: Optional[Tensor],
kernel_size_: int | Dimension2DType,
dilation_: int | Dimension2DType,
is_causal_: bool | CausalArg2DType,
kernel_size_: Dimension2DTypeOrDed,
dilation_: Dimension2DTypeOrDed,
is_causal_: CausalArg2DTypeOrDed,
scale: float,
tiling_config_: FnaForwardConfigType,
tiling_config_backward_: FnaBackwardConfigType,
Expand Down Expand Up @@ -1405,9 +1405,9 @@ def forward(
key: Tensor,
value: Tensor,
bias: Optional[Tensor],
kernel_size_: int | Dimension3DType,
dilation_: int | Dimension3DType,
is_causal_: bool | CausalArg3DType,
kernel_size_: Dimension3DTypeOrDed,
dilation_: Dimension3DTypeOrDed,
is_causal_: CausalArg3DTypeOrDed,
scale: float,
tiling_config_: FnaForwardConfigType,
tiling_config_backward_: FnaBackwardConfigType,
Expand Down Expand Up @@ -1536,10 +1536,10 @@ def backward(ctx, grad_out: Tensor) -> Tuple[
def na1d_qk(
query: Tensor,
key: Tensor,
kernel_size: int | Dimension1DType,
dilation: int | Dimension1DType = 1,
kernel_size: Dimension1DTypeOrDed,
dilation: Dimension1DTypeOrDed = 1,
additional_keys: Optional[Tensor] = None,
is_causal: Optional[bool | CausalArg1DType] = False,
is_causal: Optional[CausalArg1DTypeOrDed] = False,
rpb: Optional[Tensor] = None,
) -> Tensor:
if query.is_nested or key.is_nested:
Expand All @@ -1560,10 +1560,10 @@ def na1d_qk(
def na1d_av(
attn: Tensor,
value: Tensor,
kernel_size: int | Dimension1DType,
dilation: int | Dimension1DType = 1,
kernel_size: Dimension1DTypeOrDed,
dilation: Dimension1DTypeOrDed = 1,
additional_values: Optional[Tensor] = None,
is_causal: Optional[bool | CausalArg1DType] = False,
is_causal: Optional[CausalArg1DTypeOrDed] = False,
) -> Tensor:
if attn.is_nested or value.is_nested:
return na1d_av_nested(
Expand All @@ -1582,10 +1582,10 @@ def na1d_av(
def na2d_qk(
query: Tensor,
key: Tensor,
kernel_size: int | Dimension2DType,
dilation: int | Dimension2DType = 1,
kernel_size: Dimension2DTypeOrDed,
dilation: Dimension2DTypeOrDed = 1,
additional_keys: Optional[Tensor] = None,
is_causal: Optional[bool | CausalArg2DType] = False,
is_causal: Optional[CausalArg2DTypeOrDed] = False,
rpb: Optional[Tensor] = None,
) -> Tensor:
if query.is_nested or key.is_nested:
Expand All @@ -1606,10 +1606,10 @@ def na2d_qk(
def na2d_av(
attn: Tensor,
value: Tensor,
kernel_size: int | Dimension2DType,
dilation: int | Dimension2DType = 1,
kernel_size: Dimension2DTypeOrDed,
dilation: Dimension2DTypeOrDed = 1,
additional_values: Optional[Tensor] = None,
is_causal: Optional[bool | CausalArg2DType] = False,
is_causal: Optional[CausalArg2DTypeOrDed] = False,
) -> Tensor:
if attn.is_nested or value.is_nested:
return na2d_av_nested(
Expand All @@ -1628,10 +1628,10 @@ def na2d_av(
def na3d_qk(
query: Tensor,
key: Tensor,
kernel_size: int | Dimension3DType,
dilation: int | Dimension3DType = 1,
kernel_size: Dimension3DTypeOrDed,
dilation: Dimension3DTypeOrDed = 1,
additional_keys: Optional[Tensor] = None,
is_causal: Optional[bool | CausalArg3DType] = False,
is_causal: Optional[CausalArg3DTypeOrDed] = False,
rpb: Optional[Tensor] = None,
) -> Tensor:
if query.is_nested or key.is_nested:
Expand All @@ -1658,10 +1658,10 @@ def na3d_qk(
def na3d_av(
attn: Tensor,
value: Tensor,
kernel_size: int | Dimension3DType,
dilation: int | Dimension3DType,
kernel_size: Dimension3DTypeOrDed,
dilation: Dimension3DTypeOrDed,
additional_values: Optional[Tensor] = None,
is_causal: Optional[bool | CausalArg3DType] = False,
is_causal: Optional[CausalArg3DTypeOrDed] = False,
) -> Tensor:
if attn.is_nested or value.is_nested:
return na3d_av_nested(
Expand All @@ -1686,9 +1686,9 @@ def na1d(
query: Tensor,
key: Tensor,
value: Tensor,
kernel_size: int | Dimension1DType,
dilation: int | Dimension1DType = 1,
is_causal: Optional[bool | CausalArg1DType] = False,
kernel_size: Dimension1DTypeOrDed,
dilation: Dimension1DTypeOrDed = 1,
is_causal: Optional[CausalArg1DTypeOrDed] = False,
rpb: Optional[Tensor] = None,
scale: Optional[float] = None,
) -> Tensor:
Expand Down Expand Up @@ -1720,9 +1720,9 @@ def na2d(
query: Tensor,
key: Tensor,
value: Tensor,
kernel_size: int | Dimension2DType,
dilation: int | Dimension2DType = 1,
is_causal: Optional[bool | CausalArg2DType] = False,
kernel_size: Dimension2DTypeOrDed,
dilation: Dimension2DTypeOrDed = 1,
is_causal: Optional[CausalArg2DTypeOrDed] = False,
rpb: Optional[Tensor] = None,
scale: Optional[float] = None,
) -> Tensor:
Expand Down Expand Up @@ -1754,9 +1754,9 @@ def na3d(
query: Tensor,
key: Tensor,
value: Tensor,
kernel_size: int | Dimension3DType,
dilation: int | Dimension3DType = 1,
is_causal: Optional[bool | CausalArg3DType] = False,
kernel_size: Dimension3DTypeOrDed,
dilation: Dimension3DTypeOrDed = 1,
is_causal: Optional[CausalArg3DTypeOrDed] = False,
rpb: Optional[Tensor] = None,
scale: Optional[float] = None,
) -> Tensor:
Expand Down
8 changes: 4 additions & 4 deletions src/natten/na1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from .context import is_fna_enabled
from .functional import na1d, na1d_av, na1d_qk
from .types import CausalArg1DType, Dimension1DType
from .types import CausalArg1DTypeOrDed, Dimension1DTypeOrDed
from .utils import check_all_args, log

logger = log.get_logger(__name__)
Expand All @@ -43,9 +43,9 @@ def __init__(
self,
dim: int,
num_heads: int,
kernel_size: int | Dimension1DType,
dilation: int | Dimension1DType = 1,
is_causal: bool | CausalArg1DType = False,
kernel_size: Dimension1DTypeOrDed,
dilation: Dimension1DTypeOrDed = 1,
is_causal: CausalArg1DTypeOrDed = False,
rel_pos_bias: bool = False,
qkv_bias: bool = True,
qk_scale: Optional[float] = None,
Expand Down
8 changes: 4 additions & 4 deletions src/natten/na2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from .context import is_fna_enabled
from .functional import na2d, na2d_av, na2d_qk
from .types import CausalArg2DType, Dimension2DType
from .types import CausalArg2DTypeOrDed, Dimension2DTypeOrDed
from .utils import check_all_args, log

logger = log.get_logger(__name__)
Expand All @@ -43,9 +43,9 @@ def __init__(
self,
dim: int,
num_heads: int,
kernel_size: int | Dimension2DType,
dilation: int | Dimension2DType = 1,
is_causal: bool | CausalArg2DType = False,
kernel_size: Dimension2DTypeOrDed,
dilation: Dimension2DTypeOrDed = 1,
is_causal: CausalArg2DTypeOrDed = False,
rel_pos_bias: bool = False,
qkv_bias: bool = True,
qk_scale: Optional[float] = None,
Expand Down
8 changes: 4 additions & 4 deletions src/natten/na3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from .context import is_fna_enabled
from .functional import na3d, na3d_av, na3d_qk
from .types import CausalArg3DType, Dimension3DType
from .types import CausalArg3DTypeOrDed, Dimension3DTypeOrDed
from .utils import check_all_args, log

logger = log.get_logger(__name__)
Expand All @@ -43,9 +43,9 @@ def __init__(
self,
dim: int,
num_heads: int,
kernel_size: int | Dimension3DType,
dilation: int | Dimension3DType = 1,
is_causal: bool | CausalArg3DType = False,
kernel_size: Dimension3DTypeOrDed,
dilation: Dimension3DTypeOrDed = 1,
is_causal: CausalArg3DTypeOrDed = False,
rel_pos_bias: bool = False,
qkv_bias: bool = True,
qk_scale: Optional[float] = None,
Expand Down
Loading

0 comments on commit 15db931

Please sign in to comment.