Discrete Language Diffusion: Global Refinement on JAX/TPU

Most modern Large Language ModelsLLMs generate text by predicting one token at a time in a sequential manner, a method known as Auto-regressiveAR generation. While this approach is effective for producing locally fluent text, it suffers from a structural limitation: once a token is generated, it is fixed and cannot be corrected based on future context. In this post, I explore the application of Diffusion mechanisms to text generation—allowing a model to refine a sequence globally, much like a human drafts and polishes a paragraph across multiple iterations.

The full source code and a live demo for this project are available below. I highly recommend trying it out to see the powerful synergy between JAX and TPU v6e.

KennethanCeyer/dllm-on-jax-tpu

Discrete masked language diffusion model using JAX and Flax NNX on Google TPU V6E

iconhttps://github.com/KennethanCeyer/dllm-on-jax-tpu

Google Colab

Preview unavailable.

iconhttps://colab.research.google.com/drive/1631kiR9OQXjIoZjtYBV8k1NehsreGIf8
preview

1. Structural Limitations of Auto-regressive Models

Auto-regressive models predict the next token conditioned on the previous sequence. This paradigm inherently leads to Exposure Bias[1]. Even a minor generative slip early in a sequence can snowball into a full contextual breakdown—a primary driver of 'hallucination'.

Consider a scenario where a character is searching for food in a kitchen:

The character went to the kitchen because they were hungry. They opened the fridge and found a [ball] inside. They took a big bite of the [ball].

An auto-regressive model cannot retroactively correct the word [ball] once it has been sampled.[2] The model is forced to maintain superficial grammatical consistency by continuing the illogical narrative of eating the ball, leading to a collapse of the story's coherence.

In contrast, diffusion models treat the entire sentence as a single State Space and refine it simultaneously. Starting from a state where most tokens are masked, the model updates all positions in parallel. A Bidirectional Transformer allows the model to re-evaluate the relationship between all tokens in real-time. If a specific token conflicts with the global context, it can be surgically corrected in the subsequent refinement steps.

Diffusion Language Models Architecture
Figure 1: Architecture of Discrete Masked Language Diffusion(Source: arXiv 2502.09992)

2. Core Value of Discrete Language DiffusionDLLM

Building the Discrete Language Diffusion ModelDLLM revealed several structural advantages over traditional sequential generation:

  • Global Coherence: By removing noise at the sequence level rather than the token level, the model maintains a significantly stronger narrative flow.
  • Flexible Compute Cost: The quality-vs-latency tradeoff can be controlled in real-time by adjusting the number of diffusion steps during inference.
  • Surgical Editing: The architecture is natively suited for in-filling tasks, allowing for the re-generation of specific segments based on their surrounding context.

3. Recovery in Discrete State Space

Since natural language is discrete[3], we employ Masking Diffusion instead of continuous Gaussian noise.

The model iteratively recovers original tokens from a masked state by referencing the entire sequence. This bidirectional approach enables a higher-dimensional understanding of context, considering both preceding and succeeding information simultaneously.

Discrete Diffusion Refinement
Figure 2: Global sequence polishing process in language diffusion.

4. Hardware Acceleration and Architectural Synergy

The choice between sequential stacking and parallel global updates has profound implications for hardware utilization.

FeatureAuto-regressive (AR)Diffusion (DLLM)
GenerationSequentialParallel Global Refinement
ContextUni-directionalBi-directional
Hardware BottleneckMemory BandwidthHBMCompute UnitsMXU
Key OptimizationKV CacheSequence Parallelism
Table 1: Engineering Comparison: AR vs. Diffusion

Because diffusion models rely on iterative compute loops, aligning the architecture with the underlying hardware accelerator and software stack is critical for achieving practical performance.

5. Infrastructure: JAX and TPU v6e (Trillium)

This project is optimized for Google’s latest AI infrastructure, specifically TPU v6e (Trillium) and the JAX ecosystem.

5.1 JAX Functional Paradigm and XLA Compilation

JAX leverages the XLAAccelerated Linear Algebra compiler to generate highly optimized binaries for TPU hardware. The functional structure of JAX is a perfect match for the Systolic Array design of TPUs. By compiling the entire diffusion loop with jax.jit, we eliminate host-to-device overhead between steps, ensuring the Matrix Execution UnitsMXU stay saturated with high-throughput operations.

5.2 Flax NNX: Harmonizing OOP and Functional Purity

