JAX cheatsheet

A comprehensive guide to JAX, a high-performance library for numerical computing and machine learning research.

Core Concepts

Transformations

@jax.jit
Just-in-time compilation to XLA for speed.
jax.grad()
Automatic differentiation for gradients.
jax.vmap()
Automatic vectorization for batching.
jax.pmap()
Parallelization across multiple devices (e.g., GPUs/TPUs).

Random Numbers

JAX requires explicit PRNG key management.

from jax import random
key = random.PRNGKey(0)

# Split key before use
key, subkey = random.split(key)
r = random.normal(subkey)

# Common distributions
u = random.uniform(key)
i = random.randint(key, (5,), 0, 10)
  • Never reuse keys; always split them.

Device & Memory

JAX runs on CPU, GPU, and TPU.

# List available devices
jax.devices()

# Explicitly place data on a device
x_gpu = jax.device_put(x, jax.devices('gpu')[0])

# Clear JIT and backend caches
jax.clear_caches()
jax.clear_backends()
  • Data placement is usually automatic.

Common Usage

Basic Operations

import jax.numpy as jnp

# Create arrays
x = jnp.arange(10)

# Operations are similar to NumPy
y = jnp.sin(x)
z = jnp.dot(x, x)

# JAX arrays are immutable
# x[0] = 5  <- This will raise an error
# Use the .at property for updates
x = x.at[0].set(5)

Neural Network Snippets

# A simple linear layer
def linear_layer(params, x):
    w, b = params
    return jnp.dot(x, w) + b

# A simple ReLU activation
def relu(x):
    return jnp.maximum(0, x)

# Combine them
def simple_net(params, x):
  return relu(linear_layer(params, x))

Best Practices & Gotchas

Best Practices

  • Apply @jit to your top-level, pure functions.
  • Use vmap for efficient batching instead of Python loops.
  • Use lax.scan for loops inside JIT-compiled functions.
  • Profile your code with jax.profiler to find bottlenecks.
  • Keep functions “pure”: no side-effects like printing or global state modification.

Common Gotchas

  • Immutability: JAX arrays cannot be modified in place. Use the .at[...].set(...) syntax for updates.
  • PRNG Keys: Keys must be explicitly managed and split before use. Reusing keys will lead to identical “random” numbers.
  • Side-effects: print() or modifying external variables won’t work as expected inside JIT-compiled functions due to out-of-order execution.
  • Dynamic Shapes: JIT compilation requires static shapes for arrays.

References