简单图像识别第四步:k-NN
Opencv 简单图像识别第二步 中虽然我们预测了颜色最接近的图像,但实际上和testmadara2.jpg
最接近的是trainakahara2.jpg
。
test_marada_2.jpg | train_akahara_2.jpg |
---|---|
如果比较这两个图像,它们绿色和黑色比例看起来差不多,因此整个图像颜色看起来相同。这是因为在识别的时候,训练图像选择了一张偏离大部分情况的图像。因此,训练数据集的特征不能很好地分离,并且有时包括偏离特征分布的样本。
为了避免这中情况发生,在这里我们选择颜色相近的三副图像,并通过投票来预测最后的类别,再计算正确率。
像这样选择具有相似特征的3个学习数据的方法被称为 k-近邻算法(k-NN: k-Nearest Neighbor)。 Opencv 简单图像识别第二步中的NN 方法是 k = 1的情况。
python实现:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
# Dicrease color
def dic_color(img):
img //= 63
img = img * 64 + 32
return img
# Database
def get_DB():
# get training image path
train = glob("dataset/train_*")
train.sort()
# prepare database
db = np.zeros((len(train), 13), dtype=np.int32)
pdb = []
# each train
for i, path in enumerate(train):
# read image
img = dic_color(cv2.imread(path))
# histogram
for j in range(4):
db[i, j] = len(np.where(img[..., 0] == (64 * j + 32))[0])
db[i, j+4] = len(np.where(img[..., 1] == (64 * j + 32))[0])
db[i, j+8] = len(np.where(img[..., 2] == (64 * j + 32))[0])
# get class
if 'akahara' in path:
cls = 0
elif 'madara' in path:
cls = 1
# store class label
db[i, -1] = cls
# add image path
pdb.append(path)
return db, pdb
# test
def test_DB(db, pdb, N=3):
# get test image path
test = glob("dataset/test_*")
test.sort()
accuracy_N = 0.
# each image
for path in test:
# read image
img = dic_color(cv2.imread(path))
# get histogram
hist = np.zeros(12, dtype=np.int32)
for j in range(4):
hist[j] = len(np.where(img[..., 0] == (64 * j + 32))[0])
hist[j+4] = len(np.where(img[..., 1] == (64 * j + 32))[0])
hist[j+8] = len(np.where(img[..., 2] == (64 * j + 32))[0])
# get histogram difference
difs = np.abs(db[:, :12] - hist)
difs = np.sum(difs, axis=1)
# get top N
pred_i = np.argsort(difs)[:N]
# predict class index
pred = db[pred_i, -1]
# get class label
if len(pred[pred == 0]) > len(pred[pred == 1]):
pl = "akahara"
else:
pl = 'madara'
print(path, "is similar >> ", end='')
for i in pred_i:
print(pdb[i], end=', ')
print("|Pred >>", pl)
# count accuracy
gt = "akahara" if "akahara" in path else "madara"
if gt == pl:
accuracy_N += 1.
accuracy = accuracy_N / len(test)
print("Accuracy >>", accuracy, "({}/{})".format(int(accuracy_N), len(test)))
db, pdb = get_DB()
test_DB(db, pdb)
答案:
test_akahara_1.jpg is similar >> train_akahara_3.jpg, train_akahara_2.jpg, train_akahara_4.jpg, |Pred >> akahara
test_akahara_2.jpg is similar >> train_akahara_1.jpg, train_akahara_2.jpg, train_akahara_4.jpg, |Pred >> akahara
test_madara_1.jpg is similar >> train_madara_2.jpg, train_madara_4.jpg, train_madara_3.jpg, |Pred >> madara
test_madara_2.jpg is similar >> train_akahara_2.jpg, train_madara_3.jpg, train_madara_2.jpg, |Pred >> madara
Accuracy >> 1.0 (4/4)