添加了对于多种模型的支持
This commit is contained in:
parent
75265f6652
commit
5e464bbd3f
@ -25,7 +25,7 @@ import swanlab # 替换wandb导入
|
||||
import gc # 添加垃圾回收模块
|
||||
import psutil # 添加系统资源监控模块
|
||||
|
||||
from model.model import MiniMindLM, RMSNorm
|
||||
|
||||
from model.LMConfig import LMConfig
|
||||
from model.dataset import PretrainDataset
|
||||
|
||||
@ -105,177 +105,187 @@ def get_lr(it, num_iters, learning_rate):
|
||||
|
||||
# 初始化模型函数
|
||||
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
model = MiniMindLM(lm_config)
|
||||
|
||||
# 默认模型初始化
|
||||
Logger("Performing default model initialization...")
|
||||
|
||||
# 初始化嵌入层权重
|
||||
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
|
||||
|
||||
# 初始化输出层权重(如果不共享权重的话)
|
||||
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
|
||||
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
|
||||
|
||||
# 初始化所有线性层
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
# 使用Xavier/Glorot初始化
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
# 嵌入层使用正态分布初始化
|
||||
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
elif isinstance(module, RMSNorm):
|
||||
# RMSNorm的权重初始化为1
|
||||
if hasattr(module, 'weight'):
|
||||
nn.init.ones_(module.weight)
|
||||
|
||||
# 初始化位置编码相关参数
|
||||
if hasattr(model.knowledge_dataset, 'keys'):
|
||||
nn.init.normal_(model.knowledge_dataset.keys, mean=0.0, std=0.02)
|
||||
|
||||
Logger("Default model initialization completed")
|
||||
if args.model_type == "model":
|
||||
Logger(f"Using model type: {args.model_type}")
|
||||
from model.model import MiniMindLM, RMSNorm
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
model = MiniMindLM(lm_config)
|
||||
|
||||
# 默认模型初始化
|
||||
Logger("Performing default model initialization...")
|
||||
|
||||
# 初始化嵌入层权重
|
||||
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
|
||||
|
||||
# 初始化输出层权重(如果不共享权重的话)
|
||||
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
|
||||
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
|
||||
|
||||
# 初始化所有线性层
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
# 使用Xavier/Glorot初始化
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
# 嵌入层使用正态分布初始化
|
||||
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
elif isinstance(module, RMSNorm):
|
||||
# RMSNorm的权重初始化为1
|
||||
if hasattr(module, 'weight'):
|
||||
nn.init.ones_(module.weight)
|
||||
|
||||
# 初始化位置编码相关参数
|
||||
if hasattr(model.knowledge_dataset, 'keys'):
|
||||
nn.init.normal_(model.knowledge_dataset.keys, mean=0.0, std=0.02)
|
||||
|
||||
Logger("Default model initialization completed")
|
||||
|
||||
# 如果提供了预训练的嵌入权重,加载它们
|
||||
if pretrained_embedding_path:
|
||||
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
|
||||
pretrained_embeddings = torch.load(pretrained_embedding_path)
|
||||
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
|
||||
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
|
||||
|
||||
if database_init_path:
|
||||
import json
|
||||
import os
|
||||
# 如果提供了预训练的嵌入权重,加载它们
|
||||
if pretrained_embedding_path:
|
||||
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
|
||||
pretrained_embeddings = torch.load(pretrained_embedding_path)
|
||||
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
|
||||
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
|
||||
|
||||
# 数据库参数
|
||||
knowledge_num = args.knowledge_num
|
||||
knowledge_length = args.knowledge_length
|
||||
|
||||
# 检查是否使用缓存
|
||||
cache_dir = os.path.dirname(args.cluster_cache_path)
|
||||
if cache_dir:
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
processed_tensor = None
|
||||
|
||||
# 尝试加载缓存的处理结果
|
||||
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
|
||||
try:
|
||||
Logger(f"Loading cached processed results from {args.cluster_cache_path}")
|
||||
processed_tensor = torch.load(args.cluster_cache_path)
|
||||
|
||||
# 验证缓存文件的形状是否可用
|
||||
cached_knowledge_num, cached_knowledge_length = processed_tensor.shape
|
||||
|
||||
if cached_knowledge_length == knowledge_length:
|
||||
if cached_knowledge_num >= knowledge_num:
|
||||
# 缓存足够大,可以截取使用
|
||||
processed_tensor = processed_tensor[:knowledge_num, :]
|
||||
Logger(f"Successfully loaded cached data with shape {processed_tensor.shape}")
|
||||
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
|
||||
Logger("Skipping database initialization - using cached results")
|
||||
if database_init_path:
|
||||
import json
|
||||
import os
|
||||
|
||||
# 数据库参数
|
||||
knowledge_num = args.knowledge_num
|
||||
knowledge_length = args.knowledge_length
|
||||
|
||||
# 检查是否使用缓存
|
||||
cache_dir = os.path.dirname(args.cluster_cache_path)
|
||||
if cache_dir:
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
processed_tensor = None
|
||||
|
||||
# 尝试加载缓存的处理结果
|
||||
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
|
||||
try:
|
||||
Logger(f"Loading cached processed results from {args.cluster_cache_path}")
|
||||
processed_tensor = torch.load(args.cluster_cache_path)
|
||||
|
||||
# 验证缓存文件的形状是否可用
|
||||
cached_knowledge_num, cached_knowledge_length = processed_tensor.shape
|
||||
|
||||
if cached_knowledge_length == knowledge_length:
|
||||
if cached_knowledge_num >= knowledge_num:
|
||||
# 缓存足够大,可以截取使用
|
||||
processed_tensor = processed_tensor[:knowledge_num, :]
|
||||
Logger(f"Successfully loaded cached data with shape {processed_tensor.shape}")
|
||||
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
|
||||
Logger("Skipping database initialization - using cached results")
|
||||
else:
|
||||
# 缓存太小,需要重新计算
|
||||
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
|
||||
processed_tensor = None
|
||||
else:
|
||||
# 缓存太小,需要重新计算
|
||||
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
|
||||
# knowledge_length不匹配,需要重新计算
|
||||
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
|
||||
processed_tensor = None
|
||||
else:
|
||||
# knowledge_length不匹配,需要重新计算
|
||||
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
|
||||
except Exception as e:
|
||||
Logger(f"Failed to load cached data: {e}, recomputing...")
|
||||
processed_tensor = None
|
||||
except Exception as e:
|
||||
Logger(f"Failed to load cached data: {e}, recomputing...")
|
||||
processed_tensor = None
|
||||
|
||||
# 只有在没有有效缓存时才进行数据库初始化和处理
|
||||
if processed_tensor is None:
|
||||
Logger(f"Loading database initialization data from {database_init_path}")
|
||||
|
||||
# 1. 加载JSON文件
|
||||
with open(database_init_path, 'r', encoding='utf-8') as f:
|
||||
database_data = json.load(f)
|
||||
|
||||
# 提取sentences列表
|
||||
sentences_data = database_data.get('sentences', [])
|
||||
Logger(f"Loaded {len(sentences_data)} sentences from database")
|
||||
|
||||
# 2. 按照importance_score进行排序(从高到低)
|
||||
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
||||
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
|
||||
|
||||
# 3. 处理每条数据,不进行聚类
|
||||
Logger("Processing individual sentences...")
|
||||
processed_rows = []
|
||||
|
||||
# 获取空token的id(用于填充)
|
||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
|
||||
# 处理所需数量的句子
|
||||
num_to_process = min(knowledge_num, len(sorted_sentences))
|
||||
|
||||
for i in range(num_to_process):
|
||||
sentence_data = sorted_sentences[i]
|
||||
sentence = sentence_data.get('corrected_sentence', '')
|
||||
# 只有在没有有效缓存时才进行数据库初始化和处理
|
||||
if processed_tensor is None:
|
||||
Logger(f"Loading database initialization data from {database_init_path}")
|
||||
|
||||
# 将句子转换为tokens
|
||||
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
||||
# 1. 加载JSON文件
|
||||
with open(database_init_path, 'r', encoding='utf-8') as f:
|
||||
database_data = json.load(f)
|
||||
|
||||
# 截断或填充到knowledge_length
|
||||
if len(sentence_tokens) > knowledge_length:
|
||||
# 如果超过长度,截断
|
||||
sentence_tokens = sentence_tokens[:knowledge_length]
|
||||
Logger(f"Sentence {i+1} truncated from {len(tokenizer.encode(sentence, add_special_tokens=False))} to {knowledge_length} tokens")
|
||||
else:
|
||||
# 如果不足长度,用空token填充
|
||||
original_length = len(sentence_tokens)
|
||||
sentence_tokens.extend([pad_token_id] * (knowledge_length - len(sentence_tokens)))
|
||||
if original_length < knowledge_length:
|
||||
Logger(f"Sentence {i+1} padded from {original_length} to {knowledge_length} tokens")
|
||||
# 提取sentences列表
|
||||
sentences_data = database_data.get('sentences', [])
|
||||
Logger(f"Loaded {len(sentences_data)} sentences from database")
|
||||
|
||||
processed_rows.append(sentence_tokens)
|
||||
# 2. 按照importance_score进行排序(从高到低)
|
||||
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
||||
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
|
||||
|
||||
if (i + 1) % 1000 == 0:
|
||||
Logger(f"Processed {i + 1}/{num_to_process} sentences")
|
||||
# 3. 处理每条数据,不进行聚类
|
||||
Logger("Processing individual sentences...")
|
||||
processed_rows = []
|
||||
|
||||
# 获取空token的id(用于填充)
|
||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
|
||||
# 处理所需数量的句子
|
||||
num_to_process = min(knowledge_num, len(sorted_sentences))
|
||||
|
||||
for i in range(num_to_process):
|
||||
sentence_data = sorted_sentences[i]
|
||||
sentence = sentence_data.get('corrected_sentence', '')
|
||||
|
||||
# 将句子转换为tokens
|
||||
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
||||
|
||||
# 截断或填充到knowledge_length
|
||||
if len(sentence_tokens) > knowledge_length:
|
||||
# 如果超过长度,截断
|
||||
sentence_tokens = sentence_tokens[:knowledge_length]
|
||||
Logger(f"Sentence {i+1} truncated from {len(tokenizer.encode(sentence, add_special_tokens=False))} to {knowledge_length} tokens")
|
||||
else:
|
||||
# 如果不足长度,用空token填充
|
||||
original_length = len(sentence_tokens)
|
||||
sentence_tokens.extend([pad_token_id] * (knowledge_length - len(sentence_tokens)))
|
||||
if original_length < knowledge_length:
|
||||
Logger(f"Sentence {i+1} padded from {original_length} to {knowledge_length} tokens")
|
||||
|
||||
processed_rows.append(sentence_tokens)
|
||||
|
||||
if (i + 1) % 1000 == 0:
|
||||
Logger(f"Processed {i + 1}/{num_to_process} sentences")
|
||||
|
||||
# 如果句子数量不足,用空token填充剩余位置
|
||||
while len(processed_rows) < knowledge_num:
|
||||
empty_tokens = [pad_token_id] * knowledge_length
|
||||
processed_rows.append(empty_tokens)
|
||||
if len(processed_rows) % 1000 == 0:
|
||||
Logger(f"Added empty entry {len(processed_rows)}/{knowledge_num}")
|
||||
|
||||
Logger(f"Finished adding empty entries. Total: {len(processed_rows)}/{knowledge_num}")
|
||||
|
||||
# 转换为tensor
|
||||
processed_tensor = torch.tensor(processed_rows, dtype=torch.long)
|
||||
|
||||
Logger(f"Data processing completed:")
|
||||
Logger(f" - Processed {num_to_process} sentences")
|
||||
Logger(f" - Added {knowledge_num - num_to_process} empty entries")
|
||||
Logger(f" - Final shape: {processed_tensor.shape}")
|
||||
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
|
||||
|
||||
# 保存处理结果到缓存文件
|
||||
try:
|
||||
torch.save(processed_tensor, args.cluster_cache_path)
|
||||
Logger(f"Processed results saved to {args.cluster_cache_path}")
|
||||
except Exception as e:
|
||||
Logger(f"Failed to save processed results: {e}")
|
||||
|
||||
# 如果句子数量不足,用空token填充剩余位置
|
||||
while len(processed_rows) < knowledge_num:
|
||||
empty_tokens = [pad_token_id] * knowledge_length
|
||||
processed_rows.append(empty_tokens)
|
||||
if len(processed_rows) % 1000 == 0:
|
||||
Logger(f"Added empty entry {len(processed_rows)}/{knowledge_num}")
|
||||
|
||||
Logger(f"Finished adding empty entries. Total: {len(processed_rows)}/{knowledge_num}")
|
||||
|
||||
# 转换为tensor
|
||||
processed_tensor = torch.tensor(processed_rows, dtype=torch.long)
|
||||
|
||||
Logger(f"Data processing completed:")
|
||||
Logger(f" - Processed {num_to_process} sentences")
|
||||
Logger(f" - Added {knowledge_num - num_to_process} empty entries")
|
||||
Logger(f" - Final shape: {processed_tensor.shape}")
|
||||
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
|
||||
|
||||
# 保存处理结果到缓存文件
|
||||
try:
|
||||
torch.save(processed_tensor, args.cluster_cache_path)
|
||||
Logger(f"Processed results saved to {args.cluster_cache_path}")
|
||||
except Exception as e:
|
||||
Logger(f"Failed to save processed results: {e}")
|
||||
|
||||
# 4. 初始化模型的knowledge_dataset
|
||||
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
|
||||
model.knowledge_dataset.knowledge_dataset.data.copy_(processed_tensor)
|
||||
Logger("Successfully initialized model.knowledge_dataset.knowledge_dataset with processed data")
|
||||
else:
|
||||
Logger("Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize")
|
||||
# 存储为全局变量作为备选
|
||||
globals()['processed_database'] = processed_tensor
|
||||
# 4. 初始化模型的knowledge_dataset
|
||||
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
|
||||
model.knowledge_dataset.knowledge_dataset.data.copy_(processed_tensor)
|
||||
Logger("Successfully initialized model.knowledge_dataset.knowledge_dataset with processed data")
|
||||
else:
|
||||
Logger("Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize")
|
||||
# 存储为全局变量作为备选
|
||||
globals()['processed_database'] = processed_tensor
|
||||
|
||||
Logger(f"Database embeddings and sentences stored in model")
|
||||
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
elif args.model_type == "model_original":
|
||||
Logger(f"Using model type: {args.model_type}")
|
||||
from model.model_original import MiniMindLM, RMSNorm
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
model = MiniMindLM(lm_config)
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
|
||||
Logger(f"Database embeddings and sentences stored in model")
|
||||
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
return model, tokenizer
|
||||
|
||||
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run):
|
||||
@ -389,7 +399,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
# 添加辅助损失,如果存在的话
|
||||
try:
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
|
||||
if hasattr(l.feed_forward, 'aux_loss'))
|
||||
if hasattr(l, 'feed_forward') and hasattr(l.feed_forward, 'aux_loss'))
|
||||
loss += aux_loss
|
||||
except Exception as e:
|
||||
Logger(f"Warning: Could not add auxiliary loss: {e}")
|
||||
@ -586,7 +596,7 @@ def main():
|
||||
parser.add_argument('--max_seq_len', default=512, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/merged_pretrain.jsonl")
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/stable/merged_pretrain.jsonl")
|
||||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||
@ -599,6 +609,7 @@ def main():
|
||||
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
|
||||
parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控")
|
||||
parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)")
|
||||
parser.add_argument("--model_type", type=str, default="model", help="使用什么模型训练") #model,model_original
|
||||
args = parser.parse_args()
|
||||
|
||||
#########################################################
|
||||
|
Loading…
x
Reference in New Issue
Block a user