brain of mat kelcey...

differentiable kalman filters in jax

March 29, 2024 at 06:45 PM | categories: jax

kalman filters have numerous parameters to tune, why not roll them out with jax and use backprop to fit them

high performance ML with JAX

September 12, 2021 at 12:30 PM | categories: jax, talk

did a talk ay pycon on jax. check out the recording!

evolved channel selection

March 01, 2021 at 10:20 PM | categories: projects, ga, jax

rather than use all 13 channels in a multi spectral image for classification can we train a model that is robust to all combos, at all resolutions, and use a genetic algorithm to choose which are the most valuable? (spoiler; yes)

crazy large batch sizes

February 14, 2021 at 10:30 PM | categories: quick_hack, tpu, jax

a quick hack to see how fast we can get a v3-32 pod slice cranking with a global batch size of 170,000; tl-dr pretty fast!

solving y=mx+b... with jax on a tpu pod slice

February 07, 2021 at 01:00 PM | categories: tpu, ensemble_nets, jax, projects, haiku

a 4 (and a bit) part tutorial / colab / screencast series starting with jax fundamentals working up a data parallel approach to running on a cloud tpu pod slice... all focused on solving the toughest problem in machine learning; 1d y=mx+b

out of distribution detection using focal loss

December 02, 2020 at 01:00 PM | categories: objax, jax, projects

a series of small experiments on using focal loss to do out of distribution detection

dithernet very slow movie player

October 21, 2020 at 10:30 PM | categories: gan, jax, projects, objax

a GAN experiment to generate dithers for an eink screen minimising pixel change between frames for a very slow movie player.

ensemble networks

September 17, 2020 at 06:30 AM | categories: objax, projects, ensemble_nets, jax

ensemble nets; using jax vmap to batch over not just the inputs of a model but also sets of multiple models parameters.

metric learning for image similarity search in objax

September 02, 2020 at 12:00 PM | categories: objax, metric_learning, jax

an objax tutorial on using metric learning for image similarity.

a jax random embedding ensemble network

June 15, 2020 at 06:30 AM | categories: ensemble_nets, jax

random embedding networks can be used to generate weakly labelled data for metric learning and they see a large benefit from being run in ensembles. can we represent these ensembles as a single forward pass in jax? why yes! yes we can!

popular posts...

ensemble nets : training ensembles as a single model using jax on a tpu pod slice(sept 2020)

bnn : counting bees with a rasp pi (may 2018)

drivebot : learning to do laps with reinforcement learning and neural nets (feb 2016)

wikipedia philosophy : do all first links on wikipedia lead to philosophy? (aug 2011)

cartpole++ : deep RL hacking with a complex 3d cart pole environment (aug 2016)

malmomo : deep RL hacking on minecraft with malmo (jan 2017)

some papers from my time at google research / brain...

my honours thesis

the co-evolution of cooperative behaviour (1997) evolving neural nets with genetic algorithms for communication problems.

old projects...