API walkthrough
The function derivative_estimate
transforms a stochastic program containing discrete randomness into a new program whose average is the derivative of the original.
StochasticAD.derivative_estimate
— Functionderivative_estimate(X, p, alg::AbstractStochasticADAlgorithm = ForwardAlgorithm(PrunedFIsBackend()); direction=nothing, alg_data::NamedTuple = (;))
Compute an unbiased estimate of $\frac{\mathrm{d}\mathbb{E}[X(p)]}{\mathrm{d}p}$, the derivative of the expectation of the random function X(p)
with respect to its input p
.
Both p
and X(p)
can be any object supported by Functors.jl
, e.g. scalars or abstract arrays. The output of derivative_estimate
has the same outer structure as p
, but with each scalar in p
replaced by a derivative estimate of X(p)
with respect to that entry. For example, if X(p) <: AbstractMatrix
and p <: Real
, then the output would be a matrix.
The alg
keyword argument specifies the algorithm used to compute the derivative estimate. For backward compatibility, an additional signature derivative_estimate(X, p; backend, direction=nothing)
is supported, which uses ForwardAlgorithm
by default with the supplied backend.
The alg_data
keyword argument can specify any additional data that specific algorithms accept or require.
When direction
is provided, the output is only differentiated with respect to a perturbation of p
in that direction.
Example
julia> using Distributions, Random, StochasticAD; Random.seed!(4321);
julia> derivative_estimate(rand ∘ Bernoulli, 0.5) # A random quantity that averages to the true derivative.
2.0
julia> derivative_estimate(x -> [rand(Bernoulli(x * i/4)) for i in 1:3], 0.5)
3-element Vector{Float64}:
0.2857142857142857
0.6666666666666666
0.0
While derivative_estimate
is self-contained, we can also use the functions below to work with stochastic triples directly.
StochasticAD.stochastic_triple
— Functionstochastic_triple(X, p; backend=PrunedFIsBackend(), direction=nothing)
stochastic_triple(p; backend=PrunedFIsBackend(), direction=nothing)
For any p
that is supported by Functors.jl
, e.g. scalars or abstract arrays, differentiate the output with respect to each value of p
, returning an output of similar structure to p
, where a particular value contains the stochastic-triple output of X
when perturbing the corresponding value in p
(i.e. replacing the original value x
with x + ε
).
When direction
is provided, return only the stochastic-triple output of X
with respect to a perturbation of p
in that particular direction. When X
is not provided, the identity function is used.
The backend
keyword argument describes the algorithm used by the third component of the stochastic triple, see technical details for more details.
Example
julia> using Distributions, Random, StochasticAD; Random.seed!(4321);
julia> stochastic_triple(rand ∘ Bernoulli, 0.5)
StochasticTriple of Int64:
0 + 0ε + (1 with probability 2.0ε)
StochasticAD.derivative_contribution
— Functionderivative_contribution(st::StochasticTriple)
Return the derivative estimate given by combining the dual and triple components of st
.
StochasticAD.value
— Functionvalue(st::StochasticTriple)
Return the primal value of st
.
StochasticAD.delta
— Functiondelta(st::StochasticTriple)
Return the almost-sure derivative of st
, i.e. the rate of infinitesimal change.
StochasticAD.perturbations
— Functionperturbations(st::StochasticTriple)
Return the finite perturbation(s) of st
, in a format dependent on the backend used for storing perturbations.
Note that derivative_estimate
is simply the composition of stochastic_triple
and derivative_contribution
. We also provide a convenience function for mimicking the behaviour of standard AD, where derivatives of discrete random steps are dropped:
StochasticAD.dual_number
— Functiondual_number(X, p; backend=PrunedFIsBackend(), direction=nothing)
dual_number(p; backend=PrunedFIsBackend(), direction=nothing)
A lightweight wrapper around stochastic_triple
that entirely ignores the derivative contribution of all discrete random components, so that it behaves like a regular dual number. Mostly for fun – this, of course, leads to a useless derivative estimate for discrete random functions!
Algorithms
StochasticAD.ForwardAlgorithm
— TypeForwardAlgorithm(backend::StochasticAD.AbstractFIsBackend) <: AbstractStochasticADAlgorithm
A differentiation algorithm relying on forward propagation of stochastic triples.
The backend
argument controls the algorithm used by the third component of the stochastic triples.
The required computation time for forward-mode AD scales linearly with the number of parameters in p
(but is unaffected by the number of parameters in X(p)
).
StochasticAD.EnzymeReverseAlgorithm
— TypeEnzymeReverseAlgorithm(backend::StochasticAD.AbstractFIsBackend) <: AbstractStochasticADAlgorithm
A differentiation algorithm relying on transposing the propagation of stochastic triples to produce a reverse-mode algorithm. The transposition is performed by Enzyme.jl, which must be loaded for the algorithm to run.
Currently, only real- and vector-valued inputs are supported, and only real-valued outputs are supported.
The backend
argument controls the algorithm used by the third component of the stochastic triples.
In the call to derivative_estimate
, this algorithm optionally accepts alg_data
with the field forward_u
, which specifies the directional derivative used in the forward pass that will be transposed. If forward_u
is not provided, it is randomly generated.
For the reverse-mode algorithm to yield correct results, the employed backend
cannot use input-dependent pruning strategies. A suggested reverse-mode compatible backend is PrunedFIsBackend(Val(:wins))
.
Additionally, this algorithm relies on the ability of Enzyme.jl
to differentiate the forward stochastic triple run. It is recommended to check that the primal function X
is type stable for its input p
using a tool such as JET.jl, with all code executed in a function with no global state. In addition, sometimes X
may be type stable but stochastic triples introduce additional type instabilities. This can be debugged by checking type stability of Enzyme's target, which is Base.get_extension(StochasticAD, :StochasticADEnzymeExt).enzyme_target(u, X, p, backend)
, where u
is a test direction.
For more details on the reverse-mode approach, see the following papers and talks:
- "You Only Linearize Once: Tangents Transpose to Gradients", Radul et al. 2022.
- "Reverse mode ADEV via YOLO: tangent estimators transpose to gradient estimators", Becker et al. 2024
- "Probabilistic Programming with Programmable Variational Inference", Becker et al. 2024
Smoothing
What happens if we were to run derivative_contribution
after each step, instead of only at the end? This is smoothing, which combines the second and third components of a single stochastic triple into a single dual component. Smoothing no longer has a guarantee of unbiasedness, but is surprisingly accurate in a number of situations. For example, the popular straight through gradient estimator can be viewed as a special case of smoothing. Forward smoothing rules are provided through ForwardDiff
, and backward rules through ChainRules
, so that e.g. Zygote.gradient
and ForwardDiff.derivative
will use smoothed rules for discrete random variables rather than dropping the gradients entirely. Currently, special discrete->discrete constructs such as array indexing are not supported for smoothing.
Optimization
We also provide utilities to make it easier to get started with forming and training a model via stochastic gradient descent:
StochasticAD.StochasticModel
— TypeStochasticModel(X, p)
Combine stochastic program X
with parameter p
into a trainable model using Functors, where p <: AbstractArray
. Formulate as a minimization problem, i.e. find $p$ that minimizes $\mathbb{E}[X(p)]$.
StochasticAD.stochastic_gradient
— Functionstochastic_gradient(m::StochasticModel)
Compute gradient with respect to the trainable parameter p
of StochasticModel(X, p)
.
These are used in the tutorial on stochastic optimization.