DynamicKV-LLM Pretrain v1.2.1

This commit is contained in:
iomgaa 2025-06-08 02:20:36 +00:00
parent 1678e739b6
commit 770c34f0e3
4 changed files with 54 additions and 6 deletions

View File

@ -19,6 +19,7 @@ class LMConfig(PretrainedConfig):
rope_theta: int = 1e6,
dropout: float = 0.0,
flash_attn: bool = True,
embeddings_epoch: int = 2,
####################################################
# DB related configurations
####################################################
@ -54,6 +55,7 @@ class LMConfig(PretrainedConfig):
self.rope_theta = rope_theta
self.dropout = dropout
self.flash_attn = flash_attn
self.embeddings_epoch = embeddings_epoch
####################################################
# DB related configurations
####################################################

View File

@ -81,6 +81,8 @@ class KnowledgeDataset(nn.Module):
# 计算step数目用于动态调整权重
self.step_counter = 0
self.freeze_embedding = False
def intelligent_selection(self, query, all_scores, all_indices):
@ -169,6 +171,8 @@ class KnowledgeDataset(nn.Module):
return all_best_tokens, all_best_tokens_embeddings
def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings):
if self.freeze_embedding:
return
# 使用pre_update_embeddings更新self.keys
with torch.no_grad():
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512]
@ -199,8 +203,26 @@ class KnowledgeDataset(nn.Module):
if self.is_train:
# 获取未更新过的keys的索引
not_updated_indices = torch.where(self.has_update_keys == 0)[0]
# 如果有未更新的keys随机选择num_update_keys个进行更新
if len(not_updated_indices) > 0:
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
perm_num = perm.shape[0]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
else:
print("all keys are updated")
# 重置所有keys的更新状态
self.has_update_keys.zero_()
# 重新获取所有可更新的索引
not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device)
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
pre_update_indices = not_updated_indices[perm]
@ -208,6 +230,12 @@ class KnowledgeDataset(nn.Module):
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return best_tokens, best_tokens_embeddings
@ -484,12 +512,20 @@ class MiniMindLM(PreTrainedModel):
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
self.freeze_embedding = False
def forward(self,
input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
step: int = 0,
**args):
start_pos = args.get('start_pos', 0)
if self.freeze_embedding and step == 0:
self.tok_embeddings.weight.requires_grad = False
# 同时冻结KnowledgeDataset的嵌入更新
self.knowledge_dataset.freeze_embedding = True
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers):

View File

@ -1,8 +1,8 @@
#!/bin/bash
# 激活conda环境
# source $(conda info --base)/etc/profile.d/conda.sh
# conda activate ycz_accelerate
source $(conda info --base)/etc/profile.d/conda.sh
conda activate mini
# 设置环境变量以帮助调试
export NCCL_DEBUG=INFO
@ -26,7 +26,7 @@ export PYTHONFAULTHANDLER=1
# --profile_interval 10
# 方法2: 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0 accelerate launch \
CUDA_VISIBLE_DEVICES=0 /opt/conda/envs/mini/bin/python -m accelerate.commands.launch \
--num_processes=1 \
--mixed_precision=bf16 \
--main_process_port=29500 \

View File

@ -224,6 +224,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
total_steps_in_epoch = len(train_loader)
total_training_steps = args.epochs * total_steps_in_epoch
moe_path = '_moe' if args.use_moe else ''
best_loss = float('10000')
# 添加CUDA事件来分析性能 (只在主进程进行)
if args.profile and accelerator.is_main_process:
@ -287,7 +288,12 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
# 前向传播
with ctx:
res = model(X)
if step == 0 and args.embedding_epoch == epoch:
# 需要设置原始模型的freeze_embedding属性而不是包装后的模型
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.freeze_embedding = True
Logger(f"Set freeze_embedding=True for epoch {epoch}, step {step}", accelerator)
res = model(X, step=step)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
@ -411,7 +417,9 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
wandb.log(log_dict)
# 保存模型 (只在主进程进行)
if (step + 1) % args.save_interval == 0 and accelerator.is_main_process:
loss_total = loss.item() * args.accumulation_steps
if best_loss > loss_total and accelerator.is_main_process:
best_loss = loss_total
# 使用函数开始处定义的moe_path变量
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
@ -431,6 +439,7 @@ def main():
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=4)
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--dtype", type=str, default="bfloat16")
@ -495,7 +504,8 @@ def main():
disable_db=args.disable_db,
flash_attn=args.use_flash_attn,
knowledge_num=args.knowledge_num,
knowledge_length=args.knowledge_length
knowledge_length=args.knowledge_length,
embeddings_epoch=args.embedding_epoch
)
#########################################################