diff --git a/CHANGELOG.md b/CHANGELOG.md index 8006c1e..516a35b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ## [Main branch] +## [0.17.1] - 2024-05-19 +* Fixed interface for python 3.8 and 3.9 + ## [0.17.0] - 2024-05-02 * [Fused neighborhood attention](https://github.com/SHI-Labs/NATTEN/tree/main/docs/fna) (FNA) kernels * 1D, 2D and 3D Neighborhood Attention are supported, diff --git a/README.md b/README.md index d90744c..60d5391 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ train models based on neighborhood attention faster and more efficiently. FNA can be seen as a generalization of methods such as [Flash Attention](https://github.com/Dao-AILab/flash-attention/) and [FMHA](https://github.com/facebookresearch/xformers/) from back-to-back matrix multiplication to back-to-back tensor-tensor contraction, and comes with neighborhood attention masking built in. -This accelerates accelerates neighborhood attention, a multi-dimensional sliding window attention pattern, +This accelerates neighborhood attention, a multi-dimensional sliding window attention pattern, by never storing the attention tensor to global memory, which aside from reducing global memory footprint also reduces the memory bandwidth bottleneck. diff --git a/src/natten/__init__.py b/src/natten/__init__.py index 8fb2e47..241430a 100644 --- a/src/natten/__init__.py +++ b/src/natten/__init__.py @@ -116,4 +116,4 @@ "disable_tiled_na", ] -__version__ = "0.17.0" +__version__ = "0.17.1" diff --git a/src/natten/functional.py b/src/natten/functional.py index 3af05ff..88d64e4 100644 --- a/src/natten/functional.py +++ b/src/natten/functional.py @@ -53,12 +53,12 @@ qk_cross_forward, ) from .types import ( - CausalArg1DType, - CausalArg2DType, - CausalArg3DType, - Dimension1DType, - Dimension2DType, - Dimension3DType, + CausalArg1DTypeOrDed, + CausalArg2DTypeOrDed, + CausalArg3DTypeOrDed, + Dimension1DTypeOrDed, + Dimension2DTypeOrDed, + Dimension3DTypeOrDed, FnaBackwardConfigType, FnaForwardConfigType, NoneType, @@ -86,9 +86,9 @@ def forward( key: Tensor, bias: Optional[Tensor], additional_key: Optional[Tensor], - kernel_size_: int | Dimension1DType, - dilation_: int | Dimension1DType, - is_causal_: bool | CausalArg1DType, + kernel_size_: Dimension1DTypeOrDed, + dilation_: Dimension1DTypeOrDed, + is_causal_: CausalArg1DTypeOrDed, ) -> Tensor: kernel_size, dilation, is_causal = check_all_args( 1, kernel_size_, dilation_, is_causal_ @@ -263,9 +263,9 @@ def forward( attn: Tensor, value: Tensor, additional_value: Optional[Tensor], - kernel_size_: int | Dimension1DType, - dilation_: int | Dimension1DType, - is_causal_: bool | CausalArg1DType, + kernel_size_: Dimension1DTypeOrDed, + dilation_: Dimension1DTypeOrDed, + is_causal_: CausalArg1DTypeOrDed, ): kernel_size, dilation, is_causal = check_all_args( 1, kernel_size_, dilation_, is_causal_ @@ -429,9 +429,9 @@ def forward( key: Tensor, bias: Optional[Tensor], additional_key: Optional[Tensor], - kernel_size_: int | Dimension2DType, - dilation_: int | Dimension2DType, - is_causal_: bool | CausalArg2DType, + kernel_size_: Dimension2DTypeOrDed, + dilation_: Dimension2DTypeOrDed, + is_causal_: CausalArg2DTypeOrDed, ): kernel_size, dilation, is_causal = check_all_args( 2, kernel_size_, dilation_, is_causal_ @@ -605,9 +605,9 @@ def forward( attn: Tensor, value: Tensor, additional_value: Optional[Tensor], - kernel_size_: int | Dimension2DType, - dilation_: int | Dimension2DType, - is_causal_: bool | CausalArg2DType, + kernel_size_: Dimension2DTypeOrDed, + dilation_: Dimension2DTypeOrDed, + is_causal_: CausalArg2DTypeOrDed, ) -> Tensor: kernel_size, dilation, is_causal = check_all_args( 2, kernel_size_, dilation_, is_causal_ @@ -771,9 +771,9 @@ def forward( key: Tensor, bias: Optional[Tensor], additional_key: Optional[Tensor], - kernel_size_: int | Dimension3DType, - dilation_: int | Dimension3DType, - is_causal_: bool | CausalArg3DType, + kernel_size_: Dimension3DTypeOrDed, + dilation_: Dimension3DTypeOrDed, + is_causal_: CausalArg3DTypeOrDed, ) -> Tensor: kernel_size, dilation, is_causal = check_all_args( 3, kernel_size_, dilation_, is_causal_ @@ -957,9 +957,9 @@ def forward( attn: Tensor, value: Tensor, additional_value: Optional[Tensor], - kernel_size_: int | Dimension3DType, - dilation_: int | Dimension3DType, - is_causal_: bool | CausalArg3DType, + kernel_size_: Dimension3DTypeOrDed, + dilation_: Dimension3DTypeOrDed, + is_causal_: CausalArg3DTypeOrDed, ) -> Tensor: kernel_size, dilation, is_causal = check_all_args( 3, kernel_size_, dilation_, is_causal_ @@ -1131,9 +1131,9 @@ def forward( key: Tensor, value: Tensor, bias: Optional[Tensor], - kernel_size_: int | Dimension1DType, - dilation_: int | Dimension1DType, - is_causal_: bool | CausalArg1DType, + kernel_size_: Dimension1DTypeOrDed, + dilation_: Dimension1DTypeOrDed, + is_causal_: CausalArg1DTypeOrDed, scale: float, tiling_config_: FnaForwardConfigType, tiling_config_backward_: FnaBackwardConfigType, @@ -1268,9 +1268,9 @@ def forward( key: Tensor, value: Tensor, bias: Optional[Tensor], - kernel_size_: int | Dimension2DType, - dilation_: int | Dimension2DType, - is_causal_: bool | CausalArg2DType, + kernel_size_: Dimension2DTypeOrDed, + dilation_: Dimension2DTypeOrDed, + is_causal_: CausalArg2DTypeOrDed, scale: float, tiling_config_: FnaForwardConfigType, tiling_config_backward_: FnaBackwardConfigType, @@ -1405,9 +1405,9 @@ def forward( key: Tensor, value: Tensor, bias: Optional[Tensor], - kernel_size_: int | Dimension3DType, - dilation_: int | Dimension3DType, - is_causal_: bool | CausalArg3DType, + kernel_size_: Dimension3DTypeOrDed, + dilation_: Dimension3DTypeOrDed, + is_causal_: CausalArg3DTypeOrDed, scale: float, tiling_config_: FnaForwardConfigType, tiling_config_backward_: FnaBackwardConfigType, @@ -1536,10 +1536,10 @@ def backward(ctx, grad_out: Tensor) -> Tuple[ def na1d_qk( query: Tensor, key: Tensor, - kernel_size: int | Dimension1DType, - dilation: int | Dimension1DType = 1, + kernel_size: Dimension1DTypeOrDed, + dilation: Dimension1DTypeOrDed = 1, additional_keys: Optional[Tensor] = None, - is_causal: Optional[bool | CausalArg1DType] = False, + is_causal: Optional[CausalArg1DTypeOrDed] = False, rpb: Optional[Tensor] = None, ) -> Tensor: if query.is_nested or key.is_nested: @@ -1560,10 +1560,10 @@ def na1d_qk( def na1d_av( attn: Tensor, value: Tensor, - kernel_size: int | Dimension1DType, - dilation: int | Dimension1DType = 1, + kernel_size: Dimension1DTypeOrDed, + dilation: Dimension1DTypeOrDed = 1, additional_values: Optional[Tensor] = None, - is_causal: Optional[bool | CausalArg1DType] = False, + is_causal: Optional[CausalArg1DTypeOrDed] = False, ) -> Tensor: if attn.is_nested or value.is_nested: return na1d_av_nested( @@ -1582,10 +1582,10 @@ def na1d_av( def na2d_qk( query: Tensor, key: Tensor, - kernel_size: int | Dimension2DType, - dilation: int | Dimension2DType = 1, + kernel_size: Dimension2DTypeOrDed, + dilation: Dimension2DTypeOrDed = 1, additional_keys: Optional[Tensor] = None, - is_causal: Optional[bool | CausalArg2DType] = False, + is_causal: Optional[CausalArg2DTypeOrDed] = False, rpb: Optional[Tensor] = None, ) -> Tensor: if query.is_nested or key.is_nested: @@ -1606,10 +1606,10 @@ def na2d_qk( def na2d_av( attn: Tensor, value: Tensor, - kernel_size: int | Dimension2DType, - dilation: int | Dimension2DType = 1, + kernel_size: Dimension2DTypeOrDed, + dilation: Dimension2DTypeOrDed = 1, additional_values: Optional[Tensor] = None, - is_causal: Optional[bool | CausalArg2DType] = False, + is_causal: Optional[CausalArg2DTypeOrDed] = False, ) -> Tensor: if attn.is_nested or value.is_nested: return na2d_av_nested( @@ -1628,10 +1628,10 @@ def na2d_av( def na3d_qk( query: Tensor, key: Tensor, - kernel_size: int | Dimension3DType, - dilation: int | Dimension3DType = 1, + kernel_size: Dimension3DTypeOrDed, + dilation: Dimension3DTypeOrDed = 1, additional_keys: Optional[Tensor] = None, - is_causal: Optional[bool | CausalArg3DType] = False, + is_causal: Optional[CausalArg3DTypeOrDed] = False, rpb: Optional[Tensor] = None, ) -> Tensor: if query.is_nested or key.is_nested: @@ -1658,10 +1658,10 @@ def na3d_qk( def na3d_av( attn: Tensor, value: Tensor, - kernel_size: int | Dimension3DType, - dilation: int | Dimension3DType, + kernel_size: Dimension3DTypeOrDed, + dilation: Dimension3DTypeOrDed, additional_values: Optional[Tensor] = None, - is_causal: Optional[bool | CausalArg3DType] = False, + is_causal: Optional[CausalArg3DTypeOrDed] = False, ) -> Tensor: if attn.is_nested or value.is_nested: return na3d_av_nested( @@ -1686,9 +1686,9 @@ def na1d( query: Tensor, key: Tensor, value: Tensor, - kernel_size: int | Dimension1DType, - dilation: int | Dimension1DType = 1, - is_causal: Optional[bool | CausalArg1DType] = False, + kernel_size: Dimension1DTypeOrDed, + dilation: Dimension1DTypeOrDed = 1, + is_causal: Optional[CausalArg1DTypeOrDed] = False, rpb: Optional[Tensor] = None, scale: Optional[float] = None, ) -> Tensor: @@ -1720,9 +1720,9 @@ def na2d( query: Tensor, key: Tensor, value: Tensor, - kernel_size: int | Dimension2DType, - dilation: int | Dimension2DType = 1, - is_causal: Optional[bool | CausalArg2DType] = False, + kernel_size: Dimension2DTypeOrDed, + dilation: Dimension2DTypeOrDed = 1, + is_causal: Optional[CausalArg2DTypeOrDed] = False, rpb: Optional[Tensor] = None, scale: Optional[float] = None, ) -> Tensor: @@ -1754,9 +1754,9 @@ def na3d( query: Tensor, key: Tensor, value: Tensor, - kernel_size: int | Dimension3DType, - dilation: int | Dimension3DType = 1, - is_causal: Optional[bool | CausalArg3DType] = False, + kernel_size: Dimension3DTypeOrDed, + dilation: Dimension3DTypeOrDed = 1, + is_causal: Optional[CausalArg3DTypeOrDed] = False, rpb: Optional[Tensor] = None, scale: Optional[float] = None, ) -> Tensor: diff --git a/src/natten/na1d.py b/src/natten/na1d.py index 5e905b0..48312d6 100644 --- a/src/natten/na1d.py +++ b/src/natten/na1d.py @@ -28,7 +28,7 @@ from .context import is_fna_enabled from .functional import na1d, na1d_av, na1d_qk -from .types import CausalArg1DType, Dimension1DType +from .types import CausalArg1DTypeOrDed, Dimension1DTypeOrDed from .utils import check_all_args, log logger = log.get_logger(__name__) @@ -43,9 +43,9 @@ def __init__( self, dim: int, num_heads: int, - kernel_size: int | Dimension1DType, - dilation: int | Dimension1DType = 1, - is_causal: bool | CausalArg1DType = False, + kernel_size: Dimension1DTypeOrDed, + dilation: Dimension1DTypeOrDed = 1, + is_causal: CausalArg1DTypeOrDed = False, rel_pos_bias: bool = False, qkv_bias: bool = True, qk_scale: Optional[float] = None, diff --git a/src/natten/na2d.py b/src/natten/na2d.py index 97637b6..1760687 100644 --- a/src/natten/na2d.py +++ b/src/natten/na2d.py @@ -28,7 +28,7 @@ from .context import is_fna_enabled from .functional import na2d, na2d_av, na2d_qk -from .types import CausalArg2DType, Dimension2DType +from .types import CausalArg2DTypeOrDed, Dimension2DTypeOrDed from .utils import check_all_args, log logger = log.get_logger(__name__) @@ -43,9 +43,9 @@ def __init__( self, dim: int, num_heads: int, - kernel_size: int | Dimension2DType, - dilation: int | Dimension2DType = 1, - is_causal: bool | CausalArg2DType = False, + kernel_size: Dimension2DTypeOrDed, + dilation: Dimension2DTypeOrDed = 1, + is_causal: CausalArg2DTypeOrDed = False, rel_pos_bias: bool = False, qkv_bias: bool = True, qk_scale: Optional[float] = None, diff --git a/src/natten/na3d.py b/src/natten/na3d.py index 438a14d..2e0bb7e 100644 --- a/src/natten/na3d.py +++ b/src/natten/na3d.py @@ -28,7 +28,7 @@ from .context import is_fna_enabled from .functional import na3d, na3d_av, na3d_qk -from .types import CausalArg3DType, Dimension3DType +from .types import CausalArg3DTypeOrDed, Dimension3DTypeOrDed from .utils import check_all_args, log logger = log.get_logger(__name__) @@ -43,9 +43,9 @@ def __init__( self, dim: int, num_heads: int, - kernel_size: int | Dimension3DType, - dilation: int | Dimension3DType = 1, - is_causal: bool | CausalArg3DType = False, + kernel_size: Dimension3DTypeOrDed, + dilation: Dimension3DTypeOrDed = 1, + is_causal: CausalArg3DTypeOrDed = False, rel_pos_bias: bool = False, qkv_bias: bool = True, qk_scale: Optional[float] = None, diff --git a/src/natten/nested.py b/src/natten/nested.py index a5b4c17..709120f 100644 --- a/src/natten/nested.py +++ b/src/natten/nested.py @@ -20,7 +20,7 @@ # SOFTWARE. # ################################################################################################# -from typing import List, Optional +from typing import Optional import torch from torch import Tensor @@ -37,12 +37,13 @@ from .ops import av_cross_forward, qk_cross_forward from .types import ( - CausalArg1DType, - CausalArg2DType, - CausalArg3DType, - Dimension1DType, - Dimension2DType, - Dimension3DType, + CausalArg1DTypeOrDed, + CausalArg2DTypeOrDed, + CausalArg3DTypeOrDed, + Dimension1DTypeOrDed, + Dimension2DTypeOrDed, + Dimension3DTypeOrDed, + ListOrNestedTensor, ) from .utils import ( check_additional_keys, @@ -57,10 +58,10 @@ def na1d_qk_nested( query: Tensor, key: Tensor, bias: Optional[Tensor], - kernel_size: int | Dimension1DType, - dilation: int | Dimension1DType, + kernel_size: Dimension1DTypeOrDed, + dilation: Dimension1DTypeOrDed, additional_keys: Optional[Tensor] = None, - is_causal: Optional[bool | CausalArg1DType] = False, + is_causal: Optional[CausalArg1DTypeOrDed] = False, ) -> Tensor: kernel_size_, dilation_, is_causal_ = check_all_args( 1, kernel_size, dilation, is_causal @@ -105,7 +106,7 @@ def na1d_qk_nested( "nested." ) - additional_keys_list: List | Tensor = ( + additional_keys_list: ListOrNestedTensor = ( [None for _ in range(query.size(0))] if additional_keys is None else additional_keys @@ -135,10 +136,10 @@ def na1d_qk_nested( def na1d_av_nested( attn: Tensor, value: Tensor, - kernel_size: int | Dimension1DType, - dilation: int | Dimension1DType, + kernel_size: Dimension1DTypeOrDed, + dilation: Dimension1DTypeOrDed, additional_values: Optional[Tensor] = None, - is_causal: Optional[bool | CausalArg1DType] = False, + is_causal: Optional[CausalArg1DTypeOrDed] = False, ): kernel_size_, dilation_, is_causal_ = check_all_args( 1, kernel_size, dilation, is_causal @@ -173,12 +174,12 @@ def na1d_av_nested( attn = attn.to(value.dtype) out = torch.empty_like(value) - additional_values_list: List | Tensor = ( + additional_values_list: ListOrNestedTensor = ( [None for _ in range(attn.size(0))] if additional_values is None else additional_values ) - additional_outputs_list: List | Tensor = ( + additional_outputs_list: ListOrNestedTensor = ( [None for _ in range(attn.size(0))] if additional_values is None else torch.empty_like(out) @@ -204,10 +205,10 @@ def na2d_qk_nested( query: Tensor, key: Tensor, bias: Optional[Tensor], - kernel_size: int | Dimension2DType, - dilation: int | Dimension2DType, + kernel_size: Dimension2DTypeOrDed, + dilation: Dimension2DTypeOrDed, additional_keys: Optional[Tensor] = None, - is_causal: Optional[bool | CausalArg2DType] = False, + is_causal: Optional[CausalArg2DTypeOrDed] = False, ) -> Tensor: kernel_size_, dilation_, is_causal_ = check_all_args( 2, kernel_size, dilation, is_causal @@ -252,7 +253,7 @@ def na2d_qk_nested( "nested." ) - additional_keys_list: List | Tensor = ( + additional_keys_list: ListOrNestedTensor = ( [None for _ in range(query.size(0))] if additional_keys is None else additional_keys @@ -282,10 +283,10 @@ def na2d_qk_nested( def na2d_av_nested( attn: Tensor, value: Tensor, - kernel_size: int | Dimension2DType, - dilation: int | Dimension2DType, + kernel_size: Dimension2DTypeOrDed, + dilation: Dimension2DTypeOrDed, additional_values: Optional[Tensor] = None, - is_causal: Optional[bool | CausalArg2DType] = False, + is_causal: Optional[CausalArg2DTypeOrDed] = False, ): kernel_size_, dilation_, is_causal_ = check_all_args( 2, kernel_size, dilation, is_causal @@ -320,12 +321,12 @@ def na2d_av_nested( attn = attn.to(value.dtype) out = torch.empty_like(value) - additional_values_list: List | Tensor = ( + additional_values_list: ListOrNestedTensor = ( [None for _ in range(attn.size(0))] if additional_values is None else additional_values ) - additional_outputs_list: List | Tensor = ( + additional_outputs_list: ListOrNestedTensor = ( [None for _ in range(attn.size(0))] if additional_values is None else torch.empty_like(out) @@ -351,10 +352,10 @@ def na3d_qk_nested( query: Tensor, key: Tensor, bias: Optional[Tensor], - kernel_size: int | Dimension3DType, - dilation: int | Dimension3DType, + kernel_size: Dimension3DTypeOrDed, + dilation: Dimension3DTypeOrDed, additional_keys: Optional[Tensor] = None, - is_causal: Optional[bool | CausalArg3DType] = False, + is_causal: Optional[CausalArg3DTypeOrDed] = False, ) -> Tensor: kernel_size_, dilation_, is_causal_ = check_all_args( 3, kernel_size, dilation, is_causal @@ -399,7 +400,7 @@ def na3d_qk_nested( "nested." ) - additional_keys_list: List | Tensor = ( + additional_keys_list: ListOrNestedTensor = ( [None for _ in range(query.size(0))] if additional_keys is None else additional_keys @@ -429,10 +430,10 @@ def na3d_qk_nested( def na3d_av_nested( attn: Tensor, value: Tensor, - kernel_size: int | Dimension3DType, - dilation: int | Dimension3DType, + kernel_size: Dimension3DTypeOrDed, + dilation: Dimension3DTypeOrDed, additional_values: Optional[Tensor] = None, - is_causal: Optional[bool | CausalArg3DType] = False, + is_causal: Optional[CausalArg3DTypeOrDed] = False, ): kernel_size_, dilation_, is_causal_ = check_all_args( 3, kernel_size, dilation, is_causal @@ -467,12 +468,12 @@ def na3d_av_nested( attn = attn.to(value.dtype) out = torch.empty_like(value) - additional_values_list: List | Tensor = ( + additional_values_list: ListOrNestedTensor = ( [None for _ in range(attn.size(0))] if additional_values is None else additional_values ) - additional_outputs_list: List | Tensor = ( + additional_outputs_list: ListOrNestedTensor = ( [None for _ in range(attn.size(0))] if additional_values is None else torch.empty_like(out) diff --git a/src/natten/types.py b/src/natten/types.py index c0d56a5..bb2fee3 100644 --- a/src/natten/types.py +++ b/src/natten/types.py @@ -21,7 +21,9 @@ # ################################################################################################# -from typing import Tuple +from typing import List, Tuple, Union + +from torch import Tensor NoneType = type(None) @@ -33,25 +35,36 @@ CausalArg2DType = Tuple[bool, bool] CausalArg3DType = Tuple[bool, bool, bool] -DimensionType = Dimension1DType | Dimension2DType | Dimension3DType -CausalArgType = CausalArg1DType | CausalArg2DType | CausalArg3DType +# NOTE: switch to | when < 3.10 support is dropped +Dimension1DTypeOrDed = Union[int, Dimension1DType] +Dimension2DTypeOrDed = Union[int, Dimension2DType] +Dimension3DTypeOrDed = Union[int, Dimension3DType] + +CausalArg1DTypeOrDed = Union[bool, CausalArg1DType] +CausalArg2DTypeOrDed = Union[bool, CausalArg2DType] +CausalArg3DTypeOrDed = Union[bool, CausalArg3DType] + +DimensionType = Union[Dimension1DType, Dimension2DType, Dimension3DType] +CausalArgType = Union[CausalArg1DType, CausalArg2DType, CausalArg3DType] # (query_tile_shape, kv_tile_shape) -FnaTileShapeType = ( - Tuple[Dimension1DType, Dimension1DType] - | Tuple[Dimension2DType, Dimension2DType] - | Tuple[Dimension3DType, Dimension3DType] -) +FnaTileShapeType = Union[ + Tuple[Dimension1DType, Dimension1DType], + Tuple[Dimension2DType, Dimension2DType], + Tuple[Dimension3DType, Dimension3DType], +] # (query_tile_shape, kv_tile_shape) FnaForwardConfigType = FnaTileShapeType # (query_tile_shape, kv_tile_shape, num_kv_splits, use_torch_to_compute_delta) -FnaBackwardConfigType = ( - Tuple[Dimension1DType, Dimension1DType, Dimension1DType, bool] - | Tuple[Dimension2DType, Dimension2DType, Dimension2DType, bool] - | Tuple[Dimension3DType, Dimension3DType, Dimension3DType, bool] -) +FnaBackwardConfigType = Union[ + Tuple[Dimension1DType, Dimension1DType, Dimension1DType, bool], + Tuple[Dimension2DType, Dimension2DType, Dimension2DType, bool], + Tuple[Dimension3DType, Dimension3DType, Dimension3DType, bool], +] + +ListOrNestedTensor = Union[List, Tensor] # Redundant, but here to accommodate the type checker. diff --git a/tools/utils/formatting.py b/tools/utils/formatting.py index f3ec60f..a79732b 100644 --- a/tools/utils/formatting.py +++ b/tools/utils/formatting.py @@ -21,7 +21,7 @@ # ################################################################################################# -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from torch.profiler import profile as torch_profile @@ -29,6 +29,9 @@ from .ops import CustomOp, NAOp +# NOTE: switch to | when < 3.10 support is dropped +OpType = Union[NAOp, CustomOp] + def _format_time(time_us: float) -> str: """ @@ -46,7 +49,7 @@ def _format_time(time_us: float) -> str: class Result: def __init__( self, - op: NAOp | CustomOp, + op: OpType, op_str: str, time: float, index: int = -1, @@ -164,9 +167,7 @@ def remove_wrapper(s, wrapper): return ".".join(namespace_split[:-1]), namespace_split[-1] -def convert_ops( - ops: Dict[NAOp | CustomOp, List[float]], tags: Dict -) -> Optional[List[Result]]: +def convert_ops(ops: Dict[OpType, List[float]], tags: Dict) -> Optional[List[Result]]: output = [] for op, values in ops.items(): if len(values) and isinstance(op, NAOp): @@ -292,7 +293,7 @@ def extract_na_ops( na_dim: int, ) -> Optional[List[Result]]: events = profiler.events() - logged_ops: Dict[NAOp | CustomOp, List[float]] = {na_op: [] for na_op in NAOp} + logged_ops: Dict[OpType, List[float]] = {na_op: [] for na_op in NAOp} tags: Dict[NAOp, Optional[str]] = {na_op: None for na_op in NAOp} for evt in events: op, valid, tag = str_to_na_op(sym=evt.key, na_dim=na_dim) diff --git a/tools/utils/problem.py b/tools/utils/problem.py index 4a3a189..bf529c7 100644 --- a/tools/utils/problem.py +++ b/tools/utils/problem.py @@ -23,7 +23,10 @@ import copy import math -from typing import Any, List, Optional +from typing import Any, List, Optional, Union + +# NOTE: switch to | when < 3.10 support is dropped +CausalType = Optional[Union[bool, tuple]] class Problem: @@ -38,7 +41,7 @@ def __init__( dilation: List[int], dtype: Any, has_bias: bool, - is_causal: Optional[bool | tuple] = None, + is_causal: CausalType = None, ): self.na_dim = na_dim self.batch_size = batch_size @@ -103,7 +106,7 @@ def generate_1d_problem( dilation: int, dtype: Any, has_bias: bool, - is_causal: Optional[bool | tuple] = None, + is_causal: CausalType = None, ) -> Problem: return Problem( na_dim=1, @@ -129,7 +132,7 @@ def generate_2d_problem( dilation: int, dtype: Any, has_bias: bool, - is_causal: Optional[bool | tuple] = None, + is_causal: CausalType = None, ) -> Problem: return Problem( na_dim=2, @@ -156,7 +159,7 @@ def generate_3d_problem( dilation: int, dtype: Any, has_bias: bool, - is_causal: Optional[bool | tuple] = None, + is_causal: CausalType = None, ) -> Problem: return Problem( na_dim=3, diff --git a/webpage/index.html b/webpage/index.html index a21798c..6945967 100644 --- a/webpage/index.html +++ b/webpage/index.html @@ -48,7 +48,7 @@

