DynamicKV-LLM Pretrain v1.2.2:新数据集;使用uv;消除内存泄漏
This commit is contained in:
parent
770c34f0e3
commit
d6617702a5
1
.gitignore
vendored
1
.gitignore
vendored
@ -9,3 +9,4 @@ models/sentence_transformers_cache/
|
||||
qwen2-1.7B/
|
||||
images/
|
||||
cache/
|
||||
.venv/
|
112
.vscode/launch.json
vendored
112
.vscode/launch.json
vendored
@ -2,101 +2,39 @@
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Debug Train Pretrain Accelerate",
|
||||
"name": "DynamicKV-LLM Mini Minimind Debug",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
||||
"console": "integratedTerminal",
|
||||
"python": "/opt/conda/envs/mini/bin/python",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"module": "accelerate.commands.launch",
|
||||
"args": [
|
||||
"--num_processes=1",
|
||||
"--mixed_precision=bf16",
|
||||
"--main_process_port=29500",
|
||||
"train_pretrain_accelerate.py",
|
||||
"--batch_size", "16",
|
||||
"--knowledge_num", "48020",
|
||||
"--num_workers", "1",
|
||||
"--epochs", "4",
|
||||
"--learning_rate", "2e-4",
|
||||
"--dtype", "bfloat16",
|
||||
"--accumulation_steps", "32",
|
||||
"--grad_clip", "1.0",
|
||||
"--log_interval", "50",
|
||||
"--save_interval", "10000",
|
||||
"--dim", "512",
|
||||
"--n_layers", "8",
|
||||
"--max_seq_len", "512",
|
||||
"--use_flash_attn",
|
||||
"--profile",
|
||||
"--profile_interval", "10"
|
||||
],
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}",
|
||||
"CUDA_VISIBLE_DEVICES": "0"
|
||||
},
|
||||
"justMyCode": false,
|
||||
"stopOnEntry": false,
|
||||
"redirectOutput": true
|
||||
},
|
||||
{
|
||||
"name": "Debug Train Pretrain Accelerate (Multi-GPU)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
||||
"console": "integratedTerminal",
|
||||
"python": "/opt/conda/envs/mini/bin/python",
|
||||
"args": [
|
||||
"--hidden_size", "512",
|
||||
"--max_seq_len", "512",
|
||||
"--n_layers", "8",
|
||||
"--batch_size", "8",
|
||||
"--epochs", "1"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}",
|
||||
"CUDA_VISIBLE_DEVICES": "0,1"
|
||||
},
|
||||
"justMyCode": false,
|
||||
"stopOnEntry": false,
|
||||
"redirectOutput": true
|
||||
},
|
||||
{
|
||||
"name": "Debug Train Pretrain Accelerate (Small Test)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
||||
"console": "integratedTerminal",
|
||||
"python": "/opt/conda/envs/mini/bin/python",
|
||||
"args": [
|
||||
"--hidden_size", "512",
|
||||
"--max_seq_len", "512",
|
||||
"--n_layers", "8",
|
||||
"--batch_size", "2",
|
||||
"--epochs", "1",
|
||||
"--log_interval", "10",
|
||||
"--save_interval", "100",
|
||||
"--accumulation_steps", "4"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}",
|
||||
"CUDA_VISIBLE_DEVICES": "0",
|
||||
"WANDB_MODE": "offline"
|
||||
},
|
||||
"justMyCode": false,
|
||||
"stopOnEntry": false,
|
||||
"redirectOutput": true
|
||||
},
|
||||
{
|
||||
"name": "Debug ExtractDB Comparison",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
||||
"console": "integratedTerminal",
|
||||
"python": "/opt/conda/envs/mini/bin/python",
|
||||
"args": [
|
||||
"--hidden_size", "512",
|
||||
"--max_seq_len", "256",
|
||||
"--n_layers", "4",
|
||||
"--batch_size", "2",
|
||||
"--epochs", "1",
|
||||
"--log_interval", "10",
|
||||
"--save_interval", "200",
|
||||
"--accumulation_steps", "2",
|
||||
"--comparison_mode",
|
||||
"--knowledge_num", "256",
|
||||
"--knowledge_length", "64",
|
||||
"--comparison_mode"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}",
|
||||
"CUDA_VISIBLE_DEVICES": "0",
|
||||
"WANDB_MODE": "offline"
|
||||
},
|
||||
"justMyCode": false,
|
||||
"stopOnEntry": false,
|
||||
"redirectOutput": true
|
||||
"stopOnEntry": false
|
||||
}
|
||||
]
|
||||
}
|
199
README.md
Normal file
199
README.md
Normal file
@ -0,0 +1,199 @@
|
||||
<div align="center">
|
||||
|
||||

|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||

