Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Ratcliff Diffusion Model pdf and simulators #90

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

kiante-fernandez
Copy link
Contributor

The last PR was quite stale and did not keep up with the new type interface. Here is the new WIP for the Ratcliff DDM

@kiante-fernandez
Copy link
Contributor Author

kiante-fernandez commented Jul 12, 2024

Could use some feedback. I might have botched the rejection sampling translation.
Also after talking with folks at JuliaCon, I think we are safe using HCubature here.

Lastly, I'm not sure what we want to do about the CDF. I am aware of some implementations, but not sure which we might want to consider modifying using HCubature.

Copy link
Contributor

github-actions bot commented Jul 12, 2024

Benchmark Results

master 80791c2... master/80791c22871b99...
logpdf/("SequentialSamplingModels.DDM", 10) 1.7 ± 0.18 μs 0.088 ± 0.027 ms 0.0193
logpdf/("SequentialSamplingModels.DDM", 100) 17.3 ± 0.78 μs 0.886 ± 0.11 ms 0.0195
logpdf/("SequentialSamplingModels.LBA", 10) 2.5 ± 0.2 μs 2.5 ± 0.19 μs 0.999
logpdf/("SequentialSamplingModels.LBA", 100) 23.9 ± 0.57 μs 23.8 ± 0.62 μs 1
logpdf/("SequentialSamplingModels.LNR", 10) 1.02 ± 0.18 μs 1.02 ± 0.17 μs 1
logpdf/("SequentialSamplingModels.LNR", 100) 8.65 ± 0.27 μs 8.65 ± 0.28 μs 1
logpdf/("SequentialSamplingModels.RDM", 10) 2.64 ± 0.24 μs 2.63 ± 0.25 μs 1
logpdf/("SequentialSamplingModels.RDM", 100) 25.2 ± 0.69 μs 25.1 ± 0.71 μs 1
logpdf/("SequentialSamplingModels.Wald", 10) 0.228 ± 0.17 μs 0.228 ± 0.17 μs 1
logpdf/("SequentialSamplingModels.Wald", 100) 2.02 ± 0.056 μs 2.01 ± 0.041 μs 1
logpdf/("SequentialSamplingModels.WaldMixture", 10) 1.12 ± 0.17 μs 1.12 ± 0.16 μs 1
logpdf/("SequentialSamplingModels.WaldMixture", 100) 10.9 ± 0.16 μs 10.9 ± 0.16 μs 0.999
rand/("SequentialSamplingModels.DDM", 10) 2.91 ± 0.38 μs 5.43 ± 0.71 μs 0.537
rand/("SequentialSamplingModels.DDM", 100) 27.8 ± 1.3 μs 0.0531 ± 0.0023 ms 0.523
rand/("SequentialSamplingModels.LBA", 10) 3.23 ± 1.3 μs 3.25 ± 1.3 μs 0.993
rand/("SequentialSamplingModels.LBA", 100) 16.6 ± 0.39 μs 16.9 ± 0.35 μs 0.987
rand/("SequentialSamplingModels.LCA", 10) 0.589 ± 0.2 ms 0.582 ± 0.2 ms 1.01
rand/("SequentialSamplingModels.LCA", 100) 6.42 ± 0.26 ms 6.33 ± 0.23 ms 1.02
rand/("SequentialSamplingModels.LNR", 10) 1.07 ± 0.17 μs 1.09 ± 0.17 μs 0.977
rand/("SequentialSamplingModels.LNR", 100) 7.45 ± 3.7 μs 7.47 ± 3.7 μs 0.997
rand/("SequentialSamplingModels.RDM", 10) 1.47 ± 0.34 μs 1.48 ± 0.33 μs 0.997
rand/("SequentialSamplingModels.RDM", 100) 10.8 ± 3.7 μs 11 ± 3.8 μs 0.986
rand/("SequentialSamplingModels.Wald", 10) 0.464 ± 0.16 μs 0.471 ± 0.16 μs 0.986
rand/("SequentialSamplingModels.Wald", 100) 2.88 ± 0.23 μs 2.9 ± 0.18 μs 0.996
rand/("SequentialSamplingModels.WaldMixture", 10) 1.22 ± 0.17 μs 1.22 ± 0.17 μs 1
rand/("SequentialSamplingModels.WaldMixture", 100) 11.8 ± 0.19 μs 11.8 ± 0.19 μs 1.01
simulate/SequentialSamplingModels.DDM 3.69 ± 1.5 μs 3.75 ± 1.7 μs 0.983
simulate/SequentialSamplingModels.LBA 3.69 ± 0.37 μs 3.69 ± 0.37 μs 1
simulate/SequentialSamplingModels.LCA 0.0797 ± 0.015 ms 0.0807 ± 0.016 ms 0.988
simulate/SequentialSamplingModels.RDM 0.0729 ± 0.025 ms 0.075 ± 0.023 ms 0.972
simulate/SequentialSamplingModels.Wald 9.08 ± 4.6 μs 9.4 ± 5.2 μs 0.966
simulate/SequentialSamplingModels.WaldMixture 4 ± 1.5 μs 4.05 ± 1.5 μs 0.988
simulate/mdft 0.177 ± 0.071 ms 0.179 ± 0.068 ms 0.987
time_to_load 1.54 ± 0.0061 s 1.54 ± 0.0088 s 1

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

@itsdfish
Copy link
Owner

Sounds good. I'll see if I can help.

Let me verify whether I understand the basic approach. We want to integrate over nu via _pdf_sv, and itegrate over st and sz numerically with hcubature. Is that correct?

@itsdfish
Copy link
Owner

Here is my plan. I will start by testing _small_time_pdf and _large_time_pdf because they are part of the baseline pdf function _pdf. Once _pdf is working, I will test _pdf_sv, and once that is working, I will be able to test each of component computed via hcubature. I should be able to start over the weekend. I will post my changes to full_DDM on the primary fork so you can track progress.

@itsdfish
Copy link
Owner

itsdfish commented Jul 12, 2024

Before digging into the minutia of the code, I wanted to investigate the feasibility of differentiating through a numeric integator. So far I have not had success. I'm not sure if the devs you talked to at JuliaCon might have some insights for making it work. Here is my example:

Code

using Distributions 
import Distributions: insupport
import Distributions: logpdf 
import Distributions: loglikelihood
import Distributions: maximum
import Distributions: minimum
import Distributions: rand 
import Distributions: pdf
using HCubature 
using Turing 

struct NormalMixture{T<:Real} <: ContinuousUnivariateDistribution
    μ::T 
    σmin::T
    σmax::T
end


minimum(d::NormalMixture) = -Inf
maximum(d::NormalMixture) = Inf

insupport(d::NormalMixture, rt::Real) = true

Base.broadcastable(d::NormalMixture) = Ref(d)

function NormalMixture(μ, σmin, σmax)
    return NormalMixture(promote(μ, σmin, σmax)...)
end

NormalMixture(; μ, σmin = 0, σmax= 2) = NormalMixture(μ, σmin, σmax)

function rand(dist::NormalMixture, n::Int)
    x = fill(0.0, n)
    for i ∈ 1:n 
        x[i] = rand(dist)
    end
    return x 
end

function rand(dist::NormalMixture)
    (; μ, σmin, σmax) = dist 
    σ = rand(Uniform(σmin, σmax))
    return rand(Normal(μ, σ))
end

function pdf(dist::NormalMixture, x::Real)
    (; μ, σmin, σmax) = dist 
    Δ = σmax - σmin 
    return hcubature(σ -> pdf(Normal(μ, σ[1]), x), [σmin], [σmax])[1] / Δ
end

logpdf(dist::NormalMixture, x::Real) = log(pdf(dist, x))
loglikelihood(dist::NormalMixture, x::Real) = logpdf(dist, x)

data = rand(NormalMixture(μ = 0.0), 100)

logpdf(NormalMixture(μ = 0.5), 0.0)

@model function model(data)
    μ ~ Normal(0, 10)
    data ~ NormalMixture(; μ)
end

chains = sample(model(data), NUTS(), 1000)

Error

ERROR: MethodError: no method matching kronrod(::Type{ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 1}}, ::Int64)

Closest candidates are:
  kronrod(::Any, ::Integer, ::Real, ::Real; rtol, quad)
   @ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/weightedgauss.jl:90
  kronrod(::Type{T}, ::Integer) where T<:AbstractFloat
   @ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/gausskronrod.jl:316
  kronrod(::AbstractMatrix{<:Real}, ::Integer, ::Real, ::Pair{<:Tuple{Real, Real}, <:Tuple{Real, Real}})
   @ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/gausskronrod.jl:390
  ...

Stacktrace:
  [1] GaussKronrod
    @ ~/.julia/packages/HCubature/gOo1d/src/gauss-kronrod.jl:16 [inlined]
  [2] cubrule
    @ ~/.julia/packages/HCubature/gOo1d/src/HCubature.jl:38 [inlined]
  [3] hcubature_(f::var"#54#55"{…}, a::StaticArraysCore.SVector{…}, b::StaticArraysCore.SVector{…}, norm::typeof(LinearAlgebra.norm), rtol_::Int64, atol::Int64, maxevals::Int64, initdiv::Int64, buf::Nothing)
    @ HCubature ~/.julia/packages/HCubature/gOo1d/src/HCubature.jl:105
  [4] hcubature_(f::Function, a::Vector{…}, b::Vector{…}, norm::Function, rtol::Int64, atol::Int64, maxevals::Int64, initdiv::Int64, buf::Nothing)
    @ HCubature ~/.julia/packages/HCubature/gOo1d/src/HCubature.jl:179
  [5] hcubature
    @ ~/.julia/packages/HCubature/gOo1d/src/HCubature.jl:234 [inlined]
  [6] pdf
    @ ~/.julia/dev/sandbox/DDM/hcubature_turing/hcubature_turing.jl:52 [inlined]
  [7] logpdf(dist::NormalMixture{ForwardDiff.Dual{ForwardDiff.Tag{…}, Float64, 1}}, x::Float64)
    @ Main ~/.julia/dev/sandbox/DDM/hcubature_turing/hcubature_turing.jl:55
  [8] Fix1
    @ ./operators.jl:1118 [inlined]
  [9] mapreduce_impl(f::Base.Fix1{…}, op::typeof(Base.add_sum), A::Vector{…}, ifirst::Int64, ilast::Int64, blksize::Int64)
    @ Base ./reduce.jl:262
 [10] mapreduce_impl
    @ ./reduce.jl:277 [inlined]
 [11] _mapreduce(f::Base.Fix1{typeof(logpdf), NormalMixture{…}}, op::typeof(Base.add_sum), ::IndexLinear, A::Vector{Float64})
    @ Base ./reduce.jl:447
 [12] _mapreduce_dim
    @ ./reducedim.jl:365 [inlined]
 [13] mapreduce
    @ ./reducedim.jl:357 [inlined]
 [14] _sum
    @ ./reducedim.jl:1015 [inlined]
 [15] sum
    @ ./reducedim.jl:1011 [inlined]
 [16] loglikelihood
    @ ~/.julia/packages/Distributions/ji8PW/src/common.jl:458 [inlined]
 [17] observe
    @ ~/.julia/packages/DynamicPPL/ACaKr/src/context_implementations.jl:266 [inlined]
 [18] observe
    @ ~/.julia/packages/Turing/duwEY/src/mcmc/hmc.jl:529 [inlined]
 [19] tilde_observe
    @ ~/.julia/packages/DynamicPPL/ACaKr/src/context_implementations.jl:158 [inlined]
 [20] tilde_observe
    @ ~/.julia/packages/DynamicPPL/ACaKr/src/context_implementations.jl:156 [inlined]
 [21] tilde_observe
    @ ~/.julia/packages/DynamicPPL/ACaKr/src/context_implementations.jl:151 [inlined]
 [22] tilde_observe!!(context::DynamicPPL.SamplingContext{…}, right::NormalMixture{…}, left::Vector{…}, vi::DynamicPPL.ThreadSafeVarInfo{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/ACaKr/src/context_implementations.jl:207
 [23] tilde_observe!!
    @ ~/.julia/packages/DynamicPPL/ACaKr/src/context_implementations.jl:194 [inlined]
 [24] macro expansion
    @ ~/.julia/packages/DynamicPPL/ACaKr/src/compiler.jl:579 [inlined]
 [25] model
    @ ~/.julia/dev/sandbox/DDM/hcubature_turing/hcubature_turing.jl:64 [inlined]
 [26] _evaluate!!
    @ ~/.julia/packages/DynamicPPL/ACaKr/src/model.jl:968 [inlined]
 [27] evaluate_threadsafe!!(model::DynamicPPL.Model{…}, varinfo::DynamicPPL.TypedVarInfo{…}, context::DynamicPPL.SamplingContext{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/ACaKr/src/model.jl:957
 [28] evaluate!!
    @ ~/.julia/packages/DynamicPPL/ACaKr/src/model.jl:892 [inlined]
 [29] logdensity(f::LogDensityFunction{…}, θ::Vector{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/ACaKr/src/logdensityfunction.jl:100
 [30] Fix1
    @ ./operators.jl:1118 [inlined]
 [31] vector_mode_dual_eval!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/apiutils.jl:24 [inlined]
 [32] vector_mode_gradient!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:96 [inlined]
 [33] gradient!(result::DiffResults.MutableDiffResult{…}, f::Base.Fix1{…}, x::Vector{…}, cfg::ForwardDiff.GradientConfig{…}, ::Val{…})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:37
 [34] gradient!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:35 [inlined]
 [35] logdensity_and_gradient
    @ ~/.julia/packages/LogDensityProblemsAD/rBlLq/ext/LogDensityProblemsADForwardDiffExt.jl:118 [inlined]
 [36] ∂logπ∂θ
    @ ~/.julia/packages/Turing/duwEY/src/mcmc/hmc.jl:180 [inlined]
 [37] ∂H∂θ
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/hamiltonian.jl:38 [inlined]
 [38] phasepoint(h::AdvancedHMC.Hamiltonian{…}, θ::Vector{…}, r::Vector{…})
    @ AdvancedHMC ~/.julia/packages/AdvancedHMC/AlvV4/src/hamiltonian.jl:74
 [39] phasepoint(rng::Random.TaskLocalRNG, θ::Vector{…}, h::AdvancedHMC.Hamiltonian{…})
    @ AdvancedHMC ~/.julia/packages/AdvancedHMC/AlvV4/src/hamiltonian.jl:155
 [40] initialstep(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}, vi_original::DynamicPPL.TypedVarInfo{…}; initial_params::Nothing, nadapts::Int64, kwargs::@Kwargs{})
    @ Turing.Inference ~/.julia/packages/Turing/duwEY/src/mcmc/hmc.jl:184
 [41] step(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}; initial_params::Nothing, kwargs::@Kwargs{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/ACaKr/src/sampler.jl:116
 [42] step
    @ ~/.julia/packages/DynamicPPL/ACaKr/src/sampler.jl:99 [inlined]
 [43] macro expansion
    @ ~/.julia/packages/AbstractMCMC/YrmkI/src/sample.jl:130 [inlined]
 [44] macro expansion
    @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
 [45] macro expansion
    @ ~/.julia/packages/AbstractMCMC/YrmkI/src/logging.jl:9 [inlined]
 [46] mcmcsample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{…})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/YrmkI/src/sample.jl:120
 [47] sample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; chain_type::Type, resume_from::Nothing, initial_state::Nothing, progress::Bool, nadapts::Int64, discard_adapt::Bool, discard_initial::Int64, kwargs::@Kwargs{})
    @ Turing.Inference ~/.julia/packages/Turing/duwEY/src/mcmc/hmc.jl:123
 [48] sample
    @ ~/.julia/packages/Turing/duwEY/src/mcmc/hmc.jl:92 [inlined]
 [49] #sample#4
    @ ~/.julia/packages/Turing/duwEY/src/mcmc/Inference.jl:272 [inlined]
 [50] sample
    @ ~/.julia/packages/Turing/duwEY/src/mcmc/Inference.jl:263 [inlined]
 [51] #sample#3
    @ ~/.julia/packages/Turing/duwEY/src/mcmc/Inference.jl:260 [inlined]
 [52] sample(model::DynamicPPL.Model{…}, alg::NUTS{…}, N::Int64)
    @ Turing.Inference ~/.julia/packages/Turing/duwEY/src/mcmc/Inference.jl:254
Some type information was truncated. Use `show(err)` to see complete types.

@itsdfish
Copy link
Owner

I think the source of the problem can be traced to this line in QuadGK. My guess is that the parametric restriction on kronrod(::Type{T}, n::Integer) where T<:AbstractFloat needs to be relaxed to T < :Real because ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 1} <: Real. Of course, there might be other points where AD will fail. Another thing I am concerned about is the AD blowing up as it goes through the integration code.

@kiante-fernandez
Copy link
Contributor Author

Sounds good. I'll see if I can help.

Let me verify whether I understand the basic approach. We want to integrate over nu via _pdf_sv, and itegrate over st and sz numerically with hcubature. Is that correct?

yes that is correct.

@kiante-fernandez
Copy link
Contributor Author

Here is my plan. I will start by testing _small_time_pdf and _large_time_pdf because they are part of the baseline pdf function _pdf. Once _pdf is working, I will test _pdf_sv, and once that is working, I will be able to test each of component computed via hcubature. I should be able to start over the weekend. I will post my changes to full_DDM on the primary fork so you can track progress.

Sounds good. I have discussed this exact issue regarding hcubature with @gdalle. I will bring up the MWE and touch base tomorrow. It does indeed have to do with the interaction with AD from my understanding, and he discussed some solutions.

@kiante-fernandez
Copy link
Contributor Author

I can get it to work with a version hard-coded. It is something about the type I will keep looking into.

using Distributions 
import Distributions: insupport
import Distributions: logpdf 
import Distributions: loglikelihood
import Distributions: maximum
import Distributions: minimum
import Distributions: rand 
import Distributions: pdf
using StaticArrays: SVector
using LinearAlgebra: norm
using HCubature 

using ForwardDiff
using Optim
using Turing 

struct GaussKronrod{T<:Real}
    x::Vector{T}
    w::Vector{T}
    wg::Vector{T}
end

# Hardcoded Gauss-Kronrod rule for n=7 (15-point rule)
function GaussKronrod(::Type{T}) where T<:Real
    x = T[
        0.0,
        0.2077849550078985,
        0.4058451513773972,
        0.5860872354676911,
        0.7415311855993944,
        0.8648644233597691,
        0.9491079123427585,
        0.9914553711208126,
    ]
    w = T[
        0.2094821410847278,
        0.2044329400752989,
        0.1903505780647854,
        0.1690047266392679,
        0.1406532597155259,
        0.1047900103222502,
        0.06309209262997855,
        0.02293532201052922,
    ]
    wg = T[
        0.4179591836734694,
        0.3818300505051189,
        0.3818300505051189,
        0.2797053914892767,
        0.2797053914892767,
        0.1294849661688697,
        0.1294849661688697,
    ]
    return GaussKronrod{T}(x, w, wg)
end

const gk_float64 = GaussKronrod(Float64)
GaussKronrod(::Type{Float64}) = gk_float64

countevals(g::GaussKronrod) = 15

function (g::GaussKronrod{T})(f::F, a_::SVector{1}, b_::SVector{1}, norm=norm) where {F,T}
    a = a_[1]
    b = b_[1]
    c = (a + b) * T(0.5)
    Δ = (b - a) * T(0.5)
    fx⁰ = f(SVector(c))                # f(0)
    I = fx⁰ * g.w[1]
    I′ = fx⁰ * g.wg[1]
    @inbounds for i = 2:length(g.x)
        Δx = Δ * g.x[i]
        fx = f(SVector(c + Δx)) + f(SVector(c - Δx))
        I += fx * g.w[i]
        if i <= length(g.wg)
            I′ += fx * g.wg[i]
        end
    end
    I *= Δ
    I′ *= Δ
    return I, norm(I - I′), 1
end

# Function to perform the integration
function custom_hcubature(f, a, b; norm=norm)
    gk = GaussKronrod(eltype(a))
    I, E, _ = gk(f, SVector(a...), SVector(b...), norm)
    return I, E
end

struct NormalMixture{T<:Real} <: ContinuousUnivariateDistribution
    μ::T 
    σmin::T
    σmax::T
end

function NormalMixture(μ, σmin, σmax)
    return NormalMixture(promote(μ, σmin, σmax)...)
end

NormalMixture(; μ, σmin = 0, σmax= 2) = NormalMixture(μ, σmin, σmax)

minimum(d::NormalMixture) = -Inf
maximum(d::NormalMixture) = Inf
insupport(d::NormalMixture, x::Real) = true

Base.broadcastable(d::NormalMixture) = Ref(d)

function rand(dist::NormalMixture, n::Int)
    x = fill(0.0, n)
    for i  1:n 
        x[i] = rand(dist)
    end
    return x 
end

function rand(dist::NormalMixture)
    (; μ, σmin, σmax) = dist 
    σ = rand(Uniform(σmin, σmax))
    return rand(Normal(μ, σ))
end

function pdf(dist::NormalMixture, x::Real)
    (; μ, σmin, σmax) = dist 
    Δ = σmax - σmin 
    integral, _ = custom_hcubature-> pdf(Normal(μ, σ[1]), x), [σmin], [σmax])
    return integral / Δ
end

logpdf(dist::NormalMixture, x::Real) = log(pdf(dist, x))
loglikelihood(dist::NormalMixture, x::Real) = logpdf(dist, x)

data = rand(NormalMixture= 0.0), 100)

#try a few different test cases

## example 1
function test_normalmixture(params)
    d = NormalMixture= params[1], σmin = params[2], σmax = params[3])
    return logpdf(d, 0.5) 
end

ForwardDiff.gradient(test_normalmixture, [0.0, 1.0, 2.0])

## example 2
function normalmixture_vector(params)
    d = NormalMixture= params[1], σmin = params[2], σmax = params[3])
    return [logpdf(d, 0.5), pdf(d, 0.5)]
end

ForwardDiff.jacobian(normalmixture_vector, [0.0, 1.0, 2.0])

## example 3
function log_likelihood(μ, data)
    return sum(logpdf(NormalMixture= μ[1]), x) for x in data)
end

neg_log_likelihood(μ) = -log_likelihood(μ, data)
# Use ForwardDiff to compute the gradient
function grad_neg_log_likelihood!(storage, μ)
    ForwardDiff.gradient!(storage, neg_log_likelihood, μ)
end

## example 4
function test_gradient()
    # Choose a point to evaluate the gradient
    test_point = [1.0]

    # Create a storage vector for the gradient
    gradient = similar(test_point)
    
    # Compute the gradient
    grad_neg_log_likelihood!(gradient, test_point)

end

test_gradient()

## example 5
# Perform optimization
optimize(neg_log_likelihood, grad_neg_log_likelihood!, [0.0], BFGS())

## example 6
@model function model(data)
    μ ~ Normal(0, 10)
    data ~ NormalMixture(; μ)
end

chains = sample(model(data), NUTS(;adtype = AutoForwardDiff()), 1000)

@kiante-fernandez
Copy link
Contributor Author

kiante-fernandez commented Jul 12, 2024

Yeah it is indeed about the type.The key change is the extension of the kronrod function to work with Dual numbers, which allows HCubature to use these types internally. In other words, we have the "let it breath" : see talk on AD

function QuadGK.kronrod(::Type{ForwardDiff.Dual{T,V,N}}, n::Integer) where {T,V,N}
    x, w, wg = QuadGK.kronrod(V, n)
    return map(y -> ForwardDiff.Dual{T,V,N}(y), x),
           map(y -> ForwardDiff.Dual{T,V,N}(y), w),
           map(y -> ForwardDiff.Dual{T,V,N}(y), wg)
end

@itsdfish
Copy link
Owner

itsdfish commented Jul 12, 2024

Awesome. This is helpful. I have a few questions. First, is this something that should be changed at QuadGK, or should we add it as a package extension? Second, does the function above work for arbitrary inputs?

Also, I didn't know Julia highlighting works on Github.

Update

This might be type piracy. So it might be something that needs to be changed at QuadGK, or maybe ChainRules.

Second Update

Unfortunately, this confirms my concern about numeric integration killing performance. Integration is not slow in an absolute sense, but the relative speeds are vastly different:

julia> @btime logpdf($Normal(0, 1), $2)
  150.229 ns (1 allocation: 16 bytes)
-2.9189385332046727

julia> @btime logpdf($NormalMixture(0, 0, 2), $2)
  4.140 μs (30 allocations: 1.55 KiB)
-2.885455766369351

Or about 26 times slower. It seems like the performance hit is larger than expected, unless the gap for the gradients is even bigger. Maybe there is a way to improve performance.

@DominiqueMakowski
Copy link
Contributor

(Just to say, though you probably agree with me, that at this stage (esp for complicated models) performance should probably not be a priority that stops us from implementing any models. Even if unusably slow, I think having a working and correct implementation in Julia can be then used as a target to test against and set the stage for future improvements and the creation of interesting optimization challenges ☺️ )

@gdalle
Copy link

gdalle commented Jul 13, 2024

Just popping here to say that I'm happy to help for any AD-related bugs, as long as you can provide an MWE! Had a few chats with @kiante-fernandez at JuliaCon but I can always clarify some stuff

@itsdfish
Copy link
Owner

itsdfish commented Jul 13, 2024

@gdalle, thank you for your willingness to help with AD related issues. I really appreciate it. As some brief background, Kiante has encountered difficulty implementing the PDF of a model with several parameters that are integrated out. His first attempts were to use analytic solutions provided by others, but that has proved to be challenging. Now we are exploring the feasibility of brute forcing it with HCubature , using a normal mixture as a simple example where the mu parameter is fixed, but the sigma parameter is uniformly distributed across observations between 0 and 2 and must be integrated out. Here is a MWE of the problem we encountered:

using Distributions
using ForwardDiff
using HCubature

f(Θ) = hcubature-> pdf(Normal(Θ[1], σ[1]), 1), [Θ[2]], [Θ[3]])[1]
ForwardDiff.gradient(f, [0, 0, 2])

Our understanding is that either this method needs to relax its parametric restrictions, or we need a new method akin to Kiante's solution:

function QuadGK.kronrod(::Type{ForwardDiff.Dual{T,V,N}}, n::Integer) where {T,V,N}
    x, w, wg = QuadGK.kronrod(V, n)
    return map(y -> ForwardDiff.Dual{T,V,N}(y), x),
           map(y -> ForwardDiff.Dual{T,V,N}(y), w),
           map(y -> ForwardDiff.Dual{T,V,N}(y), wg)
end

One potential issue is that of type piracy. I don't know much about AD, but maybe a solution needs to go into ChainRules or QuadGK. Another issue is performance. Estimating mu for a simple normal model requires about .90 seconds, but the required time for the normal mixture is estimated to be about 20 minutes with Kiantes solution. I'm not sure whether it is inherently this slow, or perhaps the problem is that AD is propagating into the into the integration algorithm without needing to. Thanks again for your help.

@itsdfish
Copy link
Owner

performance should probably not be a priority that stops us from implementing any models

Given that we have limited developer hours, I recommend approaching this strategically. My rationale for developing the simple normal mixture model was to guage the performance implications of using numeric integration so we can potentially avoid wasting time on a non-viable approach. At this point, adding numeric integration to a simple model increased the execution time by more than three orders of magnitude, which makes think twice about proceeding. This will only be worse for a more complex model with two integrals. If there is not a way to significantly improve performance, I don't think it is worth pursuing.

On a related note, I also think it might be worth revisiting the status of the analytic approach Kiante was working on. If I remember correctly, I added tests for numerous sub-functions of the pdf which produced the same results as existing implementations. Typically, the bugs can be found by systematically testing the components and their integration, which suggests we might have been close. At this point, I'm not suggesting that we should go in one direction or the other, but I think we need more information from gdalle.

@gdalle
Copy link

gdalle commented Jul 13, 2024

Your diagnosis is correct, something would probably need to be fixed in GaussKronrod to accomodate Dual numbers. If you do it yourself in a third-party package, it will amount to type piracy because you own neither the Dual type nor the kronrod function.
But that's just a symptom of a more generic problem, which I discussed with @kiante-fernandez at JuliaCon: sometimes, solvers are just not differentiable, or differentiating through them is slow. I indeed suggested that as a first approach, but the more efficient and principled answer is to use something like Integrals.jl, where custom rules are defined for various AD backends to cleverly differentiate through integration methods. I'll try to cook up an example

@itsdfish
Copy link
Owner

Just a stupid question but if your true problem actually involves Gaussian integrals, can you sidestep the numerical integration and just use the Gaussian cumulative distribution function, possibly after a change of variables to account for the different σ?

Thanks for your replies. I used a normal distribution as a toy example to guage feasibility of the approach. The decision making model is a type of 1D stochastic differential equation which evolves until it hits a threshold for the first time. So the PDF is much more complex. In case it is useful, equation 5 shows the pdfs and the integrals. Maybe there is a clever mathematical trick with change of variables, but I'm not knowledable enough to know.

@gdalle
Copy link

gdalle commented Jul 13, 2024

Yeah my bad, I think even with the MWE it doesn't work because you're integrating by varying $\sigma$ itself.

I coded a small example with Integrals.jl. With Zygote.jl I obtain a gradient but the slowdown is still two orders of magnitude.

using BenchmarkTools
using Distributions
using Integrals
using Zygote

integrand(x::T1, p::T2) where {T1<:Real,T2<:Real} = pdf(Normal(p, x), one(promote_type(T1, T2)))

function f(Θ)
    domain = (Θ[2], Θ[3])
    p = Θ[1]
    prob = IntegralProblem(integrand, domain, p)
    sol = solve(prob, QuadGKJL(); reltol=1e-3, abstol=1e-3)
    return sol.u
end

Θ = [0.0, 0.0, 2.0]
f(Θ)
Zygote.gradient(f, Θ)[1]

@btime f($Θ);
@btime Zygote.gradient($f, $Θ);

With ForwardDiff.jl the exact same error pops up, which is weird cause Integrals.jl should not ever try to differentiate through the solver. I suggest you open an issue in Integrals.jl with the same MWE and ask them. My best guess is that rules for differentiation with respect to integral bounds are still missing:

https://github.com/SciML/Integrals.jl/blob/26f0f739778b8a1256535acd14cc38c2a4e21e28/ext/IntegralsForwardDiffExt.jl#L45

@gdalle
Copy link

gdalle commented Jul 13, 2024

Actually @kiante-fernandez that would be a very meaningful and useful contribution to the ecosystem as a whole. Wanna wrestle with autodiff for real? Try contributing the differentiation rules with respect to bounds!

@itsdfish
Copy link
Owner

Thanks for the example above. I replaced Zygote with ForwardDiff, but it worked for me. Here is the version I am using.

(hcubature_turing) pkg> st Integrals ForwardDiff
Status `~/.julia/dev/sandbox/DDM/hcubature_turing/Project.toml`
  [f6369f11] ForwardDiff v0.10.36
  [de52edbc] Integrals v4.4.1

Were you using a different version by chance?

@gdalle
Copy link

gdalle commented Jul 13, 2024

That's very weird, I have the same version. How did you do the replacement? What happens when you run the following code in a fresh REPL (regardless of the initial environment)

using Pkg
Pkg.activate(temp=true)
Pkg.add(["BenchmarkTools", "Distributions", "ForwardDiff", "Integrals", "Zygote"])

using BenchmarkTools
using Distributions
using ForwardDiff
using Integrals
using Zygote

integrand(x::T1, p::T2) where {T1<:Real,T2<:Real} = pdf(Normal(p, x), one(promote_type(T1, T2)))

function f(Θ)
    domain = (Θ[2], Θ[3])
    p = Θ[1]
    prob = IntegralProblem(integrand, domain, p)
    sol = solve(prob, QuadGKJL(); reltol=1e-3, abstol=1e-3)
    return sol.u
end

Θ = [0.0, 0.0, 2.0]
f(Θ)
Zygote.gradient(f, Θ)[1]
ForwardDiff.gradient(f, Θ)

@btime f($Θ);
@btime Zygote.gradient($f, $Θ);
@btime ForwardDiff.gradient($f, $Θ);

@itsdfish
Copy link
Owner

My appologies. I was just about to ammend my comment. I accidently ran Kiante's meothod for kronrod when running your example. False alarm. I will open an issue with Integrals.jl. Thanks again!

@gdalle
Copy link

gdalle commented Jul 13, 2024

My pleasure! And feel free to tag the maintainers if your issue doesn't get any attention within a few days. It's completely okay in the open source community, as long as you're respectful and considerate.

@itsdfish
Copy link
Owner

For our record keeping, the issue at Integrals.jl can be found here.

@itsdfish
Copy link
Owner

I did a preliminary investigation this morning. One thing I learned was that the gradient works with the numeric pdf of the DDM. I was surprised by that. I'm not sure whether ForwardDiff did not go through hcubature, but it did not error. The preliminary benchmarks are 17.70ms (numeric) vs 1.33ms (analytic) for logpdf (based on 1000 data points) and 9.5mu (numeric) vs 5mu (analytic) for the gradient.

I also discovered that the pdf of the DDM without any across-trial variability (_pdf) is off by a large amount, as shown in the figure below. So if we end up going with that version, fixing _pdf will be one of the first steps. We'll need to be sure that we add unit tests for each of these component functions so that bugs are easy to isolate.

bad_density

@itsdfish
Copy link
Owner

Thanks for the updated code, Kiante. The pdf looks better, but something still is not right either with the pdf or the simulated data.

using Plots
using SequentialSamplingModels

fixed_parms = (
    ν = 1.00,
    α = 0.80,
    τ = 0.30,
    z = 0.25,
    η = 0,
    sz = 0,
    st = 0,
    σ = 1.0
)

dist = DDM(; fixed_parms...)

choice, rt = rand(dist, 100_000)
pr_1 = mean(choice .== 1) 
p1 = histogram(rt[choice .== 1], norm = true)
p1[1][1][:y] .*= pr_1
x = range(.25, 1.25, length = 200)
dens1 = pdf.(dist, 1, x)
plot!(x, dens1, leg=false)

bad_density

@kiante-fernandez
Copy link
Contributor Author

kiante-fernandez commented Jul 16, 2024

I think it might be the sampling method.

Also something I wanted to note about one of our current tests in ddm_tests.jl
using the MATLAB code from Navarro, D. J., & Fuss, I. G. (2009):

function p=wfpt(t,v,a,z,err)
tt=t/(a^2); % use normalized time
w=z/a; % convert to relative start point
% calculate number of terms needed for large t
if pi*tt*err<1 % if error threshold is set low enough
kl=sqrt(-2*log(pi*tt*err)./(pi^2*tt)); % bound
kl=max(kl,1/(pi*sqrt(tt))); % ensure boundary conditions met
else % if error threshold set too high
kl=1/(pi*sqrt(tt)); % set to boundary condition
end
% calculate number of terms needed for small t
if 2*sqrt(2*pi*tt)*err<1 % if error threshold is set low enough
ks=2+sqrt(-2*tt.*log(2*sqrt(2*pi*tt)*err)); % bound
ks=max(ks,sqrt(tt)+1); % ensure boundary conditions are met
else % if error threshold was set too high
ks=2; % minimal kappa for that case
end
% compute f(tt|0,1,w)
p=0; %initialize density
if ks<kl % if small t is better...
K=ceil(ks); % round to smallest integer meeting error
for k=-floor((K-1)/2):ceil((K-1)/2) % loop over k
p=p+(w+2*k)*exp(-((w+2*k)^2)/2/tt); % increment sum
end
p=p/sqrt(2*pi*tt^3); % add constant term
else % if large t is better...
K=ceil(kl); % round to smallest integer meeting error
for k=1:K
p=p+k*exp(-(k^2)*(pi^2)*tt/2)*sin(k*pi*w); % increment sum
end
p=p*pi; % add constant term
end
% convert to f(t|v,a,w)
p=p*exp(-v*a*w -(v^2)*t/2)/(a^2);

Here is RT dist

rtdists::ddiffusion(0.35, "upper", a = 0.50, v = 0.80, t0 =  0.2, z = 0.3 * 0.50, sz = 0, sv = 0, st0 = 0.00, s = 1)
#0.6635714
rtdists::ddiffusion(0.35, "lower", a = 0.50, v = 0.80, t0 =  0.2, z = 0.3 * 0.50, sz = 0, sv = 0, st0 = 0.00, s = 1)
#0.4450956

Running the code above (as well as our current implementation)

wfpt(.35 - 0.2, -0.8, 0.5,1 - 0.3, 1.0e-12)
%-1.0323
wfpt(.35 - 0.2, 0.8, 0.5, 0.3, 1.0e-12)
%0.4638

I wonder if I am missing something. Note that our current implementation would give something closer to the wfpt.

@kiante-fernandez
Copy link
Contributor Author

Another example to consider is the following which would test off a popular python implementation (HDDM)

uses the following repo
https://github.com/brown-ccv/hddm-wfpt

import hddm_wfpt
import numpy as np


def ddm_loglik(data, v, a, z, t, err=1e-12):
    
    data = data[:, 0] * data[:, 1]
    # print(v.shape)
    # print(v)
    # print(a.shape)
    # print(a)
    # print(z)
    # print(t)
    # print(data)
    # print(data.shape[0])

    v_ = np.zeros(data.shape[0])
    v_[:] = np.float64(v)
    a_ = np.zeros(data.shape[0])
    a_[:] = 2 * np.float64(a)
    z_ = np.zeros(data.shape[0])
    z_[:] = np.float64(z)
    t_ = np.zeros(data.shape[0])
    t_[:] = np.float64(t)
 
    # Our function expects inputs as float64, but they are not guaranteed to
    # come in as such --> we type convert

    out_logp = hddm_wfpt.wfpt.wiener_logp_array(np.float64(data),
                                            v_, # v
                                            np.zeros(data.shape[0]), # sv
                                            a_, # a
                                            z_, # z
                                            np.zeros(data.shape[0]), # sz
                                            t_, # t
                                            np.zeros(data.shape[0]), # st
                                            err,
                                            )
    
    # print(out_logp)
    return out_logp

# Example usage
if __name__ == "__main__":
    # Create sample data
    # Assuming data is a 2D array where each row is [reaction_time, choice]
    # reaction_time is in seconds, choice is 1 or -1
    data = np.array([
        # [0.71, 1],
        # [0.71, -1],
        [0.35, 1],
        [0.35, -1]
    ])

    # Set parameters
    v = 0.80  # drift rate
    a = 0.50  # boundary separation
    z = 0.30  # starting point
    t = 0.20 # non-decision time

    # Call the function
    log_likelihoods = ddm_loglik(data, v, a, z, t)
    # Print results
    print("Log-likelihoods:")
    print((log_likelihoods))

#> Log-likelihoods:
#> [0.41412645 0.13426766]

@itsdfish
Copy link
Owner

itsdfish commented Jul 16, 2024

Hmmm. That is really strange the results differ so much. One possibility is that they have different parameterizations, but I'm guessing you have accounted for that already. The diffusion coefficient is another possible point of difference. Some use .1 and other use 1.

One way to adjudicate between the three versions is to figure out which overlays correctly with data simulated from the SDE. The implementation of the SDE is straight forward, which would reduce the chance of errors.

Edit

One thing I realized is that the hddm_wfpt example is log likelihood, but the other two are densities. That doesn't make them agree, but they do not appear to be the same quantities.

@itsdfish
Copy link
Owner

Also, I want to note that the SDE and rejection sampling methods produce similar results, as show below. I used Δt = .00001 for the SDE approximation. This is some evidence that rand is correct, but the pdf is not. Nonetheless, it would be good to have some unit tests for rand, which compares another package above, or the SDE, in terms of KDE or quantiles.
dist

Note these are conditioned on the response same response, but the response probabilities were nearly identical

@DominiqueMakowski
Copy link
Contributor

DominiqueMakowski commented Sep 13, 2024

As a test for the new ChatGPT version that's supposedly amazing with code and complex problems, I asked for insights on this issue, it said:

ChatGPT's solution

image
image
image
image
image
image

Unfortunately, I can't say if it's pure nonsense or not 😅

@itsdfish
Copy link
Owner

itsdfish commented Sep 13, 2024

Unfortunately, I can't say if it's pure nonsense or not 😅

Haha. That is the fundamental problem with the current AI approach. In the case of (ν, α, z, τ, η, σ) = params(d) in function _pdf_sv, that is indeed an error. One of the reasons I don't use params is because the output is order dependent and a fixed size. Instead, I recommend using Julia's built in functionality for unpacking variables. Consider this as an example. By using ;, Julia will treat the variables as keywords rather than positional. The solution GPT offered seems to be correct, but is inelegant and more challenging from a human factors perspective. Here are the solutions side by side:

# in any order and any length
(;ν, α, z, τ, η, σ) = d 

vs.

# in a fixed order and size
(ν, α, z, τ, η, _, _, σ) = params(d) 

GPT and other LLMs "hallucinate" because each token is generated sequentially by samping from a distribution. So you need to be careful when working using code generated by LLMs. It could easily swap the order of parameters in (ν, α, z, τ, η, _, _, σ) = params(d) because the parameters are close in embedding space. An example can be found in the last output above:

 (v_mean, α, z, τ, η, _, _, σ) = params(d) 

This code is still correct, but will introduce inconsistency in your code base. If you used the built-in Julia unpacking functionality, it would introduce an error because v_mean is not the name of a field in DDM. Unfortunately, this is an error that it could generate because its sampling from a token distribution conditioned on context. The fact that it identified an error explains its insidious allure.

@DominiqueMakowski
Copy link
Contributor

at least i learned something :)

@itsdfish
Copy link
Owner

I think AI will be very helpful for code review once the field decides to incorporate more symbolic methods rather than continuuing with the cash-grab, hype train of scaling data and hardware. LLMs are fundamentally flawed. In the meantime, its just useful enough to wreak havoc. I'm glad I don't have to grade papers =)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants