2025-05-14 00:01:40 +08:00
|
|
|
|
import json
|
|
|
|
|
import random
|
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
import pandas as pd
|
|
|
|
|
import numpy as np
|
|
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
|
import torch
|
|
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
|
import os
|
|
|
|
|
import ast
|
2025-06-29 16:01:36 +08:00
|
|
|
|
from tqdm import tqdm
|
2025-05-14 00:01:40 +08:00
|
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
|
|
|
|
|
|
|
|
|
2025-07-05 22:18:32 +08:00
|
|
|
|
def process_sample_filter(data_args):
|
|
|
|
|
"""处理单个样本的过滤逻辑"""
|
|
|
|
|
sample, valid_predicates = data_args
|
|
|
|
|
if 'target' in sample and isinstance(sample['target'], list):
|
|
|
|
|
# 过滤target中的低频谓词
|
|
|
|
|
valid_targets = []
|
|
|
|
|
for triple in sample['target']:
|
|
|
|
|
if isinstance(triple, dict) and 'predicate' in triple:
|
|
|
|
|
if triple['predicate'] in valid_predicates:
|
|
|
|
|
valid_targets.append(triple)
|
|
|
|
|
|
|
|
|
|
# 如果还有有效的target,保留这个样本
|
|
|
|
|
if valid_targets:
|
|
|
|
|
sample['target'] = valid_targets
|
|
|
|
|
return sample
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
else:
|
|
|
|
|
# 如果没有target信息,保留样本
|
|
|
|
|
return sample
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_sample_validation(data_args):
|
|
|
|
|
"""处理单个样本的验证逻辑"""
|
|
|
|
|
sample, predicate_vocab = data_args
|
|
|
|
|
if not isinstance(sample, dict) or 'text' not in sample:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
targets = sample.get('target', [])
|
|
|
|
|
if not isinstance(targets, list) or len(targets) == 0:
|
|
|
|
|
# 如果没有有效的target,创建一个默认的
|
|
|
|
|
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
|
|
|
|
else:
|
|
|
|
|
# 验证并选择target,优先选择占比小的谓词
|
|
|
|
|
selected_target = None
|
|
|
|
|
min_percentage = float('inf')
|
|
|
|
|
|
|
|
|
|
for triple in targets:
|
|
|
|
|
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
|
|
|
|
|
predicate = triple['predicate']
|
|
|
|
|
|
|
|
|
|
# 使用predicate_vocab中的统计信息
|
|
|
|
|
if predicate in predicate_vocab:
|
|
|
|
|
stats = predicate_vocab[predicate]
|
|
|
|
|
if isinstance(stats, dict) and 'percentage' in stats:
|
|
|
|
|
percentage = stats['percentage']
|
|
|
|
|
if percentage < min_percentage:
|
|
|
|
|
min_percentage = percentage
|
|
|
|
|
selected_target = triple
|
|
|
|
|
elif selected_target is None:
|
|
|
|
|
selected_target = triple
|
|
|
|
|
elif selected_target is None:
|
|
|
|
|
selected_target = triple
|
|
|
|
|
|
|
|
|
|
# 如果没有找到有效的target,使用默认值
|
|
|
|
|
if selected_target is None:
|
|
|
|
|
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
'text': sample['text'],
|
|
|
|
|
'target': selected_target # 只保留一个target
|
|
|
|
|
}
|
|
|
|
|
|
2025-05-14 00:01:40 +08:00
|
|
|
|
|
|
|
|
|
class PretrainDataset(Dataset):
|
|
|
|
|
def __init__(self, data_path, tokenizer, max_length=512):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
self.max_length = max_length
|
|
|
|
|
self.samples = self.load_data(data_path)
|
|
|
|
|
|
|
|
|
|
def load_data(self, path):
|
|
|
|
|
samples = []
|
|
|
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
|
|
|
for line_num, line in enumerate(f, 1):
|
|
|
|
|
data = json.loads(line.strip())
|
|
|
|
|
samples.append(data)
|
|
|
|
|
return samples
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.samples)
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
sample = self.samples[index]
|
|
|
|
|
|
|
|
|
|
# 构建输入文本
|
|
|
|
|
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
|
|
|
|
|
encoding = self.tokenizer(
|
|
|
|
|
text,
|
|
|
|
|
max_length=self.max_length,
|
|
|
|
|
padding='max_length',
|
|
|
|
|
truncation=True,
|
|
|
|
|
return_tensors='pt'
|
|
|
|
|
)
|
|
|
|
|
input_ids = encoding.input_ids.squeeze()
|
|
|
|
|
loss_mask = (input_ids != self.tokenizer.pad_token_id)
|
|
|
|
|
|
|
|
|
|
X = torch.tensor(input_ids[:-1], dtype=torch.long)
|
|
|
|
|
Y = torch.tensor(input_ids[1:], dtype=torch.long)
|
|
|
|
|
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)
|
|
|
|
|
return X, Y, loss_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SFTDataset(Dataset):
|
|
|
|
|
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
self.max_length = max_length
|
|
|
|
|
self.samples = self.load_data(jsonl_path)
|
|
|
|
|
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
|
|
|
|
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.samples)
|
|
|
|
|
|
|
|
|
|
def load_data(self, path):
|
|
|
|
|
samples = []
|
|
|
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
|
|
|
for line_num, line in enumerate(f, 1):
|
|
|
|
|
data = json.loads(line.strip())
|
|
|
|
|
samples.append(data)
|
|
|
|
|
return samples
|
|
|
|
|
|
|
|
|
|
def _create_chat_prompt(self, conversations):
|
|
|
|
|
"""构建符合ChatML格式的对话"""
|
|
|
|
|
messages = []
|
|
|
|
|
for i, turn in enumerate(conversations):
|
|
|
|
|
role = 'user' if i % 2 == 0 else 'assistant'
|
|
|
|
|
messages.append({"role": role, "content": turn['content']})
|
|
|
|
|
return self.tokenizer.apply_chat_template(
|
|
|
|
|
messages,
|
|
|
|
|
tokenize=False,
|
|
|
|
|
add_generation_prompt=False
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _generate_loss_mask(self, input_ids):
|
|
|
|
|
loss_mask = [0] * len(input_ids)
|
|
|
|
|
i = 0
|
|
|
|
|
while i < len(input_ids):
|
|
|
|
|
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
|
|
|
|
|
start = i + len(self.bos_id)
|
|
|
|
|
end = start
|
|
|
|
|
while end < len(input_ids):
|
|
|
|
|
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
|
|
|
|
|
break
|
|
|
|
|
end += 1
|
|
|
|
|
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
|
|
|
|
|
loss_mask[j] = 1
|
|
|
|
|
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
|
|
|
|
|
else:
|
|
|
|
|
i += 1
|
|
|
|
|
return loss_mask
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
sample = self.samples[index]
|
|
|
|
|
# 构建对话提示
|
|
|
|
|
prompt = self._create_chat_prompt(sample['conversations'])
|
|
|
|
|
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
|
|
|
|
|
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
|
|
|
|
|
|
|
|
|
|
# 生成动态损失掩码
|
|
|
|
|
loss_mask = self._generate_loss_mask(input_ids)
|
|
|
|
|
|
|
|
|
|
# 构建训练数据
|
|
|
|
|
X = torch.tensor(input_ids[:-1], dtype=torch.long)
|
|
|
|
|
Y = torch.tensor(input_ids[1:], dtype=torch.long)
|
|
|
|
|
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐预测位置
|
|
|
|
|
|
|
|
|
|
return X, Y, loss_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DPODataset(Dataset):
|
|
|
|
|
def __init__(self, file_path, tokenizer, max_length=4096):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
self.max_length = max_length
|
|
|
|
|
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
|
|
|
|
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
|
|
|
|
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
self.data = []
|
|
|
|
|
for line in f:
|
|
|
|
|
line = line.strip()
|
|
|
|
|
obj = json.loads(line)
|
|
|
|
|
self.data.append(obj)
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.data)
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
item = self.data[index]
|
|
|
|
|
chosen = item['chosen'] # 是一个 list,里面包含若干 {role, content}
|
|
|
|
|
rejected = item['rejected'] # 同上
|
|
|
|
|
chosen_prompt = self.tokenizer.apply_chat_template(
|
|
|
|
|
chosen, tokenize=False, add_generation_prompt=False
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
rejected_prompt = self.tokenizer.apply_chat_template(
|
|
|
|
|
rejected, tokenize=False, add_generation_prompt=False
|
|
|
|
|
)
|
|
|
|
|
chosen_encoding = self.tokenizer(
|
|
|
|
|
chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
|
|
|
|
|
)
|
|
|
|
|
rejected_encoding = self.tokenizer(
|
|
|
|
|
rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
chosen_input_ids = chosen_encoding['input_ids']
|
|
|
|
|
chosen_loss_mask = self._generate_loss_mask(chosen_input_ids)
|
|
|
|
|
|
|
|
|
|
rejected_input_ids = rejected_encoding['input_ids']
|
|
|
|
|
rejected_loss_mask = self._generate_loss_mask(rejected_input_ids)
|
|
|
|
|
x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
|
|
|
|
|
y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
|
|
|
|
|
mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
|
|
|
|
|
x_rejected = torch.tensor(rejected_input_ids[:-1], dtype=torch.long)
|
|
|
|
|
y_rejected = torch.tensor(rejected_input_ids[1:], dtype=torch.long)
|
|
|
|
|
mask_rejected = torch.tensor(rejected_loss_mask[1:], dtype=torch.long)
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
'x_chosen': x_chosen,
|
|
|
|
|
'y_chosen': y_chosen,
|
|
|
|
|
'mask_chosen': mask_chosen,
|
|
|
|
|
'x_rejected': x_rejected,
|
|
|
|
|
'y_rejected': y_rejected,
|
|
|
|
|
'mask_rejected': mask_rejected
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def _generate_loss_mask(self, input_ids):
|
|
|
|
|
loss_mask = [0] * len(input_ids)
|
|
|
|
|
i = 0
|
|
|
|
|
while i < len(input_ids):
|
|
|
|
|
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
|
|
|
|
|
start = i + len(self.bos_id)
|
|
|
|
|
end = start
|
|
|
|
|
while end < len(input_ids):
|
|
|
|
|
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
|
|
|
|
|
break
|
|
|
|
|
end += 1
|
|
|
|
|
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
|
|
|
|
|
loss_mask[j] = 1
|
|
|
|
|
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
|
|
|
|
|
else:
|
|
|
|
|
i += 1
|
|
|
|
|
return loss_mask
|
|
|
|
|
|
|
|
|
|
|
2025-06-29 16:01:36 +08:00
|
|
|
|
class TriplePretrainDataset(Dataset):
|
|
|
|
|
"""
|
|
|
|
|
优化的三元组预训练数据集
|
|
|
|
|
- 每个样本只保留一个target三元组
|
|
|
|
|
- 预先tokenize所有数据
|
|
|
|
|
- 使用进度条显示处理进度
|
|
|
|
|
"""
|
2025-07-05 22:18:32 +08:00
|
|
|
|
def __init__(self, data_path=None, predicate_vocab_path=None, samples = None,tokenizer=None, max_length=512):
|
2025-06-29 16:01:36 +08:00
|
|
|
|
super().__init__()
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
self.max_length = max_length
|
2025-07-05 22:18:32 +08:00
|
|
|
|
self.val_samples = None
|
|
|
|
|
self.predicate_to_id = {} # 初始化
|
|
|
|
|
if samples is None:
|
|
|
|
|
self.predicate_vocab = self.load_predicate_vocab(predicate_vocab_path)
|
|
|
|
|
print("🚀 开始加载和预处理三元组数据...")
|
|
|
|
|
self.samples,self.val_samples = self.load_and_preprocess_data(data_path)
|
|
|
|
|
print("🚀 加载和预处理三元组数据完成")
|
|
|
|
|
else:
|
|
|
|
|
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
|
|
|
|
|
data_filename = os.path.basename(data_path).split('.')[0]
|
|
|
|
|
predicate_to_id_path = os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json')
|
|
|
|
|
self.predicate_to_id = self.load_predicate_vocab(predicate_to_id_path)
|
|
|
|
|
self.samples = samples
|
|
|
|
|
print("🚀 加载和预处理三元组数据完成")
|
|
|
|
|
def load_predicate_vocab(self, path):
|
|
|
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
|
|
|
predicate_vocab = json.load(f)
|
|
|
|
|
return predicate_vocab
|
|
|
|
|
|
|
|
|
|
def get_val_samples(self):
|
|
|
|
|
return self.val_samples
|
|
|
|
|
|
|
|
|
|
def clear_cache(self, data_path):
|
|
|
|
|
"""清除缓存文件"""
|
|
|
|
|
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
|
|
|
|
|
data_filename = os.path.basename(data_path).split('.')[0]
|
|
|
|
|
cache_files = [
|
|
|
|
|
os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
|
|
|
|
|
os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
|
|
|
|
|
os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
|
|
|
|
|
os.path.join(cache_dir, f'{data_filename}_val_samples.json')
|
|
|
|
|
]
|
2025-06-29 16:01:36 +08:00
|
|
|
|
|
2025-07-05 22:18:32 +08:00
|
|
|
|
for cache_file in cache_files:
|
|
|
|
|
if os.path.exists(cache_file):
|
|
|
|
|
os.remove(cache_file)
|
|
|
|
|
print(f"🗑️ 已删除缓存文件: {cache_file}")
|
2025-06-29 16:01:36 +08:00
|
|
|
|
|
2025-07-05 22:18:32 +08:00
|
|
|
|
if os.path.exists(cache_dir) and not os.listdir(cache_dir):
|
|
|
|
|
os.rmdir(cache_dir)
|
|
|
|
|
print(f"🗑️ 已删除空的缓存目录: {cache_dir}")
|
|
|
|
|
|
2025-06-29 16:01:36 +08:00
|
|
|
|
def load_and_preprocess_data(self, path):
|
|
|
|
|
"""加载并预处理三元组数据"""
|
2025-07-05 22:18:32 +08:00
|
|
|
|
# 生成缓存文件名(基于数据文件路径)
|
|
|
|
|
cache_dir = os.path.join(os.path.dirname(path), 'cache')
|
|
|
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
data_filename = os.path.basename(path).split('.')[0]
|
|
|
|
|
cache_files = {
|
|
|
|
|
'predicate_vocab': os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
|
|
|
|
|
'predicate_to_id': os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
|
|
|
|
|
'train_samples': os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
|
|
|
|
|
'val_samples': os.path.join(cache_dir, f'{data_filename}_val_samples.json')
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# 检查缓存文件是否存在
|
|
|
|
|
cache_exists = all(os.path.exists(cache_file) for cache_file in cache_files.values())
|
|
|
|
|
|
|
|
|
|
if cache_exists:
|
|
|
|
|
print("📁 发现缓存文件,直接加载...")
|
|
|
|
|
# 从缓存加载
|
|
|
|
|
with open(cache_files['predicate_vocab'], 'r', encoding='utf-8') as f:
|
|
|
|
|
self.predicate_vocab = json.load(f)
|
|
|
|
|
|
|
|
|
|
with open(cache_files['predicate_to_id'], 'r', encoding='utf-8') as f:
|
|
|
|
|
self.predicate_to_id = json.load(f)
|
|
|
|
|
|
|
|
|
|
with open(cache_files['train_samples'], 'r', encoding='utf-8') as f:
|
|
|
|
|
train_samples = json.load(f)
|
|
|
|
|
|
|
|
|
|
with open(cache_files['val_samples'], 'r', encoding='utf-8') as f:
|
|
|
|
|
val_samples = json.load(f)
|
|
|
|
|
|
|
|
|
|
print(f"✅ 从缓存加载完成:")
|
|
|
|
|
print(f"✅ 谓词词表大小: {len(self.predicate_vocab)}")
|
|
|
|
|
print(f"✅ 训练集大小: {len(train_samples)}")
|
|
|
|
|
print(f"✅ 测试集大小: {len(val_samples)}")
|
|
|
|
|
|
|
|
|
|
return train_samples, val_samples
|
|
|
|
|
|
|
|
|
|
# 缓存不存在,重新处理数据
|
|
|
|
|
print("📂 缓存不存在,开始加载和处理原始数据...")
|
|
|
|
|
|
2025-06-29 16:01:36 +08:00
|
|
|
|
# 1. 加载原始数据
|
|
|
|
|
print("📂 加载原始数据...")
|
|
|
|
|
if path.endswith('.json'):
|
|
|
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
|
|
|
data = json.load(f)
|
|
|
|
|
elif path.endswith('.jsonl'):
|
|
|
|
|
data = []
|
|
|
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
|
|
|
for line in f:
|
|
|
|
|
if line.strip():
|
|
|
|
|
data.append(json.loads(line.strip()))
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported file format: {path}")
|
|
|
|
|
|
|
|
|
|
print(f"📊 原始数据量: {len(data)} 个样本")
|
2025-07-05 22:18:32 +08:00
|
|
|
|
|
|
|
|
|
# 2. 使用self.predicate_vocab过滤占比小于0.01%的谓词数据
|
|
|
|
|
print("🔍 过滤低频谓词数据...")
|
|
|
|
|
print(f"📊 谓词统计数据: 总共{len(self.predicate_vocab)}个谓词")
|
2025-06-29 16:01:36 +08:00
|
|
|
|
|
2025-07-05 22:18:32 +08:00
|
|
|
|
# 3.获取占比大于等于0.01%的谓词
|
|
|
|
|
valid_predicates = set()
|
|
|
|
|
for predicate, stats in self.predicate_vocab.items():
|
|
|
|
|
if isinstance(stats, dict) and 'percentage' in stats:
|
|
|
|
|
if stats['percentage'] >= 0.01:
|
|
|
|
|
valid_predicates.add(predicate)
|
|
|
|
|
else:
|
|
|
|
|
# 如果不是统计格式,假设是有效谓词
|
|
|
|
|
valid_predicates.add(predicate)
|
2025-06-29 16:01:36 +08:00
|
|
|
|
|
2025-07-05 22:18:32 +08:00
|
|
|
|
print(f"📊 占比≥0.01%的谓词: {len(valid_predicates)}个")
|
|
|
|
|
|
|
|
|
|
# 4.过滤数据:去除包含低频谓词的数据(单进程处理)
|
|
|
|
|
original_count = len(data)
|
|
|
|
|
filtered_data = []
|
|
|
|
|
|
|
|
|
|
print("🚀 开始过滤低频谓词数据...")
|
|
|
|
|
for sample in tqdm(data, desc="过滤低频谓词"):
|
|
|
|
|
result = process_sample_filter((sample, valid_predicates))
|
|
|
|
|
if result is not None:
|
|
|
|
|
filtered_data.append(result)
|
|
|
|
|
|
|
|
|
|
data = filtered_data
|
|
|
|
|
print(f"✅ 过滤完成: 去除前{original_count}条,去除后{len(data)}条")
|
|
|
|
|
|
|
|
|
|
# 5. 去除self.predicate_vocab中占比小于0.01%的谓词,并创建谓词到序号的映射
|
|
|
|
|
print("🔍 更新谓词词表并创建序号映射...")
|
|
|
|
|
original_vocab_size = len(self.predicate_vocab)
|
|
|
|
|
filtered_predicate_vocab = {}
|
|
|
|
|
|
|
|
|
|
for predicate, stats in self.predicate_vocab.items():
|
|
|
|
|
if isinstance(stats, dict) and 'percentage' in stats:
|
|
|
|
|
if stats['percentage'] >= 0.01:
|
|
|
|
|
filtered_predicate_vocab[predicate] = stats
|
2025-06-29 16:01:36 +08:00
|
|
|
|
else:
|
2025-07-05 22:18:32 +08:00
|
|
|
|
# 如果不是统计格式,保留
|
|
|
|
|
filtered_predicate_vocab[predicate] = stats
|
|
|
|
|
|
|
|
|
|
# 创建谓词到序号的映射字典
|
|
|
|
|
self.predicate_to_id = {predicate: idx for idx, predicate in enumerate(filtered_predicate_vocab.keys())}
|
|
|
|
|
self.predicate_vocab = filtered_predicate_vocab
|
|
|
|
|
print(f"✅ 谓词词表更新: 去除前{original_vocab_size}个,去除后{len(self.predicate_vocab)}个")
|
|
|
|
|
print(f"✅ 谓词映射创建: {len(self.predicate_to_id)}个谓词对应序号")
|
|
|
|
|
|
|
|
|
|
# 6. 数据验证和筛选(只保留一个target),优先选择占比小的谓词以平衡数据(单进程处理)
|
|
|
|
|
print("🔍 验证数据格式并选择单个target(平衡数据)...")
|
|
|
|
|
valid_samples = []
|
|
|
|
|
|
|
|
|
|
print("🚀 开始验证数据格式...")
|
|
|
|
|
for sample in tqdm(data, desc="验证数据格式"):
|
|
|
|
|
result = process_sample_validation((sample, self.predicate_vocab))
|
|
|
|
|
if result is not None:
|
|
|
|
|
valid_samples.append(result)
|
2025-06-29 16:01:36 +08:00
|
|
|
|
|
|
|
|
|
print(f"✅ 有效样本数: {len(valid_samples)}")
|
2025-07-05 22:18:32 +08:00
|
|
|
|
|
|
|
|
|
# 7.拆分训练集合与测试集合
|
|
|
|
|
import random
|
|
|
|
|
random.seed(42)
|
|
|
|
|
val_samples = random.sample(valid_samples, min(1000, len(valid_samples)))
|
|
|
|
|
train_samples = [sample for sample in valid_samples if sample not in val_samples]
|
|
|
|
|
print(f"✅ 训练集大小: {len(train_samples)}")
|
|
|
|
|
print(f"✅ 测试集大小: {len(val_samples)}")
|
|
|
|
|
|
|
|
|
|
# 8. 保存到缓存文件
|
|
|
|
|
print("💾 保存处理结果到缓存文件...")
|
|
|
|
|
with open(cache_files['predicate_vocab'], 'w', encoding='utf-8') as f:
|
|
|
|
|
json.dump(self.predicate_vocab, f, ensure_ascii=False, indent=2)
|
2025-06-29 16:01:36 +08:00
|
|
|
|
|
2025-07-05 22:18:32 +08:00
|
|
|
|
with open(cache_files['predicate_to_id'], 'w', encoding='utf-8') as f:
|
|
|
|
|
json.dump(self.predicate_to_id, f, ensure_ascii=False, indent=2)
|
2025-06-29 16:01:36 +08:00
|
|
|
|
|
2025-07-05 22:18:32 +08:00
|
|
|
|
with open(cache_files['train_samples'], 'w', encoding='utf-8') as f:
|
|
|
|
|
json.dump(train_samples, f, ensure_ascii=False, indent=2)
|
2025-06-29 16:01:36 +08:00
|
|
|
|
|
2025-07-05 22:18:32 +08:00
|
|
|
|
with open(cache_files['val_samples'], 'w', encoding='utf-8') as f:
|
|
|
|
|
json.dump(val_samples, f, ensure_ascii=False, indent=2)
|
2025-06-29 16:01:36 +08:00
|
|
|
|
|
2025-07-05 22:18:32 +08:00
|
|
|
|
print("✅ 缓存文件保存完成")
|
|
|
|
|
|
|
|
|
|
return train_samples, val_samples
|
2025-06-29 16:01:36 +08:00
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.samples)
|
|
|
|
|
|
|
|
|
|
def _triple_to_sentence(self, triple):
|
|
|
|
|
"""将三元组转换为句子格式"""
|
|
|
|
|
return f"{triple['subject']} {triple['predicate']} {triple['object']}"
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
2025-07-05 22:18:32 +08:00
|
|
|
|
"""返回数据,用于谓词分类任务"""
|
2025-06-29 16:01:36 +08:00
|
|
|
|
sample = self.samples[index]
|
|
|
|
|
|
2025-07-05 22:18:32 +08:00
|
|
|
|
# 在运行时tokenize输入文本
|
2025-06-29 16:01:36 +08:00
|
|
|
|
input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}"
|
|
|
|
|
encoding = self.tokenizer(
|
|
|
|
|
input_text,
|
|
|
|
|
max_length=self.max_length,
|
|
|
|
|
padding='max_length',
|
|
|
|
|
truncation=True,
|
|
|
|
|
return_tensors='pt'
|
|
|
|
|
)
|
|
|
|
|
input_ids = encoding.input_ids.squeeze()
|
|
|
|
|
loss_mask = (input_ids != self.tokenizer.pad_token_id)
|
|
|
|
|
|
2025-07-05 22:18:32 +08:00
|
|
|
|
# 获取谓词分类标签
|
|
|
|
|
target_predicate = sample['target']['predicate']
|
|
|
|
|
predicate_label = self.predicate_to_id.get(target_predicate) # 默认为0如果找不到
|
|
|
|
|
|
2025-06-29 16:01:36 +08:00
|
|
|
|
# 构建训练数据
|
|
|
|
|
X = input_ids[:-1]
|
|
|
|
|
loss_mask = loss_mask[1:]
|
2025-07-05 03:03:43 +00:00
|
|
|
|
# 提取谓词label
|
|
|
|
|
# 先尝试从target_sentence中间取出谓词
|
|
|
|
|
predicate_label = 0 # 默认0
|
|
|
|
|
try:
|
|
|
|
|
# target_sentence格式:主语 谓语 宾语
|
|
|
|
|
triple_str = sample['target_sentence']
|
|
|
|
|
triple_parts = triple_str.strip().split()
|
|
|
|
|
if len(triple_parts) >= 3:
|
|
|
|
|
predicate = triple_parts[1]
|
|
|
|
|
predicate_label = PREDICATE2ID.get(predicate, 0)
|
|
|
|
|
except Exception:
|
|
|
|
|
predicate_label = 0
|
2025-06-29 16:01:36 +08:00
|
|
|
|
return {
|
|
|
|
|
'input_ids': X,
|
2025-07-05 22:18:32 +08:00
|
|
|
|
'labels': torch.tensor(predicate_label, dtype=torch.long), # 谓词分类标签
|
|
|
|
|
'loss_mask': loss_mask
|
2025-06-29 16:01:36 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2025-05-14 00:01:40 +08:00
|
|
|
|
class RLAIFDataset(Dataset):
|
|
|
|
|
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
self.max_length = max_length
|
|
|
|
|
self.samples = self.load_data(jsonl_path)
|
|
|
|
|
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
|
|
|
|
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.samples)
|
|
|
|
|
|
|
|
|
|
def load_data(self, path):
|
|
|
|
|
samples = []
|
|
|
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
|
|
|
for line_num, line in enumerate(f, 1):
|
|
|
|
|
data = json.loads(line.strip())
|
|
|
|
|
samples.append(data)
|
|
|
|
|
return samples
|
|
|
|
|
|
|
|
|
|
def _create_chat_prompt(self, conversations):
|
|
|
|
|
"""构建符合ChatML格式的对话"""
|
|
|
|
|
messages = []
|
|
|
|
|
answer = ''
|
|
|
|
|
for i, turn in enumerate(conversations):
|
|
|
|
|
role = 'user' if i % 2 == 0 else 'assistant'
|
|
|
|
|
messages.append({"role": role, "content": turn['content']})
|
|
|
|
|
answer = turn['content']
|
|
|
|
|
return self.tokenizer.apply_chat_template(
|
|
|
|
|
messages[:-1],
|
|
|
|
|
tokenize=False,
|
|
|
|
|
add_generation_prompt=True
|
|
|
|
|
), answer
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
sample = self.samples[index]
|
|
|
|
|
# 构建对话提示
|
|
|
|
|
prompt, answer = self._create_chat_prompt(sample['conversations'])
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
'prompt': prompt,
|
|
|
|
|
'answer': answer
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
pass
|