Compare commits

..

4 Commits

Author SHA1 Message Date
f750edd9ba update 2025-06-23 23:47:10 +08:00
5f19adcffa update 2025-06-23 23:05:47 +08:00
44cd7b4d72 update 2025-06-23 22:15:51 +08:00
83b91859ce update 2025-06-20 12:43:21 +08:00
2 changed files with 313 additions and 145 deletions

View File

@ -2,7 +2,8 @@ import math
import struct import struct
import inspect import inspect
import time import time
import gc
#子空间二维分解+梯度更新
from .LMConfig import LMConfig from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union from typing import Any, Optional, Tuple, List, Union
import numpy as np import numpy as np
@ -67,23 +68,21 @@ class KnowledgeDataset(nn.Module):
## 数据库参数 ## 数据库参数
self.knowledge_num = params.knowledge_num self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length self.knowledge_length = params.knowledge_length
self.keys = nn.Parameter(torch.randn(self.knowledge_num, self.knowledge_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.knowledge_num)
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动 # 修改键存储为二维分解空间,设置为可训练参数
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num)) self.num_keys = int(math.sqrt(self.knowledge_num))
# 确保keys是可训练参数
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度 # 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset', self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long) torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
)
# 计算step数目用于动态调整权重 # 计算step数目用于动态调整权重
self.step_counter = 0 self.step_counter = 0
self.freeze_embedding = False # 移除批次计数器和更新频率相关代码
def intelligent_selection(self, query, all_scores, all_indices): def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略""" """智能分层选择策略"""
@ -94,6 +93,15 @@ class KnowledgeDataset(nn.Module):
device = all_scores.device device = all_scores.device
dtype = all_scores.dtype dtype = all_scores.dtype
# 记录进入智能选择前的内存状态
if hasattr(self, 'step_counter'):
self.step_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.step_counter % 50 == 0: # 每50次调用记录一次
# if torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_ENTER] Step {self.step_counter}: GPU Memory: {allocated_before:.2f}GB")
# 对每个batch进行分层选择 # 对每个batch进行分层选择
enhanced_scores = all_scores.clone() enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim] query_features = query.mean(dim=1) # [batch_size, dim]
@ -106,7 +114,8 @@ class KnowledgeDataset(nn.Module):
candidate_tokens = self.knowledge_dataset[unique_indices] candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1) flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens) flat_embeddings = self.tok_embeddings(flat_tokens)
#获取flat_tokens对应的index
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1) pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view( pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1 len(unique_indices), self.knowledge_length, -1
@ -158,84 +167,63 @@ class KnowledgeDataset(nn.Module):
all_best_tokens = torch.stack(batch_best_tokens, dim=0) all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0) all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
# 获取 # 清理中间张量以防止内存泄漏
del all_candidate_indices, unique_indices, inverse_indices
del unique_candidate_features, normalized_candidates, normalized_queries
del batch_best_tokens, batch_best_tokens_embeddings
del flat_tokens, flat_embeddings, pre_update_embeddings
# 使用重新计算的embeddings更新self.keys # 记录退出智能选择后的内存状态(已禁用以提高性能)
if self.is_train: # if hasattr(self, 'step_counter') and self.step_counter % 50 == 0:
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings) # if torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_EXIT] Step {self.step_counter}: GPU Memory: {allocated_after:.2f}GB")
# 更新被修改过的key # 强制垃圾回收(仅在监控步骤)
with torch.no_grad(): if hasattr(self, 'step_counter') and self.step_counter % 100 == 0:
self.has_update_keys[pre_update_indices] = 1 gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return all_best_tokens, all_best_tokens_embeddings return all_best_tokens, all_best_tokens_embeddings
def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings):
if self.freeze_embedding:
return
# 使用pre_update_embeddings更新self.keys
with torch.no_grad():
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512]
pre_update_embeddings = self.to_queries(pre_update_embeddings)
self.keys[pre_update_indices] = pre_update_embeddings
def search_index(self,x):
def search_index(self, x):
batch_size, seq_len, dim = x.shape batch_size, seq_len, dim = x.shape
# collapse sequence dimension by averaging # 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim] x_flat = x.mean(dim=1) # [batch_size, dim]
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key] # 2. 生成查询向量并重塑为两个子查询
# queries = queries.reshape(batch_size, 2, self.key_dim) queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
# queries = queries.permute(1, 0, 2) queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
# 2. 计算queries与keys的相似度 # 3. 计算每个子空间的相似度
sim = torch.einsum('b d, k d -> b k', queries, self.keys) sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 3. 在两个子空间分别做top-k # 4. 在两个子空间分别做top-k
scores_and_indices = sim.topk(self.product_key_topk, dim=-1) scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores, indices = scores_and_indices[0], scores_and_indices[1] scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 5. 应用智能分层选择策略 # 5. 组合两个子空间的结果
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices) best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
# 6. 更新1%的keys
if self.is_train:
# 获取未更新过的keys的索引
not_updated_indices = torch.where(self.has_update_keys == 0)[0]
# 如果有未更新的keys随机选择num_update_keys个进行更新
if len(not_updated_indices) > 0:
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
perm_num = perm.shape[0]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
else:
print("all keys are updated")
# 重置所有keys的更新状态
self.has_update_keys.zero_()
# 重新获取所有可更新的索引
not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device)
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return best_tokens, best_tokens_embeddings return best_tokens, best_tokens_embeddings
@ -257,6 +245,16 @@ class CrossAttention(nn.Module):
def forward(self, x, db, context_mask=None, pos_emb=None): def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0) batch_size = x.size(0)
# 监控交叉注意力开始时的内存(已禁用以提高性能)
if not hasattr(self, 'call_counter'):
self.call_counter = 0
self.call_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_ENTER] Call {self.call_counter}: GPU Memory: {allocated_before:.2f}GB")
# 分离多头 # 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
@ -282,6 +280,14 @@ class CrossAttention(nn.Module):
context = self.to_out(context) context = self.to_out(context)
# 清理中间张量
del q, k, v, attn_scores, attn_weights
# 监控交叉注意力结束时的内存(已禁用以提高性能)
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_EXIT] Call {self.call_counter}: GPU Memory: {allocated_after:.2f}GB")
return context return context
class Attention(nn.Module): class Attention(nn.Module):
@ -520,12 +526,11 @@ class MiniMindLM(PreTrainedModel):
step: int = 0, step: int = 0,
**args): **args):
start_pos = args.get('start_pos', 0) start_pos = args.get('start_pos', 0)
if self.freeze_embedding and step == 0: # if self.freeze_embedding and step == 0:
self.tok_embeddings.weight.requires_grad = False # self.tok_embeddings.weight.requires_grad = False
# 同时冻结KnowledgeDataset的嵌入更新 # # 移除对knowledge_dataset.freeze_embedding的设置让键更新由batch_counter控制
self.knowledge_dataset.freeze_embedding = True # # self.knowledge_dataset.freeze_embedding = True
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad) # print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding)
h = self.dropout(self.tok_embeddings(input_ids)) h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
@ -601,3 +606,4 @@ class MiniMindLM(PreTrainedModel):
yield input_ids[:, start:] yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id: if input_ids_next.item() == eos_token_id:
break break

View File

@ -1,6 +1,6 @@
import os import os
# 设置环境变量 # 设置环境变量 - 将wandb替换为SwanLab
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun" # os.environ["SWANLAB_MODE"] = "online" # SwanLab使用在线模式
import platform import platform
import argparse import argparse
from tqdm import tqdm from tqdm import tqdm
@ -21,6 +21,9 @@ from accelerate.utils import DistributedDataParallelKwargs
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
import numpy as np import numpy as np
from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_similarity
import swanlab # 替换wandb导入
import gc # 添加垃圾回收模块
import psutil # 添加系统资源监控模块
from model.model import MiniMindLM, RMSNorm from model.model import MiniMindLM, RMSNorm
from model.LMConfig import LMConfig from model.LMConfig import LMConfig
@ -28,6 +31,63 @@ from model.dataset import PretrainDataset
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
# 内存监控辅助函数
def get_memory_usage():
"""获取当前内存使用情况"""
process = psutil.Process()
memory_info = process.memory_info()
return {
'rss_mb': memory_info.rss / 1024 / 1024, # 物理内存使用量MB
'vms_mb': memory_info.vms / 1024 / 1024, # 虚拟内存使用量MB
}
def get_cuda_memory_usage():
"""获取CUDA内存使用情况"""
if torch.cuda.is_available():
return {
'cuda_allocated_mb': torch.cuda.memory_allocated() / 1024 / 1024,
'cuda_reserved_mb': torch.cuda.memory_reserved() / 1024 / 1024,
'cuda_max_allocated_mb': torch.cuda.max_memory_allocated() / 1024 / 1024,
}
return {}
def get_tensor_memory_size(tensor_list):
"""计算tensor列表的总内存占用MB"""
total_size = 0
for batch in tensor_list:
if isinstance(batch, (list, tuple)):
for tensor in batch:
if isinstance(tensor, torch.Tensor):
total_size += tensor.numel() * tensor.element_size()
elif isinstance(batch, torch.Tensor):
total_size += batch.numel() * batch.element_size()
return total_size / 1024 / 1024 # 转换为MB
def log_memory_status(step, prefetch_batches, accelerator, stage="", detailed=False):
"""记录内存状态"""
if not accelerator.is_main_process:
return
memory_info = get_memory_usage()
cuda_info = get_cuda_memory_usage()
prefetch_memory = get_tensor_memory_size(prefetch_batches)
log_msg = f"[Memory Monitor] Step {step} {stage} - "
log_msg += f"Prefetch batches: {len(prefetch_batches)}, "
log_msg += f"Prefetch memory: {prefetch_memory:.2f}MB, "
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
if detailed:
log_msg += f", System VMS: {memory_info['vms_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA max allocated: {cuda_info['cuda_max_allocated_mb']:.2f}MB"
Logger(log_msg, accelerator)
# 日志记录函数 # 日志记录函数
def Logger(msg, accelerator=None): def Logger(msg, accelerator=None):
# 如果没有提供accelerator则只在主进程打印 # 如果没有提供accelerator则只在主进程打印
@ -218,7 +278,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer return model, tokenizer
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb): def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run):
loss_fct = nn.CrossEntropyLoss(reduction='none') loss_fct = nn.CrossEntropyLoss(reduction='none')
epoch_start_time = time.time() epoch_start_time = time.time()
total_steps_in_epoch = len(train_loader) total_steps_in_epoch = len(train_loader)
@ -226,6 +286,10 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
moe_path = '_moe' if args.use_moe else '' moe_path = '_moe' if args.use_moe else ''
best_loss = float('10000') best_loss = float('10000')
# 初始化CUDA事件变量
data_start = data_end = forward_start = forward_end = None
backward_start = backward_end = optimizer_start = optimizer_end = None
# 添加CUDA事件来分析性能 (只在主进程进行) # 添加CUDA事件来分析性能 (只在主进程进行)
if args.profile and accelerator.is_main_process: if args.profile and accelerator.is_main_process:
data_start = torch.cuda.Event(enable_timing=True) data_start = torch.cuda.Event(enable_timing=True)
@ -242,40 +306,63 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
data_iter = iter(train_loader) data_iter = iter(train_loader)
prefetch_batches = [] prefetch_batches = []
# 记录初始内存状态
if args.memory_monitor:
log_memory_status(-1, prefetch_batches, accelerator, "before_prefetch", detailed=True)
# 预取初始批次 # 预取初始批次
for _ in range(min(prefetch_factor, len(train_loader))): for i in range(min(prefetch_factor, len(train_loader))):
try: try:
batch = next(data_iter) batch = next(data_iter)
prefetch_batches.append(batch) prefetch_batches.append(batch)
# 每次添加batch后记录内存变化
if args.memory_monitor and accelerator.is_main_process:
log_memory_status(-1, prefetch_batches, accelerator, f"after_adding_batch_{i+1}")
except StopIteration: except StopIteration:
break break
# 记录预取完成后的内存状态
if args.memory_monitor:
log_memory_status(-1, prefetch_batches, accelerator, "after_initial_prefetch", detailed=True)
# 在开始循环前初始化日志记录所需变量 # 在开始循环前初始化日志记录所需变量
last_log_time = epoch_start_time last_log_time = epoch_start_time
for step in range(total_steps_in_epoch): for step in range(total_steps_in_epoch):
try: try:
# 计时数据加载 (只在主进程进行) # 计时数据加载 (只在主进程进行)
if args.profile and accelerator.is_main_process: if args.profile and accelerator.is_main_process and data_start is not None:
data_start.record() data_start.record()
# 记录使用batch前的内存状态根据配置间隔记录详细信息
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "before_use_batch", detailed=True)
# 使用预取的数据 # 使用预取的数据
if prefetch_batches: if prefetch_batches:
X, Y, loss_mask = prefetch_batches.pop(0) X, Y, loss_mask = prefetch_batches.pop(0)
# 记录使用batch后的内存变化
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "after_pop_batch")
else: else:
# 如果预取队列为空,直接加载 # 如果预取队列为空,直接加载
X, Y, loss_mask = next(data_iter) X, Y, loss_mask = next(data_iter)
if args.memory_monitor and accelerator.is_main_process:
Logger(f"[Memory Monitor] Step {step} - Prefetch queue empty, loading directly!", accelerator)
# 异步预取下一批数据 # 异步预取下一批数据
if step + prefetch_factor < len(train_loader): if step + prefetch_factor < len(train_loader):
try: try:
batch = next(data_iter) batch = next(data_iter)
prefetch_batches.append(batch) prefetch_batches.append(batch)
# 记录添加新batch后的内存变化
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "after_add_batch")
except StopIteration: except StopIteration:
pass pass
# 计时数据加载结束 (只在主进程进行) # 计时数据加载结束 (只在主进程进行)
if args.profile and accelerator.is_main_process: if args.profile and accelerator.is_main_process and data_end is not None:
data_end.record() data_end.record()
# 更新学习率 # 更新学习率
@ -283,7 +370,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
scheduler.step() scheduler.step()
# 计时前向传播 (只在主进程进行) # 计时前向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process: if args.profile and accelerator.is_main_process and forward_start is not None:
forward_start.record() forward_start.record()
# 前向传播 # 前向传播
@ -310,11 +397,11 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
loss = loss / args.accumulation_steps loss = loss / args.accumulation_steps
# 计时前向传播结束 (只在主进程进行) # 计时前向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process: if args.profile and accelerator.is_main_process and forward_end is not None:
forward_end.record() forward_end.record()
# 计时反向传播 (只在主进程进行) # 计时反向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process: if args.profile and accelerator.is_main_process and backward_start is not None:
backward_start.record() backward_start.record()
# 反向传播 # 反向传播
@ -322,11 +409,11 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
accelerator.backward(loss) accelerator.backward(loss)
# 计时反向传播结束 (只在主进程进行) # 计时反向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process: if args.profile and accelerator.is_main_process and backward_end is not None:
backward_end.record() backward_end.record()
# 计时优化器步骤 (只在主进程进行) # 计时优化器步骤 (只在主进程进行)
if args.profile and accelerator.is_main_process: if args.profile and accelerator.is_main_process and optimizer_start is not None:
optimizer_start.record() optimizer_start.record()
# 优化器步骤 - 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪 # 优化器步骤 - 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪
@ -339,20 +426,33 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
optimizer.zero_grad() optimizer.zero_grad()
# 计时优化器步骤结束 (只在主进程进行) # 计时优化器步骤结束 (只在主进程进行)
if args.profile and accelerator.is_main_process: if args.profile and accelerator.is_main_process and optimizer_end is not None:
optimizer_end.record() optimizer_end.record()
# 打印训练信息 (只在主进程进行) # 打印训练信息 (只在主进程进行)
if (step + 1) % args.log_interval == 0 and accelerator.is_main_process: if (step + 1) % args.log_interval == 0 and accelerator.is_main_process:
current_time = time.time() current_time = time.time()
# 记录日志输出时的详细内存状态
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "at_log_interval", detailed=True)
# 强制垃圾回收并记录内存变化
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
log_memory_status(step, prefetch_batches, accelerator, "after_gc", detailed=True)
# 计算性能指标 # 计算性能指标
if args.profile: if args.profile and accelerator.is_main_process:
torch.cuda.synchronize() torch.cuda.synchronize()
# 使用自上次日志以来的时间计算性能指标,而不是总时间
data_time = data_start.elapsed_time(data_end) # 确保所有事件都已记录才计算elapsed_time
forward_time = forward_start.elapsed_time(forward_end) try:
backward_time = backward_start.elapsed_time(backward_end) data_time = data_start.elapsed_time(data_end) if data_start is not None and data_end is not None else 0
optimizer_time = optimizer_start.elapsed_time(optimizer_end) forward_time = forward_start.elapsed_time(forward_end) if forward_start is not None and forward_end is not None else 0
backward_time = backward_start.elapsed_time(backward_end) if backward_start is not None and backward_end is not None else 0
optimizer_time = optimizer_start.elapsed_time(optimizer_end) if optimizer_start is not None and optimizer_end is not None else 0
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
# total_time_ms = data_time + forward_time + backward_time + optimizer_time # total_time_ms = data_time + forward_time + backward_time + optimizer_time
@ -373,6 +473,11 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
backward_end = torch.cuda.Event(enable_timing=True) backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True) optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True) optimizer_end = torch.cuda.Event(enable_timing=True)
except RuntimeError as e:
if "Both events must be recorded" in str(e):
Logger(f"Warning: CUDA events not properly recorded, skipping performance analysis: {e}", accelerator)
else:
raise e
# 计算当前学习率 # 计算当前学习率
@ -413,12 +518,12 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
f"Epoch Time Left: {format_time(epoch_remaining_time)} | " f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
f"Total Time Left: {format_time(total_remaining_time)}", accelerator) f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
if args.use_wandb and accelerator.is_main_process and wandb: if args.use_swanlab and accelerator.is_main_process and swanlab_run:
wandb.log(log_dict) swanlab_run.log(log_dict)
# 保存模型 (只在主进程进行) # 保存模型 (只在主进程进行)
loss_total = loss.item() * args.accumulation_steps loss_total = loss.item() * args.accumulation_steps
if best_loss > loss_total and accelerator.is_main_process: if epoch > 1 and best_loss > loss_total and accelerator.is_main_process:
best_loss = loss_total best_loss = loss_total
# 使用函数开始处定义的moe_path变量 # 使用函数开始处定义的moe_path变量
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth' ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
@ -432,9 +537,34 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
except Exception as e: except Exception as e:
Logger(f"Error in training step: {e}", accelerator) Logger(f"Error in training step: {e}", accelerator)
# 记录异常时的内存状态
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "at_exception", detailed=True)
import traceback import traceback
Logger(traceback.format_exc(), accelerator) Logger(traceback.format_exc(), accelerator)
# 清理prefetch_batches防止内存泄漏
if args.memory_monitor and accelerator.is_main_process:
Logger(f"[Memory Monitor] Clearing prefetch_batches due to exception. Current length: {len(prefetch_batches)}", accelerator)
prefetch_batches.clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "after_exception_cleanup", detailed=True)
# 训练epoch结束时清理prefetch_batches
if args.memory_monitor:
if accelerator.is_main_process:
Logger(f"[Memory Monitor] Epoch {epoch+1} finished. Clearing prefetch_batches. Final length: {len(prefetch_batches)}", accelerator)
log_memory_status(total_steps_in_epoch-1, prefetch_batches, accelerator, "before_epoch_end_cleanup", detailed=True)
prefetch_batches.clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if args.memory_monitor:
log_memory_status(total_steps_in_epoch-1, prefetch_batches, accelerator, "after_epoch_end_cleanup", detailed=True)
def main(): def main():
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate") parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
parser.add_argument("--out_dir", type=str, default="out") parser.add_argument("--out_dir", type=str, default="out")
@ -443,8 +573,8 @@ def main():
parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--learning_rate", type=float, default=2e-4) parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", default=True, action="store_true") parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain") parser.add_argument("--swanlab_project", type=str, default="MiniMind-Pretrain") # 替换wandb参数
parser.add_argument("--num_workers", type=int, default=8) parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--accumulation_steps", type=int, default=32) parser.add_argument("--accumulation_steps", type=int, default=32)
parser.add_argument("--grad_clip", type=float, default=1.0) parser.add_argument("--grad_clip", type=float, default=1.0)
@ -456,17 +586,19 @@ def main():
parser.add_argument('--max_seq_len', default=512, type=int) parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--use_moe', default=False, type=bool) parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代") parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代")
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") parser.add_argument("--data_path", type=str, default="./dataset/merged_pretrain.jsonl")
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)") parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)") parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention") parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowledge_num", type=int, default=8192,help="知识库的数据数目") parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目")
parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度") parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径") parser.add_argument("--database_init_path", type=str, default="./dataset/combined_prepare.json", help="数据库初始化路径")
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)") parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径") parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件") parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控")
parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)")
args = parser.parse_args() args = parser.parse_args()
######################################################### #########################################################
@ -479,7 +611,7 @@ def main():
gradient_accumulation_steps=args.accumulation_steps, gradient_accumulation_steps=args.accumulation_steps,
gradient_clipping=args.grad_clip, gradient_clipping=args.grad_clip,
zero_stage=2, # 使用ZeRO-2优化 zero_stage=2, # 使用ZeRO-2优化
offload_optimizer_device="cpu", # 将优化器状态卸载到CPU offload_optimizer_device="none", # 将优化器状态卸载到CPU
offload_param_device="none", # 不将参数卸载到CPU offload_param_device="none", # 不将参数卸载到CPU
) )
accelerator = Accelerator( accelerator = Accelerator(
@ -523,18 +655,30 @@ def main():
######################################################### #########################################################
# 配置wandb # 配置SwanLab
######################################################### #########################################################
# 设置wandb运行名称 # 设置SwanLab运行名称
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" args.swanlab_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
if args.use_wandb and accelerator.is_main_process:
import wandb # 合并args和lm_config为一个字典无论是否使用SwanLab都需要用于打印配置信息
# 合并args和lm_config为一个字典
config_dict = vars(args).copy() config_dict = vars(args).copy()
config_dict.update(vars(lm_config)) config_dict.update(vars(lm_config))
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config_dict)
# 初始化SwanLab实验实例
swanlab_run = None
if args.use_swanlab and accelerator.is_main_process:
# 初始化SwanLab
swanlab_run = swanlab.init(
project=args.swanlab_project,
experiment_name=args.swanlab_run_name,
description="MiniMind预训练实验使用本地部署的SwanLab进行可视化",
config=config_dict
# 设置SwanLab服务器地址和API Key
# host="http://100.123.118.114:11071",
# api_key="LesBT7HRq23HNBrOPKP8S"
)
else: else:
wandb = None swanlab_run = None
######################################################### #########################################################
# 打印信息 # 打印信息
@ -616,13 +760,31 @@ def main():
######################################################### #########################################################
overall_start_time = time.time() # Record overall start time overall_start_time = time.time() # Record overall start time
for epoch in range(args.epochs): for epoch in range(args.epochs):
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb) # Pass overall start time Logger(f"开始第{epoch+1}轮训练", accelerator)
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run) # Pass overall start time
# 每个epoch结束后进行内存清理
Logger(f"{epoch+1}轮训练完成,进行内存清理", accelerator)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 记录epoch结束时的内存状态
if accelerator.is_main_process:
memory_info = get_memory_usage()
cuda_info = get_cuda_memory_usage()
log_msg = f"[Memory Monitor] Epoch {epoch+1} completed - "
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
Logger(log_msg, accelerator)
######################################################### #########################################################
# 关闭wandb # 关闭SwanLab
######################################################### #########################################################
if args.use_wandb and accelerator.is_main_process: if args.use_swanlab and accelerator.is_main_process and swanlab_run:
wandb.finish() swanlab_run.finish()
if __name__ == "__main__": if __name__ == "__main__":
main() main()