添加了对于多种模型的支持
This commit is contained in:
parent
75265f6652
commit
5e464bbd3f
@ -25,7 +25,7 @@ import swanlab # 替换wandb导入
|
|||||||
import gc # 添加垃圾回收模块
|
import gc # 添加垃圾回收模块
|
||||||
import psutil # 添加系统资源监控模块
|
import psutil # 添加系统资源监控模块
|
||||||
|
|
||||||
from model.model import MiniMindLM, RMSNorm
|
|
||||||
from model.LMConfig import LMConfig
|
from model.LMConfig import LMConfig
|
||||||
from model.dataset import PretrainDataset
|
from model.dataset import PretrainDataset
|
||||||
|
|
||||||
@ -105,6 +105,9 @@ def get_lr(it, num_iters, learning_rate):
|
|||||||
|
|
||||||
# 初始化模型函数
|
# 初始化模型函数
|
||||||
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
|
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
|
||||||
|
if args.model_type == "model":
|
||||||
|
Logger(f"Using model type: {args.model_type}")
|
||||||
|
from model.model import MiniMindLM, RMSNorm
|
||||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||||
model = MiniMindLM(lm_config)
|
model = MiniMindLM(lm_config)
|
||||||
|
|
||||||
@ -276,6 +279,13 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
Logger(f"Database embeddings and sentences stored in model")
|
Logger(f"Database embeddings and sentences stored in model")
|
||||||
|
|
||||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||||
|
elif args.model_type == "model_original":
|
||||||
|
Logger(f"Using model type: {args.model_type}")
|
||||||
|
from model.model_original import MiniMindLM, RMSNorm
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||||
|
model = MiniMindLM(lm_config)
|
||||||
|
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run):
|
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run):
|
||||||
@ -389,7 +399,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
|||||||
# 添加辅助损失,如果存在的话
|
# 添加辅助损失,如果存在的话
|
||||||
try:
|
try:
|
||||||
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
|
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
|
||||||
if hasattr(l.feed_forward, 'aux_loss'))
|
if hasattr(l, 'feed_forward') and hasattr(l.feed_forward, 'aux_loss'))
|
||||||
loss += aux_loss
|
loss += aux_loss
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
Logger(f"Warning: Could not add auxiliary loss: {e}")
|
Logger(f"Warning: Could not add auxiliary loss: {e}")
|
||||||
@ -586,7 +596,7 @@ def main():
|
|||||||
parser.add_argument('--max_seq_len', default=512, type=int)
|
parser.add_argument('--max_seq_len', default=512, type=int)
|
||||||
parser.add_argument('--use_moe', default=False, type=bool)
|
parser.add_argument('--use_moe', default=False, type=bool)
|
||||||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
||||||
parser.add_argument("--data_path", type=str, default="./dataset/merged_pretrain.jsonl")
|
parser.add_argument("--data_path", type=str, default="./dataset/stable/merged_pretrain.jsonl")
|
||||||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||||
@ -599,6 +609,7 @@ def main():
|
|||||||
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
|
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
|
||||||
parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控")
|
parser.add_argument("--memory_monitor", action="store_true", default=False, help="启用内存监控")
|
||||||
parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)")
|
parser.add_argument("--memory_monitor_interval", type=int, default=10, help="内存监控间隔(步数)")
|
||||||
|
parser.add_argument("--model_type", type=str, default="model", help="使用什么模型训练") #model,model_original
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
#########################################################
|
#########################################################
|
||||||
|
Loading…
x
Reference in New Issue
Block a user