Created: by Pradeep Gowda . Updated: Mar 01, 2023 Tagged: deep-learning · python

JAX is a Python library which augments numpy and Python code with function transformations which make it trivial to perform operations common in machine learning programs. Concretely, this makes it simple to write standard Python/numpy code and immediately be able to

Compute the derivative of a function via a successor to autograd Just-in-time compile a function to run efficiently on an accelerator via XLA Automagically vectorize a function, so that e.g. you can process a “batch” of data in parallel. learn more … “You don’t know Jax” (2019)

Jax’s website - google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more