作者 / 魏巍,開發技術推廣工程師
如果您對如何使用 JAX 從頭開始構建語言模型感到好奇,那么本文非常適合您。我們在 2025 年 Google Cloud Next 大會上舉辦了一場關于此主題的研討會,并獲得了一些很好的反饋,我們也為所有無法參會的開發者編寫了這份指南。
本文和代碼示例將引導您構建并預訓練 GPT-2 模型,了解 JAX 如何直接利用 Google TPU 的強大能力。您可以使用 Colab 或 Kaggle 中的 TPU 免費運行整個項目,并獲取完整的Notebook。
Notebook
https://github.com/windmaple/LLM_from_scratch.JAX/tree/main/02.GPT2-pretraining
這是一個實踐教程,如果您還不熟悉 JAX,我們建議您從《PyTorch 開發者指南: JAX 基礎知識》入手。
PyTorch 開發者指南: JAX 基礎知識
https://cloud.google.com/blog/products/ai-machine-learning/guide-to-jax-for-pytorch-developers
首先,讓我們快速了解一下將要用到的工具。
JAX 生態系統
在開始構建模型之前,讓我們先簡要介紹一下 JAX 生態系統。JAX 生態系統采用模塊化方法,通過 JAX 核心提供核心數值處理能力,而一系列豐富的庫則在此基礎上構建而成,以滿足不同應用的特定需求,如用于構建神經網絡的Flax、用于檢查點和模型持久性的Orbax以及用于優化的Optax(在本文中,這 3 個工具都將被用到)。內置函數轉換,如 autograd、矢量化和 JIT 編譯,加上強大的性能和易于使用的 API,使 JAX 非常適合訓練大語言模型。
JAX 生態系統
https://docs.jax.dev/en/latest/#ecosystem
Flax
https://github.com/google/flax
Orbax
https://github.com/google/orbax
Optax
https://github.com/google-deepmind/optax
入門指南: 構建您的 GPT-2 模型
OpenAI 此前發布了GPT-2 模型代碼和權重,這為我們提供了寶貴的參考資料,并且社區也付出了很多努力來復現該模型,例如nanoGPT。以下是 GPT-2 的高層級模型架構圖:

GPT-2 模型代碼和權重
https://github.com/openai/gpt-2
nanoGPT
https://github.com/karpathy/nanoGPT
我們將使用NNX (新的 Flax 接口)來構建 GPT-2 模型。簡潔起見,我們重點關注 Transformer Block,這是現代大語言模型的關鍵所在。Transformer Block 會捕獲任何序列的長程依賴關系,并構建豐富的上下文理解。GPT-2 Transformer Block 由 2 個 LayerNorm 層、1 個多頭注意力 (MHA) 層、2 個 Dropout 層、2 個線性投影層和 2 個殘差連接組成。因此,我們首先需要在TransformerBlock類的__init__函數中定義這些層:
classTransformerBlock(nnx.Module):
def__init__(
self,
embed_dim:int,
num_heads:int,
ff_dim:int,
dropout_rate:float,
rngs: nnx.Rngs,
):
self.layer_norm1 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim, rngs=rngs
)
self.mha = nnx.MultiHeadAttention(
num_heads=num_heads, in_features=embed_dim, rngs=rngs
)
self.dropout1 = nnx.Dropout(rate=dropout_rate)
self.layer_norm2 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim, rngs=rngs
)
self.linear1 = nnx.Linear(
in_features=embed_dim, out_features=ff_dim, rngs=rngs
)
self.linear2 = nnx.Linear(
in_features=ff_dim, out_features=embed_dim, rngs=rngs
)
self.dropout2 = nnx.Dropout(rate=dropout_rate)
NNX (新的 Flax 接口)
https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/index.html#
接下來,我們需要在__call__函數中對這些層進行組合:
classTransformerBlock(nnx.Module):
def__call__(self, inputs, training:bool=False):
input_shape = inputs.shape
bs, seq_len, emb_sz = input_shape
attention_output = self.mha(
inputs_q=self.layer_norm1(inputs),
mask=causal_attention_mask(seq_len),
decode=False,
)
x = inputs + self.dropout1(
attention_output, deterministic=nottraining
)
# MLP
mlp_output = self.linear1(self.layer_norm2(x))
mlp_output = nnx.gelu(mlp_output)
mlp_output = self.linear2(mlp_output)
mlp_output = self.dropout2(
mlp_output, deterministic=nottraining
)
returnx + mlp_output
如果您使用過任何其他機器學習框架 (如 PyTorch 或 TensorFlow) 來訓練語言模型,那么您對這段代碼應該非常熟悉。但 JAX 具有通過SPMD(Single Program Multiple Data) 自動并行運行代碼的強大能力。這項功能至關重要,因為我們將在多個加速器 (多個 TPU 核心) 上運行代碼。讓我們來看看它的工作原理。
SPMD
https://docs.jax.dev/en/latest/sharded-computation.html
要執行 SPMD,首先我們需要確保自己使用的是 TPU。如果您使用的是 Colab 或 Kaggle,請選擇 TPU 運行時 (您也可以使用 Cloud TPU 虛擬機)。
import jax jax.devices() # Free-tier Colab offers TPU v2: #[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), # TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), # TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), # TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), # TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), # TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), # TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), # TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Colab 和 Kaggle 提供 TPU v2 或 v3,其中含有 8 個獨立的 TPU 核心。TPU v3 托盤的外觀如下所示:

訓練您的 GPT-2 模型
為了高效訓練 GPT-2 模型,我們將通過 SPMD 讓所有 TPU 核心協同運行,并利用 JAX 中的數據并行。為此,我們定義了一個硬件網格:
mesh= jax.make_mesh((8,1), ('batch','model'))
數據并行
https://en.wikipedia.org/wiki/Data_parallelism
我們可以將網格視為加速器的 2D 矩陣。在本例中,我們為網格定義了兩個軸:batch軸和model軸。因此,我們總共有 8 x 1 個核心,也就是 8 個核心。這些軸決定了我們如何劃分數據和模型參數。如果之后想嘗試其他并行方案,我們可以對這些軸進行調整。
現在,我們通過告訴 JAX 如何使用 "model" 軸劃分模型參數來更改__init__函數。這是通過在初始化權重張量 (weight tensors) 時添加nnx.with_partitioning來實現的: 對于像 LayerNorm 縮放/偏置張量這樣的 1D 權重張量 (weight tensors),我們直接沿著 "model" 軸對它們進行分片;對于像 MHA 和線性內核張量這樣的 2D 權重張量,我們沿著model軸對第二維度進行分片。
classTransformerBlock(nnx.Module):
def__init__(
self,
embed_dim:int,
num_heads:int,
ff_dim:int,
dropout_rate:float,
rngs: nnx.Rngs,
):
self.layer_norm1 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim,rngs=rngs, rngs=rngs,
scale_init=nnx.with_partitioning(
nnx.initializers.ones_init(),
("model"),
),
bias_init=nnx.with_partitioning(
nnx.initializers.zeros_init(),
("model"),
),
)
self.mha = nnx.MultiHeadAttention(
num_heads=num_heads, in_features=embed_dim,
kernel_init=nnx.with_partitioning(
nnx.initializers.xavier_uniform(),
(None,"model"),
),
bias_init=nnx.with_partitioning(
nnx.initializers.zeros_init(),
("model"),
),
)
# Other layers in the block are omitted for brevity
我們需要像這樣劃分其他層,以便為整個 GPT-2 模型啟用模型張量并行。即使我們在本教程中不會使用模型張量并行,實現這一功能仍然是比較好的做法,因為隨著模型規模的增長,我們將來可能需要對模型參數進行分區。實現后,我們只需更改一行代碼即可立即運行更大的模型。例如:
mesh= jax.make_mesh((4,2), ('batch','model'))
接下來,我們需要定義loss_fn和train_step函數,與此前文章類似。train_step()函數會計算交叉熵損失函數的梯度,并通過優化器更新權重,然后在循環中被調用來訓練模型。為了獲得最佳性能,我們使用@nnx.jit裝飾器對這兩個函數進行 JIT 編譯,因為它們屬于計算密集型函數。
@nnx.jit
defloss_fn(model, batch):
logits = model(batch[0])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch[1]
).mean()
returnloss, logits
@nnx.jit
deftrain_step(
model: nnx.Module,
optimizer: nnx.Optimizer,
metrics: nnx.MultiMetric,
batch,
):
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(model, batch)
metrics.update(loss=loss, logits=logits, lables=batch[1])
optimizer.update(grads)
此前文章
https://cloud.google.com/blog/products/ai-machine-learning/guide-to-jax-for-pytorch-developers
對于優化器,我們使用 Optax 中的 AdamW 以及余弦衰減調度。您也可以在 Optax 中試用其他優化器或調度計劃。
schedule = optax.cosine_decay_schedule( init_value=init_learning_rate, decay_steps=max_steps ) optax_chain = optax.chain( optax.adamw(learning_rate=schedule, weight_decay=weight_decay) ) optimizer = nnx.Optimizer(model, optax_chain)
其他優化器
https://optax.readthedocs.io/en/latest/api/optimizers.html
調度計劃
https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html
最后,我們需要創建一個簡單的訓練循環。
while True:
input_batch, target_batch =get_batch("train")
train_step(
model,
optimizer,
train_metrics,
jax.device_put(
(input_batch, target_batch),
NamedSharding(mesh,P("batch", None)),
),
)
step +=1
if step > max_steps:
break
請注意我們使用jax.device_put函數沿著 batch 軸對輸入數據進行分區。在這種情況下,JAX 將啟用數據并行,并通過自動插入通信集合 (AllReduce) 將所有內容整合在一起,同時盡可能多地實現計算與通信的重疊。有關并行計算更深入的討論,請參閱 JAX 的并行編程入門文檔。
并行編程入門
https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#intro-and-a-quick-example
模型此時應處于訓練狀態,如果使用權重和偏差來跟蹤運行情況,我們便可以觀察訓練損失。以下是訓練 GPT-2 124M 模型的測試運行結果:

權重和偏差
https://wandb.ai/site
如果使用 Kaggle TPU v3,訓練時間大約為 7 個小時 (我們可以不中斷地使用 Kaggle TPU v3 9 個小時);但如果使用Trillium,訓練時間將縮短至約 1.5 個小時 (請注意,Trillium 的每個芯片配備 32G 高帶寬內存 (HBM),因此我們可以將批量大小加倍,并將訓練步數減半)。
Trillium
https://cloud.google.com/blog/products/compute/trillium-tpu-is-ga
最終的損失情況與nanoGPT 的損失情況大致相符。我們在編寫此代碼示例時對 nanoGPT 進行了研究。

nanoGPT 的損失情況
https://github.com/karpathy/nanoGPT/tree/master?tab=readme-ov-file#baselines
如果使用 Cloud TPU,我們還可以通過 "tpu-info" 命令 (Cloud TPU 監控調試包的一部分) 或權重和偏差儀表盤監控 TPU 利用率。我們的 TPU 正在全力運行!

Cloud TPU 監控調試
https://github.com/AI-Hypercomputer/cloud-tpu-monitoring-debugging
完成模型訓練后,我們可以使用Orbax保存模型:
checkpointer = orbax.PyTreeCheckpointer() train_state = nnx.pure(nnx.state(model)) checkpointer.save(checkpoint_path, train_state)
Orbax
https://github.com/google/orbax
后續步驟: 探索高級 LLM 訓練和擴展
這基本上就是我們訓練 GPT-2 模型所需了解的全部內容。您可以在完整的Notebook中找到其他詳細信息,如數據加載、超參數、指標等。
Notebook
https://github.com/windmaple/LLM_from_scratch.JAX/tree/main/02.GPT2-pretraining
當然,GPT-2 如今還是一個小模型,許多前沿實驗室正在訓練擁有數千億參數的模型。但是,現在您已經學習了如何使用 JAX 和 TPU 構建小語言模型,為深入了解如何擴展模型做好了準備。
如何擴展模型
https://jax-ml.github.io/scaling-book/
此外,您既可以使用MaxText來訓練預構建的前沿 LLM,也可以通過參考JAX LLM 示例或Stanford Marin 模型來學習如何從頭開始構建最新的模型。
MaxText
https://github.com/AI-Hypercomputer/maxtext
JAX LLM 示例
https://github.com/jax-ml/jax-llm-examples/
Stanford Marin 模型
https://developers.googleblog.com/en/stanfords-marin-foundation-model-first-fully-open-model-developed-using-jax/
我們期待看到您使用 JAX 和 TPU 構建的出色模型!
-
模型
+關注
關注
1文章
3658瀏覽量
51804 -
代碼
+關注
關注
30文章
4947瀏覽量
73291 -
TPU
+關注
關注
0文章
166瀏覽量
21553 -
pytorch
+關注
關注
2文章
813瀏覽量
14736
原文標題:實戰指南|手把手教您在 TPU 上免費使用 JAX 訓練 GPT-2 模型
文章出處:【微信號:Google_Developers,微信公眾號:谷歌開發者】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
如何利用Google Colab的云TPU加速Keras模型訓練
OpenAI發布了一個“逆天”的AI模型——GPT2整個模型包含15億個參數
OpenAI發布一款令人印象深刻的語言模型GPT-2
布朗大學90后研究生:我們復現了15億參數GPT-2模型,你也行!
OpenAI宣布,發布了7.74億參數GPT-2語言模型
和AI聊天,自然語言模型 GPT-2可能會推出個人信息
GPT系列的“高仿” 最大可達GPT-3大小 自主訓練
使用NVIDIA TensorRT優化T5和GPT-2
基于OpenAI的GPT-2的語言模型ProtGPT2可生成新的蛋白質序列
GPT/GPT-2/GPT-3/InstructGPT進化之路
ELMER: 高效強大的非自回歸預訓練文本生成模型
DeepSpeed里面和Zero相關技術教程
DeepSpeed結合Megatron-LM訓練GPT2模型筆記
用PaddleNLP為GPT-2模型制作FineWeb二進制預訓練數據集

如何在TPU上使用JAX訓練GPT-2模型
評論