Minimind/train_extra_accelerate.py
2025-07-05 03:03:43 +00:00

1042 lines
50 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
# 设置环境变量 - 将wandb替换为SwanLab
# os.environ["SWANLAB_MODE"] = "online" # SwanLab使用在线模式
import platform
import argparse
from tqdm import tqdm
import time
import math
import warnings
import pandas as pd
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
from contextlib import nullcontext
from typing import Optional
import datetime # Add datetime for time formatting
from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.utils import DeepSpeedPlugin
from accelerate.utils import DistributedDataParallelKwargs
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import swanlab # 替换wandb导入
import gc # 添加垃圾回收模块
import psutil # 添加系统资源监控模块
import os
import json
os.environ['CUDA_VISIBLE_DEVICES']='2'
from model.model_extra import MiniMindLM, RMSNorm # 使用model_extra
from model.LMConfig import LMConfig
from model.dataset import TriplePretrainDataset # 只需要三元组数据集
warnings.filterwarnings('ignore')
# 基于嵌入的余弦相似度损失计算函数
def compute_embedding_cosine_loss(subject_logits, predicate_logits, object_logits,
target_triples, tokenizer, tok_embeddings,
pooling_method='mean', max_targets=5, temperature=1.0):
"""
基于嵌入的余弦相似度损失计算
Args:
subject_logits: [batch_size, max_subject_len, vocab_size]
predicate_logits: [batch_size, max_predicate_len, vocab_size]
object_logits: [batch_size, max_object_len, vocab_size]
target_triples: List[List[str]] - 每个样本的多个目标句子
tokenizer: 分词器
tok_embeddings: 模型的token嵌入层
pooling_method: 句子嵌入的池化方法 ('mean', 'max', 'cls')
max_targets: int - 每个样本最大目标句子数量
temperature: float - Softmax温度参数控制预测的平滑度
Returns:
torch.Tensor: 余弦相似度损失
"""
if not target_triples or len(target_triples) == 0:
# 创建一个与输入张量相关的损失,保持在计算图中
dummy_loss = subject_logits.sum() * 0.0 + 1.0 # 这样创建的张量会保持梯度
return dummy_loss
batch_size = subject_logits.shape[0]
# 1. 获取预测的嵌入表示
pred_embeddings = get_prediction_embeddings(
subject_logits, predicate_logits, object_logits,
tok_embeddings, pooling_method, temperature
) # [batch_size, embed_dim]
# 2. 获取目标的嵌入表示
target_embeddings = get_target_embeddings(
target_triples, tokenizer, tok_embeddings, pooling_method, max_targets
) # [batch_size, max_targets, embed_dim]
# 3. 计算余弦相似度
similarities = compute_cosine_similarity_batch(pred_embeddings, target_embeddings)
# [batch_size, max_targets]
# 4. 选择最高相似度(最小损失)
best_similarities = torch.max(similarities, dim=-1)[0] # [batch_size]
# 5. 转换为损失 (1 - cosine_similarity)
loss = 1.0 - best_similarities.mean()
# 确保损失值在合理范围内(保持计算图连接)
loss = torch.clamp(loss, min=0.0, max=2.0)
return loss
def get_prediction_embeddings(subject_logits, predicate_logits, object_logits,
tok_embeddings, pooling_method='mean', temperature=1.0):
"""
从预测logits获取句子嵌入使用soft embedding保持梯度
"""
batch_size = subject_logits.shape[0]
# 使用softmax获取概率分布而不是argmax
subject_probs = torch.softmax(subject_logits / temperature, dim=-1) # [batch_size, max_subject_len, vocab_size]
predicate_probs = torch.softmax(predicate_logits / temperature, dim=-1) # [batch_size, max_predicate_len, vocab_size]
object_probs = torch.softmax(object_logits / temperature, dim=-1) # [batch_size, max_object_len, vocab_size]
# 使用概率分布与嵌入矩阵进行加权求和得到soft embeddings
# tok_embeddings.weight: [vocab_size, embed_dim]
subject_embeddings = torch.matmul(subject_probs, tok_embeddings.weight) # [batch_size, max_subject_len, embed_dim]
predicate_embeddings = torch.matmul(predicate_probs, tok_embeddings.weight) # [batch_size, max_predicate_len, embed_dim]
object_embeddings = torch.matmul(object_probs, tok_embeddings.weight) # [batch_size, max_object_len, embed_dim]
# 拼接所有部分的嵌入
combined_embeddings = torch.cat([subject_embeddings, predicate_embeddings, object_embeddings], dim=1)
# [batch_size, total_len, embed_dim]
# 池化得到句子嵌入
if pooling_method == 'mean':
# 简单平均池化
sentence_embeddings = combined_embeddings.mean(dim=1)
elif pooling_method == 'max':
sentence_embeddings = combined_embeddings.max(dim=1)[0]
elif pooling_method == 'cls':
# 使用第一个token作为句子表示
sentence_embeddings = combined_embeddings[:, 0, :]
else:
sentence_embeddings = combined_embeddings.mean(dim=1)
return sentence_embeddings # [batch_size, embed_dim]
def get_target_embeddings(target_triples, tokenizer, tok_embeddings, pooling_method='mean', max_targets=5):
"""
批量获取目标句子的嵌入表示
Args:
target_triples: List[List[str]] - 每个样本的目标句子列表
max_targets: int - 每个样本最大目标句子数量不足补空字符串超过则截取前max_targets个
"""
batch_size = len(target_triples)
if not target_triples:
# 如果没有目标句子,返回与嵌入层相关的零嵌入(保持计算图)
embed_dim = tok_embeddings.embedding_dim
# 使用嵌入层的权重创建零张量,保持计算图连接
zero_embeddings = tok_embeddings.weight[:1, :].expand(batch_size, max_targets, embed_dim) * 0.0
return zero_embeddings
# 标准化每个样本的目标数量为max_targets
normalized_targets = []
for targets in target_triples:
if len(targets) >= max_targets:
# 超过max_targets取前max_targets个
normalized_targets.extend(targets[:max_targets])
else:
# 不足max_targets补空字符串
normalized_targets.extend(targets)
normalized_targets.extend([''] * (max_targets - len(targets)))
# 现在 normalized_targets 的长度是 batch_size * max_targets
assert len(normalized_targets) == batch_size * max_targets
# 批量tokenize所有目标句子
tokenized = tokenizer(
normalized_targets,
padding=True,
truncation=True,
return_tensors='pt',
max_length=128 # 可以调整
)
# 移到正确的设备
input_ids = tokenized['input_ids'].to(tok_embeddings.weight.device)
attention_mask = tokenized['attention_mask'].to(tok_embeddings.weight.device)
# 获取token嵌入
token_embeddings = tok_embeddings(input_ids) # [batch_size * max_targets, seq_len, embed_dim]
# 应用attention mask并池化
if pooling_method == 'mean':
# 使用attention mask进行加权平均
masked_embeddings = token_embeddings * attention_mask.unsqueeze(-1)
sentence_embeddings = masked_embeddings.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True).clamp(min=1e-8)
elif pooling_method == 'max':
# 在有效token上取最大值
masked_embeddings = token_embeddings.masked_fill(
~attention_mask.unsqueeze(-1).bool(), float('-inf')
)
sentence_embeddings = masked_embeddings.max(dim=1)[0]
else:
sentence_embeddings = token_embeddings.mean(dim=1)
# 重新整形为 [batch_size, max_targets, embed_dim]
embed_dim = sentence_embeddings.shape[-1]
target_embeddings = sentence_embeddings.view(batch_size, max_targets, embed_dim)
return target_embeddings
def compute_cosine_similarity_batch(pred_embeddings, target_embeddings):
"""
批量计算余弦相似度
Args:
pred_embeddings: [batch_size, embed_dim]
target_embeddings: [batch_size, max_targets, embed_dim]
Returns:
similarities: [batch_size, max_targets]
"""
# 标准化
pred_norm = torch.nn.functional.normalize(pred_embeddings, p=2, dim=-1) # [batch_size, embed_dim]
target_norm = torch.nn.functional.normalize(target_embeddings, p=2, dim=-1) # [batch_size, max_targets, embed_dim]
# 计算余弦相似度
# pred_norm: [batch_size, 1, embed_dim]
# target_norm: [batch_size, max_targets, embed_dim]
similarities = torch.sum(pred_norm.unsqueeze(1) * target_norm, dim=-1)
# [batch_size, max_targets]
return similarities
def triple_to_sentence(subject_logits, predicate_logits, object_logits, tokenizer, predicate_cls_logits=None):
"""
将三元组logits转换为句子
Args:
subject_logits: [batch_size, seq_len, max_subject_len, vocab_size]
predicate_logits: [batch_size, seq_len, max_predicate_len, vocab_size]
object_logits: [batch_size, seq_len, max_object_len, vocab_size]
tokenizer: 分词器
predicate_cls_logits: [batch_size, num_predicates],如果提供则用分类结果输出谓词
Returns:
List[str]: 每个样本的三元组句子
"""
batch_size = subject_logits.shape[0]
# 主语
subject_seq_len = subject_logits.shape[1]
subject_logits_ = subject_logits.reshape(batch_size * subject_seq_len, -1)
subject_ids = torch.argmax(subject_logits_, dim=-1)
subject_ids = subject_ids.reshape(batch_size, subject_seq_len)
# 宾语
object_seq_len = object_logits.shape[1]
object_logits_ = object_logits.reshape(batch_size * object_seq_len, -1)
object_ids = torch.argmax(object_logits_, dim=-1)
object_ids = object_ids.reshape(batch_size, object_seq_len)
# 谓词
predicate_texts = []
if predicate_cls_logits is not None:
# 用分类结果输出谓词
pred_ids = torch.argmax(predicate_cls_logits, dim=-1) # [batch_size]
for i in range(batch_size):
pred_id = pred_ids[i].item()
pred_text = PREDICATE_LIST[pred_id] if pred_id < len(PREDICATE_LIST) else "<UNK>"
predicate_texts.append(pred_text)
else:
# 兼容原有行为:用序列生成的谓词
predicate_seq_len = predicate_logits.shape[1]
predicate_logits_ = predicate_logits.reshape(batch_size * predicate_seq_len, -1)
predicate_ids = torch.argmax(predicate_logits_, dim=-1)
predicate_ids = predicate_ids.reshape(batch_size, predicate_seq_len)
predicate_texts = tokenizer.batch_decode(predicate_ids, skip_special_tokens=True)
# 主语和宾语文本
subject_texts = tokenizer.batch_decode(subject_ids, skip_special_tokens=True)
object_texts = tokenizer.batch_decode(object_ids, skip_special_tokens=True)
# 拼接为三元组句子
sentences = []
for i in range(batch_size):
subject = subject_texts[i].strip()
predicate = predicate_texts[i].strip() if isinstance(predicate_texts[i], str) else str(predicate_texts[i])
object_ = object_texts[i].strip()
if subject and predicate and object_:
sentence = f"{subject} {predicate} {object_}"
else:
sentence = ""
sentences.append(sentence)
return sentences
def compute_triple_rouge_loss_optimized(subject_logits, predicate_logits, object_logits,
target_input_ids, target_attention_mask, tok_embeddings, temperature=1.0):
"""
优化的三元组嵌入余弦相似度损失计算单个target版本
Args:
subject_logits: [batch_size, max_subject_len, vocab_size]
predicate_logits: [batch_size, max_predicate_len, vocab_size]
object_logits: [batch_size, max_object_len, vocab_size]
target_input_ids: [batch_size, target_seq_len] - 预tokenized的目标句子
target_attention_mask: [batch_size, target_seq_len] - 目标句子的attention mask
tok_embeddings: 模型的token嵌入层
temperature: float - Softmax温度参数控制预测的平滑度
Returns:
torch.Tensor: 嵌入余弦相似度损失 (标量)
"""
batch_size = subject_logits.shape[0]
# ✅ 修复确保target数据在正确的设备上
device = tok_embeddings.weight.device
target_input_ids = target_input_ids.to(device)
target_attention_mask = target_attention_mask.to(device)
# 1. 获取预测的嵌入表示使用soft embedding保持梯度
subject_probs = torch.softmax(subject_logits / temperature, dim=-1)
predicate_probs = torch.softmax(predicate_logits / temperature, dim=-1)
object_probs = torch.softmax(object_logits / temperature, dim=-1)
# 使用概率分布与嵌入矩阵进行加权求和
subject_embeddings = torch.matmul(subject_probs, tok_embeddings.weight)
predicate_embeddings = torch.matmul(predicate_probs, tok_embeddings.weight)
object_embeddings = torch.matmul(object_probs, tok_embeddings.weight)
# 拼接所有部分的嵌入并平均池化
combined_embeddings = torch.cat([subject_embeddings, predicate_embeddings, object_embeddings], dim=1)
pred_embeddings = combined_embeddings.mean(dim=1) # [batch_size, embed_dim]
# 2. 获取目标的嵌入表示直接使用预tokenized的数据
target_embeddings = tok_embeddings(target_input_ids) # [batch_size, target_seq_len, embed_dim]
# 使用attention mask进行加权平均池化
masked_embeddings = target_embeddings * target_attention_mask.unsqueeze(-1)
target_pooled = masked_embeddings.sum(dim=1) / target_attention_mask.sum(dim=1, keepdim=True).clamp(min=1e-8)
# [batch_size, embed_dim]
# 3. 计算余弦相似度
pred_norm = torch.nn.functional.normalize(pred_embeddings, p=2, dim=-1)
target_norm = torch.nn.functional.normalize(target_pooled, p=2, dim=-1)
# 计算余弦相似度
similarities = torch.sum(pred_norm * target_norm, dim=-1) # [batch_size]
# 4. 转换为损失 (1 - cosine_similarity)
loss = 1.0 - similarities.mean()
# 确保损失值在合理范围内
loss = torch.clamp(loss, min=0.0, max=2.0)
return loss
def compute_triple_rouge_loss(subject_logits, predicate_logits, object_logits, target_triples, tokenizer, tok_embeddings, max_targets=5, temperature=1.0):
"""
原始版本的三元组损失计算(保留用于兼容性)
Args:
subject_logits: [batch_size, max_subject_len, vocab_size]
predicate_logits: [batch_size, max_predicate_len, vocab_size]
object_logits: [batch_size, max_object_len, vocab_size]
target_triples: List[List[str]] - 每个样本的多个真值三元组句子
tokenizer: 分词器
tok_embeddings: 模型的token嵌入层
max_targets: int - 每个样本最大目标句子数量
temperature: float - Softmax温度参数控制预测的平滑度
Returns:
torch.Tensor: 嵌入余弦相似度损失 (标量)
"""
return compute_embedding_cosine_loss(
subject_logits, predicate_logits, object_logits,
target_triples, tokenizer, tok_embeddings, pooling_method='mean', max_targets=max_targets, temperature=temperature
)
# 内存监控辅助函数
def get_memory_usage():
"""获取当前内存使用情况"""
process = psutil.Process()
memory_info = process.memory_info()
return {
'rss_mb': memory_info.rss / 1024 / 1024, # 物理内存使用量MB
'vms_mb': memory_info.vms / 1024 / 1024, # 虚拟内存使用量MB
}
def get_cuda_memory_usage():
"""获取CUDA内存使用情况"""
if torch.cuda.is_available():
return {
'cuda_allocated_mb': torch.cuda.memory_allocated() / 1024 / 1024,
'cuda_reserved_mb': torch.cuda.memory_reserved() / 1024 / 1024,
'cuda_max_allocated_mb': torch.cuda.max_memory_allocated() / 1024 / 1024,
}
return {}
def get_tensor_memory_size(tensor_list):
"""计算tensor列表的总内存占用MB"""
total_size = 0
for batch in tensor_list:
if isinstance(batch, (list, tuple)):
for tensor in batch:
if isinstance(tensor, torch.Tensor):
total_size += tensor.numel() * tensor.element_size()
elif isinstance(batch, torch.Tensor):
total_size += batch.numel() * batch.element_size()
return total_size / 1024 / 1024 # 转换为MB
def log_memory_status(step, accelerator, stage="", detailed=False):
"""记录内存状态"""
if not accelerator.is_main_process:
return
memory_info = get_memory_usage()
cuda_info = get_cuda_memory_usage()
log_msg = f"[Memory Monitor] Step {step} {stage} - "
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
if detailed:
log_msg += f", System VMS: {memory_info['vms_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA max allocated: {cuda_info['cuda_max_allocated_mb']:.2f}MB"
Logger(log_msg, accelerator)
# 日志记录函数
def Logger(msg, accelerator=None):
# 如果没有提供accelerator则只在主进程打印
if accelerator is None or accelerator.is_main_process:
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}")
# Helper function to format seconds into HH:MM:SS
def format_time(seconds):
return str(datetime.timedelta(seconds=int(seconds)))
# 获取学习率函数
def get_lr(it, num_iters, learning_rate):
# 余弦学习率衰减
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
# 加载谓词类别
PREDICATE_VOCAB_PATH = '/home/rwkv/RWKV-TS/RETRO_TEST/extract/predicate_vocab.json'
with open(PREDICATE_VOCAB_PATH, 'r', encoding='utf-8') as f:
PREDICATE_LIST = json.load(f)
PREDICATE2ID = {p: i for i, p in enumerate(PREDICATE_LIST)}
NUM_PREDICATES = len(PREDICATE_LIST)
# 初始化模型函数
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config, mode="triple", num_predicates=NUM_PREDICATES)
# 加载预训练权重
pretrained_path = "/home/rwkv/RWKV-TS/RETRO_TEST/extract/Experiment_1_2_2_pretrain_512.pth"
Logger(f"Loading pretrained weights from {pretrained_path}")
try:
# 加载预训练的state_dict
pretrained_state_dict = torch.load(pretrained_path, map_location='cpu')
Logger(f"Successfully loaded pretrained state_dict with {len(pretrained_state_dict)} parameters")
# 获取当前模型的state_dict
model_state_dict = model.state_dict()
# 统计加载情况
loaded_params = []
skipped_params = []
# 逐个加载兼容的权重
for name, param in pretrained_state_dict.items():
if name in model_state_dict:
if model_state_dict[name].shape == param.shape:
model_state_dict[name].copy_(param)
loaded_params.append(name)
else:
Logger(f"Warning: Shape mismatch for {name}, expected {model_state_dict[name].shape}, got {param.shape}")
skipped_params.append(f"{name} (shape mismatch)")
else:
skipped_params.append(f"{name} (not found in model2)")
Logger(f"Loaded {len(loaded_params)} parameters from pretrained weights")
Logger(f"Skipped {len(skipped_params)} parameters")
# 显示一些关键加载的参数
key_loaded = [name for name in loaded_params if any(key in name for key in ['tok_embeddings', 'layers.0', 'knowledge_dataset', 'output', 'norm'])]
if key_loaded:
Logger("Key loaded parameters:")
for name in key_loaded[:5]: # 只显示前5个
Logger(f"{name}")
if len(key_loaded) > 5:
Logger(f" ... and {len(key_loaded) - 5} more")
# 显示跳过的参数应该主要是triple_extraction_head相关的
triple_skipped = [name for name in skipped_params if 'triple_extraction_head' in name]
if triple_skipped:
Logger("Triple extraction head parameters (newly initialized):")
for name in triple_skipped[:3]: # 只显示前3个
Logger(f" 🆕 {name}")
if len(triple_skipped) > 3:
Logger(f" ... and {len(triple_skipped) - 3} more")
except Exception as e:
Logger(f"Error loading pretrained weights: {e}")
Logger("Falling back to default initialization...")
# 默认模型初始化(备用方案)
Logger("Performing default model initialization...")
# 初始化嵌入层权重
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
# 初始化输出层权重(如果不共享权重的话)
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
# 初始化所有线性层
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# 使用Xavier/Glorot初始化
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
# 嵌入层使用正态分布初始化
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, RMSNorm):
# RMSNorm的权重初始化为1
if hasattr(module, 'weight'):
nn.init.ones_(module.weight)
# 初始化位置编码相关参数
if hasattr(model.knowledge_dataset, 'keys'):
nn.init.normal_(model.knowledge_dataset.keys, mean=0.0, std=0.02)
Logger("Default model initialization completed")
# 如果提供了预训练的嵌入权重,加载它们
if pretrained_embedding_path:
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
pretrained_embeddings = torch.load(pretrained_embedding_path)
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
Logger(f"Database embeddings and sentences stored in model")
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer):
# 三元组提取训练模式:不需要传统的交叉熵损失函数
epoch_start_time = time.time()
total_steps_in_epoch = len(train_loader)
total_training_steps = args.epochs * total_steps_in_epoch
moe_path = '_moe' if args.use_moe else ''
best_loss = float('10000')
# 初始化CUDA事件变量 - 只保留GPU计算时间追踪
forward_start = forward_end = loss_start = loss_end = backward_start = backward_end = optimizer_start = optimizer_end = None
# 添加CUDA事件来分析GPU性能 (只在主进程进行)
if args.profile and accelerator.is_main_process:
forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True)
loss_start = torch.cuda.Event(enable_timing=True)
loss_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
# 移除自定义预取机制使用DataLoader内置预取
# 记录初始内存状态
if args.memory_monitor:
memory_info = get_memory_usage()
cuda_info = get_cuda_memory_usage()
log_msg = f"[Memory Monitor] Training start - System RSS: {memory_info['rss_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
Logger(log_msg, accelerator)
# 在开始循环前初始化日志记录所需变量
last_log_time = epoch_start_time
# 使用DataLoader内置的iterator移除自定义预取
criterion_predicate = nn.CrossEntropyLoss()
for step, batch_data in enumerate(train_loader):
# === 每个step开始 ===
try:
# === 1. 数据准备 ===
# 直接使用DataLoader提供的数据
if not isinstance(batch_data, dict):
raise ValueError("期望字典格式的批次数据,请确保使用 TriplePretrainDataset")
X = batch_data['input_ids']
Y = batch_data['labels']
loss_mask = batch_data['loss_mask']
target_input_ids = batch_data['target_input_ids']
target_attention_mask = batch_data['target_attention_mask']
target_sentences = batch_data['target_sentences'] # 用于调试输出
# === 2. 学习率更新 ===
if scheduler is not None:
scheduler.step()
# === 3. 前向传播 ===
# 计时GPU前向传播
if args.profile and accelerator.is_main_process and forward_start is not None:
forward_start.record()
# 前向传播
with ctx:
if step == 0 and args.embedding_epoch == epoch:
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.freeze_embedding = True
Logger(f"Set freeze_embedding=True for epoch {epoch}, step {step}", accelerator)
res = model(X, step=step)
# 计时GPU前向传播结束
if args.profile and accelerator.is_main_process and forward_end is not None:
forward_end.record()
# === 4. 损失计算 ===
# 三元组提取模式只使用ROUGE Loss进行三元组损失计算
Logger("三元组提取训练模式", accelerator) if step == 0 else None
# 确保有三元组输出
if not (hasattr(res, 'predicate_logits') and hasattr(res, 'subject_logits') and hasattr(res, 'object_logits')):
raise ValueError("模型没有输出三元组logits请检查模型配置")
# 确保有目标数据
if target_input_ids is None:
raise ValueError("没有三元组目标数据,请检查数据格式")
# 计算三元组损失
try:
Logger("使用预tokenized三元组目标数据", accelerator) if step == 0 else None
# 计时GPU损失计算
if args.profile and accelerator.is_main_process and loss_start is not None:
loss_start.record()
# 计算优化后的嵌入余弦相似度损失
loss_triple = compute_triple_rouge_loss_optimized(
res.subject_logits, res.predicate_logits, res.object_logits,
target_input_ids, target_attention_mask, model.tok_embeddings, temperature=args.temperature
)
# 计时GPU损失计算结束
if args.profile and accelerator.is_main_process and loss_end is not None:
loss_end.record()
except Exception as e:
Logger(f"Error: ROUGE loss computation failed: {e}", accelerator)
import traceback
Logger(f"Traceback: {traceback.format_exc()}", accelerator)
loss_triple = res.logits.sum() * 0.0 + 1.0
# 谓词分类loss
loss_predicate = criterion_predicate(res.predicate_cls_logits, batch_data['predicate_label'].to(accelerator.device))
# 总loss
loss = 0.99*loss_triple + 0.01*loss_predicate
loss = loss / args.accumulation_steps
# === 5. 反向传播 ===
# 计时GPU反向传播
if args.profile and accelerator.is_main_process and backward_start is not None:
backward_start.record()
# 反向传播
accelerator.backward(loss)
# 计时GPU反向传播结束
if args.profile and accelerator.is_main_process and backward_end is not None:
backward_end.record()
# === 6. 优化器步骤 ===
# 计时GPU优化器步骤
if args.profile and accelerator.is_main_process and optimizer_start is not None:
optimizer_start.record()
# 优化器步骤
optimizer.step()
optimizer.zero_grad()
# 计时GPU优化器步骤结束
if args.profile and accelerator.is_main_process and optimizer_end is not None:
optimizer_end.record()
# === 7. 日志记录 ===
# 打印训练信息 (只在主进程进行)
if (step + 1) % args.log_interval == 0 and accelerator.is_main_process:
current_time = time.time()
# 计算GPU性能指标
if args.profile and accelerator.is_main_process:
torch.cuda.synchronize()
# 获取GPU时间
try:
forward_time = forward_start.elapsed_time(forward_end) if forward_start is not None and forward_end is not None else 0
loss_time = loss_start.elapsed_time(loss_end) if loss_start is not None and loss_end is not None else 0
backward_time = backward_start.elapsed_time(backward_end) if backward_start is not None and backward_end is not None else 0
optimizer_time = optimizer_start.elapsed_time(optimizer_end) if optimizer_start is not None and optimizer_end is not None else 0
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
# 打印GPU性能分析
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
# 计算GPU时间
gpu_time_total = (forward_time + loss_time + backward_time + optimizer_time) / args.log_interval
Logger(f"=== GPU性能分析 (平均每步) ===", accelerator)
Logger(f"前向传播: {forward_time/args.log_interval:.2f}ms, "
f"损失计算: {loss_time/args.log_interval:.2f}ms, "
f"反向传播: {backward_time/args.log_interval:.2f}ms, "
f"优化器: {optimizer_time/args.log_interval:.2f}ms", accelerator)
Logger(f"GPU总时间: {gpu_time_total:.2f}ms, "
f"实际迭代时间: {iter_time:.2f}ms, "
f"GPU利用率: {gpu_time_total/iter_time*100:.1f}%", accelerator)
Logger("=" * 50, accelerator)
Logger("=== 三元组预测示例 ===", accelerator)
predict_sentences = triple_to_sentence(res.subject_logits, res.predicate_logits, res.object_logits, tokenizer)
# 显示前2个样本的目标句子
for i, target_sentence in enumerate(target_sentences[:2]):
Logger(f"样本{i+1}目标: {target_sentence}", accelerator)
Logger(f"样本{i+1}预测: {predict_sentences[i]}", accelerator)
Logger("==================", accelerator)
# 重置GPU事件
forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True)
loss_start = torch.cuda.Event(enable_timing=True)
loss_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
except RuntimeError as e:
if "Both events must be recorded" in str(e):
Logger(f"Warning: CUDA events not properly recorded, skipping performance analysis: {e}", accelerator)
else:
raise e
# 计算基本指标
current_lr = optimizer.param_groups[0]['lr']
epoch_elapsed_time = current_time - epoch_start_time
epoch_steps_done = step + 1
epoch_avg_step_time = epoch_elapsed_time / epoch_steps_done
epoch_remaining_time = epoch_avg_step_time * (total_steps_in_epoch - epoch_steps_done)
total_elapsed_time = current_time - overall_start_time
total_steps_done = epoch * total_steps_in_epoch + epoch_steps_done
total_avg_step_time = total_elapsed_time / total_steps_done if total_steps_done > 0 else 0
total_remaining_time = total_avg_step_time * (total_training_steps - total_steps_done) if total_steps_done > 0 else 0
# 计算训练速度
interval_elapsed_time = current_time - last_log_time
tokens_processed_interval = args.log_interval * args.batch_size * args.max_seq_len
tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0
last_log_time = current_time
# 基本训练信息
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
f"Loss(triple): {loss_triple.item() * args.accumulation_steps:.6f}, "
f"Loss(predicate): {loss_predicate.item() * args.accumulation_steps:.6f}, "
f"LR: {current_lr:.6f}, "
f"Speed: {tokens_per_sec:.2f} tokens/sec | "
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
# SwanLab日志记录
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
log_dict = {
"epoch": epoch + 1,
"step": step + 1,
"total_steps_in_epoch": total_steps_in_epoch,
"triple_embedding_cosine_loss": loss_triple.item() * args.accumulation_steps,
"predicate_cross_entropy_loss": loss_predicate.item() * args.accumulation_steps,
"lr": current_lr,
"tokens_per_sec": tokens_per_sec,
"epoch_time_left_seconds": epoch_remaining_time,
"total_time_left_seconds": total_remaining_time
}
swanlab_run.log(log_dict)
# === 8. 模型保存 ===
# 保存模型 (只在主进程进行)
loss_total = loss.item() * args.accumulation_steps
if epoch > 1 and best_loss > loss_total and accelerator.is_main_process:
best_loss = loss_total
ckp = f'{args.save_dir}/pretrain_cls{args.dim}{moe_path}.pth'
unwrapped_model = accelerator.unwrap_model(model)
accelerator.save(unwrapped_model.state_dict(), ckp)
Logger(f"Model saved to {ckp}", accelerator)
except Exception as e:
Logger(f"Error in training step: {e}", accelerator)
import traceback
Logger(traceback.format_exc(), accelerator)
# 清理内存,防止内存泄漏
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 训练epoch结束时清理内存
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def main():
parser = argparse.ArgumentParser(description="MiniMind Triple Extraction Training with Accelerate")
parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=4)
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
parser.add_argument("--batch_size", type=int, default=192)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数
parser.add_argument("--swanlab_project", type=str, default="MiniMind-TripleExtraction") # 替换wandb参数
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--accumulation_steps", type=int, default=32)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=50)
parser.add_argument("--save_interval", type=int, default=10000)
parser.add_argument('--dim', default=512, type=int)
parser.add_argument('--n_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代")
parser.add_argument("--data_path", type=str, default="/home/rwkv/RWKV-TS/RETRO_TEST/extract/processed_trex_data.json")
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目")
parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
parser.add_argument("--database_init_path", type=str, default="./dataset/combined_prepare.json", help="数据库初始化路径")
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控")
parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)")
parser.add_argument("--max_targets", type=int, default=5, help="每个样本最大目标句子数量,用于批处理优化")
parser.add_argument("--temperature", type=float, default=1.0, help="Softmax温度参数用于控制预测的平滑度")
parser.add_argument("--detailed_timing", action="store_true", default=True, help="启用详细的时间追踪分析")
# 移除dataset_type参数此训练脚本专用于三元组提取训练
# parser.add_argument("--dataset_type", type=str, default="pretrain", choices=["pretrain", "triple"], help="数据集类型pretrain标准预训练或triple三元组")
args = parser.parse_args()
#########################################################
# 初始化accelerator和deepspeed
#########################################################
# 设置ddp_kwargs以处理未使用的参数
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# 创建DeepSpeedPlugin对象
ds_plugin = DeepSpeedPlugin(
gradient_accumulation_steps=args.accumulation_steps,
gradient_clipping=args.grad_clip,
zero_stage=2, # 使用ZeRO-2优化
offload_optimizer_device="none", # 将优化器状态卸载到CPU
offload_param_device="none", # 不将参数卸载到CPU
)
accelerator = Accelerator(
kwargs_handlers=[ddp_kwargs],
deepspeed_plugin=ds_plugin,
mixed_precision="bf16" if args.dtype == "bfloat16" else "fp16" if args.dtype == "float16" else "no"
)
#########################################################
# 设置随机种子
#########################################################
set_seed(1337 + accelerator.process_index)
#########################################################
# 配置模型
#########################################################
lm_config = LMConfig(
dim=args.dim,
n_layers=args.n_layers,
max_seq_len=args.max_seq_len,
use_moe=args.use_moe,
disable_db=args.disable_db,
flash_attn=args.use_flash_attn,
knowledge_num=args.knowledge_num,
knowledge_length=args.knowledge_length,
embeddings_epoch=args.embedding_epoch
)
#########################################################
# 创建保存目录
#########################################################
args.save_dir = os.path.join(args.out_dir)
if accelerator.is_main_process:
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
#########################################################
# 设置数据类型
#########################################################
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
#########################################################
# 配置SwanLab
#########################################################
# 设置SwanLab运行名称
args.swanlab_run_name = f"MiniMind-TripleExtraction-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
# 合并args和lm_config为一个字典无论是否使用SwanLab都需要用于打印配置信息
config_dict = vars(args).copy()
config_dict.update(vars(lm_config))
# 初始化SwanLab实验实例
swanlab_run = None
if args.use_swanlab and accelerator.is_main_process:
# 初始化SwanLab
swanlab_run = swanlab.init(
project=args.swanlab_project,
experiment_name=args.swanlab_run_name,
description="MiniMind三元组提取训练实验使用ROUGE损失优化三元组抽取性能",
config=config_dict
# 设置SwanLab服务器地址和API Key
# host="http://100.123.118.114:11071",
# api_key="LesBT7HRq23HNBrOPKP8S"
)
else:
swanlab_run = None
#########################################################
# 打印信息
#########################################################
# 计算每次迭代的token数量
tokens_per_iter = args.batch_size * lm_config.max_seq_len
if accelerator.is_main_process:
Logger(f"tokens_per_iter: {tokens_per_iter}", accelerator)
Logger("Configuration:", accelerator)
for key, value in config_dict.items():
Logger(f" {key}: {value}", accelerator)
#########################################################
# 设置自动混合精度上下文
#########################################################
ctx = nullcontext() if accelerator.device.type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
#########################################################
# 初始化模型和tokenizer
#########################################################
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path, args.database_init_path, args)
# 将accelerator传递给init_model函数中的Logger调用
Logger(f'模型初始化完成', accelerator)
#########################################################
# 处理位置编码张量问题
#########################################################
if hasattr(model, "pos_cis_real"):
Logger(f'检测到pos_cis_real实数张量将其设置为参与分布式训练', accelerator)
# 设置模型的_ddp_params_and_buffers_to_ignore属性
# model._ddp_params_and_buffers_to_ignore = {"pos_cis_real"}
# 兼容旧版本检查是否仍有pos_cis
elif hasattr(model, "pos_cis"):
Logger(f'检测到pos_cis复数张量将其设置为不参与分布式训练', accelerator)
# 设置模型的_ddp_params_and_buffers_to_ignore属性
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
#########################################################
# 创建数据集和数据加载器(专用于三元组提取训练)
#########################################################
Logger("三元组提取训练:使用 TriplePretrainDataset", accelerator)
train_ds = TriplePretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
# 创建自定义collate_fn来处理优化后的数据格式
def triple_collate_fn(batch):
# batch是一个包含字典的列表
input_ids = torch.stack([item['input_ids'] for item in batch])
labels = torch.stack([item['labels'] for item in batch])
loss_mask = torch.stack([item['loss_mask'] for item in batch])
target_input_ids = torch.stack([item['target_input_ids'] for item in batch])
target_attention_mask = torch.stack([item['target_attention_mask'] for item in batch])
target_sentences = [item['target_sentence'] for item in batch] # 用于调试
predicate_label = torch.stack([item['predicate_label'] for item in batch])
return {
'input_ids': input_ids,
'labels': labels,
'loss_mask': loss_mask,
'target_input_ids': target_input_ids,
'target_attention_mask': target_attention_mask,
'target_sentences': target_sentences,
'predicate_label': predicate_label
}
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=False, # ✅ 实验禁用pin_memory避免内存固定问题
drop_last=True, # 修复:避免边界条件导致的死锁
shuffle=True,
num_workers=0, # ✅ 实验禁用多进程避免worker死锁
# persistent_workers 和 prefetch_factor 在 num_workers=0 时自动禁用
collate_fn=triple_collate_fn
)
#########################################################
# 创建优化器
#########################################################
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
#########################################################
# 创建学习率调度器
#########################################################
total_steps = len(train_loader) * args.epochs
warmup_steps = args.warmup_iters if args.warmup_iters > 0 else int(0.1 * total_steps)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
#########################################################
# 准备训练
#########################################################
model, optimizer, train_loader, scheduler = accelerator.prepare(
model, optimizer, train_loader, scheduler
)
#########################################################
# 训练循环
#########################################################
overall_start_time = time.time() # Record overall start time
for epoch in range(args.epochs):
Logger(f"开始第{epoch+1}轮训练", accelerator)
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer) # Pass tokenizer
# 每个epoch结束后进行内存清理
Logger(f"{epoch+1}轮训练完成,进行内存清理", accelerator)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 记录epoch结束时的内存状态
if accelerator.is_main_process:
memory_info = get_memory_usage()
cuda_info = get_cuda_memory_usage()
log_msg = f"[Memory Monitor] Epoch {epoch+1} completed - "
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
Logger(log_msg, accelerator)
#########################################################
# 关闭SwanLab
#########################################################
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
swanlab_run.finish()
if __name__ == "__main__":
main()