''' 训练svm ''' import cv2 import numpy as np from numpy.linalg import norm import sys import os import json SZ = SZ = 20 PROVINCE_START = 1000 provinces = [ "zh_cuan", "川", "zh_e", "鄂", "zh_gan", "赣", "zh_gan1", "甘", "zh_gui", "贵", "zh_gui1", "桂", "zh_hei", "黑", "zh_hu", "沪", "zh_ji", "冀", "zh_jin", "津", "zh_jing", "京", "zh_jl", "吉", "zh_liao", "辽", "zh_lu", "鲁", "zh_meng", "蒙", "zh_min", "闽", "zh_ning", "宁", "zh_qing", "靑", "zh_qiong", "琼", "zh_shan", "陕", "zh_su", "苏", "zh_sx", "晋", "zh_wan", "皖", "zh_xiang", "湘", "zh_xin", "新", "zh_yu", "豫", "zh_yu1", "渝", "zh_yue", "粤", "zh_yun", "云", "zh_zang", "藏", "zh_zhe", "浙" ] # 数据处理 def deskew(img): m = cv2.moments(img) if abs(m['mu02']) < 1e-2: return img.copy() skew = m['mu11']/m['mu02'] M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]]) img = cv2.warpAffine(img, M, (SZ, SZ), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR) return img # 特征工程 def preprocess_hog(digits): samples = [] for img in digits: gx = cv2.Sobel(img, cv2.CV_32F, 1, 0) gy = cv2.Sobel(img, cv2.CV_32F, 0, 1) mag, ang = cv2.cartToPolar(gx, gy) bin_n = 16 bin = np.int32(bin_n*ang/(2*np.pi)) bin_cells = bin[:10,:10], bin[10:,:10], bin[:10,10:], bin[10:,10:] mag_cells = mag[:10,:10], mag[10:,:10], mag[:10,10:], mag[10:,10:] hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)] hist = np.hstack(hists) # transform to Hellinger kernel eps = 1e-7 hist /= hist.sum() + eps hist = np.sqrt(hist) hist /= norm(hist) + eps samples.append(hist) return np.float32(samples) class StatModel(object): def load(self, fn): self.model = self.model.load(fn) def save(self, fn): self.model.save(fn) class SVM(StatModel): def __init__(self, C = 1, gamma = 0.5): self.model = cv2.ml.SVM_create() self.model.setGamma(gamma) self.model.setC(C) self.model.setKernel(cv2.ml.SVM_RBF) self.model.setType(cv2.ml.SVM_C_SVC) # train svm def train(self, samples, responses): self.model.train(samples, cv2.ml.ROW_SAMPLE, responses) # inference def predict(self, samples): r = self.model.predict(samples) return r[1].ravel() def train_svm(self): #识别英文字母和数字 self.model = SVM(C=1, gamma=0.5) #识别中文 self.modelchinese = SVM(C=1, gamma=0.5) if os.path.exists("./train_dat/svm.dat"): self.model.load("./train_dat/svm.dat") else: chars_train = [] chars_label = [] for root, dirs, files in os.walk("./train/chars2"): if len(os.path.basename(root)) > 1: continue root_int = ord(os.path.basename(root)) for filename in files: filepath = os.path.join(root,filename) digit_img = cv2.imread(filepath) digit_img = cv2.cvtColor(digit_img, cv2.COLOR_BGR2GRAY) chars_train.append(digit_img) #chars_label.append(1) chars_label.append(root_int) chars_train = list(map(deskew, chars_train)) chars_train = preprocess_hog(chars_train) #chars_train = chars_train.reshape(-1, 20, 20).astype(np.float32) chars_label = np.array(chars_label) print(chars_train.shape) self.model.train(chars_train, chars_label) if os.path.exists("./train_dat/svmchinese.dat"): self.modelchinese.load("./train_dat/svmchinese.dat") else: chars_train = [] chars_label = [] for root, dirs, files in os.walk("./train/charsChinese"): if not os.path.basename(root).startswith("zh_"): continue pinyin = os.path.basename(root) index = provinces.index(pinyin) + PROVINCE_START + 1 #1是拼音对应的汉字 for filename in files: filepath = os.path.join(root,filename) digit_img = cv2.imread(filepath) digit_img = cv2.cvtColor(digit_img, cv2.COLOR_BGR2GRAY) chars_train.append(digit_img) #chars_label.append(1) chars_label.append(index) chars_train = list(map(deskew, chars_train)) chars_train = preprocess_hog(chars_train) #chars_train = chars_train.reshape(-1, 20, 20).astype(np.float32) chars_label = np.array(chars_label) print(chars_train.shape) self.modelchinese.train(chars_train, chars_label) return self.model, self.modelchinese def save_traindata(self): if not os.path.exists("./train_dat/svm.dat"): self.model.save("./train_dat/svm.dat") if not os.path.exists("./train_dat/svmchinese.dat"): self.modelchinese.save("./train_dat/svmchinese.dat") if __name__ == "__main__": svm_model = SVM(C=1, gamma=0.5) # svm_model.save_traindata() model_1,model_2 = svm_model.train_svm() print(model_1)