Compare commits
No commits in common. "master" and "Gary_Lu" have entirely different histories.
9
.gitignore
vendored
9
.gitignore
vendored
@ -2,11 +2,4 @@
|
||||
/dataset
|
||||
/out
|
||||
wandb/
|
||||
**/*.log
|
||||
models/sentence_transformers/
|
||||
models/sentence_transformers_cache/
|
||||
**/*.pyc
|
||||
qwen2-1.7B/
|
||||
images/
|
||||
cache/
|
||||
.venv/
|
||||
**/*.log
|
124
.vscode/launch.json
vendored
124
.vscode/launch.json
vendored
@ -1,124 +0,0 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "MiniMind Training (Direct Python)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
||||
"args": [
|
||||
"--out_dir", "out",
|
||||
"--epochs", "3",
|
||||
"--embedding_epoch", "2",
|
||||
"--batch_size", "128",
|
||||
"--learning_rate", "8e-5",
|
||||
"--dtype", "bfloat16",
|
||||
"--use_swanlab",
|
||||
"--swanlab_project", "MiniMind-Pretrain",
|
||||
"--num_workers", "1",
|
||||
"--accumulation_steps", "16",
|
||||
"--grad_clip", "0.5",
|
||||
"--warmup_iters", "0",
|
||||
"--log_interval", "1",
|
||||
"--save_interval", "10000",
|
||||
"--dim", "512",
|
||||
"--n_layers", "8",
|
||||
"--max_seq_len", "512",
|
||||
"--data_path", "./dataset/stable/merged_pretrain.jsonl",
|
||||
"--profile",
|
||||
"--profile_interval", "10",
|
||||
"--use_flash_attn",
|
||||
"--knowledge_num", "1048576",
|
||||
"--knowledge_length", "32",
|
||||
"--database_init_path", "./dataset/stable/sentence_trex_data.json",
|
||||
"--fast_clustering",
|
||||
"--cluster_cache_path", "./cache/cluster_tokens_single.pt",
|
||||
"--memory_monitor_interval", "10",
|
||||
"--model_type", "model",
|
||||
"--model_size", "538"
|
||||
],
|
||||
"env": {
|
||||
"CUDA_VISIBLE_DEVICES": "0",
|
||||
"NCCL_DEBUG": "INFO",
|
||||
"PYTHONFAULTHANDLER": "1"
|
||||
},
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"stopOnEntry": false,
|
||||
"python": "${workspaceFolder}/.venv/bin/python"
|
||||
},
|
||||
{
|
||||
"name": "MiniMind Training (Direct Python - Simple)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
||||
"args": [
|
||||
"--epochs", "1",
|
||||
"--batch_size", "32",
|
||||
"--learning_rate", "1e-4",
|
||||
"--log_interval", "10",
|
||||
"--profile_interval", "2",
|
||||
"--model_type", "model_original"
|
||||
],
|
||||
"env": {
|
||||
"CUDA_VISIBLE_DEVICES": "0"
|
||||
},
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"stopOnEntry": false,
|
||||
"python": "${workspaceFolder}/.venv/bin/python"
|
||||
},
|
||||
{
|
||||
"name": "MiniMind Test (Direct Python)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/test.py",
|
||||
"env": {
|
||||
"CUDA_VISIBLE_DEVICES": "0"
|
||||
},
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"python": "${workspaceFolder}/.venv/bin/python"
|
||||
},
|
||||
{
|
||||
"name": "MiniMind Training Debug (Accelerate)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "accelerate.commands.launch",
|
||||
"args": [
|
||||
"--num_processes=1",
|
||||
"--mixed_precision=bf16",
|
||||
"${workspaceFolder}/train_pretrain_accelerate.py",
|
||||
"--epochs", "1",
|
||||
"--batch_size", "32",
|
||||
"--learning_rate", "1e-4",
|
||||
"--log_interval", "10",
|
||||
"--profile_interval", "2",
|
||||
"--model_type", "model_original"
|
||||
],
|
||||
"env": {
|
||||
"CUDA_VISIBLE_DEVICES": "0"
|
||||
},
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"stopOnEntry": false,
|
||||
"python": "${workspaceFolder}/.venv/bin/python"
|
||||
},
|
||||
{
|
||||
"name": "MiniMind Test Only",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/test.py",
|
||||
"env": {
|
||||
"CUDA_VISIBLE_DEVICES": "0"
|
||||
},
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
}
|
||||
]
|
||||
}
|
18
.vscode/settings.json
vendored
18
.vscode/settings.json
vendored
@ -1,18 +0,0 @@
|
||||
{
|
||||
"python.pythonPath": "/home/iomgaa/miniconda3/envs/accelerate/bin/python",
|
||||
"python.defaultInterpreterPath": "/home/iomgaa/miniconda3/envs/accelerate/bin/python",
|
||||
"python.terminal.activateEnvironment": true,
|
||||
"python.terminal.activateEnvInCurrentTerminal": true,
|
||||
"python.linting.enabled": true,
|
||||
"python.linting.pylintEnabled": false,
|
||||
"python.linting.flake8Enabled": true,
|
||||
"python.formatting.provider": "black",
|
||||
"python.analysis.autoImportCompletions": true,
|
||||
"python.analysis.typeCheckingMode": "off",
|
||||
"files.exclude": {
|
||||
"**/__pycache__": true,
|
||||
"**/*.pyc": true,
|
||||
"**/.git": false,
|
||||
"**/wandb": false
|
||||
}
|
||||
}
|
199
README.md
199
README.md
@ -1,199 +0,0 @@
|
||||
<div align="center">
|
||||
|
||||

|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||

|
||||
[](https://github.com/jingyaogong/minimind/stargazers)
|
||||
[](LICENSE)
|
||||
[](https://github.com/jingyaogong/minimind/commits/master)
|
||||
[](https://github.com/jingyaogong/minimind/pulls)
|
||||
[](https://huggingface.co/collections/jingyaogong/minimind-66caf8d999f5c7fa64f399e5)
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
# 📌 数据介绍
|
||||
|
||||
## Ⅰ Tokenizer
|
||||
|
||||
分词器将单词从自然语言通过“词典”映射到`0, 1, 36`这样的数字,可以理解为数字就代表了单词在“词典”中的页码。
|
||||
可以选择自己构造词表训练一个“词典”,代码可见`./scripts/train_tokenizer.py`(仅供学习参考,若非必要无需再自行训练,MiniMind已自带tokenizer)。
|
||||
或者选择比较出名的开源大模型分词器,
|
||||
正如同直接用新华/牛津词典的优点是token编码压缩率很好,缺点是页数太多,动辄数十万个词汇短语;
|
||||
自己训练的分词器,优点是词表长度和内容随意控制,缺点是压缩率很低(例如"hello"也许会被拆分为"h e l l o"
|
||||
五个独立的token),且生僻词难以覆盖。
|
||||
“词典”的选择固然很重要,LLM的输出本质上是SoftMax到词典N个词的多分类问题,然后通过“词典”解码到自然语言。
|
||||
因为MiniMind体积需要严格控制,为了避免模型头重脚轻(词嵌入embedding层参数在LLM占比太高),所以词表长度短短益善。
|
||||
|
||||
<details style="color:rgb(128,128,128)">
|
||||
<summary>Tokenizer介绍</summary>
|
||||
|
||||
第三方强大的开源模型例如Yi、qwen、chatglm、mistral、Llama3的tokenizer词表长度如下:
|
||||
|
||||
<table>
|
||||
<tr><th>Tokenizer模型</th><th>词表大小</th><th>来源</th></tr>
|
||||
<tr><td>yi tokenizer</td><td>64,000</td><td>01万物(中国)</td></tr>
|
||||
<tr><td>qwen2 tokenizer</td><td>151,643</td><td>阿里云(中国)</td></tr>
|
||||
<tr><td>glm tokenizer</td><td>151,329</td><td>智谱AI(中国)</td></tr>
|
||||
<tr><td>mistral tokenizer</td><td>32,000</td><td>Mistral AI(法国)</td></tr>
|
||||
<tr><td>llama3 tokenizer</td><td>128,000</td><td>Meta(美国)</td></tr>
|
||||
<tr><td>minimind tokenizer</td><td>6,400</td><td>自定义</td></tr>
|
||||
</table>
|
||||
|
||||
> 👉2024-09-17更新:为了防止过去的版本歧义&控制体积,minimind所有模型均使用minimind_tokenizer分词,废弃所有mistral_tokenizer版本。
|
||||
|
||||
```
|
||||
# 一些自言自语
|
||||
> 尽管minimind_tokenizer长度很小,编解码效率弱于qwen2、glm等中文友好型分词器。
|
||||
> 但minimind模型选择了自己训练的minimind_tokenizer作为分词器,以保持整体参数轻量,避免编码层和计算层占比失衡,头重脚轻,因为minimind的词表大小只有6400。
|
||||
> 且minimind在实际测试中没有出现过生僻词汇解码失败的情况,效果良好。
|
||||
> 由于自定义词表压缩长度到6400,使得LLM总参数量最低只有25.8M。
|
||||
> 训练数据`tokenizer_train.jsonl`均来自于`匠数大模型数据集`,这部分数据相对次要,如需训练可以自由选择。
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Ⅱ Pretrain数据
|
||||
|
||||
经历了MiniMind-V1的低质量预训练数据,导致模型胡言乱语的教训,`2025-02-05` 之后决定不再采用大规模无监督的数据集做预训练。
|
||||
进而尝试把[匠数大模型数据集](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)的中文部分提取出来,
|
||||
清洗出字符`<512`长度的大约1.6GB的语料直接拼接成预训练数据 `pretrain_hq.jsonl`,hq即为high
|
||||
quality(当然也还不算high,提升数据质量无止尽)。
|
||||
|
||||
文件`pretrain_hq.jsonl` 数据格式为
|
||||
|
||||
```bash
|
||||
{"text": "如何才能摆脱拖延症? 治愈拖延症并不容易,但以下建议可能有所帮助..."}
|
||||
```
|
||||
|
||||
## Ⅲ SFT数据
|
||||
|
||||
[匠数大模型SFT数据集](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
||||
“是一个完整、格式统一、安全的大模型训练和研究资源。
|
||||
从网络上的公开数据源收集并整理了大量开源数据集,对其进行了格式统一,数据清洗,
|
||||
包含10M条数据的中文数据集和包含2M条数据的英文数据集。”
|
||||
以上是官方介绍,下载文件后的数据总量大约在4B tokens,肯定是适合作为中文大语言模型的SFT数据的。
|
||||
但是官方提供的数据格式很乱,全部用来sft代价太大。
|
||||
我将把官方数据集进行了二次清洗,把含有符号污染和噪声的条目去除;另外依然只保留了总长度`<512`
|
||||
的内容,此阶段希望通过大量对话补充预训练阶段欠缺的知识。
|
||||
导出文件为`sft_512.jsonl`(~7.5GB)。
|
||||
|
||||
[Magpie-SFT数据集](https://www.modelscope.cn/organization/Magpie-Align)
|
||||
收集了~1M条来自Qwen2/2.5的高质量对话,我将这部分数据进一步清洗,把总长度`<2048`的部分导出为`sft_2048.jsonl`(~9GB)。
|
||||
长度`<1024`的部分导出为`sft_1024.jsonl`(~5.5GB),用大模型对话数据直接进行sft就属于“黑盒蒸馏”的范畴。
|
||||
|
||||
进一步清洗前两步sft的数据(只保留中文字符占比高的内容),筛选长度`<512`的对话,得到`sft_mini_512.jsonl`(~1.2GB)。
|
||||
|
||||
所有sft文件 `sft_X.jsonl` 数据格式均为
|
||||
|
||||
```text
|
||||
{
|
||||
"conversations": [
|
||||
{"role": "user", "content": "你好"},
|
||||
{"role": "assistant", "content": "你好!"},
|
||||
{"role": "user", "content": "再见"},
|
||||
{"role": "assistant", "content": "再见!"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Ⅳ RLHF数据
|
||||
|
||||
来自[Magpie-DPO数据集](https://www.modelscope.cn/datasets/Magpie-Align/MagpieLM-DPO-Data-v0.1)
|
||||
大约200k条偏好数据(均是英文)生成自Llama3.1-70B/8B,可以用于训练奖励模型,优化模型回复质量,使其更加符合人类偏好。
|
||||
这里将数据总长度`<3000`的内容重组为`dpo.jsonl`(~0.9GB),包含`chosen`和`rejected`两个字段,`chosen`
|
||||
为偏好的回复,`rejected`为拒绝的回复。
|
||||
|
||||
文件 `dpo.jsonl` 数据格式为
|
||||
|
||||
```text
|
||||
{
|
||||
"chosen": [
|
||||
{"content": "Q", "role": "user"},
|
||||
{"content": "good answer", "role": "assistant"}
|
||||
],
|
||||
"rejected": [
|
||||
{"content": "Q", "role": "user"},
|
||||
{"content": "bad answer", "role": "assistant"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Ⅴ Reason数据集:
|
||||
|
||||
不得不说2025年2月谁能火的过DeepSeek...
|
||||
也激发了我对RL引导的推理模型的浓厚兴趣,目前已经用Qwen2.5复现了R1-Zero。
|
||||
如果有时间+效果work(但99%基模能力不足)我会在之后更新MiniMind基于RL训练的推理模型而不是蒸馏模型。
|
||||
时间有限,最快的低成本方案依然是直接蒸馏(黑盒方式)。
|
||||
耐不住R1太火,短短几天就已经存在一些R1的蒸馏数据集[R1-Llama-70B](https://www.modelscope.cn/datasets/Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B)、[R1-Distill-SFT](https://www.modelscope.cn/datasets/AI-ModelScope/R1-Distill-SFT)、
|
||||
[Alpaca-Distill-R1](https://huggingface.co/datasets/shareAI/Alpaca-Distill-R1-ZH)、
|
||||
[deepseek_r1_zh](https://huggingface.co/datasets/jinliuxi/deepseek_r1_zh)等等,纯中文的数据可能比较少。
|
||||
最终整合它们,导出文件为`r1_mix_1024.jsonl`,数据格式和`sft_X.jsonl`一致。
|
||||
|
||||
## Ⅵ 更多数据集
|
||||
|
||||
目前已经有[HqWu-HITCS/Awesome-Chinese-LLM](https://github.com/HqWu-HITCS/Awesome-Chinese-LLM)
|
||||
在收集和梳理中文LLM相关的开源模型、应用、数据集及教程等资料,并持续更新这方面的最新进展。全面且专业,Respect!
|
||||
|
||||
---
|
||||
|
||||
## Ⅷ 数据集下载
|
||||
|
||||
> [!NOTE]
|
||||
> 2025-02-05后,开源MiniMind最终训练所用的所有数据集,因此无需再自行预处理大规模数据集,避免重复性的数据处理工作。
|
||||
|
||||
MiniMind训练数据集 ([ModelScope](https://www.modelscope.cn/datasets/gongjy/minimind-dataset/files) | [HuggingFace](https://huggingface.co/datasets/jingyaogong))
|
||||
|
||||
> 无需全部clone,可单独下载所需的文件
|
||||
|
||||
将下载的数据集文件放到`./dataset/`目录下(✨为推荐的必须项)
|
||||
|
||||
```bash
|
||||
./dataset/
|
||||
├── dpo.jsonl (909MB)
|
||||
├── lora_identity.jsonl (22.8KB)
|
||||
├── lora_medical.jsonl (34MB)
|
||||
├── pretrain_hq.jsonl (1.6GB, ✨)
|
||||
├── r1_mix_1024.jsonl (340MB)
|
||||
├── sft_1024.jsonl (5.6GB)
|
||||
├── sft_2048.jsonl (9GB)
|
||||
├── sft_512.jsonl (7.5GB)
|
||||
├── sft_mini_512.jsonl (1.2GB, ✨)
|
||||
└── tokenizer_train.jsonl (1GB)
|
||||
```
|
||||
|
||||
<details style="color:rgb(128,128,128)">
|
||||
<summary>注:各数据集简介</summary>
|
||||
|
||||
* `dpo.jsonl` --RLHF阶段数据集
|
||||
* `lora_identity.jsonl` --自我认知数据集(例如:你是谁?我是minimind...),推荐用于lora训练(亦可用于全参SFT,勿被名字局限)
|
||||
* `lora_medical.jsonl` --医疗问答数据集,推荐用于lora训练(亦可用于全参SFT,勿被名字局限)
|
||||
* `pretrain_hq.jsonl`✨ --预训练数据集,整合自jiangshu科技
|
||||
* `r1_mix_1024.jsonl` --DeepSeek-R1-1.5B蒸馏数据,每条数据字符最大长度为1024(因此训练时设置max_seq_len=1024)
|
||||
* `sft_1024.jsonl` --整合自Qwen2.5蒸馏数据(是sft_2048的子集),每条数据字符最大长度为1024(因此训练时设置max_seq_len=1024)
|
||||
* `sft_2048.jsonl` --整合自Qwen2.5蒸馏数据,每条数据字符最大长度为2048(因此训练时设置max_seq_len=2048)
|
||||
* `sft_512.jsonl` --整合自匠数科技SFT数据,每条数据字符最大长度为512(因此训练时设置max_seq_len=512)
|
||||
* `sft_mini_512.jsonl`✨ --极简整合自匠数科技SFT数据+Qwen2.5蒸馏数据(用于快速训练Zero模型),每条数据字符最大长度为512(因此训练时设置max_seq_len=512)
|
||||
* `tokenizer_train.jsonl` --均来自于`匠数大模型数据集`,这部分数据相对次要,(不推荐自己重复训练tokenizer,理由如上)如需自己训练tokenizer可以自由选择数据集。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||

|
||||
|
||||
<details style="color:rgb(128,128,128)">
|
||||
<summary>说明 & 推荐训练方案</summary>
|
||||
|
||||
* MiniMind2 Series均经过共约20GB语料训练,大约4B tokens,即对应上面的数据组合训练结果(开销:💰💰💰💰💰💰💰💰,效果:😊😊😊😊😊😊)
|
||||
|
||||
* 想要最快速度从0实现Zero模型,推荐使用`pretrain_hq.jsonl` + `sft_mini_512.jsonl` 的数据组合,具体花销和效果可查看下文表格(开销:💰,效果:😊😊)
|
||||
|
||||
* 推荐具备一定算力资源或更在意效果的朋友可以考虑前者完整复现MiniMind2;仅有单卡GPU或在乎短时间快速复现的朋友强烈推荐后者;
|
||||
|
||||
* 【折中方案】亦可选择例如`sft_mini_512.jsonl`、`sft_1024.jsonl`中等规模数据进行自由组合训练(开销:💰💰💰,效果:😊😊😊😊)。
|
||||
|
||||
</details>
|
126
README_accelerate.md
Normal file
126
README_accelerate.md
Normal file
@ -0,0 +1,126 @@
|
||||
# 使用Accelerate+DeepSpeed进行分布式训练
|
||||
|
||||
本文档介绍如何使用Accelerate和DeepSpeed进行MiniMind模型的分布式训练。
|
||||
|
||||
## 环境准备
|
||||
|
||||
首先,确保安装了必要的依赖:
|
||||
|
||||
```bash
|
||||
pip install accelerate deepspeed
|
||||
```
|
||||
|
||||
## 配置文件说明
|
||||
|
||||
### 1. DeepSpeed配置文件 (ds_config.json)
|
||||
|
||||
DeepSpeed配置文件定义了优化器、学习率调度器和ZeRO优化等参数。主要配置包括:
|
||||
|
||||
- **ZeRO优化**:使用ZeRO-2进行优化,可以减少GPU内存使用
|
||||
- **优化器设置**:使用AdamW优化器
|
||||
- **混合精度训练**:支持FP16和BF16
|
||||
- **梯度累积**:通过"auto"自动设置,与训练脚本参数保持一致
|
||||
|
||||
### 2. Accelerate配置文件 (accelerate_config.yaml)
|
||||
|
||||
Accelerate配置文件定义了分布式训练的基本设置,包括:
|
||||
|
||||
- **分布式类型**:使用DeepSpeed
|
||||
- **混合精度**:使用BF16
|
||||
- **进程数量**:设置为4(可根据GPU数量调整)
|
||||
- **DeepSpeed配置**:指向ds_config.json文件
|
||||
|
||||
## 训练脚本说明
|
||||
|
||||
新的训练脚本`train_pretrain_accelerate.py`基于原有的`train_pretrain.py`修改而来,主要变化包括:
|
||||
|
||||
1. 使用Accelerator替代了PyTorch原生的分布式功能
|
||||
2. 移除了torchrun相关的分布式初始化代码
|
||||
3. 使用Accelerator的API进行模型、优化器和数据加载器的准备
|
||||
4. 使用Accelerator的API进行反向传播和梯度裁剪
|
||||
5. 处理了位置编码和未使用参数的问题
|
||||
|
||||
## 启动训练
|
||||
|
||||
有两种方式启动训练:
|
||||
|
||||
### 方法1:使用预先配置的accelerate配置文件
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
|
||||
--epochs 3 \
|
||||
--batch_size 24 \
|
||||
--learning_rate 2e-4 \
|
||||
--dtype bfloat16 \
|
||||
--accumulation_steps 32 \
|
||||
--grad_clip 1.0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 1024 \
|
||||
--n_layers 32 \
|
||||
--max_seq_len 1024 \
|
||||
--use_flash_attn \
|
||||
--profile \
|
||||
--profile_interval 10
|
||||
```
|
||||
|
||||
### 方法2:使用命令行参数直接配置accelerate
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
--deepspeed_config_file ds_config.json \
|
||||
train_pretrain_accelerate.py \
|
||||
--epochs 3 \
|
||||
--batch_size 24 \
|
||||
--learning_rate 2e-4 \
|
||||
--dtype bfloat16 \
|
||||
--accumulation_steps 32 \
|
||||
--grad_clip 1.0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 1024 \
|
||||
--n_layers 32 \
|
||||
--max_seq_len 1024 \
|
||||
--use_flash_attn \
|
||||
--profile \
|
||||
--profile_interval 10
|
||||
```
|
||||
|
||||
也可以直接使用提供的脚本:
|
||||
|
||||
```bash
|
||||
bash run_accelerate.sh
|
||||
```
|
||||
|
||||
## Accelerate与DeepSpeed配置的关系
|
||||
|
||||
1. **Accelerate**是一个高级API,用于简化分布式训练的设置和启动,它可以与多种分布式训练后端(如DeepSpeed、FSDP等)一起使用。
|
||||
|
||||
2. **DeepSpeed**是一个优化库,专注于大规模模型训练的内存优化和性能提升,提供了ZeRO优化等功能。
|
||||
|
||||
3. **配置关系**:
|
||||
- Accelerate配置文件(YAML)定义了使用哪种分布式后端以及基本的分布式设置
|
||||
- DeepSpeed配置文件(JSON)定义了DeepSpeed特有的优化参数
|
||||
- Accelerate通过`deepspeed_config_file`参数引用DeepSpeed配置文件
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **位置编码处理**:
|
||||
- 在模型中,`pos_cis`是一个复数张量,在分布式训练中需要特别处理
|
||||
- 在新的训练脚本中,我们使用Accelerator的API来处理这个问题,不再需要`_ddp_params_and_buffers_to_ignore`
|
||||
|
||||
2. **未使用参数处理**:
|
||||
- 原代码中使用`find_unused_parameters=True`来处理未使用的参数
|
||||
- 在新的训练脚本中,我们直接使用Accelerator的API,它会自动处理这个问题
|
||||
|
||||
3. **混合精度训练**:
|
||||
- DeepSpeed配置文件中的`fp16`和`bf16`设置为`"auto"`
|
||||
- 实际使用的精度由Accelerate的`--mixed_precision`参数决定
|
||||
|
||||
4. **梯度累积**:
|
||||
- DeepSpeed配置文件中的`gradient_accumulation_steps`设置为`"auto"`
|
||||
- 实际的梯度累积步数由训练脚本的`--accumulation_steps`参数决定
|
22
ReadMe.md
Normal file
22
ReadMe.md
Normal file
@ -0,0 +1,22 @@
|
||||
## 安装环境
|
||||
1. 创建conda环境
|
||||
```bash
|
||||
conda create -n accelerate python=3.10
|
||||
conda activate accelerate
|
||||
```
|
||||
|
||||
2. 根据当前系统的cuda版本安装对应的torch、torchvision和torchaudio
|
||||
|
||||
3. 根据当前环境的torch和torchvision安装accelerate和deepspeed
|
||||
|
||||
4. 安装其他包
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 修改模型
|
||||
1. 一般情况只修改 `model`文件夹的文件
|
||||
|
||||
## 运行
|
||||
1. 如果在4090或者4070ti上运行 `bash run_file/DynamicKV-LLM_Mini_Minimind.sh`
|
||||
2. 如果在4张A800上运行 `bash run_file/DynamicKV-LLM_Small_Minimind.sh`
|
@ -1,26 +0,0 @@
|
||||
# 1. 元数据:需要修改,请为该实验配置名称和描述
|
||||
name: ycz-minimind-test
|
||||
description: 测试minimind-test
|
||||
|
||||
# 2. 运行环境:一般不修改,如有需求可以手动替换为指定镜像
|
||||
environment:
|
||||
image: determinedai/pytorch-ngc:0.38.0 # 此项无需修改
|
||||
|
||||
# 3. 指定NAS上的数据集: 需要修改,仅修改bind_mounts字段,container_path和read_only无需修改
|
||||
#将<YOUR_DATASET_FOLDER_NAME>替换为您存放在NAS上Volume1/Share/datasets/的数据集文件夹名称
|
||||
# 请再次确保您已在 NAS上的Volume1/Share/datasets/存放了<YOUR_DATASET_FOLDER_NAME>数据集
|
||||
|
||||
|
||||
# 4. 计算资源:无需修改
|
||||
resources:
|
||||
slots_per_trial: 1 # 此项无需修改
|
||||
resource_pool: rtx4090 # 此项无需修改
|
||||
|
||||
# 5. 搜索器:无需修改
|
||||
searcher:
|
||||
name: single
|
||||
metric: test_accuracy
|
||||
smaller_is_better: false
|
||||
|
||||
# 6. 启动入口:无需修改
|
||||
entrypoint: sh startup.sh
|
6
main.py
6
main.py
@ -1,6 +0,0 @@
|
||||
def main():
|
||||
print("Hello from minimind!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -19,7 +19,6 @@ class LMConfig(PretrainedConfig):
|
||||
rope_theta: int = 1e6,
|
||||
dropout: float = 0.0,
|
||||
flash_attn: bool = True,
|
||||
embeddings_epoch: int = 2,
|
||||
####################################################
|
||||
# DB related configurations
|
||||
####################################################
|
||||
@ -40,13 +39,6 @@ class LMConfig(PretrainedConfig):
|
||||
####################################################
|
||||
knowledge_num: int = 64*64,
|
||||
knowledge_length: int = 8,
|
||||
knowledge_dim: int = 128,
|
||||
####################################################
|
||||
# Triple extraction related configurations
|
||||
####################################################
|
||||
max_subject_len: int = 8,
|
||||
max_predicate_len: int = 4,
|
||||
max_object_len: int = 8,
|
||||
**kwargs,
|
||||
):
|
||||
self.dim = dim
|
||||
@ -61,7 +53,6 @@ class LMConfig(PretrainedConfig):
|
||||
self.rope_theta = rope_theta
|
||||
self.dropout = dropout
|
||||
self.flash_attn = flash_attn
|
||||
self.embeddings_epoch = embeddings_epoch
|
||||
####################################################
|
||||
# DB related configurations
|
||||
####################################################
|
||||
@ -81,11 +72,4 @@ class LMConfig(PretrainedConfig):
|
||||
####################################################
|
||||
self.knowledge_num = knowledge_num
|
||||
self.knowledge_length = knowledge_length
|
||||
self.knowledge_dim = knowledge_dim
|
||||
####################################################
|
||||
# Triple extraction related configurations
|
||||
####################################################
|
||||
self.max_subject_len = max_subject_len
|
||||
self.max_predicate_len = max_predicate_len
|
||||
self.max_object_len = max_object_len
|
||||
super().__init__(**kwargs)
|
||||
|
323
model/dataset.py
323
model/dataset.py
@ -9,75 +9,10 @@ import torch
|
||||
from sklearn.model_selection import train_test_split
|
||||
import os
|
||||
import ast
|
||||
from tqdm import tqdm
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
|
||||
def process_sample_filter(data_args):
|
||||
"""处理单个样本的过滤逻辑"""
|
||||
sample, valid_predicates = data_args
|
||||
if 'target' in sample and isinstance(sample['target'], list):
|
||||
# 过滤target中的低频谓词
|
||||
valid_targets = []
|
||||
for triple in sample['target']:
|
||||
if isinstance(triple, dict) and 'predicate' in triple:
|
||||
if triple['predicate'] in valid_predicates:
|
||||
valid_targets.append(triple)
|
||||
|
||||
# 如果还有有效的target,保留这个样本
|
||||
if valid_targets:
|
||||
sample['target'] = valid_targets
|
||||
return sample
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
# 如果没有target信息,保留样本
|
||||
return sample
|
||||
|
||||
|
||||
def process_sample_validation(data_args):
|
||||
"""处理单个样本的验证逻辑"""
|
||||
sample, predicate_vocab = data_args
|
||||
if not isinstance(sample, dict) or 'text' not in sample:
|
||||
return None
|
||||
|
||||
targets = sample.get('target', [])
|
||||
if not isinstance(targets, list) or len(targets) == 0:
|
||||
# 如果没有有效的target,创建一个默认的
|
||||
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
||||
else:
|
||||
# 验证并选择target,优先选择占比小的谓词
|
||||
selected_target = None
|
||||
min_percentage = float('inf')
|
||||
|
||||
for triple in targets:
|
||||
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
|
||||
predicate = triple['predicate']
|
||||
|
||||
# 使用predicate_vocab中的统计信息
|
||||
if predicate in predicate_vocab:
|
||||
stats = predicate_vocab[predicate]
|
||||
if isinstance(stats, dict) and 'percentage' in stats:
|
||||
percentage = stats['percentage']
|
||||
if percentage < min_percentage:
|
||||
min_percentage = percentage
|
||||
selected_target = triple
|
||||
elif selected_target is None:
|
||||
selected_target = triple
|
||||
elif selected_target is None:
|
||||
selected_target = triple
|
||||
|
||||
# 如果没有找到有效的target,使用默认值
|
||||
if selected_target is None:
|
||||
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
||||
|
||||
return {
|
||||
'text': sample['text'],
|
||||
'target': selected_target # 只保留一个target
|
||||
}
|
||||
|
||||
|
||||
class PretrainDataset(Dataset):
|
||||
def __init__(self, data_path, tokenizer, max_length=512):
|
||||
super().__init__()
|
||||
@ -98,14 +33,9 @@ class PretrainDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
text = str(sample['text'])
|
||||
|
||||
# 检查并添加<|im_start|>和<|im_end|>如果不存在
|
||||
if not text.startswith(self.tokenizer.bos_token):
|
||||
text = f"{self.tokenizer.bos_token}{text}"
|
||||
if not text.endswith(self.tokenizer.eos_token):
|
||||
text = f"{text}{self.tokenizer.eos_token}"
|
||||
|
||||
|
||||
# 构建输入文本
|
||||
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
|
||||
encoding = self.tokenizer(
|
||||
text,
|
||||
max_length=self.max_length,
|
||||
@ -128,8 +58,8 @@ class SFTDataset(Dataset):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.samples = self.load_data(jsonl_path)
|
||||
self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids
|
||||
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
@ -196,8 +126,8 @@ class DPODataset(Dataset):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids
|
||||
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
self.data = []
|
||||
for line in f:
|
||||
@ -266,249 +196,14 @@ class DPODataset(Dataset):
|
||||
return loss_mask
|
||||
|
||||
|
||||
class TriplePretrainDataset(Dataset):
|
||||
"""
|
||||
优化的三元组预训练数据集
|
||||
- 每个样本只保留一个target三元组
|
||||
- 预先tokenize所有数据
|
||||
- 使用进度条显示处理进度
|
||||
"""
|
||||
def __init__(self, data_path=None, predicate_vocab_path=None, samples = None,tokenizer=None, max_length=512):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.val_samples = None
|
||||
self.predicate_to_id = {} # 初始化
|
||||
if samples is None:
|
||||
self.predicate_vocab = self.load_predicate_vocab(predicate_vocab_path)
|
||||
print("🚀 开始加载和预处理三元组数据...")
|
||||
self.samples,self.val_samples = self.load_and_preprocess_data(data_path)
|
||||
print("🚀 加载和预处理三元组数据完成")
|
||||
else:
|
||||
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
|
||||
data_filename = os.path.basename(data_path).split('.')[0]
|
||||
predicate_to_id_path = os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json')
|
||||
self.predicate_to_id = self.load_predicate_vocab(predicate_to_id_path)
|
||||
self.samples = samples
|
||||
print("🚀 加载和预处理三元组数据完成")
|
||||
def load_predicate_vocab(self, path):
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
predicate_vocab = json.load(f)
|
||||
return predicate_vocab
|
||||
|
||||
def get_val_samples(self):
|
||||
return self.val_samples
|
||||
|
||||
def clear_cache(self, data_path):
|
||||
"""清除缓存文件"""
|
||||
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
|
||||
data_filename = os.path.basename(data_path).split('.')[0]
|
||||
cache_files = [
|
||||
os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
|
||||
os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
|
||||
os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
|
||||
os.path.join(cache_dir, f'{data_filename}_val_samples.json')
|
||||
]
|
||||
|
||||
for cache_file in cache_files:
|
||||
if os.path.exists(cache_file):
|
||||
os.remove(cache_file)
|
||||
print(f"🗑️ 已删除缓存文件: {cache_file}")
|
||||
|
||||
if os.path.exists(cache_dir) and not os.listdir(cache_dir):
|
||||
os.rmdir(cache_dir)
|
||||
print(f"🗑️ 已删除空的缓存目录: {cache_dir}")
|
||||
|
||||
def load_and_preprocess_data(self, path):
|
||||
"""加载并预处理三元组数据"""
|
||||
# 生成缓存文件名(基于数据文件路径)
|
||||
cache_dir = os.path.join(os.path.dirname(path), 'cache')
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
data_filename = os.path.basename(path).split('.')[0]
|
||||
cache_files = {
|
||||
'predicate_vocab': os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
|
||||
'predicate_to_id': os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
|
||||
'train_samples': os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
|
||||
'val_samples': os.path.join(cache_dir, f'{data_filename}_val_samples.json')
|
||||
}
|
||||
|
||||
# 检查缓存文件是否存在
|
||||
cache_exists = all(os.path.exists(cache_file) for cache_file in cache_files.values())
|
||||
|
||||
if cache_exists:
|
||||
print("📁 发现缓存文件,直接加载...")
|
||||
# 从缓存加载
|
||||
with open(cache_files['predicate_vocab'], 'r', encoding='utf-8') as f:
|
||||
self.predicate_vocab = json.load(f)
|
||||
|
||||
with open(cache_files['predicate_to_id'], 'r', encoding='utf-8') as f:
|
||||
self.predicate_to_id = json.load(f)
|
||||
|
||||
with open(cache_files['train_samples'], 'r', encoding='utf-8') as f:
|
||||
train_samples = json.load(f)
|
||||
|
||||
with open(cache_files['val_samples'], 'r', encoding='utf-8') as f:
|
||||
val_samples = json.load(f)
|
||||
|
||||
print(f"✅ 从缓存加载完成:")
|
||||
print(f"✅ 谓词词表大小: {len(self.predicate_vocab)}")
|
||||
print(f"✅ 训练集大小: {len(train_samples)}")
|
||||
print(f"✅ 测试集大小: {len(val_samples)}")
|
||||
|
||||
return train_samples, val_samples
|
||||
|
||||
# 缓存不存在,重新处理数据
|
||||
print("📂 缓存不存在,开始加载和处理原始数据...")
|
||||
|
||||
# 1. 加载原始数据
|
||||
print("📂 加载原始数据...")
|
||||
if path.endswith('.json'):
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
elif path.endswith('.jsonl'):
|
||||
data = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data.append(json.loads(line.strip()))
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {path}")
|
||||
|
||||
print(f"📊 原始数据量: {len(data)} 个样本")
|
||||
|
||||
# 2. 使用self.predicate_vocab过滤占比小于0.01%的谓词数据
|
||||
print("🔍 过滤低频谓词数据...")
|
||||
print(f"📊 谓词统计数据: 总共{len(self.predicate_vocab)}个谓词")
|
||||
|
||||
# 3.获取占比大于等于0.01%的谓词
|
||||
valid_predicates = set()
|
||||
for predicate, stats in self.predicate_vocab.items():
|
||||
if isinstance(stats, dict) and 'percentage' in stats:
|
||||
if stats['percentage'] >= 0.01:
|
||||
valid_predicates.add(predicate)
|
||||
else:
|
||||
# 如果不是统计格式,假设是有效谓词
|
||||
valid_predicates.add(predicate)
|
||||
|
||||
print(f"📊 占比≥0.01%的谓词: {len(valid_predicates)}个")
|
||||
|
||||
# 4.过滤数据:去除包含低频谓词的数据(单进程处理)
|
||||
original_count = len(data)
|
||||
filtered_data = []
|
||||
|
||||
print("🚀 开始过滤低频谓词数据...")
|
||||
for sample in tqdm(data, desc="过滤低频谓词"):
|
||||
result = process_sample_filter((sample, valid_predicates))
|
||||
if result is not None:
|
||||
filtered_data.append(result)
|
||||
|
||||
data = filtered_data
|
||||
print(f"✅ 过滤完成: 去除前{original_count}条,去除后{len(data)}条")
|
||||
|
||||
# 5. 去除self.predicate_vocab中占比小于0.01%的谓词,并创建谓词到序号的映射
|
||||
print("🔍 更新谓词词表并创建序号映射...")
|
||||
original_vocab_size = len(self.predicate_vocab)
|
||||
filtered_predicate_vocab = {}
|
||||
|
||||
for predicate, stats in self.predicate_vocab.items():
|
||||
if isinstance(stats, dict) and 'percentage' in stats:
|
||||
if stats['percentage'] >= 0.01:
|
||||
filtered_predicate_vocab[predicate] = stats
|
||||
else:
|
||||
# 如果不是统计格式,保留
|
||||
filtered_predicate_vocab[predicate] = stats
|
||||
|
||||
# 创建谓词到序号的映射字典
|
||||
self.predicate_to_id = {predicate: idx for idx, predicate in enumerate(filtered_predicate_vocab.keys())}
|
||||
self.predicate_vocab = filtered_predicate_vocab
|
||||
print(f"✅ 谓词词表更新: 去除前{original_vocab_size}个,去除后{len(self.predicate_vocab)}个")
|
||||
print(f"✅ 谓词映射创建: {len(self.predicate_to_id)}个谓词对应序号")
|
||||
|
||||
# 6. 数据验证和筛选(只保留一个target),优先选择占比小的谓词以平衡数据(单进程处理)
|
||||
print("🔍 验证数据格式并选择单个target(平衡数据)...")
|
||||
valid_samples = []
|
||||
|
||||
print("🚀 开始验证数据格式...")
|
||||
for sample in tqdm(data, desc="验证数据格式"):
|
||||
result = process_sample_validation((sample, self.predicate_vocab))
|
||||
if result is not None:
|
||||
valid_samples.append(result)
|
||||
|
||||
print(f"✅ 有效样本数: {len(valid_samples)}")
|
||||
|
||||
# 7.拆分训练集合与测试集合
|
||||
import random
|
||||
random.seed(42)
|
||||
val_samples = random.sample(valid_samples, min(1000, len(valid_samples)))
|
||||
train_samples = [sample for sample in valid_samples if sample not in val_samples]
|
||||
print(f"✅ 训练集大小: {len(train_samples)}")
|
||||
print(f"✅ 测试集大小: {len(val_samples)}")
|
||||
|
||||
# 8. 保存到缓存文件
|
||||
print("💾 保存处理结果到缓存文件...")
|
||||
with open(cache_files['predicate_vocab'], 'w', encoding='utf-8') as f:
|
||||
json.dump(self.predicate_vocab, f, ensure_ascii=False, indent=2)
|
||||
|
||||
with open(cache_files['predicate_to_id'], 'w', encoding='utf-8') as f:
|
||||
json.dump(self.predicate_to_id, f, ensure_ascii=False, indent=2)
|
||||
|
||||
with open(cache_files['train_samples'], 'w', encoding='utf-8') as f:
|
||||
json.dump(train_samples, f, ensure_ascii=False, indent=2)
|
||||
|
||||
with open(cache_files['val_samples'], 'w', encoding='utf-8') as f:
|
||||
json.dump(val_samples, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print("✅ 缓存文件保存完成")
|
||||
|
||||
return train_samples, val_samples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def _triple_to_sentence(self, triple):
|
||||
"""将三元组转换为句子格式"""
|
||||
return f"{triple['subject']} {triple['predicate']} {triple['object']}"
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""返回数据,用于谓词分类任务"""
|
||||
sample = self.samples[index]
|
||||
|
||||
# 在运行时tokenize输入文本
|
||||
input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}"
|
||||
encoding = self.tokenizer(
|
||||
input_text,
|
||||
max_length=self.max_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
input_ids = encoding.input_ids.squeeze()
|
||||
loss_mask = (input_ids != self.tokenizer.pad_token_id)
|
||||
|
||||
# 获取谓词分类标签
|
||||
target_predicate = sample['target']['predicate']
|
||||
predicate_label = self.predicate_to_id.get(target_predicate) # 默认为0如果找不到
|
||||
|
||||
# 构建训练数据
|
||||
X = input_ids[:-1]
|
||||
loss_mask = loss_mask[1:]
|
||||
|
||||
return {
|
||||
'input_ids': X,
|
||||
'labels': torch.tensor(predicate_label, dtype=torch.long), # 谓词分类标签
|
||||
'loss_mask': loss_mask
|
||||
}
|
||||
|
||||
|
||||
class RLAIFDataset(Dataset):
|
||||
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.samples = self.load_data(jsonl_path)
|
||||
self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids
|
||||
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
@ -14,7 +14,7 @@
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"content": "<|im_start|>",
|
||||
"content": "<s>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
@ -23,7 +23,7 @@
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"content": "<|im_end|>",
|
||||
"content": "</s>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
@ -56,8 +56,8 @@
|
||||
"ignore_merges": false,
|
||||
"vocab": {
|
||||
"<unk>": 0,
|
||||
"<|im_start|>": 1,
|
||||
"<|im_end|>": 2,
|
||||
"<s>": 1,
|
||||
"</s>": 2,
|
||||
"!": 3,
|
||||
"\"": 4,
|
||||
"#": 5,
|
||||
|
@ -12,7 +12,7 @@
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "<|im_start|>",
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
@ -20,7 +20,7 @@
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "<|im_end|>",
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
@ -29,9 +29,9 @@
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [],
|
||||
"bos_token": "<|im_start|>",
|
||||
"bos_token": "<s>",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "<|im_end|>",
|
||||
"eos_token": "</s>",
|
||||
"legacy": true,
|
||||
"model_max_length": 32768,
|
||||
"pad_token": "<unk>",
|
||||
@ -39,5 +39,5 @@
|
||||
"spaces_between_special_tokens": false,
|
||||
"tokenizer_class": "PreTrainedTokenizerFast",
|
||||
"unk_token": "<unk>",
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\n你是 MiniMind,是一个有用的人工智能助手。<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<s>system\\n' + system_message + '</s>\\n' }}{% else %}{{ '<s>system\\n你是 MiniMind,是一个有用的人工智能助手。</s>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
|
||||
}
|
File diff suppressed because one or more lines are too long
710
model/model.py
710
model/model.py
@ -2,8 +2,7 @@ import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
import gc
|
||||
#子空间二维分解+梯度更新
|
||||
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
@ -12,9 +11,14 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
|
||||
# RMSNorm 类定义了一个用于归一化输入张量的模块。
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
@ -27,7 +31,7 @@ class RMSNorm(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return self.weight * self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
# precompute_pos_cis 函数用于预计算位置编码(复数版本)。
|
||||
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
@ -35,7 +39,7 @@ def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return pos_cis
|
||||
|
||||
|
||||
# apply_rotary_emb 函数用于应用旋转位置编码(复数版本)。
|
||||
def apply_rotary_emb(xq, xk, pos_cis):
|
||||
def unite_shape(pos_cis, x):
|
||||
ndim = x.ndim
|
||||
@ -51,244 +55,104 @@ def apply_rotary_emb(xq, xk, pos_cis):
|
||||
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
class KnowledgeDataset(nn.Module):
|
||||
def __init__(self, params, tok_embeddings, is_train=True):
|
||||
super().__init__()
|
||||
self.is_train = is_train
|
||||
self.params = params
|
||||
self.tok_embeddings = tok_embeddings
|
||||
# precompute_pos_cis_real 函数用于预计算位置编码(实数版本)。
|
||||
def precompute_pos_cis_real(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
"""使用实数张量实现位置编码,避免使用复数张量
|
||||
|
||||
# 嵌入参数
|
||||
self.knowledge_dim = params.knowledge_dim
|
||||
self.key_dim = self.knowledge_dim // 2
|
||||
self.to_queries = nn.Sequential(
|
||||
nn.Linear(params.dim, self.knowledge_dim, bias=False),
|
||||
)
|
||||
这个函数与precompute_pos_cis完全等价,但使用实数张量而非复数张量。
|
||||
原始函数生成形状为[seq_len, dim//2]的复数张量,其中实部全为1,虚部为旋转角度。
|
||||
这个函数生成形状为[seq_len, dim]的实数张量,其中偶数索引是cos(角度),奇数索引是sin(角度)。
|
||||
"""
|
||||
# 确保dim是偶数
|
||||
if dim % 2 != 0:
|
||||
raise ValueError(f"维度必须是偶数,但得到了 {dim}")
|
||||
|
||||
## 数据库参数
|
||||
self.knowledge_num = params.knowledge_num
|
||||
self.knowledge_length = params.knowledge_length
|
||||
|
||||
# 修改键存储为二维分解空间,设置为可训练参数
|
||||
self.num_keys = int(math.sqrt(self.knowledge_num))
|
||||
# 确保keys是可训练参数
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
|
||||
# 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度
|
||||
self.register_buffer('knowledge_dataset',
|
||||
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
|
||||
# 复制原始函数的频率计算逻辑
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device)
|
||||
freqs = torch.outer(t, freqs).float()
|
||||
|
||||
# 计算step数目,用于动态调整权重
|
||||
self.step_counter = 0
|
||||
# 计算cos和sin值
|
||||
# 在复数版本中,pos_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
# 等价于 cos(freqs) + i*sin(freqs)
|
||||
cos = torch.cos(freqs)
|
||||
sin = torch.sin(freqs)
|
||||
|
||||
# 移除批次计数器和更新频率相关代码
|
||||
# 创建实数张量,交错排列cos和sin
|
||||
pos_emb = torch.zeros((end, dim), device=freqs.device)
|
||||
pos_emb[:, 0::2] = cos # 偶数索引放cos
|
||||
pos_emb[:, 1::2] = sin # 奇数索引放sin
|
||||
|
||||
def intelligent_selection(self, query, all_scores, all_indices):
|
||||
"""智能分层选择策略"""
|
||||
if self.is_train == False:
|
||||
return all_scores, all_indices
|
||||
|
||||
batch_size = all_scores.size(0)
|
||||
device = all_scores.device
|
||||
dtype = all_scores.dtype
|
||||
return pos_emb
|
||||
|
||||
# 记录进入智能选择前的内存状态
|
||||
if hasattr(self, 'step_counter'):
|
||||
self.step_counter += 1
|
||||
# 禁用GPU内存监控记录以提高性能
|
||||
# if self.step_counter % 50 == 0: # 每50次调用记录一次
|
||||
# if torch.cuda.is_available():
|
||||
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[INTEL_SELECT_ENTER] Step {self.step_counter}: GPU Memory: {allocated_before:.2f}GB")
|
||||
# apply_rotary_emb_real 函数用于应用旋转位置编码(实数版本)。
|
||||
def apply_rotary_emb_real(xq, xk, pos_emb):
|
||||
"""使用实数张量实现旋转位置编码,避免使用复数张量
|
||||
|
||||
# 对每个batch进行分层选择
|
||||
enhanced_scores = all_scores.clone()
|
||||
query_features = query.mean(dim=1) # [batch_size, dim]
|
||||
这个函数与apply_rotary_emb完全等价,但使用实数张量而非复数张量。
|
||||
原始函数将输入张量转换为复数形式,与位置编码相乘,然后再转回实数形式。
|
||||
这个函数直接使用实数运算实现相同的旋转操作。
|
||||
"""
|
||||
# 获取形状信息
|
||||
bsz, seq_len, n_heads, head_dim = xq.shape
|
||||
|
||||
# 预先计算所有候选条目的嵌入(批量优化)
|
||||
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
|
||||
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
|
||||
# 确保pos_emb形状正确
|
||||
assert pos_emb.shape[0] >= seq_len, f"位置编码长度 {pos_emb.shape[0]} 小于序列长度 {seq_len}"
|
||||
assert pos_emb.shape[1] == head_dim, f"位置编码维度 {pos_emb.shape[1]} 与头维度 {head_dim} 不匹配"
|
||||
|
||||
# 批量计算唯一候选条目的嵌入
|
||||
candidate_tokens = self.knowledge_dataset[unique_indices]
|
||||
flat_tokens = candidate_tokens.view(-1)
|
||||
flat_embeddings = self.tok_embeddings(flat_tokens)
|
||||
|
||||
# 获取flat_tokens对应的index(保留这些变量以便其他地方使用)
|
||||
pre_update_indices = unique_indices.view(-1)
|
||||
pre_update_embeddings = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
)
|
||||
# 截取需要的位置编码长度
|
||||
pos_emb = pos_emb[:seq_len]
|
||||
|
||||
unique_candidate_features = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
).mean(dim=1) # [num_unique_candidates, dim]
|
||||
|
||||
# 归一化候选特征(优化相似度计算)
|
||||
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
|
||||
normalized_queries = F.normalize(query_features, dim=-1)
|
||||
# 将pos_emb调整为广播形状 [1, seq_len, 1, head_dim]
|
||||
pos_emb = pos_emb.unsqueeze(0).unsqueeze(2)
|
||||
|
||||
# 收集所有batch的best_tokens
|
||||
batch_best_tokens = []
|
||||
batch_best_tokens_embeddings = []
|
||||
# 将head_dim分成两半
|
||||
half_head_dim = head_dim // 2
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
indices = all_indices[batch_idx]
|
||||
|
||||
# 获取当前batch候选条目对应的特征索引
|
||||
start_idx = batch_idx * len(indices)
|
||||
end_idx = start_idx + len(indices)
|
||||
batch_inverse_indices = inverse_indices[start_idx:end_idx]
|
||||
|
||||
# 使用预计算的归一化特征进行优化相似度计算
|
||||
batch_candidate_features = normalized_candidates[batch_inverse_indices]
|
||||
query_feature = normalized_queries[batch_idx]
|
||||
|
||||
# 使用矩阵乘法计算余弦相似度
|
||||
similarity_scores = torch.mv(batch_candidate_features, query_feature)
|
||||
|
||||
# 找到最大相似度分数的索引
|
||||
max_similarity_idx = torch.argmax(similarity_scores)
|
||||
|
||||
# 获取最大相似度对应的候选条目索引
|
||||
best_candidate_idx = indices[max_similarity_idx]
|
||||
|
||||
# 获取对应的tokens
|
||||
best_tokens = self.knowledge_dataset[best_candidate_idx]
|
||||
best_tokens_embeddings = self.tok_embeddings(best_tokens)
|
||||
|
||||
# 将当前batch的best_tokens添加到列表中
|
||||
batch_best_tokens.append(best_tokens)
|
||||
batch_best_tokens_embeddings.append(best_tokens_embeddings)
|
||||
# 提取cos和sin值(偶数索引是cos,奇数索引是sin)
|
||||
cos = pos_emb[..., 0::2]
|
||||
sin = pos_emb[..., 1::2]
|
||||
|
||||
# 将所有batch的best_tokens堆叠成一个张量
|
||||
# [batch_size, knowledge_length]
|
||||
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
|
||||
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
|
||||
# 将xq和xk重新排列,以便进行旋转操作
|
||||
# 原始复数版本中,xq和xk被重塑为复数张量,其中实部和虚部交错排列
|
||||
# 在实数版本中,我们需要将偶数索引和奇数索引分开处理
|
||||
|
||||
# 清理中间张量以防止内存泄漏
|
||||
del all_candidate_indices, unique_indices, inverse_indices
|
||||
del unique_candidate_features, normalized_candidates, normalized_queries
|
||||
del batch_best_tokens, batch_best_tokens_embeddings
|
||||
del flat_tokens, flat_embeddings, pre_update_embeddings
|
||||
|
||||
# 记录退出智能选择后的内存状态(已禁用以提高性能)
|
||||
# if hasattr(self, 'step_counter') and self.step_counter % 50 == 0:
|
||||
# if torch.cuda.is_available():
|
||||
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[INTEL_SELECT_EXIT] Step {self.step_counter}: GPU Memory: {allocated_after:.2f}GB")
|
||||
|
||||
# 强制垃圾回收(仅在监控步骤)
|
||||
if hasattr(self, 'step_counter') and self.step_counter % 100 == 0:
|
||||
gc.collect()
|
||||
# if torch.cuda.is_available():
|
||||
# torch.cuda.empty_cache()
|
||||
# 分离偶数和奇数索引
|
||||
xq_even = xq[..., 0::2] # 偶数索引,对应复数的实部
|
||||
xq_odd = xq[..., 1::2] # 奇数索引,对应复数的虚部
|
||||
xk_even = xk[..., 0::2]
|
||||
xk_odd = xk[..., 1::2]
|
||||
|
||||
return all_best_tokens, all_best_tokens_embeddings
|
||||
|
||||
|
||||
|
||||
def search_index(self, x):
|
||||
batch_size, seq_len, dim = x.shape
|
||||
# 应用旋转(等价于复数乘法)
|
||||
# (a + bi)(cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i
|
||||
# 其中a是偶数索引,b是奇数索引
|
||||
xq_out_even = xq_even * cos - xq_odd * sin # 新的偶数索引(实部)
|
||||
xq_out_odd = xq_even * sin + xq_odd * cos # 新的奇数索引(虚部)
|
||||
xk_out_even = xk_even * cos - xk_odd * sin
|
||||
xk_out_odd = xk_even * sin + xk_odd * cos
|
||||
|
||||
# 1. 序列维度平均
|
||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
# 重新组合偶数和奇数索引
|
||||
xq_out = torch.zeros_like(xq)
|
||||
xk_out = torch.zeros_like(xk)
|
||||
xq_out[..., 0::2] = xq_out_even
|
||||
xq_out[..., 1::2] = xq_out_odd
|
||||
xk_out[..., 0::2] = xk_out_even
|
||||
xk_out[..., 1::2] = xk_out_odd
|
||||
|
||||
# 2. 生成查询向量并重塑为两个子查询
|
||||
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
|
||||
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
|
||||
# 调整维度顺序,使子空间维度位于首位
|
||||
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
# 3. 计算每个子空间的相似度
|
||||
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
# repeat_kv 函数用于重复键值对。
|
||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||||
bs, slen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x[:, :, :, None, :]
|
||||
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
||||
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
# 4. 在两个子空间分别做top-k
|
||||
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# 5. 组合两个子空间的结果
|
||||
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
|
||||
# 6. 将结果重塑为二维
|
||||
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
|
||||
# 7. 选择最终的top-k结果
|
||||
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
|
||||
indices = torch.gather(all_indices, 1, indices_of_indices)
|
||||
|
||||
# 8. 应用智能分层选择策略
|
||||
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
|
||||
|
||||
|
||||
return best_tokens, best_tokens_embeddings
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = 8
|
||||
self.head_dim = self.config.dim // self.num_heads
|
||||
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||
batch_size = x.size(0)
|
||||
|
||||
# 监控交叉注意力开始时的内存(已禁用以提高性能)
|
||||
if not hasattr(self, 'call_counter'):
|
||||
self.call_counter = 0
|
||||
self.call_counter += 1
|
||||
|
||||
# 禁用GPU内存监控记录以提高性能
|
||||
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
|
||||
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[CROSS_ATTN_ENTER] Call {self.call_counter}: GPU Memory: {allocated_before:.2f}GB")
|
||||
|
||||
# 分离多头
|
||||
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if pos_emb is not None:
|
||||
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q = q + pos_emb
|
||||
k = k + pos_emb
|
||||
v = v + pos_emb
|
||||
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
if context_mask is not None:
|
||||
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
||||
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
|
||||
context = torch.matmul(attn_weights, v)
|
||||
|
||||
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
|
||||
|
||||
context = self.to_out(context)
|
||||
|
||||
# 清理中间张量
|
||||
del q, k, v, attn_scores, attn_weights
|
||||
|
||||
# 监控交叉注意力结束时的内存(已禁用以提高性能)
|
||||
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
|
||||
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[CROSS_ATTN_EXIT] Call {self.call_counter}: GPU Memory: {allocated_after:.2f}GB")
|
||||
|
||||
return context
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: LMConfig):
|
||||
@ -314,14 +178,56 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
pos_cis: torch.Tensor):
|
||||
bsz, seq_len, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
pos_cis: torch.Tensor,
|
||||
db_value=None):
|
||||
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换,得到查询、键和值。
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) #将变换后的张量xq重塑为形状为(bsz, seq_len, n_local_heads, head_dim)的形状。
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||||
# 应用旋转位置编码(使用实数版本)
|
||||
xq, xk = apply_rotary_emb_real(xq, xk, pos_cis)
|
||||
# kv_cache实现 REMOVED
|
||||
# if past_key_value is not None:
|
||||
# xk = torch.cat([past_key_value[0], xk], dim=1)
|
||||
# xv = torch.cat([past_key_value[1], xv], dim=1)
|
||||
# past_kv = (xk, xv) if use_cache else None
|
||||
|
||||
# 重复键值对
|
||||
xq, xk, xv = (
|
||||
xq.transpose(1, 2),
|
||||
repeat_kv(xk, self.n_rep).transpose(1, 2),
|
||||
repeat_kv(xv, self.n_rep).transpose(1, 2)
|
||||
)
|
||||
|
||||
# 如果提供了db_value,根据头的数量调整它的形状并与xv合并
|
||||
if db_value is not None:
|
||||
# 确保db_value的形状与xv兼容,假设db_value形状为[B, N, H, D]
|
||||
if db_value.ndim == 4: # [B, N, H, D]
|
||||
db_value = db_value.transpose(1, 2) # -> [B, H, N, D]
|
||||
|
||||
# 检查是否需要调整D维度
|
||||
if db_value.shape[-1] != xv.shape[-1]:
|
||||
# 如果db_value的维度与xv不同,可以添加一个投影层
|
||||
# 或者在这里使用简单的调整方法
|
||||
# 这里我们简单地通过均值池化或重复来调整维度
|
||||
if db_value.shape[-1] > xv.shape[-1]:
|
||||
# 降维
|
||||
factor = db_value.shape[-1] // xv.shape[-1]
|
||||
db_value = db_value.view(bsz, self.n_local_heads, seq_len, factor, xv.shape[-1])
|
||||
db_value = db_value.mean(dim=3)
|
||||
else:
|
||||
# 升维
|
||||
factor = xv.shape[-1] // db_value.shape[-1]
|
||||
db_value = db_value.unsqueeze(-1).repeat(1, 1, 1, 1, factor)
|
||||
db_value = db_value.view(bsz, self.n_local_heads, seq_len, xv.shape[-1])
|
||||
|
||||
# 将db_value与xv相加或融合
|
||||
# 这里我们简单地将它们相加,但你也可以使用其他融合方法
|
||||
xv = xv + db_value
|
||||
|
||||
# 使用Flash Attention
|
||||
if self.flash and seq_len != 1:
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
output = F.scaled_dot_product_attention(
|
||||
@ -342,6 +248,53 @@ class Attention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = 8
|
||||
self.head_dim = self.config.dim // self.num_heads
|
||||
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||
batch_size = x.size(0)
|
||||
|
||||
# 分离多头
|
||||
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if pos_emb is not None:
|
||||
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q = q + pos_emb
|
||||
k = k + pos_emb
|
||||
v = v + pos_emb
|
||||
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
if context_mask is not None:
|
||||
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
||||
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
|
||||
context = torch.matmul(attn_weights, v)
|
||||
|
||||
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
|
||||
|
||||
context = self.to_out(context)
|
||||
|
||||
return context
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
@ -474,30 +427,169 @@ class MOEFeedForward(nn.Module):
|
||||
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
|
||||
def __init__(self, layer_id: int, config: LMConfig):
|
||||
super().__init__()
|
||||
self.n_heads = config.n_heads
|
||||
self.dim = config.dim
|
||||
self.head_dim = config.dim // config.n_heads
|
||||
self.self_attention = Attention(config)
|
||||
self.cross_attention = CrossAttention(config)
|
||||
self.knowledge_dataset = knowledge_dataset
|
||||
self.attention = Attention(config)
|
||||
self.cross_att = CrossAttention(config)
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||
|
||||
def forward(self, x, pos_cis):
|
||||
h_attn = self.self_attention(
|
||||
# 假设num_experts是已定义的总专家数量的平方根
|
||||
|
||||
|
||||
# 查询生成的参数
|
||||
|
||||
|
||||
# 创建查询生成模块
|
||||
# if weight_down_embed is not None:
|
||||
# self.to_queries = nn.Sequential(
|
||||
# nn.Linear(config.dim, self.dim_key * 2, bias=False),
|
||||
# # nn.Unflatten(2, (2, self.n_heads, self.dim_key)) # 替代Rearrange
|
||||
# )
|
||||
|
||||
# # 超参数
|
||||
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
|
||||
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
|
||||
|
||||
def forward(self, x, db_value, pos_cis):
|
||||
# import pdb;pdb.set_trace()
|
||||
# db_value = None
|
||||
|
||||
# # 如果有weight_down_embed,使用Product Key机制
|
||||
# if self.weight_down_embed is not None:
|
||||
# # 1. 生成queries
|
||||
# batch_size, seq_len, dim = x.shape
|
||||
|
||||
# # collapse sequence dimension by averaging
|
||||
# x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
# queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
||||
# queries = queries.reshape(batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
|
||||
# queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
|
||||
|
||||
# # 2. 计算queries与keys的相似度
|
||||
# sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# # 3. 在两个子空间分别做top-k
|
||||
# scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
# scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
# indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# # 4. 组合两个子空间的分数和索引
|
||||
# all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
|
||||
# all_scores = all_scores.view(*all_scores.shape[:-2], -1)
|
||||
|
||||
# all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
|
||||
# all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
||||
|
||||
# # 5. 最终top-k选择
|
||||
# scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
|
||||
# indices = all_indices.gather(-1, pk_indices)
|
||||
|
||||
# # 6. 从embedding中获取专家值
|
||||
|
||||
# # 从embedding中获取值
|
||||
# flat_indices = indices.view(-1) # 将索引展平为一维张量
|
||||
# db_values = self.weight_down_embed(flat_indices)
|
||||
|
||||
# # 重塑回原始形状
|
||||
# db_value = db_values.view(batch_size, -1, dim)
|
||||
|
||||
|
||||
# 注意力计算
|
||||
h_attn = self.attention(
|
||||
self.attention_norm(x),
|
||||
pos_cis
|
||||
pos_cis,
|
||||
db_value=db_value
|
||||
)
|
||||
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
|
||||
h_attn = self.cross_attention(h_attn, db_embeddings)
|
||||
|
||||
h_attn = self.cross_att(h_attn, db_value)
|
||||
|
||||
# 残差连接
|
||||
h = x + h_attn
|
||||
|
||||
# 前馈神经网络
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
return out
|
||||
|
||||
class ExtractDB(nn.Module):
|
||||
def __init__(self,params):
|
||||
# 修改专家数量和知识维度,确保能开方
|
||||
super().__init__()
|
||||
self.batch_size = None
|
||||
self.dim = params.dim
|
||||
self.dim_key = self.dim // 2
|
||||
self.knowledge_num = params.knowledge_num # 100专家,确保是完全平方数
|
||||
# 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用
|
||||
self.head_dim = params.dim // params.n_heads
|
||||
self.knowledge_length = params.knowledge_length
|
||||
|
||||
# 使用register_buffer代替nn.Parameter,避免梯度问题
|
||||
# self.register_buffer('weight_down_embed', torch.randn(self.knowledge_num, self.knowledge_length) * 0.02)
|
||||
self.register_buffer('weight_down_embed',torch.randint(low=0,high=6400, size=(self.knowledge_num, self.knowledge_length),dtype=torch.long))
|
||||
|
||||
|
||||
|
||||
|
||||
self.num_keys = int(math.sqrt(self.knowledge_num)) if self.knowledge_num > 0 else 0
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
|
||||
self.num_experts_per_head_topk = 1
|
||||
self.to_queries = nn.Sequential(
|
||||
nn.Linear(params.dim, self.dim_key * 2, bias=False),
|
||||
)
|
||||
|
||||
def q_to_k(self,x):
|
||||
# 1. 生成queries
|
||||
self.batch_size, seq_len, dim = x.shape
|
||||
|
||||
# collapse sequence dimension by averaging
|
||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
||||
queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
|
||||
queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
|
||||
|
||||
# 2. 计算queries与keys的相似度
|
||||
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# 3. 在两个子空间分别做top-k
|
||||
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# 4. 组合两个子空间的分数和索引
|
||||
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
|
||||
all_scores = all_scores.view(*all_scores.shape[:-2], -1)
|
||||
|
||||
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
|
||||
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
||||
|
||||
# 5. 最终top-k选择
|
||||
scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
|
||||
indices = all_indices.gather(-1, pk_indices)
|
||||
flat_indices = indices.view(-1)
|
||||
return flat_indices
|
||||
|
||||
def get_data(self, index):
|
||||
# 直接从GPU获取embedding
|
||||
db_values = self.weight_down_embed[index]#变成token了所以是1,后续再过emb
|
||||
# db_value = db_values.view(self.batch_size,-1)
|
||||
return db_values
|
||||
|
||||
@torch.no_grad()
|
||||
def updata_value(self, k, v):#要加一个从向量返回index的过程
|
||||
# 直接更新buffer上的值 (不需要梯度)
|
||||
v_reshaped = v.view(v.size(0), -1)
|
||||
# 确保数据类型匹配
|
||||
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
|
||||
self.weight_down_embed[k] = v_reshaped
|
||||
|
||||
|
||||
|
||||
class MiniMindLM(PreTrainedModel):
|
||||
@ -509,39 +601,113 @@ class MiniMindLM(PreTrainedModel):
|
||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||||
self.dropout = nn.Dropout(params.dropout)
|
||||
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
|
||||
# 移除旧的weight_down_embed声明
|
||||
self.extract_db = ExtractDB(self.params)
|
||||
|
||||
# 将self.weight_down_embed传递给每个MiniMindBlock
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||
self.database_output = nn.Linear(params.dim, params.knowledge_length, bias=False)
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
self.register_buffer("pos_cis",
|
||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
self.database_output.weight = self.output.weight
|
||||
|
||||
# Calculate input dimension
|
||||
input_dim = (self.params.max_seq_len-1)*self.params.n_layers
|
||||
# Use a bottleneck architecture to reduce parameters
|
||||
bottleneck_dim = 256 # Significantly smaller bottleneck dimension
|
||||
|
||||
# Factorized shared downsampling using two smaller convolutions
|
||||
self.shared_downsample = nn.Sequential(
|
||||
# First reduce input dimension to bottleneck
|
||||
nn.Conv1d(input_dim, bottleneck_dim, kernel_size=1, padding='same'),
|
||||
nn.ReLU(), # Non-linearity to improve representation capacity
|
||||
# Then expand to target dimension
|
||||
nn.Conv1d(bottleneck_dim, 128*8, kernel_size=1, padding='same')
|
||||
)
|
||||
|
||||
# Specific layers for v path
|
||||
self.downsample_v_specific = nn.Sequential(
|
||||
nn.Conv1d(128*8, 128, kernel_size=1, padding='same'),
|
||||
nn.Conv1d(128, self.params.knowledge_length, kernel_size=1, padding='same')
|
||||
)
|
||||
|
||||
# Specific layers for q path
|
||||
self.downsample_q_specific = nn.Sequential(
|
||||
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
|
||||
)
|
||||
# 使用实数版本的位置编码,避免复数张量可能导致的段错误
|
||||
self.register_buffer("pos_cis_real",
|
||||
precompute_pos_cis_real(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
self.OUT = CausalLMOutputWithPast()
|
||||
self.freeze_embedding = False
|
||||
self.params = params
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
step: int = 0,
|
||||
**args):
|
||||
start_pos = args.get('start_pos', 0)
|
||||
# if self.freeze_embedding and step == 0:
|
||||
# self.tok_embeddings.weight.requires_grad = False
|
||||
# # 移除对knowledge_dataset.freeze_embedding的设置,让键更新由batch_counter控制
|
||||
# # self.knowledge_dataset.freeze_embedding = True
|
||||
# print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
pos_cis_real = self.pos_cis_real[start_pos:start_pos + input_ids.size(1)]
|
||||
h_list = []
|
||||
|
||||
for l, layer in enumerate(self.layers):
|
||||
# 禁用数据库模式,使用固定值替代数据库查询
|
||||
if self.params.disable_db:
|
||||
# 创建一个形状为[batch_size, n_layers, dim]的tensor,所有元素值为1e-4
|
||||
batch_size = h.size(0)
|
||||
db_value = torch.full((batch_size, self.n_layers, self.params.dim), 1e-4,
|
||||
dtype=h.dtype, device=h.device)
|
||||
else:
|
||||
# 正常模式,使用数据库查询
|
||||
# import pdb;pdb.set_trace()
|
||||
index = self.extract_db.q_to_k(h)
|
||||
|
||||
token_idx = self.extract_db.get_data(index) #这里是index
|
||||
|
||||
db_value =self.tok_embeddings(token_idx)
|
||||
|
||||
h = layer(
|
||||
h, pos_cis
|
||||
h, db_value, pos_cis_real
|
||||
)
|
||||
|
||||
h_list.append(h.unsqueeze(0))
|
||||
|
||||
h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3)
|
||||
|
||||
# 只在非禁用数据库模式下执行数据库更新逻辑
|
||||
if not self.params.disable_db:
|
||||
# 使用detach()分离计算图,避免多次反向传播
|
||||
h_tensor_detached = h_tensor.detach()
|
||||
h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0], -1, self.params.dim)
|
||||
|
||||
# 数据库更新逻辑与主计算图分离
|
||||
with torch.no_grad():
|
||||
|
||||
# Compute shared downsampling layer once
|
||||
shared_features = self.shared_downsample(h_tensor_detached)
|
||||
|
||||
# Get features from v path - now we output embedding-dimension vectors
|
||||
z_v_features = self.downsample_v_specific(shared_features)
|
||||
batch_z, seq_len, dim_z = z_v_features.shape
|
||||
|
||||
# Reshape to batch_size * knowledge_length, dim
|
||||
z_v_flat = z_v_features.reshape(-1, dim_z)
|
||||
|
||||
# Direct token prediction - like the main language model head
|
||||
token_logits = self.database_output(z_v_flat) # [batch_z * seq_len, vocab_size]
|
||||
# Get token indices directly from logits
|
||||
token_indices_flat = torch.argmax(token_logits, dim=-1)
|
||||
token_indices = token_indices_flat.reshape(batch_z, -1)
|
||||
|
||||
# Process query path as before
|
||||
z_q = self.downsample_q_specific(shared_features)
|
||||
z_k = self.extract_db.q_to_k(z_q)
|
||||
self.extract_db.updata_value(z_k, token_indices)
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
# 统一不使用 aux_loss
|
||||
aux_loss = 0
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
||||
|
||||
# 进一步简化,只保留必要的参数
|
||||
output = CausalLMOutputWithPast(
|
||||
@ -551,6 +717,12 @@ class MiniMindLM(PreTrainedModel):
|
||||
|
||||
output.aux_loss = aux_loss
|
||||
|
||||
# 尝试添加其他属性(如果支持的话)
|
||||
# try:
|
||||
# output.hidden_states = h
|
||||
# except:
|
||||
# pass
|
||||
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -583,19 +755,15 @@ class MiniMindLM(PreTrainedModel):
|
||||
return res
|
||||
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
|
||||
start = input_ids.shape[1]
|
||||
for _ in range(max_new_tokens):
|
||||
# 每次都传入完整的input_ids,不使用KV缓存
|
||||
out = self(input_ids, **args)
|
||||
logits = out.logits[:, -1, :] # 取最后一个位置的logits
|
||||
|
||||
# 重复惩罚
|
||||
start, first_seq = input_ids.shape[1], True
|
||||
while input_ids.shape[1] < max_new_tokens - 1:
|
||||
if first_seq:
|
||||
out, first_seq = self(input_ids, **args), False
|
||||
else:
|
||||
out = self(input_ids[:, -1:], start_pos=input_ids.shape[1] - 1, **args)
|
||||
logits = out.logits[:, -1, :]
|
||||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||||
|
||||
# 温度采样
|
||||
logits /= (temperature + 1e-9)
|
||||
|
||||
# Top-p采样
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
||||
@ -605,14 +773,8 @@ class MiniMindLM(PreTrainedModel):
|
||||
sorted_indices_to_remove[:, 0] = False
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = -float('Inf')
|
||||
|
||||
# 采样下一个token
|
||||
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
||||
|
||||
# 返回新生成的部分
|
||||
yield input_ids[:, start:]
|
||||
|
||||
# 如果遇到结束token,停止生成
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
||||
|
@ -1,732 +0,0 @@
|
||||
import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
import gc
|
||||
#子空间二维分解+梯度更新
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self.weight * self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return pos_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(xq, xk, pos_cis):
|
||||
def unite_shape(pos_cis, x):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert pos_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return pos_cis.view(*shape)
|
||||
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
pos_cis = unite_shape(pos_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
class KnowledgeDataset(nn.Module):
|
||||
def __init__(self, params, tok_embeddings, is_train=True):
|
||||
super().__init__()
|
||||
self.is_train = is_train
|
||||
self.params = params
|
||||
self.tok_embeddings = tok_embeddings
|
||||
|
||||
# 嵌入参数
|
||||
self.knowledge_dim = params.knowledge_dim
|
||||
self.key_dim = self.knowledge_dim // 2
|
||||
self.to_queries = nn.Sequential(
|
||||
nn.Linear(params.dim, self.knowledge_dim, bias=False),
|
||||
)
|
||||
|
||||
## 数据库参数
|
||||
self.knowledge_num = params.knowledge_num
|
||||
self.knowledge_length = params.knowledge_length
|
||||
|
||||
# 修改键存储为二维分解空间,设置为可训练参数
|
||||
self.num_keys = int(math.sqrt(self.knowledge_num))
|
||||
# 确保keys是可训练参数
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
|
||||
# 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度
|
||||
self.register_buffer('knowledge_dataset',
|
||||
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
|
||||
|
||||
# 计算step数目,用于动态调整权重
|
||||
self.step_counter = 0
|
||||
|
||||
# 移除批次计数器和更新频率相关代码
|
||||
|
||||
def intelligent_selection(self, query, all_scores, all_indices):
|
||||
"""智能分层选择策略"""
|
||||
if self.is_train == False:
|
||||
return all_scores, all_indices
|
||||
|
||||
batch_size = all_scores.size(0)
|
||||
device = all_scores.device
|
||||
dtype = all_scores.dtype
|
||||
|
||||
# 记录进入智能选择前的内存状态
|
||||
if hasattr(self, 'step_counter'):
|
||||
self.step_counter += 1
|
||||
# 禁用GPU内存监控记录以提高性能
|
||||
# if self.step_counter % 50 == 0: # 每50次调用记录一次
|
||||
# if torch.cuda.is_available():
|
||||
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[INTEL_SELECT_ENTER] Step {self.step_counter}: GPU Memory: {allocated_before:.2f}GB")
|
||||
|
||||
# 对每个batch进行分层选择
|
||||
enhanced_scores = all_scores.clone()
|
||||
query_features = query.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
# 预先计算所有候选条目的嵌入(批量优化)
|
||||
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
|
||||
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
|
||||
|
||||
# 批量计算唯一候选条目的嵌入
|
||||
candidate_tokens = self.knowledge_dataset[unique_indices]
|
||||
flat_tokens = candidate_tokens.view(-1)
|
||||
flat_embeddings = self.tok_embeddings(flat_tokens)
|
||||
|
||||
# 获取flat_tokens对应的index(保留这些变量以便其他地方使用)
|
||||
pre_update_indices = unique_indices.view(-1)
|
||||
pre_update_embeddings = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
)
|
||||
|
||||
unique_candidate_features = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
).mean(dim=1) # [num_unique_candidates, dim]
|
||||
|
||||
# 归一化候选特征(优化相似度计算)
|
||||
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
|
||||
normalized_queries = F.normalize(query_features, dim=-1)
|
||||
|
||||
# 收集所有batch的best_tokens
|
||||
batch_best_tokens = []
|
||||
batch_best_tokens_embeddings = []
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
indices = all_indices[batch_idx]
|
||||
|
||||
# 获取当前batch候选条目对应的特征索引
|
||||
start_idx = batch_idx * len(indices)
|
||||
end_idx = start_idx + len(indices)
|
||||
batch_inverse_indices = inverse_indices[start_idx:end_idx]
|
||||
|
||||
# 使用预计算的归一化特征进行优化相似度计算
|
||||
batch_candidate_features = normalized_candidates[batch_inverse_indices]
|
||||
query_feature = normalized_queries[batch_idx]
|
||||
|
||||
# 使用矩阵乘法计算余弦相似度
|
||||
similarity_scores = torch.mv(batch_candidate_features, query_feature)
|
||||
|
||||
# 找到最大相似度分数的索引
|
||||
max_similarity_idx = torch.argmax(similarity_scores)
|
||||
|
||||
# 获取最大相似度对应的候选条目索引
|
||||
best_candidate_idx = indices[max_similarity_idx]
|
||||
|
||||
# 获取对应的tokens
|
||||
best_tokens = self.knowledge_dataset[best_candidate_idx]
|
||||
best_tokens_embeddings = self.tok_embeddings(best_tokens)
|
||||
|
||||
# 将当前batch的best_tokens添加到列表中
|
||||
batch_best_tokens.append(best_tokens)
|
||||
batch_best_tokens_embeddings.append(best_tokens_embeddings)
|
||||
|
||||
# 将所有batch的best_tokens堆叠成一个张量
|
||||
# [batch_size, knowledge_length]
|
||||
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
|
||||
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
|
||||
|
||||
# 清理中间张量以防止内存泄漏
|
||||
del all_candidate_indices, unique_indices, inverse_indices
|
||||
del unique_candidate_features, normalized_candidates, normalized_queries
|
||||
del batch_best_tokens, batch_best_tokens_embeddings
|
||||
del flat_tokens, flat_embeddings, pre_update_embeddings
|
||||
|
||||
# 记录退出智能选择后的内存状态(已禁用以提高性能)
|
||||
# if hasattr(self, 'step_counter') and self.step_counter % 50 == 0:
|
||||
# if torch.cuda.is_available():
|
||||
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[INTEL_SELECT_EXIT] Step {self.step_counter}: GPU Memory: {allocated_after:.2f}GB")
|
||||
|
||||
# 强制垃圾回收(仅在监控步骤)
|
||||
if hasattr(self, 'step_counter') and self.step_counter % 100 == 0:
|
||||
gc.collect()
|
||||
# if torch.cuda.is_available():
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
return all_best_tokens, all_best_tokens_embeddings
|
||||
|
||||
|
||||
|
||||
def search_index(self, x):
|
||||
batch_size, seq_len, dim = x.shape
|
||||
|
||||
# 1. 序列维度平均
|
||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
# 2. 生成查询向量并重塑为两个子查询
|
||||
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
|
||||
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
|
||||
# 调整维度顺序,使子空间维度位于首位
|
||||
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
|
||||
|
||||
# 3. 计算每个子空间的相似度
|
||||
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# 4. 在两个子空间分别做top-k
|
||||
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# 5. 组合两个子空间的结果
|
||||
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
|
||||
# 6. 将结果重塑为二维
|
||||
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
|
||||
# 7. 选择最终的top-k结果
|
||||
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
|
||||
indices = torch.gather(all_indices, 1, indices_of_indices)
|
||||
|
||||
# 8. 应用智能分层选择策略
|
||||
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
|
||||
|
||||
|
||||
return best_tokens, best_tokens_embeddings
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = 8
|
||||
self.head_dim = self.config.dim // self.num_heads
|
||||
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||
batch_size = x.size(0)
|
||||
|
||||
# 监控交叉注意力开始时的内存(已禁用以提高性能)
|
||||
if not hasattr(self, 'call_counter'):
|
||||
self.call_counter = 0
|
||||
self.call_counter += 1
|
||||
|
||||
# 禁用GPU内存监控记录以提高性能
|
||||
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
|
||||
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[CROSS_ATTN_ENTER] Call {self.call_counter}: GPU Memory: {allocated_before:.2f}GB")
|
||||
|
||||
# 分离多头
|
||||
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if pos_emb is not None:
|
||||
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q = q + pos_emb
|
||||
k = k + pos_emb
|
||||
v = v + pos_emb
|
||||
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
if context_mask is not None:
|
||||
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
||||
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
|
||||
context = torch.matmul(attn_weights, v)
|
||||
|
||||
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
|
||||
|
||||
context = self.to_out(context)
|
||||
|
||||
# 清理中间张量
|
||||
del q, k, v, attn_scores, attn_weights
|
||||
|
||||
# 监控交叉注意力结束时的内存(已禁用以提高性能)
|
||||
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
|
||||
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[CROSS_ATTN_EXIT] Call {self.call_counter}: GPU Memory: {allocated_after:.2f}GB")
|
||||
|
||||
return context
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: LMConfig):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
assert args.n_heads % self.n_kv_heads == 0
|
||||
self.n_local_heads = args.n_heads
|
||||
self.n_local_kv_heads = self.n_kv_heads
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
self.dropout = args.dropout
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
pos_cis: torch.Tensor):
|
||||
bsz, seq_len, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||||
if self.flash and seq_len != 1:
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
output = F.scaled_dot_product_attention(
|
||||
xq, xk, xv,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=True
|
||||
)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores += self.mask[:, :, :seq_len, :seq_len]
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
scores = self.attn_dropout(scores)
|
||||
output = scores @ xv
|
||||
|
||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
output = self.resid_dropout(self.wo(output))
|
||||
return output
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
if config.hidden_dim is None:
|
||||
hidden_dim = 4 * config.dim
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
|
||||
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
|
||||
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
|
||||
self.scoring_func = config.scoring_func
|
||||
self.alpha = config.aux_loss_alpha
|
||||
self.seq_aux = config.seq_aux
|
||||
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gating_dim = config.dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
import torch.nn.init as init
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
if self.scoring_func == 'softmax':
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
if self.seq_aux:
|
||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||
ce.scatter_add_(1, topk_idx_for_aux_loss,
|
||||
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
|
||||
seq_len * aux_topk / self.n_routed_experts)
|
||||
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
||||
else:
|
||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||
ce = mask_ce.float().mean(0)
|
||||
Pi = scores_for_aux.mean(0)
|
||||
fi = ce * self.n_routed_experts
|
||||
aux_loss = (Pi * fi).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = 0
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
|
||||
class MOEFeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(config)
|
||||
for _ in range(config.n_routed_experts)
|
||||
])
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
self.shared_experts = FeedForward(config)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
orig_shape = x.shape
|
||||
bsz, seq_len, _ = x.shape
|
||||
# 使用门控机制选择专家
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
|
||||
y = torch.empty_like(x, dtype=torch.float16)
|
||||
for i, expert in enumerate(self.experts):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape)
|
||||
else:
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(identity)
|
||||
self.aux_loss = aux_loss
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
token_idxs = idxs // self.config.num_experts_per_tok
|
||||
# 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
|
||||
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
|
||||
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
|
||||
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens).to(expert_cache.dtype)
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
|
||||
|
||||
return expert_cache
|
||||
|
||||
|
||||
class TripleExtractionHead(nn.Module):
|
||||
"""三元组提取任务头"""
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# 三元组长度超参数
|
||||
self.max_subject_len = config.max_subject_len
|
||||
self.max_predicate_len = config.max_predicate_len
|
||||
self.max_object_len = config.max_object_len
|
||||
|
||||
# 自注意力机制
|
||||
self.self_attention = Attention(config)
|
||||
self.self_attn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
|
||||
# 交叉注意力机制(用于主语和宾语提取)
|
||||
# self.cross_attention_subject = CrossAttention(config)
|
||||
# self.cross_attention_object = CrossAttention(config)
|
||||
|
||||
# 归一化层
|
||||
self.subject_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.object_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
|
||||
# Feed Forward 网络
|
||||
self.predicate_ff = FeedForward(config)
|
||||
# self.subject_ff = FeedForward(config)
|
||||
# self.object_ff = FeedForward(config)
|
||||
|
||||
# 输出投影层 - 修改为支持序列预测
|
||||
self.predicate_output = nn.Linear(config.dim, 264, bias=False)
|
||||
# self.subject_output = nn.Linear(config.dim, self.max_subject_len * config.dim, bias=False)
|
||||
# self.object_output = nn.Linear(config.dim, self.max_object_len * config.dim, bias=False)
|
||||
|
||||
print(f"三元组提取任务头配置:")
|
||||
print(f"- 主语最大长度: {self.max_subject_len}")
|
||||
print(f"- 谓语最大长度: {self.max_predicate_len}")
|
||||
print(f"- 宾语最大长度: {self.max_object_len}")
|
||||
|
||||
def forward(self, h, pos_cis):
|
||||
"""
|
||||
Args:
|
||||
h: [batch_size, seq_len, dim] - 来自transformer层的隐藏状态
|
||||
pos_cis: 位置编码
|
||||
Returns:
|
||||
predicate_logits: [batch_size, seq_len, max_predicate_len, vocab_size] - 谓语序列预测
|
||||
subject_logits: [batch_size, seq_len, max_subject_len, vocab_size] - 主语序列预测
|
||||
object_logits: [batch_size, seq_len, max_object_len, vocab_size] - 宾语序列预测
|
||||
"""
|
||||
batch_size, seq_len, dim = h.shape
|
||||
|
||||
# 1. h通过自注意力得到h1
|
||||
h1 = self.self_attention(self.self_attn_norm(h), pos_cis)
|
||||
h1 = h + h1 # 残差连接
|
||||
|
||||
# 2. h1通过feed_forward得到谓语输出
|
||||
predicate_features = self.predicate_ff(h1)
|
||||
predicate_features = predicate_features.mean(dim=1)
|
||||
predicate_class = self.predicate_output(predicate_features) # [batch_size, max_predicate_len * vocab_size]
|
||||
|
||||
# # 3. h1通过交叉注意力(k,v都是h)得到h2
|
||||
# h2 = self.cross_attention_subject(h1, h) # query是h1,key和value都是h
|
||||
# h2 = h1 + h2 # 残差连接
|
||||
|
||||
# # 4. h2通过feed_forward得到主语输出
|
||||
# subject_features = self.subject_ff(self.subject_norm(h2))
|
||||
# subject_features = subject_features.mean(dim=1)
|
||||
# subject_raw = self.subject_output(subject_features) # [batch_size, max_subject_len * vocab_size]
|
||||
# subject_logits = subject_raw.view(batch_size, self.max_subject_len, -1)
|
||||
|
||||
# # 5. h2通过交叉注意力(k,v都是h)得到h3
|
||||
# h3 = self.cross_attention_object(h2, h) # query是h2,key和value都是h
|
||||
# h3 = h2 + h3 # 残差连接
|
||||
|
||||
# # 6. h3通过feed_forward得到宾语输出
|
||||
# object_features = self.object_ff(self.object_norm(h3))
|
||||
# object_features = object_features.mean(dim=1)
|
||||
# object_raw = self.object_output(object_features) # [batch_size, max_object_len * vocab_size]
|
||||
# object_logits = object_raw.view(batch_size, self.max_object_len, -1)
|
||||
|
||||
return predicate_class
|
||||
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
|
||||
super().__init__()
|
||||
self.n_heads = config.n_heads
|
||||
self.dim = config.dim
|
||||
self.head_dim = config.dim // config.n_heads
|
||||
self.self_attention = Attention(config)
|
||||
self.cross_attention = CrossAttention(config)
|
||||
self.knowledge_dataset = knowledge_dataset
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||
|
||||
def forward(self, x, pos_cis):
|
||||
h_attn = self.self_attention(
|
||||
self.attention_norm(x),
|
||||
pos_cis
|
||||
)
|
||||
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
|
||||
h_attn = self.cross_attention(h_attn, db_embeddings)
|
||||
h = x + h_attn
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
class MiniMindLM(PreTrainedModel):
|
||||
config_class = LMConfig
|
||||
|
||||
def __init__(self, params: LMConfig = None,mode="triple"):
|
||||
self.params = params or LMConfig()
|
||||
super().__init__(self.params)
|
||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||||
self.dropout = nn.Dropout(params.dropout)
|
||||
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
|
||||
# 添加三元组提取任务头(可训练)
|
||||
self.triple_extraction_head = TripleExtractionHead(params)
|
||||
self.register_buffer("pos_cis",
|
||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
self.OUT = CausalLMOutputWithPast()
|
||||
self.freeze_embedding = False
|
||||
|
||||
self.mode = mode
|
||||
|
||||
# 冻结所有指定组件的权重
|
||||
self._freeze_components()
|
||||
|
||||
def _freeze_components(self):
|
||||
"""冻结指定组件的权重"""
|
||||
# 冻结词嵌入层
|
||||
for param in self.tok_embeddings.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# 冻结知识数据库
|
||||
for param in self.knowledge_dataset.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# 冻结所有transformer层
|
||||
for param in self.layers.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# 冻结输出层
|
||||
for param in self.output.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# pos_cis是buffer,本身就不需要梯度,但为了明确起见
|
||||
# (实际上buffer默认就是requires_grad=False)
|
||||
if hasattr(self, 'pos_cis'):
|
||||
self.pos_cis.requires_grad = False
|
||||
|
||||
print("已冻结以下组件的权重:")
|
||||
print("- tok_embeddings")
|
||||
print("- knowledge_dataset")
|
||||
print("- layers (所有transformer层)")
|
||||
print("- output")
|
||||
print("- pos_cis")
|
||||
print("注意:triple_extraction_head 保持可训练状态")
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
step: int = 0,
|
||||
**args):
|
||||
start_pos = args.get('start_pos', 0)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
for l, layer in enumerate(self.layers):
|
||||
h = layer(
|
||||
h, pos_cis
|
||||
)
|
||||
|
||||
# 应用三元组提取任务头
|
||||
predicate_class = self.triple_extraction_head(h, pos_cis)
|
||||
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
||||
|
||||
# 进一步简化,只保留必要的参数
|
||||
output = CausalLMOutputWithPast(
|
||||
logits=logits,
|
||||
)
|
||||
output.hidden_states = h
|
||||
output.aux_loss = aux_loss
|
||||
|
||||
# 添加三元组提取结果
|
||||
# 注意:现在的维度是 [batch_size, seq_len, max_len, vocab_size]
|
||||
output.predicate_class = predicate_class
|
||||
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||||
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
|
||||
# 流式生成
|
||||
if stream:
|
||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
|
||||
# 直接生成
|
||||
generated = []
|
||||
for i in range(input_ids.size(0)):
|
||||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||||
for _ in range(num_return_sequences):
|
||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
tokens_list = [tokens[:, -1:] for tokens in out]
|
||||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||||
generated.append(full_sequence)
|
||||
|
||||
max_length = max(seq.size(1) for seq in generated)
|
||||
generated = [
|
||||
torch.cat(
|
||||
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
|
||||
dim=-1)
|
||||
for seq in generated
|
||||
]
|
||||
output = torch.cat(generated, dim=0)
|
||||
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||||
return res
|
||||
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
|
||||
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
||||
while input_ids.shape[1] < max_new_tokens - 1:
|
||||
if first_seq:
|
||||
out, first_seq = self(input_ids, **args), False
|
||||
else:
|
||||
out = self(input_ids[:, -1:],
|
||||
start_pos=input_ids.shape[1] - 1, **args)
|
||||
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
||||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||||
logits /= (temperature + 1e-9)
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||
sorted_indices_to_remove[:, 0] = False
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = -float('Inf')
|
||||
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
||||
yield input_ids[:, start:]
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
||||
|
@ -1,488 +0,0 @@
|
||||
import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
import gc
|
||||
#子空间二维分解+梯度更新
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self.weight * self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return pos_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(xq, xk, pos_cis):
|
||||
def unite_shape(pos_cis, x):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert pos_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return pos_cis.view(*shape)
|
||||
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
pos_cis = unite_shape(pos_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
class KnowledgeDataset(nn.Module):
|
||||
def __init__(self, params, tok_embeddings, is_train=True):
|
||||
super().__init__()
|
||||
self.is_train = is_train
|
||||
self.params = params
|
||||
self.tok_embeddings = tok_embeddings
|
||||
|
||||
# 嵌入参数
|
||||
self.knowledge_dim = params.knowledge_dim
|
||||
self.key_dim = self.knowledge_dim // 2
|
||||
self.to_queries = nn.Sequential(
|
||||
nn.Linear(params.dim, self.knowledge_dim, bias=False),
|
||||
)
|
||||
|
||||
## 数据库参数
|
||||
self.knowledge_num = params.knowledge_num
|
||||
self.knowledge_length = params.knowledge_length
|
||||
|
||||
# 修改键存储为二维分解空间,设置为可训练参数
|
||||
self.num_keys = int(math.sqrt(self.knowledge_num))
|
||||
# 确保keys是可训练参数
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
|
||||
# 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度
|
||||
self.register_buffer('knowledge_dataset',
|
||||
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
|
||||
|
||||
# 计算step数目,用于动态调整权重
|
||||
self.step_counter = 0
|
||||
|
||||
# 移除批次计数器和更新频率相关代码
|
||||
|
||||
def intelligent_selection(self, query, all_scores, all_indices):
|
||||
"""智能分层选择策略"""
|
||||
if self.is_train == False:
|
||||
return all_scores, all_indices
|
||||
|
||||
batch_size = all_scores.size(0)
|
||||
device = all_scores.device
|
||||
dtype = all_scores.dtype
|
||||
|
||||
# 记录进入智能选择前的内存状态
|
||||
if hasattr(self, 'step_counter'):
|
||||
self.step_counter += 1
|
||||
# 禁用GPU内存监控记录以提高性能
|
||||
# if self.step_counter % 50 == 0: # 每50次调用记录一次
|
||||
# if torch.cuda.is_available():
|
||||
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[INTEL_SELECT_ENTER] Step {self.step_counter}: GPU Memory: {allocated_before:.2f}GB")
|
||||
|
||||
# 对每个batch进行分层选择
|
||||
enhanced_scores = all_scores.clone()
|
||||
query_features = query.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
# 预先计算所有候选条目的嵌入(批量优化)
|
||||
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
|
||||
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
|
||||
|
||||
# 批量计算唯一候选条目的嵌入
|
||||
candidate_tokens = self.knowledge_dataset[unique_indices]
|
||||
flat_tokens = candidate_tokens.view(-1)
|
||||
flat_embeddings = self.tok_embeddings(flat_tokens)
|
||||
|
||||
# 获取flat_tokens对应的index(保留这些变量以便其他地方使用)
|
||||
pre_update_indices = unique_indices.view(-1)
|
||||
pre_update_embeddings = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
)
|
||||
|
||||
unique_candidate_features = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
).mean(dim=1) # [num_unique_candidates, dim]
|
||||
|
||||
# 归一化候选特征(优化相似度计算)
|
||||
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
|
||||
normalized_queries = F.normalize(query_features, dim=-1)
|
||||
|
||||
# 收集所有batch的best_tokens
|
||||
batch_best_tokens = []
|
||||
batch_best_tokens_embeddings = []
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
indices = all_indices[batch_idx]
|
||||
|
||||
# 获取当前batch候选条目对应的特征索引
|
||||
start_idx = batch_idx * len(indices)
|
||||
end_idx = start_idx + len(indices)
|
||||
batch_inverse_indices = inverse_indices[start_idx:end_idx]
|
||||
|
||||
# 使用预计算的归一化特征进行优化相似度计算
|
||||
batch_candidate_features = normalized_candidates[batch_inverse_indices]
|
||||
query_feature = normalized_queries[batch_idx]
|
||||
|
||||
# 使用矩阵乘法计算余弦相似度
|
||||
similarity_scores = torch.mv(batch_candidate_features, query_feature)
|
||||
|
||||
# 找到最大相似度分数的索引
|
||||
max_similarity_idx = torch.argmax(similarity_scores)
|
||||
|
||||
# 获取最大相似度对应的候选条目索引
|
||||
best_candidate_idx = indices[max_similarity_idx]
|
||||
|
||||
# 获取对应的tokens
|
||||
best_tokens = self.knowledge_dataset[best_candidate_idx]
|
||||
best_tokens_embeddings = self.tok_embeddings(best_tokens)
|
||||
|
||||
# 将当前batch的best_tokens添加到列表中
|
||||
batch_best_tokens.append(best_tokens)
|
||||
batch_best_tokens_embeddings.append(best_tokens_embeddings)
|
||||
|
||||
# 将所有batch的best_tokens堆叠成一个张量
|
||||
# [batch_size, knowledge_length]
|
||||
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
|
||||
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
|
||||
|
||||
# 清理中间张量以防止内存泄漏
|
||||
del all_candidate_indices, unique_indices, inverse_indices
|
||||
del unique_candidate_features, normalized_candidates, normalized_queries
|
||||
del batch_best_tokens, batch_best_tokens_embeddings
|
||||
del flat_tokens, flat_embeddings, pre_update_embeddings
|
||||
|
||||
# 记录退出智能选择后的内存状态(已禁用以提高性能)
|
||||
# if hasattr(self, 'step_counter') and self.step_counter % 50 == 0:
|
||||
# if torch.cuda.is_available():
|
||||
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[INTEL_SELECT_EXIT] Step {self.step_counter}: GPU Memory: {allocated_after:.2f}GB")
|
||||
|
||||
# 强制垃圾回收(仅在监控步骤)
|
||||
if hasattr(self, 'step_counter') and self.step_counter % 100 == 0:
|
||||
gc.collect()
|
||||
# if torch.cuda.is_available():
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
return all_best_tokens, all_best_tokens_embeddings
|
||||
|
||||
|
||||
|
||||
def search_index(self, x):
|
||||
batch_size, seq_len, dim = x.shape
|
||||
|
||||
# 1. 序列维度平均
|
||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
# 2. 生成查询向量并重塑为两个子查询
|
||||
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
|
||||
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
|
||||
# 调整维度顺序,使子空间维度位于首位
|
||||
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
|
||||
|
||||
# 3. 计算每个子空间的相似度
|
||||
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# 4. 在两个子空间分别做top-k
|
||||
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# 5. 组合两个子空间的结果
|
||||
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
|
||||
# 6. 将结果重塑为二维
|
||||
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
|
||||
# 7. 选择最终的top-k结果
|
||||
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
|
||||
indices = torch.gather(all_indices, 1, indices_of_indices)
|
||||
|
||||
# 8. 应用智能分层选择策略
|
||||
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
|
||||
|
||||
|
||||
return best_tokens, best_tokens_embeddings
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = 8
|
||||
self.head_dim = self.config.dim // self.num_heads
|
||||
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||
batch_size = x.size(0)
|
||||
|
||||
# 监控交叉注意力开始时的内存(已禁用以提高性能)
|
||||
if not hasattr(self, 'call_counter'):
|
||||
self.call_counter = 0
|
||||
self.call_counter += 1
|
||||
|
||||
# 禁用GPU内存监控记录以提高性能
|
||||
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
|
||||
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[CROSS_ATTN_ENTER] Call {self.call_counter}: GPU Memory: {allocated_before:.2f}GB")
|
||||
|
||||
# 分离多头
|
||||
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if pos_emb is not None:
|
||||
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q = q + pos_emb
|
||||
k = k + pos_emb
|
||||
v = v + pos_emb
|
||||
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
if context_mask is not None:
|
||||
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
||||
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
|
||||
context = torch.matmul(attn_weights, v)
|
||||
|
||||
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
|
||||
|
||||
context = self.to_out(context)
|
||||
|
||||
# 清理中间张量
|
||||
del q, k, v, attn_scores, attn_weights
|
||||
|
||||
# 监控交叉注意力结束时的内存(已禁用以提高性能)
|
||||
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
|
||||
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[CROSS_ATTN_EXIT] Call {self.call_counter}: GPU Memory: {allocated_after:.2f}GB")
|
||||
|
||||
return context
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: LMConfig):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
assert args.n_heads % self.n_kv_heads == 0
|
||||
self.n_local_heads = args.n_heads
|
||||
self.n_local_kv_heads = self.n_kv_heads
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
self.dropout = args.dropout
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
pos_cis: torch.Tensor):
|
||||
bsz, seq_len, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||||
if self.flash and seq_len != 1:
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
output = F.scaled_dot_product_attention(
|
||||
xq, xk, xv,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=True
|
||||
)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores += self.mask[:, :, :seq_len, :seq_len]
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
scores = self.attn_dropout(scores)
|
||||
output = scores @ xv
|
||||
|
||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
output = self.resid_dropout(self.wo(output))
|
||||
return output
|
||||
|
||||
|
||||
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
|
||||
super().__init__()
|
||||
self.n_heads = config.n_heads
|
||||
self.dim = config.dim
|
||||
self.head_dim = config.dim // config.n_heads
|
||||
self.self_attention = Attention(config)
|
||||
self.cross_attention = CrossAttention(config)
|
||||
self.knowledge_dataset = knowledge_dataset
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
# 移除 ffn_norm 和 feed_forward,因为不再使用 FeedForward 层
|
||||
|
||||
def forward(self, x, pos_cis):
|
||||
h_attn = self.self_attention(
|
||||
self.attention_norm(x),
|
||||
pos_cis
|
||||
)
|
||||
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
|
||||
h_attn = self.cross_attention(h_attn, db_embeddings)
|
||||
h = x + h_attn
|
||||
# 移除 FeedForward 层,直接返回注意力输出
|
||||
return h
|
||||
|
||||
|
||||
class MiniMindLM(PreTrainedModel):
|
||||
config_class = LMConfig
|
||||
|
||||
def __init__(self, params: LMConfig = None):
|
||||
self.params = params or LMConfig()
|
||||
super().__init__(self.params)
|
||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||||
self.dropout = nn.Dropout(params.dropout)
|
||||
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
self.register_buffer("pos_cis",
|
||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
self.OUT = CausalLMOutputWithPast()
|
||||
self.freeze_embedding = False
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
step: int = 0,
|
||||
**args):
|
||||
start_pos = args.get('start_pos', 0)
|
||||
# if self.freeze_embedding and step == 0:
|
||||
# self.tok_embeddings.weight.requires_grad = False
|
||||
# # 移除对knowledge_dataset.freeze_embedding的设置,让键更新由batch_counter控制
|
||||
# # self.knowledge_dataset.freeze_embedding = True
|
||||
# print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
for l, layer in enumerate(self.layers):
|
||||
h = layer(
|
||||
h, pos_cis
|
||||
)
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
# 移除 aux_loss 计算,因为不再使用 FeedForward 层
|
||||
aux_loss = 0
|
||||
|
||||
# 进一步简化,只保留必要的参数
|
||||
output = CausalLMOutputWithPast(
|
||||
logits=logits,
|
||||
)
|
||||
output.hidden_states = h
|
||||
|
||||
output.aux_loss = aux_loss
|
||||
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||||
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
|
||||
# 流式生成
|
||||
if stream:
|
||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
|
||||
# 直接生成
|
||||
generated = []
|
||||
for i in range(input_ids.size(0)):
|
||||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||||
for _ in range(num_return_sequences):
|
||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
tokens_list = [tokens[:, -1:] for tokens in out]
|
||||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||||
generated.append(full_sequence)
|
||||
|
||||
max_length = max(seq.size(1) for seq in generated)
|
||||
generated = [
|
||||
torch.cat(
|
||||
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
|
||||
dim=-1)
|
||||
for seq in generated
|
||||
]
|
||||
output = torch.cat(generated, dim=0)
|
||||
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||||
return res
|
||||
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
|
||||
start = input_ids.shape[1]
|
||||
for _ in range(max_new_tokens):
|
||||
# 每次都传入完整的input_ids,不使用KV缓存
|
||||
out = self(input_ids, **args)
|
||||
logits = out.logits[:, -1, :] # 取最后一个位置的logits
|
||||
|
||||
# 重复惩罚
|
||||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||||
|
||||
# 温度采样
|
||||
logits /= (temperature + 1e-9)
|
||||
|
||||
# Top-p采样
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||
sorted_indices_to_remove[:, 0] = False
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = -float('Inf')
|
||||
|
||||
# 采样下一个token
|
||||
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
||||
|
||||
# 返回新生成的部分
|
||||
yield input_ids[:, start:]
|
||||
|
||||
# 如果遇到结束token,停止生成
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
@ -1,386 +0,0 @@
|
||||
import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self.weight * self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return pos_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(xq, xk, pos_cis):
|
||||
def unite_shape(pos_cis, x):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert pos_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return pos_cis.view(*shape)
|
||||
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
pos_cis = unite_shape(pos_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||||
bs, slen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x[:, :, :, None, :]
|
||||
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
||||
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: LMConfig):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
assert args.n_heads % self.n_kv_heads == 0
|
||||
self.n_local_heads = args.n_heads
|
||||
self.n_local_kv_heads = self.n_kv_heads
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
self.dropout = args.dropout
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
pos_cis: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache=False):
|
||||
bsz, seq_len, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||||
# kv_cache实现
|
||||
if past_key_value is not None:
|
||||
xk = torch.cat([past_key_value[0], xk], dim=1)
|
||||
xv = torch.cat([past_key_value[1], xv], dim=1)
|
||||
past_kv = (xk, xv) if use_cache else None
|
||||
|
||||
xq, xk, xv = (
|
||||
xq.transpose(1, 2),
|
||||
repeat_kv(xk, self.n_rep).transpose(1, 2),
|
||||
repeat_kv(xv, self.n_rep).transpose(1, 2)
|
||||
)
|
||||
if self.flash and seq_len != 1:
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
output = F.scaled_dot_product_attention(
|
||||
xq, xk, xv,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=True
|
||||
)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores += self.mask[:, :, :seq_len, :seq_len]
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
scores = self.attn_dropout(scores)
|
||||
output = scores @ xv
|
||||
|
||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
output = self.resid_dropout(self.wo(output))
|
||||
return output, past_kv
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
if config.hidden_dim is None:
|
||||
hidden_dim = 4 * config.dim
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
|
||||
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
|
||||
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
|
||||
self.scoring_func = config.scoring_func
|
||||
self.alpha = config.aux_loss_alpha
|
||||
self.seq_aux = config.seq_aux
|
||||
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gating_dim = config.dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
import torch.nn.init as init
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
if self.scoring_func == 'softmax':
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
if self.seq_aux:
|
||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||
ce.scatter_add_(1, topk_idx_for_aux_loss,
|
||||
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
|
||||
seq_len * aux_topk / self.n_routed_experts)
|
||||
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
||||
else:
|
||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||
ce = mask_ce.float().mean(0)
|
||||
Pi = scores_for_aux.mean(0)
|
||||
fi = ce * self.n_routed_experts
|
||||
aux_loss = (Pi * fi).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = 0
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
|
||||
class MOEFeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(config)
|
||||
for _ in range(config.n_routed_experts)
|
||||
])
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
self.shared_experts = FeedForward(config)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
orig_shape = x.shape
|
||||
bsz, seq_len, _ = x.shape
|
||||
# 使用门控机制选择专家
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
|
||||
y = torch.empty_like(x, dtype=torch.float16)
|
||||
for i, expert in enumerate(self.experts):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape)
|
||||
else:
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(identity)
|
||||
self.aux_loss = aux_loss
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
token_idxs = idxs // self.config.num_experts_per_tok
|
||||
# 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
|
||||
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
|
||||
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
|
||||
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens).to(expert_cache.dtype)
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
|
||||
|
||||
return expert_cache
|
||||
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, config: LMConfig):
|
||||
super().__init__()
|
||||
self.n_heads = config.n_heads
|
||||
self.dim = config.dim
|
||||
self.head_dim = config.dim // config.n_heads
|
||||
self.attention = Attention(config)
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||
|
||||
def forward(self, x, pos_cis, past_key_value=None, use_cache=False):
|
||||
h_attn, past_kv = self.attention(
|
||||
self.attention_norm(x),
|
||||
pos_cis,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache
|
||||
)
|
||||
h = x + h_attn
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out, past_kv
|
||||
|
||||
|
||||
class MiniMindLM(PreTrainedModel):
|
||||
config_class = LMConfig
|
||||
|
||||
def __init__(self, params: LMConfig = None):
|
||||
self.params = params or LMConfig()
|
||||
super().__init__(self.params)
|
||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||||
self.dropout = nn.Dropout(params.dropout)
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
self.register_buffer("pos_cis",
|
||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
self.OUT = CausalLMOutputWithPast()
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**args):
|
||||
past_key_values = past_key_values or [None] * len(self.layers)
|
||||
start_pos = args.get('start_pos', 0)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
past_kvs = []
|
||||
for l, layer in enumerate(self.layers):
|
||||
h, past_kv = layer(
|
||||
h, pos_cis,
|
||||
past_key_value=past_key_values[l],
|
||||
use_cache=use_cache
|
||||
)
|
||||
past_kvs.append(past_kv)
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
# 统一不使用 aux_loss
|
||||
aux_loss = 0
|
||||
self.OUT.__setitem__('last_hidden_state', h)
|
||||
self.OUT.__setitem__('logits', logits)
|
||||
self.OUT.__setitem__('aux_loss', aux_loss)
|
||||
self.OUT.__setitem__('past_key_values', past_kvs)
|
||||
return self.OUT
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||||
stream=False, rp=1., use_cache=True, pad_token_id=0, num_return_sequences=1, **args):
|
||||
# 流式生成
|
||||
if stream:
|
||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
||||
|
||||
# 直接生成
|
||||
generated = []
|
||||
for i in range(input_ids.size(0)):
|
||||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||||
for _ in range(num_return_sequences):
|
||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
||||
tokens_list = [tokens[:, -1:] for tokens in out]
|
||||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||||
generated.append(full_sequence)
|
||||
|
||||
max_length = max(seq.size(1) for seq in generated)
|
||||
generated = [
|
||||
torch.cat(
|
||||
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
|
||||
dim=-1)
|
||||
for seq in generated
|
||||
]
|
||||
output = torch.cat(generated, dim=0)
|
||||
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||||
return res
|
||||
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
|
||||
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
||||
while input_ids.shape[1] < max_new_tokens - 1:
|
||||
if first_seq or not use_cache:
|
||||
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False
|
||||
else:
|
||||
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
|
||||
start_pos=input_ids.shape[1] - 1, **args)
|
||||
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
||||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||||
logits /= (temperature + 1e-9)
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||
sorted_indices_to_remove[:, 0] = False
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = -float('Inf')
|
||||
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
||||
yield input_ids[:, start:]
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
File diff suppressed because it is too large
Load Diff
@ -1,43 +0,0 @@
|
||||
{
|
||||
"add_bos_token": false,
|
||||
"add_eos_token": false,
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "<|im_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "<|im_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [],
|
||||
"bos_token": "<|im_start|>",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "<|im_end|>",
|
||||
"legacy": true,
|
||||
"model_max_length": 32768,
|
||||
"pad_token": "<unk>",
|
||||
"sp_model_kwargs": {},
|
||||
"spaces_between_special_tokens": false,
|
||||
"tokenizer_class": "PreTrainedTokenizerFast",
|
||||
"unk_token": "<unk>",
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\n你是 MiniMind,是一个有用的人工智能助手。<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"
|
||||
}
|
File diff suppressed because one or more lines are too long
@ -1,133 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import datetime
|
||||
from typing import List, Dict, Any
|
||||
|
||||
# 配置参数
|
||||
json_path = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/dataset/combined.json"
|
||||
prepare_num = 1048576 # database_init.json的数据条数,可以根据需要修改
|
||||
output_dir = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/dataset/"
|
||||
|
||||
def convert_to_database_init_format(sentences: List[str], importance_score: float = 10.0) -> Dict[str, Any]:
|
||||
"""
|
||||
将句子列表转换为 database_init.json 格式
|
||||
|
||||
Args:
|
||||
sentences: 句子列表
|
||||
importance_score: 重要性评分,默认为10.0
|
||||
|
||||
Returns:
|
||||
转换后的字典格式数据
|
||||
"""
|
||||
# 构建句子数据
|
||||
sentence_data = []
|
||||
for sentence in sentences:
|
||||
sentence_item = {
|
||||
"original_sentence": sentence,
|
||||
"corrected_sentence": sentence, # 与original_sentence相同
|
||||
"importance_score": importance_score
|
||||
}
|
||||
sentence_data.append(sentence_item)
|
||||
|
||||
# 构建完整的数据结构
|
||||
result = {
|
||||
"metadata": {
|
||||
"batch_number": 1,
|
||||
"batch_size": len(sentences),
|
||||
"total_processed_count": len(sentences),
|
||||
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"total_sentences": len(sentences),
|
||||
"duplicates_removed": 0 # 在此函数中不涉及去重,所以设为0
|
||||
},
|
||||
"sentences": sentence_data
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def preprocess_combined_json():
|
||||
# 读取原始数据
|
||||
print("正在读取combined.json...")
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
total_count = len(data)
|
||||
print(f"总共有 {total_count} 条数据")
|
||||
|
||||
# 处理所有数据:将subject、predicate、object拼接成句子,同时记录原始数据
|
||||
print("正在处理数据并拼接句子...")
|
||||
sentence_to_original = {} # 记录句子到原始数据的映射
|
||||
all_sentences = []
|
||||
|
||||
for i, item in enumerate(data):
|
||||
# 拼接subject、predicate、object为一句话
|
||||
sentence = f"{item['subject']} {item['predicate']} {item['object']}"
|
||||
all_sentences.append(sentence)
|
||||
|
||||
# 记录句子到原始数据的映射(如果句子重复,保留第一次出现的原始数据)
|
||||
if sentence not in sentence_to_original:
|
||||
sentence_to_original[sentence] = item
|
||||
|
||||
if (i + 1) % 100000 == 0:
|
||||
print(f"已处理 {i + 1}/{total_count} 条数据")
|
||||
|
||||
print(f"完成句子拼接,共 {len(all_sentences)} 条句子")
|
||||
|
||||
# 去重处理
|
||||
print("正在进行去重处理...")
|
||||
unique_sentences = list(set(all_sentences))
|
||||
duplicates_removed = len(all_sentences) - len(unique_sentences)
|
||||
print(f"去重完成,去重前: {len(all_sentences)} 条,去重后: {len(unique_sentences)} 条,移除重复: {duplicates_removed} 条")
|
||||
|
||||
# 检查是否有足够的去重数据
|
||||
if len(unique_sentences) < prepare_num:
|
||||
print(f"警告: 去重后的数据量 ({len(unique_sentences)}) 少于所需数量 ({prepare_num})")
|
||||
print(f"将使用全部 {len(unique_sentences)} 条去重数据")
|
||||
selected_sentences = unique_sentences
|
||||
else:
|
||||
print(f"选择前 {prepare_num} 条去重数据")
|
||||
selected_sentences = unique_sentences[:prepare_num]
|
||||
|
||||
# 转换为database_init.json格式
|
||||
print("正在转换为database_init.json格式...")
|
||||
database_init_data = convert_to_database_init_format(selected_sentences, importance_score=10.0)
|
||||
|
||||
# 更新metadata中的duplicates_removed信息
|
||||
database_init_data["metadata"]["duplicates_removed"] = duplicates_removed
|
||||
|
||||
# 保存database_init.json
|
||||
database_output_path = os.path.join(output_dir, "database_init_from_combined.json")
|
||||
print(f"正在保存 {database_output_path}...")
|
||||
with open(database_output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(database_init_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"database_init_from_combined.json 保存完成,包含 {len(selected_sentences)} 条数据")
|
||||
|
||||
# 保存剩余数据作为训练集(保持原格式)
|
||||
remaining_sentences = unique_sentences[prepare_num:] if len(unique_sentences) > prepare_num else []
|
||||
if remaining_sentences:
|
||||
# 将剩余的句子转换回原始格式
|
||||
print(f"正在转换剩余 {len(remaining_sentences)} 条数据为原始格式...")
|
||||
remaining_original_data = []
|
||||
for sentence in remaining_sentences:
|
||||
if sentence in sentence_to_original:
|
||||
remaining_original_data.append(sentence_to_original[sentence])
|
||||
|
||||
print(f"保存剩余 {len(remaining_original_data)} 条数据作为训练集...")
|
||||
train_output_path = os.path.join(output_dir, "combined_train.json")
|
||||
with open(train_output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(remaining_original_data, f, ensure_ascii=False, indent=2)
|
||||
print(f"combined_train.json 保存完成")
|
||||
else:
|
||||
print("没有剩余数据用于训练集")
|
||||
remaining_original_data = []
|
||||
|
||||
print("\n数据处理完成!")
|
||||
print(f"原始数据: {total_count} 条")
|
||||
print(f"拼接后: {len(all_sentences)} 条句子")
|
||||
print(f"去重后: {len(unique_sentences)} 条句子")
|
||||
print(f"用于database_init: {len(selected_sentences)} 条")
|
||||
print(f"剩余训练数据: {len(remaining_original_data) if remaining_sentences else 0} 条")
|
||||
|
||||
if __name__ == "__main__":
|
||||
preprocess_combined_json()
|
@ -1,741 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import pandas as pd
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import re
|
||||
import langdetect
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
import random
|
||||
import hashlib
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 配置参数
|
||||
BASE_DIR = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind"
|
||||
OUTPUT_FILE = os.path.join(BASE_DIR, "dataset", "merged_pretrain.jsonl")
|
||||
OUTPUT_FILE_EXTRA = os.path.join(BASE_DIR, "dataset", "merged_pretrain_extra.jsonl")
|
||||
|
||||
# 数据源路径
|
||||
PRETRAIN_HQ_PATH = os.path.join(BASE_DIR, "dataset", "pretrain_hq.jsonl")
|
||||
WIKIPEDIA_PATH = "/home/pci/nas/share/datasets/wikipedia/data/20220301.en"
|
||||
GUTENBERG_PATH = "/home/pci/nas/share/datasets/gutenberg/data"
|
||||
OPENWEBTEXT_PATH = "/home/pci/nas/share/datasets/openwebtext/subsets"
|
||||
|
||||
# Token长度限制
|
||||
MIN_TOKENS = 410
|
||||
MAX_TOKENS = 490
|
||||
|
||||
# 数据集质量和采样比例配置 - 主文件
|
||||
DATASET_CONFIG = {
|
||||
"pretrain_hq": {"quality": "high", "sample_ratio": 1.0, "max_samples": None}, # 高质量,全部使用
|
||||
"wikipedia": {"quality": "high", "sample_ratio": 1.0, "max_samples": 5000000}, # 高质量,使用全部,最多500万条
|
||||
"gutenberg": {"quality": "medium", "sample_ratio": 0.8, "max_samples": 3000000}, # 中质量,使用80%,最多300万条
|
||||
"openwebtext": {"quality": "low", "sample_ratio": 0.2, "max_samples": 2000000} # 低质量,使用20%,最多200万条
|
||||
}
|
||||
|
||||
# 额外文件的配置 - 剩余数据
|
||||
DATASET_CONFIG_EXTRA = {
|
||||
"wikipedia": {"quality": "high", "sample_ratio": 1.0, "max_samples": None}, # 剩余的全部
|
||||
"gutenberg": {"quality": "medium", "sample_ratio": 0.8, "max_samples": 5000000}, # 剩余的80%,最多500万条
|
||||
"openwebtext": {"quality": "low", "sample_ratio": 0.6, "max_samples": 4000000} # 剩余的60%,最多400万条
|
||||
}
|
||||
|
||||
# 全局变量:记录已选择的数据
|
||||
selected_data_hashes = {
|
||||
"wikipedia": set(),
|
||||
"gutenberg": set(),
|
||||
"openwebtext": set()
|
||||
}
|
||||
|
||||
# 初始化tokenizer
|
||||
tokenizer = None
|
||||
|
||||
def init_tokenizer():
|
||||
"""初始化tokenizer"""
|
||||
global tokenizer
|
||||
try:
|
||||
# 首先尝试使用本地的minimind tokenizer
|
||||
local_tokenizer_path = os.path.join(BASE_DIR, "model", "minimind_tokenizer")
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
tokenizer = AutoTokenizer.from_pretrained(local_tokenizer_path)
|
||||
logger.info("Local MiniMind tokenizer initialized successfully")
|
||||
else:
|
||||
# 如果本地tokenizer不存在,使用GPT-2(但设置离线模式)
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2", local_files_only=True)
|
||||
logger.info("GPT-2 tokenizer initialized successfully (offline)")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tokenizer: {e}")
|
||||
logger.info("Trying to use a simple fallback tokenizer...")
|
||||
# 使用简单的分词方法作为备选
|
||||
tokenizer = None
|
||||
logger.warning("Using simple word-based tokenization as fallback")
|
||||
|
||||
def count_tokens(text):
|
||||
"""计算文本的token数量"""
|
||||
if tokenizer is None:
|
||||
init_tokenizer()
|
||||
|
||||
if tokenizer is not None:
|
||||
try:
|
||||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||||
return len(tokens)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 如果tokenization失败或tokenizer为None,使用简单估算
|
||||
return int(len(text.split()) * 1.3) # 大概估算,确保返回整数
|
||||
|
||||
def is_english_text(text, threshold=0.8):
|
||||
"""检测文本是否为英文"""
|
||||
try:
|
||||
if len(text) < 50: # 太短的文本跳过检测
|
||||
return True
|
||||
detected_lang = langdetect.detect(text)
|
||||
return detected_lang == 'en'
|
||||
except:
|
||||
# 如果检测失败,使用简单的英文字符比例判断
|
||||
english_chars = sum(1 for char in text if char.isascii() and (char.isalpha() or char in ' .,!?-'))
|
||||
total_chars = len(text.replace('\n', ' ').replace('\t', ' '))
|
||||
return (english_chars / max(total_chars, 1)) > threshold
|
||||
|
||||
def truncate_to_token_limit(text, target_tokens=MAX_TOKENS):
|
||||
"""将文本截断到目标token数量"""
|
||||
if tokenizer is None:
|
||||
init_tokenizer()
|
||||
|
||||
if tokenizer is not None:
|
||||
try:
|
||||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||||
if len(tokens) <= target_tokens:
|
||||
return text
|
||||
|
||||
# 截断到目标长度
|
||||
truncated_tokens = tokens[:target_tokens]
|
||||
truncated_text = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
|
||||
|
||||
# 尝试在句号处截断以保持完整性
|
||||
sentences = truncated_text.split('.')
|
||||
if len(sentences) > 1:
|
||||
# 保留除最后一个不完整句子外的所有句子
|
||||
truncated_text = '.'.join(sentences[:-1]) + '.'
|
||||
|
||||
return truncated_text
|
||||
except:
|
||||
pass
|
||||
|
||||
# 如果处理失败或tokenizer为None,使用字符数估算
|
||||
estimated_chars = int(target_tokens / 1.3 * 4) # 大概估算
|
||||
text = text[:estimated_chars]
|
||||
|
||||
# 尝试在句号处截断以保持完整性
|
||||
sentences = text.split('.')
|
||||
if len(sentences) > 1:
|
||||
text = '.'.join(sentences[:-1]) + '.'
|
||||
|
||||
return text
|
||||
|
||||
def split_text_into_chunks(text, target_chunk_size=1500):
|
||||
"""将长文本分割成多个中等长度的段落块"""
|
||||
# 清理文本
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# 移除过多的换行符和空格
|
||||
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
||||
text = re.sub(r' +', ' ', text)
|
||||
|
||||
chunks = []
|
||||
|
||||
# 按段落分割
|
||||
paragraphs = text.split('\n\n')
|
||||
current_chunk = ""
|
||||
|
||||
for paragraph in paragraphs:
|
||||
paragraph = paragraph.strip()
|
||||
if not paragraph:
|
||||
continue
|
||||
|
||||
# 如果当前块加上新段落长度适中,就添加
|
||||
if len(current_chunk) + len(paragraph) < target_chunk_size:
|
||||
if current_chunk:
|
||||
current_chunk += "\n\n" + paragraph
|
||||
else:
|
||||
current_chunk = paragraph
|
||||
else:
|
||||
# 如果当前块不为空,保存它
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
# 如果段落本身就很长,需要进一步分割
|
||||
if len(paragraph) > target_chunk_size * 2:
|
||||
# 按句子分割长段落
|
||||
sentences = re.split(r'(?<=[.!?])\s+', paragraph)
|
||||
temp_chunk = ""
|
||||
|
||||
for sentence in sentences:
|
||||
if len(temp_chunk) + len(sentence) < target_chunk_size:
|
||||
if temp_chunk:
|
||||
temp_chunk += " " + sentence
|
||||
else:
|
||||
temp_chunk = sentence
|
||||
else:
|
||||
if temp_chunk:
|
||||
chunks.append(temp_chunk)
|
||||
temp_chunk = sentence
|
||||
|
||||
if temp_chunk:
|
||||
current_chunk = temp_chunk
|
||||
else:
|
||||
current_chunk = ""
|
||||
else:
|
||||
current_chunk = paragraph
|
||||
|
||||
# 添加最后一个块
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
def format_text_for_pretrain(text):
|
||||
"""将文本格式化为预训练格式并检查token长度"""
|
||||
# 清理文本
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# 移除过多的换行符和空格
|
||||
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
||||
text = re.sub(r' +', ' ', text)
|
||||
|
||||
# 检查token长度
|
||||
token_count = count_tokens(text)
|
||||
|
||||
# 如果太短,跳过
|
||||
if token_count < MIN_TOKENS:
|
||||
return None
|
||||
|
||||
# 如果太长,截断
|
||||
if token_count > MAX_TOKENS:
|
||||
text = truncate_to_token_limit(text, MAX_TOKENS)
|
||||
token_count = count_tokens(text)
|
||||
|
||||
# 再次检查是否在合理范围内
|
||||
if token_count < MIN_TOKENS or token_count > MAX_TOKENS:
|
||||
return None
|
||||
|
||||
# 格式化为预训练格式
|
||||
formatted_text = f"<|im_start|>{text}<|im_end|>"
|
||||
return formatted_text
|
||||
|
||||
def get_text_hash(text):
|
||||
"""获取文本的哈希值,用于去重"""
|
||||
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
||||
|
||||
def should_sample(dataset_name, current_count, config_dict=None, is_extra_mode=False):
|
||||
"""根据配置决定是否采样当前记录"""
|
||||
if config_dict is None:
|
||||
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
|
||||
|
||||
config = config_dict[dataset_name]
|
||||
|
||||
# 检查是否达到最大样本数
|
||||
if config["max_samples"] and current_count >= config["max_samples"]:
|
||||
return False
|
||||
|
||||
# 根据采样比例随机决定
|
||||
return random.random() < config["sample_ratio"]
|
||||
|
||||
def process_pretrain_hq():
|
||||
"""处理已有的高质量预训练数据 - 直接输出,不做任何处理"""
|
||||
logger.info("Processing pretrain_hq.jsonl...")
|
||||
count = 0
|
||||
|
||||
with open(PRETRAIN_HQ_PATH, 'r', encoding='utf-8') as f:
|
||||
for line in tqdm(f, desc="Processing pretrain_hq"):
|
||||
try:
|
||||
data = json.loads(line.strip())
|
||||
text = data.get('text', '').strip()
|
||||
|
||||
if text: # 只要有文本就直接输出,不做任何检测
|
||||
if should_sample("pretrain_hq", count):
|
||||
yield {"text": text}
|
||||
count += 1
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
logger.info(f"Processed {count} records from pretrain_hq.jsonl")
|
||||
|
||||
def process_wikipedia(is_extra_mode=False):
|
||||
"""处理Wikipedia数据"""
|
||||
mode_text = "extra" if is_extra_mode else "main"
|
||||
logger.info(f"Processing Wikipedia data ({mode_text} mode)...")
|
||||
count = 0
|
||||
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
|
||||
|
||||
# 获取所有英文Wikipedia文件
|
||||
wiki_files = list(Path(WIKIPEDIA_PATH).glob("*.parquet"))
|
||||
|
||||
for file_path in tqdm(wiki_files, desc=f"Processing Wikipedia files ({mode_text})"):
|
||||
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
|
||||
break
|
||||
|
||||
try:
|
||||
df = pd.read_parquet(file_path)
|
||||
for _, row in df.iterrows():
|
||||
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
|
||||
break
|
||||
|
||||
text = row.get('text', '').strip()
|
||||
if text and len(text) > 200: # 预过滤太短的文本
|
||||
# 先将长文本分割成中等大小的块
|
||||
chunks = split_text_into_chunks(text, target_chunk_size=2000)
|
||||
|
||||
for chunk in chunks:
|
||||
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
|
||||
break
|
||||
|
||||
chunk_hash = get_text_hash(chunk)
|
||||
|
||||
# 在额外模式下,跳过已经被主文件选中的数据
|
||||
if is_extra_mode and chunk_hash in selected_data_hashes["wikipedia"]:
|
||||
continue
|
||||
|
||||
formatted_text = format_text_for_pretrain(chunk)
|
||||
if formatted_text and should_sample("wikipedia", count, config_dict, is_extra_mode):
|
||||
# 在主模式下记录哈希值
|
||||
if not is_extra_mode:
|
||||
selected_data_hashes["wikipedia"].add(chunk_hash)
|
||||
|
||||
yield {"text": formatted_text}
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Processed {count} records from Wikipedia ({mode_text} mode)")
|
||||
|
||||
def process_gutenberg(is_extra_mode=False):
|
||||
"""处理Gutenberg数据"""
|
||||
mode_text = "extra" if is_extra_mode else "main"
|
||||
logger.info(f"Processing Gutenberg data ({mode_text} mode)...")
|
||||
count = 0
|
||||
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
|
||||
|
||||
# 获取所有Gutenberg训练文件
|
||||
gutenberg_files = list(Path(GUTENBERG_PATH).glob("train-*.parquet"))
|
||||
|
||||
for file_path in tqdm(gutenberg_files, desc=f"Processing Gutenberg files ({mode_text})"):
|
||||
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
|
||||
break
|
||||
|
||||
try:
|
||||
df = pd.read_parquet(file_path)
|
||||
for _, row in df.iterrows():
|
||||
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
|
||||
break
|
||||
|
||||
text = row.get('text', '').strip()
|
||||
if text and len(text) > 300 and is_english_text(text): # 预过滤
|
||||
# 先将长文本分割成中等大小的块
|
||||
chunks = split_text_into_chunks(text, target_chunk_size=1800)
|
||||
|
||||
for chunk in chunks:
|
||||
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
|
||||
break
|
||||
|
||||
chunk_hash = get_text_hash(chunk)
|
||||
|
||||
# 在额外模式下,跳过已经被主文件选中的数据
|
||||
if is_extra_mode and chunk_hash in selected_data_hashes["gutenberg"]:
|
||||
continue
|
||||
|
||||
formatted_text = format_text_for_pretrain(chunk)
|
||||
if formatted_text and should_sample("gutenberg", count, config_dict, is_extra_mode):
|
||||
# 在主模式下记录哈希值
|
||||
if not is_extra_mode:
|
||||
selected_data_hashes["gutenberg"].add(chunk_hash)
|
||||
|
||||
yield {"text": formatted_text}
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Processed {count} records from Gutenberg ({mode_text} mode)")
|
||||
|
||||
def process_openwebtext(is_extra_mode=False):
|
||||
"""处理OpenWebText数据"""
|
||||
mode_text = "extra" if is_extra_mode else "main"
|
||||
logger.info(f"Processing OpenWebText data ({mode_text} mode)...")
|
||||
count = 0
|
||||
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
|
||||
max_files = 5 # 减少处理的文件数量以避免过长处理时间
|
||||
|
||||
# 获取tar文件列表
|
||||
tar_files = list(Path(OPENWEBTEXT_PATH).glob("*.tar"))[:max_files]
|
||||
|
||||
for tar_path in tqdm(tar_files, desc=f"Processing OpenWebText files ({mode_text})"):
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
try:
|
||||
with tarfile.open(tar_path, 'r') as outer_tar:
|
||||
# 创建临时目录处理外层tar
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
outer_tar.extractall(temp_dir)
|
||||
|
||||
# 处理解压后的xz文件
|
||||
for root, dirs, files in os.walk(temp_dir):
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
for file in files:
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
if file.endswith('.xz'):
|
||||
xz_path = os.path.join(root, file)
|
||||
|
||||
# 创建另一个临时目录处理xz文件
|
||||
with tempfile.TemporaryDirectory() as xz_temp_dir:
|
||||
try:
|
||||
# 解压xz文件
|
||||
import subprocess
|
||||
decompressed_path = os.path.join(xz_temp_dir, file[:-3]) # 移除.xz后缀
|
||||
subprocess.run(['xz', '-dc', xz_path],
|
||||
stdout=open(decompressed_path, 'wb'),
|
||||
check=True)
|
||||
|
||||
# 检查解压后的文件是否是tar格式
|
||||
if tarfile.is_tarfile(decompressed_path):
|
||||
# 处理内层tar文件
|
||||
with tarfile.open(decompressed_path, 'r') as inner_tar:
|
||||
with tempfile.TemporaryDirectory() as inner_temp_dir:
|
||||
inner_tar.extractall(inner_temp_dir)
|
||||
|
||||
# 处理最终的txt文件
|
||||
for inner_root, inner_dirs, inner_files in os.walk(inner_temp_dir):
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
for txt_file in inner_files:
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
if txt_file.endswith('.txt'):
|
||||
txt_path = os.path.join(inner_root, txt_file)
|
||||
try:
|
||||
with open(txt_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
text = f.read().strip()
|
||||
if text and len(text) > 500 and is_english_text(text):
|
||||
# 先将长文本分割成中等大小的块
|
||||
chunks = split_text_into_chunks(text, target_chunk_size=1600)
|
||||
|
||||
for chunk in chunks:
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
chunk_hash = get_text_hash(chunk)
|
||||
|
||||
# 在额外模式下,跳过已经被主文件选中的数据
|
||||
if is_extra_mode and chunk_hash in selected_data_hashes["openwebtext"]:
|
||||
continue
|
||||
|
||||
formatted_text = format_text_for_pretrain(chunk)
|
||||
if formatted_text and should_sample("openwebtext", count, config_dict, is_extra_mode):
|
||||
# 在主模式下记录哈希值
|
||||
if not is_extra_mode:
|
||||
selected_data_hashes["openwebtext"].add(chunk_hash)
|
||||
|
||||
yield {"text": formatted_text}
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.debug(f"Error reading txt file {txt_path}: {e}")
|
||||
continue
|
||||
else:
|
||||
# 如果不是tar文件,直接作为文本处理
|
||||
try:
|
||||
with open(decompressed_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
text = f.read().strip()
|
||||
if text and len(text) > 500 and is_english_text(text):
|
||||
chunks = split_text_into_chunks(text, target_chunk_size=1600)
|
||||
|
||||
for chunk in chunks:
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
chunk_hash = get_text_hash(chunk)
|
||||
|
||||
# 在额外模式下,跳过已经被主文件选中的数据
|
||||
if is_extra_mode and chunk_hash in selected_data_hashes["openwebtext"]:
|
||||
continue
|
||||
|
||||
formatted_text = format_text_for_pretrain(chunk)
|
||||
if formatted_text and should_sample("openwebtext", count, config_dict, is_extra_mode):
|
||||
# 在主模式下记录哈希值
|
||||
if not is_extra_mode:
|
||||
selected_data_hashes["openwebtext"].add(chunk_hash)
|
||||
|
||||
yield {"text": formatted_text}
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.debug(f"Error reading decompressed file {decompressed_path}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error processing xz file {xz_path}: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {tar_path}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Processed {count} records from OpenWebText ({mode_text} mode)")
|
||||
|
||||
def merge_datasets():
|
||||
"""合并所有数据集,生成主文件和额外文件"""
|
||||
logger.info("Starting dataset merging...")
|
||||
logger.info("Main dataset configuration:")
|
||||
for name, config in DATASET_CONFIG.items():
|
||||
logger.info(f" {name}: quality={config['quality']}, ratio={config['sample_ratio']}, max={config['max_samples']}")
|
||||
|
||||
logger.info("Extra dataset configuration:")
|
||||
for name, config in DATASET_CONFIG_EXTRA.items():
|
||||
logger.info(f" {name}: quality={config['quality']}, ratio={config['sample_ratio']}, max={config['max_samples']}")
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)
|
||||
os.makedirs(os.path.dirname(OUTPUT_FILE_EXTRA), exist_ok=True)
|
||||
|
||||
# 统计信息
|
||||
main_dataset_stats = {}
|
||||
extra_dataset_stats = {}
|
||||
|
||||
# 第一阶段:生成主文件
|
||||
logger.info("="*50)
|
||||
logger.info("PHASE 1: Generating main dataset file")
|
||||
logger.info("="*50)
|
||||
|
||||
with open(OUTPUT_FILE, 'w', encoding='utf-8') as outfile:
|
||||
main_total_count = 0
|
||||
|
||||
# 处理各个数据集(主模式)
|
||||
main_datasets = [
|
||||
("pretrain_hq", process_pretrain_hq),
|
||||
("wikipedia", lambda: process_wikipedia(is_extra_mode=False)),
|
||||
("gutenberg", lambda: process_gutenberg(is_extra_mode=False)),
|
||||
("openwebtext", lambda: process_openwebtext(is_extra_mode=False))
|
||||
]
|
||||
|
||||
for dataset_name, dataset_func in main_datasets:
|
||||
logger.info(f"Processing {dataset_name} for main file...")
|
||||
dataset_count = 0
|
||||
|
||||
try:
|
||||
for record in dataset_func():
|
||||
json.dump(record, outfile, ensure_ascii=False)
|
||||
outfile.write('\n')
|
||||
dataset_count += 1
|
||||
main_total_count += 1
|
||||
|
||||
# 每5000条记录输出一次进度
|
||||
if main_total_count % 5000 == 0:
|
||||
logger.info(f"Main file: Processed {main_total_count} total records")
|
||||
|
||||
# 保存统计信息
|
||||
main_dataset_stats[dataset_name] = {
|
||||
'selected': dataset_count,
|
||||
'config': DATASET_CONFIG[dataset_name]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {dataset_name} for main file: {e}")
|
||||
main_dataset_stats[dataset_name] = {
|
||||
'selected': dataset_count,
|
||||
'config': DATASET_CONFIG[dataset_name]
|
||||
}
|
||||
|
||||
logger.info(f"Main file - Completed {dataset_name}: {dataset_count} records")
|
||||
|
||||
logger.info(f"Main file generation completed. Total records: {main_total_count}")
|
||||
|
||||
# 第二阶段:生成额外文件
|
||||
logger.info("="*50)
|
||||
logger.info("PHASE 2: Generating extra dataset file")
|
||||
logger.info("="*50)
|
||||
|
||||
with open(OUTPUT_FILE_EXTRA, 'w', encoding='utf-8') as outfile:
|
||||
extra_total_count = 0
|
||||
|
||||
# 处理各个数据集(额外模式)- 不包括pretrain_hq
|
||||
extra_datasets = [
|
||||
("wikipedia", lambda: process_wikipedia(is_extra_mode=True)),
|
||||
("gutenberg", lambda: process_gutenberg(is_extra_mode=True)),
|
||||
("openwebtext", lambda: process_openwebtext(is_extra_mode=True))
|
||||
]
|
||||
|
||||
for dataset_name, dataset_func in extra_datasets:
|
||||
logger.info(f"Processing {dataset_name} for extra file...")
|
||||
dataset_count = 0
|
||||
|
||||
try:
|
||||
for record in dataset_func():
|
||||
json.dump(record, outfile, ensure_ascii=False)
|
||||
outfile.write('\n')
|
||||
dataset_count += 1
|
||||
extra_total_count += 1
|
||||
|
||||
# 每5000条记录输出一次进度
|
||||
if extra_total_count % 5000 == 0:
|
||||
logger.info(f"Extra file: Processed {extra_total_count} total records")
|
||||
|
||||
# 保存统计信息
|
||||
extra_dataset_stats[dataset_name] = {
|
||||
'selected': dataset_count,
|
||||
'config': DATASET_CONFIG_EXTRA[dataset_name]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {dataset_name} for extra file: {e}")
|
||||
extra_dataset_stats[dataset_name] = {
|
||||
'selected': dataset_count,
|
||||
'config': DATASET_CONFIG_EXTRA[dataset_name]
|
||||
}
|
||||
|
||||
logger.info(f"Extra file - Completed {dataset_name}: {dataset_count} records")
|
||||
|
||||
logger.info(f"Extra file generation completed. Total records: {extra_total_count}")
|
||||
|
||||
# 打印详细统计信息
|
||||
print_detailed_statistics(main_dataset_stats, main_total_count, extra_dataset_stats, extra_total_count)
|
||||
|
||||
logger.info("All dataset processing completed successfully!")
|
||||
logger.info(f"Main file saved to: {OUTPUT_FILE}")
|
||||
logger.info(f"Extra file saved to: {OUTPUT_FILE_EXTRA}")
|
||||
|
||||
def print_detailed_statistics(main_dataset_stats, main_total_count, extra_dataset_stats, extra_total_count):
|
||||
"""打印详细统计信息"""
|
||||
print("\n" + "="*100)
|
||||
print("DATASET PROCESSING SUMMARY")
|
||||
print("="*100)
|
||||
|
||||
# 主文件统计
|
||||
print("\nMAIN FILE (merged_pretrain.jsonl):")
|
||||
print("-" * 90)
|
||||
print(f"{'Dataset':<15} {'Selected':<10} {'Sample Ratio':<12} {'Max Limit':<12} {'% of Main':<12} {'Quality':<8}")
|
||||
print("-" * 90)
|
||||
|
||||
for dataset_name, stats in main_dataset_stats.items():
|
||||
selected = stats['selected']
|
||||
config = stats['config']
|
||||
ratio = config['sample_ratio']
|
||||
max_limit = config['max_samples'] if config['max_samples'] else "No limit"
|
||||
percentage = (selected / main_total_count * 100) if main_total_count > 0 else 0
|
||||
quality = config['quality']
|
||||
|
||||
print(f"{dataset_name:<15} {selected:<10,} {ratio:<12.1%} {str(max_limit):<12} {percentage:<12.2f}% {quality:<8}")
|
||||
|
||||
print("-" * 90)
|
||||
print(f"{'MAIN TOTAL':<15} {main_total_count:<10,} {'':<12} {'':<12} {'100.00%':<12} {'':<8}")
|
||||
|
||||
# 额外文件统计
|
||||
print("\nEXTRA FILE (merged_pretrain_extra.jsonl):")
|
||||
print("-" * 90)
|
||||
print(f"{'Dataset':<15} {'Selected':<10} {'Sample Ratio':<12} {'Max Limit':<12} {'% of Extra':<12} {'Quality':<8}")
|
||||
print("-" * 90)
|
||||
|
||||
for dataset_name, stats in extra_dataset_stats.items():
|
||||
selected = stats['selected']
|
||||
config = stats['config']
|
||||
ratio = config['sample_ratio']
|
||||
max_limit = config['max_samples'] if config['max_samples'] else "No limit"
|
||||
percentage = (selected / extra_total_count * 100) if extra_total_count > 0 else 0
|
||||
quality = config['quality']
|
||||
|
||||
print(f"{dataset_name:<15} {selected:<10,} {ratio:<12.1%} {str(max_limit):<12} {percentage:<12.2f}% {quality:<8}")
|
||||
|
||||
print("-" * 90)
|
||||
print(f"{'EXTRA TOTAL':<15} {extra_total_count:<10,} {'':<12} {'':<12} {'100.00%':<12} {'':<8}")
|
||||
|
||||
# 总体统计
|
||||
total_records = main_total_count + extra_total_count
|
||||
print("\nOVERALL STATISTICS:")
|
||||
print("-" * 50)
|
||||
print(f"Main file records: {main_total_count:>10,}")
|
||||
print(f"Extra file records: {extra_total_count:>10,}")
|
||||
print(f"Total records: {total_records:>10,}")
|
||||
print(f"Token range per sample: {MIN_TOKENS}-{MAX_TOKENS} tokens")
|
||||
|
||||
# 质量分布统计
|
||||
quality_stats = {}
|
||||
for dataset_name, stats in main_dataset_stats.items():
|
||||
quality = stats['config']['quality']
|
||||
if quality not in quality_stats:
|
||||
quality_stats[quality] = {'main': 0, 'extra': 0}
|
||||
quality_stats[quality]['main'] += stats['selected']
|
||||
|
||||
for dataset_name, stats in extra_dataset_stats.items():
|
||||
quality = stats['config']['quality']
|
||||
if quality not in quality_stats:
|
||||
quality_stats[quality] = {'main': 0, 'extra': 0}
|
||||
quality_stats[quality]['extra'] += stats['selected']
|
||||
|
||||
print("\nQUALITY DISTRIBUTION:")
|
||||
print("-" * 60)
|
||||
print(f"{'Quality':<12} {'Main File':<12} {'Extra File':<12} {'Total':<12} {'%':<8}")
|
||||
print("-" * 60)
|
||||
for quality in sorted(quality_stats.keys()):
|
||||
main_count = quality_stats[quality]['main']
|
||||
extra_count = quality_stats[quality]['extra']
|
||||
total_count = main_count + extra_count
|
||||
percentage = (total_count / total_records * 100) if total_records > 0 else 0
|
||||
print(f"{quality.capitalize():<12} {main_count:<12,} {extra_count:<12,} {total_count:<12,} {percentage:<8.2f}%")
|
||||
print("-" * 60)
|
||||
print(f"{'Total':<12} {main_total_count:<12,} {extra_total_count:<12,} {total_records:<12,} {'100.00%':<8}")
|
||||
|
||||
print(f"\nFiles saved to:")
|
||||
print(f" Main file: {OUTPUT_FILE}")
|
||||
print(f" Extra file: {OUTPUT_FILE_EXTRA}")
|
||||
print("="*100)
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
try:
|
||||
# 设置随机种子以确保结果可重现
|
||||
random.seed(42)
|
||||
|
||||
# 检查依赖包
|
||||
try:
|
||||
import langdetect
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError as e:
|
||||
logger.error(f"Missing dependencies: {e}")
|
||||
logger.error("Please install: pip install langdetect transformers")
|
||||
return
|
||||
|
||||
# 初始化tokenizer
|
||||
init_tokenizer()
|
||||
|
||||
# 检查输入文件
|
||||
if not os.path.exists(PRETRAIN_HQ_PATH):
|
||||
logger.error(f"pretrain_hq.jsonl not found at {PRETRAIN_HQ_PATH}")
|
||||
return
|
||||
|
||||
# 开始合并数据集
|
||||
merge_datasets()
|
||||
logger.info("All processing completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main process: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,442 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import argparse
|
||||
from typing import List, Dict, Any, Optional
|
||||
from collections import defaultdict
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
class WikidataRelationManager:
|
||||
"""Wikidata关系管理器,支持动态获取和缓存"""
|
||||
|
||||
def __init__(self, cache_file: str = "wikidata_relations_cache.pkl",
|
||||
mapping_file: str = None):
|
||||
self.cache_file = cache_file
|
||||
self.mapping_file = mapping_file
|
||||
self.relations = {}
|
||||
# 删除了API相关属性
|
||||
|
||||
# 初始的基础关系映射
|
||||
self.base_relations = {
|
||||
# # 基本关系
|
||||
# 'P31': 'instance of',
|
||||
# 'P279': 'subclass of',
|
||||
# 'P17': 'country',
|
||||
# 'P159': 'headquarters location',
|
||||
# 'P571': 'inception',
|
||||
|
||||
# # 人物关系
|
||||
# 'P19': 'place of birth',
|
||||
# 'P20': 'place of death',
|
||||
# 'P27': 'country of citizenship',
|
||||
# 'P106': 'occupation',
|
||||
# 'P22': 'father',
|
||||
# 'P25': 'mother',
|
||||
# 'P26': 'spouse',
|
||||
# 'P40': 'child',
|
||||
# 'P69': 'educated at',
|
||||
# 'P108': 'employer',
|
||||
|
||||
# # 地理关系
|
||||
# 'P36': 'capital',
|
||||
# 'P131': 'located in',
|
||||
# 'P47': 'shares border with',
|
||||
# 'P206': 'located on terrain feature',
|
||||
# 'P1376': 'capital of',
|
||||
|
||||
# # 组织关系
|
||||
# 'P112': 'founded by',
|
||||
# 'P127': 'owned by',
|
||||
# 'P169': 'chief executive officer',
|
||||
# 'P488': 'chairperson',
|
||||
# 'P749': 'parent organization',
|
||||
|
||||
# # 作品关系
|
||||
# 'P50': 'author',
|
||||
# 'P57': 'director',
|
||||
# 'P58': 'screenwriter',
|
||||
# 'P161': 'cast member',
|
||||
# 'P175': 'performer',
|
||||
# 'P577': 'publication date',
|
||||
# 'P123': 'publisher',
|
||||
# 'P136': 'genre',
|
||||
|
||||
# # 时间关系
|
||||
# 'P155': 'follows',
|
||||
# 'P156': 'followed by',
|
||||
# 'P580': 'start time',
|
||||
# 'P582': 'end time',
|
||||
|
||||
# # 体育关系
|
||||
# 'P54': 'member of sports team',
|
||||
# 'P413': 'position played on team',
|
||||
# 'P118': 'league',
|
||||
|
||||
# # 科学关系
|
||||
# 'P275': 'copyright license',
|
||||
# 'P170': 'creator',
|
||||
# 'P398': 'child astronomical body',
|
||||
# 'P397': 'parent astronomical body',
|
||||
|
||||
# # 其他常见关系
|
||||
# 'P37': 'official language',
|
||||
# 'P1923': 'place of marriage',
|
||||
# 'P737': 'influenced by',
|
||||
# 'P463': 'member of',
|
||||
# 'P39': 'position held',
|
||||
# 'P276': 'location',
|
||||
# 'P1441': 'present in work',
|
||||
}
|
||||
|
||||
self.load_cache()
|
||||
|
||||
def load_cache(self):
|
||||
"""加载缓存的关系映射,优先使用JSON映射文件"""
|
||||
try:
|
||||
# 优先尝试加载JSON映射文件
|
||||
if self.mapping_file and os.path.exists(self.mapping_file):
|
||||
with open(self.mapping_file, 'r', encoding='utf-8') as f:
|
||||
self.relations = json.load(f)
|
||||
print(f"从JSON映射文件加载了 {len(self.relations)} 个关系映射")
|
||||
return
|
||||
|
||||
# 尝试加载pickle缓存文件
|
||||
if os.path.exists(self.cache_file):
|
||||
with open(self.cache_file, 'rb') as f:
|
||||
self.relations = pickle.load(f)
|
||||
print(f"从pickle缓存加载了 {len(self.relations)} 个关系映射")
|
||||
else:
|
||||
self.relations = self.base_relations.copy()
|
||||
print(f"初始化基础关系映射: {len(self.relations)} 个")
|
||||
except Exception as e:
|
||||
print(f"加载缓存失败: {e}")
|
||||
self.relations = self.base_relations.copy()
|
||||
|
||||
def save_cache(self):
|
||||
"""保存关系映射到缓存"""
|
||||
try:
|
||||
with open(self.cache_file, 'wb') as f:
|
||||
pickle.dump(self.relations, f)
|
||||
print(f"已保存 {len(self.relations)} 个关系映射到缓存")
|
||||
except Exception as e:
|
||||
print(f"保存缓存失败: {e}")
|
||||
|
||||
# 删除了网络抓取功能,改为纯离线模式
|
||||
|
||||
def get_relation_name(self, property_id: str) -> Optional[str]:
|
||||
"""获取关系名称,仅使用本地映射"""
|
||||
if property_id in self.relations:
|
||||
return self.relations[property_id]
|
||||
|
||||
# 如果本地映射中没有找到,返回None(表示跳过这个关系)
|
||||
return None
|
||||
|
||||
# 删除了网络请求相关的批量获取和预加载功能
|
||||
|
||||
class TRexProcessor:
|
||||
"""T-REx数据集处理器"""
|
||||
|
||||
def __init__(self, relation_manager: WikidataRelationManager):
|
||||
self.relation_manager = relation_manager
|
||||
|
||||
def extract_predicate_id(self, uri: str) -> str:
|
||||
"""从URI中提取属性ID"""
|
||||
if uri and 'prop/direct/' in uri:
|
||||
return uri.split('/')[-1]
|
||||
elif uri and uri.startswith('P') and uri[1:].isdigit():
|
||||
return uri
|
||||
return uri if uri else 'unknown'
|
||||
|
||||
def get_relation_name(self, predicate_uri: str) -> Optional[str]:
|
||||
"""获取关系的可读名称"""
|
||||
predicate_id = self.extract_predicate_id(predicate_uri)
|
||||
return self.relation_manager.get_relation_name(predicate_id)
|
||||
|
||||
# 删除了谓词收集功能,因为不再需要预加载
|
||||
|
||||
def is_valid_triple(self, triple: Dict[str, Any], confidence_threshold: float,
|
||||
boundary_threshold: int) -> bool:
|
||||
"""检查三元组是否满足过滤条件"""
|
||||
try:
|
||||
# 检查triple是否为字典
|
||||
if not isinstance(triple, dict):
|
||||
return False
|
||||
|
||||
# 检查必要字段
|
||||
if not all(key in triple for key in ['subject', 'predicate', 'object']):
|
||||
return False
|
||||
|
||||
subject = triple['subject']
|
||||
predicate = triple['predicate']
|
||||
object_info = triple['object']
|
||||
|
||||
# 检查subject、predicate、object是否都为字典
|
||||
if not isinstance(subject, dict) or not isinstance(predicate, dict) or not isinstance(object_info, dict):
|
||||
return False
|
||||
|
||||
# 检查主语和宾语是否有有效的URI和surfaceform
|
||||
if not (subject.get('uri') and subject.get('surfaceform')):
|
||||
return False
|
||||
if not (object_info.get('uri') and object_info.get('surfaceform')):
|
||||
return False
|
||||
if not predicate.get('uri'):
|
||||
return False
|
||||
|
||||
# 检查置信度(如果存在)
|
||||
confidence = triple.get('confidence')
|
||||
if confidence is not None and confidence < confidence_threshold:
|
||||
return False
|
||||
|
||||
# 检查边界信息(如果设置了阈值)
|
||||
if boundary_threshold > 0:
|
||||
subject_boundaries = subject.get('boundaries')
|
||||
object_boundaries = object_info.get('boundaries')
|
||||
|
||||
if not subject_boundaries or not object_boundaries:
|
||||
return False
|
||||
|
||||
# 检查边界是否为列表且长度至少为2
|
||||
if not (isinstance(subject_boundaries, list) and len(subject_boundaries) >= 2):
|
||||
return False
|
||||
if not (isinstance(object_boundaries, list) and len(object_boundaries) >= 2):
|
||||
return False
|
||||
|
||||
try:
|
||||
# 检查边界长度是否合理
|
||||
subject_length = subject_boundaries[1] - subject_boundaries[0]
|
||||
object_length = object_boundaries[1] - object_boundaries[0]
|
||||
|
||||
if subject_length < boundary_threshold or object_length < boundary_threshold:
|
||||
return False
|
||||
except (TypeError, IndexError):
|
||||
return False
|
||||
|
||||
# 检查文本内容是否合理
|
||||
subject_text = subject.get('surfaceform', '').strip()
|
||||
object_text = object_info.get('surfaceform', '').strip()
|
||||
|
||||
if not subject_text or not object_text:
|
||||
return False
|
||||
|
||||
# 过滤掉过长或过短的实体
|
||||
if len(subject_text) > 100 or len(object_text) > 100:
|
||||
return False
|
||||
if len(subject_text) < 2 or len(object_text) < 2:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except (KeyError, TypeError, AttributeError):
|
||||
return False
|
||||
|
||||
def process_single_file(self, file_path: str, confidence_threshold: float,
|
||||
boundary_threshold: int) -> List[Dict[str, Any]]:
|
||||
"""处理单个JSON文件"""
|
||||
print(f"Processing file: {file_path}")
|
||||
|
||||
processed_data = []
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
# 读取整个文件作为JSON数组
|
||||
print(f"正在加载JSON数组文件: {file_path}")
|
||||
data_list = json.load(f)
|
||||
print(f"文件包含 {len(data_list)} 个条目")
|
||||
|
||||
for idx, data in enumerate(data_list):
|
||||
try:
|
||||
# 获取基本信息
|
||||
text = data.get('text', '').strip()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
# 处理三元组
|
||||
triples = data.get('triples', [])
|
||||
if not triples:
|
||||
continue
|
||||
|
||||
valid_targets = []
|
||||
|
||||
for triple in triples:
|
||||
if self.is_valid_triple(triple, confidence_threshold, boundary_threshold):
|
||||
# 获取关系名称,如果无法解析则跳过这个三元组
|
||||
relation_name = self.get_relation_name(triple['predicate']['uri'])
|
||||
if relation_name is None:
|
||||
continue # 跳过无法解析的关系
|
||||
|
||||
target = {
|
||||
'subject': triple['subject']['surfaceform'].strip(),
|
||||
'predicate': relation_name,
|
||||
'object': triple['object']['surfaceform'].strip()
|
||||
}
|
||||
valid_targets.append(target)
|
||||
|
||||
# 如果有有效的三元组,添加到结果中
|
||||
if valid_targets:
|
||||
processed_data.append({
|
||||
'text': text,
|
||||
'target': valid_targets
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
if idx <= 10: # 只打印前10个错误
|
||||
print(f"处理条目时出错 in {file_path} at index {idx}: {e}")
|
||||
continue
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"文件未找到: {file_path}")
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON解析错误 in {file_path}: {e}")
|
||||
except Exception as e:
|
||||
print(f"处理文件时出错 {file_path}: {e}")
|
||||
|
||||
print(f"从 {file_path} 提取了 {len(processed_data)} 个有效样本")
|
||||
return processed_data
|
||||
|
||||
def process_folder(self, folder_path: str, confidence_threshold: float,
|
||||
boundary_threshold: int) -> List[Dict[str, Any]]:
|
||||
"""处理文件夹中的所有JSON文件"""
|
||||
all_processed_data = []
|
||||
|
||||
if not os.path.exists(folder_path):
|
||||
raise FileNotFoundError(f"文件夹不存在: {folder_path}")
|
||||
|
||||
# 获取所有JSON文件
|
||||
json_files = [f for f in os.listdir(folder_path) if f.endswith('.json')]
|
||||
|
||||
if not json_files:
|
||||
raise ValueError(f"在 {folder_path} 中没有找到JSON文件")
|
||||
|
||||
print(f"找到 {len(json_files)} 个JSON文件")
|
||||
|
||||
for filename in sorted(json_files):
|
||||
file_path = os.path.join(folder_path, filename)
|
||||
processed_data = self.process_single_file(file_path, confidence_threshold, boundary_threshold)
|
||||
all_processed_data.extend(processed_data)
|
||||
|
||||
# 保存最终的关系缓存
|
||||
self.relation_manager.save_cache()
|
||||
|
||||
return all_processed_data
|
||||
|
||||
def generate_statistics(self, processed_data: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""生成数据统计信息"""
|
||||
total_samples = len(processed_data)
|
||||
total_triples = sum(len(sample['target']) for sample in processed_data)
|
||||
|
||||
# 统计关系类型
|
||||
relation_counts = defaultdict(int)
|
||||
for sample in processed_data:
|
||||
for target in sample['target']:
|
||||
relation_counts[target['predicate']] += 1
|
||||
|
||||
# 统计文本长度
|
||||
text_lengths = [len(sample['text']) for sample in processed_data]
|
||||
avg_text_length = sum(text_lengths) / len(text_lengths) if text_lengths else 0
|
||||
|
||||
# 统计每个文本的三元组数量
|
||||
triples_per_text = [len(sample['target']) for sample in processed_data]
|
||||
avg_triples_per_text = sum(triples_per_text) / len(triples_per_text) if triples_per_text else 0
|
||||
|
||||
return {
|
||||
'total_samples': total_samples,
|
||||
'total_triples': total_triples,
|
||||
'avg_text_length': round(avg_text_length, 2),
|
||||
'avg_triples_per_text': round(avg_triples_per_text, 2),
|
||||
'relation_distribution': dict(sorted(relation_counts.items(),
|
||||
key=lambda x: x[1], reverse=True)),
|
||||
'top_10_relations': dict(list(sorted(relation_counts.items(),
|
||||
key=lambda x: x[1], reverse=True))[:10]),
|
||||
'total_unique_relations': len(relation_counts),
|
||||
'cached_relations': len(self.relation_manager.relations)
|
||||
}
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='处理T-REx数据集(支持动态关系获取)')
|
||||
parser.add_argument('--folder_path', type=str,default='/home/pci/ycz/Code/Minimind/dataset/trex', help='包含JSON文件的文件夹路径')
|
||||
parser.add_argument('--confidence_threshold', type=float, default=0.5,
|
||||
help='置信度阈值 (默认: 0.0)')
|
||||
parser.add_argument('--boundary_threshold', type=int, default=0,
|
||||
help='边界长度阈值 (默认: 0, 不过滤)')
|
||||
parser.add_argument('--output', type=str, default='./processed_trex_data.json',
|
||||
help='输出文件名 (默认: processed_trex_data.json)')
|
||||
parser.add_argument('--stats', type=str, default='trex_statistics.json',
|
||||
help='统计信息输出文件名 (默认: trex_statistics.json)')
|
||||
parser.add_argument('--cache_file', type=str, default='wikidata_relations_cache.pkl',
|
||||
help='关系缓存文件名 (默认: wikidata_relations_cache.pkl)')
|
||||
parser.add_argument('--mapping_file', type=str, default="/home/pci/ycz/Code/Minimind/preprocessing/sample_property_mappings.json",
|
||||
help='JSON映射文件路径 (必须提供,用于关系名称映射)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("T-REx数据集处理器(支持动态关系获取)")
|
||||
print("=" * 60)
|
||||
print(f"输入文件夹: {args.folder_path}")
|
||||
print(f"置信度阈值: {args.confidence_threshold}")
|
||||
print(f"边界长度阈值: {args.boundary_threshold}")
|
||||
print(f"输出文件: {args.output}")
|
||||
print(f"关系缓存文件: {args.cache_file}")
|
||||
print(f"JSON映射文件: {args.mapping_file if args.mapping_file else '未指定'}")
|
||||
print("=" * 60)
|
||||
|
||||
# 检查映射文件是否存在
|
||||
if not args.mapping_file or not os.path.exists(args.mapping_file):
|
||||
print(f"错误: 映射文件不存在或未指定: {args.mapping_file}")
|
||||
print("请确保提供有效的JSON映射文件。")
|
||||
return 1
|
||||
|
||||
# 创建关系管理器
|
||||
relation_manager = WikidataRelationManager(
|
||||
cache_file=args.cache_file,
|
||||
mapping_file=args.mapping_file
|
||||
)
|
||||
|
||||
# 创建处理器
|
||||
processor = TRexProcessor(relation_manager)
|
||||
|
||||
try:
|
||||
# 处理数据
|
||||
processed_data = processor.process_folder(
|
||||
args.folder_path,
|
||||
args.confidence_threshold,
|
||||
args.boundary_threshold
|
||||
)
|
||||
|
||||
print(f"\n处理完成!总共处理了 {len(processed_data)} 个样本")
|
||||
|
||||
# 生成统计信息
|
||||
stats = processor.generate_statistics(processed_data)
|
||||
|
||||
# 保存处理后的数据
|
||||
with open(args.output, 'w', encoding='utf-8') as f:
|
||||
json.dump(processed_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 保存统计信息
|
||||
with open(args.stats, 'w', encoding='utf-8') as f:
|
||||
json.dump(stats, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"\n数据已保存到: {args.output}")
|
||||
print(f"统计信息已保存到: {args.stats}")
|
||||
print(f"关系缓存已保存到: {args.cache_file}")
|
||||
|
||||
# 打印统计摘要
|
||||
print("\n数据统计摘要:")
|
||||
print("=" * 30)
|
||||
print(f"总样本数: {stats['total_samples']}")
|
||||
print(f"总三元组数: {stats['total_triples']}")
|
||||
print(f"唯一关系数: {stats['total_unique_relations']}")
|
||||
print(f"缓存关系数: {stats['cached_relations']}")
|
||||
print(f"平均文本长度: {stats['avg_text_length']}")
|
||||
print(f"平均每文本三元组数: {stats['avg_triples_per_text']}")
|
||||
print("\n前10个最常见关系:")
|
||||
for relation, count in stats['top_10_relations'].items():
|
||||
print(f" {relation}: {count}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理过程中出错: {e}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
@ -1,441 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import json
|
||||
import re
|
||||
import asyncio
|
||||
import aiofiles
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from preprocessing.agent_system.extractor_agent.agent import DepartmentAgent
|
||||
from typing import Dict, List, Tuple
|
||||
import gc
|
||||
import time
|
||||
import psutil
|
||||
from tqdm.asyncio import tqdm as async_tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
json_path = "dataset/merged_pretrain_extra.jsonl"
|
||||
output_path = "dataset/processed_triples.jsonl"
|
||||
|
||||
# 优化后的配置参数 - 降低资源消耗
|
||||
BATCH_SIZE = 5000 # 减少批次大小:每批1万条数据
|
||||
MAX_CONCURRENT = 200 # 减少并发数:最多50条并发处理
|
||||
AGENT_POOL_SIZE = 20 # 大幅减少agent池大小:只创建5个agent实例
|
||||
|
||||
def get_memory_usage():
|
||||
"""获取当前内存使用情况"""
|
||||
process = psutil.Process(os.getpid())
|
||||
memory_info = process.memory_info()
|
||||
memory_mb = memory_info.rss / 1024 / 1024
|
||||
return memory_mb
|
||||
|
||||
def print_memory_info(stage=""):
|
||||
"""打印内存使用信息"""
|
||||
memory_mb = get_memory_usage()
|
||||
print(f"🔧 {stage} - 内存使用: {memory_mb:.1f} MB")
|
||||
|
||||
# 创建extractor_agent池,避免并发冲突
|
||||
def create_extractor_pool(pool_size: int = 5):
|
||||
"""创建extractor_agent池"""
|
||||
print(f"正在创建 {pool_size} 个agent实例...")
|
||||
agents = []
|
||||
for i in range(pool_size):
|
||||
try:
|
||||
agent = DepartmentAgent(model_type="deepseek")
|
||||
agents.append(agent)
|
||||
print(f" ✓ Agent {i+1}/{pool_size} 创建成功")
|
||||
except Exception as e:
|
||||
print(f" ✗ Agent {i+1} 创建失败: {e}")
|
||||
print(f"Agent池创建完成,实际创建了 {len(agents)} 个实例")
|
||||
return agents
|
||||
|
||||
# 延迟初始化agent池
|
||||
AGENT_POOL = None
|
||||
agent_pool_index = 0
|
||||
|
||||
def get_agent_pool():
|
||||
"""获取agent池,延迟初始化"""
|
||||
global AGENT_POOL
|
||||
if AGENT_POOL is None:
|
||||
print_memory_info("创建Agent池前")
|
||||
AGENT_POOL = create_extractor_pool(pool_size=AGENT_POOL_SIZE)
|
||||
print_memory_info("创建Agent池后")
|
||||
return AGENT_POOL
|
||||
|
||||
def get_next_agent():
|
||||
"""轮询获取下一个可用的agent"""
|
||||
global agent_pool_index
|
||||
pool = get_agent_pool()
|
||||
agent = pool[agent_pool_index % len(pool)]
|
||||
agent_pool_index += 1
|
||||
return agent
|
||||
|
||||
def clean_and_split_text(text):
|
||||
"""
|
||||
去除文本开头结尾的标记,并按句子分割
|
||||
"""
|
||||
# 去除开头的<|im_start|>和结尾的<|im_end|>
|
||||
text = text.strip()
|
||||
if text.startswith('<|im_start|>'):
|
||||
text = text[len('<|im_start|>'):]
|
||||
if text.endswith('<|im_end|>'):
|
||||
text = text[:-len('<|im_end|>')]
|
||||
|
||||
# 清理文本,去除多余的空白字符
|
||||
text = text.strip()
|
||||
|
||||
# 按句子分割(根据句号、问号、感叹号等标点符号)
|
||||
# 使用正则表达式匹配句子结束标志
|
||||
sentence_endings = r'[.!?。!?]'
|
||||
sentences = re.split(sentence_endings, text)
|
||||
|
||||
# 清理每个句子,去除空白和空句子
|
||||
cleaned_sentences = []
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if sentence and len(sentence) > 5: # 只保留非空且有意义的句子
|
||||
cleaned_sentences.append(sentence)
|
||||
|
||||
return cleaned_sentences
|
||||
|
||||
async def extract_triple_from_sentence_async(sentence: str, context: str = None) -> Dict:
|
||||
"""
|
||||
异步使用extractor_agent从句子中提取三元组
|
||||
"""
|
||||
try:
|
||||
# 获取一个agent实例
|
||||
agent = get_next_agent()
|
||||
result = await agent.async_run(sentence=sentence, context=context)
|
||||
return {
|
||||
"sentence": sentence,
|
||||
"triple": {
|
||||
"subject": result.triple.subject,
|
||||
"predicate": result.triple.predicate,
|
||||
"object": result.triple.object
|
||||
},
|
||||
"confidence": result.confidence
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"sentence": sentence,
|
||||
"triple": {
|
||||
"subject": "",
|
||||
"predicate": "",
|
||||
"object": ""
|
||||
},
|
||||
"confidence": 0.0,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def process_paragraph_async(line_num: int, original_text: str, semaphore: asyncio.Semaphore) -> Dict:
|
||||
"""
|
||||
异步处理单个段落
|
||||
"""
|
||||
async with semaphore: # 控制并发数量
|
||||
try:
|
||||
# 清理并分割文本
|
||||
sentences = clean_and_split_text(original_text)
|
||||
|
||||
if not sentences:
|
||||
return None
|
||||
|
||||
# 构建当前段落的结果
|
||||
paragraph_result = {
|
||||
"source_line": line_num,
|
||||
"original_paragraph": original_text,
|
||||
"sentences": [],
|
||||
"triples": []
|
||||
}
|
||||
|
||||
# 异步处理所有句子
|
||||
tasks = []
|
||||
for sentence in sentences:
|
||||
task = extract_triple_from_sentence_async(sentence, context=original_text)
|
||||
tasks.append(task)
|
||||
|
||||
# 等待所有句子处理完成
|
||||
triple_results = await asyncio.gather(*tasks)
|
||||
|
||||
# 整理结果
|
||||
for i, sentence in enumerate(sentences):
|
||||
paragraph_result["sentences"].append(sentence)
|
||||
paragraph_result["triples"].append(triple_results[i])
|
||||
|
||||
return paragraph_result
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理第 {line_num} 行时出错: {e}")
|
||||
return None
|
||||
|
||||
async def process_batch_async(batch_data: List[Tuple[int, str]], batch_num: int) -> List[Dict]:
|
||||
"""
|
||||
异步处理一个批次的数据,带进度条和内存监控
|
||||
"""
|
||||
print(f"\n=== 异步处理批次 {batch_num} ===")
|
||||
print(f"批次大小: {len(batch_data)} 条记录")
|
||||
print_memory_info(f"批次 {batch_num} 开始前")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 创建信号量控制并发数量
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT)
|
||||
|
||||
# 分块处理任务,避免一次性创建太多任务
|
||||
chunk_size = 1000 # 每次处理1000个任务
|
||||
all_results = []
|
||||
|
||||
for chunk_start in range(0, len(batch_data), chunk_size):
|
||||
chunk_end = min(chunk_start + chunk_size, len(batch_data))
|
||||
chunk_data = batch_data[chunk_start:chunk_end]
|
||||
|
||||
print(f"处理子块 {chunk_start//chunk_size + 1}/{(len(batch_data)-1)//chunk_size + 1} ({len(chunk_data)} 条记录)")
|
||||
|
||||
# 创建当前块的异步任务
|
||||
tasks = []
|
||||
for line_num, original_text in chunk_data:
|
||||
task = process_paragraph_async(line_num, original_text, semaphore)
|
||||
tasks.append(task)
|
||||
|
||||
# 使用进度条处理当前块
|
||||
progress_bar = tqdm(total=len(tasks), desc=f"批次{batch_num}-块{chunk_start//chunk_size + 1}", unit="段落", ncols=100)
|
||||
|
||||
chunk_results = []
|
||||
completed_tasks = 0
|
||||
|
||||
# 使用as_completed来获取完成的任务,并更新进度条
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
try:
|
||||
result = await coro
|
||||
chunk_results.append(result)
|
||||
completed_tasks += 1
|
||||
|
||||
# 更新进度条
|
||||
progress_bar.update(1)
|
||||
|
||||
# 每完成50个任务更新一次描述
|
||||
if completed_tasks % 50 == 0:
|
||||
valid_results = [r for r in chunk_results if r is not None]
|
||||
progress_bar.set_postfix({
|
||||
'有效': len(valid_results),
|
||||
'完成': completed_tasks,
|
||||
'成功率': f"{len(valid_results)/completed_tasks*100:.1f}%"
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"任务执行失败: {e}")
|
||||
completed_tasks += 1
|
||||
progress_bar.update(1)
|
||||
|
||||
progress_bar.close()
|
||||
all_results.extend(chunk_results)
|
||||
|
||||
# 每个块完成后清理内存
|
||||
del tasks, chunk_results
|
||||
gc.collect()
|
||||
|
||||
print_memory_info(f"批次 {batch_num} 块 {chunk_start//chunk_size + 1} 完成后")
|
||||
|
||||
# 过滤None结果
|
||||
valid_results = [result for result in all_results if result is not None]
|
||||
|
||||
# 统计信息
|
||||
batch_sentences = sum(len(result["sentences"]) for result in valid_results)
|
||||
batch_triples = sum(
|
||||
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
|
||||
for result in valid_results
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print(f"批次 {batch_num} 异步处理完成:")
|
||||
print(f" - 有效段落: {len(valid_results)}/{len(batch_data)} ({len(valid_results)/len(batch_data)*100:.1f}%)")
|
||||
print(f" - 总句子数: {batch_sentences}")
|
||||
print(f" - 成功三元组: {batch_triples}")
|
||||
print(f" - 三元组成功率: {batch_triples/batch_sentences*100:.1f}%" if batch_sentences > 0 else "无句子")
|
||||
print(f" - 处理时间: {processing_time:.2f}秒")
|
||||
print(f" - 处理速度: {len(batch_data)/processing_time:.2f}段落/秒")
|
||||
|
||||
print_memory_info(f"批次 {batch_num} 完成后")
|
||||
|
||||
return valid_results
|
||||
|
||||
async def write_results_batch(results: List[Dict], output_path: str):
|
||||
"""
|
||||
异步批量写入结果,带进度提示
|
||||
"""
|
||||
try:
|
||||
print(f"开始批量写入 {len(results)} 条结果...")
|
||||
|
||||
# 准备写入内容
|
||||
content_lines = []
|
||||
for result in results:
|
||||
content_lines.append(json.dumps(result, ensure_ascii=False))
|
||||
|
||||
# 异步批量写入
|
||||
async with aiofiles.open(output_path, "a", encoding="utf-8") as f:
|
||||
await f.write("\n".join(content_lines) + "\n")
|
||||
|
||||
print(f"✓ 成功批量写入 {len(results)} 条结果到 {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 批量写入失败: {e}")
|
||||
print("尝试逐条写入...")
|
||||
|
||||
# 如果批量写入失败,回退到逐条写入(带进度条)
|
||||
async with aiofiles.open(output_path, "a", encoding="utf-8") as f:
|
||||
for result in tqdm(results, desc="逐条写入", unit="条"):
|
||||
await f.write(json.dumps(result, ensure_ascii=False) + "\n")
|
||||
print(f"✓ 逐条写入完成")
|
||||
|
||||
# 主处理流程
|
||||
async def main_async():
|
||||
total_processed = 0
|
||||
total_sentences = 0
|
||||
total_triples = 0
|
||||
batch_num = 0
|
||||
|
||||
print("=== 开始异步批次处理JSONL文件 ===")
|
||||
print(f"优化后的配置信息:")
|
||||
print(f" - 批次大小: {BATCH_SIZE:,} 条记录")
|
||||
print(f" - 最大并发数: {MAX_CONCURRENT}")
|
||||
print(f" - Agent池大小: {AGENT_POOL_SIZE}")
|
||||
print(f" - 输入文件: {json_path}")
|
||||
print(f" - 输出文件: {output_path}")
|
||||
print()
|
||||
|
||||
print_memory_info("程序开始")
|
||||
|
||||
# 清空输出文件
|
||||
async with aiofiles.open(output_path, "w", encoding="utf-8") as f:
|
||||
pass
|
||||
|
||||
# 读取并处理数据
|
||||
with open(json_path, "r", encoding="utf-8") as f_in:
|
||||
batch_data = []
|
||||
|
||||
for line_num, line in enumerate(f_in):
|
||||
if line.strip(): # 跳过空行
|
||||
try:
|
||||
item = json.loads(line)
|
||||
original_text = item.get("text", "")
|
||||
|
||||
if original_text:
|
||||
batch_data.append((line_num + 1, original_text))
|
||||
|
||||
# 当批次达到指定大小时,异步处理这个批次
|
||||
if len(batch_data) >= BATCH_SIZE:
|
||||
batch_num += 1
|
||||
|
||||
# 异步处理批次
|
||||
batch_results = await process_batch_async(batch_data, batch_num)
|
||||
|
||||
# 批量写入结果
|
||||
if batch_results:
|
||||
await write_results_batch(batch_results, output_path)
|
||||
|
||||
# 统计信息
|
||||
batch_sentences = sum(len(result["sentences"]) for result in batch_results)
|
||||
batch_triples = sum(
|
||||
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
|
||||
for result in batch_results
|
||||
)
|
||||
|
||||
total_processed += len(batch_data)
|
||||
total_sentences += batch_sentences
|
||||
total_triples += batch_triples
|
||||
|
||||
print(f"\n📊 批次 {batch_num} 累计统计:")
|
||||
print(f" - 累计处理段落: {total_processed:,}")
|
||||
print(f" - 累计句子数: {total_sentences:,}")
|
||||
print(f" - 累计三元组: {total_triples:,}")
|
||||
print(f" - 整体成功率: {total_triples/total_sentences*100:.1f}%")
|
||||
print("-" * 80)
|
||||
|
||||
# 清理批次数据,释放内存
|
||||
batch_data.clear()
|
||||
batch_results.clear()
|
||||
gc.collect() # 强制垃圾回收
|
||||
|
||||
print_memory_info(f"批次 {batch_num} 清理后")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"第 {line_num + 1} 行JSON解析错误: {e}")
|
||||
except Exception as e:
|
||||
print(f"处理第 {line_num + 1} 行时出错: {e}")
|
||||
|
||||
# 处理最后一个不完整的批次
|
||||
if batch_data:
|
||||
batch_num += 1
|
||||
batch_results = await process_batch_async(batch_data, batch_num)
|
||||
|
||||
if batch_results:
|
||||
await write_results_batch(batch_results, output_path)
|
||||
|
||||
batch_sentences = sum(len(result["sentences"]) for result in batch_results)
|
||||
batch_triples = sum(
|
||||
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
|
||||
for result in batch_results
|
||||
)
|
||||
|
||||
total_processed += len(batch_data)
|
||||
total_sentences += batch_sentences
|
||||
total_triples += batch_triples
|
||||
|
||||
# 最终统计
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🎉 所有批次异步处理完成!")
|
||||
print(f"{'='*80}")
|
||||
print(f"最终统计:")
|
||||
print(f" - 总批次数: {batch_num}")
|
||||
print(f" - 总段落数: {total_processed:,}")
|
||||
print(f" - 总句子数: {total_sentences:,}")
|
||||
print(f" - 总三元组: {total_triples:,}")
|
||||
print(f" - 整体成功率: {total_triples/total_sentences*100:.1f}%" if total_sentences > 0 else "无有效句子")
|
||||
print(f" - 输出文件: {output_path}")
|
||||
print(f"{'='*80}")
|
||||
|
||||
print_memory_info("程序结束前")
|
||||
|
||||
# 显示示例结果
|
||||
await show_sample_results()
|
||||
|
||||
async def show_sample_results():
|
||||
"""显示前几个处理结果作为示例"""
|
||||
print("\n📋 前3个处理结果示例:")
|
||||
try:
|
||||
async with aiofiles.open(output_path, "r", encoding="utf-8") as f:
|
||||
i = 0
|
||||
async for line in f:
|
||||
if i >= 3:
|
||||
break
|
||||
item = json.loads(line)
|
||||
print(f"\n--- 段落 {i+1} (来源行: {item['source_line']}) ---")
|
||||
print(f"原始段落: {item['original_paragraph'][:100]}...")
|
||||
print(f"句子数量: {len(item['sentences'])}")
|
||||
if item['triples']:
|
||||
for j, triple in enumerate(item['triples'][:2]): # 只显示前2个三元组
|
||||
print(f" 句子 {j+1}: {triple['sentence'][:50]}...")
|
||||
if triple['confidence'] > 0:
|
||||
print(f" 三元组: {triple['triple']['subject']} -> {triple['triple']['predicate']} -> {triple['triple']['object']}")
|
||||
print(f" 置信度: {triple['confidence']:.2f}")
|
||||
else:
|
||||
print(f" 提取失败: {triple.get('error', '未知错误')}")
|
||||
i += 1
|
||||
except Exception as e:
|
||||
print(f"读取示例结果时出错: {e}")
|
||||
|
||||
def main():
|
||||
"""主入口函数"""
|
||||
try:
|
||||
# 运行异步主函数
|
||||
asyncio.run(main_async())
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ 用户中断处理")
|
||||
except Exception as e:
|
||||
print(f"❌ 处理过程中出现错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
176
pyproject.toml
176
pyproject.toml
@ -1,176 +0,0 @@
|
||||
[project]
|
||||
name = "minimind"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"accelerate==1.7.0",
|
||||
"aiohappyeyeballs==2.6.1",
|
||||
"aiohttp==3.11.17",
|
||||
"aiosignal==1.3.2",
|
||||
"altair==5.5.0",
|
||||
"annotated-types==0.7.0",
|
||||
"anyio==4.9.0",
|
||||
"async-timeout==5.0.1",
|
||||
"attrs==25.3.0",
|
||||
"blinker==1.9.0",
|
||||
"boto3==1.38.41",
|
||||
"botocore==1.38.41",
|
||||
"cachetools==5.5.2",
|
||||
"certifi==2025.1.31",
|
||||
"charset-normalizer==3.4.1",
|
||||
"click==8.1.8",
|
||||
"contourpy==1.3.2",
|
||||
"cycler==0.12.1",
|
||||
"datasets==2.21.0",
|
||||
"datasketch==1.6.4",
|
||||
"deepspeed==0.17.0",
|
||||
"determined>=0.37.0",
|
||||
"dill==0.3.8",
|
||||
"distro==1.9.0",
|
||||
"docker-pycreds==0.4.0",
|
||||
"einops==0.8.1",
|
||||
"exceptiongroup==1.2.2",
|
||||
"filelock==3.18.0",
|
||||
"Flask==3.0.3",
|
||||
"Flask-Cors==4.0.0",
|
||||
"fonttools==4.57.0",
|
||||
"frozenlist==1.6.0",
|
||||
"fsspec==2024.6.1",
|
||||
"gitdb==4.0.12",
|
||||
"GitPython==3.1.44",
|
||||
"h11==0.14.0",
|
||||
"hjson==3.1.0",
|
||||
"httpcore==1.0.8",
|
||||
"httpx==0.28.1",
|
||||
"huggingface-hub==0.30.2",
|
||||
"importlib_metadata==7.2.1",
|
||||
"itsdangerous==2.2.0",
|
||||
"jieba==0.42.1",
|
||||
"Jinja2==3.1.2",
|
||||
"jiter==0.9.0",
|
||||
"jmespath==1.0.1",
|
||||
"joblib==1.4.2",
|
||||
"jsonlines==4.0.0",
|
||||
"jsonpointer==2.1",
|
||||
"jsonschema==4.23.0",
|
||||
"jsonschema-specifications==2024.10.1",
|
||||
"kiwisolver==1.4.8",
|
||||
"langdetect==1.0.9",
|
||||
"markdown-it-py==3.0.0",
|
||||
"MarkupSafe==3.0.2",
|
||||
"marshmallow==3.22.0",
|
||||
"matplotlib==3.10.0",
|
||||
"mdurl==0.1.2",
|
||||
"modelscope==1.25.0",
|
||||
"mpi4py>=4.0.3",
|
||||
"mpmath==1.3.0",
|
||||
"msgpack==1.1.0",
|
||||
"multidict==6.4.3",
|
||||
"multiprocess==0.70.16",
|
||||
"narwhals==1.35.0",
|
||||
"networkx==3.4.2",
|
||||
"ngrok==1.4.0",
|
||||
"ninja==1.11.1.4",
|
||||
"nltk==3.8",
|
||||
"numpy==1.26.4",
|
||||
"nvidia-cublas-cu11==11.11.3.6",
|
||||
"nvidia-cublas-cu12==12.6.4.1",
|
||||
"nvidia-cuda-cupti-cu11==11.8.87",
|
||||
"nvidia-cuda-cupti-cu12==12.6.80",
|
||||
"nvidia-cuda-nvrtc-cu11==11.8.89",
|
||||
"nvidia-cuda-nvrtc-cu12==12.6.77",
|
||||
"nvidia-cuda-runtime-cu11==11.8.89",
|
||||
"nvidia-cuda-runtime-cu12==12.6.77",
|
||||
"nvidia-cudnn-cu11==9.1.0.70",
|
||||
"nvidia-cudnn-cu12==9.5.1.17",
|
||||
"nvidia-cufft-cu11==10.9.0.58",
|
||||
"nvidia-cufft-cu12==11.3.0.4",
|
||||
"nvidia-cufile-cu12==1.11.1.6",
|
||||
"nvidia-curand-cu11==10.3.0.86",
|
||||
"nvidia-curand-cu12==10.3.7.77",
|
||||
"nvidia-cusolver-cu11==11.4.1.48",
|
||||
"nvidia-cusolver-cu12==11.7.1.2",
|
||||
"nvidia-cusparse-cu11==11.7.5.86",
|
||||
"nvidia-cusparse-cu12==12.5.4.2",
|
||||
"nvidia-cusparselt-cu12==0.6.3",
|
||||
"nvidia-ml-py==12.575.51",
|
||||
"nvidia-nccl-cu11==2.21.5",
|
||||
"nvidia-nccl-cu12==2.26.2",
|
||||
"nvidia-nvjitlink-cu12==12.6.85",
|
||||
"nvidia-nvtx-cu11==11.8.86",
|
||||
"nvidia-nvtx-cu12==12.6.77",
|
||||
"openai==1.59.6",
|
||||
"packaging==23.2",
|
||||
"pandas>=2.0.0",
|
||||
"peft==0.7.1",
|
||||
"pillow==10.4.0",
|
||||
"platformdirs==4.3.7",
|
||||
"prettytable==3.16.0",
|
||||
"propcache==0.3.1",
|
||||
"protobuf==4.25.6",
|
||||
"psutil==5.9.8",
|
||||
"py-cpuinfo==9.0.0",
|
||||
"pyarrow==19.0.1",
|
||||
"pydantic==2.11.7",
|
||||
"pydantic_core==2.33.2",
|
||||
"pydeck==0.9.1",
|
||||
"pyecharts==2.0.8",
|
||||
"Pygments==2.19.1",
|
||||
"pynvml==12.0.0",
|
||||
"pyparsing==3.2.3",
|
||||
"python-dateutil==2.9.0.post0",
|
||||
"pytz==2025.2",
|
||||
"PyYAML==6.0.2",
|
||||
"referencing==0.36.2",
|
||||
"regex==2024.11.6",
|
||||
"requests==2.32.3",
|
||||
"rich==13.7.1",
|
||||
"rouge-score>=0.1.2",
|
||||
"rpds-py==0.24.0",
|
||||
"s3transfer==0.13.0",
|
||||
"safetensors==0.5.3",
|
||||
"scikit-learn==1.5.1",
|
||||
"scipy==1.15.2",
|
||||
"sentence-transformers==2.3.1",
|
||||
"sentencepiece==0.2.0",
|
||||
"sentry-sdk==2.26.1",
|
||||
"setproctitle==1.3.5",
|
||||
"simhash==2.1.2",
|
||||
"simplejson==3.20.1",
|
||||
"six==1.17.0",
|
||||
"smmap==5.0.2",
|
||||
"sniffio==1.3.1",
|
||||
"streamlit==1.30.0",
|
||||
"swankit==0.2.4",
|
||||
"swanlab==0.6.4",
|
||||
"sympy==1.13.3",
|
||||
"tenacity==8.5.0",
|
||||
"threadpoolctl==3.6.0",
|
||||
"tiktoken>=0.8.0",
|
||||
"tokenizers==0.21.1",
|
||||
"toml==0.10.2",
|
||||
"torch==2.7.1",
|
||||
"torchaudio==2.7.1",
|
||||
"torchvision==0.22.1",
|
||||
"tornado==6.4.2",
|
||||
"tqdm==4.67.1",
|
||||
"transformers==4.52.4",
|
||||
"triton==3.3.1",
|
||||
"trl==0.13.0",
|
||||
"typing-inspection==0.4.1",
|
||||
"typing_extensions==4.13.2",
|
||||
"tzlocal==5.3.1",
|
||||
"ujson==5.1.0",
|
||||
"urllib3==2.4.0",
|
||||
"validators==0.34.0",
|
||||
"wandb==0.18.3",
|
||||
"watchdog==6.0.0",
|
||||
"wcwidth==0.2.13",
|
||||
"Werkzeug==3.1.3",
|
||||
"wrapt==1.17.2",
|
||||
"xxhash==3.5.0",
|
||||
"yarl==1.20.0",
|
||||
"zipp==3.21.0",
|
||||
]
|
@ -1,4 +1,3 @@
|
||||
accelerate==1.7.0
|
||||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.11.17
|
||||
aiosignal==1.3.2
|
||||
@ -8,8 +7,6 @@ anyio==4.9.0
|
||||
async-timeout==5.0.1
|
||||
attrs==25.3.0
|
||||
blinker==1.9.0
|
||||
boto3==1.38.41
|
||||
botocore==1.38.41
|
||||
cachetools==5.5.2
|
||||
certifi==2025.1.31
|
||||
charset-normalizer==3.4.1
|
||||
@ -18,7 +15,6 @@ contourpy==1.3.2
|
||||
cycler==0.12.1
|
||||
datasets==2.21.0
|
||||
datasketch==1.6.4
|
||||
deepspeed==0.17.0
|
||||
dill==0.3.8
|
||||
distro==1.9.0
|
||||
docker-pycreds==0.4.0
|
||||
@ -37,19 +33,17 @@ hjson==3.1.0
|
||||
httpcore==1.0.8
|
||||
httpx==0.28.1
|
||||
huggingface-hub==0.30.2
|
||||
idna==3.10
|
||||
importlib_metadata==7.2.1
|
||||
itsdangerous==2.2.0
|
||||
jieba==0.42.1
|
||||
Jinja2==3.1.2
|
||||
jiter==0.9.0
|
||||
jmespath==1.0.1
|
||||
joblib==1.4.2
|
||||
jsonlines==4.0.0
|
||||
jsonpointer==2.1
|
||||
jsonschema==4.23.0
|
||||
jsonschema-specifications==2024.10.1
|
||||
kiwisolver==1.4.8
|
||||
langdetect==1.0.9
|
||||
markdown-it-py==3.0.0
|
||||
MarkupSafe==3.0.2
|
||||
marshmallow==3.22.0
|
||||
@ -66,50 +60,21 @@ ngrok==1.4.0
|
||||
ninja==1.11.1.4
|
||||
nltk==3.8
|
||||
numpy==1.26.4
|
||||
nvidia-cublas-cu11==11.11.3.6
|
||||
nvidia-cublas-cu12==12.6.4.1
|
||||
nvidia-cuda-cupti-cu11==11.8.87
|
||||
nvidia-cuda-cupti-cu12==12.6.80
|
||||
nvidia-cuda-nvrtc-cu11==11.8.89
|
||||
nvidia-cuda-nvrtc-cu12==12.6.77
|
||||
nvidia-cuda-runtime-cu11==11.8.89
|
||||
nvidia-cuda-runtime-cu12==12.6.77
|
||||
nvidia-cudnn-cu11==9.1.0.70
|
||||
nvidia-cudnn-cu12==9.5.1.17
|
||||
nvidia-cufft-cu11==10.9.0.58
|
||||
nvidia-cufft-cu12==11.3.0.4
|
||||
nvidia-cufile-cu12==1.11.1.6
|
||||
nvidia-curand-cu11==10.3.0.86
|
||||
nvidia-curand-cu12==10.3.7.77
|
||||
nvidia-cusolver-cu11==11.4.1.48
|
||||
nvidia-cusolver-cu12==11.7.1.2
|
||||
nvidia-cusparse-cu11==11.7.5.86
|
||||
nvidia-cusparse-cu12==12.5.4.2
|
||||
nvidia-cusparselt-cu12==0.6.3
|
||||
nvidia-ml-py==12.575.51
|
||||
nvidia-nccl-cu11==2.21.5
|
||||
nvidia-nccl-cu12==2.26.2
|
||||
nvidia-nvjitlink-cu12==12.6.85
|
||||
nvidia-nvtx-cu11==11.8.86
|
||||
nvidia-nvtx-cu12==12.6.77
|
||||
openai==1.59.6
|
||||
packaging==23.2
|
||||
pandas==1.5.3
|
||||
peft==0.7.1
|
||||
pillow==10.4.0
|
||||
platformdirs==4.3.7
|
||||
prettytable==3.16.0
|
||||
propcache==0.3.1
|
||||
protobuf==4.25.6
|
||||
psutil==5.9.8
|
||||
py-cpuinfo==9.0.0
|
||||
pyarrow==19.0.1
|
||||
pydantic==2.11.7
|
||||
pydantic_core==2.33.2
|
||||
pydantic==2.8.2
|
||||
pydantic_core==2.20.1
|
||||
pydeck==0.9.1
|
||||
pyecharts==2.0.8
|
||||
Pygments==2.19.1
|
||||
pynvml==12.0.0
|
||||
pyparsing==3.2.3
|
||||
python-dateutil==2.9.0.post0
|
||||
pytz==2025.2
|
||||
@ -119,7 +84,6 @@ regex==2024.11.6
|
||||
requests==2.32.3
|
||||
rich==13.7.1
|
||||
rpds-py==0.24.0
|
||||
s3transfer==0.13.0
|
||||
safetensors==0.5.3
|
||||
scikit-learn==1.5.1
|
||||
scipy==1.15.2
|
||||
@ -128,28 +92,21 @@ sentencepiece==0.2.0
|
||||
sentry-sdk==2.26.1
|
||||
setproctitle==1.3.5
|
||||
simhash==2.1.2
|
||||
simplejson==3.20.1
|
||||
six==1.17.0
|
||||
smmap==5.0.2
|
||||
sniffio==1.3.1
|
||||
streamlit==1.30.0
|
||||
swankit==0.2.4
|
||||
swanlab==0.6.4
|
||||
sympy==1.13.3
|
||||
tenacity==8.5.0
|
||||
threadpoolctl==3.6.0
|
||||
tiktoken==0.5.1
|
||||
tokenizers==0.21.1
|
||||
toml==0.10.2
|
||||
torch==2.7.1
|
||||
torchaudio==2.7.1
|
||||
torchvision==0.22.1
|
||||
tornado==6.4.2
|
||||
tqdm==4.67.1
|
||||
transformers==4.52.4
|
||||
triton==3.3.1
|
||||
transformers==4.48.0
|
||||
triton==3.3.0
|
||||
trl==0.13.0
|
||||
typing-inspection==0.4.1
|
||||
typing_extensions==4.13.2
|
||||
tzlocal==5.3.1
|
||||
ujson==5.1.0
|
||||
@ -157,9 +114,7 @@ urllib3==2.4.0
|
||||
validators==0.34.0
|
||||
wandb==0.18.3
|
||||
watchdog==6.0.0
|
||||
wcwidth==0.2.13
|
||||
Werkzeug==3.1.3
|
||||
wrapt==1.17.2
|
||||
xxhash==3.5.0
|
||||
yarl==1.20.0
|
||||
zipp==3.21.0
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
# 激活conda环境
|
||||
# source $(conda info --base)/etc/profile.d/conda.sh
|
||||
# conda activate mini
|
||||
# conda activate ycz_accelerate
|
||||
|
||||
# 设置环境变量以帮助调试
|
||||
export NCCL_DEBUG=INFO
|
||||
@ -26,27 +26,24 @@ export PYTHONFAULTHANDLER=1
|
||||
# --profile_interval 10
|
||||
|
||||
# 方法2: 使用命令行参数直接配置accelerate
|
||||
# 内存泄漏调试配置 - 减少内存使用
|
||||
CUDA_VISIBLE_DEVICES=0 uv run -p .venv python -m accelerate.commands.launch \
|
||||
CUDA_VISIBLE_DEVICES=0 accelerate launch \
|
||||
--num_processes=1 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py
|
||||
# --batch_size 128 \
|
||||
# --num_workers 1
|
||||
# --knowledge_num 48020 \
|
||||
# --num_workers 1 \
|
||||
# --epochs 4 \
|
||||
# --learning_rate 2e-4 \
|
||||
# --dtype bfloat16 \
|
||||
# --accumulation_steps 32 \
|
||||
# --grad_clip 1.0 \
|
||||
# --log_interval 50 \
|
||||
# --save_interval 10000 \
|
||||
# --dim 512 \
|
||||
# --n_layers 8 \
|
||||
# --max_seq_len 512 \
|
||||
# --use_flash_attn \
|
||||
# --profile \
|
||||
# --profile_interval 10
|
||||
|
||||
train_pretrain_accelerate.py \
|
||||
--epochs 3 \
|
||||
--batch_size 24 \
|
||||
--learning_rate 2e-4 \
|
||||
--dtype bfloat16 \
|
||||
--accumulation_steps 32 \
|
||||
--grad_clip 1.0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 512 \
|
||||
--n_layers 12 \
|
||||
--max_seq_len 512 \
|
||||
--use_flash_attn \
|
||||
--profile \
|
||||
--profile_interval 10\
|
||||
--knowledge_num 4096 \
|
||||
--knowledge_length 8
|
||||
|
@ -1,46 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 激活conda环境
|
||||
source $(conda info --base)/etc/profile.d/conda.sh
|
||||
conda activate ycz_accelerate
|
||||
|
||||
# 设置环境变量以帮助调试
|
||||
export NCCL_DEBUG=INFO
|
||||
export PYTHONFAULTHANDLER=1
|
||||
|
||||
# 实验1.3.0 - 使用命令行参数直接配置accelerate
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py \
|
||||
--out_dir "out" \
|
||||
--epochs 3 \
|
||||
--embedding_epoch 2 \
|
||||
--batch_size 64 \
|
||||
--learning_rate 8e-5 \
|
||||
--dtype bfloat16 \
|
||||
--use_swanlab \
|
||||
--swanlab_project "MiniMind-Pretrain" \
|
||||
--num_workers 1 \
|
||||
--accumulation_steps 16 \
|
||||
--grad_clip 0.5 \
|
||||
--warmup_iters 0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 1024 \
|
||||
--n_layers 48 \
|
||||
--max_seq_len 512 \
|
||||
--data_path "./dataset/stable/merged_pretrain.jsonl" \
|
||||
--profile \
|
||||
--profile_interval 10 \
|
||||
--use_flash_attn \
|
||||
--knowledge_num 1048576 \
|
||||
--knowledge_length 32 \
|
||||
--database_init_path "./dataset/stable/sentence_trex_data.json" \
|
||||
--fast_clustering \
|
||||
--cluster_cache_path "./cache/cluster_tokens_single.pt" \
|
||||
--memory_monitor_interval 10 \
|
||||
--model_type "model_original" \
|
||||
--model_size 538
|
@ -1,47 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 激活conda环境
|
||||
source $(conda info --base)/etc/profile.d/conda.sh
|
||||
conda activate ycz_accelerate
|
||||
|
||||
# 设置环境变量以帮助调试
|
||||
export NCCL_DEBUG=INFO
|
||||
export PYTHONFAULTHANDLER=1
|
||||
|
||||
# 实验1.3.0 - 使用命令行参数直接配置accelerate
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py \
|
||||
--out_dir "out" \
|
||||
--epochs 3 \
|
||||
--embedding_epoch 2 \
|
||||
--batch_size 48 \
|
||||
--learning_rate 2e-4 \
|
||||
--dtype bfloat16 \
|
||||
--use_swanlab \
|
||||
--swanlab_project "MiniMind-Pretrain" \
|
||||
--num_workers 1 \
|
||||
--accumulation_steps 32 \
|
||||
--grad_clip 1.0 \
|
||||
--warmup_iters 0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 1024 \
|
||||
--n_layers 18 \
|
||||
--max_seq_len 512 \
|
||||
--use_moe False \
|
||||
--data_path "./dataset/stable/merged_pretrain.jsonl" \
|
||||
--profile \
|
||||
--profile_interval 10 \
|
||||
--use_flash_attn \
|
||||
--knowledge_num 1048576 \
|
||||
--knowledge_length 32 \
|
||||
--database_init_path "./dataset/stable/sentence_trex_data.json" \
|
||||
--fast_clustering \
|
||||
--cluster_cache_path "./cache/cluster_tokens_single.pt" \
|
||||
--memory_monitor_interval 10 \
|
||||
--model_type "model_no_feed" \
|
||||
--model_size 814.724
|
@ -1,47 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 激活conda环境
|
||||
source $(conda info --base)/etc/profile.d/conda.sh
|
||||
conda activate ycz_accelerate
|
||||
|
||||
# 设置环境变量以帮助调试
|
||||
export NCCL_DEBUG=INFO
|
||||
export PYTHONFAULTHANDLER=1
|
||||
|
||||
# 实验1.3.0 - 使用命令行参数直接配置accelerate
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py \
|
||||
--out_dir "out" \
|
||||
--epochs 3 \
|
||||
--embedding_epoch 2 \
|
||||
--batch_size 48 \
|
||||
--learning_rate 2e-4 \
|
||||
--dtype bfloat16 \
|
||||
--use_swanlab \
|
||||
--swanlab_project "MiniMind-Pretrain" \
|
||||
--num_workers 1 \
|
||||
--accumulation_steps 32 \
|
||||
--grad_clip 1.0 \
|
||||
--warmup_iters 0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 1024 \
|
||||
--n_layers 18 \
|
||||
--max_seq_len 512 \
|
||||
--use_moe False \
|
||||
--data_path "./dataset/stable/merged_pretrain.jsonl" \
|
||||
--profile \
|
||||
--profile_interval 10 \
|
||||
--use_flash_attn \
|
||||
--knowledge_num 1048576 \
|
||||
--knowledge_length 32 \
|
||||
--database_init_path "./dataset/stable/sentence_trex_data.json" \
|
||||
--fast_clustering \
|
||||
--cluster_cache_path "./cache/cluster_tokens_single.pt" \
|
||||
--memory_monitor_interval 10 \
|
||||
--model_type "model" \
|
||||
--model_size 814.724
|
@ -1,45 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 激活conda环境
|
||||
# source $(conda info --base)/etc/profile.d/conda.sh
|
||||
# conda activate ycz_accelerate
|
||||
|
||||
# 设置环境变量以帮助调试
|
||||
export NCCL_DEBUG=INFO
|
||||
export PYTHONFAULTHANDLER=1
|
||||
|
||||
# 实验1.3.0 - 使用命令行参数直接配置accelerate
|
||||
CUDA_VISIBLE_DEVICES=0 accelerate launch \
|
||||
--num_processes=1 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py \
|
||||
--out_dir "out" \
|
||||
--epochs 3 \
|
||||
--embedding_epoch 2 \
|
||||
--batch_size 128 \
|
||||
--learning_rate 8e-5 \
|
||||
--dtype bfloat16 \
|
||||
--use_swanlab \
|
||||
--swanlab_project "MiniMind-Pretrain" \
|
||||
--num_workers 1 \
|
||||
--accumulation_steps 16 \
|
||||
--grad_clip 0.5 \
|
||||
--warmup_iters 0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 512 \
|
||||
--n_layers 8 \
|
||||
--max_seq_len 512 \
|
||||
--data_path "./dataset/stable/merged_pretrain.jsonl" \
|
||||
--profile \
|
||||
--profile_interval 10 \
|
||||
--use_flash_attn \
|
||||
--knowledge_num 1048576 \
|
||||
--knowledge_length 32 \
|
||||
--database_init_path "./dataset/stable/sentence_trex_data.json" \
|
||||
--fast_clustering \
|
||||
--cluster_cache_path "./cache/cluster_tokens_single.pt" \
|
||||
--memory_monitor_interval 10 \
|
||||
--model_type "model" \
|
||||
--model_size 538
|
@ -32,7 +32,7 @@ def train_tokenizer():
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
||||
|
||||
# 定义特殊token
|
||||
special_tokens = ["<unk>", "<|im_start|>", "<|im_end|>"]
|
||||
special_tokens = ["<unk>", "<s>", "</s>"]
|
||||
|
||||
# 设置训练器并添加特殊token
|
||||
trainer = trainers.BpeTrainer(
|
||||
@ -53,8 +53,8 @@ def train_tokenizer():
|
||||
|
||||
# 检查特殊token的索引
|
||||
assert tokenizer.token_to_id("<unk>") == 0
|
||||
assert tokenizer.token_to_id("<|im_start|>") == 1
|
||||
assert tokenizer.token_to_id("<|im_end|>") == 2
|
||||
assert tokenizer.token_to_id("<s>") == 1
|
||||
assert tokenizer.token_to_id("</s>") == 2
|
||||
|
||||
# 保存tokenizer
|
||||
tokenizer_dir = "../model/minimind_tokenizer"
|
||||
@ -77,7 +77,7 @@ def train_tokenizer():
|
||||
"special": True
|
||||
},
|
||||
"1": {
|
||||
"content": "<|im_start|>",
|
||||
"content": "<s>",
|
||||
"lstrip": False,
|
||||
"normalized": False,
|
||||
"rstrip": False,
|
||||
@ -85,7 +85,7 @@ def train_tokenizer():
|
||||
"special": True
|
||||
},
|
||||
"2": {
|
||||
"content": "<|im_end|>",
|
||||
"content": "</s>",
|
||||
"lstrip": False,
|
||||
"normalized": False,
|
||||
"rstrip": False,
|
||||
@ -94,9 +94,9 @@ def train_tokenizer():
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [],
|
||||
"bos_token": "<|im_start|>",
|
||||
"bos_token": "<s>",
|
||||
"clean_up_tokenization_spaces": False,
|
||||
"eos_token": "<|im_end|>",
|
||||
"eos_token": "</s>",
|
||||
"legacy": True,
|
||||
"model_max_length": 32768,
|
||||
"pad_token": "<unk>",
|
||||
@ -104,7 +104,7 @@ def train_tokenizer():
|
||||
"spaces_between_special_tokens": False,
|
||||
"tokenizer_class": "PreTrainedTokenizerFast",
|
||||
"unk_token": "<unk>",
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\n你是 MiniMind,是一个有用的人工智能助手。<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<s>system\\n' + system_message + '</s>\\n' }}{% else %}{{ '<s>system\\n你是 MiniMind,是一个有用的人工智能助手。</s>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
|
||||
}
|
||||
|
||||
# 保存配置文件
|
||||
|
33
startup.sh
33
startup.sh
@ -1,33 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# 在容器启动后,首先从 requirements.txt 安装所有依赖包
|
||||
# pip install -r requirements.txt
|
||||
|
||||
# bash install.sh -y
|
||||
python3 -m pip install --upgrade pip
|
||||
pip install uv -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
# 切换到项目目录
|
||||
cd /ycz/Minimind
|
||||
|
||||
# 检查并修复虚拟环境
|
||||
if [ ! -f .venv/bin/python ] || [ ! -x .venv/bin/python ]; then
|
||||
echo "Virtual environment is broken or missing, recreating with uv..."
|
||||
rm -rf .venv
|
||||
uv venv .venv
|
||||
fi
|
||||
|
||||
# 不要手动激活虚拟环境,让uv自动管理
|
||||
# . ./.venv/bin/activate
|
||||
|
||||
# 使用uv同步依赖
|
||||
uv sync
|
||||
|
||||
# 安装完成后,执行主训练脚本
|
||||
# "$@" 会将 experiment.yaml 中 entrypoint 定义的参数传递给 python 脚本
|
||||
CUDA_VISIBLE_DEVICES=0 uv run python -m accelerate.commands.launch \
|
||||
--num_processes=1 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py "$@"
|
97
test_real_rope.py
Normal file
97
test_real_rope.py
Normal file
@ -0,0 +1,97 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
测试实数版本的位置编码
|
||||
"""
|
||||
|
||||
import torch
|
||||
from model.model import precompute_pos_cis, precompute_pos_cis_real, apply_rotary_emb, apply_rotary_emb_real
|
||||
from model.LMConfig import LMConfig
|
||||
from model.model import MiniMindLM
|
||||
|
||||
def test_pos_encoding_equivalence():
|
||||
"""测试复数版本和实数版本的位置编码是否等价"""
|
||||
print("测试位置编码等价性...")
|
||||
|
||||
# 参数设置
|
||||
dim = 64
|
||||
seq_len = 10
|
||||
|
||||
# 生成复数版本的位置编码
|
||||
pos_cis = precompute_pos_cis(dim=dim, end=seq_len)
|
||||
|
||||
# 生成实数版本的位置编码
|
||||
pos_cis_real = precompute_pos_cis_real(dim=dim, end=seq_len)
|
||||
|
||||
# 创建随机查询和键
|
||||
batch_size = 2
|
||||
n_heads = 4
|
||||
head_dim = dim
|
||||
|
||||
xq = torch.randn(batch_size, seq_len, n_heads, head_dim)
|
||||
xk = torch.randn(batch_size, seq_len, n_heads, head_dim)
|
||||
|
||||
# 应用复数版本的旋转位置编码
|
||||
xq_complex, xk_complex = apply_rotary_emb(xq, xk, pos_cis)
|
||||
|
||||
# 应用实数版本的旋转位置编码
|
||||
xq_real, xk_real = apply_rotary_emb_real(xq, xk, pos_cis_real)
|
||||
|
||||
# 计算差异
|
||||
q_diff = torch.abs(xq_complex - xq_real).mean().item()
|
||||
k_diff = torch.abs(xk_complex - xk_real).mean().item()
|
||||
|
||||
print(f"查询差异: {q_diff:.6f}")
|
||||
print(f"键差异: {k_diff:.6f}")
|
||||
|
||||
# 检查差异是否在可接受范围内
|
||||
tolerance = 1e-5
|
||||
if q_diff < tolerance and k_diff < tolerance:
|
||||
print("✅ 测试通过: 复数版本和实数版本的位置编码在数值上等价")
|
||||
else:
|
||||
print("❌ 测试失败: 复数版本和实数版本的位置编码存在显著差异")
|
||||
|
||||
def test_model_forward():
|
||||
"""测试模型前向传播"""
|
||||
print("\n测试模型前向传播...")
|
||||
|
||||
# 创建模型配置
|
||||
config = LMConfig(
|
||||
dim=128,
|
||||
n_layers=2,
|
||||
n_heads=4,
|
||||
n_kv_heads=4, # 确保n_kv_heads被设置,且n_heads能被n_kv_heads整除
|
||||
vocab_size=1000,
|
||||
max_seq_len=128,
|
||||
disable_db=True # 禁用数据库功能,避免额外的复杂性
|
||||
)
|
||||
|
||||
# 创建模型
|
||||
try:
|
||||
model = MiniMindLM(config)
|
||||
print(f"✅ 模型初始化成功")
|
||||
except Exception as e:
|
||||
print(f"❌ 模型初始化失败: {str(e)}")
|
||||
return
|
||||
|
||||
# 创建输入
|
||||
batch_size = 2
|
||||
seq_len = 10
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
||||
|
||||
# 前向传播
|
||||
try:
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids)
|
||||
print(f"✅ 模型前向传播成功")
|
||||
print(f"输出形状: {outputs.logits.shape}")
|
||||
except Exception as e:
|
||||
print(f"❌ 模型前向传播失败: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试位置编码等价性
|
||||
test_pos_encoding_equivalence()
|
||||
|
||||
# 测试模型前向传播
|
||||
test_model_forward()
|
File diff suppressed because it is too large
Load Diff
@ -1,9 +1,8 @@
|
||||
import os
|
||||
# 设置环境变量 - 将wandb替换为SwanLab
|
||||
# os.environ["SWANLAB_MODE"] = "online" # SwanLab使用在线模式
|
||||
# 设置环境变量
|
||||
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
|
||||
import platform
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
import math
|
||||
import warnings
|
||||
@ -19,75 +18,13 @@ from accelerate.utils import set_seed
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
||||
import numpy as np
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import swanlab # 替换wandb导入
|
||||
import gc # 添加垃圾回收模块
|
||||
import psutil # 添加系统资源监控模块
|
||||
|
||||
|
||||
from model.model import MiniMindLM
|
||||
from model.LMConfig import LMConfig
|
||||
from model.dataset import PretrainDataset
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# 内存监控辅助函数
|
||||
def get_memory_usage():
|
||||
"""获取当前内存使用情况"""
|
||||
process = psutil.Process()
|
||||
memory_info = process.memory_info()
|
||||
return {
|
||||
'rss_mb': memory_info.rss / 1024 / 1024, # 物理内存使用量(MB)
|
||||
'vms_mb': memory_info.vms / 1024 / 1024, # 虚拟内存使用量(MB)
|
||||
}
|
||||
|
||||
def get_cuda_memory_usage():
|
||||
"""获取CUDA内存使用情况"""
|
||||
if torch.cuda.is_available():
|
||||
return {
|
||||
'cuda_allocated_mb': torch.cuda.memory_allocated() / 1024 / 1024,
|
||||
'cuda_reserved_mb': torch.cuda.memory_reserved() / 1024 / 1024,
|
||||
'cuda_max_allocated_mb': torch.cuda.max_memory_allocated() / 1024 / 1024,
|
||||
}
|
||||
return {}
|
||||
|
||||
def get_tensor_memory_size(tensor_list):
|
||||
"""计算tensor列表的总内存占用(MB)"""
|
||||
total_size = 0
|
||||
for batch in tensor_list:
|
||||
if isinstance(batch, (list, tuple)):
|
||||
for tensor in batch:
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
total_size += tensor.numel() * tensor.element_size()
|
||||
elif isinstance(batch, torch.Tensor):
|
||||
total_size += batch.numel() * batch.element_size()
|
||||
return total_size / 1024 / 1024 # 转换为MB
|
||||
|
||||
def log_memory_status(step, prefetch_batches, accelerator, stage="", detailed=False):
|
||||
"""记录内存状态"""
|
||||
if not accelerator.is_main_process:
|
||||
return
|
||||
|
||||
memory_info = get_memory_usage()
|
||||
cuda_info = get_cuda_memory_usage()
|
||||
prefetch_memory = get_tensor_memory_size(prefetch_batches)
|
||||
|
||||
log_msg = f"[Memory Monitor] Step {step} {stage} - "
|
||||
log_msg += f"Prefetch batches: {len(prefetch_batches)}, "
|
||||
log_msg += f"Prefetch memory: {prefetch_memory:.2f}MB, "
|
||||
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
|
||||
|
||||
if cuda_info:
|
||||
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
|
||||
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
|
||||
|
||||
if detailed:
|
||||
log_msg += f", System VMS: {memory_info['vms_mb']:.2f}MB"
|
||||
if cuda_info:
|
||||
log_msg += f", CUDA max allocated: {cuda_info['cuda_max_allocated_mb']:.2f}MB"
|
||||
|
||||
Logger(log_msg, accelerator)
|
||||
|
||||
# 日志记录函数
|
||||
def Logger(msg, accelerator=None):
|
||||
# 如果没有提供accelerator,则只在主进程打印
|
||||
@ -104,420 +41,27 @@ def get_lr(it, num_iters, learning_rate):
|
||||
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
|
||||
|
||||
# 初始化模型函数
|
||||
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
|
||||
if args.model_type == "model":
|
||||
Logger(f"Using model type: {args.model_type}")
|
||||
from model.model import MiniMindLM, RMSNorm
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
model = MiniMindLM(lm_config)
|
||||
|
||||
# 默认模型初始化
|
||||
Logger("Performing default model initialization...")
|
||||
|
||||
# 初始化嵌入层权重
|
||||
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
|
||||
|
||||
# 初始化输出层权重(如果不共享权重的话)
|
||||
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
|
||||
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
|
||||
|
||||
# 初始化所有线性层
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
# 使用Xavier/Glorot初始化
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
# 嵌入层使用正态分布初始化
|
||||
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
elif isinstance(module, RMSNorm):
|
||||
# RMSNorm的权重初始化为1
|
||||
if hasattr(module, 'weight'):
|
||||
nn.init.ones_(module.weight)
|
||||
|
||||
# 初始化位置编码相关参数
|
||||
if hasattr(model.knowledge_dataset, 'keys'):
|
||||
nn.init.normal_(model.knowledge_dataset.keys, mean=0.0, std=0.02)
|
||||
|
||||
Logger("Default model initialization completed")
|
||||
def init_model(lm_config, pretrained_embedding_path=None):
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
model = MiniMindLM(lm_config)
|
||||
|
||||
# 如果提供了预训练的嵌入权重,加载它们
|
||||
if pretrained_embedding_path:
|
||||
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
|
||||
pretrained_embeddings = torch.load(pretrained_embedding_path)
|
||||
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
|
||||
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
|
||||
|
||||
if database_init_path:
|
||||
import json
|
||||
import os
|
||||
|
||||
# 数据库参数
|
||||
knowledge_num = args.knowledge_num
|
||||
knowledge_length = args.knowledge_length
|
||||
|
||||
# 检查是否使用缓存
|
||||
cache_dir = os.path.dirname(args.cluster_cache_path)
|
||||
if cache_dir:
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
processed_tensor = None
|
||||
|
||||
# 尝试加载缓存的处理结果
|
||||
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
|
||||
try:
|
||||
Logger(f"Loading cached processed results from {args.cluster_cache_path}")
|
||||
processed_tensor = torch.load(args.cluster_cache_path)
|
||||
|
||||
# 验证缓存文件的形状是否可用
|
||||
cached_knowledge_num, cached_knowledge_length = processed_tensor.shape
|
||||
|
||||
if cached_knowledge_length == knowledge_length:
|
||||
if cached_knowledge_num >= knowledge_num:
|
||||
# 缓存足够大,可以截取使用
|
||||
processed_tensor = processed_tensor[:knowledge_num, :]
|
||||
Logger(f"Successfully loaded cached data with shape {processed_tensor.shape}")
|
||||
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
|
||||
Logger("Skipping database initialization - using cached results")
|
||||
else:
|
||||
# 缓存太小,需要重新计算
|
||||
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
|
||||
processed_tensor = None
|
||||
else:
|
||||
# knowledge_length不匹配,需要重新计算
|
||||
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
|
||||
processed_tensor = None
|
||||
except Exception as e:
|
||||
Logger(f"Failed to load cached data: {e}, recomputing...")
|
||||
processed_tensor = None
|
||||
|
||||
# 只有在没有有效缓存时才进行数据库初始化和处理
|
||||
if processed_tensor is None:
|
||||
Logger(f"Loading database initialization data from {database_init_path}")
|
||||
|
||||
# 1. 加载JSON文件
|
||||
with open(database_init_path, 'r', encoding='utf-8') as f:
|
||||
database_data = json.load(f)
|
||||
|
||||
sentences_data = []
|
||||
for data in database_data:
|
||||
sentences_data.append(data['target'][0]['sentence'])
|
||||
|
||||
# 提取sentences列表
|
||||
# sentences_data = database_data.get('sentences', [])
|
||||
Logger(f"Loaded {len(sentences_data)} sentences from database")
|
||||
|
||||
# 2. 按照importance_score进行排序(从高到低)
|
||||
try:
|
||||
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
||||
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
|
||||
except:
|
||||
sorted_sentences = sentences_data
|
||||
# 3. 处理每条数据,不进行聚类
|
||||
Logger("Processing individual sentences...")
|
||||
processed_rows = []
|
||||
|
||||
# 获取空token的id(用于填充)
|
||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
|
||||
# 处理所需数量的句子
|
||||
num_to_process = min(knowledge_num, len(sorted_sentences))
|
||||
|
||||
# 添加截断统计变量
|
||||
total_sentences = 0
|
||||
truncated_sentences = 0
|
||||
|
||||
for i in range(num_to_process):
|
||||
sentence_data = sorted_sentences[i]
|
||||
try:
|
||||
sentence = sentence_data.get('corrected_sentence')
|
||||
except:
|
||||
sentence = sentence_data
|
||||
|
||||
# 将句子转换为tokens
|
||||
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
||||
|
||||
# 截断或填充到knowledge_length
|
||||
total_sentences += 1
|
||||
if len(sentence_tokens) > knowledge_length:
|
||||
# 如果超过长度,截断
|
||||
truncated_sentences += 1
|
||||
sentence_tokens = sentence_tokens[:knowledge_length]
|
||||
Logger(f"Sentence {i+1} truncated from {len(tokenizer.encode(sentence, add_special_tokens=False))} to {knowledge_length} tokens")
|
||||
else:
|
||||
# 如果不足长度,用空token填充
|
||||
original_length = len(sentence_tokens)
|
||||
sentence_tokens.extend([pad_token_id] * (knowledge_length - len(sentence_tokens)))
|
||||
if original_length < knowledge_length:
|
||||
Logger(f"Sentence {i+1} padded from {original_length} to {knowledge_length} tokens")
|
||||
|
||||
processed_rows.append(sentence_tokens)
|
||||
|
||||
if (i + 1) % 1000 == 0:
|
||||
Logger(f"Processed {i + 1}/{num_to_process} sentences")
|
||||
|
||||
# 如果句子数量不足,用空token填充剩余位置
|
||||
while len(processed_rows) < knowledge_num:
|
||||
empty_tokens = [pad_token_id] * knowledge_length
|
||||
processed_rows.append(empty_tokens)
|
||||
if len(processed_rows) % 1000 == 0:
|
||||
Logger(f"Added empty entry {len(processed_rows)}/{knowledge_num}")
|
||||
|
||||
Logger(f"Finished adding empty entries. Total: {len(processed_rows)}/{knowledge_num}")
|
||||
|
||||
# 转换为tensor
|
||||
processed_tensor = torch.tensor(processed_rows, dtype=torch.long)
|
||||
|
||||
# 计算并打印截断句子的占比
|
||||
truncation_ratio = truncated_sentences / total_sentences if total_sentences > 0 else 0.0
|
||||
Logger(f"截断句子统计:")
|
||||
Logger(f" - 总句子数: {total_sentences}")
|
||||
Logger(f" - 截断句子数: {truncated_sentences}")
|
||||
Logger(f" - 截断句子占比: {truncation_ratio:.4f} ({truncation_ratio*100:.2f}%)")
|
||||
|
||||
Logger(f"Data processing completed:")
|
||||
Logger(f" - Processed {num_to_process} sentences")
|
||||
Logger(f" - Added {knowledge_num - num_to_process} empty entries")
|
||||
Logger(f" - Final shape: {processed_tensor.shape}")
|
||||
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
|
||||
|
||||
# 保存处理结果到缓存文件
|
||||
try:
|
||||
torch.save(processed_tensor, args.cluster_cache_path)
|
||||
Logger(f"Processed results saved to {args.cluster_cache_path}")
|
||||
except Exception as e:
|
||||
Logger(f"Failed to save processed results: {e}")
|
||||
|
||||
# 4. 初始化模型的knowledge_dataset
|
||||
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
|
||||
model.knowledge_dataset.knowledge_dataset.data.copy_(processed_tensor)
|
||||
Logger("Successfully initialized model.knowledge_dataset.knowledge_dataset with processed data")
|
||||
else:
|
||||
Logger("Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize")
|
||||
# 存储为全局变量作为备选
|
||||
globals()['processed_database'] = processed_tensor
|
||||
|
||||
Logger(f"Database embeddings and sentences stored in model")
|
||||
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
elif args.model_type == "model_original":
|
||||
Logger(f"Using model type: {args.model_type}")
|
||||
from model.model_original import MiniMindLM, RMSNorm
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
model = MiniMindLM(lm_config)
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
elif args.model_type == "model_no_feed":
|
||||
Logger(f"Using model type: {args.model_type}")
|
||||
from model.model_no_feed import MiniMindLM, RMSNorm
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
model = MiniMindLM(lm_config)
|
||||
|
||||
# 默认模型初始化
|
||||
Logger("Performing default model initialization...")
|
||||
|
||||
# 初始化嵌入层权重
|
||||
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
|
||||
|
||||
# 初始化输出层权重(如果不共享权重的话)
|
||||
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
|
||||
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
|
||||
|
||||
# 初始化所有线性层
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
# 使用Xavier/Glorot初始化
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
# 嵌入层使用正态分布初始化
|
||||
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
elif isinstance(module, RMSNorm):
|
||||
# RMSNorm的权重初始化为1
|
||||
if hasattr(module, 'weight'):
|
||||
nn.init.ones_(module.weight)
|
||||
|
||||
# 初始化位置编码相关参数
|
||||
if hasattr(model.knowledge_dataset, 'keys'):
|
||||
nn.init.normal_(model.knowledge_dataset.keys, mean=0.0, std=0.02)
|
||||
|
||||
Logger("Default model initialization completed")
|
||||
|
||||
# 如果提供了预训练的嵌入权重,加载它们
|
||||
if pretrained_embedding_path:
|
||||
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
|
||||
pretrained_embeddings = torch.load(pretrained_embedding_path)
|
||||
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
|
||||
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
|
||||
|
||||
if database_init_path:
|
||||
import json
|
||||
import os
|
||||
|
||||
# 数据库参数
|
||||
knowledge_num = args.knowledge_num
|
||||
knowledge_length = args.knowledge_length
|
||||
|
||||
# 检查是否使用缓存
|
||||
cache_dir = os.path.dirname(args.cluster_cache_path)
|
||||
if cache_dir:
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
processed_tensor = None
|
||||
|
||||
# 尝试加载缓存的处理结果
|
||||
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
|
||||
try:
|
||||
Logger(f"Loading cached processed results from {args.cluster_cache_path}")
|
||||
processed_tensor = torch.load(args.cluster_cache_path)
|
||||
|
||||
# 验证缓存文件的形状是否可用
|
||||
cached_knowledge_num, cached_knowledge_length = processed_tensor.shape
|
||||
|
||||
if cached_knowledge_length == knowledge_length:
|
||||
if cached_knowledge_num >= knowledge_num:
|
||||
# 缓存足够大,可以截取使用
|
||||
processed_tensor = processed_tensor[:knowledge_num, :]
|
||||
Logger(f"Successfully loaded cached data with shape {processed_tensor.shape}")
|
||||
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
|
||||
Logger("Skipping database initialization - using cached results")
|
||||
else:
|
||||
# 缓存太小,需要重新计算
|
||||
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
|
||||
processed_tensor = None
|
||||
else:
|
||||
# knowledge_length不匹配,需要重新计算
|
||||
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
|
||||
processed_tensor = None
|
||||
except Exception as e:
|
||||
Logger(f"Failed to load cached data: {e}, recomputing...")
|
||||
processed_tensor = None
|
||||
|
||||
# 只有在没有有效缓存时才进行数据库初始化和处理
|
||||
if processed_tensor is None:
|
||||
Logger(f"Loading database initialization data from {database_init_path}")
|
||||
|
||||
# 1. 加载JSON文件
|
||||
with open(database_init_path, 'r', encoding='utf-8') as f:
|
||||
database_data = json.load(f)
|
||||
|
||||
sentences_data = []
|
||||
for data in database_data:
|
||||
sentences_data.append(data['target'][0]['sentence'])
|
||||
|
||||
# 提取sentences列表
|
||||
# sentences_data = database_data.get('sentences', [])
|
||||
Logger(f"Loaded {len(sentences_data)} sentences from database")
|
||||
|
||||
# 2. 按照importance_score进行排序(从高到低)
|
||||
try:
|
||||
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
||||
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
|
||||
except:
|
||||
sorted_sentences = sentences_data
|
||||
# 3. 处理每条数据,不进行聚类
|
||||
Logger("Processing individual sentences...")
|
||||
processed_rows = []
|
||||
|
||||
# 获取空token的id(用于填充)
|
||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
|
||||
# 处理所需数量的句子
|
||||
num_to_process = min(knowledge_num, len(sorted_sentences))
|
||||
|
||||
# 添加截断统计变量
|
||||
total_sentences = 0
|
||||
truncated_sentences = 0
|
||||
|
||||
for i in range(num_to_process):
|
||||
sentence_data = sorted_sentences[i]
|
||||
try:
|
||||
sentence = sentence_data.get('corrected_sentence')
|
||||
except:
|
||||
sentence = sentence_data
|
||||
|
||||
# 将句子转换为tokens
|
||||
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
||||
|
||||
# 截断或填充到knowledge_length
|
||||
total_sentences += 1
|
||||
if len(sentence_tokens) > knowledge_length:
|
||||
# 如果超过长度,截断
|
||||
truncated_sentences += 1
|
||||
sentence_tokens = sentence_tokens[:knowledge_length]
|
||||
Logger(f"Sentence {i+1} truncated from {len(tokenizer.encode(sentence, add_special_tokens=False))} to {knowledge_length} tokens")
|
||||
else:
|
||||
# 如果不足长度,用空token填充
|
||||
original_length = len(sentence_tokens)
|
||||
sentence_tokens.extend([pad_token_id] * (knowledge_length - len(sentence_tokens)))
|
||||
if original_length < knowledge_length:
|
||||
Logger(f"Sentence {i+1} padded from {original_length} to {knowledge_length} tokens")
|
||||
|
||||
processed_rows.append(sentence_tokens)
|
||||
|
||||
if (i + 1) % 1000 == 0:
|
||||
Logger(f"Processed {i + 1}/{num_to_process} sentences")
|
||||
|
||||
# 如果句子数量不足,用空token填充剩余位置
|
||||
while len(processed_rows) < knowledge_num:
|
||||
empty_tokens = [pad_token_id] * knowledge_length
|
||||
processed_rows.append(empty_tokens)
|
||||
if len(processed_rows) % 1000 == 0:
|
||||
Logger(f"Added empty entry {len(processed_rows)}/{knowledge_num}")
|
||||
|
||||
Logger(f"Finished adding empty entries. Total: {len(processed_rows)}/{knowledge_num}")
|
||||
|
||||
# 转换为tensor
|
||||
processed_tensor = torch.tensor(processed_rows, dtype=torch.long)
|
||||
|
||||
# 计算并打印截断句子的占比
|
||||
truncation_ratio = truncated_sentences / total_sentences if total_sentences > 0 else 0.0
|
||||
Logger(f"截断句子统计:")
|
||||
Logger(f" - 总句子数: {total_sentences}")
|
||||
Logger(f" - 截断句子数: {truncated_sentences}")
|
||||
Logger(f" - 截断句子占比: {truncation_ratio:.4f} ({truncation_ratio*100:.2f}%)")
|
||||
|
||||
Logger(f"Data processing completed:")
|
||||
Logger(f" - Processed {num_to_process} sentences")
|
||||
Logger(f" - Added {knowledge_num - num_to_process} empty entries")
|
||||
Logger(f" - Final shape: {processed_tensor.shape}")
|
||||
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
|
||||
|
||||
# 保存处理结果到缓存文件
|
||||
try:
|
||||
torch.save(processed_tensor, args.cluster_cache_path)
|
||||
Logger(f"Processed results saved to {args.cluster_cache_path}")
|
||||
except Exception as e:
|
||||
Logger(f"Failed to save processed results: {e}")
|
||||
|
||||
# 4. 初始化模型的knowledge_dataset
|
||||
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
|
||||
model.knowledge_dataset.knowledge_dataset.data.copy_(processed_tensor)
|
||||
Logger("Successfully initialized model.knowledge_dataset.knowledge_dataset with processed data")
|
||||
else:
|
||||
Logger("Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize")
|
||||
# 存储为全局变量作为备选
|
||||
globals()['processed_database'] = processed_tensor
|
||||
|
||||
Logger(f"Database embeddings and sentences stored in model")
|
||||
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
# 如果提供了预训练的嵌入权重,加载它们
|
||||
if pretrained_embedding_path:
|
||||
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
|
||||
pretrained_embeddings = torch.load(pretrained_embedding_path)
|
||||
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
|
||||
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
|
||||
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
return model, tokenizer
|
||||
|
||||
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer):
|
||||
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb):
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
epoch_start_time = time.time()
|
||||
total_steps_in_epoch = len(train_loader)
|
||||
total_training_steps = args.epochs * total_steps_in_epoch
|
||||
moe_path = '_moe' if args.use_moe else ''
|
||||
best_loss = float('10000')
|
||||
|
||||
# 初始化CUDA事件变量
|
||||
data_start = data_end = forward_start = forward_end = None
|
||||
backward_start = backward_end = optimizer_start = optimizer_end = None
|
||||
|
||||
# 添加CUDA事件来分析性能 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
data_start = torch.cuda.Event(enable_timing=True)
|
||||
@ -530,67 +74,44 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
optimizer_end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
# 预取数据
|
||||
prefetch_factor = 8 # 预取的批次数
|
||||
prefetch_factor = 2 # 预取的批次数
|
||||
data_iter = iter(train_loader)
|
||||
prefetch_batches = []
|
||||
|
||||
# 记录初始内存状态
|
||||
if args.memory_monitor:
|
||||
log_memory_status(-1, prefetch_batches, accelerator, "before_prefetch", detailed=True)
|
||||
|
||||
# 预取初始批次
|
||||
for i in range(min(prefetch_factor, len(train_loader))):
|
||||
for _ in range(min(prefetch_factor, len(train_loader))):
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
prefetch_batches.append(batch)
|
||||
# 每次添加batch后记录内存变化
|
||||
if args.memory_monitor and accelerator.is_main_process:
|
||||
log_memory_status(-1, prefetch_batches, accelerator, f"after_adding_batch_{i+1}")
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
# 记录预取完成后的内存状态
|
||||
if args.memory_monitor:
|
||||
log_memory_status(-1, prefetch_batches, accelerator, "after_initial_prefetch", detailed=True)
|
||||
|
||||
# 在开始循环前初始化日志记录所需变量
|
||||
last_log_time = epoch_start_time
|
||||
|
||||
for step in range(total_steps_in_epoch):
|
||||
try:
|
||||
# 计时数据加载 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process and data_start is not None:
|
||||
if args.profile and accelerator.is_main_process:
|
||||
data_start.record()
|
||||
|
||||
# 记录使用batch前的内存状态(根据配置间隔记录详细信息)
|
||||
if args.memory_monitor and step % args.memory_monitor_interval == 0:
|
||||
log_memory_status(step, prefetch_batches, accelerator, "before_use_batch", detailed=True)
|
||||
|
||||
# 使用预取的数据
|
||||
if prefetch_batches:
|
||||
X, Y, loss_mask = prefetch_batches.pop(0)
|
||||
# 记录使用batch后的内存变化
|
||||
if args.memory_monitor and step % args.memory_monitor_interval == 0:
|
||||
log_memory_status(step, prefetch_batches, accelerator, "after_pop_batch")
|
||||
else:
|
||||
# 如果预取队列为空,直接加载
|
||||
X, Y, loss_mask = next(data_iter)
|
||||
if args.memory_monitor and accelerator.is_main_process:
|
||||
Logger(f"[Memory Monitor] Step {step} - Prefetch queue empty, loading directly!", accelerator)
|
||||
|
||||
# 异步预取下一批数据
|
||||
if step + prefetch_factor < len(train_loader):
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
prefetch_batches.append(batch)
|
||||
# 记录添加新batch后的内存变化
|
||||
if args.memory_monitor and step % args.memory_monitor_interval == 0:
|
||||
log_memory_status(step, prefetch_batches, accelerator, "after_add_batch")
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
# 计时数据加载结束 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process and data_end is not None:
|
||||
if args.profile and accelerator.is_main_process:
|
||||
data_end.record()
|
||||
|
||||
# 更新学习率
|
||||
@ -598,31 +119,33 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
scheduler.step()
|
||||
|
||||
# 计时前向传播 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process and forward_start is not None:
|
||||
if args.profile and accelerator.is_main_process:
|
||||
forward_start.record()
|
||||
|
||||
# 前向传播
|
||||
with ctx:
|
||||
if step == 0 and args.embedding_epoch == epoch:
|
||||
# 需要设置原始模型的freeze_embedding属性,而不是包装后的模型
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.freeze_embedding = True
|
||||
Logger(f"Set freeze_embedding=True for epoch {epoch}, step {step}", accelerator)
|
||||
res = model(X, step=step)
|
||||
res = model(X)
|
||||
loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
Y.view(-1)
|
||||
).view(Y.size())
|
||||
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||
# 移除辅助损失计算,统一不使用 aux_loss
|
||||
# 添加辅助损失,如果存在的话
|
||||
try:
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
|
||||
if hasattr(l.feed_forward, 'aux_loss'))
|
||||
loss += aux_loss
|
||||
except Exception as e:
|
||||
Logger(f"Warning: Could not add auxiliary loss: {e}")
|
||||
# 如果出错,不添加辅助损失
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
# 计时前向传播结束 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process and forward_end is not None:
|
||||
if args.profile and accelerator.is_main_process:
|
||||
forward_end.record()
|
||||
|
||||
# 计时反向传播 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process and backward_start is not None:
|
||||
if args.profile and accelerator.is_main_process:
|
||||
backward_start.record()
|
||||
|
||||
# 反向传播
|
||||
@ -630,11 +153,11 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
accelerator.backward(loss)
|
||||
|
||||
# 计时反向传播结束 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process and backward_end is not None:
|
||||
if args.profile and accelerator.is_main_process:
|
||||
backward_end.record()
|
||||
|
||||
# 计时优化器步骤 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process and optimizer_start is not None:
|
||||
if args.profile and accelerator.is_main_process:
|
||||
optimizer_start.record()
|
||||
|
||||
# 优化器步骤 - 当使用DeepSpeed时,它会自动处理梯度累积和梯度裁剪
|
||||
@ -647,111 +170,40 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 计时优化器步骤结束 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process and optimizer_end is not None:
|
||||
if args.profile and accelerator.is_main_process:
|
||||
optimizer_end.record()
|
||||
|
||||
# 打印训练信息 (只在主进程进行)
|
||||
if (step + 1) % args.log_interval == 0 and accelerator.is_main_process:
|
||||
current_time = time.time()
|
||||
|
||||
# 记录日志输出时的详细内存状态
|
||||
if args.memory_monitor:
|
||||
log_memory_status(step, prefetch_batches, accelerator, "at_log_interval", detailed=True)
|
||||
|
||||
# 强制垃圾回收并记录内存变化
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
log_memory_status(step, prefetch_batches, accelerator, "after_gc", detailed=True)
|
||||
|
||||
# 计算性能指标
|
||||
if args.profile and accelerator.is_main_process:
|
||||
if args.profile:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# 确保所有事件都已记录才计算elapsed_time
|
||||
try:
|
||||
data_time = data_start.elapsed_time(data_end) if data_start is not None and data_end is not None else 0
|
||||
forward_time = forward_start.elapsed_time(forward_end) if forward_start is not None and forward_end is not None else 0
|
||||
backward_time = backward_start.elapsed_time(backward_end) if backward_start is not None and backward_end is not None else 0
|
||||
optimizer_time = optimizer_start.elapsed_time(optimizer_end) if optimizer_start is not None and optimizer_end is not None else 0
|
||||
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
|
||||
# total_time_ms = data_time + forward_time + backward_time + optimizer_time
|
||||
# 使用自上次日志以来的时间计算性能指标,而不是总时间
|
||||
data_time = data_start.elapsed_time(data_end)
|
||||
forward_time = forward_start.elapsed_time(forward_end)
|
||||
backward_time = backward_start.elapsed_time(backward_end)
|
||||
optimizer_time = optimizer_start.elapsed_time(optimizer_end)
|
||||
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
|
||||
# total_time_ms = data_time + forward_time + backward_time + optimizer_time
|
||||
|
||||
# 打印性能分析
|
||||
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
|
||||
Logger(f"性能分析 (Avg/iter over last {args.log_interval} steps) - "
|
||||
f"Data: {data_time/args.log_interval:.2f}ms, "
|
||||
f"Fwd: {forward_time/args.log_interval:.2f}ms, "
|
||||
f"Bwd: {backward_time/args.log_interval:.2f}ms, "
|
||||
f"Optim: {optimizer_time/args.log_interval:.2f}ms, "
|
||||
f"Iter Time: {iter_time:.2f}ms", accelerator)
|
||||
|
||||
# 生成文本示例
|
||||
try:
|
||||
# 随机选择一个样本
|
||||
random_idx = torch.randint(0, X.size(0), (1,)).item()
|
||||
sample_input = X[random_idx:random_idx+1] # [1, seq_len]
|
||||
sample_target = Y[random_idx:random_idx+1] # [1, seq_len]
|
||||
|
||||
# 取前面的部分作为prompt,确保后面有10个token作为真实值
|
||||
prompt_len = sample_input.size(1) // 2
|
||||
prompt_input = sample_input[:, :prompt_len]
|
||||
|
||||
# 获取真实的后10个token
|
||||
true_next_tokens = sample_target[:, prompt_len-1:prompt_len-1+10] # 真实的接下来10个token
|
||||
|
||||
# 生成10个token
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.eval() # 设置为评估模式
|
||||
|
||||
with torch.no_grad():
|
||||
generated = unwrapped_model.generate(
|
||||
prompt_input,
|
||||
max_new_tokens=10,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
# 转换为人类可读文本
|
||||
prompt_text = tokenizer.decode(prompt_input[0], skip_special_tokens=True)
|
||||
true_text = tokenizer.decode(true_next_tokens[0], skip_special_tokens=True)
|
||||
|
||||
# 获取新生成的token
|
||||
prompt_tokens = prompt_input[0].tolist()
|
||||
generated_tokens = generated[0].tolist()
|
||||
|
||||
if len(generated_tokens) > len(prompt_tokens):
|
||||
new_tokens = generated_tokens[len(prompt_tokens):len(prompt_tokens)+10] # 只取前10个
|
||||
generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
||||
else:
|
||||
generated_text = "[未生成新token]"
|
||||
|
||||
Logger(f"文本生成对比:", accelerator)
|
||||
Logger(f" 输入提示: {prompt_text}", accelerator)
|
||||
Logger(f" 真实续写: {true_text}", accelerator)
|
||||
Logger(f" 模型生成: {generated_text}", accelerator)
|
||||
|
||||
unwrapped_model.train() # 恢复训练模式
|
||||
|
||||
except Exception as e:
|
||||
Logger(f"生成文本示例失败: {e}", accelerator)
|
||||
|
||||
# 重置事件以便下次测量从0开始
|
||||
data_start = torch.cuda.Event(enable_timing=True)
|
||||
data_end = torch.cuda.Event(enable_timing=True)
|
||||
forward_start = torch.cuda.Event(enable_timing=True)
|
||||
forward_end = torch.cuda.Event(enable_timing=True)
|
||||
backward_start = torch.cuda.Event(enable_timing=True)
|
||||
backward_end = torch.cuda.Event(enable_timing=True)
|
||||
optimizer_start = torch.cuda.Event(enable_timing=True)
|
||||
optimizer_end = torch.cuda.Event(enable_timing=True)
|
||||
except RuntimeError as e:
|
||||
if "Both events must be recorded" in str(e):
|
||||
Logger(f"Warning: CUDA events not properly recorded, skipping performance analysis: {e}", accelerator)
|
||||
else:
|
||||
raise e
|
||||
# 打印性能分析
|
||||
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
|
||||
Logger(f"性能分析 (Avg/iter over last {args.log_interval} steps) - "
|
||||
f"Data: {data_time/args.log_interval:.2f}ms, "
|
||||
f"Fwd: {forward_time/args.log_interval:.2f}ms, "
|
||||
f"Bwd: {backward_time/args.log_interval:.2f}ms, "
|
||||
f"Optim: {optimizer_time/args.log_interval:.2f}ms, "
|
||||
f"Iter Time: {iter_time:.2f}ms", accelerator)
|
||||
# 重置事件以便下次测量从0开始
|
||||
data_start = torch.cuda.Event(enable_timing=True)
|
||||
data_end = torch.cuda.Event(enable_timing=True)
|
||||
forward_start = torch.cuda.Event(enable_timing=True)
|
||||
forward_end = torch.cuda.Event(enable_timing=True)
|
||||
backward_start = torch.cuda.Event(enable_timing=True)
|
||||
backward_end = torch.cuda.Event(enable_timing=True)
|
||||
optimizer_start = torch.cuda.Event(enable_timing=True)
|
||||
optimizer_end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
|
||||
# 计算当前学习率
|
||||
@ -792,13 +244,11 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
|
||||
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
|
||||
|
||||
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
|
||||
swanlab_run.log(log_dict)
|
||||
if args.use_wandb and accelerator.is_main_process and wandb:
|
||||
wandb.log(log_dict)
|
||||
|
||||
# 保存模型 (只在主进程进行)
|
||||
loss_total = loss.item() * args.accumulation_steps
|
||||
if epoch > 1 and best_loss > loss_total and accelerator.is_main_process:
|
||||
best_loss = loss_total
|
||||
if (step + 1) % args.save_interval == 0 and accelerator.is_main_process:
|
||||
# 使用函数开始处定义的moe_path变量
|
||||
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
|
||||
|
||||
@ -811,73 +261,38 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
|
||||
except Exception as e:
|
||||
Logger(f"Error in training step: {e}", accelerator)
|
||||
# 记录异常时的内存状态
|
||||
if args.memory_monitor:
|
||||
log_memory_status(step, prefetch_batches, accelerator, "at_exception", detailed=True)
|
||||
import traceback
|
||||
Logger(traceback.format_exc(), accelerator)
|
||||
|
||||
# 清理prefetch_batches,防止内存泄漏
|
||||
if args.memory_monitor and accelerator.is_main_process:
|
||||
Logger(f"[Memory Monitor] Clearing prefetch_batches due to exception. Current length: {len(prefetch_batches)}", accelerator)
|
||||
prefetch_batches.clear()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
if args.memory_monitor:
|
||||
log_memory_status(step, prefetch_batches, accelerator, "after_exception_cleanup", detailed=True)
|
||||
|
||||
# 训练epoch结束时清理prefetch_batches
|
||||
if args.memory_monitor:
|
||||
if accelerator.is_main_process:
|
||||
Logger(f"[Memory Monitor] Epoch {epoch+1} finished. Clearing prefetch_batches. Final length: {len(prefetch_batches)}", accelerator)
|
||||
log_memory_status(total_steps_in_epoch-1, prefetch_batches, accelerator, "before_epoch_end_cleanup", detailed=True)
|
||||
prefetch_batches.clear()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
if args.memory_monitor:
|
||||
log_memory_status(total_steps_in_epoch-1, prefetch_batches, accelerator, "after_epoch_end_cleanup", detailed=True)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
|
||||
parser.add_argument("--out_dir", type=str, default="out")
|
||||
parser.add_argument("--epochs", type=int, default=4)
|
||||
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
|
||||
parser.add_argument("--batch_size", type=int, default=128)
|
||||
parser.add_argument("--epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=24)
|
||||
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数
|
||||
parser.add_argument("--swanlab_project", type=str, default="MiniMind-Pretrain") # 替换wandb参数
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
parser.add_argument("--use_wandb", default=True, action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
|
||||
parser.add_argument("--num_workers", type=int, default=48)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=32)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--warmup_iters", type=int, default=0)
|
||||
parser.add_argument("--log_interval", type=int, default=1)
|
||||
parser.add_argument("--log_interval", type=int, default=100)
|
||||
parser.add_argument("--save_interval", type=int, default=10000)
|
||||
parser.add_argument('--dim', default=512, type=int)
|
||||
parser.add_argument('--n_layers', default=8, type=int)
|
||||
parser.add_argument('--max_seq_len', default=512, type=int)
|
||||
parser.add_argument('--dim', default=1024, type=int)
|
||||
parser.add_argument('--n_layers', default=32, type=int)
|
||||
parser.add_argument('--max_seq_len', default=1024, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/stable/merged_pretrain.jsonl")
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
|
||||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||
parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目")
|
||||
parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
|
||||
parser.add_argument("--database_init_path", type=str, default="./dataset/stable/sentence_trex_data.json", help="数据库初始化路径")
|
||||
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
|
||||
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
|
||||
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
|
||||
parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控")
|
||||
parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)")
|
||||
parser.add_argument("--model_type", type=str, default="model", help="使用什么模型训练") #model,model_original,model_no_feed
|
||||
parser.add_argument("--model_size", type=float, default=50.0, help="模型大小")
|
||||
parser.add_argument("--swanlab_online", type=bool, default=False, help="是否使用在线SwanLab服务")
|
||||
parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目")
|
||||
parser.add_argument("--knowledge_length", type=int, default=8,help="知识库的句子长度")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
#########################################################
|
||||
# 初始化accelerator和deepspeed
|
||||
#########################################################
|
||||
@ -888,7 +303,7 @@ def main():
|
||||
gradient_accumulation_steps=args.accumulation_steps,
|
||||
gradient_clipping=args.grad_clip,
|
||||
zero_stage=2, # 使用ZeRO-2优化
|
||||
offload_optimizer_device="none", # 将优化器状态卸载到CPU
|
||||
offload_optimizer_device="cpu", # 将优化器状态卸载到CPU
|
||||
offload_param_device="none", # 不将参数卸载到CPU
|
||||
)
|
||||
accelerator = Accelerator(
|
||||
@ -913,8 +328,7 @@ def main():
|
||||
disable_db=args.disable_db,
|
||||
flash_attn=args.use_flash_attn,
|
||||
knowledge_num=args.knowledge_num,
|
||||
knowledge_length=args.knowledge_length,
|
||||
embeddings_epoch=args.embedding_epoch
|
||||
knowledge_length=args.knowledge_length
|
||||
)
|
||||
|
||||
#########################################################
|
||||
@ -932,37 +346,18 @@ def main():
|
||||
|
||||
|
||||
#########################################################
|
||||
# 配置SwanLab
|
||||
# 配置wandb
|
||||
#########################################################
|
||||
# 设置SwanLab运行名称
|
||||
args.swanlab_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
|
||||
# 合并args和lm_config为一个字典(无论是否使用SwanLab都需要,用于打印配置信息)
|
||||
config_dict = vars(args).copy()
|
||||
config_dict.update(vars(lm_config))
|
||||
|
||||
# 初始化SwanLab实验实例
|
||||
swanlab_run = None
|
||||
if args.use_swanlab and accelerator.is_main_process:
|
||||
if args.swanlab_online:
|
||||
# 使用在线SwanLab服务
|
||||
# 初始化SwanLab
|
||||
swanlab_run = swanlab.init(
|
||||
project=args.swanlab_project,
|
||||
experiment_name=args.swanlab_run_name,
|
||||
description="MiniMind预训练实验,使用本地部署的SwanLab进行可视化",
|
||||
config=config_dict
|
||||
)
|
||||
else:
|
||||
swanlab_run = swanlab.init(
|
||||
project=args.swanlab_project,
|
||||
experiment_name=args.swanlab_run_name,
|
||||
description="MiniMind预训练实验,使用本地部署的SwanLab进行可视化",
|
||||
config=config_dict,
|
||||
mode="offline"
|
||||
)
|
||||
# 设置wandb运行名称
|
||||
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
if args.use_wandb and accelerator.is_main_process:
|
||||
import wandb
|
||||
# 合并args和lm_config为一个字典
|
||||
config_dict = vars(args).copy()
|
||||
config_dict.update(vars(lm_config))
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config_dict)
|
||||
else:
|
||||
swanlab_run = None
|
||||
wandb = None
|
||||
|
||||
#########################################################
|
||||
# 打印信息
|
||||
@ -984,7 +379,7 @@ def main():
|
||||
#########################################################
|
||||
# 初始化模型和tokenizer
|
||||
#########################################################
|
||||
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path, args.database_init_path, args)
|
||||
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
|
||||
# 将accelerator传递给init_model函数中的Logger调用
|
||||
Logger(f'模型初始化完成', accelerator)
|
||||
|
||||
@ -1044,31 +439,13 @@ def main():
|
||||
#########################################################
|
||||
overall_start_time = time.time() # Record overall start time
|
||||
for epoch in range(args.epochs):
|
||||
Logger(f"开始第{epoch+1}轮训练", accelerator)
|
||||
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer) # Pass tokenizer
|
||||
|
||||
# 每个epoch结束后进行内存清理
|
||||
Logger(f"第{epoch+1}轮训练完成,进行内存清理", accelerator)
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 记录epoch结束时的内存状态
|
||||
if accelerator.is_main_process:
|
||||
memory_info = get_memory_usage()
|
||||
cuda_info = get_cuda_memory_usage()
|
||||
log_msg = f"[Memory Monitor] Epoch {epoch+1} completed - "
|
||||
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
|
||||
if cuda_info:
|
||||
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
|
||||
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
|
||||
Logger(log_msg, accelerator)
|
||||
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb) # Pass overall start time
|
||||
|
||||
#########################################################
|
||||
# 关闭SwanLab
|
||||
# 关闭wandb
|
||||
#########################################################
|
||||
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
|
||||
swanlab_run.finish()
|
||||
if args.use_wandb and accelerator.is_main_process:
|
||||
wandb.finish()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user