添加了对于多种模型的支持
This commit is contained in:
parent
75265f6652
commit
5e464bbd3f
@ -25,7 +25,7 @@ import swanlab # 替换wandb导入
|
||||
import gc # 添加垃圾回收模块
|
||||
import psutil # 添加系统资源监控模块
|
||||
|
||||
from model.model import MiniMindLM, RMSNorm
|
||||
|
||||
from model.LMConfig import LMConfig
|
||||
from model.dataset import PretrainDataset
|
||||
|
||||
@ -105,6 +105,9 @@ def get_lr(it, num_iters, learning_rate):
|
||||
|
||||
# 初始化模型函数
|
||||
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
|
||||
if args.model_type == "model":
|
||||
Logger(f"Using model type: {args.model_type}")
|
||||
from model.model import MiniMindLM, RMSNorm
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
model = MiniMindLM(lm_config)
|
||||
|
||||
@ -276,6 +279,13 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
||||
Logger(f"Database embeddings and sentences stored in model")
|
||||
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
elif args.model_type == "model_original":
|
||||
Logger(f"Using model type: {args.model_type}")
|
||||
from model.model_original import MiniMindLM, RMSNorm
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||
model = MiniMindLM(lm_config)
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run):
|
||||
@ -389,7 +399,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
# 添加辅助损失,如果存在的话
|
||||
try:
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
|
||||
if hasattr(l.feed_forward, 'aux_loss'))
|
||||
if hasattr(l, 'feed_forward') and hasattr(l.feed_forward, 'aux_loss'))
|
||||
loss += aux_loss
|
||||
except Exception as e:
|
||||
Logger(f"Warning: Could not add auxiliary loss: {e}")
|
||||
@ -586,7 +596,7 @@ def main():
|
||||
parser.add_argument('--max_seq_len', default=512, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/merged_pretrain.jsonl")
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/stable/merged_pretrain.jsonl")
|
||||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||
@ -599,6 +609,7 @@ def main():
|
||||
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
|
||||
parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控")
|
||||
parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)")
|
||||
parser.add_argument("--model_type", type=str, default="model", help="使用什么模型训练") #model,model_original
|
||||
args = parser.parse_args()
|
||||
|
||||
#########################################################
|
||||
|
Loading…
x
Reference in New Issue
Block a user