Compare commits

...

No commits in common. "master" and "old/SLM" have entirely different histories.

59 changed files with 2074 additions and 24892 deletions

9
.gitignore vendored
View File

@ -2,11 +2,4 @@
/dataset
/out
wandb/
**/*.log
models/sentence_transformers/
models/sentence_transformers_cache/
**/*.pyc
qwen2-1.7B/
images/
cache/
.venv/
**/*.log

124
.vscode/launch.json vendored
View File

@ -1,124 +0,0 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "MiniMind Training (Direct Python)",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
"args": [
"--out_dir", "out",
"--epochs", "3",
"--embedding_epoch", "2",
"--batch_size", "128",
"--learning_rate", "8e-5",
"--dtype", "bfloat16",
"--use_swanlab",
"--swanlab_project", "MiniMind-Pretrain",
"--num_workers", "1",
"--accumulation_steps", "16",
"--grad_clip", "0.5",
"--warmup_iters", "0",
"--log_interval", "1",
"--save_interval", "10000",
"--dim", "512",
"--n_layers", "8",
"--max_seq_len", "512",
"--data_path", "./dataset/stable/merged_pretrain.jsonl",
"--profile",
"--profile_interval", "10",
"--use_flash_attn",
"--knowledge_num", "1048576",
"--knowledge_length", "32",
"--database_init_path", "./dataset/stable/sentence_trex_data.json",
"--fast_clustering",
"--cluster_cache_path", "./cache/cluster_tokens_single.pt",
"--memory_monitor_interval", "10",
"--model_type", "model",
"--model_size", "538"
],
"env": {
"CUDA_VISIBLE_DEVICES": "0",
"NCCL_DEBUG": "INFO",
"PYTHONFAULTHANDLER": "1"
},
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"justMyCode": false,
"stopOnEntry": false,
"python": "${workspaceFolder}/.venv/bin/python"
},
{
"name": "MiniMind Training (Direct Python - Simple)",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
"args": [
"--epochs", "1",
"--batch_size", "32",
"--learning_rate", "1e-4",
"--log_interval", "10",
"--profile_interval", "2",
"--model_type", "model_original"
],
"env": {
"CUDA_VISIBLE_DEVICES": "0"
},
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"justMyCode": false,
"stopOnEntry": false,
"python": "${workspaceFolder}/.venv/bin/python"
},
{
"name": "MiniMind Test (Direct Python)",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/test.py",
"env": {
"CUDA_VISIBLE_DEVICES": "0"
},
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"justMyCode": false,
"python": "${workspaceFolder}/.venv/bin/python"
},
{
"name": "MiniMind Training Debug (Accelerate)",
"type": "python",
"request": "launch",
"module": "accelerate.commands.launch",
"args": [
"--num_processes=1",
"--mixed_precision=bf16",
"${workspaceFolder}/train_pretrain_accelerate.py",
"--epochs", "1",
"--batch_size", "32",
"--learning_rate", "1e-4",
"--log_interval", "10",
"--profile_interval", "2",
"--model_type", "model_original"
],
"env": {
"CUDA_VISIBLE_DEVICES": "0"
},
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"justMyCode": false,
"stopOnEntry": false,
"python": "${workspaceFolder}/.venv/bin/python"
},
{
"name": "MiniMind Test Only",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/test.py",
"env": {
"CUDA_VISIBLE_DEVICES": "0"
},
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"justMyCode": false
}
]
}

18
.vscode/settings.json vendored
View File

@ -1,18 +0,0 @@
{
"python.pythonPath": "/home/iomgaa/miniconda3/envs/accelerate/bin/python",
"python.defaultInterpreterPath": "/home/iomgaa/miniconda3/envs/accelerate/bin/python",
"python.terminal.activateEnvironment": true,
"python.terminal.activateEnvInCurrentTerminal": true,
"python.linting.enabled": true,
"python.linting.pylintEnabled": false,
"python.linting.flake8Enabled": true,
"python.formatting.provider": "black",
"python.analysis.autoImportCompletions": true,
"python.analysis.typeCheckingMode": "off",
"files.exclude": {
"**/__pycache__": true,
"**/*.pyc": true,
"**/.git": false,
"**/wandb": false
}
}

128
CODE_OF_CONDUCT.md Normal file
View File

@ -0,0 +1,128 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

1509
README_en.md Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,17 +0,0 @@
compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_config_file: ds_config.json
zero3_init_flag: false
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -1,144 +0,0 @@
import os
import argparse
import torch
from transformers import AutoTokenizer
from model.model import MiniMindLM, ExtractDB
from model.LMConfig import LMConfig
def decode_dataset(model_path, output_path, device="cuda"):
"""
Decode the weight_down_embed buffer in the model to readable text
Args:
model_path: Path to the model checkpoint
output_path: Path to save the decoded text
device: Device to load the model on
"""
print(f"Loading tokenizer from ./model/minimind_tokenizer")
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
print(f"Setting up model configuration")
# Create model configuration matching the training parameters
lm_config = LMConfig(
dim=1024,
n_layers=32,
max_seq_len=1024,
use_flash_attn=True,
knowledge_num=16384, # From the script parameters
knowledge_length=64 # From the script parameters
)
print(f"Initializing model")
model = MiniMindLM(lm_config).to(device)
print(f"Loading model weights from {model_path}")
state_dict = torch.load(model_path, map_location=device)
# Get model parameters
model_state = dict(model.named_parameters())
model_state.update(dict(model.named_buffers()))
# Find parameters with matching names but different shapes
shape_mismatch = {}
for name, param in model_state.items():
if name in state_dict and param.shape != state_dict[name].shape:
shape_mismatch[name] = (param.shape, state_dict[name].shape)
# Find parameters in model but not in state_dict and vice versa
model_only = set(model_state.keys()) - set(state_dict.keys())
state_dict_only = set(state_dict.keys()) - set(model_state.keys())
# Create filtered state_dict with only compatible parameters
filtered_state_dict = {}
for name, param in state_dict.items():
if name in model_state and param.shape == model_state[name].shape:
filtered_state_dict[name] = param
# Print parameter differences
if shape_mismatch:
print(f"Parameters with shape mismatches: {len(shape_mismatch)}")
for name, (model_shape, state_shape) in shape_mismatch.items():
print(f" {name}: model={model_shape}, checkpoint={state_shape}")
if model_only:
print(f"Parameters in model but not in checkpoint: {len(model_only)}")
for name in sorted(model_only):
print(f" {name}: {model_state[name].shape}")
# 特殊处理pos_cis_real参数
if name == "pos_cis_real":
print(f"Detected pos_cis_real parameter. This is a position encoding that will be initialized automatically.")
if state_dict_only:
print(f"Parameters in checkpoint but not in model: {len(state_dict_only)}")
for name in sorted(state_dict_only):
print(f" {name}: {state_dict[name].shape}")
# 如果checkpoint中有output.weight但模型中没有尝试加载到tok_embeddings
if name == "output.weight" and "tok_embeddings.weight" in model_state:
print(f"Found output.weight in checkpoint but not in model. Will try to map it to tok_embeddings.weight")
if model_state["tok_embeddings.weight"].shape == state_dict["output.weight"].shape:
filtered_state_dict["tok_embeddings.weight"] = state_dict["output.weight"]
# Load only the compatible parameters
print(f"Loading {len(filtered_state_dict)}/{len(state_dict)} parameters")
model.load_state_dict(filtered_state_dict, strict=False)
# 检查extract_db和weight_down_embed是否存在
if not hasattr(model, "extract_db"):
print("ERROR: Model does not have extract_db attribute. This is required for decoding.")
return
print("Accessing weight_down_embed buffer")
# Get the weight_down_embed buffer from the model
try:
weight_down_embed = model.extract_db.weight_down_embed
print(f"Successfully accessed weight_down_embed buffer")
except Exception as e:
print(f"ERROR: Failed to access weight_down_embed buffer: {e}")
print(f"Model structure: {model.__class__.__name__}")
print(f"ExtractDB attributes: {dir(model.extract_db)}")
return
print(f"Shape of weight_down_embed: {weight_down_embed.shape}")
print(f"Data type of weight_down_embed: {weight_down_embed.dtype}")
# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(output_path), exist_ok=True)
print(f"Decoding knowledge and writing to {output_path}")
knowledge_num, knowledge_length = weight_down_embed.shape
with open(output_path, 'w', encoding='utf-8') as f:
for i in range(knowledge_num):
try:
# Get token IDs for this knowledge entry
token_ids = weight_down_embed[i].cpu().tolist()
# Decode tokens to text
text = tokenizer.decode(token_ids, skip_special_tokens=True)
# Write to file
f.write(f"Knowledge_{i}: {text}\n")
# Print progress periodically
if (i + 1) % 100 == 0:
print(f"Decoded {i + 1}/{knowledge_num} knowledge entries")
except Exception as e:
print(f"Error decoding knowledge entry {i}: {e}")
f.write(f"Knowledge_{i}: [ERROR DECODING]\n")
print(f"Decoding completed. Output saved to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database")
parser.add_argument("--model_path", type=str, default="out/pretrain_1024.pth",
help="Path to the model checkpoint")
parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt",
help="Path to save the decoded text file")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to load the model on")
args = parser.parse_args()
decode_dataset(args.model_path, args.output_path, args.device)

View File

@ -1,49 +0,0 @@
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true
},
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"steps_per_print": 100,
"wall_clock_breakdown": false
}

View File

@ -1,26 +0,0 @@
# 1. 元数据:需要修改,请为该实验配置名称和描述
name: ycz-minimind-test
description: 测试minimind-test
# 2. 运行环境:一般不修改,如有需求可以手动替换为指定镜像
environment:
image: determinedai/pytorch-ngc:0.38.0 # 此项无需修改
# 3. 指定NAS上的数据集: 需要修改仅修改bind_mounts字段container_path和read_only无需修改
#将<YOUR_DATASET_FOLDER_NAME>替换为您存放在NAS上Volume1/Share/datasets/的数据集文件夹名称
# 请再次确保您已在 NAS上的Volume1/Share/datasets/存放了<YOUR_DATASET_FOLDER_NAME>数据集
# 4. 计算资源:无需修改
resources:
slots_per_trial: 1 # 此项无需修改
resource_pool: rtx4090 # 此项无需修改
# 5. 搜索器:无需修改
searcher:
name: single
metric: test_accuracy
smaller_is_better: false
# 6. 启动入口:无需修改
entrypoint: sh startup.sh

BIN
images/1-wiki.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

BIN
images/2-wiki.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 73 KiB

BIN
images/3-wiki.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 230 KiB

BIN
images/4-wiki.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

BIN
images/5-wiki.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 239 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 121 KiB

BIN
images/LLM-structure.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 372 KiB

BIN
images/and_huggingface.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 178 KiB

BIN
images/and_modelscope.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 150 KiB

BIN
images/compare_radar.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 519 KiB

BIN
images/dataset.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 146 KiB

BIN
images/gpt3_config.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

BIN
images/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 495 KiB

BIN
images/logo2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 615 KiB

BIN
images/minimind2.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 MiB

BIN
images/pre_512_loss.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 559 KiB

BIN
images/pre_768_loss.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 531 KiB

BIN
images/sft_512_loss.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1006 KiB

BIN
images/sft_768_loss.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 943 KiB

View File

@ -1,6 +0,0 @@
def main():
print("Hello from minimind!")
if __name__ == "__main__":
main()

View File

@ -9,8 +9,8 @@ class LMConfig(PretrainedConfig):
self,
dim: int = 512,
n_layers: int = 8,
n_heads: int = 32,
n_kv_heads: int = 8,
n_heads: int = 8,
n_kv_heads: int = 2,
vocab_size: int = 6400,
hidden_dim: int = None,
multiple_of: int = 64,
@ -19,7 +19,6 @@ class LMConfig(PretrainedConfig):
rope_theta: int = 1e6,
dropout: float = 0.0,
flash_attn: bool = True,
embeddings_epoch: int = 2,
####################################################
# DB related configurations
####################################################
@ -37,16 +36,6 @@ class LMConfig(PretrainedConfig):
aux_loss_alpha: float = 0.1,
seq_aux: bool = True,
norm_topk_prob: bool = True,
####################################################
knowledge_num: int = 64*64,
knowledge_length: int = 8,
knowledge_dim: int = 128,
####################################################
# Triple extraction related configurations
####################################################
max_subject_len: int = 8,
max_predicate_len: int = 4,
max_object_len: int = 8,
**kwargs,
):
self.dim = dim
@ -61,7 +50,6 @@ class LMConfig(PretrainedConfig):
self.rope_theta = rope_theta
self.dropout = dropout
self.flash_attn = flash_attn
self.embeddings_epoch = embeddings_epoch
####################################################
# DB related configurations
####################################################
@ -78,14 +66,4 @@ class LMConfig(PretrainedConfig):
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
####################################################
self.knowledge_num = knowledge_num
self.knowledge_length = knowledge_length
self.knowledge_dim = knowledge_dim
####################################################
# Triple extraction related configurations
####################################################
self.max_subject_len = max_subject_len
self.max_predicate_len = max_predicate_len
self.max_object_len = max_object_len
super().__init__(**kwargs)

View File

@ -9,73 +9,8 @@ import torch
from sklearn.model_selection import train_test_split
import os
import ast
from tqdm import tqdm
os.environ["TOKENIZERS_PARALLELISM"] = "true"
def process_sample_filter(data_args):
"""处理单个样本的过滤逻辑"""
sample, valid_predicates = data_args
if 'target' in sample and isinstance(sample['target'], list):
# 过滤target中的低频谓词
valid_targets = []
for triple in sample['target']:
if isinstance(triple, dict) and 'predicate' in triple:
if triple['predicate'] in valid_predicates:
valid_targets.append(triple)
# 如果还有有效的target保留这个样本
if valid_targets:
sample['target'] = valid_targets
return sample
else:
return None
else:
# 如果没有target信息保留样本
return sample
def process_sample_validation(data_args):
"""处理单个样本的验证逻辑"""
sample, predicate_vocab = data_args
if not isinstance(sample, dict) or 'text' not in sample:
return None
targets = sample.get('target', [])
if not isinstance(targets, list) or len(targets) == 0:
# 如果没有有效的target创建一个默认的
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
else:
# 验证并选择target优先选择占比小的谓词
selected_target = None
min_percentage = float('inf')
for triple in targets:
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
predicate = triple['predicate']
# 使用predicate_vocab中的统计信息
if predicate in predicate_vocab:
stats = predicate_vocab[predicate]
if isinstance(stats, dict) and 'percentage' in stats:
percentage = stats['percentage']
if percentage < min_percentage:
min_percentage = percentage
selected_target = triple
elif selected_target is None:
selected_target = triple
elif selected_target is None:
selected_target = triple
# 如果没有找到有效的target使用默认值
if selected_target is None:
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
return {
'text': sample['text'],
'target': selected_target # 只保留一个target
}
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class PretrainDataset(Dataset):
@ -98,14 +33,9 @@ class PretrainDataset(Dataset):
def __getitem__(self, index):
sample = self.samples[index]
text = str(sample['text'])
# 检查并添加<|im_start|>和<|im_end|>如果不存在
if not text.startswith(self.tokenizer.bos_token):
text = f"{self.tokenizer.bos_token}{text}"
if not text.endswith(self.tokenizer.eos_token):
text = f"{text}{self.tokenizer.eos_token}"
# 构建输入文本
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
encoding = self.tokenizer(
text,
max_length=self.max_length,
@ -128,8 +58,8 @@ class SFTDataset(Dataset):
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = self.load_data(jsonl_path)
self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)
@ -196,8 +126,8 @@ class DPODataset(Dataset):
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
with open(file_path, 'r', encoding='utf-8') as f:
self.data = []
for line in f:
@ -266,249 +196,14 @@ class DPODataset(Dataset):
return loss_mask
class TriplePretrainDataset(Dataset):
"""
优化的三元组预训练数据集
- 每个样本只保留一个target三元组
- 预先tokenize所有数据
- 使用进度条显示处理进度
"""
def __init__(self, data_path=None, predicate_vocab_path=None, samples = None,tokenizer=None, max_length=512):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.val_samples = None
self.predicate_to_id = {} # 初始化
if samples is None:
self.predicate_vocab = self.load_predicate_vocab(predicate_vocab_path)
print("🚀 开始加载和预处理三元组数据...")
self.samples,self.val_samples = self.load_and_preprocess_data(data_path)
print("🚀 加载和预处理三元组数据完成")
else:
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
data_filename = os.path.basename(data_path).split('.')[0]
predicate_to_id_path = os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json')
self.predicate_to_id = self.load_predicate_vocab(predicate_to_id_path)
self.samples = samples
print("🚀 加载和预处理三元组数据完成")
def load_predicate_vocab(self, path):
with open(path, 'r', encoding='utf-8') as f:
predicate_vocab = json.load(f)
return predicate_vocab
def get_val_samples(self):
return self.val_samples
def clear_cache(self, data_path):
"""清除缓存文件"""
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
data_filename = os.path.basename(data_path).split('.')[0]
cache_files = [
os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
os.path.join(cache_dir, f'{data_filename}_val_samples.json')
]
for cache_file in cache_files:
if os.path.exists(cache_file):
os.remove(cache_file)
print(f"🗑️ 已删除缓存文件: {cache_file}")
if os.path.exists(cache_dir) and not os.listdir(cache_dir):
os.rmdir(cache_dir)
print(f"🗑️ 已删除空的缓存目录: {cache_dir}")
def load_and_preprocess_data(self, path):
"""加载并预处理三元组数据"""
# 生成缓存文件名(基于数据文件路径)
cache_dir = os.path.join(os.path.dirname(path), 'cache')
os.makedirs(cache_dir, exist_ok=True)
data_filename = os.path.basename(path).split('.')[0]
cache_files = {
'predicate_vocab': os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
'predicate_to_id': os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
'train_samples': os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
'val_samples': os.path.join(cache_dir, f'{data_filename}_val_samples.json')
}
# 检查缓存文件是否存在
cache_exists = all(os.path.exists(cache_file) for cache_file in cache_files.values())
if cache_exists:
print("📁 发现缓存文件,直接加载...")
# 从缓存加载
with open(cache_files['predicate_vocab'], 'r', encoding='utf-8') as f:
self.predicate_vocab = json.load(f)
with open(cache_files['predicate_to_id'], 'r', encoding='utf-8') as f:
self.predicate_to_id = json.load(f)
with open(cache_files['train_samples'], 'r', encoding='utf-8') as f:
train_samples = json.load(f)
with open(cache_files['val_samples'], 'r', encoding='utf-8') as f:
val_samples = json.load(f)
print(f"✅ 从缓存加载完成:")
print(f"✅ 谓词词表大小: {len(self.predicate_vocab)}")
print(f"✅ 训练集大小: {len(train_samples)}")
print(f"✅ 测试集大小: {len(val_samples)}")
return train_samples, val_samples
# 缓存不存在,重新处理数据
print("📂 缓存不存在,开始加载和处理原始数据...")
# 1. 加载原始数据
print("📂 加载原始数据...")
if path.endswith('.json'):
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
elif path.endswith('.jsonl'):
data = []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
data.append(json.loads(line.strip()))
else:
raise ValueError(f"Unsupported file format: {path}")
print(f"📊 原始数据量: {len(data)} 个样本")
# 2. 使用self.predicate_vocab过滤占比小于0.01%的谓词数据
print("🔍 过滤低频谓词数据...")
print(f"📊 谓词统计数据: 总共{len(self.predicate_vocab)}个谓词")
# 3.获取占比大于等于0.01%的谓词
valid_predicates = set()
for predicate, stats in self.predicate_vocab.items():
if isinstance(stats, dict) and 'percentage' in stats:
if stats['percentage'] >= 0.01:
valid_predicates.add(predicate)
else:
# 如果不是统计格式,假设是有效谓词
valid_predicates.add(predicate)
print(f"📊 占比≥0.01%的谓词: {len(valid_predicates)}")
# 4.过滤数据:去除包含低频谓词的数据(单进程处理)
original_count = len(data)
filtered_data = []
print("🚀 开始过滤低频谓词数据...")
for sample in tqdm(data, desc="过滤低频谓词"):
result = process_sample_filter((sample, valid_predicates))
if result is not None:
filtered_data.append(result)
data = filtered_data
print(f"✅ 过滤完成: 去除前{original_count}条,去除后{len(data)}")
# 5. 去除self.predicate_vocab中占比小于0.01%的谓词,并创建谓词到序号的映射
print("🔍 更新谓词词表并创建序号映射...")
original_vocab_size = len(self.predicate_vocab)
filtered_predicate_vocab = {}
for predicate, stats in self.predicate_vocab.items():
if isinstance(stats, dict) and 'percentage' in stats:
if stats['percentage'] >= 0.01:
filtered_predicate_vocab[predicate] = stats
else:
# 如果不是统计格式,保留
filtered_predicate_vocab[predicate] = stats
# 创建谓词到序号的映射字典
self.predicate_to_id = {predicate: idx for idx, predicate in enumerate(filtered_predicate_vocab.keys())}
self.predicate_vocab = filtered_predicate_vocab
print(f"✅ 谓词词表更新: 去除前{original_vocab_size}个,去除后{len(self.predicate_vocab)}")
print(f"✅ 谓词映射创建: {len(self.predicate_to_id)}个谓词对应序号")
# 6. 数据验证和筛选只保留一个target优先选择占比小的谓词以平衡数据单进程处理
print("🔍 验证数据格式并选择单个target平衡数据...")
valid_samples = []
print("🚀 开始验证数据格式...")
for sample in tqdm(data, desc="验证数据格式"):
result = process_sample_validation((sample, self.predicate_vocab))
if result is not None:
valid_samples.append(result)
print(f"✅ 有效样本数: {len(valid_samples)}")
# 7.拆分训练集合与测试集合
import random
random.seed(42)
val_samples = random.sample(valid_samples, min(1000, len(valid_samples)))
train_samples = [sample for sample in valid_samples if sample not in val_samples]
print(f"✅ 训练集大小: {len(train_samples)}")
print(f"✅ 测试集大小: {len(val_samples)}")
# 8. 保存到缓存文件
print("💾 保存处理结果到缓存文件...")
with open(cache_files['predicate_vocab'], 'w', encoding='utf-8') as f:
json.dump(self.predicate_vocab, f, ensure_ascii=False, indent=2)
with open(cache_files['predicate_to_id'], 'w', encoding='utf-8') as f:
json.dump(self.predicate_to_id, f, ensure_ascii=False, indent=2)
with open(cache_files['train_samples'], 'w', encoding='utf-8') as f:
json.dump(train_samples, f, ensure_ascii=False, indent=2)
with open(cache_files['val_samples'], 'w', encoding='utf-8') as f:
json.dump(val_samples, f, ensure_ascii=False, indent=2)
print("✅ 缓存文件保存完成")
return train_samples, val_samples
def __len__(self):
return len(self.samples)
def _triple_to_sentence(self, triple):
"""将三元组转换为句子格式"""
return f"{triple['subject']} {triple['predicate']} {triple['object']}"
def __getitem__(self, index):
"""返回数据,用于谓词分类任务"""
sample = self.samples[index]
# 在运行时tokenize输入文本
input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}"
encoding = self.tokenizer(
input_text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding.input_ids.squeeze()
loss_mask = (input_ids != self.tokenizer.pad_token_id)
# 获取谓词分类标签
target_predicate = sample['target']['predicate']
predicate_label = self.predicate_to_id.get(target_predicate) # 默认为0如果找不到
# 构建训练数据
X = input_ids[:-1]
loss_mask = loss_mask[1:]
return {
'input_ids': X,
'labels': torch.tensor(predicate_label, dtype=torch.long), # 谓词分类标签
'loss_mask': loss_mask
}
class RLAIFDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = self.load_data(jsonl_path)
self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)

