谷歌jax:下一代深度学习框架大战的前夜(代码片段)

ScorpioDoctor ScorpioDoctor     2022-12-05     116

关键词:

机器学习工具的革命:下一代深度学习框架大战的前夜

本文来自 人工智能社区 http://studyai.com/article/a6b60b7f

1、介绍

让我们回到一个更简单的时代,当所有人在机器学习中谈论的都是支持向量机和增强树(Boosted Tree),而Andrew Ng将神经网络引入了一个在实践中可能永远不会用到的巧妙的帽子戏法。

那一年是2012年,基于计算机视觉的比赛ImageNet将再次被最新的核心方法组合赢得。当然,直到几位研究人员公布了AlexNet,它的错误率几乎是竞争对手的一半,我们现在通常称之为 “深度学习”

许多人认为AlexNet是十年来最重要的科学突破之一,它无疑有助于改变ML研究的格局。然而,不需要太多的时间就可以认识到,在幕后,它只是先前迭代改进的组合,许多改进可以追溯到90年代初。AlexNet的核心“仅仅”是一个改进的LeNet3,它具有更多的层、更好的权重初始化、激活函数和数据增强功能。

2、ML研究中的工具

那么是什么让AlexNet如此引人注目呢?我相信答案在于研究人员掌握的工具,使他们能够在GPU加速器上运行人工神经网络,这在当时是一个相对新颖的想法。事实上,Alex Krizhevsky的前同事回忆说,比赛前的许多会议都是由Alex描述他在CUDA怪癖和特写方面的进步组成的。

现在让我们回到2015,当ML研究文章提交开始全面爆发,包括(重新)出现许多现在有希望的方法,如生成对抗学习,深度强化学习,元学习,自我监督学习,联合学习,神经架构搜索,神经微分方程,神经图网络,等等。

有人可能会说,这只是人工智能炒作的自然结果。然而,我相信一个重要的因素是第二代通用ML框架的出现,比如TensorFlow和PyTorch,以及NVIDIA在 AI上的全面应用。以前存在的框架,如Caffe和Theano,很难使用,也很难扩展,这就减缓了新思想的研究和开发。

3、革命的必要性

TensorFlow和PyTorch无疑是一个积极的网络,团队努力改进库。最近,他们交付了TensorFlow 2.0,它有一个更直接的接口和eager 模式,Pythorch 1.0 提供了计算图的JIT编译以及对基于XLA的加速器(如TPU)的支持。然而,这些框架也开始达到它们的极限,迫使研究人员进入一些途径,同时关闭其他途径,就像他们的前辈一样。

AlphaStar和OpenAI等备受瞩目的DRL项目不仅利用了大规模的计算集群,还通过结合深度变换器、嵌套的递归网络、深度残差塔等,突破了深度学习体系结构组件的局限性。

Demis Hassabis 在接受泰晤士报采访时表示,DeepMind将专注于直接应用人工智能实现科学突破。我们已经可以通过他们最近发表的一些关于神经科学和蛋白质折叠的自然文章,看到这一方向的转变。即使对这些出版物作一个简短的浏览,也足以看出这些项目在工程方面需要一些非常规的方法。

在NeurIPS 2019,概率规划和贝叶斯推理是热门话题,尤其是不确定性估计和因果推理。领先的人工智能研究人员展示了他们对人工智能未来的展望。值得注意的是,Yoshua Bengio描述了向 system 2 深度学习的过渡,包括分布外泛化、稀疏图网络和因果推理。

总而言之,下一代ML工具的一些要求是:

细粒度控制流使用
非标准优化循环
作为一等公民的高阶微分
一等公民的概率规划
在一个模型中支持多个异构加速器
从单机到大型集群的无缝可扩展性

理想情况下,这些工具还应该维护一个干净、直接和可扩展的API,使科学家能够快速研究和开发他们的想法。

4、下一代工具

好消息是,今天已经有许多候选工具,是为了响应科学计算的需要而出现的。从实验项目如zygate.jl到专门语言,如halide1和DiffTaichi。有趣的是,许多项目从自动微分社区的研究人员所做的基础工作中获得灵感,这些工作与ML并行发展。

