Blox

A small, functional neural network library for JAX, built to keep JAX's strengths visible instead of paper over them

Abstractions shape how we think. JAX comes with strong abstractions of its own: composable transformations over pure functions, explicit state flow through function signatures, an XLA compilation model that rewards clean code. A neural network library on top either keeps those abstractions visible or hides them, and the ones that keep them visible pay you back model after model.

Blox is what I built: a small, modular, JAX-native neural network library. The whole mental model fits on one line:

outputs, params = model(params, inputs)

Parameters go in, outputs and updated parameters come out. Because state flows through the signature, every JAX transformation (jax.jit, jax.grad, jax.vmap, jax.checkpoint) works on a Blox model with no decorators or special-cased helpers.

Who it's for

📘 Learners

No framework magic to reverse-engineer. What you read is what runs, the cleanest way to understand how neural networks actually work at the JAX level.

⚙️ Practitioners

Full transparency for custom training loops, novel architectures, and scaling work where you actually want to see the execution stack.

Blox vs Equinox vs Flax NNX

JAX has more than one answer to "what neural-net library do I use on top?" The choice mostly comes down to where state lives and how it crosses JAX's transformation boundaries.

Equinox Flax NNX Blox
Where parameters live Inside the module: the module is the parameter tree. Inside mutable Module instances as Param variables. In a separate Params container, passed in and out of the model.
Model object An immutable PyTree (a dataclass-shaped one). A mutable Python class with reference semantics. A plain Python object describing the graph; the call signature is a pure function.
Params / graph separation Coupled: params and module structure are the same tree. Coupled: both live inside the Module. Decoupled: static Graph describes structure, dynamic Params holds arrays.
Using JAX transforms jax.jit / jax.grad work directly; non-array fields need eqx.filter_* variants. Cross transform boundaries via nnx.split / nnx.merge to swap between mutable and functional form. jax.jit, jax.grad, jax.vmap, jax.checkpoint all called directly. No wrappers.
Implicit / global state None. None: RNGs and shapes are threaded by the user. None. RNG counters live inside Params alongside the rest of the state.
Non-trainable state (RNG, batch-norm stats) Stored as PyTree leaves alongside weights; filtering decides what gets gradients. Different Variable subclasses (Param, BatchStat, …) split via nnx.split. All state in one container; params.split() divides trainable from non-trainable at the gradient boundary.
Surface area to learn Small: modules, filtering, a handful of helpers. Larger: Module system, Variables, split/merge, transform wrappers. Small: Graph, Module, Params, Rng. Three source files.

Equinox is a good fit if you like the "model is a PyTree" mental model and are comfortable reaching for filter_ variants when non-array fields show up. Flax NNX is the right pick if mutable, PyTorch-style module objects feel natural and you don't mind learning when to cross into the functional API, and it has a large community and a strong development team behind it. Blox optimises for a different thing. It keeps the model graph and the parameter tree as two separate objects, so the same params can drive different graphs (actor and learner, static and dynamic scan, train and eval), and every JAX transform applies to a plain function with no library-specific wrapper in the way.

Install

pip install jax-blox

Python 3.11+, JAX 0.10+. MIT licensed.