Minimind/train_pretrain_accelerate.py

848 lines
43 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
# 设置环境变量
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
import platform
import argparse
from tqdm import tqdm
import time
import math
import warnings
import pandas as pd
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
from contextlib import nullcontext
from typing import Optional
import datetime # Add datetime for time formatting
from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.utils import DeepSpeedPlugin
from accelerate.utils import DistributedDataParallelKwargs
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from model.model import MiniMindLM, RMSNorm
from model.LMConfig import LMConfig
from model.dataset import PretrainDataset
warnings.filterwarnings('ignore')
# 日志记录函数
def Logger(msg, accelerator=None):
# 如果没有提供accelerator则只在主进程打印
if accelerator is None or accelerator.is_main_process:
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}")
# Helper function to format seconds into HH:MM:SS
def format_time(seconds):
return str(datetime.timedelta(seconds=int(seconds)))
# 获取学习率函数
def get_lr(it, num_iters, learning_rate):
# 余弦学习率衰减
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
# 初始化模型函数
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
# 默认模型初始化
Logger("Performing default model initialization...")
# 初始化嵌入层权重
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
# 初始化输出层权重(如果不共享权重的话)
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
# 初始化所有线性层
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# 使用Xavier/Glorot初始化
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
# 嵌入层使用正态分布初始化
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, RMSNorm):
# RMSNorm的权重初始化为1
if hasattr(module, 'weight'):
nn.init.ones_(module.weight)
# 初始化位置编码相关参数
if hasattr(model.extract_db, 'keys'):
nn.init.normal_(model.extract_db.keys, mean=0.0, std=0.02)
Logger("Default model initialization completed")
# 如果提供了预训练的嵌入权重,加载它们
if pretrained_embedding_path:
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
pretrained_embeddings = torch.load(pretrained_embedding_path)
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
if database_init_path:
import json
import numpy as np
from sentence_transformers import SentenceTransformer
import os
# 聚类参数(需要提前定义用于缓存检查)
knowledge_num = args.knowledge_num
knowledge_length = args.knowledge_length
# 检查是否使用缓存(提前检查,避免不必要的数据处理)
cache_dir = os.path.dirname(args.cluster_cache_path)
if cache_dir:
os.makedirs(cache_dir, exist_ok=True)
clustered_tensor = None
# 尝试加载缓存的聚类结果
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
try:
Logger(f"Loading cached cluster results from {args.cluster_cache_path}")
clustered_tensor = torch.load(args.cluster_cache_path)
# 验证缓存文件的形状是否可用
cached_knowledge_num, cached_knowledge_length = clustered_tensor.shape
if cached_knowledge_length == knowledge_length:
if cached_knowledge_num >= knowledge_num:
# 缓存足够大,可以截取使用
clustered_tensor = clustered_tensor[:knowledge_num, :]
Logger(f"Successfully loaded cached clusters with shape {clustered_tensor.shape}")
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
Logger("Skipping database initialization and clustering - using cached results")
else:
# 缓存太小,需要重新计算
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
clustered_tensor = None
else:
# knowledge_length不匹配需要重新计算
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
clustered_tensor = None
except Exception as e:
Logger(f"Failed to load cached clusters: {e}, recomputing...")
clustered_tensor = None
# 只有在没有有效缓存时才进行数据库初始化和聚类计算
if clustered_tensor is None:
Logger(f"Loading database initialization data from {database_init_path}")
# 1. 加载JSON文件并转换为字典
with open(database_init_path, 'r', encoding='utf-8') as f:
database_data = json.load(f)
# 提取sentences列表
sentences_data = database_data.get('sentences', [])
Logger(f"Loaded {len(sentences_data)} sentences from database")
# 2. 按照importance_score进行排序从高到低
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
# 3. 下载并初始化本地嵌入模型
embedding_model_name = "sentence-transformers/all-mpnet-base-v2" # 轻量级但效果好的模型
embedding_model_dir = "./models/sentence_transformers/models--sentence-transformers--all-mpnet-base-v2"
embedding_cache_dir = "./models/sentence_transformers/cache"
os.makedirs(embedding_cache_dir, exist_ok=True)
Logger(f"Loading embedding model: {embedding_model_name}")
try:
embedding_model = SentenceTransformer(embedding_model_dir, cache_folder=embedding_cache_dir)
Logger("Embedding model loaded successfully")
except Exception as e:
Logger(f"Failed to load embedding model: {e}")
Logger("Falling back to random embeddings")
embedding_model = None
# 4. 对每个corrected_sentence进行嵌入和token长度计算
Logger("Processing sentences for embeddings and token lengths...")
# 提取所有句子
sentences = [sentence_data.get('corrected_sentence', '') for sentence_data in sorted_sentences]
# 批量计算token长度
Logger("Computing token lengths...")
token_lengths = []
for sentence in sentences:
tokens = tokenizer.encode(sentence, add_special_tokens=False)
token_lengths.append(len(tokens))
# 批量计算嵌入 - 大幅提升速度
Logger("Computing embeddings in batches...")
embeddings_list = []
batch_size = 256 # 可以根据GPU内存调整
if embedding_model is not None:
try:
for i in range(0, len(sentences), batch_size):
batch_sentences = sentences[i:i+batch_size]
batch_embeddings = embedding_model.encode(
batch_sentences,
convert_to_tensor=False,
show_progress_bar=True if i == 0 else False,
batch_size=batch_size
)
embeddings_list.extend(batch_embeddings)
if (i + batch_size) % (batch_size * 10) == 0:
Logger(f"Processed {min(i + batch_size, len(sentences))}/{len(sentences)} sentences")
Logger("Batch embedding computation completed")
except Exception as e:
Logger(f"Error in batch encoding: {e}")
Logger("Falling back to random embeddings")
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
else:
# 使用随机嵌入
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
# 创建处理后的句子列表
processed_sentences = []
for i, (sentence_data, embedding, token_length) in enumerate(zip(sorted_sentences, embeddings_list, token_lengths)):
processed_sentences.append({
'sentence': sentence_data.get('corrected_sentence', ''),
'importance_score': sentence_data.get('importance_score', 0.0),
'token_length': token_length,
'embedding': embedding, # Convert numpy array to list
'original_index': i
})
# 转换为numpy数组以便后续处理
embeddings_array = np.array(embeddings_list)
token_lengths_array = np.array(token_lengths)
Logger(f"Embedding processing completed:")
Logger(f" - Total sentences: {len(processed_sentences)}")
Logger(f" - Embedding shape: {embeddings_array.shape}")
Logger(f" - Average token length: {np.mean(token_lengths_array):.2f}")
Logger(f" - Token length range: {np.min(token_lengths_array)} - {np.max(token_lengths_array)}")
# 聚类参数定义
min_tokens = int(0.85 * knowledge_length)
max_tokens = int(0.95 * knowledge_length)
# 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大)
if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算
Logger("Pre-computing similarity matrix for faster clustering...")
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
similarity_matrix = cosine_similarity(embeddings_matrix)
Logger(f"Similarity matrix computed: {similarity_matrix.shape}")
else:
similarity_matrix = None
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
clustered_rows = []
remaining_indices = list(range(len(processed_sentences))) # 使用索引而不是对象
Logger(f"Target: {knowledge_num} clusters, each with {min_tokens}-{max_tokens} tokens")
# 选择聚类算法
if args.fast_clustering and len(processed_sentences) > 5000:
Logger("Using ultra-fast approximate clustering algorithm...")
# 超快速聚类:随机采样 + 批量处理
import random
random.seed(42) # 确保可重现性
# 按重要性分层采样
high_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] > 0.7]
medium_importance = [i for i, s in enumerate(processed_sentences) if 0.3 <= s['importance_score'] <= 0.7]
low_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] < 0.3]
Logger(f"Importance distribution: High={len(high_importance)}, Medium={len(medium_importance)}, Low={len(low_importance)}")
for cluster_idx in tqdm(range(knowledge_num)):
# 分层选择种子:优先选择高重要性句子
if high_importance:
seed_pool = high_importance
elif medium_importance:
seed_pool = medium_importance
else:
seed_pool = low_importance if low_importance else list(range(len(processed_sentences)))
if not seed_pool:
break
# 随机选择种子(在同一重要性层级内)
seed_global_idx = random.choice(seed_pool)
seed_sentence = processed_sentences[seed_global_idx]
# 从所有池中移除种子
for pool in [high_importance, medium_importance, low_importance]:
if seed_global_idx in pool:
pool.remove(seed_global_idx)
current_cluster_indices = [seed_global_idx]
current_tokens = seed_sentence['token_length']
if current_tokens < max_tokens:
# 快速选择:只从附近的句子中随机选择
all_remaining = high_importance + medium_importance + low_importance
if all_remaining:
# 随机采样候选句子(而不是计算所有相似度)
sample_size = min(2000, len(all_remaining))
candidates = random.sample(all_remaining, sample_size)
# 简单按token长度和重要性选择
for candidate_idx in candidates:
candidate = processed_sentences[candidate_idx]
candidate_tokens = candidate['token_length']
if current_tokens + candidate_tokens + 1 <= max_tokens:
current_cluster_indices.append(candidate_idx)
current_tokens += candidate_tokens + 1
# 从池中移除
for pool in [high_importance, medium_importance, low_importance]:
if candidate_idx in pool:
pool.remove(candidate_idx)
break
if current_tokens >= min_tokens:
break
# 生成聚类文本
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
cluster_text = '\n '.join(cluster_sentences)
# 转换为tokens
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
if len(cluster_tokens) > knowledge_length:
cluster_tokens = cluster_tokens[:knowledge_length]
else:
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
clustered_rows.append(cluster_tokens)
if (cluster_idx + 1) % 1000 == 0:
total_remaining = len(high_importance) + len(medium_importance) + len(low_importance)
Logger(f"Fast clustering: {cluster_idx + 1}/{knowledge_num} clusters, {total_remaining} sentences remaining")
else:
# 原始优化算法(适用于中等规模数据集)
# 优化2: 批量处理和更高效的数据结构
for cluster_idx in tqdm(range(knowledge_num)):
if not remaining_indices:
Logger(f"No more sentences available. Created {cluster_idx} clusters.")
break
# 2.1 选择importance_score最高的句子作为种子
remaining_sentences_subset = [processed_sentences[i] for i in remaining_indices]
seed_idx_in_subset = max(range(len(remaining_sentences_subset)),
key=lambda i: remaining_sentences_subset[i]['importance_score'])
seed_global_idx = remaining_indices[seed_idx_in_subset]
seed_sentence = processed_sentences[seed_global_idx]
# 从剩余索引中移除种子
remaining_indices.remove(seed_global_idx)
# 当前聚类
current_cluster_indices = [seed_global_idx]
current_tokens = seed_sentence['token_length']
if current_tokens >= max_tokens:
# 如果种子句子已经超过最大token数直接作为一个聚类
cluster_text = seed_sentence['sentence']
else:
# 2.2 优化的相似度计算和选择
if remaining_indices:
if similarity_matrix is not None:
# 使用预计算的相似度矩阵
similarities = similarity_matrix[seed_global_idx][remaining_indices]
else:
# 动态计算相似度(批量)
seed_embedding = embeddings_matrix[seed_global_idx:seed_global_idx+1]
remaining_embeddings = embeddings_matrix[remaining_indices]
similarities = cosine_similarity(seed_embedding, remaining_embeddings)[0]
# 创建(相似度, 原始索引, 在remaining_indices中的位置)的元组列表
similarity_tuples = [(similarities[i], remaining_indices[i], i)
for i in range(len(remaining_indices))]
# 按相似度排序(降序)
similarity_tuples.sort(key=lambda x: x[0], reverse=True)
# 优化3: 贪心选择,但限制搜索范围以提高速度
max_candidates = min(len(similarity_tuples), 500) # 只考虑前500个最相似的句子
selected_indices_in_remaining = []
for sim_score, global_idx, pos_in_remaining in similarity_tuples[:max_candidates]:
candidate = processed_sentences[global_idx]
candidate_tokens = candidate['token_length']
if current_tokens + candidate_tokens + 1 <= max_tokens: # +1 for newline
current_cluster_indices.append(global_idx)
selected_indices_in_remaining.append(pos_in_remaining)
current_tokens += candidate_tokens + 1
if current_tokens >= min_tokens:
break
# 批量移除选中的句子(从后往前移除以避免索引问题)
for pos in sorted(selected_indices_in_remaining, reverse=True):
remaining_indices.pop(pos)
# 拼接句子
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
cluster_text = '\n'.join(cluster_sentences)
# 将聚类文本转换为token
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
# 截断或填充到knowledge_length
if len(cluster_tokens) > knowledge_length:
cluster_tokens = cluster_tokens[:knowledge_length]
else:
# 用pad_token_id填充
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
clustered_rows.append(cluster_tokens)
# 优化4: 减少日志频率
if (cluster_idx + 1) % 500 == 0:
Logger(f"Created {cluster_idx + 1}/{knowledge_num} clusters, {len(remaining_indices)} sentences remaining")
# 如果聚类数量不足用随机token填充
while len(clustered_rows) < knowledge_num:
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
random_tokens = [pad_token_id] * knowledge_length
clustered_rows.append(random_tokens)
# 转换为tensor
clustered_tensor = torch.tensor(clustered_rows, dtype=torch.long)
Logger(f"Clustering completed:")
Logger(f" - Created {len(clustered_rows)} clusters")
Logger(f" - Cluster shape: {clustered_tensor.shape}")
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
# 保存聚类结果到缓存文件
try:
torch.save(clustered_tensor, args.cluster_cache_path)
Logger(f"Cluster results saved to {args.cluster_cache_path}")
except Exception as e:
Logger(f"Failed to save cluster results: {e}")
# 3. 初始化模型的weight_down_embed
if hasattr(model, 'extract_db') and hasattr(model.extract_db, 'weight_down_embed'):
model.extract_db.weight_down_embed.data.copy_(clustered_tensor)
Logger("Successfully initialized model.extract_db.weight_down_embed with clustered data")
else:
Logger("Warning: Could not find model.extract_db.weight_down_embed to initialize")
# 存储为全局变量作为备选
globals()['clustered_database'] = clustered_tensor
Logger(f"Database embeddings and sentences stored in model")
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb):
loss_fct = nn.CrossEntropyLoss(reduction='none')
epoch_start_time = time.time()
total_steps_in_epoch = len(train_loader)
total_training_steps = args.epochs * total_steps_in_epoch
moe_path = '_moe' if args.use_moe else ''
# 添加CUDA事件来分析性能 (只在主进程进行)
if args.profile and accelerator.is_main_process:
data_start = torch.cuda.Event(enable_timing=True)
data_end = torch.cuda.Event(enable_timing=True)
forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
# 预取数据
prefetch_factor = 2 # 预取的批次数
data_iter = iter(train_loader)
prefetch_batches = []
# 预取初始批次
for _ in range(min(prefetch_factor, len(train_loader))):
try:
batch = next(data_iter)
prefetch_batches.append(batch)
except StopIteration:
break
# 在开始循环前初始化日志记录所需变量
last_log_time = epoch_start_time
for step in range(total_steps_in_epoch):
try:
# 计时数据加载 (只在主进程进行)
if args.profile and accelerator.is_main_process:
data_start.record()
# 使用预取的数据
if prefetch_batches:
X, Y, loss_mask = prefetch_batches.pop(0)
else:
# 如果预取队列为空,直接加载
X, Y, loss_mask = next(data_iter)
# 异步预取下一批数据
if step + prefetch_factor < len(train_loader):
try:
batch = next(data_iter)
prefetch_batches.append(batch)
except StopIteration:
pass
# 计时数据加载结束 (只在主进程进行)
if args.profile and accelerator.is_main_process:
data_end.record()
# 更新学习率
if scheduler is not None:
scheduler.step()
# 计时前向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process:
forward_start.record()
# 前向传播
with ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
# 添加辅助损失,如果存在的话
try:
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
if hasattr(l.feed_forward, 'aux_loss'))
loss += aux_loss
except Exception as e:
Logger(f"Warning: Could not add auxiliary loss: {e}")
# 如果出错,不添加辅助损失
loss = loss / args.accumulation_steps
# 计时前向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process:
forward_end.record()
# 计时反向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process:
backward_start.record()
# 反向传播
# 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪
accelerator.backward(loss)
# 计时反向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process:
backward_end.record()
# 计时优化器步骤 (只在主进程进行)
if args.profile and accelerator.is_main_process:
optimizer_start.record()
# 优化器步骤 - 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪
# 只有在达到累积步数时才会执行优化器步骤
# 注意当使用DeepSpeed时它会自动处理梯度累积所以我们不需要检查step % accumulation_steps
optimizer.step()
# 当使用DeepSpeed时zero_grad()会在step()之后自动调用
# 但为了安全起见,我们仍然显式调用它
optimizer.zero_grad()
# 计时优化器步骤结束 (只在主进程进行)
if args.profile and accelerator.is_main_process:
optimizer_end.record()
# 打印训练信息 (只在主进程进行)
if (step + 1) % args.log_interval == 0 and accelerator.is_main_process:
current_time = time.time()
# 计算性能指标
if args.profile:
torch.cuda.synchronize()
# 使用自上次日志以来的时间计算性能指标,而不是总时间
data_time = data_start.elapsed_time(data_end)
forward_time = forward_start.elapsed_time(forward_end)
backward_time = backward_start.elapsed_time(backward_end)
optimizer_time = optimizer_start.elapsed_time(optimizer_end)
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
# total_time_ms = data_time + forward_time + backward_time + optimizer_time
# 打印性能分析
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
Logger(f"性能分析 (Avg/iter over last {args.log_interval} steps) - "
f"Data: {data_time/args.log_interval:.2f}ms, "
f"Fwd: {forward_time/args.log_interval:.2f}ms, "
f"Bwd: {backward_time/args.log_interval:.2f}ms, "
f"Optim: {optimizer_time/args.log_interval:.2f}ms, "
f"Iter Time: {iter_time:.2f}ms", accelerator)
# 重置事件以便下次测量从0开始
data_start = torch.cuda.Event(enable_timing=True)
data_end = torch.cuda.Event(enable_timing=True)
forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
# 计算当前学习率
current_lr = optimizer.param_groups[0]['lr']
# 计算时间
epoch_elapsed_time = current_time - epoch_start_time
epoch_steps_done = step + 1
epoch_avg_step_time = epoch_elapsed_time / epoch_steps_done
epoch_remaining_time = epoch_avg_step_time * (total_steps_in_epoch - epoch_steps_done)
total_elapsed_time = current_time - overall_start_time
total_steps_done = epoch * total_steps_in_epoch + epoch_steps_done
total_avg_step_time = total_elapsed_time / total_steps_done if total_steps_done > 0 else 0
total_remaining_time = total_avg_step_time * (total_training_steps - total_steps_done) if total_steps_done > 0 else 0
# 计算训练速度 (基于最近的log_interval)
interval_elapsed_time = current_time - last_log_time
tokens_processed_interval = args.log_interval * args.batch_size * args.max_seq_len
tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0
last_log_time = current_time # 更新上次日志时间
log_dict = {
"epoch": epoch + 1,
"step": step + 1,
"total_steps_in_epoch": total_steps_in_epoch,
"loss": loss.item() * args.accumulation_steps,
"lr": current_lr,
"tokens_per_sec": tokens_per_sec,
"epoch_time_left_seconds": epoch_remaining_time,
"total_time_left_seconds": total_remaining_time
}
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
f"Loss: {log_dict['loss']:.4f}, "
f"LR: {log_dict['lr']:.6f}, "
f"Speed: {log_dict['tokens_per_sec']:.2f} tokens/sec | "
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
if args.use_wandb and accelerator.is_main_process and wandb:
wandb.log(log_dict)
# 保存模型 (只在主进程进行)
if (step + 1) % args.save_interval == 0 and accelerator.is_main_process:
# 使用函数开始处定义的moe_path变量
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
# 获取解包后的模型
unwrapped_model = accelerator.unwrap_model(model)
# 保存模型参数
accelerator.save(unwrapped_model.state_dict(), ckp)
Logger(f"Model saved to {ckp}", accelerator)
except Exception as e:
Logger(f"Error in training step: {e}", accelerator)
import traceback
Logger(traceback.format_exc(), accelerator)
def main():
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=24)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", default=True, action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
parser.add_argument("--num_workers", type=int, default=48)
parser.add_argument("--accumulation_steps", type=int, default=32)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=10000)
parser.add_argument('--dim', default=1024, type=int)
parser.add_argument('--n_layers', default=32, type=int)
parser.add_argument('--max_seq_len', default=1024, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代")
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowledge_num", type=int, default=65536,help="知识库的数据数目")
parser.add_argument("--knowledge_length", type=int, default=64,help="知识库的句子长度")
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens.pt", help="聚类结果缓存文件路径")
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
args = parser.parse_args()
#########################################################
# 初始化accelerator和deepspeed
#########################################################
# 设置ddp_kwargs以处理未使用的参数
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# 创建DeepSpeedPlugin对象
ds_plugin = DeepSpeedPlugin(
gradient_accumulation_steps=args.accumulation_steps,
gradient_clipping=args.grad_clip,
zero_stage=2, # 使用ZeRO-2优化
offload_optimizer_device="cpu", # 将优化器状态卸载到CPU
offload_param_device="none", # 不将参数卸载到CPU
)
accelerator = Accelerator(
kwargs_handlers=[ddp_kwargs],
deepspeed_plugin=ds_plugin,
mixed_precision="bf16" if args.dtype == "bfloat16" else "fp16" if args.dtype == "float16" else "no"
)
#########################################################
# 设置随机种子
#########################################################
set_seed(1337 + accelerator.process_index)
#########################################################
# 配置模型
#########################################################
lm_config = LMConfig(
dim=args.dim,
n_layers=args.n_layers,
max_seq_len=args.max_seq_len,
use_moe=args.use_moe,
disable_db=args.disable_db,
flash_attn=args.use_flash_attn,
knowledge_num=args.knowledge_num,
knowledge_length=args.knowledge_length
)
#########################################################
# 创建保存目录
#########################################################
args.save_dir = os.path.join(args.out_dir)
if accelerator.is_main_process:
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
#########################################################
# 设置数据类型
#########################################################
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
#########################################################
# 配置wandb
#########################################################
# 设置wandb运行名称
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
if args.use_wandb and accelerator.is_main_process:
import wandb
# 合并args和lm_config为一个字典
config_dict = vars(args).copy()
config_dict.update(vars(lm_config))
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config_dict)
else:
wandb = None
#########################################################
# 打印信息
#########################################################
# 计算每次迭代的token数量
tokens_per_iter = args.batch_size * lm_config.max_seq_len
if accelerator.is_main_process:
Logger(f"tokens_per_iter: {tokens_per_iter}", accelerator)
Logger("Configuration:", accelerator)
for key, value in config_dict.items():
Logger(f" {key}: {value}", accelerator)
#########################################################
# 设置自动混合精度上下文
#########################################################
ctx = nullcontext() if accelerator.device.type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
#########################################################
# 初始化模型和tokenizer
#########################################################
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path, args.database_init_path, args)
# 将accelerator传递给init_model函数中的Logger调用
Logger(f'模型初始化完成', accelerator)
#########################################################
# 处理位置编码张量问题
#########################################################
if hasattr(model, "pos_cis_real"):
Logger(f'检测到pos_cis_real实数张量将其设置为参与分布式训练', accelerator)
# 设置模型的_ddp_params_and_buffers_to_ignore属性
# model._ddp_params_and_buffers_to_ignore = {"pos_cis_real"}
# 兼容旧版本检查是否仍有pos_cis
elif hasattr(model, "pos_cis"):
Logger(f'检测到pos_cis复数张量将其设置为不参与分布式训练', accelerator)
# 设置模型的_ddp_params_and_buffers_to_ignore属性
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
#########################################################
# 创建数据集和数据加载器
#########################################################
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=True,
num_workers=args.num_workers,
persistent_workers=True if args.num_workers > 0 else False,
prefetch_factor=2 if args.num_workers > 0 else None
)
#########################################################
# 创建优化器
#########################################################
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
#########################################################
# 创建学习率调度器
#########################################################
total_steps = len(train_loader) * args.epochs
warmup_steps = args.warmup_iters if args.warmup_iters > 0 else int(0.1 * total_steps)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
#########################################################
# 准备训练
#########################################################
model, optimizer, train_loader, scheduler = accelerator.prepare(
model, optimizer, train_loader, scheduler
)
#########################################################
# 训练循环
#########################################################
overall_start_time = time.time() # Record overall start time
for epoch in range(args.epochs):
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb) # Pass overall start time
#########################################################
# 关闭wandb
#########################################################
if args.use_wandb and accelerator.is_main_process:
wandb.finish()
if __name__ == "__main__":
main()