View File

@ -14,7 +14,7 @@
},
{
"id": 1,
"content": "<|im_start|>",
"content": "<s>",
"single_word": false,
"lstrip": false,
"rstrip": false,
@ -23,7 +23,7 @@
},
{
"id": 2,
"content": "<|im_end|>",
"content": "</s>",
"single_word": false,
"lstrip": false,
"rstrip": false,
@ -56,8 +56,8 @@
"ignore_merges": false,
"vocab": {
"<unk>": 0,
"<|im_start|>": 1,
"<|im_end|>": 2,
"<s>": 1,
"</s>": 2,
"!": 3,
"\"": 4,
"#": 5,

View File

@ -12,7 +12,7 @@
"special": true
},
"1": {
"content": "<|im_start|>",
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
@ -20,7 +20,7 @@
"special": true
},
"2": {
"content": "<|im_end|>",
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
@ -29,9 +29,9 @@
}
},
"additional_special_tokens": [],
"bos_token": "<|im_start|>",
"bos_token": "<s>",
"clean_up_tokenization_spaces": false,
"eos_token": "<|im_end|>",
"eos_token": "</s>",
"legacy": true,
"model_max_length": 32768,
"pad_token": "<unk>",
@ -39,5 +39,5 @@
"spaces_between_special_tokens": false,
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": "<unk>",
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\n你是 MiniMind是一个有用的人工智能助手。<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<s>system\\n' + system_message + '</s>\\n' }}{% else %}{{ '<s>system\\n你是 MiniMind是一个有用的人工智能助手。</s>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
}

File diff suppressed because one or more lines are too long

View File

@ -2,8 +2,7 @@ import math
import struct
import inspect
import time
import gc
#子空间二维分解+梯度更新
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
@ -12,9 +11,14 @@ import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch import nn, einsum
from einops import rearrange, repeat
def exists(val):
return val is not None
# RMSNorm 类定义了一个用于归一化输入张量的模块。
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
@ -27,7 +31,7 @@ class RMSNorm(torch.nn.Module):
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
# precompute_pos_cis 函数用于预计算位置编码。
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
@ -35,7 +39,7 @@ def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
# apply_rotary_emb 函数用于应用旋转位置编码。
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
@ -51,244 +55,18 @@ def apply_rotary_emb(xq, xk, pos_cis):
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class KnowledgeDataset(nn.Module):
def __init__(self, params, tok_embeddings, is_train=True):
super().__init__()
self.is_train = is_train
self.params = params
self.tok_embeddings = tok_embeddings
# repeat_kv 函数用于重复键值对。
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
# 嵌入参数
self.knowledge_dim = params.knowledge_dim
self.key_dim = self.knowledge_dim // 2
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.knowledge_dim, bias=False),
)
## 数据库参数
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
# 修改键存储为二维分解空间,设置为可训练参数
self.num_keys = int(math.sqrt(self.knowledge_num))
# 确保keys是可训练参数
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
# 计算step数目用于动态调整权重
self.step_counter = 0
# 移除批次计数器和更新频率相关代码
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
if self.is_train == False:
return all_scores, all_indices
batch_size = all_scores.size(0)
device = all_scores.device
dtype = all_scores.dtype
# 记录进入智能选择前的内存状态
if hasattr(self, 'step_counter'):
self.step_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.step_counter % 50 == 0: # 每50次调用记录一次
# if torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_ENTER] Step {self.step_counter}: GPU Memory: {allocated_before:.2f}GB")
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
# 预先计算所有候选条目的嵌入(批量优化)
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
# 批量计算唯一候选条目的嵌入
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
)
unique_candidate_features = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
).mean(dim=1) # [num_unique_candidates, dim]
# 归一化候选特征(优化相似度计算)
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
normalized_queries = F.normalize(query_features, dim=-1)
# 收集所有batch的best_tokens
batch_best_tokens = []
batch_best_tokens_embeddings = []
for batch_idx in range(batch_size):
indices = all_indices[batch_idx]
# 获取当前batch候选条目对应的特征索引
start_idx = batch_idx * len(indices)
end_idx = start_idx + len(indices)
batch_inverse_indices = inverse_indices[start_idx:end_idx]
# 使用预计算的归一化特征进行优化相似度计算
batch_candidate_features = normalized_candidates[batch_inverse_indices]
query_feature = normalized_queries[batch_idx]
# 使用矩阵乘法计算余弦相似度
similarity_scores = torch.mv(batch_candidate_features, query_feature)
# 找到最大相似度分数的索引
max_similarity_idx = torch.argmax(similarity_scores)
# 获取最大相似度对应的候选条目索引
best_candidate_idx = indices[max_similarity_idx]
# 获取对应的tokens
best_tokens = self.knowledge_dataset[best_candidate_idx]
best_tokens_embeddings = self.tok_embeddings(best_tokens)
# 将当前batch的best_tokens添加到列表中
batch_best_tokens.append(best_tokens)
batch_best_tokens_embeddings.append(best_tokens_embeddings)
# 将所有batch的best_tokens堆叠成一个张量
# [batch_size, knowledge_length]
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
# 清理中间张量以防止内存泄漏
del all_candidate_indices, unique_indices, inverse_indices
del unique_candidate_features, normalized_candidates, normalized_queries
del batch_best_tokens, batch_best_tokens_embeddings
del flat_tokens, flat_embeddings, pre_update_embeddings
# 记录退出智能选择后的内存状态(已禁用以提高性能)
# if hasattr(self, 'step_counter') and self.step_counter % 50 == 0:
# if torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_EXIT] Step {self.step_counter}: GPU Memory: {allocated_after:.2f}GB")
# 强制垃圾回收(仅在监控步骤)
if hasattr(self, 'step_counter') and self.step_counter % 100 == 0:
gc.collect()
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
return all_best_tokens, all_best_tokens_embeddings
def search_index(self, x):
batch_size, seq_len, dim = x.shape
# 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim]
# 2. 生成查询向量并重塑为两个子查询
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
# 3. 计算每个子空间的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 4. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 5. 组合两个子空间的结果
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
return best_tokens, best_tokens_embeddings
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 监控交叉注意力开始时的内存(已禁用以提高性能)
if not hasattr(self, 'call_counter'):
self.call_counter = 0
self.call_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_ENTER] Call {self.call_counter}: GPU Memory: {allocated_before:.2f}GB")
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
# 清理中间张量
del q, k, v, attn_scores, attn_weights
# 监控交叉注意力结束时的内存(已禁用以提高性能)
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_EXIT] Call {self.call_counter}: GPU Memory: {allocated_after:.2f}GB")
return context
class Attention(nn.Module):
def __init__(self, args: LMConfig):
@ -314,14 +92,58 @@ class Attention(nn.Module):
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
pos_cis: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache=True,
db_value=None):
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换得到查询、键和值。
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) #将变换后的张量xq重塑为形状为(bsz, seq_len, n_local_heads, head_dim)的形状。
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
# 应用旋转位置编码
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
# kv_cache实现
if past_key_value is not None:
xk = torch.cat([past_key_value[0], xk], dim=1)
xv = torch.cat([past_key_value[1], xv], dim=1)
past_kv = (xk, xv) if use_cache else None
# 重复键值对
xq, xk, xv = (
xq.transpose(1, 2),
repeat_kv(xk, self.n_rep).transpose(1, 2),
repeat_kv(xv, self.n_rep).transpose(1, 2)
)
# 如果提供了db_value根据头的数量调整它的形状并与xv合并
if db_value is not None:
# 确保db_value的形状与xv兼容假设db_value形状为[B, N, H, D]
if db_value.ndim == 4: # [B, N, H, D]
db_value = db_value.transpose(1, 2) # -> [B, H, N, D]
# 检查是否需要调整D维度
if db_value.shape[-1] != xv.shape[-1]:
# 如果db_value的维度与xv不同可以添加一个投影层
# 或者在这里使用简单的调整方法
# 这里我们简单地通过均值池化或重复来调整维度
if db_value.shape[-1] > xv.shape[-1]:
# 降维
factor = db_value.shape[-1] // xv.shape[-1]
db_value = db_value.view(bsz, self.n_local_heads, seq_len, factor, xv.shape[-1])
db_value = db_value.mean(dim=3)
else:
# 升维
factor = xv.shape[-1] // db_value.shape[-1]
db_value = db_value.unsqueeze(-1).repeat(1, 1, 1, 1, factor)
db_value = db_value.view(bsz, self.n_local_heads, seq_len, xv.shape[-1])
# 将db_value与xv相加或融合
# 这里我们简单地将它们相加,但你也可以使用其他融合方法
xv = xv + db_value
# 使用Flash Attention
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
@ -339,9 +161,56 @@ class Attention(nn.Module):
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
return output, past_kv
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
return context
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
@ -474,31 +343,168 @@ class MOEFeedForward(nn.Module):
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
def __init__(self, layer_id: int, config: LMConfig):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.self_attention = Attention(config)
self.cross_attention = CrossAttention(config)
self.knowledge_dataset = knowledge_dataset
self.attention = Attention(config)
self.cross_att = CrossAttention(config)
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
# 假设num_experts是已定义的总专家数量的平方根
# 查询生成的参数
# 创建查询生成模块
# if weight_down_embed is not None:
# self.to_queries = nn.Sequential(
# nn.Linear(config.dim, self.dim_key * 2, bias=False),
# # nn.Unflatten(2, (2, self.n_heads, self.dim_key)) # 替代Rearrange
# )
# # 超参数
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
def forward(self, x, pos_cis):
h_attn = self.self_attention(
def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=True):
# import pdb;pdb.set_trace()
# db_value = None
# # 如果有weight_down_embed使用Product Key机制
# if self.weight_down_embed is not None:
# # 1. 生成queries
# batch_size, seq_len, dim = x.shape
# # collapse sequence dimension by averaging
# x_flat = x.mean(dim=1) # [batch_size, dim]
# queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
# queries = queries.reshape(batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
# queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
# # 2. 计算queries与keys的相似度
# sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# # 3. 在两个子空间分别做top-k
# scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
# scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
# indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# # 4. 组合两个子空间的分数和索引
# all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
# all_scores = all_scores.view(*all_scores.shape[:-2], -1)
# all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
# all_indices = all_indices.view(*all_indices.shape[:-2], -1)
# # 5. 最终top-k选择
# scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
# indices = all_indices.gather(-1, pk_indices)
# # 6. 从embedding中获取专家值
# # 从embedding中获取值
# flat_indices = indices.view(-1) # 将索引展平为一维张量
# db_values = self.weight_down_embed(flat_indices)
# # 重塑回原始形状
# db_value = db_values.view(batch_size, -1, dim)
# 注意力计算
h_attn, past_kv = self.attention(
self.attention_norm(x),
pos_cis
pos_cis,
past_key_value=past_key_value,
use_cache=use_cache,
db_value=db_value
)
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
h_attn = self.cross_attention(h_attn, db_embeddings)
h = x + h_attn
out = h + self.feed_forward(self.ffn_norm(h))
return out
h_attn = self.cross_att(h_attn, db_value)
# 残差连接
h = x + h_attn
# 前馈神经网络
out = h + self.feed_forward(self.ffn_norm(h))
return out, past_kv
class ExtractDB(nn.Module):
def __init__(self,params):
# 修改专家数量和知识维度,确保能开方
super().__init__()
self.batch_size = None
self.dim = params.dim
self.dim_key = self.dim // 2
self.num_experts = 10 * 10 # 100专家确保是完全平方数
# 将knowledge_dim设置为与head_dim相同以便在attention中直接使用
self.head_dim = params.dim // params.n_heads
self.knowledge_dim = 8*params.dim
# 使用register_buffer代替nn.Parameter避免梯度问题
self.register_buffer('weight_down_embed', torch.randn(self.num_experts, self.knowledge_dim) * 0.02)
self.num_keys = int(math.sqrt(self.num_experts)) if self.num_experts > 0 else 0
self.product_key_topk = min(16, self.num_keys)
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
self.num_experts_per_head_topk = 1
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.dim_key * 2, bias=False),
)
def q_to_k(self,x):
# 1. 生成queries
self.batch_size, seq_len, dim = x.shape
# collapse sequence dimension by averaging
x_flat = x.mean(dim=1) # [batch_size, dim]
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
# 2. 计算queries与keys的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 3. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 4. 组合两个子空间的分数和索引
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
all_scores = all_scores.view(*all_scores.shape[:-2], -1)
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
# 5. 最终top-k选择
scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
indices = all_indices.gather(-1, pk_indices)
flat_indices = indices.view(-1)
return flat_indices
def get_data(self, index):
# 直接从GPU获取embedding
db_values = self.weight_down_embed[index]
db_value = db_values.view(self.batch_size, -1, self.dim)
return db_value
@torch.no_grad()
def updata_value(self, k, v):
# 直接更新buffer上的值 (不需要梯度)
v_reshaped = v.view(v.size(0), -1)
# 确保数据类型匹配
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
self.weight_down_embed[k] = v_reshaped
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
@ -509,63 +515,130 @@ class MiniMindLM(PreTrainedModel):
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
# 移除旧的weight_down_embed声明
self.extract_db = ExtractDB(self.params)
# 将self.weight_down_embed传递给每个MiniMindBlock
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
# Calculate input dimension
input_dim = (self.params.max_seq_len-1)*self.params.n_layers
# Use a bottleneck architecture to reduce parameters
bottleneck_dim = 256 # Significantly smaller bottleneck dimension
# Factorized shared downsampling using two smaller convolutions
self.shared_downsample = nn.Sequential(
# First reduce input dimension to bottleneck
nn.Conv1d(input_dim, bottleneck_dim, kernel_size=1, padding='same'),
nn.ReLU(), # Non-linearity to improve representation capacity
# Then expand to target dimension
nn.Conv1d(bottleneck_dim, 128*8, kernel_size=1, padding='same')
)
# Specific layers for v path
self.downsample_v_specific = nn.Sequential(
nn.Conv1d(128*8, 128, kernel_size=1, padding='same'),
nn.Conv1d(128, 8, kernel_size=1, padding='same')
)
# Specific layers for q path
self.downsample_q_specific = nn.Sequential(
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
)
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
self.freeze_embedding = False
self.params = params
def forward(self,
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0,
step: int = 0,
**args):
past_key_values = past_key_values or [None] * len(self.layers)
start_pos = args.get('start_pos', 0)
# if self.freeze_embedding and step == 0:
# self.tok_embeddings.weight.requires_grad = False
# # 移除对knowledge_dataset.freeze_embedding的设置让键更新由batch_counter控制
# # self.knowledge_dataset.freeze_embedding = True
# print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
past_kvs = []
h_list = []
for l, layer in enumerate(self.layers):
h = layer(
h, pos_cis
# 禁用数据库模式,使用固定值替代数据库查询
if self.params.disable_db:
# 创建一个形状为[batch_size, n_layers, dim]的tensor所有元素值为1e-4
batch_size = h.size(0)
db_value = torch.full((batch_size, self.n_layers, self.params.dim), 1e-4,
dtype=h.dtype, device=h.device)
else:
# 正常模式,使用数据库查询
index = self.extract_db.q_to_k(h)
db_value = self.extract_db.get_data(index)
h, past_kv = layer(
h, db_value, pos_cis,
past_key_value=past_key_values[l],
use_cache=use_cache
)
past_kvs.append(past_kv)
h_list.append(h.unsqueeze(0))
h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3)
# 只在非禁用数据库模式下执行数据库更新逻辑
if not self.params.disable_db:
# 使用detach()分离计算图,避免多次反向传播
h_tensor_detached = h_tensor.detach()
h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0], -1, self.params.dim)
# 数据库更新逻辑与主计算图分离
with torch.no_grad():
# Compute shared downsampling layer once
shared_features = self.shared_downsample(h_tensor_detached)
z_v = self.downsample_v_specific(shared_features)
z_q = self.downsample_q_specific(shared_features)
z_k = self.extract_db.q_to_k(z_q)
self.extract_db.updata_value(z_k, z_v)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
# 统一不使用 aux_loss
aux_loss = 0
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
past_key_values=past_kvs,
)
output.hidden_states = h
output.aux_loss = aux_loss
# 尝试添加其他属性(如果支持的话)
# try:
# output.hidden_states = h
# output.aux_loss = aux_loss
# except:
# pass
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
stream=False, rp=1., use_cache=True, pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
@ -582,20 +655,17 @@ class MiniMindLM(PreTrainedModel):
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start = input_ids.shape[1]
for _ in range(max_new_tokens):
# 每次都传入完整的input_ids不使用KV缓存
out = self(input_ids, **args)
logits = out.logits[:, -1, :] # 取最后一个位置的logits
# 重复惩罚
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq or not use_cache:
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False
else:
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
start_pos=input_ids.shape[1] - 1, **args)
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp
# 温度采样
logits /= (temperature + 1e-9)
# Top-p采样
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
@ -605,14 +675,8 @@ class MiniMindLM(PreTrainedModel):
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
# 采样下一个token
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
# 返回新生成的部分
yield input_ids[:, start:]
# 如果遇到结束token停止生成
if input_ids_next.item() == eos_token_id:
break

