JAX와 TPU를 이용한 Decoder-based Tiny LLM 사전학습

LLM, Transformer 업무에 종사해보거나 공부해본 AI 작업자라면 AutoModel.from_pretrained를 사용해본 적이 있을 것이다. 사전 학습된 모델을 불러오는 허깅페이스 구현체인데, 과연 이 뒤에는 무엇이 있을까? 이 글은 원시 텍스트 데이터가 디스크에서 읽혀 토크나이징 되고, TPU라는 하드웨어를 거쳐 의미를 가진 문장으로 재탄생하기까지의 전 과정을 해부한다. 최신 Llama 모델의 설계를 JAX로 직접 구현하며 모델의 사용자에서 모델의 설계자로 거듭나는 과정을 함께해 보자.

전체 코드는 오픈소스로 공개되어 있다. 아래 GitHub 프리뷰에서 저장소를 확인할 수 있다.

KennethanCeyer/decoder-transformer-on-jax-tpu

The notebook for training and evaluating an autoregressive decoder-only Transformer for Korean text using JAX/Flax on TPU.

iconhttps://github.com/KennethanCeyer/decoder-transformer-on-jax-tpu

1. 들어가며

HuggingFace와 같은 고수준 라이브러리 덕분에 누구나 몇 줄의 코드로 GPT-4급 모델을 호출하고 파인튜닝할 수 있다. 하지만 편리함은 때로 엔지니어의 눈을 가린다. 모델이 왜 예상치 못한 결과를 내놓는지, 왜 특정 배치 사이즈에서 OOMOut of Memory이 발생하는지, 왜 학습 손실Loss이 줄지 않는지 이해하려면 결국 블랙박스를 열어봐야 한다.

이 프로젝트는 거창한 SOTAState-of-the-art 모델을 목표로 하지 않는다. 대신 다음과 같은 엔지니어링 호기심을 밑바닥부터 해결하는 데 집중한다.

  • Framework Transition: 구글과 DeepMind가 왜 PyTorch 대신 JAX를 쓰는지, 함수형 프로그래밍 기반 병렬화의 이점을 직접 체감해 본다.
  • Modern Architecture: 2017년 Transformer와 2024년 Llama 3의 차이를 구현을 통해 살펴보고 RMSNorm과 SwiGLU가 주는 안정성·표현력 개선 효과를 확인한다.
  • Data Engineering: RAM보다 큰 텍스트 데이터를 어떻게 학습시키는지, OS 가상 메모리 시스템을 활용한 고성능 I/O 파이프라인을 구축해 본다.

2. JAX와 TPU

딥러닝 프레임워크의 대세는 분명 PyTorch와 NVIDIA GPU다. 하지만 이번 프로젝트에서는 JAXCloud TPU v6e를 선택했다. 이는 단순히 새로운 툴을 시도해 보려는 선택이 아니다. LLM 학습의 가장 큰 적인 메모리 병목일명 폰 노이만 병목[1]과 연산 효율성 문제를 근본적인 아키텍처 레벨에서 해결해 보고 싶었기 때문이다.

2.1 TPU v6e

GPU가 그래픽 처리를 위해 다수의 코어를 집적한 방식이라면, TPUTensor Processing Unit는 딥러닝의 핵심인 행렬 곱Matrix Multiplication에 맞춰 설계된 ASIC주문형 반도체이다. 특히 최신 TPU v6eTrillium는 트랜스포머 아키텍처에 최적화된 설계를 보여준다.

  • 시스톨릭 배열Systolic Array[2]: 기존 CPU/GPU는 연산할 때마다 레지스터와 메모리를 오가며 데이터를 읽고 쓴다. 이 과정에서 메모리 속도가 연산 속도를 따라가지 못하는 병목이 발생한다. 반면 TPU는 수천 개의 연산 유닛ALU을 직접 연결하여, 데이터가 칩 내부를 효율적으로 흐르며 연산 되도록 설계했다. 중간 결과를 메모리에 쓰지 않고 다음 유닛으로 바로 넘기기 때문에 대규모 행렬 연산에서 압도적인 처리량Throughput을 확보할 수 있다.
  • 고대역폭 메모리HBM[3]로: LLM 학습은 종종 연산 속도가 아니라 메모리에서 데이터를 퍼 나르는 속도에 의해 결정된다. TPU v6e는 이전 세대 대비 비약적으로 상승한 HBM 대역폭을 제공하여, 수십억 파라미터 모델의 가중치와 데이터를 지연 없이 공급할 수 있다.
Cloud TPU v6e Trillium
그림 1: Cloud TPU v6e Trillium(Source: Google Blog)

최근 VSCode와 Colab의 TPU v6e 런타임, 그리고 JAX를 결합한 워크플로우를 테스트해 보았는데, 연산 효율 향상이 뚜렷했다.

VSCode의 새로운 Colab 확장 프로그램을 사용하면 로컬 VSCode에서 .ipynb 노트북을 실행하면서 Colab 백엔드에 연결할 수 있다. 이는 Codex나 Antigravity 같은 도구들을 노트북 워크플로우에 직접 통합할 수 있다는 의미이며, 개발 경험의 유연성을 높여준다.

Colab은 약 1년 전부터 TPU v6e를 지원해 왔으므로, Compute Unit이 있다면 꼭 시도해 볼 가치가 있다. 시간당 3 CU로 32GB의 HBM과 약 918 TFLOPs(bf16)의 피크 연산 능력을 사용할 수 있다.

Colab의 A100 런타임 유형이 PCIe 기반인지 SXM 기반인지 확실하지 않지만, 약 989 TFLOPs(bf16)를 제공하는 SXM 버전을 가정했을 때 비용 효율성은 꽤 차이가 난다. TPU v6e 런타임은 시간당 약 3.37 CU인 반면, A100 40GB 런타임은 시간당 약 5.37 CU이다. 비용 대비 연산량을 정규화해보면, TPU v6e는 Compute Unit 당 약 272 TFLOPs를 제공하는 반면, A100은 약 197 TFLOPs에 그친다.

즉, 비용 대비 성능 면에서 TPU v6e가 앞선다.

이러한 비교는 자주 언급되지 않는 부분이라 공유하고 싶었다. JAX로 Llama 4 스타일의 디코더 전용 모델을 밑바닥부터 사전 학습시키는 테스트를 진행하면서, Colab에서의 v6e 가치가 기대 이상임을 확인할 수 있었다.

2.2 JAX

그렇다면 왜 PyTorch가 아닌 JAX일까? JAX는 단순한 라이브러리가 아니라 Autograd자동 미분XLAAccelerated Linear Algebra 컴파일러의 결합체다. 구글 DeepMind가 AlphaFold나 Gemini 같은 거대 모델을 개발할 때 JAX를 사용하는 데는 이유가 있다.

  • XLA를 통한 극한의 최적화: PyTorchEager execution는 파이썬 코드를 한 줄씩 실행하며 GPU 커널을 호출한다. 이는 직관적이지만 오버헤드가 크다. 반면 JAX는 jitJust-In-Time 컴파일을 통해 전체 연산 그래프를 분석하고, 여러 연산을 하나의 커널로 합치는 커널 융합Kernel Fusion[4]을 수행한다. 이는 메모리 I/O를 획기적으로 줄여 TPU의 성능을 극대화할 수 있게 해 준다.
  • 함수형 프로그래밍과 병렬화: JAX는 numpy와 유사한 API를 제공하지만, 상태State를 관리하지 않는 순수 함수형 패러다임을 지향한다. 초반 학습 곡선은 가파르지만, vmap(자동 벡터화)이나 pmap(병렬화) 같은 변환 기능을 통해 단 한 줄의 코드로 멀티 디바이스 분산 학습을 구현할 수 있다는 장점이 있다.

3. 환경 설정

TPU 사용을 위해 jax[tpu]를 설치하고, 데이터 처리와 모델링에 필요한 라이브러리를 준비한다.

install.sh
!pip install -q "pydantic>=2" flax optax pydantic-settings datasets sentencepiece huggingface_hub !pip install -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
config.py
class ModelSettings(BaseSettings): d_model: int = 1024 n_heads: int = 8 n_layers: int = 8 d_ff: int = 2048 max_seq_len: int = 256 class TrainSettings(BaseSettings): seq_len: int = 256 batch_size: int = 256 num_epochs: int = 2 lr: float = 1e-4 seed: int = 0 tok_cfg = TokenizerSettings(character_coverage=0.9995, model_type="unigram") model_cfg = ModelSettings(max_seq_len=256) train_cfg = TrainSettings(seq_len=model_cfg.max_seq_len)

4. 데이터 엔지니어링

모델 학습 효율을 극대화하기 위해 데이터를 단순히 수집하는 것을 넘어 학습하기 가장 좋은 형태로 가공하는 프로세스를 구축해 보자.

4.1 데이터셋 구성

이번 프로젝트에서는 한국어 모델의 성능 확보를 위해 다음과 같은 공개 데이터셋을 조합하여 사용했다.

데이터셋출처특징 및 역할
HAERAE-HUB/KOREAN-WEBTEXT-1BHAERAE-HUB다양한 웹 문서로 구성되어 있어 일반적인 한국어 문법과 어휘 학습에 적합.
HAERAE-HUB/KOREAN-SyntheticText-1.5BHAERAE-HUBLLM이 생성한 합성 데이터로, 문장 전개와 추론 패턴을 보강하는 데 사용.
wikimedia/wikipedia (ko)Wikimedia정제된 지식과 사실 정보를 포함하고 있어 모델의 환각을 줄이고 지식 기반을 다지는 데 유용.
표 1: 학습에 사용한 한국어 데이터셋 구성

4.2 No-Mix 청킹 전략

텍스트를 단순히 max_seq_len 길이로 잘라 이어 붙이면 하나의 시퀀스Attention Window 안에 서로 다른 주제의 문서가 섞이게 된다. 예를 들어 정치 뉴스의 끝부분과 전혀 다른 문맥이 연결되는 식이다. 모델은 이 두 문맥 사이의 관계를 학습하려 할 것이고, 결과적으로 환각Hallucination이 발생하거나 논리적 일관성이 떨어진다.

우리는 No-Mix 전략을 통해 이를 방지할 수 있다.

  1. 하나의 시퀀스에는 단 하나의 문서만 포함한다.
  2. 문서가 끝나면 시퀀스를 종료하고 패딩Padding을 채운다.

이는 학습 데이터의 낭비처럼 보일 수 있지만, 모델의 수렴 속도와 품질 측면에서 더 안정적인 결과를 준다.

data_loader.py
def load_hf_text_column( dataset_name: str, *, subset: str | None = None, split: str = "train", text_col: str = "text", token: str | None = None, data_dir: str | None = None, ) -> list[str]: if token is None: token = os.getenv("HF_TOKEN") if subset is not None: ds = load_dataset( dataset_name, subset, split=split, data_dir=data_dir, token=token, ) else: ds = load_dataset( dataset_name, split=split, data_dir=data_dir, token=token, ) texts: list[str] = [] for row in ds: text = str(row[text_col]).strip() if text: texts.append(text) return texts def load_dialog_csv(path: str) -> list[str]: df = pd.read_csv(path) if "req" not in df.columns or "res" not in df.columns: raise ValueError("CSV must contain 'req' and 'res' columns.") df = df[["req", "res"]].fillna("") dialogs: list[str] = [] for req, res in zip(df["req"], df["res"]): req_str = str(req).strip() res_str = str(res).strip() if not req_str and not res_str: continue text = f"{req_str}\n{res_str}\n" dialogs.append(text) return dialogs def chunk_single_text( text: str, max_chars: int = 2000, min_chars: int = 200, ) -> Iterator[str]: text = text.replace("\r\n", "\n").strip() if not text: return if len(text) <= max_chars: yield text return lines = text.split("\n") buffer: list[str] = [] length = 0 for line in lines: line_text = line while True: if length == 0 and len(line_text) > max_chars: seg = line_text[:max_chars] seg = seg.strip() if seg: yield seg line_text = line_text[max_chars:] if not line_text: break continue piece_len = len(line_text) + (1 if length > 0 else 0) if length + piece_len > max_chars: if buffer: chunk = "\n".join(buffer).strip() if chunk: yield chunk buffer = [] length = 0 else: seg = line_text[:max_chars] seg = seg.strip() if seg: yield seg line_text = line_text[max_chars:] if not line_text: break continue else: if length == 0: buffer = [line_text] length = len(line_text) else: buffer.append(line_text) length += len(line_text) + 1 break if buffer: chunk = "\n".join(buffer).strip() if chunk: yield chunk def chunk_texts_no_mix( texts: Iterable[str], max_chars: int = 2000, min_chars: int = 200, ) -> Iterator[str]: for t in texts: yield from chunk_single_text(t, max_chars=max_chars, min_chars=min_chars) def write_corpus(chunks: Iterable[str], path: str) -> None: with open(path, "w", encoding="utf-8") as f: for text in chunks: normalized = text.replace("\r\n", "\n").strip() if not normalized: continue f.write(normalized + "\n\n") webtext = load_hf_text_column( "HAERAE-HUB/KOREAN-WEBTEXT-1B", split="train", text_col="text", ) synthetictext = load_hf_text_column( "HAERAE-HUB/KOREAN-SyntheticText-1.5B", split="train", text_col="text", ) wikitext = load_hf_text_column( "wikimedia/wikipedia", subset="20231101.ko", split="train", text_col="text", ) all_texts: list[str] = [] all_texts.extend(webtext) # all_texts.extend(synthetictext) # all_texts.extend(wikitext) # 학습 시간을 고려하여 우선 webtext만 사용하여 학습을 진행한다. # 만약 의도대로 잘 동작한다면 synthetictext나 wikitext 같은 다른 데이터를 추가로 사용해도 좋다. # 여기서 사용되는 데이터는 의미적 정보를 풍부하게 담고 있으며, 위키 콘텐츠 등으로 학습되었고 셔플된다는 점이 중요하다. rng = np.random.default_rng(seed=0) perm = rng.permutation(len(all_texts)) shuffled_texts = [all_texts[i] for i in perm] corpus_chunks: list[str] = list( chunk_texts_no_mix( shuffled_texts, max_chars=2000, min_chars=200, ) ) write_corpus(corpus_chunks, "corpus.txt") print(len(corpus_chunks), corpus_chunks[0][:200])

4.3 데이터 분포 확인

Corpus chunk length histogram
그림 2: 말뭉치 청크 길이 히스토그램
analysis.py
lengths = np.fromiter((len(s) for s in corpus_chunks), dtype=np.int32) print("num chunks:", len(lengths)) print("min len :", lengths.min()) print("max len :", lengths.max()) print("mean len :", lengths.mean()) plt.figure(figsize=(8, 4)) plt.hist(lengths, bins=50) plt.xlabel("chunk length (chars)") plt.ylabel("count") plt.title("Corpus chunk length histogram") plt.grid(True) plt.show()
tokenizer.py
class SentencePieceTokenizer: def __init__(self, cfg: TokenizerConfig): self.cfg = cfg self.sp = None @property def pad_id(self) -> int: return self.cfg.pad_id @property def bos_id(self) -> int: return self.cfg.bos_id @property def eos_id(self) -> int: return self.cfg.eos_id @property def unk_id(self) -> int: return self.cfg.unk_id @property def vocab_size(self) -> int: if self.sp is None: return 0 return int(self.sp.GetPieceSize()) def train(self, corpus_path: str) -> None: spm.SentencePieceTrainer.Train( input=corpus_path, model_prefix=self.cfg.model_prefix, vocab_size=self.cfg.vocab_size, model_type=self.cfg.model_type, character_coverage=self.cfg.character_coverage, input_sentence_size=self.cfg.input_sentence_size, shuffle_input_sentence=self.cfg.shuffle_input_sentence, train_extremely_large_corpus=self.cfg.train_extremely_large_corpus, pad_id=self.cfg.pad_id, bos_id=self.cfg.bos_id, eos_id=self.cfg.eos_id, unk_id=self.cfg.unk_id, pad_piece="<pad>", bos_piece="<bos>", eos_piece="<eos>", unk_piece="<unk>", ) self.sp = spm.SentencePieceProcessor() self.sp.Load(f"{self.cfg.model_prefix}.model") def load(self) -> None: self.sp = spm.SentencePieceProcessor() self.sp.Load(f"{self.cfg.model_prefix}.model") def encode(self, text: str, add_special: bool = True) -> list[int]: if self.sp is None: raise RuntimeError("SentencePiece model is not loaded.") ids = list(self.sp.EncodeAsIds(text)) if add_special: return [self.bos_id] + ids + [self.eos_id] return ids def encode_fixed(self, text: str, seq_len: int) -> list[int]: ids = self.encode(text, add_special=True) if len(ids) >= seq_len: return ids[:seq_len] pad_len = seq_len - len(ids) return ids + [self.pad_id] * pad_len def decode(self, ids: Iterable[int], skip_special: bool = True) -> str: if self.sp is None: raise RuntimeError("SentencePiece model is not loaded.") filtered: list[int] = [] for i in ids: if skip_special and i in (self.pad_id, self.bos_id, self.eos_id): continue filtered.append(int(i)) return self.sp.DecodeIds(filtered) def iter_corpus(paths: list[str]) -> Iterator[str]: for path in paths: with open(path, "r", encoding="utf-8") as f: for line in f: text = line.strip() if text: yield text def count_samples(paths: list[str]) -> int: return sum(1 for _ in iter_corpus(paths)) def encode_and_save_memmap_streaming( paths: list[str], tokenizer: SentencePieceTokenizer, seq_len: int, out_path: str, ) -> tuple[np.memmap, int]: n_samples = count_samples(paths) arr = np.memmap(out_path, mode="w+", dtype=np.int32, shape=(n_samples, seq_len)) i = 0 for text in iter_corpus(paths): ids = tokenizer.encode_fixed(text, seq_len) arr[i, :] = np.asarray(ids, dtype=np.int32) if (i + 1) % 10000 == 0: print(f"encoded {i + 1}/{n_samples}") i += 1 arr.flush() return arr, n_samples tokenizer = SentencePieceTokenizer(cfg=tok_cfg) tokenizer.train("corpus.txt") seq_len = model_cfg.max_seq_len corpus_paths = ["corpus.txt"] encoded_mm, n_samples = encode_and_save_memmap_streaming( corpus_paths, tokenizer, seq_len, out_path="encoded_int32.dat", ) vocab_size = tokenizer.vocab_size print(f"n_samples={n_samples}, vocab_size={vocab_size}")

