Direct tail-call threading in Rust

2021-05-06

With Clang gaining guaranteed tail calls Rust hopefully will catch up soon and also finally get guaranteed TCE. While the become keyword is not useable at the moment, I figured it's still worth looking into the potentially fastest implementation for bytecode interpreters.

The benchmarks are performed on a Ryzen 2700X.

A simple interpreter

To keep things simple, the interpreter only has to be able to execute this pseudocode:

    let a = 0
    let b = 1
    let c = 100_000_000

loop:
    a += b
    if a != c
        goto loop

    print a

Implementation in C

While this blog post is about Rust, an implementation in C is easier to implement as it allows a lot of unsafe behaviour by default. It also allows using clang-13, which has the musttail attribute.

Using a program counter

A naive implementation using a program counter may look like this:

#include <stdio.h>

#define DEF_OP(_name_) static void _name_( \
        const instruction_t *restrict instructions, \
        int *restrict memory, \
        unsigned int pc \
)

#define NEXT_OP do { \
    __attribute__((musttail)) \
    return instructions[pc].handler(instructions, memory, pc); \
} while(0)

#define GET_OP(_instr_) \
    const instruction_t *_instr_ = &instructions[pc];

typedef struct instruction {
    void (*handler)(const struct instruction *restrict, int *restrict, unsigned int);
    union {
        struct {
            unsigned char c;
            unsigned char d;
        };
        unsigned int pc;
        int imm;
    };
    unsigned char a;
    unsigned char b;
} instruction_t;

DEF_OP(load) {
    GET_OP(i);
    pc++;
    memory[i->a] = i->imm;
    NEXT_OP;
}

DEF_OP(add) {
    GET_OP(i);
    pc++;
    memory[i->a] = memory[i->b] + memory[i->c];
    NEXT_OP;
}

DEF_OP(jmpnif) { 
    GET_OP(i);
    if (memory[i->a] != memory[i->b]) {
        pc = i->pc;
    } else {
        pc++;
    }
    NEXT_OP;
}

DEF_OP(print) {
    GET_OP(i);
    pc++;
    printf("%d\n", memory[i->a]);
    NEXT_OP;
}

DEF_OP(ret) {
}

int main() {
    int memory[256] = {0};
    instruction_t instructions[] = {
        // Init loop
        [0] = { .handler = load, .a = 0, .imm = 0 },
        [1] = { .handler = load, .a = 1, .imm = 1 },
        [2] = { .handler = load, .a = 2, .imm = 100 * 1000 * 1000 },

        // Loop
        [3] = { .handler = add, .a = 0, .b = 0, .c = 1 },
        [4] = { .handler = jmpnif, .a = 0, .b = 2, .pc = 3 },

        // Finish
        [5] = { .handler = print, .a = 0 },
        [6] = { .handler = ret },
    };
    instructions[0].handler(instructions, memory, 0);

    return 0;
}

However, this does not generate efficient assembly: inspecting the add and jmpnif functions with objdump -SC shows this:

0000000000401210 <add>:
  401210:       89 d0                   mov    %edx,%eax
  401212:       83 c2 01                add    $0x1,%edx
  401215:       48 c1 e0 04             shl    $0x4,%rax
  401219:       44 0f b6 44 07 0d       movzbl 0xd(%rdi,%rax,1),%r8d
  40121f:       0f b6 4c 07 08          movzbl 0x8(%rdi,%rax,1),%ecx
  401224:       8b 0c 8e                mov    (%rsi,%rcx,4),%ecx
  401227:       42 03 0c 86             add    (%rsi,%r8,4),%ecx
  40122b:       0f b6 44 07 0c          movzbl 0xc(%rdi,%rax,1),%eax
  401230:       89 0c 86                mov    %ecx,(%rsi,%rax,4)
  401233:       48 89 d0                mov    %rdx,%rax
  401236:       48 c1 e0 04             shl    $0x4,%rax
  40123a:       48 8b 04 07             mov    (%rdi,%rax,1),%rax
  40123e:       ff e0                   jmpq   *%rax

