运用tensorflow写的第一个神经网络

        因为实训课要用LSTM+attention机制在钢材领域做一个关系抽取。作为仅仅只学过一点深度学习网络的小白在b站上学习了RNN,LSTM的一些理论知识。

但只懂得一些理论知识是无法完成关系抽取的任务的。于是从图书馆借来《tensoflow实战-----深度学习框架》,在此开始记录我的tensorflow神经网络编程!

       首先先介绍一下tensorflow的运作机制,对一个具体的计算而言,一般可以分为两个阶段,第一个阶段用来定义计算图中的计算,第二个阶段用来执行计算。

有了这个概念之后,就会发现这一操作能很好的将框架定义部分,和模型训练部分很好的分开,以下是第一次实验的代码:一个简单的分类问题,一个2,3,1(三层,每一层的节点数)的神经网络。

import tensorflow as tffrom numpy.random import RandomStatebatch_size = 8w1 = tf.Variable(tf.random_normal((2, 3), stddev=1, seed=1))//随机初始化权重,第二个参数为为标准差w2 = tf.Variable(tf.random_normal((3, 1), stddev=1, seed=1))//随机初始化权重x = tf.placeholder(tf.float32, shape=(None, 2), name="x_input")//placeholder一般用来在训练时存放输入数据,因为如果定义成常量的话,所消耗的空间太大y_=tf.placeholder(tf.float32, shape=(None, 1), name="y_input")//参数介绍,需要定义类型和维度,None的意思是,不知道有几组训练数biases1 = tf.Variable(tf.random_normal((1,3),stddev=1))//定义偏置,其实所谓偏置就是截距的概念biases2 = tf.Variable(tf.random_normal((1,1),stddev=1))#a = tf.matmul(x, w1)+biases1//以下是实现前向传播a = tf.sigmoid(tf.matmul(x, w1)+biases1)//用sigmoid函数充当激活函数,用来去线性化y = tf.matmul(a, w2)+biases2y = tf.sigmoid(y)#损失函数选用交叉熵函数cross_entropy = -tf.reduce_mean(y_*tf.log(tf.clip_by_value(y, 1e-10, 1.0))+(1-y)*tf.log(tf.clip_by_value(1-y, 1e-10, 1.0)))#选择优化方法(即更新权重所用的反向传播的方法,这个adam法还不知道啥意思,目前只知道梯度下降)train_step = tf.train.AdamOptimizer(0, 0.001).minimize(cross_entropy)#生成随机数据集rdm = RandomState(1)#随机因子为1dataset_size = 128X = rdm.rand(dataset_size, 2)Y = [[int(x1+x2<1)] for (x1, x2) in X]//生成会话开始训练模型,即前面所提到的执行计算的阶段with tf.Session() as sess:   //tensorflow中所有张量都要初始化  initall = tf.global_variables_initializer()    sess.run(initall)#print(sess.run(biases1))    print(sess.run(w1))print(sess.run(w2))//训练集中抽取一小个部分叫一个batch,训练过程是一个batch一个batch训练的    steps = 5000    for i in range(steps):        start = (i*batch_size)%dataset_size        end = min(start+batch_size, dataset_size)        sess.run(train_step, feed_dict={x:X[start:end],y_:Y[start:end]})        //每训练1000次查看一下训练结果,即交叉熵函数的值,越小越好    if(i%1000==0):              total_cross=sess.run(cross_entropy, feed_dict={x:X, y_:Y})              print(i,"  ",total_cross)   //最后查看一下最后更新的权重print(sess.run(w1))print(sess.run(w2))第一次写博客,也是初学,有问题请大家指出哈。

相关推荐