Article 02 留给我们两种把一次 matmul 切上两张 GPU 的方式。直接看它们在做什么,比记 paper 里那些名字省事得多。先把两种摆一起对一下:

Strategy A —— fxStrategy B —— 切行
切的是矩阵的什么列(每列是一个 fx行(每行是一个基向量)
每张 GPU 拿到的输入每个 token 的完整输入向量每个 token 输入特征的一半
每张 GPU 算出的输出每个 token 输出特征的一半每个 token 完整输出的一个 partial sum
怎么合起来concatenate(免费)all-reduce(一次通信)
别名column-parallelrow-parallel

每一列最精简的读法:A = “full in, half out”B = “half in, sum out”。后面所有东西都建立在这两行字上。

但一个真实的 transformer block 不是单次 matmul,是 四次,再加一些 pointwise 的胶水。下一个自然冒出来的问题就是:一整个 block 该怎么切上两张 GPU?

第一反应有一个非常自然的方案,差点就跑通了。我们先把它搭出来,看清楚它在哪儿坏掉,再让"补这个洞"的过程把我们带到 Megatron 那个经典 pattern。整篇都用一组小数字盯着 shape 走,免得讨论变空气。


1. 起手:一组装得下的小数

两张 GPU,叫 G1G2。一个小 batch 配一个小模型:

batch n4 个 token
model dim d512
heads h8
per-head dim d_head64
attention dim k = h · d_head512
FFN hidden4d = 2048

每个 token 是一行 512 个数。整个 batch 就是 [n × d] = [4 × 512]

把一个 transformer block 拍平:

        ← input: [4 × 512]
  LayerNorm
  QKV projection      d → 3k       ← matmul   weight  [d × 3k] = [512 × 1536]
  attention                          (Q 和 K 互相混;不是一次新的 matmul)
  output projection   k → d        ← matmul   weight  [k × d]  = [512 × 512]
    + residual
  LayerNorm
  FFN up-projection   d → 4d       ← matmul   weight  [d × 4d] = [512 × 2048]
  activation (GeLU)                  (pointwise)
  FFN down-projection 4d → d       ← matmul   weight  [4d × d] = [2048 × 512]
    + residual

四次 matmul,加一些胶水。

顺嘴说一下 pointwise 这些胶水:为什么两张 GPU 都要算一遍? LayerNorm、activation、residual add 这些都是 pointwise 的。它们不在乎数据怎么分布在 GPU 上,只要本地有它要的那份就行。在 TP 里我们偷个懒:只要某份数据完整地在两张 GPU 上都有,就让两张 GPU 各跑一遍这个 pointwise op。同样输入、同样输出,重复算。为什么不让一张 GPU 算完再 broadcast?因为 瓶颈一直是通信,不是算力。pointwise 在 GPU 上对几千个数做一遍,几乎不花时间;跨 GPU 发数据是真金白银的 latency 和 bandwidth。多花点便宜的算力比省那点通信划算。把这条记着 —— 后面的 trace 表里你会看到每个 LN 和 residual 步骤都标着 “(冗余)",就是这个意思。

所以这个 block 整个 TP 故事,全在那四次 matmul 上。两张 GPU,四个 cut 要做。开搞。


2. v1 —— 把 Strategy A(full → half)用到每个 matmul 上

最自然的第一步是什么?回到 article 02,Strategy A 长这样:

  • 便宜的那种 cut(concatenate,matmul 内部不需要 all-reduce);
  • 用在 QKV 上正好落在 head 边界k = 8 · 64 = 512,每张 GPU 256,正好一人 4 个 head;
  • 而且"完整输入进,半个输出出"也是更好脑补的那一版。

那就把 A 一路用到底,四次 matmul 全是 A。一步一步走过这个 block,盯着每张 GPU 手上有什么 —— 它的 weight 切片输入输出

StepGPU 1GPU 2
input[4×512] 完整[4×512] 完整
LayerNorm (冗余)in [4×512] → out [4×512]in [4×512] → out [4×512]
QKV proj (A)W [512×768] (heads 1–4)
in [4×512] → out [4×768]
= heads 1–4 的 Q+K+V,每份 [4×256]
W [512×768] (heads 5–8)
in [4×512] → out [4×768]
= heads 5–8 的 Q+K+V,每份 [4×256]
attentionheads 1–4
in [4×768] → out [4×256]
heads 5–8
in [4×768] → out [4×256]
GATHER #1 —— output proj 要的是完整 k=512,每张 GPU 只有 256,凑出 [4×512]
output proj (A)W [512×256]
in [4×512] → out [4×256]
W [512×256]
in [4×512] → out [4×256]
GATHER #2 —— residual 要完整 d=512,输出只有半个 d=256,凑出 [4×512]
+ residual[4×512][4×512][4×512][4×512]
LayerNorm (冗余)in [4×512] → out [4×512]in [4×512] → out [4×512]
FFN-up (A)W [512×1024]
in [4×512] → out [4×1024]
W [512×1024]
in [4×512] → out [4×1024]
activation (pointwise)[4×1024][4×1024][4×1024][4×1024]
GATHER #3 —— FFN-down 要完整 4d=2048,每张 GPU 只有 1024,凑出 [4×2048]
FFN-down (A)W [2048×256]
in [4×2048] → out [4×256]
W [2048×256]
in [4×2048] → out [4×256]
GATHER #4 —— residual 要完整 d=512,输出只有半个 d=256,凑出 [4×512]
+ residual[4×512][4×512][4×512][4×512]

每个 block 四次跨 GPU 的 gather。

其中两次发生在下一个 A 风格的 matmul 之前 —— 它要的是完整输入。另外两次在 residual add 之前 —— 它要完整向量,但我们刚算出来的是半个。本质都是同一个原因:A 产出半个输出,下游几乎所有东西要的都是完整输入。


3. v1 的代价

跨 GPU 通信是分布式计算里 的那一环。整个 TP 设计的目标,就是把它压到尽可能少。v1 的现状是:几乎每一个需要完整 feature 的算子前面,都得付一次 gather。

一个 32 层的模型,光一次 forward 就 ~130 次跨 GPU 通信。太多了。

问题就变成了:

能不能把 gather 省掉?

每次 gather 的存在都是一个原因:下一个算子要完整向量,但 Strategy A 给的是半个。我们真正想要的,是一个 愿意 直接吃半个输出的 matmul。

article 02 已经把它递到我们手上了。


4. v2 —— 让 Strategy A 配上 Strategy B(half → sum)

换一个角度看这两种 strategy:

  • Strategy A 输出 一个 half
  • Strategy B 输入 一个 half

形状一样。 A 的输出正好是 B 想要的输入。两者咬合上,中间一点通信都不需要。

把 v1 里的 “A → gather → A” 换成 “A → B”。B 直接吃下半个输出,而通信代价只剩下 B 末尾那一次 —— 把 partial sum 加成完整输出的 all-reduce,给后面的 residual 和 LN 用。

把这个套路用到整个 block,每个 A 配一个 B:

StepGPU 1GPU 2
input[4×512] 完整[4×512] 完整
LayerNorm (冗余)in [4×512] → out [4×512]in [4×512] → out [4×512]
QKV proj (A)W [512×768] (heads 1–4)
in [4×512] → out [4×768]
= heads 1–4 的 Q+K+V,每份 [4×256]
W [512×768] (heads 5–8)
in [4×512] → out [4×768]
= heads 5–8 的 Q+K+V,每份 [4×256]
attentionheads 1–4
in [4×768] → out [4×256]
heads 5–8
in [4×768] → out [4×256]
output proj (B)W [256×512]
in [4×256] → out [4×512] (部分和)
W [256×512]
in [4×256] → out [4×512] (部分和)
ALL-REDUCE #1 —— 把两张 GPU 上的两个 [4×512] 部分和加起来,凑出完整 [4×512](residual + LN 要用)
+ residual[4×512][4×512][4×512][4×512]
LayerNorm (冗余)in [4×512] → out [4×512]in [4×512] → out [4×512]
FFN-up (A)W [512×1024]
in [4×512] → out [4×1024]
W [512×1024]
in [4×512] → out [4×1024]
activation (pointwise)[4×1024][4×1024][4×1024][4×1024]
FFN-down (B)W [1024×512]
in [4×1024] → out [4×512] (部分和)
W [1024×512]
in [4×1024] → out [4×512] (部分和)
ALL-REDUCE #2 —— 把两张 GPU 上的两个 [4×512] 部分和加起来,凑出完整 [4×512](residual + LN 要用)
+ residual[4×512][4×512][4×512][4×512]

每个 block 两次 all-reduce。

这就是 Megatron pattern。没人告诉我们答案,我们自己一步一步走进去的。


5. 没料到的对偶性

article 02 把 A 和 B 当两种独立 strategy 介绍 —— 同一个矩阵的两种读法。但把它们并排放着看,盯着每种的输入输出形状:

  • A完整输入,吐 半个输出
  • B半个输入,吐 完整 sum 作为输出。

它们不是两种 strategy,是 同一次往返的两半。A 的输出 shape 就是 B 的输入 shape;B 的输出 shape(all-reduce 之后)就是 A 的输入 shape。你发明 A 的同时,其实就把 B 当成它的回程一起发明出来了。

再回头看这个 block 在做什么:

  • Attentionwiden(QKV:d → k)后跟 narrow(output proj:k → d)。
  • FFN 也是 widend → 4d)后跟 narrow4d → d)。

widen 那里输出 feature 多,正好分到多张 GPU —— A 派得上用场。narrow 那里输入 feature 多,正好把输入切到多张 GPU,输出小再加回来 —— B 派得上用场。

这个 block 不是"碰巧"对 A→B 友好,是 结构上就对 A→B 友好:两对 widen-narrow,中间用 pointwise 粘起来。“Megatron pattern” 不是哪个人坐下来设计出来的算法,它就是唯一一个尊重架构本身做的事的通信 pattern。A、B 对偶和 widen-narrow 节奏,是同一件事讲两遍。

提一句通信代价:一次 gather 和一次 all-reduce 在每张 GPU 上搬的数据量差不多(all-reduce 内部大致是 reduce-scatter 加 all-gather)。v1 每个 block 4 次 gather;v2 每个 block 2 次 all-reduce —— 通信砍一半,模型一行没改。


6. 为什么 cut 必须落在 head 边界

v2 trace 偷偷假设了一件事:QKV 的 column cut 是沿着 head 边界k = 512 切成两块各 256,每张 GPU 正好拿 4 整个 head。这个假设干的活比看起来多得多。试一下反事实就知道。

想象一下 single-head attention —— 同样 k = 512,但只有一个 head,没有 head 结构。Strategy A 照旧用到 QKV 上:每张 GPU 拿到形状都是 [4 × 256]Q, K, V。开始算 attention。

