PyTorch并行计算(nn.parallel.DistributedDataParallel)
PyTorch并⾏计算(nn.parallel.DistributedDataParallel)
PyTorch并⾏计算
nn.parallel.DistributedDataParallel
这部分是nn.DataParallel的后续,想看nn.DataParallel的
为什么要⽤nn.parallel.DistributedDataParallel呢,⾸先我们看PyTorch官⽹对nn.DataParallel的⼀段话
WARNING
It is recommended to use , instead of this class, to do multi-GPU training, even if there is only a single node. See: and .
这段话的意思是说即使在单机多卡中也建议使⽤DistributedDataParallel,这就不得不说⼆者的区别
nn.DataParallel:只能⽤于单机多卡的情况,且不能使⽤apex加速
nn.parallel.DistributedDataParallel:既可以⽤于单机多卡也可以⽤于多机多卡的情况,可以使⽤apex加速
apex是什么?
apex是由Nvidia维护的⼀个⽀持混合精度分布式训练的PyTorch扩展,不仅能加速收敛,还能节省显存,但由于本⽂是介绍并⾏计算,所以这⾥不作过多的apex介绍
好!下⾯开始我们的介绍
⼀、为什么要并⾏计算?
在我们训练⼤型数据集或者很⼤的模型时⼀块GPU很难放下,例如最初的AlexNet就是在两块GPU上计算的。并⾏计算⼀般采取两个策略:⼀个是模型并⾏,⼀个是数据并⾏。左图中是将模型的不同部分放在不同GPU上进⾏训练,最后汇总计算。⽽右图中是将数据放在不同GPU上进⾏训练,最后汇总计算,不仅能加快我们的计算速度,增⼤BatchSize,⼀次epoch所需要的iter降低了,还能使结果更加精确(Batch增⼤了)
⼆、基本概念
在使⽤nn.parallel.DistributedDataParallel时会有⼀些参数,这⾥做简要说明
这些概念很重要,如果不知道的话后⾯会⼀头雾⽔
多机多卡
world_size:代表有⼏台机器,可以理解为⼏个服务器
rank:第⼏台机器,即第⼏个服务器
local_rank:某台机器中的第⼏块GPU
单机多卡
world_size:代表机器⼀共有⼏块GPU
rank:第⼏块GPU
local_rank:第⼏块GPU,与rank相同
三、DistributedDataParallel的使⽤
nn.parallel.DistributedDataParallel的使⽤⼀共有两种⽅法,⼀种是通过torch.multiprocessing来实现,优点是不需要在命令⾏加⼀些其他的命令,⽽另⼀种⽅法是直接通过torch.distributed来实现的,这⼀种⽅法需要在终端加上
python -m torch.distributed.launch --nproc_per_node=8 --xx --xx train.py
来实现,其中“--”后是⼀些参数,优点是实现起来⽐第⼀种简单,下⾯来介绍这两种⽅法(建议先看第⼆种,因为第⼆种⽐较简单)1. multiprocessing
官⽅的⼀个
第⼀步,导⼊必要的模块,multiprocessing导⼊的是以下模块
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.multiprocessing import Process
第⼆步,需要⽤argparse写⾃⼰的⼀些参数
parser = argparse.ArgumentParser()
'''    ...your params    '''
''' ...distributed params'''
parser.add_argument('--device', default='cuda',help='device id (i.e. 0 or 0,1 or cpu)')
parser.add_argument('--syncBN',type=bool, default=True)# 是否启⽤SyncBatchNorm
# 开启的进程数,不⽤设置该参数,会根据nproc_per_node⾃动设置
parser.add_argument('--world-size', default=4,type=int,help='number of distributed processes')
parser.add_argument('--dist-url', default='env://',help='url used to set up distributed training')
opt = parser.parse_args()
第三步,执⾏mp.spawn函数,这个函数是调⽤GPU并⾏的,看⼀下函数的参数
torch.multiprocessing.spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn')
fn:function,函数提前定义,⼀般来说是main函数,main(rank, *args),rank为进⾏索引,单机多卡中可以理解为第⼏个
GPU,*args为函数传⼊的参数,类型tuple
args:传⼊fn的参数,tuple
nprocs:Number of processes to spawn.
join:None if join is True, ProcessContext if join is False
Spawns nprocs processes that run fn with args.
If one of the processes exits with a non-zero exit status, the remaining processes are killed and an exception is raised with the cause of termination. In the case an exception was caught in the child process, it is forwarded and its traceback is included in the exception raised in the parent process.
调⽤时如下调⽤即可
mp.spawn(main,
args=(opt,),
nprocs=opt.world_size,
join=True)
第四步,定义main函数
数据集加载⽅式变为:Datasets→DistributedSampler→BatchSampler→DataLoader
DistributedSampler
BatchSampler
def main(rank, args):
'''    初始化各进程环境    '''
args.rank = rank
args.gpu = rank
args.world_size = world_size
args.distributed =True
torch.cuda.set_device(args.gpu)
args.dist_backend ='nccl'
args.dist_backend ='nccl'
print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)    dist.barrier()
# 只在主卡中进⾏打印操作并进⾏⼀些检查
if rank ==0:
print(args)
if ists("xx")is False:
os.makedirs("xx")
'''    训练前的⼀些参数    '''
'''    数据集    '''
# 加载datasets
train_datasets = MyDataSet(xxx)
val_datasets = MyDataSet(xxx)
# 给每个rank对应的进程分配训练的样本索引,⽐如⼀共800样本8张卡
# 那么每张卡对应分配100个样本
train_sampler = torch.utils.data.distributed.DistributedSampler(train_datasets)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_datasets)
# 刚才每张卡分了100个样本,假设BatchSize=16,那么能分成100/16=6 (4)
# 即多出4个样本,下⾯的drop_last=True表⽰舍弃这四个样本,False将剩余4个样本为⼀组
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
# 加载dataloader的线程数并在rank 0中打印
nw =min([os.cpu_count(), batch_size if batch_size >1else0,8])# number of workers
if rank ==0:
print('Using {} dataloader workers every process'.format(nw))
# pin_memory将数据加载到GPU
# 验证集没有打乱val_batch_sampler,所以直接⽤batch_size
train_dataloader = torch.utils.data.DataLoader(train_datasets,
batch_sampler=train_batch_sampler,
pin_memory=True,
num_workers=nw,
collate_fn=train_llate_fn)
val_dataloader = torch.utils.data.DataLoader(val_datasets,
batch_size=batch_size,
sampler=val_sampler,
pin_memory=True,
num_workers=nw,
collate_fn=val_llate_fn)
'''      加载模型    '''
model = model(xxx).to(device)
# 如果存在预训练权重则载⼊
if ists(weights_path):
weights_dict = torch.load(weights_path, map_location=device)
load_weights_dict ={k: v for k, v in weights_dict.items()
if model.state_dict()[k].numel()== v.numel()}
model.load_state_dict(load_weights_dict, strict=False)
else:
checkpoint_path = os.path.pdir(),"initial_weights.pt")
# 如果不存在预训练权重,需要将第⼀个进程中的权重保存,然后其他进程载⼊,保持初始化权重⼀致
if rank ==0:
torch.save(model.state_dict(), checkpoint_path)
dist.barrier()
# 这⾥注意,⼀定要指定map_location参数,否则会导致第⼀块GPU占⽤更多资源
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
# 是否冻结权重
if args.freeze_layers:
for name, para in model.named_parameters():
# 除最后的全连接层外,其他权重全部冻结
if"fc"not in name:
else:
# 只有训练带有BN结构的⽹络时使⽤SyncBatchNorm采⽤意义
if args.syncBN:
# 使⽤SyncBatchNorm后训练会更耗时
model = vert_sync_batchnorm(model).to(device)
# 转为DDP模型
model = parallel.DistributedDataParallel(model, device_ids=[args.gpu])
# optimizer
pg =[p for p in model.parameters()quires_grad]
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=0.005)
# Scheduler /pdf/1812.01187.pdf
lf =lambda x:((1+ s(x * math.pi / args.epochs))/2)*(1- args.lrf)+ args.lrf  # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
'''      train    '''
for epoch in range(args.epochs):
train_sampler.set_epoch(epoch)
mean_loss = train_one_epoch(model=model,
optimizer=optimizer,
data_loader=train_loader,
device=device,
epoch=epoch)
scheduler.step()
sum_num = evaluate(model=model,
data_loader=val_loader,
device=device)
acc = sum_num / al_size
if rank ==0:# print save
print()
torch.save(xxx)
rank函数的用法
# 删除临时缓存⽂件
if rank ==0:
if ists(checkpoint_path)is True:
cleanup()
上⾯在我们初始化调⽤dist.init_process_group()函数之前,必须加⼊以下环境,不然会报错
"""
报错:
ValueError: Error initializing torch.distributed using env:// rendezvous: environment variable MASTER_ADDR expected, but not set """
# 加⼊以下语句

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。