pretrain过程中会打印10个token以方便观察

This commit is contained in:
Yu Chengzhang 2025-07-17 00:05:34 +08:00
parent 2797b76939
commit d701003f8a
10 changed files with 12738 additions and 33 deletions

View File

@ -14,7 +14,7 @@
}, },
{ {
"id": 1, "id": 1,
"content": "<s>", "content": "<|im_start|>",
"single_word": false, "single_word": false,
"lstrip": false, "lstrip": false,
"rstrip": false, "rstrip": false,
@ -23,7 +23,7 @@
}, },
{ {
"id": 2, "id": 2,
"content": "</s>", "content": "<|im_end|>",
"single_word": false, "single_word": false,
"lstrip": false, "lstrip": false,
"rstrip": false, "rstrip": false,
@ -56,8 +56,8 @@
"ignore_merges": false, "ignore_merges": false,
"vocab": { "vocab": {
"<unk>": 0, "<unk>": 0,
"<s>": 1, "<|im_start|>": 1,
"</s>": 2, "<|im_end|>": 2,
"!": 3, "!": 3,
"\"": 4, "\"": 4,
"#": 5, "#": 5,

View File

@ -12,7 +12,7 @@
"special": true "special": true
}, },
"1": { "1": {
"content": "<s>", "content": "<|im_start|>",
"lstrip": false, "lstrip": false,
"normalized": false, "normalized": false,
"rstrip": false, "rstrip": false,
@ -20,7 +20,7 @@
"special": true "special": true
}, },
"2": { "2": {
"content": "</s>", "content": "<|im_end|>",
"lstrip": false, "lstrip": false,
"normalized": false, "normalized": false,
"rstrip": false, "rstrip": false,
@ -29,9 +29,9 @@
} }
}, },
"additional_special_tokens": [], "additional_special_tokens": [],
"bos_token": "<s>", "bos_token": "<|im_start|>",
"clean_up_tokenization_spaces": false, "clean_up_tokenization_spaces": false,
"eos_token": "</s>", "eos_token": "<|im_end|>",
"legacy": true, "legacy": true,
"model_max_length": 32768, "model_max_length": 32768,
"pad_token": "<unk>", "pad_token": "<unk>",
@ -39,5 +39,5 @@
"spaces_between_special_tokens": false, "spaces_between_special_tokens": false,
"tokenizer_class": "PreTrainedTokenizerFast", "tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": "<unk>", "unk_token": "<unk>",
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<s>system\\n' + system_message + '</s>\\n' }}{% else %}{{ '<s>system\\n你是 MiniMind是一个有用的人工智能助手。</s>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}" "chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\n你是 MiniMind是一个有用的人工智能助手。<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"
} }

File diff suppressed because one or more lines are too long

View File

@ -583,16 +583,19 @@ class MiniMindLM(PreTrainedModel):
return res return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args): def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None start = input_ids.shape[1]
while input_ids.shape[1] < max_new_tokens - 1: for _ in range(max_new_tokens):
if first_seq: # 每次都传入完整的input_ids不使用KV缓存
out, first_seq = self(input_ids, **args), False out = self(input_ids, **args)
else: logits = out.logits[:, -1, :] # 取最后一个位置的logits
out = self(input_ids[:, -1:],
start_pos=input_ids.shape[1] - 1, **args) # 重复惩罚
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp logits[:, list(set(input_ids.tolist()[0]))] /= rp
# 温度采样
logits /= (temperature + 1e-9) logits /= (temperature + 1e-9)
# Top-p采样
if top_p is not None and top_p < 1.0: if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1) sorted_probs = F.softmax(sorted_logits, dim=-1)
@ -602,8 +605,14 @@ class MiniMindLM(PreTrainedModel):
sorted_indices_to_remove[:, 0] = False sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf') logits[indices_to_remove] = -float('Inf')
# 采样下一个token
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1) input_ids = torch.cat((input_ids, input_ids_next), dim=1)
# 返回新生成的部分
yield input_ids[:, start:] yield input_ids[:, start:]
# 如果遇到结束token停止生成
if input_ids_next.item() == eos_token_id: if input_ids_next.item() == eos_token_id:
break break

View File

