Developer documentation (WIP)

Writing a custom rule for stochastic triples

via StochasticAD.propagate

To handle a deterministic discrete construct that StochasticAD does not automatically handle (e.g. branching via if, boolean comparisons), it is often sufficient to simply add a dispatch rule that calls out to StochasticAD.propagate.

StochasticAD.propagateFunction
propagate(f, args...; keep_deltas = Val(false))

Propagates args through a function f, handling stochastic triples by independently running f on the primal and the alternatives, rather than by inspecting the internals of f (which may possibly be unsupported by StochasticAD). Currently handles deterministic functions f with any input and output that is fmap-able by Functors.jl. If f has a continuously differentiable component, provide keep_deltas = Val(true).

This functionality is orthogonal to dispatch: the idea is for this function to be the "backend" for operator overloading rules based on dispatch. For example:

using StochasticAD, Distributions
import Random # hide
Random.seed!(4321) # hide

function mybranch(x)
    str = repr(x) # string-valued intermediate!
    if length(str) < 2
        return 3
    else
        return 7
    end
end

function f(x)
    return mybranch(9 + rand(Bernoulli(x)))
end

# stochastic_triple(f, 0.5) # this would fail

# Add a dispatch rule for mybranch using StochasticAD.propagate
mybranch(x::StochasticAD.StochasticTriple) = StochasticAD.propagate(mybranch, x)

stochastic_triple(f, 0.5) # now works

# output

StochasticTriple of Int64:
3 + 0ε + (4 with probability 2.0ε)
Warning

This function is experimental and subject to change.

source

via a custom dispatch

If a function does not meet the conditions of StochasticAD.propagate and is not already supported, a custom dispatch may be necessary. For example, consider the following function which manually implements a geometric random variable:

import Random
using Distributions
# make rng input explicit
function mygeometric(rng, p)
    x = 0
    while !(rand(rng, Bernoulli(p)))
        x += 1
    end
    return x
end
mygeometric(p) = mygeometric(Random.default_rng(), p)
mygeometric (generic function with 2 methods)

This is equivalent to rand(Geometric(p)) which is already supported, but for pedagogical purposes we will implement our own rule from scratch. Using the stochastic derivative formulas from Automatic Differentiation of Programs with Discrete Randomness, the right stochastic derivative of this program is given by

\[Y_R = X - 1, w_R = \frac{x}{p(1-p)},\]

and the left stochastic derivative of this program is given by

\[Y_L = X + 1, w_L = -\frac{x+1}{p}.\]

Using these expressions, we can now write the dispatch rule for stochastic triples:

using StochasticAD
import StochasticAD: StochasticTriple, similar_new, similar_empty, combine
function mygeometric(rng, p_st::StochasticTriple{T}) where {T}
    p = p_st.value
    rng_copy = copy(rng) # save a copy for coupling later
    x = mygeometric(rng, p)

    # Form the new discrete perturbations (combinations of weight w and perturbation Y - X)
    Δs1 = if p_st.δ > 0
        # right stochastic derivative
        w = p_st.δ * x / (p * (1 - p))
        x > 0 ? similar_new(p_st.Δs, -1, w) : similar_empty(p_st.Δs, Int)
    elseif p_st.δ < 0
        # left stochastic derivative
        w = -p_st.δ * (x + 1) / p # positive since the negativity of p_st.δ cancels out the negativity of w_L
        similar_new(p_st.Δs, 1, w)
    else
        similar_empty(p_st.Δs, Int)
    end

    # Propagate any existing perturbations to p through the function
    function map_func(Δ)
        # Couple the samples by using the same RNG. (A simpler strategy would have been independent sampling, i.e. mygeometric(p + Δ) - x)
        mygeometric(copy(rng_copy), p + Δ) - x
    end
    Δs2 = map(map_func, p_st.Δs)

    # Return the output stochastic triple
    StochasticTriple{T}(x, zero(x), combine((Δs2, Δs1)))
end
mygeometric (generic function with 3 methods)

In the above, we used some of the interface functions supported by a collection of perturbations Δs::StochasticAD.AbstractFIs. These were similar_empty(Δs, V), which created an empty perturbation of type V, similar_new(Δs, Δ, w), which created a new perturbation of size Δ and weight w, map(map_func, Δs), which propagates a collection of perturbations through a mapping function, and combine((Δs2, Δs1))) which combines multiple collections of perturbations together.

We can test out our rule:

@show stochastic_triple(mygeometric, 0.1)

# try feeding an input that already has a pertrubation
f(x) = mygeometric(2 * x + 0.1 * rand(Bernoulli(x)))^2
@show stochastic_triple(f, 0.1)

# verify against black-box finite differences
N = 1000000
samples_stochad = [derivative_estimate(f, 0.1) for i in 1:N]
samples_fd = [(f(0.105) - f(0.095)) / 0.01 for i in 1:N]

println("Stochastic AD: $(mean(samples_stochad)) ± $(std(samples_stochad) / sqrt(N))")
println("Finite differences: $(mean(samples_fd)) ± $(std(samples_fd) / sqrt(N))")
stochastic_triple(mygeometric, 0.1) = 12 + 0ε + (-1 with probability 133.33333333333331ε)
stochastic_triple(f, 0.1) = 36 + 0ε + (-11 with probability 74.99999999999999ε)
Stochastic AD: -811.8671274206348 ± 2.3883058607387273
Finite differences: -814.9264 ± 11.828743174223884

Distribution-specific customization of differentiation algorithm

StochasticAD.randstFunction
randst(rng, d::Distributions.Sampleable; kwargs...)

When no keyword arguments are provided, randst behaves identically to rand(rng, d) in both ordinary computation and for stochastic triple dispatches. However, randst also allows the user to provide various keyword arguments for customizing the differentiation logic. The set of allowed keyword arguments depends on the type of d: a couple common ones are derivative_coupling and propagation_coupling.

For developers: if you wish to accept custom keyword arguments in a stochastic triple dispatch, you should overload randst, and redirect rand to your randst method. If you do not, it suffices to just overload rand.

source
StochasticAD.InversionMethodDerivativeCouplingType
InversionMethodDerivativeCoupling(; mode::Val = Val(:positive_weight), handle_zeroprob::Val = Val(true))

Specifies an inversion method coupling for generating perturbations from a univariate distribution. Valid choices of mode are Val(:positive_weight), Val(:always_right), and Val(:always_left).

Example

julia> using StochasticAD, Distributions, Random; Random.seed!(4321);

julia> function X(p)
           return randst(Bernoulli(1 - p); derivative_coupling = InversionMethodDerivativeCoupling(; mode = Val(:always_right)))
       end
X (generic function with 1 method)

julia> stochastic_triple(X, 0.5)
StochasticTriple of Int64:
0 + 0ε + (1 with probability -2.0ε)
source