From 6ce28bdda70daa43f99c7eda15362f18b6c8c76f Mon Sep 17 00:00:00 2001 From: chriselrod Date: Thu, 31 Aug 2023 12:08:43 -0400 Subject: [PATCH 1/2] test with 1 thread, fix 1 threaded segfault --- .github/workflows/CI.yml | 1 + src/optimize.jl | 64 +++++++++++++++++++++++++--------------- test/mnist.jl | 30 ++++++++++++++++++- 3 files changed, 71 insertions(+), 24 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c533927..6a1a9d7 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -41,6 +41,7 @@ jobs: version: - 'nightly' # coverage fast on nightly threads: + - '1' - '3' - '4' steps: diff --git a/src/optimize.jl b/src/optimize.jl index 82d5592..5beaac3 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -379,17 +379,21 @@ function train_unbatched_core!( c::Chain, pu::Ptr{UInt8}, pX, - it, p::AbstractVector{T}, opt, + it, mpt ) where {T} numthreads = _numthreads() - glen = _try_static(numparam(getchain(c)), static_length(params)) + glen = _try_static(numparam(getchain(c)), static_length(p)) aligned_glen = align(glen) g = _alloc_grad(Ptr{T}(pu), glen, numthreads, aligned_glen) offset = static_sizeof(T) * aligned_glen * numthreads - train_unbatched_core!(c, pu + offset, g, pX, it, p, opt, mpt) + if numthreads == 1 + train_unbatched_core!(c, pu + offset, g[:, begin], pX, p, opt, it, mpt) + else + train_unbatched_core!(c, pu + offset, g, pX, p, opt, it, mpt) + end end """ @@ -461,7 +465,7 @@ function train_unbatched!( pX = maybe_static_size_arg(chn.inputdim, X) optoff = optmemsize(opt, p) @unpack layers = chn - glen = _try_static(numparam(chn), static_length(params)) + glen = _try_static(numparam(chn), static_length(p)) numthreads = _numthreads() T = Base.promote_eltype(p, X) @@ -634,18 +638,33 @@ function train_batched_core!( T = Base.promote_eltype(p, pX) g = _alloc_grad(Ptr{T}(pu), glen, numthreads, aligned_glen) offset = static_sizeof(T) * aligned_glen * numthreads - train_batched_core!( - c, - pu + offset, - g, - p, - pX, - opt, - iters, - leaveofflast, - mpt, - N_bs - ) + if numthreads == 1 + train_batched_core!( + c, + pu + offset, + g[:, begin], + p, + pX, + opt, + iters, + leaveofflast, + mpt, + N_bs + ) + else + train_batched_core!( + c, + pu + offset, + g, + p, + pX, + opt, + iters, + leaveofflast, + mpt, + N_bs + ) + end end """ train_batched!(g::AbstractVecOrMat, p, chn, X, opt, iters; batchsize = nothing) @@ -705,15 +724,14 @@ function train_batched!( align(sizeof(eltype(tgt)) * tgt_batch_len) + align(sizeof(eltype(X)) * X_batch_len) perm_mem = align(sizeof(Int) * N) + base_mem = optoff + perm_mem + T = Base.promote_eltype(p, X) if g === nothing - base_mem = - optoff + - perm_mem + - align(_try_static(numparam(chn), static_length(p))) * nthread - else - base_mem = optoff + perm_mem + base_mem += + align(_try_static(numparam(chn), static_length(p))) * + nthread * + static_sizeof(T) end - T = Base.promote_eltype(p, X) mpt, total_bytes = required_bytes(Val{T}(), layers, sxb, base_mem, shuffle_per_thread, nthread) GC.@preserve X begin diff --git a/test/mnist.jl b/test/mnist.jl index db2e4bc..50cd9bd 100644 --- a/test/mnist.jl +++ b/test/mnist.jl @@ -1,7 +1,7 @@ using Test ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" @testset "LeNet" begin - using SimpleChains, MLDatasets + using SimpleChains, MLDatasets, JET lenet = SimpleChain( (static(28), static(28), static(1)), @@ -128,4 +128,32 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" @test a1 > 0.93 @test a3 > 0.95 end + + @time SimpleChains.init_params!(p, lenet; rng = SimpleChains.local_rng()) + @time SimpleChains.train_unbatched!( + p, + lenetloss, + xtrain4, + SimpleChains.ADAM(3e-4), + 10 + ) + if VERSION >= v"1.10" + @test_opt SimpleChains.train_unbatched!( + p, + lenetloss, + xtrain4, + SimpleChains.ADAM(3e-4), + 10 + ) + end + a4, l4 = SimpleChains.accuracy_and_loss(lenetloss, xtrain4, p) + @test l4 ≈ lenetloss(xtrain4, p) + a5, l5 = SimpleChains.accuracy_and_loss(lenetloss, xtest4, ytest1, p) + @test l5 ≈ SimpleChains.add_loss(lenetloss, LogitCrossEntropyLoss(ytest1))( + xtest4, + p + ) + # TODO: unbatched training is currently much less effective... + @test a4 > 0.3 + @test a5 > 0.3 end From a8d8edad305a356f91928d4d2b93a6b2a36bbfef Mon Sep 17 00:00:00 2001 From: chriselrod Date: Thu, 31 Aug 2023 17:24:55 -0400 Subject: [PATCH 2/2] another missing sizeof --- src/optimize.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimize.jl b/src/optimize.jl index 5beaac3..e8a2ad0 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -473,7 +473,7 @@ function train_unbatched!( Val{T}(), layers, static_size(pX), - optoff + align(glen) * numthreads, + optoff + align(glen) * numthreads * static_sizeof(T), static(0), numthreads )