Install with pip

-

Latest release: 0.17.0

+

Latest release: 0.17.1

@@ -83,15 +83,15 @@

Run this command:

-
pip3 install natten==0.17.0+torch230cu121 -f https://shi-labs.com/natten/wheels/
+
pip3 install natten==0.17.1+torch230cu121 -f https://shi-labs.com/natten/wheels/

Run this command:

-
pip3 install natten==0.17.0+torch230cu118 -f https://shi-labs.com/natten/wheels
+
pip3 install natten==0.17.1+torch230cu118 -f https://shi-labs.com/natten/wheels

Run this command:

-
pip3 install natten==0.17.0+torch230cpu -f https://shi-labs.com/natten/wheels
+
pip3 install natten==0.17.1+torch230cpu -f https://shi-labs.com/natten/wheels
@@ -106,15 +106,15 @@

Run this command:

-
pip3 install natten==0.17.0+torch220cu121 -f https://shi-labs.com/natten/wheels/
+
pip3 install natten==0.17.1+torch220cu121 -f https://shi-labs.com/natten/wheels/

Run this command:

-
pip3 install natten==0.17.0+torch220cu118 -f https://shi-labs.com/natten/wheels
+
pip3 install natten==0.17.1+torch220cu118 -f https://shi-labs.com/natten/wheels

