Skip to content

JAX Roll🔗

JAX Roll Implementation of Finite Differences

JAX Roll Implementation of Finite Differences

See how jnp.roll implements the finite difference stencil for the wave equation. Watch how arrays are shifted and boundary conditions are applied.

Wave Equation Implementation

# Central difference scheme using jnp.roll
u1p = jnp.roll(u1, 1).at[0].set(0) # u[j-1]
u1n = jnp.roll(u1, -1).at[n-1].set(0) # u[j+1]
u2 = 2 * u1 - u0 + C2 * (u1p - 2 * u1 + u1n)

Key Insight: Instead of explicit indexing, we use jnp.roll to shift the entire array, then apply boundary conditions using .at[].set()

Finite Difference Stencil:
∂²u/∂x² ≈ (u[j-1] - 2u[j] + u[j+1]) / Δx²
1. Original u1
2. Roll Left (+1)
3. Roll Right (-1)
4. Finite Difference
Original Array: u1

Current wave state

u1p = jnp.roll(u1, +1)

Left neighbors (boundary at [0] = 0)

u1n = jnp.roll(u1, -1)

Right neighbors (boundary at [n-1] = 0)

Finite Difference Result
u1p[j]
-
2×u1[j]
+
u1n[j]

Second derivative approximation