2019-04-05 18:45:29 +08:00
import os
import sys
import re
import six
2019-05-10 10:11:06 +08:00
import math
2019-04-05 18:45:29 +08:00
import lmdb
import torch
2019-05-21 13:31:52 +08:00
from natsort import natsorted
2019-04-05 18:45:29 +08:00
from PIL import Image
import numpy as np
from torch . utils . data import Dataset , ConcatDataset , Subset
from torch . _utils import _accumulate
import torchvision . transforms as transforms
class Batch_Balanced_Dataset ( object ) :
def __init__ ( self , opt ) :
"""
Modulate the data ratio in the batch .
For example , when select_data is " MJ-ST " and batch_ratio is " 0.5-0.5 " ,
the 50 % of the batch is filled with MJ and the other 50 % of the batch is filled with ST .
"""
2020-06-05 22:38:52 +08:00
log = open ( f ' ./saved_models/ { opt . exp_name } /log_dataset.txt ' , ' a ' )
2019-12-27 18:31:47 +08:00
dashed_line = ' - ' * 80
print ( dashed_line )
log . write ( dashed_line + ' \n ' )
2019-04-05 18:45:29 +08:00
print ( f ' dataset_root: { opt . train_data } \n opt.select_data: { opt . select_data } \n opt.batch_ratio: { opt . batch_ratio } ' )
2019-12-27 18:31:47 +08:00
log . write ( f ' dataset_root: { opt . train_data } \n opt.select_data: { opt . select_data } \n opt.batch_ratio: { opt . batch_ratio } \n ' )
2019-04-05 18:45:29 +08:00
assert len ( opt . select_data ) == len ( opt . batch_ratio )
2019-05-10 10:11:06 +08:00
_AlignCollate = AlignCollate ( imgH = opt . imgH , imgW = opt . imgW , keep_ratio_with_pad = opt . PAD )
2019-04-05 18:45:29 +08:00
self . data_loader_list = [ ]
self . dataloader_iter_list = [ ]
2019-04-10 21:48:19 +08:00
batch_size_list = [ ]
Total_batch_size = 0
2019-04-05 18:45:29 +08:00
for selected_d , batch_ratio_d in zip ( opt . select_data , opt . batch_ratio ) :
_batch_size = max ( round ( opt . batch_size * float ( batch_ratio_d ) ) , 1 )
2019-12-27 18:31:47 +08:00
print ( dashed_line )
log . write ( dashed_line + ' \n ' )
_dataset , _dataset_log = hierarchical_dataset ( root = opt . train_data , opt = opt , select_data = [ selected_d ] )
2019-04-05 18:45:29 +08:00
total_number_dataset = len ( _dataset )
2019-12-27 18:31:47 +08:00
log . write ( _dataset_log )
2019-04-05 18:45:29 +08:00
"""
The total number of data can be modified with opt . total_data_usage_ratio .
ex ) opt . total_data_usage_ratio = 1 indicates 100 % usage , and 0.2 indicates 20 % usage .
See 4.2 section in our paper .
"""
number_dataset = int ( total_number_dataset * float ( opt . total_data_usage_ratio ) )
dataset_split = [ number_dataset , total_number_dataset - number_dataset ]
indices = range ( total_number_dataset )
_dataset , _ = [ Subset ( _dataset , indices [ offset - length : offset ] )
for offset , length in zip ( _accumulate ( dataset_split ) , dataset_split ) ]
2019-12-27 18:31:47 +08:00
selected_d_log = f ' num total samples of { selected_d } : { total_number_dataset } x { opt . total_data_usage_ratio } (total_data_usage_ratio) = { len ( _dataset ) } \n '
selected_d_log + = f ' num samples of { selected_d } per batch: { opt . batch_size } x { float ( batch_ratio_d ) } (batch_ratio) = { _batch_size } '
print ( selected_d_log )
log . write ( selected_d_log + ' \n ' )
2019-04-10 21:48:19 +08:00
batch_size_list . append ( str ( _batch_size ) )
Total_batch_size + = _batch_size
2019-04-05 18:45:29 +08:00
_data_loader = torch . utils . data . DataLoader (
_dataset , batch_size = _batch_size ,
shuffle = True ,
num_workers = int ( opt . workers ) ,
collate_fn = _AlignCollate , pin_memory = True )
self . data_loader_list . append ( _data_loader )
self . dataloader_iter_list . append ( iter ( _data_loader ) )
2019-12-27 18:31:47 +08:00
Total_batch_size_log = f ' { dashed_line } \n '
batch_size_sum = ' + ' . join ( batch_size_list )
Total_batch_size_log + = f ' Total_batch_size: { batch_size_sum } = { Total_batch_size } \n '
Total_batch_size_log + = f ' { dashed_line } '
2019-04-10 21:48:19 +08:00
opt . batch_size = Total_batch_size
2019-12-27 18:31:47 +08:00
print ( Total_batch_size_log )
log . write ( Total_batch_size_log + ' \n ' )
log . close ( )
2019-04-05 18:45:29 +08:00
def get_batch ( self ) :
balanced_batch_images = [ ]
balanced_batch_texts = [ ]
for i , data_loader_iter in enumerate ( self . dataloader_iter_list ) :
try :
image , text = data_loader_iter . next ( )
balanced_batch_images . append ( image )
balanced_batch_texts + = text
except StopIteration :
self . dataloader_iter_list [ i ] = iter ( self . data_loader_list [ i ] )
image , text = self . dataloader_iter_list [ i ] . next ( )
balanced_batch_images . append ( image )
balanced_batch_texts + = text
except ValueError :
pass
balanced_batch_images = torch . cat ( balanced_batch_images , 0 )
return balanced_batch_images , balanced_batch_texts
def hierarchical_dataset ( root , opt , select_data = ' / ' ) :
""" select_data= ' / ' contains all sub-directory of root directory """
dataset_list = [ ]
2019-12-27 18:31:47 +08:00
dataset_log = f ' dataset_root: { root } \t dataset: { select_data [ 0 ] } '
print ( dataset_log )
dataset_log + = ' \n '
2019-08-10 17:29:40 +08:00
for dirpath , dirnames , filenames in os . walk ( root + ' / ' ) :
2019-04-05 18:45:29 +08:00
if not dirnames :
select_flag = False
for selected_d in select_data :
if selected_d in dirpath :
select_flag = True
break
if select_flag :
dataset = LmdbDataset ( dirpath , opt )
2019-12-27 18:31:47 +08:00
sub_dataset_log = f ' sub-directory: \t / { os . path . relpath ( dirpath , root ) } \t num samples: { len ( dataset ) } '
print ( sub_dataset_log )
dataset_log + = f ' { sub_dataset_log } \n '
2019-04-05 18:45:29 +08:00
dataset_list . append ( dataset )
concatenated_dataset = ConcatDataset ( dataset_list )
2020-02-25 19:06:34 +08:00
2019-12-27 18:31:47 +08:00
return concatenated_dataset , dataset_log
2019-04-05 18:45:29 +08:00
class LmdbDataset ( Dataset ) :
def __init__ ( self , root , opt ) :
self . root = root
self . opt = opt
self . env = lmdb . open ( root , max_readers = 32 , readonly = True , lock = False , readahead = False , meminit = False )
if not self . env :
print ( ' cannot create lmdb from %s ' % ( root ) )
sys . exit ( 0 )
with self . env . begin ( write = False ) as txn :
nSamples = int ( txn . get ( ' num-samples ' . encode ( ) ) )
self . nSamples = nSamples
2019-07-16 18:04:20 +08:00
if self . opt . data_filtering_off :
2019-11-04 12:32:46 +08:00
# for fast check or benchmark evaluation with no filtering
2019-07-16 18:04:20 +08:00
self . filtered_index_list = [ index + 1 for index in range ( self . nSamples ) ]
else :
2019-11-04 12:32:46 +08:00
""" Filtering part
2019-11-04 12:34:06 +08:00
If you want to evaluate IC15 - 2077 & CUTE datasets which have special character labels ,
2019-12-27 18:31:47 +08:00
use - - data_filtering_off and only evaluate on alphabets and digits .
see https : / / github . com / clovaai / deep - text - recognition - benchmark / blob / 6593928855 fb7abb999a99f428b3e4477d4ae356 / dataset . py #L190-L192
2020-02-25 19:09:44 +08:00
And if you want to evaluate them with the model trained with - - sensitive option ,
use - - sensitive and - - data_filtering_off ,
2020-02-25 19:16:22 +08:00
see https : / / github . com / clovaai / deep - text - recognition - benchmark / blob / dff844874dbe9e0ec8c5a52a7bd08c7f20afe704 / test . py #L137-L144
2019-11-04 12:32:46 +08:00
"""
2019-07-15 13:33:37 +08:00
self . filtered_index_list = [ ]
for index in range ( self . nSamples ) :
index + = 1 # lmdb starts with 1
label_key = ' label- %09d ' . encode ( ) % index
label = txn . get ( label_key ) . decode ( ' utf-8 ' )
if len ( label ) > self . opt . batch_max_length :
# print(f'The length of the label is longer than max_length: length
# {len(label)}, {label} in dataset {self.root}')
continue
2019-07-17 13:41:34 +08:00
# By default, images containing characters which are not in opt.character are filtered.
2019-07-17 13:34:10 +08:00
# You can add [UNK] token to `opt.character` in utils.py instead of this filtering.
out_of_char = f ' [^ { self . opt . character } ] '
if re . search ( out_of_char , label . lower ( ) ) :
continue
2019-07-15 13:33:37 +08:00
self . filtered_index_list . append ( index )
self . nSamples = len ( self . filtered_index_list )
2019-05-09 11:30:52 +08:00
2019-04-05 18:45:29 +08:00
def __len__ ( self ) :
return self . nSamples
def __getitem__ ( self , index ) :
assert index < = len ( self ) , ' index range error '
2019-05-09 11:30:52 +08:00
index = self . filtered_index_list [ index ]
2019-04-05 18:45:29 +08:00
with self . env . begin ( write = False ) as txn :
label_key = ' label- %09d ' . encode ( ) % index
label = txn . get ( label_key ) . decode ( ' utf-8 ' )
img_key = ' image- %09d ' . encode ( ) % index
imgbuf = txn . get ( img_key )
buf = six . BytesIO ( )
buf . write ( imgbuf )
buf . seek ( 0 )
try :
if self . opt . rgb :
img = Image . open ( buf ) . convert ( ' RGB ' ) # for color image
else :
img = Image . open ( buf ) . convert ( ' L ' )
except IOError :
print ( f ' Corrupted image for { index } ' )
2019-05-09 11:30:52 +08:00
# make dummy image and dummy label for corrupted image.
if self . opt . rgb :
img = Image . new ( ' RGB ' , ( self . opt . imgW , self . opt . imgH ) )
else :
img = Image . new ( ' L ' , ( self . opt . imgW , self . opt . imgH ) )
label = ' [dummy_label] '
2019-04-05 18:45:29 +08:00
if not self . opt . sensitive :
label = label . lower ( )
# We only train and evaluate on alphanumerics (or pre-defined character set in train.py)
out_of_char = f ' [^ { self . opt . character } ] '
label = re . sub ( out_of_char , ' ' , label )
return ( img , label )
2019-05-17 21:44:38 +08:00
class RawDataset ( Dataset ) :
def __init__ ( self , root , opt ) :
self . opt = opt
self . image_path_list = [ ]
for dirpath , dirnames , filenames in os . walk ( root ) :
for name in filenames :
2019-05-21 13:47:39 +08:00
_ , ext = os . path . splitext ( name )
ext = ext . lower ( )
if ext == ' .jpg ' or ext == ' .jpeg ' or ext == ' .png ' :
self . image_path_list . append ( os . path . join ( dirpath , name ) )
2019-05-17 21:44:38 +08:00
2019-05-21 13:31:52 +08:00
self . image_path_list = natsorted ( self . image_path_list )
2019-05-17 21:44:38 +08:00
self . nSamples = len ( self . image_path_list )
def __len__ ( self ) :
return self . nSamples
def __getitem__ ( self , index ) :
try :
if self . opt . rgb :
img = Image . open ( self . image_path_list [ index ] ) . convert ( ' RGB ' ) # for color image
else :
img = Image . open ( self . image_path_list [ index ] ) . convert ( ' L ' )
except IOError :
print ( f ' Corrupted image for { index } ' )
# make dummy image and dummy label for corrupted image.
if self . opt . rgb :
img = Image . new ( ' RGB ' , ( self . opt . imgW , self . opt . imgH ) )
else :
img = Image . new ( ' L ' , ( self . opt . imgW , self . opt . imgH ) )
return ( img , self . image_path_list [ index ] )
2019-04-05 18:45:29 +08:00
class ResizeNormalize ( object ) :
def __init__ ( self , size , interpolation = Image . BICUBIC ) :
self . size = size
self . interpolation = interpolation
self . toTensor = transforms . ToTensor ( )
def __call__ ( self , img ) :
img = img . resize ( self . size , self . interpolation )
img = self . toTensor ( img )
img . sub_ ( 0.5 ) . div_ ( 0.5 )
return img
2019-05-10 10:11:06 +08:00
class NormalizePAD ( object ) :
def __init__ ( self , max_size , PAD_type = ' right ' ) :
self . toTensor = transforms . ToTensor ( )
self . max_size = max_size
self . max_width_half = math . floor ( max_size [ 2 ] / 2 )
self . PAD_type = PAD_type
def __call__ ( self , img ) :
img = self . toTensor ( img )
img . sub_ ( 0.5 ) . div_ ( 0.5 )
c , h , w = img . size ( )
Pad_img = torch . FloatTensor ( * self . max_size ) . fill_ ( 0 )
Pad_img [ : , : , : w ] = img # right pad
if self . max_size [ 2 ] != w : # add border Pad
Pad_img [ : , : , w : ] = img [ : , : , w - 1 ] . unsqueeze ( 2 ) . expand ( c , h , self . max_size [ 2 ] - w )
return Pad_img
2019-04-05 18:45:29 +08:00
class AlignCollate ( object ) :
2019-05-10 10:11:06 +08:00
def __init__ ( self , imgH = 32 , imgW = 100 , keep_ratio_with_pad = False ) :
2019-04-05 18:45:29 +08:00
self . imgH = imgH
self . imgW = imgW
2019-05-10 10:11:06 +08:00
self . keep_ratio_with_pad = keep_ratio_with_pad
2019-04-05 18:45:29 +08:00
def __call__ ( self , batch ) :
batch = filter ( lambda x : x is not None , batch )
images , labels = zip ( * batch )
2019-05-10 10:11:06 +08:00
if self . keep_ratio_with_pad : # same concept with 'Rosetta' paper
resized_max_w = self . imgW
2019-10-23 13:39:34 +08:00
input_channel = 3 if images [ 0 ] . mode == ' RGB ' else 1
transform = NormalizePAD ( ( input_channel , self . imgH , resized_max_w ) )
2019-11-04 12:32:46 +08:00
2019-05-10 10:11:06 +08:00
resized_images = [ ]
for image in images :
w , h = image . size
ratio = w / float ( h )
if math . ceil ( self . imgH * ratio ) > self . imgW :
resized_w = self . imgW
else :
resized_w = math . ceil ( self . imgH * ratio )
resized_image = image . resize ( ( resized_w , self . imgH ) , Image . BICUBIC )
resized_images . append ( transform ( resized_image ) )
# resized_image.save('./image_test/%d_test.jpg' % w)
image_tensors = torch . cat ( [ t . unsqueeze ( 0 ) for t in resized_images ] , 0 )
else :
transform = ResizeNormalize ( ( self . imgW , self . imgH ) )
image_tensors = [ transform ( image ) for image in images ]
image_tensors = torch . cat ( [ t . unsqueeze ( 0 ) for t in image_tensors ] , 0 )
2019-04-05 18:45:29 +08:00
return image_tensors , labels
def tensor2im ( image_tensor , imtype = np . uint8 ) :
image_numpy = image_tensor . cpu ( ) . float ( ) . numpy ( )
if image_numpy . shape [ 0 ] == 1 :
image_numpy = np . tile ( image_numpy , ( 3 , 1 , 1 ) )
image_numpy = ( np . transpose ( image_numpy , ( 1 , 2 , 0 ) ) + 1 ) / 2.0 * 255.0
return image_numpy . astype ( imtype )
def save_image ( image_numpy , image_path ) :
image_pil = Image . fromarray ( image_numpy )
image_pil . save ( image_path )