diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7ec7f02a8..0ec2b8ada 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,3 +48,29 @@ repos: rev: 23.7.0 hooks: - id: black + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: "v1.0.0" + hooks: + - id: mypy + additional_dependencies: [typing_extensions>=4.4.0] + args: + - --ignore-missing-imports + - --config=pyproject.toml + files: ".*(_draft.*)$" + exclude: | + (?x)^( + .*creation_functions.py| + .*data_type_functions.py| + .*elementwise_functions.py| + .*fft.py| + .*indexing_functions.py| + .*linalg.py| + .*linear_algebra_functions.py| + .*manipulation_functions.py| + .*searching_functions.py| + .*set_functions.py| + .*sorting_functions.py| + .*statistical_functions.py| + .*utility_functions.py| + )$ diff --git a/pyproject.toml b/pyproject.toml index 57af04207..ae63cad21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,5 +30,15 @@ doc = [ requires = ["setuptools"] build-backend = "setuptools.build_meta" + [tool.black] line-length = 88 + + +[tool.mypy] +python_version = "3.9" +mypy_path = "$MYPY_CONFIG_FILE_DIR/src/array_api_stubs/_draft/" +files = [ + "src/array_api_stubs/_draft/**/*.py" +] +follow_imports = "silent" diff --git a/spec/draft/API_specification/array_object.rst b/spec/draft/API_specification/array_object.rst index b15bbdc43..0d848ba12 100644 --- a/spec/draft/API_specification/array_object.rst +++ b/spec/draft/API_specification/array_object.rst @@ -30,47 +30,47 @@ Arithmetic Operators A conforming implementation of the array API standard must provide and support an array object supporting the following Python arithmetic operators. -- ``+x``: :meth:`.array.__pos__` +- ``+x``: :meth:`.Array.__pos__` - `operator.pos(x) `_ - `operator.__pos__(x) `_ -- `-x`: :meth:`.array.__neg__` +- `-x`: :meth:`.Array.__neg__` - `operator.neg(x) `_ - `operator.__neg__(x) `_ -- `x1 + x2`: :meth:`.array.__add__` +- `x1 + x2`: :meth:`.Array.__add__` - `operator.add(x1, x2) `_ - `operator.__add__(x1, x2) `_ -- `x1 - x2`: :meth:`.array.__sub__` +- `x1 - x2`: :meth:`.Array.__sub__` - `operator.sub(x1, x2) `_ - `operator.__sub__(x1, x2) `_ -- `x1 * x2`: :meth:`.array.__mul__` +- `x1 * x2`: :meth:`.Array.__mul__` - `operator.mul(x1, x2) `_ - `operator.__mul__(x1, x2) `_ -- `x1 / x2`: :meth:`.array.__truediv__` +- `x1 / x2`: :meth:`.Array.__truediv__` - `operator.truediv(x1,x2) `_ - `operator.__truediv__(x1, x2) `_ -- `x1 // x2`: :meth:`.array.__floordiv__` +- `x1 // x2`: :meth:`.Array.__floordiv__` - `operator.floordiv(x1, x2) `_ - `operator.__floordiv__(x1, x2) `_ -- `x1 % x2`: :meth:`.array.__mod__` +- `x1 % x2`: :meth:`.Array.__mod__` - `operator.mod(x1, x2) `_ - `operator.__mod__(x1, x2) `_ -- `x1 ** x2`: :meth:`.array.__pow__` +- `x1 ** x2`: :meth:`.Array.__pow__` - `operator.pow(x1, x2) `_ - `operator.__pow__(x1, x2) `_ @@ -82,7 +82,7 @@ Array Operators A conforming implementation of the array API standard must provide and support an array object supporting the following Python array operators. -- `x1 @ x2`: :meth:`.array.__matmul__` +- `x1 @ x2`: :meth:`.Array.__matmul__` - `operator.matmul(x1, x2) `_ - `operator.__matmul__(x1, x2) `_ @@ -94,34 +94,34 @@ Bitwise Operators A conforming implementation of the array API standard must provide and support an array object supporting the following Python bitwise operators. -- `~x`: :meth:`.array.__invert__` +- `~x`: :meth:`.Array.__invert__` - `operator.inv(x) `_ - `operator.invert(x) `_ - `operator.__inv__(x) `_ - `operator.__invert__(x) `_ -- `x1 & x2`: :meth:`.array.__and__` +- `x1 & x2`: :meth:`.Array.__and__` - `operator.and(x1, x2) `_ - `operator.__and__(x1, x2) `_ -- `x1 | x2`: :meth:`.array.__or__` +- `x1 | x2`: :meth:`.Array.__or__` - `operator.or(x1, x2) `_ - `operator.__or__(x1, x2) `_ -- `x1 ^ x2`: :meth:`.array.__xor__` +- `x1 ^ x2`: :meth:`.Array.__xor__` - `operator.xor(x1, x2) `_ - `operator.__xor__(x1, x2) `_ -- `x1 << x2`: :meth:`.array.__lshift__` +- `x1 << x2`: :meth:`.Array.__lshift__` - `operator.lshift(x1, x2) `_ - `operator.__lshift__(x1, x2) `_ -- `x1 >> x2`: :meth:`.array.__rshift__` +- `x1 >> x2`: :meth:`.Array.__rshift__` - `operator.rshift(x1, x2) `_ - `operator.__rshift__(x1, x2) `_ @@ -133,32 +133,32 @@ Comparison Operators A conforming implementation of the array API standard must provide and support an array object supporting the following Python comparison operators. -- `x1 < x2`: :meth:`.array.__lt__` +- `x1 < x2`: :meth:`.Array.__lt__` - `operator.lt(x1, x2) `_ - `operator.__lt__(x1, x2) `_ -- `x1 <= x2`: :meth:`.array.__le__` +- `x1 <= x2`: :meth:`.Array.__le__` - `operator.le(x1, x2) `_ - `operator.__le__(x1, x2) `_ -- `x1 > x2`: :meth:`.array.__gt__` +- `x1 > x2`: :meth:`.Array.__gt__` - `operator.gt(x1, x2) `_ - `operator.__gt__(x1, x2) `_ -- `x1 >= x2`: :meth:`.array.__ge__` +- `x1 >= x2`: :meth:`.Array.__ge__` - `operator.ge(x1, x2) `_ - `operator.__ge__(x1, x2) `_ -- `x1 == x2`: :meth:`.array.__eq__` +- `x1 == x2`: :meth:`.Array.__eq__` - `operator.eq(x1, x2) `_ - `operator.__eq__(x1, x2) `_ -- `x1 != x2`: :meth:`.array.__ne__` +- `x1 != x2`: :meth:`.Array.__ne__` - `operator.ne(x1, x2) `_ - `operator.__ne__(x1, x2) `_ @@ -251,13 +251,13 @@ Attributes :toctree: generated :template: property.rst - array.dtype - array.device - array.mT - array.ndim - array.shape - array.size - array.T + Array.dtype + Array.device + Array.mT + Array.ndim + Array.shape + Array.size + Array.T ------------------------------------------------- @@ -271,37 +271,37 @@ Methods :toctree: generated :template: property.rst - array.__abs__ - array.__add__ - array.__and__ - array.__array_namespace__ - array.__bool__ - array.__complex__ - array.__dlpack__ - array.__dlpack_device__ - array.__eq__ - array.__float__ - array.__floordiv__ - array.__ge__ - array.__getitem__ - array.__gt__ - array.__index__ - array.__int__ - array.__invert__ - array.__le__ - array.__lshift__ - array.__lt__ - array.__matmul__ - array.__mod__ - array.__mul__ - array.__ne__ - array.__neg__ - array.__or__ - array.__pos__ - array.__pow__ - array.__rshift__ - array.__setitem__ - array.__sub__ - array.__truediv__ - array.__xor__ - array.to_device + Array.__abs__ + Array.__add__ + Array.__and__ + Array.__array_namespace__ + Array.__bool__ + Array.__complex__ + Array.__dlpack__ + Array.__dlpack_device__ + Array.__eq__ + Array.__float__ + Array.__floordiv__ + Array.__ge__ + Array.__getitem__ + Array.__gt__ + Array.__index__ + Array.__int__ + Array.__invert__ + Array.__le__ + Array.__lshift__ + Array.__lt__ + Array.__matmul__ + Array.__mod__ + Array.__mul__ + Array.__ne__ + Array.__neg__ + Array.__or__ + Array.__pos__ + Array.__pow__ + Array.__rshift__ + Array.__setitem__ + Array.__sub__ + Array.__truediv__ + Array.__xor__ + Array.to_device diff --git a/spec/draft/purpose_and_scope.md b/spec/draft/purpose_and_scope.md index eee3c7f4c..312d508ef 100644 --- a/spec/draft/purpose_and_scope.md +++ b/spec/draft/purpose_and_scope.md @@ -318,7 +318,7 @@ namespace (e.g. `import package_name.array_api`). This has two issues though: To address both issues, a uniform way must be provided by a conforming implementation to access the API namespace, namely a [method on the array -object](array.__array_namespace__): +object](Array.__array_namespace__): ``` xp = x.__array_namespace__() diff --git a/src/_array_api_conf.py b/src/_array_api_conf.py index d3a136eaa..e9fbaac83 100644 --- a/src/_array_api_conf.py +++ b/src/_array_api_conf.py @@ -63,8 +63,11 @@ ] nitpick_ignore_regex = [ ("py:class", ".*array"), + ("py:class", ".*Array"), ("py:class", ".*device"), + ("py:class", ".*Device"), ("py:class", ".*dtype"), + ("py:class", ".*Self"), ("py:class", ".*NestedSequence"), ("py:class", ".*SupportsBufferProtocol"), ("py:class", ".*PyCapsule"), @@ -77,6 +80,7 @@ "array": "array", "Device": "device", "Dtype": "dtype", + "DType": "dtype", } # Make autosummary show the signatures of functions in the tables using actual diff --git a/src/array_api_stubs/_draft/_types.py b/src/array_api_stubs/_draft/_types.py index 2a73dda24..448d419f0 100644 --- a/src/array_api_stubs/_draft/_types.py +++ b/src/array_api_stubs/_draft/_types.py @@ -29,6 +29,7 @@ from dataclasses import dataclass from typing import ( + TYPE_CHECKING, Any, List, Literal, @@ -41,9 +42,22 @@ ) from enum import Enum -array = TypeVar("array") -device = TypeVar("device") -dtype = TypeVar("dtype") + +if TYPE_CHECKING: + from .array_object import Array + from .data_types import DType + + +class Device(Protocol): + """Protocol for device objects.""" + + def __eq__(self, value: Any) -> bool: + ... + + +array = TypeVar("array", bound="Array") +device = TypeVar("device", bound=Device) +dtype = TypeVar("dtype", bound="DType") SupportsDLPack = TypeVar("SupportsDLPack") SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") PyCapsule = TypeVar("PyCapsule") @@ -61,7 +75,7 @@ class finfo_object: max: float min: float smallest_normal: float - dtype: dtype + dtype: DType @dataclass @@ -71,7 +85,7 @@ class iinfo_object: bits: int max: int min: int - dtype: dtype + dtype: DType _T_co = TypeVar("_T_co", covariant=True) diff --git a/src/array_api_stubs/_draft/array_object.py b/src/array_api_stubs/_draft/array_object.py index 2976b46b2..e523d3acf 100644 --- a/src/array_api_stubs/_draft/array_object.py +++ b/src/array_api_stubs/_draft/array_object.py @@ -1,27 +1,30 @@ from __future__ import annotations -__all__ = ["array"] - -from ._types import ( - array, - dtype as Dtype, - device as Device, - Optional, - Tuple, - Union, - Any, - PyCapsule, - Enum, - ellipsis, -) - - -class _array: - def __init__(self: array) -> None: +__all__ = ["Array"] + +from typing import TYPE_CHECKING, Protocol, TypeVar +from enum import Enum +from .data_types import DType +from ._types import Device + +if TYPE_CHECKING: + from ._types import ( + Any, + PyCapsule, + ellipsis, + ) + +array = TypeVar("array", bound="Array") +# NOTE: when working with py3.11+ this can be ``typing.Self``. + + +class Array(Protocol): + def __init__(self) -> None: """Initialize the attributes for the array object class.""" + ... @property - def dtype(self: array) -> Dtype: + def dtype(self) -> DType: """ Data type of the array elements. @@ -30,9 +33,10 @@ def dtype(self: array) -> Dtype: out: dtype array data type. """ + ... @property - def device(self: array) -> Device: + def device(self) -> Device: """ Hardware device the array data resides on. @@ -41,6 +45,7 @@ def device(self: array) -> Device: out: device a ``device`` object (see :ref:`device-support`). """ + ... @property def mT(self: array) -> array: @@ -54,9 +59,10 @@ def mT(self: array) -> array: out: array array whose last two dimensions (axes) are permuted in reverse order relative to original array (i.e., for an array instance having shape ``(..., M, N)``, the returned array must have shape ``(..., N, M)``). The returned array must have the same data type as the original array. """ + ... @property - def ndim(self: array) -> int: + def ndim(self) -> int: """ Number of array dimensions (axes). @@ -65,9 +71,10 @@ def ndim(self: array) -> int: out: int number of array dimensions (axes). """ + ... @property - def shape(self: array) -> Tuple[Optional[int], ...]: + def shape(self) -> tuple[int | None, ...]: """ Array dimensions. @@ -83,9 +90,10 @@ def shape(self: array) -> Tuple[Optional[int], ...]: .. note:: The returned value should be a tuple; however, where warranted, an array library may choose to return a custom shape object. If an array library returns a custom shape object, the object must be immutable, must support indexing for dimension retrieval, and must behave similarly to a tuple. """ + ... @property - def size(self: array) -> Optional[int]: + def size(self) -> int | None: """ Number of elements in an array. @@ -101,6 +109,7 @@ def size(self: array) -> Optional[int]: .. note:: For array libraries having graph-based computational models, an array may have unknown dimensions due to data-dependent operations. """ + ... @property def T(self: array) -> array: @@ -118,6 +127,7 @@ def T(self: array) -> array: .. note:: Limiting the transpose to two-dimensional arrays (matrices) deviates from the NumPy et al practice of reversing all axes for arrays having more than two-dimensions. This is intentional, as reversing all axes was found to be problematic (e.g., conflicting with the mathematical definition of a transpose which is limited to matrices; not operating on batches of matrices; et cetera). In order to reverse all axes, one is recommended to use the functional ``permute_dims`` interface found in this specification. """ + ... def __abs__(self: array, /) -> array: """ @@ -147,8 +157,9 @@ def __abs__(self: array, /) -> array: .. versionchanged:: 2022.12 Added complex data type support. """ + ... - def __add__(self: array, other: Union[int, float, array], /) -> array: + def __add__(self: array, other: int | float | Array, /) -> array: """ Calculates the sum for each element of an array instance with the respective element of the array ``other``. @@ -173,8 +184,9 @@ def __add__(self: array, other: Union[int, float, array], /) -> array: .. versionchanged:: 2022.12 Added complex data type support. """ + ... - def __and__(self: array, other: Union[int, bool, array], /) -> array: + def __and__(self: array, other: int | bool | array, /) -> array: """ Evaluates ``self_i & other_i`` for each element of an array instance with the respective element of the array ``other``. @@ -194,10 +206,9 @@ def __and__(self: array, other: Union[int, bool, array], /) -> array: .. note:: Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.bitwise_and`. """ + ... - def __array_namespace__( - self: array, /, *, api_version: Optional[str] = None - ) -> Any: + def __array_namespace__(self, /, *, api_version: str | None = None) -> Any: """ Returns an object that has all the array API functions on it. @@ -213,8 +224,9 @@ def __array_namespace__( out: Any an object representing the array API namespace. It should have every top-level function defined in the specification as an attribute. It may contain other public names as well, but it is recommended to only include those names that are part of the specification. """ + ... - def __bool__(self: array, /) -> bool: + def __bool__(self, /) -> bool: """ Converts a zero-dimensional array to a Python ``bool`` object. @@ -244,8 +256,9 @@ def __bool__(self: array, /) -> bool: .. versionchanged:: 2022.12 Added boolean and complex data type support. """ + ... - def __complex__(self: array, /) -> complex: + def __complex__(self, /) -> complex: """ Converts a zero-dimensional array to a Python ``complex`` object. @@ -278,10 +291,11 @@ def __complex__(self: array, /) -> complex: .. versionadded:: 2022.12 """ + ... def __dlpack__( - self: array, /, *, stream: Optional[Union[int, Any]] = None - ) -> PyCapsule: + self, /, *, stream: Any | None = None + ) -> PyCapsule: # type: ignore[type-var] """ Exports the array for consumption by :func:`~array_api.from_dlpack` as a DLPack capsule. @@ -349,8 +363,9 @@ def __dlpack__( .. versionchanged:: 2022.12 Added BufferError. """ + ... - def __dlpack_device__(self: array, /) -> Tuple[Enum, int]: + def __dlpack_device__(self, /) -> tuple[Enum, int]: """ Returns device type and device ID in DLPack format. Meant for use within :func:`~array_api.from_dlpack`. @@ -375,8 +390,12 @@ def __dlpack_device__(self: array, /) -> Tuple[Enum, int]: VPI = 9 ROCM = 10 """ + ... - def __eq__(self: array, other: Union[int, float, bool, array], /) -> array: + # Note that __eq__ returns an array while `object.__eq__` returns a bool. + # Hence Mypy will complain that this violates the Liskov substitution + # principle - ignore that. + def __eq__(self: array, other: int | float | bool | Array, /) -> array: # type: ignore[override] r""" Computes the truth value of ``self_i == other_i`` for each element of an array instance with the respective element of the array ``other``. @@ -396,8 +415,9 @@ def __eq__(self: array, other: Union[int, float, bool, array], /) -> array: .. note:: Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.equal`. """ + ... - def __float__(self: array, /) -> float: + def __float__(self, /) -> float: """ Converts a zero-dimensional array to a Python ``float`` object. @@ -427,8 +447,9 @@ def __float__(self: array, /) -> float: .. versionchanged:: 2022.12 Added boolean and complex data type support. """ + ... - def __floordiv__(self: array, other: Union[int, float, array], /) -> array: + def __floordiv__(self: array, other: int | float | Array, /) -> array: """ Evaluates ``self_i // other_i`` for each element of an array instance with the respective element of the array ``other``. @@ -451,8 +472,9 @@ def __floordiv__(self: array, other: Union[int, float, array], /) -> array: .. note:: Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.floor_divide`. """ + ... - def __ge__(self: array, other: Union[int, float, array], /) -> array: + def __ge__(self: array, other: int | float | Array, /) -> array: """ Computes the truth value of ``self_i >= other_i`` for each element of an array instance with the respective element of the array ``other``. @@ -475,12 +497,11 @@ def __ge__(self: array, other: Union[int, float, array], /) -> array: .. note:: Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.greater_equal`. """ + ... def __getitem__( self: array, - key: Union[ - int, slice, ellipsis, Tuple[Union[int, slice, ellipsis, None], ...], array - ], + key: int | slice | ellipsis | tuple[int | slice | ellipsis | None, ...] | array, /, ) -> array: """ @@ -498,8 +519,9 @@ def __getitem__( out: array an array containing the accessed value(s). The returned array must have the same data type as ``self``. """ + ... - def __gt__(self: array, other: Union[int, float, array], /) -> array: + def __gt__(self: array, other: int | float | Array, /) -> array: """ Computes the truth value of ``self_i > other_i`` for each element of an array instance with the respective element of the array ``other``. @@ -522,8 +544,9 @@ def __gt__(self: array, other: Union[int, float, array], /) -> array: .. note:: Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.greater`. """ + ... - def __index__(self: array, /) -> int: + def __index__(self, /) -> int: """ Converts a zero-dimensional integer array to a Python ``int`` object. @@ -540,8 +563,9 @@ def __index__(self: array, /) -> int: out: int a Python ``int`` object representing the single element of the array instance. """ + ... - def __int__(self: array, /) -> int: + def __int__(self, /) -> int: """ Converts a zero-dimensional array to a Python ``int`` object. @@ -580,6 +604,7 @@ def __int__(self: array, /) -> int: .. versionchanged:: 2022.12 Added boolean and complex data type support. """ + ... def __invert__(self: array, /) -> array: """ @@ -599,8 +624,9 @@ def __invert__(self: array, /) -> array: .. note:: Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.bitwise_invert`. """ + ... - def __le__(self: array, other: Union[int, float, array], /) -> array: + def __le__(self: array, other: int | float | Array, /) -> array: """ Computes the truth value of ``self_i <= other_i`` for each element of an array instance with the respective element of the array ``other``. @@ -623,8 +649,9 @@ def __le__(self: array, other: Union[int, float, array], /) -> array: .. note:: Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.less_equal`. """ + ... - def __lshift__(self: array, other: Union[int, array], /) -> array: + def __lshift__(self: array, other: int | array, /) -> array: """ Evaluates ``self_i << other_i`` for each element of an array instance with the respective element of the array ``other``. @@ -644,8 +671,9 @@ def __lshift__(self: array, other: Union[int, array], /) -> array: .. note:: Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.bitwise_left_shift`. """ + ... - def __lt__(self: array, other: Union[int, float, array], /) -> array: + def __lt__(self: array, other: int | float | Array, /) -> array: """ Computes the truth value of ``self_i < other_i`` for each element of an array instance with the respective element of the array ``other``. @@ -668,6 +696,7 @@ def __lt__(self: array, other: Union[int, float, array], /) -> array: .. note:: Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.less`. """ + ... def __matmul__(self: array, other: array, /) -> array: """ @@ -716,8 +745,9 @@ def __matmul__(self: array, other: array, /) -> array: .. versionchanged:: 2022.12 Added complex data type support. """ + ... - def __mod__(self: array, other: Union[int, float, array], /) -> array: + def __mod__(self: array, other: int | float | Array, /) -> array: """ Evaluates ``self_i % other_i`` for each element of an array instance with the respective element of the array ``other``. @@ -740,8 +770,9 @@ def __mod__(self: array, other: Union[int, float, array], /) -> array: .. note:: Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.remainder`. """ + ... - def __mul__(self: array, other: Union[int, float, array], /) -> array: + def __mul__(self: array, other: int | float | Array, /) -> array: r""" Calculates the product for each element of an array instance with the respective element of the array ``other``. @@ -769,8 +800,10 @@ def __mul__(self: array, other: Union[int, float, array], /) -> array: .. versionchanged:: 2022.12 Added complex data type support. """ + ... - def __ne__(self: array, other: Union[int, float, bool, array], /) -> array: + # See note above __eq__ method for explanation of the `type: ignore` + def __ne__(self: array, other: int | float | bool | Array, /) -> array: # type: ignore[override] """ Computes the truth value of ``self_i != other_i`` for each element of an array instance with the respective element of the array ``other``. @@ -796,6 +829,7 @@ def __ne__(self: array, other: Union[int, float, bool, array], /) -> array: .. versionchanged:: 2022.12 Added complex data type support. """ + ... def __neg__(self: array, /) -> array: """ @@ -826,8 +860,9 @@ def __neg__(self: array, /) -> array: .. versionchanged:: 2022.12 Added complex data type support. """ + ... - def __or__(self: array, other: Union[int, bool, array], /) -> array: + def __or__(self: array, other: int | bool | array, /) -> array: """ Evaluates ``self_i | other_i`` for each element of an array instance with the respective element of the array ``other``. @@ -847,6 +882,7 @@ def __or__(self: array, other: Union[int, bool, array], /) -> array: .. note:: Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.bitwise_or`. """ + ... def __pos__(self: array, /) -> array: """ @@ -871,8 +907,9 @@ def __pos__(self: array, /) -> array: .. versionchanged:: 2022.12 Added complex data type support. """ + ... - def __pow__(self: array, other: Union[int, float, array], /) -> array: + def __pow__(self: array, other: int | float | Array, /) -> array: r""" Calculates an implementation-dependent approximation of exponentiation by raising each element (the base) of an array instance to the power of ``other_i`` (the exponent), where ``other_i`` is the corresponding element of the array ``other``. @@ -902,8 +939,9 @@ def __pow__(self: array, other: Union[int, float, array], /) -> array: .. versionchanged:: 2022.12 Added complex data type support. """ + ... - def __rshift__(self: array, other: Union[int, array], /) -> array: + def __rshift__(self: array, other: int | array, /) -> array: """ Evaluates ``self_i >> other_i`` for each element of an array instance with the respective element of the array ``other``. @@ -923,13 +961,12 @@ def __rshift__(self: array, other: Union[int, array], /) -> array: .. note:: Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.bitwise_right_shift`. """ + ... def __setitem__( self: array, - key: Union[ - int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], array - ], - value: Union[int, float, bool, array], + key: int | slice | ellipsis | tuple[int | slice | ellipsis, ...] | array, + value: int | float | bool | array, /, ) -> None: """ @@ -953,12 +990,13 @@ def __setitem__( When ``value`` is an ``array`` of a different data type than ``self``, how values are cast to the data type of ``self`` is implementation defined. """ + ... - def __sub__(self: array, other: Union[int, float, array], /) -> array: + def __sub__(self: array, other: int | float | Array, /) -> array: """ Calculates the difference for each element of an array instance with the respective element of the array ``other``. - The result of ``self_i - other_i`` must be the same as ``self_i + (-other_i)`` and must be governed by the same floating-point rules as addition (see :meth:`array.__add__`). + The result of ``self_i - other_i`` must be the same as ``self_i + (-other_i)`` and must be governed by the same floating-point rules as addition (see :meth:`Array.__add__`). Parameters ---------- @@ -981,8 +1019,9 @@ def __sub__(self: array, other: Union[int, float, array], /) -> array: .. versionchanged:: 2022.12 Added complex data type support. """ + ... - def __truediv__(self: array, other: Union[int, float, array], /) -> array: + def __truediv__(self: array, other: int | float | Array, /) -> array: r""" Evaluates ``self_i / other_i`` for each element of an array instance with the respective element of the array ``other``. @@ -1012,8 +1051,9 @@ def __truediv__(self: array, other: Union[int, float, array], /) -> array: .. versionchanged:: 2022.12 Added complex data type support. """ + ... - def __xor__(self: array, other: Union[int, bool, array], /) -> array: + def __xor__(self: array, other: int | bool | array, /) -> array: """ Evaluates ``self_i ^ other_i`` for each element of an array instance with the respective element of the array ``other``. @@ -1033,9 +1073,10 @@ def __xor__(self: array, other: Union[int, bool, array], /) -> array: .. note:: Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.bitwise_xor`. """ + ... def to_device( - self: array, device: Device, /, *, stream: Optional[Union[int, Any]] = None + self: array, device: "Device", /, *, stream: Any | None = None # type: ignore[type-var] ) -> array: """ Copy the array from the device on which it currently resides to the specified ``device``. @@ -1046,8 +1087,8 @@ def to_device( array instance. device: device a ``device`` object (see :ref:`device-support`). - stream: Optional[Union[int, Any]] - stream object to use during copy. In addition to the types supported in :meth:`array.__dlpack__`, implementations may choose to support any library-specific stream object with the caveat that any code using such an object would not be portable. + stream: Optional[Any] + stream object to use during copy. In addition to the types supported in :meth:`Array.__dlpack__`, implementations may choose to support any library-specific stream object with the caveat that any code using such an object would not be portable. Returns ------- @@ -1058,6 +1099,4 @@ def to_device( .. note:: If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation. """ - - -array = _array + ... diff --git a/src/array_api_stubs/_draft/data_types.py b/src/array_api_stubs/_draft/data_types.py index d15f4a9f7..a0615bfeb 100644 --- a/src/array_api_stubs/_draft/data_types.py +++ b/src/array_api_stubs/_draft/data_types.py @@ -1,22 +1,26 @@ -__all__ = ["__eq__"] +from __future__ import annotations +__all__ = ["DType"] -from ._types import dtype +from typing import Protocol -def __eq__(self: dtype, other: dtype, /) -> bool: - """ - Computes the truth value of ``self == other`` in order to test for data type object equality. - Parameters - ---------- - self: dtype - data type instance. May be any supported data type. - other: dtype - other data type instance. May be any supported data type. +class DType(Protocol): + def __eq__(self, other: DType, /) -> bool: + """ + Computes the truth value of ``self == other`` in order to test for data type object equality. - Returns - ------- - out: bool - a boolean indicating whether the data type objects are equal. - """ + Parameters + ---------- + self: dtype + data type instance. May be any supported data type. + other: dtype + other data type instance. May be any supported data type. + + Returns + ------- + out: bool + a boolean indicating whether the data type objects are equal. + """ + ... diff --git a/src/array_api_stubs/_draft/linalg.py b/src/array_api_stubs/_draft/linalg.py index d05b53a9f..8302a5b87 100644 --- a/src/array_api_stubs/_draft/linalg.py +++ b/src/array_api_stubs/_draft/linalg.py @@ -299,7 +299,7 @@ def matrix_norm( /, *, keepdims: bool = False, - ord: Optional[Union[int, float, Literal[inf, -inf, "fro", "nuc"]]] = "fro", + ord: Optional[Union[int, float, Literal[inf, -inf, "fro", "nuc"]]] = "fro", # type: ignore ) -> array: """ Computes the matrix norm of a matrix (or a stack of matrices) ``x``. @@ -793,7 +793,7 @@ def vector_norm( *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - ord: Union[int, float, Literal[inf, -inf]] = 2, + ord: Union[int, float, Literal[inf, -inf]] = 2, # type: ignore ) -> array: r""" Computes the vector norm of a vector (or batch of vectors) ``x``.