对数据库进行了初始化
This commit is contained in:
parent
c09cd63794
commit
c96a9c35d5
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,3 +3,5 @@
|
|||||||
/out
|
/out
|
||||||
wandb/
|
wandb/
|
||||||
**/*.log
|
**/*.log
|
||||||
|
models/sentence_transformers/
|
||||||
|
models/sentence_transformers_cache/
|
97
preprocessing/README_trex_processor.md
Normal file
97
preprocessing/README_trex_processor.md
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
# TREx 数据集处理工具使用说明
|
||||||
|
|
||||||
|
这个工具支持两步骤处理 TREx 数据集:
|
||||||
|
1. **句子提取**:从 TREx 数据集提取三元组并转换为自然语言句子
|
||||||
|
2. **LLM 处理**:使用 ollama qwen3:4b 模型进行句子修正和重要性评分
|
||||||
|
|
||||||
|
## 安装依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install agno asyncio pydantic
|
||||||
|
```
|
||||||
|
|
||||||
|
确保已安装并启动 ollama,并下载 qwen3:4b 模型:
|
||||||
|
```bash
|
||||||
|
ollama pull qwen3:4b
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
### 1. 完整流程(两步骤连续执行)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python trex_to_sentences_simple.py --step all --input_dir dataset/TREx --max_files 2
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 分步骤执行
|
||||||
|
|
||||||
|
#### 步骤1:仅提取句子
|
||||||
|
```bash
|
||||||
|
python trex_to_sentences_simple.py --step extract --input_dir dataset/TREx --sentences_json my_sentences.json --max_files 2
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 步骤2:仅LLM处理
|
||||||
|
```bash
|
||||||
|
python trex_to_sentences_simple.py --step llm --sentences_json my_sentences.json --output_file final_output.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## 主要参数说明
|
||||||
|
|
||||||
|
- `--step`: 运行步骤
|
||||||
|
- `extract`: 仅提取句子
|
||||||
|
- `llm`: 仅LLM处理
|
||||||
|
- `all`: 完整流程(默认)
|
||||||
|
|
||||||
|
- `--input_dir`: TREx数据集目录(默认:`dataset/TREx`)
|
||||||
|
- `--sentences_json`: 提取的句子JSON文件(默认:`extracted_sentences.json`)
|
||||||
|
- `--output_file`: 最终输出文件(默认:`trex_sentences_enhanced.txt`)
|
||||||
|
- `--max_files`: 最大处理文件数(用于测试)
|
||||||
|
- `--no_llm`: 禁用LLM处理
|
||||||
|
|
||||||
|
## 输出文件
|
||||||
|
|
||||||
|
**注意:所有输出文件都会自动保存在 `./output/` 目录中**
|
||||||
|
|
||||||
|
### 步骤1输出
|
||||||
|
- `output/extracted_sentences.json`: 提取的原始句子,包含元数据
|
||||||
|
|
||||||
|
### 步骤2输出
|
||||||
|
- `output/{output_file}.txt`: 修正后的句子文本文件
|
||||||
|
- `output/{output_file}.json`: 完整的处理结果(包含原句、修正句、评分)
|
||||||
|
- `output/{output_file}_sorted_by_importance.txt`: 按重要性评分排序的句子
|
||||||
|
|
||||||
|
### 检查点文件
|
||||||
|
- `output/{output_file}_checkpoint_{数量}.json`: 每2000条句子自动保存的检查点
|
||||||
|
|
||||||
|
## 检查点恢复机制
|
||||||
|
|
||||||
|
- 步骤2会自动检测已有的检查点文件(在 `output/` 目录中)
|
||||||
|
- 只处理尚未处理的句子,避免重复工作
|
||||||
|
- 如果所有句子都已处理,会直接生成最终输出文件
|
||||||
|
|
||||||
|
## 示例工作流
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. 先提取句子(可以快速完成)
|
||||||
|
python trex_to_sentences_simple.py --step extract --max_files 5
|
||||||
|
|
||||||
|
# 2. 后续进行LLM处理(耗时较长,支持断点续传)
|
||||||
|
python trex_to_sentences_simple.py --step llm
|
||||||
|
|
||||||
|
# 如果中途中断,再次运行步骤2会自动从检查点恢复
|
||||||
|
python trex_to_sentences_simple.py --step llm
|
||||||
|
```
|
||||||
|
|
||||||
|
## 性能特点
|
||||||
|
|
||||||
|
- **并发处理**: 最大54个并发LLM请求
|
||||||
|
- **检查点保存**: 每2000条句子自动保存,支持断点续传
|
||||||
|
- **进度显示**: 详细的处理进度和时间预估
|
||||||
|
- **错误处理**: LLM请求失败时使用原句子和默认评分
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
1. 首次运行步骤2前,必须先完成步骤1
|
||||||
|
2. 检查点文件会占用额外磁盘空间(每个都包含所有已处理数据)
|
||||||
|
3. LLM处理速度取决于模型性能和网络状况
|
||||||
|
4. 建议先用`--max_files`参数测试小批量数据
|
File diff suppressed because it is too large
Load Diff
@ -3,6 +3,7 @@ import os
|
|||||||
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
|
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
|
||||||
import platform
|
import platform
|
||||||
import argparse
|
import argparse
|
||||||
|
from tqdm import tqdm
|
||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
@ -18,8 +19,10 @@ from accelerate.utils import set_seed
|
|||||||
from accelerate.utils import DeepSpeedPlugin
|
from accelerate.utils import DeepSpeedPlugin
|
||||||
from accelerate.utils import DistributedDataParallelKwargs
|
from accelerate.utils import DistributedDataParallelKwargs
|
||||||
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
from model.model import MiniMindLM
|
from model.model import MiniMindLM, RMSNorm
|
||||||
from model.LMConfig import LMConfig
|
from model.LMConfig import LMConfig
|
||||||
from model.dataset import PretrainDataset
|
from model.dataset import PretrainDataset
|
||||||
|
|
||||||
@ -41,10 +44,41 @@ def get_lr(it, num_iters, learning_rate):
|
|||||||
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
|
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
|
||||||
|
|
||||||
# 初始化模型函数
|
# 初始化模型函数
|
||||||
def init_model(lm_config, pretrained_embedding_path=None):
|
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
|
||||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||||
model = MiniMindLM(lm_config)
|
model = MiniMindLM(lm_config)
|
||||||
|
|
||||||
|
# 默认模型初始化
|
||||||
|
Logger("Performing default model initialization...")
|
||||||
|
|
||||||
|
# 初始化嵌入层权重
|
||||||
|
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
# 初始化输出层权重(如果不共享权重的话)
|
||||||
|
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
|
||||||
|
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
# 初始化所有线性层
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
# 使用Xavier/Glorot初始化
|
||||||
|
nn.init.xavier_uniform_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
# 嵌入层使用正态分布初始化
|
||||||
|
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
elif isinstance(module, RMSNorm):
|
||||||
|
# RMSNorm的权重初始化为1
|
||||||
|
if hasattr(module, 'weight'):
|
||||||
|
nn.init.ones_(module.weight)
|
||||||
|
|
||||||
|
# 初始化位置编码相关参数
|
||||||
|
if hasattr(model.extract_db, 'keys'):
|
||||||
|
nn.init.normal_(model.extract_db.keys, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
Logger("Default model initialization completed")
|
||||||
|
|
||||||
# 如果提供了预训练的嵌入权重,加载它们
|
# 如果提供了预训练的嵌入权重,加载它们
|
||||||
if pretrained_embedding_path:
|
if pretrained_embedding_path:
|
||||||
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
|
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
|
||||||
@ -52,6 +86,334 @@ def init_model(lm_config, pretrained_embedding_path=None):
|
|||||||
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
|
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
|
||||||
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
|
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
|
||||||
|
|
||||||
|
if database_init_path:
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
import os
|
||||||
|
|
||||||
|
Logger(f"Loading database initialization data from {database_init_path}")
|
||||||
|
|
||||||
|
# 1. 加载JSON文件并转换为字典
|
||||||
|
with open(database_init_path, 'r', encoding='utf-8') as f:
|
||||||
|
database_data = json.load(f)
|
||||||
|
|
||||||
|
# 提取sentences列表
|
||||||
|
sentences_data = database_data.get('sentences', [])
|
||||||
|
Logger(f"Loaded {len(sentences_data)} sentences from database")
|
||||||
|
|
||||||
|
# 2. 按照importance_score进行排序(从高到低)
|
||||||
|
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
||||||
|
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
|
||||||
|
|
||||||
|
# 3. 下载并初始化本地嵌入模型
|
||||||
|
embedding_model_name = "sentence-transformers/all-mpnet-base-v2" # 轻量级但效果好的模型
|
||||||
|
embedding_model_dir = "./models/sentence_transformers/models--sentence-transformers--all-mpnet-base-v2"
|
||||||
|
embedding_cache_dir = "./models/sentence_transformers/cache"
|
||||||
|
os.makedirs(embedding_cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
Logger(f"Loading embedding model: {embedding_model_name}")
|
||||||
|
try:
|
||||||
|
embedding_model = SentenceTransformer(embedding_model_dir, cache_folder=embedding_cache_dir)
|
||||||
|
Logger("Embedding model loaded successfully")
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"Failed to load embedding model: {e}")
|
||||||
|
Logger("Falling back to random embeddings")
|
||||||
|
embedding_model = None
|
||||||
|
|
||||||
|
# 4. 对每个corrected_sentence进行嵌入和token长度计算
|
||||||
|
Logger("Processing sentences for embeddings and token lengths...")
|
||||||
|
|
||||||
|
# 提取所有句子
|
||||||
|
sentences = [sentence_data.get('corrected_sentence', '') for sentence_data in sorted_sentences]
|
||||||
|
|
||||||
|
# 批量计算token长度
|
||||||
|
Logger("Computing token lengths...")
|
||||||
|
token_lengths = []
|
||||||
|
for sentence in sentences:
|
||||||
|
tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
||||||
|
token_lengths.append(len(tokens))
|
||||||
|
|
||||||
|
# 批量计算嵌入 - 大幅提升速度
|
||||||
|
Logger("Computing embeddings in batches...")
|
||||||
|
embeddings_list = []
|
||||||
|
batch_size = 256 # 可以根据GPU内存调整
|
||||||
|
|
||||||
|
if embedding_model is not None:
|
||||||
|
try:
|
||||||
|
for i in range(0, len(sentences), batch_size):
|
||||||
|
batch_sentences = sentences[i:i+batch_size]
|
||||||
|
batch_embeddings = embedding_model.encode(
|
||||||
|
batch_sentences,
|
||||||
|
convert_to_tensor=False,
|
||||||
|
show_progress_bar=True if i == 0 else False,
|
||||||
|
batch_size=batch_size
|
||||||
|
)
|
||||||
|
embeddings_list.extend(batch_embeddings)
|
||||||
|
|
||||||
|
if (i + batch_size) % (batch_size * 10) == 0:
|
||||||
|
Logger(f"Processed {min(i + batch_size, len(sentences))}/{len(sentences)} sentences")
|
||||||
|
|
||||||
|
Logger("Batch embedding computation completed")
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"Error in batch encoding: {e}")
|
||||||
|
Logger("Falling back to random embeddings")
|
||||||
|
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
|
||||||
|
else:
|
||||||
|
# 使用随机嵌入
|
||||||
|
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
|
||||||
|
|
||||||
|
# 创建处理后的句子列表
|
||||||
|
processed_sentences = []
|
||||||
|
for i, (sentence_data, embedding, token_length) in enumerate(zip(sorted_sentences, embeddings_list, token_lengths)):
|
||||||
|
processed_sentences.append({
|
||||||
|
'sentence': sentence_data.get('corrected_sentence', ''),
|
||||||
|
'importance_score': sentence_data.get('importance_score', 0.0),
|
||||||
|
'token_length': token_length,
|
||||||
|
'embedding': embedding, # Convert numpy array to list
|
||||||
|
'original_index': i
|
||||||
|
})
|
||||||
|
|
||||||
|
# # Create a JSON-serializable version for saving
|
||||||
|
# json_serializable_sentences = []
|
||||||
|
# for sentence in processed_sentences:
|
||||||
|
# json_sentence = sentence.copy()
|
||||||
|
# # Convert embedding to list if it's a numpy array
|
||||||
|
# if hasattr(json_sentence['embedding'], 'tolist'):
|
||||||
|
# json_sentence['embedding'] = json_sentence['embedding'].tolist()
|
||||||
|
# json_serializable_sentences.append(json_sentence)
|
||||||
|
|
||||||
|
# json.dump(json_serializable_sentences, open('processed_sentences.json', 'w', encoding='utf-8'))
|
||||||
|
|
||||||
|
# processed_sentences = json.load(open('processed_sentences.json', 'r', encoding='utf-8'))
|
||||||
|
|
||||||
|
# 转换为numpy数组以便后续处理
|
||||||
|
embeddings_array = np.array(embeddings_list)
|
||||||
|
token_lengths_array = np.array(token_lengths)
|
||||||
|
|
||||||
|
Logger(f"Embedding processing completed:")
|
||||||
|
Logger(f" - Total sentences: {len(processed_sentences)}")
|
||||||
|
Logger(f" - Embedding shape: {embeddings_array.shape}")
|
||||||
|
Logger(f" - Average token length: {np.mean(token_lengths_array):.2f}")
|
||||||
|
Logger(f" - Token length range: {np.min(token_lengths_array)} - {np.max(token_lengths_array)}")
|
||||||
|
|
||||||
|
# 2. 聚类处理 - 优化版本
|
||||||
|
Logger("Starting optimized clustering process...")
|
||||||
|
|
||||||
|
# 聚类参数
|
||||||
|
knowledge_num = args.knowledge_num
|
||||||
|
knowledge_length = args.knowledge_length
|
||||||
|
min_tokens = int(0.9 * knowledge_length)
|
||||||
|
max_tokens = knowledge_length
|
||||||
|
|
||||||
|
# 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大)
|
||||||
|
if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算
|
||||||
|
Logger("Pre-computing similarity matrix for faster clustering...")
|
||||||
|
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
|
||||||
|
similarity_matrix = cosine_similarity(embeddings_matrix)
|
||||||
|
Logger(f"Similarity matrix computed: {similarity_matrix.shape}")
|
||||||
|
else:
|
||||||
|
similarity_matrix = None
|
||||||
|
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
|
||||||
|
|
||||||
|
clustered_rows = []
|
||||||
|
remaining_indices = list(range(len(processed_sentences))) # 使用索引而不是对象
|
||||||
|
|
||||||
|
Logger(f"Target: {knowledge_num} clusters, each with {min_tokens}-{max_tokens} tokens")
|
||||||
|
|
||||||
|
# 选择聚类算法
|
||||||
|
if args.fast_clustering and len(processed_sentences) > 5000:
|
||||||
|
Logger("Using ultra-fast approximate clustering algorithm...")
|
||||||
|
|
||||||
|
# 超快速聚类:随机采样 + 批量处理
|
||||||
|
import random
|
||||||
|
random.seed(42) # 确保可重现性
|
||||||
|
|
||||||
|
# 按重要性分层采样
|
||||||
|
high_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] > 0.7]
|
||||||
|
medium_importance = [i for i, s in enumerate(processed_sentences) if 0.3 <= s['importance_score'] <= 0.7]
|
||||||
|
low_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] < 0.3]
|
||||||
|
|
||||||
|
Logger(f"Importance distribution: High={len(high_importance)}, Medium={len(medium_importance)}, Low={len(low_importance)}")
|
||||||
|
|
||||||
|
for cluster_idx in tqdm(range(knowledge_num)):
|
||||||
|
# 分层选择种子:优先选择高重要性句子
|
||||||
|
if high_importance:
|
||||||
|
seed_pool = high_importance
|
||||||
|
elif medium_importance:
|
||||||
|
seed_pool = medium_importance
|
||||||
|
else:
|
||||||
|
seed_pool = low_importance if low_importance else list(range(len(processed_sentences)))
|
||||||
|
|
||||||
|
if not seed_pool:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 随机选择种子(在同一重要性层级内)
|
||||||
|
seed_global_idx = random.choice(seed_pool)
|
||||||
|
seed_sentence = processed_sentences[seed_global_idx]
|
||||||
|
|
||||||
|
# 从所有池中移除种子
|
||||||
|
for pool in [high_importance, medium_importance, low_importance]:
|
||||||
|
if seed_global_idx in pool:
|
||||||
|
pool.remove(seed_global_idx)
|
||||||
|
|
||||||
|
current_cluster_indices = [seed_global_idx]
|
||||||
|
current_tokens = seed_sentence['token_length']
|
||||||
|
|
||||||
|
if current_tokens < max_tokens:
|
||||||
|
# 快速选择:只从附近的句子中随机选择
|
||||||
|
all_remaining = high_importance + medium_importance + low_importance
|
||||||
|
if all_remaining:
|
||||||
|
# 随机采样候选句子(而不是计算所有相似度)
|
||||||
|
sample_size = min(100, len(all_remaining))
|
||||||
|
candidates = random.sample(all_remaining, sample_size)
|
||||||
|
|
||||||
|
# 简单按token长度和重要性选择
|
||||||
|
for candidate_idx in candidates:
|
||||||
|
candidate = processed_sentences[candidate_idx]
|
||||||
|
candidate_tokens = candidate['token_length']
|
||||||
|
|
||||||
|
if current_tokens + candidate_tokens + 1 <= max_tokens:
|
||||||
|
current_cluster_indices.append(candidate_idx)
|
||||||
|
current_tokens += candidate_tokens + 1
|
||||||
|
|
||||||
|
# 从池中移除
|
||||||
|
for pool in [high_importance, medium_importance, low_importance]:
|
||||||
|
if candidate_idx in pool:
|
||||||
|
pool.remove(candidate_idx)
|
||||||
|
break
|
||||||
|
|
||||||
|
if current_tokens >= min_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 生成聚类文本
|
||||||
|
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
|
||||||
|
cluster_text = '\n'.join(cluster_sentences)
|
||||||
|
|
||||||
|
# 转换为tokens
|
||||||
|
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
|
||||||
|
if len(cluster_tokens) > knowledge_length:
|
||||||
|
cluster_tokens = cluster_tokens[:knowledge_length]
|
||||||
|
else:
|
||||||
|
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||||
|
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
|
||||||
|
|
||||||
|
clustered_rows.append(cluster_tokens)
|
||||||
|
|
||||||
|
if (cluster_idx + 1) % 1000 == 0:
|
||||||
|
total_remaining = len(high_importance) + len(medium_importance) + len(low_importance)
|
||||||
|
Logger(f"Fast clustering: {cluster_idx + 1}/{knowledge_num} clusters, {total_remaining} sentences remaining")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 原始优化算法(适用于中等规模数据集)
|
||||||
|
# 优化2: 批量处理和更高效的数据结构
|
||||||
|
for cluster_idx in tqdm(range(knowledge_num)):
|
||||||
|
if not remaining_indices:
|
||||||
|
Logger(f"No more sentences available. Created {cluster_idx} clusters.")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 2.1 选择importance_score最高的句子作为种子
|
||||||
|
remaining_sentences_subset = [processed_sentences[i] for i in remaining_indices]
|
||||||
|
seed_idx_in_subset = max(range(len(remaining_sentences_subset)),
|
||||||
|
key=lambda i: remaining_sentences_subset[i]['importance_score'])
|
||||||
|
seed_global_idx = remaining_indices[seed_idx_in_subset]
|
||||||
|
seed_sentence = processed_sentences[seed_global_idx]
|
||||||
|
|
||||||
|
# 从剩余索引中移除种子
|
||||||
|
remaining_indices.remove(seed_global_idx)
|
||||||
|
|
||||||
|
# 当前聚类
|
||||||
|
current_cluster_indices = [seed_global_idx]
|
||||||
|
current_tokens = seed_sentence['token_length']
|
||||||
|
|
||||||
|
if current_tokens >= max_tokens:
|
||||||
|
# 如果种子句子已经超过最大token数,直接作为一个聚类
|
||||||
|
cluster_text = seed_sentence['sentence']
|
||||||
|
else:
|
||||||
|
# 2.2 优化的相似度计算和选择
|
||||||
|
if remaining_indices:
|
||||||
|
if similarity_matrix is not None:
|
||||||
|
# 使用预计算的相似度矩阵
|
||||||
|
similarities = similarity_matrix[seed_global_idx][remaining_indices]
|
||||||
|
else:
|
||||||
|
# 动态计算相似度(批量)
|
||||||
|
seed_embedding = embeddings_matrix[seed_global_idx:seed_global_idx+1]
|
||||||
|
remaining_embeddings = embeddings_matrix[remaining_indices]
|
||||||
|
similarities = cosine_similarity(seed_embedding, remaining_embeddings)[0]
|
||||||
|
|
||||||
|
# 创建(相似度, 原始索引, 在remaining_indices中的位置)的元组列表
|
||||||
|
similarity_tuples = [(similarities[i], remaining_indices[i], i)
|
||||||
|
for i in range(len(remaining_indices))]
|
||||||
|
|
||||||
|
# 按相似度排序(降序)
|
||||||
|
similarity_tuples.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
|
||||||
|
# 优化3: 贪心选择,但限制搜索范围以提高速度
|
||||||
|
max_candidates = min(len(similarity_tuples), 500) # 只考虑前500个最相似的句子
|
||||||
|
|
||||||
|
selected_indices_in_remaining = []
|
||||||
|
for sim_score, global_idx, pos_in_remaining in similarity_tuples[:max_candidates]:
|
||||||
|
candidate = processed_sentences[global_idx]
|
||||||
|
candidate_tokens = candidate['token_length']
|
||||||
|
|
||||||
|
if current_tokens + candidate_tokens + 1 <= max_tokens: # +1 for newline
|
||||||
|
current_cluster_indices.append(global_idx)
|
||||||
|
selected_indices_in_remaining.append(pos_in_remaining)
|
||||||
|
current_tokens += candidate_tokens + 1
|
||||||
|
|
||||||
|
if current_tokens >= min_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 批量移除选中的句子(从后往前移除以避免索引问题)
|
||||||
|
for pos in sorted(selected_indices_in_remaining, reverse=True):
|
||||||
|
remaining_indices.pop(pos)
|
||||||
|
|
||||||
|
# 拼接句子
|
||||||
|
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
|
||||||
|
cluster_text = '\n'.join(cluster_sentences)
|
||||||
|
|
||||||
|
# 将聚类文本转换为token
|
||||||
|
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
|
||||||
|
|
||||||
|
# 截断或填充到knowledge_length
|
||||||
|
if len(cluster_tokens) > knowledge_length:
|
||||||
|
cluster_tokens = cluster_tokens[:knowledge_length]
|
||||||
|
else:
|
||||||
|
# 用pad_token_id填充
|
||||||
|
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||||
|
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
|
||||||
|
|
||||||
|
clustered_rows.append(cluster_tokens)
|
||||||
|
|
||||||
|
# 优化4: 减少日志频率
|
||||||
|
if (cluster_idx + 1) % 500 == 0:
|
||||||
|
Logger(f"Created {cluster_idx + 1}/{knowledge_num} clusters, {len(remaining_indices)} sentences remaining")
|
||||||
|
|
||||||
|
# 如果聚类数量不足,用随机token填充
|
||||||
|
while len(clustered_rows) < knowledge_num:
|
||||||
|
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||||
|
random_tokens = [pad_token_id] * knowledge_length
|
||||||
|
clustered_rows.append(random_tokens)
|
||||||
|
|
||||||
|
# 转换为tensor
|
||||||
|
clustered_tensor = torch.tensor(clustered_rows, dtype=torch.long)
|
||||||
|
|
||||||
|
Logger(f"Clustering completed:")
|
||||||
|
Logger(f" - Created {len(clustered_rows)} clusters")
|
||||||
|
Logger(f" - Cluster shape: {clustered_tensor.shape}")
|
||||||
|
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
|
||||||
|
|
||||||
|
# 3. 初始化模型的weight_down_embed
|
||||||
|
if hasattr(model, 'extract_db') and hasattr(model.extract_db, 'weight_down_embed'):
|
||||||
|
model.extract_db.weight_down_embed.data.copy_(clustered_tensor)
|
||||||
|
Logger("Successfully initialized model.extract_db.weight_down_embed with clustered data")
|
||||||
|
else:
|
||||||
|
Logger("Warning: Could not find model.extract_db.weight_down_embed to initialize")
|
||||||
|
# 存储为全局变量作为备选
|
||||||
|
globals()['clustered_database'] = clustered_tensor
|
||||||
|
|
||||||
|
Logger(f"Database embeddings and sentences stored in model")
|
||||||
|
|
||||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
@ -290,7 +652,9 @@ def main():
|
|||||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||||
parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目")
|
parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目")
|
||||||
parser.add_argument("--knowledge_length", type=int, default=8,help="知识库的句子长度")
|
parser.add_argument("--knowledge_length", type=int, default=64,help="知识库的句子长度")
|
||||||
|
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
|
||||||
|
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
#########################################################
|
#########################################################
|
||||||
@ -379,7 +743,7 @@ def main():
|
|||||||
#########################################################
|
#########################################################
|
||||||
# 初始化模型和tokenizer
|
# 初始化模型和tokenizer
|
||||||
#########################################################
|
#########################################################
|
||||||
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
|
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path, args.database_init_path, args)
|
||||||
# 将accelerator传递给init_model函数中的Logger调用
|
# 将accelerator传递给init_model函数中的Logger调用
|
||||||
Logger(f'模型初始化完成', accelerator)
|
Logger(f'模型初始化完成', accelerator)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user