博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
用TensorFlow搭建一个全连接神经网络
阅读量:4193 次
发布时间:2019-05-26

本文共 3047 字,大约阅读时间需要 10 分钟。

用TensorFlow搭建一个全连接神经网络


说明

  • 本例子利用TensorFlow搭建一个全连接神经网络,实现对MNIST手写数字的识别。

先上代码

from tensorflow.examples.tutorials.mnist import input_dataimport tensorflow as tf# prepare datamnist = input_data.read_data_sets('MNIST_data', one_hot=True)xs = tf.placeholder(tf.float32, [None, 784])ys = tf.placeholder(tf.float32, [None, 10])# the model of the fully-connected networkweights = tf.Variable(tf.random_normal([784, 10]))biases = tf.Variable(tf.zeros([1, 10]) + 0.1)outputs = tf.matmul(xs, weights) + biasespredictions = tf.nn.softmax(outputs)cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),                                              reduction_indices=[1]))train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)# compute the accuracycorrect_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))with tf.Session() as sess:    init = tf.global_variables_initializer()    sess.run(init)    for i in range(1000):        batch_xs, batch_ys = mnist.train.next_batch(100)        sess.run(train_step, feed_dict={            xs: batch_xs,            ys: batch_ys        })        if i % 50 == 0:            print(sess.run(accuracy, feed_dict={                xs: mnist.test.images,                ys: mnist.test.labels            }))

代码解析

1. 读取MNIST数据

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

2. 建立占位符

xs = tf.placeholder(tf.float32, [None, 784])ys = tf.placeholder(tf.float32, [None, 10])
  • xs 代表图片像素数据, 每张图片(28×28)被展开成(1×784), 有多少图片还未定, 所以shape为None×784.
  • ys 代表图片标签数据, 0-9十个数字被表示成One-hot形式, 即只有对应bit为1, 其余为0.

3. 建立模型

weights = tf.Variable(tf.random_normal([784, 10]))biases = tf.Variable(tf.zeros([1, 10]) + 0.1)outputs = tf.matmul(xs, weights) + biasespredictions = tf.nn.softmax(outputs)cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),                                              reduction_indices=[1]))train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

使用Softmax函数作为激活函数:

ouput=Softmax(input×weight+bias)

4. 计算正确率

correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

5. 使用模型

with tf.Session() as sess:    init = tf.global_variables_initializer()    sess.run(init)    for i in range(1000):        batch_xs, batch_ys = mnist.train.next_batch(100)        sess.run(train_step, feed_dict={            xs: batch_xs,            ys: batch_ys        })        if i % 50 == 0:            print(sess.run(accuracy, feed_dict={                xs: mnist.test.images,                ys: mnist.test.labels            }))

运行结果

训练1000个循环, 准确率在87%左右.

Extracting MNIST_data/train-images-idx3-ubyte.gzExtracting MNIST_data/train-labels-idx1-ubyte.gzExtracting MNIST_data/t10k-images-idx3-ubyte.gzExtracting MNIST_data/t10k-labels-idx1-ubyte.gz0.10410.6320.73570.78370.79710.81470.82830.83760.84230.85010.85010.85330.85670.85970.85520.86470.86540.87010.87120.8712

参考


  • 我的个人主页:
  • 我的CSDN博客:
  • 我的简书:
  • 我的GitHub:

转载地址:http://kiloi.baihongyu.com/

你可能感兴趣的文章
sublime
查看>>
linux 内存函数
查看>>
sdcardfs
查看>>
csdn 代码拷贝编译错误解决方法
查看>>
软件开发经典书籍
查看>>
spinlock原理
查看>>
dup源码分析
查看>>
try_to_wakeup 选择其他核运行逻辑
查看>>
2021-03-28
查看>>
rtlinux
查看>>
OPPO Reno3系列旗舰官宣:骁龙765G+正反双曲面设计
查看>>
一加8系列新机有望亮相CES 2020:全系支持5G网络
查看>>
称对方攀附使用近似商标 “汽车之家”起诉索赔500万
查看>>
三星突然发布Galaxy S10 Lite和Note 10 Lite:有不同也有所同
查看>>
小米10/10 Pro详细规格曝光:120Hz高刷新率屏+66W超级闪充
查看>>
支付宝2019年账单周一见 你准备好了吗?
查看>>
三星年度旗舰S20要来了:潜望式长焦加一亿像素!
查看>>
2019年微信数据报告:男性用户最爱搜“小姐姐”,表情包最受欢迎的是它
查看>>
莫名其妙就发个手机!这家公司员工晒年终奖品:人手一部iPhone 11
查看>>
苹果iPhone发布13周年:累计销量近20亿部
查看>>