分布式训练中的参数local_rank
local_rank
是一个常用于分布式训练中的参数,用于指示当前进程的本地编号。它帮助在分布式环境中区分不同的进程。通常情况下,local_rank
的值为 -1 表示不进行分布式训练,值为 0 表示第一个(主)进程,其它正数表示其它辅助进程。
在分布式训练中,我们常常需要确保某些操作(例如下载模型和词汇表)只由一个进程完成,以避免重复工作和资源浪费。以下是 local_rank
在不同情况下的用法解释:
-
local_rank == -1
: - 表示不进行分布式训练。代码在单机单卡(或 CPU)模式下运行。
-
local_rank == 0
: - 表示主进程。在多机多卡或单机多卡模式下,通常第一个进程(local_rank 为 0)负责一些需要全局唯一的操作,例如下载模型和词汇表。
-
local_rank > 0
: - 表示其它辅助进程。用于分布式训练的其它进程。
在你的代码中,local_rank
不在 [-1, 0]
中表示所有非主进程或非单机单卡模式的进程。通过这种检查,我们可以确保只有主进程或非分布式模式下才执行某些初始化操作。
示例代码
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # 确保只有第一个进程在分布式训练中下载模型和词汇表
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None,
)
if args.local_rank == 0:
torch.distributed.barrier() # 确保只有第一个进程在分布式训练中下载模型和词汇表
解释
-
检查
local_rank
是否在[-1, 0]
中:if args.local_rank not in [-1, 0]: torch.distributed.barrier()
- 如果
local_rank
不在[-1, 0]
中,表示这是一个辅助进程或在分布式训练模式下运行。torch.distributed.barrier()
是一个同步屏障,确保在所有进程到达这一点之前不会继续执行。这样可以确保在继续执行之前,所有辅助进程等待主进程下载模型和词汇表。 -
加载分词器:
tokenizer = AutoTokenizer.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None, )
-
主进程的同步屏障:
if args.local_rank == 0: torch.distributed.barrier()
- 如果
local_rank
为 0,表示这是主进程。再次调用torch.distributed.barrier()
,确保在主进程完成下载模型和词汇表后,辅助进程才继续执行。
完整示例
为了更好地理解这一点,这里是一个更详细的例子:
import torch
from transformers import AutoTokenizer
class Args:
def __init__(self):
self.local_rank = 0 # 设为0表示主进程,可以改为-1表示非分布式训练
self.tokenizer_name = None
self.model_name_or_path = 'bert-base-uncased'
self.do_lower_case = True
self.cache_dir = './cache'
self.output_dir = './output'
self.do_train = True
args = Args()
# 初始化分布式训练环境(模拟)
if args.local_rank != -1:
torch.distributed.init_process_group(backend='nccl', rank=args.local_rank, world_size=1)
if args.do_train:
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # 确保辅助进程等待主进程下载模型和词汇表
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None,
)
if args.local_rank == 0:
torch.distributed.barrier() # 确保辅助进程等待主进程下载模型和词汇表
# 继续后续的训练步骤...
print("Tokenizer loaded successfully.")
# 清理分布式环境
if args.local_rank != -1:
torch.distributed.destroy_process_group()
这个示例展示了如何在分布式训练中使用 local_rank
来确保模型和词汇表的下载只在主进程中进行,并同步所有进程。
作者:挨打且不服66