|
||||
[](https://github.com/jingyaogong/minimind/stargazers)
|
||||
[](LICENSE)
|
||||
[](https://github.com/jingyaogong/minimind/commits/master)
|
||||
[](https://github.com/jingyaogong/minimind/pulls)
|
||||
[](https://huggingface.co/collections/jingyaogong/minimind-66caf8d999f5c7fa64f399e5)
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
# 📌 数据介绍
|
||||
|
||||
## Ⅰ Tokenizer
|
||||
|
||||
分词器将单词从自然语言通过“词典”映射到`0, 1, 36`这样的数字,可以理解为数字就代表了单词在“词典”中的页码。
|
||||
可以选择自己构造词表训练一个“词典”,代码可见`./scripts/train_tokenizer.py`(仅供学习参考,若非必要无需再自行训练,MiniMind已自带tokenizer)。
|
||||
或者选择比较出名的开源大模型分词器,
|
||||
正如同直接用新华/牛津词典的优点是token编码压缩率很好,缺点是页数太多,动辄数十万个词汇短语;
|
||||
自己训练的分词器,优点是词表长度和内容随意控制,缺点是压缩率很低(例如"hello"也许会被拆分为"h e l l o"
|
||||
五个独立的token),且生僻词难以覆盖。
|
||||
“词典”的选择固然很重要,LLM的输出本质上是SoftMax到词典N个词的多分类问题,然后通过“词典”解码到自然语言。
|
||||
因为MiniMind体积需要严格控制,为了避免模型头重脚轻(词嵌入embedding层参数在LLM占比太高),所以词表长度短短益善。
|
||||
|
||||
<details style="color:rgb(128,128,128)">
|
||||
<summary>Tokenizer介绍</summary>
|
||||
|
||||
第三方强大的开源模型例如Yi、qwen、chatglm、mistral、Llama3的tokenizer词表长度如下:
|
||||
|
||||
<table>
|
||||
<tr><th>Tokenizer模型</th><th>词表大小</th><th>来源</th></tr>
|
||||
<tr><td>yi tokenizer</td><td>64,000</td><td>01万物(中国)</td></tr>
|
||||
<tr><td>qwen2 tokenizer</td><td>151,643</td><td>阿里云(中国)</td></tr>
|
||||
<tr><td>glm tokenizer</td><td>151,329</td><td>智谱AI(中国)</td></tr>
|
||||
<tr><td>mistral tokenizer</td><td>32,000</td><td>Mistral AI(法国)</td></tr>
|
||||
<tr><td>llama3 tokenizer</td><td>128,000</td><td>Meta(美国)</td></tr>
|
||||
<tr><td>minimind tokenizer</td><td>6,400</td><td>自定义</td></tr>
|
||||
</table>
|
||||
|
||||
> 👉2024-09-17更新:为了防止过去的版本歧义&控制体积,minimind所有模型均使用minimind_tokenizer分词,废弃所有mistral_tokenizer版本。
|
||||
|
||||
```
|
||||
# 一些自言自语
|
||||
> 尽管minimind_tokenizer长度很小,编解码效率弱于qwen2、glm等中文友好型分词器。
|
||||
> 但minimind模型选择了自己训练的minimind_tokenizer作为分词器,以保持整体参数轻量,避免编码层和计算层占比失衡,头重脚轻,因为minimind的词表大小只有6400。
|
||||
> 且minimind在实际测试中没有出现过生僻词汇解码失败的情况,效果良好。
|
||||
> 由于自定义词表压缩长度到6400,使得LLM总参数量最低只有25.8M。
|
||||
> 训练数据`tokenizer_train.jsonl`均来自于`匠数大模型数据集`,这部分数据相对次要,如需训练可以自由选择。
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Ⅱ Pretrain数据
|
||||
|
||||
经历了MiniMind-V1的低质量预训练数据,导致模型胡言乱语的教训,`2025-02-05` 之后决定不再采用大规模无监督的数据集做预训练。
|
||||
进而尝试把[匠数大模型数据集](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)的中文部分提取出来,
|
||||
清洗出字符`<512`长度的大约1.6GB的语料直接拼接成预训练数据 `pretrain_hq.jsonl`,hq即为high
|
||||
quality(当然也还不算high,提升数据质量无止尽)。
|
||||
|
||||
文件`pretrain_hq.jsonl` 数据格式为
|
||||
|
||||
```bash
|
||||
{"text": "如何才能摆脱拖延症? 治愈拖延症并不容易,但以下建议可能有所帮助..."}
|
||||
```
|
||||
|
||||
## Ⅲ SFT数据
|
||||
|
||||
[匠数大模型SFT数据集](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
||||
“是一个完整、格式统一、安全的大模型训练和研究资源。
|
||||
从网络上的公开数据源收集并整理了大量开源数据集,对其进行了格式统一,数据清洗,
|
||||
包含10M条数据的中文数据集和包含2M条数据的英文数据集。”
|
||||
以上是官方介绍,下载文件后的数据总量大约在4B tokens,肯定是适合作为中文大语言模型的SFT数据的。
|
||||
但是官方提供的数据格式很乱,全部用来sft代价太大。
|
||||
我将把官方数据集进行了二次清洗,把含有符号污染和噪声的条目去除;另外依然只保留了总长度`<512`
|
||||
的内容,此阶段希望通过大量对话补充预训练阶段欠缺的知识。
|
||||
导出文件为`sft_512.jsonl`(~7.5GB)。
|
||||
|
||||
[Magpie-SFT数据集](https://www.modelscope.cn/organization/Magpie-Align)
|
||||
收集了~1M条来自Qwen2/2.5的高质量对话,我将这部分数据进一步清洗,把总长度`<2048`的部分导出为`sft_2048.jsonl`(~9GB)。
|
||||
长度`<1024`的部分导出为`sft_1024.jsonl`(~5.5GB),用大模型对话数据直接进行sft就属于“黑盒蒸馏”的范畴。
|
||||
|
||||
进一步清洗前两步sft的数据(只保留中文字符占比高的内容),筛选长度`<512`的对话,得到`sft_mini_512.jsonl`(~1.2GB)。
|
||||
|
||||
所有sft文件 `sft_X.jsonl` 数据格式均为
|
||||
|
||||
```text
|
||||
{
|
||||
"conversations": [
|
||||
{"role": "user", "content": "你好"},
|
||||
{"role": "assistant", "content": "你好!"},
|
||||
{"role": "user", "content": "再见"},
|
||||
{"role": "assistant", "content": "再见!"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Ⅳ RLHF数据
|
||||
|
||||
来自[Magpie-DPO数据集](https://www.modelscope.cn/datasets/Magpie-Align/MagpieLM-DPO-Data-v0.1)
|
||||
大约200k条偏好数据(均是英文)生成自Llama3.1-70B/8B,可以用于训练奖励模型,优化模型回复质量,使其更加符合人类偏好。
|
||||
这里将数据总长度`<3000`的内容重组为`dpo.jsonl`(~0.9GB),包含`chosen`和`rejected`两个字段,`chosen`
|
||||
为偏好的回复,`rejected`为拒绝的回复。
|
||||
|
||||
文件 `dpo.jsonl` 数据格式为
|
||||
|
||||
```text
|
||||
{
|
||||
"chosen": [
|
||||
{"content": "Q", "role": "user"},
|
||||
{"content": "good answer", "role": "assistant"}
|
||||
],
|
||||
"rejected": [
|
||||
{"content": "Q", "role": "user"},
|
||||
{"content": "bad answer", "role": "assistant"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Ⅴ Reason数据集:
|
||||
|
||||
不得不说2025年2月谁能火的过DeepSeek...
|
||||
也激发了我对RL引导的推理模型的浓厚兴趣,目前已经用Qwen2.5复现了R1-Zero。
|
||||
如果有时间+效果work(但99%基模能力不足)我会在之后更新MiniMind基于RL训练的推理模型而不是蒸馏模型。
|
||||
时间有限,最快的低成本方案依然是直接蒸馏(黑盒方式)。
|
||||
耐不住R1太火,短短几天就已经存在一些R1的蒸馏数据集[R1-Llama-70B](https://www.modelscope.cn/datasets/Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B)、[R1-Distill-SFT](https://www.modelscope.cn/datasets/AI-ModelScope/R1-Distill-SFT)、
|
||||
[Alpaca-Distill-R1](https://huggingface.co/datasets/shareAI/Alpaca-Distill-R1-ZH)、
|
||||
[deepseek_r1_zh](https://huggingface.co/datasets/jinliuxi/deepseek_r1_zh)等等,纯中文的数据可能比较少。
|
||||
最终整合它们,导出文件为`r1_mix_1024.jsonl`,数据格式和`sft_X.jsonl`一致。
|
||||
|
||||
## Ⅵ 更多数据集
|
||||
|
||||
目前已经有[HqWu-HITCS/Awesome-Chinese-LLM](https://github.com/HqWu-HITCS/Awesome-Chinese-LLM)
|
||||
在收集和梳理中文LLM相关的开源模型、应用、数据集及教程等资料,并持续更新这方面的最新进展。全面且专业,Respect!
|
||||
|
||||
---
|
||||
|
||||
## Ⅷ 数据集下载
|
||||
|
||||
> [!NOTE]
|
||||
> 2025-02-05后,开源MiniMind最终训练所用的所有数据集,因此无需再自行预处理大规模数据集,避免重复性的数据处理工作。
|
||||
|
||||
MiniMind训练数据集 ([ModelScope](https://www.modelscope.cn/datasets/gongjy/minimind-dataset/files) | [HuggingFace](https://huggingface.co/datasets/jingyaogong))
|
||||
|
||||
> 无需全部clone,可单独下载所需的文件
|
||||
|
||||
将下载的数据集文件放到`./dataset/`目录下(✨为推荐的必须项)
|
||||
|
||||
```bash
|
||||
./dataset/
|
||||
├── dpo.jsonl (909MB)
|
||||
├── lora_identity.jsonl (22.8KB)
|
||||
├── lora_medical.jsonl (34MB)
|
||||
├── pretrain_hq.jsonl (1.6GB, ✨)
|
||||
├── r1_mix_1024.jsonl (340MB)
|
||||
├── sft_1024.jsonl (5.6GB)
|
||||
├── sft_2048.jsonl (9GB)
|
||||
├── sft_512.jsonl (7.5GB)
|
||||
├── sft_mini_512.jsonl (1.2GB, ✨)
|
||||
└── tokenizer_train.jsonl (1GB)
|
||||
```
|
||||
|
||||
<details style="color:rgb(128,128,128)">
|
||||
<summary>注:各数据集简介</summary>
|
||||
|
||||
* `dpo.jsonl` --RLHF阶段数据集
|
||||
* `lora_identity.jsonl` --自我认知数据集(例如:你是谁?我是minimind...),推荐用于lora训练(亦可用于全参SFT,勿被名字局限)
|
||||
* `lora_medical.jsonl` --医疗问答数据集,推荐用于lora训练(亦可用于全参SFT,勿被名字局限)
|
||||
* `pretrain_hq.jsonl`✨ --预训练数据集,整合自jiangshu科技
|
||||
* `r1_mix_1024.jsonl` --DeepSeek-R1-1.5B蒸馏数据,每条数据字符最大长度为1024(因此训练时设置max_seq_len=1024)
|
||||
* `sft_1024.jsonl` --整合自Qwen2.5蒸馏数据(是sft_2048的子集),每条数据字符最大长度为1024(因此训练时设置max_seq_len=1024)
|
||||
* `sft_2048.jsonl` --整合自Qwen2.5蒸馏数据,每条数据字符最大长度为2048(因此训练时设置max_seq_len=2048)
|
||||
* `sft_512.jsonl` --整合自匠数科技SFT数据,每条数据字符最大长度为512(因此训练时设置max_seq_len=512)
|
||||
* `sft_mini_512.jsonl`✨ --极简整合自匠数科技SFT数据+Qwen2.5蒸馏数据(用于快速训练Zero模型),每条数据字符最大长度为512(因此训练时设置max_seq_len=512)
|
||||
* `tokenizer_train.jsonl` --均来自于`匠数大模型数据集`,这部分数据相对次要,(不推荐自己重复训练tokenizer,理由如上)如需自己训练tokenizer可以自由选择数据集。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||

|
||||
|
||||
<details style="color:rgb(128,128,128)">
|
||||
<summary>说明 & 推荐训练方案</summary>
|
||||
|
||||
* MiniMind2 Series均经过共约20GB语料训练,大约4B tokens,即对应上面的数据组合训练结果(开销:💰💰💰💰💰💰💰💰,效果:😊😊😊😊😊😊)
|
||||
|
||||
* 想要最快速度从0实现Zero模型,推荐使用`pretrain_hq.jsonl` + `sft_mini_512.jsonl` 的数据组合,具体花销和效果可查看下文表格(开销:💰,效果:😊😊)
|
||||
|
||||
* 推荐具备一定算力资源或更在意效果的朋友可以考虑前者完整复现MiniMind2;仅有单卡GPU或在乎短时间快速复现的朋友强烈推荐后者;
|
||||
|
||||
* 【折中方案】亦可选择例如`sft_mini_512.jsonl`、`sft_1024.jsonl`中等规模数据进行自由组合训练(开销:💰💰💰,效果:😊😊😊😊)。
|
||||
|
||||
</details>
|
@ -1,126 +0,0 @@
|
||||
# 使用Accelerate+DeepSpeed进行分布式训练
|
||||
|
||||
本文档介绍如何使用Accelerate和DeepSpeed进行MiniMind模型的分布式训练。
|
||||
|
||||
## 环境准备
|
||||
|
||||
首先,确保安装了必要的依赖:
|
||||
|
||||
```bash
|
||||
pip install accelerate deepspeed
|
||||
```
|
||||
|
||||
## 配置文件说明
|
||||
|
||||
### 1. DeepSpeed配置文件 (ds_config.json)
|
||||
|
||||
DeepSpeed配置文件定义了优化器、学习率调度器和ZeRO优化等参数。主要配置包括:
|
||||
|
||||
- **ZeRO优化**:使用ZeRO-2进行优化,可以减少GPU内存使用
|
||||
- **优化器设置**:使用AdamW优化器
|
||||
- **混合精度训练**:支持FP16和BF16
|
||||
- **梯度累积**:通过"auto"自动设置,与训练脚本参数保持一致
|
||||
|
||||
### 2. Accelerate配置文件 (accelerate_config.yaml)
|
||||
|
||||
Accelerate配置文件定义了分布式训练的基本设置,包括:
|
||||
|
||||
- **分布式类型**:使用DeepSpeed
|
||||
- **混合精度**:使用BF16
|
||||
- **进程数量**:设置为4(可根据GPU数量调整)
|
||||
- **DeepSpeed配置**:指向ds_config.json文件
|
||||
|
||||
## 训练脚本说明
|
||||
|
||||
新的训练脚本`train_pretrain_accelerate.py`基于原有的`train_pretrain.py`修改而来,主要变化包括:
|
||||
|
||||
1. 使用Accelerator替代了PyTorch原生的分布式功能
|
||||
2. 移除了torchrun相关的分布式初始化代码
|
||||
3. 使用Accelerator的API进行模型、优化器和数据加载器的准备
|
||||
4. 使用Accelerator的API进行反向传播和梯度裁剪
|
||||
5. 处理了位置编码和未使用参数的问题
|
||||
|
||||
## 启动训练
|
||||
|
||||
有两种方式启动训练:
|
||||
|
||||
### 方法1:使用预先配置的accelerate配置文件
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
|
||||
--epochs 3 \
|
||||
--batch_size 24 \
|
||||
--learning_rate 2e-4 \
|
||||
--dtype bfloat16 \
|
||||
--accumulation_steps 32 \
|
||||
--grad_clip 1.0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 1024 \
|
||||
--n_layers 32 \
|
||||
--max_seq_len 1024 \
|
||||
--use_flash_attn \
|
||||
--profile \
|
||||
--profile_interval 10
|
||||
```
|
||||
|
||||
### 方法2:使用命令行参数直接配置accelerate
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
--deepspeed_config_file ds_config.json \
|
||||
train_pretrain_accelerate.py \
|
||||
--epochs 3 \
|
||||
--batch_size 24 \
|
||||
--learning_rate 2e-4 \
|
||||
--dtype bfloat16 \
|
||||
--accumulation_steps 32 \
|
||||
--grad_clip 1.0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 1024 \
|
||||
--n_layers 32 \
|
||||
--max_seq_len 1024 \
|
||||
--use_flash_attn \
|
||||
--profile \
|
||||
--profile_interval 10
|
||||
```
|
||||
|
||||
也可以直接使用提供的脚本:
|
||||
|
||||
```bash
|
||||
bash run_accelerate.sh
|
||||
```
|
||||
|
||||
## Accelerate与DeepSpeed配置的关系
|
||||
|
||||
1. **Accelerate**是一个高级API,用于简化分布式训练的设置和启动,它可以与多种分布式训练后端(如DeepSpeed、FSDP等)一起使用。
|
||||
|
||||
2. **DeepSpeed**是一个优化库,专注于大规模模型训练的内存优化和性能提升,提供了ZeRO优化等功能。
|
||||
|
||||
3. **配置关系**:
|
||||
- Accelerate配置文件(YAML)定义了使用哪种分布式后端以及基本的分布式设置
|
||||
- DeepSpeed配置文件(JSON)定义了DeepSpeed特有的优化参数
|
||||
- Accelerate通过`deepspeed_config_file`参数引用DeepSpeed配置文件
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **位置编码处理**:
|
||||
- 在模型中,`pos_cis`是一个复数张量,在分布式训练中需要特别处理
|
||||
- 在新的训练脚本中,我们使用Accelerator的API来处理这个问题,不再需要`_ddp_params_and_buffers_to_ignore`
|
||||
|
||||
2. **未使用参数处理**:
|
||||
- 原代码中使用`find_unused_parameters=True`来处理未使用的参数
|
||||
- 在新的训练脚本中,我们直接使用Accelerator的API,它会自动处理这个问题
|
||||
|
||||
3. **混合精度训练**:
|
||||
- DeepSpeed配置文件中的`fp16`和`bf16`设置为`"auto"`
|
||||
- 实际使用的精度由Accelerate的`--mixed_precision`参数决定
|
||||
|
||||
4. **梯度累积**:
|
||||
- DeepSpeed配置文件中的`gradient_accumulation_steps`设置为`"auto"`
|
||||
- 实际的梯度累积步数由训练脚本的`--accumulation_steps`参数决定
|
22
ReadMe.md
22
ReadMe.md
@ -1,22 +0,0 @@
|
||||
## 安装环境
|
||||
1. 创建conda环境
|
||||
```bash
|
||||
conda create -n accelerate python=3.10
|
||||
conda activate accelerate
|
||||
```
|
||||
|
||||
2. 根据当前系统的cuda版本安装对应的torch、torchvision和torchaudio
|
||||
|
||||
3. 根据当前环境的torch和torchvision安装accelerate和deepspeed
|
||||
|
||||
4. 安装其他包
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 修改模型
|
||||
1. 一般情况只修改 `model`文件夹的文件
|
||||
|
||||
## 运行
|
||||
1. 如果在4090或者4070ti上运行 `bash run_file/DynamicKV-LLM_Mini_Minimind.sh`
|
||||
2. 如果在4张A800上运行 `bash run_file/DynamicKV-LLM_Small_Minimind.sh`
|
26
experiment.yaml
Normal file
26
experiment.yaml
Normal file
@ -0,0 +1,26 @@
|
||||
# 1. 元数据:需要修改,请为该实验配置名称和描述
|
||||
name: ycz-minimind-test
|
||||
description: 测试minimind-test
|
||||
|
||||
# 2. 运行环境:一般不修改,如有需求可以手动替换为指定镜像
|
||||
environment:
|
||||
image: determinedai/pytorch-ngc:0.38.0 # 此项无需修改
|
||||
|
||||
# 3. 指定NAS上的数据集: 需要修改,仅修改bind_mounts字段,container_path和read_only无需修改
|
||||
#将<YOUR_DATASET_FOLDER_NAME>替换为您存放在NAS上Volume1/Share/datasets/的数据集文件夹名称
|
||||
# 请再次确保您已在 NAS上的Volume1/Share/datasets/存放了<YOUR_DATASET_FOLDER_NAME>数据集
|
||||
|
||||
|
||||
# 4. 计算资源:无需修改
|
||||
resources:
|
||||
slots_per_trial: 1 # 此项无需修改
|
||||
resource_pool: rtx4090 # 此项无需修改
|
||||
|
||||
# 5. 搜索器:无需修改
|
||||
searcher:
|
||||
name: single
|
||||
metric: test_accuracy
|
||||
smaller_is_better: false
|
||||
|
||||
# 6. 启动入口:无需修改
|
||||
entrypoint: sh startup.sh
|
6
main.py
Normal file
6
main.py
Normal file
@ -0,0 +1,6 @@
|
||||
def main():
|
||||
print("Hello from minimind!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
166
model/model.py
166
model/model.py
@ -2,7 +2,8 @@ import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
|
||||
import gc
|
||||
#子空间二维分解+梯度更新
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
@ -67,23 +68,21 @@ class KnowledgeDataset(nn.Module):
|
||||
## 数据库参数
|
||||
self.knowledge_num = params.knowledge_num
|
||||
self.knowledge_length = params.knowledge_length
|
||||
self.keys = nn.Parameter(torch.randn(self.knowledge_num, self.knowledge_dim) * 0.02, requires_grad=True)
|
||||
self.product_key_topk = min(16, self.knowledge_num)
|
||||
|
||||
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
|
||||
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
|
||||
# 修改键存储为二维分解空间,设置为可训练参数
|
||||
self.num_keys = int(math.sqrt(self.knowledge_num))
|
||||
# 确保keys是可训练参数
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
|
||||
# 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度
|
||||
self.register_buffer('knowledge_dataset',
|
||||
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
|
||||
)
|
||||
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
|
||||
|
||||
# 计算step数目,用于动态调整权重
|
||||
self.step_counter = 0
|
||||
|
||||
self.freeze_embedding = False
|
||||
|
||||
|
||||
# 移除批次计数器和更新频率相关代码
|
||||
|
||||
def intelligent_selection(self, query, all_scores, all_indices):
|
||||
"""智能分层选择策略"""
|
||||
@ -94,6 +93,15 @@ class KnowledgeDataset(nn.Module):
|
||||
device = all_scores.device
|
||||
dtype = all_scores.dtype
|
||||
|
||||
# 记录进入智能选择前的内存状态
|
||||
if hasattr(self, 'step_counter'):
|
||||
self.step_counter += 1
|
||||
# 禁用GPU内存监控记录以提高性能
|
||||
# if self.step_counter % 50 == 0: # 每50次调用记录一次
|
||||
# if torch.cuda.is_available():
|
||||
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[INTEL_SELECT_ENTER] Step {self.step_counter}: GPU Memory: {allocated_before:.2f}GB")
|
||||
|
||||
# 对每个batch进行分层选择
|
||||
enhanced_scores = all_scores.clone()
|
||||
query_features = query.mean(dim=1) # [batch_size, dim]
|
||||
@ -106,7 +114,8 @@ class KnowledgeDataset(nn.Module):
|
||||
candidate_tokens = self.knowledge_dataset[unique_indices]
|
||||
flat_tokens = candidate_tokens.view(-1)
|
||||
flat_embeddings = self.tok_embeddings(flat_tokens)
|
||||
#获取flat_tokens对应的index
|
||||
|
||||
# 获取flat_tokens对应的index(保留这些变量以便其他地方使用)
|
||||
pre_update_indices = unique_indices.view(-1)
|
||||
pre_update_embeddings = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
@ -158,84 +167,63 @@ class KnowledgeDataset(nn.Module):
|
||||
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
|
||||
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
|
||||
|
||||
# 获取
|
||||
# 清理中间张量以防止内存泄漏
|
||||
del all_candidate_indices, unique_indices, inverse_indices
|
||||
del unique_candidate_features, normalized_candidates, normalized_queries
|
||||
del batch_best_tokens, batch_best_tokens_embeddings
|
||||
del flat_tokens, flat_embeddings, pre_update_embeddings
|
||||
|
||||
# 使用重新计算的embeddings更新self.keys
|
||||
if self.is_train:
|
||||
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||||
# 记录退出智能选择后的内存状态(已禁用以提高性能)
|
||||
# if hasattr(self, 'step_counter') and self.step_counter % 50 == 0:
|
||||
# if torch.cuda.is_available():
|
||||
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[INTEL_SELECT_EXIT] Step {self.step_counter}: GPU Memory: {allocated_after:.2f}GB")
|
||||
|
||||
# 更新被修改过的key
|
||||
with torch.no_grad():
|
||||
self.has_update_keys[pre_update_indices] = 1
|
||||
# 强制垃圾回收(仅在监控步骤)
|
||||
if hasattr(self, 'step_counter') and self.step_counter % 100 == 0:
|
||||
gc.collect()
|
||||
# if torch.cuda.is_available():
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
return all_best_tokens, all_best_tokens_embeddings
|
||||
|
||||
def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings):
|
||||
if self.freeze_embedding:
|
||||
return
|
||||
# 使用pre_update_embeddings更新self.keys
|
||||
with torch.no_grad():
|
||||
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512]
|
||||
pre_update_embeddings = self.to_queries(pre_update_embeddings)
|
||||
self.keys[pre_update_indices] = pre_update_embeddings
|
||||
|
||||
def search_index(self,x):
|
||||
|
||||
def search_index(self, x):
|
||||
batch_size, seq_len, dim = x.shape
|
||||
|
||||
# collapse sequence dimension by averaging
|
||||
# 1. 序列维度平均
|
||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
||||
# queries = queries.reshape(batch_size, 2, self.key_dim)
|
||||
# queries = queries.permute(1, 0, 2)
|
||||
# 2. 生成查询向量并重塑为两个子查询
|
||||
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
|
||||
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
|
||||
# 调整维度顺序,使子空间维度位于首位
|
||||
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
|
||||
|
||||
# 2. 计算queries与keys的相似度
|
||||
sim = torch.einsum('b d, k d -> b k', queries, self.keys)
|
||||
# 3. 计算每个子空间的相似度
|
||||
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# 3. 在两个子空间分别做top-k
|
||||
scores_and_indices = sim.topk(self.product_key_topk, dim=-1)
|
||||
scores, indices = scores_and_indices[0], scores_and_indices[1]
|
||||
# 4. 在两个子空间分别做top-k
|
||||
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# 5. 应用智能分层选择策略
|
||||
# 5. 组合两个子空间的结果
|
||||
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
|
||||
# 6. 将结果重塑为二维
|
||||
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
|
||||
# 7. 选择最终的top-k结果
|
||||
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
|
||||
indices = torch.gather(all_indices, 1, indices_of_indices)
|
||||
|
||||
# 8. 应用智能分层选择策略
|
||||
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
|
||||
|
||||
# 6. 更新1%的keys
|
||||
if self.is_train:
|
||||
# 获取未更新过的keys的索引
|
||||
not_updated_indices = torch.where(self.has_update_keys == 0)[0]
|
||||
|
||||
# 如果有未更新的keys,随机选择num_update_keys个进行更新
|
||||
if len(not_updated_indices) > 0:
|
||||
num_update_keys = int(self.knowledge_num * 0.01)
|
||||
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
|
||||
perm_num = perm.shape[0]
|
||||
pre_update_indices = not_updated_indices[perm]
|
||||
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
|
||||
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
|
||||
pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1)
|
||||
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||||
# 更新被修改过的key
|
||||
with torch.no_grad():
|
||||
self.has_update_keys[pre_update_indices] = 1
|
||||
else:
|
||||
print("all keys are updated")
|
||||
# 重置所有keys的更新状态
|
||||
self.has_update_keys.zero_()
|
||||
# 重新获取所有可更新的索引
|
||||
not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device)
|
||||
num_update_keys = int(self.knowledge_num * 0.01)
|
||||
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
|
||||
pre_update_indices = not_updated_indices[perm]
|
||||
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
|
||||
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
|
||||
pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1)
|
||||
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||||
# 更新被修改过的key
|
||||
with torch.no_grad():
|
||||
self.has_update_keys[pre_update_indices] = 1
|
||||
|
||||
|
||||
|
||||
|
||||
return best_tokens, best_tokens_embeddings
|
||||
|
||||
@ -257,6 +245,16 @@ class CrossAttention(nn.Module):
|
||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||
batch_size = x.size(0)
|
||||
|
||||
# 监控交叉注意力开始时的内存(已禁用以提高性能)
|
||||
if not hasattr(self, 'call_counter'):
|
||||
self.call_counter = 0
|
||||
self.call_counter += 1
|
||||
|
||||
# 禁用GPU内存监控记录以提高性能
|
||||
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
|
||||
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[CROSS_ATTN_ENTER] Call {self.call_counter}: GPU Memory: {allocated_before:.2f}GB")
|
||||
|
||||
# 分离多头
|
||||
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
@ -282,6 +280,14 @@ class CrossAttention(nn.Module):
|
||||
|
||||
context = self.to_out(context)
|
||||
|
||||
# 清理中间张量
|
||||
del q, k, v, attn_scores, attn_weights
|
||||
|
||||
# 监控交叉注意力结束时的内存(已禁用以提高性能)
|
||||
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
|
||||
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[CROSS_ATTN_EXIT] Call {self.call_counter}: GPU Memory: {allocated_after:.2f}GB")
|
||||
|
||||
return context
|
||||
|
||||
class Attention(nn.Module):
|
||||
@ -520,12 +526,11 @@ class MiniMindLM(PreTrainedModel):
|
||||
step: int = 0,
|
||||
**args):
|
||||
start_pos = args.get('start_pos', 0)
|
||||
if self.freeze_embedding and step == 0:
|
||||
self.tok_embeddings.weight.requires_grad = False
|
||||
# 同时冻结KnowledgeDataset的嵌入更新
|
||||
self.knowledge_dataset.freeze_embedding = True
|
||||
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
|
||||
print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding)
|
||||
# if self.freeze_embedding and step == 0:
|
||||
# self.tok_embeddings.weight.requires_grad = False
|
||||
# # 移除对knowledge_dataset.freeze_embedding的设置,让键更新由batch_counter控制
|
||||
# # self.knowledge_dataset.freeze_embedding = True
|
||||
# print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
for l, layer in enumerate(self.layers):
|
||||
@ -601,3 +606,4 @@ class MiniMindLM(PreTrainedModel):
|
||||
yield input_ids[:, start:]
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
||||
|
||||
|
@ -1,154 +0,0 @@
|
||||
# TREx 数据集处理工具使用说明
|
||||
|
||||
这个工具支持两步骤处理 TREx 数据集:
|
||||
1. **句子提取**:从 TREx 数据集提取三元组并转换为自然语言句子
|
||||
2. **LLM 处理**:使用 ollama qwen3:4b 模型进行句子修正和重要性评分
|
||||
|
||||
## 🆕 防卡死机制
|
||||
|
||||
为了解决LLM处理时可能出现的卡死问题,新增了以下功能:
|
||||
|
||||
### 超时和重试机制
|
||||
- **超时时间**:每个LLM请求60秒超时
|
||||
- **重试机制**:失败后最多重试2次,采用指数退避策略
|
||||
- **并发控制**:降低并发数至4个,减少服务器压力
|
||||
|
||||
### 心跳监控系统
|
||||
- **实时监控**:每30秒检查一次LLM响应状态
|
||||
- **异常警告**:超过30秒无成功响应时发出警告
|
||||
- **服务检测**:自动检查ollama服务状态
|
||||
- **详细统计**:实时显示成功率、超时率等统计信息
|
||||
|
||||
### 日志系统
|
||||
- **详细日志**:所有操作都记录在 `logs/` 目录下
|
||||
- **双重输出**:同时输出到日志文件和控制台
|
||||
- **时间戳标记**:日志文件包含启动时间戳
|
||||
|
||||
### 改进的错误处理
|
||||
- **异常恢复**:LLM处理失败时使用原句子和默认评分
|
||||
- **状态监控**:处理前检查ollama服务状态
|
||||
- **批次间休息**:批次之间休息5秒,避免过度压力
|
||||
|
||||
## 安装依赖
|
||||
|
||||
```bash
|
||||
pip install agno asyncio pydantic requests
|
||||
```
|
||||
|
||||
确保已安装并启动 ollama,并下载 qwen3:4b 模型:
|
||||
```bash
|
||||
ollama pull qwen3:4b
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 完整流程(两步骤连续执行)
|
||||
|
||||
```bash
|
||||
python trex_to_sentences_simple.py --step all --input_dir dataset/TREx --max_files 2
|
||||
```
|
||||
|
||||
### 2. 分步骤执行
|
||||
|
||||
#### 步骤1:仅提取句子
|
||||
```bash
|
||||
python trex_to_sentences_simple.py --step extract --input_dir dataset/TREx --sentences_json my_sentences.json --max_files 2
|
||||
```
|
||||
|
||||
#### 步骤2:仅LLM处理
|
||||
```bash
|
||||
python trex_to_sentences_simple.py --step llm --sentences_json my_sentences.json --output_file final_output.txt
|
||||
```
|
||||
|
||||
## 主要参数说明
|
||||
|
||||
- `--step`: 运行步骤
|
||||
- `extract`: 仅提取句子
|
||||
- `llm`: 仅LLM处理
|
||||
- `all`: 完整流程(默认)
|
||||
|
||||
- `--input_dir`: TREx数据集目录(默认:`dataset/TREx`)
|
||||
- `--sentences_json`: 提取的句子JSON文件(默认:`extracted_sentences.json`)
|
||||
- `--output_file`: 最终输出文件(默认:`trex_sentences_enhanced.txt`)
|
||||
- `--max_files`: 最大处理文件数(用于测试)
|
||||
- `--no_llm`: 禁用LLM处理
|
||||
|
||||
## 输出文件
|
||||
|
||||
**注意:所有输出文件都会自动保存在相应目录中**
|
||||
|
||||
### 句子提取输出
|
||||
- `output/extracted_sentences.json`: 提取的原始句子,包含元数据
|
||||
|
||||
### LLM处理输出
|
||||
- `output/{output_file}.txt`: 修正后的句子文本文件
|
||||
- `output/{output_file}.json`: 完整的处理结果(包含原句、修正句、评分)
|
||||
- `output/{output_file}_sorted_by_importance.txt`: 按重要性评分排序的句子
|
||||
|
||||
### 检查点文件
|
||||
- `output/{output_file}_checkpoint_{数量}.json`: 每1000条句子自动保存的检查点
|
||||
|
||||
### 日志文件
|
||||
- `logs/trex_processor_{时间戳}.log`: 详细的处理日志
|
||||
|
||||
## 🆕 故障诊断
|
||||
|
||||
### 如果遇到卡死问题:
|
||||
|
||||
1. **检查日志文件**:查看 `logs/` 目录下的最新日志
|
||||
2. **观察心跳监控**:注意控制台的心跳警告信息
|
||||
3. **检查ollama服务**:
|
||||
```bash
|
||||
ps aux | grep ollama
|
||||
curl http://localhost:11434/api/tags
|
||||
```
|
||||
4. **重启ollama服务**(如果需要):
|
||||
```bash
|
||||
pkill ollama
|
||||
ollama serve &
|
||||
```
|
||||
|
||||
### 常见警告信息:
|
||||
|
||||
- `⚠️ 心跳检测`: 30秒无成功响应(正常情况下会自动恢复)
|
||||
- `❌ 严重警告`: 90秒无成功响应(可能需要检查服务)
|
||||
- `💀 Ollama服务异常`: ollama服务可能已停止
|
||||
- `💀 致命错误`: 连续多次警告(建议重启程序)
|
||||
|
||||
## 检查点恢复机制
|
||||
|
||||
- 步骤2会自动检测已有的检查点文件(在 `output/` 目录中)
|
||||
- 只处理尚未处理的句子,避免重复工作
|
||||
- 如果所有句子都已处理,会直接生成最终输出文件
|
||||
- 中断后重新运行会自动从最新检查点继续
|
||||
|
||||
## 示例工作流
|
||||
|
||||
```bash
|
||||
# 1. 先提取句子(可以快速完成)
|
||||
python trex_to_sentences_simple.py --step extract --max_files 5
|
||||
|
||||
# 2. 后续进行LLM处理(耗时较长,支持断点续传)
|
||||
python trex_to_sentences_simple.py --step llm
|
||||
|
||||
# 如果中途中断,再次运行步骤2会自动从检查点恢复
|
||||
python trex_to_sentences_simple.py --step llm
|
||||
```
|
||||
|
||||
## 性能特点
|
||||
|
||||
- **保守的并发**: 最大4个并发LLM请求(降低卡死风险)
|
||||
- **检查点保存**: 每1000条句子自动保存,支持断点续传
|
||||
- **智能监控**: 详细的处理进度和时间预估
|
||||
- **健壮的错误处理**: LLM请求失败时使用原句子和默认评分
|
||||
- **服务监控**: 自动检测ollama服务状态
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 首次运行步骤2前,必须先完成步骤1
|
||||
2. 检查点文件会占用额外磁盘空间(每个都包含所有已处理数据)
|
||||
3. LLM处理速度取决于模型性能和网络状况
|
||||
4. 建议先用`--max_files`参数测试小批量数据
|
||||
5. **新增**:如果遇到卡死,查看日志文件和心跳监控信息
|
||||
6. **新增**:程序会自动检测并报告ollama服务状态
|
||||
7. **新增**:所有处理过程都有详细日志记录,便于问题诊断
|
133
preprocessing/preprocess_combined_json.py
Normal file
133
preprocessing/preprocess_combined_json.py
Normal file
@ -0,0 +1,133 @@
|
||||
import json
|
||||
import os
|
||||
import datetime
|
||||
from typing import List, Dict, Any
|
||||
|
||||
# 配置参数
|
||||
json_path = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/dataset/combined.json"
|
||||
prepare_num = 1048576 # database_init.json的数据条数,可以根据需要修改
|
||||
output_dir = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/dataset/"
|
||||
|
||||
def convert_to_database_init_format(sentences: List[str], importance_score: float = 10.0) -> Dict[str, Any]:
|
||||
"""
|
||||
将句子列表转换为 database_init.json 格式
|
||||
|
||||
Args:
|
||||
sentences: 句子列表
|
||||
importance_score: 重要性评分,默认为10.0
|
||||
|
||||
Returns:
|
||||
转换后的字典格式数据
|
||||
"""
|
||||
# 构建句子数据
|
||||
sentence_data = []
|
||||
for sentence in sentences:
|
||||
sentence_item = {
|
||||
"original_sentence": sentence,
|
||||
"corrected_sentence": sentence, # 与original_sentence相同
|
||||
"importance_score": importance_score
|
||||
}
|
||||
sentence_data.append(sentence_item)
|
||||
|
||||
# 构建完整的数据结构
|
||||
result = {
|
||||
"metadata": {
|
||||
"batch_number": 1,
|
||||
"batch_size": len(sentences),
|
||||
"total_processed_count": len(sentences),
|
||||
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"total_sentences": len(sentences),
|
||||
"duplicates_removed": 0 # 在此函数中不涉及去重,所以设为0
|
||||
},
|
||||
"sentences": sentence_data
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def preprocess_combined_json():
|
||||
# 读取原始数据
|
||||
print("正在读取combined.json...")
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
total_count = len(data)
|
||||
print(f"总共有 {total_count} 条数据")
|
||||
|
||||
# 处理所有数据:将subject、predicate、object拼接成句子,同时记录原始数据
|
||||
print("正在处理数据并拼接句子...")
|
||||
sentence_to_original = {} # 记录句子到原始数据的映射
|
||||
all_sentences = []
|
||||
|
||||
for i, item in enumerate(data):
|
||||
# 拼接subject、predicate、object为一句话
|
||||
sentence = f"{item['subject']} {item['predicate']} {item['object']}"
|
||||
all_sentences.append(sentence)
|
||||
|
||||
# 记录句子到原始数据的映射(如果句子重复,保留第一次出现的原始数据)
|
||||
if sentence not in sentence_to_original:
|
||||
sentence_to_original[sentence] = item
|
||||
|
||||
if (i + 1) % 100000 == 0:
|
||||
print(f"已处理 {i + 1}/{total_count} 条数据")
|
||||
|
||||
print(f"完成句子拼接,共 {len(all_sentences)} 条句子")
|
||||
|
||||
# 去重处理
|
||||
print("正在进行去重处理...")
|
||||
unique_sentences = list(set(all_sentences))
|
||||
duplicates_removed = len(all_sentences) - len(unique_sentences)
|
||||
print(f"去重完成,去重前: {len(all_sentences)} 条,去重后: {len(unique_sentences)} 条,移除重复: {duplicates_removed} 条")
|
||||
|
||||
# 检查是否有足够的去重数据
|
||||
if len(unique_sentences) < prepare_num:
|
||||
print(f"警告: 去重后的数据量 ({len(unique_sentences)}) 少于所需数量 ({prepare_num})")
|
||||
print(f"将使用全部 {len(unique_sentences)} 条去重数据")
|
||||
selected_sentences = unique_sentences
|
||||
else:
|
||||
print(f"选择前 {prepare_num} 条去重数据")
|
||||
selected_sentences = unique_sentences[:prepare_num]
|
||||
|
||||
# 转换为database_init.json格式
|
||||
print("正在转换为database_init.json格式...")
|
||||
database_init_data = convert_to_database_init_format(selected_sentences, importance_score=10.0)
|
||||
|
||||
# 更新metadata中的duplicates_removed信息
|
||||
database_init_data["metadata"]["duplicates_removed"] = duplicates_removed
|
||||
|
||||
# 保存database_init.json
|
||||
database_output_path = os.path.join(output_dir, "database_init_from_combined.json")
|
||||
print(f"正在保存 {database_output_path}...")
|
||||
with open(database_output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(database_init_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"database_init_from_combined.json 保存完成,包含 {len(selected_sentences)} 条数据")
|
||||
|
||||
# 保存剩余数据作为训练集(保持原格式)
|
||||
remaining_sentences = unique_sentences[prepare_num:] if len(unique_sentences) > prepare_num else []
|
||||
if remaining_sentences:
|
||||
# 将剩余的句子转换回原始格式
|
||||
print(f"正在转换剩余 {len(remaining_sentences)} 条数据为原始格式...")
|
||||
remaining_original_data = []
|
||||
for sentence in remaining_sentences:
|
||||
if sentence in sentence_to_original:
|
||||
remaining_original_data.append(sentence_to_original[sentence])
|
||||
|
||||
print(f"保存剩余 {len(remaining_original_data)} 条数据作为训练集...")
|
||||
train_output_path = os.path.join(output_dir, "combined_train.json")
|
||||
with open(train_output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(remaining_original_data, f, ensure_ascii=False, indent=2)
|
||||
print(f"combined_train.json 保存完成")
|
||||
else:
|
||||
print("没有剩余数据用于训练集")
|
||||
remaining_original_data = []
|
||||
|
||||
print("\n数据处理完成!")
|
||||
print(f"原始数据: {total_count} 条")
|
||||
print(f"拼接后: {len(all_sentences)} 条句子")
|
||||
print(f"去重后: {len(unique_sentences)} 条句子")
|
||||
print(f"用于database_init: {len(selected_sentences)} 条")
|
||||
print(f"剩余训练数据: {len(remaining_original_data) if remaining_sentences else 0} 条")
|
||||
|
||||
if __name__ == "__main__":
|
||||
preprocess_combined_json()
|
741
preprocessing/preprocess_pretrain.py
Normal file
741
preprocessing/preprocess_pretrain.py
Normal file
@ -0,0 +1,741 @@
|
||||
import json
|
||||
import os
|
||||
import pandas as pd
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import re
|
||||
import langdetect
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
import random
|
||||
import hashlib
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 配置参数
|
||||
BASE_DIR = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind"
|
||||
OUTPUT_FILE = os.path.join(BASE_DIR, "dataset", "merged_pretrain.jsonl")
|
||||
OUTPUT_FILE_EXTRA = os.path.join(BASE_DIR, "dataset", "merged_pretrain_extra.jsonl")
|
||||
|
||||
# 数据源路径
|
||||
PRETRAIN_HQ_PATH = os.path.join(BASE_DIR, "dataset", "pretrain_hq.jsonl")
|
||||
WIKIPEDIA_PATH = "/home/pci/nas/share/datasets/wikipedia/data/20220301.en"
|
||||
GUTENBERG_PATH = "/home/pci/nas/share/datasets/gutenberg/data"
|
||||
OPENWEBTEXT_PATH = "/home/pci/nas/share/datasets/openwebtext/subsets"
|
||||
|
||||
# Token长度限制
|
||||
MIN_TOKENS = 410
|
||||
MAX_TOKENS = 490
|
||||
|
||||
# 数据集质量和采样比例配置 - 主文件
|
||||
DATASET_CONFIG = {
|
||||
"pretrain_hq": {"quality": "high", "sample_ratio": 1.0, "max_samples": None}, # 高质量,全部使用
|
||||
"wikipedia": {"quality": "high", "sample_ratio": 1.0, "max_samples": 5000000}, # 高质量,使用全部,最多500万条
|
||||
"gutenberg": {"quality": "medium", "sample_ratio": 0.8, "max_samples": 3000000}, # 中质量,使用80%,最多300万条
|
||||
"openwebtext": {"quality": "low", "sample_ratio": 0.2, "max_samples": 2000000} # 低质量,使用20%,最多200万条
|
||||
}
|
||||
|
||||
# 额外文件的配置 - 剩余数据
|
||||
DATASET_CONFIG_EXTRA = {
|
||||
"wikipedia": {"quality": "high", "sample_ratio": 1.0, "max_samples": None}, # 剩余的全部
|
||||
"gutenberg": {"quality": "medium", "sample_ratio": 0.8, "max_samples": 5000000}, # 剩余的80%,最多500万条
|
||||
"openwebtext": {"quality": "low", "sample_ratio": 0.6, "max_samples": 4000000} # 剩余的60%,最多400万条
|
||||
}
|
||||
|
||||
# 全局变量:记录已选择的数据
|
||||
selected_data_hashes = {
|
||||
"wikipedia": set(),
|
||||
"gutenberg": set(),
|
||||
"openwebtext": set()
|
||||
}
|
||||
|
||||
# 初始化tokenizer
|
||||
tokenizer = None
|
||||
|
||||
def init_tokenizer():
|
||||
"""初始化tokenizer"""
|
||||
global tokenizer
|
||||
try:
|
||||
# 首先尝试使用本地的minimind tokenizer
|
||||
local_tokenizer_path = os.path.join(BASE_DIR, "model", "minimind_tokenizer")
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
tokenizer = AutoTokenizer.from_pretrained(local_tokenizer_path)
|
||||
logger.info("Local MiniMind tokenizer initialized successfully")
|
||||
else:
|
||||
# 如果本地tokenizer不存在,使用GPT-2(但设置离线模式)
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2", local_files_only=True)
|
||||
logger.info("GPT-2 tokenizer initialized successfully (offline)")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tokenizer: {e}")
|
||||
logger.info("Trying to use a simple fallback tokenizer...")
|
||||
# 使用简单的分词方法作为备选
|
||||
tokenizer = None
|
||||
logger.warning("Using simple word-based tokenization as fallback")
|
||||
|
||||
def count_tokens(text):
|
||||
"""计算文本的token数量"""
|
||||
if tokenizer is None:
|
||||
init_tokenizer()
|
||||
|
||||
if tokenizer is not None:
|
||||
try:
|
||||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||||
return len(tokens)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 如果tokenization失败或tokenizer为None,使用简单估算
|
||||
return int(len(text.split()) * 1.3) # 大概估算,确保返回整数
|
||||
|
||||
def is_english_text(text, threshold=0.8):
|
||||
"""检测文本是否为英文"""
|
||||
try:
|
||||
if len(text) < 50: # 太短的文本跳过检测
|
||||
return True
|
||||
detected_lang = langdetect.detect(text)
|
||||
return detected_lang == 'en'
|
||||
except:
|
||||
# 如果检测失败,使用简单的英文字符比例判断
|
||||
english_chars = sum(1 for char in text if char.isascii() and (char.isalpha() or char in ' .,!?-'))
|
||||
total_chars = len(text.replace('\n', ' ').replace('\t', ' '))
|
||||
return (english_chars / max(total_chars, 1)) > threshold
|
||||
|
||||
def truncate_to_token_limit(text, target_tokens=MAX_TOKENS):
|
||||
"""将文本截断到目标token数量"""
|
||||
if tokenizer is None:
|
||||
init_tokenizer()
|
||||
|
||||
if tokenizer is not None:
|
||||
try:
|
||||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||||
if len(tokens) <= target_tokens:
|
||||
return text
|
||||
|
||||
# 截断到目标长度
|
||||
truncated_tokens = tokens[:target_tokens]
|
||||
truncated_text = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
|
||||
|
||||
# 尝试在句号处截断以保持完整性
|
||||
sentences = truncated_text.split('.')
|
||||
if len(sentences) > 1:
|
||||
# 保留除最后一个不完整句子外的所有句子
|
||||
truncated_text = '.'.join(sentences[:-1]) + '.'
|
||||
|
||||
return truncated_text
|
||||
except:
|
||||
pass
|
||||
|
||||
# 如果处理失败或tokenizer为None,使用字符数估算
|
||||
estimated_chars = int(target_tokens / 1.3 * 4) # 大概估算
|
||||
text = text[:estimated_chars]
|
||||
|
||||
# 尝试在句号处截断以保持完整性
|
||||
sentences = text.split('.')
|
||||
if len(sentences) > 1:
|
||||
text = '.'.join(sentences[:-1]) + '.'
|
||||
|
||||
return text
|
||||
|
||||
def split_text_into_chunks(text, target_chunk_size=1500):
|
||||
"""将长文本分割成多个中等长度的段落块"""
|
||||
# 清理文本
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# 移除过多的换行符和空格
|
||||
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
||||
text = re.sub(r' +', ' ', text)
|
||||
|
||||
chunks = []
|
||||
|
||||
# 按段落分割
|
||||
paragraphs = text.split('\n\n')
|
||||
current_chunk = ""
|
||||
|
||||
for paragraph in paragraphs:
|
||||
paragraph = paragraph.strip()
|
||||
if not paragraph:
|
||||
continue
|
||||
|
||||
# 如果当前块加上新段落长度适中,就添加
|
||||
if len(current_chunk) + len(paragraph) < target_chunk_size:
|
||||
if current_chunk:
|
||||
current_chunk += "\n\n" + paragraph
|
||||
else:
|
||||
current_chunk = paragraph
|
||||
else:
|
||||
# 如果当前块不为空,保存它
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
# 如果段落本身就很长,需要进一步分割
|
||||
if len(paragraph) > target_chunk_size * 2:
|
||||
# 按句子分割长段落
|
||||
sentences = re.split(r'(?<=[.!?])\s+', paragraph)
|
||||
temp_chunk = ""
|
||||
|
||||
for sentence in sentences:
|
||||
if len(temp_chunk) + len(sentence) < target_chunk_size:
|
||||
if temp_chunk:
|
||||
temp_chunk += " " + sentence
|
||||
else:
|
||||
temp_chunk = sentence
|
||||
else:
|
||||
if temp_chunk:
|
||||
chunks.append(temp_chunk)
|
||||
temp_chunk = sentence
|
||||
|
||||
if temp_chunk:
|
||||
current_chunk = temp_chunk
|
||||
else:
|
||||
current_chunk = ""
|
||||
else:
|
||||
current_chunk = paragraph
|
||||
|
||||
# 添加最后一个块
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
def format_text_for_pretrain(text):
|
||||
"""将文本格式化为预训练格式并检查token长度"""
|
||||
# 清理文本
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# 移除过多的换行符和空格
|
||||
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
||||
text = re.sub(r' +', ' ', text)
|
||||
|
||||
# 检查token长度
|
||||
token_count = count_tokens(text)
|
||||
|
||||
# 如果太短,跳过
|
||||
if token_count < MIN_TOKENS:
|
||||
return None
|
||||
|
||||
# 如果太长,截断
|
||||
if token_count > MAX_TOKENS:
|
||||
text = truncate_to_token_limit(text, MAX_TOKENS)
|
||||
token_count = count_tokens(text)
|
||||
|
||||
# 再次检查是否在合理范围内
|
||||
if token_count < MIN_TOKENS or token_count > MAX_TOKENS:
|
||||
return None
|
||||
|
||||
# 格式化为预训练格式
|
||||
formatted_text = f"<|im_start|>{text}<|im_end|>"
|
||||
return formatted_text
|
||||
|
||||
def get_text_hash(text):
|
||||
"""获取文本的哈希值,用于去重"""
|
||||
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
||||
|
||||
def should_sample(dataset_name, current_count, config_dict=None, is_extra_mode=False):
|
||||
"""根据配置决定是否采样当前记录"""
|
||||
if config_dict is None:
|
||||
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
|
||||
|
||||
config = config_dict[dataset_name]
|
||||
|
||||
# 检查是否达到最大样本数
|
||||
if config["max_samples"] and current_count >= config["max_samples"]:
|
||||
return False
|
||||
|
||||
# 根据采样比例随机决定
|
||||
return random.random() < config["sample_ratio"]
|
||||
|
||||
def process_pretrain_hq():
|
||||
"""处理已有的高质量预训练数据 - 直接输出,不做任何处理"""
|
||||
logger.info("Processing pretrain_hq.jsonl...")
|
||||
count = 0
|
||||
|
||||
with open(PRETRAIN_HQ_PATH, 'r', encoding='utf-8') as f:
|
||||
for line in tqdm(f, desc="Processing pretrain_hq"):
|
||||
try:
|
||||
data = json.loads(line.strip())
|
||||
text = data.get('text', '').strip()
|
||||
|
||||
if text: # 只要有文本就直接输出,不做任何检测
|
||||
if should_sample("pretrain_hq", count):
|
||||
yield {"text": text}
|
||||
count += 1
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
logger.info(f"Processed {count} records from pretrain_hq.jsonl")
|
||||
|
||||
def process_wikipedia(is_extra_mode=False):
|
||||
"""处理Wikipedia数据"""
|
||||
mode_text = "extra" if is_extra_mode else "main"
|
||||
logger.info(f"Processing Wikipedia data ({mode_text} mode)...")
|
||||
count = 0
|
||||
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
|
||||
|
||||
# 获取所有英文Wikipedia文件
|
||||
wiki_files = list(Path(WIKIPEDIA_PATH).glob("*.parquet"))
|
||||
|
||||
for file_path in tqdm(wiki_files, desc=f"Processing Wikipedia files ({mode_text})"):
|
||||
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
|
||||
break
|
||||
|
||||
try:
|
||||
df = pd.read_parquet(file_path)
|
||||
for _, row in df.iterrows():
|
||||
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
|
||||
break
|
||||
|
||||
text = row.get('text', '').strip()
|
||||
if text and len(text) > 200: # 预过滤太短的文本
|
||||
# 先将长文本分割成中等大小的块
|
||||
chunks = split_text_into_chunks(text, target_chunk_size=2000)
|
||||
|
||||
for chunk in chunks:
|
||||
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
|
||||
break
|
||||
|
||||
chunk_hash = get_text_hash(chunk)
|
||||
|
||||
# 在额外模式下,跳过已经被主文件选中的数据
|
||||
if is_extra_mode and chunk_hash in selected_data_hashes["wikipedia"]:
|
||||
continue
|
||||
|
||||
formatted_text = format_text_for_pretrain(chunk)
|
||||
if formatted_text and should_sample("wikipedia", count, config_dict, is_extra_mode):
|
||||
# 在主模式下记录哈希值
|
||||
if not is_extra_mode:
|
||||
selected_data_hashes["wikipedia"].add(chunk_hash)
|
||||
|
||||
yield {"text": formatted_text}
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Processed {count} records from Wikipedia ({mode_text} mode)")
|
||||
|
||||
def process_gutenberg(is_extra_mode=False):
|
||||
"""处理Gutenberg数据"""
|
||||
mode_text = "extra" if is_extra_mode else "main"
|
||||
logger.info(f"Processing Gutenberg data ({mode_text} mode)...")
|
||||
count = 0
|
||||
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
|
||||
|
||||
# 获取所有Gutenberg训练文件
|
||||
gutenberg_files = list(Path(GUTENBERG_PATH).glob("train-*.parquet"))
|
||||
|
||||
for file_path in tqdm(gutenberg_files, desc=f"Processing Gutenberg files ({mode_text})"):
|
||||
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
|
||||
break
|
||||
|
||||
try:
|
||||
df = pd.read_parquet(file_path)
|
||||
for _, row in df.iterrows():
|
||||
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
|
||||
break
|
||||
|
||||
text = row.get('text', '').strip()
|
||||
if text and len(text) > 300 and is_english_text(text): # 预过滤
|
||||
# 先将长文本分割成中等大小的块
|
||||
chunks = split_text_into_chunks(text, target_chunk_size=1800)
|
||||
|
||||
for chunk in chunks:
|
||||
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
|
||||
break
|
||||
|
||||
chunk_hash = get_text_hash(chunk)
|
||||
|
||||
# 在额外模式下,跳过已经被主文件选中的数据
|
||||
if is_extra_mode and chunk_hash in selected_data_hashes["gutenberg"]:
|
||||
continue
|
||||
|
||||
formatted_text = format_text_for_pretrain(chunk)
|
||||
if formatted_text and should_sample("gutenberg", count, config_dict, is_extra_mode):
|
||||
# 在主模式下记录哈希值
|
||||
if not is_extra_mode:
|
||||
selected_data_hashes["gutenberg"].add(chunk_hash)
|
||||
|
||||
yield {"text": formatted_text}
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Processed {count} records from Gutenberg ({mode_text} mode)")
|
||||
|
||||
def process_openwebtext(is_extra_mode=False):
|
||||
"""处理OpenWebText数据"""
|
||||
mode_text = "extra" if is_extra_mode else "main"
|
||||
logger.info(f"Processing OpenWebText data ({mode_text} mode)...")
|
||||
count = 0
|
||||
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
|
||||
max_files = 5 # 减少处理的文件数量以避免过长处理时间
|
||||
|
||||
# 获取tar文件列表
|
||||
tar_files = list(Path(OPENWEBTEXT_PATH).glob("*.tar"))[:max_files]
|
||||
|
||||
for tar_path in tqdm(tar_files, desc=f"Processing OpenWebText files ({mode_text})"):
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
try:
|
||||
with tarfile.open(tar_path, 'r') as outer_tar:
|
||||
# 创建临时目录处理外层tar
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
outer_tar.extractall(temp_dir)
|
||||
|
||||
# 处理解压后的xz文件
|
||||
for root, dirs, files in os.walk(temp_dir):
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
for file in files:
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
if file.endswith('.xz'):
|
||||
xz_path = os.path.join(root, file)
|
||||
|
||||
# 创建另一个临时目录处理xz文件
|
||||
with tempfile.TemporaryDirectory() as xz_temp_dir:
|
||||
try:
|
||||
# 解压xz文件
|
||||
import subprocess
|
||||
decompressed_path = os.path.join(xz_temp_dir, file[:-3]) # 移除.xz后缀
|
||||
subprocess.run(['xz', '-dc', xz_path],
|
||||
stdout=open(decompressed_path, 'wb'),
|
||||
check=True)
|
||||
|
||||
# 检查解压后的文件是否是tar格式
|
||||
if tarfile.is_tarfile(decompressed_path):
|
||||
# 处理内层tar文件
|
||||
with tarfile.open(decompressed_path, 'r') as inner_tar:
|
||||
with tempfile.TemporaryDirectory() as inner_temp_dir:
|
||||
inner_tar.extractall(inner_temp_dir)
|
||||
|
||||
# 处理最终的txt文件
|
||||
for inner_root, inner_dirs, inner_files in os.walk(inner_temp_dir):
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
for txt_file in inner_files:
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
if txt_file.endswith('.txt'):
|
||||
txt_path = os.path.join(inner_root, txt_file)
|
||||
try:
|
||||
with open(txt_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
text = f.read().strip()
|
||||
if text and len(text) > 500 and is_english_text(text):
|
||||
# 先将长文本分割成中等大小的块
|
||||
chunks = split_text_into_chunks(text, target_chunk_size=1600)
|
||||
|
||||
for chunk in chunks:
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
chunk_hash = get_text_hash(chunk)
|
||||
|
||||
# 在额外模式下,跳过已经被主文件选中的数据
|
||||
if is_extra_mode and chunk_hash in selected_data_hashes["openwebtext"]:
|
||||
continue
|
||||
|
||||
formatted_text = format_text_for_pretrain(chunk)
|
||||
if formatted_text and should_sample("openwebtext", count, config_dict, is_extra_mode):
|
||||
# 在主模式下记录哈希值
|
||||
if not is_extra_mode:
|
||||
selected_data_hashes["openwebtext"].add(chunk_hash)
|
||||
|
||||
yield {"text": formatted_text}
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.debug(f"Error reading txt file {txt_path}: {e}")
|
||||
continue
|
||||
else:
|
||||
# 如果不是tar文件,直接作为文本处理
|
||||
try:
|
||||
with open(decompressed_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
text = f.read().strip()
|
||||
if text and len(text) > 500 and is_english_text(text):
|
||||
chunks = split_text_into_chunks(text, target_chunk_size=1600)
|
||||
|
||||
for chunk in chunks:
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
chunk_hash = get_text_hash(chunk)
|
||||
|
||||
# 在额外模式下,跳过已经被主文件选中的数据
|
||||
if is_extra_mode and chunk_hash in selected_data_hashes["openwebtext"]:
|
||||
continue
|
||||
|
||||
formatted_text = format_text_for_pretrain(chunk)
|
||||
if formatted_text and should_sample("openwebtext", count, config_dict, is_extra_mode):
|
||||
# 在主模式下记录哈希值
|
||||
if not is_extra_mode:
|
||||
selected_data_hashes["openwebtext"].add(chunk_hash)
|
||||
|
||||
yield {"text": formatted_text}
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.debug(f"Error reading decompressed file {decompressed_path}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error processing xz file {xz_path}: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {tar_path}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Processed {count} records from OpenWebText ({mode_text} mode)")
|
||||
|
||||
def merge_datasets():
|
||||
"""合并所有数据集,生成主文件和额外文件"""
|
||||
logger.info("Starting dataset merging...")
|
||||
logger.info("Main dataset configuration:")
|
||||
for name, config in DATASET_CONFIG.items():
|
||||
logger.info(f" {name}: quality={config['quality']}, ratio={config['sample_ratio']}, max={config['max_samples']}")
|
||||
|
||||
logger.info("Extra dataset configuration:")
|
||||
for name, config in DATASET_CONFIG_EXTRA.items():
|
||||
logger.info(f" {name}: quality={config['quality']}, ratio={config['sample_ratio']}, max={config['max_samples']}")
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)
|
||||
os.makedirs(os.path.dirname(OUTPUT_FILE_EXTRA), exist_ok=True)
|
||||
|
||||
# 统计信息
|
||||
main_dataset_stats = {}
|
||||
extra_dataset_stats = {}
|
||||
|
||||
# 第一阶段:生成主文件
|
||||
logger.info("="*50)
|
||||
logger.info("PHASE 1: Generating main dataset file")
|
||||
logger.info("="*50)
|
||||
|
||||
with open(OUTPUT_FILE, 'w', encoding='utf-8') as outfile:
|
||||
main_total_count = 0
|
||||
|
||||
# 处理各个数据集(主模式)
|
||||
main_datasets = [
|
||||
("pretrain_hq", process_pretrain_hq),
|
||||
("wikipedia", lambda: process_wikipedia(is_extra_mode=False)),
|
||||
("gutenberg", lambda: process_gutenberg(is_extra_mode=False)),
|
||||
("openwebtext", lambda: process_openwebtext(is_extra_mode=False))
|
||||
]
|
||||
|
||||
for dataset_name, dataset_func in main_datasets:
|
||||
logger.info(f"Processing {dataset_name} for main file...")
|
||||
dataset_count = 0
|
||||
|
||||
try:
|
||||
for record in dataset_func():
|
||||
json.dump(record, outfile, ensure_ascii=False)
|
||||
outfile.write('\n')
|
||||
dataset_count += 1
|
||||
main_total_count += 1
|
||||
|
||||
# 每5000条记录输出一次进度
|
||||
if main_total_count % 5000 == 0:
|
||||
logger.info(f"Main file: Processed {main_total_count} total records")
|
||||
|
||||
# 保存统计信息
|
||||
main_dataset_stats[dataset_name] = {
|
||||
'selected': dataset_count,
|
||||
'config': DATASET_CONFIG[dataset_name]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {dataset_name} for main file: {e}")
|
||||
main_dataset_stats[dataset_name] = {
|
||||
'selected': dataset_count,
|
||||
'config': DATASET_CONFIG[dataset_name]
|
||||
}
|
||||
|
||||
logger.info(f"Main file - Completed {dataset_name}: {dataset_count} records")
|
||||
|
||||
logger.info(f"Main file generation completed. Total records: {main_total_count}")
|
||||
|
||||
# 第二阶段:生成额外文件
|
||||
logger.info("="*50)
|
||||
logger.info("PHASE 2: Generating extra dataset file")
|
||||
logger.info("="*50)
|
||||
|
||||
with open(OUTPUT_FILE_EXTRA, 'w', encoding='utf-8') as outfile:
|
||||
extra_total_count = 0
|
||||
|
||||
# 处理各个数据集(额外模式)- 不包括pretrain_hq
|
||||
extra_datasets = [
|
||||
("wikipedia", lambda: process_wikipedia(is_extra_mode=True)),
|
||||
("gutenberg", lambda: process_gutenberg(is_extra_mode=True)),
|
||||
("openwebtext", lambda: process_openwebtext(is_extra_mode=True))
|
||||
]
|
||||
|
||||
for dataset_name, dataset_func in extra_datasets:
|
||||
logger.info(f"Processing {dataset_name} for extra file...")
|
||||
dataset_count = 0
|
||||
|
||||
try:
|
||||
for record in dataset_func():
|
||||
json.dump(record, outfile, ensure_ascii=False)
|
||||
outfile.write('\n')
|
||||
dataset_count += 1
|
||||
extra_total_count += 1
|
||||
|
||||
# 每5000条记录输出一次进度
|
||||
if extra_total_count % 5000 == 0:
|
||||
logger.info(f"Extra file: Processed {extra_total_count} total records")
|
||||
|
||||
# 保存统计信息
|
||||
extra_dataset_stats[dataset_name] = {
|
||||
'selected': dataset_count,
|
||||
'config': DATASET_CONFIG_EXTRA[dataset_name]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {dataset_name} for extra file: {e}")
|
||||
extra_dataset_stats[dataset_name] = {
|
||||
'selected': dataset_count,
|
||||
'config': DATASET_CONFIG_EXTRA[dataset_name]
|
||||
}
|
||||
|
||||
logger.info(f"Extra file - Completed {dataset_name}: {dataset_count} records")
|
||||
|
||||
logger.info(f"Extra file generation completed. Total records: {extra_total_count}")
|
||||
|
||||
# 打印详细统计信息
|
||||
print_detailed_statistics(main_dataset_stats, main_total_count, extra_dataset_stats, extra_total_count)
|
||||
|
||||
logger.info("All dataset processing completed successfully!")
|
||||
logger.info(f"Main file saved to: {OUTPUT_FILE}")
|
||||
logger.info(f"Extra file saved to: {OUTPUT_FILE_EXTRA}")
|
||||
|
||||
def print_detailed_statistics(main_dataset_stats, main_total_count, extra_dataset_stats, extra_total_count):
|
||||
"""打印详细统计信息"""
|
||||
print("\n" + "="*100)
|
||||
print("DATASET PROCESSING SUMMARY")
|
||||
print("="*100)
|
||||
|
||||
# 主文件统计
|
||||
print("\nMAIN FILE (merged_pretrain.jsonl):")
|
||||
print("-" * 90)
|
||||
print(f"{'Dataset':<15} {'Selected':<10} {'Sample Ratio':<12} {'Max Limit':<12} {'% of Main':<12} {'Quality':<8}")
|
||||
print("-" * 90)
|
||||
|
||||
for dataset_name, stats in main_dataset_stats.items():
|
||||
selected = stats['selected']
|
||||
config = stats['config']
|
||||
ratio = config['sample_ratio']
|
||||
max_limit = config['max_samples'] if config['max_samples'] else "No limit"
|
||||
percentage = (selected / main_total_count * 100) if main_total_count > 0 else 0
|
||||
quality = config['quality']
|
||||
|
||||
print(f"{dataset_name:<15} {selected:<10,} {ratio:<12.1%} {str(max_limit):<12} {percentage:<12.2f}% {quality:<8}")
|
||||
|
||||
print("-" * 90)
|
||||
print(f"{'MAIN TOTAL':<15} {main_total_count:<10,} {'':<12} {'':<12} {'100.00%':<12} {'':<8}")
|
||||
|
||||
# 额外文件统计
|
||||
print("\nEXTRA FILE (merged_pretrain_extra.jsonl):")
|
||||
print("-" * 90)
|
||||
print(f"{'Dataset':<15} {'Selected':<10} {'Sample Ratio':<12} {'Max Limit':<12} {'% of Extra':<12} {'Quality':<8}")
|
||||
print("-" * 90)
|
||||
|
||||
for dataset_name, stats in extra_dataset_stats.items():
|
||||
selected = stats['selected']
|
||||
config = stats['config']
|
||||
ratio = config['sample_ratio']
|
||||
max_limit = config['max_samples'] if config['max_samples'] else "No limit"
|
||||
percentage = (selected / extra_total_count * 100) if extra_total_count > 0 else 0
|
||||
quality = config['quality']
|
||||
|
||||
print(f"{dataset_name:<15} {selected:<10,} {ratio:<12.1%} {str(max_limit):<12} {percentage:<12.2f}% {quality:<8}")
|
||||
|
||||
print("-" * 90)
|
||||
print(f"{'EXTRA TOTAL':<15} {extra_total_count:<10,} {'':<12} {'':<12} {'100.00%':<12} {'':<8}")
|
||||
|
||||
# 总体统计
|
||||
total_records = main_total_count + extra_total_count
|
||||
print("\nOVERALL STATISTICS:")
|
||||
print("-" * 50)
|
||||
print(f"Main file records: {main_total_count:>10,}")
|
||||
print(f"Extra file records: {extra_total_count:>10,}")
|
||||
print(f"Total records: {total_records:>10,}")
|
||||
print(f"Token range per sample: {MIN_TOKENS}-{MAX_TOKENS} tokens")
|
||||
|
||||
# 质量分布统计
|
||||
quality_stats = {}
|
||||
for dataset_name, stats in main_dataset_stats.items():
|
||||
quality = stats['config']['quality']
|
||||
if quality not in quality_stats:
|
||||
quality_stats[quality] = {'main': 0, 'extra': 0}
|
||||
quality_stats[quality]['main'] += stats['selected']
|
||||
|
||||
for dataset_name, stats in extra_dataset_stats.items():
|
||||
quality = stats['config']['quality']
|
||||
if quality not in quality_stats:
|
||||
quality_stats[quality] = {'main': 0, 'extra': 0}
|
||||
quality_stats[quality]['extra'] += stats['selected']
|
||||
|
||||
print("\nQUALITY DISTRIBUTION:")
|
||||
print("-" * 60)
|
||||
print(f"{'Quality':<12} {'Main File':<12} {'Extra File':<12} {'Total':<12} {'%':<8}")
|
||||
print("-" * 60)
|
||||
for quality in sorted(quality_stats.keys()):
|
||||
main_count = quality_stats[quality]['main']
|
||||
extra_count = quality_stats[quality]['extra']
|
||||
total_count = main_count + extra_count
|
||||
percentage = (total_count / total_records * 100) if total_records > 0 else 0
|
||||
print(f"{quality.capitalize():<12} {main_count:<12,} {extra_count:<12,} {total_count:<12,} {percentage:<8.2f}%")
|
||||
print("-" * 60)
|
||||
print(f"{'Total':<12} {main_total_count:<12,} {extra_total_count:<12,} {total_records:<12,} {'100.00%':<8}")
|
||||
|
||||
print(f"\nFiles saved to:")
|
||||
print(f" Main file: {OUTPUT_FILE}")
|
||||
print(f" Extra file: {OUTPUT_FILE_EXTRA}")
|
||||
print("="*100)
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
try:
|
||||
# 设置随机种子以确保结果可重现
|
||||
random.seed(42)
|
||||
|
||||
# 检查依赖包
|
||||
try:
|
||||
import langdetect
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError as e:
|
||||
logger.error(f"Missing dependencies: {e}")
|
||||
logger.error("Please install: pip install langdetect transformers")
|
||||
return
|
||||
|
||||
# 初始化tokenizer
|
||||
init_tokenizer()
|
||||
|
||||
# 检查输入文件
|
||||
if not os.path.exists(PRETRAIN_HQ_PATH):
|
||||
logger.error(f"pretrain_hq.jsonl not found at {PRETRAIN_HQ_PATH}")
|
||||
return
|
||||
|
||||
# 开始合并数据集
|
||||
merge_datasets()
|
||||
logger.info("All processing completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main process: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
61
preprocessing/test_preprocess_small.py
Normal file
61
preprocessing/test_preprocess_small.py
Normal file
@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
小规模测试预处理脚本
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加路径
|
||||
sys.path.append('/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/preprocessing')
|
||||
|
||||
# 导入主模块
|
||||
from preprocess_pretrain import *
|
||||
|
||||
# 修改配置为小规模测试
|
||||
DATASET_CONFIG["wikipedia"]["max_samples"] = 100
|
||||
DATASET_CONFIG["gutenberg"]["max_samples"] = 50
|
||||
DATASET_CONFIG["openwebtext"]["max_samples"] = 20
|
||||
|
||||
DATASET_CONFIG_EXTRA["wikipedia"]["max_samples"] = 50
|
||||
DATASET_CONFIG_EXTRA["gutenberg"]["max_samples"] = 30
|
||||
DATASET_CONFIG_EXTRA["openwebtext"]["max_samples"] = 15
|
||||
|
||||
# 修改输出路径
|
||||
OUTPUT_FILE = "/tmp/test_main.jsonl"
|
||||
OUTPUT_FILE_EXTRA = "/tmp/test_extra.jsonl"
|
||||
|
||||
def test_small_scale():
|
||||
"""小规模测试"""
|
||||
print("Starting small scale test...")
|
||||
|
||||
# 设置随机种子
|
||||
random.seed(42)
|
||||
|
||||
try:
|
||||
# 初始化tokenizer
|
||||
init_tokenizer()
|
||||
|
||||
# 开始合并数据集
|
||||
merge_datasets()
|
||||
|
||||
# 检查输出文件
|
||||
if os.path.exists(OUTPUT_FILE):
|
||||
with open(OUTPUT_FILE, 'r') as f:
|
||||
main_lines = len(f.readlines())
|
||||
print(f"Main file created: {main_lines} lines")
|
||||
|
||||
if os.path.exists(OUTPUT_FILE_EXTRA):
|
||||
with open(OUTPUT_FILE_EXTRA, 'r') as f:
|
||||
extra_lines = len(f.readlines())
|
||||
print(f"Extra file created: {extra_lines} lines")
|
||||
|
||||
print("Small scale test completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_small_scale()
|
175
pyproject.toml
Normal file
175
pyproject.toml
Normal file
@ -0,0 +1,175 @@
|
||||
[project]
|
||||
name = "minimind"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"accelerate==1.7.0",
|
||||
"aiohappyeyeballs==2.6.1",
|
||||
"aiohttp==3.11.17",
|
||||
"aiosignal==1.3.2",
|
||||
"altair==5.5.0",
|
||||
"annotated-types==0.7.0",
|
||||
"anyio==4.9.0",
|
||||
"async-timeout==5.0.1",
|
||||
"attrs==25.3.0",
|
||||
"blinker==1.9.0",
|
||||
"boto3==1.38.41",
|
||||
"botocore==1.38.41",
|
||||
"cachetools==5.5.2",
|
||||
"certifi==2025.1.31",
|
||||
"charset-normalizer==3.4.1",
|
||||
"click==8.1.8",
|
||||
"contourpy==1.3.2",
|
||||
"cycler==0.12.1",
|
||||
"datasets==2.21.0",
|
||||
"datasketch==1.6.4",
|
||||
"deepspeed==0.17.0",
|
||||
"determined>=0.37.0",
|
||||
"dill==0.3.8",
|
||||
"distro==1.9.0",
|
||||
"docker-pycreds==0.4.0",
|
||||
"einops==0.8.1",
|
||||
"exceptiongroup==1.2.2",
|
||||
"filelock==3.18.0",
|
||||
"Flask==3.0.3",
|
||||
"Flask-Cors==4.0.0",
|
||||
"fonttools==4.57.0",
|
||||
"frozenlist==1.6.0",
|
||||
"fsspec==2024.6.1",
|
||||
"gitdb==4.0.12",
|
||||
"GitPython==3.1.44",
|
||||
"h11==0.14.0",
|
||||
"hjson==3.1.0",
|
||||
"httpcore==1.0.8",
|
||||
"httpx==0.28.1",
|
||||
"huggingface-hub==0.30.2",
|
||||
"importlib_metadata==7.2.1",
|
||||
"itsdangerous==2.2.0",
|
||||
"jieba==0.42.1",
|
||||
"Jinja2==3.1.2",
|
||||
"jiter==0.9.0",
|
||||
"jmespath==1.0.1",
|
||||
"joblib==1.4.2",
|
||||
"jsonlines==4.0.0",
|
||||
"jsonpointer==2.1",
|
||||
"jsonschema==4.23.0",
|
||||
"jsonschema-specifications==2024.10.1",
|
||||
"kiwisolver==1.4.8",
|
||||
"langdetect==1.0.9",
|
||||
"markdown-it-py==3.0.0",
|
||||
"MarkupSafe==3.0.2",
|
||||
"marshmallow==3.22.0",
|
||||
"matplotlib==3.10.0",
|
||||
"mdurl==0.1.2",
|
||||
"modelscope==1.25.0",
|
||||
"mpi4py>=4.0.3",
|
||||
"mpmath==1.3.0",
|
||||
"msgpack==1.1.0",
|
||||
"multidict==6.4.3",
|
||||
"multiprocess==0.70.16",
|
||||
"narwhals==1.35.0",
|
||||
"networkx==3.4.2",
|
||||
"ngrok==1.4.0",
|
||||
"ninja==1.11.1.4",
|
||||
"nltk==3.8",
|
||||
"numpy==1.26.4",
|
||||
"nvidia-cublas-cu11==11.11.3.6",
|
||||
"nvidia-cublas-cu12==12.6.4.1",
|
||||
"nvidia-cuda-cupti-cu11==11.8.87",
|
||||
"nvidia-cuda-cupti-cu12==12.6.80",
|
||||
"nvidia-cuda-nvrtc-cu11==11.8.89",
|
||||
"nvidia-cuda-nvrtc-cu12==12.6.77",
|
||||
"nvidia-cuda-runtime-cu11==11.8.89",
|
||||
"nvidia-cuda-runtime-cu12==12.6.77",
|
||||
"nvidia-cudnn-cu11==9.1.0.70",
|
||||
"nvidia-cudnn-cu12==9.5.1.17",
|
||||
"nvidia-cufft-cu11==10.9.0.58",
|
||||
"nvidia-cufft-cu12==11.3.0.4",
|
||||
"nvidia-cufile-cu12==1.11.1.6",
|
||||
"nvidia-curand-cu11==10.3.0.86",
|
||||
"nvidia-curand-cu12==10.3.7.77",
|
||||
"nvidia-cusolver-cu11==11.4.1.48",
|
||||
"nvidia-cusolver-cu12==11.7.1.2",
|
||||
"nvidia-cusparse-cu11==11.7.5.86",
|
||||
"nvidia-cusparse-cu12==12.5.4.2",
|
||||
"nvidia-cusparselt-cu12==0.6.3",
|
||||
"nvidia-ml-py==12.575.51",
|
||||
"nvidia-nccl-cu11==2.21.5",
|
||||
"nvidia-nccl-cu12==2.26.2",
|
||||
"nvidia-nvjitlink-cu12==12.6.85",
|
||||
"nvidia-nvtx-cu11==11.8.86",
|
||||
"nvidia-nvtx-cu12==12.6.77",
|
||||
"openai==1.59.6",
|
||||
"packaging==23.2",
|
||||
"pandas>=2.0.0",
|
||||
"peft==0.7.1",
|
||||
"pillow==10.4.0",
|
||||
"platformdirs==4.3.7",
|
||||
"prettytable==3.16.0",
|
||||
"propcache==0.3.1",
|
||||
"protobuf==4.25.6",
|
||||
"psutil==5.9.8",
|
||||
"py-cpuinfo==9.0.0",
|
||||
"pyarrow==19.0.1",
|
||||
"pydantic==2.11.7",
|
||||
"pydantic_core==2.33.2",
|
||||
"pydeck==0.9.1",
|
||||
"pyecharts==2.0.8",
|
||||
"Pygments==2.19.1",
|
||||
"pynvml==12.0.0",
|
||||
"pyparsing==3.2.3",
|
||||
"python-dateutil==2.9.0.post0",
|
||||
"pytz==2025.2",
|
||||
"PyYAML==6.0.2",
|
||||
"referencing==0.36.2",
|
||||
"regex==2024.11.6",
|
||||
"requests==2.32.3",
|
||||
"rich==13.7.1",
|
||||
"rpds-py==0.24.0",
|
||||
"s3transfer==0.13.0",
|
||||
"safetensors==0.5.3",
|
||||
"scikit-learn==1.5.1",
|
||||
"scipy==1.15.2",
|
||||
"sentence-transformers==2.3.1",
|
||||
"sentencepiece==0.2.0",
|
||||
"sentry-sdk==2.26.1",
|
||||
"setproctitle==1.3.5",
|
||||
"simhash==2.1.2",
|
||||
"simplejson==3.20.1",
|
||||
"six==1.17.0",
|
||||
"smmap==5.0.2",
|
||||
"sniffio==1.3.1",
|
||||
"streamlit==1.30.0",
|
||||
"swankit==0.2.4",
|
||||
"swanlab==0.6.4",
|
||||
"sympy==1.13.3",
|
||||
"tenacity==8.5.0",
|
||||
"threadpoolctl==3.6.0",
|
||||
"tiktoken>=0.8.0",
|
||||
"tokenizers==0.21.1",
|
||||
"toml==0.10.2",
|
||||
"torch==2.7.1",
|
||||
"torchaudio==2.7.1",
|
||||
"torchvision==0.22.1",
|
||||
"tornado==6.4.2",
|
||||
"tqdm==4.67.1",
|
||||
"transformers==4.52.4",
|
||||
"triton==3.3.1",
|
||||
"trl==0.13.0",
|
||||
"typing-inspection==0.4.1",
|
||||
"typing_extensions==4.13.2",
|
||||
"tzlocal==5.3.1",
|
||||
"ujson==5.1.0",
|
||||
"urllib3==2.4.0",
|
||||
"validators==0.34.0",
|
||||
"wandb==0.18.3",
|
||||
"watchdog==6.0.0",
|
||||
"wcwidth==0.2.13",
|
||||
"Werkzeug==3.1.3",
|
||||
"wrapt==1.17.2",
|
||||
"xxhash==3.5.0",
|
||||
"yarl==1.20.0",
|
||||
"zipp==3.21.0",
|
||||
]
|
@ -1,3 +1,4 @@
|
||||
accelerate==1.7.0
|
||||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.11.17
|
||||
aiosignal==1.3.2
|
||||
@ -7,6 +8,8 @@ anyio==4.9.0
|
||||
async-timeout==5.0.1
|
||||
attrs==25.3.0
|
||||
blinker==1.9.0
|
||||
boto3==1.38.41
|
||||
botocore==1.38.41
|
||||
cachetools==5.5.2
|
||||
certifi==2025.1.31
|
||||
charset-normalizer==3.4.1
|
||||
@ -15,6 +18,7 @@ contourpy==1.3.2
|
||||
cycler==0.12.1
|
||||
datasets==2.21.0
|
||||
datasketch==1.6.4
|
||||
deepspeed==0.17.0
|
||||
dill==0.3.8
|
||||
distro==1.9.0
|
||||
docker-pycreds==0.4.0
|
||||
@ -33,17 +37,19 @@ hjson==3.1.0
|
||||
httpcore==1.0.8
|
||||
httpx==0.28.1
|
||||
huggingface-hub==0.30.2
|
||||
idna==3.10
|
||||
importlib_metadata==7.2.1
|
||||
itsdangerous==2.2.0
|
||||
jieba==0.42.1
|
||||
Jinja2==3.1.2
|
||||
jiter==0.9.0
|
||||
jmespath==1.0.1
|
||||
joblib==1.4.2
|
||||
jsonlines==4.0.0
|
||||
jsonpointer==2.1
|
||||
jsonschema==4.23.0
|
||||
jsonschema-specifications==2024.10.1
|
||||
kiwisolver==1.4.8
|
||||
langdetect==1.0.9
|
||||
markdown-it-py==3.0.0
|
||||
MarkupSafe==3.0.2
|
||||
marshmallow==3.22.0
|
||||
@ -60,21 +66,50 @@ ngrok==1.4.0
|
||||
ninja==1.11.1.4
|
||||
nltk==3.8
|
||||
numpy==1.26.4
|
||||
nvidia-cublas-cu11==11.11.3.6
|
||||
nvidia-cublas-cu12==12.6.4.1
|
||||
nvidia-cuda-cupti-cu11==11.8.87
|
||||
nvidia-cuda-cupti-cu12==12.6.80
|
||||
nvidia-cuda-nvrtc-cu11==11.8.89
|
||||
nvidia-cuda-nvrtc-cu12==12.6.77
|
||||
nvidia-cuda-runtime-cu11==11.8.89
|
||||
nvidia-cuda-runtime-cu12==12.6.77
|
||||
nvidia-cudnn-cu11==9.1.0.70
|
||||
nvidia-cudnn-cu12==9.5.1.17
|
||||
nvidia-cufft-cu11==10.9.0.58
|
||||
nvidia-cufft-cu12==11.3.0.4
|
||||
nvidia-cufile-cu12==1.11.1.6
|
||||
nvidia-curand-cu11==10.3.0.86
|
||||
nvidia-curand-cu12==10.3.7.77
|
||||
nvidia-cusolver-cu11==11.4.1.48
|
||||
nvidia-cusolver-cu12==11.7.1.2
|
||||
nvidia-cusparse-cu11==11.7.5.86
|
||||
nvidia-cusparse-cu12==12.5.4.2
|
||||
nvidia-cusparselt-cu12==0.6.3
|
||||
nvidia-ml-py==12.575.51
|
||||
nvidia-nccl-cu11==2.21.5
|
||||
nvidia-nccl-cu12==2.26.2
|
||||
nvidia-nvjitlink-cu12==12.6.85
|
||||
nvidia-nvtx-cu11==11.8.86
|
||||
nvidia-nvtx-cu12==12.6.77
|
||||
openai==1.59.6
|
||||
packaging==23.2
|
||||
pandas==1.5.3
|
||||
peft==0.7.1
|
||||
pillow==10.4.0
|
||||
platformdirs==4.3.7
|
||||
prettytable==3.16.0
|
||||
propcache==0.3.1
|
||||
protobuf==4.25.6
|
||||
psutil==5.9.8
|
||||
py-cpuinfo==9.0.0
|
||||
pyarrow==19.0.1
|
||||
pydantic==2.8.2
|
||||
pydantic_core==2.20.1
|
||||
pydantic==2.11.7
|
||||
pydantic_core==2.33.2
|
||||
pydeck==0.9.1
|
||||
pyecharts==2.0.8
|
||||
Pygments==2.19.1
|
||||
pynvml==12.0.0
|
||||
pyparsing==3.2.3
|
||||
python-dateutil==2.9.0.post0
|
||||
pytz==2025.2
|
||||
@ -84,6 +119,7 @@ regex==2024.11.6
|
||||
requests==2.32.3
|
||||
rich==13.7.1
|
||||
rpds-py==0.24.0
|
||||
s3transfer==0.13.0
|
||||
safetensors==0.5.3
|
||||
scikit-learn==1.5.1
|
||||
scipy==1.15.2
|
||||
@ -92,21 +128,28 @@ sentencepiece==0.2.0
|
||||
sentry-sdk==2.26.1
|
||||
setproctitle==1.3.5
|
||||
simhash==2.1.2
|
||||
simplejson==3.20.1
|
||||
six==1.17.0
|
||||
smmap==5.0.2
|
||||
sniffio==1.3.1
|
||||
streamlit==1.30.0
|
||||
swankit==0.2.4
|
||||
swanlab==0.6.4
|
||||
sympy==1.13.3
|
||||
tenacity==8.5.0
|
||||
threadpoolctl==3.6.0
|
||||
tiktoken==0.5.1
|
||||
tokenizers==0.21.1
|
||||
toml==0.10.2
|
||||
torch==2.7.1
|
||||
torchaudio==2.7.1
|
||||
torchvision==0.22.1
|
||||
tornado==6.4.2
|
||||
tqdm==4.67.1
|
||||
transformers==4.48.0
|
||||
triton==3.3.0
|
||||
transformers==4.52.4
|
||||
triton==3.3.1
|
||||
trl==0.13.0
|
||||
typing-inspection==0.4.1
|
||||
typing_extensions==4.13.2
|
||||
tzlocal==5.3.1
|
||||
ujson==5.1.0
|
||||
@ -114,7 +157,9 @@ urllib3==2.4.0
|
||||
validators==0.34.0
|
||||
wandb==0.18.3
|
||||
watchdog==6.0.0
|
||||
wcwidth==0.2.13
|
||||
Werkzeug==3.1.3
|
||||
wrapt==1.17.2
|
||||
xxhash==3.5.0
|
||||
yarl==1.20.0
|
||||
zipp==3.21.0
|
||||
|
@ -1,8 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 激活conda环境
|
||||
source $(conda info --base)/etc/profile.d/conda.sh
|
||||
conda activate mini
|
||||
# source $(conda info --base)/etc/profile.d/conda.sh
|
||||
# conda activate mini
|
||||
|
||||
# 设置环境变量以帮助调试
|
||||
export NCCL_DEBUG=INFO
|
||||
@ -26,9 +26,27 @@ export PYTHONFAULTHANDLER=1
|
||||
# --profile_interval 10
|
||||
|
||||
# 方法2: 使用命令行参数直接配置accelerate
|
||||
CUDA_VISIBLE_DEVICES=0 /opt/conda/envs/mini/bin/python -m accelerate.commands.launch \
|
||||
# 内存泄漏调试配置 - 减少内存使用
|
||||
CUDA_VISIBLE_DEVICES=0 uv run -p .venv python -m accelerate.commands.launch \
|
||||
--num_processes=1 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py \
|
||||
train_pretrain_accelerate.py
|
||||
# --batch_size 128 \
|
||||
# --num_workers 1
|
||||
# --knowledge_num 48020 \
|
||||
# --num_workers 1 \
|
||||
# --epochs 4 \
|
||||
# --learning_rate 2e-4 \
|
||||
# --dtype bfloat16 \
|
||||
# --accumulation_steps 32 \
|
||||
# --grad_clip 1.0 \
|
||||
# --log_interval 50 \
|
||||
# --save_interval 10000 \
|
||||
# --dim 512 \
|
||||
# --n_layers 8 \
|
||||
# --max_seq_len 512 \
|
||||
# --use_flash_attn \
|
||||
# --profile \
|
||||
# --profile_interval 10
|
||||
|
||||
|
33
startup.sh
Normal file
33
startup.sh
Normal file
@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# 在容器启动后,首先从 requirements.txt 安装所有依赖包
|
||||
# pip install -r requirements.txt
|
||||
|
||||
# bash install.sh -y
|
||||
python3 -m pip install --upgrade pip
|
||||
pip install uv -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
# 切换到项目目录
|
||||
cd /ycz/Minimind
|
||||
|
||||
# 检查并修复虚拟环境
|
||||
if [ ! -f .venv/bin/python ] || [ ! -x .venv/bin/python ]; then
|
||||
echo "Virtual environment is broken or missing, recreating with uv..."
|
||||
rm -rf .venv
|
||||
uv venv .venv
|
||||
fi
|
||||
|
||||
# 不要手动激活虚拟环境,让uv自动管理
|
||||
# . ./.venv/bin/activate
|
||||
|
||||
# 使用uv同步依赖
|
||||
uv sync
|
||||
|
||||
# 安装完成后,执行主训练脚本
|
||||
# "$@" 会将 experiment.yaml 中 entrypoint 定义的参数传递给 python 脚本
|
||||
CUDA_VISIBLE_DEVICES=0 uv run python -m accelerate.commands.launch \
|
||||
--num_processes=1 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py "$@"
|
@ -1,97 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
测试实数版本的位置编码
|
||||
"""
|
||||
|
||||
import torch
|
||||
from model.model import precompute_pos_cis, precompute_pos_cis_real, apply_rotary_emb, apply_rotary_emb_real
|
||||
from model.LMConfig import LMConfig
|
||||
from model.model import MiniMindLM
|
||||
|
||||
def test_pos_encoding_equivalence():
|
||||
"""测试复数版本和实数版本的位置编码是否等价"""
|
||||
print("测试位置编码等价性...")
|
||||
|
||||
# 参数设置
|
||||
dim = 64
|
||||
seq_len = 10
|
||||
|
||||
# 生成复数版本的位置编码
|
||||
pos_cis = precompute_pos_cis(dim=dim, end=seq_len)
|
||||
|
||||
# 生成实数版本的位置编码
|
||||
pos_cis_real = precompute_pos_cis_real(dim=dim, end=seq_len)
|
||||
|
||||
# 创建随机查询和键
|
||||
batch_size = 2
|
||||
n_heads = 4
|
||||
head_dim = dim
|
||||
|
||||
xq = torch.randn(batch_size, seq_len, n_heads, head_dim)
|
||||
xk = torch.randn(batch_size, seq_len, n_heads, head_dim)
|
||||
|
||||
# 应用复数版本的旋转位置编码
|
||||
xq_complex, xk_complex = apply_rotary_emb(xq, xk, pos_cis)
|
||||
|
||||
# 应用实数版本的旋转位置编码
|
||||
xq_real, xk_real = apply_rotary_emb_real(xq, xk, pos_cis_real)
|
||||
|
||||
# 计算差异
|
||||
q_diff = torch.abs(xq_complex - xq_real).mean().item()
|
||||
k_diff = torch.abs(xk_complex - xk_real).mean().item()
|
||||
|
||||
print(f"查询差异: {q_diff:.6f}")
|
||||
print(f"键差异: {k_diff:.6f}")
|
||||
|
||||
# 检查差异是否在可接受范围内
|
||||
tolerance = 1e-5
|
||||
if q_diff < tolerance and k_diff < tolerance:
|
||||
print("✅ 测试通过: 复数版本和实数版本的位置编码在数值上等价")
|
||||
else:
|
||||
print("❌ 测试失败: 复数版本和实数版本的位置编码存在显著差异")
|
||||
|
||||
def test_model_forward():
|
||||
"""测试模型前向传播"""
|
||||
print("\n测试模型前向传播...")
|
||||
|
||||
# 创建模型配置
|
||||
config = LMConfig(
|
||||
dim=128,
|
||||
n_layers=2,
|
||||
n_heads=4,
|
||||
n_kv_heads=4, # 确保n_kv_heads被设置,且n_heads能被n_kv_heads整除
|
||||
vocab_size=1000,
|
||||
max_seq_len=128,
|
||||
disable_db=True # 禁用数据库功能,避免额外的复杂性
|
||||
)
|
||||
|
||||
# 创建模型
|
||||
try:
|
||||
model = MiniMindLM(config)
|
||||
print(f"✅ 模型初始化成功")
|
||||
except Exception as e:
|
||||
print(f"❌ 模型初始化失败: {str(e)}")
|
||||
return
|
||||
|
||||
# 创建输入
|
||||
batch_size = 2
|
||||
seq_len = 10
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
||||
|
||||
# 前向传播
|
||||
try:
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids)
|
||||
print(f"✅ 模型前向传播成功")
|
||||
print(f"输出形状: {outputs.logits.shape}")
|
||||
except Exception as e:
|
||||
print(f"❌ 模型前向传播失败: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试位置编码等价性
|
||||
test_pos_encoding_equivalence()
|
||||
|
||||
# 测试模型前向传播
|
||||
test_model_forward()
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
# 设置环境变量
|
||||
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
|
||||
# 设置环境变量 - 将wandb替换为SwanLab
|
||||
# os.environ["SWANLAB_MODE"] = "online" # SwanLab使用在线模式
|
||||
import platform
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
@ -21,6 +21,9 @@ from accelerate.utils import DistributedDataParallelKwargs
|
||||
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
||||
import numpy as np
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import swanlab # 替换wandb导入
|
||||
import gc # 添加垃圾回收模块
|
||||
import psutil # 添加系统资源监控模块
|
||||
|
||||
from model.model import MiniMindLM, RMSNorm
|
||||
from model.LMConfig import LMConfig
|
||||
@ -28,6 +31,63 @@ from model.dataset import PretrainDataset
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# 内存监控辅助函数
|
||||
def get_memory_usage():
|
||||
"""获取当前内存使用情况"""
|
||||
process = psutil.Process()
|
||||
memory_info = process.memory_info()
|
||||
return {
|
||||
'rss_mb': memory_info.rss / 1024 / 1024, # 物理内存使用量(MB)
|
||||
'vms_mb': memory_info.vms / 1024 / 1024, # 虚拟内存使用量(MB)
|
||||
}
|
||||
|
||||
def get_cuda_memory_usage():
|
||||
"""获取CUDA内存使用情况"""
|
||||
if torch.cuda.is_available():
|
||||
return {
|
||||
'cuda_allocated_mb': torch.cuda.memory_allocated() / 1024 / 1024,
|
||||
'cuda_reserved_mb': torch.cuda.memory_reserved() / 1024 / 1024,
|
||||
'cuda_max_allocated_mb': torch.cuda.max_memory_allocated() / 1024 / 1024,
|
||||
}
|
||||
return {}
|
||||
|
||||
def get_tensor_memory_size(tensor_list):
|
||||
"""计算tensor列表的总内存占用(MB)"""
|
||||
total_size = 0
|
||||
for batch in tensor_list:
|
||||
if isinstance(batch, (list, tuple)):
|
||||
for tensor in batch:
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
total_size += tensor.numel() * tensor.element_size()
|
||||
elif isinstance(batch, torch.Tensor):
|
||||
total_size += batch.numel() * batch.element_size()
|
||||
return total_size / 1024 / 1024 # 转换为MB
|
||||
|
||||
def log_memory_status(step, prefetch_batches, accelerator, stage="", detailed=False):
|
||||
"""记录内存状态"""
|
||||
if not accelerator.is_main_process:
|
||||
return
|
||||
|
||||
memory_info = get_memory_usage()
|
||||
cuda_info = get_cuda_memory_usage()
|
||||
prefetch_memory = get_tensor_memory_size(prefetch_batches)
|
||||
|
||||
log_msg = f"[Memory Monitor] Step {step} {stage} - "
|
||||
log_msg += f"Prefetch batches: {len(prefetch_batches)}, "
|
||||
log_msg += f"Prefetch memory: {prefetch_memory:.2f}MB, "
|
||||
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
|
||||
|
||||
if cuda_info:
|
||||
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
|
||||
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
|
||||
|
||||
if detailed:
|
||||
log_msg += f", System VMS: {memory_info['vms_mb']:.2f}MB"
|
||||
if cuda_info:
|
||||
log_msg += f", CUDA max allocated: {cuda_info['cuda_max_allocated_mb']:.2f}MB"
|
||||
|
||||
Logger(log_msg, accelerator)
|
||||
|
||||
# 日志记录函数
|
||||
def Logger(msg, accelerator=None):
|
||||
# 如果没有提供accelerator,则只在主进程打印
|
||||
@ -218,7 +278,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
return model, tokenizer
|
||||
|
||||
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb):
|
||||
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run):
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
epoch_start_time = time.time()
|
||||
total_steps_in_epoch = len(train_loader)
|
||||
@ -226,6 +286,10 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
moe_path = '_moe' if args.use_moe else ''
|
||||
best_loss = float('10000')
|
||||
|
||||
# 初始化CUDA事件变量
|
||||
data_start = data_end = forward_start = forward_end = None
|
||||
backward_start = backward_end = optimizer_start = optimizer_end = None
|
||||
|
||||
# 添加CUDA事件来分析性能 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
data_start = torch.cuda.Event(enable_timing=True)
|
||||
@ -242,40 +306,63 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
data_iter = iter(train_loader)
|
||||
prefetch_batches = []
|
||||
|
||||
# 记录初始内存状态
|
||||
if args.memory_monitor:
|
||||
log_memory_status(-1, prefetch_batches, accelerator, "before_prefetch", detailed=True)
|
||||
|
||||
# 预取初始批次
|
||||
for _ in range(min(prefetch_factor, len(train_loader))):
|
||||
for i in range(min(prefetch_factor, len(train_loader))):
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
prefetch_batches.append(batch)
|
||||
# 每次添加batch后记录内存变化
|
||||
if args.memory_monitor and accelerator.is_main_process:
|
||||
log_memory_status(-1, prefetch_batches, accelerator, f"after_adding_batch_{i+1}")
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
# 记录预取完成后的内存状态
|
||||
if args.memory_monitor:
|
||||
log_memory_status(-1, prefetch_batches, accelerator, "after_initial_prefetch", detailed=True)
|
||||
|
||||
# 在开始循环前初始化日志记录所需变量
|
||||
last_log_time = epoch_start_time
|
||||
|
||||
for step in range(total_steps_in_epoch):
|
||||
try:
|
||||
# 计时数据加载 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
if args.profile and accelerator.is_main_process and data_start is not None:
|
||||
data_start.record()
|
||||
|
||||
# 记录使用batch前的内存状态(根据配置间隔记录详细信息)
|
||||
if args.memory_monitor and step % args.memory_monitor_interval == 0:
|
||||
log_memory_status(step, prefetch_batches, accelerator, "before_use_batch", detailed=True)
|
||||
|
||||
# 使用预取的数据
|
||||
if prefetch_batches:
|
||||
X, Y, loss_mask = prefetch_batches.pop(0)
|
||||
# 记录使用batch后的内存变化
|
||||
if args.memory_monitor and step % args.memory_monitor_interval == 0:
|
||||
log_memory_status(step, prefetch_batches, accelerator, "after_pop_batch")
|
||||
else:
|
||||
# 如果预取队列为空,直接加载
|
||||
X, Y, loss_mask = next(data_iter)
|
||||
if args.memory_monitor and accelerator.is_main_process:
|
||||
Logger(f"[Memory Monitor] Step {step} - Prefetch queue empty, loading directly!", accelerator)
|
||||
|
||||
# 异步预取下一批数据
|
||||
if step + prefetch_factor < len(train_loader):
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
prefetch_batches.append(batch)
|
||||
# 记录添加新batch后的内存变化
|
||||
if args.memory_monitor and step % args.memory_monitor_interval == 0:
|
||||
log_memory_status(step, prefetch_batches, accelerator, "after_add_batch")
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
# 计时数据加载结束 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
if args.profile and accelerator.is_main_process and data_end is not None:
|
||||
data_end.record()
|
||||
|
||||
# 更新学习率
|
||||
@ -283,7 +370,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
scheduler.step()
|
||||
|
||||
# 计时前向传播 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
if args.profile and accelerator.is_main_process and forward_start is not None:
|
||||
forward_start.record()
|
||||
|
||||
# 前向传播
|
||||
@ -310,11 +397,11 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
# 计时前向传播结束 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
if args.profile and accelerator.is_main_process and forward_end is not None:
|
||||
forward_end.record()
|
||||
|
||||
# 计时反向传播 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
if args.profile and accelerator.is_main_process and backward_start is not None:
|
||||
backward_start.record()
|
||||
|
||||
# 反向传播
|
||||
@ -322,11 +409,11 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
accelerator.backward(loss)
|
||||
|
||||
# 计时反向传播结束 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
if args.profile and accelerator.is_main_process and backward_end is not None:
|
||||
backward_end.record()
|
||||
|
||||
# 计时优化器步骤 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
if args.profile and accelerator.is_main_process and optimizer_start is not None:
|
||||
optimizer_start.record()
|
||||
|
||||
# 优化器步骤 - 当使用DeepSpeed时,它会自动处理梯度累积和梯度裁剪
|
||||
@ -339,40 +426,58 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 计时优化器步骤结束 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
if args.profile and accelerator.is_main_process and optimizer_end is not None:
|
||||
optimizer_end.record()
|
||||
|
||||
# 打印训练信息 (只在主进程进行)
|
||||
if (step + 1) % args.log_interval == 0 and accelerator.is_main_process:
|
||||
current_time = time.time()
|
||||
# 计算性能指标
|
||||
if args.profile:
|
||||
torch.cuda.synchronize()
|
||||
# 使用自上次日志以来的时间计算性能指标,而不是总时间
|
||||
data_time = data_start.elapsed_time(data_end)
|
||||
forward_time = forward_start.elapsed_time(forward_end)
|
||||
backward_time = backward_start.elapsed_time(backward_end)
|
||||
optimizer_time = optimizer_start.elapsed_time(optimizer_end)
|
||||
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
|
||||
# total_time_ms = data_time + forward_time + backward_time + optimizer_time
|
||||
|
||||
# 打印性能分析
|
||||
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
|
||||
Logger(f"性能分析 (Avg/iter over last {args.log_interval} steps) - "
|
||||
f"Data: {data_time/args.log_interval:.2f}ms, "
|
||||
f"Fwd: {forward_time/args.log_interval:.2f}ms, "
|
||||
f"Bwd: {backward_time/args.log_interval:.2f}ms, "
|
||||
f"Optim: {optimizer_time/args.log_interval:.2f}ms, "
|
||||
f"Iter Time: {iter_time:.2f}ms", accelerator)
|
||||
# 重置事件以便下次测量从0开始
|
||||
data_start = torch.cuda.Event(enable_timing=True)
|
||||
data_end = torch.cuda.Event(enable_timing=True)
|
||||
forward_start = torch.cuda.Event(enable_timing=True)
|
||||
forward_end = torch.cuda.Event(enable_timing=True)
|
||||
backward_start = torch.cuda.Event(enable_timing=True)
|
||||
backward_end = torch.cuda.Event(enable_timing=True)
|
||||
optimizer_start = torch.cuda.Event(enable_timing=True)
|
||||
optimizer_end = torch.cuda.Event(enable_timing=True)
|
||||
# 记录日志输出时的详细内存状态
|
||||
if args.memory_monitor:
|
||||
log_memory_status(step, prefetch_batches, accelerator, "at_log_interval", detailed=True)
|
||||
|
||||
# 强制垃圾回收并记录内存变化
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
log_memory_status(step, prefetch_batches, accelerator, "after_gc", detailed=True)
|
||||
|
||||
# 计算性能指标
|
||||
if args.profile and accelerator.is_main_process:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# 确保所有事件都已记录才计算elapsed_time
|
||||
try:
|
||||
data_time = data_start.elapsed_time(data_end) if data_start is not None and data_end is not None else 0
|
||||
forward_time = forward_start.elapsed_time(forward_end) if forward_start is not None and forward_end is not None else 0
|
||||
backward_time = backward_start.elapsed_time(backward_end) if backward_start is not None and backward_end is not None else 0
|
||||
optimizer_time = optimizer_start.elapsed_time(optimizer_end) if optimizer_start is not None and optimizer_end is not None else 0
|
||||
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
|
||||
# total_time_ms = data_time + forward_time + backward_time + optimizer_time
|
||||
|
||||
# 打印性能分析
|
||||
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
|
||||
Logger(f"性能分析 (Avg/iter over last {args.log_interval} steps) - "
|
||||
f"Data: {data_time/args.log_interval:.2f}ms, "
|
||||
f"Fwd: {forward_time/args.log_interval:.2f}ms, "
|
||||
f"Bwd: {backward_time/args.log_interval:.2f}ms, "
|
||||
f"Optim: {optimizer_time/args.log_interval:.2f}ms, "
|
||||
f"Iter Time: {iter_time:.2f}ms", accelerator)
|
||||
# 重置事件以便下次测量从0开始
|
||||
data_start = torch.cuda.Event(enable_timing=True)
|
||||
data_end = torch.cuda.Event(enable_timing=True)
|
||||
forward_start = torch.cuda.Event(enable_timing=True)
|
||||
forward_end = torch.cuda.Event(enable_timing=True)
|
||||
backward_start = torch.cuda.Event(enable_timing=True)
|
||||
backward_end = torch.cuda.Event(enable_timing=True)
|
||||
optimizer_start = torch.cuda.Event(enable_timing=True)
|
||||
optimizer_end = torch.cuda.Event(enable_timing=True)
|
||||
except RuntimeError as e:
|
||||
if "Both events must be recorded" in str(e):
|
||||
Logger(f"Warning: CUDA events not properly recorded, skipping performance analysis: {e}", accelerator)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
# 计算当前学习率
|
||||
@ -413,12 +518,12 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
|
||||
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
|
||||
|
||||
if args.use_wandb and accelerator.is_main_process and wandb:
|
||||
wandb.log(log_dict)
|
||||
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
|
||||
swanlab_run.log(log_dict)
|
||||
|
||||
# 保存模型 (只在主进程进行)
|
||||
loss_total = loss.item() * args.accumulation_steps
|
||||
if best_loss > loss_total and accelerator.is_main_process:
|
||||
if epoch > 1 and best_loss > loss_total and accelerator.is_main_process:
|
||||
best_loss = loss_total
|
||||
# 使用函数开始处定义的moe_path变量
|
||||
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
|
||||
@ -432,20 +537,45 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
|
||||
except Exception as e:
|
||||
Logger(f"Error in training step: {e}", accelerator)
|
||||
# 记录异常时的内存状态
|
||||
if args.memory_monitor:
|
||||
log_memory_status(step, prefetch_batches, accelerator, "at_exception", detailed=True)
|
||||
import traceback
|
||||
Logger(traceback.format_exc(), accelerator)
|
||||
|
||||
# 清理prefetch_batches,防止内存泄漏
|
||||
if args.memory_monitor and accelerator.is_main_process:
|
||||
Logger(f"[Memory Monitor] Clearing prefetch_batches due to exception. Current length: {len(prefetch_batches)}", accelerator)
|
||||
prefetch_batches.clear()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
if args.memory_monitor:
|
||||
log_memory_status(step, prefetch_batches, accelerator, "after_exception_cleanup", detailed=True)
|
||||
|
||||
# 训练epoch结束时清理prefetch_batches
|
||||
if args.memory_monitor:
|
||||
if accelerator.is_main_process:
|
||||
Logger(f"[Memory Monitor] Epoch {epoch+1} finished. Clearing prefetch_batches. Final length: {len(prefetch_batches)}", accelerator)
|
||||
log_memory_status(total_steps_in_epoch-1, prefetch_batches, accelerator, "before_epoch_end_cleanup", detailed=True)
|
||||
prefetch_batches.clear()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
if args.memory_monitor:
|
||||
log_memory_status(total_steps_in_epoch-1, prefetch_batches, accelerator, "after_epoch_end_cleanup", detailed=True)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
|
||||
parser.add_argument("--out_dir", type=str, default="out")
|
||||
parser.add_argument("--epochs", type=int, default=4)
|
||||
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument("--batch_size", type=int, default=128)
|
||||
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", default=True, action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数
|
||||
parser.add_argument("--swanlab_project", type=str, default="MiniMind-Pretrain") # 替换wandb参数
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=32)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--warmup_iters", type=int, default=0)
|
||||
@ -456,17 +586,19 @@ def main():
|
||||
parser.add_argument('--max_seq_len', default=512, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/merged_pretrain.jsonl")
|
||||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||
parser.add_argument("--knowledge_num", type=int, default=8192,help="知识库的数据数目")
|
||||
parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目")
|
||||
parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
|
||||
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
|
||||
parser.add_argument("--database_init_path", type=str, default="./dataset/combined_prepare.json", help="数据库初始化路径")
|
||||
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
|
||||
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
|
||||
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
|
||||
parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控")
|
||||
parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)")
|
||||
args = parser.parse_args()
|
||||
|
||||
#########################################################
|
||||
@ -479,7 +611,7 @@ def main():
|
||||
gradient_accumulation_steps=args.accumulation_steps,
|
||||
gradient_clipping=args.grad_clip,
|
||||
zero_stage=2, # 使用ZeRO-2优化
|
||||
offload_optimizer_device="cpu", # 将优化器状态卸载到CPU
|
||||
offload_optimizer_device="none", # 将优化器状态卸载到CPU
|
||||
offload_param_device="none", # 不将参数卸载到CPU
|
||||
)
|
||||
accelerator = Accelerator(
|
||||
@ -523,18 +655,30 @@ def main():
|
||||
|
||||
|
||||
#########################################################
|
||||
# 配置wandb
|
||||
# 配置SwanLab
|
||||
#########################################################
|
||||
# 设置wandb运行名称
|
||||
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
if args.use_wandb and accelerator.is_main_process:
|
||||
import wandb
|
||||
# 合并args和lm_config为一个字典
|
||||
config_dict = vars(args).copy()
|
||||
config_dict.update(vars(lm_config))
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config_dict)
|
||||
# 设置SwanLab运行名称
|
||||
args.swanlab_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
|
||||
# 合并args和lm_config为一个字典(无论是否使用SwanLab都需要,用于打印配置信息)
|
||||
config_dict = vars(args).copy()
|
||||
config_dict.update(vars(lm_config))
|
||||
|
||||
# 初始化SwanLab实验实例
|
||||
swanlab_run = None
|
||||
if args.use_swanlab and accelerator.is_main_process:
|
||||
# 初始化SwanLab
|
||||
swanlab_run = swanlab.init(
|
||||
project=args.swanlab_project,
|
||||
experiment_name=args.swanlab_run_name,
|
||||
description="MiniMind预训练实验,使用本地部署的SwanLab进行可视化",
|
||||
config=config_dict
|
||||
# 设置SwanLab服务器地址和API Key
|
||||
# host="http://100.123.118.114:11071",
|
||||
# api_key="LesBT7HRq23HNBrOPKP8S"
|
||||
)
|
||||
else:
|
||||
wandb = None
|
||||
swanlab_run = None
|
||||
|
||||
#########################################################
|
||||
# 打印信息
|
||||
@ -616,13 +760,31 @@ def main():
|
||||
#########################################################
|
||||
overall_start_time = time.time() # Record overall start time
|
||||
for epoch in range(args.epochs):
|
||||
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb) # Pass overall start time
|
||||
Logger(f"开始第{epoch+1}轮训练", accelerator)
|
||||
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run) # Pass overall start time
|
||||
|
||||
# 每个epoch结束后进行内存清理
|
||||
Logger(f"第{epoch+1}轮训练完成,进行内存清理", accelerator)
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 记录epoch结束时的内存状态
|
||||
if accelerator.is_main_process:
|
||||
memory_info = get_memory_usage()
|
||||
cuda_info = get_cuda_memory_usage()
|
||||
log_msg = f"[Memory Monitor] Epoch {epoch+1} completed - "
|
||||
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
|
||||
if cuda_info:
|
||||
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
|
||||
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
|
||||
Logger(log_msg, accelerator)
|
||||
|
||||
#########################################################
|
||||
# 关闭wandb
|
||||
# 关闭SwanLab
|
||||
#########################################################
|
||||
if args.use_wandb and accelerator.is_main_process:
|
||||
wandb.finish()
|
||||
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
|
||||
swanlab_run.finish()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user