dataset filtering update

This commit is contained in:
Baek JeongHun 2019-05-09 03:30:52 +00:00
parent cf390a0873
commit 19605e07fa
1 changed files with 23 additions and 6 deletions

View File

@ -121,12 +121,28 @@ class LmdbDataset(Dataset):
nSamples = int(txn.get('num-samples'.encode()))
self.nSamples = nSamples
# Filtering
self.filtered_index_list = []
for index in range(self.nSamples):
index += 1 # lmdb starts with 1
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key).decode('utf-8')
if len(label) > self.opt.batch_max_length:
print(f'The length of the label is longer than max_length: length {len(label)}, {label} in dataset {self.root}')
continue
self.filtered_index_list.append(index)
self.nSamples = len(self.filtered_index_list)
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
index += 1
index = self.filtered_index_list[index]
with self.env.begin(write=False) as txn:
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key).decode('utf-8')
@ -144,11 +160,12 @@ class LmdbDataset(Dataset):
except IOError:
print(f'Corrupted image for {index}')
return
if len(label) > self.opt.batch_max_length:
print(f'The length of the label is longer than max_length: length {len(label)}, {label} in dataset {self.root}')
return
# make dummy image and dummy label for corrupted image.
if self.opt.rgb:
img = Image.new('RGB', (self.opt.imgW, self.opt.imgH))
else:
img = Image.new('L', (self.opt.imgW, self.opt.imgH))
label = '[dummy_label]'
if not self.opt.sensitive:
label = label.lower()