0%

深度学习三大支柱

深度学习三大支柱:CNN、RNN 与预训练语言模型

核心认知:深度学习并非单一技术,而是由不同架构组成的“武器库”——CNN 擅长捕捉空间特征(图像),RNN 擅长处理时序依赖(文本、语音),而预训练语言模型(PLM)则通过“预训练+微调”范式,彻底改变了自然语言处理的方式。

本文将从原理到应用,系统讲解三种核心深度学习架构,帮助你建立完整的知识体系。


目录

  1. 卷积神经网络(CNN)
  2. 循环神经网络(RNN)与双向RNN
  3. 预训练语言模型与微调
  4. 三者对比与选型

卷积神经网络(CNN)

为什么需要 CNN?

传统的全连接神经网络在处理图像时会遇到两个致命问题:

问题 说明 后果
参数爆炸 一张 224×224 的彩色图片有 150,528 个像素,全连接层的参数数量 = 输入维度 × 输出维度 需要海量数据和计算资源
丢失空间结构 全连接层将像素展平为一维向量,忽略像素间的空间关系 无法识别边缘、形状等局部模式

CNN 的解决方案:通过局部连接权重共享,大幅减少参数数量,同时保留空间结构信息。

CNN 的核心组件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
┌─────────────────────────────────────────────────────────────────────┐
│ CNN 整体架构 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ 输入图像 卷积层 池化层 卷积层 池化层 全连接层 输出 │
│ │
│ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ ┌─┐ │
│ │ │ ──→ │ │ ──→ │ │ ──→ │ │ ──→ │ │ ──→ │ │ ──→ │ │ │
│ │ 224 │ │ 222 │ │ 111 │ │ 110 │ │ 55 │ │ 4096│ │10│ │
│ │ ×224│ │ ×222│ │ ×111│ │ ×110│ │ ×55 │ │ │ │ │ │
│ │ ×3 │ │ ×32 │ │ ×32 │ │ ×64 │ │ ×64 │ │ │ └─┘ │
│ └─────┘ └─────┘ └─────┘ └─────┘ └─────┘ └─────┘ │
3 32 32 64 64 4096 10
│ (RGB通道) (特征图) (降维后) (特征图) (降维后) (特征向量) (类别) │
│ │
└─────────────────────────────────────────────────────────────────────┘

1. 卷积层(Convolutional Layer)

卷积核(Filter/Kernel):一个小的权重矩阵(如 3×3 或 5×5),在输入图像上滑动,计算局部区域的点积,生成特征图。

1
2
3
4
5
6
7
8
9
10
11
12
卷积操作可视化(步长=1,无填充)

输入图像 (5×5) 卷积核 (3×3) 输出特征图 (3×3)
┌─────────────────┐ ┌─────────┐ ┌─────────────────┐
1 1 1 0 0 │ │ 1 0 1 │ │ 4 3 4
0 1 1 1 0 │ ⊙ │ 0 1 0 │ = │ 2 4 3
0 0 1 1 1 │ │ 1 0 1 │ │ 2 3 4
0 0 1 1 0 │ └─────────┘ └─────────────────┘
0 1 1 0 0
└─────────────────┘

每个输出位置 = 卷积核与输入对应区域的逐元素乘积之和

关键参数

  • 卷积核大小:常见 3×3、5×5、7×7
  • 步长(Stride):卷积核滑动的步长(通常为 1 或 2)
  • 填充(Padding):在输入边缘补零,控制输出尺寸
  • 通道数:输入 RGB 有 3 个通道,输出可以有多个卷积核(如 32、64、128 个)

为什么有效

  • 每个卷积核专门检测一种局部模式(边缘、角点、纹理)
  • 多个卷积核堆叠,从低级到高级逐层抽象

2. 激活函数(Activation Function)

激活函数就是把线性函数转化为非线性函数。引入非线性,使神经网络能够学习复杂模式。

函数 公式 特点
ReLU max(0, x) 计算快、缓解梯度消失(最常用)
Sigmoid 1/(1+e^{-x}) 输出范围 (0,1),适合二分类输出层
Tanh (e^x-e^{-x})/(e^x+e^{-x}) 输出范围 (-1,1),零中心化
1
2
3
# ReLU 激活前后对比
输入特征图: [-2, -1, 0, 1, 2, 3]
ReLU 之后: [0, 0, 0, 1, 2, 3] # 负值归零,正值保留

3. 池化层(Pooling Layer)

下采样操作,减少特征图尺寸,降低计算量,增强平移不变性。

1
2
3
4
5
6
7
8
9
10
11
最大池化(Max Pooling)- 2×2 池化核,步长=2

输入 (4×4) 输出 (2×2)
┌─────────────────┐ ┌─────────┐
1 3 2 4 │ │ 3 4
5 6 1 2 │ ──→ │ 7 8
2 1 7 8 │ └─────────┘
3 4 5 6
└─────────────────┘

每个 2×2 区域取最大值
池化类型 操作 特点
最大池化 取区域内最大值 保留最显著特征(最常用)
平均池化 取区域内平均值 保留整体信息
全局平均池化 整个特征图取平均 替代全连接层,减少参数

4. 全连接层(Fully Connected Layer)

将高维特征图展平,通过全连接网络映射到最终的分类结果。

经典 CNN 架构演进

1
2
3
4
5
6
7
8
9
10
11
12
时间线 ──────────────────────────────────────────────────────────────────→

AlexNet (2012) VGG (2014) ResNet (2015) EfficientNet (2019)
│ │ │ │
▼ ▼ ▼ ▼
┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐
│ 8层 │ │ 19层│ │152层│ │ 缩放│
└─────┘ └─────┘ └─────┘ └─────┘
│ │ │ │
▼ ▼ ▼ ▼
ILSVRC冠军 结构规整 残差连接 复合缩放
引入ReLU+Dropout 3×3卷积堆叠 解决梯度消失 精度/速度平衡
架构 年份 层数 核心创新 ImageNet Top-5 错误率
AlexNet 2012 8 ReLU、Dropout、GPU 训练 15.3%
VGG 2014 16-19 小卷积核(3×3)堆叠 7.3%
GoogLeNet 2014 22 Inception 模块(多尺度卷积) 6.7%
ResNet 2015 50-152 残差连接(解决梯度消失) 3.6%
DenseNet 2017 121-264 密集连接(特征复用) 3.7%
EfficientNet 2019 可变 复合缩放(深度×宽度×分辨率) 2.5%

残差连接(Residual Connection)原理解析

这是 ResNet 的核心创新,解决了深层网络难以训练的问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
普通连接(无残差)              残差连接(有残差)

x x
│ │
▼ ▼
┌───────┐ ┌───────┐
│ 层 │ │ 层 │
└───────┘ └───────┘
│ │
▼ │
┌───────┐ │
│ 激活 │ ▼
└───────┘ ┌───────────┐
│ │ + (逐元素相加) │
▼ └───────────┘
y

┌───────┐
│ 激活 │
└───────┘


y

普通层:y = F(x) 残差层:y = F(x) + x

恒等映射(短路连接)

学习目标:从学习完整映射 F(x) 降为学习残差 F(x) - x

为什么有效

  • 底层网络至少可以学习到恒等映射(残差为0)
  • 梯度可以直接通过短路连接传播,解决梯度消失
  • 可以训练超过 1000 层的超深网络

CNN 的应用场景

领域 应用 说明
计算机视觉 图像分类、目标检测、语义分割 自动驾驶、医疗影像
视频分析 动作识别、视频分类 监控分析、体育分析
医学影像 X光、CT、MRI 诊断 肺结节检测、骨折识别
人脸识别 人脸检测、人脸验证 手机解锁、安防监控
OCR 文字识别 文档扫描、车牌识别

循环神经网络(RNN)与双向 RNN

为什么需要 RNN?

传统神经网络(包括 CNN)假设输入之间相互独立,无法处理序列数据(文本、语音、时间序列)。

