I'm trying to build a keypoint detection system using Keras. I've got a UNet like model, with a series of convolutions, batch normalization, and max pooling, followed by a symmetric series of up sampling, convolution, and batch normalization layers (and skip connections). When given 100 instances, I'm able to call model.fit()
without a problem. However, if I leave the model the same but use 500 instances, Keras crashes with an OOM exception. Why does this happen, and is there anything I can do to fix it?
Here's (what I think is) the relevant part of the code where I call model.fit
:
model = build_model(
filters=50,
filter_step=1,
stages=5,
stage_steps=1,
initial_convolutions=0,
stacks=1,
)
print(model.summary())
dataset = tf.data.Dataset.from_tensor_slices((X, y))
dataset = dataset.batch(1)
model.fit(
dataset,
epochs=2**7,
callbacks=[
EarlyStopping(monitor="loss", patience=5, min_delta=1e-7, start_from_epoch=10),
LearningRateScheduler(step_decay)
],
)
X
and y
are Numpy arrays with the following shapes:
X
: (100, 1024, 1024, 3)y
: (100, 1024, 1024)
100 here is the data set size. If I increase this to 500 (or more), I get the out-of-memory exception. It appears to me that Keras is perhaps trying to load the entire data set into memory, despite using from_tensor_slices
and batch(1)
, so I'm clearly misunderstanding something.
from Why do I run out of memory when training with a large dataset, but have no problems with a small dataset?
No comments:
Post a Comment