diff --git a/train_extra_accelerate.py b/train_extra_accelerate.py index d8f5fcd..37b4e47 100644 --- a/train_extra_accelerate.py +++ b/train_extra_accelerate.py @@ -24,7 +24,9 @@ from sklearn.metrics.pairwise import cosine_similarity import swanlab # 替换wandb导入 import gc # 添加垃圾回收模块 import psutil # 添加系统资源监控模块 +import os +os.environ['CUDA_VISIBLE_DEVICES']='2' from model.model_extra import MiniMindLM, RMSNorm # 使用model_extra from model.LMConfig import LMConfig from model.dataset import TriplePretrainDataset # 只需要三元组数据集 @@ -785,14 +787,14 @@ def main(): parser.add_argument("--accumulation_steps", type=int, default=32) parser.add_argument("--grad_clip", type=float, default=1.0) parser.add_argument("--warmup_iters", type=int, default=0) - parser.add_argument("--log_interval", type=int, default=100) + parser.add_argument("--log_interval", type=int, default=50) parser.add_argument("--save_interval", type=int, default=10000) parser.add_argument('--dim', default=512, type=int) parser.add_argument('--n_layers', default=8, type=int) parser.add_argument('--max_seq_len', default=512, type=int) parser.add_argument('--use_moe', default=False, type=bool) parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代") - parser.add_argument("--data_path", type=str, default="./dataset/processed_trex_data.json") + parser.add_argument("--data_path", type=str, default="/home/rwkv/RWKV-TS/RETRO_TEST/extract/sample_1000.json") parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)") parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")