디퓨전 언어 모델: JAX/TPU 기반의 이산형 생성 구조와 원리

대규모 언어 모델LLM, Large Language Models은 토큰을 하나씩 순서대로 예측하는 자동 회귀AR, Autoregressive 방식을 주로 채택한다. 이는 문법적으로 유려한 문장을 생성하는 데 탁월하지만, 한 번 확정된 토큰을 사후에 수정할 수 없다는 구조적 결함이 존재한다. 본 포스트에서는 문장을 선형적으로 써 내려가는 대신, 사람이 초안을 다듬듯 전체 시퀀스를 동시다발적으로 정교화하는 디퓨전Diffusion 메커니즘을 텍스트 생성에 적용해 본 과정을 공유한다.

프로젝트 코드는 아래 링크에 공개해 두었다. JAX와 TPU의 강력한 결합이 선사하는 연산 효율성을 직접 확인해 보길 권장한다.

KennethanCeyer/dllm-on-jax-tpu

Discrete masked language diffusion model using JAX and Flax NNX on Google TPU V6E

iconhttps://github.com/KennethanCeyer/dllm-on-jax-tpu

Google Colab

Preview unavailable.

iconhttps://colab.research.google.com/drive/1631kiR9OQXjIoZjtYBV8k1NehsreGIf8
preview

1. 자동 회귀 모델의 구조적 한계

자동 회귀 모델은 이전 시점의 토큰들을 조건부로 다음 토큰을 예측한다. 이 방식은 필연적으로 노출 편향Exposure Bias[1] 문제를 야기한다. 생성 초기 단계의 미세한 오차가 뒤로 갈수록 증폭되어 전체 문맥의 일관성을 훼손하는 '할루시네이션Hallucination'의 원인이 되기도 한다.

예를 들어, 냉장고 안에서 음식을 찾는 상황을 생성한다고 가정해 보자.

주인공은 배가 고파서 주방으로 향했다. 그는 냉장고를 열었고, 그 안에서 [공]을 발견했다. 그는 그 [공]을 한 입 베어 물었다.

자동 회귀 모델은 이미 생성된 [공]이라는 토큰을 스스로 부정하거나 수정할 수 없다.[2] 따라서 모델은 비논리적인 상황임에도 불구하고 '공을 먹는다'는 흐름을 억지로 이어가게 되며, 이는 전체 서사의 붕괴로 이어진다.

반면 디퓨전 모델은 문장 전체를 하나의 상태 공간State Space으로 간주하고 동시에 완성한다. 대부분의 토큰이 마스킹된 상태에서 시작하여, 모든 위치의 토큰을 병렬로 업데이트하며 문맥을 다듬는다. 양방향 트랜스포머Bidirectional Transformer가 문장 전체의 상관관계를 실시간으로 참조하므로, 특정 토큰이 전체 흐름과 상충할 경우 다음 스텝에서 이를 스스로 교정할 기회를 얻는다.

Diffusion Language Models Architecture
그림 1: 이산형 마스킹 언어 디퓨전의 아키텍처 구조(Source: arXiv 2502.09992)

2. 디퓨전 언어 모델DLLM의 핵심 가치

이산형 언어 디퓨전 모델DLLM, Discrete Language Diffusion Models을 구축하며 확인한 주요 특징과 구조적 이점은 다음과 같다.

  • 전역 문맥의 일관성: 토큰 단위가 아닌 시퀀스 단위로 노이즈를 제거하므로 서사의 일관성이 뛰어나다.
  • 유연한 연산 비용: 추론 시 디퓨전 스텝 수를 조절하여 생성 품질과 속도 사이의 트레이드오프Trade-off를 실시간으로 제어할 수 있다.
  • 부분 교정In-filling: 문장의 특정 구간만 마스킹하고 주변 문맥에 맞춰 재생성하는 등 정교한 편집 기능에 특화되어 있다.

3. 이산 상태 공간의 복원Discrete State Space Recovery

자연어는 이산적Discrete[3] 성격을 띠므로, 연속적인 가우시안 노이즈 대신 토큰을 [MASK]라는 특수 토큰으로 교환하는 마스킹 디퓨전Masking Diffusion 방식을 사용한다.

