Skip to content

Commit

Permalink
make sd in LBA a vector
Browse files Browse the repository at this point in the history
  • Loading branch information
itsdfish committed Jul 30, 2023
1 parent 495a8f1 commit 3b3f3a5
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SequentialSamplingModels"
uuid = "0e71a2a6-2b30-4447-8742-d083a85e82d1"
authors = ["itsdfish"]
version = "0.5.5"
version = "0.6.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ Plots = "1.0.0"
StatsBase = "0.34.0"
StatsModels = "0.7.0"
StatsPlots = "0.15.0"
Turing = "0.26.0"
Turing = "0.26.0,0.27.0,0.28.0"
32 changes: 17 additions & 15 deletions src/LBA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ A model object for the linear ballistic accumulator.
- `ν`: a vector of drift rates
- `A`: max start point
- `k`: A + k = b, where b is the decision threshold
- `σ=1`: drift rate standard deviation
- `σ=fill(1.0, length(ν))`: drift rate standard deviation
- `τ`: a encoding-response offset
# Constructors
LBA(ν, A, k, τ, σ)
LBA(;τ=.3, A=.8, k=.5, ν=[2.0,1.75], σ=1.0)
LBA(;τ=.3, A=.8, k=.5, ν=[2.0,1.75], σ=[1.0,1.0])
# Example
Expand All @@ -36,20 +36,21 @@ mutable struct LBA{T<:Real} <: AbstractLBA
A::T
k::T
τ::T
σ::T
σ::Vector{T}
end

function LBA(ν, A, k, τ, σ)
_, A, k, τ, σ = promote(ν[1], A, k, τ, σ)
_, A, k, τ, _ = promote(ν[1], A, k, τ, σ[1])
ν = convert(Vector{typeof(k)}, ν)
σ = convert(Vector{typeof(k)}, σ)
return LBA(ν, A, k, τ, σ)
end

function params(d::LBA)
return (d.ν,d.A,d.k,d.τ,d.σ)
end

LBA(;τ=.3, A=.8, k=.5, ν=[2.0,1.75], σ=1.0) = LBA(ν, A, k, τ, σ)
LBA(;τ=.3, A=.8, k=.5, ν=[2.0,1.75], σ=fill(1.0, length(ν))) = LBA(ν, A, k, τ, σ)

function select_winner(dt)
if any(x -> x > 0, dt)
Expand All @@ -71,8 +72,9 @@ sample_drift_rates(ν, σ) = sample_drift_rates(Random.default_rng(), ν, σ)
function sample_drift_rates(rng::AbstractRNG, ν, σ)
negative = true
v = similar(ν)
n_options = length(ν)
while negative
v = [rand(rng, Normal(d, σ)) for d in ν]
v = [rand(rng, Normal(ν[i], σ[i])) for i 1:n_options]
negative = any(x -> x > 0, v) ? false : true
end
return v
Expand All @@ -96,11 +98,11 @@ function pdf(d::AbstractLBA, c, rt)
(;τ,A,k,ν,σ) = d
b = A + k; den = 1.0
rt < τ ? (return 1e-10) : nothing
for (i,v) in enumerate(ν)
for i 1:length(ν)
if c == i
den *= dens(d, v, rt)
den *= dens(d, ν[i], σ[i], rt)
else
den *= (1 - cummulative(d, v, rt))
den *= (1 - cummulative(d, ν[i], σ[i], rt))
end
end
pneg = pnegative(d)
Expand All @@ -109,8 +111,8 @@ function pdf(d::AbstractLBA, c, rt)
isnan(den) ? (return 0.0) : (return den)
end

function dens(d::AbstractLBA, v, rt)
(;τ,A,k,ν,σ) = d
function dens(d::AbstractLBA, v, σ, rt)
(;τ,A,k) = d
dt = rt - τ; b = A + k
n1 = (b - A - dt * v) / (dt * σ)
n2 = (b - dt * v) / (dt * σ)
Expand All @@ -119,8 +121,8 @@ function dens(d::AbstractLBA, v, rt)
return dens
end

function cummulative(d::AbstractLBA, v, rt)
(;τ,A,k,ν,σ) = d
function cummulative(d::AbstractLBA, v, σ, rt)
(;τ,A,k) = d
dt = rt - τ; b = A + k
n1 = (b - A - dt * v) / (dt * σ)
n2 = (b - dt * v) / (dt * σ)
Expand All @@ -133,8 +135,8 @@ end
function pnegative(d::AbstractLBA)
(;ν,σ) = d
p = 1.0
for v in ν
p *= cdf(Normal(0, 1), -v / σ)
for i 1:length(ν)
p *= cdf(Normal(0, 1), -ν[i] / σ[i])
end
return p
end
Expand Down
27 changes: 27 additions & 0 deletions test/lba_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,33 @@
@test y′ y rtol = .03
end

@safetestset "LBA Test3" begin
using SequentialSamplingModels, Test, Random
include("KDE.jl")
Random.seed!(851)

# note for some values, tests will fail
# this is because kde is sensitive to outliers
# density overlay on histograms are valid
dist = LBA=[2.0,2.7], A = .4, k = .20, τ = .4, σ=[1.0,0.5])
choice,rt = rand(dist, 10^5)
rt1 = rt[choice .== 1]
p1 = mean(x -> x == 1, choice)
p2 = 1 - p1
approx_pdf = kde(rt1)
x = .2:.01:1.5
y′ = pdf(approx_pdf, x) * p1
y = pdf.(dist, (1,), x)
@test y′ y rtol = .03

rt2 = rt[choice .== 2]
approx_pdf = kde(rt2)
x = .2:.01:1.5
y′ = pdf(approx_pdf, x) * p2
y = pdf.(dist, (2,), x)
@test y′ y rtol = .03
end

@safetestset "LBA loglikelihood" begin
using SequentialSamplingModels
using Test
Expand Down

0 comments on commit 3b3f3a5

Please sign in to comment.