添加了对于多种模型的支持

This commit is contained in:
Yu Chengzhang 2025-07-12 18:00:53 +08:00
parent 75265f6652
commit 5e464bbd3f

View File

@ -25,7 +25,7 @@ import swanlab # 替换wandb导入
import gc # 添加垃圾回收模块
import psutil # 添加系统资源监控模块
from model.model import MiniMindLM, RMSNorm
from model.LMConfig import LMConfig
from model.dataset import PretrainDataset
@ -105,6 +105,9 @@ def get_lr(it, num_iters, learning_rate):
# 初始化模型函数
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
if args.model_type == "model":
Logger(f"Using model type: {args.model_type}")
from model.model import MiniMindLM, RMSNorm
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
@ -276,6 +279,13 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
Logger(f"Database embeddings and sentences stored in model")
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
elif args.model_type == "model_original":
Logger(f"Using model type: {args.model_type}")
from model.model_original import MiniMindLM, RMSNorm
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run):
@ -389,7 +399,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
# 添加辅助损失,如果存在的话
try:
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
if hasattr(l.feed_forward, 'aux_loss'))
if hasattr(l, 'feed_forward') and hasattr(l.feed_forward, 'aux_loss'))
loss += aux_loss
except Exception as e:
Logger(f"Warning: Could not add auxiliary loss: {e}")
@ -586,7 +596,7 @@ def main():
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代")
parser.add_argument("--data_path", type=str, default="./dataset/merged_pretrain.jsonl")
parser.add_argument("--data_path", type=str, default="./dataset/stable/merged_pretrain.jsonl")
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
@ -599,6 +609,7 @@ def main():
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控")
parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)")
parser.add_argument("--model_type", type=str, default="model", help="使用什么模型训练") #model,model_original
args = parser.parse_args()
#########################################################