Compare commits

..

No commits in common. "master" and "Gary_Lu" have entirely different histories.

37 changed files with 824 additions and 24128 deletions

9
.gitignore vendored
View File

@ -2,11 +2,4 @@
/dataset
/out
wandb/
**/*.log
models/sentence_transformers/
models/sentence_transformers_cache/
**/*.pyc
qwen2-1.7B/
images/
cache/
.venv/
**/*.log

124
.vscode/launch.json vendored
View File

@ -1,124 +0,0 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "MiniMind Training (Direct Python)",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
"args": [
"--out_dir", "out",
"--epochs", "3",
"--embedding_epoch", "2",
"--batch_size", "128",
"--learning_rate", "8e-5",
"--dtype", "bfloat16",
"--use_swanlab",
"--swanlab_project", "MiniMind-Pretrain",
"--num_workers", "1",
"--accumulation_steps", "16",
"--grad_clip", "0.5",
"--warmup_iters", "0",
"--log_interval", "1",
"--save_interval", "10000",
"--dim", "512",
"--n_layers", "8",
"--max_seq_len", "512",
"--data_path", "./dataset/stable/merged_pretrain.jsonl",
"--profile",
"--profile_interval", "10",
"--use_flash_attn",
"--knowledge_num", "1048576",
"--knowledge_length", "32",
"--database_init_path", "./dataset/stable/sentence_trex_data.json",
"--fast_clustering",
"--cluster_cache_path", "./cache/cluster_tokens_single.pt",
"--memory_monitor_interval", "10",
"--model_type", "model",
"--model_size", "538"
],
"env": {
"CUDA_VISIBLE_DEVICES": "0",
"NCCL_DEBUG": "INFO",
"PYTHONFAULTHANDLER": "1"
},
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"justMyCode": false,
"stopOnEntry": false,
"python": "${workspaceFolder}/.venv/bin/python"
},
{
"name": "MiniMind Training (Direct Python - Simple)",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
"args": [
"--epochs", "1",
"--batch_size", "32",
"--learning_rate", "1e-4",
"--log_interval", "10",
"--profile_interval", "2",
"--model_type", "model_original"
],
"env": {
"CUDA_VISIBLE_DEVICES": "0"
},
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"justMyCode": false,
"stopOnEntry": false,
"python": "${workspaceFolder}/.venv/bin/python"
},
{
"name": "MiniMind Test (Direct Python)",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/test.py",
"env": {
"CUDA_VISIBLE_DEVICES": "0"
},
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"justMyCode": false,
"python": "${workspaceFolder}/.venv/bin/python"
},
{
"name": "MiniMind Training Debug (Accelerate)",
"type": "python",
"request": "launch",
"module": "accelerate.commands.launch",
"args": [
"--num_processes=1",
"--mixed_precision=bf16",
"${workspaceFolder}/train_pretrain_accelerate.py",
"--epochs", "1",
"--batch_size", "32",
"--learning_rate", "1e-4",
"--log_interval", "10",
"--profile_interval", "2",
"--model_type", "model_original"
],
"env": {
"CUDA_VISIBLE_DEVICES": "0"
},
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"justMyCode": false,
"stopOnEntry": false,
"python": "${workspaceFolder}/.venv/bin/python"
},
{
"name": "MiniMind Test Only",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/test.py",
"env": {
"CUDA_VISIBLE_DEVICES": "0"
},
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"justMyCode": false
}
]
}

18
.vscode/settings.json vendored
View File

@ -1,18 +0,0 @@
{
"python.pythonPath": "/home/iomgaa/miniconda3/envs/accelerate/bin/python",
"python.defaultInterpreterPath": "/home/iomgaa/miniconda3/envs/accelerate/bin/python",
"python.terminal.activateEnvironment": true,
"python.terminal.activateEnvInCurrentTerminal": true,
"python.linting.enabled": true,
"python.linting.pylintEnabled": false,
"python.linting.flake8Enabled": true,
"python.formatting.provider": "black",
"python.analysis.autoImportCompletions": true,
"python.analysis.typeCheckingMode": "off",
"files.exclude": {
"**/__pycache__": true,
"**/*.pyc": true,
"**/.git": false,
"**/wandb": false
}
}

199
README.md
View File

@ -1,199 +0,0 @@
<div align="center">
![logo](./images/logo.png)
</div>
<div align="center">
![visitors](https://visitor-badge.laobi.icu/badge?page_id=jingyaogong/minimind)
[![GitHub Repo stars](https://img.shields.io/github/stars/jingyaogong/minimind?style=social)](https://github.com/jingyaogong/minimind/stargazers)
[![GitHub Code License](https://img.shields.io/github/license/jingyaogong/minimind)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/jingyaogong/minimind)](https://github.com/jingyaogong/minimind/commits/master)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/jingyaogong/minimind/pulls)
[![Collection](https://img.shields.io/badge/🤗-MiniMind%20%20Collection-blue)](https://huggingface.co/collections/jingyaogong/minimind-66caf8d999f5c7fa64f399e5)
</div>
# 📌 数据介绍
## Tokenizer
分词器将单词从自然语言通过“词典”映射到`0, 1, 36`这样的数字,可以理解为数字就代表了单词在“词典”中的页码。
可以选择自己构造词表训练一个“词典”,代码可见`./scripts/train_tokenizer.py`仅供学习参考若非必要无需再自行训练MiniMind已自带tokenizer
或者选择比较出名的开源大模型分词器,
正如同直接用新华/牛津词典的优点是token编码压缩率很好缺点是页数太多动辄数十万个词汇短语
自己训练的分词器,优点是词表长度和内容随意控制,缺点是压缩率很低(例如"hello"也许会被拆分为"h e l l o"
五个独立的token且生僻词难以覆盖。
“词典”的选择固然很重要LLM的输出本质上是SoftMax到词典N个词的多分类问题然后通过“词典”解码到自然语言。
因为MiniMind体积需要严格控制为了避免模型头重脚轻词嵌入embedding层参数在LLM占比太高所以词表长度短短益善。
<details style="color:rgb(128,128,128)">
<summary>Tokenizer介绍</summary>
第三方强大的开源模型例如Yi、qwen、chatglm、mistral、Llama3的tokenizer词表长度如下
<table>
<tr><th>Tokenizer模型</th><th>词表大小</th><th>来源</th></tr>
<tr><td>yi tokenizer</td><td>64,000</td><td>01万物中国</td></tr>
<tr><td>qwen2 tokenizer</td><td>151,643</td><td>阿里云(中国)</td></tr>
<tr><td>glm tokenizer</td><td>151,329</td><td>智谱AI中国</td></tr>
<tr><td>mistral tokenizer</td><td>32,000</td><td>Mistral AI法国</td></tr>
<tr><td>llama3 tokenizer</td><td>128,000</td><td>Meta美国</td></tr>
<tr><td>minimind tokenizer</td><td>6,400</td><td>自定义</td></tr>
</table>
> 👉2024-09-17更新为了防止过去的版本歧义&控制体积minimind所有模型均使用minimind_tokenizer分词废弃所有mistral_tokenizer版本。
```
# 一些自言自语
> 尽管minimind_tokenizer长度很小编解码效率弱于qwen2、glm等中文友好型分词器。
> 但minimind模型选择了自己训练的minimind_tokenizer作为分词器以保持整体参数轻量避免编码层和计算层占比失衡头重脚轻因为minimind的词表大小只有6400。
> 且minimind在实际测试中没有出现过生僻词汇解码失败的情况效果良好。
> 由于自定义词表压缩长度到6400使得LLM总参数量最低只有25.8M。
> 训练数据`tokenizer_train.jsonl`均来自于`匠数大模型数据集`,这部分数据相对次要,如需训练可以自由选择。
```
</details>
## Ⅱ Pretrain数据
经历了MiniMind-V1的低质量预训练数据导致模型胡言乱语的教训`2025-02-05` 之后决定不再采用大规模无监督的数据集做预训练。
进而尝试把[匠数大模型数据集](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)的中文部分提取出来,
清洗出字符`<512`长度的大约1.6GB的语料直接拼接成预训练数据 `pretrain_hq.jsonl`hq即为high
quality当然也还不算high提升数据质量无止尽
文件`pretrain_hq.jsonl` 数据格式为
```bash
{"text": "如何才能摆脱拖延症? 治愈拖延症并不容易,但以下建议可能有所帮助..."}
```
## Ⅲ SFT数据
[匠数大模型SFT数据集](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
“是一个完整、格式统一、安全的大模型训练和研究资源。
从网络上的公开数据源收集并整理了大量开源数据集,对其进行了格式统一,数据清洗,
包含10M条数据的中文数据集和包含2M条数据的英文数据集。”
以上是官方介绍下载文件后的数据总量大约在4B tokens肯定是适合作为中文大语言模型的SFT数据的。
但是官方提供的数据格式很乱全部用来sft代价太大。
我将把官方数据集进行了二次清洗,把含有符号污染和噪声的条目去除;另外依然只保留了总长度`<512`
的内容,此阶段希望通过大量对话补充预训练阶段欠缺的知识。
导出文件为`sft_512.jsonl`(~7.5GB)。
[Magpie-SFT数据集](https://www.modelscope.cn/organization/Magpie-Align)
收集了~1M条来自Qwen2/2.5的高质量对话,我将这部分数据进一步清洗,把总长度`<2048`的部分导出为`sft_2048.jsonl`(~9GB)。
长度`<1024`的部分导出为`sft_1024.jsonl`(~5.5GB)用大模型对话数据直接进行sft就属于“黑盒蒸馏”的范畴。
进一步清洗前两步sft的数据只保留中文字符占比高的内容筛选长度`<512`的对话,得到`sft_mini_512.jsonl`(~1.2GB)。
所有sft文件 `sft_X.jsonl` 数据格式均为
```text
{
"conversations": [
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "你好!"},
{"role": "user", "content": "再见"},
{"role": "assistant", "content": "再见!"}
]
}
```
## Ⅳ RLHF数据
来自[Magpie-DPO数据集](https://www.modelscope.cn/datasets/Magpie-Align/MagpieLM-DPO-Data-v0.1)
大约200k条偏好数据均是英文生成自Llama3.1-70B/8B可以用于训练奖励模型优化模型回复质量使其更加符合人类偏好。
这里将数据总长度`<3000`的内容重组为`dpo.jsonl`(~0.9GB),包含`chosen``rejected`两个字段,`chosen`
为偏好的回复,`rejected`为拒绝的回复。
文件 `dpo.jsonl` 数据格式为
```text
{
"chosen": [
{"content": "Q", "role": "user"},
{"content": "good answer", "role": "assistant"}
],
"rejected": [
{"content": "Q", "role": "user"},
{"content": "bad answer", "role": "assistant"}
]
}
```
## Reason数据集
不得不说2025年2月谁能火的过DeepSeek...
也激发了我对RL引导的推理模型的浓厚兴趣目前已经用Qwen2.5复现了R1-Zero。
如果有时间+效果work但99%基模能力不足我会在之后更新MiniMind基于RL训练的推理模型而不是蒸馏模型。
时间有限,最快的低成本方案依然是直接蒸馏(黑盒方式)。
耐不住R1太火短短几天就已经存在一些R1的蒸馏数据集[R1-Llama-70B](https://www.modelscope.cn/datasets/Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B)、[R1-Distill-SFT](https://www.modelscope.cn/datasets/AI-ModelScope/R1-Distill-SFT)、
[Alpaca-Distill-R1](https://huggingface.co/datasets/shareAI/Alpaca-Distill-R1-ZH)、
[deepseek_r1_zh](https://huggingface.co/datasets/jinliuxi/deepseek_r1_zh)等等,纯中文的数据可能比较少。
最终整合它们,导出文件为`r1_mix_1024.jsonl`,数据格式和`sft_X.jsonl`一致。
## Ⅵ 更多数据集
目前已经有[HqWu-HITCS/Awesome-Chinese-LLM](https://github.com/HqWu-HITCS/Awesome-Chinese-LLM)
在收集和梳理中文LLM相关的开源模型、应用、数据集及教程等资料并持续更新这方面的最新进展。全面且专业Respect
---
## Ⅷ 数据集下载
> [!NOTE]
> 2025-02-05后开源MiniMind最终训练所用的所有数据集因此无需再自行预处理大规模数据集避免重复性的数据处理工作。
MiniMind训练数据集 ([ModelScope](https://www.modelscope.cn/datasets/gongjy/minimind-dataset/files) | [HuggingFace](https://huggingface.co/datasets/jingyaogong))
> 无需全部clone可单独下载所需的文件
将下载的数据集文件放到`./dataset/`目录下(✨为推荐的必须项)
```bash
./dataset/
├── dpo.jsonl (909MB)
├── lora_identity.jsonl (22.8KB)
├── lora_medical.jsonl (34MB)
├── pretrain_hq.jsonl (1.6GB, ✨)
├── r1_mix_1024.jsonl (340MB)
├── sft_1024.jsonl (5.6GB)
├── sft_2048.jsonl (9GB)
├── sft_512.jsonl (7.5GB)
├── sft_mini_512.jsonl (1.2GB, ✨)
└── tokenizer_train.jsonl (1GB)
```
<details style="color:rgb(128,128,128)">
<summary>注:各数据集简介</summary>
* `dpo.jsonl` --RLHF阶段数据集
* `lora_identity.jsonl` --自我认知数据集例如你是谁我是minimind...推荐用于lora训练亦可用于全参SFT勿被名字局限
* `lora_medical.jsonl` --医疗问答数据集推荐用于lora训练亦可用于全参SFT勿被名字局限
* `pretrain_hq.jsonl`✨ --预训练数据集整合自jiangshu科技
* `r1_mix_1024.jsonl` --DeepSeek-R1-1.5B蒸馏数据每条数据字符最大长度为1024因此训练时设置max_seq_len=1024
* `sft_1024.jsonl` --整合自Qwen2.5蒸馏数据是sft_2048的子集每条数据字符最大长度为1024因此训练时设置max_seq_len=1024
* `sft_2048.jsonl` --整合自Qwen2.5蒸馏数据每条数据字符最大长度为2048因此训练时设置max_seq_len=2048
* `sft_512.jsonl` --整合自匠数科技SFT数据每条数据字符最大长度为512因此训练时设置max_seq_len=512
* `sft_mini_512.jsonl`✨ --极简整合自匠数科技SFT数据+Qwen2.5蒸馏数据用于快速训练Zero模型每条数据字符最大长度为512因此训练时设置max_seq_len=512
* `tokenizer_train.jsonl` --均来自于`匠数大模型数据集`这部分数据相对次要不推荐自己重复训练tokenizer理由如上如需自己训练tokenizer可以自由选择数据集。
</details>
![dataset](./images/dataset.jpg)
<details style="color:rgb(128,128,128)">
<summary>说明 & 推荐训练方案</summary>
* MiniMind2 Series均经过共约20GB语料训练大约4B tokens即对应上面的数据组合训练结果开销💰💰💰💰💰💰💰💰效果😊😊😊😊😊😊
* 想要最快速度从0实现Zero模型推荐使用`pretrain_hq.jsonl` + `sft_mini_512.jsonl` 的数据组合,具体花销和效果可查看下文表格(开销:💰,效果:😊😊)
* 推荐具备一定算力资源或更在意效果的朋友可以考虑前者完整复现MiniMind2仅有单卡GPU或在乎短时间快速复现的朋友强烈推荐后者
* 【折中方案】亦可选择例如`sft_mini_512.jsonl``sft_1024.jsonl`中等规模数据进行自由组合训练(开销:💰💰💰,效果:😊😊😊😊)。
</details>

126
README_accelerate.md Normal file
View File

@ -0,0 +1,126 @@
# 使用Accelerate+DeepSpeed进行分布式训练
本文档介绍如何使用Accelerate和DeepSpeed进行MiniMind模型的分布式训练。
## 环境准备
首先,确保安装了必要的依赖:
```bash
pip install accelerate deepspeed
```
## 配置文件说明
### 1. DeepSpeed配置文件 (ds_config.json)
DeepSpeed配置文件定义了优化器、学习率调度器和ZeRO优化等参数。主要配置包括
- **ZeRO优化**使用ZeRO-2进行优化可以减少GPU内存使用
- **优化器设置**使用AdamW优化器
- **混合精度训练**支持FP16和BF16
- **梯度累积**:通过"auto"自动设置,与训练脚本参数保持一致
### 2. Accelerate配置文件 (accelerate_config.yaml)
Accelerate配置文件定义了分布式训练的基本设置包括
- **分布式类型**使用DeepSpeed
- **混合精度**使用BF16
- **进程数量**设置为4可根据GPU数量调整
- **DeepSpeed配置**指向ds_config.json文件
## 训练脚本说明
新的训练脚本`train_pretrain_accelerate.py`基于原有的`train_pretrain.py`修改而来,主要变化包括:
1. 使用Accelerator替代了PyTorch原生的分布式功能
2. 移除了torchrun相关的分布式初始化代码
3. 使用Accelerator的API进行模型、优化器和数据加载器的准备
4. 使用Accelerator的API进行反向传播和梯度裁剪
5. 处理了位置编码和未使用参数的问题
## 启动训练
有两种方式启动训练:
### 方法1使用预先配置的accelerate配置文件
```bash
accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
--epochs 3 \
--batch_size 24 \
--learning_rate 2e-4 \
--dtype bfloat16 \
--accumulation_steps 32 \
--grad_clip 1.0 \
--log_interval 100 \
--save_interval 10000 \
--dim 1024 \
--n_layers 32 \
--max_seq_len 1024 \
--use_flash_attn \
--profile \
--profile_interval 10
```
### 方法2使用命令行参数直接配置accelerate
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--multi_gpu \
--num_processes=4 \
--mixed_precision=bf16 \
--main_process_port=29500 \
--deepspeed_config_file ds_config.json \
train_pretrain_accelerate.py \
--epochs 3 \
--batch_size 24 \
--learning_rate 2e-4 \
--dtype bfloat16 \
--accumulation_steps 32 \
--grad_clip 1.0 \
--log_interval 100 \
--save_interval 10000 \
--dim 1024 \
--n_layers 32 \
--max_seq_len 1024 \
--use_flash_attn \
--profile \
--profile_interval 10
```
也可以直接使用提供的脚本:
```bash
bash run_accelerate.sh
```
## Accelerate与DeepSpeed配置的关系
1. **Accelerate**是一个高级API用于简化分布式训练的设置和启动它可以与多种分布式训练后端如DeepSpeed、FSDP等一起使用。
2. **DeepSpeed**是一个优化库专注于大规模模型训练的内存优化和性能提升提供了ZeRO优化等功能。
3. **配置关系**
- Accelerate配置文件YAML定义了使用哪种分布式后端以及基本的分布式设置
- DeepSpeed配置文件JSON定义了DeepSpeed特有的优化参数
- Accelerate通过`deepspeed_config_file`参数引用DeepSpeed配置文件
## 注意事项
1. **位置编码处理**
- 在模型中,`pos_cis`是一个复数张量,在分布式训练中需要特别处理
- 在新的训练脚本中我们使用Accelerator的API来处理这个问题不再需要`_ddp_params_and_buffers_to_ignore`
2. **未使用参数处理**
- 原代码中使用`find_unused_parameters=True`来处理未使用的参数
- 在新的训练脚本中我们直接使用Accelerator的API它会自动处理这个问题
3. **混合精度训练**
- DeepSpeed配置文件中的`fp16``bf16`设置为`"auto"`
- 实际使用的精度由Accelerate的`--mixed_precision`参数决定
4. **梯度累积**
- DeepSpeed配置文件中的`gradient_accumulation_steps`设置为`"auto"`
- 实际的梯度累积步数由训练脚本的`--accumulation_steps`参数决定

22
ReadMe.md Normal file
View File

@ -0,0 +1,22 @@
## 安装环境
1. 创建conda环境
```bash
conda create -n accelerate python=3.10
conda activate accelerate
```
2. 根据当前系统的cuda版本安装对应的torch、torchvision和torchaudio
3. 根据当前环境的torch和torchvision安装accelerate和deepspeed
4. 安装其他包
```bash
pip install -r requirements.txt
```
## 修改模型
1. 一般情况只修改 `model`文件夹的文件
## 运行
1. 如果在4090或者4070ti上运行 `bash run_file/DynamicKV-LLM_Mini_Minimind.sh`
2. 如果在4张A800上运行 `bash run_file/DynamicKV-LLM_Small_Minimind.sh`

View File

@ -1,26 +0,0 @@
# 1. 元数据:需要修改,请为该实验配置名称和描述
name: ycz-minimind-test
description: 测试minimind-test
# 2. 运行环境:一般不修改,如有需求可以手动替换为指定镜像
environment:
image: determinedai/pytorch-ngc:0.38.0 # 此项无需修改
# 3. 指定NAS上的数据集: 需要修改仅修改bind_mounts字段container_path和read_only无需修改
#将<YOUR_DATASET_FOLDER_NAME>替换为您存放在NAS上Volume1/Share/datasets/的数据集文件夹名称
# 请再次确保您已在 NAS上的Volume1/Share/datasets/存放了<YOUR_DATASET_FOLDER_NAME>数据集
# 4. 计算资源:无需修改
resources:
slots_per_trial: 1 # 此项无需修改
resource_pool: rtx4090 # 此项无需修改
# 5. 搜索器:无需修改
searcher:
name: single
metric: test_accuracy
smaller_is_better: false
# 6. 启动入口:无需修改
entrypoint: sh startup.sh

View File

@ -1,6 +0,0 @@
def main():
print("Hello from minimind!")
if __name__ == "__main__":
main()

View File

@ -19,7 +19,6 @@ class LMConfig(PretrainedConfig):
rope_theta: int = 1e6,
dropout: float = 0.0,
flash_attn: bool = True,
embeddings_epoch: int = 2,
####################################################
# DB related configurations
####################################################
@ -40,13 +39,6 @@ class LMConfig(PretrainedConfig):
####################################################
knowledge_num: int = 64*64,
knowledge_length: int = 8,
knowledge_dim: int = 128,
####################################################
# Triple extraction related configurations
####################################################
max_subject_len: int = 8,
max_predicate_len: int = 4,
max_object_len: int = 8,
**kwargs,
):
self.dim = dim
@ -61,7 +53,6 @@ class LMConfig(PretrainedConfig):
self.rope_theta = rope_theta
self.dropout = dropout
self.flash_attn = flash_attn
self.embeddings_epoch = embeddings_epoch
####################################################
# DB related configurations
####################################################
@ -81,11 +72,4 @@ class LMConfig(PretrainedConfig):
####################################################
self.knowledge_num = knowledge_num
self.knowledge_length = knowledge_length
self.knowledge_dim = knowledge_dim
####################################################
# Triple extraction related configurations
####################################################
self.max_subject_len = max_subject_len
self.max_predicate_len = max_predicate_len
self.max_object_len = max_object_len
super().__init__(**kwargs)

View File

@ -9,75 +9,10 @@ import torch
from sklearn.model_selection import train_test_split
import os
import ast
from tqdm import tqdm
os.environ["TOKENIZERS_PARALLELISM"] = "true"
def process_sample_filter(data_args):
"""处理单个样本的过滤逻辑"""
sample, valid_predicates = data_args
if 'target' in sample and isinstance(sample['target'], list):
# 过滤target中的低频谓词
valid_targets = []
for triple in sample['target']:
if isinstance(triple, dict) and 'predicate' in triple:
if triple['predicate'] in valid_predicates:
valid_targets.append(triple)
# 如果还有有效的target保留这个样本
if valid_targets:
sample['target'] = valid_targets
return sample
else:
return None
else:
# 如果没有target信息保留样本
return sample
def process_sample_validation(data_args):
"""处理单个样本的验证逻辑"""
sample, predicate_vocab = data_args
if not isinstance(sample, dict) or 'text' not in sample:
return None
targets = sample.get('target', [])
if not isinstance(targets, list) or len(targets) == 0:
# 如果没有有效的target创建一个默认的
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
else:
# 验证并选择target优先选择占比小的谓词
selected_target = None
min_percentage = float('inf')
for triple in targets:
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
predicate = triple['predicate']
# 使用predicate_vocab中的统计信息
if predicate in predicate_vocab:
stats = predicate_vocab[predicate]
if isinstance(stats, dict) and 'percentage' in stats:
percentage = stats['percentage']
if percentage < min_percentage:
min_percentage = percentage
selected_target = triple
elif selected_target is None:
selected_target = triple
elif selected_target is None:
selected_target = triple
# 如果没有找到有效的target使用默认值
if selected_target is None:
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
return {
'text': sample['text'],
'target': selected_target # 只保留一个target
}
class PretrainDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
@ -98,14 +33,9 @@ class PretrainDataset(Dataset):
def __getitem__(self, index):
sample = self.samples[index]
text = str(sample['text'])
# 检查并添加<|im_start|>和<|im_end|>如果不存在
if not text.startswith(self.tokenizer.bos_token):
text = f"{self.tokenizer.bos_token}{text}"
if not text.endswith(self.tokenizer.eos_token):
text = f"{text}{self.tokenizer.eos_token}"
# 构建输入文本
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
encoding = self.tokenizer(
text,
max_length=self.max_length,
@ -128,8 +58,8 @@ class SFTDataset(Dataset):
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = self.load_data(jsonl_path)
self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)
@ -196,8 +126,8 @@ class DPODataset(Dataset):
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
with open(file_path, 'r', encoding='utf-8') as f:
self.data = []
for line in f:
@ -266,249 +196,14 @@ class DPODataset(Dataset):
return loss_mask
class TriplePretrainDataset(Dataset):
"""
优化的三元组预训练数据集
- 每个样本只保留一个target三元组
- 预先tokenize所有数据
- 使用进度条显示处理进度
"""
def __init__(self, data_path=None, predicate_vocab_path=None, samples = None,tokenizer=None, max_length=512):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.val_samples = None
self.predicate_to_id = {} # 初始化
if samples is None:
self.predicate_vocab = self.load_predicate_vocab(predicate_vocab_path)
print("🚀 开始加载和预处理三元组数据...")
self.samples,self.val_samples = self.load_and_preprocess_data(data_path)
print("🚀 加载和预处理三元组数据完成")
else:
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
data_filename = os.path.basename(data_path).split('.')[0]
predicate_to_id_path = os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json')
self.predicate_to_id = self.load_predicate_vocab(predicate_to_id_path)
self.samples = samples
print("🚀 加载和预处理三元组数据完成")
def load_predicate_vocab(self, path):
with open(path, 'r', encoding='utf-8') as f:
predicate_vocab = json.load(f)
return predicate_vocab
def get_val_samples(self):
return self.val_samples
def clear_cache(self, data_path):
"""清除缓存文件"""
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
data_filename = os.path.basename(data_path).split('.')[0]
cache_files = [
os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
os.path.join(cache_dir, f'{data_filename}_val_samples.json')
]
for cache_file in cache_files:
if os.path.exists(cache_file):
os.remove(cache_file)
print(f"🗑️ 已删除缓存文件: {cache_file}")
if os.path.exists(cache_dir) and not os.listdir(cache_dir):
os.rmdir(cache_dir)
print(f"🗑️ 已删除空的缓存目录: {cache_dir}")
def load_and_preprocess_data(self, path):
"""加载并预处理三元组数据"""
# 生成缓存文件名(基于数据文件路径)
cache_dir = os.path.join(os.path.dirname(path), 'cache')
os.makedirs(cache_dir, exist_ok=True)
data_filename = os.path.basename(path).split('.')[0]
cache_files = {
'predicate_vocab': os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
'predicate_to_id': os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
'train_samples': os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
'val_samples': os.path.join(cache_dir, f'{data_filename}_val_samples.json')
}
# 检查缓存文件是否存在
cache_exists = all(os.path.exists(cache_file) for cache_file in cache_files.values())
if cache_exists:
print("📁 发现缓存文件,直接加载...")
# 从缓存加载
with open(cache_files['predicate_vocab'], 'r', encoding='utf-8') as f:
self.predicate_vocab = json.load(f)
with open(cache_files['predicate_to_id'], 'r', encoding='utf-8') as f:
self.predicate_to_id = json.load(f)
with open(cache_files['train_samples'], 'r', encoding='utf-8') as f:
train_samples = json.load(f)
with open(cache_files['val_samples'], 'r', encoding='utf-8') as f:
val_samples = json.load(f)
print(f"✅ 从缓存加载完成:")
print(f"✅ 谓词词表大小: {len(self.predicate_vocab)}")
print(f"✅ 训练集大小: {len(train_samples)}")
print(f"✅ 测试集大小: {len(val_samples)}")
return train_samples, val_samples
# 缓存不存在,重新处理数据
print("📂 缓存不存在,开始加载和处理原始数据...")
# 1. 加载原始数据
print("📂 加载原始数据...")
if path.endswith('.json'):
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
elif path.endswith('.jsonl'):
data = []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
data.append(json.loads(line.strip()))
else:
raise ValueError(f"Unsupported file format: {path}")
print(f"📊 原始数据量: {len(data)} 个样本")
# 2. 使用self.predicate_vocab过滤占比小于0.01%的谓词数据
print("🔍 过滤低频谓词数据...")
print(f"📊 谓词统计数据: 总共{len(self.predicate_vocab)}个谓词")
# 3.获取占比大于等于0.01%的谓词
valid_predicates = set()
for predicate, stats in self.predicate_vocab.items():
if isinstance(stats, dict) and 'percentage' in stats:
if stats['percentage'] >= 0.01:
valid_predicates.add(predicate)
else:
# 如果不是统计格式,假设是有效谓词
valid_predicates.add(predicate)
print(f"📊 占比≥0.01%的谓词: {len(valid_predicates)}")
# 4.过滤数据:去除包含低频谓词的数据(单进程处理)
original_count = len(data)
filtered_data = []
print("🚀 开始过滤低频谓词数据...")
for sample in tqdm(data, desc="过滤低频谓词"):
result = process_sample_filter((sample, valid_predicates))
if result is not None:
filtered_data.append(result)
data = filtered_data
print(f"✅ 过滤完成: 去除前{original_count}条,去除后{len(data)}")
# 5. 去除self.predicate_vocab中占比小于0.01%的谓词,并创建谓词到序号的映射
print("🔍 更新谓词词表并创建序号映射...")
original_vocab_size = len(self.predicate_vocab)
filtered_predicate_vocab = {}
for predicate, stats in self.predicate_vocab.items():
if isinstance(stats, dict) and 'percentage' in stats:
if stats['percentage'] >= 0.01:
filtered_predicate_vocab[predicate] = stats
else:
# 如果不是统计格式,保留
filtered_predicate_vocab[predicate] = stats
# 创建谓词到序号的映射字典
self.predicate_to_id = {predicate: idx for idx, predicate in enumerate(filtered_predicate_vocab.keys())}
self.predicate_vocab = filtered_predicate_vocab
print(f"✅ 谓词词表更新: 去除前{original_vocab_size}个,去除后{len(self.predicate_vocab)}")
print(f"✅ 谓词映射创建: {len(self.predicate_to_id)}个谓词对应序号")
# 6. 数据验证和筛选只保留一个target优先选择占比小的谓词以平衡数据单进程处理
print("🔍 验证数据格式并选择单个target平衡数据...")
valid_samples = []
print("🚀 开始验证数据格式...")
for sample in tqdm(data, desc="验证数据格式"):
result = process_sample_validation((sample, self.predicate_vocab))
if result is not None:
valid_samples.append(result)
print(f"✅ 有效样本数: {len(valid_samples)}")
# 7.拆分训练集合与测试集合
import random
random.seed(42)
val_samples = random.sample(valid_samples, min(1000, len(valid_samples)))
train_samples = [sample for sample in valid_samples if sample not in val_samples]
print(f"✅ 训练集大小: {len(train_samples)}")
print(f"✅ 测试集大小: {len(val_samples)}")
# 8. 保存到缓存文件
print("💾 保存处理结果到缓存文件...")
with open(cache_files['predicate_vocab'], 'w', encoding='utf-8') as f:
json.dump(self.predicate_vocab, f, ensure_ascii=False, indent=2)
with open(cache_files['predicate_to_id'], 'w', encoding='utf-8') as f:
json.dump(self.predicate_to_id, f, ensure_ascii=False, indent=2)
with open(cache_files['train_samples'], 'w', encoding='utf-8') as f:
json.dump(train_samples, f, ensure_ascii=False, indent=2)
with open(cache_files['val_samples'], 'w', encoding='utf-8') as f:
json.dump(val_samples, f, ensure_ascii=False, indent=2)
print("✅ 缓存文件保存完成")
return train_samples, val_samples
def __len__(self):
return len(self.samples)
def _triple_to_sentence(self, triple):
"""将三元组转换为句子格式"""
return f"{triple['subject']} {triple['predicate']} {triple['object']}"
def __getitem__(self, index):
"""返回数据,用于谓词分类任务"""
sample = self.samples[index]
# 在运行时tokenize输入文本
input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}"
encoding = self.tokenizer(
input_text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding.input_ids.squeeze()
loss_mask = (input_ids != self.tokenizer.pad_token_id)
# 获取谓词分类标签
target_predicate = sample['target']['predicate']
predicate_label = self.predicate_to_id.get(target_predicate) # 默认为0如果找不到
# 构建训练数据
X = input_ids[:-1]
loss_mask = loss_mask[1:]
return {
'input_ids': X,
'labels': torch.tensor(predicate_label, dtype=torch.long), # 谓词分类标签
'loss_mask': loss_mask
}
class RLAIFDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = self.load_data(jsonl_path)
self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)

View File

@ -14,7 +14,7 @@
},
{
"id": 1,
"content": "<|im_start|>",
"content": "<s>",
"single_word": false,
"lstrip": false,
"rstrip": false,
@ -23,7 +23,7 @@
},
{
"id": 2,
"content": "<|im_end|>",
"content": "</s>",
"single_word": false,
"lstrip": false,
"rstrip": false,
@ -56,8 +56,8 @@
"ignore_merges": false,
"vocab": {
"<unk>": 0,
"<|im_start|>": 1,
"<|im_end|>": 2,
"<s>": 1,
"</s>": 2,
"!": 3,
"\"": 4,
"#": 5,

View File

@ -12,7 +12,7 @@
"special": true
},
"1": {
"content": "<|im_start|>",
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
@ -20,7 +20,7 @@
"special": true
},
"2": {
"content": "<|im_end|>",
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
@ -29,9 +29,9 @@
}
},
"additional_special_tokens": [],
"bos_token": "<|im_start|>",
"bos_token": "<s>",
"clean_up_tokenization_spaces": false,
"eos_token": "<|im_end|>",
"eos_token": "</s>",
"legacy": true,
"model_max_length": 32768,
"pad_token": "<unk>",
@ -39,5 +39,5 @@
"spaces_between_special_tokens": false,
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": "<unk>",
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\n你是 MiniMind是一个有用的人工智能助手。<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<s>system\\n' + system_message + '</s>\\n' }}{% else %}{{ '<s>system\\n你是 MiniMind是一个有用的人工智能助手。</s>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
}

File diff suppressed because one or more lines are too long

View File

@ -2,8 +2,7 @@ import math
import struct
import inspect
import time
import gc
#子空间二维分解+梯度更新
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
@ -12,9 +11,14 @@ import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch import nn, einsum
from einops import rearrange, repeat
def exists(val):
return val is not None
# RMSNorm 类定义了一个用于归一化输入张量的模块。
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
@ -27,7 +31,7 @@ class RMSNorm(torch.nn.Module):
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
# precompute_pos_cis 函数用于预计算位置编码(复数版本)。
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
@ -35,7 +39,7 @@ def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
# apply_rotary_emb 函数用于应用旋转位置编码(复数版本)。
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
@ -51,244 +55,104 @@ def apply_rotary_emb(xq, xk, pos_cis):
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class KnowledgeDataset(nn.Module):
def __init__(self, params, tok_embeddings, is_train=True):
super().__init__()
self.is_train = is_train
self.params = params
self.tok_embeddings = tok_embeddings
# precompute_pos_cis_real 函数用于预计算位置编码(实数版本)。
def precompute_pos_cis_real(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
"""使用实数张量实现位置编码,避免使用复数张量
# 嵌入参数
self.knowledge_dim = params.knowledge_dim
self.key_dim = self.knowledge_dim // 2
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.knowledge_dim, bias=False),
)
这个函数与precompute_pos_cis完全等价但使用实数张量而非复数张量
原始函数生成形状为[seq_len, dim//2]的复数张量其中实部全为1虚部为旋转角度
这个函数生成形状为[seq_len, dim]的实数张量其中偶数索引是cos(角度)奇数索引是sin(角度)
"""
# 确保dim是偶数
if dim % 2 != 0:
raise ValueError(f"维度必须是偶数,但得到了 {dim}")
## 数据库参数
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
# 修改键存储为二维分解空间,设置为可训练参数
self.num_keys = int(math.sqrt(self.knowledge_num))
# 确保keys是可训练参数
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
# 复制原始函数的频率计算逻辑
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
# 计算step数目用于动态调整权重
self.step_counter = 0
# 计算cos和sin值
# 在复数版本中pos_cis = torch.polar(torch.ones_like(freqs), freqs)
# 等价于 cos(freqs) + i*sin(freqs)
cos = torch.cos(freqs)
sin = torch.sin(freqs)
# 移除批次计数器和更新频率相关代码
# 创建实数张量交错排列cos和sin
pos_emb = torch.zeros((end, dim), device=freqs.device)
pos_emb[:, 0::2] = cos # 偶数索引放cos
pos_emb[:, 1::2] = sin # 奇数索引放sin
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
if self.is_train == False:
return all_scores, all_indices
batch_size = all_scores.size(0)
device = all_scores.device
dtype = all_scores.dtype
return pos_emb
# 记录进入智能选择前的内存状态
if hasattr(self, 'step_counter'):
self.step_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.step_counter % 50 == 0: # 每50次调用记录一次
# if torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_ENTER] Step {self.step_counter}: GPU Memory: {allocated_before:.2f}GB")
# apply_rotary_emb_real 函数用于应用旋转位置编码(实数版本)。
def apply_rotary_emb_real(xq, xk, pos_emb):
"""使用实数张量实现旋转位置编码,避免使用复数张量
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
这个函数与apply_rotary_emb完全等价但使用实数张量而非复数张量
原始函数将输入张量转换为复数形式与位置编码相乘然后再转回实数形式
这个函数直接使用实数运算实现相同的旋转操作
"""
# 获取形状信息
bsz, seq_len, n_heads, head_dim = xq.shape
# 预先计算所有候选条目的嵌入(批量优化)
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
# 确保pos_emb形状正确
assert pos_emb.shape[0] >= seq_len, f"位置编码长度 {pos_emb.shape[0]} 小于序列长度 {seq_len}"
assert pos_emb.shape[1] == head_dim, f"位置编码维度 {pos_emb.shape[1]} 与头维度 {head_dim} 不匹配"
# 批量计算唯一候选条目的嵌入
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
)
# 截取需要的位置编码长度
pos_emb = pos_emb[:seq_len]
unique_candidate_features = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
).mean(dim=1) # [num_unique_candidates, dim]
# 归一化候选特征(优化相似度计算)
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
normalized_queries = F.normalize(query_features, dim=-1)
# 将pos_emb调整为广播形状 [1, seq_len, 1, head_dim]
pos_emb = pos_emb.unsqueeze(0).unsqueeze(2)
# 收集所有batch的best_tokens
batch_best_tokens = []
batch_best_tokens_embeddings = []
# 将head_dim分成两半
half_head_dim = head_dim // 2
for batch_idx in range(batch_size):
indices = all_indices[batch_idx]
# 获取当前batch候选条目对应的特征索引
start_idx = batch_idx * len(indices)
end_idx = start_idx + len(indices)
batch_inverse_indices = inverse_indices[start_idx:end_idx]
# 使用预计算的归一化特征进行优化相似度计算
batch_candidate_features = normalized_candidates[batch_inverse_indices]
query_feature = normalized_queries[batch_idx]
# 使用矩阵乘法计算余弦相似度
similarity_scores = torch.mv(batch_candidate_features, query_feature)
# 找到最大相似度分数的索引
max_similarity_idx = torch.argmax(similarity_scores)
# 获取最大相似度对应的候选条目索引
best_candidate_idx = indices[max_similarity_idx]
# 获取对应的tokens
best_tokens = self.knowledge_dataset[best_candidate_idx]
best_tokens_embeddings = self.tok_embeddings(best_tokens)
# 将当前batch的best_tokens添加到列表中
batch_best_tokens.append(best_tokens)
batch_best_tokens_embeddings.append(best_tokens_embeddings)
# 提取cos和sin值偶数索引是cos奇数索引是sin
cos = pos_emb[..., 0::2]
sin = pos_emb[..., 1::2]
# 将所有batch的best_tokens堆叠成一个张量
# [batch_size, knowledge_length]
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
# 将xq和xk重新排列以便进行旋转操作
# 原始复数版本中xq和xk被重塑为复数张量其中实部和虚部交错排列
# 在实数版本中,我们需要将偶数索引和奇数索引分开处理
# 清理中间张量以防止内存泄漏
del all_candidate_indices, unique_indices, inverse_indices
del unique_candidate_features, normalized_candidates, normalized_queries
del batch_best_tokens, batch_best_tokens_embeddings
del flat_tokens, flat_embeddings, pre_update_embeddings
# 记录退出智能选择后的内存状态(已禁用以提高性能)
# if hasattr(self, 'step_counter') and self.step_counter % 50 == 0:
# if torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_EXIT] Step {self.step_counter}: GPU Memory: {allocated_after:.2f}GB")
# 强制垃圾回收(仅在监控步骤)
if hasattr(self, 'step_counter') and self.step_counter % 100 == 0:
gc.collect()
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
# 分离偶数和奇数索引
xq_even = xq[..., 0::2] # 偶数索引,对应复数的实部
xq_odd = xq[..., 1::2] # 奇数索引,对应复数的虚部
xk_even = xk[..., 0::2]
xk_odd = xk[..., 1::2]
return all_best_tokens, all_best_tokens_embeddings
def search_index(self, x):
batch_size, seq_len, dim = x.shape
# 应用旋转(等价于复数乘法)
# (a + bi)(cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i
# 其中a是偶数索引b是奇数索引
xq_out_even = xq_even * cos - xq_odd * sin # 新的偶数索引(实部)
xq_out_odd = xq_even * sin + xq_odd * cos # 新的奇数索引(虚部)
xk_out_even = xk_even * cos - xk_odd * sin
xk_out_odd = xk_even * sin + xk_odd * cos
# 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim]
# 重新组合偶数和奇数索引
xq_out = torch.zeros_like(xq)
xk_out = torch.zeros_like(xk)
xq_out[..., 0::2] = xq_out_even
xq_out[..., 1::2] = xq_out_odd
xk_out[..., 0::2] = xk_out_even
xk_out[..., 1::2] = xk_out_odd
# 2. 生成查询向量并重塑为两个子查询
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
return xq_out.type_as(xq), xk_out.type_as(xk)
# 3. 计算每个子空间的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# repeat_kv 函数用于重复键值对。
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
# 4. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 5. 组合两个子空间的结果
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
return best_tokens, best_tokens_embeddings
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 监控交叉注意力开始时的内存(已禁用以提高性能)
if not hasattr(self, 'call_counter'):
self.call_counter = 0
self.call_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_ENTER] Call {self.call_counter}: GPU Memory: {allocated_before:.2f}GB")
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
# 清理中间张量
del q, k, v, attn_scores, attn_weights
# 监控交叉注意力结束时的内存(已禁用以提高性能)
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_EXIT] Call {self.call_counter}: GPU Memory: {allocated_after:.2f}GB")
return context
class Attention(nn.Module):
def __init__(self, args: LMConfig):
@ -314,14 +178,56 @@ class Attention(nn.Module):
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
pos_cis: torch.Tensor,
db_value=None):
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换得到查询、键和值。
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) #将变换后的张量xq重塑为形状为(bsz, seq_len, n_local_heads, head_dim)的形状。
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
# 应用旋转位置编码(使用实数版本)
xq, xk = apply_rotary_emb_real(xq, xk, pos_cis)
# kv_cache实现 REMOVED
# if past_key_value is not None:
# xk = torch.cat([past_key_value[0], xk], dim=1)
# xv = torch.cat([past_key_value[1], xv], dim=1)
# past_kv = (xk, xv) if use_cache else None
# 重复键值对
xq, xk, xv = (
xq.transpose(1, 2),
repeat_kv(xk, self.n_rep).transpose(1, 2),
repeat_kv(xv, self.n_rep).transpose(1, 2)
)
# 如果提供了db_value根据头的数量调整它的形状并与xv合并
if db_value is not None:
# 确保db_value的形状与xv兼容假设db_value形状为[B, N, H, D]
if db_value.ndim == 4: # [B, N, H, D]
db_value = db_value.transpose(1, 2) # -> [B, H, N, D]
# 检查是否需要调整D维度
if db_value.shape[-1] != xv.shape[-1]:
# 如果db_value的维度与xv不同可以添加一个投影层
# 或者在这里使用简单的调整方法
# 这里我们简单地通过均值池化或重复来调整维度
if db_value.shape[-1] > xv.shape[-1]:
# 降维
factor = db_value.shape[-1] // xv.shape[-1]
db_value = db_value.view(bsz, self.n_local_heads, seq_len, factor, xv.shape[-1])
db_value = db_value.mean(dim=3)
else:
# 升维
factor = xv.shape[-1] // db_value.shape[-1]
db_value = db_value.unsqueeze(-1).repeat(1, 1, 1, 1, factor)
db_value = db_value.view(bsz, self.n_local_heads, seq_len, xv.shape[-1])
# 将db_value与xv相加或融合
# 这里我们简单地将它们相加,但你也可以使用其他融合方法
xv = xv + db_value
# 使用Flash Attention
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
@ -342,6 +248,53 @@ class Attention(nn.Module):
return output
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
return context
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
@ -474,30 +427,169 @@ class MOEFeedForward(nn.Module):
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
def __init__(self, layer_id: int, config: LMConfig):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.self_attention = Attention(config)
self.cross_attention = CrossAttention(config)
self.knowledge_dataset = knowledge_dataset
self.attention = Attention(config)
self.cross_att = CrossAttention(config)
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, x, pos_cis):
h_attn = self.self_attention(
# 假设num_experts是已定义的总专家数量的平方根
# 查询生成的参数
# 创建查询生成模块
# if weight_down_embed is not None:
# self.to_queries = nn.Sequential(
# nn.Linear(config.dim, self.dim_key * 2, bias=False),
# # nn.Unflatten(2, (2, self.n_heads, self.dim_key)) # 替代Rearrange
# )
# # 超参数
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
def forward(self, x, db_value, pos_cis):
# import pdb;pdb.set_trace()
# db_value = None
# # 如果有weight_down_embed使用Product Key机制
# if self.weight_down_embed is not None:
# # 1. 生成queries
# batch_size, seq_len, dim = x.shape
# # collapse sequence dimension by averaging
# x_flat = x.mean(dim=1) # [batch_size, dim]
# queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
# queries = queries.reshape(batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
# queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
# # 2. 计算queries与keys的相似度
# sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# # 3. 在两个子空间分别做top-k
# scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
# scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
# indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# # 4. 组合两个子空间的分数和索引
# all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
# all_scores = all_scores.view(*all_scores.shape[:-2], -1)
# all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
# all_indices = all_indices.view(*all_indices.shape[:-2], -1)
# # 5. 最终top-k选择
# scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
# indices = all_indices.gather(-1, pk_indices)
# # 6. 从embedding中获取专家值
# # 从embedding中获取值
# flat_indices = indices.view(-1) # 将索引展平为一维张量
# db_values = self.weight_down_embed(flat_indices)
# # 重塑回原始形状
# db_value = db_values.view(batch_size, -1, dim)
# 注意力计算
h_attn = self.attention(
self.attention_norm(x),
pos_cis
pos_cis,
db_value=db_value
)
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
h_attn = self.cross_attention(h_attn, db_embeddings)
h_attn = self.cross_att(h_attn, db_value)
# 残差连接
h = x + h_attn
# 前馈神经网络
out = h + self.feed_forward(self.ffn_norm(h))
return out
return out
class ExtractDB(nn.Module):
def __init__(self,params):
# 修改专家数量和知识维度,确保能开方
super().__init__()
self.batch_size = None
self.dim = params.dim
self.dim_key = self.dim // 2
self.knowledge_num = params.knowledge_num # 100专家确保是完全平方数
# 将knowledge_dim设置为与head_dim相同以便在attention中直接使用
self.head_dim = params.dim // params.n_heads
self.knowledge_length = params.knowledge_length
# 使用register_buffer代替nn.Parameter避免梯度问题
# self.register_buffer('weight_down_embed', torch.randn(self.knowledge_num, self.knowledge_length) * 0.02)
self.register_buffer('weight_down_embed',torch.randint(low=0,high=6400, size=(self.knowledge_num, self.knowledge_length),dtype=torch.long))
self.num_keys = int(math.sqrt(self.knowledge_num)) if self.knowledge_num > 0 else 0
self.product_key_topk = min(16, self.num_keys)
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
self.num_experts_per_head_topk = 1
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.dim_key * 2, bias=False),
)
def q_to_k(self,x):
# 1. 生成queries
self.batch_size, seq_len, dim = x.shape
# collapse sequence dimension by averaging
x_flat = x.mean(dim=1) # [batch_size, dim]
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
# 2. 计算queries与keys的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 3. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 4. 组合两个子空间的分数和索引
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
all_scores = all_scores.view(*all_scores.shape[:-2], -1)
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
# 5. 最终top-k选择
scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
indices = all_indices.gather(-1, pk_indices)
flat_indices = indices.view(-1)
return flat_indices
def get_data(self, index):
# 直接从GPU获取embedding
db_values = self.weight_down_embed[index]#变成token了所以是1,后续再过emb
# db_value = db_values.view(self.batch_size,-1)
return db_values
@torch.no_grad()
def updata_value(self, k, v):#要加一个从向量返回index的过程
# 直接更新buffer上的值 (不需要梯度)
v_reshaped = v.view(v.size(0), -1)
# 确保数据类型匹配
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
self.weight_down_embed[k] = v_reshaped
class MiniMindLM(PreTrainedModel):
@ -509,39 +601,113 @@ class MiniMindLM(PreTrainedModel):
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
# 移除旧的weight_down_embed声明
self.extract_db = ExtractDB(self.params)
# 将self.weight_down_embed传递给每个MiniMindBlock
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.database_output = nn.Linear(params.dim, params.knowledge_length, bias=False)
self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
self.database_output.weight = self.output.weight
# Calculate input dimension
input_dim = (self.params.max_seq_len-1)*self.params.n_layers
# Use a bottleneck architecture to reduce parameters
bottleneck_dim = 256 # Significantly smaller bottleneck dimension
# Factorized shared downsampling using two smaller convolutions
self.shared_downsample = nn.Sequential(
# First reduce input dimension to bottleneck
nn.Conv1d(input_dim, bottleneck_dim, kernel_size=1, padding='same'),
nn.ReLU(), # Non-linearity to improve representation capacity
# Then expand to target dimension
nn.Conv1d(bottleneck_dim, 128*8, kernel_size=1, padding='same')
)
# Specific layers for v path
self.downsample_v_specific = nn.Sequential(
nn.Conv1d(128*8, 128, kernel_size=1, padding='same'),
nn.Conv1d(128, self.params.knowledge_length, kernel_size=1, padding='same')
)
# Specific layers for q path
self.downsample_q_specific = nn.Sequential(
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
)
# 使用实数版本的位置编码,避免复数张量可能导致的段错误
self.register_buffer("pos_cis_real",
precompute_pos_cis_real(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
self.freeze_embedding = False
self.params = params
def forward(self,
input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
step: int = 0,
**args):
start_pos = args.get('start_pos', 0)
# if self.freeze_embedding and step == 0:
# self.tok_embeddings.weight.requires_grad = False
# # 移除对knowledge_dataset.freeze_embedding的设置让键更新由batch_counter控制
# # self.knowledge_dataset.freeze_embedding = True
# print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
pos_cis_real = self.pos_cis_real[start_pos:start_pos + input_ids.size(1)]
h_list = []
for l, layer in enumerate(self.layers):
# 禁用数据库模式,使用固定值替代数据库查询
if self.params.disable_db:
# 创建一个形状为[batch_size, n_layers, dim]的tensor所有元素值为1e-4
batch_size = h.size(0)
db_value = torch.full((batch_size, self.n_layers, self.params.dim), 1e-4,
dtype=h.dtype, device=h.device)
else:
# 正常模式,使用数据库查询
# import pdb;pdb.set_trace()
index = self.extract_db.q_to_k(h)
token_idx = self.extract_db.get_data(index) #这里是index
db_value =self.tok_embeddings(token_idx)
h = layer(
h, pos_cis
h, db_value, pos_cis_real
)
h_list.append(h.unsqueeze(0))
h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3)
# 只在非禁用数据库模式下执行数据库更新逻辑
if not self.params.disable_db:
# 使用detach()分离计算图,避免多次反向传播
h_tensor_detached = h_tensor.detach()
h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0], -1, self.params.dim)
# 数据库更新逻辑与主计算图分离
with torch.no_grad():
# Compute shared downsampling layer once
shared_features = self.shared_downsample(h_tensor_detached)
# Get features from v path - now we output embedding-dimension vectors
z_v_features = self.downsample_v_specific(shared_features)
batch_z, seq_len, dim_z = z_v_features.shape
# Reshape to batch_size * knowledge_length, dim
z_v_flat = z_v_features.reshape(-1, dim_z)
# Direct token prediction - like the main language model head
token_logits = self.database_output(z_v_flat) # [batch_z * seq_len, vocab_size]
# Get token indices directly from logits
token_indices_flat = torch.argmax(token_logits, dim=-1)
token_indices = token_indices_flat.reshape(batch_z, -1)
# Process query path as before
z_q = self.downsample_q_specific(shared_features)
z_k = self.extract_db.q_to_k(z_q)
self.extract_db.updata_value(z_k, token_indices)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
# 统一不使用 aux_loss
aux_loss = 0
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
@ -551,6 +717,12 @@ class MiniMindLM(PreTrainedModel):
output.aux_loss = aux_loss
# 尝试添加其他属性(如果支持的话)
# try:
# output.hidden_states = h
# except:
# pass
return output
@torch.inference_mode()
@ -583,19 +755,15 @@ class MiniMindLM(PreTrainedModel):
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start = input_ids.shape[1]
for _ in range(max_new_tokens):
# 每次都传入完整的input_ids不使用KV缓存
out = self(input_ids, **args)
logits = out.logits[:, -1, :] # 取最后一个位置的logits
# 重复惩罚
start, first_seq = input_ids.shape[1], True
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq:
out, first_seq = self(input_ids, **args), False
else:
out = self(input_ids[:, -1:], start_pos=input_ids.shape[1] - 1, **args)
logits = out.logits[:, -1, :]
logits[:, list(set(input_ids.tolist()[0]))] /= rp
# 温度采样
logits /= (temperature + 1e-9)
# Top-p采样
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
@ -605,14 +773,8 @@ class MiniMindLM(PreTrainedModel):
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
# 采样下一个token
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
# 返回新生成的部分
yield input_ids[:, start:]
# 如果遇到结束token停止生成
if input_ids_next.item() == eos_token_id:
break

View File

@ -1,732 +0,0 @@
import math
import struct
import inspect
import time
import gc
#子空间二维分解+梯度更新
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class KnowledgeDataset(nn.Module):
def __init__(self, params, tok_embeddings, is_train=True):
super().__init__()
self.is_train = is_train
self.params = params
self.tok_embeddings = tok_embeddings
# 嵌入参数
self.knowledge_dim = params.knowledge_dim
self.key_dim = self.knowledge_dim // 2
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.knowledge_dim, bias=False),
)
## 数据库参数
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
# 修改键存储为二维分解空间,设置为可训练参数
self.num_keys = int(math.sqrt(self.knowledge_num))
# 确保keys是可训练参数
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
# 计算step数目用于动态调整权重
self.step_counter = 0
# 移除批次计数器和更新频率相关代码
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
if self.is_train == False:
return all_scores, all_indices
batch_size = all_scores.size(0)
device = all_scores.device
dtype = all_scores.dtype
# 记录进入智能选择前的内存状态
if hasattr(self, 'step_counter'):
self.step_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.step_counter % 50 == 0: # 每50次调用记录一次
# if torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_ENTER] Step {self.step_counter}: GPU Memory: {allocated_before:.2f}GB")
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
# 预先计算所有候选条目的嵌入(批量优化)
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
# 批量计算唯一候选条目的嵌入
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
)
unique_candidate_features = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
).mean(dim=1) # [num_unique_candidates, dim]
# 归一化候选特征(优化相似度计算)
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
normalized_queries = F.normalize(query_features, dim=-1)
# 收集所有batch的best_tokens
batch_best_tokens = []
batch_best_tokens_embeddings = []
for batch_idx in range(batch_size):
indices = all_indices[batch_idx]
# 获取当前batch候选条目对应的特征索引
start_idx = batch_idx * len(indices)
end_idx = start_idx + len(indices)
batch_inverse_indices = inverse_indices[start_idx:end_idx]
# 使用预计算的归一化特征进行优化相似度计算
batch_candidate_features = normalized_candidates[batch_inverse_indices]
query_feature = normalized_queries[batch_idx]
# 使用矩阵乘法计算余弦相似度
similarity_scores = torch.mv(batch_candidate_features, query_feature)
# 找到最大相似度分数的索引
max_similarity_idx = torch.argmax(similarity_scores)
# 获取最大相似度对应的候选条目索引
best_candidate_idx = indices[max_similarity_idx]
# 获取对应的tokens
best_tokens = self.knowledge_dataset[best_candidate_idx]
best_tokens_embeddings = self.tok_embeddings(best_tokens)
# 将当前batch的best_tokens添加到列表中
batch_best_tokens.append(best_tokens)
batch_best_tokens_embeddings.append(best_tokens_embeddings)
# 将所有batch的best_tokens堆叠成一个张量
# [batch_size, knowledge_length]
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
# 清理中间张量以防止内存泄漏
del all_candidate_indices, unique_indices, inverse_indices
del unique_candidate_features, normalized_candidates, normalized_queries
del batch_best_tokens, batch_best_tokens_embeddings
del flat_tokens, flat_embeddings, pre_update_embeddings
# 记录退出智能选择后的内存状态(已禁用以提高性能)
# if hasattr(self, 'step_counter') and self.step_counter % 50 == 0:
# if torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_EXIT] Step {self.step_counter}: GPU Memory: {allocated_after:.2f}GB")
# 强制垃圾回收(仅在监控步骤)
if hasattr(self, 'step_counter') and self.step_counter % 100 == 0:
gc.collect()
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
return all_best_tokens, all_best_tokens_embeddings
def search_index(self, x):
batch_size, seq_len, dim = x.shape
# 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim]
# 2. 生成查询向量并重塑为两个子查询
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
# 3. 计算每个子空间的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 4. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 5. 组合两个子空间的结果
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
return best_tokens, best_tokens_embeddings
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 监控交叉注意力开始时的内存(已禁用以提高性能)
if not hasattr(self, 'call_counter'):
self.call_counter = 0
self.call_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_ENTER] Call {self.call_counter}: GPU Memory: {allocated_before:.2f}GB")
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
# 清理中间张量
del q, k, v, attn_scores, attn_weights
# 监控交叉注意力结束时的内存(已禁用以提高性能)
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_EXIT] Call {self.call_counter}: GPU Memory: {allocated_after:.2f}GB")
return context
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
if config.hidden_dim is None:
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class MoEGate(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = 0
return topk_idx, topk_weight, aux_loss
class MOEFeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
self.shared_experts = FeedForward(config)
def forward(self, x):
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
self.aux_loss = aux_loss
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok
# 当tokens_per_expert = [6, 15, 20, 26]tokens_per_expert.shape[0]即为专家数量此时为4
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token每个token有可能被多个专家处理这取决于num_experts_per_tok
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache
class TripleExtractionHead(nn.Module):
"""三元组提取任务头"""
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
# 三元组长度超参数
self.max_subject_len = config.max_subject_len
self.max_predicate_len = config.max_predicate_len
self.max_object_len = config.max_object_len
# 自注意力机制
self.self_attention = Attention(config)
self.self_attn_norm = RMSNorm(config.dim, eps=config.norm_eps)
# 交叉注意力机制(用于主语和宾语提取)
# self.cross_attention_subject = CrossAttention(config)
# self.cross_attention_object = CrossAttention(config)
# 归一化层
self.subject_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.object_norm = RMSNorm(config.dim, eps=config.norm_eps)
# Feed Forward 网络
self.predicate_ff = FeedForward(config)
# self.subject_ff = FeedForward(config)
# self.object_ff = FeedForward(config)
# 输出投影层 - 修改为支持序列预测
self.predicate_output = nn.Linear(config.dim, 264, bias=False)
# self.subject_output = nn.Linear(config.dim, self.max_subject_len * config.dim, bias=False)
# self.object_output = nn.Linear(config.dim, self.max_object_len * config.dim, bias=False)
print(f"三元组提取任务头配置:")
print(f"- 主语最大长度: {self.max_subject_len}")
print(f"- 谓语最大长度: {self.max_predicate_len}")
print(f"- 宾语最大长度: {self.max_object_len}")
def forward(self, h, pos_cis):
"""
Args:
h: [batch_size, seq_len, dim] - 来自transformer层的隐藏状态
pos_cis: 位置编码
Returns:
predicate_logits: [batch_size, seq_len, max_predicate_len, vocab_size] - 谓语序列预测
subject_logits: [batch_size, seq_len, max_subject_len, vocab_size] - 主语序列预测
object_logits: [batch_size, seq_len, max_object_len, vocab_size] - 宾语序列预测
"""
batch_size, seq_len, dim = h.shape
# 1. h通过自注意力得到h1
h1 = self.self_attention(self.self_attn_norm(h), pos_cis)
h1 = h + h1 # 残差连接
# 2. h1通过feed_forward得到谓语输出
predicate_features = self.predicate_ff(h1)
predicate_features = predicate_features.mean(dim=1)
predicate_class = self.predicate_output(predicate_features) # [batch_size, max_predicate_len * vocab_size]
# # 3. h1通过交叉注意力k,v都是h得到h2
# h2 = self.cross_attention_subject(h1, h) # query是h1key和value都是h
# h2 = h1 + h2 # 残差连接
# # 4. h2通过feed_forward得到主语输出
# subject_features = self.subject_ff(self.subject_norm(h2))
# subject_features = subject_features.mean(dim=1)
# subject_raw = self.subject_output(subject_features) # [batch_size, max_subject_len * vocab_size]
# subject_logits = subject_raw.view(batch_size, self.max_subject_len, -1)
# # 5. h2通过交叉注意力k,v都是h得到h3
# h3 = self.cross_attention_object(h2, h) # query是h2key和value都是h
# h3 = h2 + h3 # 残差连接
# # 6. h3通过feed_forward得到宾语输出
# object_features = self.object_ff(self.object_norm(h3))
# object_features = object_features.mean(dim=1)
# object_raw = self.object_output(object_features) # [batch_size, max_object_len * vocab_size]
# object_logits = object_raw.view(batch_size, self.max_object_len, -1)
return predicate_class
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.self_attention = Attention(config)
self.cross_attention = CrossAttention(config)
self.knowledge_dataset = knowledge_dataset
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, x, pos_cis):
h_attn = self.self_attention(
self.attention_norm(x),
pos_cis
)
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
h_attn = self.cross_attention(h_attn, db_embeddings)
h = x + h_attn
out = h + self.feed_forward(self.ffn_norm(h))
return out
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None,mode="triple"):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
# 添加三元组提取任务头(可训练)
self.triple_extraction_head = TripleExtractionHead(params)
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
self.freeze_embedding = False
self.mode = mode
# 冻结所有指定组件的权重
self._freeze_components()
def _freeze_components(self):
"""冻结指定组件的权重"""
# 冻结词嵌入层
for param in self.tok_embeddings.parameters():
param.requires_grad = False
# 冻结知识数据库
for param in self.knowledge_dataset.parameters():
param.requires_grad = False
# 冻结所有transformer层
for param in self.layers.parameters():
param.requires_grad = False
# 冻结输出层
for param in self.output.parameters():
param.requires_grad = False
# pos_cis是buffer本身就不需要梯度但为了明确起见
# (实际上buffer默认就是requires_grad=False)
if hasattr(self, 'pos_cis'):
self.pos_cis.requires_grad = False
print("已冻结以下组件的权重:")
print("- tok_embeddings")
print("- knowledge_dataset")
print("- layers (所有transformer层)")
print("- output")
print("- pos_cis")
print("注意triple_extraction_head 保持可训练状态")
def forward(self,
input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
step: int = 0,
**args):
start_pos = args.get('start_pos', 0)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers):
h = layer(
h, pos_cis
)
# 应用三元组提取任务头
predicate_class = self.triple_extraction_head(h, pos_cis)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
)
output.hidden_states = h
output.aux_loss = aux_loss
# 添加三元组提取结果
# 注意:现在的维度是 [batch_size, seq_len, max_len, vocab_size]
output.predicate_class = predicate_class
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq:
out, first_seq = self(input_ids, **args), False
else:
out = self(input_ids[:, -1:],
start_pos=input_ids.shape[1] - 1, **args)
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break

View File

@ -1,488 +0,0 @@
import math
import struct
import inspect
import time
import gc
#子空间二维分解+梯度更新
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class KnowledgeDataset(nn.Module):
def __init__(self, params, tok_embeddings, is_train=True):
super().__init__()
self.is_train = is_train
self.params = params
self.tok_embeddings = tok_embeddings
# 嵌入参数
self.knowledge_dim = params.knowledge_dim
self.key_dim = self.knowledge_dim // 2
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.knowledge_dim, bias=False),
)
## 数据库参数
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
# 修改键存储为二维分解空间,设置为可训练参数
self.num_keys = int(math.sqrt(self.knowledge_num))
# 确保keys是可训练参数
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
# 计算step数目用于动态调整权重
self.step_counter = 0
# 移除批次计数器和更新频率相关代码
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
if self.is_train == False:
return all_scores, all_indices
batch_size = all_scores.size(0)
device = all_scores.device
dtype = all_scores.dtype
# 记录进入智能选择前的内存状态
if hasattr(self, 'step_counter'):
self.step_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.step_counter % 50 == 0: # 每50次调用记录一次
# if torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_ENTER] Step {self.step_counter}: GPU Memory: {allocated_before:.2f}GB")
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
# 预先计算所有候选条目的嵌入(批量优化)
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
# 批量计算唯一候选条目的嵌入
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
)
unique_candidate_features = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
).mean(dim=1) # [num_unique_candidates, dim]
# 归一化候选特征(优化相似度计算)
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
normalized_queries = F.normalize(query_features, dim=-1)
# 收集所有batch的best_tokens
batch_best_tokens = []
batch_best_tokens_embeddings = []
for batch_idx in range(batch_size):
indices = all_indices[batch_idx]
# 获取当前batch候选条目对应的特征索引
start_idx = batch_idx * len(indices)
end_idx = start_idx + len(indices)
batch_inverse_indices = inverse_indices[start_idx:end_idx]
# 使用预计算的归一化特征进行优化相似度计算
batch_candidate_features = normalized_candidates[batch_inverse_indices]
query_feature = normalized_queries[batch_idx]
# 使用矩阵乘法计算余弦相似度
similarity_scores = torch.mv(batch_candidate_features, query_feature)
# 找到最大相似度分数的索引
max_similarity_idx = torch.argmax(similarity_scores)
# 获取最大相似度对应的候选条目索引
best_candidate_idx = indices[max_similarity_idx]
# 获取对应的tokens
best_tokens = self.knowledge_dataset[best_candidate_idx]
best_tokens_embeddings = self.tok_embeddings(best_tokens)
# 将当前batch的best_tokens添加到列表中
batch_best_tokens.append(best_tokens)
batch_best_tokens_embeddings.append(best_tokens_embeddings)
# 将所有batch的best_tokens堆叠成一个张量
# [batch_size, knowledge_length]
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
# 清理中间张量以防止内存泄漏
del all_candidate_indices, unique_indices, inverse_indices
del unique_candidate_features, normalized_candidates, normalized_queries
del batch_best_tokens, batch_best_tokens_embeddings
del flat_tokens, flat_embeddings, pre_update_embeddings
# 记录退出智能选择后的内存状态(已禁用以提高性能)
# if hasattr(self, 'step_counter') and self.step_counter % 50 == 0:
# if torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_EXIT] Step {self.step_counter}: GPU Memory: {allocated_after:.2f}GB")
# 强制垃圾回收(仅在监控步骤)
if hasattr(self, 'step_counter') and self.step_counter % 100 == 0:
gc.collect()
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
return all_best_tokens, all_best_tokens_embeddings
def search_index(self, x):
batch_size, seq_len, dim = x.shape
# 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim]
# 2. 生成查询向量并重塑为两个子查询
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
# 3. 计算每个子空间的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 4. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 5. 组合两个子空间的结果
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
return best_tokens, best_tokens_embeddings
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 监控交叉注意力开始时的内存(已禁用以提高性能)
if not hasattr(self, 'call_counter'):
self.call_counter = 0
self.call_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_ENTER] Call {self.call_counter}: GPU Memory: {allocated_before:.2f}GB")
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
# 清理中间张量
del q, k, v, attn_scores, attn_weights
# 监控交叉注意力结束时的内存(已禁用以提高性能)
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_EXIT] Call {self.call_counter}: GPU Memory: {allocated_after:.2f}GB")
return context
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.self_attention = Attention(config)
self.cross_attention = CrossAttention(config)
self.knowledge_dataset = knowledge_dataset
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
# 移除 ffn_norm 和 feed_forward因为不再使用 FeedForward 层
def forward(self, x, pos_cis):
h_attn = self.self_attention(
self.attention_norm(x),
pos_cis
)
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
h_attn = self.cross_attention(h_attn, db_embeddings)
h = x + h_attn
# 移除 FeedForward 层,直接返回注意力输出
return h
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
self.freeze_embedding = False
def forward(self,
input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
step: int = 0,
**args):
start_pos = args.get('start_pos', 0)
# if self.freeze_embedding and step == 0:
# self.tok_embeddings.weight.requires_grad = False
# # 移除对knowledge_dataset.freeze_embedding的设置让键更新由batch_counter控制
# # self.knowledge_dataset.freeze_embedding = True
# print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers):
h = layer(
h, pos_cis
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
# 移除 aux_loss 计算,因为不再使用 FeedForward 层
aux_loss = 0
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
)
output.hidden_states = h
output.aux_loss = aux_loss
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start = input_ids.shape[1]
for _ in range(max_new_tokens):
# 每次都传入完整的input_ids不使用KV缓存
out = self(input_ids, **args)
logits = out.logits[:, -1, :] # 取最后一个位置的logits
# 重复惩罚
logits[:, list(set(input_ids.tolist()[0]))] /= rp
# 温度采样
logits /= (temperature + 1e-9)
# Top-p采样
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
# 采样下一个token
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
# 返回新生成的部分
yield input_ids[:, start:]
# 如果遇到结束token停止生成
if input_ids_next.item() == eos_token_id:
break

View File

@ -1,386 +0,0 @@
import math
import struct
import inspect
import time
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache=False):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
# kv_cache实现
if past_key_value is not None:
xk = torch.cat([past_key_value[0], xk], dim=1)
xv = torch.cat([past_key_value[1], xv], dim=1)
past_kv = (xk, xv) if use_cache else None
xq, xk, xv = (
xq.transpose(1, 2),
repeat_kv(xk, self.n_rep).transpose(1, 2),
repeat_kv(xv, self.n_rep).transpose(1, 2)
)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output, past_kv
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
if config.hidden_dim is None:
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class MoEGate(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = 0
return topk_idx, topk_weight, aux_loss
class MOEFeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
self.shared_experts = FeedForward(config)
def forward(self, x):
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
self.aux_loss = aux_loss
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok
# 当tokens_per_expert = [6, 15, 20, 26]tokens_per_expert.shape[0]即为专家数量此时为4
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token每个token有可能被多个专家处理这取决于num_experts_per_tok
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.attention = Attention(config)
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, x, pos_cis, past_key_value=None, use_cache=False):
h_attn, past_kv = self.attention(
self.attention_norm(x),
pos_cis,
past_key_value=past_key_value,
use_cache=use_cache
)
h = x + h_attn
out = h + self.feed_forward(self.ffn_norm(h))
return out, past_kv
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
def forward(self,
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0,
**args):
past_key_values = past_key_values or [None] * len(self.layers)
start_pos = args.get('start_pos', 0)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
past_kvs = []
for l, layer in enumerate(self.layers):
h, past_kv = layer(
h, pos_cis,
past_key_value=past_key_values[l],
use_cache=use_cache
)
past_kvs.append(past_kv)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
# 统一不使用 aux_loss
aux_loss = 0
self.OUT.__setitem__('last_hidden_state', h)
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss)
self.OUT.__setitem__('past_key_values', past_kvs)
return self.OUT
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., use_cache=True, pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq or not use_cache:
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False
else:
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
start_pos=input_ids.shape[1] - 1, **args)
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break

File diff suppressed because it is too large Load Diff

View File

@ -1,43 +0,0 @@
{
"add_bos_token": false,
"add_eos_token": false,
"add_prefix_space": false,
"added_tokens_decoder": {
"0": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "<|im_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "<|im_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [],
"bos_token": "<|im_start|>",
"clean_up_tokenization_spaces": false,
"eos_token": "<|im_end|>",
"legacy": true,
"model_max_length": 32768,
"pad_token": "<unk>",
"sp_model_kwargs": {},
"spaces_between_special_tokens": false,
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": "<unk>",
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\n你是 MiniMind是一个有用的人工智能助手。<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"
}

File diff suppressed because one or more lines are too long

View File

@ -1,133 +0,0 @@
import json
import os
import datetime
from typing import List, Dict, Any
# 配置参数
json_path = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/dataset/combined.json"
prepare_num = 1048576 # database_init.json的数据条数可以根据需要修改
output_dir = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/dataset/"
def convert_to_database_init_format(sentences: List[str], importance_score: float = 10.0) -> Dict[str, Any]:
"""
将句子列表转换为 database_init.json 格式
Args:
sentences: 句子列表
importance_score: 重要性评分默认为10.0
Returns:
转换后的字典格式数据
"""
# 构建句子数据
sentence_data = []
for sentence in sentences:
sentence_item = {
"original_sentence": sentence,
"corrected_sentence": sentence, # 与original_sentence相同
"importance_score": importance_score
}
sentence_data.append(sentence_item)
# 构建完整的数据结构
result = {
"metadata": {
"batch_number": 1,
"batch_size": len(sentences),
"total_processed_count": len(sentences),
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"total_sentences": len(sentences),
"duplicates_removed": 0 # 在此函数中不涉及去重所以设为0
},
"sentences": sentence_data
}
return result
def preprocess_combined_json():
# 读取原始数据
print("正在读取combined.json...")
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
total_count = len(data)
print(f"总共有 {total_count} 条数据")
# 处理所有数据将subject、predicate、object拼接成句子同时记录原始数据
print("正在处理数据并拼接句子...")
sentence_to_original = {} # 记录句子到原始数据的映射
all_sentences = []
for i, item in enumerate(data):
# 拼接subject、predicate、object为一句话
sentence = f"{item['subject']} {item['predicate']} {item['object']}"
all_sentences.append(sentence)
# 记录句子到原始数据的映射(如果句子重复,保留第一次出现的原始数据)
if sentence not in sentence_to_original:
sentence_to_original[sentence] = item
if (i + 1) % 100000 == 0:
print(f"已处理 {i + 1}/{total_count} 条数据")
print(f"完成句子拼接,共 {len(all_sentences)} 条句子")
# 去重处理
print("正在进行去重处理...")
unique_sentences = list(set(all_sentences))
duplicates_removed = len(all_sentences) - len(unique_sentences)
print(f"去重完成,去重前: {len(all_sentences)} 条,去重后: {len(unique_sentences)} 条,移除重复: {duplicates_removed}")
# 检查是否有足够的去重数据
if len(unique_sentences) < prepare_num:
print(f"警告: 去重后的数据量 ({len(unique_sentences)}) 少于所需数量 ({prepare_num})")
print(f"将使用全部 {len(unique_sentences)} 条去重数据")
selected_sentences = unique_sentences
else:
print(f"选择前 {prepare_num} 条去重数据")
selected_sentences = unique_sentences[:prepare_num]
# 转换为database_init.json格式
print("正在转换为database_init.json格式...")
database_init_data = convert_to_database_init_format(selected_sentences, importance_score=10.0)
# 更新metadata中的duplicates_removed信息
database_init_data["metadata"]["duplicates_removed"] = duplicates_removed
# 保存database_init.json
database_output_path = os.path.join(output_dir, "database_init_from_combined.json")
print(f"正在保存 {database_output_path}...")
with open(database_output_path, "w", encoding="utf-8") as f:
json.dump(database_init_data, f, ensure_ascii=False, indent=2)
print(f"database_init_from_combined.json 保存完成,包含 {len(selected_sentences)} 条数据")
# 保存剩余数据作为训练集(保持原格式)
remaining_sentences = unique_sentences[prepare_num:] if len(unique_sentences) > prepare_num else []
if remaining_sentences:
# 将剩余的句子转换回原始格式
print(f"正在转换剩余 {len(remaining_sentences)} 条数据为原始格式...")
remaining_original_data = []
for sentence in remaining_sentences:
if sentence in sentence_to_original:
remaining_original_data.append(sentence_to_original[sentence])
print(f"保存剩余 {len(remaining_original_data)} 条数据作为训练集...")
train_output_path = os.path.join(output_dir, "combined_train.json")
with open(train_output_path, "w", encoding="utf-8") as f:
json.dump(remaining_original_data, f, ensure_ascii=False, indent=2)
print(f"combined_train.json 保存完成")
else:
print("没有剩余数据用于训练集")
remaining_original_data = []
print("\n数据处理完成!")
print(f"原始数据: {total_count}")
print(f"拼接后: {len(all_sentences)} 条句子")
print(f"去重后: {len(unique_sentences)} 条句子")
print(f"用于database_init: {len(selected_sentences)}")
print(f"剩余训练数据: {len(remaining_original_data) if remaining_sentences else 0}")
if __name__ == "__main__":
preprocess_combined_json()

View File

@ -1,741 +0,0 @@
import json
import os
import pandas as pd
import tarfile
import tempfile
import shutil
from pathlib import Path
import re
import langdetect
from tqdm import tqdm
import logging
import random
import hashlib
from transformers import AutoTokenizer
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 配置参数
BASE_DIR = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind"
OUTPUT_FILE = os.path.join(BASE_DIR, "dataset", "merged_pretrain.jsonl")
OUTPUT_FILE_EXTRA = os.path.join(BASE_DIR, "dataset", "merged_pretrain_extra.jsonl")
# 数据源路径
PRETRAIN_HQ_PATH = os.path.join(BASE_DIR, "dataset", "pretrain_hq.jsonl")
WIKIPEDIA_PATH = "/home/pci/nas/share/datasets/wikipedia/data/20220301.en"
GUTENBERG_PATH = "/home/pci/nas/share/datasets/gutenberg/data"
OPENWEBTEXT_PATH = "/home/pci/nas/share/datasets/openwebtext/subsets"
# Token长度限制
MIN_TOKENS = 410
MAX_TOKENS = 490
# 数据集质量和采样比例配置 - 主文件
DATASET_CONFIG = {
"pretrain_hq": {"quality": "high", "sample_ratio": 1.0, "max_samples": None}, # 高质量,全部使用
"wikipedia": {"quality": "high", "sample_ratio": 1.0, "max_samples": 5000000}, # 高质量使用全部最多500万条
"gutenberg": {"quality": "medium", "sample_ratio": 0.8, "max_samples": 3000000}, # 中质量使用80%最多300万条
"openwebtext": {"quality": "low", "sample_ratio": 0.2, "max_samples": 2000000} # 低质量使用20%最多200万条
}
# 额外文件的配置 - 剩余数据
DATASET_CONFIG_EXTRA = {
"wikipedia": {"quality": "high", "sample_ratio": 1.0, "max_samples": None}, # 剩余的全部
"gutenberg": {"quality": "medium", "sample_ratio": 0.8, "max_samples": 5000000}, # 剩余的80%最多500万条
"openwebtext": {"quality": "low", "sample_ratio": 0.6, "max_samples": 4000000} # 剩余的60%最多400万条
}
# 全局变量:记录已选择的数据
selected_data_hashes = {
"wikipedia": set(),
"gutenberg": set(),
"openwebtext": set()
}
# 初始化tokenizer
tokenizer = None
def init_tokenizer():
"""初始化tokenizer"""
global tokenizer
try:
# 首先尝试使用本地的minimind tokenizer
local_tokenizer_path = os.path.join(BASE_DIR, "model", "minimind_tokenizer")
if os.path.exists(local_tokenizer_path):
tokenizer = AutoTokenizer.from_pretrained(local_tokenizer_path)
logger.info("Local MiniMind tokenizer initialized successfully")
else:
# 如果本地tokenizer不存在使用GPT-2但设置离线模式
tokenizer = AutoTokenizer.from_pretrained("gpt2", local_files_only=True)
logger.info("GPT-2 tokenizer initialized successfully (offline)")
except Exception as e:
logger.error(f"Error initializing tokenizer: {e}")
logger.info("Trying to use a simple fallback tokenizer...")
# 使用简单的分词方法作为备选
tokenizer = None
logger.warning("Using simple word-based tokenization as fallback")
def count_tokens(text):
"""计算文本的token数量"""
if tokenizer is None:
init_tokenizer()
if tokenizer is not None:
try:
tokens = tokenizer.encode(text, add_special_tokens=False)
return len(tokens)
except:
pass
# 如果tokenization失败或tokenizer为None使用简单估算
return int(len(text.split()) * 1.3) # 大概估算,确保返回整数
def is_english_text(text, threshold=0.8):
"""检测文本是否为英文"""
try:
if len(text) < 50: # 太短的文本跳过检测
return True
detected_lang = langdetect.detect(text)
return detected_lang == 'en'
except:
# 如果检测失败,使用简单的英文字符比例判断
english_chars = sum(1 for char in text if char.isascii() and (char.isalpha() or char in ' .,!?-'))
total_chars = len(text.replace('\n', ' ').replace('\t', ' '))
return (english_chars / max(total_chars, 1)) > threshold
def truncate_to_token_limit(text, target_tokens=MAX_TOKENS):
"""将文本截断到目标token数量"""
if tokenizer is None:
init_tokenizer()
if tokenizer is not None:
try:
tokens = tokenizer.encode(text, add_special_tokens=False)
if len(tokens) <= target_tokens:
return text
# 截断到目标长度
truncated_tokens = tokens[:target_tokens]
truncated_text = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
# 尝试在句号处截断以保持完整性
sentences = truncated_text.split('.')
if len(sentences) > 1:
# 保留除最后一个不完整句子外的所有句子
truncated_text = '.'.join(sentences[:-1]) + '.'
return truncated_text
except:
pass
# 如果处理失败或tokenizer为None使用字符数估算
estimated_chars = int(target_tokens / 1.3 * 4) # 大概估算
text = text[:estimated_chars]
# 尝试在句号处截断以保持完整性
sentences = text.split('.')
if len(sentences) > 1:
text = '.'.join(sentences[:-1]) + '.'
return text
def split_text_into_chunks(text, target_chunk_size=1500):
"""将长文本分割成多个中等长度的段落块"""
# 清理文本
text = text.strip()
if not text:
return []
# 移除过多的换行符和空格
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
text = re.sub(r' +', ' ', text)
chunks = []
# 按段落分割
paragraphs = text.split('\n\n')
current_chunk = ""
for paragraph in paragraphs:
paragraph = paragraph.strip()
if not paragraph:
continue
# 如果当前块加上新段落长度适中,就添加
if len(current_chunk) + len(paragraph) < target_chunk_size:
if current_chunk:
current_chunk += "\n\n" + paragraph
else:
current_chunk = paragraph
else:
# 如果当前块不为空,保存它
if current_chunk:
chunks.append(current_chunk)
# 如果段落本身就很长,需要进一步分割
if len(paragraph) > target_chunk_size * 2:
# 按句子分割长段落
sentences = re.split(r'(?<=[.!?])\s+', paragraph)
temp_chunk = ""
for sentence in sentences:
if len(temp_chunk) + len(sentence) < target_chunk_size:
if temp_chunk:
temp_chunk += " " + sentence
else:
temp_chunk = sentence
else:
if temp_chunk:
chunks.append(temp_chunk)
temp_chunk = sentence
if temp_chunk:
current_chunk = temp_chunk
else:
current_chunk = ""
else:
current_chunk = paragraph
# 添加最后一个块
if current_chunk:
chunks.append(current_chunk)
return chunks
def format_text_for_pretrain(text):
"""将文本格式化为预训练格式并检查token长度"""
# 清理文本
text = text.strip()
if not text:
return None
# 移除过多的换行符和空格
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
text = re.sub(r' +', ' ', text)
# 检查token长度
token_count = count_tokens(text)
# 如果太短,跳过
if token_count < MIN_TOKENS:
return None
# 如果太长,截断
if token_count > MAX_TOKENS:
text = truncate_to_token_limit(text, MAX_TOKENS)
token_count = count_tokens(text)
# 再次检查是否在合理范围内
if token_count < MIN_TOKENS or token_count > MAX_TOKENS:
return None
# 格式化为预训练格式
formatted_text = f"<|im_start|>{text}<|im_end|>"
return formatted_text
def get_text_hash(text):
"""获取文本的哈希值,用于去重"""
return hashlib.md5(text.encode('utf-8')).hexdigest()
def should_sample(dataset_name, current_count, config_dict=None, is_extra_mode=False):
"""根据配置决定是否采样当前记录"""
if config_dict is None:
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
config = config_dict[dataset_name]
# 检查是否达到最大样本数
if config["max_samples"] and current_count >= config["max_samples"]:
return False
# 根据采样比例随机决定
return random.random() < config["sample_ratio"]
def process_pretrain_hq():
"""处理已有的高质量预训练数据 - 直接输出,不做任何处理"""
logger.info("Processing pretrain_hq.jsonl...")
count = 0
with open(PRETRAIN_HQ_PATH, 'r', encoding='utf-8') as f:
for line in tqdm(f, desc="Processing pretrain_hq"):
try:
data = json.loads(line.strip())
text = data.get('text', '').strip()
if text: # 只要有文本就直接输出,不做任何检测
if should_sample("pretrain_hq", count):
yield {"text": text}
count += 1
except json.JSONDecodeError:
continue
logger.info(f"Processed {count} records from pretrain_hq.jsonl")
def process_wikipedia(is_extra_mode=False):
"""处理Wikipedia数据"""
mode_text = "extra" if is_extra_mode else "main"
logger.info(f"Processing Wikipedia data ({mode_text} mode)...")
count = 0
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
# 获取所有英文Wikipedia文件
wiki_files = list(Path(WIKIPEDIA_PATH).glob("*.parquet"))
for file_path in tqdm(wiki_files, desc=f"Processing Wikipedia files ({mode_text})"):
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
break
try:
df = pd.read_parquet(file_path)
for _, row in df.iterrows():
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
break
text = row.get('text', '').strip()
if text and len(text) > 200: # 预过滤太短的文本
# 先将长文本分割成中等大小的块
chunks = split_text_into_chunks(text, target_chunk_size=2000)
for chunk in chunks:
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
break
chunk_hash = get_text_hash(chunk)
# 在额外模式下,跳过已经被主文件选中的数据
if is_extra_mode and chunk_hash in selected_data_hashes["wikipedia"]:
continue
formatted_text = format_text_for_pretrain(chunk)
if formatted_text and should_sample("wikipedia", count, config_dict, is_extra_mode):
# 在主模式下记录哈希值
if not is_extra_mode:
selected_data_hashes["wikipedia"].add(chunk_hash)
yield {"text": formatted_text}
count += 1
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
continue
logger.info(f"Processed {count} records from Wikipedia ({mode_text} mode)")
def process_gutenberg(is_extra_mode=False):
"""处理Gutenberg数据"""
mode_text = "extra" if is_extra_mode else "main"
logger.info(f"Processing Gutenberg data ({mode_text} mode)...")
count = 0
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
# 获取所有Gutenberg训练文件
gutenberg_files = list(Path(GUTENBERG_PATH).glob("train-*.parquet"))
for file_path in tqdm(gutenberg_files, desc=f"Processing Gutenberg files ({mode_text})"):
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
break
try:
df = pd.read_parquet(file_path)
for _, row in df.iterrows():
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
break
text = row.get('text', '').strip()
if text and len(text) > 300 and is_english_text(text): # 预过滤
# 先将长文本分割成中等大小的块
chunks = split_text_into_chunks(text, target_chunk_size=1800)
for chunk in chunks:
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
break
chunk_hash = get_text_hash(chunk)
# 在额外模式下,跳过已经被主文件选中的数据
if is_extra_mode and chunk_hash in selected_data_hashes["gutenberg"]:
continue
formatted_text = format_text_for_pretrain(chunk)
if formatted_text and should_sample("gutenberg", count, config_dict, is_extra_mode):
# 在主模式下记录哈希值
if not is_extra_mode:
selected_data_hashes["gutenberg"].add(chunk_hash)
yield {"text": formatted_text}
count += 1
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
continue
logger.info(f"Processed {count} records from Gutenberg ({mode_text} mode)")
def process_openwebtext(is_extra_mode=False):
"""处理OpenWebText数据"""
mode_text = "extra" if is_extra_mode else "main"
logger.info(f"Processing OpenWebText data ({mode_text} mode)...")
count = 0
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
max_files = 5 # 减少处理的文件数量以避免过长处理时间
# 获取tar文件列表
tar_files = list(Path(OPENWEBTEXT_PATH).glob("*.tar"))[:max_files]
for tar_path in tqdm(tar_files, desc=f"Processing OpenWebText files ({mode_text})"):
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
try:
with tarfile.open(tar_path, 'r') as outer_tar:
# 创建临时目录处理外层tar
with tempfile.TemporaryDirectory() as temp_dir:
outer_tar.extractall(temp_dir)
# 处理解压后的xz文件
for root, dirs, files in os.walk(temp_dir):
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
for file in files:
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
if file.endswith('.xz'):
xz_path = os.path.join(root, file)
# 创建另一个临时目录处理xz文件
with tempfile.TemporaryDirectory() as xz_temp_dir:
try:
# 解压xz文件
import subprocess
decompressed_path = os.path.join(xz_temp_dir, file[:-3]) # 移除.xz后缀
subprocess.run(['xz', '-dc', xz_path],
stdout=open(decompressed_path, 'wb'),
check=True)
# 检查解压后的文件是否是tar格式
if tarfile.is_tarfile(decompressed_path):
# 处理内层tar文件
with tarfile.open(decompressed_path, 'r') as inner_tar:
with tempfile.TemporaryDirectory() as inner_temp_dir:
inner_tar.extractall(inner_temp_dir)
# 处理最终的txt文件
for inner_root, inner_dirs, inner_files in os.walk(inner_temp_dir):
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
for txt_file in inner_files:
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
if txt_file.endswith('.txt'):
txt_path = os.path.join(inner_root, txt_file)
try:
with open(txt_path, 'r', encoding='utf-8', errors='ignore') as f:
text = f.read().strip()
if text and len(text) > 500 and is_english_text(text):
# 先将长文本分割成中等大小的块
chunks = split_text_into_chunks(text, target_chunk_size=1600)
for chunk in chunks:
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
chunk_hash = get_text_hash(chunk)
# 在额外模式下,跳过已经被主文件选中的数据
if is_extra_mode and chunk_hash in selected_data_hashes["openwebtext"]:
continue
formatted_text = format_text_for_pretrain(chunk)
if formatted_text and should_sample("openwebtext", count, config_dict, is_extra_mode):
# 在主模式下记录哈希值
if not is_extra_mode:
selected_data_hashes["openwebtext"].add(chunk_hash)
yield {"text": formatted_text}
count += 1
except Exception as e:
logger.debug(f"Error reading txt file {txt_path}: {e}")
continue
else:
# 如果不是tar文件直接作为文本处理
try:
with open(decompressed_path, 'r', encoding='utf-8', errors='ignore') as f:
text = f.read().strip()
if text and len(text) > 500 and is_english_text(text):
chunks = split_text_into_chunks(text, target_chunk_size=1600)
for chunk in chunks:
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
chunk_hash = get_text_hash(chunk)
# 在额外模式下,跳过已经被主文件选中的数据
if is_extra_mode and chunk_hash in selected_data_hashes["openwebtext"]:
continue
formatted_text = format_text_for_pretrain(chunk)
if formatted_text and should_sample("openwebtext", count, config_dict, is_extra_mode):
# 在主模式下记录哈希值
if not is_extra_mode:
selected_data_hashes["openwebtext"].add(chunk_hash)
yield {"text": formatted_text}
count += 1
except Exception as e:
logger.debug(f"Error reading decompressed file {decompressed_path}: {e}")
continue
except Exception as e:
logger.debug(f"Error processing xz file {xz_path}: {e}")
continue
except Exception as e:
logger.error(f"Error processing {tar_path}: {e}")
continue
logger.info(f"Processed {count} records from OpenWebText ({mode_text} mode)")
def merge_datasets():
"""合并所有数据集,生成主文件和额外文件"""
logger.info("Starting dataset merging...")
logger.info("Main dataset configuration:")
for name, config in DATASET_CONFIG.items():
logger.info(f" {name}: quality={config['quality']}, ratio={config['sample_ratio']}, max={config['max_samples']}")
logger.info("Extra dataset configuration:")
for name, config in DATASET_CONFIG_EXTRA.items():
logger.info(f" {name}: quality={config['quality']}, ratio={config['sample_ratio']}, max={config['max_samples']}")
# 确保输出目录存在
os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)
os.makedirs(os.path.dirname(OUTPUT_FILE_EXTRA), exist_ok=True)
# 统计信息
main_dataset_stats = {}
extra_dataset_stats = {}
# 第一阶段:生成主文件
logger.info("="*50)
logger.info("PHASE 1: Generating main dataset file")
logger.info("="*50)
with open(OUTPUT_FILE, 'w', encoding='utf-8') as outfile:
main_total_count = 0
# 处理各个数据集(主模式)
main_datasets = [
("pretrain_hq", process_pretrain_hq),
("wikipedia", lambda: process_wikipedia(is_extra_mode=False)),
("gutenberg", lambda: process_gutenberg(is_extra_mode=False)),
("openwebtext", lambda: process_openwebtext(is_extra_mode=False))
]
for dataset_name, dataset_func in main_datasets:
logger.info(f"Processing {dataset_name} for main file...")
dataset_count = 0
try:
for record in dataset_func():
json.dump(record, outfile, ensure_ascii=False)
outfile.write('\n')
dataset_count += 1
main_total_count += 1
# 每5000条记录输出一次进度
if main_total_count % 5000 == 0:
logger.info(f"Main file: Processed {main_total_count} total records")
# 保存统计信息
main_dataset_stats[dataset_name] = {
'selected': dataset_count,
'config': DATASET_CONFIG[dataset_name]
}
except Exception as e:
logger.error(f"Error processing {dataset_name} for main file: {e}")
main_dataset_stats[dataset_name] = {
'selected': dataset_count,
'config': DATASET_CONFIG[dataset_name]
}
logger.info(f"Main file - Completed {dataset_name}: {dataset_count} records")
logger.info(f"Main file generation completed. Total records: {main_total_count}")
# 第二阶段:生成额外文件
logger.info("="*50)
logger.info("PHASE 2: Generating extra dataset file")
logger.info("="*50)
with open(OUTPUT_FILE_EXTRA, 'w', encoding='utf-8') as outfile:
extra_total_count = 0
# 处理各个数据集(额外模式)- 不包括pretrain_hq
extra_datasets = [
("wikipedia", lambda: process_wikipedia(is_extra_mode=True)),
("gutenberg", lambda: process_gutenberg(is_extra_mode=True)),
("openwebtext", lambda: process_openwebtext(is_extra_mode=True))
]
for dataset_name, dataset_func in extra_datasets:
logger.info(f"Processing {dataset_name} for extra file...")
dataset_count = 0
try:
for record in dataset_func():
json.dump(record, outfile, ensure_ascii=False)
outfile.write('\n')
dataset_count += 1
extra_total_count += 1
# 每5000条记录输出一次进度
if extra_total_count % 5000 == 0:
logger.info(f"Extra file: Processed {extra_total_count} total records")
# 保存统计信息
extra_dataset_stats[dataset_name] = {
'selected': dataset_count,
'config': DATASET_CONFIG_EXTRA[dataset_name]
}
except Exception as e:
logger.error(f"Error processing {dataset_name} for extra file: {e}")
extra_dataset_stats[dataset_name] = {
'selected': dataset_count,
'config': DATASET_CONFIG_EXTRA[dataset_name]
}
logger.info(f"Extra file - Completed {dataset_name}: {dataset_count} records")
logger.info(f"Extra file generation completed. Total records: {extra_total_count}")
# 打印详细统计信息
print_detailed_statistics(main_dataset_stats, main_total_count, extra_dataset_stats, extra_total_count)
logger.info("All dataset processing completed successfully!")
logger.info(f"Main file saved to: {OUTPUT_FILE}")
logger.info(f"Extra file saved to: {OUTPUT_FILE_EXTRA}")
def print_detailed_statistics(main_dataset_stats, main_total_count, extra_dataset_stats, extra_total_count):
"""打印详细统计信息"""
print("\n" + "="*100)
print("DATASET PROCESSING SUMMARY")
print("="*100)
# 主文件统计
print("\nMAIN FILE (merged_pretrain.jsonl):")
print("-" * 90)
print(f"{'Dataset':<15} {'Selected':<10} {'Sample Ratio':<12} {'Max Limit':<12} {'% of Main':<12} {'Quality':<8}")
print("-" * 90)
for dataset_name, stats in main_dataset_stats.items():
selected = stats['selected']
config = stats['config']
ratio = config['sample_ratio']
max_limit = config['max_samples'] if config['max_samples'] else "No limit"
percentage = (selected / main_total_count * 100) if main_total_count > 0 else 0
quality = config['quality']
print(f"{dataset_name:<15} {selected:<10,} {ratio:<12.1%} {str(max_limit):<12} {percentage:<12.2f}% {quality:<8}")
print("-" * 90)
print(f"{'MAIN TOTAL':<15} {main_total_count:<10,} {'':<12} {'':<12} {'100.00%':<12} {'':<8}")
# 额外文件统计
print("\nEXTRA FILE (merged_pretrain_extra.jsonl):")
print("-" * 90)
print(f"{'Dataset':<15} {'Selected':<10} {'Sample Ratio':<12} {'Max Limit':<12} {'% of Extra':<12} {'Quality':<8}")
print("-" * 90)
for dataset_name, stats in extra_dataset_stats.items():
selected = stats['selected']
config = stats['config']
ratio = config['sample_ratio']
max_limit = config['max_samples'] if config['max_samples'] else "No limit"
percentage = (selected / extra_total_count * 100) if extra_total_count > 0 else 0
quality = config['quality']
print(f"{dataset_name:<15} {selected:<10,} {ratio:<12.1%} {str(max_limit):<12} {percentage:<12.2f}% {quality:<8}")
print("-" * 90)
print(f"{'EXTRA TOTAL':<15} {extra_total_count:<10,} {'':<12} {'':<12} {'100.00%':<12} {'':<8}")
# 总体统计
total_records = main_total_count + extra_total_count
print("\nOVERALL STATISTICS:")
print("-" * 50)
print(f"Main file records: {main_total_count:>10,}")
print(f"Extra file records: {extra_total_count:>10,}")
print(f"Total records: {total_records:>10,}")
print(f"Token range per sample: {MIN_TOKENS}-{MAX_TOKENS} tokens")
# 质量分布统计
quality_stats = {}
for dataset_name, stats in main_dataset_stats.items():
quality = stats['config']['quality']
if quality not in quality_stats:
quality_stats[quality] = {'main': 0, 'extra': 0}
quality_stats[quality]['main'] += stats['selected']
for dataset_name, stats in extra_dataset_stats.items():
quality = stats['config']['quality']
if quality not in quality_stats:
quality_stats[quality] = {'main': 0, 'extra': 0}
quality_stats[quality]['extra'] += stats['selected']
print("\nQUALITY DISTRIBUTION:")
print("-" * 60)
print(f"{'Quality':<12} {'Main File':<12} {'Extra File':<12} {'Total':<12} {'%':<8}")
print("-" * 60)
for quality in sorted(quality_stats.keys()):
main_count = quality_stats[quality]['main']
extra_count = quality_stats[quality]['extra']
total_count = main_count + extra_count
percentage = (total_count / total_records * 100) if total_records > 0 else 0
print(f"{quality.capitalize():<12} {main_count:<12,} {extra_count:<12,} {total_count:<12,} {percentage:<8.2f}%")
print("-" * 60)
print(f"{'Total':<12} {main_total_count:<12,} {extra_total_count:<12,} {total_records:<12,} {'100.00%':<8}")
print(f"\nFiles saved to:")
print(f" Main file: {OUTPUT_FILE}")
print(f" Extra file: {OUTPUT_FILE_EXTRA}")
print("="*100)
def main():
"""主函数"""
try:
# 设置随机种子以确保结果可重现
random.seed(42)
# 检查依赖包
try:
import langdetect
from transformers import AutoTokenizer
except ImportError as e:
logger.error(f"Missing dependencies: {e}")
logger.error("Please install: pip install langdetect transformers")
return
# 初始化tokenizer
init_tokenizer()
# 检查输入文件
if not os.path.exists(PRETRAIN_HQ_PATH):
logger.error(f"pretrain_hq.jsonl not found at {PRETRAIN_HQ_PATH}")
return
# 开始合并数据集
merge_datasets()
logger.info("All processing completed successfully!")
except Exception as e:
logger.error(f"Error in main process: {e}")
raise
if __name__ == "__main__":
main()

View File

@ -1,442 +0,0 @@
import json
import os
import argparse
from typing import List, Dict, Any, Optional
from collections import defaultdict
import pickle
from pathlib import Path
class WikidataRelationManager:
"""Wikidata关系管理器支持动态获取和缓存"""
def __init__(self, cache_file: str = "wikidata_relations_cache.pkl",
mapping_file: str = None):
self.cache_file = cache_file
self.mapping_file = mapping_file
self.relations = {}
# 删除了API相关属性
# 初始的基础关系映射
self.base_relations = {
# # 基本关系
# 'P31': 'instance of',
# 'P279': 'subclass of',
# 'P17': 'country',
# 'P159': 'headquarters location',
# 'P571': 'inception',
# # 人物关系
# 'P19': 'place of birth',
# 'P20': 'place of death',
# 'P27': 'country of citizenship',
# 'P106': 'occupation',
# 'P22': 'father',
# 'P25': 'mother',
# 'P26': 'spouse',
# 'P40': 'child',
# 'P69': 'educated at',
# 'P108': 'employer',
# # 地理关系
# 'P36': 'capital',
# 'P131': 'located in',
# 'P47': 'shares border with',
# 'P206': 'located on terrain feature',
# 'P1376': 'capital of',
# # 组织关系
# 'P112': 'founded by',
# 'P127': 'owned by',
# 'P169': 'chief executive officer',
# 'P488': 'chairperson',
# 'P749': 'parent organization',
# # 作品关系
# 'P50': 'author',
# 'P57': 'director',
# 'P58': 'screenwriter',
# 'P161': 'cast member',
# 'P175': 'performer',
# 'P577': 'publication date',
# 'P123': 'publisher',
# 'P136': 'genre',
# # 时间关系
# 'P155': 'follows',
# 'P156': 'followed by',
# 'P580': 'start time',
# 'P582': 'end time',
# # 体育关系
# 'P54': 'member of sports team',
# 'P413': 'position played on team',
# 'P118': 'league',
# # 科学关系
# 'P275': 'copyright license',
# 'P170': 'creator',
# 'P398': 'child astronomical body',
# 'P397': 'parent astronomical body',
# # 其他常见关系
# 'P37': 'official language',
# 'P1923': 'place of marriage',
# 'P737': 'influenced by',
# 'P463': 'member of',
# 'P39': 'position held',
# 'P276': 'location',
# 'P1441': 'present in work',
}
self.load_cache()
def load_cache(self):
"""加载缓存的关系映射优先使用JSON映射文件"""
try:
# 优先尝试加载JSON映射文件
if self.mapping_file and os.path.exists(self.mapping_file):
with open(self.mapping_file, 'r', encoding='utf-8') as f:
self.relations = json.load(f)
print(f"从JSON映射文件加载了 {len(self.relations)} 个关系映射")
return
# 尝试加载pickle缓存文件
if os.path.exists(self.cache_file):
with open(self.cache_file, 'rb') as f:
self.relations = pickle.load(f)
print(f"从pickle缓存加载了 {len(self.relations)} 个关系映射")
else:
self.relations = self.base_relations.copy()
print(f"初始化基础关系映射: {len(self.relations)}")
except Exception as e:
print(f"加载缓存失败: {e}")
self.relations = self.base_relations.copy()
def save_cache(self):
"""保存关系映射到缓存"""
try:
with open(self.cache_file, 'wb') as f:
pickle.dump(self.relations, f)
print(f"已保存 {len(self.relations)} 个关系映射到缓存")
except Exception as e:
print(f"保存缓存失败: {e}")
# 删除了网络抓取功能,改为纯离线模式
def get_relation_name(self, property_id: str) -> Optional[str]:
"""获取关系名称,仅使用本地映射"""
if property_id in self.relations:
return self.relations[property_id]
# 如果本地映射中没有找到返回None表示跳过这个关系
return None
# 删除了网络请求相关的批量获取和预加载功能
class TRexProcessor:
"""T-REx数据集处理器"""
def __init__(self, relation_manager: WikidataRelationManager):
self.relation_manager = relation_manager
def extract_predicate_id(self, uri: str) -> str:
"""从URI中提取属性ID"""
if uri and 'prop/direct/' in uri:
return uri.split('/')[-1]
elif uri and uri.startswith('P') and uri[1:].isdigit():
return uri
return uri if uri else 'unknown'
def get_relation_name(self, predicate_uri: str) -> Optional[str]:
"""获取关系的可读名称"""
predicate_id = self.extract_predicate_id(predicate_uri)
return self.relation_manager.get_relation_name(predicate_id)
# 删除了谓词收集功能,因为不再需要预加载
def is_valid_triple(self, triple: Dict[str, Any], confidence_threshold: float,
boundary_threshold: int) -> bool:
"""检查三元组是否满足过滤条件"""
try:
# 检查triple是否为字典
if not isinstance(triple, dict):
return False
# 检查必要字段
if not all(key in triple for key in ['subject', 'predicate', 'object']):
return False
subject = triple['subject']
predicate = triple['predicate']
object_info = triple['object']
# 检查subject、predicate、object是否都为字典
if not isinstance(subject, dict) or not isinstance(predicate, dict) or not isinstance(object_info, dict):
return False
# 检查主语和宾语是否有有效的URI和surfaceform
if not (subject.get('uri') and subject.get('surfaceform')):
return False
if not (object_info.get('uri') and object_info.get('surfaceform')):
return False
if not predicate.get('uri'):
return False
# 检查置信度(如果存在)
confidence = triple.get('confidence')
if confidence is not None and confidence < confidence_threshold:
return False
# 检查边界信息(如果设置了阈值)
if boundary_threshold > 0:
subject_boundaries = subject.get('boundaries')
object_boundaries = object_info.get('boundaries')
if not subject_boundaries or not object_boundaries:
return False
# 检查边界是否为列表且长度至少为2
if not (isinstance(subject_boundaries, list) and len(subject_boundaries) >= 2):
return False
if not (isinstance(object_boundaries, list) and len(object_boundaries) >= 2):
return False
try:
# 检查边界长度是否合理
subject_length = subject_boundaries[1] - subject_boundaries[0]
object_length = object_boundaries[1] - object_boundaries[0]
if subject_length < boundary_threshold or object_length < boundary_threshold:
return False
except (TypeError, IndexError):
return False
# 检查文本内容是否合理
subject_text = subject.get('surfaceform', '').strip()
object_text = object_info.get('surfaceform', '').strip()
if not subject_text or not object_text:
return False
# 过滤掉过长或过短的实体
if len(subject_text) > 100 or len(object_text) > 100:
return False
if len(subject_text) < 2 or len(object_text) < 2:
return False
return True
except (KeyError, TypeError, AttributeError):
return False
def process_single_file(self, file_path: str, confidence_threshold: float,
boundary_threshold: int) -> List[Dict[str, Any]]:
"""处理单个JSON文件"""
print(f"Processing file: {file_path}")
processed_data = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
# 读取整个文件作为JSON数组
print(f"正在加载JSON数组文件: {file_path}")
data_list = json.load(f)
print(f"文件包含 {len(data_list)} 个条目")
for idx, data in enumerate(data_list):
try:
# 获取基本信息
text = data.get('text', '').strip()
if not text:
continue
# 处理三元组
triples = data.get('triples', [])
if not triples:
continue
valid_targets = []
for triple in triples:
if self.is_valid_triple(triple, confidence_threshold, boundary_threshold):
# 获取关系名称,如果无法解析则跳过这个三元组
relation_name = self.get_relation_name(triple['predicate']['uri'])
if relation_name is None:
continue # 跳过无法解析的关系
target = {
'subject': triple['subject']['surfaceform'].strip(),
'predicate': relation_name,
'object': triple['object']['surfaceform'].strip()
}
valid_targets.append(target)
# 如果有有效的三元组,添加到结果中
if valid_targets:
processed_data.append({
'text': text,
'target': valid_targets
})
except Exception as e:
if idx <= 10: # 只打印前10个错误
print(f"处理条目时出错 in {file_path} at index {idx}: {e}")
continue
except FileNotFoundError:
print(f"文件未找到: {file_path}")
except json.JSONDecodeError as e:
print(f"JSON解析错误 in {file_path}: {e}")
except Exception as e:
print(f"处理文件时出错 {file_path}: {e}")
print(f"{file_path} 提取了 {len(processed_data)} 个有效样本")
return processed_data
def process_folder(self, folder_path: str, confidence_threshold: float,
boundary_threshold: int) -> List[Dict[str, Any]]:
"""处理文件夹中的所有JSON文件"""
all_processed_data = []
if not os.path.exists(folder_path):
raise FileNotFoundError(f"文件夹不存在: {folder_path}")
# 获取所有JSON文件
json_files = [f for f in os.listdir(folder_path) if f.endswith('.json')]
if not json_files:
raise ValueError(f"{folder_path} 中没有找到JSON文件")
print(f"找到 {len(json_files)} 个JSON文件")
for filename in sorted(json_files):
file_path = os.path.join(folder_path, filename)
processed_data = self.process_single_file(file_path, confidence_threshold, boundary_threshold)
all_processed_data.extend(processed_data)
# 保存最终的关系缓存
self.relation_manager.save_cache()
return all_processed_data
def generate_statistics(self, processed_data: List[Dict[str, Any]]) -> Dict[str, Any]:
"""生成数据统计信息"""
total_samples = len(processed_data)
total_triples = sum(len(sample['target']) for sample in processed_data)
# 统计关系类型
relation_counts = defaultdict(int)
for sample in processed_data:
for target in sample['target']:
relation_counts[target['predicate']] += 1
# 统计文本长度
text_lengths = [len(sample['text']) for sample in processed_data]
avg_text_length = sum(text_lengths) / len(text_lengths) if text_lengths else 0
# 统计每个文本的三元组数量
triples_per_text = [len(sample['target']) for sample in processed_data]
avg_triples_per_text = sum(triples_per_text) / len(triples_per_text) if triples_per_text else 0
return {
'total_samples': total_samples,
'total_triples': total_triples,
'avg_text_length': round(avg_text_length, 2),
'avg_triples_per_text': round(avg_triples_per_text, 2),
'relation_distribution': dict(sorted(relation_counts.items(),
key=lambda x: x[1], reverse=True)),
'top_10_relations': dict(list(sorted(relation_counts.items(),
key=lambda x: x[1], reverse=True))[:10]),
'total_unique_relations': len(relation_counts),
'cached_relations': len(self.relation_manager.relations)
}
def main():
parser = argparse.ArgumentParser(description='处理T-REx数据集支持动态关系获取')
parser.add_argument('--folder_path', type=str,default='/home/pci/ycz/Code/Minimind/dataset/trex', help='包含JSON文件的文件夹路径')
parser.add_argument('--confidence_threshold', type=float, default=0.5,
help='置信度阈值 (默认: 0.0)')
parser.add_argument('--boundary_threshold', type=int, default=0,
help='边界长度阈值 (默认: 0, 不过滤)')
parser.add_argument('--output', type=str, default='./processed_trex_data.json',
help='输出文件名 (默认: processed_trex_data.json)')
parser.add_argument('--stats', type=str, default='trex_statistics.json',
help='统计信息输出文件名 (默认: trex_statistics.json)')
parser.add_argument('--cache_file', type=str, default='wikidata_relations_cache.pkl',
help='关系缓存文件名 (默认: wikidata_relations_cache.pkl)')
parser.add_argument('--mapping_file', type=str, default="/home/pci/ycz/Code/Minimind/preprocessing/sample_property_mappings.json",
help='JSON映射文件路径 (必须提供,用于关系名称映射)')
args = parser.parse_args()
print("T-REx数据集处理器支持动态关系获取")
print("=" * 60)
print(f"输入文件夹: {args.folder_path}")
print(f"置信度阈值: {args.confidence_threshold}")
print(f"边界长度阈值: {args.boundary_threshold}")
print(f"输出文件: {args.output}")
print(f"关系缓存文件: {args.cache_file}")
print(f"JSON映射文件: {args.mapping_file if args.mapping_file else '未指定'}")
print("=" * 60)
# 检查映射文件是否存在
if not args.mapping_file or not os.path.exists(args.mapping_file):
print(f"错误: 映射文件不存在或未指定: {args.mapping_file}")
print("请确保提供有效的JSON映射文件。")
return 1
# 创建关系管理器
relation_manager = WikidataRelationManager(
cache_file=args.cache_file,
mapping_file=args.mapping_file
)
# 创建处理器
processor = TRexProcessor(relation_manager)
try:
# 处理数据
processed_data = processor.process_folder(
args.folder_path,
args.confidence_threshold,
args.boundary_threshold
)
print(f"\n处理完成!总共处理了 {len(processed_data)} 个样本")
# 生成统计信息
stats = processor.generate_statistics(processed_data)
# 保存处理后的数据
with open(args.output, 'w', encoding='utf-8') as f:
json.dump(processed_data, f, ensure_ascii=False, indent=2)
# 保存统计信息
with open(args.stats, 'w', encoding='utf-8') as f:
json.dump(stats, f, ensure_ascii=False, indent=2)
print(f"\n数据已保存到: {args.output}")
print(f"统计信息已保存到: {args.stats}")
print(f"关系缓存已保存到: {args.cache_file}")
# 打印统计摘要
print("\n数据统计摘要:")
print("=" * 30)
print(f"总样本数: {stats['total_samples']}")
print(f"总三元组数: {stats['total_triples']}")
print(f"唯一关系数: {stats['total_unique_relations']}")
print(f"缓存关系数: {stats['cached_relations']}")
print(f"平均文本长度: {stats['avg_text_length']}")
print(f"平均每文本三元组数: {stats['avg_triples_per_text']}")
print("\n前10个最常见关系:")
for relation, count in stats['top_10_relations'].items():
print(f" {relation}: {count}")
except Exception as e:
print(f"处理过程中出错: {e}")
return 1
return 0
if __name__ == "__main__":
exit(main())

View File

@ -1,441 +0,0 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
import re
import asyncio
import aiofiles
from concurrent.futures import ThreadPoolExecutor
from preprocessing.agent_system.extractor_agent.agent import DepartmentAgent
from typing import Dict, List, Tuple
import gc
import time
import psutil
from tqdm.asyncio import tqdm as async_tqdm
from tqdm import tqdm
json_path = "dataset/merged_pretrain_extra.jsonl"
output_path = "dataset/processed_triples.jsonl"
# 优化后的配置参数 - 降低资源消耗
BATCH_SIZE = 5000 # 减少批次大小每批1万条数据
MAX_CONCURRENT = 200 # 减少并发数最多50条并发处理
AGENT_POOL_SIZE = 20 # 大幅减少agent池大小只创建5个agent实例
def get_memory_usage():
"""获取当前内存使用情况"""
process = psutil.Process(os.getpid())
memory_info = process.memory_info()
memory_mb = memory_info.rss / 1024 / 1024
return memory_mb
def print_memory_info(stage=""):
"""打印内存使用信息"""
memory_mb = get_memory_usage()
print(f"🔧 {stage} - 内存使用: {memory_mb:.1f} MB")
# 创建extractor_agent池避免并发冲突
def create_extractor_pool(pool_size: int = 5):
"""创建extractor_agent池"""
print(f"正在创建 {pool_size} 个agent实例...")
agents = []
for i in range(pool_size):
try:
agent = DepartmentAgent(model_type="deepseek")
agents.append(agent)
print(f" ✓ Agent {i+1}/{pool_size} 创建成功")
except Exception as e:
print(f" ✗ Agent {i+1} 创建失败: {e}")
print(f"Agent池创建完成实际创建了 {len(agents)} 个实例")
return agents
# 延迟初始化agent池
AGENT_POOL = None
agent_pool_index = 0
def get_agent_pool():
"""获取agent池延迟初始化"""
global AGENT_POOL
if AGENT_POOL is None:
print_memory_info("创建Agent池前")
AGENT_POOL = create_extractor_pool(pool_size=AGENT_POOL_SIZE)
print_memory_info("创建Agent池后")
return AGENT_POOL
def get_next_agent():
"""轮询获取下一个可用的agent"""
global agent_pool_index
pool = get_agent_pool()
agent = pool[agent_pool_index % len(pool)]
agent_pool_index += 1
return agent
def clean_and_split_text(text):
"""
去除文本开头结尾的标记并按句子分割
"""
# 去除开头的<|im_start|>和结尾的<|im_end|>
text = text.strip()
if text.startswith('<|im_start|>'):
text = text[len('<|im_start|>'):]
if text.endswith('<|im_end|>'):
text = text[:-len('<|im_end|>')]
# 清理文本,去除多余的空白字符
text = text.strip()
# 按句子分割(根据句号、问号、感叹号等标点符号)
# 使用正则表达式匹配句子结束标志
sentence_endings = r'[.!?。!?]'
sentences = re.split(sentence_endings, text)
# 清理每个句子,去除空白和空句子
cleaned_sentences = []
for sentence in sentences:
sentence = sentence.strip()
if sentence and len(sentence) > 5: # 只保留非空且有意义的句子
cleaned_sentences.append(sentence)
return cleaned_sentences
async def extract_triple_from_sentence_async(sentence: str, context: str = None) -> Dict:
"""
异步使用extractor_agent从句子中提取三元组
"""
try:
# 获取一个agent实例
agent = get_next_agent()
result = await agent.async_run(sentence=sentence, context=context)
return {
"sentence": sentence,
"triple": {
"subject": result.triple.subject,
"predicate": result.triple.predicate,
"object": result.triple.object
},
"confidence": result.confidence
}
except Exception as e:
return {
"sentence": sentence,
"triple": {
"subject": "",
"predicate": "",
"object": ""
},
"confidence": 0.0,
"error": str(e)
}
async def process_paragraph_async(line_num: int, original_text: str, semaphore: asyncio.Semaphore) -> Dict:
"""
异步处理单个段落
"""
async with semaphore: # 控制并发数量
try:
# 清理并分割文本
sentences = clean_and_split_text(original_text)
if not sentences:
return None
# 构建当前段落的结果
paragraph_result = {
"source_line": line_num,
"original_paragraph": original_text,
"sentences": [],
"triples": []
}
# 异步处理所有句子
tasks = []
for sentence in sentences:
task = extract_triple_from_sentence_async(sentence, context=original_text)
tasks.append(task)
# 等待所有句子处理完成
triple_results = await asyncio.gather(*tasks)
# 整理结果
for i, sentence in enumerate(sentences):
paragraph_result["sentences"].append(sentence)
paragraph_result["triples"].append(triple_results[i])
return paragraph_result
except Exception as e:
print(f"处理第 {line_num} 行时出错: {e}")
return None
async def process_batch_async(batch_data: List[Tuple[int, str]], batch_num: int) -> List[Dict]:
"""
异步处理一个批次的数据带进度条和内存监控
"""
print(f"\n=== 异步处理批次 {batch_num} ===")
print(f"批次大小: {len(batch_data)} 条记录")
print_memory_info(f"批次 {batch_num} 开始前")
start_time = time.time()
# 创建信号量控制并发数量
semaphore = asyncio.Semaphore(MAX_CONCURRENT)
# 分块处理任务,避免一次性创建太多任务
chunk_size = 1000 # 每次处理1000个任务
all_results = []
for chunk_start in range(0, len(batch_data), chunk_size):
chunk_end = min(chunk_start + chunk_size, len(batch_data))
chunk_data = batch_data[chunk_start:chunk_end]
print(f"处理子块 {chunk_start//chunk_size + 1}/{(len(batch_data)-1)//chunk_size + 1} ({len(chunk_data)} 条记录)")
# 创建当前块的异步任务
tasks = []
for line_num, original_text in chunk_data:
task = process_paragraph_async(line_num, original_text, semaphore)
tasks.append(task)
# 使用进度条处理当前块
progress_bar = tqdm(total=len(tasks), desc=f"批次{batch_num}-块{chunk_start//chunk_size + 1}", unit="段落", ncols=100)
chunk_results = []
completed_tasks = 0
# 使用as_completed来获取完成的任务并更新进度条
for coro in asyncio.as_completed(tasks):
try:
result = await coro
chunk_results.append(result)
completed_tasks += 1
# 更新进度条
progress_bar.update(1)
# 每完成50个任务更新一次描述
if completed_tasks % 50 == 0:
valid_results = [r for r in chunk_results if r is not None]
progress_bar.set_postfix({
'有效': len(valid_results),
'完成': completed_tasks,
'成功率': f"{len(valid_results)/completed_tasks*100:.1f}%"
})
except Exception as e:
print(f"任务执行失败: {e}")
completed_tasks += 1
progress_bar.update(1)
progress_bar.close()
all_results.extend(chunk_results)
# 每个块完成后清理内存
del tasks, chunk_results
gc.collect()
print_memory_info(f"批次 {batch_num}{chunk_start//chunk_size + 1} 完成后")
# 过滤None结果
valid_results = [result for result in all_results if result is not None]
# 统计信息
batch_sentences = sum(len(result["sentences"]) for result in valid_results)
batch_triples = sum(
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
for result in valid_results
)
end_time = time.time()
processing_time = end_time - start_time
print(f"批次 {batch_num} 异步处理完成:")
print(f" - 有效段落: {len(valid_results)}/{len(batch_data)} ({len(valid_results)/len(batch_data)*100:.1f}%)")
print(f" - 总句子数: {batch_sentences}")
print(f" - 成功三元组: {batch_triples}")
print(f" - 三元组成功率: {batch_triples/batch_sentences*100:.1f}%" if batch_sentences > 0 else "无句子")
print(f" - 处理时间: {processing_time:.2f}")
print(f" - 处理速度: {len(batch_data)/processing_time:.2f}段落/秒")
print_memory_info(f"批次 {batch_num} 完成后")
return valid_results
async def write_results_batch(results: List[Dict], output_path: str):
"""
异步批量写入结果带进度提示
"""
try:
print(f"开始批量写入 {len(results)} 条结果...")
# 准备写入内容
content_lines = []
for result in results:
content_lines.append(json.dumps(result, ensure_ascii=False))
# 异步批量写入
async with aiofiles.open(output_path, "a", encoding="utf-8") as f:
await f.write("\n".join(content_lines) + "\n")
print(f"✓ 成功批量写入 {len(results)} 条结果到 {output_path}")
except Exception as e:
print(f"✗ 批量写入失败: {e}")
print("尝试逐条写入...")
# 如果批量写入失败,回退到逐条写入(带进度条)
async with aiofiles.open(output_path, "a", encoding="utf-8") as f:
for result in tqdm(results, desc="逐条写入", unit=""):
await f.write(json.dumps(result, ensure_ascii=False) + "\n")
print(f"✓ 逐条写入完成")
# 主处理流程
async def main_async():
total_processed = 0
total_sentences = 0
total_triples = 0
batch_num = 0
print("=== 开始异步批次处理JSONL文件 ===")
print(f"优化后的配置信息:")
print(f" - 批次大小: {BATCH_SIZE:,} 条记录")
print(f" - 最大并发数: {MAX_CONCURRENT}")
print(f" - Agent池大小: {AGENT_POOL_SIZE}")
print(f" - 输入文件: {json_path}")
print(f" - 输出文件: {output_path}")
print()
print_memory_info("程序开始")
# 清空输出文件
async with aiofiles.open(output_path, "w", encoding="utf-8") as f:
pass
# 读取并处理数据
with open(json_path, "r", encoding="utf-8") as f_in:
batch_data = []
for line_num, line in enumerate(f_in):
if line.strip(): # 跳过空行
try:
item = json.loads(line)
original_text = item.get("text", "")
if original_text:
batch_data.append((line_num + 1, original_text))
# 当批次达到指定大小时,异步处理这个批次
if len(batch_data) >= BATCH_SIZE:
batch_num += 1
# 异步处理批次
batch_results = await process_batch_async(batch_data, batch_num)
# 批量写入结果
if batch_results:
await write_results_batch(batch_results, output_path)
# 统计信息
batch_sentences = sum(len(result["sentences"]) for result in batch_results)
batch_triples = sum(
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
for result in batch_results
)
total_processed += len(batch_data)
total_sentences += batch_sentences
total_triples += batch_triples
print(f"\n📊 批次 {batch_num} 累计统计:")
print(f" - 累计处理段落: {total_processed:,}")
print(f" - 累计句子数: {total_sentences:,}")
print(f" - 累计三元组: {total_triples:,}")
print(f" - 整体成功率: {total_triples/total_sentences*100:.1f}%")
print("-" * 80)
# 清理批次数据,释放内存
batch_data.clear()
batch_results.clear()
gc.collect() # 强制垃圾回收
print_memory_info(f"批次 {batch_num} 清理后")
except json.JSONDecodeError as e:
print(f"{line_num + 1} 行JSON解析错误: {e}")
except Exception as e:
print(f"处理第 {line_num + 1} 行时出错: {e}")
# 处理最后一个不完整的批次
if batch_data:
batch_num += 1
batch_results = await process_batch_async(batch_data, batch_num)
if batch_results:
await write_results_batch(batch_results, output_path)
batch_sentences = sum(len(result["sentences"]) for result in batch_results)
batch_triples = sum(
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
for result in batch_results
)
total_processed += len(batch_data)
total_sentences += batch_sentences
total_triples += batch_triples
# 最终统计
print(f"\n{'='*80}")
print(f"🎉 所有批次异步处理完成!")
print(f"{'='*80}")
print(f"最终统计:")
print(f" - 总批次数: {batch_num}")
print(f" - 总段落数: {total_processed:,}")
print(f" - 总句子数: {total_sentences:,}")
print(f" - 总三元组: {total_triples:,}")
print(f" - 整体成功率: {total_triples/total_sentences*100:.1f}%" if total_sentences > 0 else "无有效句子")
print(f" - 输出文件: {output_path}")
print(f"{'='*80}")
print_memory_info("程序结束前")
# 显示示例结果
await show_sample_results()
async def show_sample_results():
"""显示前几个处理结果作为示例"""
print("\n📋 前3个处理结果示例:")
try:
async with aiofiles.open(output_path, "r", encoding="utf-8") as f:
i = 0
async for line in f:
if i >= 3:
break
item = json.loads(line)
print(f"\n--- 段落 {i+1} (来源行: {item['source_line']}) ---")
print(f"原始段落: {item['original_paragraph'][:100]}...")
print(f"句子数量: {len(item['sentences'])}")
if item['triples']:
for j, triple in enumerate(item['triples'][:2]): # 只显示前2个三元组
print(f" 句子 {j+1}: {triple['sentence'][:50]}...")
if triple['confidence'] > 0:
print(f" 三元组: {triple['triple']['subject']} -> {triple['triple']['predicate']} -> {triple['triple']['object']}")
print(f" 置信度: {triple['confidence']:.2f}")
else:
print(f" 提取失败: {triple.get('error', '未知错误')}")
i += 1
except Exception as e:
print(f"读取示例结果时出错: {e}")
def main():
"""主入口函数"""
try:
# 运行异步主函数
asyncio.run(main_async())
except KeyboardInterrupt:
print("\n⚠️ 用户中断处理")
except Exception as e:
print(f"❌ 处理过程中出现错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@ -1,176 +0,0 @@
[project]
name = "minimind"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"accelerate==1.7.0",
"aiohappyeyeballs==2.6.1",
"aiohttp==3.11.17",
"aiosignal==1.3.2",
"altair==5.5.0",
"annotated-types==0.7.0",
"anyio==4.9.0",
"async-timeout==5.0.1",
"attrs==25.3.0",
"blinker==1.9.0",
"boto3==1.38.41",
"botocore==1.38.41",
"cachetools==5.5.2",
"certifi==2025.1.31",
"charset-normalizer==3.4.1",
"click==8.1.8",
"contourpy==1.3.2",
"cycler==0.12.1",
"datasets==2.21.0",
"datasketch==1.6.4",
"deepspeed==0.17.0",
"determined>=0.37.0",
"dill==0.3.8",
"distro==1.9.0",
"docker-pycreds==0.4.0",
"einops==0.8.1",
"exceptiongroup==1.2.2",
"filelock==3.18.0",
"Flask==3.0.3",
"Flask-Cors==4.0.0",
"fonttools==4.57.0",
"frozenlist==1.6.0",
"fsspec==2024.6.1",
"gitdb==4.0.12",
"GitPython==3.1.44",
"h11==0.14.0",
"hjson==3.1.0",
"httpcore==1.0.8",
"httpx==0.28.1",
"huggingface-hub==0.30.2",
"importlib_metadata==7.2.1",
"itsdangerous==2.2.0",
"jieba==0.42.1",
"Jinja2==3.1.2",
"jiter==0.9.0",
"jmespath==1.0.1",
"joblib==1.4.2",
"jsonlines==4.0.0",
"jsonpointer==2.1",
"jsonschema==4.23.0",
"jsonschema-specifications==2024.10.1",
"kiwisolver==1.4.8",
"langdetect==1.0.9",
"markdown-it-py==3.0.0",
"MarkupSafe==3.0.2",
"marshmallow==3.22.0",
"matplotlib==3.10.0",
"mdurl==0.1.2",
"modelscope==1.25.0",
"mpi4py>=4.0.3",
"mpmath==1.3.0",
"msgpack==1.1.0",
"multidict==6.4.3",
"multiprocess==0.70.16",
"narwhals==1.35.0",
"networkx==3.4.2",
"ngrok==1.4.0",
"ninja==1.11.1.4",
"nltk==3.8",
"numpy==1.26.4",
"nvidia-cublas-cu11==11.11.3.6",
"nvidia-cublas-cu12==12.6.4.1",
"nvidia-cuda-cupti-cu11==11.8.87",
"nvidia-cuda-cupti-cu12==12.6.80",
"nvidia-cuda-nvrtc-cu11==11.8.89",
"nvidia-cuda-nvrtc-cu12==12.6.77",
"nvidia-cuda-runtime-cu11==11.8.89",
"nvidia-cuda-runtime-cu12==12.6.77",
"nvidia-cudnn-cu11==9.1.0.70",
"nvidia-cudnn-cu12==9.5.1.17",
"nvidia-cufft-cu11==10.9.0.58",
"nvidia-cufft-cu12==11.3.0.4",
"nvidia-cufile-cu12==1.11.1.6",
"nvidia-curand-cu11==10.3.0.86",
"nvidia-curand-cu12==10.3.7.77",
"nvidia-cusolver-cu11==11.4.1.48",
"nvidia-cusolver-cu12==11.7.1.2",
"nvidia-cusparse-cu11==11.7.5.86",
"nvidia-cusparse-cu12==12.5.4.2",
"nvidia-cusparselt-cu12==0.6.3",
"nvidia-ml-py==12.575.51",
"nvidia-nccl-cu11==2.21.5",
"nvidia-nccl-cu12==2.26.2",
"nvidia-nvjitlink-cu12==12.6.85",
"nvidia-nvtx-cu11==11.8.86",
"nvidia-nvtx-cu12==12.6.77",
"openai==1.59.6",
"packaging==23.2",
"pandas>=2.0.0",
"peft==0.7.1",
"pillow==10.4.0",
"platformdirs==4.3.7",
"prettytable==3.16.0",
"propcache==0.3.1",
"protobuf==4.25.6",
"psutil==5.9.8",
"py-cpuinfo==9.0.0",
"pyarrow==19.0.1",
"pydantic==2.11.7",
"pydantic_core==2.33.2",
"pydeck==0.9.1",
"pyecharts==2.0.8",
"Pygments==2.19.1",
"pynvml==12.0.0",
"pyparsing==3.2.3",
"python-dateutil==2.9.0.post0",
"pytz==2025.2",
"PyYAML==6.0.2",
"referencing==0.36.2",
"regex==2024.11.6",
"requests==2.32.3",
"rich==13.7.1",
"rouge-score>=0.1.2",
"rpds-py==0.24.0",
"s3transfer==0.13.0",
"safetensors==0.5.3",
"scikit-learn==1.5.1",
"scipy==1.15.2",
"sentence-transformers==2.3.1",
"sentencepiece==0.2.0",
"sentry-sdk==2.26.1",
"setproctitle==1.3.5",
"simhash==2.1.2",
"simplejson==3.20.1",
"six==1.17.0",
"smmap==5.0.2",
"sniffio==1.3.1",
"streamlit==1.30.0",
"swankit==0.2.4",
"swanlab==0.6.4",
"sympy==1.13.3",
"tenacity==8.5.0",
"threadpoolctl==3.6.0",
"tiktoken>=0.8.0",
"tokenizers==0.21.1",
"toml==0.10.2",
"torch==2.7.1",
"torchaudio==2.7.1",
"torchvision==0.22.1",
"tornado==6.4.2",
"tqdm==4.67.1",
"transformers==4.52.4",
"triton==3.3.1",
"trl==0.13.0",
"typing-inspection==0.4.1",
"typing_extensions==4.13.2",
"tzlocal==5.3.1",
"ujson==5.1.0",
"urllib3==2.4.0",
"validators==0.34.0",
"wandb==0.18.3",
"watchdog==6.0.0",
"wcwidth==0.2.13",
"Werkzeug==3.1.3",
"wrapt==1.17.2",
"xxhash==3.5.0",
"yarl==1.20.0",
"zipp==3.21.0",
]

View File

@ -1,4 +1,3 @@
accelerate==1.7.0
aiohappyeyeballs==2.6.1
aiohttp==3.11.17
aiosignal==1.3.2
@ -8,8 +7,6 @@ anyio==4.9.0
async-timeout==5.0.1
attrs==25.3.0
blinker==1.9.0
boto3==1.38.41
botocore==1.38.41
cachetools==5.5.2
certifi==2025.1.31
charset-normalizer==3.4.1
@ -18,7 +15,6 @@ contourpy==1.3.2
cycler==0.12.1
datasets==2.21.0
datasketch==1.6.4
deepspeed==0.17.0
dill==0.3.8
distro==1.9.0
docker-pycreds==0.4.0
@ -37,19 +33,17 @@ hjson==3.1.0
httpcore==1.0.8
httpx==0.28.1
huggingface-hub==0.30.2
idna==3.10
importlib_metadata==7.2.1
itsdangerous==2.2.0
jieba==0.42.1
Jinja2==3.1.2
jiter==0.9.0
jmespath==1.0.1
joblib==1.4.2
jsonlines==4.0.0
jsonpointer==2.1
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
kiwisolver==1.4.8
langdetect==1.0.9
markdown-it-py==3.0.0
MarkupSafe==3.0.2
marshmallow==3.22.0
@ -66,50 +60,21 @@ ngrok==1.4.0
ninja==1.11.1.4
nltk==3.8
numpy==1.26.4
nvidia-cublas-cu11==11.11.3.6
nvidia-cublas-cu12==12.6.4.1
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvrtc-cu11==11.8.89
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu11==9.1.0.70
nvidia-cudnn-cu12==9.5.1.17
nvidia-cufft-cu11==10.9.0.58
nvidia-cufft-cu12==11.3.0.4
nvidia-cufile-cu12==1.11.1.6
nvidia-curand-cu11==10.3.0.86
nvidia-curand-cu12==10.3.7.77
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu11==11.7.5.86
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparselt-cu12==0.6.3
nvidia-ml-py==12.575.51
nvidia-nccl-cu11==2.21.5
nvidia-nccl-cu12==2.26.2
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvtx-cu11==11.8.86
nvidia-nvtx-cu12==12.6.77
openai==1.59.6
packaging==23.2
pandas==1.5.3
peft==0.7.1
pillow==10.4.0
platformdirs==4.3.7
prettytable==3.16.0
propcache==0.3.1
protobuf==4.25.6
psutil==5.9.8
py-cpuinfo==9.0.0
pyarrow==19.0.1
pydantic==2.11.7
pydantic_core==2.33.2
pydantic==2.8.2
pydantic_core==2.20.1
pydeck==0.9.1
pyecharts==2.0.8
Pygments==2.19.1
pynvml==12.0.0
pyparsing==3.2.3
python-dateutil==2.9.0.post0
pytz==2025.2
@ -119,7 +84,6 @@ regex==2024.11.6
requests==2.32.3
rich==13.7.1
rpds-py==0.24.0
s3transfer==0.13.0
safetensors==0.5.3
scikit-learn==1.5.1
scipy==1.15.2
@ -128,28 +92,21 @@ sentencepiece==0.2.0
sentry-sdk==2.26.1
setproctitle==1.3.5
simhash==2.1.2
simplejson==3.20.1
six==1.17.0
smmap==5.0.2
sniffio==1.3.1
streamlit==1.30.0
swankit==0.2.4
swanlab==0.6.4
sympy==1.13.3
tenacity==8.5.0
threadpoolctl==3.6.0
tiktoken==0.5.1
tokenizers==0.21.1
toml==0.10.2
torch==2.7.1
torchaudio==2.7.1
torchvision==0.22.1
tornado==6.4.2
tqdm==4.67.1
transformers==4.52.4
triton==3.3.1
transformers==4.48.0
triton==3.3.0
trl==0.13.0
typing-inspection==0.4.1
typing_extensions==4.13.2
tzlocal==5.3.1
ujson==5.1.0
@ -157,9 +114,7 @@ urllib3==2.4.0
validators==0.34.0
wandb==0.18.3
watchdog==6.0.0
wcwidth==0.2.13
Werkzeug==3.1.3
wrapt==1.17.2
xxhash==3.5.0
yarl==1.20.0
zipp==3.21.0

View File

@ -2,7 +2,7 @@
# 激活conda环境
# source $(conda info --base)/etc/profile.d/conda.sh
# conda activate mini
# conda activate ycz_accelerate
# 设置环境变量以帮助调试
export NCCL_DEBUG=INFO
@ -26,27 +26,24 @@ export PYTHONFAULTHANDLER=1
# --profile_interval 10
# 方法2: 使用命令行参数直接配置accelerate
# 内存泄漏调试配置 - 减少内存使用
CUDA_VISIBLE_DEVICES=0 uv run -p .venv python -m accelerate.commands.launch \
CUDA_VISIBLE_DEVICES=0 accelerate launch \
--num_processes=1 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py
# --batch_size 128 \
# --num_workers 1
# --knowledge_num 48020 \
# --num_workers 1 \
# --epochs 4 \
# --learning_rate 2e-4 \
# --dtype bfloat16 \
# --accumulation_steps 32 \
# --grad_clip 1.0 \
# --log_interval 50 \
# --save_interval 10000 \
# --dim 512 \
# --n_layers 8 \
# --max_seq_len 512 \
# --use_flash_attn \
# --profile \
# --profile_interval 10
train_pretrain_accelerate.py \
--epochs 3 \
--batch_size 24 \
--learning_rate 2e-4 \
--dtype bfloat16 \
--accumulation_steps 32 \
--grad_clip 1.0 \
--log_interval 100 \
--save_interval 10000 \
--dim 512 \
--n_layers 12 \
--max_seq_len 512 \
--use_flash_attn \
--profile \
--profile_interval 10\
--knowledge_num 4096 \
--knowledge_length 8

View File

@ -1,46 +0,0 @@
#!/bin/bash
# 激活conda环境
source $(conda info --base)/etc/profile.d/conda.sh
conda activate ycz_accelerate
# 设置环境变量以帮助调试
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1
# 实验1.3.0 - 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--multi_gpu \
--num_processes=4 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py \
--out_dir "out" \
--epochs 3 \
--embedding_epoch 2 \
--batch_size 64 \
--learning_rate 8e-5 \
--dtype bfloat16 \
--use_swanlab \
--swanlab_project "MiniMind-Pretrain" \
--num_workers 1 \
--accumulation_steps 16 \
--grad_clip 0.5 \
--warmup_iters 0 \
--log_interval 100 \
--save_interval 10000 \
--dim 1024 \
--n_layers 48 \
--max_seq_len 512 \
--data_path "./dataset/stable/merged_pretrain.jsonl" \
--profile \
--profile_interval 10 \
--use_flash_attn \
--knowledge_num 1048576 \
--knowledge_length 32 \
--database_init_path "./dataset/stable/sentence_trex_data.json" \
--fast_clustering \
--cluster_cache_path "./cache/cluster_tokens_single.pt" \
--memory_monitor_interval 10 \
--model_type "model_original" \
--model_size 538

View File

@ -1,47 +0,0 @@
#!/bin/bash
# 激活conda环境
source $(conda info --base)/etc/profile.d/conda.sh
conda activate ycz_accelerate
# 设置环境变量以帮助调试
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1
# 实验1.3.0 - 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--multi_gpu \
--num_processes=4 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py \
--out_dir "out" \
--epochs 3 \
--embedding_epoch 2 \
--batch_size 48 \
--learning_rate 2e-4 \
--dtype bfloat16 \
--use_swanlab \
--swanlab_project "MiniMind-Pretrain" \
--num_workers 1 \
--accumulation_steps 32 \
--grad_clip 1.0 \
--warmup_iters 0 \
--log_interval 100 \
--save_interval 10000 \
--dim 1024 \
--n_layers 18 \
--max_seq_len 512 \
--use_moe False \
--data_path "./dataset/stable/merged_pretrain.jsonl" \
--profile \
--profile_interval 10 \
--use_flash_attn \
--knowledge_num 1048576 \
--knowledge_length 32 \
--database_init_path "./dataset/stable/sentence_trex_data.json" \
--fast_clustering \
--cluster_cache_path "./cache/cluster_tokens_single.pt" \
--memory_monitor_interval 10 \
--model_type "model_no_feed" \
--model_size 814.724

View File

@ -1,47 +0,0 @@
#!/bin/bash
# 激活conda环境
source $(conda info --base)/etc/profile.d/conda.sh
conda activate ycz_accelerate
# 设置环境变量以帮助调试
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1
# 实验1.3.0 - 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--multi_gpu \
--num_processes=4 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py \
--out_dir "out" \
--epochs 3 \
--embedding_epoch 2 \
--batch_size 48 \
--learning_rate 2e-4 \
--dtype bfloat16 \
--use_swanlab \
--swanlab_project "MiniMind-Pretrain" \
--num_workers 1 \
--accumulation_steps 32 \
--grad_clip 1.0 \
--warmup_iters 0 \
--log_interval 100 \
--save_interval 10000 \
--dim 1024 \
--n_layers 18 \
--max_seq_len 512 \
--use_moe False \
--data_path "./dataset/stable/merged_pretrain.jsonl" \
--profile \
--profile_interval 10 \
--use_flash_attn \
--knowledge_num 1048576 \
--knowledge_length 32 \
--database_init_path "./dataset/stable/sentence_trex_data.json" \
--fast_clustering \
--cluster_cache_path "./cache/cluster_tokens_single.pt" \
--memory_monitor_interval 10 \
--model_type "model" \
--model_size 814.724

View File

@ -1,45 +0,0 @@
#!/bin/bash
# 激活conda环境
# source $(conda info --base)/etc/profile.d/conda.sh
# conda activate ycz_accelerate
# 设置环境变量以帮助调试
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1
# 实验1.3.0 - 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0 accelerate launch \
--num_processes=1 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py \
--out_dir "out" \
--epochs 3 \
--embedding_epoch 2 \
--batch_size 128 \
--learning_rate 8e-5 \
--dtype bfloat16 \
--use_swanlab \
--swanlab_project "MiniMind-Pretrain" \
--num_workers 1 \
--accumulation_steps 16 \
--grad_clip 0.5 \
--warmup_iters 0 \
--log_interval 100 \
--save_interval 10000 \
--dim 512 \
--n_layers 8 \
--max_seq_len 512 \
--data_path "./dataset/stable/merged_pretrain.jsonl" \
--profile \
--profile_interval 10 \
--use_flash_attn \
--knowledge_num 1048576 \
--knowledge_length 32 \
--database_init_path "./dataset/stable/sentence_trex_data.json" \
--fast_clustering \
--cluster_cache_path "./cache/cluster_tokens_single.pt" \
--memory_monitor_interval 10 \
--model_type "model" \
--model_size 538

View File

@ -32,7 +32,7 @@ def train_tokenizer():
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
# 定义特殊token
special_tokens = ["<unk>", "<|im_start|>", "<|im_end|>"]
special_tokens = ["<unk>", "<s>", "</s>"]
# 设置训练器并添加特殊token
trainer = trainers.BpeTrainer(
@ -53,8 +53,8 @@ def train_tokenizer():
# 检查特殊token的索引
assert tokenizer.token_to_id("<unk>") == 0
assert tokenizer.token_to_id("<|im_start|>") == 1
assert tokenizer.token_to_id("<|im_end|>") == 2
assert tokenizer.token_to_id("<s>") == 1
assert tokenizer.token_to_id("</s>") == 2
# 保存tokenizer
tokenizer_dir = "../model/minimind_tokenizer"
@ -77,7 +77,7 @@ def train_tokenizer():
"special": True
},
"1": {
"content": "<|im_start|>",
"content": "<s>",
"lstrip": False,
"normalized": False,
"rstrip": False,
@ -85,7 +85,7 @@ def train_tokenizer():
"special": True
},
"2": {
"content": "<|im_end|>",
"content": "</s>",
"lstrip": False,
"normalized": False,
"rstrip": False,
@ -94,9 +94,9 @@ def train_tokenizer():
}
},
"additional_special_tokens": [],
"bos_token": "<|im_start|>",
"bos_token": "<s>",
"clean_up_tokenization_spaces": False,
"eos_token": "<|im_end|>",
"eos_token": "</s>",
"legacy": True,
"model_max_length": 32768,
"pad_token": "<unk>",
@ -104,7 +104,7 @@ def train_tokenizer():
"spaces_between_special_tokens": False,
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": "<unk>",
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\n你是 MiniMind是一个有用的人工智能助手。<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<s>system\\n' + system_message + '</s>\\n' }}{% else %}{{ '<s>system\\n你是 MiniMind是一个有用的人工智能助手。</s>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
}
# 保存配置文件

View File

@ -1,33 +0,0 @@
#!/bin/bash
set -e
# 在容器启动后,首先从 requirements.txt 安装所有依赖包
# pip install -r requirements.txt
# bash install.sh -y
python3 -m pip install --upgrade pip
pip install uv -i https://pypi.tuna.tsinghua.edu.cn/simple
# 切换到项目目录
cd /ycz/Minimind
# 检查并修复虚拟环境
if [ ! -f .venv/bin/python ] || [ ! -x .venv/bin/python ]; then
echo "Virtual environment is broken or missing, recreating with uv..."
rm -rf .venv
uv venv .venv
fi
# 不要手动激活虚拟环境让uv自动管理
# . ./.venv/bin/activate
# 使用uv同步依赖
uv sync
# 安装完成后,执行主训练脚本
# "$@" 会将 experiment.yaml 中 entrypoint 定义的参数传递给 python 脚本
CUDA_VISIBLE_DEVICES=0 uv run python -m accelerate.commands.launch \
--num_processes=1 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py "$@"

97
test_real_rope.py Normal file
View File

@ -0,0 +1,97 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
测试实数版本的位置编码
"""
import torch
from model.model import precompute_pos_cis, precompute_pos_cis_real, apply_rotary_emb, apply_rotary_emb_real
from model.LMConfig import LMConfig
from model.model import MiniMindLM
def test_pos_encoding_equivalence():
"""测试复数版本和实数版本的位置编码是否等价"""
print("测试位置编码等价性...")
# 参数设置
dim = 64
seq_len = 10
# 生成复数版本的位置编码
pos_cis = precompute_pos_cis(dim=dim, end=seq_len)
# 生成实数版本的位置编码
pos_cis_real = precompute_pos_cis_real(dim=dim, end=seq_len)
# 创建随机查询和键
batch_size = 2
n_heads = 4
head_dim = dim
xq = torch.randn(batch_size, seq_len, n_heads, head_dim)
xk = torch.randn(batch_size, seq_len, n_heads, head_dim)
# 应用复数版本的旋转位置编码
xq_complex, xk_complex = apply_rotary_emb(xq, xk, pos_cis)
# 应用实数版本的旋转位置编码
xq_real, xk_real = apply_rotary_emb_real(xq, xk, pos_cis_real)
# 计算差异
q_diff = torch.abs(xq_complex - xq_real).mean().item()
k_diff = torch.abs(xk_complex - xk_real).mean().item()
print(f"查询差异: {q_diff:.6f}")
print(f"键差异: {k_diff:.6f}")
# 检查差异是否在可接受范围内
tolerance = 1e-5
if q_diff < tolerance and k_diff < tolerance:
print("✅ 测试通过: 复数版本和实数版本的位置编码在数值上等价")
else:
print("❌ 测试失败: 复数版本和实数版本的位置编码存在显著差异")
def test_model_forward():
"""测试模型前向传播"""
print("\n测试模型前向传播...")
# 创建模型配置
config = LMConfig(
dim=128,
n_layers=2,
n_heads=4,
n_kv_heads=4, # 确保n_kv_heads被设置且n_heads能被n_kv_heads整除
vocab_size=1000,
max_seq_len=128,
disable_db=True # 禁用数据库功能,避免额外的复杂性
)
# 创建模型
try:
model = MiniMindLM(config)
print(f"✅ 模型初始化成功")
except Exception as e:
print(f"❌ 模型初始化失败: {str(e)}")
return
# 创建输入
batch_size = 2
seq_len = 10
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
# 前向传播
try:
with torch.no_grad():
outputs = model(input_ids)
print(f"✅ 模型前向传播成功")
print(f"输出形状: {outputs.logits.shape}")
except Exception as e:
print(f"❌ 模型前向传播失败: {str(e)}")
if __name__ == "__main__":
# 测试位置编码等价性
test_pos_encoding_equivalence()
# 测试模型前向传播
test_model_forward()

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,8 @@
import os
# 设置环境变量 - 将wandb替换为SwanLab
# os.environ["SWANLAB_MODE"] = "online" # SwanLab使用在线模式
# 设置环境变量
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
import platform
import argparse
from tqdm import tqdm
import time
import math
import warnings
@ -19,75 +18,13 @@ from accelerate.utils import set_seed
from accelerate.utils import DeepSpeedPlugin
from accelerate.utils import DistributedDataParallelKwargs
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import swanlab # 替换wandb导入
import gc # 添加垃圾回收模块
import psutil # 添加系统资源监控模块
from model.model import MiniMindLM
from model.LMConfig import LMConfig
from model.dataset import PretrainDataset
warnings.filterwarnings('ignore')
# 内存监控辅助函数
def get_memory_usage():
"""获取当前内存使用情况"""
process = psutil.Process()
memory_info = process.memory_info()
return {
'rss_mb': memory_info.rss / 1024 / 1024, # 物理内存使用量MB
'vms_mb': memory_info.vms / 1024 / 1024, # 虚拟内存使用量MB
}
def get_cuda_memory_usage():
"""获取CUDA内存使用情况"""
if torch.cuda.is_available():
return {
'cuda_allocated_mb': torch.cuda.memory_allocated() / 1024 / 1024,
'cuda_reserved_mb': torch.cuda.memory_reserved() / 1024 / 1024,
'cuda_max_allocated_mb': torch.cuda.max_memory_allocated() / 1024 / 1024,
}
return {}
def get_tensor_memory_size(tensor_list):
"""计算tensor列表的总内存占用MB"""
total_size = 0
for batch in tensor_list:
if isinstance(batch, (list, tuple)):
for tensor in batch:
if isinstance(tensor, torch.Tensor):
total_size += tensor.numel() * tensor.element_size()
elif isinstance(batch, torch.Tensor):
total_size += batch.numel() * batch.element_size()
return total_size / 1024 / 1024 # 转换为MB
def log_memory_status(step, prefetch_batches, accelerator, stage="", detailed=False):
"""记录内存状态"""
if not accelerator.is_main_process:
return
memory_info = get_memory_usage()
cuda_info = get_cuda_memory_usage()
prefetch_memory = get_tensor_memory_size(prefetch_batches)
log_msg = f"[Memory Monitor] Step {step} {stage} - "
log_msg += f"Prefetch batches: {len(prefetch_batches)}, "
log_msg += f"Prefetch memory: {prefetch_memory:.2f}MB, "
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
if detailed:
log_msg += f", System VMS: {memory_info['vms_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA max allocated: {cuda_info['cuda_max_allocated_mb']:.2f}MB"
Logger(log_msg, accelerator)
# 日志记录函数
def Logger(msg, accelerator=None):
# 如果没有提供accelerator则只在主进程打印
@ -104,420 +41,27 @@ def get_lr(it, num_iters, learning_rate):
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
# 初始化模型函数
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
if args.model_type == "model":
Logger(f"Using model type: {args.model_type}")
from model.model import MiniMindLM, RMSNorm
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
# 默认模型初始化
Logger("Performing default model initialization...")
# 初始化嵌入层权重
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
# 初始化输出层权重(如果不共享权重的话)
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
# 初始化所有线性层
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# 使用Xavier/Glorot初始化
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
# 嵌入层使用正态分布初始化
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, RMSNorm):
# RMSNorm的权重初始化为1
if hasattr(module, 'weight'):
nn.init.ones_(module.weight)
# 初始化位置编码相关参数
if hasattr(model.knowledge_dataset, 'keys'):
nn.init.normal_(model.knowledge_dataset.keys, mean=0.0, std=0.02)
Logger("Default model initialization completed")
def init_model(lm_config, pretrained_embedding_path=None):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
# 如果提供了预训练的嵌入权重,加载它们
if pretrained_embedding_path:
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
pretrained_embeddings = torch.load(pretrained_embedding_path)
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
if database_init_path:
import json
import os
# 数据库参数
knowledge_num = args.knowledge_num
knowledge_length = args.knowledge_length
# 检查是否使用缓存
cache_dir = os.path.dirname(args.cluster_cache_path)
if cache_dir:
os.makedirs(cache_dir, exist_ok=True)
processed_tensor = None
# 尝试加载缓存的处理结果
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
try:
Logger(f"Loading cached processed results from {args.cluster_cache_path}")
processed_tensor = torch.load(args.cluster_cache_path)
# 验证缓存文件的形状是否可用
cached_knowledge_num, cached_knowledge_length = processed_tensor.shape
if cached_knowledge_length == knowledge_length:
if cached_knowledge_num >= knowledge_num:
# 缓存足够大,可以截取使用
processed_tensor = processed_tensor[:knowledge_num, :]
Logger(f"Successfully loaded cached data with shape {processed_tensor.shape}")
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
Logger("Skipping database initialization - using cached results")
else:
# 缓存太小,需要重新计算
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
processed_tensor = None
else:
# knowledge_length不匹配需要重新计算
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
processed_tensor = None
except Exception as e:
Logger(f"Failed to load cached data: {e}, recomputing...")
processed_tensor = None
# 只有在没有有效缓存时才进行数据库初始化和处理
if processed_tensor is None:
Logger(f"Loading database initialization data from {database_init_path}")
# 1. 加载JSON文件
with open(database_init_path, 'r', encoding='utf-8') as f:
database_data = json.load(f)
sentences_data = []
for data in database_data:
sentences_data.append(data['target'][0]['sentence'])
# 提取sentences列表
# sentences_data = database_data.get('sentences', [])
Logger(f"Loaded {len(sentences_data)} sentences from database")
# 2. 按照importance_score进行排序从高到低
try:
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
except:
sorted_sentences = sentences_data
# 3. 处理每条数据,不进行聚类
Logger("Processing individual sentences...")
processed_rows = []
# 获取空token的id用于填充
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# 处理所需数量的句子
num_to_process = min(knowledge_num, len(sorted_sentences))
# 添加截断统计变量
total_sentences = 0
truncated_sentences = 0
for i in range(num_to_process):
sentence_data = sorted_sentences[i]
try:
sentence = sentence_data.get('corrected_sentence')
except:
sentence = sentence_data
# 将句子转换为tokens
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
# 截断或填充到knowledge_length
total_sentences += 1
if len(sentence_tokens) > knowledge_length:
# 如果超过长度,截断
truncated_sentences += 1
sentence_tokens = sentence_tokens[:knowledge_length]
Logger(f"Sentence {i+1} truncated from {len(tokenizer.encode(sentence, add_special_tokens=False))} to {knowledge_length} tokens")
else:
# 如果不足长度用空token填充
original_length = len(sentence_tokens)
sentence_tokens.extend([pad_token_id] * (knowledge_length - len(sentence_tokens)))
if original_length < knowledge_length:
Logger(f"Sentence {i+1} padded from {original_length} to {knowledge_length} tokens")
processed_rows.append(sentence_tokens)
if (i + 1) % 1000 == 0:
Logger(f"Processed {i + 1}/{num_to_process} sentences")
# 如果句子数量不足用空token填充剩余位置
while len(processed_rows) < knowledge_num:
empty_tokens = [pad_token_id] * knowledge_length
processed_rows.append(empty_tokens)
if len(processed_rows) % 1000 == 0:
Logger(f"Added empty entry {len(processed_rows)}/{knowledge_num}")
Logger(f"Finished adding empty entries. Total: {len(processed_rows)}/{knowledge_num}")
# 转换为tensor
processed_tensor = torch.tensor(processed_rows, dtype=torch.long)
# 计算并打印截断句子的占比
truncation_ratio = truncated_sentences / total_sentences if total_sentences > 0 else 0.0
Logger(f"截断句子统计:")
Logger(f" - 总句子数: {total_sentences}")
Logger(f" - 截断句子数: {truncated_sentences}")
Logger(f" - 截断句子占比: {truncation_ratio:.4f} ({truncation_ratio*100:.2f}%)")
Logger(f"Data processing completed:")
Logger(f" - Processed {num_to_process} sentences")
Logger(f" - Added {knowledge_num - num_to_process} empty entries")
Logger(f" - Final shape: {processed_tensor.shape}")
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
# 保存处理结果到缓存文件
try:
torch.save(processed_tensor, args.cluster_cache_path)
Logger(f"Processed results saved to {args.cluster_cache_path}")
except Exception as e:
Logger(f"Failed to save processed results: {e}")
# 4. 初始化模型的knowledge_dataset
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
model.knowledge_dataset.knowledge_dataset.data.copy_(processed_tensor)
Logger("Successfully initialized model.knowledge_dataset.knowledge_dataset with processed data")
else:
Logger("Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize")
# 存储为全局变量作为备选
globals()['processed_database'] = processed_tensor
Logger(f"Database embeddings and sentences stored in model")
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
elif args.model_type == "model_original":
Logger(f"Using model type: {args.model_type}")
from model.model_original import MiniMindLM, RMSNorm
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
elif args.model_type == "model_no_feed":
Logger(f"Using model type: {args.model_type}")
from model.model_no_feed import MiniMindLM, RMSNorm
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
# 默认模型初始化
Logger("Performing default model initialization...")
# 初始化嵌入层权重
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
# 初始化输出层权重(如果不共享权重的话)
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
# 初始化所有线性层
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# 使用Xavier/Glorot初始化
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
# 嵌入层使用正态分布初始化
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, RMSNorm):
# RMSNorm的权重初始化为1
if hasattr(module, 'weight'):
nn.init.ones_(module.weight)
# 初始化位置编码相关参数
if hasattr(model.knowledge_dataset, 'keys'):
nn.init.normal_(model.knowledge_dataset.keys, mean=0.0, std=0.02)
Logger("Default model initialization completed")
# 如果提供了预训练的嵌入权重,加载它们
if pretrained_embedding_path:
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
pretrained_embeddings = torch.load(pretrained_embedding_path)
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
if database_init_path:
import json
import os
# 数据库参数
knowledge_num = args.knowledge_num
knowledge_length = args.knowledge_length
# 检查是否使用缓存
cache_dir = os.path.dirname(args.cluster_cache_path)
if cache_dir:
os.makedirs(cache_dir, exist_ok=True)
processed_tensor = None
# 尝试加载缓存的处理结果
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
try:
Logger(f"Loading cached processed results from {args.cluster_cache_path}")
processed_tensor = torch.load(args.cluster_cache_path)
# 验证缓存文件的形状是否可用
cached_knowledge_num, cached_knowledge_length = processed_tensor.shape
if cached_knowledge_length == knowledge_length:
if cached_knowledge_num >= knowledge_num:
# 缓存足够大,可以截取使用
processed_tensor = processed_tensor[:knowledge_num, :]
Logger(f"Successfully loaded cached data with shape {processed_tensor.shape}")
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
Logger("Skipping database initialization - using cached results")
else:
# 缓存太小,需要重新计算
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
processed_tensor = None
else:
# knowledge_length不匹配需要重新计算
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
processed_tensor = None
except Exception as e:
Logger(f"Failed to load cached data: {e}, recomputing...")
processed_tensor = None
# 只有在没有有效缓存时才进行数据库初始化和处理
if processed_tensor is None:
Logger(f"Loading database initialization data from {database_init_path}")
# 1. 加载JSON文件
with open(database_init_path, 'r', encoding='utf-8') as f:
database_data = json.load(f)
sentences_data = []
for data in database_data:
sentences_data.append(data['target'][0]['sentence'])
# 提取sentences列表
# sentences_data = database_data.get('sentences', [])
Logger(f"Loaded {len(sentences_data)} sentences from database")
# 2. 按照importance_score进行排序从高到低
try:
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
except:
sorted_sentences = sentences_data
# 3. 处理每条数据,不进行聚类
Logger("Processing individual sentences...")
processed_rows = []
# 获取空token的id用于填充
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# 处理所需数量的句子
num_to_process = min(knowledge_num, len(sorted_sentences))
# 添加截断统计变量
total_sentences = 0
truncated_sentences = 0
for i in range(num_to_process):
sentence_data = sorted_sentences[i]
try:
sentence = sentence_data.get('corrected_sentence')
except:
sentence = sentence_data
# 将句子转换为tokens
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
# 截断或填充到knowledge_length
total_sentences += 1
if len(sentence_tokens) > knowledge_length:
# 如果超过长度,截断
truncated_sentences += 1
sentence_tokens = sentence_tokens[:knowledge_length]
Logger(f"Sentence {i+1} truncated from {len(tokenizer.encode(sentence, add_special_tokens=False))} to {knowledge_length} tokens")
else:
# 如果不足长度用空token填充
original_length = len(sentence_tokens)
sentence_tokens.extend([pad_token_id] * (knowledge_length - len(sentence_tokens)))
if original_length < knowledge_length:
Logger(f"Sentence {i+1} padded from {original_length} to {knowledge_length} tokens")
processed_rows.append(sentence_tokens)
if (i + 1) % 1000 == 0:
Logger(f"Processed {i + 1}/{num_to_process} sentences")
# 如果句子数量不足用空token填充剩余位置
while len(processed_rows) < knowledge_num:
empty_tokens = [pad_token_id] * knowledge_length
processed_rows.append(empty_tokens)
if len(processed_rows) % 1000 == 0:
Logger(f"Added empty entry {len(processed_rows)}/{knowledge_num}")
Logger(f"Finished adding empty entries. Total: {len(processed_rows)}/{knowledge_num}")
# 转换为tensor
processed_tensor = torch.tensor(processed_rows, dtype=torch.long)
# 计算并打印截断句子的占比
truncation_ratio = truncated_sentences / total_sentences if total_sentences > 0 else 0.0
Logger(f"截断句子统计:")
Logger(f" - 总句子数: {total_sentences}")
Logger(f" - 截断句子数: {truncated_sentences}")
Logger(f" - 截断句子占比: {truncation_ratio:.4f} ({truncation_ratio*100:.2f}%)")
Logger(f"Data processing completed:")
Logger(f" - Processed {num_to_process} sentences")
Logger(f" - Added {knowledge_num - num_to_process} empty entries")
Logger(f" - Final shape: {processed_tensor.shape}")
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
# 保存处理结果到缓存文件
try:
torch.save(processed_tensor, args.cluster_cache_path)
Logger(f"Processed results saved to {args.cluster_cache_path}")
except Exception as e:
Logger(f"Failed to save processed results: {e}")
# 4. 初始化模型的knowledge_dataset
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
model.knowledge_dataset.knowledge_dataset.data.copy_(processed_tensor)
Logger("Successfully initialized model.knowledge_dataset.knowledge_dataset with processed data")
else:
Logger("Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize")
# 存储为全局变量作为备选
globals()['processed_database'] = processed_tensor
Logger(f"Database embeddings and sentences stored in model")
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
# 如果提供了预训练的嵌入权重,加载它们
if pretrained_embedding_path:
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
pretrained_embeddings = torch.load(pretrained_embedding_path)
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer):
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb):
loss_fct = nn.CrossEntropyLoss(reduction='none')
epoch_start_time = time.time()
total_steps_in_epoch = len(train_loader)
total_training_steps = args.epochs * total_steps_in_epoch
moe_path = '_moe' if args.use_moe else ''
best_loss = float('10000')
# 初始化CUDA事件变量
data_start = data_end = forward_start = forward_end = None
backward_start = backward_end = optimizer_start = optimizer_end = None
# 添加CUDA事件来分析性能 (只在主进程进行)
if args.profile and accelerator.is_main_process:
data_start = torch.cuda.Event(enable_timing=True)
@ -530,67 +74,44 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
optimizer_end = torch.cuda.Event(enable_timing=True)
# 预取数据
prefetch_factor = 8 # 预取的批次数
prefetch_factor = 2 # 预取的批次数
data_iter = iter(train_loader)
prefetch_batches = []
# 记录初始内存状态
if args.memory_monitor:
log_memory_status(-1, prefetch_batches, accelerator, "before_prefetch", detailed=True)
# 预取初始批次
for i in range(min(prefetch_factor, len(train_loader))):
for _ in range(min(prefetch_factor, len(train_loader))):
try:
batch = next(data_iter)
prefetch_batches.append(batch)
# 每次添加batch后记录内存变化
if args.memory_monitor and accelerator.is_main_process:
log_memory_status(-1, prefetch_batches, accelerator, f"after_adding_batch_{i+1}")
except StopIteration:
break
# 记录预取完成后的内存状态
if args.memory_monitor:
log_memory_status(-1, prefetch_batches, accelerator, "after_initial_prefetch", detailed=True)
# 在开始循环前初始化日志记录所需变量
last_log_time = epoch_start_time
for step in range(total_steps_in_epoch):
try:
# 计时数据加载 (只在主进程进行)
if args.profile and accelerator.is_main_process and data_start is not None:
if args.profile and accelerator.is_main_process:
data_start.record()
# 记录使用batch前的内存状态根据配置间隔记录详细信息
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "before_use_batch", detailed=True)
# 使用预取的数据
if prefetch_batches:
X, Y, loss_mask = prefetch_batches.pop(0)
# 记录使用batch后的内存变化
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "after_pop_batch")
else:
# 如果预取队列为空,直接加载
X, Y, loss_mask = next(data_iter)
if args.memory_monitor and accelerator.is_main_process:
Logger(f"[Memory Monitor] Step {step} - Prefetch queue empty, loading directly!", accelerator)
# 异步预取下一批数据
if step + prefetch_factor < len(train_loader):
try:
batch = next(data_iter)
prefetch_batches.append(batch)
# 记录添加新batch后的内存变化
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "after_add_batch")
except StopIteration:
pass
# 计时数据加载结束 (只在主进程进行)
if args.profile and accelerator.is_main_process and data_end is not None:
if args.profile and accelerator.is_main_process:
data_end.record()
# 更新学习率
@ -598,31 +119,33 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
scheduler.step()
# 计时前向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process and forward_start is not None:
if args.profile and accelerator.is_main_process:
forward_start.record()
# 前向传播
with ctx:
if step == 0 and args.embedding_epoch == epoch:
# 需要设置原始模型的freeze_embedding属性而不是包装后的模型
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.freeze_embedding = True
Logger(f"Set freeze_embedding=True for epoch {epoch}, step {step}", accelerator)
res = model(X, step=step)
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
# 移除辅助损失计算,统一不使用 aux_loss
# 添加辅助损失,如果存在的话
try:
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
if hasattr(l.feed_forward, 'aux_loss'))
loss += aux_loss
except Exception as e:
Logger(f"Warning: Could not add auxiliary loss: {e}")
# 如果出错,不添加辅助损失
loss = loss / args.accumulation_steps
# 计时前向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process and forward_end is not None:
if args.profile and accelerator.is_main_process:
forward_end.record()
# 计时反向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process and backward_start is not None:
if args.profile and accelerator.is_main_process:
backward_start.record()
# 反向传播
@ -630,11 +153,11 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
accelerator.backward(loss)
# 计时反向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process and backward_end is not None:
if args.profile and accelerator.is_main_process:
backward_end.record()
# 计时优化器步骤 (只在主进程进行)
if args.profile and accelerator.is_main_process and optimizer_start is not None:
if args.profile and accelerator.is_main_process:
optimizer_start.record()
# 优化器步骤 - 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪
@ -647,111 +170,40 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
optimizer.zero_grad()
# 计时优化器步骤结束 (只在主进程进行)
if args.profile and accelerator.is_main_process and optimizer_end is not None:
if args.profile and accelerator.is_main_process:
optimizer_end.record()
# 打印训练信息 (只在主进程进行)
if (step + 1) % args.log_interval == 0 and accelerator.is_main_process:
current_time = time.time()
# 记录日志输出时的详细内存状态
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "at_log_interval", detailed=True)
# 强制垃圾回收并记录内存变化
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
log_memory_status(step, prefetch_batches, accelerator, "after_gc", detailed=True)
# 计算性能指标
if args.profile and accelerator.is_main_process:
if args.profile:
torch.cuda.synchronize()
# 确保所有事件都已记录才计算elapsed_time
try:
data_time = data_start.elapsed_time(data_end) if data_start is not None and data_end is not None else 0
forward_time = forward_start.elapsed_time(forward_end) if forward_start is not None and forward_end is not None else 0
backward_time = backward_start.elapsed_time(backward_end) if backward_start is not None and backward_end is not None else 0
optimizer_time = optimizer_start.elapsed_time(optimizer_end) if optimizer_start is not None and optimizer_end is not None else 0
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
# total_time_ms = data_time + forward_time + backward_time + optimizer_time
# 使用自上次日志以来的时间计算性能指标,而不是总时间
data_time = data_start.elapsed_time(data_end)
forward_time = forward_start.elapsed_time(forward_end)
backward_time = backward_start.elapsed_time(backward_end)
optimizer_time = optimizer_start.elapsed_time(optimizer_end)
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
# total_time_ms = data_time + forward_time + backward_time + optimizer_time
# 打印性能分析
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
Logger(f"性能分析 (Avg/iter over last {args.log_interval} steps) - "
f"Data: {data_time/args.log_interval:.2f}ms, "
f"Fwd: {forward_time/args.log_interval:.2f}ms, "
f"Bwd: {backward_time/args.log_interval:.2f}ms, "
f"Optim: {optimizer_time/args.log_interval:.2f}ms, "
f"Iter Time: {iter_time:.2f}ms", accelerator)
# 生成文本示例
try:
# 随机选择一个样本
random_idx = torch.randint(0, X.size(0), (1,)).item()
sample_input = X[random_idx:random_idx+1] # [1, seq_len]
sample_target = Y[random_idx:random_idx+1] # [1, seq_len]
# 取前面的部分作为prompt确保后面有10个token作为真实值
prompt_len = sample_input.size(1) // 2
prompt_input = sample_input[:, :prompt_len]
# 获取真实的后10个token
true_next_tokens = sample_target[:, prompt_len-1:prompt_len-1+10] # 真实的接下来10个token
# 生成10个token
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.eval() # 设置为评估模式
with torch.no_grad():
generated = unwrapped_model.generate(
prompt_input,
max_new_tokens=10,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
# 转换为人类可读文本
prompt_text = tokenizer.decode(prompt_input[0], skip_special_tokens=True)
true_text = tokenizer.decode(true_next_tokens[0], skip_special_tokens=True)
# 获取新生成的token
prompt_tokens = prompt_input[0].tolist()
generated_tokens = generated[0].tolist()
if len(generated_tokens) > len(prompt_tokens):
new_tokens = generated_tokens[len(prompt_tokens):len(prompt_tokens)+10] # 只取前10个
generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
else:
generated_text = "[未生成新token]"
Logger(f"文本生成对比:", accelerator)
Logger(f" 输入提示: {prompt_text}", accelerator)
Logger(f" 真实续写: {true_text}", accelerator)
Logger(f" 模型生成: {generated_text}", accelerator)
unwrapped_model.train() # 恢复训练模式
except Exception as e:
Logger(f"生成文本示例失败: {e}", accelerator)
# 重置事件以便下次测量从0开始
data_start = torch.cuda.Event(enable_timing=True)
data_end = torch.cuda.Event(enable_timing=True)
forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
except RuntimeError as e:
if "Both events must be recorded" in str(e):
Logger(f"Warning: CUDA events not properly recorded, skipping performance analysis: {e}", accelerator)
else:
raise e
# 打印性能分析
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
Logger(f"性能分析 (Avg/iter over last {args.log_interval} steps) - "
f"Data: {data_time/args.log_interval:.2f}ms, "
f"Fwd: {forward_time/args.log_interval:.2f}ms, "
f"Bwd: {backward_time/args.log_interval:.2f}ms, "
f"Optim: {optimizer_time/args.log_interval:.2f}ms, "
f"Iter Time: {iter_time:.2f}ms", accelerator)
# 重置事件以便下次测量从0开始
data_start = torch.cuda.Event(enable_timing=True)
data_end = torch.cuda.Event(enable_timing=True)
forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
# 计算当前学习率
@ -792,13 +244,11 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
swanlab_run.log(log_dict)
if args.use_wandb and accelerator.is_main_process and wandb:
wandb.log(log_dict)
# 保存模型 (只在主进程进行)
loss_total = loss.item() * args.accumulation_steps
if epoch > 1 and best_loss > loss_total and accelerator.is_main_process:
best_loss = loss_total
if (step + 1) % args.save_interval == 0 and accelerator.is_main_process:
# 使用函数开始处定义的moe_path变量
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
@ -811,73 +261,38 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
except Exception as e:
Logger(f"Error in training step: {e}", accelerator)
# 记录异常时的内存状态
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "at_exception", detailed=True)
import traceback
Logger(traceback.format_exc(), accelerator)
# 清理prefetch_batches防止内存泄漏
if args.memory_monitor and accelerator.is_main_process:
Logger(f"[Memory Monitor] Clearing prefetch_batches due to exception. Current length: {len(prefetch_batches)}", accelerator)
prefetch_batches.clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "after_exception_cleanup", detailed=True)
# 训练epoch结束时清理prefetch_batches
if args.memory_monitor:
if accelerator.is_main_process:
Logger(f"[Memory Monitor] Epoch {epoch+1} finished. Clearing prefetch_batches. Final length: {len(prefetch_batches)}", accelerator)
log_memory_status(total_steps_in_epoch-1, prefetch_batches, accelerator, "before_epoch_end_cleanup", detailed=True)
prefetch_batches.clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if args.memory_monitor:
log_memory_status(total_steps_in_epoch-1, prefetch_batches, accelerator, "after_epoch_end_cleanup", detailed=True)
def main():
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=4)
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=24)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数
parser.add_argument("--swanlab_project", type=str, default="MiniMind-Pretrain") # 替换wandb参数
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--use_wandb", default=True, action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
parser.add_argument("--num_workers", type=int, default=48)
parser.add_argument("--accumulation_steps", type=int, default=32)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=1)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=10000)
parser.add_argument('--dim', default=512, type=int)
parser.add_argument('--n_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--dim', default=1024, type=int)
parser.add_argument('--n_layers', default=32, type=int)
parser.add_argument('--max_seq_len', default=1024, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代")
parser.add_argument("--data_path", type=str, default="./dataset/stable/merged_pretrain.jsonl")
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目")
parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
parser.add_argument("--database_init_path", type=str, default="./dataset/stable/sentence_trex_data.json", help="数据库初始化路径")
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控")
parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)")
parser.add_argument("--model_type", type=str, default="model", help="使用什么模型训练") #model,model_original,model_no_feed
parser.add_argument("--model_size", type=float, default=50.0, help="模型大小")
parser.add_argument("--swanlab_online", type=bool, default=False, help="是否使用在线SwanLab服务")
parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目")
parser.add_argument("--knowledge_length", type=int, default=8,help="知识库的句子长度")
args = parser.parse_args()
#########################################################
# 初始化accelerator和deepspeed
#########################################################
@ -888,7 +303,7 @@ def main():
gradient_accumulation_steps=args.accumulation_steps,
gradient_clipping=args.grad_clip,
zero_stage=2, # 使用ZeRO-2优化
offload_optimizer_device="none", # 将优化器状态卸载到CPU
offload_optimizer_device="cpu", # 将优化器状态卸载到CPU
offload_param_device="none", # 不将参数卸载到CPU
)
accelerator = Accelerator(
@ -913,8 +328,7 @@ def main():
disable_db=args.disable_db,
flash_attn=args.use_flash_attn,
knowledge_num=args.knowledge_num,
knowledge_length=args.knowledge_length,
embeddings_epoch=args.embedding_epoch
knowledge_length=args.knowledge_length
)
#########################################################
@ -932,37 +346,18 @@ def main():
#########################################################
# 配置SwanLab
# 配置wandb
#########################################################
# 设置SwanLab运行名称
args.swanlab_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
# 合并args和lm_config为一个字典无论是否使用SwanLab都需要用于打印配置信息
config_dict = vars(args).copy()
config_dict.update(vars(lm_config))
# 初始化SwanLab实验实例
swanlab_run = None
if args.use_swanlab and accelerator.is_main_process:
if args.swanlab_online:
# 使用在线SwanLab服务
# 初始化SwanLab
swanlab_run = swanlab.init(
project=args.swanlab_project,
experiment_name=args.swanlab_run_name,
description="MiniMind预训练实验使用本地部署的SwanLab进行可视化",
config=config_dict
)
else:
swanlab_run = swanlab.init(
project=args.swanlab_project,
experiment_name=args.swanlab_run_name,
description="MiniMind预训练实验使用本地部署的SwanLab进行可视化",
config=config_dict,
mode="offline"
)
# 设置wandb运行名称
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
if args.use_wandb and accelerator.is_main_process:
import wandb
# 合并args和lm_config为一个字典
config_dict = vars(args).copy()
config_dict.update(vars(lm_config))
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config_dict)
else:
swanlab_run = None
wandb = None
#########################################################
# 打印信息
@ -984,7 +379,7 @@ def main():
#########################################################
# 初始化模型和tokenizer
#########################################################
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path, args.database_init_path, args)
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
# 将accelerator传递给init_model函数中的Logger调用
Logger(f'模型初始化完成', accelerator)
@ -1044,31 +439,13 @@ def main():
#########################################################
overall_start_time = time.time() # Record overall start time
for epoch in range(args.epochs):
Logger(f"开始第{epoch+1}轮训练", accelerator)
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer) # Pass tokenizer
# 每个epoch结束后进行内存清理
Logger(f"{epoch+1}轮训练完成,进行内存清理", accelerator)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 记录epoch结束时的内存状态
if accelerator.is_main_process:
memory_info = get_memory_usage()
cuda_info = get_cuda_memory_usage()
log_msg = f"[Memory Monitor] Epoch {epoch+1} completed - "
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
Logger(log_msg, accelerator)
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb) # Pass overall start time
#########################################################
# 关闭SwanLab
# 关闭wandb
#########################################################
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
swanlab_run.finish()
if args.use_wandb and accelerator.is_main_process:
wandb.finish()
if __name__ == "__main__":
main()

4835
uv.lock generated

File diff suppressed because it is too large Load Diff