From 047e898a0b9dcce97e4090dfc3fd08b459a3b524 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <4462211+jeremiahpslewis@users.noreply.github.com> Date: Sat, 29 Jul 2023 19:13:35 +0200 Subject: [PATCH 1/6] Fix type signature --- src/episodes.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/episodes.jl b/src/episodes.jl index 0abe648..0113a1c 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -90,7 +90,7 @@ function pad!(trace::Trace) return nothing end -pad!(buf::CircularArrayBuffer{T}) where {T,N,A} = push!(buf, zero(T)) +pad!(buf::CircularArrayBuffer{T,N,A}) where {T,N,A} = push!(buf, zero(T)) pad!(vect::Vector{T}) where {T} = push!(vect, zero(T)) #push a duplicate of last element as a dummy element for all 'trace' objects, ignores multiplex traces, should never be sampled. From 31ce64877ddff82d07fe2174dc252da66ed53beb Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <4462211+jeremiahpslewis@users.noreply.github.com> Date: Sat, 29 Jul 2023 19:39:13 +0200 Subject: [PATCH 2/6] Attempt to fix NStepBatchSampler... --- src/samplers.jl | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/samplers.jl b/src/samplers.jl index 98847ab..b0e0a71 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -173,12 +173,25 @@ end NStepBatchSampler(; kw...) = NStepBatchSampler{SS′ART}(; kw...) NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names} = NStepBatchSampler{names}(n, γ, batch_size, stack_size, rng) + +function valid_range_nbatchsampler(s::NStepBatchSampler, ts) + # think about the extreme case where s.stack_size == 1 and s.n == 1 + isnothing(s.stack_size) ? (1:(length(ts)-s.n+1)) : (s.stack_size:(length(ts)-s.n+1)) +end function StatsBase.sample(s::NStepBatchSampler{names}, ts) where {names} - valid_range = isnothing(s.stack_size) ? (1:(length(ts)-s.n+1)) : (s.stack_size:(length(ts)-s.n+1))# think about the exteme case where s.stack_size == 1 and s.n == 1 + valid_range = valid_range_nbatchsampler(s, ts) inds = rand(s.rng, valid_range, s.batch_size) StatsBase.sample(s, ts, Val(names), inds) end +function StatsBase.sample(s::NStepBatchSampler{names}, ts::EpisodesBuffer) where {names} + valid_range = valid_range_nbatchsampler(s, ts) + valid_range = valid_range[valid_range ∈ findall(ts.sampleable_inds)] # Ensure that the valid range is within the sampleable indices + inds = rand(s.rng, valid_range, s.batch_size) + StatsBase.sample(s, ts, Val(names), inds) +end + + function StatsBase.sample(nbs::NStepBatchSampler, ts, ::Val{SS′ART}, inds) if isnothing(nbs.stack_size) s = ts[:state][inds] From a36ddc3335a22e4ab9129c04599917a0b5ff1af8 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <4462211+jeremiahpslewis@users.noreply.github.com> Date: Sat, 29 Jul 2023 19:46:35 +0200 Subject: [PATCH 3/6] Tweak code --- src/samplers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/samplers.jl b/src/samplers.jl index b0e0a71..d3d37a5 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -186,7 +186,7 @@ end function StatsBase.sample(s::NStepBatchSampler{names}, ts::EpisodesBuffer) where {names} valid_range = valid_range_nbatchsampler(s, ts) - valid_range = valid_range[valid_range ∈ findall(ts.sampleable_inds)] # Ensure that the valid range is within the sampleable indices + valid_range = valid_range[valid_range .∈ (findall(ts.sampleable_inds),)] # Ensure that the valid range is within the sampleable indices, probably could be done more efficiently by refactoring `valid_range_nbatchsampler` inds = rand(s.rng, valid_range, s.batch_size) StatsBase.sample(s, ts, Val(names), inds) end From ca99dd1634d6a1ccecbcb124aca5293b532f7858 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <4462211+jeremiahpslewis@users.noreply.github.com> Date: Sat, 29 Jul 2023 20:10:59 +0200 Subject: [PATCH 4/6] Add test for bug --- test/samplers.jl | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/test/samplers.jl b/test/samplers.jl index fb4c903..e562b41 100644 --- a/test/samplers.jl +++ b/test/samplers.jl @@ -138,7 +138,7 @@ end end #! format: on -@testset "Trajectory with CircularArraySARTSTraces and NStepBatchSampler" begin +@testset "Trajectory with CircularPrioritizedTraces and NStepBatchSampler" begin n=1 γ=0.99f0 @@ -169,3 +169,33 @@ end b = RLTrajectories.StatsBase.sample(t) @test haskey(b, :priority) end + + +@testset "Trajectory with CircularArraySARTSTraces and NStepBatchSampler" begin + n=1 + γ=0.99f0 + + t = Trajectory( + container=CircularArraySARTSTraces( + capacity=5, + state=Float32 => (4,), + ), + sampler=NStepBatchSampler{SS′ART}( + n=n, + γ=γ, + batch_size=32, + ), + controller=InsertSampleRatioController( + threshold=100, + n_inserted=-1 + ) + ) + + push!(t, (state = 1, action = true)) + for i = 1:9 + push!(t, (state = i+1, action = true, reward = i, terminal = false)) + end + + b = RLTrajectories.StatsBase.sample(t) + @test haskey(b, :priority) +end From 824fd8eda244a0721d4f2798cf93a67a4846d83e Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <4462211+jeremiahpslewis@users.noreply.github.com> Date: Sat, 29 Jul 2023 20:11:13 +0200 Subject: [PATCH 5/6] bug fix --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 218f701..ddd850d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ReinforcementLearningTrajectories" uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c" -version = "0.3.1" +version = "0.3.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From b06c4a3c59ef0f5cd746adfb0003ad6552c3fbd4 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <4462211+jeremiahpslewis@users.noreply.github.com> Date: Sat, 29 Jul 2023 21:14:10 +0200 Subject: [PATCH 6/6] Fix sampler test --- test/samplers.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/samplers.jl b/test/samplers.jl index e562b41..3a6fb20 100644 --- a/test/samplers.jl +++ b/test/samplers.jl @@ -168,6 +168,7 @@ end b = RLTrajectories.StatsBase.sample(t) @test haskey(b, :priority) + @test sum(b.action .== 0) == 0 end @@ -197,5 +198,5 @@ end end b = RLTrajectories.StatsBase.sample(t) - @test haskey(b, :priority) + @test sum(b.action .== 0) == 0 end