Friday, 22 September 2023

Why do I run out of memory when training with a large dataset, but have no problems with a small dataset?

I'm trying to build a keypoint detection system using Keras. I've got a UNet like model, with a series of convolutions, batch normalization, and max pooling, followed by a symmetric series of up sampling, convolution, and batch normalization layers (and skip connections). When given 100 instances, I'm able to call model.fit() without a problem. However, if I leave the model the same but use 500 instances, Keras crashes with an OOM exception. Why does this happen, and is there anything I can do to fix it?

Here's (what I think is) the relevant part of the code where I call model.fit:

model = build_model(
    filters=50,
    filter_step=1,
    stages=5,
    stage_steps=1,
    initial_convolutions=0,
    stacks=1,
)

print(model.summary()) 

dataset = tf.data.Dataset.from_tensor_slices((X, y))
dataset = dataset.batch(1)

model.fit(
    dataset,
    epochs=2**7,
    callbacks=[
        EarlyStopping(monitor="loss", patience=5, min_delta=1e-7, start_from_epoch=10),
        LearningRateScheduler(step_decay)
    ],
)

X and y are Numpy arrays with the following shapes:

  • X: (100, 1024, 1024, 3)
  • y: (100, 1024, 1024)

100 here is the data set size. If I increase this to 500 (or more), I get the out-of-memory exception. It appears to me that Keras is perhaps trying to load the entire data set into memory, despite using from_tensor_slices and batch(1), so I'm clearly misunderstanding something.



from Why do I run out of memory when training with a large dataset, but have no problems with a small dataset?

No comments:

Post a Comment