Transformer 架构的最新改进

#2025/12/29 #ai

从2017年Transformer诞生到现在,科学家们做了哪些优化,让模型跑得更快、更省内存、效果更好。


目录

一、为什么需要改进?

原始Transformer虽然很强,但有两个大问题

  1. 计算太慢:尤其是注意力机制,随着输入变长,计算量呈平方级增长
  2. 吃显存: 训练和推理都需要巨大的GPU显存

二、改进方向1:让注意力更高效

1. 稀疏注意力(Sparse Attention)

{%}

图 3-22:稀疏注意力通过只关注少量前序位置来提升性能

  • 原始问题:
    • 全注意力要求每个词元关注所有前面的词元,比如处理第1000个词时,要看前面999个词。
  • 解决方案:
    • 只看一部分重要的词元,比如:
      • 只看最近的100个词
      • 每隔几个词看一个(跳跃式)
      • 只看固定位置的词
  • 形象比喻:
    • 就像做阅读理解,不用逐字看完全文,只需要:
    • 重点看上一段(局部注意力)
    • 偶尔回顾一下开头(全局注意力)
  • 实际应用:
    • GPT-3就交替使用“全注意力块“和”稀疏注意力块

{%}

图 3-23:全注意力与稀疏注意力的对比。

2. 分组查询注意力(GQA)和多查询注意力(MQA)

  • 原始问题:
    • 多头注意力中,每个头都有自己的 Q、K、V 矩阵,太占内存
  • 解决方案:
    • 让多个头 共享 K 和 V 矩阵:
  • 效果:
    • 显存占用大幅下降,推理速度显著提升,性能损失很小

{%}

图 3-25:不同类型注意力的比较:原始的多头注意力、分组查询注意力和多查询注意力

原始多头注意力(假设8个头):  
头1: Q1, K1, V1  
头2: Q2, K2, V2  

...  
头8: Q8, K8, V8  
→ 需要24个矩阵  

多查询注意力(MQA):  
头1: Q1 ┐  
头2: Q2 ├→ 共享同一个 K, V  
...     │  
头8: Q8 ┘  
→ 只需要10个矩阵(8个Q + 1个K + 1个V)  

分组查询注意力(GQA):  
头1-4: Q1,Q2,Q3,Q4 → 共享 K1, V1  
头5-8: Q5,Q6,Q7,Q8 → 共享 K2, V2  
→ 介于两者之间,平衡性能和效率  

{%}

注意力机制通过查询矩阵、键矩阵和值矩阵来实现。在多头注意力中,每个注意力头都有一组独立的查询矩阵、键矩阵和值矩阵

{%}

分组查询注意力利用多组共享的键矩阵和值矩阵,牺牲了一些多查询注意力的效率来换取质量的大幅提升。每个分组都有其对应的注意力头集合

3. Flash Attention

  • 原始问题:
    • GPU有两种内存:
      • HBM(高带宽内存):容量大但慢
      • SRAM(共享内存):容量小但超快
  • 传统注意力计算需要频繁在两者间搬运数据,浪费时间。
    • 解决方案:
      • 优化数据读取顺序,尽量让计算都在SRAM上完成,减少数据搬运
      • GPU 共享内存(GPU’s shared memory,SRAM)和高带宽内存(high bandwidth memory,HBM)之间的数据加载和迁移来加速注意力计算
  • 效果:
    • 速度提升2-4倍,显存占用减少
  • 形象比喻:
    • 就像做饭:
      • 原来:从冰箱拿菜→切→放回冰箱→再拿→炒→放回…(反复搬运)
      • 现在:一次性拿够食材放操作台,切完直接炒(减少往返)

三、改进方向2:优化Transformer块内部结构

1. 预归一化(Pre-normalization)

原来的顺序:

输入 → 注意力 → 加残差 → 归一化 → 前馈层 → 加残差 → 归一化 → 输出  

{%}|504

图 3-29:原始 Transformer 论文中的 Transformer 块

现在的顺序:

输入 → 归一化 → 注意力 → 加残差  →  归一化 → 前馈层 → 加残差 → 输出  

{%}

图 3-30:2024 年的 Transformer(如 Llama 3)的 Transformer 块有一些新的改进,如预归一化(通过 RMSNorm 实现),以及通过分组查询注意力和旋转位置嵌入优化的注意力机制

好处: 训练更稳定,收敛更快


2. RMSNorm

  • 替代 LayerNorm
  • 改进点:
    • 计算更简单,速度更快,效果差不多
  • 原理简化:
    • LayerNorm:计算均值和方差,然后标准化
    • RMSNorm:只用方差(均方根),跳过均值计算

3. 更好的激活函数:SwiGLU

  • 原来用:
    • ReLU
  • 现在用:
    • SwiGLU
  • 效果:
    • 模型表达能力更强,在相同参数量下性能更好

四、改进方向3:更好的位置编码 —— RoPE

原始问题

  • Transformer需要知道词的位置信息(第1个词、第2个词…)。
  • 原始方案是绝对位置编码:
    • 词1 → 位置1
    • 词2 → 位置2
  • 局限性:
    • 如果训练时最长见过512个词,遇到800个词的文本就不知道怎么办
    • 多个文档拼接训练时,位置信息会混乱

解决方案:旋转位置嵌入(RoPE)

核心思想:

  • 不是在开头加位置信息,而是在注意力计算时,通过“旋转“向量来编码位置
    优点:
  1. 相对位置感知:模型更关注“词A在词B前面3个位置“,而不是“词A在第5个位置“
  2. 更好的外推性:训练时见过2K长度,推理时处理4K也能工作
  3. 适配打包训练:多个文档拼一起训练时,每个文档的第一个词都能正确标记为“位置1“
    应用:
  • Llama 2、Llama 3等新模型都使用RoPE

{%}

旋转位置嵌入是应用在注意力步骤中的,而不是应用在前向传播的开始

{%}

旋转位置嵌入在自注意力中的相关性评分步骤之前,被添加到词元的表示中

五、核心改进对比表

组件原始Transformer2024年改进版好处
注意力全注意力稀疏注意力/GQA速度↑,显存↓
键值矩阵每头独立多头共享(GQA)显存↓50%+
数据读取普通方式Flash Attention速度↑2-4倍
归一化位置后归一化预归一化训练更稳
归一化方法LayerNormRMSNorm速度↑
激活函数ReLUSwiGLU性能↑
位置编码绝对位置RoPE更灵活、可外推

六、完整演进图

2017年原始Transformer  
    ↓  
├─ 注意力太慢 → 稀疏注意力、GQA、Flash Attention  
├─ 训练不稳定 → 预归一化 + RMSNorm  
├─ 表达力不够 → SwiGLU激活函数  
└─ 位置编码弱 → RoPE  

    ↓  
2024年现代Transformer  
(如Llama 3)  

七、一句话总结

这些改进让Transformer在保持强大能力的同时:

  • 训练快了(预归一化+RMSNorm)
  • 推理快了(Flash Attention+GQA)
  • 省显存了(GQA+稀疏注意力)
  • 适应性更强了(RoPE)