softmax手写数字识别 python代码

2017-01-13 08:19:15来源:CSDN作者:u013781175人点击

第七城市

本文只贴出一段简单的代码。

数据格式如下,下图为数字2:


大小为32X32,存在多个文本文件中。

代码如下:

# -*- coding: utf-8 -*-import matplotlib.pyplot as pltimport numpy as np"""导入用于矩阵计算的numpy包和用于作图的matplotlib包"""def get_vector(f_name):    """打开文件并且将数据读入,存在链表中"""    vec = []    fp = open(f_name)    for i in xrange(32):        line = fp.readline()        for j in xrange(32):            vec.append(int(line[j]))    fp.close()    return vecdef import_data():    """导入所有训练集数据,并将数据转化为numpy矩阵格式"""    data = []    label = []    for i in xrange(10):        for j in xrange(150):            f_name = "trainingDigits//"+str(i)+"_"+str(j)+".txt"            data.append(get_vector(f_name))            label.append(i)    return np.mat(data), np.mat(label)def import_test():    """导入所有测试集数据,并将数据转化为numpy矩阵格式"""    data = []    label = []    for i in xrange(10):        for j in xrange(50):            f_name = "testDigits//"+str(i)+"_"+str(j)+".txt"            data.append(get_vector(f_name))            label.append(i)    return np.mat(data), np.mat(label)def sigmoid(x):    """定义sigmod函数"""    return 1.0/(1+np.exp(-x))def ga(w, data, label, alpha, epoch):    """    用梯度上升算法训练softmax分类器    w为要训练的参数    data为训练集数据    label为训练集数据的标签    alpha为学习率    epoch为最大迭代次数    """    errors = []    for k in xrange(epoch):        cs = data*w/1024.0    # 归一化        error = np.exp(cs)        row_sum = -error.sum(axis=1)        row_sum = row_sum.repeat(10, axis=1)        error = error/row_sum        for m in range(1500):            error[m, label[m, 0]] += 1.0        w = w + alpha * data.T*error        errors.append(1-accuracy(cs.argmax(axis=1), label, 1500))    return w, errorsdef accuracy(a, b, num):    """用于测试训练好的模型在测试集上的准确率"""    n = 0.0    for i in xrange(num):        if a[i, 0] == b[i, 0]:            n += 1.0    return n/(num+0.0)def logistic():    """主函数,首先导入数据,然后训练模型,并且计算模型在测试集上的准确率,    最后画出随着迭代次数的增加,模型的训练集上准确率的变化"""    train, train_label = import_data()    test, test_label = import_test()    w, errors = ga(np.mat(np.ones((1024, 10))), train, train_label.T, 0.01, 51)    test_ = test*w    test_cal = test_.argmax(axis=1)    # for i in xrange(500):    #     print test_cal[i, 0], test_label.T[i, 0]    print u"准确率:",    print accuracy(test_cal, test_label.T, 500)  # 打印出准确率的值    plt.figure()    plt.plot(errors, linewidth=5)    plt.grid()    plt.xlabel("Error Rate")    plt.ylabel("Iterations")    plt.xlim([0, 50])    plt.show()if __name__ == "__main__":    logistic()



第七城市

最新文章

123

最新摄影

微信扫一扫

第七城市微信公众平台