Minimind/train_pretrain_accelerate.py

1619 lines
82 KiB
Python
Raw Normal View History

2025-05-14 00:01:40 +08:00
import os
# 设置环境变量 - 将wandb替换为SwanLab
# os.environ["SWANLAB_MODE"] = "online" # SwanLab使用在线模式
2025-09-06 15:12:05 +08:00
# 🔥 强制禁用输出缓冲,确保日志立即写入
os.environ['PYTHONUNBUFFERED'] = '1' # Python 解释器不缓冲输出
os.environ['PYTHONIOENCODING'] = 'utf-8' # 确保编码一致性
2025-05-14 00:01:40 +08:00
import platform
import argparse
2025-05-26 23:09:03 +08:00
from tqdm import tqdm
2025-05-14 00:01:40 +08:00
import time
import math
import warnings
import pandas as pd
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
from contextlib import nullcontext
from typing import Optional
import datetime # Add datetime for time formatting
from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.utils import DeepSpeedPlugin
from accelerate.utils import DistributedDataParallelKwargs
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
2025-05-26 23:09:03 +08:00
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import swanlab # 替换wandb导入
import gc # 添加垃圾回收模块
import psutil # 添加系统资源监控模块
2025-08-07 11:43:23 +08:00
import json # 添加JSON支持
2025-07-12 18:00:53 +08:00
2025-05-14 00:01:40 +08:00
from model.LMConfig import LMConfig
from model.dataset import PretrainDataset
warnings.filterwarnings('ignore')
# 内存监控辅助函数
def get_memory_usage():
"""获取当前内存使用情况"""
process = psutil.Process()
memory_info = process.memory_info()
return {
'rss_mb': memory_info.rss / 1024 / 1024, # 物理内存使用量MB
'vms_mb': memory_info.vms / 1024 / 1024, # 虚拟内存使用量MB
}
def get_cuda_memory_usage():
"""获取CUDA内存使用情况"""
if torch.cuda.is_available():
return {
'cuda_allocated_mb': torch.cuda.memory_allocated() / 1024 / 1024,
'cuda_reserved_mb': torch.cuda.memory_reserved() / 1024 / 1024,
'cuda_max_allocated_mb': torch.cuda.max_memory_allocated() / 1024 / 1024,
}
return {}
def get_tensor_memory_size(tensor_list):
"""计算tensor列表的总内存占用MB"""
total_size = 0
for batch in tensor_list:
if isinstance(batch, (list, tuple)):
for tensor in batch:
if isinstance(tensor, torch.Tensor):
total_size += tensor.numel() * tensor.element_size()
elif isinstance(batch, torch.Tensor):
total_size += batch.numel() * batch.element_size()
return total_size / 1024 / 1024 # 转换为MB
def log_memory_status(step, prefetch_batches, accelerator, stage="", detailed=False):
"""记录内存状态"""
if not accelerator.is_main_process:
return
memory_info = get_memory_usage()
cuda_info = get_cuda_memory_usage()
prefetch_memory = get_tensor_memory_size(prefetch_batches)
log_msg = f"[Memory Monitor] Step {step} {stage} - "
log_msg += f"Prefetch batches: {len(prefetch_batches)}, "
log_msg += f"Prefetch memory: {prefetch_memory:.2f}MB, "
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
if detailed:
log_msg += f", System VMS: {memory_info['vms_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA max allocated: {cuda_info['cuda_max_allocated_mb']:.2f}MB"
Logger(log_msg, accelerator)
2025-05-14 00:01:40 +08:00
# 日志记录函数
def Logger(msg, accelerator=None):
# 如果没有提供accelerator则只在主进程打印
if accelerator is None or accelerator.is_main_process:
2025-09-06 15:12:05 +08:00
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}", flush=True) # 强制刷新输出缓冲
import sys
sys.stdout.flush() # 确保立即写入
2025-05-14 00:01:40 +08:00
# Helper function to format seconds into HH:MM:SS
def format_time(seconds):
return str(datetime.timedelta(seconds=int(seconds)))
2025-08-07 11:43:23 +08:00
def create_validation_dataset(val_data_path, tokenizer, max_length, num_samples=200):
"""
创建验证数据集
Args:
val_data_path: 验证数据文件路径
tokenizer: tokenizer实例
max_length: 最大序列长度
num_samples: 验证样本数量
Returns:
val_dataset: 验证数据集
"""
if not os.path.exists(val_data_path):
Logger(f"警告:验证数据文件不存在: {val_data_path},跳过验证评估")
return None
# 读取验证数据
val_data = []
with open(val_data_path, 'r', encoding='utf-8') as f:
for i, line in enumerate(f):
if i >= num_samples: # 限制验证样本数量
break
line = line.strip()
if line:
try:
sample = json.loads(line)
val_data.append(sample['text'])
except json.JSONDecodeError:
continue
# 创建临时验证文件
temp_val_file = "/tmp/temp_val.jsonl"
with open(temp_val_file, 'w', encoding='utf-8') as f:
for text in val_data:
f.write(json.dumps({'text': text}) + '\n')
# 使用PretrainDataset创建验证集
val_dataset = PretrainDataset(temp_val_file, tokenizer, max_length=max_length)
Logger(f"创建验证数据集成功,包含 {len(val_data)} 个样本")
return val_dataset
def validate_model(model, val_loader, loss_fct, ctx, accelerator):
"""
执行模型验证
Args:
model: 模型实例
val_loader: 验证数据加载器
loss_fct: 损失函数
ctx: 上下文管理器
accelerator: Accelerator实例
Returns:
avg_val_loss: 平均验证损失
"""
model.eval()
total_loss = 0
num_batches = 0
with torch.no_grad():
for batch in val_loader:
X, Y, loss_mask = batch
with ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
total_loss += loss.item()
num_batches += 1
model.train()
avg_val_loss = total_loss / num_batches if num_batches > 0 else float('inf')
return avg_val_loss
2025-05-14 00:01:40 +08:00
# 获取学习率函数
def get_lr(it, num_iters, learning_rate):
# 余弦学习率衰减
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
# 初始化模型函数
2025-05-26 23:09:03 +08:00
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
2025-07-12 18:00:53 +08:00
if args.model_type == "model":
Logger(f"Using model type: {args.model_type}")
from model.model import MiniMindLM, RMSNorm
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 默认模型初始化
Logger("Performing default model initialization...")
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 初始化嵌入层权重
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 初始化输出层权重(如果不共享权重的话)
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 初始化所有线性层
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# 使用Xavier/Glorot初始化
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
# 嵌入层使用正态分布初始化
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, RMSNorm):
# RMSNorm的权重初始化为1
if hasattr(module, 'weight'):
nn.init.ones_(module.weight)
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 初始化位置编码相关参数
if hasattr(model.knowledge_dataset, 'keys'):
nn.init.normal_(model.knowledge_dataset.keys, mean=0.0, std=0.02)
Logger("Default model initialization completed")
# 如果提供了预训练的嵌入权重,加载它们
if pretrained_embedding_path:
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
pretrained_embeddings = torch.load(pretrained_embedding_path)
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
if database_init_path:
import json
2025-05-29 20:29:45 +08:00
2025-07-12 18:00:53 +08:00
# 数据库参数
knowledge_num = args.knowledge_num
knowledge_length = args.knowledge_length
2025-05-29 20:29:45 +08:00
2025-07-12 18:00:53 +08:00
# 检查是否使用缓存
cache_dir = os.path.dirname(args.cluster_cache_path)
if cache_dir:
os.makedirs(cache_dir, exist_ok=True)
2025-05-29 20:29:45 +08:00
2025-07-12 18:00:53 +08:00
processed_tensor = None
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 尝试加载缓存的处理结果
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
try:
Logger(f"Loading cached processed results from {args.cluster_cache_path}")
processed_tensor = torch.load(args.cluster_cache_path)
# 验证缓存文件的形状是否可用
cached_knowledge_num, cached_knowledge_length = processed_tensor.shape
if cached_knowledge_length == knowledge_length:
if cached_knowledge_num >= knowledge_num:
# 缓存足够大,可以截取使用
processed_tensor = processed_tensor[:knowledge_num, :]
Logger(f"Successfully loaded cached data with shape {processed_tensor.shape}")
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
Logger("Skipping database initialization - using cached results")
else:
# 缓存太小,需要重新计算
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
processed_tensor = None
else:
# knowledge_length不匹配需要重新计算
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
processed_tensor = None
except Exception as e:
Logger(f"Failed to load cached data: {e}, recomputing...")
processed_tensor = None
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 只有在没有有效缓存时才进行数据库初始化和处理
if processed_tensor is None:
Logger(f"Loading database initialization data from {database_init_path}")
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 1. 加载JSON文件
with open(database_init_path, 'r', encoding='utf-8') as f:
database_data = json.load(f)
2025-05-26 23:09:03 +08:00
2025-07-13 21:28:46 +08:00
sentences_data = []
for data in database_data:
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 保存句子和对应的uuid信息
sentence_info = {
'sentence': data['target'][0]['sentence'],
'uuid': data['target'][0]['uuid'],
'subject': data['target'][0].get('subject', ''),
'predicate': data['target'][0].get('predicate', ''),
'object': data['target'][0].get('object', '')
}
sentences_data.append(sentence_info)
2025-07-13 21:28:46 +08:00
2025-07-12 18:00:53 +08:00
# 提取sentences列表
2025-07-13 21:28:46 +08:00
# sentences_data = database_data.get('sentences', [])
2025-07-12 18:00:53 +08:00
Logger(f"Loaded {len(sentences_data)} sentences from database")
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
# 2. 按照importance_score进行排序从高到低
2025-07-13 21:28:46 +08:00
try:
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 注意现在sentences_data中的每个元素都是字典不再有importance_score字段
# 如果需要按重要性排序,需要从原始数据中获取该信息
sorted_sentences = sentences_data # 暂时不排序,保持原始顺序
Logger(f"Loaded {len(sorted_sentences)} sentences (no importance_score sorting applied)")
2025-07-13 21:28:46 +08:00
except:
sorted_sentences = sentences_data
2025-07-12 18:00:53 +08:00
# 3. 处理每条数据,不进行聚类
Logger("Processing individual sentences...")
processed_rows = []
# 获取空token的id用于填充
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# 处理所需数量的句子
num_to_process = min(knowledge_num, len(sorted_sentences))
2025-07-13 21:28:46 +08:00
# 添加截断统计变量
total_sentences = 0
truncated_sentences = 0
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 用于记录映射关系的列表
database_mapping = []
2025-07-12 18:00:53 +08:00
for i in range(num_to_process):
sentence_data = sorted_sentences[i]
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 现在sentence_data是一个字典包含sentence和uuid
sentence = sentence_data['sentence']
uuid = sentence_data['uuid']
2025-07-12 18:00:53 +08:00
# 将句子转换为tokens
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
# 截断或填充到knowledge_length
2025-07-13 21:28:46 +08:00
total_sentences += 1
2025-07-12 18:00:53 +08:00
if len(sentence_tokens) > knowledge_length:
# 如果超过长度,截断
2025-07-13 21:28:46 +08:00
truncated_sentences += 1
2025-07-12 18:00:53 +08:00
sentence_tokens = sentence_tokens[:knowledge_length]
Logger(f"Sentence {i+1} truncated from {len(tokenizer.encode(sentence, add_special_tokens=False))} to {knowledge_length} tokens")
else:
# 如果不足长度用空token填充
original_length = len(sentence_tokens)
sentence_tokens.extend([pad_token_id] * (knowledge_length - len(sentence_tokens)))
if original_length < knowledge_length:
Logger(f"Sentence {i+1} padded from {original_length} to {knowledge_length} tokens")
processed_rows.append(sentence_tokens)
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 记录映射关系:数据库索引 -> 原始数据信息
mapping_entry = {
'database_index': i, # 在数据库中的索引位置
'uuid': uuid, # 原始uuid
'sentence': sentence, # 原始句子
'subject': sentence_data.get('subject', ''),
'predicate': sentence_data.get('predicate', ''),
'object': sentence_data.get('object', ''),
'token_count': len(sentence_tokens),
'is_truncated': len(tokenizer.encode(sentence, add_special_tokens=False)) > knowledge_length
}
database_mapping.append(mapping_entry)
2025-07-12 18:00:53 +08:00
if (i + 1) % 1000 == 0:
Logger(f"Processed {i + 1}/{num_to_process} sentences")
# 如果句子数量不足用空token填充剩余位置
while len(processed_rows) < knowledge_num:
empty_tokens = [pad_token_id] * knowledge_length
processed_rows.append(empty_tokens)
if len(processed_rows) % 1000 == 0:
Logger(f"Added empty entry {len(processed_rows)}/{knowledge_num}")
Logger(f"Finished adding empty entries. Total: {len(processed_rows)}/{knowledge_num}")
# 转换为tensor
processed_tensor = torch.tensor(processed_rows, dtype=torch.long)
2025-07-13 21:28:46 +08:00
# 计算并打印截断句子的占比
truncation_ratio = truncated_sentences / total_sentences if total_sentences > 0 else 0.0
Logger(f"截断句子统计:")
Logger(f" - 总句子数: {total_sentences}")
Logger(f" - 截断句子数: {truncated_sentences}")
Logger(f" - 截断句子占比: {truncation_ratio:.4f} ({truncation_ratio*100:.2f}%)")
2025-07-12 18:00:53 +08:00
Logger(f"Data processing completed:")
Logger(f" - Processed {num_to_process} sentences")
Logger(f" - Added {knowledge_num - num_to_process} empty entries")
Logger(f" - Final shape: {processed_tensor.shape}")
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
# 保存处理结果到缓存文件
try:
torch.save(processed_tensor, args.cluster_cache_path)
Logger(f"Processed results saved to {args.cluster_cache_path}")
except Exception as e:
Logger(f"Failed to save processed results: {e}")
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 保存数据库映射文件
try:
mapping_file_path = args.cluster_cache_path.replace('.pt', '_mapping.json')
mapping_data = {
'metadata': {
'total_entries': len(database_mapping),
'knowledge_num': knowledge_num,
'knowledge_length': knowledge_length,
'source_file': database_init_path,
'generation_time': time.strftime('%Y-%m-%d %H:%M:%S')
},
'mappings': database_mapping
}
with open(mapping_file_path, 'w', encoding='utf-8') as f:
json.dump(mapping_data, f, ensure_ascii=False, indent=2)
Logger(f"Database mapping saved to {mapping_file_path}")
except Exception as e:
Logger(f"Failed to save database mapping: {e}")
2025-05-29 20:29:45 +08:00
2025-07-12 18:00:53 +08:00
# 4. 初始化模型的knowledge_dataset
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
model.knowledge_dataset.knowledge_dataset.data.copy_(processed_tensor)
Logger("Successfully initialized model.knowledge_dataset.knowledge_dataset with processed data")
else:
Logger("Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize")
# 存储为全局变量作为备选
globals()['processed_database'] = processed_tensor
Logger(f"Database embeddings and sentences stored in model")
2025-05-26 23:09:03 +08:00
2025-07-12 18:00:53 +08:00
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
elif args.model_type == "model_original":
Logger(f"Using model type: {args.model_type}")
from model.model_original import MiniMindLM, RMSNorm
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
2025-07-13 21:28:46 +08:00
elif args.model_type == "model_no_feed":
Logger(f"Using model type: {args.model_type}")
from model.model_no_feed import MiniMindLM, RMSNorm
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
# 默认模型初始化
Logger("Performing default model initialization...")
# 初始化嵌入层权重
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
# 初始化输出层权重(如果不共享权重的话)
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
# 初始化所有线性层
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# 使用Xavier/Glorot初始化
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
# 嵌入层使用正态分布初始化
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, RMSNorm):
# RMSNorm的权重初始化为1
if hasattr(module, 'weight'):
nn.init.ones_(module.weight)
# 初始化位置编码相关参数
if hasattr(model.knowledge_dataset, 'keys'):
nn.init.normal_(model.knowledge_dataset.keys, mean=0.0, std=0.02)
Logger("Default model initialization completed")
# 如果提供了预训练的嵌入权重,加载它们
if pretrained_embedding_path:
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
pretrained_embeddings = torch.load(pretrained_embedding_path)
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
if database_init_path:
import json
# 数据库参数
knowledge_num = args.knowledge_num
knowledge_length = args.knowledge_length
# 检查是否使用缓存
cache_dir = os.path.dirname(args.cluster_cache_path)
if cache_dir:
os.makedirs(cache_dir, exist_ok=True)
processed_tensor = None
# 尝试加载缓存的处理结果
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
try:
Logger(f"Loading cached processed results from {args.cluster_cache_path}")
processed_tensor = torch.load(args.cluster_cache_path)
# 验证缓存文件的形状是否可用
cached_knowledge_num, cached_knowledge_length = processed_tensor.shape
if cached_knowledge_length == knowledge_length:
if cached_knowledge_num >= knowledge_num:
# 缓存足够大,可以截取使用
processed_tensor = processed_tensor[:knowledge_num, :]
Logger(f"Successfully loaded cached data with shape {processed_tensor.shape}")
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
Logger("Skipping database initialization - using cached results")
else:
# 缓存太小,需要重新计算
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
processed_tensor = None
else:
# knowledge_length不匹配需要重新计算
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
processed_tensor = None
except Exception as e:
Logger(f"Failed to load cached data: {e}, recomputing...")
processed_tensor = None
# 只有在没有有效缓存时才进行数据库初始化和处理
if processed_tensor is None:
Logger(f"Loading database initialization data from {database_init_path}")
# 1. 加载JSON文件
with open(database_init_path, 'r', encoding='utf-8') as f:
database_data = json.load(f)
sentences_data = []
for data in database_data:
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 保存句子和对应的uuid信息
sentence_info = {
'sentence': data['target'][0]['sentence'],
'uuid': data['target'][0]['uuid'],
'subject': data['target'][0].get('subject', ''),
'predicate': data['target'][0].get('predicate', ''),
'object': data['target'][0].get('object', '')
}
sentences_data.append(sentence_info)
2025-07-13 21:28:46 +08:00
# 提取sentences列表
# sentences_data = database_data.get('sentences', [])
Logger(f"Loaded {len(sentences_data)} sentences from database")
# 2. 按照importance_score进行排序从高到低
try:
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 注意现在sentences_data中的每个元素都是字典不再有importance_score字段
# 如果需要按重要性排序,需要从原始数据中获取该信息
sorted_sentences = sentences_data # 暂时不排序,保持原始顺序
Logger(f"Loaded {len(sorted_sentences)} sentences (no importance_score sorting applied)")
2025-07-13 21:28:46 +08:00
except:
sorted_sentences = sentences_data
# 3. 处理每条数据,不进行聚类
Logger("Processing individual sentences...")
processed_rows = []
# 获取空token的id用于填充
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# 处理所需数量的句子
num_to_process = min(knowledge_num, len(sorted_sentences))
# 添加截断统计变量
total_sentences = 0
truncated_sentences = 0
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 用于记录映射关系的列表
database_mapping = []
2025-07-13 21:28:46 +08:00
for i in range(num_to_process):
sentence_data = sorted_sentences[i]
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 现在sentence_data是一个字典包含sentence和uuid
sentence = sentence_data['sentence']
uuid = sentence_data['uuid']
2025-07-13 21:28:46 +08:00
# 将句子转换为tokens
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
# 截断或填充到knowledge_length
total_sentences += 1
if len(sentence_tokens) > knowledge_length:
# 如果超过长度,截断
truncated_sentences += 1
sentence_tokens = sentence_tokens[:knowledge_length]
Logger(f"Sentence {i+1} truncated from {len(tokenizer.encode(sentence, add_special_tokens=False))} to {knowledge_length} tokens")
else:
# 如果不足长度用空token填充
original_length = len(sentence_tokens)
sentence_tokens.extend([pad_token_id] * (knowledge_length - len(sentence_tokens)))
if original_length < knowledge_length:
Logger(f"Sentence {i+1} padded from {original_length} to {knowledge_length} tokens")
processed_rows.append(sentence_tokens)
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 记录映射关系:数据库索引 -> 原始数据信息
mapping_entry = {
'database_index': i, # 在数据库中的索引位置
'uuid': uuid, # 原始uuid
'sentence': sentence, # 原始句子
'subject': sentence_data.get('subject', ''),
'predicate': sentence_data.get('predicate', ''),
'object': sentence_data.get('object', ''),
'token_count': len(sentence_tokens),
'is_truncated': len(tokenizer.encode(sentence, add_special_tokens=False)) > knowledge_length
}
database_mapping.append(mapping_entry)
2025-07-13 21:28:46 +08:00
if (i + 1) % 1000 == 0:
Logger(f"Processed {i + 1}/{num_to_process} sentences")
# 如果句子数量不足用空token填充剩余位置
while len(processed_rows) < knowledge_num:
empty_tokens = [pad_token_id] * knowledge_length
processed_rows.append(empty_tokens)
if len(processed_rows) % 1000 == 0:
Logger(f"Added empty entry {len(processed_rows)}/{knowledge_num}")
Logger(f"Finished adding empty entries. Total: {len(processed_rows)}/{knowledge_num}")
# 转换为tensor
processed_tensor = torch.tensor(processed_rows, dtype=torch.long)
# 计算并打印截断句子的占比
truncation_ratio = truncated_sentences / total_sentences if total_sentences > 0 else 0.0
Logger(f"截断句子统计:")
Logger(f" - 总句子数: {total_sentences}")
Logger(f" - 截断句子数: {truncated_sentences}")
Logger(f" - 截断句子占比: {truncation_ratio:.4f} ({truncation_ratio*100:.2f}%)")
Logger(f"Data processing completed:")
Logger(f" - Processed {num_to_process} sentences")
Logger(f" - Added {knowledge_num - num_to_process} empty entries")
Logger(f" - Final shape: {processed_tensor.shape}")
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
# 保存处理结果到缓存文件
try:
torch.save(processed_tensor, args.cluster_cache_path)
Logger(f"Processed results saved to {args.cluster_cache_path}")
except Exception as e:
Logger(f"Failed to save processed results: {e}")
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 保存数据库映射文件
try:
mapping_file_path = args.cluster_cache_path.replace('.pt', '_mapping.json')
mapping_data = {
'metadata': {
'total_entries': len(database_mapping),
'knowledge_num': knowledge_num,
'knowledge_length': knowledge_length,
'source_file': database_init_path,
'generation_time': time.strftime('%Y-%m-%d %H:%M:%S')
},
'mappings': database_mapping
}
with open(mapping_file_path, 'w', encoding='utf-8') as f:
json.dump(mapping_data, f, ensure_ascii=False, indent=2)
Logger(f"Database mapping saved to {mapping_file_path}")
except Exception as e:
Logger(f"Failed to save database mapping: {e}")
2025-07-13 21:28:46 +08:00
# 4. 初始化模型的knowledge_dataset
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
model.knowledge_dataset.knowledge_dataset.data.copy_(processed_tensor)
Logger("Successfully initialized model.knowledge_dataset.knowledge_dataset with processed data")
else:
Logger("Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize")
# 存储为全局变量作为备选
globals()['processed_database'] = processed_tensor
Logger(f"Database embeddings and sentences stored in model")
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
2025-08-03 14:25:26 +08:00
elif args.model_type == "model_memory":
Logger(f"Using model type: {args.model_type}")
from model.model_memory import MiniMindLM, RMSNorm
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
# 默认模型初始化
Logger("Performing model_memory initialization...")
# 初始化嵌入层权重
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
# 初始化输出层权重(如果不共享权重的话)
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
# 初始化所有线性层
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# 使用Xavier/Glorot初始化
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
# 嵌入层使用正态分布初始化
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, RMSNorm):
# RMSNorm的权重初始化为1
if hasattr(module, 'weight'):
nn.init.ones_(module.weight)
# 记忆库初始化
if database_init_path and os.path.exists(database_init_path):
Logger(f"Initializing memory_bank with text data from {database_init_path}")
import json
# 数据库参数
knowledge_num = args.knowledge_num
knowledge_length = args.knowledge_length
# 缓存文件路径
memory_cache_path = args.cluster_cache_path or f"cache/memory_bank_init_{knowledge_num}_{knowledge_length}.pt"
os.makedirs(os.path.dirname(memory_cache_path) if os.path.dirname(memory_cache_path) else '.', exist_ok=True)
# 检查是否有缓存
if os.path.exists(memory_cache_path):
Logger(f"Loading memory_bank initialization from cache: {memory_cache_path}")
processed_tensor = torch.load(memory_cache_path)
Logger(f"Loaded memory_bank data with shape: {processed_tensor.shape}")
else:
Logger(f"Processing text data from {database_init_path} for memory_bank initialization")
# 加载数据
with open(database_init_path, 'r', encoding='utf-8') as f:
data = json.load(f)
Logger(f"Loaded {len(data)} sentences from {database_init_path}")
# 处理句子到token序列
processed_rows = []
total_sentences = len(data)
truncated_sentences = 0
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 用于记录映射关系的列表
database_mapping = []
# 控制处理的句子数量
num_to_process = min(len(data), knowledge_num)
Logger(f"Processing {num_to_process} out of {total_sentences} sentences")
# 处理句子到token ID序列
for idx, item in enumerate(data[:num_to_process]):
if idx % 1000 == 0:
Logger(f"Processing sentence {idx+1}/{num_to_process}")
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 获取句子文本和uuid
if isinstance(item, dict):
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 如果是字典格式尝试提取target数组中的数据
if 'target' in item and len(item['target']) > 0:
sentence = item['target'][0].get('sentence', '')
uuid = item['target'][0].get('uuid', '')
subject = item['target'][0].get('subject', '')
predicate = item['target'][0].get('predicate', '')
object_name = item['target'][0].get('object', '')
else:
sentence = item.get('sentence', '') or item.get('text', '') or str(item)
uuid = item.get('uuid', '')
subject = item.get('subject', '')
predicate = item.get('predicate', '')
object_name = item.get('object', '')
else:
sentence = str(item)
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
uuid = ''
subject = ''
predicate = ''
object_name = ''
# 使用tokenizer编码句子
try:
tokens = tokenizer(
sentence,
add_special_tokens=True,
truncation=True,
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
max_length=len(sentence),
padding=False,
return_tensors="pt"
)['input_ids'].squeeze().tolist()
# 确保是列表
if not isinstance(tokens, list):
tokens = [tokens]
# 检查长度
if len(tokens) > knowledge_length:
tokens = tokens[:knowledge_length]
truncated_sentences += 1
elif len(tokens) < knowledge_length:
# 用padding token填充
tokens.extend([pad_token_id] * (knowledge_length - len(tokens)))
processed_rows.append(tokens)
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 记录映射关系:数据库索引 -> 原始数据信息
mapping_entry = {
'database_index': idx, # 在数据库中的索引位置
'uuid': uuid, # 原始uuid
'sentence': sentence, # 原始句子
'subject': subject,
'predicate': predicate,
'object': object_name,
'token_count': len(tokens),
'is_truncated': len(tokens) > knowledge_length
}
database_mapping.append(mapping_entry)
except Exception as e:
Logger(f"Error processing sentence {idx}: {e}")
# 使用空tokens作为fallback
empty_tokens = [pad_token_id] * knowledge_length
processed_rows.append(empty_tokens)
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 为失败的句子也记录映射关系
mapping_entry = {
'database_index': idx,
'uuid': uuid,
'sentence': sentence,
'subject': subject,
'predicate': predicate,
'object': object_name,
'token_count': knowledge_length,
'is_truncated': False,
'processing_error': str(e)
}
database_mapping.append(mapping_entry)
# 如果句子数量不足用空token填充剩余位置
while len(processed_rows) < knowledge_num:
empty_tokens = [pad_token_id] * knowledge_length
processed_rows.append(empty_tokens)
if len(processed_rows) % 1000 == 0:
Logger(f"Added empty entry {len(processed_rows)}/{knowledge_num}")
# 转换为tensor
processed_tensor = torch.tensor(processed_rows, dtype=torch.long)
# 计算并打印截断句子的占比
truncation_ratio = truncated_sentences / total_sentences if total_sentences > 0 else 0.0
Logger(f"截断句子统计:")
Logger(f" - 总句子数: {total_sentences}")
Logger(f" - 截断句子数: {truncated_sentences}")
Logger(f" - 截断句子占比: {truncation_ratio:.4f} ({truncation_ratio*100:.2f}%)")
Logger(f"Memory_bank data processing completed:")
Logger(f" - Processed {num_to_process} sentences")
Logger(f" - Added {knowledge_num - num_to_process} empty entries")
Logger(f" - Final shape: {processed_tensor.shape}")
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
# 保存处理结果到缓存文件
try:
torch.save(processed_tensor, memory_cache_path)
Logger(f"Processed results saved to {memory_cache_path}")
except Exception as e:
Logger(f"Failed to save processed results: {e}")
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 保存数据库映射文件
try:
mapping_file_path = memory_cache_path.replace('.pt', '_mapping.json')
mapping_data = {
'metadata': {
'total_entries': len(database_mapping),
'knowledge_num': knowledge_num,
'knowledge_length': knowledge_length,
'source_file': database_init_path,
'generation_time': time.strftime('%Y-%m-%d %H:%M:%S')
},
'mappings': database_mapping
}
with open(mapping_file_path, 'w', encoding='utf-8') as f:
json.dump(mapping_data, f, ensure_ascii=False, indent=2)
Logger(f"Database mapping saved to {mapping_file_path}")
except Exception as e:
Logger(f"Failed to save database mapping: {e}")
# 初始化模型的memory_bank
if hasattr(model, 'memory_bank'):
model.memory_bank.data.copy_(processed_tensor)
Logger("Successfully initialized memory_bank with processed text data")
else:
Logger("Warning: Could not find memory_bank to initialize")
else:
Logger(f"Memory bank initialized with random values, shape: {model.memory_bank.shape}")
2025-08-03 14:25:26 +08:00
Logger("Model_memory initialization completed")
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
2025-07-12 18:00:53 +08:00
2025-05-14 00:01:40 +08:00
return model, tokenizer
2025-08-07 11:43:23 +08:00
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer, val_loader=None):
2025-05-14 00:01:40 +08:00
loss_fct = nn.CrossEntropyLoss(reduction='none')
epoch_start_time = time.time()
total_steps_in_epoch = len(train_loader)
total_training_steps = args.epochs * total_steps_in_epoch
moe_path = '_moe' if args.use_moe else ''
2025-06-08 02:20:36 +00:00
best_loss = float('10000')
2025-05-14 00:01:40 +08:00
# 初始化CUDA事件变量
data_start = data_end = forward_start = forward_end = None
backward_start = backward_end = optimizer_start = optimizer_end = None
2025-05-14 00:01:40 +08:00
# 添加CUDA事件来分析性能 (只在主进程进行)
if args.profile and accelerator.is_main_process:
data_start = torch.cuda.Event(enable_timing=True)
data_end = torch.cuda.Event(enable_timing=True)
forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
# 预取数据
2025-07-13 21:28:46 +08:00
prefetch_factor = 8 # 预取的批次数
2025-05-14 00:01:40 +08:00
data_iter = iter(train_loader)
prefetch_batches = []
# 记录初始内存状态
if args.memory_monitor:
log_memory_status(-1, prefetch_batches, accelerator, "before_prefetch", detailed=True)
2025-05-14 00:01:40 +08:00
# 预取初始批次
for i in range(min(prefetch_factor, len(train_loader))):
2025-05-14 00:01:40 +08:00
try:
batch = next(data_iter)
prefetch_batches.append(batch)
# 每次添加batch后记录内存变化
if args.memory_monitor and accelerator.is_main_process:
log_memory_status(-1, prefetch_batches, accelerator, f"after_adding_batch_{i+1}")
2025-05-14 00:01:40 +08:00
except StopIteration:
break
# 记录预取完成后的内存状态
if args.memory_monitor:
log_memory_status(-1, prefetch_batches, accelerator, "after_initial_prefetch", detailed=True)
2025-05-14 00:01:40 +08:00
# 在开始循环前初始化日志记录所需变量
last_log_time = epoch_start_time
for step in range(total_steps_in_epoch):
try:
# 计时数据加载 (只在主进程进行)
if args.profile and accelerator.is_main_process and data_start is not None:
2025-05-14 00:01:40 +08:00
data_start.record()
# 记录使用batch前的内存状态根据配置间隔记录详细信息
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "before_use_batch", detailed=True)
2025-05-14 00:01:40 +08:00
# 使用预取的数据
if prefetch_batches:
X, Y, loss_mask = prefetch_batches.pop(0)
# 记录使用batch后的内存变化
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "after_pop_batch")
2025-05-14 00:01:40 +08:00
else:
# 如果预取队列为空,直接加载
X, Y, loss_mask = next(data_iter)
if args.memory_monitor and accelerator.is_main_process:
Logger(f"[Memory Monitor] Step {step} - Prefetch queue empty, loading directly!", accelerator)
2025-05-14 00:01:40 +08:00
# 异步预取下一批数据
if step + prefetch_factor < len(train_loader):
try:
batch = next(data_iter)
prefetch_batches.append(batch)
# 记录添加新batch后的内存变化
if args.memory_monitor and step % args.memory_monitor_interval == 0:
log_memory_status(step, prefetch_batches, accelerator, "after_add_batch")
2025-05-14 00:01:40 +08:00
except StopIteration:
pass
# 计时数据加载结束 (只在主进程进行)
if args.profile and accelerator.is_main_process and data_end is not None:
2025-05-14 00:01:40 +08:00
data_end.record()
# 更新学习率
if scheduler is not None:
scheduler.step()
# 计时前向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process and forward_start is not None:
2025-05-14 00:01:40 +08:00
forward_start.record()
# 前向传播
with ctx:
2025-06-08 02:20:36 +00:00
if step == 0 and args.embedding_epoch == epoch:
# 需要设置原始模型的freeze_embedding属性而不是包装后的模型
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.freeze_embedding = True
Logger(f"Set freeze_embedding=True for epoch {epoch}, step {step}", accelerator)
res = model(X, step=step)
2025-08-07 11:43:23 +08:00
# 计算主要损失(交叉熵损失)
ce_loss = loss_fct(
2025-05-14 00:01:40 +08:00
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
2025-08-07 11:43:23 +08:00
ce_loss = (ce_loss * loss_mask).sum() / loss_mask.sum()
2025-09-06 12:12:08 +08:00
# 🔥 实验1.4.9: 四损失系统处理
2025-08-07 11:43:23 +08:00
balance_loss = 0
2025-09-06 12:12:08 +08:00
similarity_loss = 0
diversity_loss = 0
2025-08-07 11:43:23 +08:00
if hasattr(res, 'aux_loss') and res.aux_loss is not None:
2025-09-06 12:12:08 +08:00
aux_loss = res.aux_loss
if isinstance(aux_loss, dict):
# 新的四损失结构
balance_loss = aux_loss.get('balance_loss', 0)
similarity_loss = aux_loss.get('similarity_loss', 0)
diversity_loss = aux_loss.get('diversity_loss', 0)
else:
# 向后兼容旧的单一aux_loss
balance_loss = aux_loss
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
# 获取余弦相似度统计信息(如果模型支持)
cosine_stats = {}
2025-09-06 12:12:08 +08:00
avg_selected_similarity = 0.0
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
if hasattr(res, 'cosine_stats') and res.cosine_stats is not None:
cosine_stats = res.cosine_stats
2025-09-06 12:12:08 +08:00
# 🔥 使用选中记忆的平均相似度(更精确的指标)
selected_similarities = [v for k, v in cosine_stats.items() if k.endswith('_selected_avg_similarity')]
if selected_similarities:
avg_selected_similarity = np.mean(selected_similarities)
# 🔥 四损失系统CE + Balance + Similarity + Diversity
# 损失系数可以通过命令行参数调整
balance_coef = getattr(args, 'balance_loss_coef', 0.01)
similarity_coef = getattr(args, 'similarity_loss_coef', 0.1)
diversity_coef = getattr(args, 'diversity_loss_coef', 0.05)
2025-08-07 11:43:23 +08:00
2025-09-06 12:12:08 +08:00
total_loss = (ce_loss +
balance_coef * balance_loss +
similarity_coef * similarity_loss +
diversity_coef * diversity_loss)
2025-08-07 11:43:23 +08:00
loss = total_loss / args.accumulation_steps
2025-05-14 00:01:40 +08:00
# 计时前向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process and forward_end is not None:
2025-05-14 00:01:40 +08:00
forward_end.record()
# 计时反向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process and backward_start is not None:
2025-05-14 00:01:40 +08:00
backward_start.record()
# 反向传播
# 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪
accelerator.backward(loss)
# 计时反向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process and backward_end is not None:
2025-05-14 00:01:40 +08:00
backward_end.record()
# 计时优化器步骤 (只在主进程进行)
if args.profile and accelerator.is_main_process and optimizer_start is not None:
2025-05-14 00:01:40 +08:00
optimizer_start.record()
# 优化器步骤 - 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪
# 只有在达到累积步数时才会执行优化器步骤
# 注意当使用DeepSpeed时它会自动处理梯度累积所以我们不需要检查step % accumulation_steps
optimizer.step()
# 当使用DeepSpeed时zero_grad()会在step()之后自动调用
# 但为了安全起见,我们仍然显式调用它
optimizer.zero_grad()
# VQ-VAE风格的EMA更新仅在启用时执行
if hasattr(res, 'ema_stats') and res.ema_stats is not None:
unwrapped_model = accelerator.unwrap_model(model)
if hasattr(unwrapped_model, 'apply_ema_update'):
ema_update_stats = unwrapped_model.apply_ema_update(res.ema_stats)
# 记录EMA更新统计信息
if step % args.log_interval == 0 and accelerator.is_main_process and ema_update_stats.get('ema_update_applied', False):
total_memories = args.knowledge_num
Logger(f"EMA Update - Step: {ema_update_stats['ema_step']}, "
f"Updated memories: {ema_update_stats['updated_memories']}/{total_memories} "
f"({ema_update_stats['update_ratio']:.4f}), "
f"Coverage: {ema_update_stats['selected_memory_coverage']:.4f}", accelerator)
2025-05-14 00:01:40 +08:00
# 计时优化器步骤结束 (只在主进程进行)
if args.profile and accelerator.is_main_process and optimizer_end is not None:
2025-05-14 00:01:40 +08:00
optimizer_end.record()
2025-08-07 11:43:23 +08:00
# 验证评估和日志记录 (只在主进程进行)
if (step + 1) % args.val_interval == 0 and accelerator.is_main_process:
2025-05-14 00:01:40 +08:00
current_time = time.time()
# 记录日志输出时的详细内存状态
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "at_log_interval", detailed=True)
# 强制垃圾回收并记录内存变化
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
log_memory_status(step, prefetch_batches, accelerator, "after_gc", detailed=True)
2025-05-14 00:01:40 +08:00
# 计算性能指标
if args.profile and accelerator.is_main_process:
2025-05-14 00:01:40 +08:00
torch.cuda.synchronize()
# 确保所有事件都已记录才计算elapsed_time
try:
data_time = data_start.elapsed_time(data_end) if data_start is not None and data_end is not None else 0
forward_time = forward_start.elapsed_time(forward_end) if forward_start is not None and forward_end is not None else 0
backward_time = backward_start.elapsed_time(backward_end) if backward_start is not None and backward_end is not None else 0
optimizer_time = optimizer_start.elapsed_time(optimizer_end) if optimizer_start is not None and optimizer_end is not None else 0
iter_time = (current_time - last_log_time) * 1000 / args.log_interval # avg ms per iteration since last log
# total_time_ms = data_time + forward_time + backward_time + optimizer_time
# 打印性能分析
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
Logger(f"性能分析 (Avg/iter over last {args.log_interval} steps) - "
f"Data: {data_time/args.log_interval:.2f}ms, "
f"Fwd: {forward_time/args.log_interval:.2f}ms, "
f"Bwd: {backward_time/args.log_interval:.2f}ms, "
f"Optim: {optimizer_time/args.log_interval:.2f}ms, "
f"Iter Time: {iter_time:.2f}ms", accelerator)
# 生成文本示例
try:
# 随机选择一个样本
random_idx = torch.randint(0, X.size(0), (1,)).item()
sample_input = X[random_idx:random_idx+1] # [1, seq_len]
2025-07-17 12:06:28 +08:00
sample_target = Y[random_idx:random_idx+1] # [1, seq_len]
2025-07-17 12:06:28 +08:00
# 取前面的部分作为prompt确保后面有10个token作为真实值
prompt_len = sample_input.size(1) // 2
prompt_input = sample_input[:, :prompt_len]
2025-07-17 12:06:28 +08:00
# 获取真实的后10个token
true_next_tokens = sample_target[:, prompt_len-1:prompt_len-1+10] # 真实的接下来10个token
# 生成10个token
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.eval() # 设置为评估模式
with torch.no_grad():
generated = unwrapped_model.generate(
prompt_input,
max_new_tokens=10,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
# 转换为人类可读文本
prompt_text = tokenizer.decode(prompt_input[0], skip_special_tokens=True)
2025-07-17 12:06:28 +08:00
true_text = tokenizer.decode(true_next_tokens[0], skip_special_tokens=True)
# 获取新生成的token
prompt_tokens = prompt_input[0].tolist()
generated_tokens = generated[0].tolist()
if len(generated_tokens) > len(prompt_tokens):
new_tokens = generated_tokens[len(prompt_tokens):len(prompt_tokens)+10] # 只取前10个
generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
else:
generated_text = "[未生成新token]"
2025-07-17 12:06:28 +08:00
Logger(f"文本生成对比:", accelerator)
Logger(f" 输入提示: {prompt_text}", accelerator)
Logger(f" 真实续写: {true_text}", accelerator)
Logger(f" 模型生成: {generated_text}", accelerator)
unwrapped_model.train() # 恢复训练模式
except Exception as e:
Logger(f"生成文本示例失败: {e}", accelerator)
# 重置事件以便下次测量从0开始
data_start = torch.cuda.Event(enable_timing=True)
data_end = torch.cuda.Event(enable_timing=True)
forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
except RuntimeError as e:
if "Both events must be recorded" in str(e):
Logger(f"Warning: CUDA events not properly recorded, skipping performance analysis: {e}", accelerator)
else:
raise e
2025-05-14 00:01:40 +08:00
# 计算当前学习率
current_lr = optimizer.param_groups[0]['lr']
# 计算时间
epoch_elapsed_time = current_time - epoch_start_time
epoch_steps_done = step + 1
epoch_avg_step_time = epoch_elapsed_time / epoch_steps_done
epoch_remaining_time = epoch_avg_step_time * (total_steps_in_epoch - epoch_steps_done)
total_elapsed_time = current_time - overall_start_time
total_steps_done = epoch * total_steps_in_epoch + epoch_steps_done
total_avg_step_time = total_elapsed_time / total_steps_done if total_steps_done > 0 else 0
total_remaining_time = total_avg_step_time * (total_training_steps - total_steps_done) if total_steps_done > 0 else 0
# 计算训练速度 (基于最近的log_interval)
interval_elapsed_time = current_time - last_log_time
tokens_processed_interval = args.log_interval * args.batch_size * args.max_seq_len
tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0
last_log_time = current_time # 更新上次日志时间
2025-08-07 11:43:23 +08:00
# 执行验证评估
val_loss = None
if val_loader is not None:
try:
val_loss = validate_model(model, val_loader, loss_fct, ctx, accelerator)
Logger(f"验证损失: {val_loss:.4f}", accelerator)
except Exception as e:
Logger(f"验证评估失败: {e}", accelerator)
val_loss = None
# 获取记忆库更新统计(如果模型支持)
memory_update_stats = {}
if hasattr(model, 'get_memory_update_stats'):
try:
unwrapped_model = accelerator.unwrap_model(model)
if hasattr(unwrapped_model, 'get_memory_update_stats'):
memory_update_stats = unwrapped_model.get_memory_update_stats()
except Exception as e:
Logger(f"获取记忆更新统计失败: {e}", accelerator)
# 获取层级统计信息(如果模型支持)
layer_stats = {}
if hasattr(res, 'layer_stats') and res.layer_stats is not None:
layer_stats = res.layer_stats
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
2025-08-07 11:43:23 +08:00
2025-09-06 12:12:08 +08:00
# 🔥 构建四损失系统的日志字典
2025-05-14 00:42:50 +08:00
log_dict = {
"epoch": epoch + 1,
"step": step + 1,
"total_steps_in_epoch": total_steps_in_epoch,
2025-08-07 11:43:23 +08:00
"train/loss_ce": ce_loss.item(),
"train/loss_balance": balance_loss.item() if isinstance(balance_loss, torch.Tensor) else balance_loss,
2025-09-06 12:12:08 +08:00
"train/loss_similarity": similarity_loss.item() if isinstance(similarity_loss, torch.Tensor) else similarity_loss,
"train/loss_diversity": diversity_loss.item() if isinstance(diversity_loss, torch.Tensor) else diversity_loss,
2025-08-07 11:43:23 +08:00
"train/loss_total": total_loss.item(),
2025-05-14 00:42:50 +08:00
"lr": current_lr,
"tokens_per_sec": tokens_per_sec,
"epoch_time_left_seconds": epoch_remaining_time,
"total_time_left_seconds": total_remaining_time
}
2025-08-07 11:43:23 +08:00
# 添加验证损失
if val_loss is not None:
log_dict["val/loss"] = val_loss
# 添加记忆库更新统计
log_dict.update(memory_update_stats)
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
2025-08-07 11:43:23 +08:00
# 添加层级统计信息(选择性添加关键指标)
if layer_stats:
# 计算所有层的平均统计
avg_gini = np.mean([v for k, v in layer_stats.items() if k.endswith('_gini_coefficient')])
avg_coverage = np.mean([v for k, v in layer_stats.items() if k.endswith('_coverage_rate')])
total_dead = sum([v for k, v in layer_stats.items() if k.endswith('_dead_memories')])
total_hot = sum([v for k, v in layer_stats.items() if k.endswith('_hot_memories')])
log_dict.update({
'memory/avg_gini_coefficient': avg_gini,
'memory/avg_coverage_rate': avg_coverage,
'memory/total_dead_memories': total_dead,
'memory/total_hot_memories': total_hot,
2025-09-06 12:12:08 +08:00
'train/avg_selected_similarity': avg_selected_similarity, # 🔥 使用选中记忆的相似度
2025-08-07 11:43:23 +08:00
})
2025-09-06 12:12:08 +08:00
# 🔥 四损失系统的控制台输出
2025-05-14 00:01:40 +08:00
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
2025-09-06 12:12:08 +08:00
f"CE: {log_dict['train/loss_ce']:.4f}, "
f"Bal: {log_dict['train/loss_balance']:.4f}, "
f"Sim: {log_dict['train/loss_similarity']:.4f}, "
f"Div: {log_dict['train/loss_diversity']:.4f}, "
f"Total: {log_dict['train/loss_total']:.4f}, "
f"Val: {log_dict.get('val/loss', 'N/A')}, "
2025-05-14 00:42:50 +08:00
f"LR: {log_dict['lr']:.6f}, "
f"Speed: {log_dict['tokens_per_sec']:.2f} tokens/sec | "
2025-09-06 12:12:08 +08:00
f"Sel.Sim: {avg_selected_similarity:.4f} | "
2025-05-14 00:01:40 +08:00
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
swanlab_run.log(log_dict)
2025-05-14 00:42:50 +08:00
2025-05-14 00:01:40 +08:00
# 保存模型 (只在主进程进行)
2025-06-08 02:20:36 +00:00
loss_total = loss.item() * args.accumulation_steps
2025-08-07 11:43:23 +08:00
if epoch >= 0 and best_loss > loss_total and accelerator.is_main_process:
2025-06-08 02:20:36 +00:00
best_loss = loss_total
2025-05-14 00:01:40 +08:00
# 使用函数开始处定义的moe_path变量
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
# 获取解包后的模型
unwrapped_model = accelerator.unwrap_model(model)
# 保存模型参数
accelerator.save(unwrapped_model.state_dict(), ckp)
Logger(f"Model saved to {ckp}", accelerator)
except Exception as e:
Logger(f"Error in training step: {e}", accelerator)
# 记录异常时的内存状态
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "at_exception", detailed=True)
2025-05-14 00:01:40 +08:00
import traceback
Logger(traceback.format_exc(), accelerator)
# 清理prefetch_batches防止内存泄漏
if args.memory_monitor and accelerator.is_main_process:
Logger(f"[Memory Monitor] Clearing prefetch_batches due to exception. Current length: {len(prefetch_batches)}", accelerator)
prefetch_batches.clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if args.memory_monitor:
log_memory_status(step, prefetch_batches, accelerator, "after_exception_cleanup", detailed=True)
# 训练epoch结束时清理prefetch_batches
if args.memory_monitor:
if accelerator.is_main_process:
Logger(f"[Memory Monitor] Epoch {epoch+1} finished. Clearing prefetch_batches. Final length: {len(prefetch_batches)}", accelerator)
log_memory_status(total_steps_in_epoch-1, prefetch_batches, accelerator, "before_epoch_end_cleanup", detailed=True)
prefetch_batches.clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if args.memory_monitor:
log_memory_status(total_steps_in_epoch-1, prefetch_batches, accelerator, "after_epoch_end_cleanup", detailed=True)
2025-05-14 00:01:40 +08:00
def main():
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=4)
2025-06-08 02:20:36 +00:00
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
parser.add_argument("--batch_size", type=int, default=20)
2025-05-14 00:01:40 +08:00
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--dtype", type=str, default="bfloat16")
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
parser.add_argument("--use_swanlab", default=False, action="store_true") # 替换wandb参数
parser.add_argument("--swanlab_project", type=str, default="MiniMind-Pretrain") # 替换wandb参数
parser.add_argument("--num_workers", type=int, default=1)
2025-05-14 00:01:40 +08:00
parser.add_argument("--accumulation_steps", type=int, default=32)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
2025-07-17 12:06:28 +08:00
parser.add_argument("--log_interval", type=int, default=1)
2025-05-14 00:01:40 +08:00
parser.add_argument("--save_interval", type=int, default=10000)
parser.add_argument('--dim', default=512, type=int)
parser.add_argument('--n_layers', default=8, type=int)
2025-08-01 15:54:21 +08:00
parser.add_argument('--n_heads', default=32, type=int)
parser.add_argument('--max_seq_len', default=512, type=int)
2025-05-14 00:01:40 +08:00
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代")
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
parser.add_argument("--data_path", type=str, default="/home/iomgaa/Code/Minimind/dataset/stable/merged_pretrain.jsonl")
2025-05-14 00:01:40 +08:00
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目")
parser.add_argument("--knowledge_length", type=int, default=8,help="知识库的句子长度")
2025-08-03 14:25:26 +08:00
parser.add_argument("--knowledge_dim", type=int, default=128,help="知识库的向量维度")
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
parser.add_argument("--database_init_path", type=str, default="/home/iomgaa/Code/Minimind/dataset/stable/sentence_trex_data.json", help="数据库初始化路径")
2025-05-26 23:09:03 +08:00
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
2025-09-06 12:12:08 +08:00
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
2025-05-29 20:29:45 +08:00
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控")
parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)")
parser.add_argument("--model_type", type=str, default="model_memory", help="使用什么模型训练") #model,model_original,model_no_feed
2025-07-13 21:28:46 +08:00
parser.add_argument("--model_size", type=float, default=50.0, help="模型大小")
parser.add_argument("--swanlab_online", type=bool, default=False, help="是否使用在线SwanLab服务")
2025-08-07 11:43:23 +08:00
parser.add_argument("--balance_loss_coef", type=float, default=0.01, help="平衡损失系数")
2025-09-06 12:12:08 +08:00
parser.add_argument("--similarity_loss_coef", type=float, default=0.1, help="相似度损失系数实验1.4.9")
parser.add_argument("--diversity_loss_coef", type=float, default=0.05, help="多样性损失系数实验1.4.9")
Experiment 1.4.9: Memory Bank优化 - 顺序冻结 + 相似度Loss + 维度修复 🔬 实验基础: 基于实验1.4.7的重要改进 🎯 研究目标: 提升Memory Bank的知识保护和检索准确性 🚀 三大核心创新: 1️⃣ 智能冻结策略改进 • 从随机冻结 → 顺序冻结前20%记忆条目 • 保护重要知识: 假设前面的记忆条目更重要,需要优先保护 • freeze_ratio=0.2: 冻结前20%的memory_bank条目 2️⃣ 查询-知识相似度Loss • 新增相似度监督信号: 衡量查询向量与选中知识的匹配度 • 余弦相似度计算: F.cosine_similarity(query, selected_memory) • 相似度统计: 平均值、最大值、最小值、标准差全方位监控 3️⃣ 维度截断问题修复 • 统一维度处理: knowledge_dim → dim,避免信息截断 • concat_dim修正: dim + num_selected * dim (之前是knowledge_dim) • 记忆向量完整保留: 解决查询结果维度被不当压缩的问题 🏗️ 架构优化细节: • GatedMemoryFusion维度一致性: 统一使用dim维度 • 记忆池化策略: 使用平均池化压缩knowledge_length维度 • 残差连接增强: 改进memory_output与主路径的融合 📊 实验配置: • experiment_1_4_9-02: 8层网络完整测试 • experiment_1_4_9-04: 1层网络最小验证 • EMA更新机制: decay=0.9, update_freq=5 • 数据库初始化: sentence_trex_data.json文本数据 💡 技术假设: 顺序冻结策略能更好地保护重要知识,相似度Loss能提升检索精度, 维度统一能减少信息丢失,三者结合将显著改善Memory Bank性能。 🛠️ 基础设施改进: • UUID映射系统: 跟踪记忆条目的原始数据源 • 增强缓存机制: 支持映射文件自动生成 • 监控系统升级: 相似度统计信息实时追踪 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-05 14:24:48 +08:00
parser.add_argument("--val_data_path", type=str, default="/home/zym/Code/stable/eval_data.json", help="验证数据集路径")
2025-08-07 11:43:23 +08:00
parser.add_argument("--val_interval", type=int, default=100, help="验证评估间隔")
2025-09-06 12:12:08 +08:00
parser.add_argument("--freeze_ratio", type=float, default=0.2, help="冻结率")
2025-05-14 00:01:40 +08:00
args = parser.parse_args()
2025-05-14 00:01:40 +08:00
#########################################################
# 初始化accelerator和deepspeed
#########################################################
# 设置ddp_kwargs以处理未使用的参数
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# 创建DeepSpeedPlugin对象
ds_plugin = DeepSpeedPlugin(
gradient_accumulation_steps=args.accumulation_steps,
gradient_clipping=args.grad_clip,
zero_stage=2, # 使用ZeRO-2优化
offload_optimizer_device="none", # 将优化器状态卸载到CPU
2025-05-14 00:01:40 +08:00
offload_param_device="none", # 不将参数卸载到CPU
)
accelerator = Accelerator(
kwargs_handlers=[ddp_kwargs],
deepspeed_plugin=ds_plugin,
mixed_precision="bf16" if args.dtype == "bfloat16" else "fp16" if args.dtype == "float16" else "no"
)
#########################################################
# 设置随机种子
#########################################################
set_seed(1337 + accelerator.process_index)
#########################################################
# 配置模型
#########################################################
lm_config = LMConfig(
dim=args.dim,
n_layers=args.n_layers,
2025-09-06 17:25:46 +08:00
n_heads=args.n_heads,
2025-05-14 00:01:40 +08:00
max_seq_len=args.max_seq_len,
use_moe=args.use_moe,
disable_db=args.disable_db,
flash_attn=args.use_flash_attn,
2025-05-16 08:38:59 +00:00
knowledge_num=args.knowledge_num,
2025-06-08 02:20:36 +00:00
knowledge_length=args.knowledge_length,
2025-09-06 17:25:46 +08:00
knowledge_dim=args.knowledge_dim,
2025-09-06 12:12:08 +08:00
embeddings_epoch=args.embedding_epoch,
freeze_ratio=args.freeze_ratio
2025-05-14 00:01:40 +08:00
)
#########################################################
# 创建保存目录
#########################################################
args.save_dir = os.path.join(args.out_dir)
if accelerator.is_main_process:
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
#########################################################
# 设置数据类型
#########################################################
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
#########################################################
# 配置SwanLab
2025-05-14 00:01:40 +08:00
#########################################################
# 设置SwanLab运行名称
args.swanlab_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
# 合并args和lm_config为一个字典无论是否使用SwanLab都需要用于打印配置信息
config_dict = vars(args).copy()
config_dict.update(vars(lm_config))
# 初始化SwanLab实验实例
swanlab_run = None
if args.use_swanlab and accelerator.is_main_process:
2025-07-13 21:28:46 +08:00
if args.swanlab_online:
# 使用在线SwanLab服务
# 初始化SwanLab
swanlab_run = swanlab.init(
project=args.swanlab_project,
experiment_name=args.swanlab_run_name,
description="MiniMind预训练实验使用本地部署的SwanLab进行可视化",
config=config_dict
)
else:
swanlab_run = swanlab.init(
project=args.swanlab_project,
experiment_name=args.swanlab_run_name,
description="MiniMind预训练实验使用本地部署的SwanLab进行可视化",
config=config_dict,
mode="offline"
)
2025-05-14 00:01:40 +08:00
else:
swanlab_run = None
2025-05-14 00:01:40 +08:00
#########################################################
# 打印信息
#########################################################
# 计算每次迭代的token数量
tokens_per_iter = args.batch_size * lm_config.max_seq_len
if accelerator.is_main_process:
Logger(f"tokens_per_iter: {tokens_per_iter}", accelerator)
Logger("Configuration:", accelerator)
for key, value in config_dict.items():
Logger(f" {key}: {value}", accelerator)
#########################################################
# 设置自动混合精度上下文
#########################################################
ctx = nullcontext() if accelerator.device.type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
#########################################################
# 初始化模型和tokenizer
#########################################################
2025-05-26 23:09:03 +08:00
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path, args.database_init_path, args)
2025-05-14 00:01:40 +08:00
# 将accelerator传递给init_model函数中的Logger调用
Logger(f'模型初始化完成', accelerator)
#########################################################
# 处理位置编码张量问题
#########################################################
if hasattr(model, "pos_cis_real"):
Logger(f'检测到pos_cis_real实数张量将其设置为参与分布式训练', accelerator)
# 设置模型的_ddp_params_and_buffers_to_ignore属性
# model._ddp_params_and_buffers_to_ignore = {"pos_cis_real"}
# 兼容旧版本检查是否仍有pos_cis
elif hasattr(model, "pos_cis"):
Logger(f'检测到pos_cis复数张量将其设置为不参与分布式训练', accelerator)
# 设置模型的_ddp_params_and_buffers_to_ignore属性
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
#########################################################
# 创建数据集和数据加载器
#########################################################
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=True,
num_workers=args.num_workers,
persistent_workers=True if args.num_workers > 0 else False,
prefetch_factor=2 if args.num_workers > 0 else None
)
2025-08-07 11:43:23 +08:00
# 创建验证数据集和加载器
val_loader = None
val_ds = create_validation_dataset(args.val_data_path, tokenizer, lm_config.max_seq_len)
if val_ds is not None:
val_loader = DataLoader(
val_ds,
batch_size=args.batch_size // 2, # 验证时使用较小批次
pin_memory=True,
drop_last=False,
shuffle=False,
num_workers=0, # 验证时不使用多进程
)
2025-05-14 00:01:40 +08:00
#########################################################
# 创建优化器
#########################################################
# 如果启用EMA更新需要过滤掉memory_bank参数因为它不再需要梯度更新
if hasattr(model.params, 'use_ema_update') and model.params.use_ema_update:
# 只包含requires_grad=True的参数
optimizer_params = [p for p in model.parameters() if p.requires_grad]
Logger(f"EMA更新模式优化器包含 {len(optimizer_params)} 个参数过滤掉memory_bank")
Logger(f"总参数:{sum(p.numel() for p in model.parameters())} | 可训练参数:{sum(p.numel() for p in optimizer_params)}")
optimizer = optim.AdamW(optimizer_params, lr=args.learning_rate)
else:
# 传统模式:所有参数都使用梯度更新
Logger("传统梯度更新模式:优化器包含所有模型参数")
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
2025-05-14 00:01:40 +08:00
#########################################################
# 创建学习率调度器
#########################################################
total_steps = len(train_loader) * args.epochs
warmup_steps = args.warmup_iters if args.warmup_iters > 0 else int(0.1 * total_steps)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
#########################################################
# 准备训练
#########################################################
2025-08-07 11:43:23 +08:00
if val_loader is not None:
model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(
model, optimizer, train_loader, val_loader, scheduler
)
else:
model, optimizer, train_loader, scheduler = accelerator.prepare(
model, optimizer, train_loader, scheduler
)
2025-05-14 00:01:40 +08:00
#########################################################
# 训练循环
#########################################################
overall_start_time = time.time() # Record overall start time
for epoch in range(args.epochs):
Logger(f"开始第{epoch+1}轮训练", accelerator)
2025-08-07 11:43:23 +08:00
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer, val_loader) # Pass tokenizer and val_loader
# 每个epoch结束后进行内存清理
Logger(f"{epoch+1}轮训练完成,进行内存清理", accelerator)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 记录epoch结束时的内存状态
if accelerator.is_main_process:
memory_info = get_memory_usage()
cuda_info = get_cuda_memory_usage()
log_msg = f"[Memory Monitor] Epoch {epoch+1} completed - "
log_msg += f"System RSS: {memory_info['rss_mb']:.2f}MB"
if cuda_info:
log_msg += f", CUDA allocated: {cuda_info['cuda_allocated_mb']:.2f}MB"
log_msg += f", CUDA reserved: {cuda_info['cuda_reserved_mb']:.2f}MB"
Logger(log_msg, accelerator)
#########################################################
# 关闭SwanLab
#########################################################
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
swanlab_run.finish()
2025-05-14 00:01:40 +08:00
if __name__ == "__main__":
main()