Skip to content

Commit

Permalink
allow sharing for tensorflow and theano as numpy backends (#58)
Browse files Browse the repository at this point in the history
* allow sharing for tensorflow and theano as numpy backends

* fix: constants wasn't propagating into cache call
  • Loading branch information
jcmgray authored and dgasmith committed Aug 29, 2018
1 parent 3d72bba commit 1117f09
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 10 deletions.
5 changes: 4 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ dimension and index. However, it is only optimized to contract two terms
at a time resulting in non-optimal scaling.

For example, consider the following index transformation:
``M_{pqrs} = C_{pi} C_{qj} I_{ijkl} C_{rk} C_{sl}``

.. math::
M_{pqrs} = C_{pi} C_{qj} I_{ijkl} C_{rk} C_{sl}
Consider two different algorithms:

Expand Down
4 changes: 4 additions & 0 deletions opt_einsum/backends/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

import numpy as np

from ..sharing import to_backend_cache_wrap

__all__ = ["to_tensorflow", "build_expression", "evaluate_constants"]


_CACHED_TF_DEVICE = None


Expand All @@ -31,6 +34,7 @@ def _get_tensorflow_and_device():
return _CACHED_TF_DEVICE


@to_backend_cache_wrap(constants=True)
def to_tensorflow(array, constant=False):
"""Convert a numpy array to a ``tensorflow.placeholder`` instance.
"""
Expand Down
3 changes: 3 additions & 0 deletions opt_einsum/backends/theano.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@

import numpy as np

from ..sharing import to_backend_cache_wrap

__all__ = ["to_theano", "build_expression", "evaluate_constants"]


@to_backend_cache_wrap(constants=True)
def to_theano(array, constant=False):
"""Convert a numpy array to ``theano.tensor.TensorType`` instance.
"""
Expand Down
32 changes: 24 additions & 8 deletions opt_einsum/sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,18 +180,34 @@ def cached_einsum(*args, **kwargs):
return cached_einsum


def to_backend_cache_wrap(to_backend):
def to_backend_cache_wrap(to_backend=None, constants=False):
"""Decorates an ``to_backend()`` implementation to be memoized inside a
:func:`shared_intermediates` context (e.g. ``to_cupy``, ``to_torch``).
"""
# manage the case that decorator is called with args
if to_backend is None:
return functools.partial(to_backend_cache_wrap, constants=constants)

@functools.wraps(to_backend)
def cached_to_backend(array):
if not currently_sharing():
return to_backend(array)
if constants:

@functools.wraps(to_backend)
def cached_to_backend(array, constant=False):
if not currently_sharing():
return to_backend(array, constant=constant)

# hash by id
key = to_backend.__name__, id(array), constant
return _memoize(key, to_backend, array, constant=constant)

else:

@functools.wraps(to_backend)
def cached_to_backend(array):
if not currently_sharing():
return to_backend(array)

# hash by id
key = to_backend.__name__, id(array)
return _memoize(key, to_backend, array)
# hash by id
key = to_backend.__name__, id(array)
return _memoize(key, to_backend, array)

return cached_to_backend
50 changes: 49 additions & 1 deletion opt_einsum/tests/test_backends.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest

from opt_einsum import contract, helpers, contract_expression, backends
from opt_einsum import contract, helpers, contract_expression, backends, sharing

try:
import tensorflow as tf
Expand Down Expand Up @@ -90,6 +90,31 @@ def test_tensorflow_with_constants():
assert isinstance(res_got3, tf.Tensor)


@pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.")
@pytest.mark.parametrize("string", tests)
def test_tensorflow_with_sharing(string):
views = helpers.build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)

shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)

sess = tf.Session(config=_TF_CONFIG)

with sess.as_default(), sharing.shared_intermediates() as cache:
tfl1 = expr(*views, backend='tensorflow')
assert sharing.get_sharing_cache() is cache
cache_sz = len(cache)
assert cache_sz > 0
tfl2 = expr(*views, backend='tensorflow')
assert len(cache) == cache_sz

assert all(isinstance(t, tf.Tensor) for t in cache.values())

assert np.allclose(ein, tfl1)
assert np.allclose(ein, tfl2)


@pytest.mark.skipif(not found_theano, reason="Theano not installed.")
@pytest.mark.parametrize("string", tests)
def test_theano(string):
Expand Down Expand Up @@ -134,6 +159,29 @@ def test_theano_with_constants():
assert isinstance(res_got3, theano.tensor.TensorVariable)


@pytest.mark.skipif(not found_theano, reason="Theano not installed.")
@pytest.mark.parametrize("string", tests)
def test_theano_with_sharing(string):
views = helpers.build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)

shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)

with sharing.shared_intermediates() as cache:
thn1 = expr(*views, backend='theano')
assert sharing.get_sharing_cache() is cache
cache_sz = len(cache)
assert cache_sz > 0
thn2 = expr(*views, backend='theano')
assert len(cache) == cache_sz

assert all(isinstance(t, theano.tensor.TensorVariable) for t in cache.values())

assert np.allclose(ein, thn1)
assert np.allclose(ein, thn2)


@pytest.mark.parametrize("string", tests)
def test_cupy(string): # pragma: no cover
cupy = pytest.importorskip("cupy")
Expand Down

0 comments on commit 1117f09

Please sign in to comment.