JAX

E95194

JAX is a high-performance numerical computing library for Python that combines NumPy-like APIs with automatic differentiation and just-in-time compilation, widely used for machine learning and scientific computing.


Statements (57)
Predicate Object
instanceOf Python library
numerical computing library
open-source software
compatibleWith Flax
Haiku
NumPy NERFINISHED
Optax
SciPy ecosystem
TensorFlow Probability (JAX backend)
developedBy Google
Google Research
documentation https://jax.readthedocs.io
https://jax.readthedocs.io/en/latest/
hasAPIStyle NumPy-like API
hasComponent jax.experimental
jax.lax
jax.numpy
jax.random
implements NumPy API subset
XLA-backed array operations
automatic differentiation primitives
license Apache License 2.0 NERFINISHED
programmingLanguage Python
repository https://github.com/google/jax
supportsFeature GPU acceleration
TPU acceleration
XLA compilation
automatic differentiation
custom gradients
differentiation of Python functions
forward-mode automatic differentiation
functional transformations
grad-based optimization
higher-order differentiation
jit compilation decorator
just-in-time compilation
just-in-time compiled NumPy operations
just-in-time compiled control flow
parallelization
pmap parallel mapping
random number generation
reverse-mode automatic differentiation
vectorization
vmap vectorized mapping
targetUser engineers
machine learning researchers
scientists
usedFor deep learning
differentiable programming
large-scale linear algebra
machine learning research
neural network training
numerical optimization
probabilistic modeling
scientific computing
simulation-based inference
writtenIn Python


Please wait…