View File

@ -1,732 +0,0 @@
import math
import struct
import inspect
import time
import gc
#子空间二维分解+梯度更新
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class KnowledgeDataset(nn.Module):
def __init__(self, params, tok_embeddings, is_train=True):
super().__init__()
self.is_train = is_train
self.params = params
self.tok_embeddings = tok_embeddings
# 嵌入参数
self.knowledge_dim = params.knowledge_dim
self.key_dim = self.knowledge_dim // 2
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.knowledge_dim, bias=False),
)
## 数据库参数
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
# 修改键存储为二维分解空间,设置为可训练参数
self.num_keys = int(math.sqrt(self.knowledge_num))
# 确保keys是可训练参数
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
# 计算step数目用于动态调整权重
self.step_counter = 0
# 移除批次计数器和更新频率相关代码
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
if self.is_train == False:
return all_scores, all_indices
batch_size = all_scores.size(0)
device = all_scores.device
dtype = all_scores.dtype
# 记录进入智能选择前的内存状态
if hasattr(self, 'step_counter'):
self.step_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.step_counter % 50 == 0: # 每50次调用记录一次
# if torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_ENTER] Step {self.step_counter}: GPU Memory: {allocated_before:.2f}GB")
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
# 预先计算所有候选条目的嵌入(批量优化)
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
# 批量计算唯一候选条目的嵌入
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
)
unique_candidate_features = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
).mean(dim=1) # [num_unique_candidates, dim]
# 归一化候选特征(优化相似度计算)
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
normalized_queries = F.normalize(query_features, dim=-1)
# 收集所有batch的best_tokens
batch_best_tokens = []
batch_best_tokens_embeddings = []
for batch_idx in range(batch_size):
indices = all_indices[batch_idx]
# 获取当前batch候选条目对应的特征索引
start_idx = batch_idx * len(indices)
end_idx = start_idx + len(indices)
batch_inverse_indices = inverse_indices[start_idx:end_idx]
# 使用预计算的归一化特征进行优化相似度计算
batch_candidate_features = normalized_candidates[batch_inverse_indices]
query_feature = normalized_queries[batch_idx]
# 使用矩阵乘法计算余弦相似度
similarity_scores = torch.mv(batch_candidate_features, query_feature)
# 找到最大相似度分数的索引
max_similarity_idx = torch.argmax(similarity_scores)
# 获取最大相似度对应的候选条目索引
best_candidate_idx = indices[max_similarity_idx]
# 获取对应的tokens
best_tokens = self.knowledge_dataset[best_candidate_idx]
best_tokens_embeddings = self.tok_embeddings(best_tokens)
# 将当前batch的best_tokens添加到列表中
batch_best_tokens.append(best_tokens)
batch_best_tokens_embeddings.append(best_tokens_embeddings)
# 将所有batch的best_tokens堆叠成一个张量
# [batch_size, knowledge_length]
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
# 清理中间张量以防止内存泄漏
del all_candidate_indices, unique_indices, inverse_indices
del unique_candidate_features, normalized_candidates, normalized_queries
del batch_best_tokens, batch_best_tokens_embeddings
del flat_tokens, flat_embeddings, pre_update_embeddings
# 记录退出智能选择后的内存状态(已禁用以提高性能)
# if hasattr(self, 'step_counter') and self.step_counter % 50 == 0:
# if torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_EXIT] Step {self.step_counter}: GPU Memory: {allocated_after:.2f}GB")
# 强制垃圾回收(仅在监控步骤)
if hasattr(self, 'step_counter') and self.step_counter % 100 == 0:
gc.collect()
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
return all_best_tokens, all_best_tokens_embeddings
def search_index(self, x):
batch_size, seq_len, dim = x.shape
# 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim]
# 2. 生成查询向量并重塑为两个子查询
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
# 3. 计算每个子空间的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 4. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 5. 组合两个子空间的结果
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
return best_tokens, best_tokens_embeddings
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 监控交叉注意力开始时的内存(已禁用以提高性能)
if not hasattr(self, 'call_counter'):
self.call_counter = 0
self.call_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_ENTER] Call {self.call_counter}: GPU Memory: {allocated_before:.2f}GB")
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
# 清理中间张量
del q, k, v, attn_scores, attn_weights
# 监控交叉注意力结束时的内存(已禁用以提高性能)
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_EXIT] Call {self.call_counter}: GPU Memory: {allocated_after:.2f}GB")
return context
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
if config.hidden_dim is None:
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class MoEGate(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = 0
return topk_idx, topk_weight, aux_loss
class MOEFeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
self.shared_experts = FeedForward(config)
def forward(self, x):
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
self.aux_loss = aux_loss
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok
# 当tokens_per_expert = [6, 15, 20, 26]tokens_per_expert.shape[0]即为专家数量此时为4
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token每个token有可能被多个专家处理这取决于num_experts_per_tok
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache
class TripleExtractionHead(nn.Module):
"""三元组提取任务头"""
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
# 三元组长度超参数
self.max_subject_len = config.max_subject_len
self.max_predicate_len = config.max_predicate_len
self.max_object_len = config.max_object_len
# 自注意力机制
self.self_attention = Attention(config)
self.self_attn_norm = RMSNorm(config.dim, eps=config.norm_eps)
# 交叉注意力机制(用于主语和宾语提取)
# self.cross_attention_subject = CrossAttention(config)
# self.cross_attention_object = CrossAttention(config)
# 归一化层
self.subject_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.object_norm = RMSNorm(config.dim, eps=config.norm_eps)
# Feed Forward 网络
self.predicate_ff = FeedForward(config)
# self.subject_ff = FeedForward(config)
# self.object_ff = FeedForward(config)
# 输出投影层 - 修改为支持序列预测
self.predicate_output = nn.Linear(config.dim, 264, bias=False)
# self.subject_output = nn.Linear(config.dim, self.max_subject_len * config.dim, bias=False)
# self.object_output = nn.Linear(config.dim, self.max_object_len * config.dim, bias=False)
print(f"三元组提取任务头配置:")
print(f"- 主语最大长度: {self.max_subject_len}")
print(f"- 谓语最大长度: {self.max_predicate_len}")
print(f"- 宾语最大长度: {self.max_object_len}")
def forward(self, h, pos_cis):
"""
Args:
h: [batch_size, seq_len, dim] - 来自transformer层的隐藏状态
pos_cis: 位置编码
Returns:
predicate_logits: [batch_size, seq_len, max_predicate_len, vocab_size] - 谓语序列预测
subject_logits: [batch_size, seq_len, max_subject_len, vocab_size] - 主语序列预测
object_logits: [batch_size, seq_len, max_object_len, vocab_size] - 宾语序列预测
"""
batch_size, seq_len, dim = h.shape
# 1. h通过自注意力得到h1
h1 = self.self_attention(self.self_attn_norm(h), pos_cis)
h1 = h + h1 # 残差连接
# 2. h1通过feed_forward得到谓语输出
predicate_features = self.predicate_ff(h1)
predicate_features = predicate_features.mean(dim=1)
predicate_class = self.predicate_output(predicate_features) # [batch_size, max_predicate_len * vocab_size]
# # 3. h1通过交叉注意力k,v都是h得到h2
# h2 = self.cross_attention_subject(h1, h) # query是h1key和value都是h
# h2 = h1 + h2 # 残差连接
# # 4. h2通过feed_forward得到主语输出
# subject_features = self.subject_ff(self.subject_norm(h2))
# subject_features = subject_features.mean(dim=1)
# subject_raw = self.subject_output(subject_features) # [batch_size, max_subject_len * vocab_size]
# subject_logits = subject_raw.view(batch_size, self.max_subject_len, -1)
# # 5. h2通过交叉注意力k,v都是h得到h3
# h3 = self.cross_attention_object(h2, h) # query是h2key和value都是h
# h3 = h2 + h3 # 残差连接
# # 6. h3通过feed_forward得到宾语输出
# object_features = self.object_ff(self.object_norm(h3))
# object_features = object_features.mean(dim=1)
# object_raw = self.object_output(object_features) # [batch_size, max_object_len * vocab_size]
# object_logits = object_raw.view(batch_size, self.max_object_len, -1)
return predicate_class
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.self_attention = Attention(config)
self.cross_attention = CrossAttention(config)
self.knowledge_dataset = knowledge_dataset
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, x, pos_cis):
h_attn = self.self_attention(
self.attention_norm(x),
pos_cis
)
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
h_attn = self.cross_attention(h_attn, db_embeddings)
h = x + h_attn
out = h + self.feed_forward(self.ffn_norm(h))
return out
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None,mode="triple"):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
# 添加三元组提取任务头(可训练)
self.triple_extraction_head = TripleExtractionHead(params)
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
self.freeze_embedding = False
self.mode = mode
# 冻结所有指定组件的权重
self._freeze_components()
def _freeze_components(self):
"""冻结指定组件的权重"""
# 冻结词嵌入层
for param in self.tok_embeddings.parameters():
param.requires_grad = False
# 冻结知识数据库
for param in self.knowledge_dataset.parameters():
param.requires_grad = False
# 冻结所有transformer层
for param in self.layers.parameters():
param.requires_grad = False
# 冻结输出层
for param in self.output.parameters():
param.requires_grad = False
# pos_cis是buffer本身就不需要梯度但为了明确起见
# (实际上buffer默认就是requires_grad=False)
if hasattr(self, 'pos_cis'):
self.pos_cis.requires_grad = False
print("已冻结以下组件的权重:")
print("- tok_embeddings")
print("- knowledge_dataset")
print("- layers (所有transformer层)")
print("- output")
print("- pos_cis")
print("注意triple_extraction_head 保持可训练状态")
def forward(self,
input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
step: int = 0,
**args):
start_pos = args.get('start_pos', 0)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers):
h = layer(
h, pos_cis
)
# 应用三元组提取任务头
predicate_class = self.triple_extraction_head(h, pos_cis)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
)
output.hidden_states = h
output.aux_loss = aux_loss
# 添加三元组提取结果
# 注意:现在的维度是 [batch_size, seq_len, max_len, vocab_size]
output.predicate_class = predicate_class
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq:
out, first_seq = self(input_ids, **args), False
else:
out = self(input_ids[:, -1:],
start_pos=input_ids.shape[1] - 1, **args)
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break

View File

@ -1,488 +0,0 @@
import math
import struct
import inspect
import time
import gc
#子空间二维分解+梯度更新
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class KnowledgeDataset(nn.Module):
def __init__(self, params, tok_embeddings, is_train=True):
super().__init__()
self.is_train = is_train
self.params = params
self.tok_embeddings = tok_embeddings
# 嵌入参数
self.knowledge_dim = params.knowledge_dim
self.key_dim = self.knowledge_dim // 2
self.to_queries = nn.Sequential(
nn.Linear(params.dim, self.knowledge_dim, bias=False),
)
## 数据库参数
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
# 修改键存储为二维分解空间,设置为可训练参数
self.num_keys = int(math.sqrt(self.knowledge_num))
# 确保keys是可训练参数
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
self.product_key_topk = min(16, self.num_keys)
# 知识库存储 - 使用register_buffer因为这是整数索引不需要梯度
self.register_buffer('knowledge_dataset',
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
# 计算step数目用于动态调整权重
self.step_counter = 0
# 移除批次计数器和更新频率相关代码
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
if self.is_train == False:
return all_scores, all_indices
batch_size = all_scores.size(0)
device = all_scores.device
dtype = all_scores.dtype
# 记录进入智能选择前的内存状态
if hasattr(self, 'step_counter'):
self.step_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.step_counter % 50 == 0: # 每50次调用记录一次
# if torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_ENTER] Step {self.step_counter}: GPU Memory: {allocated_before:.2f}GB")
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
query_features = query.mean(dim=1) # [batch_size, dim]
# 预先计算所有候选条目的嵌入(批量优化)
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
# 批量计算唯一候选条目的嵌入
candidate_tokens = self.knowledge_dataset[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
# 获取flat_tokens对应的index保留这些变量以便其他地方使用
pre_update_indices = unique_indices.view(-1)
pre_update_embeddings = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
)
unique_candidate_features = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
).mean(dim=1) # [num_unique_candidates, dim]
# 归一化候选特征(优化相似度计算)
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
normalized_queries = F.normalize(query_features, dim=-1)
# 收集所有batch的best_tokens
batch_best_tokens = []
batch_best_tokens_embeddings = []
for batch_idx in range(batch_size):
indices = all_indices[batch_idx]
# 获取当前batch候选条目对应的特征索引
start_idx = batch_idx * len(indices)
end_idx = start_idx + len(indices)
batch_inverse_indices = inverse_indices[start_idx:end_idx]
# 使用预计算的归一化特征进行优化相似度计算
batch_candidate_features = normalized_candidates[batch_inverse_indices]
query_feature = normalized_queries[batch_idx]
# 使用矩阵乘法计算余弦相似度
similarity_scores = torch.mv(batch_candidate_features, query_feature)
# 找到最大相似度分数的索引
max_similarity_idx = torch.argmax(similarity_scores)
# 获取最大相似度对应的候选条目索引
best_candidate_idx = indices[max_similarity_idx]
# 获取对应的tokens
best_tokens = self.knowledge_dataset[best_candidate_idx]
best_tokens_embeddings = self.tok_embeddings(best_tokens)
# 将当前batch的best_tokens添加到列表中
batch_best_tokens.append(best_tokens)
batch_best_tokens_embeddings.append(best_tokens_embeddings)
# 将所有batch的best_tokens堆叠成一个张量
# [batch_size, knowledge_length]
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
# 清理中间张量以防止内存泄漏
del all_candidate_indices, unique_indices, inverse_indices
del unique_candidate_features, normalized_candidates, normalized_queries
del batch_best_tokens, batch_best_tokens_embeddings
del flat_tokens, flat_embeddings, pre_update_embeddings
# 记录退出智能选择后的内存状态(已禁用以提高性能)
# if hasattr(self, 'step_counter') and self.step_counter % 50 == 0:
# if torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[INTEL_SELECT_EXIT] Step {self.step_counter}: GPU Memory: {allocated_after:.2f}GB")
# 强制垃圾回收(仅在监控步骤)
if hasattr(self, 'step_counter') and self.step_counter % 100 == 0:
gc.collect()
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
return all_best_tokens, all_best_tokens_embeddings
def search_index(self, x):
batch_size, seq_len, dim = x.shape
# 1. 序列维度平均
x_flat = x.mean(dim=1) # [batch_size, dim]
# 2. 生成查询向量并重塑为两个子查询
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
# 调整维度顺序,使子空间维度位于首位
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
# 3. 计算每个子空间的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 4. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 5. 组合两个子空间的结果
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
# 6. 将结果重塑为二维
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
# 7. 选择最终的top-k结果
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
indices = torch.gather(all_indices, 1, indices_of_indices)
# 8. 应用智能分层选择策略
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
return best_tokens, best_tokens_embeddings
class CrossAttention(nn.Module):
def __init__(
self,
config
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = self.config.dim // self.num_heads
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
batch_size = x.size(0)
# 监控交叉注意力开始时的内存(已禁用以提高性能)
if not hasattr(self, 'call_counter'):
self.call_counter = 0
self.call_counter += 1
# 禁用GPU内存监控记录以提高性能
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_ENTER] Call {self.call_counter}: GPU Memory: {allocated_before:.2f}GB")
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
context = self.to_out(context)
# 清理中间张量
del q, k, v, attn_scores, attn_weights
# 监控交叉注意力结束时的内存(已禁用以提高性能)
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
# print(f"[CROSS_ATTN_EXIT] Call {self.call_counter}: GPU Memory: {allocated_after:.2f}GB")
return context
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.self_attention = Attention(config)
self.cross_attention = CrossAttention(config)
self.knowledge_dataset = knowledge_dataset
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
# 移除 ffn_norm 和 feed_forward因为不再使用 FeedForward 层
def forward(self, x, pos_cis):
h_attn = self.self_attention(
self.attention_norm(x),
pos_cis
)
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
h_attn = self.cross_attention(h_attn, db_embeddings)
h = x + h_attn
# 移除 FeedForward 层,直接返回注意力输出
return h
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
self.freeze_embedding = False
def forward(self,
input_ids: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
step: int = 0,
**args):
start_pos = args.get('start_pos', 0)
# if self.freeze_embedding and step == 0:
# self.tok_embeddings.weight.requires_grad = False
# # 移除对knowledge_dataset.freeze_embedding的设置让键更新由batch_counter控制
# # self.knowledge_dataset.freeze_embedding = True
# print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
for l, layer in enumerate(self.layers):
h = layer(
h, pos_cis
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
# 移除 aux_loss 计算,因为不再使用 FeedForward 层
aux_loss = 0
# 进一步简化,只保留必要的参数
output = CausalLMOutputWithPast(
logits=logits,
)
output.hidden_states = h
output.aux_loss = aux_loss
return output
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
start = input_ids.shape[1]
for _ in range(max_new_tokens):
# 每次都传入完整的input_ids不使用KV缓存
out = self(input_ids, **args)
logits = out.logits[:, -1, :] # 取最后一个位置的logits
# 重复惩罚
logits[:, list(set(input_ids.tolist()[0]))] /= rp
# 温度采样
logits /= (temperature + 1e-9)
# Top-p采样
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
# 采样下一个token
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
# 返回新生成的部分
yield input_ids[:, start:]
# 如果遇到结束token停止生成
if input_ids_next.item() == eos_token_id:
break

View File

@ -1,386 +0,0 @@
import math
import struct
import inspect
import time
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache=False):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
# kv_cache实现
if past_key_value is not None:
xk = torch.cat([past_key_value[0], xk], dim=1)
xv = torch.cat([past_key_value[1], xv], dim=1)
past_kv = (xk, xv) if use_cache else None
xq, xk, xv = (
xq.transpose(1, 2),
repeat_kv(xk, self.n_rep).transpose(1, 2),
repeat_kv(xv, self.n_rep).transpose(1, 2)
)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output, past_kv
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
if config.hidden_dim is None:
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class MoEGate(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = 0
return topk_idx, topk_weight, aux_loss
class MOEFeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
self.shared_experts = FeedForward(config)
def forward(self, x):
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
self.aux_loss = aux_loss
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok
# 当tokens_per_expert = [6, 15, 20, 26]tokens_per_expert.shape[0]即为专家数量此时为4
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token每个token有可能被多个专家处理这取决于num_experts_per_tok
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.attention = Attention(config)
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, x, pos_cis, past_key_value=None, use_cache=False):
h_attn, past_kv = self.attention(
self.attention_norm(x),
pos_cis,
past_key_value=past_key_value,
use_cache=use_cache
)
h = x + h_attn
out = h + self.feed_forward(self.ffn_norm(h))
return out, past_kv
class MiniMindLM(PreTrainedModel):
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.OUT = CausalLMOutputWithPast()
def forward(self,
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0,
**args):
past_key_values = past_key_values or [None] * len(self.layers)
start_pos = args.get('start_pos', 0)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
past_kvs = []
for l, layer in enumerate(self.layers):
h, past_kv = layer(
h, pos_cis,
past_key_value=past_key_values[l],
use_cache=use_cache
)
past_kvs.append(past_kv)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
# 统一不使用 aux_loss
aux_loss = 0
self.OUT.__setitem__('last_hidden_state', h)
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss)
self.OUT.__setitem__('past_key_values', past_kvs)
return self.OUT
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., use_cache=True, pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq or not use_cache:
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False
else:
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
start_pos=input_ids.shape[1] - 1, **args)
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break

