Compare commits

...

4 Commits

Author SHA1 Message Date
f750edd9ba update 2025-06-23 23:47:10 +08:00
5f19adcffa update 2025-06-23 23:05:47 +08:00
44cd7b4d72 update 2025-06-23 22:15:51 +08:00
83b91859ce update 2025-06-20 12:43:21 +08:00
2 changed files with 313 additions and 145 deletions

View File

@ -2,7 +2,8 @@ import math
import struct
import inspect
import time
import gc
#子空间二维分解+梯度更新
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
@ -67,23 +68,21 @@ class KnowledgeDataset(nn.Module):
## 数据库参数
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
self.keys = nn.Parameter(torch.randn(self.knowledge_num, self.knowledge_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.knowledge_num)
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
self.register_buffer('has_update_keys', torch.zeros(self.knowledge_num))
# 修改键存储为二维分解空间,设置为可训练参数
self.num_keys = int(math.sqrt(self.knowledge_num))
# 确保keys是可训练参数
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
)
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
# 计算step数目用于动态调整权重
self.step_counter = 0
self.freeze_embedding = False
# 移除批次计数器和更新频率相关代码
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
@ -94,6 +93,15 @@ class KnowledgeDataset(nn.Module):
device = all_scores.device
dtype = all_scores.dtype
# 记录进入智能选择前的内存状态
if hasattr(self, 'step_counter'):
self.step_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.step_counter % 50 == 0: # 每50次调用记录一次
# if torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_ENTER] Step {self.step_counter}: GPU Memory: {allocated_before:.2f}GB")
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
@ -106,7 +114,8 @@ class KnowledgeDataset(nn.Module):
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
#获取flat_tokens对应的index
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
@ -158,85 +167,64 @@ class KnowledgeDataset(nn.Module):
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
# 获取
# 使用重新计算的embeddings更新self.keys
if self.is_train:
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
# 清理中间张量以防止内存泄漏
del all_candidate_indices, unique_indices, inverse_indices
del unique_candidate_features, normalized_candidates, normalized_queries
del batch_best_tokens, batch_best_tokens_embeddings
del flat_tokens, flat_embeddings, pre_update_embeddings
# 记录退出智能选择后的内存状态(已禁用以提高性能)
# if hasattr(self, 'step_counter') and self.step_counter % 50 == 0:
# if torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_EXIT] Step {self.step_counter}: GPU Memory: {allocated_after:.2f}GB")
# 强制垃圾回收(仅在监控步骤)
if hasattr(self, 'step_counter') and self.step_counter % 100 == 0:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return all_best_tokens, all_best_tokens_embeddings
def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings):
if self.freeze_embedding:
return
# 使用pre_update_embeddings更新self.keys
with torch.no_grad():
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512]
pre_update_embeddings = self.to_queries(pre_update_embeddings)
self.keys[pre_update_indices] = pre_update_embeddings
def search_index(self,x):
def search_index(self, x):
batch_size, seq_len, dim = x.shape
# collapse sequence dimension by averaging
# 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim]
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
# queries = queries.reshape(batch_size, 2, self.key_dim)
# queries = queries.permute(1, 0, 2)
# 2. 生成查询向量并重塑为两个子查询
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
# 2. 计算queries与keys的相似度
sim = torch.einsum('b d, k d -> b k', queries, self.keys)
# 3. 计算每个子空间的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 3. 在两个子空间分别做top-k
scores_and_indices = sim.topk(self.product_key_topk, dim=-1)
scores, indices = scores_and_indices[0], scores_and_indices[1]
# 4. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 5. 应用智能分层选择策略
# 5. 组合两个子空间的结果
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
# 6. 更新1%的keys
if self.is_train:
# 获取未更新过的keys的索引
not_updated_indices = torch.where(self.has_update_keys == 0)[0]
# 如果有未更新的keys随机选择num_update_keys个进行更新
if len(not_updated_indices) > 0:
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
perm_num = perm.shape[0]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
else:
print("all keys are updated")
# 重置所有keys的更新状态
self.has_update_keys.zero_()
# 重新获取所有可更新的索引
not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device)
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return best_tokens, best_tokens_embeddings
class CrossAttention(nn.Module):
@ -256,6 +244,16 @@ class CrossAttention(nn.Module):
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 监控交叉注意力开始时的内存(已禁用以提高性能)
if not hasattr(self, 'call_counter'):
self.call_counter = 0
self.call_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_ENTER] Call {self.call_counter}: GPU Memory: {allocated_before:.2f}GB")
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
@ -281,6 +279,14 @@ class CrossAttention(nn.Module):
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
# 清理中间张量
del q, k, v, attn_scores, attn_weights
# 监控交叉注意力结束时的内存(已禁用以提高性能)
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_EXIT] Call {self.call_counter}: GPU Memory: {allocated_after:.2f}GB")
return context
@ -520,12 +526,11 @@ class MiniMindLM(PreTrainedModel):
step: int = 0,
**args):
start_pos = args.get('start_pos', 0)
if self.freeze_embedding and step == 0:
self.tok_embeddings.weight.requires_grad = False
# 同时冻结KnowledgeDataset的嵌入更新
self.knowledge_dataset.freeze_embedding = True
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding)
# if self.freeze_embedding and step == 0:
# self.tok_embeddings.weight.requires_grad = False
# # 移除对knowledge_dataset.freeze_embedding的设置让键更新由batch_counter控制
# # self.knowledge_dataset.freeze_embedding = True
# print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers):
@ -600,4 +605,5 @@ class MiniMindLM(PreTrainedModel):
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break
break

