Skip to content

Commit

Permalink
Removes explicit kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
dgasmith committed Sep 6, 2024
1 parent 1992f4a commit 0c0911c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 51 deletions.
53 changes: 11 additions & 42 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@

## Common types

_OrderKACF = Literal[None, "K", "A", "C", "F"]

_Casting = Literal["no", "equiv", "safe", "same_kind", "unsafe"]
_MemoryLimit = Union[None, int, Decimal, Literal["max_input"]]


Expand Down Expand Up @@ -479,9 +476,6 @@ def contract(
subscripts: str,
*operands: ArrayType,
out: ArrayType = ...,
dtype: Any = ...,
order: _OrderKACF = ...,
casting: _Casting = ...,
use_blas: bool = ...,
optimize: OptimizeKind = ...,
memory_limit: _MemoryLimit = ...,
Expand All @@ -495,9 +489,6 @@ def contract(
subscripts: ArrayType,
*operands: Union[ArrayType, Collection[int]],
out: ArrayType = ...,
dtype: Any = ...,
order: _OrderKACF = ...,
casting: _Casting = ...,
use_blas: bool = ...,
optimize: OptimizeKind = ...,
memory_limit: _MemoryLimit = ...,
Expand All @@ -510,9 +501,6 @@ def contract(
subscripts: Union[str, ArrayType],
*operands: Union[ArrayType, Collection[int]],
out: Optional[ArrayType] = None,
dtype: Optional[str] = None,
order: _OrderKACF = "K",
casting: _Casting = "safe",
use_blas: bool = True,
optimize: OptimizeKind = True,
memory_limit: _MemoryLimit = None,
Expand All @@ -527,9 +515,6 @@ def contract(
subscripts: Specifies the subscripts for summation.
*operands: These are the arrays for the operation.
out: A output array in which set the resulting output.
dtype: The dtype of the given contraction, see np.einsum.
order: The order of the resulting contraction, see np.einsum.
casting: The casting procedure for operations of different dtype, see np.einsum.
use_blas: Do you use BLAS for valid operations, may use extra memory for more intermediates.
optimize:- Choose the type of path the contraction will be optimized with
- if a list is given uses this as the path.
Expand Down Expand Up @@ -590,17 +575,14 @@ def contract(
optimize = "auto"

operands_list = [subscripts] + list(operands)
einsum_kwargs = {"out": out, "dtype": dtype, "order": order, "casting": casting}

# If no optimization, run pure einsum
if optimize is False:
return _einsum(*operands_list, **einsum_kwargs)
return _einsum(*operands_list, out=out, **kwargs)

# Grab non-einsum kwargs
gen_expression = kwargs.pop("_gen_expression", False)
constants_dict = kwargs.pop("_constants_dict", {})
if len(kwargs):
raise TypeError(f"Did not understand the following kwargs: {kwargs.keys()}")

if gen_expression:
full_str = operands_list[0]
Expand All @@ -613,11 +595,9 @@ def contract(

# check if performing contraction or just building expression
if gen_expression:
return ContractExpression(full_str, contraction_list, constants_dict, dtype=dtype, order=order, casting=casting)
return ContractExpression(full_str, contraction_list, constants_dict, **kwargs)

return _core_contract(
operands, contraction_list, backend=backend, out=out, dtype=dtype, order=order, casting=casting
)
return _core_contract(operands, contraction_list, backend=backend, out=out, **kwargs)


@lru_cache(None)
Expand Down Expand Up @@ -651,9 +631,7 @@ def _core_contract(
backend: Optional[str] = "auto",
evaluate_constants: bool = False,
out: Optional[ArrayType] = None,
dtype: Optional[str] = None,
order: _OrderKACF = "K",
casting: _Casting = "safe",
**kwargs: Any,
) -> ArrayType:
"""Inner loop used to perform an actual contraction given the output
from a ``contract_path(..., einsum_call=True)`` call.
Expand Down Expand Up @@ -703,7 +681,7 @@ def _core_contract(
axes = ((), ())

# Contract!
new_view = _tensordot(*tmp_operands, axes=axes, backend=backend)
new_view = _tensordot(*tmp_operands, axes=axes, backend=backend, **kwargs)

# Build a new view if needed
if (tensor_result != results_index) or handle_out:
Expand All @@ -718,9 +696,7 @@ def _core_contract(
out_kwarg: Union[None, ArrayType] = None
if handle_out:
out_kwarg = out
new_view = _einsum(
einsum_str, *tmp_operands, backend=backend, dtype=dtype, order=order, casting=casting, out=out_kwarg
)
new_view = _einsum(einsum_str, *tmp_operands, backend=backend, out=out_kwarg, **kwargs)

# Append new items and dereference what we can
operands.append(new_view)
Expand Down Expand Up @@ -768,15 +744,11 @@ def __init__(
contraction: str,
contraction_list: ContractionListType,
constants_dict: Dict[int, ArrayType],
dtype: Optional[str] = None,
order: _OrderKACF = "K",
casting: _Casting = "safe",
**kwargs: Any,
):
self.contraction_list = contraction_list
self.dtype = dtype
self.order = order
self.casting = casting
self.contraction = format_const_einsum_str(contraction, constants_dict.keys())
self.contraction_list = contraction_list
self.kwargs = kwargs

# need to know _full_num_args to parse constants with, and num_args to call with
self._full_num_args = contraction.count(",") + 1
Expand Down Expand Up @@ -844,9 +816,7 @@ def _contract(
out=out,
backend=backend,
evaluate_constants=evaluate_constants,
dtype=self.dtype,
order=self.order,
casting=self.casting,
**self.kwargs,
)

def _contract_with_conversion(
Expand Down Expand Up @@ -943,8 +913,7 @@ def __str__(self) -> str:
for i, c in enumerate(self.contraction_list):
s.append(f"\n {i + 1}. ")
s.append(f"'{c[2]}'" + (f" [{c[-1]}]" if c[-1] else ""))
kwargs = {"dtype": self.dtype, "order": self.order, "casting": self.casting}
s.append(f"\neinsum_kwargs={kwargs}")
s.append(f"\neinsum_kwargs={self.kwargs}")
return "".join(s)


Expand Down
9 changes: 1 addition & 8 deletions opt_einsum/tests/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,9 @@ def test_value_errors(contract_fn: Any) -> None:
# broadcasting to new dimensions must be enabled explicitly
with pytest.raises(ValueError):
contract_fn("i", np.arange(6).reshape(2, 3))
if contract_fn is contract:
# contract_path does not have an `out` parameter
with pytest.raises(ValueError):
contract_fn("i->i", [[0, 1], [0, 1]], out=np.arange(4).reshape(2, 2))

with pytest.raises(TypeError):
contract_fn("i->i", [[0, 1], [0, 1]], bad_kwarg=True)

with pytest.raises(ValueError):
contract_fn("i->i", [[0, 1], [0, 1]], memory_limit=-1)
contract_fn("ij->ij", [[0, 1], [0, 1]], bad_kwarg=True)


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion opt_einsum/tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_flop_cost() -> None:


def test_bad_path_option() -> None:
with pytest.raises(TypeError):
with pytest.raises(KeyError):
oe.contract("a,b,c", [1], [2], [3], optimize="optimall", shapes=True) # type: ignore


Expand Down

0 comments on commit 0c0911c

Please sign in to comment.