Compare commits

..

2 Commits

Author SHA1 Message Date
770c34f0e3 DynamicKV-LLM Pretrain v1.2.1 2025-06-08 02:20:36 +00:00
1678e739b6 DynamicKV-LLM Pretrain v1.2.0 2025-06-07 02:41:45 +00:00
6 changed files with 128 additions and 322 deletions

2
.gitignore vendored
View File

@ -7,3 +7,5 @@ models/sentence_transformers/
models/sentence_transformers_cache/ models/sentence_transformers_cache/
**/*.pyc **/*.pyc
qwen2-1.7B/ qwen2-1.7B/
images/
cache/

8
.vscode/launch.json vendored
View File

@ -7,7 +7,7 @@
"request": "launch", "request": "launch",
"program": "${workspaceFolder}/train_pretrain_accelerate.py", "program": "${workspaceFolder}/train_pretrain_accelerate.py",
"console": "integratedTerminal", "console": "integratedTerminal",
"python": "/home/iomgaa/miniconda3/envs/accelerate/bin/python", "python": "/opt/conda/envs/mini/bin/python",
"cwd": "${workspaceFolder}", "cwd": "${workspaceFolder}",
"env": { "env": {
"PYTHONPATH": "${workspaceFolder}", "PYTHONPATH": "${workspaceFolder}",
@ -23,7 +23,7 @@
"request": "launch", "request": "launch",
"program": "${workspaceFolder}/train_pretrain_accelerate.py", "program": "${workspaceFolder}/train_pretrain_accelerate.py",
"console": "integratedTerminal", "console": "integratedTerminal",
"python": "/home/iomgaa/miniconda3/envs/accelerate/bin/python", "python": "/opt/conda/envs/mini/bin/python",
"args": [ "args": [
"--hidden_size", "512", "--hidden_size", "512",
"--max_seq_len", "512", "--max_seq_len", "512",
@ -46,7 +46,7 @@
"request": "launch", "request": "launch",
"program": "${workspaceFolder}/train_pretrain_accelerate.py", "program": "${workspaceFolder}/train_pretrain_accelerate.py",
"console": "integratedTerminal", "console": "integratedTerminal",
"python": "/home/iomgaa/miniconda3/envs/accelerate/bin/python", "python": "/opt/conda/envs/mini/bin/python",
"args": [ "args": [
"--hidden_size", "512", "--hidden_size", "512",
"--max_seq_len", "512", "--max_seq_len", "512",
@ -73,7 +73,7 @@
"request": "launch", "request": "launch",
"program": "${workspaceFolder}/train_pretrain_accelerate.py", "program": "${workspaceFolder}/train_pretrain_accelerate.py",
"console": "integratedTerminal", "console": "integratedTerminal",
"python": "/home/iomgaa/miniconda3/envs/accelerate/bin/python", "python": "/opt/conda/envs/mini/bin/python",
"args": [ "args": [
"--hidden_size", "512", "--hidden_size", "512",
"--max_seq_len", "256", "--max_seq_len", "256",

View File

@ -19,6 +19,7 @@ class LMConfig(PretrainedConfig):
rope_theta: int = 1e6, rope_theta: int = 1e6,
dropout: float = 0.0, dropout: float = 0.0,
flash_attn: bool = True, flash_attn: bool = True,
embeddings_epoch: int = 2,
#################################################### ####################################################
# DB related configurations # DB related configurations
#################################################### ####################################################
@ -54,6 +55,7 @@ class LMConfig(PretrainedConfig):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.dropout = dropout self.dropout = dropout
self.flash_attn = flash_attn self.flash_attn = flash_attn
self.embeddings_epoch = embeddings_epoch
#################################################### ####################################################
# DB related configurations # DB related configurations
#################################################### ####################################################

View File

@ -81,6 +81,8 @@ class KnowledgeDataset(nn.Module):
# 计算step数目用于动态调整权重 # 计算step数目用于动态调整权重
self.step_counter = 0 self.step_counter = 0
self.freeze_embedding = False
def intelligent_selection(self, query, all_scores, all_indices): def intelligent_selection(self, query, all_scores, all_indices):
@ -169,6 +171,8 @@ class KnowledgeDataset(nn.Module):
return all_best_tokens, all_best_tokens_embeddings return all_best_tokens, all_best_tokens_embeddings
def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings): def _update_keys_with_embeddings(self, pre_update_indices, pre_update_embeddings):
if self.freeze_embedding:
return
# 使用pre_update_embeddings更新self.keys # 使用pre_update_embeddings更新self.keys
with torch.no_grad(): with torch.no_grad():
pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512] pre_update_embeddings = pre_update_embeddings.mean(dim=1) # [337, 512]
@ -199,8 +203,26 @@ class KnowledgeDataset(nn.Module):
if self.is_train: if self.is_train:
# 获取未更新过的keys的索引 # 获取未更新过的keys的索引
not_updated_indices = torch.where(self.has_update_keys == 0)[0] not_updated_indices = torch.where(self.has_update_keys == 0)[0]
# 如果有未更新的keys随机选择num_update_keys个进行更新 # 如果有未更新的keys随机选择num_update_keys个进行更新
if len(not_updated_indices) > 0: if len(not_updated_indices) > 0:
num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
perm_num = perm.shape[0]
pre_update_indices = not_updated_indices[perm]
pre_update_tokens = self.knowledge_dataset[pre_update_indices]
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(perm_num, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
else:
print("all keys are updated")
# 重置所有keys的更新状态
self.has_update_keys.zero_()
# 重新获取所有可更新的索引
not_updated_indices = torch.arange(len(self.has_update_keys), device=self.has_update_keys.device)
num_update_keys = int(self.knowledge_num * 0.01) num_update_keys = int(self.knowledge_num * 0.01)
perm = torch.randperm(len(not_updated_indices))[:num_update_keys] perm = torch.randperm(len(not_updated_indices))[:num_update_keys]
pre_update_indices = not_updated_indices[perm] pre_update_indices = not_updated_indices[perm]
@ -208,6 +230,12 @@ class KnowledgeDataset(nn.Module):
pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1)) pre_update_embeddings = self.tok_embeddings(pre_update_tokens.view(-1))
pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1) pre_update_embeddings = pre_update_embeddings.view(num_update_keys, self.knowledge_length, -1)
self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings) self._update_keys_with_embeddings(pre_update_indices, pre_update_embeddings)
# 更新被修改过的key
with torch.no_grad():
self.has_update_keys[pre_update_indices] = 1
return best_tokens, best_tokens_embeddings return best_tokens, best_tokens_embeddings
@ -484,12 +512,20 @@ class MiniMindLM(PreTrainedModel):
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta), precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False) persistent=False)
self.OUT = CausalLMOutputWithPast() self.OUT = CausalLMOutputWithPast()
self.freeze_embedding = False
def forward(self, def forward(self,
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0, logits_to_keep: Union[int, torch.Tensor] = 0,
step: int = 0,
**args): **args):
start_pos = args.get('start_pos', 0) start_pos = args.get('start_pos', 0)
if self.freeze_embedding and step == 0:
self.tok_embeddings.weight.requires_grad = False
# 同时冻结KnowledgeDataset的嵌入更新
self.knowledge_dataset.freeze_embedding = True
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
print("knowledge_dataset.freeze_embedding: ", self.knowledge_dataset.freeze_embedding)
h = self.dropout(self.tok_embeddings(input_ids)) h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):

