使用分类来获取谓词
This commit is contained in:
parent
fcab661af9
commit
75265f6652
291
model/dataset.py
291
model/dataset.py
@ -14,6 +14,70 @@ from tqdm import tqdm
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
|
||||
def process_sample_filter(data_args):
|
||||
"""处理单个样本的过滤逻辑"""
|
||||
sample, valid_predicates = data_args
|
||||
if 'target' in sample and isinstance(sample['target'], list):
|
||||
# 过滤target中的低频谓词
|
||||
valid_targets = []
|
||||
for triple in sample['target']:
|
||||
if isinstance(triple, dict) and 'predicate' in triple:
|
||||
if triple['predicate'] in valid_predicates:
|
||||
valid_targets.append(triple)
|
||||
|
||||
# 如果还有有效的target,保留这个样本
|
||||
if valid_targets:
|
||||
sample['target'] = valid_targets
|
||||
return sample
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
# 如果没有target信息,保留样本
|
||||
return sample
|
||||
|
||||
|
||||
def process_sample_validation(data_args):
|
||||
"""处理单个样本的验证逻辑"""
|
||||
sample, predicate_vocab = data_args
|
||||
if not isinstance(sample, dict) or 'text' not in sample:
|
||||
return None
|
||||
|
||||
targets = sample.get('target', [])
|
||||
if not isinstance(targets, list) or len(targets) == 0:
|
||||
# 如果没有有效的target,创建一个默认的
|
||||
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
||||
else:
|
||||
# 验证并选择target,优先选择占比小的谓词
|
||||
selected_target = None
|
||||
min_percentage = float('inf')
|
||||
|
||||
for triple in targets:
|
||||
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
|
||||
predicate = triple['predicate']
|
||||
|
||||
# 使用predicate_vocab中的统计信息
|
||||
if predicate in predicate_vocab:
|
||||
stats = predicate_vocab[predicate]
|
||||
if isinstance(stats, dict) and 'percentage' in stats:
|
||||
percentage = stats['percentage']
|
||||
if percentage < min_percentage:
|
||||
min_percentage = percentage
|
||||
selected_target = triple
|
||||
elif selected_target is None:
|
||||
selected_target = triple
|
||||
elif selected_target is None:
|
||||
selected_target = triple
|
||||
|
||||
# 如果没有找到有效的target,使用默认值
|
||||
if selected_target is None:
|
||||
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
||||
|
||||
return {
|
||||
'text': sample['text'],
|
||||
'target': selected_target # 只保留一个target
|
||||
}
|
||||
|
||||
|
||||
class PretrainDataset(Dataset):
|
||||
def __init__(self, data_path, tokenizer, max_length=512):
|
||||
super().__init__()
|
||||
@ -204,15 +268,94 @@ class TriplePretrainDataset(Dataset):
|
||||
- 预先tokenize所有数据
|
||||
- 使用进度条显示处理进度
|
||||
"""
|
||||
def __init__(self, data_path, tokenizer, max_length=512):
|
||||
def __init__(self, data_path=None, predicate_vocab_path=None, samples = None,tokenizer=None, max_length=512):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.val_samples = None
|
||||
self.predicate_to_id = {} # 初始化
|
||||
if samples is None:
|
||||
self.predicate_vocab = self.load_predicate_vocab(predicate_vocab_path)
|
||||
print("🚀 开始加载和预处理三元组数据...")
|
||||
self.samples = self.load_and_preprocess_data(data_path)
|
||||
self.samples,self.val_samples = self.load_and_preprocess_data(data_path)
|
||||
print("🚀 加载和预处理三元组数据完成")
|
||||
else:
|
||||
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
|
||||
data_filename = os.path.basename(data_path).split('.')[0]
|
||||
predicate_to_id_path = os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json')
|
||||
self.predicate_to_id = self.load_predicate_vocab(predicate_to_id_path)
|
||||
self.samples = samples
|
||||
print("🚀 加载和预处理三元组数据完成")
|
||||
def load_predicate_vocab(self, path):
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
predicate_vocab = json.load(f)
|
||||
return predicate_vocab
|
||||
|
||||
def get_val_samples(self):
|
||||
return self.val_samples
|
||||
|
||||
def clear_cache(self, data_path):
|
||||
"""清除缓存文件"""
|
||||
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
|
||||
data_filename = os.path.basename(data_path).split('.')[0]
|
||||
cache_files = [
|
||||
os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
|
||||
os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
|
||||
os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
|
||||
os.path.join(cache_dir, f'{data_filename}_val_samples.json')
|
||||
]
|
||||
|
||||
for cache_file in cache_files:
|
||||
if os.path.exists(cache_file):
|
||||
os.remove(cache_file)
|
||||
print(f"🗑️ 已删除缓存文件: {cache_file}")
|
||||
|
||||
if os.path.exists(cache_dir) and not os.listdir(cache_dir):
|
||||
os.rmdir(cache_dir)
|
||||
print(f"🗑️ 已删除空的缓存目录: {cache_dir}")
|
||||
|
||||
def load_and_preprocess_data(self, path):
|
||||
"""加载并预处理三元组数据"""
|
||||
# 生成缓存文件名(基于数据文件路径)
|
||||
cache_dir = os.path.join(os.path.dirname(path), 'cache')
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
data_filename = os.path.basename(path).split('.')[0]
|
||||
cache_files = {
|
||||
'predicate_vocab': os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
|
||||
'predicate_to_id': os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
|
||||
'train_samples': os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
|
||||
'val_samples': os.path.join(cache_dir, f'{data_filename}_val_samples.json')
|
||||
}
|
||||
|
||||
# 检查缓存文件是否存在
|
||||
cache_exists = all(os.path.exists(cache_file) for cache_file in cache_files.values())
|
||||
|
||||
if cache_exists:
|
||||
print("📁 发现缓存文件,直接加载...")
|
||||
# 从缓存加载
|
||||
with open(cache_files['predicate_vocab'], 'r', encoding='utf-8') as f:
|
||||
self.predicate_vocab = json.load(f)
|
||||
|
||||
with open(cache_files['predicate_to_id'], 'r', encoding='utf-8') as f:
|
||||
self.predicate_to_id = json.load(f)
|
||||
|
||||
with open(cache_files['train_samples'], 'r', encoding='utf-8') as f:
|
||||
train_samples = json.load(f)
|
||||
|
||||
with open(cache_files['val_samples'], 'r', encoding='utf-8') as f:
|
||||
val_samples = json.load(f)
|
||||
|
||||
print(f"✅ 从缓存加载完成:")
|
||||
print(f"✅ 谓词词表大小: {len(self.predicate_vocab)}")
|
||||
print(f"✅ 训练集大小: {len(train_samples)}")
|
||||
print(f"✅ 测试集大小: {len(val_samples)}")
|
||||
|
||||
return train_samples, val_samples
|
||||
|
||||
# 缓存不存在,重新处理数据
|
||||
print("📂 缓存不存在,开始加载和处理原始数据...")
|
||||
|
||||
# 1. 加载原始数据
|
||||
print("📂 加载原始数据...")
|
||||
if path.endswith('.json'):
|
||||
@ -229,70 +372,91 @@ class TriplePretrainDataset(Dataset):
|
||||
|
||||
print(f"📊 原始数据量: {len(data)} 个样本")
|
||||
|
||||
# 2. 数据验证和筛选(只保留一个target)
|
||||
print("🔍 验证数据格式并选择单个target...")
|
||||
# 2. 使用self.predicate_vocab过滤占比小于0.01%的谓词数据
|
||||
print("🔍 过滤低频谓词数据...")
|
||||
print(f"📊 谓词统计数据: 总共{len(self.predicate_vocab)}个谓词")
|
||||
|
||||
# 3.获取占比大于等于0.01%的谓词
|
||||
valid_predicates = set()
|
||||
for predicate, stats in self.predicate_vocab.items():
|
||||
if isinstance(stats, dict) and 'percentage' in stats:
|
||||
if stats['percentage'] >= 0.01:
|
||||
valid_predicates.add(predicate)
|
||||
else:
|
||||
# 如果不是统计格式,假设是有效谓词
|
||||
valid_predicates.add(predicate)
|
||||
|
||||
print(f"📊 占比≥0.01%的谓词: {len(valid_predicates)}个")
|
||||
|
||||
# 4.过滤数据:去除包含低频谓词的数据(单进程处理)
|
||||
original_count = len(data)
|
||||
filtered_data = []
|
||||
|
||||
print("🚀 开始过滤低频谓词数据...")
|
||||
for sample in tqdm(data, desc="过滤低频谓词"):
|
||||
result = process_sample_filter((sample, valid_predicates))
|
||||
if result is not None:
|
||||
filtered_data.append(result)
|
||||
|
||||
data = filtered_data
|
||||
print(f"✅ 过滤完成: 去除前{original_count}条,去除后{len(data)}条")
|
||||
|
||||
# 5. 去除self.predicate_vocab中占比小于0.01%的谓词,并创建谓词到序号的映射
|
||||
print("🔍 更新谓词词表并创建序号映射...")
|
||||
original_vocab_size = len(self.predicate_vocab)
|
||||
filtered_predicate_vocab = {}
|
||||
|
||||
for predicate, stats in self.predicate_vocab.items():
|
||||
if isinstance(stats, dict) and 'percentage' in stats:
|
||||
if stats['percentage'] >= 0.01:
|
||||
filtered_predicate_vocab[predicate] = stats
|
||||
else:
|
||||
# 如果不是统计格式,保留
|
||||
filtered_predicate_vocab[predicate] = stats
|
||||
|
||||
# 创建谓词到序号的映射字典
|
||||
self.predicate_to_id = {predicate: idx for idx, predicate in enumerate(filtered_predicate_vocab.keys())}
|
||||
self.predicate_vocab = filtered_predicate_vocab
|
||||
print(f"✅ 谓词词表更新: 去除前{original_vocab_size}个,去除后{len(self.predicate_vocab)}个")
|
||||
print(f"✅ 谓词映射创建: {len(self.predicate_to_id)}个谓词对应序号")
|
||||
|
||||
# 6. 数据验证和筛选(只保留一个target),优先选择占比小的谓词以平衡数据(单进程处理)
|
||||
print("🔍 验证数据格式并选择单个target(平衡数据)...")
|
||||
valid_samples = []
|
||||
|
||||
for i, sample in enumerate(tqdm(data, desc="验证数据格式")):
|
||||
if not isinstance(sample, dict) or 'text' not in sample:
|
||||
continue
|
||||
|
||||
targets = sample.get('target', [])
|
||||
if not isinstance(targets, list) or len(targets) == 0:
|
||||
# 如果没有有效的target,创建一个默认的
|
||||
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
||||
else:
|
||||
# 验证并选择第一个有效的target
|
||||
selected_target = None
|
||||
for triple in targets:
|
||||
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
|
||||
selected_target = triple
|
||||
break
|
||||
|
||||
# 如果没有找到有效的target,使用默认值
|
||||
if selected_target is None:
|
||||
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
||||
|
||||
valid_samples.append({
|
||||
'text': sample['text'],
|
||||
'target': selected_target # 只保留一个target
|
||||
})
|
||||
print("🚀 开始验证数据格式...")
|
||||
for sample in tqdm(data, desc="验证数据格式"):
|
||||
result = process_sample_validation((sample, self.predicate_vocab))
|
||||
if result is not None:
|
||||
valid_samples.append(result)
|
||||
|
||||
print(f"✅ 有效样本数: {len(valid_samples)}")
|
||||
|
||||
# 3. 分批tokenize目标句子
|
||||
print("🔤 分批tokenize目标句子...")
|
||||
# 7.拆分训练集合与测试集合
|
||||
import random
|
||||
random.seed(42)
|
||||
val_samples = random.sample(valid_samples, min(1000, len(valid_samples)))
|
||||
train_samples = [sample for sample in valid_samples if sample not in val_samples]
|
||||
print(f"✅ 训练集大小: {len(train_samples)}")
|
||||
print(f"✅ 测试集大小: {len(val_samples)}")
|
||||
|
||||
processed_samples = []
|
||||
batch_size = 1000 # 每批处理1000个句子,避免内存爆炸
|
||||
# 8. 保存到缓存文件
|
||||
print("💾 保存处理结果到缓存文件...")
|
||||
with open(cache_files['predicate_vocab'], 'w', encoding='utf-8') as f:
|
||||
json.dump(self.predicate_vocab, f, ensure_ascii=False, indent=2)
|
||||
|
||||
for i in tqdm(range(0, len(valid_samples), batch_size), desc="分批tokenize目标句子"):
|
||||
# 获取当前批次
|
||||
batch_samples = valid_samples[i:i + batch_size]
|
||||
with open(cache_files['predicate_to_id'], 'w', encoding='utf-8') as f:
|
||||
json.dump(self.predicate_to_id, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 提取当前批次的目标句子
|
||||
batch_target_sentences = [self._triple_to_sentence(sample['target']) for sample in batch_samples]
|
||||
with open(cache_files['train_samples'], 'w', encoding='utf-8') as f:
|
||||
json.dump(train_samples, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 批量tokenize当前批次
|
||||
batch_encodings = self.tokenizer(
|
||||
batch_target_sentences,
|
||||
max_length=128, # 目标句子通常较短
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
with open(cache_files['val_samples'], 'w', encoding='utf-8') as f:
|
||||
json.dump(val_samples, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 构建当前批次的样本数据
|
||||
for j, sample in enumerate(batch_samples):
|
||||
processed_samples.append({
|
||||
'text': sample['text'], # 保持原始文本,不进行tokenize
|
||||
'target_input_ids': batch_encodings.input_ids[j],
|
||||
'target_attention_mask': batch_encodings.attention_mask[j],
|
||||
'target_sentence': batch_target_sentences[j], # 保留原始句子用于调试
|
||||
})
|
||||
print("✅ 缓存文件保存完成")
|
||||
|
||||
print(f"🎉 数据预处理完成! 共处理 {len(processed_samples)} 个样本")
|
||||
return processed_samples
|
||||
return train_samples, val_samples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
@ -302,10 +466,10 @@ class TriplePretrainDataset(Dataset):
|
||||
return f"{triple['subject']} {triple['predicate']} {triple['object']}"
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""返回数据,输入文本在运行时tokenize,目标已预tokenize"""
|
||||
"""返回数据,用于谓词分类任务"""
|
||||
sample = self.samples[index]
|
||||
|
||||
# 在运行时tokenize输入文本(用于语言建模)
|
||||
# 在运行时tokenize输入文本
|
||||
input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}"
|
||||
encoding = self.tokenizer(
|
||||
input_text,
|
||||
@ -317,19 +481,18 @@ class TriplePretrainDataset(Dataset):
|
||||
input_ids = encoding.input_ids.squeeze()
|
||||
loss_mask = (input_ids != self.tokenizer.pad_token_id)
|
||||
|
||||
# 获取谓词分类标签
|
||||
target_predicate = sample['target']['predicate']
|
||||
predicate_label = self.predicate_to_id.get(target_predicate) # 默认为0如果找不到
|
||||
|
||||
# 构建训练数据
|
||||
X = input_ids[:-1]
|
||||
Y = input_ids[1:]
|
||||
loss_mask = loss_mask[1:]
|
||||
|
||||
return {
|
||||
'input_ids': X,
|
||||
'labels': Y,
|
||||
'loss_mask': loss_mask,
|
||||
'target_input_ids': sample['target_input_ids'], # 已经是tensor
|
||||
'target_attention_mask': sample['target_attention_mask'], # 已经是tensor
|
||||
'target_sentence': sample['target_sentence'], # 字符串,用于调试
|
||||
'original_text': sample['text']
|
||||
'labels': torch.tensor(predicate_label, dtype=torch.long), # 谓词分类标签
|
||||
'loss_mask': loss_mask
|
||||
}
|
||||
|
||||
|
||||
|
@ -489,8 +489,8 @@ class TripleExtractionHead(nn.Module):
|
||||
self.self_attn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
|
||||
# 交叉注意力机制(用于主语和宾语提取)
|
||||
self.cross_attention_subject = CrossAttention(config)
|
||||
self.cross_attention_object = CrossAttention(config)
|
||||
# self.cross_attention_subject = CrossAttention(config)
|
||||
# self.cross_attention_object = CrossAttention(config)
|
||||
|
||||
# 归一化层
|
||||
self.subject_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
@ -498,13 +498,13 @@ class TripleExtractionHead(nn.Module):
|
||||
|
||||
# Feed Forward 网络
|
||||
self.predicate_ff = FeedForward(config)
|
||||
self.subject_ff = FeedForward(config)
|
||||
self.object_ff = FeedForward(config)
|
||||
# self.subject_ff = FeedForward(config)
|
||||
# self.object_ff = FeedForward(config)
|
||||
|
||||
# 输出投影层 - 修改为支持序列预测
|
||||
self.predicate_output = nn.Linear(config.dim, self.max_predicate_len *config.dim, bias=False)
|
||||
self.subject_output = nn.Linear(config.dim, self.max_subject_len * config.dim, bias=False)
|
||||
self.object_output = nn.Linear(config.dim, self.max_object_len * config.dim, bias=False)
|
||||
self.predicate_output = nn.Linear(config.dim, 264, bias=False)
|
||||
# self.subject_output = nn.Linear(config.dim, self.max_subject_len * config.dim, bias=False)
|
||||
# self.object_output = nn.Linear(config.dim, self.max_object_len * config.dim, bias=False)
|
||||
|
||||
print(f"三元组提取任务头配置:")
|
||||
print(f"- 主语最大长度: {self.max_subject_len}")
|
||||
@ -530,30 +530,29 @@ class TripleExtractionHead(nn.Module):
|
||||
# 2. h1通过feed_forward得到谓语输出
|
||||
predicate_features = self.predicate_ff(h1)
|
||||
predicate_features = predicate_features.mean(dim=1)
|
||||
predicate_raw = self.predicate_output(predicate_features) # [batch_size, max_predicate_len * vocab_size]
|
||||
predicate_logits = predicate_raw.view(batch_size, self.max_predicate_len, -1)
|
||||
predicate_class = self.predicate_output(predicate_features) # [batch_size, max_predicate_len * vocab_size]
|
||||
|
||||
# 3. h1通过交叉注意力(k,v都是h)得到h2
|
||||
h2 = self.cross_attention_subject(h1, h) # query是h1,key和value都是h
|
||||
h2 = h1 + h2 # 残差连接
|
||||
# # 3. h1通过交叉注意力(k,v都是h)得到h2
|
||||
# h2 = self.cross_attention_subject(h1, h) # query是h1,key和value都是h
|
||||
# h2 = h1 + h2 # 残差连接
|
||||
|
||||
# 4. h2通过feed_forward得到主语输出
|
||||
subject_features = self.subject_ff(self.subject_norm(h2))
|
||||
subject_features = subject_features.mean(dim=1)
|
||||
subject_raw = self.subject_output(subject_features) # [batch_size, max_subject_len * vocab_size]
|
||||
subject_logits = subject_raw.view(batch_size, self.max_subject_len, -1)
|
||||
# # 4. h2通过feed_forward得到主语输出
|
||||
# subject_features = self.subject_ff(self.subject_norm(h2))
|
||||
# subject_features = subject_features.mean(dim=1)
|
||||
# subject_raw = self.subject_output(subject_features) # [batch_size, max_subject_len * vocab_size]
|
||||
# subject_logits = subject_raw.view(batch_size, self.max_subject_len, -1)
|
||||
|
||||
# 5. h2通过交叉注意力(k,v都是h)得到h3
|
||||
h3 = self.cross_attention_object(h2, h) # query是h2,key和value都是h
|
||||
h3 = h2 + h3 # 残差连接
|
||||
# # 5. h2通过交叉注意力(k,v都是h)得到h3
|
||||
# h3 = self.cross_attention_object(h2, h) # query是h2,key和value都是h
|
||||
# h3 = h2 + h3 # 残差连接
|
||||
|
||||
# 6. h3通过feed_forward得到宾语输出
|
||||
object_features = self.object_ff(self.object_norm(h3))
|
||||
object_features = object_features.mean(dim=1)
|
||||
object_raw = self.object_output(object_features) # [batch_size, max_object_len * vocab_size]
|
||||
object_logits = object_raw.view(batch_size, self.max_object_len, -1)
|
||||
# # 6. h3通过feed_forward得到宾语输出
|
||||
# object_features = self.object_ff(self.object_norm(h3))
|
||||
# object_features = object_features.mean(dim=1)
|
||||
# object_raw = self.object_output(object_features) # [batch_size, max_object_len * vocab_size]
|
||||
# object_logits = object_raw.view(batch_size, self.max_object_len, -1)
|
||||
|
||||
return predicate_logits, subject_logits, object_logits
|
||||
return predicate_class
|
||||
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
@ -656,18 +655,8 @@ class MiniMindLM(PreTrainedModel):
|
||||
)
|
||||
|
||||
# 应用三元组提取任务头
|
||||
predicate_logits, subject_logits, object_logits = self.triple_extraction_head(h, pos_cis)
|
||||
predicate_logits = predicate_logits.reshape(input_ids.size(0)*self.params.max_predicate_len, -1)
|
||||
subject_logits = subject_logits.reshape(input_ids.size(0)*self.params.max_subject_len, -1)
|
||||
object_logits = object_logits.reshape(input_ids.size(0)*self.params.max_object_len, -1)
|
||||
predicate_class = self.triple_extraction_head(h, pos_cis)
|
||||
|
||||
predicate_logits = self.output(predicate_logits)
|
||||
subject_logits = self.output(subject_logits)
|
||||
object_logits = self.output(object_logits)
|
||||
|
||||
predicate_logits = predicate_logits.reshape(input_ids.size(0), self.params.max_predicate_len, -1)
|
||||
subject_logits = subject_logits.reshape(input_ids.size(0), self.params.max_subject_len, -1)
|
||||
object_logits = object_logits.reshape(input_ids.size(0), self.params.max_object_len, -1)
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
@ -682,9 +671,7 @@ class MiniMindLM(PreTrainedModel):
|
||||
|
||||
# 添加三元组提取结果
|
||||
# 注意:现在的维度是 [batch_size, seq_len, max_len, vocab_size]
|
||||
output.predicate_logits = predicate_logits
|
||||
output.subject_logits = subject_logits
|
||||
output.object_logits = object_logits
|
||||
output.predicate_class = predicate_class
|
||||
|
||||
return output
|
||||
|
||||
|
@ -1,225 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
JSON文件合并脚本
|
||||
读取多个JSON文件并合并为一个JSON文件
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Any, Union
|
||||
|
||||
# 需要合并的JSON文件列表
|
||||
JSON_FILES_TO_MERGE = [
|
||||
"output/trex_sentences_enhanced_checkpoint_360000.json"
|
||||
]
|
||||
for i in range(1, 1010):
|
||||
JSON_FILES_TO_MERGE.append(f"output/trex_sentences_enhanced_batch_{i}.json")
|
||||
|
||||
def load_json_file(file_path: str) -> Union[Dict, List, None]:
|
||||
"""加载JSON文件"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"警告: 文件 {file_path} 不存在")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
print(f"成功加载: {file_path}")
|
||||
return data
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"错误: 无法解析JSON文件 {file_path} - {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"错误: 读取文件 {file_path} 失败 - {e}")
|
||||
return None
|
||||
|
||||
def merge_json_data(data1: Union[Dict, List], data2: Union[Dict, List]) -> Union[Dict, List]:
|
||||
"""合并两个JSON数据结构"""
|
||||
|
||||
# 如果两个都是列表,直接合并
|
||||
if isinstance(data1, list) and isinstance(data2, list):
|
||||
print(f"合并两个列表: {len(data1)} + {len(data2)} = {len(data1) + len(data2)} 项")
|
||||
return data1 + data2
|
||||
|
||||
# 如果两个都是字典
|
||||
elif isinstance(data1, dict) and isinstance(data2, dict):
|
||||
print("合并两个字典结构")
|
||||
merged = data1.copy()
|
||||
|
||||
# 特殊处理:如果都有'sentences'字段且为列表,合并sentences
|
||||
if 'sentences' in data1 and 'sentences' in data2:
|
||||
if isinstance(data1['sentences'], list) and isinstance(data2['sentences'], list):
|
||||
print(f"合并sentences字段: {len(data1['sentences'])} + {len(data2['sentences'])} = {len(data1['sentences']) + len(data2['sentences'])} 项")
|
||||
merged['sentences'] = data1['sentences'] + data2['sentences']
|
||||
|
||||
# 更新metadata if exists
|
||||
if 'metadata' in merged:
|
||||
if isinstance(merged['metadata'], dict):
|
||||
merged['metadata']['total_sentences'] = len(merged['sentences'])
|
||||
merged['metadata']['merged_from'] = [os.path.basename(f) for f in JSON_FILES_TO_MERGE if os.path.exists(f)]
|
||||
|
||||
# 合并其他字段
|
||||
for key, value in data2.items():
|
||||
if key != 'sentences' and key not in merged:
|
||||
merged[key] = value
|
||||
|
||||
return merged
|
||||
|
||||
# 普通字典合并
|
||||
for key, value in data2.items():
|
||||
if key in merged:
|
||||
# 如果key重复且都是列表,合并列表
|
||||
if isinstance(merged[key], list) and isinstance(value, list):
|
||||
merged[key] = merged[key] + value
|
||||
# 如果key重复且都是字典,递归合并
|
||||
elif isinstance(merged[key], dict) and isinstance(value, dict):
|
||||
merged[key] = merge_json_data(merged[key], value)
|
||||
else:
|
||||
# 其他情况保留第二个文件的值
|
||||
merged[key] = value
|
||||
print(f"字段 '{key}' 被覆盖")
|
||||
else:
|
||||
merged[key] = value
|
||||
|
||||
return merged
|
||||
|
||||
# 类型不匹配的情况,创建一个包含两者的新结构
|
||||
else:
|
||||
print("数据类型不匹配,创建包含两者的新结构")
|
||||
return {
|
||||
"data_from_save.json": data1,
|
||||
"data_from_save2.json": data2,
|
||||
"merged_at": "test.py"
|
||||
}
|
||||
|
||||
def save_merged_json(data: Union[Dict, List], output_path: str):
|
||||
"""保存合并后的JSON数据"""
|
||||
try:
|
||||
# 确保输出目录存在
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"合并结果已保存到: {output_path}")
|
||||
|
||||
# 显示统计信息
|
||||
if isinstance(data, dict):
|
||||
if 'sentences' in data and isinstance(data['sentences'], list):
|
||||
print(f"总计句子数: {len(data['sentences'])}")
|
||||
print(f"总计字段数: {len(data)}")
|
||||
elif isinstance(data, list):
|
||||
print(f"总计列表项数: {len(data)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误: 保存文件失败 - {e}")
|
||||
|
||||
def remove_duplicates_from_sentences(data: Union[Dict, List]) -> Union[Dict, List]:
|
||||
"""从合并结果中移除重复的句子(基于句子内容)"""
|
||||
if isinstance(data, dict) and 'sentences' in data:
|
||||
if isinstance(data['sentences'], list):
|
||||
original_count = len(data['sentences'])
|
||||
seen_sentences = set()
|
||||
unique_sentences = []
|
||||
|
||||
for item in data['sentences']:
|
||||
if isinstance(item, dict):
|
||||
# 如果是字典,使用sentence字段或corrected_sentence字段作为唯一标识
|
||||
sentence_key = item.get('sentence') or item.get('corrected_sentence') or item.get('original_sentence')
|
||||
elif isinstance(item, str):
|
||||
sentence_key = item
|
||||
else:
|
||||
sentence_key = str(item)
|
||||
|
||||
if sentence_key and sentence_key not in seen_sentences:
|
||||
seen_sentences.add(sentence_key)
|
||||
unique_sentences.append(item)
|
||||
|
||||
data['sentences'] = unique_sentences
|
||||
|
||||
# 更新metadata
|
||||
if 'metadata' in data and isinstance(data['metadata'], dict):
|
||||
data['metadata']['total_sentences'] = len(unique_sentences)
|
||||
data['metadata']['duplicates_removed'] = original_count - len(unique_sentences)
|
||||
|
||||
print(f"去重完成: {original_count} -> {len(unique_sentences)} (移除了 {original_count - len(unique_sentences)} 个重复项)")
|
||||
|
||||
return data
|
||||
|
||||
def merge_multiple_json_data(data_list: List[Union[Dict, List]]) -> Union[Dict, List]:
|
||||
"""合并多个JSON数据结构"""
|
||||
if not data_list:
|
||||
return {}
|
||||
|
||||
if len(data_list) == 1:
|
||||
return data_list[0]
|
||||
|
||||
print(f"准备合并 {len(data_list)} 个JSON数据结构")
|
||||
|
||||
# 从第一个数据开始,逐步合并其他数据
|
||||
merged_data = data_list[0]
|
||||
|
||||
for i, data in enumerate(data_list[1:], 1):
|
||||
print(f"正在合并第 {i+1} 个数据结构...")
|
||||
merged_data = merge_json_data(merged_data, data)
|
||||
|
||||
return merged_data
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("=== JSON文件合并脚本 ===")
|
||||
|
||||
# 输出路径
|
||||
output_path = "output/merged.json"
|
||||
|
||||
print(f"准备合并以下文件:")
|
||||
for i, file_path in enumerate(JSON_FILES_TO_MERGE, 1):
|
||||
print(f" {i}. {file_path}")
|
||||
print(f"输出文件: {output_path}")
|
||||
print()
|
||||
|
||||
# 加载所有文件
|
||||
loaded_data = []
|
||||
successfully_loaded = []
|
||||
|
||||
for file_path in JSON_FILES_TO_MERGE:
|
||||
data = load_json_file(file_path)
|
||||
if data is not None:
|
||||
loaded_data.append(data)
|
||||
successfully_loaded.append(file_path)
|
||||
|
||||
# 检查是否至少有一个文件加载成功
|
||||
if not loaded_data:
|
||||
print("错误: 没有文件能够成功加载,退出")
|
||||
return
|
||||
|
||||
print(f"成功加载了 {len(loaded_data)} 个文件:")
|
||||
for file_path in successfully_loaded:
|
||||
print(f" ✓ {file_path}")
|
||||
|
||||
if len(loaded_data) < len(JSON_FILES_TO_MERGE):
|
||||
failed_count = len(JSON_FILES_TO_MERGE) - len(loaded_data)
|
||||
print(f"警告: {failed_count} 个文件加载失败")
|
||||
print()
|
||||
|
||||
# 合并所有数据
|
||||
if len(loaded_data) == 1:
|
||||
print("只有一个文件可用,直接使用...")
|
||||
merged_data = loaded_data[0]
|
||||
else:
|
||||
print("开始合并所有文件...")
|
||||
merged_data = merge_multiple_json_data(loaded_data)
|
||||
|
||||
# 去重处理
|
||||
print("\n检查并去除重复项...")
|
||||
merged_data = remove_duplicates_from_sentences(merged_data)
|
||||
|
||||
# 保存合并结果
|
||||
print("\n保存合并结果...")
|
||||
save_merged_json(merged_data, output_path)
|
||||
|
||||
print("\n=== 合并完成 ===")
|
||||
print(f"合并了 {len(successfully_loaded)} 个文件的数据")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,61 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
小规模测试预处理脚本
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加路径
|
||||
sys.path.append('/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/preprocessing')
|
||||
|
||||
# 导入主模块
|
||||
from preprocess_pretrain import *
|
||||
|
||||
# 修改配置为小规模测试
|
||||
DATASET_CONFIG["wikipedia"]["max_samples"] = 100
|
||||
DATASET_CONFIG["gutenberg"]["max_samples"] = 50
|
||||
DATASET_CONFIG["openwebtext"]["max_samples"] = 20
|
||||
|
||||
DATASET_CONFIG_EXTRA["wikipedia"]["max_samples"] = 50
|
||||
DATASET_CONFIG_EXTRA["gutenberg"]["max_samples"] = 30
|
||||
DATASET_CONFIG_EXTRA["openwebtext"]["max_samples"] = 15
|
||||
|
||||
# 修改输出路径
|
||||
OUTPUT_FILE = "/tmp/test_main.jsonl"
|
||||
OUTPUT_FILE_EXTRA = "/tmp/test_extra.jsonl"
|
||||
|
||||
def test_small_scale():
|
||||
"""小规模测试"""
|
||||
print("Starting small scale test...")
|
||||
|
||||
# 设置随机种子
|
||||
random.seed(42)
|
||||
|
||||
try:
|
||||
# 初始化tokenizer
|
||||
init_tokenizer()
|
||||
|
||||
# 开始合并数据集
|
||||
merge_datasets()
|
||||
|
||||
# 检查输出文件
|
||||
if os.path.exists(OUTPUT_FILE):
|
||||
with open(OUTPUT_FILE, 'r') as f:
|
||||
main_lines = len(f.readlines())
|
||||
print(f"Main file created: {main_lines} lines")
|
||||
|
||||
if os.path.exists(OUTPUT_FILE_EXTRA):
|
||||
with open(OUTPUT_FILE_EXTRA, 'r') as f:
|
||||
extra_lines = len(f.readlines())
|
||||
print(f"Extra file created: {extra_lines} lines")
|
||||
|
||||
print("Small scale test completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_small_scale()
|
File diff suppressed because it is too large
Load Diff
@ -127,6 +127,7 @@ dependencies = [
|
||||
"regex==2024.11.6",
|
||||
"requests==2.32.3",
|
||||
"rich==13.7.1",
|
||||
"rouge-score>=0.1.2",
|
||||
"rpds-py==0.24.0",
|
||||
"s3transfer==0.13.0",
|
||||
"safetensors==0.5.3",
|
||||
|
@ -397,6 +397,66 @@ def log_memory_status(step, accelerator, stage="", detailed=False):
|
||||
|
||||
Logger(log_msg, accelerator)
|
||||
|
||||
# 验证函数
|
||||
def validate_model(model, val_loader, accelerator, ctx, args):
|
||||
"""
|
||||
验证模型性能
|
||||
Args:
|
||||
model: 模型
|
||||
val_loader: 验证集数据加载器
|
||||
accelerator: accelerator对象
|
||||
ctx: 上下文管理器
|
||||
args: 参数
|
||||
Returns:
|
||||
dict: 包含平均损失和准确率的字典
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
total_loss = 0.0
|
||||
correct_predictions = 0
|
||||
total_predictions = 0
|
||||
num_batches = 0
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_data in val_loader:
|
||||
try:
|
||||
# 数据准备
|
||||
X = batch_data['input_ids'].to(accelerator.device)
|
||||
Y = batch_data['labels']
|
||||
|
||||
# 前向传播
|
||||
with ctx:
|
||||
res = model(X, step=0) # 验证时step设为0
|
||||
loss = criterion(res.predicate_class.cpu(), Y.cpu())
|
||||
|
||||
# 计算准确率
|
||||
predicted_classes = torch.argmax(res.predicate_class, dim=1)
|
||||
predicted_classes = predicted_classes.to(Y.device)
|
||||
correct_predictions += (predicted_classes == Y).sum().item()
|
||||
total_predictions += Y.size(0)
|
||||
|
||||
# 累计损失
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
except Exception as e:
|
||||
Logger(f"验证时出错: {e}", accelerator)
|
||||
continue
|
||||
|
||||
# 计算平均值
|
||||
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
||||
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
||||
|
||||
model.train() # 重新设置为训练模式
|
||||
|
||||
return {
|
||||
'avg_loss': avg_loss,
|
||||
'accuracy': accuracy,
|
||||
'total_samples': total_predictions
|
||||
}
|
||||
|
||||
# 日志记录函数
|
||||
def Logger(msg, accelerator=None):
|
||||
# 如果没有提供accelerator,则只在主进程打印
|
||||
@ -515,7 +575,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
return model, tokenizer
|
||||
|
||||
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer):
|
||||
def train_epoch(epoch, accelerator, model, train_loader,val_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer):
|
||||
# 三元组提取训练模式:不需要传统的交叉熵损失函数
|
||||
epoch_start_time = time.time()
|
||||
total_steps_in_epoch = len(train_loader)
|
||||
@ -563,9 +623,9 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
X = batch_data['input_ids']
|
||||
Y = batch_data['labels']
|
||||
loss_mask = batch_data['loss_mask']
|
||||
target_input_ids = batch_data['target_input_ids']
|
||||
target_attention_mask = batch_data['target_attention_mask']
|
||||
target_sentences = batch_data['target_sentences'] # 用于调试输出
|
||||
# target_input_ids = batch_data['target_input_ids']
|
||||
# target_attention_mask = batch_data['target_attention_mask']
|
||||
# target_sentences = batch_data['target_sentences'] # 用于调试输出
|
||||
|
||||
# === 2. 学习率更新 ===
|
||||
if scheduler is not None:
|
||||
@ -590,36 +650,34 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
|
||||
# === 4. 损失计算 ===
|
||||
# 三元组提取模式:只使用ROUGE Loss进行三元组损失计算
|
||||
Logger("三元组提取训练模式", accelerator) if step == 0 else None
|
||||
# Logger("三元组提取训练模式", accelerator) if step == 0 else None
|
||||
|
||||
# 确保有三元组输出
|
||||
if not (hasattr(res, 'predicate_logits') and hasattr(res, 'subject_logits') and hasattr(res, 'object_logits')):
|
||||
raise ValueError("模型没有输出三元组logits,请检查模型配置")
|
||||
# # 确保有三元组输出
|
||||
# if not (hasattr(res, 'predicate_logits') and hasattr(res, 'subject_logits') and hasattr(res, 'object_logits')):
|
||||
# raise ValueError("模型没有输出三元组logits,请检查模型配置")
|
||||
|
||||
# 确保有目标数据
|
||||
if target_input_ids is None:
|
||||
raise ValueError("没有三元组目标数据,请检查数据格式")
|
||||
# # 确保有目标数据
|
||||
# if target_input_ids is None:
|
||||
# raise ValueError("没有三元组目标数据,请检查数据格式")
|
||||
|
||||
# 计算三元组损失
|
||||
# 计算分类损失
|
||||
try:
|
||||
Logger("使用预tokenized三元组目标数据", accelerator) if step == 0 else None
|
||||
Logger("使用分类交叉熵损失", accelerator) if step == 0 else None
|
||||
|
||||
# 计时GPU损失计算
|
||||
if args.profile and accelerator.is_main_process and loss_start is not None:
|
||||
loss_start.record()
|
||||
|
||||
# 计算优化后的嵌入余弦相似度损失
|
||||
loss = compute_triple_rouge_loss_optimized(
|
||||
res.subject_logits, res.predicate_logits, res.object_logits,
|
||||
target_input_ids, target_attention_mask, model.tok_embeddings, temperature=args.temperature
|
||||
)
|
||||
# 计算交叉熵损失
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
loss = criterion(res.predicate_class, Y)
|
||||
|
||||
# 计时GPU损失计算结束
|
||||
if args.profile and accelerator.is_main_process and loss_end is not None:
|
||||
loss_end.record()
|
||||
|
||||
except Exception as e:
|
||||
Logger(f"Error: ROUGE loss computation failed: {e}", accelerator)
|
||||
Logger(f"Error: 分类损失计算失败: {e}", accelerator)
|
||||
import traceback
|
||||
Logger(f"Traceback: {traceback.format_exc()}", accelerator)
|
||||
loss = res.logits.sum() * 0.0 + 1.0
|
||||
@ -683,13 +741,13 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
f"GPU利用率: {gpu_time_total/iter_time*100:.1f}%", accelerator)
|
||||
Logger("=" * 50, accelerator)
|
||||
|
||||
Logger("=== 三元组预测示例 ===", accelerator)
|
||||
predict_sentences = triple_to_sentence(res.subject_logits, res.predicate_logits, res.object_logits,tokenizer)
|
||||
# 显示前2个样本的目标句子
|
||||
for i, target_sentence in enumerate(target_sentences[:2]):
|
||||
Logger(f"样本{i+1}目标: {target_sentence}", accelerator)
|
||||
Logger(f"样本{i+1}预测: {predict_sentences[i]}", accelerator)
|
||||
Logger("==================", accelerator)
|
||||
# Logger("=== 三元组预测示例 ===", accelerator)
|
||||
# predict_sentences = triple_to_sentence(res.subject_logits, res.predicate_logits, res.object_logits,tokenizer)
|
||||
# # 显示前2个样本的目标句子
|
||||
# for i, target_sentence in enumerate(target_sentences[:2]):
|
||||
# Logger(f"样本{i+1}目标: {target_sentence}", accelerator)
|
||||
# Logger(f"样本{i+1}预测: {predict_sentences[i]}", accelerator)
|
||||
Logger("=======val dataset=========", accelerator)
|
||||
|
||||
# 重置GPU事件
|
||||
forward_start = torch.cuda.Event(enable_timing=True)
|
||||
@ -734,11 +792,20 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
|
||||
# SwanLab日志记录
|
||||
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
|
||||
Logger("=======val dataset=========", accelerator)
|
||||
|
||||
# 验证集评估
|
||||
val_results = validate_model(model, val_loader, accelerator, ctx, args)
|
||||
Logger(f"验证集结果 - 平均损失: {val_results['avg_loss']:.6f}, 准确率: {val_results['accuracy']:.4f}, 样本数: {val_results['total_samples']}", accelerator)
|
||||
|
||||
log_dict = {
|
||||
"epoch": epoch + 1,
|
||||
"step": step + 1,
|
||||
"total_steps_in_epoch": total_steps_in_epoch,
|
||||
"triple_embedding_cosine_loss": loss.item() * args.accumulation_steps,
|
||||
"train_loss": loss.item() * args.accumulation_steps,
|
||||
"val_loss": val_results['avg_loss'],
|
||||
"val_accuracy": val_results['accuracy'],
|
||||
"val_samples": val_results['total_samples'],
|
||||
"lr": current_lr,
|
||||
"tokens_per_sec": tokens_per_sec,
|
||||
"epoch_time_left_seconds": epoch_remaining_time,
|
||||
@ -776,7 +843,7 @@ def main():
|
||||
parser.add_argument("--out_dir", type=str, default="out")
|
||||
parser.add_argument("--epochs", type=int, default=4)
|
||||
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
|
||||
parser.add_argument("--batch_size", type=int, default=192)
|
||||
parser.add_argument("--batch_size", type=int, default=256)
|
||||
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数
|
||||
@ -793,6 +860,7 @@ def main():
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/processed_trex_data.json")
|
||||
parser.add_argument("--predicate_vocab_path", type=str, default="./dataset/predicate_stats.json", help="Path to predicate vocabulary/statistics file")
|
||||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||
@ -932,7 +1000,8 @@ def main():
|
||||
# 创建数据集和数据加载器(专用于三元组提取训练)
|
||||
#########################################################
|
||||
Logger("三元组提取训练:使用 TriplePretrainDataset", accelerator)
|
||||
train_ds = TriplePretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_ds = TriplePretrainDataset(data_path=args.data_path, predicate_vocab_path=args.predicate_vocab_path, tokenizer=tokenizer, max_length=lm_config.max_seq_len)
|
||||
val_ds = TriplePretrainDataset(data_path=args.data_path,samples=train_ds.get_val_samples(), predicate_vocab_path=args.predicate_vocab_path, tokenizer=tokenizer, max_length=lm_config.max_seq_len)
|
||||
|
||||
# 创建自定义collate_fn来处理优化后的数据格式
|
||||
def triple_collate_fn(batch):
|
||||
@ -940,17 +1009,17 @@ def main():
|
||||
input_ids = torch.stack([item['input_ids'] for item in batch])
|
||||
labels = torch.stack([item['labels'] for item in batch])
|
||||
loss_mask = torch.stack([item['loss_mask'] for item in batch])
|
||||
target_input_ids = torch.stack([item['target_input_ids'] for item in batch])
|
||||
target_attention_mask = torch.stack([item['target_attention_mask'] for item in batch])
|
||||
target_sentences = [item['target_sentence'] for item in batch] # 用于调试
|
||||
# target_input_ids = torch.stack([item['target_input_ids'] for item in batch])
|
||||
# target_attention_mask = torch.stack([item['target_attention_mask'] for item in batch])
|
||||
# target_sentences = [item['target_sentence'] for item in batch] # 用于调试
|
||||
|
||||
return {
|
||||
'input_ids': input_ids,
|
||||
'labels': labels,
|
||||
'loss_mask': loss_mask,
|
||||
'target_input_ids': target_input_ids,
|
||||
'target_attention_mask': target_attention_mask,
|
||||
'target_sentences': target_sentences
|
||||
# 'target_input_ids': target_input_ids,
|
||||
# 'target_attention_mask': target_attention_mask,
|
||||
# 'target_sentences': target_sentences
|
||||
}
|
||||
|
||||
train_loader = DataLoader(
|
||||
@ -963,6 +1032,15 @@ def main():
|
||||
# persistent_workers 和 prefetch_factor 在 num_workers=0 时自动禁用
|
||||
collate_fn=triple_collate_fn
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=False,
|
||||
drop_last=True,
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
collate_fn=triple_collate_fn
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# 创建优化器
|
||||
@ -993,7 +1071,7 @@ def main():
|
||||
overall_start_time = time.time() # Record overall start time
|
||||
for epoch in range(args.epochs):
|
||||
Logger(f"开始第{epoch+1}轮训练", accelerator)
|
||||
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer) # Pass tokenizer
|
||||
train_epoch(epoch, accelerator, model, train_loader,val_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run, tokenizer) # Pass tokenizer
|
||||
|
||||
# 每个epoch结束后进行内存清理
|
||||
Logger(f"第{epoch+1}轮训练完成,进行内存清理", accelerator)
|
||||
|
Loading…
x
Reference in New Issue
Block a user