diff --git a/model/LMConfig.py b/model/LMConfig.py index e9cc3f5..f5e0012 100644 --- a/model/LMConfig.py +++ b/model/LMConfig.py @@ -19,6 +19,7 @@ class LMConfig(PretrainedConfig): rope_theta: int = 1e6, dropout: float = 0.0, flash_attn: bool = True, + embeddings_epoch: int = 2, #################################################### # DB related configurations #################################################### @@ -54,6 +55,7 @@ class LMConfig(PretrainedConfig): self.rope_theta = rope_theta self.dropout = dropout self.flash_attn = flash_attn + self.embeddings_epoch = embeddings_epoch #################################################### # DB related configurations #################################################### diff --git a/model/model.py b/model/model.py index 625be9a..c94fd2c 100644 --- a/model/model.py +++ b/model/model.py @@ -81,6 +81,8 @@ class KnowledgeDataset(nn.Module): # 计算step数目,用于动态调整权重 self.step_counter = 0 + self.freeze_embedding = False + def intelligent_selection(self, query, all_scores, all_indices): @@ -169,6 +171,8 @@ class KnowledgeDataset(nn.Module): return all_best_tokens, all_best_tokens_embeddings def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings): + if self.freeze_embedding: + return # 使用pre_update_embeddings更新self.keys with torch.no_grad(): pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512] @@ -199,8 +203,26 @@ class KnowledgeDataset(nn.Module): if self.is_train: # 获取未更新过的keys的索引 not_updated_indices = torch.where(self.has_update_keys == 0)[0] + # 如果有未更新的keys,随机选择num_update_keys个进行更新 if len(not_updated_indices) > 0: + num_update_keys = int(self.knowledge_num * 0.01) + perm = torch.randperm(len(not_updated_indices))[:num_update_keys] + perm_num = perm.shape[0] + pre_update_indices = not_updated_indices[perm] + pre_update_tokens = self.knowledge_dataset[pre_update_indices] + pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1)) + pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1) + self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings) + # 更新被修改过的key + with torch.no_grad(): + self.has_update_keys[pre_update_indices] = 1 + else: + print("all keys are updated") + # 重置所有keys的更新状态 + self.has_update_keys.zero_() + # 重新获取所有可更新的索引 + not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device) num_update_keys = int(self.knowledge_num * 0.01) perm = torch.randperm(len(not_updated_indices))[:num_update_keys] pre_update_indices = not_updated_indices[perm] @@ -208,6 +230,12 @@ class KnowledgeDataset(nn.Module): pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1)) pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1) self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings) + # 更新被修改过的key + with torch.no_grad(): + self.has_update_keys[pre_update_indices] = 1 + + + return best_tokens, best_tokens_embeddings @@ -484,12 +512,20 @@ class MiniMindLM(PreTrainedModel): precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta), persistent=False) self.OUT = CausalLMOutputWithPast() + self.freeze_embedding = False def forward(self, input_ids: Optional[torch.Tensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + step: int = 0, **args): start_pos = args.get('start_pos', 0) + if self.freeze_embedding and step == 0: + self.tok_embeddings.weight.requires_grad = False + # 同时冻结KnowledgeDataset的嵌入更新 + self.knowledge_dataset.freeze_embedding = True + print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad) + print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding) h = self.dropout(self.tok_embeddings(input_ids)) pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] for l, layer in enumerate(self.layers): diff --git a/run_file/DynamicKV-LLM_Mini_Minimind.sh b/run_file/DynamicKV-LLM_Mini_Minimind.sh index 1ca761f..45fe379 100644 --- a/run_file/DynamicKV-LLM_Mini_Minimind.sh +++ b/run_file/DynamicKV-LLM_Mini_Minimind.sh @@ -1,8 +1,8 @@ #!/bin/bash # 激活conda环境 -# source $(conda info --base)/etc/profile.d/conda.sh -# conda activate ycz_accelerate +source $(conda info --base)/etc/profile.d/conda.sh +conda activate mini # 设置环境变量以帮助调试 export NCCL_DEBUG=INFO @@ -26,7 +26,7 @@ export PYTHONFAULTHANDLER=1 # --profile_interval 10 # 方法2: 使用命令行参数直接配置accelerate -CUDA_VISIBLE_DEVICES=0 accelerate launch \ +CUDA_VISIBLE_DEVICES=0 /opt/conda/envs/mini/bin/python -m accelerate.commands.launch \ --num_processes=1 \ --mixed_precision=bf16 \ --main_process_port=29500 \ diff --git a/train_pretrain_accelerate.py b/train_pretrain_accelerate.py index 5606ad8..e00e6a4 100644 --- a/train_pretrain_accelerate.py +++ b/train_pretrain_accelerate.py @@ -224,6 +224,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a total_steps_in_epoch = len(train_loader) total_training_steps = args.epochs * total_steps_in_epoch moe_path = '_moe' if args.use_moe else '' + best_loss = float('10000') # 添加CUDA事件来分析性能 (只在主进程进行) if args.profile and accelerator.is_main_process: @@ -287,7 +288,12 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a # 前向传播 with ctx: - res = model(X) + if step == 0 and args.embedding_epoch == epoch: + # 需要设置原始模型的freeze_embedding属性,而不是包装后的模型 + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.freeze_embedding = True + Logger(f"Set freeze_embedding=True for epoch {epoch}, step {step}", accelerator) + res = model(X, step=step) loss = loss_fct( res.logits.view(-1, res.logits.size(-1)), Y.view(-1) @@ -411,7 +417,9 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a wandb.log(log_dict) # 保存模型 (只在主进程进行) - if (step + 1) % args.save_interval == 0 and accelerator.is_main_process: + loss_total = loss.item() * args.accumulation_steps + if best_loss > loss_total and accelerator.is_main_process: + best_loss = loss_total # 使用函数开始处定义的moe_path变量 ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth' @@ -431,6 +439,7 @@ def main(): parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate") parser.add_argument("--out_dir", type=str, default="out") parser.add_argument("--epochs", type=int, default=4) + parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数") parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--learning_rate", type=float, default=2e-4) parser.add_argument("--dtype", type=str, default="bfloat16") @@ -495,7 +504,8 @@ def main(): disable_db=args.disable_db, flash_attn=args.use_flash_attn, knowledge_num=args.knowledge_num, - knowledge_length=args.knowledge_length + knowledge_length=args.knowledge_length, + embeddings_epoch=args.embedding_epoch ) #########################################################