JAX (software)
{{Short description|Machine Learning framework designed for parallelization and autograd.}}
{{Overly detailed|date=February 2025}}{{Infobox software
| title = JAX
| name =
| logo = Google JAX logo.svg
| logo caption = JAX logo
| logo alt =
| logo size =
| collapsible =
| screenshot = Google JAX screenshot.png
| screenshot size =
| screenshot alt =
| caption =
| author =
| developer = Google, Nvidia{{cite web |url=https://github.com/jax-ml/jax/blob/main/AUTHORS |title=jax/AUTHORS at main · jax-ml/jax |date= |author= |website=GitHub |accessdate= December 21, 2024}}
| released =
| latest release version =
| latest release date =
| latest preview version = v0.4.31
| latest preview date = {{Start date and age|2024|07|30|df=yes}}
| repo = {{GitHub|jax-ml/jax}}
| programming language = Python, C++
| middleware =
| operating system = Linux, macOS, Windows
| size = 9.0 MB
| language count =
| language footnote =
| genre = Machine learning
| license = Apache 2.0
| website =
}}
JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. It is developed by Google with contributions from Nvidia and other community contributors.{{Citation |title=JAX: Autograd and XLA |date=2022-06-18 |url=https://github.com/google/jax |archive-url=https://web.archive.org/web/20220618205214/https://github.com/google/jax |publisher=Google |bibcode=2021ascl.soft11002B |access-date=2022-06-18 |archive-date=2022-06-18 |last1=Bradbury |first1=James |last2=Frostig |first2=Roy |last3=Hawkins |first3=Peter |last4=Johnson |first4=Matthew James |last5=Leary |first5=Chris |last6=MacLaurin |first6=Dougal |last7=Necula |first7=George |last8=Paszke |first8=Adam |last9=Vanderplas |first9=Jake |last10=Wanderman-Milne |first10=Skye |last11=Zhang |first11=Qiao |journal=Astrophysics Source Code Library }}{{Cite journal |last1=Frostig |first1=Roy |last2=Johnson |first2=Matthew James |last3=Leary |first3=Chris |date=2018-02-02 |year=2018 |title=Compiling machine learning programs via high-level tracing |url=https://mlsys.org/Conferences/doc/2018/146.pdf |url-status=live |journal=MLsys |pages=1–3 |archive-url=https://web.archive.org/web/20220621153349/https://mlsys.org/Conferences/doc/2018/146.pdf |archive-date=2022-06-21}}{{Cite web |title=Using JAX to accelerate our research |url=https://www.deepmind.com/blog/using-jax-to-accelerate-our-research |url-status=live |archive-url=https://web.archive.org/web/20220618205746/https://www.deepmind.com/blog/using-jax-to-accelerate-our-research |archive-date=2022-06-18 |access-date=2022-06-18 |website=www.deepmind.com |language=en}}
It is described as bringing together a modified version of [https://github.com/HIPS/autograd autograd] (automatic obtaining of the gradient function through differentiation of a function) and OpenXLA's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch.{{Cite web |last=Lynley |first=Matthew |title=Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta |url=https://www.businessinsider.com/facebook-pytorch-beat-google-tensorflow-jax-meta-ai-2022-6 |archive-url=https://web.archive.org/web/20220621143905/https://www.businessinsider.com/facebook-pytorch-beat-google-tensorflow-jax-meta-ai-2022-6 |archive-date=2022-06-21 |access-date=2022-06-21 |website=Business Insider |language=en-US}}{{Cite web |date=2022-04-25 |title=Why is Google's JAX so popular? |url=https://analyticsindiamag.com/why-is-googles-jax-so-popular/ |url-status=live |archive-url=https://web.archive.org/web/20220618210503/https://analyticsindiamag.com/why-is-googles-jax-so-popular/ |archive-date=2022-06-18 |access-date=2022-06-18 |website=Analytics India Magazine |language=en-US}} The primary features of JAX are:{{cite web | url=https://docs.jax.dev/en/latest/quickstart.html | title=Quickstart — JAX documentation }}
- Providing a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings.
- Built-in Just-In-Time (JIT) compilation via Open XLA, an open-source machine learning compiler ecosystem.
- Efficient evaluation of gradients via its automatic differentiation transformations.
- Automatically vectorized to efficiently map them over arrays representing batches of inputs.
grad
{{Main|Automatic differentiation}}
The below code demonstrates the grad function's automatic differentiation.
- imports
from jax import grad
import jax.numpy as jnp
- define the logistic function
def logistic(x):
return jnp.exp(x) / (jnp.exp(x) + 1)
- obtain the gradient function of the logistic function
grad_logistic = grad(logistic)
- evaluate the gradient of the logistic function at x = 1
grad_log_out = grad_logistic(1.0)
print(grad_log_out)
The final line should outputː
0.19661194
jit
The below code demonstrates the jit function's optimization through fusion.
- imports
from jax import jit
import jax.numpy as jnp
- define the cube function
def cube(x):
return x * x * x
- generate data
x = jnp.ones((10000, 10000))
- create the jit version of the cube function
jit_cube = jit(cube)
- apply the cube and jit_cube functions to the same data for speed comparison
cube(x)
jit_cube(x)
The computation time for {{code|jit_cube}} (line #17) should be noticeably shorter than that for {{code|cube}} (line #16). Increasing the values on line #7, will further exacerbate the difference.
vmap
{{Main|Array programming}}
The below code demonstrates the vmap function's vectorization.
- imports
from jax import vmap partial
import jax.numpy as jnp
- define function
def grads(self, inputs):
in_grad_partial = jax.partial(self._net_grads, self._net_params)
grad_vmap = jax.vmap(in_grad_partial)
rich_grads = grad_vmap(inputs)
flat_grads = np.asarray(self._flatten_batch(rich_grads))
assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
return flat_grads
The GIF on the right of this section illustrates the notion of vectorized addition.
pmap
The below code demonstrates the pmap function's parallelization for matrix multiplication.
- import pmap and random from JAX; import JAX NumPy
from jax import pmap, random
import jax.numpy as jnp
- generate 2 random matrices of dimensions 5000 x 6000, one per device
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)
- without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)
- without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately
means = pmap(jnp.mean)(outputs)
print(means)
The final line should print the valuesː
[1.1566595 1.1805978]
See also
External links
- Documentationː {{URL|https://jax.readthedocs.io/}}
- Colab (Jupyter/iPython) Quickstart Guideː {{URL|https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/quickstart.ipynb
}}
- TensorFlow's XLAː {{URL|https://www.tensorflow.org/xla}} (Accelerated Linear Algebra)
- YouTube TensorFlow Channel "Intro to JAX: Accelerating Machine Learning research": {{URL|https://www.youtube.com/watch?v=WdTeDXsOSj4}}
- Original paperː {{URL|https://mlsys.org/Conferences/doc/2018/146.pdf}}
References
{{reflist}}
{{differentiable computing}}
Category:Articles with example Python (programming language) code