diff --git a/devito/tools/abc.py b/devito/tools/abc.py index b943256979..162b3287d3 100644 --- a/devito/tools/abc.py +++ b/devito/tools/abc.py @@ -143,6 +143,14 @@ def __init__(self, a, b, c=4): kwargs.update({i: getattr(self, i) for i in self.__rkwargs__ if i not in kwargs}) + # If this object has SymPy assumptions associated with it, which were not + # in the kwargs, then include them + try: + assumptions = self._assumptions_orig + kwargs.update({k: v for k, v in assumptions.items() if k not in kwargs}) + except AttributeError: + pass + # Should we use a custom reconstructor? try: cls = self._rcls diff --git a/devito/types/basic.py b/devito/types/basic.py index e2859ea07e..d1dd3dcb93 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -338,10 +338,11 @@ class AbstractSymbol(sympy.Symbol, Basic, Pickable, Evaluable): def _filter_assumptions(cls, **kwargs): """Extract sympy.Symbol-specific kwargs.""" assumptions = {} - # pop predefined assumptions + # Pop predefined assumptions for key in ('real', 'imaginary', 'commutative'): kwargs.pop(key, None) - # extract sympy.Symbol-specific kwargs + + # Extract sympy.Symbol-specific kwargs for i in list(kwargs): if i in _assume_rules.defined_facts: assumptions[i] = kwargs.pop(i) diff --git a/tests/test_caching.py b/tests/test_caching.py index f4346706ea..8dca69fa60 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -11,6 +11,7 @@ TensorFunction, TensorTimeFunction, VectorTimeFunction) from devito.types import (DeviceID, NThreadsBase, NPThreads, Object, LocalObject, Scalar, Symbol, ThreadID) +from devito.types.basic import AbstractSymbol @pytest.fixture @@ -44,6 +45,16 @@ class TestHashing: Test hashing of symbolic objects. """ + def test_abstractsymbol(self): + """Test that different Symbols have different hash values.""" + s0 = AbstractSymbol('s') + s1 = AbstractSymbol('s') + assert s0 is not s1 + assert hash(s0) == hash(s1) + + s2 = AbstractSymbol('s', nonnegative=True) + assert hash(s0) != hash(s2) + def test_constant(self): """Test that different Constants have different hash value.""" c0 = Constant(name='c') diff --git a/tests/test_pickle.py b/tests/test_pickle.py index bf1b859a75..bb1ddb4027 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -18,7 +18,7 @@ PointerArray, Lock, PThreadArray, SharedData, Timer, DeviceID, NPThreads, ThreadID, TempFunction, Indirection, FIndexed) -from devito.types.basic import BoundSymbol +from devito.types.basic import BoundSymbol, AbstractSymbol from devito.tools import EnrichedTuple from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer, CallFromPointer, DefFunction) @@ -29,6 +29,22 @@ @pytest.mark.parametrize('pickle', [pickle0, pickle1]) class TestBasic: + def test_abstractsymbol(self, pickle): + s0 = AbstractSymbol('s') + s1 = AbstractSymbol('s', nonnegative=True, integer=False) + + pkl_s0 = pickle.dumps(s0) + pkl_s1 = pickle.dumps(s1) + + new_s0 = pickle.loads(pkl_s0) + new_s1 = pickle.loads(pkl_s1) + + assert s0.assumptions0 == new_s0.assumptions0 + assert s1.assumptions0 == new_s1.assumptions0 + + assert s0 == new_s0 + assert s1 == new_s1 + def test_constant(self, pickle): c = Constant(name='c') assert c.data == 0. diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 9dd0c48584..353fdc934c 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -17,6 +17,7 @@ from devito.tools import as_tuple from devito.types import (Array, Bundle, FIndexed, LocalObject, Object, Symbol as dSymbol) +from devito.types.basic import AbstractSymbol def test_float_indices(): @@ -70,6 +71,46 @@ def test_floatification_issue_1627(dtype, expected): assert str(exprs[0]) == expected +def test_sympy_assumptions(): + """ + Ensure that AbstractSymbol assumptions are set correctly and + preserved during rebuild. + """ + s0 = AbstractSymbol('s') + s1 = AbstractSymbol('s', nonnegative=True, integer=False, real=True) + + assert s0.is_negative is None + assert s0.is_positive is None + assert s0.is_integer is None + assert s0.is_real is True + assert s1.is_negative is False + assert s1.is_positive is True + assert s1.is_integer is False + assert s1.is_real is True + + s0r = s0._rebuild() + s1r = s1._rebuild() + + assert s0.assumptions0 == s0r.assumptions0 + assert s0 == s0r + + assert s1.assumptions0 == s1r.assumptions0 + assert s1 == s1r + + +def test_modified_sympy_assumptions(): + """ + Check that sympy assumptions can be changed during a rebuild. + """ + s0 = AbstractSymbol('s') + s1 = AbstractSymbol('s', nonnegative=True, integer=False, real=True) + + s2 = s0._rebuild(nonnegative=True, integer=False, real=True) + + assert s2.assumptions0 == s1.assumptions0 + assert s2 == s1 + + def test_constant(): c = Constant(name='c')