使用分类来获取谓词
This commit is contained in:
parent
fcab661af9
commit
75265f6652
299
model/dataset.py
299
model/dataset.py
@ -14,6 +14,70 @@ from tqdm import tqdm
|
|||||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
def process_sample_filter(data_args):
|
||||||
|
"""处理单个样本的过滤逻辑"""
|
||||||
|
sample, valid_predicates = data_args
|
||||||
|
if 'target' in sample and isinstance(sample['target'], list):
|
||||||
|
# 过滤target中的低频谓词
|
||||||
|
valid_targets = []
|
||||||
|
for triple in sample['target']:
|
||||||
|
if isinstance(triple, dict) and 'predicate' in triple:
|
||||||
|
if triple['predicate'] in valid_predicates:
|
||||||
|
valid_targets.append(triple)
|
||||||
|
|
||||||
|
# 如果还有有效的target,保留这个样本
|
||||||
|
if valid_targets:
|
||||||
|
sample['target'] = valid_targets
|
||||||
|
return sample
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
# 如果没有target信息,保留样本
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def process_sample_validation(data_args):
|
||||||
|
"""处理单个样本的验证逻辑"""
|
||||||
|
sample, predicate_vocab = data_args
|
||||||
|
if not isinstance(sample, dict) or 'text' not in sample:
|
||||||
|
return None
|
||||||
|
|
||||||
|
targets = sample.get('target', [])
|
||||||
|
if not isinstance(targets, list) or len(targets) == 0:
|
||||||
|
# 如果没有有效的target,创建一个默认的
|
||||||
|
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
||||||
|
else:
|
||||||
|
# 验证并选择target,优先选择占比小的谓词
|
||||||
|
selected_target = None
|
||||||
|
min_percentage = float('inf')
|
||||||
|
|
||||||
|
for triple in targets:
|
||||||
|
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
|
||||||
|
predicate = triple['predicate']
|
||||||
|
|
||||||
|
# 使用predicate_vocab中的统计信息
|
||||||
|
if predicate in predicate_vocab:
|
||||||
|
stats = predicate_vocab[predicate]
|
||||||
|
if isinstance(stats, dict) and 'percentage' in stats:
|
||||||
|
percentage = stats['percentage']
|
||||||
|
if percentage < min_percentage:
|
||||||
|
min_percentage = percentage
|
||||||
|
selected_target = triple
|
||||||
|
elif selected_target is None:
|
||||||
|
selected_target = triple
|
||||||
|
elif selected_target is None:
|
||||||
|
selected_target = triple
|
||||||
|
|
||||||
|
# 如果没有找到有效的target,使用默认值
|
||||||
|
if selected_target is None:
|
||||||
|
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
||||||
|
|
||||||
|
return {
|
||||||
|
'text': sample['text'],
|
||||||
|
'target': selected_target # 只保留一个target
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class PretrainDataset(Dataset):
|
class PretrainDataset(Dataset):
|
||||||
def __init__(self, data_path, tokenizer, max_length=512):
|
def __init__(self, data_path, tokenizer, max_length=512):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -204,15 +268,94 @@ class TriplePretrainDataset(Dataset):
|
|||||||
- 预先tokenize所有数据
|
- 预先tokenize所有数据
|
||||||
- 使用进度条显示处理进度
|
- 使用进度条显示处理进度
|
||||||
"""
|
"""
|
||||||
def __init__(self, data_path, tokenizer, max_length=512):
|
def __init__(self, data_path=None, predicate_vocab_path=None, samples = None,tokenizer=None, max_length=512):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
print("🚀 开始加载和预处理三元组数据...")
|
self.val_samples = None
|
||||||
self.samples = self.load_and_preprocess_data(data_path)
|
self.predicate_to_id = {} # 初始化
|
||||||
|
if samples is None:
|
||||||
|
self.predicate_vocab = self.load_predicate_vocab(predicate_vocab_path)
|
||||||
|
print("🚀 开始加载和预处理三元组数据...")
|
||||||
|
self.samples,self.val_samples = self.load_and_preprocess_data(data_path)
|
||||||
|
print("🚀 加载和预处理三元组数据完成")
|
||||||
|
else:
|
||||||
|
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
|
||||||
|
data_filename = os.path.basename(data_path).split('.')[0]
|
||||||
|
predicate_to_id_path = os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json')
|
||||||
|
self.predicate_to_id = self.load_predicate_vocab(predicate_to_id_path)
|
||||||
|
self.samples = samples
|
||||||
|
print("🚀 加载和预处理三元组数据完成")
|
||||||
|
def load_predicate_vocab(self, path):
|
||||||
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
|
predicate_vocab = json.load(f)
|
||||||
|
return predicate_vocab
|
||||||
|
|
||||||
|
def get_val_samples(self):
|
||||||
|
return self.val_samples
|
||||||
|
|
||||||
|
def clear_cache(self, data_path):
|
||||||
|
"""清除缓存文件"""
|
||||||
|
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
|
||||||
|
data_filename = os.path.basename(data_path).split('.')[0]
|
||||||
|
cache_files = [
|
||||||
|
os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
|
||||||
|
os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
|
||||||
|
os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
|
||||||
|
os.path.join(cache_dir, f'{data_filename}_val_samples.json')
|
||||||
|
]
|
||||||
|
|
||||||
|
for cache_file in cache_files:
|
||||||
|
if os.path.exists(cache_file):
|
||||||
|
os.remove(cache_file)
|
||||||
|
print(f"🗑️ 已删除缓存文件: {cache_file}")
|
||||||
|
|
||||||
|
if os.path.exists(cache_dir) and not os.listdir(cache_dir):
|
||||||
|
os.rmdir(cache_dir)
|
||||||
|
print(f"🗑️ 已删除空的缓存目录: {cache_dir}")
|
||||||
|
|
||||||
def load_and_preprocess_data(self, path):
|
def load_and_preprocess_data(self, path):
|
||||||
"""加载并预处理三元组数据"""
|
"""加载并预处理三元组数据"""
|
||||||
|
# 生成缓存文件名(基于数据文件路径)
|
||||||
|
cache_dir = os.path.join(os.path.dirname(path), 'cache')
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
data_filename = os.path.basename(path).split('.')[0]
|
||||||
|
cache_files = {
|
||||||
|
'predicate_vocab': os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
|
||||||
|
'predicate_to_id': os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
|
||||||
|
'train_samples': os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
|
||||||
|
'val_samples': os.path.join(cache_dir, f'{data_filename}_val_samples.json')
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检查缓存文件是否存在
|
||||||
|
cache_exists = all(os.path.exists(cache_file) for cache_file in cache_files.values())
|
||||||
|
|
||||||
|
if cache_exists:
|
||||||
|
print("📁 发现缓存文件,直接加载...")
|
||||||
|
# 从缓存加载
|
||||||
|
with open(cache_files['predicate_vocab'], 'r', encoding='utf-8') as f:
|
||||||
|
self.predicate_vocab = json.load(f)
|
||||||
|
|
||||||
|
with open(cache_files['predicate_to_id'], 'r', encoding='utf-8') as f:
|
||||||
|
self.predicate_to_id = json.load(f)
|
||||||
|
|
||||||
|
with open(cache_files['train_samples'], 'r', encoding='utf-8') as f:
|
||||||
|
train_samples = json.load(f)
|
||||||
|
|
||||||
|
with open(cache_files['val_samples'], 'r', encoding='utf-8') as f:
|
||||||
|
val_samples = json.load(f)
|
||||||
|
|
||||||
|
print(f"✅ 从缓存加载完成:")
|
||||||
|
print(f"✅ 谓词词表大小: {len(self.predicate_vocab)}")
|
||||||
|
print(f"✅ 训练集大小: {len(train_samples)}")
|
||||||
|
print(f"✅ 测试集大小: {len(val_samples)}")
|
||||||
|
|
||||||
|
return train_samples, val_samples
|
||||||
|
|
||||||
|
# 缓存不存在,重新处理数据
|
||||||
|
print("📂 缓存不存在,开始加载和处理原始数据...")
|
||||||
|
|
||||||
# 1. 加载原始数据
|
# 1. 加载原始数据
|
||||||
print("📂 加载原始数据...")
|
print("📂 加载原始数据...")
|
||||||
if path.endswith('.json'):
|
if path.endswith('.json'):
|
||||||
@ -228,71 +371,92 @@ class TriplePretrainDataset(Dataset):
|
|||||||
raise ValueError(f"Unsupported file format: {path}")
|
raise ValueError(f"Unsupported file format: {path}")
|
||||||
|
|
||||||
print(f"📊 原始数据量: {len(data)} 个样本")
|
print(f"📊 原始数据量: {len(data)} 个样本")
|
||||||
|
|
||||||
|
# 2. 使用self.predicate_vocab过滤占比小于0.01%的谓词数据
|
||||||
|
print("🔍 过滤低频谓词数据...")
|
||||||
|
print(f"📊 谓词统计数据: 总共{len(self.predicate_vocab)}个谓词")
|
||||||
|
|
||||||
# 2. 数据验证和筛选(只保留一个target)
|
# 3.获取占比大于等于0.01%的谓词
|
||||||
print("🔍 验证数据格式并选择单个target...")
|
valid_predicates = set()
|
||||||
|
for predicate, stats in self.predicate_vocab.items():
|
||||||
|
if isinstance(stats, dict) and 'percentage' in stats:
|
||||||
|
if stats['percentage'] >= 0.01:
|
||||||
|
valid_predicates.add(predicate)
|
||||||
|
else:
|
||||||
|
# 如果不是统计格式,假设是有效谓词
|
||||||
|
valid_predicates.add(predicate)
|
||||||
|
|
||||||
|
print(f"📊 占比≥0.01%的谓词: {len(valid_predicates)}个")
|
||||||
|
|
||||||
|
# 4.过滤数据:去除包含低频谓词的数据(单进程处理)
|
||||||
|
original_count = len(data)
|
||||||
|
filtered_data = []
|
||||||
|
|
||||||
|
print("🚀 开始过滤低频谓词数据...")
|
||||||
|
for sample in tqdm(data, desc="过滤低频谓词"):
|
||||||
|
result = process_sample_filter((sample, valid_predicates))
|
||||||
|
if result is not None:
|
||||||
|
filtered_data.append(result)
|
||||||
|
|
||||||
|
data = filtered_data
|
||||||
|
print(f"✅ 过滤完成: 去除前{original_count}条,去除后{len(data)}条")
|
||||||
|
|
||||||
|
# 5. 去除self.predicate_vocab中占比小于0.01%的谓词,并创建谓词到序号的映射
|
||||||
|
print("🔍 更新谓词词表并创建序号映射...")
|
||||||
|
original_vocab_size = len(self.predicate_vocab)
|
||||||
|
filtered_predicate_vocab = {}
|
||||||
|
|
||||||
|
for predicate, stats in self.predicate_vocab.items():
|
||||||
|
if isinstance(stats, dict) and 'percentage' in stats:
|
||||||
|
if stats['percentage'] >= 0.01:
|
||||||
|
filtered_predicate_vocab[predicate] = stats
|
||||||
|
else:
|
||||||
|
# 如果不是统计格式,保留
|
||||||
|
filtered_predicate_vocab[predicate] = stats
|
||||||
|
|
||||||
|
# 创建谓词到序号的映射字典
|
||||||
|
self.predicate_to_id = {predicate: idx for idx, predicate in enumerate(filtered_predicate_vocab.keys())}
|
||||||
|
self.predicate_vocab = filtered_predicate_vocab
|
||||||
|
print(f"✅ 谓词词表更新: 去除前{original_vocab_size}个,去除后{len(self.predicate_vocab)}个")
|
||||||
|
print(f"✅ 谓词映射创建: {len(self.predicate_to_id)}个谓词对应序号")
|
||||||
|
|
||||||
|
# 6. 数据验证和筛选(只保留一个target),优先选择占比小的谓词以平衡数据(单进程处理)
|
||||||
|
print("🔍 验证数据格式并选择单个target(平衡数据)...")
|
||||||
valid_samples = []
|
valid_samples = []
|
||||||
|
|
||||||
for i, sample in enumerate(tqdm(data, desc="验证数据格式")):
|
print("🚀 开始验证数据格式...")
|
||||||
if not isinstance(sample, dict) or 'text' not in sample:
|
for sample in tqdm(data, desc="验证数据格式"):
|
||||||
continue
|
result = process_sample_validation((sample, self.predicate_vocab))
|
||||||
|
if result is not None:
|
||||||
targets = sample.get('target', [])
|
valid_samples.append(result)
|
||||||
if not isinstance(targets, list) or len(targets) == 0:
|
|
||||||
# 如果没有有效的target,创建一个默认的
|
|
||||||
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
|
||||||
else:
|
|
||||||
# 验证并选择第一个有效的target
|
|
||||||
selected_target = None
|
|
||||||
for triple in targets:
|
|
||||||
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
|
|
||||||
selected_target = triple
|
|
||||||
break
|
|
||||||
|
|
||||||
# 如果没有找到有效的target,使用默认值
|
|
||||||
if selected_target is None:
|
|
||||||
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
|
||||||
|
|
||||||
valid_samples.append({
|
|
||||||
'text': sample['text'],
|
|
||||||
'target': selected_target # 只保留一个target
|
|
||||||
})
|
|
||||||
|
|
||||||
print(f"✅ 有效样本数: {len(valid_samples)}")
|
print(f"✅ 有效样本数: {len(valid_samples)}")
|
||||||
|
|
||||||
|
# 7.拆分训练集合与测试集合
|
||||||
|
import random
|
||||||
|
random.seed(42)
|
||||||
|
val_samples = random.sample(valid_samples, min(1000, len(valid_samples)))
|
||||||
|
train_samples = [sample for sample in valid_samples if sample not in val_samples]
|
||||||
|
print(f"✅ 训练集大小: {len(train_samples)}")
|
||||||
|
print(f"✅ 测试集大小: {len(val_samples)}")
|
||||||
|
|
||||||
|
# 8. 保存到缓存文件
|
||||||
|
print("💾 保存处理结果到缓存文件...")
|
||||||
|
with open(cache_files['predicate_vocab'], 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(self.predicate_vocab, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
# 3. 分批tokenize目标句子
|
with open(cache_files['predicate_to_id'], 'w', encoding='utf-8') as f:
|
||||||
print("🔤 分批tokenize目标句子...")
|
json.dump(self.predicate_to_id, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
processed_samples = []
|
with open(cache_files['train_samples'], 'w', encoding='utf-8') as f:
|
||||||
batch_size = 1000 # 每批处理1000个句子,避免内存爆炸
|
json.dump(train_samples, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
for i in tqdm(range(0, len(valid_samples), batch_size), desc="分批tokenize目标句子"):
|
with open(cache_files['val_samples'], 'w', encoding='utf-8') as f:
|
||||||
# 获取当前批次
|
json.dump(val_samples, f, ensure_ascii=False, indent=2)
|
||||||
batch_samples = valid_samples[i:i + batch_size]
|
|
||||||
|
|
||||||
# 提取当前批次的目标句子
|
|
||||||
batch_target_sentences = [self._triple_to_sentence(sample['target']) for sample in batch_samples]
|
|
||||||
|
|
||||||
# 批量tokenize当前批次
|
|
||||||
batch_encodings = self.tokenizer(
|
|
||||||
batch_target_sentences,
|
|
||||||
max_length=128, # 目标句子通常较短
|
|
||||||
padding='max_length',
|
|
||||||
truncation=True,
|
|
||||||
return_tensors='pt'
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建当前批次的样本数据
|
|
||||||
for j, sample in enumerate(batch_samples):
|
|
||||||
processed_samples.append({
|
|
||||||
'text': sample['text'], # 保持原始文本,不进行tokenize
|
|
||||||
'target_input_ids': batch_encodings.input_ids[j],
|
|
||||||
'target_attention_mask': batch_encodings.attention_mask[j],
|
|
||||||
'target_sentence': batch_target_sentences[j], # 保留原始句子用于调试
|
|
||||||
})
|
|
||||||
|
|
||||||
print(f"🎉 数据预处理完成! 共处理 {len(processed_samples)} 个样本")
|
print("✅ 缓存文件保存完成")
|
||||||
return processed_samples
|
|
||||||
|
return train_samples, val_samples
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.samples)
|
return len(self.samples)
|
||||||
@ -302,10 +466,10 @@ class TriplePretrainDataset(Dataset):
|
|||||||
return f"{triple['subject']} {triple['predicate']} {triple['object']}"
|
return f"{triple['subject']} {triple['predicate']} {triple['object']}"
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
"""返回数据,输入文本在运行时tokenize,目标已预tokenize"""
|
"""返回数据,用于谓词分类任务"""
|
||||||
sample = self.samples[index]
|
sample = self.samples[index]
|
||||||
|
|
||||||
# 在运行时tokenize输入文本(用于语言建模)
|
# 在运行时tokenize输入文本
|
||||||
input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}"
|
input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}"
|
||||||
encoding = self.tokenizer(
|
encoding = self.tokenizer(
|
||||||
input_text,
|
input_text,
|
||||||
@ -317,19 +481,18 @@ class TriplePretrainDataset(Dataset):
|
|||||||
input_ids = encoding.input_ids.squeeze()
|
input_ids = encoding.input_ids.squeeze()
|
||||||
loss_mask = (input_ids != self.tokenizer.pad_token_id)
|
loss_mask = (input_ids != self.tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
# 获取谓词分类标签
|
||||||
|
target_predicate = sample['target']['predicate']
|
||||||
|
predicate_label = self.predicate_to_id.get(target_predicate) # 默认为0如果找不到
|
||||||
|
|
||||||
# 构建训练数据
|
# 构建训练数据
|
||||||
X = input_ids[:-1]
|
X = input_ids[:-1]
|
||||||
Y = input_ids[1:]
|
|
||||||
loss_mask = loss_mask[1:]
|
loss_mask = loss_mask[1:]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'input_ids': X,
|
'input_ids': X,
|
||||||
'labels': Y,
|
'labels': torch.tensor(predicate_label, dtype=torch.long), # 谓词分类标签
|
||||||
'loss_mask': loss_mask,
|
'loss_mask': loss_mask
|
||||||
'target_input_ids': sample['target_input_ids'], # 已经是tensor
|
|
||||||
'target_attention_mask': sample['target_attention_mask'], # 已经是tensor
|
|
||||||
'target_sentence': sample['target_sentence'], # 字符串,用于调试
|
|
||||||
'original_text': sample['text']
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -489,8 +489,8 @@ class TripleExtractionHead(nn.Module):
|
|||||||
self.self_attn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
self.self_attn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||||
|
|
||||||
# 交叉注意力机制(用于主语和宾语提取)
|
# 交叉注意力机制(用于主语和宾语提取)
|
||||||
self.cross_attention_subject = CrossAttention(config)
|
# self.cross_attention_subject = CrossAttention(config)
|
||||||
self.cross_attention_object = CrossAttention(config)
|
# self.cross_attention_object = CrossAttention(config)
|
||||||
|
|
||||||
# 归一化层
|
# 归一化层
|
||||||
self.subject_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
self.subject_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||||
@ -498,13 +498,13 @@ class TripleExtractionHead(nn.Module):
|
|||||||
|
|
||||||
# Feed Forward 网络
|
# Feed Forward 网络
|
||||||
self.predicate_ff = FeedForward(config)
|
self.predicate_ff = FeedForward(config)
|
||||||
self.subject_ff = FeedForward(config)
|
# self.subject_ff = FeedForward(config)
|
||||||
self.object_ff = FeedForward(config)
|
# self.object_ff = FeedForward(config)
|
||||||
|
|
||||||
# 输出投影层 - 修改为支持序列预测
|
# 输出投影层 - 修改为支持序列预测
|
||||||
self.predicate_output = nn.Linear(config.dim, self.max_predicate_len *config.dim, bias=False)
|
self.predicate_output = nn.Linear(config.dim, 264, bias=False)
|
||||||
self.subject_output = nn.Linear(config.dim, self.max_subject_len * config.dim, bias=False)
|
# self.subject_output = nn.Linear(config.dim, self.max_subject_len * config.dim, bias=False)
|
||||||
self.object_output = nn.Linear(config.dim, self.max_object_len * config.dim, bias=False)
|
# self.object_output = nn.Linear(config.dim, self.max_object_len * config.dim, bias=False)
|
||||||
|
|
||||||
print(f"三元组提取任务头配置:")
|
print(f"三元组提取任务头配置:")
|
||||||
print(f"- 主语最大长度: {self.max_subject_len}")
|
print(f"- 主语最大长度: {self.max_subject_len}")
|
||||||
@ -530,30 +530,29 @@ class TripleExtractionHead(nn.Module):
|
|||||||
# 2. h1通过feed_forward得到谓语输出
|
# 2. h1通过feed_forward得到谓语输出
|
||||||
predicate_features = self.predicate_ff(h1)
|
predicate_features = self.predicate_ff(h1)
|
||||||
predicate_features = predicate_features.mean(dim=1)
|
predicate_features = predicate_features.mean(dim=1)
|
||||||
predicate_raw = self.predicate_output(predicate_features) # [batch_size, max_predicate_len * vocab_size]
|
predicate_class = self.predicate_output(predicate_features) # [batch_size, max_predicate_len * vocab_size]
|
||||||
predicate_logits = predicate_raw.view(batch_size, self.max_predicate_len, -1)
|
|
||||||
|
|
||||||
# 3. h1通过交叉注意力(k,v都是h)得到h2
|
# # 3. h1通过交叉注意力(k,v都是h)得到h2
|
||||||
h2 = self.cross_attention_subject(h1, h) # query是h1,key和value都是h
|
# h2 = self.cross_attention_subject(h1, h) # query是h1,key和value都是h
|
||||||
h2 = h1 + h2 # 残差连接
|
# h2 = h1 + h2 # 残差连接
|
||||||
|
|
||||||
# 4. h2通过feed_forward得到主语输出
|
# # 4. h2通过feed_forward得到主语输出
|
||||||
subject_features = self.subject_ff(self.subject_norm(h2))
|
# subject_features = self.subject_ff(self.subject_norm(h2))
|
||||||
subject_features = subject_features.mean(dim=1)
|
# subject_features = subject_features.mean(dim=1)
|
||||||
subject_raw = self.subject_output(subject_features) # [batch_size, max_subject_len * vocab_size]
|
# subject_raw = self.subject_output(subject_features) # [batch_size, max_subject_len * vocab_size]
|
||||||
subject_logits = subject_raw.view(batch_size, self.max_subject_len, -1)
|
# subject_logits = subject_raw.view(batch_size, self.max_subject_len, -1)
|
||||||
|
|
||||||
# 5. h2通过交叉注意力(k,v都是h)得到h3
|
# # 5. h2通过交叉注意力(k,v都是h)得到h3
|
||||||
h3 = self.cross_attention_object(h2, h) # query是h2,key和value都是h
|
# h3 = self.cross_attention_object(h2, h) # query是h2,key和value都是h
|
||||||
h3 = h2 + h3 # 残差连接
|
# h3 = h2 + h3 # 残差连接
|
||||||
|
|
||||||
# 6. h3通过feed_forward得到宾语输出
|
# # 6. h3通过feed_forward得到宾语输出
|
||||||
object_features = self.object_ff(self.object_norm(h3))
|
# object_features = self.object_ff(self.object_norm(h3))
|
||||||
object_features = object_features.mean(dim=1)
|
# object_features = object_features.mean(dim=1)
|
||||||
object_raw = self.object_output(object_features) # [batch_size, max_object_len * vocab_size]
|
# object_raw = self.object_output(object_features) # [batch_size, max_object_len * vocab_size]
|
||||||
object_logits = object_raw.view(batch_size, self.max_object_len, -1)
|
# object_logits = object_raw.view(batch_size, self.max_object_len, -1)
|
||||||
|
|
||||||
return predicate_logits, subject_logits, object_logits
|
return predicate_class
|
||||||
|
|
||||||
|
|
||||||
class MiniMindBlock(nn.Module):
|
class MiniMindBlock(nn.Module):
|
||||||
@ -656,18 +655,8 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 应用三元组提取任务头
|
# 应用三元组提取任务头
|
||||||
predicate_logits, subject_logits, object_logits = self.triple_extraction_head(h, pos_cis)
|
predicate_class = self.triple_extraction_head(h, pos_cis)
|
||||||
predicate_logits = predicate_logits.reshape(input_ids.size(0)*self.params.max_predicate_len, -1)
|
|
||||||
subject_logits = subject_logits.reshape(input_ids.size(0)*self.params.max_subject_len, -1)
|
|
||||||
object_logits = object_logits.reshape(input_ids.size(0)*self.params.max_object_len, -1)
|
|
||||||
|
|
||||||
predicate_logits = self.output(predicate_logits)
|
|
||||||
subject_logits = self.output(subject_logits)
|
|
||||||
object_logits = self.output(object_logits)
|
|
||||||
|
|
||||||
predicate_logits = predicate_logits.reshape(input_ids.size(0), self.params.max_predicate_len, -1)
|
|
||||||
subject_logits = subject_logits.reshape(input_ids.size(0), self.params.max_subject_len, -1)
|
|
||||||
object_logits = object_logits.reshape(input_ids.size(0), self.params.max_object_len, -1)
|
|
||||||
|
|
||||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||||
@ -682,9 +671,7 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
|
|
||||||
# 添加三元组提取结果
|
# 添加三元组提取结果
|
||||||
# 注意:现在的维度是 [batch_size, seq_len, max_len, vocab_size]
|
# 注意:现在的维度是 [batch_size, seq_len, max_len, vocab_size]
|
||||||
output.predicate_logits = predicate_logits
|
output.predicate_class = predicate_class
|
||||||
output.subject_logits = subject_logits
|
|
||||||
output.object_logits = object_logits
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -1,225 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
JSON文件合并脚本
|
|
||||||
读取多个JSON文件并合并为一个JSON文件
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from typing import Dict, List, Any, Union
|
|
||||||
|
|
||||||
# 需要合并的JSON文件列表
|
|
||||||
JSON_FILES_TO_MERGE = [
|
|
||||||
"output/trex_sentences_enhanced_checkpoint_360000.json"
|
|
||||||
]
|
|
||||||
for i in range(1, 1010):
|
|
||||||
JSON_FILES_TO_MERGE.append(f"output/trex_sentences_enhanced_batch_{i}.json")
|
|
||||||
|
|
||||||
def load_json_file(file_path: str) -> Union[Dict, List, None]:
|
|
||||||
"""加载JSON文件"""
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
print(f"警告: 文件 {file_path} 不存在")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
print(f"成功加载: {file_path}")
|
|
||||||
return data
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
print(f"错误: 无法解析JSON文件 {file_path} - {e}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
print(f"错误: 读取文件 {file_path} 失败 - {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def merge_json_data(data1: Union[Dict, List], data2: Union[Dict, List]) -> Union[Dict, List]:
|
|
||||||
"""合并两个JSON数据结构"""
|
|
||||||
|
|
||||||
# 如果两个都是列表,直接合并
|
|
||||||
if isinstance(data1, list) and isinstance(data2, list):
|
|
||||||
print(f"合并两个列表: {len(data1)} + {len(data2)} = {len(data1) + len(data2)} 项")
|
|
||||||
return data1 + data2
|
|
||||||
|
|
||||||
# 如果两个都是字典
|
|
||||||
elif isinstance(data1, dict) and isinstance(data2, dict):
|
|
||||||
print("合并两个字典结构")
|
|
||||||
merged = data1.copy()
|
|
||||||
|
|
||||||
# 特殊处理:如果都有'sentences'字段且为列表,合并sentences
|
|
||||||
if 'sentences' in data1 and 'sentences' in data2:
|
|
||||||
if isinstance(data1['sentences'], list) and isinstance(data2['sentences'], list):
|
|
||||||
print(f"合并sentences字段: {len(data1['sentences'])} + {len(data2['sentences'])} = {len(data1['sentences']) + len(data2['sentences'])} 项")
|
|
||||||
merged['sentences'] = data1['sentences'] + data2['sentences']
|
|
||||||
|
|
||||||
# 更新metadata if exists
|
|
||||||
if 'metadata' in merged:
|
|
||||||
if isinstance(merged['metadata'], dict):
|
|
||||||
merged['metadata']['total_sentences'] = len(merged['sentences'])
|
|
||||||
merged['metadata']['merged_from'] = [os.path.basename(f) for f in JSON_FILES_TO_MERGE if os.path.exists(f)]
|
|
||||||
|
|
||||||
# 合并其他字段
|
|
||||||
for key, value in data2.items():
|
|
||||||
if key != 'sentences' and key not in merged:
|
|
||||||
merged[key] = value
|
|
||||||
|
|
||||||
return merged
|
|
||||||
|
|
||||||
# 普通字典合并
|
|
||||||
for key, value in data2.items():
|
|
||||||
if key in merged:
|
|
||||||
# 如果key重复且都是列表,合并列表
|
|
||||||
if isinstance(merged[key], list) and isinstance(value, list):
|
|
||||||
merged[key] = merged[key] + value
|
|
||||||
# 如果key重复且都是字典,递归合并
|
|
||||||
elif isinstance(merged[key], dict) and isinstance(value, dict):
|
|
||||||
merged[key] = merge_json_data(merged[key], value)
|
|
||||||
else:
|
|
||||||
# 其他情况保留第二个文件的值
|
|
||||||
merged[key] = value
|
|
||||||
print(f"字段 '{key}' 被覆盖")
|
|
||||||
else:
|
|
||||||
merged[key] = value
|
|
||||||
|
|
||||||
return merged
|
|
||||||
|
|
||||||
# 类型不匹配的情况,创建一个包含两者的新结构
|
|
||||||
else:
|
|
||||||
print("数据类型不匹配,创建包含两者的新结构")
|
|
||||||
return {
|
|
||||||
"data_from_save.json": data1,
|
|
||||||
"data_from_save2.json": data2,
|
|
||||||
"merged_at": "test.py"
|
|
||||||
}
|
|
||||||
|
|
||||||
def save_merged_json(data: Union[Dict, List], output_path: str):
|
|
||||||
"""保存合并后的JSON数据"""
|
|
||||||
try:
|
|
||||||
# 确保输出目录存在
|
|
||||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
||||||
|
|
||||||
with open(output_path, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
print(f"合并结果已保存到: {output_path}")
|
|
||||||
|
|
||||||
# 显示统计信息
|
|
||||||
if isinstance(data, dict):
|
|
||||||
if 'sentences' in data and isinstance(data['sentences'], list):
|
|
||||||
print(f"总计句子数: {len(data['sentences'])}")
|
|
||||||
print(f"总计字段数: {len(data)}")
|
|
||||||
elif isinstance(data, list):
|
|
||||||
print(f"总计列表项数: {len(data)}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"错误: 保存文件失败 - {e}")
|
|
||||||
|
|
||||||
def remove_duplicates_from_sentences(data: Union[Dict, List]) -> Union[Dict, List]:
|
|
||||||
"""从合并结果中移除重复的句子(基于句子内容)"""
|
|
||||||
if isinstance(data, dict) and 'sentences' in data:
|
|
||||||
if isinstance(data['sentences'], list):
|
|
||||||
original_count = len(data['sentences'])
|
|
||||||
seen_sentences = set()
|
|
||||||
unique_sentences = []
|
|
||||||
|
|
||||||
for item in data['sentences']:
|
|
||||||
if isinstance(item, dict):
|
|
||||||
# 如果是字典,使用sentence字段或corrected_sentence字段作为唯一标识
|
|
||||||
sentence_key = item.get('sentence') or item.get('corrected_sentence') or item.get('original_sentence')
|
|
||||||
elif isinstance(item, str):
|
|
||||||
sentence_key = item
|
|
||||||
else:
|
|
||||||
sentence_key = str(item)
|
|
||||||
|
|
||||||
if sentence_key and sentence_key not in seen_sentences:
|
|
||||||
seen_sentences.add(sentence_key)
|
|
||||||
unique_sentences.append(item)
|
|
||||||
|
|
||||||
data['sentences'] = unique_sentences
|
|
||||||
|
|
||||||
# 更新metadata
|
|
||||||
if 'metadata' in data and isinstance(data['metadata'], dict):
|
|
||||||
data['metadata']['total_sentences'] = len(unique_sentences)
|
|
||||||
data['metadata']['duplicates_removed'] = original_count - len(unique_sentences)
|
|
||||||
|
|
||||||
print(f"去重完成: {original_count} -> {len(unique_sentences)} (移除了 {original_count - len(unique_sentences)} 个重复项)")
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def merge_multiple_json_data(data_list: List[Union[Dict, List]]) -> Union[Dict, List]:
|
|
||||||
"""合并多个JSON数据结构"""
|
|
||||||
if not data_list:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
if len(data_list) == 1:
|
|
||||||
return data_list[0]
|
|
||||||
|
|
||||||
print(f"准备合并 {len(data_list)} 个JSON数据结构")
|
|
||||||
|
|
||||||
# 从第一个数据开始,逐步合并其他数据
|
|
||||||
merged_data = data_list[0]
|
|
||||||
|
|
||||||
for i, data in enumerate(data_list[1:], 1):
|
|
||||||
print(f"正在合并第 {i+1} 个数据结构...")
|
|
||||||
merged_data = merge_json_data(merged_data, data)
|
|
||||||
|
|
||||||
return merged_data
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""主函数"""
|
|
||||||
print("=== JSON文件合并脚本 ===")
|
|
||||||
|
|
||||||
# 输出路径
|
|
||||||
output_path = "output/merged.json"
|
|
||||||
|
|
||||||
print(f"准备合并以下文件:")
|
|
||||||
for i, file_path in enumerate(JSON_FILES_TO_MERGE, 1):
|
|
||||||
print(f" {i}. {file_path}")
|
|
||||||
print(f"输出文件: {output_path}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# 加载所有文件
|
|
||||||
loaded_data = []
|
|
||||||
successfully_loaded = []
|
|
||||||
|
|
||||||
for file_path in JSON_FILES_TO_MERGE:
|
|
||||||
data = load_json_file(file_path)
|
|
||||||
if data is not None:
|
|
||||||
loaded_data.append(data)
|
|
||||||
successfully_loaded.append(file_path)
|
|
||||||
|
|
||||||
# 检查是否至少有一个文件加载成功
|
|
||||||
if not loaded_data:
|
|
||||||
print("错误: 没有文件能够成功加载,退出")
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"成功加载了 {len(loaded_data)} 个文件:")
|
|
||||||
for file_path in successfully_loaded:
|
|
||||||
print(f" ✓ {file_path}")
|
|
||||||
|
|
||||||
if len(loaded_data) < len(JSON_FILES_TO_MERGE):
|
|
||||||
failed_count = len(JSON_FILES_TO_MERGE) - len(loaded_data)
|
|
||||||
print(f"警告: {failed_count} 个文件加载失败")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# 合并所有数据
|
|
||||||
if len(loaded_data) == 1:
|
|
||||||
print("只有一个文件可用,直接使用...")
|
|
||||||
merged_data = loaded_data[0]
|
|
||||||
else:
|
|
||||||
print("开始合并所有文件...")
|
|
||||||
merged_data = merge_multiple_json_data(loaded_data)
|
|
||||||
|
|
||||||
# 去重处理
|
|
||||||
print("\n检查并去除重复项...")
|
|
||||||
merged_data = remove_duplicates_from_sentences(merged_data)
|
|
||||||
|
|
||||||
# 保存合并结果
|
|
||||||
print("\n保存合并结果...")
|
|
||||||
save_merged_json(merged_data, output_path)
|
|
||||||
|
|
||||||
print("\n=== 合并完成 ===")
|
|
||||||
print(f"合并了 {len(successfully_loaded)} 个文件的数据")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,61 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
小规模测试预处理脚本
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
# 添加路径
|
|
||||||
sys.path.append('/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/preprocessing')
|
|
||||||
|
|
||||||
# 导入主模块
|
|
||||||
from preprocess_pretrain import *
|
|
||||||
|
|
||||||
# 修改配置为小规模测试
|
|
||||||
DATASET_CONFIG["wikipedia"]["max_samples"] = 100
|
|
||||||
DATASET_CONFIG["gutenberg"]["max_samples"] = 50
|
|
||||||
DATASET_CONFIG["openwebtext"]["max_samples"] = 20
|
|
||||||
|
|
||||||
DATASET_CONFIG_EXTRA["wikipedia"]["max_samples"] = 50
|
|
||||||
DATASET_CONFIG_EXTRA["gutenberg"]["max_samples"] = 30
|
|
||||||
DATASET_CONFIG_EXTRA["openwebtext"]["max_samples"] = 15
|
|
||||||
|
|
||||||
# 修改输出路径
|
|
||||||
OUTPUT_FILE = "/tmp/test_main.jsonl"
|
|
||||||
OUTPUT_FILE_EXTRA = "/tmp/test_extra.jsonl"
|
|
||||||
|
|
||||||
def test_small_scale():
|
|
||||||
"""小规模测试"""
|
|
||||||
print("Starting small scale test...")
|
|
||||||
|
|
||||||
# 设置随机种子
|
|
||||||
random.seed(42)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 初始化tokenizer
|
|
||||||
init_tokenizer()
|
|
||||||
|
|
||||||
# 开始合并数据集
|
|
||||||
merge_datasets()
|
|
||||||
|
|
||||||
# 检查输出文件
|
|
||||||
if os.path.exists(OUTPUT_FILE):
|
|
||||||
with open(OUTPUT_FILE, 'r') as f:
|
|
||||||
main_lines = len(f.readlines())
|
|
||||||
print(f"Main file created: {main_lines} lines")
|
|
||||||
|
|
||||||
if os.path.exists(OUTPUT_FILE_EXTRA):
|
|
||||||
with open(OUTPUT_FILE_EXTRA, 'r') as f:
|
|
||||||
extra_lines = len(f.readlines())
|
|
||||||
print(f"Extra file created: {extra_lines} lines")
|
|
||||||
|
|
||||||
print("Small scale test completed successfully!")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Test failed: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_small_scale()
|
|
File diff suppressed because it is too large
Load Diff
@ -127,6 +127,7 @@ dependencies = [
|
|||||||
"regex==2024.11.6",
|
"regex==2024.11.6",
|
||||||
"requests==2.32.3",
|
"requests==2.32.3",
|
||||||
"rich==13.7.1",
|
"rich==13.7.1",
|
||||||
|
"rouge-score>=0.1.2",
|
||||||
"rpds-py==0.24.0",
|
"rpds-py==0.24.0",
|
||||||
"s3transfer==0.13.0",
|
"s3transfer==0.13.0",
|
||||||
"safetensors==0.5.3",
|
"safetensors==0.5.3",
|
||||||
|
@ -397,6 +397,66 @@ def log_memory_status(step, accelerator, stage="", detailed=False):
|
|||||||
|
|
||||||
Logger(log_msg, accelerator)
|
Logger(log_msg, accelerator)
|
||||||
|
|
||||||
|
# 验证函数
|
||||||
|
def validate_model(model, val_loader, accelerator, ctx, args):
|
||||||
|
"""
|
||||||
|
验证模型性能
|
||||||
|
Args:
|
||||||
|
model: 模型
|
||||||
|
val_loader: 验证集数据加载器
|
||||||
|
accelerator: accelerator对象
|
||||||
|
ctx: 上下文管理器
|
||||||
|
args: 参数
|
||||||
|
Returns:
|
||||||
|
dict: 包含平均损失和准确率的字典
|
||||||
|
"""
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
total_loss = 0.0
|
||||||
|
correct_predictions = 0
|
||||||
|
total_predictions = 0
|
||||||
|
num_batches = 0
|
||||||
|
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_data in val_loader:
|
||||||
|
try:
|
||||||
|
# 数据准备
|
||||||
|
X = batch_data['input_ids'].to(accelerator.device)
|
||||||
|
Y = batch_data['labels']
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
with ctx:
|
||||||
|
res = model(X, step=0) # 验证时step设为0
|
||||||
|
loss = criterion(res.predicate_class.cpu(), Y.cpu())
|
||||||
|
|
||||||
|
# 计算准确率
|
||||||
|
predicted_classes = torch.argmax(res.predicate_class, dim=1)
|
||||||
|
predicted_classes = predicted_classes.to(Y.device)
|
||||||
|
correct_predictions += (predicted_classes == Y).sum().item()
|
||||||
|
total_predictions += Y.size(0)
|
||||||
|
|
||||||
|
# 累计损失
|
||||||
|
total_loss += loss.item()
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"验证时出错: {e}", accelerator)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 计算平均值
|
||||||
|
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
||||||
|
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
||||||
|
|
||||||
|
model.train() # 重新设置为训练模式
|
||||||
|
|
||||||
|
return {
|
||||||
|
'avg_loss': avg_loss,
|
||||||
|
'accuracy': accuracy,
|
||||||
|
'total_samples': total_predictions
|
||||||
|
}
|
||||||
|
|
||||||
# 日志记录函数
|
# 日志记录函数
|
||||||
def Logger(msg, accelerator=None):
|
def Logger(msg, accelerator=None):
|
||||||
# 如果没有提供accelerator,则只在主进程打印
|
# 如果没有提供accelerator,则只在主进程打印
|
||||||
@ -515,7 +575,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer):
|
def train_epoch(epoch, accelerator, model, train_loader,val_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer):
|
||||||
# 三元组提取训练模式:不需要传统的交叉熵损失函数
|
# 三元组提取训练模式:不需要传统的交叉熵损失函数
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
total_steps_in_epoch = len(train_loader)
|
total_steps_in_epoch = len(train_loader)
|
||||||
@ -563,9 +623,9 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
|||||||
X = batch_data['input_ids']
|
X = batch_data['input_ids']
|
||||||
Y = batch_data['labels']
|
Y = batch_data['labels']
|
||||||
loss_mask = batch_data['loss_mask']
|
loss_mask = batch_data['loss_mask']
|
||||||
target_input_ids = batch_data['target_input_ids']
|
# target_input_ids = batch_data['target_input_ids']
|
||||||
target_attention_mask = batch_data['target_attention_mask']
|
# target_attention_mask = batch_data['target_attention_mask']
|
||||||
target_sentences = batch_data['target_sentences'] # 用于调试输出
|
# target_sentences = batch_data['target_sentences'] # 用于调试输出
|
||||||
|
|
||||||
# === 2. 学习率更新 ===
|
# === 2. 学习率更新 ===
|
||||||
if scheduler is not None:
|
if scheduler is not None:
|
||||||
@ -590,36 +650,34 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
|||||||
|
|
||||||
# === 4. 损失计算 ===
|
# === 4. 损失计算 ===
|
||||||
# 三元组提取模式:只使用ROUGE Loss进行三元组损失计算
|
# 三元组提取模式:只使用ROUGE Loss进行三元组损失计算
|
||||||
Logger("三元组提取训练模式", accelerator) if step == 0 else None
|
# Logger("三元组提取训练模式", accelerator) if step == 0 else None
|
||||||
|
|
||||||
# 确保有三元组输出
|
# # 确保有三元组输出
|
||||||
if not (hasattr(res, 'predicate_logits') and hasattr(res, 'subject_logits') and hasattr(res, 'object_logits')):
|
# if not (hasattr(res, 'predicate_logits') and hasattr(res, 'subject_logits') and hasattr(res, 'object_logits')):
|
||||||
raise ValueError("模型没有输出三元组logits,请检查模型配置")
|
# raise ValueError("模型没有输出三元组logits,请检查模型配置")
|
||||||
|
|
||||||
# 确保有目标数据
|
# # 确保有目标数据
|
||||||
if target_input_ids is None:
|
# if target_input_ids is None:
|
||||||
raise ValueError("没有三元组目标数据,请检查数据格式")
|
# raise ValueError("没有三元组目标数据,请检查数据格式")
|
||||||
|
|
||||||
# 计算三元组损失
|
# 计算分类损失
|
||||||
try:
|
try:
|
||||||
Logger("使用预tokenized三元组目标数据", accelerator) if step == 0 else None
|
Logger("使用分类交叉熵损失", accelerator) if step == 0 else None
|
||||||
|
|
||||||
# 计时GPU损失计算
|
# 计时GPU损失计算
|
||||||
if args.profile and accelerator.is_main_process and loss_start is not None:
|
if args.profile and accelerator.is_main_process and loss_start is not None:
|
||||||
loss_start.record()
|
loss_start.record()
|
||||||
|
|
||||||
# 计算优化后的嵌入余弦相似度损失
|
# 计算交叉熵损失
|
||||||
loss = compute_triple_rouge_loss_optimized(
|
criterion = nn.CrossEntropyLoss()
|
||||||
res.subject_logits, res.predicate_logits, res.object_logits,
|
loss = criterion(res.predicate_class, Y)
|
||||||
target_input_ids, target_attention_mask, model.tok_embeddings, temperature=args.temperature
|
|
||||||
)
|
|
||||||
|
|
||||||
# 计时GPU损失计算结束
|
# 计时GPU损失计算结束
|
||||||
if args.profile and accelerator.is_main_process and loss_end is not None:
|
if args.profile and accelerator.is_main_process and loss_end is not None:
|
||||||
loss_end.record()
|
loss_end.record()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
Logger(f"Error: ROUGE loss computation failed: {e}", accelerator)
|
Logger(f"Error: 分类损失计算失败: {e}", accelerator)
|
||||||
import traceback
|
import traceback
|
||||||
Logger(f"Traceback: {traceback.format_exc()}", accelerator)
|
Logger(f"Traceback: {traceback.format_exc()}", accelerator)
|
||||||
loss = res.logits.sum() * 0.0 + 1.0
|
loss = res.logits.sum() * 0.0 + 1.0
|
||||||
@ -683,13 +741,13 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
|||||||
f"GPU利用率: {gpu_time_total/iter_time*100:.1f}%", accelerator)
|
f"GPU利用率: {gpu_time_total/iter_time*100:.1f}%", accelerator)
|
||||||
Logger("=" * 50, accelerator)
|
Logger("=" * 50, accelerator)
|
||||||
|
|
||||||
Logger("=== 三元组预测示例 ===", accelerator)
|
# Logger("=== 三元组预测示例 ===", accelerator)
|
||||||
predict_sentences = triple_to_sentence(res.subject_logits, res.predicate_logits, res.object_logits,tokenizer)
|
# predict_sentences = triple_to_sentence(res.subject_logits, res.predicate_logits, res.object_logits,tokenizer)
|
||||||
# 显示前2个样本的目标句子
|
# # 显示前2个样本的目标句子
|
||||||
for i, target_sentence in enumerate(target_sentences[:2]):
|
# for i, target_sentence in enumerate(target_sentences[:2]):
|
||||||
Logger(f"样本{i+1}目标: {target_sentence}", accelerator)
|
# Logger(f"样本{i+1}目标: {target_sentence}", accelerator)
|
||||||
Logger(f"样本{i+1}预测: {predict_sentences[i]}", accelerator)
|
# Logger(f"样本{i+1}预测: {predict_sentences[i]}", accelerator)
|
||||||
Logger("==================", accelerator)
|
Logger("=======val dataset=========", accelerator)
|
||||||
|
|
||||||
# 重置GPU事件
|
# 重置GPU事件
|
||||||
forward_start = torch.cuda.Event(enable_timing=True)
|
forward_start = torch.cuda.Event(enable_timing=True)
|
||||||
@ -734,11 +792,20 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
|||||||
|
|
||||||
# SwanLab日志记录
|
# SwanLab日志记录
|
||||||
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
|
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
|
||||||
|
Logger("=======val dataset=========", accelerator)
|
||||||
|
|
||||||
|
# 验证集评估
|
||||||
|
val_results = validate_model(model, val_loader, accelerator, ctx, args)
|
||||||
|
Logger(f"验证集结果 - 平均损失: {val_results['avg_loss']:.6f}, 准确率: {val_results['accuracy']:.4f}, 样本数: {val_results['total_samples']}", accelerator)
|
||||||
|
|
||||||
log_dict = {
|
log_dict = {
|
||||||
"epoch": epoch + 1,
|
"epoch": epoch + 1,
|
||||||
"step": step + 1,
|
"step": step + 1,
|
||||||
"total_steps_in_epoch": total_steps_in_epoch,
|
"total_steps_in_epoch": total_steps_in_epoch,
|
||||||
"triple_embedding_cosine_loss": loss.item() * args.accumulation_steps,
|
"train_loss": loss.item() * args.accumulation_steps,
|
||||||
|
"val_loss": val_results['avg_loss'],
|
||||||
|
"val_accuracy": val_results['accuracy'],
|
||||||
|
"val_samples": val_results['total_samples'],
|
||||||
"lr": current_lr,
|
"lr": current_lr,
|
||||||
"tokens_per_sec": tokens_per_sec,
|
"tokens_per_sec": tokens_per_sec,
|
||||||
"epoch_time_left_seconds": epoch_remaining_time,
|
"epoch_time_left_seconds": epoch_remaining_time,
|
||||||
@ -776,7 +843,7 @@ def main():
|
|||||||
parser.add_argument("--out_dir", type=str, default="out")
|
parser.add_argument("--out_dir", type=str, default="out")
|
||||||
parser.add_argument("--epochs", type=int, default=4)
|
parser.add_argument("--epochs", type=int, default=4)
|
||||||
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
|
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
|
||||||
parser.add_argument("--batch_size", type=int, default=192)
|
parser.add_argument("--batch_size", type=int, default=256)
|
||||||
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||||
parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数
|
parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数
|
||||||
@ -793,6 +860,7 @@ def main():
|
|||||||
parser.add_argument('--use_moe', default=False, type=bool)
|
parser.add_argument('--use_moe', default=False, type=bool)
|
||||||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
||||||
parser.add_argument("--data_path", type=str, default="./dataset/processed_trex_data.json")
|
parser.add_argument("--data_path", type=str, default="./dataset/processed_trex_data.json")
|
||||||
|
parser.add_argument("--predicate_vocab_path", type=str, default="./dataset/predicate_stats.json", help="Path to predicate vocabulary/statistics file")
|
||||||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||||
@ -932,7 +1000,8 @@ def main():
|
|||||||
# 创建数据集和数据加载器(专用于三元组提取训练)
|
# 创建数据集和数据加载器(专用于三元组提取训练)
|
||||||
#########################################################
|
#########################################################
|
||||||
Logger("三元组提取训练:使用 TriplePretrainDataset", accelerator)
|
Logger("三元组提取训练:使用 TriplePretrainDataset", accelerator)
|
||||||
train_ds = TriplePretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
train_ds = TriplePretrainDataset(data_path=args.data_path, predicate_vocab_path=args.predicate_vocab_path, tokenizer=tokenizer, max_length=lm_config.max_seq_len)
|
||||||
|
val_ds = TriplePretrainDataset(data_path=args.data_path,samples=train_ds.get_val_samples(), predicate_vocab_path=args.predicate_vocab_path, tokenizer=tokenizer, max_length=lm_config.max_seq_len)
|
||||||
|
|
||||||
# 创建自定义collate_fn来处理优化后的数据格式
|
# 创建自定义collate_fn来处理优化后的数据格式
|
||||||
def triple_collate_fn(batch):
|
def triple_collate_fn(batch):
|
||||||
@ -940,17 +1009,17 @@ def main():
|
|||||||
input_ids = torch.stack([item['input_ids'] for item in batch])
|
input_ids = torch.stack([item['input_ids'] for item in batch])
|
||||||
labels = torch.stack([item['labels'] for item in batch])
|
labels = torch.stack([item['labels'] for item in batch])
|
||||||
loss_mask = torch.stack([item['loss_mask'] for item in batch])
|
loss_mask = torch.stack([item['loss_mask'] for item in batch])
|
||||||
target_input_ids = torch.stack([item['target_input_ids'] for item in batch])
|
# target_input_ids = torch.stack([item['target_input_ids'] for item in batch])
|
||||||
target_attention_mask = torch.stack([item['target_attention_mask'] for item in batch])
|
# target_attention_mask = torch.stack([item['target_attention_mask'] for item in batch])
|
||||||
target_sentences = [item['target_sentence'] for item in batch] # 用于调试
|
# target_sentences = [item['target_sentence'] for item in batch] # 用于调试
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'input_ids': input_ids,
|
'input_ids': input_ids,
|
||||||
'labels': labels,
|
'labels': labels,
|
||||||
'loss_mask': loss_mask,
|
'loss_mask': loss_mask,
|
||||||
'target_input_ids': target_input_ids,
|
# 'target_input_ids': target_input_ids,
|
||||||
'target_attention_mask': target_attention_mask,
|
# 'target_attention_mask': target_attention_mask,
|
||||||
'target_sentences': target_sentences
|
# 'target_sentences': target_sentences
|
||||||
}
|
}
|
||||||
|
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
@ -963,6 +1032,15 @@ def main():
|
|||||||
# persistent_workers 和 prefetch_factor 在 num_workers=0 时自动禁用
|
# persistent_workers 和 prefetch_factor 在 num_workers=0 时自动禁用
|
||||||
collate_fn=triple_collate_fn
|
collate_fn=triple_collate_fn
|
||||||
)
|
)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_ds,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
pin_memory=False,
|
||||||
|
drop_last=True,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=0,
|
||||||
|
collate_fn=triple_collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
#########################################################
|
#########################################################
|
||||||
# 创建优化器
|
# 创建优化器
|
||||||
@ -993,7 +1071,7 @@ def main():
|
|||||||
overall_start_time = time.time() # Record overall start time
|
overall_start_time = time.time() # Record overall start time
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
Logger(f"开始第{epoch+1}轮训练", accelerator)
|
Logger(f"开始第{epoch+1}轮训练", accelerator)
|
||||||
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer) # Pass tokenizer
|
train_epoch(epoch, accelerator, model, train_loader,val_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer) # Pass tokenizer
|
||||||
|
|
||||||
# 每个epoch结束后进行内存清理
|
# 每个epoch结束后进行内存清理
|
||||||
Logger(f"第{epoch+1}轮训练完成,进行内存清理", accelerator)
|
Logger(f"第{epoch+1}轮训练完成,进行内存清理", accelerator)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user