View File

@ -1,8 +1,8 @@
#!/bin/bash #!/bin/bash
# 激活conda环境 # 激活conda环境
# source $(conda info --base)/etc/profile.d/conda.sh source $(conda info --base)/etc/profile.d/conda.sh
# conda activate ycz_accelerate conda activate mini
# 设置环境变量以帮助调试 # 设置环境变量以帮助调试
export NCCL_DEBUG=INFO export NCCL_DEBUG=INFO
@ -26,24 +26,9 @@ export PYTHONFAULTHANDLER=1
# --profile_interval 10 # --profile_interval 10
# 方法2: 使用命令行参数直接配置accelerate # 方法2: 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0 accelerate launch \ CUDA_VISIBLE_DEVICES=0 /opt/conda/envs/mini/bin/python -m accelerate.commands.launch \
--num_processes=1 \ --num_processes=1 \
--mixed_precision=bf16 \ --mixed_precision=bf16 \
--main_process_port=29500 \ --main_process_port=29500 \
train_pretrain_accelerate.py \ train_pretrain_accelerate.py \
--epochs 3 \
--batch_size 24 \
--learning_rate 2e-4 \
--dtype bfloat16 \
--accumulation_steps 32 \
--grad_clip 1.0 \
--log_interval 100 \
--save_interval 10000 \
--dim 512 \
--n_layers 12 \
--max_seq_len 512 \
--use_flash_attn \
--profile \
--profile_interval 10\
--knowledge_num 4096 \
--knowledge_length 8

