-
Notifications
You must be signed in to change notification settings - Fork 225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API: revamp cross derivative shortcuts #2458
base: master
Are you sure you want to change the base?
Conversation
@@ -86,8 +86,16 @@ def generate_fd_shortcuts(dims, so, to=0): | |||
from devito.finite_differences.derivative import Derivative | |||
|
|||
def diff_f(expr, deriv_order, dims, fd_order, side=None, **kwargs): | |||
return Derivative(expr, *as_tuple(dims), deriv_order=deriv_order, | |||
fd_order=fd_order, side=side, **kwargs) | |||
# Spearate dimension to always have cross derivatives return nested |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@FabioLuporini this basically does u.dxdy -> u.dx.dy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, you could added it to the comment too (e.g., u.dxdy -> u.dx.dy)
@@ -135,8 +136,13 @@ def _lower_exprs(expressions, subs): | |||
if dimension_map: | |||
indices = [j.xreplace(dimension_map) for j in indices] | |||
|
|||
mapper[i] = f.indexed[indices] | |||
# Handle Array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@FabioLuporini There is probably a cleaner way but don't have time to spend more on this rn
ce1c71e
to
2c51dfe
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2458 +/- ##
===========================================
- Coverage 87.01% 57.45% -29.56%
===========================================
Files 239 238 -1
Lines 44958 44973 +15
Branches 8390 8398 +8
===========================================
- Hits 39118 25839 -13279
- Misses 5108 18293 +13185
- Partials 732 841 +109 ☔ View full report in Codecov by Sentry. |
447e1e4
to
d54831c
Compare
d54831c
to
c7360e2
Compare
if isinstance(f, Array) and f.initvalue is not None: | ||
initv = [_lower_exprs(i, subs) for i in f.initvalue] | ||
# TODO: fix rebuild to avoid new name | ||
f = f._rebuild(name='%si' % f.name, initvalue=initv) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need a new name? does it work if u instead pass function=None
?
mapper[i] = f.indexed[indices] | ||
# Handle Array | ||
if isinstance(f, Array) and f.initvalue is not None: | ||
initv = [_lower_exprs(i, subs) for i in f.initvalue] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can call it 'initvalue'
@@ -222,25 +220,37 @@ def _process_weights(cls, **kwargs): | |||
|
|||
def __call__(self, x0=None, fd_order=None, side=None, method=None, weights=None): | |||
side = side or self._side | |||
method = method or self._method | |||
weights = weights if weights is not None else self._weights | |||
|
|||
x0 = self._process_x0(self.dims, x0=x0) | |||
_x0 = frozendict({**self.x0, **x0}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we actually need the underscore for all these variable names ?
it'd get way less verbose and then easier to read without the initial underscore
raise TypeError("Multi-dimensional Derivative, input expected as a dict") | ||
raise TypeError("fd_order incompatible with dimensions") | ||
|
||
# In case this was called on a cross derivative we need to propagate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"... on a perfect cross derivative (e.g. u.dxdy
), so we need to ... " ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually this comment belongs to the inside of the if
# In case this was called on a cross derivative we need to propagate | ||
# the call to the nested derivative | ||
if isinstance(self.expr, Derivative): | ||
_fd_orders = {k: v for k, v in _fd_order.items() if k in self.expr.dims} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need small utility functions for these ? I have a feeling, but I may well be wrong, this is not the only place we have to perform this kind of information retrieval
@@ -293,8 +303,12 @@ def _xreplace(self, subs): | |||
except AttributeError: | |||
return new, True | |||
|
|||
# Resolve nested derivatives | |||
dsubs = {k: v for k, v in subs.items() if isinstance(k, Derivative)} | |||
new_expr = self.expr.xreplace(dsubs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpicking alert:
expr = ...
would be more homogeneous than new_expr =
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, also fixes some of the aspects of Derivative
that have been bugging me
|
||
obj._fd_order = fd_o if skip else DimensionTuple(*fd_o, getters=obj._dims) | ||
obj._deriv_order = orders if skip else DimensionTuple(*orders, getters=obj._dims) | ||
obj._fd_order = DimensionTuple(*as_tuple(fd_o), getters=obj._dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gets rid of the whole thing where fd_order
could be int or tuple right?
@@ -222,25 +220,37 @@ def _process_weights(cls, **kwargs): | |||
|
|||
def __call__(self, x0=None, fd_order=None, side=None, method=None, weights=None): | |||
side = side or self._side | |||
method = method or self._method |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does the first of these use x or y
whilst the second uses x if x is not None else y
when the default kwval in both cases is None
?
# the call to the nested derivative | ||
if isinstance(self.expr, Derivative): | ||
_fd_orders = {k: v for k, v in _fd_order.items() if k in self.expr.dims} | ||
_x0s = {k: v for k, v in _x0.items() if k in self.expr.dims and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens here if the user does something like u.dx(x0=x+hx/2).dx(x0=x-hx/2)
? Should this consolidate them to u.dx2(x0=x)
or do we just assume that this isn't actually something anyone is going to do?
prio = lambda x: getattr(x, '_fd_priority', 0) | ||
# We want to get the object with highest priority | ||
# We also need to make sure that the object with the largest | ||
# set of dimensions is used when multiple ones with the same |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: this comment could be made clearer by replacing "ones" with "objects"
@@ -86,8 +86,16 @@ def generate_fd_shortcuts(dims, so, to=0): | |||
from devito.finite_differences.derivative import Derivative | |||
|
|||
def diff_f(expr, deriv_order, dims, fd_order, side=None, **kwargs): | |||
return Derivative(expr, *as_tuple(dims), deriv_order=deriv_order, | |||
fd_order=fd_order, side=side, **kwargs) | |||
# Spearate dimension to always have cross derivatives return nested |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: "Spearate" -> "Separate", "dimension" -> "dimensions"
grid = Grid((11, 11)) | ||
f = Function(name="f", grid=grid, space_order=2) | ||
|
||
assert f.dxdy == f.dx.dy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to test that chaining derivatives with various x0
leads to sensible consolidation too.
No description provided.