Minimind/scripts/train_tokenizer.py

153 lines
4.8 KiB
Python
Raw Normal View History

2024-08-28 16:41:44 +08:00
import random
from tqdm import tqdm
from transformers import AutoTokenizer
import json
from datasets import load_dataset
from tokenizers import (
decoders,
models,
normalizers,
pre_tokenizers,
processors,
trainers,
Tokenizer,
)
import os
random.seed(42)
2025-02-09 23:49:47 +08:00
2024-08-28 16:41:44 +08:00
def train_tokenizer():
# 读取JSONL文件并提取文本数据
def read_texts_from_jsonl(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
yield data['text']
2025-02-09 23:49:47 +08:00
data_path = '../dataset/tokenizer_train.jsonl'
2024-08-28 16:41:44 +08:00
# 初始化tokenizer
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
# 定义特殊token
special_tokens = ["<unk>", "<s>", "</s>"]
# 设置训练器并添加特殊token
trainer = trainers.BpeTrainer(
vocab_size=6400,
special_tokens=special_tokens, # 确保这三个token被包含
show_progress=True,
initial_alphabet=pre_tokenizers.ByteLevel.alphabet()
)
# 读取文本数据
texts = read_texts_from_jsonl(data_path)
# 训练tokenizer
tokenizer.train_from_iterator(texts, trainer=trainer)
# 设置解码器
tokenizer.decoder = decoders.ByteLevel()
# 检查特殊token的索引
assert tokenizer.token_to_id("<unk>") == 0
assert tokenizer.token_to_id("<s>") == 1
assert tokenizer.token_to_id("</s>") == 2
# 保存tokenizer
2025-02-09 23:49:47 +08:00
tokenizer_dir = "../model/minimind_tokenizer"
2024-08-28 16:41:44 +08:00
os.makedirs(tokenizer_dir, exist_ok=True)
tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
2025-02-09 23:49:47 +08:00
tokenizer.model.save("../model/minimind_tokenizer")
2024-08-28 16:41:44 +08:00
# 手动创建配置文件
config = {
"add_bos_token": False,
"add_eos_token": False,
2025-02-09 23:49:47 +08:00
"add_prefix_space": False,
2024-08-28 16:41:44 +08:00
"added_tokens_decoder": {
"0": {
"content": "<unk>",
"lstrip": False,
"normalized": False,
"rstrip": False,
"single_word": False,
"special": True
},
"1": {
"content": "<s>",
"lstrip": False,
"normalized": False,
"rstrip": False,
"single_word": False,
"special": True
},
"2": {
"content": "</s>",
"lstrip": False,
"normalized": False,
"rstrip": False,
"single_word": False,
"special": True
}
},
"additional_special_tokens": [],
"bos_token": "<s>",
"clean_up_tokenization_spaces": False,
"eos_token": "</s>",
"legacy": True,
2025-02-09 23:49:47 +08:00
"model_max_length": 32768,
"pad_token": "<unk>",
2024-08-28 16:41:44 +08:00
"sp_model_kwargs": {},
"spaces_between_special_tokens": False,
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": "<unk>",
2025-02-09 23:49:47 +08:00
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<s>system\\n' + system_message + '</s>\\n' }}{% else %}{{ '<s>system\\n你是 MiniMind是一个有用的人工智能助手。</s>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
2024-08-28 16:41:44 +08:00
}
# 保存配置文件
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w", encoding="utf-8") as config_file:
json.dump(config, config_file, ensure_ascii=False, indent=4)
print("Tokenizer training completed and saved.")
def eval_tokenizer():
from transformers import AutoTokenizer
# 加载预训练的tokenizer
2025-02-09 23:49:47 +08:00
tokenizer = AutoTokenizer.from_pretrained("../model/minimind_tokenizer")
2024-08-28 16:41:44 +08:00
messages = [
{"role": "system", "content": "你是一个优秀的聊天机器人,总是给我正确的回应!"},
2024-11-06 17:48:33 +08:00
{"role": "user", "content": '你来自哪里?'},
{"role": "assistant", "content": '我来自地球'}
2024-08-28 16:41:44 +08:00
]
new_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False
)
print(new_prompt)
# 获取实际词汇表长度(包括特殊符号)
actual_vocab_size = len(tokenizer)
2024-11-06 17:48:33 +08:00
print('tokenizer实际词表长度', actual_vocab_size)
2024-08-28 16:41:44 +08:00
model_inputs = tokenizer(new_prompt)
2024-11-06 17:48:33 +08:00
print('encoder长度', len(model_inputs['input_ids']))
2024-08-28 16:41:44 +08:00
2024-11-06 17:48:33 +08:00
input_ids = model_inputs['input_ids']
2025-02-09 23:49:47 +08:00
response = tokenizer.decode(input_ids, skip_special_tokens=True)
2024-11-06 17:48:33 +08:00
print('decoder和原始文本是否一致', response == new_prompt)
2024-08-28 16:41:44 +08:00
2025-02-09 23:49:47 +08:00
2024-08-28 16:41:44 +08:00
def main():
# train_tokenizer()
eval_tokenizer()
if __name__ == '__main__':
main()