Skip to content

Commit

Permalink
Allows booleans as optimize, closes #219
Browse files Browse the repository at this point in the history
  • Loading branch information
dgasmith committed Sep 6, 2024
1 parent 4fed3e0 commit 5d6d6aa
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
13 changes: 8 additions & 5 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def contract_path(
#> 5 defg,hd->efgh efgh->efgh
```
"""
if optimize is True:
if (optimize is True) or (optimize is None):
optimize = "auto"

# Hidden option, only einsum should call this
Expand Down Expand Up @@ -341,9 +341,11 @@ def contract_path(
naive_cost = helpers.flop_count(indices, inner_product, num_ops, size_dict)

# Compute the path
if not isinstance(optimize, (str, paths.PathOptimizer)):
if optimize is False:
path_tuple: PathType = [tuple(range(num_ops))]
elif not isinstance(optimize, (str, paths.PathOptimizer)):
# Custom path supplied
path_tuple: PathType = optimize # type: ignore
path_tuple = optimize # type: ignore
elif num_ops <= 2:
# Nothing to be optimized
path_tuple = [tuple(range(num_ops))]
Expand Down Expand Up @@ -536,11 +538,12 @@ def contract(
- `'branch-2'` An even more restricted version of 'branch-all' that
only searches the best two options at each step. Scales exponentially
with the number of terms in the contraction.
- `'auto'` Choose the best of the above algorithms whilst aiming to
- `'auto', None, True` Choose the best of the above algorithms whilst aiming to
keep the path finding time below 1ms.
- `'auto-hq'` Aim for a high quality contraction, choosing the best
of the above algorithms whilst aiming to keep the path finding time
below 1sec.
- `False` will not optimize the contraction.
memory_limit:- Give the upper bound of the largest intermediate tensor contract will build.
- None or -1 means there is no limit.
Expand Down Expand Up @@ -571,7 +574,7 @@ def contract(
performed optimally. When NumPy is linked to a threaded BLAS, potential
speedups are on the order of 20-100 for a six core machine.
"""
if optimize is True:
if (optimize is True) or (optimize is None):
optimize = "auto"

operands_list = [subscripts] + list(operands)
Expand Down
13 changes: 13 additions & 0 deletions opt_einsum/tests/test_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# NumPy is required for the majority of this file
np = pytest.importorskip("numpy")


tests = [
# Test scalar-like operations
"a,->a",
Expand Down Expand Up @@ -99,6 +100,18 @@
]


@pytest.mark.parametrize("optimize", (True, False, None))
def test_contract_plain_types(optimize: OptimizeKind) -> None:
expr = "ij,jk,kl->il"
ops = [np.random.rand(2, 2), np.random.rand(2, 2), np.random.rand(2, 2)]

path = contract_path(expr, *ops, optimize=optimize)
assert len(path) == 2

result = contract(expr, *ops, optimize=optimize)
assert result.shape == (2, 2)


@pytest.mark.parametrize("string", tests)
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
def test_compare(optimize: OptimizeKind, string: str) -> None:
Expand Down

0 comments on commit 5d6d6aa

Please sign in to comment.