千家信息网

python中怎么使用Tensorflow训练BP神经网络实现鸢尾花分类

发表于:2024-11-11 作者:千家信息网编辑
千家信息网最后更新 2024年11月11日,这篇文章主要介绍了python中怎么使用Tensorflow训练BP神经网络实现鸢尾花分类的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇python中怎么使用Tensor
千家信息网最后更新 2024年11月11日python中怎么使用Tensorflow训练BP神经网络实现鸢尾花分类

这篇文章主要介绍了python中怎么使用Tensorflow训练BP神经网络实现鸢尾花分类的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇python中怎么使用Tensorflow训练BP神经网络实现鸢尾花分类文章都会有所收获,下面我们一起来看看吧。

使用软件

Python 3.8,Tensorflow2.0

问题描述

鸢尾花主要分为狗尾草鸢尾(0)、杂色鸢尾(1)、弗吉尼亚鸢尾(2)。
人们发现通过计算鸢尾花的花萼长、花萼宽、花瓣长、花瓣宽可以将鸢尾花分类。
所以只要给出足够多的鸢尾花花萼、花瓣数据,以及对应种类,使用合适的神经网络训练,就可以实现鸢尾花分类。

搭建神经网络

输入数据是花萼长、花萼宽、花瓣长、花瓣宽,是n行四列的矩阵。
而输出的是每个种类的概率,是n行三列的矩阵。
我们采用BP神经网络,设X为输入数据,Y为输出数据,W为权重,B偏置。有

y=x∗w+b

因为x为n行四列的矩阵,y为n行三列的矩阵,所以w必须为四行三列的矩阵,每个神经元对应一个b,所以b为一行三列的的矩阵。
神经网络如下图。

所以,只要找到合适的w和b,就能准确判断鸢尾花的种类。
下面就开始对这两个参数进行训练。

训练参数

损失函数

损失函数表达的是预测值(y*)和真实值(y)的差距,我们采用均方误差公式作为损失函数。

损失函数值越小,说明预测值和真实值越接近,w和b就越合适。
如果人来一组一组试,那肯定是不行的。所以我们采用梯度下降算法来找到损失函数最小值。
梯度:对函数求偏导的向量。梯度下降的方向就是函数减少的方向。

其中a为学习率,即梯度下降的步长,如果a太大,就可能错过最优值,如果a太小,则就需要更多步才能找到最优值。所以选择合适的学习率很关键。

参数优化

通过反向传播来优化参数。
反向传播:从后向前,逐层求损失函数对每层神经元参数的偏导数,迭代更新所有参数。
比如

可以看到w会逐渐趋向于loss的最小值0。
以上就是我们训练的全部关键点。

代码

数据集

我们使用sklearn包提供的鸢尾花数据集。共150组数据。
打乱保证数据的随机性,取前120个为训练集,后30个为测试集。

# 导入数据,分别为输入特征和标签x_data = datasets.load_iris().data ## 存花萼、花瓣特征数据y_data = datasets.load_iris().target # 存对应种类# 随机打乱数据(因为原始数据是顺序的,顺序不打乱会影响准确率)# seed: 随机数种子,是一个整数,当设置之后,每次生成的随机数都一样(为方便教学,以保每位同学结果一致)np.random.seed(116)  # 使用相同的seed,保证输入特征和标签一一对应np.random.shuffle(x_data)np.random.seed(116)np.random.shuffle(y_data)tf.random.set_seed(116)# 将打乱后的数据集分割为训练集和测试集,训练集为前120行,测试集为后30行x_train = x_data[:-30]y_train = y_data[:-30]x_test = x_data[-30:]y_test = y_data[-30:]# 转换x的数据类型,否则后面矩阵相乘时会因数据类型不一致报错x_train = tf.cast(x_train, tf.float32)x_test = tf.cast(x_test, tf.float32)# from_tensor_slices函数使输入特征和标签值一一对应。(把数据集分批次,每个批次batch组数据)train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

参数

# 生成神经网络的参数,4个输入特征故,输入层为4个输入节点;因为3分类,故输出层为3个神经元# 用tf.Variable()标记参数可训练w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1)) # 四行三列,方差为0.1b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1)) # 一行三列,方差为0.1

训练

