848 lines
43 KiB
Python
848 lines
43 KiB
Python
import os
|
||
# 设置环境变量
|
||
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
|
||
import platform
|
||
import argparse
|
||
from tqdm import tqdm
|
||
import time
|
||
import math
|
||
import warnings
|
||
import pandas as pd
|
||
import torch
|
||
from torch import optim, nn
|
||
from torch.utils.data import DataLoader
|
||
from contextlib import nullcontext
|
||
from typing import Optional
|
||
import datetime # Add datetime for time formatting
|
||
from accelerate import Accelerator
|
||
from accelerate.utils import set_seed
|
||
from accelerate.utils import DeepSpeedPlugin
|
||
from accelerate.utils import DistributedDataParallelKwargs
|
||
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
||
import numpy as np
|
||
from sklearn.metrics.pairwise import cosine_similarity
|
||
|
||
from model.model import MiniMindLM, RMSNorm
|
||
from model.LMConfig import LMConfig
|
||
from model.dataset import PretrainDataset
|
||
|
||
warnings.filterwarnings('ignore')
|
||
|
||
# 日志记录函数
|
||
def Logger(msg, accelerator=None):
|
||
# 如果没有提供accelerator,则只在主进程打印
|
||
if accelerator is None or accelerator.is_main_process:
|
||
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}")
|
||
|
||
# Helper function to format seconds into HH:MM:SS
|
||
def format_time(seconds):
|
||
return str(datetime.timedelta(seconds=int(seconds)))
|
||
|
||
# 获取学习率函数
|
||
def get_lr(it, num_iters, learning_rate):
|
||
# 余弦学习率衰减
|
||
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
|
||
|
||
# 初始化模型函数
|
||
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
|
||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||
model = MiniMindLM(lm_config)
|
||
|
||
# 默认模型初始化
|
||
Logger("Performing default model initialization...")
|
||
|
||
# 初始化嵌入层权重
|
||
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
|
||
|
||
# 初始化输出层权重(如果不共享权重的话)
|
||
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
|
||
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
|
||
|
||
# 初始化所有线性层
|
||
for name, module in model.named_modules():
|
||
if isinstance(module, nn.Linear):
|
||
# 使用Xavier/Glorot初始化
|
||
nn.init.xavier_uniform_(module.weight)
|
||
if module.bias is not None:
|
||
nn.init.zeros_(module.bias)
|
||
elif isinstance(module, nn.Embedding):
|
||
# 嵌入层使用正态分布初始化
|
||
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||
elif isinstance(module, RMSNorm):
|
||
# RMSNorm的权重初始化为1
|
||
if hasattr(module, 'weight'):
|
||
nn.init.ones_(module.weight)
|
||
|
||
# 初始化位置编码相关参数
|
||
if hasattr(model.knowledge_dataset, 'keys'):
|
||
nn.init.normal_(model.knowledge_dataset.keys, mean=0.0, std=0.02)
|
||
|
||
Logger("Default model initialization completed")
|
||
|
||
# 如果提供了预训练的嵌入权重,加载它们
|
||
if pretrained_embedding_path:
|
||
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
|
||
pretrained_embeddings = torch.load(pretrained_embedding_path)
|
||
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
|
||
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
|
||
|
||
if database_init_path:
|
||
import json
|
||
import numpy as np
|
||
from sentence_transformers import SentenceTransformer
|
||
import os
|
||
|
||
# 聚类参数(需要提前定义用于缓存检查)
|
||
knowledge_num = args.knowledge_num
|
||
knowledge_length = args.knowledge_length
|
||
|
||
# 检查是否使用缓存(提前检查,避免不必要的数据处理)
|
||
cache_dir = os.path.dirname(args.cluster_cache_path)
|
||
if cache_dir:
|
||
os.makedirs(cache_dir, exist_ok=True)
|
||
|
||
clustered_tensor = None
|
||
|
||
# 尝试加载缓存的聚类结果
|
||
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
|
||
try:
|
||
Logger(f"Loading cached cluster results from {args.cluster_cache_path}")
|
||
clustered_tensor = torch.load(args.cluster_cache_path)
|
||
|
||
# 验证缓存文件的形状是否可用
|
||
cached_knowledge_num, cached_knowledge_length = clustered_tensor.shape
|
||
|
||
if cached_knowledge_length == knowledge_length:
|
||
if cached_knowledge_num >= knowledge_num:
|
||
# 缓存足够大,可以截取使用
|
||
clustered_tensor = clustered_tensor[:knowledge_num, :]
|
||
Logger(f"Successfully loaded cached clusters with shape {clustered_tensor.shape}")
|
||
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
|
||
Logger("Skipping database initialization and clustering - using cached results")
|
||
else:
|
||
# 缓存太小,需要重新计算
|
||
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
|
||
clustered_tensor = None
|
||
else:
|
||
# knowledge_length不匹配,需要重新计算
|
||
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
|
||
clustered_tensor = None
|
||
except Exception as e:
|
||
Logger(f"Failed to load cached clusters: {e}, recomputing...")
|
||
clustered_tensor = None
|
||
|
||
# 只有在没有有效缓存时才进行数据库初始化和聚类计算
|
||
if clustered_tensor is None:
|
||
Logger(f"Loading database initialization data from {database_init_path}")
|
||
|
||
# 1. 加载JSON文件并转换为字典
|
||
with open(database_init_path, 'r', encoding='utf-8') as f:
|
||
database_data = json.load(f)
|
||
|
||
# 提取sentences列表
|
||
sentences_data = database_data.get('sentences', [])
|
||
Logger(f"Loaded {len(sentences_data)} sentences from database")
|
||
|
||
# 2. 按照importance_score进行排序(从高到低)
|
||
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
||
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
|
||
|
||
# 3. 下载并初始化本地嵌入模型
|
||
embedding_model_name = "sentence-transformers/all-mpnet-base-v2" # 轻量级但效果好的模型
|
||
embedding_model_dir = "./models/sentence_transformers/models--sentence-transformers--all-mpnet-base-v2"
|
||
embedding_cache_dir = "./models/sentence_transformers/cache"
|
||
os.makedirs(embedding_cache_dir, exist_ok=True)
|
||
|
||
Logger(f"Loading embedding model: {embedding_model_name}")
|
||
try:
|
||
embedding_model = SentenceTransformer(embedding_model_dir, cache_folder=embedding_cache_dir)
|
||
Logger("Embedding model loaded successfully")
|
||
except Exception as e:
|
||
Logger(f"Failed to load embedding model: {e}")
|
||
Logger("Falling back to random embeddings")
|
||
embedding_model = None
|
||
|
||
# 4. 对每个corrected_sentence进行嵌入和token长度计算
|
||
Logger("Processing sentences for embeddings and token lengths...")
|
||
|
||
# 提取所有句子
|
||
sentences = [sentence_data.get('corrected_sentence', '') for sentence_data in sorted_sentences]
|
||
|
||
# 批量计算token长度
|
||
Logger("Computing token lengths...")
|
||
token_lengths = []
|
||
for sentence in sentences:
|
||
tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
||
token_lengths.append(len(tokens))
|
||
|
||
# 批量计算嵌入 - 大幅提升速度
|
||
Logger("Computing embeddings in batches...")
|
||
embeddings_list = []
|
||
batch_size = 256 # 可以根据GPU内存调整
|
||
|
||
if embedding_model is not None:
|
||
try:
|
||
for i in range(0, len(sentences), batch_size):
|
||
batch_sentences = sentences[i:i+batch_size]
|
||
batch_embeddings = embedding_model.encode(
|
||
batch_sentences,
|
||
convert_to_tensor=False,
|
||
show_progress_bar=True if i == 0 else False,
|
||
batch_size=batch_size
|
||
)
|
||
embeddings_list.extend(batch_embeddings)
|
||
|
||
if (i + batch_size) % (batch_size * 10) == 0:
|
||
Logger(f"Processed {min(i + batch_size, len(sentences))}/{len(sentences)} sentences")
|
||
|
||
Logger("Batch embedding computation completed")
|
||
except Exception as e:
|
||
Logger(f"Error in batch encoding: {e}")
|
||
Logger("Falling back to random embeddings")
|
||
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
|
||
else:
|
||
# 使用随机嵌入
|
||
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
|
||
|
||
# 创建处理后的句子列表
|
||
processed_sentences = []
|
||
for i, (sentence_data, embedding, token_length) in enumerate(zip(sorted_sentences, embeddings_list, token_lengths)):
|
||
processed_sentences.append({
|
||
'sentence': sentence_data.get('corrected_sentence', ''),
|
||
'importance_score': sentence_data.get('importance_score', 0.0),
|
||
'token_length': token_length,
|
||
'embedding': embedding, # Convert numpy array to list
|
||
'original_index': i
|
||
})
|
||
|
||
# 转换为numpy数组以便后续处理
|
||
embeddings_array = np.array(embeddings_list)
|
||
token_lengths_array = np.array(token_lengths)
|
||
|
||
Logger(f"Embedding processing completed:")
|
||
Logger(f" - Total sentences: {len(processed_sentences)}")
|
||
Logger(f" - Embedding shape: {embeddings_array.shape}")
|
||
Logger(f" - Average token length: {np.mean(token_lengths_array):.2f}")
|
||
Logger(f" - Token length range: {np.min(token_lengths_array)} - {np.max(token_lengths_array)}")
|
||
|
||
# 聚类参数定义
|
||
min_tokens = int(0.85 * knowledge_length)
|
||
max_tokens = int(0.95 * knowledge_length)
|
||
|
||
# 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大)
|
||
if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算
|
||
Logger("Pre-computing similarity matrix for faster clustering...")
|
||
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
|
||
similarity_matrix = cosine_similarity(embeddings_matrix)
|
||
Logger(f"Similarity matrix computed: {similarity_matrix.shape}")
|
||
else:
|
||
similarity_matrix = None
|
||
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
|
||
|
||
clustered_rows = []
|
||
remaining_indices = list(range(len(processed_sentences))) # 使用索引而不是对象
|
||
|
||
Logger(f"Target: {knowledge_num} clusters, each with {min_tokens}-{max_tokens} tokens")
|
||
|
||
# 选择聚类算法
|
||
if args.fast_clustering and len(processed_sentences) > 5000:
|
||
Logger("Using ultra-fast approximate clustering algorithm...")
|
||
|
||
# 超快速聚类:随机采样 + 批量处理
|
||
import random
|
||
random.seed(42) # 确保可重现性
|
||
|
||
# 按重要性分层采样
|
||
high_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] > 0.7]
|
||
medium_importance = [i for i, s in enumerate(processed_sentences) if 0.3 <= s['importance_score'] <= 0.7]
|
||
low_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] < 0.3]
|
||
|
||
Logger(f"Importance distribution: High={len(high_importance)}, Medium={len(medium_importance)}, Low={len(low_importance)}")
|
||
|
||
for cluster_idx in tqdm(range(knowledge_num)):
|
||
# 分层选择种子:优先选择高重要性句子
|
||
if high_importance:
|
||
seed_pool = high_importance
|
||
elif medium_importance:
|
||
seed_pool = medium_importance
|
||
else:
|
||
seed_pool = low_importance if low_importance else list(range(len(processed_sentences)))
|
||
|
||
if not seed_pool:
|
||
break
|
||
|
||
# 随机选择种子(在同一重要性层级内)
|
||
seed_global_idx = random.choice(seed_pool)
|
||
seed_sentence = processed_sentences[seed_global_idx]
|
||
|
||
# 从所有池中移除种子
|
||
for pool in [high_importance, medium_importance, low_importance]:
|
||
if seed_global_idx in pool:
|
||
pool.remove(seed_global_idx)
|
||
|
||
current_cluster_indices = [seed_global_idx]
|
||
current_tokens = seed_sentence['token_length']
|
||
|
||
if current_tokens < max_tokens:
|
||
# 快速选择:只从附近的句子中随机选择
|
||
all_remaining = high_importance + medium_importance + low_importance
|
||
if all_remaining:
|
||
# 随机采样候选句子(而不是计算所有相似度)
|
||
sample_size = min(2000, len(all_remaining))
|
||
candidates = random.sample(all_remaining, sample_size)
|
||
|
||
# 简单按token长度和重要性选择
|
||
for candidate_idx in candidates:
|
||
candidate = processed_sentences[candidate_idx]
|
||
candidate_tokens = candidate['token_length']
|
||
|
||
if current_tokens + candidate_tokens + 1 <= max_tokens:
|
||
current_cluster_indices.append(candidate_idx)
|
||
current_tokens += candidate_tokens + 1
|
||
|
||
# 从池中移除
|
||
for pool in [high_importance, medium_importance, low_importance]:
|
||
if candidate_idx in pool:
|
||
pool.remove(candidate_idx)
|
||
break
|
||
|
||
if current_tokens >= min_tokens:
|
||
break
|
||
|
||
# 生成聚类文本
|
||
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
|
||
cluster_text = '\n '.join(cluster_sentences)
|
||
|
||
# 转换为tokens
|
||
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
|
||
if len(cluster_tokens) > knowledge_length:
|
||
cluster_tokens = cluster_tokens[:knowledge_length]
|
||
else:
|
||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
|
||
|
||
clustered_rows.append(cluster_tokens)
|
||
|
||
if (cluster_idx + 1) % 1000 == 0:
|
||
total_remaining = len(high_importance) + len(medium_importance) + len(low_importance)
|
||
Logger(f"Fast clustering: {cluster_idx + 1}/{knowledge_num} clusters, {total_remaining} sentences remaining")
|
||
|
||
else:
|
||
# 原始优化算法(适用于中等规模数据集)
|
||
# 优化2: 批量处理和更高效的数据结构
|
||
for cluster_idx in tqdm(range(knowledge_num)):
|
||
if not remaining_indices:
|
||
Logger(f"No more sentences available. Created {cluster_idx} clusters.")
|
||
break
|
||
|
||
# 2.1 选择importance_score最高的句子作为种子
|
||
remaining_sentences_subset = [processed_sentences[i] for i in remaining_indices]
|
||
seed_idx_in_subset = max(range(len(remaining_sentences_subset)),
|
||
key=lambda i: remaining_sentences_subset[i]['importance_score'])
|
||
seed_global_idx = remaining_indices[seed_idx_in_subset]
|
||
seed_sentence = processed_sentences[seed_global_idx]
|
||
|
||
# 从剩余索引中移除种子
|
||
remaining_indices.remove(seed_global_idx)
|
||
|
||
# 当前聚类
|
||
current_cluster_indices = [seed_global_idx]
|
||
current_tokens = seed_sentence['token_length']
|
||
|
||
if current_tokens >= max_tokens:
|
||
# 如果种子句子已经超过最大token数,直接作为一个聚类
|
||
cluster_text = seed_sentence['sentence']
|
||
else:
|
||
# 2.2 优化的相似度计算和选择
|
||
if remaining_indices:
|
||
if similarity_matrix is not None:
|
||
# 使用预计算的相似度矩阵
|
||
similarities = similarity_matrix[seed_global_idx][remaining_indices]
|
||
else:
|
||
# 动态计算相似度(批量)
|
||
seed_embedding = embeddings_matrix[seed_global_idx:seed_global_idx+1]
|
||
remaining_embeddings = embeddings_matrix[remaining_indices]
|
||
similarities = cosine_similarity(seed_embedding, remaining_embeddings)[0]
|
||
|
||
# 创建(相似度, 原始索引, 在remaining_indices中的位置)的元组列表
|
||
similarity_tuples = [(similarities[i], remaining_indices[i], i)
|
||
for i in range(len(remaining_indices))]
|
||
|
||
# 按相似度排序(降序)
|
||
similarity_tuples.sort(key=lambda x: x[0], reverse=True)
|
||
|
||
# 优化3: 贪心选择,但限制搜索范围以提高速度
|
||
max_candidates = min(len(similarity_tuples), 500) # 只考虑前500个最相似的句子
|
||
|
||
selected_indices_in_remaining = []
|
||
for sim_score, global_idx, pos_in_remaining in similarity_tuples[:max_candidates]:
|
||
candidate = processed_sentences[global_idx]
|
||
candidate_tokens = candidate['token_length']
|
||
|
||
if current_tokens + candidate_tokens + 1 <= max_tokens: # +1 for newline
|
||
current_cluster_indices.append(global_idx)
|
||
selected_indices_in_remaining.append(pos_in_remaining)
|
||
current_tokens += candidate_tokens + 1
|
||
|
||
if current_tokens >= min_tokens:
|
||
break
|
||
|
||
# 批量移除选中的句子(从后往前移除以避免索引问题)
|
||
for pos in sorted(selected_indices_in_remaining, reverse=True):
|
||
remaining_indices.pop(pos)
|
||
|
||
# 拼接句子
|
||
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
|
||
cluster_text = '\n'.join(cluster_sentences)
|
||
|
||
# 将聚类文本转换为token
|
||
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
|
||
|
||
# 截断或填充到knowledge_length
|
||
if len(cluster_tokens) > knowledge_length:
|
||
cluster_tokens = cluster_tokens[:knowledge_length]
|
||
else:
|
||
# 用pad_token_id填充
|
||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
|
||
|
||
clustered_rows.append(cluster_tokens)
|
||
|
||
# 优化4: 减少日志频率
|
||
if (cluster_idx + 1) % 500 == 0:
|
||
Logger(f"Created {cluster_idx + 1}/{knowledge_num} clusters, {len(remaining_indices)} sentences remaining")
|
||
|
||
# 如果聚类数量不足,用随机token填充
|
||
while len(clustered_rows) < knowledge_num:
|
||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||
random_tokens = [pad_token_id] * knowledge_length
|
||
clustered_rows.append(random_tokens)
|
||
|
||
# 转换为tensor
|
||
clustered_tensor = torch.tensor(clustered_rows, dtype=torch.long)
|
||
|
||
Logger(f"Clustering completed:")
|
||
Logger(f" - Created {len(clustered_rows)} clusters")
|
||
Logger(f" - Cluster shape: {clustered_tensor.shape}")
|
||
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
|
||
|
||
# 保存聚类结果到缓存文件
|
||
try:
|
||
torch.save(clustered_tensor, args.cluster_cache_path)
|
||
Logger(f"Cluster results saved to {args.cluster_cache_path}")
|
||
except Exception as e:
|
||
Logger(f"Failed to save cluster results: {e}")
|
||
|
||
# 3. 初始化模型的weight_down_embed
|
||
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
|
||
model.knowledge_dataset.knowledge_dataset.data.copy_(clustered_tensor)
|
||
Logger("Successfully initialized model.knowledge_dataset.knowledge_dataset with clustered data")
|
||
else:
|
||
Logger("Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize")
|
||
# 存储为全局变量作为备选
|
||
globals()['clustered_database'] = clustered_tensor
|
||
|
||
Logger(f"Database embeddings and sentences stored in model")
|
||
|
||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||
return model, tokenizer
|
||
|
||
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb):
|
||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||
epoch_start_time = time.time()
|
||
total_steps_in_epoch = len(train_loader)
|
||
total_training_steps = args.epochs * total_steps_in_epoch
|
||
moe_path = '_moe' if args.use_moe else ''
|
||
|
||
# 添加CUDA事件来分析性能 (只在主进程进行)
|
||
if args.profile and accelerator.is_main_process:
|
||
data_start = torch.cuda.Event(enable_timing=True)
|
||
data_end = torch.cuda.Event(enable_timing=True)
|
||
forward_start = torch.cuda.Event(enable_timing=True)
|
||
forward_end = torch.cuda.Event(enable_timing=True)
|
||
backward_start = torch.cuda.Event(enable_timing=True)
|
||
backward_end = torch.cuda.Event(enable_timing=True)
|
||
optimizer_start = torch.cuda.Event(enable_timing=True)
|
||
optimizer_end = torch.cuda.Event(enable_timing=True)
|
||
|
||
# 预取数据
|
||
prefetch_factor = 2 # 预取的批次数
|
||
data_iter = iter(train_loader)
|
||
prefetch_batches = []
|
||
|
||
# 预取初始批次
|
||
for _ in range(min(prefetch_factor, len(train_loader))):
|
||
try:
|
||
batch = next(data_iter)
|
||
prefetch_batches.append(batch)
|
||
except StopIteration:
|
||
break
|
||
|
||
# 在开始循环前初始化日志记录所需变量
|
||
last_log_time = epoch_start_time
|
||
|
||
for step in range(total_steps_in_epoch):
|
||
try:
|
||
# 计时数据加载 (只在主进程进行)
|
||
if args.profile and accelerator.is_main_process:
|
||
data_start.record()
|
||
|
||
# 使用预取的数据
|
||
if prefetch_batches:
|
||
X, Y, loss_mask = prefetch_batches.pop(0)
|
||
else:
|
||
# 如果预取队列为空,直接加载
|
||
X, Y, loss_mask = next(data_iter)
|
||
|
||
# 异步预取下一批数据
|
||
if step + prefetch_factor < len(train_loader):
|
||
try:
|
||
batch = next(data_iter)
|
||
prefetch_batches.append(batch)
|
||
except StopIteration:
|
||
pass
|
||
|
||
# 计时数据加载结束 (只在主进程进行)
|
||
if args.profile and accelerator.is_main_process:
|
||
data_end.record()
|
||
|
||
# 更新学习率
|
||
if scheduler is not None:
|
||
scheduler.step()
|
||
|
||
# 计时前向传播 (只在主进程进行)
|
||
if args.profile and accelerator.is_main_process:
|
||
forward_start.record()
|
||
|
||
# 前向传播
|
||
with ctx:
|
||
res = model(X)
|
||
loss = loss_fct(
|
||
res.logits.view(-1, res.logits.size(-1)),
|
||
Y.view(-1)
|
||
).view(Y.size())
|
||
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||
# 添加辅助损失,如果存在的话
|
||
try:
|
||
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
|
||
if hasattr(l.feed_forward, 'aux_loss'))
|
||
loss += aux_loss
|
||
except Exception as e:
|
||
Logger(f"Warning: Could not add auxiliary loss: {e}")
|
||
# 如果出错,不添加辅助损失
|
||
loss = loss / args.accumulation_steps
|
||
|
||
# 计时前向传播结束 (只在主进程进行)
|
||
if args.profile and accelerator.is_main_process:
|
||
forward_end.record()
|
||
|
||
# 计时反向传播 (只在主进程进行)
|
||
if args.profile and accelerator.is_main_process:
|
||
backward_start.record()
|
||
|
||
# 反向传播
|
||
# 当使用DeepSpeed时,它会自动处理梯度累积和梯度裁剪
|
||
accelerator.backward(loss)
|
||
|
||
# 计时反向传播结束 (只在主进程进行)
|
||
if args.profile and accelerator.is_main_process:
|
||
backward_end.record()
|
||
|
||
# 计时优化器步骤 (只在主进程进行)
|
||
if args.profile and accelerator.is_main_process:
|
||
optimizer_start.record()
|
||
|
||
# 优化器步骤 - 当使用DeepSpeed时,它会自动处理梯度累积和梯度裁剪
|
||
# 只有在达到累积步数时才会执行优化器步骤
|
||
# 注意:当使用DeepSpeed时,它会自动处理梯度累积,所以我们不需要检查step % accumulation_steps
|
||
optimizer.step()
|
||
|
||
# 当使用DeepSpeed时,zero_grad()会在step()之后自动调用
|
||
# 但为了安全起见,我们仍然显式调用它
|
||
optimizer.zero_grad()
|
||
|
||
# 计时优化器步骤结束 (只在主进程进行)
|
||
if args.profile and accelerator.is_main_process:
|
||
optimizer_end.record()
|
||
|
||
# 打印训练信息 (只在主进程进行)
|
||
if (step + 1) % args.log_interval == 0 and accelerator.is_main_process:
|
||
current_time = time.time()
|
||
# 计算性能指标
|
||
if args.profile:
|
||
torch.cuda.synchronize()
|
||
# 使用自上次日志以来的时间计算性能指标,而不是总时间
|
||
data_time = data_start.elapsed_time(data_end)
|
||
forward_time = forward_start.elapsed_time(forward_end)
|
||
backward_time = backward_start.elapsed_time(backward_end)
|
||
optimizer_time = optimizer_start.elapsed_time(optimizer_end)
|
||
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
|
||
# total_time_ms = data_time + forward_time + backward_time + optimizer_time
|
||
|
||
# 打印性能分析
|
||
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
|
||
Logger(f"性能分析 (Avg/iter over last {args.log_interval} steps) - "
|
||
f"Data: {data_time/args.log_interval:.2f}ms, "
|
||
f"Fwd: {forward_time/args.log_interval:.2f}ms, "
|
||
f"Bwd: {backward_time/args.log_interval:.2f}ms, "
|
||
f"Optim: {optimizer_time/args.log_interval:.2f}ms, "
|
||
f"Iter Time: {iter_time:.2f}ms", accelerator)
|
||
# 重置事件以便下次测量从0开始
|
||
data_start = torch.cuda.Event(enable_timing=True)
|
||
data_end = torch.cuda.Event(enable_timing=True)
|
||
forward_start = torch.cuda.Event(enable_timing=True)
|
||
forward_end = torch.cuda.Event(enable_timing=True)
|
||
backward_start = torch.cuda.Event(enable_timing=True)
|
||
backward_end = torch.cuda.Event(enable_timing=True)
|
||
optimizer_start = torch.cuda.Event(enable_timing=True)
|
||
optimizer_end = torch.cuda.Event(enable_timing=True)
|
||
|
||
|
||
# 计算当前学习率
|
||
current_lr = optimizer.param_groups[0]['lr']
|
||
|
||
# 计算时间
|
||
epoch_elapsed_time = current_time - epoch_start_time
|
||
epoch_steps_done = step + 1
|
||
epoch_avg_step_time = epoch_elapsed_time / epoch_steps_done
|
||
epoch_remaining_time = epoch_avg_step_time * (total_steps_in_epoch - epoch_steps_done)
|
||
|
||
total_elapsed_time = current_time - overall_start_time
|
||
total_steps_done = epoch * total_steps_in_epoch + epoch_steps_done
|
||
total_avg_step_time = total_elapsed_time / total_steps_done if total_steps_done > 0 else 0
|
||
total_remaining_time = total_avg_step_time * (total_training_steps - total_steps_done) if total_steps_done > 0 else 0
|
||
|
||
# 计算训练速度 (基于最近的log_interval)
|
||
interval_elapsed_time = current_time - last_log_time
|
||
tokens_processed_interval = args.log_interval * args.batch_size * args.max_seq_len
|
||
tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0
|
||
last_log_time = current_time # 更新上次日志时间
|
||
|
||
log_dict = {
|
||
"epoch": epoch + 1,
|
||
"step": step + 1,
|
||
"total_steps_in_epoch": total_steps_in_epoch,
|
||
"loss": loss.item() * args.accumulation_steps,
|
||
"lr": current_lr,
|
||
"tokens_per_sec": tokens_per_sec,
|
||
"epoch_time_left_seconds": epoch_remaining_time,
|
||
"total_time_left_seconds": total_remaining_time
|
||
}
|
||
|
||
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
|
||
f"Loss: {log_dict['loss']:.4f}, "
|
||
f"LR: {log_dict['lr']:.6f}, "
|
||
f"Speed: {log_dict['tokens_per_sec']:.2f} tokens/sec | "
|
||
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
|
||
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
|
||
|
||
if args.use_wandb and accelerator.is_main_process and wandb:
|
||
wandb.log(log_dict)
|
||
|
||
# 保存模型 (只在主进程进行)
|
||
if (step + 1) % args.save_interval == 0 and accelerator.is_main_process:
|
||
# 使用函数开始处定义的moe_path变量
|
||
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
|
||
|
||
# 获取解包后的模型
|
||
unwrapped_model = accelerator.unwrap_model(model)
|
||
|
||
# 保存模型参数
|
||
accelerator.save(unwrapped_model.state_dict(), ckp)
|
||
Logger(f"Model saved to {ckp}", accelerator)
|
||
|
||
except Exception as e:
|
||
Logger(f"Error in training step: {e}", accelerator)
|
||
import traceback
|
||
Logger(traceback.format_exc(), accelerator)
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
|
||
parser.add_argument("--out_dir", type=str, default="out")
|
||
parser.add_argument("--epochs", type=int, default=4)
|
||
parser.add_argument("--batch_size", type=int, default=48)
|
||
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||
parser.add_argument("--use_wandb", default=True, action="store_true")
|
||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
|
||
parser.add_argument("--num_workers", type=int, default=8)
|
||
parser.add_argument("--accumulation_steps", type=int, default=32)
|
||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||
parser.add_argument("--warmup_iters", type=int, default=0)
|
||
parser.add_argument("--log_interval", type=int, default=100)
|
||
parser.add_argument("--save_interval", type=int, default=10000)
|
||
parser.add_argument('--dim', default=512, type=int)
|
||
parser.add_argument('--n_layers', default=8, type=int)
|
||
parser.add_argument('--max_seq_len', default=512, type=int)
|
||
parser.add_argument('--use_moe', default=False, type=bool)
|
||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
||
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
|
||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||
parser.add_argument("--knowledge_num", type=int, default=4096,help="知识库的数据数目")
|
||
parser.add_argument("--knowledge_length", type=int, default=16,help="知识库的句子长度")
|
||
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
|
||
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
|
||
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
|
||
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
|
||
args = parser.parse_args()
|
||
|
||
#########################################################
|
||
# 初始化accelerator和deepspeed
|
||
#########################################################
|
||
# 设置ddp_kwargs以处理未使用的参数
|
||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||
# 创建DeepSpeedPlugin对象
|
||
ds_plugin = DeepSpeedPlugin(
|
||
gradient_accumulation_steps=args.accumulation_steps,
|
||
gradient_clipping=args.grad_clip,
|
||
zero_stage=2, # 使用ZeRO-2优化
|
||
offload_optimizer_device="cpu", # 将优化器状态卸载到CPU
|
||
offload_param_device="none", # 不将参数卸载到CPU
|
||
)
|
||
accelerator = Accelerator(
|
||
kwargs_handlers=[ddp_kwargs],
|
||
deepspeed_plugin=ds_plugin,
|
||
mixed_precision="bf16" if args.dtype == "bfloat16" else "fp16" if args.dtype == "float16" else "no"
|
||
)
|
||
|
||
#########################################################
|
||
# 设置随机种子
|
||
#########################################################
|
||
set_seed(1337 + accelerator.process_index)
|
||
|
||
#########################################################
|
||
# 配置模型
|
||
#########################################################
|
||
lm_config = LMConfig(
|
||
dim=args.dim,
|
||
n_layers=args.n_layers,
|
||
max_seq_len=args.max_seq_len,
|
||
use_moe=args.use_moe,
|
||
disable_db=args.disable_db,
|
||
flash_attn=args.use_flash_attn,
|
||
knowledge_num=args.knowledge_num,
|
||
knowledge_length=args.knowledge_length
|
||
)
|
||
|
||
#########################################################
|
||
# 创建保存目录
|
||
#########################################################
|
||
args.save_dir = os.path.join(args.out_dir)
|
||
if accelerator.is_main_process:
|
||
os.makedirs(args.save_dir, exist_ok=True)
|
||
os.makedirs(args.out_dir, exist_ok=True)
|
||
|
||
#########################################################
|
||
# 设置数据类型
|
||
#########################################################
|
||
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
|
||
|
||
|
||
#########################################################
|
||
# 配置wandb
|
||
#########################################################
|
||
# 设置wandb运行名称
|
||
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||
if args.use_wandb and accelerator.is_main_process:
|
||
import wandb
|
||
# 合并args和lm_config为一个字典
|
||
config_dict = vars(args).copy()
|
||
config_dict.update(vars(lm_config))
|
||
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config_dict)
|
||
else:
|
||
wandb = None
|
||
|
||
#########################################################
|
||
# 打印信息
|
||
#########################################################
|
||
# 计算每次迭代的token数量
|
||
tokens_per_iter = args.batch_size * lm_config.max_seq_len
|
||
if accelerator.is_main_process:
|
||
Logger(f"tokens_per_iter: {tokens_per_iter}", accelerator)
|
||
Logger("Configuration:", accelerator)
|
||
for key, value in config_dict.items():
|
||
Logger(f" {key}: {value}", accelerator)
|
||
|
||
|
||
#########################################################
|
||
# 设置自动混合精度上下文
|
||
#########################################################
|
||
ctx = nullcontext() if accelerator.device.type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
|
||
|
||
#########################################################
|
||
# 初始化模型和tokenizer
|
||
#########################################################
|
||
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path, args.database_init_path, args)
|
||
# 将accelerator传递给init_model函数中的Logger调用
|
||
Logger(f'模型初始化完成', accelerator)
|
||
|
||
#########################################################
|
||
# 处理位置编码张量问题
|
||
#########################################################
|
||
if hasattr(model, "pos_cis_real"):
|
||
Logger(f'检测到pos_cis_real实数张量,将其设置为参与分布式训练', accelerator)
|
||
# 设置模型的_ddp_params_and_buffers_to_ignore属性
|
||
# model._ddp_params_and_buffers_to_ignore = {"pos_cis_real"}
|
||
# 兼容旧版本,检查是否仍有pos_cis
|
||
elif hasattr(model, "pos_cis"):
|
||
Logger(f'检测到pos_cis复数张量,将其设置为不参与分布式训练', accelerator)
|
||
# 设置模型的_ddp_params_and_buffers_to_ignore属性
|
||
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
||
|
||
#########################################################
|
||
# 创建数据集和数据加载器
|
||
#########################################################
|
||
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||
train_loader = DataLoader(
|
||
train_ds,
|
||
batch_size=args.batch_size,
|
||
pin_memory=True,
|
||
drop_last=False,
|
||
shuffle=True,
|
||
num_workers=args.num_workers,
|
||
persistent_workers=True if args.num_workers > 0 else False,
|
||
prefetch_factor=2 if args.num_workers > 0 else None
|
||
)
|
||
|
||
#########################################################
|
||
# 创建优化器
|
||
#########################################################
|
||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||
|
||
#########################################################
|
||
# 创建学习率调度器
|
||
#########################################################
|
||
total_steps = len(train_loader) * args.epochs
|
||
warmup_steps = args.warmup_iters if args.warmup_iters > 0 else int(0.1 * total_steps)
|
||
scheduler = get_cosine_schedule_with_warmup(
|
||
optimizer,
|
||
num_warmup_steps=warmup_steps,
|
||
num_training_steps=total_steps
|
||
)
|
||
|
||
#########################################################
|
||
# 准备训练
|
||
#########################################################
|
||
model, optimizer, train_loader, scheduler = accelerator.prepare(
|
||
model, optimizer, train_loader, scheduler
|
||
)
|
||
|
||
#########################################################
|
||
# 训练循环
|
||
#########################################################
|
||
overall_start_time = time.time() # Record overall start time
|
||
for epoch in range(args.epochs):
|
||
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb) # Pass overall start time
|
||
|
||
#########################################################
|
||
# 关闭wandb
|
||
#########################################################
|
||
if args.use_wandb and accelerator.is_main_process:
|
||
wandb.finish()
|
||
|
||
if __name__ == "__main__":
|
||
main()
|