简单图像识别第一步:减色化+柱状图
这里我们进行简单的图像识别。
图像识别是识别图像中物体的类别(它属于哪个类)的任务。图像识别通常被称为Classification、Categorization、Clustering等。
一种常见的方法是通过 HOG、SIFT、SURF 等方法从图像中提取一些特征,并通过特征确定物体类别。这种方法在CNN普及之前广泛采用,但CNN可以完成从特征提取到分类等一系列任务。
这里,利用图像的颜色直方图来执行简单的图像识别。算法如下:
- 将图像
train_***.jpg
进行减色处理(\text{RGB}取4种值)。 - 创建减色图像的直方图。直方图中,\text{RGB}分别取四个值,但为了区分它们,B = [1,4]、G = [5,8]、R = [9,12],这样bin=12。请注意,我们还需要为每个图像保存相应的柱状图。也就是说,需要将数据储存在
database = np.zeros((10(训练数据集数), 13(RGB + class), dtype=np.int)
中。 - 将步骤2中计算得到的柱状图记为 database。
- 计算想要识别的图像
test@@@.jpg
与直方图之间的差,将差称作特征量。 - 直方图的差异的总和是最小图像是预测的类别。换句话说,它被认为与近色图像属于同一类。
- 计算将想要识别的图像(
test_@@@.jpg
)的柱状图(与train_***.jpg
的柱状图)的差,将这个差作为特征量。 - 统计柱状图的差,差最小的图像为预测的类别。换句话说,可以认为待识别图像与具有相似颜色的图像属于同一类。
在这里,实现步骤1至步骤3并可视化柱状图。
训练数据集存放在文件夹dataset
中,分为trainakahara@@@.jpg
(类别1)和trainmadara@@@.jpg
(类别2)两类,共计10张。akahara
是红腹蝾螈(Cynops pyrrhogaster),madara
是理纹欧螈(Triturus marmoratus)。
train_akahara_1.jpg:
train_akahara_2.jpg:
train_akahara_3.jpg:
train_akahara_4.jpg:
train_akahara_5.jpg:
train_madara_1.jpg
train_madara_2.jpg
train_madara_3.jpg
train_madara_4.jpg
train_madara_5.jpg
这种预先将特征量存储在数据库中的方法是第一代人工智能方法。这个想法是逻辑是,如果你预先记住整个模式,那么在识别的时候就没有问题。但是,这样做会消耗大量内存,这是一种有局限的方法。
输出 |
---|
python实现:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
## Dicrease color
def dic_color(img):
img //= 63
img = img * 64 + 32
return img
## Database
def get_DB():
# get image paths
train = glob("dataset/train_*")
train.sort()
# prepare database
db = np.zeros((len(train), 13), dtype=np.int32)
# each image
for i, path in enumerate(train):
img = dic_color(cv2.imread(path))
# get histogram
for j in range(4):
db[i, j] = len(np.where(img[..., 0] == (64 * j + 32))[0])
db[i, j+4] = len(np.where(img[..., 1] == (64 * j + 32))[0])
db[i, j+8] = len(np.where(img[..., 2] == (64 * j + 32))[0])
# get class
if 'akahara' in path:
cls = 0
elif 'madara' in path:
cls = 1
# store class label
db[i, -1] = cls
img_h = img.copy() // 64
img_h[..., 1] += 4
img_h[..., 2] += 8
plt.subplot(2, 5, i+1)
plt.hist(img_h.ravel(), bins=12, rwidth=0.8)
plt.title(path)
print(db)
plt.show()
# get database
get_DB()
被存储的直方图的内容
[[ 172 12254 2983 975 485 11576 3395 928 387 10090 4845 1062 0]
[ 3627 7350 4420 987 1743 8438 4651 1552 848 9089 4979 1468 0]
[ 1646 6547 5807 2384 1715 8502 5233 934 1553 5270 7167 2394 0]
[ 749 10142 5465 28 1431 7922 7001 30 1492 7819 7024 49 0]
[ 927 4197 8581 2679 669 5689 7959 2067 506 3973 6387 5518 0]
[ 2821 6404 2540 4619 1625 7317 3019 4423 225 8635 1591 5933 1]
[ 5575 7831 1619 1359 4638 6777 3553 1416 4675 7964 2176 1569 1]
[ 4867 7523 3275 719 4457 6390 3049 2488 4328 7135 3377 1544 1]
[ 7881 6160 1992 351 7426 3967 4258 733 7359 4979 3322 724 1]
[ 5638 6580 3916 250 5041 4185 6286 872 5226 4930 5552 676 1]]