Core Concepts
Transformations
@jax.jit
- Just-in-time compilation to XLA for speed.
jax.grad()
- Automatic differentiation for gradients.
jax.vmap()
- Automatic vectorization for batching.
jax.pmap()
- Parallelization across multiple devices (e.g., GPUs/TPUs).
Random Numbers
JAX requires explicit PRNG key management.
from jax import random
key = random.PRNGKey(0)
# Split key before use
key, subkey = random.split(key)
r = random.normal(subkey)
# Common distributions
u = random.uniform(key)
i = random.randint(key, (5,), 0, 10)
- Never reuse keys; always split them.
Device & Memory
JAX runs on CPU, GPU, and TPU.
# List available devices
jax.devices()
# Explicitly place data on a device
x_gpu = jax.device_put(x, jax.devices('gpu')[0])
# Clear JIT and backend caches
jax.clear_caches()
jax.clear_backends()
- Data placement is usually automatic.
Common Usage
Basic Operations
import jax.numpy as jnp
# Create arrays
x = jnp.arange(10)
# Operations are similar to NumPy
y = jnp.sin(x)
z = jnp.dot(x, x)
# JAX arrays are immutable
# x[0] = 5 <- This will raise an error
# Use the .at property for updates
x = x.at[0].set(5)
Neural Network Snippets
# A simple linear layer
def linear_layer(params, x):
w, b = params
return jnp.dot(x, w) + b
# A simple ReLU activation
def relu(x):
return jnp.maximum(0, x)
# Combine them
def simple_net(params, x):
return relu(linear_layer(params, x))
Best Practices & Gotchas
Best Practices
- Apply
@jit
to your top-level, pure functions. - Use
vmap
for efficient batching instead of Python loops. - Use
lax.scan
for loops inside JIT-compiled functions. - Profile your code with
jax.profiler
to find bottlenecks. - Keep functions “pure”: no side-effects like printing or global state modification.
Common Gotchas
- Immutability: JAX arrays cannot be modified in place. Use the
.at[...].set(...)
syntax for updates. - PRNG Keys: Keys must be explicitly managed and split before use. Reusing keys will lead to identical “random” numbers.
- Side-effects:
print()
or modifying external variables won’t work as expected inside JIT-compiled functions due to out-of-order execution. - Dynamic Shapes: JIT compilation requires static shapes for arrays.