Tensorflow快餐教程(10) - 循环神经网络

0
1
0
1. 云栖社区>
2. 博客>
3. 正文

Tensorflow快餐教程(10) - 循环神经网络

lusing 2018-05-08 21:24:15 浏览2568

循环神经网络

LSTM

RNN中增加了对于之前状态的记忆项，不能直接使用之前BP网络的梯度下降的方法。但是基于该方法将循环项的输入都考虑进来，这个改进方法叫做BPTT算法（Back-Propagation Through Time）。

LSTM的细节我们放到后面详细讲。我们先看看在Tensorflow中如何实现一个LSTM模型：

def RNN(x, weights, biases):
x = tf.unstack(x, timesteps, 1)

lstm_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)

outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)

return tf.matmul(outputs[-1], weights['out']) + biases['out']

from __future__ import print_function

import tensorflow as tf
from tensorflow.contrib import rnn

from tensorflow.examples.tutorials.mnist import input_data

# 训练参数
learning_rate = 0.001
training_steps = 10000
batch_size = 128
display_step = 200

# 网络参数
num_input = 28 # MNIST data input (img shape: 28*28)
timesteps = 28 # timesteps
num_hidden = 128 # hidden layer num of features
num_classes = 10 # MNIST total classes (0-9 digits)

X = tf.placeholder("float", [None, timesteps, num_input])
Y = tf.placeholder("float", [None, num_classes])

# 初始权值
weights = {
'out': tf.Variable(tf.random_normal([num_hidden, num_classes]))
}
biases = {
'out': tf.Variable(tf.random_normal([num_classes]))
}

def RNN(x, weights, biases):
x = tf.unstack(x, timesteps, 1)

lstm_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)

outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)

return tf.matmul(outputs[-1], weights['out']) + biases['out']

logits = RNN(X, weights, biases)
prediction = tf.nn.softmax(logits)

# 定义损失和优化函数
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=Y))
train_op = optimizer.minimize(loss_op)

correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

init = tf.global_variables_initializer()

with tf.Session() as sess:

sess.run(init)

for step in range(1, training_steps+1):
batch_x, batch_y = mnist.train.next_batch(batch_size)
batch_x = batch_x.reshape((batch_size, timesteps, num_input))
sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})
if step % display_step == 0 or step == 1:
loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x,
Y: batch_y})
print("Step " + str(step) + ", Minibatch Loss= " + \
"{:.4f}".format(loss) + ", Training Accuracy= " + \
"{:.3f}".format(acc))

print("Optimization Finished!")

test_len = 128
test_data = mnist.test.images[:test_len].reshape((-1, timesteps, num_input))
test_label = mnist.test.labels[:test_len]
print("Testing Accuracy:", \
sess.run(accuracy, feed_dict={X: test_data, Y: test_label}))

门控循环单元GRU(Gated Recurrent Unit)

LSTM所使用的技术属于门控RNN（Gated RNN）技术。除了LSTM之外，还有一种应用广泛的门控RNN叫做GRU(Gated Recurrent Unit).

• tf.contrib.rnn.BasicRNNCell
• tf.contrib.rnn.BasicLSTMCell
• tf.contrib.rnn.GRUCell
• tf.contrib.rnn.LSTMCell
• tf.contrib.rnn.LayerNormBasicLSTMCell

双向循环神经网络

def BiRNN(x, weights, biases):
x = tf.unstack(x, timesteps, 1)

lstm_fw_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)
lstm_bw_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)

try:
outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
dtype=tf.float32)
except Exception:
outputs = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
dtype=tf.float32)

return tf.matmul(outputs[-1], weights['out']) + biases['out']


from __future__ import print_function

import tensorflow as tf
from tensorflow.contrib import rnn
import numpy as np

from tensorflow.examples.tutorials.mnist import input_data

learning_rate = 0.001
training_steps = 10000
batch_size = 128
display_step = 200

num_input = 28 # MNIST data input (img shape: 28*28)
timesteps = 28 # timesteps
num_hidden = 128 # hidden layer num of features
num_classes = 10 # MNIST total classes (0-9 digits)

X = tf.placeholder("float", [None, timesteps, num_input])
Y = tf.placeholder("float", [None, num_classes])

weights = {
'out': tf.Variable(tf.random_normal([2*num_hidden, num_classes]))
}
biases = {
'out': tf.Variable(tf.random_normal([num_classes]))
}

def BiRNN(x, weights, biases):

x = tf.unstack(x, timesteps, 1)

lstm_fw_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)
lstm_bw_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)

try:
outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
dtype=tf.float32)
except Exception: # Old TensorFlow version only returns outputs not states
outputs = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
dtype=tf.float32)

return tf.matmul(outputs[-1], weights['out']) + biases['out']

logits = BiRNN(X, weights, biases)
prediction = tf.nn.softmax(logits)

loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=Y))
train_op = optimizer.minimize(loss_op)

correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

init = tf.global_variables_initializer()

with tf.Session() as sess:

sess.run(init)

for step in range(1, training_steps+1):
batch_x, batch_y = mnist.train.next_batch(batch_size)
batch_x = batch_x.reshape((batch_size, timesteps, num_input))
sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})
if step % display_step == 0 or step == 1:
loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x,
Y: batch_y})
print("Step " + str(step) + ", Minibatch Loss= " + \
"{:.4f}".format(loss) + ", Training Accuracy= " + \
"{:.3f}".format(acc))

print("Optimization Finished!")

test_len = 128
test_data = mnist.test.images[:test_len].reshape((-1, timesteps, num_input))
test_label = mnist.test.labels[:test_len]
print("Testing Accuracy:", \
sess.run(accuracy, feed_dict={X: test_data, Y: test_label}))

小结

Tensorflow快餐教程(1) - 30行代码搞定手写识别：https://yq.aliyun.com/articles/582122
Tensorflow快餐教程(2) - 标量运算：https://yq.aliyun.com/articles/582490
Tensorflow快餐教程(3) - 向量：https://yq.aliyun.com/articles/584202
Tensorflow快餐教程(4) - 矩阵：https://yq.aliyun.com/articles/584526
Tensorflow快餐教程(5) - 范数：https://yq.aliyun.com/articles/584896
Tensorflow快餐教程(6) - 矩阵分解：https://yq.aliyun.com/articles/585599
Tensorflow快餐教程(7) - 梯度下降：https://yq.aliyun.com/articles/587350
Tensorflow快餐教程(8) - 深度学习简史：https://yq.aliyun.com/articles/588920
Tensorflow快餐教程(9) - 卷积：https://yq.aliyun.com/articles/590233

lusing
+ 关注