a = 0.1  # 学习率为0.1epoch = 500  # 循环500轮# 训练部分for epoch in range(epoch):  # 数据集级别的循环,每个epoch循环一次数据集    for step, (x_train, y_train) in enumerate(train_db):  # batch级别的循环 ,每个step循环一个batch        with tf.GradientTape() as tape:  # with结构记录梯度信息            y = tf.matmul(x_train, w1) + b1  # 神经网络乘加运算            y = tf.nn.softmax(y)  # 使输出y符合概率分布            y_ = tf.one_hot(y_train, depth=3)  # 将标签值转换为独热码格式,方便计算loss            loss = tf.reduce_mean(tf.square(y_ - y))  # 采用均方误差损失函数mse = mean(sum(y-y*)^2)        # 计算loss对w, b的梯度        grads = tape.gradient(loss, [w1, b1])        # 实现梯度更新 w1 = w1 - lr * w1_grad    b = b - lr * b_grad        w1.assign_sub(a * grads[0])  # 参数w1自更新        b1.assign_sub(a * grads[1])  # 参数b自更新

测试

# 测试部分total_correct, total_number = 0, 0for x_test, y_test in test_db:    # 前向传播求概率    y = tf.matmul(x_test, w1) + b1    y = tf.nn.softmax(y)    predict = tf.argmax(y, axis=1)  # 返回y中最大值的索引,即预测的分类    # 将predict转换为y_test的数据类型    predict = tf.cast(predict, dtype=y_test.dtype)    # 若分类正确,则correct=1,否则为0,将bool型的结果转换为int型    correct = tf.cast(tf.equal(predict, y_test), dtype=tf.int32)    # 将每个batch的correct数加起来    correct = tf.reduce_sum(correct)    # 将所有batch中的correct数加起来    total_correct += int(correct)    # total_number为测试的总样本数,也就是x_test的行数,shape[0]返回变量的行数    total_number += x_test.shape[0]# 总的准确率等于total_correct/total_numberacc = total_correct / total_numberprint("测试准确率 = %.2f %%" % (acc * 100.0))my_test = np.array([[5.9, 3.0, 5.1, 1.8]])print("输入 5.9  3.0  5.1  1.8")my_test = tf.convert_to_tensor(my_test)my_test = tf.cast(my_test, tf.float32)y = tf.matmul(my_test, w1) + b1y = tf.nn.softmax(y)species = {0: "狗尾鸢尾", 1: "杂色鸢尾", 2: "弗吉尼亚鸢尾"}predict = np.array(tf.argmax(y, axis=1))[0]  # 返回y中最大值的索引,即预测的分类print("该鸢尾花为:" + species.get(predict))

关于"python中怎么使用Tensorflow训练BP神经网络实现鸢尾花分类"这篇文章的内容就介绍到这里,感谢各位的阅读!相信大家对"python中怎么使用Tensorflow训练BP神经网络实现鸢尾花分类"知识都有一定的了解,大家如果还想学习更多知识,欢迎关注行业资讯频道。

数据 鸢尾 训练 神经 鸢尾花 参数 神经网络 网络 分类 函数 输入 损失 梯度 矩阵 测试 花瓣 花萼 特征 循环 合适 数据库的安全要保护哪些东西 数据库安全各自的含义是什么 生产安全数据库录入 数据库的安全性及管理 数据库安全策略包含哪些 海淀数据库安全审计系统 建立农村房屋安全信息数据库 易用的数据库客户端支持安全管理 连接数据库失败ssl安全错误 数据库的锁怎样保障安全 软件开发自测标准文档 2g网络和4g网络安全 云服务器如何配置安全组 腾讯的数据库平台开发师 密码学与网络安全10章答案 方舟生存服务器延迟na acc数据库密码忘记 东莞来思网络技术有限公司 计算机分为客户机服务器模式 黄河之滨网络安全 智能化软件开发服务哪些行业 三台服务器价格表 黑龙江特种网络技术产品介绍 2017网络安全知识问答 迅雷离线下载服务器 遂昌允诚网络技术有限公司 通用的数据库管理工具 计算机网络工程师和网络技术 为什么lol服务器连接失败 测控网络技术于洋课后答案 戴尔服务器运营卡 软件开发公司的英语标语 软件开发面试时hr说我话少 my sql数据库操作面板 网络安全的表现形式有哪些 沙特阿拉伯网络安全 思讯数据库可以用原来的仓库吗 等保2.0 网络安全设备 中央网络安全和信息化办公室 丝路杯网络安全技能大赛
0