Skip to content

Commit

Permalink
Merge pull request #49 from JuliaReinforcementLearning/jpsl/fix
Browse files Browse the repository at this point in the history
Bug fix for EpisodesBuffer sampling
  • Loading branch information
jeremiahpslewis committed Jul 29, 2023
2 parents ccb43cd + b06c4a3 commit 40fefb2
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReinforcementLearningTrajectories"
uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
version = "0.3.1"
version = "0.3.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
2 changes: 1 addition & 1 deletion src/episodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ function pad!(trace::Trace)
return nothing
end

pad!(buf::CircularArrayBuffer{T}) where {T,N,A} = push!(buf, zero(T))
pad!(buf::CircularArrayBuffer{T,N,A}) where {T,N,A} = push!(buf, zero(T))
pad!(vect::Vector{T}) where {T} = push!(vect, zero(T))

#push a duplicate of last element as a dummy element for all 'trace' objects, ignores multiplex traces, should never be sampled.
Expand Down
15 changes: 14 additions & 1 deletion src/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,25 @@ end
NStepBatchSampler(; kw...) = NStepBatchSampler{SS′ART}(; kw...)
NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names} = NStepBatchSampler{names}(n, γ, batch_size, stack_size, rng)


function valid_range_nbatchsampler(s::NStepBatchSampler, ts)
# think about the extreme case where s.stack_size == 1 and s.n == 1
isnothing(s.stack_size) ? (1:(length(ts)-s.n+1)) : (s.stack_size:(length(ts)-s.n+1))
end
function StatsBase.sample(s::NStepBatchSampler{names}, ts) where {names}
valid_range = isnothing(s.stack_size) ? (1:(length(ts)-s.n+1)) : (s.stack_size:(length(ts)-s.n+1))# think about the exteme case where s.stack_size == 1 and s.n == 1
valid_range = valid_range_nbatchsampler(s, ts)
inds = rand(s.rng, valid_range, s.batch_size)
StatsBase.sample(s, ts, Val(names), inds)
end

function StatsBase.sample(s::NStepBatchSampler{names}, ts::EpisodesBuffer) where {names}
valid_range = valid_range_nbatchsampler(s, ts)
valid_range = valid_range[valid_range .∈ (findall(ts.sampleable_inds),)] # Ensure that the valid range is within the sampleable indices, probably could be done more efficiently by refactoring `valid_range_nbatchsampler`
inds = rand(s.rng, valid_range, s.batch_size)
StatsBase.sample(s, ts, Val(names), inds)
end


function StatsBase.sample(nbs::NStepBatchSampler, ts, ::Val{SS′ART}, inds)
if isnothing(nbs.stack_size)
s = ts[:state][inds]
Expand Down
33 changes: 32 additions & 1 deletion test/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ end
end
#! format: on

@testset "Trajectory with CircularArraySARTSTraces and NStepBatchSampler" begin
@testset "Trajectory with CircularPrioritizedTraces and NStepBatchSampler" begin
n=1
γ=0.99f0

Expand Down Expand Up @@ -168,4 +168,35 @@ end

b = RLTrajectories.StatsBase.sample(t)
@test haskey(b, :priority)
@test sum(b.action .== 0) == 0
end


@testset "Trajectory with CircularArraySARTSTraces and NStepBatchSampler" begin
n=1
γ=0.99f0

t = Trajectory(
container=CircularArraySARTSTraces(
capacity=5,
state=Float32 => (4,),
),
sampler=NStepBatchSampler{SS′ART}(
n=n,
γ=γ,
batch_size=32,
),
controller=InsertSampleRatioController(
threshold=100,
n_inserted=-1
)
)

push!(t, (state = 1, action = true))
for i = 1:9
push!(t, (state = i+1, action = true, reward = i, terminal = false))
end

b = RLTrajectories.StatsBase.sample(t)
@test sum(b.action .== 0) == 0
end

0 comments on commit 40fefb2

Please sign in to comment.