View File

@ -1,6 +1,6 @@
import os
# 设置环境变量
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
# 设置环境变量 - 将wandb替换为SwanLab
# os.environ["SWANLAB_MODE"] = "online" # SwanLab使用在线模式
import platform
import argparse
from tqdm import tqdm
@ -21,6 +21,9 @@ from accelerate.utils import DistributedDataParallelKwargs
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import swanlab # 替换wandb导入
import gc # 添加垃圾回收模块
import psutil # 添加系统资源监控模块
from model.model import MiniMindLM, RMSNorm
from model.LMConfig import LMConfig
@ -28,6 +31,63 @@ from model.dataset import PretrainDataset
warnings.filterwarnings('ignore')
# 内存监控辅助函数
def get_memory_usage():
"""获取当前内存使用情况"""
process = psutil.Process()
memory_info = process.memory_info()
return {
'rss_mb': memory_info.rss / 1024 / 1024, # 物理内存使用量MB
'vms_mb': memory_info.vms / 1024 / 1024, # 虚拟内存使用量MB
}
def get_cuda_memory_usage():
"""获取CUDA内存使用情况"""
if torch.cuda.is_available():
return {
'cuda_allocated_mb': torch.cuda.memory_allocated() / 1024 / 1024,
'cuda_reserved_mb': torch.cuda.memory_reserved() / 1024 / 1024,
'cuda_max_allocated_mb': torch.cuda.max_memory_allocated() / 1024 / 1024,
}
return {}
def get_tensor_memory_size(tensor_list):
"""计算tensor列表的总内存占用MB"""
total_size = 0
for batch in tensor_list:
if isinstance(batch, (list, tuple)):
for tensor in batch:
if isinstance(tensor, torch.Tensor):
total_size += tensor.numel() * tensor.element_size()
elif isinstance(batch, torch.Tensor):
total_size += batch.numel() * batch.element_size()
return total_size / 1024 / 1024 # 转换为MB
def log_memory_status(step, prefetch_batches, accelerator, stage="", detailed=False):
"""记录内存状态"""
if not accelerator.is_main_process:
return
memory_info = get_memory_usage()
cuda_info = get_cuda_memory_usage()
prefetch_memory = get_tensor_memory_size(prefetch_batches)
log_msg = f"[Memory Monitor] Step {step} {stage} - "
log_msg += f"Prefetch batches: {len(prefetch_batches)}, "
log_msg += f"Prefetch memory: {prefetch_memory:.2f}MB, "
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
if detailed:
log_msg += f", System VMS: {memory_info['vms_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA max allocated: {cuda_info['cuda_max_allocated_mb']:.2f}MB"
Logger(log_msg, accelerator)
# 日志记录函数
def Logger(msg, accelerator=None):
# 如果没有提供accelerator则只在主进程打印
@ -218,7 +278,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb):
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run):
loss_fct = nn.CrossEntropyLoss(reduction='none')
epoch_start_time = time.time()
total_steps_in_epoch = len(train_loader)
@ -226,6 +286,10 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
moe_path = '_moe' if args.use_moe else ''
best_loss = float('10000')
# 初始化CUDA事件变量
data_start = data_end = forward_start = forward_end = None
backward_start = backward_end = optimizer_start = optimizer_end = None
# 添加CUDA事件来分析性能 (只在主进程进行)
if args.profile and accelerator.is_main_process:
data_start = torch.cuda.Event(enable_timing=True)
@ -242,40 +306,63 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
data_iter = iter(train_loader)
prefetch_batches = []
# 记录初始内存状态
if args.memory_monitor:
log_memory_status(-1, prefetch_batches, accelerator, "before_prefetch", detailed=True)
# 预取初始批次
for _ in range(min(prefetch_factor, len(train_loader))):
for i in range(min(prefetch_factor, len(train_loader))):
try:
batch = next(data_iter)
prefetch_batches.append(batch)
# 每次添加batch后记录内存变化
if args.memory_monitor and accelerator.is_main_process:
log_memory_status(-1, prefetch_batches, accelerator, f"after_adding_batch_{i+1}")
except StopIteration:
break
# 记录预取完成后的内存状态
if args.memory_monitor:
log_memory_status(-1, prefetch_batches, accelerator, "after_initial_prefetch", detailed=True)
# 在开始循环前初始化日志记录所需变量
last_log_time = epoch_start_time
for step in range(total_steps_in_epoch):
try:
# 计时数据加载 (只在主进程进行)
if args.profile and accelerator.is_main_process:
if args.profile and accelerator.is_main_process and data_start is not None:
data_start.record()
# 记录使用batch前的内存状态根据配置间隔记录详细信息
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "before_use_batch", detailed=True)
# 使用预取的数据
if prefetch_batches:
X, Y, loss_mask = prefetch_batches.pop(0)
# 记录使用batch后的内存变化
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "after_pop_batch")
else:
# 如果预取队列为空,直接加载
X, Y, loss_mask = next(data_iter)
if args.memory_monitor and accelerator.is_main_process:
Logger(f"[Memory Monitor] Step {step} - Prefetch queue empty, loading directly!", accelerator)
# 异步预取下一批数据
if step + prefetch_factor < len(train_loader):
try:
batch = next(data_iter)
prefetch_batches.append(batch)
# 记录添加新batch后的内存变化
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "after_add_batch")
except StopIteration:
pass
# 计时数据加载结束 (只在主进程进行)
if args.profile and accelerator.is_main_process:
if args.profile and accelerator.is_main_process and data_end is not None:
data_end.record()
# 更新学习率
@ -283,7 +370,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
scheduler.step()
# 计时前向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process:
if args.profile and accelerator.is_main_process and forward_start is not None:
forward_start.record()
# 前向传播
@ -310,11 +397,11 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
loss = loss / args.accumulation_steps
# 计时前向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process:
if args.profile and accelerator.is_main_process and forward_end is not None:
forward_end.record()
# 计时反向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process:
if args.profile and accelerator.is_main_process and backward_start is not None:
backward_start.record()
# 反向传播
@ -322,11 +409,11 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
accelerator.backward(loss)
# 计时反向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process:
if args.profile and accelerator.is_main_process and backward_end is not None:
backward_end.record()
# 计时优化器步骤 (只在主进程进行)
if args.profile and accelerator.is_main_process:
if args.profile and accelerator.is_main_process and optimizer_start is not None:
optimizer_start.record()
# 优化器步骤 - 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪
@ -339,40 +426,58 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
optimizer.zero_grad()
# 计时优化器步骤结束 (只在主进程进行)
if args.profile and accelerator.is_main_process:
if args.profile and accelerator.is_main_process and optimizer_end is not None:
optimizer_end.record()
# 打印训练信息 (只在主进程进行)
if (step + 1) % args.log_interval == 0 and accelerator.is_main_process:
current_time = time.time()
# 记录日志输出时的详细内存状态
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "at_log_interval", detailed=True)
# 强制垃圾回收并记录内存变化
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
log_memory_status(step, prefetch_batches, accelerator, "after_gc", detailed=True)
# 计算性能指标
if args.profile:
if args.profile and accelerator.is_main_process:
torch.cuda.synchronize()
# 使用自上次日志以来的时间计算性能指标,而不是总时间
data_time = data_start.elapsed_time(data_end)
forward_time = forward_start.elapsed_time(forward_end)
backward_time = backward_start.elapsed_time(backward_end)
optimizer_time = optimizer_start.elapsed_time(optimizer_end)
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
# total_time_ms = data_time + forward_time + backward_time + optimizer_time
# 确保所有事件都已记录才计算elapsed_time
try:
data_time = data_start.elapsed_time(data_end) if data_start is not None and data_end is not None else 0
forward_time = forward_start.elapsed_time(forward_end) if forward_start is not None and forward_end is not None else 0
backward_time = backward_start.elapsed_time(backward_end) if backward_start is not None and backward_end is not None else 0
optimizer_time = optimizer_start.elapsed_time(optimizer_end) if optimizer_start is not None and optimizer_end is not None else 0
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
# total_time_ms = data_time + forward_time + backward_time + optimizer_time
# 打印性能分析
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
Logger(f"性能分析 (Avg/iter over last {args.log_interval} steps) - "
f"Data: {data_time/args.log_interval:.2f}ms, "
f"Fwd: {forward_time/args.log_interval:.2f}ms, "
f"Bwd: {backward_time/args.log_interval:.2f}ms, "
f"Optim: {optimizer_time/args.log_interval:.2f}ms, "
f"Iter Time: {iter_time:.2f}ms", accelerator)
# 重置事件以便下次测量从0开始
data_start = torch.cuda.Event(enable_timing=True)
data_end = torch.cuda.Event(enable_timing=True)
forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
# 打印性能分析
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
Logger(f"性能分析 (Avg/iter over last {args.log_interval} steps) - "
f"Data: {data_time/args.log_interval:.2f}ms, "
f"Fwd: {forward_time/args.log_interval:.2f}ms, "
f"Bwd: {backward_time/args.log_interval:.2f}ms, "
f"Optim: {optimizer_time/args.log_interval:.2f}ms, "
f"Iter Time: {iter_time:.2f}ms", accelerator)
# 重置事件以便下次测量从0开始
data_start = torch.cuda.Event(enable_timing=True)
data_end = torch.cuda.Event(enable_timing=True)
forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
except RuntimeError as e:
if "Both events must be recorded" in str(e):
Logger(f"Warning: CUDA events not properly recorded, skipping performance analysis: {e}", accelerator)
else:
raise e
# 计算当前学习率
@ -413,12 +518,12 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
if args.use_wandb and accelerator.is_main_process and wandb:
wandb.log(log_dict)
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
swanlab_run.log(log_dict)
# 保存模型 (只在主进程进行)
loss_total = loss.item() * args.accumulation_steps
if best_loss > loss_total and accelerator.is_main_process:
if epoch > 1 and best_loss > loss_total and accelerator.is_main_process:
best_loss = loss_total
# 使用函数开始处定义的moe_path变量
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
@ -432,8 +537,33 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
except Exception as e:
Logger(f"Error in training step: {e}", accelerator)
# 记录异常时的内存状态
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "at_exception", detailed=True)
import traceback
Logger(traceback.format_exc(), accelerator)
# 清理prefetch_batches防止内存泄漏
if args.memory_monitor and accelerator.is_main_process:
Logger(f"[Memory Monitor] Clearing prefetch_batches due to exception. Current length: {len(prefetch_batches)}", accelerator)
prefetch_batches.clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "after_exception_cleanup", detailed=True)
# 训练epoch结束时清理prefetch_batches
if args.memory_monitor:
if accelerator.is_main_process:
Logger(f"[Memory Monitor] Epoch {epoch+1} finished. Clearing prefetch_batches. Final length: {len(prefetch_batches)}", accelerator)
log_memory_status(total_steps_in_epoch-1, prefetch_batches, accelerator, "before_epoch_end_cleanup", detailed=True)
prefetch_batches.clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if args.memory_monitor:
log_memory_status(total_steps_in_epoch-1, prefetch_batches, accelerator, "after_epoch_end_cleanup", detailed=True)
def main():
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
@ -443,8 +573,8 @@ def main():
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", default=True, action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数
parser.add_argument("--swanlab_project", type=str, default="MiniMind-Pretrain") # 替换wandb参数
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--accumulation_steps", type=int, default=32)
parser.add_argument("--grad_clip", type=float, default=1.0)
@ -456,19 +586,21 @@ def main():
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代")
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
parser.add_argument("--data_path", type=str, default="./dataset/merged_pretrain.jsonl")
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowledge_num", type=int, default=8192,help="知识库的数据数目")
parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目")
parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
parser.add_argument("--database_init_path", type=str, default="./dataset/combined_prepare.json", help="数据库初始化路径")
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控")
parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)")
args = parser.parse_args()
#########################################################
# 初始化accelerator和deepspeed
#########################################################
@ -479,7 +611,7 @@ def main():
gradient_accumulation_steps=args.accumulation_steps,
gradient_clipping=args.grad_clip,
zero_stage=2, # 使用ZeRO-2优化
offload_optimizer_device="cpu", # 将优化器状态卸载到CPU
offload_optimizer_device="none", # 将优化器状态卸载到CPU
offload_param_device="none", # 不将参数卸载到CPU
)
accelerator = Accelerator(
@ -523,18 +655,30 @@ def main():
#########################################################
# 配置wandb
# 配置SwanLab
#########################################################
# 设置wandb运行名称
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
if args.use_wandb and accelerator.is_main_process:
import wandb
# 合并args和lm_config为一个字典
config_dict = vars(args).copy()
config_dict.update(vars(lm_config))
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config_dict)
# 设置SwanLab运行名称
args.swanlab_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
# 合并args和lm_config为一个字典无论是否使用SwanLab都需要用于打印配置信息
config_dict = vars(args).copy()
config_dict.update(vars(lm_config))
# 初始化SwanLab实验实例
swanlab_run = None
if args.use_swanlab and accelerator.is_main_process:
# 初始化SwanLab
swanlab_run = swanlab.init(
project=args.swanlab_project,
experiment_name=args.swanlab_run_name,
description="MiniMind预训练实验使用本地部署的SwanLab进行可视化",
config=config_dict
# 设置SwanLab服务器地址和API Key
# host="http://100.123.118.114:11071",
# api_key="LesBT7HRq23HNBrOPKP8S"
)
else:
wandb = None
swanlab_run = None
#########################################################
# 打印信息
@ -616,13 +760,31 @@ def main():
#########################################################
overall_start_time = time.time() # Record overall start time
for epoch in range(args.epochs):
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb) # Pass overall start time
Logger(f"开始第{epoch+1}轮训练", accelerator)
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run) # Pass overall start time
# 每个epoch结束后进行内存清理
Logger(f"{epoch+1}轮训练完成,进行内存清理", accelerator)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 记录epoch结束时的内存状态
if accelerator.is_main_process:
memory_info = get_memory_usage()
cuda_info = get_cuda_memory_usage()
log_msg = f"[Memory Monitor] Epoch {epoch+1} completed - "
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
Logger(log_msg, accelerator)
#########################################################
# 关闭wandb
# 关闭SwanLab
#########################################################
if args.use_wandb and accelerator.is_main_process:
wandb.finish()
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
swanlab_run.finish()
if __name__ == "__main__":
main()