Search notes:

Python library: JAX

Installing JAX

$ sudo pip install jax
…
Successfully installed jax-…

$ python3
…
>>> import jax
…
ModuleNotFoundError: No module named 'jaxlib'
>>> exit()

$ sudo pip install jaxlib

Members of jax

abstract_arrays Module
ad_checkpoint Module
api_util Module
Array
block_until_ready()
checking_leaks
checkpoint()
checkpoint_policies
check_tracer_leaks
clear_backends()
closure_convert()
config
core Module
custom_batching Module
custom_derivatives Module
custom_gradient()
custom_jvp jax._src.custom_derivatives.custom_jvp class
custom_transpose Module
custom_vjp jax._src.custom_derivatives.custom_vjp class
debug Module
debug_infs
debug_nans
default_backend()
default_device
default_matmul_precision
default_prng_impl
_deprecated_ad Module
_deprecated_curry()
_deprecated_flatten_fun_nokwargs
_deprecated_partial_eval Module
_deprecated_pxla Module
_deprecated_ShapedArray jax._src.core.ShapedArray class
_deprecated_xla Module
_deprecations dict object
Device
device_count()
device_get()
device_put()
device_put_replicated()
device_put_sharded()
devices()
disable_jit()
distributed Module
dtypes Module
effects_barrier()
enable_checks
enable_custom_prng
enable_custom_vjp_by_custom_transpose
ensure_compile_time_eval()
errors Module
eval_shape()
float0
grad()
hessian()
host_count()
host_id()
host_ids()
image Module
interpreters Module
jacfwd()
jacobian()
jacrev()
jax Module
jax2tf_associative_scan_reductions
jit()
jvp()
lax Module
lib Module
linearize()
linear_transpose()
linear_util Module
live_arrays()
local_device_count()
local_devices()
log_compiles
make_array_from_callback()
make_array_from_single_device_arrays()
make_jaxpr()
monitoring Module
named_call()
named_scope()
nn Module
numpy Module
numpy_dtype_promotion
numpy_rank_promotion
ops Module
pmap()
print_environment_info()
process_count()
process_index()
profiler Module
pure_callback()
random Module
remat()
scipy Module (see also here)
ShapeDtypeStruct jax._src.api.ShapeDtypeStruct class
Shard jax._src.array.Shard class
sharding Module
spmd_mode
_src Module
stages Module
transfer_guard()
transfer_guard_device_to_device
transfer_guard_device_to_host
transfer_guard_host_to_device
treedef_is_leaf()
tree_flatten()
tree_leaves()
tree_map()
tree_structure()
tree_transpose()
tree_unflatten()
tree_util Module
typing Module
util Module
value_and_grad()
version Module
vjp()
vmap()
xla_computation()

See also

Haiku is a library built on top of JAX. It provides a simple, composable abstraction for machine learning research.
Paxml (aka Pax) is a framework to configure and run machine learning experiments on top of Jax.
Pax was used to train (or develop?) PaLM 2.

Index