国产精品久久久aaaa,日日干夜夜操天天插,亚洲乱熟女香蕉一区二区三区少妇,99精品国产高清一区二区三区,国产成人精品一区二区色戒,久久久国产精品成人免费,亚洲精品毛片久久久久,99久久婷婷国产综合精品电影,国产一区二区三区任你鲁

0
  • 聊天消息
  • 系統消息
  • 評論與回復
登錄后你可以
  • 下載海量資料
  • 學習在線課程
  • 觀看技術視頻
  • 寫文章/發帖/加入社區
會員中心
創作中心

完善資料讓更多小伙伴認識你,還能領取20積分哦,立即完善>

3天內不再提示

如何在TPU上使用JAX訓練GPT-2模型

谷歌開發者 ? 來源:谷歌開發者 ? 2025-09-03 11:39 ? 次閱讀
加入交流群
微信小助手二維碼

掃碼添加小助手

加入工程師交流群

作者 / 魏巍,開發技術推廣工程師

如果您對如何使用 JAX 從頭開始構建語言模型感到好奇,那么本文非常適合您。我們在 2025 年 Google Cloud Next 大會上舉辦了一場關于此主題的研討會,并獲得了一些很好的反饋,我們也為所有無法參會的開發者編寫了這份指南。

本文和代碼示例將引導您構建并預訓練 GPT-2 模型,了解 JAX 如何直接利用 Google TPU 的強大能力。您可以使用 Colab 或 Kaggle 中的 TPU 免費運行整個項目,并獲取完整的Notebook。

Notebook

https://github.com/windmaple/LLM_from_scratch.JAX/tree/main/02.GPT2-pretraining

這是一個實踐教程,如果您還不熟悉 JAX,我們建議您從《PyTorch 開發者指南: JAX 基礎知識》入手。

PyTorch 開發者指南: JAX 基礎知識

https://cloud.google.com/blog/products/ai-machine-learning/guide-to-jax-for-pytorch-developers

首先,讓我們快速了解一下將要用到的工具。

JAX 生態系統

在開始構建模型之前,讓我們先簡要介紹一下 JAX 生態系統。JAX 生態系統采用模塊化方法,通過 JAX 核心提供核心數值處理能力,而一系列豐富的庫則在此基礎上構建而成,以滿足不同應用的特定需求,如用于構建神經網絡的Flax、用于檢查點和模型持久性的Orbax以及用于優化的Optax(在本文中,這 3 個工具都將被用到)。內置函數轉換,如 autograd、矢量化和 JIT 編譯,加上強大的性能和易于使用的 API,使 JAX 非常適合訓練大語言模型。

JAX 生態系統

https://docs.jax.dev/en/latest/#ecosystem

Flax

https://github.com/google/flax

Orbax

https://github.com/google/orbax

Optax

https://github.com/google-deepmind/optax

入門指南: 構建您的 GPT-2 模型

OpenAI 此前發布了GPT-2 模型代碼和權重,這為我們提供了寶貴的參考資料,并且社區也付出了很多努力來復現該模型,例如nanoGPT。以下是 GPT-2 的高層級模型架構圖:

dedd83ce-84bb-11f0-a18e-92fbcf53809c.png

GPT-2 模型代碼和權重

https://github.com/openai/gpt-2

nanoGPT

https://github.com/karpathy/nanoGPT

我們將使用NNX (新的 Flax 接口)來構建 GPT-2 模型。簡潔起見,我們重點關注 Transformer Block,這是現代大語言模型的關鍵所在。Transformer Block 會捕獲任何序列的長程依賴關系,并構建豐富的上下文理解。GPT-2 Transformer Block 由 2 個 LayerNorm 層、1 個多頭注意力 (MHA) 層、2 個 Dropout 層、2 個線性投影層和 2 個殘差連接組成。因此,我們首先需要在TransformerBlock類的__init__函數中定義這些層:

classTransformerBlock(nnx.Module):
 def__init__(
    self,
    embed_dim:int,
    num_heads:int,
    ff_dim:int,
    dropout_rate:float,
    rngs: nnx.Rngs,
  ):
    self.layer_norm1 = nnx.LayerNorm(
      epsilon=1e-6, num_features=embed_dim, rngs=rngs
    )
    self.mha = nnx.MultiHeadAttention(
      num_heads=num_heads, in_features=embed_dim, rngs=rngs
    )
    self.dropout1 = nnx.Dropout(rate=dropout_rate)
    self.layer_norm2 = nnx.LayerNorm(
      epsilon=1e-6, num_features=embed_dim, rngs=rngs
    )
    self.linear1 = nnx.Linear(
      in_features=embed_dim, out_features=ff_dim, rngs=rngs
    )
    self.linear2 = nnx.Linear(
      in_features=ff_dim, out_features=embed_dim, rngs=rngs
    )
    self.dropout2 = nnx.Dropout(rate=dropout_rate)

NNX (新的 Flax 接口)

https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/index.html#

接下來,我們需要在__call__函數中對這些層進行組合:

classTransformerBlock(nnx.Module):
 def__call__(self, inputs, training:bool=False):
    input_shape = inputs.shape
    bs, seq_len, emb_sz = input_shape


    attention_output = self.mha(
      inputs_q=self.layer_norm1(inputs),
      mask=causal_attention_mask(seq_len),
      decode=False,
    )
    x = inputs + self.dropout1(
      attention_output, deterministic=nottraining
    )


   # MLP
    mlp_output = self.linear1(self.layer_norm2(x))
    mlp_output = nnx.gelu(mlp_output)
    mlp_output = self.linear2(mlp_output)
    mlp_output = self.dropout2(
      mlp_output, deterministic=nottraining
    )


   returnx + mlp_output

如果您使用過任何其他機器學習框架 (如 PyTorch 或 TensorFlow) 來訓練語言模型,那么您對這段代碼應該非常熟悉。但 JAX 具有通過SPMD(Single Program Multiple Data) 自動并行運行代碼的強大能力。這項功能至關重要,因為我們將在多個加速器 (多個 TPU 核心) 上運行代碼。讓我們來看看它的工作原理

SPMD

https://docs.jax.dev/en/latest/sharded-computation.html

要執行 SPMD,首先我們需要確保自己使用的是 TPU。如果您使用的是 Colab 或 Kaggle,請選擇 TPU 運行時 (您也可以使用 Cloud TPU 虛擬機)。

import jax
jax.devices()


# Free-tier Colab offers TPU v2:
#[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
# TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
# TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
# TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
# TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
# TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
# TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
# TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Colab 和 Kaggle 提供 TPU v2 或 v3,其中含有 8 個獨立的 TPU 核心。TPU v3 托盤的外觀如下所示:

def0e464-84bb-11f0-a18e-92fbcf53809c.png

訓練您的 GPT-2 模型

為了高效訓練 GPT-2 模型,我們將通過 SPMD 讓所有 TPU 核心協同運行,并利用 JAX 中的數據并行。為此,我們定義了一個硬件網格:

mesh= jax.make_mesh((8,1), ('batch','model'))

數據并行

https://en.wikipedia.org/wiki/Data_parallelism

我們可以將網格視為加速器的 2D 矩陣。在本例中,我們為網格定義了兩個軸:batch軸和model軸。因此,我們總共有 8 x 1 個核心,也就是 8 個核心。這些軸決定了我們如何劃分數據和模型參數。如果之后想嘗試其他并行方案,我們可以對這些軸進行調整。

現在,我們通過告訴 JAX 如何使用 "model" 軸劃分模型參數來更改__init__函數。這是通過在初始化權重張量 (weight tensors) 時添加nnx.with_partitioning來實現的: 對于像 LayerNorm 縮放/偏置張量這樣的 1D 權重張量 (weight tensors),我們直接沿著 "model" 軸對它們進行分片;對于像 MHA 和線性內核張量這樣的 2D 權重張量,我們沿著model軸對第二維度進行分片。

classTransformerBlock(nnx.Module):
 def__init__(
    self,
    embed_dim:int,
    num_heads:int,
    ff_dim:int,
    dropout_rate:float,
    rngs: nnx.Rngs,
  ):
    self.layer_norm1 = nnx.LayerNorm(
      epsilon=1e-6, num_features=embed_dim,rngs=rngs, rngs=rngs,
      scale_init=nnx.with_partitioning(
        nnx.initializers.ones_init(),
        ("model"),
      ),
      bias_init=nnx.with_partitioning(
        nnx.initializers.zeros_init(),
       ("model"),
      ),
    )
    self.mha = nnx.MultiHeadAttention(
      num_heads=num_heads, in_features=embed_dim,
      kernel_init=nnx.with_partitioning(
        nnx.initializers.xavier_uniform(),
       (None,"model"),
      ),
      bias_init=nnx.with_partitioning(
        nnx.initializers.zeros_init(),
       ("model"),
      ),
    )
   # Other layers in the block are omitted for brevity

我們需要像這樣劃分其他層,以便為整個 GPT-2 模型啟用模型張量并行。即使我們在本教程中不會使用模型張量并行,實現這一功能仍然是比較好的做法,因為隨著模型規模的增長,我們將來可能需要對模型參數進行分區。實現后,我們只需更改一行代碼即可立即運行更大的模型。例如:

mesh= jax.make_mesh((4,2), ('batch','model'))

接下來,我們需要定義loss_fn和train_step函數,與此前文章類似。train_step()函數會計算交叉熵損失函數的梯度,并通過優化器更新權重,然后在循環中被調用來訓練模型。為了獲得最佳性能,我們使用@nnx.jit裝飾器對這兩個函數進行 JIT 編譯,因為它們屬于計算密集型函數。

@nnx.jit
defloss_fn(model, batch):
  logits = model(batch[0])
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch[1]
  ).mean()
 returnloss, logits




@nnx.jit
deftrain_step(
  model: nnx.Module,
  optimizer: nnx.Optimizer,
  metrics: nnx.MultiMetric,
  batch,
):
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(model, batch)
  metrics.update(loss=loss, logits=logits, lables=batch[1])
  optimizer.update(grads)

此前文章

https://cloud.google.com/blog/products/ai-machine-learning/guide-to-jax-for-pytorch-developers

對于優化器,我們使用 Optax 中的 AdamW 以及余弦衰減調度。您也可以在 Optax 中試用其他優化器或調度計劃。

schedule = optax.cosine_decay_schedule(
  init_value=init_learning_rate, decay_steps=max_steps
)
optax_chain = optax.chain(
  optax.adamw(learning_rate=schedule, weight_decay=weight_decay)
)
optimizer = nnx.Optimizer(model, optax_chain)

其他優化器

https://optax.readthedocs.io/en/latest/api/optimizers.html

調度計劃

https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html

最后,我們需要創建一個簡單的訓練循環。

while True:
  input_batch, target_batch =get_batch("train")


 train_step(
    model,
    optimizer,
    train_metrics,
    jax.device_put(
      (input_batch, target_batch),
     NamedSharding(mesh,P("batch", None)),
    ),
  )


  step +=1
  if step > max_steps:
    break

請注意我們使用jax.device_put函數沿著 batch 軸對輸入數據進行分區。在這種情況下,JAX 將啟用數據并行,并通過自動插入通信集合 (AllReduce) 將所有內容整合在一起,同時盡可能多地實現計算與通信的重疊。有關并行計算更深入的討論,請參閱 JAX 的并行編程入門文檔。

并行編程入門

https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#intro-and-a-quick-example

模型此時應處于訓練狀態,如果使用權重和偏差來跟蹤運行情況,我們便可以觀察訓練損失。以下是訓練 GPT-2 124M 模型的測試運行結果:

df146da8-84bb-11f0-a18e-92fbcf53809c.png

權重和偏差

https://wandb.ai/site

如果使用 Kaggle TPU v3,訓練時間大約為 7 個小時 (我們可以不中斷地使用 Kaggle TPU v3 9 個小時);但如果使用Trillium,訓練時間將縮短至約 1.5 個小時 (請注意,Trillium 的每個芯片配備 32G 高帶寬內存 (HBM),因此我們可以將批量大小加倍,并將訓練步數減半)。

Trillium

https://cloud.google.com/blog/products/compute/trillium-tpu-is-ga

最終的損失情況與nanoGPT 的損失情況大致相符。我們在編寫此代碼示例時對 nanoGPT 進行了研究。

df270288-84bb-11f0-a18e-92fbcf53809c.png

nanoGPT 的損失情況

https://github.com/karpathy/nanoGPT/tree/master?tab=readme-ov-file#baselines

如果使用 Cloud TPU,我們還可以通過 "tpu-info" 命令 (Cloud TPU 監控調試包的一部分) 或權重和偏差儀表盤監控 TPU 利用率。我們的 TPU 正在全力運行!

df3f1f4e-84bb-11f0-a18e-92fbcf53809c.png

Cloud TPU 監控調試

https://github.com/AI-Hypercomputer/cloud-tpu-monitoring-debugging

完成模型訓練后,我們可以使用Orbax保存模型:

checkpointer = orbax.PyTreeCheckpointer()
train_state = nnx.pure(nnx.state(model))
checkpointer.save(checkpoint_path, train_state)

Orbax

https://github.com/google/orbax

后續步驟: 探索高級 LLM 訓練和擴展

這基本上就是我們訓練 GPT-2 模型所需了解的全部內容。您可以在完整的Notebook中找到其他詳細信息,如數據加載、超參數、指標等。

Notebook

https://github.com/windmaple/LLM_from_scratch.JAX/tree/main/02.GPT2-pretraining

當然,GPT-2 如今還是一個小模型,許多前沿實驗室正在訓練擁有數千億參數的模型。但是,現在您已經學習了如何使用 JAX 和 TPU 構建小語言模型,為深入了解如何擴展模型做好了準備。

如何擴展模型

https://jax-ml.github.io/scaling-book/

此外,您既可以使用MaxText來訓練預構建的前沿 LLM,也可以通過參考JAX LLM 示例或Stanford Marin 模型來學習如何從頭開始構建最新的模型。

MaxText

https://github.com/AI-Hypercomputer/maxtext

JAX LLM 示例

https://github.com/jax-ml/jax-llm-examples/

Stanford Marin 模型

https://developers.googleblog.com/en/stanfords-marin-foundation-model-first-fully-open-model-developed-using-jax/

我們期待看到您使用 JAX 和 TPU 構建的出色模型!

聲明:本文內容及配圖由入駐作者撰寫或者入駐合作網站授權轉載。文章觀點僅代表作者本人,不代表電子發燒友網立場。文章及其配圖僅供工程師學習之用,如有內容侵權或者其他違規問題,請聯系本站處理。 舉報投訴
  • 模型
    +關注

    關注

    1

    文章

    3658

    瀏覽量

    51804
  • 代碼
    +關注

    關注

    30

    文章

    4947

    瀏覽量

    73291
  • TPU
    TPU
    +關注

    關注

    0

    文章

    166

    瀏覽量

    21553
  • pytorch
    +關注

    關注

    2

    文章

    813

    瀏覽量

    14736

原文標題:實戰指南|手把手教您在 TPU 上免費使用 JAX 訓練 GPT-2 模型

文章出處:【微信號:Google_Developers,微信公眾號:谷歌開發者】歡迎添加關注!文章轉載請注明出處。

收藏 人收藏
加入交流群
微信小助手二維碼

掃碼添加小助手

