Developer documentation (WIP)
Writing a custom rule for stochastic triples
via StochasticAD.propagate
To handle a deterministic discrete construct that StochasticAD
does not automatically handle (e.g. branching via if
, boolean comparisons), it is often sufficient to simply add a dispatch rule that calls out to StochasticAD.propagate
.
StochasticAD.propagate
— Functionpropagate(f, args...; keep_deltas = Val(false))
Propagates args
through a function f
, handling stochastic triples by independently running f
on the primal and the alternatives, rather than by inspecting the internals of f
(which may possibly be unsupported by StochasticAD
). Currently handles deterministic functions f
with any input and output that is fmap
-able by Functors.jl
. If f
has a continuously differentiable component, provide keep_deltas = Val(true)
.
This functionality is orthogonal to dispatch: the idea is for this function to be the "backend" for operator overloading rules based on dispatch. For example:
using StochasticAD, Distributions
import Random # hide
Random.seed!(4321) # hide
function mybranch(x)
str = repr(x) # string-valued intermediate!
if length(str) < 2
return 3
else
return 7
end
end
function f(x)
return mybranch(9 + rand(Bernoulli(x)))
end
# stochastic_triple(f, 0.5) # this would fail
# Add a dispatch rule for mybranch using StochasticAD.propagate
mybranch(x::StochasticAD.StochasticTriple) = StochasticAD.propagate(mybranch, x)
stochastic_triple(f, 0.5) # now works
# output
StochasticTriple of Int64:
3 + 0ε + (4 with probability 2.0ε)
This function is experimental and subject to change.
via a custom dispatch
If a function does not meet the conditions of StochasticAD.propagate
and is not already supported, a custom dispatch may be necessary. For example, consider the following function which manually implements a geometric random variable:
import Random
using Distributions
# make rng input explicit
function mygeometric(rng, p)
x = 0
while !(rand(rng, Bernoulli(p)))
x += 1
end
return x
end
mygeometric(p) = mygeometric(Random.default_rng(), p)
mygeometric (generic function with 2 methods)
This is equivalent to rand(Geometric(p))
which is already supported, but for pedagogical purposes we will implement our own rule from scratch. Using the stochastic derivative formulas from Automatic Differentiation of Programs with Discrete Randomness, the right stochastic derivative of this program is given by
\[Y_R = X - 1, w_R = \frac{x}{p(1-p)},\]
and the left stochastic derivative of this program is given by
\[Y_L = X + 1, w_L = -\frac{x+1}{p}.\]
Using these expressions, we can now write the dispatch rule for stochastic triples:
using StochasticAD
import StochasticAD: StochasticTriple, similar_new, similar_empty, combine
function mygeometric(rng, p_st::StochasticTriple{T}) where {T}
p = p_st.value
rng_copy = copy(rng) # save a copy for coupling later
x = mygeometric(rng, p)
# Form the new discrete perturbations (combinations of weight w and perturbation Y - X)
Δs1 = if p_st.δ > 0
# right stochastic derivative
w = p_st.δ * x / (p * (1 - p))
x > 0 ? similar_new(p_st.Δs, -1, w) : similar_empty(p_st.Δs, Int)
elseif p_st.δ < 0
# left stochastic derivative
w = -p_st.δ * (x + 1) / p # positive since the negativity of p_st.δ cancels out the negativity of w_L
similar_new(p_st.Δs, 1, w)
else
similar_empty(p_st.Δs, Int)
end
# Propagate any existing perturbations to p through the function
function map_func(Δ)
# Couple the samples by using the same RNG. (A simpler strategy would have been independent sampling, i.e. mygeometric(p + Δ) - x)
mygeometric(copy(rng_copy), p + Δ) - x
end
Δs2 = map(map_func, p_st.Δs)
# Return the output stochastic triple
StochasticTriple{T}(x, zero(x), combine((Δs2, Δs1)))
end
mygeometric (generic function with 3 methods)
In the above, we used some of the interface functions supported by a collection of perturbations Δs::StochasticAD.AbstractFIs
. These were similar_empty(Δs, V)
, which created an empty perturbation of type V
, similar_new(Δs, Δ, w)
, which created a new perturbation of size Δ
and weight w
, map(map_func, Δs)
, which propagates a collection of perturbations through a mapping function, and combine((Δs2, Δs1)))
which combines multiple collections of perturbations together.
We can test out our rule:
@show stochastic_triple(mygeometric, 0.1)
# try feeding an input that already has a pertrubation
f(x) = mygeometric(2 * x + 0.1 * rand(Bernoulli(x)))^2
@show stochastic_triple(f, 0.1)
# verify against black-box finite differences
N = 1000000
samples_stochad = [derivative_estimate(f, 0.1) for i in 1:N]
samples_fd = [(f(0.105) - f(0.095)) / 0.01 for i in 1:N]
println("Stochastic AD: $(mean(samples_stochad)) ± $(std(samples_stochad) / sqrt(N))")
println("Finite differences: $(mean(samples_fd)) ± $(std(samples_fd) / sqrt(N))")
stochastic_triple(mygeometric, 0.1) = 12 + 0ε + (-1 with probability 133.33333333333331ε)
stochastic_triple(f, 0.1) = 36 + 0ε + (-11 with probability 74.99999999999999ε)
Stochastic AD: -811.8671274206348 ± 2.3883058607387273
Finite differences: -814.9264 ± 11.828743174223884
Distribution-specific customization of differentiation algorithm
StochasticAD.randst
— Functionrandst(rng, d::Distributions.Sampleable; kwargs...)
When no keyword arguments are provided, randst
behaves identically to rand(rng, d)
in both ordinary computation and for stochastic triple dispatches. However, randst
also allows the user to provide various keyword arguments for customizing the differentiation logic. The set of allowed keyword arguments depends on the type of d
: a couple common ones are derivative_coupling
and propagation_coupling
.
For developers: if you wish to accept custom keyword arguments in a stochastic triple dispatch, you should overload randst
, and redirect rand
to your randst
method. If you do not, it suffices to just overload rand
.
StochasticAD.InversionMethodDerivativeCoupling
— TypeInversionMethodDerivativeCoupling(; mode::Val = Val(:positive_weight), handle_zeroprob::Val = Val(true))
Specifies an inversion method coupling for generating perturbations from a univariate distribution. Valid choices of mode
are Val(:positive_weight)
, Val(:always_right)
, and Val(:always_left)
.
Example
julia> using StochasticAD, Distributions, Random; Random.seed!(4321);
julia> function X(p)
return randst(Bernoulli(1 - p); derivative_coupling = InversionMethodDerivativeCoupling(; mode = Val(:always_right)))
end
X (generic function with 1 method)
julia> stochastic_triple(X, 0.5)
StochasticTriple of Int64:
0 + 0ε + (1 with probability -2.0ε)