Minimind/README.md
2025-08-01 15:54:21 +08:00

7.2 KiB
Raw Blame History

MiniMind 预训练项目开发文档

项目概述

MiniMind 是一个基于 Transformer 架构的大语言模型预训练项目集成了先进的知识图谱技术和混合专家模型MOE架构。项目采用 PyTorch 实现,支持分布式训练和高效的内存管理。

核心架构

1. 主训练入口

train_pretrain_accelerate.py - 主训练脚本,包含完整的训练流程:

  • 内存监控系统: 实时监控系统内存和 GPU 内存使用情况
  • 分布式训练: 基于 Accelerate 和 DeepSpeed 的分布式训练支持
  • 知识库初始化: 从 JSON 数据文件初始化知识库,支持缓存机制
  • 训练循环: 包含梯度累积、学习率调度、损失计算等完整训练逻辑

2. 模型架构

model/model.py - 核心模型实现:

class MiniMindLM(PreTrainedModel):
    """主要的 Transformer 模型类"""
    - 标准 Transformer 架构decoder-only
    - RMSNorm 归一化层
    - 旋转位置编码RoPE
    - Flash Attention 支持
    - 知识库集成

model/LMConfig.py - 模型配置类:

class LMConfig(PretrainedConfig):
    """模型配置管理"""
    - 基础模型参数dim, n_layers, n_heads 
    - MOE 相关配置
    - 知识图谱配置
    - 数据库功能配置

3. 知识库系统

KnowledgeDataset 类(在 model/model.py 中):

  • 二维分解键空间: 使用 Product Key 方法优化大规模知识库检索
  • 智能选择策略: 动态调整知识库访问模式
  • 可训练参数: 键向量支持梯度更新
  • 缓存机制: 支持知识库预处理结果缓存

4. 数据处理

model/dataset.py - 数据集处理:

class PretrainDataset(Dataset):
    """预训练数据集类"""
    - JSONL 格式数据加载
    - 自动添加 BOS/EOS 标记
    - 序列填充和截断
    - 损失掩码生成

核心功能模块

1. 内存管理

项目实现了完善的内存监控系统:

def get_memory_usage():
    """获取系统内存使用情况"""
    
def get_cuda_memory_usage():
    """获取 GPU 内存使用情况"""
    
def log_memory_status():
    """记录详细的内存状态"""

2. 知识库初始化

知识库初始化流程:

  1. 数据加载: 从 JSON 文件加载句子数据
  2. 重要性排序: 根据 importance_score 对句子排序
  3. 分词处理: 使用 tokenizer 将句子转换为 token 序列
  4. 长度处理: 截断或填充到指定长度
  5. 缓存机制: 支持处理结果缓存以加速后续训练

3. 分布式训练配置

Accelerate 配置 (accelerate_config.yaml):

compute_environment: LOCAL_MACHINE
distributed_type: DEEPSPEED
mixed_precision: bf16
num_processes: 4
deepspeed_config:
  deepspeed_config_file: ds_config.json

DeepSpeed 配置 (ds_config.json):

{
  "zero_optimization": {
    "stage": 2,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    }
  },
  "optimizer": {
    "type": "AdamW"
  },
  "scheduler": {
    "type": "WarmupLR"
  }
}

主要配置参数

模型配置

  • dim: 隐藏层维度(默认 512
  • n_layers: Transformer 层数(默认 8
  • n_heads: 注意力头数(默认 32
  • n_kv_heads: KV 注意力头数(默认 8
  • max_seq_len: 最大序列长度(默认 512
  • vocab_size: 词汇表大小(默认 6400

知识库配置

  • knowledge_num: 知识库条目数量(默认 1048576
  • knowledge_length: 每个知识条目的长度(默认 32
  • knowledge_dim: 知识向量维度(默认 128

训练配置

  • batch_size: 批次大小(默认 128
  • learning_rate: 学习率(默认 8e-5
  • accumulation_steps: 梯度累积步数(默认 16
  • warmup_iters: 预热迭代次数

数据格式

预训练数据格式

{"text": "这是一个训练样本的文本内容"}

知识库数据格式

[
  {
    "target": [
      {
        "sentence": "知识库中的句子内容",
        "importance_score": 0.95
      }
    ]
  }
]

工具脚本

数据预处理脚本

  • preprocessing/preprocess_pretrain.py: 预训练数据预处理
  • preprocessing/preprocess_trex.py: 三元组数据预处理
  • preprocessing/preprocess_combined_json.py: 组合数据预处理

模型工具

  • dataset_decoder.py: 解码模型中的知识库内容

运行脚本

  • run_file/experiment_*.sh: 各种实验配置的运行脚本

依赖管理

项目使用 pyproject.toml 管理依赖:

核心依赖

  • torch >= 2.7.1: 深度学习框架
  • transformers >= 4.52.4: Transformer 模型库
  • accelerate >= 1.7.0: 分布式训练
  • deepspeed >= 0.17.0: 深度学习优化
  • swanlab >= 0.6.4: 实验监控

开发工具

  • tokenizers >= 0.21.1: 高效分词
  • datasets >= 2.21.0: 数据集处理
  • numpy >= 1.26.4: 数值计算
  • pandas >= 2.0.0: 数据处理

内存优化策略

  1. 梯度累积: 通过累积梯度减少内存占用
  2. 混合精度训练: 使用 bf16 减少内存使用
  3. ZeRO 优化: DeepSpeed ZeRO Stage 2 优化器状态分片
  4. 知识库缓存: 预处理结果缓存避免重复计算
  5. 垃圾回收: 定期清理未使用的内存

监控和日志

SwanLab 集成

  • 训练损失监控
  • 学习率变化追踪
  • 内存使用情况记录
  • 训练速度统计

日志系统

  • 时间戳格式化输出
  • 多进程日志同步
  • 内存状态详细记录
  • 训练进度追踪

目录结构详解

.
├── train_pretrain_accelerate.py    # 主训练脚本
├── dataset_decoder.py              # 知识库解码工具
├── model/                          # 模型定义目录
│   ├── LMConfig.py                 # 模型配置类
│   ├── model.py                    # 主模型实现
│   ├── dataset.py                  # 数据集处理
│   ├── model_no_feed.py            # 无反馈模型变体
│   ├── model_original.py           # 原始模型变体
│   └── minimind_tokenizer/         # 分词器文件
├── preprocessing/                   # 数据预处理脚本
├── run_file/                       # 实验运行脚本
├── models/                         # 模型检查点存储
├── accelerate_config.yaml          # Accelerate 配置
├── ds_config.json                  # DeepSpeed 配置
├── pyproject.toml                  # 项目依赖配置
└── uv.lock                         # 依赖锁定文件

开发注意事项

  1. 模型变体: 项目包含多个模型变体,选择合适的模型类型
  2. 知识库大小: 根据可用内存调整知识库参数
  3. 分布式配置: 根据硬件配置调整并行参数
  4. 缓存管理: 合理使用缓存机制避免重复计算
  5. 内存监控: 关注内存使用情况,及时调整批次大小

扩展点

  1. 新模型架构: 通过继承 PreTrainedModel 实现新的模型变体
  2. 数据处理: 扩展 PretrainDataset 支持新的数据格式
  3. 知识库优化: 改进 KnowledgeDataset 的检索策略
  4. 训练策略: 在主训练循环中添加新的训练技巧
  5. 监控扩展: 集成更多监控指标和可视化工具