This commit is contained in:
Jax922 2025-05-12 11:53:10 +08:00
parent a3ea93597c
commit d93889194d
2 changed files with 104 additions and 80 deletions

View File

@ -551,7 +551,6 @@ class MiniMindLM(PreTrainedModel):
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
self.params = params
def forward(self,
@ -608,11 +607,20 @@ class MiniMindLM(PreTrainedModel):
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
self.OUT.__setitem__('last_hidden_state', h)
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss)
self.OUT.__setitem__('past_key_values', past_kvs)
return self.OUT
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
past_key_values=past_kvs,
)
# 尝试添加其他属性(如果支持的话)
try:
output.hidden_states = h
except:
pass
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,

View File

@ -40,6 +40,8 @@ def get_lr(current_step, total_steps, lr):
def train_epoch(epoch, wandb):
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
# 在函数开始处定义moe_path避免在异常处理中引用未定义变量
moe_path = '_moe' if lm_config.use_moe else ''
for step, (X, Y, loss_mask) in enumerate(train_loader):
try:
# 将数据加载到设备上
@ -59,7 +61,20 @@ def train_epoch(epoch, wandb):
Y.view(-1)
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
loss += res.aux_loss
# 添加辅助损失,如果存在的话
try:
if hasattr(model, 'module'):
# DDP情况
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
if hasattr(l.feed_forward, 'aux_loss'))
else:
# 非DDP情况
aux_loss = sum(l.feed_forward.aux_loss for l in model.layers
if hasattr(l.feed_forward, 'aux_loss'))
loss += aux_loss
except Exception as e:
Logger(f"Warning: Could not add auxiliary loss: {e}")
# 如果出错,不添加辅助损失
loss = loss / args.accumulation_steps
# Print data types for debugging
@ -106,7 +121,7 @@ def train_epoch(epoch, wandb):
# 保存模型
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
# 使用函数开始处定义的moe_path变量
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
@ -179,7 +194,7 @@ if __name__ == "__main__":
parser.add_argument("--out_dir", type=str, default="out")
# 若要以最快速度实现zero则epochs设置为1轮否则应当利用有限的数据训练2~6个epochs。
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用则使用GPU否则使用CPU。
parser.add_argument("--dtype", type=str, default="bfloat16")
@ -193,9 +208,9 @@ if __name__ == "__main__":
parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。
parser.add_argument("--save_interval", type=int, default=100) #模型保存间隔,用于控制模型保存的频率。
parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。
parser.add_argument('--dim', default=4096, type=int) #模型维度,用于控制模型的大小。
parser.add_argument('--dim', default=2048, type=int) #模型维度,用于控制模型的大小。
parser.add_argument('--n_layers', default=32, type=int) #层数,用于控制模型层数。
parser.add_argument('--max_seq_len', default=2048, type=int) #最大序列长度,用于控制输入序列的最大长度。
parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。
parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE用于控制是否使用MOE。
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代") #禁用数据库功能,启用特殊模式
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
@ -267,7 +282,8 @@ if __name__ == "__main__":
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
# 添加find_unused_parameters=True参数解决未使用参数的问题
model = DistributedDataParallel(model, device_ids=[ddp_local_rank], find_unused_parameters=True)
torch.autograd.set_detect_anomaly(True)
iter_per_epoch = len(train_loader)