Internal Docstrings

Docstrings listed here are not part of the public Mooncake.jl interface. Consequently, they can change between non-breaking changes to Mooncake.jl without warning.

The purpose of this is to make it easy for developers to find docstrings straightforwardly via the docs, as opposed to having to ctrl+f through Mooncake.jl's source code, or looking at the docstrings via the Julia REPL.

Mooncake.ADInfoType
ADInfo

This data structure is used to hold "global" information associated to a particular call to build_rrule. It is used as a means of communication between make_ad_stmts! and the codegen which produces the forwards- and reverse-passes.

  • interp: a MooncakeInterpreter.
  • block_stack_id: the ID associated to the block stack – the stack which keeps track of which blocks we visited during the forwards-pass, and which is used on the reverse-pass to determine which blocks to visit.
  • block_stack: the block stack. Can always be found at block_stack_id in the forwards- and reverse-passes.
  • entry_id: ID associated to the block inserted at the start of execution in the the forwards-pass, and the end of execution in the pullback.
  • shared_data_pairs: the SharedDataPairs used to define the captured variables passed to both the forwards- and reverse-passes.
  • arg_types: a map from Argument to its static type.
  • ssa_insts: a map from ID associated to lines to the primal NewInstruction. This contains the line of code, its static / inferred type, and some other detailss. See Core.Compiler.NewInstruction for a full list of fields.
  • arg_rdata_ref_ids: the dict mapping from arguments to the ID which creates and initialises the Ref which contains the reverse data associated to that argument. Recall that the heap allocations associated to this Ref are always optimised away in the final programme.
  • ssa_rdata_ref_ids: the same as arg_rdata_ref_ids, but for each ID associated to an ssa rather than each argument.
  • debug_mode: if true, run in "debug mode" – wraps all rule calls in DebugRRule. This is applied recursively, so that debug mode is also switched on in derived rules.
  • is_used_dict: for each ID associated to a line of code, is false if line is not used anywhere in any other line of code.
  • lazy_zero_rdata_ref_id: for any arguments whose type doesn't permit the construction of a zero-valued rdata directly from the type alone (e.g. a struct with an abstractly- typed field), we need to have a zero-valued rdata available on the reverse-pass so that this zero-valued rdata can be returned if the argument (or a part of it) is never used during the forwards-pass and consequently doesn't obtain a value on the reverse-pass. To achieve this, we construct a LazyZeroRData for each of the arguments on the forwards-pass, and make use of it on the reverse-pass. This field is the ID that will be associated to this information.
source
Mooncake.ADStmtInfoType
ADStmtInfo

Data structure which contains the result of make_ad_stmts!. Fields are

  • line: the ID associated to the primal line from which this is derived
  • comms_id: an ID from one of the lines in fwds, whose value will be made available on the reverse-pass in the same ID. Nothing is asserted about how this value is made available on the reverse-pass of AD, so this package is free to do this in whichever way is most efficient, in particular to group these communication ID on a per-block basis.
  • fwds: the instructions which run the forwards-pass of AD
  • rvs: the instructions which run the reverse-pass of AD / the pullback
source
Mooncake.BlockStackType

The block stack is the stack used to keep track of which basic blocks are visited on the forwards pass, and therefore which blocks need to be visited on the reverse pass. There is one block stack per derived rule. By using Int32, we assume that there aren't more than typemax(Int32) unique basic blocks in a given function, which ought to be reasonable.

source
Mooncake.CannotProduceZeroRDataFromTypeType
CannotProduceZeroRDataFromType()

Returned by zero_rdata_from_type if is not possible to construct the zero rdata element for a given type. See zero_rdata_from_type for more info.

source
Mooncake.DebugPullbackType
DebugPullback(pb, y, x)

Construct a callable which is equivalent to pb, but which enforces type-based pre- and post-conditions to pb. Let dx = pb.pb(dy), for some rdata dy, then this function

  • checks that dy has the correct rdata type for y, and
  • checks that each element of dx has the correct rdata type for x.

Reverse pass counterpart to DebugRRule

source
Mooncake.DebugPullbackMethod
(pb::DebugPullback)(dy)

Apply type checking to enforce pre- and post-conditions on pb.pb. See the docstring for DebugPullback for details.

source
Mooncake.DebugRRuleType
DebugRRule(rule)

Construct a callable which is equivalent to rule, but inserts additional type checking. In particular:

  • check that the fdata in each argument is of the correct type for the primal
  • check that the fdata in the CoDual returned from the rule is of the correct type for the primal.

This happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.

Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).

Let rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.

Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.

source
Mooncake.DebugRRuleMethod
(rule::DebugRRule)(x::CoDual...)

Apply type checking to enforce pre- and post-conditions on rule.rule. See the docstring for DebugRRule for details.

source
Mooncake.DefaultCtxType
struct DefaultCtx end

Context for all usually used AD primitives. Anything which is a primitive in a MinimalCtx is a primitive in the DefaultCtx automatically. If you are adding a rule for the sake of performance, it should be a primitive in the DefaultCtx, but not the MinimalCtx.

source
Mooncake.DynamicDerivedRuleType
DynamicDerivedRule(interp::MooncakeInterpreter, debug_mode::Bool)

For internal use only.

A callable data structure which, when invoked, calls an rrule specific to the dynamic types of its arguments. Stores rules in an internal cache to avoid re-deriving.

This is used to implement dynamic dispatch.

source
Mooncake.FDataType
FData(data::NamedTuple)

The component of a struct which is propagated alongside the primal on the forwards-pass of AD. For example, the tangents for Float64s do not need to be propagated on the forwards- pass of reverse-mode AD, so any Float64 fields of Tangent do not need to appear in the associated FData.

source
Mooncake.LazyDerivedRuleType
LazyDerivedRule(interp, mi::Core.MethodInstance, debug_mode::Bool)

For internal use only.

A type-stable wrapper around a DerivedRule, which only instantiates the DerivedRule when it is first called. This is useful, as it means that if a rule does not get run, it does not have to be derived.

If debug_mode is true, then the rule constructed will be a DebugRRule. This is useful when debugging, but should usually be switched off for production code as it (in general) incurs some runtime overhead.

Note: the signature of the primal for which this is a rule is stored in the type. The only reason to keep this around is for debugging – it is very helpful to have this type visible in the stack trace when something goes wrong, as it allows you to trivially determine which bit of your code is the culprit.

Extended Help

There are two main reasons why deferring the construction of a DerivedRule until we need to use it is crucial.

The first is to do with recursion. Consider the following function:

f(x) = x > 0 ? f(x - 1) : x

If we generate the IRCode for this function, we will see something like the following:

julia> Base.code_ircode_by_type(Tuple{typeof(f), Float64})[1][1]
1 1 ─ %1  = Base.lt_float(0.0, _2)::Bool
  │   %2  = Base.or_int(%1, false)::Bool
  └──       goto #6 if not %2
  2 ─ %4  = Base.sub_float(_2, 1.0)::Float64
  │   %5  = Base.lt_float(0.0, %4)::Bool
  │   %6  = Base.or_int(%5, false)::Bool
  └──       goto #4 if not %6
  3 ─ %8  = Base.sub_float(%4, 1.0)::Float64
  │   %9  = invoke Main.f(%8::Float64)::Float64
  └──       goto #5
  4 ─       goto #5
  5 ┄ %12 = φ (#3 => %9, #4 => %4)::Float64
  └──       return %12
  6 ─       return _2

Suppose that we decide to construct a DerivedRule immediately whenever we find an :invoke statement in a rule that we're currently building a DerivedRule for. In the above example, we produce an infinite recursion when we attempt to produce a DerivedRule for %9, because it has the same signature as the call which generates this IR. By instead adopting a policy of constructing a LazyDerivedRule whenever we encounter an :invoke statement, we avoid this problem.

The second reason that delaying the construction of a DerivedRule, is essential is that it ensures that we don't derive rules for method instances which aren't run. Suppose that function B contains code for which we can't derive a rule – perhaps it contains an unsupported language feature like a PhiCNode or an UpsilonNode. Suppose that function A contains an :invoke which refers to function B, but that this call is on a branch which deals with error handling, and doesn't get run run unless something goes wrong. By deferring the derivation of the rule for B, we only ever attempt to derive it if we land on this error handling branch. Conversely, if we attempted to derive the rule for B when we derive the rule for A, we would be unable to complete the derivation of the rule for A.

source
Mooncake.LazyZeroRDataType
LazyZeroRData{P, Tdata}()

This type is a lazy placeholder for zero_like_rdata_from_type. This is used to defer construction of zero data to the reverse pass. Calling instantiate on an instance of this will construct a zero data.

Users should construct using LazyZeroRData(p), where p is an value of type P. This constructor, and instantiate, are specialised to minimise the amount of data which must be stored. For example, Float64s do not need any data, so LazyZeroRData(0.0) produces an instance of a singleton type, meaning that various important optimisations can be performed in AD.

source
Mooncake.MinimalCtxType
struct MinimalCtx end

Functions should only be primitives in this context if not making them so would cause AD to fail. In particular, do not add primitives to this context if you are writing them for performance only – instead, make these primitives in the DefaultCtx.

source
Mooncake.NoFDataType
NoFData

Singleton type which indicates that there is nothing to be propagated on the forwards-pass in addition to the primal data.

source
Mooncake.NoPullbackMethod
NoPullback(args::CoDual...)

Construct a NoPullback from the arguments passed to an rrule!!. For each argument, extracts the primal value, and constructs a LazyZeroRData. These are stored in a NoPullback which, in the reverse-pass of AD, instantiates these LazyZeroRDatas and returns them in order to perform the reverse-pass of AD.

The advantage of this approach is that if it is possible to construct the zero rdata element for each of the arguments lazily, the NoPullback generated will be a singleton type. This means that AD can avoid generating a stack to store this pullback, which can result in significant performance improvements.

source
Mooncake.NoTangentType
NoTangent

The type in question has no meaningful notion of a tangent space. Generally, you shouldn't use this – just let the default recursive tangent construction work. You might need to use this for primitive types though.

source
Mooncake.RRuleZeroWrapperType
RRuleZeroWrapper(rule)

This struct is used to ensure that ZeroRDatas, which are used as placeholder zero elements whenever an actual instance of a zero rdata for a particular primal type cannot be constructed without also having an instance of said type, never reach rules. On the pullback, we increment the cotangent dy by an amount equal to zero. This ensures that if it is a ZeroRData, we instead get an actual zero of the correct type. If it is not a zero rdata, the computation should be elided via inlining + constant prop.

source
Mooncake.SharedDataPairsType
SharedDataPairs()

A data structure used to manage the captured data in the OpaqueClosures which implement the bulk of the forwards- and reverse-passes of AD. An entry (id, data) at element n of the pairs field of this data structure means that data will be available at register id during the forwards- and reverse-passes of AD.

This is achieved by storing all of the data in the pairs field in the captured tuple which is passed to an OpaqueClosure, and extracting this data into registers associated to the corresponding IDs.

source
Mooncake.StackType
Stack{T}()

A stack specialised for reverse-mode AD.

Semantically equivalent to a usual stack, but never de-allocates memory once allocated.

source
Mooncake.ZeroRDataType
ZeroRData()

Singleton type indicating zero-valued rdata. This should only ever appear as an intermediate quantity in the reverse-pass of AD when the type of the primal is not fully inferable, or a field of a type is abstractly typed.

If you see this anywhere in actual code, or if it appears in a hand-written rule, this is an error – please open an issue in such a situation.

source
Mooncake.__exclude_unsupported_outputMethod
__exclude_unsupported_output(y)

Required for the robust design of value_and_pullback, prepare_pullback_cache. Ensures that y contains no aliasing, circular references, Ptrs or non differentiable datatypes. In the forward pass f(args...) output can only return a "Tree" like datastructure with leaf nodes as primitive types. Refer https://github.com/compintell/Mooncake.jl/issues/517#issuecomment-2715202789 and related issue for details. Internally calls __exclude_unsupported_output_internal!. The design is modelled after zero_tangent.

source
Mooncake.__exclude_unsupported_output_internal!Method
__exclude_unsupported_output_internal(y::T, address_set::Set{UInt}) where {T}

For checking if outputy is a valid Mutable/immutable composite or a primitive type. Performs a recursive depth first search over the function output y with an isbitstype() check base case. The visited memory addresses are stored inside address_set. If the set already contains a newly visited address, it errors out indicating an Alais or Circular reference. Also errors out if y is or contains a Pointer. It is called internally by __exclude_unsupported_output(y).

source
Mooncake.__flatten_varargsMethod
__flatten_varargs(isva::Bool, args, ::Val{nvargs}) where {nvargs}

If isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0).

source
Mooncake.__insts_to_instruction_streamMethod
__insts_to_instruction_stream(insts::Vector{Any})

Produces an instruction stream whose

  • stmt (v1.11 and up) / inst (v1.10) field is insts,
  • type field is all Any,
  • info field is all Core.Compiler.NoCallInfo,
  • line field is all Int32(1), and
  • flag field is all Core.Compiler.IR_FLAG_REFINED.

As such, if you wish to ensure that your IRCode prints nicely, you should ensure that its linetable field has at least one element.

source
Mooncake.__pop_blk_stack!Method
__pop_blk_stack!(block_stack::BlockStack)

Equivalent to pop!(block_stack). Going via this function, rather than just calling pop! directly, makes it easy to figure out how much time is spent popping the block stack when profiling performance, and to know that this function was hit when debugging.

source
Mooncake.__push_blk_stack!Method
__push_blk_stack!(block_stack::BlockStack, id::Int32)

Equivalent to push!(block_stack, id). Going via this function, rather than just calling push! directly, is helpful for debugging and performance analysis – it makes it very straightforward to figure out much time is spent pushing to the block stack when profiling.

source
Mooncake.__unflatten_codual_varargsMethod
__unflatten_codual_varargs(isva::Bool, args, ::Val{nargs}) where {nargs}

If isva and nargs=2, then inputs (CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0)) are transformed into (CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0))).

source
Mooncake.__value_and_gradient!!Method
__value_and_gradient!!(rule, f::CoDual, x::CoDual...)

Note: this is not part of the public Mooncake.jl interface, and may change without warning.

Equivalent to __value_and_pullback!!(rule, 1.0, f, x...) – assumes f returns a Float64.

# Set up the problem.
f(x, y) = sum(x .* y)
x = [2.0, 2.0]
y = [1.0, 1.0]
rule = build_rrule(f, x, y)

# Allocate tangents. These will be written to in-place. You are free to re-use these if you
# compute gradients multiple times.
tf = zero_tangent(f)
tx = zero_tangent(x)
ty = zero_tangent(y)

# Do AD.
Mooncake.__value_and_gradient!!(
    rule, Mooncake.CoDual(f, tf), Mooncake.CoDual(x, tx), Mooncake.CoDual(y, ty)
)
# output

(4.0, (NoTangent(), [1.0, 1.0], [2.0, 2.0]))
source
Mooncake.__value_and_pullback!!Method
__value_and_pullback!!(rule, ȳ, f::CoDual, x::CoDual...; y_cache=nothing)

Note: this is not part of the public Mooncake.jl interface, and may change without warning.

In-place version of value_and_pullback!! in which the arguments have been wrapped in CoDuals. Note that any mutable data in f and x will be incremented in-place. As such, if calling this function multiple times with different values of x, should be careful to ensure that you zero-out the tangent fields of x each time.

source
Mooncake._add_to_primalFunction
_add_to_primal(p::P, t::T, unsafe::Bool=false) where {P, T}

Adds t to p, returning a P. It must be the case that tangent_type(P) == T.

If unsafe is true and P is a composite type, then _add_to_primal will construct a new instance of P by directly invoking the :new instruction for P, rather than attempting to use the default constructor for P. This is fine if you are confident that the new P constructed by adding t to p will always be a valid instance of P, but could cause problems if you are not confident of this.

This is, for example, fine for the following type:

struct Foo{T}
    x::Vector{T}
    y::Vector{T}
    function Foo(x::Vector{T}, y::Vector{T}) where {T}
        @assert length(x) == length(y)
        return new{T}(x, y)
    end
end

Here, the value returned by _add_to_primal will satisfy the invariant asserted in the inner constructor for Foo.

source
Mooncake._diffMethod
_diff(p::P, q::P) where {P}

Required for testing.

Computes the difference between p and q, which must be of the same type, P. Returns a tangent of type tangent_type(P).

source
Mooncake._dotMethod
_dot(t::T, s::T)::Float64 where {T}

Required for testing. Should be defined for all standard tangent types.

Inner product between tangents t and s. Must return a Float64. Always available because all tangent types correspond to finite-dimensional vector spaces.

source
Mooncake._findallMethod
_findall(cond, x::Tuple)

Type-stable version of findall for Tuples. Should constant-fold if cond can be determined from the type of x.

source
Mooncake._foreigncall_Method
function _foreigncall_(
    ::Val{name}, ::Val{RT}, AT::Tuple, ::Val{nreq}, ::Val{calling_convention}, x...
) where {name, RT, nreq, calling_convention}

:foreigncall nodes get translated into calls to this function. For example,

Expr(:foreigncall, :foo, Tout, (A, B), nreq, :ccall, args...)

becomes

