AI 大模型的训练过程

一、数据准备 - 模型的原材料奠基

数据准备的核心目标是构建干净、全面、合规的训练数据集,主要包含数据采集、数据清洗、数据预处理三个核心环节。

1. 数据采集

(1)数据来源

数据来源分为三类,覆盖文本与多模态场景:

  • 公开文本库:学术论文、百科全书、新闻资讯、书籍、论坛、博客等
  • 多模态数据:图片、视频、音频、表格等
  • 私有数据:企业内部文档、行业数据集等

(2)低质来源过滤

需排除垃圾邮件、重复内容、极端言论、错误信息(如伪科学、谣言),避免模型学习偏见。

2. 数据清洗

核心目的是剔除杂质、统一格式,解决原始数据的噪声、冗余、偏见问题:

(1)格式标准化

  • 文本:统一为 UTF-8 编码
  • 图片:统一分辨率、格式(jpg/png)
  • 音频:统一采样率

(2)噪声过滤

  • 用正则表达式删除特殊符号、乱码
  • 通过语法检查工具(如 Grammarly API)修正错别字

(3)去重与去冗余

  • 用哈希算法识别重复文本
  • 合并相似内容

(4)偏见修正

通过人工标注或算法检测(如关键词匹配)删除/平衡含性别、地域歧视等偏见的内容。

3. 数据预处理

核心是将原始数据转换为模型可理解的格式,并完成数据集划分:

(1)模型可理解格式转换

数据类型 转换方式
文本数据 先分词拆分为最小语义单元(token),再通过嵌入层将 token 映射为高维向量
图像数据 将像素值归一化(如 0-255 缩放到 0-1),通过卷积层提取特征向量

(2)数据集划分

数据集类型 占比 作用
训练集 80%-90% 用于模型学习核心规律
验证集 5%-10% 用于调整模型参数
测试集 5%-10% 用于最终性能评估(不参与训练)

二、模型架构设计 - 模型的骨架搭建

模型架构决定数据的流动方式和学习逻辑,是大模型能力的核心载体,主要包含核心网络架构选择、超参数设定、辅助模块设计三部分。

1. 核心网络架构选择

(1)Transformer 架构(主流)

Transformer 是 2017 年 Google 团队在论文《Attention Is All You Need》中提出的神经网络架构,是当前大语言模型(如 GPT、BERT)、多模态模型(如 DALL·E)的核心基础。其核心优势是基于自注意力机制实现并行计算,突破了传统 RNN/LSTM 的序列依赖限制。

① 核心优势
优势项 具体说明
并行计算能力 RNN 需按序列顺序处理,Transformer 可同时处理所有 token,训练效率提升数十倍
长距离依赖捕捉 RNN 易因梯度消失导致长文本(1000 词以上)依赖捕捉能力弱;自注意力机制可直接计算任意两个 token 的关联,无距离限制
灵活性与扩展性 适配 NLP 任务(翻译、生成、理解),也可扩展至多模态领域(图像、音频),仅需修改输入嵌入方式
② 自注意力机制

能捕捉数据中的长距离依赖关系,例如“小明今天去公园,他很开心”中,“他”与“小明”的关联。

③ 核心架构(编码器-解码器)

整体流程:

输入文本 ->> 编码器 ->> 上下文特征 ->> 解码器 ->> 输出文本
  • 编码器(Encoder):由 N 个相同的编码器层堆叠而成,将输入序列转化为含上下文信息的特征向量。

    每个编码器层包含两个核心子层,且子层均配备残差连接(解决梯度消失)和层归一化(加速收敛):

    • 多头自注意力层:让每个 token 关注序列中其他相关 token,捕捉上下文依赖;“多头”指将注意力拆分到多个并行子空间,捕捉不同维度关联(语义、语法),最后拼接增强表达。
    • 前馈神经网络:对自注意力层输出做非线性变换(两层线性映射 + ReLU 激活函数),为每个 token 添加局部特征(无跨 token 交互)。
  • 解码器(Decoder):由 N 个相同的解码器层堆叠而成,基于编码器输出的上下文特征生成目标序列。

    每个解码器层包含三个核心子层,同样配备残差连接和层归一化:

    • 掩码多头自注意力层:与编码器自注意力类似,添加掩码机制确保生成第 i 个 token 时,仅能看到前 i-1 个 token(避免泄露后续信息)。
    • 编码器-解码器自注意力层:让解码器的每个 token 关注编码器输出中相关的 token,建立输入与输出的关联。
    • 前馈神经网络:与编码器的前馈神经网络结构一致,处理局部特征。