5. 모델 아키텍처

우리는 2017년의 Attention Is All You Need 논문이 아닌, Meta가 공개한 오픈형 LLM인 Llama 아키텍처를 참고하여 구성하였다.

Pre-training사전 학습은 모델이 다음 토큰 예측Next Token Prediction 작업으로 언어의 통계적 구조와 지식을 습득하는 과정이다. 이 과정에서 모델은 입력된 문맥을 바탕으로 뒤에 올 단어의 확률 분포를 학습한다.

Llama-style Decoder-only Architecture Diagram
그림 3: Llama 스타일 Decoder-only 아키텍처 다이어그램

5.1 RMSNorm

기존 LayerNorm은 평균과 분산을 모두 계산하여 정규화한다. 하지만 최근 연구들은 정규화의 핵심이 평균 맞추기Centering가 아니라 스케일 조정Scaling에 있음을 밝혀냈다. 참고 링크

RMSNormRoot Mean Square Normalization은 평균 계산을 제거하여 연산 효율을 높인다.

aˉi=a_iRMS(a)g_i,where RMS(a)=1ni=1na_i2\bar{a}*i = \frac{a\_i}{\text{RMS}(a)} g\_i, \quad \text{where} \ \text{RMS}(a) = \sqrt{\frac{1}{n} \sum*{i=1}^{n} a\_i^2}

JAX 구현에서는 rsqrtReciprocal Square Root를 사용하여 나눗셈을 곱셈으로 대체, 속도를 더욱 최적화할 수 있다.

model.py
class RMSNorm(nn.Module): eps: float = 1e-6 def setup(self) -> None: self.scale = self.param("scale", nn.initializers.ones, (1,)) def __call__(self, x: jnp.ndarray) -> jnp.ndarray: var = jnp.mean(x * x, axis=-1, keepdims=True) norm_x = x * jax.lax.rsqrt(var + self.eps) return norm_x * self.scale

5.2 Decoder-only Self-Attention과 Causal Masking

Decoder-only Transformer의 핵심은 Self-Attention과 Causal Masking이다.

  • Self-Attention: 시퀀스 내의 모든 토큰이 서로를 참조(Query-Key 내적)하여 문맥적 연관성을 계산한다. 이를 통해 대명사나 지칭어가 가리키는 대상을 추적하고, 문장 전체의 의미를 계산한다.
  • Causal Masking: 학습 시 모델이 미래의 정답Next Token을 미리 컨닝하는 것을 방지하기 위해 사용된다. T×TT \times T 어텐션 행렬의 상삼각Upper Triangular 부분을 마스킹Masking하여, 현재 시점 tt에서는 tt 이전의 토큰들만 볼 수 있게 강제한다.
Scaled Dot Product Attention
그림 4: Scaled Dot Product Attention 메커니즘

특히 Decoder-only 구조에서는 Masked Multi-Head Attention이 필수적이다. 아래 그림과 같이 행렬의 대각선 윗부분(미래 시점)을 -\infty로 마스킹하면, Softmax를 거친 후 확률이 0이 되어 정보가 차단된다. 이를 통해 모델은 오직 과거의 정보만을 바탕으로 다음 단어를 예측하도록 훈련된다.

