Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dsl: Fix missing sympy assumptions during rebuilding #2436

Merged
merged 3 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions devito/tools/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ def __init__(self, a, b, c=4):

kwargs.update({i: getattr(self, i) for i in self.__rkwargs__ if i not in kwargs})

# If this object has SymPy assumptions associated with it, which were not
# in the kwargs, then include them
try:
assumptions = self._assumptions_orig
kwargs.update({k: v for k, v in assumptions.items() if k not in kwargs})
except AttributeError:
pass

# Should we use a custom reconstructor?
try:
cls = self._rcls
Expand Down
5 changes: 3 additions & 2 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,11 @@ class AbstractSymbol(sympy.Symbol, Basic, Pickable, Evaluable):
def _filter_assumptions(cls, **kwargs):
"""Extract sympy.Symbol-specific kwargs."""
assumptions = {}
# pop predefined assumptions
# Pop predefined assumptions
for key in ('real', 'imaginary', 'commutative'):
kwargs.pop(key, None)
# extract sympy.Symbol-specific kwargs

# Extract sympy.Symbol-specific kwargs
for i in list(kwargs):
if i in _assume_rules.defined_facts:
assumptions[i] = kwargs.pop(i)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
TensorFunction, TensorTimeFunction, VectorTimeFunction)
from devito.types import (DeviceID, NThreadsBase, NPThreads, Object, LocalObject,
Scalar, Symbol, ThreadID)
from devito.types.basic import AbstractSymbol


@pytest.fixture
Expand Down Expand Up @@ -44,6 +45,16 @@ class TestHashing:
Test hashing of symbolic objects.
"""

def test_abstractsymbol(self):
"""Test that different Symbols have different hash values."""
s0 = AbstractSymbol('s')
s1 = AbstractSymbol('s')
assert s0 is not s1
assert hash(s0) == hash(s1)

s2 = AbstractSymbol('s', nonnegative=True)
assert hash(s0) != hash(s2)

def test_constant(self):
"""Test that different Constants have different hash value."""
c0 = Constant(name='c')
Expand Down
18 changes: 17 additions & 1 deletion tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
PointerArray, Lock, PThreadArray, SharedData, Timer,
DeviceID, NPThreads, ThreadID, TempFunction, Indirection,
FIndexed)
from devito.types.basic import BoundSymbol
from devito.types.basic import BoundSymbol, AbstractSymbol
from devito.tools import EnrichedTuple
from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer,
CallFromPointer, DefFunction)
Expand All @@ -29,6 +29,22 @@
@pytest.mark.parametrize('pickle', [pickle0, pickle1])
class TestBasic:

def test_abstractsymbol(self, pickle):
s0 = AbstractSymbol('s')
s1 = AbstractSymbol('s', nonnegative=True, integer=False)

pkl_s0 = pickle.dumps(s0)
pkl_s1 = pickle.dumps(s1)

new_s0 = pickle.loads(pkl_s0)
new_s1 = pickle.loads(pkl_s1)

assert s0.assumptions0 == new_s0.assumptions0
assert s1.assumptions0 == new_s1.assumptions0

assert s0 == new_s0
assert s1 == new_s1

def test_constant(self, pickle):
c = Constant(name='c')
assert c.data == 0.
Expand Down
41 changes: 41 additions & 0 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from devito.tools import as_tuple
from devito.types import (Array, Bundle, FIndexed, LocalObject, Object,
Symbol as dSymbol)
from devito.types.basic import AbstractSymbol


def test_float_indices():
Expand Down Expand Up @@ -70,6 +71,46 @@ def test_floatification_issue_1627(dtype, expected):
assert str(exprs[0]) == expected


def test_sympy_assumptions():
"""
Ensure that AbstractSymbol assumptions are set correctly and
preserved during rebuild.
"""
s0 = AbstractSymbol('s')
s1 = AbstractSymbol('s', nonnegative=True, integer=False, real=True)

assert s0.is_negative is None
assert s0.is_positive is None
assert s0.is_integer is None
assert s0.is_real is True
assert s1.is_negative is False
assert s1.is_positive is True
assert s1.is_integer is False
assert s1.is_real is True

s0r = s0._rebuild()
s1r = s1._rebuild()

assert s0.assumptions0 == s0r.assumptions0
assert s0 == s0r

assert s1.assumptions0 == s1r.assumptions0
assert s1 == s1r


def test_modified_sympy_assumptions():
"""
Check that sympy assumptions can be changed during a rebuild.
"""
s0 = AbstractSymbol('s')
s1 = AbstractSymbol('s', nonnegative=True, integer=False, real=True)

s2 = s0._rebuild(nonnegative=True, integer=False, real=True)

assert s2.assumptions0 == s1.assumptions0
assert s2 == s1
EdCaunt marked this conversation as resolved.
Show resolved Hide resolved


def test_constant():
c = Constant(name='c')

Expand Down
Loading