I want (the proper and official - bug free way) to do:
- resume from a checkpoint to continue training on multiple gpus
- save checkpoint correctly during training with multiple gpus
For that my guess is the following:
- to do 1 we have all the processes load the checkpoint from the file, then call
DDP(mdl)
for each process. I assume the checkpoint saved addp_mdl.module.state_dict()
. - to do 2 simply check who is rank = 0 and have that one do the torch.save({'model': ddp_mdl.module.state_dict()})
Approximate code:
def save_ckpt(rank, ddp_model, path):
if rank == 0:
state = {'model': ddp_model.module.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(state, path)
def load_ckpt(path, distributed, map_location=map_location=torch.device('cpu')):
# loads to
checkpoint = torch.load(path, map_location=map_location)
model = Net(...)
optimizer = ...
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
if distributed:
model = DDP(model, device_ids=[gpu], find_unused_parameters=True)
return model
Is this correct?
One of the reasons that I am asking is that distributed code can go subtly wrong. I want to make sure this does not happen to me. Of course I want to avoid deadlocks but that would be obvious if it happens to me (e.g. perhaps it could happen if all the processes somehow tried to open the same ckpt file at the same time. In that case I'd somehow make sure that only one of them loads it one at a time or have rank 0 only load it and then send it to the rest of the processes).
I am also asking because the official docs don't make sense to me. I will paste their code and explanation since links can die sometimes:
Save and Load Checkpoints It’s common to use torch.save and torch.load to checkpoint modules during training and recover from checkpoints. See SAVING AND LOADING MODELS for more details. When using DDP, one optimization is to save the model in only one process and then load it to all processes, reducing write overhead. This is correct because all processes start from the same parameters and gradients are synchronized in backward passes, and hence optimizers should keep setting parameters to the same values. If you use this optimization, make sure all processes do not start loading before the saving is finished. Besides, when loading the module, you need to provide an appropriate map_location argument to prevent a process to step into others’ devices. If map_location is missing, torch.load will first load the module to CPU and then copy each parameter to where it was saved, which would result in all processes on the same machine using the same set of devices. For more advanced failure recovery and elasticity support, please refer to TorchElastic.
def demo_checkpoint(rank, world_size):
print(f"Running DDP checkpoint example on rank {rank}.")
setup(rank, world_size)
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
if rank == 0:
# All processes should see same parameters as they all start from same
# random parameters and gradients are synchronized in backward passes.
# Therefore, saving it in one process is sufficient.
torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)
# Use a barrier() to make sure that process 1 loads the model after process
# 0 saves it.
dist.barrier()
# configure map_location properly
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
ddp_model.load_state_dict(
torch.load(CHECKPOINT_PATH, map_location=map_location))
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn = nn.MSELoss()
loss_fn(outputs, labels).backward()
optimizer.step()
# Not necessary to use a dist.barrier() to guard the file deletion below
# as the AllReduce ops in the backward pass of DDP already served as
# a synchronization.
if rank == 0:
os.remove(CHECKPOINT_PATH)
cleanup()
Related:
- https://discuss.pytorch.org/t/checkpointing-ddp-module-instead-of-ddp-itself/115714
- https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
- https://discuss.pytorch.org/t/ddp-and-gradient-checkpointing/132244
- https://github.com/pytorch/pytorch/issues/23138
- https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
- What is the proper way to checkpoint during training when using distributed data parallel (DDP) in PyTorch?
- https://discuss.pytorch.org/t/what-is-the-proper-way-to-checkpoint-during-training-when-using-distributed-data-parallel-ddp-in-pytorch/139575
from What is the proper way to checkpoint during training when using distributed data parallel (DDP) in PyTorch?
No comments:
Post a Comment