纯Python实现鸢尾属植物数据集神经网络模型
摘要: 本文以Python代码完成整个鸾尾花图像分类任务,没有调用任何的数据包,适合新手阅读理解,并动手实践体验下机器学习方法的大致流程。
尝试使用过各大公司推出的植物识别APP吗?比如微软识花、花伴侣等这些APP。当你看到一朵不知道学名的花时,只需要打开植物识别APP,拍摄一张你所想辨认的植物照片并上传,APP会自动识别出该花的品种及详细介绍,感觉手机中装了一个知识渊博的生物学家,是不是很神奇?其实,背后的原理很简单,是一个图像分类的过程,将上传的图像与手机中预存的数据集或联网数据进行匹配,将其分类到对应的类别即可。随着深度学习方法的应用,图像分类的精度越来越高,在部分数据集上已经超越了人眼的能力。
相对于传统神经网络的方法而言,深度学习方法一般对数据集规模、硬件平台有着比较高的要求,如果只是单纯的想尝试了解图像分类任务的基本流程,建议采用小数据集样本及传统的神经网络方法实现。本文将带领读者采用鸢尾属植物数据集(Iris Data Set)来实现一个分类任务,整个鸢尾属植物数据集是机器学习中历史悠久的数据集,比现在常用的数字手写体数据集(Mnist Data Set)数据集还要早得多,该数据集来源于英国著名的统计学家、生物学家Ronald Fiser。本文在不使用相关软件库的情况下,从头开始构建针对鸢尾属植物数据的神经网络模型,对其进行训练并获得好的结果。
鸢尾属植物数据集是用于测试机器学习算法的最常用数据集。该数据包含四种特征,萼片长度、萼片宽度、花瓣长度和花瓣宽度,用于鸢尾属植物的不同物种(versicolor, virginica和setosa)。此外,每个物种有50个实例(数据行),下面让我们看看样本数据分布情况。
我们将在这个数据集上使用神经网络构建分类模型。为了简单起见,使用花瓣长度和花瓣宽度作为特征,且只有两类物种:versicolor和virginica。下面就让我们在Python中逐步训练针对该样本数据集的神经网络:
步骤1:准备鸢尾属植物数据集
将Iris数据集导入python并对数据进行子集划分以保留行之间的相关性:
蓝色点代表Versicolor物种,红色点代表Virginica物种。本文构建的神经网络将在这些数据上进行训练,以期最后能正确地分类物种。
步骤2:初始化参数(权重和偏置)
下面构建一个具有单个隐藏层的神经网络。此外,将隐藏图层的大小设置为6:
步骤3:前向传播(forward propagation)
在前向传播过程中,使用tanh激活函数作为第一层的激活函数,使用sigmoid激活函数作为第二层的激活函数:
步骤4:计算代价函数(cost function)
目标是使得计算的代价函数小化,本文采用交叉熵(cross-entropy)作为代价函数:
步骤5:反向传播(back propagation)
计算反向传播过程,主要是计算代价函数的导数:
步骤6:更新参数
使用反向传播过程中计算的梯度来更新权重和偏置:
步骤7:建立神经网络
将以上所有函数组合起来以创建设计的神经网络模型。总而言之,下面是模型函数的整体顺序:
1、初始化参数
2、前向传播
3、计算代价函数
4、反向传播
5、更新参数
步骤8:跑动模型
将隐藏层节点设置为6,最大迭代次数设置为10,000次,并每隔1000次打印出训练的结果:
步骤9:画出分类边界
从图中可以观察到,只有四个点被错误分类。虽然我们可以调整模型来进一步地提高模型训练精度,但该些操作显然会导致过拟合现象的出现。
资源
本文作者:【方向】
本文为云栖社区原创内容,未经允许不得转载。