分类头、池化、CLS

#2026/01/01 #id

目录

一、什么是“分类头“?

1. 基本概念

分类头(Classification Head) 就是接在BERT模型后面的一个小型神经网络,专门用来做分类任务。

2. 形象比喻

BERT模型 → 分类头 → 输出类别  
  ↓          ↓         ↓  
 大脑      决策器    最终判断  

想象一下:

  • BERT 是一个理解文字的“大脑“
  • 分类头 是根据大脑理解的内容做出“决策“的部分
  • 比如判断:“这是好评还是差评?”

3. 具体结构

BERT模型  
  ↓  
前馈神经网络(分类头)  
  ↓  
输出:25% 负例 / 75% 正例  

分类头通常是一个前馈神经网络,包含:

  • 一层或多层全连接层
  • 激活函数
  • 最后的softmax层输出概率

4. 代码示例

from transformers import AutoModelForSequenceClassification  

# 加载带分类头的模型  
model = AutoModelForSequenceClassification.from_pretrained(  
    "bert-base-cased",   
    num_labels=2  # 分类头输出2个类别  
)  

这里的 num_labels=2 就是在BERT上加一个分类头,输出2个类别的概率。


二、什么是“池化“(Pooling)?

1. 基本概念

池化(Pooling) 就是把多个数字压缩成一个数字的过程,目的是提取最重要的信息。

2. 为什么需要池化?

BERT处理一个句子后,会为每个词元生成一个向量:

输入:  "I  love  cats"  
        ↓    ↓     ↓  
BERT输出:[向量1, 向量2, 向量3]  # 3个词元 = 3个向量  

但我们想要整个句子的一个向量,怎么办?
用池化把多个向量合并成一个!

3. 平均池化(Mean Pooling)

平均池化是最常用的方法,就是取所有词元向量的平均值

句子:"I love cats"  

词元向量:  
I    → [0.1, 0.2, 0.3]  
love → [0.5, 0.6, 0.7]  
cats → [0.9, 1.0, 1.1]  

平均池化后:  
句子向量 → [(0.1+0.5+0.9)/3, (0.2+0.6+1.0)/3, (0.3+0.7+1.1)/3]  
        = [0.5, 0.6, 0.7]  

4. 形象比喻

想象你有一本书的每页内容,池化就是:

  • 平均池化:把所有页的内容“混合“成一个摘要
  • 最大池化:只保留最重要的那几页
  • [CLS]池化:只看目录页(第一个特殊词元)

5. SBERT中的池化流程

输入句子:"My dog is cute"  
    ↓  
BERT处理  
    ↓  
词元嵌入:[向量1, 向量2, 向量3, 向量4]  
    ↓  
平均池化层  
    ↓  
句子嵌入:一个固定维度的向量(比如768维)  

6. 代码示例

from sentence_transformers import models, SentenceTransformer  

# 创建词嵌入模型  
word_embedding_model = models.Transformer("bert-base-uncased")  

# 添加池化层(使用平均池化)  
pooling_model = models.Pooling(  
    word_embedding_model.get_word_embedding_dimension(),  
    pooling_mode_mean_tokens=True  # 平均池化  
)  

# 组合成完整的sentence-transformers模型  
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])  

三、SBERT的改进总结

原来的BERT做分类

句子 → BERT → [CLS]词元 → 分类头 → 输出类别  

SBERT的改进

句子 → BERT → 所有词元向量 → 池化层 → 句子嵌入向量  

关键区别

特征原BERTSBERT
后续层分类头池化层
输出类别标签句子向量
目的分类任务生成可复用的嵌入
灵活性只能做训练时的任务嵌入可用于多种任务

四、为什么这样改进?

1. 分类头的问题

  • 只能做一件事
    • 训练时是情感分析,就只能做情感分析
  • 不能比较句子
    • 无法直接计算两个句子的相似度

2. 池化层的优势

  • 生成通用表示
    • 一个向量可以用于搜索、聚类、分类等多种任务
  • 支持句子比较
    • 可以直接计算两个句子向量的余弦相似度
  • 可以预计算
    • 句子向量可以提前算好存起来,查询时超快

3. 实际效果

# 使用SBERT  
from sentence_transformers import SentenceTransformer  

model = SentenceTransformer('all-mpnet-base-v2')  

# 生成句子嵌入(池化后的结果)  
embeddings = model.encode([  
    "I love cats",  
    "I adore felines",   
    "The weather is nice"  
])  

# 可以直接比较相似度  
from sklearn.metrics.pairwise import cosine_similarity  
print(cosine_similarity(embeddings))  

# 前两个句子(意思相近)相似度高  
# 第三个句子(无关)相似度低  

总结一句话

  • 分类头 = 专门做分类决策的小网络
  • 池化 = 把多个向量压缩成一个向量的技术

SBERT用池化替代分类头,让模型从“只会分类“变成“能生成通用句子表示“,这样就能用于搜索、聚类等更多任务!

CLS ?

根据书中的内容,[CLS]是一个特殊的词元(token),其全称是 Classification Token,主要用于分类任务。让我为你详细解释:

  1. 定义和作用
  • [CLS]Classification 的缩写
  • 是在 BERT 模型输入序列开头添加的特殊词元
  • 主要用于表示整个输入序列的语义信息
  1. 具体使用方式
  • 通常用于在特定任务(如分类)上进行模型微调
  • 可以作为输入嵌入(input embedding),用于后续的分类操作
  • 模型会为 [CLS]词元生成一个特殊的向量表示
  1. 举例说明
输入序列:`[CLS]`My name is Maarten [SEP]  
                ↑  
                [CLS]词元位于序列开头  
  1. 使用场景
  • 文本分类
  • 句子对相似度判断
  • 命名实体识别等任务

简单来说,[CLS]就是模型用来“总结“整个输入文本语义信息的特殊标记。