View File

@ -88,54 +88,52 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
if database_init_path: if database_init_path:
import json import json
import numpy as np
from sentence_transformers import SentenceTransformer
import os import os
# 聚类参数(需要提前定义用于缓存检查) # 数据库参数
knowledge_num = args.knowledge_num knowledge_num = args.knowledge_num
knowledge_length = args.knowledge_length knowledge_length = args.knowledge_length
# 检查是否使用缓存(提前检查,避免不必要的数据处理) # 检查是否使用缓存
cache_dir = os.path.dirname(args.cluster_cache_path) cache_dir = os.path.dirname(args.cluster_cache_path)
if cache_dir: if cache_dir:
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
clustered_tensor = None processed_tensor = None
# 尝试加载缓存的聚类结果 # 尝试加载缓存的处理结果
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path): if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
try: try:
Logger(f"Loading cached cluster results from {args.cluster_cache_path}") Logger(f"Loading cached processed results from {args.cluster_cache_path}")
clustered_tensor = torch.load(args.cluster_cache_path) processed_tensor = torch.load(args.cluster_cache_path)
# 验证缓存文件的形状是否可用 # 验证缓存文件的形状是否可用
cached_knowledge_num, cached_knowledge_length = clustered_tensor.shape cached_knowledge_num, cached_knowledge_length = processed_tensor.shape
if cached_knowledge_length == knowledge_length: if cached_knowledge_length == knowledge_length:
if cached_knowledge_num >= knowledge_num: if cached_knowledge_num >= knowledge_num:
# 缓存足够大,可以截取使用 # 缓存足够大,可以截取使用
clustered_tensor = clustered_tensor[:knowledge_num, :] processed_tensor = processed_tensor[:knowledge_num, :]
Logger(f"Successfully loaded cached clusters with shape {clustered_tensor.shape}") Logger(f"Successfully loaded cached data with shape {processed_tensor.shape}")
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})") Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
Logger("Skipping database initialization and clustering - using cached results") Logger("Skipping database initialization - using cached results")
else: else:
# 缓存太小,需要重新计算 # 缓存太小,需要重新计算
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...") Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
clustered_tensor = None processed_tensor = None
else: else:
# knowledge_length不匹配需要重新计算 # knowledge_length不匹配需要重新计算
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...") Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
clustered_tensor = None processed_tensor = None
except Exception as e: except Exception as e:
Logger(f"Failed to load cached clusters: {e}, recomputing...") Logger(f"Failed to load cached data: {e}, recomputing...")
clustered_tensor = None processed_tensor = None
# 只有在没有有效缓存时才进行数据库初始化和聚类计算 # 只有在没有有效缓存时才进行数据库初始化和处理
if clustered_tensor is None: if processed_tensor is None:
Logger(f"Loading database initialization data from {database_init_path}") Logger(f"Loading database initialization data from {database_init_path}")
# 1. 加载JSON文件并转换为字典 # 1. 加载JSON文件
with open(database_init_path, 'r', encoding='utf-8') as f: with open(database_init_path, 'r', encoding='utf-8') as f:
database_data = json.load(f) database_data = json.load(f)
@ -147,300 +145,73 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True) sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})") Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
# 3. 下载并初始化本地嵌入模型 # 3. 处理每条数据,不进行聚类
embedding_model_name = "sentence-transformers/all-mpnet-base-v2" # 轻量级但效果好的模型 Logger("Processing individual sentences...")
embedding_model_dir = "./models/sentence_transformers/models--sentence-transformers--all-mpnet-base-v2" processed_rows = []
embedding_cache_dir = "./models/sentence_transformers/cache"
os.makedirs(embedding_cache_dir, exist_ok=True)
Logger(f"Loading embedding model: {embedding_model_name}") # 获取空token的id用于填充
try:
embedding_model = SentenceTransformer(embedding_model_dir, cache_folder=embedding_cache_dir)
Logger("Embedding model loaded successfully")
except Exception as e:
Logger(f"Failed to load embedding model: {e}")
Logger("Falling back to random embeddings")
embedding_model = None
# 4. 对每个corrected_sentence进行嵌入和token长度计算
Logger("Processing sentences for embeddings and token lengths...")
# 提取所有句子
sentences = [sentence_data.get('corrected_sentence', '') for sentence_data in sorted_sentences]
# 批量计算token长度
Logger("Computing token lengths...")
token_lengths = []
for sentence in sentences:
tokens = tokenizer.encode(sentence, add_special_tokens=False)
token_lengths.append(len(tokens))
# 批量计算嵌入 - 大幅提升速度
Logger("Computing embeddings in batches...")
embeddings_list = []
batch_size = 256 # 可以根据GPU内存调整
if embedding_model is not None:
try:
for i in range(0, len(sentences), batch_size):
batch_sentences = sentences[i:i+batch_size]
batch_embeddings = embedding_model.encode(
batch_sentences,
convert_to_tensor=False,
show_progress_bar=True if i == 0 else False,
batch_size=batch_size
)
embeddings_list.extend(batch_embeddings)
if (i + batch_size) % (batch_size * 10) == 0:
Logger(f"Processed {min(i + batch_size, len(sentences))}/{len(sentences)} sentences")
Logger("Batch embedding computation completed")
except Exception as e:
Logger(f"Error in batch encoding: {e}")
Logger("Falling back to random embeddings")
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
else:
# 使用随机嵌入
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
# 创建处理后的句子列表
processed_sentences = []
for i, (sentence_data, embedding, token_length) in enumerate(zip(sorted_sentences, embeddings_list, token_lengths)):
processed_sentences.append({
'sentence': sentence_data.get('corrected_sentence', ''),
'importance_score': sentence_data.get('importance_score', 0.0),
'token_length': token_length,
'embedding': embedding, # Convert numpy array to list
'original_index': i
})
# 转换为numpy数组以便后续处理
embeddings_array = np.array(embeddings_list)
token_lengths_array = np.array(token_lengths)
Logger(f"Embedding processing completed:")
Logger(f" - Total sentences: {len(processed_sentences)}")
Logger(f" - Embedding shape: {embeddings_array.shape}")
Logger(f" - Average token length: {np.mean(token_lengths_array):.2f}")
Logger(f" - Token length range: {np.min(token_lengths_array)} - {np.max(token_lengths_array)}")
# 聚类参数定义
min_tokens = int(0.85 * knowledge_length)
max_tokens = int(0.95 * knowledge_length)
# 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大)
if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算
Logger("Pre-computing similarity matrix for faster clustering...")
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
similarity_matrix = cosine_similarity(embeddings_matrix)
Logger(f"Similarity matrix computed: {similarity_matrix.shape}")
else:
similarity_matrix = None
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
clustered_rows = []
remaining_indices = list(range(len(processed_sentences))) # 使用索引而不是对象
Logger(f"Target: {knowledge_num} clusters, each with {min_tokens}-{max_tokens} tokens")
# 选择聚类算法
if args.fast_clustering and len(processed_sentences) > 5000:
Logger("Using ultra-fast approximate clustering algorithm...")
# 超快速聚类:随机采样 + 批量处理
import random
random.seed(42) # 确保可重现性
# 按重要性分层采样
high_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] > 0.7]
medium_importance = [i for i, s in enumerate(processed_sentences) if 0.3 <= s['importance_score'] <= 0.7]
low_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] < 0.3]
Logger(f"Importance distribution: High={len(high_importance)}, Medium={len(medium_importance)}, Low={len(low_importance)}")
for cluster_idx in tqdm(range(knowledge_num)):
# 分层选择种子:优先选择高重要性句子
if high_importance:
seed_pool = high_importance
elif medium_importance:
seed_pool = medium_importance
else:
seed_pool = low_importance if low_importance else list(range(len(processed_sentences)))
if not seed_pool:
break
# 随机选择种子(在同一重要性层级内)
seed_global_idx = random.choice(seed_pool)
seed_sentence = processed_sentences[seed_global_idx]
# 从所有池中移除种子
for pool in [high_importance, medium_importance, low_importance]:
if seed_global_idx in pool:
pool.remove(seed_global_idx)
current_cluster_indices = [seed_global_idx]
current_tokens = seed_sentence['token_length']
if current_tokens < max_tokens:
# 快速选择:只从附近的句子中随机选择
all_remaining = high_importance + medium_importance + low_importance
if all_remaining:
# 随机采样候选句子(而不是计算所有相似度)
sample_size = min(2000, len(all_remaining))
candidates = random.sample(all_remaining, sample_size)
# 简单按token长度和重要性选择
for candidate_idx in candidates:
candidate = processed_sentences[candidate_idx]
candidate_tokens = candidate['token_length']
if current_tokens + candidate_tokens + 1 <= max_tokens:
current_cluster_indices.append(candidate_idx)
current_tokens += candidate_tokens + 1
# 从池中移除
for pool in [high_importance, medium_importance, low_importance]:
if candidate_idx in pool:
pool.remove(candidate_idx)
break
if current_tokens >= min_tokens:
break
# 生成聚类文本
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
cluster_text = '\n '.join(cluster_sentences)
# 转换为tokens
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
if len(cluster_tokens) > knowledge_length:
cluster_tokens = cluster_tokens[:knowledge_length]
else:
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
clustered_rows.append(cluster_tokens) # 处理所需数量的句子
num_to_process = min(knowledge_num, len(sorted_sentences))
if (cluster_idx + 1) % 1000 == 0: for i in range(num_to_process):
total_remaining = len(high_importance) + len(medium_importance) + len(low_importance) sentence_data = sorted_sentences[i]
Logger(f"Fast clustering: {cluster_idx + 1}/{knowledge_num} clusters, {total_remaining} sentences remaining") sentence = sentence_data.get('corrected_sentence', '')
else: # 将句子转换为tokens
# 原始优化算法(适用于中等规模数据集) sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
# 优化2: 批量处理和更高效的数据结构
for cluster_idx in tqdm(range(knowledge_num)):
if not remaining_indices:
Logger(f"No more sentences available. Created {cluster_idx} clusters.")
break
# 2.1 选择importance_score最高的句子作为种子
remaining_sentences_subset = [processed_sentences[i] for i in remaining_indices]
seed_idx_in_subset = max(range(len(remaining_sentences_subset)),
key=lambda i: remaining_sentences_subset[i]['importance_score'])
seed_global_idx = remaining_indices[seed_idx_in_subset]
seed_sentence = processed_sentences[seed_global_idx]
# 从剩余索引中移除种子
remaining_indices.remove(seed_global_idx)
# 当前聚类
current_cluster_indices = [seed_global_idx]
current_tokens = seed_sentence['token_length']
if current_tokens >= max_tokens:
# 如果种子句子已经超过最大token数直接作为一个聚类
cluster_text = seed_sentence['sentence']
else:
# 2.2 优化的相似度计算和选择
if remaining_indices:
if similarity_matrix is not None:
# 使用预计算的相似度矩阵
similarities = similarity_matrix[seed_global_idx][remaining_indices]
else:
# 动态计算相似度(批量)
seed_embedding = embeddings_matrix[seed_global_idx:seed_global_idx+1]
remaining_embeddings = embeddings_matrix[remaining_indices]
similarities = cosine_similarity(seed_embedding, remaining_embeddings)[0]
# 创建(相似度, 原始索引, 在remaining_indices中的位置)的元组列表
similarity_tuples = [(similarities[i], remaining_indices[i], i)
for i in range(len(remaining_indices))]
# 按相似度排序(降序)
similarity_tuples.sort(key=lambda x: x[0], reverse=True)
# 优化3: 贪心选择,但限制搜索范围以提高速度
max_candidates = min(len(similarity_tuples), 500) # 只考虑前500个最相似的句子
selected_indices_in_remaining = []
for sim_score, global_idx, pos_in_remaining in similarity_tuples[:max_candidates]:
candidate = processed_sentences[global_idx]
candidate_tokens = candidate['token_length']
if current_tokens + candidate_tokens + 1 <= max_tokens: # +1 for newline
current_cluster_indices.append(global_idx)
selected_indices_in_remaining.append(pos_in_remaining)
current_tokens += candidate_tokens + 1
if current_tokens >= min_tokens:
break
# 批量移除选中的句子(从后往前移除以避免索引问题)
for pos in sorted(selected_indices_in_remaining, reverse=True):
remaining_indices.pop(pos)
# 拼接句子
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
cluster_text = '\n'.join(cluster_sentences)
# 将聚类文本转换为token
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
# 截断或填充到knowledge_length # 截断或填充到knowledge_length
if len(cluster_tokens) > knowledge_length: if len(sentence_tokens) > knowledge_length:
cluster_tokens = cluster_tokens[:knowledge_length] # 如果超过长度,截断
sentence_tokens = sentence_tokens[:knowledge_length]
Logger(f"Sentence {i+1} truncated from {len(tokenizer.encode(sentence, add_special_tokens=False))} to {knowledge_length} tokens")
else: else:
# 用pad_token_id填充 # 如果不足长度用空token填充
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 original_length = len(sentence_tokens)
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens))) sentence_tokens.extend([pad_token_id] * (knowledge_length - len(sentence_tokens)))
if original_length < knowledge_length:
Logger(f"Sentence {i+1} padded from {original_length} to {knowledge_length} tokens")
clustered_rows.append(cluster_tokens) processed_rows.append(sentence_tokens)
# 优化4: 减少日志频率 if (i + 1) % 1000 == 0:
if (cluster_idx + 1) % 500 == 0: Logger(f"Processed {i + 1}/{num_to_process} sentences")
Logger(f"Created {cluster_idx + 1}/{knowledge_num} clusters, {len(remaining_indices)} sentences remaining")
# 如果聚类数量不足用随机token填充 # 如果句子数量不足用空token填充剩余位置
while len(clustered_rows) < knowledge_num: while len(processed_rows) < knowledge_num:
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 empty_tokens = [pad_token_id] * knowledge_length
random_tokens = [pad_token_id] * knowledge_length processed_rows.append(empty_tokens)
clustered_rows.append(random_tokens) if len(processed_rows) % 1000 == 0:
Logger(f"Added empty entry {len(processed_rows)}/{knowledge_num}")
Logger(f"Finished adding empty entries. Total: {len(processed_rows)}/{knowledge_num}")
# 转换为tensor # 转换为tensor
clustered_tensor = torch.tensor(clustered_rows, dtype=torch.long) processed_tensor = torch.tensor(processed_rows, dtype=torch.long)
Logger(f"Clustering completed:") Logger(f"Data processing completed:")
Logger(f" - Created {len(clustered_rows)} clusters") Logger(f" - Processed {num_to_process} sentences")
Logger(f" - Cluster shape: {clustered_tensor.shape}") Logger(f" - Added {knowledge_num - num_to_process} empty entries")
Logger(f" - Final shape: {processed_tensor.shape}")
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})") Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
# 保存聚类结果到缓存文件 # 保存处理结果到缓存文件
try: try:
torch.save(clustered_tensor, args.cluster_cache_path) torch.save(processed_tensor, args.cluster_cache_path)
Logger(f"Cluster results saved to {args.cluster_cache_path}") Logger(f"Processed results saved to {args.cluster_cache_path}")
except Exception as e: except Exception as e:
Logger(f"Failed to save cluster results: {e}") Logger(f"Failed to save processed results: {e}")
# 3. 初始化模型的weight_down_embed # 4. 初始化模型的knowledge_dataset
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'): if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
model.knowledge_dataset.knowledge_dataset.data.copy_(clustered_tensor) model.knowledge_dataset.knowledge_dataset.data.copy_(processed_tensor)
Logger("Successfully initialized model.knowledge_dataset.knowledge_dataset with clustered data") Logger("Successfully initialized model.knowledge_dataset.knowledge_dataset with processed data")
else: else:
Logger("Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize") Logger("Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize")
# 存储为全局变量作为备选 # 存储为全局变量作为备选
globals()['clustered_database'] = clustered_tensor globals()['processed_database'] = processed_tensor
Logger(f"Database embeddings and sentences stored in model") Logger(f"Database embeddings and sentences stored in model")
@ -453,6 +224,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
total_steps_in_epoch = len(train_loader) total_steps_in_epoch = len(train_loader)
total_training_steps = args.epochs * total_steps_in_epoch total_training_steps = args.epochs * total_steps_in_epoch
moe_path = '_moe' if args.use_moe else '' moe_path = '_moe' if args.use_moe else ''
best_loss = float('10000')
# 添加CUDA事件来分析性能 (只在主进程进行) # 添加CUDA事件来分析性能 (只在主进程进行)
if args.profile and accelerator.is_main_process: if args.profile and accelerator.is_main_process:
@ -516,7 +288,12 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
# 前向传播 # 前向传播
with ctx: with ctx:
res = model(X) if step == 0 and args.embedding_epoch == epoch:
# 需要设置原始模型的freeze_embedding属性而不是包装后的模型
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.freeze_embedding = True
Logger(f"Set freeze_embedding=True for epoch {epoch}, step {step}", accelerator)
res = model(X, step=step)
loss = loss_fct( loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)), res.logits.view(-1, res.logits.size(-1)),
Y.view(-1) Y.view(-1)
@ -640,7 +417,9 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
wandb.log(log_dict) wandb.log(log_dict)
# 保存模型 (只在主进程进行) # 保存模型 (只在主进程进行)
if (step + 1) % args.save_interval == 0 and accelerator.is_main_process: loss_total = loss.item() * args.accumulation_steps
if best_loss > loss_total and accelerator.is_main_process:
best_loss = loss_total
# 使用函数开始处定义的moe_path变量 # 使用函数开始处定义的moe_path变量
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth' ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
@ -660,7 +439,8 @@ def main():
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate") parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
parser.add_argument("--out_dir", type=str, default="out") parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=4) parser.add_argument("--epochs", type=int, default=4)
parser.add_argument("--batch_size", type=int, default=48) parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--learning_rate", type=float, default=2e-4) parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", default=True, action="store_true") parser.add_argument("--use_wandb", default=True, action="store_true")
@ -681,8 +461,8 @@ def main():
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)") parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention") parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowledge_num", type=int, default=4096,help="知识库的数据数目") parser.add_argument("--knowledge_num", type=int, default=8192,help="知识库的数据数目")
parser.add_argument("--knowledge_length", type=int, default=16,help="知识库的句子长度") parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径") parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)") parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径") parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
@ -724,7 +504,8 @@ def main():
disable_db=args.disable_db, disable_db=args.disable_db,
flash_attn=args.use_flash_attn, flash_attn=args.use_flash_attn,
knowledge_num=args.knowledge_num, knowledge_num=args.knowledge_num,
knowledge_length=args.knowledge_length knowledge_length=args.knowledge_length,
embeddings_epoch=args.embedding_epoch
) )
######################################################### #########################################################