Skip to content

Commit

Permalink
Merge pull request #2429 from devitocodes/fd-corner
Browse files Browse the repository at this point in the history
api: Make Derivative Reconstructable
  • Loading branch information
mloubout committed Jul 26, 2024
2 parents cd6ffbb + f072f51 commit 4ef520b
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 31 deletions.
49 changes: 23 additions & 26 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from .differentiable import Differentiable
from .tools import direct, transpose
from .rsfd import d45
from devito.tools import as_mapper, as_tuple, filter_ordered, frozendict, is_integer
from devito.tools import (as_mapper, as_tuple, filter_ordered, frozendict, is_integer,
Reconstructable)
from devito.types.utils import DimensionTuple

__all__ = ['Derivative']


class Derivative(sympy.Derivative, Differentiable):
class Derivative(sympy.Derivative, Differentiable, Reconstructable):

"""
An unevaluated Derivative, which carries metadata (Dimensions,
Expand Down Expand Up @@ -86,7 +87,7 @@ class Derivative(sympy.Derivative, Differentiable):

_fd_priority = 3

__rargs__ = ('expr', 'dims')
__rargs__ = ('expr', '*dims')
__rkwargs__ = ('side', 'deriv_order', 'fd_order', 'transpose', '_ppsubs',
'x0', 'method')

Expand Down Expand Up @@ -201,7 +202,7 @@ def _process_x0(cls, dims, **kwargs):
# Only given a value
_x0 = kwargs.get('x0')
assert len(dims) == 1 or _x0 is None
if _x0 is not None:
if _x0 is not None and _x0 is not dims[0]:
x0 = frozendict({dims[0]: _x0})
else:
x0 = frozendict({})
Expand All @@ -215,8 +216,7 @@ def __call__(self, x0=None, fd_order=None, side=None, method=None):
fd_order = fd_order or self._fd_order
side = side or self._side
method = method or self._method
return self._new_from_self(fd_order=fd_order, side=side, x0=_x0,
method=method)
return self._rebuild(fd_order=fd_order, side=side, x0=_x0, method=method)

if side is not None:
raise TypeError("Side only supported for first order single"
Expand All @@ -230,18 +230,13 @@ def __call__(self, x0=None, fd_order=None, side=None, method=None):
except AttributeError:
raise TypeError("Multi-dimensional Derivative, input expected as a dict")

return self._new_from_self(fd_order=_fd_order, x0=_x0)
return self._rebuild(fd_order=_fd_order, x0=_x0)

def _new_from_self(self, **kwargs):
expr = kwargs.pop('expr', self.expr)
_kwargs = {'deriv_order': self.deriv_order, 'fd_order': self.fd_order,
'side': self.side, 'transpose': self.transpose, 'subs': self._ppsubs,
'x0': self.x0, 'preprocessed': True, 'method': self.method}
_kwargs.update(**kwargs)
return Derivative(expr, *self.dims, **_kwargs)
def _rebuild(self, *args, **kwargs):
kwargs['preprocessed'] = True
return super()._rebuild(*args, **kwargs)

def func(self, expr, *args, **kwargs):
return self._new_from_self(expr=expr, **kwargs)
func = _rebuild

def _subs(self, old, new, **hints):
# Basic case
Expand All @@ -251,7 +246,7 @@ def _subs(self, old, new, **hints):
if self.expr.has(old):
newexpr = self.expr._subs(old, new, **hints)
try:
return self._new_from_self(expr=newexpr)
return self._rebuild(expr=newexpr)
except ValueError:
# Expr replacement leads to non-differentiable expression
# e.g `f.dx.subs(f: 1) = 1.dx = 0`
Expand All @@ -260,7 +255,7 @@ def _subs(self, old, new, **hints):

# In case `x0` was passed as a substitution instead of `(x0=`
if str(old) == 'x0':
return self._new_from_self(x0={self.dims[0]: new})
return self._rebuild(x0={self.dims[0]: new})

# Trying to substitute by another derivative with different metadata
# Only need to check if is a Derivative since one for the cases above would
Expand Down Expand Up @@ -289,13 +284,11 @@ def _xreplace(self, subs):
return new, True

subs = self._ppsubs + (subs,) # Postponed substitutions
return self._new_from_self(subs=subs), True
return self._rebuild(subs=subs), True

@cached_property
def _metadata(self):
state = list(self.__rargs__ + self.__rkwargs__)
state.remove('expr')
ret = [getattr(self, i) for i in state]
ret = [self.dims] + [getattr(self, i) for i in self.__rkwargs__]
ret.append(self.expr.staggered or (None,))
return tuple(ret)

Expand Down Expand Up @@ -348,7 +341,7 @@ def T(self):
else:
adjoint = direct

return self._new_from_self(transpose=adjoint)
return self._rebuild(transpose=adjoint)

def _eval_at(self, func):
"""
Expand All @@ -360,6 +353,10 @@ def _eval_at(self, func):
# do not overwrite it
if self.x0 or self.side is not None or func.function is self.expr.function:
return self
# For basic equation of the form f = Derivative(g, ...) we can just
# compare staggering
if self.expr.staggered == func.staggered:
return self

