diff --git a/model/model.py b/model/model.py index 2873a77..b9fdf14 100644 --- a/model/model.py +++ b/model/model.py @@ -116,13 +116,13 @@ class Attention(nn.Module): repeat_kv(xk, self.n_rep).transpose(1, 2), repeat_kv(xv, self.n_rep).transpose(1, 2) ) - + # 如果提供了db_value,根据头的数量调整它的形状并与xv合并 if db_value is not None: # 确保db_value的形状与xv兼容,假设db_value形状为[B, N, H, D] if db_value.ndim == 4: # [B, N, H, D] db_value = db_value.transpose(1, 2) # -> [B, H, N, D] - + # 检查是否需要调整D维度 if db_value.shape[-1] != xv.shape[-1]: # 如果db_value的维度与xv不同,可以添加一个投影层 @@ -138,11 +138,11 @@ class Attention(nn.Module): factor = xv.shape[-1] // db_value.shape[-1] db_value = db_value.unsqueeze(-1).repeat(1, 1, 1, 1, factor) db_value = db_value.view(bsz, self.n_local_heads, seq_len, xv.shape[-1]) - + # 将db_value与xv相加或融合 # 这里我们简单地将它们相加,但你也可以使用其他融合方法 xv = xv + db_value - + # 使用Flash Attention if self.flash and seq_len != 1: dropout_p = self.dropout if self.training else 0.0 @@ -173,42 +173,42 @@ class CrossAttention(nn.Module): ): super().__init__() self.config = config - self.num_heads = 8 - self.head_dim = self.config.dim // self.num_heads + self.num_heads = 8 + self.head_dim = self.config.dim // self.num_heads self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False) self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False) self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False) - + self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False) - + def forward(self, x, db, context_mask=None, pos_emb=None): batch_size = x.size(0) - + # 分离多头 q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) - + if pos_emb is not None: pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) q = q + pos_emb k = k + pos_emb v = v + pos_emb - + attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) - + if context_mask is not None: expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1) attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10) - + attn_weights = F.softmax(attn_scores, dim=-1) context = torch.matmul(attn_weights, v) - + context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim) context = self.to_out(context) - + return context class FeedForward(nn.Module): @@ -350,25 +350,25 @@ class MiniMindBlock(nn.Module): self.head_dim = config.dim // config.n_heads self.attention = Attention(config) self.cross_att = CrossAttention(config) - + self.layer_id = layer_id self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config) - + # 假设num_experts是已定义的总专家数量的平方根 - - + + # 查询生成的参数 - - + + # 创建查询生成模块 # if weight_down_embed is not None: # self.to_queries = nn.Sequential( # nn.Linear(config.dim, self.dim_key * 2, bias=False), # # nn.Unflatten(2, (2, self.n_heads, self.dim_key)) # 替代Rearrange # ) - + # # 超参数 # self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys # self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数 @@ -376,47 +376,47 @@ class MiniMindBlock(nn.Module): def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=True): # import pdb;pdb.set_trace() # db_value = None - + # # 如果有weight_down_embed,使用Product Key机制 # if self.weight_down_embed is not None: # # 1. 生成queries # batch_size, seq_len, dim = x.shape - + # # collapse sequence dimension by averaging # x_flat = x.mean(dim=1) # [batch_size, dim] # queries = self.to_queries(x_flat) # [batch_size, 2*dim_key] # queries = queries.reshape(batch_size, 2, self.dim_key) # [batch_size, 2, dim_key] # queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key] - + # # 2. 计算queries与keys的相似度 # sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys) - + # # 3. 在两个子空间分别做top-k # scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)] # scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0] # indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1] - + # # 4. 组合两个子空间的分数和索引 # all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # all_scores = all_scores.view(*all_scores.shape[:-2], -1) - + # all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # all_indices = all_indices.view(*all_indices.shape[:-2], -1) - + # # 5. 最终top-k选择 # scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1) # indices = all_indices.gather(-1, pk_indices) - + # # 6. 从embedding中获取专家值 - + # # 从embedding中获取值 # flat_indices = indices.view(-1) # 将索引展平为一维张量 # db_values = self.weight_down_embed(flat_indices) - + # # 重塑回原始形状 # db_value = db_values.view(batch_size, -1, dim) - - + + # 注意力计算 h_attn, past_kv = self.attention( self.attention_norm(x), @@ -428,7 +428,7 @@ class MiniMindBlock(nn.Module): h_attn = self.cross_att(h_attn, db_value) - # 残差连接 + # 残差连接 h = x + h_attn # 前馈神经网络 @@ -441,15 +441,15 @@ class ExtractDB(nn.Module): super().__init__() self.batch_size = None self.dim = params.dim - self.dim_key = self.dim // 2 + self.dim_key = self.dim // 2 self.num_experts = 10 * 10 # 100专家,确保是完全平方数 # 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用 self.head_dim = params.dim // params.n_heads self.knowledge_dim = 8*params.dim - + # 使用register_buffer代替nn.Parameter,避免梯度问题 self.register_buffer('weight_down_embed', torch.randn(self.num_experts, self.knowledge_dim) * 0.02) - + self.num_keys = int(math.sqrt(self.num_experts)) if self.num_experts > 0 else 0 self.product_key_topk = min(16, self.num_keys) self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02) @@ -457,45 +457,45 @@ class ExtractDB(nn.Module): self.to_queries = nn.Sequential( nn.Linear(params.dim, self.dim_key * 2, bias=False), ) - + def q_to_k(self,x): # 1. 生成queries self.batch_size, seq_len, dim = x.shape - + # collapse sequence dimension by averaging x_flat = x.mean(dim=1) # [batch_size, dim] queries = self.to_queries(x_flat) # [batch_size, 2*dim_key] queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key] queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key] - + # 2. 计算queries与keys的相似度 sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys) - + # 3. 在两个子空间分别做top-k scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)] scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0] indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1] - + # 4. 组合两个子空间的分数和索引 all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) all_scores = all_scores.view(*all_scores.shape[:-2], -1) - + all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) all_indices = all_indices.view(*all_indices.shape[:-2], -1) - + # 5. 最终top-k选择 scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1) indices = all_indices.gather(-1, pk_indices) flat_indices = indices.view(-1) return flat_indices - + def get_data(self, index): # 直接从GPU获取embedding db_values = self.weight_down_embed[index] db_value = db_values.view(self.batch_size, -1, self.dim) return db_value - + @torch.no_grad() def updata_value(self, k, v): # 直接更新buffer上的值 (不需要梯度) @@ -504,7 +504,7 @@ class ExtractDB(nn.Module): v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype) self.weight_down_embed[k] = v_reshaped - + class MiniMindLM(PreTrainedModel): config_class = LMConfig @@ -523,12 +523,12 @@ class MiniMindLM(PreTrainedModel): self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.tok_embeddings.weight = self.output.weight - + # Calculate input dimension input_dim = (self.params.max_seq_len-1)*self.params.n_layers # Use a bottleneck architecture to reduce parameters bottleneck_dim = 256 # Significantly smaller bottleneck dimension - + # Factorized shared downsampling using two smaller convolutions self.shared_downsample = nn.Sequential( # First reduce input dimension to bottleneck @@ -537,13 +537,13 @@ class MiniMindLM(PreTrainedModel): # Then expand to target dimension nn.Conv1d(bottleneck_dim, 128*8, kernel_size=1, padding='same') ) - + # Specific layers for v path self.downsample_v_specific = nn.Sequential( nn.Conv1d(128*8, 128, kernel_size=1, padding='same'), nn.Conv1d(128, 8, kernel_size=1, padding='same') ) - + # Specific layers for q path self.downsample_q_specific = nn.Sequential( nn.Conv1d(128*8, 512, kernel_size=1, padding='same') @@ -551,7 +551,6 @@ class MiniMindLM(PreTrainedModel): self.register_buffer("pos_cis", precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta), persistent=False) - self.OUT = CausalLMOutputWithPast() self.params = params def forward(self, @@ -572,13 +571,13 @@ class MiniMindLM(PreTrainedModel): if self.params.disable_db: # 创建一个形状为[batch_size, n_layers, dim]的tensor,所有元素值为1e-4 batch_size = h.size(0) - db_value = torch.full((batch_size, self.n_layers, self.params.dim), 1e-4, + db_value = torch.full((batch_size, self.n_layers, self.params.dim), 1e-4, dtype=h.dtype, device=h.device) else: # 正常模式,使用数据库查询 index = self.extract_db.q_to_k(h) db_value = self.extract_db.get_data(index) - + h, past_kv = layer( h, db_value, pos_cis, past_key_value=past_key_values[l], @@ -587,15 +586,15 @@ class MiniMindLM(PreTrainedModel): past_kvs.append(past_kv) h_list.append(h.unsqueeze(0)) - + h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3) - + # 只在非禁用数据库模式下执行数据库更新逻辑 if not self.params.disable_db: # 使用detach()分离计算图,避免多次反向传播 h_tensor_detached = h_tensor.detach() h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0], -1, self.params.dim) - + # 数据库更新逻辑与主计算图分离 with torch.no_grad(): # Compute shared downsampling layer once @@ -604,15 +603,24 @@ class MiniMindLM(PreTrainedModel): z_q = self.downsample_q_specific(shared_features) z_k = self.extract_db.q_to_k(z_q) self.extract_db.updata_value(z_k, z_v) - + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.output(self.norm(h)[:, slice_indices, :]) aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) - self.OUT.__setitem__('last_hidden_state', h) - self.OUT.__setitem__('logits', logits) - self.OUT.__setitem__('aux_loss', aux_loss) - self.OUT.__setitem__('past_key_values', past_kvs) - return self.OUT + + # 进一步简化,只保留必要的参数 + output = CausalLMOutputWithPast( + logits=logits, + past_key_values=past_kvs, + ) + + # 尝试添加其他属性(如果支持的话) + try: + output.hidden_states = h + except: + pass + + return output @torch.inference_mode() def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90, diff --git a/train_pretrain.py b/train_pretrain.py index d2a86fa..da25834 100644 --- a/train_pretrain.py +++ b/train_pretrain.py @@ -40,6 +40,8 @@ def get_lr(current_step, total_steps, lr): def train_epoch(epoch, wandb): loss_fct = nn.CrossEntropyLoss(reduction='none') start_time = time.time() + # 在函数开始处定义moe_path,避免在异常处理中引用未定义变量 + moe_path = '_moe' if lm_config.use_moe else '' for step, (X, Y, loss_mask) in enumerate(train_loader): try: # 将数据加载到设备上 @@ -59,7 +61,20 @@ def train_epoch(epoch, wandb): Y.view(-1) ).view(Y.size()) loss = (loss * loss_mask).sum() / loss_mask.sum() - loss += res.aux_loss + # 添加辅助损失,如果存在的话 + try: + if hasattr(model, 'module'): + # DDP情况 + aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers + if hasattr(l.feed_forward, 'aux_loss')) + else: + # 非DDP情况 + aux_loss = sum(l.feed_forward.aux_loss for l in model.layers + if hasattr(l.feed_forward, 'aux_loss')) + loss += aux_loss + except Exception as e: + Logger(f"Warning: Could not add auxiliary loss: {e}") + # 如果出错,不添加辅助损失 loss = loss / args.accumulation_steps # Print data types for debugging @@ -106,7 +121,7 @@ def train_epoch(epoch, wandb): # 保存模型 if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): model.eval() - moe_path = '_moe' if lm_config.use_moe else '' + # 使用函数开始处定义的moe_path变量 ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth' if isinstance(model, torch.nn.parallel.DistributedDataParallel): @@ -122,9 +137,9 @@ def train_epoch(epoch, wandb): save_path = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}_nanERROR.pth' if os.path.exists(save_path): os.remove(save_path) - + if isinstance(model, torch.nn.parallel.DistributedDataParallel): - state_dict = model.module.state_dict() + state_dict = model.module.state_dict() else: state_dict = model.state_dict() torch.save(state_dict, save_path) @@ -132,12 +147,12 @@ def train_epoch(epoch, wandb): for name, param in model.named_parameters(): if param.grad is not None and torch.isnan(param.grad).any(): print(f"NaN gradient in parameter: {name}") - + for name, param in model.named_parameters(): if param.grad is not None and torch.isnan(param.grad).any(): print(f"Parameter {name} values: {param.data}") print(f"Parameter {name} gradients: {param.grad}") - + raise ValueError("NaN gradient detected") @@ -179,7 +194,7 @@ if __name__ == "__main__": parser.add_argument("--out_dir", type=str, default="out") # 若要以最快速度实现zero则epochs设置为1轮;否则应当利用有限的数据训练2~6个epochs。 parser.add_argument("--epochs", type=int, default=3) - parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--batch_size", type=int, default=8) parser.add_argument("--learning_rate", type=float, default=2e-4) parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用,则使用GPU,否则使用CPU。 parser.add_argument("--dtype", type=str, default="bfloat16") @@ -193,9 +208,9 @@ if __name__ == "__main__": parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。 parser.add_argument("--save_interval", type=int, default=100) #模型保存间隔,用于控制模型保存的频率。 parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。 - parser.add_argument('--dim', default=4096, type=int) #模型维度,用于控制模型的大小。 + parser.add_argument('--dim', default=2048, type=int) #模型维度,用于控制模型的大小。 parser.add_argument('--n_layers', default=32, type=int) #层数,用于控制模型层数。 - parser.add_argument('--max_seq_len', default=2048, type=int) #最大序列长度,用于控制输入序列的最大长度。 + parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。 parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE,用于控制是否使用MOE。 parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代") #禁用数据库功能,启用特殊模式 parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。 @@ -203,9 +218,9 @@ if __name__ == "__main__": args = parser.parse_args() lm_config = LMConfig( - dim=args.dim, - n_layers=args.n_layers, - max_seq_len=args.max_seq_len, + dim=args.dim, + n_layers=args.n_layers, + max_seq_len=args.max_seq_len, use_moe=args.use_moe, disable_db=args.disable_db # 添加禁用数据库参数 ) #创建LMConfig对象,用于控制模型配置。 @@ -240,11 +255,11 @@ if __name__ == "__main__": if args.use_wandb and (not ddp or ddp_local_rank == 0): import wandb - + # Merge args and lm_config parameters for wandb config config = vars(args).copy() config.update(lm_config.__dict__) - + wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config) else: wandb = None @@ -267,8 +282,9 @@ if __name__ == "__main__": if ddp: model._ddp_params_and_buffers_to_ignore = {"pos_cis"} - model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) - + # 添加find_unused_parameters=True参数,解决未使用参数的问题 + model = DistributedDataParallel(model, device_ids=[ddp_local_rank], find_unused_parameters=True) + torch.autograd.set_detect_anomaly(True) iter_per_epoch = len(train_loader) for epoch in range(args.epochs):