Thursday, 3 November 2022

How to sample batch from only one class at each iteration in PyTorch

I want to train a classifier on ImageNet dataset (1000 classes) and I need each batch to contain 64 images from the same class and consecutive batches from different classes. Is it possible create a DataLoader in PyTorch that implements this functionality?

The same post was made here but in TensorFlow. In one of the answers, it is suggested to

  1. create one dataset per class,

  2. batch each of these datasets — each batch has samples from a single class,

  3. zip all of them into one big dataset of batches,

  4. shuffle this dataset — the shuffling will occur on the batches, not on the samples, so it won't change the fact that batches are single class.

Could you please someone provide a PyTorch example or a code sketch for the aforementioned steps? Alternatively, is there a more efficient way by using the existed samplers?

I will try gradually to fill the code needed. I start from the main function. Any help is highly appreciated.

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import random
import argparse
import torch
import os


class DS(Dataset):
    def __init__(self, data):
        super(DS, self).__init__()
        self.data = data

        self.indices = [[] for _ in range(data.num_classes)]
        for i, (data, class_label) in enumerate(data):
            # create a list of lists, where every sublist containts the indices of
            # the samples that belong to the class_label
            self.indices[class_label].append(i)

    def classes(self):
        return self.indices

    def __getitem__(self, index):
        return self.data[index]


class Sampler():
    def __init__(self, classes, class_per_batch, batch_size):
        # classes is a list of lists where each sublist refers to a class and contains
        # the sample ids tha belond to this class
        self.classes = classes
        self.n_batches = sum([len(x) for x in classes]) // batch_size
        self.class_per_batch = class_per_batch
        self.batch_size = batch_size

    def __iter__(self):

        batches = []
        for _ in range(self.n_batches):
            # create a list with random sampling without replacement from a range
            classes = random.sample(range(len(self.classes)), self.class_per_batch)
            # batch contains the samples ids that belong to a subset of classes
            batch = []
            for i in range(self.batch_size):
                # pick a random class id that belongs to the classes list
                klass = random.choice(classes)
                # pick a random sample id that belongs to the class klass
                batch.append(random.choice(self.classes[klass]))
            batches.append(batch)
        return iter(batches)

def main():
    # Code about
    _train_dataset = DS(train_dataset)
    _batch_sampler = Sampler(_train_dataset.classes(), class_per_batch=1, batch_size=args.batch_size)
    train_loader = DataLoader(dataset=_train_dataset, batch_sampler=_batch_sampler)
    labels = []
    for i, (inputs, _labels) in enumerate(train_loader):
        labels.append(torch.unique(_labels).item())
        print("Unique labels: {}".format(torch.unique(_labels).item()))

    labels = set(labels)
    print('Length of traversed unique labels: {}'.format(len(labels)))

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

    parser.add_argument('--data', metavar='DIR', nargs='?', default='imagenet',
                        help='path to dataset (default: imagenet)')
    parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")

    parser.add_argument('-b', '--batch-size', default=64, type=int,
                        metavar='N',
                        help='mini-batch size (default: 256), this is the total '
                             'batch size of all GPUs on the current node when '
                             'using Data Parallel or Distributed Data Parallel')

    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')

    args = parser.parse_args()

    if args.dummy:
        print("=> Dummy data is used!")
        num_classes = 100
        train_dataset = datasets.FakeData(size=12811, image_size=(3, 224, 224),
                                          num_classes=num_classes, transform=transforms.ToTensor())
        val_dataset = datasets.FakeData(5000, (3, 224, 224), num_classes, transforms.ToTensor())
    else:
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))

    # Samplers are initialized to None and train_sampler will be replaced
    train_sampler, val_sampler = None, None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    main()

Prints: Length of traversed unique labels: 88



from How to sample batch from only one class at each iteration in PyTorch

No comments:

Post a Comment