Wednesday, 16 February 2022

retrieving data from lmdb in python is throwing error

This is how the dataset was created


def createDataset1(
  outputPath, 
  imagePathList, 
  labelList, 
  lexiconList=None, 
  validset_percent=10,
  testset_percent=0,
  random_seed=1111,
  checkValid=True,
  
  ):
  """
  Create LMDB dataset for CRNN training.
  ARGS:
      outputPath    : LMDB output path
      imagePathList : list of image path
      labelList     : list of corresponding groundtruth texts
      lexiconList   : (optional) list of lexicon lists
      checkValid    : if true, check the validity of every image
  """

  train_path = os.path.join(outputPath, "training", "9M")
  valid_path = os.path.join(outputPath, "validation", "9M")
  # CAUTION: if train_path (lmdb) already exists, this function add dataset
    # into it. so remove former one and re-create lmdb.
  if os.path.exists(train_path):
      os.system(f"rm -r {train_path}")

  if os.path.exists(valid_path):
      os.system(f"rm -r {valid_path}")
  
  os.makedirs(train_path, exist_ok=True)
  os.makedirs(valid_path, exist_ok=True)
  gt_train_path = gt_file.replace(".txt", "_train.txt")
  gt_valid_path = gt_file.replace(".txt", "_valid.txt")
  data_log = open(gt_train_path, "w", encoding="utf-8")

  if testset_percent != 0:

    test_path = os.path.join(outputPath, "evaluation", dataset_name)
    if os.path.exists(test_path):
      os.system(f"rm -r {test_path}")
    os.makedirs(test_path, exist_ok=True)
    gt_test_path = gtFile.replace(".txt", "_test.txt")



  assert(len(imagePathList) == len(labelList))
  nSamples = len(imagePathList)

  num_valid_dataset = int(nSamples * validset_percent / 100.0)
  num_test_dataset = int(nSamples * testset_percent / 100.0)
  num_train_dataset = nSamples - num_valid_dataset - num_test_dataset

  print("validation datasets: ",num_valid_dataset,"\n", "test datasets: ", num_test_dataset, " \n training datasets: ", num_train_dataset)

  env = lmdb.open(outputPath, map_size=1099511627776)
  cache = {}
  cnt = 1

  random.seed(random_seed)
  random.shuffle(imagePathList)

  for i in tqdm(range(nSamples)):
    data_log.write(imagePathList[i])
    imagePath = imagePathList[i]
    label = labelList[i]
    if len(label) == 0:
      continue
    if not os.path.exists(imagePath):
      print('%s does not exist' % imagePath)
      continue
    with open(imagePath, 'rb') as f:
      imageBin = f.read()
    if checkValid:
      if not checkImageIsValid(imageBin):
        print('%s is not a valid image' % imagePath)
        continue
    embed_vec = fasttext_model[label]
    imageKey = 'image-%09d' % cnt
    labelKey = 'label-%09d' % cnt
    embedKey = 'embed-%09d' % cnt
    cache[imageKey] = imageBin
    cache[labelKey] = label.encode()
    cache[embedKey] = ' '.join(str(v) for v in embed_vec.tolist()).encode()
    if lexiconList:
      lexiconKey = 'lexicon-%09d' % cnt
      cache[lexiconKey] = ' '.join(lexiconList[i])
    if cnt % 1000 == 0:
      writeCache(env, cache)
      cache = {}
      print('Written %d / %d' % (cnt, nSamples))
    

    #finish train dataset and start validation dataset
    if i + 1 ==  num_train_dataset:
      print(f"# Train dataset: {num_train_dataset} is finished")
      cache["num-samples".encode()] = str(num_train_dataset).encode()
      writeCache(env, cache)
      data_log.close()


      #start validation set
      env = lmdb.open(valid_path, map_size=30 * 2 ** 30)
      cache = {}
      cnt = 0
      data_log = open(gt_valid_path, "w", encoding="utf-8")
    
    # Finish train/valid dataset and Start test dataset
    if (i + 1 == num_train_dataset + num_valid_dataset) and num_test_dataset != 0:
        print(f"# Valid dataset: {num_valid_dataset} is finished")
        cache["num-samples".encode()] = str(num_valid_dataset).encode()
        writeCache(env, cache)
        data_log.close()

        # start test set
        env = lmdb.open(test_path, map_size=30 * 2 ** 30)
        cache = {}
        cnt = 0  # not 1 at this time
        data_log = open(gt_test_path, "w", encoding="utf-8")


    cnt += 1

  if testset_percent == 0:
      cache["num-samples".encode()] = str(num_valid_dataset).encode()
      writeCache(env, cache)
      print(f"# Valid datast: {num_valid_dataset} is finished")
  else:
      cache["num-samples".encode()] = str(num_test_dataset).encode()
      writeCache(env, cache)
      print(f"# Test datast: {num_test_dataset} is finished")

This is how i am trying to retrieve the data


class LmdbDataset(data.Dataset):
  def __init__(self, root, voc_type, max_len, num_samples, transform=None):
    super(LmdbDataset, self).__init__()

    if global_args.run_on_remote:
      dataset_name = os.path.basename(root)
      data_cache_url = "/cache/%s" % dataset_name
      if not os.path.exists(data_cache_url):
        os.makedirs(data_cache_url)
      if mox.file.exists(root):
        mox.file.copy_parallel(root, data_cache_url)
      else:
        raise ValueError("%s not exists!" % root)
      
      self.env = lmdb.open(data_cache_url, max_readers=32, readonly=True)
    else:
      self.env = lmdb.open(root, max_readers=32, readonly=True)

    assert self.env is not None, "cannot create lmdb from %s" % root
    self.txn = self.env.begin()

    self.voc_type = voc_type
    self.transform = transform
    self.max_len = max_len
    # nums = b"num-samples"  
    # print('NUM SAMPLES ------ \n',nums)
    nSamples = self.txn.get('num-samples'.encode())
    print("STRING nSamples :", nSamples)
    self.nSamples = int(self.txn.get(b"num-samples"))
    self.nSamples = min(self.nSamples, num_samples)

    assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS']
    self.EOS = 'EOS'
    self.PADDING = 'PADDING'
    self.UNKNOWN = 'UNKNOWN'
    self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN)
    self.char2id = dict(zip(self.voc, range(len(self.voc))))
    self.id2char = dict(zip(range(len(self.voc)), self.voc))

    self.rec_num_classes = len(self.voc)
    self.lowercase = (voc_type == 'LOWERCASE')


I am getting the error below whenever the code tries to call elf.txn.get(b"num-samples")

Traceback (most recent call last):
  File "main.py", line 268, in <module>
    main(args)
  File "main.py", line 157, in main
    train_dataset, train_loader = get_data_lmdb(args.synthetic_train_data_dir, args.voc_type, args.max_len, args.num_train,
  File "main.py", line 66, in get_data_lmdb
    dataset_list.append(LmdbDataset(data_dir_, voc_type, max_len, num_samples))
  File "/Users/SEED/lib/datasets/dataset.py", line 189, in __init__
    self.nSamples = int(self.txn.get(b"num-samples"))
TypeError: int() argument must be a string, a bytes-like object or a number, not 'NoneType'

I have tried many different suggestions online and some stackoverflow threads but could not figure out what is wrong.

What is causing this error and how can I fix this?



from retrieving data from lmdb in python is throwing error

No comments:

Post a Comment