TensorFlow.js:让你在浏览器中玩转机器学习
当谈及机器学习和谷歌的TensorFlow时候,相比于JavaScript和其他浏览器,大多数人会想到Python和专用硬件。本文解释了TensorFlow.js的用途,以及机器学习在浏览器中运行的意义。
TensorFlow.js 是一个JavaScript库,可以在浏览器中运行,也可以通过服务器上的Node.js.运行。但是,在本文中,我们着眼的范围仅在于浏览器中的应用程序。TensorFlow.js的界面完全基于TensorFlow的高级API Keras。Keras代码通常只能和TensorFlow.js代码区分开,大多数差异是因为在Python和JavaScript配置参数中语言结构不同而导致的。
每台GPU都有机器学习
TensorFlow.js可以让你从零开始机器学习。如果有可使用的必需数据,你可以直接在浏览器中训练或执行模型。为此,TensorFlow.js通过WebGL浏览器API,使用计算机的图形卡(GPU)。 这样一来,由于WebGL浏览器需要一些技巧,才能强制执行TensorFlow.js所需的矩阵乘法,导致部分性能最终会丢失。然而,这是无法避免的,因为TensorFlow.js作为一种机器学习策略,是神经网络的主要支撑。这些损耗,可以在训练期间或预测期间,通过矩阵乘法准确反映出来。
到这里,我们已经看到了TensorFlow.js胜过TensorFlow的第一个优势:尽管TensorFlow目前只能通过CUDA支持NVIDIA GPU,但TensorFlow.js可以和任意显卡配合使用。清单1包含了使用High Level API在浏览器中创建顺序神经网络的代码。如果你了解TensorFlow的Keras API,那一切操作都很清楚,教程也可以在 tensorflow.org 上找到。
Listing 1
// create a sequential model
const model = tf.sequential();
// add a fully connected layer with 10 units (neurons)
model.add(tf.layers.dense({units: 10}));
// add a convolutional layer to work on a monochrome 28x28 pixel image with 8
// filter units
model.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
filters: 8
}));
// compile the model like you would do in Keras
// the API speaks for itself
model.compile({
optimizer: 'adam',
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
与所有浏览器API交互
寻找不同操作系统和设备上的接口地址,着实令人苦恼。而基于浏览器的应用
程序,在开发过程中便无需这样,它们甚至可以访问相机或麦克风一类的复杂硬件,这些硬件都固定在HTML标准中,目前所有的浏览器都支持。
此外,浏览器的性质,既为浏览器交互而设计,同样也适合你。因此,如今想获取一个具有机器学习功能的交互式应用程序比以往更容易。
举个例子,我们有一个简单的游戏Scavenger Hunt,它可以在手机浏览器中运行,从而给我们带来极大的乐趣。
如下图所示,在现实世界中,你必须快速找到与显示的表情符号相匹配的对象。为此,内置摄像机和训练好的神经网络,可以检测到相匹配的对象。即使没有机器学习技能,任何JavaScript开发人员都可以运用这样的模型。
机器学习,无需在每台计算机上安装
TensorFlow.js允许你使用TensorFlow部署预先创建的模型。此模型可能已经在强硬件上完全或一定程度上接受过训练。然而,在浏览器中,它仅被归结为应用程序或进一步训练。图2显示了通过不同姿势控制的吃豆人变体,它基于预先训练的网络,在浏览器中根据自己的姿势进行再训练,我们称之为迁移学习。
模型由提供的程序进行转换,并且可以在加载后通过输入类似于以下内容的行,进行异步加载:
const model = await tf.loadModel('model.json')
之后,模型不再与浏览器中直接创建的模型相区分开。因此,它便于进行预测,接着,预测又在GPU上异步执行:
const example = tf.tensor([[150, 45, 10]]);
const prediction = model.predict(example);
const value = await prediction.data();
除了通过游戏进行娱乐外,这里还可以设想更多有用的应用程序。比如通过手势进行导航或互动,可以为残疾人或特殊情况下的人提供帮助。 正如前面已经提到的:只需加载一个网站,即可完成所有操作。
位置检测技术的另一案例,是下图中的PoseNet。它已经过预先训练,即使图片中有多个人,它也可以识别脸部,手臂和腿部的位置。在这里,即使有一定的距离,我们也去有能力去有效地控制重要程序。PoseNet的使用非常简单,甚至不需要机器学习领域的基础知识。
清单2进行了概述。
Listing 2
import * as posenet from '@tensorflow-models/posenet';
import * as tf from '@tensorflow/tfjs';
// load the posenet model
const model = await posenet.load();
// get the poses from a video element linked to the camera
const poses = await model.estimateMultiplePoses(video);
// poses contain
// - confidence score
// - x, y positions
用户数据无需离开浏览器
特别是现在,根据GDPR进行数据保护,已经越来越重要。人们会考虑,他们是否想在计算机上有特定的cookie,或者是否愿意把用户的统计数据发送给制造商,用于改善软件的用户体验。如果反过来,会怎么样?制造商提供了如何使用软件的一般模型,类似于上述的吃豆人游戏,它通过转移学习模型来适应个人用户。尽管这方面成果不多,但非常有发展潜力,让我们拭目以待。
总结
首先,浏览器中的机器学习似乎对许多开发人员没有多大意义。但是,如果你仔细研究,就会发现它的应用可能性,这是其他平台无法提供的:
1.培训:你可以直接与机器学习概念进行交互,通过实验进行学习。
2.开发:如果你已经拥有或想要或需要构建JS应用程序,则可以直接使用或训练机器学习模型。
3.游戏:仅通过相机进行实时位置估算(当前相机前方的人们如何移动)或图像识别,可以与游戏直接结合。已经有一些非常酷的游戏案例,但是,你可以做的远不止游戏。
4.部署:假设你已经拥有了机器学习模型,想知道如何投入生产。可以用浏览器来解决这个问题。即使是已经完成的模型也可以集中到您自己的应用程序中,并无需深入了解机器学习。
5.交互式可视化:用于交互式项目甚至艺术项目。
正如我们在上图中看到的,对于相同的硬件,在TensorFlow上的性能仍有不足。在1080GTX GPU上运行后,作为比较,我们测量出使用MobileNet进行预测的时间,因为提到了它的运用示例。在这种情况下,TensorFlow的运行速度比TensorFlow.js快了三到四倍。但是,两个值都非常低。WebGPU标准可以更直接地访问GPU,有望实现更好的性能。