Java还欠缺什么才能真正支持机器/深度学习?

如何让团队开始使用ML以及如何最好地将ML与我们运行的现有系统集成?

实际上没有用Java构建的ML框架(有DL4J,但我真的不知道有谁使用它,MXNet有一个Scala API而不是Java,而且它不是用Java编写的,Tensorflow有一个不完整的Java API),但是Java在企业中拥有巨大的使用范围,在过去的20年中,在全球范围内投资了数万亿美元的金融服务,交易,电子商务和电信公司 - 这个名单是无穷无尽的。对于机器学习,“第一个公民”编程语言不是Java,而是Python。

就个人而言,我喜欢用Python和Java编写代码,但Frank Greco提出了一个让我思考的有趣问题:

Java还需要什么才可在ML中与Python竞争?如果Java认真对待真正支持机器学习怎么办?

很重要么?

自1998年以来,就多个企业的变革而言,Java一直处于领先地位 - 网络,移动,浏览器与原生,消息传递,i18n和l10n全球化支持,扩展和支持各种企业信息存储值得一提的是,从关系数据库到Elasticsearch。

机器学习行业并非如此。Java团队如果进入ML只能有两个选择:

  1. 在Python中重新训练/共同训练。
  2. 使用供应商API为您的企业系统添加机器学习功能。

这两种选择都不是真的很好。第一个需要大量的前期时间和投资加上持续的维护成本,而第二个风险是供应商锁定,供应商解除支持,引入第三方组件(需要支付网络价格),这可能是一个性能关键系统,并且需要您可以在组织边界之外共享数据 - 对某些人来说是不行的。

在我看来,最具破坏性的是文化消耗的可能性 - 团队无法改变他们不理解或无法维护的代码,Java团队有可能在企业计算的下一波浪潮机器学习浪潮中落后 。

因此,Java编程语言和平台拥有一流的机器学习支持是非常重要,如果没有,Java将面临被未来5到10年内支持ML的语言慢慢取代的风险。

为什么Python在ML中占据主导地位?

首先,让我们考虑为什么Python是机器学习和深度学习的主要语言。

我怀疑这一切都始于一个功能 - 列表的切片slicing支持。这种支持是可扩展的:任何实现__getitem__和__setitem__方法的Python类都可以使用这种语法进行切片。下面的代码段显示了这个Python功能的强大和自然性。

a = [1, 2, 3, 4, 5, 6, 7, 8]
print(a[1:4])
#returns [2, 3, 4] - -挑选出中间元素的切片
print(a[1:-1])
#returns [2, 3, 4, 5, 6, 7] - 跳过第0和最后一个元素
print(a[4:])
#returns [5, 6, 7, 8] - 终点默认
print(a[:4])
#returns [1, 2, 3, 4] -开始点被默认

当然,还有更多。与旧的Java代码相比,Python代码更简洁,更简洁。支持未经检查的异常,开发人员可以轻松地编写一次性Python脚本来尝试填充,而不会陷入“一切都是一个类”的Java思维模式中。使用Python很容易。

但是现在我认为是主要因素 - 尽管Python社区在维持2.7和3之间的凝聚力方面做了一顿狗晚餐,但他们在构建设计良好,快速的数字计算库(NumPy)方面做得更好 。Numpy是围绕ndarray构建的 - N维数组对象。直接来自文档:“ NumPy的主要对象是同构多维数组。它是一个元素表(通常是数字),所有相同的类型,由正整数元组索引 “。


NumPy中的所有内容都是将数据放入ndarray然后对其执行操作。NumPy支持多种类型的索引,广播,矢量化以提高速度,并且通常允许开发人员轻松创建和操作大型数字数组。

下一个片段显示了ndarray 索引和正在进行的广播,这些是ML / DL中的核心操作。

import numpy as np
#Simple broadcast example
a = np.array([1.0, 2.0, 3.0])
b = 2.0
c = a * b
print(c)
#returns [ 2. 4. 6.] - the scalar b is automatically promoted / broadcast and applied to the vector a to create c
#return返回[2. 4. 6.] - 标量b被自动提升/广播并应用于向量a以创建c
#2-d (matrix with rank 2) indexing in NumPy - this extends to Tensors - i.e. rank > 2
y = np.arange(35).reshape(5,7)
print(y)
# array([[ 0, 1, 2, 3, 4, 5, 6],
# [ 7, 8, 9, 10, 11, 12, 13],
# [14, 15, 16, 17, 18, 19, 20],
# [21, 22, 23, 24, 25, 26, 27],
# [28, 29, 30, 31, 32, 33, 34]])
print(y[0,0])
# 单个单元格访问 - notation is row-major, returns 0
print(y[4])
# returns all of row 4: array([28, 29, 30, 31, 32, 33, 34])
print(y[:,2])
# returns all of column 2: array([ 2, 9, 16, 23, 30])

处理大型多维数字数组是机器学习编码的核心,尤其是深度学习。深度神经网络是节点格和边格的数字模型。在训练网络或对其进行推理时的运行时操作需要快速矩阵乘法。

NumPy已经促成并启用了更多 -  scipy,pandas和许多其他依赖于NumPy的库。领先的深度学习库(Tensorflow来自谷歌,PyTorch来自Facebook)都投入巨资在Python。Tensorflow还有其他用于Go,Java和JavaScript的API,但它们不完整且被视为不稳定。PyTorch最初是用Lua编写的,当它们从2017年相当小的语言转移到主要的Python ML生态系统时,它的受欢迎程度大幅上升。

Python的缺点

Python不是一种完美的语言 - 特别是最流行的Python运行时 - CPython - 具有全局解释器锁(GIL),因此性能缩放并不简单。此外,像PyTorch和Tensorflow这样的Python DL框架仍然将核心方法交给不透明的实现。例如,NVidia 的cuDNN库对PyTorch中[url=https://pytorch.org/docs/stable/nn.html#rnn]RNN / LSTM实现[/url]的范围产生了深远的影响。RNN和LSTM是一种非常重要的DL技术,特别适用于商业应用,因为它们专门用于对顺序,可变长度序列进行分类和预测 - 例如网络点击流,文本片段,用户事件等。

为了公平对待Python,这种不透明度/限制几乎适用于任何未用C或C ++编写的ML / DL框架。为什么?因为为了获得核心的最大性能,像矩阵乘法这样的高频操作,开发人员尽可能“接近底层冶金工艺”。

Java需要做些什么才能参与竞争?

我建议Java平台有三个主要的补充,如果存在的话,会促使Java中一个健康且蓬勃发展的机器学习生态系统的萌芽:

1.在核心语言中添加本机索引/切片支持,以与Python的易用性和表现力相媲美,可能以现有的有序集合List <E>接口为中心。这种支持还需要承认重载以支持#2点。


2.构建Tensor实现 - 可能在java.math包中,但也可以桥接到Collections API。这组类和接口将作为ndarray的等价物,并提供额外的索引支持 - 特别是三种类型的NumPy索引:字段访问,基本切片和编码ML所必需的高级索引。

3.支持广播 - 任意(但兼容)维度的标量和张量。

如果在核心Java语言和运行时中存在这三件事,它将开辟构建“ NumJava ” 的道路,相当于NumPy。巴拿马项目还可以用于提供对CPU,GPU,TPU等运行的快速张量操作的矢量化低级访问,以帮助Java ML成为最快的。

我并不是说这些补充是微不足道的 - 远非如此,但Java平台的潜在优势是巨大的。

下面的代码片段展示了我们的NumPy广播和索引示例如何在NumJava中使用Tensor类,核心语言支持切片语法,并尊重当前对运算符重载的限制。

//Java广播的张量
//使用Java 10中的var语法进行简洁性
// Java不支持运算符重载,所以我们不能做“a * b”
//我们应该将其添加到需求列表中吗?
var a = new Tensor([1.0, 2.0, 3.0]);
var b = 2.0;
var c = a.mult(b);
/**
 * And a snippet showing how the Java Tensor class could look.
 *显示Java Tensor类的外观的片段。
 */
import static java.math.Numeric.arange;
//arange returns a tensor instance and reshape is defined on tensor
var y = arange(35).reshape(5,7);
System.out.println(y);
// tensor([[ 0, 1, 2, 3, 4, 5, 6],
// [ 7, 8, 9, 10, 11, 12, 13],
// [14, 15, 16, 17, 18, 19, 20],
// [21, 22, 23, 24, 25, 26, 27],
// [28, 29, 30, 31, 32, 33, 34]])
System.out.println(y[0,0]);
// single cell access - notation is row-major, returns 0
System.out.println(y[4]);
// returns all of row 4 (5th row starting from 0 idx): tensor([28, 29, 30, 31, 32, 33, 34])
System.out.println(y[:,2]);
// returns all of column 2 (3rd col starting from 0 idx): tensor([ 2, 9, 16, 23, 30])

总结

从本文中概述的实用起点开始,我们可以拥有用Java编写并在JRE上运行的尽可能多的机器/深度学习框架,因为我们有Web,持久性或XML解析器 - 想象一下!我们可以设想Java框架支持卷积神经网络(CNN)用于前沿计算机视觉,像LSTM这样的循环神经网络实现对于顺序数据集(对业务至关重要),具有尖端的ML功能,如自动差异化等。然后,这些框架将为下一代企业级系统提供动力并为其提供动力 - 所有这些系统都使用相同的工具 - IDE,测试框架和持续集成。

作者:JDON

原文:https://www.jdon.com/51172

Java还欠缺什么才能真正支持机器/深度学习?

相关推荐