如何使用tensorflow.js在Chrome上训练神经网络

如何使用tensorflow.js在Chrome上训练神经网络

本教程只是简单演示如何使用脚本语言(javascript)在浏览器中创建神经网络,并进行训练和预测。

在本教程中,我们将构建一个模型来推断两个数字之间的关系,其中y = 2x -1 (y = 2x -1)。

让我们开始创建一个基本的html文件

如何使用tensorflow.js在Chrome上训练神经网络

现在我们需要导入tensorflow.js机器学习库

如何使用tensorflow.js在Chrome上训练神经网络

创建训练函数

我们需要使函数异步,以便它可以在后台运行而不影响我们的网页。

如何使用tensorflow.js在Chrome上训练神经网络

函数说明:

我们在函数中异步调用model.fit(),为此我们需要将神经网络模型作为参数传递给异步函数。

我们对模型使用了一个wait,这样它就可以一直等到训练结束。

我们在训练结束后使用javascript进行回调,比如在本例中,我们调用onEpochEnd来打印训练结束后的最终损失。

用单个神经网络创建模型

如何使用tensorflow.js在Chrome上训练神经网络

神经网络模型摘要

model.summary()

如何使用tensorflow.js在Chrome上训练神经网络

具有1个神经元网络的模型摘要

PS:神经网络模型显示训练参数为2,这是由于有权重和偏差这两个参数(即w和c)。

样本数据

如何使用tensorflow.js在Chrome上训练神经网络

说明:

就像我们在python中使用numpy一样,我们需要使用tf.tensor2d()函数来定义二维数组。
在tensor2d函数中提及数组的形状很重要。[6,1] 为数组形状。

异步训练和预测

如何使用tensorflow.js在Chrome上训练神经网络

添加一些数据以显示在网页上

如何使用tensorflow.js在Chrome上训练神经网络

最终的html文件将如下所示

<!DOCTYPE html><html><head> <title>Training a model on browser</title> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script><script lang="js">        async function doTraining(model){            const history =                   await model.fit(xs, ys,                         { epochs: 500,                          callbacks:{                              onEpochEnd: async(epoch, logs) =>{                                  console.log("Epoch:"                                               + epoch                                               + " Loss:"                                               + logs.loss);                                                                }                          }                        });        }        const model = tf.sequential();        model.add(tf.layers.dense({units: 1, inputShape: [1]}));        model.compile({loss:'meanSquaredError',                        optimizer:'sgd'});        model.summary();        const xs = tf.tensor2d([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], [6, 1]);        const ys = tf.tensor2d([-3.0, -1.0, 2.0, 3.0, 5.0, 7.0], [6, 1]);        doTraining(model).then(() => {            alert(model.predict(tf.tensor2d([10], [1,1])));        });    </script></head><body> <h1 align="center">Press 'f12' key or 'Ctrl' + 'Shift' + 'i' to check whats going on</h1></body></html>

如何使用tensorflow.js在Chrome上训练神经网络

最后,在浏览器上训练和预测模型

使用Google Chrome浏览器打开html文件,然后按“ F12”键检查开发者控制台。

如何使用tensorflow.js在Chrome上训练神经网络

您可以在开发者控制台中看到训练阶段的损失。训练结束后,网页上会自动显示预测结果提示框。

如何使用tensorflow.js在Chrome上训练神经网络

这是一个提示框,显示输入数字10的预测。根据方程Y = 2X-1,输入x = 10的输出应该是y = 19。我们的神经网络模型预测了18.91,已经非常接近了。

相关推荐