mnist LSTM 训练、测试,模型保存、加载和识别

2017-09-13 20:38:41来源:CSDN作者:wanggao_1990人点击

分享

原创文章,转载请注明出处:http://blog.csdn.net/wanggao_1990/article/details/77964504

MNIST 字符数据库每个字符(0-9) 对应一张28x28的一通道图片,可以将图片的每一列(行)当作特征,所有行(列)当做一个序列。那么可以通过输入大小为28,时间长度为28的RNN(lstm)对字符建模。对于同一个字符,比如0,其行与行之间的动态变化可以很好地被RNN表示,所有这些连续行的变化表征了某个字符的特定模式。因此可以使用RNN来进行字符识别。

Tensorflow提供了不错的RNN接口,基本思路是
1. 建立RNN网络中的基本单元 cell; tf提供了很多中类型的cell, BasicRNNCell,BasicLSTMCell,LSTMCell 等等
2. 通过调用rnn.static_rnn 函数或者rnn.static_bidirectional_rnn将cell连成RNN 网络。本例子采用的是rnn.static_bidirectional_rnn函数。(版本不同有所区别)

LSTM训练、测试

import osimport numpy as np'''A Bidirectional Recurrent Neural Network (LSTM) implementation example using TensorFlow library.This example is using the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/)Long Short Term Memory paper: http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdfAuthor: Aymeric DamienProject: https://github.com/aymericdamien/TensorFlow-Examples/'''from __future__ import print_functionimport tensorflow as tffrom tensorflow.contrib import rnn# Import user date convertorimport osfrom convert_data import convert_datas# Import MNIST datafrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("/data/", one_hot=True)'''To classify images using a bidirectional recurrent neural network, we considerevery image row as a sequence of pixels. Because MNIST image shape is 28*28px,we will then handle 28 sequences of 28 steps for every sample.'''# Parameterslearning_rate = 0.001# 训练迭代次数training_iters = 100000# 每次训练的样本大小batch_size = 128# 这个是用来显示的。display_step = 10# Network Parameters# n_steps*n_input其实就是那张图 把每一行拆到每个time step上。n_input = 28 # MNIST data input (img shape: 28*28)n_steps = 28 # timesteps# 隐藏层大小n_hidden = 128 # hidden layer num of featuresn_classes = 10 # MNIST total classes (0-9 digits)# tf Graph input# [None, n_steps, n_input]这个None表示这一维(样本数)不确定大小x = tf.placeholder("float", [None, n_steps, n_input], name="input_x")y = tf.placeholder("float", [None, n_classes], name="input_y")# Define weights and biasesweights = tf.Variable(tf.random_normal([2*n_hidden, n_classes]), name="weights")biases = tf.Variable(tf.random_normal([n_classes]), name="biases")def BiRNN( x, weights, biases):    # Prepare data shape to match `bidirectional_rnn` function requirements    # Current data input shape: (batch_size, n_steps, n_input)    # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)    # Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)    # 变成了n_steps*(batch_size, n_input)    x = tf.unstack(x, n_steps, 1)    # Define lstm cells with tensorflow    # Forward direction cell    lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)    # Backward direction cell    lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)    # Get lstm cell output    try:        outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x, dtype=tf.float32)    except Exception: # Old TensorFlow version only returns outputs not states        outputs = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x, dtype=tf.float32)    # Linear activation, using rnn inner loop last output    # return tf.matmul(outputs[-1], weights['out']) + biases['out']    # return tf.matmul(outputs[-1], weights) + biases    return tf.add(tf.matmul(outputs[-1], weights), biases)pred = BiRNN(x, weights, biases)# Define loss and optimizer# softmax_cross_entropy_with_logits:Measures the probability error in discrete classification tasks in which the classes are mutually exclusive# return a 1-D Tensor of length batch_size of the same type as logits with the softmax cross entropy loss.# reduce_mean就是对所有数值(这里没有指定哪一维)求均值。cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)# Evaluate modelcorrect_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))# Initializing the variablesinit = tf.global_variables_initializer()# Launch the graphwith tf.Session() as sess:    sess.run(init)    step = 1    # Keep training until reach max iterations    while step * batch_size < training_iters:        batch_x, batch_y = mnist.train.next_batch(batch_size)        # Reshape data to get 28 seq of 28 elements        batch_x = batch_x.reshape((batch_size, n_steps, n_input))        # Run optimization op (backprop)        sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})        if step % display_step == 0:            # Calculate batch accuracy            acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})            # Calculate batch loss            loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})            print("Iter " + str(step*batch_size) + ", Minibatch Loss= " + "{:.6f}".format(loss) + /                  ", Training Accuracy= " + "{:.5f}".format(acc))        step += 1    print("Optimization Finished!")    # Calculate accuracy for 128 mnist test images    # test_len = 128    # test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))    # test_label = mnist.test.labels[:test_len]    ## Input 为 batch_size*30*17    ##  实际测试,需要满足 tensorflow的输入placeholder要求    test_data = mnist.test.images.reshape((-1, n_steps, n_input))    test_label = mnist.test.labels    print("Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label}))

保存训练模型

紧接着上面测试进度输出后,输入以下代码 重复运行即可

    saver = tf.train.Saver()    model_path = "./model/my_model"    save_path = saver.save(sess, model_path)    print("Model saved in file: %s" % save_path)

这里只是一种方式,并且保存整个网络结构。 model_path中的model是模型保存的文件夹,my_model是保存模型的前缀,可以理解为模型的名称。

运行完毕后,当前目录会新建名称为“model”的文件夹,且含有四个文件夹:checkpoint、my_model.data-00000-of-00001、my_model.index和my_model.meta。这里的四个文件的有关介绍网上有很多。

注意,这里值是进行了模型的保存,这里保存的目的是为了进行加载并对输入的数据进行测试,并且不需要重建整个网络。因此,还需要对某些计算节点进行保存,在识别阶段利用这些节点计算输出。这里需要增加1个预测节点。在pred = BiRNN(x, weights, biases)后增加:

    tf.add_to_collection('predict', pred)

将pred整个计算和“predict”整个名字绑定在一起,就可以在加载后通过整个名字读取整个运算节点。


加载训练模型 、识别

加载模型很简单,主要代码如下

with tf.Session() as sess:    new_saver = tf.train.import_meta_graph('./model/my_model.meta')    new_saver.restore(sess, './model/my_model')

这里需要注意,restore()函数的路径和保存时要一致。

接着,从加载的模型中读取需要的节点。首先是predict节点对应的pred运算,其次这个pred运算需要输入x,也就是训练代码中的占位符“input_x”。继续添加代码如下

    graph = tf.get_default_graph()        predict = tf.get_collection('predict')[0]    input_x = graph.get_operation_by_name("input_x").outputs[0]

最后,就是输入一个图片数据,对其进行识别分类了。

    x = mnist.test.images[0].reshape((1, n_steps, n_input))    res = sess.run(predict, feed_dict={input_x: x})

这里用的test数据集的第一个图,这里的过程和测试部分类似,只是没有第二个参数label。返回的结果可以通过tf.argmax进行获取类别值。

在利用argmax函数时,需要确认数据的shape,再确定计算的维度。这一部分完整代码如下:

import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("/data/", one_hot=True)n_input = 28n_steps = 30n_classes = 2with tf.Session() as sess:    new_saver = tf.train.import_meta_graph('./model/my_model.meta')    new_saver.restore(sess, './model/my_model')    graph = tf.get_default_graph()    predict = tf.get_collection('predict')[0]    input_x = graph.get_operation_by_name("input_x").outputs[0]    x = mnist.test.images[0].reshape((1, n_steps, n_input))    y = mnist.test.labels[0].reshape(-1, n_classes)  # 转为one-hot形式    res = sess.run(predict, feed_dict={input_x: test_data })    print("Actual class: ", str(sess.run(tf.argmax(y, 1))), /          ", predict class ",str(sess.run(tf.argmax(res, 1))), /          ", predict ", str(sess.run(tf.equal(tf.argmax(y, 1), tf.argmax(res, 1))))          )

最新文章

123

最新摄影

微信扫一扫

第七城市微信公众平台