Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update JAX fixed-step solver templates to be compilable w.r.t. t_span and t_eval args #126

Open
DanPuzzuoli opened this issue Sep 1, 2022 · 0 comments

Comments

@DanPuzzuoli
Copy link
Collaborator

Summary

Make the fixed step solvers compileable/differentiable with respect to the t_span and t_eval arguments.

This isn't an extremely urgent issue, however it would still be very nice to round out the features of these solvers.

Details

Issue #122 outlines a bug that is ultimately due to the fact that the JAX solvers in dynamics cannot be compiled if t_eval is not None. The fix PR, #125, resolves this issue by updating jax_odeint and the diffrax solver wrapper so that they can be compiled with respect to t_eval.

As described in #122 however, updating the fixed step JAX solvers built in dynamics to be compilable with respect to both t_span and t_eval is non-trivial due to their looping structure being dependent on the values of t_span and t_eval. As a result, the fix #125 is only partial: in the case of JAX fixed step solvers, the problem is simply avoided rather than being fundamentally fixed.

To make the fixed step solvers fully compilable/differentiable with respect to the t_span and t_eval arguments, the functions fixed_step_solver_template_jax and fixed_step_lmde_solver_parallel_template_jax need to be updated to use more advanced JAX control flow. Preserving differentiability with respect to other parameters may also require defining custom differentiation rules (vjp and jvp rules).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant