Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to predict from prior-sampled model (+ how to have choice and rt as a separate input) #19

Closed
DominiqueMakowski opened this issue Jun 17, 2023 · 6 comments

Comments

@DominiqueMakowski
Copy link
Contributor

I feel like this question lies somewhere between SequentialSamplingModels and Turing itself, but let's assume the following basic LBA model:

using Turing
using SequentialSamplingModels
using Random
using LinearAlgebra
using DataFrames
using StatsPlots

# Generate data with different drifts for two conditions A vs. B
Random.seed!(254)
data = DataFrame(rand(LBA=[3.0, 2.0], A=0.8, k=0.2, τ=0.3), 500))



@model function model_lba(data)
    data = (choice=data.choice, rt=data.rt)
    min_rt = minimum(data[2])

    # Priors
    drift ~ filldist(MvNormal(zeros(2), I * 2), 2)
    A ~ truncated(Normal(0.8, 0.4), 0.0, Inf)
    k ~ truncated(Normal(0.2, 0.2), 0.0, Inf)
    tau ~ Uniform(0.0, min_rt)

    # Likelihood
    data ~ LBA(; τ=tau, A=A, k=k, ν=drift)
    return (; data, drift, A, k, tau)
end


chain = sample(model_lba(data), Prior(), 50000)

(note that it is made to run with a dataframe input for convenience)

I can easily sample from priors. What I like to do next is to make generate predictions from these priors (to visualize a prior predictive check).

My first instinct was to simply run:

predict(model_lba(data), chain)

But it returns nothing - which I thought was a bug. The Turing team clarified the solution: it was to set the outcome variable to missing, which is straightforward in a linear model, but not here.

Attempt 1: Re-instantiating the model on a data with missing for choice:

predict(model_lba(DataFrames.transform(data, choice -> (choice=missing,))), chain)
ERROR: TypeError: non-boolean (Missing) used in boolean context
Stacktrace:
  [1] pdf
    @ C:\Users\domma\.julia\packages\SequentialSamplingModels\hMsCP\src\LBA.jl:108 [inlined]
  ...

Attempt 2: Re-write the model to have choice and rt as separate inputs, which would also make it more flexible (for instance to add predictor variables in the future).

@model function model_lba(choice, rt)
    min_rt = minimum(rt)

    # Priors
    drift ~ filldist(MvNormal(zeros(2), I * 2), 2)
    A ~ truncated(Normal(0.8, 0.4), 0.0, Inf)
    k ~ truncated(Normal(0.2, 0.2), 0.0, Inf)
    tau ~ Uniform(0.0, min_rt)

    # Likelihood
    (choice, rt) ~ LBA(; τ=tau, A=A, k=k, ν=drift)
    return (; choice, rt, drift, A, k, tau)
end
ERROR: LoadError: Malformed variable name (choice, rt)!

Turing doesn't seem to like the tuple output. So maybe we can workaround by creating the data tuple inside the model?

@model function model_lba(choice, rt)
    data = (choice=choice, rt=rt)
    min_rt = minimum(rt)

    # Priors
    drift ~ filldist(MvNormal(zeros(2), I * 2), 2)
    A ~ truncated(Normal(0.8, 0.4), 0.0, Inf)
    k ~ truncated(Normal(0.2, 0.2), 0.0, Inf)
    tau ~ Uniform(0.0, min_rt)

    # Likelihood
    data ~ LBA(; τ=tau, A=A, k=k, ν=drift)
    return (; data.choice, data.rt, drift, A, k, tau)
end

chain = sample(model_lba(data.choice, data.rt), Prior(), 50000)
ERROR: MethodError: no method matching iterate(::LBA{Matrix{Float64}, Float64, Float64, Float64})
  ...

Unfortunately no. I am not sure what else to try, and any pointers and thoughts are more than welcome ☺️

@itsdfish
Copy link
Owner

I'm not quite sure how to fix the issue. It might be good to ask the Turing team for advice for this specific model. Yesterday, Kiante and I made changes that I believe were supposed to prevent the no matching method error. Have you upgraded to the most recent version?

@itsdfish
Copy link
Owner

Also, if you don't mind, I will add your solution to the documentation.

@DominiqueMakowski
Copy link
Contributor Author

Have you upgraded to the most recent version?

I did, but the iterate() error still exists. I'm not sure why iterate() is called here (and creates problem) as compared to the working version...

It might be good to ask the Turing team for advice

posted on discourse

if you don't mind, I will add your solution

Of course :)

Thanks!

@itsdfish
Copy link
Owner

itsdfish commented Jun 17, 2023

@DominiqueMakowski, the following code runs without a method error. However, predict runs but does not work correctly. If you run predict a second time, you can see that the data are generated in lba_model. I tried using Tuple and Array as the data container, but it still produced the same behavior. So I am not sure what is happening. Do you have any insights?

using Turing
using SequentialSamplingModels
using Random
using LinearAlgebra


# Generate data with different drifts for two conditions A vs. B
Random.seed!(254)
data = rand(LBA(ν=[3.0, 2.0], A=0.8, k=0.2, τ=0.3), 500)

@model function model_lba(data)
    min_rt = minimum(data[2])

    # Priors
    drift ~ filldist(MvNormal(zeros(2), I * 2), 2)
    A ~ truncated(Normal(0.8, 0.4), 0.0, Inf)
    k ~ truncated(Normal(0.2, 0.2), 0.0, Inf)
    tau ~ Uniform(0.0, min_rt)

    # Likelihood
    data ~ LBA(; τ=tau, A=A, k=k, ν=drift)
    println("data $data")
    return (; data..., drift, A, k, tau)
end

chain = sample(model_lba(data), Prior(), 100)
# this does not work, but if you rerun predict, it does generate data correctly
predictions = predict(model_lba(data), chain)

@itsdfish
Copy link
Owner

itsdfish commented Jun 18, 2023

I have been wondering whether the solution to this problem is to define a type called MixedMultivariateDistribution which works for these types of models. We would implement methods for MixedMultivariateDistribution as needed. I'm not sure exactly what this would entail, but it seems like the proper way to make things work. I think there is new functionality in Julia 1.9 which would allow these methods to be exported conditionally when Turing is being used.

@DominiqueMakowski
Copy link
Contributor Author

I agree with your vision of striving at implementing things "the proper way" without hacks and workarounds - when possible. Unfortunately I cannot help much you directly with the code (still a Julia newbie) - but I can assist with testing things, trying to get external help and all that ☺️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants