Tools for Rules
Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported. However, this does not always necessitate writing your own rrule!!
from scratch. In this section, we detail some useful strategies which can help you avoid having to write rrule!!
s in many situations.
Simplfiying Code via Overlays
Mooncake.@mooncake_overlay
— Macro@mooncake_overlay method_expr
Define a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.
For example, suppose that you have a function
julia> foo(x::Float64) = bar(x)
foo (generic function with 1 method)
where Mooncake.jl fails to differentiate bar
for some reason. If you have access to another function baz
, which does the same thing as bar
, but does so in a way which Mooncake.jl can differentiate, you can simply write:
julia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x)
When looking up the code for foo(::Float64)
, Mooncake.jl will see this method, rather than the original, and differentiate it instead.
A Worked Example
To demonstrate how to use @mooncake_overlay
s in practice, we here demonstrate how the answer that Mooncake.jl gives changes if you change the definition of a function using a @mooncake_overlay
. Do not do this in practice – this is just a simple way to demonostrate how to use overlays!
First, consider a simple example:
julia> scale(x) = 2x
scale (generic function with 1 method)
julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});
julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(10.0, (NoTangent(), 2.0))
We can use @mooncake_overlay
to change the definition which Mooncake.jl sees:
julia> Mooncake.@mooncake_overlay scale(x) = 3x
julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});
julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(15.0, (NoTangent(), 3.0))
As can be seen from the output, the result of differentiating using Mooncake.jl has changed to reflect the overlay-ed definition of the method.
Additionally, it is possible to use the usual multi-line syntax to declare an overlay:
julia> Mooncake.@mooncake_overlay function scale(x)
return 4x
end
julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});
julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(20.0, (NoTangent(), 4.0))
Functions with Zero Adjoint
If the above strategy does not work, but you find yourself in the surprisingly common situation that the adjoint of the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following:
Mooncake.@zero_adjoint
— Macro@zero_adjoint ctx sig
Defines is_primitive(context_type, sig) = true
, and defines a method of Mooncake.rrule!!
which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable
.
For example:
julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive
julia> foo(x) = 5
foo (generic function with 1 method)
julia> @zero_adjoint DefaultCtx Tuple{typeof(foo), Any}
julia> is_primitive(DefaultCtx, Tuple{typeof(foo), Any})
true
julia> rrule!!(zero_fcodual(foo), zero_fcodual(3.0))[2](NoRData())
(NoRData(), 0.0)
Limited support for Vararg
s is also available. For example
julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive
julia> foo_varargs(x...) = 5
foo_varargs (generic function with 1 method)
julia> @zero_adjoint DefaultCtx Tuple{typeof(foo_varargs), Vararg}
julia> is_primitive(DefaultCtx, Tuple{typeof(foo_varargs), Any, Float64, Int})
true
julia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData())
(NoRData(), 0.0, NoRData())
Be aware that it is not currently possible to specify any of the type parameters of the Vararg
. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}}
will not work with this macro.
WARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x
will yield incorrect results.
As always, you should use TestUtils.test_rule
to ensure that you've not made a mistake.
Signatures Unsupported By This Macro
If the signature you wish to apply @zero_adjoint
to is not supported, for example because it uses a Vararg
with a type parameter, you can still make use of zero_adjoint
.
Mooncake.zero_adjoint
— Functionzero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}
Utility functionality for constructing rrule!!
s for functions which produce adjoints which always return zero.
NOTE: you should only make use of this function if you cannot make use of the @zero_adjoint
macro.
You make use of this functionality by writing a method of Mooncake.rrule!!
, and passing all of its arguments (including the function itself) to this function. For example:
julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual
julia> foo(x::Vararg{Int}) = 5
foo (generic function with 1 method)
julia> is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(foo), Vararg{Int}}}) = true;
julia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);
julia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())
(NoRData(), NoRData(), NoRData())
WARNING: this is only correct if the output of primal(f)(map(primal, x)...)
does not alias anything in f
or x
. This is always the case if the result is a bits type, but more care may be required if it is not. ```
Using ChainRules.jl
ChainRules.jl provides a large number of rules for differentiating functions in reverse-mode. These rules are methods of the ChainRulesCore.rrule
function. There are some instances where it is most convenient to implement a Mooncake.rrule!!
by wrapping an existing ChainRulesCore.rrule
.
There is enough similarity between these two systems that most of the boilerplate code can be avoided.
Mooncake.@from_rrule
— Macro@from_rrule ctx sig [has_kwargs=false]
Convenience functionality to assist in using ChainRulesCore.rrule
s to write rrule!!
s.
Arguments
ctx
: A Mooncake context typesig
: the signature which you wish to assert should be a primitive inMooncake.jl
, and use an existingChainRulesCore.rrule
to implement this functionality.has_kwargs
: aBool
state whether or not the function has keyword arguments. This feature has the same limitations asChainRulesCore.rrule
– the derivative w.r.t. all kwargs must be zero.
Example Usage
A Basic Example
julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils
julia> using ChainRulesCore
julia> foo(x::Real) = 5x;
julia> function ChainRulesCore.rrule(::typeof(foo), x::Real)
foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), 5Ω
return foo(x), foo_pb
end;
julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat}
julia> rrule!!(zero_fcodual(foo), zero_fcodual(5.0))[2](1.0)
(NoRData(), 5.0)
julia> # Check that the rule works as intended.
TestUtils.test_rule(Xoshiro(123), foo, 5.0; is_primitive=true)
Test Passed
An Example with Keyword Arguments
julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils
julia> using ChainRulesCore
julia> foo(x::Real; cond::Bool) = cond ? 5x : 4x;
julia> function ChainRulesCore.rrule(::typeof(foo), x::Real; cond::Bool)
foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), cond ? 5Ω : 4Ω
return foo(x; cond), foo_pb
end;
julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} true
julia> _, pb = rrule!!(
zero_fcodual(Core.kwcall),
zero_fcodual((cond=false, )),
zero_fcodual(foo),
zero_fcodual(5.0),
);
julia> pb(3.0)
(NoRData(), NoRData(), NoRData(), 12.0)
julia> # Check that the rule works as intended.
TestUtils.test_rule(
Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true
)
Test Passed
Notice that, in order to access the kwarg method we must call the method of Core.kwcall
, as Mooncake's rrule!!
does not itself permit the use of kwargs.
Limitations
It is your responsibility to ensure that
- calls with signature
sig
do not mutate their arguments, - the output of calls with signature
sig
does not alias any of the inputs.
As with all hand-written rules, you should definitely make use of TestUtils.test_rule
to verify correctness on some test cases.
Argument Type Constraints
Many methods of ChainRuleCore.rrule
are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature
Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}
There are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.
Suffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule
will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat
argument, i.e. Union{Float16, Float32, Float64}
, but it is usually not possible to know that the rule is correct for all possible subtypes of Real
that someone might define.
Conversions Between Different Tangent Type Systems
Under the hood, this functionality relies on two functions: Mooncake.to_cr_tangent
, and Mooncake.increment_and_get_rdata!
. These two functions handle conversion to / from Mooncake
tangent types and ChainRulesCore
tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule
does not work in your case because the required method of either of these functions does not exist, please open an issue.