DynamicKV-LLM Pretrain v1.1.0

This commit is contained in:
iomgaa 2025-05-14 00:01:40 +08:00
commit 089afd6728
32 changed files with 23963 additions and 0 deletions

5
.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
/model/__pycache__
/dataset
/out
wandb/
**/*.log

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

126
README_accelerate.md Normal file
View File

@ -0,0 +1,126 @@
# 使用Accelerate+DeepSpeed进行分布式训练
本文档介绍如何使用Accelerate和DeepSpeed进行MiniMind模型的分布式训练。
## 环境准备
首先,确保安装了必要的依赖:
```bash
pip install accelerate deepspeed
```
## 配置文件说明
### 1. DeepSpeed配置文件 (ds_config.json)
DeepSpeed配置文件定义了优化器、学习率调度器和ZeRO优化等参数。主要配置包括
- **ZeRO优化**使用ZeRO-2进行优化可以减少GPU内存使用
- **优化器设置**使用AdamW优化器
- **混合精度训练**支持FP16和BF16
- **梯度累积**:通过"auto"自动设置,与训练脚本参数保持一致
### 2. Accelerate配置文件 (accelerate_config.yaml)
Accelerate配置文件定义了分布式训练的基本设置包括
- **分布式类型**使用DeepSpeed
- **混合精度**使用BF16
- **进程数量**设置为4可根据GPU数量调整
- **DeepSpeed配置**指向ds_config.json文件
## 训练脚本说明
新的训练脚本`train_pretrain_accelerate.py`基于原有的`train_pretrain.py`修改而来,主要变化包括:
1. 使用Accelerator替代了PyTorch原生的分布式功能
2. 移除了torchrun相关的分布式初始化代码
3. 使用Accelerator的API进行模型、优化器和数据加载器的准备
4. 使用Accelerator的API进行反向传播和梯度裁剪
5. 处理了位置编码和未使用参数的问题
## 启动训练
有两种方式启动训练:
### 方法1使用预先配置的accelerate配置文件
```bash
accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
--epochs 3 \
--batch_size 24 \
--learning_rate 2e-4 \
--dtype bfloat16 \
--accumulation_steps 32 \
--grad_clip 1.0 \
--log_interval 100 \
--save_interval 10000 \
--dim 1024 \
--n_layers 32 \
--max_seq_len 1024 \
--use_flash_attn \
--profile \
--profile_interval 10
```
### 方法2使用命令行参数直接配置accelerate
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--multi_gpu \
--num_processes=4 \
--mixed_precision=bf16 \
--main_process_port=29500 \
--deepspeed_config_file ds_config.json \
train_pretrain_accelerate.py \
--epochs 3 \
--batch_size 24 \
--learning_rate 2e-4 \
--dtype bfloat16 \
--accumulation_steps 32 \
--grad_clip 1.0 \
--log_interval 100 \
--save_interval 10000 \
--dim 1024 \
--n_layers 32 \
--max_seq_len 1024 \
--use_flash_attn \
--profile \
--profile_interval 10
```
也可以直接使用提供的脚本:
```bash
bash run_accelerate.sh
```
## Accelerate与DeepSpeed配置的关系
1. **Accelerate**是一个高级API用于简化分布式训练的设置和启动它可以与多种分布式训练后端如DeepSpeed、FSDP等一起使用。
2. **DeepSpeed**是一个优化库专注于大规模模型训练的内存优化和性能提升提供了ZeRO优化等功能。
3. **配置关系**
- Accelerate配置文件YAML定义了使用哪种分布式后端以及基本的分布式设置
- DeepSpeed配置文件JSON定义了DeepSpeed特有的优化参数
- Accelerate通过`deepspeed_config_file`参数引用DeepSpeed配置文件
## 注意事项
1. **位置编码处理**
- 在模型中,`pos_cis`是一个复数张量,在分布式训练中需要特别处理
- 在新的训练脚本中我们使用Accelerator的API来处理这个问题不再需要`_ddp_params_and_buffers_to_ignore`
2. **未使用参数处理**
- 原代码中使用`find_unused_parameters=True`来处理未使用的参数
- 在新的训练脚本中我们直接使用Accelerator的API它会自动处理这个问题
3. **混合精度训练**
- DeepSpeed配置文件中的`fp16``bf16`设置为`"auto"`
- 实际使用的精度由Accelerate的`--mixed_precision`参数决定
4. **梯度累积**
- DeepSpeed配置文件中的`gradient_accumulation_steps`设置为`"auto"`
- 实际的梯度累积步数由训练脚本的`--accumulation_steps`参数决定

22
ReadMe.md Normal file
View File

@ -0,0 +1,22 @@
## 安装环境
1. 创建conda环境
```bash
conda create -n accelerate python=3.10
conda activate accelerate
```
2. 根据当前系统的cuda版本安装对应的torch、torchvision和torchaudio
3. 根据当前环境的torch和torchvision安装accelerate和deepspeed
4. 安装其他包
```bash
pip install -r requirements.txt
```
## 修改模型
1. 一般情况只修改 `model`文件夹的文件
## 运行
1. 如果在4090或者4070ti上运行 `bash run_file/DynamicKV-LLM_Mini_Minimind.sh`
2. 如果在4张A800上运行 `bash run_file/DynamicKV-LLM_Small_Minimind.sh`

17
accelerate_config.yaml Normal file
View File

@ -0,0 +1,17 @@
compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_config_file: ds_config.json
zero3_init_flag: false
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

49
ds_config.json Normal file
View File

@ -0,0 +1,49 @@
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true
},
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"steps_per_print": 100,
"wall_clock_breakdown": false
}

181
eval_model.py Normal file
View File

@ -0,0 +1,181 @@
import argparse
import random
import time
import numpy as np
import torch
import warnings
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model import MiniMindLM
from model.LMConfig import LMConfig
from model.model_lora import *
warnings.filterwarnings('ignore')
def init_model(args):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
if args.load == 0:
moe_path = '_moe' if args.use_moe else ''
modes = {0: 'pretrain', 1: 'full_sft', 2: 'rlhf', 3: 'reason', 4: 'grpo'}
ckp = f'./{args.out_dir}/{modes[args.model_mode]}_{args.dim}{moe_path}.pth'
model = MiniMindLM(LMConfig(
dim=args.dim,
n_layers=args.n_layers,
max_seq_len=args.max_seq_len,
use_moe=args.use_moe
))
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict({k: v for k, v in state_dict.items() if 'mask' not in k}, strict=True)
if args.lora_name != 'None':
apply_lora(model)
load_lora(model, f'./{args.out_dir}/lora/{args.lora_name}_{args.dim}.pth')
else:
transformers_model_path = './MiniMind2'
tokenizer = AutoTokenizer.from_pretrained(transformers_model_path)
model = AutoModelForCausalLM.from_pretrained(transformers_model_path, trust_remote_code=True)
print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M(illion)')
return model.eval().to(args.device), tokenizer
def get_prompt_datas(args):
if args.model_mode == 0:
# pretrain模型的接龙能力无法对话
prompt_datas = [
'马克思主义基本原理',
'人类大脑的主要功能',
'万有引力原理是',
'世界上最高的山峰是',
'二氧化碳在空气中',
'地球上最大的动物有',
'杭州市的美食有'
]
else:
if args.lora_name == 'None':
# 通用对话问题
prompt_datas = [
'请介绍一下自己。',
'你更擅长哪一个学科?',
'鲁迅的《狂人日记》是如何批判封建礼教的?',
'我咳嗽已经持续了两周,需要去医院检查吗?',
'详细的介绍光速的物理概念。',
'推荐一些杭州的特色美食吧。',
'请为我讲解“大语言模型”这个概念。',
'如何理解ChatGPT',
'Introduce the history of the United States, please.'
]
else:
# 特定领域问题
lora_prompt_datas = {
'lora_identity': [
"你是ChatGPT吧。",
"你叫什么名字?",
"你和openai是什么关系"
],
'lora_medical': [
'我最近经常感到头晕,可能是什么原因?',
'我咳嗽已经持续了两周,需要去医院检查吗?',
'服用抗生素时需要注意哪些事项?',
'体检报告中显示胆固醇偏高,我该怎么办?',
'孕妇在饮食上需要注意什么?',
'老年人如何预防骨质疏松?',
'我最近总是感到焦虑,应该怎么缓解?',
'如果有人突然晕倒,应该如何急救?'
],
}
prompt_datas = lora_prompt_datas[args.lora_name]
return prompt_datas
# 设置可复现的随机种子
def setup_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
parser = argparse.ArgumentParser(description="Chat with MiniMind")
parser.add_argument('--lora_name', default='None', type=str)
parser.add_argument('--out_dir', default='out', type=str)
parser.add_argument('--temperature', default=0.85, type=float)
parser.add_argument('--top_p', default=0.85, type=float)
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str)
# 此处max_seq_len最大允许输入长度并不意味模型具有对应的长文本的性能仅防止QA出现被截断的问题
# MiniMind2-moe (145M)(dim=640, n_layers=8, use_moe=True)
# MiniMind2-Small (26M)(dim=512, n_layers=8)
# MiniMind2 (104M)(dim=768, n_layers=16)
parser.add_argument('--dim', default=512, type=int)
parser.add_argument('--n_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=8192, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
# 携带历史对话上下文条数
# history_cnt需要设为偶数即【用户问题, 模型回答】为1组设置为0时即当前query不携带历史上文
# 模型未经过外推微调时在更长的上下文的chat_template时难免出现性能的明显退化因此需要注意此处设置
parser.add_argument('--history_cnt', default=0, type=int)
parser.add_argument('--stream', default=True, type=bool)
parser.add_argument('--load', default=0, type=int, help="0: 原生torch权重1: transformers加载")
parser.add_argument('--model_mode', default=1, type=int,
help="0: 预训练模型1: SFT-Chat模型2: RLHF-Chat模型3: Reason模型4: RLAIF-Chat模型")
args = parser.parse_args()
model, tokenizer = init_model(args)
prompts = get_prompt_datas(args)
test_mode = int(input('[0] 自动测试\n[1] 手动输入\n'))
messages = []
for idx, prompt in enumerate(prompts if test_mode == 0 else iter(lambda: input('👶: '), '')):
setup_seed(random.randint(0, 2048))
# setup_seed(2025) # 如需固定每次输出则换成【固定】的随机种子
if test_mode == 0: print(f'👶: {prompt}')
messages = messages[-args.history_cnt:] if args.history_cnt else []
messages.append({"role": "user", "content": prompt})
new_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)[-args.max_seq_len - 1:] if args.model_mode != 0 else (tokenizer.bos_token + prompt)
answer = new_prompt
with torch.no_grad():
x = torch.tensor(tokenizer(new_prompt)['input_ids'], device=args.device).unsqueeze(0)
outputs = model.generate(
x,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=args.max_seq_len,
temperature=args.temperature,
top_p=args.top_p,
stream=args.stream,
pad_token_id=tokenizer.pad_token_id
)
print('🤖️: ', end='')
try:
if not args.stream:
print(tokenizer.decode(outputs.squeeze()[x.shape[1]:].tolist(), skip_special_tokens=True), end='')
else:
history_idx = 0
for y in outputs:
answer = tokenizer.decode(y[0].tolist(), skip_special_tokens=True)
if (answer and answer[-1] == '<EFBFBD>') or not answer:
continue
print(answer[history_idx:], end='', flush=True)
history_idx = len(answer)
except StopIteration:
print("No answer")
print('\n')
messages.append({"role": "assistant", "content": answer})
if __name__ == "__main__":
main()

75
model/LMConfig.py Normal file
View File

@ -0,0 +1,75 @@
from transformers import PretrainedConfig
from typing import List
class LMConfig(PretrainedConfig):
model_type = "minimind"
def __init__(
self,
dim: int = 512,
n_layers: int = 8,
n_heads: int = 32,
n_kv_heads: int = 8,
vocab_size: int = 6400,
hidden_dim: int = None,
multiple_of: int = 64,
norm_eps: float = 1e-5,
max_seq_len: int = 8192,
rope_theta: int = 1e6,
dropout: float = 0.0,
flash_attn: bool = True,
####################################################
# DB related configurations
####################################################
disable_db: bool = False, # 特殊模式:禁用数据库功能
####################################################
# Here are the specific configurations of MOE
# When use_moe is false, the following is invalid
####################################################
use_moe: bool = False,
####################################################
num_experts_per_tok: int = 2,
n_routed_experts: int = 4,
n_shared_experts: bool = True,
scoring_func: str = 'softmax',
aux_loss_alpha: float = 0.1,
seq_aux: bool = True,
norm_topk_prob: bool = True,
####################################################
knowlwdge_num: int = 64*64,
knowlwdge_length: int = 8,
**kwargs,
):
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.multiple_of = multiple_of
self.norm_eps = norm_eps
self.max_seq_len = max_seq_len
self.rope_theta = rope_theta
self.dropout = dropout
self.flash_attn = flash_attn
####################################################
# DB related configurations
####################################################
self.disable_db = disable_db # 设置是否禁用数据库
####################################################
# Here are the specific configurations of MOE
# When use_moe is false, the following is invalid
####################################################
self.use_moe = use_moe
self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量
self.n_routed_experts = n_routed_experts # 总的专家数量
self.n_shared_experts = n_shared_experts # 共享专家
self.scoring_func = scoring_func # 评分函数,默认为'softmax'
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
####################################################
self.knowlwdge_num = knowlwdge_num
self.knowlwdge_length = knowlwdge_length
super().__init__(**kwargs)

245
model/dataset.py Normal file
View File

@ -0,0 +1,245 @@
import json
import random
import re
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from sklearn.model_selection import train_test_split
import os
import ast
os.environ["TOKENIZERS_PARALLELISM"] = "true"
class PretrainDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = self.load_data(data_path)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
# 构建输入文本
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding.input_ids.squeeze()
loss_mask = (input_ids != self.tokenizer.pad_token_id)
X = torch.tensor(input_ids[:-1], dtype=torch.long)
Y = torch.tensor(input_ids[1:], dtype=torch.long)
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)
return X, Y, loss_mask
class SFTDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = self.load_data(jsonl_path)
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def _create_chat_prompt(self, conversations):
"""构建符合ChatML格式的对话"""
messages = []
for i, turn in enumerate(conversations):
role = 'user' if i % 2 == 0 else 'assistant'
messages.append({"role": role, "content": turn['content']})
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False
)
def _generate_loss_mask(self, input_ids):
loss_mask = [0] * len(input_ids)
i = 0
while i < len(input_ids):
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
start = i + len(self.bos_id)
end = start
while end < len(input_ids):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
loss_mask[j] = 1
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
i += 1
return loss_mask
def __getitem__(self, index):
sample = self.samples[index]
# 构建对话提示
prompt = self._create_chat_prompt(sample['conversations'])
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
# 生成动态损失掩码
loss_mask = self._generate_loss_mask(input_ids)
# 构建训练数据
X = torch.tensor(input_ids[:-1], dtype=torch.long)
Y = torch.tensor(input_ids[1:], dtype=torch.long)
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐预测位置
return X, Y, loss_mask
class DPODataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=4096):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
with open(file_path, 'r', encoding='utf-8') as f:
self.data = []
for line in f:
line = line.strip()
obj = json.loads(line)
self.data.append(obj)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
item = self.data[index]
chosen = item['chosen'] # 是一个 list里面包含若干 {role, content}
rejected = item['rejected'] # 同上
chosen_prompt = self.tokenizer.apply_chat_template(
chosen, tokenize=False, add_generation_prompt=False
)
rejected_prompt = self.tokenizer.apply_chat_template(
rejected, tokenize=False, add_generation_prompt=False
)
chosen_encoding = self.tokenizer(
chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
)
rejected_encoding = self.tokenizer(
rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length'
)
chosen_input_ids = chosen_encoding['input_ids']
chosen_loss_mask = self._generate_loss_mask(chosen_input_ids)
rejected_input_ids = rejected_encoding['input_ids']
rejected_loss_mask = self._generate_loss_mask(rejected_input_ids)
x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
x_rejected = torch.tensor(rejected_input_ids[:-1], dtype=torch.long)
y_rejected = torch.tensor(rejected_input_ids[1:], dtype=torch.long)
mask_rejected = torch.tensor(rejected_loss_mask[1:], dtype=torch.long)
return {
'x_chosen': x_chosen,
'y_chosen': y_chosen,
'mask_chosen': mask_chosen,
'x_rejected': x_rejected,
'y_rejected': y_rejected,
'mask_rejected': mask_rejected
}
def _generate_loss_mask(self, input_ids):
loss_mask = [0] * len(input_ids)
i = 0
while i < len(input_ids):
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
start = i + len(self.bos_id)
end = start
while end < len(input_ids):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
loss_mask[j] = 1
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
i += 1
return loss_mask
class RLAIFDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = self.load_data(jsonl_path)
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def _create_chat_prompt(self, conversations):
"""构建符合ChatML格式的对话"""
messages = []
answer = ''
for i, turn in enumerate(conversations):
role = 'user' if i % 2 == 0 else 'assistant'
messages.append({"role": role, "content": turn['content']})
answer = turn['content']
return self.tokenizer.apply_chat_template(
messages[:-1],
tokenize=False,
add_generation_prompt=True
), answer
def __getitem__(self, index):
sample = self.samples[index]
# 构建对话提示
prompt, answer = self._create_chat_prompt(sample['conversations'])
return {
'prompt': prompt,
'answer': answer
}
if __name__ == "__main__":
pass

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,43 @@
{
"add_bos_token": false,
"add_eos_token": false,
"add_prefix_space": false,
"added_tokens_decoder": {
"0": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [],
"bos_token": "<s>",
"clean_up_tokenization_spaces": false,
"eos_token": "</s>",
"legacy": true,
"model_max_length": 32768,
"pad_token": "<unk>",
"sp_model_kwargs": {},
"spaces_between_special_tokens": false,
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": "<unk>",
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<s>system\\n' + system_message + '</s>\\n' }}{% else %}{{ '<s>system\\n你是 MiniMind是一个有用的人工智能助手。</s>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
}

File diff suppressed because one or more lines are too long

755
model/model.py Normal file
View File

@ -0,0 +1,755 @@
import math
import struct
import inspect
import time
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch import nn, einsum
from einops import rearrange, repeat
def exists(val):
return val is not None
# RMSNorm 类定义了一个用于归一化输入张量的模块。
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
# precompute_pos_cis 函数用于预计算位置编码(复数版本)。
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
# apply_rotary_emb 函数用于应用旋转位置编码(复数版本)。
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
# precompute_pos_cis_real 函数用于预计算位置编码(实数版本)。
def precompute_pos_cis_real(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
"""使用实数张量实现位置编码,避免使用复数张量
这个函数与precompute_pos_cis完全等价但使用实数张量而非复数张量
原始函数生成形状为[seq_len, dim//2]的复数张量其中实部全为1虚部为旋转角度
这个函数生成形状为[seq_len, dim]的实数张量其中偶数索引是cos(角度)奇数索引是sin(角度)
"""
# 确保dim是偶数
if dim % 2 != 0:
raise ValueError(f"维度必须是偶数,但得到了 {dim}")
# 复制原始函数的频率计算逻辑
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
# 计算cos和sin值
# 在复数版本中pos_cis = torch.polar(torch.ones_like(freqs), freqs)
# 等价于 cos(freqs) + i*sin(freqs)
cos = torch.cos(freqs)
sin = torch.sin(freqs)
# 创建实数张量交错排列cos和sin
pos_emb = torch.zeros((end, dim), device=freqs.device)
pos_emb[:, 0::2] = cos # 偶数索引放cos
pos_emb[:, 1::2] = sin # 奇数索引放sin
return pos_emb
# apply_rotary_emb_real 函数用于应用旋转位置编码(实数版本)。
def apply_rotary_emb_real(xq, xk, pos_emb):
"""使用实数张量实现旋转位置编码,避免使用复数张量
这个函数与apply_rotary_emb完全等价但使用实数张量而非复数张量
原始函数将输入张量转换为复数形式与位置编码相乘然后再转回实数形式
这个函数直接使用实数运算实现相同的旋转操作
"""
# 获取形状信息
bsz, seq_len, n_heads, head_dim = xq.shape
# 确保pos_emb形状正确
assert pos_emb.shape[0] >= seq_len, f"位置编码长度 {pos_emb.shape[0]} 小于序列长度 {seq_len}"
assert pos_emb.shape[1] == head_dim, f"位置编码维度 {pos_emb.shape[1]} 与头维度 {head_dim} 不匹配"
# 截取需要的位置编码长度
pos_emb = pos_emb[:seq_len]
# 将pos_emb调整为广播形状 [1, seq_len, 1, head_dim]
pos_emb = pos_emb.unsqueeze(0).unsqueeze(2)
# 将head_dim分成两半
half_head_dim = head_dim // 2
# 提取cos和sin值偶数索引是cos奇数索引是sin
cos = pos_emb[..., 0::2]
sin = pos_emb[..., 1::2]
# 将xq和xk重新排列以便进行旋转操作
# 原始复数版本中xq和xk被重塑为复数张量其中实部和虚部交错排列
# 在实数版本中,我们需要将偶数索引和奇数索引分开处理
# 分离偶数和奇数索引
xq_even = xq[..., 0::2] # 偶数索引,对应复数的实部
xq_odd = xq[..., 1::2] # 奇数索引,对应复数的虚部
xk_even = xk[..., 0::2]
xk_odd = xk[..., 1::2]
# 应用旋转(等价于复数乘法)
# (a + bi)(cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i
# 其中a是偶数索引b是奇数索引
xq_out_even = xq_even * cos - xq_odd * sin # 新的偶数索引(实部)
xq_out_odd = xq_even * sin + xq_odd * cos # 新的奇数索引(虚部)
xk_out_even = xk_even * cos - xk_odd * sin
xk_out_odd = xk_even * sin + xk_odd * cos
# 重新组合偶数和奇数索引
xq_out = torch.zeros_like(xq)
xk_out = torch.zeros_like(xk)
xq_out[..., 0::2] = xq_out_even
xq_out[..., 1::2] = xq_out_odd
xk_out[..., 0::2] = xk_out_even
xk_out[..., 1::2] = xk_out_odd
return xq_out.type_as(xq), xk_out.type_as(xk)
# repeat_kv 函数用于重复键值对。
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor,
db_value=None):
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换得到查询、键和值。
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) #将变换后的张量xq重塑为形状为(bsz, seq_len, n_local_heads, head_dim)的形状。
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
# 应用旋转位置编码(使用实数版本)
xq, xk = apply_rotary_emb_real(xq, xk, pos_cis)
# kv_cache实现 REMOVED
# if past_key_value is not None:
# xk = torch.cat([past_key_value[0], xk], dim=1)
# xv = torch.cat([past_key_value[1], xv], dim=1)
# past_kv = (xk, xv) if use_cache else None
# 重复键值对
xq, xk, xv = (
xq.transpose(1, 2),
repeat_kv(xk, self.n_rep).transpose(1, 2),
repeat_kv(xv, self.n_rep).transpose(1, 2)
)
# 如果提供了db_value根据头的数量调整它的形状并与xv合并
if db_value is not None:
# 确保db_value的形状与xv兼容假设db_value形状为[B, N, H, D]
if db_value.ndim == 4: # [B, N, H, D]
db_value = db_value.transpose(1, 2) # -> [B, H, N, D]
# 检查是否需要调整D维度
if db_value.shape[-1] != xv.shape[-1]:
# 如果db_value的维度与xv不同可以添加一个投影层
# 或者在这里使用简单的调整方法
# 这里我们简单地通过均值池化或重复来调整维度
if db_value.shape[-1] > xv.shape[-1]:
# 降维
factor = db_value.shape[-1] // xv.shape[-1]
db_value = db_value.view(bsz, self.n_local_heads, seq_len, factor, xv.shape[-1])
db_value = db_value.mean(dim=3)
else:
# 升维
factor = xv.shape[-1] // db_value.shape[-1]
db_value = db_value.unsqueeze(-1).repeat(1, 1, 1, 1, factor)
db_value = db_value.view(bsz, self.n_local_heads, seq_len, xv.shape[-1])
# 将db_value与xv相加或融合
# 这里我们简单地将它们相加,但你也可以使用其他融合方法
xv = xv + db_value
# 使用Flash Attention
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
return context
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
if config.hidden_dim is None:
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class MoEGate(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = 0
return topk_idx, topk_weight, aux_loss
class MOEFeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
self.shared_experts = FeedForward(config)
def forward(self, x):
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
self.aux_loss = aux_loss
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok
# 当tokens_per_expert = [6, 15, 20, 26]tokens_per_expert.shape[0]即为专家数量此时为4
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token每个token有可能被多个专家处理这取决于num_experts_per_tok
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.attention = Attention(config)
self.cross_att = CrossAttention(config)
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
# 假设num_experts是已定义的总专家数量的平方根
# 查询生成的参数
# 创建查询生成模块
# if weight_down_embed is not None:
# self.to_queries = nn.Sequential(
# nn.Linear(config.dim, self.dim_key * 2, bias=False),
# # nn.Unflatten(2, (2, self.n_heads, self.dim_key)) # 替代Rearrange
# )
# # 超参数
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
def forward(self, x, db_value, pos_cis):
# import pdb;pdb.set_trace()
# db_value = None
# # 如果有weight_down_embed使用Product Key机制
# if self.weight_down_embed is not None:
# # 1. 生成queries
# batch_size, seq_len, dim = x.shape
# # collapse sequence dimension by averaging
# x_flat = x.mean(dim=1) # [batch_size, dim]
# queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
# queries = queries.reshape(batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
# queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
# # 2. 计算queries与keys的相似度
# sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# # 3. 在两个子空间分别做top-k
# scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
# scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
# indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# # 4. 组合两个子空间的分数和索引
# all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
# all_scores = all_scores.view(*all_scores.shape[:-2], -1)
# all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
# all_indices = all_indices.view(*all_indices.shape[:-2], -1)
# # 5. 最终top-k选择
# scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
# indices = all_indices.gather(-1, pk_indices)
# # 6. 从embedding中获取专家值
# # 从embedding中获取值
# flat_indices = indices.view(-1) # 将索引展平为一维张量
# db_values = self.weight_down_embed(flat_indices)
# # 重塑回原始形状
# db_value = db_values.view(batch_size, -1, dim)
# 注意力计算
h_attn = self.attention(
self.attention_norm(x),
pos_cis,
db_value=db_value
)
h_attn = self.cross_att(h_attn, db_value)
# 残差连接
h = x + h_attn
# 前馈神经网络
out = h + self.feed_forward(self.ffn_norm(h))
return out
class ExtractDB(nn.Module):
def __init__(self,params):
# 修改专家数量和知识维度,确保能开方
super().__init__()
self.batch_size = None
self.dim = params.dim
self.dim_key = self.dim // 2
self.knowlwdge_num = params.knowlwdge_num # 100专家确保是完全平方数
# 将knowledge_dim设置为与head_dim相同以便在attention中直接使用
self.head_dim = params.dim // params.n_heads
self.knowledge_length = params.knowlwdge_length*params.dim
# 使用register_buffer代替nn.Parameter避免梯度问题
self.register_buffer('weight_down_embed', torch.randn(self.knowlwdge_num, self.knowledge_length) * 0.02)
self.num_keys = int(math.sqrt(self.knowlwdge_num)) if self.knowlwdge_num > 0 else 0
self.product_key_topk = min(16, self.num_keys)
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
self.num_experts_per_head_topk = 1
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.dim_key * 2, bias=False),
)
def q_to_k(self,x):
# 1. 生成queries
self.batch_size, seq_len, dim = x.shape
# collapse sequence dimension by averaging
x_flat = x.mean(dim=1) # [batch_size, dim]
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
# 2. 计算queries与keys的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 3. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 4. 组合两个子空间的分数和索引
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
all_scores = all_scores.view(*all_scores.shape[:-2], -1)
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
# 5. 最终top-k选择
scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
indices = all_indices.gather(-1, pk_indices)
flat_indices = indices.view(-1)
return flat_indices
def get_data(self, index):
# 直接从GPU获取embedding
db_values = self.weight_down_embed[index]
db_value = db_values.view(self.batch_size, -1, self.dim)
return db_value
@torch.no_grad()
def updata_value(self, k, v):
# 直接更新buffer上的值 (不需要梯度)
v_reshaped = v.view(v.size(0), -1)
# 确保数据类型匹配
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
self.weight_down_embed[k] = v_reshaped
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
# 移除旧的weight_down_embed声明
self.extract_db = ExtractDB(self.params)
# 将self.weight_down_embed传递给每个MiniMindBlock
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
# Calculate input dimension
input_dim = (self.params.max_seq_len-1)*self.params.n_layers
# Use a bottleneck architecture to reduce parameters
bottleneck_dim = 256 # Significantly smaller bottleneck dimension
# Factorized shared downsampling using two smaller convolutions
self.shared_downsample = nn.Sequential(
# First reduce input dimension to bottleneck
nn.Conv1d(input_dim, bottleneck_dim, kernel_size=1, padding='same'),
nn.ReLU(), # Non-linearity to improve representation capacity
# Then expand to target dimension
nn.Conv1d(bottleneck_dim, 128*8, kernel_size=1, padding='same')
)
# Specific layers for v path
self.downsample_v_specific = nn.Sequential(
nn.Conv1d(128*8, 128, kernel_size=1, padding='same'),
nn.Conv1d(128, 8, kernel_size=1, padding='same')
)
# Specific layers for q path
self.downsample_q_specific = nn.Sequential(
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
)
# 使用实数版本的位置编码,避免复数张量可能导致的段错误
self.register_buffer("pos_cis_real",
precompute_pos_cis_real(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.params = params
def forward(self,
input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**args):
start_pos = args.get('start_pos', 0)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis_real = self.pos_cis_real[start_pos:start_pos + input_ids.size(1)]
h_list = []
for l, layer in enumerate(self.layers):
# 禁用数据库模式,使用固定值替代数据库查询
if self.params.disable_db:
# 创建一个形状为[batch_size, n_layers, dim]的tensor所有元素值为1e-4
batch_size = h.size(0)
db_value = torch.full((batch_size, self.n_layers, self.params.dim), 1e-4,
dtype=h.dtype, device=h.device)
else:
# 正常模式,使用数据库查询
index = self.extract_db.q_to_k(h)
db_value = self.extract_db.get_data(index)
h = layer(
h, db_value, pos_cis_real
)
h_list.append(h.unsqueeze(0))
h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3)
# 只在非禁用数据库模式下执行数据库更新逻辑
if not self.params.disable_db:
# 使用detach()分离计算图,避免多次反向传播
h_tensor_detached = h_tensor.detach()
h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0], -1, self.params.dim)
# 数据库更新逻辑与主计算图分离
with torch.no_grad():
# Compute shared downsampling layer once
shared_features = self.shared_downsample(h_tensor_detached)
z_v = self.downsample_v_specific(shared_features)
z_q = self.downsample_q_specific(shared_features)
z_k = self.extract_db.q_to_k(z_q)
self.extract_db.updata_value(z_k, z_v)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
)
output.hidden_states = h
output.aux_loss = aux_loss
# 尝试添加其他属性(如果支持的话)
# try:
# output.hidden_states = h
# except:
# pass
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start, first_seq = input_ids.shape[1], True
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq:
out, first_seq = self(input_ids, **args), False
else:
out = self(input_ids[:, -1:], start_pos=input_ids.shape[1] - 1, **args)
logits = out.logits[:, -1, :]
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break

49
model/model_lora.py Normal file
View File

@ -0,0 +1,49 @@
import torch
from torch import optim, nn
# 定义Lora网络结构
class LoRA(nn.Module):
def __init__(self, in_features, out_features, rank):
super().__init__()
self.rank = rank # LoRA的秩rank控制低秩矩阵的大小
self.A = nn.Linear(in_features, rank, bias=False) # 低秩矩阵A
self.B = nn.Linear(rank, out_features, bias=False) # 低秩矩阵B
# 矩阵A高斯初始化
self.A.weight.data.normal_(mean=0.0, std=0.02)
# 矩阵B全0初始化
self.B.weight.data.zero_()
def forward(self, x):
return self.B(self.A(x))
def apply_lora(model, rank=16):
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]:
lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device)
setattr(module, "lora", lora)
original_forward = module.forward
# 显式绑定
def forward_with_lora(x, layer1=original_forward, layer2=lora):
return layer1(x) + layer2(x)
module.forward = forward_with_lora
def load_lora(model, path):
state_dict = torch.load(path, map_location=model.device)
for name, module in model.named_modules():
if hasattr(module, 'lora'):
lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k}
module.lora.load_state_dict(lora_state)
def save_lora(model, path):
state_dict = {}
for name, module in model.named_modules():
if hasattr(module, 'lora'):
lora_state = {f'{name}.lora.{k}': v for k, v in module.lora.state_dict().items()}
state_dict.update(lora_state)
torch.save(state_dict, path)

120
requirements.txt Normal file
View File

@ -0,0 +1,120 @@
aiohappyeyeballs==2.6.1
aiohttp==3.11.17
aiosignal==1.3.2
altair==5.5.0
annotated-types==0.7.0
anyio==4.9.0
async-timeout==5.0.1
attrs==25.3.0
blinker==1.9.0
cachetools==5.5.2
certifi==2025.1.31
charset-normalizer==3.4.1
click==8.1.8
contourpy==1.3.2
cycler==0.12.1
datasets==2.21.0
datasketch==1.6.4
dill==0.3.8
distro==1.9.0
docker-pycreds==0.4.0
einops==0.8.1
exceptiongroup==1.2.2
filelock==3.18.0
Flask==3.0.3
Flask-Cors==4.0.0
fonttools==4.57.0
frozenlist==1.6.0
fsspec==2024.6.1
gitdb==4.0.12
GitPython==3.1.44
h11==0.14.0
hjson==3.1.0
httpcore==1.0.8
httpx==0.28.1
huggingface-hub==0.30.2
idna==3.10
importlib_metadata==7.2.1
itsdangerous==2.2.0
jieba==0.42.1
Jinja2==3.1.2
jiter==0.9.0
joblib==1.4.2
jsonlines==4.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
kiwisolver==1.4.8
markdown-it-py==3.0.0
MarkupSafe==3.0.2
marshmallow==3.22.0
matplotlib==3.10.0
mdurl==0.1.2
modelscope==1.25.0
mpmath==1.3.0
msgpack==1.1.0
multidict==6.4.3
multiprocess==0.70.16
narwhals==1.35.0
networkx==3.4.2
ngrok==1.4.0
ninja==1.11.1.4
nltk==3.8
numpy==1.26.4
openai==1.59.6
packaging==23.2
pandas==1.5.3
peft==0.7.1
pillow==10.4.0
platformdirs==4.3.7
propcache==0.3.1
protobuf==4.25.6
psutil==5.9.8
py-cpuinfo==9.0.0
pyarrow==19.0.1
pydantic==2.8.2
pydantic_core==2.20.1
pydeck==0.9.1
Pygments==2.19.1
pyparsing==3.2.3
python-dateutil==2.9.0.post0
pytz==2025.2
PyYAML==6.0.2
referencing==0.36.2
regex==2024.11.6
requests==2.32.3
rich==13.7.1
rpds-py==0.24.0
safetensors==0.5.3
scikit-learn==1.5.1
scipy==1.15.2
sentence-transformers==2.3.1
sentencepiece==0.2.0
sentry-sdk==2.26.1
setproctitle==1.3.5
simhash==2.1.2
six==1.17.0
smmap==5.0.2
sniffio==1.3.1
streamlit==1.30.0
sympy==1.13.3
tenacity==8.5.0
threadpoolctl==3.6.0
tiktoken==0.5.1
tokenizers==0.21.1
toml==0.10.2
tornado==6.4.2
tqdm==4.67.1
transformers==4.48.0
triton==3.3.0
trl==0.13.0
typing_extensions==4.13.2
tzlocal==5.3.1
ujson==5.1.0
urllib3==2.4.0
validators==0.34.0
wandb==0.18.3
watchdog==6.0.0
Werkzeug==3.1.3
xxhash==3.5.0
yarl==1.20.0
zipp==3.21.0

View File

@ -0,0 +1,47 @@
#!/bin/bash
# 激活conda环境
# source $(conda info --base)/etc/profile.d/conda.sh
# conda activate ycz_accelerate
# 设置环境变量以帮助调试
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1
# 方法1: 使用预先配置的accelerate配置文件
# accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
# --epochs 3 \
# --batch_size 24 \
# --learning_rate 2e-4 \
# --dtype bfloat16 \
# --accumulation_steps 32 \
# --grad_clip 1.0 \
# --log_interval 100 \
# --save_interval 10000 \
# --dim 1024 \
# --n_layers 32 \
# --max_seq_len 1024 \
# --use_flash_attn \
# --profile \
# --profile_interval 10
# 方法2: 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0 accelerate launch \
--num_processes=1 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py \
--epochs 3 \
--batch_size 24 \
--learning_rate 2e-4 \
--dtype bfloat16 \
--accumulation_steps 32 \
--grad_clip 1.0 \
--log_interval 100 \
--save_interval 10000 \
--dim 512 \
--n_layers 12 \
--max_seq_len 512 \
--use_flash_attn \
--profile \
--profile_interval 10

View File

@ -0,0 +1,48 @@
#!/bin/bash
# 激活conda环境
source $(conda info --base)/etc/profile.d/conda.sh
conda activate ycz_accelerate
# 设置环境变量以帮助调试
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1
# 方法1: 使用预先配置的accelerate配置文件
# accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
# --epochs 3 \
# --batch_size 24 \
# --learning_rate 2e-4 \
# --dtype bfloat16 \
# --accumulation_steps 32 \
# --grad_clip 1.0 \
# --log_interval 100 \
# --save_interval 10000 \
# --dim 1024 \
# --n_layers 32 \
# --max_seq_len 1024 \
# --use_flash_attn \
# --profile \
# --profile_interval 10
# 方法2: 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--multi_gpu \
--num_processes=4 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py \
--epochs 3 \
--batch_size 24 \
--learning_rate 2e-4 \
--dtype bfloat16 \
--accumulation_steps 32 \
--grad_clip 1.0 \
--log_interval 100 \
--save_interval 10000 \
--dim 1024 \
--n_layers 32 \
--max_seq_len 1024 \
--use_flash_attn \
--profile \
--profile_interval 10

View File

@ -0,0 +1,30 @@
from openai import OpenAI
client = OpenAI(
api_key="none",
base_url="http://localhost:8998/v1"
)
stream = True
conversation_history_origin = []
conversation_history = conversation_history_origin.copy()
history_messages_num = 2 # 设置为偶数Q+A为0则每次不携带历史对话进行独立QA
while True:
query = input('[Q]: ')
conversation_history.append({"role": "user", "content": query})
response = client.chat.completions.create(
model="minimind",
messages=conversation_history[-history_messages_num:],
stream=stream
)
if not stream:
assistant_res = response.choices[0].message.content
print('[A]: ', assistant_res)
else:
print('[A]: ', end='')
assistant_res = ''
for chunk in response:
print(chunk.choices[0].delta.content or "", end="")
assistant_res += chunk.choices[0].delta.content or ""
conversation_history.append({"role": "assistant", "content": assistant_res})
print('\n\n')

62
scripts/convert_model.py Normal file
View File

@ -0,0 +1,62 @@
import torch
import warnings
import sys
import os
__package__ = "scripts"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.LMConfig import LMConfig
from model.model import MiniMindLM
warnings.filterwarnings('ignore', category=UserWarning)
def convert_torch2transformers(torch_path, transformers_path):
def export_tokenizer(transformers_path):
tokenizer = AutoTokenizer.from_pretrained('../model/minimind_tokenizer')
tokenizer.save_pretrained(transformers_path)
LMConfig.register_for_auto_class()
MiniMindLM.register_for_auto_class("AutoModelForCausalLM")
lm_model = MiniMindLM(lm_config)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
state_dict = torch.load(torch_path, map_location=device)
lm_model.load_state_dict(state_dict, strict=False)
model_params = sum(p.numel() for p in lm_model.parameters() if p.requires_grad)
print(f'模型参数: {model_params / 1e6} 百万 = {model_params / 1e9} B (Billion)')
lm_model.save_pretrained(transformers_path, safe_serialization=False)
export_tokenizer(transformers_path)
print(f"模型已保存为 Transformers 格式: {transformers_path}")
def convert_transformers2torch(transformers_path, torch_path):
model = AutoModelForCausalLM.from_pretrained(transformers_path, trust_remote_code=True)
torch.save(model.state_dict(), torch_path)
print(f"模型已保存为 PyTorch 格式: {torch_path}")
# don't need to use
def push_to_hf(export_model_path):
def init_model():
tokenizer = AutoTokenizer.from_pretrained('../model/minimind_tokenizer')
model = AutoModelForCausalLM.from_pretrained(export_model_path, trust_remote_code=True)
return model, tokenizer
model, tokenizer = init_model()
# model.push_to_hub(model_path)
# tokenizer.push_to_hub(model_path, safe_serialization=False)
if __name__ == '__main__':
lm_config = LMConfig(dim=512, n_layers=8, max_seq_len=8192, use_moe=False)
torch_path = f"../out/rlhf_{lm_config.dim}{'_moe' if lm_config.use_moe else ''}.pth"
transformers_path = '../MiniMind2-Small'
# convert torch to transformers model
convert_torch2transformers(torch_path, transformers_path)
# # convert transformers to torch model
# convert_transformers2torch(transformers_path, torch_path)

164
scripts/serve_openai_api.py Normal file
View File

@ -0,0 +1,164 @@
import argparse
import json
import os
import sys
__package__ = "scripts"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import time
import torch
import warnings
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.LMConfig import LMConfig
from model.model import MiniMindLM
from model.model_lora import apply_lora, load_lora
warnings.filterwarnings('ignore')
app = FastAPI()
def init_model(args):
tokenizer = AutoTokenizer.from_pretrained('../model/minimind_tokenizer')
if args.load == 0:
moe_path = '_moe' if args.use_moe else ''
modes = {0: 'pretrain', 1: 'full_sft', 2: 'rlhf', 3: 'reason'}
ckp = f'../{args.out_dir}/{modes[args.model_mode]}_{args.dim}{moe_path}.pth'
model = MiniMindLM(LMConfig(
dim=args.dim,
n_layers=args.n_layers,
max_seq_len=args.max_seq_len,
use_moe=args.use_moe
))
state_dict = torch.load(ckp, map_location=device)
model.load_state_dict({k: v for k, v in state_dict.items() if 'mask' not in k}, strict=True)
if args.lora_name != 'None':
apply_lora(model)
load_lora(model, f'../{args.out_dir}/{args.lora_name}_{args.dim}.pth')
else:
model = AutoModelForCausalLM.from_pretrained(
'./MiniMind2',
trust_remote_code=True
)
print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M(illion)')
return model.eval().to(device), tokenizer
class ChatRequest(BaseModel):
model: str
messages: list
temperature: float = 0.7
top_p: float = 0.92
max_tokens: int = 8192
stream: bool = False
def generate_stream_response(messages, temperature, top_p, max_tokens):
try:
new_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)[-max_tokens:]
x = tokenizer(new_prompt).data['input_ids']
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
with torch.no_grad():
res_y = model.generate(
x,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stream=True,
rp=1.,
pad_token_id=tokenizer.pad_token_id
)
history_idx = 0
for y in res_y:
answer = tokenizer.decode(y[0].tolist(), skip_special_tokens=True)
if (answer and answer[-1] == '<EFBFBD>') or not answer:
continue
delta = answer[history_idx:]
history_idx = len(answer)
json_data = {
'id': f'chatcmpl-{int(time.time())}',
'object': 'chat.completion.chunk',
'created': int(time.time()),
'model': 'minimind',
'choices': [{'index': 0, 'delta': {'content': delta}, 'finish_reason': None}]
}
yield f"data: {json.dumps(json_data)}\n\n"
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatRequest):
try:
if request.stream:
return StreamingResponse(
generate_stream_response(
messages=request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens
),
media_type="text/event-stream"
)
else:
new_prompt = tokenizer.apply_chat_template(
request.messages,
tokenize=False,
add_generation_prompt=True
)[-request.max_tokens:]
x = tokenizer(new_prompt).data['input_ids']
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
with torch.no_grad():
res_y = model.generate(
x,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
stream=False,
rp=1.,
pad_token_id=tokenizer.pad_token_id
)
answer = tokenizer.decode(res_y.squeeze()[x.shape[1]:].tolist(), skip_special_tokens=True)
return {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": "minimind",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": answer},
"finish_reason": "stop"
}
]
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Server for MiniMind")
parser.add_argument('--out_dir', default='out', type=str)
parser.add_argument('--lora_name', default='None', type=str)
parser.add_argument('--dim', default=512, type=int)
parser.add_argument('--n_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=8192, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--load', default=0, type=int, help="0: 从原生torch权重1: 利用transformers加载")
parser.add_argument('--model_mode', default=1, type=int, help="0: 预训练模型1: SFT-Chat模型2: RLHF-Chat模型3: Reason模型")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model, tokenizer = init_model(parser.parse_args())
uvicorn.run(app, host="0.0.0.0", port=8998)

152
scripts/train_tokenizer.py Normal file
View File

@ -0,0 +1,152 @@
import random
from tqdm import tqdm
from transformers import AutoTokenizer
import json
from datasets import load_dataset
from tokenizers import (
decoders,
models,
normalizers,
pre_tokenizers,
processors,
trainers,
Tokenizer,
)
import os
random.seed(42)
def train_tokenizer():
# 读取JSONL文件并提取文本数据
def read_texts_from_jsonl(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
yield data['text']
data_path = '../dataset/pretrain_hq.jsonl'
# 初始化tokenizer
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
# 定义特殊token
special_tokens = ["<unk>", "<s>", "</s>"]
# 设置训练器并添加特殊token
trainer = trainers.BpeTrainer(
vocab_size=6400,
special_tokens=special_tokens, # 确保这三个token被包含
show_progress=True,
initial_alphabet=pre_tokenizers.ByteLevel.alphabet()
)
# 读取文本数据
texts = read_texts_from_jsonl(data_path)
# 训练tokenizer
tokenizer.train_from_iterator(texts, trainer=trainer)
# 设置解码器
tokenizer.decoder = decoders.ByteLevel()
# 检查特殊token的索引
assert tokenizer.token_to_id("<unk>") == 0
assert tokenizer.token_to_id("<s>") == 1
assert tokenizer.token_to_id("</s>") == 2
# 保存tokenizer
tokenizer_dir = "../model/minimind_tokenizer"
os.makedirs(tokenizer_dir, exist_ok=True)
tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
tokenizer.model.save("../model/minimind_tokenizer")
# 手动创建配置文件
config = {
"add_bos_token": False,
"add_eos_token": False,
"add_prefix_space": False,
"added_tokens_decoder": {
"0": {
"content": "<unk>",
"lstrip": False,
"normalized": False,
"rstrip": False,
"single_word": False,
"special": True
},
"1": {
"content": "<s>",
"lstrip": False,
"normalized": False,
"rstrip": False,
"single_word": False,
"special": True
},
"2": {
"content": "</s>",
"lstrip": False,
"normalized": False,
"rstrip": False,
"single_word": False,
"special": True
}
},
"additional_special_tokens": [],
"bos_token": "<s>",
"clean_up_tokenization_spaces": False,
"eos_token": "</s>",
"legacy": True,
"model_max_length": 32768,
"pad_token": "<unk>",
"sp_model_kwargs": {},
"spaces_between_special_tokens": False,
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": "<unk>",
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<s>system\\n' + system_message + '</s>\\n' }}{% else %}{{ '<s>system\\n你是 MiniMind是一个有用的人工智能助手。</s>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
}
# 保存配置文件
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w", encoding="utf-8") as config_file:
json.dump(config, config_file, ensure_ascii=False, indent=4)
print("Tokenizer training completed and saved.")
def eval_tokenizer():
from transformers import AutoTokenizer
# 加载预训练的tokenizer
tokenizer = AutoTokenizer.from_pretrained("../model/minimind_tokenizer")
messages = [
{"role": "system", "content": "你是一个优秀的聊天机器人,总是给我正确的回应!"},
{"role": "user", "content": '你来自哪里?'},
{"role": "assistant", "content": '我来自地球'}
]
new_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False
)
print(new_prompt)
# 获取实际词汇表长度(包括特殊符号)
actual_vocab_size = len(tokenizer)
print('tokenizer实际词表长度', actual_vocab_size)
model_inputs = tokenizer(new_prompt)
print('encoder长度', len(model_inputs['input_ids']))
input_ids = model_inputs['input_ids']
response = tokenizer.decode(input_ids, skip_special_tokens=False)
print('decoder和原始文本是否一致', response == new_prompt)
def main():
train_tokenizer()
eval_tokenizer()
if __name__ == '__main__':
main()

293
scripts/web_demo.py Normal file
View File

@ -0,0 +1,293 @@
import random
import re
import time
import numpy as np
import streamlit as st
import torch
st.set_page_config(page_title="MiniMind", initial_sidebar_state="collapsed")
# 在文件开头的 CSS 样式中修改按钮样式
st.markdown("""
<style>
/* 添加操作按钮样式 */
.stButton button {
border-radius: 50% !important; /* 改为圆形 */
width: 32px !important; /* 固定宽度 */
height: 32px !important; /* 固定高度 */
padding: 0 !important; /* 移除内边距 */
background-color: transparent !important;
border: 1px solid #ddd !important;
display: flex !important;
align-items: center !important;
justify-content: center !important;
font-size: 14px !important;
color: #666 !important; /* 更柔和的颜色 */
margin: 5px 10px 5px 0 !important; /* 调整按钮间距 */
}
.stButton button:hover {
border-color: #999 !important;
color: #333 !important;
background-color: #f5f5f5 !important;
}
.stMainBlockContainer > div:first-child {
margin-top: -50px !important;
}
.stApp > div:last-child {
margin-bottom: -35px !important;
}
/* 重置按钮基础样式 */
.stButton > button {
all: unset !important; /* 重置所有默认样式 */
box-sizing: border-box !important;
border-radius: 50% !important;
width: 18px !important;
height: 18px !important;
min-width: 18px !important;
min-height: 18px !important;
max-width: 18px !important;
max-height: 18px !important;
padding: 0 !important;
background-color: transparent !important;
border: 1px solid #ddd !important;
display: flex !important;
align-items: center !important;
justify-content: center !important;
font-size: 14px !important;
color: #888 !important;
cursor: pointer !important;
transition: all 0.2s ease !important;
margin: 0 2px !important; /* 调整这里的 margin */
}
</style>
""", unsafe_allow_html=True)
system_prompt = []
device = "cuda" if torch.cuda.is_available() else "cpu"
def process_assistant_content(content):
if 'R1' not in MODEL_PATHS[selected_model][1]:
return content
if '<think>' in content and '</think>' in content:
content = re.sub(r'(<think>)(.*?)(</think>)',
r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理内容(展开)</summary>\2</details>',
content,
flags=re.DOTALL)
if '<think>' in content and '</think>' not in content:
content = re.sub(r'<think>(.*?)$',
r'<details open style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理中...</summary>\1</details>',
content,
flags=re.DOTALL)
if '<think>' not in content and '</think>' in content:
content = re.sub(r'(.*?)</think>',
r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理内容(展开)</summary>\1</details>',
content,
flags=re.DOTALL)
return content
@st.cache_resource
def load_model_tokenizer(model_path):
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True
)
model = model.eval().to(device)
return model, tokenizer
def clear_chat_messages():
del st.session_state.messages
del st.session_state.chat_messages
def init_chat_messages():
if "messages" in st.session_state:
for i, message in enumerate(st.session_state.messages):
if message["role"] == "assistant":
with st.chat_message("assistant", avatar=image_url):
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
# 在消息内容下方添加按钮
if st.button("🗑", key=f"delete_{i}"):
st.session_state.messages.pop(i)
st.session_state.messages.pop(i - 1)
st.session_state.chat_messages.pop(i)
st.session_state.chat_messages.pop(i - 1)
st.rerun()
else:
st.markdown(
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: #ddd; border-radius: 10px; color: black;">{message["content"]}</div></div>',
unsafe_allow_html=True)
else:
st.session_state.messages = []
st.session_state.chat_messages = []
return st.session_state.messages
# 添加这两个辅助函数
def regenerate_answer(index):
st.session_state.messages.pop()
st.session_state.chat_messages.pop()
st.rerun()
def delete_conversation(index):
st.session_state.messages.pop(index)
st.session_state.messages.pop(index - 1)
st.session_state.chat_messages.pop(index)
st.session_state.chat_messages.pop(index - 1)
st.rerun()
# 侧边栏模型选择
st.sidebar.title("模型设定调整")
st.sidebar.text("【注】训练数据偏差,增加上下文记忆时\n多轮对话(较单轮)容易出现能力衰减")
st.session_state.history_chat_num = st.sidebar.slider("Number of Historical Dialogues", 0, 6, 0, step=2)
# st.session_state.history_chat_num = 0
st.session_state.max_new_tokens = st.sidebar.slider("Max Sequence Length", 256, 8192, 8192, step=1)
st.session_state.top_p = st.sidebar.slider("Top-P", 0.8, 0.99, 0.85, step=0.01)
st.session_state.temperature = st.sidebar.slider("Temperature", 0.6, 1.2, 0.85, step=0.01)
# 模型路径映射
MODEL_PATHS = {
"MiniMind2-R1 (0.1B)": ["../MiniMind2-R1", "MiniMind2-R1"],
"MiniMind2-Small-R1 (0.02B)": ["../MiniMind2-Small-R1", "MiniMind2-Small-R1"],
"MiniMind2 (0.1B)": ["../MiniMind2", "MiniMind2"],
"MiniMind2-MoE (0.15B)": ["../MiniMind2-MoE", "MiniMind2-MoE"],
"MiniMind2-Small (0.02B)": ["../MiniMind2-Small", "MiniMind2-Small"],
"MiniMind-V1 (0.1B)": ["../minimind-v1", "MiniMind-V1"],
"MiniMind-V1-MoE (0.1B)": ["../minimind-v1-moe", "MiniMind-V1-MoE"],
"MiniMind-V1-Small (0.02B)": ["../minimind-v1-small", "MiniMind-V1-Small"],
}
selected_model = st.sidebar.selectbox('Models', list(MODEL_PATHS.keys()), index=2) # 默认选择 MiniMind2
model_path = MODEL_PATHS[selected_model][0]
slogan = f"Hi, I'm {MODEL_PATHS[selected_model][1]}"
image_url = "https://www.modelscope.cn/api/v1/studio/gongjy/MiniMind/repo?Revision=master&FilePath=images%2Flogo2.png&View=true"
st.markdown(
f'<div style="display: flex; flex-direction: column; align-items: center; text-align: center; margin: 0; padding: 0;">'
'<div style="font-style: italic; font-weight: 900; margin: 0; padding-top: 4px; display: flex; align-items: center; justify-content: center; flex-wrap: wrap; width: 100%;">'
f'<img src="{image_url}" style="width: 45px; height: 45px; "> '
f'<span style="font-size: 26px; margin-left: 10px;">{slogan}</span>'
'</div>'
'<span style="color: #bbb; font-style: italic; margin-top: 6px; margin-bottom: 10px;">内容完全由AI生成请务必仔细甄别<br>Content AI-generated, please discern with care</span>'
'</div>',
unsafe_allow_html=True
)
def setup_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
model, tokenizer = load_model_tokenizer(model_path)
# 初始化消息列表
if "messages" not in st.session_state:
st.session_state.messages = []
st.session_state.chat_messages = []
# Use session state messages
messages = st.session_state.messages
# 在显示历史消息的循环中
for i, message in enumerate(messages):
if message["role"] == "assistant":
with st.chat_message("assistant", avatar=image_url):
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
if st.button("×", key=f"delete_{i}"):
# 删除当前消息及其之后的所有消息
st.session_state.messages = st.session_state.messages[:i - 1]
st.session_state.chat_messages = st.session_state.chat_messages[:i - 1]
st.rerun()
else:
st.markdown(
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: gray; border-radius: 10px; color:white; ">{message["content"]}</div></div>',
unsafe_allow_html=True)
# 处理新的输入或重新生成
prompt = st.chat_input(key="input", placeholder="给 MiniMind 发送消息")
# 检查是否需要重新生成
if hasattr(st.session_state, 'regenerate') and st.session_state.regenerate:
prompt = st.session_state.last_user_message
regenerate_index = st.session_state.regenerate_index # 获取重新生成的位置
# 清除所有重新生成相关的状态
delattr(st.session_state, 'regenerate')
delattr(st.session_state, 'last_user_message')
delattr(st.session_state, 'regenerate_index')
if prompt:
st.markdown(
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: gray; border-radius: 10px; color:white; ">{prompt}</div></div>',
unsafe_allow_html=True)
messages.append({"role": "user", "content": prompt})
st.session_state.chat_messages.append({"role": "user", "content": prompt})
with st.chat_message("assistant", avatar=image_url):
placeholder = st.empty()
random_seed = random.randint(0, 2 ** 32 - 1)
setup_seed(random_seed)
st.session_state.chat_messages = system_prompt + st.session_state.chat_messages[
-(st.session_state.history_chat_num + 1):]
new_prompt = tokenizer.apply_chat_template(
st.session_state.chat_messages,
tokenize=False,
add_generation_prompt=True
)[-(st.session_state.max_new_tokens - 1):]
x = torch.tensor(tokenizer(new_prompt)['input_ids'], device=device).unsqueeze(0)
with torch.no_grad():
res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=st.session_state.max_new_tokens,
temperature=st.session_state.temperature,
top_p=st.session_state.top_p, stream=True)
try:
for y in res_y:
answer = tokenizer.decode(y[0].tolist(), skip_special_tokens=True)
if (answer and answer[-1] == '<EFBFBD>') or not answer:
continue
placeholder.markdown(process_assistant_content(answer), unsafe_allow_html=True)
except StopIteration:
print("No answer")
assistant_answer = answer.replace(new_prompt, "")
messages.append({"role": "assistant", "content": assistant_answer})
st.session_state.chat_messages.append({"role": "assistant", "content": assistant_answer})
with st.empty():
if st.button("×", key=f"delete_{len(messages) - 1}"):
st.session_state.messages = st.session_state.messages[:-2]
st.session_state.chat_messages = st.session_state.chat_messages[:-2]
st.rerun()
if __name__ == "__main__":
from transformers import AutoModelForCausalLM, AutoTokenizer
main()

97
test_real_rope.py Normal file
View File

@ -0,0 +1,97 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
测试实数版本的位置编码
"""
import torch
from model.model import precompute_pos_cis, precompute_pos_cis_real, apply_rotary_emb, apply_rotary_emb_real
from model.LMConfig import LMConfig
from model.model import MiniMindLM
def test_pos_encoding_equivalence():
"""测试复数版本和实数版本的位置编码是否等价"""
print("测试位置编码等价性...")
# 参数设置
dim = 64
seq_len = 10
# 生成复数版本的位置编码
pos_cis = precompute_pos_cis(dim=dim, end=seq_len)
# 生成实数版本的位置编码
pos_cis_real = precompute_pos_cis_real(dim=dim, end=seq_len)
# 创建随机查询和键
batch_size = 2
n_heads = 4
head_dim = dim
xq = torch.randn(batch_size, seq_len, n_heads, head_dim)
xk = torch.randn(batch_size, seq_len, n_heads, head_dim)
# 应用复数版本的旋转位置编码
xq_complex, xk_complex = apply_rotary_emb(xq, xk, pos_cis)
# 应用实数版本的旋转位置编码
xq_real, xk_real = apply_rotary_emb_real(xq, xk, pos_cis_real)
# 计算差异
q_diff = torch.abs(xq_complex - xq_real).mean().item()
k_diff = torch.abs(xk_complex - xk_real).mean().item()
print(f"查询差异: {q_diff:.6f}")
print(f"键差异: {k_diff:.6f}")
# 检查差异是否在可接受范围内
tolerance = 1e-5
if q_diff < tolerance and k_diff < tolerance:
print("✅ 测试通过: 复数版本和实数版本的位置编码在数值上等价")
else:
print("❌ 测试失败: 复数版本和实数版本的位置编码存在显著差异")
def test_model_forward():
"""测试模型前向传播"""
print("\n测试模型前向传播...")
# 创建模型配置
config = LMConfig(
dim=128,
n_layers=2,
n_heads=4,
n_kv_heads=4, # 确保n_kv_heads被设置且n_heads能被n_kv_heads整除
vocab_size=1000,
max_seq_len=128,
disable_db=True # 禁用数据库功能,避免额外的复杂性
)
# 创建模型
try:
model = MiniMindLM(config)
print(f"✅ 模型初始化成功")
except Exception as e:
print(f"❌ 模型初始化失败: {str(e)}")
return
# 创建输入
batch_size = 2
seq_len = 10
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
# 前向传播
try:
with torch.no_grad():
outputs = model(input_ids)
print(f"✅ 模型前向传播成功")
print(f"输出形状: {outputs.logits.shape}")
except Exception as e:
print(f"❌ 模型前向传播失败: {str(e)}")
if __name__ == "__main__":
# 测试位置编码等价性
test_pos_encoding_equivalence()
# 测试模型前向传播
test_model_forward()

215
train_distill_reason.py Normal file
View File

@ -0,0 +1,215 @@
import os
import platform
import argparse
import time
import math
import warnings
import pandas as pd
import torch
import torch.nn.functional as F
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model import MiniMindLM
from model.LMConfig import LMConfig
from model.dataset import SFTDataset
warnings.filterwarnings('ignore')
def Logger(content):
if not ddp or dist.get_rank() == 0:
print(content)
def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
def train_epoch(epoch, wandb):
# 思考标签占位符
start_of_think_ids = tokenizer('<think>').input_ids
end_of_think_ids = tokenizer('</think>').input_ids
start_of_answer_ids = tokenizer('<answer>').input_ids
end_of_answer_ids = tokenizer('</answer>').input_ids
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
sp_ids = torch.isin(Y.view(-1),
torch.tensor(start_of_think_ids + end_of_think_ids
+ start_of_answer_ids + end_of_answer_ids
).to(args.device))
# 在 sp_ids 对应的位置增加额外的惩罚
loss_mask = loss_mask.view(-1)
loss_mask_sum = loss_mask.sum()
loss_mask[sp_ids] = 10
loss_mask = loss_mask.view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask_sum
loss += res.aux_loss
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
loss.item(),
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/reason_{lm_config.dim}{moe_path}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, ckp)
model.train()
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/rlhf_{lm_config.dim}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
return model, tokenizer
def init_distributed_mode():
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Distill Reasoning")
parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--learning_rate", type=float, default=1e-6)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=1)
parser.add_argument("--save_interval", type=int, default=50)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--dim', default=512, type=int)
parser.add_argument('--n_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=1024, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument("--data_path", type=str, default="./dataset/r1_mix_1024.jsonl")
args = parser.parse_args()
lm_config = LMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe)
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * lm_config.max_seq_len
device_type = "cuda" if "cuda" in args.device else "cpu"
args.wandb_run_name = f"MiniMind-Distill-Reasoning-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else:
wandb = None
model, tokenizer = init_model(lm_config)
train_ds = SFTDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=False,
num_workers=args.num_workers,
sampler=train_sampler
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
train_epoch(epoch, wandb)

263
train_distillation.py Normal file
View File

@ -0,0 +1,263 @@
import os
import argparse
import time
import math
import warnings
import pandas as pd
import torch
import torch.nn.functional as F
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model import MiniMindLM
from model.LMConfig import LMConfig
from model.dataset import SFTDataset
warnings.filterwarnings('ignore')
def Logger(content):
if not ddp or dist.get_rank() == 0:
print(content)
def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
with torch.no_grad():
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).detach()
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
kl = F.kl_div(
student_log_probs,
teacher_probs,
reduction=reduction
)
return (temperature ** 2) * kl
def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
start_time = time.time()
if teacher_model is not None:
teacher_model.eval()
teacher_model.requires_grad_(False)
for step, (X, Y, loss_mask) in enumerate(train_loader):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
lr = get_lr(epoch * iter_per_epoch + step,
args.epochs * iter_per_epoch,
args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 前向传播(学生模型)
with ctx:
res = model(X)
student_logits = res.logits
# 教师模型前向传播只在eval & no_grad
if teacher_model is not None:
with torch.no_grad():
teacher_logits = teacher_model(X).logits
vocab_size_student = student_logits.size(-1) # N
teacher_logits = teacher_logits[..., :vocab_size_student]
# ========== 计算损失 ==========
# 1) Ground-Truth CE Loss可选
loss_mask_flat = loss_mask.view(-1)
ce_loss = F.cross_entropy(
student_logits.view(-1, student_logits.size(-1)),
Y.view(-1),
ignore_index=0,
reduction='none'
)
ce_loss = torch.sum(ce_loss * loss_mask_flat) / loss_mask_flat.sum()
if lm_config_student.use_moe:
ce_loss += res.aux_loss
# 2) Distillation Loss可选
if teacher_model is not None:
# 只在有效token位置做蒸馏
distill_loss = distillation_loss_fn(
student_logits.view(-1, student_logits.size(-1))[loss_mask_flat == 1],
teacher_logits.view(-1, teacher_logits.size(-1))[loss_mask_flat == 1],
temperature=temperature
)
else:
distill_loss = torch.tensor(0.0, device=args.device)
# 3) 总损失 = alpha * CE + (1-alpha) * Distill
loss = alpha * ce_loss + (1 - alpha) * distill_loss
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.4f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch,
args.epochs - 1,
step,
iter_per_epoch,
loss.item(),
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
)
)
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({
"loss": loss.item(),
"ce_loss": ce_loss.item(),
"distill_loss": distill_loss.item() if teacher_model is not None else 0.0,
"lr": optimizer.param_groups[-1]['lr'],
"last-time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
})
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
moe_path = '_moe' if lm_config_student.use_moe else ''
ckp = f'{args.save_dir}/full_dist_{lm_config_student.dim}{moe_path}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, ckp)
model.train()
def init_student_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/full_sft_{lm_config.dim}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
Logger(f'学生模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
return model, tokenizer
def init_teacher_model(lm_config):
model = MiniMindLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/full_sft_{lm_config.dim}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
Logger(f'教师模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
return model
def init_distributed_mode():
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Full SFT")
parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=6)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--learning_rate", type=float, default=5e-6)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=100)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument("--data_path", type=str, default="./dataset/sft_data.jsonl")
args = parser.parse_args()
# 定义学生模型和教师模型
lm_config_student = LMConfig(dim=512, n_layers=8, max_seq_len=512)
lm_config_teacher = LMConfig(dim=768, n_layers=16, max_seq_len=512)
max_seq_len = lm_config_student.max_seq_len
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * max_seq_len
device_type = "cuda" if "cuda" in args.device else "cpu"
args.wandb_run_name = f"MiniMind-Dist-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else:
wandb = None
# 初始化学生模型和教师模型
model, tokenizer = init_student_model(lm_config_student)
teacher_model = init_teacher_model(lm_config_teacher)
train_ds = SFTDataset(args.data_path, tokenizer, max_length=max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=False,
num_workers=args.num_workers,
sampler=train_sampler
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
train_epoch(epoch, wandb)

247
train_dpo.py Normal file
View File

@ -0,0 +1,247 @@
import os
import platform
import argparse
import time
import math
import warnings
import pandas as pd
import torch
import torch.nn.functional as F
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model import MiniMindLM
from model.LMConfig import LMConfig
from model.dataset import DPODataset
warnings.filterwarnings('ignore')
def Logger(content):
if not ddp or dist.get_rank() == 0:
print(content)
def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
def logits_to_probs(logits, labels):
# logits shape: (batch_size, seq_len, vocab_size)
# labels shape: (batch_size, seq_len)
# probs shape: (batch_size, seq_len)
log_probs = F.log_softmax(logits, dim=2)
probs = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)
return probs
def dpo_loss(ref_probs, probs, mask, beta):
# ref_probs 和 probs 都是 shape: (batch_size, seq_len)
# https://github.com/jingyaogong/minimind/issues/298
seq_lengths = mask.sum(dim=1, keepdim=True) # (batch_size, 1)
ref_probs = (ref_probs * mask).sum(dim=1) / seq_lengths.squeeze()
probs = (probs * mask).sum(dim=1) / seq_lengths.squeeze()
# 将 chosen 和 rejected 数据分开
batch_size = ref_probs.shape[0]
chosen_ref_probs = ref_probs[:batch_size // 2]
reject_ref_probs = ref_probs[batch_size // 2:]
chosen_probs = probs[:batch_size // 2]
reject_probs = probs[batch_size // 2:]
pi_logratios = chosen_probs - reject_probs
ref_logratios = chosen_ref_probs - reject_ref_probs
logits = pi_logratios - ref_logratios
loss = -F.logsigmoid(beta * logits)
return loss.mean()
def train_epoch(epoch, wandb):
start_time = time.time()
for step, batch in enumerate(train_loader):
x_chosen = batch['x_chosen'].to(args.device)
x_rejected = batch['x_rejected'].to(args.device)
y_chosen = batch['y_chosen'].to(args.device)
y_rejected = batch['y_rejected'].to(args.device)
mask_chosen = batch['mask_chosen'].to(args.device)
mask_rejected = batch['mask_rejected'].to(args.device)
x = torch.cat([x_chosen, x_rejected], dim=0)
y = torch.cat([y_chosen, y_rejected], dim=0)
mask = torch.cat([mask_chosen, mask_rejected], dim=0)
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with ctx:
with torch.no_grad():
ref_outputs = ref_model(x)
ref_logits = ref_outputs.logits
ref_probs = logits_to_probs(ref_logits, y)
ref_probs = ref_probs * mask
outputs = model(x)
logits = outputs.logits
probs = logits_to_probs(logits, y)
probs = probs * mask
loss = dpo_loss(ref_probs, probs, mask, beta=0.1)
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
loss.item(),
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/rlhf_{lm_config.dim}{moe_path}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, ckp)
model.train()
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/full_sft_{lm_config.dim}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
# 初始化参考模型
ref_model = MiniMindLM(lm_config)
ref_model.load_state_dict(state_dict, strict=False)
ref_model.eval()
ref_model.requires_grad_(False)
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
ref_model = ref_model.to(args.device)
return model, ref_model, tokenizer
def init_distributed_mode():
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind RLHF")
parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=2)
parser.add_argument("--batch_size", type=int, default=8)
# sft阶段学习率为 「5e-6」->「5e-7」长度512建议离线正负样本「概率」偏好对齐阶段lr <=「1e-8」长度3000否则很容易遗忘训坏
parser.add_argument("--learning_rate", type=float, default=1e-8)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-RLHF-SFT")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=100)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--dim', default=512, type=int)
parser.add_argument('--n_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=1024, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument("--data_path", type=str, default="./dataset/dpo.jsonl")
args = parser.parse_args()
lm_config = LMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe)
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * lm_config.max_seq_len
device_type = "cuda" if "cuda" in args.device else "cpu"
args.wandb_run_name = f"MiniMind-Full-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else:
wandb = None
model, ref_model, tokenizer = init_model(lm_config)
train_ds = DPODataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=False,
num_workers=args.num_workers,
sampler=train_sampler
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
train_epoch(epoch, wandb)

418
train_embedding.py Normal file
View File

@ -0,0 +1,418 @@
import os
# 设置环境变量
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
import platform
import argparse
import time
import math
import warnings
import pandas as pd
import torch
import torch.distributed as dist
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, DistributedSampler, Dataset
from contextlib import nullcontext
import random
import numpy as np
import json
from transformers import AutoTokenizer
# Removed: from model.model import MiniMindLM
from model.LMConfig import LMConfig
# from model.dataset import PretrainDataset
warnings.filterwarnings('ignore')
# Define a Word2Vec-style CBOW model
class CBOWModel(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.vocab_size = config.vocab_size
self.embedding_dim = config.dim
# Input embeddings (context words)
self.embeddings = nn.Embedding(config.vocab_size, config.dim)
# Output weights for target prediction
self.output_weights = nn.Linear(config.dim, config.vocab_size, bias=False)
# Initialize weights
self.init_weights()
def init_weights(self):
# Xavier initialization for better convergence
nn.init.xavier_uniform_(self.embeddings.weight)
nn.init.xavier_uniform_(self.output_weights.weight)
def forward(self, context_words):
# context_words shape: [batch_size, context_size]context_size可变
# Get embeddings for all context words
embeds = self.embeddings(context_words) # [batch_size, context_size, embedding_dim]
# Average the context word embeddings along context dimension
embeds = torch.mean(embeds, dim=1) # [batch_size, embedding_dim]
# Predict the target word
output = self.output_weights(embeds) # [batch_size, vocab_size]
return output
# Word2Vec CBOW dataset
class CBOWDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512, window_size=5):
super().__init__()
self.tokenizer = tokenizer
self.window_size = window_size
self.max_length = max_length
self.samples = self.load_data(data_path)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
# 构建输入文本
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# 获取token ids
input_ids = encoding.input_ids.squeeze()
# 过滤掉padding
attention_mask = encoding.attention_mask.squeeze()
valid_indices = torch.where(attention_mask == 1)[0]
valid_input_ids = input_ids[valid_indices]
# 确保有足够的token进行CBOW训练
if len(valid_input_ids) <= 2 * self.window_size + 1:
# 如果token不足随机选择一个不同的样本
return self.__getitem__(random.randint(0, len(self.samples) - 1))
# 随机选择一个中心位置不包括首尾的特殊token
# 确保中心位置两边都有至少window_size个token
min_center_pos = self.window_size + 1 # 避开起始token
max_center_pos = len(valid_input_ids) - self.window_size - 1 # 避开结束token
if max_center_pos <= min_center_pos:
return self.__getitem__(random.randint(0, len(self.samples) - 1))
center_pos = random.randint(min_center_pos, max_center_pos)
# 目标词(中心词)
target = valid_input_ids[center_pos].unsqueeze(0)
# 上下文词(中心词前后的词)
context = torch.cat([
valid_input_ids[center_pos - self.window_size:center_pos],
valid_input_ids[center_pos + 1:center_pos + self.window_size + 1]
])
return context, target
def Logger(content):
# 如果没有使用ddp或者ddp的主设备那么就打印
if not ddp or dist.get_rank() == 0:
print(content)
def get_lr(current_step, total_steps, lr):
# 更新学习率
# \text{get\_lr}(c, t, l) = \frac{l}{10} + 0.5 \cdot l \cdot \left(1 + \cos\left(\frac{\pi \cdot c}{t}\right)\right)
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
def train_epoch(epoch, wandb):
loss_fct = nn.CrossEntropyLoss()
start_time = time.time()
total_loss = 0
total_samples = 0
for step, (context, target) in enumerate(train_loader):
try:
# 将数据加载到设备上
context = context.to(args.device)
target = target.to(args.device)
# 更新学习率
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with ctx:
# Forward pass
logits = model(context) # [batch_size, vocab_size]
# target是[batch_size, 1]需要squeeze成[batch_size]来匹配CrossEntropyLoss的预期
loss = loss_fct(logits, target.squeeze())
loss = loss / args.accumulation_steps
# Print data types for debugging
if step == 0 and (not ddp or dist.get_rank() == 0):
Logger("---- Data Type Check ----")
Logger(f"context.dtype: {context.dtype}")
Logger(f"context.shape: {context.shape}")
Logger(f"target.dtype: {target.dtype}")
Logger(f"target.shape: {target.shape}")
if hasattr(model, 'module'): # DDP case
Logger(f"Model parameter dtype: {next(model.module.parameters()).dtype}")
else: # Non-DDP case
Logger(f"Model parameter dtype: {next(model.parameters()).dtype}")
Logger(f"logits.dtype: {logits.dtype}")
Logger(f"logits.shape: {logits.shape}")
Logger(f"loss.dtype: {loss.dtype}")
Logger("-------------------------")
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
total_loss += loss.item() * args.accumulation_steps
total_samples += 1
# 打印日志
if step % args.log_interval == 0:
spend_time = time.time() - start_time
avg_loss = total_loss / total_samples if total_samples > 0 else 0
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
avg_loss,
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": avg_loss,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
except Exception as e:
print(f"Error occurred: {str(e)}")
import traceback
traceback.print_exc()
# Modified checkpoint path for error
save_path = f'{args.save_dir}/word2vec_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_ERROR.pth'
if os.path.exists(save_path):
os.remove(save_path)
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.embeddings.state_dict()
else:
state_dict = model.embeddings.state_dict()
torch.save(state_dict, save_path)
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"NaN gradient in parameter: {name}")
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"Parameter {name} values: {param.data}")
print(f"Parameter {name} gradients: {param.grad}")
raise ValueError("NaN gradient detected")
# Save model once at the end of each epoch
if not ddp or dist.get_rank() == 0:
model.eval()
ckp = f'{args.save_dir}/word2vec_embedding_dim{lm_config.dim}_vocab{lm_config.vocab_size}_epoch{epoch+1}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
embedding_state_dict = model.module.embeddings.state_dict()
else:
embedding_state_dict = model.embeddings.state_dict()
torch.save(embedding_state_dict, ckp)
Logger(f"Saved word2vec embedding for epoch {epoch+1} to {ckp}")
model.train()
def init_model(lm_config_params: LMConfig):
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
# Update vocab_size in lm_config if tokenizer has a different one
if tokenizer.vocab_size != lm_config_params.vocab_size:
Logger(f"Updating lm_config.vocab_size from {lm_config_params.vocab_size} to {tokenizer.vocab_size} based on tokenizer.")
lm_config_params.vocab_size = tokenizer.vocab_size
# 加载word2vec CBOW模型
model = CBOWModel(lm_config_params).to(args.device)
# 打印模型参数
Logger(f'CBOW Model total parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} Million')
return model, tokenizer
def init_distributed_mode():
if not ddp: return #如果没有启用分布式数据并行(DDP),直接返回,不执行任何操作。
global ddp_local_rank, DEVICE #声明这两个变量为全局变量,以便在函数外部也能访问它们。
dist.init_process_group(backend="nccl") #初始化分布式进程组使用NCCL后端NVIDIA Collective Communications Library这是NVIDIA GPU之间通信的优化库。
ddp_rank = int(os.environ["RANK"]) #从环境变量获取当前进程的全局编号。
ddp_local_rank = int(os.environ["LOCAL_RANK"]) #从环境变量获取当前进程的本地编号。
ddp_world_size = int(os.environ["WORLD_SIZE"]) #从环境变量获取当前进程组中的进程总数。
DEVICE = f"cuda:{ddp_local_rank}" #根据本地编号选择GPU设备。
torch.cuda.set_device(DEVICE) #设置当前进程的GPU设备。
# torchrun --nproc_per_node 2 train_embedding.py
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Word2Vec Embedding Training")
parser.add_argument("--out_dir", type=str, default="out_word2vec")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Word2Vec-Training")
parser.add_argument("--num_workers", type=int, default=32)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=100)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--dim', default=768, type=int)
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
parser.add_argument('--vocab_size', default=6400, type=int)
parser.add_argument('--window_size', default=5, type=int)
args = parser.parse_args()
# Create LMConfig with relevant parameters for embedding
lm_config = LMConfig(
dim=args.dim,
vocab_size=args.vocab_size, # Will be updated by tokenizer
max_seq_len=args.max_seq_len,
n_layers=1, # Minimal
n_heads=1, # Minimal
n_kv_heads=1 #Minimal
)
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * lm_config.max_seq_len
print(f"tokens_per_iter: {tokens_per_iter}")
device_type = "cuda" if "cuda" in args.device else "cpu"
# Determine the torch dtype
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
args.wandb_run_name = f"MiniMind-Word2Vec-Dim-{args.dim}-Vocab-{lm_config.vocab_size}-Window-{args.window_size}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0" # Default values, will be overwritten in DDP
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode() # This sets DEVICE and ddp_local_rank
args.device = torch.device(DEVICE) # Ensure args.device is updated
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed_all(base_seed + rank) # Use seed_all for DDP
if args.use_wandb and (not ddp or dist.get_rank() == 0): # Check rank for DDP wandb init
import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args)
else:
wandb = None
model, tokenizer = init_model(lm_config) # Pass the lm_config instance
# Update lm_config vocab_size again after tokenizer to ensure consistency for save path name
if lm_config.vocab_size != tokenizer.vocab_size:
lm_config.vocab_size = tokenizer.vocab_size
args.wandb_run_name = f"MiniMind-Word2Vec-Dim-{args.dim}-Vocab-{lm_config.vocab_size}-Window-{args.window_size}"
if wandb is not None and (not ddp or dist.get_rank() == 0):
wandb.config.update({'vocab_size': lm_config.vocab_size, 'wandb_run_name': args.wandb_run_name}, allow_val_change=True)
# 添加collate函数处理不同长度的序列
def collate_cbow_batch(batch):
# 提取context和target
contexts, targets = zip(*batch)
# 获取当前批次中最长的context长度
max_len = max([ctx.size(0) for ctx in contexts])
# 创建填充后的tensor
padded_contexts = torch.zeros(len(contexts), max_len, dtype=torch.long)
# 填充每个context
for i, ctx in enumerate(contexts):
ctx_len = ctx.size(0)
padded_contexts[i, :ctx_len] = ctx
# 将targets stack成一个tensor
stacked_targets = torch.stack(targets)
return padded_contexts, stacked_targets
# Create Word2Vec CBOW dataset
train_ds = CBOWDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len, window_size=args.window_size)
train_sampler = DistributedSampler(train_ds, shuffle=True, seed=base_seed) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=True,
shuffle=(train_sampler is None),
num_workers=args.num_workers,
sampler=train_sampler,
collate_fn=collate_cbow_batch
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)
Logger(f"Starting Word2Vec CBOW training for {args.epochs} epochs with {iter_per_epoch} iterations per epoch.")
for epoch in range(args.epochs):
if ddp:
train_sampler.set_epoch(epoch)
train_epoch(epoch, wandb)
if wandb is not None and (not ddp or dist.get_rank() == 0):
wandb.finish()
Logger("Word2Vec embedding training finished.")

