API walkthrough

The function derivative_estimate transforms a stochastic program containing discrete randomness into a new program whose average is the derivative of the original.

StochasticAD.derivative_estimateFunction
derivative_estimate(X, p, alg::AbstractStochasticADAlgorithm = ForwardAlgorithm(PrunedFIsBackend()); direction=nothing, alg_data::NamedTuple = (;))

Compute an unbiased estimate of $\frac{\mathrm{d}\mathbb{E}[X(p)]}{\mathrm{d}p}$, the derivative of the expectation of the random function X(p) with respect to its input p.

Both p and X(p) can be any object supported by Functors.jl, e.g. scalars or abstract arrays. The output of derivative_estimate has the same outer structure as p, but with each scalar in p replaced by a derivative estimate of X(p) with respect to that entry. For example, if X(p) <: AbstractMatrix and p <: Real, then the output would be a matrix.

The alg keyword argument specifies the algorithm used to compute the derivative estimate. For backward compatibility, an additional signature derivative_estimate(X, p; backend, direction=nothing) is supported, which uses ForwardAlgorithm by default with the supplied backend. The alg_data keyword argument can specify any additional data that specific algorithms accept or require.

When direction is provided, the output is only differentiated with respect to a perturbation of p in that direction.

Example

julia> using Distributions, Random, StochasticAD; Random.seed!(4321);

julia> derivative_estimate(rand ∘ Bernoulli, 0.5) # A random quantity that averages to the true derivative.
2.0

julia> derivative_estimate(x -> [rand(Bernoulli(x * i/4)) for i in 1:3], 0.5)
3-element Vector{Float64}:
 0.2857142857142857
 0.6666666666666666
 0.0
source

While derivative_estimate is self-contained, we can also use the functions below to work with stochastic triples directly.

StochasticAD.stochastic_tripleFunction
stochastic_triple(X, p; backend=PrunedFIsBackend(), direction=nothing)
stochastic_triple(p; backend=PrunedFIsBackend(), direction=nothing)

For any p that is supported by Functors.jl, e.g. scalars or abstract arrays, differentiate the output with respect to each value of p, returning an output of similar structure to p, where a particular value contains the stochastic-triple output of X when perturbing the corresponding value in p (i.e. replacing the original value x with x + ε).

When direction is provided, return only the stochastic-triple output of X with respect to a perturbation of p in that particular direction. When X is not provided, the identity function is used.

The backend keyword argument describes the algorithm used by the third component of the stochastic triple, see technical details for more details.

Example

julia> using Distributions, Random, StochasticAD; Random.seed!(4321);

julia> stochastic_triple(rand ∘ Bernoulli, 0.5)
StochasticTriple of Int64:
0 + 0ε + (1 with probability 2.0ε)
source
StochasticAD.deltaFunction
delta(st::StochasticTriple)

Return the almost-sure derivative of st, i.e. the rate of infinitesimal change.

source

Note that derivative_estimate is simply the composition of stochastic_triple and derivative_contribution. We also provide a convenience function for mimicking the behaviour of standard AD, where derivatives of discrete random steps are dropped:

StochasticAD.dual_numberFunction
dual_number(X, p; backend=PrunedFIsBackend(), direction=nothing)
dual_number(p; backend=PrunedFIsBackend(), direction=nothing)

A lightweight wrapper around stochastic_triple that entirely ignores the derivative contribution of all discrete random components, so that it behaves like a regular dual number. Mostly for fun – this, of course, leads to a useless derivative estimate for discrete random functions!

source

Algorithms

StochasticAD.ForwardAlgorithmType
ForwardAlgorithm(backend::StochasticAD.AbstractFIsBackend) <: AbstractStochasticADAlgorithm

A differentiation algorithm relying on forward propagation of stochastic triples.

The backend argument controls the algorithm used by the third component of the stochastic triples.

Note

The required computation time for forward-mode AD scales linearly with the number of parameters in p (but is unaffected by the number of parameters in X(p)).

source
StochasticAD.EnzymeReverseAlgorithmType
EnzymeReverseAlgorithm(backend::StochasticAD.AbstractFIsBackend) <: AbstractStochasticADAlgorithm

A differentiation algorithm relying on transposing the propagation of stochastic triples to produce a reverse-mode algorithm. The transposition is performed by Enzyme.jl, which must be loaded for the algorithm to run.

Currently, only real- and vector-valued inputs are supported, and only real-valued outputs are supported.

The backend argument controls the algorithm used by the third component of the stochastic triples.

In the call to derivative_estimate, this algorithm optionally accepts alg_data with the field forward_u, which specifies the directional derivative used in the forward pass that will be transposed. If forward_u is not provided, it is randomly generated.

Warning

For the reverse-mode algorithm to yield correct results, the employed backend cannot use input-dependent pruning strategies. A suggested reverse-mode compatible backend is PrunedFIsBackend(Val(:wins)).

Additionally, this algorithm relies on the ability of Enzyme.jl to differentiate the forward stochastic triple run. It is recommended to check that the primal function X is type stable for its input p using a tool such as JET.jl, with all code executed in a function with no global state. In addition, sometimes X may be type stable but stochastic triples introduce additional type stabilities. This can be debugged by checking type stability of Enzyme's target, which is Base.get_extension(StochasticAD, :StochasticADEnzymeExt).enzyme_target(u, X, p, backend), where u is a test direction.

Note

For more details on the reverse-mode approach, see the following papers and talks:

source

Smoothing

What happens if we were to run derivative_contribution after each step, instead of only at the end? This is smoothing, which combines the second and third components of a single stochastic triple into a single dual component. Smoothing no longer has a guarantee of unbiasedness, but is surprisingly accurate in a number of situations. For example, the popular straight through gradient estimator can be viewed as a special case of smoothing. Forward smoothing rules are provided through ForwardDiff, and backward rules through ChainRules, so that e.g. Zygote.gradient and ForwardDiff.derivative will use smoothed rules for discrete random variables rather than dropping the gradients entirely. Currently, special discrete->discrete constructs such as array indexing are not supported for smoothing.

Optimization

We also provide utilities to make it easier to get started with forming and training a model via stochastic gradient descent:

StochasticAD.StochasticModelType
StochasticModel(X, p)

Combine stochastic program X with parameter p into a trainable model using Functors, where p <: AbstractArray. Formulate as a minimization problem, i.e. find $p$ that minimizes $\mathbb{E}[X(p)]$.

source

These are used in the tutorial on stochastic optimization.