@ -453,16 +453,19 @@ class MiniMindLM(PreTrainedModel):
return res return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args): def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None start = input_ids.shape[1]
while input_ids.shape[1] < max_new_tokens - 1: for _ in range(max_new_tokens):
if first_seq: # 每次都传入完整的input_ids不使用KV缓存
out, first_seq = self(input_ids, **args), False out = self(input_ids, **args)
else: logits = out.logits[:, -1, :] # 取最后一个位置的logits
out = self(input_ids[:, -1:],
start_pos=input_ids.shape[1] - 1, **args) # 重复惩罚
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp logits[:, list(set(input_ids.tolist()[0]))] /= rp
# 温度采样
logits /= (temperature + 1e-9) logits /= (temperature + 1e-9)
# Top-p采样
if top_p is not None and top_p < 1.0: if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1) sorted_probs = F.softmax(sorted_logits, dim=-1)
@ -472,8 +475,14 @@ class MiniMindLM(PreTrainedModel):
sorted_indices_to_remove[:, 0] = False sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf') logits[indices_to_remove] = -float('Inf')
# 采样下一个token
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1) input_ids = torch.cat((input_ids, input_ids_next), dim=1)
# 返回新生成的部分
yield input_ids[:, start:] yield input_ids[:, start:]
# 如果遇到结束token停止生成
if input_ids_next.item() == eos_token_id: if input_ids_next.item() == eos_token_id:
break break

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,43 @@
{
"add_bos_token": false,
"add_eos_token": false,
"add_prefix_space": false,
"added_tokens_decoder": {
"0": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "<|im_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "<|im_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [],
"bos_token": "<|im_start|>",
"clean_up_tokenization_spaces": false,
"eos_token": "<|im_end|>",
"legacy": true,
"model_max_length": 32768,
"pad_token": "<unk>",
"sp_model_kwargs": {},
"spaces_between_special_tokens": false,
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": "<unk>",
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\n你是 MiniMind是一个有用的人工智能助手。<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"
}

File diff suppressed because one or more lines are too long

View File

@ -18,21 +18,20 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--out_dir "out" \ --out_dir "out" \
--epochs 3 \ --epochs 3 \
--embedding_epoch 2 \ --embedding_epoch 2 \
--batch_size 48 \ --batch_size 64 \
--learning_rate 2e-4 \ --learning_rate 8e-5 \
--dtype bfloat16 \ --dtype bfloat16 \
--use_swanlab \ --use_swanlab \
--swanlab_project "MiniMind-Pretrain" \ --swanlab_project "MiniMind-Pretrain" \
--num_workers 1 \ --num_workers 1 \
--accumulation_steps 32 \ --accumulation_steps 16 \
--grad_clip 1.0 \ --grad_clip 0.5 \
--warmup_iters 0 \ --warmup_iters 0 \
--log_interval 100 \ --log_interval 100 \
--save_interval 10000 \ --save_interval 10000 \
--dim 1024 \ --dim 1024 \
--n_layers 18 \ --n_layers 48 \
--max_seq_len 512 \ --max_seq_len 512 \
--use_moe False \
--data_path "./dataset/stable/merged_pretrain.jsonl" \ --data_path "./dataset/stable/merged_pretrain.jsonl" \
--profile \ --profile \
--profile_interval 10 \ --profile_interval 10 \
@ -44,4 +43,4 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--cluster_cache_path "./cache/cluster_tokens_single.pt" \ --cluster_cache_path "./cache/cluster_tokens_single.pt" \
--memory_monitor_interval 10 \ --memory_monitor_interval 10 \
--model_type "model_original" \ --model_type "model_original" \
--model_size 814.724 --model_size 538

View File

@ -685,6 +685,47 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
f"Bwd: {backward_time/args.log_interval:.2f}ms, " f"Bwd: {backward_time/args.log_interval:.2f}ms, "
f"Optim: {optimizer_time/args.log_interval:.2f}ms, " f"Optim: {optimizer_time/args.log_interval:.2f}ms, "
f"Iter Time: {iter_time:.2f}ms", accelerator) f"Iter Time: {iter_time:.2f}ms", accelerator)
# 生成文本示例
try:
# 随机选择一个样本
random_idx = torch.randint(0, X.size(0), (1,)).item()
sample_input = X[random_idx:random_idx+1] # [1, seq_len]
# 取前面的部分作为prompt例如前一半
prompt_len = min(sample_input.size(1) // 2, sample_input.size(1) - 10)
prompt_input = sample_input[:, :prompt_len]
# 生成10个token
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.eval() # 设置为评估模式
with torch.no_grad():
generated = unwrapped_model.generate(
prompt_input,
max_new_tokens=10,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
# 转换为人类可读文本
original_text = tokenizer.decode(sample_input[0], skip_special_tokens=True)
prompt_text = tokenizer.decode(prompt_input[0], skip_special_tokens=True)
generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
new_tokens_text = generated_text[len(prompt_text):]
Logger(f"生成文本示例:", accelerator)
Logger(f" 原始文本: {original_text[:100]}...", accelerator)
Logger(f" 输入提示: {prompt_text[-50:]}", accelerator)
Logger(f" 生成续写: {new_tokens_text}", accelerator)
unwrapped_model.train() # 恢复训练模式
except Exception as e:
Logger(f"生成文本示例失败: {e}", accelerator)
# 重置事件以便下次测量从0开始 # 重置事件以便下次测量从0开始
data_start = torch.cuda.Event(enable_timing=True) data_start = torch.cuda.Event(enable_timing=True)
data_end = torch.cuda.Event(enable_timing=True) data_end = torch.cuda.Event(enable_timing=True)