DynamicKV-LLM Pretrain v1.2.2:新数据集;使用uv;消除内存泄漏

This commit is contained in:
iomgaa 2025-06-25 20:27:28 +08:00
parent 770c34f0e3
commit d6617702a5
19 changed files with 6601 additions and 644 deletions

1
.gitignore vendored
View File

@ -9,3 +9,4 @@ models/sentence_transformers_cache/
qwen2-1.7B/
images/
cache/
.venv/

112
.vscode/launch.json vendored
View File

@ -2,101 +2,39 @@
"version": "0.2.0",
"configurations": [
{
"name": "Debug Train Pretrain Accelerate",
"name": "DynamicKV-LLM Mini Minimind Debug",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
"console": "integratedTerminal",
"python": "/opt/conda/envs/mini/bin/python",
"cwd": "${workspaceFolder}",
"module": "accelerate.commands.launch",
"args": [
"--num_processes=1",
"--mixed_precision=bf16",
"--main_process_port=29500",
"train_pretrain_accelerate.py",
"--batch_size", "16",
"--knowledge_num", "48020",
"--num_workers", "1",
"--epochs", "4",
"--learning_rate", "2e-4",
"--dtype", "bfloat16",
"--accumulation_steps", "32",
"--grad_clip", "1.0",
"--log_interval", "50",
"--save_interval", "10000",
"--dim", "512",
"--n_layers", "8",
"--max_seq_len", "512",
"--use_flash_attn",
"--profile",
"--profile_interval", "10"
],
"env": {
"PYTHONPATH": "${workspaceFolder}",
"CUDA_VISIBLE_DEVICES": "0"
},
"justMyCode": false,
"stopOnEntry": false,
"redirectOutput": true
},
{
"name": "Debug Train Pretrain Accelerate (Multi-GPU)",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
"console": "integratedTerminal",
"python": "/opt/conda/envs/mini/bin/python",
"args": [
"--hidden_size", "512",
"--max_seq_len", "512",
"--n_layers", "8",
"--batch_size", "8",
"--epochs", "1"
],
"cwd": "${workspaceFolder}",
"env": {
"PYTHONPATH": "${workspaceFolder}",
"CUDA_VISIBLE_DEVICES": "0,1"
},
"justMyCode": false,
"stopOnEntry": false,
"redirectOutput": true
},
{
"name": "Debug Train Pretrain Accelerate (Small Test)",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
"console": "integratedTerminal",
"python": "/opt/conda/envs/mini/bin/python",
"args": [
"--hidden_size", "512",
"--max_seq_len", "512",
"--n_layers", "8",
"--batch_size", "2",
"--epochs", "1",
"--log_interval", "10",
"--save_interval", "100",
"--accumulation_steps", "4"
],
"cwd": "${workspaceFolder}",
"env": {
"PYTHONPATH": "${workspaceFolder}",
"CUDA_VISIBLE_DEVICES": "0",
"WANDB_MODE": "offline"
},
"justMyCode": false,
"stopOnEntry": false,
"redirectOutput": true
},
{
"name": "Debug ExtractDB Comparison",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
"console": "integratedTerminal",
"python": "/opt/conda/envs/mini/bin/python",
"args": [
"--hidden_size", "512",
"--max_seq_len", "256",
"--n_layers", "4",
"--batch_size", "2",
"--epochs", "1",
"--log_interval", "10",
"--save_interval", "200",
"--accumulation_steps", "2",
"--comparison_mode",
"--knowledge_num", "256",
"--knowledge_length", "64",
"--comparison_mode"
],
"cwd": "${workspaceFolder}",
"env": {
"PYTHONPATH": "${workspaceFolder}",
"CUDA_VISIBLE_DEVICES": "0",
"WANDB_MODE": "offline"
},
"justMyCode": false,
"stopOnEntry": false,
"redirectOutput": true
"stopOnEntry": false
}
]
}

199
README.md Normal file
View File

