研究|AI研究的提速器! DeepMind力荐的JAX到底有多强大?

From: DeepMind 编译: T.R
DeepMind 做出过很多世界级的研究成果 , 这离不开一套高效完备的工具链和系统 。 工程师们致力于构建工具、可规模化的算法 , 创造充满挑战的虚拟与物理世界 , 以此来训练和测试人工智能系统;在开发与研究的同时 , 也不断测评新的机器学习工具包与开源框架 , 不断提高生产力 。
最近 , 随着越来越多项目的使用 , 由谷歌研究开发的 JAX 工具包逐渐得到研究和工程领域人员的喜爱 。 JAX 工具包具有良好的工程哲学和便捷的使用方法 , DeepMind 的研究人员在充分地使用后整理出了详细的分析 , 并汇总了JAX的整个生态系统介绍 , 为想要入坑的小伙伴们提供了详细的参考 。
JAX 厉害在哪儿?
JAX 是一个用于高性能数值计算的 Python 库 , 特别为机器学习领域的高性能计算设计 。 它的 API 基于 Numpy 构建 , 包含丰富的数值计算与科学计算函数 。

研究|AI研究的提速器! DeepMind力荐的JAX到底有多强大?
文章图片

Python 和 Numpy 的广泛使用 , 使得 JAX 十分简洁、灵活、易于上手 , 学习成本也比较低 。 除了 Numpy 的 API 外 , JAX 还包含一系列可拓展、可组合的系统功能 , 有力地支持了机器学习研究 。 这些功能特性主要包括:
可差分:基于梯度的优化方法在机器学习领域具有十分重要的作用 。 JAX 可通过grad、hessian、jacfwd 和 jacrev 等函数转换 , 原生支持任意数值函数的前向和反向模式的自动微分 。
向量化:在机器学习中 , 通常需要在大规模的数据上运行相同的函数 , 例如计算整个批次的损失或每个样本的损失等 。 JAX 通过 vmap 变换提供了自动矢量化算法 , 大大简化了这种类型的计算 , 这使得研究人员在处理新算法时无需再去处理批量化的问题 。 JAX 同时还可以通过 pmap 转换支持大规模的数据并行 , 从而优雅地将单个处理器无法处理的大数据进行处理 。
JIT编译:XLA (Accelerated Linear Algebra, 加速线性代数) 被用于 JIT 即时编译 , 在 GPU 和云 TPU 加速器上执行 JAX 程序 。 JIT 编译与 JAX 的 API (与 Numpy 一致的数据函数) 为研发人员提供了便捷接入高性能计算的可能 , 无需特别的经验就能将计算运行在多个加速器上 。
DeepMind 是如何使用 JAX 的?
为了支持前沿的人工智能研究 , 需要有效平衡好快速的原型验证 , 以及在生产系统上进行规模化部署与迭代验证的能力 。 这些项目的挑战在于:研究领域正在发生着的变化迅速而深远 , 并且难以预测 。 突破可能发生在任何地方 , 并且常常会改变整个研究团队的需求和轨迹 。 在瞬息万变的环境中 , 工程团队的职责是总结每一个项目中学习到的经验和代码 , 使其可以在未来被有效复用 。
模块化是一种非常成功的方法 , 它将研究中最为重要和关键的构建模块抽取出来 , 成为通过良好测试的基本构造模块 。 模块化使得研究人员可以聚焦于研究本身 , 并从其他可复用的代码中受益 , 使研究避免bug的干扰、得到更多的性能提升 。 工程团队还发现 , 最重要的是保证每个库都有明确的定义范围 , 并确保它们之间可以相互调用但保持独立 。 此外 , 还需要具有增量复用的能力 , 具有可选择而不被其他功能锁定的能力 。 这些要素至关重要 , 有助于为研究人员提供最大程度的灵活性与选择性 。
另外 , 在开发 JAX 的生态过程中 , 还需要保证与已有计算框架 (例如 Tensorflow、Sonnet、TRFL 等) 的连续性与一致性 , 需要在构建过程中尽量接近其基础数学原理 , 实现完善的自描述 , 并避免从纸面到代码的思维跳跃 。
最后 , 开源是保障研究成果共享和促进研究的重要途径 。 越多的人参与探索 JAX 生态系统 , 就越会促进 JAX 更为迅速的发展壮大 。
DeepMing 的 JAX 生态系统
多年来 , 许多研究基于 JAX 构建了涉及多个领域的工具与开发库 , 为研发人员提供了诸多轮子、形成了丰富的生态系统 。 下面让我们一起来看看 DeepMind 常用的工具都有哪些 。
1. Haiku: 基于 JAX 的编程模型对处理神经网络这类含有可训练参数的对象会很复杂 。 为此 , 研究人员开发了神经网络库 Haiku , 让用户可以使用熟悉的面向对象编程模型 , 同时得以利用 JAX 强大并简洁的纯函数范式 。
Haiku有诸多活跃用户 , 在谷歌和 DeepMind 就得到了数百位研究人员的使用 , 并已经被多个外部项目 (Coax、DeepChem、NumPyro) 采用 。 它基于 DeepMind 神经网络工具 Sonnet 的 API 构建 , 从 Sonnet 到 Haiku 的迁移成本会逐渐降低 。

