Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reusable and NonReusable capability #592

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Compat: hasfield, hasproperty

export frule, rrule # core function
# rule configurations
export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode
export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode, Reuseable, NotReuseable
export frule_via_ad, rrule_via_ad
# definition helper macros
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
Expand Down
18 changes: 18 additions & 0 deletions src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@ that do not support performing forwards mode AD should be `RuleConfig{>:NoForwar
"""
struct NoForwardsMode <: ForwardsModeCapability end

abstract type PullbackCapability end

"""
NotReuseable

This trait indicate that a pullback acquired by `RuleConfig{>:NotReuseable}` can only be called once.
So optimizations like reusing array buffers can be done in the pullback.
"""
struct NotReuseable <: PullbackCapability end

"""
Reuseable

This is the complement to [`NotReuseable`](@ref). If it is set then the pullback must return correct
result when being called multiple times. This is useful for computing jacobian.
"""
struct Reuseable <: PullbackCapability end

"""
frule_via_ad(::RuleConfig{>:HasForwardsMode}, ȧrgs, f, args...; kwargs...)

Expand Down
36 changes: 36 additions & 0 deletions test/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ function ChainRulesCore.rrule_via_ad(config::MockBothConfig, f, args...; kws...)
return f(args...; kws...), pullback_via_ad
end

struct ReuseableConfig <: RuleConfig{Union{NoForwardsMode,HasReverseMode,Reuseable}} end
struct NotReuseableConfig <: RuleConfig{Union{NoForwardsMode,HasReverseMode,NotReuseable}} end

##############################

#define some functions for testing
Expand Down Expand Up @@ -155,6 +158,39 @@ end
@test rconfig.reverse_calls == [(identity, (32.1,))]
end

@testset "pullback capability" begin
f(x) = x .* fill(2, size(x))
function ChainRulesCore.rrule(::RuleConfig{>:NotReuseable}, ::typeof(f), x)
tmp = similar(x)
fill!(tmp, 2)
y = x .* tmp
function pullback(Ȳ)
tmp .*= Ȳ
∂ = tmp
return (NoTangent(), ∂)
end
return y, pullback
end

function ChainRulesCore.rrule(::RuleConfig{>:Reuseable}, ::typeof(f), x)
tmp = similar(x)
fill!(tmp, 2)
y = x .* tmp
function pullback(Ȳ)
∂ = tmp .* Ȳ
return (NoTangent(), ∂)
end
return y, pullback
end

reuseable_pullback = rrule(ReuseableConfig(), f, randn(3))[2]
@test reuseable_pullback([1.0, 2.0, 3.0])[2] == [2.0, 4.0, 6.0]
@test reuseable_pullback([1.0, 2.0, 3.0])[2] == [2.0, 4.0, 6.0]
notreuseable_pullback = rrule(NotReuseableConfig(), f, randn(3))[2]
@test notreuseable_pullback([1.0, 2.0, 3.0])[2] == [2.0, 4.0, 6.0]
@test notreuseable_pullback([1.0, 2.0, 3.0])[2] != [2.0, 4.0, 6.0]
end

@testset "RuleConfig broadcasts like a scaler" begin
@test (MostBoringConfig() .=> (1, 2, 3)) isa NTuple{3,Pair{MostBoringConfig,Int}}
end
Expand Down