214
train_full_sft.py Normal file
View File

@ -0,0 +1,214 @@
import os
# 设置环境变量
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
import platform
import argparse
import time
import math
import warnings
import pandas as pd
import torch
import torch.nn.functional as F
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model import MiniMindLM
from model.LMConfig import LMConfig
from model.dataset import SFTDataset
warnings.filterwarnings('ignore')
# 日志记录函数,用于打印训练信息。
def Logger(content):
if not ddp or dist.get_rank() == 0:
print(content)
# 学习率计算函数,用于计算当前学习率。
def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
# 训练一个epoch的函数用于训练模型。
def train_epoch(epoch, wandb):
loss_fct = nn.CrossEntropyLoss(reduction='none') #交叉熵损失函数,用于计算损失。
start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader):
# 将数据移动到指定设备。
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
# 计算当前学习率。
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
# 更新学习率。
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with ctx:
res = model(X) #获取输出
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size()) #计算损失
# 计算损失
loss = (loss * loss_mask).sum() / loss_mask.sum()
loss += res.aux_loss
loss = loss / args.accumulation_steps
scaler.scale(loss).backward() #用于处理混合精度训练。它的作用是自动缩放损失值,以防止在使用低精度(如 FP16计算时出现数值不稳定的问题。
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer) #PyTorch 自动混合精度(AMP)训练的一部分。它"反缩放"之前为防止在混合精度训练中出现下溢而缩放的梯度。
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) #应用梯度裁剪以防止梯度爆炸。它会缩放梯度使其范数不超过args.grad_clip。
scaler.step(optimizer) #使用优化器更新模型权重,但由缩放器控制以适应混合精度训练。
scaler.update() #根据本次迭代是否有梯度溢出来更新下一次迭代的缩放因子。
optimizer.zero_grad(set_to_none=True) #清空梯度。
# 如果达到日志记录间隔,则记录日志。
if step % args.log_interval == 0:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
loss.item(),
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/full_sft_{lm_config.dim}{moe_path}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, ckp)
model.train()
# 初始化模型函数,用于初始化模型。
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
return model, tokenizer
# 初始化分布式模式函数,用于初始化分布式模式。
def init_distributed_mode():
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Full SFT")
parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--learning_rate", type=float, default=5e-5)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", default=True, action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=100)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--dim', default=1024, type=int) #模型维度,用于控制模型的大小。
parser.add_argument('--n_layers', default=24, type=int) #层数,用于控制模型层数。
parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument("--data_path", type=str, default="./dataset/sft_1024.jsonl")
args = parser.parse_args()
lm_config = LMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe)
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * lm_config.max_seq_len
device_type = "cuda" if "cuda" in args.device else "cpu"
args.wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
# 如果使用分布式模式,则初始化分布式模式。
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
# 如果使用WandB则初始化WandB。
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else:
wandb = None
# 初始化模型。
model, tokenizer = init_model(lm_config)
# 初始化数据集。
train_ds = SFTDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=False,
num_workers=args.num_workers,
sampler=train_sampler
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) #创建一个梯度缩放器(GradScaler),用于混合精度训练。当模型使用半精度格式(float16或bfloat16)训练时启用,它帮助防止梯度下溢并提高训练效率。
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) # 创建AdamW优化器实例负责更新模型参数。它接收模型的所有参数和指定的学习率作为输入。AdamW是Adam优化器的变体增加了权重衰减的正则化。
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
train_epoch(epoch, wandb)

