This commit is contained in:
gongjy 2025-02-13 20:56:14 +08:00
parent 416cc90b58
commit b5d10d9a7d
4 changed files with 6 additions and 6 deletions

View File

@ -225,7 +225,7 @@ git clone https://huggingface.co/jingyaogong/MiniMind2
```bash ```bash
# load=0: load from pytorch model, load=1: load from transformers-hf model # load=0: load from pytorch model, load=1: load from transformers-hf model
python eval_model.py --load 1 python eval_model.py --load 1 --model_mode 2
``` ```
### 4.或启动WebUI ### 4.或启动WebUI

View File

@ -239,7 +239,7 @@ git clone https://huggingface.co/jingyaogong/MiniMind2
```bash ```bash
# load=0: load from pytorch model, load=1: load from transformers-hf model # load=0: load from pytorch model, load=1: load from transformers-hf model
python eval_model.py --load 1 python eval_model.py --load 1 --model_mode 2
``` ```
### 4. Or Start WebUI ### 4. Or Start WebUI

View File

@ -25,7 +25,7 @@ def train_tokenizer():
data = json.loads(line) data = json.loads(line)
yield data['text'] yield data['text']
data_path = '../dataset/tokenizer_train.jsonl' data_path = '../dataset/pretrain_hq.jsonl'
# 初始化tokenizer # 初始化tokenizer
tokenizer = Tokenizer(models.BPE()) tokenizer = Tokenizer(models.BPE())
@ -139,12 +139,12 @@ def eval_tokenizer():
print('encoder长度', len(model_inputs['input_ids'])) print('encoder长度', len(model_inputs['input_ids']))
input_ids = model_inputs['input_ids'] input_ids = model_inputs['input_ids']
response = tokenizer.decode(input_ids, skip_special_tokens=True) response = tokenizer.decode(input_ids, skip_special_tokens=False)
print('decoder和原始文本是否一致', response == new_prompt) print('decoder和原始文本是否一致', response == new_prompt)
def main(): def main():
# train_tokenizer() train_tokenizer()
eval_tokenizer() eval_tokenizer()

View File

@ -35,7 +35,7 @@ def train_epoch(epoch, wandb):
# 思考标签占位符 # 思考标签占位符
start_of_think_ids = tokenizer('<think>').input_ids start_of_think_ids = tokenizer('<think>').input_ids
end_of_think_ids = tokenizer('</think>').input_ids end_of_think_ids = tokenizer('</think>').input_ids
start_of_answer_ids = tokenizer('</answer>').input_ids start_of_answer_ids = tokenizer('<answer>').input_ids
end_of_answer_ids = tokenizer('</answer>').input_ids end_of_answer_ids = tokenizer('</answer>').input_ids
loss_fct = nn.CrossEntropyLoss(reduction='none') loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time() start_time = time.time()