x0 = func.indices_ref._getters
if self.expr.is_Add:
Expand All @@ -370,19 +367,19 @@ def _eval_at(self, func):
mapper = as_mapper(self.expr._args_diff, lambda i: i.staggered)
args = [self.expr.func(*v) for v in mapper.values()]
args.extend([a for a in self.expr.args if a not in self.expr._args_diff])
args = [self._new_from_self(expr=a, x0=x0) for a in args]
args = [self._rebuild(expr=a, x0=x0) for a in args]
return self.expr.func(*args)
elif self.expr.is_Mul:
# For Mul, We treat the basic case `u(x + h_x/2) * v(x) which is what appear
# in most equation with div(a * u) for example. The expression is re-centered
# at the highest priority index (see _gather_for_diff) to compute the
# derivative at x0.
return self._new_from_self(x0=x0, expr=self.expr._gather_for_diff)
return self._rebuild(expr=self.expr._gather_for_diff, x0=x0)
else:
# For every other cases, that has more functions or more complexe arithmetic,
# there is not actual way to decide what to do so it’s as safe to use
# the expression as is.
return self._new_from_self(x0=x0)
return self._rebuild(x0=x0)

def _evaluate(self, **kwargs):
# Evaluate finite-difference.
Expand Down
16 changes: 12 additions & 4 deletions devito/passes/equations/linearity.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _(expr, mapper, nn_derivs=None):
@aggregate_coeffs.register(sympy.Derivative)
def _(expr, mapper, nn_derivs=None):
# Opens up a new derivative scope, so do not propagate `nn_derivs`
args = [aggregate_coeffs(a, mapper) for a in expr.args]
args = [aggregate_coeffs(expr.expr, mapper)]
expr = reuse_if_untouched(expr, args)

return expr
Expand Down Expand Up @@ -164,10 +164,10 @@ def _(expr, mapper, nn_derivs=None):
return expr

if len(derivs) == 1 and with_deriv is derivs[0]:
expr = with_deriv._new_from_self(expr=expr.func(*hope_coeffs, with_deriv.expr))
expr = with_deriv._rebuild(expr=expr.func(*hope_coeffs, with_deriv.expr))
else:
others = [expr.func(*hope_coeffs, a) for a in others]
derivs = [a._new_from_self(expr=expr.func(*hope_coeffs, a.expr)) for a in derivs]
derivs = [a._rebuild(expr=expr.func(*hope_coeffs, a.expr)) for a in derivs]
expr = with_deriv.func(*(derivs + others))

return expr
Expand All @@ -190,6 +190,14 @@ def _(expr):
return expr


@factorize_derivatives.register(sympy.Derivative)
def _(expr):
args = [factorize_derivatives(expr.expr)]
expr = reuse_if_untouched(expr, args)

return expr


@factorize_derivatives.register(sympy.Add)
def _(expr):
args = [factorize_derivatives(a) for a in expr.args]
Expand All @@ -216,7 +224,7 @@ def _(expr):
if len(v) == 1:
args.append(c)
else:
args.append(c._new_from_self(expr=expr.func(*[i.expr for i in v])))
args.append(c._rebuild(expr=expr.func(*[i.expr for i in v])))
expr = expr.func(*args)

return expr
2 changes: 1 addition & 1 deletion devito/tools/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(self, a, b, c=4):

kwargs.update({i: getattr(self, i) for i in self.__rkwargs__ if i not in kwargs})

# Should we use a constum reconstructor?
# Should we use a custom reconstructor?
try:
cls = self._rcls
except AttributeError:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,16 @@ def test_deriv_spec(self):
assert dxy0.x0 == {y: y+y.spacing/2}
assert dxy02.x0 == {x: x+x.spacing/2}

def test_deriv_stagg_plain(self):
grid = Grid((11, 11))
x, y = grid.dimensions
f1 = Function(name="f1", grid=grid, space_order=2, staggered=NODE)
f2 = Function(name="f2", grid=grid, space_order=2, staggered=NODE)

eq0 = Eq(f1, f2.laplace).evaluate
assert eq0.rhs == f2.laplace.evaluate
assert eq0.rhs != 0


class TestTwoStageEvaluation:

Expand Down

0 comments on commit 4ef520b

Please sign in to comment.