Internal Docstrings
Docstrings listed here are not part of the public Mooncake.jl interface. Consequently, they can change between non-breaking changes to Mooncake.jl without warning.
The purpose of this is to make it easy for developers to find docstrings straightforwardly via the docs, as opposed to having to ctrl+f through Mooncake.jl's source code, or looking at the docstrings via the Julia REPL.
Mooncake.GLOBAL_INTERPRETER
— Constantconst GLOBAL_INTERPRETER
Globally cached interpreter. Should only be accessed via get_interpreter
.
Mooncake.Terminator
— TypeTerminator = Union{Switch, IDGotoIfNot, IDGotoNode, ReturnNode}
A Union of the possible types of a terminator node.
Core.Compiler.IRCode
— MethodIRCode(bb_code::BBCode)
Produce an IRCode
instance which is equivalent to bb_code
. The resulting IRCode
shares no memory with bb_code
, so can be safely mutated without modifying bb_code
.
All IDPhiNode
s, IDGotoIfNot
s, and IDGotoNode
s are converted into PhiNode
s, GotoIfNot
s, and GotoNode
s respectively.
In the resulting bb_code
, any Switch
nodes are lowered into a semantically-equivalent collection of GotoIfNot
nodes.
Mooncake.ADInfo
— TypeADInfo
This data structure is used to hold "global" information associated to a particular call to build_rrule
. It is used as a means of communication between make_ad_stmts!
and the codegen which produces the forwards- and reverse-passes.
interp
: aMooncakeInterpreter
.block_stack_id
: the ID associated to the block stack – the stack which keeps track of which blocks we visited during the forwards-pass, and which is used on the reverse-pass to determine which blocks to visit.block_stack
: the block stack. Can always be found atblock_stack_id
in the forwards- and reverse-passes.entry_id
: ID associated to the block inserted at the start of execution in the the forwards-pass, and the end of execution in the pullback.shared_data_pairs
: theSharedDataPairs
used to define the captured variables passed to both the forwards- and reverse-passes.arg_types
: a map fromArgument
to its static type.ssa_insts
: a map fromID
associated to lines to the primalNewInstruction
. This contains the line of code, its static / inferred type, and some other detailss. SeeCore.Compiler.NewInstruction
for a full list of fields.arg_rdata_ref_ids
: the dict mapping from arguments to theID
which creates and initialises theRef
which contains the reverse data associated to that argument. Recall that the heap allocations associated to thisRef
are always optimised away in the final programme.ssa_rdata_ref_ids
: the same asarg_rdata_ref_ids
, but for eachID
associated to an ssa rather than each argument.debug_mode
: iftrue
, run in "debug mode" – wraps all rule calls inDebugRRule
. This is applied recursively, so that debug mode is also switched on in derived rules.is_used_dict
: for eachID
associated to a line of code, isfalse
if line is not used anywhere in any other line of code.lazy_zero_rdata_ref_id
: for any arguments whose type doesn't permit the construction of a zero-valued rdata directly from the type alone (e.g. a struct with an abstractly- typed field), we need to have a zero-valued rdata available on the reverse-pass so that this zero-valued rdata can be returned if the argument (or a part of it) is never used during the forwards-pass and consequently doesn't obtain a value on the reverse-pass. To achieve this, we construct aLazyZeroRData
for each of the arguments on the forwards-pass, and make use of it on the reverse-pass. This field is the ID that will be associated to this information.
Mooncake.ADStmtInfo
— TypeADStmtInfo
Data structure which contains the result of make_ad_stmts!
. Fields are
line
: the ID associated to the primal line from which this is derivedcomms_id
: anID
from one of the lines infwds
, whose value will be made available on the reverse-pass in the sameID
. Nothing is asserted about how this value is made available on the reverse-pass of AD, so this package is free to do this in whichever way is most efficient, in particular to group these communicationID
on a per-block basis.fwds
: the instructions which run the forwards-pass of ADrvs
: the instructions which run the reverse-pass of AD / the pullback
Mooncake.BBCode
— MethodBBCode(ir::IRCode)
Convert an ir
into a BBCode
. Creates a completely independent data structure, so mutating the BBCode
returned will not mutate ir
.
All PhiNode
s, GotoIfNot
s, and GotoNode
s will be replaced with the IDPhiNode
s, IDGotoIfNot
s, and IDGotoNode
s respectively.
See IRCode
for conversion back to IRCode
.
Note that IRCode(BBCode(ir))
should be equal to the identity function.
Mooncake.BBCode
— MethodBBCode(ir::Union{IRCode, BBCode}, new_blocks::Vector{Block})
Make a new BBCode
whose blocks
is given by new_blocks
, and fresh copies are made of all other fields from ir
.
Mooncake.BBCode
— TypeBBCode(
blocks::Vector{BBlock}
argtypes::Vector{Any}
sptypes::Vector{CC.VarState}
linetable::Vector{Core.LineInfoNode}
meta::Vector{Expr}
)
A BBCode
is a data structure which is similar to IRCode
, but adds additional structure.
In particular, a BBCode
comprises a sequence of basic blocks (BBlock
s), each of which comprise a sequence of statements. Moreover, each BBlock
has its own unique ID
, as does each statment.
The consequence of this is that new basic blocks can be inserted into a BBCode
. This is distinct from IRCode
, in which to create a new basic block, one must insert additional statments which you know will create a new basic block – this is generally quite an unreliable process, while inserting a new BBlock
into BBCode
is entirely predictable. Furthermore, inserting a new BBlock
does not change the ID
associated to the other blocks, meaning that you can safely assume that references from existing basic block terminators / phi nodes to other blocks will not be modified by inserting a new basic block.
Additionally, since each statment in each basic block has its own unique ID
, new statments can be inserted without changing references between other blocks. IRCode
also has some support for this via its new_nodes
field, but eventually all statements will be renamed upon compact!
ing the IRCode
, meaning that the name of any given statement will eventually change.
Finally, note that the basic blocks in a BBCode
support the custom Switch
statement. This statement is not valid in IRCode
, and is therefore lowered into a collection of GotoIfNot
s and GotoNode
s when a BBCode
is converted back into an IRCode
.
Mooncake.BBlock
— MethodBBlock(id::ID, inst_pairs::Vector{IDInstPair})
Convenience constructor – splits inst_pairs
into a Vector{ID}
and InstVector
in order to build a BBlock
.
Mooncake.BBlock
— TypeBBlock(id::ID, stmt_ids::Vector{ID}, stmts::InstVector)
A basic block data structure (not called BasicBlock
to avoid accidental confusion with CC.BasicBlock
). Forms a single basic block.
Each BBlock
has an ID
(a unique name). This makes it possible to refer to blocks in a way that does not change when additional BBlocks
are inserted into a BBCode
. This differs from the positional block numbering found in IRCode
, in which the number associated to a basic block changes when new blocks are inserted.
The n
th line of code in a BBlock
is associated to ID
stmt_ids[n]
, and the n
th instruction from stmts
.
Note that PhiNode
s, GotoIfNot
s, and GotoNode
s should not appear in a BBlock
– instead an IDPhiNode
, IDGotoIfNot
, or IDGotoNode
should be used.
Mooncake.BlockStack
— TypeThe block stack is the stack used to keep track of which basic blocks are visited on the forwards pass, and therefore which blocks need to be visited on the reverse pass. There is one block stack per derived rule. By using Int32, we assume that there aren't more than typemax(Int32)
unique basic blocks in a given function, which ought to be reasonable.
Mooncake.CannotProduceZeroRDataFromType
— TypeCannotProduceZeroRDataFromType()
Returned by zero_rdata_from_type
if is not possible to construct the zero rdata element for a given type. See zero_rdata_from_type
for more info.
Mooncake.Config
— TypeConfig(; debug_mode=false, silence_debug_messages=false)
Configuration struct for use with ADTypes.AutoMooncake.
Mooncake.DebugPullback
— Method(pb::DebugPullback)(dy)
Apply type checking to enforce pre- and post-conditions on pb.pb
. See the docstring for DebugPullback
for details.
Mooncake.DebugPullback
— TypeDebugPullback(pb, y, x)
Construct a callable which is equivalent to pb
, but which enforces type-based pre- and post-conditions to pb
. Let dx = pb.pb(dy)
, for some rdata dy
, then this function
- checks that
dy
has the correct rdata type fory
, and - checks that each element of
dx
has the correct rdata type forx
.
Reverse pass counterpart to DebugRRule
Mooncake.DebugRRule
— Method(rule::DebugRRule)(x::CoDual...)
Apply type checking to enforce pre- and post-conditions on rule.rule
. See the docstring for DebugRRule
for details.
Mooncake.DebugRRule
— TypeDebugRRule(rule)
Construct a callable which is equivalent to rule
, but inserts additional type checking. In particular:
- check that the fdata in each argument is of the correct type for the primal
- check that the fdata in the
CoDual
returned from the rule is of the correct type for the primal.
This happens recursively. For example, each element of a Vector{Any}
is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.
Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).
Let rule
return y, pb!!
, then DebugRRule(rule)
returns y, DebugPullback(pb!!)
. DebugPullback
inserts the same kind of checks as DebugRRule
, but on the reverse-pass. See the docstring for details.
Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.
Mooncake.DefaultCtx
— Typestruct DefaultCtx end
Context for all usually used AD primitives. Anything which is a primitive in a MinimalCtx is a primitive in the DefaultCtx automatically. If you are adding a rule for the sake of performance, it should be a primitive in the DefaultCtx, but not the MinimalCtx.
Mooncake.DynamicDerivedRule
— TypeDynamicDerivedRule(interp::MooncakeInterpreter, debug_mode::Bool)
For internal use only.
A callable data structure which, when invoked, calls an rrule specific to the dynamic types of its arguments. Stores rules in an internal cache to avoid re-deriving.
This is used to implement dynamic dispatch.
Mooncake.FData
— TypeFData(data::NamedTuple)
The component of a struct
which is propagated alongside the primal on the forwards-pass of AD. For example, the tangents for Float64
s do not need to be propagated on the forwards- pass of reverse-mode AD, so any Float64
fields of Tangent
do not need to appear in the associated FData
.
Mooncake.ID
— TypeID()
An ID
(read: unique name) is just a wrapper around an Int32
. Uniqueness is ensured via a global counter, which is incremented each time that an ID
is created.
This counter can be reset using seed_id!
if you need to ensure deterministic ID
s are produced, in the same way that seed for random number generators can be set.
Mooncake.IDGotoIfNot
— TypeIDGotoIfNot(cond::Any, dest::ID)
Like a GotoIfNot
, but dest
is an ID
rather than an Int64
.
Mooncake.IDGotoNode
— TypeIDGotoNode(label::ID)
Like a GotoNode
, but label
is an ID
rather than an Int64
.
Mooncake.IDInstPair
— Typeconst IDInstPair = Tuple{ID, NewInstruction}
Mooncake.IDPhiNode
— TypeIDPhiNode(edges::Vector{ID}, values::Vector{Any})
Like a PhiNode
, but edges
are ID
s rather than Int32
s.
Mooncake.InstVector
— Typeconst InstVector = Vector{NewInstruction}
Note: the CC.NewInstruction
type is used to represent instructions because it has the correct fields. While it is only used to represent new instrucdtions in Core.Compiler
, it is used to represent all instructions in BBCode
.
Mooncake.InvalidFDataException
— TypeInvalidFDataException(msg::String)
Exception indicating that there is a problem with the fdata associated to a primal.
Mooncake.InvalidRDataException
— TypeInvalidRDataException(msg::String)
Exception indicating that there is a problem with the rdata associated to a primal.
Mooncake.LazyDerivedRule
— TypeLazyDerivedRule(interp, mi::Core.MethodInstance, debug_mode::Bool)
For internal use only.
A type-stable wrapper around a DerivedRule
, which only instantiates the DerivedRule
when it is first called. This is useful, as it means that if a rule does not get run, it does not have to be derived.
If debug_mode
is true
, then the rule constructed will be a DebugRRule
. This is useful when debugging, but should usually be switched off for production code as it (in general) incurs some runtime overhead.
Note: the signature of the primal for which this is a rule is stored in the type. The only reason to keep this around is for debugging – it is very helpful to have this type visible in the stack trace when something goes wrong, as it allows you to trivially determine which bit of your code is the culprit.
Mooncake.LazyZeroRData
— TypeLazyZeroRData{P, Tdata}()
This type is a lazy placeholder for zero_like_rdata_from_type
. This is used to defer construction of zero data to the reverse pass. Calling instantiate
on an instance of this will construct a zero data.
Users should construct using LazyZeroRData(p)
, where p
is an value of type P
. This constructor, and instantiate
, are specialised to minimise the amount of data which must be stored. For example, Float64
s do not need any data, so LazyZeroRData(0.0)
produces an instance of a singleton type, meaning that various important optimisations can be performed in AD.
Mooncake.MinimalCtx
— Typestruct MinimalCtx end
Functions should only be primitives in this context if not making them so would cause AD to fail. In particular, do not add primitives to this context if you are writing them for performance only – instead, make these primitives in the DefaultCtx.
Mooncake.NoPullback
— MethodNoPullback(args::CoDual...)
Construct a NoPullback
from the arguments passed to an rrule!!
. For each argument, extracts the primal value, and constructs a LazyZeroRData
. These are stored in a NoPullback
which, in the reverse-pass of AD, instantiates these LazyZeroRData
s and returns them in order to perform the reverse-pass of AD.
The advantage of this approach is that if it is possible to construct the zero rdata element for each of the arguments lazily, the NoPullback
generated will be a singleton type. This means that AD can avoid generating a stack to store this pullback, which can result in significant performance improvements.
Mooncake.RRuleZeroWrapper
— TypeRRuleZeroWrapper(rule)
This struct is used to ensure that ZeroRData
s, which are used as placeholder zero elements whenever an actual instance of a zero rdata for a particular primal type cannot be constructed without also having an instance of said type, never reach rules. On the pullback, we increment the cotangent dy by an amount equal to zero. This ensures that if it is a ZeroRData
, we instead get an actual zero of the correct type. If it is not a zero rdata, the computation should be elided via inlining + constant prop.
Mooncake.SharedDataPairs
— TypeSharedDataPairs()
A data structure used to manage the captured data in the OpaqueClosures
which implement the bulk of the forwards- and reverse-passes of AD. An entry (id, data)
at element n
of the pairs
field of this data structure means that data
will be available at register id
during the forwards- and reverse-passes of AD
.
This is achieved by storing all of the data in the pairs
field in the captured tuple which is passed to an OpaqueClosure
, and extracting this data into registers associated to the corresponding ID
s.
Mooncake.Stack
— TypeStack{T}()
A stack specialised for reverse-mode AD.
Semantically equivalent to a usual stack, but never de-allocates memory once allocated.
Mooncake.Switch
— TypeSwitch(conds::Vector{Any}, dests::Vector{ID}, fallthrough_dest::ID)
A switch-statement node. These can be inserted in the BBCode
representation of Julia IR. Switch
has the following semantics:
goto dests[1] if not conds[1]
goto dests[2] if not conds[2]
...
goto dests[N] if not conds[N]
goto fallthrough_dest
where the value associated to each element of conds
is a Bool
, and dests
indicate which block to jump to. If none of the conditions are met, then we go to whichever block is specified by fallthrough_dest
.
Switch
statements are lowered into the above sequence of GotoIfNot
s and GotoNode
s when converting BBCode
back into IRCode
, because Switch
statements are not valid nodes in regular Julia IR.
Mooncake.UnhandledLanguageFeatureException
— TypeUnhandledLanguageFeatureException(message::String)
An exception used to indicate that some aspect of the Julia language which AD cannot handle has been encountered.
Mooncake.ZeroRData
— TypeZeroRData()
Singleton type indicating zero-valued rdata. This should only ever appear as an intermediate quantity in the reverse-pass of AD when the type of the primal is not fully inferable, or a field of a type is abstractly typed.
If you see this anywhere in actual code, or if it appears in a hand-written rule, this is an error – please open an issue in such a situation.
Base.insert!
— MethodBase.insert!(bb::BBlock, n::Int, id::ID, stmt::CC.NewInstruction)::Nothing
Inserts stmt
and id
into bb
immediately before the n
th instruction.
Mooncake.__deref_and_zero
— Method__deref_and_zero(::Type{P}, x::Ref) where {P}
Helper, used in concludervsblock.
Mooncake.__flatten_varargs
— Method__flatten_varargs(isva::Bool, args, ::Val{nvargs}) where {nvargs}
If isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0).
Mooncake.__get_value
— Method__get_value(edge::ID, x::IDPhiNode)
Helper functionality for concludervsblock.
Mooncake.__insts_to_instruction_stream
— Method__insts_to_instruction_stream(insts::Vector{Any})
Produces an instruction stream whose
stmt
(v1.11 and up) /inst
(v1.10) field isinsts
,type
field is allAny
,info
field is allCore.Compiler.NoCallInfo
,line
field is allInt32(1)
, andflag
field is allCore.Compiler.IR_FLAG_REFINED
.
As such, if you wish to ensure that your IRCode
prints nicely, you should ensure that its linetable field has at least one element.
Mooncake.__line_numbers_to_block_numbers!
— Method__line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG)
Converts any edges in GotoNode
s, GotoIfNot
s, PhiNode
s, and :enter
expressions which refer to line numbers into references to block numbers. The cfg
provides the information required to perform this conversion.
For context, CodeInfo
objects have references to line numbers, while IRCode
uses block numbers.
This code is copied over directly from the body of Core.Compiler.inflate_ir!
.
Mooncake.__make_ref
— Method__make_ref(p::Type{P}) where {P}
Helper for reverse_data_ref_stmts
. Constructs a Ref
whose element type is the zero_like_rdata_type
for P
, and whose element is the zero-like rdata for P
.
Mooncake.__pop_blk_stack!
— Method__pop_blk_stack!(block_stack::BlockStack)
Equivalent to pop!(block_stack)
. Going via this function, rather than just calling pop!
directly, makes it easy to figure out how much time is spent popping the block stack when profiling performance, and to know that this function was hit when debugging.
Mooncake.__push_blk_stack!
— Method__push_blk_stack!(block_stack::BlockStack, id::Int32)
Equivalent to push!(block_stack, id)
. Going via this function, rather than just calling push! directly, is helpful for debugging and performance analysis – it makes it very straightforward to figure out much time is spent pushing to the block stack when profiling.
Mooncake.__run_rvs_pass!
— Method__run_rvs_pass!(
P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...
) where {sig}
Used in make_ad_stmts!
method for Expr(:call, ...)
and Expr(:invoke, ...)
.
Mooncake.__switch_case
— Method__switch_case(id::Int32, predecessor_id::Int32)
Helper function emitted by make_switch_stmts
.
Mooncake.__unflatten_codual_varargs
— Method__unflatten_codual_varargs(isva::Bool, args, ::Val{nargs}) where {nargs}
If isva and nargs=2, then inputs (CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0))
are transformed into (CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0)))
.
Mooncake.__value_and_gradient!!
— Method__value_and_gradient!!(rule, f::CoDual, x::CoDual...)
Note: this is not part of the public Mooncake.jl interface, and may change without warning.
Equivalent to __value_and_pullback!!(rule, 1.0, f, x...)
– assumes f
returns a Float64
.
# Set up the problem.
f(x, y) = sum(x .* y)
x = [2.0, 2.0]
y = [1.0, 1.0]
rule = build_rrule(f, x, y)
# Allocate tangents. These will be written to in-place. You are free to re-use these if you
# compute gradients multiple times.
tf = zero_tangent(f)
tx = zero_tangent(x)
ty = zero_tangent(y)
# Do AD.
Mooncake.__value_and_gradient!!(
rule, Mooncake.CoDual(f, tf), Mooncake.CoDual(x, tx), Mooncake.CoDual(y, ty)
)
# output
(4.0, (NoTangent(), [1.0, 1.0], [2.0, 2.0]))
Mooncake.__value_and_pullback!!
— Method__value_and_pullback!!(rule, ȳ, f::CoDual, x::CoDual...; y_cache=nothing)
Note: this is not part of the public Mooncake.jl interface, and may change without warning.
In-place version of value_and_pullback!!
in which the arguments have been wrapped in CoDual
s. Note that any mutable data in f
and x
will be incremented in-place. As such, if calling this function multiple times with different values of x
, should be careful to ensure that you zero-out the tangent fields of x
each time.
Mooncake._block_nums_to_ids
— Method_block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID}, InstVector}
Assign to each basic block in cfg
an ID
. Replace all integers referencing block numbers in insts
with the corresponding ID
. Return the ID
s and the updated instructions.
Mooncake._build_graph_of_cfg
— Method_build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph, Dict{ID, Int}}
Builds a SimpleDiGraph
, g
, representing of the CFG associated to blks
, where blks
comprises the collection of basic blocks associated to a BBCode
. This is a type from Graphs.jl, so constructing g
makes it straightforward to analyse the control flow structure of ir
using algorithms from Graphs.jl.
Returns a 2-tuple, whose first element is g
, and whose second element is a map from the ID
associated to each basic block in ir
, to the Int
corresponding to its node index in g
.
Mooncake._compute_all_predecessors
— Method_compute_all_predecessors(blks::Vector{BBlock})::Dict{ID, Vector{ID}}
Internal method implementing compute_all_predecessors
. This method is easier to construct test cases for because it only requires the collection of BBlocks
, not all of the other stuff that goes into a BBCode
.
Mooncake._compute_all_successors
— Method_compute_all_successors(blks::Vector{BBlock})::Dict{ID, Vector{ID}}
Internal method implementing compute_all_successors
. This method is easier to construct test cases for because it only requires the collection of BBlocks
, not all of the other stuff that goes into a BBCode
.
Mooncake._control_flow_graph
— Method_control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG
Internal function, used to implement control_flow_graph
. Easier to write test cases for because there is no need to construct an ensure BBCode object, just the BBlock
s.
Mooncake._distance_to_entry
— Method_distance_to_entry(blks::Vector{BBlock})::Vector{Int}
For each basic block in blks
, compute the distance from it to the entry point (the first block. The distance is typemax(Int)
if no path from the entry point to a given node.
Mooncake._find_id_uses!
— Method_find_id_uses!(d::Dict{ID, Bool}, x)
Helper function used in characterise_used_ids
. For all uses of ID
s in x
, set the corresponding value of d
to true
.
For example, if x = ReturnNode(ID(5))
, then this function sets d[ID(5)] = true
.
Mooncake._foreigncall_
— Methodfunction _foreigncall_(
::Val{name}, ::Val{RT}, AT::Tuple, ::Val{nreq}, ::Val{calling_convention}, x...
) where {name, RT, nreq, calling_convention}
:foreigncall nodes get translated into calls to this function. For example,
Expr(:foreigncall, :foo, Tout, (A, B), nreq, :ccall, args...)
becomes
_foreigncall_(Val(:foo), Val(Tout), (Val(A), Val(B)), Val(nreq), Val(:ccall), args...)
Please consult the Julia documentation for more information on how foreigncall nodes work, and consult this package's tests for examples.
Credit: Umlaut.jl has the original implementation of this function. This is largely copied over from there.
Mooncake._ids_to_line_numbers
— Method_ids_to_line_numbers(bb_code::BBCode)::InstVector
For each statement in bb_code
, returns a NewInstruction
in which every ID
is replaced by either an SSAValue
, or an Int64
/ Int32
which refers to an SSAValue
.
Mooncake._is_reachable
— Method_is_reachable(blks::Vector{BBlock})::Vector{Bool}
Computes a Vector
whose length is length(blks)
. The n
th element is true
iff it is possible for control flow to reach the n
th block.
Mooncake._lines_to_blocks
— Method_instructions_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector
Pulls out the instructions from insts
, and calls __line_numbers_to_block_numbers!
.
Mooncake._lower_switch_statements
— Method_lower_switch_statements(bb_code::BBCode)
Converts all Switch
s into a semantically-equivalent collection of GotoIfNot
s. See the Switch
docstring for an explanation of what is going on here.
Mooncake._map
— Method_map(f, x...)
Same as map
but requires all elements of x
to have equal length. The usual function map
doesn't enforce this for Array
s.
Mooncake._map_if_assigned!
— Method_map_if_assigned!(f::F, y::DenseArray, x1::DenseArray{P}, x2::DenseArray)
Similar to the other method of _map_if_assigned!
– for all n
, if x1[n]
is assigned, writes f(x1[n], x2[n])
to y[n]
, otherwise leaves y[n]
unchanged.
Requires that y
, x1
, and x2
have the same size.
Mooncake._map_if_assigned!
— Method_map_if_assigned!(f, y::DenseArray, x::DenseArray{P}) where {P}
For all n
, if x[n]
is assigned, then writes the value returned by f(x[n])
to y[n]
, otherwise leaves y[n]
unchanged.
Equivalent to map!(f, y, x)
if P
is a bits type as element will always be assigned.
Requires that y
and x
have the same size.
Mooncake._new_
— Method_new_(::Type{T}, x::Vararg{Any, N}) where {T, N}
One-liner which calls the :new
instruction with type T
with arguments x
.
Mooncake._remove_double_edges
— Method_remove_double_edges(ir::BBCode)::BBCode
If the dest
field of an IDGotoIfNot
node in block n
of ir
points towards the n+1
th block then we have two edges from block n
to block n+1
. This transformation replaces all such IDGotoIfNot
nodes with unconditional IDGotoNode
s pointing towards the n+1
th block in ir
.
Mooncake._sort_blocks!
— Method_sort_blocks!(ir::BBCode)::BBCode
Ensure that blocks appear in order of distance-from-entry-point, where distance the distance from block b to the entry point is defined to be the minimum number of basic blocks that must be passed through in order to reach b.
For reasons unknown (to me, Will), the compiler / optimiser needs this for inference to succeed. Since we do quite a lot of re-ordering on the reverse-pass of AD, this is a problem there.
WARNING: use with care. Only use if you are confident that arbitrary re-ordering of basic blocks in ir
is valid. Notably, this does not hold if you have any IDGotoIfNot
nodes in ir
.
Mooncake._splat_new_
— Method_splat_new_(::Type{P}, x::Tuple) where {P}
Function which replaces instances of :splatnew
.
Mooncake._ssa_to_ids
— Method_ssa_to_ids(d::SSAToIdDict, inst::NewInstruction)
Produce a new instance of inst
in which all instances of SSAValue
s are replaced with the ID
s prescribed by d
, all basic block numbers are replaced with the ID
s prescribed by d
, and GotoIfNot
, GotoNode
, and PhiNode
instances are replaced with the corresponding ID
versions.
Mooncake._ssas_to_ids
— Method_ssas_to_ids(insts::InstVector)::Tuple{Vector{ID}, InstVector}
Assigns an ID to each line in stmts
, and replaces each instance of an SSAValue
in each line with the corresponding ID
. For example, a call statement of the form Expr(:call, :f, %4)
is be replaced with Expr(:call, :f, id_assigned_to_%4)
.
Mooncake._to_ssas
— Method_to_ssas(d::Dict, inst::NewInstruction)
Like _ssas_to_ids
, but in reverse. Converts IDs to SSAValues / (integers corresponding to ssas).
Mooncake._typeof
— Method_typeof(x)
Central definition of typeof, which is specific to the use-required in this package.
Mooncake.ad_stmt_info
— Methodad_stmt_info(line::ID, comms_id::Union{ID, Nothing}, fwds, rvs)
Convenient constructor for ADStmtInfo
. If either fwds
or rvs
is not a vector, __vec
promotes it to a single-element Vector
.
Mooncake.add_data!
— Methodadd_data!(info::ADInfo, data)::ID
Equivalent to add_data!(info.shared_data_pairs, data)
.
Mooncake.add_data!
— Methodadd_data!(p::SharedDataPairs, data)::ID
Puts data
into p
, and returns the id
associated to it. This id
should be assumed to be available during the forwards- and reverse-passes of AD, and it should further be assumed that the value associated to this id
is always data
.
Mooncake.add_data_if_not_singleton!
— Methodadd_data_if_not_singleton!(p::Union{ADInfo, SharedDataPairs}, x)
Returns x
if it is a singleton, or the ID
of the ssa which will contain it on the forwards- and reverse-passes. The reason for this is that if something is a singleton, it can be inserted directly into the IR.
Mooncake.arrayify
— Methodarrayify(x::CoDual{<:AbstractArray{<:BlasRealFloat}})
Return the primal field of x
, and convert its fdata into an array of the same type as the primal. This operation is not guaranteed to be possible for all array types, but seems to be possible for all array types of interest so far.
Mooncake.can_produce_zero_rdata_from_type
— Methodcan_produce_zero_rdata_from_type(::Type{P}) where {P}
Returns whether or not the zero element of the rdata type for primal type P
can be obtained from P
alone.
Mooncake.characterise_unique_predecessor_blocks
— Methodcharacterise_unique_predecessor_blocks(blks::Vector{BBlock}) ->
Tuple{Dict{ID, Bool}, Dict{ID, Bool}}
We call a block b
a unique predecessor in the control flow graph associated to blks
if it is the only predecessor to all of its successors. Put differently we call b
a unique predecessor if, whenever control flow arrives in any of the successors of b
, we know for certain that the previous block must have been b
.
Returns two Dict
s. A value in the first Dict
is true
if the block associated to its key is a unique precessor, and is false
if not. A value in the second Dict
is true
if it has a single predecessor, and that predecessor is a unique predecessor.
Context:
This information is important for optimising AD because knowing that b
is a unique predecessor means that
- on the forwards-pass, there is no need to push the ID of
b
to the block stack when passing through it, and - on the reverse-pass, there is no need to pop the block stack when passing through one of the successors to
b
.
Utilising this reduces the overhead associated to doing AD. It is quite important when working with cheap loops – loops where the operations performed at each iteration are inexpensive – for which minimising memory pressure is critical to performance. It is also important for single-block functions, because it can be used to entirely avoid using a block stack at all.
Mooncake.characterise_used_ids
— Methodcharacterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID, Bool}
For each line in stmts
, determine whether it is referenced anywhere else in the code. Returns a dictionary containing the results. An element is false
if the corresponding ID
is unused, and true
if is used.
Mooncake.collect_stmts
— Methodcollect_stmts(ir::BBCode)::Vector{IDInstPair}
Produce a Vector
containing all of the statements in ir
. These are returned in order, so it is safe to assume that element n
refers to the nth
element of the IRCode
associated to ir
.
Mooncake.collect_stmts
— Methodcollect_stmts(bb::BBlock)::Vector{IDInstPair}
Returns a Vector
containing the ID
s and instructions associated to each line in bb
. These should be assumed to be ordered.
Mooncake.comms_channel
— Methodcomms_channel(info::ADStmtInfo)
Return the element of fwds
whose ID
is the communcation ID
. Returns Nothing
if comms_id
is nothing
.
Mooncake.compute_all_predecessors
— Methodcompute_all_predecessors(ir::BBCode)::Dict{ID, Vector{ID}}
Compute a map from the ID of each
BBlockin
ir` to its possible predecessors.
Mooncake.compute_all_successors
— Methodcompute_all_successors(ir::BBCode)::Dict{ID, Vector{ID}}
Compute a map from the ID of each
BBlockin
ir` to its possible successors.
Mooncake.conclude_rvs_block
— Methodconclude_rvs_block(
blk::BBlock, pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo
)
Generates code which is inserted at the end of each counterpart block in the reverse-pass. Handles phi nodes, and choosing the correct next block to switch to.
Mooncake.const_ad_stmt
— Methodconst_ad_stmt(stmt, line::ID, info::ADInfo)
Implementation of make_ad_stmts!
used for constants.
Mooncake.const_codual
— Methodconst_codual(stmt, info::ADInfo)
Build a CoDual
from stmt
, with zero / uninitialised fdata. If the resulting CoDual is a bits type, then it is returned. If it is not, then the CoDual is put into shared data, and the ID associated to it in the forwards- and reverse-passes returned.
Mooncake.const_codual_stmt
— Methodconst_codual_stmt(stmt, info::ADInfo)
Returns a :call
expression which will return a CoDual
whose primal is stmt
, and whose tangent is whatever uninit_tangent
returns.
Mooncake.control_flow_graph
— Methodcontrol_flow_graph(bb_code::BBCode)::Core.Compiler.CFG
Computes the Core.Compiler.CFG
object associated to this bb_code
.
Mooncake.create_comms_insts!
— Methodcreate_comms_insts!(ad_stmts_blocks::ADStmts, info::ADInfo)
This function produces code which can be inserted into the forwards-pass and reverse-pass at specific locations to implement the promise associated to the comms_id
field of the ADStmtInfo
type – namely that if you assign a value to comms_id
on the forwards-pass, the same value will be available at comms_id
on the reverse-pass.
For each basic block represented in ADStmts
:
- create a stack containing a
Tuple
which can hold all of the values associated to thecomms_id
s for each statement. Put this stack in shared data. - create instructions which can be inserted at the end of the block generated to perform the forwards-pass (in
forwards_pass_ir
) which will put all of the data associated to thecomms_id
s into shared data, and - create instruction which can be inserted at the start of the block generated to perform the reverse-pass (in
pullback_ir
), which will extract all of the data put into shared data by the instructions generated by the previous point, and assigned them to thecomms_id
s.
Returns two a Tuple{Vector{IDInstPair}, Vector{IDInstPair}
. The nth element of each Vector
corresponds to the instructions to be inserted into the forwards- and reverse passes resp. for the nth block in ad_stmts_blocks
.
Mooncake.fcodual_type
— Methodfcodual_type(P::Type)
The type of the CoDual
which contains instances of P
and its fdata.
Mooncake.fdata_field_type
— Methodfdata_field_type(::Type{P}, n::Int) where {P}
Returns the type of to the nth field of the fdata type associated to P
. Will be a PossiblyUninitTangent
if said field can be undefined.
Mooncake.fix_up_invoke_inference!
— Methodfix_up_invoke_inference!(ir::IRCode)
The Problem
Consider the following:
@noinline function bar!(x)
x .*= 2
end
function foo!(x)
bar!(x)
return nothing
end
In this case, the IR associated to Tuple{typeof(foo), Vector{Float64}}
will be something along the lines of
julia> Base.code_ircode_by_type(Tuple{typeof(foo), Vector{Float64}})
1-element Vector{Any}:
2 1 ─ invoke Main.bar!(_2::Vector{Float64})::Any
3 └── return Main.nothing
=> Nothing
Observe that the type inferred for the first line is Any
. Inference is at liberty to do this without any risk of performance problems because the first line is not used anywhere else in the function. Had this line been used elsewhere in the function, inference would have inferred its type to be Vector{Float64}
.
This causes performance problems for Mooncake, because it uses the return type to do various things, including allocating storage for quantities required on the reverse-pass. Consequently, inference infering Any
rather than Vector{Float64}
causes type instabilities in the code that Mooncake generates, which can have catastrophic conseqeuences for performance.
The Solution
:invoke
expressions contain the Core.MethodInstance
associated to them, which contains a Core.CodeCache
, which contains the return type of the :invoke
. This function looks for :invoke
statements whose return type is inferred to be Any
in ir
, and modifies it to be the return type given by the code cache.
Mooncake.flat_product
— Methodflat_product(xs...)
Equivalent to vec(collect(Iterators.product(xs...)))
.
Mooncake.foreigncall_to_call
— Methodforeigncall_to_call(inst, sp_map::Dict{Symbol, CC.VarState})
If inst
is a :foreigncall
expression translate it into an equivalent :call
expression. If anything else, just return inst
. See Mooncake._foreigncall_
for details.
sp_map
maps the names of the static parameters to their values. This function is intended to be called in the context of an IRCode
, in which case the values of sp_map
are given by the sptypes
field of said IRCode
. The keys should generally be obtained from the Method
from which the IRCode
is derived. See Mooncake.normalise!
for more details.
The purpose of this transformation is to make it possible to differentiate :foreigncall
expressions in the same way as a primitive :call
expression, i.e. via an rrule!!
.
Mooncake.forwards_pass_ir
— Methodforwards_pass_ir(ir::BBCode, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)
Produce the IR associated to the OpaqueClosure
which runs most of the forwards-pass.
Mooncake.fwd_ir
— Methodfwd_ir(
sig::Type{<:Tuple};
interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true
)::IRCode
!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.
Generate the Core.Compiler.IRCode
used to construct the forwards-pass of AD. Take a look at how build_rrule
makes use of generate_ir
to see exactly how this is used in practice.
For example, if you wanted to get the IR associated to the forwards pass for the call map(sin, randn(10))
, you could do either of the following:
julia> Mooncake.fwd_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
true
julia> Mooncake.fwd_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
true
Arguments
sig::Type{<:Tuple}
: the signature of the call to be differentiated.
Keyword Arguments
interp
: the interpreter to use to obtain the primal IR.debug_mode::Bool
: whether the generated IR should make use of Mooncake's debug mode.do_inline::Bool
: whether to apply an inlining pass prior to returning the ir generated by this function. This istrue
by default, but setting it tofalse
can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
Mooncake.gc_preserve
— Methodgc_preserve(xs...)
A no-op function. Its rrule!!
ensures that the memory associated to xs
is not freed until the pullback that it returns is run.
Mooncake.generate_ir
— Methodgenerate_ir(
interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true
)
Used by build_rrule
, and the various debugging tools: primalir, fwdsir, adjoint_ir.
Mooncake.get_const_primal_value
— Methodget_const_primal_value(x::GlobalRef)
Get the value associated to x
. For GlobalRef
s, verify that x
is indeed a constant.
Mooncake.get_primal_type
— Methodget_primal_type(info::ADInfo, x)
Returns the static / inferred type associated to x
.
Mooncake.get_rev_data_id
— Methodget_rev_data_id(info::ADInfo, x)
Returns the ID
associated to the line in the reverse pass which will contain the reverse data for x
. If x
is not an Argument
or ID
, then nothing
is returned.
Mooncake.get_tangent_field
— Methodget_tangent_field(t::Union{MutableTangent, Tangent}, i::Int)
Gets the i
th field of data in t
.
Has the same semantics that getfield!
would have if the data in the fields
field of t
were actually fields of t
. This is the moral equivalent of getfield
for MutableTangent
.
Mooncake.id_to_line_map
— Methodid_to_line_map(ir::BBCode)
Produces a Dict
mapping from each ID
associated with a line in ir
to its line number. This is isomorphic to mapping to its SSAValue
in IRCode
. Terminators do not have ID
s associated to them, so not every line in the original IRCode
is mapped to.
Mooncake.inc_args
— Methodinc_args(stmt)
Increment by 1
the n
field of any Argument
s present in stmt
.
Mooncake.increment_and_get_rdata!
— Methodincrement_and_get_rdata!(fdata, zero_rdata, cr_tangent)
Increment fdata
by the fdata component of the ChainRules.jl-style tangent, cr_tangent
, and return the rdata component of cr_tangent
by adding it to zero_rdata
.
Mooncake.increment_field!!
— Methodincrement_field!!(x::T, y::V, f) where {T, V}
increment!!
the field f
of x
by y
, and return the updated x
.
Mooncake.increment_rdata!!
— Methodincrement_rdata!!(t::T, r)::T where {T}
Increment the rdata component of tangent t
by r
, and return the updated tangent. Useful for implementation getfield-like rules for mutable structs, pointers, dicts, etc.
Mooncake.infer_ir!
— Methodinfer_ir!(ir::IRCode) -> IRCode
Runs type inference on ir
, which mutates ir
, and returns it.
Note: the compiler will not infer the types of anything where the corrsponding element of ir.stmts.flag
is not set to Core.Compiler.IR_FLAG_REFINED
. Nor will it attempt to refine the type of the value returned by a :invoke
expressions. Consequently, if you find that the types in your IR are not being refined, you may wish to check that neither of these things are happening.
Mooncake.insert_before_terminator!
— Methodinsert_before_terminator!(bb::BBlock, id::ID, inst::NewInstruction)::Nothing
If the final instruction in bb
is a Terminator
, insert inst
immediately before it. Otherwise, insert inst
at the end of the block.
Mooncake.interpolate_boundschecks!
— Methodinterpolate_boundschecks!(ir::IRCode)
For every x = Expr(:boundscheck, value)
in ir
, interpolate value
into all uses of x
. This is only required in order to ensure that literal versions of memoryrefget, memoryrefset!, getfield, and setfield! work effectively. If they are removed through improvements to the way that we handle constant propagation inside Mooncake, then this functionality can be removed.
Mooncake.intrinsic_to_function
— Methodintrinsic_to_function(inst)
If inst
is a :call
expression to a Core.IntrinsicFunction
, replace it with a call to the corresponding function
from Mooncake.IntrinsicsWrappers
, else return inst
.
cglobal
is a special case – it requires that its first argument be static in exactly the same way as :foreigncall
. See IntrinsicsWrappers.__cglobal
for more info.
The purpose of this transformation is to make it possible to use dispatch to write rules for intrinsic calls using dispatch in a type-stable way. See IntrinsicsWrappers
for more context.
Mooncake.ircode
— Functionircode(
inst::Vector{Any},
argtypes::Vector{Any},
sptypes::Vector{CC.VarState}=CC.VarState[],
) -> IRCode
Constructs an instance of an IRCode
. This is useful for constructing test cases with known properties.
No optimisations or type inference are performed on the resulting IRCode
, so that the IRCode
contains exactly what is intended by the caller. Please make use of infer_types!
if you require the types to be inferred.
Edges in PhiNode
s, GotoIfNot
s, and GotoNode
s found in inst
must refer to lines (as in CodeInfo
). In the IRCode
returned by this function, these line references are translated into block references.
Mooncake.is_always_fully_initialised
— Methodis_always_fully_initialised(P::DataType)::Bool
True if all fields in P
are always initialised. Put differently, there are no inner constructors which permit partial initialisation.
Mooncake.is_always_initialised
— Methodis_always_initialised(P::DataType, n::Int)::Bool
True if the n
th field of P
is always initialised. If the n
th fieldtype of P
isbitstype
, then this is distinct from asking whether the n
th field is always defined. An isbits field is always defined, but is not always explicitly initialised.
Mooncake.is_primitive
— Methodis_primitive(::Type{Ctx}, sig) where {Ctx}
Returns a Bool
specifying whether the methods specified by sig
are considered primitives in the context of contexts of type Ctx
.
is_primitive(DefaultCtx, Tuple{typeof(sin), Float64})
will return if calling sin(5.0)
should be treated as primitive when the context is a DefaultCtx
.
Observe that this information means that whether or not something is a primitive in a particular context depends only on static information, not any run-time information that might live in a particular instance of Ctx
.
Mooncake.is_reachable_return_node
— Methodis_reachable_return_node(x::ReturnNode)
Determine whether x
is a ReturnNode
, and if it is, if it is also reachable. This is purely a function of whether or not its val
field is defined or not.
Mooncake.is_unreachable_return_node
— Methodis_unreachable_return_node(x::ReturnNode)
Determine whehter x
is a ReturnNode
, and if it is, if it is also unreachable. This is purely a function of whether or not its val
field is defined or not.
Mooncake.is_used
— Methodis_used(info::ADInfo, id::ID)::Bool
Returns true
if id
is used by any of the lines in the ir, false otherwise.
Mooncake.is_vararg_and_sparam_names
— Methodis_vararg_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}}
Finds the method associated to sig
, and calls is_vararg_and_sparam_names
on it.
Mooncake.is_vararg_and_sparam_names
— Methodis_vararg_and_sparam_names(mi::Core.MethodInstance)
Calls is_vararg_and_sparam_names
on mi.def::Method
.
Mooncake.is_vararg_and_sparam_names
— Methodis_vararg_and_sparam_names(m::Method)
Returns a 2-tuple. The first element is true if m
is a vararg method, and false if not. The second element contains the names of the static parameters associated to m
.
Mooncake.lgetfield
— Methodlgetfield(x, f::Val)
An implementation of getfield
in which the the field f
is specified statically via a Val
. This enables the implementation to be type-stable even when it is not possible to constant-propagate f
. Moreover, it enable the pullback to also be type-stable.
It will always be the case that
getfield(x, :f) === lgetfield(x, Val(:f))
getfield(x, 2) === lgetfield(x, Val(2))
This approach is identical to the one taken by Zygote.jl
to circumvent the same problem. Zygote.jl
calls the function literal_getfield
, while we call it lgetfield
.
Mooncake.lgetfield
— Methodlgetfield(x, ::Val{f}, ::Val{order}) where {f, order}
Like getfield
, but with the field and access order encoded as types.
Mooncake.lift_gc_preservation
— Methodlift_gc_preserve(inst)
Expressions of the form
y = GC.@preserve x1 x2 foo(args...)
get lowered to
token = Expr(:gc_preserve_begin, x1, x2)
y = expr
Expr(:gc_preserve_end, token)
These expressions guarantee that any memory associated x1
and x2
not be freed until the :gc_preserve_end
expression is reached.
In the context of reverse-mode AD, we must ensure that the memory associated to x1
, x2
and their fdata is available during the reverse pass code associated to expr
. We do this by preventing the memory from being freed until the :gc_preserve_begin
is reached on the reverse pass.
To achieve this, we replace the primal code with
# store `x` in `pb_gc_preserve` to prevent it from being freed.
_, pb_gc_preserve = rrule!!(zero_fcodual(gc_preserve), x1, x2)
# Differentiate the `:call` expression in the usual way.
y, foo_pb = rrule!!(zero_fcodual(foo), args...)
# Do not permit the GC to free `x` here.
nothing
The pullback should be something along the lines of
# no pullback associated to `nothing`.
nothing
# Run the pullback associated to `foo` in the usual manner. `x` must be available.
_, dargs... = foo_pb(dy)
# No-op pullback associated to `gc_preserve`.
pb_gc_preserve(NoRData())
Mooncake.lift_getfield_and_others
— Methodlift_getfield_and_others(inst)
Converts expressions of the form getfield(x, :a)
into lgetfield(x, Val(:a))
. This has identical semantics, but is performant in the absence of proper constant propagation.
Does the same for...
Mooncake.lookup_ir
— Methodlookup_ir(
interp::AbstractInterpreter,
sig_or_mi::Union{Type{<:Tuple}, Core.MethodInstance},
)::Tuple{IRCode, T}
Get the unique IR associated to sig_or_mi
under interp
. Throws ArgumentError
s if there is no code found, or if more than one IRCode
instance returned.
Returns a tuple containing the IRCode
and its return type.
Mooncake.lsetfield!
— Methodlsetfield!(value, name::Val, x, [order::Val])
This function is to setfield!
what lgetfield
is to getfield
. It will always hold that
setfield!(copy(x), :f, v) == lsetfield!(copy(x), Val(:f), v)
setfield!(copy(x), 2, v) == lsetfield(copy(x), Val(2), v)
Mooncake.make_ad_stmts!
— Functionmake_ad_stmts!(inst::NewInstruction, line::ID, info::ADInfo)::ADStmtInfo
Every line in the primal code is associated to one or more lines in the forwards-pass of AD, and one or more lines in the pullback. This function has method specific to every node type in the Julia SSAIR.
Translates the instruction inst
, associated to line
in the primal, into a specification of what should happen for this instruction in the forwards- and reverse-passes of AD, and what data should be shared between the forwards- and reverse-passes. Returns this in the form of an ADStmtInfo
.
info
is a data structure containing various bits of global information that certain types of nodes need access to.
Mooncake.make_switch_stmts
— Methodmake_switch_stmts(
pred_ids::Vector{ID}, target_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo
)
preds_ids
comprises the ID
s associated to all possible predecessor blocks to the primal block under consideration. Suppose its value is [ID(1), ID(2), ID(3)]
, then make_switch_stmts
emits code along the lines of
prev_block = pop!(block_stack)
not_pred_was_1 = !(prev_block == ID(1))
not_pred_was_2 = !(prev_block == ID(2))
switch(
not_pred_was_1 => ID(1),
not_pred_was_2 => ID(2),
ID(3)
)
In words: make_switch_stmts
emits code which jumps to whichever block preceded the current block during the forwards-pass.
Mooncake.map_prod
— Methodmap_prod(f, xs...)
Equivalent to map(f, flat_product(xs...))
.
Mooncake.misty_closure
— Methodmisty_closure(
ret_type::Type,
ir::IRCode,
@nospecialize env...;
isva::Bool=false,
do_compile::Bool=true,
)
Identical to Mooncake.opaque_closure
, but returns a MistyClosure
closure rather than a Core.OpaqueClosure
.
Mooncake.new_inst
— Functionnew_inst(stmt, type=Any, flag=CC.IR_FLAG_REFINED)::NewInstruction
Create a NewInstruction
with fields:
stmt
=stmt
type
=type
info
=CC.NoCallInfo()
line
=Int32(1)
flag
=flag
Mooncake.new_inst_vec
— Methodnew_inst_vec(x::CC.InstructionStream)
Convert an Compiler.InstructionStream
into a list of Compiler.NewInstruction
s.
Mooncake.new_to_call
— Methodnew_to_call(x)
If instruction x
is a :new
expression, replace it with a :call
to Mooncake._new_
. Otherwise, return x
.
The purpose of this transformation is to make it possible to differentiate :new
expressions in the same way as a primitive :call
expression, i.e. via an rrule!!
.
Mooncake.normalise!
— Methodnormalise!(ir::IRCode, spnames::Vector{Symbol})
Apply a sequence of standardising transformations to ir
which leaves its semantics unchanged, but makes AD more straightforward. In particular, replace
:foreigncall
Expr
s with:call
s toMooncake._foreigncall_
,:new
Expr
s with:call
s toMooncake._new_
,:splatnew
Exprs with
:calls to
Mooncake.splatnew_`,Core.IntrinsicFunction
s with counterparts fromMooncake.IntrinsicWrappers
,getfield(x, 1)
withlgetfield(x, Val(1))
, and related transformations,memoryrefget
calls tolmemoryrefget
calls, and related transformations,gc_preserve_begin
/gc_preserve_end
exprs so that memory release is delayed.
spnames
are the names associated to the static parameters of ir
. These are needed when handling :foreigncall
expressions, in which it is not necessarily the case that all static parameter names have been translated into either types, or :static_parameter
expressions.
Unfortunately, the static parameter names are not retained in IRCode
, and the Method
from which the IRCode
is derived must be consulted. Mooncake.is_vararg_and_sparam_names
provides a convenient way to do this.
Mooncake.opaque_closure
— Methodopaque_closure(
ret_type::Type,
ir::IRCode,
@nospecialize env...;
isva::Bool=false,
do_compile::Bool=true,
)::Core.OpaqueClosure{<:Tuple, ret_type}
Construct a Core.OpaqueClosure
. Almost equivalent to Core.OpaqueClosure(ir, env...; isva, do_compile)
, but instead of letting Core.compute_oc_rettype
figure out the return type from ir
, impose ret_type
as the return type.
Warning
User beware: if the Core.OpaqueClosure
produced by this function ever returns anything which is not an instance of a subtype of ret_type
, you should expect all kinds of awful things to happen, such as segfaults. You have been warned!
Extended Help
This is needed in Mooncake.jl because make extensive use of our ability to know the return type of a couple of specific OpaqueClosure
s without actually having constructed them – see LazyDerivedRule
. Without the capability to specify the return type, we have to guess what type compute_ir_rettype
will return for a given IRCode
before we have constructed the IRCode
and run type inference on it. This exposes us to details of type inference, which are not part of the public interface of the language, and can therefore vary from Julia version to Julia version (including patch versions). Moreover, even for a fixed Julia version it can be extremely hard to predict exactly what type inference will infer to be the return type of a function.
Failing to correctly guess the return type can happen for a number of reasons, and the kinds of errors that tend to be generated when this fails tell you very little about the underlying cause of the problem.
By specifying the return type ourselves, we remove this dependence. The price we pay for this is the potential for segfaults etc if we fail to specify ret_type
correctly.
Mooncake.optimise_ir!
— Methodoptimise_ir!(ir::IRCode, show_ir=false)
Run a fairly standard optimisation pass on ir
. If show_ir
is true
, displays the IR to stdout
at various points in the pipeline – this is sometimes useful for debugging.
Mooncake.phi_nodes
— Methodphi_nodes(bb::BBlock)::Tuple{Vector{ID}, Vector{IDPhiNode}}
Returns all of the IDPhiNode
s at the start of bb
, along with their ID
s. If there are no IDPhiNode
s at the start of bb
, then both vectors will be empty.
Mooncake.prepare_gradient_cache
— Methodprepare_gradient_cache(f, x...)
WARNING: experimental functionality. Interface subject to change without warning!
Returns a cache
which can be passed to value_and_gradient!!
. See the docstring for Mooncake.value_and_gradient!!
for more info.
Mooncake.prepare_pullback_cache
— Methodprepare_pullback_cache(f, x...)
WARNING: experimental functionality. Interface subject to change without warning!
Returns a cache
which can be passed to value_and_gradient!!
. See the docstring for Mooncake.value_and_gradient!!
for more info.
Mooncake.primal_ir
— Methodprimal_ir(sig::Type{<:Tuple}; interp=get_interpreter())::IRCode
!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.
Get the Core.Compiler.IRCode
associated to sig
from which the a rule can be derived. Roughly equivalent to Base.code_ircode_by_type(sig; interp)
.
For example, if you wanted to get the IR associated to the call map(sin, randn(10))
, you could do one of the following calls:
julia> Mooncake.primal_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
true
julia> Mooncake.primal_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
true
Mooncake.pullback_ir
— Methodpullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)
Produce the IR associated to the OpaqueClosure
which runs most of the pullback.
Mooncake.pullback_type
— Methodpullback_type(Trule, arg_types)
Get a bound on the pullback type, given a rule and associated primal types.
Mooncake.rdata_field_type
— Methodrdata_field_type(::Type{P}, n::Int) where {P}
Returns the type of to the nth field of the rdata type associated to P
. Will be a PossiblyUninitTangent
if said field can be undefined.
Mooncake.remove_unreachable_blocks!
— Methodremove_unreachable_blocks!(ir::BBCode)::BBCode
If a basic block in ir
cannot possibly be reached during execution, then it can be safely removed from ir
without changing its functionality. A block is unreachable if either:
- it has no predecessors and it is not the first block, or
- all of its predecessors are themselves unreachable.
For example, consider the following IR:
julia> ir = Mooncake.ircode(
Any[Core.ReturnNode(nothing), Expr(:call, sin, 5), Core.ReturnNode(Core.SSAValue(2))],
Any[Any, Any, Any],
);
There is no possible way to reach the second basic block (lines 2 and 3). Applying this function will therefore remove it, yielding the following:
julia> Mooncake.IRCode(Mooncake.remove_unreachable_blocks!(Mooncake.BBCode(ir)))
1 1 ─ return nothing
In the blocks which have not been removed, there may be references to blocks which have been removed. For example, the edge
s in a PhiNode
may contain a reference to a removed block. These references are removed in-place from these remaining blocks, so this function will (in general) modify ir
.
Mooncake.replace_captures
— Methodreplace_captures(mc::Tmc, new_captures) where {Tmc<:MistyClosure}
Same as replace_captures
for Core.OpaqueClosure
s, but returns a new MistyClosure
.
Mooncake.replace_captures
— Methodreplace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure}
Given an OpaqueClosure
oc
, create a new OpaqueClosure
of the same type, but with new captured variables. This is needed for efficiency reasons – if build_rrule
is called repeatedly with the same signature and intepreter, it is important to avoid recompiling the OpaqueClosure
s that it produces multiple times, because it can be quite expensive to do so.
Mooncake.replace_uses_with!
— Methodreplace_uses_with!(stmt, def::Union{Argument, SSAValue}, val)
Replace all uses of def
with val
in the single statement stmt
. Note: this function is highly incomplete, really only working correctly for a specific function in ir_normalisation.jl
. You probably do not want to use it.
Mooncake.reverse_data_ref_stmts
— Methodreverse_data_ref_stmts(info::ADInfo)
Create the statements which initialise the reverse-data Ref
s.
Mooncake.rrule_wrapper
— Methodrrule_wrapper(f::CoDual, args::CoDual...)
Used to implement rrule!!
s via ChainRulesCore.rrule
.
Given a function foo
, argument types arg_types
, and a method of ChainRulesCore.rrule
which applies to these, you can make use of this function as follows:
Mooncake.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...}
function Mooncake.rrule!!(f::CoDual{typeof(foo)}, args::CoDual...)
return rrule_wrapper(f, args...)
end
Assumes that methods of to_cr_tangent
and to_mooncake_tangent
are defined such that you can convert between the different representations of tangents that Mooncake and ChainRulesCore expect.
Furthermore, it is essential that
f(args)
does not mutatef
orargs
, and- the result of
f(args)
does not alias any data stored inf
orargs
.
Subject to some constraints, you can use the @from_rrule
macro to reduce the amount of boilerplate code that you are required to write even further.
Mooncake.rule_type
— Methodrule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C}
Compute the concrete type of the rule that will be returned from build_rrule
. This is important for performance in dynamic dispatch, and to ensure that recursion works properly.
Mooncake.rvs_ir
— Methodrvs_ir(
sig::Type{<:Tuple};
interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true
)::IRCode
!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.
Generate the Core.Compiler.IRCode
used to construct the reverse-pass of AD. Take a look at how build_rrule
makes use of generate_ir
to see exactly how this is used in practice.
For example, if you wanted to get the IR associated to the reverse pass for the call map(sin, randn(10))
, you could do either of the following:
julia> Mooncake.rvs_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
true
julia> Mooncake.rvs_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
true
Arguments
sig::Type{<:Tuple}
: the signature of the call to be differentiated.
Keyword Arguments
interp
: the interpreter to use to obtain the primal IR.debug_mode::Bool
: whether the generated IR should make use of Mooncake's debug mode.do_inline::Bool
: whether to apply an inlining pass prior to returning the ir generated by this function. This istrue
by default, but setting it tofalse
can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
Mooncake.rvs_phi_block
— Methodrvs_phi_block(pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo)
Produces a BBlock
which runs the reverse-pass for the edge associated to pred_id
in a collection of IDPhiNode
s, and then goes to the block associated to pred_id
.
For example, suppose that we encounter the following collection of PhiNode
s at the start of some block:
%6 = φ (#2 => _1, #3 => %5)
%7 = φ (#2 => 5., #3 => _2)
Let the tangent refs associated to %6
, %7
, and _1
be denoted
t%6,
t%7, and
t1resp., and let
predidbe
#2`, then this function will produce a basic block of the form
increment_ref!(t_1, t%6)
nothing
goto #2
The call to increment_ref!
appears because _1
is the value associated to%6
when the primal code comes from #2
. Similarly, the goto #2
statement appears because we came from #2
on the forwards-pass. There is no increment_ref!
associated to %7
because 5.
is a constant. We emit a nothing
statement, which the compiler will happily optimise away later on.
The same ideas apply if pred_id
were #3
. The block would end with #3
, and there would be two increment_ref!
calls because both %5
and _2
are not constants.
Mooncake.seed_id!
— Methodseed_id!()
Set the global counter used to ensure ID uniqueness to 0. This is useful when you want to ensure determinism between two runs of the same function which makes use of ID
s.
This is akin to setting the random seed associated to a random number generator globally.
Mooncake.set_tangent_field!
— Methodset_tangent_field!(t::MutableTangent{Tfields}, i::Int, x) where {Tfields}
Sets the value of the i
th field of the data in t
to value x
.
Has the same semantics that setfield!
would have if the data in the fields
field of t
were actually fields of t
. This is the moral equivalent of setfield!
for MutableTangent
.
Mooncake.shared_data_stmts
— Methodshared_data_stmts(p::SharedDataPairs)::Vector{IDInstPair}
Produce a sequence of id-statment pairs which will extract the data from shared_data_tuple(p)
such that the correct value is associated to the correct ID
.
For example, if p.pairs
is
[(ID(5), 5.0), (ID(3), "hello")]
then the output of this function is
IDInstPair[
(ID(5), new_inst(:(getfield(_1, 1)))),
(ID(3), new_inst(:(getfield(_1, 2)))),
]
Mooncake.shared_data_tuple
— Methodshared_data_tuple(p::SharedDataPairs)::Tuple
Create the tuple that will constitute the captured variables in the forwards- and reverse- pass OpaqueClosure
s.
For example, if p.pairs
is
[(ID(5), 5.0), (ID(3), "hello")]
then the output of this function is
(5.0, "hello")
Mooncake.sparam_names
— Methodsparam_names(m::Core.Method)::Vector{Symbol}
Returns the names of all of the static parameters in m
.
Mooncake.splatnew_to_call
— Methodsplatnew_to_call(x)
If instruction x
is a :splatnew
expression, replace it with a :call
to Mooncake._splat_new_
. Otherwise return x
.
The purpose of this transformation is to make it possible to differentiate :splatnew
expressions in the same way as a primitive :call
expression, i.e. via an rrule!!
.
Mooncake.stmt
— Methodstmt(ir::CC.InstructionStream)
Get the field containing the instructions in ir
. This changed name in 1.11 from inst
to stmt
.
Mooncake.tangent_field_type
— Methodtangent_field_type(::Type{P}, n::Int) where {P}
Returns the type that lives in the nth elements of fields
in a Tangent
/ MutableTangent
. Will either be the tangent_type
of the nth fieldtype of P
, or the tangent_type
wrapped in a PossiblyUninitTangent
. The latter case only occurs if it is possible for the field to be undefined.
Mooncake.tangent_test_cases
— Methodtangent_test_cases()
Constructs a Vector
of Tuple
s containing test cases for the tangent infrastructure.
If the returned tuple has 2 elements, the elements should be interpreted as follows: 1 - interface_only 2 - primal value
interface_only is a Bool which will be used to determine which subset of tests to run.
If the returned tuple has 5 elements, then the elements are interpreted as follows: 1 - interface_only 2 - primal value 3, 4, 5 - tangents, where <5> == increment!!(<3>, <4>).
Test cases in the first format make use of zero_tangent
/ randn_tangent
etc to generate tangents, but they're unable to check that increment!!
is correct in an absolute sense.
Mooncake.terminator
— Methodterminator(bb::BBlock)
Returns the terminator associated to bb
. If the last instruction in bb
isa Terminator
then that is returned, otherwise nothing
is returned.
Mooncake.to_cr_tangent
— Methodto_cr_tangent(t)
Convert a Mooncake tangent into a type that ChainRules.jl rrule
s expect to see.
Mooncake.tuple_map
— Methodtuple_map(f::F, x::Tuple) where {F}
This function is largely equivalent to map(f, x)
, but always specialises on all of the element types of x
, regardless the length of x
. This contrasts with map
, in which the number of element types specialised upon is a fixed constant in the compiler.
As a consequence, if x
is very long, this function may have very large compile times.
tuple_map(f::F, x::Tuple, y::Tuple) where {F}
Binary extension of tuple_map
. Nearly equivalent to map(f, x, y)
, but guaranteed to specialise on all element types of x
and y
. Furthermore, errors if x
and y
aren't the same length, while map
will just produce a new tuple whose length is equal to the shorter of x
and y
.
Mooncake.unhandled_feature
— Methodunhandled_feature(msg::String)
Throw an UnhandledLanguageFeatureException
with message msg
.
Mooncake.uninit_codual
— Methoduninit_codual(x)
Equivalent to CoDual(x, uninit_tangent(x))
.
Mooncake.uninit_fcodual
— Methoduninit_fcodual(x)
Like zero_fcodual
, but doesn't guarantee that the value of the fdata is initialised. See implementation for details, as this function is subject to change.
Mooncake.uninit_tangent
— Methoduninit_tangent(x)
Related to zero_tangent
, but a bit different. Check current implementation for details – this docstring is intentionally non-specific in order to avoid becoming outdated.
Mooncake.verify_fdata_type
— Methodverify_fdata_type(P::Type, F::Type)::Nothing
Check that F
is a valid type for fdata associated to a primal of type P
. Returns nothing
if valid, throws an InvalidFDataException
if a problem is found.
This applies to both concrete and non-concrete P
. For example, if P
is the type inferred for a primal q::Q
, such that Q <: P
, then this method is still applicable.
Mooncake.verify_fdata_value
— Methodverify_fdata_value(p, f)::Nothing
Check that f
cannot be proven to be invalid fdata for p
.
This method attempts to provide some confidence that f
is valid fdata for p
by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.
Put differently, we cannot prove that f
is valid fdata, only that it is not obviously invalid.
Mooncake.verify_rdata_type
— Methodverify_rdata_type(P::Type, R::Type)::Nothing
Check that R
is a valid type for rdata associated to a primal of type P
. Returns nothing
if valid, throws an InvalidRDataException
if a problem is found.
This applies to both concrete and non-concrete P
. For example, if P
is the type inferred for a primal q::Q
, such that Q <: P
, then this method is still applicable.
Mooncake.verify_rdata_value
— Methodverify_rdata_value(p, r)::Nothing
Check that r
cannot be proven to be invalid rdata for p
.
This method attempts to provide some confidence that r
is valid rdata for p
by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.
Put differently, we cannot prove that r
is valid rdata, only that it is not obviously invalid.
Mooncake.zero_adjoint
— Methodzero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}
Utility functionality for constructing rrule!!
s for functions which produce adjoints which always return zero.
NOTE: you should only make use of this function if you cannot make use of the @zero_adjoint
macro.
You make use of this functionality by writing a method of Mooncake.rrule!!
, and passing all of its arguments (including the function itself) to this function. For example:
julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual
julia> foo(x::Vararg{Int}) = 5
foo (generic function with 1 method)
julia> is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(foo), Vararg{Int}}}) = true;
julia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);
julia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())
(NoRData(), NoRData(), NoRData())
WARNING: this is only correct if the output of primal(f)(map(primal, x)...)
does not alias anything in f
or x
. This is always the case if the result is a bits type, but more care may be required if it is not. ```
Mooncake.zero_like_rdata_from_type
— Methodzero_like_rdata_from_type(::Type{P}) where {P}
This is an internal implementation detail – you should generally not use this function.
Returns either the zero element of type rdata_type(tangent_type(P))
, or a ZeroRData
. It is always valid to return a ZeroRData
,
Mooncake.zero_like_rdata_type
— Methodzero_like_rdata_type(::Type{P}) where {P}
Indicates the type which will be returned by zero_like_rdata_from_type
. Will be the rdata type for P
if we can produce the zero rdata element given only P
, and will be the union of R
and ZeroRData
if an instance of P
is needed.
Mooncake.zero_rdata
— Methodzero_rdata(p)
Given value p
, return the zero element associated to its reverse data type.
Mooncake.zero_rdata_from_type
— Methodzero_rdata_from_type(::Type{P}) where {P}
Returns the zero element of rdata_type(tangent_type(P))
if this is possible given only P
. If not possible, returns an instance of CannotProduceZeroRDataFromType
.
For example, the zero rdata associated to any primal of type Float64
is 0.0
, so for Float64
s this function is simple. Similarly, if the rdata type for P
is NoRData
, that can simply be returned.
However, it is not possible to return the zero rdata element for abstract types e.g. Real
as the type does not uniquely determine the zero element – the rdata type for Real
is Any
.
These considerations apply recursively to tuples / namedtuples / structs, etc.
If you encounter a type which this function returns CannotProduceZeroRDataFromType
, but you believe this is done in error, please open an issue. This kind of problem does not constitute a correctness problem, but can be detrimental to performance, so should be dealt with.
Mooncake.@from_rrule
— Macro@from_rrule ctx sig [has_kwargs=false]
Convenience functionality to assist in using ChainRulesCore.rrule
s to write rrule!!
s.
Arguments
ctx
: A Mooncake context typesig
: the signature which you wish to assert should be a primitive inMooncake.jl
, and use an existingChainRulesCore.rrule
to implement this functionality.has_kwargs
: aBool
state whether or not the function has keyword arguments. This feature has the same limitations asChainRulesCore.rrule
– the derivative w.r.t. all kwargs must be zero.
Example Usage
A Basic Example
julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils
julia> using ChainRulesCore
julia> foo(x::Real) = 5x;
julia> function ChainRulesCore.rrule(::typeof(foo), x::Real)
foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), 5Ω
return foo(x), foo_pb
end;
julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat}
julia> rrule!!(zero_fcodual(foo), zero_fcodual(5.0))[2](1.0)
(NoRData(), 5.0)
julia> # Check that the rule works as intended.
TestUtils.test_rule(Xoshiro(123), foo, 5.0; is_primitive=true)
Test Passed
An Example with Keyword Arguments
julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils
julia> using ChainRulesCore
julia> foo(x::Real; cond::Bool) = cond ? 5x : 4x;
julia> function ChainRulesCore.rrule(::typeof(foo), x::Real; cond::Bool)
foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), cond ? 5Ω : 4Ω
return foo(x; cond), foo_pb
end;
julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} true
julia> _, pb = rrule!!(
zero_fcodual(Core.kwcall),
zero_fcodual((cond=false, )),
zero_fcodual(foo),
zero_fcodual(5.0),
);
julia> pb(3.0)
(NoRData(), NoRData(), NoRData(), 12.0)
julia> # Check that the rule works as intended.
TestUtils.test_rule(
Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true
)
Test Passed
Notice that, in order to access the kwarg method we must call the method of Core.kwcall
, as Mooncake's rrule!!
does not itself permit the use of kwargs.
Limitations
It is your responsibility to ensure that
- calls with signature
sig
do not mutate their arguments, - the output of calls with signature
sig
does not alias any of the inputs.
As with all hand-written rules, you should definitely make use of TestUtils.test_rule
to verify correctness on some test cases.
Argument Type Constraints
Many methods of ChainRuleCore.rrule
are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature
Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}
There are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.
Suffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule
will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat
argument, i.e. Union{Float16, Float32, Float64}
, but it is usually not possible to know that the rule is correct for all possible subtypes of Real
that someone might define.
Conversions Between Different Tangent Type Systems
Under the hood, this functionality relies on two functions: Mooncake.to_cr_tangent
, and Mooncake.increment_and_get_rdata!
. These two functions handle conversion to / from Mooncake
tangent types and ChainRulesCore
tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule
does not work in your case because the required method of either of these functions does not exist, please open an issue.
Mooncake.@is_primitive
— Macro@is_primitive context_type signature
Creates a method of is_primitive
which always returns true
for the context_type and signature
provided. For example
@is_primitive MinimalCtx Tuple{typeof(foo), Float64}
is equivalent to
is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(foo), Float64}}) = true
You should implemented more complicated method of is_primitive
in the usual way.
Mooncake.@mooncake_overlay
— Macro@mooncake_overlay method_expr
Define a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.
For example, suppose that you have a function
julia> foo(x::Float64) = bar(x)
foo (generic function with 1 method)
where Mooncake.jl fails to differentiate bar
for some reason. If you have access to another function baz
, which does the same thing as bar
, but does so in a way which Mooncake.jl can differentiate, you can simply write:
julia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x)
When looking up the code for foo(::Float64)
, Mooncake.jl will see this method, rather than the original, and differentiate it instead.
A Worked Example
To demonstrate how to use @mooncake_overlay
s in practice, we here demonstrate how the answer that Mooncake.jl gives changes if you change the definition of a function using a @mooncake_overlay
. Do not do this in practice – this is just a simple way to demonostrate how to use overlays!
First, consider a simple example:
julia> scale(x) = 2x
scale (generic function with 1 method)
julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});
julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(10.0, (NoTangent(), 2.0))
We can use @mooncake_overlay
to change the definition which Mooncake.jl sees:
julia> Mooncake.@mooncake_overlay scale(x) = 3x
julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});
julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(15.0, (NoTangent(), 3.0))
As can be seen from the output, the result of differentiating using Mooncake.jl has changed to reflect the overlay-ed definition of the method.
Additionally, it is possible to use the usual multi-line syntax to declare an overlay:
julia> Mooncake.@mooncake_overlay function scale(x)
return 4x
end
julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});
julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(20.0, (NoTangent(), 4.0))
Mooncake.@zero_adjoint
— Macro@zero_adjoint ctx sig
Defines is_primitive(context_type, sig) = true
, and defines a method of Mooncake.rrule!!
which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable
.
For example:
julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive
julia> foo(x) = 5
foo (generic function with 1 method)
julia> @zero_adjoint DefaultCtx Tuple{typeof(foo), Any}
julia> is_primitive(DefaultCtx, Tuple{typeof(foo), Any})
true
julia> rrule!!(zero_fcodual(foo), zero_fcodual(3.0))[2](NoRData())
(NoRData(), 0.0)
Limited support for Vararg
s is also available. For example
julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive
julia> foo_varargs(x...) = 5
foo_varargs (generic function with 1 method)
julia> @zero_adjoint DefaultCtx Tuple{typeof(foo_varargs), Vararg}
julia> is_primitive(DefaultCtx, Tuple{typeof(foo_varargs), Any, Float64, Int})
true
julia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData())
(NoRData(), 0.0, NoRData())
Be aware that it is not currently possible to specify any of the type parameters of the Vararg
. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}}
will not work with this macro.
WARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x
will yield incorrect results.
As always, you should use TestUtils.test_rule
to ensure that you've not made a mistake.
Signatures Unsupported By This Macro
If the signature you wish to apply @zero_adjoint
to is not supported, for example because it uses a Vararg
with a type parameter, you can still make use of zero_adjoint
.
Mooncake.IntrinsicsWrappers
— Modulemodule IntrinsicsWrappers
The purpose of this module
is to associate to each function in Core.Intrinsics
a regular Julia function.
To understand the rationale for this observe that, unlike regular Julia functions, each Core.IntrinsicFunction
in Core.Intrinsics
does not have its own type. Rather, they are instances of Core.IntrinsicFunction
. To see this, observe that
julia> typeof(Core.Intrinsics.add_float)
Core.IntrinsicFunction
julia> typeof(Core.Intrinsics.sub_float)
Core.IntrinsicFunction
While we could simply write a rule for Core.IntrinsicFunction
, this would (naively) lead to a large list of conditionals of the form
if f === Core.Intrinsics.add_float
# return add_float and its pullback
elseif f === Core.Intrinsics.sub_float
# return add_float and its pullback
elseif
...
end
which has the potential to cause quite substantial type instabilities. (This might not be true anymore – see extended help for more context).
Instead, we map each Core.IntrinsicFunction
to one of the regular Julia functions in Mooncake.IntrinsicsWrappers
, to which we can dispatch in the usual way.
Extended Help
It is possible that owing to improvements in constant propagation in the Julia compiler in version 1.10, we actually could get away with just writing a single method of rrule!!
to handle all intrinsics, so this dispatch-based mechanism might be unnecessary. Someone should investigate this. Discussed at https://github.com/compintell/Mooncake.jl/issues/387 .