Compare commits
No commits in common. "master" and "old/HPC" have entirely different histories.
9
.gitignore
vendored
@ -2,11 +2,4 @@
|
||||
/dataset
|
||||
/out
|
||||
wandb/
|
||||
**/*.log
|
||||
models/sentence_transformers/
|
||||
models/sentence_transformers_cache/
|
||||
**/*.pyc
|
||||
qwen2-1.7B/
|
||||
images/
|
||||
cache/
|
||||
.venv/
|
||||
**/*.log
|
124
.vscode/launch.json
vendored
@ -1,124 +0,0 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "MiniMind Training (Direct Python)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
||||
"args": [
|
||||
"--out_dir", "out",
|
||||
"--epochs", "3",
|
||||
"--embedding_epoch", "2",
|
||||
"--batch_size", "128",
|
||||
"--learning_rate", "8e-5",
|
||||
"--dtype", "bfloat16",
|
||||
"--use_swanlab",
|
||||
"--swanlab_project", "MiniMind-Pretrain",
|
||||
"--num_workers", "1",
|
||||
"--accumulation_steps", "16",
|
||||
"--grad_clip", "0.5",
|
||||
"--warmup_iters", "0",
|
||||
"--log_interval", "1",
|
||||
"--save_interval", "10000",
|
||||
"--dim", "512",
|
||||
"--n_layers", "8",
|
||||
"--max_seq_len", "512",
|
||||
"--data_path", "./dataset/stable/merged_pretrain.jsonl",
|
||||
"--profile",
|
||||
"--profile_interval", "10",
|
||||
"--use_flash_attn",
|
||||
"--knowledge_num", "1048576",
|
||||
"--knowledge_length", "32",
|
||||
"--database_init_path", "./dataset/stable/sentence_trex_data.json",
|
||||
"--fast_clustering",
|
||||
"--cluster_cache_path", "./cache/cluster_tokens_single.pt",
|
||||
"--memory_monitor_interval", "10",
|
||||
"--model_type", "model",
|
||||
"--model_size", "538"
|
||||
],
|
||||
"env": {
|
||||
"CUDA_VISIBLE_DEVICES": "0",
|
||||
"NCCL_DEBUG": "INFO",
|
||||
"PYTHONFAULTHANDLER": "1"
|
||||
},
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"stopOnEntry": false,
|
||||
"python": "${workspaceFolder}/.venv/bin/python"
|
||||
},
|
||||
{
|
||||
"name": "MiniMind Training (Direct Python - Simple)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
||||
"args": [
|
||||
"--epochs", "1",
|
||||
"--batch_size", "32",
|
||||
"--learning_rate", "1e-4",
|
||||
"--log_interval", "10",
|
||||
"--profile_interval", "2",
|
||||
"--model_type", "model_original"
|
||||
],
|
||||
"env": {
|
||||
"CUDA_VISIBLE_DEVICES": "0"
|
||||
},
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"stopOnEntry": false,
|
||||
"python": "${workspaceFolder}/.venv/bin/python"
|
||||
},
|
||||
{
|
||||
"name": "MiniMind Test (Direct Python)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/test.py",
|
||||
"env": {
|
||||
"CUDA_VISIBLE_DEVICES": "0"
|
||||
},
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"python": "${workspaceFolder}/.venv/bin/python"
|
||||
},
|
||||
{
|
||||
"name": "MiniMind Training Debug (Accelerate)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "accelerate.commands.launch",
|
||||
"args": [
|
||||
"--num_processes=1",
|
||||
"--mixed_precision=bf16",
|
||||
"${workspaceFolder}/train_pretrain_accelerate.py",
|
||||
"--epochs", "1",
|
||||
"--batch_size", "32",
|
||||
"--learning_rate", "1e-4",
|
||||
"--log_interval", "10",
|
||||
"--profile_interval", "2",
|
||||
"--model_type", "model_original"
|
||||
],
|
||||
"env": {
|
||||
"CUDA_VISIBLE_DEVICES": "0"
|
||||
},
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"stopOnEntry": false,
|
||||
"python": "${workspaceFolder}/.venv/bin/python"
|
||||
},
|
||||
{
|
||||
"name": "MiniMind Test Only",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/test.py",
|
||||
"env": {
|
||||
"CUDA_VISIBLE_DEVICES": "0"
|
||||
},
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
}
|
||||
]
|
||||
}
|
18
.vscode/settings.json
vendored
@ -1,18 +0,0 @@
|
||||
{
|
||||
"python.pythonPath": "/home/iomgaa/miniconda3/envs/accelerate/bin/python",
|
||||
"python.defaultInterpreterPath": "/home/iomgaa/miniconda3/envs/accelerate/bin/python",
|
||||
"python.terminal.activateEnvironment": true,
|
||||
"python.terminal.activateEnvInCurrentTerminal": true,
|
||||
"python.linting.enabled": true,
|
||||
"python.linting.pylintEnabled": false,
|
||||
"python.linting.flake8Enabled": true,
|
||||
"python.formatting.provider": "black",
|
||||
"python.analysis.autoImportCompletions": true,
|
||||
"python.analysis.typeCheckingMode": "off",
|
||||
"files.exclude": {
|
||||
"**/__pycache__": true,
|
||||
"**/*.pyc": true,
|
||||
"**/.git": false,
|
||||
"**/wandb": false
|
||||
}
|
||||
}
|
128
CODE_OF_CONDUCT.md
Normal file
@ -0,0 +1,128 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, religion, or sexual identity
|
||||
and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community include:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the
|
||||
overall community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or
|
||||
advances of any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email
|
||||
address, without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
.
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series
|
||||
of actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period of time. This
|
||||
includes avoiding interactions in community spaces as well as external channels
|
||||
like social media. Violating these terms may lead to a temporary or
|
||||
permanent ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public
|
||||
communication with the community for a specified period of time. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within
|
||||
the community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.0, available at
|
||||
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
||||
|
||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
||||
enforcement ladder](https://github.com/mozilla/diversity).
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
https://www.contributor-covenant.org/faq. Translations are available at
|
||||
https://www.contributor-covenant.org/translations.
|
1509
README_en.md
Normal file
@ -1,17 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
deepspeed_config:
|
||||
deepspeed_config_file: ds_config.json
|
||||
zero3_init_flag: false
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
@ -1,144 +0,0 @@
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from model.model import MiniMindLM, ExtractDB
|
||||
from model.LMConfig import LMConfig
|
||||
|
||||
def decode_dataset(model_path, output_path, device="cuda"):
|
||||
"""
|
||||
Decode the weight_down_embed buffer in the model to readable text
|
||||
|
||||
Args:
|
||||
model_path: Path to the model checkpoint
|
||||
output_path: Path to save the decoded text
|
||||
device: Device to load the model on
|
||||
"""
|
||||
print(f"Loading tokenizer from ./model/minimind_tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
|
||||
print(f"Setting up model configuration")
|
||||
# Create model configuration matching the training parameters
|
||||
lm_config = LMConfig(
|
||||
dim=1024,
|
||||
n_layers=32,
|
||||
max_seq_len=1024,
|
||||
use_flash_attn=True,
|
||||
knowledge_num=16384, # From the script parameters
|
||||
knowledge_length=64 # From the script parameters
|
||||
)
|
||||
|
||||
print(f"Initializing model")
|
||||
model = MiniMindLM(lm_config).to(device)
|
||||
|
||||
print(f"Loading model weights from {model_path}")
|
||||
state_dict = torch.load(model_path, map_location=device)
|
||||
|
||||
# Get model parameters
|
||||
model_state = dict(model.named_parameters())
|
||||
model_state.update(dict(model.named_buffers()))
|
||||
|
||||
# Find parameters with matching names but different shapes
|
||||
shape_mismatch = {}
|
||||
for name, param in model_state.items():
|
||||
if name in state_dict and param.shape != state_dict[name].shape:
|
||||
shape_mismatch[name] = (param.shape, state_dict[name].shape)
|
||||
|
||||
# Find parameters in model but not in state_dict and vice versa
|
||||
model_only = set(model_state.keys()) - set(state_dict.keys())
|
||||
state_dict_only = set(state_dict.keys()) - set(model_state.keys())
|
||||
|
||||
# Create filtered state_dict with only compatible parameters
|
||||
filtered_state_dict = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in model_state and param.shape == model_state[name].shape:
|
||||
filtered_state_dict[name] = param
|
||||
|
||||
# Print parameter differences
|
||||
if shape_mismatch:
|
||||
print(f"Parameters with shape mismatches: {len(shape_mismatch)}")
|
||||
for name, (model_shape, state_shape) in shape_mismatch.items():
|
||||
print(f" {name}: model={model_shape}, checkpoint={state_shape}")
|
||||
|
||||
if model_only:
|
||||
print(f"Parameters in model but not in checkpoint: {len(model_only)}")
|
||||
for name in sorted(model_only):
|
||||
print(f" {name}: {model_state[name].shape}")
|
||||
|
||||
# 特殊处理pos_cis_real参数
|
||||
if name == "pos_cis_real":
|
||||
print(f"Detected pos_cis_real parameter. This is a position encoding that will be initialized automatically.")
|
||||
|
||||
if state_dict_only:
|
||||
print(f"Parameters in checkpoint but not in model: {len(state_dict_only)}")
|
||||
for name in sorted(state_dict_only):
|
||||
print(f" {name}: {state_dict[name].shape}")
|
||||
|
||||
# 如果checkpoint中有output.weight但模型中没有,尝试加载到tok_embeddings
|
||||
if name == "output.weight" and "tok_embeddings.weight" in model_state:
|
||||
print(f"Found output.weight in checkpoint but not in model. Will try to map it to tok_embeddings.weight")
|
||||
if model_state["tok_embeddings.weight"].shape == state_dict["output.weight"].shape:
|
||||
filtered_state_dict["tok_embeddings.weight"] = state_dict["output.weight"]
|
||||
|
||||
# Load only the compatible parameters
|
||||
print(f"Loading {len(filtered_state_dict)}/{len(state_dict)} parameters")
|
||||
model.load_state_dict(filtered_state_dict, strict=False)
|
||||
|
||||
# 检查extract_db和weight_down_embed是否存在
|
||||
if not hasattr(model, "extract_db"):
|
||||
print("ERROR: Model does not have extract_db attribute. This is required for decoding.")
|
||||
return
|
||||
|
||||
print("Accessing weight_down_embed buffer")
|
||||
# Get the weight_down_embed buffer from the model
|
||||
try:
|
||||
weight_down_embed = model.extract_db.weight_down_embed
|
||||
print(f"Successfully accessed weight_down_embed buffer")
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to access weight_down_embed buffer: {e}")
|
||||
print(f"Model structure: {model.__class__.__name__}")
|
||||
print(f"ExtractDB attributes: {dir(model.extract_db)}")
|
||||
return
|
||||
|
||||
print(f"Shape of weight_down_embed: {weight_down_embed.shape}")
|
||||
print(f"Data type of weight_down_embed: {weight_down_embed.dtype}")
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
print(f"Decoding knowledge and writing to {output_path}")
|
||||
knowledge_num, knowledge_length = weight_down_embed.shape
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for i in range(knowledge_num):
|
||||
try:
|
||||
# Get token IDs for this knowledge entry
|
||||
token_ids = weight_down_embed[i].cpu().tolist()
|
||||
|
||||
# Decode tokens to text
|
||||
text = tokenizer.decode(token_ids, skip_special_tokens=True)
|
||||
|
||||
# Write to file
|
||||
f.write(f"Knowledge_{i}: {text}\n")
|
||||
|
||||
# Print progress periodically
|
||||
if (i + 1) % 100 == 0:
|
||||
print(f"Decoded {i + 1}/{knowledge_num} knowledge entries")
|
||||
except Exception as e:
|
||||
print(f"Error decoding knowledge entry {i}: {e}")
|
||||
f.write(f"Knowledge_{i}: [ERROR DECODING]\n")
|
||||
|
||||
print(f"Decoding completed. Output saved to {output_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database")
|
||||
parser.add_argument("--model_path", type=str, default="out/pretrain_1024.pth",
|
||||
help="Path to the model checkpoint")
|
||||
parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt",
|
||||
help="Path to save the decoded text file")
|
||||
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Device to load the model on")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
decode_dataset(args.model_path, args.output_path, args.device)
|
@ -1,49 +0,0 @@
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": "auto",
|
||||
"eps": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"steps_per_print": 100,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
@ -1,26 +0,0 @@
|
||||
# 1. 元数据:需要修改,请为该实验配置名称和描述
|
||||
name: ycz-minimind-test
|
||||
description: 测试minimind-test
|
||||
|
||||
# 2. 运行环境:一般不修改,如有需求可以手动替换为指定镜像
|
||||
environment:
|
||||
image: determinedai/pytorch-ngc:0.38.0 # 此项无需修改
|
||||
|
||||
# 3. 指定NAS上的数据集: 需要修改,仅修改bind_mounts字段,container_path和read_only无需修改
|
||||
#将<YOUR_DATASET_FOLDER_NAME>替换为您存放在NAS上Volume1/Share/datasets/的数据集文件夹名称
|
||||
# 请再次确保您已在 NAS上的Volume1/Share/datasets/存放了<YOUR_DATASET_FOLDER_NAME>数据集
|
||||
|
||||
|
||||
# 4. 计算资源:无需修改
|
||||
resources:
|
||||
slots_per_trial: 1 # 此项无需修改
|
||||
resource_pool: rtx4090 # 此项无需修改
|
||||
|
||||
# 5. 搜索器:无需修改
|
||||
searcher:
|
||||
name: single
|
||||
metric: test_accuracy
|
||||
smaller_is_better: false
|
||||
|
||||
# 6. 启动入口:无需修改
|
||||
entrypoint: sh startup.sh
|
BIN
images/1-wiki.png
Normal file
After Width: | Height: | Size: 136 KiB |
BIN
images/2-wiki.png
Normal file
After Width: | Height: | Size: 73 KiB |
BIN
images/3-wiki.png
Normal file
After Width: | Height: | Size: 230 KiB |
BIN
images/4-wiki.png
Normal file
After Width: | Height: | Size: 104 KiB |
BIN
images/5-wiki.png
Normal file
After Width: | Height: | Size: 239 KiB |
BIN
images/LLM-structure-moe.png
Normal file
After Width: | Height: | Size: 121 KiB |
BIN
images/LLM-structure.png
Executable file
After Width: | Height: | Size: 372 KiB |
BIN
images/and_huggingface.png
Normal file
After Width: | Height: | Size: 178 KiB |
BIN
images/and_modelscope.png
Normal file
After Width: | Height: | Size: 150 KiB |
BIN
images/compare_radar.png
Normal file
After Width: | Height: | Size: 519 KiB |
BIN
images/dataset.jpg
Normal file
After Width: | Height: | Size: 146 KiB |
BIN
images/gpt3_config.png
Normal file
After Width: | Height: | Size: 66 KiB |
BIN
images/logo.png
Normal file
After Width: | Height: | Size: 495 KiB |
BIN
images/logo2.png
Normal file
After Width: | Height: | Size: 615 KiB |
BIN
images/minimind2.gif
Normal file
After Width: | Height: | Size: 3.8 MiB |
BIN
images/pre_512_loss.png
Normal file
After Width: | Height: | Size: 559 KiB |
BIN
images/pre_768_loss.png
Normal file
After Width: | Height: | Size: 531 KiB |
BIN
images/sft_512_loss.png
Normal file
After Width: | Height: | Size: 1006 KiB |
BIN
images/sft_768_loss.png
Normal file
After Width: | Height: | Size: 943 KiB |
6
main.py
@ -1,6 +0,0 @@
|
||||
def main():
|
||||
print("Hello from minimind!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -19,7 +19,6 @@ class LMConfig(PretrainedConfig):
|
||||
rope_theta: int = 1e6,
|
||||
dropout: float = 0.0,
|
||||
flash_attn: bool = True,
|
||||
embeddings_epoch: int = 2,
|
||||
####################################################
|
||||
# DB related configurations
|
||||
####################################################
|
||||
@ -37,16 +36,6 @@ class LMConfig(PretrainedConfig):
|
||||
aux_loss_alpha: float = 0.1,
|
||||
seq_aux: bool = True,
|
||||
norm_topk_prob: bool = True,
|
||||
####################################################
|
||||
knowledge_num: int = 64*64,
|
||||
knowledge_length: int = 8,
|
||||
knowledge_dim: int = 128,
|
||||
####################################################
|
||||
# Triple extraction related configurations
|
||||
####################################################
|
||||
max_subject_len: int = 8,
|
||||
max_predicate_len: int = 4,
|
||||
max_object_len: int = 8,
|
||||
**kwargs,
|
||||
):
|
||||
self.dim = dim
|
||||
@ -61,7 +50,6 @@ class LMConfig(PretrainedConfig):
|
||||
self.rope_theta = rope_theta
|
||||
self.dropout = dropout
|
||||
self.flash_attn = flash_attn
|
||||
self.embeddings_epoch = embeddings_epoch
|
||||
####################################################
|
||||
# DB related configurations
|
||||
####################################################
|
||||
@ -78,14 +66,4 @@ class LMConfig(PretrainedConfig):
|
||||
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
|
||||
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
|
||||
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
|
||||
####################################################
|
||||
self.knowledge_num = knowledge_num
|
||||
self.knowledge_length = knowledge_length
|
||||
self.knowledge_dim = knowledge_dim
|
||||
####################################################
|
||||
# Triple extraction related configurations
|
||||
####################################################
|
||||
self.max_subject_len = max_subject_len
|
||||
self.max_predicate_len = max_predicate_len
|
||||
self.max_object_len = max_object_len
|
||||
super().__init__(**kwargs)
|
||||
|
325
model/dataset.py
@ -9,73 +9,8 @@ import torch
|
||||
from sklearn.model_selection import train_test_split
|
||||
import os
|
||||
import ast
|
||||
from tqdm import tqdm
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
|
||||
def process_sample_filter(data_args):
|
||||
"""处理单个样本的过滤逻辑"""
|
||||
sample, valid_predicates = data_args
|
||||
if 'target' in sample and isinstance(sample['target'], list):
|
||||
# 过滤target中的低频谓词
|
||||
valid_targets = []
|
||||
for triple in sample['target']:
|
||||
if isinstance(triple, dict) and 'predicate' in triple:
|
||||
if triple['predicate'] in valid_predicates:
|
||||
valid_targets.append(triple)
|
||||
|
||||
# 如果还有有效的target,保留这个样本
|
||||
if valid_targets:
|
||||
sample['target'] = valid_targets
|
||||
return sample
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
# 如果没有target信息,保留样本
|
||||
return sample
|
||||
|
||||
|
||||
def process_sample_validation(data_args):
|
||||
"""处理单个样本的验证逻辑"""
|
||||
sample, predicate_vocab = data_args
|
||||
if not isinstance(sample, dict) or 'text' not in sample:
|
||||
return None
|
||||
|
||||
targets = sample.get('target', [])
|
||||
if not isinstance(targets, list) or len(targets) == 0:
|
||||
# 如果没有有效的target,创建一个默认的
|
||||
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
||||
else:
|
||||
# 验证并选择target,优先选择占比小的谓词
|
||||
selected_target = None
|
||||
min_percentage = float('inf')
|
||||
|
||||
for triple in targets:
|
||||
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
|
||||
predicate = triple['predicate']
|
||||
|
||||
# 使用predicate_vocab中的统计信息
|
||||
if predicate in predicate_vocab:
|
||||
stats = predicate_vocab[predicate]
|
||||
if isinstance(stats, dict) and 'percentage' in stats:
|
||||
percentage = stats['percentage']
|
||||
if percentage < min_percentage:
|
||||
min_percentage = percentage
|
||||
selected_target = triple
|
||||
elif selected_target is None:
|
||||
selected_target = triple
|
||||
elif selected_target is None:
|
||||
selected_target = triple
|
||||
|
||||
# 如果没有找到有效的target,使用默认值
|
||||
if selected_target is None:
|
||||
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
||||
|
||||
return {
|
||||
'text': sample['text'],
|
||||
'target': selected_target # 只保留一个target
|
||||
}
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
class PretrainDataset(Dataset):
|
||||
@ -98,14 +33,9 @@ class PretrainDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
text = str(sample['text'])
|
||||
|
||||
# 检查并添加<|im_start|>和<|im_end|>如果不存在
|
||||
if not text.startswith(self.tokenizer.bos_token):
|
||||
text = f"{self.tokenizer.bos_token}{text}"
|
||||
if not text.endswith(self.tokenizer.eos_token):
|
||||
text = f"{text}{self.tokenizer.eos_token}"
|
||||
|
||||
|
||||
# 构建输入文本
|
||||
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
|
||||
encoding = self.tokenizer(
|
||||
text,
|
||||
max_length=self.max_length,
|
||||
@ -128,8 +58,8 @@ class SFTDataset(Dataset):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.samples = self.load_data(jsonl_path)
|
||||
self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids
|
||||
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
@ -196,8 +126,8 @@ class DPODataset(Dataset):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids
|
||||
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
self.data = []
|
||||
for line in f:
|
||||
@ -266,249 +196,14 @@ class DPODataset(Dataset):
|
||||
return loss_mask
|
||||
|
||||
|
||||
class TriplePretrainDataset(Dataset):
|
||||
"""
|
||||
优化的三元组预训练数据集
|
||||
- 每个样本只保留一个target三元组
|
||||
- 预先tokenize所有数据
|
||||
- 使用进度条显示处理进度
|
||||
"""
|
||||
def __init__(self, data_path=None, predicate_vocab_path=None, samples = None,tokenizer=None, max_length=512):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.val_samples = None
|
||||
self.predicate_to_id = {} # 初始化
|
||||
if samples is None:
|
||||
self.predicate_vocab = self.load_predicate_vocab(predicate_vocab_path)
|
||||
print("🚀 开始加载和预处理三元组数据...")
|
||||
self.samples,self.val_samples = self.load_and_preprocess_data(data_path)
|
||||
print("🚀 加载和预处理三元组数据完成")
|
||||
else:
|
||||
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
|
||||
data_filename = os.path.basename(data_path).split('.')[0]
|
||||
predicate_to_id_path = os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json')
|
||||
self.predicate_to_id = self.load_predicate_vocab(predicate_to_id_path)
|
||||
self.samples = samples
|
||||
print("🚀 加载和预处理三元组数据完成")
|
||||
def load_predicate_vocab(self, path):
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
predicate_vocab = json.load(f)
|
||||
return predicate_vocab
|
||||
|
||||
def get_val_samples(self):
|
||||
return self.val_samples
|
||||
|
||||
def clear_cache(self, data_path):
|
||||
"""清除缓存文件"""
|
||||
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
|
||||
data_filename = os.path.basename(data_path).split('.')[0]
|
||||
cache_files = [
|
||||
os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
|
||||
os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
|
||||
os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
|
||||
os.path.join(cache_dir, f'{data_filename}_val_samples.json')
|
||||
]
|
||||
|
||||
for cache_file in cache_files:
|
||||
if os.path.exists(cache_file):
|
||||
os.remove(cache_file)
|
||||
print(f"🗑️ 已删除缓存文件: {cache_file}")
|
||||
|
||||
if os.path.exists(cache_dir) and not os.listdir(cache_dir):
|
||||
os.rmdir(cache_dir)
|
||||
print(f"🗑️ 已删除空的缓存目录: {cache_dir}")
|
||||
|
||||
def load_and_preprocess_data(self, path):
|
||||
"""加载并预处理三元组数据"""
|
||||
# 生成缓存文件名(基于数据文件路径)
|
||||
cache_dir = os.path.join(os.path.dirname(path), 'cache')
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
data_filename = os.path.basename(path).split('.')[0]
|
||||
cache_files = {
|
||||
'predicate_vocab': os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
|
||||
'predicate_to_id': os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
|
||||
'train_samples': os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
|
||||
'val_samples': os.path.join(cache_dir, f'{data_filename}_val_samples.json')
|
||||
}
|
||||
|
||||
# 检查缓存文件是否存在
|
||||
cache_exists = all(os.path.exists(cache_file) for cache_file in cache_files.values())
|
||||
|
||||
if cache_exists:
|
||||
print("📁 发现缓存文件,直接加载...")
|
||||
# 从缓存加载
|
||||
with open(cache_files['predicate_vocab'], 'r', encoding='utf-8') as f:
|
||||
self.predicate_vocab = json.load(f)
|
||||
|
||||
with open(cache_files['predicate_to_id'], 'r', encoding='utf-8') as f:
|
||||
self.predicate_to_id = json.load(f)
|
||||
|
||||
with open(cache_files['train_samples'], 'r', encoding='utf-8') as f:
|
||||
train_samples = json.load(f)
|
||||
|
||||
with open(cache_files['val_samples'], 'r', encoding='utf-8') as f:
|
||||
val_samples = json.load(f)
|
||||
|
||||
print(f"✅ 从缓存加载完成:")
|
||||
print(f"✅ 谓词词表大小: {len(self.predicate_vocab)}")
|
||||
print(f"✅ 训练集大小: {len(train_samples)}")
|
||||
print(f"✅ 测试集大小: {len(val_samples)}")
|
||||
|
||||
return train_samples, val_samples
|
||||
|
||||
# 缓存不存在,重新处理数据
|
||||
print("📂 缓存不存在,开始加载和处理原始数据...")
|
||||
|
||||
# 1. 加载原始数据
|
||||
print("📂 加载原始数据...")
|
||||
if path.endswith('.json'):
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
elif path.endswith('.jsonl'):
|
||||
data = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data.append(json.loads(line.strip()))
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {path}")
|
||||
|
||||
print(f"📊 原始数据量: {len(data)} 个样本")
|
||||
|
||||
# 2. 使用self.predicate_vocab过滤占比小于0.01%的谓词数据
|
||||
print("🔍 过滤低频谓词数据...")
|
||||
print(f"📊 谓词统计数据: 总共{len(self.predicate_vocab)}个谓词")
|
||||
|
||||
# 3.获取占比大于等于0.01%的谓词
|
||||
valid_predicates = set()
|
||||
for predicate, stats in self.predicate_vocab.items():
|
||||
if isinstance(stats, dict) and 'percentage' in stats:
|
||||
if stats['percentage'] >= 0.01:
|
||||
valid_predicates.add(predicate)
|
||||
else:
|
||||
# 如果不是统计格式,假设是有效谓词
|
||||
valid_predicates.add(predicate)
|
||||
|
||||
print(f"📊 占比≥0.01%的谓词: {len(valid_predicates)}个")
|
||||
|
||||
# 4.过滤数据:去除包含低频谓词的数据(单进程处理)
|
||||
original_count = len(data)
|
||||
filtered_data = []
|
||||
|
||||
print("🚀 开始过滤低频谓词数据...")
|
||||
for sample in tqdm(data, desc="过滤低频谓词"):
|
||||
result = process_sample_filter((sample, valid_predicates))
|
||||
if result is not None:
|
||||
filtered_data.append(result)
|
||||
|
||||
data = filtered_data
|
||||
print(f"✅ 过滤完成: 去除前{original_count}条,去除后{len(data)}条")
|
||||
|
||||
# 5. 去除self.predicate_vocab中占比小于0.01%的谓词,并创建谓词到序号的映射
|
||||
print("🔍 更新谓词词表并创建序号映射...")
|
||||
original_vocab_size = len(self.predicate_vocab)
|
||||
filtered_predicate_vocab = {}
|
||||
|
||||
for predicate, stats in self.predicate_vocab.items():
|
||||
if isinstance(stats, dict) and 'percentage' in stats:
|
||||
if stats['percentage'] >= 0.01:
|
||||
filtered_predicate_vocab[predicate] = stats
|
||||
else:
|
||||
# 如果不是统计格式,保留
|
||||
filtered_predicate_vocab[predicate] = stats
|
||||
|
||||
# 创建谓词到序号的映射字典
|
||||
self.predicate_to_id = {predicate: idx for idx, predicate in enumerate(filtered_predicate_vocab.keys())}
|
||||
self.predicate_vocab = filtered_predicate_vocab
|
||||
print(f"✅ 谓词词表更新: 去除前{original_vocab_size}个,去除后{len(self.predicate_vocab)}个")
|
||||
print(f"✅ 谓词映射创建: {len(self.predicate_to_id)}个谓词对应序号")
|
||||
|
||||
# 6. 数据验证和筛选(只保留一个target),优先选择占比小的谓词以平衡数据(单进程处理)
|
||||
print("🔍 验证数据格式并选择单个target(平衡数据)...")
|
||||
valid_samples = []
|
||||
|
||||
print("🚀 开始验证数据格式...")
|
||||
for sample in tqdm(data, desc="验证数据格式"):
|
||||
result = process_sample_validation((sample, self.predicate_vocab))
|
||||
if result is not None:
|
||||
valid_samples.append(result)
|
||||
|
||||
print(f"✅ 有效样本数: {len(valid_samples)}")
|
||||
|
||||
# 7.拆分训练集合与测试集合
|
||||
import random
|
||||
random.seed(42)
|
||||
val_samples = random.sample(valid_samples, min(1000, len(valid_samples)))
|
||||
train_samples = [sample for sample in valid_samples if sample not in val_samples]
|
||||
print(f"✅ 训练集大小: {len(train_samples)}")
|
||||
print(f"✅ 测试集大小: {len(val_samples)}")
|
||||
|
||||
# 8. 保存到缓存文件
|
||||
print("💾 保存处理结果到缓存文件...")
|
||||
with open(cache_files['predicate_vocab'], 'w', encoding='utf-8') as f:
|
||||
json.dump(self.predicate_vocab, f, ensure_ascii=False, indent=2)
|
||||
|
||||
with open(cache_files['predicate_to_id'], 'w', encoding='utf-8') as f:
|
||||
json.dump(self.predicate_to_id, f, ensure_ascii=False, indent=2)
|
||||
|
||||
with open(cache_files['train_samples'], 'w', encoding='utf-8') as f:
|
||||
json.dump(train_samples, f, ensure_ascii=False, indent=2)
|
||||
|
||||
with open(cache_files['val_samples'], 'w', encoding='utf-8') as f:
|
||||
json.dump(val_samples, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print("✅ 缓存文件保存完成")
|
||||
|
||||
return train_samples, val_samples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def _triple_to_sentence(self, triple):
|
||||
"""将三元组转换为句子格式"""
|
||||
return f"{triple['subject']} {triple['predicate']} {triple['object']}"
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""返回数据,用于谓词分类任务"""
|
||||
sample = self.samples[index]
|
||||
|
||||
# 在运行时tokenize输入文本
|
||||
input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}"
|
||||
encoding = self.tokenizer(
|
||||
input_text,
|
||||
max_length=self.max_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
input_ids = encoding.input_ids.squeeze()
|
||||
loss_mask = (input_ids != self.tokenizer.pad_token_id)
|
||||
|
||||
# 获取谓词分类标签
|
||||
target_predicate = sample['target']['predicate']
|
||||
predicate_label = self.predicate_to_id.get(target_predicate) # 默认为0如果找不到
|
||||
|
||||
# 构建训练数据
|
||||
X = input_ids[:-1]
|
||||
loss_mask = loss_mask[1:]
|
||||
|
||||
return {
|
||||
'input_ids': X,
|
||||
'labels': torch.tensor(predicate_label, dtype=torch.long), # 谓词分类标签
|
||||
'loss_mask': loss_mask
|
||||
}
|
||||
|
||||
|
||||
class RLAIFDataset(Dataset):
|
||||
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.samples = self.load_data(jsonl_path)
|
||||
self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids
|
||||
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
@ -14,7 +14,7 @@
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"content": "<|im_start|>",
|
||||
"content": "<s>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
@ -23,7 +23,7 @@
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"content": "<|im_end|>",
|
||||
"content": "</s>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
@ -56,8 +56,8 @@
|
||||
"ignore_merges": false,
|
||||
"vocab": {
|
||||
"<unk>": 0,
|
||||
"<|im_start|>": 1,
|
||||
"<|im_end|>": 2,
|
||||
"<s>": 1,
|
||||
"</s>": 2,
|
||||
"!": 3,
|
||||
"\"": 4,
|
||||
"#": 5,
|
||||
|
@ -12,7 +12,7 @@
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "<|im_start|>",
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
@ -20,7 +20,7 @@
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "<|im_end|>",
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
@ -29,9 +29,9 @@
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [],
|
||||
"bos_token": "<|im_start|>",
|
||||
"bos_token": "<s>",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "<|im_end|>",
|
||||
"eos_token": "</s>",
|
||||
"legacy": true,
|
||||
"model_max_length": 32768,
|
||||
"pad_token": "<unk>",
|
||||
@ -39,5 +39,5 @@
|
||||
"spaces_between_special_tokens": false,
|
||||
"tokenizer_class": "PreTrainedTokenizerFast",
|
||||
"unk_token": "<unk>",
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\n你是 MiniMind,是一个有用的人工智能助手。<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<s>system\\n' + system_message + '</s>\\n' }}{% else %}{{ '<s>system\\n你是 MiniMind,是一个有用的人工智能助手。</s>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
|
||||
}
|
651
model/model.py
@ -2,8 +2,7 @@ import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
import gc
|
||||
#子空间二维分解+梯度更新
|
||||
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
@ -12,9 +11,14 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
|
||||
# RMSNorm 类定义了一个用于归一化输入张量的模块。
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
@ -27,7 +31,7 @@ class RMSNorm(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return self.weight * self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
# precompute_pos_cis 函数用于预计算位置编码。
|
||||
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
@ -35,7 +39,7 @@ def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return pos_cis
|
||||
|
||||
|
||||
# apply_rotary_emb 函数用于应用旋转位置编码。
|
||||
def apply_rotary_emb(xq, xk, pos_cis):
|
||||
def unite_shape(pos_cis, x):
|
||||
ndim = x.ndim
|
||||
@ -51,244 +55,18 @@ def apply_rotary_emb(xq, xk, pos_cis):
|
||||
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
class KnowledgeDataset(nn.Module):
|
||||
def __init__(self, params, tok_embeddings, is_train=True):
|
||||
super().__init__()
|
||||
self.is_train = is_train
|
||||
self.params = params
|
||||
self.tok_embeddings = tok_embeddings
|
||||
# repeat_kv 函数用于重复键值对。
|
||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||||
bs, slen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x[:, :, :, None, :]
|
||||
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
||||
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
# 嵌入参数
|
||||
self.knowledge_dim = params.knowledge_dim
|
||||
self.key_dim = self.knowledge_dim // 2
|
||||
self.to_queries = nn.Sequential(
|
||||
nn.Linear(params.dim, self.knowledge_dim, bias=False),
|
||||
)
|
||||
|
||||
## 数据库参数
|
||||
self.knowledge_num = params.knowledge_num
|
||||
self.knowledge_length = params.knowledge_length
|
||||
|
||||
# 修改键存储为二维分解空间,设置为可训练参数
|
||||
self.num_keys = int(math.sqrt(self.knowledge_num))
|
||||
# 确保keys是可训练参数
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
|
||||
# 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度
|
||||
self.register_buffer('knowledge_dataset',
|
||||
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
|
||||
|
||||
# 计算step数目,用于动态调整权重
|
||||
self.step_counter = 0
|
||||
|
||||
# 移除批次计数器和更新频率相关代码
|
||||
|
||||
def intelligent_selection(self, query, all_scores, all_indices):
|
||||
"""智能分层选择策略"""
|
||||
if self.is_train == False:
|
||||
return all_scores, all_indices
|
||||
|
||||
batch_size = all_scores.size(0)
|
||||
device = all_scores.device
|
||||
dtype = all_scores.dtype
|
||||
|
||||
# 记录进入智能选择前的内存状态
|
||||
if hasattr(self, 'step_counter'):
|
||||
self.step_counter += 1
|
||||
# 禁用GPU内存监控记录以提高性能
|
||||
# if self.step_counter % 50 == 0: # 每50次调用记录一次
|
||||
# if torch.cuda.is_available():
|
||||
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[INTEL_SELECT_ENTER] Step {self.step_counter}: GPU Memory: {allocated_before:.2f}GB")
|
||||
|
||||
# 对每个batch进行分层选择
|
||||
enhanced_scores = all_scores.clone()
|
||||
query_features = query.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
# 预先计算所有候选条目的嵌入(批量优化)
|
||||
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
|
||||
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
|
||||
|
||||
# 批量计算唯一候选条目的嵌入
|
||||
candidate_tokens = self.knowledge_dataset[unique_indices]
|
||||
flat_tokens = candidate_tokens.view(-1)
|
||||
flat_embeddings = self.tok_embeddings(flat_tokens)
|
||||
|
||||
# 获取flat_tokens对应的index(保留这些变量以便其他地方使用)
|
||||
pre_update_indices = unique_indices.view(-1)
|
||||
pre_update_embeddings = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
)
|
||||
|
||||
unique_candidate_features = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
).mean(dim=1) # [num_unique_candidates, dim]
|
||||
|
||||
# 归一化候选特征(优化相似度计算)
|
||||
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
|
||||
normalized_queries = F.normalize(query_features, dim=-1)
|
||||
|
||||
# 收集所有batch的best_tokens
|
||||
batch_best_tokens = []
|
||||
batch_best_tokens_embeddings = []
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
indices = all_indices[batch_idx]
|
||||
|
||||
# 获取当前batch候选条目对应的特征索引
|
||||
start_idx = batch_idx * len(indices)
|
||||
end_idx = start_idx + len(indices)
|
||||
batch_inverse_indices = inverse_indices[start_idx:end_idx]
|
||||
|
||||
# 使用预计算的归一化特征进行优化相似度计算
|
||||
batch_candidate_features = normalized_candidates[batch_inverse_indices]
|
||||
query_feature = normalized_queries[batch_idx]
|
||||
|
||||
# 使用矩阵乘法计算余弦相似度
|
||||
similarity_scores = torch.mv(batch_candidate_features, query_feature)
|
||||
|
||||
# 找到最大相似度分数的索引
|
||||
max_similarity_idx = torch.argmax(similarity_scores)
|
||||
|
||||
# 获取最大相似度对应的候选条目索引
|
||||
best_candidate_idx = indices[max_similarity_idx]
|
||||
|
||||
# 获取对应的tokens
|
||||
best_tokens = self.knowledge_dataset[best_candidate_idx]
|
||||
best_tokens_embeddings = self.tok_embeddings(best_tokens)
|
||||
|
||||
# 将当前batch的best_tokens添加到列表中
|
||||
batch_best_tokens.append(best_tokens)
|
||||
batch_best_tokens_embeddings.append(best_tokens_embeddings)
|
||||
|
||||
# 将所有batch的best_tokens堆叠成一个张量
|
||||
# [batch_size, knowledge_length]
|
||||
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
|
||||
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
|
||||
|
||||
# 清理中间张量以防止内存泄漏
|
||||
del all_candidate_indices, unique_indices, inverse_indices
|
||||
del unique_candidate_features, normalized_candidates, normalized_queries
|
||||
del batch_best_tokens, batch_best_tokens_embeddings
|
||||
del flat_tokens, flat_embeddings, pre_update_embeddings
|
||||
|
||||
# 记录退出智能选择后的内存状态(已禁用以提高性能)
|
||||
# if hasattr(self, 'step_counter') and self.step_counter % 50 == 0:
|
||||
# if torch.cuda.is_available():
|
||||
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[INTEL_SELECT_EXIT] Step {self.step_counter}: GPU Memory: {allocated_after:.2f}GB")
|
||||
|
||||
# 强制垃圾回收(仅在监控步骤)
|
||||
if hasattr(self, 'step_counter') and self.step_counter % 100 == 0:
|
||||
gc.collect()
|
||||
# if torch.cuda.is_available():
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
return all_best_tokens, all_best_tokens_embeddings
|
||||
|
||||
|
||||
|
||||
def search_index(self, x):
|
||||
batch_size, seq_len, dim = x.shape
|
||||
|
||||
# 1. 序列维度平均
|
||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
# 2. 生成查询向量并重塑为两个子查询
|
||||
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
|
||||
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
|
||||
# 调整维度顺序,使子空间维度位于首位
|
||||
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
|
||||
|
||||
# 3. 计算每个子空间的相似度
|
||||
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# 4. 在两个子空间分别做top-k
|
||||
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# 5. 组合两个子空间的结果
|
||||
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
|
||||
# 6. 将结果重塑为二维
|
||||
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
|
||||
# 7. 选择最终的top-k结果
|
||||
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
|
||||
indices = torch.gather(all_indices, 1, indices_of_indices)
|
||||
|
||||
# 8. 应用智能分层选择策略
|
||||
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
|
||||
|
||||
|
||||
return best_tokens, best_tokens_embeddings
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = 8
|
||||
self.head_dim = self.config.dim // self.num_heads
|
||||
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||
batch_size = x.size(0)
|
||||
|
||||
# 监控交叉注意力开始时的内存(已禁用以提高性能)
|
||||
if not hasattr(self, 'call_counter'):
|
||||
self.call_counter = 0
|
||||
self.call_counter += 1
|
||||
|
||||
# 禁用GPU内存监控记录以提高性能
|
||||
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
|
||||
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[CROSS_ATTN_ENTER] Call {self.call_counter}: GPU Memory: {allocated_before:.2f}GB")
|
||||
|
||||
# 分离多头
|
||||
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if pos_emb is not None:
|
||||
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q = q + pos_emb
|
||||
k = k + pos_emb
|
||||
v = v + pos_emb
|
||||
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
if context_mask is not None:
|
||||
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
||||
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
|
||||
context = torch.matmul(attn_weights, v)
|
||||
|
||||
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
|
||||
|
||||
context = self.to_out(context)
|
||||
|
||||
# 清理中间张量
|
||||
del q, k, v, attn_scores, attn_weights
|
||||
|
||||
# 监控交叉注意力结束时的内存(已禁用以提高性能)
|
||||
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
|
||||
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[CROSS_ATTN_EXIT] Call {self.call_counter}: GPU Memory: {allocated_after:.2f}GB")
|
||||
|
||||
return context
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: LMConfig):
|
||||
@ -314,14 +92,58 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
pos_cis: torch.Tensor):
|
||||
bsz, seq_len, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
pos_cis: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache=True,
|
||||
db_value=None):
|
||||
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换,得到查询、键和值。
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) #将变换后的张量xq重塑为形状为(bsz, seq_len, n_local_heads, head_dim)的形状。
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
|
||||
|
||||
# 应用旋转位置编码
|
||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||||
# kv_cache实现
|
||||
if past_key_value is not None:
|
||||
xk = torch.cat([past_key_value[0], xk], dim=1)
|
||||
xv = torch.cat([past_key_value[1], xv], dim=1)
|
||||
past_kv = (xk, xv) if use_cache else None
|
||||
|
||||
# 重复键值对
|
||||
xq, xk, xv = (
|
||||
xq.transpose(1, 2),
|
||||
repeat_kv(xk, self.n_rep).transpose(1, 2),
|
||||
repeat_kv(xv, self.n_rep).transpose(1, 2)
|
||||
)
|
||||
|
||||
# 如果提供了db_value,根据头的数量调整它的形状并与xv合并
|
||||
if db_value is not None:
|
||||
# 确保db_value的形状与xv兼容,假设db_value形状为[B, N, H, D]
|
||||
if db_value.ndim == 4: # [B, N, H, D]
|
||||
db_value = db_value.transpose(1, 2) # -> [B, H, N, D]
|
||||
|
||||
# 检查是否需要调整D维度
|
||||
if db_value.shape[-1] != xv.shape[-1]:
|
||||
# 如果db_value的维度与xv不同,可以添加一个投影层
|
||||
# 或者在这里使用简单的调整方法
|
||||
# 这里我们简单地通过均值池化或重复来调整维度
|
||||
if db_value.shape[-1] > xv.shape[-1]:
|
||||
# 降维
|
||||
factor = db_value.shape[-1] // xv.shape[-1]
|
||||
db_value = db_value.view(bsz, self.n_local_heads, seq_len, factor, xv.shape[-1])
|
||||
db_value = db_value.mean(dim=3)
|
||||
else:
|
||||
# 升维
|
||||
factor = xv.shape[-1] // db_value.shape[-1]
|
||||
db_value = db_value.unsqueeze(-1).repeat(1, 1, 1, 1, factor)
|
||||
db_value = db_value.view(bsz, self.n_local_heads, seq_len, xv.shape[-1])
|
||||
|
||||
# 将db_value与xv相加或融合
|
||||
# 这里我们简单地将它们相加,但你也可以使用其他融合方法
|
||||
xv = xv + db_value
|
||||
|
||||
# 使用Flash Attention
|
||||
if self.flash and seq_len != 1:
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
output = F.scaled_dot_product_attention(
|
||||
@ -339,9 +161,56 @@ class Attention(nn.Module):
|
||||
|
||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
output = self.resid_dropout(self.wo(output))
|
||||
return output
|
||||
return output, past_kv
|
||||
|
||||
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = 8
|
||||
self.head_dim = self.config.dim // self.num_heads
|
||||
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||
batch_size = x.size(0)
|
||||
|
||||
# 分离多头
|
||||
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if pos_emb is not None:
|
||||
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q = q + pos_emb
|
||||
k = k + pos_emb
|
||||
v = v + pos_emb
|
||||
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
if context_mask is not None:
|
||||
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
||||
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
|
||||
context = torch.matmul(attn_weights, v)
|
||||
|
||||
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
|
||||
|
||||
context = self.to_out(context)
|
||||
|
||||
return context
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
@ -474,30 +343,167 @@ class MOEFeedForward(nn.Module):
|
||||
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
|
||||
def __init__(self, layer_id: int, config: LMConfig):
|
||||
super().__init__()
|
||||
self.n_heads = config.n_heads
|
||||
self.dim = config.dim
|
||||
self.head_dim = config.dim // config.n_heads
|
||||
self.self_attention = Attention(config)
|
||||
self.cross_attention = CrossAttention(config)
|
||||
self.knowledge_dataset = knowledge_dataset
|
||||
self.attention = Attention(config)
|
||||
self.cross_att = CrossAttention(config)
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||
|
||||
def forward(self, x, pos_cis):
|
||||
h_attn = self.self_attention(
|
||||
# 假设num_experts是已定义的总专家数量的平方根
|
||||
|
||||
|
||||
# 查询生成的参数
|
||||
|
||||
|
||||
# 创建查询生成模块
|
||||
# if weight_down_embed is not None:
|
||||
# self.to_queries = nn.Sequential(
|
||||
# nn.Linear(config.dim, self.dim_key * 2, bias=False),
|
||||
# # nn.Unflatten(2, (2, self.n_heads, self.dim_key)) # 替代Rearrange
|
||||
# )
|
||||
|
||||
# # 超参数
|
||||
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
|
||||
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
|
||||
|
||||
def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=True):
|
||||
# import pdb;pdb.set_trace()
|
||||
# db_value = None
|
||||
|
||||
# # 如果有weight_down_embed,使用Product Key机制
|
||||
# if self.weight_down_embed is not None:
|
||||
# # 1. 生成queries
|
||||
# batch_size, seq_len, dim = x.shape
|
||||
|
||||
# # collapse sequence dimension by averaging
|
||||
# x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
# queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
||||
# queries = queries.reshape(batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
|
||||
# queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
|
||||
|
||||
# # 2. 计算queries与keys的相似度
|
||||
# sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# # 3. 在两个子空间分别做top-k
|
||||
# scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
# scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
# indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# # 4. 组合两个子空间的分数和索引
|
||||
# all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
|
||||
# all_scores = all_scores.view(*all_scores.shape[:-2], -1)
|
||||
|
||||
# all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
|
||||
# all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
||||
|
||||
# # 5. 最终top-k选择
|
||||
# scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
|
||||
# indices = all_indices.gather(-1, pk_indices)
|
||||
|
||||
# # 6. 从embedding中获取专家值
|
||||
|
||||
# # 从embedding中获取值
|
||||
# flat_indices = indices.view(-1) # 将索引展平为一维张量
|
||||
# db_values = self.weight_down_embed(flat_indices)
|
||||
|
||||
# # 重塑回原始形状
|
||||
# db_value = db_values.view(batch_size, -1, dim)
|
||||
|
||||
|
||||
# 注意力计算
|
||||
h_attn, past_kv = self.attention(
|
||||
self.attention_norm(x),
|
||||
pos_cis
|
||||
pos_cis,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
db_value=db_value
|
||||
)
|
||||
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
|
||||
h_attn = self.cross_attention(h_attn, db_embeddings)
|
||||
|
||||
h_attn = self.cross_att(h_attn, db_value)
|
||||
|
||||
# 残差连接
|
||||
h = x + h_attn
|
||||
|
||||
# 前馈神经网络
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
return out, past_kv
|
||||
|
||||
class ExtractDB(nn.Module):
|
||||
def __init__(self,params):
|
||||
# 修改专家数量和知识维度,确保能开方
|
||||
super().__init__()
|
||||
self.batch_size = None
|
||||
self.dim = params.dim
|
||||
self.dim_key = self.dim // 2
|
||||
self.num_experts = 10 * 10 # 100专家,确保是完全平方数
|
||||
# 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用
|
||||
self.head_dim = params.dim // params.n_heads
|
||||
self.knowledge_dim = 8*params.dim
|
||||
|
||||
# 使用register_buffer代替nn.Parameter,避免梯度问题
|
||||
self.register_buffer('weight_down_embed', torch.randn(self.num_experts, self.knowledge_dim) * 0.02)
|
||||
|
||||
self.num_keys = int(math.sqrt(self.num_experts)) if self.num_experts > 0 else 0
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02)
|
||||
self.num_experts_per_head_topk = 1
|
||||
self.to_queries = nn.Sequential(
|
||||
nn.Linear(params.dim, self.dim_key * 2, bias=False),
|
||||
)
|
||||
|
||||
def q_to_k(self,x):
|
||||
# 1. 生成queries
|
||||
self.batch_size, seq_len, dim = x.shape
|
||||
|
||||
# collapse sequence dimension by averaging
|
||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
||||
queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
|
||||
queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
|
||||
|
||||
# 2. 计算queries与keys的相似度
|
||||
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# 3. 在两个子空间分别做top-k
|
||||
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# 4. 组合两个子空间的分数和索引
|
||||
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
|
||||
all_scores = all_scores.view(*all_scores.shape[:-2], -1)
|
||||
|
||||
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
|
||||
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
||||
|
||||
# 5. 最终top-k选择
|
||||
scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
|
||||
indices = all_indices.gather(-1, pk_indices)
|
||||
flat_indices = indices.view(-1)
|
||||
return flat_indices
|
||||
|
||||
def get_data(self, index):
|
||||
# 直接从GPU获取embedding
|
||||
db_values = self.weight_down_embed[index]
|
||||
db_value = db_values.view(self.batch_size, -1, self.dim)
|
||||
return db_value
|
||||
|
||||
@torch.no_grad()
|
||||
def updata_value(self, k, v):
|
||||
# 直接更新buffer上的值 (不需要梯度)
|
||||
v_reshaped = v.view(v.size(0), -1)
|
||||
# 确保数据类型匹配
|
||||
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
|
||||
self.weight_down_embed[k] = v_reshaped
|
||||
|
||||
|
||||
|
||||
class MiniMindLM(PreTrainedModel):
|
||||
@ -509,63 +515,129 @@ class MiniMindLM(PreTrainedModel):
|
||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||||
self.dropout = nn.Dropout(params.dropout)
|
||||
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
|
||||
# 移除旧的weight_down_embed声明
|
||||
self.extract_db = ExtractDB(self.params)
|
||||
|
||||
# 将self.weight_down_embed传递给每个MiniMindBlock
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
|
||||
# Calculate input dimension
|
||||
input_dim = (self.params.max_seq_len-1)*self.params.n_layers
|
||||
# Use a bottleneck architecture to reduce parameters
|
||||
bottleneck_dim = 256 # Significantly smaller bottleneck dimension
|
||||
|
||||
# Factorized shared downsampling using two smaller convolutions
|
||||
self.shared_downsample = nn.Sequential(
|
||||
# First reduce input dimension to bottleneck
|
||||
nn.Conv1d(input_dim, bottleneck_dim, kernel_size=1, padding='same'),
|
||||
nn.ReLU(), # Non-linearity to improve representation capacity
|
||||
# Then expand to target dimension
|
||||
nn.Conv1d(bottleneck_dim, 128*8, kernel_size=1, padding='same')
|
||||
)
|
||||
|
||||
# Specific layers for v path
|
||||
self.downsample_v_specific = nn.Sequential(
|
||||
nn.Conv1d(128*8, 128, kernel_size=1, padding='same'),
|
||||
nn.Conv1d(128, 8, kernel_size=1, padding='same')
|
||||
)
|
||||
|
||||
# Specific layers for q path
|
||||
self.downsample_q_specific = nn.Sequential(
|
||||
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
|
||||
)
|
||||
self.register_buffer("pos_cis",
|
||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
self.OUT = CausalLMOutputWithPast()
|
||||
self.freeze_embedding = False
|
||||
self.params = params
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
step: int = 0,
|
||||
**args):
|
||||
past_key_values = past_key_values or [None] * len(self.layers)
|
||||
start_pos = args.get('start_pos', 0)
|
||||
# if self.freeze_embedding and step == 0:
|
||||
# self.tok_embeddings.weight.requires_grad = False
|
||||
# # 移除对knowledge_dataset.freeze_embedding的设置,让键更新由batch_counter控制
|
||||
# # self.knowledge_dataset.freeze_embedding = True
|
||||
# print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
past_kvs = []
|
||||
h_list = []
|
||||
|
||||
for l, layer in enumerate(self.layers):
|
||||
h = layer(
|
||||
h, pos_cis
|
||||
# 禁用数据库模式,使用固定值替代数据库查询
|
||||
if self.params.disable_db:
|
||||
# 创建一个形状为[batch_size, n_layers, dim]的tensor,所有元素值为1e-4
|
||||
batch_size = h.size(0)
|
||||
db_value = torch.full((batch_size, self.n_layers, self.params.dim), 1e-4,
|
||||
dtype=h.dtype, device=h.device)
|
||||
else:
|
||||
# 正常模式,使用数据库查询
|
||||
index = self.extract_db.q_to_k(h)
|
||||
db_value = self.extract_db.get_data(index)
|
||||
|
||||
h, past_kv = layer(
|
||||
h, db_value, pos_cis,
|
||||
past_key_value=past_key_values[l],
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
past_kvs.append(past_kv)
|
||||
h_list.append(h.unsqueeze(0))
|
||||
|
||||
h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3)
|
||||
|
||||
# 只在非禁用数据库模式下执行数据库更新逻辑
|
||||
if not self.params.disable_db:
|
||||
# 使用detach()分离计算图,避免多次反向传播
|
||||
h_tensor_detached = h_tensor.detach()
|
||||
h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0], -1, self.params.dim)
|
||||
|
||||
# 数据库更新逻辑与主计算图分离
|
||||
with torch.no_grad():
|
||||
# Compute shared downsampling layer once
|
||||
shared_features = self.shared_downsample(h_tensor_detached)
|
||||
z_v = self.downsample_v_specific(shared_features)
|
||||
z_q = self.downsample_q_specific(shared_features)
|
||||
z_k = self.extract_db.q_to_k(z_q)
|
||||
self.extract_db.updata_value(z_k, z_v)
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
# 统一不使用 aux_loss
|
||||
aux_loss = 0
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
||||
|
||||
# 进一步简化,只保留必要的参数
|
||||
output = CausalLMOutputWithPast(
|
||||
logits=logits,
|
||||
past_key_values=past_kvs,
|
||||
)
|
||||
output.hidden_states = h
|
||||
|
||||
output.aux_loss = aux_loss
|
||||
|
||||
# 尝试添加其他属性(如果支持的话)
|
||||
# try:
|
||||
# output.hidden_states = h
|
||||
# except:
|
||||
# pass
|
||||
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||||
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
|
||||
stream=False, rp=1., use_cache=True, pad_token_id=0, num_return_sequences=1, **args):
|
||||
# 流式生成
|
||||
if stream:
|
||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
||||
|
||||
# 直接生成
|
||||
generated = []
|
||||
for i in range(input_ids.size(0)):
|
||||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||||
for _ in range(num_return_sequences):
|
||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
||||
tokens_list = [tokens[:, -1:] for tokens in out]
|
||||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||||
@ -582,20 +654,17 @@ class MiniMindLM(PreTrainedModel):
|
||||
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||||
return res
|
||||
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
|
||||
start = input_ids.shape[1]
|
||||
for _ in range(max_new_tokens):
|
||||
# 每次都传入完整的input_ids,不使用KV缓存
|
||||
out = self(input_ids, **args)
|
||||
logits = out.logits[:, -1, :] # 取最后一个位置的logits
|
||||
|
||||
# 重复惩罚
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
|
||||
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
||||
while input_ids.shape[1] < max_new_tokens - 1:
|
||||
if first_seq or not use_cache:
|
||||
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False
|
||||
else:
|
||||
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
|
||||
start_pos=input_ids.shape[1] - 1, **args)
|
||||
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
||||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||||
|
||||
# 温度采样
|
||||
logits /= (temperature + 1e-9)
|
||||
|
||||
# Top-p采样
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
||||
@ -605,14 +674,8 @@ class MiniMindLM(PreTrainedModel):
|
||||
sorted_indices_to_remove[:, 0] = False
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = -float('Inf')
|
||||
|
||||
# 采样下一个token
|
||||
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
||||
|
||||
# 返回新生成的部分
|
||||
yield input_ids[:, start:]
|
||||
|
||||
# 如果遇到结束token,停止生成
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
||||
|
@ -1,732 +0,0 @@
|
||||
import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
import gc
|
||||
#子空间二维分解+梯度更新
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self.weight * self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return pos_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(xq, xk, pos_cis):
|
||||
def unite_shape(pos_cis, x):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert pos_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return pos_cis.view(*shape)
|
||||
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
pos_cis = unite_shape(pos_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
class KnowledgeDataset(nn.Module):
|
||||
def __init__(self, params, tok_embeddings, is_train=True):
|
||||
super().__init__()
|
||||
self.is_train = is_train
|
||||
self.params = params
|
||||
self.tok_embeddings = tok_embeddings
|
||||
|
||||
# 嵌入参数
|
||||
self.knowledge_dim = params.knowledge_dim
|
||||
self.key_dim = self.knowledge_dim // 2
|
||||
self.to_queries = nn.Sequential(
|
||||
nn.Linear(params.dim, self.knowledge_dim, bias=False),
|
||||
)
|
||||
|
||||
## 数据库参数
|
||||
self.knowledge_num = params.knowledge_num
|
||||
self.knowledge_length = params.knowledge_length
|
||||
|
||||
# 修改键存储为二维分解空间,设置为可训练参数
|
||||
self.num_keys = int(math.sqrt(self.knowledge_num))
|
||||
# 确保keys是可训练参数
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
|
||||
# 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度
|
||||
self.register_buffer('knowledge_dataset',
|
||||
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
|
||||
|
||||
# 计算step数目,用于动态调整权重
|
||||
self.step_counter = 0
|
||||
|
||||
# 移除批次计数器和更新频率相关代码
|
||||
|
||||
def intelligent_selection(self, query, all_scores, all_indices):
|
||||
"""智能分层选择策略"""
|
||||
if self.is_train == False:
|
||||
return all_scores, all_indices
|
||||
|
||||
batch_size = all_scores.size(0)
|
||||
device = all_scores.device
|
||||
dtype = all_scores.dtype
|
||||
|
||||
# 记录进入智能选择前的内存状态
|
||||
if hasattr(self, 'step_counter'):
|
||||
self.step_counter += 1
|
||||
# 禁用GPU内存监控记录以提高性能
|
||||
# if self.step_counter % 50 == 0: # 每50次调用记录一次
|
||||
# if torch.cuda.is_available():
|
||||
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[INTEL_SELECT_ENTER] Step {self.step_counter}: GPU Memory: {allocated_before:.2f}GB")
|
||||
|
||||
# 对每个batch进行分层选择
|
||||
enhanced_scores = all_scores.clone()
|
||||
query_features = query.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
# 预先计算所有候选条目的嵌入(批量优化)
|
||||
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
|
||||
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
|
||||
|
||||
# 批量计算唯一候选条目的嵌入
|
||||
candidate_tokens = self.knowledge_dataset[unique_indices]
|
||||
flat_tokens = candidate_tokens.view(-1)
|
||||
flat_embeddings = self.tok_embeddings(flat_tokens)
|
||||
|
||||
# 获取flat_tokens对应的index(保留这些变量以便其他地方使用)
|
||||
pre_update_indices = unique_indices.view(-1)
|
||||
pre_update_embeddings = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
)
|
||||
|
||||
unique_candidate_features = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
).mean(dim=1) # [num_unique_candidates, dim]
|
||||
|
||||
# 归一化候选特征(优化相似度计算)
|
||||
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
|
||||
normalized_queries = F.normalize(query_features, dim=-1)
|
||||
|
||||
# 收集所有batch的best_tokens
|
||||
batch_best_tokens = []
|
||||
batch_best_tokens_embeddings = []
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
indices = all_indices[batch_idx]
|
||||
|
||||
# 获取当前batch候选条目对应的特征索引
|
||||
start_idx = batch_idx * len(indices)
|
||||
end_idx = start_idx + len(indices)
|
||||
batch_inverse_indices = inverse_indices[start_idx:end_idx]
|
||||
|
||||
# 使用预计算的归一化特征进行优化相似度计算
|
||||
batch_candidate_features = normalized_candidates[batch_inverse_indices]
|
||||
query_feature = normalized_queries[batch_idx]
|
||||
|
||||
# 使用矩阵乘法计算余弦相似度
|
||||
similarity_scores = torch.mv(batch_candidate_features, query_feature)
|
||||
|
||||
# 找到最大相似度分数的索引
|
||||
max_similarity_idx = torch.argmax(similarity_scores)
|
||||
|
||||
# 获取最大相似度对应的候选条目索引
|
||||
best_candidate_idx = indices[max_similarity_idx]
|
||||
|
||||
# 获取对应的tokens
|
||||
best_tokens = self.knowledge_dataset[best_candidate_idx]
|
||||
best_tokens_embeddings = self.tok_embeddings(best_tokens)
|
||||
|
||||
# 将当前batch的best_tokens添加到列表中
|
||||
batch_best_tokens.append(best_tokens)
|
||||
batch_best_tokens_embeddings.append(best_tokens_embeddings)
|
||||
|
||||
# 将所有batch的best_tokens堆叠成一个张量
|
||||
# [batch_size, knowledge_length]
|
||||
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
|
||||
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
|
||||
|
||||
# 清理中间张量以防止内存泄漏
|
||||
del all_candidate_indices, unique_indices, inverse_indices
|
||||
del unique_candidate_features, normalized_candidates, normalized_queries
|
||||
del batch_best_tokens, batch_best_tokens_embeddings
|
||||
del flat_tokens, flat_embeddings, pre_update_embeddings
|
||||
|
||||
# 记录退出智能选择后的内存状态(已禁用以提高性能)
|
||||
# if hasattr(self, 'step_counter') and self.step_counter % 50 == 0:
|
||||
# if torch.cuda.is_available():
|
||||
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[INTEL_SELECT_EXIT] Step {self.step_counter}: GPU Memory: {allocated_after:.2f}GB")
|
||||
|
||||
# 强制垃圾回收(仅在监控步骤)
|
||||
if hasattr(self, 'step_counter') and self.step_counter % 100 == 0:
|
||||
gc.collect()
|
||||
# if torch.cuda.is_available():
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
return all_best_tokens, all_best_tokens_embeddings
|
||||
|
||||
|
||||
|
||||
def search_index(self, x):
|
||||
batch_size, seq_len, dim = x.shape
|
||||
|
||||
# 1. 序列维度平均
|
||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
# 2. 生成查询向量并重塑为两个子查询
|
||||
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
|
||||
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
|
||||
# 调整维度顺序,使子空间维度位于首位
|
||||
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
|
||||
|
||||
# 3. 计算每个子空间的相似度
|
||||
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# 4. 在两个子空间分别做top-k
|
||||
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# 5. 组合两个子空间的结果
|
||||
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
|
||||
# 6. 将结果重塑为二维
|
||||
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
|
||||
# 7. 选择最终的top-k结果
|
||||
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
|
||||
indices = torch.gather(all_indices, 1, indices_of_indices)
|
||||
|
||||
# 8. 应用智能分层选择策略
|
||||
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
|
||||
|
||||
|
||||
return best_tokens, best_tokens_embeddings
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = 8
|
||||
self.head_dim = self.config.dim // self.num_heads
|
||||
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||
batch_size = x.size(0)
|
||||
|
||||
# 监控交叉注意力开始时的内存(已禁用以提高性能)
|
||||
if not hasattr(self, 'call_counter'):
|
||||
self.call_counter = 0
|
||||
self.call_counter += 1
|
||||
|
||||
# 禁用GPU内存监控记录以提高性能
|
||||
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
|
||||
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[CROSS_ATTN_ENTER] Call {self.call_counter}: GPU Memory: {allocated_before:.2f}GB")
|
||||
|
||||
# 分离多头
|
||||
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if pos_emb is not None:
|
||||
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q = q + pos_emb
|
||||
k = k + pos_emb
|
||||
v = v + pos_emb
|
||||
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
if context_mask is not None:
|
||||
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
||||
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
|
||||
context = torch.matmul(attn_weights, v)
|
||||
|
||||
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
|
||||
|
||||
context = self.to_out(context)
|
||||
|
||||
# 清理中间张量
|
||||
del q, k, v, attn_scores, attn_weights
|
||||
|
||||
# 监控交叉注意力结束时的内存(已禁用以提高性能)
|
||||
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
|
||||
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[CROSS_ATTN_EXIT] Call {self.call_counter}: GPU Memory: {allocated_after:.2f}GB")
|
||||
|
||||
return context
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: LMConfig):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
assert args.n_heads % self.n_kv_heads == 0
|
||||
self.n_local_heads = args.n_heads
|
||||
self.n_local_kv_heads = self.n_kv_heads
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
self.dropout = args.dropout
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
pos_cis: torch.Tensor):
|
||||
bsz, seq_len, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||||
if self.flash and seq_len != 1:
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
output = F.scaled_dot_product_attention(
|
||||
xq, xk, xv,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=True
|
||||
)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores += self.mask[:, :, :seq_len, :seq_len]
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
scores = self.attn_dropout(scores)
|
||||
output = scores @ xv
|
||||
|
||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
output = self.resid_dropout(self.wo(output))
|
||||
return output
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
if config.hidden_dim is None:
|
||||
hidden_dim = 4 * config.dim
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
|
||||
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
|
||||
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
|
||||
self.scoring_func = config.scoring_func
|
||||
self.alpha = config.aux_loss_alpha
|
||||
self.seq_aux = config.seq_aux
|
||||
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gating_dim = config.dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
import torch.nn.init as init
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
if self.scoring_func == 'softmax':
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
if self.seq_aux:
|
||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||
ce.scatter_add_(1, topk_idx_for_aux_loss,
|
||||
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
|
||||
seq_len * aux_topk / self.n_routed_experts)
|
||||
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
||||
else:
|
||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||
ce = mask_ce.float().mean(0)
|
||||
Pi = scores_for_aux.mean(0)
|
||||
fi = ce * self.n_routed_experts
|
||||
aux_loss = (Pi * fi).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = 0
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
|
||||
class MOEFeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(config)
|
||||
for _ in range(config.n_routed_experts)
|
||||
])
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
self.shared_experts = FeedForward(config)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
orig_shape = x.shape
|
||||
bsz, seq_len, _ = x.shape
|
||||
# 使用门控机制选择专家
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
|
||||
y = torch.empty_like(x, dtype=torch.float16)
|
||||
for i, expert in enumerate(self.experts):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape)
|
||||
else:
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(identity)
|
||||
self.aux_loss = aux_loss
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
token_idxs = idxs // self.config.num_experts_per_tok
|
||||
# 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
|
||||
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
|
||||
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
|
||||
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens).to(expert_cache.dtype)
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
|
||||
|
||||
return expert_cache
|
||||
|
||||
|
||||
class TripleExtractionHead(nn.Module):
|
||||
"""三元组提取任务头"""
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# 三元组长度超参数
|
||||
self.max_subject_len = config.max_subject_len
|
||||
self.max_predicate_len = config.max_predicate_len
|
||||
self.max_object_len = config.max_object_len
|
||||
|
||||
# 自注意力机制
|
||||
self.self_attention = Attention(config)
|
||||
self.self_attn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
|
||||
# 交叉注意力机制(用于主语和宾语提取)
|
||||
# self.cross_attention_subject = CrossAttention(config)
|
||||
# self.cross_attention_object = CrossAttention(config)
|
||||
|
||||
# 归一化层
|
||||
self.subject_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.object_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
|
||||
# Feed Forward 网络
|
||||
self.predicate_ff = FeedForward(config)
|
||||
# self.subject_ff = FeedForward(config)
|
||||
# self.object_ff = FeedForward(config)
|
||||
|
||||
# 输出投影层 - 修改为支持序列预测
|
||||
self.predicate_output = nn.Linear(config.dim, 264, bias=False)
|
||||
# self.subject_output = nn.Linear(config.dim, self.max_subject_len * config.dim, bias=False)
|
||||
# self.object_output = nn.Linear(config.dim, self.max_object_len * config.dim, bias=False)
|
||||
|
||||
print(f"三元组提取任务头配置:")
|
||||
print(f"- 主语最大长度: {self.max_subject_len}")
|
||||
print(f"- 谓语最大长度: {self.max_predicate_len}")
|
||||
print(f"- 宾语最大长度: {self.max_object_len}")
|
||||
|
||||
def forward(self, h, pos_cis):
|
||||
"""
|
||||
Args:
|
||||
h: [batch_size, seq_len, dim] - 来自transformer层的隐藏状态
|
||||
pos_cis: 位置编码
|
||||
Returns:
|
||||
predicate_logits: [batch_size, seq_len, max_predicate_len, vocab_size] - 谓语序列预测
|
||||
subject_logits: [batch_size, seq_len, max_subject_len, vocab_size] - 主语序列预测
|
||||
object_logits: [batch_size, seq_len, max_object_len, vocab_size] - 宾语序列预测
|
||||
"""
|
||||
batch_size, seq_len, dim = h.shape
|
||||
|
||||
# 1. h通过自注意力得到h1
|
||||
h1 = self.self_attention(self.self_attn_norm(h), pos_cis)
|
||||
h1 = h + h1 # 残差连接
|
||||
|
||||
# 2. h1通过feed_forward得到谓语输出
|
||||
predicate_features = self.predicate_ff(h1)
|
||||
predicate_features = predicate_features.mean(dim=1)
|
||||
predicate_class = self.predicate_output(predicate_features) # [batch_size, max_predicate_len * vocab_size]
|
||||
|
||||
# # 3. h1通过交叉注意力(k,v都是h)得到h2
|
||||
# h2 = self.cross_attention_subject(h1, h) # query是h1,key和value都是h
|
||||
# h2 = h1 + h2 # 残差连接
|
||||
|
||||
# # 4. h2通过feed_forward得到主语输出
|
||||
# subject_features = self.subject_ff(self.subject_norm(h2))
|
||||
# subject_features = subject_features.mean(dim=1)
|
||||
# subject_raw = self.subject_output(subject_features) # [batch_size, max_subject_len * vocab_size]
|
||||
# subject_logits = subject_raw.view(batch_size, self.max_subject_len, -1)
|
||||
|
||||
# # 5. h2通过交叉注意力(k,v都是h)得到h3
|
||||
# h3 = self.cross_attention_object(h2, h) # query是h2,key和value都是h
|
||||
# h3 = h2 + h3 # 残差连接
|
||||
|
||||
# # 6. h3通过feed_forward得到宾语输出
|
||||
# object_features = self.object_ff(self.object_norm(h3))
|
||||
# object_features = object_features.mean(dim=1)
|
||||
# object_raw = self.object_output(object_features) # [batch_size, max_object_len * vocab_size]
|
||||
# object_logits = object_raw.view(batch_size, self.max_object_len, -1)
|
||||
|
||||
return predicate_class
|
||||
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
|
||||
super().__init__()
|
||||
self.n_heads = config.n_heads
|
||||
self.dim = config.dim
|
||||
self.head_dim = config.dim // config.n_heads
|
||||
self.self_attention = Attention(config)
|
||||
self.cross_attention = CrossAttention(config)
|
||||
self.knowledge_dataset = knowledge_dataset
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||
|
||||
def forward(self, x, pos_cis):
|
||||
h_attn = self.self_attention(
|
||||
self.attention_norm(x),
|
||||
pos_cis
|
||||
)
|
||||
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
|
||||
h_attn = self.cross_attention(h_attn, db_embeddings)
|
||||
h = x + h_attn
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
class MiniMindLM(PreTrainedModel):
|
||||
config_class = LMConfig
|
||||
|
||||
def __init__(self, params: LMConfig = None,mode="triple"):
|
||||
self.params = params or LMConfig()
|
||||
super().__init__(self.params)
|
||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||||
self.dropout = nn.Dropout(params.dropout)
|
||||
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
|
||||
# 添加三元组提取任务头(可训练)
|
||||
self.triple_extraction_head = TripleExtractionHead(params)
|
||||
self.register_buffer("pos_cis",
|
||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
self.OUT = CausalLMOutputWithPast()
|
||||
self.freeze_embedding = False
|
||||
|
||||
self.mode = mode
|
||||
|
||||
# 冻结所有指定组件的权重
|
||||
self._freeze_components()
|
||||
|
||||
def _freeze_components(self):
|
||||
"""冻结指定组件的权重"""
|
||||
# 冻结词嵌入层
|
||||
for param in self.tok_embeddings.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# 冻结知识数据库
|
||||
for param in self.knowledge_dataset.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# 冻结所有transformer层
|
||||
for param in self.layers.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# 冻结输出层
|
||||
for param in self.output.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# pos_cis是buffer,本身就不需要梯度,但为了明确起见
|
||||
# (实际上buffer默认就是requires_grad=False)
|
||||
if hasattr(self, 'pos_cis'):
|
||||
self.pos_cis.requires_grad = False
|
||||
|
||||
print("已冻结以下组件的权重:")
|
||||
print("- tok_embeddings")
|
||||
print("- knowledge_dataset")
|
||||
print("- layers (所有transformer层)")
|
||||
print("- output")
|
||||
print("- pos_cis")
|
||||
print("注意:triple_extraction_head 保持可训练状态")
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
step: int = 0,
|
||||
**args):
|
||||
start_pos = args.get('start_pos', 0)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
for l, layer in enumerate(self.layers):
|
||||
h = layer(
|
||||
h, pos_cis
|
||||
)
|
||||
|
||||
# 应用三元组提取任务头
|
||||
predicate_class = self.triple_extraction_head(h, pos_cis)
|
||||
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
||||
|
||||
# 进一步简化,只保留必要的参数
|
||||
output = CausalLMOutputWithPast(
|
||||
logits=logits,
|
||||
)
|
||||
output.hidden_states = h
|
||||
output.aux_loss = aux_loss
|
||||
|
||||
# 添加三元组提取结果
|
||||
# 注意:现在的维度是 [batch_size, seq_len, max_len, vocab_size]
|
||||
output.predicate_class = predicate_class
|
||||
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||||
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
|
||||
# 流式生成
|
||||
if stream:
|
||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
|
||||
# 直接生成
|
||||
generated = []
|
||||
for i in range(input_ids.size(0)):
|
||||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||||
for _ in range(num_return_sequences):
|
||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
tokens_list = [tokens[:, -1:] for tokens in out]
|
||||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||||
generated.append(full_sequence)
|
||||
|
||||
max_length = max(seq.size(1) for seq in generated)
|
||||
generated = [
|
||||
torch.cat(
|
||||
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
|
||||
dim=-1)
|
||||
for seq in generated
|
||||
]
|
||||
output = torch.cat(generated, dim=0)
|
||||
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||||
return res
|
||||
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
|
||||
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
||||
while input_ids.shape[1] < max_new_tokens - 1:
|
||||
if first_seq:
|
||||
out, first_seq = self(input_ids, **args), False
|
||||
else:
|
||||
out = self(input_ids[:, -1:],
|
||||
start_pos=input_ids.shape[1] - 1, **args)
|
||||
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
||||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||||
logits /= (temperature + 1e-9)
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||
sorted_indices_to_remove[:, 0] = False
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = -float('Inf')
|
||||
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
||||
yield input_ids[:, start:]
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
||||
|
@ -1,488 +0,0 @@
|
||||
import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
import gc
|
||||
#子空间二维分解+梯度更新
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self.weight * self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return pos_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(xq, xk, pos_cis):
|
||||
def unite_shape(pos_cis, x):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert pos_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return pos_cis.view(*shape)
|
||||
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
pos_cis = unite_shape(pos_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
class KnowledgeDataset(nn.Module):
|
||||
def __init__(self, params, tok_embeddings, is_train=True):
|
||||
super().__init__()
|
||||
self.is_train = is_train
|
||||
self.params = params
|
||||
self.tok_embeddings = tok_embeddings
|
||||
|
||||
# 嵌入参数
|
||||
self.knowledge_dim = params.knowledge_dim
|
||||
self.key_dim = self.knowledge_dim // 2
|
||||
self.to_queries = nn.Sequential(
|
||||
nn.Linear(params.dim, self.knowledge_dim, bias=False),
|
||||
)
|
||||
|
||||
## 数据库参数
|
||||
self.knowledge_num = params.knowledge_num
|
||||
self.knowledge_length = params.knowledge_length
|
||||
|
||||
# 修改键存储为二维分解空间,设置为可训练参数
|
||||
self.num_keys = int(math.sqrt(self.knowledge_num))
|
||||
# 确保keys是可训练参数
|
||||
self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.key_dim) * 0.02, requires_grad=True)
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
|
||||
# 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度
|
||||
self.register_buffer('knowledge_dataset',
|
||||
torch.randint(low=0, high=params.vocab_size, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long))
|
||||
|
||||
# 计算step数目,用于动态调整权重
|
||||
self.step_counter = 0
|
||||
|
||||
# 移除批次计数器和更新频率相关代码
|
||||
|
||||
def intelligent_selection(self, query, all_scores, all_indices):
|
||||
"""智能分层选择策略"""
|
||||
if self.is_train == False:
|
||||
return all_scores, all_indices
|
||||
|
||||
batch_size = all_scores.size(0)
|
||||
device = all_scores.device
|
||||
dtype = all_scores.dtype
|
||||
|
||||
# 记录进入智能选择前的内存状态
|
||||
if hasattr(self, 'step_counter'):
|
||||
self.step_counter += 1
|
||||
# 禁用GPU内存监控记录以提高性能
|
||||
# if self.step_counter % 50 == 0: # 每50次调用记录一次
|
||||
# if torch.cuda.is_available():
|
||||
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[INTEL_SELECT_ENTER] Step {self.step_counter}: GPU Memory: {allocated_before:.2f}GB")
|
||||
|
||||
# 对每个batch进行分层选择
|
||||
enhanced_scores = all_scores.clone()
|
||||
query_features = query.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
# 预先计算所有候选条目的嵌入(批量优化)
|
||||
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
|
||||
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
|
||||
|
||||
# 批量计算唯一候选条目的嵌入
|
||||
candidate_tokens = self.knowledge_dataset[unique_indices]
|
||||
flat_tokens = candidate_tokens.view(-1)
|
||||
flat_embeddings = self.tok_embeddings(flat_tokens)
|
||||
|
||||
# 获取flat_tokens对应的index(保留这些变量以便其他地方使用)
|
||||
pre_update_indices = unique_indices.view(-1)
|
||||
pre_update_embeddings = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
)
|
||||
|
||||
unique_candidate_features = flat_embeddings.view(
|
||||
len(unique_indices), self.knowledge_length, -1
|
||||
).mean(dim=1) # [num_unique_candidates, dim]
|
||||
|
||||
# 归一化候选特征(优化相似度计算)
|
||||
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
|
||||
normalized_queries = F.normalize(query_features, dim=-1)
|
||||
|
||||
# 收集所有batch的best_tokens
|
||||
batch_best_tokens = []
|
||||
batch_best_tokens_embeddings = []
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
indices = all_indices[batch_idx]
|
||||
|
||||
# 获取当前batch候选条目对应的特征索引
|
||||
start_idx = batch_idx * len(indices)
|
||||
end_idx = start_idx + len(indices)
|
||||
batch_inverse_indices = inverse_indices[start_idx:end_idx]
|
||||
|
||||
# 使用预计算的归一化特征进行优化相似度计算
|
||||
batch_candidate_features = normalized_candidates[batch_inverse_indices]
|
||||
query_feature = normalized_queries[batch_idx]
|
||||
|
||||
# 使用矩阵乘法计算余弦相似度
|
||||
similarity_scores = torch.mv(batch_candidate_features, query_feature)
|
||||
|
||||
# 找到最大相似度分数的索引
|
||||
max_similarity_idx = torch.argmax(similarity_scores)
|
||||
|
||||
# 获取最大相似度对应的候选条目索引
|
||||
best_candidate_idx = indices[max_similarity_idx]
|
||||
|
||||
# 获取对应的tokens
|
||||
best_tokens = self.knowledge_dataset[best_candidate_idx]
|
||||
best_tokens_embeddings = self.tok_embeddings(best_tokens)
|
||||
|
||||
# 将当前batch的best_tokens添加到列表中
|
||||
batch_best_tokens.append(best_tokens)
|
||||
batch_best_tokens_embeddings.append(best_tokens_embeddings)
|
||||
|
||||
# 将所有batch的best_tokens堆叠成一个张量
|
||||
# [batch_size, knowledge_length]
|
||||
all_best_tokens = torch.stack(batch_best_tokens, dim=0)
|
||||
all_best_tokens_embeddings = torch.stack(batch_best_tokens_embeddings, dim=0)
|
||||
|
||||
# 清理中间张量以防止内存泄漏
|
||||
del all_candidate_indices, unique_indices, inverse_indices
|
||||
del unique_candidate_features, normalized_candidates, normalized_queries
|
||||
del batch_best_tokens, batch_best_tokens_embeddings
|
||||
del flat_tokens, flat_embeddings, pre_update_embeddings
|
||||
|
||||
# 记录退出智能选择后的内存状态(已禁用以提高性能)
|
||||
# if hasattr(self, 'step_counter') and self.step_counter % 50 == 0:
|
||||
# if torch.cuda.is_available():
|
||||
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[INTEL_SELECT_EXIT] Step {self.step_counter}: GPU Memory: {allocated_after:.2f}GB")
|
||||
|
||||
# 强制垃圾回收(仅在监控步骤)
|
||||
if hasattr(self, 'step_counter') and self.step_counter % 100 == 0:
|
||||
gc.collect()
|
||||
# if torch.cuda.is_available():
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
return all_best_tokens, all_best_tokens_embeddings
|
||||
|
||||
|
||||
|
||||
def search_index(self, x):
|
||||
batch_size, seq_len, dim = x.shape
|
||||
|
||||
# 1. 序列维度平均
|
||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||
|
||||
# 2. 生成查询向量并重塑为两个子查询
|
||||
queries = self.to_queries(x_flat) # [batch_size, knowledge_dim]
|
||||
queries = queries.reshape(batch_size, 2, self.key_dim) # [batch_size, 2, key_dim]
|
||||
# 调整维度顺序,使子空间维度位于首位
|
||||
queries = queries.permute(1, 0, 2) # [2, batch_size, key_dim]
|
||||
|
||||
# 3. 计算每个子空间的相似度
|
||||
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
||||
|
||||
# 4. 在两个子空间分别做top-k
|
||||
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
||||
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
||||
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
||||
|
||||
# 5. 组合两个子空间的结果
|
||||
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) # [batch_size, topk, topk]
|
||||
|
||||
# 6. 将结果重塑为二维
|
||||
all_scores = all_scores.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
all_indices = all_indices.reshape(batch_size, -1) # [batch_size, topk*topk]
|
||||
|
||||
# 7. 选择最终的top-k结果
|
||||
scores, indices_of_indices = all_scores.topk(self.product_key_topk, dim=-1)
|
||||
indices = torch.gather(all_indices, 1, indices_of_indices)
|
||||
|
||||
# 8. 应用智能分层选择策略
|
||||
best_tokens, best_tokens_embeddings = self.intelligent_selection(x, scores, indices)
|
||||
|
||||
|
||||
return best_tokens, best_tokens_embeddings
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = 8
|
||||
self.head_dim = self.config.dim // self.num_heads
|
||||
self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False)
|
||||
|
||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||
batch_size = x.size(0)
|
||||
|
||||
# 监控交叉注意力开始时的内存(已禁用以提高性能)
|
||||
if not hasattr(self, 'call_counter'):
|
||||
self.call_counter = 0
|
||||
self.call_counter += 1
|
||||
|
||||
# 禁用GPU内存监控记录以提高性能
|
||||
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
|
||||
# allocated_before = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[CROSS_ATTN_ENTER] Call {self.call_counter}: GPU Memory: {allocated_before:.2f}GB")
|
||||
|
||||
# 分离多头
|
||||
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if pos_emb is not None:
|
||||
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q = q + pos_emb
|
||||
k = k + pos_emb
|
||||
v = v + pos_emb
|
||||
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
if context_mask is not None:
|
||||
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
||||
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
|
||||
context = torch.matmul(attn_weights, v)
|
||||
|
||||
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim)
|
||||
|
||||
context = self.to_out(context)
|
||||
|
||||
# 清理中间张量
|
||||
del q, k, v, attn_scores, attn_weights
|
||||
|
||||
# 监控交叉注意力结束时的内存(已禁用以提高性能)
|
||||
# if self.call_counter % 100 == 0 and torch.cuda.is_available():
|
||||
# allocated_after = torch.cuda.memory_allocated() / (1024**3)
|
||||
# print(f"[CROSS_ATTN_EXIT] Call {self.call_counter}: GPU Memory: {allocated_after:.2f}GB")
|
||||
|
||||
return context
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: LMConfig):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
assert args.n_heads % self.n_kv_heads == 0
|
||||
self.n_local_heads = args.n_heads
|
||||
self.n_local_kv_heads = self.n_kv_heads
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
self.dropout = args.dropout
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
pos_cis: torch.Tensor):
|
||||
bsz, seq_len, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||||
if self.flash and seq_len != 1:
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
output = F.scaled_dot_product_attention(
|
||||
xq, xk, xv,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=True
|
||||
)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores += self.mask[:, :, :seq_len, :seq_len]
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
scores = self.attn_dropout(scores)
|
||||
output = scores @ xv
|
||||
|
||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
output = self.resid_dropout(self.wo(output))
|
||||
return output
|
||||
|
||||
|
||||
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, config: LMConfig, knowledge_dataset: KnowledgeDataset):
|
||||
super().__init__()
|
||||
self.n_heads = config.n_heads
|
||||
self.dim = config.dim
|
||||
self.head_dim = config.dim // config.n_heads
|
||||
self.self_attention = Attention(config)
|
||||
self.cross_attention = CrossAttention(config)
|
||||
self.knowledge_dataset = knowledge_dataset
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
# 移除 ffn_norm 和 feed_forward,因为不再使用 FeedForward 层
|
||||
|
||||
def forward(self, x, pos_cis):
|
||||
h_attn = self.self_attention(
|
||||
self.attention_norm(x),
|
||||
pos_cis
|
||||
)
|
||||
db, db_embeddings = self.knowledge_dataset.search_index(h_attn)
|
||||
h_attn = self.cross_attention(h_attn, db_embeddings)
|
||||
h = x + h_attn
|
||||
# 移除 FeedForward 层,直接返回注意力输出
|
||||
return h
|
||||
|
||||
|
||||
class MiniMindLM(PreTrainedModel):
|
||||
config_class = LMConfig
|
||||
|
||||
def __init__(self, params: LMConfig = None):
|
||||
self.params = params or LMConfig()
|
||||
super().__init__(self.params)
|
||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||||
self.dropout = nn.Dropout(params.dropout)
|
||||
self.knowledge_dataset = KnowledgeDataset(params, self.tok_embeddings)
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params, self.knowledge_dataset) for l in range(self.n_layers)])
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
self.register_buffer("pos_cis",
|
||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
self.OUT = CausalLMOutputWithPast()
|
||||
self.freeze_embedding = False
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
step: int = 0,
|
||||
**args):
|
||||
start_pos = args.get('start_pos', 0)
|
||||
# if self.freeze_embedding and step == 0:
|
||||
# self.tok_embeddings.weight.requires_grad = False
|
||||
# # 移除对knowledge_dataset.freeze_embedding的设置,让键更新由batch_counter控制
|
||||
# # self.knowledge_dataset.freeze_embedding = True
|
||||
# print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
for l, layer in enumerate(self.layers):
|
||||
h = layer(
|
||||
h, pos_cis
|
||||
)
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
# 移除 aux_loss 计算,因为不再使用 FeedForward 层
|
||||
aux_loss = 0
|
||||
|
||||
# 进一步简化,只保留必要的参数
|
||||
output = CausalLMOutputWithPast(
|
||||
logits=logits,
|
||||
)
|
||||
output.hidden_states = h
|
||||
|
||||
output.aux_loss = aux_loss
|
||||
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||||
stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args):
|
||||
# 流式生成
|
||||
if stream:
|
||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
|
||||
# 直接生成
|
||||
generated = []
|
||||
for i in range(input_ids.size(0)):
|
||||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||||
for _ in range(num_return_sequences):
|
||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args)
|
||||
tokens_list = [tokens[:, -1:] for tokens in out]
|
||||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||||
generated.append(full_sequence)
|
||||
|
||||
max_length = max(seq.size(1) for seq in generated)
|
||||
generated = [
|
||||
torch.cat(
|
||||
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
|
||||
dim=-1)
|
||||
for seq in generated
|
||||
]
|
||||
output = torch.cat(generated, dim=0)
|
||||
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||||
return res
|
||||
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args):
|
||||
start = input_ids.shape[1]
|
||||
for _ in range(max_new_tokens):
|
||||
# 每次都传入完整的input_ids,不使用KV缓存
|
||||
out = self(input_ids, **args)
|
||||
logits = out.logits[:, -1, :] # 取最后一个位置的logits
|
||||
|
||||
# 重复惩罚
|
||||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||||
|
||||
# 温度采样
|
||||
logits /= (temperature + 1e-9)
|
||||
|
||||
# Top-p采样
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||
sorted_indices_to_remove[:, 0] = False
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = -float('Inf')
|
||||
|
||||
# 采样下一个token
|
||||
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
||||
|
||||
# 返回新生成的部分
|
||||
yield input_ids[:, start:]
|
||||
|
||||
# 如果遇到结束token,停止生成
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
@ -1,386 +0,0 @@
|
||||
import math
|
||||
import struct
|
||||
import inspect
|
||||
import time
|
||||
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self.weight * self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return pos_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(xq, xk, pos_cis):
|
||||
def unite_shape(pos_cis, x):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert pos_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return pos_cis.view(*shape)
|
||||
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
pos_cis = unite_shape(pos_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||||
bs, slen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x[:, :, :, None, :]
|
||||
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
||||
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: LMConfig):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
assert args.n_heads % self.n_kv_heads == 0
|
||||
self.n_local_heads = args.n_heads
|
||||
self.n_local_kv_heads = self.n_kv_heads
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
self.dropout = args.dropout
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
pos_cis: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache=False):
|
||||
bsz, seq_len, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
||||
# kv_cache实现
|
||||
if past_key_value is not None:
|
||||
xk = torch.cat([past_key_value[0], xk], dim=1)
|
||||
xv = torch.cat([past_key_value[1], xv], dim=1)
|
||||
past_kv = (xk, xv) if use_cache else None
|
||||
|
||||
xq, xk, xv = (
|
||||
xq.transpose(1, 2),
|
||||
repeat_kv(xk, self.n_rep).transpose(1, 2),
|
||||
repeat_kv(xv, self.n_rep).transpose(1, 2)
|
||||
)
|
||||
if self.flash and seq_len != 1:
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
output = F.scaled_dot_product_attention(
|
||||
xq, xk, xv,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=True
|
||||
)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores += self.mask[:, :, :seq_len, :seq_len]
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
scores = self.attn_dropout(scores)
|
||||
output = scores @ xv
|
||||
|
||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
output = self.resid_dropout(self.wo(output))
|
||||
return output, past_kv
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
if config.hidden_dim is None:
|
||||
hidden_dim = 4 * config.dim
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
|
||||
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
|
||||
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
|
||||
self.scoring_func = config.scoring_func
|
||||
self.alpha = config.aux_loss_alpha
|
||||
self.seq_aux = config.seq_aux
|
||||
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gating_dim = config.dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
import torch.nn.init as init
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
if self.scoring_func == 'softmax':
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
if self.seq_aux:
|
||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||
ce.scatter_add_(1, topk_idx_for_aux_loss,
|
||||
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
|
||||
seq_len * aux_topk / self.n_routed_experts)
|
||||
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
||||
else:
|
||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||
ce = mask_ce.float().mean(0)
|
||||
Pi = scores_for_aux.mean(0)
|
||||
fi = ce * self.n_routed_experts
|
||||
aux_loss = (Pi * fi).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = 0
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
|
||||
class MOEFeedForward(nn.Module):
|
||||
def __init__(self, config: LMConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(config)
|
||||
for _ in range(config.n_routed_experts)
|
||||
])
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
self.shared_experts = FeedForward(config)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
orig_shape = x.shape
|
||||
bsz, seq_len, _ = x.shape
|
||||
# 使用门控机制选择专家
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
|
||||
y = torch.empty_like(x, dtype=torch.float16)
|
||||
for i, expert in enumerate(self.experts):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape)
|
||||
else:
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(identity)
|
||||
self.aux_loss = aux_loss
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
token_idxs = idxs // self.config.num_experts_per_tok
|
||||
# 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
|
||||
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
|
||||
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
|
||||
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens).to(expert_cache.dtype)
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
|
||||
|
||||
return expert_cache
|
||||
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, config: LMConfig):
|
||||
super().__init__()
|
||||
self.n_heads = config.n_heads
|
||||
self.dim = config.dim
|
||||
self.head_dim = config.dim // config.n_heads
|
||||
self.attention = Attention(config)
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||
|
||||
def forward(self, x, pos_cis, past_key_value=None, use_cache=False):
|
||||
h_attn, past_kv = self.attention(
|
||||
self.attention_norm(x),
|
||||
pos_cis,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache
|
||||
)
|
||||
h = x + h_attn
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out, past_kv
|
||||
|
||||
|
||||
class MiniMindLM(PreTrainedModel):
|
||||
config_class = LMConfig
|
||||
|
||||
def __init__(self, params: LMConfig = None):
|
||||
self.params = params or LMConfig()
|
||||
super().__init__(self.params)
|
||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||||
self.dropout = nn.Dropout(params.dropout)
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
self.register_buffer("pos_cis",
|
||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
self.OUT = CausalLMOutputWithPast()
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**args):
|
||||
past_key_values = past_key_values or [None] * len(self.layers)
|
||||
start_pos = args.get('start_pos', 0)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
past_kvs = []
|
||||
for l, layer in enumerate(self.layers):
|
||||
h, past_kv = layer(
|
||||
h, pos_cis,
|
||||
past_key_value=past_key_values[l],
|
||||
use_cache=use_cache
|
||||
)
|
||||
past_kvs.append(past_kv)
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
# 统一不使用 aux_loss
|
||||
aux_loss = 0
|
||||
self.OUT.__setitem__('last_hidden_state', h)
|
||||
self.OUT.__setitem__('logits', logits)
|
||||
self.OUT.__setitem__('aux_loss', aux_loss)
|
||||
self.OUT.__setitem__('past_key_values', past_kvs)
|
||||
return self.OUT
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||||
stream=False, rp=1., use_cache=True, pad_token_id=0, num_return_sequences=1, **args):
|
||||
# 流式生成
|
||||
if stream:
|
||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
||||
|
||||
# 直接生成
|
||||
generated = []
|
||||
for i in range(input_ids.size(0)):
|
||||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||||
for _ in range(num_return_sequences):
|
||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
||||
tokens_list = [tokens[:, -1:] for tokens in out]
|
||||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||||
generated.append(full_sequence)
|
||||
|
||||
max_length = max(seq.size(1) for seq in generated)
|
||||
generated = [
|
||||
torch.cat(
|
||||
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
|
||||
dim=-1)
|
||||
for seq in generated
|
||||
]
|
||||
output = torch.cat(generated, dim=0)
|
||||
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||||
return res
|
||||
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
|
||||
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
||||
while input_ids.shape[1] < max_new_tokens - 1:
|
||||
if first_seq or not use_cache:
|
||||
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False
|
||||
else:
|
||||
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
|
||||
start_pos=input_ids.shape[1] - 1, **args)
|
||||
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
||||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||||
logits /= (temperature + 1e-9)
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||
sorted_indices_to_remove[:, 0] = False
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = -float('Inf')
|
||||
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
||||
yield input_ids[:, start:]
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
@ -1,43 +0,0 @@
|
||||
{
|
||||
"add_bos_token": false,
|
||||
"add_eos_token": false,
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "<|im_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "<|im_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [],
|
||||
"bos_token": "<|im_start|>",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "<|im_end|>",
|
||||
"legacy": true,
|
||||
"model_max_length": 32768,
|
||||
"pad_token": "<unk>",
|
||||
"sp_model_kwargs": {},
|
||||
"spaces_between_special_tokens": false,
|
||||
"tokenizer_class": "PreTrainedTokenizerFast",
|
||||
"unk_token": "<unk>",
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\n你是 MiniMind,是一个有用的人工智能助手。<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"
|
||||
}
|
@ -1,133 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import datetime
|
||||
from typing import List, Dict, Any
|
||||
|
||||
# 配置参数
|
||||
json_path = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/dataset/combined.json"
|
||||
prepare_num = 1048576 # database_init.json的数据条数,可以根据需要修改
|
||||
output_dir = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind/dataset/"
|
||||
|
||||
def convert_to_database_init_format(sentences: List[str], importance_score: float = 10.0) -> Dict[str, Any]:
|
||||
"""
|
||||
将句子列表转换为 database_init.json 格式
|
||||
|
||||
Args:
|
||||
sentences: 句子列表
|
||||
importance_score: 重要性评分,默认为10.0
|
||||
|
||||
Returns:
|
||||
转换后的字典格式数据
|
||||
"""
|
||||
# 构建句子数据
|
||||
sentence_data = []
|
||||
for sentence in sentences:
|
||||
sentence_item = {
|
||||
"original_sentence": sentence,
|
||||
"corrected_sentence": sentence, # 与original_sentence相同
|
||||
"importance_score": importance_score
|
||||
}
|
||||
sentence_data.append(sentence_item)
|
||||
|
||||
# 构建完整的数据结构
|
||||
result = {
|
||||
"metadata": {
|
||||
"batch_number": 1,
|
||||
"batch_size": len(sentences),
|
||||
"total_processed_count": len(sentences),
|
||||
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"total_sentences": len(sentences),
|
||||
"duplicates_removed": 0 # 在此函数中不涉及去重,所以设为0
|
||||
},
|
||||
"sentences": sentence_data
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def preprocess_combined_json():
|
||||
# 读取原始数据
|
||||
print("正在读取combined.json...")
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
total_count = len(data)
|
||||
print(f"总共有 {total_count} 条数据")
|
||||
|
||||
# 处理所有数据:将subject、predicate、object拼接成句子,同时记录原始数据
|
||||
print("正在处理数据并拼接句子...")
|
||||
sentence_to_original = {} # 记录句子到原始数据的映射
|
||||
all_sentences = []
|
||||
|
||||
for i, item in enumerate(data):
|
||||
# 拼接subject、predicate、object为一句话
|
||||
sentence = f"{item['subject']} {item['predicate']} {item['object']}"
|
||||
all_sentences.append(sentence)
|
||||
|
||||
# 记录句子到原始数据的映射(如果句子重复,保留第一次出现的原始数据)
|
||||
if sentence not in sentence_to_original:
|
||||
sentence_to_original[sentence] = item
|
||||
|
||||
if (i + 1) % 100000 == 0:
|
||||
print(f"已处理 {i + 1}/{total_count} 条数据")
|
||||
|
||||
print(f"完成句子拼接,共 {len(all_sentences)} 条句子")
|
||||
|
||||
# 去重处理
|
||||
print("正在进行去重处理...")
|
||||
unique_sentences = list(set(all_sentences))
|
||||
duplicates_removed = len(all_sentences) - len(unique_sentences)
|
||||
print(f"去重完成,去重前: {len(all_sentences)} 条,去重后: {len(unique_sentences)} 条,移除重复: {duplicates_removed} 条")
|
||||
|
||||
# 检查是否有足够的去重数据
|
||||
if len(unique_sentences) < prepare_num:
|
||||
print(f"警告: 去重后的数据量 ({len(unique_sentences)}) 少于所需数量 ({prepare_num})")
|
||||
print(f"将使用全部 {len(unique_sentences)} 条去重数据")
|
||||
selected_sentences = unique_sentences
|
||||
else:
|
||||
print(f"选择前 {prepare_num} 条去重数据")
|
||||
selected_sentences = unique_sentences[:prepare_num]
|
||||
|
||||
# 转换为database_init.json格式
|
||||
print("正在转换为database_init.json格式...")
|
||||
database_init_data = convert_to_database_init_format(selected_sentences, importance_score=10.0)
|
||||
|
||||
# 更新metadata中的duplicates_removed信息
|
||||
database_init_data["metadata"]["duplicates_removed"] = duplicates_removed
|
||||
|
||||
# 保存database_init.json
|
||||
database_output_path = os.path.join(output_dir, "database_init_from_combined.json")
|
||||
print(f"正在保存 {database_output_path}...")
|
||||
with open(database_output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(database_init_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"database_init_from_combined.json 保存完成,包含 {len(selected_sentences)} 条数据")
|
||||
|
||||
# 保存剩余数据作为训练集(保持原格式)
|
||||
remaining_sentences = unique_sentences[prepare_num:] if len(unique_sentences) > prepare_num else []
|
||||
if remaining_sentences:
|
||||
# 将剩余的句子转换回原始格式
|
||||
print(f"正在转换剩余 {len(remaining_sentences)} 条数据为原始格式...")
|
||||
remaining_original_data = []
|
||||
for sentence in remaining_sentences:
|
||||
if sentence in sentence_to_original:
|
||||
remaining_original_data.append(sentence_to_original[sentence])
|
||||
|
||||
print(f"保存剩余 {len(remaining_original_data)} 条数据作为训练集...")
|
||||
train_output_path = os.path.join(output_dir, "combined_train.json")
|
||||
with open(train_output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(remaining_original_data, f, ensure_ascii=False, indent=2)
|
||||
print(f"combined_train.json 保存完成")
|
||||
else:
|
||||
print("没有剩余数据用于训练集")
|
||||
remaining_original_data = []
|
||||
|
||||
print("\n数据处理完成!")
|
||||
print(f"原始数据: {total_count} 条")
|
||||
print(f"拼接后: {len(all_sentences)} 条句子")
|
||||
print(f"去重后: {len(unique_sentences)} 条句子")
|
||||
print(f"用于database_init: {len(selected_sentences)} 条")
|
||||
print(f"剩余训练数据: {len(remaining_original_data) if remaining_sentences else 0} 条")
|
||||
|
||||
if __name__ == "__main__":
|
||||
preprocess_combined_json()
|
@ -1,741 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import pandas as pd
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import re
|
||||
import langdetect
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
import random
|
||||
import hashlib
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 配置参数
|
||||
BASE_DIR = "/home/pci/nas/AI_Large_Model_Team/ycz/Minimind"
|
||||
OUTPUT_FILE = os.path.join(BASE_DIR, "dataset", "merged_pretrain.jsonl")
|
||||
OUTPUT_FILE_EXTRA = os.path.join(BASE_DIR, "dataset", "merged_pretrain_extra.jsonl")
|
||||
|
||||
# 数据源路径
|
||||
PRETRAIN_HQ_PATH = os.path.join(BASE_DIR, "dataset", "pretrain_hq.jsonl")
|
||||
WIKIPEDIA_PATH = "/home/pci/nas/share/datasets/wikipedia/data/20220301.en"
|
||||
GUTENBERG_PATH = "/home/pci/nas/share/datasets/gutenberg/data"
|
||||
OPENWEBTEXT_PATH = "/home/pci/nas/share/datasets/openwebtext/subsets"
|
||||
|
||||
# Token长度限制
|
||||
MIN_TOKENS = 410
|
||||
MAX_TOKENS = 490
|
||||
|
||||
# 数据集质量和采样比例配置 - 主文件
|
||||
DATASET_CONFIG = {
|
||||
"pretrain_hq": {"quality": "high", "sample_ratio": 1.0, "max_samples": None}, # 高质量,全部使用
|
||||
"wikipedia": {"quality": "high", "sample_ratio": 1.0, "max_samples": 5000000}, # 高质量,使用全部,最多500万条
|
||||
"gutenberg": {"quality": "medium", "sample_ratio": 0.8, "max_samples": 3000000}, # 中质量,使用80%,最多300万条
|
||||
"openwebtext": {"quality": "low", "sample_ratio": 0.2, "max_samples": 2000000} # 低质量,使用20%,最多200万条
|
||||
}
|
||||
|
||||
# 额外文件的配置 - 剩余数据
|
||||
DATASET_CONFIG_EXTRA = {
|
||||
"wikipedia": {"quality": "high", "sample_ratio": 1.0, "max_samples": None}, # 剩余的全部
|
||||
"gutenberg": {"quality": "medium", "sample_ratio": 0.8, "max_samples": 5000000}, # 剩余的80%,最多500万条
|
||||
"openwebtext": {"quality": "low", "sample_ratio": 0.6, "max_samples": 4000000} # 剩余的60%,最多400万条
|
||||
}
|
||||
|
||||
# 全局变量:记录已选择的数据
|
||||
selected_data_hashes = {
|
||||
"wikipedia": set(),
|
||||
"gutenberg": set(),
|
||||
"openwebtext": set()
|
||||
}
|
||||
|
||||
# 初始化tokenizer
|
||||
tokenizer = None
|
||||
|
||||
def init_tokenizer():
|
||||
"""初始化tokenizer"""
|
||||
global tokenizer
|
||||
try:
|
||||
# 首先尝试使用本地的minimind tokenizer
|
||||
local_tokenizer_path = os.path.join(BASE_DIR, "model", "minimind_tokenizer")
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
tokenizer = AutoTokenizer.from_pretrained(local_tokenizer_path)
|
||||
logger.info("Local MiniMind tokenizer initialized successfully")
|
||||
else:
|
||||
# 如果本地tokenizer不存在,使用GPT-2(但设置离线模式)
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2", local_files_only=True)
|
||||
logger.info("GPT-2 tokenizer initialized successfully (offline)")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tokenizer: {e}")
|
||||
logger.info("Trying to use a simple fallback tokenizer...")
|
||||
# 使用简单的分词方法作为备选
|
||||
tokenizer = None
|
||||
logger.warning("Using simple word-based tokenization as fallback")
|
||||
|
||||
def count_tokens(text):
|
||||
"""计算文本的token数量"""
|
||||
if tokenizer is None:
|
||||
init_tokenizer()
|
||||
|
||||
if tokenizer is not None:
|
||||
try:
|
||||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||||
return len(tokens)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 如果tokenization失败或tokenizer为None,使用简单估算
|
||||
return int(len(text.split()) * 1.3) # 大概估算,确保返回整数
|
||||
|
||||
def is_english_text(text, threshold=0.8):
|
||||
"""检测文本是否为英文"""
|
||||
try:
|
||||
if len(text) < 50: # 太短的文本跳过检测
|
||||
return True
|
||||
detected_lang = langdetect.detect(text)
|
||||
return detected_lang == 'en'
|
||||
except:
|
||||
# 如果检测失败,使用简单的英文字符比例判断
|
||||
english_chars = sum(1 for char in text if char.isascii() and (char.isalpha() or char in ' .,!?-'))
|
||||
total_chars = len(text.replace('\n', ' ').replace('\t', ' '))
|
||||
return (english_chars / max(total_chars, 1)) > threshold
|
||||
|
||||
def truncate_to_token_limit(text, target_tokens=MAX_TOKENS):
|
||||
"""将文本截断到目标token数量"""
|
||||
if tokenizer is None:
|
||||
init_tokenizer()
|
||||
|
||||
if tokenizer is not None:
|
||||
try:
|
||||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||||
if len(tokens) <= target_tokens:
|
||||
return text
|
||||
|
||||
# 截断到目标长度
|
||||
truncated_tokens = tokens[:target_tokens]
|
||||
truncated_text = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
|
||||
|
||||
# 尝试在句号处截断以保持完整性
|
||||
sentences = truncated_text.split('.')
|
||||
if len(sentences) > 1:
|
||||
# 保留除最后一个不完整句子外的所有句子
|
||||
truncated_text = '.'.join(sentences[:-1]) + '.'
|
||||
|
||||
return truncated_text
|
||||
except:
|
||||
pass
|
||||
|
||||
# 如果处理失败或tokenizer为None,使用字符数估算
|
||||
estimated_chars = int(target_tokens / 1.3 * 4) # 大概估算
|
||||
text = text[:estimated_chars]
|
||||
|
||||
# 尝试在句号处截断以保持完整性
|
||||
sentences = text.split('.')
|
||||
if len(sentences) > 1:
|
||||
text = '.'.join(sentences[:-1]) + '.'
|
||||
|
||||
return text
|
||||
|
||||
def split_text_into_chunks(text, target_chunk_size=1500):
|
||||
"""将长文本分割成多个中等长度的段落块"""
|
||||
# 清理文本
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# 移除过多的换行符和空格
|
||||
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
||||
text = re.sub(r' +', ' ', text)
|
||||
|
||||
chunks = []
|
||||
|
||||
# 按段落分割
|
||||
paragraphs = text.split('\n\n')
|
||||
current_chunk = ""
|
||||
|
||||
for paragraph in paragraphs:
|
||||
paragraph = paragraph.strip()
|
||||
if not paragraph:
|
||||
continue
|
||||
|
||||
# 如果当前块加上新段落长度适中,就添加
|
||||
if len(current_chunk) + len(paragraph) < target_chunk_size:
|
||||
if current_chunk:
|
||||
current_chunk += "\n\n" + paragraph
|
||||
else:
|
||||
current_chunk = paragraph
|
||||
else:
|
||||
# 如果当前块不为空,保存它
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
# 如果段落本身就很长,需要进一步分割
|
||||
if len(paragraph) > target_chunk_size * 2:
|
||||
# 按句子分割长段落
|
||||
sentences = re.split(r'(?<=[.!?])\s+', paragraph)
|
||||
temp_chunk = ""
|
||||
|
||||
for sentence in sentences:
|
||||
if len(temp_chunk) + len(sentence) < target_chunk_size:
|
||||
if temp_chunk:
|
||||
temp_chunk += " " + sentence
|
||||
else:
|
||||
temp_chunk = sentence
|
||||
else:
|
||||
if temp_chunk:
|
||||
chunks.append(temp_chunk)
|
||||
temp_chunk = sentence
|
||||
|
||||
if temp_chunk:
|
||||
current_chunk = temp_chunk
|
||||
else:
|
||||
current_chunk = ""
|
||||
else:
|
||||
current_chunk = paragraph
|
||||
|
||||
# 添加最后一个块
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
def format_text_for_pretrain(text):
|
||||
"""将文本格式化为预训练格式并检查token长度"""
|
||||
# 清理文本
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# 移除过多的换行符和空格
|
||||
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
||||
text = re.sub(r' +', ' ', text)
|
||||
|
||||
# 检查token长度
|
||||
token_count = count_tokens(text)
|
||||
|
||||
# 如果太短,跳过
|
||||
if token_count < MIN_TOKENS:
|
||||
return None
|
||||
|
||||
# 如果太长,截断
|
||||
if token_count > MAX_TOKENS:
|
||||
text = truncate_to_token_limit(text, MAX_TOKENS)
|
||||
token_count = count_tokens(text)
|
||||
|
||||
# 再次检查是否在合理范围内
|
||||
if token_count < MIN_TOKENS or token_count > MAX_TOKENS:
|
||||
return None
|
||||
|
||||
# 格式化为预训练格式
|
||||
formatted_text = f"<|im_start|>{text}<|im_end|>"
|
||||
return formatted_text
|
||||
|
||||
def get_text_hash(text):
|
||||
"""获取文本的哈希值,用于去重"""
|
||||
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
||||
|
||||
def should_sample(dataset_name, current_count, config_dict=None, is_extra_mode=False):
|
||||
"""根据配置决定是否采样当前记录"""
|
||||
if config_dict is None:
|
||||
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
|
||||
|
||||
config = config_dict[dataset_name]
|
||||
|
||||
# 检查是否达到最大样本数
|
||||
if config["max_samples"] and current_count >= config["max_samples"]:
|
||||
return False
|
||||
|
||||
# 根据采样比例随机决定
|
||||
return random.random() < config["sample_ratio"]
|
||||
|
||||
def process_pretrain_hq():
|
||||
"""处理已有的高质量预训练数据 - 直接输出,不做任何处理"""
|
||||
logger.info("Processing pretrain_hq.jsonl...")
|
||||
count = 0
|
||||
|
||||
with open(PRETRAIN_HQ_PATH, 'r', encoding='utf-8') as f:
|
||||
for line in tqdm(f, desc="Processing pretrain_hq"):
|
||||
try:
|
||||
data = json.loads(line.strip())
|
||||
text = data.get('text', '').strip()
|
||||
|
||||
if text: # 只要有文本就直接输出,不做任何检测
|
||||
if should_sample("pretrain_hq", count):
|
||||
yield {"text": text}
|
||||
count += 1
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
logger.info(f"Processed {count} records from pretrain_hq.jsonl")
|
||||
|
||||
def process_wikipedia(is_extra_mode=False):
|
||||
"""处理Wikipedia数据"""
|
||||
mode_text = "extra" if is_extra_mode else "main"
|
||||
logger.info(f"Processing Wikipedia data ({mode_text} mode)...")
|
||||
count = 0
|
||||
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
|
||||
|
||||
# 获取所有英文Wikipedia文件
|
||||
wiki_files = list(Path(WIKIPEDIA_PATH).glob("*.parquet"))
|
||||
|
||||
for file_path in tqdm(wiki_files, desc=f"Processing Wikipedia files ({mode_text})"):
|
||||
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
|
||||
break
|
||||
|
||||
try:
|
||||
df = pd.read_parquet(file_path)
|
||||
for _, row in df.iterrows():
|
||||
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
|
||||
break
|
||||
|
||||
text = row.get('text', '').strip()
|
||||
if text and len(text) > 200: # 预过滤太短的文本
|
||||
# 先将长文本分割成中等大小的块
|
||||
chunks = split_text_into_chunks(text, target_chunk_size=2000)
|
||||
|
||||
for chunk in chunks:
|
||||
if config_dict["wikipedia"]["max_samples"] and count >= config_dict["wikipedia"]["max_samples"]:
|
||||
break
|
||||
|
||||
chunk_hash = get_text_hash(chunk)
|
||||
|
||||
# 在额外模式下,跳过已经被主文件选中的数据
|
||||
if is_extra_mode and chunk_hash in selected_data_hashes["wikipedia"]:
|
||||
continue
|
||||
|
||||
formatted_text = format_text_for_pretrain(chunk)
|
||||
if formatted_text and should_sample("wikipedia", count, config_dict, is_extra_mode):
|
||||
# 在主模式下记录哈希值
|
||||
if not is_extra_mode:
|
||||
selected_data_hashes["wikipedia"].add(chunk_hash)
|
||||
|
||||
yield {"text": formatted_text}
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Processed {count} records from Wikipedia ({mode_text} mode)")
|
||||
|
||||
def process_gutenberg(is_extra_mode=False):
|
||||
"""处理Gutenberg数据"""
|
||||
mode_text = "extra" if is_extra_mode else "main"
|
||||
logger.info(f"Processing Gutenberg data ({mode_text} mode)...")
|
||||
count = 0
|
||||
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
|
||||
|
||||
# 获取所有Gutenberg训练文件
|
||||
gutenberg_files = list(Path(GUTENBERG_PATH).glob("train-*.parquet"))
|
||||
|
||||
for file_path in tqdm(gutenberg_files, desc=f"Processing Gutenberg files ({mode_text})"):
|
||||
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
|
||||
break
|
||||
|
||||
try:
|
||||
df = pd.read_parquet(file_path)
|
||||
for _, row in df.iterrows():
|
||||
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
|
||||
break
|
||||
|
||||
text = row.get('text', '').strip()
|
||||
if text and len(text) > 300 and is_english_text(text): # 预过滤
|
||||
# 先将长文本分割成中等大小的块
|
||||
chunks = split_text_into_chunks(text, target_chunk_size=1800)
|
||||
|
||||
for chunk in chunks:
|
||||
if config_dict["gutenberg"]["max_samples"] and count >= config_dict["gutenberg"]["max_samples"]:
|
||||
break
|
||||
|
||||
chunk_hash = get_text_hash(chunk)
|
||||
|
||||
# 在额外模式下,跳过已经被主文件选中的数据
|
||||
if is_extra_mode and chunk_hash in selected_data_hashes["gutenberg"]:
|
||||
continue
|
||||
|
||||
formatted_text = format_text_for_pretrain(chunk)
|
||||
if formatted_text and should_sample("gutenberg", count, config_dict, is_extra_mode):
|
||||
# 在主模式下记录哈希值
|
||||
if not is_extra_mode:
|
||||
selected_data_hashes["gutenberg"].add(chunk_hash)
|
||||
|
||||
yield {"text": formatted_text}
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Processed {count} records from Gutenberg ({mode_text} mode)")
|
||||
|
||||
def process_openwebtext(is_extra_mode=False):
|
||||
"""处理OpenWebText数据"""
|
||||
mode_text = "extra" if is_extra_mode else "main"
|
||||
logger.info(f"Processing OpenWebText data ({mode_text} mode)...")
|
||||
count = 0
|
||||
config_dict = DATASET_CONFIG_EXTRA if is_extra_mode else DATASET_CONFIG
|
||||
max_files = 5 # 减少处理的文件数量以避免过长处理时间
|
||||
|
||||
# 获取tar文件列表
|
||||
tar_files = list(Path(OPENWEBTEXT_PATH).glob("*.tar"))[:max_files]
|
||||
|
||||
for tar_path in tqdm(tar_files, desc=f"Processing OpenWebText files ({mode_text})"):
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
try:
|
||||
with tarfile.open(tar_path, 'r') as outer_tar:
|
||||
# 创建临时目录处理外层tar
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
outer_tar.extractall(temp_dir)
|
||||
|
||||
# 处理解压后的xz文件
|
||||
for root, dirs, files in os.walk(temp_dir):
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
for file in files:
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
if file.endswith('.xz'):
|
||||
xz_path = os.path.join(root, file)
|
||||
|
||||
# 创建另一个临时目录处理xz文件
|
||||
with tempfile.TemporaryDirectory() as xz_temp_dir:
|
||||
try:
|
||||
# 解压xz文件
|
||||
import subprocess
|
||||
decompressed_path = os.path.join(xz_temp_dir, file[:-3]) # 移除.xz后缀
|
||||
subprocess.run(['xz', '-dc', xz_path],
|
||||
stdout=open(decompressed_path, 'wb'),
|
||||
check=True)
|
||||
|
||||
# 检查解压后的文件是否是tar格式
|
||||
if tarfile.is_tarfile(decompressed_path):
|
||||
# 处理内层tar文件
|
||||
with tarfile.open(decompressed_path, 'r') as inner_tar:
|
||||
with tempfile.TemporaryDirectory() as inner_temp_dir:
|
||||
inner_tar.extractall(inner_temp_dir)
|
||||
|
||||
# 处理最终的txt文件
|
||||
for inner_root, inner_dirs, inner_files in os.walk(inner_temp_dir):
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
for txt_file in inner_files:
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
if txt_file.endswith('.txt'):
|
||||
txt_path = os.path.join(inner_root, txt_file)
|
||||
try:
|
||||
with open(txt_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
text = f.read().strip()
|
||||
if text and len(text) > 500 and is_english_text(text):
|
||||
# 先将长文本分割成中等大小的块
|
||||
chunks = split_text_into_chunks(text, target_chunk_size=1600)
|
||||
|
||||
for chunk in chunks:
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
chunk_hash = get_text_hash(chunk)
|
||||
|
||||
# 在额外模式下,跳过已经被主文件选中的数据
|
||||
if is_extra_mode and chunk_hash in selected_data_hashes["openwebtext"]:
|
||||
continue
|
||||
|
||||
formatted_text = format_text_for_pretrain(chunk)
|
||||
if formatted_text and should_sample("openwebtext", count, config_dict, is_extra_mode):
|
||||
# 在主模式下记录哈希值
|
||||
if not is_extra_mode:
|
||||
selected_data_hashes["openwebtext"].add(chunk_hash)
|
||||
|
||||
yield {"text": formatted_text}
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.debug(f"Error reading txt file {txt_path}: {e}")
|
||||
continue
|
||||
else:
|
||||
# 如果不是tar文件,直接作为文本处理
|
||||
try:
|
||||
with open(decompressed_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
text = f.read().strip()
|
||||
if text and len(text) > 500 and is_english_text(text):
|
||||
chunks = split_text_into_chunks(text, target_chunk_size=1600)
|
||||
|
||||
for chunk in chunks:
|
||||
if config_dict["openwebtext"]["max_samples"] and count >= config_dict["openwebtext"]["max_samples"]:
|
||||
break
|
||||
|
||||
chunk_hash = get_text_hash(chunk)
|
||||
|
||||
# 在额外模式下,跳过已经被主文件选中的数据
|
||||
if is_extra_mode and chunk_hash in selected_data_hashes["openwebtext"]:
|
||||
continue
|
||||
|
||||
formatted_text = format_text_for_pretrain(chunk)
|
||||
if formatted_text and should_sample("openwebtext", count, config_dict, is_extra_mode):
|
||||
# 在主模式下记录哈希值
|
||||
if not is_extra_mode:
|
||||
selected_data_hashes["openwebtext"].add(chunk_hash)
|
||||
|
||||
yield {"text": formatted_text}
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.debug(f"Error reading decompressed file {decompressed_path}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error processing xz file {xz_path}: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {tar_path}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Processed {count} records from OpenWebText ({mode_text} mode)")
|
||||
|
||||
def merge_datasets():
|
||||
"""合并所有数据集,生成主文件和额外文件"""
|
||||
logger.info("Starting dataset merging...")
|
||||
logger.info("Main dataset configuration:")
|
||||
for name, config in DATASET_CONFIG.items():
|
||||
logger.info(f" {name}: quality={config['quality']}, ratio={config['sample_ratio']}, max={config['max_samples']}")
|
||||
|
||||
logger.info("Extra dataset configuration:")
|
||||
for name, config in DATASET_CONFIG_EXTRA.items():
|
||||
logger.info(f" {name}: quality={config['quality']}, ratio={config['sample_ratio']}, max={config['max_samples']}")
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)
|
||||
os.makedirs(os.path.dirname(OUTPUT_FILE_EXTRA), exist_ok=True)
|
||||
|
||||
# 统计信息
|
||||
main_dataset_stats = {}
|
||||
extra_dataset_stats = {}
|
||||
|
||||
# 第一阶段:生成主文件
|
||||
logger.info("="*50)
|
||||
logger.info("PHASE 1: Generating main dataset file")
|
||||
logger.info("="*50)
|
||||
|
||||
with open(OUTPUT_FILE, 'w', encoding='utf-8') as outfile:
|
||||
main_total_count = 0
|
||||
|
||||
# 处理各个数据集(主模式)
|
||||
main_datasets = [
|
||||
("pretrain_hq", process_pretrain_hq),
|
||||
("wikipedia", lambda: process_wikipedia(is_extra_mode=False)),
|
||||
("gutenberg", lambda: process_gutenberg(is_extra_mode=False)),
|
||||
("openwebtext", lambda: process_openwebtext(is_extra_mode=False))
|
||||
]
|
||||
|
||||
for dataset_name, dataset_func in main_datasets:
|
||||
logger.info(f"Processing {dataset_name} for main file...")
|
||||
dataset_count = 0
|
||||
|
||||
try:
|
||||
for record in dataset_func():
|
||||
json.dump(record, outfile, ensure_ascii=False)
|
||||
outfile.write('\n')
|
||||
dataset_count += 1
|
||||
main_total_count += 1
|
||||
|
||||
# 每5000条记录输出一次进度
|
||||
if main_total_count % 5000 == 0:
|
||||
logger.info(f"Main file: Processed {main_total_count} total records")
|
||||
|
||||
# 保存统计信息
|
||||
main_dataset_stats[dataset_name] = {
|
||||
'selected': dataset_count,
|
||||
'config': DATASET_CONFIG[dataset_name]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {dataset_name} for main file: {e}")
|
||||
main_dataset_stats[dataset_name] = {
|
||||
'selected': dataset_count,
|
||||
'config': DATASET_CONFIG[dataset_name]
|
||||
}
|
||||
|
||||
logger.info(f"Main file - Completed {dataset_name}: {dataset_count} records")
|
||||
|
||||
logger.info(f"Main file generation completed. Total records: {main_total_count}")
|
||||
|
||||
# 第二阶段:生成额外文件
|
||||
logger.info("="*50)
|
||||
logger.info("PHASE 2: Generating extra dataset file")
|
||||
logger.info("="*50)
|
||||
|
||||
with open(OUTPUT_FILE_EXTRA, 'w', encoding='utf-8') as outfile:
|
||||
extra_total_count = 0
|
||||
|
||||
# 处理各个数据集(额外模式)- 不包括pretrain_hq
|
||||
extra_datasets = [
|
||||
("wikipedia", lambda: process_wikipedia(is_extra_mode=True)),
|
||||
("gutenberg", lambda: process_gutenberg(is_extra_mode=True)),
|
||||
("openwebtext", lambda: process_openwebtext(is_extra_mode=True))
|
||||
]
|
||||
|
||||
for dataset_name, dataset_func in extra_datasets:
|
||||
logger.info(f"Processing {dataset_name} for extra file...")
|
||||
dataset_count = 0
|
||||
|
||||
try:
|
||||
for record in dataset_func():
|
||||
json.dump(record, outfile, ensure_ascii=False)
|
||||
outfile.write('\n')
|
||||
dataset_count += 1
|
||||
extra_total_count += 1
|
||||
|
||||
# 每5000条记录输出一次进度
|
||||
if extra_total_count % 5000 == 0:
|
||||
logger.info(f"Extra file: Processed {extra_total_count} total records")
|
||||
|
||||
# 保存统计信息
|
||||
extra_dataset_stats[dataset_name] = {
|
||||
'selected': dataset_count,
|
||||
'config': DATASET_CONFIG_EXTRA[dataset_name]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {dataset_name} for extra file: {e}")
|
||||
extra_dataset_stats[dataset_name] = {
|
||||
'selected': dataset_count,
|
||||
'config': DATASET_CONFIG_EXTRA[dataset_name]
|
||||
}
|
||||
|
||||
logger.info(f"Extra file - Completed {dataset_name}: {dataset_count} records")
|
||||
|
||||
logger.info(f"Extra file generation completed. Total records: {extra_total_count}")
|
||||
|
||||
# 打印详细统计信息
|
||||
print_detailed_statistics(main_dataset_stats, main_total_count, extra_dataset_stats, extra_total_count)
|
||||
|
||||
logger.info("All dataset processing completed successfully!")
|
||||
logger.info(f"Main file saved to: {OUTPUT_FILE}")
|
||||
logger.info(f"Extra file saved to: {OUTPUT_FILE_EXTRA}")
|
||||
|
||||
def print_detailed_statistics(main_dataset_stats, main_total_count, extra_dataset_stats, extra_total_count):
|
||||
"""打印详细统计信息"""
|
||||
print("\n" + "="*100)
|
||||
print("DATASET PROCESSING SUMMARY")
|
||||
print("="*100)
|
||||
|
||||
# 主文件统计
|
||||
print("\nMAIN FILE (merged_pretrain.jsonl):")
|
||||
print("-" * 90)
|
||||
print(f"{'Dataset':<15} {'Selected':<10} {'Sample Ratio':<12} {'Max Limit':<12} {'% of Main':<12} {'Quality':<8}")
|
||||
print("-" * 90)
|
||||
|
||||
for dataset_name, stats in main_dataset_stats.items():
|
||||
selected = stats['selected']
|
||||
config = stats['config']
|
||||
ratio = config['sample_ratio']
|
||||
max_limit = config['max_samples'] if config['max_samples'] else "No limit"
|
||||
percentage = (selected / main_total_count * 100) if main_total_count > 0 else 0
|
||||
quality = config['quality']
|
||||
|
||||
print(f"{dataset_name:<15} {selected:<10,} {ratio:<12.1%} {str(max_limit):<12} {percentage:<12.2f}% {quality:<8}")
|
||||
|
||||
print("-" * 90)
|
||||
print(f"{'MAIN TOTAL':<15} {main_total_count:<10,} {'':<12} {'':<12} {'100.00%':<12} {'':<8}")
|
||||
|
||||
# 额外文件统计
|
||||
print("\nEXTRA FILE (merged_pretrain_extra.jsonl):")
|
||||
print("-" * 90)
|
||||
print(f"{'Dataset':<15} {'Selected':<10} {'Sample Ratio':<12} {'Max Limit':<12} {'% of Extra':<12} {'Quality':<8}")
|
||||
print("-" * 90)
|
||||
|
||||
for dataset_name, stats in extra_dataset_stats.items():
|
||||
selected = stats['selected']
|
||||
config = stats['config']
|
||||
ratio = config['sample_ratio']
|
||||
max_limit = config['max_samples'] if config['max_samples'] else "No limit"
|
||||
percentage = (selected / extra_total_count * 100) if extra_total_count > 0 else 0
|
||||
quality = config['quality']
|
||||
|
||||
print(f"{dataset_name:<15} {selected:<10,} {ratio:<12.1%} {str(max_limit):<12} {percentage:<12.2f}% {quality:<8}")
|
||||
|
||||
print("-" * 90)
|
||||
print(f"{'EXTRA TOTAL':<15} {extra_total_count:<10,} {'':<12} {'':<12} {'100.00%':<12} {'':<8}")
|
||||
|
||||
# 总体统计
|
||||
total_records = main_total_count + extra_total_count
|
||||
print("\nOVERALL STATISTICS:")
|
||||
print("-" * 50)
|
||||
print(f"Main file records: {main_total_count:>10,}")
|
||||
print(f"Extra file records: {extra_total_count:>10,}")
|
||||
print(f"Total records: {total_records:>10,}")
|
||||
print(f"Token range per sample: {MIN_TOKENS}-{MAX_TOKENS} tokens")
|
||||
|
||||
# 质量分布统计
|
||||
quality_stats = {}
|
||||
for dataset_name, stats in main_dataset_stats.items():
|
||||
quality = stats['config']['quality']
|
||||
if quality not in quality_stats:
|
||||
quality_stats[quality] = {'main': 0, 'extra': 0}
|
||||
quality_stats[quality]['main'] += stats['selected']
|
||||
|
||||
for dataset_name, stats in extra_dataset_stats.items():
|
||||
quality = stats['config']['quality']
|
||||
if quality not in quality_stats:
|
||||
quality_stats[quality] = {'main': 0, 'extra': 0}
|
||||
quality_stats[quality]['extra'] += stats['selected']
|
||||
|
||||
print("\nQUALITY DISTRIBUTION:")
|
||||
print("-" * 60)
|
||||
print(f"{'Quality':<12} {'Main File':<12} {'Extra File':<12} {'Total':<12} {'%':<8}")
|
||||
print("-" * 60)
|
||||
for quality in sorted(quality_stats.keys()):
|
||||
main_count = quality_stats[quality]['main']
|
||||
extra_count = quality_stats[quality]['extra']
|
||||
total_count = main_count + extra_count
|
||||
percentage = (total_count / total_records * 100) if total_records > 0 else 0
|
||||
print(f"{quality.capitalize():<12} {main_count:<12,} {extra_count:<12,} {total_count:<12,} {percentage:<8.2f}%")
|
||||
print("-" * 60)
|
||||
print(f"{'Total':<12} {main_total_count:<12,} {extra_total_count:<12,} {total_records:<12,} {'100.00%':<8}")
|
||||
|
||||
print(f"\nFiles saved to:")
|
||||
print(f" Main file: {OUTPUT_FILE}")
|
||||
print(f" Extra file: {OUTPUT_FILE_EXTRA}")
|
||||
print("="*100)
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
try:
|
||||
# 设置随机种子以确保结果可重现
|
||||
random.seed(42)
|
||||
|
||||
# 检查依赖包
|
||||
try:
|
||||
import langdetect
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError as e:
|
||||
logger.error(f"Missing dependencies: {e}")
|
||||
logger.error("Please install: pip install langdetect transformers")
|
||||
return
|
||||
|
||||
# 初始化tokenizer
|
||||
init_tokenizer()
|
||||
|
||||
# 检查输入文件
|
||||
if not os.path.exists(PRETRAIN_HQ_PATH):
|
||||
logger.error(f"pretrain_hq.jsonl not found at {PRETRAIN_HQ_PATH}")
|
||||
return
|
||||
|
||||
# 开始合并数据集
|
||||
merge_datasets()
|
||||
logger.info("All processing completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main process: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,442 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import argparse
|
||||
from typing import List, Dict, Any, Optional
|
||||
from collections import defaultdict
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
class WikidataRelationManager:
|
||||
"""Wikidata关系管理器,支持动态获取和缓存"""
|
||||
|
||||
def __init__(self, cache_file: str = "wikidata_relations_cache.pkl",
|
||||
mapping_file: str = None):
|
||||
self.cache_file = cache_file
|
||||
self.mapping_file = mapping_file
|
||||
self.relations = {}
|
||||
# 删除了API相关属性
|
||||
|
||||
# 初始的基础关系映射
|
||||
self.base_relations = {
|
||||
# # 基本关系
|
||||
# 'P31': 'instance of',
|
||||
# 'P279': 'subclass of',
|
||||
# 'P17': 'country',
|
||||
# 'P159': 'headquarters location',
|
||||
# 'P571': 'inception',
|
||||
|
||||
# # 人物关系
|
||||
# 'P19': 'place of birth',
|
||||
# 'P20': 'place of death',
|
||||
# 'P27': 'country of citizenship',
|
||||
# 'P106': 'occupation',
|
||||
# 'P22': 'father',
|
||||
# 'P25': 'mother',
|
||||
# 'P26': 'spouse',
|
||||
# 'P40': 'child',
|
||||
# 'P69': 'educated at',
|
||||
# 'P108': 'employer',
|
||||
|
||||
# # 地理关系
|
||||
# 'P36': 'capital',
|
||||
# 'P131': 'located in',
|
||||
# 'P47': 'shares border with',
|
||||
# 'P206': 'located on terrain feature',
|
||||
# 'P1376': 'capital of',
|
||||
|
||||
# # 组织关系
|
||||
# 'P112': 'founded by',
|
||||
# 'P127': 'owned by',
|
||||
# 'P169': 'chief executive officer',
|
||||
# 'P488': 'chairperson',
|
||||
# 'P749': 'parent organization',
|
||||
|
||||
# # 作品关系
|
||||
# 'P50': 'author',
|
||||
# 'P57': 'director',
|
||||
# 'P58': 'screenwriter',
|
||||
# 'P161': 'cast member',
|
||||
# 'P175': 'performer',
|
||||
# 'P577': 'publication date',
|
||||
# 'P123': 'publisher',
|
||||
# 'P136': 'genre',
|
||||
|
||||
# # 时间关系
|
||||
# 'P155': 'follows',
|
||||
# 'P156': 'followed by',
|
||||
# 'P580': 'start time',
|
||||
# 'P582': 'end time',
|
||||
|
||||
# # 体育关系
|
||||
# 'P54': 'member of sports team',
|
||||
# 'P413': 'position played on team',
|
||||
# 'P118': 'league',
|
||||
|
||||
# # 科学关系
|
||||
# 'P275': 'copyright license',
|
||||
# 'P170': 'creator',
|
||||
# 'P398': 'child astronomical body',
|
||||
# 'P397': 'parent astronomical body',
|
||||
|
||||
# # 其他常见关系
|
||||
# 'P37': 'official language',
|
||||
# 'P1923': 'place of marriage',
|
||||
# 'P737': 'influenced by',
|
||||
# 'P463': 'member of',
|
||||
# 'P39': 'position held',
|
||||
# 'P276': 'location',
|
||||
# 'P1441': 'present in work',
|
||||
}
|
||||
|
||||
self.load_cache()
|
||||
|
||||
def load_cache(self):
|
||||
"""加载缓存的关系映射,优先使用JSON映射文件"""
|
||||
try:
|
||||
# 优先尝试加载JSON映射文件
|
||||
if self.mapping_file and os.path.exists(self.mapping_file):
|
||||
with open(self.mapping_file, 'r', encoding='utf-8') as f:
|
||||
self.relations = json.load(f)
|
||||
print(f"从JSON映射文件加载了 {len(self.relations)} 个关系映射")
|
||||
return
|
||||
|
||||
# 尝试加载pickle缓存文件
|
||||
if os.path.exists(self.cache_file):
|
||||
with open(self.cache_file, 'rb') as f:
|
||||
self.relations = pickle.load(f)
|
||||
print(f"从pickle缓存加载了 {len(self.relations)} 个关系映射")
|
||||
else:
|
||||
self.relations = self.base_relations.copy()
|
||||
print(f"初始化基础关系映射: {len(self.relations)} 个")
|
||||
except Exception as e:
|
||||
print(f"加载缓存失败: {e}")
|
||||
self.relations = self.base_relations.copy()
|
||||
|
||||
def save_cache(self):
|
||||
"""保存关系映射到缓存"""
|
||||
try:
|
||||
with open(self.cache_file, 'wb') as f:
|
||||
pickle.dump(self.relations, f)
|
||||
print(f"已保存 {len(self.relations)} 个关系映射到缓存")
|
||||
except Exception as e:
|
||||
print(f"保存缓存失败: {e}")
|
||||
|
||||
# 删除了网络抓取功能,改为纯离线模式
|
||||
|
||||
def get_relation_name(self, property_id: str) -> Optional[str]:
|
||||
"""获取关系名称,仅使用本地映射"""
|
||||
if property_id in self.relations:
|
||||
return self.relations[property_id]
|
||||
|
||||
# 如果本地映射中没有找到,返回None(表示跳过这个关系)
|
||||
return None
|
||||
|
||||
# 删除了网络请求相关的批量获取和预加载功能
|
||||
|
||||
class TRexProcessor:
|
||||
"""T-REx数据集处理器"""
|
||||
|
||||
def __init__(self, relation_manager: WikidataRelationManager):
|
||||
self.relation_manager = relation_manager
|
||||
|
||||
def extract_predicate_id(self, uri: str) -> str:
|
||||
"""从URI中提取属性ID"""
|
||||
if uri and 'prop/direct/' in uri:
|
||||
return uri.split('/')[-1]
|
||||
elif uri and uri.startswith('P') and uri[1:].isdigit():
|
||||
return uri
|
||||
return uri if uri else 'unknown'
|
||||
|
||||
def get_relation_name(self, predicate_uri: str) -> Optional[str]:
|
||||
"""获取关系的可读名称"""
|
||||
predicate_id = self.extract_predicate_id(predicate_uri)
|
||||
return self.relation_manager.get_relation_name(predicate_id)
|
||||
|
||||
# 删除了谓词收集功能,因为不再需要预加载
|
||||
|
||||
def is_valid_triple(self, triple: Dict[str, Any], confidence_threshold: float,
|
||||
boundary_threshold: int) -> bool:
|
||||
"""检查三元组是否满足过滤条件"""
|
||||
try:
|
||||
# 检查triple是否为字典
|
||||
if not isinstance(triple, dict):
|
||||
return False
|
||||
|
||||
# 检查必要字段
|
||||
if not all(key in triple for key in ['subject', 'predicate', 'object']):
|
||||
return False
|
||||
|
||||
subject = triple['subject']
|
||||
predicate = triple['predicate']
|
||||
object_info = triple['object']
|
||||
|
||||
# 检查subject、predicate、object是否都为字典
|
||||
if not isinstance(subject, dict) or not isinstance(predicate, dict) or not isinstance(object_info, dict):
|
||||
return False
|
||||
|
||||
# 检查主语和宾语是否有有效的URI和surfaceform
|
||||
if not (subject.get('uri') and subject.get('surfaceform')):
|
||||
return False
|
||||
if not (object_info.get('uri') and object_info.get('surfaceform')):
|
||||
return False
|
||||
if not predicate.get('uri'):
|
||||
return False
|
||||
|
||||
# 检查置信度(如果存在)
|
||||
confidence = triple.get('confidence')
|
||||
if confidence is not None and confidence < confidence_threshold:
|
||||
return False
|
||||
|
||||
# 检查边界信息(如果设置了阈值)
|
||||
if boundary_threshold > 0:
|
||||
subject_boundaries = subject.get('boundaries')
|
||||
object_boundaries = object_info.get('boundaries')
|
||||
|
||||
if not subject_boundaries or not object_boundaries:
|
||||
return False
|
||||
|
||||
# 检查边界是否为列表且长度至少为2
|
||||
if not (isinstance(subject_boundaries, list) and len(subject_boundaries) >= 2):
|
||||
return False
|
||||
if not (isinstance(object_boundaries, list) and len(object_boundaries) >= 2):
|
||||
return False
|
||||
|
||||
try:
|
||||
# 检查边界长度是否合理
|
||||
subject_length = subject_boundaries[1] - subject_boundaries[0]
|
||||
object_length = object_boundaries[1] - object_boundaries[0]
|
||||
|
||||
if subject_length < boundary_threshold or object_length < boundary_threshold:
|
||||
return False
|
||||
except (TypeError, IndexError):
|
||||
return False
|
||||
|
||||
# 检查文本内容是否合理
|
||||
subject_text = subject.get('surfaceform', '').strip()
|
||||
object_text = object_info.get('surfaceform', '').strip()
|
||||
|
||||
if not subject_text or not object_text:
|
||||
return False
|
||||
|
||||
# 过滤掉过长或过短的实体
|
||||
if len(subject_text) > 100 or len(object_text) > 100:
|
||||
return False
|
||||
if len(subject_text) < 2 or len(object_text) < 2:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except (KeyError, TypeError, AttributeError):
|
||||
return False
|
||||
|
||||
def process_single_file(self, file_path: str, confidence_threshold: float,
|
||||
boundary_threshold: int) -> List[Dict[str, Any]]:
|
||||
"""处理单个JSON文件"""
|
||||
print(f"Processing file: {file_path}")
|
||||
|
||||
processed_data = []
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
# 读取整个文件作为JSON数组
|
||||
print(f"正在加载JSON数组文件: {file_path}")
|
||||
data_list = json.load(f)
|
||||
print(f"文件包含 {len(data_list)} 个条目")
|
||||
|
||||
for idx, data in enumerate(data_list):
|
||||
try:
|
||||
# 获取基本信息
|
||||
text = data.get('text', '').strip()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
# 处理三元组
|
||||
triples = data.get('triples', [])
|
||||
if not triples:
|
||||
continue
|
||||
|
||||
valid_targets = []
|
||||
|
||||
for triple in triples:
|
||||
if self.is_valid_triple(triple, confidence_threshold, boundary_threshold):
|
||||
# 获取关系名称,如果无法解析则跳过这个三元组
|
||||
relation_name = self.get_relation_name(triple['predicate']['uri'])
|
||||
if relation_name is None:
|
||||
continue # 跳过无法解析的关系
|
||||
|
||||
target = {
|
||||
'subject': triple['subject']['surfaceform'].strip(),
|
||||
'predicate': relation_name,
|
||||
'object': triple['object']['surfaceform'].strip()
|
||||
}
|
||||
valid_targets.append(target)
|
||||
|
||||
# 如果有有效的三元组,添加到结果中
|
||||
if valid_targets:
|
||||
processed_data.append({
|
||||
'text': text,
|
||||
'target': valid_targets
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
if idx <= 10: # 只打印前10个错误
|
||||
print(f"处理条目时出错 in {file_path} at index {idx}: {e}")
|
||||
continue
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"文件未找到: {file_path}")
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON解析错误 in {file_path}: {e}")
|
||||
except Exception as e:
|
||||
print(f"处理文件时出错 {file_path}: {e}")
|
||||
|
||||
print(f"从 {file_path} 提取了 {len(processed_data)} 个有效样本")
|
||||
return processed_data
|
||||
|
||||
def process_folder(self, folder_path: str, confidence_threshold: float,
|
||||
boundary_threshold: int) -> List[Dict[str, Any]]:
|
||||
"""处理文件夹中的所有JSON文件"""
|
||||
all_processed_data = []
|
||||
|
||||
if not os.path.exists(folder_path):
|
||||
raise FileNotFoundError(f"文件夹不存在: {folder_path}")
|
||||
|
||||
# 获取所有JSON文件
|
||||
json_files = [f for f in os.listdir(folder_path) if f.endswith('.json')]
|
||||
|
||||
if not json_files:
|
||||
raise ValueError(f"在 {folder_path} 中没有找到JSON文件")
|
||||
|
||||
print(f"找到 {len(json_files)} 个JSON文件")
|
||||
|
||||
for filename in sorted(json_files):
|
||||
file_path = os.path.join(folder_path, filename)
|
||||
processed_data = self.process_single_file(file_path, confidence_threshold, boundary_threshold)
|
||||
all_processed_data.extend(processed_data)
|
||||
|
||||
# 保存最终的关系缓存
|
||||
self.relation_manager.save_cache()
|
||||
|
||||
return all_processed_data
|
||||
|
||||
def generate_statistics(self, processed_data: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""生成数据统计信息"""
|
||||
total_samples = len(processed_data)
|
||||
total_triples = sum(len(sample['target']) for sample in processed_data)
|
||||
|
||||
# 统计关系类型
|
||||
relation_counts = defaultdict(int)
|
||||
for sample in processed_data:
|
||||
for target in sample['target']:
|
||||
relation_counts[target['predicate']] += 1
|
||||
|
||||
# 统计文本长度
|
||||
text_lengths = [len(sample['text']) for sample in processed_data]
|
||||
avg_text_length = sum(text_lengths) / len(text_lengths) if text_lengths else 0
|
||||
|
||||
# 统计每个文本的三元组数量
|
||||
triples_per_text = [len(sample['target']) for sample in processed_data]
|
||||
avg_triples_per_text = sum(triples_per_text) / len(triples_per_text) if triples_per_text else 0
|
||||
|
||||
return {
|
||||
'total_samples': total_samples,
|
||||
'total_triples': total_triples,
|
||||
'avg_text_length': round(avg_text_length, 2),
|
||||
'avg_triples_per_text': round(avg_triples_per_text, 2),
|
||||
'relation_distribution': dict(sorted(relation_counts.items(),
|
||||
key=lambda x: x[1], reverse=True)),
|
||||
'top_10_relations': dict(list(sorted(relation_counts.items(),
|
||||
key=lambda x: x[1], reverse=True))[:10]),
|
||||
'total_unique_relations': len(relation_counts),
|
||||
'cached_relations': len(self.relation_manager.relations)
|
||||
}
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='处理T-REx数据集(支持动态关系获取)')
|
||||
parser.add_argument('--folder_path', type=str,default='/home/pci/ycz/Code/Minimind/dataset/trex', help='包含JSON文件的文件夹路径')
|
||||
parser.add_argument('--confidence_threshold', type=float, default=0.5,
|
||||
help='置信度阈值 (默认: 0.0)')
|
||||
parser.add_argument('--boundary_threshold', type=int, default=0,
|
||||
help='边界长度阈值 (默认: 0, 不过滤)')
|
||||
parser.add_argument('--output', type=str, default='./processed_trex_data.json',
|
||||
help='输出文件名 (默认: processed_trex_data.json)')
|
||||
parser.add_argument('--stats', type=str, default='trex_statistics.json',
|
||||
help='统计信息输出文件名 (默认: trex_statistics.json)')
|
||||
parser.add_argument('--cache_file', type=str, default='wikidata_relations_cache.pkl',
|
||||
help='关系缓存文件名 (默认: wikidata_relations_cache.pkl)')
|
||||
parser.add_argument('--mapping_file', type=str, default="/home/pci/ycz/Code/Minimind/preprocessing/sample_property_mappings.json",
|
||||
help='JSON映射文件路径 (必须提供,用于关系名称映射)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("T-REx数据集处理器(支持动态关系获取)")
|
||||
print("=" * 60)
|
||||
print(f"输入文件夹: {args.folder_path}")
|
||||
print(f"置信度阈值: {args.confidence_threshold}")
|
||||
print(f"边界长度阈值: {args.boundary_threshold}")
|
||||
print(f"输出文件: {args.output}")
|
||||
print(f"关系缓存文件: {args.cache_file}")
|
||||
print(f"JSON映射文件: {args.mapping_file if args.mapping_file else '未指定'}")
|
||||
print("=" * 60)
|
||||
|
||||
# 检查映射文件是否存在
|
||||
if not args.mapping_file or not os.path.exists(args.mapping_file):
|
||||
print(f"错误: 映射文件不存在或未指定: {args.mapping_file}")
|
||||
print("请确保提供有效的JSON映射文件。")
|
||||
return 1
|
||||
|
||||
# 创建关系管理器
|
||||
relation_manager = WikidataRelationManager(
|
||||
cache_file=args.cache_file,
|
||||
mapping_file=args.mapping_file
|
||||
)
|
||||
|
||||
# 创建处理器
|
||||
processor = TRexProcessor(relation_manager)
|
||||
|
||||
try:
|
||||
# 处理数据
|
||||
processed_data = processor.process_folder(
|
||||
args.folder_path,
|
||||
args.confidence_threshold,
|
||||
args.boundary_threshold
|
||||
)
|
||||
|
||||
print(f"\n处理完成!总共处理了 {len(processed_data)} 个样本")
|
||||
|
||||
# 生成统计信息
|
||||
stats = processor.generate_statistics(processed_data)
|
||||
|
||||
# 保存处理后的数据
|
||||
with open(args.output, 'w', encoding='utf-8') as f:
|
||||
json.dump(processed_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 保存统计信息
|
||||
with open(args.stats, 'w', encoding='utf-8') as f:
|
||||
json.dump(stats, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"\n数据已保存到: {args.output}")
|
||||
print(f"统计信息已保存到: {args.stats}")
|
||||
print(f"关系缓存已保存到: {args.cache_file}")
|
||||
|
||||
# 打印统计摘要
|
||||
print("\n数据统计摘要:")
|
||||
print("=" * 30)
|
||||
print(f"总样本数: {stats['total_samples']}")
|
||||
print(f"总三元组数: {stats['total_triples']}")
|
||||
print(f"唯一关系数: {stats['total_unique_relations']}")
|
||||
print(f"缓存关系数: {stats['cached_relations']}")
|
||||
print(f"平均文本长度: {stats['avg_text_length']}")
|
||||
print(f"平均每文本三元组数: {stats['avg_triples_per_text']}")
|
||||
print("\n前10个最常见关系:")
|
||||
for relation, count in stats['top_10_relations'].items():
|
||||
print(f" {relation}: {count}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理过程中出错: {e}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
@ -1,441 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import json
|
||||
import re
|
||||
import asyncio
|
||||
import aiofiles
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from preprocessing.agent_system.extractor_agent.agent import DepartmentAgent
|
||||
from typing import Dict, List, Tuple
|
||||
import gc
|
||||
import time
|
||||
import psutil
|
||||
from tqdm.asyncio import tqdm as async_tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
json_path = "dataset/merged_pretrain_extra.jsonl"
|
||||
output_path = "dataset/processed_triples.jsonl"
|
||||
|
||||
# 优化后的配置参数 - 降低资源消耗
|
||||
BATCH_SIZE = 5000 # 减少批次大小:每批1万条数据
|
||||
MAX_CONCURRENT = 200 # 减少并发数:最多50条并发处理
|
||||
AGENT_POOL_SIZE = 20 # 大幅减少agent池大小:只创建5个agent实例
|
||||
|
||||
def get_memory_usage():
|
||||
"""获取当前内存使用情况"""
|
||||
process = psutil.Process(os.getpid())
|
||||
memory_info = process.memory_info()
|
||||
memory_mb = memory_info.rss / 1024 / 1024
|
||||
return memory_mb
|
||||
|
||||
def print_memory_info(stage=""):
|
||||
"""打印内存使用信息"""
|
||||
memory_mb = get_memory_usage()
|
||||
print(f"🔧 {stage} - 内存使用: {memory_mb:.1f} MB")
|
||||
|
||||
# 创建extractor_agent池,避免并发冲突
|
||||
def create_extractor_pool(pool_size: int = 5):
|
||||
"""创建extractor_agent池"""
|
||||
print(f"正在创建 {pool_size} 个agent实例...")
|
||||
agents = []
|
||||
for i in range(pool_size):
|
||||
try:
|
||||
agent = DepartmentAgent(model_type="deepseek")
|
||||
agents.append(agent)
|
||||
print(f" ✓ Agent {i+1}/{pool_size} 创建成功")
|
||||
except Exception as e:
|
||||
print(f" ✗ Agent {i+1} 创建失败: {e}")
|
||||
print(f"Agent池创建完成,实际创建了 {len(agents)} 个实例")
|
||||
return agents
|
||||
|
||||
# 延迟初始化agent池
|
||||
AGENT_POOL = None
|
||||
agent_pool_index = 0
|
||||
|
||||
def get_agent_pool():
|
||||
"""获取agent池,延迟初始化"""
|
||||
global AGENT_POOL
|
||||
if AGENT_POOL is None:
|
||||
print_memory_info("创建Agent池前")
|
||||
AGENT_POOL = create_extractor_pool(pool_size=AGENT_POOL_SIZE)
|
||||
print_memory_info("创建Agent池后")
|
||||
return AGENT_POOL
|
||||
|
||||
def get_next_agent():
|
||||
"""轮询获取下一个可用的agent"""
|
||||
global agent_pool_index
|
||||
pool = get_agent_pool()
|
||||
agent = pool[agent_pool_index % len(pool)]
|
||||
agent_pool_index += 1
|
||||
return agent
|
||||
|
||||
def clean_and_split_text(text):
|
||||
"""
|
||||
去除文本开头结尾的标记,并按句子分割
|
||||
"""
|
||||
# 去除开头的<|im_start|>和结尾的<|im_end|>
|
||||
text = text.strip()
|
||||
if text.startswith('<|im_start|>'):
|
||||
text = text[len('<|im_start|>'):]
|
||||
if text.endswith('<|im_end|>'):
|
||||
text = text[:-len('<|im_end|>')]
|
||||
|
||||
# 清理文本,去除多余的空白字符
|
||||
text = text.strip()
|
||||
|
||||
# 按句子分割(根据句号、问号、感叹号等标点符号)
|
||||
# 使用正则表达式匹配句子结束标志
|
||||
sentence_endings = r'[.!?。!?]'
|
||||
sentences = re.split(sentence_endings, text)
|
||||
|
||||
# 清理每个句子,去除空白和空句子
|
||||
cleaned_sentences = []
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if sentence and len(sentence) > 5: # 只保留非空且有意义的句子
|
||||
cleaned_sentences.append(sentence)
|
||||
|
||||
return cleaned_sentences
|
||||
|
||||
async def extract_triple_from_sentence_async(sentence: str, context: str = None) -> Dict:
|
||||
"""
|
||||
异步使用extractor_agent从句子中提取三元组
|
||||
"""
|
||||
try:
|
||||
# 获取一个agent实例
|
||||
agent = get_next_agent()
|
||||
result = await agent.async_run(sentence=sentence, context=context)
|
||||
return {
|
||||
"sentence": sentence,
|
||||
"triple": {
|
||||
"subject": result.triple.subject,
|
||||
"predicate": result.triple.predicate,
|
||||
"object": result.triple.object
|
||||
},
|
||||
"confidence": result.confidence
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"sentence": sentence,
|
||||
"triple": {
|
||||
"subject": "",
|
||||
"predicate": "",
|
||||
"object": ""
|
||||
},
|
||||
"confidence": 0.0,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def process_paragraph_async(line_num: int, original_text: str, semaphore: asyncio.Semaphore) -> Dict:
|
||||
"""
|
||||
异步处理单个段落
|
||||
"""
|
||||
async with semaphore: # 控制并发数量
|
||||
try:
|
||||
# 清理并分割文本
|
||||
sentences = clean_and_split_text(original_text)
|
||||
|
||||
if not sentences:
|
||||
return None
|
||||
|
||||
# 构建当前段落的结果
|
||||
paragraph_result = {
|
||||
"source_line": line_num,
|
||||
"original_paragraph": original_text,
|
||||
"sentences": [],
|
||||
"triples": []
|
||||
}
|
||||
|
||||
# 异步处理所有句子
|
||||
tasks = []
|
||||
for sentence in sentences:
|
||||
task = extract_triple_from_sentence_async(sentence, context=original_text)
|
||||
tasks.append(task)
|
||||
|
||||
# 等待所有句子处理完成
|
||||
triple_results = await asyncio.gather(*tasks)
|
||||
|
||||
# 整理结果
|
||||
for i, sentence in enumerate(sentences):
|
||||
paragraph_result["sentences"].append(sentence)
|
||||
paragraph_result["triples"].append(triple_results[i])
|
||||
|
||||
return paragraph_result
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理第 {line_num} 行时出错: {e}")
|
||||
return None
|
||||
|
||||
async def process_batch_async(batch_data: List[Tuple[int, str]], batch_num: int) -> List[Dict]:
|
||||
"""
|
||||
异步处理一个批次的数据,带进度条和内存监控
|
||||
"""
|
||||
print(f"\n=== 异步处理批次 {batch_num} ===")
|
||||
print(f"批次大小: {len(batch_data)} 条记录")
|
||||
print_memory_info(f"批次 {batch_num} 开始前")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 创建信号量控制并发数量
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT)
|
||||
|
||||
# 分块处理任务,避免一次性创建太多任务
|
||||
chunk_size = 1000 # 每次处理1000个任务
|
||||
all_results = []
|
||||
|
||||
for chunk_start in range(0, len(batch_data), chunk_size):
|
||||
chunk_end = min(chunk_start + chunk_size, len(batch_data))
|
||||
chunk_data = batch_data[chunk_start:chunk_end]
|
||||
|
||||
print(f"处理子块 {chunk_start//chunk_size + 1}/{(len(batch_data)-1)//chunk_size + 1} ({len(chunk_data)} 条记录)")
|
||||
|
||||
# 创建当前块的异步任务
|
||||
tasks = []
|
||||
for line_num, original_text in chunk_data:
|
||||
task = process_paragraph_async(line_num, original_text, semaphore)
|
||||
tasks.append(task)
|
||||
|
||||
# 使用进度条处理当前块
|
||||
progress_bar = tqdm(total=len(tasks), desc=f"批次{batch_num}-块{chunk_start//chunk_size + 1}", unit="段落", ncols=100)
|
||||
|
||||
chunk_results = []
|
||||
completed_tasks = 0
|
||||
|
||||
# 使用as_completed来获取完成的任务,并更新进度条
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
try:
|
||||
result = await coro
|
||||
chunk_results.append(result)
|
||||
completed_tasks += 1
|
||||
|
||||
# 更新进度条
|
||||
progress_bar.update(1)
|
||||
|
||||
# 每完成50个任务更新一次描述
|
||||
if completed_tasks % 50 == 0:
|
||||
valid_results = [r for r in chunk_results if r is not None]
|
||||
progress_bar.set_postfix({
|
||||
'有效': len(valid_results),
|
||||
'完成': completed_tasks,
|
||||
'成功率': f"{len(valid_results)/completed_tasks*100:.1f}%"
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"任务执行失败: {e}")
|
||||
completed_tasks += 1
|
||||
progress_bar.update(1)
|
||||
|
||||
progress_bar.close()
|
||||
all_results.extend(chunk_results)
|
||||
|
||||
# 每个块完成后清理内存
|
||||
del tasks, chunk_results
|
||||
gc.collect()
|
||||
|
||||
print_memory_info(f"批次 {batch_num} 块 {chunk_start//chunk_size + 1} 完成后")
|
||||
|
||||
# 过滤None结果
|
||||
valid_results = [result for result in all_results if result is not None]
|
||||
|
||||
# 统计信息
|
||||
batch_sentences = sum(len(result["sentences"]) for result in valid_results)
|
||||
batch_triples = sum(
|
||||
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
|
||||
for result in valid_results
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print(f"批次 {batch_num} 异步处理完成:")
|
||||
print(f" - 有效段落: {len(valid_results)}/{len(batch_data)} ({len(valid_results)/len(batch_data)*100:.1f}%)")
|
||||
print(f" - 总句子数: {batch_sentences}")
|
||||
print(f" - 成功三元组: {batch_triples}")
|
||||
print(f" - 三元组成功率: {batch_triples/batch_sentences*100:.1f}%" if batch_sentences > 0 else "无句子")
|
||||
print(f" - 处理时间: {processing_time:.2f}秒")
|
||||
print(f" - 处理速度: {len(batch_data)/processing_time:.2f}段落/秒")
|
||||
|
||||
print_memory_info(f"批次 {batch_num} 完成后")
|
||||
|
||||
return valid_results
|
||||
|
||||
async def write_results_batch(results: List[Dict], output_path: str):
|
||||
"""
|
||||
异步批量写入结果,带进度提示
|
||||
"""
|
||||
try:
|
||||
print(f"开始批量写入 {len(results)} 条结果...")
|
||||
|
||||
# 准备写入内容
|
||||
content_lines = []
|
||||
for result in results:
|
||||
content_lines.append(json.dumps(result, ensure_ascii=False))
|
||||
|
||||
# 异步批量写入
|
||||
async with aiofiles.open(output_path, "a", encoding="utf-8") as f:
|
||||
await f.write("\n".join(content_lines) + "\n")
|
||||
|
||||
print(f"✓ 成功批量写入 {len(results)} 条结果到 {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 批量写入失败: {e}")
|
||||
print("尝试逐条写入...")
|
||||
|
||||
# 如果批量写入失败,回退到逐条写入(带进度条)
|
||||
async with aiofiles.open(output_path, "a", encoding="utf-8") as f:
|
||||
for result in tqdm(results, desc="逐条写入", unit="条"):
|
||||
await f.write(json.dumps(result, ensure_ascii=False) + "\n")
|
||||
print(f"✓ 逐条写入完成")
|
||||
|
||||
# 主处理流程
|
||||
async def main_async():
|
||||
total_processed = 0
|
||||
total_sentences = 0
|
||||
total_triples = 0
|
||||
batch_num = 0
|
||||
|
||||
print("=== 开始异步批次处理JSONL文件 ===")
|
||||
print(f"优化后的配置信息:")
|
||||
print(f" - 批次大小: {BATCH_SIZE:,} 条记录")
|
||||
print(f" - 最大并发数: {MAX_CONCURRENT}")
|
||||
print(f" - Agent池大小: {AGENT_POOL_SIZE}")
|
||||
print(f" - 输入文件: {json_path}")
|
||||
print(f" - 输出文件: {output_path}")
|
||||
print()
|
||||
|
||||
print_memory_info("程序开始")
|
||||
|
||||
# 清空输出文件
|
||||
async with aiofiles.open(output_path, "w", encoding="utf-8") as f:
|
||||
pass
|
||||
|
||||
# 读取并处理数据
|
||||
with open(json_path, "r", encoding="utf-8") as f_in:
|
||||
batch_data = []
|
||||
|
||||
for line_num, line in enumerate(f_in):
|
||||
if line.strip(): # 跳过空行
|
||||
try:
|
||||
item = json.loads(line)
|
||||
original_text = item.get("text", "")
|
||||
|
||||
if original_text:
|
||||
batch_data.append((line_num + 1, original_text))
|
||||
|
||||
# 当批次达到指定大小时,异步处理这个批次
|
||||
if len(batch_data) >= BATCH_SIZE:
|
||||
batch_num += 1
|
||||
|
||||
# 异步处理批次
|
||||
batch_results = await process_batch_async(batch_data, batch_num)
|
||||
|
||||
# 批量写入结果
|
||||
if batch_results:
|
||||
await write_results_batch(batch_results, output_path)
|
||||
|
||||
# 统计信息
|
||||
batch_sentences = sum(len(result["sentences"]) for result in batch_results)
|
||||
batch_triples = sum(
|
||||
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
|
||||
for result in batch_results
|
||||
)
|
||||
|
||||
total_processed += len(batch_data)
|
||||
total_sentences += batch_sentences
|
||||
total_triples += batch_triples
|
||||
|
||||
print(f"\n📊 批次 {batch_num} 累计统计:")
|
||||
print(f" - 累计处理段落: {total_processed:,}")
|
||||
print(f" - 累计句子数: {total_sentences:,}")
|
||||
print(f" - 累计三元组: {total_triples:,}")
|
||||
print(f" - 整体成功率: {total_triples/total_sentences*100:.1f}%")
|
||||
print("-" * 80)
|
||||
|
||||
# 清理批次数据,释放内存
|
||||
batch_data.clear()
|
||||
batch_results.clear()
|
||||
gc.collect() # 强制垃圾回收
|
||||
|
||||
print_memory_info(f"批次 {batch_num} 清理后")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"第 {line_num + 1} 行JSON解析错误: {e}")
|
||||
except Exception as e:
|
||||
print(f"处理第 {line_num + 1} 行时出错: {e}")
|
||||
|
||||
# 处理最后一个不完整的批次
|
||||
if batch_data:
|
||||
batch_num += 1
|
||||
batch_results = await process_batch_async(batch_data, batch_num)
|
||||
|
||||
if batch_results:
|
||||
await write_results_batch(batch_results, output_path)
|
||||
|
||||
batch_sentences = sum(len(result["sentences"]) for result in batch_results)
|
||||
batch_triples = sum(
|
||||
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
|
||||
for result in batch_results
|
||||
)
|
||||
|
||||
total_processed += len(batch_data)
|
||||
total_sentences += batch_sentences
|
||||
total_triples += batch_triples
|
||||
|
||||
# 最终统计
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🎉 所有批次异步处理完成!")
|
||||
print(f"{'='*80}")
|
||||
print(f"最终统计:")
|
||||
print(f" - 总批次数: {batch_num}")
|
||||
print(f" - 总段落数: {total_processed:,}")
|
||||
print(f" - 总句子数: {total_sentences:,}")
|
||||
print(f" - 总三元组: {total_triples:,}")
|
||||
print(f" - 整体成功率: {total_triples/total_sentences*100:.1f}%" if total_sentences > 0 else "无有效句子")
|
||||
print(f" - 输出文件: {output_path}")
|
||||
print(f"{'='*80}")
|
||||
|
||||
print_memory_info("程序结束前")
|
||||
|
||||
# 显示示例结果
|
||||
await show_sample_results()
|
||||
|
||||
async def show_sample_results():
|
||||
"""显示前几个处理结果作为示例"""
|
||||
print("\n📋 前3个处理结果示例:")
|
||||
try:
|
||||
async with aiofiles.open(output_path, "r", encoding="utf-8") as f:
|
||||
i = 0
|
||||
async for line in f:
|
||||
if i >= 3:
|
||||
break
|
||||
item = json.loads(line)
|
||||
print(f"\n--- 段落 {i+1} (来源行: {item['source_line']}) ---")
|
||||
print(f"原始段落: {item['original_paragraph'][:100]}...")
|
||||
print(f"句子数量: {len(item['sentences'])}")
|
||||
if item['triples']:
|
||||
for j, triple in enumerate(item['triples'][:2]): # 只显示前2个三元组
|
||||
print(f" 句子 {j+1}: {triple['sentence'][:50]}...")
|
||||
if triple['confidence'] > 0:
|
||||
print(f" 三元组: {triple['triple']['subject']} -> {triple['triple']['predicate']} -> {triple['triple']['object']}")
|
||||
print(f" 置信度: {triple['confidence']:.2f}")
|
||||
else:
|
||||
print(f" 提取失败: {triple.get('error', '未知错误')}")
|
||||
i += 1
|
||||
except Exception as e:
|
||||
print(f"读取示例结果时出错: {e}")
|
||||
|
||||
def main():
|
||||
"""主入口函数"""
|
||||
try:
|
||||
# 运行异步主函数
|
||||
asyncio.run(main_async())
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ 用户中断处理")
|
||||
except Exception as e:
|
||||
print(f"❌ 处理过程中出现错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
176
pyproject.toml
@ -1,176 +0,0 @@
|
||||
[project]
|
||||
name = "minimind"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"accelerate==1.7.0",
|
||||
"aiohappyeyeballs==2.6.1",
|
||||
"aiohttp==3.11.17",
|
||||
"aiosignal==1.3.2",
|
||||
"altair==5.5.0",
|
||||
"annotated-types==0.7.0",
|
||||
"anyio==4.9.0",
|
||||
"async-timeout==5.0.1",
|
||||
"attrs==25.3.0",
|
||||
"blinker==1.9.0",
|
||||
"boto3==1.38.41",
|
||||
"botocore==1.38.41",
|
||||
"cachetools==5.5.2",
|
||||
"certifi==2025.1.31",
|
||||
"charset-normalizer==3.4.1",
|
||||
"click==8.1.8",
|
||||
"contourpy==1.3.2",
|
||||
"cycler==0.12.1",
|
||||
"datasets==2.21.0",
|
||||
"datasketch==1.6.4",
|
||||
"deepspeed==0.17.0",
|
||||
"determined>=0.37.0",
|
||||
"dill==0.3.8",
|
||||
"distro==1.9.0",
|
||||
"docker-pycreds==0.4.0",
|
||||
"einops==0.8.1",
|
||||
"exceptiongroup==1.2.2",
|
||||
"filelock==3.18.0",
|
||||
"Flask==3.0.3",
|
||||
"Flask-Cors==4.0.0",
|
||||
"fonttools==4.57.0",
|
||||
"frozenlist==1.6.0",
|
||||
"fsspec==2024.6.1",
|
||||
"gitdb==4.0.12",
|
||||
"GitPython==3.1.44",
|
||||
"h11==0.14.0",
|
||||
"hjson==3.1.0",
|
||||
"httpcore==1.0.8",
|
||||
"httpx==0.28.1",
|
||||
"huggingface-hub==0.30.2",
|
||||
"importlib_metadata==7.2.1",
|
||||
"itsdangerous==2.2.0",
|
||||
"jieba==0.42.1",
|
||||
"Jinja2==3.1.2",
|
||||
"jiter==0.9.0",
|
||||
"jmespath==1.0.1",
|
||||
"joblib==1.4.2",
|
||||
"jsonlines==4.0.0",
|
||||
"jsonpointer==2.1",
|
||||
"jsonschema==4.23.0",
|
||||
"jsonschema-specifications==2024.10.1",
|
||||
"kiwisolver==1.4.8",
|
||||
"langdetect==1.0.9",
|
||||
"markdown-it-py==3.0.0",
|
||||
"MarkupSafe==3.0.2",
|
||||
"marshmallow==3.22.0",
|
||||
"matplotlib==3.10.0",
|
||||
"mdurl==0.1.2",
|
||||
"modelscope==1.25.0",
|
||||
"mpi4py>=4.0.3",
|
||||
"mpmath==1.3.0",
|
||||
"msgpack==1.1.0",
|
||||
"multidict==6.4.3",
|
||||
"multiprocess==0.70.16",
|
||||
"narwhals==1.35.0",
|
||||
"networkx==3.4.2",
|
||||
"ngrok==1.4.0",
|
||||
"ninja==1.11.1.4",
|
||||
"nltk==3.8",
|
||||
"numpy==1.26.4",
|
||||
"nvidia-cublas-cu11==11.11.3.6",
|
||||
"nvidia-cublas-cu12==12.6.4.1",
|
||||
"nvidia-cuda-cupti-cu11==11.8.87",
|
||||
"nvidia-cuda-cupti-cu12==12.6.80",
|
||||
"nvidia-cuda-nvrtc-cu11==11.8.89",
|
||||
"nvidia-cuda-nvrtc-cu12==12.6.77",
|
||||
"nvidia-cuda-runtime-cu11==11.8.89",
|
||||
"nvidia-cuda-runtime-cu12==12.6.77",
|
||||
"nvidia-cudnn-cu11==9.1.0.70",
|
||||
"nvidia-cudnn-cu12==9.5.1.17",
|
||||
"nvidia-cufft-cu11==10.9.0.58",
|
||||
"nvidia-cufft-cu12==11.3.0.4",
|
||||
"nvidia-cufile-cu12==1.11.1.6",
|
||||
"nvidia-curand-cu11==10.3.0.86",
|
||||
"nvidia-curand-cu12==10.3.7.77",
|
||||
"nvidia-cusolver-cu11==11.4.1.48",
|
||||
"nvidia-cusolver-cu12==11.7.1.2",
|
||||
"nvidia-cusparse-cu11==11.7.5.86",
|
||||
"nvidia-cusparse-cu12==12.5.4.2",
|
||||
"nvidia-cusparselt-cu12==0.6.3",
|
||||
"nvidia-ml-py==12.575.51",
|
||||
"nvidia-nccl-cu11==2.21.5",
|
||||
"nvidia-nccl-cu12==2.26.2",
|
||||
"nvidia-nvjitlink-cu12==12.6.85",
|
||||
"nvidia-nvtx-cu11==11.8.86",
|
||||
"nvidia-nvtx-cu12==12.6.77",
|
||||
"openai==1.59.6",
|
||||
"packaging==23.2",
|
||||
"pandas>=2.0.0",
|
||||
"peft==0.7.1",
|
||||
"pillow==10.4.0",
|
||||
"platformdirs==4.3.7",
|
||||
"prettytable==3.16.0",
|
||||
"propcache==0.3.1",
|
||||
"protobuf==4.25.6",
|
||||
"psutil==5.9.8",
|
||||
"py-cpuinfo==9.0.0",
|
||||
"pyarrow==19.0.1",
|
||||
"pydantic==2.11.7",
|
||||
"pydantic_core==2.33.2",
|
||||
"pydeck==0.9.1",
|
||||
"pyecharts==2.0.8",
|
||||
"Pygments==2.19.1",
|
||||
"pynvml==12.0.0",
|
||||
"pyparsing==3.2.3",
|
||||
"python-dateutil==2.9.0.post0",
|
||||
"pytz==2025.2",
|
||||
"PyYAML==6.0.2",
|
||||
"referencing==0.36.2",
|
||||
"regex==2024.11.6",
|
||||
"requests==2.32.3",
|
||||
"rich==13.7.1",
|
||||
"rouge-score>=0.1.2",
|
||||
"rpds-py==0.24.0",
|
||||
"s3transfer==0.13.0",
|
||||
"safetensors==0.5.3",
|
||||
"scikit-learn==1.5.1",
|
||||
"scipy==1.15.2",
|
||||
"sentence-transformers==2.3.1",
|
||||
"sentencepiece==0.2.0",
|
||||
"sentry-sdk==2.26.1",
|
||||
"setproctitle==1.3.5",
|
||||
"simhash==2.1.2",
|
||||
"simplejson==3.20.1",
|
||||
"six==1.17.0",
|
||||
"smmap==5.0.2",
|
||||
"sniffio==1.3.1",
|
||||
"streamlit==1.30.0",
|
||||
"swankit==0.2.4",
|
||||
"swanlab==0.6.4",
|
||||
"sympy==1.13.3",
|
||||
"tenacity==8.5.0",
|
||||
"threadpoolctl==3.6.0",
|
||||
"tiktoken>=0.8.0",
|
||||
"tokenizers==0.21.1",
|
||||
"toml==0.10.2",
|
||||
"torch==2.7.1",
|
||||
"torchaudio==2.7.1",
|
||||
"torchvision==0.22.1",
|
||||
"tornado==6.4.2",
|
||||
"tqdm==4.67.1",
|
||||
"transformers==4.52.4",
|
||||
"triton==3.3.1",
|
||||
"trl==0.13.0",
|
||||
"typing-inspection==0.4.1",
|
||||
"typing_extensions==4.13.2",
|
||||
"tzlocal==5.3.1",
|
||||
"ujson==5.1.0",
|
||||
"urllib3==2.4.0",
|
||||
"validators==0.34.0",
|
||||
"wandb==0.18.3",
|
||||
"watchdog==6.0.0",
|
||||
"wcwidth==0.2.13",
|
||||
"Werkzeug==3.1.3",
|
||||
"wrapt==1.17.2",
|
||||
"xxhash==3.5.0",
|
||||
"yarl==1.20.0",
|
||||
"zipp==3.21.0",
|
||||
]
|
155
requirements.txt
@ -1,165 +1,30 @@
|
||||
accelerate==1.7.0
|
||||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.11.17
|
||||
aiosignal==1.3.2
|
||||
altair==5.5.0
|
||||
annotated-types==0.7.0
|
||||
anyio==4.9.0
|
||||
async-timeout==5.0.1
|
||||
attrs==25.3.0
|
||||
blinker==1.9.0
|
||||
boto3==1.38.41
|
||||
botocore==1.38.41
|
||||
cachetools==5.5.2
|
||||
certifi==2025.1.31
|
||||
charset-normalizer==3.4.1
|
||||
click==8.1.8
|
||||
contourpy==1.3.2
|
||||
cycler==0.12.1
|
||||
datasets==2.21.0
|
||||
datasketch==1.6.4
|
||||
deepspeed==0.17.0
|
||||
dill==0.3.8
|
||||
distro==1.9.0
|
||||
docker-pycreds==0.4.0
|
||||
einops==0.8.1
|
||||
exceptiongroup==1.2.2
|
||||
filelock==3.18.0
|
||||
Flask==3.0.3
|
||||
Flask-Cors==4.0.0
|
||||
fonttools==4.57.0
|
||||
frozenlist==1.6.0
|
||||
fsspec==2024.6.1
|
||||
gitdb==4.0.12
|
||||
GitPython==3.1.44
|
||||
h11==0.14.0
|
||||
hjson==3.1.0
|
||||
httpcore==1.0.8
|
||||
httpx==0.28.1
|
||||
huggingface-hub==0.30.2
|
||||
importlib_metadata==7.2.1
|
||||
itsdangerous==2.2.0
|
||||
Flask_Cors==4.0.0
|
||||
jieba==0.42.1
|
||||
Jinja2==3.1.2
|
||||
jiter==0.9.0
|
||||
jmespath==1.0.1
|
||||
joblib==1.4.2
|
||||
jsonlines==4.0.0
|
||||
jsonpointer==2.1
|
||||
jsonschema==4.23.0
|
||||
jsonschema-specifications==2024.10.1
|
||||
kiwisolver==1.4.8
|
||||
langdetect==1.0.9
|
||||
markdown-it-py==3.0.0
|
||||
MarkupSafe==3.0.2
|
||||
marshmallow==3.22.0
|
||||
matplotlib==3.10.0
|
||||
mdurl==0.1.2
|
||||
modelscope==1.25.0
|
||||
mpmath==1.3.0
|
||||
msgpack==1.1.0
|
||||
multidict==6.4.3
|
||||
multiprocess==0.70.16
|
||||
narwhals==1.35.0
|
||||
networkx==3.4.2
|
||||
ngrok==1.4.0
|
||||
ninja==1.11.1.4
|
||||
nltk==3.8
|
||||
numpy==1.26.4
|
||||
nvidia-cublas-cu11==11.11.3.6
|
||||
nvidia-cublas-cu12==12.6.4.1
|
||||
nvidia-cuda-cupti-cu11==11.8.87
|
||||
nvidia-cuda-cupti-cu12==12.6.80
|
||||
nvidia-cuda-nvrtc-cu11==11.8.89
|
||||
nvidia-cuda-nvrtc-cu12==12.6.77
|
||||
nvidia-cuda-runtime-cu11==11.8.89
|
||||
nvidia-cuda-runtime-cu12==12.6.77
|
||||
nvidia-cudnn-cu11==9.1.0.70
|
||||
nvidia-cudnn-cu12==9.5.1.17
|
||||
nvidia-cufft-cu11==10.9.0.58
|
||||
nvidia-cufft-cu12==11.3.0.4
|
||||
nvidia-cufile-cu12==1.11.1.6
|
||||
nvidia-curand-cu11==10.3.0.86
|
||||
nvidia-curand-cu12==10.3.7.77
|
||||
nvidia-cusolver-cu11==11.4.1.48
|
||||
nvidia-cusolver-cu12==11.7.1.2
|
||||
nvidia-cusparse-cu11==11.7.5.86
|
||||
nvidia-cusparse-cu12==12.5.4.2
|
||||
nvidia-cusparselt-cu12==0.6.3
|
||||
nvidia-ml-py==12.575.51
|
||||
nvidia-nccl-cu11==2.21.5
|
||||
nvidia-nccl-cu12==2.26.2
|
||||
nvidia-nvjitlink-cu12==12.6.85
|
||||
nvidia-nvtx-cu11==11.8.86
|
||||
nvidia-nvtx-cu12==12.6.77
|
||||
openai==1.59.6
|
||||
packaging==23.2
|
||||
pandas==1.5.3
|
||||
peft==0.7.1
|
||||
pillow==10.4.0
|
||||
platformdirs==4.3.7
|
||||
prettytable==3.16.0
|
||||
propcache==0.3.1
|
||||
protobuf==4.25.6
|
||||
psutil==5.9.8
|
||||
py-cpuinfo==9.0.0
|
||||
pyarrow==19.0.1
|
||||
pydantic==2.11.7
|
||||
pydantic_core==2.33.2
|
||||
pydeck==0.9.1
|
||||
pyecharts==2.0.8
|
||||
Pygments==2.19.1
|
||||
pynvml==12.0.0
|
||||
pyparsing==3.2.3
|
||||
python-dateutil==2.9.0.post0
|
||||
pytz==2025.2
|
||||
PyYAML==6.0.2
|
||||
referencing==0.36.2
|
||||
regex==2024.11.6
|
||||
requests==2.32.3
|
||||
pydantic==2.8.2
|
||||
rich==13.7.1
|
||||
rpds-py==0.24.0
|
||||
s3transfer==0.13.0
|
||||
safetensors==0.5.3
|
||||
scikit-learn==1.5.1
|
||||
scipy==1.15.2
|
||||
sentence-transformers==2.3.1
|
||||
sentencepiece==0.2.0
|
||||
sentry-sdk==2.26.1
|
||||
setproctitle==1.3.5
|
||||
scikit_learn==1.5.1
|
||||
sentence_transformers==2.3.1
|
||||
simhash==2.1.2
|
||||
simplejson==3.20.1
|
||||
six==1.17.0
|
||||
smmap==5.0.2
|
||||
sniffio==1.3.1
|
||||
streamlit==1.30.0
|
||||
swankit==0.2.4
|
||||
swanlab==0.6.4
|
||||
sympy==1.13.3
|
||||
tenacity==8.5.0
|
||||
threadpoolctl==3.6.0
|
||||
tiktoken==0.5.1
|
||||
tokenizers==0.21.1
|
||||
toml==0.10.2
|
||||
torch==2.7.1
|
||||
torchaudio==2.7.1
|
||||
torchvision==0.22.1
|
||||
tornado==6.4.2
|
||||
tqdm==4.67.1
|
||||
transformers==4.52.4
|
||||
triton==3.3.1
|
||||
transformers==4.48.0
|
||||
jinja2==3.1.2
|
||||
jsonlines==4.0.0
|
||||
trl==0.13.0
|
||||
typing-inspection==0.4.1
|
||||
typing_extensions==4.13.2
|
||||
tzlocal==5.3.1
|
||||
ujson==5.1.0
|
||||
urllib3==2.4.0
|
||||
validators==0.34.0
|
||||
wandb==0.18.3
|
||||
watchdog==6.0.0
|
||||
wcwidth==0.2.13
|
||||
Werkzeug==3.1.3
|
||||
wrapt==1.17.2
|
||||
xxhash==3.5.0
|
||||
yarl==1.20.0
|
||||
zipp==3.21.0
|
||||
streamlit==1.30.0
|
||||
torch==2.2.2
|
||||
torchvision==0.17.2
|
@ -1,52 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 激活conda环境
|
||||
# source $(conda info --base)/etc/profile.d/conda.sh
|
||||
# conda activate mini
|
||||
|
||||
# 设置环境变量以帮助调试
|
||||
export NCCL_DEBUG=INFO
|
||||
export PYTHONFAULTHANDLER=1
|
||||
|
||||
# 方法1: 使用预先配置的accelerate配置文件
|
||||
# accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
|
||||
# --epochs 3 \
|
||||
# --batch_size 24 \
|
||||
# --learning_rate 2e-4 \
|
||||
# --dtype bfloat16 \
|
||||
# --accumulation_steps 32 \
|
||||
# --grad_clip 1.0 \
|
||||
# --log_interval 100 \
|
||||
# --save_interval 10000 \
|
||||
# --dim 1024 \
|
||||
# --n_layers 32 \
|
||||
# --max_seq_len 1024 \
|
||||
# --use_flash_attn \
|
||||
# --profile \
|
||||
# --profile_interval 10
|
||||
|
||||
# 方法2: 使用命令行参数直接配置accelerate
|
||||
# 内存泄漏调试配置 - 减少内存使用
|
||||
CUDA_VISIBLE_DEVICES=0 uv run -p .venv python -m accelerate.commands.launch \
|
||||
--num_processes=1 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py
|
||||
# --batch_size 128 \
|
||||
# --num_workers 1
|
||||
# --knowledge_num 48020 \
|
||||
# --num_workers 1 \
|
||||
# --epochs 4 \
|
||||
# --learning_rate 2e-4 \
|
||||
# --dtype bfloat16 \
|
||||
# --accumulation_steps 32 \
|
||||
# --grad_clip 1.0 \
|
||||
# --log_interval 50 \
|
||||
# --save_interval 10000 \
|
||||
# --dim 512 \
|
||||
# --n_layers 8 \
|
||||
# --max_seq_len 512 \
|
||||
# --use_flash_attn \
|
||||
# --profile \
|
||||
# --profile_interval 10
|
||||
|
@ -1,50 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 激活conda环境
|
||||
source $(conda info --base)/etc/profile.d/conda.sh
|
||||
conda activate ycz_accelerate
|
||||
|
||||
# 设置环境变量以帮助调试
|
||||
export NCCL_DEBUG=INFO
|
||||
export PYTHONFAULTHANDLER=1
|
||||
|
||||
# 方法1: 使用预先配置的accelerate配置文件
|
||||
# accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
|
||||
# --epochs 3 \
|
||||
# --batch_size 24 \
|
||||
# --learning_rate 2e-4 \
|
||||
# --dtype bfloat16 \
|
||||
# --accumulation_steps 32 \
|
||||
# --grad_clip 1.0 \
|
||||
# --log_interval 100 \
|
||||
# --save_interval 10000 \
|
||||
# --dim 1024 \
|
||||
# --n_layers 32 \
|
||||
# --max_seq_len 1024 \
|
||||
# --use_flash_attn \
|
||||
# --profile \
|
||||
# --profile_interval 10
|
||||
|
||||
# 方法2: 使用命令行参数直接配置accelerate
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py \
|
||||
--epochs 3 \
|
||||
--batch_size 24 \
|
||||
--learning_rate 2e-4 \
|
||||
--dtype bfloat16 \
|
||||
--accumulation_steps 32 \
|
||||
--grad_clip 1.0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 1024 \
|
||||
--n_layers 32 \
|
||||
--max_seq_len 1024 \
|
||||
--use_flash_attn \
|
||||
--profile \
|
||||
--profile_interval 10\
|
||||
--knowledge_num 16384 \
|
||||
--knowledge_length 64
|
@ -1,46 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 激活conda环境
|
||||
source $(conda info --base)/etc/profile.d/conda.sh
|
||||
conda activate ycz_accelerate
|
||||
|
||||
# 设置环境变量以帮助调试
|
||||
export NCCL_DEBUG=INFO
|
||||
export PYTHONFAULTHANDLER=1
|
||||
|
||||
# 实验1.3.0 - 使用命令行参数直接配置accelerate
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py \
|
||||
--out_dir "out" \
|
||||
--epochs 3 \
|
||||
--embedding_epoch 2 \
|
||||
--batch_size 64 \
|
||||
--learning_rate 8e-5 \
|
||||
--dtype bfloat16 \
|
||||
--use_swanlab \
|
||||
--swanlab_project "MiniMind-Pretrain" \
|
||||
--num_workers 1 \
|
||||
--accumulation_steps 16 \
|
||||
--grad_clip 0.5 \
|
||||
--warmup_iters 0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 1024 \
|
||||
--n_layers 48 \
|
||||
--max_seq_len 512 \
|
||||
--data_path "./dataset/stable/merged_pretrain.jsonl" \
|
||||
--profile \
|
||||
--profile_interval 10 \
|
||||
--use_flash_attn \
|
||||
--knowledge_num 1048576 \
|
||||
--knowledge_length 32 \
|
||||
--database_init_path "./dataset/stable/sentence_trex_data.json" \
|
||||
--fast_clustering \
|
||||
--cluster_cache_path "./cache/cluster_tokens_single.pt" \
|
||||
--memory_monitor_interval 10 \
|
||||
--model_type "model_original" \
|
||||
--model_size 538
|
@ -1,47 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 激活conda环境
|
||||
source $(conda info --base)/etc/profile.d/conda.sh
|
||||
conda activate ycz_accelerate
|
||||
|
||||
# 设置环境变量以帮助调试
|
||||
export NCCL_DEBUG=INFO
|
||||
export PYTHONFAULTHANDLER=1
|
||||
|
||||
# 实验1.3.0 - 使用命令行参数直接配置accelerate
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py \
|
||||
--out_dir "out" \
|
||||
--epochs 3 \
|
||||
--embedding_epoch 2 \
|
||||
--batch_size 48 \
|
||||
--learning_rate 2e-4 \
|
||||
--dtype bfloat16 \
|
||||
--use_swanlab \
|
||||
--swanlab_project "MiniMind-Pretrain" \
|
||||
--num_workers 1 \
|
||||
--accumulation_steps 32 \
|
||||
--grad_clip 1.0 \
|
||||
--warmup_iters 0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 1024 \
|
||||
--n_layers 18 \
|
||||
--max_seq_len 512 \
|
||||
--use_moe False \
|
||||
--data_path "./dataset/stable/merged_pretrain.jsonl" \
|
||||
--profile \
|
||||
--profile_interval 10 \
|
||||
--use_flash_attn \
|
||||
--knowledge_num 1048576 \
|
||||
--knowledge_length 32 \
|
||||
--database_init_path "./dataset/stable/sentence_trex_data.json" \
|
||||
--fast_clustering \
|
||||
--cluster_cache_path "./cache/cluster_tokens_single.pt" \
|
||||
--memory_monitor_interval 10 \
|
||||
--model_type "model_no_feed" \
|
||||
--model_size 814.724
|
@ -1,47 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 激活conda环境
|
||||
source $(conda info --base)/etc/profile.d/conda.sh
|
||||
conda activate ycz_accelerate
|
||||
|
||||
# 设置环境变量以帮助调试
|
||||
export NCCL_DEBUG=INFO
|
||||
export PYTHONFAULTHANDLER=1
|
||||
|
||||
# 实验1.3.0 - 使用命令行参数直接配置accelerate
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py \
|
||||
--out_dir "out" \
|
||||
--epochs 3 \
|
||||
--embedding_epoch 2 \
|
||||
--batch_size 48 \
|
||||
--learning_rate 2e-4 \
|
||||
--dtype bfloat16 \
|
||||
--use_swanlab \
|
||||
--swanlab_project "MiniMind-Pretrain" \
|
||||
--num_workers 1 \
|
||||
--accumulation_steps 32 \
|
||||
--grad_clip 1.0 \
|
||||
--warmup_iters 0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 1024 \
|
||||
--n_layers 18 \
|
||||
--max_seq_len 512 \
|
||||
--use_moe False \
|
||||
--data_path "./dataset/stable/merged_pretrain.jsonl" \
|
||||
--profile \
|
||||
--profile_interval 10 \
|
||||
--use_flash_attn \
|
||||
--knowledge_num 1048576 \
|
||||
--knowledge_length 32 \
|
||||
--database_init_path "./dataset/stable/sentence_trex_data.json" \
|
||||
--fast_clustering \
|
||||
--cluster_cache_path "./cache/cluster_tokens_single.pt" \
|
||||
--memory_monitor_interval 10 \
|
||||
--model_type "model" \
|
||||
--model_size 814.724
|
@ -1,45 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 激活conda环境
|
||||
# source $(conda info --base)/etc/profile.d/conda.sh
|
||||
# conda activate ycz_accelerate
|
||||
|
||||
# 设置环境变量以帮助调试
|
||||
export NCCL_DEBUG=INFO
|
||||
export PYTHONFAULTHANDLER=1
|
||||
|
||||
# 实验1.3.0 - 使用命令行参数直接配置accelerate
|
||||
CUDA_VISIBLE_DEVICES=0 accelerate launch \
|
||||
--num_processes=1 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py \
|
||||
--out_dir "out" \
|
||||
--epochs 3 \
|
||||
--embedding_epoch 2 \
|
||||
--batch_size 128 \
|
||||
--learning_rate 8e-5 \
|
||||
--dtype bfloat16 \
|
||||
--use_swanlab \
|
||||
--swanlab_project "MiniMind-Pretrain" \
|
||||
--num_workers 1 \
|
||||
--accumulation_steps 16 \
|
||||
--grad_clip 0.5 \
|
||||
--warmup_iters 0 \
|
||||
--log_interval 100 \
|
||||
--save_interval 10000 \
|
||||
--dim 512 \
|
||||
--n_layers 8 \
|
||||
--max_seq_len 512 \
|
||||
--data_path "./dataset/stable/merged_pretrain.jsonl" \
|
||||
--profile \
|
||||
--profile_interval 10 \
|
||||
--use_flash_attn \
|
||||
--knowledge_num 1048576 \
|
||||
--knowledge_length 32 \
|
||||
--database_init_path "./dataset/stable/sentence_trex_data.json" \
|
||||
--fast_clustering \
|
||||
--cluster_cache_path "./cache/cluster_tokens_single.pt" \
|
||||
--memory_monitor_interval 10 \
|
||||
--model_type "model" \
|
||||
--model_size 538
|
@ -32,7 +32,7 @@ def train_tokenizer():
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
||||
|
||||
# 定义特殊token
|
||||
special_tokens = ["<unk>", "<|im_start|>", "<|im_end|>"]
|
||||
special_tokens = ["<unk>", "<s>", "</s>"]
|
||||
|
||||
# 设置训练器并添加特殊token
|
||||
trainer = trainers.BpeTrainer(
|
||||
@ -53,8 +53,8 @@ def train_tokenizer():
|
||||
|
||||
# 检查特殊token的索引
|
||||
assert tokenizer.token_to_id("<unk>") == 0
|
||||
assert tokenizer.token_to_id("<|im_start|>") == 1
|
||||
assert tokenizer.token_to_id("<|im_end|>") == 2
|
||||
assert tokenizer.token_to_id("<s>") == 1
|
||||
assert tokenizer.token_to_id("</s>") == 2
|
||||
|
||||
# 保存tokenizer
|
||||
tokenizer_dir = "../model/minimind_tokenizer"
|
||||
@ -77,7 +77,7 @@ def train_tokenizer():
|
||||
"special": True
|
||||
},
|
||||
"1": {
|
||||
"content": "<|im_start|>",
|
||||
"content": "<s>",
|
||||
"lstrip": False,
|
||||
"normalized": False,
|
||||
"rstrip": False,
|
||||
@ -85,7 +85,7 @@ def train_tokenizer():
|
||||
"special": True
|
||||
},
|
||||
"2": {
|
||||
"content": "<|im_end|>",
|
||||
"content": "</s>",
|
||||
"lstrip": False,
|
||||
"normalized": False,
|
||||
"rstrip": False,
|
||||
@ -94,9 +94,9 @@ def train_tokenizer():
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [],
|
||||
"bos_token": "<|im_start|>",
|
||||
"bos_token": "<s>",
|
||||
"clean_up_tokenization_spaces": False,
|
||||
"eos_token": "<|im_end|>",
|
||||
"eos_token": "</s>",
|
||||
"legacy": True,
|
||||
"model_max_length": 32768,
|
||||
"pad_token": "<unk>",
|
||||
@ -104,7 +104,7 @@ def train_tokenizer():
|
||||
"spaces_between_special_tokens": False,
|
||||
"tokenizer_class": "PreTrainedTokenizerFast",
|
||||
"unk_token": "<unk>",
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\n你是 MiniMind,是一个有用的人工智能助手。<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<s>system\\n' + system_message + '</s>\\n' }}{% else %}{{ '<s>system\\n你是 MiniMind,是一个有用的人工智能助手。</s>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
|
||||
}
|
||||
|
||||
# 保存配置文件
|
||||
|
33
startup.sh
@ -1,33 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# 在容器启动后,首先从 requirements.txt 安装所有依赖包
|
||||
# pip install -r requirements.txt
|
||||
|
||||
# bash install.sh -y
|
||||
python3 -m pip install --upgrade pip
|
||||
pip install uv -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
# 切换到项目目录
|
||||
cd /ycz/Minimind
|
||||
|
||||
# 检查并修复虚拟环境
|
||||
if [ ! -f .venv/bin/python ] || [ ! -x .venv/bin/python ]; then
|
||||
echo "Virtual environment is broken or missing, recreating with uv..."
|
||||
rm -rf .venv
|
||||
uv venv .venv
|
||||
fi
|
||||
|
||||
# 不要手动激活虚拟环境,让uv自动管理
|
||||
# . ./.venv/bin/activate
|
||||
|
||||
# 使用uv同步依赖
|
||||
uv sync
|
||||
|
||||
# 安装完成后,执行主训练脚本
|
||||
# "$@" 会将 experiment.yaml 中 entrypoint 定义的参数传递给 python 脚本
|
||||
CUDA_VISIBLE_DEVICES=0 uv run python -m accelerate.commands.launch \
|
||||
--num_processes=1 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py "$@"
|
@ -13,7 +13,6 @@ from torch import optim, nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
# 移除通信分析工具导入
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
@ -43,67 +42,18 @@ def train_epoch(epoch, wandb):
|
||||
start_time = time.time()
|
||||
# 在函数开始处定义moe_path,避免在异常处理中引用未定义变量
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
|
||||
# 添加CUDA事件来分析性能
|
||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||
data_start = torch.cuda.Event(enable_timing=True)
|
||||
data_end = torch.cuda.Event(enable_timing=True)
|
||||
forward_start = torch.cuda.Event(enable_timing=True)
|
||||
forward_end = torch.cuda.Event(enable_timing=True)
|
||||
backward_start = torch.cuda.Event(enable_timing=True)
|
||||
backward_end = torch.cuda.Event(enable_timing=True)
|
||||
optimizer_start = torch.cuda.Event(enable_timing=True)
|
||||
optimizer_end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
# 移除CUDA图优化代码
|
||||
|
||||
# 预取数据
|
||||
prefetch_factor = 2 # 预取的批次数
|
||||
data_iter = iter(train_loader)
|
||||
prefetch_batches = []
|
||||
|
||||
# 预取初始批次
|
||||
for _ in range(min(prefetch_factor, len(train_loader))):
|
||||
for step, (X, Y, loss_mask) in enumerate(train_loader):
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
prefetch_batches.append([t.to(args.device, non_blocking=True) for t in batch])
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
for step in range(len(train_loader)):
|
||||
try:
|
||||
# 计时数据加载
|
||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||
data_start.record()
|
||||
|
||||
# 使用预取的数据
|
||||
if prefetch_batches:
|
||||
X, Y, loss_mask = prefetch_batches.pop(0)
|
||||
else:
|
||||
# 如果预取队列为空,直接加载
|
||||
X, Y, loss_mask = [t.to(args.device) for t in next(data_iter)]
|
||||
|
||||
# 异步预取下一批数据
|
||||
if step + prefetch_factor < len(train_loader):
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
prefetch_batches.append([t.to(args.device, non_blocking=True) for t in batch])
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||
data_end.record()
|
||||
# 将数据加载到设备上
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
|
||||
# 更新学习率
|
||||
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# 计时前向传播
|
||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||
forward_start.record()
|
||||
|
||||
# 常规前向传播
|
||||
with ctx:
|
||||
res = model(X)
|
||||
loss = loss_fct(
|
||||
@ -127,13 +77,6 @@ def train_epoch(epoch, wandb):
|
||||
# 如果出错,不添加辅助损失
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
# 反向传播
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||
forward_end.record()
|
||||
backward_start.record()
|
||||
|
||||
# Print data types for debugging
|
||||
if step == 0 and (not ddp or dist.get_rank() == 0): # Print only for the first step of the first epoch on the main process
|
||||
Logger("---- Data Type Check ----")
|
||||
@ -146,21 +89,9 @@ def train_epoch(epoch, wandb):
|
||||
Logger(f"loss.dtype: {loss.dtype}")
|
||||
Logger("-------------------------")
|
||||
|
||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||
backward_end.record()
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# 在每一步都进行性能分析,而不仅仅是在梯度累积完成时
|
||||
if (step + 1) % args.profile_interval == 0:
|
||||
# 记录优化器时间(如果是梯度累积步骤)
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
optimizer_start.record()
|
||||
|
||||
# 优化器步骤
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||
if (step + 1) % args.profile_interval != 0:
|
||||
optimizer_start.record()
|
||||
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
|
||||
@ -169,40 +100,6 @@ def train_epoch(epoch, wandb):
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||
optimizer_end.record()
|
||||
|
||||
# 性能分析输出(每profile_interval步)
|
||||
if args.profile and (not ddp or dist.get_rank() == 0) and (step + 1) % args.profile_interval == 0:
|
||||
# 同步CUDA事件以获取准确的计时
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# 计算各阶段耗时
|
||||
data_time = data_start.elapsed_time(data_end)
|
||||
forward_time = forward_start.elapsed_time(forward_end)
|
||||
backward_time = backward_start.elapsed_time(backward_end)
|
||||
|
||||
# 只有在梯度累积步骤完成时才有优化器时间
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
optimizer_time = optimizer_start.elapsed_time(optimizer_end)
|
||||
total_compute_time = forward_time + backward_time + optimizer_time
|
||||
Logger(f"性能分析 - 步骤 {step+1}:")
|
||||
Logger(f" 数据加载时间: {data_time:.2f} ms")
|
||||
Logger(f" 前向传播时间: {forward_time:.2f} ms")
|
||||
Logger(f" 反向传播时间: {backward_time:.2f} ms")
|
||||
Logger(f" 优化器时间: {optimizer_time:.2f} ms")
|
||||
Logger(f" 总计算时间: {total_compute_time:.2f} ms")
|
||||
Logger(f" 计算/数据比例: {total_compute_time / data_time:.2f}")
|
||||
else:
|
||||
# 非梯度累积步骤,没有优化器时间
|
||||
total_compute_time = forward_time + backward_time
|
||||
Logger(f"性能分析 - 步骤 {step+1} (梯度累积中):")
|
||||
Logger(f" 数据加载时间: {data_time:.2f} ms")
|
||||
Logger(f" 前向传播时间: {forward_time:.2f} ms")
|
||||
Logger(f" 反向传播时间: {backward_time:.2f} ms")
|
||||
Logger(f" 总计算时间: {total_compute_time:.2f} ms")
|
||||
Logger(f" 计算/数据比例: {total_compute_time / data_time:.2f}")
|
||||
|
||||
# 打印日志
|
||||
if step % args.log_interval == 0:
|
||||
spend_time = time.time() - start_time
|
||||
@ -217,39 +114,9 @@ def train_epoch(epoch, wandb):
|
||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
||||
|
||||
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
||||
log_dict = {
|
||||
"loss": loss.item() * args.accumulation_steps,
|
||||
"lr": optimizer.param_groups[-1]['lr'],
|
||||
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
|
||||
}
|
||||
|
||||
# 如果启用了性能分析,也记录性能指标
|
||||
if args.profile and (step + 1) % args.profile_interval == 0:
|
||||
# 基本性能指标
|
||||
perf_dict = {
|
||||
"data_time_ms": data_time,
|
||||
"forward_time_ms": forward_time,
|
||||
"backward_time_ms": backward_time
|
||||
}
|
||||
|
||||
# 只有在梯度累积步骤完成时才有优化器时间
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
total_compute_time = forward_time + backward_time + optimizer_time
|
||||
perf_dict.update({
|
||||
"optimizer_time_ms": optimizer_time,
|
||||
"compute_time_ms": total_compute_time
|
||||
})
|
||||
else:
|
||||
total_compute_time = forward_time + backward_time
|
||||
perf_dict.update({
|
||||
"compute_time_ms": total_compute_time
|
||||
})
|
||||
|
||||
log_dict.update(perf_dict)
|
||||
|
||||
wandb.log(log_dict)
|
||||
|
||||
# 移除通信分析代码
|
||||
wandb.log({"loss": loss.item() * args.accumulation_steps,
|
||||
"lr": optimizer.param_groups[-1]['lr'],
|
||||
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
|
||||
|
||||
# 保存模型
|
||||
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
|
||||
@ -291,7 +158,7 @@ def train_epoch(epoch, wandb):
|
||||
|
||||
def init_model(lm_config, pretrained_embedding_path: Optional[str] = None):
|
||||
# 加载tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained('/mnt/lzn/Minimind/Minimind/model/minimind_tokenizer')
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
# 加载模型
|
||||
model = MiniMindLM(lm_config).to(args.device)
|
||||
|
||||
@ -309,9 +176,6 @@ def init_model(lm_config, pretrained_embedding_path: Optional[str] = None):
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
# 移除通信分析函数
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if not ddp: return #如果没有启用分布式数据并行(DDP),直接返回,不执行任何操作。
|
||||
global ddp_local_rank, DEVICE #声明这两个变量为全局变量,以便在函数外部也能访问它们。
|
||||
@ -330,42 +194,35 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--out_dir", type=str, default="out")
|
||||
# 若要以最快速度实现zero则epochs设置为1轮;否则应当利用有限的数据训练2~6个epochs。
|
||||
parser.add_argument("--epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=24)
|
||||
parser.add_argument("--batch_size", type=int, default=8)
|
||||
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用,则使用GPU,否则使用CPU。
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", default=True, action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
|
||||
parser.add_argument("--num_workers", type=int, default=48)
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=32) #梯度累积步数,用于控制梯度更新频率。
|
||||
parser.add_argument("--accumulation_steps", type=int, default=64) #梯度累积步数,用于控制梯度更新频率。
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0) #梯度裁剪阈值,用于防止梯度爆炸。
|
||||
parser.add_argument("--warmup_iters", type=int, default=0) #预热迭代次数,用于控制学习率预热过程。
|
||||
parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。
|
||||
parser.add_argument("--save_interval", type=int, default=10000) #模型保存间隔,用于控制模型保存的频率。
|
||||
parser.add_argument("--save_interval", type=int, default=100) #模型保存间隔,用于控制模型保存的频率。
|
||||
parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。
|
||||
parser.add_argument('--dim', default=1024, type=int) #模型维度,用于控制模型的大小。
|
||||
parser.add_argument('--dim', default=2048, type=int) #模型维度,用于控制模型的大小。
|
||||
parser.add_argument('--n_layers', default=32, type=int) #层数,用于控制模型层数。
|
||||
parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。
|
||||
parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE,用于控制是否使用MOE。
|
||||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代") #禁用数据库功能,启用特殊模式
|
||||
parser.add_argument("--data_path", type=str, default="/mnt/lzn/Minimind/dataset/dir/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
|
||||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||||
# 性能分析相关参数
|
||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
|
||||
lm_config = LMConfig(
|
||||
dim=args.dim,
|
||||
n_layers=args.n_layers,
|
||||
max_seq_len=args.max_seq_len,
|
||||
use_moe=args.use_moe,
|
||||
disable_db=args.disable_db, # 添加禁用数据库参数
|
||||
flash_attn=args.use_flash_attn # 添加FlashAttention支持
|
||||
disable_db=args.disable_db # 添加禁用数据库参数
|
||||
) #创建LMConfig对象,用于控制模型配置。
|
||||
args.save_dir = os.path.join(args.out_dir) #创建保存目录。
|
||||
os.makedirs(args.save_dir, exist_ok=True) #创建保存目录。
|
||||
@ -406,34 +263,28 @@ if __name__ == "__main__":
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
|
||||
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
# 优化DataLoader配置
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
pin_memory_device=f"cuda:{ddp_local_rank}" if ddp else "cuda:0", # 指定pin_memory设备
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler,
|
||||
persistent_workers=True if args.num_workers > 0 else False, # 保持worker进程活跃
|
||||
prefetch_factor=2 if args.num_workers > 0 else None # 预取因子
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
# 只有在使用float16时才启用GradScaler,bfloat16不需要
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16']))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
if ddp:
|
||||
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
||||
# 保留find_unused_parameters=True参数,因为模型中确实有未使用的参数
|
||||
# 添加find_unused_parameters=True参数,解决未使用参数的问题
|
||||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank], find_unused_parameters=True)
|
||||
|
||||
# 暂时保留set_detect_anomaly以便调试
|
||||
# 训练稳定后可以注释掉这行来提高速度
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
iter_per_epoch = len(train_loader)
|
||||
for epoch in range(args.epochs):
|
||||
|