Simple reverse mode example
Load our packages
using StochasticAD
using Distributions
using Enzyme
using LinearAlgebra
Let us define our target function.
# Define a toy `StochasticAD`-differentiable function for computing an integer value from a string.
string_value(strings, index) = Int(sum(codepoint, strings[index]))
string_value(strings, index::StochasticTriple) = StochasticAD.propagate(index -> string_value(strings, index), index)
function f(θ; derivative_coupling = StochasticAD.InversionMethodDerivativeCoupling())
strings = ["cat", "dog", "meow", "woofs"]
index = randst(Categorical(θ); derivative_coupling)
return string_value(strings, index)
end
θ = [0.1, 0.5, 0.3, 0.1]
@show f(θ)
nothing
f(θ) = 314
First, let's compute the sensitivity of f
in a particular direction via forward-mode Stochastic AD.
u = [1.0, 2.0, 4.0, -7.0]
@show derivative_estimate(f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u)
nothing
derivative_estimate(f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u) = -4.0
Now, let's do the same with reverse-mode.
@show derivative_estimate(f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins))))
4-element Vector{Float64}:
-420.0
-420.0
0.0
0.0
Let's verify that our reverse-mode gradient is consistent with our forward-mode directional derivative.
forward() = derivative_estimate(f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u)
reverse() = derivative_estimate(f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins))))
N = 40000
directional_derivs_fwd = [forward() for i in 1:N]
derivs_bwd = [reverse() for i in 1:N]
directional_derivs_bwd = [dot(u, δ) for δ in derivs_bwd]
println("Forward mode: $(mean(directional_derivs_fwd)) ± $(std(directional_derivs_fwd) / sqrt(N))")
println("Reverse mode: $(mean(directional_derivs_bwd)) ± $(std(directional_derivs_bwd) / sqrt(N))")
@assert isapprox(mean(directional_derivs_fwd), mean(directional_derivs_bwd), rtol = 3e-2)
nothing
Forward mode: -1205.9251 ± 12.082082198756398
Reverse mode: -1198.9259 ± 12.04394975354312
This page was generated using Literate.jl.