Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector of inputs in streaming_inference #318

Open
wouterwln opened this issue Jun 18, 2024 · 0 comments
Open

Vector of inputs in streaming_inference #318

wouterwln opened this issue Jun 18, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@wouterwln
Copy link
Member

If we add a vector/tensor of inputs in streaming inference and create a (data)variable for every entry, we get the following error message:

MethodError: no method matching is_data(::Vector{RxInfer.GraphVariableRef})

Closest candidates are:
  is_data(!Matched::RxInfer.GraphVariableRef)
   @ RxInfer ~/.julia/packages/RxInfer/SROpQ/src/model/plugins/reactivemp_inference.jl:229
  is_data(!Matched::GraphPPL.VariableNodeProperties)
   @ GraphPPL ~/.julia/packages/GraphPPL/ke7hR/src/graph_engine.jl:696

MWE:

@model function test_model(x, y, mx, vx)
    for i in 1:3
        x[i] ~ NormalMeanVariance(mx, vx)
    end
    my ~ NormalMeanVariance(0, 1)
    y ~ NormalMeanVariance(my, 1.0)
end

d = [(x = rand(3),y = rand()) for i in 1:10]
datastream = from(d) |> map(NamedTuple{(:x, :y), Tuple{Vector{Float64}, Float64}}, (d) -> d)

foo(x) = 1.0

autoupdates = @autoupdates begin
    mx = foo(q(my))
    vx = foo(q(my))
end

The following code runs and gives a result:

infer(model = test_model(mx = 1.0, vx = 1.0), data=(x = rand(3), y = 0.0), iterations=10, showprogress=true)

When we run streaming inference the error message is being thrown:

infer(model = test_model(), datastream=datastream, autoupdates = autoupdates, initialization = @initialization begin q(my) = NormalMeanVariance(1.0, 1.0) end)

The following fixes this, but might not be the most rigorous fix:

RxInfer.is_data(vector::Vector{RxInfer.GraphVariableRef}) = all(RxInfer.is_data.(vector))
@bvdmitri bvdmitri self-assigned this Jun 18, 2024
@bvdmitri bvdmitri added the bug Something isn't working label Jun 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants