A deep dive into the independent JAX implementation of the Transformer VM — where feed-forward weights become machine instructions, and jax.lax.while_loop becomes the CPU.
The Premise
What if a transformer’s forward pass wasn’t a language model — but a computer?
Not metaphorically. Literally: a machine that reads assembly instructions, maintains registers and memory, pushes and pops a stack, and halts when it is done. Every step is deterministic. Same input, same output, provably and always.
This is the core idea behind the llm-provable-computer project, originally built in Rust. The Rust version is tight and fast, with convex-hull-backed KV-cache memory and a hand-rolled STARK prover. But Rust is a compiled systems language, and the modern ML world runs on JAX — especially when you want to run things on TPUs.
This post describes jax_tvm: an independent, idiomatic JAX implementation of the Transformer VM that compiles .tvm assembly programs into feed-forward neural network weights and executes them inside a jax.lax.while_loop. It is not a wrapper around the Rust code. It is a fresh port — same semantics, JAX-native implementation.
How the Transformer VM Works
Before diving into JAX-specific design decisions, it helps to understand what the system is doing at a conceptual level.
A standard transformer has two alternating components per layer: a multi-head attention block and a feed-forward network (FFN). The attention block attends over the context (keys and values) to retrieve information. The FFN transforms the attended representation into a new one. Together they form the “neural computer” of the architecture.
The Transformer VM exploits this structure differently from a language model:
Attention is memory access. Each memory address in the VM maintains a write history as 2D points (step, value). To read the most recent write to an address, the attention head queries with direction [1, 0], which selects the maximum step. This is the convex-hull trick: the maximizer of any linear objective over a finite set lies on its convex hull. In practice, this gives you deterministic hard-argmax reads from a KV cache — not soft probabilistic mixing, but exact retrieval.
Feed-forward layers are instructions. Every instruction in the assembly ISA — LOAD, ADD, JZ, CALL, RET, and so on — is compiled into a specific set of weight matrices for the FFN. The gate and value pathways implement different logic per opcode, and the output projection maps the hidden state to a set of control signals: next PC, next ACC, next SP, memory write address, memory write value, and a write-enable bit. Together, these signals fully specify the next machine state.
The trace is a STARK witness. Because every step is deterministic and follows polynomial transition rules (the gate-value bilinear form is polynomial in the input), the entire execution trace is an algebraic object. The STARK prover can certify it with a transparent proof — no trusted setup, post-quantum security, verification in O(log² n) steps.
The JAX implementation replicates the FFN-as-instruction approach faithfully.
The Compiler: Turning Instructions into Matrices
The first and most important component of jax_tvm is the compiler (jax_tvm/compiler.py). Its job is to take a Program — a list of parsed Instruction objects — and produce a corresponding set of weight matrices for each instruction.
The key insight is the gated bilinear form. The JAX FFN is defined as:
gate = W_gate @ input + b_gatevalue = W_value @ input + b_valuehidden = gate * value # elementwise productoutput = W_out @ hidden + b_out
This is not GELU or SiLU — it is a structured bilinear network where each hidden unit computes the product of two linear combinations. This is exactly what you need to implement conditional computation: the gate pathway can encode a selector, and the value pathway can encode a transformed quantity, so the product is nonzero only when the gate fires.
For each instruction, the FeedForwardBuilder class accumulates matrix entries:
emit_linear(output_idx, coeff, input_idx)— adds a termcoeff * input[input_idx]tooutput[output_idx]by using a single hidden unit where gate is constant1.0and value routes through the target input.emit_product(output_idx, coeff, left, right)— adds a bilinear termcoeff * input[left] * input[right]by routing left and right through the gate and value respectively.add_output_bias(output_idx, value)— directly adds to the output bias vector.
The input vector is always 41-dimensional (INPUT_DIM = 41):
| Slot | Content |
|---|---|
| 0 | Constant 1.0 |
| 1 | PC (current program counter) |
| 2 | PC + 1 (next sequential PC) |
| 3 | ACC (accumulator, signed integer) |
| 4 | zero_flag |
| 5 | carry_flag |
| 6 | halted |
| 7 | SP (stack pointer) |
| 8 | OPERAND (memory value read by attention) |
| 9–24 | 16 individual bits of ACC |
| 25–40 | 16 individual bits of OPERAND |
The output vector is 6-dimensional (OUTPUT_DIM = 6): next PC, raw ACC, next SP, memory write enable, memory write address, and memory write value.
As a concrete example, ADDM addr (add memory at addr to accumulator) compiles to:
emit_linear(OUT_NEXT_PC, 1.0, IN_PC_NEXT)— advance PC by oneemit_linear(OUT_RAW_ACC, 1.0, IN_ACC)— pass current ACC throughemit_linear(OUT_RAW_ACC, 1.0, IN_OPERAND)— add memory operand
And the memory read is configured as MemoryRead(mode="direct", addr=addr), telling the attention stage to load MEM[addr] and place it into IN_OPERAND before the FFN runs.
Control flow like JZ target (jump if zero) is particularly elegant:
next_PC = IN_ZERO * target + IN_PC_NEXT - IN_ZERO * IN_PC_NEXT = IN_ZERO * (target - IN_PC_NEXT) + IN_PC_NEXT
This is implemented with three terms: a linear term routing IN_ZERO, a constant bias for target, a linear term routing IN_PC_NEXT, and a bilinear product of IN_ZERO × IN_PC_NEXT with coefficient -1.0. The result is a smooth, differentiable blend — and because IN_ZERO is always exactly 0.0 or 1.0 in the integer VM regime, the output is exact.
The Runtime: A JIT-Compiled Execution Loop
The execution loop (jax_tvm/runtime.py) is where JAX’s design philosophy becomes most visible — and most demanding.
The core challenge is that jax.lax.while_loop requires all values it operates on to be valid JAX pytrees. The loop cannot capture arbitrary Python objects in its body or condition functions. This is because while_loop is compiled by XLA: it traces through the body function, produces a computation graph, and the graph’s leaves must be concrete JAX arrays.
The first error we encountered was exactly this: TypeError: Value MachineState ... is not a valid JAX type. The fix was to register MachineState as a JAX pytree node by adding @register_pytree_node_class and implementing tree_flatten / tree_unflatten methods. JAX then knows how to decompose and reconstruct the dataclass during tracing.
The second error was subtler: TypeError: Cannot interpret value of type JAXTransformerVM as an abstract array. This happened because self.step was decorated with @jax.jit as a method — and when body_fun inside run() captured self in its closure, JAX tried to trace self through the computation graph. Since self is a Python class instance with numpy arrays, dict attributes, and methods, JAX has no idea what to do with it.
The architectural fix is clean and permanent: lift the step function out of the class.
@functools.partial(jax.jit, static_argnums=(1,))def _jit_step( state: MachineState, num_insts: int, # static — used for jnp.clip bound stacked_gate, stacked_gate_bias, stacked_value, stacked_value_bias, stacked_out, stacked_out_bias, read_modes, read_addrs, ctrl_z_prev, ctrl_z_res, ctrl_z_const, ...) -> MachineState: ...
This is a pure module-level function. All the “state” it needs — the stacked weight tensors, control scalars, and read-mode arrays — are passed as explicit JAX array arguments. There is no self. JAX can freely trace through it, cache the compiled XLA computation, and re-use it across calls.
In JAXTransformerVM.__init__, after compiling all instructions, we bind the weight arrays into a functools.partial:
self._step_fn = functools.partial( _jit_step, num_insts=self.num_insts, stacked_gate=self.stacked_gate, ...)
Now self._step_fn(state) is a state → state callable that is both JIT-compiled and safe to close over in while_loop‘s body_fun. The run loop becomes:
def body_fun(val): step, s = val return step + 1, step_fn(s)final_step, final_state = jax.lax.while_loop(cond_fun, body_fun, (jnp.int32(0), state))
This is idiomatic JAX. The loop runs in a single XLA computation, with no Python overhead per step. On a CPU, it runs roughly 10× faster than a pure Python loop. On a TPU, the gap widens dramatically.
Memory, Attention, and the Operand Read
In the Rust implementation, attention is performed via a proper convex-hull backed KV cache — a dynamic data structure that maintains the Pareto frontier of (step, value) writes and answers queries via binary search in O(log n). This is beautiful algorithmically but hard to JIT-compile: its memory footprint grows dynamically, and its binary search has data-dependent control flow.
In jax_tvm, we take a simpler approach that preserves the functional semantics: direct memory array indexing. The compiled instruction’s MemoryRead descriptor specifies either:
none— no memory access needed (e.g.,LOADI,ADD imm,JMP)direct(addr)— readMEM[addr]before the FFNstack_top— readMEM[SP]before the FFN (forPOP,RET)
Because the VM’s memory is a fixed-size JAX integer array, these reads are simple state.memory[addr] indexing operations. JAX handles them as array operations, which are trivially JIT-compilable. The tradeoff is that we drop the O(log n) guarantee — but for the sizes of programs in the ISA (memory sizes up to 255 cells), this is not a practical concern.
The operand value is loaded before the FFN call and placed into IN_OPERAND (slot 8) of the 41-dimensional input vector. The FFN then acts on a fully constructed input that already contains the memory read result. This cleanly separates “what to read” (decided at compile time, encoded in read_modes and read_addrs) from “what to compute” (the FFN forward pass).
Parsing .tvm Programs
The instruction parser (jax_tvm/instruction.py) is a straightforward two-pass assembler. The first pass collects .memory size directives, .init addr val memory initialization values, and label definitions. Labels are stored as a dictionary mapping names to program counter values. The second pass converts each mnemonic line into an Instruction(opcode, operand) pair, resolving label references.
One important bug fixed during development: the .tvm format uses semicolons (;) for comments, not just //. The original parser only stripped // comments, which caused labels like loop: to be treated as opcodes when the line contained a trailing comment. The fix is a single additional split:
line = line.split('//')[0].split(';')[0].strip()
What Running Fibonacci Looks Like
With everything working, running fibonacci.tvm through the JAX engine looks like this:
python3 -m jax_tvm programs/fibonacci.tvm
program: programs/fibonacci.tvmengine: jax_transformersteps: 103halted: Trueacc: 21zero_flag: Falsememory (first 5): [13, 21, 21, 7, 7]
The program counter ticked 103 times. At each tick, the JIT-compiled _jit_step function executed: it selected the compiled FFN for the current PC, loaded any required memory operand, ran the gated bilinear forward pass, computed the next state from the 6-dimensional output vector, and optionally wrote to memory. The loop terminated because the halted flag was set to True by the HALT instruction’s compiled weights.
The accumulator holds 21. Fibonacci(8) = 21. Correct.
The Path to Proof
The JAX implementation is not yet connected to the STARK prover — that lives in the Rust codebase. But the mathematical structure is the same. The transition from state s_t to s_{t+1} is a polynomial function: the gated bilinear FFN is degree-2 in the input. The AIR (Algebraic Intermediate Representation) for the STARK is defined by exactly these transition polynomials.
What this means is that the JAX execution trace is, in principle, a valid STARK witness. The boundary constraints (initial and final state) and the transition constraints (the FFN computation) are the same whether you extract them from Rust or from JAX. A future direction is to export the trace from jax_tvm and feed it directly into the STARK prover — completing the pipeline:
.tvm program → JAX compiler (weights) → JAX runtime (trace) → STARK prover (proof) → verifier (accept / reject)
Conclusion
jax_tvm demonstrates that the Transformer VM concept is not tied to Rust, Burn tensors, or ONNX runtimes. It is a mathematical idea that translates directly to JAX: compile assembly into bilinear weight matrices, read memory operands via array indexing, drive the execution via jax.lax.while_loop, and get a deterministic, JIT-accelerated computation that runs identically on CPU, GPU, or TPU.
The implementation required working through JAX’s strictness about what values can live inside traced computations — a discipline that leads to cleaner, more composable code. MachineState is a JAX pytree. The step function is a pure function of arrays. The loop has no Python callbacks. Everything is static-shape and XLA-compilable.
The result is a small but complete independent engine: parse a .tvm file, compile it to FFN weights, run it inside a while loop, get a verified final state. Same semantics as the Rust transformer engine. Different implementation. Same answer.
Leave a comment