From b6bd97aaaabc6eba54acdc0a6189d64324e99a70 Mon Sep 17 00:00:00 2001 From: iomgaa Date: Fri, 9 May 2025 15:01:06 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8A=BD=E5=8F=96self.downsample=5Fv=E4=B8=8Es?= =?UTF-8?q?elf.downsample=5Fq=E7=9A=84=E5=85=B1=E5=90=8C=E9=83=A8=E5=88=86?= =?UTF-8?q?=EF=BC=8C=E5=B9=B6=E4=BD=BF=E7=94=A8=E5=8F=AF=E5=88=86=E7=A6=BB?= =?UTF-8?q?=E5=8D=B7=E7=A7=AF=E9=99=8D=E4=BD=8E=E5=8F=82=E6=95=B0=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/model.py | 51 ++++++++++++++++++++++++++++++++--------------- train_pretrain.py | 6 +++--- 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/model/model.py b/model/model.py index 3dc48b2..878c675 100644 --- a/model/model.py +++ b/model/model.py @@ -174,12 +174,12 @@ class CrossAttention(nn.Module): super().__init__() self.config = config self.num_heads = 8 - self.head_dim = 768 // self.num_heads - self.to_q = nn.Linear(768, 768, bias=False) - self.to_k = nn.Linear(768, 768, bias=False) - self.to_v = nn.Linear(768, 768, bias=False) + self.head_dim = self.config.dim // self.num_heads + self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False) + self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False) + self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False) - self.to_out = nn.Linear(768, 768, bias=False) + self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False) def forward(self, x, db, context_mask=None, pos_emb=None): batch_size = x.size(0) @@ -205,7 +205,7 @@ class CrossAttention(nn.Module): context = torch.matmul(attn_weights, v) - context = context.transpose(1, 2).contiguous().view(batch_size, -1, 768) + context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim) context = self.to_out(context) @@ -523,19 +523,36 @@ class MiniMindLM(PreTrainedModel): self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.tok_embeddings.weight = self.output.weight - self.downsample_v = nn.Sequential( - nn.Conv1d(511*8,128*8,kernel_size=1,padding='same'), - nn.Conv1d(128*8,128,kernel_size=1,padding='same'), - nn.Conv1d(128,8,kernel_size=1,padding='same') + + # Calculate input dimension + input_dim = (self.params.max_seq_len-1)*self.params.n_layers + # Use a bottleneck architecture to reduce parameters + bottleneck_dim = 256 # Significantly smaller bottleneck dimension + + # Factorized shared downsampling using two smaller convolutions + self.shared_downsample = nn.Sequential( + # First reduce input dimension to bottleneck + nn.Conv1d(input_dim, bottleneck_dim, kernel_size=1, padding='same'), + nn.ReLU(), # Non-linearity to improve representation capacity + # Then expand to target dimension + nn.Conv1d(bottleneck_dim, 128*8, kernel_size=1, padding='same') ) - self.downsample_q = nn.Sequential( - nn.Conv1d(511*8,128*8,kernel_size=1,padding='same'), - nn.Conv1d(128*8,512,kernel_size=1,padding='same') + + # Specific layers for v path + self.downsample_v_specific = nn.Sequential( + nn.Conv1d(128*8, 128, kernel_size=1, padding='same'), + nn.Conv1d(128, 8, kernel_size=1, padding='same') + ) + + # Specific layers for q path + self.downsample_q_specific = nn.Sequential( + nn.Conv1d(128*8, 512, kernel_size=1, padding='same') ) self.register_buffer("pos_cis", precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta), persistent=False) self.OUT = CausalLMOutputWithPast() + self.params = params def forward(self, input_ids: Optional[torch.Tensor] = None, @@ -565,12 +582,14 @@ class MiniMindLM(PreTrainedModel): # 使用detach()分离计算图,避免多次反向传播 h_tensor = torch.cat(h_list,dim=0).permute(1,0,2,3) h_tensor_detached = h_tensor.detach() - h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0],-1,768) + h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0],-1,self.params.dim) # 数据库更新逻辑与主计算图分离 with torch.no_grad(): - z_v = self.downsample_v(h_tensor_detached) - z_q = self.downsample_q(h_tensor_detached) + # Compute shared downsampling layer once + shared_features = self.shared_downsample(h_tensor_detached) + z_v = self.downsample_v_specific(shared_features) + z_q = self.downsample_q_specific(shared_features) z_k = self.extract_db.q_to_k(z_q) self.extract_db.updata_value(z_k,z_v) diff --git a/train_pretrain.py b/train_pretrain.py index 1c9995b..92311fa 100644 --- a/train_pretrain.py +++ b/train_pretrain.py @@ -179,7 +179,7 @@ if __name__ == "__main__": parser.add_argument("--out_dir", type=str, default="out") # 若要以最快速度实现zero则epochs设置为1轮;否则应当利用有限的数据训练2~6个epochs。 parser.add_argument("--epochs", type=int, default=3) - parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--learning_rate", type=float, default=5e-4) parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用,则使用GPU,否则使用CPU。 parser.add_argument("--dtype", type=str, default="bfloat16") @@ -193,8 +193,8 @@ if __name__ == "__main__": parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。 parser.add_argument("--save_interval", type=int, default=100) #模型保存间隔,用于控制模型保存的频率。 parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。 - parser.add_argument('--dim', default=768, type=int) #模型维度,用于控制模型的大小。 - parser.add_argument('--n_layers', default=8, type=int) #层数,用于控制模型层数。 + parser.add_argument('--dim', default=592, type=int) #模型维度,用于控制模型的大小。 + parser.add_argument('--n_layers', default=4, type=int) #层数,用于控制模型层数。 parser.add_argument('--max_seq_len', default=512, type=int) #最大序列长度,用于控制输入序列的最大长度。 parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE,用于控制是否使用MOE。 parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。