Masked Tokens (Triangular Matrix)
그림 5: Causal Masking을 위한 삼각 행렬 구조
model.py
class MultiHeadSelfAttention(nn.Module): d_model: int n_heads: int def setup(self) -> None: self.qkv_proj = nn.Dense(3 * self.d_model, use_bias=False) self.out_proj = nn.Dense(self.d_model, use_bias=False) def __call__(self, x: jnp.ndarray) -> jnp.ndarray: b, t, d = x.shape h = self.n_heads d_head = d // h qkv = self.qkv_proj(x) q, k, v = jnp.split(qkv, 3, axis=-1) q = q.reshape(b, t, h, d_head) k = k.reshape(b, t, h, d_head) v = v.reshape(b, t, h, d_head) attn_logits = jnp.einsum("bthd,bThd->bhtT", q, k) / jnp.sqrt(jnp.float32(d_head)) mask = jnp.tril(jnp.ones((t, t), dtype=jnp.bool_)) mask = jnp.reshape(mask, (1, 1, t, t)) attn_logits = jnp.where(mask, attn_logits, -1e9) attn = nn.softmax(attn_logits, axis=-1) out = jnp.einsum("bhtT,bThd->bthd", attn, v) out = out.reshape(b, t, d) out = self.out_proj(out) return out def decode( self, x: jnp.ndarray, *, cache: dict[str, jnp.ndarray], cache_index: jnp.ndarray, ) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]: b, t, d = x.shape h = self.n_heads d_head = d // h qkv = self.qkv_proj(x) q, k, v = jnp.split(qkv, 3, axis=-1) q = q.reshape(b, t, h, d_head) k = k.reshape(b, t, h, d_head) v = v.reshape(b, t, h, d_head) k_cache = cache["k"] v_cache = cache["v"] k_cache = jax.lax.dynamic_update_slice(k_cache, k, (0, cache_index, 0, 0)) v_cache = jax.lax.dynamic_update_slice(v_cache, v, (0, cache_index, 0, 0)) total_len = cache_index + t k_used = k_cache[:, :total_len, :, :] v_used = v_cache[:, :total_len, :, :] attn_logits = jnp.einsum("bthd,bThd->bhtT", q, k_used) / jnp.sqrt(jnp.float32(d_head)) attn = nn.softmax(attn_logits, axis=-1) out = jnp.einsum("bhtT,bThd->bthd", attn, v_used) out = out.reshape(b, t, d) out = self.out_proj(out) new_cache = {"k": k_cache, "v": v_cache} return out, new_cache

5.3 SwiGLU

Llama는 FFNFeed Forward NetworkSwiGLUSiLU Gated Linear Unit를 도입했다. 이는 단순한 활성화 함수ReLU가 아니라, 정보의 흐름을 제어하는 게이트Gate[5] 역할을 한다. 참고 링크

SwiGLU(x)=(SiLU(xW_G)xW_Up)W_Down\text{SwiGLU}(x) = (\text{SiLU}(xW\_G) \odot xW\_{Up})W\_{Down}

게이트 경로(xW_GxW\_G)의 값이 작으면 정보(xW_UpxW\_{Up})를 차단하고, 크면 통과시킨다. 이는 모델이 어떤 정보를 기억하고 어떤 정보를 망각할지 학습하는 능력을 부여한다.

model.py
class FeedForward(nn.Module): d_model: int d_ff: int def setup(self) -> None: self.gate = nn.Dense(self.d_ff) self.up = nn.Dense(self.d_ff) self.down = nn.Dense(self.d_model) def __call__(self, x: jnp.ndarray) -> jnp.ndarray: gate = self.gate(x) up = self.up(x) x = nn.silu(gate) * up x = self.down(x) return x

5.4 블록 조립

마지막으로 이 부품들을 조립한다. 여기서 Pre-Norm 구조(Norm \to Attention \to Add)를 사용하는 것이 중요하다. 이는 깊은 신경망에서 그래디언트 소실을 방지하고 학습 초기 안정성을 획기적으로 개선한다.

model.py
class DecoderBlock(nn.Module): d_model: int n_heads: int d_ff: int def setup(self) -> None: self.norm1 = RMSNorm() self.norm2 = RMSNorm() self.attn = MultiHeadSelfAttention(self.d_model, self.n_heads) self.ff = FeedForward(self.d_model, self.d_ff) def __call__(self, x: jnp.ndarray) -> jnp.ndarray: h = self.norm1(x) h = self.attn(h) x = x + h h = self.norm2(x) h = self.ff(h) x = x + h return x def decode( self, x: jnp.ndarray, *, cache: dict[str, jnp.ndarray], cache_index: jnp.ndarray, ) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]: h = self.norm1(x) h, new_cache = self.attn.decode(h, cache=cache, cache_index=cache_index) x = x + h h = self.norm2(x) h = self.ff(h) x = x + h return x, new_cache class DecoderTransformer(nn.Module): vocab_size: int d_model: int n_heads: int d_ff: int n_layers: int max_seq_len: int def setup(self) -> None: self.tok_embed = nn.Embed(num_embeddings=self.vocab_size, features=self.d_model) self.pos_embed = nn.Embed(num_embeddings=self.max_seq_len, features=self.d_model) self.blocks = [ DecoderBlock(self.d_model, self.n_heads, self.d_ff) for _ in range(self.n_layers) ] self.norm = RMSNorm() self.head = nn.Dense(self.vocab_size) def __call__(self, x: jnp.ndarray) -> jnp.ndarray: _, t = x.shape pos = jnp.arange(t)[None, :] h = self.tok_embed(x) + self.pos_embed(pos) for blk in self.blocks: h = blk(h) h = self.norm(h) logits = self.head(h) return logits def decode( self, x: jnp.ndarray, *, cache: list[dict[str, jnp.ndarray]], cache_index: jnp.ndarray, ) -> tuple[jnp.ndarray, list[dict[str, jnp.ndarray]]]: _, t = x.shape pos = (cache_index + jnp.arange(t))[None, :] h = self.tok_embed(x) + self.pos_embed(pos) new_caches: list[dict[str, jnp.ndarray]] = [] for i, blk in enumerate(self.blocks): h, layer_cache = blk.decode(h, cache=cache[i], cache_index=cache_index) new_caches.append(layer_cache) h = self.norm(h) logits = self.head(h) return logits, new_caches def init_cache(self, batch_size: int) -> list[dict[str, jnp.ndarray]]: d_head = self.d_model // self.n_heads k = jnp.zeros((batch_size, self.max_seq_len, self.n_heads, d_head), dtype=jnp.float32) v = jnp.zeros_like(k) return [{"k": k, "v": v} for _ in range(self.n_layers)]