_foreigncall_(Val(:foo), Val(Tout), (Val(A), Val(B)), Val(nreq), Val(:ccall), args...)

Please consult the Julia documentation for more information on how foreigncall nodes work, and consult this package's tests for examples.

Credit: Umlaut.jl has the original implementation of this function. This is largely copied over from there.

source
Mooncake._mapMethod
_map(f, x...)

Same as map but requires all elements of x to have equal length. The usual function map doesn't enforce this for Arrays.

source
Mooncake._map_if_assigned!Method
_map_if_assigned!(f::F, y::DenseArray, x1::DenseArray{P}, x2::DenseArray)

Similar to the other method of _map_if_assigned! – for all n, if x1[n] is assigned, writes f(x1[n], x2[n]) to y[n], otherwise leaves y[n] unchanged.

Requires that y, x1, and x2 have the same size.

source
Mooncake._map_if_assigned!Method
_map_if_assigned!(f, y::DenseArray, x::DenseArray{P}) where {P}

For all n, if x[n] is assigned, then writes the value returned by f(x[n]) to y[n], otherwise leaves y[n] unchanged.

Equivalent to map!(f, y, x) if P is a bits type as element will always be assigned.

Requires that y and x have the same size.

source
Mooncake._new_Method
_new_(::Type{T}, x::Vararg{Any, N}) where {T, N}

One-liner which calls the :new instruction with type T with arguments x.

source
Mooncake._scaleMethod
_scale(a::Float64, t::T) where {T}

Required for testing. Should be defined for all standard tangent types.

Multiply tangent t by scalar a. Always possible because any given tangent type must correspond to a vector field. Not using * in order to avoid piracy.

source
Mooncake._typeofMethod
_typeof(x)

Central definition of typeof, which is specific to the use-required in this package.

source
Mooncake.ad_stmt_infoMethod
ad_stmt_info(line::ID, comms_id::Union{ID, Nothing}, fwds, rvs)

Convenient constructor for ADStmtInfo. If either fwds or rvs is not a vector, __vec promotes it to a single-element Vector.

source
Mooncake.add_data!Method
add_data!(info::ADInfo, data)::ID

Equivalent to add_data!(info.shared_data_pairs, data).

source
Mooncake.add_data!Method
add_data!(p::SharedDataPairs, data)::ID

Puts data into p, and returns the id associated to it. This id should be assumed to be available during the forwards- and reverse-passes of AD, and it should further be assumed that the value associated to this id is always data.

source
Mooncake.add_data_if_not_singleton!Method
add_data_if_not_singleton!(p::Union{ADInfo, SharedDataPairs}, x)

Returns x if it is a singleton, or the ID of the ssa which will contain it on the forwards- and reverse-passes. The reason for this is that if something is a singleton, it can be inserted directly into the IR.

source
Mooncake.always_initialisedMethod
always_initialised(::Type{P}) where {P}

Returns a tuple with number of fields equal to the number of fields in P. The nth field is set to true if the nth field of P is initialised, and false otherwise.

source
Mooncake.arrayifyMethod
arrayify(x::CoDual{<:AbstractArray{<:BlasFloat}})

Return the primal field of x, and convert its fdata into an array of the same type as the primal. This operation is not guaranteed to be possible for all array types, but seems to be possible for all array types of interest so far.

source
Mooncake.build_primitive_rruleMethod
build_primitive_rrule(sig::Type{<:Tuple})

Construct an rrule for signature sig. For this function to be called in build_rrule, you must also ensure that is_primitive(context_type, sig) is true. The callable returned by this must obey the rrule interface, but there are no restrictions on the type of callable itself. For example, you might return a callable struct. By default, this function returns rrule!! so, most of the time, you should just implement a method of rrule!!.

Extended Help

The purpose of this function is to permit computation at rule construction time, which can be re-used at runtime. For example, you might wish to derive some information from sig which you use at runtime (e.g. the fdata type of one of the arguments). While constant propagation will often optimise this kind of computation away, it will sometimes fail to do so in hard-to-predict circumstances. Consequently, if you need certain computations not to happen at runtime in order to guarantee good performance, you might wish to e.g. emit a callable struct with type parameters which are the result of this computation. In this context, the motivation for using this function is the same as that of using staged programming (e.g. via @generated functions) more generally.

source
Mooncake.build_rruleMethod
build_rrule(args...; kwargs...)

Helper method: equivalent to extracting the signature from args and calling build_rrule(sig; kwargs...).

source
Mooncake.build_rruleMethod
build_rrule(sig::Type{<:Tuple}; kwargs...)

Helper method: Equivalent to build_rrule(Mooncake.get_interpreter(), sig; kwargs...).

source
Mooncake.build_rruleMethod
build_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C}

Returns a DerivedRule which is an rrule!! for sig_or_mi in context C. See the docstring for rrule!! for more info.

If debug_mode is true, then all calls to rules are replaced with calls to DebugRRules.

source
Mooncake.codual_typeMethod
codual_type(P::Type)

The type of the CoDual which contains instances of P and associated tangents.

source
Mooncake.comms_channelMethod
comms_channel(info::ADStmtInfo)

Return the element of fwds whose ID is the communcation ID. Returns Nothing if comms_id is nothing.

source
Mooncake.conclude_rvs_blockMethod
conclude_rvs_block(
    blk::BBlock, pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo
)

Generates code which is inserted at the end of each counterpart block in the reverse-pass. Handles phi nodes, and choosing the correct next block to switch to.

source
Mooncake.const_codualMethod
const_codual(stmt, info::ADInfo)

Build a CoDual from stmt, with zero / uninitialised fdata. If the resulting CoDual is a bits type, then it is returned. If it is not, then the CoDual is put into shared data, and the ID associated to it in the forwards- and reverse-passes returned.

source
Mooncake.const_codual_stmtMethod
const_codual_stmt(stmt, info::ADInfo)

Returns a :call expression which will return a CoDual whose primal is stmt, and whose tangent is whatever uninit_tangent returns.

source
Mooncake.const_prop_gotoifnots!Method
const_prop_gotoifnots(ir::IRCode)

Replace all occurences in ir of goto %n if not true in block b with a goto b + 1, and all occurences of goto %n if not false with goto n, and make the adjustments to ir that this necessitates.

source
Mooncake.create_comms_insts!Method
create_comms_insts!(ad_stmts_blocks::ADStmts, info::ADInfo)

This function produces code which can be inserted into the forwards-pass and reverse-pass at specific locations to implement the promise associated to the comms_id field of the ADStmtInfo type – namely that if you assign a value to comms_id on the forwards-pass, the same value will be available at comms_id on the reverse-pass.

For each basic block represented in ADStmts:

  1. create a stack containing a Tuple which can hold all of the values associated to the comms_ids for each statement. Put this stack in shared data.
  2. create instructions which can be inserted at the end of the block generated to perform the forwards-pass (in forwards_pass_ir) which will put all of the data associated to the comms_ids into shared data, and
  3. create instruction which can be inserted at the start of the block generated to perform the reverse-pass (in pullback_ir), which will extract all of the data put into shared data by the instructions generated by the previous point, and assigned them to the comms_ids.

Returns two a Tuple{Vector{IDInstPair}, Vector{IDInstPair}. The nth element of each Vector corresponds to the instructions to be inserted into the forwards- and reverse passes resp. for the nth block in ad_stmts_blocks.

source
Mooncake.fdataMethod
fdata(t)::fdata_type(typeof(t))

Extract the forwards data from tangent t.

source
Mooncake.fdata_field_typeMethod
fdata_field_type(::Type{P}, n::Int) where {P}

Returns the type of to the nth field of the fdata type associated to P. Will be a PossiblyUninitTangent if said field can be undefined.

source
Mooncake.fdata_typeMethod
fdata_type(T)

Returns the type of the forwards data associated to a tangent of type T.

Extended help

Rules in Mooncake.jl do not operate on tangents directly. Rather, functionality is defined to split each tangent into two components, that we call fdata (forwards-pass data) and rdata (reverse-pass data). In short, any component of a tangent which is identified by its address (e.g. a mutable structs or an Array) gets passed around on the forwards-pass of AD and is incremented in-place on the reverse-pass, while components of tangents identified by their value get propagated and accumulated only on the reverse-pass.

Given a tangent type T, you can find out what type its fdata and rdata must be with fdata_type(T) and rdata_type(T) respectively. A consequence of this is that there is exactly one valid fdata type and rdata type for each primal type.

Given a tangent t, you can get its fdata and rdata using f = fdata(t) and r = rdata(t) respectively. f and r can be re-combined to recover the original tangent using the binary version of tangent: tangent(f, r). It must always hold that

tangent(fdata(t), rdata(t)) === t

The need for all of this is explained in the docs, but for now it suffices to consider our running examples again, and to see what their fdata and rdata look like.

Int

Ints are non-differentiable types, so there is nothing to pass around on the forwards- or reverse-pass. Therefore

julia> fdata_type(tangent_type(Int)), rdata_type(tangent_type(Int))
(NoFData, NoRData)

Float64

The tangent type of Float64 is Float64. Float64s are identified by their value / have no fixed address, so

julia> (fdata_type(Float64), rdata_type(Float64))
(NoFData, Float64)

Vector{Float64}

The tangent type of Vector{Float64} is Vector{Float64}. A Vector{Float64} is identified by its address, so

julia> (fdata_type(Vector{Float64}), rdata_type(Vector{Float64}))
(Vector{Float64}, NoRData)

Tuple{Float64, Vector{Float64}, Int}

This is an example of a type which has both fdata and rdata. The tangent type for Tuple{Float64, Vector{Float64}, Int} is Tuple{Float64, Vector{Float64}, NoTangent}. Tuples have no fixed memory address, so we interogate each field on its own. We have already established the fdata and rdata types for each element, so we recurse to obtain:

julia> T = tangent_type(Tuple{Float64, Vector{Float64}, Int})
Tuple{Float64, Vector{Float64}, NoTangent}

julia> (fdata_type(T), rdata_type(T))
(Tuple{NoFData, Vector{Float64}, NoFData}, Tuple{Float64, NoRData, NoRData})

The zero tangent for (5.0, [5.0]) is t = (0.0, [0.0]). fdata(t) returns (NoFData(), [0.0]), where the second element is === to the second element of t. rdata(t) returns (0.0, NoRData()). In this example, t contains a mixture of data, some of which is identified by its value, and some of which is identified by its address, so there is some fdata and some rdata.

Structs

Structs are handled in more-or-less the same way as Tuples, albeit with the possibility of undefined fields needing to be explicitly handled. For example, a struct such as

julia> struct Foo
           x::Float64
           y
           z::Int
       end

has tangent type

julia> tangent_type(Foo)
Tangent{@NamedTuple{x::Float64, y, z::NoTangent}}

Its fdata and rdata are given by special FData and RData types:

julia> (fdata_type(tangent_type(Foo)), rdata_type(tangent_type(Foo)))
(Mooncake.FData{@NamedTuple{x::NoFData, y, z::NoFData}}, Mooncake.RData{@NamedTuple{x::Float64, y, z::NoRData}})

Practically speaking, FData and RData both have the same structure as Tangents and are just used in different contexts.

Mutable Structs

The fdata for a mutable structs is its tangent, and it has no rdata. This is because mutable structs have fixed memory addresses, and can therefore be incremented in-place. For example,

julia> mutable struct Bar
           x::Float64
           y
           z::Int
       end

has tangent type

julia> tangent_type(Bar)
MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}

and fdata / rdata types

julia> (fdata_type(tangent_type(Bar)), rdata_type(tangent_type(Bar)))
(MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}, NoRData)

Primitive Types

As with tangents, each primitive type must specify what its fdata and rdata is. See specific examples for details.

source
Mooncake.fix_up_invoke_inference!Method
fix_up_invoke_inference!(ir::IRCode)

The Problem

Consider the following:

@noinline function bar!(x)
    x .*= 2
end

function foo!(x)
    bar!(x)
    return nothing
end

In this case, the IR associated to Tuple{typeof(foo), Vector{Float64}} will be something along the lines of

julia> Base.code_ircode_by_type(Tuple{typeof(foo), Vector{Float64}})
1-element Vector{Any}:
2 1 ─     invoke Main.bar!(_2::Vector{Float64})::Any
3 └──     return Main.nothing
   => Nothing

Observe that the type inferred for the first line is Any. Inference is at liberty to do this without any risk of performance problems because the first line is not used anywhere else in the function. Had this line been used elsewhere in the function, inference would have inferred its type to be Vector{Float64}.

This causes performance problems for Mooncake, because it uses the return type to do various things, including allocating storage for quantities required on the reverse-pass. Consequently, inference infering Any rather than Vector{Float64} causes type instabilities in the code that Mooncake generates, which can have catastrophic conseqeuences for performance.

The Solution

:invoke expressions contain the Core.MethodInstance associated to them, which contains a Core.CodeCache, which contains the return type of the :invoke. This function looks for :invoke statements whose return type is inferred to be Any in ir, and modifies it to be the return type given by the code cache.

source
Mooncake.foreigncall_to_callMethod
foreigncall_to_call(inst, sp_map::Dict{Symbol, CC.VarState})

If inst is a :foreigncall expression translate it into an equivalent :call expression. If anything else, just return inst. See Mooncake._foreigncall_ for details.

sp_map maps the names of the static parameters to their values. This function is intended to be called in the context of an IRCode, in which case the values of sp_map are given by the sptypes field of said IRCode. The keys should generally be obtained from the Method from which the IRCode is derived. See Mooncake.normalise! for more details.

The purpose of this transformation is to make it possible to differentiate :foreigncall expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.forwards_pass_irMethod
forwards_pass_ir(ir::BBCode, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)

Produce the IR associated to the OpaqueClosure which runs most of the forwards-pass.

source
Mooncake.fwd_irMethod
fwd_ir(
    sig::Type{<:Tuple};
    interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true
)::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Generate the Core.Compiler.IRCode used to construct the forwards-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.

For example, if you wanted to get the IR associated to the forwards pass for the call map(sin, randn(10)), you could do either of the following:

julia> Mooncake.fwd_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
true
julia> Mooncake.fwd_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
Mooncake.gc_preserveMethod
gc_preserve(xs...)

A no-op function. Its rrule!! ensures that the memory associated to xs is not freed until the pullback that it returns is run.

source
Mooncake.generate_irMethod
generate_ir(
    interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true
)

Used by build_rrule, and the various debugging tools: primalir, fwdsir, adjoint_ir.

source
Mooncake.get_interpreterMethod
get_interpreter()

Returns a MooncakeInterpreter appropriate for the current world age. Will use a cached interpreter if one already exists for the current world age, otherwise creates a new one.

This should be prefered over constructing a MooncakeInterpreter directly.

source
Mooncake.get_rev_data_idMethod
get_rev_data_id(info::ADInfo, x)

Returns the ID associated to the line in the reverse pass which will contain the reverse data for x. If x is not an Argument or ID, then nothing is returned.

source
Mooncake.get_tangent_fieldMethod
get_tangent_field(t::Union{MutableTangent, Tangent}, i::Int)

Gets the ith field of data in t.

Has the same semantics that getfield! would have if the data in the fields field of t were actually fields of t. This is the moral equivalent of getfield for MutableTangent.

source
Mooncake.inc_argsMethod
inc_args(stmt)

Increment by 1 the n field of any Arguments present in stmt. Used in make_ad_stmts!.

source
Mooncake.increment!!Method
increment!!(x::T, y::T) where {T}

Add x to y. If ismutabletype(T), then increment!!(x, y) === x must hold. That is, increment!! will mutate x. This must apply recursively if T is a composite type whose fields are mutable.

source
Mooncake.increment_and_get_rdata!Method
increment_and_get_rdata!(fdata, zero_rdata, cr_tangent)

Increment fdata by the fdata component of the ChainRules.jl-style tangent, cr_tangent, and return the rdata component of cr_tangent by adding it to zero_rdata.

source
Mooncake.increment_rdata!!Method
increment_rdata!!(t::T, r)::T where {T}

Increment the rdata component of tangent t by r, and return the updated tangent. Useful for implementation getfield-like rules for mutable structs, pointers, dicts, etc.

source
Mooncake.increment_ref_stmtsMethod
increment_ref_stmts(ref_id::ID, inc_data)::Vector{IDInstPair}

Equivalent to ref[] = increment!!(ref[], inc_data), where ref and inc_data are the values associated to ref_id and inc_data respectively.

source
Mooncake.infer_ir!Method
infer_ir!(ir::IRCode) -> IRCode

Runs type inference on ir, which mutates ir, and returns it.

Note: the compiler will not infer the types of anything where the corrsponding element of ir.stmts.flag is not set to Core.Compiler.IR_FLAG_REFINED. Nor will it attempt to refine the type of the value returned by a :invoke expressions. Consequently, if you find that the types in your IR are not being refined, you may wish to check that neither of these things are happening.

source
Mooncake.interpolate_boundschecks!Method
interpolate_boundschecks!(ir::IRCode)

For every x = Expr(:boundscheck, value) in ir, interpolate value into all uses of x. This is only required in order to ensure that literal versions of memoryrefget, memoryrefset!, getfield, and setfield! work effectively. If they are removed through improvements to the way that we handle constant propagation inside Mooncake, then this functionality can be removed.

source
Mooncake.intrinsic_to_functionMethod
intrinsic_to_function(inst)

If inst is a :call expression to a Core.IntrinsicFunction, replace it with a call to the corresponding function from Mooncake.IntrinsicsWrappers, else return inst.

cglobal is a special case – it requires that its first argument be static in exactly the same way as :foreigncall. See IntrinsicsWrappers.__cglobal for more info.

The purpose of this transformation is to make it possible to use dispatch to write rules for intrinsic calls using dispatch in a type-stable way. See IntrinsicsWrappers for more context.

source
Mooncake.ircodeFunction
ircode(
    inst::Vector{Any},
    argtypes::Vector{Any},
    sptypes::Vector{CC.VarState}=CC.VarState[],
) -> IRCode

Constructs an instance of an IRCode. This is useful for constructing test cases with known properties.

No optimisations or type inference are performed on the resulting IRCode, so that the IRCode contains exactly what is intended by the caller. Please make use of infer_types! if you require the types to be inferred.

Edges in PhiNodes, GotoIfNots, and GotoNodes found in inst must refer to lines (as in CodeInfo). In the IRCode returned by this function, these line references are translated into block references.

source
Mooncake.is_always_fully_initialisedMethod
is_always_fully_initialised(P::DataType)::Bool

True if all fields in P are always initialised. Put differently, there are no inner constructors which permit partial initialisation.

source
Mooncake.is_always_initialisedMethod
is_always_initialised(P::DataType, n::Int)::Bool

True if the nth field of P is always initialised. If the nth fieldtype of P isbitstype, then this is distinct from asking whether the nth field is always defined. An isbits field is always defined, but is not always explicitly initialised.

source
Mooncake.is_primitiveMethod
is_primitive(::Type{Ctx}, sig) where {Ctx}

Returns a Bool specifying whether the methods specified by sig are considered primitives in the context of contexts of type Ctx.

is_primitive(DefaultCtx, Tuple{typeof(sin), Float64})

will return if calling sin(5.0) should be treated as primitive when the context is a DefaultCtx.

Observe that this information means that whether or not something is a primitive in a particular context depends only on static information, not any run-time information that might live in a particular instance of Ctx.

source
Mooncake.is_unreachable_return_nodeMethod
is_unreachable_return_node(x::ReturnNode)

Determine whehter x is a ReturnNode, and if it is, if it is also unreachable. This is purely a function of whether or not its val field is defined or not.

source
Mooncake.is_usedMethod
is_used(info::ADInfo, id::ID)::Bool

Returns true if id is used by any of the lines in the ir, false otherwise.

source
Mooncake.is_vararg_and_sparam_namesMethod
is_vararg_and_sparam_names(m::Method)

Returns a 2-tuple. The first element is true if m is a vararg method, and false if not. The second element contains the names of the static parameters associated to m.

source
Mooncake.lgetfieldMethod
lgetfield(x, f::Val)

An implementation of getfield in which the the field f is specified statically via a Val. This enables the implementation to be type-stable even when it is not possible to constant-propagate f. Moreover, it enable the pullback to also be type-stable.

It will always be the case that

getfield(x, :f) === lgetfield(x, Val(:f))
getfield(x, 2) === lgetfield(x, Val(2))

This approach is identical to the one taken by Zygote.jl to circumvent the same problem. Zygote.jl calls the function literal_getfield, while we call it lgetfield.

source
Mooncake.lgetfieldMethod
lgetfield(x, ::Val{f}, ::Val{order}) where {f, order}

Like getfield, but with the field and access order encoded as types.

source
Mooncake.lift_gc_preservationMethod
lift_gc_preserve(inst)

Expressions of the form

y = GC.@preserve x1 x2 foo(args...)

get lowered to

token = Expr(:gc_preserve_begin, x1, x2)
y = expr
Expr(:gc_preserve_end, token)

These expressions guarantee that any memory associated x1 and x2 not be freed until the :gc_preserve_end expression is reached.

In the context of reverse-mode AD, we must ensure that the memory associated to x1, x2 and their fdata is available during the reverse pass code associated to expr. We do this by preventing the memory from being freed until the :gc_preserve_begin is reached on the reverse pass.

To achieve this, we replace the primal code with

# store `x` in `pb_gc_preserve` to prevent it from being freed.
_, pb_gc_preserve = rrule!!(zero_fcodual(gc_preserve), x1, x2)

# Differentiate the `:call` expression in the usual way.
y, foo_pb = rrule!!(zero_fcodual(foo), args...)

# Do not permit the GC to free `x` here.
nothing

The pullback should be something along the lines of

# no pullback associated to `nothing`.
nothing

# Run the pullback associated to `foo` in the usual manner. `x` must be available.
_, dargs... = foo_pb(dy)

# No-op pullback associated to `gc_preserve`.
pb_gc_preserve(NoRData())
source
Mooncake.lift_getfield_and_othersMethod
lift_getfield_and_others(inst)

Converts expressions of the form getfield(x, :a) into lgetfield(x, Val(:a)). This has identical semantics, but is performant in the absence of proper constant propagation.

Does the same for...

source
Mooncake.lookup_irMethod
lookup_ir(
    interp::AbstractInterpreter,
    sig_or_mi::Union{Type{<:Tuple}, Core.MethodInstance},
)::Tuple{IRCode, T}

Get the unique IR associated to sig_or_mi under interp. Throws ArgumentErrors if there is no code found, or if more than one IRCode instance returned.

Returns a tuple containing the IRCode and its return type.

source
Mooncake.lsetfield!Method
lsetfield!(value, name::Val, x, [order::Val])

This function is to setfield! what lgetfield is to getfield. It will always hold that

setfield!(copy(x), :f, v) == lsetfield!(copy(x), Val(:f), v)
setfield!(copy(x), 2, v) == lsetfield(copy(x), Val(2), v)
source
Mooncake.make_ad_stmts!Function
make_ad_stmts!(inst::NewInstruction, line::ID, info::ADInfo)::ADStmtInfo

Every line in the primal code is associated to one or more lines in the forwards-pass of AD, and one or more lines in the pullback. This function has method specific to every node type in the Julia SSAIR.

Translates the instruction inst, associated to line in the primal, into a specification of what should happen for this instruction in the forwards- and reverse-passes of AD, and what data should be shared between the forwards- and reverse-passes. Returns this in the form of an ADStmtInfo.

info is a data structure containing various bits of global information that certain types of nodes need access to.

source
Mooncake.make_switch_stmtsMethod
make_switch_stmts(
    pred_ids::Vector{ID}, target_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo
)

preds_ids comprises the IDs associated to all possible predecessor blocks to the primal block under consideration. Suppose its value is [ID(1), ID(2), ID(3)], then make_switch_stmts emits code along the lines of

prev_block = pop!(block_stack)
not_pred_was_1 = !(prev_block == ID(1))
not_pred_was_2 = !(prev_block == ID(2))
switch(
    not_pred_was_1 => ID(1),
    not_pred_was_2 => ID(2),
    ID(3)
)

In words: make_switch_stmts emits code which jumps to whichever block preceded the current block during the forwards-pass.

source
Mooncake.new_to_callMethod
new_to_call(x)

If instruction x is a :new expression, replace it with a :call to Mooncake._new_. Otherwise, return x.

The purpose of this transformation is to make it possible to differentiate :new expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.normalise!Method
normalise!(ir::IRCode, spnames::Vector{Symbol})

Apply a sequence of standardising transformations to ir which leaves its semantics unchanged, but makes AD more straightforward. In particular, replace

  1. :foreigncall Exprs with :calls to Mooncake._foreigncall_,
  2. :new Exprs with :calls to Mooncake._new_,
  3. :splatnew Exprs with:calls toMooncake.splatnew_`,
  4. Core.IntrinsicFunctions with counterparts from Mooncake.IntrinsicWrappers,
  5. getfield(x, 1) with lgetfield(x, Val(1)), and related transformations,
  6. memoryrefget calls to lmemoryrefget calls, and related transformations,
  7. gc_preserve_begin / gc_preserve_end exprs so that memory release is delayed.

spnames are the names associated to the static parameters of ir. These are needed when handling :foreigncall expressions, in which it is not necessarily the case that all static parameter names have been translated into either types, or :static_parameter expressions.

Unfortunately, the static parameter names are not retained in IRCode, and the Method from which the IRCode is derived must be consulted. Mooncake.is_vararg_and_sparam_names provides a convenient way to do this.

source
Mooncake.opaque_closureMethod
opaque_closure(
    ret_type::Type,
    ir::IRCode,
    @nospecialize env...;
    isva::Bool=false,
    do_compile::Bool=true,
)::Core.OpaqueClosure{<:Tuple, ret_type}

Construct a Core.OpaqueClosure. Almost equivalent to Core.OpaqueClosure(ir, env...; isva, do_compile), but instead of letting Core.compute_oc_rettype figure out the return type from ir, impose ret_type as the return type.

Warning

User beware: if the Core.OpaqueClosure produced by this function ever returns anything which is not an instance of a subtype of ret_type, you should expect all kinds of awful things to happen, such as segfaults. You have been warned!

Extended Help

This is needed in Mooncake.jl because make extensive use of our ability to know the return type of a couple of specific OpaqueClosures without actually having constructed them – see LazyDerivedRule. Without the capability to specify the return type, we have to guess what type compute_ir_rettype will return for a given IRCode before we have constructed the IRCode and run type inference on it. This exposes us to details of type inference, which are not part of the public interface of the language, and can therefore vary from Julia version to Julia version (including patch versions). Moreover, even for a fixed Julia version it can be extremely hard to predict exactly what type inference will infer to be the return type of a function.

Failing to correctly guess the return type can happen for a number of reasons, and the kinds of errors that tend to be generated when this fails tell you very little about the underlying cause of the problem.

By specifying the return type ourselves, we remove this dependence. The price we pay for this is the potential for segfaults etc if we fail to specify ret_type correctly.

source
Mooncake.optimise_ir!Method
optimise_ir!(ir::IRCode, show_ir=false)

Run a fairly standard optimisation pass on ir. If show_ir is true, displays the IR to stdout at various points in the pipeline – this is sometimes useful for debugging.

source
Mooncake.primal_irMethod
primal_ir(sig::Type{<:Tuple}; interp=get_interpreter())::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Get the Core.Compiler.IRCode associated to sig from which the a rule can be derived. Roughly equivalent to Base.code_ircode_by_type(sig; interp).

For example, if you wanted to get the IR associated to the call map(sin, randn(10)), you could do one of the following calls:

julia> Mooncake.primal_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
true
julia> Mooncake.primal_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
true
source
Mooncake.pullback_irMethod
pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)

Produce the IR associated to the OpaqueClosure which runs most of the pullback.

source
Mooncake.pullback_typeMethod
pullback_type(Trule, arg_types)

Get a bound on the pullback type, given a rule and associated primal types.

source
Mooncake.randn_tangentMethod
randn_tangent(rng::AbstractRNG, x::T) where {T}

Required for testing. Generate a randomly-chosen tangent to x. The design is closely modelled after zero_tangent.

source
Mooncake.rdataMethod
rdata(t)::rdata_type(typeof(t))

Extract the reverse data from tangent t.

Extended help

See extended help section of fdata_type.

source
Mooncake.rdata_field_typeMethod
rdata_field_type(::Type{P}, n::Int) where {P}

Returns the type of to the nth field of the rdata type associated to P. Will be a PossiblyUninitTangent if said field can be undefined.

source
Mooncake.remove_edge!Method
remove_edge!(ir::IRCode, from::Int, to::Int)

Removes an edge in ir from from to to. See implementation for what this entails.

Note: this is slightly different from Core.Compiler.kill_edge!, in that it also updates PhiNodes in the to block. Moreover, the available methods of remove_edge! differ between 1.10 and 1.11, so we need something which is stable across both.

source
Mooncake.replace_capturesMethod
replace_captures(mc::Tmc, new_captures) where {Tmc<:MistyClosure}

Same as replace_captures for Core.OpaqueClosures, but returns a new MistyClosure.

source
Mooncake.replace_capturesMethod
replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure}

Given an OpaqueClosure oc, create a new OpaqueClosure of the same type, but with new captured variables. This is needed for efficiency reasons – if build_rrule is called repeatedly with the same signature and intepreter, it is important to avoid recompiling the OpaqueClosures that it produces multiple times, because it can be quite expensive to do so.

source
Mooncake.replace_uses_with!Method
replace_uses_with!(stmt, def::Union{Argument, SSAValue}, val)

Replace all uses of def with val in the single statement stmt. Note: this function is highly incomplete, really only working correctly for a specific function in ir_normalisation.jl. You probably do not want to use it.

source
Mooncake.reverse_data_ref_stmtsMethod
reverse_data_ref_stmts(info::ADInfo)

Create the :new statements which initialise the reverse-data Refs. Interpolates the initial rdata directly into the statement, which is safe because it is always a bits type.

source
Mooncake.rrule!!Function
rrule!!(f::CoDual, x::CoDual...)

Performs the forwards-pass of AD. The tangent field of f and each x should contain the forwards tangent data (fdata) associated to each corresponding primal field.

Returns a 2-tuple. The first element, y, is a CoDual whose primal field is the value associated to running f.primal(map(x -> x.primal, x)...), and whose tangent field is its associated fdata. The second element contains the pullback, which runs the reverse-pass. It maps from the rdata associated to y to the rdata associated to f and each x.

using Mooncake: zero_fcodual, CoDual, NoFData, rrule!!
y, pb!! = rrule!!(zero_fcodual(sin), CoDual(5.0, NoFData()))
pb!!(1.0)

# output

(NoRData(), 0.28366218546322625)
source
Mooncake.rrule_wrapperMethod
rrule_wrapper(f::CoDual, args::CoDual...)

Used to implement rrule!!s via ChainRulesCore.rrule.

Given a function foo, argument types arg_types, and a method of ChainRulesCore.rrule which applies to these, you can make use of this function as follows:

Mooncake.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...}
function Mooncake.rrule!!(f::CoDual{typeof(foo)}, args::CoDual...)
    return rrule_wrapper(f, args...)
end

Assumes that methods of to_cr_tangent and to_mooncake_tangent are defined such that you can convert between the different representations of tangents that Mooncake and ChainRulesCore expect.

Furthermore, it is essential that

  1. f(args) does not mutate f or args, and
  2. the result of f(args) does not alias any data stored in f or args.

Subject to some constraints, you can use the @from_rrule macro to reduce the amount of boilerplate code that you are required to write even further.

source
Mooncake.rule_typeMethod
rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C}

Compute the concrete type of the rule that will be returned from build_rrule. This is important for performance in dynamic dispatch, and to ensure that recursion works properly.

source
Mooncake.rvs_irMethod
rvs_ir(
    sig::Type{<:Tuple};
    interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true
)::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Generate the Core.Compiler.IRCode used to construct the reverse-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.

For example, if you wanted to get the IR associated to the reverse pass for the call map(sin, randn(10)), you could do either of the following:

julia> Mooncake.rvs_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
true
julia> Mooncake.rvs_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
Mooncake.rvs_phi_blockMethod
rvs_phi_block(pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo)

Produces a BBlock which runs the reverse-pass for the edge associated to pred_id in a collection of IDPhiNodes, and then goes to the block associated to pred_id.

For example, suppose that we encounter the following collection of PhiNodes at the start of some block:

%6 = φ (#2 => _1, #3 => %5)
%7 = φ (#2 => 5., #3 => _2)

Let the rdata refs associated to %6, %7, and _1be denotedr%6,r%7, andr1resp., and letpredidbe#2, andincrement_ref!` be the following function,

increment_ref!(ref, x) = ref[] = increment!!(ref[], x)

then this rvs_phi_block will produce a basic block of the form

increment_ref!(r_1, r%6)
nothing
goto #2

The call to increment_ref! appears because _1 is the value associated to%6 when the primal code comes from #2. Similarly, the goto #2 statement appears because we came from #2 on the forwards-pass. There is no increment_ref! associated to %7 because 5. is a constant. We emit a nothing statement, which the compiler will happily optimise away later on.

The same ideas apply if pred_id were #3. The block would end with #3, and there would be two increment_ref! calls because both %5 and _2 are not constants.

In practice, code which is equivalent to increment_ref! is created directly, rather than inserting a call to a generic Julia function. This is because we need to be certain that the getfield and setfield! calls applied to any references are visible to the SROA optimisation pass. If we insert a call to a function like increment_ref!, it might not be inlined away, making such references opaque.

source
Mooncake.set_tangent_field!Method
set_tangent_field!(t::MutableTangent{Tfields}, i::Int, x) where {Tfields}

Sets the value of the ith field of the data in t to value x.

Has the same semantics that setfield! would have if the data in the fields field of t were actually fields of t. This is the moral equivalent of setfield! for MutableTangent.

source
Mooncake.shared_data_stmtsMethod
shared_data_stmts(p::SharedDataPairs)::Vector{IDInstPair}

Produce a sequence of id-statment pairs which will extract the data from shared_data_tuple(p) such that the correct value is associated to the correct ID.

For example, if p.pairs is

[(ID(5), 5.0), (ID(3), "hello")]

then the output of this function is

IDInstPair[
    (ID(5), new_inst(:(getfield(_1, 1)))),
    (ID(3), new_inst(:(getfield(_1, 2)))),
]
source
Mooncake.shared_data_tupleMethod
shared_data_tuple(p::SharedDataPairs)::Tuple

Create the tuple that will constitute the captured variables in the forwards- and reverse- pass OpaqueClosures.

For example, if p.pairs is

[(ID(5), 5.0), (ID(3), "hello")]

then the output of this function is

(5.0, "hello")
source
Mooncake.sparam_namesMethod
sparam_names(m::Core.Method)::Vector{Symbol}

Returns the names of all of the static parameters in m.

source
Mooncake.splatnew_to_callMethod
splatnew_to_call(x)

If instruction x is a :splatnew expression, replace it with a :call to Mooncake._splat_new_. Otherwise return x.

The purpose of this transformation is to make it possible to differentiate :splatnew expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.stable_allMethod
stable_all(x::NTuple{N, Bool}) where {N}

all(x::NTuple{N, Bool}) does not constant-fold nicely on 1.10 if the values of x are known statically. This implementation constant-folds nicely on both 1.10 and 1.11, so can be used in its place in situations where this is important.

source
Mooncake.stmtMethod
stmt(ir::CC.InstructionStream)

Get the field containing the instructions in ir. This changed name in 1.11 from inst to stmt.

source
Mooncake.tangentMethod
tangent(f, r)

Reconstruct the tangent t for which fdata(t) == f and rdata(t) == r.

source
Mooncake.tangent_test_casesMethod
tangent_test_cases()

Constructs a Vector of Tuples containing test cases for the tangent infrastructure.

If the returned tuple has 2 elements, the elements should be interpreted as follows: 1 - interface_only 2 - primal value

interface_only is a Bool which will be used to determine which subset of tests to run.

If the returned tuple has 5 elements, then the elements are interpreted as follows: 1 - interface_only 2 - primal value 3, 4, 5 - tangents, where <5> == increment!!(<3>, <4>).

Test cases in the first format make use of zero_tangent / randn_tangent etc to generate tangents, but they're unable to check that increment!! is correct in an absolute sense.

source
Mooncake.tangent_typeMethod
tangent_type(P)

There must be a single type used to represents tangents of primals of type P, and it must be given by tangent_type(P).

Warning: this function assumes the effects :removable and :consistent. This is necessary to ensure good performance, but imposes precise constraints on your implementation. If adding new methods to tangent_type, you should consult the extended help of Base.@assume_effects to see what this imposes upon your implementation.

Extended help

The tangent types which Mooncake.jl uses are quite similar in spirit to ChainRules.jl. For example, tangent "vectors" for

  1. Float64s are Float64s,
  2. Vector{Float64}s are Vector{Float64}s, and
  3. structs are other another (special) struct with field types specified recursively.

There are, however, some major differences. Firstly, while it is certainly true that the above tangent types are permissible in ChainRules.jl, they are not the uniquely permissible types. For example, ZeroTangent is also a permissible type of tangent for any of them, and Float32 is permissible for Float64. This is a general theme in ChainRules.jl – it intentionally declines to place restrictions on what type can be used to represent the tangent of a given type.

Mooncake.jl differs from this. It insists that each primal type is associated to a single tangent type. Furthermore, this type is always given by the function Mooncake.tangent_type(primal_type).

Consider some more worked examples.

Int

Int is not a differentiable type, so its tangent type is NoTangent:

julia> tangent_type(Int)
NoTangent

Tuples

The tangent type of a Tuple is defined recursively based on its field types. For example

julia> tangent_type(Tuple{Float64, Vector{Float64}, Int})
Tuple{Float64, Vector{Float64}, NoTangent}

There is one edge case to be aware of: if all of the field of a Tuple are non-differentiable, then the tangent type is NoTangent. For example,

julia> tangent_type(Tuple{Int, Int})
NoTangent

Structs

As with Tuples, the tangent type of a struct is, by default, given recursively. In particular, the tangent type of a struct type is Tangent. This type contains a NamedTuple containing the tangent to each field in the primal struct.

As with Tuples, if all field types are non-differentiable, the tangent type of the entire struct is NoTangent.

There are a couple of additional subtleties to consider over Tuples though. Firstly, not all fields of a struct have to be defined. Fortunately, Julia makes it easy to determine how many of the fields might possibly not be defined. The tangent associated to any field which might possibly not be defined is wrapped in a PossiblyUninitTangent.

Furthermore, structs can have fields whose static type is abstract. For example

julia> struct Foo
           x
       end

If you ask for the tangent type of Foo, you will see that it is

julia> tangent_type(Foo)
Tangent{@NamedTuple{x}}

Observe that the field type associated to x is Any. The way to understand this result is to observe that

  1. x could have literally any type at runtime, so we know nothing about what its tangent type must be until runtime, and
  2. we require that the tangent type of Foo be unique.

The consequence of these two considerations is that the tangent type of Foo must be able to contain any type of tangent in its x field. It follows that the fieldtype of the x field of Foos tangent must be Any.

Mutable Structs

The tangent type for mutable structs have the same set of considerations as structs. The only difference is that they must themselves be mutable. Consequently, we use a type called MutableTangent to represent their tangents. It is a mutable struct with the same structure as Tangent.

For example, if you ask for the tangent_type of

julia> mutable struct Bar
           x::Float64
       end

you will find that it is

julia> tangent_type(Bar)
MutableTangent{@NamedTuple{x::Float64}}

Primitive Types

We've already seen a couple of primitive types (Float64 and Int). The basic story here is that all primitive types require an explicit specification of what their tangent type must be.

One interesting case are Ptr types. The tangent type of a Ptr{P} is Ptr{T}, where T = tangent_type(P). For example

julia> tangent_type(Ptr{Float64})
Ptr{Float64}
source
Mooncake.tangent_typeMethod
tangent_type(F::Type, R::Type)::Type

Given the type of the fdata and rdata, F and R resp., for some primal type, compute its tangent type. This method must be equivalent to tangent_type(_typeof(primal)).

source
Mooncake.tuple_mapMethod
tuple_map(f::F, x::Tuple) where {F}

This function is largely equivalent to map(f, x), but always specialises on all of the element types of x, regardless the length of x. This contrasts with map, in which the number of element types specialised upon is a fixed constant in the compiler.

As a consequence, if x is very long, this function may have very large compile times.

tuple_map(f::F, x::Tuple, y::Tuple) where {F}

Binary extension of tuple_map. Nearly equivalent to map(f, x, y), but guaranteed to specialise on all element types of x and y. Furthermore, errors if x and y aren't the same length, while map will just produce a new tuple whose length is equal to the shorter of x and y.

source
Mooncake.uninit_fcodualMethod
uninit_fcodual(x)

Like zero_fcodual, but doesn't guarantee that the value of the fdata is initialised. See implementation for details, as this function is subject to change.

source
Mooncake.uninit_tangentMethod
uninit_tangent(x)

Related to zero_tangent, but a bit different. Check current implementation for details – this docstring is intentionally non-specific in order to avoid becoming outdated.

source
Mooncake.verify_fdata_typeMethod
verify_fdata_type(P::Type, F::Type)::Nothing

Check that F is a valid type for fdata associated to a primal of type P. Returns nothing if valid, throws an InvalidFDataException if a problem is found.

This applies to both concrete and non-concrete P. For example, if P is the type inferred for a primal q::Q, such that Q <: P, then this method is still applicable.

source
Mooncake.verify_fdata_valueMethod
verify_fdata_value(p, f)::Nothing

Check that f cannot be proven to be invalid fdata for p.

This method attempts to provide some confidence that f is valid fdata for p by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.

Put differently, we cannot prove that f is valid fdata, only that it is not obviously invalid.

source
Mooncake.verify_no_constant_gotoifnotsMethod
verify_no_constant_gotoifnots(ir::IRCode)

Verify that we have successfully removed all instances of goto %n if not true and goto %n if not false, as these can be reduced to simpler nodes (namely, GotoNodes or "fallthrough"s). Moreover, removing them tends to yield performance improvements by reducing the amount of information Mooncake must keep in its block stacks.

This is essentially just testing functionality for const_prop_constant_gotoifnots. This is usually run each time a rule is compiled, as it is cheap, and because it is hard to construct a convincing set of test cases which, if passed at test-time, would indicate we were done.

source
Mooncake.verify_rdata_typeMethod
verify_rdata_type(P::Type, R::Type)::Nothing

Check that R is a valid type for rdata associated to a primal of type P. Returns nothing if valid, throws an InvalidRDataException if a problem is found.

This applies to both concrete and non-concrete P. For example, if P is the type inferred for a primal q::Q, such that Q <: P, then this method is still applicable.

source
Mooncake.verify_rdata_valueMethod
verify_rdata_value(p, r)::Nothing

Check that r cannot be proven to be invalid rdata for p.

This method attempts to provide some confidence that r is valid rdata for p by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.

Put differently, we cannot prove that r is valid rdata, only that it is not obviously invalid.

source
Mooncake.zero_adjointMethod
zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}

Utility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.

NOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.

You make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:

julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual

julia> foo(x::Vararg{Int}) = 5
foo (generic function with 1 method)

julia> is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(foo), Vararg{Int}}}) = true;

julia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);

julia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())
(NoRData(), NoRData(), NoRData())

WARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```

source
Mooncake.zero_like_rdata_from_typeMethod
zero_like_rdata_from_type(::Type{P}) where {P}

This is an internal implementation detail – you should generally not use this function.

Returns either the zero element of type rdata_type(tangent_type(P)), or a ZeroRData. It is always valid to return a ZeroRData,

source
Mooncake.zero_like_rdata_typeMethod
zero_like_rdata_type(::Type{P}) where {P}

Indicates the type which will be returned by zero_like_rdata_from_type. Will be the rdata type for P if we can produce the zero rdata element given only P, and will be the union of R and ZeroRData if an instance of P is needed.

source
Mooncake.zero_rdata_from_typeMethod
zero_rdata_from_type(::Type{P}) where {P}

Returns the zero element of rdata_type(tangent_type(P)) if this is possible given only P. If not possible, returns an instance of CannotProduceZeroRDataFromType.

For example, the zero rdata associated to any primal of type Float64 is 0.0, so for Float64s this function is simple. Similarly, if the rdata type for P is NoRData, that can simply be returned.

However, it is not possible to return the zero rdata element for abstract types e.g. Real as the type does not uniquely determine the zero element – the rdata type for Real is Any.

These considerations apply recursively to tuples / namedtuples / structs, etc.

If you encounter a type which this function returns CannotProduceZeroRDataFromType, but you believe this is done in error, please open an issue. This kind of problem does not constitute a correctness problem, but can be detrimental to performance, so should be dealt with.

source
Mooncake.zero_tangentMethod
zero_tangent(x)

Returns the unique zero element of the tangent space of x. It is an error for the zero element of the tangent space of x to be represented by anything other than that which this function returns.

Internally, zero_tangent calls zero_tangent_internal, which handles different types of inputs differently. zero_tangent_internal has two variants:

  1. For isbitstype types, zero_tangent_internal takes one argument.
  2. Otherwise, zero_tangent_internal takes another argument which is an IdDict, which

handles both circular references and aliasing correctly.

source
Mooncake.@from_rruleMacro
@from_rrule ctx sig [has_kwargs=false]

Convenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.

Arguments

  • ctx: A Mooncake context type
  • sig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.
  • has_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.

Example Usage

A Basic Example

julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils

julia> import ChainRulesCore

julia> foo(x::Real) = 5x;

julia> function ChainRulesCore.rrule(::typeof(foo), x::Real)
           foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), 5Ω
           return foo(x), foo_pb
       end;

julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat}

julia> rrule!!(zero_fcodual(foo), zero_fcodual(5.0))[2](1.0)
(NoRData(), 5.0)

julia> # Check that the rule works as intended.
       TestUtils.test_rule(Xoshiro(123), foo, 5.0; is_primitive=true)
Test Passed

An Example with Keyword Arguments

julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils

julia> import ChainRulesCore

julia> foo(x::Real; cond::Bool) = cond ? 5x : 4x;

julia> function ChainRulesCore.rrule(::typeof(foo), x::Real; cond::Bool)
           foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), cond ? 5Ω : 4Ω
           return foo(x; cond), foo_pb
       end;

julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} true

julia> _, pb = rrule!!(
           zero_fcodual(Core.kwcall),
           zero_fcodual((cond=false, )),
           zero_fcodual(foo),
           zero_fcodual(5.0),
       );

julia> pb(3.0)
(NoRData(), NoRData(), NoRData(), 12.0)

julia> # Check that the rule works as intended.
       TestUtils.test_rule(
           Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true
       )
Test Passed

Notice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.

Limitations

It is your responsibility to ensure that

  1. calls with signature sig do not mutate their arguments,
  2. the output of calls with signature sig does not alias any of the inputs.

As with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.

Argument Type Constraints

Many methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature

Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}

There are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.

Suffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.

Conversions Between Different Tangent Type Systems

Under the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.

source
Mooncake.@is_primitiveMacro
@is_primitive context_type signature

Creates a method of is_primitive which always returns true for the context_type and signature provided. For example

@is_primitive MinimalCtx Tuple{typeof(foo), Float64}

is equivalent to

is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(foo), Float64}}) = true

You should implemented more complicated method of is_primitive in the usual way.

source
Mooncake.@mooncake_overlayMacro
@mooncake_overlay method_expr

Define a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.

For example, suppose that you have a function

julia> foo(x::Float64) = bar(x)
foo (generic function with 1 method)

where Mooncake.jl fails to differentiate bar for some reason. If you have access to another function baz, which does the same thing as bar, but does so in a way which Mooncake.jl can differentiate, you can simply write:

julia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x)

When looking up the code for foo(::Float64), Mooncake.jl will see this method, rather than the original, and differentiate it instead.

A Worked Example

To demonstrate how to use @mooncake_overlays in practice, we here demonstrate how the answer that Mooncake.jl gives changes if you change the definition of a function using a @mooncake_overlay. Do not do this in practice – this is just a simple way to demonostrate how to use overlays!

First, consider a simple example:

julia> scale(x) = 2x
scale (generic function with 1 method)

julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});

julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(10.0, (NoTangent(), 2.0))

We can use @mooncake_overlay to change the definition which Mooncake.jl sees:

julia> Mooncake.@mooncake_overlay scale(x) = 3x

julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});

julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(15.0, (NoTangent(), 3.0))

As can be seen from the output, the result of differentiating using Mooncake.jl has changed to reflect the overlay-ed definition of the method.

Additionally, it is possible to use the usual multi-line syntax to declare an overlay:

julia> Mooncake.@mooncake_overlay function scale(x)
           return 4x
       end

julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});

julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(20.0, (NoTangent(), 4.0))
source
Mooncake.@zero_adjointMacro
@zero_adjoint ctx sig

Defines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.

For example:

julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive

julia> foo(x) = 5
foo (generic function with 1 method)

julia> @zero_adjoint DefaultCtx Tuple{typeof(foo), Any}

julia> is_primitive(DefaultCtx, Tuple{typeof(foo), Any})
true

julia> rrule!!(zero_fcodual(foo), zero_fcodual(3.0))[2](NoRData())
(NoRData(), 0.0)

Limited support for Varargs is also available. For example

julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive

julia> foo_varargs(x...) = 5
foo_varargs (generic function with 1 method)

julia> @zero_adjoint DefaultCtx Tuple{typeof(foo_varargs), Vararg}

julia> is_primitive(DefaultCtx, Tuple{typeof(foo_varargs), Any, Float64, Int})
true

julia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData())
(NoRData(), 0.0, NoRData())

Be aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.

WARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.

As always, you should use TestUtils.test_rule to ensure that you've not made a mistake.

Signatures Unsupported By This Macro

If the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.

source
Mooncake.IntrinsicsWrappersModule
module IntrinsicsWrappers

The purpose of this module is to associate to each function in Core.Intrinsics a regular Julia function.

To understand the rationale for this observe that, unlike regular Julia functions, each Core.IntrinsicFunction in Core.Intrinsics does not have its own type. Rather, they are instances of Core.IntrinsicFunction. To see this, observe that

julia> typeof(Core.Intrinsics.add_float)
Core.IntrinsicFunction

julia> typeof(Core.Intrinsics.sub_float)
Core.IntrinsicFunction

While we could simply write a rule for Core.IntrinsicFunction, this would (naively) lead to a large list of conditionals of the form

if f === Core.Intrinsics.add_float
    # return add_float and its pullback
elseif f === Core.Intrinsics.sub_float
    # return add_float and its pullback
elseif
    ...
end

which has the potential to cause quite substantial type instabilities. (This might not be true anymore – see extended help for more context).

Instead, we map each Core.IntrinsicFunction to one of the regular Julia functions in Mooncake.IntrinsicsWrappers, to which we can dispatch in the usual way.

Extended Help

It is possible that owing to improvements in constant propagation in the Julia compiler in version 1.10, we actually could get away with just writing a single method of rrule!! to handle all intrinsics, so this dispatch-based mechanism might be unnecessary. Someone should investigate this. Discussed at https://github.com/compintell/Mooncake.jl/issues/387 .

source