diff --git a/Project.toml b/Project.toml index f7be0257d..b7d1589f0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.29" +version = "0.30" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -46,7 +46,7 @@ AbstractMCMC = "5" AbstractPPL = "0.8.4" Accessors = "0.1" BangBang = "0.4.1" -Bijectors = "0.13.9" +Bijectors = "0.13.18" ChainRulesCore = "1" Compat = "4" ConstructionBase = "1.5.4" diff --git a/docs/make.jl b/docs/make.jl index b0168076d..42a82436c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -23,7 +23,7 @@ makedocs(; "Home" => "index.md", "API" => "api.md", "Tutorials" => ["tutorials/prob-interface.md"], - "Internals" => ["internals/transformations.md"], + "Internals" => ["internals/varinfo.md", "internals/transformations.md"], ], checkdocs=:exports, doctest=false, diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md new file mode 100644 index 000000000..c1219444f --- /dev/null +++ b/docs/src/internals/varinfo.md @@ -0,0 +1,310 @@ +# Design of `VarInfo` + +[`VarInfo`](@ref) is a fairly simple structure. + +```@docs; canonical=false +VarInfo +``` + +It contains + + - a `logp` field for accumulation of the log-density evaluation, and + - a `metadata` field for storing information about the realizations of the different variables. + +Representing `logp` is fairly straight-forward: we'll just use a `Real` or an array of `Real`, depending on the context. + +**Representing `metadata` is a bit trickier**. This is supposed to contain all the necessary information for each `VarName` to enable the different executions of the model + extraction of different properties of interest after execution, e.g. the realization / value corresponding to a variable `@varname(x)`. + +!!! note + + We want to work with `VarName` rather than something like `Symbol` or `String` as `VarName` contains additional structural information, e.g. a `Symbol("x[1]")` can be a result of either `var"x[1]" ~ Normal()` or `x[1] ~ Normal()`; these scenarios are disambiguated by `VarName`. + +To ensure that `VarInfo` is simple and intuitive to work with, we want `VarInfo`, and hence the underlying `metadata`, to replicate the following functionality of `Dict`: + + - `keys(::Dict)`: return all the `VarName`s present in `metadata`. + - `haskey(::Dict)`: check if a particular `VarName` is present in `metadata`. + - `getindex(::Dict, ::VarName)`: return the realization corresponding to a particular `VarName`. + - `setindex!(::Dict, val, ::VarName)`: set the realization corresponding to a particular `VarName`. + - `push!(::Dict, ::Pair)`: add a new key-value pair to the container. + - `delete!(::Dict, ::VarName)`: delete the realization corresponding to a particular `VarName`. + - `empty!(::Dict)`: delete all realizations in `metadata`. + - `merge(::Dict, ::Dict)`: merge two `metadata` structures according to similar rules as `Dict`. + +*But* for general-purpose samplers, we often want to work with a simple flattened structure, typically a `Vector{<:Real}`. Therefore we also want `varinfo` to be able to replicate the following functionality of `Vector{<:Real}`: + + - `getindex(::Vector{<:Real}, ::Int)`: return the i-th value in the flat representation of `metadata`. + + + For example, if `metadata` contains a realization of `x ~ MvNormal(zeros(3), I)`, then `getindex(varinfo, 1)` should return the realization of `x[1]`, `getindex(varinfo, 2)` should return the realization of `x[2]`, etc. + + - `setindex!(::Vector{<:Real}, val, ::Int)`: set the i-th value in the flat representation of `metadata`. + - `length(::Vector{<:Real})`: return the length of the flat representation of `metadata`. + - `similar(::Vector{<:Real})`: return a new instance with the same `eltype` as the input. + +We also want some additional methods that are *not* part of the `Dict` or `Vector` interface: + + - `push!(container, varname::VarName, value[, transform])`: add a new element to the container, but with an optional transformation that has been applied to `value`, and should be reverted when returning `container[varname]`. One can also provide a `Pair` instead of a `VarName` and a `value`. + + - `update!(container, ::VarName, value[, transform])`: similar to `push!` but if the `VarName` is already present in the container, then we update the corresponding value instead of adding a new element. + +In addition, we want to be able to access the transformed / "unconstrained" realization for a particular `VarName` and so we also need corresponding methods for this: + + - `getindex_internal` and `setindex_internal!` for extracting and mutating the internal, possibly unconstrained, representaton of a particular `VarName`. + +Finally, we want want the underlying representation used in `metadata` to have a few performance-related properties: + + 1. Type-stable when possible, but functional when not. + 2. Efficient storage and iteration when possible, but functional when not. + +The "but functional when not" is important as we want to support arbitrary models, which means that we can't always have these performance properties. + +In the following sections, we'll outline how we achieve this in [`VarInfo`](@ref). + +## Type-stability + +Ensuring type-stability is somewhat non-trivial to address since we want this to be the case even when models mix continuous (typically `Float64`) and discrete (typically `Int`) variables. + +Suppose we have an implementation of `metadata` which implements the functionality outlined in the previous section. The way we approach this in `VarInfo` is to use a `NamedTuple` with a separate `metadata` *for each distinct `Symbol` used*. For example, if we have a model of the form + +```@example varinfo-design +using DynamicPPL, Distributions, FillArrays + +@model function demo() + x ~ product_distribution(Fill(Bernoulli(0.5), 2)) + y ~ Normal(0, 1) + return nothing +end +``` + +then we construct a type-stable representation by using a `NamedTuple{(:x, :y), Tuple{Vx, Vy}}` where + + - `Vx` is a container with `eltype` `Bool`, and + - `Vy` is a container with `eltype` `Float64`. + +Since `VarName` contains the `Symbol` used in its type, something like `getindex(varinfo, @varname(x))` can be resolved to `getindex(varinfo.metadata.x, @varname(x))` at compile-time. + +For example, with the model above we have + +```@example varinfo-design +# Type-unstable `VarInfo` +varinfo_untyped = DynamicPPL.untyped_varinfo( + demo(), SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata +) +typeof(varinfo_untyped.metadata) +``` + +```@example varinfo-design +# Type-stable `VarInfo` +varinfo_typed = DynamicPPL.typed_varinfo(demo()) +typeof(varinfo_typed.metadata) +``` + +They both work as expected but one results in concrete typing and the other does not: + +```@example varinfo-design +varinfo_untyped[@varname(x)], varinfo_untyped[@varname(y)] +``` + +```@example varinfo-design +varinfo_typed[@varname(x)], varinfo_typed[@varname(y)] +``` + +Notice that the untyped `VarInfo` uses `Vector{Real}` to store the boolean entries while the typed uses `Vector{Bool}`. This is because the untyped version needs the underlying container to be able to handle both the `Bool` for `x` and the `Float64` for `y`, while the typed version can use a `Vector{Bool}` for `x` and a `Vector{Float64}` for `y` due to its usage of `NamedTuple`. + +!!! warning + + Of course, this `NamedTuple` approach is *not* necessarily going to help us in scenarios where the `Symbol` does not correspond to a unique type, e.g. + + ```julia + x[1] ~ Bernoulli(0.5) + x[2] ~ Normal(0, 1) + ``` + + In this case we'll end up with a `NamedTuple((:x,), Tuple{Vx})` where `Vx` is a container with `eltype` `Union{Bool, Float64}` or something worse. This is *not* type-stable but will still be functional. + + In practice, we rarely observe such mixing of types, therefore in DynamicPPL, and more widely in Turing.jl, we use a `NamedTuple` approach for type-stability with great success. + +!!! warning + + Another downside with such a `NamedTuple` approach is that if we have a model with lots of tilde-statements, e.g. `a ~ Normal()`, `b ~ Normal()`, ..., `z ~ Normal()` will result in a `NamedTuple` with 27 entries, potentially leading to long compilation times. + + For these scenarios it can be useful to fall back to "untyped" representations. + +Hence we obtain a "type-stable when possible"-representation by wrapping it in a `NamedTuple` and partially resolving the `getindex`, `setindex!`, etc. methods at compile-time. When type-stability is *not* desired, we can simply use a single `metadata` for all `VarName`s instead of a `NamedTuple` wrapping a collection of `metadata`s. + +## Efficient storage and iteration + +Efficient storage and iteration we achieve through implementation of the `metadata`. In particular, we do so with [`VarNamedVector`](@ref): + +```@docs +DynamicPPL.VarNamedVector +``` + +In a [`VarNamedVector{<:VarName,T}`](@ref), we achieve the desiderata by storing the values for different `VarName`s contiguously in a `Vector{T}` and keeping track of which ranges correspond to which `VarName`s. + +This does require a bit of book-keeping, in particular when it comes to insertions and deletions. Internally, this is handled by assigning each `VarName` a unique `Int` index in the `varname_to_index` field, which is then used to index into the following fields: + + - `varnames::Vector{<:VarName}`: the `VarName`s in the order they appear in the `Vector{T}`. + - `ranges::Vector{UnitRange{Int}}`: the ranges of indices in the `Vector{T}` that correspond to each `VarName`. + - `transforms::Vector`: the transforms associated with each `VarName`. + +Mutating functions, e.g. `setindex!(vnv::VarNamedVector, val, vn::VarName)`, are then treated according to the following rules: + + 1. If `vn` is not already present: add it to the end of `vnv.varnames`, add the `val` to the underlying `vnv.vals`, etc. + + 2. If `vn` is already present in `vnv`: + + 1. If `val` has the *same length* as the existing value for `vn`: replace existing value. + 2. If `val` has a *smaller length* than the existing value for `vn`: replace existing value and mark the remaining indices as "inactive" by increasing the entry in `vnv.num_inactive` field. + 3. If `val` has a *larger length* than the existing value for `vn`: expand the underlying `vnv.vals` to accommodate the new value, update all `VarName`s occuring after `vn`, and update the `vnv.ranges` to point to the new range for `vn`. + +This means that `VarNamedVector` is allowed to grow as needed, while "shrinking" (i.e. insertion of smaller elements) is handled by simply marking the redundant indices as "inactive". This turns out to be efficient for use-cases that we are generally interested in. + +For example, we want to optimize code-paths which effectively boil down to inner-loop in the following example: + +```julia +# Construct a `VarInfo` with types inferred from `model`. +varinfo = VarInfo(model) + +# Repeatedly sample from `model`. +for _ in 1:num_samples + rand!(rng, model, varinfo) + + # Do something with `varinfo`. + # ... +end +``` + +There are typically a few scenarios where we encounter changing representation sizes of a random variable `x`: + + 1. We're working with a transformed version `x` which is represented in a lower-dimensional space, e.g. transforming a `x ~ LKJ(2, 1)` to unconstrained `y = f(x)` takes us from 2-by-2 `Matrix{Float64}` to a 1-length `Vector{Float64}`. + 2. `x` has a random size, e.g. in a mixture model with a prior on the number of components. Here the size of `x` can vary widly between every realization of the `Model`. + +In scenario (1), we're usually *shrinking* the representation of `x`, and so we end up not making any allocations for the underlying `Vector{T}` but instead just marking the redundant part as "inactive". + +In scenario (2), we end up increasing the allocated memory for the randomly sized `x`, eventually leading to a vector that is large enough to hold realizations without needing to reallocate. But this can still lead to unnecessary memory usage, which might be undesirable. Hence one has to make a decision regarding the trade-off between memory usage and performance for the use-case at hand. + +To help with this, we have the following functions: + +```@docs +DynamicPPL.has_inactive +DynamicPPL.num_inactive +DynamicPPL.num_allocated +DynamicPPL.is_contiguous +DynamicPPL.contiguify! +``` + +For example, one might encounter the following scenario: + +```@example varinfo-design +vnv = DynamicPPL.VarNamedVector(@varname(x) => [true]) +println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))") + +for i in 1:5 + x = fill(true, rand(1:100)) + DynamicPPL.update!(vnv, @varname(x), x) + println( + "After insertion #$(i) of length $(length(x)): number of allocated entries $(DynamicPPL.num_allocated(vnv))", + ) +end +``` + +We can then insert a call to [`DynamicPPL.contiguify!`](@ref) after every insertion whenever the allocation grows too large to reduce overall memory usage: + +```@example varinfo-design +vnv = DynamicPPL.VarNamedVector(@varname(x) => [true]) +println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))") + +for i in 1:5 + x = fill(true, rand(1:100)) + DynamicPPL.update!(vnv, @varname(x), x) + if DynamicPPL.num_allocated(vnv) > 10 + DynamicPPL.contiguify!(vnv) + end + println( + "After insertion #$(i) of length $(length(x)): number of allocated entries $(DynamicPPL.num_allocated(vnv))", + ) +end +``` + +This does incur a runtime cost as it requires re-allocation of the `ranges` in addition to a `resize!` of the underlying `Vector{T}`. However, this also ensures that the the underlying `Vector{T}` is contiguous, which is important for performance. Hence, if we're about to do a lot of work with the `VarNamedVector` without insertions, etc., it can be worth it to do a sweep to ensure that the underlying `Vector{T}` is contiguous. + +!!! note + + Higher-dimensional arrays, e.g. `Matrix`, are handled by simply vectorizing them before storing them in the `Vector{T}`, and composing the `VarName`'s transformation with a `DynamicPPL.ReshapeTransform`. + +Continuing from the example from the previous section, we can use a `VarInfo` with a `VarNamedVector` as the `metadata` field: + +```@example varinfo-design +# Type-unstable +varinfo_untyped_vnv = DynamicPPL.VectorVarInfo(varinfo_untyped) +varinfo_untyped_vnv[@varname(x)], varinfo_untyped_vnv[@varname(y)] +``` + +```@example varinfo-design +# Type-stable +varinfo_typed_vnv = DynamicPPL.VectorVarInfo(varinfo_typed) +varinfo_typed_vnv[@varname(x)], varinfo_typed_vnv[@varname(y)] +``` + +If we now try to `delete!` `@varname(x)` + +```@example varinfo-design +haskey(varinfo_untyped_vnv, @varname(x)) +``` + +```@example varinfo-design +DynamicPPL.has_inactive(varinfo_untyped_vnv.metadata) +``` + +```@example varinfo-design +# `delete!` +DynamicPPL.delete!(varinfo_untyped_vnv.metadata, @varname(x)) +DynamicPPL.has_inactive(varinfo_untyped_vnv.metadata) +``` + +```@example varinfo-design +haskey(varinfo_untyped_vnv, @varname(x)) +``` + +Or insert a differently-sized value for `@varname(x)` + +```@example varinfo-design +DynamicPPL.update!(varinfo_untyped_vnv.metadata, @varname(x), fill(true, 1)) +varinfo_untyped_vnv[@varname(x)] +``` + +```@example varinfo-design +DynamicPPL.num_allocated(varinfo_untyped_vnv.metadata, @varname(x)) +``` + +```@example varinfo-design +DynamicPPL.update!(varinfo_untyped_vnv.metadata, @varname(x), fill(true, 4)) +varinfo_untyped_vnv[@varname(x)] +``` + +```@example varinfo-design +DynamicPPL.num_allocated(varinfo_untyped_vnv.metadata, @varname(x)) +``` + +### Performance summary + +In the end, we have the following "rough" performance characteristics for `VarNamedVector`: + +| Method | Is blazingly fast? | +|:----------------------------------------:|:--------------------------------------------------------------------------------------------:| +| `getindex` | ${\color{green} \checkmark}$ | +| `setindex!` | ${\color{green} \checkmark}$ | +| `push!` | ${\color{green} \checkmark}$ | +| `delete!` | ${\color{red} \times}$ | +| `update!` on existing `VarName` | ${\color{green} \checkmark}$ if smaller or same size / ${\color{red} \times}$ if larger size | +| `values_as(::VarNamedVector, Vector{T})` | ${\color{green} \checkmark}$ if contiguous / ${\color{orange} \div}$ otherwise | + +## Other methods + +```@docs +DynamicPPL.replace_values(::VarNamedVector, vals::AbstractVector) +``` + +```@docs; canonical=false +DynamicPPL.values_as(::VarNamedVector) +``` diff --git a/ext/DynamicPPLChainRulesCoreExt.jl b/ext/DynamicPPLChainRulesCoreExt.jl index 1c6e188fb..1559467f8 100644 --- a/ext/DynamicPPLChainRulesCoreExt.jl +++ b/ext/DynamicPPLChainRulesCoreExt.jl @@ -24,4 +24,6 @@ ChainRulesCore.@non_differentiable DynamicPPL.updategid!( # No need + causes issues for some AD backends, e.g. Zygote. ChainRulesCore.@non_differentiable DynamicPPL.infer_nested_eltype(x) +ChainRulesCore.@non_differentiable DynamicPPL.recontiguify_ranges!(ranges) + end # module diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 7c7fb216d..c91fb1fe0 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -42,6 +42,65 @@ function DynamicPPL.varnames(c::MCMCChains.Chains) return keys(c.info.varname_to_symbol) end +""" + generated_quantities(model::Model, chain::MCMCChains.Chains) + +Execute `model` for each of the samples in `chain` and return an array of the values +returned by the `model` for each sample. + +# Examples +## General +Often you might have additional quantities computed inside the model that you want to +inspect, e.g. +```julia +@model function demo(x) + # sample and observe + θ ~ Prior() + x ~ Likelihood() + return interesting_quantity(θ, x) +end +m = demo(data) +chain = sample(m, alg, n) +# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples +# from the posterior/`chain`: +generated_quantities(m, chain) # <= results in a `Vector` of returned values + # from `interesting_quantity(θ, x)` +``` +## Concrete (and simple) +```julia +julia> using DynamicPPL, Turing + +julia> @model function demo(xs) + s ~ InverseGamma(2, 3) + m_shifted ~ Normal(10, √s) + m = m_shifted - 10 + + for i in eachindex(xs) + xs[i] ~ Normal(m, √s) + end + + return (m, ) + end +demo (generic function with 1 method) + +julia> model = demo(randn(10)); + +julia> chain = sample(model, MH(), 10); + +julia> generated_quantities(model, chain) +10×1 Array{Tuple{Float64},2}: + (2.1964758025119338,) + (2.1964758025119338,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.043088571494005024,) + (-0.16489786710222099,) + (-0.16489786710222099,) +``` +""" function DynamicPPL.generated_quantities( model::DynamicPPL.Model, chain_full::MCMCChains.Chains ) @@ -49,14 +108,86 @@ function DynamicPPL.generated_quantities( varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) - # Update the varinfo with the current sample and make variables not present in `chain` - # to be sampled. - DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx) + if DynamicPPL.supports_varname_indexing(chain) + varname_pairs = _varname_pairs_with_varname_indexing( + chain, varinfo, sample_idx, chain_idx + ) + else + varname_pairs = _varname_pairs_without_varname_indexing( + chain, varinfo, sample_idx, chain_idx + ) + end + fixed_model = DynamicPPL.fix(model, Dict(varname_pairs)) + return fixed_model() + end +end + +""" + _varname_pairs_with_varname_indexing( + chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx + ) - # TODO: Some of the variables can be a view into the `varinfo`, so we need to - # `deepcopy` the `varinfo` before passing it to `model`. - model(deepcopy(varinfo)) +Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values +from the chain. + +This implementation assumes `chain` can be indexed using variable names, and is the +preffered implementation. +""" +function _varname_pairs_with_varname_indexing( + chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx +) + vns = DynamicPPL.varnames(chain) + vn_parents = Iterators.map(vns) do vn + # The call nested_setindex_maybe! is used to handle cases where vn is not + # the variable name used in the model, but rather subsumed by one. Except + # for the subsumption part, this could be + # vn => getindex_varname(chain, sample_idx, vn, chain_idx) + # TODO(mhauru) This call to nested_setindex_maybe! is unintuitive. + DynamicPPL.nested_setindex_maybe!( + varinfo, DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx), vn + ) end + varname_pairs = Iterators.map(Iterators.filter(!isnothing, vn_parents)) do vn_parent + vn_parent => varinfo[vn_parent] + end + return varname_pairs +end + +""" +Check which keys in `key_strings` are subsumed by `vn_string` and return the their values. + +The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and +won't catch all cases. We should get rid of this if we can. +""" +# TODO(mhauru) See docstring above. +function _vcat_subsumed_values(vn_string, values, key_strings) + indices = findall(Base.Fix1(DynamicPPL.subsumes_string, vn_string), key_strings) + return !isempty(indices) ? reduce(vcat, values[indices]) : nothing +end + +""" + _varname_pairs_without_varname_indexing( + chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx + ) + +Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values +from the chain. + +This implementation does not assume that `chain` can be indexed using variable names. It is +thus not guaranteed to work in cases where the variable names have complex subsumption +patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`. +""" +function _varname_pairs_without_varname_indexing( + chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx +) + values = chain.value[sample_idx, :, chain_idx] + keys = Base.keys(chain) + keys_strings = map(string, keys) + varname_pairs = [ + vn => _vcat_subsumed_values(string(vn), values, keys_strings) for + vn in Base.keys(varinfo) + ] + return varname_pairs end end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index eb027b45b..969d69936 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -45,7 +45,9 @@ export AbstractVarInfo, VarInfo, UntypedVarInfo, TypedVarInfo, + VectorVarInfo, SimpleVarInfo, + VarNamedVector, push!!, empty!!, subset, @@ -175,6 +177,7 @@ include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") +include("varnamedvector.jl") include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 7ddd09b2e..3f513d71d 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -295,7 +295,7 @@ julia> # Just use an example model to construct the `VarInfo` because we're lazy julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; julia> # For the sake of brevity, let's just check the type. - md = values_as(vi); md.s isa DynamicPPL.Metadata + md = values_as(vi); md.s isa Union{DynamicPPL.Metadata, DynamicPPL.VarNamedVector} true julia> values_as(vi, NamedTuple) @@ -321,7 +321,7 @@ julia> # Just use an example model to construct the `VarInfo` because we're lazy julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; julia> # For the sake of brevity, let's just check the type. - values_as(vi) isa DynamicPPL.Metadata + values_as(vi) isa Union{DynamicPPL.Metadata, Vector} true julia> values_as(vi, NamedTuple) @@ -349,7 +349,7 @@ Determine the default `eltype` of the values returned by `vi[spl]`. This should generally not be called explicitly, as it's only used in [`matchingvalue`](@ref) to determine the default type to use in place of type-parameters passed to the model. - + This method is considered legacy, and is likely to be deprecated in the future. """ function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior}) @@ -363,6 +363,13 @@ function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromP return eltype(T) end +""" + has_varnamedvector(varinfo::VarInfo) + +Returns `true` if `varinfo` uses `VarNamedVector` as metadata. +""" +has_varnamedvector(vi::AbstractVarInfo) = false + # TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert # the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which # might result in a `Vector{Any}`. @@ -554,7 +561,7 @@ end link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) -Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`. +Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`. If `t` is not provided, `default_transformation(model, vi)` will be used. @@ -573,7 +580,7 @@ end invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) -Transform the variables in `vi` to their constrained space, using the (inverse of) +Transform the variables in `vi` to their constrained space, using the (inverse of) transformation `t`, mutating `vi` if possible. If `t` is not provided, `default_transformation(model, vi)` will be used. diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 13231837f..1961965ca 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -240,7 +240,10 @@ function assume( if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") + # TODO(mhauru) Is it important to unset the flag here? The `true` allows us + # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if + # that's okay. + unset_flag!(vi, vn, "del", true) r = init(rng, dist, sampler) f = to_maybe_linked_internal_transform(vi, vn, dist) BangBang.setindex!!(vi, f(r), vn) @@ -516,7 +519,10 @@ function get_and_set_val!( if haskey(vi, vns[1]) # Always overwrite the parameters with new ones for `SampleFromUniform`. if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - unset_flag!(vi, vns[1], "del") + # TODO(mhauru) Is it important to unset the flag here? The `true` allows us + # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if + # that's okay. + unset_flag!(vi, vns[1], "del", true) r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] @@ -554,7 +560,10 @@ function get_and_set_val!( if haskey(vi, vns[1]) # Always overwrite the parameters with new ones for `SampleFromUniform`. if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - unset_flag!(vi, vns[1], "del") + # TODO(mhauru) Is it important to unset the flag here? The `true` allows us + # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if + # that's okay. + unset_flag!(vi, vns[1], "del", true) f = (vn, dist) -> init(rng, dist, spl) r = f.(vns, dists) for i in eachindex(vns) diff --git a/src/model.jl b/src/model.jl index 082ec3871..2a1a6db88 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1203,74 +1203,6 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC end end -""" - generated_quantities(model::Model, chain::AbstractChains) - -Execute `model` for each of the samples in `chain` and return an array of the values -returned by the `model` for each sample. - -# Examples -## General -Often you might have additional quantities computed inside the model that you want to -inspect, e.g. -```julia -@model function demo(x) - # sample and observe - θ ~ Prior() - x ~ Likelihood() - return interesting_quantity(θ, x) -end -m = demo(data) -chain = sample(m, alg, n) -# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples -# from the posterior/`chain`: -generated_quantities(m, chain) # <= results in a `Vector` of returned values - # from `interesting_quantity(θ, x)` -``` -## Concrete (and simple) -```julia -julia> using DynamicPPL, Turing - -julia> @model function demo(xs) - s ~ InverseGamma(2, 3) - m_shifted ~ Normal(10, √s) - m = m_shifted - 10 - - for i in eachindex(xs) - xs[i] ~ Normal(m, √s) - end - - return (m, ) - end -demo (generic function with 1 method) - -julia> model = demo(randn(10)); - -julia> chain = sample(model, MH(), 10); - -julia> generated_quantities(model, chain) -10×1 Array{Tuple{Float64},2}: - (2.1964758025119338,) - (2.1964758025119338,) - (0.09270081916291417,) - (0.09270081916291417,) - (0.09270081916291417,) - (0.09270081916291417,) - (0.09270081916291417,) - (0.043088571494005024,) - (-0.16489786710222099,) - (-0.16489786710222099,) -``` -""" -function generated_quantities(model::Model, chain::AbstractChains) - varinfo = VarInfo(model) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - return map(iters) do (sample_idx, chain_idx) - setval_and_resample!(varinfo, chain, sample_idx, chain_idx) - model(varinfo) - end -end - """ generated_quantities(model::Model, parameters::NamedTuple) generated_quantities(model::Model, values, keys) @@ -1297,7 +1229,7 @@ demo (generic function with 2 methods) julia> model = demo(randn(10)); -julia> parameters = (; s = 1.0, m_shifted=10); +julia> parameters = (; s = 1.0, m_shifted=10.0); julia> generated_quantities(model, parameters) (0.0,) @@ -1307,13 +1239,10 @@ julia> generated_quantities(model, values(parameters), keys(parameters)) ``` """ function generated_quantities(model::Model, parameters::NamedTuple) - varinfo = VarInfo(model) - setval_and_resample!(varinfo, values(parameters), keys(parameters)) - return model(varinfo) + fixed_model = fix(model, parameters) + return fixed_model() end function generated_quantities(model::Model, values, keys) - varinfo = VarInfo(model) - setval_and_resample!(varinfo, values, keys) - return model(varinfo) + return generated_quantities(model, NamedTuple{keys}(values)) end diff --git a/src/sampler.jl b/src/sampler.jl index cfc58942e..833aaf7e2 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -150,7 +150,7 @@ function set_values!!( flattened_param_vals = varinfo[spl] length(flattened_param_vals) == length(initial_params) || throw( DimensionMismatch( - "Provided initial value size ($(length(initial_params))) doesn't match the model size ($(length(theta)))", + "Provided initial value size ($(length(initial_params))) doesn't match the model size ($(length(flattened_param_vals)))", ), ) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index d8afb9cec..06a151f82 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -322,15 +322,17 @@ Base.getindex(vi::SimpleVarInfo, vn::VarName) = getindex_internal(vi, vn) function Base.getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) return map(Base.Fix1(getindex, vi), vns) end -# HACK: Needed to disambiguiate. +# HACK: Needed to disambiguate. Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector) getindex_internal(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) # `AbstractDict` -function getindex_internal(vi::SimpleVarInfo{<:AbstractDict}, vn::VarName) - return nested_getindex(vi.values, vn) +function getindex_internal( + vi::SimpleVarInfo{<:Union{AbstractDict,VarNamedVector}}, vn::VarName +) + return getvalue(vi.values, vn) end Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) @@ -399,14 +401,28 @@ end function BangBang.push!!( vi::SimpleVarInfo{<:AbstractDict}, vn::VarName, - r, + value, dist::Distribution, gidset::Set{Selector}, ) - vi.values[vn] = r + vi.values[vn] = value return vi end +function BangBang.push!!( + vi::SimpleVarInfo{<:VarNamedVector}, + vn::VarName, + value, + dist::Distribution, + gidset::Set{Selector}, +) + # The semantics of push!! for SimpleVarInfo and VarNamedVector are different. For + # SimpleVarInfo, push!! allows the key to exist already, for VarNamedVector it does not. + # Hence we need to call update!! here, which has the same semantics as push!! does for + # SimpleVarInfo. + return Accessors.@set vi.values = update!!(vi.values, vn, value) +end + const SimpleOrThreadSafeSimple{T,V,C} = Union{ SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}} } @@ -456,6 +472,8 @@ function _subset(x::NamedTuple, vns) return NamedTuple{Tuple(syms)}(Tuple(map(Base.Fix1(getindex, x), syms))) end +_subset(x::VarNamedVector, vns) = subset(x, vns) + # `merge` function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) values = merge(varinfo_left.values, varinfo_right.values) @@ -563,6 +581,9 @@ end function values_as(vi::SimpleVarInfo{<:AbstractDict}, ::Type{NamedTuple}) return NamedTuple((Symbol(k), v) for (k, v) in vi.values) end +function values_as(vi::SimpleVarInfo, ::Type{T}) where {T} + return values_as(vi.values, T) +end """ logjoint(model::Model, θ) @@ -708,3 +729,5 @@ end function ThreadSafeVarInfo(vi::SimpleVarInfo{<:Any,<:Ref}) return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()]) end + +has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector diff --git a/src/test_utils.jl b/src/test_utils.jl index 6f7481c40..9a606b4ef 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -37,20 +37,35 @@ function setup_varinfos( model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false ) # VarInfo - vi_untyped = VarInfo() - model(vi_untyped) - vi_typed = DynamicPPL.TypedVarInfo(vi_untyped) + vi_untyped_metadata = VarInfo(DynamicPPL.Metadata()) + vi_untyped_vnv = VarInfo(DynamicPPL.VarNamedVector()) + model(vi_untyped_metadata) + model(vi_untyped_vnv) + vi_typed_metadata = DynamicPPL.TypedVarInfo(vi_untyped_metadata) + vi_typed_vnv = DynamicPPL.TypedVarInfo(vi_untyped_vnv) + # SimpleVarInfo svi_typed = SimpleVarInfo(example_values) svi_untyped = SimpleVarInfo(OrderedDict()) + svi_vnv = SimpleVarInfo(VarNamedVector()) # SimpleVarInfo{<:Any,<:Ref} svi_typed_ref = SimpleVarInfo(example_values, Ref(getlogp(svi_typed))) svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped))) + svi_vnv_ref = SimpleVarInfo(VarNamedVector(), Ref(getlogp(svi_vnv))) - lp = getlogp(vi_typed) + lp = getlogp(vi_typed_metadata) varinfos = map(( - vi_untyped, vi_typed, svi_typed, svi_untyped, svi_typed_ref, svi_untyped_ref + vi_untyped_metadata, + vi_untyped_vnv, + vi_typed_metadata, + vi_typed_vnv, + svi_typed, + svi_untyped, + svi_vnv, + svi_typed_ref, + svi_untyped_ref, + svi_vnv_ref, )) do vi # Set them all to the same values. DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 4fbf0d124..ec890a674 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -55,6 +55,8 @@ function setlogp!!(vi::ThreadSafeVarInfoWithRef, logp) return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), vi.logps) end +has_varnamedvector(vi::DynamicPPL.ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) + function BangBang.push!!( vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) @@ -188,8 +190,10 @@ end values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo) values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) -function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) - return unset_flag!(vi.varinfo, vn, flag) +function unset_flag!( + vi::ThreadSafeVarInfo, vn::VarName, flag::String, ignoreable::Bool=false +) + return unset_flag!(vi.varinfo, vn, flag, ignoreable) end function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) return is_flagged(vi.varinfo, vn, flag) diff --git a/src/utils.jl b/src/utils.jl index 9ddeb6247..fed303021 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -48,7 +48,7 @@ true i.e., regardless of whether you evaluate the log prior, the log likelihood or the log joint density. If you would like to avoid this behaviour you should check the evaluation context. It can be accessed with the internal variable `__context__`. - For instance, in the following example the log density is not accumulated when only the log prior is computed: + For instance, in the following example the log density is not accumulated when only the log prior is computed: ```jldoctest; setup = :(using Distributions) julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x); @@ -225,21 +225,45 @@ invlink_transform(dist) = inverse(link_transform(dist)) # Helper functions for vectorize/reconstruct values # ##################################################### -# Useful transformation going from the flattened representation. -struct FromVec{Size} <: Bijectors.Bijector +""" + UnwrapSingletonTransform + +A transformation that unwraps a singleton array into a scalar. + +This transformation can be inverted by calling `tovec`. +""" +struct UnwrapSingletonTransform <: Bijectors.Bijector end + +(f::UnwrapSingletonTransform)(x) = only(x) + +Bijectors.with_logabsdet_jacobian(f::UnwrapSingletonTransform, x) = (f(x), 0) +function Bijectors.with_logabsdet_jacobian( + ::Bijectors.Inverse{<:UnwrapSingletonTransform}, x +) + return (tovec(x), 0) +end + +""" + ReshapeTransform(size::Size) + +A `Bijector` that transforms an `AbstractVector` to a realization of size `size`. + +This transformation can be inverted by calling `tovec`. +""" +struct ReshapeTransform{Size} <: Bijectors.Bijector size::Size end -FromVec(x::Union{Real,AbstractArray}) = FromVec(size(x)) +ReshapeTransform(x::AbstractArray) = ReshapeTransform(size(x)) # TODO: Should we materialize the `reshape`? -(f::FromVec)(x) = reshape(x, f.size) -(f::FromVec{Tuple{}})(x) = only(x) -# TODO: Specialize for `Tuple{<:Any}` since this correspond to a `Vector`. +(f::ReshapeTransform)(x) = reshape(x, f.size) -Bijectors.with_logabsdet_jacobian(f::FromVec, x) = (f(x), 0) -# We want to use the inverse of `FromVec` so it preserves the size information. -Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:FromVec}, x) = (tovec(x), 0) +Bijectors.with_logabsdet_jacobian(f::ReshapeTransform, x) = (f(x), 0) +# We want to use the inverse of `ReshapeTransform` so it preserves the size information. +function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ReshapeTransform}, x) + return (tovec(x), 0) +end struct ToChol <: Bijectors.Bijector uplo::Char @@ -253,16 +277,18 @@ Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y::Cholesky) = Return the transformation from the vector representation of `x` to original representation. """ -from_vec_transform(x::Union{Real,AbstractArray}) = from_vec_transform_for_size(size(x)) -from_vec_transform(C::Cholesky) = ToChol(C.uplo) ∘ FromVec(size(C.UL)) +from_vec_transform(x::AbstractArray) = from_vec_transform_for_size(size(x)) +from_vec_transform(C::Cholesky) = ToChol(C.uplo) ∘ ReshapeTransform(size(C.UL)) +from_vec_transform(::Real) = UnwrapSingletonTransform() """ from_vec_transform_for_size(sz::Tuple) -Return the transformation from the vector representation of a realization of size `sz` to original representation. +Return the transformation from the vector representation of a realization of size `sz` to +original representation. """ -from_vec_transform_for_size(sz::Tuple) = FromVec(sz) -from_vec_transform_for_size(::Tuple{()}) = FromVec(()) +from_vec_transform_for_size(sz::Tuple) = ReshapeTransform(sz) +# TODO(mhauru) Is the below used? If not, this function can be removed. from_vec_transform_for_size(::Tuple{<:Any}) = identity """ @@ -272,7 +298,8 @@ Return the transformation from the vector representation of a realization from distribution `dist` to the original representation compatible with `dist`. """ from_vec_transform(dist::Distribution) = from_vec_transform_for_size(size(dist)) -from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ∘ FromVec(size(dist)) +from_vec_transform(::UnivariateDistribution) = UnwrapSingletonTransform() +from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ∘ ReshapeTransform(size(dist)) """ from_vec_transform(f, size::Tuple) @@ -300,6 +327,17 @@ function from_linked_vec_transform(dist::Distribution) return f_invlink ∘ f_vec end +# UnivariateDistributions need to be handled as a special case, because size(dist) is (), +# which makes the usual machinery think we are dealing with a 0-dim array, whereas in +# actuality we are dealing with a scalar. +# TODO(mhauru) Hopefully all this can go once the old Gibbs sampler is removed and +# VarNamedVector takes over from Metadata. +function from_linked_vec_transform(dist::UnivariateDistribution) + f_invlink = invlink_transform(dist) + f_vec = from_vec_transform(inverse(f_invlink), size(dist)) + return UnwrapSingletonTransform() ∘ f_invlink ∘ f_vec +end + # Specializations that circumvent the `from_vec_transform` machinery. function from_linked_vec_transform(dist::LKJCholesky) return inverse(Bijectors.VecCholeskyBijector(dist.uplo)) @@ -854,6 +892,7 @@ end Return type corresponding to `float(typeof(x))` if possible; otherwise return `Real`. """ float_type_with_fallback(::Type) = Real +float_type_with_fallback(::Type{Union{}}) = Real float_type_with_fallback(::Type{T}) where {T<:Real} = float(T) """ diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 52ba6eb61..c5003d53a 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -177,7 +177,7 @@ julia> # Approach 1: Convert back to constrained space using `invlink` and extra julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions # used in the very first model evaluation, hence the support of `y` # is not updated even though `x` has changed. - lb ≤ varinfo_invlinked[@varname(y)] ≤ ub + lb ≤ first(varinfo_invlinked[@varname(y)]) ≤ ub false julia> # Approach 2: Extract realizations using `values_as_in_model`. diff --git a/src/varinfo.jl b/src/varinfo.jl index 2670397d9..8b548cc14 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -101,6 +101,7 @@ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo logp::Base.RefValue{Tlogp} num_produce::Base.RefValue{Int} end +const VectorVarInfo = VarInfo{<:VarNamedVector} const UntypedVarInfo = VarInfo{<:Metadata} const TypedVarInfo = VarInfo{<:NamedTuple} const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ @@ -119,6 +120,49 @@ function VarInfo(old_vi::VarInfo, spl, x::AbstractVector) ) end +# No-op if we're already working with a `VarNamedVector`. +metadata_to_varnamedvector(vnv::VarNamedVector) = vnv +function metadata_to_varnamedvector(md::Metadata) + idcs = copy(md.idcs) + vns = copy(md.vns) + ranges = copy(md.ranges) + vals = copy(md.vals) + is_unconstrained = map(Base.Fix1(istrans, md), md.vns) + transforms = map(md.dists, is_unconstrained) do dist, trans + if trans + return from_linked_vec_transform(dist) + else + return from_vec_transform(dist) + end + end + + return VarNamedVector( + OrderedDict{eltype(keys(idcs)),Int}(idcs), + vns, + ranges, + vals, + transforms, + is_unconstrained, + ) +end + +function VectorVarInfo(vi::UntypedVarInfo) + md = metadata_to_varnamedvector(vi.metadata) + lp = getlogp(vi) + return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) +end + +function VectorVarInfo(vi::TypedVarInfo) + md = map(metadata_to_varnamedvector, vi.metadata) + lp = getlogp(vi) + return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) +end + +function has_varnamedvector(vi::VarInfo) + return vi.metadata isa VarNamedVector || + (vi isa TypedVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata))) +end + """ untyped_varinfo([rng, ]model[, sampler, context]) @@ -129,11 +173,12 @@ function untyped_varinfo( model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), + metadata_type::Type=VarNamedVector, ) - varinfo = VarInfo() + varinfo = VarInfo(metadata_type()) return last(evaluate!!(model, varinfo, SamplingContext(rng, sampler, context))) end -function untyped_varinfo(model::Model, args::Union{AbstractSampler,AbstractContext}...) +function untyped_varinfo(model::Model, args::Union{AbstractSampler,AbstractContext,Type}...) return untyped_varinfo(Random.default_rng(), model, args...) end @@ -149,15 +194,53 @@ function VarInfo( model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), + # TODO(mhauru) Revisit the default. We probably don't want it to be VarNamedVector just + # yet. + metadata_type::Type=VarNamedVector, ) - return typed_varinfo(rng, model, sampler, context) + return typed_varinfo(rng, model, sampler, context, metadata_type) end VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...) unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x) # TODO: deprecate. -unflatten(vi::VarInfo, spl::AbstractSampler, x::AbstractVector) = VarInfo(vi, spl, x) +function unflatten(vi::VarInfo, spl::AbstractSampler, x::AbstractVector) + md = unflatten(vi.metadata, spl, x) + return VarInfo(md, Base.RefValue{eltype(x)}(getlogp(vi)), Ref(get_num_produce(vi))) +end + +# The Val(getspace(spl)) is used to dispatch into the below generated function. +function unflatten(metadata::NamedTuple, spl::AbstractSampler, x::AbstractVector) + return unflatten(metadata, Val(getspace(spl)), x) +end + +@generated function unflatten( + metadata::NamedTuple{names}, ::Val{space}, x +) where {names,space} + exprs = [] + offset = :(0) + for f in names + mdf = :(metadata.$f) + if inspace(f, space) || length(space) == 0 + len = :(sum(length, $mdf.ranges)) + push!(exprs, :($f = unflatten($mdf, x[($offset + 1):($offset + $len)]))) + offset = :($offset + $len) + else + push!(exprs, :($f = $mdf)) + end + end + length(exprs) == 0 && return :(NamedTuple()) + return :($(exprs...),) +end + +# For Metadata unflatten and replace_values are the same. For VarNamedVector they are not. +function unflatten(md::Metadata, x::AbstractVector) + return replace_values(md, x) +end +function unflatten(md::Metadata, spl::AbstractSampler, x::AbstractVector) + return replace_values(md, spl, x) +end # without AbstractSampler function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) @@ -256,13 +339,22 @@ end function subset(varinfo::UntypedVarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, varinfo.logp, varinfo.num_produce) + return VarInfo(metadata, deepcopy(varinfo.logp), deepcopy(varinfo.num_produce)) +end + +function subset(varinfo::VectorVarInfo, vns::AbstractVector{<:VarName}) + metadata = subset(varinfo.metadata, vns) + return VarInfo(metadata, deepcopy(varinfo.logp), deepcopy(varinfo.num_produce)) end function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName{sym}}) where {sym} # If all the variables are using the same symbol, then we can just extract that field from the metadata. metadata = subset(getfield(varinfo.metadata, sym), vns) - return VarInfo(NamedTuple{(sym,)}(tuple(metadata)), varinfo.logp, varinfo.num_produce) + return VarInfo( + NamedTuple{(sym,)}(tuple(metadata)), + deepcopy(varinfo.logp), + deepcopy(varinfo.num_produce), + ) end function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName}) @@ -271,7 +363,9 @@ function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName}) subset(getfield(varinfo.metadata, sym), filter(==(sym) ∘ getsym, vns)) end - return VarInfo(NamedTuple{syms}(metadatas), varinfo.logp, varinfo.num_produce) + return VarInfo( + NamedTuple{syms}(metadatas), deepcopy(varinfo.logp), deepcopy(varinfo.num_produce) + ) end function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName}) @@ -338,6 +432,10 @@ function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) ) end +function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) + return merge(vnv_left, vnv_right) +end + @generated function merge_metadata( metadata_left::NamedTuple{names_left}, metadata_right::NamedTuple{names_right} ) where {names_left,names_right} @@ -528,6 +626,10 @@ Return the distribution from which `vn` was sampled in `vi`. """ getdist(vi::VarInfo, vn::VarName) = getdist(getmetadata(vi, vn), vn) getdist(md::Metadata, vn::VarName) = md.dists[getidx(md, vn)] +# TODO(mhauru) Remove this once the old Gibbs sampler stuff is gone. +function getdist(::VarNamedVector, ::VarName) + throw(ErrorException("getdist does not exist for VarNamedVector")) +end getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, vn), vn) # TODO(torfjelde): Use `view` instead of `getindex`. Requires addressing type-stability issues though, @@ -571,6 +673,7 @@ function getall(md::Metadata) Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0) ) end +getall(vnv::VarNamedVector) = getindex_internal(vnv, Colon()) """ setall!(vi::VarInfo, val) @@ -586,6 +689,12 @@ function _setall!(metadata::Metadata, val) metadata.vals[r] .= val[r] end end +function _setall!(vnv::VarNamedVector, val) + # TODO(mhauru) Do something more efficient here. + for i in 1:length(vnv) + vnv[i] = val[i] + end +end @generated function _setall!(metadata::NamedTuple{names}, val) where {names} expr = Expr(:block) start = :(1) @@ -698,7 +807,7 @@ end length(exprs) == 0 && return :(NamedTuple()) return :($(exprs...),) end -@inline function findinds(f_meta, s, ::Val{space}) where {space} +@inline function findinds(f_meta::Metadata, s, ::Val{space}) where {space} # Get all the idcs of the vns in `space` and that belong to the selector `s` return filter( (i) -> @@ -707,11 +816,27 @@ end 1:length(f_meta.gids), ) end -@inline function findinds(f_meta) +@inline function findinds(f_meta::Metadata) # Get all the idcs of the vns return filter((i) -> isempty(f_meta.gids[i]), 1:length(f_meta.gids)) end +function findinds(vnv::VarNamedVector, ::Selector, ::Val{space}) where {space} + # New Metadata objects are created with an empty list of gids, which is intrepreted as + # all Selectors applying to all variables. We assume the same behavior for + # VarNamedVector, and thus ignore the Selector argument. + if space !== () + msg = "VarNamedVector does not support selecting variables based on samplers" + throw(ErrorException(msg)) + else + return findinds(vnv) + end +end + +function findinds(vnv::VarNamedVector) + return 1:length(vnv.varnames) +end + # Get all vns of variables belonging to spl _getvns(vi::VarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl))) function _getvns(vi::VarInfo, spl::Union{SampleFromPrior,SampleFromUniform}) @@ -727,7 +852,7 @@ end @generated function _getvns(metadata, idcs::NamedTuple{names}) where {names} exprs = [] for f in names - push!(exprs, :($f = metadata.$f.vns[idcs.$f])) + push!(exprs, :($f = Base.keys(metadata.$f)[idcs.$f])) end length(exprs) == 0 && return :(NamedTuple()) return :($(exprs...),) @@ -774,6 +899,8 @@ end return results end +# TODO(mhauru) These set_flag! methods return the VarInfo. They should probably be called +# set_flag!!. """ set_flag!(vi::VarInfo, vn::VarName, flag::String) @@ -787,13 +914,32 @@ function set_flag!(md::Metadata, vn::VarName, flag::String) return md.flags[flag][getidx(md, vn)] = true end +function set_flag!(vnv::VarNamedVector, ::VarName, flag::String) + if flag == "del" + # The "del" flag is effectively always set for a VarNamedVector, so this is a no-op. + else + throw(ErrorException("Flag $flag not valid for VarNamedVector")) + end + return vnv +end + #### #### APIs for typed and untyped VarInfo #### # VarInfo -VarInfo(meta=Metadata()) = VarInfo(meta, Ref{Float64}(0.0), Ref(0)) +# TODO(mhauru) Revisit the default for meta. We probably should keep it as Metadata as long +# as the old Gibbs sampler is in use. +VarInfo(meta=VarNamedVector()) = VarInfo(meta, Ref{Float64}(0.0), Ref(0)) + +function TypedVarInfo(vi::VectorVarInfo) + new_metas = group_by_symbol(vi.metadata) + logp = getlogp(vi) + num_produce = get_num_produce(vi) + nt = NamedTuple(new_metas) + return VarInfo(nt, Ref(logp), Ref(num_produce)) +end """ TypedVarInfo(vi::UntypedVarInfo) @@ -905,8 +1051,14 @@ end Add `gid` to the set of sampler selectors associated with `vn` in `vi`. """ -function setgid!(vi::VarInfo, gid::Selector, vn::VarName) - return push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) +setgid!(vi::VarInfo, gid::Selector, vn::VarName) = setgid!(getmetadata(vi, vn), gid, vn) + +function setgid!(m::Metadata, gid::Selector, vn::VarName) + return push!(m.gids[getidx(m, vn)], gid) +end + +function setgid!(vnv::VarNamedVector, gid::Selector, vn::VarName) + throw(ErrorException("Calling setgid! on a VarNamedVector isn't valid.")) end istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn) @@ -953,18 +1105,18 @@ and parameters sampled in `vi` to 0. """ reset_num_produce!(vi::VarInfo) = set_num_produce!(vi, 0) -isempty(vi::UntypedVarInfo) = isempty(vi.metadata.idcs) -isempty(vi::TypedVarInfo) = _isempty(vi.metadata) +# Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple). +isempty(vi::VarInfo) = _isempty(vi.metadata) +_isempty(metadata::Metadata) = isempty(metadata.idcs) +_isempty(vnv::VarNamedVector) = isempty(vnv) @generated function _isempty(metadata::NamedTuple{names}) where {names} - expr = Expr(:&&, :true) - for f in names - push!(expr.args, :(isempty(metadata.$f.idcs))) - end - return expr + return Expr(:&&, (:(isempty(metadata.$f)) for f in names)...) end # X -> R for all variables associated with given sampler function link!!(t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return link(t, vi, spl, model) # Call `_link!` instead of `link!` to avoid deprecation warning. _link!(vi, spl) return vi @@ -1007,10 +1159,8 @@ function _link!(vi::UntypedVarInfo, spl::AbstractSampler) vns = _getvns(vi, spl) if ~istrans(vi, vns[1]) for vn in vns - dist = getdist(vi, vn) - _inner_transform!( - vi, vn, dist, internal_to_linked_internal_transform(vi, vn, dist) - ) + f = internal_to_linked_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) settrans!!(vi, true, vn) end else @@ -1037,13 +1187,8 @@ end if ~istrans(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns - dist = getdist(vi, vn) - _inner_transform!( - vi, - vn, - dist, - internal_to_linked_internal_transform(vi, vn, dist), - ) + f = internal_to_linked_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) settrans!!(vi, true, vn) end else @@ -1060,6 +1205,8 @@ end function invlink!!( t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model ) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return invlink(t, vi, spl, model) # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. _invlink!(vi, spl) return vi @@ -1111,10 +1258,8 @@ function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) vns = _getvns(vi, spl) if istrans(vi, vns[1]) for vn in vns - dist = getdist(vi, vn) - _inner_transform!( - vi, vn, dist, linked_internal_to_internal_transform(vi, vn, dist) - ) + f = linked_internal_to_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) settrans!!(vi, false, vn) end else @@ -1141,13 +1286,8 @@ end if istrans(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns - dist = getdist(vi, vn) - _inner_transform!( - vi, - vn, - dist, - linked_internal_to_internal_transform(vi, vn, dist), - ) + f = linked_internal_to_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) settrans!!(vi, false, vn) end else @@ -1160,11 +1300,11 @@ end return expr end -function _inner_transform!(vi::VarInfo, vn::VarName, dist, f) - return _inner_transform!(getmetadata(vi, vn), vi, vn, dist, f) +function _inner_transform!(vi::VarInfo, vn::VarName, f) + return _inner_transform!(getmetadata(vi, vn), vi, vn, f) end -function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, dist, f) +function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) # TODO: Use inplace versions to avoid allocations yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, vn)) # Determine the new range. @@ -1202,10 +1342,12 @@ function link( return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, spl, model) end -function _link(model::Model, varinfo::UntypedVarInfo, spl::AbstractSampler) +function _link( + model::Model, varinfo::Union{UntypedVarInfo,VectorVarInfo}, spl::AbstractSampler +) varinfo = deepcopy(varinfo) return VarInfo( - _link_metadata!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), + _link_metadata!!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) @@ -1229,7 +1371,7 @@ end vals = Expr(:tuple) for f in names if inspace(f, space) || length(space) == 0 - push!(vals.args, :(_link_metadata!(model, varinfo, metadata.$f, vns.$f))) + push!(vals.args, :(_link_metadata!!(model, varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end @@ -1237,7 +1379,7 @@ end return :(NamedTuple{$names}($vals)) end -function _link_metadata!(model::Model, varinfo::VarInfo, metadata::Metadata, target_vns) +function _link_metadata!!(model::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. @@ -1257,7 +1399,7 @@ function _link_metadata!(model::Model, varinfo::VarInfo, metadata::Metadata, tar yvec = tovec(y) # Accumulate the log-abs-det jacobian correction. acclogp!!(varinfo, -logjac) - # Mark as no longer transformed. + # Mark as transformed. settrans!!(varinfo, true, vn) # Return the vectorized transformed value. return yvec @@ -1285,6 +1427,29 @@ function _link_metadata!(model::Model, varinfo::VarInfo, metadata::Metadata, tar ) end +function _link_metadata!!( + model::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns +) + vns = target_vns === nothing ? keys(metadata) : target_vns + dists = extract_priors(model, varinfo) + for vn in vns + # First transform from however the variable is stored in vnv to the model + # representation. + transform_to_orig = gettransform(metadata, vn) + val_old = getindex_internal(metadata, vn) + val_orig, logjac1 = with_logabsdet_jacobian(transform_to_orig, val_old) + # Then transform from the model representation to the linked representation. + transform_from_linked = from_linked_vec_transform(dists[vn]) + transform_to_linked = inverse(transform_from_linked) + val_new, logjac2 = with_logabsdet_jacobian(transform_to_linked, val_orig) + # TODO(mhauru) We are calling a !! function but ignoring the return value. + acclogp!!(varinfo, -logjac1 - logjac2) + metadata = update!!(metadata, vn, val_new, transform_from_linked) + settrans!(metadata, true, vn) + end + return metadata +end + function invlink( ::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model ) @@ -1304,7 +1469,7 @@ end function _invlink(model::Model, varinfo::VarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _invlink_metadata!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), + _invlink_metadata!!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) @@ -1328,7 +1493,7 @@ end vals = Expr(:tuple) for f in names if inspace(f, space) || length(space) == 0 - push!(vals.args, :(_invlink_metadata!(model, varinfo, metadata.$f, vns.$f))) + push!(vals.args, :(_invlink_metadata!!(model, varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end @@ -1336,13 +1501,13 @@ end return :(NamedTuple{$names}($vals)) end -function _invlink_metadata!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) +function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn # Return early if we're already in constrained space OR if we're not - # supposed to touch this `vn`, e.g. when `vn` does not belong to the current sampler. + # supposed to touch this `vn`, e.g. when `vn` does not belong to the current sampler. # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. if !istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] @@ -1385,14 +1550,30 @@ function _invlink_metadata!(::Model, varinfo::VarInfo, metadata::Metadata, targe ) end +function _invlink_metadata!!( + model::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns +) + vns = target_vns === nothing ? keys(metadata) : target_vns + for vn in vns + transform = gettransform(metadata, vn) + old_val = getindex_internal(metadata, vn) + new_val, logjac = with_logabsdet_jacobian(transform, old_val) + # TODO(mhauru) We are calling a !! function but ignoring the return value. + acclogp!!(varinfo, -logjac) + metadata = update!!(metadata, vn, new_val) + settrans!(metadata, false, vn) + end + return metadata +end + """ islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior}) Check whether `vi` is in the transformed space for a particular sampler `spl`. -Turing's Hamiltonian samplers use the `link` and `invlink` functions from +Turing's Hamiltonian samplers use the `link` and `invlink` functions from [Bijectors.jl](https://github.com/TuringLang/Bijectors.jl) to map a constrained variable -(for example, one bounded to the space `[0, 1]`) from its constrained space to the set of +(for example, one bounded to the space `[0, 1]`) from its constrained space to the set of real numbers. `islinked` checks if the number is in the constrained space or the real space. """ function islinked(vi::UntypedVarInfo, spl::Union{Sampler,SampleFromPrior}) @@ -1423,9 +1604,11 @@ function nested_setindex_maybe!( nothing end end -function _nested_setindex_maybe!(vi::VarInfo, md::Metadata, val, vn::VarName) +function _nested_setindex_maybe!( + vi::VarInfo, md::Union{Metadata,VarNamedVector}, val, vn::VarName +) # If `vn` is in `vns`, then we can just use the standard `setindex!`. - vns = md.vns + vns = Base.keys(md) if vn in vns setindex!(vi, val, vn) return vn @@ -1436,8 +1619,7 @@ function _nested_setindex_maybe!(vi::VarInfo, md::Metadata, val, vn::VarName) i === nothing && return nothing vn_parent = vns[i] - dist = getdist(md, vn_parent) - val_parent = getindex(vi, vn_parent, dist) # TODO: Ensure that we're working with a view here. + val_parent = getindex(vi, vn_parent) # TODO: Ensure that we're working with a view here. # Split the varname into its tail optic. optic = remove_parent_optic(vn_parent, vn) # Update the value for the parent. @@ -1448,7 +1630,10 @@ end # The default getindex & setindex!() for get & set values # NOTE: vi[vn] will always transform the variable to its original space and Julia type -getindex(vi::VarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn)) +function getindex(vi::VarInfo, vn::VarName) + return from_maybe_linked_internal_transform(vi, vn)(getindex_internal(vi, vn)) +end + function getindex(vi::VarInfo, vn::VarName, dist::Distribution) @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" val = getindex_internal(vi, vn) @@ -1456,13 +1641,34 @@ function getindex(vi::VarInfo, vn::VarName, dist::Distribution) end function getindex(vi::VarInfo, vns::Vector{<:VarName}) - vals_linked = mapreduce(vcat, vns) do vn - getindex(vi, vn) + vals = map(vn -> getindex(vi, vn), vns) + + et = eltype(vals) + # This will catch type unstable cases, where vals has mixed types. + if !isconcretetype(et) + throw(ArgumentError("All variables must have the same type.")) + end + + if et <: Vector + all_of_equal_dimension = all(x -> length(x) == length(vals[1]), vals) + if !all_of_equal_dimension + throw(ArgumentError("All variables must have the same dimension.")) + end + end + + # TODO(mhauru) I'm not very pleased with the return type varying like this, even though + # this should be type stable. + vec_vals = reduce(vcat, vals) + if et <: Vector + # The individual variables are multivariate, and thus we return the values as a + # matrix. + return reshape(vec_vals, (:, length(vns))) + else + # The individual variables are univariate, and thus we return a vector of scalars. + return vec_vals end - # HACK: I don't like this. - dist = getdist(vi, vns[1]) - return recombine(dist, vals_linked, length(vns)) end + function getindex(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" vals_linked = mapreduce(vcat, vns) do vn @@ -1660,6 +1866,7 @@ function setorder!(metadata::Metadata, vn::VarName, index::Int) metadata.orders[metadata.idcs[vn]] = index return metadata end +setorder!(vnv::VarNamedVector, ::VarName, ::Int) = vnv """ getorder(vi::VarInfo, vn::VarName) @@ -1685,21 +1892,45 @@ end function is_flagged(metadata::Metadata, vn::VarName, flag::String) return metadata.flags[flag][getidx(metadata, vn)] end +function is_flagged(::VarNamedVector, ::VarName, flag::String) + if flag == "del" + return true + else + throw(ErrorException("Flag $flag not valid for VarNamedVector")) + end +end +# TODO(mhauru) The "ignorable" argument is a temporary hack while developing VarNamedVector, +# but still having to support the interface based on Metadata too """ - unset_flag!(vi::VarInfo, vn::VarName, flag::String) + unset_flag!(vi::VarInfo, vn::VarName, flag::String, ignorable::Bool=false Set `vn`'s value for `flag` to `false` in `vi`. + +Setting some flags for some `VarInfo` types is not possible, and by default attempting to do +so will error. If `ignorable` is set to `true` then this will silently be ignored instead. """ -function unset_flag!(vi::VarInfo, vn::VarName, flag::String) - unset_flag!(getmetadata(vi, vn), vn, flag) +function unset_flag!(vi::VarInfo, vn::VarName, flag::String, ignorable::Bool=false) + unset_flag!(getmetadata(vi, vn), vn, flag, ignorable) return vi end -function unset_flag!(metadata::Metadata, vn::VarName, flag::String) +function unset_flag!(metadata::Metadata, vn::VarName, flag::String, ignorable::Bool=false) metadata.flags[flag][getidx(metadata, vn)] = false return metadata end +function unset_flag!(vnv::VarNamedVector, ::VarName, flag::String, ignorable::Bool=false) + if ignorable + return vnv + end + if flag == "del" + throw(ErrorException("The \"del\" flag cannot be unset for VarNamedVector")) + else + throw(ErrorException("Flag $flag not valid for VarNamedVector")) + end + return vnv +end + """ set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler) @@ -1804,7 +2035,7 @@ end ) where {names} updates = map(names) do n quote - for vn in metadata.$n.vns + for vn in Base.keys(metadata.$n) indices_found = kernel!(vi, vn, values, keys_strings) if indices_found !== nothing num_indices_seen += length(indices_found) @@ -1886,14 +2117,6 @@ julia> DynamicPPL.setval!(var_info, (m = 100.0, )); # set `m` and and keep `x[1] julia> var_info[@varname(m)] # [✓] changed 100.0 -julia> var_info[@varname(x[1])] # [✓] unchanged --0.22312984965118443 - -julia> m(rng, var_info); # rerun model - -julia> var_info[@varname(m)] # [✓] unchanged -100.0 - julia> var_info[@varname(x[1])] # [✓] unchanged -0.22312984965118443 ``` @@ -1923,9 +2146,9 @@ end Set the values in `vi` to the provided values and those which are not present in `x` or `chains` to *be* resampled. -Note that this does *not* resample the values not provided! It will call `setflag!(vi, vn, "del")` -for variables `vn` for which no values are provided, which means that the next time we call `model(vi)` these -variables will be resampled. +Note that this does *not* resample the values not provided! It will call +`setflag!(vi, vn, "del")` for variables `vn` for which no values are provided, which means +that the next time we call `model(vi)` these variables will be resampled. ## Note - This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info. @@ -1945,7 +2168,7 @@ julia> rng = StableRNG(42); julia> m = demo([missing]); -julia> var_info = DynamicPPL.VarInfo(rng, m); +julia> var_info = DynamicPPL.VarInfo(rng, m, SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata); # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. julia> var_info[@varname(m)] -0.6702516921145671 @@ -2043,6 +2266,9 @@ function values_as( return ConstructionBase.constructorof(D)(iter) end +values_as(vi::VectorVarInfo, args...) = values_as(vi.metadata, args...) +values_as(vi::VectorVarInfo, T::Type{Vector}) = values_as(vi.metadata, T) + function values_from_metadata(md::Metadata) return ( # `copy` to avoid accidentally mutation of internal representation. @@ -2052,6 +2278,8 @@ function values_from_metadata(md::Metadata) ) end +values_from_metadata(md::VarNamedVector) = pairs(md) + # Transforming from internal representation to distribution representation. # Without `dist` argument: base on `dist` extracted from self. function from_internal_transform(vi::VarInfo, vn::VarName) @@ -2060,11 +2288,17 @@ end function from_internal_transform(md::Metadata, vn::VarName) return from_internal_transform(md, vn, getdist(md, vn)) end +function from_internal_transform(md::VarNamedVector, vn::VarName) + return gettransform(md, vn) +end # With both `vn` and `dist` arguments: base on provided `dist`. function from_internal_transform(vi::VarInfo, vn::VarName, dist) return from_internal_transform(getmetadata(vi, vn), vn, dist) end from_internal_transform(::Metadata, ::VarName, dist) = from_vec_transform(dist) +function from_internal_transform(::VarNamedVector, ::VarName, dist) + return from_vec_transform(dist) +end # Without `dist` argument: base on `dist` extracted from self. function from_linked_internal_transform(vi::VarInfo, vn::VarName) @@ -2073,6 +2307,9 @@ end function from_linked_internal_transform(md::Metadata, vn::VarName) return from_linked_internal_transform(md, vn, getdist(md, vn)) end +function from_linked_internal_transform(md::VarNamedVector, vn::VarName) + return gettransform(md, vn) +end # With both `vn` and `dist` arguments: base on provided `dist`. function from_linked_internal_transform(vi::VarInfo, vn::VarName, dist) # Dispatch to metadata in case this alters the behavior. @@ -2081,3 +2318,6 @@ end function from_linked_internal_transform(::Metadata, ::VarName, dist) return from_linked_vec_transform(dist) end +function from_linked_internal_transform(::VarNamedVector, ::VarName, dist) + return from_linked_vec_transform(dist) +end diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl new file mode 100644 index 000000000..abd15ed34 --- /dev/null +++ b/src/varnamedvector.jl @@ -0,0 +1,1259 @@ +""" + VarNamedVector + +A container that stores values in a vectorised form, but indexable by variable names. + +When indexed by integers or `Colon`s, e.g. `vnv[2]` or `vnv[:]`, `VarNamedVector` behaves +like a `Vector`, and returns the values as they are stored. The stored form is always +vectorised, for instance matrix variables have been flattened, and may be further +transformed to achieve linking. + +When indexed by `VarName`s, e.g. `vnv[@varname(x)]`, `VarNamedVector` returns the values +in the original space. For instance, a linked matrix variable is first inverse linked and +then reshaped to its original form before returning it to the caller. + +`VarNamedVector` also stores a boolean for whether a variable has been transformed to +unconstrained Euclidean space or not. + +# Fields +$(FIELDS) +""" +struct VarNamedVector{ + K<:VarName,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector +} + """ + mapping from a `VarName` to its integer index in `varnames`, `ranges` and `transforms` + """ + varname_to_index::OrderedDict{K,Int} + + """ + vector of `VarNames` for the variables, where `varnames[varname_to_index[vn]] == vn` + """ + varnames::TVN # AbstractVector{<:VarName} + + """ + vector of index ranges in `vals` corresponding to `varnames`; each `VarName` `vn` has + a single index or a set of contiguous indices, such that the values of `vn` can be found + at `vals[ranges[varname_to_index[vn]]]` + """ + ranges::Vector{UnitRange{Int}} + + """ + vector of values of all variables; the value(s) of `vn` is/are + `vals[ranges[varname_to_index[vn]]]` + """ + vals::TVal # AbstractVector{<:Real} + + """ + vector of transformations, so that `transforms[varname_to_index[vn]]` is a callable + that transforms the value of `vn` back to its original space, undoing any linking and + vectorisation + """ + transforms::TTrans + + """ + vector of booleans indicating whether a variable has been transformed to unconstrained + Euclidean space or not, i.e. whether its domain is all of `ℝ^ⁿ`. Having + `is_unconstrained[varname_to_index[vn]] == false` does not necessarily mean that a + variable is constrained, but rather that it's not guaranteed to not be. + """ + is_unconstrained::BitVector + + """ + mapping from a variable index to the number of inactive entries for that variable. + Inactive entries are elements in `vals` that are not part of the value of any variable. + They arise when a variable is set to a new value with a different dimension, in-place. + Inactive entries always come after the last active entry for the given variable. + """ + num_inactive::OrderedDict{Int,Int} + + function VarNamedVector( + varname_to_index, + varnames::TVN, + ranges, + vals::TVal, + transforms::TTrans, + is_unconstrained, + num_inactive, + ) where {K,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector} + if length(varnames) != length(ranges) || + length(varnames) != length(transforms) || + length(varnames) != length(is_unconstrained) || + length(varnames) != length(varname_to_index) + msg = "Inputs to VarNamedVector have inconsistent lengths. Got lengths varnames: $(length(varnames)), ranges: $(length(ranges)), transforms: $(length(transforms)), is_unconstrained: $(length(is_unconstrained)), varname_to_index: $(length(varname_to_index))." + throw(ArgumentError(msg)) + end + + num_vals = mapreduce(length, (+), ranges; init=0) + sum(values(num_inactive)) + if num_vals != length(vals) + msg = "The total number of elements in `vals` ($(length(vals))) does not match the sum of the lengths of the ranges and the number of inactive entries ($(num_vals))." + throw(ArgumentError(msg)) + end + + if Set(values(varname_to_index)) != Set(1:length(varnames)) + msg = "The values of `varname_to_index` are not valid indices." + throw(ArgumentError(msg)) + end + + if !issubset(Set(keys(num_inactive)), Set(values(varname_to_index))) + msg = "The keys of `num_inactive` are not valid indices." + throw(ArgumentError(msg)) + end + + # Check that the varnames don't overlap. The time cost is quadratic in number of + # variables. If this ever becomes an issue, we should be able to go down to at least + # N log N by sorting based on subsumes-order. + for vn1 in keys(varname_to_index) + for vn2 in keys(varname_to_index) + vn1 === vn2 && continue + if subsumes(vn1, vn2) + msg = "Variables in a VarNamedVector should not subsume each other, but $vn1 subsumes $vn2." + throw(ArgumentError(msg)) + end + end + end + + # We could also have a test to check that the ranges don't overlap, but that sounds + # unlikely to occur, and implementing it in linear time would require a tiny bit of + # thought. + + return new{K,V,TVN,TVal,TTrans}( + varname_to_index, + varnames, + ranges, + vals, + transforms, + is_unconstrained, + num_inactive, + ) + end +end + +# Default values for is_unconstrained (all false) and num_inactive (empty). +function VarNamedVector( + varname_to_index, + varnames, + ranges, + vals, + transforms, + is_unconstrained=fill!(BitVector(undef, length(varnames)), 0), +) + return VarNamedVector( + varname_to_index, + varnames, + ranges, + vals, + transforms, + is_unconstrained, + OrderedDict{Int,Int}(), + ) +end + +# TODO(mhauru) Are we sure we want the last one to be of type Any[]? Might this cause +# unnecessary type instability? +function VarNamedVector{K,V}() where {K,V} + return VarNamedVector(OrderedDict{K,Int}(), K[], UnitRange{Int}[], V[], Any[]) +end + +# TODO(mhauru) I would like for this to be VarNamedVector(Union{}, Union{}). This would +# allow expanding the VarName and element types only as necessary, which would help keep +# them concrete. However, making that change here opens some other cans of worms related to +# how VarInfo uses BangBang, that I don't want to deal with right now. +VarNamedVector() = VarNamedVector{VarName,Real}() +VarNamedVector(xs::Pair...) = VarNamedVector(OrderedDict(xs...)) +VarNamedVector(x::AbstractDict) = VarNamedVector(keys(x), values(x)) +function VarNamedVector(varnames, vals) + return VarNamedVector(collect_maybe(varnames), collect_maybe(vals)) +end +function VarNamedVector( + varnames::AbstractVector, + vals::AbstractVector, + transforms=fill(identity, length(varnames)), +) + # Convert `vals` into a vector of vectors. + vals_vecs = map(tovec, vals) + transforms = map( + (t, val) -> _compose_no_identity(t, from_vec_transform(val)), transforms, vals + ) + # TODO: Is this really the way to do this? + if !(eltype(varnames) <: VarName) + varnames = convert(Vector{VarName}, varnames) + end + varname_to_index = OrderedDict{eltype(varnames),Int}( + vn => i for (i, vn) in enumerate(varnames) + ) + vals = reduce(vcat, vals_vecs) + # Make the ranges. + ranges = Vector{UnitRange{Int}}() + offset = 0 + for x in vals_vecs + r = (offset + 1):(offset + length(x)) + push!(ranges, r) + offset = r[end] + end + + return VarNamedVector(varname_to_index, varnames, ranges, vals, transforms) +end + +function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector) + return vnv_left.varname_to_index == vnv_right.varname_to_index && + vnv_left.varnames == vnv_right.varnames && + vnv_left.ranges == vnv_right.ranges && + vnv_left.vals == vnv_right.vals && + vnv_left.transforms == vnv_right.transforms && + vnv_left.is_unconstrained == vnv_right.is_unconstrained && + vnv_left.num_inactive == vnv_right.num_inactive +end + +getidx(vnv::VarNamedVector, vn::VarName) = vnv.varname_to_index[vn] + +getrange(vnv::VarNamedVector, idx::Int) = vnv.ranges[idx] +getrange(vnv::VarNamedVector, vn::VarName) = getrange(vnv, getidx(vnv, vn)) + +gettransform(vnv::VarNamedVector, idx::Int) = vnv.transforms[idx] +gettransform(vnv::VarNamedVector, vn::VarName) = gettransform(vnv, getidx(vnv, vn)) + +# TODO(mhauru) Eventually I would like to rename the istrans function to is_unconstrained, +# but that's significantly breaking. +""" + istrans(vnv::VarNamedVector, vn::VarName) + +Return a boolean for whether `vn` is guaranteed to have been transformed so that its domain +is all of Euclidean space. +""" +istrans(vnv::VarNamedVector, vn::VarName) = vnv.is_unconstrained[getidx(vnv, vn)] + +""" + settrans!(vnv::VarNamedVector, val::Bool, vn::VarName) + +Set the value for whether `vn` is guaranteed to have been transformed so that all of +Euclidean space is its domain. +""" +function settrans!(vnv::VarNamedVector, val::Bool, vn::VarName) + return vnv.is_unconstrained[vnv.varname_to_index[vn]] = val +end + +function settrans!!(vnv::VarNamedVector, val::Bool, vn::VarName) + settrans!(vnv, val, vn) + return vnv +end + +""" + has_inactive(vnv::VarNamedVector) + +Returns `true` if `vnv` has inactive ranges. +""" +has_inactive(vnv::VarNamedVector) = !isempty(vnv.num_inactive) + +""" + num_inactive(vnv::VarNamedVector) + +Return the number of inactive entries in `vnv`. +""" +num_inactive(vnv::VarNamedVector) = sum(values(vnv.num_inactive)) + +""" + num_inactive(vnv::VarNamedVector, vn::VarName) + +Returns the number of inactive entries for `vn` in `vnv`. +""" +num_inactive(vnv::VarNamedVector, vn::VarName) = num_inactive(vnv, getidx(vnv, vn)) +num_inactive(vnv::VarNamedVector, idx::Int) = get(vnv.num_inactive, idx, 0) + +""" + num_allocated(vnv::VarNamedVector) + +Returns the number of allocated entries in `vnv`, both active and inactive. +""" +num_allocated(vnv::VarNamedVector) = length(vnv.vals) + +""" + num_allocated(vnv::VarNamedVector, vn::VarName) + +Returns the number of allocated entries for `vn` in `vnv`, both active and inactive. +""" +num_allocated(vnv::VarNamedVector, vn::VarName) = num_allocated(vnv, getidx(vnv, vn)) +function num_allocated(vnv::VarNamedVector, idx::Int) + return length(getrange(vnv, idx)) + num_inactive(vnv, idx) +end + +# Basic array interface. +Base.eltype(vnv::VarNamedVector) = eltype(vnv.vals) +Base.length(vnv::VarNamedVector) = + if !has_inactive(vnv) + length(vnv.vals) + else + sum(length, vnv.ranges) + end +Base.size(vnv::VarNamedVector) = (length(vnv),) +Base.isempty(vnv::VarNamedVector) = isempty(vnv.varnames) + +Base.IndexStyle(::Type{<:VarNamedVector}) = IndexLinear() + +# Dictionary interface. +Base.keys(vnv::VarNamedVector) = vnv.varnames +Base.values(vnv::VarNamedVector) = Iterators.map(Base.Fix1(getindex, vnv), vnv.varnames) +Base.pairs(vnv::VarNamedVector) = (vn => vnv[vn] for vn in keys(vnv)) + +Base.haskey(vnv::VarNamedVector, vn::VarName) = haskey(vnv.varname_to_index, vn) + +# `getindex` & `setindex!` +Base.getindex(vnv::VarNamedVector, i::Int) = getindex_internal(vnv, i) +function Base.getindex(vnv::VarNamedVector, vn::VarName) + x = getindex_internal(vnv, vn) + f = gettransform(vnv, vn) + return f(x) +end + +""" + find_containing_range(ranges::AbstractVector{<:AbstractRange}, x) + +Find the first range in `ranges` that contains `x`. + +Throw an `ArgumentError` if `x` is not in any of the ranges. +""" +function find_containing_range(ranges::AbstractVector{<:AbstractRange}, x) + # TODO: Assume `ranges` to be sorted and contiguous, and use `searchsortedfirst` + # for a more efficient approach. + range_idx = findfirst(Base.Fix1(∈, x), ranges) + + # If we're out of bounds, we raise an error. + if range_idx === nothing + throw(ArgumentError("Value $x is not in any of the ranges.")) + end + + return range_idx +end + +""" + adjusted_ranges(vnv::VarNamedVector) + +Return what `vnv.ranges` would be if there were no inactive entries. +""" +function adjusted_ranges(vnv::VarNamedVector) + # Every range following inactive entries needs to be shifted. + offset = 0 + ranges_adj = similar(vnv.ranges) + for (idx, r) in enumerate(vnv.ranges) + # Remove the `offset` in `r` due to inactive entries. + ranges_adj[idx] = r .- offset + # Update `offset`. + offset += get(vnv.num_inactive, idx, 0) + end + + return ranges_adj +end + +""" + index_to_vals_index(vnv::VarNamedVector, i::Int) + +Convert an integer index that ignores inactive entries to an index that accounts for them. + +This is needed when the user wants to index `vnv` like a vector, but shouldn't have to care +about inactive entries in `vnv.vals`. +""" +function index_to_vals_index(vnv::VarNamedVector, i::Int) + # If we don't have any inactive entries, there's nothing to do. + has_inactive(vnv) || return i + + # Get the adjusted ranges. + ranges_adj = adjusted_ranges(vnv) + # Determine the adjusted range that the index corresponds to. + r_idx = find_containing_range(ranges_adj, i) + r = vnv.ranges[r_idx] + # Determine how much of the index `i` is used to get to this range. + i_used = r_idx == 1 ? 0 : sum(length, ranges_adj[1:(r_idx - 1)]) + # Use remainder to index into `r`. + i_remainder = i - i_used + return r[i_remainder] +end + +""" + getindex_internal(vnv::VarNamedVector, i::Int) + getindex_internal(vnv::VarNamedVector, vn::VarName) + +Like `getindex`, but returns the values as they are stored in `vnv` without transforming. + +For integer indices this is the same as `getindex`, but for `VarName`s this is different. +""" +getindex_internal(vnv::VarNamedVector, i::Int) = vnv.vals[index_to_vals_index(vnv, i)] +getindex_internal(vnv::VarNamedVector, vn::VarName) = vnv.vals[getrange(vnv, vn)] + +# `getindex` for `Colon` +function Base.getindex(vnv::VarNamedVector, ::Colon) + return if has_inactive(vnv) + mapreduce(Base.Fix1(getindex, vnv.vals), vcat, vnv.ranges) + else + vnv.vals + end +end + +getindex_internal(vnv::VarNamedVector, ::Colon) = getindex(vnv, Colon()) + +# TODO(mhauru): Remove this as soon as possible. Only needed because of the old Gibbs +# sampler. +function Base.getindex(vnv::VarNamedVector, spl::AbstractSampler) + throw(ErrorException("Cannot index a VarNamedVector with a sampler.")) +end + +Base.setindex!(vnv::VarNamedVector, val, i::Int) = setindex_internal!(vnv, val, i) +function Base.setindex!(vnv::VarNamedVector, val, vn::VarName) + # Since setindex! does not change the transform, we need to apply it to `val`. + f = inverse(gettransform(vnv, vn)) + return setindex_internal!(vnv, f(val), vn) +end + +""" + setindex_internal!(vnv::VarNamedVector, val, i::Int) + setindex_internal!(vnv::VarNamedVector, val, vn::VarName) + +Like `setindex!`, but sets the values as they are stored in `vnv` without transforming. + +For integer indices this is the same as `setindex!`, but for `VarName`s this is different. +""" +function setindex_internal!(vnv::VarNamedVector, val, i::Int) + return vnv.vals[index_to_vals_index(vnv, i)] = val +end + +function setindex_internal!(vnv::VarNamedVector, val::AbstractVector, vn::VarName) + return vnv.vals[getrange(vnv, vn)] = val +end + +function Base.empty!(vnv::VarNamedVector) + # TODO: Or should the semantics be different, e.g. keeping `varnames`? + empty!(vnv.varname_to_index) + empty!(vnv.varnames) + empty!(vnv.ranges) + empty!(vnv.vals) + empty!(vnv.transforms) + empty!(vnv.is_unconstrained) + empty!(vnv.num_inactive) + return nothing +end +BangBang.empty!!(vnv::VarNamedVector) = (empty!(vnv); return vnv) + +""" + replace_values(vnv::VarNamedVector, vals::AbstractVector) + +Replace the values in `vnv` with `vals`, as they are stored internally. + +This is useful when we want to update the entire underlying vector of values in one go or if +we want to change the how the values are stored, e.g. alter the `eltype`. + +!!! warning + This replaces the raw underlying values, and so care should be taken when using this + function. For example, if `vnv` has any inactive entries, then the provided `vals` + should also contain the inactive entries to avoid unexpected behavior. + +# Examples + +```jldoctest varnamedvector-replace-values +julia> using DynamicPPL: VarNamedVector, replace_values + +julia> vnv = VarNamedVector(@varname(x) => [1.0]); + +julia> replace_values(vnv, [2.0])[@varname(x)] == [2.0] +true +``` + +This is also useful when we want to differentiate wrt. the values using automatic +differentiation, e.g. ForwardDiff.jl. + +```jldoctest varnamedvector-replace-values +julia> using ForwardDiff: ForwardDiff + +julia> f(x) = sum(abs2, replace_values(vnv, x)[@varname(x)]) +f (generic function with 1 method) + +julia> ForwardDiff.gradient(f, [1.0]) +1-element Vector{Float64}: + 2.0 +``` +""" +replace_values(vnv::VarNamedVector, vals) = Accessors.@set vnv.vals = vals + +# TODO(mhauru) The space argument is used by the old Gibbs sampler. To be removed. +function replace_values(vnv::VarNamedVector, ::Val{space}, vals) where {space} + if length(space) > 0 + msg = "Selecting values in a VarNamedVector with a space is not supported." + throw(ArgumentError(msg)) + end + return replace_values(vnv, vals) +end + +""" + unflatten(vnv::VarNamedVector, vals::AbstractVector) + +Return a new instance of `vnv` with the values of `vals` assigned to the variables. + +This assumes that `vals` have been transformed by the same transformations that that the +values in `vnv` have been transformed by. However, unlike [`replace_values`](@ref), +`unflatten` does account for inactive entries in `vnv`, so that the user does not have to +care about them. + +This is in a sense the reverse operation of `vnv[:]`. + +Unflatten recontiguifies the internal storage, getting rid of any inactive entries. + +# Examples + +```jldoctest varnamedvector-unflatten +julia> using DynamicPPL: VarNamedVector, unflatten + +julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(y) => [3.0]); + +julia> unflatten(vnv, vnv[:]) == vnv +true +""" +function unflatten(vnv::VarNamedVector, vals::AbstractVector) + new_ranges = deepcopy(vnv.ranges) + recontiguify_ranges!(new_ranges) + return VarNamedVector( + vnv.varname_to_index, vnv.varnames, new_ranges, vals, vnv.transforms + ) +end + +# TODO(mhauru) To be removed once the old Gibbs sampler is removed. +function unflatten(vnv::VarNamedVector, spl::AbstractSampler, vals::AbstractVector) + if length(getspace(spl)) > 0 + msg = "Selecting values in a VarNamedVector with a space is not supported." + throw(ArgumentError(msg)) + end + return unflatten(vnv, vals) +end + +function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) + # Return early if possible. + isempty(left_vnv) && return deepcopy(right_vnv) + isempty(right_vnv) && return deepcopy(left_vnv) + + # Determine varnames. + vns_left = left_vnv.varnames + vns_right = right_vnv.varnames + vns_both = union(vns_left, vns_right) + + # Determine `eltype` of `vals`. + T_left = eltype(left_vnv.vals) + T_right = eltype(right_vnv.vals) + T = promote_type(T_left, T_right) + + # Determine `eltype` of `varnames`. + V_left = eltype(left_vnv.varnames) + V_right = eltype(right_vnv.varnames) + V = promote_type(V_left, V_right) + if !(V <: VarName) + V = VarName + end + + # Determine `eltype` of `transforms`. + F_left = eltype(left_vnv.transforms) + F_right = eltype(right_vnv.transforms) + F = promote_type(F_left, F_right) + + # Allocate. + varname_to_index = OrderedDict{V,Int}() + ranges = UnitRange{Int}[] + vals = T[] + transforms = F[] + is_unconstrained = BitVector(undef, length(vns_both)) + + # Range offset. + offset = 0 + + for (idx, vn) in enumerate(vns_both) + varname_to_index[vn] = idx + # Extract the necessary information from `left` or `right`. + if vn in vns_left && !(vn in vns_right) + # `vn` is only in `left`. + val = getindex_internal(left_vnv, vn) + f = gettransform(left_vnv, vn) + is_unconstrained[idx] = istrans(left_vnv, vn) + else + # `vn` is either in both or just `right`. + # Note that in a `merge` the right value has precedence. + val = getindex_internal(right_vnv, vn) + f = gettransform(right_vnv, vn) + is_unconstrained[idx] = istrans(right_vnv, vn) + end + n = length(val) + r = (offset + 1):(offset + n) + # Update. + append!(vals, val) + push!(ranges, r) + push!(transforms, f) + # Increment `offset`. + offset += n + end + + return VarNamedVector( + varname_to_index, vns_both, ranges, vals, transforms, is_unconstrained + ) +end + +""" + subset(vnv::VarNamedVector, vns::AbstractVector{<:VarName}) + +Return a new `VarNamedVector` containing the values from `vnv` for variables in `vns`. + +Which variables to include is determined by the `VarName`'s `subsumes` relation, meaning +that e.g. `subset(vnv, [@varname(x)])` will include variables like `@varname(x.a[1])`. + +# Examples + +```jldoctest varnamedvector-subset +julia> using DynamicPPL: VarNamedVector, @varname, subset + +julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(y) => [3.0]); + +julia> subset(vnv, [@varname(x)]) == VarNamedVector(@varname(x) => [1.0, 2.0]) +true + +julia> subset(vnv, [@varname(x[2])]) == VarNamedVector(@varname(x[2]) => [2.0]) +true +""" +function subset(vnv::VarNamedVector, vns_given::AbstractVector{VN}) where {VN<:VarName} + # NOTE: This does not specialize types when possible. + vns = mapreduce(vcat, vns_given; init=VN[]) do vn + filter(Base.Fix1(subsumes, vn), vnv.varnames) + end + vnv_new = similar(vnv) + # Return early if possible. + isempty(vnv) && return vnv_new + + for vn in vns + push!(vnv_new, vn, getindex_internal(vnv, vn), gettransform(vnv, vn)) + settrans!(vnv_new, istrans(vnv, vn), vn) + end + + return vnv_new +end + +""" + similar(vnv::VarNamedVector) + +Return a new `VarNamedVector` with the same structure as `vnv`, but with empty values. + +In this respect `vnv` behaves more like a dictionary than an array: `similar(vnv)` will +be entirely empty, rather than have `undef` values in it. + +# Examples + +```julia-doctest-varnamedvector-similar +julia> using DynamicPPL: VarNamedVector, @varname, similar + +julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(x[3]) => [3.0]); + +julia> similar(vnv) == VarNamedVector{VarName{:x}, Float64}() +true +""" +function Base.similar(vnv::VarNamedVector) + # NOTE: Whether or not we should empty the underlying containers or not + # is somewhat ambiguous. For example, `similar(vnv.varname_to_index)` will + # result in an empty `AbstractDict`, while the vectors, e.g. `vnv.ranges`, + # will result in non-empty vectors but with entries as `undef`. But it's + # much easier to write the rest of the code assuming that `undef` is not + # present, and so for now we empty the underlying containers, thus differing + # from the behavior of `similar` for `AbstractArray`s. + return VarNamedVector( + empty(vnv.varname_to_index), + similar(vnv.varnames, 0), + similar(vnv.ranges, 0), + similar(vnv.vals, 0), + similar(vnv.transforms, 0), + BitVector(), + empty(vnv.num_inactive), + ) +end + +""" + is_contiguous(vnv::VarNamedVector) + +Returns `true` if the underlying data of `vnv` is stored in a contiguous array. + +This is equivalent to negating [`has_inactive(vnv)`](@ref). +""" +is_contiguous(vnv::VarNamedVector) = !has_inactive(vnv) + +""" + nextrange(vnv::VarNamedVector, x) + +Return the range of `length(x)` from the end of current data in `vnv`. +""" +function nextrange(vnv::VarNamedVector, x) + offset = length(vnv.vals) + return (offset + 1):(offset + length(x)) +end + +# TODO(mhauru) Might add another specialisation to _compose_no_identity, where if +# ReshapeTransforms are composed with each other or with a an UnwrapSingeltonTransform, only +# the latter one would be kept. +""" + _compose_no_identity(f, g) + +Like `f ∘ g`, but if `f` or `g` is `identity` it is omitted. + +This helps avoid trivial cases of `ComposedFunction` that would cause unnecessary type +conflicts. +""" +_compose_no_identity(f, g) = f ∘ g +_compose_no_identity(::typeof(identity), g) = g +_compose_no_identity(f, ::typeof(identity)) = f +_compose_no_identity(::typeof(identity), ::typeof(identity)) = identity + +""" + push!(vnv::VarNamedVector, vn::VarName, val[, transform]) + push!(vnv::VarNamedVector, vn => val[, transform]) + +Add a variable with given value to `vnv`. + +`transform` should be a function that converts `val` to the original representation, by +default it's `identity`. +""" +function Base.push!(vnv::VarNamedVector, vn::VarName, val, transform=identity) + # Error if we already have the variable. + haskey(vnv, vn) && throw(ArgumentError("variable name $vn already exists")) + # NOTE: We need to compute the `nextrange` BEFORE we start mutating the underlying + # storage. + if !(val isa AbstractVector) + val_vec = tovec(val) + transform = _compose_no_identity(transform, from_vec_transform(val)) + else + val_vec = val + end + r_new = nextrange(vnv, val_vec) + vnv.varname_to_index[vn] = length(vnv.varname_to_index) + 1 + push!(vnv.varnames, vn) + push!(vnv.ranges, r_new) + append!(vnv.vals, val_vec) + push!(vnv.transforms, transform) + push!(vnv.is_unconstrained, false) + return nothing +end + +function Base.push!(vnv::VarNamedVector, pair, transform=identity) + vn, val = pair + return push!(vnv, vn, val, transform) +end + +# TODO(mhauru) The gidset and num_produce arguments are used by the old Gibbs sampler. +# Remove this method as soon as possible. +function Base.push!(vnv::VarNamedVector, vn, val, dist, gidset, num_produce) + f = from_vec_transform(dist) + return push!(vnv, vn, tovec(val), f) +end + +""" + loosen_types!!(vnv::VarNamedVector{K,V,TVN,TVal,TTrans}, ::Type{KNew}, ::Type{TransNew}) + +Loosen the types of `vnv` to allow varname type `KNew` and transformation type `TransNew`. + +If `KNew` is a subtype of `K` and `TransNew` is a subtype of the element type of the +`TTrans` then this is a no-op and `vnv` is returned as is. Otherwise a new `VarNamedVector` +is returned with the same data but more abstract types, so that variables of type `KNew` and +transformations of type `TransNew` can be pushed to it. Some of the underlying storage is +shared between `vnv` and the return value, and thus mutating one may affect the other. + +# See also +[`tighten_types`](@ref) + +# Examples + +```jldoctest varnamedvector-loosen-types +julia> using DynamicPPL: VarNamedVector, @varname, loosen_types!! + +julia> vnv = VarNamedVector(@varname(x) => [1.0]); + +julia> vnv_new = loosen_types!!(vnv, VarName{:x}, Real); + +julia> push!(vnv, @varname(y), Float32[2.0]) +ERROR: MethodError: Cannot `convert` an object of type + VarName{y,typeof(identity)} to an object of type + VarName{x,typeof(identity)} +[...] + +julia> vnv_loose = DynamicPPL.loosen_types!!(vnv, typeof(@varname(y)), Float32); + +julia> push!(vnv_loose, @varname(y), Float32[2.0]); vnv_loose # Passes without issues. +VarNamedVector{VarName{sym, typeof(identity)} where sym, Float64, Vector{VarName{sym, typeof(identity)} where sym}, Vector{Float64}, Vector{Any}}(OrderedDict{VarName{sym, typeof(identity)} where sym, Int64}(x => 1, y => 2), VarName{sym, typeof(identity)} where sym[x, y], UnitRange{Int64}[1:1, 2:2], [1.0, 2.0], Any[identity, identity], Bool[0, 0], OrderedDict{Int64, Int64}()) +""" +function loosen_types!!( + vnv::VarNamedVector, ::Type{KNew}, ::Type{TransNew} +) where {KNew,TransNew} + K = eltype(vnv.varnames) + Trans = eltype(vnv.transforms) + if KNew <: K && TransNew <: Trans + return vnv + else + vn_type = promote_type(K, KNew) + transform_type = promote_type(Trans, TransNew) + return VarNamedVector( + OrderedDict{vn_type,Int}(vnv.varname_to_index), + Vector{vn_type}(vnv.varnames), + vnv.ranges, + vnv.vals, + Vector{transform_type}(vnv.transforms), + vnv.is_unconstrained, + vnv.num_inactive, + ) + end +end + +""" + tighten_types(vnv::VarNamedVector) + +Return a copy of `vnv` with the most concrete types possible. + +For instance, if `vnv` has element type `Real`, but all the values are actually `Float64`s, +then `tighten_types(vnv)` will have element type `Float64`. + +# See also +[`loosen_types!!`](@ref) +""" +function tighten_types(vnv::VarNamedVector) + return VarNamedVector( + OrderedDict(vnv.varname_to_index...), + map(identity, vnv.varnames), + copy(vnv.ranges), + map(identity, vnv.vals), + map(identity, vnv.transforms), + copy(vnv.is_unconstrained), + copy(vnv.num_inactive), + ) +end + +function BangBang.push!!(vnv::VarNamedVector, vn::VarName, val, transform=identity) + vnv = loosen_types!!( + vnv, typeof(vn), typeof(_compose_no_identity(transform, from_vec_transform(val))) + ) + push!(vnv, vn, val, transform) + return vnv +end + +# TODO(mhauru) The gidset and num_produce arguments are used by the old Gibbs sampler. +# Remove this method as soon as possible. +function BangBang.push!!(vnv::VarNamedVector, vn, val, dist, gidset, num_produce) + f = from_vec_transform(dist) + return push!!(vnv, vn, tovec(val), f) +end + +""" + shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int) + +Shifts the elements of `x` starting from index `start` by `n` to the right. +""" +function shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int) + x[(start + n):end] = x[start:(end - n)] + return x +end + +""" + shift_subsequent_ranges_by!(vnv::VarNamedVector, idx::Int, n) + +Shifts the ranges of variables in `vnv` starting from index `idx` by `n`. +""" +function shift_subsequent_ranges_by!(vnv::VarNamedVector, idx::Int, n) + for i in (idx + 1):length(vnv.ranges) + vnv.ranges[i] = vnv.ranges[i] .+ n + end + return nothing +end + +""" + update!(vnv::VarNamedVector, vn::VarName, val[, transform]) + +Either add a new entry or update existing entry for `vn` in `vnv` with the value `val`. + +If `vn` does not exist in `vnv`, this is equivalent to [`push!`](@ref). + +`transform` should be a function that converts `val` to the original representation, by +default it's `identity`. +""" +function update!(vnv::VarNamedVector, vn::VarName, val, transform=identity) + if !haskey(vnv, vn) + # Here we just add a new entry. + return push!(vnv, vn, val, transform) + end + + # Here we update an existing entry. + if !(val isa AbstractVector) + val_vec = tovec(val) + transform = _compose_no_identity(transform, from_vec_transform(val)) + else + val_vec = val + end + idx = getidx(vnv, vn) + # Extract the old range. + r_old = getrange(vnv, idx) + start_old, end_old = first(r_old), last(r_old) + n_old = length(r_old) + # Compute the new range. + n_new = length(val_vec) + start_new = start_old + end_new = start_old + n_new - 1 + r_new = start_new:end_new + + #= + Suppose we currently have the following: + + | x | x | o | o | o | y | y | y | <- Current entries + + where 'O' denotes an inactive entry, and we're going to + update the variable `x` to be of size `k` instead of 2. + + We then have a few different scenarios: + 1. `k > 5`: All inactive entries become active + need to shift `y` to the right. + E.g. if `k = 7`, then + + | x | x | o | o | o | y | y | y | <- Current entries + | x | x | x | x | x | x | x | y | y | y | <- New entries + + 2. `k = 5`: All inactive entries become active. + Then + + | x | x | o | o | o | y | y | y | <- Current entries + | x | x | x | x | x | y | y | y | <- New entries + + 3. `k < 5`: Some inactive entries become active, some remain inactive. + E.g. if `k = 3`, then + + | x | x | o | o | o | y | y | y | <- Current entries + | x | x | x | o | o | y | y | y | <- New entries + + 4. `k = 2`: No inactive entries become active. + Then + + | x | x | o | o | o | y | y | y | <- Current entries + | x | x | o | o | o | y | y | y | <- New entries + + 5. `k < 2`: More entries become inactive. + E.g. if `k = 1`, then + + | x | x | o | o | o | y | y | y | <- Current entries + | x | o | o | o | o | y | y | y | <- New entries + =# + + # Compute the allocated space for `vn`. + had_inactive = haskey(vnv.num_inactive, idx) + n_allocated = had_inactive ? n_old + vnv.num_inactive[idx] : n_old + + if n_new > n_allocated + # Then we need to grow the underlying vector. + n_extra = n_new - n_allocated + # Allocate. + resize!(vnv.vals, length(vnv.vals) + n_extra) + # Shift current values. + shift_right!(vnv.vals, end_old + 1, n_extra) + # No more inactive entries. + had_inactive && delete!(vnv.num_inactive, idx) + # Update the ranges for all variables after this one. + shift_subsequent_ranges_by!(vnv, idx, n_extra) + elseif n_new == n_allocated + # => No more inactive entries. + had_inactive && delete!(vnv.num_inactive, idx) + else + # `n_new < n_allocated` + # => Need to update the number of inactive entries. + vnv.num_inactive[idx] = n_allocated - n_new + end + + # Update the range for this variable. + vnv.ranges[idx] = r_new + # Update the value. + vnv.vals[r_new] = val_vec + # Update the transform. + vnv.transforms[idx] = transform + + # TODO: Should we maybe sweep over inactive ranges and re-contiguify + # if the total number of inactive elements is "large" in some sense? + + return nothing +end + +function update!!(vnv::VarNamedVector, vn::VarName, val, transform=identity) + vnv = loosen_types!!( + vnv, typeof(vn), typeof(_compose_no_identity(transform, from_vec_transform(val))) + ) + update!(vnv, vn, val, transform) + return vnv +end + +# set!! is the function defined in utils.jl that tries to do fancy stuff with optics when +# setting the value of a generic container using a VarName. We can bypass all that because +# VarNamedVector handles VarNames natively. +set!!(vnv::VarNamedVector, vn::VarName, val) = update!!(vnv, vn, val) + +function setval!(vnv::VarNamedVector, val, vn::VarName) + return setindex_internal!(vnv, tovec(val), vn) +end + +function recontiguify_ranges!(ranges::AbstractVector{<:AbstractRange}) + offset = 0 + for i in 1:length(ranges) + r_old = ranges[i] + ranges[i] = (offset + 1):(offset + length(r_old)) + offset += length(r_old) + end + + return ranges +end + +""" + contiguify!(vnv::VarNamedVector) + +Re-contiguify the underlying vector and shrink if possible. + +# Examples + +```jldoctest varnamedvector-contiguify +julia> using DynamicPPL: VarNamedVector, @varname, contiguify!, update!, has_inactive + +julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0, 3.0], @varname(y) => [3.0]); + +julia> update!(vnv, @varname(x), [23.0, 24.0]); + +julia> has_inactive(vnv) +true + +julia> length(vnv.vals) +4 + +julia> contiguify!(vnv); + +julia> has_inactive(vnv) +false + +julia> length(vnv.vals) +3 + +julia> vnv[@varname(x)] # All the values are still there. +2-element Vector{Float64}: + 23.0 + 24.0 +``` +""" +function contiguify!(vnv::VarNamedVector) + # Extract the re-contiguified values. + # NOTE: We need to do this before we update the ranges. + old_vals = copy(vnv.vals) + old_ranges = copy(vnv.ranges) + # And then we re-contiguify the ranges. + recontiguify_ranges!(vnv.ranges) + # Clear the inactive ranges. + empty!(vnv.num_inactive) + # Now we update the values. + for (old_range, new_range) in zip(old_ranges, vnv.ranges) + vnv.vals[new_range] = old_vals[old_range] + end + # And (potentially) shrink the underlying vector. + resize!(vnv.vals, vnv.ranges[end][end]) + # The rest should be left as is. + return vnv +end + +""" + group_by_symbol(vnv::VarNamedVector) + +Return a dictionary mapping symbols to `VarNamedVector`s with varnames containing that +symbol. + +# Examples + +```jldoctest varnamedvector-group-by-symbol +julia> using DynamicPPL: VarNamedVector, @varname, group_by_symbol + +julia> vnv = VarNamedVector(@varname(x) => [1.0], @varname(y) => [2.0], @varname(x[1]) => [3.0]); + +julia> d = group_by_symbol(vnv); + +julia> collect(keys(d)) +[Symbol("x"), Symbol("y")] + +julia> d[@varname(x)] == VarNamedVector(@varname(x) => [1.0], @varname(x[1]) => [3.0]) +true + +julia> d[@varname(y)] == VarNamedVector(@varname(y) => [2.0]) +true +""" +function group_by_symbol(vnv::VarNamedVector) + symbols = unique(map(getsym, vnv.varnames)) + nt_vals = map(s -> tighten_types(subset(vnv, [VarName{s}()])), symbols) + return OrderedDict(zip(symbols, nt_vals)) +end + +""" + shift_index_left!(vnv::VarNamedVector, idx::Int) + +Shift the index `idx` to the left by one and update the relevant fields. + +This only affects `vnv.varname_to_index` and `vnv.num_inactive` and is only valid as a +helper function for [`shift_subsequent_indices_left!`](@ref). + +!!! warning + This does not check if index we're shifting to is already occupied. +""" +function shift_index_left!(vnv::VarNamedVector, idx::Int) + # Shift the index in the lookup table. + vn = vnv.varnames[idx] + vnv.varname_to_index[vn] = idx - 1 + # Shift the index in the inactive ranges. + if haskey(vnv.num_inactive, idx) + # Done in increasing order => don't need to worry about + # potentially shifting the same index twice. + vnv.num_inactive[idx - 1] = pop!(vnv.num_inactive, idx) + end +end + +""" + shift_subsequent_indices_left!(vnv::VarNamedVector, idx::Int) + +Shift the indices for all variables after `idx` to the left by one and update the relevant + fields. + +This only affects `vnv.varname_to_index` and `vnv.num_inactive` and is only valid as a +helper function for [`delete!`](@ref). +""" +function shift_subsequent_indices_left!(vnv::VarNamedVector, idx::Int) + # Shift the indices for all variables after `idx`. + for idx_to_shift in (idx + 1):length(vnv.varnames) + shift_index_left!(vnv, idx_to_shift) + end +end + +function Base.delete!(vnv::VarNamedVector, vn::VarName) + # Error if we don't have the variable. + !haskey(vnv, vn) && throw(ArgumentError("variable name $vn does not exist")) + + # Get the index of the variable. + idx = getidx(vnv, vn) + + # Delete the values. + r_start = first(getrange(vnv, idx)) + n_allocated = num_allocated(vnv, idx) + # NOTE: `deleteat!` also results in a `resize!` so we don't need to do that. + deleteat!(vnv.vals, r_start:(r_start + n_allocated - 1)) + + # Delete `vn` from the lookup table. + delete!(vnv.varname_to_index, vn) + + # Delete any inactive ranges corresponding to `vn`. + haskey(vnv.num_inactive, idx) && delete!(vnv.num_inactive, idx) + + # Re-adjust the indices for varnames occuring after `vn` so + # that they point to the correct indices after the deletions below. + shift_subsequent_indices_left!(vnv, idx) + + # Re-adjust the ranges for varnames occuring after `vn`. + shift_subsequent_ranges_by!(vnv, idx, -n_allocated) + + # Delete references from vector fields, thus shifting the indices of + # varnames occuring after `vn` by one to the left, as we adjusted for above. + deleteat!(vnv.varnames, idx) + deleteat!(vnv.ranges, idx) + deleteat!(vnv.transforms, idx) + + return vnv +end + +""" + values_as(vnv::VarNamedVector[, T]) + +Return the values/realizations in `vnv` as type `T`, if implemented. + +If no type `T` is provided, return values as stored in `vnv`. + +# Examples + +```jldoctest +julia> using DynamicPPL: VarNamedVector + +julia> vnv = VarNamedVector(@varname(x) => 1, @varname(y) => [2.0]); + +julia> values_as(vnv) == [1.0, 2.0] +true + +julia> values_as(vnv, Vector{Float32}) == Vector{Float32}([1.0, 2.0]) +true + +julia> values_as(vnv, OrderedDict) == OrderedDict(@varname(x) => 1.0, @varname(y) => [2.0]) +true + +julia> values_as(vnv, NamedTuple) == (x = 1.0, y = [2.0]) +true +``` +""" +values_as(vnv::VarNamedVector) = values_as(vnv, Vector) +values_as(vnv::VarNamedVector, ::Type{Vector}) = vnv[:] +function values_as(vnv::VarNamedVector, ::Type{Vector{T}}) where {T} + return convert(Vector{T}, values_as(vnv, Vector)) +end +function values_as(vnv::VarNamedVector, ::Type{NamedTuple}) + return NamedTuple(zip(map(Symbol, keys(vnv)), values(vnv))) +end +function values_as(vnv::VarNamedVector, ::Type{D}) where {D<:AbstractDict} + return ConstructionBase.constructorof(D)(pairs(vnv)) +end + +# See the docstring of `getvalue` for the semantics of `hasvalue` and `getvalue`, and how +# they differ from `haskey` and `getindex`. They can be found in src/utils.jl. + +# TODO(mhauru) This is tricky to implement in the general case, and the below implementation +# only covers some simple cases. It's probably sufficient in most situations though. +function hasvalue(vnv::VarNamedVector, vn::VarName) + haskey(vnv, vn) && return true + any(subsumes(vn, k) for k in keys(vnv)) && return true + # Handle the easy case where the right symbol isn't even present. + !any(k -> getsym(k) == getsym(vn), keys(vnv)) && return false + + optic = getoptic(vn) + if optic isa Accessors.IndexLens || optic isa Accessors.ComposedOptic + # If vn is of the form @varname(somesymbol[someindex]), we check whether we store + # @varname(somesymbol) and can index into it with someindex. If we rather have a + # composed optic with the last part being an index lens, we do a similar check but + # stripping out the last index lens part. If these pass, the answer is definitely + # "yes". If not, we still don't know for sure. + # TODO(mhauru) What about casese where vnv stores both @varname(x) and + # @varname(x[1]) or @varname(x.a)? Those should probably be banned, but currently + # aren't. + head, tail = if optic isa Accessors.ComposedOptic + decomp_optic = Accessors.decompose(optic) + first(decomp_optic), Accessors.compose(decomp_optic[2:end]...) + else + optic, identity + end + parent_varname = VarName{getsym(vn)}(tail) + if haskey(vnv, parent_varname) + valvec = getindex(vnv, parent_varname) + return canview(head, valvec) + end + end + throw(ErrorException("hasvalue has not been fully implemented for this VarName: $(vn)")) +end + +# TODO(mhauru) Like hasvalue, this is only partially implemented. +function getvalue(vnv::VarNamedVector, vn::VarName) + !hasvalue(vnv, vn) && throw(KeyError(vn)) + haskey(vnv, vn) && getindex(vnv, vn) + + subsumed_keys = filter(k -> subsumes(vn, k), keys(vnv)) + if length(subsumed_keys) > 0 + # TODO(mhauru) What happens if getindex returns e.g. matrices, and we vcat them? + return mapreduce(k -> getindex(vnv, k), vcat, subsumed_keys) + end + + optic = getoptic(vn) + # See hasvalue for some comments on the logic of this if block. + if optic isa Accessors.IndexLens || optic isa Accessors.ComposedOptic + head, tail = if optic isa Accessors.ComposedOptic + decomp_optic = Accessors.decompose(optic) + first(decomp_optic), Accessors.compose(decomp_optic[2:end]...) + else + optic, identity + end + parent_varname = VarName{getsym(vn)}(tail) + valvec = getindex(vnv, parent_varname) + return head(valvec) + end + throw(ErrorException("getvalue has not been fully implemented for this VarName: $(vn)")) +end + +Base.get(vnv::VarNamedVector, vn::VarName) = getvalue(vnv, vn) diff --git a/test/Project.toml b/test/Project.toml index 13267ee1d..0fe068daa 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -31,6 +32,7 @@ ADTypes = "0.2, 1" AbstractMCMC = "5" AbstractPPL = "0.8.2" Bijectors = "0.13" +Combinatorics = "1" Compat = "4.3.0" Distributions = "0.25" DistributionsAD = "0.6.3" diff --git a/test/compiler.jl b/test/compiler.jl index f1f06eabe..f2d7e5852 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -309,11 +309,11 @@ module Issue537 end vi2 = VarInfo(f2()) vi3 = VarInfo(f3()) @test haskey(vi1.metadata, :y) - @test vi1.metadata.y.vns[1] == @varname(y) + @test first(Base.keys(vi1.metadata.y)) == @varname(y) @test haskey(vi2.metadata, :y) - @test vi2.metadata.y.vns[1] == @varname(y[2][:, 1]) + @test first(Base.keys(vi2.metadata.y)) == @varname(y[2][:, 1]) @test haskey(vi3.metadata, :y) - @test vi3.metadata.y.vns[1] == @varname(y[1]) + @test first(Base.keys(vi3.metadata.y)) == @varname(y[1]) # Conditioning f1_c = f1() | (y=1,) diff --git a/test/model.jl b/test/model.jl index 60a8d2461..dab019c2f 100644 --- a/test/model.jl +++ b/test/model.jl @@ -122,7 +122,6 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true @test logjoints[i] ≈ DynamicPPL.TestUtils.logjoint_true(model, samples[:s], samples[:m]) end - println("\n model $(model) passed !!! \n") end end @@ -200,10 +199,10 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true @test !any(map(x -> x isa DynamicPPL.AbstractVarInfo, call_retval)) end - @testset "Dynamic constraints" begin + @testset "Dynamic constraints, Metadata" begin model = DynamicPPL.TestUtils.demo_dynamic_constraint() - vi = VarInfo(model) spl = SampleFromPrior() + vi = VarInfo(model, spl, DefaultContext(), DynamicPPL.Metadata) link!!(vi, spl, model) for i in 1:10 @@ -216,6 +215,14 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end end + @testset "Dynamic constraints, VectorVarInfo" begin + model = DynamicPPL.TestUtils.demo_dynamic_constraint() + for i in 1:10 + vi = VarInfo(model) + @test vi[@varname(x)] >= vi[@varname(m)] + end + end + @testset "rand" begin model = gdemo_default @@ -324,7 +331,6 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true chain = MCMCChains.Chains( permutedims(stack(vals)), syms; info=(varname_to_symbol=vns_to_syms,) ) - display(chain) # Test! results = generated_quantities(model, chain) @@ -345,7 +351,6 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true vcat(syms, [:y]); info=(varname_to_symbol=vns_to_syms_with_extra,), ) - display(chain_with_extra) # Test! results = generated_quantities(model, chain_with_extra) for (x_true, result) in zip(xs, results) @@ -358,6 +363,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true models_to_test = [ DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) ] + context = DefaultContext() @testset "$(model.f)" for model in models_to_test vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) @@ -366,18 +372,16 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - @test ( - @inferred(DynamicPPL.evaluate!!(model, varinfo, DefaultContext())); + @test begin + @inferred(DynamicPPL.evaluate!!(model, varinfo, context)) true - ) + end varinfo_linked = DynamicPPL.link(varinfo, model) - @test ( - @inferred( - DynamicPPL.evaluate!!(model, varinfo_linked, DefaultContext()) - ); + @test begin + @inferred(DynamicPPL.evaluate!!(model, varinfo_linked, context)) true - ) + end end end end diff --git a/test/runtests.jl b/test/runtests.jl index aa0883708..9596067eb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,8 @@ using Random using Serialization using Test +using Combinatorics: combinations + using DynamicPPL: getargs_dottilde, getargs_tilde, Selector const DIRECTORY_DynamicPPL = dirname(dirname(pathof(DynamicPPL))) @@ -38,6 +40,7 @@ include("test_util.jl") @testset "interface" begin include("utils.jl") include("compiler.jl") + include("varnamedvector.jl") include("varinfo.jl") include("simple_varinfo.jl") include("model.jl") diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 5ce112941..f5b97dbbc 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -56,13 +56,41 @@ @test !haskey(svi, @varname(m.a[2])) @test !haskey(svi, @varname(m.a.b)) end + + @testset "VarNamedVector" begin + svi = SimpleVarInfo(push!!(VarNamedVector(), @varname(m), 1.0)) + @test getlogp(svi) == 0.0 + @test haskey(svi, @varname(m)) + @test !haskey(svi, @varname(m[1])) + + svi = SimpleVarInfo(push!!(VarNamedVector(), @varname(m), [1.0])) + @test getlogp(svi) == 0.0 + @test haskey(svi, @varname(m)) + @test haskey(svi, @varname(m[1])) + @test !haskey(svi, @varname(m[2])) + @test svi[@varname(m)][1] == svi[@varname(m[1])] + + svi = SimpleVarInfo(push!!(VarNamedVector(), @varname(m.a), [1.0])) + @test haskey(svi, @varname(m)) + @test haskey(svi, @varname(m.a)) + @test haskey(svi, @varname(m.a[1])) + @test !haskey(svi, @varname(m.a[2])) + @test !haskey(svi, @varname(m.a.b)) + # The implementation of haskey and getvalue fo VarNamedVector is incomplete, the + # next test is here to remind of us that. + svi = SimpleVarInfo(push!!(VarNamedVector(), @varname(m.a.b), [1.0])) + @test_broken !haskey(svi, @varname(m.a.b.c.d)) + end end @testset "link!! & invlink!! on $(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) @testset "$(typeof(vi))" for vi in ( - SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), VarInfo(model) + SimpleVarInfo(Dict()), + SimpleVarInfo(values_constrained), + SimpleVarInfo(VarNamedVector()), + VarInfo(model), ) for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) @@ -115,12 +143,19 @@ # to see whether this is the case. svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model)) svi_dict = SimpleVarInfo(VarInfo(model), Dict) + vnv = VarNamedVector() + for (k, v) in pairs(DynamicPPL.TestUtils.rand_prior_true(model)) + vnv = push!!(vnv, VarName{k}(), v) + end + svi_vnv = SimpleVarInfo(vnv) @testset "$(nameof(typeof(DynamicPPL.values_as(svi))))" for svi in ( svi_nt, svi_dict, - DynamicPPL.settrans!!(svi_nt, true), - DynamicPPL.settrans!!(svi_dict, true), + svi_vnv, + DynamicPPL.settrans!!(deepcopy(svi_nt), true), + DynamicPPL.settrans!!(deepcopy(svi_dict), true), + DynamicPPL.settrans!!(deepcopy(svi_vnv), true), ) # RandOM seed is set in each `@testset`, so we need to sample # a new realization for `m` here. @@ -195,30 +230,34 @@ model = DynamicPPL.TestUtils.demo_dynamic_constraint() # Initialize. - svi = DynamicPPL.settrans!!(SimpleVarInfo(), true) - svi = last(DynamicPPL.evaluate!!(model, svi, SamplingContext())) - - # Sample with large variations in unconstrained space. - for i in 1:10 - for vn in keys(svi) - svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) - end - retval, svi = DynamicPPL.evaluate!!(model, svi, DefaultContext()) - @test retval.m == svi[@varname(m)] # `m` is unconstrained - @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` + svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) + svi_nt = last(DynamicPPL.evaluate!!(model, svi_nt, SamplingContext())) + svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(VarNamedVector()), true) + svi_vnv = last(DynamicPPL.evaluate!!(model, svi_vnv, SamplingContext())) + + for svi in (svi_nt, svi_vnv) + # Sample with large variations in unconstrained space. + for i in 1:10 + for vn in keys(svi) + svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) + end + retval, svi = DynamicPPL.evaluate!!(model, svi, DefaultContext()) + @test retval.m == svi[@varname(m)] # `m` is unconstrained + @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` + + retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, retval.m, retval.x + ) - retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, retval.m, retval.x - ) + # Realizations from model should all be equal to the unconstrained realization. + for vn in DynamicPPL.TestUtils.varnames(model) + @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 + end - # Realizations from model should all be equal to the unconstrained realization. - for vn in DynamicPPL.TestUtils.varnames(model) - @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 + # `getlogp` should be equal to the logjoint with log-absdet-jac correction. + lp = getlogp(svi) + @test lp ≈ lp_true end - - # `getlogp` should be equal to the logjoint with log-absdet-jac correction. - lp = getlogp(svi) - @test lp ≈ lp_true end end diff --git a/test/test_util.jl b/test/test_util.jl index 64832f51e..0c7949e48 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -84,10 +84,15 @@ Return string representing a short description of `vi`. """ short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = "threadsafe($(short_varinfo_name(vi.varinfo)))" -short_varinfo_name(::TypedVarInfo) = "TypedVarInfo" +function short_varinfo_name(vi::TypedVarInfo) + DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector" + return "TypedVarInfo" +end short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" +short_varinfo_name(::VectorVarInfo) = "VectorVarInfo" short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" +short_varinfo_name(::SimpleVarInfo{<:VarNamedVector}) = "SimpleVarInfo{<:VarNamedVector}" # convenient functions for testing model.jl # function to modify the representation of values based on their length diff --git a/test/varinfo.jl b/test/varinfo.jl index 6a3d8d2bc..382eb7e58 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -19,7 +19,7 @@ struct MySAlg end DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @testset "varinfo.jl" begin - @testset "TypedVarInfo" begin + @testset "TypedVarInfo with Metadata" begin @model gdemo(x, y) = begin s ~ InverseGamma(2, 3) m ~ truncated(Normal(0.0, sqrt(s)), 0.0, 2.0) @@ -28,7 +28,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end model = gdemo(1.0, 2.0) - vi = VarInfo() + vi = VarInfo(DynamicPPL.Metadata()) model(vi, SampleFromUniform()) tvi = TypedVarInfo(vi) @@ -51,6 +51,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end end end + @testset "Base" begin # Test Base functions: # string, Symbol, ==, hash, in, keys, haskey, isempty, push!!, empty!!, @@ -120,6 +121,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) test_base!!(TypedVarInfo(vi)) test_base!!(SimpleVarInfo()) test_base!!(SimpleVarInfo(Dict())) + test_base!!(SimpleVarInfo(VarNamedVector())) end @testset "flags" begin # Test flag setting: @@ -141,12 +143,12 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) unset_flag!(vi, vn_x, "del") @test !is_flagged(vi, vn_x, "del") end - vi = VarInfo() + vi = VarInfo(DynamicPPL.Metadata()) test_varinfo!(vi) test_varinfo!(empty!!(TypedVarInfo(vi))) end @testset "setgid!" begin - vi = VarInfo() + vi = VarInfo(DynamicPPL.Metadata()) meta = vi.metadata vn = @varname x dist = Normal(0, 1) @@ -196,18 +198,36 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) m_vns = model == model_uv ? [@varname(m[i]) for i in 1:5] : @varname(m) s_vns = @varname(s) - vi_typed = VarInfo(model) - vi_untyped = VarInfo() + vi_typed = VarInfo( + model, SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata + ) + vi_untyped = VarInfo(DynamicPPL.Metadata()) + vi_vnv = VarInfo(VarNamedVector()) + vi_vnv_typed = VarInfo( + model, SampleFromPrior(), DefaultContext(), DynamicPPL.VarNamedVector + ) model(vi_untyped, SampleFromPrior()) + model(vi_vnv, SampleFromPrior()) - for vi in [vi_untyped, vi_typed] + model_name = model == model_uv ? "univariate" : "multivariate" + @testset "$(model_name), $(short_varinfo_name(vi))" for vi in [ + vi_untyped, vi_typed, vi_vnv, vi_vnv_typed + ] + Random.seed!(23) vicopy = deepcopy(vi) ### `setval` ### - DynamicPPL.setval!(vicopy, (m=zeros(5),)) + # TODO(mhauru) The interface here seems inconsistent between Metadata and + # VarNamedVector. I'm lazy to fix it though, because I think we need to + # rework it soon anyway. + if vi in [vi_vnv, vi_vnv_typed] + DynamicPPL.setval!(vicopy, zeros(5), m_vns) + else + DynamicPPL.setval!(vicopy, (m=zeros(5),)) + end # Setting `m` fails for univariate due to limitations of `setval!` # and `setval_and_resample!`. See docstring of `setval!` for more info. - if model == model_uv + if model == model_uv && vi in [vi_untyped, vi_typed] @test_broken vicopy[m_vns] == zeros(5) else @test vicopy[m_vns] == zeros(5) @@ -240,6 +260,13 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) continue end + if vi in [vi_vnv, vi_vnv_typed] + # `setval_and_resample!` works differently for `VarNamedVector`: All + # values will be resampled when model(vicopy) is called. Hence the below + # tests are not applicable. + continue + end + vicopy = deepcopy(vi) DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),)) model(vicopy) @@ -338,6 +365,14 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + + ## `SimpleVarInfo{<:VarNamedVector}` + vi = DynamicPPL.settrans!!(SimpleVarInfo(VarNamedVector()), true) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) + x = f(DynamicPPL.getindex_internal(vi, vn)) + @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) end @testset "values_as" begin @@ -409,6 +444,12 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) continue end + if DynamicPPL.has_varnamedvector(varinfo) && mutating + # NOTE: Can't handle mutating `link!` and `invlink!` `VarNamedVector`. + @test_broken false + continue + end + # Evaluate the model once to update the logp of the varinfo. varinfo = last(DynamicPPL.evaluate!!(model, varinfo, DefaultContext())) @@ -636,6 +677,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) varinfo_left = VarInfo(model_left) varinfo_right = VarInfo(model_right) + varinfo_right = DynamicPPL.settrans!!(varinfo_right, true, @varname(x)) varinfo_merged = merge(varinfo_left, varinfo_right) vns = [@varname(x), @varname(y), @varname(z)] @@ -643,13 +685,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) # Right has precedence. @test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)] - @test DynamicPPL.getdist(varinfo_merged, @varname(x)) isa Normal + @test DynamicPPL.istrans(varinfo_merged, @varname(x)) end end @testset "VarInfo with selectors" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - varinfo = VarInfo(model) + varinfo = VarInfo( + model, + DynamicPPL.SampleFromPrior(), + DynamicPPL.DefaultContext(), + DynamicPPL.Metadata, + ) selector = DynamicPPL.Selector() spl = Sampler(MySAlg(), model, selector) diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl new file mode 100644 index 000000000..ba365d24a --- /dev/null +++ b/test/varnamedvector.jl @@ -0,0 +1,489 @@ +replace_sym(vn::VarName, sym_new::Symbol) = VarName{sym_new}(vn.lens) + +increase_size_for_test(x::Real) = [x] +increase_size_for_test(x::AbstractArray) = repeat(x, 2) + +decrease_size_for_test(x::Real) = x +decrease_size_for_test(x::AbstractVector) = first(x) +decrease_size_for_test(x::AbstractArray) = first(eachslice(x; dims=1)) + +function need_varnames_relaxation(vnv::VarNamedVector, vn::VarName, val) + if isconcretetype(eltype(vnv.varnames)) + # If the container is concrete, we need to make sure that the varname types match. + # E.g. if `vnv.varnames` has `eltype` `VarName{:x, IndexLens{Tuple{Int64}}}` then + # we need `vn` to also be of this type. + # => If the varname types don't match, we need to relax the container type. + return any(keys(vnv)) do vn_present + typeof(vn_present) !== typeof(val) + end + end + + return false +end +function need_varnames_relaxation(vnv::VarNamedVector, vns, vals) + return any(need_varnames_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) +end + +function need_values_relaxation(vnv::VarNamedVector, vn::VarName, val) + if isconcretetype(eltype(vnv.vals)) + return promote_type(eltype(vnv.vals), eltype(val)) != eltype(vnv.vals) + end + + return false +end +function need_values_relaxation(vnv::VarNamedVector, vns, vals) + return any(need_values_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) +end + +function need_transforms_relaxation(vnv::VarNamedVector, vn::VarName, val) + return if isconcretetype(eltype(vnv.transforms)) + # If the container is concrete, we need to make sure that the sizes match. + # => If the sizes don't match, we need to relax the container type. + any(keys(vnv)) do vn_present + size(vnv[vn_present]) != size(val) + end + elseif eltype(vnv.transforms) !== Any + # If it's not concrete AND it's not `Any`, then we should just make it `Any`. + true + else + # Otherwise, it's `Any`, so we don't need to relax the container type. + false + end +end +function need_transforms_relaxation(vnv::VarNamedVector, vns, vals) + return any(need_transforms_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) +end + +""" + relax_container_types(vnv::VarNamedVector, vn::VarName, val) + relax_container_types(vnv::VarNamedVector, vns, val) + +Relax the container types of `vnv` if necessary to accommodate `vn` and `val`. + +This attempts to avoid unnecessary container type relaxations by checking whether +the container types of `vnv` are already compatible with `vn` and `val`. + +# Notes +For example, if `vn` is not compatible with the current keys in `vnv`, then +the underlying types will be changed to `VarName` to accommodate `vn`. + +Similarly: +- If `val` is not compatible with the current values in `vnv`, then + the underlying value type will be changed to `Real`. +- If `val` requires a transformation that is not compatible with the current + transformations type in `vnv`, then the underlying transformation type will + be changed to `Any`. +""" +function relax_container_types(vnv::VarNamedVector, vn::VarName, val) + return relax_container_types(vnv, [vn], [val]) +end +function relax_container_types(vnv::VarNamedVector, vns, vals) + if need_varnames_relaxation(vnv, vns, vals) + varname_to_index_new = convert(OrderedDict{VarName,Int}, vnv.varname_to_index) + varnames_new = convert(Vector{VarName}, vnv.varnames) + else + varname_to_index_new = vnv.varname_to_index + varnames_new = vnv.varnames + end + + transforms_new = if need_transforms_relaxation(vnv, vns, vals) + convert(Vector{Any}, vnv.transforms) + else + vnv.transforms + end + + vals_new = if need_values_relaxation(vnv, vns, vals) + convert(Vector{Real}, vnv.vals) + else + vnv.vals + end + + return VarNamedVector( + varname_to_index_new, + varnames_new, + vnv.ranges, + vals_new, + transforms_new, + vnv.is_unconstrained, + vnv.num_inactive, + ) +end + +@testset "VarNamedVector" begin + # Test element-related operations: + # - `getindex` + # - `setindex!` + # - `push!` + # - `update!` + # + # And these are all be tested for different types of values: + # - scalar + # - vector + # - matrix + + # Test operations on `VarNamedVector`: + # - `empty!` + # - `iterate` + # - `convert` to + # - `AbstractDict` + test_pairs = OrderedDict( + @varname(x[1]) => rand(), + @varname(x[2]) => rand(2), + @varname(x[3]) => rand(2, 3), + @varname(y[1]) => rand(), + @varname(y[2]) => rand(2), + @varname(y[3]) => rand(2, 3), + @varname(z[1]) => rand(1:10), + @varname(z[2]) => rand(1:10, 2), + @varname(z[3]) => rand(1:10, 2, 3), + ) + test_vns = collect(keys(test_pairs)) + test_vals = collect(values(test_pairs)) + + @testset "constructor: no args" begin + # Empty. + vnv = VarNamedVector() + @test isempty(vnv) + @test eltype(vnv) == Real + + # Empty with types. + vnv = VarNamedVector{VarName,Float64}() + @test isempty(vnv) + @test eltype(vnv) == Float64 + end + + test_varnames_iter = combinations(test_vns, 2) + @testset "$(vn_left) and $(vn_right)" for (vn_left, vn_right) in test_varnames_iter + val_left = test_pairs[vn_left] + val_right = test_pairs[vn_right] + vnv_base = VarNamedVector([vn_left, vn_right], [val_left, val_right]) + + # We'll need the transformations later. + # TODO: Should we test other transformations than just `ReshapeTransform`? + from_vec_left = DynamicPPL.from_vec_transform(val_left) + from_vec_right = DynamicPPL.from_vec_transform(val_right) + to_vec_left = inverse(from_vec_left) + to_vec_right = inverse(from_vec_right) + + # Compare to alternative constructors. + vnv_from_dict = VarNamedVector( + OrderedDict(vn_left => val_left, vn_right => val_right) + ) + @test vnv_base == vnv_from_dict + + # We want the types of fields such as `varnames` and `transforms` to specialize + # whenever possible + some functionality, e.g. `push!`, is only sensible + # if the underlying containers can support it. + # Expected behavior + should_have_restricted_varname_type = typeof(vn_left) == typeof(vn_right) + should_have_restricted_transform_type = size(val_left) == size(val_right) + # Actual behavior + has_restricted_transform_type = isconcretetype(eltype(vnv_base.transforms)) + has_restricted_varname_type = isconcretetype(eltype(vnv_base.varnames)) + + @testset "type specialization" begin + @test !should_have_restricted_varname_type || has_restricted_varname_type + @test !should_have_restricted_transform_type || has_restricted_transform_type + end + + # `eltype` + @test eltype(vnv_base) == promote_type(eltype(val_left), eltype(val_right)) + # `length` + @test length(vnv_base) == length(val_left) + length(val_right) + + # `isempty` + @test !isempty(vnv_base) + + # `empty!` + @testset "empty!" begin + vnv = deepcopy(vnv_base) + empty!(vnv) + @test isempty(vnv) + end + + # `similar` + @testset "similar" begin + vnv = similar(vnv_base) + @test isempty(vnv) + @test typeof(vnv) == typeof(vnv_base) + end + + # `getindex` + @testset "getindex" begin + # With `VarName` index. + @test vnv_base[vn_left] == val_left + @test vnv_base[vn_right] == val_right + + # With `Int` index. + val_vec = vcat(to_vec_left(val_left), to_vec_right(val_right)) + @test all(vnv_base[i] == val_vec[i] for i in 1:length(val_vec)) + end + + # `setindex!` + @testset "setindex!" begin + vnv = deepcopy(vnv_base) + vnv[vn_left] = val_left .+ 100 + @test vnv[vn_left] == val_left .+ 100 + vnv[vn_right] = val_right .+ 100 + @test vnv[vn_right] == val_right .+ 100 + end + + # `getindex_internal` + @testset "getindex_internal" begin + # With `VarName` index. + @test DynamicPPL.getindex_internal(vnv_base, vn_left) == to_vec_left(val_left) + @test DynamicPPL.getindex_internal(vnv_base, vn_right) == + to_vec_right(val_right) + # With `Int` index. + val_vec = vcat(to_vec_left(val_left), to_vec_right(val_right)) + @test all( + DynamicPPL.getindex_internal(vnv_base, i) == val_vec[i] for + i in 1:length(val_vec) + ) + end + + # `setindex_internal!` + @testset "setindex_internal!" begin + vnv = deepcopy(vnv_base) + DynamicPPL.setindex_internal!(vnv, to_vec_left(val_left .+ 100), vn_left) + @test vnv[vn_left] == val_left .+ 100 + DynamicPPL.setindex_internal!(vnv, to_vec_right(val_right .+ 100), vn_right) + @test vnv[vn_right] == val_right .+ 100 + end + + # `delete!` + @testset "delete!" begin + vnv = deepcopy(vnv_base) + delete!(vnv, vn_left) + @test !haskey(vnv, vn_left) + @test haskey(vnv, vn_right) + delete!(vnv, vn_right) + @test !haskey(vnv, vn_right) + end + + # `merge` + @testset "merge" begin + # When there are no inactive entries, `merge` on itself result in the same. + @test merge(vnv_base, vnv_base) == vnv_base + + # Merging with empty should result in the same. + @test merge(vnv_base, similar(vnv_base)) == vnv_base + @test merge(similar(vnv_base), vnv_base) == vnv_base + + # With differences. + vnv_left_only = deepcopy(vnv_base) + delete!(vnv_left_only, vn_right) + vnv_right_only = deepcopy(vnv_base) + delete!(vnv_right_only, vn_left) + + # `(x,)` and `(x, y)` should be `(x, y)`. + @test merge(vnv_left_only, vnv_base) == vnv_base + # `(x, y)` and `(x,)` should be `(x, y)`. + @test merge(vnv_base, vnv_left_only) == vnv_base + # `(x, y)` and `(y,)` should be `(x, y)`. + @test merge(vnv_base, vnv_right_only) == vnv_base + # `(y,)` and `(x, y)` should be `(y, x)`. + vnv_merged = merge(vnv_right_only, vnv_base) + @test vnv_merged != vnv_base + @test collect(keys(vnv_merged)) == [vn_right, vn_left] + end + + # `push!` & `update!` + @testset "push!" begin + vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) + @testset "$vn" for vn in test_vns + val = test_pairs[vn] + if vn == vn_left || vn == vn_right + # Should not be possible to `push!` existing varname. + @test_throws ArgumentError push!(vnv, vn, val) + else + push!(vnv, vn, val) + @test vnv[vn] == val + end + end + end + + @testset "update!" begin + vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) + @testset "$vn" for vn in test_vns + val = test_pairs[vn] + expected_length = if haskey(vnv, vn) + # If it's already present, the resulting length will be unchanged. + length(vnv) + else + length(vnv) + length(val) + end + + DynamicPPL.update!(vnv, vn, val .+ 1) + x = vnv[:] + @test vnv[vn] == val .+ 1 + @test length(vnv) == expected_length + @test length(x) == length(vnv) + + # There should be no redundant values in the underlying vector. + @test !DynamicPPL.has_inactive(vnv) + + # `getindex` with `Int` index. + @test all(vnv[i] == x[i] for i in 1:length(x)) + end + + vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) + @testset "$vn (increased size)" for vn in test_vns + val_original = test_pairs[vn] + val = increase_size_for_test(val_original) + vn_already_present = haskey(vnv, vn) + expected_length = if vn_already_present + # If it's already present, the resulting length will be altered. + length(vnv) + length(val) - length(val_original) + else + length(vnv) + length(val) + end + + DynamicPPL.update!(vnv, vn, val .+ 1) + x = vnv[:] + @test vnv[vn] == val .+ 1 + @test length(vnv) == expected_length + @test length(x) == length(vnv) + + # `getindex` with `Int` index. + @test all(vnv[i] == x[i] for i in 1:length(x)) + end + + vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) + @testset "$vn (decreased size)" for vn in test_vns + val_original = test_pairs[vn] + val = decrease_size_for_test(val_original) + vn_already_present = haskey(vnv, vn) + expected_length = if vn_already_present + # If it's already present, the resulting length will be altered. + length(vnv) + length(val) - length(val_original) + else + length(vnv) + length(val) + end + DynamicPPL.update!(vnv, vn, val .+ 1) + x = vnv[:] + @test vnv[vn] == val .+ 1 + @test length(vnv) == expected_length + @test length(x) == length(vnv) + + # `getindex` with `Int` index. + @test all(vnv[i] == x[i] for i in 1:length(x)) + end + end + end + + @testset "growing and shrinking" begin + @testset "deterministic" begin + n = 5 + vn = @varname(x) + vnv = VarNamedVector(OrderedDict(vn => [true])) + @test !DynamicPPL.has_inactive(vnv) + # Growing should not create inactive ranges. + for i in 1:n + x = fill(true, i) + DynamicPPL.update!(vnv, vn, x) + @test !DynamicPPL.has_inactive(vnv) + end + + # Same size should not create inactive ranges. + x = fill(true, n) + DynamicPPL.update!(vnv, vn, x) + @test !DynamicPPL.has_inactive(vnv) + + # Shrinking should create inactive ranges. + for i in (n - 1):-1:1 + x = fill(true, i) + DynamicPPL.update!(vnv, vn, x) + @test DynamicPPL.has_inactive(vnv) + @test DynamicPPL.num_inactive(vnv, vn) == n - i + end + end + + @testset "random" begin + n = 5 + vn = @varname(x) + vnv = VarNamedVector(OrderedDict(vn => [true])) + @test !DynamicPPL.has_inactive(vnv) + + # Insert a bunch of random-length vectors. + for i in 1:100 + x = fill(true, rand(1:n)) + DynamicPPL.update!(vnv, vn, x) + end + # Should never be allocating more than `n` elements. + @test DynamicPPL.num_allocated(vnv, vn) ≤ n + + # If we compaticfy, then it should always be the same size as just inserted. + for i in 1:10 + x = fill(true, rand(1:n)) + DynamicPPL.update!(vnv, vn, x) + DynamicPPL.contiguify!(vnv) + @test DynamicPPL.num_allocated(vnv, vn) == length(x) + end + end + end + + @testset "subset" begin + vnv = VarNamedVector(test_pairs) + @test subset(vnv, test_vns) == vnv + @test subset(vnv, VarName[]) == VarNamedVector() + @test merge(subset(vnv, test_vns[1:3]), subset(vnv, test_vns[4:end])) == vnv + + # Test that subset preserves transformations and unconstrainedness. + vn = @varname(t[1]) + vns = vcat(test_vns, [vn]) + vnv = push!!(vnv, vn, 2.0, x -> x^2) + DynamicPPL.settrans!(vnv, true, @varname(t[1])) + @test vnv[@varname(t[1])] == 4.0 + @test istrans(vnv, @varname(t[1])) + @test subset(vnv, vns) == vnv + end +end + +@testset "VarInfo + VarNamedVector" begin + models = DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in models + # NOTE: Need to set random seed explicitly to avoid using the same seed + # for initialization as for sampling in the inner testset below. + Random.seed!(42) + value_true = DynamicPPL.TestUtils.rand_prior_true(model) + vns = DynamicPPL.TestUtils.varnames(model) + varnames = DynamicPPL.TestUtils.varnames(model) + varinfos = DynamicPPL.TestUtils.setup_varinfos( + model, value_true, varnames; include_threadsafe=false + ) + # Filter out those which are not based on `VarNamedVector`. + varinfos = filter(DynamicPPL.has_varnamedvector, varinfos) + # Get the true log joint. + logp_true = DynamicPPL.TestUtils.logjoint_true(model, value_true...) + + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + # Need to make sure we're using a different random seed from the + # one used in the above call to `rand_prior_true`. + Random.seed!(43) + + # Are values correct? + DynamicPPL.TestUtils.test_values(varinfo, value_true, vns) + + # Is evaluation correct? + varinfo_eval = last( + DynamicPPL.evaluate!!(model, deepcopy(varinfo), DefaultContext()) + ) + # Log density should be the same. + @test getlogp(varinfo_eval) ≈ logp_true + # Values should be the same. + DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) + + # Is sampling correct? + varinfo_sample = last( + DynamicPPL.evaluate!!(model, deepcopy(varinfo), SamplingContext()) + ) + # Log density should be different. + @test getlogp(varinfo_sample) != getlogp(varinfo) + # Values should be different. + DynamicPPL.TestUtils.test_values( + varinfo_sample, value_true, vns; compare=!isequal + ) + end + end +end