加入工程師交流群

    評論

    相關推薦
    熱點推薦

    用PaddleNLP在4060單卡實踐大模型訓練技術

    手把手教您如何在單張消費級顯卡,利用PaddleNLP實踐OpenAI的GPT-2模型的預訓練GPT
    的頭像 發表于 02-19 16:10 ?2172次閱讀
    用PaddleNLP在4060單卡<b class='flag-5'>上</b>實踐大<b class='flag-5'>模型</b>預<b class='flag-5'>訓練</b>技術

    如何利用Google Colab的云TPU加速Keras模型訓練

    TPU包含8個TPU核,每個核都作為獨立的處理單元運作。如果沒有用上全部8個核心,那就沒有充分利用TPU。為了充分加速訓練,相比在單GPU
    的頭像 發表于 11-16 09:10 ?1.1w次閱讀

    OpenAI發布了一個“逆天”的AI模型——GPT2整個模型包含15億個參數

    能有這樣出色的表現,不是沒有原因的,GPT-2各種特定領域的語言建模任務中都取得了很好的分數。作為一個沒有經過任何領域數據專門訓練模型,它的表現,比那些專為特定領域數據集(例如維基百科,新聞,書籍)
    的頭像 發表于 03-07 14:45 ?9245次閱讀

    OpenAI發布一款令人印象深刻的語言模型GPT-2

    今年2月,OpenAI發布了一款令人印象深刻的語言模型GPT-2,它可以寫短篇小說、詩歌,甚至輕松辨別《哈利波特》和《指環王》中的角色。最近,一位加拿大工程師用它創建了一個向公眾開放的文本生成器,只需提供一個句子,機器便能自動生
    的頭像 發表于 05-17 18:48 ?5071次閱讀

    布朗大學90后研究生:我們復現了15億參數GPT-2模型,你也行!

    模型的實現基于Grover模型,并修改其代碼庫以匹配GPT-2的語言建模訓練目標。由于他們的模型是在類似的大型語料庫上進行
    的頭像 發表于 09-01 07:11 ?3766次閱讀

    OpenAI宣布,發布了7.74億參數GPT-2語言模型

    就在本周,OpenAI宣布,發布了7.74億參數GPT-2語言模型,15.58億的完整模型也有望于幾個月內發布,并將GPT-2這6個月的進展情況在博客
    的頭像 發表于 09-01 09:10 ?3498次閱讀

    和AI聊天,自然語言模型 GPT-2可能會推出個人信息

    Stroudsburg……” 自然語言模型 GPT-2就像是收到了某種暗號,立刻“送出”一套 個人信息:姓名、電話號碼,還有地址、郵箱和傳真 (部分信息已打碼)。 這可不是GPT-2瞎編的,而是真實存在的個人信息!這些個人信息
    的頭像 發表于 01-02 09:22 ?3031次閱讀

    GPT系列的“高仿” 最大可達GPT-3大小 自主訓練

    雖然GPT-3沒有開源,卻已經有人在復刻GPT系列的模型了。 例如,慕尼黑工業大學的Connor Leahy,此前用200個小時、6000RMB,復現了GPT-2。 又例如,基于150
    的頭像 發表于 02-13 09:24 ?3265次閱讀

    使用NVIDIA TensorRT優化T5和GPT-2

    在這篇文章中,我們向您介紹了如何將擁抱臉 PyTorch T5 和 GPT-2 模型轉換為優化的 TensorRT 推理引擎。 TensorRT 推理機用作原始 HuggingFace T5
    的頭像 發表于 03-31 17:25 ?4615次閱讀
    使用NVIDIA TensorRT優化T5和<b class='flag-5'>GPT-2</b>

    基于OpenAI的GPT-2的語言模型ProtGPT2可生成新的蛋白質序列

    人類語言與蛋白質有很多共同點,至少在計算建模方面。這使得研究團隊將自然語言處理(NLP)的新方法應用于蛋白質設計。其中,德國Bayreuth大學Birte H?cker的蛋白質設計實驗室,描述了基于OpenAI的GPT-2的語言模型ProtGPT
    的頭像 發表于 09-08 16:24 ?3233次閱讀

    GPT/GPT-2/GPT-3/InstructGPT進化之路

    在預訓練階段,GPT 選擇 transformer 的 decoder 部分作為模型的主要模塊,transformer 是 2017年 google 提出的一種特征抽取模型
    的頭像 發表于 03-03 11:14 ?5098次閱讀

    ELMER: 高效強大的非自回歸預訓練文本生成模型

    每個單詞都依賴于輸入文本與之前生成的單詞。自回歸生成模型只建模了前向的單詞依賴關系,依次生成的結構也使得自回歸模型難以并行化。目前大部分預訓練生成模型均采用自回歸方式,包括
    的頭像 發表于 03-13 10:39 ?2176次閱讀

    DeepSpeed里面和Zero相關技術教程

    和NVMe 分配大規模Megatron-LM模型 以內存為中心的分塊優化 提取權重 ZeRO-Offload概述 訓練環境 在單個 V100 GPU 訓練10B的
    的頭像 發表于 06-12 10:25 ?5590次閱讀
    DeepSpeed里面和Zero相關技術教程

    DeepSpeed結合Megatron-LM訓練GPT2模型筆記

    本文基于DeepSpeedExamples倉庫中給出的Megatron相關例子探索一下訓練GPT2模型的流程。主要包含3個部分,第一個部分是基于原始的Megatron如何訓練
    的頭像 發表于 06-19 14:45 ?4726次閱讀
    DeepSpeed結合Megatron-LM<b class='flag-5'>訓練</b><b class='flag-5'>GPT2</b><b class='flag-5'>模型</b>筆記

    用PaddleNLP為GPT-2模型制作FineWeb二進制預訓練數據集

    作者:算力魔方創始人/英特爾創新大使劉力 《用PaddleNLP在4060單卡實踐大模型訓練技術》發布后收到讀者熱烈反響,很多讀者要求進一步講解更多的技術細節。本文主要針對大語言模型
    的頭像 發表于 03-21 18:24 ?3930次閱讀
    用PaddleNLP為<b class='flag-5'>GPT-2</b><b class='flag-5'>模型</b>制作FineWeb二進制預<b class='flag-5'>訓練</b>數據集