Minimind/model/LMConfig.py
Yu Chengzhang 44fe6259ec Experiment 1.4.7: Memory Bank文本初始化 + 部分冻结机制
## 主要改进
- 🔥 Memory Bank文本初始化:使用sentence_trex_data.json真实文本数据
- 🔥 部分冻结机制:新增freeze_ratio=0.2,保护20%重要记忆条目
- 📊 性能提升:推理Loss改善5.5% (2.4699 vs 2.6142)

## 核心变更
### model/LMConfig.py
- 新增freeze_ratio参数,支持Memory Bank条目冻结控制

### model/model_memory.py
- 实现freeze_mask机制,随机冻结20%记忆条目
- EMA更新过滤:只更新未冻结条目,保护重要知识
- 统计信息增强:新增冻结条目数量和比例监控

### train_pretrain_accelerate.py
- model_memory完整初始化支持:文本数据处理、缓存机制
- sentence_trex_data.json文本tokenization和长度处理
- memory_bank_init缓存优化,提升重复实验效率

### 实验文档
- experiment/EXPERIMENT_1_4_7.md:完整实验记录和结果分析
- run_file/experiment_1_4_7.sh:实验执行脚本
- CLAUDE.md:架构设计防护规则和模型版本管理规范

## 实验结果
 文本初始化效果验证:Loss性能改善5.5%
 冻结机制技术实现:209,715/1,048,576条目成功冻结
 生成连贯性仍需改进:架构级问题待解决

## 下一步优化
- EOS token控制修复
- Cross-attention权重优化
- 生成参数调优(temperature/top_p)

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-19 19:32:52 +08:00

108 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from transformers import PretrainedConfig
from typing import List
class LMConfig(PretrainedConfig):
model_type = "minimind"
def __init__(
self,
dim: int = 512,
n_layers: int = 8,
n_heads: int = 32,
n_kv_heads: int = 8,
vocab_size: int = 6400,
hidden_dim: int = None,
multiple_of: int = 64,
norm_eps: float = 1e-5,
max_seq_len: int = 8192,
rope_theta: int = 1e6,
dropout: float = 0.0,
flash_attn: bool = True,
embeddings_epoch: int = 2,
####################################################
# DB related configurations
####################################################
disable_db: bool = False, # 特殊模式:禁用数据库功能
####################################################
# Here are the specific configurations of MOE
# When use_moe is false, the following is invalid
####################################################
use_moe: bool = False,
####################################################
num_experts_per_tok: int = 2,
n_routed_experts: int = 4,
n_shared_experts: bool = True,
scoring_func: str = 'softmax',
aux_loss_alpha: float = 0.1,
seq_aux: bool = True,
norm_topk_prob: bool = True,
####################################################
knowledge_num: int = 64*64,
knowledge_length: int = 8,
knowledge_dim: int = 128,
####################################################
# EMA update related configurations (inspired by VQ-VAE)
####################################################
use_ema_update: bool = True, # 是否使用EMA更新memory_bank
ema_decay: float = 0.9, # 🔥 1.4.6: 进一步降低衰减率,允许更激进更新 (0.999 → 0.8)
ema_update_freq: int = 5, # 🔥 1.4.6: 进一步提高更新频率 (1 → 5)
use_token_memory: bool = True, # 🔥 1.4.6: 新增token-based memory flag
freeze_ratio: float = 0.2, # 🔥 新增: memory_bank冻结率 (0.0表示不冻结0.2表示20%条目不更新)
####################################################
# Triple extraction related configurations
####################################################
max_subject_len: int = 8,
max_predicate_len: int = 4,
max_object_len: int = 8,
**kwargs,
):
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.multiple_of = multiple_of
self.norm_eps = norm_eps
self.max_seq_len = max_seq_len
self.rope_theta = rope_theta
self.dropout = dropout
self.flash_attn = flash_attn
self.embeddings_epoch = embeddings_epoch
####################################################
# DB related configurations
####################################################
self.disable_db = disable_db # 设置是否禁用数据库
####################################################
# Here are the specific configurations of MOE
# When use_moe is false, the following is invalid
####################################################
self.use_moe = use_moe
self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量
self.n_routed_experts = n_routed_experts # 总的专家数量
self.n_shared_experts = n_shared_experts # 共享专家
self.scoring_func = scoring_func # 评分函数,默认为'softmax'
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
####################################################
self.knowledge_num = knowledge_num
self.knowledge_length = knowledge_length
self.knowledge_dim = knowledge_dim
####################################################
# EMA update related configurations (inspired by VQ-VAE)
####################################################
self.use_ema_update = use_ema_update
self.ema_decay = ema_decay
self.ema_update_freq = ema_update_freq
self.use_token_memory = use_token_memory # 🔥 1.4.6: token-based memory flag
self.freeze_ratio = freeze_ratio # 🔥 新增: memory_bank冻结率
####################################################
# Triple extraction related configurations
####################################################
self.max_subject_len = max_subject_len
self.max_predicate_len = max_predicate_len
self.max_object_len = max_object_len
super().__init__(**kwargs)