Lab 09 - Generated Functions & IR

In this lab you will practice two advanced meta programming techniques:

  • Generated functions can help you write specialized code for certain kinds of parametric types with more flexibility and/or less code.
  • IRTools.jl is a package that simplifies the manipulation of lowered and typed Julia code

@generated Functions

Remember the three most important things about generated functions:

  • They return quoted expressions (like macros).
  • You have access to type information of your input variables.
  • They have to be pure

A faster polynomial

Throughout this course we have come back to our polynomial function which evaluates a polynomial based on the Horner schema. Below you can find a version of the function that operates on a tuple of length $N$.

function polynomial(x, p::NTuple{N}) where N
    acc = p[N]
    for i in N-1:-1:1
        acc = x*acc + p[i]
    end
    acc
end

Julia has its own implementation of this function called evalpoly. If we compare the performance of our polynomial and Julia's evalpoly we can observe a pretty big difference:

julia> x = 2.02.0
julia> p = ntuple(float,20);
julia> @btime polynomial($x,$p) 8.926 ns (0 allocations: 0 bytes) 1.9922945e7
julia> @btime evalpoly($x,$p) 4.528 ns (0 allocations: 0 bytes) 1.9922945e7

Julia's implementation uses a generated function which specializes on different tuple lengths (i.e. it unrolls the loop) and eliminates the (small) overhead of looping over the tuple. This is possible, because the length of the tuple is known during compile time. You can check the difference between polynomial and evalpoly yourself via the introspectionwtools you know - e.g. @code_lowered.

Exercise

Rewrite the polynomial function as a generated function with the signature

genpoly(x::Number, p::NTuple{N}) where N

Hints:

  • Remember that you have to generate a quoted expression inside your generated function, so you will need things like :($expr1 + $expr2).
  • You can debug the expression you are generating by omitting the @generated macro from your function.
Solution:

@generated function genpoly(x, p::NTuple{N}) where N
    ex = :(p[$N])
    for i in N-1:-1:1
        ex = :(x*$ex + p[$i])
    end
    ex
end

You should get the same performance as evalpoly (and as @poly from Lab 7 with the added convenience of not having to spell out all the coefficients in your code like: p = @poly 1 2 3 ...).

julia> @btime genpoly($x,$p)  8.976 ns (0 allocations: 0 bytes)
1.9922945e7

Fast, Static Matrices

Another great example that makes heavy use of generated functions are static arrays. A static array is an array of fixed size which can be implemented via an NTuple. This means that it will be allocated on the stack, which can buy us a lot of performance for smaller static arrays. We define a StaticMatrix{T,C,R,L} where the paramteric types represent the matrix element type T (e.g. Float32), the number of rows R, the number of columns C, and the total length of the matrix L=C*R (which we need to set the size of the NTuple).

struct StaticMatrix{T,R,C,L} <: AbstractArray{T,2}
    data::NTuple{L,T}
end

function StaticMatrix(x::AbstractMatrix{T}) where T
    (R,C) = size(x)
    StaticMatrix{T,R,C,C*R}(x |> Tuple)
end
Exercise

As a warm-up, overload the Base functions size, length, getindex(x::StaticMatrix,i::Int), and getindex(x::Solution,r::Int,c::Int).

Solution:

Base.size(x::StaticMatrix{T,R,C}) where {T,R,C} = (R,C)
Base.length(x::StaticMatrix{T,R,C,L}) where {T,R,C,L} = L
Base.getindex(x::StaticMatrix, i::Int) = x.data[i]
Base.getindex(x::StaticMatrix{T,R,C}, r::Int, c::Int) where {T,R,C} = x.data[R*(c-1) + r]

You can check if everything works correctly by comparing to a normal Matrix:

julia> x = rand(2,3)2×3 Matrix{Float64}:
 0.644112  0.494664  0.486665
 0.278853  0.803221  0.355822
julia> x[1,2]0.4946638963137471
julia> a = StaticMatrix(x)2×3 Main.StaticMatrix{Float64, 2, 3, 6}: 0.644112 0.494664 0.486665 0.278853 0.803221 0.355822
julia> a[1,2]0.4946638963137471
Exercise

Overload matrix multiplication between two static matrices

Base.:*(x::StaticMatrix{T,K,M},y::StaticMatrix{T,M,N})

with a generated function that creates an expression without loops. Below you can see an example for an expression that would be generated from multiplying two $2\times 2$ matrices.

:(StaticMatrix{T,2,2,4}((
    (x[1,1]*y[1,1] + x[1,2]*y[2,1]),
    (x[2,1]*y[1,1] + x[2,2]*y[2,1]),
    (x[1,1]*y[1,2] + x[1,2]*y[2,2]),
    (x[2,1]*y[1,2] + x[2,2]*y[2,2])
)))

Hints:

  • You can get output like above by leaving out the @generated in front of your overload.
  • It might be helpful to implement matrix multiplication in a normal Julia function first.
  • You can construct an expression for a sum of multiple elements like below.
julia> Expr(:call,:+,1,2,3):(1 + 2 + 3)
julia> Expr(:call,:+,1,2,3) |> eval6
Solution:

@generated function Base.:*(x::StaticMatrix{T,K,M}, y::StaticMatrix{T,M,N}) where {T,K,M,N}
    zs = map(Iterators.product(1:K, 1:N) |> collect |> vec) do (k,n)
        Expr(:call, :+, [:(x[$k,$m] * y[$m,$n]) for m=1:M]...)
    end
    z = Expr(:tuple, zs...)
    :(StaticMatrix{$T,$K,$N,$(K*N)}($z))
end

You can check that your matrix multiplication works by multiplying two random matrices. Which one is faster?

julia> a = rand(2,3)2×3 Matrix{Float64}:
 0.940869  0.666663  0.638214
 0.429085  0.933749  0.0453683
julia> b = rand(3,4)3×4 Matrix{Float64}: 0.71851 0.376276 0.336816 0.973923 0.201484 0.0494671 0.566261 0.684967 0.542826 0.191345 0.00384357 0.395016
julia> c = StaticMatrix(a)2×3 Main.StaticMatrix{Float64, 2, 3, 6}: 0.940869 0.666663 0.638214 0.429085 0.933749 0.0453683
julia> d = StaticMatrix(b)3×4 Main.StaticMatrix{Float64, 3, 4, 12}: 0.71851 0.376276 0.336816 0.973923 0.201484 0.0494671 0.566261 0.684967 0.542826 0.191345 0.00384357 0.395016
julia> a*b2×4 Matrix{Float64}: 1.15678 0.509123 0.696858 1.62508 0.521064 0.216325 0.673442 1.0754
julia> c*d2×4 Main.StaticMatrix{Float64, 2, 4, 8}: 1.15678 0.509123 0.696858 1.62508 0.521064 0.216325 0.673442 1.0754

OptionalArgChecks.jl

The package OptionalArgChecks.jl makes is possible to add checks to a function which can then be removed by calling the function with the @skip macro. For example, we can check if the input to a function f is an even number

function f(x::Number)
    iseven(x) || error("Input has to be an even number!")
    x
end

If you are doing more involved argument checking it can take quite some time to perform all your checks. However, if you want to be fast and are completely sure that you are always passing in the correct inputs to your function, you might want to remove them in some cases. Hence, we would like to transform the IR of the function above

julia> using IRTools
julia> using IRTools: @code_ir
julia> @code_ir f(1)1: (%1, %2) %3 = Main.iseven(%2) br 2 unless %3 br 3 2: %4 = Main.error("Input has to be an even number!") 3: return %2

To some thing like this

julia> transformed_f(x::Number) = xtransformed_f (generic function with 1 method)
julia> @code_ir transformed_f(1)1: (%1, %2) return %2

Marking Argument Checks

As a first step we will implement a macro that marks checks which we might want to remove later by surrounding it with :meta expressions. This will make it easy to detect which part of the code can be removed. A :meta expression can be created like this

julia> Expr(:meta, :mark_begin):($(Expr(:meta, :mark_begin)))
julia> Expr(:meta, :mark_end):($(Expr(:meta, :mark_end)))

and they will not be evaluated but remain in your IR. To surround an expression with two meta expressions you can use a :block expression:

julia> ex = :(x+x):(x + x)
julia> Expr(:block, :(print(x)), ex, :(print(x)))quote print(x) x + x print(x) end
Exercise

Define a macro @mark that takes an expression and surrounds it with two meta expressions marking the beginning and end of a check. Hints

  • Defining a function _mark(ex::Expr) which manipulates your expressions can help a lot with debugging your macro.
Solution:

function _mark(ex::Expr)
    return Expr(
        :block,
        Expr(:meta, :mark_begin),
        esc(ex),
        Expr(:meta, :mark_end),
    )
end

macro mark(ex)
    _mark(ex)
end

If you have defined a _mark function you can test that it works like this

julia> _mark(:(println(x)))quote
    $(Expr(:meta, :mark_begin))
    $(Expr(:escape, :(println(x))))
    $(Expr(:meta, :mark_end))
end

The complete macro should work like below

julia> function f(x::Number)
           @mark @show x
           x
       end;
julia> @code_ir f(2)1: (%1, %2) %3 = $(Expr(:meta, :mark_begin)) %4 = Base.repr(%2) %5 = Base.println("x = ", %4) %6 = $(Expr(:meta, :mark_end)) return %2
julia> f(2)x = 2 2

Removing Argument Checks

Now comes tricky part for which we need IRTools.jl. We want to remove all lines that are between our two meta blocks. You can delete the line that corresponds to a certain variable with the delete! and the var functions. E.g. deleting the line that defines variable %4 works like this:

julia> using IRTools: delete!, var
julia> ir = @code_ir f(2)1: (%1, %2) %3 = $(Expr(:meta, :mark_begin)) %4 = Base.repr(%2) %5 = Base.println("x = ", %4) %6 = $(Expr(:meta, :mark_end)) return %2
julia> delete!(ir, var(4))1: (%1, %2) %3 = $(Expr(:meta, :mark_begin)) %5 = Base.println("x = ", %4) %6 = $(Expr(:meta, :mark_end)) return %2
Exercise

Write a function skip(ir::IR) which deletes all lines between the meta expression :mark_begin and :mark_end.

Hints You can check whether a statement is one of our meta expressions like this:

julia> ismarkbegin(e::Expr) = Meta.isexpr(e,:meta) && e.args[1]===:mark_beginismarkbegin (generic function with 1 method)
julia> ismarkbegin(Expr(:meta,:mark_begin))true
Solution:

ismarkend(e::Expr) = Meta.isexpr(e,:meta) && e.args[1]===:mark_end

function skip(ir)
    delete_line = false
    for (x,st) in ir
        isbegin = ismarkbegin(st.expr)
        isend   = ismarkend(st.expr)

        if isbegin
            delete_line = true
        end

        if delete_line
            delete!(ir,x)
        end

        if isend
            delete_line = false
        end
    end
    ir
end

Your function should transform the IR of f like below.

julia> ir = @code_ir f(2)1: (%1, %2)
  %3 = $(Expr(:meta, :mark_begin))
  %4 = Base.repr(%2)
  %5 = Base.println("x = ", %4)
  %6 = $(Expr(:meta, :mark_end))
  return %2
julia> ir = skip(ir)1: (%1, %2) return %2
julia> using IRTools: func
julia> func(ir)(nothing, 2) # no output from @show!2

However, if we have a slightly more complicated IR like below this version of our function will fail. It actually fails so badly that running func(ir)(nothing,2) after skip will cause the build of this page to crash, so we cannot show you the output here ;).

julia> function g(x)
           @mark iseven(x) && println("even")
           x
       endg (generic function with 1 method)
julia> ir = @code_ir g(2)1: (%1, %2) %3 = $(Expr(:meta, :mark_begin)) %4 = Main.iseven(%2) br 3 unless %4 2: %5 = Main.println("even") br 3 3: %6 = $(Expr(:meta, :mark_end)) return %2
julia> ir = skip(ir)1: (%1, %2) br 3 unless %4 2: br 3 3: return %2

The crash is due to %4 not existing anymore. We can fix this by emptying the block in which we found the :mark_begin expression and branching to the block that contains :mark_end (unless they are in the same block already). If some (branching) code in between remained, it should then be removed by the compiler because it is never reached.

Exercise

Use the functions IRTools.block, IRTools.branches, IRTools.empty!, and IRTools.branch! to modify skip such that it also empties the :mark_begin block, and adds a branch to the :mark_end block (unless they are the same block).

Hints

  • block gets you the block of IR in which a given variable is if you call e.g. block(ir,var(4)).
  • empty! removes all statements in a block.
  • branches returns all branches of a block.
  • branch!(a,b) creates a branch from the end of block a to the beginning block b
Solution:

using IRTools: block, branch!, empty!, branches
function skip(ir)
    delete_line = false
    orig = nothing
    for (x,st) in ir
        isbegin = ismarkbegin(st.expr)
        isend   = ismarkend(st.expr)

        if isbegin
            delete_line = true
        end

        # this part is new
        if isbegin
            orig = block(ir,x)
        elseif isend
            dest = block(ir,x)
            if orig != dest
                empty!(branches(orig))
                branch!(orig,dest)
            end
        end

        if delete_line
            delete!(ir,x)
        end

        if isend
            delete_line = false
        end
    end
    ir
end

The result should construct valid IR for our g function.

julia> g(2)even
2
julia> ir = @code_ir g(2)1: (%1, %2) %3 = $(Expr(:meta, :mark_begin)) %4 = Main.iseven(%2) br 3 unless %4 2: %5 = Main.println("even") br 3 3: %6 = $(Expr(:meta, :mark_end)) return %2
julia> ir = skip(ir)1: (%1, %2) br 3 2: br 3 3: return %2
julia> func(ir)(nothing,2)2

And it should not break when applying it to f.

julia> f(2)x = 2
2
julia> ir = @code_ir f(2)1: (%1, %2) %3 = $(Expr(:meta, :mark_begin)) %4 = Base.repr(%2) %5 = Base.println("x = ", %4) %6 = $(Expr(:meta, :mark_end)) return %2
julia> ir = skip(ir)1: (%1, %2) return %2
julia> func(ir)(nothing,2)2

Recursively Removing Argument Checks

The last step to finalize the skip function is to make it work recursively. In the current version we can handle functions that contain @mark statements, but we are not going any deeper than that. Nested functions will not be touched:

foo(x) = bar(baz(x))

function bar(x)
    @mark iseven(x) && println("The input is even.")
    x
end

function baz(x)
    @mark x<0 && println("The input is negative.")
    x
end
julia> ir = @code_ir foo(-2)1: (%1, %2)
  %3 = Main.baz(%2)
  %4 = Main.bar(%3)
  return %4
julia> ir = skip(ir)1: (%1, %2) %3 = Main.baz(%2) %4 = Main.bar(%3) return %4
julia> func(ir)(nothing,-2)The input is negative. The input is even. -2

For recursion we will use the macro IRTools.@dynamo which will make recursion of our skip function a lot easier. Additionally, it will save us from all the func(ir)(nothing, args...) statements. To use @dynamo we have to slightly modify how we call skip:

@dynamo function skip(args...)
    ir = IR(args...)
    
    # same code as before that modifies `ir`
    # ...

    return ir
end

# now we can call `skip` like this
skip(f,2)

Now we can easily use skip in recursion, because we can just pass the arguments of an expression like this:

using IRTools: xcall

for (x,st) in ir
    isexpr(st.expr,:call) || continue
    ir[x] = xcall(skip, st.expr.args...)
end

The function xcall will create an expression that calls skip with the given arguments and returns Expr(:call, skip, args...). Note that you can modify expressions of a given variable in the IR via setindex!.

Exercise

Modify skip such that it uses @dynamo and apply it recursively to all :call expressions that you ecounter while looping over the given IR. This will dive all the way down to Core.Builtins and Core.IntrinsicFunctions which you cannot maniuplate anymore (because they are written in C). You have to end the recursion at these places which can be done via multiple dispatch of skip on Builtins and IntrinsicFunctions.

Once you are done with this you can also define a macro such that you can conveniently call @skip with an expression:

skip(f,2)
@skip f(2)
Solution:

using IRTools: @dynamo, xcall, IR

# this is where we want to stop recursion
skip(f::Core.IntrinsicFunction, args...) = f(args...)
skip(f::Core.Builtin, args...) = f(args...)

@dynamo function skip(args...)
    ir = IR(args...)
    delete_line = false
    orig = nothing
    for (x,st) in ir
        isbegin = ismarkbegin(st.expr)
        isend   = ismarkend(st.expr)

        if isbegin
            delete_line = true
        end

        if isbegin
            orig = block(ir,x)
        elseif isend
            dest = block(ir,x)
            if orig != dest
                empty!(branches(orig))
                branch!(orig,dest)
            end
        end

        if delete_line
            delete!(ir,x)
        end

        if isend
            delete_line = false
        end

        # this part is new
        if haskey(ir,x) && Meta.isexpr(st.expr,:call)
            ir[x] = xcall(skip, st.expr.args...)
        end
    end
    return ir
end

macro skip(ex)
    ex.head == :call || error("Input expression has to be a `:call`.")
    return xcall(skip, ex.args...)
end

julia> @code_ir foo(2)1: (%1, %2)
  %3 = Main.baz(%2)
  %4 = Main.bar(%3)
  return %4
julia> @code_ir skip(foo,2)1: (%1, %2) %3 = Base.getfield(%2, 1) %4 = Base.getfield(%2, 2) %5 = (Main.skip)(Main.baz, %4) %6 = (Main.skip)(Main.bar, %5) return %6
julia> foo(-2)The input is negative. The input is even. -2
julia> skip(foo,-2)-2
julia> @skip foo(-2)-2

References