File diff suppressed because it is too large Load Diff

View File

@ -1,43 +0,0 @@
{
"add_bos_token": false,
"add_eos_token": false,
"add_prefix_space": false,
"added_tokens_decoder": {
"0": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "<|im_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "<|im_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [],
"bos_token": "<|im_start|>",
"clean_up_tokenization_spaces": false,
"eos_token": "<|im_end|>",
"legacy": true,
"model_max_length": 32768,
"pad_token": "<unk>",
"sp_model_kwargs": {},
"spaces_between_special_tokens": false,
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": "<unk>",
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\n你是 MiniMind是一个有用的人工智能助手。<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"
}

File diff suppressed because one or more lines are too long

View File

@ -1,133 +0,0 @@
import json
import os
import datetime
from typing import List, Dict, Any
# 配置参数
json_path = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/dataset/combined.json"
prepare_num = 1048576 # database_init.json的数据条数可以根据需要修改
output_dir = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/dataset/"
def convert_to_database_init_format(sentences: List[str], importance_score: float = 10.0) -> Dict[str, Any]:
"""
将句子列表转换为 database_init.json 格式
Args:
sentences: 句子列表
importance_score: 重要性评分默认为10.0
Returns:
转换后的字典格式数据
"""
# 构建句子数据
sentence_data = []
for sentence in sentences:
sentence_item = {
"original_sentence": sentence,
"corrected_sentence": sentence, # 与original_sentence相同
"importance_score": importance_score
}
sentence_data.append(sentence_item)
# 构建完整的数据结构
result = {
"metadata": {
"batch_number": 1,
"batch_size": len(sentences),
"total_processed_count": len(sentences),
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"total_sentences": len(sentences),
"duplicates_removed": 0 # 在此函数中不涉及去重所以设为0
},
"sentences": sentence_data
}
return result
def preprocess_combined_json():
# 读取原始数据
print("正在读取combined.json...")
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
total_count = len(data)
print(f"总共有 {total_count} 条数据")
# 处理所有数据将subject、predicate、object拼接成句子同时记录原始数据
print("正在处理数据并拼接句子...")
sentence_to_original = {} # 记录句子到原始数据的映射
all_sentences = []
for i, item in enumerate(data):
# 拼接subject、predicate、object为一句话
sentence = f"{item['subject']} {item['predicate']} {item['object']}"
all_sentences.append(sentence)
# 记录句子到原始数据的映射(如果句子重复,保留第一次出现的原始数据)
if sentence not in sentence_to_original:
sentence_to_original[sentence] = item
if (i + 1) % 100000 == 0:
print(f"已处理 {i + 1}/{total_count} 条数据")
print(f"完成句子拼接,共 {len(all_sentences)} 条句子")
# 去重处理
print("正在进行去重处理...")
unique_sentences = list(set(all_sentences))
duplicates_removed = len(all_sentences) - len(unique_sentences)
print(f"去重完成,去重前: {len(all_sentences)} 条,去重后: {len(unique_sentences)} 条,移除重复: {duplicates_removed}")
# 检查是否有足够的去重数据
if len(unique_sentences) < prepare_num:
print(f"警告: 去重后的数据量 ({len(unique_sentences)}) 少于所需数量 ({prepare_num})")
print(f"将使用全部 {len(unique_sentences)} 条去重数据")
selected_sentences = unique_sentences
else:
print(f"选择前 {prepare_num} 条去重数据")
selected_sentences = unique_sentences[:prepare_num]
# 转换为database_init.json格式
print("正在转换为database_init.json格式...")
database_init_data = convert_to_database_init_format(selected_sentences, importance_score=10.0)
# 更新metadata中的duplicates_removed信息
database_init_data["metadata"]["duplicates_removed"] = duplicates_removed
# 保存database_init.json
database_output_path = os.path.join(output_dir, "database_init_from_combined.json")
print(f"正在保存 {database_output_path}...")
with open(database_output_path, "w", encoding="utf-8") as f:
json.dump(database_init_data, f, ensure_ascii=False, indent=2)
print(f"database_init_from_combined.json 保存完成,包含 {len(selected_sentences)} 条数据")
# 保存剩余数据作为训练集(保持原格式)
remaining_sentences = unique_sentences[prepare_num:] if len(unique_sentences) > prepare_num else []
if remaining_sentences:
# 将剩余的句子转换回原始格式
print(f"正在转换剩余 {len(remaining_sentences)} 条数据为原始格式...")
remaining_original_data = []
for sentence in remaining_sentences:
if sentence in sentence_to_original:
remaining_original_data.append(sentence_to_original[sentence])
print(f"保存剩余 {len(remaining_original_data)} 条数据作为训练集...")
train_output_path = os.path.join(output_dir, "combined_train.json")
with open(train_output_path, "w", encoding="utf-8") as f:
json.dump(remaining_original_data, f, ensure_ascii=False, indent=2)
print(f"combined_train.json 保存完成")
else:
print("没有剩余数据用于训练集")
remaining_original_data = []
print("\n数据处理完成!")
print(f"原始数据: {total_count}")
print(f"拼接后: {len(all_sentences)} 条句子")
print(f"去重后: {len(unique_sentences)} 条句子")
print(f"用于database_init: {len(selected_sentences)}")
print(f"剩余训练数据: {len(remaining_original_data) if remaining_sentences else 0}")
if __name__ == "__main__":
preprocess_combined_json()

View File

@ -1,741 +0,0 @@
import json
import os
import pandas as pd
import tarfile
import tempfile
import shutil
from pathlib import Path
import re
import langdetect
from tqdm import tqdm
import logging
import random
import hashlib
from transformers import AutoTokenizer
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 配置参数
BASE_DIR = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind"
OUTPUT_FILE = os.path.join(BASE_DIR, "dataset", "merged_pretrain.jsonl")
OUTPUT_FILE_EXTRA = os.path.join(BASE_DIR, "dataset", "merged_pretrain_extra.jsonl")
# 数据源路径
PRETRAIN_HQ_PATH = os.path.join(BASE_DIR, "dataset", "pretrain_hq.jsonl")
WIKIPEDIA_PATH = "/home/pci/nas/share/datasets/wikipedia/data/20220301.en"
GUTENBERG_PATH = "/home/pci/nas/share/datasets/gutenberg/data"
OPENWEBTEXT_PATH = "/home/pci/nas/share/datasets/openwebtext/subsets"
# Token长度限制
MIN_TOKENS = 410
MAX_TOKENS = 490
# 数据集质量和采样比例配置 - 主文件
DATASET_CONFIG = {
"pretrain_hq": {"quality": "high", "sample_ratio": 1.0, "max_samples": None}, # 高质量,全部使用
"wikipedia": {"quality": "high", "sample_ratio": 1.0, "max_samples": 5000000}, # 高质量使用全部最多500万条
"gutenberg": {"quality": "medium", "sample_ratio": 0.8, "max_samples": 3000000}, # 中质量使用80%最多300万条
"openwebtext": {"quality": "low", "sample_ratio": 0.2, "max_samples": 2000000} # 低质量使用20%最多200万条
}
# 额外文件的配置 - 剩余数据
DATASET_CONFIG_EXTRA = {
"wikipedia": {"quality": "high", "sample_ratio": 1.0, "max_samples": None}, # 剩余的全部
"gutenberg": {"quality": "medium", "sample_ratio": 0.8, "max_samples": 5000000}, # 剩余的80%最多500万条
"openwebtext": {"quality": "low", "sample_ratio": 0.6, "max_samples": 4000000} # 剩余的60%最多400万条
}
# 全局变量:记录已选择的数据
selected_data_hashes = {
"wikipedia": set(),
"gutenberg": set(),
"openwebtext": set()
}
# 初始化tokenizer
tokenizer = None
def init_tokenizer():
"""初始化tokenizer"""
global tokenizer
try:
# 首先尝试使用本地的minimind tokenizer
local_tokenizer_path = os.path.join(BASE_DIR, "model", "minimind_tokenizer")
if os.path.exists(local_tokenizer_path):
tokenizer = AutoTokenizer.from_pretrained(local_tokenizer_path)
logger.info("Local MiniMind tokenizer initialized successfully")
else:
# 如果本地tokenizer不存在使用GPT-2但设置离线模式
tokenizer = AutoTokenizer.from_pretrained("gpt2", local_files_only=True)
logger.info("GPT-2 tokenizer initialized successfully (offline)")
except Exception as e:
logger.error(f"Error initializing tokenizer: {e}")
logger.info("Trying to use a simple fallback tokenizer...")
# 使用简单的分词方法作为备选
tokenizer = None
logger.warning("Using simple word-based tokenization as fallback")
def count_tokens(text):
"""计算文本的token数量"""
if tokenizer is None:
init_tokenizer()
if tokenizer is not None:
try:
tokens = tokenizer.encode(text, add_special_tokens=False)
return len(tokens)
except:
pass
# 如果tokenization失败或tokenizer为None使用简单估算
return int(len(text.split()) * 1.3) # 大概估算,确保返回整数
def is_english_text(text, threshold=0.8):
"""检测文本是否为英文"""
try:
if len(text) < 50: # 太短的文本跳过检测
return True
detected_lang = langdetect.detect(text)
return detected_lang == 'en'
except:
# 如果检测失败,使用简单的英文字符比例判断
english_chars = sum(1 for char in text if char.isascii() and (char.isalpha() or char in ' .,!?-'))
total_chars = len(text.replace('\n', ' ').replace('\t', ' '))
return (english_chars / max(total_chars, 1)) > threshold
def truncate_to_token_limit(text, target_tokens=MAX_TOKENS):
"""将文本截断到目标token数量"""
if tokenizer is None:
init_tokenizer()
if tokenizer is not None:
try:
tokens = tokenizer.encode(text, add_special_tokens=False)
if len(tokens) <= target_tokens:
return text
# 截断到目标长度
truncated_tokens = tokens[:target_tokens]
truncated_text = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
# 尝试在句号处截断以保持完整性
sentences = truncated_text.split('.')
if len(sentences) > 1:
# 保留除最后一个不完整句子外的所有句子
truncated_text = '.'.join(sentences[:-1]) + '.'
return truncated_text
except:
pass
# 如果处理失败或tokenizer为None使用字符数估算
estimated_chars = int(target_tokens / 1.3 * 4) # 大概估算
text = text[:estimated_chars]
# 尝试在句号处截断以保持完整性
sentences = text.split('.')
if len(sentences) > 1:
text = '.'.join(sentences[:-1]) + '.'
return text
def split_text_into_chunks(text, target_chunk_size=1500):
"""将长文本分割成多个中等长度的段落块"""
# 清理文本
text = text.strip()
if not text:
return []
# 移除过多的换行符和空格
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
text = re.sub(r' +', ' ', text)
chunks = []
# 按段落分割
paragraphs = text.split('\n\n')
current_chunk = ""
for paragraph in paragraphs:
paragraph = paragraph.strip()
if not paragraph:
continue
# 如果当前块加上新段落长度适中,就添加
if len(current_chunk) + len(paragraph) < target_chunk_size:
if current_chunk:
current_chunk += "\n\n" + paragraph
else:
current_chunk = paragraph
else:
# 如果当前块不为空,保存它
if current_chunk:
chunks.append(current_chunk)
# 如果段落本身就很长,需要进一步分割
if len(paragraph) > target_chunk_size * 2:
# 按句子分割长段落
sentences = re.split(r'(?<=[.!?])\s+', paragraph)
temp_chunk = ""
for sentence in sentences:
if len(temp_chunk) + len(sentence) < target_chunk_size:
if temp_chunk:
temp_chunk += " " + sentence
else:
temp_chunk = sentence
else:
if temp_chunk:
chunks.append(temp_chunk)
temp_chunk = sentence
if temp_chunk:
current_chunk = temp_chunk
else:
current_chunk = ""
else:
current_chunk = paragraph
# 添加最后一个块
if current_chunk:
chunks.append(current_chunk)
return chunks
def format_text_for_pretrain(text):
"""将文本格式化为预训练格式并检查token长度"""
# 清理文本
text = text.strip()
if not text:
return None
# 移除过多的换行符和空格
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
text = re.sub(r' +', ' ', text)
# 检查token长度
token_count = count_tokens(text)
# 如果太短,跳过
if token_count < MIN_TOKENS:
return None
# 如果太长,截断
if token_count > MAX_TOKENS:
text = truncate_to_token_limit(text, MAX_TOKENS)
token_count = count_tokens(text)
# 再次检查是否在合理范围内
if token_count < MIN_TOKENS or token_count > MAX_TOKENS:
return None
# 格式化为预训练格式
formatted_text = f"<|im_start|>{text}<|im_end|>"
return formatted_text
def get_text_hash(text):
"""获取文本的哈希值,用于去重"""
return hashlib.md5(text.encode('utf-8')).hexdigest()
def should_sample(dataset_name, current_count, config_dict=None, is_extra_mode=False):
"""根据配置决定是否采样当前记录"""
if config_dict is None:
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
config = config_dict[dataset_name]
# 检查是否达到最大样本数
if config["max_samples"] and current_count >= config["max_samples"]:
return False
# 根据采样比例随机决定
return random.random() < config["sample_ratio"]
def process_pretrain_hq():
"""处理已有的高质量预训练数据 - 直接输出,不做任何处理"""
logger.info("Processing pretrain_hq.jsonl...")
count = 0
with open(PRETRAIN_HQ_PATH, 'r', encoding='utf-8') as f:
for line in tqdm(f, desc="Processing pretrain_hq"):
try:
data = json.loads(line.strip())
text = data.get('text', '').strip()
if text: # 只要有文本就直接输出,不做任何检测
if should_sample("pretrain_hq", count):
yield {"text": text}
count += 1
except json.JSONDecodeError:
continue
logger.info(f"Processed {count} records from pretrain_hq.jsonl")
def process_wikipedia(is_extra_mode=False):
"""处理Wikipedia数据"""
mode_text = "extra" if is_extra_mode else "main"
logger.info(f"Processing Wikipedia data ({mode_text} mode)...")
count = 0
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
# 获取所有英文Wikipedia文件
wiki_files = list(Path(WIKIPEDIA_PATH).glob("*.parquet"))
for file_path in tqdm(wiki_files, desc=f"Processing Wikipedia files ({mode_text})"):
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
break
try:
df = pd.read_parquet(file_path)
for _, row in df.iterrows():
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
break
text = row.get('text', '').strip()
if text and len(text) > 200: # 预过滤太短的文本
# 先将长文本分割成中等大小的块
chunks = split_text_into_chunks(text, target_chunk_size=2000)
for chunk in chunks:
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
break
chunk_hash = get_text_hash(chunk)
# 在额外模式下,跳过已经被主文件选中的数据
if is_extra_mode and chunk_hash in selected_data_hashes["wikipedia"]:
continue
formatted_text = format_text_for_pretrain(chunk)
if formatted_text and should_sample("wikipedia", count, config_dict, is_extra_mode):
# 在主模式下记录哈希值
if not is_extra_mode:
selected_data_hashes["wikipedia"].add(chunk_hash)
yield {"text": formatted_text}
count += 1
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
continue
logger.info(f"Processed {count} records from Wikipedia ({mode_text} mode)")
def process_gutenberg(is_extra_mode=False):
"""处理Gutenberg数据"""
mode_text = "extra" if is_extra_mode else "main"
logger.info(f"Processing Gutenberg data ({mode_text} mode)...")
count = 0
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
# 获取所有Gutenberg训练文件
gutenberg_files = list(Path(GUTENBERG_PATH).glob("train-*.parquet"))
for file_path in tqdm(gutenberg_files, desc=f"Processing Gutenberg files ({mode_text})"):
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
break
try:
df = pd.read_parquet(file_path)
for _, row in df.iterrows():
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
break
text = row.get('text', '').strip()
if text and len(text) > 300 and is_english_text(text): # 预过滤
# 先将长文本分割成中等大小的块
chunks = split_text_into_chunks(text, target_chunk_size=1800)
for chunk in chunks:
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
break
chunk_hash = get_text_hash(chunk)
# 在额外模式下,跳过已经被主文件选中的数据
if is_extra_mode and chunk_hash in selected_data_hashes["gutenberg"]:
continue
formatted_text = format_text_for_pretrain(chunk)
if formatted_text and should_sample("gutenberg", count, config_dict, is_extra_mode):
# 在主模式下记录哈希值
if not is_extra_mode:
selected_data_hashes["gutenberg"].add(chunk_hash)
yield {"text": formatted_text}
count += 1
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
continue
logger.info(f"Processed {count} records from Gutenberg ({mode_text} mode)")
def process_openwebtext(is_extra_mode=False):
"""处理OpenWebText数据"""
mode_text = "extra" if is_extra_mode else "main"
logger.info(f"Processing OpenWebText data ({mode_text} mode)...")
count = 0
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
max_files = 5 # 减少处理的文件数量以避免过长处理时间
# 获取tar文件列表
tar_files = list(Path(OPENWEBTEXT_PATH).glob("*.tar"))[:max_files]
for tar_path in tqdm(tar_files, desc=f"Processing OpenWebText files ({mode_text})"):
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
try:
with tarfile.open(tar_path, 'r') as outer_tar:
# 创建临时目录处理外层tar
with tempfile.TemporaryDirectory() as temp_dir:
outer_tar.extractall(temp_dir)
# 处理解压后的xz文件
for root, dirs, files in os.walk(temp_dir):
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
for file in files:
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
if file.endswith('.xz'):
xz_path = os.path.join(root, file)
# 创建另一个临时目录处理xz文件
with tempfile.TemporaryDirectory() as xz_temp_dir:
try:
# 解压xz文件
import subprocess
decompressed_path = os.path.join(xz_temp_dir, file[:-3]) # 移除.xz后缀
subprocess.run(['xz', '-dc', xz_path],
stdout=open(decompressed_path, 'wb'),
check=True)
# 检查解压后的文件是否是tar格式
if tarfile.is_tarfile(decompressed_path):
# 处理内层tar文件
with tarfile.open(decompressed_path, 'r') as inner_tar:
with tempfile.TemporaryDirectory() as inner_temp_dir:
inner_tar.extractall(inner_temp_dir)
# 处理最终的txt文件
for inner_root, inner_dirs, inner_files in os.walk(inner_temp_dir):
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
for txt_file in inner_files:
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
if txt_file.endswith('.txt'):
txt_path = os.path.join(inner_root, txt_file)
try:
with open(txt_path, 'r', encoding='utf-8', errors='ignore') as f:
text = f.read().strip()
if text and len(text) > 500 and is_english_text(text):
# 先将长文本分割成中等大小的块
chunks = split_text_into_chunks(text, target_chunk_size=1600)
for chunk in chunks:
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
chunk_hash = get_text_hash(chunk)
# 在额外模式下,跳过已经被主文件选中的数据
if is_extra_mode and chunk_hash in selected_data_hashes["openwebtext"]:
continue
formatted_text = format_text_for_pretrain(chunk)
if formatted_text and should_sample("openwebtext", count, config_dict, is_extra_mode):
# 在主模式下记录哈希值
if not is_extra_mode:
selected_data_hashes["openwebtext"].add(chunk_hash)
yield {"text": formatted_text}
count += 1
except Exception as e:
logger.debug(f"Error reading txt file {txt_path}: {e}")
continue
else:
# 如果不是tar文件直接作为文本处理
try:
with open(decompressed_path, 'r', encoding='utf-8', errors='ignore') as f:
text = f.read().strip()
if text and len(text) > 500 and is_english_text(text):
chunks = split_text_into_chunks(text, target_chunk_size=1600)
for chunk in chunks:
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
break
chunk_hash = get_text_hash(chunk)
# 在额外模式下,跳过已经被主文件选中的数据
if is_extra_mode and chunk_hash in selected_data_hashes["openwebtext"]:
continue
formatted_text = format_text_for_pretrain(chunk)
if formatted_text and should_sample("openwebtext", count, config_dict, is_extra_mode):
# 在主模式下记录哈希值
if not is_extra_mode:
selected_data_hashes["openwebtext"].add(chunk_hash)
yield {"text": formatted_text}
count += 1
except Exception as e:
logger.debug(f"Error reading decompressed file {decompressed_path}: {e}")
continue
except Exception as e:
logger.debug(f"Error processing xz file {xz_path}: {e}")
continue
except Exception as e:
logger.error(f"Error processing {tar_path}: {e}")
continue
logger.info(f"Processed {count} records from OpenWebText ({mode_text} mode)")
def merge_datasets():
"""合并所有数据集,生成主文件和额外文件"""
logger.info("Starting dataset merging...")
logger.info("Main dataset configuration:")
for name, config in DATASET_CONFIG.items():
logger.info(f" {name}: quality={config['quality']}, ratio={config['sample_ratio']}, max={config['max_samples']}")
logger.info("Extra dataset configuration:")
for name, config in DATASET_CONFIG_EXTRA.items():
logger.info(f" {name}: quality={config['quality']}, ratio={config['sample_ratio']}, max={config['max_samples']}")
# 确保输出目录存在
os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)
os.makedirs(os.path.dirname(OUTPUT_FILE_EXTRA), exist_ok=True)
# 统计信息
main_dataset_stats = {}
extra_dataset_stats = {}
# 第一阶段:生成主文件
logger.info("="*50)
logger.info("PHASE 1: Generating main dataset file")
logger.info("="*50)
with open(OUTPUT_FILE, 'w', encoding='utf-8') as outfile:
main_total_count = 0
# 处理各个数据集(主模式)
main_datasets = [
("pretrain_hq", process_pretrain_hq),
("wikipedia", lambda: process_wikipedia(is_extra_mode=False)),
("gutenberg", lambda: process_gutenberg(is_extra_mode=False)),
("openwebtext", lambda: process_openwebtext(is_extra_mode=False))
]
for dataset_name, dataset_func in main_datasets:
logger.info(f"Processing {dataset_name} for main file...")
dataset_count = 0
try:
for record in dataset_func():
json.dump(record, outfile, ensure_ascii=False)
outfile.write('\n')
dataset_count += 1
main_total_count += 1
# 每5000条记录输出一次进度
if main_total_count % 5000 == 0:
logger.info(f"Main file: Processed {main_total_count} total records")
# 保存统计信息
main_dataset_stats[dataset_name] = {
'selected': dataset_count,
'config': DATASET_CONFIG[dataset_name]
}
except Exception as e:
logger.error(f"Error processing {dataset_name} for main file: {e}")
main_dataset_stats[dataset_name] = {
'selected': dataset_count,
'config': DATASET_CONFIG[dataset_name]
}
logger.info(f"Main file - Completed {dataset_name}: {dataset_count} records")
logger.info(f"Main file generation completed. Total records: {main_total_count}")
# 第二阶段:生成额外文件
logger.info("="*50)
logger.info("PHASE 2: Generating extra dataset file")
logger.info("="*50)
with open(OUTPUT_FILE_EXTRA, 'w', encoding='utf-8') as outfile:
extra_total_count = 0
# 处理各个数据集(额外模式)- 不包括pretrain_hq
extra_datasets = [
("wikipedia", lambda: process_wikipedia(is_extra_mode=True)),
("gutenberg", lambda: process_gutenberg(is_extra_mode=True)),
("openwebtext", lambda: process_openwebtext(is_extra_mode=True))
]
for dataset_name, dataset_func in extra_datasets:
logger.info(f"Processing {dataset_name} for extra file...")
dataset_count = 0
try:
for record in dataset_func():
json.dump(record, outfile, ensure_ascii=False)
outfile.write('\n')
dataset_count += 1
extra_total_count += 1
# 每5000条记录输出一次进度
if extra_total_count % 5000 == 0:
logger.info(f"Extra file: Processed {extra_total_count} total records")
# 保存统计信息
extra_dataset_stats[dataset_name] = {
'selected': dataset_count,
'config': DATASET_CONFIG_EXTRA[dataset_name]
}
except Exception as e:
logger.error(f"Error processing {dataset_name} for extra file: {e}")
extra_dataset_stats[dataset_name] = {
'selected': dataset_count,
'config': DATASET_CONFIG_EXTRA[dataset_name]
}
logger.info(f"Extra file - Completed {dataset_name}: {dataset_count} records")
logger.info(f"Extra file generation completed. Total records: {extra_total_count}")
# 打印详细统计信息
print_detailed_statistics(main_dataset_stats, main_total_count, extra_dataset_stats, extra_total_count)
logger.info("All dataset processing completed successfully!")
logger.info(f"Main file saved to: {OUTPUT_FILE}")
logger.info(f"Extra file saved to: {OUTPUT_FILE_EXTRA}")
def print_detailed_statistics(main_dataset_stats, main_total_count, extra_dataset_stats, extra_total_count):
"""打印详细统计信息"""
print("\n" + "="*100)
print("DATASET PROCESSING SUMMARY")
print("="*100)
# 主文件统计
print("\nMAIN FILE (merged_pretrain.jsonl):")
print("-" * 90)
print(f"{'Dataset':<15} {'Selected':<10} {'Sample Ratio':<12} {'Max Limit':<12} {'% of Main':<12} {'Quality':<8}")
print("-" * 90)
for dataset_name, stats in main_dataset_stats.items():
selected = stats['selected']
config = stats['config']
ratio = config['sample_ratio']
max_limit = config['max_samples'] if config['max_samples'] else "No limit"
percentage = (selected / main_total_count * 100) if main_total_count > 0 else 0
quality = config['quality']
print(f"{dataset_name:<15} {selected:<10,} {ratio:<12.1%} {str(max_limit):<12} {percentage:<12.2f}% {quality:<8}")
print("-" * 90)
print(f"{'MAIN TOTAL':<15} {main_total_count:<10,} {'':<12} {'':<12} {'100.00%':<12} {'':<8}")
# 额外文件统计
print("\nEXTRA FILE (merged_pretrain_extra.jsonl):")
print("-" * 90)
print(f"{'Dataset':<15} {'Selected':<10} {'Sample Ratio':<12} {'Max Limit':<12} {'% of Extra':<12} {'Quality':<8}")
print("-" * 90)
for dataset_name, stats in extra_dataset_stats.items():
selected = stats['selected']
config = stats['config']
ratio = config['sample_ratio']
max_limit = config['max_samples'] if config['max_samples'] else "No limit"
percentage = (selected / extra_total_count * 100) if extra_total_count > 0 else 0
quality = config['quality']
print(f"{dataset_name:<15} {selected:<10,} {ratio:<12.1%} {str(max_limit):<12} {percentage:<12.2f}% {quality:<8}")
print("-" * 90)
print(f"{'EXTRA TOTAL':<15} {extra_total_count:<10,} {'':<12} {'':<12} {'100.00%':<12} {'':<8}")
# 总体统计
total_records = main_total_count + extra_total_count
print("\nOVERALL STATISTICS:")
print("-" * 50)
print(f"Main file records: {main_total_count:>10,}")
print(f"Extra file records: {extra_total_count:>10,}")
print(f"Total records: {total_records:>10,}")
print(f"Token range per sample: {MIN_TOKENS}-{MAX_TOKENS} tokens")
# 质量分布统计
quality_stats = {}
for dataset_name, stats in main_dataset_stats.items():
quality = stats['config']['quality']
if quality not in quality_stats:
quality_stats[quality] = {'main': 0, 'extra': 0}
quality_stats[quality]['main'] += stats['selected']
for dataset_name, stats in extra_dataset_stats.items():
quality = stats['config']['quality']
if quality not in quality_stats:
quality_stats[quality] = {'main': 0, 'extra': 0}
quality_stats[quality]['extra'] += stats['selected']
print("\nQUALITY DISTRIBUTION:")
print("-" * 60)
print(f"{'Quality':<12} {'Main File':<12} {'Extra File':<12} {'Total':<12} {'%':<8}")
print("-" * 60)
for quality in sorted(quality_stats.keys()):
main_count = quality_stats[quality]['main']
extra_count = quality_stats[quality]['extra']
total_count = main_count + extra_count
percentage = (total_count / total_records * 100) if total_records > 0 else 0
print(f"{quality.capitalize():<12} {main_count:<12,} {extra_count:<12,} {total_count:<12,} {percentage:<8.2f}%")
print("-" * 60)
print(f"{'Total':<12} {main_total_count:<12,} {extra_total_count:<12,} {total_records:<12,} {'100.00%':<8}")
print(f"\nFiles saved to:")
print(f" Main file: {OUTPUT_FILE}")
print(f" Extra file: {OUTPUT_FILE_EXTRA}")
print("="*100)
def main():
"""主函数"""
try:
# 设置随机种子以确保结果可重现
random.seed(42)
# 检查依赖包
try:
import langdetect
from transformers import AutoTokenizer
except ImportError as e:
logger.error(f"Missing dependencies: {e}")
logger.error("Please install: pip install langdetect transformers")
return
# 初始化tokenizer
init_tokenizer()
# 检查输入文件
if not os.path.exists(PRETRAIN_HQ_PATH):
logger.error(f"pretrain_hq.jsonl not found at {PRETRAIN_HQ_PATH}")
return
# 开始合并数据集
merge_datasets()
logger.info("All processing completed successfully!")
except Exception as e:
logger.error(f"Error in main process: {e}")
raise
if __name__ == "__main__":
main()