研究|AI研究的提速器! DeepMind力荐的JAX到底有多强大?
文章图片

详情请参看下方链接:
https://github.com/deepmind/dm-haiku
2. Optax: 基于梯度的优化方法是机器学习的基础 。 Optax 提供了包含梯度转换与合成运算符 (例如链) 的工具库 , 这使得仅仅一行代码就可以实现许多标准的优化器 (例如RMSProp或Adam) 。 Optax 天然支持用户在定制优化器中对基本部分进行重组 , 并且还提供了许多用于随机梯度估计和二阶优化的工具集 。

研究|AI研究的提速器! DeepMind力荐的JAX到底有多强大?
文章图片

详情请参看下方链接:
https://github.com/deepmind/optax
3. RLax:许多成功的项目都位于深度学习与强化学习的交叉学科前沿 。 RLax 库为构建 RL 主体提供了有效的基础模块 。 RLax 中的组件涵盖了丰富的算法和思想 , 包括:TD学习法、策略梯度、参与者评定法 (actor critics)、MAP、近端策略优化、非线性价值转换以及通用价值函数和多种学习探索方法 。

研究|AI研究的提速器! DeepMind力荐的JAX到底有多强大?
文章图片

尽管 RLax 提供了一些介绍性的构建代理的例子 , 但它的主要目标并不是用作构建和部署完整 RL 系统的框架 。 另一个工具包 Acme 是基于 RLax 组件构建的框架 , 如果需要框架性的支持可以使用 Acem 来完成 。
详情请参看下方链接:
Rlax:https://github.com/deepmind/rlax
Acem:https://deepmind.com/research/publications/Acme
4. Chex:测试对于验证软件的可靠性至关重要 , 对于验证代码的正确性也同样重要 。 代码的正确性将提高研究结论的可信度 。 为了对代码进行有效测试 , 测试工具集 Chex 为工具库作者提供了验证通用构件是否正确且健壮的有效手段 , 终端用户也可以使用它们来检查实验代码是否可靠 。

研究|AI研究的提速器! DeepMind力荐的JAX到底有多强大?
文章图片

Chex 提供了各种实用的工具 , 包括可识别 JAX 的单元测试、JAX 数据类型的属性声明校验、模拟和伪造多种设备的测试环境 。 Chex 已被用于 DeepMind 的 JAX 生态系统以及多个外部项目 (例如 Coax 和 MineRL ) 。
详情请参看下方链接:
https://github.com/deepmind/chex
5. Jraph:图神经网络 (GNN) 是近年来飞速发展的研究领域 , 在交通流量预测和物流模拟等众多领域具有非常强大的应用潜力 。 Jraph (与“长颈鹿”一个发音) 是一个轻量级的库 , 支持在 JAX 中构建和使用 GNN。

研究|AI研究的提速器! DeepMind力荐的JAX到底有多强大?
文章图片

Jraph 提供了用于图数据结构的标准化数据结构 , 以及一系列用于处理图的工具库和构建可扩展 GNNs 的神经网络库 。 此外还包括下面一些特性:基于Graph Tuples 的批处理可以有效利用硬件加速器的性能;通过填充、掩膜以及在输入分区上定义损失来实现对于可变形图卷积的 JIT 支持 。 与 Optax 及其他 JAX 库一样 , Jraph 对用户选择其他神经网络库没有的任何限制 。
详情请参看下方链接:
https://github.com/deepmind/jraph
JAX 生态系统一直在持续演进 , 越来越多的研究与开发者开始使用 JAX 生态中的工具库来加速研究与开发 , 相信你也能从中收获到更强大的机器学习性能!
更多细节请参看下面的链接:
https://deepmind.com/blog/article/using-jax-to-accelerate-our-research
【研究|AI研究的提速器! DeepMind力荐的JAX到底有多强大?】https://github.com/google/jax#neural-network-libraries

    推荐阅读