120 lines
5.9 KiB
Python
120 lines
5.9 KiB
Python
from transformers import PretrainedConfig
|
||
from typing import List
|
||
|
||
|
||
class LMConfig(PretrainedConfig):
|
||
model_type = "minimind"
|
||
|
||
def __init__(
|
||
self,
|
||
dim: int = 512,
|
||
n_layers: int = 8,
|
||
n_heads: int = 16,
|
||
n_kv_heads: int = 8,
|
||
vocab_size: int = 6400,
|
||
hidden_dim: int = None,
|
||
multiple_of: int = 64,
|
||
norm_eps: float = 1e-5,
|
||
max_seq_len: int = 512,
|
||
rope_theta: int = 1e6,
|
||
dropout: float = 0.0,
|
||
flash_attn: bool = True,
|
||
embeddings_epoch: int = 2,
|
||
####################################################
|
||
# DB related configurations
|
||
####################################################
|
||
disable_db: bool = False, # 特殊模式:禁用数据库功能
|
||
####################################################
|
||
# Here are the specific configurations of MOE
|
||
# When use_moe is false, the following is invalid
|
||
####################################################
|
||
use_moe: bool = False,
|
||
####################################################
|
||
num_experts_per_tok: int = 2,
|
||
n_routed_experts: int = 4,
|
||
n_shared_experts: bool = True,
|
||
scoring_func: str = 'softmax',
|
||
aux_loss_alpha: float = 0.1,
|
||
seq_aux: bool = True,
|
||
norm_topk_prob: bool = True,
|
||
####################################################
|
||
knowledge_num: int = 1024*1024,
|
||
knowledge_length: int = 16,
|
||
knowledge_dim: int = 128,
|
||
####################################################
|
||
# EMA update related configurations (inspired by VQ-VAE)
|
||
####################################################
|
||
use_ema_update: bool = True, # 是否使用EMA更新memory_bank
|
||
ema_decay: float = 0.9, # 🔥 1.4.6: 进一步降低衰减率,允许更激进更新 (0.999 → 0.8)
|
||
ema_update_freq: int = 5, # 🔥 1.4.6: 进一步提高更新频率 (1 → 5)
|
||
use_token_memory: bool = True, # 🔥 1.4.6: 新增token-based memory flag
|
||
freeze_ratio: float = 0.2, # 🔥 新增: memory_bank冻结率 (0.0表示不冻结,0.2表示20%条目不更新)
|
||
####################################################
|
||
# Experiment 1.4.9: Gumbel-Softmax + Diversity Loss
|
||
####################################################
|
||
num_candidates: int = 32, # 🔥 实验1.4.9: 候选记忆条目数量
|
||
num_selected: int = 1, # 🔥 实验1.4.9: 选中的记忆条目数量 (现在只选1个最佳)
|
||
gumbel_temperature: float = 1.0, # 🔥 实验1.4.9: Gumbel-Softmax温度参数
|
||
####################################################
|
||
# Triple extraction related configurations
|
||
####################################################
|
||
max_subject_len: int = 8,
|
||
max_predicate_len: int = 4,
|
||
max_object_len: int = 8,
|
||
**kwargs,
|
||
):
|
||
self.dim = dim
|
||
self.n_layers = n_layers
|
||
self.n_heads = n_heads
|
||
self.n_kv_heads = n_kv_heads
|
||
self.vocab_size = vocab_size
|
||
self.hidden_dim = hidden_dim
|
||
self.multiple_of = multiple_of
|
||
self.norm_eps = norm_eps
|
||
self.max_seq_len = max_seq_len
|
||
self.rope_theta = rope_theta
|
||
self.dropout = dropout
|
||
self.flash_attn = flash_attn
|
||
self.embeddings_epoch = embeddings_epoch
|
||
####################################################
|
||
# DB related configurations
|
||
####################################################
|
||
self.disable_db = disable_db # 设置是否禁用数据库
|
||
####################################################
|
||
# Here are the specific configurations of MOE
|
||
# When use_moe is false, the following is invalid
|
||
####################################################
|
||
self.use_moe = use_moe
|
||
self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量
|
||
self.n_routed_experts = n_routed_experts # 总的专家数量
|
||
self.n_shared_experts = n_shared_experts # 共享专家
|
||
self.scoring_func = scoring_func # 评分函数,默认为'softmax'
|
||
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
|
||
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
|
||
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
|
||
####################################################
|
||
self.knowledge_num = knowledge_num
|
||
self.knowledge_length = knowledge_length
|
||
self.knowledge_dim = knowledge_dim
|
||
####################################################
|
||
# EMA update related configurations (inspired by VQ-VAE)
|
||
####################################################
|
||
self.use_ema_update = use_ema_update
|
||
self.ema_decay = ema_decay
|
||
self.ema_update_freq = ema_update_freq
|
||
self.use_token_memory = use_token_memory # 🔥 1.4.6: token-based memory flag
|
||
self.freeze_ratio = freeze_ratio # 🔥 新增: memory_bank冻结率
|
||
####################################################
|
||
# Experiment 1.4.9: Gumbel-Softmax + Diversity Loss
|
||
####################################################
|
||
self.num_candidates = num_candidates
|
||
self.num_selected = num_selected
|
||
self.gumbel_temperature = gumbel_temperature
|
||
####################################################
|
||
# Triple extraction related configurations
|
||
####################################################
|
||
self.max_subject_len = max_subject_len
|
||
self.max_predicate_len = max_predicate_len
|
||
self.max_object_len = max_object_len
|
||
super().__init__(**kwargs)
|