Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable the use of structured gradients for Zygote #62

Closed
wants to merge 6 commits into from

Conversation

Red-Portal
Copy link
Member

@Red-Portal Red-Portal commented Jun 6, 2024

This PR restructures the project so that Zygote can use structured gradients without having to flatten everything. Also, the interface forces the use of Optimiers.destructure, unlike the previous version where we exposed some control over this. Overall, the changes are summarized as follows:

  • Zygote now can use structured gradients directly, without having to flatten the parameters.
  • The optimize interface is simpler as we do not expose control over restructure (we will always use Opimisers.destructure internally)
  • We do not use DiffResults anymore and just pass around gradients directly. This makes most of the package immutable which is nice, but could impact memory usage/GC time. However, in large-scale problems, Zygote will be the only option, so I think this is okay in the sense of prioritizing Zygote-friendliness.

Any concerns/comments would be much appreciated!

Also, sorry that the diff is overlapping with #61 !

- Internally always use `Optimisers.destructure`
- Use structured gradients for Zygote
- Don't use `DiffResults` and just pass around gradients
@Red-Portal Red-Portal requested review from devmotion and torfjelde and removed request for devmotion June 6, 2024 03:11
@torfjelde
Copy link
Member

However, in large-scale problems, Zygote will be the only option, so I think this is okay in the sense of prioritizing Zygote-friendliness.

I'm not so certain this is true? There are many "large-scale" cases where ReverseDiff.jl beats Zygote.jl easily. Similarly, we have the "up-and-coming" ADs like Enzyme.jl and Tapir.jl which both promises further perf gains beyond Zygote.jl. IMO it seems a bit premature to go down the Zygote-friendly route 😕


maybe_destructure(::ADTypes.AutoZygote, q) = (q, identity)

maybe_destructure(::ADTypes.AbstractADType, q) = Optimisers.destructure(q)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The annoyance with using dispatch to do the destructuring is that you now need to define new structs for every type of parameterization of a distribution you want to do.

As in, how do you separate between, say, a MvNormal with a diag and dense covariance matrix here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intent here was to determine whether to use destruct or not at all depending on the ADType.

@Red-Portal
Copy link
Member Author

@torfjelde I agree. Let's not do this.

@Red-Portal Red-Portal closed this Jun 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants