in my last post, wavenet on an mcu, we talked about running a cached optimised version of wavenet on an MCU at 48kHz. this time we're going to get things running on an FPGA instead!
note: this post assumes you've read that last one; it describes the model architecture, some caching tricks, and describes a waveshaping task.
when doing some research on how to make things faster i stumbled on the eurorack-pmod project which is a piece of hardware that connects an FPGA with a eurorack setup. it includes 192kHz analog-digital-analog conversion and supports 4 channels in and out. just add an FPGA! perfect!
i'd never done anything beyond blinking LEDs with an FPGA but it can't be that hard.... right? right?!?!
well. turns out it wasn't that simple... i couldn't find any open source examples of compiling a neural net to an FPGA. the closest was HLS4ML but, if i understand it correctly, it works only with (expensive) propertiary FPGA toolchains :(
HLS4ML it did at least though put me onto qkeras which is a key component we'll talk about in a second.
( not interested in details? jump straight to the bottom for the demos )
recall from the last post that an important aspect of the MCU project was having multiple versions for training vs inference
the FPGA version ended with a similar collection of variants;
for both the MCU and FPGA versions we use a single firmware running on the daisy patch for training data collection.
let's expand on each part of the FPGA version...
neural networks end up doing a lot of matrix multiplication. a lot.
for training it's common to work with a floating point representation since it gives the best continuous view of the loss space.
but for inference we can often reduce the precision of the weights and use integer arithmetic. we're motivated to do this since, for a lot of systems, integer math is much faster to run than full floating point math.
( and, actually, in a lot of systems we just simply might not have the ability to do floating point math at all! )
the systems we use for converting from floats during training to something simpler for inference is called quantisation and we'll look at two flavours for this project.
for this project i mainly used fixed point numbers.
fixed point numbers are a simpler representation of floating point numbers with some constraints around range and precision but they allow the multiplication to be done as if it were integer multiplication. ( this is the first project i've done with fixed point math and i'm hooked, it's perfect for neural nets! )
the high level idea is you specify a total number of bits and then how many of those bits you want to use for the integer part of the number, and how many you want to use from representing the fractional part.
in this project all inputs, outputs, weights and biases are 16 bits in total with 4 bits for the integer and 12 bits for the fractional part. ( i.e. FP4.12 )
with "only" 4 bits for the integer part this means the range of values is +/- 2^4 = 8. though this might seem limiting, it's actually ok for a network, where generally activations etc are centered on zero ( especially if we add a bit of L2 regularisation along the way )
with 12 bits allocated to the fractional part we are able to describe numbers with a precision of 2^-12 = 0.00024.
here are some examples; we show the 16 bit binary number with a decimal point after the 4th bit to denote the change from the integer part to the fractional part.
bits | decimal |
---|---|
0010.0000 0000 0000 | 2^1 = 2 |
0101.0000 0000 0000 | 2^2 + 2^0 = 4 + 1 = 5 |
0000.0000 0000 0000 | 0 |
0000.1000 0000 0000 | 2^-1 = 0.5 |
0000.1001 0000 0000 | 2^-1 + 2^-4 = 0.5 + 0.0625 = 0.5625 |
0000.0000 0000 0001 | 2^-12 = 0.000244140625 |
the purpose of the qkeras model then is to train in full float32 but provide the ability to export the weights and biases in this configured fixed point configuration.
notes...
the options for quantisation can get pretty crazy too!
qkeras provides another scheme called power-of-two quantisation where all values are quantised to be only powers of two. i.e. depending on the fixed point config, a weight/bias can only only be one of [ +/- 1, 0, +/- 1/2, +/- 1/4, +/- 1/8, ...]
though this seems overly restrictive it has one HUGE important benefit... when a weight is a power of two then the "multiplication" of a feature by that weight can be simply done with a bit shift operation. and bit shifting is VERY fast.
and there are ways to "recover" representational power too, the best one i found being based around matrix factorisation.
if we have, say, a restricted weight matrix W of shape (8, 8) it can only contain those fixed values. but if we instead represent W as a a product of matrices, say, (8, 32). (32, 8) then we can see that, even though all the individual weights are restricted, the product of the matrices, an effective (8, 8) matrix, has many more possible values.
the pro is that the weights can take many more values, all the combos of w*w. the con though is we have to do two mat muls. depending on how much space we have for allocating the shift operations compared to the number of multiply unit we have, this tradeoff might be ok!
i messed around with this a lot and though it was interesting, and generally worked, it turned on that for the FPGA sizing ( i'm using at least ) the best result was to just use fixed point multiplication instead. :/ am guessing this is fault on my part in terms of poor verilog design and i still have some ideas to try at least...
anyways, back to the models. the next model after the qkeras one is a fxpmath one.
the fxpmath version connects qkeras model fixed point export with the caching approach the inference logic that will go into the verilog design
the activation caching has two two elements that are basically the same as the MCU version; the left shift buffer
for handling the first input and an activation cache
for handling the activations between each convolutional
layer.
the big difference comes in with the implementation of the convolution which i had to roll from scratch :/ but at least it allows for some very optimised parallelisation.
consider a 1D convolution with kernel size K=4 & input / output feature depth of 16.
since we don't intend to stride this convolution at all ( that's handled implicitly by the activation caching ) we can treat this convolution as the following steps...
each of the x4 matrix multiplications are actually just a row by matrix multiplication, so can be decomposed into k=16
independent dot products ( we'll call this a row_by_matrix_multiply
from now on )
and each of those dot products can be decomposed into k=16 independent multiplications followed by an accumulation.
the reason for being so explicit about what is independent, and what isn't, comes into play with the verilog version.
verilog is a language used to "program" FPGAs and 's not at all like other languages, it's been really interesting to learn.
in the MPU version the two big concerns were
but the FPGA version is a little different. instead we more flexibility to design things based on what we want to run in parallel, vs what we want to run sequentially. for neural networks this gives lots of options for design!
at a 30,000" view of verilog we have two main concerns; 1) executing code in parallel 2) executing blocks of code sequentially.
e.g. consider a dot product A.B with |A|=|B|=4
normally we'd think of this as simply something like a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3]
provided by a single np.dot
call.
but if we're writing verilog we have to think of things in terms of hardware, and that means considering parallel vs sequential.
the simplest sequential way would be like the following psuedo code...
// WARNING PSEUDO CODE!
case(state):
multiply_0:
// recall; the following three statements are run _in parallel_
accum <= 0 // set accumulator to 0
product <= a[0] * b[0] // set intermediate product variable to a0.b0
state <= multiply_1 // set the next state
multiply_1:
accum <= accum + product // update accumulator with the product value from the last state
product <= a[1] * b[1] // set product variable to a0.b0
state <= multiply_2 // set the next state
multiply_2:
accum <= accum + product
product <= a[2] * b[2]
state <= multiply_1
multiply_3:
accum <= accum + product
product <= a[3] * b[3]
state <= final_accumulate
final_accumulate:
accum <= accum + product
state <= result_ready
done:
// final result available in `accum`
state <= done
doing things this way does the dot product in 5 cycles...
an important aspect of how verilog works is to note that the statements in one of those case clauses all run in parallel. i.e. all right hand sides are evaluated and then assigned to the left hand side
e.g. the following would swap a and b
a <= b; b <= a;
and if functionally equivalent to ...
b <= a; a <= b;
so thinking in terms of sequential vs parallel we have the option to do more than one multiplication at any given time. this requires the hardware to be able to support 2 multiplications per clock cycle, but saves some clock cycles. not much in this case, but it add ups as |a| and |b| increase...
// WARNING PSEUDO CODE!
case(state):
multiply_01:
accum_0 <= 0
accum_1 <= 0
product_0 <= a[0] * b[0]
product_1 <= a[1] * b[1]
state <= multiply_23
multiply_23:
accum_0 <= accum_0 + product_0
accum_1 <= accum_1 + product_1
product_0 <= a[2] * b[2]
product_1 <= a[3] * b[3]
state <= accumulate_0
accumulate_0:
accum_0 <= accum_0 + product_0
accum_1 <= accum_1 + product_1
state <= accumulate_1
accumulate_1:
accum_0 <= accum_0 + accum_1
state <= done
done:
// final result available in `accum_0`
state <= done
or, if we want to go nuts, and, again, can support it, we can do all the elements multiplications at the same time, and then hierarchically accumulate the result into one.
// WARNING PSEUDO CODE!
case(state):
multiply_all:
// calculate all 4 elements of dot product in parallel; ( note: requires 4 available multiplication units )
product_0 <= a[0] * b[0]
product_1 <= a[1] * b[1]
product_2 <= a[2] * b[2]
product_3 <= a[3] * b[3]
state <= accumulate_0
accumulate_0:
// add p0 and p1 at the same time as p2 and p3
accum_0 <= product_0 + product_1
accum_1 <= product_2 + product_3
accumulate_1:
// final add of the two
accum_0 <= accum_0 + accum_1
state <= done
done:
// final result available in `accum0`
state <= done
this general idea of how much we do in a single clock cycles versus making values available in the next clock gives a lot of flexibility for a design
specifically for this neural net i've represented the K=4 conv 1d as
dot_product
module
code
where the elements of dot products are calculated sequentially; i.e. x[0]*w[0]
in the first step, x[1]*w[1]
in the second, etc.
so a dot product of N values takes N*M clock cycles ( where M is the number of cycles the multiple unit takes ) + some parallel accumulation.
row_by_matrix_multiply
module
code
which runs the j=16 dot products required for each row_by_matrix_multiply
in parallel
conv1d
module
code
that runs the K=4 row_by_matrix_multiply
s are also in parallel, as well as handling the state machine for accumulating results with a bias and applying relu.
so we end up having all j=16 * K=4 = 64 dot products run in parallel, all together.
having said this, there are a number of ways to restructure this; e.g. if there were too many dot products to run
in parallel for the 4 row_by_matrix_multiply
we could run 2 of them in parallel, and then when they were finished,
run the other 2. there are loads of trade offs between the number of multiple units available vs the time required to run them.
in the MCU version i was only feeding in samples based on the embedding corner points, one of 4 types sampled randomly from...
input ( core wave, e0, e1 ) | output |
---|---|
(triangle, 0, 0) | sine |
(triangle, 0, 1) | ramp |
(triangle, 1, 1) | zigzag |
(triangle, 1, 0) | square |
for the first FPGA version i did this but the model was large enough that it was quickly overfitting this and basically outputing noise for the intermediate points. as such for training of this model i changed things a bit to include interpolated data.
basically we emit corner points, say (e0=0, e1=1, sine wave) as well as interpolated points, where we pick two waves, say sine and ramp, and a random point between them and train for that point as an interpolated wave between the two ( using constant power interpolation )
though i couldn't get the MCU model to converge well with this kind of data, the larger FPGA variant has no problems.
i also messed around with a 3d input embeddings; to translate between any pairing, but it didn't add anything really so i stuck with 2d.
where as the final MCU model was ...
--------------------------------------------------------------------------- Layer (type) Output Shape Par# Conv1D params --------------------------------------------------------------------------- input (InputLayer) [(None, 256, 3)] 0 c0a (Conv1D) (None, 64, 4) 52 F=4, K=4, D=1, P=causal c0b (Conv1D) (None, 64, 4) 20 F=4, K=1 c1a (Conv1D) (None, 16, 4) 68 F=4, K=4, D=4, P=causal c1b (Conv1D) (None, 16, 4) 20 F=4, K=1 c2a (Conv1D) (None, 4, 4) 68 F=4, K=4, D=16, P=causal c2b (Conv1D) (None, 4, 4) 20 F=4, K=1 c3a (Conv1D) (None, 1, 8) 136 F=8, K=4, D=64, P=causal c3b (Conv1D) (None, 1, 8) 72 F=8, K=1 y_pred (Conv1D) (None, 1, 1) 13 F=1, K=1 --------------------------------------------------------------------------- Trainable params: 465 ---------------------------------------------------------------------------
... the current FPGA version i'm running is ...
----------------------------------------------------------------- Layer (type) Output Shape Param # ----------------------------------------------------------------- input_1 (InputLayer) [(None, 64, 4)] 0 qconv_0 (QConv1D) (None, 16, 16) 272 qrelu_0 (QActivation) (None, 16, 16) 0 qconv_1 (QConv1D) (None, 4, 16) 1040 qrelu_1 (QActivation) (None, 4, 16) 0 qconv_2 (QConv1D) (None, 1, 4) 260 ----------------------------------------------------------------- Trainable params: 1,572 -----------------------------------------------------------------
it has the following differences
so compared to the MCU version
to be honest utilisation is a bit harder to compare; the trade off between compute and space is quite different with an FPGA design..
for each sample coming in at 192kHz the FPGA is running a simple state machine of 1) accept next sample 2) run the sequence of qconvs and activation caches, then 3) output the result and sit in a while-true loop until the next sample. when i say above the FPGA is running at 30% what i really should say is that it's spending 70% of the time in the post sample processing while-true loop waiting for the next sample.
looking at the device utilisation we have the following..
Info: Device utilisation: Info: TRELLIS_IO: 11/ 365 3% Info: DCCA: 5/ 56 8% Info: DP16KD: 24/ 208 11% Info: MULT18X18D: 134/ 156 85% Info: EHXPLLL: 1/ 4 25% Info: TRELLIS_FF: 21081/83640 25% Info: TRELLIS_COMB: 44647/83640 53% Info: TRELLIS_RAMW: 192/10455 1%
the pieces of interest are...
DP16KD
: which is the amount of ( one type ) of RAM being used; this looks to be dominated by the activation cache,
so with only 11% being used there is a lot of room for having more layers.
MULT18X18D
: is the big one, it's the max number of multiplication DSP units being used at any one time. in this model that qconv1
with in_dim = out_dim = 16. since it's already 85% if we wanted to increase the filter size much more we might be forced to not
run all 16 dot products of the 16x16 row_by_matrix_multiply
at once but instead, say, do 8 in parallel, then the other 8.
this would incur a latency hit, but that's totally fine given we still have a lot of clock time available to do work between samples.
the trouble is the code as written would end up being tricky to refactor.
currently things are setup so that the entire network has to run before the next sample comes in. this is just because it was the simplest thing to do while i'm learning verilog and it seems like the FPGA is fast enough for it. but with a neural net it doesn't have to be like that; we really just need to finish the first layer before next sample comes, not the whole network. as long as we don't mind a little bit of output latency we can run a layer per sample clock tick. doing this would actually delay the output by number-of-layer sample clock ticks, but at 192kHz that'd be fine :)
another way to run a bigger network is to continue using the same naive MULT18X18D dsp allocation but just use an intermediate layer twice; e.g. if you have a network input -> conv0 -> output you can get extra depth but running input -> conv0 -> conv0 -> output instead. you lose a bit of representation power, since the same layer needs to model two layers, but sometimes it's worth it. in this model we'd get extra depth without having to worry about more allocation, and we've got plenty of headroom to do more compute.
the work in progress model i've been tinkering with for the power of two quantisation is the following...
_________________________________________________________________ Layer (type) Output Shape Param # _________________________________________________________________ input_1 (InputLayer) [(None, 640, 4)] 0 qconv_0_qb (QConv1D) (None, 640, 8) 136 qrelu_0 (QActivation) (None, 640, 8) 0 qconv_1_qb (QConv1D) (None, 640, 8) 264 qrelu_1 (QActivation) (None, 640, 8) 0 qconv_1_1a_po2 (QConv1D) (None, 640, 16) 144 qconv_1_1b_po2 (QConv1D) (None, 640, 8) 136 qrelu_1_1 (QActivation) (None, 640, 8) 0 qconv_1_2a_po2 (QConv1D) (None, 640, 16) 144 qconv_1_2b_po2 (QConv1D) (None, 640, 8) 136 qrelu_1_2 (QActivation) (None, 640, 8) 0 qconv_2_qb (QConv1D) (None, 640, 4) 132 _________________________________________________________________ Trainable params: 1,092 _________________________________________________________________
_qb
are the normal quantised bits layers which are fixed point weights and use MULT18X18D
units for inference.
_po2
have power-of-two weights and use just shift operators for inference.
the output quality is the same and it uses less MULT18X18D
units but doesn't quite fit :)
Info: Device utilisation: Info: TRELLIS_IO: 11/ 365 3% Info: DCCA: 5/ 56 8% Info: DP16KD: 12/ 208 5% Info: MULT18X18D: 70/ 156 44% Info: EHXPLLL: 1/ 4 25% Info: TRELLIS_FF: 37356/83640 44% Info: TRELLIS_COMB: 85053/83640 101% close!! Info: TRELLIS_RAMW: 96/10455 0%
i spend a bunch of timing trying various combos of reuse of the modules but never had a design that would meet timing for the FPGA :/
i still feel i'm doing something wrong here, and might come back to it.
waveforms generated by the model across the embedding space |
the above images show the range of waveforms generated by the model across the two dimensional embedding space. of note...
lets look at some examples!
( make sure subtitles are enabled! )
an example of a triangle core wave as input and a manual transistion of the embedding values between the corners of the 2d space.
modulating the embedding x value at audio rates makes for some great timbres! the FPGA and the eurorack pmod have no problems handling this.
since the model was trained only on a triangle wave if you give it something discontinuous, like a ramp or square, it glitches! :)
and if you sequence things, add an envelope to the embedding point & stick in some 909 kick & hats.... what do you have? neural net techno! doff. doff. some delay on the hats and clap, but no effects on the oscillator. ( mixed the hats too loud as well, story of my life )
there is a lot of room to optimise for a larger network. see the issues on github
the code for training and verilog simulation is in this github repo
whereas the code representing this as core on the eurorack pmod is in this github repo on a branch
]]>the electro smith daisy patch is a eurorack module made for prototyping; it provides a powerful audio focussed microcontroller setup (the daisy 'seed' with arm cortex-m7 running at 480MHz) along with all the connectivity required to be a eurorack module.
pretty much the first thing i thought when i saw one was; "could it run a neural net?" microcontrollers are fast, but so too is audio rate.
after a bit of research i found there are a couple of people who have done some real time audio effects processing on the daisy seed; e.g. guitarML's neural seed
all the examples i found though were quite small recurrent models and i was keen to give wavenet a go instead ( i'm a big fan of causal dilated convolutions )
the wavenet architecture is a 1D convolutional network designed to operate on timeseries data. it is composed of two key structures;
usually it's just dilation convolutions stacked but for my variant i also include one extra 1x1 conv between each dilation; a 1x1 doesn't break any of the dilation structure wavenet needs and including another non linearity is almost always a good thing.
the first step was to write a firmware that just recorded some audio and/or control voltages into buffers and then streamed them out over a serial connection. it's kinda clumsy, and i'm sure there's a cleaner way, but it works without having to mess around too much. the code is datalogger_firmware/
from there a wavenet could be trained ( see keras_model.py ) and it was trivial to quantise and export to a c++ lib for the microcontroller using edge impulse's bring-your-own-model
from there i got a bit stuck though integrating tensorflow with the daisy
optimised arm_math.h
stuff ( Make files and such aren't my expertise ) so instead i
thought i'd use this chance to write my own inference. convolutions don't need much,
it's all just a bunch of matrix math after all.
while poking around doing this it suddenly occured to me that you could do heavy heavy caching of the convolution activations! i was super excited, thinking i was doing some novel, and it wasn't until i was actually finished that i found out someone else had thought of the same idea basically, what they call fast-wavenet :/
oh well, it was fun to discover independently i suppose, and i ended up implementing things a bit differently.
so let's walk through the caching optimisation...
consider the following wavenet like network; it has 16 inputs being integrated thru 3 layers to a single prediction
if we consider a sliding window input at time step 4 we see that the output of node 0 ( the node with the red 0 ) will be the processed values from [0, 1, 2, 3]
a bit later on an interesting things happens at time step 8; node 1 now gets the same inputs that node 0 had 4 steps ago.
so we don't need to calculate it! when processing a timeseries we can see that the node 1 output just lags node 0 by 4 steps, and nodes 2 and 3 lag another 4 each. as long as we've got the memory for caching we can store all these in a circular buffer.
and this whole thing is stackable! in fact we only ever need to run the right hand side convolutions, as long as we have the memory to cache.
one big win for this on the daisy is that it can all run fast and small enough that we can use full float32 math everywhere. hoorah!
in terms of coding the whole thing becomes an exercise in
the following model has an input of 3 values => a receptive field of 256 steps, and a single output. it runs on the daisy using the above caching technique at 48KHz @88% CPU ( ~20 micro seconds per inference ). we'll describe what the 3 input values are in a bit. it's the model we'll use the examples at the end.
cNa - denotes the dilated convolutions cNb - denotes the 1x1s convolutions that follow cNa F - number of filters K - kernel size; either 4 for cNa or 1 for cNb D - dilation; K^layer# for cNa, or 1 for cNb P - padding; 'causal' or the layer dft 'valid'
____________________________________________________________________________________________ Layer (type) Output Shape Par# Conv1D params ============================================================================================ input (InputLayer) [(None, 256, 3)] 0 c0a (Conv1D) (None, 64, 4) 52 F=4, K=4, D=1, P=causal c0b (Conv1D) (None, 64, 4) 20 F=4, K=1 c1a (Conv1D) (None, 16, 4) 68 F=4, K=4, D=4, P=causal c1b (Conv1D) (None, 16, 4) 20 F=4, K=1 c2a (Conv1D) (None, 4, 4) 68 F=4, K=4, D=16, P=causal c2b (Conv1D) (None, 4, 4) 20 F=4, K=1 c3a (Conv1D) (None, 1, 8) 136 F=8, K=4, D=64, P=causal c3b (Conv1D) (None, 1, 8) 72 F=8, K=1 y_pred (Conv1D) (None, 1, 1) 13 F=1, K=1 ============================================================================================ Total params: 465 ____________________________________________________________________________________________
as a side note, the daisy can actually run at 96kHz but at this rate i could only run a smaller [c0a, c0b, c1a, c2b] model with just 2 filters each. it runs, but i couldn't get it to train on the examples i show below, so i didn't use it. shame because naming the blog post "at (almost) 100,000 inferences a second" would have had a nicer ring to it :D
it can also run slower, e.g. 32kHz, which allows either more filters per step, or even more depth => large receptive field.
but these combos demonstrate an interesting set of trade offs we have between
for training i just use keras.layer.Conv1D
everywhere but then exported to the device in two passes; a python
prototype and then the c++ code for the device.
the final code on the daisy uses the cmsis-dsp library for all the matrix math. only three pieces end up being used though
arm_mat_init_f32
for making the matrix structures,
arm_mat_mult_f32
for the actual matrix multiplications and
arm_add_f32
for the kernel accumulation and adding biases here and there.
to be honest none of these were benchmarked and i assume they are faster than if the multiplications were just rolled out (???)
the first pass was to get the inference working as a prototype in python. the cmsisdsp python lib was used ( a pure python api equivalent ) and the code is cmsisdsp_py_version/.
having cmsisdsp
was very handy to prototype, especially as i've not done any cmsis
stuff before.
the code for the daisy is under
inference_firmware/
and is a port of the python version. it's where i probably spent the most amount of time, especially block.h
it has 4 main parts
left_shift_buffer.h
to handle the shifted InputLayer
block.h
to handle the sequential running of each cNa
and cNb
pair
rolling_cache.h
which represents the circular buffer used for activation lag and
regression.h
which is just a simple y=mx+b
using the weights from the final 1x1 Conv1D regression used during training
it all ends up being pretty straightforward but i did iterate a bit; any code involving memcpy and pointer arithmetic needs careful attention.
the "best" bit of the code is where the initial kera model is trained in a python notebook and
then exported to be used in the c++ code by having python code construct a model_defn.h
with
a bunch of print statements :/ it's actually written to /tmp/model_defn.h
no less! what
a hack, lol. may god have mercy on my soul.
for a demo project i wanted to make a little waveshaper; something that takes one waveform and outputs another.
so 5 waveforms were collected from another eurorack oscillator; a triangle, sine, saw, square and a weird zigzag thing. during data collection random voltages were sampled to change the oscillator's frequency.
to convert to actual training dataset the triangle wave is used as input with one of the other waves acting as output. which output is decided by including two additional selector variables in the inputs. these variables are only ever {0, 1} during training but can act (very loosely) as an embedding since any float value can be passed in the range (0, 1) when running on the device.
input | output |
---|---|
(triangle, 0, 0) | sine |
(triangle, 0, 1) | ramp |
(triangle, 1, 0) | square |
(triangle, 1, 1) | zigzag |
the model was trained for < 1 min ( the model is pretty small, and there's not much variety in the data ).
it does well on held out test data...
but the much more interesting thing is what happens when we input values for x2 and x3 that weren't see during training data
e.g. what if we choose a point that is 20% between sine (0, 0) and square (1, 0) ? we end up with some weird non defined part of the input space.
these out of training distribution things are always the most interesting to me; values for x2 and x3 inbetween (0, 1) give a weird sort-of-interpolation. generally these results are never a smooth transistion unless there's some aspect of the loss that directs it to be so (and in this case there isn't).
we get the classic model hallucination stuff that i've always loved. ( see my older post on hallucinating softmaxs for more info )
we could encourage the model to make full use of the space by including a GAN like discriminator loss. this would be trained on random samples of values for (x2, x3). i've seen this kind of training force the model to be much more consistent for interpolated inputs.
this video of the network actually running give a better idea of the "interpolation". the green wave shows the input core triangle wave, the blue wave shows the output waveshaped result. the daisy module display shows the values for x2 and x3; with these values controlled by hand from the module on the right.
the video transistions across values of (x2, x3), moving between the corner (0, 1) values. at each corner the frequency of the core triangle input wave was adjusted over a range as well. something i hadn't considered is that when we change the frequency to be outside the range of what was seen during training the model does some weird extrapolation.
the interpolation extrapolation stuff makes some weird things, but it all sounds cool :)
further ideas could be...
code on github
]]>here's a recording; check it out!
and here's a pdf of the slides
]]>eurosat/all is a dataset of 27,000 64x64 satellite images taken with 13 spectral bands. each image is labelled one of ten classes.
for the purpose of classification these 13 aren't equally useful, and the information in them varies across resolutions. if we were designing a sensor we might choose to use different channels in different resolutions.
how can we explore the trade off between mixed resolutions and whether to use a channel at all?
let's start with a simple baseline to see what performance we get. we won't spend too much time on this model, we just want something that we can iterate on quickly.
the simple model shown below trained on a 'training' split with adam hits 0.942 top 1 accuracy on a 2nd 'validation' split in 5 epochs. that'll do for a start.
let's check the effect of including different combos of input channels. we'll do so by introducing a channel mask.
a mask of all ones denotes using all channels and gives our baseline performance
mask | validation accuracy |
[1,1,1,1,1,1,1,1,1,1,1,1,1] | 0.942 |
a mask of all zeros denotes using no channels and acts as a sanity check; it gives the performance of random chance which is in line with what we expect give the balanced training set. (note: we standardise the input data so that it has zero mean per channel (with the mean, standard deviation parameters fit against training data only) so we can get this effect)
mask | validation accuracy |
[1,1,1,1,1,1,1,1,1,1,1,1,1] | 0.942 |
[0,0,0,0,0,0,0,0,0,0,0,0,0] | 0.113 |
but what about if we drop just one channel? i.e. a mask of all ones except for a single zero.
channel to drop | validation accuracy |
0 | 0.735 |
1 | 0.528 |
2 | 0.661 |
3 | 0.675 |
4 | 0.809 |
5 | 0.724 |
6 | 0.749 |
7 | 0.634 |
8 | 0.874 |
9 | 0.934 |
10 | 0.593 |
11 | 0.339 |
12 | 0.896 |
from this we can see that the performance hit we get from losing a single channel is not always the same. in particular consider channel 11; if we drop that channel we get a huge hit! does that mean that if we keep only 11 that should give reasonable performance?
mask | validation accuracy |
[1,1,1,1,1,1,1,1,1,1,1,1,1] | 0.942 |
[0,0,0,0,0,0,0,0,0,0,0,0,0] | 0.113 |
[0,0,0,0,0,0,0,0,0,0,0,1,0] (keep just 11) | 0.260 |
bbbzzzttt (or other appropriate annoying buzzer noise). channel 11 is contributing to the classification but it's not being used independently. in general this is exactly the behaviour we want from a neural network but what should we do to explore the effect of not having this dependence?
consider using a dropout idea, just with input channels instead of intermediate nodes.
what behaviour do we get if we drop channels out during training? i.e. with 50% probability we replace an entire input channel with 0s?
things take longer to train and we get a slight hit in accuracy...
dropout? | validation accuracy |
no | 0.942 |
yes | 0.934 |
...but now when we mask out one channel at a time we don't get a big hit for losing any particular one.
channel to drop | validation accuracy | |
no dropout | with dropout | |
0 | 0.735 | 0.931 |
1 | 0.528 | 0.931 |
2 | 0.661 | 0.936 |
3 | 0.675 | 0.935 |
4 | 0.809 | 0.937 |
5 | 0.724 | 0.934 |
6 | 0.749 | 0.931 |
7 | 0.634 | 0.927 |
8 | 0.874 | 0.927 |
9 | 0.934 | 0.927 |
10 | 0.593 | 0.927 |
11 | 0.339 | 0.933 |
12 | 0.896 | 0.937 |
now that we have a model that is robust to any combo of channels what do we see if we use a simple genetic algorithm (GA) to evolve the channel mask to use with this pre trained network? a mask that represents using all channels will be the best right? right?
we'll evolve the GA using the network trained above but based on it's performance on a 3rd "ga_train" split using the inverse loss as a fitness function.
amusingly the GA finds that a mask of [1,1,0,1,0,0,0,1,1,0,1,0,1]
does better marginally better than all channels, but only uses 1/2 of them!
mask | split | accuracy |
[1,1,1,1,1,1,1,1,1,1,1,1,1] (all) | ga_validate | 0.934 |
[1,1,0,1,0,0,0,1,1,0,1,0,1] (ga) | ga_validate | 0.936 |
important note: we can imagine the best performance overall would be to have the GA evolve not the channels to use from this model, but the channels to use when training from scratch. this would though require a lot more model training, basically a full training cycle per fitness evaluation :( in the approach we describe here we only have to train a single model and then have the GA just run inference.
taking the idea of channel selection a step further, what if we got the GA to not only decide whether to use a channel or not, but also what resolution it should be in?
consider some example images across resolutions....
example images (just RGB channels shown) | |||
orig x64 | x32 | x16 | x8 |
we could then weight the use of a channel based on resolution; the higher the resolution the more the channel "costs" to use, with not using the channel at all being "free".
to support this we can change the GA to represent members not as a string of {0, 1}s but instead a sequence of {0, x8, x16, x32, x64} values per channel where these represent...
resolution | description | channel cost |
x64 | use original (64, 64) version of input | 0.8 |
x32 | use a 1/2 res (32, 32) version of input | 0.4 |
x16 | use a 1/4 res (16, 16) version of input | 0.2 |
x8 | use a 1/8 res (8, 8) version of input | 0.1 |
0 | don't use channel | 0 |
the change in the encoding of our GA is trivial, just 5 values per channel instead of 2, but before we look at that; how do we change our network?
we can do it without having to add too many extra parameters by using the magic of fully convolutional networks :)
notice how the main trunk of our first network was a series of 2d convolutions with a global spatial mean. this network will simply take as input all the resolutions we need! we can simply reuse it multiple times!
so we can have our network...
note: try as i might i can't get steps 2 to 4 to run parallelised in a pmap. asked on github about it and looks to be something you can't do at the moment.
when we consider channel cost vs loss there is no single best solution, it's a classic example of a pareto front where we see a tradeoff between the channel_cost and loss.
consider this sampling of 1,000 random channel masks...
the GA needs to operate with a fitness that's a single scalar; for now we just use
a simple combo of (1.0 / loss) - channel_cost
running with this fitness function we evolve the solution
[x16, x64, x64, x16, x32, ignore, x8, x64, x8, ignore, x8, ignore, x32]
it's on the pareto front, as we'd hope, and it's interesting that it includes a mix of resolutions including ignoring 3 channels completely :)
different mixings of loss and channel_cost would result in different GA solutions along the front
all the code is on github
]]>a classic piece of advice i hear for people using tpus is that they should "try using a bigger batch size".
this got me thinking; i wonder how big a batch size i could reasonably use? how would the optimisation go? how fast could i get things?
let's train a model on the eurosat/rgb dataset. it's a 10 way classification problem on 64x64 images
with a training split of 80% we have 21,600 training examples. we'll use another 10% for validation (2,700 images) and just not use the final 10% test split #hack
for the model we'll use a simple stack of convolutions with channel sizes 32, 64, 128 and 256, a stride of 2 for spatial reduction all with gelu activation. after the convolutions we'll do a simple global spatial pooling, a single 128d dense layer with gelu and then a 10d logit output. a pretty vanilla architecture of ~400K params. nothing fancy.
a v3-32 tpu pod slice is 4 hosts, each with 8 tpu devices.
21,600 training examples total => 5,400 examples per host => 675 examples per device.
this number of images easily fits on a device. great.
now usually augmentation is something we do randomly per batch, but for this hack we're interested in seeing how big a batch we can run. so why not fill out the dataset a bit by just running a stack of augmentations before training?
for each image we'll do 90, 180 and 270 deg rotations along with left/right flips for a total of 8 augmented images for each original image. e.g.....
this gives us now 172,800 images total => 43,200 per host => 5,400 per tpu device. which stills fits no problem.
side note: turns out doing this augmentation was one of the most fun parts of this hack :) see this tweet thread for some more info on how i used nested pmaps and vmaps to do it!
one motivation i had for this hack was to compare adam to lamb. i'd seen lamb referred to in the past, would it perform better for this model/dataset size? turns out it does! a simple sweep comparing lamb, adam and sgd shows lamb consistently doing the best. definitely one to add to the tuning mix from now on.
not only does the augmented data fit sharded across devices but we can replicate both the model parameters and the optimiser state as well. this is important for speed since the main training loop doesn't have to do any host/device communication. taking a data parallel approach means the only cross device comms is a gradient psum.
for training we run an inner loop just pumping the param = update(params)
step.
an outer loop runs the inner loop 100 times before doing a validation accuracy check.
the inner loop runs at 1.5s for the 100 iterations and since each iteration is a forward & backwards pass for all 172,800 images across all hosts that's 11M images processed per second. ðŸ”¥ðŸ”¥ðŸ”¥
at this speed the best result of 0.95 on validation takes 13 outer loops; i.e. all done in under 20s. o_O !!
when reviewing runs i did laugh to see sgd with momentum make a top 10 entry.
new t-shirt slogan: "sgd with momentum; always worth a try"
all the code in hacktastic undocumented form on github
]]>this 4 (and a bit) part tute series starts with
jax
fundamentals, builds up to describing a data parallel approach to training on a
cloud tpu pod slice, and
finishes with a tpu pod slice implementation of
ensemble nets....
all with the goal of solving 1d y=mx+b
and though it may seem like a bit of overkill it turns out it's a good example to work through so that we can focus on the library support without having to worry about the modelling.
in this first section we introduce some jax fundamentals; e.g. make_jaxpr, grad, jit, vmap & pmap.
colab: 01 pmap jit vmap oh my.ipynb
in part 2 we use the techniques from part 1 to solve y=mx+b
in pure jax. we'll also
introduce
pytrees
and various
tree_utils
for manipulating them.
we run first on a single device and work up to using pmap to demonstrate a simple data parallelism approach. along the way we'll do a small detour to a tpu pod slice to illustrate the difference in a multi host setup.
( note: the experience as described here for a pod slice isn't publically available yet; but sign up via the JAX on Cloud TPU Interest Form to get more info. see also this JAX on Cloud TPUs (NeurIPS 2020) talk )
colab: 02 y mx b on a tpu.ipynb
next we introduce haiku as a way of defining our model and optax as a library to provide standard optimisers. to illustrate there use we'll do a minimal port of our model and training loop to use them.
colab: 03 y mx b in haiku.ipynb
in part 4 we'll reimplement ensemble nets for this trivial model, continuing to do things in a way that supports running on a tpu pod slice.
colab: 04 y mx b haiku ensemble.ipynb
to wrap up we acknowledge that though tpu pod slices and data parallel approaches are fun we could have just solved this in a single calculation using the normal equation... :D
colab: 05 booooooooooooooooring.ipynb
y=mx+b
!!!OOD detection is an often overlooked part of many projects. it's super important to understand when your model is operating on data it's not familiar with!
though this can be framed as the problem of detecting that an input differs from training data, discussed more in this 'data engineering concerns for machine learning products' talk, for this post we're going to look at the problem more from the angle of the model needing to be able to express it's own lack of confidence.
a core theme of this post is treating OOD as a function of the output of the model. when i've played with this idea before the best result i've had is from using entropy as a proxy of confidence. in the past it's not be in the context of OOD directly but instead as a way of prioritising data annotation (i.e. uncertainty sampling for active learning)
using entropy as a stable measure of confidence though requires a model be well calibrated.
neural nets are sadly notorious for not being well calibrated under the default ways we train these days.
the first approach i was shown to calibrate a neural net is to train your model as normal, then finetune just the final classifier layer using a held out set; the general idea being that this is effectively just a logistic regression on features learnt by the rest of the network and so should come with the better guarantees that logistic regression provides us regarding calibration (see this great overview from sklearn) whereas this has always worked well for me it does require two steps to training, as well as the need for a seperate held out set to be managed.
another apporach to calibration i revisited this week comes from this great overview paper from 2017 on calibration of modern neural networks. it touches on platt scaling, which i've also used in the past successfully, but includes something even simpler i didn't really notice until it was pointed out to me by a colleague; don't bother tuning the entire last layer, just fit a single temperature rescaling of the output logits i.e. what can be thought of as a single parameter version of platt scaling. what a great simple idea!
the main purpose of this post though is to reproduce some ideas around out of distribution and calibration when you use focal loss. this was described in this great piece of work calibrating deep neural networks using focal loss. we'll implement that, but the other two as well for comparison.
( as an aside, another great simple way of determining confidence is using an ensemble! i explore this idea a bit in my post on ensemble nets where i train ensembles as a single model using jax vmap )
let's start with cifar10 and split the data up a bit differently than the standard splits...
train
and test
splits.
automobile
& cat
as a "hard" ood
set ( i've intentionally chosen these two since i know they cause model confusion as observed in this keras.io tute on metric learning ).
train
and 0.1 for each of validation
, calibration
and test
.
we'll additionally generate an "easy" ood
set which will just be random images.
in terms of how we'll use these splits...
train
will be the main dataset to train against.
validation
will be used during training for learning rate stepping, early stopping, etc.
calibration
will be used when we want to tune a model in some way for calibration.
test
will be used for final held out analysis.
ood_hard
and ood_easy
will be used as the representative out of distribution sets.
examples images | |
in distribution | |
out of distribution (hard) | |
out of distribution (easy) |
model_1
, on train
, tuning against validate
.
model_2
by finetuning the last layer of model_1
against calibrate
; i.e. the logisitic regression approach to calibration.
model_3
by just fitting a scalar temperture of last layer of model_1
; i.e. the temperture scaling approach.
train
using validate
for tuning.
at each step we'll check the entropy distributions across the various splits, including the easy and hard ood sets
for a baseline we'll use a
resnet18 objax model
trained against train
using validate
for simple early stopping.
# define model model = objax.zoo.resnet_v2.ResNet18(in_channels=3, num_classes=10) # train against all model vars trainable_vars = model.vars()
$ python3 train_basic_model.py --loss-fn cross_entropy --model-dir m1 learning_rate 0.001 validation accuracy 0.372 entropy min/mean/max 0.0005 0.8479 1.9910 learning_rate 0.001 validation accuracy 0.449 entropy min/mean/max 0.0000 0.4402 1.9456 learning_rate 0.001 validation accuracy 0.517 entropy min/mean/max 0.0000 0.3566 1.8013 learning_rate 0.001 validation accuracy 0.578 entropy min/mean/max 0.0000 0.3070 1.7604 learning_rate 0.001 validation accuracy 0.680 entropy min/mean/max 0.0000 0.3666 1.7490 learning_rate 0.001 validation accuracy 0.705 entropy min/mean/max 0.0000 0.3487 1.6987 learning_rate 0.001 validation accuracy 0.671 entropy min/mean/max 0.0000 0.3784 1.8276 learning_rate 0.0001 validation accuracy 0.805 entropy min/mean/max 0.0000 0.3252 1.8990 learning_rate 0.0001 validation accuracy 0.818 entropy min/mean/max 0.0000 0.3224 1.7743 learning_rate 0.0001 validation accuracy 0.807 entropy min/mean/max 0.0000 0.3037 1.8580
we can check the distribution of entropy values of this trained model against our various splits
some observations...
recall that a higher entropy => a more uniform distribution => less prediction confidence.
as such our goal is to have the entropy of the ood
set as high as possible without
dropping the test
set too much.
as mentioned before the main approach i'm familiar with is retraining the classifier
layer. we'll start with model_1
and use the so-far-unseen calibration
set for fine tuning it.
# define model model = objax.zoo.resnet_v2.ResNet18(in_channels=3, num_classes=10) # restore from model_1 objax.io.load_var_collection('m1/weights.npz', model.vars()) # train against last layer only classifier_layer = model[-1] trainable_vars = classifier_layer.vars()
$ python3 train_model_2.py --input-model-dir m1 --output-model-dir m2 learning_rate 0.001 calibration accuracy 0.808 entropy min/mean/max 0.0000 0.4449 1.8722 learning_rate 0.001 calibration accuracy 0.809 entropy min/mean/max 0.0002 0.5236 1.9026 learning_rate 0.001 calibration accuracy 0.808 entropy min/mean/max 0.0004 0.5468 1.9079 learning_rate 0.001 calibration accuracy 0.811 entropy min/mean/max 0.0005 0.5496 1.9214 learning_rate 0.001 calibration accuracy 0.812 entropy min/mean/max 0.0001 0.5416 1.9020 learning_rate 0.001 calibration accuracy 0.812 entropy min/mean/max 0.0002 0.5413 1.9044 learning_rate 0.001 calibration accuracy 0.811 entropy min/mean/max 0.0003 0.5458 1.8990 learning_rate 0.001 calibration accuracy 0.814 entropy min/mean/max 0.0001 0.5472 1.9002 learning_rate 0.001 calibration accuracy 0.812 entropy min/mean/max 0.0001 0.5455 1.9124 learning_rate 0.001 calibration accuracy 0.813 entropy min/mean/max 0.0001 0.5503 1.9092 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0001 0.5485 1.9011 learning_rate 0.0001 calibration accuracy 0.814 entropy min/mean/max 0.0001 0.5476 1.9036 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0001 0.5483 1.9051 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0000 0.5479 1.9038 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0000 0.5482 1.9043 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0000 0.5482 1.9059 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0000 0.5479 1.9068 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0000 0.5480 1.9058 learning_rate 0.0001 calibration accuracy 0.812 entropy min/mean/max 0.0000 0.5488 1.9040 learning_rate 0.0001 calibration accuracy 0.814 entropy min/mean/max 0.0000 0.5477 1.9051
if we compare (top1) accuracy of model_1
vs model_2
we see a slight improvement
in model_2
, attributable to the slight extra training i suppose (?)
$ python3 calculate_metrics.py --model m1/weights.npz train accuracy 0.970 validate accuracy 0.807 calibrate accuracy 0.799 test accuracy 0.803 $ python3 calculate_metrics.py --model m2/weights.npz train accuracy 0.980 validate accuracy 0.816 calibrate accuracy 0.814 test accuracy 0.810
more importantly though; how do the entropy distributions looks?
some observations...
in model_3
we create a simple single parameter layer that represents rescaling,
append it to the pretrained resnet and train just that layer.
class Temperature(objax.module.Module): def __init__(self): super().__init__() self.temperature = objax.variable.TrainVar(jnp.array([1.0])) def __call__(self, x): return x / self.temperature.value # define model model = objax.zoo.resnet_v2.ResNet18(in_channels=3, num_classes=10) # restore from model_1 objax.io.load_var_collection('m1/weights.npz', model.vars()) # add a temp rescaling layer temperature_layer = layers.Temperature() model.append(temperature_layer) # train against just this layer trainable_vars = temperature_layer.vars()
$ python3 train_model_3.py --input-model-dir m1 --output-model-dir m3 learning_rate 0.01 temp 1.4730 calibration accuracy 0.799 entropy min/mean/max 0.0006 0.5215 1.9348 learning_rate 0.01 temp 1.5955 calibration accuracy 0.799 entropy min/mean/max 0.0013 0.5828 1.9611 learning_rate 0.01 temp 1.6118 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5911 1.9643 learning_rate 0.01 temp 1.6162 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5933 1.9652 learning_rate 0.01 temp 1.6118 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5911 1.9643 learning_rate 0.01 temp 1.6181 calibration accuracy 0.799 entropy min/mean/max 0.0015 0.5943 1.9655 learning_rate 0.01 temp 1.6126 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5915 1.9645 learning_rate 0.01 temp 1.6344 calibration accuracy 0.799 entropy min/mean/max 0.0016 0.6025 1.9687 learning_rate 0.01 temp 1.6338 calibration accuracy 0.799 entropy min/mean/max 0.0016 0.6022 1.9686 learning_rate 0.01 temp 1.6018 calibration accuracy 0.799 entropy min/mean/max 0.0013 0.5860 1.9623 learning_rate 0.001 temp 1.6054 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5878 1.9630 learning_rate 0.001 temp 1.6047 calibration accuracy 0.799 entropy min/mean/max 0.0013 0.5874 1.9629 learning_rate 0.001 temp 1.6102 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5903 1.9640 learning_rate 0.001 temp 1.6095 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5899 1.9638 learning_rate 0.001 temp 1.6112 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5908 1.9642 learning_rate 0.001 temp 1.6145 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5924 1.9648 learning_rate 0.001 temp 1.6134 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5919 1.9646 learning_rate 0.001 temp 1.6144 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5924 1.9648 learning_rate 0.001 temp 1.6105 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5904 1.9640 learning_rate 0.001 temp 1.6151 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5927 1.9650
some observations...
model_1
. this is since the temp rescaling never changes the ordering of predictions. ( just this property alone could be enough reason to use this approach in some cases! )
model_2
in terms of raising the entropy of ood
!
not bad for fitting a single parameter :D
as a final experiment we'll use focal loss for training instead of cross entropy
note: i rolled my own focal loss function which is always dangerous! it's not impossible (i.e. it's likely) i've done something wrong in terms of numerical stability; please let me know if you see a dumb bug :)
def focal_loss_sparse(logits, y_true, gamma=1.0):
log_probs = jax.nn.log_softmax(logits, axis=-1)
log_probs = log_probs[jnp.arange(len(log_probs)), y_true]
probs = jnp.exp(log_probs)
elementwise_loss = -1 * ((1 - probs)**gamma) * log_probs
return elementwise_loss
note that focal loss includes a gamma. we'll trial 1.0, 2.0 and 3.0 in line with the experiments mentioned in calibrating deep neural networks using focal loss
$ python3 train_basic_model.py --loss-fn focal_loss --gamma 1.0 --model-dir m4_1 learning_rate 0.001 validation accuracy 0.314 entropy min/mean/max 0.0193 1.0373 2.0371 learning_rate 0.001 validation accuracy 0.326 entropy min/mean/max 0.0000 0.5101 1.8546 learning_rate 0.001 validation accuracy 0.615 entropy min/mean/max 0.0000 0.6680 1.8593 learning_rate 0.001 validation accuracy 0.604 entropy min/mean/max 0.0000 0.5211 1.8107 learning_rate 0.001 validation accuracy 0.555 entropy min/mean/max 0.0001 0.5127 1.7489 learning_rate 0.0001 validation accuracy 0.758 entropy min/mean/max 0.0028 0.6008 1.8130 learning_rate 0.0001 validation accuracy 0.780 entropy min/mean/max 0.0065 0.6590 1.9305 learning_rate 0.0001 validation accuracy 0.796 entropy min/mean/max 0.0052 0.6390 1.8589 learning_rate 0.0001 validation accuracy 0.790 entropy min/mean/max 0.0029 0.5886 1.9151
$ python3 train_basic_model.py --loss-fn focal_loss --gamma 2.0 --model-dir m4_2 learning_rate 0.001 validation accuracy 0.136 entropy min/mean/max 0.0001 0.2410 1.7906 learning_rate 0.001 validation accuracy 0.530 entropy min/mean/max 0.0000 0.8216 2.0485 learning_rate 0.001 validation accuracy 0.495 entropy min/mean/max 0.0003 0.7524 1.9086 learning_rate 0.001 validation accuracy 0.507 entropy min/mean/max 0.0000 0.6308 1.8460 learning_rate 0.001 validation accuracy 0.540 entropy min/mean/max 0.0000 0.5899 2.0196 learning_rate 0.001 validation accuracy 0.684 entropy min/mean/max 0.0001 0.6189 1.9029 learning_rate 0.001 validation accuracy 0.716 entropy min/mean/max 0.0000 0.6509 1.8822 learning_rate 0.001 validation accuracy 0.606 entropy min/mean/max 0.0001 0.7096 1.8312 learning_rate 0.0001 validation accuracy 0.810 entropy min/mean/max 0.0002 0.6208 1.9447 learning_rate 0.0001 validation accuracy 0.807 entropy min/mean/max 0.0004 0.6034 1.8420
$ python3 train_basic_model.py --loss-fn focal_loss --gamma 3.0 --model-dir m4_3 learning_rate 0.001 validation accuracy 0.288 entropy min/mean/max 0.0185 1.1661 1.9976 learning_rate 0.001 validation accuracy 0.319 entropy min/mean/max 0.0170 1.0925 2.1027 learning_rate 0.001 validation accuracy 0.401 entropy min/mean/max 0.0396 0.9795 2.0361 learning_rate 0.001 validation accuracy 0.528 entropy min/mean/max 0.0001 0.8534 1.9115 learning_rate 0.001 validation accuracy 0.528 entropy min/mean/max 0.0002 0.7405 1.8509 learning_rate 0.0001 validation accuracy 0.748 entropy min/mean/max 0.0380 0.9204 1.9697 learning_rate 0.0001 validation accuracy 0.760 entropy min/mean/max 0.0764 0.9807 1.9861 learning_rate 0.0001 validation accuracy 0.772 entropy min/mean/max 0.0906 0.9646 2.0556 learning_rate 0.0001 validation accuracy 0.781 entropy min/mean/max 0.0768 0.8973 1.9264 learning_rate 0.0001 validation accuracy 0.780 entropy min/mean/max 0.0871 0.8212 1.9814
some observations...
model_2
than model_1
in terms of calibration.
ood
data, but with a drop in overall validation
accuracy. maybe the seperation isn't as much as i'd like, but happy to experiment more given this result.
interesting!! focal loss is definitely going in my "calibration toolkit" :D
recall the distribution plot of entropies from the model_3
experiment
we can eyeball that the entropies values are low for in-distribution sets (train, validate, calibrate and test) and high for the two ood sets (easy and hard) but how do we actually turn this difference into a in-distribution classifier? ( recall this is not a standard supervised learning problem, we only have the in-distribution instances to train against and nothing from the ood sets at training time )
amusingly i spent a bunch of time hacking around with all sorts of density estimation techniques but it turned out the simplest baseline i started with was the best :/
the baseline works on a strong, but valid enough, assumption of the entropies; the entropies for in-distribution instances will be lower than the entropies for OOD instances. given this assumption we can just do the simplest "classifier" of all; thresholding
to get a sense of the precision/recall tradeoff of approach we can do the following...
train
set; this scales all train values to (0, 1)
test
, ood_easy
and ood_hard
; including (0, 1) clipping
x = 1.0 - x
, since we want a low entropy to denote a positive instance of in-distribution
test
instances as positives and ood
instances
as negatives.
with this simple approach we get the following ROC, P/R and reliability plots.
i've included the attempts i made to use a kernel density estimator in fit_kernel_density_to_entropy.py
so if you manage to get a result better than this trivial baseline please
let me know :)
though this post is about using entropy as a way of measuring OOD there are other ways too.
the simplest way i've handled OOD detection in the past is to include an explicit OTHER label. this has worked in cases where i've had data available to train against, e.g. data that was (somehow) sampled but wasn't of interest for actual supervised problem at hand. but i can see in lots of situations this wouldn't always be possible.
another approach i'm still hacking on is doing density estimation on the inputs; this related very closely to the diversity sampling ideas for active learning.
these experiments were written in objax, an object oriented wrapper for jax
]]>Mathematics for Machine Learning by Marc Peter Deisenroth, A. Aldo Faisal & Cheng Soon Ong.
this is my personal favorite book on the general math required for machine learning, the way things are described really resonate with me. available as a free pdf but i got a paper copy to support the authors after reading the first half.
Linear Algebra and Learning from Data by Gilbert Strang.
this is gilbert's most recent work. it's really great, he's such a good teacher, and his freely available lectures are even better. it's a shorter text than his other classic intro below with more of a focus on how things are connected to modern machine learning techniques.
Introduction to Linear Algebra by Gilbert Strang.
this was my favorite linear algebra book for a long time before his 'learning from data' came out. this is a larger book with a more comprehensive view of linear algebra.
Think Stats: Probability and Statistics for Programmers by Allen Downey.
this book focuses on practical computation methods for probability and statistics. i got a lot out of working through this one. it's all in python and available for free. ( exciting update! as part of writing this post i've discovered there's a new edition to read!)
Doing Bayesian Data Analysis by John Kruscgke
on the bayesian side of things this is the book i've most enjoyed working through. i've only got the first edition which was R and BUGS but i see the second edition is R, JAGS and Stan. it'd be fun i'm sure to work through it doing everything in numpyro. i might do that in all my free time. haha. "free time" hahaha. sob.
The Elements of Statistical Learning by Hastie, Tibshirani and Friedman
this is still one of the most amazing fundamental machine learning books i've ever had. in fact i've purchased this book twice and given it away both times :/ i might buy another copy some time soon, even though it's been freely available to download for ages. an amazing piece of work.
Probabilistic Graphical Models by Daphne Koller & Nir Friedman
this is an epic textbook that i'd love to understand better. i've read a couple of sections in detail but not the entire tome yet.
Pattern Recognition and Machine Learning by Christopher Bishop
this is probably the best overall machine learning text book i've ever read. such a beautiful book and the pdf is FREE FOR DOWNLOAD!!!
Machine Learning: A Probabilistic Perspective by Kevin Murphy
this is my second favorite general theory text on machine learning. i got kevin to sign my copy when he was passing my desk once but someone borrowed it and never gave it back :( so if you see a copy with my name on the spine let me know!
Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow by AurÃ©lien GÃ©ron
this is the book i point most people to when they are interested in getting up to speed with modern applied machine learning without too much concern for the theory. it's very up to date (as much as a book can be) with the latest libraries and, most importantly, provides a good overview of not just neural stuff but fundamental scikit-learn as well.
Machine Learning Engineering by Andriy Burkov
a great book focussing on the operations side of running a machine learning system. i'm a bit under half way through the free online version and very likely to buy a physical copy to finish it and support the author. great stuff and, in many ways, a more impactful book than any of the theory books here.
Introduction to Data Mining by Pang-Ning Tan, Michael Steinbach & Vipin Kumar
this is another one that was also on my list from ten years ago and though it's section on neural networks is a bit of chuckle these days there is still a bunch of really great fundamental stuff in this book. very practical and easy to digest. i also see there's a second edition now. i reckon this would compliment the "hands on" book above very well.
Speech and Language Processing by Dan Jurafsky & James Martin
still the best overview of NLP there is (IMHO). can't wait to read the 3rd edition which apparently will cover more modern stuff (e.g. transformers) but until then, for the love of god though, please don't be one of those "this entire book is irrelevant now! just fine tune BERT" people :/
Numerical Optimization by Jorge NocedalStephen J. Wright
this book is super hard core and maybe more an operations research book than machine learning. though i've not read it cover to cover the couple of bits i've worked through really taught me a lot. i'd love to understand the stuff in this text better; it's so so fundamental to machine learning (and more)
Deep Learning by Ian Goodfellow
writing a book specifically on deep learning is very dangerous since things move so fast but if anyone can do it, ian can... i think ian's approach to explaining neural networks from the ground up is one of my favorites. i got the first edition hardback but it's free to download from the website.
Probabilistic Robotics by Sebastian Thrun, Wolfram Burgard and Dieter Fox
when i first joined a robotics group i bought a stack of ML/robotics books and this was by far the best. it's good intro stuff, and maybe already dated in places given it's age (the 2006 edition i have) but i still got a bunch from it.
TinyML by Pete Warden & Daniel Situnayake
this was a super super fun book to tech review! neural networks on microcontrollers?!? yes please!
Evolutionary Computation by David Fogel
this is still by favorite book on evolutionary algorithms; i've had this for a loooong time now. i still feel like evolutionary approaches are due for a big big comeback any time soon....
the good thing about writing a list is you get people telling you cool ones you've missed :)
the top three i've chosen (that are in the mail) are...
Causal Inference in Statistics by Judea Pearl, Madelyn Glymour & Nicholas P. Jewell
recommended by animesh garg who quite rightly points out the lack of causality in machine learning books in the books above.
Information Theory, Inference and Learning Algorithms by David MacKay
i've seen this book mentioned a number of times and was most recently recommended by my colleague dane so it's time to get it.
]]>it's been about two years since i first saw the awesome very slow movie player project by bryan boyer. i thought it was such an excellent idea but never got around to buying the hardware to make one. more recently though i've seen a couple of references to the project so i decided it was finally time to make one.
one interesting concern about an eink very slow movie player is the screen refresh. simpler eink screens refresh by doing a full cycle of a screen of white or black before displaying the new image. i hated the idea of an ambient slow player doing this every few minutes as it switched frames, so i wanted to make sure i got a piece of hardware that could do incremental update.
after a bit of shopping around i settled on a 6 inch HD screen from waveshare
it ticks all the boxes i wanted
this screen also supports grey scale, but only with a flashy full cycle redraw, so i'm going to stick to just black and white since it supports the partial redraw.
note: even though the partial redraw is basically instant it does suffer from a ghosting problem; when you draw a white pixel over a black one things are fine, but if you draw black over white, in the partial redraw, you get a slight ghosting of gray that is present until a full redraw :/
so how do you display an image when you can only show black and white? dithering! here's an example of a 384x288 RGB image dithered using PILS implementation of the Floyd-Steinberg algorithm
original RGB vs dithered version |
it makes intuitive sense that you could have small variations in the exact locations of the dots as long as you get the densities generally right. s so there's a reasonable question then; how do you dither in such a way that you get a good result, but with minimal pixel changes from a previous frame? (since we're motivated on these screens to change as little as possible)
there are two approaches i see
1) spend 30 minutes googling for a solution that no doubt someone came up with 20 years ago that can be implemented in 10 lines of c running at 1000fps ...
2) .... or train an jax based GAN to generate the dithers with a loss balancing a good dither vs no pixel change. :P
when building a very slow movie player the most critical decision is... what movie to play? i really love the 1979 classic alien, it's such a great dark movie, so i thought i'd go with it. the movie is 160,000 frames so at a play back rate of a frame every 200 seconds it'll take just over a year to finish.
note that in this type of problem there is no concern around overfitting. we have access to all data going in and so it's fine to overfit as much as we like; as long as we're minimising whatever our objective is we're good to go.
i started with a unet that maps 3 channel RGB images to a single channel dither.
v1 architecture |
i tinkered a bit with the architecture but didn't spend too much time tuning it. for the final v3 result i ended with a pretty vanilla stack of encoders & decoders (with skip connections connecting an encoder to the decoder at the same spatial resolution) each encoder/decoder block uses a residual like shortcut around a couple of convolutions. nearest neighbour upsampling gave a nicer result than deconvolutions in the decoder for the v3 result. also, gelu is my new favorite activation :)
for v1 i used a binary cross entropy loss of P(white) per pixel ( since it's what worked well for my bee counting project )
as always i started by overfitting to a single example to get a baseline feel for capacity required.
v1 overfit result |
when scaling up to the full dataset i switched to training on half resolution images against a patch size of 128. working on half resolution consistently gave a better result than working with the full resolution.
as expected though this model gave us the classic type of problem we see with straight unet style image translation; we get a reasonable sense of the shapes, but no fine details around the dithering.
v1 vanilla unet with upsampling example |
side notes:
v1 vanilla unet with deconvolution example |
for v2 i added a GAN objective in an attempt to capture finer details
v2 architecture |
i started with the original pix2pix objective but reasonably quickly moved to use a wasserstein critic style objective since i've always found it more stable.
the generator (G) was the same as the unet above with the discriminator (D) running patch based. at this point i also changed the reconstruction loss from a binary objective to just L1. i ended up using batchnorm in D, but not G. to be honest i only did a little did of manual tuning, i'm sure there's a better result hidden in the hyperparameters somewhere.
so, for this version, the loss for G has two components
1. D(G(rgb)) # fool D 2. L1(G(rgb), dither) # reconstruct the dither
very quickly (i.e. in < 10mins ) we get a reasonable result that is started to show some more detail than just the blobby reconstruction.
v2 partial trained eg |
note: if the loss weight of 2) is 0 we degenerate to v1 (which proved a useful intermediate debugging step). at this point i didn't want to tune to much since the final v3 is coming...
for v3 we finally introduce a loss relating the previous frame (which was one of the main intentions of the project in the first place)
now G takes not just the RGB image, but the dither of the previous frame.
v3 architecture |
the loss for G now has three parts
1. D(G(rgb_t1)) => real # fool D 2. L1(G(rgb_t1), dither_t1) # reconstruct the dither 3. L1(G(rgb_t1), dither_t0) # don't change too much from the last frame
normally with a network that takes as input the same thing it's outputting we have to be careful to include things like teacher forcing. but since we don't intend to use this network for any kind of rollouts we can just always feed the "true" dithers in where required. having said that, rolling out the dithers from this network would be interesting :D
the third loss objective, not changing too many pixels from the last frame, works well for generally stationary shots but is disastrous for scene changes :/
consider the following graph for a sequence of frames showing the pixel difference between frames.
when there is a scene change we observe a clear "spike" in pixel diff. my first thought was to look for these and do a full redraw for them. it's very straightforward to find them (using a simple z-score based anomaly detector on a sliding window) but the problem is that it doesn't pick up the troublesome case of a panning shot where we don't have a scene change exactly. in these cases there is no abrupt scene change, but there are a lot of pixels changing so we end up seeing a lot of ghosting.
i spent ages tinkering with the best way to approach this before deciding that a simple
approach of num_pixels_changed_since_last_redraw > threshold
was good enough to decide
if a full redraw was required (with a cooldown to ensure we not redrawing all the time)
the v3 network gets a very good result very quickly; unsurprisingly since the dither at time t0 provided to G is a pretty good estimate of the dither at t1 :) i.e. G can get a good result simply by copying it!
the following scenario shows this effect...
consider three sequential frames, the middle one being a scene change.
at the very start of training the reconstruction loss is dominant and we get blobby outlines of the frame.
but as the contribution from the dither at time t0 kicks it things look good in general but the frames at the scene change end up being a ghosted mix attempt to copy through the old frame along with dithering the new one. (depending on the relative strength of the loss terms of G).
so the v3 version generally works and i'm sure with some more tuning i could get a better result but, as luck would have it, i actually find the results from v2 more appealing when testing on the actual eink screen. so even though the intention was do something like v3 i'm going to end up running something more like v2 (as shown in these couple of examples (though the resolution does it no justice (not to mention the fact the player will run about 5000 times slower than these gifs)))
i ran for a few weeks with a prototype that lived balanced precariously on a piece of foam below it's younger sibling pi zero eink screen running game of life. eventually i cut up some pieces of an old couch and made a simple wooden frame. a carpenter, i am not :/
prototype | frame |