201
train_lora.py Normal file
View File

@ -0,0 +1,201 @@
import os
import platform
import argparse
import random
import time
import math
import warnings
import torch.distributed as dist
from contextlib import nullcontext
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model import MiniMindLM
from model.LMConfig import LMConfig
from model.dataset import SFTDataset
from model.model_lora import *
warnings.filterwarnings('ignore')
# Logger function
def Logger(content):
if not ddp or dist.get_rank() == 0:
print(content)
def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
# 代码和full_sft「几乎」一致
def train_epoch(epoch, wandb):
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
loss += res.aux_loss
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
loss.item(),
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
# 【区别1】只保存lora权重即可
save_lora(model, f'{args.save_dir}/lora/{args.lora_name}_{lm_config.dim}.pth')
model.train()
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/rlhf_{lm_config.dim}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
return model.to(args.device), tokenizer
def init_distributed_mode():
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind SFT with LoRA")
parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--learning_rate", type=float, default=5e-5)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA-SFT")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=1)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--dim', default=512, type=int)
parser.add_argument('--n_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument("--data_path", type=str, default="./dataset/lora_identity.jsonl")
parser.add_argument("--lora_name", type=str, default="lora_identity", help="根据任务保存成lora_(英文/医学/心理...)")
args = parser.parse_args()
lm_config = LMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe)
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * lm_config.max_seq_len
device_type = "cuda" if "cuda" in args.device else "cpu"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
args.wandb_run_name = f"MiniMind-Lora-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else:
wandb = None
model, tokenizer = init_model(lm_config)
apply_lora(model)
total_params = sum(p.numel() for p in model.parameters()) # 总参数数量
lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name) # LoRA 参数数量
if not ddp or dist.get_rank() == 0:
print(f"LLM 总参数量: {total_params}")
print(f"LoRA 参数量: {lora_params_count}")
print(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
for name, param in model.named_parameters():
if 'lora' not in name:
param.requires_grad = False
lora_params = []
for name, param in model.named_parameters():
if 'lora' in name:
lora_params.append(param)
# 只对 LoRA 参数进行优化
optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
train_ds = SFTDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=False,
num_workers=args.num_workers,
sampler=train_sampler
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
train_epoch(epoch, wandb)

441
train_pretrain.py Normal file
View File

@ -0,0 +1,441 @@
import os
# 设置环境变量
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
import platform
import argparse
import time
import math
import warnings
import pandas as pd
import torch
import torch.distributed as dist
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, DistributedSampler
# 移除通信分析工具导入
from contextlib import nullcontext
from typing import Optional
from transformers import AutoTokenizer
from model.model import MiniMindLM
from model.LMConfig import LMConfig
from model.dataset import PretrainDataset
warnings.filterwarnings('ignore')
def Logger(content):
# 如果没有使用ddp或者ddp的主设备那么就打印
if not ddp or dist.get_rank() == 0:
print(content)
def get_lr(current_step, total_steps, lr):
# 更新学习率
# \text{get\_lr}(c, t, l) = \frac{l}{10} + 0.5 \cdot l \cdot \left(1 + \cos\left(\frac{\pi \cdot c}{t}\right)\right)
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
def train_epoch(epoch, wandb):
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
# 在函数开始处定义moe_path避免在异常处理中引用未定义变量
moe_path = '_moe' if lm_config.use_moe else ''
# 添加CUDA事件来分析性能
if args.profile and (not ddp or dist.get_rank() == 0):
data_start = torch.cuda.Event(enable_timing=True)
data_end = torch.cuda.Event(enable_timing=True)
forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
# 移除CUDA图优化代码
# 预取数据
prefetch_factor = 2 # 预取的批次数
data_iter = iter(train_loader)
prefetch_batches = []
# 预取初始批次
for _ in range(min(prefetch_factor, len(train_loader))):
try:
batch = next(data_iter)
prefetch_batches.append([t.to(args.device, non_blocking=True) for t in batch])
except StopIteration:
break
for step in range(len(train_loader)):
try:
# 计时数据加载
if args.profile and (not ddp or dist.get_rank() == 0):
data_start.record()
# 使用预取的数据
if prefetch_batches:
X, Y, loss_mask = prefetch_batches.pop(0)
else:
# 如果预取队列为空,直接加载
X, Y, loss_mask = [t.to(args.device) for t in next(data_iter)]
# 异步预取下一批数据
if step + prefetch_factor < len(train_loader):
try:
batch = next(data_iter)
prefetch_batches.append([t.to(args.device, non_blocking=True) for t in batch])
except StopIteration:
pass
if args.profile and (not ddp or dist.get_rank() == 0):
data_end.record()
# 更新学习率
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 计时前向传播
if args.profile and (not ddp or dist.get_rank() == 0):
forward_start.record()
# 常规前向传播
with ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
# 添加辅助损失,如果存在的话
try:
if hasattr(model, 'module'):
# DDP情况
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
if hasattr(l.feed_forward, 'aux_loss'))
else:
# 非DDP情况
aux_loss = sum(l.feed_forward.aux_loss for l in model.layers
if hasattr(l.feed_forward, 'aux_loss'))
loss += aux_loss
except Exception as e:
Logger(f"Warning: Could not add auxiliary loss: {e}")
# 如果出错,不添加辅助损失
loss = loss / args.accumulation_steps
# 反向传播
scaler.scale(loss).backward()
if args.profile and (not ddp or dist.get_rank() == 0):
forward_end.record()
backward_start.record()
# Print data types for debugging
if step == 0 and (not ddp or dist.get_rank() == 0): # Print only for the first step of the first epoch on the main process
Logger("---- Data Type Check ----")
Logger(f"X.dtype: {X.dtype}")
if hasattr(model, 'module'): # DDP case
Logger(f"Model parameter dtype: {next(model.module.parameters()).dtype}")
else: # Non-DDP case
Logger(f"Model parameter dtype: {next(model.parameters()).dtype}")
Logger(f"res.logits.dtype: {res.logits.dtype}")
Logger(f"loss.dtype: {loss.dtype}")
Logger("-------------------------")
if args.profile and (not ddp or dist.get_rank() == 0):
backward_end.record()
# 在每一步都进行性能分析,而不仅仅是在梯度累积完成时
if (step + 1) % args.profile_interval == 0:
# 记录优化器时间(如果是梯度累积步骤)
if (step + 1) % args.accumulation_steps == 0:
optimizer_start.record()
# 优化器步骤
if (step + 1) % args.accumulation_steps == 0:
if args.profile and (not ddp or dist.get_rank() == 0):
if (step + 1) % args.profile_interval != 0:
optimizer_start.record()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if args.profile and (not ddp or dist.get_rank() == 0):
optimizer_end.record()
# 性能分析输出每profile_interval步
if args.profile and (not ddp or dist.get_rank() == 0) and (step + 1) % args.profile_interval == 0:
# 同步CUDA事件以获取准确的计时
torch.cuda.synchronize()
# 计算各阶段耗时
data_time = data_start.elapsed_time(data_end)
forward_time = forward_start.elapsed_time(forward_end)
backward_time = backward_start.elapsed_time(backward_end)
# 只有在梯度累积步骤完成时才有优化器时间
if (step + 1) % args.accumulation_steps == 0:
optimizer_time = optimizer_start.elapsed_time(optimizer_end)
total_compute_time = forward_time + backward_time + optimizer_time
Logger(f"性能分析 - 步骤 {step+1}:")
Logger(f" 数据加载时间: {data_time:.2f} ms")
Logger(f" 前向传播时间: {forward_time:.2f} ms")
Logger(f" 反向传播时间: {backward_time:.2f} ms")
Logger(f" 优化器时间: {optimizer_time:.2f} ms")
Logger(f" 总计算时间: {total_compute_time:.2f} ms")
Logger(f" 计算/数据比例: {total_compute_time / data_time:.2f}")
else:
# 非梯度累积步骤,没有优化器时间
total_compute_time = forward_time + backward_time
Logger(f"性能分析 - 步骤 {step+1} (梯度累积中):")
Logger(f" 数据加载时间: {data_time:.2f} ms")
Logger(f" 前向传播时间: {forward_time:.2f} ms")
Logger(f" 反向传播时间: {backward_time:.2f} ms")
Logger(f" 总计算时间: {total_compute_time:.2f} ms")
Logger(f" 计算/数据比例: {total_compute_time / data_time:.2f}")
# 打印日志
if step % args.log_interval == 0:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
loss.item() * args.accumulation_steps,
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
log_dict = {
"loss": loss.item() * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
}
# 如果启用了性能分析,也记录性能指标
if args.profile and (step + 1) % args.profile_interval == 0:
# 基本性能指标
perf_dict = {
"data_time_ms": data_time,
"forward_time_ms": forward_time,
"backward_time_ms": backward_time
}
# 只有在梯度累积步骤完成时才有优化器时间
if (step + 1) % args.accumulation_steps == 0:
total_compute_time = forward_time + backward_time + optimizer_time
perf_dict.update({
"optimizer_time_ms": optimizer_time,
"compute_time_ms": total_compute_time
})
else:
total_compute_time = forward_time + backward_time
perf_dict.update({
"compute_time_ms": total_compute_time
})
log_dict.update(perf_dict)
wandb.log(log_dict)
# 移除通信分析代码
# 保存模型
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
# 使用函数开始处定义的moe_path变量
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict() #获取模型参数
else:
state_dict = model.state_dict() #获取模型参数
torch.save(state_dict, ckp) #只保存参数
model.train()
except Exception as e:
print(f"Error occurred: {str(e)}")
save_path = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}_nanERROR.pth'
if os.path.exists(save_path):
os.remove(save_path)
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, save_path)
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"NaN gradient in parameter: {name}")
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"Parameter {name} values: {param.data}")
print(f"Parameter {name} gradients: {param.grad}")
raise ValueError("NaN gradient detected")
def init_model(lm_config, pretrained_embedding_path: Optional[str] = None):
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
# 加载模型
model = MiniMindLM(lm_config).to(args.device)
# Load pretrained token embeddings if path is provided
if pretrained_embedding_path and os.path.exists(pretrained_embedding_path):
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
embedding_weights = torch.load(pretrained_embedding_path, map_location=args.device)
model.tok_embeddings.load_state_dict(embedding_weights)
Logger("Successfully loaded pretrained token embeddings.")
elif pretrained_embedding_path:
Logger(f"Warning: Pretrained embedding path {pretrained_embedding_path} provided but file does not exist. Initializing embeddings from scratch.")
# 打印模型参数
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer
# 移除通信分析函数
def init_distributed_mode():
if not ddp: return #如果没有启用分布式数据并行(DDP),直接返回,不执行任何操作。
global ddp_local_rank, DEVICE #声明这两个变量为全局变量,以便在函数外部也能访问它们。
dist.init_process_group(backend="nccl") #初始化分布式进程组使用NCCL后端NVIDIA Collective Communications Library这是NVIDIA GPU之间通信的优化库。
ddp_rank = int(os.environ["RANK"]) #从环境变量获取当前进程的全局编号。
ddp_local_rank = int(os.environ["LOCAL_RANK"]) #从环境变量获取当前进程的本地编号。
ddp_world_size = int(os.environ["WORLD_SIZE"]) #从环境变量获取当前进程组中的进程总数。
DEVICE = f"cuda:{ddp_local_rank}" #根据本地编号选择GPU设备。
torch.cuda.set_device(DEVICE) #设置当前进程的GPU设备。
# torchrun --nproc_per_node 2 1-pretrain.py
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Pretraining")
parser.add_argument("--out_dir", type=str, default="out")
# 若要以最快速度实现zero则epochs设置为1轮否则应当利用有限的数据训练2~6个epochs。
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=24)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用则使用GPU否则使用CPU。
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", default=True, action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
parser.add_argument("--num_workers", type=int, default=48)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=32) #梯度累积步数,用于控制梯度更新频率。
parser.add_argument("--grad_clip", type=float, default=1.0) #梯度裁剪阈值,用于防止梯度爆炸。
parser.add_argument("--warmup_iters", type=int, default=0) #预热迭代次数,用于控制学习率预热过程。
parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。
parser.add_argument("--save_interval", type=int, default=10000) #模型保存间隔,用于控制模型保存的频率。
parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。
parser.add_argument('--dim', default=1024, type=int) #模型维度,用于控制模型的大小。
parser.add_argument('--n_layers', default=32, type=int) #层数,用于控制模型层数。
parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。
parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE用于控制是否使用MOE。
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代") #禁用数据库功能,启用特殊模式
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
# 性能分析相关参数
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
args = parser.parse_args()
print(args)
lm_config = LMConfig(
dim=args.dim,
n_layers=args.n_layers,
max_seq_len=args.max_seq_len,
use_moe=args.use_moe,
disable_db=args.disable_db, # 添加禁用数据库参数
flash_attn=args.use_flash_attn # 添加FlashAttention支持
) #创建LMConfig对象用于控制模型配置。
args.save_dir = os.path.join(args.out_dir) #创建保存目录。
os.makedirs(args.save_dir, exist_ok=True) #创建保存目录。
os.makedirs(args.out_dir, exist_ok=True) #创建输出目录。
tokens_per_iter = args.batch_size * lm_config.max_seq_len #计算每个迭代步骤的token数量。
print(f"tokens_per_iter: {tokens_per_iter}")
device_type = "cuda" if "cuda" in args.device else "cpu" #确定设备类型。
# Determine the torch dtype
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb
# Merge args and lm_config parameters for wandb config
config = vars(args).copy()
config.update(lm_config.__dict__)
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config)
else:
wandb = None
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
# 优化DataLoader配置
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
pin_memory_device=f"cuda:{ddp_local_rank}" if ddp else "cuda:0", # 指定pin_memory设备
drop_last=False,
shuffle=False,
num_workers=args.num_workers,
sampler=train_sampler,
persistent_workers=True if args.num_workers > 0 else False, # 保持worker进程活跃
prefetch_factor=2 if args.num_workers > 0 else None # 预取因子
)
# 只有在使用float16时才启用GradScalerbfloat16不需要
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
# 保留find_unused_parameters=True参数因为模型中确实有未使用的参数
model = DistributedDataParallel(model, device_ids=[ddp_local_rank], find_unused_parameters=True)
# 暂时保留set_detect_anomaly以便调试
# 训练稳定后可以注释掉这行来提高速度
torch.autograd.set_detect_anomaly(True)
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
train_epoch(epoch, wandb)