他们中的许多人在最近的NeurIPS 2019项目转型研讨会上亮相。我最兴奋的两个是 S4TFJAX 。他们都致力于把可微编程变成工具链的一个组成部分,但他们都以自己的方式,几乎是正交的。

4.1、S4TF

顾名思义,S4TF将TensorFlow ML框架与Swift编程语言紧密集成。对该项目的信任投票是由Chris Lattner领导的,他曾撰写LLVM、Clang和Swift。

Swift是一种编译编程语言,它的主要卖点之一是强大的静态类型系统。最后一部分简单地说,SWIFT包含了在Python中使用代码的验证和转换,例如在C++中的易用性。

let a: Int = 1
let b = 2
let c = "3"

print(a + b)         // 3
print(b + c)         // compilation (!) error
print(String(b) + c) // 23

Swift特性使S4TF团队能够通过在编译过程中对使用有效算法执行的计算图进行分析、验证和优化,来满足下一代列表中的许多要求。

最重要的是,自动微分的处理被卸载到编译器中。

这个项目是一项重大的工程,在投产前还有一些路要走。然而,这是一个让工程师和研究人员都去尝试的好时机,并且有可能对它的发展做出贡献。关于S4TF的工作已经在编程语言和自动微分理论的交叉点上产生了有趣的科学进展。

关于S4TF,有一件事对我来说特别突出,那就是他们的社区推广方式。例如,核心开发人员每周都会举行设计会议,任何有兴趣参加甚至参与的人都可以参加。

当然,TensorFlow本身在这个例子中得到了很好的支持。

import TensorFlow

struct Model: Layer 
    var conv = Conv2D<Float>(filterShape: (5, 5, 6, 16), activation: relu)
    var pool = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
    var flatten = Flatten<Float>()
    var dense = Dense<Float>(inputSize: 16 * 5 * 5, outputSize: 100, activation: relu)
    var logits = Dense<Float>(inputSize: 100, outputSize: 10, activation: identity)

    @differentiable
    func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> 
        return input.sequenced(through: conv, pool, flatten, dense, logits)
    


var model = Model()
let optimizer = RMSProp(for: model, learningRate: 3e-4, decay: 1e-6)

for batch in CIFAR10().trainDataset.batched(128) 
  let (loss, gradients) = valueWithGradient(at: model)  model in
    softmaxCrossEntropy(logits: model(batch.data), labels: batch.label)
  
  print(loss)
  optimizer.update(&model, along: gradients)

另一方面,如果一个关键特性被证明难以实现,那么对整个管道有深入的了解就特别有价值。例如,MLIR编译器框架是S4TF工作的直接结果。

虽然可微编程是核心目标,但S4TF远不止支持各种下一代ML工具(如调试器)的基础设施。例如,假设一个IDE警告用户自定义模型计算总是导致零梯度,甚至不执行它。

Python有一个围绕科学计算构建的令人难以置信的社区,S4TF团队已经明确地花了时间通过互操作性来接受它。

import Python // All that is necessary to enable the interop.

let np = Python.import("numpy") // Can import any Python module.
let plt = Python.import("matplotlib.pyplot") 

let x = np.arange(0, 10, 0.01)
plt.plot(x, np.sin(x)) // Can use the modules as if inside Python.
plt.show() // Will show the sin plot, just as you would expect.

这个项目是一项重大的工程,在投产前还有一些路要走。然而,这是一个让工程师和研究人员都去尝试的好时机,并且有可能对它的发展做出贡献。关于S4TF的工作已经在编程语言和自动微分理论的交叉点上产生了有趣的科学进展。

关于S4TF,有一件事对我来说特别突出,那就是他们的社区推广方式。例如,核心开发人员每周都会举行设计会议,任何有兴趣参加甚至参与的人都可以参加。

要了解更多有关Swift for TensorFlow的信息,以下是一些有用的资源:

Fast.ai’s Lessons 13 and 14
Design Doc: Why Swift For TensorFlow?
Model Training Tutorial
Pre-Built Google Colab Notebook
Swift Models for Popular Architectures in DL and DRL

4.2、 JAX

JAX是一个函数转换的集合,例如实时编译和自动微分,它是用一个API在XLA上实现的瘦包装器,API本质上是NumPy和SciPy的替代品。事实上,开始使用JAX的一种方法是将其视为一个加速器支持的NumPy。

import jax.numpy as np

# Will be seamlessly executed on an accelerator such as GPU/TPU.
x, w, b = np.ones((3, 1000, 1000))
y = np.dot(w, x) + b

当然,实际上,JAX远不止这些。对许多人来说,这个项目似乎是凭空出现的,但事实是,它是跨越三个项目的五年多研究的演变。值得注意的是,JAX是从Autograd(一种对本地程序代码的AD的研究)发展而来的,它概括了支持任意转换的核心思想

def f(x):
  return np.where(x > 0, x, x / (1 + np.exp(-x)))

# Note: same singular style for the API entry points.
jit_f = jax.jit(f) # Will be 10-100x faster, depending on the accelerator.
grad_f = jax.grad(f) # Will work as expected, handling both branches. 

除了上面讨论的gradjit之外,还有两个更优秀的JAX转换示例,帮助用户通过批处理维度的自动矢量化(vmap)或跨多个设备(pmap)批处理数据。

a = np.ones((100, 300))

def g(vec):
  return np.dot(a, vec)

# Suppose `z` is a batch of 10 samples of 1 x 300 vectors.
z = np.ones((10, 300))

g(z) # Will not work due to (batch) dimension mismatch (100x300 x 10x300).

vec_g = jax.vmap(g)
vec_g(z) # Will work, efficiently propagating through batch dimension.

# Manual solution requires "playing" with matrix transpositions.
np.dot(a, z.T)

这些特性一开始可能看起来令人困惑,但经过一些实践,它们变成了研究人员工具箱中不可替代的一部分。它们甚至激发了最近TensorFlow和PyTorch中类似功能的开发。

目前,JAX作者在开发新特性时似乎坚持自己的核心能力。当然,合理的方法也是其主要缺点之一:缺乏内置的神经网络组件,除了证明Stax的概念。

添加高级功能是终端用户可能参与和贡献的东西,并且赋予JAX的坚实基础,任务可能比看起来更容易。例如,现在有两个构建在JAX之上的“竞争”库,它们都是由Google研究人员开发的,使用不同的方法:TraxFlax

# Trax approach is functional.
# Note: params are stored outside and `forward` is "pure".

import jax.numpy as np
from trax.layers import base

class Linear(base.Layer):
  def __init__(self, num_units, init_fn):
    super().__init__()
    self.num_units = num_units
    self.init_fn = init_fn

  def forward(self, x, w):
    return np.dot(x, w)

  def new_weights(self, input_signature):
    w = self.init_fn((input_signature.shape, self._num_units))
    return w
# Flax approach is object-oriented, closer to PyTorch style.

import jax.numpy as np
from flax import nn

class Linear(nn.Module):
  def apply(self, x, num_units, init_fn):
    W = self.param('W', (x.shape[-1], num_units), init_fn)
    return np.dot(x, W)

尽管有些人可能更喜欢由核心开发人员认可的单一方式,但方法的多样性很好地表明这项技术是可靠的。

在JAX特性特别突出的领域,也有一些研究方向。例如,在元学习中,训练元学习器的一种常见方法是计算输入的梯度。为了有效地解决这一问题,需要一种计算梯度的替代方法——前向模式自动微分法,这在JAX中是现成的,但在其他库中要么是不存在的,要么是实验性的特性。

JAX可能比它的S4TF计数器部件更精良,更易于生产,Google研究的一些最新发展依赖于它,比如 Reformer——一种内存高效的转换器模型,能够处理一百万字的上下文窗口,同时安装在消费者的GPU上,和神经切线-一个无限宽复杂神经网络库。

该库还被更广泛的科学计算社区所接受,用于分子动力学、概率规划和约束优化等领域的工作。

要开始使用JAX并进一步阅读,请查看以下内容:

Talk: Overview by Skye Wanderman-Milne, a core developer (starts at 44:26)
Notebook: Quickstart, going over fundamental features
Notebook: Cloud TPU Playground
Blog: You don’t know JAX
Blog: Massively parallel MCMC with JAX
Blog: Differentiable Path Tracing on the GPU/TPU

5、结论

ML研究已经开始触及我们目前可以使用的工具的极限,但是一些新的、令人兴奋的候选者即将到来,比如JAX和S4TF。如果你觉得自己更像一个工程师,而不是一个研究人员,并想知道是否有一个地方值得攻入,答案是明确的:现在是攻入它的最佳时机。此外,您还有机会参与到下一代ML工具的底层!

请注意,这并不意味着TensorFlow或PyTorch会在不久的将来就马上寿终正寝。在这些成熟的、经过战斗测试的库中仍然有很多价值。毕竟,JAX和S4TF都有TensorFlow的一部分。但是,如果你即将开始一个新的研究项目,或者如果你觉得你是围绕着库的限制而不是你的想法工作,那么也许给他们一个尝试!

谷歌jax快速入门笔记详解和案例(代码片段)

一.什么是JAX?JAX最初由谷歌大脑团队的MattJohnson、RoyFrostig、DougalMaclaurin和ChrisLeary等人发起,借助Autograd的更新版本,并且结合了XLA,可对Python程序与NumPy运算执行自动微分,支持循环、分支、递归、闭包函数求导&#x... 查看详情

谷歌jax快速入门笔记详解和案例(代码片段)

一.什么是JAX?JAX最初由谷歌大脑团队的MattJohnson、RoyFrostig、DougalMaclaurin和ChrisLeary等人发起,借助Autograd的更新版本,并且结合了XLA,可对Python程序与NumPy运算执行自动微分,支持循环、分支、递归、闭包函数求导&#x... 查看详情

谷歌jax快速入门笔记详解和案例(代码片段)

一.什么是JAX?JAX最初由谷歌大脑团队的MattJohnson、RoyFrostig、DougalMaclaurin和ChrisLeary等人发起,借助Autograd的更新版本,并且结合了XLA,可对Python程序与NumPy运算执行自动微分,支持循环、分支、递归、闭包函数求导&#x... 查看详情

两个不同的深度学习框架如何使用相同的模型?(代码片段)

在使用张量流here的深度梦想示例中,代码引用了谷歌开发的inception5h模型。然而,谷歌here的原始代码是使用caffe,而不是tensorflow,可能是因为当时不存在张量流。两个不同的框架可以使用相同的模型怎么样?与bvlc_googlenet.caffemod... 查看详情

深度学习及pytorch基础

...结,最后列出没有学明白的问题。【任务二】代码练习在谷歌Colab上完成代码练习中的2.1、2.2、2.3、2.4节,关键步骤截图,并附一些自己的想法和解读。【任务三】进阶练习在谷歌Colab上完成猫狗大战的VGG模型的迁移学习,关键... 查看详情

tensorflow-谷歌深度学习库命令行参数(代码片段)

程序的入口:tf.app.runtf.app.run(main=None,argv=None)运行程序,可以提供‘main’函数以及函数参数列表。处理flag解析然后执行main函数。 什么是flag解析呢? 由于深度学习神经网络往往需要对各种Hyperparameter调优,比如学习率,... 查看详情

tensorflow败给pytorch,谷歌:未来就靠你了,jax

整理|彭慧中责编|屠敏出品|CSDN(ID:CSDNnews)谷歌是机器学习领域的开拓者,它于2015年发布开源深度学习框架TensorFlow,开创了现代机器学习的生态系统.TensorFlow一经发布便迅速被开发者热捧,谷歌也由此成... 查看详情

深度学习框架介绍(代码片段)

