Lab 09 - Generated Functions & IR
In this lab you will practice two advanced meta programming techniques:
- Generated functions &
CompTime.jlcan help you write specialized code for certain kinds of parametric types. This can offload certain computations to compile time! - IRTools.jl is a package that simplifies the manipulation of lowered and typed Julia code enabling relatively easy source-to-source transformations.
@generated Functions
Remember the three most important things about generated functions:
- They return quoted expressions (like macros).
- You have access to type information of your input variables.
- They have to be pure
A faster polynomial
Throughout this course we have come back to our polynomial function which evaluates a polynomial based on the Horner schema. Below you can find a version of the function that operates on a tuple of length $N$.
function polynomial(x, p::NTuple{N}) where N
acc = p[N]
for i in N-1:-1:1
acc = x*acc + p[i]
end
acc
endJulia has its own implementation of this function called evalpoly. If we compare the performance of our polynomial and Julia's evalpoly we can observe a pretty big difference:
julia> x = 2.02.0julia> p = ntuple(float,20);julia> @btime polynomial($x,$p)15.832 ns (0 allocations: 0 bytes) 1.9922945e7julia> @btime evalpoly($x,$p)8.600 ns (0 allocations: 0 bytes) 1.9922945e7
Julia's implementation uses a generated function which specializes on different tuple lengths (i.e. it unrolls the loop) and eliminates the (small) overhead of looping over the tuple. This is possible, because the length of the tuple is known during compile time. You can check the difference between polynomial and evalpoly yourself via the introspectionwtools you know - e.g. @code_lowered.
Rewrite the polynomial function as a generated function with the signature
genpoly(x::Number, p::NTuple{N}) where NHints:
- Remember that you have to generate a quoted expression inside your generated function, so you will need things like
:($expr1 + $expr2). - You can debug the expression you are generating by omitting the
@generatedmacro from your function. - You have already written a very similar expression in lab 07
Solution:
@generated function genpoly(x, p::NTuple{N}) where N
ex = :(p[$N])
for i in N-1:-1:1
ex = :(x*$ex + p[$i])
end
ex
endYou should get the same performance as evalpoly (and as @poly from Lab 7 with the added convenience of not having to spell out all the coefficients in your code like: p = @poly 1 2 3 ...).
julia> @btime genpoly($x,$p)15.832 ns (0 allocations: 0 bytes) 1.9922945e7
Fast, Static Matrices
Another great example that makes heavy use of generated functions are static arrays. A static array is an array of fixed size which can be implemented via an NTuple. This means that it will be allocated on the stack, which can buy us a lot of performance for smaller static arrays. We define a StaticMatrix{T,C,R,L} where the paramteric types represent the matrix element type T (e.g. Float32), the number of rows R, the number of columns C, and the total length of the matrix L=C*R (which we need to set the size of the NTuple).
struct StaticMatrix{T,R,C,L} <: AbstractArray{T,2}
data::NTuple{L,T}
end
function StaticMatrix(x::AbstractMatrix{T}) where T
(R,C) = size(x)
StaticMatrix{T,R,C,C*R}(x |> Tuple)
endAs a warm-up, overload the Base functions size, length, getindex(x::StaticMatrix,i::Int), and getindex(x::Solution,r::Int,c::Int).
Solution:
Base.size(x::StaticMatrix{T,R,C}) where {T,R,C} = (R,C)
Base.length(x::StaticMatrix{T,R,C,L}) where {T,R,C,L} = L
Base.getindex(x::StaticMatrix, i::Int) = x.data[i]
Base.getindex(x::StaticMatrix{T,R,C}, r::Int, c::Int) where {T,R,C} = x.data[R*(c-1) + r]You can check if everything works correctly by comparing to a normal Matrix:
julia> x = rand(2,3)2×3 Matrix{Float64}: 0.56873 0.760339 0.498291 0.978606 0.879543 0.132146julia> x[1,2]0.760338519587211julia> a = StaticMatrix(x)2×3 Main.StaticMatrix{Float64, 2, 3, 6}: 0.56873 0.760339 0.498291 0.978606 0.879543 0.132146julia> a[1,2]0.760338519587211
Overload matrix multiplication between two static matrices
Base.:*(x::StaticMatrix{T,K,M},y::StaticMatrix{T,M,N})with a generated function that creates an expression without loops. Below you can see an example for an expression that would be generated from multiplying two $2\times 2$ matrices.
:(StaticMatrix{T,2,2,4}((
(x[1,1]*y[1,1] + x[1,2]*y[2,1]),
(x[2,1]*y[1,1] + x[2,2]*y[2,1]),
(x[1,1]*y[1,2] + x[1,2]*y[2,2]),
(x[2,1]*y[1,2] + x[2,2]*y[2,2])
)))Hints:
- You can get output like above by leaving out the
@generatedin front of your overload. - It might be helpful to implement matrix multiplication in a normal Julia function first.
- You can construct an expression for a sum of multiple elements like below.
julia> Expr(:call,:+,1,2,3):(1 + 2 + 3)julia> Expr(:call,:+,1,2,3) |> eval6
Solution:
@generated function Base.:*(x::StaticMatrix{T,K,M}, y::StaticMatrix{T,M,N}) where {T,K,M,N}
zs = map(Iterators.product(1:K, 1:N) |> collect |> vec) do (k,n)
Expr(:call, :+, [:(x[$k,$m] * y[$m,$n]) for m=1:M]...)
end
z = Expr(:tuple, zs...)
:(StaticMatrix{$T,$K,$N,$(K*N)}($z))
endYou can check that your matrix multiplication works by multiplying two random matrices. Which one is faster?
julia> a = rand(2,3)2×3 Matrix{Float64}: 0.92402 0.179175 0.216795 0.513967 0.491728 0.568149julia> b = rand(3,4)3×4 Matrix{Float64}: 0.538731 0.918312 0.960321 0.449442 0.898655 0.819317 0.153256 0.868133 0.978518 0.54134 0.227203 0.928135julia> c = StaticMatrix(a)2×3 Main.StaticMatrix{Float64, 2, 3, 6}: 0.92402 0.179175 0.216795 0.513967 0.491728 0.568149julia> d = StaticMatrix(b)3×4 Main.StaticMatrix{Float64, 3, 4, 12}: 0.538731 0.918312 0.960321 0.449442 0.898655 0.819317 0.153256 0.868133 0.978518 0.54134 0.227203 0.928135julia> a*b2×4 Matrix{Float64}: 0.870953 1.1127 0.964072 0.772056 1.27473 1.18243 0.698019 1.1852julia> c*d2×4 Main.StaticMatrix{Float64, 2, 4, 8}: 0.870953 1.1127 0.964072 0.772056 1.27473 1.18243 0.698019 1.1852
OptionalArgChecks.jl
The package OptionalArgChecks.jl makes is possible to add checks to a function which can then be removed by calling the function with the @skip macro. For example, we can check if the input to a function f is an even number
function f(x::Number)
iseven(x) || error("Input has to be an even number!")
x
endIf you are doing more involved argument checking it can take quite some time to perform all your checks. However, if you want to be fast and are completely sure that you are always passing in the correct inputs to your function, you might want to remove them in some cases. Hence, we would like to transform the IR of the function above
julia> using IRToolsjulia> using IRTools: @code_irjulia> @code_ir f(1)1: (%1, %2) %3 = Main.iseven(%2) br 2 unless %3 br 3 2: %4 = Main.error("Input has to be an even number!") 3: return %2
To some thing like this
julia> transformed_f(x::Number) = xtransformed_f (generic function with 1 method)julia> @code_ir transformed_f(1)1: (%1, %2) return %2
Marking Argument Checks
As a first step we will implement a macro that marks checks which we might want to remove later by surrounding it with :meta expressions. This will make it easy to detect which part of the code can be removed. A :meta expression can be created like this
julia> Expr(:meta, :mark_begin):($(Expr(:meta, :mark_begin)))julia> Expr(:meta, :mark_end):($(Expr(:meta, :mark_end)))
and they will not be evaluated but remain in your IR. To surround an expression with two meta expressions you can use a :block expression:
julia> ex = :(x+x):(x + x)julia> Expr(:block, :(print(x)), ex, :(print(x)))quote print(x) x + x print(x) end
Define a macro @mark that takes an expression and surrounds it with two meta expressions marking the beginning and end of a check. Hints
- Defining a function
_mark(ex::Expr)which manipulates your expressions can help a lot with debugging your macro.
Solution:
function _mark(ex::Expr)
return Expr(
:block,
Expr(:meta, :mark_begin),
esc(ex),
Expr(:meta, :mark_end),
)
end
macro mark(ex)
_mark(ex)
endIf you have defined a _mark function you can test that it works like this
julia> _mark(:(println(x)))quote $(Expr(:meta, :mark_begin)) $(Expr(:escape, :(println(x)))) $(Expr(:meta, :mark_end)) end
The complete macro should work like below
julia> function f(x::Number) @mark @show x x end;julia> @code_ir f(2)1: (%1, %2) %3 = $(Expr(:meta, :mark_begin)) %4 = Base.repr(%2) %5 = Base.println("x = ", %4) %6 = $(Expr(:meta, :mark_end)) return %2julia> f(2)x = 2 2
Removing Argument Checks
Now comes tricky part for which we need IRTools.jl. We want to remove all lines that are between our two meta blocks. You can delete the line that corresponds to a certain variable with the delete! and the var functions. E.g. deleting the line that defines variable %4 works like this:
julia> using IRTools: delete!, varjulia> ir = @code_ir f(2)1: (%1, %2) %3 = $(Expr(:meta, :mark_begin)) %4 = Base.repr(%2) %5 = Base.println("x = ", %4) %6 = $(Expr(:meta, :mark_end)) return %2julia> delete!(ir, var(4))1: (%1, %2) %3 = $(Expr(:meta, :mark_begin)) %5 = Base.println("x = ", %4) %6 = $(Expr(:meta, :mark_end)) return %2
Write a function skip(ir::IR) which deletes all lines between the meta expression :mark_begin and :mark_end.
Hints You can check whether a statement is one of our meta expressions like this:
julia> ismarkbegin(e::Expr) = Meta.isexpr(e,:meta) && e.args[1]===:mark_beginismarkbegin (generic function with 1 method)julia> ismarkbegin(Expr(:meta,:mark_begin))true
Solution:
ismarkend(e::Expr) = Meta.isexpr(e,:meta) && e.args[1]===:mark_end
function skip(ir)
delete_line = false
for (x,st) in ir
isbegin = ismarkbegin(st.expr)
isend = ismarkend(st.expr)
if isbegin
delete_line = true
end
if delete_line
delete!(ir,x)
end
if isend
delete_line = false
end
end
ir
endYour function should transform the IR of f like below.
julia> ir = @code_ir f(2)1: (%1, %2) %3 = $(Expr(:meta, :mark_begin)) %4 = Base.repr(%2) %5 = Base.println("x = ", %4) %6 = $(Expr(:meta, :mark_end)) return %2julia> ir = skip(ir)1: (%1, %2) return %2julia> using IRTools: funcjulia> func(ir)(nothing, 2) # no output from @show!2
However, if we have a slightly more complicated IR like below this version of our function will fail. It actually fails so badly that running func(ir)(nothing,2) after skip will cause the build of this page to crash, so we cannot show you the output here ;).
julia> function g(x) @mark iseven(x) && println("even") x endg (generic function with 1 method)julia> ir = @code_ir g(2)1: (%1, %2) %3 = $(Expr(:meta, :mark_begin)) %4 = Main.iseven(%2) br 3 unless %4 2: %5 = Main.println("even") br 3 3: %6 = $(Expr(:meta, :mark_end)) return %2julia> ir = skip(ir)1: (%1, %2) br 3 unless %4 2: br 3 3: return %2
The crash is due to %4 not existing anymore. We can fix this by emptying the block in which we found the :mark_begin expression and branching to the block that contains :mark_end (unless they are in the same block already). If some (branching) code in between remained, it should then be removed by the compiler because it is never reached.
Use the functions IRTools.block, IRTools.branches, IRTools.empty!, and IRTools.branch! to modify skip such that it also empties the :mark_begin block, and adds a branch to the :mark_end block (unless they are the same block).
Hints
blockgets you the block of IR in which a given variable is if you call e.g.block(ir,var(4)).empty!removes all statements in a block.branchesreturns all branches of a block.branch!(a,b)creates a branch from the end of blockato the beginning blockb
Solution:
using IRTools: block, branch!, empty!, branches
function skip(ir)
delete_line = false
orig = nothing
for (x,st) in ir
isbegin = ismarkbegin(st.expr)
isend = ismarkend(st.expr)
if isbegin
delete_line = true
end
# this part is new
if isbegin
orig = block(ir,x)
elseif isend
dest = block(ir,x)
if orig != dest
empty!(branches(orig))
branch!(orig,dest)
end
end
if delete_line
delete!(ir,x)
end
if isend
delete_line = false
end
end
ir
endThe result should construct valid IR for our g function.
julia> g(2)even 2julia> ir = @code_ir g(2)1: (%1, %2) %3 = $(Expr(:meta, :mark_begin)) %4 = Main.iseven(%2) br 3 unless %4 2: %5 = Main.println("even") br 3 3: %6 = $(Expr(:meta, :mark_end)) return %2julia> ir = skip(ir)1: (%1, %2) br 3 2: br 3 3: return %2julia> func(ir)(nothing,2)2
And it should not break when applying it to f.
julia> f(2)x = 2 2julia> ir = @code_ir f(2)1: (%1, %2) %3 = $(Expr(:meta, :mark_begin)) %4 = Base.repr(%2) %5 = Base.println("x = ", %4) %6 = $(Expr(:meta, :mark_end)) return %2julia> ir = skip(ir)1: (%1, %2) return %2julia> func(ir)(nothing,2)2
Recursively Removing Argument Checks
The last step to finalize the skip function is to make it work recursively. In the current version we can handle functions that contain @mark statements, but we are not going any deeper than that. Nested functions will not be touched:
foo(x) = bar(baz(x))
function bar(x)
@mark iseven(x) && println("The input is even.")
x
end
function baz(x)
@mark x<0 && println("The input is negative.")
x
endjulia> ir = @code_ir foo(-2)1: (%1, %2) %3 = Main.baz(%2) %4 = Main.bar(%3) return %4julia> ir = skip(ir)1: (%1, %2) %3 = Main.baz(%2) %4 = Main.bar(%3) return %4julia> func(ir)(nothing,-2)The input is negative. The input is even. -2
For recursion we will use the macro IRTools.@dynamo which will make recursion of our skip function a lot easier. Additionally, it will save us from all the func(ir)(nothing, args...) statements. To use @dynamo we have to slightly modify how we call skip:
@dynamo function skip(args...)
ir = IR(args...)
# same code as before that modifies `ir`
# ...
return ir
end
# now we can call `skip` like this
skip(f,2)Now we can easily use skip in recursion, because we can just pass the arguments of an expression like this:
using IRTools: xcall
for (x,st) in ir
isexpr(st.expr,:call) || continue
ir[x] = xcall(skip, st.expr.args...)
endThe function xcall will create an expression that calls skip with the given arguments and returns Expr(:call, skip, args...). Note that you can modify expressions of a given variable in the IR via setindex!.
Modify skip such that it uses @dynamo and apply it recursively to all :call expressions that you ecounter while looping over the given IR. This will dive all the way down to Core.Builtins and Core.IntrinsicFunctions which you cannot maniuplate anymore (because they are written in C). You have to end the recursion at these places which can be done via multiple dispatch of skip on Builtins and IntrinsicFunctions.
Once you are done with this you can also define a macro such that you can conveniently call @skip with an expression:
skip(f,2)
@skip f(2)Solution:
using IRTools: @dynamo, xcall, IR
# this is where we want to stop recursion
skip(f::Core.IntrinsicFunction, args...) = f(args...)
skip(f::Core.Builtin, args...) = f(args...)
@dynamo function skip(args...)
ir = IR(args...)
delete_line = false
orig = nothing
for (x,st) in ir
isbegin = ismarkbegin(st.expr)
isend = ismarkend(st.expr)
if isbegin
delete_line = true
end
if isbegin
orig = block(ir,x)
elseif isend
dest = block(ir,x)
if orig != dest
empty!(branches(orig))
branch!(orig,dest)
end
end
if delete_line
delete!(ir,x)
end
if isend
delete_line = false
end
# this part is new
if haskey(ir,x) && Meta.isexpr(st.expr,:call)
ir[x] = xcall(skip, st.expr.args...)
end
end
return ir
end
macro skip(ex)
ex.head == :call || error("Input expression has to be a `:call`.")
return xcall(skip, ex.args...)
endjulia> @code_ir foo(2)1: (%1, %2) %3 = Main.baz(%2) %4 = Main.bar(%3) return %4julia> @code_ir skip(foo,2)1: (%1, %2) %3 = Base.getfield(%2, 1) %4 = Base.getfield(%2, 2) %5 = (Main.skip)(Main.baz, %4) %6 = (Main.skip)(Main.bar, %5) return %6julia> foo(-2)The input is negative. The input is even. -2julia> skip(foo,-2)-2julia> @skip foo(-2)-2
References
- Static matrices with
@generated functions blog post OptionalArgChecks.jl- IRTools Dynamo