A deep dive into the independent JAX implementation of the Transformer VM — where feed-forward weights become machine instructions, and jax.lax.while_loop becomes the CPU. The Premise What if a transformer’s forward pass wasn’t a language model — but a computer? Not metaphorically. Literally: a machine that reads assembly instructions, maintains registers and memory, pushes and pops a…

Written by

×

JAX Transformer VM: Running Assembly Programs Inside a JIT-Compiled Neural Network

A deep dive into the independent JAX implementation of the Transformer VM — where feed-forward weights become machine instructions, and jax.lax.while_loop becomes the CPU.


The Premise

What if a transformer’s forward pass wasn’t a language model — but a computer?

Not metaphorically. Literally: a machine that reads assembly instructions, maintains registers and memory, pushes and pops a stack, and halts when it is done. Every step is deterministic. Same input, same output, provably and always.

This is the core idea behind the llm-provable-computer project, originally built in Rust. The Rust version is tight and fast, with convex-hull-backed KV-cache memory and a hand-rolled STARK prover. But Rust is a compiled systems language, and the modern ML world runs on JAX — especially when you want to run things on TPUs.

This post describes jax_tvm: an independent, idiomatic JAX implementation of the Transformer VM that compiles .tvm assembly programs into feed-forward neural network weights and executes them inside a jax.lax.while_loop. It is not a wrapper around the Rust code. It is a fresh port — same semantics, JAX-native implementation.


How the Transformer VM Works

Before diving into JAX-specific design decisions, it helps to understand what the system is doing at a conceptual level.

A standard transformer has two alternating components per layer: a multi-head attention block and a feed-forward network (FFN). The attention block attends over the context (keys and values) to retrieve information. The FFN transforms the attended representation into a new one. Together they form the “neural computer” of the architecture.

The Transformer VM exploits this structure differently from a language model:

Attention is memory access. Each memory address in the VM maintains a write history as 2D points (step, value). To read the most recent write to an address, the attention head queries with direction [1, 0], which selects the maximum step. This is the convex-hull trick: the maximizer of any linear objective over a finite set lies on its convex hull. In practice, this gives you deterministic hard-argmax reads from a KV cache — not soft probabilistic mixing, but exact retrieval.

Feed-forward layers are instructions. Every instruction in the assembly ISA — LOADADDJZCALLRET, and so on — is compiled into a specific set of weight matrices for the FFN. The gate and value pathways implement different logic per opcode, and the output projection maps the hidden state to a set of control signals: next PC, next ACC, next SP, memory write address, memory write value, and a write-enable bit. Together, these signals fully specify the next machine state.

The trace is a STARK witness. Because every step is deterministic and follows polynomial transition rules (the gate-value bilinear form is polynomial in the input), the entire execution trace is an algebraic object. The STARK prover can certify it with a transparent proof — no trusted setup, post-quantum security, verification in O(log² n) steps.

The JAX implementation replicates the FFN-as-instruction approach faithfully.


The Compiler: Turning Instructions into Matrices

The first and most important component of jax_tvm is the compiler (jax_tvm/compiler.py). Its job is to take a Program — a list of parsed Instruction objects — and produce a corresponding set of weight matrices for each instruction.

The key insight is the gated bilinear form. The JAX FFN is defined as:

gate = W_gate @ input + b_gate
value = W_value @ input + b_value
hidden = gate * value # elementwise product
output = W_out @ hidden + b_out

This is not GELU or SiLU — it is a structured bilinear network where each hidden unit computes the product of two linear combinations. This is exactly what you need to implement conditional computation: the gate pathway can encode a selector, and the value pathway can encode a transformed quantity, so the product is nonzero only when the gate fires.

For each instruction, the FeedForwardBuilder class accumulates matrix entries:

  • emit_linear(output_idx, coeff, input_idx) — adds a term coeff * input[input_idx] to output[output_idx] by using a single hidden unit where gate is constant 1.0 and value routes through the target input.
  • emit_product(output_idx, coeff, left, right) — adds a bilinear term coeff * input[left] * input[right] by routing left and right through the gate and value respectively.
  • add_output_bias(output_idx, value) — directly adds to the output bias vector.

The input vector is always 41-dimensional (INPUT_DIM = 41):