6. 학습

JAX 학습의 핵심은 @jax.jit이다. 파이썬 함수 train_step은 최초 실행 시 추적되어 XLA 그래프로 변환되고, TPU 기계어로 컴파일된다. 이 과정에서 연산들이 융합Fusion되어 메모리 병목을 제거한다.

train.py
model = DecoderTransformer( vocab_size=vocab_size, d_model=model_cfg.d_model, n_heads=model_cfg.n_heads, d_ff=model_cfg.d_ff, n_layers=model_cfg.n_layers, max_seq_len=model_cfg.max_seq_len, ) key = jax.random.PRNGKey(train_cfg.seed) dummy_x = jnp.ones((1, model_cfg.max_seq_len), dtype=jnp.int32) params = model.init(key, dummy_x) tx = optax.adamw(learning_rate=train_cfg.lr) state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx) def loss_fn(params: dict, batch: jnp.ndarray) -> jnp.ndarray: x = batch[:, :-1] y = batch[:, 1:] logits = model.apply(params, x) loss = optax.softmax_cross_entropy_with_integer_labels(logits, y) return loss.mean() @jax.jit def train_step(state: train_state.TrainState, batch: jnp.ndarray) -> train_state.TrainState: grads = jax.grad(loss_fn)(state.params, batch) state = state.apply_gradients(grads=grads) return state encoded_mm = np.memmap( "encoded_int32.dat", mode="r", dtype=np.int32, shape=(n_samples, seq_len), ) batch_size: int = train_cfg.batch_size steps_per_epoch: int = n_samples // batch_size if steps_per_epoch == 0: raise ValueError("batch_size is larger than dataset size") num_epochs: int = train_cfg.num_epochs total_steps: int = steps_per_epoch * num_epochs log_every: int = 100 rng = np.random.default_rng(train_cfg.seed) global_step: int = 0 for epoch in range(num_epochs): indices = rng.permutation(n_samples) for step_in_epoch in range(steps_per_epoch): start = step_in_epoch * batch_size end = start + batch_size idx = indices[start:end] batch_np = encoded_mm[idx] batch = jnp.array(batch_np, dtype=jnp.int32) state = train_step(state, batch) global_step += 1 if global_step % log_every == 0 or ( epoch == num_epochs - 1 and step_in_epoch == steps_per_epoch - 1 ): loss_val = float(loss_fn(state.params, batch)) epoch_idx = epoch + 1 step_idx = step_in_epoch + 1 global_progress = global_step / total_steps print( f"[epoch {epoch_idx}/{num_epochs} | " f"step {step_idx}/{steps_per_epoch} | " f"global {global_step}/{total_steps} " f"({global_progress * 100:.1f}%)] " f"loss={loss_val:.4f}" )

7. 결과 및 인사이트

학습된 모델로 텍스트를 생성할 때, 매번 전체 시퀀스를 다시 계산하는 것은 비효율적이다(O(N2)O(N^2)). KV Cache는 이전 시점의 Key/Value를 저장해두고 재사용하여 복잡도를 O(N)O(N)으로 낮추는 핵심 최적화 기법이다.

