分类头、池化、CLS
#2026/01/01 #id
目录
一、什么是“分类头“?
1. 基本概念
分类头(Classification Head) 就是接在BERT模型后面的一个小型神经网络,专门用来做分类任务。
2. 形象比喻
BERT模型 → 分类头 → 输出类别
↓ ↓ ↓
大脑 决策器 最终判断
想象一下:
- BERT 是一个理解文字的“大脑“
- 分类头 是根据大脑理解的内容做出“决策“的部分
- 比如判断:“这是好评还是差评?”
3. 具体结构
BERT模型
↓
前馈神经网络(分类头)
↓
输出:25% 负例 / 75% 正例
分类头通常是一个前馈神经网络,包含:
- 一层或多层全连接层
- 激活函数
- 最后的
softmax层输出概率
4. 代码示例
from transformers import AutoModelForSequenceClassification
# 加载带分类头的模型
model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-cased",
num_labels=2 # 分类头输出2个类别
)
这里的 num_labels=2 就是在BERT上加一个分类头,输出2个类别的概率。
二、什么是“池化“(Pooling)?
1. 基本概念
池化(Pooling) 就是把多个数字压缩成一个数字的过程,目的是提取最重要的信息。
2. 为什么需要池化?
BERT处理一个句子后,会为每个词元生成一个向量:
输入: "I love cats"
↓ ↓ ↓
BERT输出:[向量1, 向量2, 向量3] # 3个词元 = 3个向量
但我们想要整个句子的一个向量,怎么办?
→ 用池化把多个向量合并成一个!
3. 平均池化(Mean Pooling)
平均池化是最常用的方法,就是取所有词元向量的平均值:
句子:"I love cats"
词元向量:
I → [0.1, 0.2, 0.3]
love → [0.5, 0.6, 0.7]
cats → [0.9, 1.0, 1.1]
平均池化后:
句子向量 → [(0.1+0.5+0.9)/3, (0.2+0.6+1.0)/3, (0.3+0.7+1.1)/3]
= [0.5, 0.6, 0.7]
4. 形象比喻
想象你有一本书的每页内容,池化就是:
- 平均池化:把所有页的内容“混合“成一个摘要
- 最大池化:只保留最重要的那几页
[CLS]池化:只看目录页(第一个特殊词元)
5. SBERT中的池化流程
输入句子:"My dog is cute"
↓
BERT处理
↓
词元嵌入:[向量1, 向量2, 向量3, 向量4]
↓
平均池化层
↓
句子嵌入:一个固定维度的向量(比如768维)
6. 代码示例
from sentence_transformers import models, SentenceTransformer
# 创建词嵌入模型
word_embedding_model = models.Transformer("bert-base-uncased")
# 添加池化层(使用平均池化)
pooling_model = models.Pooling(
word_embedding_model.get_word_embedding_dimension(),
pooling_mode_mean_tokens=True # 平均池化
)
# 组合成完整的sentence-transformers模型
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
三、SBERT的改进总结
原来的BERT做分类
句子 → BERT → [CLS]词元 → 分类头 → 输出类别
SBERT的改进
句子 → BERT → 所有词元向量 → 池化层 → 句子嵌入向量
关键区别:
| 特征 | 原BERT | SBERT |
|---|---|---|
| 后续层 | 分类头 | 池化层 |
| 输出 | 类别标签 | 句子向量 |
| 目的 | 分类任务 | 生成可复用的嵌入 |
| 灵活性 | 只能做训练时的任务 | 嵌入可用于多种任务 |
四、为什么这样改进?
1. 分类头的问题
- 只能做一件事:
- 训练时是情感分析,就只能做情感分析
- 不能比较句子:
- 无法直接计算两个句子的相似度
2. 池化层的优势
- 生成通用表示:
- 一个向量可以用于搜索、聚类、分类等多种任务
- 支持句子比较:
- 可以直接计算两个句子向量的余弦相似度
- 可以预计算:
- 句子向量可以提前算好存起来,查询时超快
3. 实际效果
# 使用SBERT
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-mpnet-base-v2')
# 生成句子嵌入(池化后的结果)
embeddings = model.encode([
"I love cats",
"I adore felines",
"The weather is nice"
])
# 可以直接比较相似度
from sklearn.metrics.pairwise import cosine_similarity
print(cosine_similarity(embeddings))
# 前两个句子(意思相近)相似度高
# 第三个句子(无关)相似度低
总结一句话:
- 分类头 = 专门做分类决策的小网络
- 池化 = 把多个向量压缩成一个向量的技术
SBERT用池化替代分类头,让模型从“只会分类“变成“能生成通用句子表示“,这样就能用于搜索、聚类等更多任务!
CLS ?
根据书中的内容,[CLS]是一个特殊的词元(token),其全称是 Classification Token,主要用于分类任务。让我为你详细解释:
- 定义和作用
[CLS]是Classification的缩写- 是在 BERT 模型输入序列开头添加的特殊词元
- 主要用于表示整个输入序列的语义信息
- 具体使用方式
- 通常用于在特定任务(如分类)上进行模型微调
- 可以作为输入嵌入(input embedding),用于后续的分类操作
- 模型会为
[CLS]词元生成一个特殊的向量表示
- 举例说明
输入序列:`[CLS]`My name is Maarten [SEP]
↑
[CLS]词元位于序列开头
- 使用场景
- 文本分类
- 句子对相似度判断
- 命名实体识别等任务
简单来说,[CLS]就是模型用来“总结“整个输入文本语义信息的特殊标记。