博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
TensorFlow官方文档学习02-MNIST初级课程
阅读量:6947 次
发布时间:2019-06-27

本文共 1994 字,大约阅读时间需要 6 分钟。

MNIST数据准备

MNIST数据是机器学习入门数据,相当于其他编程语言的第一个“hello world”程序。

  1. 首先在项目目录下新建文件夹“MNIST_data”

2. 官网下载数据的压缩包放入“MNIST_data”文件夹下。 数据官网:

3. 导入并测试数据集

from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=True)print("Training data size:", mnist.train.num_examples)#输出:#Extracting MNIST_data/train-images-idx3-ubyte.gz#Extracting MNIST_data/train-labels-idx1-ubyte.gz#Extracting MNIST_data/t10k-images-idx3-ubyte.gz#Extracting MNIST_data/t10k-labels-idx1-ubyte.gz#Training data size: 55000复制代码

Softmax回归实现

  1. 简要介绍
    Softmax回归可以实现多分类。我们所要用到的MNIST_data数据集其实是数字识别数据集,将图片中的数据识别成0~9这十个数,是一个多分类问题。
    图示:
    公式:
y=softmax(Wx+b)
softmax(z)=normalize(exp(z))
  1. 实现
import tensorflow as tfx = tf.placeholder("float", [None, 784]) #x是占位符placeholder。占位符可以让我们在程序运行计算时再输入这个值。此处x是一个形状为[none, 784]的张量,用来表示任意数量的MNIST图像,每一张图展开成784维的向量。(第一个维度none可以是任意长度)。W = tf.Variable(tf.zeros([784, 10]))#权重,初始化为0b = tf.Variable(tf.zeros([10]))#偏置,初始化为0y = tf.nn.softmax(tf.matmul(x,W) + b)#模型实现:tf.matmul是矩阵的乘法y_ = tf.placeholder("float",[None, 10])#用于输入正确值cross_entropy = -tf.reduce_sum(y_*tf.log(y))#计算交叉熵作为成本函数。-(y_*tf.log(y)是交叉熵的定义。tf.reduce_sum()计算张量所有元素的总和,即所有图片交叉熵的总和。train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)#使用梯度下降法不断的通过最小化交叉熵来调整参数W和b。init = tf.initialize_all_variables()#运行计算前,初始化变量sess = tf.Session()sess.run(init)for i in range(1000):#模型循环1000次,每一次循环中都随机抓取100个批处理数据点来替换之前的占位符。实际上,我们可以每一次训练的时候都使用所有的数据,但这样会造成极大的计算开销,所以每次随机选择100个数据进行训练,因此,这里实际上是随机梯度下降算法。  batch_xs, batch_ys = mnist.train.next_batch(100)  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))#统计预测正确的标签数。函数tf.argmax(),它能给出某个tensor对象在某一维上的其数据最大值所在的索引值。tf.argmax(y,1)返回预测的标签,tf.argmax(y_,1)返回正确的标签。本行代码返回一组布尔值。accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))#将这组布尔值转化为浮点数并求平均值。这个平均值就是准确率print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))#打印结果,填充占位符。#输出#0.9171复制代码

转载于:https://juejin.im/post/5b6cf8566fb9a04fab453c28

你可能感兴趣的文章