brain of mat kelcey...

solving y=mx+b... with jax on a tpu pod slice

February 07, 2021 at 01:00 PM | categories: tpu, ensemble_nets, jax, projects, haiku

from jax fundamentals to running on a tpu pod slice

this 4 (and a bit) part tute series starts with jax fundamentals, builds up to describing a data parallel approach to training on a cloud tpu pod slice, and finishes with a tpu pod slice implementation of ensemble nets.... all with the goal of solving 1d y=mx+b

and though it may seem like a bit of overkill it turns out it's a good example to work through so that we can focus on the library support without having to worry about the modelling.

part 1: some jax basics

in this first section we introduce some jax fundamentals; e.g. make_jaxpr, grad, jit, vmap & pmap.

colab: 01 pmap jit vmap oh my.ipynb

part 2: solving y=mx+b

in part 2 we use the techniques from part 1 to solve y=mx+b in pure jax. we'll also introduce pytrees and various tree_utils for manipulating them.

we run first on a single device and work up to using pmap to demonstrate a simple data parallelism approach. along the way we'll do a small detour to a tpu pod slice to illustrate the difference in a multi host setup.

( note: the experience as described here for a pod slice isn't publically available yet; but sign up via the JAX on Cloud TPU Interest Form to get more info. see also this JAX on Cloud TPUs (NeurIPS 2020) talk )

colab: 02 y mx b on a tpu.ipynb

part 3: introducing haiku and optax

next we introduce haiku as a way of defining our model and optax as a library to provide standard optimisers. to illustrate there use we'll do a minimal port of our model and training loop to use them.

colab: 03 y mx b in haiku.ipynb

part 4: ensemble nets

in part 4 we'll reimplement ensemble nets for this trivial model, continuing to do things in a way that supports running on a tpu pod slice.

colab: 04 y mx b haiku ensemble.ipynb

part 5: some sanity

to wrap up we acknowledge that though tpu pod slices and data parallel approaches are fun we could have just solved this in a single calculation using the normal equation... :D

colab: 05 booooooooooooooooring.ipynb

what a way to solve y=mx+b !!!