update
This commit is contained in:
parent
fc688ddde4
commit
1a8c86360d
@ -1,8 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 激活conda环境
|
||||
source $(conda info --base)/etc/profile.d/conda.sh
|
||||
conda activate ycz_accelerate
|
||||
# source $(conda info --base)/etc/profile.d/conda.sh
|
||||
# conda activate ycz_accelerate
|
||||
|
||||
# 设置环境变量以帮助调试
|
||||
export NCCL_DEBUG=INFO
|
||||
@ -27,8 +27,7 @@ export PYTHONFAULTHANDLER=1
|
||||
|
||||
# 方法2: 使用命令行参数直接配置accelerate
|
||||
CUDA_VISIBLE_DEVICES=0 accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
--num_processes=1 \
|
||||
--mixed_precision=bf16 \
|
||||
--main_process_port=29500 \
|
||||
train_pretrain_accelerate.py \
|
||||
|
@ -339,9 +339,9 @@ def main():
|
||||
# 我们已经将复数版本的pos_cis替换为实数版本的pos_cis_real
|
||||
# 但为了安全起见,我们仍然将其设置为不参与分布式训练
|
||||
if hasattr(model, "pos_cis_real"):
|
||||
Logger(f'检测到pos_cis_real实数张量,将其设置为不参与分布式训练', accelerator)
|
||||
Logger(f'检测到pos_cis_real实数张量,将其设置为参与分布式训练', accelerator)
|
||||
# 设置模型的_ddp_params_and_buffers_to_ignore属性
|
||||
model._ddp_params_and_buffers_to_ignore = {"pos_cis_real"}
|
||||
# model._ddp_params_and_buffers_to_ignore = {"pos_cis_real"}
|
||||
# 兼容旧版本,检查是否仍有pos_cis
|
||||
elif hasattr(model, "pos_cis"):
|
||||
Logger(f'检测到pos_cis复数张量,将其设置为不参与分布式训练', accelerator)
|
||||
|
Loading…
x
Reference in New Issue
Block a user