谷歌jax快速入门笔记详解和案例(代码片段)

cui_yonghua cui_yonghua     2023-03-28     378

关键词:

一. 什么是JAX?

JAX最初由谷歌大脑团队的 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 等人发起,借助 Autograd 的更新版本,并且结合了 XLA,可对 Python 程序与 NumPy 运算执行自动微分,支持循环、分支、递归、闭包函数求导,也可以求三阶导数;依赖于 XLA,JAX 可以在 GPU 和 TPU 上编译和运行 NumPy 程序;通过 grad,可以支持自动模式反向传播和正向传播,且二者可以任意组合成任何顺序。

JAX并非是一个深度学习的框架或者库,它的设计目标也并非是作为一个新的深度学习框架。

简单来说,JAX是一个包含可组合函数变换的数值计算库,只不过深度学习恰好是JAX能做的一项工作。

JAX处于函数变换(function transformations)和科学计算的交界处,所以也有能力训练神经网络模型,但不止于训练。

目前JAX在Github上已经斩获了超2万多颗star:

github地址:https://github.com/google/jax(截至目前,star数:20.3k)

官方文档:https://jax.readthedocs.io/en/latest/

JAX 是一个非常有前途的项目,并且用户一直在稳步增长。JAX 已经在深度学习、机器人 / 控制系统、贝叶斯方法和科学模拟等诸多领域得到了广泛应用。

二. 为什么应该使用JAX

JAX目前已经达到深度学习的最高水平。在当前开源的框架中,没有哪一个框架能在简洁、易用、速度这3个方面有两个能同时超过JAX。

  • 简洁:JAX的设计追求最少的封装,尽量避免重复造轮子。设计遵循tensor→variable(autograd)→module 3个由低到高的层次,分别代表高维数组(张量),自动求导(变量)和神经网络(层/模块),而且这3个抽象直接连接紧密,可以同时进行修改和操作。而tensorflow充斥着graph、operation、tensor、layer等全新的概念。JAX源码只有 tensorflow 的十分之一左右,更少的抽象、更直观的设计使得JAX的源码十分易于阅读。
  • 速度:在许多测评中,JAX的速度表现胜过TensorFlow和PyTorch等框架。
  • 易用:JAX是所有框架中面向对象设计的最优雅的一个,符合人们的思维,可以让用户专注于自己的想法,不需要考虑太多关于框架本身的束缚。

2.1 应该使用JAX的 6个原因:

  1. 加速NumPy。NumPy是用Python进行科学计算的基本软件包之一,但它只与CPU兼容。JAX提供了一个NumPy的实现(具有近乎相同的API),可以非常容易地在GPU和TPU上工作。对于许多用户来说,仅仅这一点就足以证明使用JAX的合理性。

  2. XLA,即加速线性代数(Accelerated Linear Algebra),是一个全程序优化编译器,专门为线性代数设计。JAX是建立在XLA之上的,大大提升了计算速度的上限。

  3. JIT。JAX允许用户使用XLA将函数转化为JIT(just in time)编译的版本。这意味用户可以通过给计算函数添加一个简单的函数装饰器来提高计算速度,可能是几个数量级的性能提升。

  4. 自动求导。JAX文档将JAX称为Autograd和XLA的结合体。自动求导的能力在科学计算的许多领域都是至关重要的,而JAX提供了几个强大的自动求导工具。

  5. 深度学习。虽然JAX本身不是一个深度学习框架,但它肯定为深度学习提供了一个更充分的基础。现在有许多建立在JAX之上的深度学习库,例如Flax、Haiku和Elegy。甚至有研究人员在PyTorch vs TensorFlow文章中强调JAX也是一个值得关注的「框架」,推荐其用于基于TPU的深度学习研究。JAX对Hessians的高效计算也与深度学习有关,因为它们使高阶优化技术更进一步。

  6. 通用可微分编程范式。虽然可以使用JAX来构建和训练深度学习模型,但它也为通用可微分编程提供了一个框架。这意味着JAX可以通过使用基于模型的机器学习方法来解决实际问题。

2.2 不该使用JAX的情况

