update
This commit is contained in:
parent
cb3152dc94
commit
2f6995d667
@ -622,10 +622,12 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
# 固定冻结前面的条目
|
# 固定冻结前面的条目
|
||||||
freeze_mask[:freeze_num] = True
|
freeze_mask[:freeze_num] = True
|
||||||
self.register_buffer('freeze_mask', freeze_mask, persistent=False)
|
self.register_buffer('freeze_mask', freeze_mask, persistent=False)
|
||||||
print(f"🔥 Memory bank freezing enabled: {freeze_num}/{params.knowledge_num} entries ({params.freeze_ratio*100:.1f}%) frozen")
|
print(f"🔥 Memory bank freezing enabled: {freeze_num}/{params.knowledge_num} entries ({params.freeze_ratio*100:.1f}%) frozen", flush=True)
|
||||||
|
import sys; sys.stdout.flush()
|
||||||
else:
|
else:
|
||||||
self.register_buffer('freeze_mask', torch.zeros(params.knowledge_num, dtype=torch.bool), persistent=False)
|
self.register_buffer('freeze_mask', torch.zeros(params.knowledge_num, dtype=torch.bool), persistent=False)
|
||||||
print(f"🔥 Memory bank freezing disabled: all entries can be updated")
|
print(f"🔥 Memory bank freezing disabled: all entries can be updated", flush=True)
|
||||||
|
import sys; sys.stdout.flush()
|
||||||
|
|
||||||
self.OUT = CausalLMOutputWithPast()
|
self.OUT = CausalLMOutputWithPast()
|
||||||
|
|
||||||
|
|||||||
@ -40,8 +40,8 @@ LOG_FILE="$LOG_DIR/experiment.log"
|
|||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
# 🤖 硬件配置
|
# 🤖 硬件配置
|
||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
CUDA_VISIBLE_DEVICES="0"
|
CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
NUM_PROCESSES="1"
|
NUM_PROCESSES="4"
|
||||||
MIXED_PRECISION="bf16"
|
MIXED_PRECISION="bf16"
|
||||||
MAIN_PROCESS_PORT="29500"
|
MAIN_PROCESS_PORT="29500"
|
||||||
|
|
||||||
@ -66,9 +66,9 @@ DISABLE_DB="false"
|
|||||||
# 🤖 训练超参数
|
# 🤖 训练超参数
|
||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
EPOCHS="3"
|
EPOCHS="3"
|
||||||
EMBEDDING_EPOCH="2"
|
EMBEDDING_EPOCH="42"
|
||||||
BATCH_SIZE="32" # 🔥 降低批次大小以适应更复杂的计算
|
BATCH_SIZE="4" # 🔥 降低批次大小以适应更复杂的计算
|
||||||
ACCUMULATION_STEPS="12" # 🔥 增加累积步数保持有效批次大小
|
ACCUMULATION_STEPS="4" # 🔥 增加累积步数保持有效批次大小
|
||||||
LEARNING_RATE="2e-4" # 🔥 适度降低学习率提升稳定性
|
LEARNING_RATE="2e-4" # 🔥 适度降低学习率提升稳定性
|
||||||
DTYPE="bfloat16"
|
DTYPE="bfloat16"
|
||||||
GRAD_CLIP="1.0"
|
GRAD_CLIP="1.0"
|
||||||
@ -86,7 +86,7 @@ CLUSTER_CACHE_PATH="None" # 禁用聚类缓存
|
|||||||
VAL_DATA_PATH="dataset/stable/eval_data.json"
|
VAL_DATA_PATH="dataset/stable/eval_data.json"
|
||||||
|
|
||||||
# 训练配置
|
# 训练配置
|
||||||
NUM_WORKERS="1"
|
NUM_WORKERS="8"
|
||||||
LOG_INTERVAL="100" # 🔥 更频繁的日志记录观察四个损失
|
LOG_INTERVAL="100" # 🔥 更频繁的日志记录观察四个损失
|
||||||
VAL_INTERVAL="100"
|
VAL_INTERVAL="100"
|
||||||
SAVE_INTERVAL="10000"
|
SAVE_INTERVAL="10000"
|
||||||
@ -215,7 +215,7 @@ run_experiment() {
|
|||||||
echo "⏰ 开始时间: $EXPERIMENT_DATE"
|
echo "⏰ 开始时间: $EXPERIMENT_DATE"
|
||||||
|
|
||||||
# 构建训练命令
|
# 构建训练命令
|
||||||
local train_cmd="CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python train_pretrain_accelerate.py"
|
local train_cmd="CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py"
|
||||||
|
|
||||||
# 添加训练参数
|
# 添加训练参数
|
||||||
train_cmd+=" --out_dir \"$LOG_DIR\""
|
train_cmd+=" --out_dir \"$LOG_DIR\""
|
||||||
|
|||||||
@ -37,9 +37,13 @@ EXPERIMENT_DATE="$(date '+%Y-%m-%d %H:%M:%S')" # 自动记录实验开始时间
|
|||||||
# source "$VIRTUAL_ENV/bin/activate"
|
# source "$VIRTUAL_ENV/bin/activate"
|
||||||
|
|
||||||
# 调试和监控环境变量
|
# 调试和监控环境变量
|
||||||
export NCCL_DEBUG=INFO # NCCL 调试信息
|
|
||||||
export PYTHONFAULTHANDLER=1 # Python 故障处理
|
export PYTHONFAULTHANDLER=1 # Python 故障处理
|
||||||
export CUDA_LAUNCH_BLOCKING=1 # CUDA 同步执行(调试用)
|
# export NCCL_DEBUG=INFO # NCCL 调试信息(仅调试时启用)
|
||||||
|
# export CUDA_LAUNCH_BLOCKING=1 # CUDA 同步执行(严重影响性能,仅调试时启用)
|
||||||
|
|
||||||
|
# 🔥 强制禁用输出缓冲,确保日志立即写入(不影响GPU性能)
|
||||||
|
export PYTHONUNBUFFERED=1 # Python 解释器不缓冲输出
|
||||||
|
export PYTHONIOENCODING=utf-8 # 确保编码一致性
|
||||||
|
|
||||||
# SwanLab 配置
|
# SwanLab 配置
|
||||||
export SWANLAB_API_KEY="[SWANLAB_API_KEY]" # 🤖 [AI构建] SwanLab API密钥
|
export SWANLAB_API_KEY="[SWANLAB_API_KEY]" # 🤖 [AI构建] SwanLab API密钥
|
||||||
@ -292,8 +296,8 @@ echo "退出代码: \$?"
|
|||||||
EOF
|
EOF
|
||||||
chmod +x "$train_script"
|
chmod +x "$train_script"
|
||||||
|
|
||||||
# 使用nohup后台运行
|
# 使用nohup后台运行,并使用stdbuf禁用缓冲
|
||||||
nohup bash "$train_script" >> "$LOG_FILE" 2>&1 &
|
nohup stdbuf -oL -eL bash "$train_script" >> "$LOG_FILE" 2>&1 &
|
||||||
local train_pid=$!
|
local train_pid=$!
|
||||||
|
|
||||||
echo "🔥 训练进程已启动,PID: $train_pid"
|
echo "🔥 训练进程已启动,PID: $train_pid"
|
||||||
|
|||||||
@ -1,6 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
# 设置环境变量 - 将wandb替换为SwanLab
|
# 设置环境变量 - 将wandb替换为SwanLab
|
||||||
# os.environ["SWANLAB_MODE"] = "online" # SwanLab使用在线模式
|
# os.environ["SWANLAB_MODE"] = "online" # SwanLab使用在线模式
|
||||||
|
|
||||||
|
# 🔥 强制禁用输出缓冲,确保日志立即写入
|
||||||
|
os.environ['PYTHONUNBUFFERED'] = '1' # Python 解释器不缓冲输出
|
||||||
|
os.environ['PYTHONIOENCODING'] = 'utf-8' # 确保编码一致性
|
||||||
import platform
|
import platform
|
||||||
import argparse
|
import argparse
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -92,7 +96,9 @@ def log_memory_status(step, prefetch_batches, accelerator, stage="", detailed=Fa
|
|||||||
def Logger(msg, accelerator=None):
|
def Logger(msg, accelerator=None):
|
||||||
# 如果没有提供accelerator,则只在主进程打印
|
# 如果没有提供accelerator,则只在主进程打印
|
||||||
if accelerator is None or accelerator.is_main_process:
|
if accelerator is None or accelerator.is_main_process:
|
||||||
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}")
|
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}", flush=True) # 强制刷新输出缓冲
|
||||||
|
import sys
|
||||||
|
sys.stdout.flush() # 确保立即写入
|
||||||
|
|
||||||
# Helper function to format seconds into HH:MM:SS
|
# Helper function to format seconds into HH:MM:SS
|
||||||
def format_time(seconds):
|
def format_time(seconds):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user