千家信息网

TensorFlow卷积神经网络MNIST数据集实现方法是什么

发表于:2024-09-22 作者:千家信息网编辑
千家信息网最后更新 2024年09月22日,本篇内容主要讲解"TensorFlow卷积神经网络MNIST数据集实现方法是什么",感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习"TensorFlow卷积神经
千家信息网最后更新 2024年09月22日TensorFlow卷积神经网络MNIST数据集实现方法是什么

本篇内容主要讲解"TensorFlow卷积神经网络MNIST数据集实现方法是什么",感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习"TensorFlow卷积神经网络MNIST数据集实现方法是什么"吧!

这里使用TensorFlow实现一个简单的卷积神经网络,使用的是MNIST数据集。网络结构为:数据输入层-卷积层1-池化层1-卷积层2-池化层2-全连接层1-全连接层2(输出层),这是一个简单但非常有代表性的卷积神经网络。

import tensorflow as tfimport numpy as npimport input_datamnist = input_data.read_data_sets('data/', one_hot=True)print("MNIST ready")sess = tf.InteractiveSession()# 定义好初始化函数以便重复使用。给权重制造一些随机噪声来打破完全对称,使用截断的正态分布,标准差设为0.1,# 同时因为使用relu,也给偏执增加一些小的正值(0.1)用来避免死亡节点(dead neurons)def weight_variable(shape):    initial = tf.truncated_normal(shape, stddev=0.1)    return tf.Variable(initial)def bias_variable(shape):    initial = tf.constant(0.1, shape=shape)    return tf.Variable(initial)def conv2d(x, W):    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') # 参数分别指定了卷积核的尺寸、多少个channel、filter的个数即产生特征图的个数# 2x2最大池化,即将一个2x2的像素块降为1x1的像素。最大池化会保留原始像素块中灰度值最高的那一个像素,即保留最显著的特征。def max_pool_2x2(x):    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')n_input  = 784 # 28*28的灰度图,像素个数784n_output = 10  # 是10分类问题# 在设计网络结构前,先定义输入的placeholder,x是特征,y是真实的labelx = tf.placeholder(tf.float32, [None, n_input]) y = tf.placeholder(tf.float32, [None, n_output]) x_image = tf.reshape(x, [-1, 28, 28, 1]) # 对图像做预处理,将1D的输入向量转为2D的图片结构,即1*784到28*28的结构,-1代表样本数量不固定,1代表颜色通道数量# 定义第一个卷积层,使用前面写好的函数进行参数初始化,包括weight和biasW_conv1 = weight_variable([3, 3, 1, 32])b_conv1 = bias_variable([32])h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)h_pool1 = max_pool_2x2(h_conv1)# 定义第二个卷积层W_conv2 = weight_variable([3, 3, 32, 64])b_conv2 = bias_variable([64])h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)h_pool2 = max_pool_2x2(h_conv2)# fc1,将两次池化后的7*7共128个特征图转换为1D向量,隐含节点1024由自己定义W_fc1 = weight_variable([7*7*64, 1024])b_fc1 = bias_variable([1024])h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)# 为了减轻过拟合,使用Dropout层keep_prob = tf.placeholder(tf.float32)h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)# Dropout层输出连接一个Softmax层,得到最后的概率输出W_fc2 = weight_variable([1024, 10])b_fc2 = bias_variable([10])pred = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) #前向传播的预测值,print("CNN READY")# 定义损失函数为交叉熵损失函数cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=[1]))# 优化器optm = tf.train.AdamOptimizer(0.001).minimize(cost)# 定义评测准确率的操作corr = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) # 对比预测值的索引和真实label的索引是否一样,一样返回True,不一样返回Falseaccuracy = tf.reduce_mean(tf.cast(corr, tf.float32))# 初始化所有参数tf.global_variables_initializer().run()print("FUNCTIONS READY")training_epochs = 1000 # 所有样本迭代1000次batch_size = 100 # 每进行一次迭代选择100个样本display_step = 1for i in range(training_epochs):    avg_cost = 0.    total_batch = int(mnist.train.num_examples/batch_size)    batch = mnist.train.next_batch(batch_size)    optm.run(feed_dict={x:batch[0], y:batch[1], keep_prob:0.7})    avg_cost += sess.run(cost, feed_dict={x:batch[0], y:batch[1], keep_prob:1.0})/total_batch    if i % display_step ==0: # 每10次训练,对准确率进行一次测试        train_accuracy = accuracy.eval(feed_dict={x:batch[0], y:batch[1], keep_prob:1.0})        test_accuracy = accuracy.eval(feed_dict={x:mnist.test.images, y:mnist.test.labels, keep_prob:1.0})        print("step: %d  cost: %.9f  TRAIN ACCURACY: %.3f  TEST ACCURACY: %.3f" % (i, avg_cost, train_accuracy, test_accuracy))print("DONE")

训练迭代1000次之后,测试分类正确率达到了98.6%

step: 999 cost: 0.000048231 TRAIN ACCURACY: 0.990 TEST ACCURACY: 0.986

在2000次的时候达到了99.1%

step: 2004 cost: 0.000042901 TRAIN ACCURACY: 0.990 TEST ACCURACY: 0.991

相比之前简单神经网络,CNN的效果明显较好,这其中主要的性能提升都来自于更优秀的网络设计,即卷积神经网络对图像特征的提取和抽象能力。依靠卷积核的权值共享,CNN的参数量并没有爆炸,降低计算量的同时也减轻了过拟合,因此整个模型的性能有较大的提升。

到此,相信大家对"TensorFlow卷积神经网络MNIST数据集实现方法是什么"有了更深的了解,不妨来实际操作一番吧!这里是网站,更多相关内容可以进入相关频道进行查询,关注我们,继续学习!

卷积 网络 神经 神经网络 数据 像素 特征 方法 函数 参数 结构 个数 代表 样本 输入 输出 迭代 最大 网络结构 全连 数据库的安全要保护哪些东西 数据库安全各自的含义是什么 生产安全数据库录入 数据库的安全性及管理 数据库安全策略包含哪些 海淀数据库安全审计系统 建立农村房屋安全信息数据库 易用的数据库客户端支持安全管理 连接数据库失败ssl安全错误 数据库的锁怎样保障安全 万网轻云服务器怎么样 英雄联盟服务器一直未响应 服务器cpu散热片在哪 数据库系统软件有哪些厂家 网络安全检查是什么时候 想学网络安全考研学什么专业 我的世界服务器如何用命令换皮肤 红量网络技术有限公司 北才软件开发 怀旧服服务器的几大主播 机电网络安全管理制度 最出名的软件开发公司 消费者数据库 零售户 宽带网络技术及计算机网络技术 南京互联网科技有限公司电话 服务器下载和客户端下载 软件开发怎么写 纵目科技产业互联网融资 oracle数据库报表 上海迎喜互联网科技有限公司 媒体广电网络安全论文 怎么能转进灰烬使者服务器 计算网络技术跟计算机科学 怎么找出数据库中重复的值 雅尔塔会议记录软件开发 数据库建表的5大约束 那个代理服务器可用 明朝思维导图软件开发 计算机网络技术包含学科 安徽多功能软件开发平均价格
0