0000000000401240 <jmpnif>:
  401240:       89 d0                   mov    %edx,%eax
  401242:       48 c1 e0 04             shl    $0x4,%rax
  401246:       0f b6 4c 07 0c          movzbl 0xc(%rdi,%rax,1),%ecx
  40124b:       44 8b 04 8e             mov    (%rsi,%rcx,4),%r8d
  40124f:       0f b6 4c 07 0d          movzbl 0xd(%rdi,%rax,1),%ecx
  401254:       44 3b 04 8e             cmp    (%rsi,%rcx,4),%r8d
  401258:       75 0f                   jne    401269 <jmpnif+0x29>
  40125a:       83 c2 01                add    $0x1,%edx
  40125d:       89 d0                   mov    %edx,%eax
  40125f:       48 c1 e0 04             shl    $0x4,%rax
  401263:       48 8b 04 07             mov    (%rdi,%rax,1),%rax
  401267:       ff e0                   jmpq   *%rax
  401269:       8b 54 07 08             mov    0x8(%rdi,%rax,1),%edx
  40126d:       89 d0                   mov    %edx,%eax
  40126f:       48 c1 e0 04             shl    $0x4,%rax
  401273:       48 8b 04 07             mov    (%rdi,%rax,1),%rax
  401277:       ff e0                   jmpq   *%rax
  401279:       0f 1f 80 00 00 00 00    nopl   0x0(%rax)

That's an awful lot of work for what should be two relatively simple operations.

Using direct tail call threading

An implementation in C using what I dub direct tail-call threading could look like this:

#include <stdio.h>

#define DEF_OP(_name_) static void _name_( \
        const instruction_t *restrict instruction, \
        int *restrict memory \
)

#define NEXT_OP do { \
    __attribute__((musttail)) \
    return instruction->handler(instruction, memory); \
} while(0)

typedef struct instruction {
    void (*handler)(const struct instruction *restrict, int *restrict);
    union {
        unsigned char c;
        struct instruction *jmp_instruction;
        int imm;
    };
    unsigned char a;
    unsigned char b;
} instruction_t;

DEF_OP(load) {
    memory[instruction->a] = instruction->imm;
    instruction++;
    NEXT_OP;
}

DEF_OP(add) {
    memory[instruction->a] = memory[instruction->b] + memory[instruction->c];
    instruction++;
    NEXT_OP;
}

DEF_OP(jmpnif) { 
    if (memory[instruction->a] != memory[instruction->b]) {
        instruction = instruction->jmp_instruction;
    } else {
        instruction++;
    }
    NEXT_OP;
}

DEF_OP(print) {
    printf("%d\n", memory[instruction->a]);
    instruction++;
    NEXT_OP;
}

DEF_OP(ret) {
}

int main() {
    int memory[256] = {0};
    instruction_t instructions[] = {
        // Init loop
        [0] = { .handler = load, .a = 0, .imm = 0 },
        [1] = { .handler = load, .a = 1, .imm = 1 },
        [2] = { .handler = load, .a = 2, .imm = 100 * 1000 * 1000 },

        // Loop
        [3] = { .handler = add, .a = 0, .b = 0, .c = 1 },
        [4] = { .handler = jmpnif, .a = 0, .b = 2, .jmp_instruction = &instructions[3] },

        // Finish
        [5] = { .handler = print, .a = 0 },
        [6] = { .handler = ret },
    };
    instructions[0].handler(instructions, memory);

    return 0;
}

The generated assembly is much more efficient (and is ideal at first sight):

0000000000401270 <add>:
  401270:       0f b6 47 11             movzbl 0x11(%rdi),%eax
  401274:       0f b6 4f 08             movzbl 0x8(%rdi),%ecx
  401278:       8b 0c 8e                mov    (%rsi,%rcx,4),%ecx
  40127b:       03 0c 86                add    (%rsi,%rax,4),%ecx
  40127e:       0f b6 47 10             movzbl 0x10(%rdi),%eax
  401282:       89 0c 86                mov    %ecx,(%rsi,%rax,4)
  401285:       48 8b 47 18             mov    0x18(%rdi),%rax
  401289:       48 83 c7 18             add    $0x18,%rdi
  40128d:       ff e0                   jmpq   *%rax
  40128f:       90                      nop

0000000000401290 <jmpnif>:
  401290:       0f b6 47 10             movzbl 0x10(%rdi),%eax
  401294:       8b 04 86                mov    (%rsi,%rax,4),%eax
  401297:       0f b6 4f 11             movzbl 0x11(%rdi),%ecx
  40129b:       3b 04 8e                cmp    (%rsi,%rcx,4),%eax
  40129e:       75 09                   jne    4012a9 <jmpnif+0x19>
  4012a0:       48 83 c7 18             add    $0x18,%rdi
  4012a4:       48 8b 07                mov    (%rdi),%rax
  4012a7:       ff e0                   jmpq   *%rax
  4012a9:       48 8b 7f 08             mov    0x8(%rdi),%rdi
  4012ad:       48 8b 07                mov    (%rdi),%rax
  4012b0:       ff e0                   jmpq   *%rax
  4012b2:       66 2e 0f 1f 84 00 00    nopw   %cs:0x0(%rax,%rax,1)
  4012b9:       00 00 00
  4012bc:       0f 1f 40 00             nopl   0x0(%rax)

The difference is clearly visible when running both versions with perf stat:

$ perf stat ./with_direct_threading
100000000

 Performance counter stats for './build/main':

            229.35 msec task-clock:u              #    0.991 CPUs utilized          
                 0      context-switches:u        #    0.000 K/sec                  
                 0      cpu-migrations:u          #    0.000 K/sec                  
                50      page-faults:u             #    0.218 K/sec                  
       957,568,185      cycles:u                  #    4.175 GHz                      (82.76%)
           575,004      stalled-cycles-frontend:u #    0.06% frontend cycles idle     (84.37%)
       538,644,702      stalled-cycles-backend:u  #   56.25% backend cycles idle      (80.98%)
     1,691,695,590      instructions:u            #    1.77  insn per cycle         
                                                  #    0.32  stalled cycles per insn  (84.43%)
       298,083,096      branches:u                # 1299.694 M/sec                    (82.70%)
               303      branch-misses:u           #    0.00% of all branches          (84.75%)

       0.231515959 seconds time elapsed

       0.225605000 seconds user
       0.004256000 seconds sys

$ perf stat ./with_program_counter
100000000

 Performance counter stats for './build/with_program_counter':

            263.86 msec task-clock:u              #    0.998 CPUs utilized          
                 0      context-switches:u        #    0.000 K/sec                  
                 0      cpu-migrations:u          #    0.000 K/sec                  
                52      page-faults:u             #    0.197 K/sec                  
     1,089,623,673      cycles:u                  #    4.129 GHz                      (83.33%)
           223,565      stalled-cycles-frontend:u #    0.02% frontend cycles idle     (83.33%)
       613,199,854      stalled-cycles-backend:u  #   56.28% backend cycles idle      (83.34%)
     2,491,824,728      instructions:u            #    2.29  insn per cycle         
                                                  #    0.25  stalled cycles per insn  (83.33%)
       299,860,727      branches:u                # 1136.422 M/sec                    (83.34%)
               822      branch-misses:u           #    0.00% of all branches          (83.33%)

       0.264392820 seconds time elapsed

       0.264333000 seconds user
       0.000000000 seconds sys

Porting the C code to Rust

While Rust doesn't guarantee TCE, rustc is still able to optimize simple functions to use tailcalls.

The Rust equivalent of the threaded C implementation looks like this:

macro_rules! def_op {
    ($name:ident($instruction:ident, $memory:ident) $code:block) => {
        unsafe fn $name($instruction: *const Instruction, $memory: &mut [i32]) {
            $code
        }
    };
}

#[cfg(feature = "unchecked_memory")]
macro_rules! get_mem {
    ($mem:ident[$r:expr]) => {
        $mem.get_unchecked($r as usize)
    };
    (mut $mem:ident[$r:expr]) => {
        $mem.get_unchecked_mut($r as usize)
    };
}

#[cfg(not(feature = "unchecked_memory"))]
macro_rules! get_mem {
    ($mem:ident[$r:expr]) => {
        &$mem[$r as usize]
    };
    (mut $mem:ident[$r:expr]) => {
        &mut $mem[$r as usize]
    };
}

/// ## SAFETY
///
/// `instruction` is a valid pointer.
#[inline(always)]
unsafe fn next_op(instruction: *const Instruction, memory: &mut [i32]) {
    ((*instruction).handler)(instruction, memory)
}

union ExtraParam {
    c: u8,
    jmp: *const Instruction,
    imm: i32,
}

struct Instruction {
    handler: unsafe fn(*const Self, &mut [i32]),
    param: ExtraParam,
    a: u8,
    b: u8,
}

def_op!(load(instruction, memory) {
    *get_mem!(mut memory[(*instruction).a]) = (*instruction).param.imm;
    next_op(instruction.offset(1), memory)
});

def_op!(add(instruction, memory) {
    *get_mem!(mut memory[(*instruction).a]) =
        *get_mem!(memory[(*instruction).b]) +
        *get_mem!(memory[(*instruction).param.c]);
    next_op(instruction.offset(1), memory)
});

def_op!(jmpnif(instruction, memory) {
    let instruction = if get_mem!(memory[(*instruction).a]) != get_mem!(memory[(*instruction).b]) {
        (*instruction).param.jmp
    } else {
        instruction.offset(1)
    };
    next_op(instruction, memory)
});

def_op!(print(instruction, memory) {
    println!("{}", get_mem!(memory[(*instruction).a as usize]));
    next_op(instruction.offset(1), memory)
});

def_op!(ret(_instruction, _memory) {});

fn main() {
    type I = Instruction;
    type P = ExtraParam;
    let mut memory = [0; 256];
    let mut instructions = [
        // Init loop
        I { handler: load, a: 0, b: 0, param: P { imm: 0 }},
        I { handler: load, a: 1, b: 0, param: P { imm: 1 }},
        I { handler: load, a: 2, b: 0, param: P { imm: 100 * 1000 * 1000 }},

        // Loop
        I { handler: add, a: 0, b: 0, param: P { c: 1 }},
        I { handler: jmpnif, a: 0, b: 2, param: P { jmp: core::ptr::null() }},

        // Finish
        I { handler: print, a: 0, b: 0, param: P { c: 0 }},
        I { handler: ret, a: 0, b: 0, param: P { c: 0 }},
    ];
    instructions[4].param.jmp = &instructions[3];
    // SAFETY: the bytecode is valid.
    unsafe {
        (instructions[0].handler)(&instructions as *const _, &mut memory);
    }
}

Building with rustc -C opt-level=3 --cfg 'feature="unchecked_memory"' main.rs and inspecting the final assembly, we can see that the assembly is almost equivalent:

0000000000006a20 <main::add>:
    6a20:       0f b6 47 11             movzbl 0x11(%rdi),%eax
    6a24:       0f b6 4f 08             movzbl 0x8(%rdi),%ecx
    6a28:       8b 0c 8e                mov    (%rsi,%rcx,4),%ecx
    6a2b:       03 0c 86                add    (%rsi,%rax,4),%ecx
    6a2e:       0f b6 47 10             movzbl 0x10(%rdi),%eax
    6a32:       89 0c 86                mov    %ecx,(%rsi,%rax,4)
    6a35:       48 8b 47 18             mov    0x18(%rdi),%rax
    6a39:       48 83 c7 18             add    $0x18,%rdi
    6a3d:       ff e0                   jmpq   *%rax
    6a3f:       90                      nop

0000000000006a40 <main::jmpnif>:
    6a40:       0f b6 47 10             movzbl 0x10(%rdi),%eax
    6a44:       0f b6 4f 11             movzbl 0x11(%rdi),%ecx
    6a48:       8b 04 86                mov    (%rsi,%rax,4),%eax
    6a4b:       3b 04 8e                cmp    (%rsi,%rcx,4),%eax
    6a4e:       75 09                   jne    6a59 <main::jmpnif+0x19>
    6a50:       48 83 c7 18             add    $0x18,%rdi
    6a54:       48 8b 07                mov    (%rdi),%rax
    6a57:       ff e0                   jmpq   *%rax
    6a59:       48 8b 7f 08             mov    0x8(%rdi),%rdi
    6a5d:       48 8b 07                mov    (%rdi),%rax
    6a60:       ff e0                   jmpq   *%rax
    6a62:       66 2e 0f 1f 84 00 00    nopw   %cs:0x0(%rax,%rax,1)
    6a69:       00 00 00 
    6a6c:       0f 1f 40 00             nopl   0x0(%rax)

The only difference between the Rust and C version is two swapped instructions in jmpnif:

0000000000006a40 <main::jmpnif>:
    6a40:       0f b6 47 10             movzbl 0x10(%rdi),%eax
    6a44:       0f b6 4f 11             movzbl 0x11(%rdi),%ecx
    6a48:       8b 04 86                mov    (%rsi,%rax,4),%eax
    6a4b:       3b 04 8e                cmp    (%rsi,%rcx,4),%eax

0000000000401290 <jmpnif>:
  401290:       0f b6 47 10             movzbl 0x10(%rdi),%eax
  401294:       8b 04 86                mov    (%rsi,%rax,4),%eax
  401297:       0f b6 4f 11             movzbl 0x11(%rdi),%ecx
  40129b:       3b 04 8e                cmp    (%rsi,%rcx,4),%eax

Rust's version is slightly faster since it works better with pipelining and doesn't stall the frontend as much:

$ perf stat ./with_direct_threading
100000000

 Performance counter stats for './build/main':

            229.35 msec task-clock:u              #    0.991 CPUs utilized          
                 0      context-switches:u        #    0.000 K/sec                  
                 0      cpu-migrations:u          #    0.000 K/sec                  
                50      page-faults:u             #    0.218 K/sec                  
       957,568,185      cycles:u                  #    4.175 GHz                      (82.76%)
           575,004      stalled-cycles-frontend:u #    0.06% frontend cycles idle     (84.37%)
       538,644,702      stalled-cycles-backend:u  #   56.25% backend cycles idle      (80.98%)
     1,691,695,590      instructions:u            #    1.77  insn per cycle         
                                                  #    0.32  stalled cycles per insn  (84.43%)
       298,083,096      branches:u                # 1299.694 M/sec                    (82.70%)
               303      branch-misses:u           #    0.00% of all branches          (84.75%)

       0.231515959 seconds time elapsed

       0.225605000 seconds user
       0.004256000 seconds sys

$ perf stat ./build/threaded_tco_rust
100000000

 Performance counter stats for './build/threaded_tco_rust':

            223.33 msec task-clock:u              #    0.998 CPUs utilized
                 0      context-switches:u        #    0.000 K/sec
                 0      cpu-migrations:u          #    0.000 K/sec
                81      page-faults:u             #    0.363 K/sec
       940,918,264      cycles:u                  #    4.213 GHz                      (82.10%)
            50,970      stalled-cycles-frontend:u #    0.01% frontend cycles idle     (83.13%)
       632,340,038      stalled-cycles-backend:u  #   67.20% backend cycles idle      (83.89%)
     1,697,590,700      instructions:u            #    1.80  insn per cycle
                                                  #    0.37  stalled cycles per insn  (83.89%)
       299,671,562      branches:u                # 1341.808 M/sec                    (83.89%)
               504      branch-misses:u           #    0.00% of all branches          (83.10%)

       0.223767884 seconds time elapsed

       0.223755000 seconds user
       0.000000000 seconds sys

Conclusion

With TCE and direct threading it is possible to create very efficient interpreters in Rust with simple bytecode instructions. Whether the generated assembly will be as ideal for more complex bytecode instructions (which may suffer from many pushes & pops) remains to be seen.