DynamicKV-LLM Pretrain v1.2.1
This commit is contained in:
parent
1678e739b6
commit
770c34f0e3
@ -19,6 +19,7 @@ class LMConfig(PretrainedConfig):
|
||||
rope_theta: int = 1e6,
|
||||
dropout: float = 0.0,
|
||||
flash_attn: bool = True,
|
||||
embeddings_epoch: int = 2,
|
||||
####################################################
|
||||
# DB related configurations
|
||||
####################################################
|
||||
@ -54,6 +55,7 @@ class LMConfig(PretrainedConfig):
|
||||
self.rope_theta = rope_theta
|
||||
self.dropout = dropout
|
||||
self.flash_attn = flash_attn
|
||||
self.embeddings_epoch = embeddings_epoch
|
||||
####################################################
|
||||
# DB related configurations
|
||||
####################################################
|
||||
|
@ -81,6 +81,8 @@ class KnowledgeDataset(nn.Module):
|
||||
# 计算step数目,用于动态调整权重
|
||||
self.step_counter = 0
|
||||
|
||||
self.freeze_embedding = False
|
||||
|
||||
|
||||
|
||||
def intelligent_selection(self, query, all_scores, all_indices):
|
||||
@ -169,6 +171,8 @@ class KnowledgeDataset(nn.Module):
|
||||
return all_best_tokens, all_best_tokens_embeddings
|
||||
|
||||
def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings):
|
||||
if self.freeze_embedding:
|
||||
return
|
||||
# 使用pre_update_embeddings更新self.keys
|
||||
with torch.no_grad():
|
||||
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512]
|
||||
@ -199,8 +203,26 @@ class KnowledgeDataset(nn.Module):
|
||||
if self.is_train:
|
||||
# 获取未更新过的keys的索引
|
||||
not_updated_indices = torch.where(self.has_update_keys == 0)[0]
|
||||
|
||||
# 如果有未更新的keys,随机选择num_update_keys个进行更新
|
||||
if len(not_updated_indices) > 0:
|
||||
num_update_keys = int(self.knowledge_num * 0.01)
|
||||
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
|
||||
perm_num = perm.shape[0]
|
||||
pre_update_indices = not_updated_indices[perm]
|
||||
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
|
||||
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
|
||||
pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1)
|
||||
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||||
# 更新被修改过的key
|
||||
with torch.no_grad():
|
||||
self.has_update_keys[pre_update_indices] = 1
|
||||
else:
|
||||
print("all keys are updated")
|
||||
# 重置所有keys的更新状态
|
||||
self.has_update_keys.zero_()
|
||||
# 重新获取所有可更新的索引
|
||||
not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device)
|
||||
num_update_keys = int(self.knowledge_num * 0.01)
|
||||
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
|
||||
pre_update_indices = not_updated_indices[perm]
|
||||
@ -208,6 +230,12 @@ class KnowledgeDataset(nn.Module):
|
||||
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
|
||||
pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1)
|
||||
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
|
||||
# 更新被修改过的key
|
||||
with torch.no_grad():
|
||||
self.has_update_keys[pre_update_indices] = 1
|
||||
|
||||
|
||||
|
||||
|
||||
return best_tokens, best_tokens_embeddings
|
||||
|
||||
@ -484,12 +512,20 @@ class MiniMindLM(PreTrainedModel):
|
||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
self.OUT = CausalLMOutputWithPast()
|
||||
self.freeze_embedding = False
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
step: int = 0,
|
||||
**args):
|
||||
start_pos = args.get('start_pos', 0)
|
||||
if self.freeze_embedding and step == 0:
|
||||
self.tok_embeddings.weight.requires_grad = False
|
||||
# 同时冻结KnowledgeDataset的嵌入更新
|
||||
self.knowledge_dataset.freeze_embedding = True
|
||||
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
|
||||
print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
for l, layer in enumerate(self.layers):
|
||||
|
@ -1,8 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 激活conda环境
|
||||
# source $(conda info --base)/etc/profile.d/conda.sh
|
||||
# conda activate ycz_accelerate
|
||||
source $(conda info --base)/etc/profile.d/conda.sh
|
||||
conda activate mini
|
||||
|
||||
# 设置环境变量以帮助调试
|
||||
export NCCL_DEBUG=INFO
|
||||
@ -26,7 +26,7 @@ export PYTHONFAULTHANDLER=1
|
||||
# --profile_interval 10
|
||||
|
||||
# 方法2: 使用命令行参数直接配置accelerate
|
||||
CUDA_VISIBLE_DEVICES=0 accelerate launch \
|
||||
CUDA_VISIBLE_DEVICES=0 /opt/conda/envs/mini/bin/python -m accelerate.commands.launch \
|
||||
--num_processes=1 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
|
@ -224,6 +224,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
total_steps_in_epoch = len(train_loader)
|
||||
total_training_steps = args.epochs * total_steps_in_epoch
|
||||
moe_path = '_moe' if args.use_moe else ''
|
||||
best_loss = float('10000')
|
||||
|
||||
# 添加CUDA事件来分析性能 (只在主进程进行)
|
||||
if args.profile and accelerator.is_main_process:
|
||||
@ -287,7 +288,12 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
|
||||
# 前向传播
|
||||
with ctx:
|
||||
res = model(X)
|
||||
if step == 0 and args.embedding_epoch == epoch:
|
||||
# 需要设置原始模型的freeze_embedding属性,而不是包装后的模型
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.freeze_embedding = True
|
||||
Logger(f"Set freeze_embedding=True for epoch {epoch}, step {step}", accelerator)
|
||||
res = model(X, step=step)
|
||||
loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
Y.view(-1)
|
||||
@ -411,7 +417,9 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
wandb.log(log_dict)
|
||||
|
||||
# 保存模型 (只在主进程进行)
|
||||
if (step + 1) % args.save_interval == 0 and accelerator.is_main_process:
|
||||
loss_total = loss.item() * args.accumulation_steps
|
||||
if best_loss > loss_total and accelerator.is_main_process:
|
||||
best_loss = loss_total
|
||||
# 使用函数开始处定义的moe_path变量
|
||||
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
|
||||
|
||||
@ -431,6 +439,7 @@ def main():
|
||||
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
|
||||
parser.add_argument("--out_dir", type=str, default="out")
|
||||
parser.add_argument("--epochs", type=int, default=4)
|
||||
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
@ -495,7 +504,8 @@ def main():
|
||||
disable_db=args.disable_db,
|
||||
flash_attn=args.use_flash_attn,
|
||||
knowledge_num=args.knowledge_num,
|
||||
knowledge_length=args.knowledge_length
|
||||
knowledge_length=args.knowledge_length,
|
||||
embeddings_epoch=args.embedding_epoch
|
||||
)
|
||||
|
||||
#########################################################
|
||||
|
Loading…
x
Reference in New Issue
Block a user