이 과정에서 모델은 [MASK] 주변 정보를 토대로 원래 토큰을 점진적으로 복원해 나간다. 특히 양방향 트랜스포머를 활용함으로써, 시퀀스의 시작과 끝을 동시에 고려하는 고차원적인 문맥 파악이 가능해진다.

이산형 디퓨전 다듬기 과정
그림 2: 문장 전체를 동시에 다듬어가는 언어 디퓨전 프로세스

4. 하드웨어 가속과 아키텍처의 상관관계

토큰을 순차적으로 쌓는 것과 전체를 한꺼번에 업데이트하는 것은 연산 하드웨어 활용 측면에서 극명한 차이를 보인다.

항목자동 회귀 방식디퓨전 방식
생성 방식순차적Sequential전체 병렬Parallel
정보 참조단방향Uni-directional양방향Bi-directional
하드웨어 병목메모리 대역폭HBM 중심연산 장치MXU 중심
주요 최적화KV 캐시KV Cache시퀀스 병렬화
표 1: 아키텍처별 엔지니어링 특성 비교

디퓨전 모델은 반복적인 연산 루프가 발생하므로, 이를 고속으로 처리할 수 있는 하드웨어 가속기와 소프트웨어 스택의 정렬이 무엇보다 중요하다.

5. JAX와 TPU v6e (Trillium) 인프라

본 프로젝트는 구글의 최신 가속기인 TPU v6eTrillium와 JAX 생태계를 기반으로 최적화되었다.

5.1 JAX의 함수형 paradigm과 XLA 컴파일

JAX는 XLAAccelerated Linear Algebra 컴파일러를 통해 TPU 아키텍처에 최적화된 바이너리를 생성한다. JAX의 순수 함수형 구조는 TPU의 시스톨릭 어레이Systolic Array 설계와 완벽하게 일치한다. 특히 디퓨전의 반복 루프를 jax.jit으로 컴파일하면 단계 간 오버헤드를 제로에 가깝게 줄일 수 있어, MXUMatrix Execution Unit 가동률을 극대화할 수 있다.

5.2 Flax NNX: 객체 지향과 함수형의 조화

모델 설계에는 Flax NNX를 활용했다. NNX는 상태State와 가중치를 객체 내부에서 관리하면서도 JAX의 순수 함수 성질을 훼손하지 않는다. 이는 반복 루프가 많은 디퓨전 모델의 구현 복잡도를 획기적으로 낮춰주며, TPU v6e의 고속 HBM 메모리 자원을 효율적으로 관리할 수 있게 해준다.

5.3 실험 모델 사양

단일 TPU v6e 칩에서 연산 효율을 극대화하도록 설계된 모델 사양이다.

파라미터설정값비고
Hidden Size1280TPU MXU 가동률을 고려한 배수 설정
Layers16양방향 문맥 전파의 깊이 확보
Diffusion Steps128품질과 추론 속도의 최적 지점
PrecisionBF16/FP16TPU v6e의 하드웨어 가속 활용
표 2: 이산형 디퓨전 모델(320M) 사양

6. 구현 및 데이터 전략

학습 데이터로는 3~4세 수준의 짧은 서사 구조를 가진 TinyStories 데이터셋을 선택했다. 모델 규모는 작지만 문장의 인과 관계를 익히는 데 매우 효율적이다.

roneneldan/TinyStories · Datasets at Hugging Face

We’re on a journey to advance and democratize artificial intelligence through open source and open science.

iconhttps://huggingface.co/datasets/roneneldan/TinyStories
preview

6.1 토크나이저 최적화 (BPE)

어휘 사전의 크기를 4,096개로 한정한 BPEByte Pair Encoding[4] 방식을 직접 학습했다. 이를 통해 임베딩 레이어의 파라미터를 절약하고, 해당 자원을 트랜스포머 레이어의 깊이를 더하는 데 재배치하여 추론 능력을 강화했다.

6.2 TPU 최적화 Corrupt 로직

