Compare commits
No commits in common. "archive/SLM" and "master" have entirely different histories.
archive/SL
...
master
8
.gitignore
vendored
@ -2,4 +2,10 @@
|
|||||||
/dataset
|
/dataset
|
||||||
/out
|
/out
|
||||||
wandb/
|
wandb/
|
||||||
**/*.log
|
**/*.log
|
||||||
|
models/sentence_transformers/
|
||||||
|
models/sentence_transformers_cache/
|
||||||
|
**/*.pyc
|
||||||
|
qwen2-1.7B/
|
||||||
|
images/
|
||||||
|
cache/
|
102
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
{
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Debug Train Pretrain Accelerate",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"python": "/opt/conda/envs/mini/bin/python",
|
||||||
|
"cwd": "${workspaceFolder}",
|
||||||
|
"env": {
|
||||||
|
"PYTHONPATH": "${workspaceFolder}",
|
||||||
|
"CUDA_VISIBLE_DEVICES": "0"
|
||||||
|
},
|
||||||
|
"justMyCode": false,
|
||||||
|
"stopOnEntry": false,
|
||||||
|
"redirectOutput": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Debug Train Pretrain Accelerate (Multi-GPU)",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"python": "/opt/conda/envs/mini/bin/python",
|
||||||
|
"args": [
|
||||||
|
"--hidden_size", "512",
|
||||||
|
"--max_seq_len", "512",
|
||||||
|
"--n_layers", "8",
|
||||||
|
"--batch_size", "8",
|
||||||
|
"--epochs", "1"
|
||||||
|
],
|
||||||
|
"cwd": "${workspaceFolder}",
|
||||||
|
"env": {
|
||||||
|
"PYTHONPATH": "${workspaceFolder}",
|
||||||
|
"CUDA_VISIBLE_DEVICES": "0,1"
|
||||||
|
},
|
||||||
|
"justMyCode": false,
|
||||||
|
"stopOnEntry": false,
|
||||||
|
"redirectOutput": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Debug Train Pretrain Accelerate (Small Test)",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"python": "/opt/conda/envs/mini/bin/python",
|
||||||
|
"args": [
|
||||||
|
"--hidden_size", "512",
|
||||||
|
"--max_seq_len", "512",
|
||||||
|
"--n_layers", "8",
|
||||||
|
"--batch_size", "2",
|
||||||
|
"--epochs", "1",
|
||||||
|
"--log_interval", "10",
|
||||||
|
"--save_interval", "100",
|
||||||
|
"--accumulation_steps", "4"
|
||||||
|
],
|
||||||
|
"cwd": "${workspaceFolder}",
|
||||||
|
"env": {
|
||||||
|
"PYTHONPATH": "${workspaceFolder}",
|
||||||
|
"CUDA_VISIBLE_DEVICES": "0",
|
||||||
|
"WANDB_MODE": "offline"
|
||||||
|
},
|
||||||
|
"justMyCode": false,
|
||||||
|
"stopOnEntry": false,
|
||||||
|
"redirectOutput": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Debug ExtractDB Comparison",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"python": "/opt/conda/envs/mini/bin/python",
|
||||||
|
"args": [
|
||||||
|
"--hidden_size", "512",
|
||||||
|
"--max_seq_len", "256",
|
||||||
|
"--n_layers", "4",
|
||||||
|
"--batch_size", "2",
|
||||||
|
"--epochs", "1",
|
||||||
|
"--log_interval", "10",
|
||||||
|
"--save_interval", "200",
|
||||||
|
"--accumulation_steps", "2",
|
||||||
|
"--comparison_mode",
|
||||||
|
"--knowledge_num", "256",
|
||||||
|
"--knowledge_length", "64",
|
||||||
|
"--comparison_mode"
|
||||||
|
],
|
||||||
|
"cwd": "${workspaceFolder}",
|
||||||
|
"env": {
|
||||||
|
"PYTHONPATH": "${workspaceFolder}",
|
||||||
|
"CUDA_VISIBLE_DEVICES": "0",
|
||||||
|
"WANDB_MODE": "offline"
|
||||||
|
},
|
||||||
|
"justMyCode": false,
|
||||||
|
"stopOnEntry": false,
|
||||||
|
"redirectOutput": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
18
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
{
|
||||||
|
"python.pythonPath": "/home/iomgaa/miniconda3/envs/accelerate/bin/python",
|
||||||
|
"python.defaultInterpreterPath": "/home/iomgaa/miniconda3/envs/accelerate/bin/python",
|
||||||
|
"python.terminal.activateEnvironment": true,
|
||||||
|
"python.terminal.activateEnvInCurrentTerminal": true,
|
||||||
|
"python.linting.enabled": true,
|
||||||
|
"python.linting.pylintEnabled": false,
|
||||||
|
"python.linting.flake8Enabled": true,
|
||||||
|
"python.formatting.provider": "black",
|
||||||
|
"python.analysis.autoImportCompletions": true,
|
||||||
|
"python.analysis.typeCheckingMode": "off",
|
||||||
|
"files.exclude": {
|
||||||
|
"**/__pycache__": true,
|
||||||
|
"**/*.pyc": true,
|
||||||
|
"**/.git": false,
|
||||||
|
"**/wandb": false
|
||||||
|
}
|
||||||
|
}
|
@ -1,128 +0,0 @@
|
|||||||
# Contributor Covenant Code of Conduct
|
|
||||||
|
|
||||||
## Our Pledge
|
|
||||||
|
|
||||||
We as members, contributors, and leaders pledge to make participation in our
|
|
||||||
community a harassment-free experience for everyone, regardless of age, body
|
|
||||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
|
||||||
identity and expression, level of experience, education, socio-economic status,
|
|
||||||
nationality, personal appearance, race, religion, or sexual identity
|
|
||||||
and orientation.
|
|
||||||
|
|
||||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
|
||||||
diverse, inclusive, and healthy community.
|
|
||||||
|
|
||||||
## Our Standards
|
|
||||||
|
|
||||||
Examples of behavior that contributes to a positive environment for our
|
|
||||||
community include:
|
|
||||||
|
|
||||||
* Demonstrating empathy and kindness toward other people
|
|
||||||
* Being respectful of differing opinions, viewpoints, and experiences
|
|
||||||
* Giving and gracefully accepting constructive feedback
|
|
||||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
|
||||||
and learning from the experience
|
|
||||||
* Focusing on what is best not just for us as individuals, but for the
|
|
||||||
overall community
|
|
||||||
|
|
||||||
Examples of unacceptable behavior include:
|
|
||||||
|
|
||||||
* The use of sexualized language or imagery, and sexual attention or
|
|
||||||
advances of any kind
|
|
||||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
|
||||||
* Public or private harassment
|
|
||||||
* Publishing others' private information, such as a physical or email
|
|
||||||
address, without their explicit permission
|
|
||||||
* Other conduct which could reasonably be considered inappropriate in a
|
|
||||||
professional setting
|
|
||||||
|
|
||||||
## Enforcement Responsibilities
|
|
||||||
|
|
||||||
Community leaders are responsible for clarifying and enforcing our standards of
|
|
||||||
acceptable behavior and will take appropriate and fair corrective action in
|
|
||||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
|
||||||
or harmful.
|
|
||||||
|
|
||||||
Community leaders have the right and responsibility to remove, edit, or reject
|
|
||||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
|
||||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
|
||||||
decisions when appropriate.
|
|
||||||
|
|
||||||
## Scope
|
|
||||||
|
|
||||||
This Code of Conduct applies within all community spaces, and also applies when
|
|
||||||
an individual is officially representing the community in public spaces.
|
|
||||||
Examples of representing our community include using an official e-mail address,
|
|
||||||
posting via an official social media account, or acting as an appointed
|
|
||||||
representative at an online or offline event.
|
|
||||||
|
|
||||||
## Enforcement
|
|
||||||
|
|
||||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
|
||||||
reported to the community leaders responsible for enforcement at
|
|
||||||
.
|
|
||||||
All complaints will be reviewed and investigated promptly and fairly.
|
|
||||||
|
|
||||||
All community leaders are obligated to respect the privacy and security of the
|
|
||||||
reporter of any incident.
|
|
||||||
|
|
||||||
## Enforcement Guidelines
|
|
||||||
|
|
||||||
Community leaders will follow these Community Impact Guidelines in determining
|
|
||||||
the consequences for any action they deem in violation of this Code of Conduct:
|
|
||||||
|
|
||||||
### 1. Correction
|
|
||||||
|
|
||||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
|
||||||
unprofessional or unwelcome in the community.
|
|
||||||
|
|
||||||
**Consequence**: A private, written warning from community leaders, providing
|
|
||||||
clarity around the nature of the violation and an explanation of why the
|
|
||||||
behavior was inappropriate. A public apology may be requested.
|
|
||||||
|
|
||||||
### 2. Warning
|
|
||||||
|
|
||||||
**Community Impact**: A violation through a single incident or series
|
|
||||||
of actions.
|
|
||||||
|
|
||||||
**Consequence**: A warning with consequences for continued behavior. No
|
|
||||||
interaction with the people involved, including unsolicited interaction with
|
|
||||||
those enforcing the Code of Conduct, for a specified period of time. This
|
|
||||||
includes avoiding interactions in community spaces as well as external channels
|
|
||||||
like social media. Violating these terms may lead to a temporary or
|
|
||||||
permanent ban.
|
|
||||||
|
|
||||||
### 3. Temporary Ban
|
|
||||||
|
|
||||||
**Community Impact**: A serious violation of community standards, including
|
|
||||||
sustained inappropriate behavior.
|
|
||||||
|
|
||||||
**Consequence**: A temporary ban from any sort of interaction or public
|
|
||||||
communication with the community for a specified period of time. No public or
|
|
||||||
private interaction with the people involved, including unsolicited interaction
|
|
||||||
with those enforcing the Code of Conduct, is allowed during this period.
|
|
||||||
Violating these terms may lead to a permanent ban.
|
|
||||||
|
|
||||||
### 4. Permanent Ban
|
|
||||||
|
|
||||||
**Community Impact**: Demonstrating a pattern of violation of community
|
|
||||||
standards, including sustained inappropriate behavior, harassment of an
|
|
||||||
individual, or aggression toward or disparagement of classes of individuals.
|
|
||||||
|
|
||||||
**Consequence**: A permanent ban from any sort of public interaction within
|
|
||||||
the community.
|
|
||||||
|
|
||||||
## Attribution
|
|
||||||
|
|
||||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
|
||||||
version 2.0, available at
|
|
||||||
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
|
||||||
|
|
||||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
|
||||||
enforcement ladder](https://github.com/mozilla/diversity).
|
|
||||||
|
|
||||||
[homepage]: https://www.contributor-covenant.org
|
|
||||||
|
|
||||||
For answers to common questions about this code of conduct, see the FAQ at
|
|
||||||
https://www.contributor-covenant.org/faq. Translations are available at
|
|
||||||
https://www.contributor-covenant.org/translations.
|
|
199
README.md
@ -1,199 +0,0 @@
|
|||||||
<div align="center">
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div align="center">
|
|
||||||
|
|
||||||

|
|
||||||
[](https://github.com/jingyaogong/minimind/stargazers)
|
|
||||||
[](LICENSE)
|
|
||||||
[](https://github.com/jingyaogong/minimind/commits/master)
|
|
||||||
[](https://github.com/jingyaogong/minimind/pulls)
|
|
||||||
[](https://huggingface.co/collections/jingyaogong/minimind-66caf8d999f5c7fa64f399e5)
|
|
||||||
|
|
||||||
</div>
|
|
||||||
|
|
||||||
|
|
||||||
# 📌 数据介绍
|
|
||||||
|
|
||||||
## Ⅰ Tokenizer
|
|
||||||
|
|
||||||
分词器将单词从自然语言通过“词典”映射到`0, 1, 36`这样的数字,可以理解为数字就代表了单词在“词典”中的页码。
|
|
||||||
可以选择自己构造词表训练一个“词典”,代码可见`./scripts/train_tokenizer.py`(仅供学习参考,若非必要无需再自行训练,MiniMind已自带tokenizer)。
|
|
||||||
或者选择比较出名的开源大模型分词器,
|
|
||||||
正如同直接用新华/牛津词典的优点是token编码压缩率很好,缺点是页数太多,动辄数十万个词汇短语;
|
|
||||||
自己训练的分词器,优点是词表长度和内容随意控制,缺点是压缩率很低(例如"hello"也许会被拆分为"h e l l o"
|
|
||||||
五个独立的token),且生僻词难以覆盖。
|
|
||||||
“词典”的选择固然很重要,LLM的输出本质上是SoftMax到词典N个词的多分类问题,然后通过“词典”解码到自然语言。
|
|
||||||
因为MiniMind体积需要严格控制,为了避免模型头重脚轻(词嵌入embedding层参数在LLM占比太高),所以词表长度短短益善。
|
|
||||||
|
|
||||||
<details style="color:rgb(128,128,128)">
|
|
||||||
<summary>Tokenizer介绍</summary>
|
|
||||||
|
|
||||||
第三方强大的开源模型例如Yi、qwen、chatglm、mistral、Llama3的tokenizer词表长度如下:
|
|
||||||
|
|
||||||
<table>
|
|
||||||
<tr><th>Tokenizer模型</th><th>词表大小</th><th>来源</th></tr>
|
|
||||||
<tr><td>yi tokenizer</td><td>64,000</td><td>01万物(中国)</td></tr>
|
|
||||||
<tr><td>qwen2 tokenizer</td><td>151,643</td><td>阿里云(中国)</td></tr>
|
|
||||||
<tr><td>glm tokenizer</td><td>151,329</td><td>智谱AI(中国)</td></tr>
|
|
||||||
<tr><td>mistral tokenizer</td><td>32,000</td><td>Mistral AI(法国)</td></tr>
|
|
||||||
<tr><td>llama3 tokenizer</td><td>128,000</td><td>Meta(美国)</td></tr>
|
|
||||||
<tr><td>minimind tokenizer</td><td>6,400</td><td>自定义</td></tr>
|
|
||||||
</table>
|
|
||||||
|
|
||||||
> 👉2024-09-17更新:为了防止过去的版本歧义&控制体积,minimind所有模型均使用minimind_tokenizer分词,废弃所有mistral_tokenizer版本。
|
|
||||||
|
|
||||||
```
|
|
||||||
# 一些自言自语
|
|
||||||
> 尽管minimind_tokenizer长度很小,编解码效率弱于qwen2、glm等中文友好型分词器。
|
|
||||||
> 但minimind模型选择了自己训练的minimind_tokenizer作为分词器,以保持整体参数轻量,避免编码层和计算层占比失衡,头重脚轻,因为minimind的词表大小只有6400。
|
|
||||||
> 且minimind在实际测试中没有出现过生僻词汇解码失败的情况,效果良好。
|
|
||||||
> 由于自定义词表压缩长度到6400,使得LLM总参数量最低只有25.8M。
|
|
||||||
> 训练数据`tokenizer_train.jsonl`均来自于`匠数大模型数据集`,这部分数据相对次要,如需训练可以自由选择。
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
## Ⅱ Pretrain数据
|
|
||||||
|
|
||||||
经历了MiniMind-V1的低质量预训练数据,导致模型胡言乱语的教训,`2025-02-05` 之后决定不再采用大规模无监督的数据集做预训练。
|
|
||||||
进而尝试把[匠数大模型数据集](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)的中文部分提取出来,
|
|
||||||
清洗出字符`<512`长度的大约1.6GB的语料直接拼接成预训练数据 `pretrain_hq.jsonl`,hq即为high
|
|
||||||
quality(当然也还不算high,提升数据质量无止尽)。
|
|
||||||
|
|
||||||
文件`pretrain_hq.jsonl` 数据格式为
|
|
||||||
|
|
||||||
```bash
|
|
||||||
{"text": "如何才能摆脱拖延症? 治愈拖延症并不容易,但以下建议可能有所帮助..."}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Ⅲ SFT数据
|
|
||||||
|
|
||||||
[匠数大模型SFT数据集](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
|
||||||
“是一个完整、格式统一、安全的大模型训练和研究资源。
|
|
||||||
从网络上的公开数据源收集并整理了大量开源数据集,对其进行了格式统一,数据清洗,
|
|
||||||
包含10M条数据的中文数据集和包含2M条数据的英文数据集。”
|
|
||||||
以上是官方介绍,下载文件后的数据总量大约在4B tokens,肯定是适合作为中文大语言模型的SFT数据的。
|
|
||||||
但是官方提供的数据格式很乱,全部用来sft代价太大。
|
|
||||||
我将把官方数据集进行了二次清洗,把含有符号污染和噪声的条目去除;另外依然只保留了总长度`<512`
|
|
||||||
的内容,此阶段希望通过大量对话补充预训练阶段欠缺的知识。
|
|
||||||
导出文件为`sft_512.jsonl`(~7.5GB)。
|
|
||||||
|
|
||||||
[Magpie-SFT数据集](https://www.modelscope.cn/organization/Magpie-Align)
|
|
||||||
收集了~1M条来自Qwen2/2.5的高质量对话,我将这部分数据进一步清洗,把总长度`<2048`的部分导出为`sft_2048.jsonl`(~9GB)。
|
|
||||||
长度`<1024`的部分导出为`sft_1024.jsonl`(~5.5GB),用大模型对话数据直接进行sft就属于“黑盒蒸馏”的范畴。
|
|
||||||
|
|
||||||
进一步清洗前两步sft的数据(只保留中文字符占比高的内容),筛选长度`<512`的对话,得到`sft_mini_512.jsonl`(~1.2GB)。
|
|
||||||
|
|
||||||
所有sft文件 `sft_X.jsonl` 数据格式均为
|
|
||||||
|
|
||||||
```text
|
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{"role": "user", "content": "你好"},
|
|
||||||
{"role": "assistant", "content": "你好!"},
|
|
||||||
{"role": "user", "content": "再见"},
|
|
||||||
{"role": "assistant", "content": "再见!"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Ⅳ RLHF数据
|
|
||||||
|
|
||||||
来自[Magpie-DPO数据集](https://www.modelscope.cn/datasets/Magpie-Align/MagpieLM-DPO-Data-v0.1)
|
|
||||||
大约200k条偏好数据(均是英文)生成自Llama3.1-70B/8B,可以用于训练奖励模型,优化模型回复质量,使其更加符合人类偏好。
|
|
||||||
这里将数据总长度`<3000`的内容重组为`dpo.jsonl`(~0.9GB),包含`chosen`和`rejected`两个字段,`chosen`
|
|
||||||
为偏好的回复,`rejected`为拒绝的回复。
|
|
||||||
|
|
||||||
文件 `dpo.jsonl` 数据格式为
|
|
||||||
|
|
||||||
```text
|
|
||||||
{
|
|
||||||
"chosen": [
|
|
||||||
{"content": "Q", "role": "user"},
|
|
||||||
{"content": "good answer", "role": "assistant"}
|
|
||||||
],
|
|
||||||
"rejected": [
|
|
||||||
{"content": "Q", "role": "user"},
|
|
||||||
{"content": "bad answer", "role": "assistant"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Ⅴ Reason数据集:
|
|
||||||
|
|
||||||
不得不说2025年2月谁能火的过DeepSeek...
|
|
||||||
也激发了我对RL引导的推理模型的浓厚兴趣,目前已经用Qwen2.5复现了R1-Zero。
|
|
||||||
如果有时间+效果work(但99%基模能力不足)我会在之后更新MiniMind基于RL训练的推理模型而不是蒸馏模型。
|
|
||||||
时间有限,最快的低成本方案依然是直接蒸馏(黑盒方式)。
|
|
||||||
耐不住R1太火,短短几天就已经存在一些R1的蒸馏数据集[R1-Llama-70B](https://www.modelscope.cn/datasets/Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B)、[R1-Distill-SFT](https://www.modelscope.cn/datasets/AI-ModelScope/R1-Distill-SFT)、
|
|
||||||
[Alpaca-Distill-R1](https://huggingface.co/datasets/shareAI/Alpaca-Distill-R1-ZH)、
|
|
||||||
[deepseek_r1_zh](https://huggingface.co/datasets/jinliuxi/deepseek_r1_zh)等等,纯中文的数据可能比较少。
|
|
||||||
最终整合它们,导出文件为`r1_mix_1024.jsonl`,数据格式和`sft_X.jsonl`一致。
|
|
||||||
|
|
||||||
## Ⅵ 更多数据集
|
|
||||||
|
|
||||||
目前已经有[HqWu-HITCS/Awesome-Chinese-LLM](https://github.com/HqWu-HITCS/Awesome-Chinese-LLM)
|
|
||||||
在收集和梳理中文LLM相关的开源模型、应用、数据集及教程等资料,并持续更新这方面的最新进展。全面且专业,Respect!
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Ⅷ 数据集下载
|
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> 2025-02-05后,开源MiniMind最终训练所用的所有数据集,因此无需再自行预处理大规模数据集,避免重复性的数据处理工作。
|
|
||||||
|
|
||||||
MiniMind训练数据集 ([ModelScope](https://www.modelscope.cn/datasets/gongjy/minimind-dataset/files) | [HuggingFace](https://huggingface.co/datasets/jingyaogong))
|
|
||||||
|
|
||||||
> 无需全部clone,可单独下载所需的文件
|
|
||||||
|
|
||||||
将下载的数据集文件放到`./dataset/`目录下(✨为推荐的必须项)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./dataset/
|
|
||||||
├── dpo.jsonl (909MB)
|
|
||||||
├── lora_identity.jsonl (22.8KB)
|
|
||||||
├── lora_medical.jsonl (34MB)
|
|
||||||
├── pretrain_hq.jsonl (1.6GB, ✨)
|
|
||||||
├── r1_mix_1024.jsonl (340MB)
|
|
||||||
├── sft_1024.jsonl (5.6GB)
|
|
||||||
├── sft_2048.jsonl (9GB)
|
|
||||||
├── sft_512.jsonl (7.5GB)
|
|
||||||
├── sft_mini_512.jsonl (1.2GB, ✨)
|
|
||||||
└── tokenizer_train.jsonl (1GB)
|
|
||||||
```
|
|
||||||
|
|
||||||
<details style="color:rgb(128,128,128)">
|
|
||||||
<summary>注:各数据集简介</summary>
|
|
||||||
|
|
||||||
* `dpo.jsonl` --RLHF阶段数据集
|
|
||||||
* `lora_identity.jsonl` --自我认知数据集(例如:你是谁?我是minimind...),推荐用于lora训练(亦可用于全参SFT,勿被名字局限)
|
|
||||||
* `lora_medical.jsonl` --医疗问答数据集,推荐用于lora训练(亦可用于全参SFT,勿被名字局限)
|
|
||||||
* `pretrain_hq.jsonl`✨ --预训练数据集,整合自jiangshu科技
|
|
||||||
* `r1_mix_1024.jsonl` --DeepSeek-R1-1.5B蒸馏数据,每条数据字符最大长度为1024(因此训练时设置max_seq_len=1024)
|
|
||||||
* `sft_1024.jsonl` --整合自Qwen2.5蒸馏数据(是sft_2048的子集),每条数据字符最大长度为1024(因此训练时设置max_seq_len=1024)
|
|
||||||
* `sft_2048.jsonl` --整合自Qwen2.5蒸馏数据,每条数据字符最大长度为2048(因此训练时设置max_seq_len=2048)
|
|
||||||
* `sft_512.jsonl` --整合自匠数科技SFT数据,每条数据字符最大长度为512(因此训练时设置max_seq_len=512)
|
|
||||||
* `sft_mini_512.jsonl`✨ --极简整合自匠数科技SFT数据+Qwen2.5蒸馏数据(用于快速训练Zero模型),每条数据字符最大长度为512(因此训练时设置max_seq_len=512)
|
|
||||||
* `tokenizer_train.jsonl` --均来自于`匠数大模型数据集`,这部分数据相对次要,(不推荐自己重复训练tokenizer,理由如上)如需自己训练tokenizer可以自由选择数据集。
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
<details style="color:rgb(128,128,128)">
|
|
||||||
<summary>说明 & 推荐训练方案</summary>
|
|
||||||
|
|
||||||
* MiniMind2 Series均经过共约20GB语料训练,大约4B tokens,即对应上面的数据组合训练结果(开销:💰💰💰💰💰💰💰💰,效果:😊😊😊😊😊😊)
|
|
||||||
|
|
||||||
* 想要最快速度从0实现Zero模型,推荐使用`pretrain_hq.jsonl` + `sft_mini_512.jsonl` 的数据组合,具体花销和效果可查看下文表格(开销:💰,效果:😊😊)
|
|
||||||
|
|
||||||
* 推荐具备一定算力资源或更在意效果的朋友可以考虑前者完整复现MiniMind2;仅有单卡GPU或在乎短时间快速复现的朋友强烈推荐后者;
|
|
||||||
|
|
||||||
* 【折中方案】亦可选择例如`sft_mini_512.jsonl`、`sft_1024.jsonl`中等规模数据进行自由组合训练(开销:💰💰💰,效果:😊😊😊😊)。
|
|
||||||
|
|
||||||
</details>
|
|
126
README_accelerate.md
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
# 使用Accelerate+DeepSpeed进行分布式训练
|
||||||
|
|
||||||
|
本文档介绍如何使用Accelerate和DeepSpeed进行MiniMind模型的分布式训练。
|
||||||
|
|
||||||
|
## 环境准备
|
||||||
|
|
||||||
|
首先,确保安装了必要的依赖:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install accelerate deepspeed
|
||||||
|
```
|
||||||
|
|
||||||
|
## 配置文件说明
|
||||||
|
|
||||||
|
### 1. DeepSpeed配置文件 (ds_config.json)
|
||||||
|
|
||||||
|
DeepSpeed配置文件定义了优化器、学习率调度器和ZeRO优化等参数。主要配置包括:
|
||||||
|
|
||||||
|
- **ZeRO优化**:使用ZeRO-2进行优化,可以减少GPU内存使用
|
||||||
|
- **优化器设置**:使用AdamW优化器
|
||||||
|
- **混合精度训练**:支持FP16和BF16
|
||||||
|
- **梯度累积**:通过"auto"自动设置,与训练脚本参数保持一致
|
||||||
|
|
||||||
|
### 2. Accelerate配置文件 (accelerate_config.yaml)
|
||||||
|
|
||||||
|
Accelerate配置文件定义了分布式训练的基本设置,包括:
|
||||||
|
|
||||||
|
- **分布式类型**:使用DeepSpeed
|
||||||
|
- **混合精度**:使用BF16
|
||||||
|
- **进程数量**:设置为4(可根据GPU数量调整)
|
||||||
|
- **DeepSpeed配置**:指向ds_config.json文件
|
||||||
|
|
||||||
|
## 训练脚本说明
|
||||||
|
|
||||||
|
新的训练脚本`train_pretrain_accelerate.py`基于原有的`train_pretrain.py`修改而来,主要变化包括:
|
||||||
|
|
||||||
|
1. 使用Accelerator替代了PyTorch原生的分布式功能
|
||||||
|
2. 移除了torchrun相关的分布式初始化代码
|
||||||
|
3. 使用Accelerator的API进行模型、优化器和数据加载器的准备
|
||||||
|
4. 使用Accelerator的API进行反向传播和梯度裁剪
|
||||||
|
5. 处理了位置编码和未使用参数的问题
|
||||||
|
|
||||||
|
## 启动训练
|
||||||
|
|
||||||
|
有两种方式启动训练:
|
||||||
|
|
||||||
|
### 方法1:使用预先配置的accelerate配置文件
|
||||||
|
|
||||||
|
```bash
|
||||||
|
accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
|
||||||
|
--epochs 3 \
|
||||||
|
--batch_size 24 \
|
||||||
|
--learning_rate 2e-4 \
|
||||||
|
--dtype bfloat16 \
|
||||||
|
--accumulation_steps 32 \
|
||||||
|
--grad_clip 1.0 \
|
||||||
|
--log_interval 100 \
|
||||||
|
--save_interval 10000 \
|
||||||
|
--dim 1024 \
|
||||||
|
--n_layers 32 \
|
||||||
|
--max_seq_len 1024 \
|
||||||
|
--use_flash_attn \
|
||||||
|
--profile \
|
||||||
|
--profile_interval 10
|
||||||
|
```
|
||||||
|
|
||||||
|
### 方法2:使用命令行参数直接配置accelerate
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||||
|
--multi_gpu \
|
||||||
|
--num_processes=4 \
|
||||||
|
--mixed_precision=bf16 \
|
||||||
|
--main_process_port=29500 \
|
||||||
|
--deepspeed_config_file ds_config.json \
|
||||||
|
train_pretrain_accelerate.py \
|
||||||
|
--epochs 3 \
|
||||||
|
--batch_size 24 \
|
||||||
|
--learning_rate 2e-4 \
|
||||||
|
--dtype bfloat16 \
|
||||||
|
--accumulation_steps 32 \
|
||||||
|
--grad_clip 1.0 \
|
||||||
|
--log_interval 100 \
|
||||||
|
--save_interval 10000 \
|
||||||
|
--dim 1024 \
|
||||||
|
--n_layers 32 \
|
||||||
|
--max_seq_len 1024 \
|
||||||
|
--use_flash_attn \
|
||||||
|
--profile \
|
||||||
|
--profile_interval 10
|
||||||
|
```
|
||||||
|
|
||||||
|
也可以直接使用提供的脚本:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash run_accelerate.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
## Accelerate与DeepSpeed配置的关系
|
||||||
|
|
||||||
|
1. **Accelerate**是一个高级API,用于简化分布式训练的设置和启动,它可以与多种分布式训练后端(如DeepSpeed、FSDP等)一起使用。
|
||||||
|
|
||||||
|
2. **DeepSpeed**是一个优化库,专注于大规模模型训练的内存优化和性能提升,提供了ZeRO优化等功能。
|
||||||
|
|
||||||
|
3. **配置关系**:
|
||||||
|
- Accelerate配置文件(YAML)定义了使用哪种分布式后端以及基本的分布式设置
|
||||||
|
- DeepSpeed配置文件(JSON)定义了DeepSpeed特有的优化参数
|
||||||
|
- Accelerate通过`deepspeed_config_file`参数引用DeepSpeed配置文件
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
1. **位置编码处理**:
|
||||||
|
- 在模型中,`pos_cis`是一个复数张量,在分布式训练中需要特别处理
|
||||||
|
- 在新的训练脚本中,我们使用Accelerator的API来处理这个问题,不再需要`_ddp_params_and_buffers_to_ignore`
|
||||||
|
|
||||||
|
2. **未使用参数处理**:
|
||||||
|
- 原代码中使用`find_unused_parameters=True`来处理未使用的参数
|
||||||
|
- 在新的训练脚本中,我们直接使用Accelerator的API,它会自动处理这个问题
|
||||||
|
|
||||||
|
3. **混合精度训练**:
|
||||||
|
- DeepSpeed配置文件中的`fp16`和`bf16`设置为`"auto"`
|
||||||
|
- 实际使用的精度由Accelerate的`--mixed_precision`参数决定
|
||||||
|
|
||||||
|
4. **梯度累积**:
|
||||||
|
- DeepSpeed配置文件中的`gradient_accumulation_steps`设置为`"auto"`
|
||||||
|
- 实际的梯度累积步数由训练脚本的`--accumulation_steps`参数决定
|
1509
README_en.md
22
ReadMe.md
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
## 安装环境
|
||||||
|
1. 创建conda环境
|
||||||
|
```bash
|
||||||
|
conda create -n accelerate python=3.10
|
||||||
|
conda activate accelerate
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 根据当前系统的cuda版本安装对应的torch、torchvision和torchaudio
|
||||||
|
|
||||||
|
3. 根据当前环境的torch和torchvision安装accelerate和deepspeed
|
||||||
|
|
||||||
|
4. 安装其他包
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## 修改模型
|
||||||
|
1. 一般情况只修改 `model`文件夹的文件
|
||||||
|
|
||||||
|
## 运行
|
||||||
|
1. 如果在4090或者4070ti上运行 `bash run_file/DynamicKV-LLM_Mini_Minimind.sh`
|
||||||
|
2. 如果在4张A800上运行 `bash run_file/DynamicKV-LLM_Small_Minimind.sh`
|
17
accelerate_config.yaml
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
deepspeed_config:
|
||||||
|
deepspeed_config_file: ds_config.json
|
||||||
|
zero3_init_flag: false
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 4
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
144
dataset_decoder.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from model.model import MiniMindLM, ExtractDB
|
||||||
|
from model.LMConfig import LMConfig
|
||||||
|
|
||||||
|
def decode_dataset(model_path, output_path, device="cuda"):
|
||||||
|
"""
|
||||||
|
Decode the weight_down_embed buffer in the model to readable text
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model checkpoint
|
||||||
|
output_path: Path to save the decoded text
|
||||||
|
device: Device to load the model on
|
||||||
|
"""
|
||||||
|
print(f"Loading tokenizer from ./model/minimind_tokenizer")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||||
|
|
||||||
|
print(f"Setting up model configuration")
|
||||||
|
# Create model configuration matching the training parameters
|
||||||
|
lm_config = LMConfig(
|
||||||
|
dim=1024,
|
||||||
|
n_layers=32,
|
||||||
|
max_seq_len=1024,
|
||||||
|
use_flash_attn=True,
|
||||||
|
knowledge_num=16384, # From the script parameters
|
||||||
|
knowledge_length=64 # From the script parameters
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Initializing model")
|
||||||
|
model = MiniMindLM(lm_config).to(device)
|
||||||
|
|
||||||
|
print(f"Loading model weights from {model_path}")
|
||||||
|
state_dict = torch.load(model_path, map_location=device)
|
||||||
|
|
||||||
|
# Get model parameters
|
||||||
|
model_state = dict(model.named_parameters())
|
||||||
|
model_state.update(dict(model.named_buffers()))
|
||||||
|
|
||||||
|
# Find parameters with matching names but different shapes
|
||||||
|
shape_mismatch = {}
|
||||||
|
for name, param in model_state.items():
|
||||||
|
if name in state_dict and param.shape != state_dict[name].shape:
|
||||||
|
shape_mismatch[name] = (param.shape, state_dict[name].shape)
|
||||||
|
|
||||||
|
# Find parameters in model but not in state_dict and vice versa
|
||||||
|
model_only = set(model_state.keys()) - set(state_dict.keys())
|
||||||
|
state_dict_only = set(state_dict.keys()) - set(model_state.keys())
|
||||||
|
|
||||||
|
# Create filtered state_dict with only compatible parameters
|
||||||
|
filtered_state_dict = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name in model_state and param.shape == model_state[name].shape:
|
||||||
|
filtered_state_dict[name] = param
|
||||||
|
|
||||||
|
# Print parameter differences
|
||||||
|
if shape_mismatch:
|
||||||
|
print(f"Parameters with shape mismatches: {len(shape_mismatch)}")
|
||||||
|
for name, (model_shape, state_shape) in shape_mismatch.items():
|
||||||
|
print(f" {name}: model={model_shape}, checkpoint={state_shape}")
|
||||||
|
|
||||||
|
if model_only:
|
||||||
|
print(f"Parameters in model but not in checkpoint: {len(model_only)}")
|
||||||
|
for name in sorted(model_only):
|
||||||
|
print(f" {name}: {model_state[name].shape}")
|
||||||
|
|
||||||
|
# 特殊处理pos_cis_real参数
|
||||||
|
if name == "pos_cis_real":
|
||||||
|
print(f"Detected pos_cis_real parameter. This is a position encoding that will be initialized automatically.")
|
||||||
|
|
||||||
|
if state_dict_only:
|
||||||
|
print(f"Parameters in checkpoint but not in model: {len(state_dict_only)}")
|
||||||
|
for name in sorted(state_dict_only):
|
||||||
|
print(f" {name}: {state_dict[name].shape}")
|
||||||
|
|
||||||
|
# 如果checkpoint中有output.weight但模型中没有,尝试加载到tok_embeddings
|
||||||
|
if name == "output.weight" and "tok_embeddings.weight" in model_state:
|
||||||
|
print(f"Found output.weight in checkpoint but not in model. Will try to map it to tok_embeddings.weight")
|
||||||
|
if model_state["tok_embeddings.weight"].shape == state_dict["output.weight"].shape:
|
||||||
|
filtered_state_dict["tok_embeddings.weight"] = state_dict["output.weight"]
|
||||||
|
|
||||||
|
# Load only the compatible parameters
|
||||||
|
print(f"Loading {len(filtered_state_dict)}/{len(state_dict)} parameters")
|
||||||
|
model.load_state_dict(filtered_state_dict, strict=False)
|
||||||
|
|
||||||
|
# 检查extract_db和weight_down_embed是否存在
|
||||||
|
if not hasattr(model, "extract_db"):
|
||||||
|
print("ERROR: Model does not have extract_db attribute. This is required for decoding.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("Accessing weight_down_embed buffer")
|
||||||
|
# Get the weight_down_embed buffer from the model
|
||||||
|
try:
|
||||||
|
weight_down_embed = model.extract_db.weight_down_embed
|
||||||
|
print(f"Successfully accessed weight_down_embed buffer")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR: Failed to access weight_down_embed buffer: {e}")
|
||||||
|
print(f"Model structure: {model.__class__.__name__}")
|
||||||
|
print(f"ExtractDB attributes: {dir(model.extract_db)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Shape of weight_down_embed: {weight_down_embed.shape}")
|
||||||
|
print(f"Data type of weight_down_embed: {weight_down_embed.dtype}")
|
||||||
|
|
||||||
|
# Create output directory if it doesn't exist
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Decoding knowledge and writing to {output_path}")
|
||||||
|
knowledge_num, knowledge_length = weight_down_embed.shape
|
||||||
|
|
||||||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||||||
|
for i in range(knowledge_num):
|
||||||
|
try:
|
||||||
|
# Get token IDs for this knowledge entry
|
||||||
|
token_ids = weight_down_embed[i].cpu().tolist()
|
||||||
|
|
||||||
|
# Decode tokens to text
|
||||||
|
text = tokenizer.decode(token_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# Write to file
|
||||||
|
f.write(f"Knowledge_{i}: {text}\n")
|
||||||
|
|
||||||
|
# Print progress periodically
|
||||||
|
if (i + 1) % 100 == 0:
|
||||||
|
print(f"Decoded {i + 1}/{knowledge_num} knowledge entries")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error decoding knowledge entry {i}: {e}")
|
||||||
|
f.write(f"Knowledge_{i}: [ERROR DECODING]\n")
|
||||||
|
|
||||||
|
print(f"Decoding completed. Output saved to {output_path}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database")
|
||||||
|
parser.add_argument("--model_path", type=str, default="out/pretrain_1024.pth",
|
||||||
|
help="Path to the model checkpoint")
|
||||||
|
parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt",
|
||||||
|
help="Path to save the decoded text file")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
help="Device to load the model on")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
decode_dataset(args.model_path, args.output_path, args.device)
|
49
ds_config.json
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
{
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"gradient_clipping": "auto",
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 2,
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"allgather_partitions": true,
|
||||||
|
"allgather_bucket_size": 5e8,
|
||||||
|
"overlap_comm": true,
|
||||||
|
"reduce_scatter": true,
|
||||||
|
"reduce_bucket_size": 5e8,
|
||||||
|
"contiguous_gradients": true
|
||||||
|
},
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"loss_scale": 0,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"initial_scale_power": 16,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"params": {
|
||||||
|
"lr": "auto",
|
||||||
|
"betas": "auto",
|
||||||
|
"eps": "auto",
|
||||||
|
"weight_decay": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": "auto",
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"steps_per_print": 100,
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
Before Width: | Height: | Size: 136 KiB |
Before Width: | Height: | Size: 73 KiB |
Before Width: | Height: | Size: 230 KiB |
Before Width: | Height: | Size: 104 KiB |
Before Width: | Height: | Size: 239 KiB |
Before Width: | Height: | Size: 121 KiB |
Before Width: | Height: | Size: 372 KiB |
Before Width: | Height: | Size: 178 KiB |
Before Width: | Height: | Size: 150 KiB |
Before Width: | Height: | Size: 519 KiB |
Before Width: | Height: | Size: 146 KiB |
Before Width: | Height: | Size: 66 KiB |
BIN
images/logo.png
Before Width: | Height: | Size: 495 KiB |
BIN
images/logo2.png
Before Width: | Height: | Size: 615 KiB |
Before Width: | Height: | Size: 3.8 MiB |
Before Width: | Height: | Size: 559 KiB |
Before Width: | Height: | Size: 531 KiB |
Before Width: | Height: | Size: 1006 KiB |
Before Width: | Height: | Size: 943 KiB |
@ -9,8 +9,8 @@ class LMConfig(PretrainedConfig):
|
|||||||
self,
|
self,
|
||||||
dim: int = 512,
|
dim: int = 512,
|
||||||
n_layers: int = 8,
|
n_layers: int = 8,
|
||||||
n_heads: int = 8,
|
n_heads: int = 32,
|
||||||
n_kv_heads: int = 2,
|
n_kv_heads: int = 8,
|
||||||
vocab_size: int = 6400,
|
vocab_size: int = 6400,
|
||||||
hidden_dim: int = None,
|
hidden_dim: int = None,
|
||||||
multiple_of: int = 64,
|
multiple_of: int = 64,
|
||||||
@ -19,6 +19,7 @@ class LMConfig(PretrainedConfig):
|
|||||||
rope_theta: int = 1e6,
|
rope_theta: int = 1e6,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
flash_attn: bool = True,
|
flash_attn: bool = True,
|
||||||
|
embeddings_epoch: int = 2,
|
||||||
####################################################
|
####################################################
|
||||||
# DB related configurations
|
# DB related configurations
|
||||||
####################################################
|
####################################################
|
||||||
@ -36,6 +37,10 @@ class LMConfig(PretrainedConfig):
|
|||||||
aux_loss_alpha: float = 0.1,
|
aux_loss_alpha: float = 0.1,
|
||||||
seq_aux: bool = True,
|
seq_aux: bool = True,
|
||||||
norm_topk_prob: bool = True,
|
norm_topk_prob: bool = True,
|
||||||
|
####################################################
|
||||||
|
knowledge_num: int = 64*64,
|
||||||
|
knowledge_length: int = 8,
|
||||||
|
knowledge_dim: int = 128,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -50,6 +55,7 @@ class LMConfig(PretrainedConfig):
|
|||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.flash_attn = flash_attn
|
self.flash_attn = flash_attn
|
||||||
|
self.embeddings_epoch = embeddings_epoch
|
||||||
####################################################
|
####################################################
|
||||||
# DB related configurations
|
# DB related configurations
|
||||||
####################################################
|
####################################################
|
||||||
@ -66,4 +72,8 @@ class LMConfig(PretrainedConfig):
|
|||||||
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
|
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
|
||||||
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
|
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
|
||||||
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
|
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
|
||||||
|
####################################################
|
||||||
|
self.knowledge_num = knowledge_num
|
||||||
|
self.knowledge_length = knowledge_length
|
||||||
|
self.knowledge_dim = knowledge_dim
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split
|
|||||||
import os
|
import os
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||||
|
|
||||||
|
|
||||||
class PretrainDataset(Dataset):
|
class PretrainDataset(Dataset):
|
||||||
|
629
model/model.py
@ -11,14 +11,9 @@ import torch.nn.functional as F
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from torch import nn, einsum
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
|
|
||||||
def exists(val):
|
|
||||||
return val is not None
|
|
||||||
|
|
||||||
|
|
||||||
# RMSNorm 类定义了一个用于归一化输入张量的模块。
|
|
||||||
class RMSNorm(torch.nn.Module):
|
class RMSNorm(torch.nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -31,7 +26,7 @@ class RMSNorm(torch.nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.weight * self._norm(x.float()).type_as(x)
|
return self.weight * self._norm(x.float()).type_as(x)
|
||||||
|
|
||||||
# precompute_pos_cis 函数用于预计算位置编码。
|
|
||||||
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||||
@ -39,7 +34,7 @@ def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
|||||||
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||||
return pos_cis
|
return pos_cis
|
||||||
|
|
||||||
# apply_rotary_emb 函数用于应用旋转位置编码。
|
|
||||||
def apply_rotary_emb(xq, xk, pos_cis):
|
def apply_rotary_emb(xq, xk, pos_cis):
|
||||||
def unite_shape(pos_cis, x):
|
def unite_shape(pos_cis, x):
|
||||||
ndim = x.ndim
|
ndim = x.ndim
|
||||||
@ -55,18 +50,239 @@ def apply_rotary_emb(xq, xk, pos_cis):
|
|||||||
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
||||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||||
|
|
||||||
# repeat_kv 函数用于重复键值对。
|
class KnowledgeDataset(nn.Module):
|
||||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def __init__(self, params, tok_embeddings, is_train=True):
|
||||||
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
super().__init__()
|
||||||
bs, slen, n_kv_heads, head_dim = x.shape
|
self.is_train = is_train
|
||||||
if n_rep == 1:
|
self.params = params
|
||||||
return x
|
self.tok_embeddings = tok_embeddings
|
||||||
return (
|
|
||||||
x[:, :, :, None, :]
|
|
||||||
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
|
||||||
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# 嵌入参数
|
||||||
|
self.knowledge_dim = params.knowledge_dim
|
||||||
|
self.key_dim = self.knowledge_dim // 2
|
||||||
|
self.to_queries = nn.Sequential(
|
||||||
|
nn.Linear(params.dim, self.knowledge_dim, bias=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
## 数据库参数
|
||||||
|
self.knowledge_num = params.knowledge_num
|
||||||
|
self.knowledge_length = params.knowledge_length
|
||||||
|
self.keys = nn.Parameter(torch.randn(self.knowledge_num, self.knowledge_dim) * 0.02, requires_grad=True)
|
||||||
|
self.product_key_topk = min(16, self.knowledge_num)
|
||||||
|
|
||||||
|
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
|
||||||
|
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
|
||||||
|
|
||||||
|
# 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度
|
||||||
|
self.register_buffer('knowledge_dataset',
|
||||||
|
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算step数目,用于动态调整权重
|
||||||
|
self.step_counter = 0
|
||||||
|
|
||||||
|
self.freeze_embedding = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def intelligent_selection(self, query, all_scores, all_indices):
|
||||||
|
"""智能分层选择策略"""
|
||||||
|
if self.is_train == False:
|
||||||
|
return all_scores, all_indices
|
||||||
|
|
||||||
|
batch_size = all_scores.size(0)
|
||||||
|
device = all_scores.device
|
||||||
|
dtype = all_scores.dtype
|
||||||
|
|
||||||
|
# 对每个batch进行分层选择
|
||||||
|
enhanced_scores = all_scores.clone()
|
||||||
|
query_features = query.mean(dim=1) # [batch_size, dim]
|
||||||
|
|
||||||
|
# 预先计算所有候选条目的嵌入(批量优化)
|
||||||
|
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
|
||||||
|
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
|
||||||
|
|
||||||
|
# 批量计算唯一候选条目的嵌入
|
||||||
|
candidate_tokens = self.knowledge_dataset[unique_indices]
|
||||||
|
flat_tokens = candidate_tokens.view(-1)
|
||||||
|
flat_embeddings = self.tok_embeddings(flat_tokens)
|
||||||
|
#获取flat_tokens对应的index
|
||||||
|
pre_update_indices = unique_indices.view(-1)
|
||||||
|
pre_update_embeddings = flat_embeddings.view(
|
||||||
|
len(unique_indices), self.knowledge_length, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
unique_candidate_features = flat_embeddings.view(
|
||||||
|
len(unique_indices), self.knowledge_length, -1
|
||||||
|
).mean(dim=1) # [num_unique_candidates, dim]
|
||||||
|
|
||||||
|
# 归一化候选特征(优化相似度计算)
|
||||||
|
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
|
||||||
|
normalized_queries = F.normalize(query_features, dim=-1)
|
||||||
|
|
||||||
|
# 收集所有batch的best_tokens
|
||||||
|
batch_best_tokens = []
|
||||||
|
batch_best_tokens_embeddings = []
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
indices = all_indices[batch_idx]
|
||||||
|
|
||||||
|
# 获取当前batch候选条目对应的特征索引
|
||||||
|
start_idx = batch_idx * len(indices)
|
||||||
|
end_idx = start_idx + len(indices)
|
||||||
|
batch_inverse_indices = inverse_indices[start_idx:end_idx]
|
||||||
|
|
||||||
|
# 使用预计算的归一化特征进行优化相似度计算
|
||||||
|
batch_candidate_features = normalized_candidates[batch_inverse_indices]
|
||||||
|
query_feature = normalized_queries[batch_idx]
|
||||||
|
|
||||||
|
# 使用矩阵乘法计算余弦相似度
|
||||||
|
similarity_scores = torch.mv(batch_candidate_features, query_feature)
|
||||||
|
|
||||||
|
# 找到最大相似度分数的索引
|
||||||
|
max_similarity_idx = torch.argmax(similarity_scores)
|
||||||
|
|
||||||
|
# 获取最大相似度对应的候选条目索引
|
||||||
|
best_candidate_idx = indices[max_similarity_idx]
|
||||||
|
|
||||||
|
# 获取对应的tokens
|
||||||
|
best_tokens = self.knowledge_dataset[best_candidate_idx]
|
||||||
|
best_tokens_embeddings = self.tok_embeddings(best_tokens)
|
||||||
|
|
||||||
|
# 将当前batch的best_tokens添加到列表中
|
||||||
|
batch_best_tokens.append(best_tokens)
|
||||||
|
batch_best_tokens_embeddings.append(best_tokens_embeddings)
|
||||||
|
|
||||||
|
# 将所有batch的best_tokens堆叠成一个张量
|
||||||
|
# [batch_size, knowledge_length]
|
||||||
|
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
|
||||||
|
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
|
||||||
|
|
||||||
|
# 获取
|
||||||
|
|
||||||
|
# 使用重新计算的embeddings更新self.keys
|
||||||
|
if self.is_train:
|
||||||
|
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||||||
|
|
||||||
|
# 更新被修改过的key
|
||||||
|
with torch.no_grad():
|
||||||
|
self.has_update_keys[pre_update_indices] = 1
|
||||||
|
|
||||||
|
return all_best_tokens, all_best_tokens_embeddings
|
||||||
|
|
||||||
|
def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings):
|
||||||
|
if self.freeze_embedding:
|
||||||
|
return
|
||||||
|
# 使用pre_update_embeddings更新self.keys
|
||||||
|
with torch.no_grad():
|
||||||
|
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512]
|
||||||
|
pre_update_embeddings = self.to_queries(pre_update_embeddings)
|
||||||
|
self.keys[pre_update_indices] = pre_update_embeddings
|
||||||
|
|
||||||
|
def search_index(self,x):
|
||||||
|
batch_size, seq_len, dim = x.shape
|
||||||
|
|
||||||
|
# collapse sequence dimension by averaging
|
||||||
|
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||||
|
|
||||||
|
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
||||||
|
# queries = queries.reshape(batch_size, 2, self.key_dim)
|
||||||
|
# queries = queries.permute(1, 0, 2)
|
||||||
|
|
||||||
|
# 2. 计算queries与keys的相似度
|
||||||
|
sim = torch.einsum('b d, k d -> b k', queries, self.keys)
|
||||||
|
|
||||||
|
# 3. 在两个子空间分别做top-k
|
||||||
|
scores_and_indices = sim.topk(self.product_key_topk, dim=-1)
|
||||||
|
scores, indices = scores_and_indices[0], scores_and_indices[1]
|
||||||
|
|
||||||
|
# 5. 应用智能分层选择策略
|
||||||
|
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
|
||||||
|
|
||||||
|
# 6. 更新1%的keys
|
||||||
|
if self.is_train:
|
||||||
|
# 获取未更新过的keys的索引
|
||||||
|
not_updated_indices = torch.where(self.has_update_keys == 0)[0]
|
||||||
|
|
||||||
|
# 如果有未更新的keys,随机选择num_update_keys个进行更新
|
||||||
|
if len(not_updated_indices) > 0:
|
||||||
|
num_update_keys = int(self.knowledge_num * 0.01)
|
||||||
|
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
|
||||||
|
perm_num = perm.shape[0]
|
||||||
|
pre_update_indices = not_updated_indices[perm]
|
||||||
|
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
|
||||||
|
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
|
||||||
|
pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1)
|
||||||
|
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||||||
|
# 更新被修改过的key
|
||||||
|
with torch.no_grad():
|
||||||
|
self.has_update_keys[pre_update_indices] = 1
|
||||||
|
else:
|
||||||
|
print("all keys are updated")
|
||||||
|
# 重置所有keys的更新状态
|
||||||
|
self.has_update_keys.zero_()
|
||||||
|
# 重新获取所有可更新的索引
|
||||||
|
not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device)
|
||||||
|
num_update_keys = int(self.knowledge_num * 0.01)
|
||||||
|
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
|
||||||
|
pre_update_indices = not_updated_indices[perm]
|
||||||
|
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
|
||||||
|
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
|
||||||
|
pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1)
|
||||||
|
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||||||
|
# 更新被修改过的key
|
||||||
|
with torch.no_grad():
|
||||||
|
self.has_update_keys[pre_update_indices] = 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return best_tokens, best_tokens_embeddings
|
||||||
|
|
||||||
|
class CrossAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.num_heads = 8
|
||||||
|
self.head_dim = self.config.dim // self.num_heads
|
||||||
|
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||||
|
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||||
|
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||||
|
|
||||||
|
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||||
|
batch_size = x.size(0)
|
||||||
|
|
||||||
|
# 分离多头
|
||||||
|
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if pos_emb is not None:
|
||||||
|
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
q = q + pos_emb
|
||||||
|
k = k + pos_emb
|
||||||
|
v = v + pos_emb
|
||||||
|
|
||||||
|
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if context_mask is not None:
|
||||||
|
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
||||||
|
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
||||||
|
|
||||||
|
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||||
|
|
||||||
|
context = torch.matmul(attn_weights, v)
|
||||||
|
|
||||||
|
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
|
||||||
|
|
||||||
|
context = self.to_out(context)
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, args: LMConfig):
|
def __init__(self, args: LMConfig):
|
||||||
@ -92,58 +308,14 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
pos_cis: torch.Tensor,
|
pos_cis: torch.Tensor):
|
||||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
bsz, seq_len, _ = x.shape
|
||||||
use_cache=True,
|
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||||
db_value=None):
|
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||||||
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
|
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换,得到查询、键和值。
|
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) #将变换后的张量xq重塑为形状为(bsz, seq_len, n_local_heads, head_dim)的形状。
|
|
||||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
|
|
||||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
|
|
||||||
|
|
||||||
# 应用旋转位置编码
|
|
||||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||||||
# kv_cache实现
|
|
||||||
if past_key_value is not None:
|
|
||||||
xk = torch.cat([past_key_value[0], xk], dim=1)
|
|
||||||
xv = torch.cat([past_key_value[1], xv], dim=1)
|
|
||||||
past_kv = (xk, xv) if use_cache else None
|
|
||||||
|
|
||||||
# 重复键值对
|
|
||||||
xq, xk, xv = (
|
|
||||||
xq.transpose(1, 2),
|
|
||||||
repeat_kv(xk, self.n_rep).transpose(1, 2),
|
|
||||||
repeat_kv(xv, self.n_rep).transpose(1, 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 如果提供了db_value,根据头的数量调整它的形状并与xv合并
|
|
||||||
if db_value is not None:
|
|
||||||
# 确保db_value的形状与xv兼容,假设db_value形状为[B, N, H, D]
|
|
||||||
if db_value.ndim == 4: # [B, N, H, D]
|
|
||||||
db_value = db_value.transpose(1, 2) # -> [B, H, N, D]
|
|
||||||
|
|
||||||
# 检查是否需要调整D维度
|
|
||||||
if db_value.shape[-1] != xv.shape[-1]:
|
|
||||||
# 如果db_value的维度与xv不同,可以添加一个投影层
|
|
||||||
# 或者在这里使用简单的调整方法
|
|
||||||
# 这里我们简单地通过均值池化或重复来调整维度
|
|
||||||
if db_value.shape[-1] > xv.shape[-1]:
|
|
||||||
# 降维
|
|
||||||
factor = db_value.shape[-1] // xv.shape[-1]
|
|
||||||
db_value = db_value.view(bsz, self.n_local_heads, seq_len, factor, xv.shape[-1])
|
|
||||||
db_value = db_value.mean(dim=3)
|
|
||||||
else:
|
|
||||||
# 升维
|
|
||||||
factor = xv.shape[-1] // db_value.shape[-1]
|
|
||||||
db_value = db_value.unsqueeze(-1).repeat(1, 1, 1, 1, factor)
|
|
||||||
db_value = db_value.view(bsz, self.n_local_heads, seq_len, xv.shape[-1])
|
|
||||||
|
|
||||||
# 将db_value与xv相加或融合
|
|
||||||
# 这里我们简单地将它们相加,但你也可以使用其他融合方法
|
|
||||||
xv = xv + db_value
|
|
||||||
|
|
||||||
# 使用Flash Attention
|
|
||||||
if self.flash and seq_len != 1:
|
if self.flash and seq_len != 1:
|
||||||
dropout_p = self.dropout if self.training else 0.0
|
dropout_p = self.dropout if self.training else 0.0
|
||||||
output = F.scaled_dot_product_attention(
|
output = F.scaled_dot_product_attention(
|
||||||
@ -161,56 +333,9 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||||
output = self.resid_dropout(self.wo(output))
|
output = self.resid_dropout(self.wo(output))
|
||||||
return output, past_kv
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.num_heads = 8
|
|
||||||
self.head_dim = self.config.dim // self.num_heads
|
|
||||||
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
|
||||||
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
|
||||||
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
|
||||||
|
|
||||||
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
|
||||||
batch_size = x.size(0)
|
|
||||||
|
|
||||||
# 分离多头
|
|
||||||
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
if pos_emb is not None:
|
|
||||||
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
q = q + pos_emb
|
|
||||||
k = k + pos_emb
|
|
||||||
v = v + pos_emb
|
|
||||||
|
|
||||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
if context_mask is not None:
|
|
||||||
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
|
||||||
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
||||||
|
|
||||||
context = torch.matmul(attn_weights, v)
|
|
||||||
|
|
||||||
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
|
|
||||||
|
|
||||||
context = self.to_out(context)
|
|
||||||
|
|
||||||
return context
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
def __init__(self, config: LMConfig):
|
def __init__(self, config: LMConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -343,168 +468,31 @@ class MOEFeedForward(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MiniMindBlock(nn.Module):
|
class MiniMindBlock(nn.Module):
|
||||||
def __init__(self, layer_id: int, config: LMConfig):
|
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_heads = config.n_heads
|
self.n_heads = config.n_heads
|
||||||
self.dim = config.dim
|
self.dim = config.dim
|
||||||
self.head_dim = config.dim // config.n_heads
|
self.head_dim = config.dim // config.n_heads
|
||||||
self.attention = Attention(config)
|
self.self_attention = Attention(config)
|
||||||
self.cross_att = CrossAttention(config)
|
self.cross_attention = CrossAttention(config)
|
||||||
|
self.knowledge_dataset = knowledge_dataset
|
||||||
|
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||||
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||||
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||||
|
|
||||||
# 假设num_experts是已定义的总专家数量的平方根
|
|
||||||
|
|
||||||
|
|
||||||
# 查询生成的参数
|
|
||||||
|
|
||||||
|
|
||||||
# 创建查询生成模块
|
|
||||||
# if weight_down_embed is not None:
|
|
||||||
# self.to_queries = nn.Sequential(
|
|
||||||
# nn.Linear(config.dim, self.dim_key * 2, bias=False),
|
|
||||||
# # nn.Unflatten(2, (2, self.n_heads, self.dim_key)) # 替代Rearrange
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # 超参数
|
|
||||||
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
|
|
||||||
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
|
|
||||||
|
|
||||||
def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=True):
|
def forward(self, x, pos_cis):
|
||||||
# import pdb;pdb.set_trace()
|
h_attn = self.self_attention(
|
||||||
# db_value = None
|
|
||||||
|
|
||||||
# # 如果有weight_down_embed,使用Product Key机制
|
|
||||||
# if self.weight_down_embed is not None:
|
|
||||||
# # 1. 生成queries
|
|
||||||
# batch_size, seq_len, dim = x.shape
|
|
||||||
|
|
||||||
# # collapse sequence dimension by averaging
|
|
||||||
# x_flat = x.mean(dim=1) # [batch_size, dim]
|
|
||||||
# queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
|
||||||
# queries = queries.reshape(batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
|
|
||||||
# queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
|
|
||||||
|
|
||||||
# # 2. 计算queries与keys的相似度
|
|
||||||
# sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
|
||||||
|
|
||||||
# # 3. 在两个子空间分别做top-k
|
|
||||||
# scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
|
||||||
# scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
|
||||||
# indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
|
||||||
|
|
||||||
# # 4. 组合两个子空间的分数和索引
|
|
||||||
# all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
|
|
||||||
# all_scores = all_scores.view(*all_scores.shape[:-2], -1)
|
|
||||||
|
|
||||||
# all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
|
|
||||||
# all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
|
||||||
|
|
||||||
# # 5. 最终top-k选择
|
|
||||||
# scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
|
|
||||||
# indices = all_indices.gather(-1, pk_indices)
|
|
||||||
|
|
||||||
# # 6. 从embedding中获取专家值
|
|
||||||
|
|
||||||
# # 从embedding中获取值
|
|
||||||
# flat_indices = indices.view(-1) # 将索引展平为一维张量
|
|
||||||
# db_values = self.weight_down_embed(flat_indices)
|
|
||||||
|
|
||||||
# # 重塑回原始形状
|
|
||||||
# db_value = db_values.view(batch_size, -1, dim)
|
|
||||||
|
|
||||||
|
|
||||||
# 注意力计算
|
|
||||||
h_attn, past_kv = self.attention(
|
|
||||||
self.attention_norm(x),
|
self.attention_norm(x),
|
||||||
pos_cis,
|
pos_cis
|
||||||
past_key_value=past_key_value,
|
|
||||||
use_cache=use_cache,
|
|
||||||
db_value=db_value
|
|
||||||
)
|
)
|
||||||
|
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
|
||||||
h_attn = self.cross_att(h_attn, db_value)
|
h_attn = self.cross_attention(h_attn, db_embeddings)
|
||||||
|
|
||||||
# 残差连接
|
|
||||||
h = x + h_attn
|
h = x + h_attn
|
||||||
|
|
||||||
# 前馈神经网络
|
|
||||||
out = h + self.feed_forward(self.ffn_norm(h))
|
out = h + self.feed_forward(self.ffn_norm(h))
|
||||||
return out, past_kv
|
return out
|
||||||
|
|
||||||
class ExtractDB(nn.Module):
|
|
||||||
def __init__(self,params):
|
|
||||||
# 修改专家数量和知识维度,确保能开方
|
|
||||||
super().__init__()
|
|
||||||
self.batch_size = None
|
|
||||||
self.dim = params.dim
|
|
||||||
self.dim_key = self.dim // 2
|
|
||||||
self.num_experts = 10 * 10 # 100专家,确保是完全平方数
|
|
||||||
# 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用
|
|
||||||
self.head_dim = params.dim // params.n_heads
|
|
||||||
self.knowledge_dim = 8*params.dim
|
|
||||||
|
|
||||||
# 使用register_buffer代替nn.Parameter,避免梯度问题
|
|
||||||
self.register_buffer('weight_down_embed', torch.randn(self.num_experts, self.knowledge_dim) * 0.02)
|
|
||||||
|
|
||||||
self.num_keys = int(math.sqrt(self.num_experts)) if self.num_experts > 0 else 0
|
|
||||||
self.product_key_topk = min(16, self.num_keys)
|
|
||||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
|
|
||||||
self.num_experts_per_head_topk = 1
|
|
||||||
self.to_queries = nn.Sequential(
|
|
||||||
nn.Linear(params.dim, self.dim_key * 2, bias=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
def q_to_k(self,x):
|
|
||||||
# 1. 生成queries
|
|
||||||
self.batch_size, seq_len, dim = x.shape
|
|
||||||
|
|
||||||
# collapse sequence dimension by averaging
|
|
||||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
|
||||||
|
|
||||||
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
|
||||||
queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
|
|
||||||
queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
|
|
||||||
|
|
||||||
# 2. 计算queries与keys的相似度
|
|
||||||
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
|
||||||
|
|
||||||
# 3. 在两个子空间分别做top-k
|
|
||||||
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
|
||||||
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
|
||||||
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
|
||||||
|
|
||||||
# 4. 组合两个子空间的分数和索引
|
|
||||||
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
|
|
||||||
all_scores = all_scores.view(*all_scores.shape[:-2], -1)
|
|
||||||
|
|
||||||
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
|
|
||||||
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
|
||||||
|
|
||||||
# 5. 最终top-k选择
|
|
||||||
scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
|
|
||||||
indices = all_indices.gather(-1, pk_indices)
|
|
||||||
flat_indices = indices.view(-1)
|
|
||||||
return flat_indices
|
|
||||||
|
|
||||||
def get_data(self, index):
|
|
||||||
# 直接从GPU获取embedding
|
|
||||||
db_values = self.weight_down_embed[index]
|
|
||||||
db_value = db_values.view(self.batch_size, -1, self.dim)
|
|
||||||
return db_value
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def updata_value(self, k, v):
|
|
||||||
# 直接更新buffer上的值 (不需要梯度)
|
|
||||||
v_reshaped = v.view(v.size(0), -1)
|
|
||||||
# 确保数据类型匹配
|
|
||||||
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
|
|
||||||
self.weight_down_embed[k] = v_reshaped
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMindLM(PreTrainedModel):
|
class MiniMindLM(PreTrainedModel):
|
||||||
config_class = LMConfig
|
config_class = LMConfig
|
||||||
@ -515,130 +503,63 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||||||
self.dropout = nn.Dropout(params.dropout)
|
self.dropout = nn.Dropout(params.dropout)
|
||||||
# 移除旧的weight_down_embed声明
|
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
|
||||||
self.extract_db = ExtractDB(self.params)
|
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
|
||||||
|
|
||||||
# 将self.weight_down_embed传递给每个MiniMindBlock
|
|
||||||
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
|
|
||||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||||
self.tok_embeddings.weight = self.output.weight
|
self.tok_embeddings.weight = self.output.weight
|
||||||
|
|
||||||
# Calculate input dimension
|
|
||||||
input_dim = (self.params.max_seq_len-1)*self.params.n_layers
|
|
||||||
# Use a bottleneck architecture to reduce parameters
|
|
||||||
bottleneck_dim = 256 # Significantly smaller bottleneck dimension
|
|
||||||
|
|
||||||
# Factorized shared downsampling using two smaller convolutions
|
|
||||||
self.shared_downsample = nn.Sequential(
|
|
||||||
# First reduce input dimension to bottleneck
|
|
||||||
nn.Conv1d(input_dim, bottleneck_dim, kernel_size=1, padding='same'),
|
|
||||||
nn.ReLU(), # Non-linearity to improve representation capacity
|
|
||||||
# Then expand to target dimension
|
|
||||||
nn.Conv1d(bottleneck_dim, 128*8, kernel_size=1, padding='same')
|
|
||||||
)
|
|
||||||
|
|
||||||
# Specific layers for v path
|
|
||||||
self.downsample_v_specific = nn.Sequential(
|
|
||||||
nn.Conv1d(128*8, 128, kernel_size=1, padding='same'),
|
|
||||||
nn.Conv1d(128, 8, kernel_size=1, padding='same')
|
|
||||||
)
|
|
||||||
|
|
||||||
# Specific layers for q path
|
|
||||||
self.downsample_q_specific = nn.Sequential(
|
|
||||||
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
|
|
||||||
)
|
|
||||||
self.register_buffer("pos_cis",
|
self.register_buffer("pos_cis",
|
||||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||||
persistent=False)
|
persistent=False)
|
||||||
self.params = params
|
self.OUT = CausalLMOutputWithPast()
|
||||||
|
self.freeze_embedding = False
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: Optional[torch.Tensor] = None,
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
|
||||||
use_cache: bool = False,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
step: int = 0,
|
||||||
**args):
|
**args):
|
||||||
past_key_values = past_key_values or [None] * len(self.layers)
|
|
||||||
start_pos = args.get('start_pos', 0)
|
start_pos = args.get('start_pos', 0)
|
||||||
|
if self.freeze_embedding and step == 0:
|
||||||
|
self.tok_embeddings.weight.requires_grad = False
|
||||||
|
# 同时冻结KnowledgeDataset的嵌入更新
|
||||||
|
self.knowledge_dataset.freeze_embedding = True
|
||||||
|
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
|
||||||
|
print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding)
|
||||||
h = self.dropout(self.tok_embeddings(input_ids))
|
h = self.dropout(self.tok_embeddings(input_ids))
|
||||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||||
past_kvs = []
|
|
||||||
h_list = []
|
|
||||||
|
|
||||||
for l, layer in enumerate(self.layers):
|
for l, layer in enumerate(self.layers):
|
||||||
# 禁用数据库模式,使用固定值替代数据库查询
|
h = layer(
|
||||||
if self.params.disable_db:
|
h, pos_cis
|
||||||
# 创建一个形状为[batch_size, n_layers, dim]的tensor,所有元素值为1e-4
|
|
||||||
batch_size = h.size(0)
|
|
||||||
db_value = torch.full((batch_size, self.n_layers, self.params.dim), 1e-4,
|
|
||||||
dtype=h.dtype, device=h.device)
|
|
||||||
else:
|
|
||||||
# 正常模式,使用数据库查询
|
|
||||||
index = self.extract_db.q_to_k(h)
|
|
||||||
db_value = self.extract_db.get_data(index)
|
|
||||||
|
|
||||||
h, past_kv = layer(
|
|
||||||
h, db_value, pos_cis,
|
|
||||||
past_key_value=past_key_values[l],
|
|
||||||
use_cache=use_cache
|
|
||||||
)
|
)
|
||||||
|
|
||||||
past_kvs.append(past_kv)
|
|
||||||
h_list.append(h.unsqueeze(0))
|
|
||||||
|
|
||||||
h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3)
|
|
||||||
|
|
||||||
# 只在非禁用数据库模式下执行数据库更新逻辑
|
|
||||||
if not self.params.disable_db:
|
|
||||||
# 使用detach()分离计算图,避免多次反向传播
|
|
||||||
h_tensor_detached = h_tensor.detach()
|
|
||||||
h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0], -1, self.params.dim)
|
|
||||||
|
|
||||||
# 数据库更新逻辑与主计算图分离
|
|
||||||
with torch.no_grad():
|
|
||||||
# Compute shared downsampling layer once
|
|
||||||
shared_features = self.shared_downsample(h_tensor_detached)
|
|
||||||
z_v = self.downsample_v_specific(shared_features)
|
|
||||||
z_q = self.downsample_q_specific(shared_features)
|
|
||||||
z_k = self.extract_db.q_to_k(z_q)
|
|
||||||
self.extract_db.updata_value(z_k, z_v)
|
|
||||||
|
|
||||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||||
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
||||||
|
|
||||||
# 进一步简化,只保留必要的参数
|
# 进一步简化,只保留必要的参数
|
||||||
output = CausalLMOutputWithPast(
|
output = CausalLMOutputWithPast(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
past_key_values=past_kvs,
|
|
||||||
)
|
)
|
||||||
output.hidden_states = h
|
output.hidden_states = h
|
||||||
|
|
||||||
output.aux_loss = aux_loss
|
output.aux_loss = aux_loss
|
||||||
|
|
||||||
# 尝试添加其他属性(如果支持的话)
|
|
||||||
# try:
|
|
||||||
# output.hidden_states = h
|
|
||||||
|
|
||||||
# output.aux_loss = aux_loss
|
|
||||||
# except:
|
|
||||||
# pass
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||||||
stream=False, rp=1., use_cache=True, pad_token_id=0, num_return_sequences=1, **args):
|
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
|
||||||
# 流式生成
|
# 流式生成
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||||
|
|
||||||
# 直接生成
|
# 直接生成
|
||||||
generated = []
|
generated = []
|
||||||
for i in range(input_ids.size(0)):
|
for i in range(input_ids.size(0)):
|
||||||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||||||
for _ in range(num_return_sequences):
|
for _ in range(num_return_sequences):
|
||||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||||
tokens_list = [tokens[:, -1:] for tokens in out]
|
tokens_list = [tokens[:, -1:] for tokens in out]
|
||||||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||||||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||||||
@ -655,13 +576,13 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
|
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
|
||||||
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
||||||
while input_ids.shape[1] < max_new_tokens - 1:
|
while input_ids.shape[1] < max_new_tokens - 1:
|
||||||
if first_seq or not use_cache:
|
if first_seq:
|
||||||
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False
|
out, first_seq = self(input_ids, **args), False
|
||||||
else:
|
else:
|
||||||
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
|
out = self(input_ids[:, -1:],
|
||||||
start_pos=input_ids.shape[1] - 1, **args)
|
start_pos=input_ids.shape[1] - 1, **args)
|
||||||
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
||||||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||||||
@ -679,4 +600,4 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
||||||
yield input_ids[:, start:]
|
yield input_ids[:, start:]
|
||||||
if input_ids_next.item() == eos_token_id:
|
if input_ids_next.item() == eos_token_id:
|
||||||
break
|
break
|
154
preprocessing/README_trex_processor.md
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
# TREx 数据集处理工具使用说明
|
||||||
|
|
||||||
|
这个工具支持两步骤处理 TREx 数据集:
|
||||||
|
1. **句子提取**:从 TREx 数据集提取三元组并转换为自然语言句子
|
||||||
|
2. **LLM 处理**:使用 ollama qwen3:4b 模型进行句子修正和重要性评分
|
||||||
|
|
||||||
|
## 🆕 防卡死机制
|
||||||
|
|
||||||
|
为了解决LLM处理时可能出现的卡死问题,新增了以下功能:
|
||||||
|
|
||||||
|
### 超时和重试机制
|
||||||
|
- **超时时间**:每个LLM请求60秒超时
|
||||||
|
- **重试机制**:失败后最多重试2次,采用指数退避策略
|
||||||
|
- **并发控制**:降低并发数至4个,减少服务器压力
|
||||||
|
|
||||||
|
### 心跳监控系统
|
||||||
|
- **实时监控**:每30秒检查一次LLM响应状态
|
||||||
|
- **异常警告**:超过30秒无成功响应时发出警告
|
||||||
|
- **服务检测**:自动检查ollama服务状态
|
||||||
|
- **详细统计**:实时显示成功率、超时率等统计信息
|
||||||
|
|
||||||
|
### 日志系统
|
||||||
|
- **详细日志**:所有操作都记录在 `logs/` 目录下
|
||||||
|
- **双重输出**:同时输出到日志文件和控制台
|
||||||
|
- **时间戳标记**:日志文件包含启动时间戳
|
||||||
|
|
||||||
|
### 改进的错误处理
|
||||||
|
- **异常恢复**:LLM处理失败时使用原句子和默认评分
|
||||||
|
- **状态监控**:处理前检查ollama服务状态
|
||||||
|
- **批次间休息**:批次之间休息5秒,避免过度压力
|
||||||
|
|
||||||
|
## 安装依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install agno asyncio pydantic requests
|
||||||
|
```
|
||||||
|
|
||||||
|
确保已安装并启动 ollama,并下载 qwen3:4b 模型:
|
||||||
|
```bash
|
||||||
|
ollama pull qwen3:4b
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
### 1. 完整流程(两步骤连续执行)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python trex_to_sentences_simple.py --step all --input_dir dataset/TREx --max_files 2
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 分步骤执行
|
||||||
|
|
||||||
|
#### 步骤1:仅提取句子
|
||||||
|
```bash
|
||||||
|
python trex_to_sentences_simple.py --step extract --input_dir dataset/TREx --sentences_json my_sentences.json --max_files 2
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 步骤2:仅LLM处理
|
||||||
|
```bash
|
||||||
|
python trex_to_sentences_simple.py --step llm --sentences_json my_sentences.json --output_file final_output.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## 主要参数说明
|
||||||
|
|
||||||
|
- `--step`: 运行步骤
|
||||||
|
- `extract`: 仅提取句子
|
||||||
|
- `llm`: 仅LLM处理
|
||||||
|
- `all`: 完整流程(默认)
|
||||||
|
|
||||||
|
- `--input_dir`: TREx数据集目录(默认:`dataset/TREx`)
|
||||||
|
- `--sentences_json`: 提取的句子JSON文件(默认:`extracted_sentences.json`)
|
||||||
|
- `--output_file`: 最终输出文件(默认:`trex_sentences_enhanced.txt`)
|
||||||
|
- `--max_files`: 最大处理文件数(用于测试)
|
||||||
|
- `--no_llm`: 禁用LLM处理
|
||||||
|
|
||||||
|
## 输出文件
|
||||||
|
|
||||||
|
**注意:所有输出文件都会自动保存在相应目录中**
|
||||||
|
|
||||||
|
### 句子提取输出
|
||||||
|
- `output/extracted_sentences.json`: 提取的原始句子,包含元数据
|
||||||
|
|
||||||
|
### LLM处理输出
|
||||||
|
- `output/{output_file}.txt`: 修正后的句子文本文件
|
||||||
|
- `output/{output_file}.json`: 完整的处理结果(包含原句、修正句、评分)
|
||||||
|
- `output/{output_file}_sorted_by_importance.txt`: 按重要性评分排序的句子
|
||||||
|
|
||||||
|
### 检查点文件
|
||||||
|
- `output/{output_file}_checkpoint_{数量}.json`: 每1000条句子自动保存的检查点
|
||||||
|
|
||||||
|
### 日志文件
|
||||||
|
- `logs/trex_processor_{时间戳}.log`: 详细的处理日志
|
||||||
|
|
||||||
|
## 🆕 故障诊断
|
||||||
|
|
||||||
|
### 如果遇到卡死问题:
|
||||||
|
|
||||||
|
1. **检查日志文件**:查看 `logs/` 目录下的最新日志
|
||||||
|
2. **观察心跳监控**:注意控制台的心跳警告信息
|
||||||
|
3. **检查ollama服务**:
|
||||||
|
```bash
|
||||||
|
ps aux | grep ollama
|
||||||
|
curl http://localhost:11434/api/tags
|
||||||
|
```
|
||||||
|
4. **重启ollama服务**(如果需要):
|
||||||
|
```bash
|
||||||
|
pkill ollama
|
||||||
|
ollama serve &
|
||||||
|
```
|
||||||
|
|
||||||
|
### 常见警告信息:
|
||||||
|
|
||||||
|
- `⚠️ 心跳检测`: 30秒无成功响应(正常情况下会自动恢复)
|
||||||
|
- `❌ 严重警告`: 90秒无成功响应(可能需要检查服务)
|
||||||
|
- `💀 Ollama服务异常`: ollama服务可能已停止
|
||||||
|
- `💀 致命错误`: 连续多次警告(建议重启程序)
|
||||||
|
|
||||||
|
## 检查点恢复机制
|
||||||
|
|
||||||
|
- 步骤2会自动检测已有的检查点文件(在 `output/` 目录中)
|
||||||
|
- 只处理尚未处理的句子,避免重复工作
|
||||||
|
- 如果所有句子都已处理,会直接生成最终输出文件
|
||||||
|
- 中断后重新运行会自动从最新检查点继续
|
||||||
|
|
||||||
|
## 示例工作流
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. 先提取句子(可以快速完成)
|
||||||
|
python trex_to_sentences_simple.py --step extract --max_files 5
|
||||||
|
|
||||||
|
# 2. 后续进行LLM处理(耗时较长,支持断点续传)
|
||||||
|
python trex_to_sentences_simple.py --step llm
|
||||||
|
|
||||||
|
# 如果中途中断,再次运行步骤2会自动从检查点恢复
|
||||||
|
python trex_to_sentences_simple.py --step llm
|
||||||
|
```
|
||||||
|
|
||||||
|
## 性能特点
|
||||||
|
|
||||||
|
- **保守的并发**: 最大4个并发LLM请求(降低卡死风险)
|
||||||
|
- **检查点保存**: 每1000条句子自动保存,支持断点续传
|
||||||
|
- **智能监控**: 详细的处理进度和时间预估
|
||||||
|
- **健壮的错误处理**: LLM请求失败时使用原句子和默认评分
|
||||||
|
- **服务监控**: 自动检测ollama服务状态
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
1. 首次运行步骤2前,必须先完成步骤1
|
||||||
|
2. 检查点文件会占用额外磁盘空间(每个都包含所有已处理数据)
|
||||||
|
3. LLM处理速度取决于模型性能和网络状况
|
||||||
|
4. 建议先用`--max_files`参数测试小批量数据
|
||||||
|
5. **新增**:如果遇到卡死,查看日志文件和心跳监控信息
|
||||||
|
6. **新增**:程序会自动检测并报告ollama服务状态
|
||||||
|
7. **新增**:所有处理过程都有详细日志记录,便于问题诊断
|
225
preprocessing/merge_output_json.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
JSON文件合并脚本
|
||||||
|
读取多个JSON文件并合并为一个JSON文件
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Any, Union
|
||||||
|
|
||||||
|
# 需要合并的JSON文件列表
|
||||||
|
JSON_FILES_TO_MERGE = [
|
||||||
|
"output/trex_sentences_enhanced_checkpoint_360000.json"
|
||||||
|
]
|
||||||
|
for i in range(1, 1010):
|
||||||
|
JSON_FILES_TO_MERGE.append(f"output/trex_sentences_enhanced_batch_{i}.json")
|
||||||
|
|
||||||
|
def load_json_file(file_path: str) -> Union[Dict, List, None]:
|
||||||
|
"""加载JSON文件"""
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
print(f"警告: 文件 {file_path} 不存在")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
print(f"成功加载: {file_path}")
|
||||||
|
return data
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"错误: 无法解析JSON文件 {file_path} - {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"错误: 读取文件 {file_path} 失败 - {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def merge_json_data(data1: Union[Dict, List], data2: Union[Dict, List]) -> Union[Dict, List]:
|
||||||
|
"""合并两个JSON数据结构"""
|
||||||
|
|
||||||
|
# 如果两个都是列表,直接合并
|
||||||
|
if isinstance(data1, list) and isinstance(data2, list):
|
||||||
|
print(f"合并两个列表: {len(data1)} + {len(data2)} = {len(data1) + len(data2)} 项")
|
||||||
|
return data1 + data2
|
||||||
|
|
||||||
|
# 如果两个都是字典
|
||||||
|
elif isinstance(data1, dict) and isinstance(data2, dict):
|
||||||
|
print("合并两个字典结构")
|
||||||
|
merged = data1.copy()
|
||||||
|
|
||||||
|
# 特殊处理:如果都有'sentences'字段且为列表,合并sentences
|
||||||
|
if 'sentences' in data1 and 'sentences' in data2:
|
||||||
|
if isinstance(data1['sentences'], list) and isinstance(data2['sentences'], list):
|
||||||
|
print(f"合并sentences字段: {len(data1['sentences'])} + {len(data2['sentences'])} = {len(data1['sentences']) + len(data2['sentences'])} 项")
|
||||||
|
merged['sentences'] = data1['sentences'] + data2['sentences']
|
||||||
|
|
||||||
|
# 更新metadata if exists
|
||||||
|
if 'metadata' in merged:
|
||||||
|
if isinstance(merged['metadata'], dict):
|
||||||
|
merged['metadata']['total_sentences'] = len(merged['sentences'])
|
||||||
|
merged['metadata']['merged_from'] = [os.path.basename(f) for f in JSON_FILES_TO_MERGE if os.path.exists(f)]
|
||||||
|
|
||||||
|
# 合并其他字段
|
||||||
|
for key, value in data2.items():
|
||||||
|
if key != 'sentences' and key not in merged:
|
||||||
|
merged[key] = value
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
# 普通字典合并
|
||||||
|
for key, value in data2.items():
|
||||||
|
if key in merged:
|
||||||
|
# 如果key重复且都是列表,合并列表
|
||||||
|
if isinstance(merged[key], list) and isinstance(value, list):
|
||||||
|
merged[key] = merged[key] + value
|
||||||
|
# 如果key重复且都是字典,递归合并
|
||||||
|
elif isinstance(merged[key], dict) and isinstance(value, dict):
|
||||||
|
merged[key] = merge_json_data(merged[key], value)
|
||||||
|
else:
|
||||||
|
# 其他情况保留第二个文件的值
|
||||||
|
merged[key] = value
|
||||||
|
print(f"字段 '{key}' 被覆盖")
|
||||||
|
else:
|
||||||
|
merged[key] = value
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
# 类型不匹配的情况,创建一个包含两者的新结构
|
||||||
|
else:
|
||||||
|
print("数据类型不匹配,创建包含两者的新结构")
|
||||||
|
return {
|
||||||
|
"data_from_save.json": data1,
|
||||||
|
"data_from_save2.json": data2,
|
||||||
|
"merged_at": "test.py"
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_merged_json(data: Union[Dict, List], output_path: str):
|
||||||
|
"""保存合并后的JSON数据"""
|
||||||
|
try:
|
||||||
|
# 确保输出目录存在
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
|
||||||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
print(f"合并结果已保存到: {output_path}")
|
||||||
|
|
||||||
|
# 显示统计信息
|
||||||
|
if isinstance(data, dict):
|
||||||
|
if 'sentences' in data and isinstance(data['sentences'], list):
|
||||||
|
print(f"总计句子数: {len(data['sentences'])}")
|
||||||
|
print(f"总计字段数: {len(data)}")
|
||||||
|
elif isinstance(data, list):
|
||||||
|
print(f"总计列表项数: {len(data)}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"错误: 保存文件失败 - {e}")
|
||||||
|
|
||||||
|
def remove_duplicates_from_sentences(data: Union[Dict, List]) -> Union[Dict, List]:
|
||||||
|
"""从合并结果中移除重复的句子(基于句子内容)"""
|
||||||
|
if isinstance(data, dict) and 'sentences' in data:
|
||||||
|
if isinstance(data['sentences'], list):
|
||||||
|
original_count = len(data['sentences'])
|
||||||
|
seen_sentences = set()
|
||||||
|
unique_sentences = []
|
||||||
|
|
||||||
|
for item in data['sentences']:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
# 如果是字典,使用sentence字段或corrected_sentence字段作为唯一标识
|
||||||
|
sentence_key = item.get('sentence') or item.get('corrected_sentence') or item.get('original_sentence')
|
||||||
|
elif isinstance(item, str):
|
||||||
|
sentence_key = item
|
||||||
|
else:
|
||||||
|
sentence_key = str(item)
|
||||||
|
|
||||||
|
if sentence_key and sentence_key not in seen_sentences:
|
||||||
|
seen_sentences.add(sentence_key)
|
||||||
|
unique_sentences.append(item)
|
||||||
|
|
||||||
|
data['sentences'] = unique_sentences
|
||||||
|
|
||||||
|
# 更新metadata
|
||||||
|
if 'metadata' in data and isinstance(data['metadata'], dict):
|
||||||
|
data['metadata']['total_sentences'] = len(unique_sentences)
|
||||||
|
data['metadata']['duplicates_removed'] = original_count - len(unique_sentences)
|
||||||
|
|
||||||
|
print(f"去重完成: {original_count} -> {len(unique_sentences)} (移除了 {original_count - len(unique_sentences)} 个重复项)")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def merge_multiple_json_data(data_list: List[Union[Dict, List]]) -> Union[Dict, List]:
|
||||||
|
"""合并多个JSON数据结构"""
|
||||||
|
if not data_list:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if len(data_list) == 1:
|
||||||
|
return data_list[0]
|
||||||
|
|
||||||
|
print(f"准备合并 {len(data_list)} 个JSON数据结构")
|
||||||
|
|
||||||
|
# 从第一个数据开始,逐步合并其他数据
|
||||||
|
merged_data = data_list[0]
|
||||||
|
|
||||||
|
for i, data in enumerate(data_list[1:], 1):
|
||||||
|
print(f"正在合并第 {i+1} 个数据结构...")
|
||||||
|
merged_data = merge_json_data(merged_data, data)
|
||||||
|
|
||||||
|
return merged_data
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
print("=== JSON文件合并脚本 ===")
|
||||||
|
|
||||||
|
# 输出路径
|
||||||
|
output_path = "output/merged.json"
|
||||||
|
|
||||||
|
print(f"准备合并以下文件:")
|
||||||
|
for i, file_path in enumerate(JSON_FILES_TO_MERGE, 1):
|
||||||
|
print(f" {i}. {file_path}")
|
||||||
|
print(f"输出文件: {output_path}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 加载所有文件
|
||||||
|
loaded_data = []
|
||||||
|
successfully_loaded = []
|
||||||
|
|
||||||
|
for file_path in JSON_FILES_TO_MERGE:
|
||||||
|
data = load_json_file(file_path)
|
||||||
|
if data is not None:
|
||||||
|
loaded_data.append(data)
|
||||||
|
successfully_loaded.append(file_path)
|
||||||
|
|
||||||
|
# 检查是否至少有一个文件加载成功
|
||||||
|
if not loaded_data:
|
||||||
|
print("错误: 没有文件能够成功加载,退出")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"成功加载了 {len(loaded_data)} 个文件:")
|
||||||
|
for file_path in successfully_loaded:
|
||||||
|
print(f" ✓ {file_path}")
|
||||||
|
|
||||||
|
if len(loaded_data) < len(JSON_FILES_TO_MERGE):
|
||||||
|
failed_count = len(JSON_FILES_TO_MERGE) - len(loaded_data)
|
||||||
|
print(f"警告: {failed_count} 个文件加载失败")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 合并所有数据
|
||||||
|
if len(loaded_data) == 1:
|
||||||
|
print("只有一个文件可用,直接使用...")
|
||||||
|
merged_data = loaded_data[0]
|
||||||
|
else:
|
||||||
|
print("开始合并所有文件...")
|
||||||
|
merged_data = merge_multiple_json_data(loaded_data)
|
||||||
|
|
||||||
|
# 去重处理
|
||||||
|
print("\n检查并去除重复项...")
|
||||||
|
merged_data = remove_duplicates_from_sentences(merged_data)
|
||||||
|
|
||||||
|
# 保存合并结果
|
||||||
|
print("\n保存合并结果...")
|
||||||
|
save_merged_json(merged_data, output_path)
|
||||||
|
|
||||||
|
print("\n=== 合并完成 ===")
|
||||||
|
print(f"合并了 {len(successfully_loaded)} 个文件的数据")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1238
preprocessing/trex_to_sentences_simple.py
Normal file
114
requirements.txt
@ -1,30 +1,120 @@
|
|||||||
|
aiohappyeyeballs==2.6.1
|
||||||
|
aiohttp==3.11.17
|
||||||
|
aiosignal==1.3.2
|
||||||
|
altair==5.5.0
|
||||||
|
annotated-types==0.7.0
|
||||||
|
anyio==4.9.0
|
||||||
|
async-timeout==5.0.1
|
||||||
|
attrs==25.3.0
|
||||||
|
blinker==1.9.0
|
||||||
|
cachetools==5.5.2
|
||||||
|
certifi==2025.1.31
|
||||||
|
charset-normalizer==3.4.1
|
||||||
|
click==8.1.8
|
||||||
|
contourpy==1.3.2
|
||||||
|
cycler==0.12.1
|
||||||
datasets==2.21.0
|
datasets==2.21.0
|
||||||
datasketch==1.6.4
|
datasketch==1.6.4
|
||||||
|
dill==0.3.8
|
||||||
|
distro==1.9.0
|
||||||
|
docker-pycreds==0.4.0
|
||||||
|
einops==0.8.1
|
||||||
|
exceptiongroup==1.2.2
|
||||||
|
filelock==3.18.0
|
||||||
Flask==3.0.3
|
Flask==3.0.3
|
||||||
Flask_Cors==4.0.0
|
Flask-Cors==4.0.0
|
||||||
|
fonttools==4.57.0
|
||||||
|
frozenlist==1.6.0
|
||||||
|
fsspec==2024.6.1
|
||||||
|
gitdb==4.0.12
|
||||||
|
GitPython==3.1.44
|
||||||
|
h11==0.14.0
|
||||||
|
hjson==3.1.0
|
||||||
|
httpcore==1.0.8
|
||||||
|
httpx==0.28.1
|
||||||
|
huggingface-hub==0.30.2
|
||||||
|
idna==3.10
|
||||||
|
importlib_metadata==7.2.1
|
||||||
|
itsdangerous==2.2.0
|
||||||
jieba==0.42.1
|
jieba==0.42.1
|
||||||
|
Jinja2==3.1.2
|
||||||
|
jiter==0.9.0
|
||||||
|
joblib==1.4.2
|
||||||
jsonlines==4.0.0
|
jsonlines==4.0.0
|
||||||
|
jsonschema==4.23.0
|
||||||
|
jsonschema-specifications==2024.10.1
|
||||||
|
kiwisolver==1.4.8
|
||||||
|
markdown-it-py==3.0.0
|
||||||
|
MarkupSafe==3.0.2
|
||||||
marshmallow==3.22.0
|
marshmallow==3.22.0
|
||||||
matplotlib==3.10.0
|
matplotlib==3.10.0
|
||||||
|
mdurl==0.1.2
|
||||||
|
modelscope==1.25.0
|
||||||
|
mpmath==1.3.0
|
||||||
|
msgpack==1.1.0
|
||||||
|
multidict==6.4.3
|
||||||
|
multiprocess==0.70.16
|
||||||
|
narwhals==1.35.0
|
||||||
|
networkx==3.4.2
|
||||||
ngrok==1.4.0
|
ngrok==1.4.0
|
||||||
|
ninja==1.11.1.4
|
||||||
nltk==3.8
|
nltk==3.8
|
||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
openai==1.59.6
|
openai==1.59.6
|
||||||
|
packaging==23.2
|
||||||
pandas==1.5.3
|
pandas==1.5.3
|
||||||
peft==0.7.1
|
peft==0.7.1
|
||||||
|
pillow==10.4.0
|
||||||
|
platformdirs==4.3.7
|
||||||
|
propcache==0.3.1
|
||||||
|
protobuf==4.25.6
|
||||||
psutil==5.9.8
|
psutil==5.9.8
|
||||||
|
py-cpuinfo==9.0.0
|
||||||
|
pyarrow==19.0.1
|
||||||
pydantic==2.8.2
|
pydantic==2.8.2
|
||||||
|
pydantic_core==2.20.1
|
||||||
|
pydeck==0.9.1
|
||||||
|
Pygments==2.19.1
|
||||||
|
pyparsing==3.2.3
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
pytz==2025.2
|
||||||
|
PyYAML==6.0.2
|
||||||
|
referencing==0.36.2
|
||||||
|
regex==2024.11.6
|
||||||
|
requests==2.32.3
|
||||||
rich==13.7.1
|
rich==13.7.1
|
||||||
scikit_learn==1.5.1
|
rpds-py==0.24.0
|
||||||
sentence_transformers==2.3.1
|
safetensors==0.5.3
|
||||||
|
scikit-learn==1.5.1
|
||||||
|
scipy==1.15.2
|
||||||
|
sentence-transformers==2.3.1
|
||||||
|
sentencepiece==0.2.0
|
||||||
|
sentry-sdk==2.26.1
|
||||||
|
setproctitle==1.3.5
|
||||||
simhash==2.1.2
|
simhash==2.1.2
|
||||||
tiktoken==0.5.1
|
six==1.17.0
|
||||||
transformers==4.48.0
|
smmap==5.0.2
|
||||||
jinja2==3.1.2
|
sniffio==1.3.1
|
||||||
jsonlines==4.0.0
|
|
||||||
trl==0.13.0
|
|
||||||
ujson==5.1.0
|
|
||||||
wandb==0.18.3
|
|
||||||
streamlit==1.30.0
|
streamlit==1.30.0
|
||||||
torch==2.2.2
|
sympy==1.13.3
|
||||||
torchvision==0.17.2
|
tenacity==8.5.0
|
||||||
|
threadpoolctl==3.6.0
|
||||||
|
tiktoken==0.5.1
|
||||||
|
tokenizers==0.21.1
|
||||||
|
toml==0.10.2
|
||||||
|
tornado==6.4.2
|
||||||
|
tqdm==4.67.1
|
||||||
|
transformers==4.48.0
|
||||||
|
triton==3.3.0
|
||||||
|
trl==0.13.0
|
||||||
|
typing_extensions==4.13.2
|
||||||
|
tzlocal==5.3.1
|
||||||
|
ujson==5.1.0
|
||||||
|
urllib3==2.4.0
|
||||||
|
validators==0.34.0
|
||||||
|
wandb==0.18.3
|
||||||
|
watchdog==6.0.0
|
||||||
|
Werkzeug==3.1.3
|
||||||
|
xxhash==3.5.0
|
||||||
|
yarl==1.20.0
|
||||||
|
zipp==3.21.0
|
||||||
|
34
run_file/DynamicKV-LLM_Mini_Minimind.sh
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# 激活conda环境
|
||||||
|
source $(conda info --base)/etc/profile.d/conda.sh
|
||||||
|
conda activate mini
|
||||||
|
|
||||||
|
# 设置环境变量以帮助调试
|
||||||
|
export NCCL_DEBUG=INFO
|
||||||
|
export PYTHONFAULTHANDLER=1
|
||||||
|
|
||||||
|
# 方法1: 使用预先配置的accelerate配置文件
|
||||||
|
# accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
|
||||||
|
# --epochs 3 \
|
||||||
|
# --batch_size 24 \
|
||||||
|
# --learning_rate 2e-4 \
|
||||||
|
# --dtype bfloat16 \
|
||||||
|
# --accumulation_steps 32 \
|
||||||
|
# --grad_clip 1.0 \
|
||||||
|
# --log_interval 100 \
|
||||||
|
# --save_interval 10000 \
|
||||||
|
# --dim 1024 \
|
||||||
|
# --n_layers 32 \
|
||||||
|
# --max_seq_len 1024 \
|
||||||
|
# --use_flash_attn \
|
||||||
|
# --profile \
|
||||||
|
# --profile_interval 10
|
||||||
|
|
||||||
|
# 方法2: 使用命令行参数直接配置accelerate
|
||||||
|
CUDA_VISIBLE_DEVICES=0 /opt/conda/envs/mini/bin/python -m accelerate.commands.launch \
|
||||||
|
--num_processes=1 \
|
||||||
|
--mixed_precision=bf16 \
|
||||||
|
--main_process_port=29500 \
|
||||||
|
train_pretrain_accelerate.py \
|
||||||
|
|
50
run_file/DynamicKV-LLM_Small_Minimind.sh
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# 激活conda环境
|
||||||
|
source $(conda info --base)/etc/profile.d/conda.sh
|
||||||
|
conda activate ycz_accelerate
|
||||||
|
|
||||||
|
# 设置环境变量以帮助调试
|
||||||
|
export NCCL_DEBUG=INFO
|
||||||
|
export PYTHONFAULTHANDLER=1
|
||||||
|
|
||||||
|
# 方法1: 使用预先配置的accelerate配置文件
|
||||||
|
# accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
|
||||||
|
# --epochs 3 \
|
||||||
|
# --batch_size 24 \
|
||||||
|
# --learning_rate 2e-4 \
|
||||||
|
# --dtype bfloat16 \
|
||||||
|
# --accumulation_steps 32 \
|
||||||
|
# --grad_clip 1.0 \
|
||||||
|
# --log_interval 100 \
|
||||||
|
# --save_interval 10000 \
|
||||||
|
# --dim 1024 \
|
||||||
|
# --n_layers 32 \
|
||||||
|
# --max_seq_len 1024 \
|
||||||
|
# --use_flash_attn \
|
||||||
|
# --profile \
|
||||||
|
# --profile_interval 10
|
||||||
|
|
||||||
|
# 方法2: 使用命令行参数直接配置accelerate
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||||
|
--multi_gpu \
|
||||||
|
--num_processes=4 \
|
||||||
|
--mixed_precision=bf16 \
|
||||||
|
--main_process_port=29500 \
|
||||||
|
train_pretrain_accelerate.py \
|
||||||
|
--epochs 3 \
|
||||||
|
--batch_size 24 \
|
||||||
|
--learning_rate 2e-4 \
|
||||||
|
--dtype bfloat16 \
|
||||||
|
--accumulation_steps 32 \
|
||||||
|
--grad_clip 1.0 \
|
||||||
|
--log_interval 100 \
|
||||||
|
--save_interval 10000 \
|
||||||
|
--dim 1024 \
|
||||||
|
--n_layers 32 \
|
||||||
|
--max_seq_len 1024 \
|
||||||
|
--use_flash_attn \
|
||||||
|
--profile \
|
||||||
|
--profile_interval 10\
|
||||||
|
--knowledge_num 16384 \
|
||||||
|
--knowledge_length 64
|
97
test_real_rope.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
测试实数版本的位置编码
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from model.model import precompute_pos_cis, precompute_pos_cis_real, apply_rotary_emb, apply_rotary_emb_real
|
||||||
|
from model.LMConfig import LMConfig
|
||||||
|
from model.model import MiniMindLM
|
||||||
|
|
||||||
|
def test_pos_encoding_equivalence():
|
||||||
|
"""测试复数版本和实数版本的位置编码是否等价"""
|
||||||
|
print("测试位置编码等价性...")
|
||||||
|
|
||||||
|
# 参数设置
|
||||||
|
dim = 64
|
||||||
|
seq_len = 10
|
||||||
|
|
||||||
|
# 生成复数版本的位置编码
|
||||||
|
pos_cis = precompute_pos_cis(dim=dim, end=seq_len)
|
||||||
|
|
||||||
|
# 生成实数版本的位置编码
|
||||||
|
pos_cis_real = precompute_pos_cis_real(dim=dim, end=seq_len)
|
||||||
|
|
||||||
|
# 创建随机查询和键
|
||||||
|
batch_size = 2
|
||||||
|
n_heads = 4
|
||||||
|
head_dim = dim
|
||||||
|
|
||||||
|
xq = torch.randn(batch_size, seq_len, n_heads, head_dim)
|
||||||
|
xk = torch.randn(batch_size, seq_len, n_heads, head_dim)
|
||||||
|
|
||||||
|
# 应用复数版本的旋转位置编码
|
||||||
|
xq_complex, xk_complex = apply_rotary_emb(xq, xk, pos_cis)
|
||||||
|
|
||||||
|
# 应用实数版本的旋转位置编码
|
||||||
|
xq_real, xk_real = apply_rotary_emb_real(xq, xk, pos_cis_real)
|
||||||
|
|
||||||
|
# 计算差异
|
||||||
|
q_diff = torch.abs(xq_complex - xq_real).mean().item()
|
||||||
|
k_diff = torch.abs(xk_complex - xk_real).mean().item()
|
||||||
|
|
||||||
|
print(f"查询差异: {q_diff:.6f}")
|
||||||
|
print(f"键差异: {k_diff:.6f}")
|
||||||
|
|
||||||
|
# 检查差异是否在可接受范围内
|
||||||
|
tolerance = 1e-5
|
||||||
|
if q_diff < tolerance and k_diff < tolerance:
|
||||||
|
print("✅ 测试通过: 复数版本和实数版本的位置编码在数值上等价")
|
||||||
|
else:
|
||||||
|
print("❌ 测试失败: 复数版本和实数版本的位置编码存在显著差异")
|
||||||
|
|
||||||
|
def test_model_forward():
|
||||||
|
"""测试模型前向传播"""
|
||||||
|
print("\n测试模型前向传播...")
|
||||||
|
|
||||||
|
# 创建模型配置
|
||||||
|
config = LMConfig(
|
||||||
|
dim=128,
|
||||||
|
n_layers=2,
|
||||||
|
n_heads=4,
|
||||||
|
n_kv_heads=4, # 确保n_kv_heads被设置,且n_heads能被n_kv_heads整除
|
||||||
|
vocab_size=1000,
|
||||||
|
max_seq_len=128,
|
||||||
|
disable_db=True # 禁用数据库功能,避免额外的复杂性
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
try:
|
||||||
|
model = MiniMindLM(config)
|
||||||
|
print(f"✅ 模型初始化成功")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 模型初始化失败: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 创建输入
|
||||||
|
batch_size = 2
|
||||||
|
seq_len = 10
|
||||||
|
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(input_ids)
|
||||||
|
print(f"✅ 模型前向传播成功")
|
||||||
|
print(f"输出形状: {outputs.logits.shape}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 模型前向传播失败: {str(e)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试位置编码等价性
|
||||||
|
test_pos_encoding_equivalence()
|
||||||
|
|
||||||
|
# 测试模型前向传播
|
||||||
|
test_model_forward()
|
@ -13,6 +13,7 @@ from torch import optim, nn
|
|||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
# 移除通信分析工具导入
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -40,19 +41,69 @@ def get_lr(current_step, total_steps, lr):
|
|||||||
def train_epoch(epoch, wandb):
|
def train_epoch(epoch, wandb):
|
||||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
# 在函数开始处定义moe_path,避免在异常处理中引用未定义变量
|
||||||
moe_path = '_moe' if lm_config.use_moe else ''
|
moe_path = '_moe' if lm_config.use_moe else ''
|
||||||
for step, (X, Y, loss_mask) in enumerate(train_loader):
|
|
||||||
|
# 添加CUDA事件来分析性能
|
||||||
|
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||||
|
data_start = torch.cuda.Event(enable_timing=True)
|
||||||
|
data_end = torch.cuda.Event(enable_timing=True)
|
||||||
|
forward_start = torch.cuda.Event(enable_timing=True)
|
||||||
|
forward_end = torch.cuda.Event(enable_timing=True)
|
||||||
|
backward_start = torch.cuda.Event(enable_timing=True)
|
||||||
|
backward_end = torch.cuda.Event(enable_timing=True)
|
||||||
|
optimizer_start = torch.cuda.Event(enable_timing=True)
|
||||||
|
optimizer_end = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
# 移除CUDA图优化代码
|
||||||
|
|
||||||
|
# 预取数据
|
||||||
|
prefetch_factor = 2 # 预取的批次数
|
||||||
|
data_iter = iter(train_loader)
|
||||||
|
prefetch_batches = []
|
||||||
|
|
||||||
|
# 预取初始批次
|
||||||
|
for _ in range(min(prefetch_factor, len(train_loader))):
|
||||||
try:
|
try:
|
||||||
# 将数据加载到设备上
|
batch = next(data_iter)
|
||||||
X = X.to(args.device)
|
prefetch_batches.append([t.to(args.device, non_blocking=True) for t in batch])
|
||||||
Y = Y.to(args.device)
|
except StopIteration:
|
||||||
loss_mask = loss_mask.to(args.device)
|
break
|
||||||
|
|
||||||
|
for step in range(len(train_loader)):
|
||||||
|
try:
|
||||||
|
# 计时数据加载
|
||||||
|
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||||
|
data_start.record()
|
||||||
|
|
||||||
|
# 使用预取的数据
|
||||||
|
if prefetch_batches:
|
||||||
|
X, Y, loss_mask = prefetch_batches.pop(0)
|
||||||
|
else:
|
||||||
|
# 如果预取队列为空,直接加载
|
||||||
|
X, Y, loss_mask = [t.to(args.device) for t in next(data_iter)]
|
||||||
|
|
||||||
|
# 异步预取下一批数据
|
||||||
|
if step + prefetch_factor < len(train_loader):
|
||||||
|
try:
|
||||||
|
batch = next(data_iter)
|
||||||
|
prefetch_batches.append([t.to(args.device, non_blocking=True) for t in batch])
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||||
|
data_end.record()
|
||||||
|
|
||||||
# 更新学习率
|
# 更新学习率
|
||||||
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
|
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
|
||||||
for param_group in optimizer.param_groups:
|
for param_group in optimizer.param_groups:
|
||||||
param_group['lr'] = lr
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
# 计时前向传播
|
||||||
|
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||||
|
forward_start.record()
|
||||||
|
|
||||||
|
# 常规前向传播
|
||||||
with ctx:
|
with ctx:
|
||||||
res = model(X)
|
res = model(X)
|
||||||
loss = loss_fct(
|
loss = loss_fct(
|
||||||
@ -76,6 +127,13 @@ def train_epoch(epoch, wandb):
|
|||||||
# 如果出错,不添加辅助损失
|
# 如果出错,不添加辅助损失
|
||||||
loss = loss / args.accumulation_steps
|
loss = loss / args.accumulation_steps
|
||||||
|
|
||||||
|
# 反向传播
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||||
|
forward_end.record()
|
||||||
|
backward_start.record()
|
||||||
|
|
||||||
# Print data types for debugging
|
# Print data types for debugging
|
||||||
if step == 0 and (not ddp or dist.get_rank() == 0): # Print only for the first step of the first epoch on the main process
|
if step == 0 and (not ddp or dist.get_rank() == 0): # Print only for the first step of the first epoch on the main process
|
||||||
Logger("---- Data Type Check ----")
|
Logger("---- Data Type Check ----")
|
||||||
@ -88,9 +146,21 @@ def train_epoch(epoch, wandb):
|
|||||||
Logger(f"loss.dtype: {loss.dtype}")
|
Logger(f"loss.dtype: {loss.dtype}")
|
||||||
Logger("-------------------------")
|
Logger("-------------------------")
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||||
|
backward_end.record()
|
||||||
|
|
||||||
|
# 在每一步都进行性能分析,而不仅仅是在梯度累积完成时
|
||||||
|
if (step + 1) % args.profile_interval == 0:
|
||||||
|
# 记录优化器时间(如果是梯度累积步骤)
|
||||||
|
if (step + 1) % args.accumulation_steps == 0:
|
||||||
|
optimizer_start.record()
|
||||||
|
|
||||||
|
# 优化器步骤
|
||||||
if (step + 1) % args.accumulation_steps == 0:
|
if (step + 1) % args.accumulation_steps == 0:
|
||||||
|
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||||
|
if (step + 1) % args.profile_interval != 0:
|
||||||
|
optimizer_start.record()
|
||||||
|
|
||||||
scaler.unscale_(optimizer)
|
scaler.unscale_(optimizer)
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
|
|
||||||
@ -99,6 +169,40 @@ def train_epoch(epoch, wandb):
|
|||||||
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||||
|
optimizer_end.record()
|
||||||
|
|
||||||
|
# 性能分析输出(每profile_interval步)
|
||||||
|
if args.profile and (not ddp or dist.get_rank() == 0) and (step + 1) % args.profile_interval == 0:
|
||||||
|
# 同步CUDA事件以获取准确的计时
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# 计算各阶段耗时
|
||||||
|
data_time = data_start.elapsed_time(data_end)
|
||||||
|
forward_time = forward_start.elapsed_time(forward_end)
|
||||||
|
backward_time = backward_start.elapsed_time(backward_end)
|
||||||
|
|
||||||
|
# 只有在梯度累积步骤完成时才有优化器时间
|
||||||
|
if (step + 1) % args.accumulation_steps == 0:
|
||||||
|
optimizer_time = optimizer_start.elapsed_time(optimizer_end)
|
||||||
|
total_compute_time = forward_time + backward_time + optimizer_time
|
||||||
|
Logger(f"性能分析 - 步骤 {step+1}:")
|
||||||
|
Logger(f" 数据加载时间: {data_time:.2f} ms")
|
||||||
|
Logger(f" 前向传播时间: {forward_time:.2f} ms")
|
||||||
|
Logger(f" 反向传播时间: {backward_time:.2f} ms")
|
||||||
|
Logger(f" 优化器时间: {optimizer_time:.2f} ms")
|
||||||
|
Logger(f" 总计算时间: {total_compute_time:.2f} ms")
|
||||||
|
Logger(f" 计算/数据比例: {total_compute_time / data_time:.2f}")
|
||||||
|
else:
|
||||||
|
# 非梯度累积步骤,没有优化器时间
|
||||||
|
total_compute_time = forward_time + backward_time
|
||||||
|
Logger(f"性能分析 - 步骤 {step+1} (梯度累积中):")
|
||||||
|
Logger(f" 数据加载时间: {data_time:.2f} ms")
|
||||||
|
Logger(f" 前向传播时间: {forward_time:.2f} ms")
|
||||||
|
Logger(f" 反向传播时间: {backward_time:.2f} ms")
|
||||||
|
Logger(f" 总计算时间: {total_compute_time:.2f} ms")
|
||||||
|
Logger(f" 计算/数据比例: {total_compute_time / data_time:.2f}")
|
||||||
|
|
||||||
# 打印日志
|
# 打印日志
|
||||||
if step % args.log_interval == 0:
|
if step % args.log_interval == 0:
|
||||||
spend_time = time.time() - start_time
|
spend_time = time.time() - start_time
|
||||||
@ -113,14 +217,44 @@ def train_epoch(epoch, wandb):
|
|||||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
||||||
|
|
||||||
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
||||||
wandb.log({"loss": loss.item() * args.accumulation_steps,
|
log_dict = {
|
||||||
"lr": optimizer.param_groups[-1]['lr'],
|
"loss": loss.item() * args.accumulation_steps,
|
||||||
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
|
"lr": optimizer.param_groups[-1]['lr'],
|
||||||
|
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果启用了性能分析,也记录性能指标
|
||||||
|
if args.profile and (step + 1) % args.profile_interval == 0:
|
||||||
|
# 基本性能指标
|
||||||
|
perf_dict = {
|
||||||
|
"data_time_ms": data_time,
|
||||||
|
"forward_time_ms": forward_time,
|
||||||
|
"backward_time_ms": backward_time
|
||||||
|
}
|
||||||
|
|
||||||
|
# 只有在梯度累积步骤完成时才有优化器时间
|
||||||
|
if (step + 1) % args.accumulation_steps == 0:
|
||||||
|
total_compute_time = forward_time + backward_time + optimizer_time
|
||||||
|
perf_dict.update({
|
||||||
|
"optimizer_time_ms": optimizer_time,
|
||||||
|
"compute_time_ms": total_compute_time
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
total_compute_time = forward_time + backward_time
|
||||||
|
perf_dict.update({
|
||||||
|
"compute_time_ms": total_compute_time
|
||||||
|
})
|
||||||
|
|
||||||
|
log_dict.update(perf_dict)
|
||||||
|
|
||||||
|
wandb.log(log_dict)
|
||||||
|
|
||||||
|
# 移除通信分析代码
|
||||||
|
|
||||||
# 保存模型
|
# 保存模型
|
||||||
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
|
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
|
||||||
model.eval()
|
model.eval()
|
||||||
# moe_path = '_moe' if lm_config.use_moe else ''
|
# 使用函数开始处定义的moe_path变量
|
||||||
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
|
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
|
||||||
|
|
||||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||||
@ -136,7 +270,7 @@ def train_epoch(epoch, wandb):
|
|||||||
save_path = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}_nanERROR.pth'
|
save_path = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}_nanERROR.pth'
|
||||||
if os.path.exists(save_path):
|
if os.path.exists(save_path):
|
||||||
os.remove(save_path)
|
os.remove(save_path)
|
||||||
|
|
||||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||||
state_dict = model.module.state_dict()
|
state_dict = model.module.state_dict()
|
||||||
else:
|
else:
|
||||||
@ -146,18 +280,18 @@ def train_epoch(epoch, wandb):
|
|||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.grad is not None and torch.isnan(param.grad).any():
|
if param.grad is not None and torch.isnan(param.grad).any():
|
||||||
print(f"NaN gradient in parameter: {name}")
|
print(f"NaN gradient in parameter: {name}")
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.grad is not None and torch.isnan(param.grad).any():
|
if param.grad is not None and torch.isnan(param.grad).any():
|
||||||
print(f"Parameter {name} values: {param.data}")
|
print(f"Parameter {name} values: {param.data}")
|
||||||
print(f"Parameter {name} gradients: {param.grad}")
|
print(f"Parameter {name} gradients: {param.grad}")
|
||||||
|
|
||||||
raise ValueError("NaN gradient detected")
|
raise ValueError("NaN gradient detected")
|
||||||
|
|
||||||
|
|
||||||
def init_model(lm_config, pretrained_embedding_path: Optional[str] = None):
|
def init_model(lm_config, pretrained_embedding_path: Optional[str] = None):
|
||||||
# 加载tokenizer
|
# 加载tokenizer
|
||||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
tokenizer = AutoTokenizer.from_pretrained('/mnt/lzn/Minimind/Minimind/model/minimind_tokenizer')
|
||||||
# 加载模型
|
# 加载模型
|
||||||
model = MiniMindLM(lm_config).to(args.device)
|
model = MiniMindLM(lm_config).to(args.device)
|
||||||
|
|
||||||
@ -175,6 +309,9 @@ def init_model(lm_config, pretrained_embedding_path: Optional[str] = None):
|
|||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
# 移除通信分析函数
|
||||||
|
|
||||||
|
|
||||||
def init_distributed_mode():
|
def init_distributed_mode():
|
||||||
if not ddp: return #如果没有启用分布式数据并行(DDP),直接返回,不执行任何操作。
|
if not ddp: return #如果没有启用分布式数据并行(DDP),直接返回,不执行任何操作。
|
||||||
global ddp_local_rank, DEVICE #声明这两个变量为全局变量,以便在函数外部也能访问它们。
|
global ddp_local_rank, DEVICE #声明这两个变量为全局变量,以便在函数外部也能访问它们。
|
||||||
@ -193,35 +330,42 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--out_dir", type=str, default="out")
|
parser.add_argument("--out_dir", type=str, default="out")
|
||||||
# 若要以最快速度实现zero则epochs设置为1轮;否则应当利用有限的数据训练2~6个epochs。
|
# 若要以最快速度实现zero则epochs设置为1轮;否则应当利用有限的数据训练2~6个epochs。
|
||||||
parser.add_argument("--epochs", type=int, default=3)
|
parser.add_argument("--epochs", type=int, default=3)
|
||||||
parser.add_argument("--batch_size", type=int, default=32)
|
parser.add_argument("--batch_size", type=int, default=24)
|
||||||
parser.add_argument("--learning_rate", type=float, default=5e-4)
|
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用,则使用GPU,否则使用CPU。
|
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用,则使用GPU,否则使用CPU。
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||||
parser.add_argument("--use_wandb", default=True, action="store_true")
|
parser.add_argument("--use_wandb", default=True, action="store_true")
|
||||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
|
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
|
||||||
parser.add_argument("--num_workers", type=int, default=8)
|
parser.add_argument("--num_workers", type=int, default=48)
|
||||||
parser.add_argument("--ddp", action="store_true")
|
parser.add_argument("--ddp", action="store_true")
|
||||||
parser.add_argument("--accumulation_steps", type=int, default=8) #梯度累积步数,用于控制梯度更新频率。
|
parser.add_argument("--accumulation_steps", type=int, default=32) #梯度累积步数,用于控制梯度更新频率。
|
||||||
parser.add_argument("--grad_clip", type=float, default=1.0) #梯度裁剪阈值,用于防止梯度爆炸。
|
parser.add_argument("--grad_clip", type=float, default=1.0) #梯度裁剪阈值,用于防止梯度爆炸。
|
||||||
parser.add_argument("--warmup_iters", type=int, default=0) #预热迭代次数,用于控制学习率预热过程。
|
parser.add_argument("--warmup_iters", type=int, default=0) #预热迭代次数,用于控制学习率预热过程。
|
||||||
parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。
|
parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。
|
||||||
parser.add_argument("--save_interval", type=int, default=100) #模型保存间隔,用于控制模型保存的频率。
|
parser.add_argument("--save_interval", type=int, default=10000) #模型保存间隔,用于控制模型保存的频率。
|
||||||
parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。
|
parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。
|
||||||
parser.add_argument('--dim', default=768, type=int) #模型维度,用于控制模型的大小。
|
parser.add_argument('--dim', default=1024, type=int) #模型维度,用于控制模型的大小。
|
||||||
parser.add_argument('--n_layers', default=8, type=int) #层数,用于控制模型层数。
|
parser.add_argument('--n_layers', default=32, type=int) #层数,用于控制模型层数。
|
||||||
parser.add_argument('--max_seq_len', default=512, type=int) #最大序列长度,用于控制输入序列的最大长度。
|
parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。
|
||||||
parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE,用于控制是否使用MOE。
|
parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE,用于控制是否使用MOE。
|
||||||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代") #禁用数据库功能,启用特殊模式
|
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代") #禁用数据库功能,启用特殊模式
|
||||||
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
|
parser.add_argument("--data_path", type=str, default="/mnt/lzn/Minimind/dataset/dir/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
|
||||||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||||||
|
# 性能分析相关参数
|
||||||
|
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||||
|
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||||
|
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
|
||||||
lm_config = LMConfig(
|
lm_config = LMConfig(
|
||||||
dim=args.dim,
|
dim=args.dim,
|
||||||
n_layers=args.n_layers,
|
n_layers=args.n_layers,
|
||||||
max_seq_len=args.max_seq_len,
|
max_seq_len=args.max_seq_len,
|
||||||
use_moe=args.use_moe,
|
use_moe=args.use_moe,
|
||||||
disable_db=args.disable_db # 添加禁用数据库参数
|
disable_db=args.disable_db, # 添加禁用数据库参数
|
||||||
|
flash_attn=args.use_flash_attn # 添加FlashAttention支持
|
||||||
) #创建LMConfig对象,用于控制模型配置。
|
) #创建LMConfig对象,用于控制模型配置。
|
||||||
args.save_dir = os.path.join(args.out_dir) #创建保存目录。
|
args.save_dir = os.path.join(args.out_dir) #创建保存目录。
|
||||||
os.makedirs(args.save_dir, exist_ok=True) #创建保存目录。
|
os.makedirs(args.save_dir, exist_ok=True) #创建保存目录。
|
||||||
@ -254,35 +398,42 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Merge args and lm_config parameters for wandb config
|
# Merge args and lm_config parameters for wandb config
|
||||||
config = vars(args).copy()
|
config = vars(args).copy()
|
||||||
config.update(lm_config.__dict__)
|
config.update(lm_config.__dict__)
|
||||||
|
|
||||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config)
|
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config)
|
||||||
else:
|
else:
|
||||||
wandb = None
|
wandb = None
|
||||||
|
|
||||||
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
|
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
|
||||||
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||||
|
# 优化DataLoader配置
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_ds,
|
train_ds,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
|
pin_memory_device=f"cuda:{ddp_local_rank}" if ddp else "cuda:0", # 指定pin_memory设备
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=args.num_workers,
|
num_workers=args.num_workers,
|
||||||
sampler=train_sampler
|
sampler=train_sampler,
|
||||||
|
persistent_workers=True if args.num_workers > 0 else False, # 保持worker进程活跃
|
||||||
|
prefetch_factor=2 if args.num_workers > 0 else None # 预取因子
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16']))
|
# 只有在使用float16时才启用GradScaler,bfloat16不需要
|
||||||
|
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||||
|
|
||||||
if ddp:
|
if ddp:
|
||||||
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
||||||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
|
# 保留find_unused_parameters=True参数,因为模型中确实有未使用的参数
|
||||||
|
model = DistributedDataParallel(model, device_ids=[ddp_local_rank], find_unused_parameters=True)
|
||||||
|
|
||||||
|
# 暂时保留set_detect_anomaly以便调试
|
||||||
|
# 训练稳定后可以注释掉这行来提高速度
|
||||||
torch.autograd.set_detect_anomaly(True)
|
torch.autograd.set_detect_anomaly(True)
|
||||||
iter_per_epoch = len(train_loader)
|
iter_per_epoch = len(train_loader)
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
|
628
train_pretrain_accelerate.py
Normal file
@ -0,0 +1,628 @@
|
|||||||
|
import os
|
||||||
|
# 设置环境变量
|
||||||
|
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
|
||||||
|
import platform
|
||||||
|
import argparse
|
||||||
|
from tqdm import tqdm
|
||||||
|
import time
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from torch import optim, nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from typing import Optional
|
||||||
|
import datetime # Add datetime for time formatting
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
from accelerate.utils import DeepSpeedPlugin
|
||||||
|
from accelerate.utils import DistributedDataParallelKwargs
|
||||||
|
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
|
from model.model import MiniMindLM, RMSNorm
|
||||||
|
from model.LMConfig import LMConfig
|
||||||
|
from model.dataset import PretrainDataset
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
# 日志记录函数
|
||||||
|
def Logger(msg, accelerator=None):
|
||||||
|
# 如果没有提供accelerator,则只在主进程打印
|
||||||
|
if accelerator is None or accelerator.is_main_process:
|
||||||
|
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}")
|
||||||
|
|
||||||
|
# Helper function to format seconds into HH:MM:SS
|
||||||
|
def format_time(seconds):
|
||||||
|
return str(datetime.timedelta(seconds=int(seconds)))
|
||||||
|
|
||||||
|
# 获取学习率函数
|
||||||
|
def get_lr(it, num_iters, learning_rate):
|
||||||
|
# 余弦学习率衰减
|
||||||
|
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
|
||||||
|
|
||||||
|
# 初始化模型函数
|
||||||
|
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||||
|
model = MiniMindLM(lm_config)
|
||||||
|
|
||||||
|
# 默认模型初始化
|
||||||
|
Logger("Performing default model initialization...")
|
||||||
|
|
||||||
|
# 初始化嵌入层权重
|
||||||
|
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
# 初始化输出层权重(如果不共享权重的话)
|
||||||
|
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
|
||||||
|
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
# 初始化所有线性层
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
# 使用Xavier/Glorot初始化
|
||||||
|
nn.init.xavier_uniform_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
# 嵌入层使用正态分布初始化
|
||||||
|
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
elif isinstance(module, RMSNorm):
|
||||||
|
# RMSNorm的权重初始化为1
|
||||||
|
if hasattr(module, 'weight'):
|
||||||
|
nn.init.ones_(module.weight)
|
||||||
|
|
||||||
|
# 初始化位置编码相关参数
|
||||||
|
if hasattr(model.knowledge_dataset, 'keys'):
|
||||||
|
nn.init.normal_(model.knowledge_dataset.keys, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
Logger("Default model initialization completed")
|
||||||
|
|
||||||
|
# 如果提供了预训练的嵌入权重,加载它们
|
||||||
|
if pretrained_embedding_path:
|
||||||
|
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
|
||||||
|
pretrained_embeddings = torch.load(pretrained_embedding_path)
|
||||||
|
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
|
||||||
|
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
|
||||||
|
|
||||||
|
if database_init_path:
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 数据库参数
|
||||||
|
knowledge_num = args.knowledge_num
|
||||||
|
knowledge_length = args.knowledge_length
|
||||||
|
|
||||||
|
# 检查是否使用缓存
|
||||||
|
cache_dir = os.path.dirname(args.cluster_cache_path)
|
||||||
|
if cache_dir:
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
processed_tensor = None
|
||||||
|
|
||||||
|
# 尝试加载缓存的处理结果
|
||||||
|
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
|
||||||
|
try:
|
||||||
|
Logger(f"Loading cached processed results from {args.cluster_cache_path}")
|
||||||
|
processed_tensor = torch.load(args.cluster_cache_path)
|
||||||
|
|
||||||
|
# 验证缓存文件的形状是否可用
|
||||||
|
cached_knowledge_num, cached_knowledge_length = processed_tensor.shape
|
||||||
|
|
||||||
|
if cached_knowledge_length == knowledge_length:
|
||||||
|
if cached_knowledge_num >= knowledge_num:
|
||||||
|
# 缓存足够大,可以截取使用
|
||||||
|
processed_tensor = processed_tensor[:knowledge_num, :]
|
||||||
|
Logger(f"Successfully loaded cached data with shape {processed_tensor.shape}")
|
||||||
|
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
|
||||||
|
Logger("Skipping database initialization - using cached results")
|
||||||
|
else:
|
||||||
|
# 缓存太小,需要重新计算
|
||||||
|
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
|
||||||
|
processed_tensor = None
|
||||||
|
else:
|
||||||
|
# knowledge_length不匹配,需要重新计算
|
||||||
|
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
|
||||||
|
processed_tensor = None
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"Failed to load cached data: {e}, recomputing...")
|
||||||
|
processed_tensor = None
|
||||||
|
|
||||||
|
# 只有在没有有效缓存时才进行数据库初始化和处理
|
||||||
|
if processed_tensor is None:
|
||||||
|
Logger(f"Loading database initialization data from {database_init_path}")
|
||||||
|
|
||||||
|
# 1. 加载JSON文件
|
||||||
|
with open(database_init_path, 'r', encoding='utf-8') as f:
|
||||||
|
database_data = json.load(f)
|
||||||
|
|
||||||
|
# 提取sentences列表
|
||||||
|
sentences_data = database_data.get('sentences', [])
|
||||||
|
Logger(f"Loaded {len(sentences_data)} sentences from database")
|
||||||
|
|
||||||
|
# 2. 按照importance_score进行排序(从高到低)
|
||||||
|
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
||||||
|
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
|
||||||
|
|
||||||
|
# 3. 处理每条数据,不进行聚类
|
||||||
|
Logger("Processing individual sentences...")
|
||||||
|
processed_rows = []
|
||||||
|
|
||||||
|
# 获取空token的id(用于填充)
|
||||||
|
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||||
|
|
||||||
|
# 处理所需数量的句子
|
||||||
|
num_to_process = min(knowledge_num, len(sorted_sentences))
|
||||||
|
|
||||||
|
for i in range(num_to_process):
|
||||||
|
sentence_data = sorted_sentences[i]
|
||||||
|
sentence = sentence_data.get('corrected_sentence', '')
|
||||||
|
|
||||||
|
# 将句子转换为tokens
|
||||||
|
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
||||||
|
|
||||||
|
# 截断或填充到knowledge_length
|
||||||
|
if len(sentence_tokens) > knowledge_length:
|
||||||
|
# 如果超过长度,截断
|
||||||
|
sentence_tokens = sentence_tokens[:knowledge_length]
|
||||||
|
Logger(f"Sentence {i+1} truncated from {len(tokenizer.encode(sentence, add_special_tokens=False))} to {knowledge_length} tokens")
|
||||||
|
else:
|
||||||
|
# 如果不足长度,用空token填充
|
||||||
|
original_length = len(sentence_tokens)
|
||||||
|
sentence_tokens.extend([pad_token_id] * (knowledge_length - len(sentence_tokens)))
|
||||||
|
if original_length < knowledge_length:
|
||||||
|
Logger(f"Sentence {i+1} padded from {original_length} to {knowledge_length} tokens")
|
||||||
|
|
||||||
|
processed_rows.append(sentence_tokens)
|
||||||
|
|
||||||
|
if (i + 1) % 1000 == 0:
|
||||||
|
Logger(f"Processed {i + 1}/{num_to_process} sentences")
|
||||||
|
|
||||||
|
# 如果句子数量不足,用空token填充剩余位置
|
||||||
|
while len(processed_rows) < knowledge_num:
|
||||||
|
empty_tokens = [pad_token_id] * knowledge_length
|
||||||
|
processed_rows.append(empty_tokens)
|
||||||
|
if len(processed_rows) % 1000 == 0:
|
||||||
|
Logger(f"Added empty entry {len(processed_rows)}/{knowledge_num}")
|
||||||
|
|
||||||
|
Logger(f"Finished adding empty entries. Total: {len(processed_rows)}/{knowledge_num}")
|
||||||
|
|
||||||
|
# 转换为tensor
|
||||||
|
processed_tensor = torch.tensor(processed_rows, dtype=torch.long)
|
||||||
|
|
||||||
|
Logger(f"Data processing completed:")
|
||||||
|
Logger(f" - Processed {num_to_process} sentences")
|
||||||
|
Logger(f" - Added {knowledge_num - num_to_process} empty entries")
|
||||||
|
Logger(f" - Final shape: {processed_tensor.shape}")
|
||||||
|
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
|
||||||
|
|
||||||
|
# 保存处理结果到缓存文件
|
||||||
|
try:
|
||||||
|
torch.save(processed_tensor, args.cluster_cache_path)
|
||||||
|
Logger(f"Processed results saved to {args.cluster_cache_path}")
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"Failed to save processed results: {e}")
|
||||||
|
|
||||||
|
# 4. 初始化模型的knowledge_dataset
|
||||||
|
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
|
||||||
|
model.knowledge_dataset.knowledge_dataset.data.copy_(processed_tensor)
|
||||||
|
Logger("Successfully initialized model.knowledge_dataset.knowledge_dataset with processed data")
|
||||||
|
else:
|
||||||
|
Logger("Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize")
|
||||||
|
# 存储为全局变量作为备选
|
||||||
|
globals()['processed_database'] = processed_tensor
|
||||||
|
|
||||||
|
Logger(f"Database embeddings and sentences stored in model")
|
||||||
|
|
||||||
|
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb):
|
||||||
|
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||||
|
epoch_start_time = time.time()
|
||||||
|
total_steps_in_epoch = len(train_loader)
|
||||||
|
total_training_steps = args.epochs * total_steps_in_epoch
|
||||||
|
moe_path = '_moe' if args.use_moe else ''
|
||||||
|
best_loss = float('10000')
|
||||||
|
|
||||||
|
# 添加CUDA事件来分析性能 (只在主进程进行)
|
||||||
|
if args.profile and accelerator.is_main_process:
|
||||||
|
data_start = torch.cuda.Event(enable_timing=True)
|
||||||
|
data_end = torch.cuda.Event(enable_timing=True)
|
||||||
|
forward_start = torch.cuda.Event(enable_timing=True)
|
||||||
|
forward_end = torch.cuda.Event(enable_timing=True)
|
||||||
|
backward_start = torch.cuda.Event(enable_timing=True)
|
||||||
|
backward_end = torch.cuda.Event(enable_timing=True)
|
||||||
|
optimizer_start = torch.cuda.Event(enable_timing=True)
|
||||||
|
optimizer_end = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
# 预取数据
|
||||||
|
prefetch_factor = 2 # 预取的批次数
|
||||||
|
data_iter = iter(train_loader)
|
||||||
|
prefetch_batches = []
|
||||||
|
|
||||||
|
# 预取初始批次
|
||||||
|
for _ in range(min(prefetch_factor, len(train_loader))):
|
||||||
|
try:
|
||||||
|
batch = next(data_iter)
|
||||||
|
prefetch_batches.append(batch)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 在开始循环前初始化日志记录所需变量
|
||||||
|
last_log_time = epoch_start_time
|
||||||
|
|
||||||
|
for step in range(total_steps_in_epoch):
|
||||||
|
try:
|
||||||
|
# 计时数据加载 (只在主进程进行)
|
||||||
|
if args.profile and accelerator.is_main_process:
|
||||||
|
data_start.record()
|
||||||
|
|
||||||
|
# 使用预取的数据
|
||||||
|
if prefetch_batches:
|
||||||
|
X, Y, loss_mask = prefetch_batches.pop(0)
|
||||||
|
else:
|
||||||
|
# 如果预取队列为空,直接加载
|
||||||
|
X, Y, loss_mask = next(data_iter)
|
||||||
|
|
||||||
|
# 异步预取下一批数据
|
||||||
|
if step + prefetch_factor < len(train_loader):
|
||||||
|
try:
|
||||||
|
batch = next(data_iter)
|
||||||
|
prefetch_batches.append(batch)
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 计时数据加载结束 (只在主进程进行)
|
||||||
|
if args.profile and accelerator.is_main_process:
|
||||||
|
data_end.record()
|
||||||
|
|
||||||
|
# 更新学习率
|
||||||
|
if scheduler is not None:
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
# 计时前向传播 (只在主进程进行)
|
||||||
|
if args.profile and accelerator.is_main_process:
|
||||||
|
forward_start.record()
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
with ctx:
|
||||||
|
if step == 0 and args.embedding_epoch == epoch:
|
||||||
|
# 需要设置原始模型的freeze_embedding属性,而不是包装后的模型
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
unwrapped_model.freeze_embedding = True
|
||||||
|
Logger(f"Set freeze_embedding=True for epoch {epoch}, step {step}", accelerator)
|
||||||
|
res = model(X, step=step)
|
||||||
|
loss = loss_fct(
|
||||||
|
res.logits.view(-1, res.logits.size(-1)),
|
||||||
|
Y.view(-1)
|
||||||
|
).view(Y.size())
|
||||||
|
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||||
|
# 添加辅助损失,如果存在的话
|
||||||
|
try:
|
||||||
|
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
|
||||||
|
if hasattr(l.feed_forward, 'aux_loss'))
|
||||||
|
loss += aux_loss
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"Warning: Could not add auxiliary loss: {e}")
|
||||||
|
# 如果出错,不添加辅助损失
|
||||||
|
loss = loss / args.accumulation_steps
|
||||||
|
|
||||||
|
# 计时前向传播结束 (只在主进程进行)
|
||||||
|
if args.profile and accelerator.is_main_process:
|
||||||
|
forward_end.record()
|
||||||
|
|
||||||
|
# 计时反向传播 (只在主进程进行)
|
||||||
|
if args.profile and accelerator.is_main_process:
|
||||||
|
backward_start.record()
|
||||||
|
|
||||||
|
# 反向传播
|
||||||
|
# 当使用DeepSpeed时,它会自动处理梯度累积和梯度裁剪
|
||||||
|
accelerator.backward(loss)
|
||||||
|
|
||||||
|
# 计时反向传播结束 (只在主进程进行)
|
||||||
|
if args.profile and accelerator.is_main_process:
|
||||||
|
backward_end.record()
|
||||||
|
|
||||||
|
# 计时优化器步骤 (只在主进程进行)
|
||||||
|
if args.profile and accelerator.is_main_process:
|
||||||
|
optimizer_start.record()
|
||||||
|
|
||||||
|
# 优化器步骤 - 当使用DeepSpeed时,它会自动处理梯度累积和梯度裁剪
|
||||||
|
# 只有在达到累积步数时才会执行优化器步骤
|
||||||
|
# 注意:当使用DeepSpeed时,它会自动处理梯度累积,所以我们不需要检查step % accumulation_steps
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# 当使用DeepSpeed时,zero_grad()会在step()之后自动调用
|
||||||
|
# 但为了安全起见,我们仍然显式调用它
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# 计时优化器步骤结束 (只在主进程进行)
|
||||||
|
if args.profile and accelerator.is_main_process:
|
||||||
|
optimizer_end.record()
|
||||||
|
|
||||||
|
# 打印训练信息 (只在主进程进行)
|
||||||
|
if (step + 1) % args.log_interval == 0 and accelerator.is_main_process:
|
||||||
|
current_time = time.time()
|
||||||
|
# 计算性能指标
|
||||||
|
if args.profile:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# 使用自上次日志以来的时间计算性能指标,而不是总时间
|
||||||
|
data_time = data_start.elapsed_time(data_end)
|
||||||
|
forward_time = forward_start.elapsed_time(forward_end)
|
||||||
|
backward_time = backward_start.elapsed_time(backward_end)
|
||||||
|
optimizer_time = optimizer_start.elapsed_time(optimizer_end)
|
||||||
|
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
|
||||||
|
# total_time_ms = data_time + forward_time + backward_time + optimizer_time
|
||||||
|
|
||||||
|
# 打印性能分析
|
||||||
|
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
|
||||||
|
Logger(f"性能分析 (Avg/iter over last {args.log_interval} steps) - "
|
||||||
|
f"Data: {data_time/args.log_interval:.2f}ms, "
|
||||||
|
f"Fwd: {forward_time/args.log_interval:.2f}ms, "
|
||||||
|
f"Bwd: {backward_time/args.log_interval:.2f}ms, "
|
||||||
|
f"Optim: {optimizer_time/args.log_interval:.2f}ms, "
|
||||||
|
f"Iter Time: {iter_time:.2f}ms", accelerator)
|
||||||
|
# 重置事件以便下次测量从0开始
|
||||||
|
data_start = torch.cuda.Event(enable_timing=True)
|
||||||
|
data_end = torch.cuda.Event(enable_timing=True)
|
||||||
|
forward_start = torch.cuda.Event(enable_timing=True)
|
||||||
|
forward_end = torch.cuda.Event(enable_timing=True)
|
||||||
|
backward_start = torch.cuda.Event(enable_timing=True)
|
||||||
|
backward_end = torch.cuda.Event(enable_timing=True)
|
||||||
|
optimizer_start = torch.cuda.Event(enable_timing=True)
|
||||||
|
optimizer_end = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
|
||||||
|
# 计算当前学习率
|
||||||
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
|
|
||||||
|
# 计算时间
|
||||||
|
epoch_elapsed_time = current_time - epoch_start_time
|
||||||
|
epoch_steps_done = step + 1
|
||||||
|
epoch_avg_step_time = epoch_elapsed_time / epoch_steps_done
|
||||||
|
epoch_remaining_time = epoch_avg_step_time * (total_steps_in_epoch - epoch_steps_done)
|
||||||
|
|
||||||
|
total_elapsed_time = current_time - overall_start_time
|
||||||
|
total_steps_done = epoch * total_steps_in_epoch + epoch_steps_done
|
||||||
|
total_avg_step_time = total_elapsed_time / total_steps_done if total_steps_done > 0 else 0
|
||||||
|
total_remaining_time = total_avg_step_time * (total_training_steps - total_steps_done) if total_steps_done > 0 else 0
|
||||||
|
|
||||||
|
# 计算训练速度 (基于最近的log_interval)
|
||||||
|
interval_elapsed_time = current_time - last_log_time
|
||||||
|
tokens_processed_interval = args.log_interval * args.batch_size * args.max_seq_len
|
||||||
|
tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0
|
||||||
|
last_log_time = current_time # 更新上次日志时间
|
||||||
|
|
||||||
|
log_dict = {
|
||||||
|
"epoch": epoch + 1,
|
||||||
|
"step": step + 1,
|
||||||
|
"total_steps_in_epoch": total_steps_in_epoch,
|
||||||
|
"loss": loss.item() * args.accumulation_steps,
|
||||||
|
"lr": current_lr,
|
||||||
|
"tokens_per_sec": tokens_per_sec,
|
||||||
|
"epoch_time_left_seconds": epoch_remaining_time,
|
||||||
|
"total_time_left_seconds": total_remaining_time
|
||||||
|
}
|
||||||
|
|
||||||
|
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
|
||||||
|
f"Loss: {log_dict['loss']:.4f}, "
|
||||||
|
f"LR: {log_dict['lr']:.6f}, "
|
||||||
|
f"Speed: {log_dict['tokens_per_sec']:.2f} tokens/sec | "
|
||||||
|
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
|
||||||
|
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
|
||||||
|
|
||||||
|
if args.use_wandb and accelerator.is_main_process and wandb:
|
||||||
|
wandb.log(log_dict)
|
||||||
|
|
||||||
|
# 保存模型 (只在主进程进行)
|
||||||
|
loss_total = loss.item() * args.accumulation_steps
|
||||||
|
if best_loss > loss_total and accelerator.is_main_process:
|
||||||
|
best_loss = loss_total
|
||||||
|
# 使用函数开始处定义的moe_path变量
|
||||||
|
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
|
||||||
|
|
||||||
|
# 获取解包后的模型
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
|
||||||
|
# 保存模型参数
|
||||||
|
accelerator.save(unwrapped_model.state_dict(), ckp)
|
||||||
|
Logger(f"Model saved to {ckp}", accelerator)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"Error in training step: {e}", accelerator)
|
||||||
|
import traceback
|
||||||
|
Logger(traceback.format_exc(), accelerator)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
|
||||||
|
parser.add_argument("--out_dir", type=str, default="out")
|
||||||
|
parser.add_argument("--epochs", type=int, default=4)
|
||||||
|
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=64)
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||||||
|
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||||
|
parser.add_argument("--use_wandb", default=True, action="store_true")
|
||||||
|
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=8)
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default=32)
|
||||||
|
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||||
|
parser.add_argument("--warmup_iters", type=int, default=0)
|
||||||
|
parser.add_argument("--log_interval", type=int, default=100)
|
||||||
|
parser.add_argument("--save_interval", type=int, default=10000)
|
||||||
|
parser.add_argument('--dim', default=512, type=int)
|
||||||
|
parser.add_argument('--n_layers', default=8, type=int)
|
||||||
|
parser.add_argument('--max_seq_len', default=512, type=int)
|
||||||
|
parser.add_argument('--use_moe', default=False, type=bool)
|
||||||
|
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
||||||
|
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
|
||||||
|
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||||||
|
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||||
|
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||||
|
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||||
|
parser.add_argument("--knowledge_num", type=int, default=8192,help="知识库的数据数目")
|
||||||
|
parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
|
||||||
|
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
|
||||||
|
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
|
||||||
|
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
|
||||||
|
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 初始化accelerator和deepspeed
|
||||||
|
#########################################################
|
||||||
|
# 设置ddp_kwargs以处理未使用的参数
|
||||||
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||||
|
# 创建DeepSpeedPlugin对象
|
||||||
|
ds_plugin = DeepSpeedPlugin(
|
||||||
|
gradient_accumulation_steps=args.accumulation_steps,
|
||||||
|
gradient_clipping=args.grad_clip,
|
||||||
|
zero_stage=2, # 使用ZeRO-2优化
|
||||||
|
offload_optimizer_device="cpu", # 将优化器状态卸载到CPU
|
||||||
|
offload_param_device="none", # 不将参数卸载到CPU
|
||||||
|
)
|
||||||
|
accelerator = Accelerator(
|
||||||
|
kwargs_handlers=[ddp_kwargs],
|
||||||
|
deepspeed_plugin=ds_plugin,
|
||||||
|
mixed_precision="bf16" if args.dtype == "bfloat16" else "fp16" if args.dtype == "float16" else "no"
|
||||||
|
)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 设置随机种子
|
||||||
|
#########################################################
|
||||||
|
set_seed(1337 + accelerator.process_index)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 配置模型
|
||||||
|
#########################################################
|
||||||
|
lm_config = LMConfig(
|
||||||
|
dim=args.dim,
|
||||||
|
n_layers=args.n_layers,
|
||||||
|
max_seq_len=args.max_seq_len,
|
||||||
|
use_moe=args.use_moe,
|
||||||
|
disable_db=args.disable_db,
|
||||||
|
flash_attn=args.use_flash_attn,
|
||||||
|
knowledge_num=args.knowledge_num,
|
||||||
|
knowledge_length=args.knowledge_length,
|
||||||
|
embeddings_epoch=args.embedding_epoch
|
||||||
|
)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 创建保存目录
|
||||||
|
#########################################################
|
||||||
|
args.save_dir = os.path.join(args.out_dir)
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
os.makedirs(args.save_dir, exist_ok=True)
|
||||||
|
os.makedirs(args.out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 设置数据类型
|
||||||
|
#########################################################
|
||||||
|
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
|
||||||
|
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 配置wandb
|
||||||
|
#########################################################
|
||||||
|
# 设置wandb运行名称
|
||||||
|
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||||
|
if args.use_wandb and accelerator.is_main_process:
|
||||||
|
import wandb
|
||||||
|
# 合并args和lm_config为一个字典
|
||||||
|
config_dict = vars(args).copy()
|
||||||
|
config_dict.update(vars(lm_config))
|
||||||
|
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config_dict)
|
||||||
|
else:
|
||||||
|
wandb = None
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 打印信息
|
||||||
|
#########################################################
|
||||||
|
# 计算每次迭代的token数量
|
||||||
|
tokens_per_iter = args.batch_size * lm_config.max_seq_len
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
Logger(f"tokens_per_iter: {tokens_per_iter}", accelerator)
|
||||||
|
Logger("Configuration:", accelerator)
|
||||||
|
for key, value in config_dict.items():
|
||||||
|
Logger(f" {key}: {value}", accelerator)
|
||||||
|
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 设置自动混合精度上下文
|
||||||
|
#########################################################
|
||||||
|
ctx = nullcontext() if accelerator.device.type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 初始化模型和tokenizer
|
||||||
|
#########################################################
|
||||||
|
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path, args.database_init_path, args)
|
||||||
|
# 将accelerator传递给init_model函数中的Logger调用
|
||||||
|
Logger(f'模型初始化完成', accelerator)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 处理位置编码张量问题
|
||||||
|
#########################################################
|
||||||
|
if hasattr(model, "pos_cis_real"):
|
||||||
|
Logger(f'检测到pos_cis_real实数张量,将其设置为参与分布式训练', accelerator)
|
||||||
|
# 设置模型的_ddp_params_and_buffers_to_ignore属性
|
||||||
|
# model._ddp_params_and_buffers_to_ignore = {"pos_cis_real"}
|
||||||
|
# 兼容旧版本,检查是否仍有pos_cis
|
||||||
|
elif hasattr(model, "pos_cis"):
|
||||||
|
Logger(f'检测到pos_cis复数张量,将其设置为不参与分布式训练', accelerator)
|
||||||
|
# 设置模型的_ddp_params_and_buffers_to_ignore属性
|
||||||
|
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 创建数据集和数据加载器
|
||||||
|
#########################################################
|
||||||
|
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_ds,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=False,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
persistent_workers=True if args.num_workers > 0 else False,
|
||||||
|
prefetch_factor=2 if args.num_workers > 0 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 创建优化器
|
||||||
|
#########################################################
|
||||||
|
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 创建学习率调度器
|
||||||
|
#########################################################
|
||||||
|
total_steps = len(train_loader) * args.epochs
|
||||||
|
warmup_steps = args.warmup_iters if args.warmup_iters > 0 else int(0.1 * total_steps)
|
||||||
|
scheduler = get_cosine_schedule_with_warmup(
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=warmup_steps,
|
||||||
|
num_training_steps=total_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 准备训练
|
||||||
|
#########################################################
|
||||||
|
model, optimizer, train_loader, scheduler = accelerator.prepare(
|
||||||
|
model, optimizer, train_loader, scheduler
|
||||||
|
)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 训练循环
|
||||||
|
#########################################################
|
||||||
|
overall_start_time = time.time() # Record overall start time
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb) # Pass overall start time
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 关闭wandb
|
||||||
|
#########################################################
|
||||||
|
if args.use_wandb and accelerator.is_main_process:
|
||||||
|
wandb.finish()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|