Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt at implementation of VarNameVector (Metadata alternative) #555

Open
wants to merge 196 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
196 commits
Select commit Hold shift + click to select a range
5af1afa
initial implementation of VarNameVector
torfjelde Oct 31, 2023
8ce53f7
added some hacky getval and getdist get things to work for VarInfo
torfjelde Nov 7, 2023
fc6a051
Apply suggestions from code review
torfjelde Nov 7, 2023
7cd599d
added arbitrary metadata field as discussed
torfjelde Nov 12, 2023
ed0a757
renamed idcs to varname_to_index
torfjelde Nov 12, 2023
4ebd252
renamed vns to varnames for VarNameVector
torfjelde Nov 12, 2023
9f12c9a
added keys impl for Metadata
torfjelde Nov 12, 2023
5a15121
added push! and update! for VarNameVector
torfjelde Nov 13, 2023
edde2c1
added getindex_raw! and setindex_raw! for VarNameVector
torfjelde Nov 13, 2023
ed46002
added `iterate` and `convert` (for `AbstractDict) impls for `VarNameV…
torfjelde Nov 13, 2023
5b00059
make the key and eltype part of the `VarNameVector` type
torfjelde Nov 13, 2023
bef7e0a
added more tests for VarNameVector
torfjelde Nov 13, 2023
006ee8d
formatting
torfjelde Nov 13, 2023
9802811
more testing for VarNameVector
torfjelde Nov 13, 2023
88b1721
minor changes to some comments
torfjelde Nov 13, 2023
ca7b173
added a bunch more tests for VarNameVector + several bugfixes in the …
torfjelde Nov 13, 2023
fb01b94
formatting
torfjelde Nov 13, 2023
9634839
added `similar` implementation for `VarNameVector`
torfjelde Nov 13, 2023
5179f6f
formatting
torfjelde Nov 13, 2023
9f632bb
removed debug statement
torfjelde Nov 13, 2023
3c210f7
made VarInfo slighly more generic wrt. underlying metadata
torfjelde Nov 13, 2023
8bf6589
Merge branch 'master' into torfjelde/varnamevector
torfjelde Nov 14, 2023
8b2720f
fixed incorrect behavior in `keys` for `Metadata`
torfjelde Nov 14, 2023
9fa6446
minor style changes to VarNameVector tests
torfjelde Nov 14, 2023
0900c57
style
torfjelde Nov 14, 2023
1f7e633
added testing of `update!` with smaller sizes and fixed bug related t…
torfjelde Nov 14, 2023
8d05586
formatting
torfjelde Nov 14, 2023
7801fe1
move functionality related to `push!` for `VarNameVector` into `push!`
torfjelde Nov 14, 2023
cdc2373
Update src/varnamevector.jl
torfjelde Nov 16, 2023
d2d776d
Merge branch 'master' into torfjelde/varnamevector
torfjelde Nov 20, 2023
ae4bcb7
several fixes to make sampling with VarNameVector + initiall tests for
torfjelde Dec 30, 2023
97e1bcc
VarInfo + VarNameVector tests for all demo models
torfjelde Dec 30, 2023
be3c1b4
Merge remote-tracking branch 'origin/torfjelde/varnamevector' into to…
torfjelde Dec 30, 2023
ad343f3
Apply suggestions from code review
torfjelde Dec 30, 2023
f707b25
added docs on the design of `VarNameVector`
torfjelde Dec 31, 2023
4e7af1d
Merge branch 'master' into torfjelde/varnamevector
torfjelde Dec 31, 2023
f1faf18
Apply suggestions from code review
torfjelde Dec 31, 2023
87d3d01
added note on `update!`
torfjelde Dec 31, 2023
9c3b265
further elaboration of the design of `VarInfo` and `VarNameVector`
torfjelde Jan 1, 2024
958c66b
more writing improvements
torfjelde Jan 1, 2024
74c6efd
added docstring to `has_inactive_ranges` and `inactive_ranges_sweep!`
torfjelde Jan 1, 2024
d9ea878
moved docs on `VarInfo` design to a separate internals section
torfjelde Jan 1, 2024
5acce98
writing improvements for internal docs
torfjelde Jan 1, 2024
6f95cdd
further motivation of the design choices made in `VarNameVector`
torfjelde Jan 1, 2024
38a4b08
improved writing
torfjelde Jan 1, 2024
60edd10
VarNameVector is now grown as much as needed
torfjelde Jan 1, 2024
3f9d34f
updated `delete!`
torfjelde Jan 2, 2024
fb822b5
Significant changes to implementation of `VarNameVector`:
torfjelde Jan 2, 2024
66bc090
added `copy` when constructing `VectorVarInfo` from `VarInfo`
torfjelde Jan 2, 2024
ccd86f2
added missing `isempty` impl
torfjelde Jan 2, 2024
1d4a000
remove impl of `iterate` and instead implemented `pairs` and `values`…
torfjelde Jan 2, 2024
9a16dd1
added missing `empty!` for `num_inactive`
torfjelde Jan 2, 2024
e49b762
removed redundant `shift_left!` methd
torfjelde Jan 2, 2024
2b445c9
fixed `delete!` for `VarNameVector`
torfjelde Jan 2, 2024
e3c2633
added `is_contiguous` as an alterantive to `!has_inactive`
torfjelde Jan 2, 2024
19a829c
updates to internal docs
torfjelde Jan 2, 2024
a358bc4
renamed `sweep_inactive_ranges!` to `contiguify!`
torfjelde Jan 2, 2024
46be8d5
improvements to internal docs
torfjelde Jan 2, 2024
57d688e
more improvements to internal docs
torfjelde Jan 2, 2024
0968a07
moved additional methods description in internals to earlier in the doc
torfjelde Jan 2, 2024
0d008a4
moved internals docs to a separate directory and split into files
torfjelde Jan 2, 2024
ccd0d64
more improvements to internals doc
torfjelde Jan 2, 2024
7c45e67
formatting
torfjelde Jan 2, 2024
373215b
added tests for `delete!` and fixed reference to old method
torfjelde Jan 2, 2024
0cdafbf
addition to `delete!` test
torfjelde Jan 2, 2024
51c041f
added `values_as` impls for `VarNameVector`
torfjelde Jan 2, 2024
20b3742
added docs for `replace_valus` and `values_as` for `VarNameVector`
torfjelde Jan 2, 2024
ef6c618
fixed doctest
torfjelde Jan 2, 2024
8a1209c
formatting
torfjelde Jan 2, 2024
adeadf0
temporarily disable doctests so we can build docs
torfjelde Jan 2, 2024
7ff179d
added missing compat entry for ForwardDiff in docs
torfjelde Jan 2, 2024
c7ec08a
moved some shared code into methods to make things a bit cleaner
torfjelde Jan 3, 2024
c5a5e58
added impl of `merge` for `VarNameVector`
torfjelde Jan 3, 2024
c376d95
renamed a few variables in `merge` impl for `VarNameVector`
torfjelde Jan 3, 2024
f71baa5
forgot to include some changes in previous commit
torfjelde Jan 3, 2024
af25f3c
added impl of `subset` for `VarNameVector`
torfjelde Jan 3, 2024
c28f076
fixed `pairs` impl for `VarNameVector`
torfjelde Jan 3, 2024
f5d2c63
added missing impl of `subset` for `VectorVarInfo`
torfjelde Jan 3, 2024
3eb6c7f
added missing impl of `merge_metadata` for `VarNameVector`
torfjelde Jan 3, 2024
9ba8144
added a bunch of `from_vec_transform` and `tovec` impls to make
torfjelde Jan 3, 2024
acd6951
make default args use `from_vec_transform` rather than `FromVec`
torfjelde Jan 3, 2024
790f743
fixed `values_as` fro `VarInfo` with `VarNameVector` as `metadata`
torfjelde Jan 3, 2024
c474bb0
fixed impl of `getindex_raw` when using integer index for `VarNameVec…
torfjelde Jan 4, 2024
8251463
added tests for `getindex` with `Int` index for `VarNameVector`
torfjelde Jan 4, 2024
5df7031
fix for `setindex!` and `setindex_raw!` for `VarNameVector`
torfjelde Jan 4, 2024
683b776
introduction of `from_vec_transform` and `tovec` and its usage in `Va…
torfjelde Jan 19, 2024
4dae00d
moved definition of `is_splat_symbol` to the file where it's used
torfjelde Jan 19, 2024
e3b52a4
added `VarInfo` constructor with vector input for `VectorVarInfo`
torfjelde Jan 19, 2024
9626be1
make `extract_priors` take the `rng` as an argument
torfjelde Jan 19, 2024
e731fd6
added `replace_values` for `Metadata`
torfjelde Jan 19, 2024
0785abf
make link and invlink act on the `metadata` field for `VarInfo` +
torfjelde Jan 19, 2024
b3e0955
added temporary defs of `with_logabsdet_jacobian` and `inverse` for
torfjelde Jan 19, 2024
ff963ce
added invlink_with_logpdf overload for `ThreadSafeVarInfo`
torfjelde Jan 19, 2024
03f2b2b
added `is_transformed` field to `VarNameVector`
torfjelde Jan 19, 2024
949b33a
removed unnecessary defintions of `with_logabsdet_jacobian` and
torfjelde Jan 19, 2024
cc5ecc4
fixed issue where we were storing the wrong transformations in `VarNa…
torfjelde Jan 19, 2024
1aae1b4
make sure `extract_priors` doesn't mutate the `varinfo`
torfjelde Jan 19, 2024
8e0853d
updated `similar` for `VarNameVector` and fixed `invlink` for `VarNam…
torfjelde Jan 19, 2024
229b168
added handling of `is_transformed` in `merge` for `VarNameVector`
torfjelde Jan 19, 2024
c581dcf
removed unnecesasry `deepcopy` from outer `link`
torfjelde Jan 19, 2024
b4d3f55
updated `push!` to also `push!` on `is_transformed`
torfjelde Jan 19, 2024
ed1d006
skip tests for mutating linking when using VarNameVector
torfjelde Jan 19, 2024
f132209
use same projection for `Cholesky` in `VarNameVector` as in `VarInfo`
torfjelde Jan 19, 2024
49454de
fixed `settrans!!` for `VarInfo` with `VarNameVector`
torfjelde Jan 19, 2024
01ff2dd
fixed bug in `set_flag!`
torfjelde Jan 19, 2024
20adedf
fixed another typo
torfjelde Jan 19, 2024
8f9566a
fixed return values of `settrans!!`
torfjelde Jan 19, 2024
5532046
updated static transformation tests
torfjelde Jan 20, 2024
3c5d2ac
Update test/simple_varinfo.jl
torfjelde Jan 20, 2024
317d969
Merge branch 'master' into torfjelde/varnamevector
torfjelde Jan 20, 2024
f8441ea
Merge remote-tracking branch 'origin/torfjelde/varnamevector' into to…
torfjelde Jan 25, 2024
ab16323
Merge branch 'master' into torfjelde/varnamevector
torfjelde Jan 25, 2024
a9be219
removed unnecessary impl of `extract_priors`
torfjelde Jan 25, 2024
53c8d33
make `short_varinfo_name` of `TypedVarInfo` a bit more informative
torfjelde Jan 25, 2024
61d85ad
moved impl of `has_varnamevector` for `ThreadSafeVarInfo`
torfjelde Jan 25, 2024
9ace554
added back `extract_priors` impl as we do need it
torfjelde Jan 25, 2024
67c9821
forgot to include tests for `VarNameVector` in `runtests.jl`
torfjelde Jan 25, 2024
32a2d31
fix for `relax_container_types` in `test/varnamevector.jl`
torfjelde Jan 25, 2024
b3bb42d
fixed `need_transforms_relaxation`
torfjelde Jan 26, 2024
25ff2b1
updated some tests to not refer directly to `FromVec`
torfjelde Jan 28, 2024
004f038
introduce `from_internal_transform` and its siblings
torfjelde Jan 28, 2024
38c89bd
remove `with_logabsdet_jacobian_and_reconstruct` in favour of
torfjelde Jan 28, 2024
218dc23
added `internal_to_linked_internal_transform` + fixed a few bugs in
torfjelde Jan 28, 2024
1df4293
added `linked_internal_to_internal_transform` as a complement to `int…
torfjelde Jan 28, 2024
f8df896
fixed bugs in `invlink` for `VarInfo` using `linked_internal_to_inter…
torfjelde Jan 28, 2024
d62f26a
more work on removing calls to `reconstruct`
torfjelde Jan 28, 2024
b4517d6
removed redundant comment
torfjelde Jan 28, 2024
b7d4754
added `from_linked_vec_transform` specialization for `LKJ`
torfjelde Jan 28, 2024
0244dd9
more work on removing references to `reconstruct`
torfjelde Jan 28, 2024
e886d07
added `copy` in `values_from_metadata` to preserve behavior and avoid
torfjelde Jan 28, 2024
2af6605
remove `reconstruct_and_link` and `invlink_and_reconstruct`
torfjelde Jan 28, 2024
a0664d7
replaced references to `link_and_reconstruct` and `invlink_and_recons…
torfjelde Jan 28, 2024
f2d59b2
introduced `recombine` and replaced calls to `reconstruct` with `n` s…
torfjelde Jan 28, 2024
e3bfa76
completely removed `reconstruct`
torfjelde Jan 28, 2024
c0aef81
renamed `maybe_reconstruct_and_link` to `to_maybe_linked_internal` and
torfjelde Jan 28, 2024
f7c0853
added impls of `from_*_internal_transform` for `ThreadSafeVarInfo`
torfjelde Jan 30, 2024
77b835e
removed `reconstruct` from docs and from exports
torfjelde Jan 30, 2024
b83c262
renamed `getval` to `getindex_internal` and made `dist` an optional
torfjelde Jan 31, 2024
c4faf3e
updated docs + added description of how internals of transforms work
torfjelde Jan 31, 2024
c8d9695
added a bunch of illustrations for the transforms docs + dot files us…
torfjelde Jan 31, 2024
95dc8e3
temporarily removed `VarNameVector` completely
torfjelde Jan 31, 2024
8930f9c
formatting
torfjelde Jan 31, 2024
e45b668
Update docs/src/internals/transformations.md
torfjelde Jan 31, 2024
0e71092
Update docs/src/internals/transformations.md
torfjelde Jan 31, 2024
2de9ac9
removed refs to VectorVarInfo
torfjelde Jan 31, 2024
9b71428
added impls of `from_internal_transform` for `ThreadSafeVarInfo`
torfjelde Feb 1, 2024
786e9bf
reverted accidental removal of old `VarInfo` constructor
torfjelde Feb 1, 2024
f1fe42c
fixed incorrect `recombine` call
torfjelde Feb 1, 2024
2273954
removed undefined refs to `VarNameVector` stuff in `setup_varinfos`
torfjelde Feb 1, 2024
ab7c189
bump minior version because Turing breaks
torfjelde Feb 1, 2024
3a86601
fix: was using `from_linked_internal_transform` in
torfjelde Feb 1, 2024
28c7d85
removed `getindex_raw`
torfjelde Feb 1, 2024
59514d6
removed redundant docstrings
torfjelde Feb 1, 2024
cdc882b
fixed tests
torfjelde Feb 1, 2024
57ba7c0
fixed comparisons in tests
torfjelde Feb 1, 2024
902e59c
try relative references for images in transformation docs
torfjelde Feb 1, 2024
d7aba55
another attempt at fixing asset-references
torfjelde Feb 3, 2024
b0dd2f8
Merge branch 'master' into torfjelde/transformations
torfjelde Feb 3, 2024
1f51203
fixed getindex diagrams in docs
torfjelde Feb 3, 2024
0eb79b1
minor changes to comments
torfjelde Feb 3, 2024
071bebf
remove Combinatorics as a test dep, as it's not needed for this PR
torfjelde Feb 3, 2024
bbdc060
reverted unnecessary change
torfjelde Feb 3, 2024
e2f4d18
disable type-stability tests for models on older Julia versions
torfjelde Feb 3, 2024
3d823ac
removed seemingly completely unused impl of `setval!`
torfjelde Feb 3, 2024
54792f4
Revert "temporarily removed `VarNameVector` completely"
torfjelde Feb 3, 2024
ff68206
Revert "remove Combinatorics as a test dep, as it's not needed for th…
torfjelde Feb 6, 2024
9b1014d
Merge remote-tracking branch 'origin/master' into torfjelde/varnameve…
mhauru Aug 22, 2024
19978ec
More work on `VarNameVector` (#637)
mhauru Sep 3, 2024
95668eb
Merge remote-tracking branch 'origin/master' into torfjelde/varnameve…
mhauru Sep 3, 2024
1e4efe6
Bump Bijectors dependecy
mhauru Sep 3, 2024
3ee9832
Remove dead TODO note
mhauru Sep 3, 2024
26753e9
Remove old TODOs, improve VNV invlinking
mhauru Sep 3, 2024
ea18e1f
Fix from_vec_transform for 0-dim arrays
mhauru Sep 4, 2024
ffbf2ad
Fix unflatten for VarInfo
mhauru Sep 4, 2024
f077f4a
Fix some VarInfo index getters
mhauru Sep 4, 2024
e27af80
Change how VNV handles transformations, and other VNV stuff
mhauru Sep 4, 2024
b5677b4
Small docs fixes
mhauru Sep 4, 2024
9d1c8d3
Merge remote-tracking branch 'origin/master' into torfjelde/varnameve…
mhauru Sep 4, 2024
b778082
Small fixes all over for VNV
mhauru Sep 5, 2024
9750e60
Add comments
mhauru Sep 5, 2024
9ecc506
Fix some tests
mhauru Sep 5, 2024
3f1b9a2
Change long string formatting to support Julia 1.6
mhauru Sep 5, 2024
9145965
Small changes to ReshapeTransformation
mhauru Sep 5, 2024
937956d
Revert unrelated changes to ReverseDiff extension
mhauru Sep 5, 2024
4fbe5d2
Improve VarNamedVector VarInfo testing
mhauru Sep 5, 2024
9f11e7b
Fix some depwarns
mhauru Sep 5, 2024
86d97ae
Improvements to test/simple_varinfo.jl
mhauru Sep 5, 2024
2535517
Fix for unset_flag!, better docstring
mhauru Sep 5, 2024
93ef3ee
Add a comment about hasvalue/getvalue
mhauru Sep 5, 2024
f35eca6
Add @non_differentiable calls to work around Zygote limitations
mhauru Sep 9, 2024
d55fc00
Fix docs, workaround Zygote issue
mhauru Sep 9, 2024
5bbba91
Remove outdated workaround
mhauru Sep 17, 2024
851630f
Move has_varnamedvector(varinfo::VarInfo) to abstract_varinfo.jl
mhauru Sep 17, 2024
45c89c4
Make copies of logp and num_produce in subset
mhauru Sep 17, 2024
30252dc
Rename getindex_raw to getindex_internal
mhauru Sep 17, 2024
be77c36
Add push!(::VarNamedVector, ::Pair)
mhauru Sep 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.29"
version = "0.30"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -46,7 +46,7 @@ AbstractMCMC = "5"
AbstractPPL = "0.8.4"
Accessors = "0.1"
BangBang = "0.4.1"
Bijectors = "0.13.9"
Bijectors = "0.13.18"
ChainRulesCore = "1"
Compat = "4"
ConstructionBase = "1.5.4"
Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ makedocs(;
"Home" => "index.md",
"API" => "api.md",
"Tutorials" => ["tutorials/prob-interface.md"],
"Internals" => ["internals/transformations.md"],
"Internals" => ["internals/varinfo.md", "internals/transformations.md"],
],
checkdocs=:exports,
doctest=false,
Expand Down
310 changes: 310 additions & 0 deletions docs/src/internals/varinfo.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions ext/DynamicPPLChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ ChainRulesCore.@non_differentiable DynamicPPL.updategid!(
# No need + causes issues for some AD backends, e.g. Zygote.
ChainRulesCore.@non_differentiable DynamicPPL.infer_nested_eltype(x)

ChainRulesCore.@non_differentiable DynamicPPL.recontiguify_ranges!(ranges)

end # module
143 changes: 137 additions & 6 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,152 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
return keys(c.info.varname_to_symbol)
end

"""
generated_quantities(model::Model, chain::MCMCChains.Chains)

Execute `model` for each of the samples in `chain` and return an array of the values
returned by the `model` for each sample.

# Examples
## General
Often you might have additional quantities computed inside the model that you want to
inspect, e.g.
```julia
@model function demo(x)
# sample and observe
θ ~ Prior()
x ~ Likelihood()
return interesting_quantity(θ, x)
end
m = demo(data)
chain = sample(m, alg, n)
# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
# from the posterior/`chain`:
generated_quantities(m, chain) # <= results in a `Vector` of returned values
# from `interesting_quantity(θ, x)`
```
## Concrete (and simple)
```julia
julia> using DynamicPPL, Turing

julia> @model function demo(xs)
s ~ InverseGamma(2, 3)
m_shifted ~ Normal(10, √s)
m = m_shifted - 10

for i in eachindex(xs)
xs[i] ~ Normal(m, √s)
end

return (m, )
end
demo (generic function with 1 method)

julia> model = demo(randn(10));

julia> chain = sample(model, MH(), 10);

julia> generated_quantities(model, chain)
10×1 Array{Tuple{Float64},2}:
(2.1964758025119338,)
(2.1964758025119338,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.043088571494005024,)
(-0.16489786710222099,)
(-0.16489786710222099,)
```
"""
function DynamicPPL.generated_quantities(
model::DynamicPPL.Model, chain_full::MCMCChains.Chains
)
chain = MCMCChains.get_sections(chain_full, :parameters)
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
# Update the varinfo with the current sample and make variables not present in `chain`
# to be sampled.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
if DynamicPPL.supports_varname_indexing(chain)
varname_pairs = _varname_pairs_with_varname_indexing(
chain, varinfo, sample_idx, chain_idx
)
else
varname_pairs = _varname_pairs_without_varname_indexing(
chain, varinfo, sample_idx, chain_idx
)
end
fixed_model = DynamicPPL.fix(model, Dict(varname_pairs))
return fixed_model()
end
end

"""
_varname_pairs_with_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)

# TODO: Some of the variables can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to `model`.
model(deepcopy(varinfo))
Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
from the chain.

This implementation assumes `chain` can be indexed using variable names, and is the
preffered implementation.
"""
function _varname_pairs_with_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)
vns = DynamicPPL.varnames(chain)
vn_parents = Iterators.map(vns) do vn
# The call nested_setindex_maybe! is used to handle cases where vn is not
# the variable name used in the model, but rather subsumed by one. Except
# for the subsumption part, this could be
# vn => getindex_varname(chain, sample_idx, vn, chain_idx)
# TODO(mhauru) This call to nested_setindex_maybe! is unintuitive.
DynamicPPL.nested_setindex_maybe!(
varinfo, DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx), vn
)
end
varname_pairs = Iterators.map(Iterators.filter(!isnothing, vn_parents)) do vn_parent
vn_parent => varinfo[vn_parent]
end
return varname_pairs
end

"""
Check which keys in `key_strings` are subsumed by `vn_string` and return the their values.

The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and
won't catch all cases. We should get rid of this if we can.
"""
# TODO(mhauru) See docstring above.
function _vcat_subsumed_values(vn_string, values, key_strings)
indices = findall(Base.Fix1(DynamicPPL.subsumes_string, vn_string), key_strings)
return !isempty(indices) ? reduce(vcat, values[indices]) : nothing
end

"""
_varname_pairs_without_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)

Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
from the chain.

This implementation does not assume that `chain` can be indexed using variable names. It is
thus not guaranteed to work in cases where the variable names have complex subsumption
patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`.
"""
function _varname_pairs_without_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)
values = chain.value[sample_idx, :, chain_idx]
keys = Base.keys(chain)
keys_strings = map(string, keys)
varname_pairs = [
vn => _vcat_subsumed_values(string(vn), values, keys_strings) for
vn in Base.keys(varinfo)
]
return varname_pairs
end

end
3 changes: 3 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ export AbstractVarInfo,
VarInfo,
UntypedVarInfo,
TypedVarInfo,
VectorVarInfo,
SimpleVarInfo,
VarNamedVector,
push!!,
empty!!,
subset,
Expand Down Expand Up @@ -175,6 +177,7 @@ include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("varnamedvector.jl")
include("abstract_varinfo.jl")
include("threadsafe.jl")
include("varinfo.jl")
Expand Down
17 changes: 12 additions & 5 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ julia> # Just use an example model to construct the `VarInfo` because we're lazy
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;

julia> # For the sake of brevity, let's just check the type.
md = values_as(vi); md.s isa DynamicPPL.Metadata
md = values_as(vi); md.s isa Union{DynamicPPL.Metadata, DynamicPPL.VarNamedVector}
true

julia> values_as(vi, NamedTuple)
Expand All @@ -321,7 +321,7 @@ julia> # Just use an example model to construct the `VarInfo` because we're lazy
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;

julia> # For the sake of brevity, let's just check the type.
values_as(vi) isa DynamicPPL.Metadata
values_as(vi) isa Union{DynamicPPL.Metadata, Vector}
true

julia> values_as(vi, NamedTuple)
Expand Down Expand Up @@ -349,7 +349,7 @@ Determine the default `eltype` of the values returned by `vi[spl]`.
This should generally not be called explicitly, as it's only used in
[`matchingvalue`](@ref) to determine the default type to use in place of
type-parameters passed to the model.

This method is considered legacy, and is likely to be deprecated in the future.
"""
function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior})
Expand All @@ -363,6 +363,13 @@ function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromP
return eltype(T)
end

"""
has_varnamedvector(varinfo::VarInfo)

Returns `true` if `varinfo` uses `VarNamedVector` as metadata.
"""
has_varnamedvector(vi::AbstractVarInfo) = false

# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert
# the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which
# might result in a `Vector{Any}`.
Expand Down Expand Up @@ -554,7 +561,7 @@ end
link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)

Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`.
Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`.

If `t` is not provided, `default_transformation(model, vi)` will be used.

Expand All @@ -573,7 +580,7 @@ end
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)

Transform the variables in `vi` to their constrained space, using the (inverse of)
Transform the variables in `vi` to their constrained space, using the (inverse of)
transformation `t`, mutating `vi` if possible.

If `t` is not provided, `default_transformation(model, vi)` will be used.
Expand Down
15 changes: 12 additions & 3 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,10 @@ function assume(
if haskey(vi, vn)
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if
# that's okay.
unset_flag!(vi, vn, "del", true)
r = init(rng, dist, sampler)
f = to_maybe_linked_internal_transform(vi, vn, dist)
BangBang.setindex!!(vi, f(r), vn)
Expand Down Expand Up @@ -516,7 +519,10 @@ function get_and_set_val!(
if haskey(vi, vns[1])
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
unset_flag!(vi, vns[1], "del")
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if
# that's okay.
unset_flag!(vi, vns[1], "del", true)
r = init(rng, dist, spl, n)
for i in 1:n
vn = vns[i]
Expand Down Expand Up @@ -554,7 +560,10 @@ function get_and_set_val!(
if haskey(vi, vns[1])
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
unset_flag!(vi, vns[1], "del")
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if
# that's okay.
unset_flag!(vi, vns[1], "del", true)
f = (vn, dist) -> init(rng, dist, spl)
r = f.(vns, dists)
for i in eachindex(vns)
Expand Down
Loading
Loading