DynamicKV-LLM Pretrain v1.1.0
This commit is contained in:
commit
089afd6728
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
/model/__pycache__
|
||||
/dataset
|
||||
/out
|
||||
wandb/
|
||||
**/*.log
|
201
LICENSE
Normal file
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
126
README_accelerate.md
Normal file
126
README_accelerate.md
Normal file
@ -0,0 +1,126 @@
|
||||
# 使用Accelerate+DeepSpeed进行分布式训练
|
||||
|
||||
本文档介绍如何使用Accelerate和DeepSpeed进行MiniMind模型的分布式训练。
|
||||
|
||||
## 环境准备
|
||||
|
||||
首先,确保安装了必要的依赖:
|
||||
|
||||
```bash
|
||||
pip install accelerate deepspeed
|
||||
```
|
||||
|
||||
## 配置文件说明
|
||||
|
||||
### 1. DeepSpeed配置文件 (ds_config.json)
|
||||
|
||||
DeepSpeed配置文件定义了优化器、学习率调度器和ZeRO优化等参数。主要配置包括:
|
||||
|
||||
- **ZeRO优化**:使用ZeRO-2进行优化,可以减少GPU内存使用
|
||||
- **优化器设置**:使用AdamW优化器
|
||||
- **混合精度训练**:支持FP16和BF16
|
||||
- **梯度累积**:通过"auto"自动设置,与训练脚本参数保持一致
|
||||
|
||||
### 2. Accelerate配置文件 (accelerate_config.yaml)
|
||||
|
||||
Accelerate配置文件定义了分布式训练的基本设置,包括:
|
||||
|
||||
- **分布式类型**:使用DeepSpeed
|
||||
- **混合精度**:使用BF16
|
||||
- **进程数量**:设置为4(可根据GPU数量调整)
|
||||
- **DeepSpeed配置**:指向ds_config.json文件
|
||||
|
||||
## 训练脚本说明
|
||||
|
||||
新的训练脚本`train_pretrain_accelerate.py`基于原有的`train_pretrain.py`修改而来,主要变化包括:
|
||||
|
||||
1. 使用Accelerator替代了PyTorch原生的分布式功能
|
||||
2. 移除了torchrun相关的分布式初始化代码
|
||||
3. 使用Accelerator的API进行模型、优化器和数据加载器的准备
|
||||
4. 使用Accelerator的API进行反向传播和梯度裁剪
|
||||
5. 处理了位置编码和未使用参数的问题
|
||||
|
||||
## 启动训练
|
||||
|
||||
有两种方式启动训练:
|
||||
|
||||
### 方法1:使用预先配置的accelerate配置文件
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
|
||||
--epochs 3 \
|
||||
--batch_size 24 \
|
||||
--learning_rate 2e-4 \
|
||||
--dtype bfloat16 \
|
||||
--accumulation_steps 32 \
|
||||
--grad_clip 1.0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 1024 \
|
||||
--n_layers 32 \
|
||||
--max_seq_len 1024 \
|
||||
--use_flash_attn \
|
||||
--profile \
|
||||
--profile_interval 10
|
||||
```
|
||||
|
||||
### 方法2:使用命令行参数直接配置accelerate
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
--deepspeed_config_file ds_config.json \
|
||||
train_pretrain_accelerate.py \
|
||||
--epochs 3 \
|
||||
--batch_size 24 \
|
||||
--learning_rate 2e-4 \
|
||||
--dtype bfloat16 \
|
||||
--accumulation_steps 32 \
|
||||
--grad_clip 1.0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 1024 \
|
||||
--n_layers 32 \
|
||||
--max_seq_len 1024 \
|
||||
--use_flash_attn \
|
||||
--profile \
|
||||
--profile_interval 10
|
||||
```
|
||||
|
||||
也可以直接使用提供的脚本:
|
||||
|
||||
```bash
|
||||
bash run_accelerate.sh
|
||||
```
|
||||
|
||||
## Accelerate与DeepSpeed配置的关系
|
||||
|
||||
1. **Accelerate**是一个高级API,用于简化分布式训练的设置和启动,它可以与多种分布式训练后端(如DeepSpeed、FSDP等)一起使用。
|
||||
|
||||
2. **DeepSpeed**是一个优化库,专注于大规模模型训练的内存优化和性能提升,提供了ZeRO优化等功能。
|
||||
|
||||
3. **配置关系**:
|
||||
- Accelerate配置文件(YAML)定义了使用哪种分布式后端以及基本的分布式设置
|
||||
- DeepSpeed配置文件(JSON)定义了DeepSpeed特有的优化参数
|
||||
- Accelerate通过`deepspeed_config_file`参数引用DeepSpeed配置文件
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **位置编码处理**:
|
||||
- 在模型中,`pos_cis`是一个复数张量,在分布式训练中需要特别处理
|
||||
- 在新的训练脚本中,我们使用Accelerator的API来处理这个问题,不再需要`_ddp_params_and_buffers_to_ignore`
|
||||
|
||||
2. **未使用参数处理**:
|
||||
- 原代码中使用`find_unused_parameters=True`来处理未使用的参数
|
||||
- 在新的训练脚本中,我们直接使用Accelerator的API,它会自动处理这个问题
|
||||
|
||||
3. **混合精度训练**:
|
||||
- DeepSpeed配置文件中的`fp16`和`bf16`设置为`"auto"`
|
||||
- 实际使用的精度由Accelerate的`--mixed_precision`参数决定
|
||||
|
||||
4. **梯度累积**:
|
||||
- DeepSpeed配置文件中的`gradient_accumulation_steps`设置为`"auto"`
|
||||
- 实际的梯度累积步数由训练脚本的`--accumulation_steps`参数决定
|
22
ReadMe.md
Normal file
22
ReadMe.md
Normal file
@ -0,0 +1,22 @@
|
||||
## 安装环境
|
||||
1. 创建conda环境
|
||||
```bash
|
||||
conda create -n accelerate python=3.10
|
||||
conda activate accelerate
|
||||
```
|
||||
|
||||
2. 根据当前系统的cuda版本安装对应的torch、torchvision和torchaudio
|
||||
|
||||
3. 根据当前环境的torch和torchvision安装accelerate和deepspeed
|
||||
|
||||
4. 安装其他包
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 修改模型
|
||||
1. 一般情况只修改 `model`文件夹的文件
|
||||
|
||||
## 运行
|
||||
1. 如果在4090或者4070ti上运行 `bash run_file/DynamicKV-LLM_Mini_Minimind.sh`
|
||||
2. 如果在4张A800上运行 `bash run_file/DynamicKV-LLM_Small_Minimind.sh`
|
17
accelerate_config.yaml
Normal file
17
accelerate_config.yaml
Normal file
@ -0,0 +1,17 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
deepspeed_config:
|
||||
deepspeed_config_file: ds_config.json
|
||||
zero3_init_flag: false
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
49
ds_config.json
Normal file
49
ds_config.json
Normal file
@ -0,0 +1,49 @@
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": "auto",
|
||||
"eps": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"steps_per_print": 100,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
181
eval_model.py
Normal file
181
eval_model.py
Normal file
@ -0,0 +1,181 @@
|
||||
import argparse
|
||||
import random
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import warnings
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from model.model import MiniMindLM
|
||||
from model.LMConfig import LMConfig
|
||||
from model.model_lora import *
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def init_model(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
if args.load == 0:
|
||||
moe_path = '_moe' if args.use_moe else ''
|
||||
modes = {0: 'pretrain', 1: 'full_sft', 2: 'rlhf', 3: 'reason', 4: 'grpo'}
|
||||
ckp = f'./{args.out_dir}/{modes[args.model_mode]}_{args.dim}{moe_path}.pth'
|
||||
|
||||
model = MiniMindLM(LMConfig(
|
||||
dim=args.dim,
|
||||
n_layers=args.n_layers,
|
||||
max_seq_len=args.max_seq_len,
|
||||
use_moe=args.use_moe
|
||||
))
|
||||
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
model.load_state_dict({k: v for k, v in state_dict.items() if 'mask' not in k}, strict=True)
|
||||
|
||||
if args.lora_name != 'None':
|
||||
apply_lora(model)
|
||||
load_lora(model, f'./{args.out_dir}/lora/{args.lora_name}_{args.dim}.pth')
|
||||
else:
|
||||
transformers_model_path = './MiniMind2'
|
||||
tokenizer = AutoTokenizer.from_pretrained(transformers_model_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(transformers_model_path, trust_remote_code=True)
|
||||
print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M(illion)')
|
||||
return model.eval().to(args.device), tokenizer
|
||||
|
||||
|
||||
def get_prompt_datas(args):
|
||||
if args.model_mode == 0:
|
||||
# pretrain模型的接龙能力(无法对话)
|
||||
prompt_datas = [
|
||||
'马克思主义基本原理',
|
||||
'人类大脑的主要功能',
|
||||
'万有引力原理是',
|
||||
'世界上最高的山峰是',
|
||||
'二氧化碳在空气中',
|
||||
'地球上最大的动物有',
|
||||
'杭州市的美食有'
|
||||
]
|
||||
else:
|
||||
if args.lora_name == 'None':
|
||||
# 通用对话问题
|
||||
prompt_datas = [
|
||||
'请介绍一下自己。',
|
||||
'你更擅长哪一个学科?',
|
||||
'鲁迅的《狂人日记》是如何批判封建礼教的?',
|
||||
'我咳嗽已经持续了两周,需要去医院检查吗?',
|
||||
'详细的介绍光速的物理概念。',
|
||||
'推荐一些杭州的特色美食吧。',
|
||||
'请为我讲解“大语言模型”这个概念。',
|
||||
'如何理解ChatGPT?',
|
||||
'Introduce the history of the United States, please.'
|
||||
]
|
||||
else:
|
||||
# 特定领域问题
|
||||
lora_prompt_datas = {
|
||||
'lora_identity': [
|
||||
"你是ChatGPT吧。",
|
||||
"你叫什么名字?",
|
||||
"你和openai是什么关系?"
|
||||
],
|
||||
'lora_medical': [
|
||||
'我最近经常感到头晕,可能是什么原因?',
|
||||
'我咳嗽已经持续了两周,需要去医院检查吗?',
|
||||
'服用抗生素时需要注意哪些事项?',
|
||||
'体检报告中显示胆固醇偏高,我该怎么办?',
|
||||
'孕妇在饮食上需要注意什么?',
|
||||
'老年人如何预防骨质疏松?',
|
||||
'我最近总是感到焦虑,应该怎么缓解?',
|
||||
'如果有人突然晕倒,应该如何急救?'
|
||||
],
|
||||
}
|
||||
prompt_datas = lora_prompt_datas[args.lora_name]
|
||||
|
||||
return prompt_datas
|
||||
|
||||
|
||||
# 设置可复现的随机种子
|
||||
def setup_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Chat with MiniMind")
|
||||
parser.add_argument('--lora_name', default='None', type=str)
|
||||
parser.add_argument('--out_dir', default='out', type=str)
|
||||
parser.add_argument('--temperature', default=0.85, type=float)
|
||||
parser.add_argument('--top_p', default=0.85, type=float)
|
||||
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str)
|
||||
# 此处max_seq_len(最大允许输入长度)并不意味模型具有对应的长文本的性能,仅防止QA出现被截断的问题
|
||||
# MiniMind2-moe (145M):(dim=640, n_layers=8, use_moe=True)
|
||||
# MiniMind2-Small (26M):(dim=512, n_layers=8)
|
||||
# MiniMind2 (104M):(dim=768, n_layers=16)
|
||||
parser.add_argument('--dim', default=512, type=int)
|
||||
parser.add_argument('--n_layers', default=8, type=int)
|
||||
parser.add_argument('--max_seq_len', default=8192, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
# 携带历史对话上下文条数
|
||||
# history_cnt需要设为偶数,即【用户问题, 模型回答】为1组;设置为0时,即当前query不携带历史上文
|
||||
# 模型未经过外推微调时,在更长的上下文的chat_template时难免出现性能的明显退化,因此需要注意此处设置
|
||||
parser.add_argument('--history_cnt', default=0, type=int)
|
||||
parser.add_argument('--stream', default=True, type=bool)
|
||||
parser.add_argument('--load', default=0, type=int, help="0: 原生torch权重,1: transformers加载")
|
||||
parser.add_argument('--model_mode', default=1, type=int,
|
||||
help="0: 预训练模型,1: SFT-Chat模型,2: RLHF-Chat模型,3: Reason模型,4: RLAIF-Chat模型")
|
||||
args = parser.parse_args()
|
||||
|
||||
model, tokenizer = init_model(args)
|
||||
|
||||
prompts = get_prompt_datas(args)
|
||||
test_mode = int(input('[0] 自动测试\n[1] 手动输入\n'))
|
||||
messages = []
|
||||
for idx, prompt in enumerate(prompts if test_mode == 0 else iter(lambda: input('👶: '), '')):
|
||||
setup_seed(random.randint(0, 2048))
|
||||
# setup_seed(2025) # 如需固定每次输出则换成【固定】的随机种子
|
||||
if test_mode == 0: print(f'👶: {prompt}')
|
||||
|
||||
messages = messages[-args.history_cnt:] if args.history_cnt else []
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
new_prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)[-args.max_seq_len - 1:] if args.model_mode != 0 else (tokenizer.bos_token + prompt)
|
||||
|
||||
answer = new_prompt
|
||||
with torch.no_grad():
|
||||
x = torch.tensor(tokenizer(new_prompt)['input_ids'], device=args.device).unsqueeze(0)
|
||||
outputs = model.generate(
|
||||
x,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
max_new_tokens=args.max_seq_len,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
stream=args.stream,
|
||||
pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
print('🤖️: ', end='')
|
||||
try:
|
||||
if not args.stream:
|
||||
print(tokenizer.decode(outputs.squeeze()[x.shape[1]:].tolist(), skip_special_tokens=True), end='')
|
||||
else:
|
||||
history_idx = 0
|
||||
for y in outputs:
|
||||
answer = tokenizer.decode(y[0].tolist(), skip_special_tokens=True)
|
||||
if (answer and answer[-1] == '<EFBFBD>') or not answer:
|
||||
continue
|
||||
print(answer[history_idx:], end='', flush=True)
|
||||
history_idx = len(answer)
|
||||
except StopIteration:
|
||||
print("No answer")
|
||||
print('\n')
|
||||
|
||||
messages.append({"role": "assistant", "content": answer})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
75
model/LMConfig.py
Normal file
75
model/LMConfig.py
Normal file
@ -0,0 +1,75 @@
|
||||
from transformers import PretrainedConfig
|
||||
from typing import List
|
||||
|
||||
|
||||
class LMConfig(PretrainedConfig):
|
||||
model_type = "minimind"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 512,
|
||||
n_layers: int = 8,
|
||||
n_heads: int = 32,
|
||||
n_kv_heads: int = 8,
|
||||
vocab_size: int = 6400,
|
||||
hidden_dim: int = None,
|
||||
multiple_of: int = 64,
|
||||
norm_eps: float = 1e-5,
|
||||
max_seq_len: int = 8192,
|
||||
rope_theta: int = 1e6,
|
||||
dropout: float = 0.0,
|
||||
flash_attn: bool = True,
|
||||
####################################################
|
||||
# DB related configurations
|
||||
####################################################
|
||||
disable_db: bool = False, # 特殊模式:禁用数据库功能
|
||||
####################################################
|
||||
# Here are the specific configurations of MOE
|
||||
# When use_moe is false, the following is invalid
|
||||
####################################################
|
||||
use_moe: bool = False,
|
||||
####################################################
|
||||
num_experts_per_tok: int = 2,
|
||||
n_routed_experts: int = 4,
|
||||
n_shared_experts: bool = True,
|
||||
scoring_func: str = 'softmax',
|
||||
aux_loss_alpha: float = 0.1,
|
||||
seq_aux: bool = True,
|
||||
norm_topk_prob: bool = True,
|
||||
####################################################
|
||||
knowlwdge_num: int = 64*64,
|
||||
knowlwdge_length: int = 8,
|
||||
**kwargs,
|
||||
):
|
||||
self.dim = dim
|
||||
self.n_layers = n_layers
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_dim = hidden_dim
|
||||
self.multiple_of = multiple_of
|
||||
self.norm_eps = norm_eps
|
||||
self.max_seq_len = max_seq_len
|
||||
self.rope_theta = rope_theta
|
||||
self.dropout = dropout
|
||||
self.flash_attn = flash_attn
|
||||
####################################################
|
||||
# DB related configurations
|
||||
####################################################
|
||||
self.disable_db = disable_db # 设置是否禁用数据库
|
||||
####################################################
|
||||
# Here are the specific configurations of MOE
|
||||
# When use_moe is false, the following is invalid
|
||||
####################################################
|
||||
self.use_moe = use_moe
|
||||
self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量
|
||||
self.n_routed_experts = n_routed_experts # 总的专家数量
|
||||
self.n_shared_experts = n_shared_experts # 共享专家
|
||||
self.scoring_func = scoring_func # 评分函数,默认为'softmax'
|
||||
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
|
||||
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
|
||||
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
|
||||
####################################################
|
||||
self.knowlwdge_num = knowlwdge_num
|
||||
self.knowlwdge_length = knowlwdge_length
|
||||
super().__init__(**kwargs)
|
245
model/dataset.py
Normal file
245
model/dataset.py
Normal file
@ -0,0 +1,245 @@
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch
|
||||
from sklearn.model_selection import train_test_split
|
||||
import os
|
||||
import ast
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
|
||||
class PretrainDataset(Dataset):
|
||||
def __init__(self, data_path, tokenizer, max_length=512):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.samples = self.load_data(data_path)
|
||||
|
||||
def load_data(self, path):
|
||||
samples = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
data = json.loads(line.strip())
|
||||
samples.append(data)
|
||||
return samples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
|
||||
# 构建输入文本
|
||||
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
|
||||
encoding = self.tokenizer(
|
||||
text,
|
||||
max_length=self.max_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
input_ids = encoding.input_ids.squeeze()
|
||||
loss_mask = (input_ids != self.tokenizer.pad_token_id)
|
||||
|
||||
X = torch.tensor(input_ids[:-1], dtype=torch.long)
|
||||
Y = torch.tensor(input_ids[1:], dtype=torch.long)
|
||||
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)
|
||||
return X, Y, loss_mask
|
||||
|
||||
|
||||
class SFTDataset(Dataset):
|
||||
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.samples = self.load_data(jsonl_path)
|
||||
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def load_data(self, path):
|
||||
samples = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
data = json.loads(line.strip())
|
||||
samples.append(data)
|
||||
return samples
|
||||
|
||||
def _create_chat_prompt(self, conversations):
|
||||
"""构建符合ChatML格式的对话"""
|
||||
messages = []
|
||||
for i, turn in enumerate(conversations):
|
||||
role = 'user' if i % 2 == 0 else 'assistant'
|
||||
messages.append({"role": role, "content": turn['content']})
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False
|
||||
)
|
||||
|
||||
def _generate_loss_mask(self, input_ids):
|
||||
loss_mask = [0] * len(input_ids)
|
||||
i = 0
|
||||
while i < len(input_ids):
|
||||
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
|
||||
start = i + len(self.bos_id)
|
||||
end = start
|
||||
while end < len(input_ids):
|
||||
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
|
||||
break
|
||||
end += 1
|
||||
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
|
||||
loss_mask[j] = 1
|
||||
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
|
||||
else:
|
||||
i += 1
|
||||
return loss_mask
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
# 构建对话提示
|
||||
prompt = self._create_chat_prompt(sample['conversations'])
|
||||
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
|
||||
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
|
||||
|
||||
# 生成动态损失掩码
|
||||
loss_mask = self._generate_loss_mask(input_ids)
|
||||
|
||||
# 构建训练数据
|
||||
X = torch.tensor(input_ids[:-1], dtype=torch.long)
|
||||
Y = torch.tensor(input_ids[1:], dtype=torch.long)
|
||||
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐预测位置
|
||||
|
||||
return X, Y, loss_mask
|
||||
|
||||
|
||||
class DPODataset(Dataset):
|
||||
def __init__(self, file_path, tokenizer, max_length=4096):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
self.data = []
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
obj = json.loads(line)
|
||||
self.data.append(obj)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.data[index]
|
||||
chosen = item['chosen'] # 是一个 list,里面包含若干 {role, content}
|
||||
rejected = item['rejected'] # 同上
|
||||
chosen_prompt = self.tokenizer.apply_chat_template(
|
||||
chosen, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
|
||||
rejected_prompt = self.tokenizer.apply_chat_template(
|
||||
rejected, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
chosen_encoding = self.tokenizer(
|
||||
chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
|
||||
)
|
||||
rejected_encoding = self.tokenizer(
|
||||
rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length'
|
||||
)
|
||||
|
||||
chosen_input_ids = chosen_encoding['input_ids']
|
||||
chosen_loss_mask = self._generate_loss_mask(chosen_input_ids)
|
||||
|
||||
rejected_input_ids = rejected_encoding['input_ids']
|
||||
rejected_loss_mask = self._generate_loss_mask(rejected_input_ids)
|
||||
x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
|
||||
y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
|
||||
mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
|
||||
x_rejected = torch.tensor(rejected_input_ids[:-1], dtype=torch.long)
|
||||
y_rejected = torch.tensor(rejected_input_ids[1:], dtype=torch.long)
|
||||
mask_rejected = torch.tensor(rejected_loss_mask[1:], dtype=torch.long)
|
||||
|
||||
return {
|
||||
'x_chosen': x_chosen,
|
||||
'y_chosen': y_chosen,
|
||||
'mask_chosen': mask_chosen,
|
||||
'x_rejected': x_rejected,
|
||||
'y_rejected': y_rejected,
|
||||
'mask_rejected': mask_rejected
|
||||
}
|
||||
|
||||
def _generate_loss_mask(self, input_ids):
|
||||
loss_mask = [0] * len(input_ids)
|
||||
i = 0
|
||||
while i < len(input_ids):
|
||||
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
|
||||
start = i + len(self.bos_id)
|
||||
end = start
|
||||
while end < len(input_ids):
|
||||
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
|
||||
break
|
||||
end += 1
|
||||
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
|
||||
loss_mask[j] = 1
|
||||
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
|
||||
else:
|
||||
i += 1
|
||||
return loss_mask
|
||||
|
||||
|
||||
class RLAIFDataset(Dataset):
|
||||
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.samples = self.load_data(jsonl_path)
|
||||
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def load_data(self, path):
|
||||
samples = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
data = json.loads(line.strip())
|
||||
samples.append(data)
|
||||
return samples
|
||||
|
||||
def _create_chat_prompt(self, conversations):
|
||||
"""构建符合ChatML格式的对话"""
|
||||
messages = []
|
||||
answer = ''
|
||||
for i, turn in enumerate(conversations):
|
||||
role = 'user' if i % 2 == 0 else 'assistant'
|
||||
messages.append({"role": role, "content": turn['content']})
|
||||
answer = turn['content']
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages[:-1],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
), answer
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
# 构建对话提示
|
||||
prompt, answer = self._create_chat_prompt(sample['conversations'])
|
||||
|
||||
return {
|
||||
'prompt': prompt,
|
||||
'answer': answer
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
6142
model/minimind_tokenizer/merges.txt
Normal file
6142
model/minimind_tokenizer/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
12603
model/minimind_tokenizer/tokenizer.json
Normal file
12603
model/minimind_tokenizer/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
43
model/minimind_tokenizer/tokenizer_config.json
Normal file
43
model/minimind_tokenizer/tokenizer_config.json
Normal file
@ -0,0 +1,43 @@
|
||||
{
|
||||
"add_bos_token": false,
|
||||
"add_eos_token": false,
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [],
|
||||
"bos_token": "<s>",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "</s>",
|
||||
"legacy": true,
|
||||
"model_max_length": 32768,
|
||||
"pad_token": "<unk>",
|
||||
"sp_model_kwargs": {},
|
||||
"spaces_between_special_tokens": false,
|
||||
"tokenizer_class": "PreTrainedTokenizerFast",
|
||||
"unk_token": "<unk>",
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<s>system\\n' + system_message + '</s>\\n' }}{% else %}{{ '<s>system\\n你是 MiniMind,是一个有用的人工智能助手。</s>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
|
||||
}
|
1
model/minimind_tokenizer/vocab.json
Normal file
1
model/minimind_tokenizer/vocab.json
Normal file
File diff suppressed because one or more lines are too long
755
model/model.py
Normal file
755
model/model.py
Normal file
@ -0,0 +1,755 @@
|
||||
import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
# RMSNorm 类定义了一个用于归一化输入张量的模块。
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self.weight * self._norm(x.float()).type_as(x)
|
||||
|
||||
# precompute_pos_cis 函数用于预计算位置编码(复数版本)。
|
||||
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return pos_cis
|
||||
|
||||
# apply_rotary_emb 函数用于应用旋转位置编码(复数版本)。
|
||||
def apply_rotary_emb(xq, xk, pos_cis):
|
||||
def unite_shape(pos_cis, x):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert pos_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return pos_cis.view(*shape)
|
||||
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
pos_cis = unite_shape(pos_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
# precompute_pos_cis_real 函数用于预计算位置编码(实数版本)。
|
||||
def precompute_pos_cis_real(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
"""使用实数张量实现位置编码,避免使用复数张量
|
||||
|
||||
这个函数与precompute_pos_cis完全等价,但使用实数张量而非复数张量。
|
||||
原始函数生成形状为[seq_len, dim//2]的复数张量,其中实部全为1,虚部为旋转角度。
|
||||
这个函数生成形状为[seq_len, dim]的实数张量,其中偶数索引是cos(角度),奇数索引是sin(角度)。
|
||||
"""
|
||||
# 确保dim是偶数
|
||||
if dim % 2 != 0:
|
||||
raise ValueError(f"维度必须是偶数,但得到了 {dim}")
|
||||
|
||||
# 复制原始函数的频率计算逻辑
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device)
|
||||
freqs = torch.outer(t, freqs).float()
|
||||
|
||||
# 计算cos和sin值
|
||||
# 在复数版本中,pos_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
# 等价于 cos(freqs) + i*sin(freqs)
|
||||
cos = torch.cos(freqs)
|
||||
sin = torch.sin(freqs)
|
||||
|
||||
# 创建实数张量,交错排列cos和sin
|
||||
pos_emb = torch.zeros((end, dim), device=freqs.device)
|
||||
pos_emb[:, 0::2] = cos # 偶数索引放cos
|
||||
pos_emb[:, 1::2] = sin # 奇数索引放sin
|
||||
|
||||
return pos_emb
|
||||
|
||||
# apply_rotary_emb_real 函数用于应用旋转位置编码(实数版本)。
|
||||
def apply_rotary_emb_real(xq, xk, pos_emb):
|
||||
"""使用实数张量实现旋转位置编码,避免使用复数张量
|
||||
|
||||
这个函数与apply_rotary_emb完全等价,但使用实数张量而非复数张量。
|
||||
原始函数将输入张量转换为复数形式,与位置编码相乘,然后再转回实数形式。
|
||||
这个函数直接使用实数运算实现相同的旋转操作。
|
||||
"""
|
||||
# 获取形状信息
|
||||
bsz, seq_len, n_heads, head_dim = xq.shape
|
||||
|
||||
# 确保pos_emb形状正确
|
||||
assert pos_emb.shape[0] >= seq_len, f"位置编码长度 {pos_emb.shape[0]} 小于序列长度 {seq_len}"
|
||||
assert pos_emb.shape[1] == head_dim, f"位置编码维度 {pos_emb.shape[1]} 与头维度 {head_dim} 不匹配"
|
||||
|
||||
# 截取需要的位置编码长度
|
||||
pos_emb = pos_emb[:seq_len]
|
||||
|
||||
# 将pos_emb调整为广播形状 [1, seq_len, 1, head_dim]
|
||||
pos_emb = pos_emb.unsqueeze(0).unsqueeze(2)
|
||||
|
||||
# 将head_dim分成两半
|
||||
half_head_dim = head_dim // 2
|
||||
|
||||
# 提取cos和sin值(偶数索引是cos,奇数索引是sin)
|
||||
cos = pos_emb[..., 0::2]
|
||||
sin = pos_emb[..., 1::2]
|
||||
|
||||
# 将xq和xk重新排列,以便进行旋转操作
|
||||
# 原始复数版本中,xq和xk被重塑为复数张量,其中实部和虚部交错排列
|
||||
# 在实数版本中,我们需要将偶数索引和奇数索引分开处理
|
||||
|
||||
# 分离偶数和奇数索引
|
||||
xq_even = xq[..., 0::2] # 偶数索引,对应复数的实部
|
||||
xq_odd = xq[..., 1::2] # 奇数索引,对应复数的虚部
|
||||
xk_even = xk[..., 0::2]
|
||||
xk_odd = xk[..., 1::2]
|
||||
|
||||
# 应用旋转(等价于复数乘法)
|
||||
# (a + bi)(cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i
|
||||
# 其中a是偶数索引,b是奇数索引
|
||||
xq_out_even = xq_even * cos - xq_odd * sin # 新的偶数索引(实部)
|
||||
xq_out_odd = xq_even * sin + xq_odd * cos # 新的奇数索引(虚部)
|
||||
xk_out_even = xk_even * cos - xk_odd * sin
|
||||
xk_out_odd = xk_even * sin + xk_odd * cos
|
||||
|
||||
# 重新组合偶数和奇数索引
|
||||
xq_out = torch.zeros_like(xq)
|
||||
xk_out = torch.zeros_like(xk)
|
||||
xq_out[..., 0::2] = xq_out_even
|
||||
xq_out[..., 1::2] = xq_out_odd
|
||||
xk_out[..., 0::2] = xk_out_even
|
||||
xk_out[..., 1::2] = xk_out_odd
|
||||
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
# repeat_kv 函数用于重复键值对。
|
||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||||
bs, slen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x[:, :, :, None, :]
|
||||
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
||||
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: LMConfig):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
assert args.n_heads % self.n_kv_heads == 0
|
||||
self.n_local_heads = args.n_heads
|
||||
self.n_local_kv_heads = self.n_kv_heads
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
self.dropout = args.dropout
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
pos_cis: torch.Tensor,
|
||||
db_value=None):
|
||||
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换,得到查询、键和值。
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) #将变换后的张量xq重塑为形状为(bsz, seq_len, n_local_heads, head_dim)的形状。
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
|
||||
|
||||
# 应用旋转位置编码(使用实数版本)
|
||||
xq, xk = apply_rotary_emb_real(xq, xk, pos_cis)
|
||||
# kv_cache实现 REMOVED
|
||||
# if past_key_value is not None:
|
||||
# xk = torch.cat([past_key_value[0], xk], dim=1)
|
||||
# xv = torch.cat([past_key_value[1], xv], dim=1)
|
||||
# past_kv = (xk, xv) if use_cache else None
|
||||
|
||||
# 重复键值对
|
||||
xq, xk, xv = (
|
||||
xq.transpose(1, 2),
|
||||
repeat_kv(xk, self.n_rep).transpose(1, 2),
|
||||
repeat_kv(xv, self.n_rep).transpose(1, 2)
|
||||
)
|
||||
|
||||
# 如果提供了db_value,根据头的数量调整它的形状并与xv合并
|
||||
if db_value is not None:
|
||||
# 确保db_value的形状与xv兼容,假设db_value形状为[B, N, H, D]
|
||||
if db_value.ndim == 4: # [B, N, H, D]
|
||||
db_value = db_value.transpose(1, 2) # -> [B, H, N, D]
|
||||
|
||||
# 检查是否需要调整D维度
|
||||
if db_value.shape[-1] != xv.shape[-1]:
|
||||
# 如果db_value的维度与xv不同,可以添加一个投影层
|
||||
# 或者在这里使用简单的调整方法
|
||||
# 这里我们简单地通过均值池化或重复来调整维度
|
||||
if db_value.shape[-1] > xv.shape[-1]:
|
||||
# 降维
|
||||
factor = db_value.shape[-1] // xv.shape[-1]
|
||||
db_value = db_value.view(bsz, self.n_local_heads, seq_len, factor, xv.shape[-1])
|
||||
db_value = db_value.mean(dim=3)
|
||||
else:
|
||||
# 升维
|
||||
factor = xv.shape[-1] // db_value.shape[-1]
|
||||
db_value = db_value.unsqueeze(-1).repeat(1, 1, 1, 1, factor)
|
||||
db_value = db_value.view(bsz, self.n_local_heads, seq_len, xv.shape[-1])
|
||||
|
||||
# 将db_value与xv相加或融合
|
||||
# 这里我们简单地将它们相加,但你也可以使用其他融合方法
|
||||
xv = xv + db_value
|
||||
|
||||
# 使用Flash Attention
|
||||
if self.flash and seq_len != 1:
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
output = F.scaled_dot_product_attention(
|
||||
xq, xk, xv,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=True
|
||||
)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores += self.mask[:, :, :seq_len, :seq_len]
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
scores = self.attn_dropout(scores)
|
||||
output = scores @ xv
|
||||
|
||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
output = self.resid_dropout(self.wo(output))
|
||||
return output
|
||||
|
||||
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = 8
|
||||
self.head_dim = self.config.dim // self.num_heads
|
||||
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||
batch_size = x.size(0)
|
||||
|
||||
# 分离多头
|
||||
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if pos_emb is not None:
|
||||
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q = q + pos_emb
|
||||
k = k + pos_emb
|
||||
v = v + pos_emb
|
||||
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
if context_mask is not None:
|
||||
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
||||
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
|
||||
context = torch.matmul(attn_weights, v)
|
||||
|
||||
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
|
||||
|
||||
context = self.to_out(context)
|
||||
|
||||
return context
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
if config.hidden_dim is None:
|
||||
hidden_dim = 4 * config.dim
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
|
||||
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
|
||||
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
|
||||
self.scoring_func = config.scoring_func
|
||||
self.alpha = config.aux_loss_alpha
|
||||
self.seq_aux = config.seq_aux
|
||||
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gating_dim = config.dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
import torch.nn.init as init
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
if self.scoring_func == 'softmax':
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
if self.seq_aux:
|
||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||
ce.scatter_add_(1, topk_idx_for_aux_loss,
|
||||
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
|
||||
seq_len * aux_topk / self.n_routed_experts)
|
||||
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
||||
else:
|
||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||
ce = mask_ce.float().mean(0)
|
||||
Pi = scores_for_aux.mean(0)
|
||||
fi = ce * self.n_routed_experts
|
||||
aux_loss = (Pi * fi).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = 0
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
|
||||
class MOEFeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(config)
|
||||
for _ in range(config.n_routed_experts)
|
||||
])
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
self.shared_experts = FeedForward(config)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
orig_shape = x.shape
|
||||
bsz, seq_len, _ = x.shape
|
||||
# 使用门控机制选择专家
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
|
||||
y = torch.empty_like(x, dtype=torch.float16)
|
||||
for i, expert in enumerate(self.experts):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape)
|
||||
else:
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(identity)
|
||||
self.aux_loss = aux_loss
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
token_idxs = idxs // self.config.num_experts_per_tok
|
||||
# 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
|
||||
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
|
||||
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
|
||||
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens).to(expert_cache.dtype)
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
|
||||
|
||||
return expert_cache
|
||||
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, config: LMConfig):
|
||||
super().__init__()
|
||||
self.n_heads = config.n_heads
|
||||
self.dim = config.dim
|
||||
self.head_dim = config.dim // config.n_heads
|
||||
self.attention = Attention(config)
|
||||
self.cross_att = CrossAttention(config)
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||
|
||||
# 假设num_experts是已定义的总专家数量的平方根
|
||||
|
||||
|
||||
# 查询生成的参数
|
||||
|
||||
|
||||
# 创建查询生成模块
|
||||
# if weight_down_embed is not None:
|
||||
# self.to_queries = nn.Sequential(
|
||||
# nn.Linear(config.dim, self.dim_key * 2, bias=False),
|
||||
# # nn.Unflatten(2, (2, self.n_heads, self.dim_key)) # 替代Rearrange
|
||||
# )
|
||||
|
||||
# # 超参数
|
||||
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
|
||||
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
|
||||
|
||||
def forward(self, x, db_value, pos_cis):
|
||||
# import pdb;pdb.set_trace()
|
||||
# db_value = None
|
||||
|
||||
# # 如果有weight_down_embed,使用Product Key机制
|
||||
# if self.weight_down_embed is not None:
|
||||
# # 1. 生成queries
|
||||
# batch_size, seq_len, dim = x.shape
|
||||
|
||||
# # collapse sequence dimension by averaging
|
||||
# x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
# queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
||||
# queries = queries.reshape(batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
|
||||
# queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
|
||||
|
||||
# # 2. 计算queries与keys的相似度
|
||||
# sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# # 3. 在两个子空间分别做top-k
|
||||
# scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
# scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
# indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# # 4. 组合两个子空间的分数和索引
|
||||
# all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
|
||||
# all_scores = all_scores.view(*all_scores.shape[:-2], -1)
|
||||
|
||||
# all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
|
||||
# all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
||||
|
||||
# # 5. 最终top-k选择
|
||||
# scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
|
||||
# indices = all_indices.gather(-1, pk_indices)
|
||||
|
||||
# # 6. 从embedding中获取专家值
|
||||
|
||||
# # 从embedding中获取值
|
||||
# flat_indices = indices.view(-1) # 将索引展平为一维张量
|
||||
# db_values = self.weight_down_embed(flat_indices)
|
||||
|
||||
# # 重塑回原始形状
|
||||
# db_value = db_values.view(batch_size, -1, dim)
|
||||
|
||||
|
||||
# 注意力计算
|
||||
h_attn = self.attention(
|
||||
self.attention_norm(x),
|
||||
pos_cis,
|
||||
db_value=db_value
|
||||
)
|
||||
|
||||
h_attn = self.cross_att(h_attn, db_value)
|
||||
|
||||
# 残差连接
|
||||
h = x + h_attn
|
||||
|
||||
# 前馈神经网络
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
class ExtractDB(nn.Module):
|
||||
def __init__(self,params):
|
||||
# 修改专家数量和知识维度,确保能开方
|
||||
super().__init__()
|
||||
self.batch_size = None
|
||||
self.dim = params.dim
|
||||
self.dim_key = self.dim // 2
|
||||
self.knowlwdge_num = params.knowlwdge_num # 100专家,确保是完全平方数
|
||||
# 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用
|
||||
self.head_dim = params.dim // params.n_heads
|
||||
self.knowledge_length = params.knowlwdge_length*params.dim
|
||||
|
||||
# 使用register_buffer代替nn.Parameter,避免梯度问题
|
||||
self.register_buffer('weight_down_embed', torch.randn(self.knowlwdge_num, self.knowledge_length) * 0.02)
|
||||
|
||||
self.num_keys = int(math.sqrt(self.knowlwdge_num)) if self.knowlwdge_num > 0 else 0
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
|
||||
self.num_experts_per_head_topk = 1
|
||||
self.to_queries = nn.Sequential(
|
||||
nn.Linear(params.dim, self.dim_key * 2, bias=False),
|
||||
)
|
||||
|
||||
def q_to_k(self,x):
|
||||
# 1. 生成queries
|
||||
self.batch_size, seq_len, dim = x.shape
|
||||
|
||||
# collapse sequence dimension by averaging
|
||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
||||
queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
|
||||
queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
|
||||
|
||||
# 2. 计算queries与keys的相似度
|
||||
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# 3. 在两个子空间分别做top-k
|
||||
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# 4. 组合两个子空间的分数和索引
|
||||
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
|
||||
all_scores = all_scores.view(*all_scores.shape[:-2], -1)
|
||||
|
||||
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
|
||||
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
||||
|
||||
# 5. 最终top-k选择
|
||||
scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
|
||||
indices = all_indices.gather(-1, pk_indices)
|
||||
flat_indices = indices.view(-1)
|
||||
return flat_indices
|
||||
|
||||
def get_data(self, index):
|
||||
# 直接从GPU获取embedding
|
||||
db_values = self.weight_down_embed[index]
|
||||
db_value = db_values.view(self.batch_size, -1, self.dim)
|
||||
return db_value
|
||||
|
||||
@torch.no_grad()
|
||||
def updata_value(self, k, v):
|
||||
# 直接更新buffer上的值 (不需要梯度)
|
||||
v_reshaped = v.view(v.size(0), -1)
|
||||
# 确保数据类型匹配
|
||||
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
|
||||
self.weight_down_embed[k] = v_reshaped
|
||||
|
||||
|
||||
|
||||
class MiniMindLM(PreTrainedModel):
|
||||
config_class = LMConfig
|
||||
|
||||
def __init__(self, params: LMConfig = None):
|
||||
self.params = params or LMConfig()
|
||||
super().__init__(self.params)
|
||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||||
self.dropout = nn.Dropout(params.dropout)
|
||||
# 移除旧的weight_down_embed声明
|
||||
self.extract_db = ExtractDB(self.params)
|
||||
|
||||
# 将self.weight_down_embed传递给每个MiniMindBlock
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
|
||||
# Calculate input dimension
|
||||
input_dim = (self.params.max_seq_len-1)*self.params.n_layers
|
||||
# Use a bottleneck architecture to reduce parameters
|
||||
bottleneck_dim = 256 # Significantly smaller bottleneck dimension
|
||||
|
||||
# Factorized shared downsampling using two smaller convolutions
|
||||
self.shared_downsample = nn.Sequential(
|
||||
# First reduce input dimension to bottleneck
|
||||
nn.Conv1d(input_dim, bottleneck_dim, kernel_size=1, padding='same'),
|
||||
nn.ReLU(), # Non-linearity to improve representation capacity
|
||||
# Then expand to target dimension
|
||||
nn.Conv1d(bottleneck_dim, 128*8, kernel_size=1, padding='same')
|
||||
)
|
||||
|
||||
# Specific layers for v path
|
||||
self.downsample_v_specific = nn.Sequential(
|
||||
nn.Conv1d(128*8, 128, kernel_size=1, padding='same'),
|
||||
nn.Conv1d(128, 8, kernel_size=1, padding='same')
|
||||
)
|
||||
|
||||
# Specific layers for q path
|
||||
self.downsample_q_specific = nn.Sequential(
|
||||
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
|
||||
)
|
||||
# 使用实数版本的位置编码,避免复数张量可能导致的段错误
|
||||
self.register_buffer("pos_cis_real",
|
||||
precompute_pos_cis_real(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
self.params = params
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**args):
|
||||
start_pos = args.get('start_pos', 0)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis_real = self.pos_cis_real[start_pos:start_pos + input_ids.size(1)]
|
||||
h_list = []
|
||||
|
||||
for l, layer in enumerate(self.layers):
|
||||
# 禁用数据库模式,使用固定值替代数据库查询
|
||||
if self.params.disable_db:
|
||||
# 创建一个形状为[batch_size, n_layers, dim]的tensor,所有元素值为1e-4
|
||||
batch_size = h.size(0)
|
||||
db_value = torch.full((batch_size, self.n_layers, self.params.dim), 1e-4,
|
||||
dtype=h.dtype, device=h.device)
|
||||
else:
|
||||
# 正常模式,使用数据库查询
|
||||
index = self.extract_db.q_to_k(h)
|
||||
db_value = self.extract_db.get_data(index)
|
||||
|
||||
h = layer(
|
||||
h, db_value, pos_cis_real
|
||||
)
|
||||
|
||||
h_list.append(h.unsqueeze(0))
|
||||
|
||||
h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3)
|
||||
|
||||
# 只在非禁用数据库模式下执行数据库更新逻辑
|
||||
if not self.params.disable_db:
|
||||
# 使用detach()分离计算图,避免多次反向传播
|
||||
h_tensor_detached = h_tensor.detach()
|
||||
h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0], -1, self.params.dim)
|
||||
|
||||
# 数据库更新逻辑与主计算图分离
|
||||
with torch.no_grad():
|
||||
# Compute shared downsampling layer once
|
||||
shared_features = self.shared_downsample(h_tensor_detached)
|
||||
z_v = self.downsample_v_specific(shared_features)
|
||||
z_q = self.downsample_q_specific(shared_features)
|
||||
z_k = self.extract_db.q_to_k(z_q)
|
||||
self.extract_db.updata_value(z_k, z_v)
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
||||
|
||||
# 进一步简化,只保留必要的参数
|
||||
output = CausalLMOutputWithPast(
|
||||
logits=logits,
|
||||
)
|
||||
output.hidden_states = h
|
||||
|
||||
output.aux_loss = aux_loss
|
||||
|
||||
# 尝试添加其他属性(如果支持的话)
|
||||
# try:
|
||||
# output.hidden_states = h
|
||||
# except:
|
||||
# pass
|
||||
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||||
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
|
||||
# 流式生成
|
||||
if stream:
|
||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
|
||||
# 直接生成
|
||||
generated = []
|
||||
for i in range(input_ids.size(0)):
|
||||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||||
for _ in range(num_return_sequences):
|
||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
tokens_list = [tokens[:, -1:] for tokens in out]
|
||||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||||
generated.append(full_sequence)
|
||||
|
||||
max_length = max(seq.size(1) for seq in generated)
|
||||
generated = [
|
||||
torch.cat(
|
||||
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
|
||||
dim=-1)
|
||||
for seq in generated
|
||||
]
|
||||
output = torch.cat(generated, dim=0)
|
||||
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||||
return res
|
||||
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
|
||||
start, first_seq = input_ids.shape[1], True
|
||||
while input_ids.shape[1] < max_new_tokens - 1:
|
||||
if first_seq:
|
||||
out, first_seq = self(input_ids, **args), False
|
||||
else:
|
||||
out = self(input_ids[:, -1:], start_pos=input_ids.shape[1] - 1, **args)
|
||||
logits = out.logits[:, -1, :]
|
||||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||||
logits /= (temperature + 1e-9)
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||
sorted_indices_to_remove[:, 0] = False
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = -float('Inf')
|
||||
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
||||
yield input_ids[:, start:]
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
49
model/model_lora.py
Normal file
49
model/model_lora.py
Normal file
@ -0,0 +1,49 @@
|
||||
import torch
|
||||
from torch import optim, nn
|
||||
|
||||
|
||||
# 定义Lora网络结构
|
||||
class LoRA(nn.Module):
|
||||
def __init__(self, in_features, out_features, rank):
|
||||
super().__init__()
|
||||
self.rank = rank # LoRA的秩(rank),控制低秩矩阵的大小
|
||||
self.A = nn.Linear(in_features, rank, bias=False) # 低秩矩阵A
|
||||
self.B = nn.Linear(rank, out_features, bias=False) # 低秩矩阵B
|
||||
# 矩阵A高斯初始化
|
||||
self.A.weight.data.normal_(mean=0.0, std=0.02)
|
||||
# 矩阵B全0初始化
|
||||
self.B.weight.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
return self.B(self.A(x))
|
||||
|
||||
|
||||
def apply_lora(model, rank=16):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]:
|
||||
lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device)
|
||||
setattr(module, "lora", lora)
|
||||
original_forward = module.forward
|
||||
|
||||
# 显式绑定
|
||||
def forward_with_lora(x, layer1=original_forward, layer2=lora):
|
||||
return layer1(x) + layer2(x)
|
||||
|
||||
module.forward = forward_with_lora
|
||||
|
||||
|
||||
def load_lora(model, path):
|
||||
state_dict = torch.load(path, map_location=model.device)
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'lora'):
|
||||
lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k}
|
||||
module.lora.load_state_dict(lora_state)
|
||||
|
||||
|
||||
def save_lora(model, path):
|
||||
state_dict = {}
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'lora'):
|
||||
lora_state = {f'{name}.lora.{k}': v for k, v in module.lora.state_dict().items()}
|
||||
state_dict.update(lora_state)
|
||||
torch.save(state_dict, path)
|
120
requirements.txt
Normal file
120
requirements.txt
Normal file
@ -0,0 +1,120 @@
|
||||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.11.17
|
||||
aiosignal==1.3.2
|
||||
altair==5.5.0
|
||||
annotated-types==0.7.0
|
||||
anyio==4.9.0
|
||||
async-timeout==5.0.1
|
||||
attrs==25.3.0
|
||||
blinker==1.9.0
|
||||
cachetools==5.5.2
|
||||
certifi==2025.1.31
|
||||
charset-normalizer==3.4.1
|
||||
click==8.1.8
|
||||
contourpy==1.3.2
|
||||
cycler==0.12.1
|
||||
datasets==2.21.0
|
||||
datasketch==1.6.4
|
||||
dill==0.3.8
|
||||
distro==1.9.0
|
||||
docker-pycreds==0.4.0
|
||||
einops==0.8.1
|
||||
exceptiongroup==1.2.2
|
||||
filelock==3.18.0
|
||||
Flask==3.0.3
|
||||
Flask-Cors==4.0.0
|
||||
fonttools==4.57.0
|
||||
frozenlist==1.6.0
|
||||
fsspec==2024.6.1
|
||||
gitdb==4.0.12
|
||||
GitPython==3.1.44
|
||||
h11==0.14.0
|
||||
hjson==3.1.0
|
||||
httpcore==1.0.8
|
||||
httpx==0.28.1
|
||||
huggingface-hub==0.30.2
|
||||
idna==3.10
|
||||
importlib_metadata==7.2.1
|
||||
itsdangerous==2.2.0
|
||||
jieba==0.42.1
|
||||
Jinja2==3.1.2
|
||||
jiter==0.9.0
|
||||
joblib==1.4.2
|
||||
jsonlines==4.0.0
|
||||
jsonschema==4.23.0
|
||||
jsonschema-specifications==2024.10.1
|
||||
kiwisolver==1.4.8
|
||||
markdown-it-py==3.0.0
|
||||
MarkupSafe==3.0.2
|
||||
marshmallow==3.22.0
|
||||
matplotlib==3.10.0
|
||||
mdurl==0.1.2
|
||||
modelscope==1.25.0
|
||||
mpmath==1.3.0
|
||||
msgpack==1.1.0
|
||||
multidict==6.4.3
|
||||
multiprocess==0.70.16
|
||||
narwhals==1.35.0
|
||||
networkx==3.4.2
|
||||
ngrok==1.4.0
|
||||
ninja==1.11.1.4
|
||||
nltk==3.8
|
||||
numpy==1.26.4
|
||||
openai==1.59.6
|
||||
packaging==23.2
|
||||
pandas==1.5.3
|
||||
peft==0.7.1
|
||||
pillow==10.4.0
|
||||
platformdirs==4.3.7
|
||||
propcache==0.3.1
|
||||
protobuf==4.25.6
|
||||
psutil==5.9.8
|
||||
py-cpuinfo==9.0.0
|
||||
pyarrow==19.0.1
|
||||
pydantic==2.8.2
|
||||
pydantic_core==2.20.1
|
||||
pydeck==0.9.1
|
||||
Pygments==2.19.1
|
||||
pyparsing==3.2.3
|
||||
python-dateutil==2.9.0.post0
|
||||
pytz==2025.2
|
||||
PyYAML==6.0.2
|
||||
referencing==0.36.2
|
||||
regex==2024.11.6
|
||||
requests==2.32.3
|
||||
rich==13.7.1
|
||||
rpds-py==0.24.0
|
||||
safetensors==0.5.3
|
||||
scikit-learn==1.5.1
|
||||
scipy==1.15.2
|
||||
sentence-transformers==2.3.1
|
||||
sentencepiece==0.2.0
|
||||
sentry-sdk==2.26.1
|
||||
setproctitle==1.3.5
|
||||
simhash==2.1.2
|
||||
six==1.17.0
|
||||
smmap==5.0.2
|
||||
sniffio==1.3.1
|
||||
streamlit==1.30.0
|
||||
sympy==1.13.3
|
||||
tenacity==8.5.0
|
||||
threadpoolctl==3.6.0
|
||||
tiktoken==0.5.1
|
||||
tokenizers==0.21.1
|
||||
toml==0.10.2
|
||||
tornado==6.4.2
|
||||
tqdm==4.67.1
|
||||
transformers==4.48.0
|
||||
triton==3.3.0
|
||||
trl==0.13.0
|
||||
typing_extensions==4.13.2
|
||||
tzlocal==5.3.1
|
||||
ujson==5.1.0
|
||||
urllib3==2.4.0
|
||||
validators==0.34.0
|
||||
wandb==0.18.3
|
||||
watchdog==6.0.0
|
||||
Werkzeug==3.1.3
|
||||
xxhash==3.5.0
|
||||
yarl==1.20.0
|
||||
zipp==3.21.0
|
47
run_file/DynamicKV-LLM_Mini_Minimind.sh
Normal file
47
run_file/DynamicKV-LLM_Mini_Minimind.sh
Normal file
@ -0,0 +1,47 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 激活conda环境
|
||||
# source $(conda info --base)/etc/profile.d/conda.sh
|
||||
# conda activate ycz_accelerate
|
||||
|
||||
# 设置环境变量以帮助调试
|
||||
export NCCL_DEBUG=INFO
|
||||
export PYTHONFAULTHANDLER=1
|
||||
|
||||
# 方法1: 使用预先配置的accelerate配置文件
|
||||
# accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
|
||||
# --epochs 3 \
|
||||
# --batch_size 24 \
|
||||
# --learning_rate 2e-4 \
|
||||
# --dtype bfloat16 \
|
||||
# --accumulation_steps 32 \
|
||||
# --grad_clip 1.0 \
|
||||
# --log_interval 100 \
|
||||
# --save_interval 10000 \
|
||||
# --dim 1024 \
|
||||
# --n_layers 32 \
|
||||
# --max_seq_len 1024 \
|
||||
# --use_flash_attn \
|
||||
# --profile \
|
||||
# --profile_interval 10
|
||||
|
||||
# 方法2: 使用命令行参数直接配置accelerate
|
||||
CUDA_VISIBLE_DEVICES=0 accelerate launch \
|
||||
--num_processes=1 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py \
|
||||
--epochs 3 \
|
||||
--batch_size 24 \
|
||||
--learning_rate 2e-4 \
|
||||
--dtype bfloat16 \
|
||||
--accumulation_steps 32 \
|
||||
--grad_clip 1.0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 512 \
|
||||
--n_layers 12 \
|
||||
--max_seq_len 512 \
|
||||
--use_flash_attn \
|
||||
--profile \
|
||||
--profile_interval 10
|
48
run_file/DynamicKV-LLM_Small_Minimind.sh
Normal file
48
run_file/DynamicKV-LLM_Small_Minimind.sh
Normal file
@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 激活conda环境
|
||||
source $(conda info --base)/etc/profile.d/conda.sh
|
||||
conda activate ycz_accelerate
|
||||
|
||||
# 设置环境变量以帮助调试
|
||||
export NCCL_DEBUG=INFO
|
||||
export PYTHONFAULTHANDLER=1
|
||||
|
||||
# 方法1: 使用预先配置的accelerate配置文件
|
||||
# accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
|
||||
# --epochs 3 \
|
||||
# --batch_size 24 \
|
||||
# --learning_rate 2e-4 \
|
||||
# --dtype bfloat16 \
|
||||
# --accumulation_steps 32 \
|
||||
# --grad_clip 1.0 \
|
||||
# --log_interval 100 \
|
||||
# --save_interval 10000 \
|
||||
# --dim 1024 \
|
||||
# --n_layers 32 \
|
||||
# --max_seq_len 1024 \
|
||||
# --use_flash_attn \
|
||||
# --profile \
|
||||
# --profile_interval 10
|
||||
|
||||
# 方法2: 使用命令行参数直接配置accelerate
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py \
|
||||
--epochs 3 \
|
||||
--batch_size 24 \
|
||||
--learning_rate 2e-4 \
|
||||
--dtype bfloat16 \
|
||||
--accumulation_steps 32 \
|
||||
--grad_clip 1.0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 1024 \
|
||||
--n_layers 32 \
|
||||
--max_seq_len 1024 \
|
||||
--use_flash_attn \
|
||||
--profile \
|
||||
--profile_interval 10
|
30
scripts/chat_openai_api.py
Normal file
30
scripts/chat_openai_api.py
Normal file
@ -0,0 +1,30 @@
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key="none",
|
||||
base_url="http://localhost:8998/v1"
|
||||
)
|
||||
stream = True
|
||||
conversation_history_origin = []
|
||||
conversation_history = conversation_history_origin.copy()
|
||||
history_messages_num = 2 # 设置为偶数(Q+A),为0则每次不携带历史对话进行独立QA
|
||||
while True:
|
||||
query = input('[Q]: ')
|
||||
conversation_history.append({"role": "user", "content": query})
|
||||
response = client.chat.completions.create(
|
||||
model="minimind",
|
||||
messages=conversation_history[-history_messages_num:],
|
||||
stream=stream
|
||||
)
|
||||
if not stream:
|
||||
assistant_res = response.choices[0].message.content
|
||||
print('[A]: ', assistant_res)
|
||||
else:
|
||||
print('[A]: ', end='')
|
||||
assistant_res = ''
|
||||
for chunk in response:
|
||||
print(chunk.choices[0].delta.content or "", end="")
|
||||
assistant_res += chunk.choices[0].delta.content or ""
|
||||
|
||||
conversation_history.append({"role": "assistant", "content": assistant_res})
|
||||
print('\n\n')
|
62
scripts/convert_model.py
Normal file
62
scripts/convert_model.py
Normal file
@ -0,0 +1,62 @@
|
||||
import torch
|
||||
import warnings
|
||||
import sys
|
||||
import os
|
||||
|
||||
__package__ = "scripts"
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from model.LMConfig import LMConfig
|
||||
from model.model import MiniMindLM
|
||||
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
|
||||
def convert_torch2transformers(torch_path, transformers_path):
|
||||
def export_tokenizer(transformers_path):
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/minimind_tokenizer')
|
||||
tokenizer.save_pretrained(transformers_path)
|
||||
|
||||
LMConfig.register_for_auto_class()
|
||||
MiniMindLM.register_for_auto_class("AutoModelForCausalLM")
|
||||
lm_model = MiniMindLM(lm_config)
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
state_dict = torch.load(torch_path, map_location=device)
|
||||
lm_model.load_state_dict(state_dict, strict=False)
|
||||
model_params = sum(p.numel() for p in lm_model.parameters() if p.requires_grad)
|
||||
print(f'模型参数: {model_params / 1e6} 百万 = {model_params / 1e9} B (Billion)')
|
||||
lm_model.save_pretrained(transformers_path, safe_serialization=False)
|
||||
export_tokenizer(transformers_path)
|
||||
print(f"模型已保存为 Transformers 格式: {transformers_path}")
|
||||
|
||||
|
||||
def convert_transformers2torch(transformers_path, torch_path):
|
||||
model = AutoModelForCausalLM.from_pretrained(transformers_path, trust_remote_code=True)
|
||||
torch.save(model.state_dict(), torch_path)
|
||||
print(f"模型已保存为 PyTorch 格式: {torch_path}")
|
||||
|
||||
|
||||
# don't need to use
|
||||
def push_to_hf(export_model_path):
|
||||
def init_model():
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/minimind_tokenizer')
|
||||
model = AutoModelForCausalLM.from_pretrained(export_model_path, trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
model, tokenizer = init_model()
|
||||
# model.push_to_hub(model_path)
|
||||
# tokenizer.push_to_hub(model_path, safe_serialization=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
lm_config = LMConfig(dim=512, n_layers=8, max_seq_len=8192, use_moe=False)
|
||||
|
||||
torch_path = f"../out/rlhf_{lm_config.dim}{'_moe' if lm_config.use_moe else ''}.pth"
|
||||
|
||||
transformers_path = '../MiniMind2-Small'
|
||||
|
||||
# convert torch to transformers model
|
||||
convert_torch2transformers(torch_path, transformers_path)
|
||||
|
||||
# # convert transformers to torch model
|
||||
# convert_transformers2torch(transformers_path, torch_path)
|
164
scripts/serve_openai_api.py
Normal file
164
scripts/serve_openai_api.py
Normal file
@ -0,0 +1,164 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
__package__ = "scripts"
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
import time
|
||||
import torch
|
||||
import warnings
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from model.LMConfig import LMConfig
|
||||
from model.model import MiniMindLM
|
||||
from model.model_lora import apply_lora, load_lora
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def init_model(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/minimind_tokenizer')
|
||||
if args.load == 0:
|
||||
moe_path = '_moe' if args.use_moe else ''
|
||||
modes = {0: 'pretrain', 1: 'full_sft', 2: 'rlhf', 3: 'reason'}
|
||||
ckp = f'../{args.out_dir}/{modes[args.model_mode]}_{args.dim}{moe_path}.pth'
|
||||
|
||||
model = MiniMindLM(LMConfig(
|
||||
dim=args.dim,
|
||||
n_layers=args.n_layers,
|
||||
max_seq_len=args.max_seq_len,
|
||||
use_moe=args.use_moe
|
||||
))
|
||||
|
||||
state_dict = torch.load(ckp, map_location=device)
|
||||
model.load_state_dict({k: v for k, v in state_dict.items() if 'mask' not in k}, strict=True)
|
||||
|
||||
if args.lora_name != 'None':
|
||||
apply_lora(model)
|
||||
load_lora(model, f'../{args.out_dir}/{args.lora_name}_{args.dim}.pth')
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
'./MiniMind2',
|
||||
trust_remote_code=True
|
||||
)
|
||||
print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M(illion)')
|
||||
return model.eval().to(device), tokenizer
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
model: str
|
||||
messages: list
|
||||
temperature: float = 0.7
|
||||
top_p: float = 0.92
|
||||
max_tokens: int = 8192
|
||||
stream: bool = False
|
||||
|
||||
|
||||
def generate_stream_response(messages, temperature, top_p, max_tokens):
|
||||
try:
|
||||
new_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)[-max_tokens:]
|
||||
x = tokenizer(new_prompt).data['input_ids']
|
||||
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
|
||||
with torch.no_grad():
|
||||
res_y = model.generate(
|
||||
x,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
stream=True,
|
||||
rp=1.,
|
||||
pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
history_idx = 0
|
||||
for y in res_y:
|
||||
answer = tokenizer.decode(y[0].tolist(), skip_special_tokens=True)
|
||||
if (answer and answer[-1] == '<EFBFBD>') or not answer:
|
||||
continue
|
||||
delta = answer[history_idx:]
|
||||
history_idx = len(answer)
|
||||
json_data = {
|
||||
'id': f'chatcmpl-{int(time.time())}',
|
||||
'object': 'chat.completion.chunk',
|
||||
'created': int(time.time()),
|
||||
'model': 'minimind',
|
||||
'choices': [{'index': 0, 'delta': {'content': delta}, 'finish_reason': None}]
|
||||
}
|
||||
yield f"data: {json.dumps(json_data)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: ChatRequest):
|
||||
try:
|
||||
if request.stream:
|
||||
return StreamingResponse(
|
||||
generate_stream_response(
|
||||
messages=request.messages,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_tokens=request.max_tokens
|
||||
),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
new_prompt = tokenizer.apply_chat_template(
|
||||
request.messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)[-request.max_tokens:]
|
||||
x = tokenizer(new_prompt).data['input_ids']
|
||||
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
|
||||
with torch.no_grad():
|
||||
res_y = model.generate(
|
||||
x,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
max_new_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stream=False,
|
||||
rp=1.,
|
||||
pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
answer = tokenizer.decode(res_y.squeeze()[x.shape[1]:].tolist(), skip_special_tokens=True)
|
||||
return {
|
||||
"id": f"chatcmpl-{int(time.time())}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": "minimind",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": answer},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Server for MiniMind")
|
||||
parser.add_argument('--out_dir', default='out', type=str)
|
||||
parser.add_argument('--lora_name', default='None', type=str)
|
||||
parser.add_argument('--dim', default=512, type=int)
|
||||
parser.add_argument('--n_layers', default=8, type=int)
|
||||
parser.add_argument('--max_seq_len', default=8192, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument('--load', default=0, type=int, help="0: 从原生torch权重,1: 利用transformers加载")
|
||||
parser.add_argument('--model_mode', default=1, type=int, help="0: 预训练模型,1: SFT-Chat模型,2: RLHF-Chat模型,3: Reason模型")
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model, tokenizer = init_model(parser.parse_args())
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8998)
|
152
scripts/train_tokenizer.py
Normal file
152
scripts/train_tokenizer.py
Normal file
@ -0,0 +1,152 @@
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
import json
|
||||
from datasets import load_dataset
|
||||
from tokenizers import (
|
||||
decoders,
|
||||
models,
|
||||
normalizers,
|
||||
pre_tokenizers,
|
||||
processors,
|
||||
trainers,
|
||||
Tokenizer,
|
||||
)
|
||||
import os
|
||||
|
||||
random.seed(42)
|
||||
|
||||
|
||||
def train_tokenizer():
|
||||
# 读取JSONL文件并提取文本数据
|
||||
def read_texts_from_jsonl(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
yield data['text']
|
||||
|
||||
data_path = '../dataset/pretrain_hq.jsonl'
|
||||
|
||||
# 初始化tokenizer
|
||||
tokenizer = Tokenizer(models.BPE())
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
||||
|
||||
# 定义特殊token
|
||||
special_tokens = ["<unk>", "<s>", "</s>"]
|
||||
|
||||
# 设置训练器并添加特殊token
|
||||
trainer = trainers.BpeTrainer(
|
||||
vocab_size=6400,
|
||||
special_tokens=special_tokens, # 确保这三个token被包含
|
||||
show_progress=True,
|
||||
initial_alphabet=pre_tokenizers.ByteLevel.alphabet()
|
||||
)
|
||||
|
||||
# 读取文本数据
|
||||
texts = read_texts_from_jsonl(data_path)
|
||||
|
||||
# 训练tokenizer
|
||||
tokenizer.train_from_iterator(texts, trainer=trainer)
|
||||
|
||||
# 设置解码器
|
||||
tokenizer.decoder = decoders.ByteLevel()
|
||||
|
||||
# 检查特殊token的索引
|
||||
assert tokenizer.token_to_id("<unk>") == 0
|
||||
assert tokenizer.token_to_id("<s>") == 1
|
||||
assert tokenizer.token_to_id("</s>") == 2
|
||||
|
||||
# 保存tokenizer
|
||||
tokenizer_dir = "../model/minimind_tokenizer"
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||
tokenizer.model.save("../model/minimind_tokenizer")
|
||||
|
||||
# 手动创建配置文件
|
||||
config = {
|
||||
"add_bos_token": False,
|
||||
"add_eos_token": False,
|
||||
"add_prefix_space": False,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<unk>",
|
||||
"lstrip": False,
|
||||
"normalized": False,
|
||||
"rstrip": False,
|
||||
"single_word": False,
|
||||
"special": True
|
||||
},
|
||||
"1": {
|
||||
"content": "<s>",
|
||||
"lstrip": False,
|
||||
"normalized": False,
|
||||
"rstrip": False,
|
||||
"single_word": False,
|
||||
"special": True
|
||||
},
|
||||
"2": {
|
||||
"content": "</s>",
|
||||
"lstrip": False,
|
||||
"normalized": False,
|
||||
"rstrip": False,
|
||||
"single_word": False,
|
||||
"special": True
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [],
|
||||
"bos_token": "<s>",
|
||||
"clean_up_tokenization_spaces": False,
|
||||
"eos_token": "</s>",
|
||||
"legacy": True,
|
||||
"model_max_length": 32768,
|
||||
"pad_token": "<unk>",
|
||||
"sp_model_kwargs": {},
|
||||
"spaces_between_special_tokens": False,
|
||||
"tokenizer_class": "PreTrainedTokenizerFast",
|
||||
"unk_token": "<unk>",
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<s>system\\n' + system_message + '</s>\\n' }}{% else %}{{ '<s>system\\n你是 MiniMind,是一个有用的人工智能助手。</s>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
|
||||
}
|
||||
|
||||
# 保存配置文件
|
||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w", encoding="utf-8") as config_file:
|
||||
json.dump(config, config_file, ensure_ascii=False, indent=4)
|
||||
|
||||
print("Tokenizer training completed and saved.")
|
||||
|
||||
|
||||
def eval_tokenizer():
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# 加载预训练的tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("../model/minimind_tokenizer")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个优秀的聊天机器人,总是给我正确的回应!"},
|
||||
{"role": "user", "content": '你来自哪里?'},
|
||||
{"role": "assistant", "content": '我来自地球'}
|
||||
]
|
||||
new_prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False
|
||||
)
|
||||
print(new_prompt)
|
||||
|
||||
# 获取实际词汇表长度(包括特殊符号)
|
||||
actual_vocab_size = len(tokenizer)
|
||||
print('tokenizer实际词表长度:', actual_vocab_size)
|
||||
|
||||
model_inputs = tokenizer(new_prompt)
|
||||
print('encoder长度:', len(model_inputs['input_ids']))
|
||||
|
||||
input_ids = model_inputs['input_ids']
|
||||
response = tokenizer.decode(input_ids, skip_special_tokens=False)
|
||||
print('decoder和原始文本是否一致:', response == new_prompt)
|
||||
|
||||
|
||||
def main():
|
||||
train_tokenizer()
|
||||
eval_tokenizer()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
293
scripts/web_demo.py
Normal file
293
scripts/web_demo.py
Normal file
@ -0,0 +1,293 @@
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import torch
|
||||
|
||||
st.set_page_config(page_title="MiniMind", initial_sidebar_state="collapsed")
|
||||
|
||||
# 在文件开头的 CSS 样式中修改按钮样式
|
||||
st.markdown("""
|
||||
<style>
|
||||
/* 添加操作按钮样式 */
|
||||
.stButton button {
|
||||
border-radius: 50% !important; /* 改为圆形 */
|
||||
width: 32px !important; /* 固定宽度 */
|
||||
height: 32px !important; /* 固定高度 */
|
||||
padding: 0 !important; /* 移除内边距 */
|
||||
background-color: transparent !important;
|
||||
border: 1px solid #ddd !important;
|
||||
display: flex !important;
|
||||
align-items: center !important;
|
||||
justify-content: center !important;
|
||||
font-size: 14px !important;
|
||||
color: #666 !important; /* 更柔和的颜色 */
|
||||
margin: 5px 10px 5px 0 !important; /* 调整按钮间距 */
|
||||
}
|
||||
.stButton button:hover {
|
||||
border-color: #999 !important;
|
||||
color: #333 !important;
|
||||
background-color: #f5f5f5 !important;
|
||||
}
|
||||
.stMainBlockContainer > div:first-child {
|
||||
margin-top: -50px !important;
|
||||
}
|
||||
.stApp > div:last-child {
|
||||
margin-bottom: -35px !important;
|
||||
}
|
||||
|
||||
/* 重置按钮基础样式 */
|
||||
.stButton > button {
|
||||
all: unset !important; /* 重置所有默认样式 */
|
||||
box-sizing: border-box !important;
|
||||
border-radius: 50% !important;
|
||||
width: 18px !important;
|
||||
height: 18px !important;
|
||||
min-width: 18px !important;
|
||||
min-height: 18px !important;
|
||||
max-width: 18px !important;
|
||||
max-height: 18px !important;
|
||||
padding: 0 !important;
|
||||
background-color: transparent !important;
|
||||
border: 1px solid #ddd !important;
|
||||
display: flex !important;
|
||||
align-items: center !important;
|
||||
justify-content: center !important;
|
||||
font-size: 14px !important;
|
||||
color: #888 !important;
|
||||
cursor: pointer !important;
|
||||
transition: all 0.2s ease !important;
|
||||
margin: 0 2px !important; /* 调整这里的 margin 值 */
|
||||
}
|
||||
|
||||
</style>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
system_prompt = []
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def process_assistant_content(content):
|
||||
if 'R1' not in MODEL_PATHS[selected_model][1]:
|
||||
return content
|
||||
|
||||
if '<think>' in content and '</think>' in content:
|
||||
content = re.sub(r'(<think>)(.*?)(</think>)',
|
||||
r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理内容(展开)</summary>\2</details>',
|
||||
content,
|
||||
flags=re.DOTALL)
|
||||
|
||||
if '<think>' in content and '</think>' not in content:
|
||||
content = re.sub(r'<think>(.*?)$',
|
||||
r'<details open style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理中...</summary>\1</details>',
|
||||
content,
|
||||
flags=re.DOTALL)
|
||||
|
||||
if '<think>' not in content and '</think>' in content:
|
||||
content = re.sub(r'(.*?)</think>',
|
||||
r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理内容(展开)</summary>\1</details>',
|
||||
content,
|
||||
flags=re.DOTALL)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def load_model_tokenizer(model_path):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=True
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=True
|
||||
)
|
||||
model = model.eval().to(device)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def clear_chat_messages():
|
||||
del st.session_state.messages
|
||||
del st.session_state.chat_messages
|
||||
|
||||
|
||||
def init_chat_messages():
|
||||
if "messages" in st.session_state:
|
||||
for i, message in enumerate(st.session_state.messages):
|
||||
if message["role"] == "assistant":
|
||||
with st.chat_message("assistant", avatar=image_url):
|
||||
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
|
||||
# 在消息内容下方添加按钮
|
||||
if st.button("🗑", key=f"delete_{i}"):
|
||||
st.session_state.messages.pop(i)
|
||||
st.session_state.messages.pop(i - 1)
|
||||
st.session_state.chat_messages.pop(i)
|
||||
st.session_state.chat_messages.pop(i - 1)
|
||||
st.rerun()
|
||||
else:
|
||||
st.markdown(
|
||||
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: #ddd; border-radius: 10px; color: black;">{message["content"]}</div></div>',
|
||||
unsafe_allow_html=True)
|
||||
|
||||
else:
|
||||
st.session_state.messages = []
|
||||
st.session_state.chat_messages = []
|
||||
|
||||
return st.session_state.messages
|
||||
|
||||
|
||||
# 添加这两个辅助函数
|
||||
def regenerate_answer(index):
|
||||
st.session_state.messages.pop()
|
||||
st.session_state.chat_messages.pop()
|
||||
st.rerun()
|
||||
|
||||
|
||||
def delete_conversation(index):
|
||||
st.session_state.messages.pop(index)
|
||||
st.session_state.messages.pop(index - 1)
|
||||
st.session_state.chat_messages.pop(index)
|
||||
st.session_state.chat_messages.pop(index - 1)
|
||||
st.rerun()
|
||||
|
||||
|
||||
# 侧边栏模型选择
|
||||
st.sidebar.title("模型设定调整")
|
||||
|
||||
st.sidebar.text("【注】训练数据偏差,增加上下文记忆时\n多轮对话(较单轮)容易出现能力衰减")
|
||||
st.session_state.history_chat_num = st.sidebar.slider("Number of Historical Dialogues", 0, 6, 0, step=2)
|
||||
# st.session_state.history_chat_num = 0
|
||||
st.session_state.max_new_tokens = st.sidebar.slider("Max Sequence Length", 256, 8192, 8192, step=1)
|
||||
st.session_state.top_p = st.sidebar.slider("Top-P", 0.8, 0.99, 0.85, step=0.01)
|
||||
st.session_state.temperature = st.sidebar.slider("Temperature", 0.6, 1.2, 0.85, step=0.01)
|
||||
|
||||
# 模型路径映射
|
||||
MODEL_PATHS = {
|
||||
"MiniMind2-R1 (0.1B)": ["../MiniMind2-R1", "MiniMind2-R1"],
|
||||
"MiniMind2-Small-R1 (0.02B)": ["../MiniMind2-Small-R1", "MiniMind2-Small-R1"],
|
||||
"MiniMind2 (0.1B)": ["../MiniMind2", "MiniMind2"],
|
||||
"MiniMind2-MoE (0.15B)": ["../MiniMind2-MoE", "MiniMind2-MoE"],
|
||||
"MiniMind2-Small (0.02B)": ["../MiniMind2-Small", "MiniMind2-Small"],
|
||||
"MiniMind-V1 (0.1B)": ["../minimind-v1", "MiniMind-V1"],
|
||||
"MiniMind-V1-MoE (0.1B)": ["../minimind-v1-moe", "MiniMind-V1-MoE"],
|
||||
"MiniMind-V1-Small (0.02B)": ["../minimind-v1-small", "MiniMind-V1-Small"],
|
||||
}
|
||||
|
||||
selected_model = st.sidebar.selectbox('Models', list(MODEL_PATHS.keys()), index=2) # 默认选择 MiniMind2
|
||||
model_path = MODEL_PATHS[selected_model][0]
|
||||
|
||||
slogan = f"Hi, I'm {MODEL_PATHS[selected_model][1]}"
|
||||
|
||||
image_url = "https://www.modelscope.cn/api/v1/studio/gongjy/MiniMind/repo?Revision=master&FilePath=images%2Flogo2.png&View=true"
|
||||
|
||||
st.markdown(
|
||||
f'<div style="display: flex; flex-direction: column; align-items: center; text-align: center; margin: 0; padding: 0;">'
|
||||
'<div style="font-style: italic; font-weight: 900; margin: 0; padding-top: 4px; display: flex; align-items: center; justify-content: center; flex-wrap: wrap; width: 100%;">'
|
||||
f'<img src="{image_url}" style="width: 45px; height: 45px; "> '
|
||||
f'<span style="font-size: 26px; margin-left: 10px;">{slogan}</span>'
|
||||
'</div>'
|
||||
'<span style="color: #bbb; font-style: italic; margin-top: 6px; margin-bottom: 10px;">内容完全由AI生成,请务必仔细甄别<br>Content AI-generated, please discern with care</span>'
|
||||
'</div>',
|
||||
unsafe_allow_html=True
|
||||
)
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def main():
|
||||
model, tokenizer = load_model_tokenizer(model_path)
|
||||
|
||||
# 初始化消息列表
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
st.session_state.chat_messages = []
|
||||
|
||||
# Use session state messages
|
||||
messages = st.session_state.messages
|
||||
|
||||
# 在显示历史消息的循环中
|
||||
for i, message in enumerate(messages):
|
||||
if message["role"] == "assistant":
|
||||
with st.chat_message("assistant", avatar=image_url):
|
||||
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
|
||||
if st.button("×", key=f"delete_{i}"):
|
||||
# 删除当前消息及其之后的所有消息
|
||||
st.session_state.messages = st.session_state.messages[:i - 1]
|
||||
st.session_state.chat_messages = st.session_state.chat_messages[:i - 1]
|
||||
st.rerun()
|
||||
else:
|
||||
st.markdown(
|
||||
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: gray; border-radius: 10px; color:white; ">{message["content"]}</div></div>',
|
||||
unsafe_allow_html=True)
|
||||
|
||||
# 处理新的输入或重新生成
|
||||
prompt = st.chat_input(key="input", placeholder="给 MiniMind 发送消息")
|
||||
|
||||
# 检查是否需要重新生成
|
||||
if hasattr(st.session_state, 'regenerate') and st.session_state.regenerate:
|
||||
prompt = st.session_state.last_user_message
|
||||
regenerate_index = st.session_state.regenerate_index # 获取重新生成的位置
|
||||
# 清除所有重新生成相关的状态
|
||||
delattr(st.session_state, 'regenerate')
|
||||
delattr(st.session_state, 'last_user_message')
|
||||
delattr(st.session_state, 'regenerate_index')
|
||||
|
||||
if prompt:
|
||||
st.markdown(
|
||||
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: gray; border-radius: 10px; color:white; ">{prompt}</div></div>',
|
||||
unsafe_allow_html=True)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
st.session_state.chat_messages.append({"role": "user", "content": prompt})
|
||||
|
||||
with st.chat_message("assistant", avatar=image_url):
|
||||
placeholder = st.empty()
|
||||
random_seed = random.randint(0, 2 ** 32 - 1)
|
||||
setup_seed(random_seed)
|
||||
|
||||
st.session_state.chat_messages = system_prompt + st.session_state.chat_messages[
|
||||
-(st.session_state.history_chat_num + 1):]
|
||||
new_prompt = tokenizer.apply_chat_template(
|
||||
st.session_state.chat_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)[-(st.session_state.max_new_tokens - 1):]
|
||||
|
||||
x = torch.tensor(tokenizer(new_prompt)['input_ids'], device=device).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=st.session_state.max_new_tokens,
|
||||
temperature=st.session_state.temperature,
|
||||
top_p=st.session_state.top_p, stream=True)
|
||||
try:
|
||||
for y in res_y:
|
||||
answer = tokenizer.decode(y[0].tolist(), skip_special_tokens=True)
|
||||
if (answer and answer[-1] == '<EFBFBD>') or not answer:
|
||||
continue
|
||||
placeholder.markdown(process_assistant_content(answer), unsafe_allow_html=True)
|
||||
except StopIteration:
|
||||
print("No answer")
|
||||
|
||||
assistant_answer = answer.replace(new_prompt, "")
|
||||
messages.append({"role": "assistant", "content": assistant_answer})
|
||||
st.session_state.chat_messages.append({"role": "assistant", "content": assistant_answer})
|
||||
|
||||
with st.empty():
|
||||
if st.button("×", key=f"delete_{len(messages) - 1}"):
|
||||
st.session_state.messages = st.session_state.messages[:-2]
|
||||
st.session_state.chat_messages = st.session_state.chat_messages[:-2]
|
||||
st.rerun()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
main()
|
97
test_real_rope.py
Normal file
97
test_real_rope.py
Normal file
@ -0,0 +1,97 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
测试实数版本的位置编码
|
||||
"""
|
||||
|
||||
import torch
|
||||
from model.model import precompute_pos_cis, precompute_pos_cis_real, apply_rotary_emb, apply_rotary_emb_real
|
||||
from model.LMConfig import LMConfig
|
||||
from model.model import MiniMindLM
|
||||
|
||||
def test_pos_encoding_equivalence():
|
||||
"""测试复数版本和实数版本的位置编码是否等价"""
|
||||
print("测试位置编码等价性...")
|
||||
|
||||
# 参数设置
|
||||
dim = 64
|
||||
seq_len = 10
|
||||
|
||||
# 生成复数版本的位置编码
|
||||
pos_cis = precompute_pos_cis(dim=dim, end=seq_len)
|
||||
|
||||
# 生成实数版本的位置编码
|
||||
pos_cis_real = precompute_pos_cis_real(dim=dim, end=seq_len)
|
||||
|
||||
# 创建随机查询和键
|
||||
batch_size = 2
|
||||
n_heads = 4
|
||||
head_dim = dim
|
||||
|
||||
xq = torch.randn(batch_size, seq_len, n_heads, head_dim)
|
||||
xk = torch.randn(batch_size, seq_len, n_heads, head_dim)
|
||||
|
||||
# 应用复数版本的旋转位置编码
|
||||
xq_complex, xk_complex = apply_rotary_emb(xq, xk, pos_cis)
|
||||
|
||||
# 应用实数版本的旋转位置编码
|
||||
xq_real, xk_real = apply_rotary_emb_real(xq, xk, pos_cis_real)
|
||||
|
||||
# 计算差异
|
||||
q_diff = torch.abs(xq_complex - xq_real).mean().item()
|
||||
k_diff = torch.abs(xk_complex - xk_real).mean().item()
|
||||
|
||||
print(f"查询差异: {q_diff:.6f}")
|
||||
print(f"键差异: {k_diff:.6f}")
|
||||
|
||||
# 检查差异是否在可接受范围内
|
||||
tolerance = 1e-5
|
||||
if q_diff < tolerance and k_diff < tolerance:
|
||||
print("✅ 测试通过: 复数版本和实数版本的位置编码在数值上等价")
|
||||
else:
|
||||
print("❌ 测试失败: 复数版本和实数版本的位置编码存在显著差异")
|
||||
|
||||
def test_model_forward():
|
||||
"""测试模型前向传播"""
|
||||
print("\n测试模型前向传播...")
|
||||
|
||||
# 创建模型配置
|
||||
config = LMConfig(
|
||||
dim=128,
|
||||
n_layers=2,
|
||||
n_heads=4,
|
||||
n_kv_heads=4, # 确保n_kv_heads被设置,且n_heads能被n_kv_heads整除
|
||||
vocab_size=1000,
|
||||
max_seq_len=128,
|
||||
disable_db=True # 禁用数据库功能,避免额外的复杂性
|
||||
)
|
||||
|
||||
# 创建模型
|
||||
try:
|
||||
model = MiniMindLM(config)
|
||||
print(f"✅ 模型初始化成功")
|
||||
except Exception as e:
|
||||
print(f"❌ 模型初始化失败: {str(e)}")
|
||||
return
|
||||
|
||||
# 创建输入
|
||||
batch_size = 2
|
||||
seq_len = 10
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
||||
|
||||
# 前向传播
|
||||
try:
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids)
|
||||
print(f"✅ 模型前向传播成功")
|
||||
print(f"输出形状: {outputs.logits.shape}")
|
||||
except Exception as e:
|
||||
print(f"❌ 模型前向传播失败: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试位置编码等价性
|
||||
test_pos_encoding_equivalence()
|
||||
|
||||
# 测试模型前向传播
|
||||
test_model_forward()
|
215
train_distill_reason.py
Normal file
215
train_distill_reason.py
Normal file
@ -0,0 +1,215 @@
|
||||
import os
|
||||
import platform
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from torch import optim, nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from model.model import MiniMindLM
|
||||
from model.LMConfig import LMConfig
|
||||
from model.dataset import SFTDataset
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def Logger(content):
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(content)
|
||||
|
||||
|
||||
def get_lr(current_step, total_steps, lr):
|
||||
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
|
||||
|
||||
|
||||
def train_epoch(epoch, wandb):
|
||||
# 思考标签占位符
|
||||
start_of_think_ids = tokenizer('<think>').input_ids
|
||||
end_of_think_ids = tokenizer('</think>').input_ids
|
||||
start_of_answer_ids = tokenizer('<answer>').input_ids
|
||||
end_of_answer_ids = tokenizer('</answer>').input_ids
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
start_time = time.time()
|
||||
for step, (X, Y, loss_mask) in enumerate(train_loader):
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with ctx:
|
||||
res = model(X)
|
||||
loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
Y.view(-1)
|
||||
).view(Y.size())
|
||||
sp_ids = torch.isin(Y.view(-1),
|
||||
torch.tensor(start_of_think_ids + end_of_think_ids
|
||||
+ start_of_answer_ids + end_of_answer_ids
|
||||
).to(args.device))
|
||||
# 在 sp_ids 对应的位置增加额外的惩罚
|
||||
loss_mask = loss_mask.view(-1)
|
||||
loss_mask_sum = loss_mask.sum()
|
||||
loss_mask[sp_ids] = 10
|
||||
loss_mask = loss_mask.view(Y.size())
|
||||
loss = (loss * loss_mask).sum() / loss_mask_sum
|
||||
loss += res.aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if step % args.log_interval == 0:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
iter_per_epoch,
|
||||
loss.item(),
|
||||
optimizer.param_groups[-1]['lr'],
|
||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
||||
|
||||
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
||||
wandb.log({"loss": loss,
|
||||
"lr": optimizer.param_groups[-1]['lr'],
|
||||
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
|
||||
|
||||
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
|
||||
model.eval()
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/reason_{lm_config.dim}{moe_path}.pth'
|
||||
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
|
||||
torch.save(state_dict, ckp)
|
||||
model.train()
|
||||
|
||||
|
||||
def init_model(lm_config):
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
model = MiniMindLM(lm_config)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'./out/rlhf_{lm_config.dim}{moe_path}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
model = model.to(args.device)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if not ddp: return
|
||||
global ddp_local_rank, DEVICE
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
ddp_rank = int(os.environ["RANK"])
|
||||
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
||||
ddp_world_size = int(os.environ["WORLD_SIZE"])
|
||||
DEVICE = f"cuda:{ddp_local_rank}"
|
||||
torch.cuda.set_device(DEVICE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind Distill Reasoning")
|
||||
parser.add_argument("--out_dir", type=str, default="out")
|
||||
parser.add_argument("--epochs", type=int, default=1)
|
||||
parser.add_argument("--batch_size", type=int, default=8)
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-6)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--warmup_iters", type=int, default=0)
|
||||
parser.add_argument("--log_interval", type=int, default=1)
|
||||
parser.add_argument("--save_interval", type=int, default=50)
|
||||
parser.add_argument('--local_rank', type=int, default=-1)
|
||||
parser.add_argument('--dim', default=512, type=int)
|
||||
parser.add_argument('--n_layers', default=8, type=int)
|
||||
parser.add_argument('--max_seq_len', default=1024, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/r1_mix_1024.jsonl")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
lm_config = LMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe)
|
||||
args.save_dir = os.path.join(args.out_dir)
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
tokens_per_iter = args.batch_size * lm_config.max_seq_len
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
|
||||
args.wandb_run_name = f"MiniMind-Distill-Reasoning-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
|
||||
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
||||
ddp_local_rank, DEVICE = 0, "cuda:0"
|
||||
base_seed = 1337
|
||||
torch.manual_seed(base_seed)
|
||||
torch.cuda.manual_seed(base_seed)
|
||||
|
||||
if ddp:
|
||||
init_distributed_mode()
|
||||
args.device = torch.device(DEVICE)
|
||||
rank = dist.get_rank()
|
||||
torch.manual_seed(base_seed + rank)
|
||||
# 同时设置 CUDA 的随机种子
|
||||
torch.cuda.manual_seed(base_seed + rank)
|
||||
|
||||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||||
import wandb
|
||||
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
model, tokenizer = init_model(lm_config)
|
||||
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
if ddp:
|
||||
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
||||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
|
||||
|
||||
iter_per_epoch = len(train_loader)
|
||||
for epoch in range(args.epochs):
|
||||
train_epoch(epoch, wandb)
|
263
train_distillation.py
Normal file
263
train_distillation.py
Normal file
@ -0,0 +1,263 @@
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from torch import optim, nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from model.model import MiniMindLM
|
||||
from model.LMConfig import LMConfig
|
||||
from model.dataset import SFTDataset
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def Logger(content):
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(content)
|
||||
|
||||
|
||||
def get_lr(current_step, total_steps, lr):
|
||||
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
|
||||
|
||||
|
||||
def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
|
||||
with torch.no_grad():
|
||||
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).detach()
|
||||
|
||||
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
|
||||
|
||||
kl = F.kl_div(
|
||||
student_log_probs,
|
||||
teacher_probs,
|
||||
reduction=reduction
|
||||
)
|
||||
return (temperature ** 2) * kl
|
||||
|
||||
|
||||
def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
|
||||
start_time = time.time()
|
||||
|
||||
if teacher_model is not None:
|
||||
teacher_model.eval()
|
||||
teacher_model.requires_grad_(False)
|
||||
|
||||
for step, (X, Y, loss_mask) in enumerate(train_loader):
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
lr = get_lr(epoch * iter_per_epoch + step,
|
||||
args.epochs * iter_per_epoch,
|
||||
args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# 前向传播(学生模型)
|
||||
with ctx:
|
||||
res = model(X)
|
||||
student_logits = res.logits
|
||||
|
||||
# 教师模型前向传播(只在eval & no_grad)
|
||||
if teacher_model is not None:
|
||||
with torch.no_grad():
|
||||
teacher_logits = teacher_model(X).logits
|
||||
vocab_size_student = student_logits.size(-1) # N
|
||||
teacher_logits = teacher_logits[..., :vocab_size_student]
|
||||
|
||||
# ========== 计算损失 ==========
|
||||
# 1) Ground-Truth CE Loss(可选)
|
||||
loss_mask_flat = loss_mask.view(-1)
|
||||
ce_loss = F.cross_entropy(
|
||||
student_logits.view(-1, student_logits.size(-1)),
|
||||
Y.view(-1),
|
||||
ignore_index=0,
|
||||
reduction='none'
|
||||
)
|
||||
ce_loss = torch.sum(ce_loss * loss_mask_flat) / loss_mask_flat.sum()
|
||||
if lm_config_student.use_moe:
|
||||
ce_loss += res.aux_loss
|
||||
|
||||
# 2) Distillation Loss(可选)
|
||||
if teacher_model is not None:
|
||||
# 只在有效token位置做蒸馏
|
||||
distill_loss = distillation_loss_fn(
|
||||
student_logits.view(-1, student_logits.size(-1))[loss_mask_flat == 1],
|
||||
teacher_logits.view(-1, teacher_logits.size(-1))[loss_mask_flat == 1],
|
||||
temperature=temperature
|
||||
)
|
||||
else:
|
||||
distill_loss = torch.tensor(0.0, device=args.device)
|
||||
|
||||
# 3) 总损失 = alpha * CE + (1-alpha) * Distill
|
||||
loss = alpha * ce_loss + (1 - alpha) * distill_loss
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if step % args.log_interval == 0:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.4f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch,
|
||||
args.epochs - 1,
|
||||
step,
|
||||
iter_per_epoch,
|
||||
loss.item(),
|
||||
optimizer.param_groups[-1]['lr'],
|
||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
|
||||
)
|
||||
)
|
||||
|
||||
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
||||
wandb.log({
|
||||
"loss": loss.item(),
|
||||
"ce_loss": ce_loss.item(),
|
||||
"distill_loss": distill_loss.item() if teacher_model is not None else 0.0,
|
||||
"lr": optimizer.param_groups[-1]['lr'],
|
||||
"last-time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
|
||||
})
|
||||
|
||||
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
|
||||
model.eval()
|
||||
moe_path = '_moe' if lm_config_student.use_moe else ''
|
||||
ckp = f'{args.save_dir}/full_dist_{lm_config_student.dim}{moe_path}.pth'
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, ckp)
|
||||
model.train()
|
||||
|
||||
|
||||
def init_student_model(lm_config):
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
model = MiniMindLM(lm_config)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'./out/full_sft_{lm_config.dim}{moe_path}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
Logger(f'学生模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
model = model.to(args.device)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def init_teacher_model(lm_config):
|
||||
model = MiniMindLM(lm_config)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'./out/full_sft_{lm_config.dim}{moe_path}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
Logger(f'教师模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
model = model.to(args.device)
|
||||
return model
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if not ddp: return
|
||||
global ddp_local_rank, DEVICE
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
ddp_rank = int(os.environ["RANK"])
|
||||
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
||||
ddp_world_size = int(os.environ["WORLD_SIZE"])
|
||||
DEVICE = f"cuda:{ddp_local_rank}"
|
||||
torch.cuda.set_device(DEVICE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind Full SFT")
|
||||
parser.add_argument("--out_dir", type=str, default="out")
|
||||
parser.add_argument("--epochs", type=int, default=6)
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-6)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--warmup_iters", type=int, default=0)
|
||||
parser.add_argument("--log_interval", type=int, default=100)
|
||||
parser.add_argument("--save_interval", type=int, default=100)
|
||||
parser.add_argument('--local_rank', type=int, default=-1)
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/sft_data.jsonl")
|
||||
|
||||
args = parser.parse_args()
|
||||
# 定义学生模型和教师模型
|
||||
lm_config_student = LMConfig(dim=512, n_layers=8, max_seq_len=512)
|
||||
lm_config_teacher = LMConfig(dim=768, n_layers=16, max_seq_len=512)
|
||||
max_seq_len = lm_config_student.max_seq_len
|
||||
args.save_dir = os.path.join(args.out_dir)
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
tokens_per_iter = args.batch_size * max_seq_len
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
|
||||
args.wandb_run_name = f"MiniMind-Dist-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
|
||||
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
||||
ddp_local_rank, DEVICE = 0, "cuda:0"
|
||||
base_seed = 1337
|
||||
torch.manual_seed(base_seed)
|
||||
torch.cuda.manual_seed(base_seed)
|
||||
|
||||
if ddp:
|
||||
init_distributed_mode()
|
||||
args.device = torch.device(DEVICE)
|
||||
rank = dist.get_rank()
|
||||
torch.manual_seed(base_seed + rank)
|
||||
# 同时设置 CUDA 的随机种子
|
||||
torch.cuda.manual_seed(base_seed + rank)
|
||||
|
||||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||||
import wandb
|
||||
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
# 初始化学生模型和教师模型
|
||||
model, tokenizer = init_student_model(lm_config_student)
|
||||
teacher_model = init_teacher_model(lm_config_teacher)
|
||||
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
if ddp:
|
||||
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
||||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
|
||||
|
||||
iter_per_epoch = len(train_loader)
|
||||
for epoch in range(args.epochs):
|
||||
train_epoch(epoch, wandb)
|
247
train_dpo.py
Normal file
247
train_dpo.py
Normal file
@ -0,0 +1,247 @@
|
||||
import os
|
||||
import platform
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from torch import optim, nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from model.model import MiniMindLM
|
||||
from model.LMConfig import LMConfig
|
||||
from model.dataset import DPODataset
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def Logger(content):
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(content)
|
||||
|
||||
|
||||
def get_lr(current_step, total_steps, lr):
|
||||
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
|
||||
|
||||
|
||||
def logits_to_probs(logits, labels):
|
||||
# logits shape: (batch_size, seq_len, vocab_size)
|
||||
# labels shape: (batch_size, seq_len)
|
||||
# probs shape: (batch_size, seq_len)
|
||||
log_probs = F.log_softmax(logits, dim=2)
|
||||
probs = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)
|
||||
return probs
|
||||
|
||||
|
||||
def dpo_loss(ref_probs, probs, mask, beta):
|
||||
# ref_probs 和 probs 都是 shape: (batch_size, seq_len)
|
||||
# https://github.com/jingyaogong/minimind/issues/298
|
||||
seq_lengths = mask.sum(dim=1, keepdim=True) # (batch_size, 1)
|
||||
ref_probs = (ref_probs * mask).sum(dim=1) / seq_lengths.squeeze()
|
||||
probs = (probs * mask).sum(dim=1) / seq_lengths.squeeze()
|
||||
|
||||
# 将 chosen 和 rejected 数据分开
|
||||
batch_size = ref_probs.shape[0]
|
||||
chosen_ref_probs = ref_probs[:batch_size // 2]
|
||||
reject_ref_probs = ref_probs[batch_size // 2:]
|
||||
chosen_probs = probs[:batch_size // 2]
|
||||
reject_probs = probs[batch_size // 2:]
|
||||
|
||||
pi_logratios = chosen_probs - reject_probs
|
||||
ref_logratios = chosen_ref_probs - reject_ref_probs
|
||||
logits = pi_logratios - ref_logratios
|
||||
loss = -F.logsigmoid(beta * logits)
|
||||
return loss.mean()
|
||||
|
||||
|
||||
def train_epoch(epoch, wandb):
|
||||
start_time = time.time()
|
||||
for step, batch in enumerate(train_loader):
|
||||
x_chosen = batch['x_chosen'].to(args.device)
|
||||
x_rejected = batch['x_rejected'].to(args.device)
|
||||
y_chosen = batch['y_chosen'].to(args.device)
|
||||
y_rejected = batch['y_rejected'].to(args.device)
|
||||
mask_chosen = batch['mask_chosen'].to(args.device)
|
||||
mask_rejected = batch['mask_rejected'].to(args.device)
|
||||
x = torch.cat([x_chosen, x_rejected], dim=0)
|
||||
y = torch.cat([y_chosen, y_rejected], dim=0)
|
||||
mask = torch.cat([mask_chosen, mask_rejected], dim=0)
|
||||
|
||||
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with ctx:
|
||||
with torch.no_grad():
|
||||
ref_outputs = ref_model(x)
|
||||
ref_logits = ref_outputs.logits
|
||||
ref_probs = logits_to_probs(ref_logits, y)
|
||||
ref_probs = ref_probs * mask
|
||||
outputs = model(x)
|
||||
logits = outputs.logits
|
||||
probs = logits_to_probs(logits, y)
|
||||
probs = probs * mask
|
||||
loss = dpo_loss(ref_probs, probs, mask, beta=0.1)
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if step % args.log_interval == 0:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
iter_per_epoch,
|
||||
loss.item(),
|
||||
optimizer.param_groups[-1]['lr'],
|
||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
||||
|
||||
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
||||
wandb.log({"loss": loss,
|
||||
"lr": optimizer.param_groups[-1]['lr'],
|
||||
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
|
||||
|
||||
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
|
||||
model.eval()
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/rlhf_{lm_config.dim}{moe_path}.pth'
|
||||
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
|
||||
torch.save(state_dict, ckp)
|
||||
model.train()
|
||||
|
||||
|
||||
def init_model(lm_config):
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
model = MiniMindLM(lm_config)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'./out/full_sft_{lm_config.dim}{moe_path}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
# 初始化参考模型
|
||||
ref_model = MiniMindLM(lm_config)
|
||||
ref_model.load_state_dict(state_dict, strict=False)
|
||||
ref_model.eval()
|
||||
ref_model.requires_grad_(False)
|
||||
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
model = model.to(args.device)
|
||||
ref_model = ref_model.to(args.device)
|
||||
|
||||
return model, ref_model, tokenizer
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if not ddp: return
|
||||
global ddp_local_rank, DEVICE
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
ddp_rank = int(os.environ["RANK"])
|
||||
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
||||
ddp_world_size = int(os.environ["WORLD_SIZE"])
|
||||
DEVICE = f"cuda:{ddp_local_rank}"
|
||||
torch.cuda.set_device(DEVICE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind RLHF")
|
||||
parser.add_argument("--out_dir", type=str, default="out")
|
||||
parser.add_argument("--epochs", type=int, default=2)
|
||||
parser.add_argument("--batch_size", type=int, default=8)
|
||||
# sft阶段学习率为 「5e-6」->「5e-7」长度512,建议离线正负样本「概率」偏好对齐阶段lr <=「1e-8」长度3000,否则很容易遗忘训坏
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-8)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-RLHF-SFT")
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--warmup_iters", type=int, default=0)
|
||||
parser.add_argument("--log_interval", type=int, default=100)
|
||||
parser.add_argument("--save_interval", type=int, default=100)
|
||||
parser.add_argument('--local_rank', type=int, default=-1)
|
||||
parser.add_argument('--dim', default=512, type=int)
|
||||
parser.add_argument('--n_layers', default=8, type=int)
|
||||
parser.add_argument('--max_seq_len', default=1024, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/dpo.jsonl")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
lm_config = LMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe)
|
||||
args.save_dir = os.path.join(args.out_dir)
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
tokens_per_iter = args.batch_size * lm_config.max_seq_len
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
|
||||
args.wandb_run_name = f"MiniMind-Full-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
|
||||
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
||||
ddp_local_rank, DEVICE = 0, "cuda:0"
|
||||
base_seed = 1337
|
||||
torch.manual_seed(base_seed)
|
||||
torch.cuda.manual_seed(base_seed)
|
||||
|
||||
if ddp:
|
||||
init_distributed_mode()
|
||||
args.device = torch.device(DEVICE)
|
||||
rank = dist.get_rank()
|
||||
torch.manual_seed(base_seed + rank)
|
||||
# 同时设置 CUDA 的随机种子
|
||||
torch.cuda.manual_seed(base_seed + rank)
|
||||
|
||||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||||
import wandb
|
||||
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
model, ref_model, tokenizer = init_model(lm_config)
|
||||
|
||||
train_ds = DPODataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
if ddp:
|
||||
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
||||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
|
||||
|
||||
iter_per_epoch = len(train_loader)
|
||||
for epoch in range(args.epochs):
|
||||
train_epoch(epoch, wandb)
|
418
train_embedding.py
Normal file
418
train_embedding.py
Normal file
@ -0,0 +1,418 @@
|
||||
import os
|
||||
# 设置环境变量
|
||||
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
|
||||
import platform
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import warnings
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import optim, nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader, DistributedSampler, Dataset
|
||||
from contextlib import nullcontext
|
||||
import random
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Removed: from model.model import MiniMindLM
|
||||
from model.LMConfig import LMConfig
|
||||
# from model.dataset import PretrainDataset
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
# Define a Word2Vec-style CBOW model
|
||||
class CBOWModel(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embedding_dim = config.dim
|
||||
|
||||
# Input embeddings (context words)
|
||||
self.embeddings = nn.Embedding(config.vocab_size, config.dim)
|
||||
|
||||
# Output weights for target prediction
|
||||
self.output_weights = nn.Linear(config.dim, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
# Xavier initialization for better convergence
|
||||
nn.init.xavier_uniform_(self.embeddings.weight)
|
||||
nn.init.xavier_uniform_(self.output_weights.weight)
|
||||
|
||||
def forward(self, context_words):
|
||||
# context_words shape: [batch_size, context_size],context_size可变
|
||||
|
||||
# Get embeddings for all context words
|
||||
embeds = self.embeddings(context_words) # [batch_size, context_size, embedding_dim]
|
||||
|
||||
# Average the context word embeddings along context dimension
|
||||
embeds = torch.mean(embeds, dim=1) # [batch_size, embedding_dim]
|
||||
|
||||
# Predict the target word
|
||||
output = self.output_weights(embeds) # [batch_size, vocab_size]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# Word2Vec CBOW dataset
|
||||
class CBOWDataset(Dataset):
|
||||
def __init__(self, data_path, tokenizer, max_length=512, window_size=5):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.window_size = window_size
|
||||
self.max_length = max_length
|
||||
self.samples = self.load_data(data_path)
|
||||
|
||||
def load_data(self, path):
|
||||
samples = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
data = json.loads(line.strip())
|
||||
samples.append(data)
|
||||
return samples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
|
||||
# 构建输入文本
|
||||
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
|
||||
encoding = self.tokenizer(
|
||||
text,
|
||||
max_length=self.max_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
# 获取token ids
|
||||
input_ids = encoding.input_ids.squeeze()
|
||||
# 过滤掉padding
|
||||
attention_mask = encoding.attention_mask.squeeze()
|
||||
valid_indices = torch.where(attention_mask == 1)[0]
|
||||
valid_input_ids = input_ids[valid_indices]
|
||||
|
||||
# 确保有足够的token进行CBOW训练
|
||||
if len(valid_input_ids) <= 2 * self.window_size + 1:
|
||||
# 如果token不足,随机选择一个不同的样本
|
||||
return self.__getitem__(random.randint(0, len(self.samples) - 1))
|
||||
|
||||
# 随机选择一个中心位置(不包括首尾的特殊token)
|
||||
# 确保中心位置两边都有至少window_size个token
|
||||
min_center_pos = self.window_size + 1 # 避开起始token
|
||||
max_center_pos = len(valid_input_ids) - self.window_size - 1 # 避开结束token
|
||||
|
||||
if max_center_pos <= min_center_pos:
|
||||
return self.__getitem__(random.randint(0, len(self.samples) - 1))
|
||||
|
||||
center_pos = random.randint(min_center_pos, max_center_pos)
|
||||
|
||||
# 目标词(中心词)
|
||||
target = valid_input_ids[center_pos].unsqueeze(0)
|
||||
|
||||
# 上下文词(中心词前后的词)
|
||||
context = torch.cat([
|
||||
valid_input_ids[center_pos - self.window_size:center_pos],
|
||||
valid_input_ids[center_pos + 1:center_pos + self.window_size + 1]
|
||||
])
|
||||
|
||||
return context, target
|
||||
|
||||
|
||||
def Logger(content):
|
||||
# 如果没有使用ddp或者ddp的主设备,那么就打印
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(content)
|
||||
|
||||
|
||||
def get_lr(current_step, total_steps, lr):
|
||||
# 更新学习率
|
||||
# \text{get\_lr}(c, t, l) = \frac{l}{10} + 0.5 \cdot l \cdot \left(1 + \cos\left(\frac{\pi \cdot c}{t}\right)\right)
|
||||
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
|
||||
|
||||
|
||||
def train_epoch(epoch, wandb):
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
start_time = time.time()
|
||||
total_loss = 0
|
||||
total_samples = 0
|
||||
|
||||
for step, (context, target) in enumerate(train_loader):
|
||||
try:
|
||||
# 将数据加载到设备上
|
||||
context = context.to(args.device)
|
||||
target = target.to(args.device)
|
||||
|
||||
# 更新学习率
|
||||
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with ctx:
|
||||
# Forward pass
|
||||
logits = model(context) # [batch_size, vocab_size]
|
||||
# target是[batch_size, 1],需要squeeze成[batch_size]来匹配CrossEntropyLoss的预期
|
||||
loss = loss_fct(logits, target.squeeze())
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
# Print data types for debugging
|
||||
if step == 0 and (not ddp or dist.get_rank() == 0):
|
||||
Logger("---- Data Type Check ----")
|
||||
Logger(f"context.dtype: {context.dtype}")
|
||||
Logger(f"context.shape: {context.shape}")
|
||||
Logger(f"target.dtype: {target.dtype}")
|
||||
Logger(f"target.shape: {target.shape}")
|
||||
if hasattr(model, 'module'): # DDP case
|
||||
Logger(f"Model parameter dtype: {next(model.module.parameters()).dtype}")
|
||||
else: # Non-DDP case
|
||||
Logger(f"Model parameter dtype: {next(model.parameters()).dtype}")
|
||||
Logger(f"logits.dtype: {logits.dtype}")
|
||||
Logger(f"logits.shape: {logits.shape}")
|
||||
Logger(f"loss.dtype: {loss.dtype}")
|
||||
Logger("-------------------------")
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
total_loss += loss.item() * args.accumulation_steps
|
||||
total_samples += 1
|
||||
|
||||
# 打印日志
|
||||
if step % args.log_interval == 0:
|
||||
spend_time = time.time() - start_time
|
||||
avg_loss = total_loss / total_samples if total_samples > 0 else 0
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
iter_per_epoch,
|
||||
avg_loss,
|
||||
optimizer.param_groups[-1]['lr'],
|
||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
||||
|
||||
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
||||
wandb.log({"loss": avg_loss,
|
||||
"lr": optimizer.param_groups[-1]['lr'],
|
||||
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
# Modified checkpoint path for error
|
||||
save_path = f'{args.save_dir}/word2vec_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_ERROR.pth'
|
||||
if os.path.exists(save_path):
|
||||
os.remove(save_path)
|
||||
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.embeddings.state_dict()
|
||||
else:
|
||||
state_dict = model.embeddings.state_dict()
|
||||
torch.save(state_dict, save_path)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None and torch.isnan(param.grad).any():
|
||||
print(f"NaN gradient in parameter: {name}")
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None and torch.isnan(param.grad).any():
|
||||
print(f"Parameter {name} values: {param.data}")
|
||||
print(f"Parameter {name} gradients: {param.grad}")
|
||||
|
||||
raise ValueError("NaN gradient detected")
|
||||
|
||||
# Save model once at the end of each epoch
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
model.eval()
|
||||
ckp = f'{args.save_dir}/word2vec_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_epoch{epoch+1}.pth'
|
||||
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
embedding_state_dict = model.module.embeddings.state_dict()
|
||||
else:
|
||||
embedding_state_dict = model.embeddings.state_dict()
|
||||
|
||||
torch.save(embedding_state_dict, ckp)
|
||||
Logger(f"Saved word2vec embedding for epoch {epoch+1} to {ckp}")
|
||||
model.train()
|
||||
|
||||
|
||||
def init_model(lm_config_params: LMConfig):
|
||||
# 加载tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
# Update vocab_size in lm_config if tokenizer has a different one
|
||||
if tokenizer.vocab_size != lm_config_params.vocab_size:
|
||||
Logger(f"Updating lm_config.vocab_size from {lm_config_params.vocab_size} to {tokenizer.vocab_size} based on tokenizer.")
|
||||
lm_config_params.vocab_size = tokenizer.vocab_size
|
||||
|
||||
# 加载word2vec CBOW模型
|
||||
model = CBOWModel(lm_config_params).to(args.device)
|
||||
# 打印模型参数
|
||||
Logger(f'CBOW Model total parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} Million')
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if not ddp: return #如果没有启用分布式数据并行(DDP),直接返回,不执行任何操作。
|
||||
global ddp_local_rank, DEVICE #声明这两个变量为全局变量,以便在函数外部也能访问它们。
|
||||
|
||||
dist.init_process_group(backend="nccl") #初始化分布式进程组,使用NCCL后端(NVIDIA Collective Communications Library),这是NVIDIA GPU之间通信的优化库。
|
||||
ddp_rank = int(os.environ["RANK"]) #从环境变量获取当前进程的全局编号。
|
||||
ddp_local_rank = int(os.environ["LOCAL_RANK"]) #从环境变量获取当前进程的本地编号。
|
||||
ddp_world_size = int(os.environ["WORLD_SIZE"]) #从环境变量获取当前进程组中的进程总数。
|
||||
DEVICE = f"cuda:{ddp_local_rank}" #根据本地编号选择GPU设备。
|
||||
torch.cuda.set_device(DEVICE) #设置当前进程的GPU设备。
|
||||
|
||||
|
||||
# torchrun --nproc_per_node 2 train_embedding.py
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind Word2Vec Embedding Training")
|
||||
parser.add_argument("--out_dir", type=str, default="out_word2vec")
|
||||
parser.add_argument("--epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=256)
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-4)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Word2Vec-Training")
|
||||
parser.add_argument("--num_workers", type=int, default=32)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--log_interval", type=int, default=100)
|
||||
parser.add_argument("--save_interval", type=int, default=100)
|
||||
parser.add_argument('--local_rank', type=int, default=-1)
|
||||
parser.add_argument('--dim', default=768, type=int)
|
||||
parser.add_argument('--max_seq_len', default=512, type=int)
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
|
||||
parser.add_argument('--vocab_size', default=6400, type=int)
|
||||
parser.add_argument('--window_size', default=5, type=int)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create LMConfig with relevant parameters for embedding
|
||||
lm_config = LMConfig(
|
||||
dim=args.dim,
|
||||
vocab_size=args.vocab_size, # Will be updated by tokenizer
|
||||
max_seq_len=args.max_seq_len,
|
||||
n_layers=1, # Minimal
|
||||
n_heads=1, # Minimal
|
||||
n_kv_heads=1 #Minimal
|
||||
)
|
||||
args.save_dir = os.path.join(args.out_dir)
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
tokens_per_iter = args.batch_size * lm_config.max_seq_len
|
||||
print(f"tokens_per_iter: {tokens_per_iter}")
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
|
||||
# Determine the torch dtype
|
||||
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
|
||||
|
||||
args.wandb_run_name = f"MiniMind-Word2Vec-Dim-{args.dim}-Vocab-{lm_config.vocab_size}-Window-{args.window_size}"
|
||||
|
||||
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
|
||||
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
||||
ddp_local_rank, DEVICE = 0, "cuda:0" # Default values, will be overwritten in DDP
|
||||
|
||||
base_seed = 1337
|
||||
torch.manual_seed(base_seed)
|
||||
torch.cuda.manual_seed(base_seed)
|
||||
|
||||
if ddp:
|
||||
init_distributed_mode() # This sets DEVICE and ddp_local_rank
|
||||
args.device = torch.device(DEVICE) # Ensure args.device is updated
|
||||
rank = dist.get_rank()
|
||||
torch.manual_seed(base_seed + rank)
|
||||
# 同时设置 CUDA 的随机种子
|
||||
torch.cuda.manual_seed_all(base_seed + rank) # Use seed_all for DDP
|
||||
|
||||
if args.use_wandb and (not ddp or dist.get_rank() == 0): # Check rank for DDP wandb init
|
||||
import wandb
|
||||
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
model, tokenizer = init_model(lm_config) # Pass the lm_config instance
|
||||
|
||||
# Update lm_config vocab_size again after tokenizer to ensure consistency for save path name
|
||||
if lm_config.vocab_size != tokenizer.vocab_size:
|
||||
lm_config.vocab_size = tokenizer.vocab_size
|
||||
args.wandb_run_name = f"MiniMind-Word2Vec-Dim-{args.dim}-Vocab-{lm_config.vocab_size}-Window-{args.window_size}"
|
||||
if wandb is not None and (not ddp or dist.get_rank() == 0):
|
||||
wandb.config.update({'vocab_size': lm_config.vocab_size, 'wandb_run_name': args.wandb_run_name}, allow_val_change=True)
|
||||
|
||||
# 添加collate函数处理不同长度的序列
|
||||
def collate_cbow_batch(batch):
|
||||
# 提取context和target
|
||||
contexts, targets = zip(*batch)
|
||||
|
||||
# 获取当前批次中最长的context长度
|
||||
max_len = max([ctx.size(0) for ctx in contexts])
|
||||
|
||||
# 创建填充后的tensor
|
||||
padded_contexts = torch.zeros(len(contexts), max_len, dtype=torch.long)
|
||||
|
||||
# 填充每个context
|
||||
for i, ctx in enumerate(contexts):
|
||||
ctx_len = ctx.size(0)
|
||||
padded_contexts[i, :ctx_len] = ctx
|
||||
|
||||
# 将targets stack成一个tensor
|
||||
stacked_targets = torch.stack(targets)
|
||||
|
||||
return padded_contexts, stacked_targets
|
||||
|
||||
# Create Word2Vec CBOW dataset
|
||||
train_ds = CBOWDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len, window_size=args.window_size)
|
||||
train_sampler = DistributedSampler(train_ds, shuffle=True, seed=base_seed) if ddp else None
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler,
|
||||
collate_fn=collate_cbow_batch
|
||||
)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
if ddp:
|
||||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
|
||||
|
||||
iter_per_epoch = len(train_loader)
|
||||
Logger(f"Starting Word2Vec CBOW training for {args.epochs} epochs with {iter_per_epoch} iterations per epoch.")
|
||||
for epoch in range(args.epochs):
|
||||
if ddp:
|
||||
train_sampler.set_epoch(epoch)
|
||||
train_epoch(epoch, wandb)
|
||||
|
||||
if wandb is not None and (not ddp or dist.get_rank() == 0):
|
||||
wandb.finish()
|
||||
|
||||
Logger("Word2Vec embedding training finished.")
|
214
train_full_sft.py
Normal file
214
train_full_sft.py
Normal file
@ -0,0 +1,214 @@
|
||||
import os
|
||||
# 设置环境变量
|
||||
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
|
||||
import platform
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from torch import optim, nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from model.model import MiniMindLM
|
||||
from model.LMConfig import LMConfig
|
||||
from model.dataset import SFTDataset
|
||||
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# 日志记录函数,用于打印训练信息。
|
||||
def Logger(content):
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(content)
|
||||
|
||||
# 学习率计算函数,用于计算当前学习率。
|
||||
def get_lr(current_step, total_steps, lr):
|
||||
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
|
||||
|
||||
# 训练一个epoch的函数,用于训练模型。
|
||||
def train_epoch(epoch, wandb):
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none') #交叉熵损失函数,用于计算损失。
|
||||
start_time = time.time()
|
||||
for step, (X, Y, loss_mask) in enumerate(train_loader):
|
||||
# 将数据移动到指定设备。
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
# 计算当前学习率。
|
||||
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
|
||||
# 更新学习率。
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with ctx:
|
||||
res = model(X) #获取输出
|
||||
loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
Y.view(-1)
|
||||
).view(Y.size()) #计算损失
|
||||
|
||||
# 计算损失
|
||||
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||
loss += res.aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
scaler.scale(loss).backward() #用于处理混合精度训练。它的作用是自动缩放损失值,以防止在使用低精度(如 FP16)计算时出现数值不稳定的问题。
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer) #PyTorch 自动混合精度(AMP)训练的一部分。它"反缩放"之前为防止在混合精度训练中出现下溢而缩放的梯度。
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) #应用梯度裁剪以防止梯度爆炸。它会缩放梯度,使其范数不超过args.grad_clip。
|
||||
|
||||
scaler.step(optimizer) #使用优化器更新模型权重,但由缩放器控制以适应混合精度训练。
|
||||
scaler.update() #根据本次迭代是否有梯度溢出来更新下一次迭代的缩放因子。
|
||||
|
||||
optimizer.zero_grad(set_to_none=True) #清空梯度。
|
||||
|
||||
# 如果达到日志记录间隔,则记录日志。
|
||||
if step % args.log_interval == 0:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
iter_per_epoch,
|
||||
loss.item(),
|
||||
optimizer.param_groups[-1]['lr'],
|
||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
||||
|
||||
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
||||
wandb.log({"loss": loss,
|
||||
"lr": optimizer.param_groups[-1]['lr'],
|
||||
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
|
||||
|
||||
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
|
||||
model.eval()
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/full_sft_{lm_config.dim}{moe_path}.pth'
|
||||
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
|
||||
torch.save(state_dict, ckp)
|
||||
model.train()
|
||||
|
||||
# 初始化模型函数,用于初始化模型。
|
||||
def init_model(lm_config):
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
model = MiniMindLM(lm_config)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
model = model.to(args.device)
|
||||
return model, tokenizer
|
||||
|
||||
# 初始化分布式模式函数,用于初始化分布式模式。
|
||||
def init_distributed_mode():
|
||||
if not ddp: return
|
||||
global ddp_local_rank, DEVICE
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
ddp_rank = int(os.environ["RANK"])
|
||||
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
||||
ddp_world_size = int(os.environ["WORLD_SIZE"])
|
||||
DEVICE = f"cuda:{ddp_local_rank}"
|
||||
torch.cuda.set_device(DEVICE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind Full SFT")
|
||||
parser.add_argument("--out_dir", type=str, default="out")
|
||||
parser.add_argument("--epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-5)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", default=True, action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--warmup_iters", type=int, default=0)
|
||||
parser.add_argument("--log_interval", type=int, default=100)
|
||||
parser.add_argument("--save_interval", type=int, default=100)
|
||||
parser.add_argument('--local_rank', type=int, default=-1)
|
||||
parser.add_argument('--dim', default=1024, type=int) #模型维度,用于控制模型的大小。
|
||||
parser.add_argument('--n_layers', default=24, type=int) #层数,用于控制模型层数。
|
||||
parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/sft_1024.jsonl")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
lm_config = LMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe)
|
||||
args.save_dir = os.path.join(args.out_dir)
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
tokens_per_iter = args.batch_size * lm_config.max_seq_len
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
|
||||
args.wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
|
||||
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
||||
ddp_local_rank, DEVICE = 0, "cuda:0"
|
||||
base_seed = 1337
|
||||
torch.manual_seed(base_seed)
|
||||
torch.cuda.manual_seed(base_seed)
|
||||
|
||||
# 如果使用分布式模式,则初始化分布式模式。
|
||||
if ddp:
|
||||
init_distributed_mode()
|
||||
args.device = torch.device(DEVICE)
|
||||
rank = dist.get_rank()
|
||||
torch.manual_seed(base_seed + rank)
|
||||
# 同时设置 CUDA 的随机种子
|
||||
torch.cuda.manual_seed(base_seed + rank)
|
||||
|
||||
# 如果使用WandB,则初始化WandB。
|
||||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||||
import wandb
|
||||
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
# 初始化模型。
|
||||
model, tokenizer = init_model(lm_config)
|
||||
|
||||
# 初始化数据集。
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) #创建一个梯度缩放器(GradScaler),用于混合精度训练。当模型使用半精度格式(float16或bfloat16)训练时启用,它帮助防止梯度下溢并提高训练效率。
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) # 创建AdamW优化器实例,负责更新模型参数。它接收模型的所有参数和指定的学习率作为输入。AdamW是Adam优化器的变体,增加了权重衰减的正则化。
|
||||
|
||||
if ddp:
|
||||
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
||||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
|
||||
|
||||
iter_per_epoch = len(train_loader)
|
||||
for epoch in range(args.epochs):
|
||||
train_epoch(epoch, wandb)
|
201
train_lora.py
Normal file
201
train_lora.py
Normal file
@ -0,0 +1,201 @@
|
||||
import os
|
||||
import platform
|
||||
import argparse
|
||||
import random
|
||||
import time
|
||||
import math
|
||||
import warnings
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from model.model import MiniMindLM
|
||||
from model.LMConfig import LMConfig
|
||||
from model.dataset import SFTDataset
|
||||
from model.model_lora import *
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
# Logger function
|
||||
def Logger(content):
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(content)
|
||||
|
||||
|
||||
def get_lr(current_step, total_steps, lr):
|
||||
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
|
||||
|
||||
|
||||
# 代码和full_sft「几乎」一致
|
||||
def train_epoch(epoch, wandb):
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
start_time = time.time()
|
||||
for step, (X, Y, loss_mask) in enumerate(train_loader):
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with ctx:
|
||||
res = model(X)
|
||||
loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
Y.view(-1)
|
||||
).view(Y.size())
|
||||
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||
loss += res.aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if step % args.log_interval == 0:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
iter_per_epoch,
|
||||
loss.item(),
|
||||
optimizer.param_groups[-1]['lr'],
|
||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
||||
|
||||
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
||||
wandb.log({"loss": loss,
|
||||
"lr": optimizer.param_groups[-1]['lr'],
|
||||
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
|
||||
|
||||
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
|
||||
model.eval()
|
||||
# 【区别1】只保存lora权重即可
|
||||
save_lora(model, f'{args.save_dir}/lora/{args.lora_name}_{lm_config.dim}.pth')
|
||||
model.train()
|
||||
|
||||
|
||||
def init_model(lm_config):
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
model = MiniMindLM(lm_config)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'./out/rlhf_{lm_config.dim}{moe_path}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model.to(args.device), tokenizer
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if not ddp: return
|
||||
global ddp_local_rank, DEVICE
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
ddp_rank = int(os.environ["RANK"])
|
||||
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
||||
ddp_world_size = int(os.environ["WORLD_SIZE"])
|
||||
DEVICE = f"cuda:{ddp_local_rank}"
|
||||
torch.cuda.set_device(DEVICE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind SFT with LoRA")
|
||||
parser.add_argument("--out_dir", type=str, default="out")
|
||||
parser.add_argument("--epochs", type=int, default=50)
|
||||
parser.add_argument("--batch_size", type=int, default=16)
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-5)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA-SFT")
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--warmup_iters", type=int, default=0)
|
||||
parser.add_argument("--log_interval", type=int, default=100)
|
||||
parser.add_argument("--save_interval", type=int, default=1)
|
||||
parser.add_argument('--local_rank', type=int, default=-1)
|
||||
parser.add_argument('--dim', default=512, type=int)
|
||||
parser.add_argument('--n_layers', default=8, type=int)
|
||||
parser.add_argument('--max_seq_len', default=512, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/lora_identity.jsonl")
|
||||
parser.add_argument("--lora_name", type=str, default="lora_identity", help="根据任务保存成lora_(英文/医学/心理...)")
|
||||
args = parser.parse_args()
|
||||
|
||||
lm_config = LMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe)
|
||||
args.save_dir = os.path.join(args.out_dir)
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
tokens_per_iter = args.batch_size * lm_config.max_seq_len
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
|
||||
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
||||
ddp_local_rank, DEVICE = 0, "cuda:0"
|
||||
base_seed = 1337
|
||||
torch.manual_seed(base_seed)
|
||||
torch.cuda.manual_seed(base_seed)
|
||||
|
||||
if ddp:
|
||||
init_distributed_mode()
|
||||
args.device = torch.device(DEVICE)
|
||||
rank = dist.get_rank()
|
||||
torch.manual_seed(base_seed + rank)
|
||||
# 同时设置 CUDA 的随机种子
|
||||
torch.cuda.manual_seed(base_seed + rank)
|
||||
|
||||
args.wandb_run_name = f"MiniMind-Lora-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||||
import wandb
|
||||
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
model, tokenizer = init_model(lm_config)
|
||||
apply_lora(model)
|
||||
|
||||
total_params = sum(p.numel() for p in model.parameters()) # 总参数数量
|
||||
lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name) # LoRA 参数数量
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(f"LLM 总参数量: {total_params}")
|
||||
print(f"LoRA 参数量: {lora_params_count}")
|
||||
print(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if 'lora' not in name:
|
||||
param.requires_grad = False
|
||||
lora_params = []
|
||||
for name, param in model.named_parameters():
|
||||
if 'lora' in name:
|
||||
lora_params.append(param)
|
||||
|
||||
# 只对 LoRA 参数进行优化
|
||||
optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
|
||||
iter_per_epoch = len(train_loader)
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
train_epoch(epoch, wandb)
|
441
train_pretrain.py
Normal file
441
train_pretrain.py
Normal file
@ -0,0 +1,441 @@
|
||||
import os
|
||||
# 设置环境变量
|
||||
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
|
||||
import platform
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import warnings
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import optim, nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
# 移除通信分析工具导入
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from model.model import MiniMindLM
|
||||
from model.LMConfig import LMConfig
|
||||
from model.dataset import PretrainDataset
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def Logger(content):
|
||||
# 如果没有使用ddp或者ddp的主设备,那么就打印
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(content)
|
||||
|
||||
|
||||
def get_lr(current_step, total_steps, lr):
|
||||
# 更新学习率
|
||||
# \text{get\_lr}(c, t, l) = \frac{l}{10} + 0.5 \cdot l \cdot \left(1 + \cos\left(\frac{\pi \cdot c}{t}\right)\right)
|
||||
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
|
||||
|
||||
|
||||
def train_epoch(epoch, wandb):
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
start_time = time.time()
|
||||
# 在函数开始处定义moe_path,避免在异常处理中引用未定义变量
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
|
||||
# 添加CUDA事件来分析性能
|
||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||
data_start = torch.cuda.Event(enable_timing=True)
|
||||
data_end = torch.cuda.Event(enable_timing=True)
|
||||
forward_start = torch.cuda.Event(enable_timing=True)
|
||||
forward_end = torch.cuda.Event(enable_timing=True)
|
||||
backward_start = torch.cuda.Event(enable_timing=True)
|
||||
backward_end = torch.cuda.Event(enable_timing=True)
|
||||
optimizer_start = torch.cuda.Event(enable_timing=True)
|
||||
optimizer_end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
# 移除CUDA图优化代码
|
||||
|
||||
# 预取数据
|
||||
prefetch_factor = 2 # 预取的批次数
|
||||
data_iter = iter(train_loader)
|
||||
prefetch_batches = []
|
||||
|
||||
# 预取初始批次
|
||||
for _ in range(min(prefetch_factor, len(train_loader))):
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
prefetch_batches.append([t.to(args.device, non_blocking=True) for t in batch])
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
for step in range(len(train_loader)):
|
||||
try:
|
||||
# 计时数据加载
|
||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||
data_start.record()
|
||||
|
||||
# 使用预取的数据
|
||||
if prefetch_batches:
|
||||
X, Y, loss_mask = prefetch_batches.pop(0)
|
||||
else:
|
||||
# 如果预取队列为空,直接加载
|
||||
X, Y, loss_mask = [t.to(args.device) for t in next(data_iter)]
|
||||
|
||||
# 异步预取下一批数据
|
||||
if step + prefetch_factor < len(train_loader):
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
prefetch_batches.append([t.to(args.device, non_blocking=True) for t in batch])
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||
data_end.record()
|
||||
|
||||
# 更新学习率
|
||||
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# 计时前向传播
|
||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||
forward_start.record()
|
||||
|
||||
# 常规前向传播
|
||||
with ctx:
|
||||
res = model(X)
|
||||
loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
Y.view(-1)
|
||||
).view(Y.size())
|
||||
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||
# 添加辅助损失,如果存在的话
|
||||
try:
|
||||
if hasattr(model, 'module'):
|
||||
# DDP情况
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
|
||||
if hasattr(l.feed_forward, 'aux_loss'))
|
||||
else:
|
||||
# 非DDP情况
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in model.layers
|
||||
if hasattr(l.feed_forward, 'aux_loss'))
|
||||
loss += aux_loss
|
||||
except Exception as e:
|
||||
Logger(f"Warning: Could not add auxiliary loss: {e}")
|
||||
# 如果出错,不添加辅助损失
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
# 反向传播
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||
forward_end.record()
|
||||
backward_start.record()
|
||||
|
||||
# Print data types for debugging
|
||||
if step == 0 and (not ddp or dist.get_rank() == 0): # Print only for the first step of the first epoch on the main process
|
||||
Logger("---- Data Type Check ----")
|
||||
Logger(f"X.dtype: {X.dtype}")
|
||||
if hasattr(model, 'module'): # DDP case
|
||||
Logger(f"Model parameter dtype: {next(model.module.parameters()).dtype}")
|
||||
else: # Non-DDP case
|
||||
Logger(f"Model parameter dtype: {next(model.parameters()).dtype}")
|
||||
Logger(f"res.logits.dtype: {res.logits.dtype}")
|
||||
Logger(f"loss.dtype: {loss.dtype}")
|
||||
Logger("-------------------------")
|
||||
|
||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||
backward_end.record()
|
||||
|
||||
# 在每一步都进行性能分析,而不仅仅是在梯度累积完成时
|
||||
if (step + 1) % args.profile_interval == 0:
|
||||
# 记录优化器时间(如果是梯度累积步骤)
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
optimizer_start.record()
|
||||
|
||||
# 优化器步骤
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||
if (step + 1) % args.profile_interval != 0:
|
||||
optimizer_start.record()
|
||||
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||
optimizer_end.record()
|
||||
|
||||
# 性能分析输出(每profile_interval步)
|
||||
if args.profile and (not ddp or dist.get_rank() == 0) and (step + 1) % args.profile_interval == 0:
|
||||
# 同步CUDA事件以获取准确的计时
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# 计算各阶段耗时
|
||||
data_time = data_start.elapsed_time(data_end)
|
||||
forward_time = forward_start.elapsed_time(forward_end)
|
||||
backward_time = backward_start.elapsed_time(backward_end)
|
||||
|
||||
# 只有在梯度累积步骤完成时才有优化器时间
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
optimizer_time = optimizer_start.elapsed_time(optimizer_end)
|
||||
total_compute_time = forward_time + backward_time + optimizer_time
|
||||
Logger(f"性能分析 - 步骤 {step+1}:")
|
||||
Logger(f" 数据加载时间: {data_time:.2f} ms")
|
||||
Logger(f" 前向传播时间: {forward_time:.2f} ms")
|
||||
Logger(f" 反向传播时间: {backward_time:.2f} ms")
|
||||
Logger(f" 优化器时间: {optimizer_time:.2f} ms")
|
||||
Logger(f" 总计算时间: {total_compute_time:.2f} ms")
|
||||
Logger(f" 计算/数据比例: {total_compute_time / data_time:.2f}")
|
||||
else:
|
||||
# 非梯度累积步骤,没有优化器时间
|
||||
total_compute_time = forward_time + backward_time
|
||||
Logger(f"性能分析 - 步骤 {step+1} (梯度累积中):")
|
||||
Logger(f" 数据加载时间: {data_time:.2f} ms")
|
||||
Logger(f" 前向传播时间: {forward_time:.2f} ms")
|
||||
Logger(f" 反向传播时间: {backward_time:.2f} ms")
|
||||
Logger(f" 总计算时间: {total_compute_time:.2f} ms")
|
||||
Logger(f" 计算/数据比例: {total_compute_time / data_time:.2f}")
|
||||
|
||||
# 打印日志
|
||||
if step % args.log_interval == 0:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
iter_per_epoch,
|
||||
loss.item() * args.accumulation_steps,
|
||||
optimizer.param_groups[-1]['lr'],
|
||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
||||
|
||||
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
||||
log_dict = {
|
||||
"loss": loss.item() * args.accumulation_steps,
|
||||
"lr": optimizer.param_groups[-1]['lr'],
|
||||
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
|
||||
}
|
||||
|
||||
# 如果启用了性能分析,也记录性能指标
|
||||
if args.profile and (step + 1) % args.profile_interval == 0:
|
||||
# 基本性能指标
|
||||
perf_dict = {
|
||||
"data_time_ms": data_time,
|
||||
"forward_time_ms": forward_time,
|
||||
"backward_time_ms": backward_time
|
||||
}
|
||||
|
||||
# 只有在梯度累积步骤完成时才有优化器时间
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
total_compute_time = forward_time + backward_time + optimizer_time
|
||||
perf_dict.update({
|
||||
"optimizer_time_ms": optimizer_time,
|
||||
"compute_time_ms": total_compute_time
|
||||
})
|
||||
else:
|
||||
total_compute_time = forward_time + backward_time
|
||||
perf_dict.update({
|
||||
"compute_time_ms": total_compute_time
|
||||
})
|
||||
|
||||
log_dict.update(perf_dict)
|
||||
|
||||
wandb.log(log_dict)
|
||||
|
||||
# 移除通信分析代码
|
||||
|
||||
# 保存模型
|
||||
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
|
||||
model.eval()
|
||||
# 使用函数开始处定义的moe_path变量
|
||||
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
|
||||
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict() #获取模型参数
|
||||
else:
|
||||
state_dict = model.state_dict() #获取模型参数
|
||||
|
||||
torch.save(state_dict, ckp) #只保存参数
|
||||
model.train()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {str(e)}")
|
||||
save_path = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}_nanERROR.pth'
|
||||
if os.path.exists(save_path):
|
||||
os.remove(save_path)
|
||||
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, save_path)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None and torch.isnan(param.grad).any():
|
||||
print(f"NaN gradient in parameter: {name}")
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None and torch.isnan(param.grad).any():
|
||||
print(f"Parameter {name} values: {param.data}")
|
||||
print(f"Parameter {name} gradients: {param.grad}")
|
||||
|
||||
raise ValueError("NaN gradient detected")
|
||||
|
||||
|
||||
def init_model(lm_config, pretrained_embedding_path: Optional[str] = None):
|
||||
# 加载tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
# 加载模型
|
||||
model = MiniMindLM(lm_config).to(args.device)
|
||||
|
||||
# Load pretrained token embeddings if path is provided
|
||||
if pretrained_embedding_path and os.path.exists(pretrained_embedding_path):
|
||||
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
|
||||
embedding_weights = torch.load(pretrained_embedding_path, map_location=args.device)
|
||||
model.tok_embeddings.load_state_dict(embedding_weights)
|
||||
Logger("Successfully loaded pretrained token embeddings.")
|
||||
elif pretrained_embedding_path:
|
||||
Logger(f"Warning: Pretrained embedding path {pretrained_embedding_path} provided but file does not exist. Initializing embeddings from scratch.")
|
||||
|
||||
# 打印模型参数
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
# 移除通信分析函数
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if not ddp: return #如果没有启用分布式数据并行(DDP),直接返回,不执行任何操作。
|
||||
global ddp_local_rank, DEVICE #声明这两个变量为全局变量,以便在函数外部也能访问它们。
|
||||
|
||||
dist.init_process_group(backend="nccl") #初始化分布式进程组,使用NCCL后端(NVIDIA Collective Communications Library),这是NVIDIA GPU之间通信的优化库。
|
||||
ddp_rank = int(os.environ["RANK"]) #从环境变量获取当前进程的全局编号。
|
||||
ddp_local_rank = int(os.environ["LOCAL_RANK"]) #从环境变量获取当前进程的本地编号。
|
||||
ddp_world_size = int(os.environ["WORLD_SIZE"]) #从环境变量获取当前进程组中的进程总数。
|
||||
DEVICE = f"cuda:{ddp_local_rank}" #根据本地编号选择GPU设备。
|
||||
torch.cuda.set_device(DEVICE) #设置当前进程的GPU设备。
|
||||
|
||||
|
||||
# torchrun --nproc_per_node 2 1-pretrain.py
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind Pretraining")
|
||||
parser.add_argument("--out_dir", type=str, default="out")
|
||||
# 若要以最快速度实现zero则epochs设置为1轮;否则应当利用有限的数据训练2~6个epochs。
|
||||
parser.add_argument("--epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=24)
|
||||
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用,则使用GPU,否则使用CPU。
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", default=True, action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
|
||||
parser.add_argument("--num_workers", type=int, default=48)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=32) #梯度累积步数,用于控制梯度更新频率。
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0) #梯度裁剪阈值,用于防止梯度爆炸。
|
||||
parser.add_argument("--warmup_iters", type=int, default=0) #预热迭代次数,用于控制学习率预热过程。
|
||||
parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。
|
||||
parser.add_argument("--save_interval", type=int, default=10000) #模型保存间隔,用于控制模型保存的频率。
|
||||
parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。
|
||||
parser.add_argument('--dim', default=1024, type=int) #模型维度,用于控制模型的大小。
|
||||
parser.add_argument('--n_layers', default=32, type=int) #层数,用于控制模型层数。
|
||||
parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。
|
||||
parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE,用于控制是否使用MOE。
|
||||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代") #禁用数据库功能,启用特殊模式
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
|
||||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||||
# 性能分析相关参数
|
||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
|
||||
lm_config = LMConfig(
|
||||
dim=args.dim,
|
||||
n_layers=args.n_layers,
|
||||
max_seq_len=args.max_seq_len,
|
||||
use_moe=args.use_moe,
|
||||
disable_db=args.disable_db, # 添加禁用数据库参数
|
||||
flash_attn=args.use_flash_attn # 添加FlashAttention支持
|
||||
) #创建LMConfig对象,用于控制模型配置。
|
||||
args.save_dir = os.path.join(args.out_dir) #创建保存目录。
|
||||
os.makedirs(args.save_dir, exist_ok=True) #创建保存目录。
|
||||
os.makedirs(args.out_dir, exist_ok=True) #创建输出目录。
|
||||
tokens_per_iter = args.batch_size * lm_config.max_seq_len #计算每个迭代步骤的token数量。
|
||||
print(f"tokens_per_iter: {tokens_per_iter}")
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu" #确定设备类型。
|
||||
|
||||
# Determine the torch dtype
|
||||
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
|
||||
|
||||
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
|
||||
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
|
||||
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
||||
ddp_local_rank, DEVICE = 0, "cuda:0"
|
||||
|
||||
base_seed = 1337
|
||||
torch.manual_seed(base_seed)
|
||||
torch.cuda.manual_seed(base_seed)
|
||||
|
||||
if ddp:
|
||||
init_distributed_mode()
|
||||
args.device = torch.device(DEVICE)
|
||||
rank = dist.get_rank()
|
||||
torch.manual_seed(base_seed + rank)
|
||||
# 同时设置 CUDA 的随机种子
|
||||
torch.cuda.manual_seed(base_seed + rank)
|
||||
|
||||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||||
import wandb
|
||||
|
||||
# Merge args and lm_config parameters for wandb config
|
||||
config = vars(args).copy()
|
||||
config.update(lm_config.__dict__)
|
||||
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
|
||||
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
# 优化DataLoader配置
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
pin_memory_device=f"cuda:{ddp_local_rank}" if ddp else "cuda:0", # 指定pin_memory设备
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler,
|
||||
persistent_workers=True if args.num_workers > 0 else False, # 保持worker进程活跃
|
||||
prefetch_factor=2 if args.num_workers > 0 else None # 预取因子
|
||||
)
|
||||
|
||||
# 只有在使用float16时才启用GradScaler,bfloat16不需要
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
if ddp:
|
||||
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
||||
# 保留find_unused_parameters=True参数,因为模型中确实有未使用的参数
|
||||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank], find_unused_parameters=True)
|
||||
|
||||
# 暂时保留set_detect_anomaly以便调试
|
||||
# 训练稳定后可以注释掉这行来提高速度
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
iter_per_epoch = len(train_loader)
|
||||
for epoch in range(args.epochs):
|
||||
train_epoch(epoch, wandb)
|
437
train_pretrain_accelerate.py
Normal file
437
train_pretrain_accelerate.py
Normal file
@ -0,0 +1,437 @@
|
||||
import os
|
||||
# 设置环境变量
|
||||
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
|
||||
import platform
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import warnings
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch import optim, nn
|
||||
from torch.utils.data import DataLoader
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
import datetime # Add datetime for time formatting
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
||||
|
||||
from model.model import MiniMindLM
|
||||
from model.LMConfig import LMConfig
|
||||
from model.dataset import PretrainDataset
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# 日志记录函数
|
||||
def Logger(msg, accelerator=None):
|
||||
# 如果没有提供accelerator,则只在主进程打印
|
||||
if accelerator is None or accelerator.is_main_process:
|
||||
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}")
|
||||
|
||||
# Helper function to format seconds into HH:MM:SS
|
||||
def format_time(seconds):
|
||||
return str(datetime.timedelta(seconds=int(seconds)))
|
||||
|
||||
# 获取学习率函数
|
||||
def get_lr(it, num_iters, learning_rate):
|
||||
# 余弦学习率衰减
|
||||
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
|
||||
|
||||
# 初始化模型函数
|
||||
def init_model(lm_config, pretrained_embedding_path=None):
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
model = MiniMindLM(lm_config)
|
||||
|
||||
# 如果提供了预训练的嵌入权重,加载它们
|
||||
if pretrained_embedding_path:
|
||||
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
|
||||
pretrained_embeddings = torch.load(pretrained_embedding_path)
|
||||
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
|
||||
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
|
||||
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
return model, tokenizer
|
||||
|
||||
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time):
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
epoch_start_time = time.time()
|
||||
total_steps_in_epoch = len(train_loader)
|
||||
total_training_steps = args.epochs * total_steps_in_epoch
|
||||
moe_path = '_moe' if args.use_moe else ''
|
||||
|
||||
# 添加CUDA事件来分析性能 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
data_start = torch.cuda.Event(enable_timing=True)
|
||||
data_end = torch.cuda.Event(enable_timing=True)
|
||||
forward_start = torch.cuda.Event(enable_timing=True)
|
||||
forward_end = torch.cuda.Event(enable_timing=True)
|
||||
backward_start = torch.cuda.Event(enable_timing=True)
|
||||
backward_end = torch.cuda.Event(enable_timing=True)
|
||||
optimizer_start = torch.cuda.Event(enable_timing=True)
|
||||
optimizer_end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
# 预取数据
|
||||
prefetch_factor = 2 # 预取的批次数
|
||||
data_iter = iter(train_loader)
|
||||
prefetch_batches = []
|
||||
|
||||
# 预取初始批次
|
||||
for _ in range(min(prefetch_factor, len(train_loader))):
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
prefetch_batches.append(batch)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
# 在开始循环前初始化日志记录所需变量
|
||||
last_log_time = epoch_start_time
|
||||
|
||||
for step in range(total_steps_in_epoch):
|
||||
try:
|
||||
# 计时数据加载 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
data_start.record()
|
||||
|
||||
# 使用预取的数据
|
||||
if prefetch_batches:
|
||||
X, Y, loss_mask = prefetch_batches.pop(0)
|
||||
else:
|
||||
# 如果预取队列为空,直接加载
|
||||
X, Y, loss_mask = next(data_iter)
|
||||
|
||||
# 异步预取下一批数据
|
||||
if step + prefetch_factor < len(train_loader):
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
prefetch_batches.append(batch)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
# 计时数据加载结束 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
data_end.record()
|
||||
|
||||
# 更新学习率
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
# 计时前向传播 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
forward_start.record()
|
||||
|
||||
# 前向传播
|
||||
with ctx:
|
||||
res = model(X)
|
||||
loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
Y.view(-1)
|
||||
).view(Y.size())
|
||||
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||
# 添加辅助损失,如果存在的话
|
||||
try:
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
|
||||
if hasattr(l.feed_forward, 'aux_loss'))
|
||||
loss += aux_loss
|
||||
except Exception as e:
|
||||
Logger(f"Warning: Could not add auxiliary loss: {e}")
|
||||
# 如果出错,不添加辅助损失
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
# 计时前向传播结束 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
forward_end.record()
|
||||
|
||||
# 计时反向传播 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
backward_start.record()
|
||||
|
||||
# 反向传播
|
||||
# 当使用DeepSpeed时,它会自动处理梯度累积和梯度裁剪
|
||||
accelerator.backward(loss)
|
||||
|
||||
# 计时反向传播结束 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
backward_end.record()
|
||||
|
||||
# 计时优化器步骤 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
optimizer_start.record()
|
||||
|
||||
# 优化器步骤 - 当使用DeepSpeed时,它会自动处理梯度累积和梯度裁剪
|
||||
# 只有在达到累积步数时才会执行优化器步骤
|
||||
# 注意:当使用DeepSpeed时,它会自动处理梯度累积,所以我们不需要检查step % accumulation_steps
|
||||
optimizer.step()
|
||||
|
||||
# 当使用DeepSpeed时,zero_grad()会在step()之后自动调用
|
||||
# 但为了安全起见,我们仍然显式调用它
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 计时优化器步骤结束 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
optimizer_end.record()
|
||||
|
||||
# 打印训练信息 (只在主进程进行)
|
||||
if (step + 1) % args.log_interval == 0 and accelerator.is_main_process:
|
||||
current_time = time.time()
|
||||
# 计算性能指标
|
||||
if args.profile:
|
||||
torch.cuda.synchronize()
|
||||
# 使用自上次日志以来的时间计算性能指标,而不是总时间
|
||||
data_time = data_start.elapsed_time(data_end)
|
||||
forward_time = forward_start.elapsed_time(forward_end)
|
||||
backward_time = backward_start.elapsed_time(backward_end)
|
||||
optimizer_time = optimizer_start.elapsed_time(optimizer_end)
|
||||
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
|
||||
# total_time_ms = data_time + forward_time + backward_time + optimizer_time
|
||||
|
||||
# 打印性能分析
|
||||
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
|
||||
Logger(f"性能分析 (Avg/iter over last {args.log_interval} steps) - "
|
||||
f"Data: {data_time/args.log_interval:.2f}ms, "
|
||||
f"Fwd: {forward_time/args.log_interval:.2f}ms, "
|
||||
f"Bwd: {backward_time/args.log_interval:.2f}ms, "
|
||||
f"Optim: {optimizer_time/args.log_interval:.2f}ms, "
|
||||
f"Iter Time: {iter_time:.2f}ms", accelerator)
|
||||
# 重置事件以便下次测量从0开始
|
||||
data_start = torch.cuda.Event(enable_timing=True)
|
||||
data_end = torch.cuda.Event(enable_timing=True)
|
||||
forward_start = torch.cuda.Event(enable_timing=True)
|
||||
forward_end = torch.cuda.Event(enable_timing=True)
|
||||
backward_start = torch.cuda.Event(enable_timing=True)
|
||||
backward_end = torch.cuda.Event(enable_timing=True)
|
||||
optimizer_start = torch.cuda.Event(enable_timing=True)
|
||||
optimizer_end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
|
||||
# 计算当前学习率
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
# 计算时间
|
||||
epoch_elapsed_time = current_time - epoch_start_time
|
||||
epoch_steps_done = step + 1
|
||||
epoch_avg_step_time = epoch_elapsed_time / epoch_steps_done
|
||||
epoch_remaining_time = epoch_avg_step_time * (total_steps_in_epoch - epoch_steps_done)
|
||||
|
||||
total_elapsed_time = current_time - overall_start_time
|
||||
total_steps_done = epoch * total_steps_in_epoch + epoch_steps_done
|
||||
total_avg_step_time = total_elapsed_time / total_steps_done if total_steps_done > 0 else 0
|
||||
total_remaining_time = total_avg_step_time * (total_training_steps - total_steps_done) if total_steps_done > 0 else 0
|
||||
|
||||
# 计算训练速度 (基于最近的log_interval)
|
||||
interval_elapsed_time = current_time - last_log_time
|
||||
tokens_processed_interval = args.log_interval * args.batch_size * args.max_seq_len
|
||||
tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0
|
||||
last_log_time = current_time # 更新上次日志时间
|
||||
|
||||
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
|
||||
f"Loss: {loss.item()*args.accumulation_steps:.4f}, "
|
||||
f"LR: {current_lr:.6f}, "
|
||||
f"Speed: {tokens_per_sec:.2f} tokens/sec | "
|
||||
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
|
||||
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
|
||||
|
||||
# 保存模型 (只在主进程进行)
|
||||
if (step + 1) % args.save_interval == 0 and accelerator.is_main_process:
|
||||
# 使用函数开始处定义的moe_path变量
|
||||
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
|
||||
|
||||
# 获取解包后的模型
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
|
||||
# 保存模型参数
|
||||
accelerator.save(unwrapped_model.state_dict(), ckp)
|
||||
Logger(f"Model saved to {ckp}", accelerator)
|
||||
|
||||
except Exception as e:
|
||||
Logger(f"Error in training step: {e}", accelerator)
|
||||
import traceback
|
||||
Logger(traceback.format_exc(), accelerator)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
|
||||
parser.add_argument("--out_dir", type=str, default="out")
|
||||
parser.add_argument("--epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=24)
|
||||
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", default=True, action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
|
||||
parser.add_argument("--num_workers", type=int, default=48)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=32)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--warmup_iters", type=int, default=0)
|
||||
parser.add_argument("--log_interval", type=int, default=100)
|
||||
parser.add_argument("--save_interval", type=int, default=10000)
|
||||
parser.add_argument('--dim', default=1024, type=int)
|
||||
parser.add_argument('--n_layers', default=32, type=int)
|
||||
parser.add_argument('--max_seq_len', default=1024, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
|
||||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||
parser.add_argument("--knowlwdge_num", type=int, default=64*64,help="知识库的数据数目")
|
||||
parser.add_argument("--knowlwdge_length", type=int, default=8,help="知识库的句子长度")
|
||||
args = parser.parse_args()
|
||||
|
||||
#########################################################
|
||||
# 初始化accelerator和deepspeed
|
||||
#########################################################
|
||||
# 设置ddp_kwargs以处理未使用的参数
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# 创建DeepSpeedPlugin对象
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
gradient_accumulation_steps=args.accumulation_steps,
|
||||
gradient_clipping=args.grad_clip,
|
||||
zero_stage=2, # 使用ZeRO-2优化
|
||||
offload_optimizer_device="cpu", # 将优化器状态卸载到CPU
|
||||
offload_param_device="none", # 不将参数卸载到CPU
|
||||
)
|
||||
accelerator = Accelerator(
|
||||
kwargs_handlers=[ddp_kwargs],
|
||||
deepspeed_plugin=ds_plugin,
|
||||
mixed_precision="bf16" if args.dtype == "bfloat16" else "fp16" if args.dtype == "float16" else "no"
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# 设置随机种子
|
||||
#########################################################
|
||||
set_seed(1337 + accelerator.process_index)
|
||||
|
||||
#########################################################
|
||||
# 配置模型
|
||||
#########################################################
|
||||
lm_config = LMConfig(
|
||||
dim=args.dim,
|
||||
n_layers=args.n_layers,
|
||||
max_seq_len=args.max_seq_len,
|
||||
use_moe=args.use_moe,
|
||||
disable_db=args.disable_db,
|
||||
flash_attn=args.use_flash_attn,
|
||||
knowlwdge_num=args.knowlwdge_num,
|
||||
knowlwdge_length=args.knowlwdge_length
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# 创建保存目录
|
||||
#########################################################
|
||||
args.save_dir = os.path.join(args.out_dir)
|
||||
if accelerator.is_main_process:
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
#########################################################
|
||||
# 设置数据类型
|
||||
#########################################################
|
||||
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
|
||||
|
||||
|
||||
#########################################################
|
||||
# 配置wandb
|
||||
#########################################################
|
||||
# 设置wandb运行名称
|
||||
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
if args.use_wandb and accelerator.is_main_process:
|
||||
import wandb
|
||||
# 合并args和lm_config为一个字典
|
||||
config_dict = vars(args).copy()
|
||||
config_dict.update(vars(lm_config))
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config_dict)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
#########################################################
|
||||
# 打印信息
|
||||
#########################################################
|
||||
# 计算每次迭代的token数量
|
||||
tokens_per_iter = args.batch_size * lm_config.max_seq_len
|
||||
if accelerator.is_main_process:
|
||||
Logger(f"tokens_per_iter: {tokens_per_iter}", accelerator)
|
||||
Logger("Configuration:", accelerator)
|
||||
for key, value in config_dict.items():
|
||||
Logger(f" {key}: {value}", accelerator)
|
||||
|
||||
|
||||
#########################################################
|
||||
# 设置自动混合精度上下文
|
||||
#########################################################
|
||||
ctx = nullcontext() if accelerator.device.type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
|
||||
|
||||
#########################################################
|
||||
# 初始化模型和tokenizer
|
||||
#########################################################
|
||||
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
|
||||
# 将accelerator传递给init_model函数中的Logger调用
|
||||
Logger(f'模型初始化完成', accelerator)
|
||||
|
||||
#########################################################
|
||||
# 处理位置编码张量问题
|
||||
#########################################################
|
||||
if hasattr(model, "pos_cis_real"):
|
||||
Logger(f'检测到pos_cis_real实数张量,将其设置为参与分布式训练', accelerator)
|
||||
# 设置模型的_ddp_params_and_buffers_to_ignore属性
|
||||
# model._ddp_params_and_buffers_to_ignore = {"pos_cis_real"}
|
||||
# 兼容旧版本,检查是否仍有pos_cis
|
||||
elif hasattr(model, "pos_cis"):
|
||||
Logger(f'检测到pos_cis复数张量,将其设置为不参与分布式训练', accelerator)
|
||||
# 设置模型的_ddp_params_and_buffers_to_ignore属性
|
||||
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
||||
|
||||
#########################################################
|
||||
# 创建数据集和数据加载器
|
||||
#########################################################
|
||||
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
num_workers=args.num_workers,
|
||||
persistent_workers=True if args.num_workers > 0 else False,
|
||||
prefetch_factor=2 if args.num_workers > 0 else None
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# 创建优化器
|
||||
#########################################################
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
#########################################################
|
||||
# 创建学习率调度器
|
||||
#########################################################
|
||||
total_steps = len(train_loader) * args.epochs
|
||||
warmup_steps = args.warmup_iters if args.warmup_iters > 0 else int(0.1 * total_steps)
|
||||
scheduler = get_cosine_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=warmup_steps,
|
||||
num_training_steps=total_steps
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# 准备训练
|
||||
#########################################################
|
||||
model, optimizer, train_loader, scheduler = accelerator.prepare(
|
||||
model, optimizer, train_loader, scheduler
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# 训练循环
|
||||
#########################################################
|
||||
overall_start_time = time.time() # Record overall start time
|
||||
for epoch in range(args.epochs):
|
||||
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time) # Pass overall start time
|
||||
|
||||
#########################################################
|
||||
# 关闭wandb
|
||||
#########################################################
|
||||
if args.use_wandb and accelerator.is_main_process:
|
||||
wandb.finish()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user