Minimind/model/LMConfig.py

104 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from transformers import PretrainedConfig
from typing import List
class LMConfig(PretrainedConfig):
model_type = "minimind"
def __init__(
self,
dim: int = 512,
n_layers: int = 8,
n_heads: int = 32,
n_kv_heads: int = 8,
vocab_size: int = 6400,
hidden_dim: int = None,
multiple_of: int = 64,
norm_eps: float = 1e-5,
max_seq_len: int = 8192,
rope_theta: int = 1e6,
dropout: float = 0.0,
flash_attn: bool = True,
embeddings_epoch: int = 2,
####################################################
# DB related configurations
####################################################
disable_db: bool = False, # 特殊模式:禁用数据库功能
####################################################
# Here are the specific configurations of MOE
# When use_moe is false, the following is invalid
####################################################
use_moe: bool = False,
####################################################
num_experts_per_tok: int = 2,
n_routed_experts: int = 4,
n_shared_experts: bool = True,
scoring_func: str = 'softmax',
aux_loss_alpha: float = 0.1,
seq_aux: bool = True,
norm_topk_prob: bool = True,
####################################################
knowledge_num: int = 64*64,
knowledge_length: int = 8,
knowledge_dim: int = 128,
####################################################
# EMA update related configurations (inspired by VQ-VAE)
####################################################
use_ema_update: bool = True, # 是否使用EMA更新memory_bank
ema_decay: float = 0.999, # EMA衰减率类似VQ-VAE中的γ
ema_update_freq: int = 1, # EMA更新频率每N个训练步更新一次
####################################################
# Triple extraction related configurations
####################################################
max_subject_len: int = 8,
max_predicate_len: int = 4,
max_object_len: int = 8,
**kwargs,
):
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.multiple_of = multiple_of
self.norm_eps = norm_eps
self.max_seq_len = max_seq_len
self.rope_theta = rope_theta
self.dropout = dropout
self.flash_attn = flash_attn
self.embeddings_epoch = embeddings_epoch
####################################################
# DB related configurations
####################################################
self.disable_db = disable_db # 设置是否禁用数据库
####################################################
# Here are the specific configurations of MOE
# When use_moe is false, the following is invalid
####################################################
self.use_moe = use_moe
self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量
self.n_routed_experts = n_routed_experts # 总的专家数量
self.n_shared_experts = n_shared_experts # 共享专家
self.scoring_func = scoring_func # 评分函数,默认为'softmax'
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
####################################################
self.knowledge_num = knowledge_num
self.knowledge_length = knowledge_length
self.knowledge_dim = knowledge_dim
####################################################
# EMA update related configurations (inspired by VQ-VAE)
####################################################
self.use_ema_update = use_ema_update
self.ema_decay = ema_decay
self.ema_update_freq = ema_update_freq
####################################################
# Triple extraction related configurations
####################################################
self.max_subject_len = max_subject_len
self.max_predicate_len = max_predicate_len
self.max_object_len = max_object_len
super().__init__(**kwargs)