2025-05-14 00:01:40 +08:00
import os
2025-06-25 20:27:28 +08:00
# 设置环境变量 - 将wandb替换为SwanLab
# os.environ["SWANLAB_MODE"] = "online" # SwanLab使用在线模式
2025-09-06 15:12:05 +08:00
# 🔥 强制禁用输出缓冲,确保日志立即写入
os . environ [ ' PYTHONUNBUFFERED ' ] = ' 1 ' # Python 解释器不缓冲输出
os . environ [ ' PYTHONIOENCODING ' ] = ' utf-8 ' # 确保编码一致性
2025-05-14 00:01:40 +08:00
import platform
import argparse
2025-05-26 23:09:03 +08:00
from tqdm import tqdm
2025-05-14 00:01:40 +08:00
import time
import math
import warnings
import pandas as pd
import torch
from torch import optim , nn
from torch . utils . data import DataLoader
from contextlib import nullcontext
from typing import Optional
import datetime # Add datetime for time formatting
from accelerate import Accelerator
from accelerate . utils import set_seed
from accelerate . utils import DeepSpeedPlugin
from accelerate . utils import DistributedDataParallelKwargs
from transformers import AutoTokenizer , get_cosine_schedule_with_warmup
2025-05-26 23:09:03 +08:00
import numpy as np
from sklearn . metrics . pairwise import cosine_similarity
2025-06-25 20:27:28 +08:00
import swanlab # 替换wandb导入
import gc # 添加垃圾回收模块
import psutil # 添加系统资源监控模块
2025-08-07 11:43:23 +08:00
import json # 添加JSON支持
2025-07-12 18:00:53 +08:00
2025-05-14 00:01:40 +08:00
from model . LMConfig import LMConfig
from model . dataset import PretrainDataset
warnings . filterwarnings ( ' ignore ' )
2025-06-25 20:27:28 +08:00
# 内存监控辅助函数
def get_memory_usage ( ) :
""" 获取当前内存使用情况 """
process = psutil . Process ( )
memory_info = process . memory_info ( )
return {
' rss_mb ' : memory_info . rss / 1024 / 1024 , # 物理内存使用量( MB)
' vms_mb ' : memory_info . vms / 1024 / 1024 , # 虚拟内存使用量( MB)
}
def get_cuda_memory_usage ( ) :
""" 获取CUDA内存使用情况 """
if torch . cuda . is_available ( ) :
return {
' cuda_allocated_mb ' : torch . cuda . memory_allocated ( ) / 1024 / 1024 ,
' cuda_reserved_mb ' : torch . cuda . memory_reserved ( ) / 1024 / 1024 ,
' cuda_max_allocated_mb ' : torch . cuda . max_memory_allocated ( ) / 1024 / 1024 ,
}
return { }
def get_tensor_memory_size ( tensor_list ) :
""" 计算tensor列表的总内存占用( MB) """
total_size = 0
for batch in tensor_list :
if isinstance ( batch , ( list , tuple ) ) :
for tensor in batch :
if isinstance ( tensor , torch . Tensor ) :
total_size + = tensor . numel ( ) * tensor . element_size ( )
elif isinstance ( batch , torch . Tensor ) :
total_size + = batch . numel ( ) * batch . element_size ( )
return total_size / 1024 / 1024 # 转换为MB
def log_memory_status ( step , prefetch_batches , accelerator , stage = " " , detailed = False ) :
""" 记录内存状态 """
if not accelerator . is_main_process :
return
memory_info = get_memory_usage ( )
cuda_info = get_cuda_memory_usage ( )
prefetch_memory = get_tensor_memory_size ( prefetch_batches )
log_msg = f " [Memory Monitor] Step { step } { stage } - "
log_msg + = f " Prefetch batches: { len ( prefetch_batches ) } , "
log_msg + = f " Prefetch memory: { prefetch_memory : .2f } MB, "
log_msg + = f " System RSS: { memory_info [ ' rss_mb ' ] : .2f } MB "
if cuda_info :
log_msg + = f " , CUDA allocated: { cuda_info [ ' cuda_allocated_mb ' ] : .2f } MB "
log_msg + = f " , CUDA reserved: { cuda_info [ ' cuda_reserved_mb ' ] : .2f } MB "
if detailed :
log_msg + = f " , System VMS: { memory_info [ ' vms_mb ' ] : .2f } MB "
if cuda_info :
log_msg + = f " , CUDA max allocated: { cuda_info [ ' cuda_max_allocated_mb ' ] : .2f } MB "
Logger ( log_msg , accelerator )
2025-05-14 00:01:40 +08:00
# 日志记录函数
def Logger ( msg , accelerator = None ) :
# 如果没有提供accelerator, 则只在主进程打印
if accelerator is None or accelerator . is_main_process :
2025-09-06 15:12:05 +08:00
print ( f " [ { time . strftime ( ' % Y- % m- %d % H: % M: % S ' ) } ] { msg } " , flush = True ) # 强制刷新输出缓冲
import sys
sys . stdout . flush ( ) # 确保立即写入
2025-05-14 00:01:40 +08:00
# Helper function to format seconds into HH:MM:SS
def format_time ( seconds ) :
return str ( datetime . timedelta ( seconds = int ( seconds ) ) )
2025-08-07 11:43:23 +08:00
def create_validation_dataset ( val_data_path , tokenizer , max_length , num_samples = 200 ) :
"""
创建验证数据集
Args :
val_data_path : 验证数据文件路径
tokenizer : tokenizer实例
max_length : 最大序列长度
num_samples : 验证样本数量
Returns :
val_dataset : 验证数据集
"""
if not os . path . exists ( val_data_path ) :
Logger ( f " 警告:验证数据文件不存在: { val_data_path } ,跳过验证评估 " )
return None
# 读取验证数据
val_data = [ ]
with open ( val_data_path , ' r ' , encoding = ' utf-8 ' ) as f :
for i , line in enumerate ( f ) :
if i > = num_samples : # 限制验证样本数量
break
line = line . strip ( )
if line :
try :
sample = json . loads ( line )
val_data . append ( sample [ ' text ' ] )
except json . JSONDecodeError :
continue
# 创建临时验证文件
temp_val_file = " /tmp/temp_val.jsonl "
with open ( temp_val_file , ' w ' , encoding = ' utf-8 ' ) as f :
for text in val_data :
f . write ( json . dumps ( { ' text ' : text } ) + ' \n ' )
# 使用PretrainDataset创建验证集
val_dataset = PretrainDataset ( temp_val_file , tokenizer , max_length = max_length )
Logger ( f " 创建验证数据集成功,包含 { len ( val_data ) } 个样本 " )
return val_dataset
def validate_model ( model , val_loader , loss_fct , ctx , accelerator ) :
"""
执行模型验证
Args :
model : 模型实例
val_loader : 验证数据加载器
loss_fct : 损失函数
ctx : 上下文管理器
accelerator : Accelerator实例
Returns :
avg_val_loss : 平均验证损失
"""
model . eval ( )
total_loss = 0
num_batches = 0
with torch . no_grad ( ) :
for batch in val_loader :
X , Y , loss_mask = batch
with ctx :
res = model ( X )
loss = loss_fct (
res . logits . view ( - 1 , res . logits . size ( - 1 ) ) ,
Y . view ( - 1 )
) . view ( Y . size ( ) )
loss = ( loss * loss_mask ) . sum ( ) / loss_mask . sum ( )
total_loss + = loss . item ( )
num_batches + = 1
model . train ( )
avg_val_loss = total_loss / num_batches if num_batches > 0 else float ( ' inf ' )
return avg_val_loss
2025-05-14 00:01:40 +08:00
# 获取学习率函数
def get_lr ( it , num_iters , learning_rate ) :
# 余弦学习率衰减
return learning_rate * 0.5 * ( 1.0 + math . cos ( math . pi * it / num_iters ) )
# 初始化模型函数
2025-05-26 23:09:03 +08:00
def init_model ( lm_config , pretrained_embedding_path = None , database_init_path = None , args = None ) :
2025-07-12 18:00:53 +08:00
if args . model_type == " model " :
Logger ( f " Using model type: { args . model_type } " )
from model . model import MiniMindLM , RMSNorm
tokenizer = AutoTokenizer . from_pretrained ( ' ./model/minimind_tokenizer ' )
model = MiniMindLM ( lm_config )
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 默认模型初始化
Logger ( " Performing default model initialization... " )
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 初始化嵌入层权重
nn . init . normal_ ( model . tok_embeddings . weight , mean = 0.0 , std = 0.02 )
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 初始化输出层权重(如果不共享权重的话)
if not hasattr ( model . tok_embeddings , ' weight ' ) or model . output . weight is not model . tok_embeddings . weight :
nn . init . normal_ ( model . output . weight , mean = 0.0 , std = 0.02 )
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 初始化所有线性层
for name , module in model . named_modules ( ) :
if isinstance ( module , nn . Linear ) :
# 使用Xavier/Glorot初始化
nn . init . xavier_uniform_ ( module . weight )
if module . bias is not None :
nn . init . zeros_ ( module . bias )
elif isinstance ( module , nn . Embedding ) :
# 嵌入层使用正态分布初始化
nn . init . normal_ ( module . weight , mean = 0.0 , std = 0.02 )
elif isinstance ( module , RMSNorm ) :
# RMSNorm的权重初始化为1
if hasattr ( module , ' weight ' ) :
nn . init . ones_ ( module . weight )
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 初始化位置编码相关参数
if hasattr ( model . knowledge_dataset , ' keys ' ) :
nn . init . normal_ ( model . knowledge_dataset . keys , mean = 0.0 , std = 0.02 )
Logger ( " Default model initialization completed " )
# 如果提供了预训练的嵌入权重,加载它们
if pretrained_embedding_path :
Logger ( f " Loading pretrained token embeddings from { pretrained_embedding_path } " )
pretrained_embeddings = torch . load ( pretrained_embedding_path )
model . tok_embeddings . weight . data . copy_ ( pretrained_embeddings )
model . output . weight . data . copy_ ( pretrained_embeddings ) # 共享权重
if database_init_path :
import json
2025-05-29 20:29:45 +08:00
2025-07-12 18:00:53 +08:00
# 数据库参数
knowledge_num = args . knowledge_num
knowledge_length = args . knowledge_length
2025-05-29 20:29:45 +08:00
2025-07-12 18:00:53 +08:00
# 检查是否使用缓存
cache_dir = os . path . dirname ( args . cluster_cache_path )
if cache_dir :
os . makedirs ( cache_dir , exist_ok = True )
2025-05-29 20:29:45 +08:00
2025-07-12 18:00:53 +08:00
processed_tensor = None
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 尝试加载缓存的处理结果
if not args . recompute_clusters and os . path . exists ( args . cluster_cache_path ) :
try :
Logger ( f " Loading cached processed results from { args . cluster_cache_path } " )
processed_tensor = torch . load ( args . cluster_cache_path )
# 验证缓存文件的形状是否可用
cached_knowledge_num , cached_knowledge_length = processed_tensor . shape
if cached_knowledge_length == knowledge_length :
if cached_knowledge_num > = knowledge_num :
# 缓存足够大,可以截取使用
processed_tensor = processed_tensor [ : knowledge_num , : ]
Logger ( f " Successfully loaded cached data with shape { processed_tensor . shape } " )
Logger ( f " Truncated from cached shape ( { cached_knowledge_num } , { cached_knowledge_length } ) to required shape ( { knowledge_num } , { knowledge_length } ) " )
Logger ( " Skipping database initialization - using cached results " )
else :
# 缓存太小,需要重新计算
Logger ( f " Cached knowledge_num ( { cached_knowledge_num } ) < required knowledge_num ( { knowledge_num } ), recomputing... " )
processed_tensor = None
else :
# knowledge_length不匹配, 需要重新计算
Logger ( f " Cached knowledge_length ( { cached_knowledge_length } ) != required knowledge_length ( { knowledge_length } ), recomputing... " )
processed_tensor = None
except Exception as e :
Logger ( f " Failed to load cached data: { e } , recomputing... " )
processed_tensor = None
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 只有在没有有效缓存时才进行数据库初始化和处理
if processed_tensor is None :
Logger ( f " Loading database initialization data from { database_init_path } " )
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 1. 加载JSON文件
with open ( database_init_path , ' r ' , encoding = ' utf-8 ' ) as f :
database_data = json . load ( f )
2025-05-26 23:09:03 +08:00
2025-07-13 21:28:46 +08:00
sentences_data = [ ]
for data in database_data :
2025-09-05 14:24:48 +08:00
# 保存句子和对应的uuid信息
sentence_info = {
' sentence ' : data [ ' target ' ] [ 0 ] [ ' sentence ' ] ,
' uuid ' : data [ ' target ' ] [ 0 ] [ ' uuid ' ] ,
' subject ' : data [ ' target ' ] [ 0 ] . get ( ' subject ' , ' ' ) ,
' predicate ' : data [ ' target ' ] [ 0 ] . get ( ' predicate ' , ' ' ) ,
' object ' : data [ ' target ' ] [ 0 ] . get ( ' object ' , ' ' )
}
sentences_data . append ( sentence_info )
2025-07-13 21:28:46 +08:00
2025-07-12 18:00:53 +08:00
# 提取sentences列表
2025-07-13 21:28:46 +08:00
# sentences_data = database_data.get('sentences', [])
2025-07-12 18:00:53 +08:00
Logger ( f " Loaded { len ( sentences_data ) } sentences from database " )
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 2. 按照importance_score进行排序( 从高到低)
2025-07-13 21:28:46 +08:00
try :
2025-09-05 14:24:48 +08:00
# 注意: 现在sentences_data中的每个元素都是字典, 不再有importance_score字段
# 如果需要按重要性排序,需要从原始数据中获取该信息
sorted_sentences = sentences_data # 暂时不排序,保持原始顺序
Logger ( f " Loaded { len ( sorted_sentences ) } sentences (no importance_score sorting applied) " )
2025-07-13 21:28:46 +08:00
except :
sorted_sentences = sentences_data
2025-07-12 18:00:53 +08:00
# 3. 处理每条数据,不进行聚类
Logger ( " Processing individual sentences... " )
processed_rows = [ ]
# 获取空token的id( 用于填充)
pad_token_id = tokenizer . pad_token_id if tokenizer . pad_token_id is not None else 0
# 处理所需数量的句子
num_to_process = min ( knowledge_num , len ( sorted_sentences ) )
2025-07-13 21:28:46 +08:00
# 添加截断统计变量
total_sentences = 0
truncated_sentences = 0
2025-09-05 14:24:48 +08:00
# 用于记录映射关系的列表
database_mapping = [ ]
2025-07-12 18:00:53 +08:00
for i in range ( num_to_process ) :
sentence_data = sorted_sentences [ i ]
2025-09-05 14:24:48 +08:00
# 现在sentence_data是一个字典, 包含sentence和uuid
sentence = sentence_data [ ' sentence ' ]
uuid = sentence_data [ ' uuid ' ]
2025-07-12 18:00:53 +08:00
# 将句子转换为tokens
sentence_tokens = tokenizer . encode ( sentence , add_special_tokens = False )
# 截断或填充到knowledge_length
2025-07-13 21:28:46 +08:00
total_sentences + = 1
2025-07-12 18:00:53 +08:00
if len ( sentence_tokens ) > knowledge_length :
# 如果超过长度,截断
2025-07-13 21:28:46 +08:00
truncated_sentences + = 1
2025-07-12 18:00:53 +08:00
sentence_tokens = sentence_tokens [ : knowledge_length ]
Logger ( f " Sentence { i + 1 } truncated from { len ( tokenizer . encode ( sentence , add_special_tokens = False ) ) } to { knowledge_length } tokens " )
else :
# 如果不足长度, 用空token填充
original_length = len ( sentence_tokens )
sentence_tokens . extend ( [ pad_token_id ] * ( knowledge_length - len ( sentence_tokens ) ) )
if original_length < knowledge_length :
Logger ( f " Sentence { i + 1 } padded from { original_length } to { knowledge_length } tokens " )
processed_rows . append ( sentence_tokens )
2025-09-05 14:24:48 +08:00
# 记录映射关系:数据库索引 -> 原始数据信息
mapping_entry = {
' database_index ' : i , # 在数据库中的索引位置
' uuid ' : uuid , # 原始uuid
' sentence ' : sentence , # 原始句子
' subject ' : sentence_data . get ( ' subject ' , ' ' ) ,
' predicate ' : sentence_data . get ( ' predicate ' , ' ' ) ,
' object ' : sentence_data . get ( ' object ' , ' ' ) ,
' token_count ' : len ( sentence_tokens ) ,
' is_truncated ' : len ( tokenizer . encode ( sentence , add_special_tokens = False ) ) > knowledge_length
}
database_mapping . append ( mapping_entry )
2025-07-12 18:00:53 +08:00
if ( i + 1 ) % 1000 == 0 :
Logger ( f " Processed { i + 1 } / { num_to_process } sentences " )
# 如果句子数量不足, 用空token填充剩余位置
while len ( processed_rows ) < knowledge_num :
empty_tokens = [ pad_token_id ] * knowledge_length
processed_rows . append ( empty_tokens )
if len ( processed_rows ) % 1000 == 0 :
Logger ( f " Added empty entry { len ( processed_rows ) } / { knowledge_num } " )
Logger ( f " Finished adding empty entries. Total: { len ( processed_rows ) } / { knowledge_num } " )
# 转换为tensor
processed_tensor = torch . tensor ( processed_rows , dtype = torch . long )
2025-07-13 21:28:46 +08:00
# 计算并打印截断句子的占比
truncation_ratio = truncated_sentences / total_sentences if total_sentences > 0 else 0.0
Logger ( f " 截断句子统计: " )
Logger ( f " - 总句子数: { total_sentences } " )
Logger ( f " - 截断句子数: { truncated_sentences } " )
Logger ( f " - 截断句子占比: { truncation_ratio : .4f } ( { truncation_ratio * 100 : .2f } %) " )
2025-07-12 18:00:53 +08:00
Logger ( f " Data processing completed: " )
Logger ( f " - Processed { num_to_process } sentences " )
Logger ( f " - Added { knowledge_num - num_to_process } empty entries " )
Logger ( f " - Final shape: { processed_tensor . shape } " )
Logger ( f " - Expected shape: ( { knowledge_num } , { knowledge_length } ) " )
# 保存处理结果到缓存文件
try :
torch . save ( processed_tensor , args . cluster_cache_path )
Logger ( f " Processed results saved to { args . cluster_cache_path } " )
except Exception as e :
Logger ( f " Failed to save processed results: { e } " )
2025-09-05 14:24:48 +08:00
# 保存数据库映射文件
try :
mapping_file_path = args . cluster_cache_path . replace ( ' .pt ' , ' _mapping.json ' )
mapping_data = {
' metadata ' : {
' total_entries ' : len ( database_mapping ) ,
' knowledge_num ' : knowledge_num ,
' knowledge_length ' : knowledge_length ,
' source_file ' : database_init_path ,
' generation_time ' : time . strftime ( ' % Y- % m- %d % H: % M: % S ' )
} ,
' mappings ' : database_mapping
}
with open ( mapping_file_path , ' w ' , encoding = ' utf-8 ' ) as f :
json . dump ( mapping_data , f , ensure_ascii = False , indent = 2 )
Logger ( f " Database mapping saved to { mapping_file_path } " )
except Exception as e :
Logger ( f " Failed to save database mapping: { e } " )
2025-05-29 20:29:45 +08:00
2025-07-12 18:00:53 +08:00
# 4. 初始化模型的knowledge_dataset
if hasattr ( model , ' knowledge_dataset ' ) and hasattr ( model . knowledge_dataset , ' knowledge_dataset ' ) :
model . knowledge_dataset . knowledge_dataset . data . copy_ ( processed_tensor )
Logger ( " Successfully initialized model.knowledge_dataset.knowledge_dataset with processed data " )
else :
Logger ( " Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize " )
# 存储为全局变量作为备选
globals ( ) [ ' processed_database ' ] = processed_tensor
Logger ( f " Database embeddings and sentences stored in model " )
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
Logger ( f ' LLM总参数量: { sum ( p . numel ( ) for p in model . parameters ( ) if p . requires_grad ) / 1e6 : .3f } 百万 ' )
elif args . model_type == " model_original " :
Logger ( f " Using model type: { args . model_type } " )
from model . model_original import MiniMindLM , RMSNorm
tokenizer = AutoTokenizer . from_pretrained ( ' ./model/minimind_tokenizer ' )
model = MiniMindLM ( lm_config )
Logger ( f ' LLM总参数量: { sum ( p . numel ( ) for p in model . parameters ( ) if p . requires_grad ) / 1e6 : .3f } 百万 ' )
2025-07-13 21:28:46 +08:00
elif args . model_type == " model_no_feed " :
Logger ( f " Using model type: { args . model_type } " )
from model . model_no_feed import MiniMindLM , RMSNorm
tokenizer = AutoTokenizer . from_pretrained ( ' ./model/minimind_tokenizer ' )
model = MiniMindLM ( lm_config )
# 默认模型初始化
Logger ( " Performing default model initialization... " )
# 初始化嵌入层权重
nn . init . normal_ ( model . tok_embeddings . weight , mean = 0.0 , std = 0.02 )
# 初始化输出层权重(如果不共享权重的话)
if not hasattr ( model . tok_embeddings , ' weight ' ) or model . output . weight is not model . tok_embeddings . weight :
nn . init . normal_ ( model . output . weight , mean = 0.0 , std = 0.02 )
# 初始化所有线性层
for name , module in model . named_modules ( ) :
if isinstance ( module , nn . Linear ) :
# 使用Xavier/Glorot初始化
nn . init . xavier_uniform_ ( module . weight )
if module . bias is not None :
nn . init . zeros_ ( module . bias )
elif isinstance ( module , nn . Embedding ) :
# 嵌入层使用正态分布初始化
nn . init . normal_ ( module . weight , mean = 0.0 , std = 0.02 )
elif isinstance ( module , RMSNorm ) :
# RMSNorm的权重初始化为1
if hasattr ( module , ' weight ' ) :
nn . init . ones_ ( module . weight )
# 初始化位置编码相关参数
if hasattr ( model . knowledge_dataset , ' keys ' ) :
nn . init . normal_ ( model . knowledge_dataset . keys , mean = 0.0 , std = 0.02 )
Logger ( " Default model initialization completed " )
# 如果提供了预训练的嵌入权重,加载它们
if pretrained_embedding_path :
Logger ( f " Loading pretrained token embeddings from { pretrained_embedding_path } " )
pretrained_embeddings = torch . load ( pretrained_embedding_path )
model . tok_embeddings . weight . data . copy_ ( pretrained_embeddings )
model . output . weight . data . copy_ ( pretrained_embeddings ) # 共享权重
if database_init_path :
import json
# 数据库参数
knowledge_num = args . knowledge_num
knowledge_length = args . knowledge_length
# 检查是否使用缓存
cache_dir = os . path . dirname ( args . cluster_cache_path )
if cache_dir :
os . makedirs ( cache_dir , exist_ok = True )
processed_tensor = None
# 尝试加载缓存的处理结果
if not args . recompute_clusters and os . path . exists ( args . cluster_cache_path ) :
try :
Logger ( f " Loading cached processed results from { args . cluster_cache_path } " )
processed_tensor = torch . load ( args . cluster_cache_path )
# 验证缓存文件的形状是否可用
cached_knowledge_num , cached_knowledge_length = processed_tensor . shape
if cached_knowledge_length == knowledge_length :
if cached_knowledge_num > = knowledge_num :
# 缓存足够大,可以截取使用
processed_tensor = processed_tensor [ : knowledge_num , : ]
Logger ( f " Successfully loaded cached data with shape { processed_tensor . shape } " )
Logger ( f " Truncated from cached shape ( { cached_knowledge_num } , { cached_knowledge_length } ) to required shape ( { knowledge_num } , { knowledge_length } ) " )
Logger ( " Skipping database initialization - using cached results " )
else :
# 缓存太小,需要重新计算
Logger ( f " Cached knowledge_num ( { cached_knowledge_num } ) < required knowledge_num ( { knowledge_num } ), recomputing... " )
processed_tensor = None
else :
# knowledge_length不匹配, 需要重新计算
Logger ( f " Cached knowledge_length ( { cached_knowledge_length } ) != required knowledge_length ( { knowledge_length } ), recomputing... " )
processed_tensor = None
except Exception as e :
Logger ( f " Failed to load cached data: { e } , recomputing... " )
processed_tensor = None
# 只有在没有有效缓存时才进行数据库初始化和处理
if processed_tensor is None :
Logger ( f " Loading database initialization data from { database_init_path } " )
# 1. 加载JSON文件
with open ( database_init_path , ' r ' , encoding = ' utf-8 ' ) as f :
database_data = json . load ( f )
sentences_data = [ ]
for data in database_data :
2025-09-05 14:24:48 +08:00
# 保存句子和对应的uuid信息
sentence_info = {
' sentence ' : data [ ' target ' ] [ 0 ] [ ' sentence ' ] ,
' uuid ' : data [ ' target ' ] [ 0 ] [ ' uuid ' ] ,
' subject ' : data [ ' target ' ] [ 0 ] . get ( ' subject ' , ' ' ) ,
' predicate ' : data [ ' target ' ] [ 0 ] . get ( ' predicate ' , ' ' ) ,
' object ' : data [ ' target ' ] [ 0 ] . get ( ' object ' , ' ' )
}
sentences_data . append ( sentence_info )
2025-07-13 21:28:46 +08:00
# 提取sentences列表
# sentences_data = database_data.get('sentences', [])
Logger ( f " Loaded { len ( sentences_data ) } sentences from database " )
# 2. 按照importance_score进行排序( 从高到低)
try :
2025-09-05 14:24:48 +08:00
# 注意: 现在sentences_data中的每个元素都是字典, 不再有importance_score字段
# 如果需要按重要性排序,需要从原始数据中获取该信息
sorted_sentences = sentences_data # 暂时不排序,保持原始顺序
Logger ( f " Loaded { len ( sorted_sentences ) } sentences (no importance_score sorting applied) " )
2025-07-13 21:28:46 +08:00
except :
sorted_sentences = sentences_data
# 3. 处理每条数据,不进行聚类
Logger ( " Processing individual sentences... " )
processed_rows = [ ]
# 获取空token的id( 用于填充)
pad_token_id = tokenizer . pad_token_id if tokenizer . pad_token_id is not None else 0
# 处理所需数量的句子
num_to_process = min ( knowledge_num , len ( sorted_sentences ) )
# 添加截断统计变量
total_sentences = 0
truncated_sentences = 0
2025-09-05 14:24:48 +08:00
# 用于记录映射关系的列表
database_mapping = [ ]
2025-07-13 21:28:46 +08:00
for i in range ( num_to_process ) :
sentence_data = sorted_sentences [ i ]
2025-09-05 14:24:48 +08:00
# 现在sentence_data是一个字典, 包含sentence和uuid
sentence = sentence_data [ ' sentence ' ]
uuid = sentence_data [ ' uuid ' ]
2025-07-13 21:28:46 +08:00
# 将句子转换为tokens
sentence_tokens = tokenizer . encode ( sentence , add_special_tokens = False )
# 截断或填充到knowledge_length
total_sentences + = 1
if len ( sentence_tokens ) > knowledge_length :
# 如果超过长度,截断
truncated_sentences + = 1
sentence_tokens = sentence_tokens [ : knowledge_length ]
Logger ( f " Sentence { i + 1 } truncated from { len ( tokenizer . encode ( sentence , add_special_tokens = False ) ) } to { knowledge_length } tokens " )
else :
# 如果不足长度, 用空token填充
original_length = len ( sentence_tokens )
sentence_tokens . extend ( [ pad_token_id ] * ( knowledge_length - len ( sentence_tokens ) ) )
if original_length < knowledge_length :
Logger ( f " Sentence { i + 1 } padded from { original_length } to { knowledge_length } tokens " )
processed_rows . append ( sentence_tokens )
2025-09-05 14:24:48 +08:00
# 记录映射关系:数据库索引 -> 原始数据信息
mapping_entry = {
' database_index ' : i , # 在数据库中的索引位置
' uuid ' : uuid , # 原始uuid
' sentence ' : sentence , # 原始句子
' subject ' : sentence_data . get ( ' subject ' , ' ' ) ,
' predicate ' : sentence_data . get ( ' predicate ' , ' ' ) ,
' object ' : sentence_data . get ( ' object ' , ' ' ) ,
' token_count ' : len ( sentence_tokens ) ,
' is_truncated ' : len ( tokenizer . encode ( sentence , add_special_tokens = False ) ) > knowledge_length
}
database_mapping . append ( mapping_entry )
2025-07-13 21:28:46 +08:00
if ( i + 1 ) % 1000 == 0 :
Logger ( f " Processed { i + 1 } / { num_to_process } sentences " )
# 如果句子数量不足, 用空token填充剩余位置
while len ( processed_rows ) < knowledge_num :
empty_tokens = [ pad_token_id ] * knowledge_length
processed_rows . append ( empty_tokens )
if len ( processed_rows ) % 1000 == 0 :
Logger ( f " Added empty entry { len ( processed_rows ) } / { knowledge_num } " )
Logger ( f " Finished adding empty entries. Total: { len ( processed_rows ) } / { knowledge_num } " )
# 转换为tensor
processed_tensor = torch . tensor ( processed_rows , dtype = torch . long )
# 计算并打印截断句子的占比
truncation_ratio = truncated_sentences / total_sentences if total_sentences > 0 else 0.0
Logger ( f " 截断句子统计: " )
Logger ( f " - 总句子数: { total_sentences } " )
Logger ( f " - 截断句子数: { truncated_sentences } " )
Logger ( f " - 截断句子占比: { truncation_ratio : .4f } ( { truncation_ratio * 100 : .2f } %) " )
Logger ( f " Data processing completed: " )
Logger ( f " - Processed { num_to_process } sentences " )
Logger ( f " - Added { knowledge_num - num_to_process } empty entries " )
Logger ( f " - Final shape: { processed_tensor . shape } " )
Logger ( f " - Expected shape: ( { knowledge_num } , { knowledge_length } ) " )
# 保存处理结果到缓存文件
try :
torch . save ( processed_tensor , args . cluster_cache_path )
Logger ( f " Processed results saved to { args . cluster_cache_path } " )
except Exception as e :
Logger ( f " Failed to save processed results: { e } " )
2025-09-05 14:24:48 +08:00
# 保存数据库映射文件
try :
mapping_file_path = args . cluster_cache_path . replace ( ' .pt ' , ' _mapping.json ' )
mapping_data = {
' metadata ' : {
' total_entries ' : len ( database_mapping ) ,
' knowledge_num ' : knowledge_num ,
' knowledge_length ' : knowledge_length ,
' source_file ' : database_init_path ,
' generation_time ' : time . strftime ( ' % Y- % m- %d % H: % M: % S ' )
} ,
' mappings ' : database_mapping
}
with open ( mapping_file_path , ' w ' , encoding = ' utf-8 ' ) as f :
json . dump ( mapping_data , f , ensure_ascii = False , indent = 2 )
Logger ( f " Database mapping saved to { mapping_file_path } " )
except Exception as e :
Logger ( f " Failed to save database mapping: { e } " )
2025-07-13 21:28:46 +08:00
# 4. 初始化模型的knowledge_dataset
if hasattr ( model , ' knowledge_dataset ' ) and hasattr ( model . knowledge_dataset , ' knowledge_dataset ' ) :
model . knowledge_dataset . knowledge_dataset . data . copy_ ( processed_tensor )
Logger ( " Successfully initialized model.knowledge_dataset.knowledge_dataset with processed data " )
else :
Logger ( " Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize " )
# 存储为全局变量作为备选
globals ( ) [ ' processed_database ' ] = processed_tensor
Logger ( f " Database embeddings and sentences stored in model " )
Logger ( f ' LLM总参数量: { sum ( p . numel ( ) for p in model . parameters ( ) if p . requires_grad ) / 1e6 : .3f } 百万 ' )
2025-08-03 14:25:26 +08:00
elif args . model_type == " model_memory " :
Logger ( f " Using model type: { args . model_type } " )
from model . model_memory import MiniMindLM , RMSNorm
tokenizer = AutoTokenizer . from_pretrained ( ' ./model/minimind_tokenizer ' )
model = MiniMindLM ( lm_config )
# 默认模型初始化
Logger ( " Performing model_memory initialization... " )
# 初始化嵌入层权重
nn . init . normal_ ( model . tok_embeddings . weight , mean = 0.0 , std = 0.02 )
# 初始化输出层权重(如果不共享权重的话)
if not hasattr ( model . tok_embeddings , ' weight ' ) or model . output . weight is not model . tok_embeddings . weight :
nn . init . normal_ ( model . output . weight , mean = 0.0 , std = 0.02 )
# 初始化所有线性层
for name , module in model . named_modules ( ) :
if isinstance ( module , nn . Linear ) :
# 使用Xavier/Glorot初始化
nn . init . xavier_uniform_ ( module . weight )
if module . bias is not None :
nn . init . zeros_ ( module . bias )
elif isinstance ( module , nn . Embedding ) :
# 嵌入层使用正态分布初始化
nn . init . normal_ ( module . weight , mean = 0.0 , std = 0.02 )
elif isinstance ( module , RMSNorm ) :
# RMSNorm的权重初始化为1
if hasattr ( module , ' weight ' ) :
nn . init . ones_ ( module . weight )
2025-09-06 18:16:46 +08:00
# 🔥 实验1.4.10优化: 启用梯度检查点以减少显存占用
if hasattr ( model , ' gradient_checkpointing_enable ' ) :
model . gradient_checkpointing_enable ( )
Logger ( " ✅ 梯度检查点已启用 - 预计减少激活显存占用60-80 % " )
else :
# 手动为每个Transformer层启用梯度检查点
from torch . utils . checkpoint import checkpoint
if hasattr ( model , ' layers ' ) :
def make_checkpoint_forward ( original_forward ) :
def checkpoint_forward ( * args , * * kwargs ) :
return checkpoint ( original_forward , * args , * * kwargs , use_reentrant = False )
return checkpoint_forward
for layer_idx , layer in enumerate ( model . layers ) :
# 包装layer的forward方法以使用checkpoint
layer . forward = make_checkpoint_forward ( layer . forward )
Logger ( " ✅ 手动梯度检查点已启用 - 预计减少激活显存占用60-80 % " )
2025-08-19 19:32:52 +08:00
# 记忆库初始化
if database_init_path and os . path . exists ( database_init_path ) :
Logger ( f " Initializing memory_bank with text data from { database_init_path } " )
import json
# 数据库参数
knowledge_num = args . knowledge_num
knowledge_length = args . knowledge_length
# 缓存文件路径
memory_cache_path = args . cluster_cache_path or f " cache/memory_bank_init_ { knowledge_num } _ { knowledge_length } .pt "
os . makedirs ( os . path . dirname ( memory_cache_path ) if os . path . dirname ( memory_cache_path ) else ' . ' , exist_ok = True )
# 检查是否有缓存
if os . path . exists ( memory_cache_path ) :
Logger ( f " Loading memory_bank initialization from cache: { memory_cache_path } " )
processed_tensor = torch . load ( memory_cache_path )
Logger ( f " Loaded memory_bank data with shape: { processed_tensor . shape } " )
else :
Logger ( f " Processing text data from { database_init_path } for memory_bank initialization " )
# 加载数据
with open ( database_init_path , ' r ' , encoding = ' utf-8 ' ) as f :
data = json . load ( f )
Logger ( f " Loaded { len ( data ) } sentences from { database_init_path } " )
# 处理句子到token序列
processed_rows = [ ]
total_sentences = len ( data )
truncated_sentences = 0
pad_token_id = tokenizer . pad_token_id if tokenizer . pad_token_id is not None else 0
2025-09-05 14:24:48 +08:00
# 用于记录映射关系的列表
database_mapping = [ ]
2025-08-19 19:32:52 +08:00
# 控制处理的句子数量
num_to_process = min ( len ( data ) , knowledge_num )
Logger ( f " Processing { num_to_process } out of { total_sentences } sentences " )
# 处理句子到token ID序列
for idx , item in enumerate ( data [ : num_to_process ] ) :
if idx % 1000 == 0 :
Logger ( f " Processing sentence { idx + 1 } / { num_to_process } " )
2025-09-05 14:24:48 +08:00
# 获取句子文本和uuid
2025-08-19 19:32:52 +08:00
if isinstance ( item , dict ) :
2025-09-05 14:24:48 +08:00
# 如果是字典格式, 尝试提取target数组中的数据
if ' target ' in item and len ( item [ ' target ' ] ) > 0 :
sentence = item [ ' target ' ] [ 0 ] . get ( ' sentence ' , ' ' )
uuid = item [ ' target ' ] [ 0 ] . get ( ' uuid ' , ' ' )
subject = item [ ' target ' ] [ 0 ] . get ( ' subject ' , ' ' )
predicate = item [ ' target ' ] [ 0 ] . get ( ' predicate ' , ' ' )
object_name = item [ ' target ' ] [ 0 ] . get ( ' object ' , ' ' )
else :
sentence = item . get ( ' sentence ' , ' ' ) or item . get ( ' text ' , ' ' ) or str ( item )
uuid = item . get ( ' uuid ' , ' ' )
subject = item . get ( ' subject ' , ' ' )
predicate = item . get ( ' predicate ' , ' ' )
object_name = item . get ( ' object ' , ' ' )
2025-08-19 19:32:52 +08:00
else :
sentence = str ( item )
2025-09-05 14:24:48 +08:00
uuid = ' '
subject = ' '
predicate = ' '
object_name = ' '
2025-08-19 19:32:52 +08:00
# 使用tokenizer编码句子
try :
tokens = tokenizer (
sentence ,
add_special_tokens = True ,
truncation = True ,
2025-09-05 14:24:48 +08:00
max_length = len ( sentence ) ,
2025-08-19 19:32:52 +08:00
padding = False ,
return_tensors = " pt "
) [ ' input_ids ' ] . squeeze ( ) . tolist ( )
# 确保是列表
if not isinstance ( tokens , list ) :
tokens = [ tokens ]
# 检查长度
if len ( tokens ) > knowledge_length :
tokens = tokens [ : knowledge_length ]
truncated_sentences + = 1
elif len ( tokens ) < knowledge_length :
# 用padding token填充
tokens . extend ( [ pad_token_id ] * ( knowledge_length - len ( tokens ) ) )
processed_rows . append ( tokens )
2025-09-05 14:24:48 +08:00
# 记录映射关系:数据库索引 -> 原始数据信息
mapping_entry = {
' database_index ' : idx , # 在数据库中的索引位置
' uuid ' : uuid , # 原始uuid
' sentence ' : sentence , # 原始句子
' subject ' : subject ,
' predicate ' : predicate ,
' object ' : object_name ,
' token_count ' : len ( tokens ) ,
' is_truncated ' : len ( tokens ) > knowledge_length
}
database_mapping . append ( mapping_entry )
2025-08-19 19:32:52 +08:00
except Exception as e :
Logger ( f " Error processing sentence { idx } : { e } " )
# 使用空tokens作为fallback
empty_tokens = [ pad_token_id ] * knowledge_length
processed_rows . append ( empty_tokens )
2025-09-05 14:24:48 +08:00
# 为失败的句子也记录映射关系
mapping_entry = {
' database_index ' : idx ,
' uuid ' : uuid ,
' sentence ' : sentence ,
' subject ' : subject ,
' predicate ' : predicate ,
' object ' : object_name ,
' token_count ' : knowledge_length ,
' is_truncated ' : False ,
' processing_error ' : str ( e )
}
database_mapping . append ( mapping_entry )
2025-08-19 19:32:52 +08:00
# 如果句子数量不足, 用空token填充剩余位置
while len ( processed_rows ) < knowledge_num :
empty_tokens = [ pad_token_id ] * knowledge_length
processed_rows . append ( empty_tokens )
if len ( processed_rows ) % 1000 == 0 :
Logger ( f " Added empty entry { len ( processed_rows ) } / { knowledge_num } " )
# 转换为tensor
processed_tensor = torch . tensor ( processed_rows , dtype = torch . long )
# 计算并打印截断句子的占比
truncation_ratio = truncated_sentences / total_sentences if total_sentences > 0 else 0.0
Logger ( f " 截断句子统计: " )
Logger ( f " - 总句子数: { total_sentences } " )
Logger ( f " - 截断句子数: { truncated_sentences } " )
Logger ( f " - 截断句子占比: { truncation_ratio : .4f } ( { truncation_ratio * 100 : .2f } %) " )
Logger ( f " Memory_bank data processing completed: " )
Logger ( f " - Processed { num_to_process } sentences " )
Logger ( f " - Added { knowledge_num - num_to_process } empty entries " )
Logger ( f " - Final shape: { processed_tensor . shape } " )
Logger ( f " - Expected shape: ( { knowledge_num } , { knowledge_length } ) " )
# 保存处理结果到缓存文件
try :
torch . save ( processed_tensor , memory_cache_path )
Logger ( f " Processed results saved to { memory_cache_path } " )
except Exception as e :
Logger ( f " Failed to save processed results: { e } " )
2025-09-05 14:24:48 +08:00
# 保存数据库映射文件
try :
mapping_file_path = memory_cache_path . replace ( ' .pt ' , ' _mapping.json ' )
mapping_data = {
' metadata ' : {
' total_entries ' : len ( database_mapping ) ,
' knowledge_num ' : knowledge_num ,
' knowledge_length ' : knowledge_length ,
' source_file ' : database_init_path ,
' generation_time ' : time . strftime ( ' % Y- % m- %d % H: % M: % S ' )
} ,
' mappings ' : database_mapping
}
with open ( mapping_file_path , ' w ' , encoding = ' utf-8 ' ) as f :
json . dump ( mapping_data , f , ensure_ascii = False , indent = 2 )
Logger ( f " Database mapping saved to { mapping_file_path } " )
except Exception as e :
Logger ( f " Failed to save database mapping: { e } " )
2025-08-19 19:32:52 +08:00
# 初始化模型的memory_bank
if hasattr ( model , ' memory_bank ' ) :
model . memory_bank . data . copy_ ( processed_tensor )
Logger ( " Successfully initialized memory_bank with processed text data " )
else :
Logger ( " Warning: Could not find memory_bank to initialize " )
else :
Logger ( f " Memory bank initialized with random values, shape: { model . memory_bank . shape } " )
2025-08-03 14:25:26 +08:00
Logger ( " Model_memory initialization completed " )
Logger ( f ' LLM总参数量: { sum ( p . numel ( ) for p in model . parameters ( ) if p . requires_grad ) / 1e6 : .3f } 百万 ' )
2025-07-12 18:00:53 +08:00
2025-05-14 00:01:40 +08:00
return model , tokenizer
2025-08-07 11:43:23 +08:00
def train_epoch ( epoch , accelerator , model , train_loader , optimizer , scheduler , args , ctx , overall_start_time , swanlab_run , tokenizer , val_loader = None ) :
2025-05-14 00:01:40 +08:00
loss_fct = nn . CrossEntropyLoss ( reduction = ' none ' )
epoch_start_time = time . time ( )
total_steps_in_epoch = len ( train_loader )
total_training_steps = args . epochs * total_steps_in_epoch
moe_path = ' _moe ' if args . use_moe else ' '
2025-06-08 02:20:36 +00:00
best_loss = float ( ' 10000 ' )
2025-05-14 00:01:40 +08:00
2025-06-25 20:27:28 +08:00
# 初始化CUDA事件变量
data_start = data_end = forward_start = forward_end = None
backward_start = backward_end = optimizer_start = optimizer_end = None
2025-05-14 00:01:40 +08:00
# 添加CUDA事件来分析性能 (只在主进程进行)
if args . profile and accelerator . is_main_process :
data_start = torch . cuda . Event ( enable_timing = True )
data_end = torch . cuda . Event ( enable_timing = True )
forward_start = torch . cuda . Event ( enable_timing = True )
forward_end = torch . cuda . Event ( enable_timing = True )
backward_start = torch . cuda . Event ( enable_timing = True )
backward_end = torch . cuda . Event ( enable_timing = True )
optimizer_start = torch . cuda . Event ( enable_timing = True )
optimizer_end = torch . cuda . Event ( enable_timing = True )
# 预取数据
2025-07-13 21:28:46 +08:00
prefetch_factor = 8 # 预取的批次数
2025-05-14 00:01:40 +08:00
data_iter = iter ( train_loader )
prefetch_batches = [ ]
2025-06-25 20:27:28 +08:00
# 记录初始内存状态
if args . memory_monitor :
log_memory_status ( - 1 , prefetch_batches , accelerator , " before_prefetch " , detailed = True )
2025-05-14 00:01:40 +08:00
# 预取初始批次
2025-06-25 20:27:28 +08:00
for i in range ( min ( prefetch_factor , len ( train_loader ) ) ) :
2025-05-14 00:01:40 +08:00
try :
batch = next ( data_iter )
prefetch_batches . append ( batch )
2025-06-25 20:27:28 +08:00
# 每次添加batch后记录内存变化
if args . memory_monitor and accelerator . is_main_process :
log_memory_status ( - 1 , prefetch_batches , accelerator , f " after_adding_batch_ { i + 1 } " )
2025-05-14 00:01:40 +08:00
except StopIteration :
break
2025-06-25 20:27:28 +08:00
# 记录预取完成后的内存状态
if args . memory_monitor :
log_memory_status ( - 1 , prefetch_batches , accelerator , " after_initial_prefetch " , detailed = True )
2025-05-14 00:01:40 +08:00
# 在开始循环前初始化日志记录所需变量
last_log_time = epoch_start_time
for step in range ( total_steps_in_epoch ) :
try :
# 计时数据加载 (只在主进程进行)
2025-06-25 20:27:28 +08:00
if args . profile and accelerator . is_main_process and data_start is not None :
2025-05-14 00:01:40 +08:00
data_start . record ( )
2025-06-25 20:27:28 +08:00
# 记录使用batch前的内存状态( 根据配置间隔记录详细信息)
if args . memory_monitor and step % args . memory_monitor_interval == 0 :
log_memory_status ( step , prefetch_batches , accelerator , " before_use_batch " , detailed = True )
2025-05-14 00:01:40 +08:00
# 使用预取的数据
if prefetch_batches :
X , Y , loss_mask = prefetch_batches . pop ( 0 )
2025-06-25 20:27:28 +08:00
# 记录使用batch后的内存变化
if args . memory_monitor and step % args . memory_monitor_interval == 0 :
log_memory_status ( step , prefetch_batches , accelerator , " after_pop_batch " )
2025-05-14 00:01:40 +08:00
else :
# 如果预取队列为空,直接加载
X , Y , loss_mask = next ( data_iter )
2025-06-25 20:27:28 +08:00
if args . memory_monitor and accelerator . is_main_process :
Logger ( f " [Memory Monitor] Step { step } - Prefetch queue empty, loading directly! " , accelerator )
2025-05-14 00:01:40 +08:00
# 异步预取下一批数据
if step + prefetch_factor < len ( train_loader ) :
try :
batch = next ( data_iter )
prefetch_batches . append ( batch )
2025-06-25 20:27:28 +08:00
# 记录添加新batch后的内存变化
if args . memory_monitor and step % args . memory_monitor_interval == 0 :
log_memory_status ( step , prefetch_batches , accelerator , " after_add_batch " )
2025-05-14 00:01:40 +08:00
except StopIteration :
pass
# 计时数据加载结束 (只在主进程进行)
2025-06-25 20:27:28 +08:00
if args . profile and accelerator . is_main_process and data_end is not None :
2025-05-14 00:01:40 +08:00
data_end . record ( )
# 更新学习率
if scheduler is not None :
scheduler . step ( )
# 计时前向传播 (只在主进程进行)
2025-06-25 20:27:28 +08:00
if args . profile and accelerator . is_main_process and forward_start is not None :
2025-05-14 00:01:40 +08:00
forward_start . record ( )
# 前向传播
with ctx :
2025-06-08 02:20:36 +00:00
if step == 0 and args . embedding_epoch == epoch :
# 需要设置原始模型的freeze_embedding属性, 而不是包装后的模型
unwrapped_model = accelerator . unwrap_model ( model )
unwrapped_model . freeze_embedding = True
Logger ( f " Set freeze_embedding=True for epoch { epoch } , step { step } " , accelerator )
res = model ( X , step = step )
2025-08-07 11:43:23 +08:00
# 计算主要损失(交叉熵损失)
ce_loss = loss_fct (
2025-05-14 00:01:40 +08:00
res . logits . view ( - 1 , res . logits . size ( - 1 ) ) ,
Y . view ( - 1 )
) . view ( Y . size ( ) )
2025-08-07 11:43:23 +08:00
ce_loss = ( ce_loss * loss_mask ) . sum ( ) / loss_mask . sum ( )
2025-09-06 12:12:08 +08:00
# 🔥 实验1.4.9: 四损失系统处理
2025-08-07 11:43:23 +08:00
balance_loss = 0
2025-09-06 12:12:08 +08:00
similarity_loss = 0
diversity_loss = 0
2025-08-07 11:43:23 +08:00
if hasattr ( res , ' aux_loss ' ) and res . aux_loss is not None :
2025-09-06 12:12:08 +08:00
aux_loss = res . aux_loss
if isinstance ( aux_loss , dict ) :
# 新的四损失结构
balance_loss = aux_loss . get ( ' balance_loss ' , 0 )
similarity_loss = aux_loss . get ( ' similarity_loss ' , 0 )
diversity_loss = aux_loss . get ( ' diversity_loss ' , 0 )
else :
# 向后兼容: 旧的单一aux_loss
balance_loss = aux_loss
2025-09-05 14:24:48 +08:00
# 获取余弦相似度统计信息(如果模型支持)
cosine_stats = { }
2025-09-06 12:12:08 +08:00
avg_selected_similarity = 0.0
2025-09-05 14:24:48 +08:00
if hasattr ( res , ' cosine_stats ' ) and res . cosine_stats is not None :
cosine_stats = res . cosine_stats
2025-09-06 12:12:08 +08:00
# 🔥 使用选中记忆的平均相似度(更精确的指标)
selected_similarities = [ v for k , v in cosine_stats . items ( ) if k . endswith ( ' _selected_avg_similarity ' ) ]
if selected_similarities :
avg_selected_similarity = np . mean ( selected_similarities )
# 🔥 四损失系统: CE + Balance + Similarity + Diversity
# 损失系数可以通过命令行参数调整
balance_coef = getattr ( args , ' balance_loss_coef ' , 0.01 )
similarity_coef = getattr ( args , ' similarity_loss_coef ' , 0.1 )
diversity_coef = getattr ( args , ' diversity_loss_coef ' , 0.05 )
2025-08-07 11:43:23 +08:00
2025-09-06 12:12:08 +08:00
total_loss = ( ce_loss +
balance_coef * balance_loss +
similarity_coef * similarity_loss +
diversity_coef * diversity_loss )
2025-08-07 11:43:23 +08:00
loss = total_loss / args . accumulation_steps
2025-05-14 00:01:40 +08:00
# 计时前向传播结束 (只在主进程进行)
2025-06-25 20:27:28 +08:00
if args . profile and accelerator . is_main_process and forward_end is not None :
2025-05-14 00:01:40 +08:00
forward_end . record ( )
# 计时反向传播 (只在主进程进行)
2025-06-25 20:27:28 +08:00
if args . profile and accelerator . is_main_process and backward_start is not None :
2025-05-14 00:01:40 +08:00
backward_start . record ( )
# 反向传播
# 当使用DeepSpeed时, 它会自动处理梯度累积和梯度裁剪
accelerator . backward ( loss )
# 计时反向传播结束 (只在主进程进行)
2025-06-25 20:27:28 +08:00
if args . profile and accelerator . is_main_process and backward_end is not None :
2025-05-14 00:01:40 +08:00
backward_end . record ( )
# 计时优化器步骤 (只在主进程进行)
2025-06-25 20:27:28 +08:00
if args . profile and accelerator . is_main_process and optimizer_start is not None :
2025-05-14 00:01:40 +08:00
optimizer_start . record ( )
# 优化器步骤 - 当使用DeepSpeed时, 它会自动处理梯度累积和梯度裁剪
# 只有在达到累积步数时才会执行优化器步骤
# 注意: 当使用DeepSpeed时, 它会自动处理梯度累积, 所以我们不需要检查step % accumulation_steps
optimizer . step ( )
# 当使用DeepSpeed时, zero_grad()会在step()之后自动调用
# 但为了安全起见,我们仍然显式调用它
optimizer . zero_grad ( )
2025-08-09 10:47:35 +08:00
# VQ-VAE风格的EMA更新( 仅在启用时执行)
if hasattr ( res , ' ema_stats ' ) and res . ema_stats is not None :
unwrapped_model = accelerator . unwrap_model ( model )
if hasattr ( unwrapped_model , ' apply_ema_update ' ) :
ema_update_stats = unwrapped_model . apply_ema_update ( res . ema_stats )
# 记录EMA更新统计信息
if step % args . log_interval == 0 and accelerator . is_main_process and ema_update_stats . get ( ' ema_update_applied ' , False ) :
total_memories = args . knowledge_num
Logger ( f " EMA Update - Step: { ema_update_stats [ ' ema_step ' ] } , "
f " Updated memories: { ema_update_stats [ ' updated_memories ' ] } / { total_memories } "
f " ( { ema_update_stats [ ' update_ratio ' ] : .4f } ), "
f " Coverage: { ema_update_stats [ ' selected_memory_coverage ' ] : .4f } " , accelerator )
2025-05-14 00:01:40 +08:00
# 计时优化器步骤结束 (只在主进程进行)
2025-06-25 20:27:28 +08:00
if args . profile and accelerator . is_main_process and optimizer_end is not None :
2025-05-14 00:01:40 +08:00
optimizer_end . record ( )
2025-08-07 11:43:23 +08:00
# 验证评估和日志记录 (只在主进程进行)
if ( step + 1 ) % args . val_interval == 0 and accelerator . is_main_process :
2025-05-14 00:01:40 +08:00
current_time = time . time ( )
2025-06-25 20:27:28 +08:00
# 记录日志输出时的详细内存状态
if args . memory_monitor :
log_memory_status ( step , prefetch_batches , accelerator , " at_log_interval " , detailed = True )
# 强制垃圾回收并记录内存变化
if torch . cuda . is_available ( ) :
torch . cuda . empty_cache ( )
gc . collect ( )
log_memory_status ( step , prefetch_batches , accelerator , " after_gc " , detailed = True )
2025-05-14 00:01:40 +08:00
# 计算性能指标
2025-06-25 20:27:28 +08:00
if args . profile and accelerator . is_main_process :
2025-05-14 00:01:40 +08:00
torch . cuda . synchronize ( )
2025-06-25 20:27:28 +08:00
# 确保所有事件都已记录才计算elapsed_time
try :
data_time = data_start . elapsed_time ( data_end ) if data_start is not None and data_end is not None else 0
forward_time = forward_start . elapsed_time ( forward_end ) if forward_start is not None and forward_end is not None else 0
backward_time = backward_start . elapsed_time ( backward_end ) if backward_start is not None and backward_end is not None else 0
optimizer_time = optimizer_start . elapsed_time ( optimizer_end ) if optimizer_start is not None and optimizer_end is not None else 0
iter_time = ( current_time - last_log_time ) * 1000 / args . log_interval # avg ms per iteration since last log
# total_time_ms = data_time + forward_time + backward_time + optimizer_time
# 打印性能分析
if ( step + 1 ) % ( args . log_interval * args . profile_interval ) == 0 :
Logger ( f " 性能分析 (Avg/iter over last { args . log_interval } steps) - "
f " Data: { data_time / args . log_interval : .2f } ms, "
f " Fwd: { forward_time / args . log_interval : .2f } ms, "
f " Bwd: { backward_time / args . log_interval : .2f } ms, "
f " Optim: { optimizer_time / args . log_interval : .2f } ms, "
f " Iter Time: { iter_time : .2f } ms " , accelerator )
2025-07-17 00:05:34 +08:00
# 生成文本示例
try :
# 随机选择一个样本
random_idx = torch . randint ( 0 , X . size ( 0 ) , ( 1 , ) ) . item ( )
sample_input = X [ random_idx : random_idx + 1 ] # [1, seq_len]
2025-07-17 12:06:28 +08:00
sample_target = Y [ random_idx : random_idx + 1 ] # [1, seq_len]
2025-07-17 00:05:34 +08:00
2025-07-17 12:06:28 +08:00
# 取前面的部分作为prompt, 确保后面有10个token作为真实值
prompt_len = sample_input . size ( 1 ) / / 2
2025-07-17 00:05:34 +08:00
prompt_input = sample_input [ : , : prompt_len ]
2025-07-17 12:06:28 +08:00
# 获取真实的后10个token
true_next_tokens = sample_target [ : , prompt_len - 1 : prompt_len - 1 + 10 ] # 真实的接下来10个token
2025-07-17 00:05:34 +08:00
# 生成10个token
unwrapped_model = accelerator . unwrap_model ( model )
unwrapped_model . eval ( ) # 设置为评估模式
with torch . no_grad ( ) :
generated = unwrapped_model . generate (
prompt_input ,
max_new_tokens = 10 ,
temperature = 0.7 ,
top_p = 0.9 ,
eos_token_id = tokenizer . eos_token_id ,
pad_token_id = tokenizer . pad_token_id
)
# 转换为人类可读文本
prompt_text = tokenizer . decode ( prompt_input [ 0 ] , skip_special_tokens = True )
2025-07-17 12:06:28 +08:00
true_text = tokenizer . decode ( true_next_tokens [ 0 ] , skip_special_tokens = True )
# 获取新生成的token
prompt_tokens = prompt_input [ 0 ] . tolist ( )
generated_tokens = generated [ 0 ] . tolist ( )
if len ( generated_tokens ) > len ( prompt_tokens ) :
new_tokens = generated_tokens [ len ( prompt_tokens ) : len ( prompt_tokens ) + 10 ] # 只取前10个
generated_text = tokenizer . decode ( new_tokens , skip_special_tokens = True )
else :
generated_text = " [未生成新token] "
2025-07-17 00:05:34 +08:00
2025-07-17 12:06:28 +08:00
Logger ( f " 文本生成对比: " , accelerator )
Logger ( f " 输入提示: { prompt_text } " , accelerator )
Logger ( f " 真实续写: { true_text } " , accelerator )
Logger ( f " 模型生成: { generated_text } " , accelerator )
2025-07-17 00:05:34 +08:00
unwrapped_model . train ( ) # 恢复训练模式
except Exception as e :
Logger ( f " 生成文本示例失败: { e } " , accelerator )
2025-06-25 20:27:28 +08:00
# 重置事件以便下次测量从0开始
data_start = torch . cuda . Event ( enable_timing = True )
data_end = torch . cuda . Event ( enable_timing = True )
forward_start = torch . cuda . Event ( enable_timing = True )
forward_end = torch . cuda . Event ( enable_timing = True )
backward_start = torch . cuda . Event ( enable_timing = True )
backward_end = torch . cuda . Event ( enable_timing = True )
optimizer_start = torch . cuda . Event ( enable_timing = True )
optimizer_end = torch . cuda . Event ( enable_timing = True )
except RuntimeError as e :
if " Both events must be recorded " in str ( e ) :
Logger ( f " Warning: CUDA events not properly recorded, skipping performance analysis: { e } " , accelerator )
else :
raise e
2025-05-14 00:01:40 +08:00
# 计算当前学习率
current_lr = optimizer . param_groups [ 0 ] [ ' lr ' ]
# 计算时间
epoch_elapsed_time = current_time - epoch_start_time
epoch_steps_done = step + 1
epoch_avg_step_time = epoch_elapsed_time / epoch_steps_done
epoch_remaining_time = epoch_avg_step_time * ( total_steps_in_epoch - epoch_steps_done )
total_elapsed_time = current_time - overall_start_time
total_steps_done = epoch * total_steps_in_epoch + epoch_steps_done
total_avg_step_time = total_elapsed_time / total_steps_done if total_steps_done > 0 else 0
total_remaining_time = total_avg_step_time * ( total_training_steps - total_steps_done ) if total_steps_done > 0 else 0
# 计算训练速度 (基于最近的log_interval)
interval_elapsed_time = current_time - last_log_time
tokens_processed_interval = args . log_interval * args . batch_size * args . max_seq_len
tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0
last_log_time = current_time # 更新上次日志时间
2025-08-07 11:43:23 +08:00
# 执行验证评估
val_loss = None
if val_loader is not None :
try :
val_loss = validate_model ( model , val_loader , loss_fct , ctx , accelerator )
Logger ( f " 验证损失: { val_loss : .4f } " , accelerator )
except Exception as e :
Logger ( f " 验证评估失败: { e } " , accelerator )
val_loss = None
# 获取记忆库更新统计(如果模型支持)
memory_update_stats = { }
if hasattr ( model , ' get_memory_update_stats ' ) :
try :
unwrapped_model = accelerator . unwrap_model ( model )
if hasattr ( unwrapped_model , ' get_memory_update_stats ' ) :
memory_update_stats = unwrapped_model . get_memory_update_stats ( )
except Exception as e :
Logger ( f " 获取记忆更新统计失败: { e } " , accelerator )
# 获取层级统计信息(如果模型支持)
layer_stats = { }
if hasattr ( res , ' layer_stats ' ) and res . layer_stats is not None :
layer_stats = res . layer_stats
2025-09-05 14:24:48 +08:00
2025-08-07 11:43:23 +08:00
2025-09-06 12:12:08 +08:00
# 🔥 构建四损失系统的日志字典
2025-05-14 00:42:50 +08:00
log_dict = {
" epoch " : epoch + 1 ,
" step " : step + 1 ,
" total_steps_in_epoch " : total_steps_in_epoch ,
2025-08-07 11:43:23 +08:00
" train/loss_ce " : ce_loss . item ( ) ,
" train/loss_balance " : balance_loss . item ( ) if isinstance ( balance_loss , torch . Tensor ) else balance_loss ,
2025-09-06 12:12:08 +08:00
" train/loss_similarity " : similarity_loss . item ( ) if isinstance ( similarity_loss , torch . Tensor ) else similarity_loss ,
" train/loss_diversity " : diversity_loss . item ( ) if isinstance ( diversity_loss , torch . Tensor ) else diversity_loss ,
2025-08-07 11:43:23 +08:00
" train/loss_total " : total_loss . item ( ) ,
2025-05-14 00:42:50 +08:00
" lr " : current_lr ,
" tokens_per_sec " : tokens_per_sec ,
" epoch_time_left_seconds " : epoch_remaining_time ,
" total_time_left_seconds " : total_remaining_time
}
2025-08-07 11:43:23 +08:00
# 添加验证损失
if val_loss is not None :
log_dict [ " val/loss " ] = val_loss
# 添加记忆库更新统计
log_dict . update ( memory_update_stats )
2025-09-05 14:24:48 +08:00
2025-08-07 11:43:23 +08:00
# 添加层级统计信息(选择性添加关键指标)
if layer_stats :
# 计算所有层的平均统计
avg_gini = np . mean ( [ v for k , v in layer_stats . items ( ) if k . endswith ( ' _gini_coefficient ' ) ] )
avg_coverage = np . mean ( [ v for k , v in layer_stats . items ( ) if k . endswith ( ' _coverage_rate ' ) ] )
total_dead = sum ( [ v for k , v in layer_stats . items ( ) if k . endswith ( ' _dead_memories ' ) ] )
total_hot = sum ( [ v for k , v in layer_stats . items ( ) if k . endswith ( ' _hot_memories ' ) ] )
log_dict . update ( {
' memory/avg_gini_coefficient ' : avg_gini ,
' memory/avg_coverage_rate ' : avg_coverage ,
' memory/total_dead_memories ' : total_dead ,
' memory/total_hot_memories ' : total_hot ,
2025-09-06 12:12:08 +08:00
' train/avg_selected_similarity ' : avg_selected_similarity , # 🔥 使用选中记忆的相似度
2025-08-07 11:43:23 +08:00
} )
2025-09-06 12:12:08 +08:00
# 🔥 四损失系统的控制台输出
2025-05-14 00:01:40 +08:00
Logger ( f " Epoch { epoch + 1 } / { args . epochs } , Step { step + 1 } / { total_steps_in_epoch } , "
2025-09-06 12:12:08 +08:00
f " CE: { log_dict [ ' train/loss_ce ' ] : .4f } , "
f " Bal: { log_dict [ ' train/loss_balance ' ] : .4f } , "
f " Sim: { log_dict [ ' train/loss_similarity ' ] : .4f } , "
f " Div: { log_dict [ ' train/loss_diversity ' ] : .4f } , "
f " Total: { log_dict [ ' train/loss_total ' ] : .4f } , "
f " Val: { log_dict . get ( ' val/loss ' , ' N/A ' ) } , "
2025-05-14 00:42:50 +08:00
f " LR: { log_dict [ ' lr ' ] : .6f } , "
f " Speed: { log_dict [ ' tokens_per_sec ' ] : .2f } tokens/sec | "
2025-09-06 12:12:08 +08:00
f " Sel.Sim: { avg_selected_similarity : .4f } | "
2025-05-14 00:01:40 +08:00
f " Epoch Time Left: { format_time ( epoch_remaining_time ) } | "
f " Total Time Left: { format_time ( total_remaining_time ) } " , accelerator )
2025-06-25 20:27:28 +08:00
if args . use_swanlab and accelerator . is_main_process and swanlab_run :
swanlab_run . log ( log_dict )
2025-05-14 00:42:50 +08:00
2025-05-14 00:01:40 +08:00
# 保存模型 (只在主进程进行)
2025-06-08 02:20:36 +00:00
loss_total = loss . item ( ) * args . accumulation_steps
2025-08-07 11:43:23 +08:00
if epoch > = 0 and best_loss > loss_total and accelerator . is_main_process :
2025-06-08 02:20:36 +00:00
best_loss = loss_total
2025-05-14 00:01:40 +08:00
# 使用函数开始处定义的moe_path变量
ckp = f ' { args . save_dir } /pretrain_ { args . dim } { moe_path } .pth '
# 获取解包后的模型
unwrapped_model = accelerator . unwrap_model ( model )
# 保存模型参数
accelerator . save ( unwrapped_model . state_dict ( ) , ckp )
Logger ( f " Model saved to { ckp } " , accelerator )
except Exception as e :
Logger ( f " Error in training step: { e } " , accelerator )
2025-06-25 20:27:28 +08:00
# 记录异常时的内存状态
if args . memory_monitor :
log_memory_status ( step , prefetch_batches , accelerator , " at_exception " , detailed = True )
2025-05-14 00:01:40 +08:00
import traceback
Logger ( traceback . format_exc ( ) , accelerator )
2025-06-25 20:27:28 +08:00
# 清理prefetch_batches, 防止内存泄漏
if args . memory_monitor and accelerator . is_main_process :
Logger ( f " [Memory Monitor] Clearing prefetch_batches due to exception. Current length: { len ( prefetch_batches ) } " , accelerator )
prefetch_batches . clear ( )
gc . collect ( )
if torch . cuda . is_available ( ) :
torch . cuda . empty_cache ( )
if args . memory_monitor :
log_memory_status ( step , prefetch_batches , accelerator , " after_exception_cleanup " , detailed = True )
# 训练epoch结束时清理prefetch_batches
if args . memory_monitor :
if accelerator . is_main_process :
Logger ( f " [Memory Monitor] Epoch { epoch + 1 } finished. Clearing prefetch_batches. Final length: { len ( prefetch_batches ) } " , accelerator )
log_memory_status ( total_steps_in_epoch - 1 , prefetch_batches , accelerator , " before_epoch_end_cleanup " , detailed = True )
prefetch_batches . clear ( )
gc . collect ( )
if torch . cuda . is_available ( ) :
torch . cuda . empty_cache ( )
if args . memory_monitor :
log_memory_status ( total_steps_in_epoch - 1 , prefetch_batches , accelerator , " after_epoch_end_cleanup " , detailed = True )
2025-05-14 00:01:40 +08:00
def main ( ) :
parser = argparse . ArgumentParser ( description = " MiniMind Pretraining with Accelerate " )
parser . add_argument ( " --out_dir " , type = str , default = " out " )
2025-06-06 11:25:59 +08:00
parser . add_argument ( " --epochs " , type = int , default = 4 )
2025-06-08 02:20:36 +00:00
parser . add_argument ( " --embedding_epoch " , type = int , default = 2 , help = " embedding训练的epoch数 " )
2025-09-05 14:24:48 +08:00
parser . add_argument ( " --batch_size " , type = int , default = 20 )
2025-05-14 00:01:40 +08:00
parser . add_argument ( " --learning_rate " , type = float , default = 2e-4 )
parser . add_argument ( " --dtype " , type = str , default = " bfloat16 " )
2025-09-05 14:24:48 +08:00
parser . add_argument ( " --use_swanlab " , default = False , action = " store_true " ) # 替换wandb参数
2025-06-25 20:27:28 +08:00
parser . add_argument ( " --swanlab_project " , type = str , default = " MiniMind-Pretrain " ) # 替换wandb参数
parser . add_argument ( " --num_workers " , type = int , default = 1 )
2025-05-14 00:01:40 +08:00
parser . add_argument ( " --accumulation_steps " , type = int , default = 32 )
parser . add_argument ( " --grad_clip " , type = float , default = 1.0 )
parser . add_argument ( " --warmup_iters " , type = int , default = 0 )
2025-07-17 12:06:28 +08:00
parser . add_argument ( " --log_interval " , type = int , default = 1 )
2025-05-14 00:01:40 +08:00
parser . add_argument ( " --save_interval " , type = int , default = 10000 )
2025-06-06 11:25:59 +08:00
parser . add_argument ( ' --dim ' , default = 512 , type = int )
parser . add_argument ( ' --n_layers ' , default = 8 , type = int )
2025-08-01 15:54:21 +08:00
parser . add_argument ( ' --n_heads ' , default = 32 , type = int )
2025-06-06 11:25:59 +08:00
parser . add_argument ( ' --max_seq_len ' , default = 512 , type = int )
2025-05-14 00:01:40 +08:00
parser . add_argument ( ' --use_moe ' , default = False , type = bool )
parser . add_argument ( ' --disable_db ' , action = ' store_true ' , help = " 禁用数据库功能, 使用固定值1e-4替代 " )
2025-09-05 14:24:48 +08:00
parser . add_argument ( " --data_path " , type = str , default = " /home/iomgaa/Code/Minimind/dataset/stable/merged_pretrain.jsonl " )
2025-05-14 00:01:40 +08:00
parser . add_argument ( " --pretrained_embedding_path " , type = str , default = None , help = " Path to pretrained token embedding weights (.pth file) " )
parser . add_argument ( " --profile " , action = " store_true " , default = True , help = " 启用性能分析 " )
parser . add_argument ( " --profile_interval " , type = int , default = 10 , help = " 性能分析打印间隔(步数) " )
parser . add_argument ( " --use_flash_attn " , action = " store_true " , default = True , help = " 启用FlashAttention " )
2025-06-25 20:27:28 +08:00
parser . add_argument ( " --knowledge_num " , type = int , default = 960400 , help = " 知识库的数据数目 " )
2025-08-14 23:04:52 +08:00
parser . add_argument ( " --knowledge_length " , type = int , default = 8 , help = " 知识库的句子长度 " )
2025-08-03 14:25:26 +08:00
parser . add_argument ( " --knowledge_dim " , type = int , default = 128 , help = " 知识库的向量维度 " )
2025-09-05 14:24:48 +08:00
parser . add_argument ( " --database_init_path " , type = str , default = " /home/iomgaa/Code/Minimind/dataset/stable/sentence_trex_data.json " , help = " 数据库初始化路径 " )
2025-05-26 23:09:03 +08:00
parser . add_argument ( " --fast_clustering " , action = " store_true " , default = True , help = " 使用快速近似聚类算法(适用于大数据集) " )
2025-09-06 12:12:08 +08:00
parser . add_argument ( " --cluster_cache_path " , type = str , default = " ./cache/cluster_tokens_single.pt " , help = " 聚类结果缓存文件路径 " )
2025-05-29 20:29:45 +08:00
parser . add_argument ( " --recompute_clusters " , action = " store_true " , default = False , help = " 强制重新计算聚类,忽略缓存文件 " )
2025-06-25 20:27:28 +08:00
parser . add_argument ( " --memory_monitor " , action = " store_true " , default = False , help = " 启用内存监控 " )
parser . add_argument ( " --memory_monitor_interval " , type = int , default = 10 , help = " 内存监控间隔(步数) " )
2025-08-09 10:47:35 +08:00
parser . add_argument ( " --model_type " , type = str , default = " model_memory " , help = " 使用什么模型训练 " ) #model,model_original,model_no_feed
2025-07-13 21:28:46 +08:00
parser . add_argument ( " --model_size " , type = float , default = 50.0 , help = " 模型大小 " )
parser . add_argument ( " --swanlab_online " , type = bool , default = False , help = " 是否使用在线SwanLab服务 " )
2025-08-07 11:43:23 +08:00
parser . add_argument ( " --balance_loss_coef " , type = float , default = 0.01 , help = " 平衡损失系数 " )
2025-09-06 12:12:08 +08:00
parser . add_argument ( " --similarity_loss_coef " , type = float , default = 0.1 , help = " 相似度损失系数( 实验1.4.9) " )
parser . add_argument ( " --diversity_loss_coef " , type = float , default = 0.05 , help = " 多样性损失系数( 实验1.4.9) " )
2025-09-05 14:24:48 +08:00
parser . add_argument ( " --val_data_path " , type = str , default = " /home/zym/Code/stable/eval_data.json " , help = " 验证数据集路径 " )
2025-08-07 11:43:23 +08:00
parser . add_argument ( " --val_interval " , type = int , default = 100 , help = " 验证评估间隔 " )
2025-09-06 12:12:08 +08:00
parser . add_argument ( " --freeze_ratio " , type = float , default = 0.2 , help = " 冻结率 " )
2025-05-14 00:01:40 +08:00
args = parser . parse_args ( )
2025-06-25 20:27:28 +08:00
2025-05-14 00:01:40 +08:00
#########################################################
# 初始化accelerator和deepspeed
#########################################################
# 设置ddp_kwargs以处理未使用的参数
ddp_kwargs = DistributedDataParallelKwargs ( find_unused_parameters = True )
# 创建DeepSpeedPlugin对象
ds_plugin = DeepSpeedPlugin (
gradient_accumulation_steps = args . accumulation_steps ,
gradient_clipping = args . grad_clip ,
zero_stage = 2 , # 使用ZeRO-2优化
2025-06-25 20:27:28 +08:00
offload_optimizer_device = " none " , # 将优化器状态卸载到CPU
2025-05-14 00:01:40 +08:00
offload_param_device = " none " , # 不将参数卸载到CPU
)
accelerator = Accelerator (
kwargs_handlers = [ ddp_kwargs ] ,
deepspeed_plugin = ds_plugin ,
mixed_precision = " bf16 " if args . dtype == " bfloat16 " else " fp16 " if args . dtype == " float16 " else " no "
)
#########################################################
# 设置随机种子
#########################################################
set_seed ( 1337 + accelerator . process_index )
#########################################################
# 配置模型
#########################################################
lm_config = LMConfig (
dim = args . dim ,
n_layers = args . n_layers ,
2025-09-06 17:25:46 +08:00
n_heads = args . n_heads ,
2025-05-14 00:01:40 +08:00
max_seq_len = args . max_seq_len ,
use_moe = args . use_moe ,
disable_db = args . disable_db ,
flash_attn = args . use_flash_attn ,
2025-05-16 08:38:59 +00:00
knowledge_num = args . knowledge_num ,
2025-06-08 02:20:36 +00:00
knowledge_length = args . knowledge_length ,
2025-09-06 17:25:46 +08:00
knowledge_dim = args . knowledge_dim ,
2025-09-06 12:12:08 +08:00
embeddings_epoch = args . embedding_epoch ,
freeze_ratio = args . freeze_ratio
2025-05-14 00:01:40 +08:00
)
#########################################################
# 创建保存目录
#########################################################
args . save_dir = os . path . join ( args . out_dir )
if accelerator . is_main_process :
os . makedirs ( args . save_dir , exist_ok = True )
os . makedirs ( args . out_dir , exist_ok = True )
#########################################################
# 设置数据类型
#########################################################
pt_dtype = { ' float32 ' : torch . float32 , ' bfloat16 ' : torch . bfloat16 , ' float16 ' : torch . float16 } [ args . dtype ]
#########################################################
2025-06-25 20:27:28 +08:00
# 配置SwanLab
2025-05-14 00:01:40 +08:00
#########################################################
2025-06-25 20:27:28 +08:00
# 设置SwanLab运行名称
args . swanlab_run_name = f " MiniMind-Pretrain-Epoch- { args . epochs } -BatchSize- { args . batch_size } -LearningRate- { args . learning_rate } "
# 合并args和lm_config为一个字典( 无论是否使用SwanLab都需要, 用于打印配置信息)
config_dict = vars ( args ) . copy ( )
config_dict . update ( vars ( lm_config ) )
# 初始化SwanLab实验实例
swanlab_run = None
if args . use_swanlab and accelerator . is_main_process :
2025-07-13 21:28:46 +08:00
if args . swanlab_online :
# 使用在线SwanLab服务
# 初始化SwanLab
swanlab_run = swanlab . init (
project = args . swanlab_project ,
experiment_name = args . swanlab_run_name ,
description = " MiniMind预训练实验, 使用本地部署的SwanLab进行可视化 " ,
config = config_dict
)
else :
swanlab_run = swanlab . init (
project = args . swanlab_project ,
experiment_name = args . swanlab_run_name ,
description = " MiniMind预训练实验, 使用本地部署的SwanLab进行可视化 " ,
config = config_dict ,
mode = " offline "
)
2025-05-14 00:01:40 +08:00
else :
2025-06-25 20:27:28 +08:00
swanlab_run = None
2025-05-14 00:01:40 +08:00
#########################################################
# 打印信息
#########################################################
# 计算每次迭代的token数量
tokens_per_iter = args . batch_size * lm_config . max_seq_len
if accelerator . is_main_process :
Logger ( f " tokens_per_iter: { tokens_per_iter } " , accelerator )
Logger ( " Configuration: " , accelerator )
for key , value in config_dict . items ( ) :
Logger ( f " { key } : { value } " , accelerator )
#########################################################
# 设置自动混合精度上下文
#########################################################
ctx = nullcontext ( ) if accelerator . device . type == " cpu " else torch . cuda . amp . autocast ( dtype = pt_dtype )
#########################################################
# 初始化模型和tokenizer
#########################################################
2025-05-26 23:09:03 +08:00
model , tokenizer = init_model ( lm_config , args . pretrained_embedding_path , args . database_init_path , args )
2025-05-14 00:01:40 +08:00
# 将accelerator传递给init_model函数中的Logger调用
Logger ( f ' 模型初始化完成 ' , accelerator )
#########################################################
# 处理位置编码张量问题
#########################################################
if hasattr ( model , " pos_cis_real " ) :
Logger ( f ' 检测到pos_cis_real实数张量, 将其设置为参与分布式训练 ' , accelerator )
# 设置模型的_ddp_params_and_buffers_to_ignore属性
# model._ddp_params_and_buffers_to_ignore = {"pos_cis_real"}
# 兼容旧版本, 检查是否仍有pos_cis
elif hasattr ( model , " pos_cis " ) :
Logger ( f ' 检测到pos_cis复数张量, 将其设置为不参与分布式训练 ' , accelerator )
# 设置模型的_ddp_params_and_buffers_to_ignore属性
model . _ddp_params_and_buffers_to_ignore = { " pos_cis " }
#########################################################
# 创建数据集和数据加载器
#########################################################
train_ds = PretrainDataset ( args . data_path , tokenizer , max_length = lm_config . max_seq_len )
train_loader = DataLoader (
train_ds ,
batch_size = args . batch_size ,
pin_memory = True ,
drop_last = False ,
shuffle = True ,
num_workers = args . num_workers ,
persistent_workers = True if args . num_workers > 0 else False ,
prefetch_factor = 2 if args . num_workers > 0 else None
)
2025-08-07 11:43:23 +08:00
# 创建验证数据集和加载器
val_loader = None
val_ds = create_validation_dataset ( args . val_data_path , tokenizer , lm_config . max_seq_len )
if val_ds is not None :
val_loader = DataLoader (
val_ds ,
batch_size = args . batch_size / / 2 , # 验证时使用较小批次
pin_memory = True ,
drop_last = False ,
shuffle = False ,
num_workers = 0 , # 验证时不使用多进程
)
2025-05-14 00:01:40 +08:00
#########################################################
# 创建优化器
#########################################################
2025-08-09 10:47:35 +08:00
# 如果启用EMA更新, 需要过滤掉memory_bank参数( 因为它不再需要梯度更新)
if hasattr ( model . params , ' use_ema_update ' ) and model . params . use_ema_update :
# 只包含requires_grad=True的参数
optimizer_params = [ p for p in model . parameters ( ) if p . requires_grad ]
Logger ( f " EMA更新模式: 优化器包含 { len ( optimizer_params ) } 个参数( 过滤掉memory_bank) " )
Logger ( f " 总参数: { sum ( p . numel ( ) for p in model . parameters ( ) ) } | 可训练参数: { sum ( p . numel ( ) for p in optimizer_params ) } " )
optimizer = optim . AdamW ( optimizer_params , lr = args . learning_rate )
else :
# 传统模式:所有参数都使用梯度更新
Logger ( " 传统梯度更新模式:优化器包含所有模型参数 " )
optimizer = optim . AdamW ( model . parameters ( ) , lr = args . learning_rate )
2025-05-14 00:01:40 +08:00
#########################################################
# 创建学习率调度器
#########################################################
total_steps = len ( train_loader ) * args . epochs
warmup_steps = args . warmup_iters if args . warmup_iters > 0 else int ( 0.1 * total_steps )
scheduler = get_cosine_schedule_with_warmup (
optimizer ,
num_warmup_steps = warmup_steps ,
num_training_steps = total_steps
)
#########################################################
# 准备训练
#########################################################
2025-08-07 11:43:23 +08:00
if val_loader is not None :
model , optimizer , train_loader , val_loader , scheduler = accelerator . prepare (
model , optimizer , train_loader , val_loader , scheduler
)
else :
model , optimizer , train_loader , scheduler = accelerator . prepare (
model , optimizer , train_loader , scheduler
)
2025-05-14 00:01:40 +08:00
#########################################################
# 训练循环
#########################################################
overall_start_time = time . time ( ) # Record overall start time
for epoch in range ( args . epochs ) :
2025-06-25 20:27:28 +08:00
Logger ( f " 开始第 { epoch + 1 } 轮训练 " , accelerator )
2025-08-07 11:43:23 +08:00
train_epoch ( epoch , accelerator , model , train_loader , optimizer , scheduler , args , ctx , overall_start_time , swanlab_run , tokenizer , val_loader ) # Pass tokenizer and val_loader
2025-06-25 20:27:28 +08:00
# 每个epoch结束后进行内存清理
Logger ( f " 第 { epoch + 1 } 轮训练完成,进行内存清理 " , accelerator )
gc . collect ( )
if torch . cuda . is_available ( ) :
torch . cuda . empty_cache ( )
# 记录epoch结束时的内存状态
if accelerator . is_main_process :
memory_info = get_memory_usage ( )
cuda_info = get_cuda_memory_usage ( )
log_msg = f " [Memory Monitor] Epoch { epoch + 1 } completed - "
log_msg + = f " System RSS: { memory_info [ ' rss_mb ' ] : .2f } MB "
if cuda_info :
log_msg + = f " , CUDA allocated: { cuda_info [ ' cuda_allocated_mb ' ] : .2f } MB "
log_msg + = f " , CUDA reserved: { cuda_info [ ' cuda_reserved_mb ' ] : .2f } MB "
Logger ( log_msg , accelerator )
#########################################################
# 关闭SwanLab
#########################################################
if args . use_swanlab and accelerator . is_main_process and swanlab_run :
swanlab_run . finish ( )
2025-05-14 00:01:40 +08:00
if __name__ == " __main__ " :
main ( )