import cv2
import numpy as np

datapath = "./CarData/TrainImages/"
SAMPLES = 200

def path(cls,i):
    return "%s/%s%d.pgm"  % (datapath,cls,i+1)

def get_flann_matcher():
  flann_params = dict(algorithm = 1, trees = 5)
  return cv2.FlannBasedMatcher(flann_params, {})

def get_bow_extractor(extract, match):
  return cv2.BOWImgDescriptorExtractor(extract, match)

def get_extract_detect():
  return cv2.xfeatures2d.SIFT_create(), cv2.xfeatures2d.SIFT_create()

def extract_sift(fn, extractor, detector):
  im = cv2.imread(fn,0)
  return extractor.compute(im, detector.detect(im))[1]
    
def bow_features(img, extractor_bow, detector):
  return extractor_bow.compute(img, detector.detect(img))

def car_detector():
  pos, neg = "pos-", "neg-"
  detect, extract = get_extract_detect()
  flann = get_flann_matcher()
  
  print ("building BOWKMeansTrainer...")
  bow_kmeans_trainer = cv2.BOWKMeansTrainer(12)
  extract_bow = get_bow_extractor(extract, flann)

  print ("adding features to trainer")
  for i in range(SAMPLES):
    
    #不知何故，发现i==129时数据异常
    if i==129: continue

    #print (i)     
    bow_kmeans_trainer.add(extract_sift(path(pos,i), extract, detect))
    bow_kmeans_trainer.add(extract_sift(path(neg,i), extract, detect))    
    
  vocabulary = bow_kmeans_trainer.cluster()
  extract_bow.setVocabulary(vocabulary)

  traindata, trainlabels = [],[]
  print ("adding to train data")
  for i in range(SAMPLES):
    
    #不知何故，发现i==129时数据异常
    if i==129: continue
      
    traindata.extend(bow_features(cv2.imread(path(pos, i),0),extract_bow, detect))
    trainlabels.append(1)  #1表示正匹配
    
    traindata.extend(bow_features(cv2.imread(path(neg, i),0),extract_bow, detect)); 
    trainlabels.append(-1)  #-1表示负匹配
    
    '''
    若需训练多个类，可采用如下方式
    traindata.extend(bow_features(cv2.imread(path(pos, i),0), extract_bow, detect))
    trainlabels.append(1)
        
    traindata.extend(bow_features(cv2.imread(path(neg, i),0), extract_bow, detect))
    trainlabels.append(-1)
    ... 
    traindata.extend(bow_features(cv2.imread(path(classN, i),0), extract_bow, detect))
    trainlabels.append(N)
    '''  
    #print (i)

  svm = cv2.ml.SVM_create()
  #svm.setType(cv2.ml.SVM_C_SVC)
  
  '''
  gamma是选择RBF函数作为kernel后，该函数自带的一个参数。
  隐含地决定了数据映射到新的特征空间后的分布，gamma越大，支持向量越少；
  gamma值越小，支持向量越多。支持向量的个数影响训练与预测的速度。
  具体参考：https://blog.csdn.net/wusecaiyun/article/details/49681431
  '''
  #svm.setGamma(1)
  
  #此参数决定分类器的训练误差和预测误差。
  #其值越大，误判可能性越小，但训练精度降低。另一方面，值太低会导致过拟合，从而预测精度降低。
  #svm.setC(35)
  
  #SVM_LINEAR：分类器为线性超平面，实际应用中非常适用于二分类；
  #SVM_RBF：高斯函数进行分类，意味着数据被分到由这些函数定义的核中；当训练超过两个类时，必须使用RBF
  #svm.setKernel(cv2.ml.SVM_RBF)

  svm.train(np.array(traindata), cv2.ml.ROW_SAMPLE, np.array(trainlabels))
  return svm, extract_bow
