Homework 08

In this homework you will write an additional rule for our scalar reverse AD from the lab. For this homework, please write all your code in one file hw.jl which you have to zip and upload to BRUTE as usual. The solution to the lab is below.

mutable struct TrackedReal{T<:Real}
    data::T
    grad::Union{Nothing,T}
    children::Dict
    # this field is only need for printing the graph. you can safely remove it.
    name::String
end

track(x::Real,name="") = TrackedReal(x,nothing,Dict(),name)

function Base.show(io::IO, x::TrackedReal)
    t = isempty(x.name) ? "(tracked)" : "(tracked $(x.name))"
    print(io, "$(x.data) $t")
end

function accum!(x::TrackedReal)
    if isnothing(x.grad)
        x.grad = sum(w*accum!(v) for (v,w) in x.children)
    end
    x.grad
end

function Base.:*(a::TrackedReal, b::TrackedReal)
    z = track(a.data * b.data, "*")
    a.children[z] = b.data  # dz/da=b
    b.children[z] = a.data  # dz/db=a
    z
end

function Base.:+(a::TrackedReal{T}, b::TrackedReal{T}) where T
    z = track(a.data + b.data, "+")
    a.children[z] = one(T)
    b.children[z] = one(T)
    z
end

function Base.sin(x::TrackedReal)
    z = track(sin(x.data), "sin")
    x.children[z] = cos(x.data)
    z
end

function gradient(f, args::Real...)
    ts = track.(args)
    y  = f(ts...)
    y.grad = 1.0
    accum!.(ts)
end
gradient (generic function with 1 method)

We will use it to compute the derivative of the Babylonian square root.

babysqrt(x, t=(1+x)/2, n=10) = n==0 ? t : babysqrt(x, (t+x/t)/2, n-1)

In order to differentiate through babysqrt you will need a reverse rule for / for Base.:/(TrackedReal,TrackedReal) as well as the cases where you divide with constants in volved (e.g. Base.:/(TrackedReal,Real)).

Homework (2 points)

Write the reverse rules for / and the missing rules for + such that you can differentiate through division and addition with and without constants.

You can verify your solution with the gradient function.

julia> gradient(babysqrt, 2.0)(0.35355339059327373,)
julia> 1/(2babysqrt(2.0))0.3535533905932738