Lab 06: Code introspection and metaprogramming
In this lab we are first going to inspect some tooling to help you understand what Julia does under the hood such as:
- looking at the code at different levels
- understanding what method is being called
- showing different levels of code optimization
Secondly we will start playing with the metaprogramming side of Julia, mainly covering:
- how to view abstract syntax tree (AST) of Julia code
- how to manipulate AST
These topics will be extended in the next lecture/lab, where we are going use metaprogramming to manipulate code with macros.
We will be again a little getting ahead of ourselves as we are going to use quite a few macros, which will be properly explained in the next lecture as well, however for now the important thing to know is that a macro is just a special function, that accepts as an argument Julia code, which it can modify.
Quick reminder of introspection tooling
Let's start with the topic of code inspection, e.g. we may ask the following: What happens when Julia evaluates [i for i in 1:10]?
- parsing
julia> :([i for i in 1:10]) |> dumpExpr head: Symbol comprehension args: Array{Any}((1,)) 1: Expr head: Symbol generator args: Array{Any}((2,)) 1: Symbol i 2: Expr head: Symbol = args: Array{Any}((2,)) 1: Symbol i 2: Expr head: Symbol call args: Array{Any}((3,)) 1: Symbol : 2: Int64 1 3: Int64 10
- lowering
julia> Meta.@lower [i for i in 1:10]:($(Expr(:thunk, CodeInfo( @ none within `top-level scope` 1 ─ %1 = 1:10 │ %2 = Base.Generator(Base.identity, %1) │ %3 = Base.collect(%2) └── return %3 ))))
- typing
julia> f() = [i for i in 1:10]f (generic function with 1 method)julia> @code_typed f()CodeInfo( 1 ── goto #3 if not true 2 ── nothing::Nothing 3 ┄─ %3 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Vector{Int64}, svec(Any, Int64), 0, :(:ccall), Vector{Int64}, 10, 10))::Vector{Int64} └─── Base.arrayset(true, %3, 1, 1)::Vector{Int64} 4 ┄─ %5 = φ (#3 => 2, #13 => %23)::Int64 │ %6 = φ (#3 => 1, #13 => %15)::Int64 └─── goto #14 if not true 5 ── %8 = (%6 === 10)::Bool └─── goto #7 if not %8 6 ── goto #8 7 ── %11 = Base.add_int(%6, 1)::Int64 └─── goto #8 8 ┄─ %13 = φ (#6 => true, #7 => false)::Bool │ %14 = φ (#7 => %11)::Int64 │ %15 = φ (#7 => %11)::Int64 └─── goto #10 if not %13 9 ── goto #11 10 ─ goto #11 11 ┄ %19 = φ (#9 => true, #10 => false)::Bool └─── goto #13 if not %19 12 ─ goto #14 13 ─ Base.arrayset(false, %3, %14, %5)::Vector{Int64} │ %23 = Base.add_int(%5, 1)::Int64 └─── goto #4 14 ┄ goto #15 15 ─ goto #16 16 ─ goto #17 17 ─ return %3 ) => Vector{Int64}
- LLVM code generation
julia> @code_llvm f(); @ REPL[1]:1 within `f` define nonnull {}* @julia_f_7544() #0 { top: ; ┌ @ array.jl:787 within `collect` ; │┌ @ array.jl:671 within `_array_for` ; ││┌ @ abstractarray.jl:881 within `similar` @ abstractarray.jl:882 ; │││┌ @ boot.jl:486 within `Array` @ boot.jl:477 %0 = call nonnull {}* inttoptr (i64 139758018971088 to {}* ({}*, i64)*)({}* inttoptr (i64 139757680387296 to {}*), i64 10) ; │└└└ ; │ @ array.jl:792 within `collect` ; │┌ @ array.jl:817 within `collect_to_with_first!` ; ││┌ @ array.jl:969 within `setindex!` %1 = bitcast {}* %0 to { i8*, i64, i16, i16, i32 }* %2 = getelementptr inbounds { i8*, i64, i16, i16, i32 }, { i8*, i64, i16, i16, i32 }* %1, i64 0, i32 1 %3 = load i64, i64* %2, align 8 %.not = icmp eq i64 %3, 0 br i1 %.not, label %oob, label %idxend oob: ; preds = %top %4 = alloca i64, align 8 store i64 1, i64* %4, align 8 call void @ijl_bounds_error_ints({}* %0, i64* nonnull %4, i64 1) unreachable idxend: ; preds = %top %5 = bitcast {}* %0 to i64** %6 = load i64*, i64** %5, align 8 %7 = bitcast i64* %6 to <4 x i64>* store <4 x i64> <i64 1, i64 2, i64 3, i64 4>, <4 x i64>* %7, align 8 ; ││└ ; ││ @ array.jl:818 within `collect_to_with_first!` ; ││┌ @ array.jl:844 within `collect_to!` ; │││┌ @ array.jl:969 within `setindex!` %8 = getelementptr inbounds i64, i64* %6, i64 4 %9 = bitcast i64* %8 to <4 x i64>* store <4 x i64> <i64 5, i64 6, i64 7, i64 8>, <4 x i64>* %9, align 8 %10 = getelementptr inbounds i64, i64* %6, i64 8 %11 = bitcast i64* %10 to <2 x i64>* store <2 x i64> <i64 9, i64 10>, <2 x i64>* %11, align 8 ; └└└└ ret {}* %0 }
- native code generation
julia> @code_native f().text .file "f" .section .rodata.cst32,"aM",@progbits,32 .p2align 5 # -- Begin function julia_f_7562 .LCPI0_0: .quad 1 # 0x1 .quad 2 # 0x2 .quad 3 # 0x3 .quad 4 # 0x4 .LCPI0_1: .quad 5 # 0x5 .quad 6 # 0x6 .quad 7 # 0x7 .quad 8 # 0x8 .section .rodata.cst16,"aM",@progbits,16 .p2align 4 .LCPI0_2: .quad 9 # 0x9 .quad 10 # 0xa .text .globl julia_f_7562 .p2align 4, 0x90 .type julia_f_7562,@function julia_f_7562: # @julia_f_7562 ; ┌ @ REPL[1]:1 within `f` .cfi_startproc # %bb.0: # %top pushq %rbp .cfi_def_cfa_offset 16 .cfi_offset %rbp, -16 movq %rsp, %rbp .cfi_def_cfa_register %rbp movabsq $139757680387296, %rdi # imm = 0x7F1BDEE4E4E0 movabsq $139758018971088, %rax # imm = 0x7F1BF31345D0 ; │┌ @ array.jl:787 within `collect` ; ││┌ @ array.jl:671 within `_array_for` ; │││┌ @ abstractarray.jl:881 within `similar` @ abstractarray.jl:882 ; ││││┌ @ boot.jl:486 within `Array` @ boot.jl:477 movl $10, %esi callq *%rax ; ││└└└ ; ││ @ array.jl:792 within `collect` ; ││┌ @ array.jl:817 within `collect_to_with_first!` ; │││┌ @ array.jl:969 within `setindex!` cmpq $0, 8(%rax) je .LBB0_1 # %bb.5: # %idxend movq (%rax), %rcx movabsq $.LCPI0_0, %rdx vmovaps (%rdx), %ymm0 vmovups %ymm0, (%rcx) movabsq $.LCPI0_1, %rdx ; │││└ ; │││ @ array.jl:818 within `collect_to_with_first!` ; │││┌ @ array.jl:844 within `collect_to!` ; ││││┌ @ array.jl:969 within `setindex!` vmovaps (%rdx), %ymm0 vmovups %ymm0, 32(%rcx) movabsq $.LCPI0_2, %rdx vmovaps (%rdx), %xmm0 vmovups %xmm0, 64(%rcx) ; │└└└└ movq %rbp, %rsp popq %rbp .cfi_def_cfa %rsp, 8 vzeroupper retq .LBB0_1: # %oob ; │┌ @ array.jl:792 within `collect` ; ││┌ @ array.jl:817 within `collect_to_with_first!` ; │││┌ @ array.jl:969 within `setindex!` .cfi_def_cfa %rbp, 16 movq %rsp, %rsi movl $16, %ecx subq %rcx, %rsi cmpq %rsp, %rsi jge .LBB0_4 .LBB0_3: # %oob # =>This Inner Loop Header: Depth=1 xorq $0, (%rsp) subq $4096, %rsp # imm = 0x1000 cmpq %rsp, %rsi jl .LBB0_3 .LBB0_4: # %oob movq %rsi, %rsp movq $1, (%rsi) movabsq $ijl_bounds_error_ints, %rcx movl $1, %edx movq %rax, %rdi callq *%rcx .Lfunc_end0: .size julia_f_7562, .Lfunc_end0-julia_f_7562 .cfi_endproc ; └└└└ # -- End function .type .L_j_const1,@object # @_j_const1 .section .rodata.cst8,"aM",@progbits,8 .p2align 3 .L_j_const1: .quad 1 # 0x1 .size .L_j_const1, 8 .section ".note.GNU-stack","",@progbits
Let's see how these tools can help us understand some of Julia's internals on examples from previous labs and lectures.
Understanding the runtime dispatch and type instabilities
We will start with a question: Can we spot internally some difference between type stable/unstable code?
Inspect the following two functions using @code_lowered, @code_typed, @code_llvm and @code_native.
x = rand(10^5)
function explicit_len(x)
length(x)
end
function implicit_len()
length(x)
endFor now do not try to understand the details, but focus on the overall differences such as length of the code.
If the output of the method introspection tools is too long you can use a general way of redirecting standard output stdout to a file
open("./llvm_fun.ll", "w") do file
original_stdout = stdout
redirect_stdout(file)
@code_llvm fun()
redirect_stdout(original_stdout)
endIn case of @code_llvm and @code_native there are special options, that allow this out of the box, see help ? for underlying code_llvm and code_native. If you don't mind adding dependencies there is also the @capture_out from Suppressor.jl
Solution:
@code_warntype explicit_sum(x)
@code_warntype implicit_sum()
@code_typed explicit_sum(x)
@code_typed implicit_sum()
@code_llvm explicit_sum(x)
@code_llvm implicit_sum()
@code_native explicit_sum(x)
@code_native implicit_sum()In this case we see that the generated code for such a simple operation is much longer in the type unstable case resulting in longer run times. However in the next example we will see that having longer code is not always a bad thing.
Loop unrolling
In some cases the compiler uses loop unrolling[1] optimization to speed up loops at the expense of binary size. The result of such optimization is removal of the loop control instructions and rewriting the loop into a repeated sequence of independent statements.
Inspect under what conditions does the compiler unroll the for loop in the polynomial function from the last lab.
function polynomial(a, x)
accumulator = a[end] * one(x)
for i in length(a)-1:-1:1
accumulator = accumulator * x + a[i]
end
accumulator
endCompare the speed of execution with and without loop unrolling.
HINTS:
- these kind of optimization are lower level than intermediate language
- loop unrolling is possible when compiler knows the length of the input
Solution:
using BenchmarkTools
a = Tuple(ones(20)) # tuple has known size
ac = collect(a)
x = 2.0
@code_lowered polynomial(a,x) # cannot be seen here as optimizations are not applied
@code_typed polynomial(a,x) # loop unrolling is not part of type inference optimizationjulia> @code_llvm polynomial(a,x); @ lab.md:113 within `polynomial` define double @julia_polynomial_7606([20 x double]* nocapture noundef nonnull readonly align 8 dereferenceable(160) %0, double %1) #0 { pass.18: ; @ lab.md:114 within `polynomial` ; ┌ @ tuple.jl:29 within `getindex` %2 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 19 ; └ ; ┌ @ float.jl:410 within `*` %3 = load double, double* %2, align 8 ; └ ; @ lab.md:116 within `polynomial` ; ┌ @ float.jl:410 within `*` %4 = fmul double %3, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %5 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 18 ; └ ; ┌ @ float.jl:408 within `+` %6 = load double, double* %5, align 8 %7 = fadd double %4, %6 ; └ ; ┌ @ float.jl:410 within `*` %8 = fmul double %7, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %9 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 17 ; └ ; ┌ @ float.jl:408 within `+` %10 = load double, double* %9, align 8 %11 = fadd double %8, %10 ; └ ; ┌ @ float.jl:410 within `*` %12 = fmul double %11, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %13 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 16 ; └ ; ┌ @ float.jl:408 within `+` %14 = load double, double* %13, align 8 %15 = fadd double %12, %14 ; └ ; ┌ @ float.jl:410 within `*` %16 = fmul double %15, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %17 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 15 ; └ ; ┌ @ float.jl:408 within `+` %18 = load double, double* %17, align 8 %19 = fadd double %16, %18 ; └ ; ┌ @ float.jl:410 within `*` %20 = fmul double %19, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %21 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 14 ; └ ; ┌ @ float.jl:408 within `+` %22 = load double, double* %21, align 8 %23 = fadd double %20, %22 ; └ ; ┌ @ float.jl:410 within `*` %24 = fmul double %23, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %25 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 13 ; └ ; ┌ @ float.jl:408 within `+` %26 = load double, double* %25, align 8 %27 = fadd double %24, %26 ; └ ; ┌ @ float.jl:410 within `*` %28 = fmul double %27, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %29 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 12 ; └ ; ┌ @ float.jl:408 within `+` %30 = load double, double* %29, align 8 %31 = fadd double %28, %30 ; └ ; ┌ @ float.jl:410 within `*` %32 = fmul double %31, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %33 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 11 ; └ ; ┌ @ float.jl:408 within `+` %34 = load double, double* %33, align 8 %35 = fadd double %32, %34 ; └ ; ┌ @ float.jl:410 within `*` %36 = fmul double %35, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %37 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 10 ; └ ; ┌ @ float.jl:408 within `+` %38 = load double, double* %37, align 8 %39 = fadd double %36, %38 ; └ ; ┌ @ float.jl:410 within `*` %40 = fmul double %39, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %41 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 9 ; └ ; ┌ @ float.jl:408 within `+` %42 = load double, double* %41, align 8 %43 = fadd double %40, %42 ; └ ; ┌ @ float.jl:410 within `*` %44 = fmul double %43, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %45 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 8 ; └ ; ┌ @ float.jl:408 within `+` %46 = load double, double* %45, align 8 %47 = fadd double %44, %46 ; └ ; ┌ @ float.jl:410 within `*` %48 = fmul double %47, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %49 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 7 ; └ ; ┌ @ float.jl:408 within `+` %50 = load double, double* %49, align 8 %51 = fadd double %48, %50 ; └ ; ┌ @ float.jl:410 within `*` %52 = fmul double %51, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %53 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 6 ; └ ; ┌ @ float.jl:408 within `+` %54 = load double, double* %53, align 8 %55 = fadd double %52, %54 ; └ ; ┌ @ float.jl:410 within `*` %56 = fmul double %55, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %57 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 5 ; └ ; ┌ @ float.jl:408 within `+` %58 = load double, double* %57, align 8 %59 = fadd double %56, %58 ; └ ; ┌ @ float.jl:410 within `*` %60 = fmul double %59, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %61 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 4 ; └ ; ┌ @ float.jl:408 within `+` %62 = load double, double* %61, align 8 %63 = fadd double %60, %62 ; └ ; ┌ @ float.jl:410 within `*` %64 = fmul double %63, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %65 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 3 ; └ ; ┌ @ float.jl:408 within `+` %66 = load double, double* %65, align 8 %67 = fadd double %64, %66 ; └ ; ┌ @ float.jl:410 within `*` %68 = fmul double %67, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %69 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 2 ; └ ; ┌ @ float.jl:408 within `+` %70 = load double, double* %69, align 8 %71 = fadd double %68, %70 ; └ ; ┌ @ float.jl:410 within `*` %72 = fmul double %71, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %73 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 1 ; └ ; ┌ @ float.jl:408 within `+` %74 = load double, double* %73, align 8 %75 = fadd double %72, %74 ; └ ; ┌ @ float.jl:410 within `*` %76 = fmul double %75, %1 ; └ ; ┌ @ tuple.jl:29 within `getindex` %77 = getelementptr inbounds [20 x double], [20 x double]* %0, i64 0, i64 0 ; └ ; ┌ @ float.jl:408 within `+` %78 = load double, double* %77, align 8 %79 = fadd double %76, %78 ; └ ; @ lab.md:118 within `polynomial` ret double %79 }julia> @code_llvm polynomial(ac,x); @ lab.md:113 within `polynomial` define double @julia_polynomial_7608({}* noundef nonnull align 16 dereferenceable(40) %0, double %1) #0 { top: ; @ lab.md:114 within `polynomial` ; ┌ @ abstractarray.jl:419 within `lastindex` ; │┌ @ abstractarray.jl:382 within `eachindex` ; ││┌ @ abstractarray.jl:133 within `axes1` ; │││┌ @ abstractarray.jl:98 within `axes` ; ││││┌ @ array.jl:149 within `size` %2 = bitcast {}* %0 to { i8*, i64, i16, i16, i32 }* %3 = getelementptr inbounds { i8*, i64, i16, i16, i32 }, { i8*, i64, i16, i16, i32 }* %2, i64 0, i32 1 %4 = load i64, i64* %3, align 8 ; └└└└└ ; ┌ @ essentials.jl:13 within `getindex` %5 = add nsw i64 %4, -1 %.not19 = icmp eq i64 %4, 0 br i1 %.not19, label %oob, label %idxend L23: ; preds = %idxend7, %L23.preheader %value_phi3 = phi i64 [ %6, %idxend7 ], [ %5, %L23.preheader ] %value_phi5 = phi double [ %22, %idxend7 ], [ %12, %L23.preheader ] ; └ ; @ lab.md:116 within `polynomial` ; ┌ @ essentials.jl:13 within `getindex` %6 = add nsw i64 %value_phi3, -1 %7 = icmp ult i64 %6, %16 br i1 %7, label %idxend7, label %oob6 L40: ; preds = %idxend7, %idxend %value_phi11 = phi double [ %12, %idxend ], [ %22, %idxend7 ] ; └ ; @ lab.md:118 within `polynomial` ret double %value_phi11 oob: ; preds = %top ; @ lab.md:114 within `polynomial` ; ┌ @ essentials.jl:13 within `getindex` %8 = alloca i64, align 8 store i64 0, i64* %8, align 8 call void @ijl_bounds_error_ints({}* %0, i64* nonnull %8, i64 1) unreachable idxend: ; preds = %top %9 = bitcast {}* %0 to double** %10 = load double*, double** %9, align 8 %11 = getelementptr inbounds double, double* %10, i64 %5 %12 = load double, double* %11, align 8 ; └ ; @ lab.md:115 within `polynomial` ; ┌ @ range.jl:22 within `Colon` ; │┌ @ range.jl:24 within `_colon` ; ││┌ @ range.jl:373 within `StepRange` @ range.jl:320 %13 = call i64 @j_steprange_last_7610(i64 signext %5, i64 signext -1, i64 signext 1) #0 ; └└└ ; ┌ @ range.jl:887 within `iterate` ; │┌ @ range.jl:659 within `isempty` ; ││┌ @ operators.jl:269 within `!=` ; │││┌ @ promotion.jl:499 within `==` %14 = icmp eq i64 %5, %13 ; ││└└ ; ││┌ @ operators.jl:369 within `>` ; │││┌ @ int.jl:83 within `<` %.not = icmp sgt i64 %4, %13 ; │└└└ %15 = or i1 %14, %.not ; └ br i1 %15, label %L23.preheader, label %L40 L23.preheader: ; preds = %idxend %16 = load i64, i64* %3, align 8 %17 = load double*, double** %9, align 8 ; @ lab.md:116 within `polynomial` ; ┌ @ essentials.jl:13 within `getindex` br label %L23 oob6: ; preds = %L23 %18 = alloca i64, align 8 store i64 %value_phi3, i64* %18, align 8 call void @ijl_bounds_error_ints({}* %0, i64* nonnull %18, i64 1) unreachable idxend7: ; preds = %L23 ; └ ; ┌ @ float.jl:410 within `*` %19 = fmul double %value_phi5, %1 ; └ ; ┌ @ essentials.jl:13 within `getindex` %20 = getelementptr inbounds double, double* %17, i64 %6 %21 = load double, double* %20, align 8 ; └ ; ┌ @ float.jl:408 within `+` %22 = fadd double %19, %21 ; └ ; @ lab.md:117 within `polynomial` ; ┌ @ range.jl:891 within `iterate` ; │┌ @ promotion.jl:499 within `==` %.not15 = icmp eq i64 %value_phi3, %13 ; └└ br i1 %.not15, label %L40, label %L23 }
More than 2x speedup
julia> @btime polynomial($a,$x)15.832 ns (0 allocations: 0 bytes) 1.048575e6julia> @btime polynomial($ac,$x)37.362 ns (0 allocations: 0 bytes) 1.048575e6
Recursion inlining depth
Inlining[2] is another compiler optimization that allows us to speed up the code by avoiding function calls. Where applicable compiler can replace f(args) directly with the function body of f, thus removing the need to modify stack to transfer the control flow to a different place. This is yet another optimization that may improve speed at the expense of binary size.
Rewrite the polynomial function from the last lab using recursion and find the length of the coefficients, at which inlining of the recursive calls stops occurring.
function polynomial(a, x)
accumulator = a[end] * one(x)
for i in length(a)-1:-1:1
accumulator = accumulator * x + a[i]
end
accumulator
endThe operator ... serves two purposes inside function calls [3][4]:
- combines multiple arguments into one
function printargs(args...)
println(typeof(args))
for (i, arg) in enumerate(args)
println("Arg #$i = $arg")
end
end
printargs(1, 2, 3)- splits one argument into many different arguments
function threeargs(a, b, c)
println("a = $a::$(typeof(a))")
println("b = $b::$(typeof(b))")
println("c = $c::$(typeof(c))")
end
threeargs([1,2,3]...) # or with a variable threeargs(x...)HINTS:
- define two methods
_polynomial!(ac, x, a...)and_polynomial!(ac, x, a)for the case of ≥2 coefficients and the last coefficient - use splatting together with range indexing
a[1:end-1]... - the correctness can be checked using the built-in
evalpoly - recall that these kind of optimization are possible just around the type inference stage
- use container of known length to store the coefficients
Solution:
_polynomial!(ac, x, a...) = _polynomial!(x * ac + a[end], x, a[1:end-1]...)
_polynomial!(ac, x, a) = x * ac + a
polynomial(a, x) = _polynomial!(a[end] * one(x), x, a[1:end-1]...)
# the coefficients have to be a tuple
a = Tuple(ones(Int, 21)) # everything less than 22 gets inlined
x = 2
polynomial(a,x) == evalpoly(x,a) # compare with built-in function
# @code_llvm polynomial(a,x) # seen here too, but code_typed is a better option
@code_lowered polynomial(a,x) # cannot be seen here as optimizations are not appliedjulia> @code_typed polynomial(a,x)CodeInfo( 1 ─ %1 = Base.getfield(a, 21, true)::Int64 │ %2 = Base.mul_int(%1, 1)::Int64 │ %3 = Core.getfield(a, 1)::Int64 │ %4 = Core.getfield(a, 2)::Int64 │ %5 = Core.getfield(a, 3)::Int64 │ %6 = Core.getfield(a, 4)::Int64 │ %7 = Core.getfield(a, 5)::Int64 │ %8 = Core.getfield(a, 6)::Int64 │ %9 = Core.getfield(a, 7)::Int64 │ %10 = Core.getfield(a, 8)::Int64 │ %11 = Core.getfield(a, 9)::Int64 │ %12 = Core.getfield(a, 10)::Int64 │ %13 = Core.getfield(a, 11)::Int64 │ %14 = Core.getfield(a, 12)::Int64 │ %15 = Core.getfield(a, 13)::Int64 │ %16 = Core.getfield(a, 14)::Int64 │ %17 = Core.getfield(a, 15)::Int64 │ %18 = Core.getfield(a, 16)::Int64 │ %19 = Core.getfield(a, 17)::Int64 │ %20 = Core.getfield(a, 18)::Int64 │ %21 = Core.getfield(a, 19)::Int64 │ %22 = Core.getfield(a, 20)::Int64 │ %23 = Base.mul_int(x, %2)::Int64 │ %24 = Base.add_int(%23, %22)::Int64 │ %25 = Base.mul_int(x, %24)::Int64 │ %26 = Base.add_int(%25, %21)::Int64 │ %27 = Base.mul_int(x, %26)::Int64 │ %28 = Base.add_int(%27, %20)::Int64 │ %29 = Base.mul_int(x, %28)::Int64 │ %30 = Base.add_int(%29, %19)::Int64 │ %31 = Base.mul_int(x, %30)::Int64 │ %32 = Base.add_int(%31, %18)::Int64 │ %33 = Base.mul_int(x, %32)::Int64 │ %34 = Base.add_int(%33, %17)::Int64 │ %35 = Base.mul_int(x, %34)::Int64 │ %36 = Base.add_int(%35, %16)::Int64 │ %37 = Base.mul_int(x, %36)::Int64 │ %38 = Base.add_int(%37, %15)::Int64 │ %39 = Base.mul_int(x, %38)::Int64 │ %40 = Base.add_int(%39, %14)::Int64 │ %41 = Base.mul_int(x, %40)::Int64 │ %42 = Base.add_int(%41, %13)::Int64 │ %43 = Base.mul_int(x, %42)::Int64 │ %44 = Base.add_int(%43, %12)::Int64 │ %45 = Base.mul_int(x, %44)::Int64 │ %46 = Base.add_int(%45, %11)::Int64 │ %47 = Base.mul_int(x, %46)::Int64 │ %48 = Base.add_int(%47, %10)::Int64 │ %49 = Base.mul_int(x, %48)::Int64 │ %50 = Base.add_int(%49, %9)::Int64 │ %51 = Base.mul_int(x, %50)::Int64 │ %52 = Base.add_int(%51, %8)::Int64 │ %53 = Base.mul_int(x, %52)::Int64 │ %54 = Base.add_int(%53, %7)::Int64 │ %55 = Base.mul_int(x, %54)::Int64 │ %56 = Base.add_int(%55, %6)::Int64 │ %57 = Base.mul_int(x, %56)::Int64 │ %58 = Base.add_int(%57, %5)::Int64 │ %59 = Base.mul_int(x, %58)::Int64 │ %60 = Base.add_int(%59, %4)::Int64 │ %61 = Base.mul_int(x, %60)::Int64 │ %62 = Base.add_int(%61, %3)::Int64 └── return %62 ) => Int64
AST manipulation: The first steps to metaprogramming
Julia is so called homoiconic language, as it allows the language to reason about its code. This capability is inspired by years of development in other languages such as Lisp, Clojure or Prolog.
There are two easy ways to extract/construct the code structure [5]
- parsing code stored in string with internal
Meta.parse
julia> code_parse = Meta.parse("x = 2") # for single line expressions (additional spaces are ignored):(x = 2)julia> code_parse_block = Meta.parse(""" begin x = 2 y = 3 x + y end """) # for multiline expressionsquote #= none:2 =# x = 2 #= none:3 =# y = 3 #= none:4 =# x + y end
- constructing an expression using
quote ... endor simple:()syntax
julia> code_expr = :(x = 2) # for single line expressions (additional spaces are ignored):(x = 2)julia> code_expr_block = quote x = 2 y = 3 x + y end # for multiline expressionsquote #= REPL[2]:2 =# x = 2 #= REPL[2]:3 =# y = 3 #= REPL[2]:4 =# x + y end
Results can be stored into some variables, which we can inspect further.
julia> typeof(code_parse)Exprjulia> dump(code_parse)Expr head: Symbol = args: Array{Any}((2,)) 1: Symbol x 2: Int64 2
julia> typeof(code_parse_block)Exprjulia> dump(code_parse_block)Expr head: Symbol block args: Array{Any}((6,)) 1: LineNumberNode line: Int64 2 file: Symbol none 2: Expr head: Symbol = args: Array{Any}((2,)) 1: Symbol x 2: Int64 2 3: LineNumberNode line: Int64 3 file: Symbol none 4: Expr head: Symbol = args: Array{Any}((2,)) 1: Symbol y 2: Int64 3 5: LineNumberNode line: Int64 4 file: Symbol none 6: Expr head: Symbol call args: Array{Any}((3,)) 1: Symbol + 2: Symbol x 3: Symbol y
The type of both multiline and single line expression is Expr with fields head and args. Notice that Expr type is recursive in the args, which can store other expressions resulting in a tree structure - abstract syntax tree (AST) - that can be visualized for example with the combination of GraphRecipes and Plots packages.
plot(code_expr_block, fontsize=12, shorten=0.01, axis_buffer=0.15, nodeshape=:rect)
This recursive structure has some major performance drawbacks, because the args field is of type Any and therefore modifications of this expression level AST won't be type stable. Building blocks of expressions are Symbols and literal values (numbers).
A possible nuisance of working with multiline expressions is the presence of LineNumber nodes, which can be removed with Base.remove_linenums! function.
julia> Base.remove_linenums!(code_parse_block)quote x = 2 y = 3 x + y end
Parsed expressions can be evaluate using eval function.
julia> eval(code_parse) # evaluation of :(x = 2)2julia> x # should be defined2
Before doing anything more fancy let's start with some simple manipulation of ASTs.
- Define a variable
codeto be as the result of parsing the string"j = i^2". - Copy code into a variable
code2. Modify this to replace the power2with a power3. Make sure that the original code variable is not also modified. - Copy
code2to a variablecode3. Replaceiwithi + 1incode3. - Define a variable
iwith the value4. Evaluate the different code expressions using theevalfunction and check the value of the variablej.
Solution:
julia> code = Meta.parse("j = i^2"):(j = i ^ 2)julia> code2 = copy(code):(j = i ^ 2)julia> code2.args[2].args[3] = 33julia> code3 = copy(code2):(j = i ^ 3)julia> code3.args[2].args[2] = :(i + 1):(i + 1)julia> i = 44julia> eval(code), eval(code2), eval(code3)(16, 64, 125)
Following up on the more general substitution of variables in an expression from the lecture, let's see how the situation becomes more complicated, when we are dealing with strings instead of a parsed AST.
replace_i(s::Symbol) = s == :i ? :k : s
replace_i(e::Expr) = Expr(e.head, map(replace_i, e.args)...)
replace_i(u) = uGiven a function replace_i, which replaces variables i for k in an expression like the following
julia> ex = :(i + i*i + y*i - sin(z)):((i + i * i + y * i) - sin(z))julia> @test replace_i(ex) == :(k + k*k + y*k - sin(z))Test Passed
write a different function sreplace_i(s), which does the same thing but instead of a parsed expression (AST) it manipulates a string, such as
julia> s = string(ex)"(i + i * i + y * i) - sin(z)"
HINTS:
- Use
Meta.parsein combination withreplace_iONLY for checking of correctness. - You can use the
replacefunction in combination with regular expressions. - Think of some corner cases, that the method may not handle properly.
Solution:
The naive solution
julia> sreplace_i(s) = replace(s, 'i' => 'k')sreplace_i (generic function with 1 method)julia> @test Meta.parse(sreplace_i(s)) == replace_i(Meta.parse(s))Test Failed at REPL[2]:1 Expression: Meta.parse(sreplace_i(s)) == replace_i(Meta.parse(s)) Evaluated: (k + k * k + y * k) - skn(z) == (k + k * k + y * k) - sin(z) ERROR: There was an error during testing
does not work in this simple case, because it will replace "i" inside the sin(z) expression. We can play with regular expressions to obtain something, that is more robust
julia> sreplace_i(s) = replace(s, r"([^\w]|\b)i(?=[^\w]|\z)" => s"\1k")sreplace_i (generic function with 1 method)julia> @test Meta.parse(sreplace_i(s)) == replace_i(Meta.parse(s))Test Passed
however the code may now be harder to read. Thus it is preferable to use the parsed AST when manipulating Julia's code.
If the exercises so far did not feel very useful let's focus on one, that is similar to a part of the IntervalArithmetics.jl pkg.
Write function wrap!(ex::Expr) which wraps literal values (numbers) with a call to f(). You can test it on the following example
f = x -> convert(Float64, x)
ex = :(x*x + 2*y*x + y*y) # original expression
rex = :(x*x + f(2)*y*x + y*y) # result expressionHINTS:
- use recursion and multiple dispatch
- dispatch on
::Numberto detect numbers in an expression - for testing purposes, create a copy of
exbefore mutating
Solution:
julia> function wrap!(ex::Expr) args = ex.args for i in 1:length(args) args[i] = wrap!(args[i]) end return ex endwrap! (generic function with 1 method)julia> wrap!(ex::Number) = Expr(:call, :f, ex)wrap! (generic function with 2 methods)julia> wrap!(ex) = exwrap! (generic function with 3 methods)julia> ext, x, y = copy(ex), 2, 3(:(x * x + 2 * y * x + y * y), 2, 3)julia> @test wrap!(ex) == :(x*x + f(2)*y*x + y*y)Test Passedjulia> eval(ext)25julia> eval(ex)25.0
This kind of manipulation is at the core of some pkgs, such as aforementioned IntervalArithmetics.jl where every number is replaced with a narrow interval in order to find some bounds on the result of a computation.
Resources
- Julia's manual on metaprogramming
- David P. Sanders' workshop @ JuliaCon 2021
- Steven Johnson's keynote talk @ JuliaCon 2019
- Andy Ferris's workshop @ JuliaCon 2018
- From Macros to DSL by John Myles White
- Notes on JuliaCompilerPlugin
- 1https://en.wikipedia.org/wiki/Loop_unrolling
- 2https://en.wikipedia.org/wiki/Inline_expansion
- 3https://docs.julialang.org/en/v1/manual/faq/#What-does-the-...-operator-do?
- 4https://docs.julialang.org/en/v1/manual/functions/#Varargs-Functions
- 5Once you understand the recursive structure of expressions, the AST can be constructed manually like any other type.