Run this command:

-
pip3 install natten==0.17.0+torch220cpu -f https://shi-labs.com/natten/wheels
+
pip3 install natten==0.17.1+torch220cpu -f https://shi-labs.com/natten/wheels
@@ -129,15 +129,15 @@

Run this command:

-
pip3 install natten==0.17.0+torch210cu121 -f https://shi-labs.com/natten/wheels/
+
pip3 install natten==0.17.1+torch210cu121 -f https://shi-labs.com/natten/wheels/

Run this command:

-
pip3 install natten==0.17.0+torch210cu118 -f https://shi-labs.com/natten/wheels
+
pip3 install natten==0.17.1+torch210cu118 -f https://shi-labs.com/natten/wheels

Run this command:

-
pip3 install natten==0.17.0+torch210cpu -f https://shi-labs.com/natten/wheels
+
pip3 install natten==0.17.1+torch210cpu -f https://shi-labs.com/natten/wheels
@@ -152,21 +152,21 @@

Run this command:

-
pip3 install natten==0.17.0+torch200cu118 -f https://shi-labs.com/natten/wheels
+
pip3 install natten==0.17.1+torch200cu118 -f https://shi-labs.com/natten/wheels

Run this command:

-
pip3 install natten==0.17.0+torch200cu117 -f https://shi-labs.com/natten/wheels
+
pip3 install natten==0.17.1+torch200cu117 -f https://shi-labs.com/natten/wheels

Run this command:

-
pip3 install natten==0.17.0+torch200cpu -f https://shi-labs.com/natten/wheels
+
pip3 install natten==0.17.1+torch200cpu -f https://shi-labs.com/natten/wheels
-

Your build isn't listed? Mac user? Just do:

pip install natten==0.17.0

+

Your build isn't listed? Mac user? Just do:

pip install natten==0.17.1

Careful though, without pre-compiled wheels installing might take a while.

You're also required to have CUDA > 11.7, cmake > 3.20 and PyTorch > 2.0 installed before attempting to install/build NATTEN.