虽然JAX有可能极大地提高你的程序的性能,但也有几种情况下,是不适合使用JAX的。

  1. JAX仍然被官方认为是一个实验性框架,而不是一个完全成熟的Google产品,所以如果你正在考虑转移到JAX,需要慎重考虑。
  2. 在使用JAX时,调试的时间成本会更高,并且有很多bug仍然未被发现。对于那些没有牢固掌握函数式编程的人来说,使用JAX可能不值得。在开始将JAX用于正式的项目之前,请确保了解使用JAX的常见陷阱。
  3. JAX没有针对CPU计算进行优化。鉴于JAX是以「加速优先」的方式开发的,因此每个操作的调度并没有完全优化。正因为如此,在某些情况下,NumPy实际上可能比JAX更快,特别是对于小程序来说。
  4. JAX与Windows不兼容。目前在Windows上没有对JAX的支持。

三. 安装和使用

python环境下安装:
pip3 install jax
pip3 install jaxlib

能成功打印如下代码即可成功安装。

import jax
print(jax.random.PRNGKey(17))  # [ 0 17]

注意:jax.numpy是CPU、GPU和TPU上的numpy,具有出色的自动差异化功能,可用于高性能机器学习研究。。由此特意进行了测试:

先试一下原生的numpy:

import numpy as np
import time

x = np.random.random([5000, 5000]).astype(np.float32)
try:
    st = time.time()
    y = np.matmul(x, x)
    print(time.time() - st)
    print(y)
except Exception as e:
    print(f"error: e")

执行结果如下:(耗时4.1秒)

再来试一下jax带的numpy:(耗时4.8秒)

import jax.numpy as np
from jax import random
import time

x = random.uniform(random.PRNGKey(0), [10000, 10000])
st = time.time()
try:
    y = np.matmul(x, x)
    print(time.time() - st)
    print(y)
except Exception as e:
    print(f"error: e")

执行结果如下:

实践是检验真理的唯一标准。(我用的MacBook Pro做的测试)

为什么JAX没有比numpy更快呢?

谷歌jax快速入门笔记详解和案例(代码片段)

一.什么是JAX?JAX最初由谷歌大脑团队的MattJohnson、RoyFrostig、DougalMaclaurin和ChrisLeary等人发起,借助Autograd的更新版本,并且结合了XLA,可对Python程序与NumPy运算执行自动微分,支持循环、分支、递归、闭包函数求导&#x... 查看详情

springaop快速入门详解(代码片段)

文章目录AOP1.AOP简介1.1AOP简介和作用1.2AOP中的核心概念3.1AOP工作流程2.AOP入门案例2.1AOP入门案例思路分析2.2AOP入门案例实现(一)导入aop相关坐标(二)定义dao接口与实现类(三)定义通知类,制作通知方法(四)定义切入点表达式、配... 查看详情

『jax中文文档』jax快速入门(代码片段)

...:https://jax.readthedocs.io/en/latest/notebooks/quickstart.htmlJAX快速入门首先解答一个问题:JAX是什么?简单的说就是GPU加速、支持自动微分(autodiff)的numpy。众所周知,numpy是Python下的基础数值运算库,得到广泛应用。用P... 查看详情

spring事务快速入门详解(代码片段)

文章目录Spring事务简介什么是事务?事务特征(ACID)Spring事务的作用Spring事务核心接口Spring事务角色Spring事务实战分析需求和分析代码实现测试环境准备使用Spring事务管理修改案例在业务层接口上添加Spring事务管理... 查看详情

html入门学习笔记+详细案例(代码片段)

✨HTML入门学习笔记+详细案例作者介绍:🎓作者:偷偷敲代码的青花瓷🐱‍🚀👀作者的Gitee:代码仓库✨✨我和大家一样都是热爱编程✨,很高兴能在此和大家分享知识,希望在分享知识的同时,能和大家一起共同进... 查看详情

flume快速入门及常用案例整理(代码片段)

flume快速入门及常用案例整理flume概述1.1flume定义flume是Cloudera提供的一个高可用的,高可靠的,分布式的海量日志采集、聚合和传输的系统,flume基于流式架构,灵活简单Flume最主要的作用就是,实时读取服务器... 查看详情

casbin访问控制框架入门详解及java案例示范(代码片段)

