Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ProjectTo causes scalar indexing when taking adjoints of complex CuArray #624

Open
DomCRose opened this issue Jul 21, 2023 · 1 comment · May be fixed by #630
Open

ProjectTo causes scalar indexing when taking adjoints of complex CuArray #624

DomCRose opened this issue Jul 21, 2023 · 1 comment · May be fixed by #630
Labels
bug Something isn't working ProjectTo related to the projection functionality

Comments

@DomCRose
Copy link

DomCRose commented Jul 21, 2023

As the title says. Code to see this:

using CUDA, Zygote
function test_func(a, b)
    return sum(abs2, a .+ b')
end
a = CUDA.rand(ComplexF64, 3)
b = CUDA.rand(3)
gradient(test_func, a, b)

Produces:

ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] assertscalar(op::String)
    @ GPUArraysCore C:\Users\domin\.julia\packages\GPUArraysCore\uOYfN\src\GPUArraysCore.jl:103
  [3] getindex(::CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}, ::Int64, ::Int64)
    @ GPUArrays C:\Users\domin\.julia\packages\GPUArrays\5XhED\src\host\indexing.jl:9
  [4] getindex
    @ C:\Users\domin\.julia\juliaup\julia-1.9.0+0.x64.w64.mingw32\share\julia\stdlib\v1.9\LinearAlgebra\src\adjtrans.jl:303 [inlined]
  [5] _unsafe_getindex_rs
    @ .\reshapedarray.jl:251 [inlined]
  [6] _unsafe_getindex
    @ .\reshapedarray.jl:248 [inlined]
  [7] getindex
    @ .\reshapedarray.jl:236 [inlined]
  [8] iterate
    @ .\abstractarray.jl:1220 [inlined]
  [9] iterate
    @ .\abstractarray.jl:1218 [inlined]
 [10] iterate
    @ .\generator.jl:44 [inlined]
 [11] _collect(c::Base.ReshapedArray{ComplexF64, 1, LinearAlgebra.Adjoint{ComplexF64, CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, itr::Base.Generator{Base.ReshapedArray{ComplexF64, 1, LinearAlgebra.Adjoint{ComplexF64, CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
    @ Base .\array.jl:802
 [12] collect_similar
    @ .\array.jl:711 [inlined]
 [13] map
    @ .\abstractarray.jl:3261 [inlined]
 [14] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::LinearAlgebra.Adjoint{ComplexF64, CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}})
    @ ChainRulesCore C:\Users\domin\.julia\packages\ChainRulesCore\0t04l\src\projection.jl:236
 [15] ProjectTo
    @ C:\Users\domin\.julia\packages\ChainRulesCore\0t04l\src\projection.jl:414 [inlined]
 [16] _project
    @ C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\chainrules.jl:189 [inlined]
 [17] unbroadcast(x::LinearAlgebra.Adjoint{Float64, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}}, x̄::CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\lib\broadcast.jl:62
 [18] #1172
    @ C:\Users\domin\.julia\packages\Zygote\JeHtr\src\lib\broadcast.jl:83 [inlined]
 [19] map
    @ .\tuple.jl:274 [inlined]
 [20] #1171
    @ C:\Users\domin\.julia\packages\Zygote\JeHtr\src\lib\broadcast.jl:83 [inlined]
 [21] #3754#back
    @ C:\Users\domin\.julia\packages\ZygoteRules\OgCVT\src\adjoint.jl:71 [inlined]
 [22] Pullback
    @ .\REPL[1]:2 [inlined]
 [23] (::Zygote.Pullback{Tuple{typeof(test_func), CuArray{ComplexF64, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.ZBack{typeof(ChainRules._adjoint_vec_pullback)}, ComposedFunction{Zygote.Pullback{Tuple{Zygote.var"#1441#1442", typeof(abs2), CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#4197#back#1437"{Zygote.var"#1433#1436"{CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#3978#back#1283"{Zygote.var"#1279#1282"{CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}}}}, typeof(ZygoteRules.unthunk_tangent)}, Zygote.var"#3754#back#1177"{Zygote.var"#1171#1175"{Tuple{CuArray{ComplexF64, 1, CUDA.Mem.DeviceBuffer}, LinearAlgebra.Adjoint{Float64, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}}})(Δ::Float64)
    @ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface2.jl:0
 [24] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(test_func), CuArray{ComplexF64, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.ZBack{typeof(ChainRules._adjoint_vec_pullback)}, ComposedFunction{Zygote.Pullback{Tuple{Zygote.var"#1441#1442", typeof(abs2), CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#4197#back#1437"{Zygote.var"#1433#1436"{CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#3978#back#1283"{Zygote.var"#1279#1282"{CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}}}}, typeof(ZygoteRules.unthunk_tangent)}, Zygote.var"#3754#back#1177"{Zygote.var"#1171#1175"{Tuple{CuArray{ComplexF64, 1, CUDA.Mem.DeviceBuffer}, LinearAlgebra.Adjoint{Float64, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}}}})(Δ::Float64)
    @ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface.jl:45
 [25] gradient(::Function, ::CuArray{ComplexF64, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{Any})
    @ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface.jl:97
 [26] top-level scope
    @ REPL[11]:1
 [27] top-level scope
    @ C:\Users\domin\.julia\packages\CUDA\tVtYo\src\initialization.jl:185

Interestingly, making a real and b complex allows it to run, but errors on display as the output type for the b gradient is becomes Base.ReshapedArray{ComplexF64, 1, LinearAlgebra.Adjoint{ComplexF64, CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}} which it refuses to print. Collecting that array produces a CuArray with the correct gradient.

The issue (at least with a complex and b real) seems to stem from

dy = eltype(dx) <: Real ? vec(dx) : adjoint(dx)
creating a Adjoint{ComplexF64, CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}, which when then reshaped at
reshape(dx, project.axes)
creates a Base.ReshapedArray{ComplexF64, 1, Adjoint{ComplexF64, CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}. Dispatch then sees this as an AbstractArray and sends it down base paths when map is called at
S <: T ? dy : map(project.element, dy)
rather than CUDA paths, resulting in scalar indexing.

When a is real and b is complex, the element type of the gradient S matches the element type of the primal T, so the map in

S <: T ? dy : map(project.element, dy)
is not hit and instead the reshaped adjoint CuArray escapes, which I assume then hits scalar indexing show methods when dispatched for printing. However, in a more complicated function I guess this tangent would then enter later pullbacks and cause scalar indexing before gradient returns.

As far as I understand it, this would ideally be fixed by better wrapper array handling in Base / CuArray, but that seems like a hard and long lived issue. In the meantime I'm not sure what the best way to fix this would be, and whether that responsibility lies with CUDA or ChainRulesCore. Given the leaking of the wrapped array as a gradient of b in the a real, b complex case, perhaps there could be some tweaks to wrapped array handling here. Perhaps when the typeof dx is an Adjoint(...) then the reshape should be replaced by an adjoint followed by a broadcast of conj, or the earlier adjoint call in the ProjectTo{Adjoint} method should be a conj broadcast instead? Not sure what would be correct.

@oxinabox oxinabox added bug Something isn't working ProjectTo related to the projection functionality labels Jul 25, 2023
@DomCRose
Copy link
Author

DomCRose commented Aug 15, 2023

xref: JuliaGPU/Adapt.jl#21, JuliaGPU/CUDA.jl#228

Since forward passes with nested wrappers will hit scalar indexing anyway, I think the best short term solution here would be to simply try and ensure that forward passes which only contain depth 1 wrappers, only result in depth 1 wrappers on the reverse pass for GPU arrays.

Materializing lazy array wrappers unecessarily could hamper CPU performance, so is it possible to add GPUArraysCore as a dependancy so that specialized methods can be added to ProjectTo to ensure this wrapper depth behaviour? Or would that increase load time too much?

Alternatively, perhaps the adjoint at

dy = eltype(dx) <: Real ? vec(dx) : adjoint(dx)
could be replaced by a transpose and some sort of lazy conjugation, if such a thing is implemented anywhere?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ProjectTo related to the projection functionality
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants