Wednesday 4 August 2021

What is the best way to load data with tf.data.Dataset in memory efficient way

I'm trying to load data for optimizing model for object detection + instance segmentation. However using tf.data.Dataset is giving me a bit headache with loading instance segmentations masks. tf.data.Dataset is using all the memory on the server (more than 128 GB) with a small dataset.

Is there a way to effectively load data in more memory efficient way, right now we are using this code:

train_dataset, train_examples = dataset.load_train_datasets()
ds = (
    train_dataset.shuffle(min(100, train_examples), reshuffle_each_iteration=True)
    .map(dataset.decode, num_parallel_calls=args.num_parallel_calls)
    .map(train_processing.prepare_for_batch, num_parallel_calls=args.num_parallel_calls)
    .batch(args.batch_size)
    .map(train_processing.preprocess_batch, num_parallel_calls=args.num_parallel_calls)
    .prefetch(AUTOTUNE)
)

The problem is that the second map call with train_processing.prepare_for_batch (takes single element) and third with train_processing.preprocess_batch (takes batch of elements) is creating a lot of binary masks for segmentation which are using all the memory.

Is there a way to reorganize the mapping functions to save the memory? I was thinking something like: 1. take first 100 samples, 2. decode the samples, 3. prepare the the masks and bounding boxes for one sample 4. takes the batch of them 5. final preparation of data per batch 6. FIT ONE step/one batch of data 7. clean the data from memory



from What is the best way to load data with tf.data.Dataset in memory efficient way

No comments:

Post a Comment