Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scalar indexing error using prod(...) #747

Open
mentics opened this issue Oct 24, 2023 · 1 comment
Open

Scalar indexing error using prod(...) #747

mentics opened this issue Oct 24, 2023 · 1 comment
Labels

Comments

@mentics
Copy link

mentics commented Oct 24, 2023

This might not be a bug, but at the least it seems to be a mysterious error.

I'm using Flux and in the loss function (which is processed by Zygote and so uses ChainRules.jl), I'm using prod. The surprising thing was that it ran for several iterations before failing with a scalar index error.

I'm guessing the error happens when it goes into this branch because it finds a zero somewhere. That would explain why some iterations occur before hitting this error. I don't know if the code can be changed to avoid scalar indexing, or maybe a more informative error, or maybe it's just something I need to better understand.

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] assertscalar(op::String)
    @ GPUArraysCore C:\Users\joel\.julia\packages\GPUArraysCore\uOYfN\src\GPUArraysCore.jl:103
  [3] getindex(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Int64, ::Int64)
    @ GPUArrays C:\Users\joel\.julia\packages\GPUArrays\5XhED\src\host\indexing.jl:9
  [4] maybeview
    @ .\views.jl:149 [inlined]
  [5] ∇prod_dims!(dx::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, #unused#::Val{1}, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, dy::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ ChainRules C:\Users\joel\.julia\packages\ChainRules\9sNmB\src\rulesets\Base\mapreduce.jl:287
  [6] ∇prod_dims(vald::Val{1}, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, dy::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ ChainRules C:\Users\joel\.julia\packages\ChainRules\9sNmB\src\rulesets\Base\mapreduce.jl:278
  [7] (::ChainRules.var"#1683#1686"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Int64, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})()
    @ ChainRules C:\Users\joel\.julia\packages\ChainRules\9sNmB\src\rulesets\Base\mapreduce.jl:265
  [8] unthunk
    @ C:\Users\joel\.julia\packages\ChainRulesCore\0t04l\src\tangent_types\thunks.jl:204 [inlined]
  [9] unthunk
    @ C:\Users\joel\.julia\packages\ChainRulesCore\0t04l\src\tangent_types\thunks.jl:237 [inlined]
 [10] wrap_chainrules_output
    @ C:\Users\joel\.julia\packages\Zygote\4rucm\src\compiler\chainrules.jl:110 [inlined]
 [11] map
    @ .\tuple.jl:274 [inlined]
 [12] wrap_chainrules_output
    @ C:\Users\joel\.julia\packages\Zygote\4rucm\src\compiler\chainrules.jl:111 [inlined]
 [13] ZBack
    @ C:\Users\joel\.julia\packages\Zygote\4rucm\src\compiler\chainrules.jl:211 [inlined]
 [14] (::Zygote.var"#kw_zpullback#53"{ChainRules.var"#prod_pullback#1684"{Int64, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})(dy::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\4rucm\src\compiler\chainrules.jl:237
...
@mcabbott mcabbott added the GPU label Oct 24, 2023
@mcabbott
Copy link
Member

I forgot how this works, but as you observe, the code which is careful about zeros is behind a test any(iszero, x), precisely to let all-nonzero x work without error on a GPU.

It wouldn't be crazy to add an explicit message there, like x isa AbstractGPUArray && @error "rule for prod found zeros..." maxlog=1. The reason for the error is otherwise a bit mysterious, as e.g. it may work for the first iteration but later fail. Usually scalar indexing errors depend only on the types not the values.

The alternatives are just to use the broadcasting branch on GPU arrays (and accept that getting NaN is your signal that something is wrong), or to write a GPU kernel to do this correctly (perhaps using KernelAbstractions to be device-agnostic).

That's for ChainRules. Can you share more about what your actual use is? Inserting something like clamp.(x, 0.001, 0.99) is one way you might avoid problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants