Discrete Language Diffusion: Global Refinement on JAX/TPU
Most modern Large Language ModelsLLMs generate text by predicting one token at a time in a sequential manner, a method known as Auto-regressiveAR generation. While this approach is effective for producing locally fluent text, it suffers from a structural limitation: once a token is generated, it is fixed and cannot be corrected based on future context. In this post, I explore the application of Diffusion mechanisms to text generation—allowing a model to refine a sequence globally, much like a human drafts and polishes a paragraph across multiple iterations.
The full source code and a live demo for this project are available below. I highly recommend trying it out to see the powerful synergy between JAX and TPU v6e.
KennethanCeyer/dllm-on-jax-tpu
Discrete masked language diffusion model using JAX and Flax NNX on Google TPU V6E
Google Colab
Preview unavailable.
1. Structural Limitations of Auto-regressive Models
Auto-regressive models predict the next token conditioned on the previous sequence. This paradigm inherently leads to Exposure Bias[1]. Even a minor generative slip early in a sequence can snowball into a full contextual breakdown—a primary driver of 'hallucination'.
Consider a scenario where a character is searching for food in a kitchen:
An auto-regressive model cannot retroactively correct the word [ball] once it has been sampled.[2] The model is forced to maintain superficial grammatical consistency by continuing the illogical narrative of eating the ball, leading to a collapse of the story's coherence.
In contrast, diffusion models treat the entire sentence as a single State Space and refine it simultaneously. Starting from a state where most tokens are masked, the model updates all positions in parallel. A Bidirectional Transformer allows the model to re-evaluate the relationship between all tokens in real-time. If a specific token conflicts with the global context, it can be surgically corrected in the subsequent refinement steps.

2. Core Value of Discrete Language DiffusionDLLM
Building the Discrete Language Diffusion ModelDLLM revealed several structural advantages over traditional sequential generation:
- Global Coherence: By removing noise at the sequence level rather than the token level, the model maintains a significantly stronger narrative flow.
- Flexible Compute Cost: The quality-vs-latency tradeoff can be controlled in real-time by adjusting the number of diffusion steps during inference.
- Surgical Editing: The architecture is natively suited for in-filling tasks, allowing for the re-generation of specific segments based on their surrounding context.
3. Recovery in Discrete State Space
Since natural language is discrete[3], we employ Masking Diffusion instead of continuous Gaussian noise.
The model iteratively recovers original tokens from a masked state by referencing the entire sequence. This bidirectional approach enables a higher-dimensional understanding of context, considering both preceding and succeeding information simultaneously.
4. Hardware Acceleration and Architectural Synergy
The choice between sequential stacking and parallel global updates has profound implications for hardware utilization.
| Feature | Auto-regressive (AR) | Diffusion (DLLM) |
|---|---|---|
| Generation | Sequential | Parallel Global Refinement |
| Context | Uni-directional | Bi-directional |
| Hardware Bottleneck | Memory BandwidthHBM | Compute UnitsMXU |
| Key Optimization | KV Cache | Sequence Parallelism |
Because diffusion models rely on iterative compute loops, aligning the architecture with the underlying hardware accelerator and software stack is critical for achieving practical performance.
5. Infrastructure: JAX and TPU v6e (Trillium)
This project is optimized for Google’s latest AI infrastructure, specifically TPU v6e (Trillium) and the JAX ecosystem.
5.1 JAX Functional Paradigm and XLA Compilation
JAX leverages the XLAAccelerated Linear Algebra compiler to generate highly optimized binaries for TPU hardware. The functional structure of JAX is a perfect match for the Systolic Array design of TPUs. By compiling the entire diffusion loop with jax.jit, we eliminate host-to-device overhead between steps, ensuring the Matrix Execution UnitsMXU stay saturated with high-throughput operations.
5.2 Flax NNX: Harmonizing OOP and Functional Purity
Model development was powered by Flax NNX, which allows for object-oriented state management without sacrificing JAX's functional requirements. This drastically simplifies the implementation of complex iterative loops, as weights and variables are handled as internal object properties, keeping the diffusion logic clean and maintainable.
5.3 Model Architecture Specifications
The following specifications were chosen to maximize compute efficiency on a single TPU v6e-1 chip.
| Parameter | Value | Rationale |
|---|---|---|
| Hidden Size | 1280 | Multiple of 128 for optimal MXU utilization. |
| Layers | 16 | Depth required for global context propagation. |
| Diffusion Steps | 128 | Convergence sweet spot for quality and speed. |
| Precision | BF16/FP16 | Native TPU hardware acceleration. |
6. Training Strategy and Implementation
I utilized the TinyStories dataset, which consists of short, syntactically simple narratives ideal for validating architectural concepts in low-resource environments.
roneneldan/TinyStories · Datasets at Hugging Face
We’re on a journey to advance and democratize artificial intelligence through open source and open science.
6.1 BPE Tokenizer Optimization
I trained a custom BPEByte Pair Encoding tokenizer with a focused vocabulary of 4,096 tokens. By keeping the vocabulary small, I significantly reduced the parameter footprint of the embedding layer. This "parameter efficiency" allowed for a deeper Transformer backbone, enhancing the model's capacity for complex logical reasoning within a limited memory budget.
6.2 TPU-Optimized Corruption Logic
To maintain maximum hardware throughput, I avoided standard branching (if/else) in the corruption process. Instead, I used jnp.where for vector-parallel masking, ensuring the TPU execution pipeline remains streamlined.
6.3 Confidence-based Re-masking
During inference, the model preserves only the tokens it is most "confident" about—those with the highest predicted probabilities—and re-masks the rest. This iterative process prevents the model from getting stuck in Greedy local optima, allowing the global context to emerge over 128 refinement steps.
6.3.1 Serving Performance: Auto-regressive model (AR) vs Diffusion (DLLM)
| Feature | Auto-regressive model (AR) | Diffusion (DLLM) |
|---|---|---|
| Complexity | (Sequence length) | (Diffusion steps) |
| Hardware Bottleneck | Memory BandwidthHBM | Compute UnitsMXU |
| Optimization | KV Cache | Parallel sequence updates |
Diffusion models offer predictable inference times regardless of sequence length (within the context window), whereas Auto-regressive costs scale linearly with sentence length.
7. Results and Visualization
The resulting model demonstrates a remarkable ability to self-correct during the generation process. Inconsistencies sampled in early steps are often resolved in later iterations, leading to more coherent and logically sound narratives than comparable small-scale AR models.

The underlying infrastructure architecture used for these experiments is shown below.

8. Final Thoughts
Implementing Discrete Language Diffusion on JAX/TPU highlights a shift from sequential to global optimization in NLP. The architectural synergy between bidirectional attention and TPU's parallel compute capabilities offers a compelling alternative to the standard auto-regressive paradigm.
I hope this exploration provides valuable insights for engineers looking to push the boundaries of generative architectures.
TPU v6e Resource Management: TPU v6e is tailor-made for Transformer workloads. You can explore these capabilities using TPU Flex-start on Google Cloud for a cost-effective and powerful acceleration experience.
Footnotes
- 1: Exposure BiasExposure Bias: A phenomenon where the model is trained on ground truth tokens but forced to rely on its own previous (potentially erroneous) predictions during inference, causing error accumulation. [↩︎]
- 2: AR models rely on unidirectional causality (), making it impossible to perform backward optimization of past decisions during the current generation step. [↩︎]
- 3: Discrete: Data that consists of distinct, separate values rather than a continuous range. Unlike images, which are continuous, text is composed of individual tokens (words or characters). [↩︎]
Recommended Articles
Pre-training Decoder-based Tiny LLM with JAX and TPU
We dissect the entire process from raw text data being read from disk, tokenized, and reborn as meaningful sentences through hardware called TPU. Let's implement the design of the latest Llama model directly with JAX and transform from a user of the model to a designer of the model.
How Do GPUs Perform Machine Learning Computations?
Explore the principles of hardware acceleration from Python code to GPU transistors through JAX and CUDA.