@ -0,0 +1,199 @@
<div align="center">
![logo](./images/logo.png)
</div>
<div align="center">
![visitors](https://visitor-badge.laobi.icu/badge?page_id=jingyaogong/minimind)
[![GitHub Repo stars](https://img.shields.io/github/stars/jingyaogong/minimind?style=social)](https://github.com/jingyaogong/minimind/stargazers)
[![GitHub Code License](https://img.shields.io/github/license/jingyaogong/minimind)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/jingyaogong/minimind)](https://github.com/jingyaogong/minimind/commits/master)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/jingyaogong/minimind/pulls)
[![Collection](https://img.shields.io/badge/🤗-MiniMind%20%20Collection-blue)](https://huggingface.co/collections/jingyaogong/minimind-66caf8d999f5c7fa64f399e5)
</div>
# 📌 数据介绍
## Tokenizer
分词器将单词从自然语言通过“词典”映射到`0, 1, 36`这样的数字,可以理解为数字就代表了单词在“词典”中的页码。
可以选择自己构造词表训练一个“词典”,代码可见`./scripts/train_tokenizer.py`仅供学习参考若非必要无需再自行训练MiniMind已自带tokenizer
或者选择比较出名的开源大模型分词器,
正如同直接用新华/牛津词典的优点是token编码压缩率很好缺点是页数太多动辄数十万个词汇短语
自己训练的分词器,优点是词表长度和内容随意控制,缺点是压缩率很低(例如"hello"也许会被拆分为"h e l l o"
五个独立的token且生僻词难以覆盖。
“词典”的选择固然很重要LLM的输出本质上是SoftMax到词典N个词的多分类问题然后通过“词典”解码到自然语言。
因为MiniMind体积需要严格控制为了避免模型头重脚轻词嵌入embedding层参数在LLM占比太高所以词表长度短短益善。
<details style="color:rgb(128,128,128)">
<summary>Tokenizer介绍</summary>
第三方强大的开源模型例如Yi、qwen、chatglm、mistral、Llama3的tokenizer词表长度如下
<table>
<tr><th>Tokenizer模型</th><th>词表大小</th><th>来源</th></tr>
<tr><td>yi tokenizer</td><td>64,000</td><td>01万物中国</td></tr>
<tr><td>qwen2 tokenizer</td><td>151,643</td><td>阿里云(中国)</td></tr>
<tr><td>glm tokenizer</td><td>151,329</td><td>智谱AI中国</td></tr>
<tr><td>mistral tokenizer</td><td>32,000</td><td>Mistral AI法国</td></tr>
<tr><td>llama3 tokenizer</td><td>128,000</td><td>Meta美国</td></tr>
<tr><td>minimind tokenizer</td><td>6,400</td><td>自定义</td></tr>
</table>
> 👉2024-09-17更新为了防止过去的版本歧义&控制体积minimind所有模型均使用minimind_tokenizer分词废弃所有mistral_tokenizer版本。
```
# 一些自言自语
> 尽管minimind_tokenizer长度很小编解码效率弱于qwen2、glm等中文友好型分词器。
> 但minimind模型选择了自己训练的minimind_tokenizer作为分词器以保持整体参数轻量避免编码层和计算层占比失衡头重脚轻因为minimind的词表大小只有6400。
> 且minimind在实际测试中没有出现过生僻词汇解码失败的情况效果良好。
> 由于自定义词表压缩长度到6400使得LLM总参数量最低只有25.8M。
> 训练数据`tokenizer_train.jsonl`均来自于`匠数大模型数据集`,这部分数据相对次要,如需训练可以自由选择。
```
</details>
## Ⅱ Pretrain数据
经历了MiniMind-V1的低质量预训练数据导致模型胡言乱语的教训`2025-02-05` 之后决定不再采用大规模无监督的数据集做预训练。
进而尝试把[匠数大模型数据集](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)的中文部分提取出来,
清洗出字符`<512`长度的大约1.6GB的语料直接拼接成预训练数据 `pretrain_hq.jsonl`hq即为high
quality当然也还不算high提升数据质量无止尽
文件`pretrain_hq.jsonl` 数据格式为
```bash
{"text": "如何才能摆脱拖延症? 治愈拖延症并不容易,但以下建议可能有所帮助..."}
```
## Ⅲ SFT数据
[匠数大模型SFT数据集](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
“是一个完整、格式统一、安全的大模型训练和研究资源。
从网络上的公开数据源收集并整理了大量开源数据集,对其进行了格式统一,数据清洗,
包含10M条数据的中文数据集和包含2M条数据的英文数据集。”
以上是官方介绍下载文件后的数据总量大约在4B tokens肯定是适合作为中文大语言模型的SFT数据的。
但是官方提供的数据格式很乱全部用来sft代价太大。
我将把官方数据集进行了二次清洗,把含有符号污染和噪声的条目去除;另外依然只保留了总长度`<512`
的内容,此阶段希望通过大量对话补充预训练阶段欠缺的知识。
导出文件为`sft_512.jsonl`(~7.5GB)。
[Magpie-SFT数据集](https://www.modelscope.cn/organization/Magpie-Align)
收集了~1M条来自Qwen2/2.5的高质量对话,我将这部分数据进一步清洗,把总长度`<2048`的部分导出为`sft_2048.jsonl`(~9GB)。
长度`<1024`的部分导出为`sft_1024.jsonl`(~5.5GB)用大模型对话数据直接进行sft就属于“黑盒蒸馏”的范畴。
进一步清洗前两步sft的数据只保留中文字符占比高的内容筛选长度`<512`的对话,得到`sft_mini_512.jsonl`(~1.2GB)。
所有sft文件 `sft_X.jsonl` 数据格式均为
```text
{
"conversations": [
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "你好!"},
{"role": "user", "content": "再见"},
{"role": "assistant", "content": "再见!"}
]
}
```
## Ⅳ RLHF数据
来自[Magpie-DPO数据集](https://www.modelscope.cn/datasets/Magpie-Align/MagpieLM-DPO-Data-v0.1)
大约200k条偏好数据均是英文生成自Llama3.1-70B/8B可以用于训练奖励模型优化模型回复质量使其更加符合人类偏好。
这里将数据总长度`<3000`的内容重组为`dpo.jsonl`(~0.9GB),包含`chosen``rejected`两个字段,`chosen`
为偏好的回复,`rejected`为拒绝的回复。
文件 `dpo.jsonl` 数据格式为
```text
{
"chosen": [
{"content": "Q", "role": "user"},
{"content": "good answer", "role": "assistant"}
],
"rejected": [
{"content": "Q", "role": "user"},
{"content": "bad answer", "role": "assistant"}
]
}
```
## Reason数据集
不得不说2025年2月谁能火的过DeepSeek...
也激发了我对RL引导的推理模型的浓厚兴趣目前已经用Qwen2.5复现了R1-Zero。
如果有时间+效果work但99%基模能力不足我会在之后更新MiniMind基于RL训练的推理模型而不是蒸馏模型。
时间有限,最快的低成本方案依然是直接蒸馏(黑盒方式)。
耐不住R1太火短短几天就已经存在一些R1的蒸馏数据集[R1-Llama-70B](https://www.modelscope.cn/datasets/Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B)、[R1-Distill-SFT](https://www.modelscope.cn/datasets/AI-ModelScope/R1-Distill-SFT)、
[Alpaca-Distill-R1](https://huggingface.co/datasets/shareAI/Alpaca-Distill-R1-ZH)、
[deepseek_r1_zh](https://huggingface.co/datasets/jinliuxi/deepseek_r1_zh)等等,纯中文的数据可能比较少。
最终整合它们,导出文件为`r1_mix_1024.jsonl`,数据格式和`sft_X.jsonl`一致。
## Ⅵ 更多数据集
目前已经有[HqWu-HITCS/Awesome-Chinese-LLM](https://github.com/HqWu-HITCS/Awesome-Chinese-LLM)
在收集和梳理中文LLM相关的开源模型、应用、数据集及教程等资料并持续更新这方面的最新进展。全面且专业Respect
---
## Ⅷ 数据集下载
> [!NOTE]
> 2025-02-05后开源MiniMind最终训练所用的所有数据集因此无需再自行预处理大规模数据集避免重复性的数据处理工作。
MiniMind训练数据集 ([ModelScope](https://www.modelscope.cn/datasets/gongjy/minimind-dataset/files) | [HuggingFace](https://huggingface.co/datasets/jingyaogong))
> 无需全部clone可单独下载所需的文件
将下载的数据集文件放到`./dataset/`目录下(✨为推荐的必须项)
```bash
./dataset/
├── dpo.jsonl (909MB)
├── lora_identity.jsonl (22.8KB)
├── lora_medical.jsonl (34MB)
├── pretrain_hq.jsonl (1.6GB, ✨)
├── r1_mix_1024.jsonl (340MB)
├── sft_1024.jsonl (5.6GB)
├── sft_2048.jsonl (9GB)
├── sft_512.jsonl (7.5GB)
├── sft_mini_512.jsonl (1.2GB, ✨)
└── tokenizer_train.jsonl (1GB)
```
<details style="color:rgb(128,128,128)">
<summary>注:各数据集简介</summary>
* `dpo.jsonl` --RLHF阶段数据集
* `lora_identity.jsonl` --自我认知数据集例如你是谁我是minimind...推荐用于lora训练亦可用于全参SFT勿被名字局限
* `lora_medical.jsonl` --医疗问答数据集推荐用于lora训练亦可用于全参SFT勿被名字局限
* `pretrain_hq.jsonl`✨ --预训练数据集整合自jiangshu科技
* `r1_mix_1024.jsonl` --DeepSeek-R1-1.5B蒸馏数据每条数据字符最大长度为1024因此训练时设置max_seq_len=1024
* `sft_1024.jsonl` --整合自Qwen2.5蒸馏数据是sft_2048的子集每条数据字符最大长度为1024因此训练时设置max_seq_len=1024
* `sft_2048.jsonl` --整合自Qwen2.5蒸馏数据每条数据字符最大长度为2048因此训练时设置max_seq_len=2048
* `sft_512.jsonl` --整合自匠数科技SFT数据每条数据字符最大长度为512因此训练时设置max_seq_len=512
* `sft_mini_512.jsonl`✨ --极简整合自匠数科技SFT数据+Qwen2.5蒸馏数据用于快速训练Zero模型每条数据字符最大长度为512因此训练时设置max_seq_len=512
* `tokenizer_train.jsonl` --均来自于`匠数大模型数据集`这部分数据相对次要不推荐自己重复训练tokenizer理由如上如需自己训练tokenizer可以自由选择数据集。
</details>
![dataset](./images/dataset.jpg)
<details style="color:rgb(128,128,128)">
<summary>说明 & 推荐训练方案</summary>
* MiniMind2 Series均经过共约20GB语料训练大约4B tokens即对应上面的数据组合训练结果开销💰💰💰💰💰💰💰💰效果😊😊😊😊😊😊
* 想要最快速度从0实现Zero模型推荐使用`pretrain_hq.jsonl` + `sft_mini_512.jsonl` 的数据组合,具体花销和效果可查看下文表格(开销:💰,效果:😊😊)
* 推荐具备一定算力资源或更在意效果的朋友可以考虑前者完整复现MiniMind2仅有单卡GPU或在乎短时间快速复现的朋友强烈推荐后者
* 【折中方案】亦可选择例如`sft_mini_512.jsonl``sft_1024.jsonl`中等规模数据进行自由组合训练(开销:💰💰💰,效果:😊😊😊😊)。
</details>

View File

@ -1,126 +0,0 @@
# 使用Accelerate+DeepSpeed进行分布式训练
本文档介绍如何使用Accelerate和DeepSpeed进行MiniMind模型的分布式训练。
## 环境准备
首先,确保安装了必要的依赖:
```bash
pip install accelerate deepspeed
```
## 配置文件说明
### 1. DeepSpeed配置文件 (ds_config.json)
DeepSpeed配置文件定义了优化器、学习率调度器和ZeRO优化等参数。主要配置包括
- **ZeRO优化**使用ZeRO-2进行优化可以减少GPU内存使用
- **优化器设置**使用AdamW优化器
- **混合精度训练**支持FP16和BF16
- **梯度累积**:通过"auto"自动设置,与训练脚本参数保持一致
### 2. Accelerate配置文件 (accelerate_config.yaml)
Accelerate配置文件定义了分布式训练的基本设置包括
- **分布式类型**使用DeepSpeed
- **混合精度**使用BF16
- **进程数量**设置为4可根据GPU数量调整
- **DeepSpeed配置**指向ds_config.json文件
## 训练脚本说明
新的训练脚本`train_pretrain_accelerate.py`基于原有的`train_pretrain.py`修改而来,主要变化包括:
1. 使用Accelerator替代了PyTorch原生的分布式功能
2. 移除了torchrun相关的分布式初始化代码
3. 使用Accelerator的API进行模型、优化器和数据加载器的准备
4. 使用Accelerator的API进行反向传播和梯度裁剪
5. 处理了位置编码和未使用参数的问题
## 启动训练
有两种方式启动训练:
### 方法1使用预先配置的accelerate配置文件
```bash
accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
--epochs 3 \
--batch_size 24 \
--learning_rate 2e-4 \
--dtype bfloat16 \
--accumulation_steps 32 \
--grad_clip 1.0 \
--log_interval 100 \
--save_interval 10000 \
--dim 1024 \
--n_layers 32 \
--max_seq_len 1024 \
--use_flash_attn \
--profile \
--profile_interval 10
```
### 方法2使用命令行参数直接配置accelerate
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--multi_gpu \
--num_processes=4 \
--mixed_precision=bf16 \
--main_process_port=29500 \
--deepspeed_config_file ds_config.json \
train_pretrain_accelerate.py \
--epochs 3 \
--batch_size 24 \
--learning_rate 2e-4 \
--dtype bfloat16 \
--accumulation_steps 32 \
--grad_clip 1.0 \
--log_interval 100 \
--save_interval 10000 \
--dim 1024 \
--n_layers 32 \
--max_seq_len 1024 \
--use_flash_attn \
--profile \
--profile_interval 10
```
也可以直接使用提供的脚本:
```bash
bash run_accelerate.sh
```
## Accelerate与DeepSpeed配置的关系
1. **Accelerate**是一个高级API用于简化分布式训练的设置和启动它可以与多种分布式训练后端如DeepSpeed、FSDP等一起使用。
2. **DeepSpeed**是一个优化库专注于大规模模型训练的内存优化和性能提升提供了ZeRO优化等功能。
3. **配置关系**
- Accelerate配置文件YAML定义了使用哪种分布式后端以及基本的分布式设置
- DeepSpeed配置文件JSON定义了DeepSpeed特有的优化参数
- Accelerate通过`deepspeed_config_file`参数引用DeepSpeed配置文件
## 注意事项
1. **位置编码处理**
- 在模型中,`pos_cis`是一个复数张量,在分布式训练中需要特别处理
- 在新的训练脚本中我们使用Accelerator的API来处理这个问题不再需要`_ddp_params_and_buffers_to_ignore`
2. **未使用参数处理**
- 原代码中使用`find_unused_parameters=True`来处理未使用的参数
- 在新的训练脚本中我们直接使用Accelerator的API它会自动处理这个问题
3. **混合精度训练**
- DeepSpeed配置文件中的`fp16``bf16`设置为`"auto"`
- 实际使用的精度由Accelerate的`--mixed_precision`参数决定
4. **梯度累积**
- DeepSpeed配置文件中的`gradient_accumulation_steps`设置为`"auto"`
- 实际的梯度累积步数由训练脚本的`--accumulation_steps`参数决定

View File

@ -1,22 +0,0 @@
## 安装环境
1. 创建conda环境
```bash
conda create -n accelerate python=3.10
conda activate accelerate
```
2. 根据当前系统的cuda版本安装对应的torch、torchvision和torchaudio
3. 根据当前环境的torch和torchvision安装accelerate和deepspeed
4. 安装其他包
```bash
pip install -r requirements.txt
```
## 修改模型
1. 一般情况只修改 `model`文件夹的文件
## 运行
1. 如果在4090或者4070ti上运行 `bash run_file/DynamicKV-LLM_Mini_Minimind.sh`
2. 如果在4张A800上运行 `bash run_file/DynamicKV-LLM_Small_Minimind.sh`

26
experiment.yaml Normal file
View File

@ -0,0 +1,26 @@
# 1. 元数据:需要修改,请为该实验配置名称和描述
name: ycz-minimind-test
description: 测试minimind-test
# 2. 运行环境:一般不修改,如有需求可以手动替换为指定镜像
environment:
image: determinedai/pytorch-ngc:0.38.0 # 此项无需修改
# 3. 指定NAS上的数据集: 需要修改仅修改bind_mounts字段container_path和read_only无需修改
#将<YOUR_DATASET_FOLDER_NAME>替换为您存放在NAS上Volume1/Share/datasets/的数据集文件夹名称
# 请再次确保您已在 NAS上的Volume1/Share/datasets/存放了<YOUR_DATASET_FOLDER_NAME>数据集
# 4. 计算资源:无需修改
resources:
slots_per_trial: 1 # 此项无需修改
resource_pool: rtx4090 # 此项无需修改
# 5. 搜索器:无需修改
searcher:
name: single
metric: test_accuracy
smaller_is_better: false
# 6. 启动入口:无需修改
entrypoint: sh startup.sh

6
main.py Normal file
View File

@ -0,0 +1,6 @@
def main():
print("Hello from minimind!")
if __name__ == "__main__":
main()

View File

@ -2,7 +2,8 @@ import math
import struct
import inspect
import time
import gc
#子空间二维分解+梯度更新
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
@ -67,23 +68,21 @@ class KnowledgeDataset(nn.Module):
## 数据库参数
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
self.keys = nn.Parameter(torch.randn(self.knowledge_num, self.knowledge_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.knowledge_num)
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
# 修改键存储为二维分解空间,设置为可训练参数
self.num_keys = int(math.sqrt(self.knowledge_num))
# 确保keys是可训练参数
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
)
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
# 计算step数目用于动态调整权重
self.step_counter = 0
self.freeze_embedding = False
# 移除批次计数器和更新频率相关代码
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
@ -94,6 +93,15 @@ class KnowledgeDataset(nn.Module):
device = all_scores.device
dtype = all_scores.dtype
# 记录进入智能选择前的内存状态
if hasattr(self, 'step_counter'):
self.step_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.step_counter % 50 == 0: # 每50次调用记录一次
# if torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_ENTER] Step {self.step_counter}: GPU Memory: {allocated_before:.2f}GB")
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
@ -106,7 +114,8 @@ class KnowledgeDataset(nn.Module):
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
#获取flat_tokens对应的index
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
@ -158,84 +167,63 @@ class KnowledgeDataset(nn.Module):
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
# 获取
# 清理中间张量以防止内存泄漏
del all_candidate_indices, unique_indices, inverse_indices
del unique_candidate_features, normalized_candidates, normalized_queries
del batch_best_tokens, batch_best_tokens_embeddings
del flat_tokens, flat_embeddings, pre_update_embeddings
# 使用重新计算的embeddings更新self.keys
if self.is_train:
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 记录退出智能选择后的内存状态(已禁用以提高性能)
# if hasattr(self, 'step_counter') and self.step_counter % 50 == 0:
# if torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_EXIT] Step {self.step_counter}: GPU Memory: {allocated_after:.2f}GB")
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
# 强制垃圾回收(仅在监控步骤)
if hasattr(self, 'step_counter') and self.step_counter % 100 == 0:
gc.collect()
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
return all_best_tokens, all_best_tokens_embeddings
def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings):
if self.freeze_embedding:
return
# 使用pre_update_embeddings更新self.keys
with torch.no_grad():
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512]
pre_update_embeddings = self.to_queries(pre_update_embeddings)
self.keys[pre_update_indices] = pre_update_embeddings
def search_index(self, x):
batch_size, seq_len, dim = x.shape
# collapse sequence dimension by averaging
# 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim]
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
# queries = queries.reshape(batch_size, 2, self.key_dim)
# queries = queries.permute(1, 0, 2)
# 2. 生成查询向量并重塑为两个子查询
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
# 2. 计算queries与keys的相似度
sim = torch.einsum('b d, k d -> b k', queries, self.keys)
# 3. 计算每个子空间的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 3. 在两个子空间分别做top-k
scores_and_indices = sim.topk(self.product_key_topk, dim=-1)
scores, indices = scores_and_indices[0], scores_and_indices[1]
# 4. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 5. 应用智能分层选择策略
# 5. 组合两个子空间的结果
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
# 6. 更新1%的keys
if self.is_train:
# 获取未更新过的keys的索引
not_updated_indices = torch.where(self.has_update_keys == 0)[0]
# 如果有未更新的keys随机选择num_update_keys个进行更新
if len(not_updated_indices) > 0:
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
perm_num = perm.shape[0]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
else:
print("all keys are updated")
# 重置所有keys的更新状态
self.has_update_keys.zero_()
# 重新获取所有可更新的索引
not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device)
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return best_tokens, best_tokens_embeddings
@ -257,6 +245,16 @@ class CrossAttention(nn.Module):
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 监控交叉注意力开始时的内存(已禁用以提高性能)
if not hasattr(self, 'call_counter'):
self.call_counter = 0
self.call_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_ENTER] Call {self.call_counter}: GPU Memory: {allocated_before:.2f}GB")
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
@ -282,6 +280,14 @@ class CrossAttention(nn.Module):
context = self.to_out(context)
# 清理中间张量
del q, k, v, attn_scores, attn_weights
# 监控交叉注意力结束时的内存(已禁用以提高性能)
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_EXIT] Call {self.call_counter}: GPU Memory: {allocated_after:.2f}GB")
return context
class Attention(nn.Module):
@ -520,12 +526,11 @@ class MiniMindLM(PreTrainedModel):
step: int = 0,
**args):
start_pos = args.get('start_pos', 0)
if self.freeze_embedding and step == 0:
self.tok_embeddings.weight.requires_grad = False
# 同时冻结KnowledgeDataset的嵌入更新
self.knowledge_dataset.freeze_embedding = True
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding)
# if self.freeze_embedding and step == 0:
# self.tok_embeddings.weight.requires_grad = False
# # 移除对knowledge_dataset.freeze_embedding的设置让键更新由batch_counter控制
# # self.knowledge_dataset.freeze_embedding = True
# print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers):
@ -601,3 +606,4 @@ class MiniMindLM(PreTrainedModel):
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break

View File

@ -1,154 +0,0 @@
# TREx 数据集处理工具使用说明
这个工具支持两步骤处理 TREx 数据集:
1. **句子提取**:从 TREx 数据集提取三元组并转换为自然语言句子
2. **LLM 处理**:使用 ollama qwen3:4b 模型进行句子修正和重要性评分
## 🆕 防卡死机制
为了解决LLM处理时可能出现的卡死问题新增了以下功能
### 超时和重试机制
- **超时时间**每个LLM请求60秒超时
- **重试机制**失败后最多重试2次采用指数退避策略
- **并发控制**降低并发数至4个减少服务器压力
### 心跳监控系统
- **实时监控**每30秒检查一次LLM响应状态
- **异常警告**超过30秒无成功响应时发出警告
- **服务检测**自动检查ollama服务状态
- **详细统计**:实时显示成功率、超时率等统计信息
### 日志系统
- **详细日志**:所有操作都记录在 `logs/` 目录下
- **双重输出**:同时输出到日志文件和控制台
- **时间戳标记**:日志文件包含启动时间戳
### 改进的错误处理
- **异常恢复**LLM处理失败时使用原句子和默认评分
- **状态监控**处理前检查ollama服务状态
- **批次间休息**批次之间休息5秒避免过度压力
## 安装依赖
```bash
pip install agno asyncio pydantic requests
```
确保已安装并启动 ollama并下载 qwen3:4b 模型:
```bash
ollama pull qwen3:4b
```
## 使用方法
### 1. 完整流程(两步骤连续执行)
```bash
python trex_to_sentences_simple.py --step all --input_dir dataset/TREx --max_files 2
```
### 2. 分步骤执行
#### 步骤1仅提取句子
```bash
python trex_to_sentences_simple.py --step extract --input_dir dataset/TREx --sentences_json my_sentences.json --max_files 2
```
#### 步骤2仅LLM处理
```bash
python trex_to_sentences_simple.py --step llm --sentences_json my_sentences.json --output_file final_output.txt
```
## 主要参数说明
- `--step`: 运行步骤
- `extract`: 仅提取句子
- `llm`: 仅LLM处理
- `all`: 完整流程(默认)
- `--input_dir`: TREx数据集目录默认`dataset/TREx`
- `--sentences_json`: 提取的句子JSON文件默认`extracted_sentences.json`
- `--output_file`: 最终输出文件(默认:`trex_sentences_enhanced.txt`
- `--max_files`: 最大处理文件数(用于测试)
- `--no_llm`: 禁用LLM处理
## 输出文件
**注意:所有输出文件都会自动保存在相应目录中**
### 句子提取输出
- `output/extracted_sentences.json`: 提取的原始句子,包含元数据
### LLM处理输出
- `output/{output_file}.txt`: 修正后的句子文本文件
- `output/{output_file}.json`: 完整的处理结果(包含原句、修正句、评分)
- `output/{output_file}_sorted_by_importance.txt`: 按重要性评分排序的句子
### 检查点文件
- `output/{output_file}_checkpoint_{数量}.json`: 每1000条句子自动保存的检查点
### 日志文件
- `logs/trex_processor_{时间戳}.log`: 详细的处理日志
## 🆕 故障诊断
### 如果遇到卡死问题:
1. **检查日志文件**:查看 `logs/` 目录下的最新日志
2. **观察心跳监控**:注意控制台的心跳警告信息
3. **检查ollama服务**
```bash
ps aux | grep ollama
curl http://localhost:11434/api/tags
```
4. **重启ollama服务**(如果需要):
```bash
pkill ollama
ollama serve &
```
### 常见警告信息:
- `⚠️ 心跳检测`: 30秒无成功响应正常情况下会自动恢复
- `❌ 严重警告`: 90秒无成功响应可能需要检查服务
- `💀 Ollama服务异常`: ollama服务可能已停止
- `💀 致命错误`: 连续多次警告(建议重启程序)
## 检查点恢复机制
- 步骤2会自动检测已有的检查点文件`output/` 目录中)
- 只处理尚未处理的句子,避免重复工作
- 如果所有句子都已处理,会直接生成最终输出文件
- 中断后重新运行会自动从最新检查点继续
## 示例工作流
```bash
# 1. 先提取句子(可以快速完成)
python trex_to_sentences_simple.py --step extract --max_files 5
# 2. 后续进行LLM处理耗时较长支持断点续传
python trex_to_sentences_simple.py --step llm
# 如果中途中断再次运行步骤2会自动从检查点恢复
python trex_to_sentences_simple.py --step llm
```
## 性能特点
- **保守的并发**: 最大4个并发LLM请求降低卡死风险
- **检查点保存**: 每1000条句子自动保存支持断点续传
- **智能监控**: 详细的处理进度和时间预估
- **健壮的错误处理**: LLM请求失败时使用原句子和默认评分
- **服务监控**: 自动检测ollama服务状态
## 注意事项
1. 首次运行步骤2前必须先完成步骤1
2. 检查点文件会占用额外磁盘空间(每个都包含所有已处理数据)
3. LLM处理速度取决于模型性能和网络状况
4. 建议先用`--max_files`参数测试小批量数据
5. **新增**:如果遇到卡死,查看日志文件和心跳监控信息
6. **新增**程序会自动检测并报告ollama服务状态
7. **新增**:所有处理过程都有详细日志记录,便于问题诊断

View File

@ -0,0 +1,133 @@
import json
import os
import datetime
from typing import List, Dict, Any
# 配置参数
json_path = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/dataset/combined.json"
prepare_num = 1048576 # database_init.json的数据条数可以根据需要修改
output_dir = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/dataset/"
def convert_to_database_init_format(sentences: List[str], importance_score: float = 10.0) -> Dict[str, Any]:
"""
将句子列表转换为 database_init.json 格式
Args:
sentences: 句子列表
importance_score: 重要性评分默认为10.0
Returns:
转换后的字典格式数据
"""
# 构建句子数据
sentence_data = []
for sentence in sentences:
sentence_item = {
"original_sentence": sentence,
"corrected_sentence": sentence, # 与original_sentence相同
"importance_score": importance_score
}
sentence_data.append(sentence_item)
# 构建完整的数据结构
result = {
"metadata": {
"batch_number": 1,
"batch_size": len(sentences),
"total_processed_count": len(sentences),
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"total_sentences": len(sentences),
"duplicates_removed": 0 # 在此函数中不涉及去重所以设为0
},
"sentences": sentence_data
}
return result
def preprocess_combined_json():
# 读取原始数据
print("正在读取combined.json...")
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
total_count = len(data)
print(f"总共有 {total_count} 条数据")
# 处理所有数据将subject、predicate、object拼接成句子同时记录原始数据
print("正在处理数据并拼接句子...")
sentence_to_original = {} # 记录句子到原始数据的映射
all_sentences = []
for i, item in enumerate(data):
# 拼接subject、predicate、object为一句话
sentence = f"{item['subject']} {item['predicate']} {item['object']}"
all_sentences.append(sentence)
# 记录句子到原始数据的映射(如果句子重复,保留第一次出现的原始数据)
if sentence not in sentence_to_original:
sentence_to_original[sentence] = item
if (i + 1) % 100000 == 0:
print(f"已处理 {i + 1}/{total_count} 条数据")
print(f"完成句子拼接,共 {len(all_sentences)} 条句子")
# 去重处理
print("正在进行去重处理...")
unique_sentences = list(set(all_sentences))
duplicates_removed = len(all_sentences) - len(unique_sentences)
print(f"去重完成,去重前: {len(all_sentences)} 条,去重后: {len(unique_sentences)} 条,移除重复: {duplicates_removed}")
# 检查是否有足够的去重数据
if len(unique_sentences) < prepare_num:
print(f"警告: 去重后的数据量 ({len(unique_sentences)}) 少于所需数量 ({prepare_num})")
print(f"将使用全部 {len(unique_sentences)} 条去重数据")
selected_sentences = unique_sentences
else:
print(f"选择前 {prepare_num} 条去重数据")
selected_sentences = unique_sentences[:prepare_num]
# 转换为database_init.json格式
print("正在转换为database_init.json格式...")
database_init_data = convert_to_database_init_format(selected_sentences, importance_score=10.0)
# 更新metadata中的duplicates_removed信息
database_init_data["metadata"]["duplicates_removed"] = duplicates_removed
# 保存database_init.json
database_output_path = os.path.join(output_dir, "database_init_from_combined.json")
print(f"正在保存 {database_output_path}...")
with open(database_output_path, "w", encoding="utf-8") as f:
json.dump(database_init_data, f, ensure_ascii=False, indent=2)
print(f"database_init_from_combined.json 保存完成,包含 {len(selected_sentences)} 条数据")
# 保存剩余数据作为训练集(保持原格式)
remaining_sentences = unique_sentences[prepare_num:] if len(unique_sentences) > prepare_num else []
if remaining_sentences:
# 将剩余的句子转换回原始格式
print(f"正在转换剩余 {len(remaining_sentences)} 条数据为原始格式...")
remaining_original_data = []
for sentence in remaining_sentences:
if sentence in sentence_to_original:
remaining_original_data.append(sentence_to_original[sentence])
print(f"保存剩余 {len(remaining_original_data)} 条数据作为训练集...")
train_output_path = os.path.join(output_dir, "combined_train.json")
with open(train_output_path, "w", encoding="utf-8") as f:
json.dump(remaining_original_data, f, ensure_ascii=False, indent=2)
print(f"combined_train.json 保存完成")
else:
print("没有剩余数据用于训练集")
remaining_original_data = []
print("\n数据处理完成!")
print(f"原始数据: {total_count}")
print(f"拼接后: {len(all_sentences)} 条句子")
print(f"去重后: {len(unique_sentences)} 条句子")
print(f"用于database_init: {len(selected_sentences)}")
print(f"剩余训练数据: {len(remaining_original_data) if remaining_sentences else 0}")
if __name__ == "__main__":
preprocess_combined_json()

View File

@ -0,0 +1,741 @@
import json
import os
import pandas as pd
import tarfile
import tempfile
import shutil
from pathlib import Path
import re
import langdetect
from tqdm import tqdm
import logging
import random
import hashlib
from transformers import AutoTokenizer
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 配置参数
BASE_DIR = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind"
OUTPUT_FILE = os.path.join(BASE_DIR, "dataset", "merged_pretrain.jsonl")
OUTPUT_FILE_EXTRA = os.path.join(BASE_DIR, "dataset", "merged_pretrain_extra.jsonl")
# 数据源路径
PRETRAIN_HQ_PATH = os.path.join(BASE_DIR, "dataset", "pretrain_hq.jsonl")
WIKIPEDIA_PATH = "/home/pci/nas/share/datasets/wikipedia/data/20220301.en"
GUTENBERG_PATH = "/home/pci/nas/share/datasets/gutenberg/data"
OPENWEBTEXT_PATH = "/home/pci/nas/share/datasets/openwebtext/subsets"
# Token长度限制
MIN_TOKENS = 410
MAX_TOKENS = 490
# 数据集质量和采样比例配置 - 主文件
DATASET_CONFIG = {
"pretrain_hq": {"quality": "high", "sample_ratio": 1.0, "max_samples": None}, # 高质量,全部使用
"wikipedia": {"quality": "high", "sample_ratio": 1.0, "max_samples": 5000000}, # 高质量使用全部最多500万条
"gutenberg": {"quality": "medium", "sample_ratio": 0.8, "max_samples": 3000000}, # 中质量使用80%最多300万条
"openwebtext": {"quality": "low", "sample_ratio": 0.2, "max_samples": 2000000} # 低质量使用20%最多200万条
}
# 额外文件的配置 - 剩余数据
DATASET_CONFIG_EXTRA = {
"wikipedia": {"quality": "high", "sample_ratio": 1.0, "max_samples": None}, # 剩余的全部
"gutenberg": {"quality": "medium", "sample_ratio": 0.8, "max_samples": 5000000}, # 剩余的80%最多500万条
"openwebtext": {"quality": "low", "sample_ratio": 0.6, "max_samples": 4000000} # 剩余的60%最多400万条
}
# 全局变量:记录已选择的数据
selected_data_hashes = {
"wikipedia": set(),
"gutenberg": set(),
"openwebtext": set()
}
# 初始化tokenizer
tokenizer = None
def init_tokenizer():
"""初始化tokenizer"""
global tokenizer
try:
# 首先尝试使用本地的minimind tokenizer
local_tokenizer_path = os.path.join(BASE_DIR, "model", "minimind_tokenizer")
if os.path.exists(local_tokenizer_path):
tokenizer = AutoTokenizer.from_pretrained(local_tokenizer_path)
logger.info("Local MiniMind tokenizer initialized successfully")
else:
# 如果本地tokenizer不存在使用GPT-2但设置离线模式
tokenizer = AutoTokenizer.from_pretrained("gpt2", local_files_only=True)
logger.info("GPT-2 tokenizer initialized successfully (offline)")
except Exception as e:
logger.error(f"Error initializing tokenizer: {e}")
logger.info("Trying to use a simple fallback tokenizer...")
# 使用简单的分词方法作为备选
tokenizer = None
logger.warning("Using simple word-based tokenization as fallback")
def count_tokens(text):
"""计算文本的token数量"""
if tokenizer is None:
init_tokenizer()
if tokenizer is not None:
try:
tokens = tokenizer.encode(text, add_special_tokens=False)
return len(tokens)
except:
pass
# 如果tokenization失败或tokenizer为None使用简单估算
return int(len(text.split()) * 1.3) # 大概估算,确保返回整数
def is_english_text(text, threshold=0.8):
"""检测文本是否为英文"""
try:
if len(text) < 50: # 太短的文本跳过检测
return True
detected_lang = langdetect.detect(text)
return detected_lang == 'en'
except:
# 如果检测失败,使用简单的英文字符比例判断
english_chars = sum(1 for char in text if char.isascii() and (char.isalpha() or char in ' .,!?-'))
total_chars = len(text.replace('\n', ' ').replace('\t', ' '))
return (english_chars / max(total_chars, 1)) > threshold
def truncate_to_token_limit(text, target_tokens=MAX_TOKENS):
"""将文本截断到目标token数量"""
if tokenizer is None:
init_tokenizer()
if tokenizer is not None:
try:
tokens = tokenizer.encode(text, add_special_tokens=False)
if len(tokens) <= target_tokens:
return text
# 截断到目标长度
truncated_tokens = tokens[:target_tokens]
truncated_text = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
# 尝试在句号处截断以保持完整性
sentences = truncated_text.split('.')
if len(sentences) > 1:
# 保留除最后一个不完整句子外的所有句子
truncated_text = '.'.join(sentences[:-1]) + '.'
return truncated_text
except:
pass
# 如果处理失败或tokenizer为None使用字符数估算
estimated_chars = int(target_tokens / 1.3 * 4) # 大概估算
text = text[:estimated_chars]
# 尝试在句号处截断以保持完整性
sentences = text.split('.')
if len(sentences) > 1:
text = '.'.join(sentences[:-1]) + '.'
return text
def split_text_into_chunks(text, target_chunk_size=1500):
"""将长文本分割成多个中等长度的段落块"""
# 清理文本
text = text.strip()
if not text:
return []
# 移除过多的换行符和空格
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
text = re.sub(r' +', ' ', text)
chunks = []
# 按段落分割
paragraphs = text.split('\n\n')
current_chunk = ""
for paragraph in paragraphs:
paragraph = paragraph.strip()
if not paragraph:
continue
# 如果当前块加上新段落长度适中,就添加
if len(current_chunk) + len(paragraph) < target_chunk_size:
if current_chunk:
current_chunk += "\n\n" + paragraph
else:
current_chunk = paragraph
else:
# 如果当前块不为空,保存它
if current_chunk:
chunks.append(current_chunk)
# 如果段落本身就很长,需要进一步分割
if len(paragraph) > target_chunk_size * 2:
# 按句子分割长段落
sentences = re.split(r'(?<=[.!?])\s+', paragraph)
temp_chunk = ""
for sentence in sentences:
if len(temp_chunk) + len(sentence) < target_chunk_size:
if temp_chunk:
temp_chunk += " " + sentence
else:
temp_chunk = sentence
else:
if temp_chunk:
chunks.append(temp_chunk)
temp_chunk = sentence
if temp_chunk:
current_chunk = temp_chunk
else:
current_chunk = ""
else:
current_chunk = paragraph
# 添加最后一个块
if current_chunk:
chunks.append(current_chunk)
return chunks
def format_text_for_pretrain(text):
"""将文本格式化为预训练格式并检查token长度"""
# 清理文本
text = text.strip()
if not text:
return None
# 移除过多的换行符和空格
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
text = re.sub(r' +', ' ', text)
# 检查token长度
token_count = count_tokens(text)
# 如果太短,跳过
if token_count < MIN_TOKENS:
return None
# 如果太长,截断
if token_count > MAX_TOKENS:
text = truncate_to_token_limit(text, MAX_TOKENS)
token_count = count_tokens(text)
# 再次检查是否在合理范围内
if token_count < MIN_TOKENS or token_count > MAX_TOKENS:
return None
# 格式化为预训练格式
formatted_text = f"<|im_start|>{text}<|im_end|>"
return formatted_text
def get_text_hash(text):
"""获取文本的哈希值,用于去重"""
return hashlib.md5(text.encode('utf-8')).hexdigest()
def should_sample(dataset_name, current_count, config_dict=None, is_extra_mode=False):
"""根据配置决定是否采样当前记录"""
if config_dict is None:
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
config = config_dict[dataset_name]
# 检查是否达到最大样本数
if config["max_samples"] and current_count >= config["max_samples"]:
return False
# 根据采样比例随机决定
return random.random() < config["sample_ratio"]
def process_pretrain_hq():
"""处理已有的高质量预训练数据 - 直接输出,不做任何处理"""
logger.info("Processing pretrain_hq.jsonl...")
count = 0
with open(PRETRAIN_HQ_PATH, 'r', encoding='utf-8') as f:
for line in tqdm(f, desc="Processing pretrain_hq"):
try:
data = json.loads(line.strip())
text = data.get('text', '').strip()
if text: # 只要有文本就直接输出,不做任何检测
if should_sample("pretrain_hq", count):
yield {"text": text}
count += 1
except json.JSONDecodeError:
continue
logger.info(f"Processed {count} records from pretrain_hq.jsonl")
def process_wikipedia(is_extra_mode=False):
"""处理Wikipedia数据"""
mode_text = "extra" if is_extra_mode else "main"
logger.info(f"Processing Wikipedia data ({mode_text} mode)...")
count = 0
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
# 获取所有英文Wikipedia文件
wiki_files = list(Path(WIKIPEDIA_PATH).glob("*.parquet"))
for file_path in tqdm(wiki_files, desc=f"Processing Wikipedia files ({mode_text})"):
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
break
try:
df = pd.read_parquet(file_path)
for _, row in df.iterrows():
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
break
text = row.get('text', '').strip()
if text and len(text) > 200: # 预过滤太短的文本
# 先将长文本分割成中等大小的块
chunks = split_text_into_chunks(text, target_chunk_size=2000)
for chunk in chunks:
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
break
chunk_hash = get_text_hash(chunk)
# 在额外模式下,跳过已经被主文件选中的数据
if is_extra_mode and chunk_hash in selected_data_hashes["wikipedia"]:
continue
formatted_text = format_text_for_pretrain(chunk)
if formatted_text and should_sample("wikipedia", count, config_dict, is_extra_mode):
# 在主模式下记录哈希值
if not is_extra_mode:
selected_data_hashes["wikipedia"].add(chunk_hash)
yield {"text": formatted_text}
count += 1
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
continue
logger.info(f"Processed {count} records from Wikipedia ({mode_text} mode)")
def process_gutenberg(is_extra_mode=False):
"""处理Gutenberg数据"""
mode_text = "extra" if is_extra_mode else "main"
logger.info(f"Processing Gutenberg data ({mode_text} mode)...")
count = 0
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
# 获取所有Gutenberg训练文件
gutenberg_files = list(Path(GUTENBERG_PATH).glob("train-*.parquet"))
for file_path in tqdm(gutenberg_files, desc=f"Processing Gutenberg files ({mode_text})"):
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
break
try:
df = pd.read_parquet(file_path)
for _, row in df.iterrows():
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
break
text = row.get('text', '').strip()
if text and len(text) > 300 and is_english_text(text): # 预过滤
# 先将长文本分割成中等大小的块
chunks = split_text_into_chunks(text, target_chunk_size=1800)
for chunk in chunks:
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
break
chunk_hash = get_text_hash(chunk)
# 在额外模式下,跳过已经被主文件选中的数据
if is_extra_mode and chunk_hash in selected_data_hashes["gutenberg"]:
continue
formatted_text = format_text_for_pretrain(chunk)
if formatted_text and should_sample("gutenberg", count, config_dict, is_extra_mode):
# 在主模式下记录哈希值
if not is_extra_mode:
selected_data_hashes["gutenberg"].add(chunk_hash)
yield {"text": formatted_text}
count += 1
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
continue
logger.info(f"Processed {count} records from Gutenberg ({mode_text} mode)")
def process_openwebtext(is_extra_mode=False):
"""处理OpenWebText数据"""
mode_text = "extra" if is_extra_mode else "main"
logger.info(f"Processing OpenWebText data ({mode_text} mode)...")
count = 0
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
max_files = 5 # 减少处理的文件数量以避免过长处理时间
# 获取tar文件列表
tar_files = list(Path(OPENWEBTEXT_PATH).glob("*.tar"))[:max_files]
for tar_path in tqdm(tar_files, desc=f"Processing OpenWebText files ({mode_text})"):
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
try:
with tarfile.open(tar_path, 'r') as outer_tar:
# 创建临时目录处理外层tar
with tempfile.TemporaryDirectory() as temp_dir:
outer_tar.extractall(temp_dir)
# 处理解压后的xz文件
for root, dirs, files in os.walk(temp_dir):
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
for file in files:
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
if file.endswith('.xz'):
xz_path = os.path.join(root, file)
# 创建另一个临时目录处理xz文件
with tempfile.TemporaryDirectory() as xz_temp_dir:
try:
# 解压xz文件
import subprocess
decompressed_path = os.path.join(xz_temp_dir, file[:-3]) # 移除.xz后缀
subprocess.run(['xz', '-dc', xz_path],
stdout=open(decompressed_path, 'wb'),
check=True)
# 检查解压后的文件是否是tar格式
if tarfile.is_tarfile(decompressed_path):
# 处理内层tar文件
with tarfile.open(decompressed_path, 'r') as inner_tar:
with tempfile.TemporaryDirectory() as inner_temp_dir:
inner_tar.extractall(inner_temp_dir)
# 处理最终的txt文件
for inner_root, inner_dirs, inner_files in os.walk(inner_temp_dir):
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
for txt_file in inner_files:
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
if txt_file.endswith('.txt'):
txt_path = os.path.join(inner_root, txt_file)
try:
with open(txt_path, 'r', encoding='utf-8', errors='ignore') as f:
text = f.read().strip()
if text and len(text) > 500 and is_english_text(text):
# 先将长文本分割成中等大小的块
chunks = split_text_into_chunks(text, target_chunk_size=1600)
for chunk in chunks:
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
chunk_hash = get_text_hash(chunk)
# 在额外模式下,跳过已经被主文件选中的数据
if is_extra_mode and chunk_hash in selected_data_hashes["openwebtext"]:
continue
formatted_text = format_text_for_pretrain(chunk)
if formatted_text and should_sample("openwebtext", count, config_dict, is_extra_mode):
# 在主模式下记录哈希值
if not is_extra_mode:
selected_data_hashes["openwebtext"].add(chunk_hash)
yield {"text": formatted_text}
count += 1
except Exception as e:
logger.debug(f"Error reading txt file {txt_path}: {e}")
continue
else:
# 如果不是tar文件直接作为文本处理
try:
with open(decompressed_path, 'r', encoding='utf-8', errors='ignore') as f:
text = f.read().strip()
if text and len(text) > 500 and is_english_text(text):
chunks = split_text_into_chunks(text, target_chunk_size=1600)
for chunk in chunks:
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
chunk_hash = get_text_hash(chunk)
# 在额外模式下,跳过已经被主文件选中的数据
if is_extra_mode and chunk_hash in selected_data_hashes["openwebtext"]:
continue
formatted_text = format_text_for_pretrain(chunk)
if formatted_text and should_sample("openwebtext", count, config_dict, is_extra_mode):
# 在主模式下记录哈希值
if not is_extra_mode:
selected_data_hashes["openwebtext"].add(chunk_hash)
yield {"text": formatted_text}
count += 1
except Exception as e:
logger.debug(f"Error reading decompressed file {decompressed_path}: {e}")
continue
except Exception as e:
logger.debug(f"Error processing xz file {xz_path}: {e}")
continue
except Exception as e:
logger.error(f"Error processing {tar_path}: {e}")
continue
logger.info(f"Processed {count} records from OpenWebText ({mode_text} mode)")
def merge_datasets():
"""合并所有数据集,生成主文件和额外文件"""
logger.info("Starting dataset merging...")
logger.info("Main dataset configuration:")
for name, config in DATASET_CONFIG.items():
logger.info(f" {name}: quality={config['quality']}, ratio={config['sample_ratio']}, max={config['max_samples']}")
logger.info("Extra dataset configuration:")
for name, config in DATASET_CONFIG_EXTRA.items():
logger.info(f" {name}: quality={config['quality']}, ratio={config['sample_ratio']}, max={config['max_samples']}")
# 确保输出目录存在
os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)
os.makedirs(os.path.dirname(OUTPUT_FILE_EXTRA), exist_ok=True)
# 统计信息
main_dataset_stats = {}
extra_dataset_stats = {}
# 第一阶段:生成主文件
logger.info("="*50)
logger.info("PHASE 1: Generating main dataset file")
logger.info("="*50)
with open(OUTPUT_FILE, 'w', encoding='utf-8') as outfile:
main_total_count = 0
# 处理各个数据集(主模式)
main_datasets = [
("pretrain_hq", process_pretrain_hq),
("wikipedia", lambda: process_wikipedia(is_extra_mode=False)),
("gutenberg", lambda: process_gutenberg(is_extra_mode=False)),
("openwebtext", lambda: process_openwebtext(is_extra_mode=False))
]
for dataset_name, dataset_func in main_datasets:
logger.info(f"Processing {dataset_name} for main file...")
dataset_count = 0
try:
for record in dataset_func():
json.dump(record, outfile, ensure_ascii=False)
outfile.write('\n')
dataset_count += 1
main_total_count += 1
# 每5000条记录输出一次进度
if main_total_count % 5000 == 0:
logger.info(f"Main file: Processed {main_total_count} total records")
# 保存统计信息
main_dataset_stats[dataset_name] = {
'selected': dataset_count,
'config': DATASET_CONFIG[dataset_name]
}
except Exception as e:
logger.error(f"Error processing {dataset_name} for main file: {e}")
main_dataset_stats[dataset_name] = {
'selected': dataset_count,
'config': DATASET_CONFIG[dataset_name]
}
logger.info(f"Main file - Completed {dataset_name}: {dataset_count} records")
logger.info(f"Main file generation completed. Total records: {main_total_count}")
# 第二阶段:生成额外文件
logger.info("="*50)
logger.info("PHASE 2: Generating extra dataset file")
logger.info("="*50)
with open(OUTPUT_FILE_EXTRA, 'w', encoding='utf-8') as outfile:
extra_total_count = 0
# 处理各个数据集(额外模式)- 不包括pretrain_hq
extra_datasets = [
("wikipedia", lambda: process_wikipedia(is_extra_mode=True)),
("gutenberg", lambda: process_gutenberg(is_extra_mode=True)),
("openwebtext", lambda: process_openwebtext(is_extra_mode=True))
]
for dataset_name, dataset_func in extra_datasets:
logger.info(f"Processing {dataset_name} for extra file...")
dataset_count = 0
try:
for record in dataset_func():
json.dump(record, outfile, ensure_ascii=False)
outfile.write('\n')
dataset_count += 1
extra_total_count += 1
# 每5000条记录输出一次进度
if extra_total_count % 5000 == 0:
logger.info(f"Extra file: Processed {extra_total_count} total records")
# 保存统计信息
extra_dataset_stats[dataset_name] = {
'selected': dataset_count,
'config': DATASET_CONFIG_EXTRA[dataset_name]
}
except Exception as e:
logger.error(f"Error processing {dataset_name} for extra file: {e}")
extra_dataset_stats[dataset_name] = {
'selected': dataset_count,
'config': DATASET_CONFIG_EXTRA[dataset_name]
}
logger.info(f"Extra file - Completed {dataset_name}: {dataset_count} records")
logger.info(f"Extra file generation completed. Total records: {extra_total_count}")
# 打印详细统计信息
print_detailed_statistics(main_dataset_stats, main_total_count, extra_dataset_stats, extra_total_count)
logger.info("All dataset processing completed successfully!")
logger.info(f"Main file saved to: {OUTPUT_FILE}")
logger.info(f"Extra file saved to: {OUTPUT_FILE_EXTRA}")
def print_detailed_statistics(main_dataset_stats, main_total_count, extra_dataset_stats, extra_total_count):
"""打印详细统计信息"""
print("\n" + "="*100)
print("DATASET PROCESSING SUMMARY")
print("="*100)
# 主文件统计
print("\nMAIN FILE (merged_pretrain.jsonl):")
print("-" * 90)
print(f"{'Dataset':<15} {'Selected':<10} {'Sample Ratio':<12} {'Max Limit':<12} {'% of Main':<12} {'Quality':<8}")
print("-" * 90)
for dataset_name, stats in main_dataset_stats.items():
selected = stats['selected']
config = stats['config']
ratio = config['sample_ratio']
max_limit = config['max_samples'] if config['max_samples'] else "No limit"
percentage = (selected / main_total_count * 100) if main_total_count > 0 else 0
quality = config['quality']
print(f"{dataset_name:<15} {selected:<10,} {ratio:<12.1%} {str(max_limit):<12} {percentage:<12.2f}% {quality:<8}")
print("-" * 90)
print(f"{'MAIN TOTAL':<15} {main_total_count:<10,} {'':<12} {'':<12} {'100.00%':<12} {'':<8}")
# 额外文件统计
print("\nEXTRA FILE (merged_pretrain_extra.jsonl):")
print("-" * 90)
print(f"{'Dataset':<15} {'Selected':<10} {'Sample Ratio':<12} {'Max Limit':<12} {'% of Extra':<12} {'Quality':<8}")
print("-" * 90)
for dataset_name, stats in extra_dataset_stats.items():
selected = stats['selected']
config = stats['config']
ratio = config['sample_ratio']
max_limit = config['max_samples'] if config['max_samples'] else "No limit"
percentage = (selected / extra_total_count * 100) if extra_total_count > 0 else 0
quality = config['quality']
print(f"{dataset_name:<15} {selected:<10,} {ratio:<12.1%} {str(max_limit):<12} {percentage:<12.2f}% {quality:<8}")
print("-" * 90)
print(f"{'EXTRA TOTAL':<15} {extra_total_count:<10,} {'':<12} {'':<12} {'100.00%':<12} {'':<8}")
# 总体统计
total_records = main_total_count + extra_total_count
print("\nOVERALL STATISTICS:")
print("-" * 50)
print(f"Main file records: {main_total_count:>10,}")
print(f"Extra file records: {extra_total_count:>10,}")
print(f"Total records: {total_records:>10,}")
print(f"Token range per sample: {MIN_TOKENS}-{MAX_TOKENS} tokens")
# 质量分布统计
quality_stats = {}
for dataset_name, stats in main_dataset_stats.items():
quality = stats['config']['quality']
if quality not in quality_stats:
quality_stats[quality] = {'main': 0, 'extra': 0}
quality_stats[quality]['main'] += stats['selected']
for dataset_name, stats in extra_dataset_stats.items():
quality = stats['config']['quality']
if quality not in quality_stats:
quality_stats[quality] = {'main': 0, 'extra': 0}
quality_stats[quality]['extra'] += stats['selected']
print("\nQUALITY DISTRIBUTION:")
print("-" * 60)
print(f"{'Quality':<12} {'Main File':<12} {'Extra File':<12} {'Total':<12} {'%':<8}")
print("-" * 60)
for quality in sorted(quality_stats.keys()):
main_count = quality_stats[quality]['main']
extra_count = quality_stats[quality]['extra']
total_count = main_count + extra_count
percentage = (total_count / total_records * 100) if total_records > 0 else 0
print(f"{quality.capitalize():<12} {main_count:<12,} {extra_count:<12,} {total_count:<12,} {percentage:<8.2f}%")
print("-" * 60)
print(f"{'Total':<12} {main_total_count:<12,} {extra_total_count:<12,} {total_records:<12,} {'100.00%':<8}")
print(f"\nFiles saved to:")
print(f" Main file: {OUTPUT_FILE}")
print(f" Extra file: {OUTPUT_FILE_EXTRA}")
print("="*100)
def main():
"""主函数"""
try:
# 设置随机种子以确保结果可重现
random.seed(42)
# 检查依赖包
try:
import langdetect
from transformers import AutoTokenizer
except ImportError as e:
logger.error(f"Missing dependencies: {e}")
logger.error("Please install: pip install langdetect transformers")
return
# 初始化tokenizer
init_tokenizer()
# 检查输入文件
if not os.path.exists(PRETRAIN_HQ_PATH):
logger.error(f"pretrain_hq.jsonl not found at {PRETRAIN_HQ_PATH}")
return
# 开始合并数据集
merge_datasets()
logger.info("All processing completed successfully!")
except Exception as e:
logger.error(f"Error in main process: {e}")
raise
if __name__ == "__main__":
main()

View File

@ -0,0 +1,61 @@
#!/usr/bin/env python3
"""
小规模测试预处理脚本
"""
import sys
import os
# 添加路径
sys.path.append('/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/preprocessing')
# 导入主模块
from preprocess_pretrain import *
# 修改配置为小规模测试
DATASET_CONFIG["wikipedia"]["max_samples"] = 100
DATASET_CONFIG["gutenberg"]["max_samples"] = 50
DATASET_CONFIG["openwebtext"]["max_samples"] = 20
DATASET_CONFIG_EXTRA["wikipedia"]["max_samples"] = 50
DATASET_CONFIG_EXTRA["gutenberg"]["max_samples"] = 30
DATASET_CONFIG_EXTRA["openwebtext"]["max_samples"] = 15
# 修改输出路径
OUTPUT_FILE = "/tmp/test_main.jsonl"
OUTPUT_FILE_EXTRA = "/tmp/test_extra.jsonl"
def test_small_scale():
"""小规模测试"""
print("Starting small scale test...")
# 设置随机种子
random.seed(42)
try:
# 初始化tokenizer
init_tokenizer()
# 开始合并数据集
merge_datasets()
# 检查输出文件
if os.path.exists(OUTPUT_FILE):
with open(OUTPUT_FILE, 'r') as f:
main_lines = len(f.readlines())
print(f"Main file created: {main_lines} lines")
if os.path.exists(OUTPUT_FILE_EXTRA):
with open(OUTPUT_FILE_EXTRA, 'r') as f:
extra_lines = len(f.readlines())
print(f"Extra file created: {extra_lines} lines")
print("Small scale test completed successfully!")
except Exception as e:
print(f"Test failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_small_scale()

175
pyproject.toml Normal file
View File

@ -0,0 +1,175 @@
[project]
name = "minimind"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"accelerate==1.7.0",
"aiohappyeyeballs==2.6.1",
"aiohttp==3.11.17",
"aiosignal==1.3.2",
"altair==5.5.0",
"annotated-types==0.7.0",
"anyio==4.9.0",
"async-timeout==5.0.1",
"attrs==25.3.0",
"blinker==1.9.0",
"boto3==1.38.41",
"botocore==1.38.41",
"cachetools==5.5.2",
"certifi==2025.1.31",
"charset-normalizer==3.4.1",
"click==8.1.8",
"contourpy==1.3.2",
"cycler==0.12.1",
"datasets==2.21.0",
"datasketch==1.6.4",
"deepspeed==0.17.0",
"determined>=0.37.0",
"dill==0.3.8",
"distro==1.9.0",
"docker-pycreds==0.4.0",
"einops==0.8.1",
"exceptiongroup==1.2.2",
"filelock==3.18.0",
"Flask==3.0.3",
"Flask-Cors==4.0.0",
"fonttools==4.57.0",
"frozenlist==1.6.0",
"fsspec==2024.6.1",
"gitdb==4.0.12",
"GitPython==3.1.44",
"h11==0.14.0",
"hjson==3.1.0",
"httpcore==1.0.8",
"httpx==0.28.1",
"huggingface-hub==0.30.2",
"importlib_metadata==7.2.1",
"itsdangerous==2.2.0",
"jieba==0.42.1",
"Jinja2==3.1.2",
"jiter==0.9.0",
"jmespath==1.0.1",
"joblib==1.4.2",
"jsonlines==4.0.0",
"jsonpointer==2.1",
"jsonschema==4.23.0",
"jsonschema-specifications==2024.10.1",
"kiwisolver==1.4.8",
"langdetect==1.0.9",
"markdown-it-py==3.0.0",
"MarkupSafe==3.0.2",
"marshmallow==3.22.0",
"matplotlib==3.10.0",
"mdurl==0.1.2",
"modelscope==1.25.0",
"mpi4py>=4.0.3",
"mpmath==1.3.0",
"msgpack==1.1.0",
"multidict==6.4.3",
"multiprocess==0.70.16",
"narwhals==1.35.0",
"networkx==3.4.2",
"ngrok==1.4.0",
"ninja==1.11.1.4",
"nltk==3.8",
"numpy==1.26.4",
"nvidia-cublas-cu11==11.11.3.6",
"nvidia-cublas-cu12==12.6.4.1",
"nvidia-cuda-cupti-cu11==11.8.87",
"nvidia-cuda-cupti-cu12==12.6.80",
"nvidia-cuda-nvrtc-cu11==11.8.89",
"nvidia-cuda-nvrtc-cu12==12.6.77",
"nvidia-cuda-runtime-cu11==11.8.89",
"nvidia-cuda-runtime-cu12==12.6.77",
"nvidia-cudnn-cu11==9.1.0.70",
"nvidia-cudnn-cu12==9.5.1.17",
"nvidia-cufft-cu11==10.9.0.58",
"nvidia-cufft-cu12==11.3.0.4",
"nvidia-cufile-cu12==1.11.1.6",
"nvidia-curand-cu11==10.3.0.86",
"nvidia-curand-cu12==10.3.7.77",
"nvidia-cusolver-cu11==11.4.1.48",
"nvidia-cusolver-cu12==11.7.1.2",
"nvidia-cusparse-cu11==11.7.5.86",
"nvidia-cusparse-cu12==12.5.4.2",
"nvidia-cusparselt-cu12==0.6.3",
"nvidia-ml-py==12.575.51",
"nvidia-nccl-cu11==2.21.5",
"nvidia-nccl-cu12==2.26.2",
"nvidia-nvjitlink-cu12==12.6.85",
"nvidia-nvtx-cu11==11.8.86",
"nvidia-nvtx-cu12==12.6.77",
"openai==1.59.6",
"packaging==23.2",
"pandas>=2.0.0",
"peft==0.7.1",
"pillow==10.4.0",
"platformdirs==4.3.7",
"prettytable==3.16.0",
"propcache==0.3.1",
"protobuf==4.25.6",
"psutil==5.9.8",
"py-cpuinfo==9.0.0",
"pyarrow==19.0.1",
"pydantic==2.11.7",
"pydantic_core==2.33.2",
"pydeck==0.9.1",
"pyecharts==2.0.8",
"Pygments==2.19.1",
"pynvml==12.0.0",
"pyparsing==3.2.3",
"python-dateutil==2.9.0.post0",
"pytz==2025.2",
"PyYAML==6.0.2",
"referencing==0.36.2",
"regex==2024.11.6",
"requests==2.32.3",
"rich==13.7.1",
"rpds-py==0.24.0",
"s3transfer==0.13.0",
"safetensors==0.5.3",
"scikit-learn==1.5.1",
"scipy==1.15.2",
"sentence-transformers==2.3.1",
"sentencepiece==0.2.0",
"sentry-sdk==2.26.1",
"setproctitle==1.3.5",
"simhash==2.1.2",
"simplejson==3.20.1",
"six==1.17.0",
"smmap==5.0.2",
"sniffio==1.3.1",
"streamlit==1.30.0",
"swankit==0.2.4",
"swanlab==0.6.4",
"sympy==1.13.3",
"tenacity==8.5.0",
"threadpoolctl==3.6.0",
"tiktoken>=0.8.0",
"tokenizers==0.21.1",
"toml==0.10.2",
"torch==2.7.1",
"torchaudio==2.7.1",
"torchvision==0.22.1",
"tornado==6.4.2",
"tqdm==4.67.1",
"transformers==4.52.4",
"triton==3.3.1",
"trl==0.13.0",
"typing-inspection==0.4.1",
"typing_extensions==4.13.2",
"tzlocal==5.3.1",
"ujson==5.1.0",
"urllib3==2.4.0",
"validators==0.34.0",
"wandb==0.18.3",
"watchdog==6.0.0",
"wcwidth==0.2.13",
"Werkzeug==3.1.3",
"wrapt==1.17.2",
"xxhash==3.5.0",
"yarl==1.20.0",
"zipp==3.21.0",
]

View File

@ -1,3 +1,4 @@
accelerate==1.7.0
aiohappyeyeballs==2.6.1
aiohttp==3.11.17
aiosignal==1.3.2
@ -7,6 +8,8 @@ anyio==4.9.0
async-timeout==5.0.1
attrs==25.3.0
blinker==1.9.0
boto3==1.38.41
botocore==1.38.41
cachetools==5.5.2
certifi==2025.1.31
charset-normalizer==3.4.1
@ -15,6 +18,7 @@ contourpy==1.3.2
cycler==0.12.1
datasets==2.21.0
datasketch==1.6.4
deepspeed==0.17.0
dill==0.3.8
distro==1.9.0
docker-pycreds==0.4.0
@ -33,17 +37,19 @@ hjson==3.1.0
httpcore==1.0.8
httpx==0.28.1
huggingface-hub==0.30.2
idna==3.10
importlib_metadata==7.2.1
itsdangerous==2.2.0
jieba==0.42.1
Jinja2==3.1.2
jiter==0.9.0
jmespath==1.0.1
joblib==1.4.2
jsonlines==4.0.0
jsonpointer==2.1
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
kiwisolver==1.4.8
langdetect==1.0.9
markdown-it-py==3.0.0
MarkupSafe==3.0.2
marshmallow==3.22.0
@ -60,21 +66,50 @@ ngrok==1.4.0
ninja==1.11.1.4
nltk==3.8
numpy==1.26.4
nvidia-cublas-cu11==11.11.3.6
nvidia-cublas-cu12==12.6.4.1
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvrtc-cu11==11.8.89
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu11==9.1.0.70
nvidia-cudnn-cu12==9.5.1.17
nvidia-cufft-cu11==10.9.0.58
nvidia-cufft-cu12==11.3.0.4
nvidia-cufile-cu12==1.11.1.6
nvidia-curand-cu11==10.3.0.86
nvidia-curand-cu12==10.3.7.77
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu11==11.7.5.86
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparselt-cu12==0.6.3
nvidia-ml-py==12.575.51
nvidia-nccl-cu11==2.21.5
nvidia-nccl-cu12==2.26.2
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvtx-cu11==11.8.86
nvidia-nvtx-cu12==12.6.77
openai==1.59.6
packaging==23.2
pandas==1.5.3
peft==0.7.1
pillow==10.4.0
platformdirs==4.3.7
prettytable==3.16.0
propcache==0.3.1
protobuf==4.25.6
psutil==5.9.8
py-cpuinfo==9.0.0
pyarrow==19.0.1
pydantic==2.8.2
pydantic_core==2.20.1
pydantic==2.11.7
pydantic_core==2.33.2
pydeck==0.9.1
pyecharts==2.0.8
Pygments==2.19.1
pynvml==12.0.0
pyparsing==3.2.3
python-dateutil==2.9.0.post0
pytz==2025.2
@ -84,6 +119,7 @@ regex==2024.11.6
requests==2.32.3
rich==13.7.1
rpds-py==0.24.0
s3transfer==0.13.0
safetensors==0.5.3
scikit-learn==1.5.1
scipy==1.15.2
@ -92,21 +128,28 @@ sentencepiece==0.2.0
sentry-sdk==2.26.1
setproctitle==1.3.5
simhash==2.1.2
simplejson==3.20.1
six==1.17.0
smmap==5.0.2
sniffio==1.3.1
streamlit==1.30.0
swankit==0.2.4
swanlab==0.6.4
sympy==1.13.3
tenacity==8.5.0
threadpoolctl==3.6.0
tiktoken==0.5.1
tokenizers==0.21.1
toml==0.10.2
torch==2.7.1
torchaudio==2.7.1
torchvision==0.22.1
tornado==6.4.2
tqdm==4.67.1
transformers==4.48.0
triton==3.3.0
transformers==4.52.4
triton==3.3.1
trl==0.13.0
typing-inspection==0.4.1
typing_extensions==4.13.2
tzlocal==5.3.1
ujson==5.1.0
@ -114,7 +157,9 @@ urllib3==2.4.0
validators==0.34.0
wandb==0.18.3
watchdog==6.0.0
wcwidth==0.2.13
Werkzeug==3.1.3
wrapt==1.17.2
xxhash==3.5.0
yarl==1.20.0
zipp==3.21.0

View File

@ -1,8 +1,8 @@
#!/bin/bash
# 激活conda环境
source $(conda info --base)/etc/profile.d/conda.sh
conda activate mini
# source $(conda info --base)/etc/profile.d/conda.sh
# conda activate mini
# 设置环境变量以帮助调试
export NCCL_DEBUG=INFO
@ -26,9 +26,27 @@ export PYTHONFAULTHANDLER=1
# --profile_interval 10
# 方法2: 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0 /opt/conda/envs/mini/bin/python -m accelerate.commands.launch \
# 内存泄漏调试配置 - 减少内存使用
CUDA_VISIBLE_DEVICES=0 uv run -p .venv python -m accelerate.commands.launch \
--num_processes=1 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py \
train_pretrain_accelerate.py
# --batch_size 128 \
# --num_workers 1
# --knowledge_num 48020 \
# --num_workers 1 \
# --epochs 4 \
# --learning_rate 2e-4 \
# --dtype bfloat16 \
# --accumulation_steps 32 \
# --grad_clip 1.0 \
# --log_interval 50 \
# --save_interval 10000 \
# --dim 512 \
# --n_layers 8 \
# --max_seq_len 512 \
# --use_flash_attn \
# --profile \
# --profile_interval 10

33
startup.sh Normal file
View File

@ -0,0 +1,33 @@
#!/bin/bash
set -e
# 在容器启动后,首先从 requirements.txt 安装所有依赖包
# pip install -r requirements.txt
# bash install.sh -y
python3 -m pip install --upgrade pip
pip install uv -i https://pypi.tuna.tsinghua.edu.cn/simple
# 切换到项目目录
cd /ycz/Minimind
# 检查并修复虚拟环境
if [ ! -f .venv/bin/python ] || [ ! -x .venv/bin/python ]; then
echo "Virtual environment is broken or missing, recreating with uv..."
rm -rf .venv
uv venv .venv
fi
# 不要手动激活虚拟环境让uv自动管理
# . ./.venv/bin/activate
# 使用uv同步依赖
uv sync
# 安装完成后,执行主训练脚本
# "$@" 会将 experiment.yaml 中 entrypoint 定义的参数传递给 python 脚本
CUDA_VISIBLE_DEVICES=0 uv run python -m accelerate.commands.launch \
--num_processes=1 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py "$@"

View File

@ -1,97 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
测试实数版本的位置编码
"""
import torch
from model.model import precompute_pos_cis, precompute_pos_cis_real, apply_rotary_emb, apply_rotary_emb_real
from model.LMConfig import LMConfig
from model.model import MiniMindLM
def test_pos_encoding_equivalence():
"""测试复数版本和实数版本的位置编码是否等价"""
print("测试位置编码等价性...")
# 参数设置
dim = 64
seq_len = 10
# 生成复数版本的位置编码
pos_cis = precompute_pos_cis(dim=dim, end=seq_len)
# 生成实数版本的位置编码
pos_cis_real = precompute_pos_cis_real(dim=dim, end=seq_len)
# 创建随机查询和键
batch_size = 2
n_heads = 4
head_dim = dim
xq = torch.randn(batch_size, seq_len, n_heads, head_dim)
xk = torch.randn(batch_size, seq_len, n_heads, head_dim)
# 应用复数版本的旋转位置编码
xq_complex, xk_complex = apply_rotary_emb(xq, xk, pos_cis)
# 应用实数版本的旋转位置编码
xq_real, xk_real = apply_rotary_emb_real(xq, xk, pos_cis_real)
# 计算差异
q_diff = torch.abs(xq_complex - xq_real).mean().item()
k_diff = torch.abs(xk_complex - xk_real).mean().item()
print(f"查询差异: {q_diff:.6f}")
print(f"键差异: {k_diff:.6f}")
# 检查差异是否在可接受范围内
tolerance = 1e-5
if q_diff < tolerance and k_diff < tolerance:
print("✅ 测试通过: 复数版本和实数版本的位置编码在数值上等价")
else:
print("❌ 测试失败: 复数版本和实数版本的位置编码存在显著差异")
def test_model_forward():
"""测试模型前向传播"""
print("\n测试模型前向传播...")
# 创建模型配置
config = LMConfig(
dim=128,
n_layers=2,
n_heads=4,
n_kv_heads=4, # 确保n_kv_heads被设置且n_heads能被n_kv_heads整除
vocab_size=1000,
max_seq_len=128,
disable_db=True # 禁用数据库功能,避免额外的复杂性
)
# 创建模型
try:
model = MiniMindLM(config)
print(f"✅ 模型初始化成功")
except Exception as e:
print(f"❌ 模型初始化失败: {str(e)}")
return
# 创建输入
batch_size = 2
seq_len = 10
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
# 前向传播
try:
with torch.no_grad():
outputs = model(input_ids)
print(f"✅ 模型前向传播成功")
print(f"输出形状: {outputs.logits.shape}")
except Exception as e:
print(f"❌ 模型前向传播失败: {str(e)}")
if __name__ == "__main__":
# 测试位置编码等价性
test_pos_encoding_equivalence()
# 测试模型前向传播
test_model_forward()

View File

@ -1,6 +1,6 @@
import os
# 设置环境变量
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
# 设置环境变量 - 将wandb替换为SwanLab
# os.environ["SWANLAB_MODE"] = "online" # SwanLab使用在线模式
import platform
import argparse
from tqdm import tqdm
@ -21,6 +21,9 @@ from accelerate.utils import DistributedDataParallelKwargs
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import swanlab # 替换wandb导入
import gc # 添加垃圾回收模块
import psutil # 添加系统资源监控模块
from model.model import MiniMindLM, RMSNorm
from model.LMConfig import LMConfig
@ -28,6 +31,63 @@ from model.dataset import PretrainDataset
warnings.filterwarnings('ignore')
# 内存监控辅助函数
def get_memory_usage():
"""获取当前内存使用情况"""
process = psutil.Process()
memory_info = process.memory_info()
return {
'rss_mb': memory_info.rss / 1024 / 1024, # 物理内存使用量MB
'vms_mb': memory_info.vms / 1024 / 1024, # 虚拟内存使用量MB
}
def get_cuda_memory_usage():
"""获取CUDA内存使用情况"""
if torch.cuda.is_available():
return {
'cuda_allocated_mb': torch.cuda.memory_allocated() / 1024 / 1024,
'cuda_reserved_mb': torch.cuda.memory_reserved() / 1024 / 1024,
'cuda_max_allocated_mb': torch.cuda.max_memory_allocated() / 1024 / 1024,
}
return {}
def get_tensor_memory_size(tensor_list):
"""计算tensor列表的总内存占用MB"""
total_size = 0
for batch in tensor_list:
if isinstance(batch, (list, tuple)):
for tensor in batch:
if isinstance(tensor, torch.Tensor):
total_size += tensor.numel() * tensor.element_size()
elif isinstance(batch, torch.Tensor):
total_size += batch.numel() * batch.element_size()
return total_size / 1024 / 1024 # 转换为MB
def log_memory_status(step, prefetch_batches, accelerator, stage="", detailed=False):
"""记录内存状态"""
if not accelerator.is_main_process:
return
memory_info = get_memory_usage()
cuda_info = get_cuda_memory_usage()
prefetch_memory = get_tensor_memory_size(prefetch_batches)
log_msg = f"[Memory Monitor] Step {step} {stage} - "
log_msg += f"Prefetch batches: {len(prefetch_batches)}, "
log_msg += f"Prefetch memory: {prefetch_memory:.2f}MB, "
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
if detailed:
log_msg += f", System VMS: {memory_info['vms_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA max allocated: {cuda_info['cuda_max_allocated_mb']:.2f}MB"
Logger(log_msg, accelerator)
# 日志记录函数
def Logger(msg, accelerator=None):
# 如果没有提供accelerator则只在主进程打印
@ -218,7 +278,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb):
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run):
loss_fct = nn.CrossEntropyLoss(reduction='none')
epoch_start_time = time.time()
total_steps_in_epoch = len(train_loader)
@ -226,6 +286,10 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
moe_path = '_moe' if args.use_moe else ''
best_loss = float('10000')
# 初始化CUDA事件变量
data_start = data_end = forward_start = forward_end = None
backward_start = backward_end = optimizer_start = optimizer_end = None
# 添加CUDA事件来分析性能 (只在主进程进行)
if args.profile and accelerator.is_main_process:
data_start = torch.cuda.Event(enable_timing=True)
@ -242,40 +306,63 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
data_iter = iter(train_loader)
prefetch_batches = []
# 记录初始内存状态
if args.memory_monitor:
log_memory_status(-1, prefetch_batches, accelerator, "before_prefetch", detailed=True)
# 预取初始批次
for _ in range(min(prefetch_factor, len(train_loader))):
for i in range(min(prefetch_factor, len(train_loader))):
try:
batch = next(data_iter)
prefetch_batches.append(batch)
# 每次添加batch后记录内存变化
if args.memory_monitor and accelerator.is_main_process:
log_memory_status(-1, prefetch_batches, accelerator, f"after_adding_batch_{i+1}")
except StopIteration:
break
# 记录预取完成后的内存状态
if args.memory_monitor:
log_memory_status(-1, prefetch_batches, accelerator, "after_initial_prefetch", detailed=True)
# 在开始循环前初始化日志记录所需变量
last_log_time = epoch_start_time
for step in range(total_steps_in_epoch):
try:
# 计时数据加载 (只在主进程进行)
if args.profile and accelerator.is_main_process:
if args.profile and accelerator.is_main_process and data_start is not None:
data_start.record()
# 记录使用batch前的内存状态根据配置间隔记录详细信息
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "before_use_batch", detailed=True)
# 使用预取的数据
if prefetch_batches:
X, Y, loss_mask = prefetch_batches.pop(0)
# 记录使用batch后的内存变化
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "after_pop_batch")
else:
# 如果预取队列为空,直接加载
X, Y, loss_mask = next(data_iter)
if args.memory_monitor and accelerator.is_main_process:
Logger(f"[Memory Monitor] Step {step} - Prefetch queue empty, loading directly!", accelerator)
# 异步预取下一批数据
if step + prefetch_factor < len(train_loader):
try:
batch = next(data_iter)
prefetch_batches.append(batch)
# 记录添加新batch后的内存变化
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "after_add_batch")
except StopIteration:
pass
# 计时数据加载结束 (只在主进程进行)
if args.profile and accelerator.is_main_process:
if args.profile and accelerator.is_main_process and data_end is not None:
data_end.record()
# 更新学习率
@ -283,7 +370,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
scheduler.step()
# 计时前向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process:
if args.profile and accelerator.is_main_process and forward_start is not None:
forward_start.record()
# 前向传播
@ -310,11 +397,11 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
loss = loss / args.accumulation_steps
# 计时前向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process:
if args.profile and accelerator.is_main_process and forward_end is not None:
forward_end.record()
# 计时反向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process:
if args.profile and accelerator.is_main_process and backward_start is not None:
backward_start.record()
# 反向传播
@ -322,11 +409,11 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
accelerator.backward(loss)
# 计时反向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process:
if args.profile and accelerator.is_main_process and backward_end is not None:
backward_end.record()
# 计时优化器步骤 (只在主进程进行)
if args.profile and accelerator.is_main_process:
if args.profile and accelerator.is_main_process and optimizer_start is not None:
optimizer_start.record()
# 优化器步骤 - 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪
@ -339,20 +426,33 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
optimizer.zero_grad()
# 计时优化器步骤结束 (只在主进程进行)
if args.profile and accelerator.is_main_process:
if args.profile and accelerator.is_main_process and optimizer_end is not None:
optimizer_end.record()
# 打印训练信息 (只在主进程进行)
if (step + 1) % args.log_interval == 0 and accelerator.is_main_process:
current_time = time.time()
# 记录日志输出时的详细内存状态
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "at_log_interval", detailed=True)
# 强制垃圾回收并记录内存变化
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
log_memory_status(step, prefetch_batches, accelerator, "after_gc", detailed=True)
# 计算性能指标
if args.profile:
if args.profile and accelerator.is_main_process:
torch.cuda.synchronize()
# 使用自上次日志以来的时间计算性能指标,而不是总时间
data_time = data_start.elapsed_time(data_end)
forward_time = forward_start.elapsed_time(forward_end)
backward_time = backward_start.elapsed_time(backward_end)
optimizer_time = optimizer_start.elapsed_time(optimizer_end)
# 确保所有事件都已记录才计算elapsed_time
try:
data_time = data_start.elapsed_time(data_end) if data_start is not None and data_end is not None else 0
forward_time = forward_start.elapsed_time(forward_end) if forward_start is not None and forward_end is not None else 0
backward_time = backward_start.elapsed_time(backward_end) if backward_start is not None and backward_end is not None else 0
optimizer_time = optimizer_start.elapsed_time(optimizer_end) if optimizer_start is not None and optimizer_end is not None else 0
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
# total_time_ms = data_time + forward_time + backward_time + optimizer_time
@ -373,6 +473,11 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
except RuntimeError as e:
if "Both events must be recorded" in str(e):
Logger(f"Warning: CUDA events not properly recorded, skipping performance analysis: {e}", accelerator)
else:
raise e
# 计算当前学习率
@ -413,12 +518,12 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
if args.use_wandb and accelerator.is_main_process and wandb:
wandb.log(log_dict)
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
swanlab_run.log(log_dict)
# 保存模型 (只在主进程进行)
loss_total = loss.item() * args.accumulation_steps
if best_loss > loss_total and accelerator.is_main_process:
if epoch > 1 and best_loss > loss_total and accelerator.is_main_process:
best_loss = loss_total
# 使用函数开始处定义的moe_path变量
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
@ -432,20 +537,45 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
except Exception as e:
Logger(f"Error in training step: {e}", accelerator)
# 记录异常时的内存状态
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "at_exception", detailed=True)
import traceback
Logger(traceback.format_exc(), accelerator)
# 清理prefetch_batches防止内存泄漏
if args.memory_monitor and accelerator.is_main_process:
Logger(f"[Memory Monitor] Clearing prefetch_batches due to exception. Current length: {len(prefetch_batches)}", accelerator)
prefetch_batches.clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "after_exception_cleanup", detailed=True)
# 训练epoch结束时清理prefetch_batches
if args.memory_monitor:
if accelerator.is_main_process:
Logger(f"[Memory Monitor] Epoch {epoch+1} finished. Clearing prefetch_batches. Final length: {len(prefetch_batches)}", accelerator)
log_memory_status(total_steps_in_epoch-1, prefetch_batches, accelerator, "before_epoch_end_cleanup", detailed=True)
prefetch_batches.clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if args.memory_monitor:
log_memory_status(total_steps_in_epoch-1, prefetch_batches, accelerator, "after_epoch_end_cleanup", detailed=True)
def main():
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=4)
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", default=True, action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数
parser.add_argument("--swanlab_project", type=str, default="MiniMind-Pretrain") # 替换wandb参数
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--accumulation_steps", type=int, default=32)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
@ -456,17 +586,19 @@ def main():
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代")
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
parser.add_argument("--data_path", type=str, default="./dataset/merged_pretrain.jsonl")
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowledge_num", type=int, default=8192,help="知识库的数据数目")
parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目")
parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
parser.add_argument("--database_init_path", type=str, default="./dataset/combined_prepare.json", help="数据库初始化路径")
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控")
parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)")
args = parser.parse_args()
#########################################################
@ -479,7 +611,7 @@ def main():
gradient_accumulation_steps=args.accumulation_steps,
gradient_clipping=args.grad_clip,
zero_stage=2, # 使用ZeRO-2优化
offload_optimizer_device="cpu", # 将优化器状态卸载到CPU
offload_optimizer_device="none", # 将优化器状态卸载到CPU
offload_param_device="none", # 不将参数卸载到CPU
)
accelerator = Accelerator(
@ -523,18 +655,30 @@ def main():
#########################################################
# 配置wandb
# 配置SwanLab
#########################################################
# 设置wandb运行名称
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
if args.use_wandb and accelerator.is_main_process:
import wandb
# 合并args和lm_config为一个字典
# 设置SwanLab运行名称
args.swanlab_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
# 合并args和lm_config为一个字典无论是否使用SwanLab都需要用于打印配置信息
config_dict = vars(args).copy()
config_dict.update(vars(lm_config))
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config_dict)
# 初始化SwanLab实验实例
swanlab_run = None
if args.use_swanlab and accelerator.is_main_process:
# 初始化SwanLab
swanlab_run = swanlab.init(
project=args.swanlab_project,
experiment_name=args.swanlab_run_name,
description="MiniMind预训练实验使用本地部署的SwanLab进行可视化",
config=config_dict
# 设置SwanLab服务器地址和API Key
# host="http://100.123.118.114:11071",
# api_key="LesBT7HRq23HNBrOPKP8S"
)
else:
wandb = None
swanlab_run = None
#########################################################
# 打印信息
@ -616,13 +760,31 @@ def main():
#########################################################
overall_start_time = time.time() # Record overall start time
for epoch in range(args.epochs):
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb) # Pass overall start time
Logger(f"开始第{epoch+1}轮训练", accelerator)
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run) # Pass overall start time
# 每个epoch结束后进行内存清理
Logger(f"{epoch+1}轮训练完成,进行内存清理", accelerator)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 记录epoch结束时的内存状态
if accelerator.is_main_process:
memory_info = get_memory_usage()
cuda_info = get_cuda_memory_usage()
log_msg = f"[Memory Monitor] Epoch {epoch+1} completed - "
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
Logger(log_msg, accelerator)
#########################################################
# 关闭wandb
# 关闭SwanLab
#########################################################
if args.use_wandb and accelerator.is_main_process:
wandb.finish()
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
swanlab_run.finish()
if __name__ == "__main__":
main()

4812
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff