API walkthrough
The function derivative_estimate
transforms a stochastic program containing discrete randomness into a new program whose average is the derivative of the original.
StochasticAD.derivative_estimate
— Functionderivative_estimate(X, p; backend=PrunedFIsBackend(), direction=nothing)
Compute an unbiased estimate of $\frac{\mathrm{d}\mathbb{E}[X(p)]}{\mathrm{d}p}$, the derivative of the expectation of the random function X(p)
with respect to its input p
.
Both p
and X(p)
can be any object supported by Functors.jl
, e.g. scalars or abstract arrays. The output of derivative_estimate
has the same outer structure as p
, but with each scalar in p
replaced by a derivative estimate of X(p)
with respect to that entry.
For example, if X(p) <: AbstractMatrix
and p <: Real
, then the output would be a matrix. The backend
keyword argument describes the algorithm used by the third component of the stochastic triple, see technical details for more details.
When direction
is provided, the output is only differentiated with respect to a perturbation of p
in that direction.
Since derivative_estimate
performs forward-mode AD, the required computation time scales linearly with the number of parameters in p
(but is unaffected by the number of parameters in X(p)
).
Example
julia> using Distributions, Random, StochasticAD; Random.seed!(4321);
julia> derivative_estimate(rand ∘ Bernoulli, 0.5) # A random quantity that averages to the true derivative.
2.0
julia> derivative_estimate(x -> [rand(Bernoulli(x * i/4)) for i in 1:3], 0.5)
3-element Vector{Float64}:
0.2857142857142857
0.6666666666666666
0.0
While derivative_estimate
is self-contained, we can also use the functions below to work with stochastic triples directly.
StochasticAD.stochastic_triple
— Functionstochastic_triple(X, p; backend=PrunedFIsBackend(), direction=nothing)
stochastic_triple(p; backend=PrunedFIsBackend(), direction=nothing)
For any p
that is supported by Functors.jl
, e.g. scalars or abstract arrays, differentiate the output with respect to each value of p
, returning an output of similar structure to p
, where a particular value contains the stochastic-triple output of X
when perturbing the corresponding value in p
(i.e. replacing the original value x
with x + ε
).
When direction
is provided, return only the stochastic-triple output of X
with respect to a perturbation of p
in that particular direction. When X
is not provided, the identity function is used.
The backend
keyword argument describes the algorithm used by the third component of the stochastic triple, see technical details for more details.
Example
julia> using Distributions, Random, StochasticAD; Random.seed!(4321);
julia> stochastic_triple(rand ∘ Bernoulli, 0.5)
StochasticTriple of Int64:
0 + 0ε + (1 with probability 2.0ε)
StochasticAD.derivative_contribution
— Functionderivative_contribution(st::StochasticTriple)
Return the derivative estimate given by combining the dual and triple components of st
.
StochasticAD.value
— Functionvalue(st::StochasticTriple)
Return the primal value of st
.
StochasticAD.delta
— Functiondelta(st::StochasticTriple)
Return the almost-sure derivative of st
, i.e. the rate of infinitesimal change.
StochasticAD.perturbations
— Functionperturbations(st::StochasticTriple)
Return the finite perturbation(s) of st
, in a format dependent on the backend used for storing perturbations.
Note that derivative_estimate
is simply the composition of stochastic_triple
and derivative_contribution
. We also provide a convenience function for mimicking the behaviour of standard AD, where derivatives of discrete random steps are dropped:
StochasticAD.dual_number
— Functiondual_number(X, p; backend=PrunedFIsBackend(), direction=nothing)
dual_number(p; backend=PrunedFIsBackend(), direction=nothing)
A lightweight wrapper around stochastic_triple
that entirely ignores the derivative contribution of all discrete random components, so that it behaves like a regular dual number. Mostly for fun – this, of course, leads to a useless derivative estimate for discrete random functions!
Smoothing
What happens if we were to run derivative_contribution
after each step, instead of only at the end? This is smoothing, which combines the second and third components of a single stochastic triple into a single dual component. Smoothing no longer has a guarantee of unbiasedness, but is surprisingly accurate in a number of situations. For example, the popular straight through gradient estimator can be viewed as a special case of smoothing. Forward smoothing rules are provided through ForwardDiff
, and backward rules through ChainRules
, so that e.g. Zygote.gradient
and ForwardDiff.derivative
will use smoothed rules for discrete random variables rather than dropping the gradients entirely. Currently, special discrete->discrete constructs such as array indexing are not supported for smoothing.
Optimization
We also provide utilities to make it easier to get started with forming and training a model via stochastic gradient descent:
StochasticAD.StochasticModel
— TypeStochasticModel(X, p)
Combine stochastic program X
with parameter p
into a trainable model using Functors, where p <: AbstractArray
. Formulate as a minimization problem, i.e. find $p$ that minimizes $\mathbb{E}[X(p)]$.
StochasticAD.stochastic_gradient
— Functionstochastic_gradient(m::StochasticModel)
Compute gradient with respect to the trainable parameter p
of StochasticModel(X, p)
.
These are used in the tutorial on stochastic optimization.