Internal Docstrings
Docstrings listed here are not part of the public Mooncake.jl interface. Consequently, they can change between non-breaking changes to Mooncake.jl without warning.
The purpose of this is to make it easy for developers to find docstrings straightforwardly via the docs, as opposed to having to ctrl+f through Mooncake.jl's source code, or looking at the docstrings via the Julia REPL.
Mooncake.GLOBAL_INTERPRETER
— Constantconst GLOBAL_INTERPRETER
Globally cached interpreter. Should only be accessed via get_interpreter
.
Mooncake.ADInfo
— TypeADInfo
This data structure is used to hold "global" information associated to a particular call to build_rrule
. It is used as a means of communication between make_ad_stmts!
and the codegen which produces the forwards- and reverse-passes.
interp
: aMooncakeInterpreter
.block_stack_id
: the ID associated to the block stack – the stack which keeps track of which blocks we visited during the forwards-pass, and which is used on the reverse-pass to determine which blocks to visit.block_stack
: the block stack. Can always be found atblock_stack_id
in the forwards- and reverse-passes.entry_id
: ID associated to the block inserted at the start of execution in the the forwards-pass, and the end of execution in the pullback.shared_data_pairs
: theSharedDataPairs
used to define the captured variables passed to both the forwards- and reverse-passes.arg_types
: a map fromArgument
to its static type.ssa_insts
: a map fromID
associated to lines to the primalNewInstruction
. This contains the line of code, its static / inferred type, and some other detailss. SeeCore.Compiler.NewInstruction
for a full list of fields.arg_rdata_ref_ids
: the dict mapping from arguments to theID
which creates and initialises theRef
which contains the reverse data associated to that argument. Recall that the heap allocations associated to thisRef
are always optimised away in the final programme.ssa_rdata_ref_ids
: the same asarg_rdata_ref_ids
, but for eachID
associated to an ssa rather than each argument.debug_mode
: iftrue
, run in "debug mode" – wraps all rule calls inDebugRRule
. This is applied recursively, so that debug mode is also switched on in derived rules.is_used_dict
: for eachID
associated to a line of code, isfalse
if line is not used anywhere in any other line of code.lazy_zero_rdata_ref_id
: for any arguments whose type doesn't permit the construction of a zero-valued rdata directly from the type alone (e.g. a struct with an abstractly- typed field), we need to have a zero-valued rdata available on the reverse-pass so that this zero-valued rdata can be returned if the argument (or a part of it) is never used during the forwards-pass and consequently doesn't obtain a value on the reverse-pass. To achieve this, we construct aLazyZeroRData
for each of the arguments on the forwards-pass, and make use of it on the reverse-pass. This field is the ID that will be associated to this information.
Mooncake.ADStmtInfo
— TypeADStmtInfo
Data structure which contains the result of make_ad_stmts!
. Fields are
line
: the ID associated to the primal line from which this is derivedcomms_id
: anID
from one of the lines infwds
, whose value will be made available on the reverse-pass in the sameID
. Nothing is asserted about how this value is made available on the reverse-pass of AD, so this package is free to do this in whichever way is most efficient, in particular to group these communicationID
on a per-block basis.fwds
: the instructions which run the forwards-pass of ADrvs
: the instructions which run the reverse-pass of AD / the pullback
Mooncake.BlockStack
— TypeThe block stack is the stack used to keep track of which basic blocks are visited on the forwards pass, and therefore which blocks need to be visited on the reverse pass. There is one block stack per derived rule. By using Int32, we assume that there aren't more than typemax(Int32)
unique basic blocks in a given function, which ought to be reasonable.
Mooncake.CannotProduceZeroRDataFromType
— TypeCannotProduceZeroRDataFromType()
Returned by zero_rdata_from_type
if is not possible to construct the zero rdata element for a given type. See zero_rdata_from_type
for more info.
Mooncake.DebugPullback
— TypeDebugPullback(pb, y, x)
Construct a callable which is equivalent to pb
, but which enforces type-based pre- and post-conditions to pb
. Let dx = pb.pb(dy)
, for some rdata dy
, then this function
- checks that
dy
has the correct rdata type fory
, and - checks that each element of
dx
has the correct rdata type forx
.
Reverse pass counterpart to DebugRRule
Mooncake.DebugPullback
— Method(pb::DebugPullback)(dy)
Apply type checking to enforce pre- and post-conditions on pb.pb
. See the docstring for DebugPullback
for details.
Mooncake.DebugRRule
— TypeDebugRRule(rule)
Construct a callable which is equivalent to rule
, but inserts additional type checking. In particular:
- check that the fdata in each argument is of the correct type for the primal
- check that the fdata in the
CoDual
returned from the rule is of the correct type for the primal.
This happens recursively. For example, each element of a Vector{Any}
is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.
Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).
Let rule
return y, pb!!
, then DebugRRule(rule)
returns y, DebugPullback(pb!!)
. DebugPullback
inserts the same kind of checks as DebugRRule
, but on the reverse-pass. See the docstring for details.
Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.
Mooncake.DebugRRule
— Method(rule::DebugRRule)(x::CoDual...)
Apply type checking to enforce pre- and post-conditions on rule.rule
. See the docstring for DebugRRule
for details.
Mooncake.DefaultCtx
— Typestruct DefaultCtx end
Context for all usually used AD primitives. Anything which is a primitive in a MinimalCtx is a primitive in the DefaultCtx automatically. If you are adding a rule for the sake of performance, it should be a primitive in the DefaultCtx, but not the MinimalCtx.
Mooncake.DynamicDerivedRule
— TypeDynamicDerivedRule(interp::MooncakeInterpreter, debug_mode::Bool)
For internal use only.
A callable data structure which, when invoked, calls an rrule specific to the dynamic types of its arguments. Stores rules in an internal cache to avoid re-deriving.
This is used to implement dynamic dispatch.
Mooncake.FData
— TypeFData(data::NamedTuple)
The component of a struct
which is propagated alongside the primal on the forwards-pass of AD. For example, the tangents for Float64
s do not need to be propagated on the forwards- pass of reverse-mode AD, so any Float64
fields of Tangent
do not need to appear in the associated FData
.
Mooncake.InvalidFDataException
— TypeInvalidFDataException(msg::String)
Exception indicating that there is a problem with the fdata associated to a primal.
Mooncake.InvalidRDataException
— TypeInvalidRDataException(msg::String)
Exception indicating that there is a problem with the rdata associated to a primal.
Mooncake.LazyDerivedRule
— TypeLazyDerivedRule(interp, mi::Core.MethodInstance, debug_mode::Bool)
For internal use only.
A type-stable wrapper around a DerivedRule
, which only instantiates the DerivedRule
when it is first called. This is useful, as it means that if a rule does not get run, it does not have to be derived.
If debug_mode
is true
, then the rule constructed will be a DebugRRule
. This is useful when debugging, but should usually be switched off for production code as it (in general) incurs some runtime overhead.
Note: the signature of the primal for which this is a rule is stored in the type. The only reason to keep this around is for debugging – it is very helpful to have this type visible in the stack trace when something goes wrong, as it allows you to trivially determine which bit of your code is the culprit.
Extended Help
There are two main reasons why deferring the construction of a DerivedRule
until we need to use it is crucial.
The first is to do with recursion. Consider the following function:
f(x) = x > 0 ? f(x - 1) : x
If we generate the IRCode
for this function, we will see something like the following:
julia> Base.code_ircode_by_type(Tuple{typeof(f), Float64})[1][1]
1 1 ─ %1 = Base.lt_float(0.0, _2)::Bool
│ %2 = Base.or_int(%1, false)::Bool
└── goto #6 if not %2
2 ─ %4 = Base.sub_float(_2, 1.0)::Float64
│ %5 = Base.lt_float(0.0, %4)::Bool
│ %6 = Base.or_int(%5, false)::Bool
└── goto #4 if not %6
3 ─ %8 = Base.sub_float(%4, 1.0)::Float64
│ %9 = invoke Main.f(%8::Float64)::Float64
└── goto #5
4 ─ goto #5
5 ┄ %12 = φ (#3 => %9, #4 => %4)::Float64
└── return %12
6 ─ return _2
Suppose that we decide to construct a DerivedRule
immediately whenever we find an :invoke
statement in a rule that we're currently building a DerivedRule
for. In the above example, we produce an infinite recursion when we attempt to produce a DerivedRule
for %9, because it has the same signature as the call which generates this IR. By instead adopting a policy of constructing a LazyDerivedRule
whenever we encounter an :invoke
statement, we avoid this problem.
The second reason that delaying the construction of a DerivedRule
, is essential is that it ensures that we don't derive rules for method instances which aren't run. Suppose that function B contains code for which we can't derive a rule – perhaps it contains an unsupported language feature like a PhiCNode
or an UpsilonNode
. Suppose that function A contains an :invoke
which refers to function B
, but that this call is on a branch which deals with error handling, and doesn't get run run unless something goes wrong. By deferring the derivation of the rule for B, we only ever attempt to derive it if we land on this error handling branch. Conversely, if we attempted to derive the rule for B when we derive the rule for A, we would be unable to complete the derivation of the rule for A.
Mooncake.LazyZeroRData
— TypeLazyZeroRData{P, Tdata}()
This type is a lazy placeholder for zero_like_rdata_from_type
. This is used to defer construction of zero data to the reverse pass. Calling instantiate
on an instance of this will construct a zero data.
Users should construct using LazyZeroRData(p)
, where p
is an value of type P
. This constructor, and instantiate
, are specialised to minimise the amount of data which must be stored. For example, Float64
s do not need any data, so LazyZeroRData(0.0)
produces an instance of a singleton type, meaning that various important optimisations can be performed in AD.
Mooncake.MinimalCtx
— Typestruct MinimalCtx end
Functions should only be primitives in this context if not making them so would cause AD to fail. In particular, do not add primitives to this context if you are writing them for performance only – instead, make these primitives in the DefaultCtx.
Mooncake.MutableTangent
— TypeMutableTangent{Tfields<:NamedTuple}
Default type used to represent the tangent of a mutable struct
. See tangent_type
for more info.
Mooncake.NoFData
— TypeNoFData
Singleton type which indicates that there is nothing to be propagated on the forwards-pass in addition to the primal data.
Mooncake.NoPullback
— MethodNoPullback(args::CoDual...)
Construct a NoPullback
from the arguments passed to an rrule!!
. For each argument, extracts the primal value, and constructs a LazyZeroRData
. These are stored in a NoPullback
which, in the reverse-pass of AD, instantiates these LazyZeroRData
s and returns them in order to perform the reverse-pass of AD.
The advantage of this approach is that if it is possible to construct the zero rdata element for each of the arguments lazily, the NoPullback
generated will be a singleton type. This means that AD can avoid generating a stack to store this pullback, which can result in significant performance improvements.
Mooncake.NoRData
— TypeNoRData()
Nothing to propagate backwards on the reverse-pass.
Mooncake.NoTangent
— TypeNoTangent
The type in question has no meaningful notion of a tangent space. Generally, you shouldn't use this – just let the default recursive tangent construction work. You might need to use this for primitive type
s though.
Mooncake.PossiblyUninitTangent
— TypePossiblyUninitTangent{T}
Represents a T
which maybe or may not be present. Does not distinguish between 0 and not being present.
Mooncake.RRuleZeroWrapper
— TypeRRuleZeroWrapper(rule)
This struct is used to ensure that ZeroRData
s, which are used as placeholder zero elements whenever an actual instance of a zero rdata for a particular primal type cannot be constructed without also having an instance of said type, never reach rules. On the pullback, we increment the cotangent dy by an amount equal to zero. This ensures that if it is a ZeroRData
, we instead get an actual zero of the correct type. If it is not a zero rdata, the computation should be elided via inlining + constant prop.
Mooncake.SharedDataPairs
— TypeSharedDataPairs()
A data structure used to manage the captured data in the OpaqueClosures
which implement the bulk of the forwards- and reverse-passes of AD. An entry (id, data)
at element n
of the pairs
field of this data structure means that data
will be available at register id
during the forwards- and reverse-passes of AD
.
This is achieved by storing all of the data in the pairs
field in the captured tuple which is passed to an OpaqueClosure
, and extracting this data into registers associated to the corresponding ID
s.
Mooncake.Stack
— TypeStack{T}()
A stack specialised for reverse-mode AD.
Semantically equivalent to a usual stack, but never de-allocates memory once allocated.
Mooncake.Tangent
— TypeTangent{Tfields<:NamedTuple}
Default type used to represent the tangent of a struct
. See tangent_type
for more info.
Mooncake.UnhandledLanguageFeatureException
— TypeUnhandledLanguageFeatureException(message::String)
An exception used to indicate that some aspect of the Julia language which AD cannot handle has been encountered.
Mooncake.ZeroRData
— TypeZeroRData()
Singleton type indicating zero-valued rdata. This should only ever appear as an intermediate quantity in the reverse-pass of AD when the type of the primal is not fully inferable, or a field of a type is abstractly typed.
If you see this anywhere in actual code, or if it appears in a hand-written rule, this is an error – please open an issue in such a situation.
Mooncake.__exclude_unsupported_output
— Method__exclude_unsupported_output(y)
Required for the robust design of value_and_pullback
, prepare_pullback_cache
. Ensures that y
contains no aliasing, circular references, Ptr
s or non differentiable datatypes. In the forward pass f(args...) output can only return a "Tree" like datastructure with leaf nodes as primitive types. Refer https://github.com/compintell/Mooncake.jl/issues/517#issuecomment-2715202789 and related issue for details. Internally calls __exclude_unsupported_output_internal!
. The design is modelled after zero_tangent
.
Mooncake.__exclude_unsupported_output_internal!
— Method__exclude_unsupported_output_internal(y::T, address_set::Set{UInt}) where {T}
For checking if outputy
is a valid Mutable/immutable composite or a primitive type. Performs a recursive depth first search over the function output y
with an isbitstype()
check base case. The visited memory addresses are stored inside address_set
. If the set already contains a newly visited address, it errors out indicating an Alais or Circular reference. Also errors out if y
is or contains a Pointer. It is called internally by __exclude_unsupported_output(y)
.
Mooncake.__flatten_varargs
— Method__flatten_varargs(isva::Bool, args, ::Val{nvargs}) where {nvargs}
If isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0).
Mooncake.__get_value
— Method__get_value(edge::ID, x::IDPhiNode)
Helper functionality for concludervsblock.
Mooncake.__insts_to_instruction_stream
— Method__insts_to_instruction_stream(insts::Vector{Any})
Produces an instruction stream whose
stmt
(v1.11 and up) /inst
(v1.10) field isinsts
,type
field is allAny
,info
field is allCore.Compiler.NoCallInfo
,line
field is allInt32(1)
, andflag
field is allCore.Compiler.IR_FLAG_REFINED
.
As such, if you wish to ensure that your IRCode
prints nicely, you should ensure that its linetable field has at least one element.
Mooncake.__pop_blk_stack!
— Method__pop_blk_stack!(block_stack::BlockStack)
Equivalent to pop!(block_stack)
. Going via this function, rather than just calling pop!
directly, makes it easy to figure out how much time is spent popping the block stack when profiling performance, and to know that this function was hit when debugging.
Mooncake.__push_blk_stack!
— Method__push_blk_stack!(block_stack::BlockStack, id::Int32)
Equivalent to push!(block_stack, id)
. Going via this function, rather than just calling push! directly, is helpful for debugging and performance analysis – it makes it very straightforward to figure out much time is spent pushing to the block stack when profiling.
Mooncake.__switch_case
— Method__switch_case(id::Int32, predecessor_id::Int32)
Helper function emitted by make_switch_stmts
.
Mooncake.__unflatten_codual_varargs
— Method__unflatten_codual_varargs(isva::Bool, args, ::Val{nargs}) where {nargs}
If isva and nargs=2, then inputs (CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0))
are transformed into (CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0)))
.
Mooncake.__value_and_gradient!!
— Method__value_and_gradient!!(rule, f::CoDual, x::CoDual...)
Note: this is not part of the public Mooncake.jl interface, and may change without warning.
Equivalent to __value_and_pullback!!(rule, 1.0, f, x...)
– assumes f
returns a Float64
.
# Set up the problem.
f(x, y) = sum(x .* y)
x = [2.0, 2.0]
y = [1.0, 1.0]
rule = build_rrule(f, x, y)
# Allocate tangents. These will be written to in-place. You are free to re-use these if you
# compute gradients multiple times.
tf = zero_tangent(f)
tx = zero_tangent(x)
ty = zero_tangent(y)
# Do AD.
Mooncake.__value_and_gradient!!(
rule, Mooncake.CoDual(f, tf), Mooncake.CoDual(x, tx), Mooncake.CoDual(y, ty)
)
# output
(4.0, (NoTangent(), [1.0, 1.0], [2.0, 2.0]))
Mooncake.__value_and_pullback!!
— Method__value_and_pullback!!(rule, ȳ, f::CoDual, x::CoDual...; y_cache=nothing)
Note: this is not part of the public Mooncake.jl interface, and may change without warning.
In-place version of value_and_pullback!!
in which the arguments have been wrapped in CoDual
s. Note that any mutable data in f
and x
will be incremented in-place. As such, if calling this function multiple times with different values of x
, should be careful to ensure that you zero-out the tangent fields of x
each time.
Mooncake._add_to_primal
— Function_add_to_primal(p::P, t::T, unsafe::Bool=false) where {P, T}
Adds t
to p
, returning a P
. It must be the case that tangent_type(P) == T
.
If unsafe
is true
and P
is a composite type, then _add_to_primal
will construct a new instance of P
by directly invoking the :new
instruction for P
, rather than attempting to use the default constructor for P
. This is fine if you are confident that the new P
constructed by adding t
to p
will always be a valid instance of P
, but could cause problems if you are not confident of this.
This is, for example, fine for the following type:
struct Foo{T}
x::Vector{T}
y::Vector{T}
function Foo(x::Vector{T}, y::Vector{T}) where {T}
@assert length(x) == length(y)
return new{T}(x, y)
end
end
Here, the value returned by _add_to_primal
will satisfy the invariant asserted in the inner constructor for Foo
.
Mooncake._diff
— Method_diff(p::P, q::P) where {P}
Required for testing.
Computes the difference between p
and q
, which must be of the same type, P
. Returns a tangent of type tangent_type(P)
.
Mooncake._dot
— Method_dot(t::T, s::T)::Float64 where {T}
Required for testing. Should be defined for all standard tangent types.
Inner product between tangents t
and s
. Must return a Float64
. Always available because all tangent types correspond to finite-dimensional vector spaces.
Mooncake._findall
— Method_findall(cond, x::Tuple)
Type-stable version of findall
for Tuple
s. Should constant-fold if cond
can be determined from the type of x
.
Mooncake._foreigncall_
— Methodfunction _foreigncall_(
::Val{name}, ::Val{RT}, AT::Tuple, ::Val{nreq}, ::Val{calling_convention}, x...
) where {name, RT, nreq, calling_convention}
:foreigncall nodes get translated into calls to this function. For example,
Expr(:foreigncall, :foo, Tout, (A, B), nreq, :ccall, args...)
becomes
_foreigncall_(Val(:foo), Val(Tout), (Val(A), Val(B)), Val(nreq), Val(:ccall), args...)
Please consult the Julia documentation for more information on how foreigncall nodes work, and consult this package's tests for examples.
Credit: Umlaut.jl has the original implementation of this function. This is largely copied over from there.
Mooncake._map
— Method_map(f, x...)
Same as map
but requires all elements of x
to have equal length. The usual function map
doesn't enforce this for Array
s.
Mooncake._map_if_assigned!
— Method_map_if_assigned!(f::F, y::DenseArray, x1::DenseArray{P}, x2::DenseArray)
Similar to the other method of _map_if_assigned!
– for all n
, if x1[n]
is assigned, writes f(x1[n], x2[n])
to y[n]
, otherwise leaves y[n]
unchanged.
Requires that y
, x1
, and x2
have the same size.
Mooncake._map_if_assigned!
— Method_map_if_assigned!(f, y::DenseArray, x::DenseArray{P}) where {P}
For all n
, if x[n]
is assigned, then writes the value returned by f(x[n])
to y[n]
, otherwise leaves y[n]
unchanged.
Equivalent to map!(f, y, x)
if P
is a bits type as element will always be assigned.
Requires that y
and x
have the same size.
Mooncake._new_
— Method_new_(::Type{T}, x::Vararg{Any, N}) where {T, N}
One-liner which calls the :new
instruction with type T
with arguments x
.
Mooncake._scale
— Method_scale(a::Float64, t::T) where {T}
Required for testing. Should be defined for all standard tangent types.
Multiply tangent t
by scalar a
. Always possible because any given tangent type must correspond to a vector field. Not using *
in order to avoid piracy.
Mooncake._splat_new_
— Method_splat_new_(::Type{P}, x::Tuple) where {P}
Function which replaces instances of :splatnew
.
Mooncake._typeof
— Method_typeof(x)
Central definition of typeof, which is specific to the use-required in this package.
Mooncake.ad_stmt_info
— Methodad_stmt_info(line::ID, comms_id::Union{ID, Nothing}, fwds, rvs)
Convenient constructor for ADStmtInfo
. If either fwds
or rvs
is not a vector, __vec
promotes it to a single-element Vector
.
Mooncake.add_data!
— Methodadd_data!(info::ADInfo, data)::ID
Equivalent to add_data!(info.shared_data_pairs, data)
.
Mooncake.add_data!
— Methodadd_data!(p::SharedDataPairs, data)::ID
Puts data
into p
, and returns the id
associated to it. This id
should be assumed to be available during the forwards- and reverse-passes of AD, and it should further be assumed that the value associated to this id
is always data
.
Mooncake.add_data_if_not_singleton!
— Methodadd_data_if_not_singleton!(p::Union{ADInfo, SharedDataPairs}, x)
Returns x
if it is a singleton, or the ID
of the ssa which will contain it on the forwards- and reverse-passes. The reason for this is that if something is a singleton, it can be inserted directly into the IR.
Mooncake.always_initialised
— Methodalways_initialised(::Type{P}) where {P}
Returns a tuple with number of fields equal to the number of fields in P
. The nth field is set to true
if the nth field of P
is initialised, and false
otherwise.
Mooncake.arrayify
— Methodarrayify(x::CoDual{<:AbstractArray{<:BlasFloat}})
Return the primal field of x
, and convert its fdata into an array of the same type as the primal. This operation is not guaranteed to be possible for all array types, but seems to be possible for all array types of interest so far.
Mooncake.build_primitive_rrule
— Methodbuild_primitive_rrule(sig::Type{<:Tuple})
Construct an rrule for signature sig
. For this function to be called in build_rrule
, you must also ensure that is_primitive(context_type, sig)
is true
. The callable returned by this must obey the rrule interface, but there are no restrictions on the type of callable itself. For example, you might return a callable struct
. By default, this function returns rrule!!
so, most of the time, you should just implement a method of rrule!!
.
Extended Help
The purpose of this function is to permit computation at rule construction time, which can be re-used at runtime. For example, you might wish to derive some information from sig
which you use at runtime (e.g. the fdata type of one of the arguments). While constant propagation will often optimise this kind of computation away, it will sometimes fail to do so in hard-to-predict circumstances. Consequently, if you need certain computations not to happen at runtime in order to guarantee good performance, you might wish to e.g. emit a callable struct
with type parameters which are the result of this computation. In this context, the motivation for using this function is the same as that of using staged programming (e.g. via @generated
functions) more generally.
Mooncake.build_rrule
— Methodbuild_rrule(args...; kwargs...)
Helper method: equivalent to extracting the signature from args
and calling build_rrule(sig; kwargs...)
.
Mooncake.build_rrule
— Methodbuild_rrule(sig::Type{<:Tuple}; kwargs...)
Helper method: Equivalent to build_rrule(Mooncake.get_interpreter(), sig; kwargs...)
.
Mooncake.build_rrule
— Methodbuild_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C}
Returns a DerivedRule
which is an rrule!!
for sig_or_mi
in context C
. See the docstring for rrule!!
for more info.
If debug_mode
is true
, then all calls to rules are replaced with calls to DebugRRule
s.
Mooncake.can_produce_zero_rdata_from_type
— Methodcan_produce_zero_rdata_from_type(::Type{P}) where {P}
Returns whether or not the zero element of the rdata type for primal type P
can be obtained from P
alone.
Mooncake.codual_type
— Methodcodual_type(P::Type)
The type of the CoDual
which contains instances of P
and associated tangents.
Mooncake.comms_channel
— Methodcomms_channel(info::ADStmtInfo)
Return the element of fwds
whose ID
is the communcation ID
. Returns Nothing
if comms_id
is nothing
.
Mooncake.conclude_rvs_block
— Methodconclude_rvs_block(
blk::BBlock, pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo
)
Generates code which is inserted at the end of each counterpart block in the reverse-pass. Handles phi nodes, and choosing the correct next block to switch to.
Mooncake.const_ad_stmt
— Methodconst_ad_stmt(stmt, line::ID, info::ADInfo)
Implementation of make_ad_stmts!
used for constants.
Mooncake.const_codual
— Methodconst_codual(stmt, info::ADInfo)
Build a CoDual
from stmt
, with zero / uninitialised fdata. If the resulting CoDual is a bits type, then it is returned. If it is not, then the CoDual is put into shared data, and the ID associated to it in the forwards- and reverse-passes returned.
Mooncake.const_codual_stmt
— Methodconst_codual_stmt(stmt, info::ADInfo)
Returns a :call
expression which will return a CoDual
whose primal is stmt
, and whose tangent is whatever uninit_tangent
returns.
Mooncake.const_prop_gotoifnots!
— Methodconst_prop_gotoifnots(ir::IRCode)
Replace all occurences in ir
of goto %n if not true
in block b
with a goto b + 1
, and all occurences of goto %n if not false
with goto n
, and make the adjustments to ir
that this necessitates.
Mooncake.create_comms_insts!
— Methodcreate_comms_insts!(ad_stmts_blocks::ADStmts, info::ADInfo)
This function produces code which can be inserted into the forwards-pass and reverse-pass at specific locations to implement the promise associated to the comms_id
field of the ADStmtInfo
type – namely that if you assign a value to comms_id
on the forwards-pass, the same value will be available at comms_id
on the reverse-pass.
For each basic block represented in ADStmts
:
- create a stack containing a
Tuple
which can hold all of the values associated to thecomms_id
s for each statement. Put this stack in shared data. - create instructions which can be inserted at the end of the block generated to perform the forwards-pass (in
forwards_pass_ir
) which will put all of the data associated to thecomms_id
s into shared data, and - create instruction which can be inserted at the start of the block generated to perform the reverse-pass (in
pullback_ir
), which will extract all of the data put into shared data by the instructions generated by the previous point, and assigned them to thecomms_id
s.
Returns two a Tuple{Vector{IDInstPair}, Vector{IDInstPair}
. The nth element of each Vector
corresponds to the instructions to be inserted into the forwards- and reverse passes resp. for the nth block in ad_stmts_blocks
.
Mooncake.deref_and_zero_stmts
— Methodderef_and_zero_stmts(P, ref_id, val_id)
Equivalent to something like
val = ref[]
ref[] = zero_rdata_from_type(P)
Mooncake.fcodual_type
— Methodfcodual_type(P::Type)
The type of the CoDual
which contains instances of P
and its fdata.
Mooncake.fdata
— Methodfdata(t)::fdata_type(typeof(t))
Extract the forwards data from tangent t
.
Mooncake.fdata_field_type
— Methodfdata_field_type(::Type{P}, n::Int) where {P}
Returns the type of to the nth field of the fdata type associated to P
. Will be a PossiblyUninitTangent
if said field can be undefined.
Mooncake.fdata_type
— Methodfdata_type(T)
Returns the type of the forwards data associated to a tangent of type T
.
Extended help
Rules in Mooncake.jl do not operate on tangents directly. Rather, functionality is defined to split each tangent into two components, that we call fdata (forwards-pass data) and rdata (reverse-pass data). In short, any component of a tangent which is identified by its address (e.g. a mutable struct
s or an Array
) gets passed around on the forwards-pass of AD and is incremented in-place on the reverse-pass, while components of tangents identified by their value get propagated and accumulated only on the reverse-pass.
Given a tangent type T
, you can find out what type its fdata and rdata must be with fdata_type(T)
and rdata_type(T)
respectively. A consequence of this is that there is exactly one valid fdata type and rdata type for each primal type.
Given a tangent t
, you can get its fdata and rdata using f = fdata(t)
and r = rdata(t)
respectively. f
and r
can be re-combined to recover the original tangent using the binary version of tangent
: tangent(f, r)
. It must always hold that
tangent(fdata(t), rdata(t)) === t
The need for all of this is explained in the docs, but for now it suffices to consider our running examples again, and to see what their fdata and rdata look like.
Int
Int
s are non-differentiable types, so there is nothing to pass around on the forwards- or reverse-pass. Therefore
julia> fdata_type(tangent_type(Int)), rdata_type(tangent_type(Int))
(NoFData, NoRData)
Float64
The tangent type of Float64
is Float64
. Float64
s are identified by their value / have no fixed address, so
julia> (fdata_type(Float64), rdata_type(Float64))
(NoFData, Float64)
Vector{Float64}
The tangent type of Vector{Float64}
is Vector{Float64}
. A Vector{Float64}
is identified by its address, so
julia> (fdata_type(Vector{Float64}), rdata_type(Vector{Float64}))
(Vector{Float64}, NoRData)
Tuple{Float64, Vector{Float64}, Int}
This is an example of a type which has both fdata and rdata. The tangent type for Tuple{Float64, Vector{Float64}, Int}
is Tuple{Float64, Vector{Float64}, NoTangent}
. Tuple
s have no fixed memory address, so we interogate each field on its own. We have already established the fdata and rdata types for each element, so we recurse to obtain:
julia> T = tangent_type(Tuple{Float64, Vector{Float64}, Int})
Tuple{Float64, Vector{Float64}, NoTangent}
julia> (fdata_type(T), rdata_type(T))
(Tuple{NoFData, Vector{Float64}, NoFData}, Tuple{Float64, NoRData, NoRData})
The zero tangent for (5.0, [5.0])
is t = (0.0, [0.0])
. fdata(t)
returns (NoFData(), [0.0])
, where the second element is ===
to the second element of t
. rdata(t)
returns (0.0, NoRData())
. In this example, t
contains a mixture of data, some of which is identified by its value, and some of which is identified by its address, so there is some fdata and some rdata.
Structs
Structs are handled in more-or-less the same way as Tuple
s, albeit with the possibility of undefined fields needing to be explicitly handled. For example, a struct such as
julia> struct Foo
x::Float64
y
z::Int
end
has tangent type
julia> tangent_type(Foo)
Tangent{@NamedTuple{x::Float64, y, z::NoTangent}}
Its fdata and rdata are given by special FData
and RData
types:
julia> (fdata_type(tangent_type(Foo)), rdata_type(tangent_type(Foo)))
(Mooncake.FData{@NamedTuple{x::NoFData, y, z::NoFData}}, Mooncake.RData{@NamedTuple{x::Float64, y, z::NoRData}})
Practically speaking, FData
and RData
both have the same structure as Tangent
s and are just used in different contexts.
Mutable Structs
The fdata for a mutable struct
s is its tangent, and it has no rdata. This is because mutable struct
s have fixed memory addresses, and can therefore be incremented in-place. For example,
julia> mutable struct Bar
x::Float64
y
z::Int
end
has tangent type
julia> tangent_type(Bar)
MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}
and fdata / rdata types
julia> (fdata_type(tangent_type(Bar)), rdata_type(tangent_type(Bar)))
(MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}, NoRData)
Primitive Types
As with tangents, each primitive type must specify what its fdata and rdata is. See specific examples for details.
Mooncake.fix_up_invoke_inference!
— Methodfix_up_invoke_inference!(ir::IRCode)
The Problem
Consider the following:
@noinline function bar!(x)
x .*= 2
end
function foo!(x)
bar!(x)
return nothing
end
In this case, the IR associated to Tuple{typeof(foo), Vector{Float64}}
will be something along the lines of
julia> Base.code_ircode_by_type(Tuple{typeof(foo), Vector{Float64}})
1-element Vector{Any}:
2 1 ─ invoke Main.bar!(_2::Vector{Float64})::Any
3 └── return Main.nothing
=> Nothing
Observe that the type inferred for the first line is Any
. Inference is at liberty to do this without any risk of performance problems because the first line is not used anywhere else in the function. Had this line been used elsewhere in the function, inference would have inferred its type to be Vector{Float64}
.
This causes performance problems for Mooncake, because it uses the return type to do various things, including allocating storage for quantities required on the reverse-pass. Consequently, inference infering Any
rather than Vector{Float64}
causes type instabilities in the code that Mooncake generates, which can have catastrophic conseqeuences for performance.
The Solution
:invoke
expressions contain the Core.MethodInstance
associated to them, which contains a Core.CodeCache
, which contains the return type of the :invoke
. This function looks for :invoke
statements whose return type is inferred to be Any
in ir
, and modifies it to be the return type given by the code cache.
Mooncake.flat_product
— Methodflat_product(xs...)
Equivalent to vec(collect(Iterators.product(xs...)))
.
Mooncake.foreigncall_to_call
— Methodforeigncall_to_call(inst, sp_map::Dict{Symbol, CC.VarState})
If inst
is a :foreigncall
expression translate it into an equivalent :call
expression. If anything else, just return inst
. See Mooncake._foreigncall_
for details.
sp_map
maps the names of the static parameters to their values. This function is intended to be called in the context of an IRCode
, in which case the values of sp_map
are given by the sptypes
field of said IRCode
. The keys should generally be obtained from the Method
from which the IRCode
is derived. See Mooncake.normalise!
for more details.
The purpose of this transformation is to make it possible to differentiate :foreigncall
expressions in the same way as a primitive :call
expression, i.e. via an rrule!!
.
Mooncake.forwards_pass_ir
— Methodforwards_pass_ir(ir::BBCode, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)
Produce the IR associated to the OpaqueClosure
which runs most of the forwards-pass.
Mooncake.fwd_ir
— Methodfwd_ir(
sig::Type{<:Tuple};
interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true
)::IRCode
!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.
Generate the Core.Compiler.IRCode
used to construct the forwards-pass of AD. Take a look at how build_rrule
makes use of generate_ir
to see exactly how this is used in practice.
For example, if you wanted to get the IR associated to the forwards pass for the call map(sin, randn(10))
, you could do either of the following:
julia> Mooncake.fwd_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
true
julia> Mooncake.fwd_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
true
Arguments
sig::Type{<:Tuple}
: the signature of the call to be differentiated.
Keyword Arguments
interp
: the interpreter to use to obtain the primal IR.debug_mode::Bool
: whether the generated IR should make use of Mooncake's debug mode.do_inline::Bool
: whether to apply an inlining pass prior to returning the ir generated by this function. This istrue
by default, but setting it tofalse
can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
Mooncake.gc_preserve
— Methodgc_preserve(xs...)
A no-op function. Its rrule!!
ensures that the memory associated to xs
is not freed until the pullback that it returns is run.
Mooncake.generate_ir
— Methodgenerate_ir(
interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true
)
Used by build_rrule
, and the various debugging tools: primalir, fwdsir, adjoint_ir.
Mooncake.get_const_primal_value
— Methodget_const_primal_value(x::GlobalRef)
Get the value associated to x
. For GlobalRef
s, verify that x
is indeed a constant.
Mooncake.get_interpreter
— Methodget_interpreter()
Returns a MooncakeInterpreter
appropriate for the current world age. Will use a cached interpreter if one already exists for the current world age, otherwise creates a new one.
This should be prefered over constructing a MooncakeInterpreter
directly.
Mooncake.get_primal_type
— Methodget_primal_type(info::ADInfo, x)
Returns the static / inferred type associated to x
.
Mooncake.get_rev_data_id
— Methodget_rev_data_id(info::ADInfo, x)
Returns the ID
associated to the line in the reverse pass which will contain the reverse data for x
. If x
is not an Argument
or ID
, then nothing
is returned.
Mooncake.get_tangent_field
— Methodget_tangent_field(t::Union{MutableTangent, Tangent}, i::Int)
Gets the i
th field of data in t
.
Has the same semantics that getfield!
would have if the data in the fields
field of t
were actually fields of t
. This is the moral equivalent of getfield
for MutableTangent
.
Mooncake.inc_args
— Methodinc_args(stmt)
Increment by 1
the n
field of any Argument
s present in stmt
. Used in make_ad_stmts!
.
Mooncake.increment!!
— Methodincrement!!(x::T, y::T) where {T}
Add x
to y
. If ismutabletype(T)
, then increment!!(x, y) === x
must hold. That is, increment!!
will mutate x
. This must apply recursively if T
is a composite type whose fields are mutable.
Mooncake.increment_and_get_rdata!
— Methodincrement_and_get_rdata!(fdata, zero_rdata, cr_tangent)
Increment fdata
by the fdata component of the ChainRules.jl-style tangent, cr_tangent
, and return the rdata component of cr_tangent
by adding it to zero_rdata
.
Mooncake.increment_field!!
— Methodincrement_field!!(x::T, y::V, f) where {T, V}
increment!!
the field f
of x
by y
, and return the updated x
.
Mooncake.increment_rdata!!
— Methodincrement_rdata!!(t::T, r)::T where {T}
Increment the rdata component of tangent t
by r
, and return the updated tangent. Useful for implementation getfield-like rules for mutable structs, pointers, dicts, etc.
Mooncake.increment_ref_stmts
— Methodincrement_ref_stmts(ref_id::ID, inc_data)::Vector{IDInstPair}
Equivalent to ref[] = increment!!(ref[], inc_data)
, where ref
and inc_data
are the values associated to ref_id
and inc_data
respectively.
Mooncake.infer_ir!
— Methodinfer_ir!(ir::IRCode) -> IRCode
Runs type inference on ir
, which mutates ir
, and returns it.
Note: the compiler will not infer the types of anything where the corrsponding element of ir.stmts.flag
is not set to Core.Compiler.IR_FLAG_REFINED
. Nor will it attempt to refine the type of the value returned by a :invoke
expressions. Consequently, if you find that the types in your IR are not being refined, you may wish to check that neither of these things are happening.
Mooncake.interpolate_boundschecks!
— Methodinterpolate_boundschecks!(ir::IRCode)
For every x = Expr(:boundscheck, value)
in ir
, interpolate value
into all uses of x
. This is only required in order to ensure that literal versions of memoryrefget, memoryrefset!, getfield, and setfield! work effectively. If they are removed through improvements to the way that we handle constant propagation inside Mooncake, then this functionality can be removed.
Mooncake.intrinsic_to_function
— Methodintrinsic_to_function(inst)
If inst
is a :call
expression to a Core.IntrinsicFunction
, replace it with a call to the corresponding function
from Mooncake.IntrinsicsWrappers
, else return inst
.
cglobal
is a special case – it requires that its first argument be static in exactly the same way as :foreigncall
. See IntrinsicsWrappers.__cglobal
for more info.
The purpose of this transformation is to make it possible to use dispatch to write rules for intrinsic calls using dispatch in a type-stable way. See IntrinsicsWrappers
for more context.
Mooncake.ircode
— Functionircode(
inst::Vector{Any},
argtypes::Vector{Any},
sptypes::Vector{CC.VarState}=CC.VarState[],
) -> IRCode
Constructs an instance of an IRCode
. This is useful for constructing test cases with known properties.
No optimisations or type inference are performed on the resulting IRCode
, so that the IRCode
contains exactly what is intended by the caller. Please make use of infer_types!
if you require the types to be inferred.
Edges in PhiNode
s, GotoIfNot
s, and GotoNode
s found in inst
must refer to lines (as in CodeInfo
). In the IRCode
returned by this function, these line references are translated into block references.
Mooncake.is_always_fully_initialised
— Methodis_always_fully_initialised(P::DataType)::Bool
True if all fields in P
are always initialised. Put differently, there are no inner constructors which permit partial initialisation.
Mooncake.is_always_initialised
— Methodis_always_initialised(P::DataType, n::Int)::Bool
True if the n
th field of P
is always initialised. If the n
th fieldtype of P
isbitstype
, then this is distinct from asking whether the n
th field is always defined. An isbits field is always defined, but is not always explicitly initialised.
Mooncake.is_primitive
— Methodis_primitive(::Type{Ctx}, sig) where {Ctx}
Returns a Bool
specifying whether the methods specified by sig
are considered primitives in the context of contexts of type Ctx
.
is_primitive(DefaultCtx, Tuple{typeof(sin), Float64})
will return if calling sin(5.0)
should be treated as primitive when the context is a DefaultCtx
.
Observe that this information means that whether or not something is a primitive in a particular context depends only on static information, not any run-time information that might live in a particular instance of Ctx
.
Mooncake.is_unreachable_return_node
— Methodis_unreachable_return_node(x::ReturnNode)
Determine whehter x
is a ReturnNode
, and if it is, if it is also unreachable. This is purely a function of whether or not its val
field is defined or not.
Mooncake.is_used
— Methodis_used(info::ADInfo, id::ID)::Bool
Returns true
if id
is used by any of the lines in the ir, false otherwise.
Mooncake.is_vararg_and_sparam_names
— Methodis_vararg_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}}
Finds the method associated to sig
, and calls is_vararg_and_sparam_names
on it.
Mooncake.is_vararg_and_sparam_names
— Methodis_vararg_and_sparam_names(mi::Core.MethodInstance)
Calls is_vararg_and_sparam_names
on mi.def::Method
.
Mooncake.is_vararg_and_sparam_names
— Methodis_vararg_and_sparam_names(m::Method)
Returns a 2-tuple. The first element is true if m
is a vararg method, and false if not. The second element contains the names of the static parameters associated to m
.
Mooncake.lgetfield
— Methodlgetfield(x, f::Val)
An implementation of getfield
in which the the field f
is specified statically via a Val
. This enables the implementation to be type-stable even when it is not possible to constant-propagate f
. Moreover, it enable the pullback to also be type-stable.
It will always be the case that
getfield(x, :f) === lgetfield(x, Val(:f))
getfield(x, 2) === lgetfield(x, Val(2))
This approach is identical to the one taken by Zygote.jl
to circumvent the same problem. Zygote.jl
calls the function literal_getfield
, while we call it lgetfield
.
Mooncake.lgetfield
— Methodlgetfield(x, ::Val{f}, ::Val{order}) where {f, order}
Like getfield
, but with the field and access order encoded as types.
Mooncake.lift_gc_preservation
— Methodlift_gc_preserve(inst)
Expressions of the form
y = GC.@preserve x1 x2 foo(args...)
get lowered to
token = Expr(:gc_preserve_begin, x1, x2)
y = expr
Expr(:gc_preserve_end, token)
These expressions guarantee that any memory associated x1
and x2
not be freed until the :gc_preserve_end
expression is reached.
In the context of reverse-mode AD, we must ensure that the memory associated to x1
, x2
and their fdata is available during the reverse pass code associated to expr
. We do this by preventing the memory from being freed until the :gc_preserve_begin
is reached on the reverse pass.
To achieve this, we replace the primal code with
# store `x` in `pb_gc_preserve` to prevent it from being freed.
_, pb_gc_preserve = rrule!!(zero_fcodual(gc_preserve), x1, x2)
# Differentiate the `:call` expression in the usual way.
y, foo_pb = rrule!!(zero_fcodual(foo), args...)
# Do not permit the GC to free `x` here.
nothing
The pullback should be something along the lines of
# no pullback associated to `nothing`.
nothing
# Run the pullback associated to `foo` in the usual manner. `x` must be available.
_, dargs... = foo_pb(dy)
# No-op pullback associated to `gc_preserve`.
pb_gc_preserve(NoRData())
Mooncake.lift_getfield_and_others
— Methodlift_getfield_and_others(inst)
Converts expressions of the form getfield(x, :a)
into lgetfield(x, Val(:a))
. This has identical semantics, but is performant in the absence of proper constant propagation.
Does the same for...
Mooncake.lookup_ir
— Methodlookup_ir(
interp::AbstractInterpreter,
sig_or_mi::Union{Type{<:Tuple}, Core.MethodInstance},
)::Tuple{IRCode, T}
Get the unique IR associated to sig_or_mi
under interp
. Throws ArgumentError
s if there is no code found, or if more than one IRCode
instance returned.
Returns a tuple containing the IRCode
and its return type.
Mooncake.lsetfield!
— Methodlsetfield!(value, name::Val, x, [order::Val])
This function is to setfield!
what lgetfield
is to getfield
. It will always hold that
setfield!(copy(x), :f, v) == lsetfield!(copy(x), Val(:f), v)
setfield!(copy(x), 2, v) == lsetfield(copy(x), Val(2), v)
Mooncake.make_ad_stmts!
— Functionmake_ad_stmts!(inst::NewInstruction, line::ID, info::ADInfo)::ADStmtInfo
Every line in the primal code is associated to one or more lines in the forwards-pass of AD, and one or more lines in the pullback. This function has method specific to every node type in the Julia SSAIR.
Translates the instruction inst
, associated to line
in the primal, into a specification of what should happen for this instruction in the forwards- and reverse-passes of AD, and what data should be shared between the forwards- and reverse-passes. Returns this in the form of an ADStmtInfo
.
info
is a data structure containing various bits of global information that certain types of nodes need access to.
Mooncake.make_switch_stmts
— Methodmake_switch_stmts(
pred_ids::Vector{ID}, target_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo
)
preds_ids
comprises the ID
s associated to all possible predecessor blocks to the primal block under consideration. Suppose its value is [ID(1), ID(2), ID(3)]
, then make_switch_stmts
emits code along the lines of
prev_block = pop!(block_stack)
not_pred_was_1 = !(prev_block == ID(1))
not_pred_was_2 = !(prev_block == ID(2))
switch(
not_pred_was_1 => ID(1),
not_pred_was_2 => ID(2),
ID(3)
)
In words: make_switch_stmts
emits code which jumps to whichever block preceded the current block during the forwards-pass.
Mooncake.map_prod
— Methodmap_prod(f, xs...)
Equivalent to map(f, flat_product(xs...))
.
Mooncake.misty_closure
— Methodmisty_closure(
ret_type::Type,
ir::IRCode,
@nospecialize env...;
isva::Bool=false,
do_compile::Bool=true,
)
Identical to Mooncake.opaque_closure
, but returns a MistyClosure
closure rather than a Core.OpaqueClosure
.
Mooncake.new_to_call
— Methodnew_to_call(x)
If instruction x
is a :new
expression, replace it with a :call
to Mooncake._new_
. Otherwise, return x
.
The purpose of this transformation is to make it possible to differentiate :new
expressions in the same way as a primitive :call
expression, i.e. via an rrule!!
.
Mooncake.normalise!
— Methodnormalise!(ir::IRCode, spnames::Vector{Symbol})
Apply a sequence of standardising transformations to ir
which leaves its semantics unchanged, but makes AD more straightforward. In particular, replace
:foreigncall
Expr
s with:call
s toMooncake._foreigncall_
,:new
Expr
s with:call
s toMooncake._new_
,:splatnew
Exprs with
:calls to
Mooncake.splatnew_`,Core.IntrinsicFunction
s with counterparts fromMooncake.IntrinsicWrappers
,getfield(x, 1)
withlgetfield(x, Val(1))
, and related transformations,memoryrefget
calls tolmemoryrefget
calls, and related transformations,gc_preserve_begin
/gc_preserve_end
exprs so that memory release is delayed.
spnames
are the names associated to the static parameters of ir
. These are needed when handling :foreigncall
expressions, in which it is not necessarily the case that all static parameter names have been translated into either types, or :static_parameter
expressions.
Unfortunately, the static parameter names are not retained in IRCode
, and the Method
from which the IRCode
is derived must be consulted. Mooncake.is_vararg_and_sparam_names
provides a convenient way to do this.
Mooncake.opaque_closure
— Methodopaque_closure(
ret_type::Type,
ir::IRCode,
@nospecialize env...;
isva::Bool=false,
do_compile::Bool=true,
)::Core.OpaqueClosure{<:Tuple, ret_type}
Construct a Core.OpaqueClosure
. Almost equivalent to Core.OpaqueClosure(ir, env...; isva, do_compile)
, but instead of letting Core.compute_oc_rettype
figure out the return type from ir
, impose ret_type
as the return type.
Warning
User beware: if the Core.OpaqueClosure
produced by this function ever returns anything which is not an instance of a subtype of ret_type
, you should expect all kinds of awful things to happen, such as segfaults. You have been warned!
Extended Help
This is needed in Mooncake.jl because make extensive use of our ability to know the return type of a couple of specific OpaqueClosure
s without actually having constructed them – see LazyDerivedRule
. Without the capability to specify the return type, we have to guess what type compute_ir_rettype
will return for a given IRCode
before we have constructed the IRCode
and run type inference on it. This exposes us to details of type inference, which are not part of the public interface of the language, and can therefore vary from Julia version to Julia version (including patch versions). Moreover, even for a fixed Julia version it can be extremely hard to predict exactly what type inference will infer to be the return type of a function.
Failing to correctly guess the return type can happen for a number of reasons, and the kinds of errors that tend to be generated when this fails tell you very little about the underlying cause of the problem.
By specifying the return type ourselves, we remove this dependence. The price we pay for this is the potential for segfaults etc if we fail to specify ret_type
correctly.
Mooncake.optimise_ir!
— Methodoptimise_ir!(ir::IRCode, show_ir=false)
Run a fairly standard optimisation pass on ir
. If show_ir
is true
, displays the IR to stdout
at various points in the pipeline – this is sometimes useful for debugging.
Mooncake.primal_ir
— Methodprimal_ir(sig::Type{<:Tuple}; interp=get_interpreter())::IRCode
!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.
Get the Core.Compiler.IRCode
associated to sig
from which the a rule can be derived. Roughly equivalent to Base.code_ircode_by_type(sig; interp)
.
For example, if you wanted to get the IR associated to the call map(sin, randn(10))
, you could do one of the following calls:
julia> Mooncake.primal_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
true
julia> Mooncake.primal_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
true
Mooncake.pullback_ir
— Methodpullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)
Produce the IR associated to the OpaqueClosure
which runs most of the pullback.
Mooncake.pullback_type
— Methodpullback_type(Trule, arg_types)
Get a bound on the pullback type, given a rule and associated primal types.
Mooncake.randn_tangent
— Methodrandn_tangent(rng::AbstractRNG, x::T) where {T}
Required for testing. Generate a randomly-chosen tangent to x
. The design is closely modelled after zero_tangent
.
Mooncake.rdata
— Methodrdata(t)::rdata_type(typeof(t))
Extract the reverse data from tangent t
.
Extended help
See extended help section of fdata_type.
Mooncake.rdata_backing_type
— Methodrdata_backing_type(::Type{P}) where {P}
The type of the field of RData
for P
.
Mooncake.rdata_field_type
— Methodrdata_field_type(::Type{P}, n::Int) where {P}
Returns the type of to the nth field of the rdata type associated to P
. Will be a PossiblyUninitTangent
if said field can be undefined.
Mooncake.rdata_field_types_exprs
— Methodrdata_field_types_exprs(::Type{P}) where {P}
Tuple of expressions. The nth computes the rdata backing type of the nth field of P
.
Mooncake.rdata_type
— Methodrdata_type(T)
Returns the type of the reverse data of a tangent of type T.
Extended help
See extended help in fdata_type
docstring.
Mooncake.remove_edge!
— Methodremove_edge!(ir::IRCode, from::Int, to::Int)
Removes an edge in ir
from from
to to
. See implementation for what this entails.
Note: this is slightly different from Core.Compiler.kill_edge!
, in that it also updates PhiNode
s in the to
block. Moreover, the available methods of remove_edge!
differ between 1.10 and 1.11, so we need something which is stable across both.
Mooncake.replace_captures
— Methodreplace_captures(mc::Tmc, new_captures) where {Tmc<:MistyClosure}
Same as replace_captures
for Core.OpaqueClosure
s, but returns a new MistyClosure
.
Mooncake.replace_captures
— Methodreplace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure}
Given an OpaqueClosure
oc
, create a new OpaqueClosure
of the same type, but with new captured variables. This is needed for efficiency reasons – if build_rrule
is called repeatedly with the same signature and intepreter, it is important to avoid recompiling the OpaqueClosure
s that it produces multiple times, because it can be quite expensive to do so.
Mooncake.replace_uses_with!
— Methodreplace_uses_with!(stmt, def::Union{Argument, SSAValue}, val)
Replace all uses of def
with val
in the single statement stmt
. Note: this function is highly incomplete, really only working correctly for a specific function in ir_normalisation.jl
. You probably do not want to use it.
Mooncake.reverse_data_ref_stmts
— Methodreverse_data_ref_stmts(info::ADInfo)
Create the :new
statements which initialise the reverse-data Ref
s. Interpolates the initial rdata directly into the statement, which is safe because it is always a bits type.
Mooncake.rrule!!
— Functionrrule!!(f::CoDual, x::CoDual...)
Performs the forwards-pass of AD. The tangent
field of f
and each x
should contain the forwards tangent data (fdata) associated to each corresponding primal
field.
Returns a 2-tuple. The first element, y
, is a CoDual
whose primal
field is the value associated to running f.primal(map(x -> x.primal, x)...)
, and whose tangent
field is its associated fdata
. The second element contains the pullback, which runs the reverse-pass. It maps from the rdata associated to y
to the rdata associated to f
and each x
.
using Mooncake: zero_fcodual, CoDual, NoFData, rrule!!
y, pb!! = rrule!!(zero_fcodual(sin), CoDual(5.0, NoFData()))
pb!!(1.0)
# output
(NoRData(), 0.28366218546322625)
Mooncake.rrule_wrapper
— Methodrrule_wrapper(f::CoDual, args::CoDual...)
Used to implement rrule!!
s via ChainRulesCore.rrule
.
Given a function foo
, argument types arg_types
, and a method of ChainRulesCore.rrule
which applies to these, you can make use of this function as follows:
Mooncake.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...}
function Mooncake.rrule!!(f::CoDual{typeof(foo)}, args::CoDual...)
return rrule_wrapper(f, args...)
end
Assumes that methods of to_cr_tangent
and to_mooncake_tangent
are defined such that you can convert between the different representations of tangents that Mooncake and ChainRulesCore expect.
Furthermore, it is essential that
f(args)
does not mutatef
orargs
, and- the result of
f(args)
does not alias any data stored inf
orargs
.
Subject to some constraints, you can use the @from_rrule
macro to reduce the amount of boilerplate code that you are required to write even further.
Mooncake.rule_type
— Methodrule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C}
Compute the concrete type of the rule that will be returned from build_rrule
. This is important for performance in dynamic dispatch, and to ensure that recursion works properly.
Mooncake.rvs_ir
— Methodrvs_ir(
sig::Type{<:Tuple};
interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true
)::IRCode
!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.
Generate the Core.Compiler.IRCode
used to construct the reverse-pass of AD. Take a look at how build_rrule
makes use of generate_ir
to see exactly how this is used in practice.
For example, if you wanted to get the IR associated to the reverse pass for the call map(sin, randn(10))
, you could do either of the following:
julia> Mooncake.rvs_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
true
julia> Mooncake.rvs_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
true
Arguments
sig::Type{<:Tuple}
: the signature of the call to be differentiated.
Keyword Arguments
interp
: the interpreter to use to obtain the primal IR.debug_mode::Bool
: whether the generated IR should make use of Mooncake's debug mode.do_inline::Bool
: whether to apply an inlining pass prior to returning the ir generated by this function. This istrue
by default, but setting it tofalse
can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
Mooncake.rvs_phi_block
— Methodrvs_phi_block(pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo)
Produces a BBlock
which runs the reverse-pass for the edge associated to pred_id
in a collection of IDPhiNode
s, and then goes to the block associated to pred_id
.
For example, suppose that we encounter the following collection of PhiNode
s at the start of some block:
%6 = φ (#2 => _1, #3 => %5)
%7 = φ (#2 => 5., #3 => _2)
Let the rdata refs associated to %6
, %7
, and _1
be denoted
r%6,
r%7, and
r1resp., and let
predidbe
#2, and
increment_ref!` be the following function,
increment_ref!(ref, x) = ref[] = increment!!(ref[], x)
then this rvs_phi_block
will produce a basic block of the form
increment_ref!(r_1, r%6)
nothing
goto #2
The call to increment_ref!
appears because _1
is the value associated to%6
when the primal code comes from #2
. Similarly, the goto #2
statement appears because we came from #2
on the forwards-pass. There is no increment_ref!
associated to %7
because 5.
is a constant. We emit a nothing
statement, which the compiler will happily optimise away later on.
The same ideas apply if pred_id
were #3
. The block would end with #3
, and there would be two increment_ref!
calls because both %5
and _2
are not constants.
In practice, code which is equivalent to increment_ref!
is created directly, rather than inserting a call to a generic Julia function. This is because we need to be certain that the getfield and setfield! calls applied to any references are visible to the SROA optimisation pass. If we insert a call to a function like increment_ref!
, it might not be inlined away, making such references opaque.
Mooncake.set_tangent_field!
— Methodset_tangent_field!(t::MutableTangent{Tfields}, i::Int, x) where {Tfields}
Sets the value of the i
th field of the data in t
to value x
.
Has the same semantics that setfield!
would have if the data in the fields
field of t
were actually fields of t
. This is the moral equivalent of setfield!
for MutableTangent
.
Mooncake.set_to_zero!!
— Methodset_to_zero!!(x)
Set x
to its zero element (x
should be a tangent, so the zero must exist).
Mooncake.shared_data_stmts
— Methodshared_data_stmts(p::SharedDataPairs)::Vector{IDInstPair}
Produce a sequence of id-statment pairs which will extract the data from shared_data_tuple(p)
such that the correct value is associated to the correct ID
.
For example, if p.pairs
is
[(ID(5), 5.0), (ID(3), "hello")]
then the output of this function is
IDInstPair[
(ID(5), new_inst(:(getfield(_1, 1)))),
(ID(3), new_inst(:(getfield(_1, 2)))),
]
Mooncake.shared_data_tuple
— Methodshared_data_tuple(p::SharedDataPairs)::Tuple
Create the tuple that will constitute the captured variables in the forwards- and reverse- pass OpaqueClosure
s.
For example, if p.pairs
is
[(ID(5), 5.0), (ID(3), "hello")]
then the output of this function is
(5.0, "hello")
Mooncake.sparam_names
— Methodsparam_names(m::Core.Method)::Vector{Symbol}
Returns the names of all of the static parameters in m
.
Mooncake.splatnew_to_call
— Methodsplatnew_to_call(x)
If instruction x
is a :splatnew
expression, replace it with a :call
to Mooncake._splat_new_
. Otherwise return x
.
The purpose of this transformation is to make it possible to differentiate :splatnew
expressions in the same way as a primitive :call
expression, i.e. via an rrule!!
.
Mooncake.stable_all
— Methodstable_all(x::NTuple{N, Bool}) where {N}
all(x::NTuple{N, Bool})
does not constant-fold nicely on 1.10 if the values of x
are known statically. This implementation constant-folds nicely on both 1.10 and 1.11, so can be used in its place in situations where this is important.
Mooncake.stmt
— Methodstmt(ir::CC.InstructionStream)
Get the field containing the instructions in ir
. This changed name in 1.11 from inst
to stmt
.
Mooncake.tangent
— Methodtangent(f, r)
Reconstruct the tangent t
for which fdata(t) == f
and rdata(t) == r
.
Mooncake.tangent_test_cases
— Methodtangent_test_cases()
Constructs a Vector
of Tuple
s containing test cases for the tangent infrastructure.
If the returned tuple has 2 elements, the elements should be interpreted as follows: 1 - interface_only 2 - primal value
interface_only is a Bool which will be used to determine which subset of tests to run.
If the returned tuple has 5 elements, then the elements are interpreted as follows: 1 - interface_only 2 - primal value 3, 4, 5 - tangents, where <5> == increment!!(<3>, <4>).
Test cases in the first format make use of zero_tangent
/ randn_tangent
etc to generate tangents, but they're unable to check that increment!!
is correct in an absolute sense.
Mooncake.tangent_type
— Methodtangent_type(P)
There must be a single type used to represents tangents of primals of type P
, and it must be given by tangent_type(P)
.
Warning: this function assumes the effects :removable
and :consistent
. This is necessary to ensure good performance, but imposes precise constraints on your implementation. If adding new methods to tangent_type
, you should consult the extended help of Base.@assume_effects
to see what this imposes upon your implementation.
Extended help
The tangent types which Mooncake.jl uses are quite similar in spirit to ChainRules.jl. For example, tangent "vectors" for
Float64
s areFloat64
s,Vector{Float64}
s areVector{Float64}
s, andstruct
s are other another (special)struct
with field types specified recursively.
There are, however, some major differences. Firstly, while it is certainly true that the above tangent types are permissible in ChainRules.jl, they are not the uniquely permissible types. For example, ZeroTangent
is also a permissible type of tangent for any of them, and Float32
is permissible for Float64
. This is a general theme in ChainRules.jl – it intentionally declines to place restrictions on what type can be used to represent the tangent of a given type.
Mooncake.jl differs from this. It insists that each primal type is associated to a single tangent type. Furthermore, this type is always given by the function Mooncake.tangent_type(primal_type)
.
Consider some more worked examples.
Int
Int
is not a differentiable type, so its tangent type is NoTangent
:
julia> tangent_type(Int)
NoTangent
Tuples
The tangent type of a Tuple
is defined recursively based on its field types. For example
julia> tangent_type(Tuple{Float64, Vector{Float64}, Int})
Tuple{Float64, Vector{Float64}, NoTangent}
There is one edge case to be aware of: if all of the field of a Tuple
are non-differentiable, then the tangent type is NoTangent
. For example,
julia> tangent_type(Tuple{Int, Int})
NoTangent
Structs
As with Tuple
s, the tangent type of a struct is, by default, given recursively. In particular, the tangent type of a struct
type is Tangent
. This type contains a NamedTuple
containing the tangent to each field in the primal struct
.
As with Tuple
s, if all field types are non-differentiable, the tangent type of the entire struct is NoTangent
.
There are a couple of additional subtleties to consider over Tuple
s though. Firstly, not all fields of a struct
have to be defined. Fortunately, Julia makes it easy to determine how many of the fields might possibly not be defined. The tangent associated to any field which might possibly not be defined is wrapped in a PossiblyUninitTangent
.
Furthermore, struct
s can have fields whose static type is abstract. For example
julia> struct Foo
x
end
If you ask for the tangent type of Foo
, you will see that it is
julia> tangent_type(Foo)
Tangent{@NamedTuple{x}}
Observe that the field type associated to x
is Any
. The way to understand this result is to observe that
x
could have literally any type at runtime, so we know nothing about what its tangent type must be until runtime, and- we require that the tangent type of
Foo
be unique.
The consequence of these two considerations is that the tangent type of Foo
must be able to contain any type of tangent in its x
field. It follows that the fieldtype of the x
field of Foo
s tangent must be Any
.
Mutable Structs
The tangent type for mutable struct
s have the same set of considerations as struct
s. The only difference is that they must themselves be mutable. Consequently, we use a type called MutableTangent
to represent their tangents. It is a mutable struct
with the same structure as Tangent
.
For example, if you ask for the tangent_type
of
julia> mutable struct Bar
x::Float64
end
you will find that it is
julia> tangent_type(Bar)
MutableTangent{@NamedTuple{x::Float64}}
Primitive Types
We've already seen a couple of primitive types (Float64
and Int
). The basic story here is that all primitive types require an explicit specification of what their tangent type must be.
One interesting case are Ptr
types. The tangent type of a Ptr{P}
is Ptr{T}
, where T = tangent_type(P)
. For example
julia> tangent_type(Ptr{Float64})
Ptr{Float64}
Mooncake.tangent_type
— Methodtangent_type(F::Type, R::Type)::Type
Given the type of the fdata and rdata, F
and R
resp., for some primal type, compute its tangent type. This method must be equivalent to tangent_type(_typeof(primal))
.
Mooncake.to_cr_tangent
— Methodto_cr_tangent(t)
Convert a Mooncake tangent into a type that ChainRules.jl rrule
s expect to see.
Mooncake.tuple_map
— Methodtuple_map(f::F, x::Tuple) where {F}
This function is largely equivalent to map(f, x)
, but always specialises on all of the element types of x
, regardless the length of x
. This contrasts with map
, in which the number of element types specialised upon is a fixed constant in the compiler.
As a consequence, if x
is very long, this function may have very large compile times.
tuple_map(f::F, x::Tuple, y::Tuple) where {F}
Binary extension of tuple_map
. Nearly equivalent to map(f, x, y)
, but guaranteed to specialise on all element types of x
and y
. Furthermore, errors if x
and y
aren't the same length, while map
will just produce a new tuple whose length is equal to the shorter of x
and y
.
Mooncake.unhandled_feature
— Methodunhandled_feature(msg::String)
Throw an UnhandledLanguageFeatureException
with message msg
.
Mooncake.uninit_codual
— Methoduninit_codual(x)
Equivalent to CoDual(x, uninit_tangent(x))
.
Mooncake.uninit_fcodual
— Methoduninit_fcodual(x)
Like zero_fcodual
, but doesn't guarantee that the value of the fdata is initialised. See implementation for details, as this function is subject to change.
Mooncake.uninit_tangent
— Methoduninit_tangent(x)
Related to zero_tangent
, but a bit different. Check current implementation for details – this docstring is intentionally non-specific in order to avoid becoming outdated.
Mooncake.verify_fdata_type
— Methodverify_fdata_type(P::Type, F::Type)::Nothing
Check that F
is a valid type for fdata associated to a primal of type P
. Returns nothing
if valid, throws an InvalidFDataException
if a problem is found.
This applies to both concrete and non-concrete P
. For example, if P
is the type inferred for a primal q::Q
, such that Q <: P
, then this method is still applicable.
Mooncake.verify_fdata_value
— Methodverify_fdata_value(p, f)::Nothing
Check that f
cannot be proven to be invalid fdata for p
.
This method attempts to provide some confidence that f
is valid fdata for p
by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.
Put differently, we cannot prove that f
is valid fdata, only that it is not obviously invalid.
Mooncake.verify_no_constant_gotoifnots
— Methodverify_no_constant_gotoifnots(ir::IRCode)
Verify that we have successfully removed all instances of goto %n if not true
and goto %n if not false
, as these can be reduced to simpler nodes (namely, GotoNode
s or "fallthrough"s). Moreover, removing them tends to yield performance improvements by reducing the amount of information Mooncake must keep in its block stacks.
This is essentially just testing functionality for const_prop_constant_gotoifnots
. This is usually run each time a rule is compiled, as it is cheap, and because it is hard to construct a convincing set of test cases which, if passed at test-time, would indicate we were done.
Mooncake.verify_rdata_type
— Methodverify_rdata_type(P::Type, R::Type)::Nothing
Check that R
is a valid type for rdata associated to a primal of type P
. Returns nothing
if valid, throws an InvalidRDataException
if a problem is found.
This applies to both concrete and non-concrete P
. For example, if P
is the type inferred for a primal q::Q
, such that Q <: P
, then this method is still applicable.
Mooncake.verify_rdata_value
— Methodverify_rdata_value(p, r)::Nothing
Check that r
cannot be proven to be invalid rdata for p
.
This method attempts to provide some confidence that r
is valid rdata for p
by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.
Put differently, we cannot prove that r
is valid rdata, only that it is not obviously invalid.
Mooncake.zero_adjoint
— Methodzero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}
Utility functionality for constructing rrule!!
s for functions which produce adjoints which always return zero.
NOTE: you should only make use of this function if you cannot make use of the @zero_adjoint
macro.
You make use of this functionality by writing a method of Mooncake.rrule!!
, and passing all of its arguments (including the function itself) to this function. For example:
julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual
julia> foo(x::Vararg{Int}) = 5
foo (generic function with 1 method)
julia> is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(foo), Vararg{Int}}}) = true;
julia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);
julia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())
(NoRData(), NoRData(), NoRData())
WARNING: this is only correct if the output of primal(f)(map(primal, x)...)
does not alias anything in f
or x
. This is always the case if the result is a bits type, but more care may be required if it is not. ```
Mooncake.zero_codual
— Methodzero_codual(x)
Equivalent to CoDual(x, zero_tangent(x))
.
Mooncake.zero_like_rdata_from_type
— Methodzero_like_rdata_from_type(::Type{P}) where {P}
This is an internal implementation detail – you should generally not use this function.
Returns either the zero element of type rdata_type(tangent_type(P))
, or a ZeroRData
. It is always valid to return a ZeroRData
,
Mooncake.zero_like_rdata_type
— Methodzero_like_rdata_type(::Type{P}) where {P}
Indicates the type which will be returned by zero_like_rdata_from_type
. Will be the rdata type for P
if we can produce the zero rdata element given only P
, and will be the union of R
and ZeroRData
if an instance of P
is needed.
Mooncake.zero_rdata
— Methodzero_rdata(p)
Given value p
, return the zero element associated to its reverse data type.
Mooncake.zero_rdata_from_type
— Methodzero_rdata_from_type(::Type{P}) where {P}
Returns the zero element of rdata_type(tangent_type(P))
if this is possible given only P
. If not possible, returns an instance of CannotProduceZeroRDataFromType
.
For example, the zero rdata associated to any primal of type Float64
is 0.0
, so for Float64
s this function is simple. Similarly, if the rdata type for P
is NoRData
, that can simply be returned.
However, it is not possible to return the zero rdata element for abstract types e.g. Real
as the type does not uniquely determine the zero element – the rdata type for Real
is Any
.
These considerations apply recursively to tuples / namedtuples / structs, etc.
If you encounter a type which this function returns CannotProduceZeroRDataFromType
, but you believe this is done in error, please open an issue. This kind of problem does not constitute a correctness problem, but can be detrimental to performance, so should be dealt with.
Mooncake.zero_tangent
— Methodzero_tangent(primal, fdata)
Equivalent to tangent(fdata, rdata(zero_tangent(primal)))
.
Mooncake.zero_tangent
— Methodzero_tangent(x)
Returns the unique zero element of the tangent space of x
. It is an error for the zero element of the tangent space of x
to be represented by anything other than that which this function returns.
Internally, zero_tangent
calls zero_tangent_internal
, which handles different types of inputs differently. zero_tangent_internal
has two variants:
- For
isbitstype
types,zero_tangent_internal
takes one argument. - Otherwise,
zero_tangent_internal
takes another argument which is anIdDict
, which
handles both circular references and aliasing correctly.
Mooncake.@foldable
— Macromacro foldable def
Shorthand for Base.@assume_effects :foldable function f(x)...
.
Mooncake.@from_rrule
— Macro@from_rrule ctx sig [has_kwargs=false]
Convenience functionality to assist in using ChainRulesCore.rrule
s to write rrule!!
s.
Arguments
ctx
: A Mooncake context typesig
: the signature which you wish to assert should be a primitive inMooncake.jl
, and use an existingChainRulesCore.rrule
to implement this functionality.has_kwargs
: aBool
state whether or not the function has keyword arguments. This feature has the same limitations asChainRulesCore.rrule
– the derivative w.r.t. all kwargs must be zero.
Example Usage
A Basic Example
julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils
julia> import ChainRulesCore
julia> foo(x::Real) = 5x;
julia> function ChainRulesCore.rrule(::typeof(foo), x::Real)
foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), 5Ω
return foo(x), foo_pb
end;
julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat}
julia> rrule!!(zero_fcodual(foo), zero_fcodual(5.0))[2](1.0)
(NoRData(), 5.0)
julia> # Check that the rule works as intended.
TestUtils.test_rule(Xoshiro(123), foo, 5.0; is_primitive=true)
Test Passed
An Example with Keyword Arguments
julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils
julia> import ChainRulesCore
julia> foo(x::Real; cond::Bool) = cond ? 5x : 4x;
julia> function ChainRulesCore.rrule(::typeof(foo), x::Real; cond::Bool)
foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), cond ? 5Ω : 4Ω
return foo(x; cond), foo_pb
end;
julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} true
julia> _, pb = rrule!!(
zero_fcodual(Core.kwcall),
zero_fcodual((cond=false, )),
zero_fcodual(foo),
zero_fcodual(5.0),
);
julia> pb(3.0)
(NoRData(), NoRData(), NoRData(), 12.0)
julia> # Check that the rule works as intended.
TestUtils.test_rule(
Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true
)
Test Passed
Notice that, in order to access the kwarg method we must call the method of Core.kwcall
, as Mooncake's rrule!!
does not itself permit the use of kwargs.
Limitations
It is your responsibility to ensure that
- calls with signature
sig
do not mutate their arguments, - the output of calls with signature
sig
does not alias any of the inputs.
As with all hand-written rules, you should definitely make use of TestUtils.test_rule
to verify correctness on some test cases.
Argument Type Constraints
Many methods of ChainRuleCore.rrule
are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature
Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}
There are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.
Suffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule
will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat
argument, i.e. Union{Float16, Float32, Float64}
, but it is usually not possible to know that the rule is correct for all possible subtypes of Real
that someone might define.
Conversions Between Different Tangent Type Systems
Under the hood, this functionality relies on two functions: Mooncake.to_cr_tangent
, and Mooncake.increment_and_get_rdata!
. These two functions handle conversion to / from Mooncake
tangent types and ChainRulesCore
tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule
does not work in your case because the required method of either of these functions does not exist, please open an issue.
Mooncake.@is_primitive
— Macro@is_primitive context_type signature
Creates a method of is_primitive
which always returns true
for the context_type and signature
provided. For example
@is_primitive MinimalCtx Tuple{typeof(foo), Float64}
is equivalent to
is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(foo), Float64}}) = true
You should implemented more complicated method of is_primitive
in the usual way.
Mooncake.@mooncake_overlay
— Macro@mooncake_overlay method_expr
Define a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.
For example, suppose that you have a function
julia> foo(x::Float64) = bar(x)
foo (generic function with 1 method)
where Mooncake.jl fails to differentiate bar
for some reason. If you have access to another function baz
, which does the same thing as bar
, but does so in a way which Mooncake.jl can differentiate, you can simply write:
julia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x)
When looking up the code for foo(::Float64)
, Mooncake.jl will see this method, rather than the original, and differentiate it instead.
A Worked Example
To demonstrate how to use @mooncake_overlay
s in practice, we here demonstrate how the answer that Mooncake.jl gives changes if you change the definition of a function using a @mooncake_overlay
. Do not do this in practice – this is just a simple way to demonostrate how to use overlays!
First, consider a simple example:
julia> scale(x) = 2x
scale (generic function with 1 method)
julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});
julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(10.0, (NoTangent(), 2.0))
We can use @mooncake_overlay
to change the definition which Mooncake.jl sees:
julia> Mooncake.@mooncake_overlay scale(x) = 3x
julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});
julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(15.0, (NoTangent(), 3.0))
As can be seen from the output, the result of differentiating using Mooncake.jl has changed to reflect the overlay-ed definition of the method.
Additionally, it is possible to use the usual multi-line syntax to declare an overlay:
julia> Mooncake.@mooncake_overlay function scale(x)
return 4x
end
julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});
julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(20.0, (NoTangent(), 4.0))
Mooncake.@zero_adjoint
— Macro@zero_adjoint ctx sig
Defines is_primitive(context_type, sig) = true
, and defines a method of Mooncake.rrule!!
which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable
.
For example:
julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive
julia> foo(x) = 5
foo (generic function with 1 method)
julia> @zero_adjoint DefaultCtx Tuple{typeof(foo), Any}
julia> is_primitive(DefaultCtx, Tuple{typeof(foo), Any})
true
julia> rrule!!(zero_fcodual(foo), zero_fcodual(3.0))[2](NoRData())
(NoRData(), 0.0)
Limited support for Vararg
s is also available. For example
julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive
julia> foo_varargs(x...) = 5
foo_varargs (generic function with 1 method)
julia> @zero_adjoint DefaultCtx Tuple{typeof(foo_varargs), Vararg}
julia> is_primitive(DefaultCtx, Tuple{typeof(foo_varargs), Any, Float64, Int})
true
julia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData())
(NoRData(), 0.0, NoRData())
Be aware that it is not currently possible to specify any of the type parameters of the Vararg
. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}}
will not work with this macro.
WARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x
will yield incorrect results.
As always, you should use TestUtils.test_rule
to ensure that you've not made a mistake.
Signatures Unsupported By This Macro
If the signature you wish to apply @zero_adjoint
to is not supported, for example because it uses a Vararg
with a type parameter, you can still make use of zero_adjoint
.
Mooncake.IntrinsicsWrappers
— Modulemodule IntrinsicsWrappers
The purpose of this module
is to associate to each function in Core.Intrinsics
a regular Julia function.
To understand the rationale for this observe that, unlike regular Julia functions, each Core.IntrinsicFunction
in Core.Intrinsics
does not have its own type. Rather, they are instances of Core.IntrinsicFunction
. To see this, observe that
julia> typeof(Core.Intrinsics.add_float)
Core.IntrinsicFunction
julia> typeof(Core.Intrinsics.sub_float)
Core.IntrinsicFunction
While we could simply write a rule for Core.IntrinsicFunction
, this would (naively) lead to a large list of conditionals of the form
if f === Core.Intrinsics.add_float
# return add_float and its pullback
elseif f === Core.Intrinsics.sub_float
# return add_float and its pullback
elseif
...
end
which has the potential to cause quite substantial type instabilities. (This might not be true anymore – see extended help for more context).
Instead, we map each Core.IntrinsicFunction
to one of the regular Julia functions in Mooncake.IntrinsicsWrappers
, to which we can dispatch in the usual way.
Extended Help
It is possible that owing to improvements in constant propagation in the Julia compiler in version 1.10, we actually could get away with just writing a single method of rrule!!
to handle all intrinsics, so this dispatch-based mechanism might be unnecessary. Someone should investigate this. Discussed at https://github.com/compintell/Mooncake.jl/issues/387 .