深度学习框架介绍1.常见深度学习框架对比2.TensorFlow的特点3.TensorFlow的安装4.Tenssorlfow使用技巧1.常见深度学习框架对比tensorflow的github:2.TensorFlow的特点官网:https://www.tensorflow.org/语言多样(LanguageOptions)TensorFlo 查看详情

谷歌算法研究员:我为什么钟爱pytorch?(代码片段)

老铁们好!我是一名前谷歌的算法研究员,处理深度学习相关项目已有三年经验,接下来会在平台上给大家分享一些深度学习,计算机视觉和统计机器学习的心得体会,当然了内推简历一定是收的。这篇文章,不想说太多学术的... 查看详情

基于深度强化学习的局内战斗自动化测试探索(代码片段)

游戏项目研发时,期望搭建自动化测试平台,发现局内bug,避免重复劳动、提高测试效率以及避免人为的操作错误。其中环境要求使用项目需要使用Airtest、poco对接强化学习的服务器,实现Airtest将状态信息发送给服务器,服务器... 查看详情

codeforbetter谷歌开发者之声——tensorflow与深度学习(代码片段)

目录一、TensorFlow简介二、机器学习与深度学习2.1什么是机器学习2.2 什么是深度学习2.3机器学习和深度学习应用2.4趋势三、TensorFlow实现递归神经网络一、TensorFlow简介  TensorFlow是由Google团队开发的深度学习框架之一,它是一... 查看详情

深度学习计算框架实现(代码片段)

参考与评述参考书目《DeepLearning》LanGoodfellow.经典的深度学习框架是以计算图&梯度下降方法实现对前馈网络的有监督学习。这里复现了前馈计算图的梯度计算实现。一、前馈计算图实现1.前向与梯度计算结果数组(保存输入节... 查看详情

tensorflow-谷歌深度学习库图片处理模块(代码片段)

Module:tf.image这篇文章主要介绍TensorFlow处理图片这一块,这个模块和之前说过的文件I/O处理一样也是主要从python导过来的。通过官方文档,我们了解到这个模块主要有一下这些个函数。Functionsadjust_brightness(...):AdjustthebrightnessofRGBorG... 查看详情

深度学习框架keras入门案例(代码片段)

...训练支持Keras的发展得到关键公司的支持,比如:谷歌、微软等详细信息见中文官网:https://keras.io/zh/why-use-keras/主要步骤使用Keras解决机器学习/深度学习问题的主要步骤:特征工程+数据划分搭建神经网络模型ad... 查看详情

《tensorflow实战google深度学习框架》pdf一套四本+源代码_高清_完整

...解析与实战》,不能错过的学习Tensorflow书籍。TensorFlow是谷歌2015年开源的主流深度学习框架,目前已在谷歌、优步(Uber)、京东、小米等科技公司广泛应用。《Tensorflow实战》为使用TensorFlow深度学习框架的入门参考书,旨在帮助... 查看详情

深度学习框架-tensorflow(代码片段)

...操作知道tf.keras中的相关模块及常用方法 1.1TensorFlow介绍深度学习框架TensorFlow一经发布,就受到了广泛的关注,并在计算机视觉、音频处理、推荐系统和自然语言处理等场景下都被大面积推广使用,现在已发布2.3.0版... 查看详情

学习《tensorflow实战google深度学习框架(第2版)》中文pdf和代码

TensorFlow是谷歌2015年开源的主流深度学习框架,目前已得到广泛应用。《TensorFlow:实战Google深度学习框架(第2版)》为TensorFlow入门参考书,帮助快速、有效的方式上手TensorFlow和深度学习。书中省略了烦琐的数学模型推导,从实... 查看详情

深度学习框架中的动态shape问题(代码片段)

前言最近这段时间实现了公司内部深度学习框架的dynamicshape的功能。在完成这个功能之后,我进一步深刻体会到了乔布斯说过的一段话:Whenyoustartlookingataproblemanditseemsreallysimple,youdon’treallyunderstandthecomplexityoftheproblem.Thenyo... 查看详情