Debug Mode
The Problem
A major source of potential problems in AD systems is rules returning the wrong type of tangent / fdata / rdata for a given primal value. For example, if someone writes a rule like
function rrule!!(::CoDual{typeof(+)}, x::CoDual{<:Real}, y::CoDual{<:Real})
plus_reverse_pass(dz::Real) = NoRData(), dz, dz
return zero_fcodual(primal(x) + primal(y))
end
and calls
rrule(zero_fcodual(+), zero_fcodual(5.0), zero_fcodual(4f0))
then the type of dz
on the reverse pass will be Float64
(assuming everything happens correctly), and this rule will return a Float64
as the rdata for y
. However, the primal value of y
is a Float32
, and rdata_type(Float32)
is Float32
, so returning a Float64
is incorrect. This error might cause the reverse pass to fail loudly immediately, but it might also fail silently. It might cause an error much later in the reverse pass, making it hard to determine that the source of the error was the above rule. Worst of all, in some cases it could plausibly cause a segfault, which is more-or-less the worst kind of outcome possible.
The Solution
Check that the types of the fdata / rdata associated to arguments are exactly what tangent_type
/ fdata_type
/ rdata_type
require upon entry to / exit from rules and pullbacks.
This is implemented via DebugRRule
:
Mooncake.DebugRRule
— TypeDebugRRule(rule)
Construct a callable which is equivalent to rule
, but inserts additional type checking. In particular:
- check that the fdata in each argument is of the correct type for the primal
- check that the fdata in the
CoDual
returned from the rule is of the correct type for the primal.
This happens recursively. For example, each element of a Vector{Any}
is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.
Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).
Let rule
return y, pb!!
, then DebugRRule(rule)
returns y, DebugPullback(pb!!)
. DebugPullback
inserts the same kind of checks as DebugRRule
, but on the reverse-pass. See the docstring for details.
Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.
You can straightforwardly enable it when building a rule via the debug_mode
kwarg in the following:
Mooncake.build_rrule
— Functionbuild_rrule(args...; debug_mode=false)
Helper method. Only uses static information from args
.
build_rrule(sig::Type{<:Tuple})
Equivalent to build_rrule(Mooncake.get_interpreter(), sig)
.
build_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C}
Returns a DerivedRule
which is an rrule!!
for sig_or_mi
in context C
. See the docstring for rrule!!
for more info.
If debug_mode
is true
, then all calls to rules are replaced with calls to DebugRRule
s.
When using ADTypes.jl, you can choose whether or not to use it via the debug_mode
kwarg:
julia> AutoMooncake(; config=Mooncake.Config(; debug_mode=true))
AutoMooncake{Mooncake.Config}(Mooncake.Config(true, false))
When Should You Use Debug Mode?
Only use debug_mode
when debugging a problem. This is because is has substantial performance implications.