Model development was powered by Flax NNX, which allows for object-oriented state management without sacrificing JAX's functional requirements. This drastically simplifies the implementation of complex iterative loops, as weights and variables are handled as internal object properties, keeping the diffusion logic clean and maintainable.

5.3 Model Architecture Specifications

The following specifications were chosen to maximize compute efficiency on a single TPU v6e-1 chip.

ParameterValueRationale
Hidden Size1280Multiple of 128 for optimal MXU utilization.
Layers16Depth required for global context propagation.
Diffusion Steps128Convergence sweet spot for quality and speed.
PrecisionBF16/FP16Native TPU hardware acceleration.
Table 2: Discrete Diffusion Model (320M) Specs

6. Training Strategy and Implementation

I utilized the TinyStories dataset, which consists of short, syntactically simple narratives ideal for validating architectural concepts in low-resource environments.

roneneldan/TinyStories · Datasets at Hugging Face

We’re on a journey to advance and democratize artificial intelligence through open source and open science.

iconhttps://huggingface.co/datasets/roneneldan/TinyStories
preview

6.1 BPE Tokenizer Optimization

I trained a custom BPEByte Pair Encoding tokenizer with a focused vocabulary of 4,096 tokens. By keeping the vocabulary small, I significantly reduced the parameter footprint of the embedding layer. This "parameter efficiency" allowed for a deeper Transformer backbone, enhancing the model's capacity for complex logical reasoning within a limited memory budget.

6.2 TPU-Optimized Corruption Logic

To maintain maximum hardware throughput, I avoided standard branching (if/else) in the corruption process. Instead, I used jnp.where for vector-parallel masking, ensuring the TPU execution pipeline remains streamlined.

training_logic.py
def corrupt_batch(batch, rng, mask_token_id, t_steps): # Calculate survival probability via cosine schedule survival_prob = jnp.cos((t / t_steps) * (jnp.pi / 2)) # Vector-parallel masking optimized for TPU MXU return jnp.where(mask, batch, mask_token_id), t

6.3 Confidence-based Re-masking

During inference, the model preserves only the tokens it is most "confident" about—those with the highest predicted probabilities—and re-masks the rest. This iterative process prevents the model from getting stuck in Greedy local optima, allowing the global context to emerge over 128 refinement steps.

6.3.1 Serving Performance: Auto-regressive model (AR) vs Diffusion (DLLM)

FeatureAuto-regressive model (AR)Diffusion (DLLM)
ComplexityO(N)O(N) (Sequence length)O(T)O(T) (Diffusion steps)
Hardware BottleneckMemory BandwidthHBMCompute UnitsMXU
OptimizationKV CacheParallel sequence updates

Diffusion models offer predictable inference times regardless of sequence length (within the context window), whereas Auto-regressive costs scale linearly with sentence length.

7. Results and Visualization

The resulting model demonstrates a remarkable ability to self-correct during the generation process. Inconsistencies sampled in early steps are often resolved in later iterations, leading to more coherent and logically sound narratives than comparable small-scale AR models.

Diffusion Generation Trace
Figure 3: Visualizing the global refinement of text over diffusion timesteps

The underlying infrastructure architecture used for these experiments is shown below.

Cloud TPU v5e Architecture
Figure 4: Infrastructure architecture of Cloud TPU used for training(Source: Google Cloud Blog)

8. Final Thoughts

Implementing Discrete Language Diffusion on JAX/TPU highlights a shift from sequential to global optimization in NLP. The architectural synergy between bidirectional attention and TPU's parallel compute capabilities offers a compelling alternative to the standard auto-regressive paradigm.

I hope this exploration provides valuable insights for engineers looking to push the boundaries of generative architectures.


TPU v6e Resource Management: TPU v6e is tailor-made for Transformer workloads. You can explore these capabilities using TPU Flex-start on Google Cloud for a cost-effective and powerful acceleration experience.


Footnotes


  • 1: Exposure BiasExposure Bias: A phenomenon where the model is trained on ground truth tokens but forced to rely on its own previous (potentially erroneous) predictions during inference, causing error accumulation. [↩︎]
  • 2: AR models rely on unidirectional causality (t1tt-1 \to t), making it impossible to perform backward optimization of past decisions during the current generation step. [↩︎]
  • 3: Discrete: Data that consists of distinct, separate values rather than a continuous range. Unlike images, which are continuous, text is composed of individual tokens (words or characters). [↩︎]

Recommended Articles