Blox

A functional, lightweight neural network library for JAX

Abstractions shape how we think. JAX comes with strong abstractions of its own — composable transformations over pure functions, explicit state flow through function signatures, an XLA compilation model that rewards clean code. A neural network library on top either keeps those abstractions visible or hides them, and the ones that keep them visible pay you back model after model.

Blox is what I built: a small, modular, JAX-native neural network library.

The whole mental model fits on one line:

outputs, params = model(params, inputs)

Parameters go in, outputs and updated parameters come out. Because state flows through the signature, every JAX transformation (jax.jit, jax.grad, jax.vmap, jax.checkpoint) works on a Blox model with no decorators or special-cased helpers.

I wrote it for two kinds of people. If you are learning JAX, there is no framework magic to reverse-engineer; what you read is what runs. If you are already shipping JAX in anger, you get full transparency for custom training loops, novel architectures, and scaling work where you actually want to see the execution stack.

Install with pip install jax-blox (Python 3.11+, JAX 0.10+). MIT licensed.