from jax fundamentals to running on a tpu pod slice
this 4 (and a bit) part tute series starts with
fundamentals, builds up to describing a data parallel approach to training on a
cloud tpu pod slice, and
finishes with a tpu pod slice implementation of
all with the goal of solving 1d
and though it may seem like a bit of overkill it turns out it's a good example to work through so that we can focus on the library support without having to worry about the modelling.
part 1: some jax basics
in this first section we introduce some jax fundamentals; e.g. make_jaxpr, grad, jit, vmap & pmap.
colab: 01 pmap jit vmap oh my.ipynb
part 2: solving y=mx+b
we run first on a single device and work up to using pmap to demonstrate a simple data parallelism approach. along the way we'll do a small detour to a tpu pod slice to illustrate the difference in a multi host setup.
( note: the experience as described here for a pod slice isn't publically available yet; but sign up via the JAX on Cloud TPU Interest Form to get more info. see also this JAX on Cloud TPUs (NeurIPS 2020) talk )
colab: 02 y mx b on a tpu.ipynb
part 3: introducing haiku and optax
next we introduce haiku as a way of defining our model and optax as a library to provide standard optimisers. to illustrate there use we'll do a minimal port of our model and training loop to use them.
colab: 03 y mx b in haiku.ipynb
part 4: ensemble nets
in part 4 we'll reimplement ensemble nets for this trivial model, continuing to do things in a way that supports running on a tpu pod slice.
part 5: some sanity
to wrap up we acknowledge that though tpu pod slices and data parallel approaches are fun we could have just solved this in a single calculation using the normal equation... :D