View File

@ -1,442 +0,0 @@
import json
import os
import argparse
from typing import List, Dict, Any, Optional
from collections import defaultdict
import pickle
from pathlib import Path
class WikidataRelationManager:
"""Wikidata关系管理器支持动态获取和缓存"""
def __init__(self, cache_file: str = "wikidata_relations_cache.pkl",
mapping_file: str = None):
self.cache_file = cache_file
self.mapping_file = mapping_file
self.relations = {}
# 删除了API相关属性
# 初始的基础关系映射
self.base_relations = {
# # 基本关系
# 'P31': 'instance of',
# 'P279': 'subclass of',
# 'P17': 'country',
# 'P159': 'headquarters location',
# 'P571': 'inception',
# # 人物关系
# 'P19': 'place of birth',
# 'P20': 'place of death',
# 'P27': 'country of citizenship',
# 'P106': 'occupation',
# 'P22': 'father',
# 'P25': 'mother',
# 'P26': 'spouse',
# 'P40': 'child',
# 'P69': 'educated at',
# 'P108': 'employer',
# # 地理关系
# 'P36': 'capital',
# 'P131': 'located in',
# 'P47': 'shares border with',
# 'P206': 'located on terrain feature',
# 'P1376': 'capital of',
# # 组织关系
# 'P112': 'founded by',
# 'P127': 'owned by',
# 'P169': 'chief executive officer',
# 'P488': 'chairperson',
# 'P749': 'parent organization',
# # 作品关系
# 'P50': 'author',
# 'P57': 'director',
# 'P58': 'screenwriter',
# 'P161': 'cast member',
# 'P175': 'performer',
# 'P577': 'publication date',
# 'P123': 'publisher',
# 'P136': 'genre',
# # 时间关系
# 'P155': 'follows',
# 'P156': 'followed by',
# 'P580': 'start time',
# 'P582': 'end time',
# # 体育关系
# 'P54': 'member of sports team',
# 'P413': 'position played on team',
# 'P118': 'league',
# # 科学关系
# 'P275': 'copyright license',
# 'P170': 'creator',
# 'P398': 'child astronomical body',
# 'P397': 'parent astronomical body',
# # 其他常见关系
# 'P37': 'official language',
# 'P1923': 'place of marriage',
# 'P737': 'influenced by',
# 'P463': 'member of',
# 'P39': 'position held',
# 'P276': 'location',
# 'P1441': 'present in work',
}
self.load_cache()
def load_cache(self):
"""加载缓存的关系映射优先使用JSON映射文件"""
try:
# 优先尝试加载JSON映射文件
if self.mapping_file and os.path.exists(self.mapping_file):
with open(self.mapping_file, 'r', encoding='utf-8') as f:
self.relations = json.load(f)
print(f"从JSON映射文件加载了 {len(self.relations)} 个关系映射")
return
# 尝试加载pickle缓存文件
if os.path.exists(self.cache_file):
with open(self.cache_file, 'rb') as f:
self.relations = pickle.load(f)
print(f"从pickle缓存加载了 {len(self.relations)} 个关系映射")
else:
self.relations = self.base_relations.copy()
print(f"初始化基础关系映射: {len(self.relations)}")
except Exception as e:
print(f"加载缓存失败: {e}")
self.relations = self.base_relations.copy()
def save_cache(self):
"""保存关系映射到缓存"""
try:
with open(self.cache_file, 'wb') as f:
pickle.dump(self.relations, f)
print(f"已保存 {len(self.relations)} 个关系映射到缓存")
except Exception as e:
print(f"保存缓存失败: {e}")
# 删除了网络抓取功能,改为纯离线模式
def get_relation_name(self, property_id: str) -> Optional[str]:
"""获取关系名称,仅使用本地映射"""
if property_id in self.relations:
return self.relations[property_id]
# 如果本地映射中没有找到返回None表示跳过这个关系
return None
# 删除了网络请求相关的批量获取和预加载功能
class TRexProcessor:
"""T-REx数据集处理器"""
def __init__(self, relation_manager: WikidataRelationManager):
self.relation_manager = relation_manager
def extract_predicate_id(self, uri: str) -> str:
"""从URI中提取属性ID"""
if uri and 'prop/direct/' in uri:
return uri.split('/')[-1]
elif uri and uri.startswith('P') and uri[1:].isdigit():
return uri
return uri if uri else 'unknown'
def get_relation_name(self, predicate_uri: str) -> Optional[str]:
"""获取关系的可读名称"""
predicate_id = self.extract_predicate_id(predicate_uri)
return self.relation_manager.get_relation_name(predicate_id)
# 删除了谓词收集功能,因为不再需要预加载
def is_valid_triple(self, triple: Dict[str, Any], confidence_threshold: float,
boundary_threshold: int) -> bool:
"""检查三元组是否满足过滤条件"""
try:
# 检查triple是否为字典
if not isinstance(triple, dict):
return False
# 检查必要字段
if not all(key in triple for key in ['subject', 'predicate', 'object']):
return False
subject = triple['subject']
predicate = triple['predicate']
object_info = triple['object']
# 检查subject、predicate、object是否都为字典
if not isinstance(subject, dict) or not isinstance(predicate, dict) or not isinstance(object_info, dict):
return False
# 检查主语和宾语是否有有效的URI和surfaceform
if not (subject.get('uri') and subject.get('surfaceform')):
return False
if not (object_info.get('uri') and object_info.get('surfaceform')):
return False
if not predicate.get('uri'):
return False
# 检查置信度(如果存在)
confidence = triple.get('confidence')
if confidence is not None and confidence < confidence_threshold:
return False
# 检查边界信息(如果设置了阈值)
if boundary_threshold > 0:
subject_boundaries = subject.get('boundaries')
object_boundaries = object_info.get('boundaries')
if not subject_boundaries or not object_boundaries:
return False
# 检查边界是否为列表且长度至少为2
if not (isinstance(subject_boundaries, list) and len(subject_boundaries) >= 2):
return False
if not (isinstance(object_boundaries, list) and len(object_boundaries) >= 2):
return False
try:
# 检查边界长度是否合理
subject_length = subject_boundaries[1] - subject_boundaries[0]
object_length = object_boundaries[1] - object_boundaries[0]
if subject_length < boundary_threshold or object_length < boundary_threshold:
return False
except (TypeError, IndexError):
return False
# 检查文本内容是否合理
subject_text = subject.get('surfaceform', '').strip()
object_text = object_info.get('surfaceform', '').strip()
if not subject_text or not object_text:
return False
# 过滤掉过长或过短的实体
if len(subject_text) > 100 or len(object_text) > 100:
return False
if len(subject_text) < 2 or len(object_text) < 2:
return False
return True
except (KeyError, TypeError, AttributeError):
return False
def process_single_file(self, file_path: str, confidence_threshold: float,
boundary_threshold: int) -> List[Dict[str, Any]]:
"""处理单个JSON文件"""
print(f"Processing file: {file_path}")
processed_data = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
# 读取整个文件作为JSON数组
print(f"正在加载JSON数组文件: {file_path}")
data_list = json.load(f)
print(f"文件包含 {len(data_list)} 个条目")
for idx, data in enumerate(data_list):
try:
# 获取基本信息
text = data.get('text', '').strip()
if not text:
continue
# 处理三元组
triples = data.get('triples', [])
if not triples:
continue
valid_targets = []
for triple in triples:
if self.is_valid_triple(triple, confidence_threshold, boundary_threshold):
# 获取关系名称,如果无法解析则跳过这个三元组
relation_name = self.get_relation_name(triple['predicate']['uri'])
if relation_name is None:
continue # 跳过无法解析的关系
target = {
'subject': triple['subject']['surfaceform'].strip(),
'predicate': relation_name,
'object': triple['object']['surfaceform'].strip()
}
valid_targets.append(target)
# 如果有有效的三元组,添加到结果中
if valid_targets:
processed_data.append({
'text': text,
'target': valid_targets
})
except Exception as e:
if idx <= 10: # 只打印前10个错误
print(f"处理条目时出错 in {file_path} at index {idx}: {e}")
continue
except FileNotFoundError:
print(f"文件未找到: {file_path}")
except json.JSONDecodeError as e:
print(f"JSON解析错误 in {file_path}: {e}")
except Exception as e:
print(f"处理文件时出错 {file_path}: {e}")
print(f"{file_path} 提取了 {len(processed_data)} 个有效样本")
return processed_data
def process_folder(self, folder_path: str, confidence_threshold: float,
boundary_threshold: int) -> List[Dict[str, Any]]:
"""处理文件夹中的所有JSON文件"""
all_processed_data = []
if not os.path.exists(folder_path):
raise FileNotFoundError(f"文件夹不存在: {folder_path}")
# 获取所有JSON文件
json_files = [f for f in os.listdir(folder_path) if f.endswith('.json')]
if not json_files:
raise ValueError(f"{folder_path} 中没有找到JSON文件")
print(f"找到 {len(json_files)} 个JSON文件")
for filename in sorted(json_files):
file_path = os.path.join(folder_path, filename)
processed_data = self.process_single_file(file_path, confidence_threshold, boundary_threshold)
all_processed_data.extend(processed_data)
# 保存最终的关系缓存
self.relation_manager.save_cache()
return all_processed_data
def generate_statistics(self, processed_data: List[Dict[str, Any]]) -> Dict[str, Any]:
"""生成数据统计信息"""
total_samples = len(processed_data)
total_triples = sum(len(sample['target']) for sample in processed_data)
# 统计关系类型
relation_counts = defaultdict(int)
for sample in processed_data:
for target in sample['target']:
relation_counts[target['predicate']] += 1
# 统计文本长度
text_lengths = [len(sample['text']) for sample in processed_data]
avg_text_length = sum(text_lengths) / len(text_lengths) if text_lengths else 0
# 统计每个文本的三元组数量
triples_per_text = [len(sample['target']) for sample in processed_data]
avg_triples_per_text = sum(triples_per_text) / len(triples_per_text) if triples_per_text else 0
return {
'total_samples': total_samples,
'total_triples': total_triples,
'avg_text_length': round(avg_text_length, 2),
'avg_triples_per_text': round(avg_triples_per_text, 2),
'relation_distribution': dict(sorted(relation_counts.items(),
key=lambda x: x[1], reverse=True)),
'top_10_relations': dict(list(sorted(relation_counts.items(),
key=lambda x: x[1], reverse=True))[:10]),
'total_unique_relations': len(relation_counts),
'cached_relations': len(self.relation_manager.relations)
}
def main():
parser = argparse.ArgumentParser(description='处理T-REx数据集支持动态关系获取')
parser.add_argument('--folder_path', type=str,default='/home/pci/ycz/Code/Minimind/dataset/trex', help='包含JSON文件的文件夹路径')
parser.add_argument('--confidence_threshold', type=float, default=0.5,
help='置信度阈值 (默认: 0.0)')
parser.add_argument('--boundary_threshold', type=int, default=0,
help='边界长度阈值 (默认: 0, 不过滤)')
parser.add_argument('--output', type=str, default='./processed_trex_data.json',
help='输出文件名 (默认: processed_trex_data.json)')
parser.add_argument('--stats', type=str, default='trex_statistics.json',
help='统计信息输出文件名 (默认: trex_statistics.json)')
parser.add_argument('--cache_file', type=str, default='wikidata_relations_cache.pkl',
help='关系缓存文件名 (默认: wikidata_relations_cache.pkl)')
parser.add_argument('--mapping_file', type=str, default="/home/pci/ycz/Code/Minimind/preprocessing/sample_property_mappings.json",
help='JSON映射文件路径 (必须提供,用于关系名称映射)')
args = parser.parse_args()
print("T-REx数据集处理器支持动态关系获取")
print("=" * 60)
print(f"输入文件夹: {args.folder_path}")
print(f"置信度阈值: {args.confidence_threshold}")
print(f"边界长度阈值: {args.boundary_threshold}")
print(f"输出文件: {args.output}")
print(f"关系缓存文件: {args.cache_file}")
print(f"JSON映射文件: {args.mapping_file if args.mapping_file else '未指定'}")
print("=" * 60)
# 检查映射文件是否存在
if not args.mapping_file or not os.path.exists(args.mapping_file):
print(f"错误: 映射文件不存在或未指定: {args.mapping_file}")
print("请确保提供有效的JSON映射文件。")
return 1
# 创建关系管理器
relation_manager = WikidataRelationManager(
cache_file=args.cache_file,
mapping_file=args.mapping_file
)
# 创建处理器
processor = TRexProcessor(relation_manager)
try:
# 处理数据
processed_data = processor.process_folder(
args.folder_path,
args.confidence_threshold,
args.boundary_threshold
)
print(f"\n处理完成!总共处理了 {len(processed_data)} 个样本")
# 生成统计信息
stats = processor.generate_statistics(processed_data)
# 保存处理后的数据
with open(args.output, 'w', encoding='utf-8') as f:
json.dump(processed_data, f, ensure_ascii=False, indent=2)
# 保存统计信息
with open(args.stats, 'w', encoding='utf-8') as f:
json.dump(stats, f, ensure_ascii=False, indent=2)
print(f"\n数据已保存到: {args.output}")
print(f"统计信息已保存到: {args.stats}")
print(f"关系缓存已保存到: {args.cache_file}")
# 打印统计摘要
print("\n数据统计摘要:")
print("=" * 30)
print(f"总样本数: {stats['total_samples']}")
print(f"总三元组数: {stats['total_triples']}")
print(f"唯一关系数: {stats['total_unique_relations']}")
print(f"缓存关系数: {stats['cached_relations']}")
print(f"平均文本长度: {stats['avg_text_length']}")
print(f"平均每文本三元组数: {stats['avg_triples_per_text']}")
print("\n前10个最常见关系:")
for relation, count in stats['top_10_relations'].items():
print(f" {relation}: {count}")
except Exception as e:
print(f"处理过程中出错: {e}")
return 1
return 0
if __name__ == "__main__":
exit(main())

View File

@ -1,441 +0,0 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
import re
import asyncio
import aiofiles
from concurrent.futures import ThreadPoolExecutor
from preprocessing.agent_system.extractor_agent.agent import DepartmentAgent
from typing import Dict, List, Tuple
import gc
import time
import psutil
from tqdm.asyncio import tqdm as async_tqdm
from tqdm import tqdm
json_path = "dataset/merged_pretrain_extra.jsonl"
output_path = "dataset/processed_triples.jsonl"
# 优化后的配置参数 - 降低资源消耗
BATCH_SIZE = 5000 # 减少批次大小每批1万条数据
MAX_CONCURRENT = 200 # 减少并发数最多50条并发处理
AGENT_POOL_SIZE = 20 # 大幅减少agent池大小只创建5个agent实例
def get_memory_usage():
"""获取当前内存使用情况"""
process = psutil.Process(os.getpid())
memory_info = process.memory_info()
memory_mb = memory_info.rss / 1024 / 1024
return memory_mb
def print_memory_info(stage=""):
"""打印内存使用信息"""
memory_mb = get_memory_usage()
print(f"🔧 {stage} - 内存使用: {memory_mb:.1f} MB")
# 创建extractor_agent池避免并发冲突
def create_extractor_pool(pool_size: int = 5):
"""创建extractor_agent池"""
print(f"正在创建 {pool_size} 个agent实例...")
agents = []
for i in range(pool_size):
try:
agent = DepartmentAgent(model_type="deepseek")
agents.append(agent)
print(f" ✓ Agent {i+1}/{pool_size} 创建成功")
except Exception as e:
print(f" ✗ Agent {i+1} 创建失败: {e}")
print(f"Agent池创建完成实际创建了 {len(agents)} 个实例")
return agents
# 延迟初始化agent池
AGENT_POOL = None
agent_pool_index = 0
def get_agent_pool():
"""获取agent池延迟初始化"""
global AGENT_POOL
if AGENT_POOL is None:
print_memory_info("创建Agent池前")
AGENT_POOL = create_extractor_pool(pool_size=AGENT_POOL_SIZE)
print_memory_info("创建Agent池后")
return AGENT_POOL
def get_next_agent():
"""轮询获取下一个可用的agent"""
global agent_pool_index
pool = get_agent_pool()
agent = pool[agent_pool_index % len(pool)]
agent_pool_index += 1
return agent
def clean_and_split_text(text):
"""
去除文本开头结尾的标记并按句子分割
"""
# 去除开头的<|im_start|>和结尾的<|im_end|>
text = text.strip()
if text.startswith('<|im_start|>'):
text = text[len('<|im_start|>'):]
if text.endswith('<|im_end|>'):
text = text[:-len('<|im_end|>')]
# 清理文本,去除多余的空白字符
text = text.strip()
# 按句子分割(根据句号、问号、感叹号等标点符号)
# 使用正则表达式匹配句子结束标志
sentence_endings = r'[.!?。!?]'
sentences = re.split(sentence_endings, text)
# 清理每个句子,去除空白和空句子
cleaned_sentences = []
for sentence in sentences:
sentence = sentence.strip()
if sentence and len(sentence) > 5: # 只保留非空且有意义的句子
cleaned_sentences.append(sentence)
return cleaned_sentences
async def extract_triple_from_sentence_async(sentence: str, context: str = None) -> Dict:
"""
异步使用extractor_agent从句子中提取三元组
"""
try:
# 获取一个agent实例
agent = get_next_agent()
result = await agent.async_run(sentence=sentence, context=context)
return {
"sentence": sentence,
"triple": {
"subject": result.triple.subject,
"predicate": result.triple.predicate,
"object": result.triple.object
},
"confidence": result.confidence
}
except Exception as e:
return {
"sentence": sentence,
"triple": {
"subject": "",
"predicate": "",
"object": ""
},
"confidence": 0.0,
"error": str(e)
}
async def process_paragraph_async(line_num: int, original_text: str, semaphore: asyncio.Semaphore) -> Dict:
"""
异步处理单个段落
"""
async with semaphore: # 控制并发数量
try:
# 清理并分割文本
sentences = clean_and_split_text(original_text)
if not sentences:
return None
# 构建当前段落的结果
paragraph_result = {
"source_line": line_num,
"original_paragraph": original_text,
"sentences": [],
"triples": []
}
# 异步处理所有句子
tasks = []
for sentence in sentences:
task = extract_triple_from_sentence_async(sentence, context=original_text)
tasks.append(task)
# 等待所有句子处理完成
triple_results = await asyncio.gather(*tasks)
# 整理结果
for i, sentence in enumerate(sentences):
paragraph_result["sentences"].append(sentence)
paragraph_result["triples"].append(triple_results[i])
return paragraph_result
except Exception as e:
print(f"处理第 {line_num} 行时出错: {e}")
return None
async def process_batch_async(batch_data: List[Tuple[int, str]], batch_num: int) -> List[Dict]:
"""
异步处理一个批次的数据带进度条和内存监控
"""
print(f"\n=== 异步处理批次 {batch_num} ===")
print(f"批次大小: {len(batch_data)} 条记录")
print_memory_info(f"批次 {batch_num} 开始前")
start_time = time.time()
# 创建信号量控制并发数量
semaphore = asyncio.Semaphore(MAX_CONCURRENT)
# 分块处理任务,避免一次性创建太多任务
chunk_size = 1000 # 每次处理1000个任务
all_results = []
for chunk_start in range(0, len(batch_data), chunk_size):
chunk_end = min(chunk_start + chunk_size, len(batch_data))
chunk_data = batch_data[chunk_start:chunk_end]
print(f"处理子块 {chunk_start//chunk_size + 1}/{(len(batch_data)-1)//chunk_size + 1} ({len(chunk_data)} 条记录)")
# 创建当前块的异步任务
tasks = []
for line_num, original_text in chunk_data:
task = process_paragraph_async(line_num, original_text, semaphore)
tasks.append(task)
# 使用进度条处理当前块
progress_bar = tqdm(total=len(tasks), desc=f"批次{batch_num}-块{chunk_start//chunk_size + 1}", unit="段落", ncols=100)
chunk_results = []
completed_tasks = 0
# 使用as_completed来获取完成的任务并更新进度条
for coro in asyncio.as_completed(tasks):
try:
result = await coro
chunk_results.append(result)
completed_tasks += 1
# 更新进度条
progress_bar.update(1)
# 每完成50个任务更新一次描述
if completed_tasks % 50 == 0:
valid_results = [r for r in chunk_results if r is not None]
progress_bar.set_postfix({
'有效': len(valid_results),
'完成': completed_tasks,
'成功率': f"{len(valid_results)/completed_tasks*100:.1f}%"
})
except Exception as e:
print(f"任务执行失败: {e}")
completed_tasks += 1
progress_bar.update(1)
progress_bar.close()
all_results.extend(chunk_results)
# 每个块完成后清理内存
del tasks, chunk_results
gc.collect()
print_memory_info(f"批次 {batch_num}{chunk_start//chunk_size + 1} 完成后")
# 过滤None结果
valid_results = [result for result in all_results if result is not None]
# 统计信息
batch_sentences = sum(len(result["sentences"]) for result in valid_results)
batch_triples = sum(
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
for result in valid_results
)
end_time = time.time()
processing_time = end_time - start_time
print(f"批次 {batch_num} 异步处理完成:")
print(f" - 有效段落: {len(valid_results)}/{len(batch_data)} ({len(valid_results)/len(batch_data)*100:.1f}%)")
print(f" - 总句子数: {batch_sentences}")
print(f" - 成功三元组: {batch_triples}")
print(f" - 三元组成功率: {batch_triples/batch_sentences*100:.1f}%" if batch_sentences > 0 else "无句子")
print(f" - 处理时间: {processing_time:.2f}")
print(f" - 处理速度: {len(batch_data)/processing_time:.2f}段落/秒")
print_memory_info(f"批次 {batch_num} 完成后")
return valid_results
async def write_results_batch(results: List[Dict], output_path: str):
"""
异步批量写入结果带进度提示
"""
try:
print(f"开始批量写入 {len(results)} 条结果...")
# 准备写入内容
content_lines = []
for result in results:
content_lines.append(json.dumps(result, ensure_ascii=False))
# 异步批量写入
async with aiofiles.open(output_path, "a", encoding="utf-8") as f:
await f.write("\n".join(content_lines) + "\n")
print(f"✓ 成功批量写入 {len(results)} 条结果到 {output_path}")
except Exception as e:
print(f"✗ 批量写入失败: {e}")
print("尝试逐条写入...")
# 如果批量写入失败,回退到逐条写入(带进度条)
async with aiofiles.open(output_path, "a", encoding="utf-8") as f:
for result in tqdm(results, desc="逐条写入", unit=""):
await f.write(json.dumps(result, ensure_ascii=False) + "\n")
print(f"✓ 逐条写入完成")
# 主处理流程
async def main_async():
total_processed = 0
total_sentences = 0
total_triples = 0
batch_num = 0
print("=== 开始异步批次处理JSONL文件 ===")
print(f"优化后的配置信息:")
print(f" - 批次大小: {BATCH_SIZE:,} 条记录")
print(f" - 最大并发数: {MAX_CONCURRENT}")
print(f" - Agent池大小: {AGENT_POOL_SIZE}")
print(f" - 输入文件: {json_path}")
print(f" - 输出文件: {output_path}")
print()
print_memory_info("程序开始")
# 清空输出文件
async with aiofiles.open(output_path, "w", encoding="utf-8") as f:
pass
# 读取并处理数据
with open(json_path, "r", encoding="utf-8") as f_in:
batch_data = []
for line_num, line in enumerate(f_in):
if line.strip(): # 跳过空行
try:
item = json.loads(line)
original_text = item.get("text", "")
if original_text:
batch_data.append((line_num + 1, original_text))
# 当批次达到指定大小时,异步处理这个批次
if len(batch_data) >= BATCH_SIZE:
batch_num += 1
# 异步处理批次
batch_results = await process_batch_async(batch_data, batch_num)
# 批量写入结果
if batch_results:
await write_results_batch(batch_results, output_path)
# 统计信息
batch_sentences = sum(len(result["sentences"]) for result in batch_results)
batch_triples = sum(
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
for result in batch_results
)
total_processed += len(batch_data)
total_sentences += batch_sentences
total_triples += batch_triples
print(f"\n📊 批次 {batch_num} 累计统计:")
print(f" - 累计处理段落: {total_processed:,}")
print(f" - 累计句子数: {total_sentences:,}")
print(f" - 累计三元组: {total_triples:,}")
print(f" - 整体成功率: {total_triples/total_sentences*100:.1f}%")
print("-" * 80)
# 清理批次数据,释放内存
batch_data.clear()
batch_results.clear()
gc.collect() # 强制垃圾回收
print_memory_info(f"批次 {batch_num} 清理后")
except json.JSONDecodeError as e:
print(f"{line_num + 1} 行JSON解析错误: {e}")
except Exception as e:
print(f"处理第 {line_num + 1} 行时出错: {e}")
# 处理最后一个不完整的批次
if batch_data:
batch_num += 1
batch_results = await process_batch_async(batch_data, batch_num)
if batch_results:
await write_results_batch(batch_results, output_path)
batch_sentences = sum(len(result["sentences"]) for result in batch_results)
batch_triples = sum(
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
for result in batch_results
)
total_processed += len(batch_data)
total_sentences += batch_sentences
total_triples += batch_triples
# 最终统计
print(f"\n{'='*80}")
print(f"🎉 所有批次异步处理完成!")
print(f"{'='*80}")
print(f"最终统计:")
print(f" - 总批次数: {batch_num}")
print(f" - 总段落数: {total_processed:,}")
print(f" - 总句子数: {total_sentences:,}")
print(f" - 总三元组: {total_triples:,}")
print(f" - 整体成功率: {total_triples/total_sentences*100:.1f}%" if total_sentences > 0 else "无有效句子")
print(f" - 输出文件: {output_path}")
print(f"{'='*80}")
print_memory_info("程序结束前")
# 显示示例结果
await show_sample_results()
async def show_sample_results():
"""显示前几个处理结果作为示例"""
print("\n📋 前3个处理结果示例:")
try:
async with aiofiles.open(output_path, "r", encoding="utf-8") as f:
i = 0
async for line in f:
if i >= 3:
break
item = json.loads(line)
print(f"\n--- 段落 {i+1} (来源行: {item['source_line']}) ---")
print(f"原始段落: {item['original_paragraph'][:100]}...")
print(f"句子数量: {len(item['sentences'])}")
if item['triples']:
for j, triple in enumerate(item['triples'][:2]): # 只显示前2个三元组
print(f" 句子 {j+1}: {triple['sentence'][:50]}...")
if triple['confidence'] > 0:
print(f" 三元组: {triple['triple']['subject']} -> {triple['triple']['predicate']} -> {triple['triple']['object']}")
print(f" 置信度: {triple['confidence']:.2f}")
else:
print(f" 提取失败: {triple.get('error', '未知错误')}")
i += 1
except Exception as e:
print(f"读取示例结果时出错: {e}")
def main():
"""主入口函数"""
try:
# 运行异步主函数
asyncio.run(main_async())
except KeyboardInterrupt:
print("\n⚠️ 用户中断处理")
except Exception as e:
print(f"❌ 处理过程中出现错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@ -1,176 +0,0 @@
[project]
name = "minimind"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"accelerate==1.7.0",
"aiohappyeyeballs==2.6.1",
"aiohttp==3.11.17",
"aiosignal==1.3.2",
"altair==5.5.0",
"annotated-types==0.7.0",
"anyio==4.9.0",
"async-timeout==5.0.1",
"attrs==25.3.0",
"blinker==1.9.0",
"boto3==1.38.41",
"botocore==1.38.41",
"cachetools==5.5.2",
"certifi==2025.1.31",
"charset-normalizer==3.4.1",
"click==8.1.8",
"contourpy==1.3.2",
"cycler==0.12.1",
"datasets==2.21.0",
"datasketch==1.6.4",
"deepspeed==0.17.0",
"determined>=0.37.0",
"dill==0.3.8",
"distro==1.9.0",
"docker-pycreds==0.4.0",
"einops==0.8.1",
"exceptiongroup==1.2.2",
"filelock==3.18.0",
"Flask==3.0.3",
"Flask-Cors==4.0.0",
"fonttools==4.57.0",
"frozenlist==1.6.0",
"fsspec==2024.6.1",
"gitdb==4.0.12",
"GitPython==3.1.44",
"h11==0.14.0",
"hjson==3.1.0",
"httpcore==1.0.8",
"httpx==0.28.1",
"huggingface-hub==0.30.2",
"importlib_metadata==7.2.1",
"itsdangerous==2.2.0",
"jieba==0.42.1",
"Jinja2==3.1.2",
"jiter==0.9.0",
"jmespath==1.0.1",
"joblib==1.4.2",
"jsonlines==4.0.0",
"jsonpointer==2.1",
"jsonschema==4.23.0",
"jsonschema-specifications==2024.10.1",
"kiwisolver==1.4.8",
"langdetect==1.0.9",
"markdown-it-py==3.0.0",
"MarkupSafe==3.0.2",
"marshmallow==3.22.0",
"matplotlib==3.10.0",
"mdurl==0.1.2",
"modelscope==1.25.0",
"mpi4py>=4.0.3",
"mpmath==1.3.0",
"msgpack==1.1.0",
"multidict==6.4.3",
"multiprocess==0.70.16",
"narwhals==1.35.0",
"networkx==3.4.2",
"ngrok==1.4.0",
"ninja==1.11.1.4",
"nltk==3.8",
"numpy==1.26.4",
"nvidia-cublas-cu11==11.11.3.6",
"nvidia-cublas-cu12==12.6.4.1",
"nvidia-cuda-cupti-cu11==11.8.87",
"nvidia-cuda-cupti-cu12==12.6.80",
"nvidia-cuda-nvrtc-cu11==11.8.89",
"nvidia-cuda-nvrtc-cu12==12.6.77",
"nvidia-cuda-runtime-cu11==11.8.89",
"nvidia-cuda-runtime-cu12==12.6.77",
"nvidia-cudnn-cu11==9.1.0.70",
"nvidia-cudnn-cu12==9.5.1.17",
"nvidia-cufft-cu11==10.9.0.58",
"nvidia-cufft-cu12==11.3.0.4",
"nvidia-cufile-cu12==1.11.1.6",
"nvidia-curand-cu11==10.3.0.86",
"nvidia-curand-cu12==10.3.7.77",
"nvidia-cusolver-cu11==11.4.1.48",
"nvidia-cusolver-cu12==11.7.1.2",
"nvidia-cusparse-cu11==11.7.5.86",
"nvidia-cusparse-cu12==12.5.4.2",
"nvidia-cusparselt-cu12==0.6.3",
"nvidia-ml-py==12.575.51",
"nvidia-nccl-cu11==2.21.5",
"nvidia-nccl-cu12==2.26.2",
"nvidia-nvjitlink-cu12==12.6.85",
"nvidia-nvtx-cu11==11.8.86",
"nvidia-nvtx-cu12==12.6.77",
"openai==1.59.6",
"packaging==23.2",
"pandas>=2.0.0",
"peft==0.7.1",
"pillow==10.4.0",
"platformdirs==4.3.7",
"prettytable==3.16.0",
"propcache==0.3.1",
"protobuf==4.25.6",
"psutil==5.9.8",
"py-cpuinfo==9.0.0",
"pyarrow==19.0.1",
"pydantic==2.11.7",
"pydantic_core==2.33.2",
"pydeck==0.9.1",
"pyecharts==2.0.8",
"Pygments==2.19.1",
"pynvml==12.0.0",
"pyparsing==3.2.3",
"python-dateutil==2.9.0.post0",
"pytz==2025.2",
"PyYAML==6.0.2",
"referencing==0.36.2",
"regex==2024.11.6",
"requests==2.32.3",
"rich==13.7.1",
"rouge-score>=0.1.2",
"rpds-py==0.24.0",
"s3transfer==0.13.0",
"safetensors==0.5.3",
"scikit-learn==1.5.1",
"scipy==1.15.2",
"sentence-transformers==2.3.1",
"sentencepiece==0.2.0",
"sentry-sdk==2.26.1",
"setproctitle==1.3.5",
"simhash==2.1.2",
"simplejson==3.20.1",
"six==1.17.0",
"smmap==5.0.2",
"sniffio==1.3.1",
"streamlit==1.30.0",
"swankit==0.2.4",
"swanlab==0.6.4",
"sympy==1.13.3",
"tenacity==8.5.0",
"threadpoolctl==3.6.0",
"tiktoken>=0.8.0",
"tokenizers==0.21.1",
"toml==0.10.2",
"torch==2.7.1",
"torchaudio==2.7.1",
"torchvision==0.22.1",
"tornado==6.4.2",
"tqdm==4.67.1",
"transformers==4.52.4",
"triton==3.3.1",
"trl==0.13.0",
"typing-inspection==0.4.1",
"typing_extensions==4.13.2",
"tzlocal==5.3.1",
"ujson==5.1.0",
"urllib3==2.4.0",
"validators==0.34.0",
"wandb==0.18.3",
"watchdog==6.0.0",
"wcwidth==0.2.13",
"Werkzeug==3.1.3",
"wrapt==1.17.2",
"xxhash==3.5.0",
"yarl==1.20.0",
"zipp==3.21.0",
]

View File

@ -1,165 +1,30 @@
accelerate==1.7.0
aiohappyeyeballs==2.6.1
aiohttp==3.11.17
aiosignal==1.3.2
altair==5.5.0
annotated-types==0.7.0
anyio==4.9.0
async-timeout==5.0.1
attrs==25.3.0
blinker==1.9.0
boto3==1.38.41
botocore==1.38.41
cachetools==5.5.2
certifi==2025.1.31
charset-normalizer==3.4.1
click==8.1.8
contourpy==1.3.2
cycler==0.12.1
datasets==2.21.0
datasketch==1.6.4
deepspeed==0.17.0
dill==0.3.8
distro==1.9.0
docker-pycreds==0.4.0
einops==0.8.1
exceptiongroup==1.2.2
filelock==3.18.0
Flask==3.0.3
Flask-Cors==4.0.0
fonttools==4.57.0
frozenlist==1.6.0
fsspec==2024.6.1
gitdb==4.0.12
GitPython==3.1.44
h11==0.14.0
hjson==3.1.0
httpcore==1.0.8
httpx==0.28.1
huggingface-hub==0.30.2
importlib_metadata==7.2.1
itsdangerous==2.2.0
Flask_Cors==4.0.0
jieba==0.42.1
Jinja2==3.1.2
jiter==0.9.0
jmespath==1.0.1
joblib==1.4.2
jsonlines==4.0.0
jsonpointer==2.1
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
kiwisolver==1.4.8
langdetect==1.0.9
markdown-it-py==3.0.0
MarkupSafe==3.0.2
marshmallow==3.22.0
matplotlib==3.10.0
mdurl==0.1.2
modelscope==1.25.0
mpmath==1.3.0
msgpack==1.1.0
multidict==6.4.3
multiprocess==0.70.16
narwhals==1.35.0
networkx==3.4.2
ngrok==1.4.0
ninja==1.11.1.4
nltk==3.8
numpy==1.26.4
nvidia-cublas-cu11==11.11.3.6
nvidia-cublas-cu12==12.6.4.1
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvrtc-cu11==11.8.89
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu11==9.1.0.70
nvidia-cudnn-cu12==9.5.1.17
nvidia-cufft-cu11==10.9.0.58
nvidia-cufft-cu12==11.3.0.4
nvidia-cufile-cu12==1.11.1.6
nvidia-curand-cu11==10.3.0.86
nvidia-curand-cu12==10.3.7.77
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu11==11.7.5.86
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparselt-cu12==0.6.3
nvidia-ml-py==12.575.51
nvidia-nccl-cu11==2.21.5
nvidia-nccl-cu12==2.26.2
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvtx-cu11==11.8.86
nvidia-nvtx-cu12==12.6.77
openai==1.59.6
packaging==23.2
pandas==1.5.3
peft==0.7.1
pillow==10.4.0
platformdirs==4.3.7
prettytable==3.16.0
propcache==0.3.1
protobuf==4.25.6
psutil==5.9.8
py-cpuinfo==9.0.0
pyarrow==19.0.1
pydantic==2.11.7
pydantic_core==2.33.2
pydeck==0.9.1
pyecharts==2.0.8
Pygments==2.19.1
pynvml==12.0.0
pyparsing==3.2.3
python-dateutil==2.9.0.post0
pytz==2025.2
PyYAML==6.0.2
referencing==0.36.2
regex==2024.11.6
requests==2.32.3
pydantic==2.8.2
rich==13.7.1
rpds-py==0.24.0
s3transfer==0.13.0
safetensors==0.5.3
scikit-learn==1.5.1
scipy==1.15.2
sentence-transformers==2.3.1
sentencepiece==0.2.0
sentry-sdk==2.26.1
setproctitle==1.3.5
scikit_learn==1.5.1
sentence_transformers==2.3.1
simhash==2.1.2
simplejson==3.20.1
six==1.17.0
smmap==5.0.2
sniffio==1.3.1
streamlit==1.30.0
swankit==0.2.4
swanlab==0.6.4
sympy==1.13.3
tenacity==8.5.0
threadpoolctl==3.6.0
tiktoken==0.5.1
tokenizers==0.21.1
toml==0.10.2
torch==2.7.1
torchaudio==2.7.1
torchvision==0.22.1
tornado==6.4.2
tqdm==4.67.1
transformers==4.52.4
triton==3.3.1
transformers==4.48.0
jinja2==3.1.2
jsonlines==4.0.0
trl==0.13.0
typing-inspection==0.4.1
typing_extensions==4.13.2
tzlocal==5.3.1
ujson==5.1.0
urllib3==2.4.0
validators==0.34.0
wandb==0.18.3
watchdog==6.0.0
wcwidth==0.2.13
Werkzeug==3.1.3
wrapt==1.17.2
xxhash==3.5.0
yarl==1.20.0
zipp==3.21.0
streamlit==1.30.0
torch==2.2.2
torchvision==0.17.2

View File

@ -1,52 +0,0 @@
#!/bin/bash
# 激活conda环境
# source $(conda info --base)/etc/profile.d/conda.sh
# conda activate mini
# 设置环境变量以帮助调试
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1
# 方法1: 使用预先配置的accelerate配置文件
# accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
# --epochs 3 \
# --batch_size 24 \
# --learning_rate 2e-4 \
# --dtype bfloat16 \
# --accumulation_steps 32 \
# --grad_clip 1.0 \
# --log_interval 100 \
# --save_interval 10000 \
# --dim 1024 \
# --n_layers 32 \
# --max_seq_len 1024 \
# --use_flash_attn \
# --profile \
# --profile_interval 10
# 方法2: 使用命令行参数直接配置accelerate
# 内存泄漏调试配置 - 减少内存使用
CUDA_VISIBLE_DEVICES=0 uv run -p .venv python -m accelerate.commands.launch \
--num_processes=1 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py
# --batch_size 128 \
# --num_workers 1
# --knowledge_num 48020 \
# --num_workers 1 \
# --epochs 4 \
# --learning_rate 2e-4 \
# --dtype bfloat16 \
# --accumulation_steps 32 \
# --grad_clip 1.0 \
# --log_interval 50 \
# --save_interval 10000 \
# --dim 512 \
# --n_layers 8 \
# --max_seq_len 512 \
# --use_flash_attn \
# --profile \
# --profile_interval 10

View File

@ -1,50 +0,0 @@
#!/bin/bash
# 激活conda环境
source $(conda info --base)/etc/profile.d/conda.sh
conda activate ycz_accelerate
# 设置环境变量以帮助调试
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1
# 方法1: 使用预先配置的accelerate配置文件
# accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
# --epochs 3 \
# --batch_size 24 \
# --learning_rate 2e-4 \
# --dtype bfloat16 \
# --accumulation_steps 32 \
# --grad_clip 1.0 \
# --log_interval 100 \
# --save_interval 10000 \
# --dim 1024 \
# --n_layers 32 \
# --max_seq_len 1024 \
# --use_flash_attn \
# --profile \
# --profile_interval 10
# 方法2: 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--multi_gpu \
--num_processes=4 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py \
--epochs 3 \
--batch_size 24 \
--learning_rate 2e-4 \
--dtype bfloat16 \
--accumulation_steps 32 \
--grad_clip 1.0 \
--log_interval 100 \
--save_interval 10000 \
--dim 1024 \
--n_layers 32 \
--max_seq_len 1024 \
--use_flash_attn \
--profile \
--profile_interval 10\
--knowledge_num 16384 \
--knowledge_length 64

View File

