Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "low-rank" variational families #76

Merged
merged 46 commits into from
Sep 13, 2024
Merged

Conversation

Red-Portal
Copy link
Member

@Red-Portal Red-Portal commented Aug 3, 2024

This PR adds low-rank variational families which cannot be simply represented as a location-scale family (the reparameterization path has to be modified.)

  • add MvLocationLowRankScale
  • implement rand
  • implement logpdf
  • implement entropy
  • implement mean, var, and cov
  • add bunch of tests
  • add documentation

The tricky part would be logpdf and entropy since, to be done efficiently, will have to involve low-rank Cholesky updates. Given that low-rank Cholesky updates are niche, I am not sure whether their AD is up to the task.

@Red-Portal Red-Portal requested a review from yebai August 9, 2024 05:10
@Red-Portal
Copy link
Member Author

Okay the tests didn't pass because the latest Enzyme patch somehow broke some stuff, but all tests passed locally.

@yebai yebai requested review from mhauru and sunxd3 and removed request for yebai August 9, 2024 08:33
@sunxd3
Copy link

sunxd3 commented Aug 9, 2024

Interesting! I'll take a look later (probably going to be in the weekend). Sorry for the possible delay

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Red-Portal! I was actually just yesterday thinking about something like MvLocationScale, I hadn't realised it existed already. This low-rank version seems cool too.

I put a bunch of questions and proposals I had in local comments. It's all about code stuff, I don't have much expertise on the theory side here.

More broadly, and probably not to be addressed in this PR, but is there a reason to keep MvLocationScale and MvLocationScaleLowRank in AdvancedVI, rather than somewhere more centrally in TuringLang so that one could use them more broadly with Turing.jl, or maybe even in Distributions.jl?

I suspect the main reason @yebai tagged me as a reviewer though is the Enzyme failure. I'll look into it.

docs/src/elbo/families.md Outdated Show resolved Hide resolved
docs/src/elbo/families.md Outdated Show resolved Hide resolved
src/families/location_scale_low_rank.jl Outdated Show resolved Hide resolved
src/families/location_scale_low_rank.jl Outdated Show resolved Hide resolved
src/families/location_scale_low_rank.jl Outdated Show resolved Hide resolved
src/families/location_scale_low_rank.jl Outdated Show resolved Hide resolved
src/families/location_scale_low_rank.jl Outdated Show resolved Hide resolved
test/families/location_scale_low_rank.jl Outdated Show resolved Hide resolved
test/families/location_scale_low_rank.jl Outdated Show resolved Hide resolved
test/families/location_scale_low_rank.jl Outdated Show resolved Hide resolved
@mhauru
Copy link
Member

mhauru commented Aug 9, 2024

The Enzyme issue is this: The function A = randn(D, D); f(λ′) = λ′'*A*λ′ / 2 in test/interface/ad.jl creates a closure, and then we call AD on it as

    _, y = Enzyme.autodiff(
        Enzyme.ReverseWithPrimal,
        f,
        Enzyme.Active,
        Enzyme.Duplicated(θ, ∇θ),
    )

Since f is a closure Enzyme is unable to figure out whether it might be mutated or contain differentiable data. (Note that a callable is an object like any other, and thus all sorts of things could be stored in it.) This can be fixed by changing the call to

    _, y = Enzyme.autodiff(
        Enzyme.ReverseWithPrimal,
        Enzyme.Const(f),  # Only line that has changed
        Enzyme.Active,
        Enzyme.Duplicated(θ, ∇θ),
    )

which is just explicitly telling Enzyme that you, the caller, guarantee that f does not contain anything that needs differentiating. That's probably the right thing to do, though I don't know if there could be some situation where the assumption would not be true, in which case it would of course be very bad to mark f as a constant if it's not.

@Red-Portal Red-Portal requested a review from mhauru August 10, 2024 06:58
Copy link

@sunxd3 sunxd3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All make sense.

Is #75 solved by this? I saw it is referred by this PR, but not directed mentioned anywhere?

docs/src/elbo/families.md Outdated Show resolved Hide resolved
docs/make.jl Outdated Show resolved Hide resolved
@Red-Portal
Copy link
Member Author

