was doing some work with kalman filters recently and, as always, wasn't sure what the best way to tune the filter configuration.
in the past i've taken a simple grid/random search like approach but was curious how i could express this tuning as an optimisation using jax instead.
first though, what is a kalman filter? they are pretty broad concept but the main thing i've used them for is dynamic system prediction.
they operate in a two step fashion...
predict
step which predicts something about the system based on some internal state and
update
step which integrates new observation information into the filter's state ready for the next predict
in this post we'll use a simple 2D trajectory as the system. e.g. throwing an object
the task of the kalman filter to be the prediction of the objects' position at the next time step
consider then the simple system of throwing an object under a trivial physics model
def simulate_throw(dx, dy):
x, y = 0, 0
for _ in range(10):
yield x, y
x += dx
y += dy
dy -= 1
we can use this to simulate a couple of throws and plot the trajectories...
draw_throw_with_colours(
[simulate_throw_a(dx=3, dy=5), simulate_throw_a(dx=4, dy=3)],
['red', 'green']
)
can we use a kalman filter to predict the next state of these systems? i do hope so, it's what they were designed for!
implementations of a kalman filter can vary a lot so let's just use this random kalman filter from the internet
note: using this implementation, on any random snippet of code from the internet, for, say, controlling a rocket or something might be correctly considered a generally "bad idea". the correctness, or otherwise, of this filter is irrelevant to the task of jaxifying it :)
# as is from https://machinelearningspace.com/2d-object-tracking-using-kalman-filter/ with some
# minor changes
class KalmanFilter(object):
def __init__(self, dt, u_x,u_y, std_acc, x_std_meas, y_std_meas):
"""
:param dt: sampling time (time for 1 cycle)
:param u_x: acceleration in x-direction
:param u_y: acceleration in y-direction
:param std_acc: process noise magnitude
:param x_std_meas: standard deviation of the measurement in x-direction
:param y_std_meas: standard deviation of the measurement in y-direction
"""
# Define sampling time
self.dt = dt
# Define the control input variables
self.u = np.matrix([[u_x],[u_y]])
# Intial State
self.x = np.matrix([[0, 0], [0, 0], [0, 0], [0, 0]])
# Define the State Transition Matrix A
self.A = np.matrix([[1, 0, self.dt, 0],
[0, 1, 0, self.dt],
[0, 0, 1, 0],
[0, 0, 0, 1]])
# Define the Control Input Matrix B
self.B = np.matrix([[(self.dt**2)/2, 0],
[0, (self.dt**2)/2],
[self.dt, 0],
[0, self.dt]])
# Define Measurement Mapping Matrix
self.H = np.matrix([[1, 0, 0, 0],
[0, 1, 0, 0]])
# Initial Process Noise Covariance
self.Q = np.matrix([[(self.dt**4)/4, 0, (self.dt**3)/2, 0],
[0, (self.dt**4)/4, 0, (self.dt**3)/2],
[(self.dt**3)/2, 0, self.dt**2, 0],
[0, (self.dt**3)/2, 0, self.dt**2]]) * std_acc**2
# Initial Measurement Noise Covariance
self.R = np.matrix([[x_std_meas**2,0],
[0, y_std_meas**2]])
# Initial Covariance Matrix
self.P = np.eye(self.A.shape[1])
def predict(self):
# Refer to :Eq.(9) and Eq.(10)
# Update time state
#x_k =Ax_(k-1) + Bu_(k-1) Eq.(9)
self.x = np.dot(self.A, self.x) + np.dot(self.B, self.u)
# Calculate error covariance
# P= A*P*A' + Q Eq.(10)
self.P = np.dot(np.dot(self.A, self.P), self.A.T) + self.Q
return self.x[0]
def update(self, z):
# Refer to :Eq.(11), Eq.(12) and Eq.(13)
# S = H*P*H'+R
S = np.dot(self.H, np.dot(self.P, self.H.T)) + self.R
# Calculate the Kalman Gain
# K = P * H'* inv(H*P*H'+R)
K = np.dot(np.dot(self.P, self.H.T), np.linalg.inv(S)) #Eq.(11)
self.x = self.x + np.dot(K, (z - np.dot(self.H, self.x))) #Eq.(12)
I = np.eye(self.H.shape[1])
# Update error covariance matrix
self.P = (I - (K * self.H)) * self.P #Eq.(13)
a couple of things to note about this filter
predict
which provides an estimate of x
and update
which updates
the internal state of the filter based on a real observation. the implementations of these are
kinda opaque and i wouldn't be surprised if there are a stack of subtle bugs here that a
dynamic systems expert could spot. i am happy to take it "as is" for the purpose of poking around in some jax
predict
and update
change the internal state of the filter it's expected they
are called in sequence; predict
, update
, predict
, update
etc
u
, B
, H
etc, some of which are
to do with the internal state of the filter ( like P
) with others representing a form of config
around how we expect the dynamics of the system to behave ( like A
). these latter
matrices are configured based off scalar values such as dt
and x_std_meas
.
we can use this filter as is to make predictions about the next time step of a throw, and it's not too bad at it...
# construct a filter with some config
filter = KalmanFilter(
dt=1.0, u_x=0, u_y=0,
std_acc=1.0, x_std_meas=0.1, y_std_meas=0.1)
# simulate a throw
xy_trues = simulate_throw(2.8, 4.8)
# step throw the trajectory
xy_preds = []
for xy_true in xy_trues:
# make prediction based on filter and record it
xy_pred = filter.predict()
xy_preds.append(xy_pred)
# update the filter state based on the true value
filter.update(xy_true)
# plot the pair
# * red denotes true values,
# * green denotes predicted values based on the time step before
xy_preds = np.stack(xy_preds).squeeze()
draw_throw_with_colours(
[xy_trues, xy_preds],
['red', 'green'])
next let's port this kalman filter to jax. there's a few aspects of this....
a key concept in the jax port is making predict
and update
fully functional,
and that means taking a state and returning it for the two methods. i.e. something like...
def predict(params, state):
...
return state, xy_pred
def update(params, state, z):
...
return state
we can then have all the P
, A
, Q
, etc go in either params
or state
.
note: we are going to be explicit about two different types of variables we are dealing with here...
state
represents the internal state of the filter that changes over time
based on the sequence of predict
, update
, predict
, ... calls
params
represents the configuration items, based on dt
etc, that we eventually
want to get gradients for ( with respect to a loss function )
the pairing of predict
then update
can be expressed as the following
that reassigns state
each method call
def predict_then_update(params, state, xy_true):
state, xy_pred = predict(params, state)
state = update(params, state, xy_true)
return state, xy_pred
what type are params
and state
then?
we want them to be collections of variables which are well supported in jax
under the idea of
pytrees
and we can get a lot way with just thinking of them as dictionaries...
note for this jax port:
jnp.dot
to @
which, IMHO, reads easier.
u_x
and u_y
which were 0.0s anyways. and no u
implies no B
either...
def predict(params, state):
state['x'] = params['A'] @ state['x']
state['P'] = ((params['A'] @ state['P']) @ params['A'].T) + params['Q']
xy_pred = state['x'][0]
return state, xy_pred
def update(params, state, z):
# Define Measurement Mapping Matrix
H = jnp.array([[1, 0, 0, 0],
[0, 1, 0, 0]])
S = (H @ (state['P'] @ H.T)) + params['R']
K = (state['P'] @ H.T) @ jnp.linalg.inv(S)
state['x'] = state['x'] + (K @ (z - (H @ state['x'])))
I = jnp.eye(4)
state['P'] = (I - (K @ H)) @ state['P']
return state
the final missing piece then is how we define the initial values for params
and state
def default_params():
dt = 1.0
std_acc = 1.0
x_std_meas, y_std_meas = 0.1, 0.1
return {
# Define the State Transition Matrix A
'A': jnp.array([[1, 0, dt, 0],
[0, 1, 0, dt],
[0, 0, 1, 0],
[0, 0, 0, 1]]),
# Initial Measurement Noise Covariance
'R': jnp.array([[x_std_meas**2, 0],
[0, y_std_meas**2]]),
# Initial Process Noise Covariance
'Q': jnp.array([[(dt**4)/4, 0, (dt**3)/2, 0],
[0, (dt**4)/4, 0, (dt**3)/2],
[(dt**3)/2, 0, dt**2, 0],
[0, (dt**3)/2, 0, dt**2]]) * std_acc**2
}
def initial_state():
return {
# Initial State
'x': jnp.zeros((4, 2)),
# Initial Covariance Matrix
'P': jnp.eye(4),
}
which all comes together like the numpy one as ...
params = default_params()
state = initial_state()
xy_trues = simulate_throw(2.8, 4.8)
xy_preds = []
for xy_true in xy_trues:
state, xy_pred = predict_then_update(params, state, xy_true)
xy_preds.append(xy_pred)
xy_preds = np.stack(xy_preds)
draw_throw_with_colours(
[xy_trues, xy_preds],
['red', 'green'])
before we go any deeper into jax land let's talk about one more aspect of using kalman filters.
the main idea touched on already is that we can use them to make a prediction about the next state of a system before we observe it.
xy_pred_t0 = predict()
# wait for actual xy_true_t0
update(xy_true_t0)
xy_pred_t1 = predict()
# wait for actual xy_true_t1
update(xy_true_t1)
but if we really trust the filter we can use it to alos handle missing observations by passing in the last predicted as the observed one when we don't have it ( for whatever reason )
xy_pred_t0 = predict()
# oh oh! for whatever reason we don't have _t0
update(xy_pred_t0) # use predicted instead
xy_pred_t1 = predict()
# wait for actual xy_true_t1
update(xy_true_t1)
this can be handy for a signal that noisy and dropping out but it does put more pressure on the filter to be robust to any compounding error it might exhibit.
during training then ( which we still haven't talked about yet ) we can induce this
by occasionally randomly dropping out
xy_true
values and using the prior xy_pred
value instead.
( those with a background in RNNs might notice this is the same as teacher forcing )
the code to do this based on a 20% dropout can be....
def predict_then_update(params, state, has_observation, xy_true):
state, xy_pred = predict(params, state)
xy_for_update = jnp.where(has_observation, xy_true, xy_pred)
state = update(params, state, xy_for_update)
return state, xy_pred
xy_preds = []
has_observations = []
for xy_true in xy_trues:
has_observation = rng.uniform() > 0.2
has_observations.append(has_observation)
state, xy_pred = predict_then_update(params, state, has_observation, xy_true)
xy_preds.append(xy_pred)
xy_preds = np.stack(xy_preds)
print("has_observations", has_observations
draw_throw_with_colours(
[xy_trues, xy_preds],
['red', 'green'])
has_observations [True, True, True, True, True, False, False, True, False, False]
notice how the filter shoots way off after that pair of Falses. the default_params
values
might be good enough to predict one step in the future, but they don't look robust to predicting two
or more steps in the future.
a key first thing to do is to implement the for loop with jax.lax.scan so that jax can more cleanly trace it.
jax.lax.scan
provides the classic functional programming idea of iterating over a method
with a carried state.
note:
has_observation
into the jax function
which means having to assign a key
to the state
that will be carried along
between calls to predict_then_update_single
xy_trues
sequence is the same length. in the
cases where we might have a different roll out length we'd need to use something like zero padding
and loss masking or bucketing by sequence length.
xy_preds
in a single call from xy_trues
def initial_state(seed):
return {
# Initial State
'x': jnp.zeros((4, 2)),
# Initial Covariance Matrix
'P': jnp.eye(4),
# rng key for missing observations
'key': jax.random.key(seed)
}
def rolled_out_predict_then_update(params, seed, xy_trues, missing_rate):
def predict_then_update_single(state, xy_true):
state, xy_pred = predict(params, state)
state['key'], subkey = jax.random.split(state['key'])
has_observation = jax.random.uniform(subkey) > missing_rate
xy_for_update = jnp.where(has_observation, xy_true, xy_pred)
state = update(params, state, xy_for_update)
return state, xy_pred
_final_state, predictions = jax.lax.scan(predict_then_update_single, initial_state(seed), xy_trues)
return predictions
xy_trues = np.array(list(simulate_throw(2.8, 4.8)))
seed = 1234414
xy_preds = jax.jit(rolled_out_predict_then_update)(default_params(), seed, xy_trues, missing_rate=0.2)
draw_throw_with_colours(
[xy_trues, xy_preds],
['red', 'green'])
this example has a bunch of missed observations and the filter struggles quite a bit :/
with a function now that takes a list of xy_true
values and returns a list of xy_pred
values
we can start to think about a loss, the simplest being mean square error.
def loss_fn(params, seed, xy_trues, missing_rate):
predictions = rolled_out_predict_then_update(params, seed, xy_trues, missing_rate)
squared_difference = (predictions - xy_trues) ** 2
return jnp.mean(squared_difference)
it's interesting to note how the losses are quite unstable across seeds,
especially as we increase the missing_rate
.
for missing_rate in np.linspace(0.0, 0.9, 10):
losses = [loss_fn(default_params(), seed, xy_trues, missing_rate)
for seed in range(100)]
print(missing_rate, np.mean(losses), np.std(losses))
0.0 2.359239 2.3841858e-07
0.1 5.6141863 10.308739
0.2 18.179102 53.249737
0.3 88.48642 323.95425
0.4 173.53473 505.12674
0.5 269.37463 636.1734
0.6 497.91122 774.9014
0.7 712.81384 858.53827
0.8 989.5386 972.14777
0.9 661.3348 749.5342
having a loss function means we can get gradients with respect to the params and use them in a trivial gradient descent update step
@jax.jit
def update_step(params, seed, xy_trues):
gradients = jax.grad(loss_fn)(params, seed, xy_trues, missing_rate=0.1)
def apply_gradients(p, g):
learning_rate = 1e-5
return p - learning_rate * g
return jax.tree_util.tree_map(apply_gradients, params, gradients)
this update step allows us to sample a trajectory, run it through the rolled out filter, calculate a loss and gradient and finally update the params...
params = default_params()
seed = 0
for _ in range(1000):
dx = 2 + np.random.uniform() * 5 # (2, 7)
dy = 4 + np.random.uniform() * 4 # (4, 8)
xy_trues = simulate_throw(dx, dy)
params = update_step(params, seed, xy_trues)
if i % 100 == 0:
print("loss", next_seed, loss_fn(params, next_seed, xy_trues, missing_rate=0.1))
seed += 1
loss 0 2.5613098 loss 100 12.690718 loss 200 4.0097485 loss 300 4.4077697 loss 400 4.8039966 loss 500 3.5365129 loss 600 3.0296855 loss 700 2.361609 loss 800 3.0249815 loss 900 1.7365919
we see the loss has dropped, hooray! but how does the filter behave?
if we plot some examples of the default_params
versus these trained params
for
a range of missing_rate
we can see the trained filter is much more robust to missing values
xy_trues = simulate_throw_a(3, 5)
seed = 1234414
def throw_img(xy_trues, seed, missing_rate):
xy_preds_initial_params = rolled_out_predict_then_update(default_params(), seed, xy_trues, missing_rate)
xy_preds_trained_params = rolled_out_predict_then_update(params, seed, xy_trues, missing_rate)
return draw_throw_with_colours(
[xy_trues, xy_preds_initial_params, xy_preds_trained_params],
['red', 'green', 'yellow'])
missing_rate=0.0 | missing_rate=0.2 | missing_rate=0.4 |
let's look at the differences between the original params and the ones that were learnt....
param | original | learnt |
---|---|---|
A | [[1. 0. 1. 0.] [0. 1. 0. 1.] [0. 0. 1. 0.] [0. 0. 0. 1.]] |
[[ 0.958 0.014 0.897 0.038] [ 0.024 1.016 0.021 1.020] [ 0.004 -0.019 0.967 -0.003] [-0.026 -0.020 0.024 1.025]] |
Q | [[0.25 0. 0.5 0. ] [0. 0.25 0. 0.5 ] [0.5 0. 1. 0. ] [0. 0.5 0. 1. ]] |
[[ 0.278 0.064 0.470 -0.035] [-0.001 0.279 0.013 0.513] [ 0.504 -0.027 1.031 0.042] [-0.002 0.465 0.003 1.005]] |
R | [[0.01 0. ] [0. 0.01]] |
[[0.069 0.029] [0.011 0.137]] |
we can see that A
and Q
had minor changes but R
was changed much more, particularly the value
related to y_std_meas
seeing this result made me realise something. by providing the full A
matrix to be optimised
we end up with non-zero and non-one values where as i really only wanted to tune for dt
it is interesting that the model has tuned the transistion matrix fully
but maybe in some cases it's better to constrain things and only allow it to change dt
to do this we just need to be more explict about what the actual parameter set is..
def default_params():
return {
'dt': 1.0,
'std_acc': 1.0,
'x_std_meas': 0.1,
'y_std_meas': 0.1
}
and then materialise A
, Q
and R
from the params
as required in predict
and update
def predict(params, state):
dt = params['dt']
# State Transition Matrix
A = jnp.array([[1, 0, dt, 0],
[0, 1, 0, dt],
[0, 0, 1, 0],
[0, 0, 0, 1]])
# Process Noise Covariance
Q = jnp.array([[(dt**4)/4, 0, (dt**3)/2, 0],
[0, (dt**4)/4, 0, (dt**3)/2],
[(dt**3)/2, 0, dt**2, 0],
[0, (dt**3)/2, 0, dt**2]]) * params['std_acc']**2
state['x'] = A @ state['x']
state['P'] = ((A @ state['P']) @ A.T) + Q
xy_pred = state['x'][0]
return state, xy_pred
def update(params, state, z):
# Define Measurement Mapping Matrix
H = jnp.array([[1, 0, 0, 0],
[0, 1, 0, 0]])
# Measurement Noise Covariance
R = jnp.array([[params['x_std_meas'] **2, 0],
[0, params['y_std_meas'] **2]])
S = (H @ (state['P'] @ H.T)) + R
K = (state['P'] @ H.T) @ jnp.linalg.inv(S)
state['x'] = state['x'] + (K @ (z - (H @ state['x'])))
I = jnp.eye(4)
state['P'] = (I - (K @ H)) @ state['P']
return state
doing this makes for a minor difference in the loss, and it's not really even visible in the trace visualisation.
to be honest though given the instability of the loss there are bigger things at play here in terms of ways to improve...
step | full matrix params | scalar params |
---|---|---|
0 | 2.561 | 5.138 |
100 | 12.69 | 24.98 |
200 | 4.009 | 2.928 |
300 | 4.407 | 6.072 |
400 | 4.803 | 2.422 |
500 | 3.536 | 2.577 |
600 | 3.029 | 4.388 |
700 | 2.361 | 3.006 |
800 | 3.024 | 1.631 |
900 | 1.736 | 3.306 |
running an update step per example is often troublesome. not only is it slow (since we're not making use of as much vectorisation as possible) it can also suffer a lot from gradient variance problems.
the general approach is to batch things and calculate gradients with respect to multiple examples before an update step. jax.vmap is perfect for this...
we can express the loss with respect to multiple examples by using jax to make a version of the rollout function that rolls out multiple trajectories at once...
note: in_axes
denotes we want to vmap transform to vectorise over the
second and third arg of the loss function ( seed
& xy_trues
)
while broadcasting the first and last arg ( params
& missing_rate
)
def loss_fn(params, seeds, xy_truess, missing_rate):
v_rolled_out_predict_then_update = jax.vmap(rolled_out_predict_then_update, in_axes=[None, 0, 0, None])
predictions = v_rolled_out_predict_then_update(params, seeds, xy_truess, missing_rate)
squared_difference = (predictions - xy_truess) ** 2
return jnp.mean(squared_difference)
this loss is called as before but instead with a batch of seeds
and xy_trues
batch_size = 8
xy_truess = []
seeds = []
for _ in range(batch_size):
dx = 2 + np.random.uniform() * 5 # (2, 7)
dy = 4 + np.random.uniform() * 4 # (4, 8)
xy_truess.append(simulate_throw_a(dx, dy))
seeds.append(next_seed)
next_seed += 1
xy_truess = np.stack(xy_truess)
seeds = np.array(seeds)
...
params = update_step(params, seeds, xy_truess)
interestingly i found this didn't actually help! usually for a neural network this is a big win, for speed as well as gradient variance, but in this case it was behaving worse for me :/
rolling your own optimisation step is generally (*) a bad idea, better to use an update step using an optax optimiser.
params = default_params()
opt = optax.adam(1e-6)
opt_state = opt.init(params)
...
@jax.jit
def update_step(params, opt_state, seeds, xy_truess):
gradients = jax.grad(loss_fn)(params, seeds, xy_truess, missing_rate=0.1)
updates, opt_state = opt.update(gradients, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
...
for i in range(1000):
...
params, opt_state = update_step(params, opt_state, seeds, xy_truess)
...
having said that trying a couple of different optimisers gave no better result than the simple hand rolled update step from above!!! i guess the dynamics are weird enough that they aren't fitting the expected default space of adam, rmsprop, etc (?) am betting on the noisy behaviour of the rollout to be at fault...
sometimes the best way to avoid noisy samples is go nuts on a fixed large dataset size
e.g. if we just materialise a large range of cross product values for dx
, dy
...
N = 40
SEEDS_PER_DX_DY = 50
xy_truess = []
seeds = []
for dx_l in np.linspace(0, 1, N):
for dy_l in np.linspace(0, 1, N):
xy_trues = simulate_throw_a(dx=2+(dx_l*5), dy=4+(dy_l*4))
for _ in range(SEEDS_PER_DX_DY):
xy_truess.append(xy_trues)
xy_truess = np.stack(xy_truess)
seeds = np.array(range(len(xy_truess)))
print("xy_truess", xy_truess.shape, "seeds", seeds.shape)
xy_truess (80000, 10, 2) seeds (80000,)
we can then jit
this fixed data into the update step
@jax.jit
def update_step(params, opt_state):
gradients = jax.grad(loss_fn)(params, seeds, xy_truess, missing_rate=0.1)
updates, opt_state = opt.update(gradients, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
for i in range(1000):
params, opt_state = update_step(params, opt_state)
if i % 100 == 0:
print("loss", i, loss_fn(params, seeds, xy_truess, missing_rate=0.1))
and we get a fast update and the most stable loss, though it flattens out pretty quick
loss 0 9.521738 loss 100 5.65931 loss 200 5.651896 loss 300 5.6428213 loss 400 5.636997 ...
variable | initial | trained |
---|---|---|
dt | 1.0 | 1.3539 |
std_acc | 1.0 | 0.1803 |
x_std_meas | 0.1 | 0.1394 |
y_std_meas | 0.1 | 0.1 |
curiously in this formulation we've ended up with no gradients with respect to y_std_meas
(?)
so i've either got a bug or it's the subtle way it's being used.. TODO!
anyways, that's my time up. check out the prototype colab i made for this post
]]>in my last post, wavenet on an mcu, we talked about running a cached optimised version of wavenet on an MCU at 48kHz. this time we're going to get things running on an FPGA instead!
note: this post assumes you've read that last one; it describes the model architecture, some caching tricks, and describes a waveshaping task.
when doing some research on how to make things faster i stumbled on the eurorack-pmod project which is a piece of hardware that connects an FPGA with a eurorack setup. it includes 192kHz analog-digital-analog conversion and supports 4 channels in and out. just add an FPGA! perfect!
i'd never done anything beyond blinking LEDs with an FPGA but it can't be that hard.... right? right?!?!
well. turns out it wasn't that simple... i couldn't find any open source examples of compiling a neural net to an FPGA. the closest was HLS4ML but, if i understand it correctly, it works only with (expensive) propertiary FPGA toolchains :(
HLS4ML it did at least though put me onto qkeras which is a key component we'll talk about in a second.
( not interested in details? jump straight to the bottom for the demos )
recall from the last post that an important aspect of the MCU project was having multiple versions for training vs inference
the FPGA version ended with a similar collection of variants;
for both the MCU and FPGA versions we use a single firmware running on the daisy patch for training data collection.
let's expand on each part of the FPGA version...
neural networks end up doing a lot of matrix multiplication. a lot.
for training it's common to work with a floating point representation since it gives the best continuous view of the loss space.
but for inference we can often reduce the precision of the weights and use integer arithmetic. we're motivated to do this since, for a lot of systems, integer math is much faster to run than full floating point math.
( and, actually, in a lot of systems we just simply might not have the ability to do floating point math at all! )
the systems we use for converting from floats during training to something simpler for inference is called quantisation and we'll look at two flavours for this project.
for this project i mainly used fixed point numbers.
fixed point numbers are a simpler representation of floating point numbers with some constraints around range and precision but they allow the multiplication to be done as if it were integer multiplication. ( this is the first project i've done with fixed point math and i'm hooked, it's perfect for neural nets! )
the high level idea is you specify a total number of bits and then how many of those bits you want to use for the integer part of the number, and how many you want to use from representing the fractional part.
in this project all inputs, outputs, weights and biases are 16 bits in total with 4 bits for the integer and 12 bits for the fractional part. ( i.e. FP4.12 )
with "only" 4 bits for the integer part this means the range of values is +/- 2^4 = 8. though this might seem limiting, it's actually ok for a network, where generally activations etc are centered on zero ( especially if we add a bit of L2 regularisation along the way )
with 12 bits allocated to the fractional part we are able to describe numbers with a precision of 2^-12 = 0.00024.
here are some examples; we show the 16 bit binary number with a decimal point after the 4th bit to denote the change from the integer part to the fractional part.
bits | decimal |
---|---|
0010.0000 0000 0000 | 2^1 = 2 |
0101.0000 0000 0000 | 2^2 + 2^0 = 4 + 1 = 5 |
0000.0000 0000 0000 | 0 |
0000.1000 0000 0000 | 2^-1 = 0.5 |
0000.1001 0000 0000 | 2^-1 + 2^-4 = 0.5 + 0.0625 = 0.5625 |
0000.0000 0000 0001 | 2^-12 = 0.000244140625 |
the purpose of the qkeras model then is to train in full float32 but provide the ability to export the weights and biases in this configured fixed point configuration.
notes...
the options for quantisation can get pretty crazy too!
qkeras provides another scheme called power-of-two quantisation where all values are quantised to be only powers of two. i.e. depending on the fixed point config, a weight/bias can only only be one of [ +/- 1, 0, +/- 1/2, +/- 1/4, +/- 1/8, ...]
though this seems overly restrictive it has one HUGE important benefit... when a weight is a power of two then the "multiplication" of a feature by that weight can be simply done with a bit shift operation. and bit shifting is VERY fast.
and there are ways to "recover" representational power too, the best one i found being based around matrix factorisation.
if we have, say, a restricted weight matrix W of shape (8, 8) it can only contain those fixed values. but if we instead represent W as a a product of matrices, say, (8, 32). (32, 8) then we can see that, even though all the individual weights are restricted, the product of the matrices, an effective (8, 8) matrix, has many more possible values.
the pro is that the weights can take many more values, all the combos of w*w. the con though is we have to do two mat muls. depending on how much space we have for allocating the shift operations compared to the number of multiply unit we have, this tradeoff might be ok!
i messed around with this a lot and though it was interesting, and generally worked, it turned on that for the FPGA sizing ( i'm using at least ) the best result was to just use fixed point multiplication instead. :/ am guessing this is fault on my part in terms of poor verilog design and i still have some ideas to try at least...
anyways, back to the models. the next model after the qkeras one is a fxpmath one.
the fxpmath version connects qkeras model fixed point export with the caching approach the inference logic that will go into the verilog design
the activation caching has two two elements that are basically the same as the MCU version; the left shift buffer
for handling the first input and an activation cache
for handling the activations between each convolutional
layer.
the big difference comes in with the implementation of the convolution which i had to roll from scratch :/ but at least it allows for some very optimised parallelisation.
consider a 1D convolution with kernel size K=4 & input / output feature depth of 16.
since we don't intend to stride this convolution at all ( that's handled implicitly by the activation caching ) we can treat this convolution as the following steps...
each of the x4 matrix multiplications are actually just a row by matrix multiplication, so can be decomposed into k=16
independent dot products ( we'll call this a row_by_matrix_multiply
from now on )
and each of those dot products can be decomposed into k=16 independent multiplications followed by an accumulation.
the reason for being so explicit about what is independent, and what isn't, comes into play with the verilog version.
verilog is a language used to "program" FPGAs and 's not at all like other languages, it's been really interesting to learn.
in the MPU version the two big concerns were
but the FPGA version is a little different. instead we more flexibility to design things based on what we want to run in parallel, vs what we want to run sequentially. for neural networks this gives lots of options for design!
at a 30,000" view of verilog we have two main concerns; 1) executing code in parallel 2) executing blocks of code sequentially.
e.g. consider a dot product A.B with |A|=|B|=4
normally we'd think of this as simply something like a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3]
provided by a single np.dot
call.
but if we're writing verilog we have to think of things in terms of hardware, and that means considering parallel vs sequential.
the simplest sequential way would be like the following psuedo code...
// WARNING PSEUDO CODE!
case(state):
multiply_0:
// recall; the following three statements are run _in parallel_
accum <= 0 // set accumulator to 0
product <= a[0] * b[0] // set intermediate product variable to a0.b0
state <= multiply_1 // set the next state
multiply_1:
accum <= accum + product // update accumulator with the product value from the last state
product <= a[1] * b[1] // set product variable to a0.b0
state <= multiply_2 // set the next state
multiply_2:
accum <= accum + product
product <= a[2] * b[2]
state <= multiply_1
multiply_3:
accum <= accum + product
product <= a[3] * b[3]
state <= final_accumulate
final_accumulate:
accum <= accum + product
state <= result_ready
done:
// final result available in `accum`
state <= done
doing things this way does the dot product in 5 cycles...
an important aspect of how verilog works is to note that the statements in one of those case clauses all run in parallel. i.e. all right hand sides are evaluated and then assigned to the left hand side
e.g. the following would swap a and b
a <= b; b <= a;
and if functionally equivalent to ...
b <= a; a <= b;
so thinking in terms of sequential vs parallel we have the option to do more than one multiplication at any given time. this requires the hardware to be able to support 2 multiplications per clock cycle, but saves some clock cycles. not much in this case, but it add ups as |a| and |b| increase...
// WARNING PSEUDO CODE!
case(state):
multiply_01:
accum_0 <= 0
accum_1 <= 0
product_0 <= a[0] * b[0]
product_1 <= a[1] * b[1]
state <= multiply_23
multiply_23:
accum_0 <= accum_0 + product_0
accum_1 <= accum_1 + product_1
product_0 <= a[2] * b[2]
product_1 <= a[3] * b[3]
state <= accumulate_0
accumulate_0:
accum_0 <= accum_0 + product_0
accum_1 <= accum_1 + product_1
state <= accumulate_1
accumulate_1:
accum_0 <= accum_0 + accum_1
state <= done
done:
// final result available in `accum_0`
state <= done
or, if we want to go nuts, and, again, can support it, we can do all the elements multiplications at the same time, and then hierarchically accumulate the result into one.
// WARNING PSEUDO CODE!
case(state):
multiply_all:
// calculate all 4 elements of dot product in parallel; ( note: requires 4 available multiplication units )
product_0 <= a[0] * b[0]
product_1 <= a[1] * b[1]
product_2 <= a[2] * b[2]
product_3 <= a[3] * b[3]
state <= accumulate_0
accumulate_0:
// add p0 and p1 at the same time as p2 and p3
accum_0 <= product_0 + product_1
accum_1 <= product_2 + product_3
accumulate_1:
// final add of the two
accum_0 <= accum_0 + accum_1
state <= done
done:
// final result available in `accum0`
state <= done
this general idea of how much we do in a single clock cycles versus making values available in the next clock gives a lot of flexibility for a design
specifically for this neural net i've represented the K=4 conv 1d as
dot_product
module
code
where the elements of dot products are calculated sequentially; i.e. x[0]*w[0]
in the first step, x[1]*w[1]
in the second, etc.
so a dot product of N values takes N*M clock cycles ( where M is the number of cycles the multiple unit takes ) + some parallel accumulation.
row_by_matrix_multiply
module
code
which runs the j=16 dot products required for each row_by_matrix_multiply
in parallel
conv1d
module
code
that runs the K=4 row_by_matrix_multiply
s are also in parallel, as well as handling the state machine for accumulating results with a bias and applying relu.
so we end up having all j=16 * K=4 = 64 dot products run in parallel, all together.
having said this, there are a number of ways to restructure this; e.g. if there were too many dot products to run
in parallel for the 4 row_by_matrix_multiply
we could run 2 of them in parallel, and then when they were finished,
run the other 2. there are loads of trade offs between the number of multiple units available vs the time required to run them.
in the MCU version i was only feeding in samples based on the embedding corner points, one of 4 types sampled randomly from...
input ( core wave, e0, e1 ) | output |
---|---|
(triangle, 0, 0) | sine |
(triangle, 0, 1) | ramp |
(triangle, 1, 1) | zigzag |
(triangle, 1, 0) | square |
for the first FPGA version i did this but the model was large enough that it was quickly overfitting this and basically outputing noise for the intermediate points. as such for training of this model i changed things a bit to include interpolated data.
basically we emit corner points, say (e0=0, e1=1, sine wave) as well as interpolated points, where we pick two waves, say sine and ramp, and a random point between them and train for that point as an interpolated wave between the two ( using constant power interpolation )
though i couldn't get the MCU model to converge well with this kind of data, the larger FPGA variant has no problems.
i also messed around with a 3d input embeddings; to translate between any pairing, but it didn't add anything really so i stuck with 2d.
where as the final MCU model was ...
--------------------------------------------------------------------------- Layer (type) Output Shape Par# Conv1D params --------------------------------------------------------------------------- input (InputLayer) [(None, 256, 3)] 0 c0a (Conv1D) (None, 64, 4) 52 F=4, K=4, D=1, P=causal c0b (Conv1D) (None, 64, 4) 20 F=4, K=1 c1a (Conv1D) (None, 16, 4) 68 F=4, K=4, D=4, P=causal c1b (Conv1D) (None, 16, 4) 20 F=4, K=1 c2a (Conv1D) (None, 4, 4) 68 F=4, K=4, D=16, P=causal c2b (Conv1D) (None, 4, 4) 20 F=4, K=1 c3a (Conv1D) (None, 1, 8) 136 F=8, K=4, D=64, P=causal c3b (Conv1D) (None, 1, 8) 72 F=8, K=1 y_pred (Conv1D) (None, 1, 1) 13 F=1, K=1 --------------------------------------------------------------------------- Trainable params: 465 ---------------------------------------------------------------------------
... the current FPGA version i'm running is ...
----------------------------------------------------------------- Layer (type) Output Shape Param # ----------------------------------------------------------------- input_1 (InputLayer) [(None, 64, 4)] 0 qconv_0 (QConv1D) (None, 16, 16) 272 qrelu_0 (QActivation) (None, 16, 16) 0 qconv_1 (QConv1D) (None, 4, 16) 1040 qrelu_1 (QActivation) (None, 4, 16) 0 qconv_2 (QConv1D) (None, 1, 4) 260 ----------------------------------------------------------------- Trainable params: 1,572 -----------------------------------------------------------------
it has the following differences
so compared to the MCU version
to be honest utilisation is a bit harder to compare; the trade off between compute and space is quite different with an FPGA design..
for each sample coming in at 192kHz the FPGA is running a simple state machine of 1) accept next sample 2) run the sequence of qconvs and activation caches, then 3) output the result and sit in a while-true loop until the next sample. when i say above the FPGA is running at 30% what i really should say is that it's spending 70% of the time in the post sample processing while-true loop waiting for the next sample.
looking at the device utilisation we have the following..
Info: Device utilisation: Info: TRELLIS_IO: 11/ 365 3% Info: DCCA: 5/ 56 8% Info: DP16KD: 24/ 208 11% Info: MULT18X18D: 134/ 156 85% Info: EHXPLLL: 1/ 4 25% Info: TRELLIS_FF: 21081/83640 25% Info: TRELLIS_COMB: 44647/83640 53% Info: TRELLIS_RAMW: 192/10455 1%
the pieces of interest are...
DP16KD
: which is the amount of ( one type ) of RAM being used; this looks to be dominated by the activation cache,
so with only 11% being used there is a lot of room for having more layers.
MULT18X18D
: is the big one, it's the max number of multiplication DSP units being used at any one time. in this model that qconv1
with in_dim = out_dim = 16. since it's already 85% if we wanted to increase the filter size much more we might be forced to not
run all 16 dot products of the 16x16 row_by_matrix_multiply
at once but instead, say, do 8 in parallel, then the other 8.
this would incur a latency hit, but that's totally fine given we still have a lot of clock time available to do work between samples.
the trouble is the code as written would end up being tricky to refactor.
currently things are setup so that the entire network has to run before the next sample comes in. this is just because it was the simplest thing to do while i'm learning verilog and it seems like the FPGA is fast enough for it. but with a neural net it doesn't have to be like that; we really just need to finish the first layer before next sample comes, not the whole network. as long as we don't mind a little bit of output latency we can run a layer per sample clock tick. doing this would actually delay the output by number-of-layer sample clock ticks, but at 192kHz that'd be fine :)
another way to run a bigger network is to continue using the same naive MULT18X18D dsp allocation but just use an intermediate layer twice; e.g. if you have a network input -> conv0 -> output you can get extra depth but running input -> conv0 -> conv0 -> output instead. you lose a bit of representation power, since the same layer needs to model two layers, but sometimes it's worth it. in this model we'd get extra depth without having to worry about more allocation, and we've got plenty of headroom to do more compute.
the work in progress model i've been tinkering with for the power of two quantisation is the following...
_________________________________________________________________ Layer (type) Output Shape Param # _________________________________________________________________ input_1 (InputLayer) [(None, 640, 4)] 0 qconv_0_qb (QConv1D) (None, 640, 8) 136 qrelu_0 (QActivation) (None, 640, 8) 0 qconv_1_qb (QConv1D) (None, 640, 8) 264 qrelu_1 (QActivation) (None, 640, 8) 0 qconv_1_1a_po2 (QConv1D) (None, 640, 16) 144 qconv_1_1b_po2 (QConv1D) (None, 640, 8) 136 qrelu_1_1 (QActivation) (None, 640, 8) 0 qconv_1_2a_po2 (QConv1D) (None, 640, 16) 144 qconv_1_2b_po2 (QConv1D) (None, 640, 8) 136 qrelu_1_2 (QActivation) (None, 640, 8) 0 qconv_2_qb (QConv1D) (None, 640, 4) 132 _________________________________________________________________ Trainable params: 1,092 _________________________________________________________________
_qb
are the normal quantised bits layers which are fixed point weights and use MULT18X18D
units for inference.
_po2
have power-of-two weights and use just shift operators for inference.
the output quality is the same and it uses less MULT18X18D
units but doesn't quite fit :)
Info: Device utilisation: Info: TRELLIS_IO: 11/ 365 3% Info: DCCA: 5/ 56 8% Info: DP16KD: 12/ 208 5% Info: MULT18X18D: 70/ 156 44% Info: EHXPLLL: 1/ 4 25% Info: TRELLIS_FF: 37356/83640 44% Info: TRELLIS_COMB: 85053/83640 101% close!! Info: TRELLIS_RAMW: 96/10455 0%
i spend a bunch of timing trying various combos of reuse of the modules but never had a design that would meet timing for the FPGA :/
i still feel i'm doing something wrong here, and might come back to it.
waveforms generated by the model across the embedding space |
the above images show the range of waveforms generated by the model across the two dimensional embedding space. of note...
lets look at some examples!
( make sure subtitles are enabled! )
an example of a triangle core wave as input and a manual transistion of the embedding values between the corners of the 2d space.
modulating the embedding x value at audio rates makes for some great timbres! the FPGA and the eurorack pmod have no problems handling this.
since the model was trained only on a triangle wave if you give it something discontinuous, like a ramp or square, it glitches! :)
and if you sequence things, add an envelope to the embedding point & stick in some 909 kick & hats.... what do you have? neural net techno! doff. doff. some delay on the hats and clap, but no effects on the oscillator. ( mixed the hats too loud as well, story of my life )
there is a lot of room to optimise for a larger network. see the issues on github
the code for training and verilog simulation is in this github repo
whereas the code representing this as core on the eurorack pmod is in this github repo on a branch
]]>the electro smith daisy patch is a eurorack module made for prototyping; it provides a powerful audio focussed microcontroller setup (the daisy 'seed' with arm cortex-m7 running at 480MHz) along with all the connectivity required to be a eurorack module.
pretty much the first thing i thought when i saw one was; "could it run a neural net?" microcontrollers are fast, but so too is audio rate.
after a bit of research i found there are a couple of people who have done some real time audio effects processing on the daisy seed; e.g. guitarML's neural seed
all the examples i found though were quite small recurrent models and i was keen to give wavenet a go instead ( i'm a big fan of causal dilated convolutions )
the wavenet architecture is a 1D convolutional network designed to operate on timeseries data. it is composed of two key structures;
usually it's just dilation convolutions stacked but for my variant i also include one extra 1x1 conv between each dilation; a 1x1 doesn't break any of the dilation structure wavenet needs and including another non linearity is almost always a good thing.
the first step was to write a firmware that just recorded some audio and/or control voltages into buffers and then streamed them out over a serial connection. it's kinda clumsy, and i'm sure there's a cleaner way, but it works without having to mess around too much. the code is datalogger_firmware/
from there a wavenet could be trained ( see keras_model.py ) and it was trivial to quantise and export to a c++ lib for the microcontroller using edge impulse's bring-your-own-model
from there i got a bit stuck though integrating tensorflow with the daisy
optimised arm_math.h
stuff ( Make files and such aren't my expertise ) so instead i
thought i'd use this chance to write my own inference. convolutions don't need much,
it's all just a bunch of matrix math after all.
while poking around doing this it suddenly occured to me that you could do heavy heavy caching of the convolution activations! i was super excited, thinking i was doing some novel, and it wasn't until i was actually finished that i found out someone else had thought of the same idea basically, what they call fast-wavenet :/
oh well, it was fun to discover independently i suppose, and i ended up implementing things a bit differently.
so let's walk through the caching optimisation...
consider the following wavenet like network; it has 16 inputs being integrated thru 3 layers to a single prediction
if we consider a sliding window input at time step 4 we see that the output of node 0 ( the node with the red 0 ) will be the processed values from [0, 1, 2, 3]
a bit later on an interesting things happens at time step 8; node 1 now gets the same inputs that node 0 had 4 steps ago.
so we don't need to calculate it! when processing a timeseries we can see that the node 1 output just lags node 0 by 4 steps, and nodes 2 and 3 lag another 4 each. as long as we've got the memory for caching we can store all these in a circular buffer.
and this whole thing is stackable! in fact we only ever need to run the right hand side convolutions, as long as we have the memory to cache.
one big win for this on the daisy is that it can all run fast and small enough that we can use full float32 math everywhere. hoorah!
in terms of coding the whole thing becomes an exercise in
the following model has an input of 3 values => a receptive field of 256 steps, and a single output. it runs on the daisy using the above caching technique at 48KHz @88% CPU ( ~20 micro seconds per inference ). we'll describe what the 3 input values are in a bit. it's the model we'll use the examples at the end.
cNa - denotes the dilated convolutions cNb - denotes the 1x1s convolutions that follow cNa F - number of filters K - kernel size; either 4 for cNa or 1 for cNb D - dilation; K^layer# for cNa, or 1 for cNb P - padding; 'causal' or the layer dft 'valid'
____________________________________________________________________________________________ Layer (type) Output Shape Par# Conv1D params ============================================================================================ input (InputLayer) [(None, 256, 3)] 0 c0a (Conv1D) (None, 64, 4) 52 F=4, K=4, D=1, P=causal c0b (Conv1D) (None, 64, 4) 20 F=4, K=1 c1a (Conv1D) (None, 16, 4) 68 F=4, K=4, D=4, P=causal c1b (Conv1D) (None, 16, 4) 20 F=4, K=1 c2a (Conv1D) (None, 4, 4) 68 F=4, K=4, D=16, P=causal c2b (Conv1D) (None, 4, 4) 20 F=4, K=1 c3a (Conv1D) (None, 1, 8) 136 F=8, K=4, D=64, P=causal c3b (Conv1D) (None, 1, 8) 72 F=8, K=1 y_pred (Conv1D) (None, 1, 1) 13 F=1, K=1 ============================================================================================ Total params: 465 ____________________________________________________________________________________________
as a side note, the daisy can actually run at 96kHz but at this rate i could only run a smaller [c0a, c0b, c1a, c2b] model with just 2 filters each. it runs, but i couldn't get it to train on the examples i show below, so i didn't use it. shame because naming the blog post "at (almost) 100,000 inferences a second" would have had a nicer ring to it :D
it can also run slower, e.g. 32kHz, which allows either more filters per step, or even more depth => large receptive field.
but these combos demonstrate an interesting set of trade offs we have between
for training i just use keras.layer.Conv1D
everywhere but then exported to the device in two passes; a python
prototype and then the c++ code for the device.
the final code on the daisy uses the cmsis-dsp library for all the matrix math. only three pieces end up being used though
arm_mat_init_f32
for making the matrix structures,
arm_mat_mult_f32
for the actual matrix multiplications and
arm_add_f32
for the kernel accumulation and adding biases here and there.
to be honest none of these were benchmarked and i assume they are faster than if the multiplications were just rolled out (???)
the first pass was to get the inference working as a prototype in python. the cmsisdsp python lib was used ( a pure python api equivalent ) and the code is cmsisdsp_py_version/.
having cmsisdsp
was very handy to prototype, especially as i've not done any cmsis
stuff before.
the code for the daisy is under
inference_firmware/
and is a port of the python version. it's where i probably spent the most amount of time, especially block.h
it has 4 main parts
left_shift_buffer.h
to handle the shifted InputLayer
block.h
to handle the sequential running of each cNa
and cNb
pair
rolling_cache.h
which represents the circular buffer used for activation lag and
regression.h
which is just a simple y=mx+b
using the weights from the final 1x1 Conv1D regression used during training
it all ends up being pretty straightforward but i did iterate a bit; any code involving memcpy and pointer arithmetic needs careful attention.
the "best" bit of the code is where the initial kera model is trained in a python notebook and
then exported to be used in the c++ code by having python code construct a model_defn.h
with
a bunch of print statements :/ it's actually written to /tmp/model_defn.h
no less! what
a hack, lol. may god have mercy on my soul.
for a demo project i wanted to make a little waveshaper; something that takes one waveform and outputs another.
so 5 waveforms were collected from another eurorack oscillator; a triangle, sine, saw, square and a weird zigzag thing. during data collection random voltages were sampled to change the oscillator's frequency.
to convert to actual training dataset the triangle wave is used as input with one of the other waves acting as output. which output is decided by including two additional selector variables in the inputs. these variables are only ever {0, 1} during training but can act (very loosely) as an embedding since any float value can be passed in the range (0, 1) when running on the device.
input | output |
---|---|
(triangle, 0, 0) | sine |
(triangle, 0, 1) | ramp |
(triangle, 1, 0) | square |
(triangle, 1, 1) | zigzag |
the model was trained for < 1 min ( the model is pretty small, and there's not much variety in the data ).
it does well on held out test data...
but the much more interesting thing is what happens when we input values for x2 and x3 that weren't see during training data
e.g. what if we choose a point that is 20% between sine (0, 0) and square (1, 0) ? we end up with some weird non defined part of the input space.
these out of training distribution things are always the most interesting to me; values for x2 and x3 inbetween (0, 1) give a weird sort-of-interpolation. generally these results are never a smooth transistion unless there's some aspect of the loss that directs it to be so (and in this case there isn't).
we get the classic model hallucination stuff that i've always loved. ( see my older post on hallucinating softmaxs for more info )
we could encourage the model to make full use of the space by including a GAN like discriminator loss. this would be trained on random samples of values for (x2, x3). i've seen this kind of training force the model to be much more consistent for interpolated inputs.
this video of the network actually running give a better idea of the "interpolation". the green wave shows the input core triangle wave, the blue wave shows the output waveshaped result. the daisy module display shows the values for x2 and x3; with these values controlled by hand from the module on the right.
the video transistions across values of (x2, x3), moving between the corner (0, 1) values. at each corner the frequency of the core triangle input wave was adjusted over a range as well. something i hadn't considered is that when we change the frequency to be outside the range of what was seen during training the model does some weird extrapolation.
the interpolation extrapolation stuff makes some weird things, but it all sounds cool :)
further ideas could be...
code on github
]]>here's a recording; check it out!
and here's a pdf of the slides
]]>eurosat/all is a dataset of 27,000 64x64 satellite images taken with 13 spectral bands. each image is labelled one of ten classes.
for the purpose of classification these 13 aren't equally useful, and the information in them varies across resolutions. if we were designing a sensor we might choose to use different channels in different resolutions.
how can we explore the trade off between mixed resolutions and whether to use a channel at all?
let's start with a simple baseline to see what performance we get. we won't spend too much time on this model, we just want something that we can iterate on quickly.
the simple model shown below trained on a 'training' split with adam hits 0.942 top 1 accuracy on a 2nd 'validation' split in 5 epochs. that'll do for a start.
let's check the effect of including different combos of input channels. we'll do so by introducing a channel mask.
a mask of all ones denotes using all channels and gives our baseline performance
mask | validation accuracy |
[1,1,1,1,1,1,1,1,1,1,1,1,1] | 0.942 |
a mask of all zeros denotes using no channels and acts as a sanity check; it gives the performance of random chance which is in line with what we expect give the balanced training set. (note: we standardise the input data so that it has zero mean per channel (with the mean, standard deviation parameters fit against training data only) so we can get this effect)
mask | validation accuracy |
[1,1,1,1,1,1,1,1,1,1,1,1,1] | 0.942 |
[0,0,0,0,0,0,0,0,0,0,0,0,0] | 0.113 |
but what about if we drop just one channel? i.e. a mask of all ones except for a single zero.
channel to drop | validation accuracy |
0 | 0.735 |
1 | 0.528 |
2 | 0.661 |
3 | 0.675 |
4 | 0.809 |
5 | 0.724 |
6 | 0.749 |
7 | 0.634 |
8 | 0.874 |
9 | 0.934 |
10 | 0.593 |
11 | 0.339 |
12 | 0.896 |
from this we can see that the performance hit we get from losing a single channel is not always the same. in particular consider channel 11; if we drop that channel we get a huge hit! does that mean that if we keep only 11 that should give reasonable performance?
mask | validation accuracy |
[1,1,1,1,1,1,1,1,1,1,1,1,1] | 0.942 |
[0,0,0,0,0,0,0,0,0,0,0,0,0] | 0.113 |
[0,0,0,0,0,0,0,0,0,0,0,1,0] (keep just 11) | 0.260 |
bbbzzzttt (or other appropriate annoying buzzer noise). channel 11 is contributing to the classification but it's not being used independently. in general this is exactly the behaviour we want from a neural network but what should we do to explore the effect of not having this dependence?
consider using a dropout idea, just with input channels instead of intermediate nodes.
what behaviour do we get if we drop channels out during training? i.e. with 50% probability we replace an entire input channel with 0s?
things take longer to train and we get a slight hit in accuracy...
dropout? | validation accuracy |
no | 0.942 |
yes | 0.934 |
...but now when we mask out one channel at a time we don't get a big hit for losing any particular one.
channel to drop | validation accuracy | |
no dropout | with dropout | |
0 | 0.735 | 0.931 |
1 | 0.528 | 0.931 |
2 | 0.661 | 0.936 |
3 | 0.675 | 0.935 |
4 | 0.809 | 0.937 |
5 | 0.724 | 0.934 |
6 | 0.749 | 0.931 |
7 | 0.634 | 0.927 |
8 | 0.874 | 0.927 |
9 | 0.934 | 0.927 |
10 | 0.593 | 0.927 |
11 | 0.339 | 0.933 |
12 | 0.896 | 0.937 |
now that we have a model that is robust to any combo of channels what do we see if we use a simple genetic algorithm (GA) to evolve the channel mask to use with this pre trained network? a mask that represents using all channels will be the best right? right?
we'll evolve the GA using the network trained above but based on it's performance on a 3rd "ga_train" split using the inverse loss as a fitness function.
amusingly the GA finds that a mask of [1,1,0,1,0,0,0,1,1,0,1,0,1]
does better marginally better than all channels, but only uses 1/2 of them!
mask | split | accuracy |
[1,1,1,1,1,1,1,1,1,1,1,1,1] (all) | ga_validate | 0.934 |
[1,1,0,1,0,0,0,1,1,0,1,0,1] (ga) | ga_validate | 0.936 |
important note: we can imagine the best performance overall would be to have the GA evolve not the channels to use from this model, but the channels to use when training from scratch. this would though require a lot more model training, basically a full training cycle per fitness evaluation :( in the approach we describe here we only have to train a single model and then have the GA just run inference.
taking the idea of channel selection a step further, what if we got the GA to not only decide whether to use a channel or not, but also what resolution it should be in?
consider some example images across resolutions....
example images (just RGB channels shown) | |||
orig x64 | x32 | x16 | x8 |
we could then weight the use of a channel based on resolution; the higher the resolution the more the channel "costs" to use, with not using the channel at all being "free".
to support this we can change the GA to represent members not as a string of {0, 1}s but instead a sequence of {0, x8, x16, x32, x64} values per channel where these represent...
resolution | description | channel cost |
x64 | use original (64, 64) version of input | 0.8 |
x32 | use a 1/2 res (32, 32) version of input | 0.4 |
x16 | use a 1/4 res (16, 16) version of input | 0.2 |
x8 | use a 1/8 res (8, 8) version of input | 0.1 |
0 | don't use channel | 0 |
the change in the encoding of our GA is trivial, just 5 values per channel instead of 2, but before we look at that; how do we change our network?
we can do it without having to add too many extra parameters by using the magic of fully convolutional networks :)
notice how the main trunk of our first network was a series of 2d convolutions with a global spatial mean. this network will simply take as input all the resolutions we need! we can simply reuse it multiple times!
so we can have our network...
note: try as i might i can't get steps 2 to 4 to run parallelised in a pmap. asked on github about it and looks to be something you can't do at the moment.
when we consider channel cost vs loss there is no single best solution, it's a classic example of a pareto front where we see a tradeoff between the channel_cost and loss.
consider this sampling of 1,000 random channel masks...
the GA needs to operate with a fitness that's a single scalar; for now we just use
a simple combo of (1.0 / loss) - channel_cost
running with this fitness function we evolve the solution
[x16, x64, x64, x16, x32, ignore, x8, x64, x8, ignore, x8, ignore, x32]
it's on the pareto front, as we'd hope, and it's interesting that it includes a mix of resolutions including ignoring 3 channels completely :)
different mixings of loss and channel_cost would result in different GA solutions along the front
all the code is on github
]]>a classic piece of advice i hear for people using tpus is that they should "try using a bigger batch size".
this got me thinking; i wonder how big a batch size i could reasonably use? how would the optimisation go? how fast could i get things?
let's train a model on the eurosat/rgb dataset. it's a 10 way classification problem on 64x64 images
with a training split of 80% we have 21,600 training examples. we'll use another 10% for validation (2,700 images) and just not use the final 10% test split #hack
for the model we'll use a simple stack of convolutions with channel sizes 32, 64, 128 and 256, a stride of 2 for spatial reduction all with gelu activation. after the convolutions we'll do a simple global spatial pooling, a single 128d dense layer with gelu and then a 10d logit output. a pretty vanilla architecture of ~400K params. nothing fancy.
a v3-32 tpu pod slice is 4 hosts, each with 8 tpu devices.
21,600 training examples total => 5,400 examples per host => 675 examples per device.
this number of images easily fits on a device. great.
now usually augmentation is something we do randomly per batch, but for this hack we're interested in seeing how big a batch we can run. so why not fill out the dataset a bit by just running a stack of augmentations before training?
for each image we'll do 90, 180 and 270 deg rotations along with left/right flips for a total of 8 augmented images for each original image. e.g.....
this gives us now 172,800 images total => 43,200 per host => 5,400 per tpu device. which stills fits no problem.
side note: turns out doing this augmentation was one of the most fun parts of this hack :) see this tweet thread for some more info on how i used nested pmaps and vmaps to do it!
one motivation i had for this hack was to compare adam to lamb. i'd seen lamb referred to in the past, would it perform better for this model/dataset size? turns out it does! a simple sweep comparing lamb, adam and sgd shows lamb consistently doing the best. definitely one to add to the tuning mix from now on.
not only does the augmented data fit sharded across devices but we can replicate both the model parameters and the optimiser state as well. this is important for speed since the main training loop doesn't have to do any host/device communication. taking a data parallel approach means the only cross device comms is a gradient psum.
for training we run an inner loop just pumping the param = update(params)
step.
an outer loop runs the inner loop 100 times before doing a validation accuracy check.
the inner loop runs at 1.5s for the 100 iterations and since each iteration is a forward & backwards pass for all 172,800 images across all hosts that's 11M images processed per second. ðŸ”¥ðŸ”¥ðŸ”¥
at this speed the best result of 0.95 on validation takes 13 outer loops; i.e. all done in under 20s. o_O !!
when reviewing runs i did laugh to see sgd with momentum make a top 10 entry.
new t-shirt slogan: "sgd with momentum; always worth a try"
all the code in hacktastic undocumented form on github
]]>this 4 (and a bit) part tute series starts with
jax
fundamentals, builds up to describing a data parallel approach to training on a
cloud tpu pod slice, and
finishes with a tpu pod slice implementation of
ensemble nets....
all with the goal of solving 1d y=mx+b
and though it may seem like a bit of overkill it turns out it's a good example to work through so that we can focus on the library support without having to worry about the modelling.
in this first section we introduce some jax fundamentals; e.g. make_jaxpr, grad, jit, vmap & pmap.
colab: 01 pmap jit vmap oh my.ipynb
in part 2 we use the techniques from part 1 to solve y=mx+b
in pure jax. we'll also
introduce
pytrees
and various
tree_utils
for manipulating them.
we run first on a single device and work up to using pmap to demonstrate a simple data parallelism approach. along the way we'll do a small detour to a tpu pod slice to illustrate the difference in a multi host setup.
( note: the experience as described here for a pod slice isn't publically available yet; but sign up via the JAX on Cloud TPU Interest Form to get more info. see also this JAX on Cloud TPUs (NeurIPS 2020) talk )
colab: 02 y mx b on a tpu.ipynb
next we introduce haiku as a way of defining our model and optax as a library to provide standard optimisers. to illustrate there use we'll do a minimal port of our model and training loop to use them.
colab: 03 y mx b in haiku.ipynb
in part 4 we'll reimplement ensemble nets for this trivial model, continuing to do things in a way that supports running on a tpu pod slice.
colab: 04 y mx b haiku ensemble.ipynb
to wrap up we acknowledge that though tpu pod slices and data parallel approaches are fun we could have just solved this in a single calculation using the normal equation... :D
colab: 05 booooooooooooooooring.ipynb
y=mx+b
!!!OOD detection is an often overlooked part of many projects. it's super important to understand when your model is operating on data it's not familiar with!
though this can be framed as the problem of detecting that an input differs from training data, discussed more in this 'data engineering concerns for machine learning products' talk, for this post we're going to look at the problem more from the angle of the model needing to be able to express it's own lack of confidence.
a core theme of this post is treating OOD as a function of the output of the model. when i've played with this idea before the best result i've had is from using entropy as a proxy of confidence. in the past it's not be in the context of OOD directly but instead as a way of prioritising data annotation (i.e. uncertainty sampling for active learning)
using entropy as a stable measure of confidence though requires a model be well calibrated.
neural nets are sadly notorious for not being well calibrated under the default ways we train these days.
the first approach i was shown to calibrate a neural net is to train your model as normal, then finetune just the final classifier layer using a held out set; the general idea being that this is effectively just a logistic regression on features learnt by the rest of the network and so should come with the better guarantees that logistic regression provides us regarding calibration (see this great overview from sklearn) whereas this has always worked well for me it does require two steps to training, as well as the need for a seperate held out set to be managed.
another apporach to calibration i revisited this week comes from this great overview paper from 2017 on calibration of modern neural networks. it touches on platt scaling, which i've also used in the past successfully, but includes something even simpler i didn't really notice until it was pointed out to me by a colleague; don't bother tuning the entire last layer, just fit a single temperature rescaling of the output logits i.e. what can be thought of as a single parameter version of platt scaling. what a great simple idea!
the main purpose of this post though is to reproduce some ideas around out of distribution and calibration when you use focal loss. this was described in this great piece of work calibrating deep neural networks using focal loss. we'll implement that, but the other two as well for comparison.
( as an aside, another great simple way of determining confidence is using an ensemble! i explore this idea a bit in my post on ensemble nets where i train ensembles as a single model using jax vmap )
let's start with cifar10 and split the data up a bit differently than the standard splits...
train
and test
splits.
automobile
& cat
as a "hard" ood
set ( i've intentionally chosen these two since i know they cause model confusion as observed in this keras.io tute on metric learning ).
train
and 0.1 for each of validation
, calibration
and test
.
we'll additionally generate an "easy" ood
set which will just be random images.
in terms of how we'll use these splits...
train
will be the main dataset to train against.
validation
will be used during training for learning rate stepping, early stopping, etc.
calibration
will be used when we want to tune a model in some way for calibration.
test
will be used for final held out analysis.
ood_hard
and ood_easy
will be used as the representative out of distribution sets.
examples images | |
in distribution | |
out of distribution (hard) | |
out of distribution (easy) |
model_1
, on train
, tuning against validate
.
model_2
by finetuning the last layer of model_1
against calibrate
; i.e. the logisitic regression approach to calibration.
model_3
by just fitting a scalar temperture of last layer of model_1
; i.e. the temperture scaling approach.
train
using validate
for tuning.
at each step we'll check the entropy distributions across the various splits, including the easy and hard ood sets
for a baseline we'll use a
resnet18 objax model
trained against train
using validate
for simple early stopping.
# define model model = objax.zoo.resnet_v2.ResNet18(in_channels=3, num_classes=10) # train against all model vars trainable_vars = model.vars()
$ python3 train_basic_model.py --loss-fn cross_entropy --model-dir m1 learning_rate 0.001 validation accuracy 0.372 entropy min/mean/max 0.0005 0.8479 1.9910 learning_rate 0.001 validation accuracy 0.449 entropy min/mean/max 0.0000 0.4402 1.9456 learning_rate 0.001 validation accuracy 0.517 entropy min/mean/max 0.0000 0.3566 1.8013 learning_rate 0.001 validation accuracy 0.578 entropy min/mean/max 0.0000 0.3070 1.7604 learning_rate 0.001 validation accuracy 0.680 entropy min/mean/max 0.0000 0.3666 1.7490 learning_rate 0.001 validation accuracy 0.705 entropy min/mean/max 0.0000 0.3487 1.6987 learning_rate 0.001 validation accuracy 0.671 entropy min/mean/max 0.0000 0.3784 1.8276 learning_rate 0.0001 validation accuracy 0.805 entropy min/mean/max 0.0000 0.3252 1.8990 learning_rate 0.0001 validation accuracy 0.818 entropy min/mean/max 0.0000 0.3224 1.7743 learning_rate 0.0001 validation accuracy 0.807 entropy min/mean/max 0.0000 0.3037 1.8580
we can check the distribution of entropy values of this trained model against our various splits
some observations...
recall that a higher entropy => a more uniform distribution => less prediction confidence.
as such our goal is to have the entropy of the ood
set as high as possible without
dropping the test
set too much.
as mentioned before the main approach i'm familiar with is retraining the classifier
layer. we'll start with model_1
and use the so-far-unseen calibration
set for fine tuning it.
# define model model = objax.zoo.resnet_v2.ResNet18(in_channels=3, num_classes=10) # restore from model_1 objax.io.load_var_collection('m1/weights.npz', model.vars()) # train against last layer only classifier_layer = model[-1] trainable_vars = classifier_layer.vars()
$ python3 train_model_2.py --input-model-dir m1 --output-model-dir m2 learning_rate 0.001 calibration accuracy 0.808 entropy min/mean/max 0.0000 0.4449 1.8722 learning_rate 0.001 calibration accuracy 0.809 entropy min/mean/max 0.0002 0.5236 1.9026 learning_rate 0.001 calibration accuracy 0.808 entropy min/mean/max 0.0004 0.5468 1.9079 learning_rate 0.001 calibration accuracy 0.811 entropy min/mean/max 0.0005 0.5496 1.9214 learning_rate 0.001 calibration accuracy 0.812 entropy min/mean/max 0.0001 0.5416 1.9020 learning_rate 0.001 calibration accuracy 0.812 entropy min/mean/max 0.0002 0.5413 1.9044 learning_rate 0.001 calibration accuracy 0.811 entropy min/mean/max 0.0003 0.5458 1.8990 learning_rate 0.001 calibration accuracy 0.814 entropy min/mean/max 0.0001 0.5472 1.9002 learning_rate 0.001 calibration accuracy 0.812 entropy min/mean/max 0.0001 0.5455 1.9124 learning_rate 0.001 calibration accuracy 0.813 entropy min/mean/max 0.0001 0.5503 1.9092 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0001 0.5485 1.9011 learning_rate 0.0001 calibration accuracy 0.814 entropy min/mean/max 0.0001 0.5476 1.9036 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0001 0.5483 1.9051 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0000 0.5479 1.9038 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0000 0.5482 1.9043 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0000 0.5482 1.9059 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0000 0.5479 1.9068 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0000 0.5480 1.9058 learning_rate 0.0001 calibration accuracy 0.812 entropy min/mean/max 0.0000 0.5488 1.9040 learning_rate 0.0001 calibration accuracy 0.814 entropy min/mean/max 0.0000 0.5477 1.9051
if we compare (top1) accuracy of model_1
vs model_2
we see a slight improvement
in model_2
, attributable to the slight extra training i suppose (?)
$ python3 calculate_metrics.py --model m1/weights.npz train accuracy 0.970 validate accuracy 0.807 calibrate accuracy 0.799 test accuracy 0.803 $ python3 calculate_metrics.py --model m2/weights.npz train accuracy 0.980 validate accuracy 0.816 calibrate accuracy 0.814 test accuracy 0.810
more importantly though; how do the entropy distributions looks?
some observations...
in model_3
we create a simple single parameter layer that represents rescaling,
append it to the pretrained resnet and train just that layer.
class Temperature(objax.module.Module): def __init__(self): super().__init__() self.temperature = objax.variable.TrainVar(jnp.array([1.0])) def __call__(self, x): return x / self.temperature.value # define model model = objax.zoo.resnet_v2.ResNet18(in_channels=3, num_classes=10) # restore from model_1 objax.io.load_var_collection('m1/weights.npz', model.vars()) # add a temp rescaling layer temperature_layer = layers.Temperature() model.append(temperature_layer) # train against just this layer trainable_vars = temperature_layer.vars()
$ python3 train_model_3.py --input-model-dir m1 --output-model-dir m3 learning_rate 0.01 temp 1.4730 calibration accuracy 0.799 entropy min/mean/max 0.0006 0.5215 1.9348 learning_rate 0.01 temp 1.5955 calibration accuracy 0.799 entropy min/mean/max 0.0013 0.5828 1.9611 learning_rate 0.01 temp 1.6118 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5911 1.9643 learning_rate 0.01 temp 1.6162 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5933 1.9652 learning_rate 0.01 temp 1.6118 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5911 1.9643 learning_rate 0.01 temp 1.6181 calibration accuracy 0.799 entropy min/mean/max 0.0015 0.5943 1.9655 learning_rate 0.01 temp 1.6126 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5915 1.9645 learning_rate 0.01 temp 1.6344 calibration accuracy 0.799 entropy min/mean/max 0.0016 0.6025 1.9687 learning_rate 0.01 temp 1.6338 calibration accuracy 0.799 entropy min/mean/max 0.0016 0.6022 1.9686 learning_rate 0.01 temp 1.6018 calibration accuracy 0.799 entropy min/mean/max 0.0013 0.5860 1.9623 learning_rate 0.001 temp 1.6054 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5878 1.9630 learning_rate 0.001 temp 1.6047 calibration accuracy 0.799 entropy min/mean/max 0.0013 0.5874 1.9629 learning_rate 0.001 temp 1.6102 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5903 1.9640 learning_rate 0.001 temp 1.6095 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5899 1.9638 learning_rate 0.001 temp 1.6112 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5908 1.9642 learning_rate 0.001 temp 1.6145 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5924 1.9648 learning_rate 0.001 temp 1.6134 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5919 1.9646 learning_rate 0.001 temp 1.6144 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5924 1.9648 learning_rate 0.001 temp 1.6105 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5904 1.9640 learning_rate 0.001 temp 1.6151 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5927 1.9650
some observations...
model_1
. this is since the temp rescaling never changes the ordering of predictions. ( just this property alone could be enough reason to use this approach in some cases! )
model_2
in terms of raising the entropy of ood
!
not bad for fitting a single parameter :D
as a final experiment we'll use focal loss for training instead of cross entropy
note: i rolled my own focal loss function which is always dangerous! it's not impossible (i.e. it's likely) i've done something wrong in terms of numerical stability; please let me know if you see a dumb bug :)
def focal_loss_sparse(logits, y_true, gamma=1.0):
log_probs = jax.nn.log_softmax(logits, axis=-1)
log_probs = log_probs[jnp.arange(len(log_probs)), y_true]
probs = jnp.exp(log_probs)
elementwise_loss = -1 * ((1 - probs)**gamma) * log_probs
return elementwise_loss
note that focal loss includes a gamma. we'll trial 1.0, 2.0 and 3.0 in line with the experiments mentioned in calibrating deep neural networks using focal loss
$ python3 train_basic_model.py --loss-fn focal_loss --gamma 1.0 --model-dir m4_1 learning_rate 0.001 validation accuracy 0.314 entropy min/mean/max 0.0193 1.0373 2.0371 learning_rate 0.001 validation accuracy 0.326 entropy min/mean/max 0.0000 0.5101 1.8546 learning_rate 0.001 validation accuracy 0.615 entropy min/mean/max 0.0000 0.6680 1.8593 learning_rate 0.001 validation accuracy 0.604 entropy min/mean/max 0.0000 0.5211 1.8107 learning_rate 0.001 validation accuracy 0.555 entropy min/mean/max 0.0001 0.5127 1.7489 learning_rate 0.0001 validation accuracy 0.758 entropy min/mean/max 0.0028 0.6008 1.8130 learning_rate 0.0001 validation accuracy 0.780 entropy min/mean/max 0.0065 0.6590 1.9305 learning_rate 0.0001 validation accuracy 0.796 entropy min/mean/max 0.0052 0.6390 1.8589 learning_rate 0.0001 validation accuracy 0.790 entropy min/mean/max 0.0029 0.5886 1.9151
$ python3 train_basic_model.py --loss-fn focal_loss --gamma 2.0 --model-dir m4_2 learning_rate 0.001 validation accuracy 0.136 entropy min/mean/max 0.0001 0.2410 1.7906 learning_rate 0.001 validation accuracy 0.530 entropy min/mean/max 0.0000 0.8216 2.0485 learning_rate 0.001 validation accuracy 0.495 entropy min/mean/max 0.0003 0.7524 1.9086 learning_rate 0.001 validation accuracy 0.507 entropy min/mean/max 0.0000 0.6308 1.8460 learning_rate 0.001 validation accuracy 0.540 entropy min/mean/max 0.0000 0.5899 2.0196 learning_rate 0.001 validation accuracy 0.684 entropy min/mean/max 0.0001 0.6189 1.9029 learning_rate 0.001 validation accuracy 0.716 entropy min/mean/max 0.0000 0.6509 1.8822 learning_rate 0.001 validation accuracy 0.606 entropy min/mean/max 0.0001 0.7096 1.8312 learning_rate 0.0001 validation accuracy 0.810 entropy min/mean/max 0.0002 0.6208 1.9447 learning_rate 0.0001 validation accuracy 0.807 entropy min/mean/max 0.0004 0.6034 1.8420
$ python3 train_basic_model.py --loss-fn focal_loss --gamma 3.0 --model-dir m4_3 learning_rate 0.001 validation accuracy 0.288 entropy min/mean/max 0.0185 1.1661 1.9976 learning_rate 0.001 validation accuracy 0.319 entropy min/mean/max 0.0170 1.0925 2.1027 learning_rate 0.001 validation accuracy 0.401 entropy min/mean/max 0.0396 0.9795 2.0361 learning_rate 0.001 validation accuracy 0.528 entropy min/mean/max 0.0001 0.8534 1.9115 learning_rate 0.001 validation accuracy 0.528 entropy min/mean/max 0.0002 0.7405 1.8509 learning_rate 0.0001 validation accuracy 0.748 entropy min/mean/max 0.0380 0.9204 1.9697 learning_rate 0.0001 validation accuracy 0.760 entropy min/mean/max 0.0764 0.9807 1.9861 learning_rate 0.0001 validation accuracy 0.772 entropy min/mean/max 0.0906 0.9646 2.0556 learning_rate 0.0001 validation accuracy 0.781 entropy min/mean/max 0.0768 0.8973 1.9264 learning_rate 0.0001 validation accuracy 0.780 entropy min/mean/max 0.0871 0.8212 1.9814
some observations...
model_2
than model_1
in terms of calibration.
ood
data, but with a drop in overall validation
accuracy. maybe the seperation isn't as much as i'd like, but happy to experiment more given this result.
interesting!! focal loss is definitely going in my "calibration toolkit" :D
recall the distribution plot of entropies from the model_3
experiment
we can eyeball that the entropies values are low for in-distribution sets (train, validate, calibrate and test) and high for the two ood sets (easy and hard) but how do we actually turn this difference into a in-distribution classifier? ( recall this is not a standard supervised learning problem, we only have the in-distribution instances to train against and nothing from the ood sets at training time )
amusingly i spent a bunch of time hacking around with all sorts of density estimation techniques but it turned out the simplest baseline i started with was the best :/
the baseline works on a strong, but valid enough, assumption of the entropies; the entropies for in-distribution instances will be lower than the entropies for OOD instances. given this assumption we can just do the simplest "classifier" of all; thresholding
to get a sense of the precision/recall tradeoff of approach we can do the following...
train
set; this scales all train values to (0, 1)
test
, ood_easy
and ood_hard
; including (0, 1) clipping
x = 1.0 - x
, since we want a low entropy to denote a positive instance of in-distribution
test
instances as positives and ood
instances
as negatives.
with this simple approach we get the following ROC, P/R and reliability plots.
i've included the attempts i made to use a kernel density estimator in fit_kernel_density_to_entropy.py
so if you manage to get a result better than this trivial baseline please
let me know :)
though this post is about using entropy as a way of measuring OOD there are other ways too.
the simplest way i've handled OOD detection in the past is to include an explicit OTHER label. this has worked in cases where i've had data available to train against, e.g. data that was (somehow) sampled but wasn't of interest for actual supervised problem at hand. but i can see in lots of situations this wouldn't always be possible.
another approach i'm still hacking on is doing density estimation on the inputs; this related very closely to the diversity sampling ideas for active learning.
these experiments were written in objax, an object oriented wrapper for jax
]]>Mathematics for Machine Learning by Marc Peter Deisenroth, A. Aldo Faisal & Cheng Soon Ong.
this is my personal favorite book on the general math required for machine learning, the way things are described really resonate with me. available as a free pdf but i got a paper copy to support the authors after reading the first half.
Linear Algebra and Learning from Data by Gilbert Strang.
this is gilbert's most recent work. it's really great, he's such a good teacher, and his freely available lectures are even better. it's a shorter text than his other classic intro below with more of a focus on how things are connected to modern machine learning techniques.
Introduction to Linear Algebra by Gilbert Strang.
this was my favorite linear algebra book for a long time before his 'learning from data' came out. this is a larger book with a more comprehensive view of linear algebra.
Think Stats: Probability and Statistics for Programmers by Allen Downey.
this book focuses on practical computation methods for probability and statistics. i got a lot out of working through this one. it's all in python and available for free. ( exciting update! as part of writing this post i've discovered there's a new edition to read!)
Doing Bayesian Data Analysis by John Kruscgke
on the bayesian side of things this is the book i've most enjoyed working through. i've only got the first edition which was R and BUGS but i see the second edition is R, JAGS and Stan. it'd be fun i'm sure to work through it doing everything in numpyro. i might do that in all my free time. haha. "free time" hahaha. sob.
The Elements of Statistical Learning by Hastie, Tibshirani and Friedman
this is still one of the most amazing fundamental machine learning books i've ever had. in fact i've purchased this book twice and given it away both times :/ i might buy another copy some time soon, even though it's been freely available to download for ages. an amazing piece of work.
Probabilistic Graphical Models by Daphne Koller & Nir Friedman
this is an epic textbook that i'd love to understand better. i've read a couple of sections in detail but not the entire tome yet.
Pattern Recognition and Machine Learning by Christopher Bishop
this is probably the best overall machine learning text book i've ever read. such a beautiful book and the pdf is FREE FOR DOWNLOAD!!!
Machine Learning: A Probabilistic Perspective by Kevin Murphy
this is my second favorite general theory text on machine learning. i got kevin to sign my copy when he was passing my desk once but someone borrowed it and never gave it back :( so if you see a copy with my name on the spine let me know!
Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow by AurÃ©lien GÃ©ron
this is the book i point most people to when they are interested in getting up to speed with modern applied machine learning without too much concern for the theory. it's very up to date (as much as a book can be) with the latest libraries and, most importantly, provides a good overview of not just neural stuff but fundamental scikit-learn as well.
Machine Learning Engineering by Andriy Burkov
a great book focussing on the operations side of running a machine learning system. i'm a bit under half way through the free online version and very likely to buy a physical copy to finish it and support the author. great stuff and, in many ways, a more impactful book than any of the theory books here.
Introduction to Data Mining by Pang-Ning Tan, Michael Steinbach & Vipin Kumar
this is another one that was also on my list from ten years ago and though it's section on neural networks is a bit of chuckle these days there is still a bunch of really great fundamental stuff in this book. very practical and easy to digest. i also see there's a second edition now. i reckon this would compliment the "hands on" book above very well.
Speech and Language Processing by Dan Jurafsky & James Martin
still the best overview of NLP there is (IMHO). can't wait to read the 3rd edition which apparently will cover more modern stuff (e.g. transformers) but until then, for the love of god though, please don't be one of those "this entire book is irrelevant now! just fine tune BERT" people :/
Numerical Optimization by Jorge NocedalStephen J. Wright
this book is super hard core and maybe more an operations research book than machine learning. though i've not read it cover to cover the couple of bits i've worked through really taught me a lot. i'd love to understand the stuff in this text better; it's so so fundamental to machine learning (and more)
Deep Learning by Ian Goodfellow
writing a book specifically on deep learning is very dangerous since things move so fast but if anyone can do it, ian can... i think ian's approach to explaining neural networks from the ground up is one of my favorites. i got the first edition hardback but it's free to download from the website.
Probabilistic Robotics by Sebastian Thrun, Wolfram Burgard and Dieter Fox
when i first joined a robotics group i bought a stack of ML/robotics books and this was by far the best. it's good intro stuff, and maybe already dated in places given it's age (the 2006 edition i have) but i still got a bunch from it.
TinyML by Pete Warden & Daniel Situnayake
this was a super super fun book to tech review! neural networks on microcontrollers?!? yes please!
Evolutionary Computation by David Fogel
this is still by favorite book on evolutionary algorithms; i've had this for a loooong time now. i still feel like evolutionary approaches are due for a big big comeback any time soon....
the good thing about writing a list is you get people telling you cool ones you've missed :)
the top three i've chosen (that are in the mail) are...
Causal Inference in Statistics by Judea Pearl, Madelyn Glymour & Nicholas P. Jewell
recommended by animesh garg who quite rightly points out the lack of causality in machine learning books in the books above.
Information Theory, Inference and Learning Algorithms by David MacKay
i've seen this book mentioned a number of times and was most recently recommended by my colleague dane so it's time to get it.
]]>