Minimind/train_pretrain_accelerate.py
Yu Chengzhang 44fe6259ec Experiment 1.4.7: Memory Bank文本初始化 + 部分冻结机制
## 主要改进
- 🔥 Memory Bank文本初始化:使用sentence_trex_data.json真实文本数据
- 🔥 部分冻结机制:新增freeze_ratio=0.2,保护20%重要记忆条目
- 📊 性能提升:推理Loss改善5.5% (2.4699 vs 2.6142)

## 核心变更
### model/LMConfig.py
- 新增freeze_ratio参数,支持Memory Bank条目冻结控制

### model/model_memory.py
- 实现freeze_mask机制,随机冻结20%记忆条目
- EMA更新过滤:只更新未冻结条目,保护重要知识
- 统计信息增强:新增冻结条目数量和比例监控

### train_pretrain_accelerate.py
- model_memory完整初始化支持:文本数据处理、缓存机制
- sentence_trex_data.json文本tokenization和长度处理
- memory_bank_init缓存优化,提升重复实验效率

### 实验文档
- experiment/EXPERIMENT_1_4_7.md:完整实验记录和结果分析
- run_file/experiment_1_4_7.sh:实验执行脚本
- CLAUDE.md:架构设计防护规则和模型版本管理规范

## 实验结果
 文本初始化效果验证:Loss性能改善5.5%
 冻结机制技术实现:209,715/1,048,576条目成功冻结
 生成连贯性仍需改进:架构级问题待解决

## 下一步优化
- EOS token控制修复
- Cross-attention权重优化
- 生成参数调优(temperature/top_p)

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-19 19:32:52 +08:00

1413 lines
71 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
# 设置环境变量 - 将wandb替换为SwanLab
# os.environ["SWANLAB_MODE"] = "online" # SwanLab使用在线模式
import platform
import argparse
from tqdm import tqdm
import time
import math
import warnings
import pandas as pd
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
from contextlib import nullcontext
from typing import Optional
import datetime # Add datetime for time formatting
from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.utils import DeepSpeedPlugin
from accelerate.utils import DistributedDataParallelKwargs
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import swanlab # 替换wandb导入
import gc # 添加垃圾回收模块
import psutil # 添加系统资源监控模块
import json # 添加JSON支持
from model.LMConfig import LMConfig
from model.dataset import PretrainDataset
warnings.filterwarnings('ignore')
# 内存监控辅助函数
def get_memory_usage():
"""获取当前内存使用情况"""
process = psutil.Process()
memory_info = process.memory_info()
return {
'rss_mb': memory_info.rss / 1024 / 1024, # 物理内存使用量MB
'vms_mb': memory_info.vms / 1024 / 1024, # 虚拟内存使用量MB
}
def get_cuda_memory_usage():
"""获取CUDA内存使用情况"""
if torch.cuda.is_available():
return {
'cuda_allocated_mb': torch.cuda.memory_allocated() / 1024 / 1024,
'cuda_reserved_mb': torch.cuda.memory_reserved() / 1024 / 1024,
'cuda_max_allocated_mb': torch.cuda.max_memory_allocated() / 1024 / 1024,
}
return {}
def get_tensor_memory_size(tensor_list):
"""计算tensor列表的总内存占用MB"""
total_size = 0
for batch in tensor_list:
if isinstance(batch, (list, tuple)):
for tensor in batch:
if isinstance(tensor, torch.Tensor):
total_size += tensor.numel() * tensor.element_size()
elif isinstance(batch, torch.Tensor):
total_size += batch.numel() * batch.element_size()
return total_size / 1024 / 1024 # 转换为MB
def log_memory_status(step, prefetch_batches, accelerator, stage="", detailed=False):
"""记录内存状态"""
if not accelerator.is_main_process:
return
memory_info = get_memory_usage()
cuda_info = get_cuda_memory_usage()
prefetch_memory = get_tensor_memory_size(prefetch_batches)
log_msg = f"[Memory Monitor] Step {step} {stage} - "
log_msg += f"Prefetch batches: {len(prefetch_batches)}, "
log_msg += f"Prefetch memory: {prefetch_memory:.2f}MB, "
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
if detailed:
log_msg += f", System VMS: {memory_info['vms_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA max allocated: {cuda_info['cuda_max_allocated_mb']:.2f}MB"
Logger(log_msg, accelerator)
# 日志记录函数
def Logger(msg, accelerator=None):
# 如果没有提供accelerator则只在主进程打印
if accelerator is None or accelerator.is_main_process:
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}")
# Helper function to format seconds into HH:MM:SS
def format_time(seconds):
return str(datetime.timedelta(seconds=int(seconds)))
def create_validation_dataset(val_data_path, tokenizer, max_length, num_samples=200):
"""
创建验证数据集
Args:
val_data_path: 验证数据文件路径
tokenizer: tokenizer实例
max_length: 最大序列长度
num_samples: 验证样本数量
Returns:
val_dataset: 验证数据集
"""
if not os.path.exists(val_data_path):
Logger(f"警告:验证数据文件不存在: {val_data_path},跳过验证评估")
return None
# 读取验证数据
val_data = []
with open(val_data_path, 'r', encoding='utf-8') as f:
for i, line in enumerate(f):
if i >= num_samples: # 限制验证样本数量
break
line = line.strip()
if line:
try:
sample = json.loads(line)
val_data.append(sample['text'])
except json.JSONDecodeError:
continue
# 创建临时验证文件
temp_val_file = "/tmp/temp_val.jsonl"
with open(temp_val_file, 'w', encoding='utf-8') as f:
for text in val_data:
f.write(json.dumps({'text': text}) + '\n')
# 使用PretrainDataset创建验证集
val_dataset = PretrainDataset(temp_val_file, tokenizer, max_length=max_length)
Logger(f"创建验证数据集成功,包含 {len(val_data)} 个样本")
return val_dataset
def validate_model(model, val_loader, loss_fct, ctx, accelerator):
"""
执行模型验证
Args:
model: 模型实例
val_loader: 验证数据加载器
loss_fct: 损失函数
ctx: 上下文管理器
accelerator: Accelerator实例
Returns:
avg_val_loss: 平均验证损失
"""
model.eval()
total_loss = 0
num_batches = 0
with torch.no_grad():
for batch in val_loader:
X, Y, loss_mask = batch
with ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
total_loss += loss.item()
num_batches += 1
model.train()
avg_val_loss = total_loss / num_batches if num_batches > 0 else float('inf')
return avg_val_loss
# 获取学习率函数
def get_lr(it, num_iters, learning_rate):
# 余弦学习率衰减
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
# 初始化模型函数
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
if args.model_type == "model":
Logger(f"Using model type: {args.model_type}")
from model.model import MiniMindLM, RMSNorm
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
# 默认模型初始化
Logger("Performing default model initialization...")
# 初始化嵌入层权重
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
# 初始化输出层权重(如果不共享权重的话)
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
# 初始化所有线性层
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# 使用Xavier/Glorot初始化
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
# 嵌入层使用正态分布初始化
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, RMSNorm):
# RMSNorm的权重初始化为1
if hasattr(module, 'weight'):
nn.init.ones_(module.weight)
# 初始化位置编码相关参数
if hasattr(model.knowledge_dataset, 'keys'):
nn.init.normal_(model.knowledge_dataset.keys, mean=0.0, std=0.02)
Logger("Default model initialization completed")
# 如果提供了预训练的嵌入权重,加载它们
if pretrained_embedding_path:
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
pretrained_embeddings = torch.load(pretrained_embedding_path)
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
if database_init_path:
import json
# 数据库参数
knowledge_num = args.knowledge_num
knowledge_length = args.knowledge_length
# 检查是否使用缓存
cache_dir = os.path.dirname(args.cluster_cache_path)
if cache_dir:
os.makedirs(cache_dir, exist_ok=True)
processed_tensor = None
# 尝试加载缓存的处理结果
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
try:
Logger(f"Loading cached processed results from {args.cluster_cache_path}")
processed_tensor = torch.load(args.cluster_cache_path)
# 验证缓存文件的形状是否可用
cached_knowledge_num, cached_knowledge_length = processed_tensor.shape
if cached_knowledge_length == knowledge_length:
if cached_knowledge_num >= knowledge_num:
# 缓存足够大,可以截取使用
processed_tensor = processed_tensor[:knowledge_num, :]
Logger(f"Successfully loaded cached data with shape {processed_tensor.shape}")
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
Logger("Skipping database initialization - using cached results")
else:
# 缓存太小,需要重新计算
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
processed_tensor = None
else:
# knowledge_length不匹配需要重新计算
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
processed_tensor = None
except Exception as e:
Logger(f"Failed to load cached data: {e}, recomputing...")
processed_tensor = None
# 只有在没有有效缓存时才进行数据库初始化和处理
if processed_tensor is None:
Logger(f"Loading database initialization data from {database_init_path}")
# 1. 加载JSON文件
with open(database_init_path, 'r', encoding='utf-8') as f:
database_data = json.load(f)
sentences_data = []
for data in database_data:
sentences_data.append(data['target'][0]['sentence'])
# 提取sentences列表
# sentences_data = database_data.get('sentences', [])
Logger(f"Loaded {len(sentences_data)} sentences from database")
# 2. 按照importance_score进行排序从高到低
try:
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
except:
sorted_sentences = sentences_data
# 3. 处理每条数据,不进行聚类
Logger("Processing individual sentences...")
processed_rows = []
# 获取空token的id用于填充
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# 处理所需数量的句子
num_to_process = min(knowledge_num, len(sorted_sentences))
# 添加截断统计变量
total_sentences = 0
truncated_sentences = 0
for i in range(num_to_process):
sentence_data = sorted_sentences[i]
try:
sentence = sentence_data.get('corrected_sentence')
except:
sentence = sentence_data
# 将句子转换为tokens
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
# 截断或填充到knowledge_length
total_sentences += 1
if len(sentence_tokens) > knowledge_length:
# 如果超过长度,截断
truncated_sentences += 1
sentence_tokens = sentence_tokens[:knowledge_length]
Logger(f"Sentence {i+1} truncated from {len(tokenizer.encode(sentence, add_special_tokens=False))} to {knowledge_length} tokens")
else:
# 如果不足长度用空token填充
original_length = len(sentence_tokens)
sentence_tokens.extend([pad_token_id] * (knowledge_length - len(sentence_tokens)))
if original_length < knowledge_length:
Logger(f"Sentence {i+1} padded from {original_length} to {knowledge_length} tokens")
processed_rows.append(sentence_tokens)
if (i + 1) % 1000 == 0:
Logger(f"Processed {i + 1}/{num_to_process} sentences")
# 如果句子数量不足用空token填充剩余位置
while len(processed_rows) < knowledge_num:
empty_tokens = [pad_token_id] * knowledge_length
processed_rows.append(empty_tokens)
if len(processed_rows) % 1000 == 0:
Logger(f"Added empty entry {len(processed_rows)}/{knowledge_num}")
Logger(f"Finished adding empty entries. Total: {len(processed_rows)}/{knowledge_num}")
# 转换为tensor
processed_tensor = torch.tensor(processed_rows, dtype=torch.long)
# 计算并打印截断句子的占比
truncation_ratio = truncated_sentences / total_sentences if total_sentences > 0 else 0.0
Logger(f"截断句子统计:")
Logger(f" - 总句子数: {total_sentences}")
Logger(f" - 截断句子数: {truncated_sentences}")
Logger(f" - 截断句子占比: {truncation_ratio:.4f} ({truncation_ratio*100:.2f}%)")
Logger(f"Data processing completed:")
Logger(f" - Processed {num_to_process} sentences")
Logger(f" - Added {knowledge_num - num_to_process} empty entries")
Logger(f" - Final shape: {processed_tensor.shape}")
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
# 保存处理结果到缓存文件
try:
torch.save(processed_tensor, args.cluster_cache_path)
Logger(f"Processed results saved to {args.cluster_cache_path}")
except Exception as e:
Logger(f"Failed to save processed results: {e}")
# 4. 初始化模型的knowledge_dataset
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
model.knowledge_dataset.knowledge_dataset.data.copy_(processed_tensor)
Logger("Successfully initialized model.knowledge_dataset.knowledge_dataset with processed data")
else:
Logger("Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize")
# 存储为全局变量作为备选
globals()['processed_database'] = processed_tensor
Logger(f"Database embeddings and sentences stored in model")
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
elif args.model_type == "model_original":
Logger(f"Using model type: {args.model_type}")
from model.model_original import MiniMindLM, RMSNorm
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
elif args.model_type == "model_no_feed":
Logger(f"Using model type: {args.model_type}")
from model.model_no_feed import MiniMindLM, RMSNorm
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
# 默认模型初始化
Logger("Performing default model initialization...")
# 初始化嵌入层权重
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
# 初始化输出层权重(如果不共享权重的话)
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
# 初始化所有线性层
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# 使用Xavier/Glorot初始化
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
# 嵌入层使用正态分布初始化
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, RMSNorm):
# RMSNorm的权重初始化为1
if hasattr(module, 'weight'):
nn.init.ones_(module.weight)
# 初始化位置编码相关参数
if hasattr(model.knowledge_dataset, 'keys'):
nn.init.normal_(model.knowledge_dataset.keys, mean=0.0, std=0.02)
Logger("Default model initialization completed")
# 如果提供了预训练的嵌入权重,加载它们
if pretrained_embedding_path:
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
pretrained_embeddings = torch.load(pretrained_embedding_path)
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
if database_init_path:
import json
# 数据库参数
knowledge_num = args.knowledge_num
knowledge_length = args.knowledge_length
# 检查是否使用缓存
cache_dir = os.path.dirname(args.cluster_cache_path)
if cache_dir:
os.makedirs(cache_dir, exist_ok=True)
processed_tensor = None
# 尝试加载缓存的处理结果
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
try:
Logger(f"Loading cached processed results from {args.cluster_cache_path}")
processed_tensor = torch.load(args.cluster_cache_path)
# 验证缓存文件的形状是否可用
cached_knowledge_num, cached_knowledge_length = processed_tensor.shape
if cached_knowledge_length == knowledge_length:
if cached_knowledge_num >= knowledge_num:
# 缓存足够大,可以截取使用
processed_tensor = processed_tensor[:knowledge_num, :]
Logger(f"Successfully loaded cached data with shape {processed_tensor.shape}")
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
Logger("Skipping database initialization - using cached results")
else:
# 缓存太小,需要重新计算
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
processed_tensor = None
else:
# knowledge_length不匹配需要重新计算
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
processed_tensor = None
except Exception as e:
Logger(f"Failed to load cached data: {e}, recomputing...")
processed_tensor = None
# 只有在没有有效缓存时才进行数据库初始化和处理
if processed_tensor is None:
Logger(f"Loading database initialization data from {database_init_path}")
# 1. 加载JSON文件
with open(database_init_path, 'r', encoding='utf-8') as f:
database_data = json.load(f)
sentences_data = []
for data in database_data:
sentences_data.append(data['target'][0]['sentence'])
# 提取sentences列表
# sentences_data = database_data.get('sentences', [])
Logger(f"Loaded {len(sentences_data)} sentences from database")
# 2. 按照importance_score进行排序从高到低
try:
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
except:
sorted_sentences = sentences_data
# 3. 处理每条数据,不进行聚类
Logger("Processing individual sentences...")
processed_rows = []
# 获取空token的id用于填充
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# 处理所需数量的句子
num_to_process = min(knowledge_num, len(sorted_sentences))
# 添加截断统计变量
total_sentences = 0
truncated_sentences = 0
for i in range(num_to_process):
sentence_data = sorted_sentences[i]
try:
sentence = sentence_data.get('corrected_sentence')
except:
sentence = sentence_data
# 将句子转换为tokens
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
# 截断或填充到knowledge_length
total_sentences += 1
if len(sentence_tokens) > knowledge_length:
# 如果超过长度,截断
truncated_sentences += 1
sentence_tokens = sentence_tokens[:knowledge_length]
Logger(f"Sentence {i+1} truncated from {len(tokenizer.encode(sentence, add_special_tokens=False))} to {knowledge_length} tokens")
else:
# 如果不足长度用空token填充
original_length = len(sentence_tokens)
sentence_tokens.extend([pad_token_id] * (knowledge_length - len(sentence_tokens)))
if original_length < knowledge_length:
Logger(f"Sentence {i+1} padded from {original_length} to {knowledge_length} tokens")
processed_rows.append(sentence_tokens)
if (i + 1) % 1000 == 0:
Logger(f"Processed {i + 1}/{num_to_process} sentences")
# 如果句子数量不足用空token填充剩余位置
while len(processed_rows) < knowledge_num:
empty_tokens = [pad_token_id] * knowledge_length
processed_rows.append(empty_tokens)
if len(processed_rows) % 1000 == 0:
Logger(f"Added empty entry {len(processed_rows)}/{knowledge_num}")
Logger(f"Finished adding empty entries. Total: {len(processed_rows)}/{knowledge_num}")
# 转换为tensor
processed_tensor = torch.tensor(processed_rows, dtype=torch.long)
# 计算并打印截断句子的占比
truncation_ratio = truncated_sentences / total_sentences if total_sentences > 0 else 0.0
Logger(f"截断句子统计:")
Logger(f" - 总句子数: {total_sentences}")
Logger(f" - 截断句子数: {truncated_sentences}")
Logger(f" - 截断句子占比: {truncation_ratio:.4f} ({truncation_ratio*100:.2f}%)")
Logger(f"Data processing completed:")
Logger(f" - Processed {num_to_process} sentences")
Logger(f" - Added {knowledge_num - num_to_process} empty entries")
Logger(f" - Final shape: {processed_tensor.shape}")
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
# 保存处理结果到缓存文件
try:
torch.save(processed_tensor, args.cluster_cache_path)
Logger(f"Processed results saved to {args.cluster_cache_path}")
except Exception as e:
Logger(f"Failed to save processed results: {e}")
# 4. 初始化模型的knowledge_dataset
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
model.knowledge_dataset.knowledge_dataset.data.copy_(processed_tensor)
Logger("Successfully initialized model.knowledge_dataset.knowledge_dataset with processed data")
else:
Logger("Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize")
# 存储为全局变量作为备选
globals()['processed_database'] = processed_tensor
Logger(f"Database embeddings and sentences stored in model")
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
elif args.model_type == "model_memory":
Logger(f"Using model type: {args.model_type}")
from model.model_memory import MiniMindLM, RMSNorm
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
# 默认模型初始化
Logger("Performing model_memory initialization...")
# 初始化嵌入层权重
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
# 初始化输出层权重(如果不共享权重的话)
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
# 初始化所有线性层
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# 使用Xavier/Glorot初始化
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
# 嵌入层使用正态分布初始化
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, RMSNorm):
# RMSNorm的权重初始化为1
if hasattr(module, 'weight'):
nn.init.ones_(module.weight)
# 记忆库初始化
if database_init_path and os.path.exists(database_init_path):
Logger(f"Initializing memory_bank with text data from {database_init_path}")
import json
# 数据库参数
knowledge_num = args.knowledge_num
knowledge_length = args.knowledge_length
# 缓存文件路径
memory_cache_path = args.cluster_cache_path or f"cache/memory_bank_init_{knowledge_num}_{knowledge_length}.pt"
os.makedirs(os.path.dirname(memory_cache_path) if os.path.dirname(memory_cache_path) else '.', exist_ok=True)
# 检查是否有缓存
if os.path.exists(memory_cache_path):
Logger(f"Loading memory_bank initialization from cache: {memory_cache_path}")
processed_tensor = torch.load(memory_cache_path)
Logger(f"Loaded memory_bank data with shape: {processed_tensor.shape}")
else:
Logger(f"Processing text data from {database_init_path} for memory_bank initialization")
# 加载数据
with open(database_init_path, 'r', encoding='utf-8') as f:
data = json.load(f)
Logger(f"Loaded {len(data)} sentences from {database_init_path}")
# 处理句子到token序列
processed_rows = []
total_sentences = len(data)
truncated_sentences = 0
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# 控制处理的句子数量
num_to_process = min(len(data), knowledge_num)
Logger(f"Processing {num_to_process} out of {total_sentences} sentences")
# 处理句子到token ID序列
for idx, item in enumerate(data[:num_to_process]):
if idx % 1000 == 0:
Logger(f"Processing sentence {idx+1}/{num_to_process}")
# 获取句子文本
if isinstance(item, dict):
sentence = item.get('sentence', '') or item.get('text', '') or str(item)
else:
sentence = str(item)
# 使用tokenizer编码句子
try:
tokens = tokenizer(
sentence,
add_special_tokens=True,
truncation=True,
max_length=knowledge_length,
padding=False,
return_tensors="pt"
)['input_ids'].squeeze().tolist()
# 确保是列表
if not isinstance(tokens, list):
tokens = [tokens]
# 检查长度
if len(tokens) > knowledge_length:
tokens = tokens[:knowledge_length]
truncated_sentences += 1
elif len(tokens) < knowledge_length:
# 用padding token填充
tokens.extend([pad_token_id] * (knowledge_length - len(tokens)))
processed_rows.append(tokens)
except Exception as e:
Logger(f"Error processing sentence {idx}: {e}")
# 使用空tokens作为fallback
empty_tokens = [pad_token_id] * knowledge_length
processed_rows.append(empty_tokens)
# 如果句子数量不足用空token填充剩余位置
while len(processed_rows) < knowledge_num:
empty_tokens = [pad_token_id] * knowledge_length
processed_rows.append(empty_tokens)
if len(processed_rows) % 1000 == 0:
Logger(f"Added empty entry {len(processed_rows)}/{knowledge_num}")
# 转换为tensor
processed_tensor = torch.tensor(processed_rows, dtype=torch.long)
# 计算并打印截断句子的占比
truncation_ratio = truncated_sentences / total_sentences if total_sentences > 0 else 0.0
Logger(f"截断句子统计:")
Logger(f" - 总句子数: {total_sentences}")
Logger(f" - 截断句子数: {truncated_sentences}")
Logger(f" - 截断句子占比: {truncation_ratio:.4f} ({truncation_ratio*100:.2f}%)")
Logger(f"Memory_bank data processing completed:")
Logger(f" - Processed {num_to_process} sentences")
Logger(f" - Added {knowledge_num - num_to_process} empty entries")
Logger(f" - Final shape: {processed_tensor.shape}")
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
# 保存处理结果到缓存文件
try:
torch.save(processed_tensor, memory_cache_path)
Logger(f"Processed results saved to {memory_cache_path}")
except Exception as e:
Logger(f"Failed to save processed results: {e}")
# 初始化模型的memory_bank
if hasattr(model, 'memory_bank'):
model.memory_bank.data.copy_(processed_tensor)
Logger("Successfully initialized memory_bank with processed text data")
else:
Logger("Warning: Could not find memory_bank to initialize")
else:
Logger(f"Memory bank initialized with random values, shape: {model.memory_bank.shape}")
Logger("Model_memory initialization completed")
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer, val_loader=None):
loss_fct = nn.CrossEntropyLoss(reduction='none')
epoch_start_time = time.time()
total_steps_in_epoch = len(train_loader)
total_training_steps = args.epochs * total_steps_in_epoch
moe_path = '_moe' if args.use_moe else ''
best_loss = float('10000')
# 初始化CUDA事件变量
data_start = data_end = forward_start = forward_end = None
backward_start = backward_end = optimizer_start = optimizer_end = None
# 添加CUDA事件来分析性能 (只在主进程进行)
if args.profile and accelerator.is_main_process:
data_start = torch.cuda.Event(enable_timing=True)
data_end = torch.cuda.Event(enable_timing=True)
forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
# 预取数据
prefetch_factor = 8 # 预取的批次数
data_iter = iter(train_loader)
prefetch_batches = []
# 记录初始内存状态
if args.memory_monitor:
log_memory_status(-1, prefetch_batches, accelerator, "before_prefetch", detailed=True)
# 预取初始批次
for i in range(min(prefetch_factor, len(train_loader))):
try:
batch = next(data_iter)
prefetch_batches.append(batch)
# 每次添加batch后记录内存变化
if args.memory_monitor and accelerator.is_main_process:
log_memory_status(-1, prefetch_batches, accelerator, f"after_adding_batch_{i+1}")
except StopIteration:
break
# 记录预取完成后的内存状态
if args.memory_monitor:
log_memory_status(-1, prefetch_batches, accelerator, "after_initial_prefetch", detailed=True)
# 在开始循环前初始化日志记录所需变量
last_log_time = epoch_start_time
for step in range(total_steps_in_epoch):
try:
# 计时数据加载 (只在主进程进行)
if args.profile and accelerator.is_main_process and data_start is not None:
data_start.record()
# 记录使用batch前的内存状态根据配置间隔记录详细信息
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "before_use_batch", detailed=True)
# 使用预取的数据
if prefetch_batches:
X, Y, loss_mask = prefetch_batches.pop(0)
# 记录使用batch后的内存变化
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "after_pop_batch")
else:
# 如果预取队列为空,直接加载
X, Y, loss_mask = next(data_iter)
if args.memory_monitor and accelerator.is_main_process:
Logger(f"[Memory Monitor] Step {step} - Prefetch queue empty, loading directly!", accelerator)
# 异步预取下一批数据
if step + prefetch_factor < len(train_loader):
try:
batch = next(data_iter)
prefetch_batches.append(batch)
# 记录添加新batch后的内存变化
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "after_add_batch")
except StopIteration:
pass
# 计时数据加载结束 (只在主进程进行)
if args.profile and accelerator.is_main_process and data_end is not None:
data_end.record()
# 更新学习率
if scheduler is not None:
scheduler.step()
# 计时前向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process and forward_start is not None:
forward_start.record()
# 前向传播
with ctx:
if step == 0 and args.embedding_epoch == epoch:
# 需要设置原始模型的freeze_embedding属性而不是包装后的模型
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.freeze_embedding = True
Logger(f"Set freeze_embedding=True for epoch {epoch}, step {step}", accelerator)
res = model(X, step=step)
# 计算主要损失(交叉熵损失)
ce_loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
ce_loss = (ce_loss * loss_mask).sum() / loss_mask.sum()
# 获取平衡损失(如果模型支持)
balance_loss = 0
if hasattr(res, 'aux_loss') and res.aux_loss is not None:
balance_loss = res.aux_loss
# 计算总损失
total_loss = ce_loss + args.balance_loss_coef * balance_loss
loss = total_loss / args.accumulation_steps
# 计时前向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process and forward_end is not None:
forward_end.record()
# 计时反向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process and backward_start is not None:
backward_start.record()
# 反向传播
# 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪
accelerator.backward(loss)
# 计时反向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process and backward_end is not None:
backward_end.record()
# 计时优化器步骤 (只在主进程进行)
if args.profile and accelerator.is_main_process and optimizer_start is not None:
optimizer_start.record()
# 优化器步骤 - 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪
# 只有在达到累积步数时才会执行优化器步骤
# 注意当使用DeepSpeed时它会自动处理梯度累积所以我们不需要检查step % accumulation_steps
optimizer.step()
# 当使用DeepSpeed时zero_grad()会在step()之后自动调用
# 但为了安全起见,我们仍然显式调用它
optimizer.zero_grad()
# VQ-VAE风格的EMA更新仅在启用时执行
if hasattr(res, 'ema_stats') and res.ema_stats is not None:
unwrapped_model = accelerator.unwrap_model(model)
if hasattr(unwrapped_model, 'apply_ema_update'):
ema_update_stats = unwrapped_model.apply_ema_update(res.ema_stats)
# 记录EMA更新统计信息
if step % args.log_interval == 0 and accelerator.is_main_process and ema_update_stats.get('ema_update_applied', False):
total_memories = args.knowledge_num
Logger(f"EMA Update - Step: {ema_update_stats['ema_step']}, "
f"Updated memories: {ema_update_stats['updated_memories']}/{total_memories} "
f"({ema_update_stats['update_ratio']:.4f}), "
f"Coverage: {ema_update_stats['selected_memory_coverage']:.4f}", accelerator)
# 计时优化器步骤结束 (只在主进程进行)
if args.profile and accelerator.is_main_process and optimizer_end is not None:
optimizer_end.record()
# 验证评估和日志记录 (只在主进程进行)
if (step + 1) % args.val_interval == 0 and accelerator.is_main_process:
current_time = time.time()
# 记录日志输出时的详细内存状态
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "at_log_interval", detailed=True)
# 强制垃圾回收并记录内存变化
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
log_memory_status(step, prefetch_batches, accelerator, "after_gc", detailed=True)
# 计算性能指标
if args.profile and accelerator.is_main_process:
torch.cuda.synchronize()
# 确保所有事件都已记录才计算elapsed_time
try:
data_time = data_start.elapsed_time(data_end) if data_start is not None and data_end is not None else 0
forward_time = forward_start.elapsed_time(forward_end) if forward_start is not None and forward_end is not None else 0
backward_time = backward_start.elapsed_time(backward_end) if backward_start is not None and backward_end is not None else 0
optimizer_time = optimizer_start.elapsed_time(optimizer_end) if optimizer_start is not None and optimizer_end is not None else 0
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
# total_time_ms = data_time + forward_time + backward_time + optimizer_time
# 打印性能分析
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
Logger(f"性能分析 (Avg/iter over last {args.log_interval} steps) - "
f"Data: {data_time/args.log_interval:.2f}ms, "
f"Fwd: {forward_time/args.log_interval:.2f}ms, "
f"Bwd: {backward_time/args.log_interval:.2f}ms, "
f"Optim: {optimizer_time/args.log_interval:.2f}ms, "
f"Iter Time: {iter_time:.2f}ms", accelerator)
# 生成文本示例
try:
# 随机选择一个样本
random_idx = torch.randint(0, X.size(0), (1,)).item()
sample_input = X[random_idx:random_idx+1] # [1, seq_len]
sample_target = Y[random_idx:random_idx+1] # [1, seq_len]
# 取前面的部分作为prompt确保后面有10个token作为真实值
prompt_len = sample_input.size(1) // 2
prompt_input = sample_input[:, :prompt_len]
# 获取真实的后10个token
true_next_tokens = sample_target[:, prompt_len-1:prompt_len-1+10] # 真实的接下来10个token
# 生成10个token
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.eval() # 设置为评估模式
with torch.no_grad():
generated = unwrapped_model.generate(
prompt_input,
max_new_tokens=10,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
# 转换为人类可读文本
prompt_text = tokenizer.decode(prompt_input[0], skip_special_tokens=True)
true_text = tokenizer.decode(true_next_tokens[0], skip_special_tokens=True)
# 获取新生成的token
prompt_tokens = prompt_input[0].tolist()
generated_tokens = generated[0].tolist()
if len(generated_tokens) > len(prompt_tokens):
new_tokens = generated_tokens[len(prompt_tokens):len(prompt_tokens)+10] # 只取前10个
generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
else:
generated_text = "[未生成新token]"
Logger(f"文本生成对比:", accelerator)
Logger(f" 输入提示: {prompt_text}", accelerator)
Logger(f" 真实续写: {true_text}", accelerator)
Logger(f" 模型生成: {generated_text}", accelerator)
unwrapped_model.train() # 恢复训练模式
except Exception as e:
Logger(f"生成文本示例失败: {e}", accelerator)
# 重置事件以便下次测量从0开始
data_start = torch.cuda.Event(enable_timing=True)
data_end = torch.cuda.Event(enable_timing=True)
forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
except RuntimeError as e:
if "Both events must be recorded" in str(e):
Logger(f"Warning: CUDA events not properly recorded, skipping performance analysis: {e}", accelerator)
else:
raise e
# 计算当前学习率
current_lr = optimizer.param_groups[0]['lr']
# 计算时间
epoch_elapsed_time = current_time - epoch_start_time
epoch_steps_done = step + 1
epoch_avg_step_time = epoch_elapsed_time / epoch_steps_done
epoch_remaining_time = epoch_avg_step_time * (total_steps_in_epoch - epoch_steps_done)
total_elapsed_time = current_time - overall_start_time
total_steps_done = epoch * total_steps_in_epoch + epoch_steps_done
total_avg_step_time = total_elapsed_time / total_steps_done if total_steps_done > 0 else 0
total_remaining_time = total_avg_step_time * (total_training_steps - total_steps_done) if total_steps_done > 0 else 0
# 计算训练速度 (基于最近的log_interval)
interval_elapsed_time = current_time - last_log_time
tokens_processed_interval = args.log_interval * args.batch_size * args.max_seq_len
tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0
last_log_time = current_time # 更新上次日志时间
# 执行验证评估
val_loss = None
if val_loader is not None:
try:
val_loss = validate_model(model, val_loader, loss_fct, ctx, accelerator)
Logger(f"验证损失: {val_loss:.4f}", accelerator)
except Exception as e:
Logger(f"验证评估失败: {e}", accelerator)
val_loss = None
# 获取记忆库更新统计(如果模型支持)
memory_update_stats = {}
if hasattr(model, 'get_memory_update_stats'):
try:
unwrapped_model = accelerator.unwrap_model(model)
if hasattr(unwrapped_model, 'get_memory_update_stats'):
memory_update_stats = unwrapped_model.get_memory_update_stats()
except Exception as e:
Logger(f"获取记忆更新统计失败: {e}", accelerator)
# 获取层级统计信息(如果模型支持)
layer_stats = {}
if hasattr(res, 'layer_stats') and res.layer_stats is not None:
layer_stats = res.layer_stats
# 构建日志字典
log_dict = {
"epoch": epoch + 1,
"step": step + 1,
"total_steps_in_epoch": total_steps_in_epoch,
"train/loss_ce": ce_loss.item(),
"train/loss_balance": balance_loss.item() if isinstance(balance_loss, torch.Tensor) else balance_loss,
"train/loss_total": total_loss.item(),
"lr": current_lr,
"tokens_per_sec": tokens_per_sec,
"epoch_time_left_seconds": epoch_remaining_time,
"total_time_left_seconds": total_remaining_time
}
# 添加验证损失
if val_loss is not None:
log_dict["val/loss"] = val_loss
# 添加记忆库更新统计
log_dict.update(memory_update_stats)
# 添加层级统计信息(选择性添加关键指标)
if layer_stats:
# 计算所有层的平均统计
avg_gini = np.mean([v for k, v in layer_stats.items() if k.endswith('_gini_coefficient')])
avg_coverage = np.mean([v for k, v in layer_stats.items() if k.endswith('_coverage_rate')])
total_dead = sum([v for k, v in layer_stats.items() if k.endswith('_dead_memories')])
total_hot = sum([v for k, v in layer_stats.items() if k.endswith('_hot_memories')])
log_dict.update({
'memory/avg_gini_coefficient': avg_gini,
'memory/avg_coverage_rate': avg_coverage,
'memory/total_dead_memories': total_dead,
'memory/total_hot_memories': total_hot,
})
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
f"CE Loss: {log_dict['train/loss_ce']:.4f}, "
f"Balance Loss: {log_dict['train/loss_balance']:.4f}, "
f"Total Loss: {log_dict['train/loss_total']:.4f}, "
f"Val Loss: {log_dict.get('val/loss', 'N/A')}, "
f"LR: {log_dict['lr']:.6f}, "
f"Speed: {log_dict['tokens_per_sec']:.2f} tokens/sec | "
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
swanlab_run.log(log_dict)
# 保存模型 (只在主进程进行)
loss_total = loss.item() * args.accumulation_steps
if epoch >= 0 and best_loss > loss_total and accelerator.is_main_process:
best_loss = loss_total
# 使用函数开始处定义的moe_path变量
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
# 获取解包后的模型
unwrapped_model = accelerator.unwrap_model(model)
# 保存模型参数
accelerator.save(unwrapped_model.state_dict(), ckp)
Logger(f"Model saved to {ckp}", accelerator)
except Exception as e:
Logger(f"Error in training step: {e}", accelerator)
# 记录异常时的内存状态
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "at_exception", detailed=True)
import traceback
Logger(traceback.format_exc(), accelerator)
# 清理prefetch_batches防止内存泄漏
if args.memory_monitor and accelerator.is_main_process:
Logger(f"[Memory Monitor] Clearing prefetch_batches due to exception. Current length: {len(prefetch_batches)}", accelerator)
prefetch_batches.clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "after_exception_cleanup", detailed=True)
# 训练epoch结束时清理prefetch_batches
if args.memory_monitor:
if accelerator.is_main_process:
Logger(f"[Memory Monitor] Epoch {epoch+1} finished. Clearing prefetch_batches. Final length: {len(prefetch_batches)}", accelerator)
log_memory_status(total_steps_in_epoch-1, prefetch_batches, accelerator, "before_epoch_end_cleanup", detailed=True)
prefetch_batches.clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if args.memory_monitor:
log_memory_status(total_steps_in_epoch-1, prefetch_batches, accelerator, "after_epoch_end_cleanup", detailed=True)
def main():
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=4)
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
parser.add_argument("--batch_size", type=int, default=60)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数
parser.add_argument("--swanlab_project", type=str, default="MiniMind-Pretrain") # 替换wandb参数
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--accumulation_steps", type=int, default=32)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=1)
parser.add_argument("--save_interval", type=int, default=10000)
parser.add_argument('--dim', default=512, type=int)
parser.add_argument('--n_layers', default=8, type=int)
parser.add_argument('--n_heads', default=32, type=int)
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代")
parser.add_argument("--data_path", type=str, default="/home/pci/ycz/Code/Minimind/dataset/stable/merged_pretrain.jsonl")
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目")
parser.add_argument("--knowledge_length", type=int, default=8,help="知识库的句子长度")
parser.add_argument("--knowledge_dim", type=int, default=128,help="知识库的向量维度")
parser.add_argument("--database_init_path", type=str, default="/home/pci/ycz/Code/Minimind/dataset/stable/sentence_trex_data.json", help="数据库初始化路径")
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
parser.add_argument("--cluster_cache_path", type=str, default="/home/pci/ycz/Code/Minimind/cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控")
parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)")
parser.add_argument("--model_type", type=str, default="model_memory", help="使用什么模型训练") #model,model_original,model_no_feed
parser.add_argument("--model_size", type=float, default=50.0, help="模型大小")
parser.add_argument("--swanlab_online", type=bool, default=False, help="是否使用在线SwanLab服务")
parser.add_argument("--balance_loss_coef", type=float, default=0.01, help="平衡损失系数")
parser.add_argument("--val_data_path", type=str, default="dataset/stable/eval_data.json", help="验证数据集路径")
parser.add_argument("--val_interval", type=int, default=100, help="验证评估间隔")
args = parser.parse_args()
#########################################################
# 初始化accelerator和deepspeed
#########################################################
# 设置ddp_kwargs以处理未使用的参数
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# 创建DeepSpeedPlugin对象
ds_plugin = DeepSpeedPlugin(
gradient_accumulation_steps=args.accumulation_steps,
gradient_clipping=args.grad_clip,
zero_stage=2, # 使用ZeRO-2优化
offload_optimizer_device="none", # 将优化器状态卸载到CPU
offload_param_device="none", # 不将参数卸载到CPU
)
accelerator = Accelerator(
kwargs_handlers=[ddp_kwargs],
deepspeed_plugin=ds_plugin,
mixed_precision="bf16" if args.dtype == "bfloat16" else "fp16" if args.dtype == "float16" else "no"
)
#########################################################
# 设置随机种子
#########################################################
set_seed(1337 + accelerator.process_index)
#########################################################
# 配置模型
#########################################################
lm_config = LMConfig(
dim=args.dim,
n_layers=args.n_layers,
max_seq_len=args.max_seq_len,
use_moe=args.use_moe,
disable_db=args.disable_db,
flash_attn=args.use_flash_attn,
knowledge_num=args.knowledge_num,
knowledge_length=args.knowledge_length,
embeddings_epoch=args.embedding_epoch
)
#########################################################
# 创建保存目录
#########################################################
args.save_dir = os.path.join(args.out_dir)
if accelerator.is_main_process:
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
#########################################################
# 设置数据类型
#########################################################
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
#########################################################
# 配置SwanLab
#########################################################
# 设置SwanLab运行名称
args.swanlab_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
# 合并args和lm_config为一个字典无论是否使用SwanLab都需要用于打印配置信息
config_dict = vars(args).copy()
config_dict.update(vars(lm_config))
# 初始化SwanLab实验实例
swanlab_run = None
if args.use_swanlab and accelerator.is_main_process:
if args.swanlab_online:
# 使用在线SwanLab服务
# 初始化SwanLab
swanlab_run = swanlab.init(
project=args.swanlab_project,
experiment_name=args.swanlab_run_name,
description="MiniMind预训练实验使用本地部署的SwanLab进行可视化",
config=config_dict
)
else:
swanlab_run = swanlab.init(
project=args.swanlab_project,
experiment_name=args.swanlab_run_name,
description="MiniMind预训练实验使用本地部署的SwanLab进行可视化",
config=config_dict,
mode="offline"
)
else:
swanlab_run = None
#########################################################
# 打印信息
#########################################################
# 计算每次迭代的token数量
tokens_per_iter = args.batch_size * lm_config.max_seq_len
if accelerator.is_main_process:
Logger(f"tokens_per_iter: {tokens_per_iter}", accelerator)
Logger("Configuration:", accelerator)
for key, value in config_dict.items():
Logger(f" {key}: {value}", accelerator)
#########################################################
# 设置自动混合精度上下文
#########################################################
ctx = nullcontext() if accelerator.device.type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
#########################################################
# 初始化模型和tokenizer
#########################################################
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path, args.database_init_path, args)
# 将accelerator传递给init_model函数中的Logger调用
Logger(f'模型初始化完成', accelerator)
#########################################################
# 处理位置编码张量问题
#########################################################
if hasattr(model, "pos_cis_real"):
Logger(f'检测到pos_cis_real实数张量将其设置为参与分布式训练', accelerator)
# 设置模型的_ddp_params_and_buffers_to_ignore属性
# model._ddp_params_and_buffers_to_ignore = {"pos_cis_real"}
# 兼容旧版本检查是否仍有pos_cis
elif hasattr(model, "pos_cis"):
Logger(f'检测到pos_cis复数张量将其设置为不参与分布式训练', accelerator)
# 设置模型的_ddp_params_and_buffers_to_ignore属性
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
#########################################################
# 创建数据集和数据加载器
#########################################################
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=True,
num_workers=args.num_workers,
persistent_workers=True if args.num_workers > 0 else False,
prefetch_factor=2 if args.num_workers > 0 else None
)
# 创建验证数据集和加载器
val_loader = None
val_ds = create_validation_dataset(args.val_data_path, tokenizer, lm_config.max_seq_len)
if val_ds is not None:
val_loader = DataLoader(
val_ds,
batch_size=args.batch_size // 2, # 验证时使用较小批次
pin_memory=True,
drop_last=False,
shuffle=False,
num_workers=0, # 验证时不使用多进程
)
#########################################################
# 创建优化器
#########################################################
# 如果启用EMA更新需要过滤掉memory_bank参数因为它不再需要梯度更新
if hasattr(model.params, 'use_ema_update') and model.params.use_ema_update:
# 只包含requires_grad=True的参数
optimizer_params = [p for p in model.parameters() if p.requires_grad]
Logger(f"EMA更新模式优化器包含 {len(optimizer_params)} 个参数过滤掉memory_bank")
Logger(f"总参数:{sum(p.numel() for p in model.parameters())} | 可训练参数:{sum(p.numel() for p in optimizer_params)}")
optimizer = optim.AdamW(optimizer_params, lr=args.learning_rate)
else:
# 传统模式:所有参数都使用梯度更新
Logger("传统梯度更新模式:优化器包含所有模型参数")
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
#########################################################
# 创建学习率调度器
#########################################################
total_steps = len(train_loader) * args.epochs
warmup_steps = args.warmup_iters if args.warmup_iters > 0 else int(0.1 * total_steps)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
#########################################################
# 准备训练
#########################################################
if val_loader is not None:
model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(
model, optimizer, train_loader, val_loader, scheduler
)
else:
model, optimizer, train_loader, scheduler = accelerator.prepare(
model, optimizer, train_loader, scheduler
)
#########################################################
# 训练循环
#########################################################
overall_start_time = time.time() # Record overall start time
for epoch in range(args.epochs):
Logger(f"开始第{epoch+1}轮训练", accelerator)
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer, val_loader) # Pass tokenizer and val_loader
# 每个epoch结束后进行内存清理
Logger(f"{epoch+1}轮训练完成,进行内存清理", accelerator)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 记录epoch结束时的内存状态
if accelerator.is_main_process:
memory_info = get_memory_usage()
cuda_info = get_cuda_memory_usage()
log_msg = f"[Memory Monitor] Epoch {epoch+1} completed - "
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
Logger(log_msg, accelerator)
#########################################################
# 关闭SwanLab
#########################################################
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
swanlab_run.finish()
if __name__ == "__main__":
main()