abstract_arrays | Module |
ad_checkpoint | Module |
api_util | Module |
Array | |
block_until_ready() | |
checking_leaks | |
checkpoint() | |
checkpoint_policies | |
check_tracer_leaks | |
clear_backends() | |
closure_convert() | |
config | |
core | Module |
custom_batching | Module |
custom_derivatives | Module |
custom_gradient() | |
custom_jvp | jax._src.custom_derivatives.custom_jvp class |
custom_transpose | Module |
custom_vjp | jax._src.custom_derivatives.custom_vjp class |
debug | Module |
debug_infs | |
debug_nans | |
default_backend() | |
default_device | |
default_matmul_precision | |
default_prng_impl | |
_deprecated_ad | Module |
_deprecated_curry() | |
_deprecated_flatten_fun_nokwargs | |
_deprecated_partial_eval | Module |
_deprecated_pxla | Module |
_deprecated_ShapedArray | jax._src.core.ShapedArray class |
_deprecated_xla | Module |
_deprecations | dict object |
Device | |
device_count() | |
device_get() | |
device_put() | |
device_put_replicated() | |
device_put_sharded() | |
devices() | |
disable_jit() | |
distributed | Module |
dtypes | Module |
effects_barrier() | |
enable_checks | |
enable_custom_prng | |
enable_custom_vjp_by_custom_transpose | |
ensure_compile_time_eval() | |
errors | Module |
eval_shape() | |
float0 | |
grad() | |
hessian() | |
host_count() | |
host_id() | |
host_ids() | |
image | Module |
interpreters | Module |
jacfwd() | |
jacobian() | |
jacrev() | |
jax | Module |
jax2tf_associative_scan_reductions | |
jit() | |
jvp() | |
lax | Module |
lib | Module |
linearize() | |
linear_transpose() | |
linear_util | Module |
live_arrays() | |
local_device_count() | |
local_devices() | |
log_compiles | |
make_array_from_callback() | |
make_array_from_single_device_arrays() | |
make_jaxpr() | |
monitoring | Module |
named_call() | |
named_scope() | |
nn | Module |
numpy | Module |
numpy_dtype_promotion | |
numpy_rank_promotion | |
ops | Module |
pmap() | |
print_environment_info() | |
process_count() | |
process_index() | |
profiler | Module |
pure_callback() | |
random | Module |
remat() | |
scipy | Module (see also here) |
ShapeDtypeStruct | jax._src.api.ShapeDtypeStruct class |
Shard | jax._src.array.Shard class |
sharding | Module |
spmd_mode | |
_src | Module |
stages | Module |
transfer_guard() | |
transfer_guard_device_to_device | |
transfer_guard_device_to_host | |
transfer_guard_host_to_device | |
treedef_is_leaf() | |
tree_flatten() | |
tree_leaves() | |
tree_map() | |
tree_structure() | |
tree_transpose() | |
tree_unflatten() | |
tree_util | Module |
typing | Module |
util | Module |
value_and_grad() | |
version | Module |
vjp() | |
vmap() | |
xla_computation() | |