1、Casbin基本介绍Casbin是一个强大的、高效的开源访问控制框架,其权限管理机制支持多种访问控制模型。Casbin可以:支持自定义请求的格式,Listitem默认的请求格式为subject,object,action。2.具有访问控制模型model和策略po... 查看详情

casbin访问控制框架入门详解及java案例示范(代码片段)

1、Casbin基本介绍Casbin是一个强大的、高效的开源访问控制框架,其权限管理机制支持多种访问控制模型。Casbin可以:支持自定义请求的格式,Listitem默认的请求格式为subject,object,action。2.具有访问控制模型model和策略po... 查看详情

[python图像识别]四十五.目标检测入门普及和imageai“傻瓜式”对象检测案例详解(代码片段)

...章是讲解PythonOpenCV图像处理知识,前期主要讲解图像入门、OpenCV基础用法,中期讲解图像处理的各种算法,包括图像锐化算子、图像增强技术、图像分割等,后期结合深度学习研究图像识别、图像分类应用。希望... 查看详情

(超详解)springboot初级部分-快速入门-02(代码片段)

文章目录SpringBoot-快速入门-021.需求2.实现步骤2.1创建Maven项目2.2导入SpringBoot起步依赖2.3定义Controller2.4编写引导类2.5启动测试3.SpringInitializr创建SpringBoot工程4.SpringBoot起步依赖原理分析5.总结SpringBoot-快速入门-02该文章参考:黑... 查看详情

sparksql入门指南《读书笔记》(代码片段)

文章目录sparkSQL入门指南第一章初识sparkmysql1.1Spark的诞生和SparkSQL是什么?1.2 SparkSQL能做什么?第2章 Spark安装、编程环境搭建以及打包提交运行spark案例:运行pyspark案例其他案例第3章 Spark上的RDD(ResilientDistrib... 查看详情

springboot2整合nacos组件,环境搭建和入门案例详解(代码片段)

...理微服务。Nacos提供一组简单易用的特性集,帮助开发者快速实现动态服务发现、服务配置、服务元数据及流量管理。敏捷构建、交付和管理微服务平台。2、关键特性动态配置服务服务发现和服务健康监测动态DNS服务服务及其元... 查看详情

前端如何开始深度学习,那不妨试试jax(代码片段)

...个框架之外,一些新生的框架也不容小觑,比如谷歌推出的JAX深度学习框架。1.1、快速发展的JAXJAX是一个用于高性能数值计算的Python库,专门为深度学习领域的高性能计算而设计。自2018年底谷歌的JAX出现以来,它... 查看详情

jdbc-00-笔记(代码片段)

JDBC-系列-笔记一、JDBC快速入门1.jdbc的概念2.jdbc的本质3.jdbc的快速入门程序二、JDBC各个功能类详解1.DriverManager2.Connection3.Statement4.ResultSet三、JDBC的学习内容1,JDBC增删改查案例演示2,JDBC工具类抽取3,SQL防注入攻击࿰... 查看详情

分享《微信小程序开发图解案例教程》pdf及代码+《微信小程序开发入门及案例详解》pdf及代码

...单易懂,提供了丰富详尽的实战案例,带读者边做边学,快速掌握微信小程序的设计和实现。 《微信小程序开发入门及案例详解》PDF,514页;配套源代码。主要介绍微信小程序的开发流程,有部分内容是来自微信小程序开发... 查看详情

mybatis框架理解与快速入门详解(代码片段)

...;1.5原始jdbc操作的分析1.6Mybatis是什么操作?1.7Mybatis的快速入门1.7.1环境搭建1.7.2编写测试代码1.8知识小 查看详情

springcloud——ribbon的学习笔记以及入门案例(代码片段)

什么是RibbonRibbon是一个基于HTTP和TCP的客服端负载均衡工具,它是基于NetflixRibbon实现的,它几乎存在每一个SpringCloud微服务之中,包括Feign提供的声明式服务调用也是基于Ribbon实现的其默认提供很多负载均衡算法,... 查看详情

mobx知识点集合案例(快速入门)(代码片段)

文章目录一、文章参考二、知识点串联2.1使用注解方式串联2.2使用函数一、文章参考bilibiliReact系列教程之Mobxmobx中文文档二、知识点串联2.1使用注解方式串联importReact,Componentfrom'react';importReactDOMfrom'react-dom';import'./i... 查看详情