Minimind/model/LMConfig.py

90 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from transformers import PretrainedConfig
from typing import List, Optional, Union
class LMConfig(PretrainedConfig):
model_type = "minimind"
def __init__(
self,
dim: int = 512,
n_layers: int = 8,
n_heads: int = 32,
n_kv_heads: int = 8,
vocab_size: int = 6400,
hidden_dim: Optional[int] = None,
multiple_of: int = 64,
norm_eps: float = 1e-5,
max_seq_len: int = 8192,
rope_theta: float = 1e6,
dropout: float = 0.0,
flash_attn: bool = True,
####################################################
# DB related configurations
####################################################
disable_db: bool = False, # 特殊模式:禁用数据库功能
use_direct_semantic: bool = False, # 是否使用直接语义匹配替代Product Key
realtime_steps: int = 2000, # 前多少步使用实时计算(后续使用渐进式缓存)
db_intelligent_balance: bool = True, # 是否启用智能负载均衡
db_relevance_threshold: float = 0.7, # 相关性阈值(第一层过滤)
db_balance_strength: float = 0.3, # 平衡权重的基础值
db_momentum: float = 0.9, # 使用频率统计的动量
db_adaptive_weights: bool = True, # 是否启用动态权重调整
####################################################
# Here are the specific configurations of MOE
# When use_moe is false, the following is invalid
####################################################
use_moe: bool = False,
####################################################
num_experts_per_tok: int = 2,
n_routed_experts: int = 4,
n_shared_experts: bool = True,
scoring_func: str = 'softmax',
aux_loss_alpha: float = 0.1,
seq_aux: bool = True,
norm_topk_prob: bool = True,
####################################################
knowledge_num: int = 64*64,
knowledge_length: int = 8,
**kwargs,
):
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.multiple_of = multiple_of
self.norm_eps = norm_eps
self.max_seq_len = max_seq_len
self.rope_theta = rope_theta
self.dropout = dropout
self.flash_attn = flash_attn
####################################################
# DB related configurations
####################################################
self.disable_db = disable_db # 设置是否禁用数据库
self.use_direct_semantic = use_direct_semantic # 是否使用直接语义匹配替代Product Key
self.realtime_steps = realtime_steps # 前多少步使用实时计算(后续使用渐进式缓存)
self.db_intelligent_balance = db_intelligent_balance # 是否启用智能负载均衡
self.db_relevance_threshold = db_relevance_threshold # 相关性阈值(第一层过滤)
self.db_balance_strength = db_balance_strength # 平衡权重的基础值
self.db_momentum = db_momentum # 使用频率统计的动量
self.db_adaptive_weights = db_adaptive_weights # 是否启用动态权重调整
####################################################
# Here are the specific configurations of MOE
# When use_moe is false, the following is invalid
####################################################
self.use_moe = use_moe
self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量
self.n_routed_experts = n_routed_experts # 总的专家数量
self.n_shared_experts = n_shared_experts # 共享专家
self.scoring_func = scoring_func # 评分函数,默认为'softmax'
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
####################################################
self.knowledge_num = knowledge_num
self.knowledge_length = knowledge_length
super().__init__(**kwargs)