④ 输入输出处理
环节 处理方式
输入(文本→特征向量) 1. 嵌入层:将每个 token 转化为固定维度向量(如 512 维),理解词的语义;
2. 位置编码:通过正弦/余弦函数为 token 添加位置信息,区分“我爱你”和“你爱我”的语义差异(弥补 Transformer 无时序结构的缺陷)
输出(特征向量→文本) 解码器最终输出经线性层(映射到词表维度)+ softmax 层(生成词概率分布),通过采样(如贪心算法)选择概率最高的词,逐步生成完整序列
⑤ 基于 Transformer 的经典模型
模型 核心架构 适用场景
BERT 仅使用编码器 文本理解任务(文本分类、问答)
GPT 仅使用解码器 文本生成任务(文本续写、对话)
T5 编码器+解码器 兼顾理解与生成任务

(2)RNN/LSTM

适用于序列数据,但长距离依赖捕捉能力弱,已逐渐被 Transformer 替代。

(3)卷积神经网络(CNN)

多用于图像、音频的局部特征提取。

2. 超参数设定

超参数是训练前人工设定的全局参数,直接影响模型复杂度和训练成本:

(1)模型规模相关参数

  • 层数(L)
  • 注意力头数(H)
  • 隐藏层维度(D)

(2)训练相关参数

参数 说明
批次大小(Batch Size) 一次训练的样本数量,需匹配 GPU 内存
学习率(Learning Rate) 控制参数更新幅度,通常从 0.001 开始衰减
训练轮次(Epochs) 数据集重复训练的次数,避免过拟合

3. 辅助模块设计

模块 作用
位置编码(Positional Encoding) 为 token 添加位置信息,区分语序导致的语义差异(如“我吃苹果”和“苹果吃我”),常用正弦/余弦函数或可学习向量实现
残差连接(Residual Connection) 将前一层输出直接叠加到后一层输入,解决深层模型的梯度消失问题
层归一化(Layer Normalization) 标准化每一层的输入数据分布,加速模型收敛

三、预训练 - 模型的通识教育

预训练是大模型的基础学习阶段,目标是让模型从海量数据中学习通用知识(语法、常识、逻辑),而非针对特定任务,是通用智能的核心来源。

1. 预训练任务设计

通过间接目标引导模型学习数据规律,核心任务如下:

任务类型 核心逻辑
掩码语言模型(MLM) 随机遮盖文本中的部分 token(如“小明 [MASK] 天去公园”),让模型预测被遮盖的 token,强制学习上下文语义
因果语言模型(CLM) 给定前文文本(如“小明今天去公园”),让模型预测下一个 token(如“他很”),模拟人类续写逻辑,适配生成任务
对比学习(Contrastive Learning) 对同一文本生成相似样本(同义词替换)和不相似样本(随机打乱),让模型区分差异,提升语义理解能力
多模态预训练任务 给定“图片 + 文本描述”对,让模型学习图文匹配(判断文本是否描述图片)或文本生成图片特征,实现跨模态理解

2. 算力支撑

预训练需海量算力,是大模型训练的核心门槛:

(1)硬件需求

通常使用 GPU 集群或 TPU(Google 定制芯片),单台 GPU 内存需 ≥ 40 GB。

(2)分布式训练策略

通过以下策略实现大规模模型训练:

  • 数据并行:将数据拆分到多 GPU,同步更新参数
  • 模型并行:将模型层拆分到多 GPU,解决单 GPU 内存不足问题
  • 流水线并行:多 GPU 按层分工,提升训练效率

3. 训练过程监控

通过多维度监控避免模型训练“走偏”:

(1)损失函数(Loss Function)

监控模型预测值与真实值的差异(如交叉熵损失),损失持续下降说明模型有效学习。

(2)梯度检查

  • 梯度爆炸:通过梯度裁剪限制梯度最大值
  • 梯度消失:通过残差连接缓解

(3)验证集性能

定期用验证集测试模型(如预测准确率),若验证集性能下降,可能出现过拟合,需及时干预。

① 过拟合定义

模型过度学习训练数据的特征,甚至将噪声、异常值也当作通用规律,导致训练集表现极好、测试集/新数据表现糟糕,失去泛化能力。

② 过拟合本质原因

模型复杂度与数据质量/数量不匹配,具体分为两类:

原因类型 具体表现
模型能力过剩 1. 模型结构过于复杂;
2. 参数数量过多(参数量远超数据量,模型强行记忆样本而非学习通用模式)
数据支撑不足 1. 数据量过少(样本无法覆盖所有场景);
2. 数据噪声多(错误标注、异常值);
3. 数据分布不均(训练集与测试集场景差异大)
③ 过拟合解决方法
优化维度 具体措施
模型层面 1. 简化模型结构(减少层数/神经元数、降低多项式次数);
2. 正则化(添加 L1/L2 正则,限制参数过大);
3. 早停(验证误差停止下降时停止训练);
4. 随机去除依赖(训练时随机关闭部分神经元);
5. 集成学习(训练多个简单模型,投票/平均结果)
数据层面 1. 增加数据量;
2. 数据增强(图片旋转、文字同义替换等);
3. 清洗数据(删除异常值、修正错误标注)

四、微调 - 模型的专业培训

预训练模型仅具备通用知识,微调目标是让模型适配医疗问答、法律文档分析等特定场景。

1. 两类主流微调方式

微调方式 核心逻辑 适用场景 缺点
全参数微调(Full Fine-Tuning) 冻结预训练模型部分底层参数(学习通用规律的层),微调上层参数;或直接微调全部参数 有足量任务特定数据(如 10 万+ 医疗问答样本),需深度适配任务 参数规模大,算力成本高
参数高效微调(PEFT) 仅微调少量任务相关参数,降低算力需求 任务数据有限(如 1 万+ 样本)或算力有限(单 GPU 微调) -

(1)PEFT 常见方法

  • LoRA(Low-Rank Adaptation):在注意力层插入低秩矩阵(如 12288×12288 拆分为 12288×64 + 64×12288),仅微调低秩矩阵参数(参数量减少 10-100 倍)。
  • Prefix Tuning:在输入文本前添加可学习的前缀 token,仅微调前缀参数,不修改预训练模型主体。

2. 微调数据与任务设计

  • 数据:需高质量任务特定数据(如医疗微调需“症状-诊断”配对数据),可通过人工标注或数据增强获取。
  • 任务:将具体需求转化为模型可理解的格式(如“输入:患者发烧、咳嗽,输出:可能为感冒,建议多喝热水”)。

五、评估与迭代 - 模型的质量验收

评估目标是全面检测模型性能上限和缺陷,迭代优化针对性解决问题。

1. 评估维度

(1)定量评估(指标衡量)

任务类型 核心指标 指标说明
文本生成 BLEU、ROUGE、CIDEr 衡量生成文本与参考文本的相似度
分类任务(情感分析等) 准确率、F1 分数 衡量分类正确性
问答任务 EM(精确匹配)、F1 分数 衡量回答与标准答案的匹配度
多模态任务 图文匹配准确率、生成图片质量 衡量跨模态理解与生成能力

(2)定性评估(人工检测)

  • 人工标注:邀请领域专家评估模型输出(如医疗问答准确性、法律建议合规性)。
  • 用户反馈:通过小规模内测,收集生成内容流畅度、逻辑连贯性等评价。
  • 安全检测:测试模型是否生成暴力、歧视、虚假信息等有害内容,通过 RLHF 等对齐技术修正。

2. 迭代优化

针对评估发现的缺陷,从多维度优化:

优化维度 具体措施
数据层面 某领域表现差时,补充该领域数据重新微调
模型层面 过拟合时,增加数据量或使用正则化(如 Dropout 随机丢弃部分神经元)
算法层面 生成内容重复时,优化采样策略(如 Top-P 采样替代贪心采样)

六、部署与优化 - 模型的落地应用

训练后的模型需部署到实际场景(APP、API 接口),核心目标是降低响应延迟、减少资源消耗,提升使用效率。

1. 模型压缩(减小模型体积)

压缩方式 核心逻辑
量化(Quantization) 将模型参数从 32 位浮点数(FP32)压缩为 16 位(FP16)、8 位(FP8/INT8)、4 位(INT4),大幅降低模型体积和推理延迟
剪枝(Pruning) 删除模型中冗余的神经元或连接(如对损失贡献小的注意力头),在不影响性能的前提下减少参数量
蒸馏(Knowledge Distillation) 用大模型(教师模型)的输出指导小模型(学生模型)训练,让小模型具备接近大模型的性能

2. 部署框架和环境

(1)框架选择

使用 TensorRT(NVIDIA)、ONNX Runtime(跨平台)、TensorFlow Lite(移动端)等框架优化推理速度。

(2)部署方式

部署方式 说明
云端部署 将模型部署到云服务器,通过 API 接口提供服务
端侧部署 将压缩后的模型部署到终端设备(手机、嵌入式设备),如手机端 AI 助手(无需联网响应)

(3)并发处理

通过负载均衡、批量推理(合并多个用户请求为一批处理),提升模型并发能力。

3. 持续监控和更新

  • 线上监控:监控模型响应延迟、准确率、错误率。
  • 模型更新:定期用新数据(如最新行业知识)微调模型,避免知识过时。