This commit is contained in:
Jax922 2025-05-13 15:05:56 +08:00
parent 7ba51b8571
commit f31e17030c
28 changed files with 192 additions and 1853 deletions

View File

@ -1,128 +0,0 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

199
README.md
View File

@ -1,199 +0,0 @@
<div align="center">
![logo](./images/logo.png)
</div>
<div align="center">
![visitors](https://visitor-badge.laobi.icu/badge?page_id=jingyaogong/minimind)
[![GitHub Repo stars](https://img.shields.io/github/stars/jingyaogong/minimind?style=social)](https://github.com/jingyaogong/minimind/stargazers)
[![GitHub Code License](https://img.shields.io/github/license/jingyaogong/minimind)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/jingyaogong/minimind)](https://github.com/jingyaogong/minimind/commits/master)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/jingyaogong/minimind/pulls)
[![Collection](https://img.shields.io/badge/🤗-MiniMind%20%20Collection-blue)](https://huggingface.co/collections/jingyaogong/minimind-66caf8d999f5c7fa64f399e5)
</div>
# 📌 数据介绍
## Tokenizer
分词器将单词从自然语言通过“词典”映射到`0, 1, 36`这样的数字,可以理解为数字就代表了单词在“词典”中的页码。
可以选择自己构造词表训练一个“词典”,代码可见`./scripts/train_tokenizer.py`仅供学习参考若非必要无需再自行训练MiniMind已自带tokenizer
或者选择比较出名的开源大模型分词器,
正如同直接用新华/牛津词典的优点是token编码压缩率很好缺点是页数太多动辄数十万个词汇短语
自己训练的分词器,优点是词表长度和内容随意控制,缺点是压缩率很低(例如"hello"也许会被拆分为"h e l l o"
五个独立的token且生僻词难以覆盖。
“词典”的选择固然很重要LLM的输出本质上是SoftMax到词典N个词的多分类问题然后通过“词典”解码到自然语言。
因为MiniMind体积需要严格控制为了避免模型头重脚轻词嵌入embedding层参数在LLM占比太高所以词表长度短短益善。
<details style="color:rgb(128,128,128)">
<summary>Tokenizer介绍</summary>
第三方强大的开源模型例如Yi、qwen、chatglm、mistral、Llama3的tokenizer词表长度如下
<table>
<tr><th>Tokenizer模型</th><th>词表大小</th><th>来源</th></tr>
<tr><td>yi tokenizer</td><td>64,000</td><td>01万物中国</td></tr>
<tr><td>qwen2 tokenizer</td><td>151,643</td><td>阿里云(中国)</td></tr>
<tr><td>glm tokenizer</td><td>151,329</td><td>智谱AI中国</td></tr>
<tr><td>mistral tokenizer</td><td>32,000</td><td>Mistral AI法国</td></tr>
<tr><td>llama3 tokenizer</td><td>128,000</td><td>Meta美国</td></tr>
<tr><td>minimind tokenizer</td><td>6,400</td><td>自定义</td></tr>
</table>
> 👉2024-09-17更新为了防止过去的版本歧义&控制体积minimind所有模型均使用minimind_tokenizer分词废弃所有mistral_tokenizer版本。
```
# 一些自言自语
> 尽管minimind_tokenizer长度很小编解码效率弱于qwen2、glm等中文友好型分词器。
> 但minimind模型选择了自己训练的minimind_tokenizer作为分词器以保持整体参数轻量避免编码层和计算层占比失衡头重脚轻因为minimind的词表大小只有6400。
> 且minimind在实际测试中没有出现过生僻词汇解码失败的情况效果良好。
> 由于自定义词表压缩长度到6400使得LLM总参数量最低只有25.8M。
> 训练数据`tokenizer_train.jsonl`均来自于`匠数大模型数据集`,这部分数据相对次要,如需训练可以自由选择。
```
</details>
## Ⅱ Pretrain数据
经历了MiniMind-V1的低质量预训练数据导致模型胡言乱语的教训`2025-02-05` 之后决定不再采用大规模无监督的数据集做预训练。
进而尝试把[匠数大模型数据集](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)的中文部分提取出来,
清洗出字符`<512`长度的大约1.6GB的语料直接拼接成预训练数据 `pretrain_hq.jsonl`hq即为high
quality当然也还不算high提升数据质量无止尽
文件`pretrain_hq.jsonl` 数据格式为
```bash
{"text": "如何才能摆脱拖延症? 治愈拖延症并不容易,但以下建议可能有所帮助..."}
```
## Ⅲ SFT数据
[匠数大模型SFT数据集](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
“是一个完整、格式统一、安全的大模型训练和研究资源。
从网络上的公开数据源收集并整理了大量开源数据集,对其进行了格式统一,数据清洗,
包含10M条数据的中文数据集和包含2M条数据的英文数据集。”
以上是官方介绍下载文件后的数据总量大约在4B tokens肯定是适合作为中文大语言模型的SFT数据的。
但是官方提供的数据格式很乱全部用来sft代价太大。
我将把官方数据集进行了二次清洗,把含有符号污染和噪声的条目去除;另外依然只保留了总长度`<512`
的内容,此阶段希望通过大量对话补充预训练阶段欠缺的知识。
导出文件为`sft_512.jsonl`(~7.5GB)。
[Magpie-SFT数据集](https://www.modelscope.cn/organization/Magpie-Align)
收集了~1M条来自Qwen2/2.5的高质量对话,我将这部分数据进一步清洗,把总长度`<2048`的部分导出为`sft_2048.jsonl`(~9GB)。
长度`<1024`的部分导出为`sft_1024.jsonl`(~5.5GB)用大模型对话数据直接进行sft就属于“黑盒蒸馏”的范畴。
进一步清洗前两步sft的数据只保留中文字符占比高的内容筛选长度`<512`的对话,得到`sft_mini_512.jsonl`(~1.2GB)。
所有sft文件 `sft_X.jsonl` 数据格式均为
```text
{
"conversations": [
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "你好!"},
{"role": "user", "content": "再见"},
{"role": "assistant", "content": "再见!"}
]
}
```
## Ⅳ RLHF数据
来自[Magpie-DPO数据集](https://www.modelscope.cn/datasets/Magpie-Align/MagpieLM-DPO-Data-v0.1)
大约200k条偏好数据均是英文生成自Llama3.1-70B/8B可以用于训练奖励模型优化模型回复质量使其更加符合人类偏好。
这里将数据总长度`<3000`的内容重组为`dpo.jsonl`(~0.9GB),包含`chosen``rejected`两个字段,`chosen`
为偏好的回复,`rejected`为拒绝的回复。
文件 `dpo.jsonl` 数据格式为
```text
{
"chosen": [
{"content": "Q", "role": "user"},
{"content": "good answer", "role": "assistant"}
],
"rejected": [
{"content": "Q", "role": "user"},
{"content": "bad answer", "role": "assistant"}
]
}
```
## Reason数据集
不得不说2025年2月谁能火的过DeepSeek...
也激发了我对RL引导的推理模型的浓厚兴趣目前已经用Qwen2.5复现了R1-Zero。
如果有时间+效果work但99%基模能力不足我会在之后更新MiniMind基于RL训练的推理模型而不是蒸馏模型。
时间有限,最快的低成本方案依然是直接蒸馏(黑盒方式)。
耐不住R1太火短短几天就已经存在一些R1的蒸馏数据集[R1-Llama-70B](https://www.modelscope.cn/datasets/Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B)、[R1-Distill-SFT](https://www.modelscope.cn/datasets/AI-ModelScope/R1-Distill-SFT)、
[Alpaca-Distill-R1](https://huggingface.co/datasets/shareAI/Alpaca-Distill-R1-ZH)、
[deepseek_r1_zh](https://huggingface.co/datasets/jinliuxi/deepseek_r1_zh)等等,纯中文的数据可能比较少。
最终整合它们,导出文件为`r1_mix_1024.jsonl`,数据格式和`sft_X.jsonl`一致。
## Ⅵ 更多数据集
目前已经有[HqWu-HITCS/Awesome-Chinese-LLM](https://github.com/HqWu-HITCS/Awesome-Chinese-LLM)
在收集和梳理中文LLM相关的开源模型、应用、数据集及教程等资料并持续更新这方面的最新进展。全面且专业Respect
---
## Ⅷ 数据集下载
> [!NOTE]
> 2025-02-05后开源MiniMind最终训练所用的所有数据集因此无需再自行预处理大规模数据集避免重复性的数据处理工作。
MiniMind训练数据集 ([ModelScope](https://www.modelscope.cn/datasets/gongjy/minimind-dataset/files) | [HuggingFace](https://huggingface.co/datasets/jingyaogong))
> 无需全部clone可单独下载所需的文件
将下载的数据集文件放到`./dataset/`目录下(✨为推荐的必须项)
```bash
./dataset/
├── dpo.jsonl (909MB)
├── lora_identity.jsonl (22.8KB)
├── lora_medical.jsonl (34MB)
├── pretrain_hq.jsonl (1.6GB, ✨)
├── r1_mix_1024.jsonl (340MB)
├── sft_1024.jsonl (5.6GB)
├── sft_2048.jsonl (9GB)
├── sft_512.jsonl (7.5GB)
├── sft_mini_512.jsonl (1.2GB, ✨)
└── tokenizer_train.jsonl (1GB)
```
<details style="color:rgb(128,128,128)">
<summary>注:各数据集简介</summary>
* `dpo.jsonl` --RLHF阶段数据集
* `lora_identity.jsonl` --自我认知数据集例如你是谁我是minimind...推荐用于lora训练亦可用于全参SFT勿被名字局限
* `lora_medical.jsonl` --医疗问答数据集推荐用于lora训练亦可用于全参SFT勿被名字局限
* `pretrain_hq.jsonl`✨ --预训练数据集整合自jiangshu科技
* `r1_mix_1024.jsonl` --DeepSeek-R1-1.5B蒸馏数据每条数据字符最大长度为1024因此训练时设置max_seq_len=1024
* `sft_1024.jsonl` --整合自Qwen2.5蒸馏数据是sft_2048的子集每条数据字符最大长度为1024因此训练时设置max_seq_len=1024
* `sft_2048.jsonl` --整合自Qwen2.5蒸馏数据每条数据字符最大长度为2048因此训练时设置max_seq_len=2048
* `sft_512.jsonl` --整合自匠数科技SFT数据每条数据字符最大长度为512因此训练时设置max_seq_len=512
* `sft_mini_512.jsonl`✨ --极简整合自匠数科技SFT数据+Qwen2.5蒸馏数据用于快速训练Zero模型每条数据字符最大长度为512因此训练时设置max_seq_len=512
* `tokenizer_train.jsonl` --均来自于`匠数大模型数据集`这部分数据相对次要不推荐自己重复训练tokenizer理由如上如需自己训练tokenizer可以自由选择数据集。
</details>
![dataset](./images/dataset.jpg)
<details style="color:rgb(128,128,128)">
<summary>说明 & 推荐训练方案</summary>
* MiniMind2 Series均经过共约20GB语料训练大约4B tokens即对应上面的数据组合训练结果开销💰💰💰💰💰💰💰💰效果😊😊😊😊😊😊
* 想要最快速度从0实现Zero模型推荐使用`pretrain_hq.jsonl` + `sft_mini_512.jsonl` 的数据组合,具体花销和效果可查看下文表格(开销:💰,效果:😊😊)
* 推荐具备一定算力资源或更在意效果的朋友可以考虑前者完整复现MiniMind2仅有单卡GPU或在乎短时间快速复现的朋友强烈推荐后者
* 【折中方案】亦可选择例如`sft_mini_512.jsonl``sft_1024.jsonl`中等规模数据进行自由组合训练(开销:💰💰💰,效果:😊😊😊😊)。
</details>

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 73 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 230 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 104 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 239 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 121 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 372 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 178 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 150 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 519 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 146 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 66 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 495 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 615 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.8 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 559 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 531 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1006 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 943 KiB

View File

@ -36,6 +36,9 @@ class LMConfig(PretrainedConfig):
aux_loss_alpha: float = 0.1, aux_loss_alpha: float = 0.1,
seq_aux: bool = True, seq_aux: bool = True,
norm_topk_prob: bool = True, norm_topk_prob: bool = True,
####################################################
knowlwdge_num: int = 64*64,
knowlwdge_length: int = 8,
**kwargs, **kwargs,
): ):
self.dim = dim self.dim = dim
@ -66,4 +69,7 @@ class LMConfig(PretrainedConfig):
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数 self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失 self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率 self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
####################################################
self.knowlwdge_num = knowlwdge_num
self.knowlwdge_length = knowlwdge_length
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@ -528,15 +528,15 @@ class ExtractDB(nn.Module):
self.batch_size = None self.batch_size = None
self.dim = params.dim self.dim = params.dim
self.dim_key = self.dim // 2 self.dim_key = self.dim // 2
self.num_experts = 10 * 10 # 100专家确保是完全平方数 self.knowlwdge_num = params.knowlwdge_num # 100专家确保是完全平方数
# 将knowledge_dim设置为与head_dim相同以便在attention中直接使用 # 将knowledge_dim设置为与head_dim相同以便在attention中直接使用
self.head_dim = params.dim // params.n_heads self.head_dim = params.dim // params.n_heads
self.knowledge_dim = 8*params.dim self.knowledge_length = params.knowlwdge_length*params.dim
# 使用register_buffer代替nn.Parameter避免梯度问题 # 使用register_buffer代替nn.Parameter避免梯度问题
self.register_buffer('weight_down_embed', torch.randn(self.num_experts, self.knowledge_dim) * 0.02) self.register_buffer('weight_down_embed', torch.randn(self.knowlwdge_num, self.knowledge_length) * 0.02)
self.num_keys = int(math.sqrt(self.num_experts)) if self.num_experts > 0 else 0 self.num_keys = int(math.sqrt(self.knowlwdge_num)) if self.knowlwdge_num > 0 else 0
self.product_key_topk = min(16, self.num_keys) self.product_key_topk = min(16, self.num_keys)
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02) self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
self.num_experts_per_head_topk = 1 self.num_experts_per_head_topk = 1

View File

@ -1,30 +1,147 @@
accelerate==1.6.0
aiohappyeyeballs==2.6.1
aiohttp==3.11.17
aiosignal==1.3.2
altair==5.5.0
annotated-types==0.7.0
anyio==4.9.0
async-timeout==5.0.1
attrs==25.3.0
blinker==1.9.0
cachetools==5.5.2
certifi==2025.1.31
charset-normalizer==3.4.1
click==8.1.8
contourpy==1.3.2
cycler==0.12.1
datasets==2.21.0 datasets==2.21.0
datasketch==1.6.4 datasketch==1.6.4
deepspeed==0.16.7
dill==0.3.8
distro==1.9.0
docker-pycreds==0.4.0
einops==0.8.1
exceptiongroup==1.2.2
filelock==3.18.0
Flask==3.0.3 Flask==3.0.3
Flask_Cors==4.0.0 Flask-Cors==4.0.0
fonttools==4.57.0
frozenlist==1.6.0
fsspec==2024.6.1
gitdb==4.0.12
GitPython==3.1.44
h11==0.14.0
hjson==3.1.0
httpcore==1.0.8
httpx==0.28.1
huggingface-hub==0.30.2
idna==3.10
importlib_metadata==7.2.1
itsdangerous==2.2.0
jieba==0.42.1 jieba==0.42.1
Jinja2==3.1.2
jiter==0.9.0
joblib==1.4.2
jsonlines==4.0.0 jsonlines==4.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
kiwisolver==1.4.8
markdown-it-py==3.0.0
MarkupSafe==3.0.2
marshmallow==3.22.0 marshmallow==3.22.0
matplotlib==3.10.0 matplotlib==3.10.0
mdurl==0.1.2
modelscope==1.25.0
mpmath==1.3.0
msgpack==1.1.0
multidict==6.4.3
multiprocess==0.70.16
narwhals==1.35.0
networkx==3.4.2
ngrok==1.4.0 ngrok==1.4.0
ninja==1.11.1.4
nltk==3.8 nltk==3.8
numpy==1.26.4 numpy==1.26.4
nvidia-cublas-cu11==11.11.3.6
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu11==11.8.89
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu11==9.1.0.70
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu11==10.9.0.58
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu11==10.3.0.86
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu11==11.7.5.86
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu11==2.21.5
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.8.93
nvidia-nvtx-cu11==11.8.86
nvidia-nvtx-cu12==12.1.105
openai==1.59.6 openai==1.59.6
packaging==23.2
pandas==1.5.3 pandas==1.5.3
peft==0.7.1 peft==0.7.1
pillow==10.4.0
platformdirs==4.3.7
propcache==0.3.1
protobuf==4.25.6
psutil==5.9.8 psutil==5.9.8
py-cpuinfo==9.0.0
pyarrow==19.0.1
pydantic==2.8.2 pydantic==2.8.2
pydantic_core==2.20.1
pydeck==0.9.1
Pygments==2.19.1
pyparsing==3.2.3
python-dateutil==2.9.0.post0
pytz==2025.2
PyYAML==6.0.2
referencing==0.36.2
regex==2024.11.6
requests==2.32.3
rich==13.7.1 rich==13.7.1
scikit_learn==1.5.1 rpds-py==0.24.0
sentence_transformers==2.3.1 safetensors==0.5.3
scikit-learn==1.5.1
scipy==1.15.2
sentence-transformers==2.3.1
sentencepiece==0.2.0
sentry-sdk==2.26.1
setproctitle==1.3.5
simhash==2.1.2 simhash==2.1.2
tiktoken==0.5.1 six==1.17.0
transformers==4.48.0 smmap==5.0.2
jinja2==3.1.2 sniffio==1.3.1
jsonlines==4.0.0
trl==0.13.0
ujson==5.1.0
wandb==0.18.3
streamlit==1.30.0 streamlit==1.30.0
torch==2.2.2 sympy==1.13.3
torchvision==0.17.2 tenacity==8.5.0
threadpoolctl==3.6.0
tiktoken==0.5.1
tokenizers==0.21.1
toml==0.10.2
torch==2.7.0+cu118
torchvision==0.22.0+cu118
tornado==6.4.2
tqdm==4.67.1
transformers==4.48.0
triton==3.3.0
trl==0.13.0
typing_extensions==4.13.2
tzlocal==5.3.1
ujson==5.1.0
urllib3==2.4.0
validators==0.34.0
wandb==0.18.3
watchdog==6.0.0
Werkzeug==3.1.3
xxhash==3.5.0
yarl==1.20.0
zipp==3.21.0

View File

@ -0,0 +1,48 @@
#!/bin/bash
# 激活conda环境
source $(conda info --base)/etc/profile.d/conda.sh
conda activate ycz_accelerate
# 设置环境变量以帮助调试
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1
# 方法1: 使用预先配置的accelerate配置文件
# accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
# --epochs 3 \
# --batch_size 24 \
# --learning_rate 2e-4 \
# --dtype bfloat16 \
# --accumulation_steps 32 \
# --grad_clip 1.0 \
# --log_interval 100 \
# --save_interval 10000 \
# --dim 1024 \
# --n_layers 32 \
# --max_seq_len 1024 \
# --use_flash_attn \
# --profile \
# --profile_interval 10
# 方法2: 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0 accelerate launch \
--multi_gpu \
--num_processes=4 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py \
--epochs 3 \
--batch_size 24 \
--learning_rate 2e-4 \
--dtype bfloat16 \
--accumulation_steps 32 \
--grad_clip 1.0 \
--log_interval 100 \
--save_interval 10000 \
--dim 512 \
--n_layers 12 \
--max_seq_len 512 \
--use_flash_attn \
--profile \
--profile_interval 10

View File

@ -275,6 +275,8 @@ def main():
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)") parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention") parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowlwdge_num", type=int, default=64*64,help="知识库的数据数目")
parser.add_argument("--knowlwdge_length", type=int, default=8,help="知识库的句子长度")
args = parser.parse_args() args = parser.parse_args()
# 初始化accelerator # 初始化accelerator
@ -304,7 +306,9 @@ def main():
max_seq_len=args.max_seq_len, max_seq_len=args.max_seq_len,
use_moe=args.use_moe, use_moe=args.use_moe,
disable_db=args.disable_db, disable_db=args.disable_db,
flash_attn=args.use_flash_attn flash_attn=args.use_flash_attn,
knowlwdge_num=args.knowlwdge_num,
knowlwdge_length=args.knowlwdge_length
) )
# 创建保存目录 # 创建保存目录