I have a model that is too large for a single GPU. It has a list of transformers on it that once I run it through the 17th hidden layer, I get a CUDA out of memory error
I want to therefore run it on multiple GPUs.
import torch
import torch.nn as nn
import pytorch_lightning as pl
class SplitModel(pl.LightningModule):
def __init__(self, device1, device2):
super(SplitModel, self).__init__()
# Define your model segments
self.segment1 = #arbitraty in layer
self.transformers = #a torch.nn.ModuleList of transformers
self.segment2 = #arbitrary out layer
self.loss_fn = nn.CrossEntropyLoss()
self.device1 = torch.device('cuda:0')
self.device2 = torch.device('cuda:1')
def forward(self, x):
# Forward pass for segment1 on device1
x = self.segment1(x)
for i, transformer in enumerate(self.transformers):
current_device = '['+''.join("{}: {} ".format(name, next(child.parameters()).device if list(child.parameters()) else "CPU") for name, child in transformer.named_children()) + ']'
print("itterating through transformer {} on device {}".format(i, current_device))
attn, ff = transformer
x = attn(x) + x
x = ff(x) + x
# Forward pass for segment2 on device2
x = self.segment2(x)
return x
def training_step(self, batch, batch_idx):
inputs, labels = batch
# Forward pass
outputs = self(inputs)
# Calculate loss using segment2 outputs
loss = self.loss_fn(outputs, labels)
# Log loss for monitoring (optional)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
model = SplitModel()
ddp_strategy = DDPStrategy(find_unused_parameters=True)
Trainer = pl.Trainer(precision="16-mixed", accelerator="cuda", devices=[0, 1], strategy=ddp_strategy)
data_loader = #some dataloader
Trainer.fit(model, data_loader)
So then this would be an example output:
itterating through transformer 0 on device [0: cuda:0]
itterating through transformer 1 on device [0: cuda:0 1: cuda:0]
itterating through transformer 2 on device [0: cuda:0 1: cuda:0]
itterating through transformer 3 on device [0: cuda:0 1: cuda:0]
itterating through transformer 4 on device [0: cuda:0 1: cuda:0]
itterating through transformer 5 on device [0: cuda:0 1: cuda:0]
itterating through transformer 6 on device [0: cuda:0 1: cuda:0]
itterating through transformer 7 on device [0: cuda:0 1: cuda:0]
itterating through transformer 8 on device [0: cuda:0 1: cuda:0]
itterating through transformer 9 on device [0: cuda:0 1: cuda:0]
itterating through transformer 10 on device [0: cuda:0 1: cuda:0]
itterating through transformer 11 on device [0: cuda:0 1: cuda:0]
itterating through transformer 12 on device [0: cuda:0 1: cuda:0]
itterating through transformer 13 on device [0: cuda:0 1: cuda:0]
itterating through transformer 14 on device [0: cuda:0 1: cuda:0]
itterating through transformer 15 on device [0: cuda:0 1: cuda:0]
itterating through transformer 16 on device [0: cuda:0 1: cuda:0]
itterating through transformer 17 on device [0: cuda:0 1: cuda:0]
CUDA out of memory error
However, if I were to add this line of code to the forward pass:
self.Segment2 = self.Segment2.to(self.device2)
for i, transformer in enumerate(self.transformers):
if i == 17:
x = x.to(self.device2)
if i > 16:
transformer = transformer.to(self.device2)
#the rest of iterating through the transformers
return self.Segment2(x).to(self.device1)
Then I do not get a CUDA out of memory error
however, I do get the following error from the backward pass:
RuntimeError: grad.device() == bucket_view.device() INTERNAL ASSERT FAILED at "../torch/csrc/distributed/c10d/reducer.cpp":314, please report a bug to PyTorch.
I have also looked into sharding the model, instead of manually deciding which parts to put on the GPU. The strategy
in the pl.Trainer
module would be strategy="fsdp"
, I have gotten an error about the batch norm variables one being torch.cuda.FloatTensor
and torch.cuda.HalfTensor
.
Is there maybe a way to do this where I create a custom backward layer that changes the device manually?
from How can I manually split a model amongst multiple GPU's using pytorch lightning?
No comments:
Post a Comment