#!/bin/bash

# 激活conda环境
# source $(conda info --base)/etc/profile.d/conda.sh
# conda activate ycz_accelerate

# 设置环境变量以帮助调试
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1

# 方法1: 使用预先配置的accelerate配置文件
# accelerate launch --config_file accelerate_config.yaml train_pretrain_accelerate.py \
#     --epochs 3 \
#     --batch_size 24 \
#     --learning_rate 2e-4 \
#     --dtype bfloat16 \
#     --accumulation_steps 32 \
#     --grad_clip 1.0 \
#     --log_interval 100 \
#     --save_interval 10000 \
#     --dim 1024 \
#     --n_layers 32 \
#     --max_seq_len 1024 \
#     --use_flash_attn \
#     --profile \
#     --profile_interval 10

# 方法2: 使用命令行参数直接配置accelerate
CUDA_VISIBLE_DEVICES=0 accelerate launch \
    --num_processes=1 \
    --mixed_precision=bf16 \
    --main_process_port=29500 \
    train_pretrain_accelerate.py \
    --epochs 3 \
    --batch_size 24 \
    --learning_rate 2e-4 \
    --dtype bfloat16 \
    --accumulation_steps 32 \
    --grad_clip 1.0 \
    --log_interval 100 \
    --save_interval 10000 \
    --dim 512 \
    --n_layers 12 \
    --max_seq_len 512 \
    --use_flash_attn \
    --profile \
    --profile_interval 10\
    --knowledge_num 4096 \
    --knowledge_length 8