Skip to content

Commit

Permalink
Tiny style improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Sep 2, 2024
1 parent 1a4eadb commit 5c2a625
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
16 changes: 8 additions & 8 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(conte
"""
struct DefaultContext <: AbstractContext end
The `DefaultContext` is used by default to compute the log joint probability of the data
The `DefaultContext` is used by default to compute the log joint probability of the data
and parameters when running the model.
"""
struct DefaultContext <: AbstractContext end
Expand All @@ -199,7 +199,7 @@ NodeTrait(context::DefaultContext) = IsLeaf()
vars::Tvars
end
The `PriorContext` enables the computation of the log prior of the parameters `vars` when
The `PriorContext` enables the computation of the log prior of the parameters `vars` when
running the model.
"""
struct PriorContext{Tvars} <: AbstractContext
Expand All @@ -213,8 +213,8 @@ NodeTrait(context::PriorContext) = IsLeaf()
vars::Tvars
end
The `LikelihoodContext` enables the computation of the log likelihood of the parameters when
running the model. `vars` can be used to evaluate the log likelihood for specific values
The `LikelihoodContext` enables the computation of the log likelihood of the parameters when
running the model. `vars` can be used to evaluate the log likelihood for specific values
of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default.
"""
struct LikelihoodContext{Tvars} <: AbstractContext
Expand All @@ -229,10 +229,10 @@ NodeTrait(context::LikelihoodContext) = IsLeaf()
loglike_scalar::T
end
The `MiniBatchContext` enables the computation of
`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the
`loglike_scalar` field, typically equal to `the number of data points / batch size`.
This is useful in batch-based stochastic gradient descent algorithms to be optimizing
The `MiniBatchContext` enables the computation of
`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the
`loglike_scalar` field, typically equal to `the number of data points / batch size`.
This is useful in batch-based stochastic gradient descent algorithms to be optimizing
`log(prior) + log(likelihood of all the data points)` in the expectation.
"""
struct MiniBatchContext{Tctx,T} <: AbstractContext
Expand Down
14 changes: 7 additions & 7 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
f::F
args::NamedTuple{argnames,Targs}
defaults::NamedTuple{defaultnames,Tdefaults}
context::Ctx
context::Ctx=DefaultContext()
end
A `Model` struct with model evaluation function of type `F`, arguments of names `argnames`
Expand Down Expand Up @@ -1079,7 +1079,7 @@ end
Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`.
# Examples
```jldoctest
julia> using MCMCChains, Distributions
Expand All @@ -1095,7 +1095,7 @@ julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);
julia> logjoint(demo_model([1., 2.]), chain);
```
```
"""
function logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
Expand Down Expand Up @@ -1126,7 +1126,7 @@ end
Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`.
# Examples
```jldoctest
julia> using MCMCChains, Distributions
Expand All @@ -1142,7 +1142,7 @@ julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);
julia> logprior(demo_model([1., 2.]), chain);
```
```
"""
function logprior(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
Expand Down Expand Up @@ -1173,7 +1173,7 @@ end
Return an array of log likelihoods evaluated at each sample in an MCMC `chain`.
# Examples
```jldoctest
julia> using MCMCChains, Distributions
Expand All @@ -1189,7 +1189,7 @@ julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);
julia> loglikelihood(demo_model([1., 2.]), chain);
```
```
"""
function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
Expand Down

0 comments on commit 5c2a625

Please sign in to comment.