Homework 08
In this homework you will write an additional rule for our scalar reverse AD from the lab. For this homework, please write all your code in one file hw.jl
which you have to zip and upload to BRUTE as usual. The solution to the lab is below.
mutable struct TrackedReal{T<:Real}
data::T
grad::Union{Nothing,T}
children::Dict
# this field is only need for printing the graph. you can safely remove it.
name::String
end
track(x::Real,name="") = TrackedReal(x,nothing,Dict(),name)
function Base.show(io::IO, x::TrackedReal)
t = isempty(x.name) ? "(tracked)" : "(tracked $(x.name))"
print(io, "$(x.data) $t")
end
function accum!(x::TrackedReal)
if isnothing(x.grad)
x.grad = sum(w*accum!(v) for (v,w) in x.children)
end
x.grad
end
function Base.:*(a::TrackedReal, b::TrackedReal)
z = track(a.data * b.data, "*")
a.children[z] = b.data # dz/da=b
b.children[z] = a.data # dz/db=a
z
end
function Base.:+(a::TrackedReal{T}, b::TrackedReal{T}) where T
z = track(a.data + b.data, "+")
a.children[z] = one(T)
b.children[z] = one(T)
z
end
function Base.sin(x::TrackedReal)
z = track(sin(x.data), "sin")
x.children[z] = cos(x.data)
z
end
function gradient(f, args::Real...)
ts = track.(args)
y = f(ts...)
y.grad = 1.0
accum!.(ts)
end
gradient (generic function with 1 method)
We will use it to compute the derivative of the Babylonian square root.
babysqrt(x, t=(1+x)/2, n=10) = n==0 ? t : babysqrt(x, (t+x/t)/2, n-1)
In order to differentiate through babysqrt
you will need a reverse rule for /
for Base.:/(TrackedReal,TrackedReal)
as well as the cases where you divide with constants in volved (e.g. Base.:/(TrackedReal,Real)
).
Write the reverse rules for /
and the missing rules for +
such that you can differentiate through division and addition with and without constants.
You can verify your solution with the gradient
function.
julia> gradient(babysqrt, 2.0)
(0.35355339059327373,)
julia> 1/(2babysqrt(2.0))
0.3535533905932738