第一步是 Q Kᵀ。每张 GPU 算 Q_half @ K_halfᵀ,得到一个 [4 × 4] 的矩阵 —— 但这个矩阵是它手上那 256 个 feature 的 partial sum。真正的 scores 得把两张 GPU 上的 partial 加起来才行。

接下来麻烦了:下一步是 softmax。Softmax 是非线性的,本地算完再合不行 —— softmax(a) + softmax(b) ≠ softmax(a + b)。reduction 必须发生在 softmax 之前。也就是说,attention 中间会硬塞进一次同步:

★ ALL-REDUCE,对 [n × n] 的 scores,发生在 softmax 之前。

这是 v2 那两次之外的第三次 all-reduce。Megatron pattern 直接塌成 次同步,而新增的这次的 tensor 大小随 sequence length 平方增长 —— 你最不想让它变大的那个。

修法不在算法层,在结构层:不要让 cut 跨过 head。每个 head 的 Q Kᵀ 必须完整地落在一张 GPU 上,partial sum 的问题就不会出现。multi-head attention 把这个白送给我们了:head 在定义上就互相独立,head 边界天然是切点;只要 h 能被 GPU 数整除,对 k = h · d_head 的 column split 就正好落在两个 head 之间。

所以 multi-head 不是系统人捡了 modeler 的便宜。它是 v2 能存在的 结构性前提。cut 落到 head 内部,softmax 立刻逼出一次毁掉一切的同步;cut 落在 head 之间,非线性就保持本地。Megatron pattern 不是 碰巧 在 multi-head 架构上能跑 —— 它要求 multi-head 架构。


7. 这一篇打开了哪些门

到这里,一个 block 跑在两张 GPU 上,每次 forward 两次 all-reduce。下一轮"等下,那……“的问题就开始了:

  • 如果 block 很多、GPU 也很多呢? TP 切的是 block 内部。切 block —— 把整个 block 摆到不同 GPU 上,让 microbatch 流过去 —— 是另一回事。Pipeline parallelism,下一篇见。
  • 如果 FFN 被换成 expert 呢? column-then-row 这一套对每个 expert 的 matmul 还是适用,但把 token 路由到正确的 expert 又引入一种新通信。MoE,再下一篇。
  • 如果 batch 里 sequence 长度差异巨大呢? 通信 pattern 不变,但 attention 那边要处理变长 sequence —— continuous batching 由此登场。

同样的语法。每条都自成一篇 walk-through。