@ -1,46 +0,0 @@
#!/bin/bash
# 激活conda环境
source $(conda info --base)/etc/profile.d/conda.sh
conda activate ycz_accelerate
# 设置环境变量以帮助调试
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1
# 实验1.3.0 - 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--multi_gpu \
--num_processes=4 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py \
--out_dir "out" \
--epochs 3 \
--embedding_epoch 2 \
--batch_size 64 \
--learning_rate 8e-5 \
--dtype bfloat16 \
--use_swanlab \
--swanlab_project "MiniMind-Pretrain" \
--num_workers 1 \
--accumulation_steps 16 \
--grad_clip 0.5 \
--warmup_iters 0 \
--log_interval 100 \
--save_interval 10000 \
--dim 1024 \
--n_layers 48 \
--max_seq_len 512 \
--data_path "./dataset/stable/merged_pretrain.jsonl" \
--profile \
--profile_interval 10 \
--use_flash_attn \
--knowledge_num 1048576 \
--knowledge_length 32 \
--database_init_path "./dataset/stable/sentence_trex_data.json" \
--fast_clustering \
--cluster_cache_path "./cache/cluster_tokens_single.pt" \
--memory_monitor_interval 10 \
--model_type "model_original" \
--model_size 538

View File

@ -1,47 +0,0 @@
#!/bin/bash
# 激活conda环境
source $(conda info --base)/etc/profile.d/conda.sh
conda activate ycz_accelerate
# 设置环境变量以帮助调试
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1
# 实验1.3.0 - 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--multi_gpu \
--num_processes=4 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py \
--out_dir "out" \
--epochs 3 \
--embedding_epoch 2 \
--batch_size 48 \
--learning_rate 2e-4 \
--dtype bfloat16 \
--use_swanlab \
--swanlab_project "MiniMind-Pretrain" \
--num_workers 1 \
--accumulation_steps 32 \
--grad_clip 1.0 \
--warmup_iters 0 \
--log_interval 100 \
--save_interval 10000 \
--dim 1024 \
--n_layers 18 \
--max_seq_len 512 \
--use_moe False \
--data_path "./dataset/stable/merged_pretrain.jsonl" \
--profile \
--profile_interval 10 \
--use_flash_attn \
--knowledge_num 1048576 \
--knowledge_length 32 \
--database_init_path "./dataset/stable/sentence_trex_data.json" \
--fast_clustering \
--cluster_cache_path "./cache/cluster_tokens_single.pt" \
--memory_monitor_interval 10 \
--model_type "model_no_feed" \
--model_size 814.724

View File

@ -1,47 +0,0 @@
#!/bin/bash
# 激活conda环境
source $(conda info --base)/etc/profile.d/conda.sh
conda activate ycz_accelerate
# 设置环境变量以帮助调试
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1
# 实验1.3.0 - 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--multi_gpu \
--num_processes=4 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py \
--out_dir "out" \
--epochs 3 \
--embedding_epoch 2 \
--batch_size 48 \
--learning_rate 2e-4 \
--dtype bfloat16 \
--use_swanlab \
--swanlab_project "MiniMind-Pretrain" \
--num_workers 1 \
--accumulation_steps 32 \
--grad_clip 1.0 \
--warmup_iters 0 \
--log_interval 100 \
--save_interval 10000 \
--dim 1024 \
--n_layers 18 \
--max_seq_len 512 \
--use_moe False \
--data_path "./dataset/stable/merged_pretrain.jsonl" \
--profile \
--profile_interval 10 \
--use_flash_attn \
--knowledge_num 1048576 \
--knowledge_length 32 \
--database_init_path "./dataset/stable/sentence_trex_data.json" \
--fast_clustering \
--cluster_cache_path "./cache/cluster_tokens_single.pt" \
--memory_monitor_interval 10 \
--model_type "model" \
--model_size 814.724

View File

@ -1,45 +0,0 @@
#!/bin/bash
# 激活conda环境
# source $(conda info --base)/etc/profile.d/conda.sh
# conda activate ycz_accelerate
# 设置环境变量以帮助调试
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1
# 实验1.3.0 - 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0 accelerate launch \
--num_processes=1 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py \
--out_dir "out" \
--epochs 3 \
--embedding_epoch 2 \
--batch_size 128 \
--learning_rate 8e-5 \
--dtype bfloat16 \
--use_swanlab \
--swanlab_project "MiniMind-Pretrain" \
--num_workers 1 \
--accumulation_steps 16 \
--grad_clip 0.5 \
--warmup_iters 0 \
--log_interval 100 \
--save_interval 10000 \
--dim 512 \
--n_layers 8 \
--max_seq_len 512 \
--data_path "./dataset/stable/merged_pretrain.jsonl" \
--profile \
--profile_interval 10 \
--use_flash_attn \
--knowledge_num 1048576 \
--knowledge_length 32 \
--database_init_path "./dataset/stable/sentence_trex_data.json" \
--fast_clustering \
--cluster_cache_path "./cache/cluster_tokens_single.pt" \
--memory_monitor_interval 10 \
--model_type "model" \
--model_size 538

View File

@ -32,7 +32,7 @@ def train_tokenizer():
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
# 定义特殊token
special_tokens = ["<unk>", "<|im_start|>", "<|im_end|>"]
special_tokens = ["<unk>", "<s>", "</s>"]
# 设置训练器并添加特殊token
trainer = trainers.BpeTrainer(
@ -53,8 +53,8 @@ def train_tokenizer():
# 检查特殊token的索引
assert tokenizer.token_to_id("<unk>") == 0
assert tokenizer.token_to_id("<|im_start|>") == 1
assert tokenizer.token_to_id("<|im_end|>") == 2
assert tokenizer.token_to_id("<s>") == 1
assert tokenizer.token_to_id("</s>") == 2
# 保存tokenizer
tokenizer_dir = "../model/minimind_tokenizer"
@ -77,7 +77,7 @@ def train_tokenizer():
"special": True
},
"1": {
"content": "<|im_start|>",
"content": "<s>",
"lstrip": False,
"normalized": False,
"rstrip": False,
@ -85,7 +85,7 @@ def train_tokenizer():
"special": True
},
"2": {
"content": "<|im_end|>",
"content": "</s>",
"lstrip": False,
"normalized": False,
"rstrip": False,
@ -94,9 +94,9 @@ def train_tokenizer():
}
},
"additional_special_tokens": [],
"bos_token": "<|im_start|>",
"bos_token": "<s>",
"clean_up_tokenization_spaces": False,
"eos_token": "<|im_end|>",
"eos_token": "</s>",
"legacy": True,
"model_max_length": 32768,
"pad_token": "<unk>",
@ -104,7 +104,7 @@ def train_tokenizer():
"spaces_between_special_tokens": False,
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": "<unk>",
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\n你是 MiniMind是一个有用的人工智能助手。<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<s>system\\n' + system_message + '</s>\\n' }}{% else %}{{ '<s>system\\n你是 MiniMind是一个有用的人工智能助手。</s>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
}
# 保存配置文件

View File

@ -1,33 +0,0 @@
#!/bin/bash
set -e
# 在容器启动后,首先从 requirements.txt 安装所有依赖包
# pip install -r requirements.txt
# bash install.sh -y
python3 -m pip install --upgrade pip
pip install uv -i https://pypi.tuna.tsinghua.edu.cn/simple
# 切换到项目目录
cd /ycz/Minimind
# 检查并修复虚拟环境
if [ ! -f .venv/bin/python ] || [ ! -x .venv/bin/python ]; then
echo "Virtual environment is broken or missing, recreating with uv..."
rm -rf .venv
uv venv .venv
fi
# 不要手动激活虚拟环境让uv自动管理
# . ./.venv/bin/activate
# 使用uv同步依赖
uv sync
# 安装完成后,执行主训练脚本
# "$@" 会将 experiment.yaml 中 entrypoint 定义的参数传递给 python 脚本
CUDA_VISIBLE_DEVICES=0 uv run python -m accelerate.commands.launch \
--num_processes=1 \
--mixed_precision=bf16 \
--main_process_port=29500 \
train_pretrain_accelerate.py "$@"

File diff suppressed because it is too large Load Diff

View File

@ -13,7 +13,6 @@ from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, DistributedSampler
# 移除通信分析工具导入
from contextlib import nullcontext
from typing import Optional
@ -41,69 +40,19 @@ def get_lr(current_step, total_steps, lr):
def train_epoch(epoch, wandb):
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
# 在函数开始处定义moe_path避免在异常处理中引用未定义变量
moe_path = '_moe' if lm_config.use_moe else ''
# 添加CUDA事件来分析性能
if args.profile and (not ddp or dist.get_rank() == 0):
data_start = torch.cuda.Event(enable_timing=True)
data_end = torch.cuda.Event(enable_timing=True)
forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
# 移除CUDA图优化代码
# 预取数据
prefetch_factor = 2 # 预取的批次数
data_iter = iter(train_loader)
prefetch_batches = []
# 预取初始批次
for _ in range(min(prefetch_factor, len(train_loader))):
for step, (X, Y, loss_mask) in enumerate(train_loader):
try:
batch = next(data_iter)
prefetch_batches.append([t.to(args.device, non_blocking=True) for t in batch])
except StopIteration:
break
for step in range(len(train_loader)):
try:
# 计时数据加载
if args.profile and (not ddp or dist.get_rank() == 0):
data_start.record()
# 使用预取的数据
if prefetch_batches:
X, Y, loss_mask = prefetch_batches.pop(0)
else:
# 如果预取队列为空,直接加载
X, Y, loss_mask = [t.to(args.device) for t in next(data_iter)]
# 异步预取下一批数据
if step + prefetch_factor < len(train_loader):
try:
batch = next(data_iter)
prefetch_batches.append([t.to(args.device, non_blocking=True) for t in batch])
except StopIteration:
pass
if args.profile and (not ddp or dist.get_rank() == 0):
data_end.record()
# 将数据加载到设备上
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
# 更新学习率
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 计时前向传播
if args.profile and (not ddp or dist.get_rank() == 0):
forward_start.record()
# 常规前向传播
with ctx:
res = model(X)
loss = loss_fct(
@ -127,13 +76,6 @@ def train_epoch(epoch, wandb):
# 如果出错,不添加辅助损失
loss = loss / args.accumulation_steps
# 反向传播
scaler.scale(loss).backward()
if args.profile and (not ddp or dist.get_rank() == 0):
forward_end.record()
backward_start.record()
# Print data types for debugging
if step == 0 and (not ddp or dist.get_rank() == 0): # Print only for the first step of the first epoch on the main process
Logger("---- Data Type Check ----")
@ -146,21 +88,9 @@ def train_epoch(epoch, wandb):
Logger(f"loss.dtype: {loss.dtype}")
Logger("-------------------------")
if args.profile and (not ddp or dist.get_rank() == 0):
backward_end.record()
scaler.scale(loss).backward()
# 在每一步都进行性能分析,而不仅仅是在梯度累积完成时
if (step + 1) % args.profile_interval == 0:
# 记录优化器时间(如果是梯度累积步骤)
if (step + 1) % args.accumulation_steps == 0:
optimizer_start.record()
# 优化器步骤
if (step + 1) % args.accumulation_steps == 0:
if args.profile and (not ddp or dist.get_rank() == 0):
if (step + 1) % args.profile_interval != 0:
optimizer_start.record()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
@ -169,40 +99,6 @@ def train_epoch(epoch, wandb):
optimizer.zero_grad(set_to_none=True)
if args.profile and (not ddp or dist.get_rank() == 0):
optimizer_end.record()
# 性能分析输出每profile_interval步
if args.profile and (not ddp or dist.get_rank() == 0) and (step + 1) % args.profile_interval == 0:
# 同步CUDA事件以获取准确的计时
torch.cuda.synchronize()
# 计算各阶段耗时
data_time = data_start.elapsed_time(data_end)
forward_time = forward_start.elapsed_time(forward_end)
backward_time = backward_start.elapsed_time(backward_end)
# 只有在梯度累积步骤完成时才有优化器时间
if (step + 1) % args.accumulation_steps == 0:
optimizer_time = optimizer_start.elapsed_time(optimizer_end)
total_compute_time = forward_time + backward_time + optimizer_time
Logger(f"性能分析 - 步骤 {step+1}:")
Logger(f" 数据加载时间: {data_time:.2f} ms")
Logger(f" 前向传播时间: {forward_time:.2f} ms")
Logger(f" 反向传播时间: {backward_time:.2f} ms")
Logger(f" 优化器时间: {optimizer_time:.2f} ms")
Logger(f" 总计算时间: {total_compute_time:.2f} ms")
Logger(f" 计算/数据比例: {total_compute_time / data_time:.2f}")
else:
# 非梯度累积步骤,没有优化器时间
total_compute_time = forward_time + backward_time
Logger(f"性能分析 - 步骤 {step+1} (梯度累积中):")
Logger(f" 数据加载时间: {data_time:.2f} ms")
Logger(f" 前向传播时间: {forward_time:.2f} ms")
Logger(f" 反向传播时间: {backward_time:.2f} ms")
Logger(f" 总计算时间: {total_compute_time:.2f} ms")
Logger(f" 计算/数据比例: {total_compute_time / data_time:.2f}")
# 打印日志
if step % args.log_interval == 0:
spend_time = time.time() - start_time
@ -217,44 +113,14 @@ def train_epoch(epoch, wandb):
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
log_dict = {
"loss": loss.item() * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
}
# 如果启用了性能分析,也记录性能指标
if args.profile and (step + 1) % args.profile_interval == 0:
# 基本性能指标
perf_dict = {
"data_time_ms": data_time,
"forward_time_ms": forward_time,
"backward_time_ms": backward_time
}
# 只有在梯度累积步骤完成时才有优化器时间
if (step + 1) % args.accumulation_steps == 0:
total_compute_time = forward_time + backward_time + optimizer_time
perf_dict.update({
"optimizer_time_ms": optimizer_time,
"compute_time_ms": total_compute_time
})
else:
total_compute_time = forward_time + backward_time
perf_dict.update({
"compute_time_ms": total_compute_time
})
log_dict.update(perf_dict)
wandb.log(log_dict)
# 移除通信分析代码
wandb.log({"loss": loss.item() * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
# 保存模型
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
# 使用函数开始处定义的moe_path变量
# moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
@ -270,7 +136,7 @@ def train_epoch(epoch, wandb):
save_path = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}_nanERROR.pth'
if os.path.exists(save_path):
os.remove(save_path)
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
@ -280,18 +146,18 @@ def train_epoch(epoch, wandb):
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"NaN gradient in parameter: {name}")
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"Parameter {name} values: {param.data}")
print(f"Parameter {name} gradients: {param.grad}")
raise ValueError("NaN gradient detected")
def init_model(lm_config, pretrained_embedding_path: Optional[str] = None):
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained('/mnt/lzn/Minimind/Minimind/model/minimind_tokenizer')
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
# 加载模型
model = MiniMindLM(lm_config).to(args.device)
@ -309,9 +175,6 @@ def init_model(lm_config, pretrained_embedding_path: Optional[str] = None):
return model, tokenizer
# 移除通信分析函数
def init_distributed_mode():
if not ddp: return #如果没有启用分布式数据并行(DDP),直接返回,不执行任何操作。
global ddp_local_rank, DEVICE #声明这两个变量为全局变量,以便在函数外部也能访问它们。
@ -330,42 +193,35 @@ if __name__ == "__main__":
parser.add_argument("--out_dir", type=str, default="out")
# 若要以最快速度实现zero则epochs设置为1轮否则应当利用有限的数据训练2~6个epochs。
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=24)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用则使用GPU否则使用CPU。
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", default=True, action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
parser.add_argument("--num_workers", type=int, default=48)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=32) #梯度累积步数,用于控制梯度更新频率。
parser.add_argument("--accumulation_steps", type=int, default=8) #梯度累积步数,用于控制梯度更新频率。
parser.add_argument("--grad_clip", type=float, default=1.0) #梯度裁剪阈值,用于防止梯度爆炸。
parser.add_argument("--warmup_iters", type=int, default=0) #预热迭代次数,用于控制学习率预热过程。
parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。
parser.add_argument("--save_interval", type=int, default=10000) #模型保存间隔,用于控制模型保存的频率。
parser.add_argument("--save_interval", type=int, default=100) #模型保存间隔,用于控制模型保存的频率。
parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。
parser.add_argument('--dim', default=1024, type=int) #模型维度,用于控制模型的大小。
parser.add_argument('--n_layers', default=32, type=int) #层数,用于控制模型层数。
parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。
parser.add_argument('--dim', default=768, type=int) #模型维度,用于控制模型的大小。
parser.add_argument('--n_layers', default=8, type=int) #层数,用于控制模型层数。
parser.add_argument('--max_seq_len', default=512, type=int) #最大序列长度,用于控制输入序列的最大长度。
parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE用于控制是否使用MOE。
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代") #禁用数据库功能,启用特殊模式
parser.add_argument("--data_path", type=str, default="/mnt/lzn/Minimind/dataset/dir/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
# 性能分析相关参数
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
args = parser.parse_args()
print(args)
lm_config = LMConfig(
dim=args.dim,
n_layers=args.n_layers,
max_seq_len=args.max_seq_len,
dim=args.dim,
n_layers=args.n_layers,
max_seq_len=args.max_seq_len,
use_moe=args.use_moe,
disable_db=args.disable_db, # 添加禁用数据库参数
flash_attn=args.use_flash_attn # 添加FlashAttention支持
disable_db=args.disable_db # 添加禁用数据库参数
) #创建LMConfig对象用于控制模型配置。
args.save_dir = os.path.join(args.out_dir) #创建保存目录。
os.makedirs(args.save_dir, exist_ok=True) #创建保存目录。
@ -398,42 +254,35 @@ if __name__ == "__main__":
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb
# Merge args and lm_config parameters for wandb config
config = vars(args).copy()
config.update(lm_config.__dict__)
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config)
else:
wandb = None
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
# 优化DataLoader配置
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
pin_memory_device=f"cuda:{ddp_local_rank}" if ddp else "cuda:0", # 指定pin_memory设备
drop_last=False,
shuffle=False,
num_workers=args.num_workers,
sampler=train_sampler,
persistent_workers=True if args.num_workers > 0 else False, # 保持worker进程活跃
prefetch_factor=2 if args.num_workers > 0 else None # 预取因子
sampler=train_sampler
)
# 只有在使用float16时才启用GradScalerbfloat16不需要
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16']))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
# 保留find_unused_parameters=True参数因为模型中确实有未使用的参数
model = DistributedDataParallel(model, device_ids=[ddp_local_rank], find_unused_parameters=True)
# 暂时保留set_detect_anomaly以便调试
# 训练稳定后可以注释掉这行来提高速度
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
torch.autograd.set_detect_anomaly(True)
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):

File diff suppressed because it is too large Load Diff

4835
uv.lock generated

File diff suppressed because it is too large Load Diff