vehicle-license-plate-recog.../svm_train.py

173 lines
4.5 KiB
Python
Raw Normal View History

2019-07-06 09:59:59 +08:00
'''
训练svm
'''
import cv2
import numpy as np
from numpy.linalg import norm
import sys
import os
import json
SZ = SZ = 20
PROVINCE_START = 1000
provinces = [
"zh_cuan", "",
"zh_e", "",
"zh_gan", "",
"zh_gan1", "",
"zh_gui", "",
"zh_gui1", "",
"zh_hei", "",
"zh_hu", "",
"zh_ji", "",
"zh_jin", "",
"zh_jing", "",
"zh_jl", "",
"zh_liao", "",
"zh_lu", "",
"zh_meng", "",
"zh_min", "",
"zh_ning", "",
"zh_qing", "",
"zh_qiong", "",
"zh_shan", "",
"zh_su", "",
"zh_sx", "",
"zh_wan", "",
"zh_xiang", "",
"zh_xin", "",
"zh_yu", "",
"zh_yu1", "",
"zh_yue", "",
"zh_yun", "",
"zh_zang", "",
"zh_zhe", ""
]
# 数据处理
def deskew(img):
m = cv2.moments(img)
if abs(m['mu02']) < 1e-2:
return img.copy()
skew = m['mu11']/m['mu02']
M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]])
img = cv2.warpAffine(img, M, (SZ, SZ), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)
return img
# 特征工程
def preprocess_hog(digits):
samples = []
for img in digits:
gx = cv2.Sobel(img, cv2.CV_32F, 1, 0)
gy = cv2.Sobel(img, cv2.CV_32F, 0, 1)
mag, ang = cv2.cartToPolar(gx, gy)
bin_n = 16
bin = np.int32(bin_n*ang/(2*np.pi))
bin_cells = bin[:10,:10], bin[10:,:10], bin[:10,10:], bin[10:,10:]
mag_cells = mag[:10,:10], mag[10:,:10], mag[:10,10:], mag[10:,10:]
hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)]
hist = np.hstack(hists)
# transform to Hellinger kernel
eps = 1e-7
hist /= hist.sum() + eps
hist = np.sqrt(hist)
hist /= norm(hist) + eps
samples.append(hist)
return np.float32(samples)
class StatModel(object):
def load(self, fn):
self.model = self.model.load(fn)
def save(self, fn):
self.model.save(fn)
class SVM(StatModel):
def __init__(self, C = 1, gamma = 0.5):
self.model = cv2.ml.SVM_create()
self.model.setGamma(gamma)
self.model.setC(C)
self.model.setKernel(cv2.ml.SVM_RBF)
self.model.setType(cv2.ml.SVM_C_SVC)
# train svm
def train(self, samples, responses):
self.model.train(samples, cv2.ml.ROW_SAMPLE, responses)
# inference
def predict(self, samples):
r = self.model.predict(samples)
return r[1].ravel()
def train_svm(self):
#识别英文字母和数字
self.model = SVM(C=1, gamma=0.5)
#识别中文
self.modelchinese = SVM(C=1, gamma=0.5)
if os.path.exists("./train_dat/svm.dat"):
self.model.load("./train_dat/svm.dat")
else:
chars_train = []
chars_label = []
for root, dirs, files in os.walk("./train/chars2"):
if len(os.path.basename(root)) > 1:
continue
root_int = ord(os.path.basename(root))
for filename in files:
filepath = os.path.join(root,filename)
digit_img = cv2.imread(filepath)
digit_img = cv2.cvtColor(digit_img, cv2.COLOR_BGR2GRAY)
chars_train.append(digit_img)
#chars_label.append(1)
chars_label.append(root_int)
chars_train = list(map(deskew, chars_train))
chars_train = preprocess_hog(chars_train)
#chars_train = chars_train.reshape(-1, 20, 20).astype(np.float32)
chars_label = np.array(chars_label)
print(chars_train.shape)
self.model.train(chars_train, chars_label)
if os.path.exists("./train_dat/svmchinese.dat"):
self.modelchinese.load("./train_dat/svmchinese.dat")
else:
chars_train = []
chars_label = []
for root, dirs, files in os.walk("./train/charsChinese"):
if not os.path.basename(root).startswith("zh_"):
continue
pinyin = os.path.basename(root)
index = provinces.index(pinyin) + PROVINCE_START + 1 #1是拼音对应的汉字
for filename in files:
filepath = os.path.join(root,filename)
digit_img = cv2.imread(filepath)
digit_img = cv2.cvtColor(digit_img, cv2.COLOR_BGR2GRAY)
chars_train.append(digit_img)
#chars_label.append(1)
chars_label.append(index)
chars_train = list(map(deskew, chars_train))
chars_train = preprocess_hog(chars_train)
#chars_train = chars_train.reshape(-1, 20, 20).astype(np.float32)
chars_label = np.array(chars_label)
print(chars_train.shape)
self.modelchinese.train(chars_train, chars_label)
return self.model, self.modelchinese
def save_traindata(self):
if not os.path.exists("./train_dat/svm.dat"):
self.model.save("./train_dat/svm.dat")
if not os.path.exists("./train_dat/svmchinese.dat"):
self.modelchinese.save("./train_dat/svmchinese.dat")
if __name__ == "__main__":
svm_model = SVM(C=1, gamma=0.5)
# svm_model.save_traindata()
model_1,model_2 = svm_model.train_svm()
print(model_1)