Minimind/model/LMConfig.py
Yu Chengzhang cf9acb2064 Experiment 1.4.6: Token-based Memory架构实现
完成实验1.4.6的Token-based Memory架构,实现以下改进:
- 记忆库从连续特征向量存储改为离散token ID存储
- 实现双向编解码机制(embedding→特征→output→token)
- 优化EMA更新参数:ema_decay=0.9, ema_update_freq=5
- 显著降低GPU显存使用:从23GB降至13GB(-43%)
- 推理Loss从2.6382降至2.6142(改善0.9%)

技术亮点:
- 有效表示维度从128提升至4096(32x增强)
- 稀疏缓存机制避免内存爆炸
- 立即压缩策略平衡显存和性能
- 人类可解释的记忆内容

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-14 23:04:52 +08:00

106 lines
4.9 KiB
Python

from transformers import PretrainedConfig
from typing import List
class LMConfig(PretrainedConfig):
model_type = "minimind"
def __init__(
self,
dim: int = 512,
n_layers: int = 8,
n_heads: int = 32,
n_kv_heads: int = 8,
vocab_size: int = 6400,
hidden_dim: int = None,
multiple_of: int = 64,
norm_eps: float = 1e-5,
max_seq_len: int = 8192,
rope_theta: int = 1e6,
dropout: float = 0.0,
flash_attn: bool = True,
embeddings_epoch: int = 2,
####################################################
# DB related configurations
####################################################
disable_db: bool = False, # 特殊模式:禁用数据库功能
####################################################
# Here are the specific configurations of MOE
# When use_moe is false, the following is invalid
####################################################
use_moe: bool = False,
####################################################
num_experts_per_tok: int = 2,
n_routed_experts: int = 4,
n_shared_experts: bool = True,
scoring_func: str = 'softmax',
aux_loss_alpha: float = 0.1,
seq_aux: bool = True,
norm_topk_prob: bool = True,
####################################################
knowledge_num: int = 64*64,
knowledge_length: int = 8,
knowledge_dim: int = 128,
####################################################
# EMA update related configurations (inspired by VQ-VAE)
####################################################
use_ema_update: bool = True, # 是否使用EMA更新memory_bank
ema_decay: float = 0.9, # 🔥 1.4.6: 进一步降低衰减率,允许更激进更新 (0.999 → 0.8)
ema_update_freq: int = 5, # 🔥 1.4.6: 进一步提高更新频率 (1 → 5)
use_token_memory: bool = True, # 🔥 1.4.6: 新增token-based memory flag
####################################################
# Triple extraction related configurations
####################################################
max_subject_len: int = 8,
max_predicate_len: int = 4,
max_object_len: int = 8,
**kwargs,
):
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.multiple_of = multiple_of
self.norm_eps = norm_eps
self.max_seq_len = max_seq_len
self.rope_theta = rope_theta
self.dropout = dropout
self.flash_attn = flash_attn
self.embeddings_epoch = embeddings_epoch
####################################################
# DB related configurations
####################################################
self.disable_db = disable_db # 设置是否禁用数据库
####################################################
# Here are the specific configurations of MOE
# When use_moe is false, the following is invalid
####################################################
self.use_moe = use_moe
self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量
self.n_routed_experts = n_routed_experts # 总的专家数量
self.n_shared_experts = n_shared_experts # 共享专家
self.scoring_func = scoring_func # 评分函数,默认为'softmax'
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
####################################################
self.knowledge_num = knowledge_num
self.knowledge_length = knowledge_length
self.knowledge_dim = knowledge_dim
####################################################
# EMA update related configurations (inspired by VQ-VAE)
####################################################
self.use_ema_update = use_ema_update
self.ema_decay = ema_decay
self.ema_update_freq = ema_update_freq
self.use_token_memory = use_token_memory # 🔥 1.4.6: token-based memory flag
####################################################
# Triple extraction related configurations
####################################################
self.max_subject_len = max_subject_len
self.max_predicate_len = max_predicate_len
self.max_object_len = max_object_len
super().__init__(**kwargs)