TPU 연산 성능을 극대화하기 위해 조건문(if) 대신 벡터 연산인 jnp.where를 적극 활용했다. TPU 하드웨어는 제어 흐름Control Flow이 복잡할수록 연산 파이프라인이 멈추는 병목이 발생하기 쉬운데, 이를 병렬 벡터 연산으로 대체하여 MXU 연산량을 높였다.

corrupt_logic.py
def corrupt_batch(batch, rng, mask_token_id, t_steps): # 코사인 스케줄에 따른 유지 확률 계산 survival_prob = jnp.cos((t / t_steps) * (jnp.pi / 2)) # 마스크가 False인 위치를 mask_token_id로 교체 (TPU 최적화) return jnp.where(mask, batch, mask_token_id), t

6.3 신뢰도 기반 재마스킹Confidence-based Re-masking

추론 과정에서는 모델이 예측한 토큰 중 확률값이 높은(신뢰도가 높은) 토큰만 남기고 나머지는 다시 가리는 과정을 반복한다. 이는 단번에 정답을 도출하기보다, 여러 단계에 걸쳐 확률 분포를 그리디Greedy[5]하게 고착되지 않도록 점진적으로 수렴시키는 방식이다.

7. 결과 및 시각화 분석

학습 결과, 모델은 빈칸을 앞뒤 문맥에 맞춰 유연하게 채워 넣는 능력을 보여주었다. 특히 자동 회귀 모델이 흔히 저지르는 문맥 단절 현상이 현저히 줄어들었음을 확인했다.

Diffusion Generation Trace
그림 3: 타임스텝에 따른 텍스트 복원 과정 시각화

학습에 활용된 Cloud TPU 인프라의 전체적인 구조는 아래와 같다.

Cloud TPU v5e Architecture
그림 4: 학습에 활용된 Cloud TPU 인프라 아키텍처(Source: Google Cloud Blog)

8. 마치며

언어 모델에 디퓨전을 적용하는 것은 고정관념을 깨는 도전적인 접근이다. 문장 전체를 하나의 유기체처럼 다듬는 이 방식은 기존 자동 회귀 모델의 한계를 보완할 수 있는 강력한 대안이다. 특히 JAX와 TPU v6e가 제공하는 압도적인 연산 성능은 이러한 반복적 생성 알고리즘을 실용적인 수준으로 끌어올리는 핵심 동력이다.

본 프로젝트의 경험이 새로운 생성 아키텍처를 고민하는 엔지니어들에게 영감이 되기를 바란다.


TPU v6e 활용 가이드: TPU v6e는 트랜스포머 연산에 특화된 시스톨릭 어레이 구조를 갖추고 있다. 구글 클라우드의 TPU Flex-start를 활용하면 합리적인 비용으로 고성능 가속기를 직접 경험해 볼 수 있다.


각주


  • 1: 노출 편향Exposure Bias: 학습 시에는 실제 정답(Ground Truth)을 입력으로 사용하지만, 추론 시에는 이전 시점의 모델 예측값을 입력으로 사용함에 따라 오차가 누적되는 현상이다. [↩︎]
  • 2: 자동 회귀 모델은 단방향 인과 관계(t1tt-1 \to t)를 기반으로 하기에, 현재 시점에서 과거의 결정을 번복하는 역방향 최적화가 불가능하다. [↩︎]
  • 3: 이산적Discrete: 데이터의 값이 연속적이지 않고 서로 떨어져 있는 상태를 말한다. 자연어는 이미지 같은 연속적인 데이터와 달리, 단어나 문자와 같이 하나하나 구분되는 토큰으로 구성된 대표적인 이산적 데이터다. [↩︎]
  • 4: BPEByte Pair Encoding: 빈번하게 등장하는 글자 쌍을 하나의 토큰으로 합쳐 나가는 서브워드 토크나이징 알고리즘이다. 어휘 사전을 효율적으로 구성할 수 있어 최신 언어 모델에서 표준적으로 사용된다. [↩︎]
  • 5: 그리디Greedy 현상: 모델이 전체 문맥을 살피기보다 당장 해당 위치에서 확률이 가장 높은 토큰만을 선택하려 하여, 문장이 단조로워지거나 특정 표현이 반복되는 현상을 말한다. [↩︎]

추천 아티클