1101 lines
52 KiB
Python
1101 lines
52 KiB
Python
import os
|
||
# 设置环境变量 - 将wandb替换为SwanLab
|
||
# os.environ["SWANLAB_MODE"] = "online" # SwanLab使用在线模式
|
||
import platform
|
||
import argparse
|
||
from tqdm import tqdm
|
||
import time
|
||
import math
|
||
import warnings
|
||
import pandas as pd
|
||
import torch
|
||
from torch import optim, nn
|
||
from torch.utils.data import DataLoader
|
||
from contextlib import nullcontext
|
||
from typing import Optional
|
||
import datetime # Add datetime for time formatting
|
||
from accelerate import Accelerator
|
||
from accelerate.utils import set_seed
|
||
from accelerate.utils import DeepSpeedPlugin
|
||
from accelerate.utils import DistributedDataParallelKwargs
|
||
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
||
import numpy as np
|
||
from sklearn.metrics.pairwise import cosine_similarity
|
||
import swanlab # 替换wandb导入
|
||
import gc # 添加垃圾回收模块
|
||
import psutil # 添加系统资源监控模块
|
||
|
||
from model.model_extra import MiniMindLM, RMSNorm # 使用model_extra
|
||
from model.LMConfig import LMConfig
|
||
from model.dataset import TriplePretrainDataset # 只需要三元组数据集
|
||
|
||
warnings.filterwarnings('ignore')
|
||
|
||
# 基于嵌入的余弦相似度损失计算函数
|
||
def compute_embedding_cosine_loss(subject_logits, predicate_logits, object_logits,
|
||
target_triples, tokenizer, tok_embeddings,
|
||
pooling_method='mean', max_targets=5, temperature=1.0):
|
||
"""
|
||
基于嵌入的余弦相似度损失计算
|
||
Args:
|
||
subject_logits: [batch_size, max_subject_len, vocab_size]
|
||
predicate_logits: [batch_size, max_predicate_len, vocab_size]
|
||
object_logits: [batch_size, max_object_len, vocab_size]
|
||
target_triples: List[List[str]] - 每个样本的多个目标句子
|
||
tokenizer: 分词器
|
||
tok_embeddings: 模型的token嵌入层
|
||
pooling_method: 句子嵌入的池化方法 ('mean', 'max', 'cls')
|
||
max_targets: int - 每个样本最大目标句子数量
|
||
temperature: float - Softmax温度参数,控制预测的平滑度
|
||
Returns:
|
||
torch.Tensor: 余弦相似度损失
|
||
"""
|
||
if not target_triples or len(target_triples) == 0:
|
||
# 创建一个与输入张量相关的损失,保持在计算图中
|
||
dummy_loss = subject_logits.sum() * 0.0 + 1.0 # 这样创建的张量会保持梯度
|
||
return dummy_loss
|
||
|
||
batch_size = subject_logits.shape[0]
|
||
|
||
# 1. 获取预测的嵌入表示
|
||
pred_embeddings = get_prediction_embeddings(
|
||
subject_logits, predicate_logits, object_logits,
|
||
tok_embeddings, pooling_method, temperature
|
||
) # [batch_size, embed_dim]
|
||
|
||
# 2. 获取目标的嵌入表示
|
||
target_embeddings = get_target_embeddings(
|
||
target_triples, tokenizer, tok_embeddings, pooling_method, max_targets
|
||
) # [batch_size, max_targets, embed_dim]
|
||
|
||
# 3. 计算余弦相似度
|
||
similarities = compute_cosine_similarity_batch(pred_embeddings, target_embeddings)
|
||
# [batch_size, max_targets]
|
||
|
||
# 4. 选择最高相似度(最小损失)
|
||
best_similarities = torch.max(similarities, dim=-1)[0] # [batch_size]
|
||
|
||
# 5. 转换为损失 (1 - cosine_similarity)
|
||
loss = 1.0 - best_similarities.mean()
|
||
|
||
# 确保损失值在合理范围内(保持计算图连接)
|
||
loss = torch.clamp(loss, min=0.0, max=2.0)
|
||
|
||
return loss
|
||
|
||
def get_prediction_embeddings(subject_logits, predicate_logits, object_logits,
|
||
tok_embeddings, pooling_method='mean', temperature=1.0):
|
||
"""
|
||
从预测logits获取句子嵌入(使用soft embedding保持梯度)
|
||
"""
|
||
batch_size = subject_logits.shape[0]
|
||
|
||
# 使用softmax获取概率分布,而不是argmax
|
||
subject_probs = torch.softmax(subject_logits / temperature, dim=-1) # [batch_size, max_subject_len, vocab_size]
|
||
predicate_probs = torch.softmax(predicate_logits / temperature, dim=-1) # [batch_size, max_predicate_len, vocab_size]
|
||
object_probs = torch.softmax(object_logits / temperature, dim=-1) # [batch_size, max_object_len, vocab_size]
|
||
|
||
# 使用概率分布与嵌入矩阵进行加权求和,得到soft embeddings
|
||
# tok_embeddings.weight: [vocab_size, embed_dim]
|
||
subject_embeddings = torch.matmul(subject_probs, tok_embeddings.weight) # [batch_size, max_subject_len, embed_dim]
|
||
predicate_embeddings = torch.matmul(predicate_probs, tok_embeddings.weight) # [batch_size, max_predicate_len, embed_dim]
|
||
object_embeddings = torch.matmul(object_probs, tok_embeddings.weight) # [batch_size, max_object_len, embed_dim]
|
||
|
||
# 拼接所有部分的嵌入
|
||
combined_embeddings = torch.cat([subject_embeddings, predicate_embeddings, object_embeddings], dim=1)
|
||
# [batch_size, total_len, embed_dim]
|
||
|
||
# 池化得到句子嵌入
|
||
if pooling_method == 'mean':
|
||
# 简单平均池化
|
||
sentence_embeddings = combined_embeddings.mean(dim=1)
|
||
elif pooling_method == 'max':
|
||
sentence_embeddings = combined_embeddings.max(dim=1)[0]
|
||
elif pooling_method == 'cls':
|
||
# 使用第一个token作为句子表示
|
||
sentence_embeddings = combined_embeddings[:, 0, :]
|
||
else:
|
||
sentence_embeddings = combined_embeddings.mean(dim=1)
|
||
|
||
return sentence_embeddings # [batch_size, embed_dim]
|
||
|
||
def get_target_embeddings(target_triples, tokenizer, tok_embeddings, pooling_method='mean', max_targets=5):
|
||
"""
|
||
批量获取目标句子的嵌入表示
|
||
Args:
|
||
target_triples: List[List[str]] - 每个样本的目标句子列表
|
||
max_targets: int - 每个样本最大目标句子数量,不足补空字符串,超过则截取前max_targets个
|
||
"""
|
||
batch_size = len(target_triples)
|
||
|
||
if not target_triples:
|
||
# 如果没有目标句子,返回与嵌入层相关的零嵌入(保持计算图)
|
||
embed_dim = tok_embeddings.embedding_dim
|
||
# 使用嵌入层的权重创建零张量,保持计算图连接
|
||
zero_embeddings = tok_embeddings.weight[:1, :].expand(batch_size, max_targets, embed_dim) * 0.0
|
||
return zero_embeddings
|
||
|
||
# 标准化每个样本的目标数量为max_targets
|
||
normalized_targets = []
|
||
for targets in target_triples:
|
||
if len(targets) >= max_targets:
|
||
# 超过max_targets,取前max_targets个
|
||
normalized_targets.extend(targets[:max_targets])
|
||
else:
|
||
# 不足max_targets,补空字符串
|
||
normalized_targets.extend(targets)
|
||
normalized_targets.extend([''] * (max_targets - len(targets)))
|
||
|
||
# 现在 normalized_targets 的长度是 batch_size * max_targets
|
||
assert len(normalized_targets) == batch_size * max_targets
|
||
|
||
# 批量tokenize所有目标句子
|
||
tokenized = tokenizer(
|
||
normalized_targets,
|
||
padding=True,
|
||
truncation=True,
|
||
return_tensors='pt',
|
||
max_length=128 # 可以调整
|
||
)
|
||
|
||
# 移到正确的设备
|
||
input_ids = tokenized['input_ids'].to(tok_embeddings.weight.device)
|
||
attention_mask = tokenized['attention_mask'].to(tok_embeddings.weight.device)
|
||
|
||
# 获取token嵌入
|
||
token_embeddings = tok_embeddings(input_ids) # [batch_size * max_targets, seq_len, embed_dim]
|
||
|
||
# 应用attention mask并池化
|
||
if pooling_method == 'mean':
|
||
# 使用attention mask进行加权平均
|
||
masked_embeddings = token_embeddings * attention_mask.unsqueeze(-1)
|
||
sentence_embeddings = masked_embeddings.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True).clamp(min=1e-8)
|
||
elif pooling_method == 'max':
|
||
# 在有效token上取最大值
|
||
masked_embeddings = token_embeddings.masked_fill(
|
||
~attention_mask.unsqueeze(-1).bool(), float('-inf')
|
||
)
|
||
sentence_embeddings = masked_embeddings.max(dim=1)[0]
|
||
else:
|
||
sentence_embeddings = token_embeddings.mean(dim=1)
|
||
|
||
# 重新整形为 [batch_size, max_targets, embed_dim]
|
||
embed_dim = sentence_embeddings.shape[-1]
|
||
target_embeddings = sentence_embeddings.view(batch_size, max_targets, embed_dim)
|
||
|
||
return target_embeddings
|
||
|
||
def compute_cosine_similarity_batch(pred_embeddings, target_embeddings):
|
||
"""
|
||
批量计算余弦相似度
|
||
Args:
|
||
pred_embeddings: [batch_size, embed_dim]
|
||
target_embeddings: [batch_size, max_targets, embed_dim]
|
||
Returns:
|
||
similarities: [batch_size, max_targets]
|
||
"""
|
||
# 标准化
|
||
pred_norm = torch.nn.functional.normalize(pred_embeddings, p=2, dim=-1) # [batch_size, embed_dim]
|
||
target_norm = torch.nn.functional.normalize(target_embeddings, p=2, dim=-1) # [batch_size, max_targets, embed_dim]
|
||
|
||
# 计算余弦相似度
|
||
# pred_norm: [batch_size, 1, embed_dim]
|
||
# target_norm: [batch_size, max_targets, embed_dim]
|
||
similarities = torch.sum(pred_norm.unsqueeze(1) * target_norm, dim=-1)
|
||
# [batch_size, max_targets]
|
||
|
||
return similarities
|
||
|
||
def triple_to_sentence(subject_logits, predicate_logits, object_logits, tokenizer):
|
||
"""
|
||
将三元组logits转换为句子
|
||
Args:
|
||
subject_logits: [batch_size, seq_len, max_subject_len, vocab_size]
|
||
predicate_logits: [batch_size, seq_len, max_predicate_len, vocab_size]
|
||
object_logits: [batch_size, seq_len, max_object_len, vocab_size]
|
||
tokenizer: 分词器
|
||
Returns:
|
||
List[List[str]]: 每个样本每个位置的三元组句子
|
||
"""
|
||
batch_size = subject_logits.shape[0]
|
||
predicate_seq_len = predicate_logits.shape[1]
|
||
subject_seq_len = subject_logits.shape[1]
|
||
object_seq_len = object_logits.shape[1]
|
||
|
||
predicate_logits = predicate_logits.reshape(batch_size*predicate_seq_len, -1)
|
||
subject_logits = subject_logits.reshape(batch_size*subject_seq_len, -1)
|
||
object_logits = object_logits.reshape(batch_size*object_seq_len, -1)
|
||
|
||
predicate_logits = torch.argmax(predicate_logits, dim=-1)
|
||
subject_logits = torch.argmax(subject_logits, dim=-1)
|
||
object_logits = torch.argmax(object_logits, dim=-1)
|
||
|
||
predicate_logits = predicate_logits.reshape(batch_size, predicate_seq_len)
|
||
subject_logits = subject_logits.reshape(batch_size, subject_seq_len)
|
||
object_logits = object_logits.reshape(batch_size, object_seq_len)
|
||
|
||
combined_logits = torch.cat([subject_logits, predicate_logits, object_logits], dim=1)
|
||
|
||
sentences = tokenizer.batch_decode(combined_logits, skip_special_tokens=True)
|
||
|
||
# sentences = []
|
||
|
||
# for batch_idx in range(batch_size):
|
||
# batch_sentences = []
|
||
# for seq_idx in range(seq_len):
|
||
# # 获取预测的token ids
|
||
# subject_ids = torch.argmax(subject_logits[batch_idx, seq_idx], dim=-1)
|
||
# predicate_ids = torch.argmax(predicate_logits[batch_idx, seq_idx], dim=-1)
|
||
# object_ids = torch.argmax(object_logits[batch_idx, seq_idx], dim=-1)
|
||
|
||
# # 转换为文本
|
||
# subject_text = tokenizer.decode(subject_ids, skip_special_tokens=True).strip()
|
||
# predicate_text = tokenizer.decode(predicate_ids, skip_special_tokens=True).strip()
|
||
# object_text = tokenizer.decode(object_ids, skip_special_tokens=True).strip()
|
||
|
||
# # 拼接为句子 (主语 + 谓语 + 宾语)
|
||
# if subject_text and predicate_text and object_text:
|
||
# sentence = f"{subject_text} {predicate_text} {object_text}"
|
||
# else:
|
||
# sentence = ""
|
||
|
||
# batch_sentences.append(sentence)
|
||
# sentences.append(batch_sentences)
|
||
|
||
return sentences
|
||
|
||
def compute_triple_rouge_loss_optimized(subject_logits, predicate_logits, object_logits,
|
||
target_input_ids, target_attention_mask, tok_embeddings, temperature=1.0):
|
||
"""
|
||
优化的三元组嵌入余弦相似度损失计算(单个target版本)
|
||
Args:
|
||
subject_logits: [batch_size, max_subject_len, vocab_size]
|
||
predicate_logits: [batch_size, max_predicate_len, vocab_size]
|
||
object_logits: [batch_size, max_object_len, vocab_size]
|
||
target_input_ids: [batch_size, target_seq_len] - 预tokenized的目标句子
|
||
target_attention_mask: [batch_size, target_seq_len] - 目标句子的attention mask
|
||
tok_embeddings: 模型的token嵌入层
|
||
temperature: float - Softmax温度参数,控制预测的平滑度
|
||
Returns:
|
||
torch.Tensor: 嵌入余弦相似度损失 (标量)
|
||
"""
|
||
batch_size = subject_logits.shape[0]
|
||
|
||
# ✅ 修复:确保target数据在正确的设备上
|
||
device = tok_embeddings.weight.device
|
||
target_input_ids = target_input_ids.to(device)
|
||
target_attention_mask = target_attention_mask.to(device)
|
||
|
||
# 1. 获取预测的嵌入表示(使用soft embedding保持梯度)
|
||
subject_probs = torch.softmax(subject_logits / temperature, dim=-1)
|
||
predicate_probs = torch.softmax(predicate_logits / temperature, dim=-1)
|
||
object_probs = torch.softmax(object_logits / temperature, dim=-1)
|
||
|
||
# 使用概率分布与嵌入矩阵进行加权求和
|
||
subject_embeddings = torch.matmul(subject_probs, tok_embeddings.weight)
|
||
predicate_embeddings = torch.matmul(predicate_probs, tok_embeddings.weight)
|
||
object_embeddings = torch.matmul(object_probs, tok_embeddings.weight)
|
||
|
||
# 拼接所有部分的嵌入并平均池化
|
||
combined_embeddings = torch.cat([subject_embeddings, predicate_embeddings, object_embeddings], dim=1)
|
||
pred_embeddings = combined_embeddings.mean(dim=1) # [batch_size, embed_dim]
|
||
|
||
# 2. 获取目标的嵌入表示(直接使用预tokenized的数据)
|
||
target_embeddings = tok_embeddings(target_input_ids) # [batch_size, target_seq_len, embed_dim]
|
||
|
||
# 使用attention mask进行加权平均池化
|
||
masked_embeddings = target_embeddings * target_attention_mask.unsqueeze(-1)
|
||
target_pooled = masked_embeddings.sum(dim=1) / target_attention_mask.sum(dim=1, keepdim=True).clamp(min=1e-8)
|
||
# [batch_size, embed_dim]
|
||
|
||
# 3. 计算余弦相似度
|
||
pred_norm = torch.nn.functional.normalize(pred_embeddings, p=2, dim=-1)
|
||
target_norm = torch.nn.functional.normalize(target_pooled, p=2, dim=-1)
|
||
|
||
# 计算余弦相似度
|
||
similarities = torch.sum(pred_norm * target_norm, dim=-1) # [batch_size]
|
||
|
||
# 4. 转换为损失 (1 - cosine_similarity)
|
||
loss = 1.0 - similarities.mean()
|
||
|
||
# 确保损失值在合理范围内
|
||
loss = torch.clamp(loss, min=0.0, max=2.0)
|
||
|
||
return loss
|
||
|
||
def compute_triple_rouge_loss(subject_logits, predicate_logits, object_logits, target_triples, tokenizer, tok_embeddings, max_targets=5, temperature=1.0):
|
||
"""
|
||
原始版本的三元组损失计算(保留用于兼容性)
|
||
Args:
|
||
subject_logits: [batch_size, max_subject_len, vocab_size]
|
||
predicate_logits: [batch_size, max_predicate_len, vocab_size]
|
||
object_logits: [batch_size, max_object_len, vocab_size]
|
||
target_triples: List[List[str]] - 每个样本的多个真值三元组句子
|
||
tokenizer: 分词器
|
||
tok_embeddings: 模型的token嵌入层
|
||
max_targets: int - 每个样本最大目标句子数量
|
||
temperature: float - Softmax温度参数,控制预测的平滑度
|
||
Returns:
|
||
torch.Tensor: 嵌入余弦相似度损失 (标量)
|
||
"""
|
||
return compute_embedding_cosine_loss(
|
||
subject_logits, predicate_logits, object_logits,
|
||
target_triples, tokenizer, tok_embeddings, pooling_method='mean', max_targets=max_targets, temperature=temperature
|
||
)
|
||
|
||
# 内存监控辅助函数
|
||
def get_memory_usage():
|
||
"""获取当前内存使用情况"""
|
||
process = psutil.Process()
|
||
memory_info = process.memory_info()
|
||
return {
|
||
'rss_mb': memory_info.rss / 1024 / 1024, # 物理内存使用量(MB)
|
||
'vms_mb': memory_info.vms / 1024 / 1024, # 虚拟内存使用量(MB)
|
||
}
|
||
|
||
def get_cuda_memory_usage():
|
||
"""获取CUDA内存使用情况"""
|
||
if torch.cuda.is_available():
|
||
return {
|
||
'cuda_allocated_mb': torch.cuda.memory_allocated() / 1024 / 1024,
|
||
'cuda_reserved_mb': torch.cuda.memory_reserved() / 1024 / 1024,
|
||
'cuda_max_allocated_mb': torch.cuda.max_memory_allocated() / 1024 / 1024,
|
||
}
|
||
return {}
|
||
|
||
def get_tensor_memory_size(tensor_list):
|
||
"""计算tensor列表的总内存占用(MB)"""
|
||
total_size = 0
|
||
for batch in tensor_list:
|
||
if isinstance(batch, (list, tuple)):
|
||
for tensor in batch:
|
||
if isinstance(tensor, torch.Tensor):
|
||
total_size += tensor.numel() * tensor.element_size()
|
||
elif isinstance(batch, torch.Tensor):
|
||
total_size += batch.numel() * batch.element_size()
|
||
return total_size / 1024 / 1024 # 转换为MB
|
||
|
||
def log_memory_status(step, accelerator, stage="", detailed=False):
|
||
"""记录内存状态"""
|
||
if not accelerator.is_main_process:
|
||
return
|
||
|
||
memory_info = get_memory_usage()
|
||
cuda_info = get_cuda_memory_usage()
|
||
|
||
log_msg = f"[Memory Monitor] Step {step} {stage} - "
|
||
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
|
||
|
||
if cuda_info:
|
||
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
|
||
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
|
||
|
||
if detailed:
|
||
log_msg += f", System VMS: {memory_info['vms_mb']:.2f}MB"
|
||
if cuda_info:
|
||
log_msg += f", CUDA max allocated: {cuda_info['cuda_max_allocated_mb']:.2f}MB"
|
||
|
||
Logger(log_msg, accelerator)
|
||
|
||
# 验证函数
|
||
def validate_model(model, val_loader, accelerator, ctx, args):
|
||
"""
|
||
验证模型性能
|
||
Args:
|
||
model: 模型
|
||
val_loader: 验证集数据加载器
|
||
accelerator: accelerator对象
|
||
ctx: 上下文管理器
|
||
args: 参数
|
||
Returns:
|
||
dict: 包含平均损失和准确率的字典
|
||
"""
|
||
model.eval()
|
||
|
||
total_loss = 0.0
|
||
correct_predictions = 0
|
||
total_predictions = 0
|
||
num_batches = 0
|
||
|
||
criterion = nn.CrossEntropyLoss()
|
||
|
||
with torch.no_grad():
|
||
for batch_data in val_loader:
|
||
try:
|
||
# 数据准备
|
||
X = batch_data['input_ids'].to(accelerator.device)
|
||
Y = batch_data['labels']
|
||
|
||
# 前向传播
|
||
with ctx:
|
||
res = model(X, step=0) # 验证时step设为0
|
||
loss = criterion(res.predicate_class.cpu(), Y.cpu())
|
||
|
||
# 计算准确率
|
||
predicted_classes = torch.argmax(res.predicate_class, dim=1)
|
||
predicted_classes = predicted_classes.to(Y.device)
|
||
correct_predictions += (predicted_classes == Y).sum().item()
|
||
total_predictions += Y.size(0)
|
||
|
||
# 累计损失
|
||
total_loss += loss.item()
|
||
num_batches += 1
|
||
|
||
except Exception as e:
|
||
Logger(f"验证时出错: {e}", accelerator)
|
||
continue
|
||
|
||
# 计算平均值
|
||
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
||
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
||
|
||
model.train() # 重新设置为训练模式
|
||
|
||
return {
|
||
'avg_loss': avg_loss,
|
||
'accuracy': accuracy,
|
||
'total_samples': total_predictions
|
||
}
|
||
|
||
# 日志记录函数
|
||
def Logger(msg, accelerator=None):
|
||
# 如果没有提供accelerator,则只在主进程打印
|
||
if accelerator is None or accelerator.is_main_process:
|
||
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}")
|
||
|
||
# Helper function to format seconds into HH:MM:SS
|
||
def format_time(seconds):
|
||
return str(datetime.timedelta(seconds=int(seconds)))
|
||
|
||
# 获取学习率函数
|
||
def get_lr(it, num_iters, learning_rate):
|
||
# 余弦学习率衰减
|
||
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
|
||
|
||
# 初始化模型函数
|
||
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
|
||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||
model = MiniMindLM(lm_config, mode="triple") # 设置为三元组模式
|
||
|
||
# 加载预训练权重
|
||
pretrained_path = "./out/Experiment_1_2_2_pretrain_512.pth"
|
||
Logger(f"Loading pretrained weights from {pretrained_path}")
|
||
|
||
try:
|
||
# 加载预训练的state_dict
|
||
pretrained_state_dict = torch.load(pretrained_path, map_location='cpu')
|
||
Logger(f"Successfully loaded pretrained state_dict with {len(pretrained_state_dict)} parameters")
|
||
|
||
# 获取当前模型的state_dict
|
||
model_state_dict = model.state_dict()
|
||
|
||
# 统计加载情况
|
||
loaded_params = []
|
||
skipped_params = []
|
||
|
||
# 逐个加载兼容的权重
|
||
for name, param in pretrained_state_dict.items():
|
||
if name in model_state_dict:
|
||
if model_state_dict[name].shape == param.shape:
|
||
model_state_dict[name].copy_(param)
|
||
loaded_params.append(name)
|
||
else:
|
||
Logger(f"Warning: Shape mismatch for {name}, expected {model_state_dict[name].shape}, got {param.shape}")
|
||
skipped_params.append(f"{name} (shape mismatch)")
|
||
else:
|
||
skipped_params.append(f"{name} (not found in model2)")
|
||
|
||
Logger(f"Loaded {len(loaded_params)} parameters from pretrained weights")
|
||
Logger(f"Skipped {len(skipped_params)} parameters")
|
||
|
||
# 显示一些关键加载的参数
|
||
key_loaded = [name for name in loaded_params if any(key in name for key in ['tok_embeddings', 'layers.0', 'knowledge_dataset', 'output', 'norm'])]
|
||
if key_loaded:
|
||
Logger("Key loaded parameters:")
|
||
for name in key_loaded[:5]: # 只显示前5个
|
||
Logger(f" ✅ {name}")
|
||
if len(key_loaded) > 5:
|
||
Logger(f" ... and {len(key_loaded) - 5} more")
|
||
|
||
# 显示跳过的参数(应该主要是triple_extraction_head相关的)
|
||
triple_skipped = [name for name in skipped_params if 'triple_extraction_head' in name]
|
||
if triple_skipped:
|
||
Logger("Triple extraction head parameters (newly initialized):")
|
||
for name in triple_skipped[:3]: # 只显示前3个
|
||
Logger(f" 🆕 {name}")
|
||
if len(triple_skipped) > 3:
|
||
Logger(f" ... and {len(triple_skipped) - 3} more")
|
||
|
||
except Exception as e:
|
||
Logger(f"Error loading pretrained weights: {e}")
|
||
Logger("Falling back to default initialization...")
|
||
|
||
# 默认模型初始化(备用方案)
|
||
Logger("Performing default model initialization...")
|
||
|
||
# 初始化嵌入层权重
|
||
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
|
||
|
||
# 初始化输出层权重(如果不共享权重的话)
|
||
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
|
||
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
|
||
|
||
# 初始化所有线性层
|
||
for name, module in model.named_modules():
|
||
if isinstance(module, nn.Linear):
|
||
# 使用Xavier/Glorot初始化
|
||
nn.init.xavier_uniform_(module.weight)
|
||
if module.bias is not None:
|
||
nn.init.zeros_(module.bias)
|
||
elif isinstance(module, nn.Embedding):
|
||
# 嵌入层使用正态分布初始化
|
||
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||
elif isinstance(module, RMSNorm):
|
||
# RMSNorm的权重初始化为1
|
||
if hasattr(module, 'weight'):
|
||
nn.init.ones_(module.weight)
|
||
|
||
# 初始化位置编码相关参数
|
||
if hasattr(model.knowledge_dataset, 'keys'):
|
||
nn.init.normal_(model.knowledge_dataset.keys, mean=0.0, std=0.02)
|
||
|
||
Logger("Default model initialization completed")
|
||
|
||
# 如果提供了预训练的嵌入权重,加载它们
|
||
if pretrained_embedding_path:
|
||
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
|
||
pretrained_embeddings = torch.load(pretrained_embedding_path)
|
||
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
|
||
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
|
||
|
||
|
||
|
||
Logger(f"Database embeddings and sentences stored in model")
|
||
|
||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||
return model, tokenizer
|
||
|
||
def train_epoch(epoch, accelerator, model, train_loader,val_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer):
|
||
# 三元组提取训练模式:不需要传统的交叉熵损失函数
|
||
epoch_start_time = time.time()
|
||
total_steps_in_epoch = len(train_loader)
|
||
total_training_steps = args.epochs * total_steps_in_epoch
|
||
moe_path = '_moe' if args.use_moe else ''
|
||
best_loss = float('10000')
|
||
|
||
# 初始化CUDA事件变量 - 只保留GPU计算时间追踪
|
||
forward_start = forward_end = loss_start = loss_end = backward_start = backward_end = optimizer_start = optimizer_end = None
|
||
|
||
# 添加CUDA事件来分析GPU性能 (只在主进程进行)
|
||
if args.profile and accelerator.is_main_process:
|
||
forward_start = torch.cuda.Event(enable_timing=True)
|
||
forward_end = torch.cuda.Event(enable_timing=True)
|
||
loss_start = torch.cuda.Event(enable_timing=True)
|
||
loss_end = torch.cuda.Event(enable_timing=True)
|
||
backward_start = torch.cuda.Event(enable_timing=True)
|
||
backward_end = torch.cuda.Event(enable_timing=True)
|
||
optimizer_start = torch.cuda.Event(enable_timing=True)
|
||
optimizer_end = torch.cuda.Event(enable_timing=True)
|
||
|
||
# 移除自定义预取机制,使用DataLoader内置预取
|
||
# 记录初始内存状态
|
||
if args.memory_monitor:
|
||
memory_info = get_memory_usage()
|
||
cuda_info = get_cuda_memory_usage()
|
||
log_msg = f"[Memory Monitor] Training start - System RSS: {memory_info['rss_mb']:.2f}MB"
|
||
if cuda_info:
|
||
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
|
||
Logger(log_msg, accelerator)
|
||
|
||
# 在开始循环前初始化日志记录所需变量
|
||
last_log_time = epoch_start_time
|
||
|
||
# 使用DataLoader内置的iterator,移除自定义预取
|
||
for step, batch_data in enumerate(train_loader):
|
||
# === 每个step开始 ===
|
||
|
||
try:
|
||
# === 1. 数据准备 ===
|
||
# 直接使用DataLoader提供的数据
|
||
if not isinstance(batch_data, dict):
|
||
raise ValueError("期望字典格式的批次数据,请确保使用 TriplePretrainDataset")
|
||
|
||
X = batch_data['input_ids']
|
||
Y = batch_data['labels']
|
||
loss_mask = batch_data['loss_mask']
|
||
# target_input_ids = batch_data['target_input_ids']
|
||
# target_attention_mask = batch_data['target_attention_mask']
|
||
# target_sentences = batch_data['target_sentences'] # 用于调试输出
|
||
|
||
# === 2. 学习率更新 ===
|
||
if scheduler is not None:
|
||
scheduler.step()
|
||
|
||
# === 3. 前向传播 ===
|
||
# 计时GPU前向传播
|
||
if args.profile and accelerator.is_main_process and forward_start is not None:
|
||
forward_start.record()
|
||
|
||
# 前向传播
|
||
with ctx:
|
||
if step == 0 and args.embedding_epoch == epoch:
|
||
unwrapped_model = accelerator.unwrap_model(model)
|
||
unwrapped_model.freeze_embedding = True
|
||
Logger(f"Set freeze_embedding=True for epoch {epoch}, step {step}", accelerator)
|
||
res = model(X, step=step)
|
||
|
||
# 计时GPU前向传播结束
|
||
if args.profile and accelerator.is_main_process and forward_end is not None:
|
||
forward_end.record()
|
||
|
||
# === 4. 损失计算 ===
|
||
# 三元组提取模式:只使用ROUGE Loss进行三元组损失计算
|
||
# Logger("三元组提取训练模式", accelerator) if step == 0 else None
|
||
|
||
# # 确保有三元组输出
|
||
# if not (hasattr(res, 'predicate_logits') and hasattr(res, 'subject_logits') and hasattr(res, 'object_logits')):
|
||
# raise ValueError("模型没有输出三元组logits,请检查模型配置")
|
||
|
||
# # 确保有目标数据
|
||
# if target_input_ids is None:
|
||
# raise ValueError("没有三元组目标数据,请检查数据格式")
|
||
|
||
# 计算分类损失
|
||
try:
|
||
Logger("使用分类交叉熵损失", accelerator) if step == 0 else None
|
||
|
||
# 计时GPU损失计算
|
||
if args.profile and accelerator.is_main_process and loss_start is not None:
|
||
loss_start.record()
|
||
|
||
# 计算交叉熵损失
|
||
criterion = nn.CrossEntropyLoss()
|
||
loss = criterion(res.predicate_class, Y)
|
||
|
||
# 计时GPU损失计算结束
|
||
if args.profile and accelerator.is_main_process and loss_end is not None:
|
||
loss_end.record()
|
||
|
||
except Exception as e:
|
||
Logger(f"Error: 分类损失计算失败: {e}", accelerator)
|
||
import traceback
|
||
Logger(f"Traceback: {traceback.format_exc()}", accelerator)
|
||
loss = res.logits.sum() * 0.0 + 1.0
|
||
|
||
loss = loss / args.accumulation_steps
|
||
|
||
# === 5. 反向传播 ===
|
||
# 计时GPU反向传播
|
||
if args.profile and accelerator.is_main_process and backward_start is not None:
|
||
backward_start.record()
|
||
|
||
# 反向传播
|
||
accelerator.backward(loss)
|
||
|
||
# 计时GPU反向传播结束
|
||
if args.profile and accelerator.is_main_process and backward_end is not None:
|
||
backward_end.record()
|
||
|
||
# === 6. 优化器步骤 ===
|
||
# 计时GPU优化器步骤
|
||
if args.profile and accelerator.is_main_process and optimizer_start is not None:
|
||
optimizer_start.record()
|
||
|
||
# 优化器步骤
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
|
||
# 计时GPU优化器步骤结束
|
||
if args.profile and accelerator.is_main_process and optimizer_end is not None:
|
||
optimizer_end.record()
|
||
|
||
# === 7. 日志记录 ===
|
||
# 打印训练信息 (只在主进程进行)
|
||
if (step + 1) % args.log_interval == 0 and accelerator.is_main_process:
|
||
current_time = time.time()
|
||
|
||
# 计算GPU性能指标
|
||
if args.profile and accelerator.is_main_process:
|
||
torch.cuda.synchronize()
|
||
|
||
# 获取GPU时间
|
||
try:
|
||
forward_time = forward_start.elapsed_time(forward_end) if forward_start is not None and forward_end is not None else 0
|
||
loss_time = loss_start.elapsed_time(loss_end) if loss_start is not None and loss_end is not None else 0
|
||
backward_time = backward_start.elapsed_time(backward_end) if backward_start is not None and backward_end is not None else 0
|
||
optimizer_time = optimizer_start.elapsed_time(optimizer_end) if optimizer_start is not None and optimizer_end is not None else 0
|
||
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
|
||
|
||
# 打印GPU性能分析
|
||
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
|
||
# 计算GPU时间
|
||
gpu_time_total = (forward_time + loss_time + backward_time + optimizer_time) / args.log_interval
|
||
|
||
Logger(f"=== GPU性能分析 (平均每步) ===", accelerator)
|
||
Logger(f"前向传播: {forward_time/args.log_interval:.2f}ms, "
|
||
f"损失计算: {loss_time/args.log_interval:.2f}ms, "
|
||
f"反向传播: {backward_time/args.log_interval:.2f}ms, "
|
||
f"优化器: {optimizer_time/args.log_interval:.2f}ms", accelerator)
|
||
Logger(f"GPU总时间: {gpu_time_total:.2f}ms, "
|
||
f"实际迭代时间: {iter_time:.2f}ms, "
|
||
f"GPU利用率: {gpu_time_total/iter_time*100:.1f}%", accelerator)
|
||
Logger("=" * 50, accelerator)
|
||
|
||
# Logger("=== 三元组预测示例 ===", accelerator)
|
||
# predict_sentences = triple_to_sentence(res.subject_logits, res.predicate_logits, res.object_logits,tokenizer)
|
||
# # 显示前2个样本的目标句子
|
||
# for i, target_sentence in enumerate(target_sentences[:2]):
|
||
# Logger(f"样本{i+1}目标: {target_sentence}", accelerator)
|
||
# Logger(f"样本{i+1}预测: {predict_sentences[i]}", accelerator)
|
||
Logger("=======val dataset=========", accelerator)
|
||
|
||
# 重置GPU事件
|
||
forward_start = torch.cuda.Event(enable_timing=True)
|
||
forward_end = torch.cuda.Event(enable_timing=True)
|
||
loss_start = torch.cuda.Event(enable_timing=True)
|
||
loss_end = torch.cuda.Event(enable_timing=True)
|
||
backward_start = torch.cuda.Event(enable_timing=True)
|
||
backward_end = torch.cuda.Event(enable_timing=True)
|
||
optimizer_start = torch.cuda.Event(enable_timing=True)
|
||
optimizer_end = torch.cuda.Event(enable_timing=True)
|
||
except RuntimeError as e:
|
||
if "Both events must be recorded" in str(e):
|
||
Logger(f"Warning: CUDA events not properly recorded, skipping performance analysis: {e}", accelerator)
|
||
else:
|
||
raise e
|
||
|
||
# 计算基本指标
|
||
current_lr = optimizer.param_groups[0]['lr']
|
||
epoch_elapsed_time = current_time - epoch_start_time
|
||
epoch_steps_done = step + 1
|
||
epoch_avg_step_time = epoch_elapsed_time / epoch_steps_done
|
||
epoch_remaining_time = epoch_avg_step_time * (total_steps_in_epoch - epoch_steps_done)
|
||
|
||
total_elapsed_time = current_time - overall_start_time
|
||
total_steps_done = epoch * total_steps_in_epoch + epoch_steps_done
|
||
total_avg_step_time = total_elapsed_time / total_steps_done if total_steps_done > 0 else 0
|
||
total_remaining_time = total_avg_step_time * (total_training_steps - total_steps_done) if total_steps_done > 0 else 0
|
||
|
||
# 计算训练速度
|
||
interval_elapsed_time = current_time - last_log_time
|
||
tokens_processed_interval = args.log_interval * args.batch_size * args.max_seq_len
|
||
tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0
|
||
last_log_time = current_time
|
||
|
||
# 基本训练信息
|
||
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
|
||
f"Loss: {loss.item() * args.accumulation_steps:.6f}, "
|
||
f"LR: {current_lr:.6f}, "
|
||
f"Speed: {tokens_per_sec:.2f} tokens/sec | "
|
||
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
|
||
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
|
||
|
||
# SwanLab日志记录
|
||
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
|
||
Logger("=======val dataset=========", accelerator)
|
||
|
||
# 验证集评估
|
||
val_results = validate_model(model, val_loader, accelerator, ctx, args)
|
||
Logger(f"验证集结果 - 平均损失: {val_results['avg_loss']:.6f}, 准确率: {val_results['accuracy']:.4f}, 样本数: {val_results['total_samples']}", accelerator)
|
||
|
||
log_dict = {
|
||
"epoch": epoch + 1,
|
||
"step": step + 1,
|
||
"total_steps_in_epoch": total_steps_in_epoch,
|
||
"train_loss": loss.item() * args.accumulation_steps,
|
||
"val_loss": val_results['avg_loss'],
|
||
"val_accuracy": val_results['accuracy'],
|
||
"val_samples": val_results['total_samples'],
|
||
"lr": current_lr,
|
||
"tokens_per_sec": tokens_per_sec,
|
||
"epoch_time_left_seconds": epoch_remaining_time,
|
||
"total_time_left_seconds": total_remaining_time
|
||
}
|
||
swanlab_run.log(log_dict)
|
||
|
||
# === 8. 模型保存 ===
|
||
# 保存模型 (只在主进程进行)
|
||
loss_total = loss.item() * args.accumulation_steps
|
||
if epoch > 1 and best_loss > loss_total and accelerator.is_main_process:
|
||
best_loss = loss_total
|
||
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
|
||
unwrapped_model = accelerator.unwrap_model(model)
|
||
accelerator.save(unwrapped_model.state_dict(), ckp)
|
||
Logger(f"Model saved to {ckp}", accelerator)
|
||
|
||
except Exception as e:
|
||
Logger(f"Error in training step: {e}", accelerator)
|
||
import traceback
|
||
Logger(traceback.format_exc(), accelerator)
|
||
|
||
# 清理内存,防止内存泄漏
|
||
gc.collect()
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
|
||
# 训练epoch结束时清理内存
|
||
gc.collect()
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="MiniMind Triple Extraction Training with Accelerate")
|
||
parser.add_argument("--out_dir", type=str, default="out")
|
||
parser.add_argument("--epochs", type=int, default=4)
|
||
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
|
||
parser.add_argument("--batch_size", type=int, default=256)
|
||
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||
parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数
|
||
parser.add_argument("--swanlab_project", type=str, default="MiniMind-TripleExtraction") # 替换wandb参数
|
||
parser.add_argument("--num_workers", type=int, default=1)
|
||
parser.add_argument("--accumulation_steps", type=int, default=32)
|
||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||
parser.add_argument("--warmup_iters", type=int, default=0)
|
||
parser.add_argument("--log_interval", type=int, default=100)
|
||
parser.add_argument("--save_interval", type=int, default=10000)
|
||
parser.add_argument('--dim', default=512, type=int)
|
||
parser.add_argument('--n_layers', default=8, type=int)
|
||
parser.add_argument('--max_seq_len', default=512, type=int)
|
||
parser.add_argument('--use_moe', default=False, type=bool)
|
||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
||
parser.add_argument("--data_path", type=str, default="./dataset/processed_trex_data.json")
|
||
parser.add_argument("--predicate_vocab_path", type=str, default="./dataset/predicate_stats.json", help="Path to predicate vocabulary/statistics file")
|
||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||
parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目")
|
||
parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
|
||
parser.add_argument("--database_init_path", type=str, default="./dataset/combined_prepare.json", help="数据库初始化路径")
|
||
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
|
||
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
|
||
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
|
||
parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控")
|
||
parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)")
|
||
parser.add_argument("--max_targets", type=int, default=5, help="每个样本最大目标句子数量,用于批处理优化")
|
||
parser.add_argument("--temperature", type=float, default=1.0, help="Softmax温度参数,用于控制预测的平滑度")
|
||
parser.add_argument("--detailed_timing", action="store_true", default=True, help="启用详细的时间追踪分析")
|
||
# 移除dataset_type参数,此训练脚本专用于三元组提取训练
|
||
# parser.add_argument("--dataset_type", type=str, default="pretrain", choices=["pretrain", "triple"], help="数据集类型:pretrain(标准预训练)或triple(三元组)")
|
||
args = parser.parse_args()
|
||
|
||
#########################################################
|
||
# 初始化accelerator和deepspeed
|
||
#########################################################
|
||
# 设置ddp_kwargs以处理未使用的参数
|
||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||
# 创建DeepSpeedPlugin对象
|
||
ds_plugin = DeepSpeedPlugin(
|
||
gradient_accumulation_steps=args.accumulation_steps,
|
||
gradient_clipping=args.grad_clip,
|
||
zero_stage=2, # 使用ZeRO-2优化
|
||
offload_optimizer_device="none", # 将优化器状态卸载到CPU
|
||
offload_param_device="none", # 不将参数卸载到CPU
|
||
)
|
||
accelerator = Accelerator(
|
||
kwargs_handlers=[ddp_kwargs],
|
||
deepspeed_plugin=ds_plugin,
|
||
mixed_precision="bf16" if args.dtype == "bfloat16" else "fp16" if args.dtype == "float16" else "no"
|
||
)
|
||
|
||
#########################################################
|
||
# 设置随机种子
|
||
#########################################################
|
||
set_seed(1337 + accelerator.process_index)
|
||
|
||
#########################################################
|
||
# 配置模型
|
||
#########################################################
|
||
lm_config = LMConfig(
|
||
dim=args.dim,
|
||
n_layers=args.n_layers,
|
||
max_seq_len=args.max_seq_len,
|
||
use_moe=args.use_moe,
|
||
disable_db=args.disable_db,
|
||
flash_attn=args.use_flash_attn,
|
||
knowledge_num=args.knowledge_num,
|
||
knowledge_length=args.knowledge_length,
|
||
embeddings_epoch=args.embedding_epoch
|
||
)
|
||
|
||
#########################################################
|
||
# 创建保存目录
|
||
#########################################################
|
||
args.save_dir = os.path.join(args.out_dir)
|
||
if accelerator.is_main_process:
|
||
os.makedirs(args.save_dir, exist_ok=True)
|
||
os.makedirs(args.out_dir, exist_ok=True)
|
||
|
||
#########################################################
|
||
# 设置数据类型
|
||
#########################################################
|
||
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
|
||
|
||
|
||
#########################################################
|
||
# 配置SwanLab
|
||
#########################################################
|
||
# 设置SwanLab运行名称
|
||
args.swanlab_run_name = f"MiniMind-TripleExtraction-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||
|
||
# 合并args和lm_config为一个字典(无论是否使用SwanLab都需要,用于打印配置信息)
|
||
config_dict = vars(args).copy()
|
||
config_dict.update(vars(lm_config))
|
||
|
||
# 初始化SwanLab实验实例
|
||
swanlab_run = None
|
||
if args.use_swanlab and accelerator.is_main_process:
|
||
# 初始化SwanLab
|
||
swanlab_run = swanlab.init(
|
||
project=args.swanlab_project,
|
||
experiment_name=args.swanlab_run_name,
|
||
description="MiniMind三元组提取训练实验,使用ROUGE损失优化三元组抽取性能",
|
||
config=config_dict
|
||
# 设置SwanLab服务器地址和API Key
|
||
# host="http://100.123.118.114:11071",
|
||
# api_key="LesBT7HRq23HNBrOPKP8S"
|
||
)
|
||
else:
|
||
swanlab_run = None
|
||
|
||
#########################################################
|
||
# 打印信息
|
||
#########################################################
|
||
# 计算每次迭代的token数量
|
||
tokens_per_iter = args.batch_size * lm_config.max_seq_len
|
||
if accelerator.is_main_process:
|
||
Logger(f"tokens_per_iter: {tokens_per_iter}", accelerator)
|
||
Logger("Configuration:", accelerator)
|
||
for key, value in config_dict.items():
|
||
Logger(f" {key}: {value}", accelerator)
|
||
|
||
|
||
#########################################################
|
||
# 设置自动混合精度上下文
|
||
#########################################################
|
||
ctx = nullcontext() if accelerator.device.type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
|
||
|
||
#########################################################
|
||
# 初始化模型和tokenizer
|
||
#########################################################
|
||
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path, args.database_init_path, args)
|
||
# 将accelerator传递给init_model函数中的Logger调用
|
||
Logger(f'模型初始化完成', accelerator)
|
||
|
||
#########################################################
|
||
# 处理位置编码张量问题
|
||
#########################################################
|
||
if hasattr(model, "pos_cis_real"):
|
||
Logger(f'检测到pos_cis_real实数张量,将其设置为参与分布式训练', accelerator)
|
||
# 设置模型的_ddp_params_and_buffers_to_ignore属性
|
||
# model._ddp_params_and_buffers_to_ignore = {"pos_cis_real"}
|
||
# 兼容旧版本,检查是否仍有pos_cis
|
||
elif hasattr(model, "pos_cis"):
|
||
Logger(f'检测到pos_cis复数张量,将其设置为不参与分布式训练', accelerator)
|
||
# 设置模型的_ddp_params_and_buffers_to_ignore属性
|
||
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
||
|
||
#########################################################
|
||
# 创建数据集和数据加载器(专用于三元组提取训练)
|
||
#########################################################
|
||
Logger("三元组提取训练:使用 TriplePretrainDataset", accelerator)
|
||
train_ds = TriplePretrainDataset(data_path=args.data_path, predicate_vocab_path=args.predicate_vocab_path, tokenizer=tokenizer, max_length=lm_config.max_seq_len)
|
||
val_ds = TriplePretrainDataset(data_path=args.data_path,samples=train_ds.get_val_samples(), predicate_vocab_path=args.predicate_vocab_path, tokenizer=tokenizer, max_length=lm_config.max_seq_len)
|
||
|
||
# 创建自定义collate_fn来处理优化后的数据格式
|
||
def triple_collate_fn(batch):
|
||
# batch是一个包含字典的列表
|
||
input_ids = torch.stack([item['input_ids'] for item in batch])
|
||
labels = torch.stack([item['labels'] for item in batch])
|
||
loss_mask = torch.stack([item['loss_mask'] for item in batch])
|
||
# target_input_ids = torch.stack([item['target_input_ids'] for item in batch])
|
||
# target_attention_mask = torch.stack([item['target_attention_mask'] for item in batch])
|
||
# target_sentences = [item['target_sentence'] for item in batch] # 用于调试
|
||
|
||
return {
|
||
'input_ids': input_ids,
|
||
'labels': labels,
|
||
'loss_mask': loss_mask,
|
||
# 'target_input_ids': target_input_ids,
|
||
# 'target_attention_mask': target_attention_mask,
|
||
# 'target_sentences': target_sentences
|
||
}
|
||
|
||
train_loader = DataLoader(
|
||
train_ds,
|
||
batch_size=args.batch_size,
|
||
pin_memory=False, # ✅ 实验:禁用pin_memory,避免内存固定问题
|
||
drop_last=True, # 修复:避免边界条件导致的死锁
|
||
shuffle=True,
|
||
num_workers=0, # ✅ 实验:禁用多进程,避免worker死锁
|
||
# persistent_workers 和 prefetch_factor 在 num_workers=0 时自动禁用
|
||
collate_fn=triple_collate_fn
|
||
)
|
||
val_loader = DataLoader(
|
||
val_ds,
|
||
batch_size=args.batch_size,
|
||
pin_memory=False,
|
||
drop_last=True,
|
||
shuffle=False,
|
||
num_workers=0,
|
||
collate_fn=triple_collate_fn
|
||
)
|
||
|
||
#########################################################
|
||
# 创建优化器
|
||
#########################################################
|
||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||
|
||
#########################################################
|
||
# 创建学习率调度器
|
||
#########################################################
|
||
total_steps = len(train_loader) * args.epochs
|
||
warmup_steps = args.warmup_iters if args.warmup_iters > 0 else int(0.1 * total_steps)
|
||
scheduler = get_cosine_schedule_with_warmup(
|
||
optimizer,
|
||
num_warmup_steps=warmup_steps,
|
||
num_training_steps=total_steps
|
||
)
|
||
|
||
#########################################################
|
||
# 准备训练
|
||
#########################################################
|
||
model, optimizer, train_loader, scheduler = accelerator.prepare(
|
||
model, optimizer, train_loader, scheduler
|
||
)
|
||
|
||
#########################################################
|
||
# 训练循环
|
||
#########################################################
|
||
overall_start_time = time.time() # Record overall start time
|
||
for epoch in range(args.epochs):
|
||
Logger(f"开始第{epoch+1}轮训练", accelerator)
|
||
train_epoch(epoch, accelerator, model, train_loader,val_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer) # Pass tokenizer
|
||
|
||
# 每个epoch结束后进行内存清理
|
||
Logger(f"第{epoch+1}轮训练完成,进行内存清理", accelerator)
|
||
gc.collect()
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
|
||
# 记录epoch结束时的内存状态
|
||
if accelerator.is_main_process:
|
||
memory_info = get_memory_usage()
|
||
cuda_info = get_cuda_memory_usage()
|
||
log_msg = f"[Memory Monitor] Epoch {epoch+1} completed - "
|
||
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
|
||
if cuda_info:
|
||
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
|
||
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
|
||
Logger(log_msg, accelerator)
|
||
|
||
#########################################################
|
||
# 关闭SwanLab
|
||
#########################################################
|
||
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
|
||
swanlab_run.finish()
|
||
|
||
if __name__ == "__main__":
|
||
main()
|