LLM, Transformer 업무에 종사해보거나 공부해본 AI 작업자라면 AutoModel.from_pretrained를 사용해본 적이 있을 것이다. 사전 학습된 모델을 불러오는 허깅페이스 구현체인데, 과연 이 뒤에는 무엇이 있을까? 이 글은 원시 텍스트 데이터가 디스크에서 읽혀 토크나이징 되고, TPU라는 하드웨어를 거쳐 의미를 가진 문장으로 재탄생하기까지의 전 과정을 해부한다. 최신 Llama 모델의 설계를 JAX로 직접 구현하며 모델의 사용자에서 모델의 설계자로 거듭나는 과정을 함께해 보자.
전체 코드는 오픈소스로 공개되어 있다. 아래 GitHub 프리뷰에서 저장소를 확인할 수 있다.
HuggingFace와 같은 고수준 라이브러리 덕분에 누구나 몇 줄의 코드로 GPT-4급 모델을 호출하고 파인튜닝할 수 있다. 하지만 편리함은 때로 엔지니어의 눈을 가린다. 모델이 왜 예상치 못한 결과를 내놓는지, 왜 특정 배치 사이즈에서 OOMOut of Memory이 발생하는지, 왜 학습 손실Loss이 줄지 않는지 이해하려면 결국 블랙박스를 열어봐야 한다.
이 프로젝트는 거창한 SOTAState-of-the-art 모델을 목표로 하지 않는다. 대신 다음과 같은 엔지니어링 호기심을 밑바닥부터 해결하는 데 집중한다.
Framework Transition: 구글과 DeepMind가 왜 PyTorch 대신 JAX를 쓰는지, 함수형 프로그래밍 기반 병렬화의 이점을 직접 체감해 본다.
Modern Architecture: 2017년 Transformer와 2024년 Llama 3의 차이를 구현을 통해 살펴보고 RMSNorm과 SwiGLU가 주는 안정성·표현력 개선 효과를 확인한다.
Data Engineering: RAM보다 큰 텍스트 데이터를 어떻게 학습시키는지, OS 가상 메모리 시스템을 활용한 고성능 I/O 파이프라인을 구축해 본다.
2. JAX와 TPU
딥러닝 프레임워크의 대세는 분명 PyTorch와 NVIDIA GPU다. 하지만 이번 프로젝트에서는 JAX와 Cloud TPU v6e를 선택했다. 이는 단순히 새로운 툴을 시도해 보려는 선택이 아니다. LLM 학습의 가장 큰 적인 메모리 병목일명 폰 노이만 병목[1]과 연산 효율성 문제를 근본적인 아키텍처 레벨에서 해결해 보고 싶었기 때문이다.
2.1 TPU v6e
GPU가 그래픽 처리를 위해 다수의 코어를 집적한 방식이라면, TPUTensor Processing Unit는 딥러닝의 핵심인 행렬 곱Matrix Multiplication에 맞춰 설계된 ASIC주문형 반도체이다. 특히 최신 TPU v6eTrillium는 트랜스포머 아키텍처에 최적화된 설계를 보여준다.
시스톨릭 배열Systolic Array[2]: 기존 CPU/GPU는 연산할 때마다 레지스터와 메모리를 오가며 데이터를 읽고 쓴다. 이 과정에서 메모리 속도가 연산 속도를 따라가지 못하는 병목이 발생한다. 반면 TPU는 수천 개의 연산 유닛ALU을 직접 연결하여, 데이터가 칩 내부를 효율적으로 흐르며 연산 되도록 설계했다. 중간 결과를 메모리에 쓰지 않고 다음 유닛으로 바로 넘기기 때문에 대규모 행렬 연산에서 압도적인 처리량Throughput을 확보할 수 있다.
고대역폭 메모리HBM[3]로: LLM 학습은 종종 연산 속도가 아니라 메모리에서 데이터를 퍼 나르는 속도에 의해 결정된다. TPU v6e는 이전 세대 대비 비약적으로 상승한 HBM 대역폭을 제공하여, 수십억 파라미터 모델의 가중치와 데이터를 지연 없이 공급할 수 있다.
최근 VSCode와 Colab의 TPU v6e 런타임, 그리고 JAX를 결합한 워크플로우를 테스트해 보았는데, 연산 효율 향상이 뚜렷했다.
VSCode의 새로운 Colab 확장 프로그램을 사용하면 로컬 VSCode에서 .ipynb 노트북을 실행하면서 Colab 백엔드에 연결할 수 있다. 이는 Codex나 Antigravity 같은 도구들을 노트북 워크플로우에 직접 통합할 수 있다는 의미이며, 개발 경험의 유연성을 높여준다.
Colab은 약 1년 전부터 TPU v6e를 지원해 왔으므로, Compute Unit이 있다면 꼭 시도해 볼 가치가 있다. 시간당 3 CU로 32GB의 HBM과 약 918 TFLOPs(bf16)의 피크 연산 능력을 사용할 수 있다.
Colab의 A100 런타임 유형이 PCIe 기반인지 SXM 기반인지 확실하지 않지만, 약 989 TFLOPs(bf16)를 제공하는 SXM 버전을 가정했을 때 비용 효율성은 꽤 차이가 난다.
TPU v6e 런타임은 시간당 약 3.37 CU인 반면, A100 40GB 런타임은 시간당 약 5.37 CU이다. 비용 대비 연산량을 정규화해보면, TPU v6e는 Compute Unit 당 약 272 TFLOPs를 제공하는 반면, A100은 약 197 TFLOPs에 그친다.
즉, 비용 대비 성능 면에서 TPU v6e가 앞선다.
이러한 비교는 자주 언급되지 않는 부분이라 공유하고 싶었다. JAX로 Llama 4 스타일의 디코더 전용 모델을 밑바닥부터 사전 학습시키는 테스트를 진행하면서, Colab에서의 v6e 가치가 기대 이상임을 확인할 수 있었다.
2.2 JAX
그렇다면 왜 PyTorch가 아닌 JAX일까? JAX는 단순한 라이브러리가 아니라 Autograd자동 미분와 XLAAccelerated Linear Algebra 컴파일러의 결합체다. 구글 DeepMind가 AlphaFold나 Gemini 같은 거대 모델을 개발할 때 JAX를 사용하는 데는 이유가 있다.
XLA를 통한 극한의 최적화: PyTorchEager execution는 파이썬 코드를 한 줄씩 실행하며 GPU 커널을 호출한다. 이는 직관적이지만 오버헤드가 크다. 반면 JAX는 jitJust-In-Time 컴파일을 통해 전체 연산 그래프를 분석하고, 여러 연산을 하나의 커널로 합치는 커널 융합Kernel Fusion[4]을 수행한다. 이는 메모리 I/O를 획기적으로 줄여 TPU의 성능을 극대화할 수 있게 해 준다.
함수형 프로그래밍과 병렬화: JAX는 numpy와 유사한 API를 제공하지만, 상태State를 관리하지 않는 순수 함수형 패러다임을 지향한다. 초반 학습 곡선은 가파르지만, vmap(자동 벡터화)이나 pmap(병렬화) 같은 변환 기능을 통해 단 한 줄의 코드로 멀티 디바이스 분산 학습을 구현할 수 있다는 장점이 있다.
3. 환경 설정
TPU 사용을 위해 jax[tpu]를 설치하고, 데이터 처리와 모델링에 필요한 라이브러리를 준비한다.
정제된 지식과 사실 정보를 포함하고 있어 모델의 환각을 줄이고 지식 기반을 다지는 데 유용.
표 1: 학습에 사용한 한국어 데이터셋 구성
4.2 No-Mix 청킹 전략
텍스트를 단순히 max_seq_len 길이로 잘라 이어 붙이면 하나의 시퀀스Attention Window 안에 서로 다른 주제의 문서가 섞이게 된다. 예를 들어 정치 뉴스의 끝부분과 전혀 다른 문맥이 연결되는 식이다. 모델은 이 두 문맥 사이의 관계를 학습하려 할 것이고, 결과적으로 환각Hallucination이 발생하거나 논리적 일관성이 떨어진다.
우리는 No-Mix 전략을 통해 이를 방지할 수 있다.
하나의 시퀀스에는 단 하나의 문서만 포함한다.
문서가 끝나면 시퀀스를 종료하고 패딩Padding을 채운다.
이는 학습 데이터의 낭비처럼 보일 수 있지만, 모델의 수렴 속도와 품질 측면에서 더 안정적인 결과를 준다.
data_loader.py
defload_hf_text_column( dataset_name:str,*, subset:str|None=None, split:str="train", text_col:str="text", token:str|None=None, data_dir:str|None=None,)->list[str]:if token isNone: token = os.getenv("HF_TOKEN")if subset isnotNone: ds = load_dataset( dataset_name, subset, split=split, data_dir=data_dir, token=token,)else: ds = load_dataset( dataset_name, split=split, data_dir=data_dir, token=token,) texts:list[str]=[]for row in ds: text =str(row[text_col]).strip()if text: texts.append(text)return texts
defload_dialog_csv(path:str)->list[str]: df = pd.read_csv(path)if"req"notin df.columns or"res"notin df.columns:raise ValueError("CSV must contain 'req' and 'res' columns.") df = df[["req","res"]].fillna("") dialogs:list[str]=[]for req, res inzip(df["req"], df["res"]): req_str =str(req).strip() res_str =str(res).strip()ifnot req_str andnot res_str:continue text =f"{req_str}\n{res_str}\n" dialogs.append(text)return dialogs
defchunk_single_text( text:str, max_chars:int=2000, min_chars:int=200,)-> Iterator[str]: text = text.replace("\r\n","\n").strip()ifnot text:returniflen(text)<= max_chars:yield text
return lines = text.split("\n")buffer:list[str]=[] length =0for line in lines: line_text = line
whileTrue:if length ==0andlen(line_text)> max_chars: seg = line_text[:max_chars] seg = seg.strip()if seg:yield seg
line_text = line_text[max_chars:]ifnot line_text:breakcontinue piece_len =len(line_text)+(1if length >0else0)if length + piece_len > max_chars:ifbuffer: chunk ="\n".join(buffer).strip()if chunk:yield chunk
buffer=[] length =0else: seg = line_text[:max_chars] seg = seg.strip()if seg:yield seg
line_text = line_text[max_chars:]ifnot line_text:breakcontinueelse:if length ==0:buffer=[line_text] length =len(line_text)else:buffer.append(line_text) length +=len(line_text)+1breakifbuffer: chunk ="\n".join(buffer).strip()if chunk:yield chunk
defchunk_texts_no_mix( texts: Iterable[str], max_chars:int=2000, min_chars:int=200,)-> Iterator[str]:for t in texts:yieldfrom chunk_single_text(t, max_chars=max_chars, min_chars=min_chars)defwrite_corpus(chunks: Iterable[str], path:str)->None:withopen(path,"w", encoding="utf-8")as f:for text in chunks: normalized = text.replace("\r\n","\n").strip()ifnot normalized:continue f.write(normalized +"\n\n")webtext = load_hf_text_column("HAERAE-HUB/KOREAN-WEBTEXT-1B", split="train", text_col="text",)synthetictext = load_hf_text_column("HAERAE-HUB/KOREAN-SyntheticText-1.5B", split="train", text_col="text",)wikitext = load_hf_text_column("wikimedia/wikipedia", subset="20231101.ko", split="train", text_col="text",)all_texts:list[str]=[]all_texts.extend(webtext)# all_texts.extend(synthetictext)# all_texts.extend(wikitext)# 학습 시간을 고려하여 우선 webtext만 사용하여 학습을 진행한다.# 만약 의도대로 잘 동작한다면 synthetictext나 wikitext 같은 다른 데이터를 추가로 사용해도 좋다.# 여기서 사용되는 데이터는 의미적 정보를 풍부하게 담고 있으며, 위키 콘텐츠 등으로 학습되었고 셔플된다는 점이 중요하다.rng = np.random.default_rng(seed=0)perm = rng.permutation(len(all_texts))shuffled_texts =[all_texts[i]for i in perm]corpus_chunks:list[str]=list( chunk_texts_no_mix( shuffled_texts, max_chars=2000, min_chars=200,))write_corpus(corpus_chunks,"corpus.txt")print(len(corpus_chunks), corpus_chunks[0][:200])
4.3 데이터 분포 확인
그림 2: 말뭉치 청크 길이 히스토그램
analysis.py
lengths = np.fromiter((len(s)for s in corpus_chunks), dtype=np.int32)print("num chunks:",len(lengths))print("min len :", lengths.min())print("max len :", lengths.max())print("mean len :", lengths.mean())plt.figure(figsize=(8,4))plt.hist(lengths, bins=50)plt.xlabel("chunk length (chars)")plt.ylabel("count")plt.title("Corpus chunk length histogram")plt.grid(True)plt.show()
tokenizer.py
classSentencePieceTokenizer:def__init__(self, cfg: TokenizerConfig): self.cfg = cfg
self.sp =None@propertydefpad_id(self)->int:return self.cfg.pad_id
@propertydefbos_id(self)->int:return self.cfg.bos_id
@propertydefeos_id(self)->int:return self.cfg.eos_id
@propertydefunk_id(self)->int:return self.cfg.unk_id
@propertydefvocab_size(self)->int:if self.sp isNone:return0returnint(self.sp.GetPieceSize())deftrain(self, corpus_path:str)->None: spm.SentencePieceTrainer.Train(input=corpus_path, model_prefix=self.cfg.model_prefix, vocab_size=self.cfg.vocab_size, model_type=self.cfg.model_type, character_coverage=self.cfg.character_coverage, input_sentence_size=self.cfg.input_sentence_size, shuffle_input_sentence=self.cfg.shuffle_input_sentence, train_extremely_large_corpus=self.cfg.train_extremely_large_corpus, pad_id=self.cfg.pad_id, bos_id=self.cfg.bos_id, eos_id=self.cfg.eos_id, unk_id=self.cfg.unk_id, pad_piece="<pad>", bos_piece="<bos>", eos_piece="<eos>", unk_piece="<unk>",) self.sp = spm.SentencePieceProcessor() self.sp.Load(f"{self.cfg.model_prefix}.model")defload(self)->None: self.sp = spm.SentencePieceProcessor() self.sp.Load(f"{self.cfg.model_prefix}.model")defencode(self, text:str, add_special:bool=True)->list[int]:if self.sp isNone:raise RuntimeError("SentencePiece model is not loaded.") ids =list(self.sp.EncodeAsIds(text))if add_special:return[self.bos_id]+ ids +[self.eos_id]return ids
defencode_fixed(self, text:str, seq_len:int)->list[int]: ids = self.encode(text, add_special=True)iflen(ids)>= seq_len:return ids[:seq_len] pad_len = seq_len -len(ids)return ids +[self.pad_id]* pad_len
defdecode(self, ids: Iterable[int], skip_special:bool=True)->str:if self.sp isNone:raise RuntimeError("SentencePiece model is not loaded.") filtered:list[int]=[]for i in ids:if skip_special and i in(self.pad_id, self.bos_id, self.eos_id):continue filtered.append(int(i))return self.sp.DecodeIds(filtered)defiter_corpus(paths:list[str])-> Iterator[str]:for path in paths:withopen(path,"r", encoding="utf-8")as f:for line in f: text = line.strip()if text:yield text
defcount_samples(paths:list[str])->int:returnsum(1for _ in iter_corpus(paths))defencode_and_save_memmap_streaming( paths:list[str], tokenizer: SentencePieceTokenizer, seq_len:int, out_path:str,)->tuple[np.memmap,int]: n_samples = count_samples(paths) arr = np.memmap(out_path, mode="w+", dtype=np.int32, shape=(n_samples, seq_len)) i =0for text in iter_corpus(paths): ids = tokenizer.encode_fixed(text, seq_len) arr[i,:]= np.asarray(ids, dtype=np.int32)if(i +1)%10000==0:print(f"encoded {i +1}/{n_samples}") i +=1 arr.flush()return arr, n_samples
tokenizer = SentencePieceTokenizer(cfg=tok_cfg)tokenizer.train("corpus.txt")seq_len = model_cfg.max_seq_len
corpus_paths =["corpus.txt"]encoded_mm, n_samples = encode_and_save_memmap_streaming( corpus_paths, tokenizer, seq_len, out_path="encoded_int32.dat",)vocab_size = tokenizer.vocab_size
print(f"n_samples={n_samples}, vocab_size={vocab_size}")
5. 모델 아키텍처
우리는 2017년의 Attention Is All You Need 논문이 아닌, Meta가 공개한 오픈형 LLM인 Llama 아키텍처를 참고하여 구성하였다.
Pre-training사전 학습은 모델이 다음 토큰 예측Next Token Prediction 작업으로 언어의 통계적 구조와 지식을 습득하는 과정이다. 이 과정에서 모델은 입력된 문맥을 바탕으로 뒤에 올 단어의 확률 분포를 학습한다.
그림 3: Llama 스타일 Decoder-only 아키텍처 다이어그램
5.1 RMSNorm
기존 LayerNorm은 평균과 분산을 모두 계산하여 정규화한다. 하지만 최근 연구들은 정규화의 핵심이 평균 맞추기Centering가 아니라 스케일 조정Scaling에 있음을 밝혀냈다. 참고 링크
RMSNormRoot Mean Square Normalization은 평균 계산을 제거하여 연산 효율을 높인다.
aˉ∗i=RMS(a)a_ig_i,whereRMS(a)=n1∑∗i=1na_i2
JAX 구현에서는 rsqrtReciprocal Square Root를 사용하여 나눗셈을 곱셈으로 대체, 속도를 더욱 최적화할 수 있다.
Self-Attention: 시퀀스 내의 모든 토큰이 서로를 참조(Query-Key 내적)하여 문맥적 연관성을 계산한다. 이를 통해 대명사나 지칭어가 가리키는 대상을 추적하고, 문장 전체의 의미를 계산한다.
Causal Masking: 학습 시 모델이 미래의 정답Next Token을 미리 컨닝하는 것을 방지하기 위해 사용된다. T×T 어텐션 행렬의 상삼각Upper Triangular 부분을 마스킹Masking하여, 현재 시점 t에서는 t 이전의 토큰들만 볼 수 있게 강제한다.
그림 4: Scaled Dot Product Attention 메커니즘
특히 Decoder-only 구조에서는 Masked Multi-Head Attention이 필수적이다. 아래 그림과 같이 행렬의 대각선 윗부분(미래 시점)을 −∞로 마스킹하면, Softmax를 거친 후 확률이 0이 되어 정보가 차단된다. 이를 통해 모델은 오직 과거의 정보만을 바탕으로 다음 단어를 예측하도록 훈련된다.
그림 5: Causal Masking을 위한 삼각 행렬 구조
model.py
classMultiHeadSelfAttention(nn.Module): d_model:int n_heads:intdefsetup(self)->None: self.qkv_proj = nn.Dense(3* self.d_model, use_bias=False) self.out_proj = nn.Dense(self.d_model, use_bias=False)def__call__(self, x: jnp.ndarray)-> jnp.ndarray: b, t, d = x.shape
h = self.n_heads
d_head = d // h
qkv = self.qkv_proj(x) q, k, v = jnp.split(qkv,3, axis=-1) q = q.reshape(b, t, h, d_head) k = k.reshape(b, t, h, d_head) v = v.reshape(b, t, h, d_head) attn_logits = jnp.einsum("bthd,bThd->bhtT", q, k)/ jnp.sqrt(jnp.float32(d_head)) mask = jnp.tril(jnp.ones((t, t), dtype=jnp.bool_)) mask = jnp.reshape(mask,(1,1, t, t)) attn_logits = jnp.where(mask, attn_logits,-1e9) attn = nn.softmax(attn_logits, axis=-1) out = jnp.einsum("bhtT,bThd->bthd", attn, v) out = out.reshape(b, t, d) out = self.out_proj(out)return out
defdecode( self, x: jnp.ndarray,*, cache:dict[str, jnp.ndarray], cache_index: jnp.ndarray,)->tuple[jnp.ndarray,dict[str, jnp.ndarray]]: b, t, d = x.shape
h = self.n_heads
d_head = d // h
qkv = self.qkv_proj(x) q, k, v = jnp.split(qkv,3, axis=-1) q = q.reshape(b, t, h, d_head) k = k.reshape(b, t, h, d_head) v = v.reshape(b, t, h, d_head) k_cache = cache["k"] v_cache = cache["v"] k_cache = jax.lax.dynamic_update_slice(k_cache, k,(0, cache_index,0,0)) v_cache = jax.lax.dynamic_update_slice(v_cache, v,(0, cache_index,0,0)) total_len = cache_index + t
k_used = k_cache[:,:total_len,:,:] v_used = v_cache[:,:total_len,:,:] attn_logits = jnp.einsum("bthd,bThd->bhtT", q, k_used)/ jnp.sqrt(jnp.float32(d_head)) attn = nn.softmax(attn_logits, axis=-1) out = jnp.einsum("bhtT,bThd->bthd", attn, v_used) out = out.reshape(b, t, d) out = self.out_proj(out) new_cache ={"k": k_cache,"v": v_cache}return out, new_cache
5.3 SwiGLU
Llama는 FFNFeed Forward Network에 SwiGLUSiLU Gated Linear Unit를 도입했다. 이는 단순한 활성화 함수ReLU가 아니라, 정보의 흐름을 제어하는 게이트Gate[5] 역할을 한다. 참고 링크
SwiGLU(x)=(SiLU(xW_G)⊙xW_Up)W_Down
게이트 경로(xW_G)의 값이 작으면 정보(xW_Up)를 차단하고, 크면 통과시킨다. 이는 모델이 어떤 정보를 기억하고 어떤 정보를 망각할지 학습하는 능력을 부여한다.
model.py
classFeedForward(nn.Module): d_model:int d_ff:intdefsetup(self)->None: self.gate = nn.Dense(self.d_ff) self.up = nn.Dense(self.d_ff) self.down = nn.Dense(self.d_model)def__call__(self, x: jnp.ndarray)-> jnp.ndarray: gate = self.gate(x) up = self.up(x) x = nn.silu(gate)* up
x = self.down(x)return x
5.4 블록 조립
마지막으로 이 부품들을 조립한다. 여기서 Pre-Norm 구조(Norm → Attention → Add)를 사용하는 것이 중요하다. 이는 깊은 신경망에서 그래디언트 소실을 방지하고 학습 초기 안정성을 획기적으로 개선한다.
model.py
classDecoderBlock(nn.Module): d_model:int n_heads:int d_ff:intdefsetup(self)->None: self.norm1 = RMSNorm() self.norm2 = RMSNorm() self.attn = MultiHeadSelfAttention(self.d_model, self.n_heads) self.ff = FeedForward(self.d_model, self.d_ff)def__call__(self, x: jnp.ndarray)-> jnp.ndarray: h = self.norm1(x) h = self.attn(h) x = x + h
h = self.norm2(x) h = self.ff(h) x = x + h
return x
defdecode( self, x: jnp.ndarray,*, cache:dict[str, jnp.ndarray], cache_index: jnp.ndarray,)->tuple[jnp.ndarray,dict[str, jnp.ndarray]]: h = self.norm1(x) h, new_cache = self.attn.decode(h, cache=cache, cache_index=cache_index) x = x + h
h = self.norm2(x) h = self.ff(h) x = x + h
return x, new_cache
classDecoderTransformer(nn.Module): vocab_size:int d_model:int n_heads:int d_ff:int n_layers:int max_seq_len:intdefsetup(self)->None: self.tok_embed = nn.Embed(num_embeddings=self.vocab_size, features=self.d_model) self.pos_embed = nn.Embed(num_embeddings=self.max_seq_len, features=self.d_model) self.blocks =[ DecoderBlock(self.d_model, self.n_heads, self.d_ff)for _ inrange(self.n_layers)] self.norm = RMSNorm() self.head = nn.Dense(self.vocab_size)def__call__(self, x: jnp.ndarray)-> jnp.ndarray: _, t = x.shape
pos = jnp.arange(t)[None,:] h = self.tok_embed(x)+ self.pos_embed(pos)for blk in self.blocks: h = blk(h) h = self.norm(h) logits = self.head(h)return logits
defdecode( self, x: jnp.ndarray,*, cache:list[dict[str, jnp.ndarray]], cache_index: jnp.ndarray,)->tuple[jnp.ndarray,list[dict[str, jnp.ndarray]]]: _, t = x.shape
pos =(cache_index + jnp.arange(t))[None,:] h = self.tok_embed(x)+ self.pos_embed(pos) new_caches:list[dict[str, jnp.ndarray]]=[]for i, blk inenumerate(self.blocks): h, layer_cache = blk.decode(h, cache=cache[i], cache_index=cache_index) new_caches.append(layer_cache) h = self.norm(h) logits = self.head(h)return logits, new_caches
definit_cache(self, batch_size:int)->list[dict[str, jnp.ndarray]]: d_head = self.d_model // self.n_heads
k = jnp.zeros((batch_size, self.max_seq_len, self.n_heads, d_head), dtype=jnp.float32) v = jnp.zeros_like(k)return[{"k": k,"v": v}for _ inrange(self.n_layers)]
6. 학습
JAX 학습의 핵심은 @jax.jit이다. 파이썬 함수 train_step은 최초 실행 시 추적되어 XLA 그래프로 변환되고, TPU 기계어로 컴파일된다. 이 과정에서 연산들이 융합Fusion되어 메모리 병목을 제거한다.
학습된 모델로 텍스트를 생성할 때, 매번 전체 시퀀스를 다시 계산하는 것은 비효율적이다(O(N2)). KV Cache는 이전 시점의 Key/Value를 저장해두고 재사용하여 복잡도를 O(N)으로 낮추는 핵심 최적화 기법이다.
생성 과정Inference에서는 토큰을 하나씩 생성한다. 100번째 토큰을 생성할 때, 앞선 99개 토큰에 대한 Key와 Value 값은 이미 계산되어 있다. 이를 다시 계산하는 대신 GPU/TPU 메모리Cache에 저장해 두었다가, 100번째 토큰에 대한 Query와 저장된 Key/Value 만을 연산하면 된다. 이렇게 하면 시퀀스 길이가 길어져도 연산량이 선형적으로만 증가하여 빠른 생성이 가능하다. JAX에서는 jax.lax.dynamic_update_slice를 사용하여 이 캐시를 효율적으로 업데이트할 수 있다.