diff --git a/dataset.py b/dataset.py index 409f3f6..1b61dbb 100755 --- a/dataset.py +++ b/dataset.py @@ -121,12 +121,28 @@ class LmdbDataset(Dataset): nSamples = int(txn.get('num-samples'.encode())) self.nSamples = nSamples + # Filtering + self.filtered_index_list = [] + for index in range(self.nSamples): + index += 1 # lmdb starts with 1 + label_key = 'label-%09d'.encode() % index + label = txn.get(label_key).decode('utf-8') + + if len(label) > self.opt.batch_max_length: + print(f'The length of the label is longer than max_length: length {len(label)}, {label} in dataset {self.root}') + continue + + self.filtered_index_list.append(index) + + self.nSamples = len(self.filtered_index_list) + def __len__(self): return self.nSamples def __getitem__(self, index): assert index <= len(self), 'index range error' - index += 1 + index = self.filtered_index_list[index] + with self.env.begin(write=False) as txn: label_key = 'label-%09d'.encode() % index label = txn.get(label_key).decode('utf-8') @@ -144,11 +160,12 @@ class LmdbDataset(Dataset): except IOError: print(f'Corrupted image for {index}') - return - - if len(label) > self.opt.batch_max_length: - print(f'The length of the label is longer than max_length: length {len(label)}, {label} in dataset {self.root}') - return + # make dummy image and dummy label for corrupted image. + if self.opt.rgb: + img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) + else: + img = Image.new('L', (self.opt.imgW, self.opt.imgH)) + label = '[dummy_label]' if not self.opt.sensitive: label = label.lower()