tensorflow官网给出的入门案例,即给定一堆手写数字图片,使用机器学习方法对该数字图片进行分类。
1.导入数据集,tensorflow可以直接使用如下方式在网上在线下载
import tensorflow as tf import ssl ssl._create_default_https_context = ssl._create_unverified_context#不加这个会报错的,当你urllib.urlopen一个 https 的时候会验证一次 SSL 证书 ,当目标使用的是自签名的证书时就会爆出该错误消息。 from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data', one_hot=True)#这个数据集每一张图片都有784个像素点28*28 xs = tf.placeholder(tf.float32,[None,784])#28*28 ys = tf.placeholder(tf.float32,[None,10])#每一个图片对应十个标签的输出
Extracting MNIST_data/train-images-idx3-ubyte.gz Extracting MNIST_data/train-labels-idx1-ubyte.gz Extracting MNIST_data/t10k-images-idx3-ubyte.gz Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
2.添加层的代码,这里只做一个仅有输入层和输出层的简单模型
def add_layer(inputs,in_size,out_size,activation_function=None): Weights = tf.Variable(tf.random_normal([in_size,out_size]))#定义一个大小为in_size和out_size的一个权重矩阵 biases = tf.Variable(tf.zeros([1,out_size])+0.1) #定义一个1行,out_size列的偏置。一般biases不推荐为0 Wx_plus_b = tf.matmul(inputs,Weights)+biases if activation_function is None: output = Wx_plus_b else: output = activation_function(Wx_plus_b) return output prediction = add_layer(xs,784,10,activation_function=tf.nn.softmax)#这里用的是softmax作为激活函数
3.设置误差函数
cross_entroy = tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction),reduction_indices=[1])) #交叉熵
softmax作为激励函数+交叉熵作为损失函数是非常常用的输出层
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entroy)
4.训练
sess = tf.Session() sess.run(tf.global_variables_initializer()) def compute_accuracy(v_xs,v_ys): global prediction y_pre = sess.run(prediction,feed_dict={xs:v_xs}) correct_prediction = tf.equal(tf.arg_max(y_pre,1),tf.arg_max(v_ys,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#tf.cast()函数的作用是执行 tensorflow 中张量数据类型转换 res = sess.run(accuracy,feed_dict={xs:v_xs,ys:v_ys}) return res for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys}) if i%50==0: print(compute_accuracy(mnist.test.images,mnist.test.labels))
0.074 0.6654 0.7534 0.795 0.8156 0.8296 0.8352 0.8476 0.8521 0.8556 0.8583 0.8621 0.8601 0.8681 0.8695 0.8746 0.8701 0.8725 0.8748 0.8769
softmax+交叉熵:https://segmentfault.com/a/1190000017320763