fix
This commit is contained in:
parent
1ddfd310ec
commit
e3120f5e62
254
model/model.py
254
model/model.py
@ -11,6 +11,12 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
# RMSNorm 类定义了一个用于归一化输入张量的模块。
|
||||
class RMSNorm(torch.nn.Module):
|
||||
@ -158,6 +164,42 @@ class Attention(nn.Module):
|
||||
return output, past_kv
|
||||
|
||||
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.to_q = nn.Linear(768, 768, bias=False)
|
||||
self.to_k = nn.Linear(768, 768, bias=False)
|
||||
self.to_v = nn.Linear(768, 768, bias=False)
|
||||
|
||||
|
||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||
# db = db.permute(0, 2, 1)
|
||||
|
||||
q = self.to_q(x)
|
||||
k = self.to_k(db)
|
||||
v = self.to_v(db)
|
||||
|
||||
if pos_emb is not None:
|
||||
q = q + pos_emb
|
||||
k = k + pos_emb
|
||||
v = v + pos_emb
|
||||
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(k.size(-1))
|
||||
|
||||
if context_mask is not None:
|
||||
attn_scores = attn_scores.masked_fill(context_mask == 0, -1e10)
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
context = torch.matmul(attn_weights, v)
|
||||
|
||||
return context
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
@ -290,51 +332,136 @@ class MOEFeedForward(nn.Module):
|
||||
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, config: LMConfig, weight_down_embed=None):
|
||||
def __init__(self, layer_id: int, config: LMConfig):
|
||||
super().__init__()
|
||||
self.n_heads = config.n_heads
|
||||
self.dim = config.dim
|
||||
self.head_dim = config.dim // config.n_heads
|
||||
self.attention = Attention(config)
|
||||
self.cross_att = CrossAttention(config)
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||
|
||||
# Product Key 相关参数
|
||||
self.weight_down_embed = weight_down_embed
|
||||
# 假设num_experts是已定义的总专家数量的平方根
|
||||
self.num_keys = int(math.sqrt(self.weight_down_embed.num_embeddings)) if weight_down_embed is not None else 0
|
||||
|
||||
|
||||
# 查询生成的参数
|
||||
self.dim_key = config.dim // 2 # 一般用特征维度的一半
|
||||
|
||||
|
||||
# 创建查询生成模块
|
||||
if weight_down_embed is not None:
|
||||
self.to_queries = nn.Sequential(
|
||||
nn.Linear(config.dim, self.dim_key * self.n_heads * 2, bias=False),
|
||||
nn.Unflatten(2, (2, self.n_heads, self.dim_key)) # 替代Rearrange
|
||||
# if weight_down_embed is not None:
|
||||
# self.to_queries = nn.Sequential(
|
||||
# nn.Linear(config.dim, self.dim_key * 2, bias=False),
|
||||
# # nn.Unflatten(2, (2, self.n_heads, self.dim_key)) # 替代Rearrange
|
||||
# )
|
||||
|
||||
# # 超参数
|
||||
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
|
||||
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
|
||||
|
||||
def forward(self, x,db_value, pos_cis, past_key_value=None, use_cache=False):
|
||||
# import pdb;pdb.set_trace()
|
||||
# db_value = None
|
||||
|
||||
# # 如果有weight_down_embed,使用Product Key机制
|
||||
# if self.weight_down_embed is not None:
|
||||
# # 1. 生成queries
|
||||
# batch_size, seq_len, dim = x.shape
|
||||
|
||||
# # collapse sequence dimension by averaging
|
||||
# x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
# queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
||||
# queries = queries.reshape(batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
|
||||
# queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
|
||||
|
||||
# # 2. 计算queries与keys的相似度
|
||||
# sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# # 3. 在两个子空间分别做top-k
|
||||
# scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
# scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
# indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# # 4. 组合两个子空间的分数和索引
|
||||
# all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
|
||||
# all_scores = all_scores.view(*all_scores.shape[:-2], -1)
|
||||
|
||||
# all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
|
||||
# all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
||||
|
||||
# # 5. 最终top-k选择
|
||||
# scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
|
||||
# indices = all_indices.gather(-1, pk_indices)
|
||||
|
||||
# # 6. 从embedding中获取专家值
|
||||
|
||||
# # 从embedding中获取值
|
||||
# flat_indices = indices.view(-1) # 将索引展平为一维张量
|
||||
# db_values = self.weight_down_embed(flat_indices)
|
||||
|
||||
# # 重塑回原始形状
|
||||
# db_value = db_values.view(batch_size, -1, dim)
|
||||
|
||||
|
||||
# 注意力计算
|
||||
h_attn, past_kv = self.attention(
|
||||
self.attention_norm(x),
|
||||
pos_cis,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
db_value=db_value
|
||||
)
|
||||
|
||||
# 存储Product Keys
|
||||
self.keys = nn.Parameter(torch.randn(self.n_heads, self.num_keys, 2, self.dim_key) * 0.02)
|
||||
h_attn = self.cross_att(h_attn,db_value)
|
||||
|
||||
# 超参数
|
||||
self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
|
||||
self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
|
||||
# 残差连接
|
||||
h = x + h_attn
|
||||
|
||||
def forward(self, x, pos_cis, past_key_value=None, use_cache=False):
|
||||
db_value = None
|
||||
# 前馈神经网络
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out, past_kv
|
||||
|
||||
# 如果有weight_down_embed,使用Product Key机制
|
||||
if self.weight_down_embed is not None:
|
||||
class ExtractDB(nn.Module):
|
||||
def __init__(self,params):
|
||||
# 修改专家数量和知识维度,确保能开方
|
||||
super().__init__()
|
||||
self.batch_size = None
|
||||
self.dim = params.dim
|
||||
self.dim_key = self.dim // 2
|
||||
self.num_experts = 10 * 10 # 1M专家,确保是完全平方数
|
||||
# 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用
|
||||
self.head_dim = params.dim // params.n_heads
|
||||
self.knowledge_dim = 8*params.dim
|
||||
# 使用CPU上的普通tensor替代nn.Embedding
|
||||
self.register_buffer('weight_down_embed_cpu', torch.randn(self.num_experts, self.knowledge_dim,
|
||||
dtype=torch.float32,
|
||||
device='cpu') * 0.02,
|
||||
persistent=True)
|
||||
|
||||
self.num_keys = int(math.sqrt(self.num_experts)) if self.num_experts > 0 else 0
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
|
||||
self.num_experts_per_head_topk = 1
|
||||
self.to_queries = nn.Sequential(
|
||||
nn.Linear(params.dim, self.dim_key * 2, bias=False),
|
||||
)
|
||||
|
||||
def q_to_k(self,x):
|
||||
# 1. 生成queries
|
||||
queries = self.to_queries(x) # [b, n, 2, h, d]
|
||||
queries = queries.permute(2, 0, 1, 3, 4) # [2, b, n, h, d]
|
||||
self.batch_size, seq_len, dim = x.shape
|
||||
|
||||
# collapse sequence dimension by averaging
|
||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
||||
queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
|
||||
queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
|
||||
|
||||
# 2. 计算queries与keys的相似度
|
||||
sim = torch.einsum('p b n h d, h k p d -> p b n h k', queries, self.keys)
|
||||
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# 3. 在两个子空间分别做top-k
|
||||
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
@ -351,40 +478,25 @@ class MiniMindBlock(nn.Module):
|
||||
# 5. 最终top-k选择
|
||||
scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
|
||||
indices = all_indices.gather(-1, pk_indices)
|
||||
|
||||
# 6. 从embedding中获取专家值
|
||||
# [b, n, h, k] -> [b, n, h, k, 1] -> [b, n, h, k, d]
|
||||
indices_expanded = indices.unsqueeze(-1).expand(-1, -1, -1, -1, self.weight_down_embed.embedding_dim)
|
||||
|
||||
# 将索引从3D展平为1D以便gather操作
|
||||
batch_size, seq_len = x.shape[0], x.shape[1]
|
||||
flat_indices = indices.view(-1)
|
||||
return flat_indices
|
||||
|
||||
# 从embedding中获取值
|
||||
db_values = self.weight_down_embed(flat_indices)
|
||||
def get_data(self, index):
|
||||
# 将需要的embedding从CPU移到当前设备上
|
||||
device = index.device
|
||||
# 根据索引获取对应的embedding
|
||||
db_values = self.weight_down_embed_cpu[index.cpu()].to(device)
|
||||
db_value = db_values.view(self.batch_size, -1, self.dim)
|
||||
return db_value
|
||||
|
||||
# 重塑回原始形状
|
||||
db_value = db_values.view(batch_size, seq_len, self.n_heads,
|
||||
self.num_experts_per_head_topk, -1)
|
||||
def updata_value(self, k, v):
|
||||
# 更新CPU上的张量值
|
||||
k_cpu = k.cpu()
|
||||
v_cpu = v.view(v.size(0), -1).cpu()
|
||||
|
||||
# 使用分数加权
|
||||
db_value = db_value * F.relu(scores.unsqueeze(-1))
|
||||
# 直接更新内存中的值
|
||||
self.weight_down_embed_cpu[k_cpu] = v_cpu
|
||||
|
||||
# 合并多个专家的输出(如果每个头有多个专家)
|
||||
if self.num_experts_per_head_topk > 1:
|
||||
db_value = db_value.sum(dim=3) # [b, n, h, d]
|
||||
|
||||
# 注意力计算
|
||||
h_attn, past_kv = self.attention(
|
||||
self.attention_norm(x),
|
||||
pos_cis,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
db_value=db_value
|
||||
)
|
||||
h = x + h_attn
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out, past_kv
|
||||
|
||||
|
||||
class MiniMindLM(PreTrainedModel):
|
||||
@ -396,23 +508,23 @@ class MiniMindLM(PreTrainedModel):
|
||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||||
self.dropout = nn.Dropout(params.dropout)
|
||||
|
||||
# 修改专家数量和知识维度,确保能开方
|
||||
self.num_experts = 1000 * 1000 # 1M专家,确保是完全平方数
|
||||
# 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用
|
||||
self.head_dim = params.dim // params.n_heads
|
||||
self.knowledge_dim = self.head_dim
|
||||
|
||||
# 定义weight_down_embed,用于存储专家知识
|
||||
self.weight_down_embed = nn.Embedding(self.num_experts, self.knowledge_dim)
|
||||
# 初始化embedding权重
|
||||
nn.init.normal_(self.weight_down_embed.weight, std=0.02)
|
||||
# 移除旧的weight_down_embed声明
|
||||
self.extract_db = ExtractDB(self.params)
|
||||
|
||||
# 将self.weight_down_embed传递给每个MiniMindBlock
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.weight_down_embed) for l in range(self.n_layers)])
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
self.downsample_v = nn.Sequential(
|
||||
nn.Conv1d(511*8,128*8,kernel_size=1,padding='same'),
|
||||
nn.Conv1d(128*8,128,kernel_size=1,padding='same'),
|
||||
nn.Conv1d(128,8,kernel_size=1,padding='same')
|
||||
)
|
||||
self.downsample_q = nn.Sequential(
|
||||
nn.Conv1d(511*8,128*8,kernel_size=1,padding='same'),
|
||||
nn.Conv1d(128*8,512,kernel_size=1,padding='same')
|
||||
)
|
||||
self.register_buffer("pos_cis",
|
||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
@ -429,13 +541,29 @@ class MiniMindLM(PreTrainedModel):
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
past_kvs = []
|
||||
h_list = []
|
||||
|
||||
for l, layer in enumerate(self.layers):
|
||||
index = self.extract_db.q_to_k(h)
|
||||
db_value = self.extract_db.get_data(index)
|
||||
h, past_kv = layer(
|
||||
h, pos_cis,
|
||||
h,db_value, pos_cis,
|
||||
past_key_value=past_key_values[l],
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
past_kvs.append(past_kv)
|
||||
h_list.append(h.unsqueeze(0))
|
||||
h_tensor = torch.cat(h_list,dim=0).permute(1,0,2,3)
|
||||
h_tensor = h_tensor.reshape(h_tensor.shape[0],-1,768)
|
||||
z_v = self.downsample_v(h_tensor)
|
||||
z_q = self.downsample_q(h_tensor)
|
||||
|
||||
z_k = self.extract_db.q_to_k(z_q)
|
||||
self.extract_db.updata_value(z_k,z_v)
|
||||
|
||||
#更新数据库
|
||||
# q,v = f(h_list)
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
|
@ -138,7 +138,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-4)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用,则使用GPU,否则使用CPU。
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", default=True, action="store_true")
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
@ -148,9 +148,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。
|
||||
parser.add_argument("--save_interval", type=int, default=100) #模型保存间隔,用于控制模型保存的频率。
|
||||
parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。
|
||||
parser.add_argument('--dim', default=1024, type=int) #模型维度,用于控制模型的大小。
|
||||
parser.add_argument('--n_layers', default=24, type=int) #层数,用于控制模型层数。
|
||||
parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。
|
||||
parser.add_argument('--dim', default=768, type=int) #模型维度,用于控制模型的大小。
|
||||
parser.add_argument('--n_layers', default=8, type=int) #层数,用于控制模型层数。
|
||||
parser.add_argument('--max_seq_len', default=512, type=int) #最大序列长度,用于控制输入序列的最大长度。
|
||||
parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE,用于控制是否使用MOE。
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
|
||||
args = parser.parse_args()
|
||||
|
Loading…
x
Reference in New Issue
Block a user