deep-text-recognition-bench.../demo.py

130 lines
6.2 KiB
Python
Raw Normal View History

2019-05-17 21:44:38 +08:00
import string
import argparse
import torch
import torch.backends.cudnn as cudnn
import torch.utils.data
2019-10-22 19:46:45 +08:00
import torch.nn.functional as F
2019-05-17 21:44:38 +08:00
from utils import CTCLabelConverter, AttnLabelConverter
from dataset import RawDataset, AlignCollate
from model import Model
2019-08-03 16:03:46 +08:00
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2019-05-17 21:44:38 +08:00
def demo(opt):
""" model configuration """
if 'CTC' in opt.Prediction:
converter = CTCLabelConverter(opt.character)
else:
converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character)
if opt.rgb:
opt.input_channel = 3
model = Model(opt)
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
opt.SequenceModeling, opt.Prediction)
2019-08-03 16:03:46 +08:00
model = torch.nn.DataParallel(model).to(device)
2019-05-17 21:44:38 +08:00
# load model
print('loading pretrained model from %s' % opt.saved_model)
model.load_state_dict(torch.load(opt.saved_model, map_location=device))
2019-05-17 21:44:38 +08:00
# prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset
demo_loader = torch.utils.data.DataLoader(
demo_data, batch_size=opt.batch_size,
shuffle=False,
num_workers=int(opt.workers),
collate_fn=AlignCollate_demo, pin_memory=True)
# predict
model.eval()
2019-08-03 14:55:32 +08:00
with torch.no_grad():
for image_tensors, image_path_list in demo_loader:
batch_size = image_tensors.size(0)
2019-08-03 16:03:46 +08:00
image = image_tensors.to(device)
2019-05-17 21:44:38 +08:00
# For max length prediction
2019-08-03 16:03:46 +08:00
length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
2019-05-17 21:44:38 +08:00
2019-08-03 14:55:32 +08:00
if 'CTC' in opt.Prediction:
preds = model(image, text_for_pred)
2019-05-17 21:44:38 +08:00
2019-08-03 14:55:32 +08:00
# Select max probabilty (greedy decoding) then decode index to character
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
2019-10-23 01:26:58 +08:00
_, preds_index = preds.max(2)
# preds_index = preds_index.view(-1)
2020-07-24 09:26:01 +08:00
preds_str = converter.decode(preds_index, preds_size)
2019-05-17 21:44:38 +08:00
2019-08-03 14:55:32 +08:00
else:
preds = model(image, text_for_pred, is_train=False)
2019-05-17 21:44:38 +08:00
2019-08-03 14:55:32 +08:00
# select max probabilty (greedy decoding) then decode index to character
_, preds_index = preds.max(2)
preds_str = converter.decode(preds_index, length_for_pred)
2019-05-17 21:44:38 +08:00
log = open(f'./log_demo_result.txt', 'a')
dashed_line = '-' * 80
head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'
print(f'{dashed_line}\n{head}\n{dashed_line}')
log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')
2019-10-22 19:46:45 +08:00
preds_prob = F.softmax(preds, dim=2)
preds_max_prob, _ = preds_prob.max(dim=2)
for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob):
2019-08-03 14:55:32 +08:00
if 'Attn' in opt.Prediction:
2019-10-22 19:46:45 +08:00
pred_EOS = pred.find('[s]')
pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
pred_max_prob = pred_max_prob[:pred_EOS]
2019-05-17 21:44:38 +08:00
2019-10-22 19:46:45 +08:00
# calculate confidence score (= multiply of pred_max_prob)
confidence_score = pred_max_prob.cumprod(dim=0)[-1]
2019-10-22 21:58:23 +08:00
print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}')
log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n')
2019-05-17 21:44:38 +08:00
log.close()
2019-05-17 21:44:38 +08:00
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_folder', required=True, help='path to image_folder which contains text images')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
parser.add_argument('--batch_size', type=int, default=192, help='input batch size')
parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation")
""" Data processing """
parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
parser.add_argument('--rgb', action='store_true', help='use rgb input')
parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')
parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')
""" Model Architecture """
parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')
parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor')
parser.add_argument('--output_channel', type=int, default=512,
help='the number of output channel of Feature extractor')
parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
opt = parser.parse_args()
""" vocab / character number configuration """
if opt.sensitive:
opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
cudnn.benchmark = True
cudnn.deterministic = True
opt.num_gpu = torch.cuda.device_count()
demo(opt)