分布式训练中的参数local_rank

local_rank 是一个常用于分布式训练中的参数,用于指示当前进程的本地编号。它帮助在分布式环境中区分不同的进程。通常情况下,local_rank 的值为 -1 表示不进行分布式训练,值为 0 表示第一个(主)进程,其它正数表示其它辅助进程。

在分布式训练中,我们常常需要确保某些操作(例如下载模型和词汇表)只由一个进程完成,以避免重复工作和资源浪费。以下是 local_rank 在不同情况下的用法解释:

  1. local_rank == -1

  2. 表示不进行分布式训练。代码在单机单卡(或 CPU)模式下运行。
  3. local_rank == 0

  4. 表示主进程。在多机多卡或单机多卡模式下,通常第一个进程(local_rank 为 0)负责一些需要全局唯一的操作,例如下载模型和词汇表。
  5. local_rank > 0

  6. 表示其它辅助进程。用于分布式训练的其它进程。

在你的代码中,local_rank 不在 [-1, 0] 中表示所有非主进程或非单机单卡模式的进程。通过这种检查,我们可以确保只有主进程或非分布式模式下才执行某些初始化操作。

示例代码

if args.local_rank not in [-1, 0]:
    torch.distributed.barrier()  # 确保只有第一个进程在分布式训练中下载模型和词汇表

tokenizer = AutoTokenizer.from_pretrained(
    args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
    do_lower_case=args.do_lower_case,
    cache_dir=args.cache_dir if args.cache_dir else None,
)

if args.local_rank == 0:
    torch.distributed.barrier()  # 确保只有第一个进程在分布式训练中下载模型和词汇表

解释

  1. 检查 local_rank 是否在 [-1, 0]

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()
    
  2. 如果 local_rank 不在 [-1, 0] 中,表示这是一个辅助进程或在分布式训练模式下运行。torch.distributed.barrier() 是一个同步屏障,确保在所有进程到达这一点之前不会继续执行。这样可以确保在继续执行之前,所有辅助进程等待主进程下载模型和词汇表。
  3. 加载分词器

    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    
  4. 主进程的同步屏障

    if args.local_rank == 0:
        torch.distributed.barrier()
    
  5. 如果 local_rank 为 0,表示这是主进程。再次调用 torch.distributed.barrier(),确保在主进程完成下载模型和词汇表后,辅助进程才继续执行。

完整示例

为了更好地理解这一点,这里是一个更详细的例子:

import torch
from transformers import AutoTokenizer

class Args:
    def __init__(self):
        self.local_rank = 0  # 设为0表示主进程,可以改为-1表示非分布式训练
        self.tokenizer_name = None
        self.model_name_or_path = 'bert-base-uncased'
        self.do_lower_case = True
        self.cache_dir = './cache'
        self.output_dir = './output'
        self.do_train = True

args = Args()

# 初始化分布式训练环境(模拟)
if args.local_rank != -1:
    torch.distributed.init_process_group(backend='nccl', rank=args.local_rank, world_size=1)

if args.do_train:
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()  # 确保辅助进程等待主进程下载模型和词汇表

    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )

    if args.local_rank == 0:
        torch.distributed.barrier()  # 确保辅助进程等待主进程下载模型和词汇表

    # 继续后续的训练步骤...

    print("Tokenizer loaded successfully.")

# 清理分布式环境
if args.local_rank != -1:
    torch.distributed.destroy_process_group()

这个示例展示了如何在分布式训练中使用 local_rank 来确保模型和词汇表的下载只在主进程中进行,并同步所有进程。

作者:挨打且不服66

物联沃分享整理
物联沃-IOTWORD物联网 » 分布式训练中的参数local_rank

发表回复