Internal Docstrings

Docstrings listed here are not part of the public Mooncake.jl interface. Consequently, they can change between non-breaking changes to Mooncake.jl without warning.

The purpose of this is to make it easy for developers to find docstrings straightforwardly via the docs, as opposed to having to ctrl+f through Mooncake.jl's source code, or looking at the docstrings via the Julia REPL.

Mooncake.TerminatorType
Terminator = Union{Switch, IDGotoIfNot, IDGotoNode, ReturnNode}

A Union of the possible types of a terminator node.

source
Core.Compiler.IRCodeMethod
IRCode(bb_code::BBCode)

Produce an IRCode instance which is equivalent to bb_code. The resulting IRCode shares no memory with bb_code, so can be safely mutated without modifying bb_code.

All IDPhiNodes, IDGotoIfNots, and IDGotoNodes are converted into PhiNodes, GotoIfNots, and GotoNodes respectively.

In the resulting bb_code, any Switch nodes are lowered into a semantically-equivalent collection of GotoIfNot nodes.

source
Mooncake.ADInfoType
ADInfo

This data structure is used to hold "global" information associated to a particular call to build_rrule. It is used as a means of communication between make_ad_stmts! and the codegen which produces the forwards- and reverse-passes.

  • interp: a MooncakeInterpreter.
  • block_stack_id: the ID associated to the block stack – the stack which keeps track of which blocks we visited during the forwards-pass, and which is used on the reverse-pass to determine which blocks to visit.
  • block_stack: the block stack. Can always be found at block_stack_id in the forwards- and reverse-passes.
  • entry_id: ID associated to the block inserted at the start of execution in the the forwards-pass, and the end of execution in the pullback.
  • shared_data_pairs: the SharedDataPairs used to define the captured variables passed to both the forwards- and reverse-passes.
  • arg_types: a map from Argument to its static type.
  • ssa_insts: a map from ID associated to lines to the primal NewInstruction. This contains the line of code, its static / inferred type, and some other detailss. See Core.Compiler.NewInstruction for a full list of fields.
  • arg_rdata_ref_ids: the dict mapping from arguments to the ID which creates and initialises the Ref which contains the reverse data associated to that argument. Recall that the heap allocations associated to this Ref are always optimised away in the final programme.
  • ssa_rdata_ref_ids: the same as arg_rdata_ref_ids, but for each ID associated to an ssa rather than each argument.
  • debug_mode: if true, run in "debug mode" – wraps all rule calls in DebugRRule. This is applied recursively, so that debug mode is also switched on in derived rules.
  • is_used_dict: for each ID associated to a line of code, is false if line is not used anywhere in any other line of code.
  • lazy_zero_rdata_ref_id: for any arguments whose type doesn't permit the construction of a zero-valued rdata directly from the type alone (e.g. a struct with an abstractly- typed field), we need to have a zero-valued rdata available on the reverse-pass so that this zero-valued rdata can be returned if the argument (or a part of it) is never used during the forwards-pass and consequently doesn't obtain a value on the reverse-pass. To achieve this, we construct a LazyZeroRData for each of the arguments on the forwards-pass, and make use of it on the reverse-pass. This field is the ID that will be associated to this information.
source
Mooncake.ADStmtInfoType
ADStmtInfo

Data structure which contains the result of make_ad_stmts!. Fields are

  • line: the ID associated to the primal line from which this is derived
  • comms_id: an ID from one of the lines in fwds, whose value will be made available on the reverse-pass in the same ID. Nothing is asserted about how this value is made available on the reverse-pass of AD, so this package is free to do this in whichever way is most efficient, in particular to group these communication ID on a per-block basis.
  • fwds: the instructions which run the forwards-pass of AD
  • rvs: the instructions which run the reverse-pass of AD / the pullback
source
Mooncake.BBCodeMethod
BBCode(ir::IRCode)

Convert an ir into a BBCode. Creates a completely independent data structure, so mutating the BBCode returned will not mutate ir.

All PhiNodes, GotoIfNots, and GotoNodes will be replaced with the IDPhiNodes, IDGotoIfNots, and IDGotoNodes respectively.

See IRCode for conversion back to IRCode.

Note that IRCode(BBCode(ir)) should be equal to the identity function.

source
Mooncake.BBCodeMethod
BBCode(ir::Union{IRCode, BBCode}, new_blocks::Vector{Block})

Make a new BBCode whose blocks is given by new_blocks, and fresh copies are made of all other fields from ir.

source
Mooncake.BBCodeType
BBCode(
    blocks::Vector{BBlock}
    argtypes::Vector{Any}
    sptypes::Vector{CC.VarState}
    linetable::Vector{Core.LineInfoNode}
    meta::Vector{Expr}
)

A BBCode is a data structure which is similar to IRCode, but adds additional structure.

In particular, a BBCode comprises a sequence of basic blocks (BBlocks), each of which comprise a sequence of statements. Moreover, each BBlock has its own unique ID, as does each statment.

The consequence of this is that new basic blocks can be inserted into a BBCode. This is distinct from IRCode, in which to create a new basic block, one must insert additional statments which you know will create a new basic block – this is generally quite an unreliable process, while inserting a new BBlock into BBCode is entirely predictable. Furthermore, inserting a new BBlock does not change the ID associated to the other blocks, meaning that you can safely assume that references from existing basic block terminators / phi nodes to other blocks will not be modified by inserting a new basic block.

Additionally, since each statment in each basic block has its own unique ID, new statments can be inserted without changing references between other blocks. IRCode also has some support for this via its new_nodes field, but eventually all statements will be renamed upon compact!ing the IRCode, meaning that the name of any given statement will eventually change.

Finally, note that the basic blocks in a BBCode support the custom Switch statement. This statement is not valid in IRCode, and is therefore lowered into a collection of GotoIfNots and GotoNodes when a BBCode is converted back into an IRCode.

source
Mooncake.BBlockMethod
BBlock(id::ID, inst_pairs::Vector{IDInstPair})

Convenience constructor – splits inst_pairs into a Vector{ID} and InstVector in order to build a BBlock.

source
Mooncake.BBlockType
BBlock(id::ID, stmt_ids::Vector{ID}, stmts::InstVector)

A basic block data structure (not called BasicBlock to avoid accidental confusion with CC.BasicBlock). Forms a single basic block.

Each BBlock has an ID (a unique name). This makes it possible to refer to blocks in a way that does not change when additional BBlocks are inserted into a BBCode. This differs from the positional block numbering found in IRCode, in which the number associated to a basic block changes when new blocks are inserted.

The nth line of code in a BBlock is associated to ID stmt_ids[n], and the nth instruction from stmts.

Note that PhiNodes, GotoIfNots, and GotoNodes should not appear in a BBlock – instead an IDPhiNode, IDGotoIfNot, or IDGotoNode should be used.

source
Mooncake.BlockStackType

The block stack is the stack used to keep track of which basic blocks are visited on the forwards pass, and therefore which blocks need to be visited on the reverse pass. There is one block stack per derived rule. By using Int32, we assume that there aren't more than typemax(Int32) unique basic blocks in a given function, which ought to be reasonable.

source
Mooncake.CannotProduceZeroRDataFromTypeType
CannotProduceZeroRDataFromType()

Returned by zero_rdata_from_type if is not possible to construct the zero rdata element for a given type. See zero_rdata_from_type for more info.

source
Mooncake.ConfigType
Config(; debug_mode=false, silence_debug_messages=false)

Configuration struct for use with ADTypes.AutoMooncake.

source
Mooncake.DebugPullbackMethod
(pb::DebugPullback)(dy)

Apply type checking to enforce pre- and post-conditions on pb.pb. See the docstring for DebugPullback for details.

source
Mooncake.DebugPullbackType
DebugPullback(pb, y, x)

Construct a callable which is equivalent to pb, but which enforces type-based pre- and post-conditions to pb. Let dx = pb.pb(dy), for some rdata dy, then this function

  • checks that dy has the correct rdata type for y, and
  • checks that each element of dx has the correct rdata type for x.

Reverse pass counterpart to DebugRRule

source
Mooncake.DebugRRuleMethod
(rule::DebugRRule)(x::CoDual...)

Apply type checking to enforce pre- and post-conditions on rule.rule. See the docstring for DebugRRule for details.

source
Mooncake.DebugRRuleType
DebugRRule(rule)

Construct a callable which is equivalent to rule, but inserts additional type checking. In particular:

  • check that the fdata in each argument is of the correct type for the primal
  • check that the fdata in the CoDual returned from the rule is of the correct type for the primal.

This happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.

Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).

Let rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.

Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.

source
Mooncake.DefaultCtxType
struct DefaultCtx end

Context for all usually used AD primitives. Anything which is a primitive in a MinimalCtx is a primitive in the DefaultCtx automatically. If you are adding a rule for the sake of performance, it should be a primitive in the DefaultCtx, but not the MinimalCtx.

source
Mooncake.DynamicDerivedRuleType
DynamicDerivedRule(interp::MooncakeInterpreter, debug_mode::Bool)

For internal use only.

A callable data structure which, when invoked, calls an rrule specific to the dynamic types of its arguments. Stores rules in an internal cache to avoid re-deriving.

This is used to implement dynamic dispatch.

source
Mooncake.FDataType
FData(data::NamedTuple)

The component of a struct which is propagated alongside the primal on the forwards-pass of AD. For example, the tangents for Float64s do not need to be propagated on the forwards- pass of reverse-mode AD, so any Float64 fields of Tangent do not need to appear in the associated FData.

source
Mooncake.IDType
ID()

An ID (read: unique name) is just a wrapper around an Int32. Uniqueness is ensured via a global counter, which is incremented each time that an ID is created.

This counter can be reset using seed_id! if you need to ensure deterministic IDs are produced, in the same way that seed for random number generators can be set.

source
Mooncake.IDPhiNodeType
IDPhiNode(edges::Vector{ID}, values::Vector{Any})

Like a PhiNode, but edges are IDs rather than Int32s.

source
Mooncake.InstVectorType
const InstVector = Vector{NewInstruction}

Note: the CC.NewInstruction type is used to represent instructions because it has the correct fields. While it is only used to represent new instrucdtions in Core.Compiler, it is used to represent all instructions in BBCode.

source
Mooncake.LazyDerivedRuleType
LazyDerivedRule(interp, mi::Core.MethodInstance, debug_mode::Bool)

For internal use only.

A type-stable wrapper around a DerivedRule, which only instantiates the DerivedRule when it is first called. This is useful, as it means that if a rule does not get run, it does not have to be derived.

If debug_mode is true, then the rule constructed will be a DebugRRule. This is useful when debugging, but should usually be switched off for production code as it (in general) incurs some runtime overhead.

Note: the signature of the primal for which this is a rule is stored in the type. The only reason to keep this around is for debugging – it is very helpful to have this type visible in the stack trace when something goes wrong, as it allows you to trivially determine which bit of your code is the culprit.

source
Mooncake.LazyZeroRDataType
LazyZeroRData{P, Tdata}()

This type is a lazy placeholder for zero_like_rdata_from_type. This is used to defer construction of zero data to the reverse pass. Calling instantiate on an instance of this will construct a zero data.

Users should construct using LazyZeroRData(p), where p is an value of type P. This constructor, and instantiate, are specialised to minimise the amount of data which must be stored. For example, Float64s do not need any data, so LazyZeroRData(0.0) produces an instance of a singleton type, meaning that various important optimisations can be performed in AD.

source
Mooncake.MinimalCtxType
struct MinimalCtx end

Functions should only be primitives in this context if not making them so would cause AD to fail. In particular, do not add primitives to this context if you are writing them for performance only – instead, make these primitives in the DefaultCtx.

source
Mooncake.NoPullbackMethod
NoPullback(args::CoDual...)

Construct a NoPullback from the arguments passed to an rrule!!. For each argument, extracts the primal value, and constructs a LazyZeroRData. These are stored in a NoPullback which, in the reverse-pass of AD, instantiates these LazyZeroRDatas and returns them in order to perform the reverse-pass of AD.

The advantage of this approach is that if it is possible to construct the zero rdata element for each of the arguments lazily, the NoPullback generated will be a singleton type. This means that AD can avoid generating a stack to store this pullback, which can result in significant performance improvements.

source
Mooncake.RRuleZeroWrapperType
RRuleZeroWrapper(rule)

This struct is used to ensure that ZeroRDatas, which are used as placeholder zero elements whenever an actual instance of a zero rdata for a particular primal type cannot be constructed without also having an instance of said type, never reach rules. On the pullback, we increment the cotangent dy by an amount equal to zero. This ensures that if it is a ZeroRData, we instead get an actual zero of the correct type. If it is not a zero rdata, the computation should be elided via inlining + constant prop.

source
Mooncake.SharedDataPairsType
SharedDataPairs()

A data structure used to manage the captured data in the OpaqueClosures which implement the bulk of the forwards- and reverse-passes of AD. An entry (id, data) at element n of the pairs field of this data structure means that data will be available at register id during the forwards- and reverse-passes of AD.

This is achieved by storing all of the data in the pairs field in the captured tuple which is passed to an OpaqueClosure, and extracting this data into registers associated to the corresponding IDs.

source
Mooncake.StackType
Stack{T}()

A stack specialised for reverse-mode AD.

Semantically equivalent to a usual stack, but never de-allocates memory once allocated.

source
Mooncake.SwitchType
Switch(conds::Vector{Any}, dests::Vector{ID}, fallthrough_dest::ID)

A switch-statement node. These can be inserted in the BBCode representation of Julia IR. Switch has the following semantics:

goto dests[1] if not conds[1]
goto dests[2] if not conds[2]
...
goto dests[N] if not conds[N]
goto fallthrough_dest

where the value associated to each element of conds is a Bool, and dests indicate which block to jump to. If none of the conditions are met, then we go to whichever block is specified by fallthrough_dest.

Switch statements are lowered into the above sequence of GotoIfNots and GotoNodes when converting BBCode back into IRCode, because Switch statements are not valid nodes in regular Julia IR.

source
Mooncake.ZeroRDataType
ZeroRData()

Singleton type indicating zero-valued rdata. This should only ever appear as an intermediate quantity in the reverse-pass of AD when the type of the primal is not fully inferable, or a field of a type is abstractly typed.

If you see this anywhere in actual code, or if it appears in a hand-written rule, this is an error – please open an issue in such a situation.

source
Base.insert!Method
Base.insert!(bb::BBlock, n::Int, id::ID, stmt::CC.NewInstruction)::Nothing

Inserts stmt and id into bb immediately before the nth instruction.

source
Mooncake.__flatten_varargsMethod
__flatten_varargs(isva::Bool, args, ::Val{nvargs}) where {nvargs}

If isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0).

source
Mooncake.__insts_to_instruction_streamMethod
__insts_to_instruction_stream(insts::Vector{Any})

Produces an instruction stream whose

  • stmt (v1.11 and up) / inst (v1.10) field is insts,
  • type field is all Any,
  • info field is all Core.Compiler.NoCallInfo,
  • line field is all Int32(1), and
  • flag field is all Core.Compiler.IR_FLAG_REFINED.

As such, if you wish to ensure that your IRCode prints nicely, you should ensure that its linetable field has at least one element.

source
Mooncake.__line_numbers_to_block_numbers!Method
__line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG)

Converts any edges in GotoNodes, GotoIfNots, PhiNodes, and :enter expressions which refer to line numbers into references to block numbers. The cfg provides the information required to perform this conversion.

For context, CodeInfo objects have references to line numbers, while IRCode uses block numbers.

This code is copied over directly from the body of Core.Compiler.inflate_ir!.

source
Mooncake.__pop_blk_stack!Method
__pop_blk_stack!(block_stack::BlockStack)

Equivalent to pop!(block_stack). Going via this function, rather than just calling pop! directly, makes it easy to figure out how much time is spent popping the block stack when profiling performance, and to know that this function was hit when debugging.

source
Mooncake.__push_blk_stack!Method
__push_blk_stack!(block_stack::BlockStack, id::Int32)

Equivalent to push!(block_stack, id). Going via this function, rather than just calling push! directly, is helpful for debugging and performance analysis – it makes it very straightforward to figure out much time is spent pushing to the block stack when profiling.

source
Mooncake.__run_rvs_pass!Method
__run_rvs_pass!(
    P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...
) where {sig}

Used in make_ad_stmts! method for Expr(:call, ...) and Expr(:invoke, ...).

source
Mooncake.__unflatten_codual_varargsMethod
__unflatten_codual_varargs(isva::Bool, args, ::Val{nargs}) where {nargs}

If isva and nargs=2, then inputs (CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0)) are transformed into (CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0))).

source
Mooncake.__value_and_gradient!!Method
__value_and_gradient!!(rule, f::CoDual, x::CoDual...)

Note: this is not part of the public Mooncake.jl interface, and may change without warning.

Equivalent to __value_and_pullback!!(rule, 1.0, f, x...) – assumes f returns a Float64.

# Set up the problem.
f(x, y) = sum(x .* y)
x = [2.0, 2.0]
y = [1.0, 1.0]
rule = build_rrule(f, x, y)

# Allocate tangents. These will be written to in-place. You are free to re-use these if you
# compute gradients multiple times.
tf = zero_tangent(f)
tx = zero_tangent(x)
ty = zero_tangent(y)

# Do AD.
Mooncake.__value_and_gradient!!(
    rule, Mooncake.CoDual(f, tf), Mooncake.CoDual(x, tx), Mooncake.CoDual(y, ty)
)
# output

(4.0, (NoTangent(), [1.0, 1.0], [2.0, 2.0]))
source
Mooncake.__value_and_pullback!!Method
__value_and_pullback!!(rule, ȳ, f::CoDual, x::CoDual...; y_cache=nothing)

Note: this is not part of the public Mooncake.jl interface, and may change without warning.

In-place version of value_and_pullback!! in which the arguments have been wrapped in CoDuals. Note that any mutable data in f and x will be incremented in-place. As such, if calling this function multiple times with different values of x, should be careful to ensure that you zero-out the tangent fields of x each time.

source
Mooncake._block_nums_to_idsMethod
_block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID}, InstVector}

Assign to each basic block in cfg an ID. Replace all integers referencing block numbers in insts with the corresponding ID. Return the IDs and the updated instructions.

source
Mooncake._build_graph_of_cfgMethod
_build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph, Dict{ID, Int}}

Builds a SimpleDiGraph, g, representing of the CFG associated to blks, where blks comprises the collection of basic blocks associated to a BBCode. This is a type from Graphs.jl, so constructing g makes it straightforward to analyse the control flow structure of ir using algorithms from Graphs.jl.

Returns a 2-tuple, whose first element is g, and whose second element is a map from the ID associated to each basic block in ir, to the Int corresponding to its node index in g.

source
Mooncake._compute_all_predecessorsMethod
_compute_all_predecessors(blks::Vector{BBlock})::Dict{ID, Vector{ID}}

Internal method implementing compute_all_predecessors. This method is easier to construct test cases for because it only requires the collection of BBlocks, not all of the other stuff that goes into a BBCode.

source
Mooncake._compute_all_successorsMethod
_compute_all_successors(blks::Vector{BBlock})::Dict{ID, Vector{ID}}

Internal method implementing compute_all_successors. This method is easier to construct test cases for because it only requires the collection of BBlocks, not all of the other stuff that goes into a BBCode.

source
Mooncake._control_flow_graphMethod
_control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG

Internal function, used to implement control_flow_graph. Easier to write test cases for because there is no need to construct an ensure BBCode object, just the BBlocks.

source
Mooncake._distance_to_entryMethod
_distance_to_entry(blks::Vector{BBlock})::Vector{Int}

For each basic block in blks, compute the distance from it to the entry point (the first block. The distance is typemax(Int) if no path from the entry point to a given node.

source
Mooncake._find_id_uses!Method
_find_id_uses!(d::Dict{ID, Bool}, x)

Helper function used in characterise_used_ids. For all uses of IDs in x, set the corresponding value of d to true.

For example, if x = ReturnNode(ID(5)), then this function sets d[ID(5)] = true.

source
Mooncake._foreigncall_Method
function _foreigncall_(
    ::Val{name}, ::Val{RT}, AT::Tuple, ::Val{nreq}, ::Val{calling_convention}, x...
) where {name, RT, nreq, calling_convention}

:foreigncall nodes get translated into calls to this function. For example,

Expr(:foreigncall, :foo, Tout, (A, B), nreq, :ccall, args...)

becomes

_foreigncall_(Val(:foo), Val(Tout), (Val(A), Val(B)), Val(nreq), Val(:ccall), args...)

Please consult the Julia documentation for more information on how foreigncall nodes work, and consult this package's tests for examples.

Credit: Umlaut.jl has the original implementation of this function. This is largely copied over from there.

source
Mooncake._ids_to_line_numbersMethod
_ids_to_line_numbers(bb_code::BBCode)::InstVector

For each statement in bb_code, returns a NewInstruction in which every ID is replaced by either an SSAValue, or an Int64 / Int32 which refers to an SSAValue.

source
Mooncake._is_reachableMethod
_is_reachable(blks::Vector{BBlock})::Vector{Bool}

Computes a Vector whose length is length(blks). The nth element is true iff it is possible for control flow to reach the nth block.

source
Mooncake._lines_to_blocksMethod
_instructions_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector

Pulls out the instructions from insts, and calls __line_numbers_to_block_numbers!.

source
Mooncake._lower_switch_statementsMethod
_lower_switch_statements(bb_code::BBCode)

Converts all Switchs into a semantically-equivalent collection of GotoIfNots. See the Switch docstring for an explanation of what is going on here.

source
Mooncake._mapMethod
_map(f, x...)

Same as map but requires all elements of x to have equal length. The usual function map doesn't enforce this for Arrays.

source
Mooncake._map_if_assigned!Method
_map_if_assigned!(f::F, y::DenseArray, x1::DenseArray{P}, x2::DenseArray)

Similar to the other method of _map_if_assigned! – for all n, if x1[n] is assigned, writes f(x1[n], x2[n]) to y[n], otherwise leaves y[n] unchanged.

Requires that y, x1, and x2 have the same size.

source
Mooncake._map_if_assigned!Method
_map_if_assigned!(f, y::DenseArray, x::DenseArray{P}) where {P}

For all n, if x[n] is assigned, then writes the value returned by f(x[n]) to y[n], otherwise leaves y[n] unchanged.

Equivalent to map!(f, y, x) if P is a bits type as element will always be assigned.

Requires that y and x have the same size.

source
Mooncake._new_Method
_new_(::Type{T}, x::Vararg{Any, N}) where {T, N}

One-liner which calls the :new instruction with type T with arguments x.

source
Mooncake._remove_double_edgesMethod
_remove_double_edges(ir::BBCode)::BBCode

If the dest field of an IDGotoIfNot node in block n of ir points towards the n+1th block then we have two edges from block n to block n+1. This transformation replaces all such IDGotoIfNot nodes with unconditional IDGotoNodes pointing towards the n+1th block in ir.

source
Mooncake._sort_blocks!Method
_sort_blocks!(ir::BBCode)::BBCode

Ensure that blocks appear in order of distance-from-entry-point, where distance the distance from block b to the entry point is defined to be the minimum number of basic blocks that must be passed through in order to reach b.

For reasons unknown (to me, Will), the compiler / optimiser needs this for inference to succeed. Since we do quite a lot of re-ordering on the reverse-pass of AD, this is a problem there.

WARNING: use with care. Only use if you are confident that arbitrary re-ordering of basic blocks in ir is valid. Notably, this does not hold if you have any IDGotoIfNot nodes in ir.

source
Mooncake._ssa_to_idsMethod
_ssa_to_ids(d::SSAToIdDict, inst::NewInstruction)

Produce a new instance of inst in which all instances of SSAValues are replaced with the IDs prescribed by d, all basic block numbers are replaced with the IDs prescribed by d, and GotoIfNot, GotoNode, and PhiNode instances are replaced with the corresponding ID versions.

source
Mooncake._ssas_to_idsMethod
_ssas_to_ids(insts::InstVector)::Tuple{Vector{ID}, InstVector}

Assigns an ID to each line in stmts, and replaces each instance of an SSAValue in each line with the corresponding ID. For example, a call statement of the form Expr(:call, :f, %4) is be replaced with Expr(:call, :f, id_assigned_to_%4).

source
Mooncake._to_ssasMethod
_to_ssas(d::Dict, inst::NewInstruction)

Like _ssas_to_ids, but in reverse. Converts IDs to SSAValues / (integers corresponding to ssas).

source
Mooncake._typeofMethod
_typeof(x)

Central definition of typeof, which is specific to the use-required in this package.

source
Mooncake.ad_stmt_infoMethod
ad_stmt_info(line::ID, comms_id::Union{ID, Nothing}, fwds, rvs)

Convenient constructor for ADStmtInfo. If either fwds or rvs is not a vector, __vec promotes it to a single-element Vector.

source
Mooncake.add_data!Method
add_data!(info::ADInfo, data)::ID

Equivalent to add_data!(info.shared_data_pairs, data).

source
Mooncake.add_data!Method
add_data!(p::SharedDataPairs, data)::ID

Puts data into p, and returns the id associated to it. This id should be assumed to be available during the forwards- and reverse-passes of AD, and it should further be assumed that the value associated to this id is always data.

source
Mooncake.add_data_if_not_singleton!Method
add_data_if_not_singleton!(p::Union{ADInfo, SharedDataPairs}, x)

Returns x if it is a singleton, or the ID of the ssa which will contain it on the forwards- and reverse-passes. The reason for this is that if something is a singleton, it can be inserted directly into the IR.

source
Mooncake.arrayifyMethod
arrayify(x::CoDual{<:AbstractArray{<:BlasRealFloat}})

Return the primal field of x, and convert its fdata into an array of the same type as the primal. This operation is not guaranteed to be possible for all array types, but seems to be possible for all array types of interest so far.

source
Mooncake.characterise_unique_predecessor_blocksMethod
characterise_unique_predecessor_blocks(blks::Vector{BBlock}) ->
    Tuple{Dict{ID, Bool}, Dict{ID, Bool}}

We call a block b a unique predecessor in the control flow graph associated to blks if it is the only predecessor to all of its successors. Put differently we call b a unique predecessor if, whenever control flow arrives in any of the successors of b, we know for certain that the previous block must have been b.

Returns two Dicts. A value in the first Dict is true if the block associated to its key is a unique precessor, and is false if not. A value in the second Dict is true if it has a single predecessor, and that predecessor is a unique predecessor.

Context:

This information is important for optimising AD because knowing that b is a unique predecessor means that

  1. on the forwards-pass, there is no need to push the ID of b to the block stack when passing through it, and
  2. on the reverse-pass, there is no need to pop the block stack when passing through one of the successors to b.

Utilising this reduces the overhead associated to doing AD. It is quite important when working with cheap loops – loops where the operations performed at each iteration are inexpensive – for which minimising memory pressure is critical to performance. It is also important for single-block functions, because it can be used to entirely avoid using a block stack at all.

source
Mooncake.characterise_used_idsMethod
characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID, Bool}

For each line in stmts, determine whether it is referenced anywhere else in the code. Returns a dictionary containing the results. An element is false if the corresponding ID is unused, and true if is used.

source
Mooncake.collect_stmtsMethod
collect_stmts(ir::BBCode)::Vector{IDInstPair}

Produce a Vector containing all of the statements in ir. These are returned in order, so it is safe to assume that element n refers to the nth element of the IRCode associated to ir.

source
Mooncake.collect_stmtsMethod
collect_stmts(bb::BBlock)::Vector{IDInstPair}

Returns a Vector containing the IDs and instructions associated to each line in bb. These should be assumed to be ordered.

source
Mooncake.comms_channelMethod
comms_channel(info::ADStmtInfo)

Return the element of fwds whose ID is the communcation ID. Returns Nothing if comms_id is nothing.

source
Mooncake.conclude_rvs_blockMethod
conclude_rvs_block(
    blk::BBlock, pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo
)

Generates code which is inserted at the end of each counterpart block in the reverse-pass. Handles phi nodes, and choosing the correct next block to switch to.

source
Mooncake.const_codualMethod
const_codual(stmt, info::ADInfo)

Build a CoDual from stmt, with zero / uninitialised fdata. If the resulting CoDual is a bits type, then it is returned. If it is not, then the CoDual is put into shared data, and the ID associated to it in the forwards- and reverse-passes returned.

source
Mooncake.const_codual_stmtMethod
const_codual_stmt(stmt, info::ADInfo)

Returns a :call expression which will return a CoDual whose primal is stmt, and whose tangent is whatever uninit_tangent returns.

source
Mooncake.create_comms_insts!Method
create_comms_insts!(ad_stmts_blocks::ADStmts, info::ADInfo)

This function produces code which can be inserted into the forwards-pass and reverse-pass at specific locations to implement the promise associated to the comms_id field of the ADStmtInfo type – namely that if you assign a value to comms_id on the forwards-pass, the same value will be available at comms_id on the reverse-pass.

For each basic block represented in ADStmts:

  1. create a stack containing a Tuple which can hold all of the values associated to the comms_ids for each statement. Put this stack in shared data.
  2. create instructions which can be inserted at the end of the block generated to perform the forwards-pass (in forwards_pass_ir) which will put all of the data associated to the comms_ids into shared data, and
  3. create instruction which can be inserted at the start of the block generated to perform the reverse-pass (in pullback_ir), which will extract all of the data put into shared data by the instructions generated by the previous point, and assigned them to the comms_ids.

Returns two a Tuple{Vector{IDInstPair}, Vector{IDInstPair}. The nth element of each Vector corresponds to the instructions to be inserted into the forwards- and reverse passes resp. for the nth block in ad_stmts_blocks.

source
Mooncake.fdata_field_typeMethod
fdata_field_type(::Type{P}, n::Int) where {P}

Returns the type of to the nth field of the fdata type associated to P. Will be a PossiblyUninitTangent if said field can be undefined.

source
Mooncake.fix_up_invoke_inference!Method
fix_up_invoke_inference!(ir::IRCode)

The Problem

Consider the following:

@noinline function bar!(x)
    x .*= 2
end

function foo!(x)
    bar!(x)
    return nothing
end

In this case, the IR associated to Tuple{typeof(foo), Vector{Float64}} will be something along the lines of

julia> Base.code_ircode_by_type(Tuple{typeof(foo), Vector{Float64}})
1-element Vector{Any}:
2 1 ─     invoke Main.bar!(_2::Vector{Float64})::Any
3 └──     return Main.nothing
   => Nothing

Observe that the type inferred for the first line is Any. Inference is at liberty to do this without any risk of performance problems because the first line is not used anywhere else in the function. Had this line been used elsewhere in the function, inference would have inferred its type to be Vector{Float64}.

This causes performance problems for Mooncake, because it uses the return type to do various things, including allocating storage for quantities required on the reverse-pass. Consequently, inference infering Any rather than Vector{Float64} causes type instabilities in the code that Mooncake generates, which can have catastrophic conseqeuences for performance.

The Solution

:invoke expressions contain the Core.MethodInstance associated to them, which contains a Core.CodeCache, which contains the return type of the :invoke. This function looks for :invoke statements whose return type is inferred to be Any in ir, and modifies it to be the return type given by the code cache.

source
Mooncake.foreigncall_to_callMethod
foreigncall_to_call(inst, sp_map::Dict{Symbol, CC.VarState})

If inst is a :foreigncall expression translate it into an equivalent :call expression. If anything else, just return inst. See Mooncake._foreigncall_ for details.

sp_map maps the names of the static parameters to their values. This function is intended to be called in the context of an IRCode, in which case the values of sp_map are given by the sptypes field of said IRCode. The keys should generally be obtained from the Method from which the IRCode is derived. See Mooncake.normalise! for more details.

The purpose of this transformation is to make it possible to differentiate :foreigncall expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.forwards_pass_irMethod
forwards_pass_ir(ir::BBCode, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)

Produce the IR associated to the OpaqueClosure which runs most of the forwards-pass.

source
Mooncake.fwd_irMethod
fwd_ir(
    sig::Type{<:Tuple};
    interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true
)::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Generate the Core.Compiler.IRCode used to construct the forwards-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.

For example, if you wanted to get the IR associated to the forwards pass for the call map(sin, randn(10)), you could do either of the following:

julia> Mooncake.fwd_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
true
julia> Mooncake.fwd_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
Mooncake.gc_preserveMethod
gc_preserve(xs...)

A no-op function. Its rrule!! ensures that the memory associated to xs is not freed until the pullback that it returns is run.

source
Mooncake.generate_irMethod
generate_ir(
    interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true
)

Used by build_rrule, and the various debugging tools: primalir, fwdsir, adjoint_ir.

source
Mooncake.get_rev_data_idMethod
get_rev_data_id(info::ADInfo, x)

Returns the ID associated to the line in the reverse pass which will contain the reverse data for x. If x is not an Argument or ID, then nothing is returned.

source
Mooncake.get_tangent_fieldMethod
get_tangent_field(t::Union{MutableTangent, Tangent}, i::Int)

Gets the ith field of data in t.

Has the same semantics that getfield! would have if the data in the fields field of t were actually fields of t. This is the moral equivalent of getfield for MutableTangent.

source
Mooncake.id_to_line_mapMethod
id_to_line_map(ir::BBCode)

Produces a Dict mapping from each ID associated with a line in ir to its line number. This is isomorphic to mapping to its SSAValue in IRCode. Terminators do not have IDs associated to them, so not every line in the original IRCode is mapped to.

source
Mooncake.increment_and_get_rdata!Method
increment_and_get_rdata!(fdata, zero_rdata, cr_tangent)

Increment fdata by the fdata component of the ChainRules.jl-style tangent, cr_tangent, and return the rdata component of cr_tangent by adding it to zero_rdata.

source
Mooncake.increment_rdata!!Method
increment_rdata!!(t::T, r)::T where {T}

Increment the rdata component of tangent t by r, and return the updated tangent. Useful for implementation getfield-like rules for mutable structs, pointers, dicts, etc.

source
Mooncake.infer_ir!Method
infer_ir!(ir::IRCode) -> IRCode

Runs type inference on ir, which mutates ir, and returns it.

Note: the compiler will not infer the types of anything where the corrsponding element of ir.stmts.flag is not set to Core.Compiler.IR_FLAG_REFINED. Nor will it attempt to refine the type of the value returned by a :invoke expressions. Consequently, if you find that the types in your IR are not being refined, you may wish to check that neither of these things are happening.

source
Mooncake.insert_before_terminator!Method
insert_before_terminator!(bb::BBlock, id::ID, inst::NewInstruction)::Nothing

If the final instruction in bb is a Terminator, insert inst immediately before it. Otherwise, insert inst at the end of the block.

source
Mooncake.interpolate_boundschecks!Method
interpolate_boundschecks!(ir::IRCode)

For every x = Expr(:boundscheck, value) in ir, interpolate value into all uses of x. This is only required in order to ensure that literal versions of memoryrefget, memoryrefset!, getfield, and setfield! work effectively. If they are removed through improvements to the way that we handle constant propagation inside Mooncake, then this functionality can be removed.

source
Mooncake.intrinsic_to_functionMethod
intrinsic_to_function(inst)

If inst is a :call expression to a Core.IntrinsicFunction, replace it with a call to the corresponding function from Mooncake.IntrinsicsWrappers, else return inst.

cglobal is a special case – it requires that its first argument be static in exactly the same way as :foreigncall. See IntrinsicsWrappers.__cglobal for more info.

The purpose of this transformation is to make it possible to use dispatch to write rules for intrinsic calls using dispatch in a type-stable way. See IntrinsicsWrappers for more context.

source
Mooncake.ircodeFunction
ircode(
    inst::Vector{Any},
    argtypes::Vector{Any},
    sptypes::Vector{CC.VarState}=CC.VarState[],
) -> IRCode

Constructs an instance of an IRCode. This is useful for constructing test cases with known properties.

No optimisations or type inference are performed on the resulting IRCode, so that the IRCode contains exactly what is intended by the caller. Please make use of infer_types! if you require the types to be inferred.

Edges in PhiNodes, GotoIfNots, and GotoNodes found in inst must refer to lines (as in CodeInfo). In the IRCode returned by this function, these line references are translated into block references.

source
Mooncake.is_always_fully_initialisedMethod
is_always_fully_initialised(P::DataType)::Bool

True if all fields in P are always initialised. Put differently, there are no inner constructors which permit partial initialisation.

source
Mooncake.is_always_initialisedMethod
is_always_initialised(P::DataType, n::Int)::Bool

True if the nth field of P is always initialised. If the nth fieldtype of P isbitstype, then this is distinct from asking whether the nth field is always defined. An isbits field is always defined, but is not always explicitly initialised.

source
Mooncake.is_primitiveMethod
is_primitive(::Type{Ctx}, sig) where {Ctx}

Returns a Bool specifying whether the methods specified by sig are considered primitives in the context of contexts of type Ctx.

is_primitive(DefaultCtx, Tuple{typeof(sin), Float64})

will return if calling sin(5.0) should be treated as primitive when the context is a DefaultCtx.

Observe that this information means that whether or not something is a primitive in a particular context depends only on static information, not any run-time information that might live in a particular instance of Ctx.

source
Mooncake.is_reachable_return_nodeMethod
is_reachable_return_node(x::ReturnNode)

Determine whether x is a ReturnNode, and if it is, if it is also reachable. This is purely a function of whether or not its val field is defined or not.

source
Mooncake.is_unreachable_return_nodeMethod
is_unreachable_return_node(x::ReturnNode)

Determine whehter x is a ReturnNode, and if it is, if it is also unreachable. This is purely a function of whether or not its val field is defined or not.

source
Mooncake.is_usedMethod
is_used(info::ADInfo, id::ID)::Bool

Returns true if id is used by any of the lines in the ir, false otherwise.

source
Mooncake.is_vararg_and_sparam_namesMethod
is_vararg_and_sparam_names(m::Method)

Returns a 2-tuple. The first element is true if m is a vararg method, and false if not. The second element contains the names of the static parameters associated to m.

source
Mooncake.lgetfieldMethod
lgetfield(x, f::Val)

An implementation of getfield in which the the field f is specified statically via a Val. This enables the implementation to be type-stable even when it is not possible to constant-propagate f. Moreover, it enable the pullback to also be type-stable.

It will always be the case that

getfield(x, :f) === lgetfield(x, Val(:f))
getfield(x, 2) === lgetfield(x, Val(2))

This approach is identical to the one taken by Zygote.jl to circumvent the same problem. Zygote.jl calls the function literal_getfield, while we call it lgetfield.

source
Mooncake.lgetfieldMethod
lgetfield(x, ::Val{f}, ::Val{order}) where {f, order}

Like getfield, but with the field and access order encoded as types.

source
Mooncake.lift_gc_preservationMethod
lift_gc_preserve(inst)

Expressions of the form

y = GC.@preserve x1 x2 foo(args...)

get lowered to

token = Expr(:gc_preserve_begin, x1, x2)
y = expr
Expr(:gc_preserve_end, token)

These expressions guarantee that any memory associated x1 and x2 not be freed until the :gc_preserve_end expression is reached.

In the context of reverse-mode AD, we must ensure that the memory associated to x1, x2 and their fdata is available during the reverse pass code associated to expr. We do this by preventing the memory from being freed until the :gc_preserve_begin is reached on the reverse pass.

To achieve this, we replace the primal code with

# store `x` in `pb_gc_preserve` to prevent it from being freed.
_, pb_gc_preserve = rrule!!(zero_fcodual(gc_preserve), x1, x2)

# Differentiate the `:call` expression in the usual way.
y, foo_pb = rrule!!(zero_fcodual(foo), args...)

# Do not permit the GC to free `x` here.
nothing

The pullback should be something along the lines of

# no pullback associated to `nothing`.
nothing

# Run the pullback associated to `foo` in the usual manner. `x` must be available.
_, dargs... = foo_pb(dy)

# No-op pullback associated to `gc_preserve`.
pb_gc_preserve(NoRData())
source
Mooncake.lift_getfield_and_othersMethod
lift_getfield_and_others(inst)

Converts expressions of the form getfield(x, :a) into lgetfield(x, Val(:a)). This has identical semantics, but is performant in the absence of proper constant propagation.

Does the same for...

source
Mooncake.lookup_irMethod
lookup_ir(
    interp::AbstractInterpreter,
    sig_or_mi::Union{Type{<:Tuple}, Core.MethodInstance},
)::Tuple{IRCode, T}

Get the unique IR associated to sig_or_mi under interp. Throws ArgumentErrors if there is no code found, or if more than one IRCode instance returned.

Returns a tuple containing the IRCode and its return type.

source
Mooncake.lsetfield!Method
lsetfield!(value, name::Val, x, [order::Val])

This function is to setfield! what lgetfield is to getfield. It will always hold that

setfield!(copy(x), :f, v) == lsetfield!(copy(x), Val(:f), v)
setfield!(copy(x), 2, v) == lsetfield(copy(x), Val(2), v)
source
Mooncake.make_ad_stmts!Function
make_ad_stmts!(inst::NewInstruction, line::ID, info::ADInfo)::ADStmtInfo

Every line in the primal code is associated to one or more lines in the forwards-pass of AD, and one or more lines in the pullback. This function has method specific to every node type in the Julia SSAIR.

Translates the instruction inst, associated to line in the primal, into a specification of what should happen for this instruction in the forwards- and reverse-passes of AD, and what data should be shared between the forwards- and reverse-passes. Returns this in the form of an ADStmtInfo.

info is a data structure containing various bits of global information that certain types of nodes need access to.

source
Mooncake.make_switch_stmtsMethod
make_switch_stmts(
    pred_ids::Vector{ID}, target_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo
)

preds_ids comprises the IDs associated to all possible predecessor blocks to the primal block under consideration. Suppose its value is [ID(1), ID(2), ID(3)], then make_switch_stmts emits code along the lines of

prev_block = pop!(block_stack)
not_pred_was_1 = !(prev_block == ID(1))
not_pred_was_2 = !(prev_block == ID(2))
switch(
    not_pred_was_1 => ID(1),
    not_pred_was_2 => ID(2),
    ID(3)
)

In words: make_switch_stmts emits code which jumps to whichever block preceded the current block during the forwards-pass.

source
Mooncake.new_instFunction
new_inst(stmt, type=Any, flag=CC.IR_FLAG_REFINED)::NewInstruction

Create a NewInstruction with fields:

  • stmt = stmt
  • type = type
  • info = CC.NoCallInfo()
  • line = Int32(1)
  • flag = flag
source
Mooncake.new_inst_vecMethod
new_inst_vec(x::CC.InstructionStream)

Convert an Compiler.InstructionStream into a list of Compiler.NewInstructions.

source
Mooncake.new_to_callMethod
new_to_call(x)

If instruction x is a :new expression, replace it with a :call to Mooncake._new_. Otherwise, return x.

The purpose of this transformation is to make it possible to differentiate :new expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.normalise!Method
normalise!(ir::IRCode, spnames::Vector{Symbol})

Apply a sequence of standardising transformations to ir which leaves its semantics unchanged, but makes AD more straightforward. In particular, replace

  1. :foreigncall Exprs with :calls to Mooncake._foreigncall_,
  2. :new Exprs with :calls to Mooncake._new_,
  3. :splatnew Exprs with:calls toMooncake.splatnew_`,
  4. Core.IntrinsicFunctions with counterparts from Mooncake.IntrinsicWrappers,
  5. getfield(x, 1) with lgetfield(x, Val(1)), and related transformations,
  6. memoryrefget calls to lmemoryrefget calls, and related transformations,
  7. gc_preserve_begin / gc_preserve_end exprs so that memory release is delayed.

spnames are the names associated to the static parameters of ir. These are needed when handling :foreigncall expressions, in which it is not necessarily the case that all static parameter names have been translated into either types, or :static_parameter expressions.

Unfortunately, the static parameter names are not retained in IRCode, and the Method from which the IRCode is derived must be consulted. Mooncake.is_vararg_and_sparam_names provides a convenient way to do this.

source
Mooncake.opaque_closureMethod
opaque_closure(
    ret_type::Type,
    ir::IRCode,
    @nospecialize env...;
    isva::Bool=false,
    do_compile::Bool=true,
)::Core.OpaqueClosure{<:Tuple, ret_type}

Construct a Core.OpaqueClosure. Almost equivalent to Core.OpaqueClosure(ir, env...; isva, do_compile), but instead of letting Core.compute_oc_rettype figure out the return type from ir, impose ret_type as the return type.

Warning

User beware: if the Core.OpaqueClosure produced by this function ever returns anything which is not an instance of a subtype of ret_type, you should expect all kinds of awful things to happen, such as segfaults. You have been warned!

Extended Help

This is needed in Mooncake.jl because make extensive use of our ability to know the return type of a couple of specific OpaqueClosures without actually having constructed them – see LazyDerivedRule. Without the capability to specify the return type, we have to guess what type compute_ir_rettype will return for a given IRCode before we have constructed the IRCode and run type inference on it. This exposes us to details of type inference, which are not part of the public interface of the language, and can therefore vary from Julia version to Julia version (including patch versions). Moreover, even for a fixed Julia version it can be extremely hard to predict exactly what type inference will infer to be the return type of a function.

Failing to correctly guess the return type can happen for a number of reasons, and the kinds of errors that tend to be generated when this fails tell you very little about the underlying cause of the problem.

By specifying the return type ourselves, we remove this dependence. The price we pay for this is the potential for segfaults etc if we fail to specify ret_type correctly.

source
Mooncake.optimise_ir!Method
optimise_ir!(ir::IRCode, show_ir=false)

Run a fairly standard optimisation pass on ir. If show_ir is true, displays the IR to stdout at various points in the pipeline – this is sometimes useful for debugging.

source
Mooncake.phi_nodesMethod
phi_nodes(bb::BBlock)::Tuple{Vector{ID}, Vector{IDPhiNode}}

Returns all of the IDPhiNodes at the start of bb, along with their IDs. If there are no IDPhiNodes at the start of bb, then both vectors will be empty.

source
Mooncake.prepare_gradient_cacheMethod
prepare_gradient_cache(f, x...)

WARNING: experimental functionality. Interface subject to change without warning!

Returns a cache which can be passed to value_and_gradient!!. See the docstring for Mooncake.value_and_gradient!! for more info.

source
Mooncake.prepare_pullback_cacheMethod
prepare_pullback_cache(f, x...)

WARNING: experimental functionality. Interface subject to change without warning!

Returns a cache which can be passed to value_and_gradient!!. See the docstring for Mooncake.value_and_gradient!! for more info.

source
Mooncake.primal_irMethod
primal_ir(sig::Type{<:Tuple}; interp=get_interpreter())::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Get the Core.Compiler.IRCode associated to sig from which the a rule can be derived. Roughly equivalent to Base.code_ircode_by_type(sig; interp).

For example, if you wanted to get the IR associated to the call map(sin, randn(10)), you could do one of the following calls:

julia> Mooncake.primal_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
true
julia> Mooncake.primal_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
true
source
Mooncake.pullback_irMethod
pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)

Produce the IR associated to the OpaqueClosure which runs most of the pullback.

source
Mooncake.pullback_typeMethod
pullback_type(Trule, arg_types)

Get a bound on the pullback type, given a rule and associated primal types.

source
Mooncake.rdata_field_typeMethod
rdata_field_type(::Type{P}, n::Int) where {P}

Returns the type of to the nth field of the rdata type associated to P. Will be a PossiblyUninitTangent if said field can be undefined.

source
Mooncake.remove_unreachable_blocks!Method
remove_unreachable_blocks!(ir::BBCode)::BBCode

If a basic block in ir cannot possibly be reached during execution, then it can be safely removed from ir without changing its functionality. A block is unreachable if either:

  1. it has no predecessors and it is not the first block, or
  2. all of its predecessors are themselves unreachable.

For example, consider the following IR:

julia> ir = Mooncake.ircode(
           Any[Core.ReturnNode(nothing), Expr(:call, sin, 5), Core.ReturnNode(Core.SSAValue(2))],
           Any[Any, Any, Any],
       );

There is no possible way to reach the second basic block (lines 2 and 3). Applying this function will therefore remove it, yielding the following:

julia> Mooncake.IRCode(Mooncake.remove_unreachable_blocks!(Mooncake.BBCode(ir)))
1 1 ─     return nothing

In the blocks which have not been removed, there may be references to blocks which have been removed. For example, the edges in a PhiNode may contain a reference to a removed block. These references are removed in-place from these remaining blocks, so this function will (in general) modify ir.

source
Mooncake.replace_capturesMethod
replace_captures(mc::Tmc, new_captures) where {Tmc<:MistyClosure}

Same as replace_captures for Core.OpaqueClosures, but returns a new MistyClosure.

source
Mooncake.replace_capturesMethod
replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure}

Given an OpaqueClosure oc, create a new OpaqueClosure of the same type, but with new captured variables. This is needed for efficiency reasons – if build_rrule is called repeatedly with the same signature and intepreter, it is important to avoid recompiling the OpaqueClosures that it produces multiple times, because it can be quite expensive to do so.

source
Mooncake.replace_uses_with!Method
replace_uses_with!(stmt, def::Union{Argument, SSAValue}, val)

Replace all uses of def with val in the single statement stmt. Note: this function is highly incomplete, really only working correctly for a specific function in ir_normalisation.jl. You probably do not want to use it.

source
Mooncake.rrule_wrapperMethod
rrule_wrapper(f::CoDual, args::CoDual...)

Used to implement rrule!!s via ChainRulesCore.rrule.

Given a function foo, argument types arg_types, and a method of ChainRulesCore.rrule which applies to these, you can make use of this function as follows:

Mooncake.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...}
function Mooncake.rrule!!(f::CoDual{typeof(foo)}, args::CoDual...)
    return rrule_wrapper(f, args...)
end

Assumes that methods of to_cr_tangent and to_mooncake_tangent are defined such that you can convert between the different representations of tangents that Mooncake and ChainRulesCore expect.

Furthermore, it is essential that

  1. f(args) does not mutate f or args, and
  2. the result of f(args) does not alias any data stored in f or args.

Subject to some constraints, you can use the @from_rrule macro to reduce the amount of boilerplate code that you are required to write even further.

source
Mooncake.rule_typeMethod
rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C}

Compute the concrete type of the rule that will be returned from build_rrule. This is important for performance in dynamic dispatch, and to ensure that recursion works properly.

source
Mooncake.rvs_irMethod
rvs_ir(
    sig::Type{<:Tuple};
    interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true
)::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Generate the Core.Compiler.IRCode used to construct the reverse-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.

For example, if you wanted to get the IR associated to the reverse pass for the call map(sin, randn(10)), you could do either of the following:

julia> Mooncake.rvs_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
true
julia> Mooncake.rvs_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
Mooncake.rvs_phi_blockMethod
rvs_phi_block(pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo)

Produces a BBlock which runs the reverse-pass for the edge associated to pred_id in a collection of IDPhiNodes, and then goes to the block associated to pred_id.

For example, suppose that we encounter the following collection of PhiNodes at the start of some block:

%6 = φ (#2 => _1, #3 => %5)
%7 = φ (#2 => 5., #3 => _2)

Let the tangent refs associated to %6, %7, and _1be denotedt%6,t%7, andt1resp., and letpredidbe#2`, then this function will produce a basic block of the form

increment_ref!(t_1, t%6)
nothing
goto #2

The call to increment_ref! appears because _1 is the value associated to%6 when the primal code comes from #2. Similarly, the goto #2 statement appears because we came from #2 on the forwards-pass. There is no increment_ref! associated to %7 because 5. is a constant. We emit a nothing statement, which the compiler will happily optimise away later on.

The same ideas apply if pred_id were #3. The block would end with #3, and there would be two increment_ref! calls because both %5 and _2 are not constants.

source
Mooncake.seed_id!Method
seed_id!()

Set the global counter used to ensure ID uniqueness to 0. This is useful when you want to ensure determinism between two runs of the same function which makes use of IDs.

This is akin to setting the random seed associated to a random number generator globally.

source
Mooncake.set_tangent_field!Method
set_tangent_field!(t::MutableTangent{Tfields}, i::Int, x) where {Tfields}

Sets the value of the ith field of the data in t to value x.

Has the same semantics that setfield! would have if the data in the fields field of t were actually fields of t. This is the moral equivalent of setfield! for MutableTangent.

source
Mooncake.shared_data_stmtsMethod
shared_data_stmts(p::SharedDataPairs)::Vector{IDInstPair}

Produce a sequence of id-statment pairs which will extract the data from shared_data_tuple(p) such that the correct value is associated to the correct ID.

For example, if p.pairs is

[(ID(5), 5.0), (ID(3), "hello")]

then the output of this function is

IDInstPair[
    (ID(5), new_inst(:(getfield(_1, 1)))),
    (ID(3), new_inst(:(getfield(_1, 2)))),
]
source
Mooncake.shared_data_tupleMethod
shared_data_tuple(p::SharedDataPairs)::Tuple

Create the tuple that will constitute the captured variables in the forwards- and reverse- pass OpaqueClosures.

For example, if p.pairs is

[(ID(5), 5.0), (ID(3), "hello")]

then the output of this function is

(5.0, "hello")
source
Mooncake.sparam_namesMethod
sparam_names(m::Core.Method)::Vector{Symbol}

Returns the names of all of the static parameters in m.

source
Mooncake.splatnew_to_callMethod
splatnew_to_call(x)

If instruction x is a :splatnew expression, replace it with a :call to Mooncake._splat_new_. Otherwise return x.

The purpose of this transformation is to make it possible to differentiate :splatnew expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.stmtMethod
stmt(ir::CC.InstructionStream)

Get the field containing the instructions in ir. This changed name in 1.11 from inst to stmt.

source
Mooncake.tangent_field_typeMethod
tangent_field_type(::Type{P}, n::Int) where {P}

Returns the type that lives in the nth elements of fields in a Tangent / MutableTangent. Will either be the tangent_type of the nth fieldtype of P, or the tangent_type wrapped in a PossiblyUninitTangent. The latter case only occurs if it is possible for the field to be undefined.

source
Mooncake.tangent_test_casesMethod
tangent_test_cases()

Constructs a Vector of Tuples containing test cases for the tangent infrastructure.

If the returned tuple has 2 elements, the elements should be interpreted as follows: 1 - interface_only 2 - primal value

interface_only is a Bool which will be used to determine which subset of tests to run.

If the returned tuple has 5 elements, then the elements are interpreted as follows: 1 - interface_only 2 - primal value 3, 4, 5 - tangents, where <5> == increment!!(<3>, <4>).

Test cases in the first format make use of zero_tangent / randn_tangent etc to generate tangents, but they're unable to check that increment!! is correct in an absolute sense.

source
Mooncake.terminatorMethod
terminator(bb::BBlock)

Returns the terminator associated to bb. If the last instruction in bb isa Terminator then that is returned, otherwise nothing is returned.

source
Mooncake.tuple_mapMethod
tuple_map(f::F, x::Tuple) where {F}

This function is largely equivalent to map(f, x), but always specialises on all of the element types of x, regardless the length of x. This contrasts with map, in which the number of element types specialised upon is a fixed constant in the compiler.

As a consequence, if x is very long, this function may have very large compile times.

tuple_map(f::F, x::Tuple, y::Tuple) where {F}

Binary extension of tuple_map. Nearly equivalent to map(f, x, y), but guaranteed to specialise on all element types of x and y. Furthermore, errors if x and y aren't the same length, while map will just produce a new tuple whose length is equal to the shorter of x and y.

source
Mooncake.uninit_fcodualMethod
uninit_fcodual(x)

Like zero_fcodual, but doesn't guarantee that the value of the fdata is initialised. See implementation for details, as this function is subject to change.

source
Mooncake.uninit_tangentMethod
uninit_tangent(x)

Related to zero_tangent, but a bit different. Check current implementation for details – this docstring is intentionally non-specific in order to avoid becoming outdated.

source
Mooncake.verify_fdata_typeMethod
verify_fdata_type(P::Type, F::Type)::Nothing

Check that F is a valid type for fdata associated to a primal of type P. Returns nothing if valid, throws an InvalidFDataException if a problem is found.

This applies to both concrete and non-concrete P. For example, if P is the type inferred for a primal q::Q, such that Q <: P, then this method is still applicable.

source
Mooncake.verify_fdata_valueMethod
verify_fdata_value(p, f)::Nothing

Check that f cannot be proven to be invalid fdata for p.

This method attempts to provide some confidence that f is valid fdata for p by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.

Put differently, we cannot prove that f is valid fdata, only that it is not obviously invalid.

source
Mooncake.verify_rdata_typeMethod
verify_rdata_type(P::Type, R::Type)::Nothing

Check that R is a valid type for rdata associated to a primal of type P. Returns nothing if valid, throws an InvalidRDataException if a problem is found.

This applies to both concrete and non-concrete P. For example, if P is the type inferred for a primal q::Q, such that Q <: P, then this method is still applicable.

source
Mooncake.verify_rdata_valueMethod
verify_rdata_value(p, r)::Nothing

Check that r cannot be proven to be invalid rdata for p.

This method attempts to provide some confidence that r is valid rdata for p by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.

Put differently, we cannot prove that r is valid rdata, only that it is not obviously invalid.

source
Mooncake.zero_adjointMethod
zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}

Utility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.

NOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.

You make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:

julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual

julia> foo(x::Vararg{Int}) = 5
foo (generic function with 1 method)

julia> is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(foo), Vararg{Int}}}) = true;

julia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);

julia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())
(NoRData(), NoRData(), NoRData())

WARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```

source
Mooncake.zero_like_rdata_from_typeMethod
zero_like_rdata_from_type(::Type{P}) where {P}

This is an internal implementation detail – you should generally not use this function.

Returns either the zero element of type rdata_type(tangent_type(P)), or a ZeroRData. It is always valid to return a ZeroRData,

source
Mooncake.zero_like_rdata_typeMethod
zero_like_rdata_type(::Type{P}) where {P}

Indicates the type which will be returned by zero_like_rdata_from_type. Will be the rdata type for P if we can produce the zero rdata element given only P, and will be the union of R and ZeroRData if an instance of P is needed.

source
Mooncake.zero_rdata_from_typeMethod
zero_rdata_from_type(::Type{P}) where {P}

Returns the zero element of rdata_type(tangent_type(P)) if this is possible given only P. If not possible, returns an instance of CannotProduceZeroRDataFromType.

For example, the zero rdata associated to any primal of type Float64 is 0.0, so for Float64s this function is simple. Similarly, if the rdata type for P is NoRData, that can simply be returned.

However, it is not possible to return the zero rdata element for abstract types e.g. Real as the type does not uniquely determine the zero element – the rdata type for Real is Any.

These considerations apply recursively to tuples / namedtuples / structs, etc.

If you encounter a type which this function returns CannotProduceZeroRDataFromType, but you believe this is done in error, please open an issue. This kind of problem does not constitute a correctness problem, but can be detrimental to performance, so should be dealt with.

source
Mooncake.@from_rruleMacro
@from_rrule ctx sig [has_kwargs=false]

Convenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.

Arguments

  • ctx: A Mooncake context type
  • sig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.
  • has_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.

Example Usage

A Basic Example

julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils

julia> using ChainRulesCore

julia> foo(x::Real) = 5x;

julia> function ChainRulesCore.rrule(::typeof(foo), x::Real)
           foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), 5Ω
           return foo(x), foo_pb
       end;

julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat}

julia> rrule!!(zero_fcodual(foo), zero_fcodual(5.0))[2](1.0)
(NoRData(), 5.0)

julia> # Check that the rule works as intended.
       TestUtils.test_rule(Xoshiro(123), foo, 5.0; is_primitive=true)
Test Passed

An Example with Keyword Arguments

julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils

julia> using ChainRulesCore

julia> foo(x::Real; cond::Bool) = cond ? 5x : 4x;

julia> function ChainRulesCore.rrule(::typeof(foo), x::Real; cond::Bool)
           foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), cond ? 5Ω : 4Ω
           return foo(x; cond), foo_pb
       end;

julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} true

julia> _, pb = rrule!!(
           zero_fcodual(Core.kwcall),
           zero_fcodual((cond=false, )),
           zero_fcodual(foo),
           zero_fcodual(5.0),
       );

julia> pb(3.0)
(NoRData(), NoRData(), NoRData(), 12.0)

julia> # Check that the rule works as intended.
       TestUtils.test_rule(
           Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true
       )
Test Passed

Notice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.

Limitations

It is your responsibility to ensure that

  1. calls with signature sig do not mutate their arguments,
  2. the output of calls with signature sig does not alias any of the inputs.

As with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.

Argument Type Constraints

Many methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature

Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}

There are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.

Suffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.

Conversions Between Different Tangent Type Systems

Under the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.

source
Mooncake.@is_primitiveMacro
@is_primitive context_type signature

Creates a method of is_primitive which always returns true for the context_type and signature provided. For example

@is_primitive MinimalCtx Tuple{typeof(foo), Float64}

is equivalent to

is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(foo), Float64}}) = true

You should implemented more complicated method of is_primitive in the usual way.

source
Mooncake.@mooncake_overlayMacro
@mooncake_overlay method_expr

Define a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.

For example, suppose that you have a function

julia> foo(x::Float64) = bar(x)
foo (generic function with 1 method)

where Mooncake.jl fails to differentiate bar for some reason. If you have access to another function baz, which does the same thing as bar, but does so in a way which Mooncake.jl can differentiate, you can simply write:

julia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x)

When looking up the code for foo(::Float64), Mooncake.jl will see this method, rather than the original, and differentiate it instead.

A Worked Example

To demonstrate how to use @mooncake_overlays in practice, we here demonstrate how the answer that Mooncake.jl gives changes if you change the definition of a function using a @mooncake_overlay. Do not do this in practice – this is just a simple way to demonostrate how to use overlays!

First, consider a simple example:

julia> scale(x) = 2x
scale (generic function with 1 method)

julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});

julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(10.0, (NoTangent(), 2.0))

We can use @mooncake_overlay to change the definition which Mooncake.jl sees:

julia> Mooncake.@mooncake_overlay scale(x) = 3x

julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});

julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(15.0, (NoTangent(), 3.0))

As can be seen from the output, the result of differentiating using Mooncake.jl has changed to reflect the overlay-ed definition of the method.

Additionally, it is possible to use the usual multi-line syntax to declare an overlay:

julia> Mooncake.@mooncake_overlay function scale(x)
           return 4x
       end

julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});

julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(20.0, (NoTangent(), 4.0))
source
Mooncake.@zero_adjointMacro
@zero_adjoint ctx sig

Defines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.

For example:

julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive

julia> foo(x) = 5
foo (generic function with 1 method)

julia> @zero_adjoint DefaultCtx Tuple{typeof(foo), Any}

julia> is_primitive(DefaultCtx, Tuple{typeof(foo), Any})
true

julia> rrule!!(zero_fcodual(foo), zero_fcodual(3.0))[2](NoRData())
(NoRData(), 0.0)

Limited support for Varargs is also available. For example

julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive

julia> foo_varargs(x...) = 5
foo_varargs (generic function with 1 method)

julia> @zero_adjoint DefaultCtx Tuple{typeof(foo_varargs), Vararg}

julia> is_primitive(DefaultCtx, Tuple{typeof(foo_varargs), Any, Float64, Int})
true

julia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData())
(NoRData(), 0.0, NoRData())

Be aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.

WARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.

As always, you should use TestUtils.test_rule to ensure that you've not made a mistake.

Signatures Unsupported By This Macro

If the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.

source
Mooncake.IntrinsicsWrappersModule
module IntrinsicsWrappers

The purpose of this module is to associate to each function in Core.Intrinsics a regular Julia function.

To understand the rationale for this observe that, unlike regular Julia functions, each Core.IntrinsicFunction in Core.Intrinsics does not have its own type. Rather, they are instances of Core.IntrinsicFunction. To see this, observe that

julia> typeof(Core.Intrinsics.add_float)
Core.IntrinsicFunction

julia> typeof(Core.Intrinsics.sub_float)
Core.IntrinsicFunction

While we could simply write a rule for Core.IntrinsicFunction, this would (naively) lead to a large list of conditionals of the form

if f === Core.Intrinsics.add_float
    # return add_float and its pullback
elseif f === Core.Intrinsics.sub_float
    # return add_float and its pullback
elseif
    ...
end

which has the potential to cause quite substantial type instabilities. (This might not be true anymore – see extended help for more context).

Instead, we map each Core.IntrinsicFunction to one of the regular Julia functions in Mooncake.IntrinsicsWrappers, to which we can dispatch in the usual way.

Extended Help

It is possible that owing to improvements in constant propagation in the Julia compiler in version 1.10, we actually could get away with just writing a single method of rrule!! to handle all intrinsics, so this dispatch-based mechanism might be unnecessary. Someone should investigate this. Discussed at https://github.com/compintell/Mooncake.jl/issues/387 .

source