SlotContent
0Constant 1.0
1PC (current program counter)
2PC + 1 (next sequential PC)
3ACC (accumulator, signed integer)
4zero_flag
5carry_flag
6halted
7SP (stack pointer)
8OPERAND (memory value read by attention)
9–2416 individual bits of ACC
25–4016 individual bits of OPERAND

The output vector is 6-dimensional (OUTPUT_DIM = 6): next PC, raw ACC, next SP, memory write enable, memory write address, and memory write value.

As a concrete example, ADDM addr (add memory at addr to accumulator) compiles to:

  • emit_linear(OUT_NEXT_PC, 1.0, IN_PC_NEXT) — advance PC by one
  • emit_linear(OUT_RAW_ACC, 1.0, IN_ACC) — pass current ACC through
  • emit_linear(OUT_RAW_ACC, 1.0, IN_OPERAND) — add memory operand

And the memory read is configured as MemoryRead(mode="direct", addr=addr), telling the attention stage to load MEM[addr] and place it into IN_OPERAND before the FFN runs.

Control flow like JZ target (jump if zero) is particularly elegant:

next_PC = IN_ZERO * target + IN_PC_NEXT - IN_ZERO * IN_PC_NEXT
= IN_ZERO * (target - IN_PC_NEXT) + IN_PC_NEXT

This is implemented with three terms: a linear term routing IN_ZERO, a constant bias for target, a linear term routing IN_PC_NEXT, and a bilinear product of IN_ZERO × IN_PC_NEXT with coefficient -1.0. The result is a smooth, differentiable blend — and because IN_ZERO is always exactly 0.0 or 1.0 in the integer VM regime, the output is exact.


The Runtime: A JIT-Compiled Execution Loop

The execution loop (jax_tvm/runtime.py) is where JAX’s design philosophy becomes most visible — and most demanding.

The core challenge is that jax.lax.while_loop requires all values it operates on to be valid JAX pytrees. The loop cannot capture arbitrary Python objects in its body or condition functions. This is because while_loop is compiled by XLA: it traces through the body function, produces a computation graph, and the graph’s leaves must be concrete JAX arrays.

The first error we encountered was exactly this: TypeError: Value MachineState ... is not a valid JAX type. The fix was to register MachineState as a JAX pytree node by adding @register_pytree_node_class and implementing tree_flatten / tree_unflatten methods. JAX then knows how to decompose and reconstruct the dataclass during tracing.

The second error was subtler: TypeError: Cannot interpret value of type JAXTransformerVM as an abstract array. This happened because self.step was decorated with @jax.jit as a method — and when body_fun inside run() captured self in its closure, JAX tried to trace self through the computation graph. Since self is a Python class instance with numpy arrays, dict attributes, and methods, JAX has no idea what to do with it.

The architectural fix is clean and permanent: lift the step function out of the class.

@functools.partial(jax.jit, static_argnums=(1,))
def _jit_step(
state: MachineState,
num_insts: int, # static — used for jnp.clip bound
stacked_gate, stacked_gate_bias,
stacked_value, stacked_value_bias,
stacked_out, stacked_out_bias,
read_modes, read_addrs,
ctrl_z_prev, ctrl_z_res, ctrl_z_const,
...
) -> MachineState:
...

This is a pure module-level function. All the “state” it needs — the stacked weight tensors, control scalars, and read-mode arrays — are passed as explicit JAX array arguments. There is no self. JAX can freely trace through it, cache the compiled XLA computation, and re-use it across calls.

In JAXTransformerVM.__init__, after compiling all instructions, we bind the weight arrays into a functools.partial:

self._step_fn = functools.partial(
_jit_step,
num_insts=self.num_insts,
stacked_gate=self.stacked_gate,
...
)

Now self._step_fn(state) is a state → state callable that is both JIT-compiled and safe to close over in while_loop‘s body_fun. The run loop becomes:

def body_fun(val):
step, s = val
return step + 1, step_fn(s)
final_step, final_state = jax.lax.while_loop(cond_fun, body_fun, (jnp.int32(0), state))

This is idiomatic JAX. The loop runs in a single XLA computation, with no Python overhead per step. On a CPU, it runs roughly 10× faster than a pure Python loop. On a TPU, the gap widens dramatically.


Memory, Attention, and the Operand Read

In the Rust implementation, attention is performed via a proper convex-hull backed KV cache — a dynamic data structure that maintains the Pareto frontier of (step, value) writes and answers queries via binary search in O(log n). This is beautiful algorithmically but hard to JIT-compile: its memory footprint grows dynamically, and its binary search has data-dependent control flow.

In jax_tvm, we take a simpler approach that preserves the functional semantics: direct memory array indexing. The compiled instruction’s MemoryRead descriptor specifies either:

  • none — no memory access needed (e.g., LOADIADD immJMP)
  • direct(addr) — read MEM[addr] before the FFN
  • stack_top — read MEM[SP] before the FFN (for POPRET)

Because the VM’s memory is a fixed-size JAX integer array, these reads are simple state.memory[addr] indexing operations. JAX handles them as array operations, which are trivially JIT-compilable. The tradeoff is that we drop the O(log n) guarantee — but for the sizes of programs in the ISA (memory sizes up to 255 cells), this is not a practical concern.

The operand value is loaded before the FFN call and placed into IN_OPERAND (slot 8) of the 41-dimensional input vector. The FFN then acts on a fully constructed input that already contains the memory read result. This cleanly separates “what to read” (decided at compile time, encoded in read_modes and read_addrs) from “what to compute” (the FFN forward pass).


Parsing .tvm Programs

The instruction parser (jax_tvm/instruction.py) is a straightforward two-pass assembler. The first pass collects .memory size directives, .init addr val memory initialization values, and label definitions. Labels are stored as a dictionary mapping names to program counter values. The second pass converts each mnemonic line into an Instruction(opcode, operand) pair, resolving label references.

One important bug fixed during development: the .tvm format uses semicolons (;) for comments, not just //. The original parser only stripped // comments, which caused labels like loop: to be treated as opcodes when the line contained a trailing comment. The fix is a single additional split:

line = line.split('//')[0].split(';')[0].strip()

What Running Fibonacci Looks Like

With everything working, running fibonacci.tvm through the JAX engine looks like this:

python3 -m jax_tvm programs/fibonacci.tvm
program: programs/fibonacci.tvm
engine: jax_transformer
steps: 103
halted: True
acc: 21
zero_flag: False
memory (first 5): [13, 21, 21, 7, 7]

The program counter ticked 103 times. At each tick, the JIT-compiled _jit_step function executed: it selected the compiled FFN for the current PC, loaded any required memory operand, ran the gated bilinear forward pass, computed the next state from the 6-dimensional output vector, and optionally wrote to memory. The loop terminated because the halted flag was set to True by the HALT instruction’s compiled weights.

The accumulator holds 21. Fibonacci(8) = 21. Correct.


The Path to Proof

The JAX implementation is not yet connected to the STARK prover — that lives in the Rust codebase. But the mathematical structure is the same. The transition from state s_t to s_{t+1} is a polynomial function: the gated bilinear FFN is degree-2 in the input. The AIR (Algebraic Intermediate Representation) for the STARK is defined by exactly these transition polynomials.

What this means is that the JAX execution trace is, in principle, a valid STARK witness. The boundary constraints (initial and final state) and the transition constraints (the FFN computation) are the same whether you extract them from Rust or from JAX. A future direction is to export the trace from jax_tvm and feed it directly into the STARK prover — completing the pipeline:

.tvm program
→ JAX compiler (weights)
→ JAX runtime (trace)
→ STARK prover (proof)
→ verifier (accept / reject)

Conclusion

jax_tvm demonstrates that the Transformer VM concept is not tied to Rust, Burn tensors, or ONNX runtimes. It is a mathematical idea that translates directly to JAX: compile assembly into bilinear weight matrices, read memory operands via array indexing, drive the execution via jax.lax.while_loop, and get a deterministic, JIT-accelerated computation that runs identically on CPU, GPU, or TPU.

The implementation required working through JAX’s strictness about what values can live inside traced computations — a discipline that leads to cleaner, more composable code. MachineState is a JAX pytree. The step function is a pure function of arrays. The loop has no Python callbacks. Everything is static-shape and XLA-compilable.

The result is a small but complete independent engine: parse a .tvm file, compile it to FFN weights, run it inside a while loop, get a verified final state. Same semantics as the Rust transformer engine. Different implementation. Same answer.

Leave a comment