数据类型 序列特性 传统网络的问题
文本 单词顺序决定语义 “猫追老鼠” ≠ “老鼠追猫”
语音 时间顺序至关重要 无法捕捉音频时序依赖
股票价格 前后依赖 无法建模趋势和周期

RNN 的解决方案:引入隐藏状态(Hidden State),让网络具有“记忆”能力——当前输出依赖历史输入。

RNN 核心原理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
RNN 单元展开图(时间步)

t=0 t=1 t=2 t=3
│ │ │ │
x₀ ──┐│ x₁ ──┐│ x₂ ──┐│ x₃ ──┐│
▼▼ ▼▼ ▼▼ ▼▼
┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐
│ RNN │ │ RNN │ │ RNN │ │ RNN │
│ 单元 │←─│ 单元 │←─│ 单元 │←─│ 单元 │
└─────────┘ └─────────┘ └─────────┘ └─────────┘
│ │ │ │
▼ ▼ ▼ ▼
yyyy

隐藏状态 h_t 在时间步之间传递
h_t = f(W_h·h_{t-1} + W_x·x_t + b)

核心公式

1
2
h_t = tanh(W_hh · h_{t-1} + W_xh · x_t + b_h)    # 隐藏状态更新
y_t = W_hy · h_t + b_y # 输出(可选)

关键概念

  • 隐藏状态(h_t):网络的“记忆”,编码历史信息
  • 权重共享:所有时间步使用相同的 W_hh、W_xh 矩阵
  • 循环连接:h_{t-1} → h_t 形成循环,实现记忆

RNN 的变体

1. 标准 RNN(Vanilla RNN)

1
2
3
4
特点:
结构简单,参数量少
梯度消失/爆炸(难以学习长距离依赖)
记忆容量有限(约 10-20 个时间步)
1
2
3
4
# 标准 RNN 的局限性示例
输入: "我 在 北京 长大,... (省略500字) ... 我 喜欢 吃 ?"

难以记住前面的"北京"

2. LSTM(长短期记忆网络)

LSTM 通过门控机制解决了梯度消失问题,能够学习长达数百个时间步的依赖。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
LSTM 单元结构图

┌─────────────────────────────────────────┐
│ LSTM 单元 │
│ │
h_{t-1} ────────┼───┐ ┌──────┐ ┌──────┐ ┌──────┐ │
│ │ │ │ │ │ │ │ │
C_{t-1} ────────┼───┼────│遗忘门│────│输入门│────│输出门│──┼──→ C_t
│ │ │ │ │ │ │ │ │
│ │ └──────┘ └──────┘ └──────┘ │
│ │ │ │ │ │
x_t ────────────┼───┴────────┴───────────┴───────────┴──────┼──→ h_t
│ │
└─────────────────────────────────────────┘

三个门控机制:
┌─────────────────────────────────────────────────────────────┐
│ 遗忘门:决定从细胞状态中丢弃什么信息 │
f_t = σ(W_f·[h_{t-1}, x_t] + b_f) │
├─────────────────────────────────────────────────────────────┤
│ 输入门:决定将哪些新信息存入细胞状态 │
i_t = σ(W_i·[h_{t-1}, x_t] + b_i) │
│ C̃_t = tanh(W_c·[h_{t-1}, x_t] + b_c) │
├─────────────────────────────────────────────────────────────┤
│ 输出门:决定输出哪些信息 │
o_t = σ(W_o·[h_{t-1}, x_t] + b_o) │
h_t = o_t * tanh(C_t) │
└─────────────────────────────────────────────────────────────┘

LSTM vs 标准 RNN

特性 标准 RNN LSTM
长距离依赖 难以学习(>20步) 可以学习(>500步)
梯度消失 严重 大幅缓解
参数量 多(约4倍)
计算复杂度 较高
适用场景 短序列、简单任务 长序列、复杂任务

3. GRU(门控循环单元)

LSTM 的简化版本,保留核心门控机制,参数更少。

1
2
3
4
GRU vs LSTM:
- LSTM:3个门(遗忘、输入、输出)+ 独立的细胞状态 C
- GRU:2个门(更新门、重置门)+ 隐藏状态 H
- GRU 参数更少,训练更快,效果通常与 LSTM 相当

双向 RNN(Bidirectional RNN)

标准 RNN 只能利用过去的信息,无法看到“未来”。双向 RNN 通过两个方向的信息流解决这个问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
双向 RNN 结构图

前向 RNN ──────────────────────────────────────→
┌─────────┐ ┌─────────┐ ┌─────────┐
x₀ ───→│ RNN │───→│ RNN │───→│ RNN │───→
└─────────┘ └─────────┘ └─────────┘
│ │ │
▼ ▼ ▼
┌─────┐ ┌─────┐ ┌─────┐
concatconcatconcat
└─────┘ └─────┘ └─────┘
▲ ▲ ▲
┌─────────┐ ┌─────────┐ ┌─────────┐
x₀ ───→│ RNN │───→│ RNN │───→│ RNN │───→
└─────────┘ └─────────┘ └─────────┘
←──────────────────────────────────────────────
反向 RNN


输出 y_t = concat(前向隐藏状态 h_t→, 反向隐藏状态 h_t←)

优势:每个时间步的输出同时利用了上下文信息(过去+未来)

应用场景

任务 为什么需要双向 示例
命名实体识别 需要上下文判断”Apple”是公司还是水果 “Apple stock” vs “apple fruit”
情感分析 语义依赖前后文 “not good” 整体为负面
机器翻译 目标语言同时依赖源语言前后文 句法结构需整体理解
语音识别 音素依赖前后发音 连读、变调需要上下文

RNN 的应用场景

领域 任务 推荐架构
自然语言处理 文本分类、情感分析 BiLSTM + Attention
机器翻译 序列到序列 LSTM/GRU + Attention + Transformer
语音识别 语音转文字 BiLSTM + CTC
时间序列预测 股票价格、天气 LSTM
命名实体识别 人名、地名识别 BiLSTM + CRF

预训练语言模型与微调

NLP 范式的演进

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
时间线 ──────────────────────────────────────────────────────────────────→

规则时代 (1980-1990) 统计时代 (1990-2018) 预训练时代 (2018-至今)
│ │ │
▼ ▼ ▼
┌──────┐ ┌──────┐ ┌──────┐
│ 人工 │ │ 特征 │ │ 预训练 │
│ 规则 │ │ 工程 │ │ +微调 │
└──────┘ └──────┘ └──────┘
│ │ │
▼ ▼ ▼
• 词典匹配 • TF-IDF • BERT
• 正则表达式 • 词向量(Word2Vec) • GPT
• 语法规则 • 传统ML模型 • RoBERTa
• T5
范式 核心思想 代表技术 缺点
规则时代 人工编写语言规则 正则表达式、词典 无法泛化,维护成本高
统计时代 特征工程 + 传统ML TF-IDF、Word2Vec、SVM 特征工程依赖人工
预训练时代 通用语言知识 + 下游微调 BERT、GPT、RoBERTa 计算资源需求大

预训练-微调范式核心思想

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
┌─────────────────────────────────────────────────────────────────────┐
│ 预训练 + 微调范式 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ 阶段1:预训练(Pre-training) │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 海量无标注文本 │ │
│ │ (Wikipedia, Books, Web Text) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────┐ │ │
│ │ │ 预训练任务 │ │ │
│ │ │ (MLM/NSP/LM) │ │ │
│ │ └─────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────┐ │ │
│ │ │ 基础模型 │ ← 学习通用语言知识 │ │
│ │ └─────────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ │ 迁移 │
│ ▼ │
│ 阶段2:微调(Fine-tuning) │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 少量标注数据(下游任务) │ │
│ │ (情感分类、命名实体识别、问答、翻译) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────┐ │ │
│ │ │ 微调过程 │ ← 适配具体任务 │ │
│ │ └─────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────┐ │ │
│ │ │ 任务专用模型 │ │ │
│ │ └─────────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────┘

核心思想

  1. 预训练:在海量无标注文本上学习通用语言知识(语法、语义、常识)
  2. 微调:在少量标注数据上适配特定下游任务
  3. 迁移:将通用知识迁移到具体应用,大幅降低数据需求

核心预训练模型

BERT(Bidirectional Encoder Representations from Transformers)

Google 2018 年提出,革命性的双向编码器模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
BERT 预训练任务

任务1:掩码语言模型(MLM - Masked Language Model
┌─────────────────────────────────────────────────────────────────┐
│ 输入: [CLS] 我 爱 [MASK] 京 [SEP] 我 爱 中 国 [SEP] │
│ ↑ │
│ 被掩码的词 │
│ │
│ 预测: [MASK] 位置应该是什么词? → "北"
│ │
│ 作用:让模型学习双向上下文理解 │
└─────────────────────────────────────────────────────────────────┘

任务2:下一句预测(NSP - Next Sentence Prediction
┌─────────────────────────────────────────────────────────────────┐
│ 输入: [CLS] 今天天气很好 [SEP] 我们去公园吧 [SEP] │
│ ↑ ↑ │
│ 句子A 句子B
│ │
│ 预测:句子B是否是句子A的下一句? → IsNext(是) │
│ │
│ 作用:让模型学习句子间的关系 │
└─────────────────────────────────────────────────────────────────┘

BERT 的核心创新

  • 双向编码:同时利用左右上下文(区别于 GPT 的单向)
  • Transformer 架构:自注意力机制,并行计算
  • 通用性强:一个模型适配多种 NLP 任务

GPT(Generative Pre-trained Transformer)

OpenAI 开发的生成式预训练模型,擅长文本生成。

对比维度 BERT GPT
架构 仅编码器(Encoder-only) 仅解码器(Decoder-only)
注意力方向 双向(同时看左右) 单向(只看左边,自回归)
预训练任务 MLM + NSP 语言模型(预测下一个词)
擅长任务 理解类(分类、抽取) 生成类(写作、对话)
代表模型 BERT、RoBERTa、ALBERT GPT-2、GPT-3、GPT-4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
架构对比图

BERT(双向): GPT(单向自回归):

[CLS] 我 爱 [MASK] 京 我 爱 北 京
↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
│ │ │ │ │ │ │ │
┌─┴──┐ ┌─┴──┐ ┌─┴──┐ ┌─┴──┐ ┌─┴──┐ ┌─┴──┐ ┌─┴──┐ ┌─┴──┐
│编码│ │编码│ │编码│ │编码│ │解码│ │解码│ │解码│ │解码│
│层 │ │层 │ │层 │ │层 │ │层 │ │层 │ │层 │ │层 │
└─┬──┘ └─┬──┘ └─┬──┘ └─┬──┘ └─┬──┘ └─┬──┘ └─┬──┘ └─┬──┘
│ │ │ │ │ │ │ │
└─────┴─────┴─────┘ └─────┘ │ │
信息双向流动 只能看到左边的词

其他重要模型

模型 发布方 核心特点 适用场景
RoBERTa Facebook BERT 改进版,更大数据+更长时间训练 各类理解任务
ALBERT Google 参数共享,模型更小 资源受限场景
DistilBERT Hugging Face 知识蒸馏,模型缩小40% 轻量级部署
T5 Google Text-to-Text 统一框架 所有任务统一为生成
XLNet CMU/Google 排列语言模型(PLM) 长文本理解

微调(Fine-tuning)流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
┌─────────────────────────────────────────────────────────────────────┐
│ 微调流程示例(文本分类) │
├─────────────────────────────────────────────────────────────────────┤
│ │
1. 加载预训练模型 │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ from transformers import AutoModelForSequenceClassification │ │
│ │ model = AutoModelForSequenceClassification.from_pretrained( │ │
│ │ "bert-base-chinese", # 预训练模型名称 │ │
│ │ num_labels=2 # 分类类别数(正面/负面) │ │
│ │ ) │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
2. 准备下游任务数据 │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 训练数据: │ │
│ │ "这部电影太棒了!" → 正面 │ │
│ │ "剧情无聊,浪费时间" → 负面 │ │
│ │ "演技出色,特效震撼" → 正面 │ │
│ │ ...(只需要几千条标注数据) │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
3. 添加任务特定层 │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 预训练 BERT 输出 (768维) → 分类头 (2维) → softmax → 概率 │ │
│ │ │ │
│ │ • 预训练部分:参数更新(学习率较小) │ │
│ │ • 分类头:参数从头训练(学习率较大) │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
4. 训练微调 │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 超参数设置: │ │
│ │ • 学习率:2e-5(比预训练小10-100倍) │ │
│ │ • 批次大小:16-32 │ │
│ │ • 训练轮数:2-5 epoch │ │
│ │ • 优化器:AdamW │ │
│ │ │ │
│ │ 训练过程: │ │
│ │ for epoch in range(3): │ │
│ │ for batch in data_loader: │ │
│ │ loss = model(batch) │ │
│ │ loss.backward() │ │
│ │ optimizer.step() │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
5. 评估与部署 │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 模型保存后可用于: │ │
│ │ • API 服务部署 │ │
│ │ • 批量推理 │ │
│ │ • 边缘设备部署 │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────┘

微调的优势

对比维度 传统方法(从头训练) 预训练+微调
标注数据需求 数万到数十万 几百到几千
训练时间 数天到数周 几分钟到几小时
计算资源 多 GPU,大规模集群 单 GPU,普通工作站
模型性能 依赖数据量 通常优于从头训练
可迁移性 每个任务单独训练 一个基础模型适配多任务

应用案例

任务 预训练模型 微调数据 效果
情感分析 BERT-base 10,000 条影评 准确率 94%
命名实体识别 BERT-base 5,000 条标注 F1 88%
问答系统 RoBERTa SQuAD 2.0 EM 86%
文本摘要 T5 CNN/DailyMail ROUGE 42
文本生成 GPT-2 特定领域语料 高质量生成

三者对比与选型

核心对比表

维度 CNN RNN/LSTM 预训练模型 (BERT/GPT)
擅长数据类型 图像、空间数据 时序数据、中等长度文本 长文本、复杂语言理解
核心机制 卷积+池化 循环隐藏状态 自注意力 + 预训练
并行计算 低(序列依赖) 高(Transformer)
长距离依赖 (感受野有限) LSTM 可处理 100-500 步 可处理 512-2048 步
参数量 数百万到数千万 数百万到数千万 数亿到数千亿
训练数据需求 万到百万 万到百万 预训练:千亿级
微调:千到万
硬件需求 GPU(中) GPU(中) GPU(高)
推理速度 中(串行) 中(长文本较慢)
可解释性 特征可视化 较难 很难

选型决策树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
开始


数据类型是什么?

├── 图像/视频
│ │
│ └── 数据量?
│ ├── 小(<1万)→ 迁移学习(使用预训练 CNN)
│ └── 大(>1万)→ 从头训练 CNN

├── 时序数据/中等长度文本(<500词)
│ │
│ └── 需要上下文方向?
│ ├── 只看过去(时间序列预测)→ 标准 RNN/LSTM
│ └── 需要双向理解(文本分类)→ BiLSTM

└── 长文本/复杂语言任务

├── 理解类任务(分类、抽取、QA)
│ └── 资源充足?→ BERT/RoBERTa
│ └── 资源受限?→ DistilBERT/AlBERT

└── 生成类任务(写作、翻译、对话)
└── GPT 系列 / T5

总结

模型 一句话总结 何时使用
CNN “滑动窗口识别局部特征” 图像处理、需要快速推理
RNN/LSTM “时序记忆,依赖过去” 时间序列、语音、中等长度文本
BiRNN “前后文都要看” 需要上下文理解的任务(NER、情感分析)
预训练模型 “先学通用知识,再适配任务” NLP 任务,特别是标注数据有限时

欢迎关注我的其它发布渠道