View File

@ -0,0 +1,437 @@
import os
# 设置环境变量
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
import platform
import argparse
import time
import math
import warnings
import pandas as pd
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
from contextlib import nullcontext
from typing import Optional
import datetime # Add datetime for time formatting
from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.utils import DeepSpeedPlugin
from accelerate.utils import DistributedDataParallelKwargs
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
from model.model import MiniMindLM
from model.LMConfig import LMConfig
from model.dataset import PretrainDataset
warnings.filterwarnings('ignore')
# 日志记录函数
def Logger(msg, accelerator=None):
# 如果没有提供accelerator则只在主进程打印
if accelerator is None or accelerator.is_main_process:
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}")
# Helper function to format seconds into HH:MM:SS
def format_time(seconds):
return str(datetime.timedelta(seconds=int(seconds)))
# 获取学习率函数
def get_lr(it, num_iters, learning_rate):
# 余弦学习率衰减
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
# 初始化模型函数
def init_model(lm_config, pretrained_embedding_path=None):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
# 如果提供了预训练的嵌入权重,加载它们
if pretrained_embedding_path:
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
pretrained_embeddings = torch.load(pretrained_embedding_path)
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time):
loss_fct = nn.CrossEntropyLoss(reduction='none')
epoch_start_time = time.time()
total_steps_in_epoch = len(train_loader)
total_training_steps = args.epochs * total_steps_in_epoch
moe_path = '_moe' if args.use_moe else ''
# 添加CUDA事件来分析性能 (只在主进程进行)
if args.profile and accelerator.is_main_process:
data_start = torch.cuda.Event(enable_timing=True)
data_end = torch.cuda.Event(enable_timing=True)
forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
# 预取数据
prefetch_factor = 2 # 预取的批次数
data_iter = iter(train_loader)
prefetch_batches = []
# 预取初始批次
for _ in range(min(prefetch_factor, len(train_loader))):
try:
batch = next(data_iter)
prefetch_batches.append(batch)
except StopIteration:
break
# 在开始循环前初始化日志记录所需变量
last_log_time = epoch_start_time
for step in range(total_steps_in_epoch):
try:
# 计时数据加载 (只在主进程进行)
if args.profile and accelerator.is_main_process:
data_start.record()
# 使用预取的数据
if prefetch_batches:
X, Y, loss_mask = prefetch_batches.pop(0)
else:
# 如果预取队列为空,直接加载
X, Y, loss_mask = next(data_iter)
# 异步预取下一批数据
if step + prefetch_factor < len(train_loader):
try:
batch = next(data_iter)
prefetch_batches.append(batch)
except StopIteration:
pass
# 计时数据加载结束 (只在主进程进行)
if args.profile and accelerator.is_main_process:
data_end.record()
# 更新学习率
if scheduler is not None:
scheduler.step()
# 计时前向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process:
forward_start.record()
# 前向传播
with ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
# 添加辅助损失,如果存在的话
try:
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
if hasattr(l.feed_forward, 'aux_loss'))
loss += aux_loss
except Exception as e:
Logger(f"Warning: Could not add auxiliary loss: {e}")
# 如果出错,不添加辅助损失
loss = loss / args.accumulation_steps
# 计时前向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process:
forward_end.record()
# 计时反向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process:
backward_start.record()
# 反向传播
# 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪
accelerator.backward(loss)
# 计时反向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process:
backward_end.record()
# 计时优化器步骤 (只在主进程进行)
if args.profile and accelerator.is_main_process:
optimizer_start.record()
# 优化器步骤 - 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪
# 只有在达到累积步数时才会执行优化器步骤
# 注意当使用DeepSpeed时它会自动处理梯度累积所以我们不需要检查step % accumulation_steps
optimizer.step()
# 当使用DeepSpeed时zero_grad()会在step()之后自动调用
# 但为了安全起见,我们仍然显式调用它
optimizer.zero_grad()
# 计时优化器步骤结束 (只在主进程进行)
if args.profile and accelerator.is_main_process:
optimizer_end.record()
# 打印训练信息 (只在主进程进行)
if (step + 1) % args.log_interval == 0 and accelerator.is_main_process:
current_time = time.time()
# 计算性能指标
if args.profile:
torch.cuda.synchronize()
# 使用自上次日志以来的时间计算性能指标,而不是总时间
data_time = data_start.elapsed_time(data_end)
forward_time = forward_start.elapsed_time(forward_end)
backward_time = backward_start.elapsed_time(backward_end)
optimizer_time = optimizer_start.elapsed_time(optimizer_end)
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
# total_time_ms = data_time + forward_time + backward_time + optimizer_time
# 打印性能分析
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
Logger(f"性能分析 (Avg/iter over last {args.log_interval} steps) - "
f"Data: {data_time/args.log_interval:.2f}ms, "
f"Fwd: {forward_time/args.log_interval:.2f}ms, "
f"Bwd: {backward_time/args.log_interval:.2f}ms, "
f"Optim: {optimizer_time/args.log_interval:.2f}ms, "
f"Iter Time: {iter_time:.2f}ms", accelerator)
# 重置事件以便下次测量从0开始
data_start = torch.cuda.Event(enable_timing=True)
data_end = torch.cuda.Event(enable_timing=True)
forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
# 计算当前学习率
current_lr = optimizer.param_groups[0]['lr']
# 计算时间
epoch_elapsed_time = current_time - epoch_start_time
epoch_steps_done = step + 1
epoch_avg_step_time = epoch_elapsed_time / epoch_steps_done
epoch_remaining_time = epoch_avg_step_time * (total_steps_in_epoch - epoch_steps_done)
total_elapsed_time = current_time - overall_start_time
total_steps_done = epoch * total_steps_in_epoch + epoch_steps_done
total_avg_step_time = total_elapsed_time / total_steps_done if total_steps_done > 0 else 0
total_remaining_time = total_avg_step_time * (total_training_steps - total_steps_done) if total_steps_done > 0 else 0
# 计算训练速度 (基于最近的log_interval)
interval_elapsed_time = current_time - last_log_time
tokens_processed_interval = args.log_interval * args.batch_size * args.max_seq_len
tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0
last_log_time = current_time # 更新上次日志时间
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
f"Loss: {loss.item()*args.accumulation_steps:.4f}, "
f"LR: {current_lr:.6f}, "
f"Speed: {tokens_per_sec:.2f} tokens/sec | "
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
# 保存模型 (只在主进程进行)
if (step + 1) % args.save_interval == 0 and accelerator.is_main_process:
# 使用函数开始处定义的moe_path变量
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
# 获取解包后的模型
unwrapped_model = accelerator.unwrap_model(model)
# 保存模型参数
accelerator.save(unwrapped_model.state_dict(), ckp)
Logger(f"Model saved to {ckp}", accelerator)
except Exception as e:
Logger(f"Error in training step: {e}", accelerator)
import traceback
Logger(traceback.format_exc(), accelerator)
def main():
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=24)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", default=True, action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
parser.add_argument("--num_workers", type=int, default=48)
parser.add_argument("--accumulation_steps", type=int, default=32)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=10000)
parser.add_argument('--dim', default=1024, type=int)
parser.add_argument('--n_layers', default=32, type=int)
parser.add_argument('--max_seq_len', default=1024, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代")
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowlwdge_num", type=int, default=64*64,help="知识库的数据数目")
parser.add_argument("--knowlwdge_length", type=int, default=8,help="知识库的句子长度")
args = parser.parse_args()
#########################################################
# 初始化accelerator和deepspeed
#########################################################
# 设置ddp_kwargs以处理未使用的参数
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# 创建DeepSpeedPlugin对象
ds_plugin = DeepSpeedPlugin(
gradient_accumulation_steps=args.accumulation_steps,
gradient_clipping=args.grad_clip,
zero_stage=2, # 使用ZeRO-2优化
offload_optimizer_device="cpu", # 将优化器状态卸载到CPU
offload_param_device="none", # 不将参数卸载到CPU
)
accelerator = Accelerator(
kwargs_handlers=[ddp_kwargs],
deepspeed_plugin=ds_plugin,
mixed_precision="bf16" if args.dtype == "bfloat16" else "fp16" if args.dtype == "float16" else "no"
)
#########################################################
# 设置随机种子
#########################################################
set_seed(1337 + accelerator.process_index)
#########################################################
# 配置模型
#########################################################
lm_config = LMConfig(
dim=args.dim,
n_layers=args.n_layers,
max_seq_len=args.max_seq_len,
use_moe=args.use_moe,
disable_db=args.disable_db,
flash_attn=args.use_flash_attn,
knowlwdge_num=args.knowlwdge_num,
knowlwdge_length=args.knowlwdge_length
)
#########################################################
# 创建保存目录
#########################################################
args.save_dir = os.path.join(args.out_dir)
if accelerator.is_main_process:
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
#########################################################
# 设置数据类型
#########################################################
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
#########################################################
# 配置wandb
#########################################################
# 设置wandb运行名称
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
if args.use_wandb and accelerator.is_main_process:
import wandb
# 合并args和lm_config为一个字典
config_dict = vars(args).copy()
config_dict.update(vars(lm_config))
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config_dict)
else:
wandb = None
#########################################################
# 打印信息
#########################################################
# 计算每次迭代的token数量
tokens_per_iter = args.batch_size * lm_config.max_seq_len
if accelerator.is_main_process:
Logger(f"tokens_per_iter: {tokens_per_iter}", accelerator)
Logger("Configuration:", accelerator)
for key, value in config_dict.items():
Logger(f" {key}: {value}", accelerator)
#########################################################
# 设置自动混合精度上下文
#########################################################
ctx = nullcontext() if accelerator.device.type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
#########################################################
# 初始化模型和tokenizer
#########################################################
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
# 将accelerator传递给init_model函数中的Logger调用
Logger(f'模型初始化完成', accelerator)
#########################################################
# 处理位置编码张量问题
#########################################################
if hasattr(model, "pos_cis_real"):
Logger(f'检测到pos_cis_real实数张量将其设置为参与分布式训练', accelerator)
# 设置模型的_ddp_params_and_buffers_to_ignore属性
# model._ddp_params_and_buffers_to_ignore = {"pos_cis_real"}
# 兼容旧版本检查是否仍有pos_cis
elif hasattr(model, "pos_cis"):
Logger(f'检测到pos_cis复数张量将其设置为不参与分布式训练', accelerator)
# 设置模型的_ddp_params_and_buffers_to_ignore属性
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
#########################################################
# 创建数据集和数据加载器
#########################################################
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=True,
num_workers=args.num_workers,
persistent_workers=True if args.num_workers > 0 else False,
prefetch_factor=2 if args.num_workers > 0 else None
)
#########################################################
# 创建优化器
#########################################################
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
#########################################################
# 创建学习率调度器
#########################################################
total_steps = len(train_loader) * args.epochs
warmup_steps = args.warmup_iters if args.warmup_iters > 0 else int(0.1 * total_steps)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
#########################################################
# 准备训练
#########################################################
model, optimizer, train_loader, scheduler = accelerator.prepare(
model, optimizer, train_loader, scheduler
)
#########################################################
# 训练循环
#########################################################
overall_start_time = time.time() # Record overall start time
for epoch in range(args.epochs):
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time) # Pass overall start time
#########################################################
# 关闭wandb
#########################################################
if args.use_wandb and accelerator.is_main_process:
wandb.finish()
if __name__ == "__main__":
main()