update
@ -1,128 +0,0 @@
|
|||||||
# Contributor Covenant Code of Conduct
|
|
||||||
|
|
||||||
## Our Pledge
|
|
||||||
|
|
||||||
We as members, contributors, and leaders pledge to make participation in our
|
|
||||||
community a harassment-free experience for everyone, regardless of age, body
|
|
||||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
|
||||||
identity and expression, level of experience, education, socio-economic status,
|
|
||||||
nationality, personal appearance, race, religion, or sexual identity
|
|
||||||
and orientation.
|
|
||||||
|
|
||||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
|
||||||
diverse, inclusive, and healthy community.
|
|
||||||
|
|
||||||
## Our Standards
|
|
||||||
|
|
||||||
Examples of behavior that contributes to a positive environment for our
|
|
||||||
community include:
|
|
||||||
|
|
||||||
* Demonstrating empathy and kindness toward other people
|
|
||||||
* Being respectful of differing opinions, viewpoints, and experiences
|
|
||||||
* Giving and gracefully accepting constructive feedback
|
|
||||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
|
||||||
and learning from the experience
|
|
||||||
* Focusing on what is best not just for us as individuals, but for the
|
|
||||||
overall community
|
|
||||||
|
|
||||||
Examples of unacceptable behavior include:
|
|
||||||
|
|
||||||
* The use of sexualized language or imagery, and sexual attention or
|
|
||||||
advances of any kind
|
|
||||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
|
||||||
* Public or private harassment
|
|
||||||
* Publishing others' private information, such as a physical or email
|
|
||||||
address, without their explicit permission
|
|
||||||
* Other conduct which could reasonably be considered inappropriate in a
|
|
||||||
professional setting
|
|
||||||
|
|
||||||
## Enforcement Responsibilities
|
|
||||||
|
|
||||||
Community leaders are responsible for clarifying and enforcing our standards of
|
|
||||||
acceptable behavior and will take appropriate and fair corrective action in
|
|
||||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
|
||||||
or harmful.
|
|
||||||
|
|
||||||
Community leaders have the right and responsibility to remove, edit, or reject
|
|
||||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
|
||||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
|
||||||
decisions when appropriate.
|
|
||||||
|
|
||||||
## Scope
|
|
||||||
|
|
||||||
This Code of Conduct applies within all community spaces, and also applies when
|
|
||||||
an individual is officially representing the community in public spaces.
|
|
||||||
Examples of representing our community include using an official e-mail address,
|
|
||||||
posting via an official social media account, or acting as an appointed
|
|
||||||
representative at an online or offline event.
|
|
||||||
|
|
||||||
## Enforcement
|
|
||||||
|
|
||||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
|
||||||
reported to the community leaders responsible for enforcement at
|
|
||||||
.
|
|
||||||
All complaints will be reviewed and investigated promptly and fairly.
|
|
||||||
|
|
||||||
All community leaders are obligated to respect the privacy and security of the
|
|
||||||
reporter of any incident.
|
|
||||||
|
|
||||||
## Enforcement Guidelines
|
|
||||||
|
|
||||||
Community leaders will follow these Community Impact Guidelines in determining
|
|
||||||
the consequences for any action they deem in violation of this Code of Conduct:
|
|
||||||
|
|
||||||
### 1. Correction
|
|
||||||
|
|
||||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
|
||||||
unprofessional or unwelcome in the community.
|
|
||||||
|
|
||||||
**Consequence**: A private, written warning from community leaders, providing
|
|
||||||
clarity around the nature of the violation and an explanation of why the
|
|
||||||
behavior was inappropriate. A public apology may be requested.
|
|
||||||
|
|
||||||
### 2. Warning
|
|
||||||
|
|
||||||
**Community Impact**: A violation through a single incident or series
|
|
||||||
of actions.
|
|
||||||
|
|
||||||
**Consequence**: A warning with consequences for continued behavior. No
|
|
||||||
interaction with the people involved, including unsolicited interaction with
|
|
||||||
those enforcing the Code of Conduct, for a specified period of time. This
|
|
||||||
includes avoiding interactions in community spaces as well as external channels
|
|
||||||
like social media. Violating these terms may lead to a temporary or
|
|
||||||
permanent ban.
|
|
||||||
|
|
||||||
### 3. Temporary Ban
|
|
||||||
|
|
||||||
**Community Impact**: A serious violation of community standards, including
|
|
||||||
sustained inappropriate behavior.
|
|
||||||
|
|
||||||
**Consequence**: A temporary ban from any sort of interaction or public
|
|
||||||
communication with the community for a specified period of time. No public or
|
|
||||||
private interaction with the people involved, including unsolicited interaction
|
|
||||||
with those enforcing the Code of Conduct, is allowed during this period.
|
|
||||||
Violating these terms may lead to a permanent ban.
|
|
||||||
|
|
||||||
### 4. Permanent Ban
|
|
||||||
|
|
||||||
**Community Impact**: Demonstrating a pattern of violation of community
|
|
||||||
standards, including sustained inappropriate behavior, harassment of an
|
|
||||||
individual, or aggression toward or disparagement of classes of individuals.
|
|
||||||
|
|
||||||
**Consequence**: A permanent ban from any sort of public interaction within
|
|
||||||
the community.
|
|
||||||
|
|
||||||
## Attribution
|
|
||||||
|
|
||||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
|
||||||
version 2.0, available at
|
|
||||||
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
|
||||||
|
|
||||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
|
||||||
enforcement ladder](https://github.com/mozilla/diversity).
|
|
||||||
|
|
||||||
[homepage]: https://www.contributor-covenant.org
|
|
||||||
|
|
||||||
For answers to common questions about this code of conduct, see the FAQ at
|
|
||||||
https://www.contributor-covenant.org/faq. Translations are available at
|
|
||||||
https://www.contributor-covenant.org/translations.
|
|
199
README.md
@ -1,199 +0,0 @@
|
|||||||
<div align="center">
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div align="center">
|
|
||||||
|
|
||||||

|
|
||||||
[](https://github.com/jingyaogong/minimind/stargazers)
|
|
||||||
[](LICENSE)
|
|
||||||
[](https://github.com/jingyaogong/minimind/commits/master)
|
|
||||||
[](https://github.com/jingyaogong/minimind/pulls)
|
|
||||||
[](https://huggingface.co/collections/jingyaogong/minimind-66caf8d999f5c7fa64f399e5)
|
|
||||||
|
|
||||||
</div>
|
|
||||||
|
|
||||||
|
|
||||||
# 📌 数据介绍
|
|
||||||
|
|
||||||
## Ⅰ Tokenizer
|
|
||||||
|
|
||||||
分词器将单词从自然语言通过“词典”映射到`0, 1, 36`这样的数字,可以理解为数字就代表了单词在“词典”中的页码。
|
|
||||||
可以选择自己构造词表训练一个“词典”,代码可见`./scripts/train_tokenizer.py`(仅供学习参考,若非必要无需再自行训练,MiniMind已自带tokenizer)。
|
|
||||||
或者选择比较出名的开源大模型分词器,
|
|
||||||
正如同直接用新华/牛津词典的优点是token编码压缩率很好,缺点是页数太多,动辄数十万个词汇短语;
|
|
||||||
自己训练的分词器,优点是词表长度和内容随意控制,缺点是压缩率很低(例如"hello"也许会被拆分为"h e l l o"
|
|
||||||
五个独立的token),且生僻词难以覆盖。
|
|
||||||
“词典”的选择固然很重要,LLM的输出本质上是SoftMax到词典N个词的多分类问题,然后通过“词典”解码到自然语言。
|
|
||||||
因为MiniMind体积需要严格控制,为了避免模型头重脚轻(词嵌入embedding层参数在LLM占比太高),所以词表长度短短益善。
|
|
||||||
|
|
||||||
<details style="color:rgb(128,128,128)">
|
|
||||||
<summary>Tokenizer介绍</summary>
|
|
||||||
|
|
||||||
第三方强大的开源模型例如Yi、qwen、chatglm、mistral、Llama3的tokenizer词表长度如下:
|
|
||||||
|
|
||||||
<table>
|
|
||||||
<tr><th>Tokenizer模型</th><th>词表大小</th><th>来源</th></tr>
|
|
||||||
<tr><td>yi tokenizer</td><td>64,000</td><td>01万物(中国)</td></tr>
|
|
||||||
<tr><td>qwen2 tokenizer</td><td>151,643</td><td>阿里云(中国)</td></tr>
|
|
||||||
<tr><td>glm tokenizer</td><td>151,329</td><td>智谱AI(中国)</td></tr>
|
|
||||||
<tr><td>mistral tokenizer</td><td>32,000</td><td>Mistral AI(法国)</td></tr>
|
|
||||||
<tr><td>llama3 tokenizer</td><td>128,000</td><td>Meta(美国)</td></tr>
|
|
||||||
<tr><td>minimind tokenizer</td><td>6,400</td><td>自定义</td></tr>
|
|
||||||
</table>
|
|
||||||
|
|
||||||
> 👉2024-09-17更新:为了防止过去的版本歧义&控制体积,minimind所有模型均使用minimind_tokenizer分词,废弃所有mistral_tokenizer版本。
|
|
||||||
|
|
||||||
```
|
|
||||||
# 一些自言自语
|
|
||||||
> 尽管minimind_tokenizer长度很小,编解码效率弱于qwen2、glm等中文友好型分词器。
|
|
||||||
> 但minimind模型选择了自己训练的minimind_tokenizer作为分词器,以保持整体参数轻量,避免编码层和计算层占比失衡,头重脚轻,因为minimind的词表大小只有6400。
|
|
||||||
> 且minimind在实际测试中没有出现过生僻词汇解码失败的情况,效果良好。
|
|
||||||
> 由于自定义词表压缩长度到6400,使得LLM总参数量最低只有25.8M。
|
|
||||||
> 训练数据`tokenizer_train.jsonl`均来自于`匠数大模型数据集`,这部分数据相对次要,如需训练可以自由选择。
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
## Ⅱ Pretrain数据
|
|
||||||
|
|
||||||
经历了MiniMind-V1的低质量预训练数据,导致模型胡言乱语的教训,`2025-02-05` 之后决定不再采用大规模无监督的数据集做预训练。
|
|
||||||
进而尝试把[匠数大模型数据集](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)的中文部分提取出来,
|
|
||||||
清洗出字符`<512`长度的大约1.6GB的语料直接拼接成预训练数据 `pretrain_hq.jsonl`,hq即为high
|
|
||||||
quality(当然也还不算high,提升数据质量无止尽)。
|
|
||||||
|
|
||||||
文件`pretrain_hq.jsonl` 数据格式为
|
|
||||||
|
|
||||||
```bash
|
|
||||||
{"text": "如何才能摆脱拖延症? 治愈拖延症并不容易,但以下建议可能有所帮助..."}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Ⅲ SFT数据
|
|
||||||
|
|
||||||
[匠数大模型SFT数据集](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
|
||||||
“是一个完整、格式统一、安全的大模型训练和研究资源。
|
|
||||||
从网络上的公开数据源收集并整理了大量开源数据集,对其进行了格式统一,数据清洗,
|
|
||||||
包含10M条数据的中文数据集和包含2M条数据的英文数据集。”
|
|
||||||
以上是官方介绍,下载文件后的数据总量大约在4B tokens,肯定是适合作为中文大语言模型的SFT数据的。
|
|
||||||
但是官方提供的数据格式很乱,全部用来sft代价太大。
|
|
||||||
我将把官方数据集进行了二次清洗,把含有符号污染和噪声的条目去除;另外依然只保留了总长度`<512`
|
|
||||||
的内容,此阶段希望通过大量对话补充预训练阶段欠缺的知识。
|
|
||||||
导出文件为`sft_512.jsonl`(~7.5GB)。
|
|
||||||
|
|
||||||
[Magpie-SFT数据集](https://www.modelscope.cn/organization/Magpie-Align)
|
|
||||||
收集了~1M条来自Qwen2/2.5的高质量对话,我将这部分数据进一步清洗,把总长度`<2048`的部分导出为`sft_2048.jsonl`(~9GB)。
|
|
||||||
长度`<1024`的部分导出为`sft_1024.jsonl`(~5.5GB),用大模型对话数据直接进行sft就属于“黑盒蒸馏”的范畴。
|
|
||||||
|
|
||||||
进一步清洗前两步sft的数据(只保留中文字符占比高的内容),筛选长度`<512`的对话,得到`sft_mini_512.jsonl`(~1.2GB)。
|
|
||||||
|
|
||||||
所有sft文件 `sft_X.jsonl` 数据格式均为
|
|
||||||
|
|
||||||
```text
|
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{"role": "user", "content": "你好"},
|
|
||||||
{"role": "assistant", "content": "你好!"},
|
|
||||||
{"role": "user", "content": "再见"},
|
|
||||||
{"role": "assistant", "content": "再见!"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Ⅳ RLHF数据
|
|
||||||
|
|
||||||
来自[Magpie-DPO数据集](https://www.modelscope.cn/datasets/Magpie-Align/MagpieLM-DPO-Data-v0.1)
|
|
||||||
大约200k条偏好数据(均是英文)生成自Llama3.1-70B/8B,可以用于训练奖励模型,优化模型回复质量,使其更加符合人类偏好。
|
|
||||||
这里将数据总长度`<3000`的内容重组为`dpo.jsonl`(~0.9GB),包含`chosen`和`rejected`两个字段,`chosen`
|
|
||||||
为偏好的回复,`rejected`为拒绝的回复。
|
|
||||||
|
|
||||||
文件 `dpo.jsonl` 数据格式为
|
|
||||||
|
|
||||||
```text
|
|
||||||
{
|
|
||||||
"chosen": [
|
|
||||||
{"content": "Q", "role": "user"},
|
|
||||||
{"content": "good answer", "role": "assistant"}
|
|
||||||
],
|
|
||||||
"rejected": [
|
|
||||||
{"content": "Q", "role": "user"},
|
|
||||||
{"content": "bad answer", "role": "assistant"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Ⅴ Reason数据集:
|
|
||||||
|
|
||||||
不得不说2025年2月谁能火的过DeepSeek...
|
|
||||||
也激发了我对RL引导的推理模型的浓厚兴趣,目前已经用Qwen2.5复现了R1-Zero。
|
|
||||||
如果有时间+效果work(但99%基模能力不足)我会在之后更新MiniMind基于RL训练的推理模型而不是蒸馏模型。
|
|
||||||
时间有限,最快的低成本方案依然是直接蒸馏(黑盒方式)。
|
|
||||||
耐不住R1太火,短短几天就已经存在一些R1的蒸馏数据集[R1-Llama-70B](https://www.modelscope.cn/datasets/Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B)、[R1-Distill-SFT](https://www.modelscope.cn/datasets/AI-ModelScope/R1-Distill-SFT)、
|
|
||||||
[Alpaca-Distill-R1](https://huggingface.co/datasets/shareAI/Alpaca-Distill-R1-ZH)、
|
|
||||||
[deepseek_r1_zh](https://huggingface.co/datasets/jinliuxi/deepseek_r1_zh)等等,纯中文的数据可能比较少。
|
|
||||||
最终整合它们,导出文件为`r1_mix_1024.jsonl`,数据格式和`sft_X.jsonl`一致。
|
|
||||||
|
|
||||||
## Ⅵ 更多数据集
|
|
||||||
|
|
||||||
目前已经有[HqWu-HITCS/Awesome-Chinese-LLM](https://github.com/HqWu-HITCS/Awesome-Chinese-LLM)
|
|
||||||
在收集和梳理中文LLM相关的开源模型、应用、数据集及教程等资料,并持续更新这方面的最新进展。全面且专业,Respect!
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Ⅷ 数据集下载
|
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> 2025-02-05后,开源MiniMind最终训练所用的所有数据集,因此无需再自行预处理大规模数据集,避免重复性的数据处理工作。
|
|
||||||
|
|
||||||
MiniMind训练数据集 ([ModelScope](https://www.modelscope.cn/datasets/gongjy/minimind-dataset/files) | [HuggingFace](https://huggingface.co/datasets/jingyaogong))
|
|
||||||
|
|
||||||
> 无需全部clone,可单独下载所需的文件
|
|
||||||
|
|
||||||
将下载的数据集文件放到`./dataset/`目录下(✨为推荐的必须项)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./dataset/
|
|
||||||
├── dpo.jsonl (909MB)
|
|
||||||
├── lora_identity.jsonl (22.8KB)
|
|
||||||
├── lora_medical.jsonl (34MB)
|
|
||||||
├── pretrain_hq.jsonl (1.6GB, ✨)
|
|
||||||
├── r1_mix_1024.jsonl (340MB)
|
|
||||||
├── sft_1024.jsonl (5.6GB)
|
|
||||||
├── sft_2048.jsonl (9GB)
|
|
||||||
├── sft_512.jsonl (7.5GB)
|
|
||||||
├── sft_mini_512.jsonl (1.2GB, ✨)
|
|
||||||
└── tokenizer_train.jsonl (1GB)
|
|
||||||
```
|
|
||||||
|
|
||||||
<details style="color:rgb(128,128,128)">
|
|
||||||
<summary>注:各数据集简介</summary>
|
|
||||||
|
|
||||||
* `dpo.jsonl` --RLHF阶段数据集
|
|
||||||
* `lora_identity.jsonl` --自我认知数据集(例如:你是谁?我是minimind...),推荐用于lora训练(亦可用于全参SFT,勿被名字局限)
|
|
||||||
* `lora_medical.jsonl` --医疗问答数据集,推荐用于lora训练(亦可用于全参SFT,勿被名字局限)
|
|
||||||
* `pretrain_hq.jsonl`✨ --预训练数据集,整合自jiangshu科技
|
|
||||||
* `r1_mix_1024.jsonl` --DeepSeek-R1-1.5B蒸馏数据,每条数据字符最大长度为1024(因此训练时设置max_seq_len=1024)
|
|
||||||
* `sft_1024.jsonl` --整合自Qwen2.5蒸馏数据(是sft_2048的子集),每条数据字符最大长度为1024(因此训练时设置max_seq_len=1024)
|
|
||||||
* `sft_2048.jsonl` --整合自Qwen2.5蒸馏数据,每条数据字符最大长度为2048(因此训练时设置max_seq_len=2048)
|
|
||||||
* `sft_512.jsonl` --整合自匠数科技SFT数据,每条数据字符最大长度为512(因此训练时设置max_seq_len=512)
|
|
||||||
* `sft_mini_512.jsonl`✨ --极简整合自匠数科技SFT数据+Qwen2.5蒸馏数据(用于快速训练Zero模型),每条数据字符最大长度为512(因此训练时设置max_seq_len=512)
|
|
||||||
* `tokenizer_train.jsonl` --均来自于`匠数大模型数据集`,这部分数据相对次要,(不推荐自己重复训练tokenizer,理由如上)如需自己训练tokenizer可以自由选择数据集。
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
<details style="color:rgb(128,128,128)">
|
|
||||||
<summary>说明 & 推荐训练方案</summary>
|
|
||||||
|
|
||||||
* MiniMind2 Series均经过共约20GB语料训练,大约4B tokens,即对应上面的数据组合训练结果(开销:💰💰💰💰💰💰💰💰,效果:😊😊😊😊😊😊)
|
|
||||||
|
|
||||||
* 想要最快速度从0实现Zero模型,推荐使用`pretrain_hq.jsonl` + `sft_mini_512.jsonl` 的数据组合,具体花销和效果可查看下文表格(开销:💰,效果:😊😊)
|
|
||||||
|
|
||||||
* 推荐具备一定算力资源或更在意效果的朋友可以考虑前者完整复现MiniMind2;仅有单卡GPU或在乎短时间快速复现的朋友强烈推荐后者;
|
|
||||||
|
|
||||||
* 【折中方案】亦可选择例如`sft_mini_512.jsonl`、`sft_1024.jsonl`中等规模数据进行自由组合训练(开销:💰💰💰,效果:😊😊😊😊)。
|
|
||||||
|
|
||||||
</details>
|
|
1509
README_en.md
Before Width: | Height: | Size: 136 KiB |
Before Width: | Height: | Size: 73 KiB |
Before Width: | Height: | Size: 230 KiB |
Before Width: | Height: | Size: 104 KiB |
Before Width: | Height: | Size: 239 KiB |
Before Width: | Height: | Size: 121 KiB |
Before Width: | Height: | Size: 372 KiB |
Before Width: | Height: | Size: 178 KiB |
Before Width: | Height: | Size: 150 KiB |
Before Width: | Height: | Size: 519 KiB |
Before Width: | Height: | Size: 146 KiB |
Before Width: | Height: | Size: 66 KiB |
BIN
images/logo.png
Before Width: | Height: | Size: 495 KiB |
BIN
images/logo2.png
Before Width: | Height: | Size: 615 KiB |
Before Width: | Height: | Size: 3.8 MiB |
Before Width: | Height: | Size: 559 KiB |
Before Width: | Height: | Size: 531 KiB |
Before Width: | Height: | Size: 1006 KiB |
Before Width: | Height: | Size: 943 KiB |
@ -36,6 +36,9 @@ class LMConfig(PretrainedConfig):
|
|||||||
aux_loss_alpha: float = 0.1,
|
aux_loss_alpha: float = 0.1,
|
||||||
seq_aux: bool = True,
|
seq_aux: bool = True,
|
||||||
norm_topk_prob: bool = True,
|
norm_topk_prob: bool = True,
|
||||||
|
####################################################
|
||||||
|
knowlwdge_num: int = 64*64,
|
||||||
|
knowlwdge_length: int = 8,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -66,4 +69,7 @@ class LMConfig(PretrainedConfig):
|
|||||||
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
|
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
|
||||||
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
|
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
|
||||||
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
|
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
|
||||||
|
####################################################
|
||||||
|
self.knowlwdge_num = knowlwdge_num
|
||||||
|
self.knowlwdge_length = knowlwdge_length
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
@ -528,15 +528,15 @@ class ExtractDB(nn.Module):
|
|||||||
self.batch_size = None
|
self.batch_size = None
|
||||||
self.dim = params.dim
|
self.dim = params.dim
|
||||||
self.dim_key = self.dim // 2
|
self.dim_key = self.dim // 2
|
||||||
self.num_experts = 10 * 10 # 100专家,确保是完全平方数
|
self.knowlwdge_num = params.knowlwdge_num # 100专家,确保是完全平方数
|
||||||
# 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用
|
# 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用
|
||||||
self.head_dim = params.dim // params.n_heads
|
self.head_dim = params.dim // params.n_heads
|
||||||
self.knowledge_dim = 8*params.dim
|
self.knowledge_length = params.knowlwdge_length*params.dim
|
||||||
|
|
||||||
# 使用register_buffer代替nn.Parameter,避免梯度问题
|
# 使用register_buffer代替nn.Parameter,避免梯度问题
|
||||||
self.register_buffer('weight_down_embed', torch.randn(self.num_experts, self.knowledge_dim) * 0.02)
|
self.register_buffer('weight_down_embed', torch.randn(self.knowlwdge_num, self.knowledge_length) * 0.02)
|
||||||
|
|
||||||
self.num_keys = int(math.sqrt(self.num_experts)) if self.num_experts > 0 else 0
|
self.num_keys = int(math.sqrt(self.knowlwdge_num)) if self.knowlwdge_num > 0 else 0
|
||||||
self.product_key_topk = min(16, self.num_keys)
|
self.product_key_topk = min(16, self.num_keys)
|
||||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
|
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
|
||||||
self.num_experts_per_head_topk = 1
|
self.num_experts_per_head_topk = 1
|
||||||
|
141
requirements.txt
@ -1,30 +1,147 @@
|
|||||||
|
accelerate==1.6.0
|
||||||
|
aiohappyeyeballs==2.6.1
|
||||||
|
aiohttp==3.11.17
|
||||||
|
aiosignal==1.3.2
|
||||||
|
altair==5.5.0
|
||||||
|
annotated-types==0.7.0
|
||||||
|
anyio==4.9.0
|
||||||
|
async-timeout==5.0.1
|
||||||
|
attrs==25.3.0
|
||||||
|
blinker==1.9.0
|
||||||
|
cachetools==5.5.2
|
||||||
|
certifi==2025.1.31
|
||||||
|
charset-normalizer==3.4.1
|
||||||
|
click==8.1.8
|
||||||
|
contourpy==1.3.2
|
||||||
|
cycler==0.12.1
|
||||||
datasets==2.21.0
|
datasets==2.21.0
|
||||||
datasketch==1.6.4
|
datasketch==1.6.4
|
||||||
|
deepspeed==0.16.7
|
||||||
|
dill==0.3.8
|
||||||
|
distro==1.9.0
|
||||||
|
docker-pycreds==0.4.0
|
||||||
|
einops==0.8.1
|
||||||
|
exceptiongroup==1.2.2
|
||||||
|
filelock==3.18.0
|
||||||
Flask==3.0.3
|
Flask==3.0.3
|
||||||
Flask_Cors==4.0.0
|
Flask-Cors==4.0.0
|
||||||
|
fonttools==4.57.0
|
||||||
|
frozenlist==1.6.0
|
||||||
|
fsspec==2024.6.1
|
||||||
|
gitdb==4.0.12
|
||||||
|
GitPython==3.1.44
|
||||||
|
h11==0.14.0
|
||||||
|
hjson==3.1.0
|
||||||
|
httpcore==1.0.8
|
||||||
|
httpx==0.28.1
|
||||||
|
huggingface-hub==0.30.2
|
||||||
|
idna==3.10
|
||||||
|
importlib_metadata==7.2.1
|
||||||
|
itsdangerous==2.2.0
|
||||||
jieba==0.42.1
|
jieba==0.42.1
|
||||||
|
Jinja2==3.1.2
|
||||||
|
jiter==0.9.0
|
||||||
|
joblib==1.4.2
|
||||||
jsonlines==4.0.0
|
jsonlines==4.0.0
|
||||||
|
jsonschema==4.23.0
|
||||||
|
jsonschema-specifications==2024.10.1
|
||||||
|
kiwisolver==1.4.8
|
||||||
|
markdown-it-py==3.0.0
|
||||||
|
MarkupSafe==3.0.2
|
||||||
marshmallow==3.22.0
|
marshmallow==3.22.0
|
||||||
matplotlib==3.10.0
|
matplotlib==3.10.0
|
||||||
|
mdurl==0.1.2
|
||||||
|
modelscope==1.25.0
|
||||||
|
mpmath==1.3.0
|
||||||
|
msgpack==1.1.0
|
||||||
|
multidict==6.4.3
|
||||||
|
multiprocess==0.70.16
|
||||||
|
narwhals==1.35.0
|
||||||
|
networkx==3.4.2
|
||||||
ngrok==1.4.0
|
ngrok==1.4.0
|
||||||
|
ninja==1.11.1.4
|
||||||
nltk==3.8
|
nltk==3.8
|
||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
|
nvidia-cublas-cu11==11.11.3.6
|
||||||
|
nvidia-cublas-cu12==12.1.3.1
|
||||||
|
nvidia-cuda-cupti-cu11==11.8.87
|
||||||
|
nvidia-cuda-cupti-cu12==12.1.105
|
||||||
|
nvidia-cuda-nvrtc-cu11==11.8.89
|
||||||
|
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||||
|
nvidia-cuda-runtime-cu11==11.8.89
|
||||||
|
nvidia-cuda-runtime-cu12==12.1.105
|
||||||
|
nvidia-cudnn-cu11==9.1.0.70
|
||||||
|
nvidia-cudnn-cu12==8.9.2.26
|
||||||
|
nvidia-cufft-cu11==10.9.0.58
|
||||||
|
nvidia-cufft-cu12==11.0.2.54
|
||||||
|
nvidia-curand-cu11==10.3.0.86
|
||||||
|
nvidia-curand-cu12==10.3.2.106
|
||||||
|
nvidia-cusolver-cu11==11.4.1.48
|
||||||
|
nvidia-cusolver-cu12==11.4.5.107
|
||||||
|
nvidia-cusparse-cu11==11.7.5.86
|
||||||
|
nvidia-cusparse-cu12==12.1.0.106
|
||||||
|
nvidia-nccl-cu11==2.21.5
|
||||||
|
nvidia-nccl-cu12==2.19.3
|
||||||
|
nvidia-nvjitlink-cu12==12.8.93
|
||||||
|
nvidia-nvtx-cu11==11.8.86
|
||||||
|
nvidia-nvtx-cu12==12.1.105
|
||||||
openai==1.59.6
|
openai==1.59.6
|
||||||
|
packaging==23.2
|
||||||
pandas==1.5.3
|
pandas==1.5.3
|
||||||
peft==0.7.1
|
peft==0.7.1
|
||||||
|
pillow==10.4.0
|
||||||
|
platformdirs==4.3.7
|
||||||
|
propcache==0.3.1
|
||||||
|
protobuf==4.25.6
|
||||||
psutil==5.9.8
|
psutil==5.9.8
|
||||||
|
py-cpuinfo==9.0.0
|
||||||
|
pyarrow==19.0.1
|
||||||
pydantic==2.8.2
|
pydantic==2.8.2
|
||||||
|
pydantic_core==2.20.1
|
||||||
|
pydeck==0.9.1
|
||||||
|
Pygments==2.19.1
|
||||||
|
pyparsing==3.2.3
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
pytz==2025.2
|
||||||
|
PyYAML==6.0.2
|
||||||
|
referencing==0.36.2
|
||||||
|
regex==2024.11.6
|
||||||
|
requests==2.32.3
|
||||||
rich==13.7.1
|
rich==13.7.1
|
||||||
scikit_learn==1.5.1
|
rpds-py==0.24.0
|
||||||
sentence_transformers==2.3.1
|
safetensors==0.5.3
|
||||||
|
scikit-learn==1.5.1
|
||||||
|
scipy==1.15.2
|
||||||
|
sentence-transformers==2.3.1
|
||||||
|
sentencepiece==0.2.0
|
||||||
|
sentry-sdk==2.26.1
|
||||||
|
setproctitle==1.3.5
|
||||||
simhash==2.1.2
|
simhash==2.1.2
|
||||||
tiktoken==0.5.1
|
six==1.17.0
|
||||||
transformers==4.48.0
|
smmap==5.0.2
|
||||||
jinja2==3.1.2
|
sniffio==1.3.1
|
||||||
jsonlines==4.0.0
|
|
||||||
trl==0.13.0
|
|
||||||
ujson==5.1.0
|
|
||||||
wandb==0.18.3
|
|
||||||
streamlit==1.30.0
|
streamlit==1.30.0
|
||||||
torch==2.2.2
|
sympy==1.13.3
|
||||||
torchvision==0.17.2
|
tenacity==8.5.0
|
||||||
|
threadpoolctl==3.6.0
|
||||||
|
tiktoken==0.5.1
|
||||||
|
tokenizers==0.21.1
|
||||||
|
toml==0.10.2
|
||||||
|
torch==2.7.0+cu118
|
||||||
|
torchvision==0.22.0+cu118
|
||||||
|
tornado==6.4.2
|
||||||
|
tqdm==4.67.1
|
||||||
|
transformers==4.48.0
|
||||||
|
triton==3.3.0
|
||||||
|
trl==0.13.0
|
||||||
|
typing_extensions==4.13.2
|
||||||
|
tzlocal==5.3.1
|
||||||
|
ujson==5.1.0
|
||||||
|
urllib3==2.4.0
|
||||||
|
validators==0.34.0
|
||||||
|
wandb==0.18.3
|
||||||
|
watchdog==6.0.0
|
||||||
|
Werkzeug==3.1.3
|
||||||
|
xxhash==3.5.0
|
||||||
|
yarl==1.20.0
|
||||||
|
zipp==3.21.0
|
||||||
|
48
run_file/DynamicKV-LLM_Mini_Minimind.sh
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# 激活conda环境
|
||||||
|
source $(conda info --base)/etc/profile.d/conda.sh
|
||||||
|
conda activate ycz_accelerate
|
||||||
|
|
||||||
|
# 设置环境变量以帮助调试
|
||||||
|
export NCCL_DEBUG=INFO
|
||||||
|
export PYTHONFAULTHANDLER=1
|
||||||
|
|
||||||
|
# 方法1: 使用预先配置的accelerate配置文件
|
||||||
|
# accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
|
||||||
|
# --epochs 3 \
|
||||||
|
# --batch_size 24 \
|
||||||
|
# --learning_rate 2e-4 \
|
||||||
|
# --dtype bfloat16 \
|
||||||
|
# --accumulation_steps 32 \
|
||||||
|
# --grad_clip 1.0 \
|
||||||
|
# --log_interval 100 \
|
||||||
|
# --save_interval 10000 \
|
||||||
|
# --dim 1024 \
|
||||||
|
# --n_layers 32 \
|
||||||
|
# --max_seq_len 1024 \
|
||||||
|
# --use_flash_attn \
|
||||||
|
# --profile \
|
||||||
|
# --profile_interval 10
|
||||||
|
|
||||||
|
# 方法2: 使用命令行参数直接配置accelerate
|
||||||
|
CUDA_VISIBLE_DEVICES=0 accelerate launch \
|
||||||
|
--multi_gpu \
|
||||||
|
--num_processes=4 \
|
||||||
|
--mixed_precision=bf16 \
|
||||||
|
--main_process_port=29500 \
|
||||||
|
train_pretrain_accelerate.py \
|
||||||
|
--epochs 3 \
|
||||||
|
--batch_size 24 \
|
||||||
|
--learning_rate 2e-4 \
|
||||||
|
--dtype bfloat16 \
|
||||||
|
--accumulation_steps 32 \
|
||||||
|
--grad_clip 1.0 \
|
||||||
|
--log_interval 100 \
|
||||||
|
--save_interval 10000 \
|
||||||
|
--dim 512 \
|
||||||
|
--n_layers 12 \
|
||||||
|
--max_seq_len 512 \
|
||||||
|
--use_flash_attn \
|
||||||
|
--profile \
|
||||||
|
--profile_interval 10
|
@ -275,6 +275,8 @@ def main():
|
|||||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||||
|
parser.add_argument("--knowlwdge_num", type=int, default=64*64,help="知识库的数据数目")
|
||||||
|
parser.add_argument("--knowlwdge_length", type=int, default=8,help="知识库的句子长度")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# 初始化accelerator
|
# 初始化accelerator
|
||||||
@ -304,7 +306,9 @@ def main():
|
|||||||
max_seq_len=args.max_seq_len,
|
max_seq_len=args.max_seq_len,
|
||||||
use_moe=args.use_moe,
|
use_moe=args.use_moe,
|
||||||
disable_db=args.disable_db,
|
disable_db=args.disable_db,
|
||||||
flash_attn=args.use_flash_attn
|
flash_attn=args.use_flash_attn,
|
||||||
|
knowlwdge_num=args.knowlwdge_num,
|
||||||
|
knowlwdge_length=args.knowlwdge_length
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建保存目录
|
# 创建保存目录
|
||||||
|