【深度学习笔记】（二）Hello, Tensorflow!

0
0
0
1. 云栖社区>
2. 博客>
3. 正文

# 【深度学习笔记】（二）Hello, Tensorflow!

## 二、Hello, Tensorflow

1. 定义数据
2. 定义计算图与变量
3. 定义会话
4. 进行计算

### 2、基于MNIST数据集的手写数字识别

``````# 1、load data set
from tensorflow.examples.tutorials.mnist import input_data

# 2、see data set:
# train - test - validation

# train data set
# print(mnist.train.images.shape,mnist.train.labels.shape)

# test data set
# print(mnist.test.images.shape,mnist.test.labels.shape)

# validation data set
# print(mnist.validation.images.shape,mnist.validation.labels.shape)

# 3、开启tensorflow session
import tensorflow as tf
sess = tf.InteractiveSession()

# 4、define softmax regression
# x
x = tf.placeholder(tf.float32,[None,784])
# W
W = tf.Variable(tf.zeros([784,10]))
# b
b = tf.Variable(tf.zeros(10))
# y
y = tf.nn.softmax(tf.matmul(x,W) + b)
# y_
y_ = tf.placeholder(tf.float32,[None,10])
# loss
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))
# SGD
# init
tf.global_variables_initializer().run()

# 5、trainning starts
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
train_step.run({x:batch_xs, y_:batch_ys})
# trainning ends

# correct predictiong
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
# accuracy
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
# evalue
print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))
``````

+ 关注

corcosa 8736人浏览