생성 과정Inference에서는 토큰을 하나씩 생성한다. 100번째 토큰을 생성할 때, 앞선 99개 토큰에 대한 Key와 Value 값은 이미 계산되어 있다. 이를 다시 계산하는 대신 GPU/TPU 메모리Cache에 저장해 두었다가, 100번째 토큰에 대한 Query와 저장된 Key/Value 만을 연산하면 된다. 이렇게 하면 시퀀스 길이가 길어져도 연산량이 선형적으로만 증가하여 빠른 생성이 가능하다. JAX에서는 jax.lax.dynamic_update_slice를 사용하여 이 캐시를 효율적으로 업데이트할 수 있다.

generate.py
from functools import partial @partial(jax.jit, static_argnums=0) def decode_step( model: DecoderTransformer, params, x: jnp.ndarray, cache: list[dict[str, jnp.ndarray]], cache_index: jnp.ndarray, ) -> tuple[jnp.ndarray, list[dict[str, jnp.ndarray]]]: logits, new_cache = model.apply( params, x, cache=cache, cache_index=cache_index, method=DecoderTransformer.decode, ) return logits, new_cache def encode_prompt(text: str) -> list[int]: base_ids = tokenizer.encode(text, add_special=False) return [tokenizer.bos_id] + base_ids def generate(prompt: str, max_new_tokens: int = 128) -> str: ids: list[int] = encode_prompt(prompt) if not ids: ids = [tokenizer.bos_id] batch_size = 1 cache = model.init_cache(batch_size) cache_index = 0 for tok in ids: x = jnp.array([[tok]], dtype=jnp.int32) logits, cache = decode_step( model, state.params, x, cache, jnp.array(cache_index, dtype=jnp.int32), ) cache_index += 1 if cache_index >= model.max_seq_len: break for _ in range(max_new_tokens): next_id = int(jnp.argmax(logits[0, -1])) ids.append(next_id) if next_id == tokenizer.eos_id or cache_index >= model.max_seq_len: break x = jnp.array([[next_id]], dtype=jnp.int32) logits, cache = decode_step( model, state.params, x, cache, jnp.array(cache_index, dtype=jnp.int32), ) cache_index += 1 return tokenizer.decode(ids) print(generate("너 좋아하는 차 종류 있어?"))

비록 Tiny 모델[6]이지만, 한국어 문법 구조를 파악하고 간단한 대화 맥락을 따라가는 모습을 확인할 수 있다.

8. 마치며

본 프로젝트를 통해 우리는 JAX와 TPU를 활용하여 LLM 학습 파이프라인의 전 과정을 엔지니어링 관점에서 재구성해 보았다.

  • JAX와 XLA의 효용성: PyTorch의 Eager Execution 대비, XLA 컴파일을 통한 커널 융합Kernel Fusion이 메모리 대역폭 의존도가 높은 LLM 학습에서 장치 활용률MFU을 실질적으로 높이는 것을 확인했다.
  • Modern Architecture의 이점: RMSNorm의 연산 간소화와 SwiGLU의 게이팅 메커니즘이 학습 초기 발산을 억제하고 수렴 속도를 개선하는 데 기여함을 검증했다.
  • 시스템 레벨 최적화: 대규모 텍스트 데이터를 메모리에 모두 올리지 않고, OS의 페이지 캐시와 Memmap을 활용하여 I/O 병목을 최소화하는 파이프라인 구축의 중요성을 확인했다.

이 Tiny LLM 구현체는 거대 모델의 내부를 이해하고, 나아가 새로운 아키텍처를 실험하기 위한 견고한 토대가 될 것이다.


각주


  • 1: 폰 노이만 병목Von Neumann Bottleneck: CPU와 메모리 간의 데이터 전송 속도가 연산 속도를 따라가지 못해 전체 시스템 성능이 저하되는 현상. [↩︎]
  • 2: 시스톨릭 배열Systolic Array: 데이터가 연산 유닛PE 배열을 통과하며 연속적으로 연산되는 구조. 심장 박동Systolic처럼 데이터가 흐른다고 하여 붙여진 이름. [↩︎]
  • 3: HBMHigh Bandwidth Memory: 3D 적층 기술을 이용해 대역폭을 획기적으로 높인 고성능 메모리. [↩︎]
  • 4: 커널 융합Kernel Fusion: 여러 개의 GPU/TPU 연산 커널을 하나로 합쳐 메모리 접근 횟수를 줄이는 최적화 기법. [↩︎]
  • 5: 게이트: 입력마다 요소 단위 가중치를 곱해 유용한 정보만 통과시키고 불필요한 값은 억제하도록 학습되는 경로를 의미한다. [↩︎]
  • 6: tiny: 수천만~1억 파라미터 규모의 실험용 모델을 가리키며, 아키텍처 검증과 학습 파이프라인 점검에 초점을 둔다. [↩︎]

추천 아티클