Debug Mode

The Problem

A major source of potential problems in AD systems is rules returning the wrong type of tangent / fdata / rdata for a given primal value. For example, if someone writes a rule like

function rrule!!(::CoDual{typeof(+)}, x::CoDual{<:Real}, y::CoDual{<:Real})
    plus_reverse_pass(dz::Real) = NoRData(), dz, dz
    return zero_fcodual(primal(x) + primal(y))
end

and calls

rrule(zero_fcodual(+), zero_fcodual(5.0), zero_fcodual(4f0))

then the type of dz on the reverse pass will be Float64 (assuming everything happens correctly), and this rule will return a Float64 as the rdata for y. However, the primal value of y is a Float32, and rdata_type(Float32) is Float32, so returning a Float64 is incorrect. This error might cause the reverse pass to fail loudly immediately, but it might also fail silently. It might cause an error much later in the reverse pass, making it hard to determine that the source of the error was the above rule. Worst of all, in some cases it could plausibly cause a segfault, which is more-or-less the worst kind of outcome possible.

The Solution

Check that the types of the fdata / rdata associated to arguments are exactly what tangent_type / fdata_type / rdata_type require upon entry to / exit from rules and pullbacks.

This is implemented via DebugRRule:

Mooncake.DebugRRuleType
DebugRRule(rule)

Construct a callable which is equivalent to rule, but inserts additional type checking. In particular:

  • check that the fdata in each argument is of the correct type for the primal
  • check that the fdata in the CoDual returned from the rule is of the correct type for the primal.

This happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.

Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).

Let rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.

Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.

source

You can straightforwardly enable it when building a rule via the debug_mode kwarg in the following:

Mooncake.build_rruleFunction
build_rrule(args...; debug_mode=false)

Helper method. Only uses static information from args.

source
build_rrule(sig::Type{<:Tuple})

Equivalent to build_rrule(Mooncake.get_interpreter(), sig).

source
build_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C}

Returns a DerivedRule which is an rrule!! for sig_or_mi in context C. See the docstring for rrule!! for more info.

If debug_mode is true, then all calls to rules are replaced with calls to DebugRRules.

source

When using ADTypes.jl, you can choose whether or not to use it via the debug_mode kwarg:

julia> AutoMooncake(; config=Mooncake.Config(; debug_mode=true))
AutoMooncake{Mooncake.Config}(Mooncake.Config(true, false))

When Should You Use Debug Mode?

Only use debug_mode when debugging a problem. This is because is has substantial performance implications.