CIFAR-10 最近邻分类识别 python3 NearestNeighbor
发布日期:2021-11-21 04:41:19 浏览次数:31 分类:技术文章

本文共 2920 字,大约阅读时间需要 9 分钟。

用到的数据集:

http://www.cs.toronto.edu/~kriz/cifar.html,自行下载python版本

参考的是斯坦福大学CS231N教程的notes

实现程序如下:

需要指出的是数据集有些大,如果仅仅用于测试程序对不对这里只选择了1000个图片进行训练。如果真的需要把索引去掉即是完整的训练和测试了

import pickle as pimport matplotlib.pyplot as pltimport numpy as np# NearestNeighbor classclass NearestNeighbor(object):    def __init__(self):        pass    def train(self, X, y):        """ X is N x D where each row is an example. Y is 1-dimension of size N """        # the nearest neighbor classifier simply remembers all the training data        self.Xtr = X        self.ytr = y    def predict(self, X):        """ X is N x D where each row is an example we wish to predict label for """        num_test = X.shape[0]        # lets make sure that the output type matches the input type        Ypred = np.zeros(num_test, dtype=self.ytr.dtype)        # loop over all test rows        for i in range(num_test):            # find the nearest training image to the i'th test image            # using the L1 distance (sum of absolute value differences)            distances = np.sum(np.abs(self.Xtr - X[i, :]), axis=1)            min_index = np.argmin(distances)  # get the index with smallest distance            Ypred[i] = self.ytr[min_index]  # predict the label of the nearest example        return Ypreddef load_CIFAR_batch(filename):    """ load single batch of cifar """    with open(filename, 'rb')as f:        datadict = p.load(f, encoding='latin1')        X = datadict['data']        Y = datadict['labels']        Y = np.array(Y)  # 字典里载入的Y是list类型,把它变成array类型        return X, Ydef load_CIFAR_Labels(filename):    with open(filename, 'rb') as f:        label_names = p.load(f, encoding='latin1')        names = label_names['label_names']        return names
# load datalabel_names = load_CIFAR_Labels("cifar-10-batches-py/batches.meta")imgX1, imgY1 = load_CIFAR_batch("cifar-10-batches-py/data_batch_1")imgX2, imgY2 = load_CIFAR_batch("cifar-10-batches-py/data_batch_2")imgX3, imgY3 = load_CIFAR_batch("cifar-10-batches-py/data_batch_3")imgX4, imgY4 = load_CIFAR_batch("cifar-10-batches-py/data_batch_4")imgX5, imgY5 = load_CIFAR_batch("cifar-10-batches-py/data_batch_5")Xte_rows, Yte = load_CIFAR_batch("cifar-10-batches-py/test_batch")Xtr_rows = np.concatenate((imgX1, imgX2, imgX3, imgX4, imgX5))Ytr_rows = np.concatenate((imgY1, imgY2, imgY3, imgY4, imgY5))nn = NearestNeighbor()  # create a Nearest Neighbor classifier classnn.train(Xtr_rows[:1000,:], Ytr_rows[:1000])  # train the classifier on the training images and labelsYte_predict = nn.predict(Xte_rows[:100,:])  # predict labels on the test images# and now print the classification accuracy, which is the average number# of examples that are correctly predicted (i.e. label matches)print('accuracy: %f' % (np.mean(Yte_predict == Yte[:100])))# show a pictureimage=imgX1[6,0:1024].reshape(32,32)print(image.shape)plt.imshow(image,cmap=plt.cm.gray)plt.axis('off')    #去除图片边上的坐标轴plt.show()

转载地址:https://blog.csdn.net/xiaojiajia007/article/details/53708818 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:python 图像读入 reshape尺寸时的问题
下一篇:mnsit 手写数据集 python3.x的读入 以及利用softmax回归进行数字识别

发表评论

最新留言

网站不错 人气很旺了 加油
[***.192.178.218]2024年03月20日 17时09分56秒