OpenCV3-Python下ANN进行MNIST数字识别
admin 于 2018年10月08日 发表在 计算机视觉
1. MNIST数据库
MNIST数据库是Web上非常流行的OCR和手写字符识别分类器的训练资源。
下载mnist.pkl.gz数据集,并将其放置到与.py文件同级目录下。
2. 初始化参数
(1)输入层
由于采用MNIST数据库,它里面的每幅图像大小为28x28像素,即784像素,因此输入层有784个输入节点。
(2)隐藏层
隐藏层大小没有固定,通过多次尝试发现,在训练数据量不大的情况下,50-60个节点可得到最好结果。
(3)输出层
输出层为0-9的数字,共10个节点。
3. 封装ANN库
为了尽可能自行执行,此处建立了一个迷你库,用来封装ANN到OpenCV中的原始实现。
注意:pickle是MNIST数据库序列化库,需要提前确保已安装。
digits_ann.py内容如下:
import cv2 import pickle import numpy as np import gzip """OpenCV ANN Handwritten digit recognition example Wraps OpenCV's own ANN by automating the loading of data and supplying default paramters, such as 20 hidden layers, 10000 samples and 1 training epoch. The load data code is taken from http://neuralnetworksanddeeplearning.com/chap1.html by Michael Nielsen """ def vectorized_result(j): e = np.zeros((10, 1)) e[j] = 1.0 return e def load_data(): with gzip.open('./mnist.pkl.gz') as fp: #注意版本不同,需要添加传入第二个参数encoding='bytes',否则出现编码错误 training_data, valid_data, test_data = pickle.load(fp,encoding='bytes') fp.close() return (training_data, valid_data, test_data) def wrap_data(): #tr_d数组长度为50000,va_d数组长度为10000,te_d数组长度为10000 tr_d, va_d, te_d = load_data() #训练数据集 training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]] training_results = [vectorized_result(y) for y in tr_d[1]] training_data = list(zip(training_inputs, training_results)) #校验数据集 validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]] validation_data = list(zip(validation_inputs, va_d[1])) #测试数据集 test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]] test_data = list(zip(test_inputs, te_d[1])) return (training_data, validation_data, test_data) def create_ANN(hidden = 20): ann = cv2.ml.ANN_MLP_create() #建立模型 ann.setTrainMethod(cv2.ml.ANN_MLP_RPROP | cv2.ml.ANN_MLP_UPDATE_WEIGHTS) #设置训练方式为反向传播 ann.setActivationFunction(cv2.ml.ANN_MLP_SIGMOID_SYM) #设置激活函数为SIGMOID,还有cv2.ml.ANN_MLP_IDENTITY,cv2.ml.ANNMLP_GAUSSIAN ann.setLayerSizes(np.array([784, hidden, 10])) #设置层数,输入784层,输出层10 ann.setTermCriteria(( cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 100, 0.1 )) #设置终止条件 return ann def train(ann, samples = 10000, epochs = 1): #tr:训练数据集; val:校验数据集; test:测试数据集; tr, val, test = wrap_data() for x in range(epochs): counter = 0 for img in tr: if (counter > samples): break if (counter % 1000 == 0): print ("Epoch %d: Trained %d/%d" % (x, counter, samples)) counter += 1 data, digit = img ann.train(np.array([data.ravel()], dtype=np.float32), cv2.ml.ROW_SAMPLE, np.array([digit.ravel()], dtype=np.float32)) print ("Epoch %d complete" % x) return ann, test def predict(ann, sample): resized = sample.copy() rows, cols = resized.shape if rows != 28 and cols != 28 and rows * cols > 0: resized = cv2.resize(resized, (28, 28), interpolation=cv2.INTER_CUBIC) return ann.predict(np.array([resized.ravel()], dtype=np.float32))
4. ANN库调用方法
import cv2 import numpy as np import digits_ann as ANN def inside(r1, r2): x1,y1,w1,h1 = r1 x2,y2,w2,h2 = r2 if (x1 > x2) and (y1 > y2) and (x1+w1 < x2+w2) and (y1+h1 < y2 + h2): return True else: return False def wrap_digit(rect): x, y, w, h = rect padding = 5 hcenter = x + w/2 vcenter = y + h/2 if (h > w): w = h x = hcenter - (w/2) else: h = w y = vcenter - (h/2) return (int(x-padding), int(y-padding),int(w+padding),int(h+padding)) ''' 注意:首次测试时,建议将使用完整的训练数据集,且进行多次迭代,直到收敛 如:ann, test_data = ANN.train(ANN.create_ANN(100), 50000, 30) ''' ann, test_data = ANN.train(ANN.create_ANN(10), 50000, 1) #调用所需识别的图片,并处理 path = "./3.png" img = cv2.imread(path, cv2.IMREAD_UNCHANGED) bw = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) bw = cv2.GaussianBlur(bw, (7,7), 0) ret, thbw = cv2.threshold(bw, 127, 255, cv2.THRESH_BINARY_INV) thbw = cv2.erode(thbw, np.ones((2,2), np.uint8), iterations = 2) image, cntrs, hier = cv2.findContours(thbw.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) rectangles = [] for c in cntrs: r = x,y,w,h = cv2.boundingRect(c) a = cv2.contourArea(c) b = (img.shape[0]-3) * (img.shape[1] - 3) is_inside = False for q in rectangles: if inside(r, q): is_inside = True break if not is_inside: if not a == b: rectangles.append(r) for r in rectangles: x,y,w,h = wrap_digit(r) cv2.rectangle(img, (x,y), (x+w, y+h), (0, 255, 0), 2) roi = thbw[y:y+h, x:x+w] try: digit_class = ANN.predict(ann, roi)[0] except: print("except") continue cv2.putText(img, "%d" % digit_class, (x, y-1), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0)) cv2.imshow("thbw", thbw) cv2.imshow("contours", img) cv2.waitKey() cv2.destroyAllWindows()
5. 识别结果
除调用自带的ANN接口,也可自己尝试从零开始写ANN,可参考Michael Nielsen的文章。