AI 大模型的训练过程
一、数据准备 - 模型的原材料奠基
数据准备的核心目标是构建干净、全面、合规的训练数据集,主要包含数据采集、数据清洗、数据预处理三个核心环节。
1. 数据采集
(1)数据来源
数据来源分为三类,覆盖文本与多模态场景:
- 公开文本库:学术论文、百科全书、新闻资讯、书籍、论坛、博客等
- 多模态数据:图片、视频、音频、表格等
- 私有数据:企业内部文档、行业数据集等
(2)低质来源过滤
需排除垃圾邮件、重复内容、极端言论、错误信息(如伪科学、谣言),避免模型学习偏见。
2. 数据清洗
核心目的是剔除杂质、统一格式,解决原始数据的噪声、冗余、偏见问题:
(1)格式标准化
- 文本:统一为 UTF-8 编码
- 图片:统一分辨率、格式(jpg/png)
- 音频:统一采样率
(2)噪声过滤
- 用正则表达式删除特殊符号、乱码
- 通过语法检查工具(如 Grammarly API)修正错别字
(3)去重与去冗余
- 用哈希算法识别重复文本
- 合并相似内容
(4)偏见修正
通过人工标注或算法检测(如关键词匹配)删除/平衡含性别、地域歧视等偏见的内容。
3. 数据预处理
核心是将原始数据转换为模型可理解的格式,并完成数据集划分:
(1)模型可理解格式转换
| 数据类型 | 转换方式 |
|---|---|
| 文本数据 | 先分词拆分为最小语义单元(token),再通过嵌入层将 token 映射为高维向量 |
| 图像数据 | 将像素值归一化(如 0-255 缩放到 0-1),通过卷积层提取特征向量 |
(2)数据集划分
| 数据集类型 | 占比 | 作用 |
|---|---|---|
| 训练集 | 80%-90% | 用于模型学习核心规律 |
| 验证集 | 5%-10% | 用于调整模型参数 |
| 测试集 | 5%-10% | 用于最终性能评估(不参与训练) |
二、模型架构设计 - 模型的骨架搭建
模型架构决定数据的流动方式和学习逻辑,是大模型能力的核心载体,主要包含核心网络架构选择、超参数设定、辅助模块设计三部分。
1. 核心网络架构选择
(1)Transformer 架构(主流)
Transformer 是 2017 年 Google 团队在论文《Attention Is All You Need》中提出的神经网络架构,是当前大语言模型(如 GPT、BERT)、多模态模型(如 DALL·E)的核心基础。其核心优势是基于自注意力机制实现并行计算,突破了传统 RNN/LSTM 的序列依赖限制。
① 核心优势
| 优势项 | 具体说明 |
|---|---|
| 并行计算能力 | RNN 需按序列顺序处理,Transformer 可同时处理所有 token,训练效率提升数十倍 |
| 长距离依赖捕捉 | RNN 易因梯度消失导致长文本(1000 词以上)依赖捕捉能力弱;自注意力机制可直接计算任意两个 token 的关联,无距离限制 |
| 灵活性与扩展性 | 适配 NLP 任务(翻译、生成、理解),也可扩展至多模态领域(图像、音频),仅需修改输入嵌入方式 |
② 自注意力机制
能捕捉数据中的长距离依赖关系,例如“小明今天去公园,他很开心”中,“他”与“小明”的关联。
③ 核心架构(编码器-解码器)
整体流程:
输入文本 ->> 编码器 ->> 上下文特征 ->> 解码器 ->> 输出文本
-
编码器(Encoder):由 N 个相同的编码器层堆叠而成,将输入序列转化为含上下文信息的特征向量。
每个编码器层包含两个核心子层,且子层均配备残差连接(解决梯度消失)和层归一化(加速收敛):- 多头自注意力层:让每个 token 关注序列中其他相关 token,捕捉上下文依赖;“多头”指将注意力拆分到多个并行子空间,捕捉不同维度关联(语义、语法),最后拼接增强表达。
- 前馈神经网络:对自注意力层输出做非线性变换(两层线性映射 + ReLU 激活函数),为每个 token 添加局部特征(无跨 token 交互)。
-
解码器(Decoder):由 N 个相同的解码器层堆叠而成,基于编码器输出的上下文特征生成目标序列。
每个解码器层包含三个核心子层,同样配备残差连接和层归一化:- 掩码多头自注意力层:与编码器自注意力类似,添加掩码机制确保生成第 i 个 token 时,仅能看到前 i-1 个 token(避免泄露后续信息)。
- 编码器-解码器自注意力层:让解码器的每个 token 关注编码器输出中相关的 token,建立输入与输出的关联。
- 前馈神经网络:与编码器的前馈神经网络结构一致,处理局部特征。
④ 输入输出处理
| 环节 | 处理方式 |
|---|---|
| 输入(文本→特征向量) | 1. 嵌入层:将每个 token 转化为固定维度向量(如 512 维),理解词的语义; 2. 位置编码:通过正弦/余弦函数为 token 添加位置信息,区分“我爱你”和“你爱我”的语义差异(弥补 Transformer 无时序结构的缺陷) |
| 输出(特征向量→文本) | 解码器最终输出经线性层(映射到词表维度)+ softmax 层(生成词概率分布),通过采样(如贪心算法)选择概率最高的词,逐步生成完整序列 |
⑤ 基于 Transformer 的经典模型
| 模型 | 核心架构 | 适用场景 |
|---|---|---|
| BERT | 仅使用编码器 | 文本理解任务(文本分类、问答) |
| GPT | 仅使用解码器 | 文本生成任务(文本续写、对话) |
| T5 | 编码器+解码器 | 兼顾理解与生成任务 |
(2)RNN/LSTM
适用于序列数据,但长距离依赖捕捉能力弱,已逐渐被 Transformer 替代。
(3)卷积神经网络(CNN)
多用于图像、音频的局部特征提取。
2. 超参数设定
超参数是训练前人工设定的全局参数,直接影响模型复杂度和训练成本:
(1)模型规模相关参数
- 层数(L)
- 注意力头数(H)
- 隐藏层维度(D)
(2)训练相关参数
| 参数 | 说明 |
|---|---|
| 批次大小(Batch Size) | 一次训练的样本数量,需匹配 GPU 内存 |
| 学习率(Learning Rate) | 控制参数更新幅度,通常从 0.001 开始衰减 |
| 训练轮次(Epochs) | 数据集重复训练的次数,避免过拟合 |
3. 辅助模块设计
| 模块 | 作用 |
|---|---|
| 位置编码(Positional Encoding) | 为 token 添加位置信息,区分语序导致的语义差异(如“我吃苹果”和“苹果吃我”),常用正弦/余弦函数或可学习向量实现 |
| 残差连接(Residual Connection) | 将前一层输出直接叠加到后一层输入,解决深层模型的梯度消失问题 |
| 层归一化(Layer Normalization) | 标准化每一层的输入数据分布,加速模型收敛 |
三、预训练 - 模型的通识教育
预训练是大模型的基础学习阶段,目标是让模型从海量数据中学习通用知识(语法、常识、逻辑),而非针对特定任务,是通用智能的核心来源。
1. 预训练任务设计
通过间接目标引导模型学习数据规律,核心任务如下:
| 任务类型 | 核心逻辑 |
|---|---|
| 掩码语言模型(MLM) | 随机遮盖文本中的部分 token(如“小明 [MASK] 天去公园”),让模型预测被遮盖的 token,强制学习上下文语义 |
| 因果语言模型(CLM) | 给定前文文本(如“小明今天去公园”),让模型预测下一个 token(如“他很”),模拟人类续写逻辑,适配生成任务 |
| 对比学习(Contrastive Learning) | 对同一文本生成相似样本(同义词替换)和不相似样本(随机打乱),让模型区分差异,提升语义理解能力 |
| 多模态预训练任务 | 给定“图片 + 文本描述”对,让模型学习图文匹配(判断文本是否描述图片)或文本生成图片特征,实现跨模态理解 |
2. 算力支撑
预训练需海量算力,是大模型训练的核心门槛:
(1)硬件需求
通常使用 GPU 集群或 TPU(Google 定制芯片),单台 GPU 内存需 ≥ 40 GB。
(2)分布式训练策略
通过以下策略实现大规模模型训练:
- 数据并行:将数据拆分到多 GPU,同步更新参数
- 模型并行:将模型层拆分到多 GPU,解决单 GPU 内存不足问题
- 流水线并行:多 GPU 按层分工,提升训练效率
3. 训练过程监控
通过多维度监控避免模型训练“走偏”:
(1)损失函数(Loss Function)
监控模型预测值与真实值的差异(如交叉熵损失),损失持续下降说明模型有效学习。
(2)梯度检查
- 梯度爆炸:通过梯度裁剪限制梯度最大值
- 梯度消失:通过残差连接缓解
(3)验证集性能
定期用验证集测试模型(如预测准确率),若验证集性能下降,可能出现过拟合,需及时干预。
① 过拟合定义
模型过度学习训练数据的特征,甚至将噪声、异常值也当作通用规律,导致训练集表现极好、测试集/新数据表现糟糕,失去泛化能力。
② 过拟合本质原因
模型复杂度与数据质量/数量不匹配,具体分为两类:
| 原因类型 | 具体表现 |
|---|---|
| 模型能力过剩 | 1. 模型结构过于复杂; 2. 参数数量过多(参数量远超数据量,模型强行记忆样本而非学习通用模式) |
| 数据支撑不足 | 1. 数据量过少(样本无法覆盖所有场景); 2. 数据噪声多(错误标注、异常值); 3. 数据分布不均(训练集与测试集场景差异大) |
③ 过拟合解决方法
| 优化维度 | 具体措施 |
|---|---|
| 模型层面 | 1. 简化模型结构(减少层数/神经元数、降低多项式次数); 2. 正则化(添加 L1/L2 正则,限制参数过大); 3. 早停(验证误差停止下降时停止训练); 4. 随机去除依赖(训练时随机关闭部分神经元); 5. 集成学习(训练多个简单模型,投票/平均结果) |
| 数据层面 | 1. 增加数据量; 2. 数据增强(图片旋转、文字同义替换等); 3. 清洗数据(删除异常值、修正错误标注) |
四、微调 - 模型的专业培训
预训练模型仅具备通用知识,微调目标是让模型适配医疗问答、法律文档分析等特定场景。
1. 两类主流微调方式
| 微调方式 | 核心逻辑 | 适用场景 | 缺点 |
|---|---|---|---|
| 全参数微调(Full Fine-Tuning) | 冻结预训练模型部分底层参数(学习通用规律的层),微调上层参数;或直接微调全部参数 | 有足量任务特定数据(如 10 万+ 医疗问答样本),需深度适配任务 | 参数规模大,算力成本高 |
| 参数高效微调(PEFT) | 仅微调少量任务相关参数,降低算力需求 | 任务数据有限(如 1 万+ 样本)或算力有限(单 GPU 微调) | - |
(1)PEFT 常见方法
- LoRA(Low-Rank Adaptation):在注意力层插入低秩矩阵(如 12288×12288 拆分为 12288×64 + 64×12288),仅微调低秩矩阵参数(参数量减少 10-100 倍)。
- Prefix Tuning:在输入文本前添加可学习的前缀 token,仅微调前缀参数,不修改预训练模型主体。
2. 微调数据与任务设计
- 数据:需高质量任务特定数据(如医疗微调需“症状-诊断”配对数据),可通过人工标注或数据增强获取。
- 任务:将具体需求转化为模型可理解的格式(如“输入:患者发烧、咳嗽,输出:可能为感冒,建议多喝热水”)。
五、评估与迭代 - 模型的质量验收
评估目标是全面检测模型性能上限和缺陷,迭代优化针对性解决问题。
1. 评估维度
(1)定量评估(指标衡量)
| 任务类型 | 核心指标 | 指标说明 |
|---|---|---|
| 文本生成 | BLEU、ROUGE、CIDEr | 衡量生成文本与参考文本的相似度 |
| 分类任务(情感分析等) | 准确率、F1 分数 | 衡量分类正确性 |
| 问答任务 | EM(精确匹配)、F1 分数 | 衡量回答与标准答案的匹配度 |
| 多模态任务 | 图文匹配准确率、生成图片质量 | 衡量跨模态理解与生成能力 |
(2)定性评估(人工检测)
- 人工标注:邀请领域专家评估模型输出(如医疗问答准确性、法律建议合规性)。
- 用户反馈:通过小规模内测,收集生成内容流畅度、逻辑连贯性等评价。
- 安全检测:测试模型是否生成暴力、歧视、虚假信息等有害内容,通过 RLHF 等对齐技术修正。
2. 迭代优化
针对评估发现的缺陷,从多维度优化:
| 优化维度 | 具体措施 |
|---|---|
| 数据层面 | 某领域表现差时,补充该领域数据重新微调 |
| 模型层面 | 过拟合时,增加数据量或使用正则化(如 Dropout 随机丢弃部分神经元) |
| 算法层面 | 生成内容重复时,优化采样策略(如 Top-P 采样替代贪心采样) |
六、部署与优化 - 模型的落地应用
训练后的模型需部署到实际场景(APP、API 接口),核心目标是降低响应延迟、减少资源消耗,提升使用效率。
1. 模型压缩(减小模型体积)
| 压缩方式 | 核心逻辑 |
|---|---|
| 量化(Quantization) | 将模型参数从 32 位浮点数(FP32)压缩为 16 位(FP16)、8 位(FP8/INT8)、4 位(INT4),大幅降低模型体积和推理延迟 |
| 剪枝(Pruning) | 删除模型中冗余的神经元或连接(如对损失贡献小的注意力头),在不影响性能的前提下减少参数量 |
| 蒸馏(Knowledge Distillation) | 用大模型(教师模型)的输出指导小模型(学生模型)训练,让小模型具备接近大模型的性能 |
2. 部署框架和环境
(1)框架选择
使用 TensorRT(NVIDIA)、ONNX Runtime(跨平台)、TensorFlow Lite(移动端)等框架优化推理速度。
(2)部署方式
| 部署方式 | 说明 |
|---|---|
| 云端部署 | 将模型部署到云服务器,通过 API 接口提供服务 |
| 端侧部署 | 将压缩后的模型部署到终端设备(手机、嵌入式设备),如手机端 AI 助手(无需联网响应) |
(3)并发处理
通过负载均衡、批量推理(合并多个用户请求为一批处理),提升模型并发能力。
3. 持续监控和更新
- 线上监控:监控模型响应延迟、准确率、错误率。
- 模型更新:定期用新数据(如最新行业知识)微调模型,避免知识过时。