Hi @sunxd3! Thanks for the comments! #75 was mentioned in a response to a reviewed comment by @mhauru . I'm planning to fix #75 after this PR gets merged.

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me apart from the bug where LocationScaleLowRank assumes zero mean and unit variance. I know you said you'd fix that in a separate PR, but I'm a bit wary of committing to master code that has a known, significant bug in it. Could we make it so that the code errors out if the assumption of normalisation is violated? I can just imagine a situation where this gets into master, someone tries to use it before the other PR gets done, and it silently gives wrong results.

Comment on lines +133 to +135
[:meanfield, :fullrank],
realtype in [Float32, Float64],
bijector in [nothing, :identity]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably the formatter's work, but the indents here are quite odd.

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Red-Portal, and sorry for being a bit slow to respond. I had a few more questions. I'll also need to build and read the new docs still, haven't done that yet.

Comment on lines +129 to +130
σ2 = var(q.dist)
return σ2 * Hermitian(C * C')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know the theory here well, but is there a reason why this involves var(q.dist) rather than cov(q.dist)? I could have imagined it being something like C * cov(q.dist) * C', though that's just a not-very-educated guess.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I was thinking that q.dist was constrained to be a univariate distribution, which would make all of this valid, but seems like I have to use ContinuousUnivariateDistribution for that. Let me fix this later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, yeah, this makes sense for univariate. Is there a reason you want to restrict to q.dist being univariate? Just less of a headache to implement?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it thought to be the easiest way to force people to provide a standardized isotropic distribution. We're not quite forcing it to be standardized, but at least this guarantees it is isotropic.

src/families/location_scale_low_rank.jl Show resolved Hide resolved
test/runtests.jl Outdated Show resolved Hide resolved
test/families/location_scale.jl Show resolved Hide resolved
src/families/location_scale.jl Show resolved Hide resolved
@mhauru
Copy link
Member

mhauru commented Sep 10, 2024

Docs look good. I still wonder if these distributions would be useful more broadly than just within AdvancedVI.

@Red-Portal Red-Portal added bug Something isn't working enhancement New feature or request labels Sep 10, 2024
@Red-Portal Red-Portal added this to the v0.3.0 milestone Sep 10, 2024
@Red-Portal
Copy link
Member Author

Red-Portal commented Sep 10, 2024

@yebai The original formatter complaints were manual touches suggested by @mhauru because the formatted did a pretty ugly job in a few places. But I guess it will be hard to do manual formatting in the long run since people will just want to run the formatter without having to manually revert to the non-standard styles.

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A test seems to fail because of float rounding, I proposed a few two-char changes to hopefully fix that. I'm not sure if all of them are necessary, reject the ones that don't make sense to you.

If you want to add the constraint that the base distribution needs to be univariate, then after that I'm out of nits to pick and happy to approve.

EDIT: Oh and on the formatter thing, yeah, if the formatter makes it ugly I'm happy to still go with the what the formatter does, for consistency and ease.

test/families/location_scale_low_rank.jl Outdated Show resolved Hide resolved
test/families/location_scale_low_rank.jl Outdated Show resolved Hide resolved
test/families/location_scale_low_rank.jl Outdated Show resolved Hide resolved
test/families/location_scale_low_rank.jl Outdated Show resolved Hide resolved
test/families/location_scale_low_rank.jl Outdated Show resolved Hide resolved
test/families/location_scale_low_rank.jl Outdated Show resolved Hide resolved
test/families/location_scale_low_rank.jl Outdated Show resolved Hide resolved
Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Red-Portal! Some Enzyme test now seems to fail, but I'll approve since the code itself looks good. Not immediately obvious to me what the issue is, feel free to look into it, or I can try to dig into it, hopefully sometime in the next couple of days.

src/families/location_scale.jl Outdated Show resolved Hide resolved
@Red-Portal
Copy link
Member Author

Red-Portal commented Sep 12, 2024

@mhauru The enzyme issue is due to an LLVM update. It should be resolved automatically pretty soon.

Edit: Seems like it got fixed.

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Red-Portal, looks great! I have no idea what the right value for scale_eps is, happy to take your word for it.

@Red-Portal Red-Portal merged commit 57c9e58 into TuringLang:master Sep 13, 2024
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants