Simple reverse mode example

Load our packages

using StochasticAD
using Distributions
using Enzyme
using LinearAlgebra

Let us define our target function.

# Define a toy `StochasticAD`-differentiable function for computing an integer value from a string.
string_value(strings, index) = Int(sum(codepoint, strings[index]))
string_value(strings, index::StochasticTriple) = StochasticAD.propagate(index -> string_value(strings, index), index)

function f(θ; derivative_coupling = StochasticAD.InversionMethodDerivativeCoupling())
    strings = ["cat", "dog", "meow", "woofs"]
    index = randst(Categorical(θ); derivative_coupling)
    return string_value(strings, index)
end

θ = [0.1, 0.5, 0.3, 0.1]
@show f(θ)
nothing
f(θ) = 314

First, let's compute the sensitivity of f in a particular direction via forward-mode Stochastic AD.

u = [1.0, 2.0, 4.0, -7.0]
@show derivative_estimate(f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u)
nothing
derivative_estimate(f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u) = -4.0

Now, let's do the same with reverse-mode.

@show derivative_estimate(f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins))))
4-element Vector{Float64}:
 -420.0
 -420.0
    0.0
    0.0

Let's verify that our reverse-mode gradient is consistent with our forward-mode directional derivative.

forward() = derivative_estimate(f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u)
reverse() = derivative_estimate(f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins))))

N = 40000
directional_derivs_fwd = [forward() for i in 1:N]
derivs_bwd = [reverse() for i in 1:N]
directional_derivs_bwd = [dot(u, δ) for δ in derivs_bwd]
println("Forward mode: $(mean(directional_derivs_fwd)) ± $(std(directional_derivs_fwd) / sqrt(N))")
println("Reverse mode: $(mean(directional_derivs_bwd)) ± $(std(directional_derivs_bwd) / sqrt(N))")
@assert isapprox(mean(directional_derivs_fwd), mean(directional_derivs_bwd), rtol = 3e-2)

nothing
Forward mode: -1205.9251 ± 12.082082198756398
Reverse mode: -1198.9259 ± 12.04394975354312

This page was generated using Literate.jl.