dataset filtering update
This commit is contained in:
parent
cf390a0873
commit
19605e07fa
29
dataset.py
29
dataset.py
|
@ -121,12 +121,28 @@ class LmdbDataset(Dataset):
|
|||
nSamples = int(txn.get('num-samples'.encode()))
|
||||
self.nSamples = nSamples
|
||||
|
||||
# Filtering
|
||||
self.filtered_index_list = []
|
||||
for index in range(self.nSamples):
|
||||
index += 1 # lmdb starts with 1
|
||||
label_key = 'label-%09d'.encode() % index
|
||||
label = txn.get(label_key).decode('utf-8')
|
||||
|
||||
if len(label) > self.opt.batch_max_length:
|
||||
print(f'The length of the label is longer than max_length: length {len(label)}, {label} in dataset {self.root}')
|
||||
continue
|
||||
|
||||
self.filtered_index_list.append(index)
|
||||
|
||||
self.nSamples = len(self.filtered_index_list)
|
||||
|
||||
def __len__(self):
|
||||
return self.nSamples
|
||||
|
||||
def __getitem__(self, index):
|
||||
assert index <= len(self), 'index range error'
|
||||
index += 1
|
||||
index = self.filtered_index_list[index]
|
||||
|
||||
with self.env.begin(write=False) as txn:
|
||||
label_key = 'label-%09d'.encode() % index
|
||||
label = txn.get(label_key).decode('utf-8')
|
||||
|
@ -144,11 +160,12 @@ class LmdbDataset(Dataset):
|
|||
|
||||
except IOError:
|
||||
print(f'Corrupted image for {index}')
|
||||
return
|
||||
|
||||
if len(label) > self.opt.batch_max_length:
|
||||
print(f'The length of the label is longer than max_length: length {len(label)}, {label} in dataset {self.root}')
|
||||
return
|
||||
# make dummy image and dummy label for corrupted image.
|
||||
if self.opt.rgb:
|
||||
img = Image.new('RGB', (self.opt.imgW, self.opt.imgH))
|
||||
else:
|
||||
img = Image.new('L', (self.opt.imgW, self.opt.imgH))
|
||||
label = '[dummy_label]'
|
||||
|
||||
if not self.opt.sensitive:
|
||||
label = label.lower()
|
||||
|
|
Loading…
Reference in New Issue