<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0"
     xmlns:content="http://purl.org/rss/1.0/modules/content/"
     xmlns:sy="http://purl.org/rss/1.0/modules/syndication/"
     xmlns:atom="http://www.w3.org/2005/Atom"
     xmlns:dc="http://purl.org/dc/elements/1.1/"
     xmlns:wfw="http://wellformedweb.org/CommentAPI/"
     >
  <channel>
    <title>brain of mat kelcey...</title>
    <link>http://matpalm.com/blog</link>
    <description>thoughts from a data scientist wannabe</description>
    <generator>Blogofile</generator>
    <sy:updatePeriod>hourly</sy:updatePeriod>
    <sy:updateFrequency>1</sy:updateFrequency>
    <item>
      <title>keras3 to jax to tflite/liteRT via orbax</title>
      <link>http://matpalm.com/blog/keras3_jax_tflite</link>
      <category><![CDATA[tflite]]></category>
      <category><![CDATA[litert]]></category>
      <category><![CDATA[keras3]]></category>
      <category><![CDATA[jax]]></category>
      <category><![CDATA[orbax]]></category>
      <guid>http://matpalm.com/blog/keras3_jax_tflite</guid>
      <description>keras3 to jax to tflite/liteRT via orbax</description>
      <content:encoded><![CDATA[<h2>keras 3, jax, tflite, oh my.</h2>
<p>i've always seen keras as providing two key things;
1) a way to compose layers into models and
2) an opinionated <code>fit</code> loop</p>
<p>for the last few years i've been focussed much more on jax than tensorflow which means i've
not really been using keras for any hobby stuff. am super excited that keras3 supports a jax
backend because i can go back to using keras for it's clear model definition while having
the flexibility to mess around with weird non standard model optimisation.</p>
<p>my day job has some focus around the tflite ecosystem ( recently renamed liteRT ) and
this set of tools has had great(†) support for tensorflow conversion, though not-so-great (†) support
for jax conversion. <small>† for some definition of great</small></p>
<p>if you have a "standard" keras model you can do conversion to tflite via the various keras save formats.
( in which case whether the backend is jax or not is not relevant ) however after using keras3/jax for
a little while i'm finding myself building composite models that don't fit into this "standard" keras model.</p>
<p>consider my <a href="/blog/yolz/">yolz</a> project; it's a model that i train with jax that's actually made up of
two keras models, and given the multiple input nature of this model using different nested batching
it's not something that fits nicely with the standard keras model.</p>
<p>the crux of it all is; <b>you can compose two keras models into one, but not if one of them
has been vmapped ( or has had some other jax function transform on it )</b> i can define the parts with
keras, train them with jax/optax, but i can't put them "back" into a keras model to save for tf lite export.</p>
<h2>orbax tflite export</h2>
<p>jax to tflite conversion has been through a LOT of churn; i've used libs listed under
<code>.experimental.</code> more than not it feels like...</p>
<p>the one i started poking with today is <a href="https://ai.google.dev/edge/litert/models/jax_to_tflite">orbax</a></p>
<p>consider the following keras3 model</p>
<pre class="prettyprint"><code class="language-python">from keras.layers import Input, Conv2D, GlobalMaxPooling2D
from keras.models import Model

input = Input((64, 64, 3))
y = Conv2D(filters=8, strides=1, kernel_size=3, activation='relu', padding='same')(input)
y = GlobalMaxPooling2D()(y)
model = Model(input, y)
</code></pre>

<p>we can convert this to a jax function for inference with a closure over the
non trainable params</p>
<pre class="prettyprint"><code class="language-python">params = model.trainable_variables
nt_params = model.non_trainable_variables

def predict(params, x):
  y_pred, _nt_params = model.stateless_call(params, nt_params, x, training=False)
  return y_pred

eg_x = jnp.ones((1, 64, 64, 3))
predict(params, eg_x).shape
# (1, 8)
</code></pre>

<p>and then pass this function to the tf lite export with no surprises.</p>
<pre class="prettyprint"><code class="language-python">from orbax.export import constants
jax_module = JaxModule(params, k3_predict, input_polymorphic_shape='b, ...')
converter = tf.lite.TFLiteConverter.from_concrete_functions(
   [
       jax_module.methods[constants.DEFAULT_METHOD_KEY].get_concrete_function(
           tf.TensorSpec(shape=(1, 64, 64, 3), dtype=tf.float32, name="input")
       )
   ]
)
tflite_model = converter.convert()
with open('keras.tflite', 'wb') as f:
  f.write(tflite_model)
</code></pre>

<p><img src="/blog/imgs/2024/k3_orbax/tflite.single.png"/></p>
<h2>a vmapped model</h2>
<p>more interestingly we can make a vectorised version of the inference function
that is called on a set of inputs and then does a reduction ( such as the embedding
model does in yolz )</p>
<pre class="prettyprint"><code class="language-python">input = Input((64, 64, 3))
y = Conv2D(filters=8, strides=1, kernel_size=3, activation='relu', padding='same')(input)
y = GlobalMaxPooling2D()(y)
model = Model(input, y)

params = model.trainable_variables
nt_params = model.non_trainable_variables

def predict(params, x):
  """ the vanilla version of inference """
  y_pred, _nt_params = model.stateless_call(params, nt_params, x, training=False)
  return y_pred

def v_predict(params, x):
  """ a variant that supports multiple inputs and does a mean over the result """
  y_preds = jax.vmap(predict, [None, 0])(params, x)
  return jnp.mean(y_preds, axis=-2)

eg_x = jnp.ones((1, 4, 64, 64, 3))
v_predict(params, eg_x).shape
# (1, 8)
</code></pre>

<p>this time we call with the additional shape axis representing the set of inputs; in this example 4</p>
<pre class="prettyprint"><code class="language-python">converter = tf.lite.TFLiteConverter.from_concrete_functions(
   [
       jax_module.methods[constants.DEFAULT_METHOD_KEY].get_concrete_function(
           tf.TensorSpec(shape=(1, 4, 64, 64, 3), dtype=tf.float32, name="input")
       )
   ]
)
</code></pre>

<p>and we see the tflite result is correctly compiled as a reshape to collapse the leading
axis ( from <code>(B, 4, ...)</code> to <code>(4B, ...)</code> for the convolution and then back to (<code>B, 4, ...)</code> for the reduction )</p>
<p>we also see how the additional plumbing is included for the <code>mean</code> reduction mixed with the <code>pooling</code></p>
<p><img src="/blog/imgs/2024/k3_orbax/tflite.vmap.png"/></p>
<p>note: all of this is additional the batching we use during training for not just hardware optimisation but
also for gradient variance reduction</p>
<p>as i mentioned before in the yolz blog post this is all reshaping, axis plumbing etc
we can do manually but it's much easier, especially as we start to nest vmaps, when
the compiler does it for us</p>
<h2>code</h2>
<p>here's the <a href="https://colab.research.google.com/drive/1lD-6CTVxcJ29yHbOw6kbmrZKkJvEKmnD?usp=sharing">colab</a></p>]]></content:encoded>
    </item>
    <item>
      <title>yolz; you only look o̶n̶c̶e̶ zero times</title>
      <link>http://matpalm.com/blog/yolz</link>
      <category><![CDATA[keras3]]></category>
      <category><![CDATA[jax]]></category>
      <guid>http://matpalm.com/blog/yolz</guid>
      <description>yolz; you only look o̶n̶c̶e̶ zero times</description>
      <content:encoded><![CDATA[<h2>what does zero shot mean?</h2>
<h3>transfer learning</h3>
<p>transfer learning is arguably the greatest success story of classifier neural nets.
it works based on the idea that an N layer model can be thought of as an
N-1 feature extractor followed by a 'simple' logisitic regression. transfer learning
is based on the idea that the large N-1 feature extractor is reusable across related
domains and you can get a long way just by retraining the classifier to get a second model.</p>
<p>note that this works best when the first model was trained on a set of classes that is, somehow,
related to the set of the classes in the second model.</p>
<p>( for more details see this talk i did on
<a href="https://matpalm.com/blog/self_supervised_learning/">self supervised learning</a> )</p>
<h3>what is few shot learning?</h3>
<p>the fact that transfer learning even works is kind of an accident ( in the sense that
it was never the objective of training the first model ). however there's no reason you
can't <em>make</em> it an objective. techniques such as meta learning focus on this and aim to
train a model that can adapted with transfer learning with only a <strong>few</strong> examples, sometimes
as little as a single one per class ( e.g. <a href="https://openai.com/index/reptile/">reptile</a> )</p>
<p>( for more details see this talk i did on
<a href="https://matpalm.com/blog/learning_to_learn/">learning to learn</a> )</p>
<h3>what is zero shot learning?</h3>
<p>with transfer learning you have to provide <i>a lot</i> of new examples during retraining.
with few shot learning you only need to provide onlt <i>a few</i> examples during retraining.
but you can train a model that generalises to new classes without needing to retrain at all?</p>
<p>a zero shot learner aims to do this with a key technique being that the new class
to generalise to is somehow provided as part of the model input itself.</p>
<p>( note how LLMs are the ultimate in zero shot learning since the use an input that is one of the
richest representations we know, language )</p>
<h2>what is zero shot detection</h2>
<p>to do zero shot detection we need a model that has takes two inputs</p>
<ol>
<li>examples of the thing we're actually trying to detect ( the 'reference objects' )</li>
<li>the image we want to find those object in ( the 'scene' )</li>
</ol>
<p>and returns as output a binary mask of the location of the object of interest...</p>
<p><img src="/blog/imgs/2024/yolz/in_out_example.png"/></p>
<p>since we want the model to be robust to different angles of the
object so we'll take as input <em>multiple views</em> for the object reference...</p>
<h2>inference model</h2>
<p>first let's consider the inference version of this model..</p>
<p><img src="/blog/imgs/2024/yolz/inference_model.png" /></p>
<p>the N object references are run through an embedding network with the "mean" embedding being the result</p>
<p>the scene input is run through a standard u-net like architecture with the object embeddings combined
with the scene features in the middle of the u-net ( i.e. after the downsampling, but before
the upsampling ).</p>
<p>note we have one embedding to mix into a <code>(10, 10)</code> feature map
so we first broadcast the embedding to match the feature map size.</p>
<p>the best performing approach to combining was just simple elementwise addition ( after a non linear
projection of both scene features and object embeddings to a joint feature depth ).
some simple experiments were run trying to use the object embeddings as a query for self
attention against the scene features but it wasn't any better and was more computationally
expensive. less is more sometimes i guess, and maybe it'd provide a better result for a more
complex version of the problem?</p>
<h2>training model</h2>
<p>now consider the additional parts required for the training model...</p>
<p><img src="/blog/imgs/2024/yolz/training_model.png" /></p>
<p>the main change is around the object reference branch.</p>
<ol>
<li>since a scene has many objects of interest we want to train against the masks of <em>all</em> the
objects present, not just one. this means the inclusion of an additional <code>C</code> axis.</li>
<li>since we want a set of embeddings that will generalise well to novel examples we jointly train
the embedding model with a constrastive loss. this requires having to run the embeddings for <em>pairs</em>
of examples that will go to the constrastive loss. note: we only pass the anchors down to the scene
branch.</li>
</ol>
<p>the joint training means that the embeddings are forced to both</p>
<ol>
<li>generalise well based on the constrastive loss but also</li>
<li>be useful for features to include in the scene branch.</li>
</ol>
<p>though the main change is around the contrastive loss, we also need to do extra broadcasting in the scene
branch. in the inference model we needed to broadcast the embeddings from <code>(E)</code> to <code>(10, 10, E)</code>
but now the broadcasting will be <code>(C, E)</code> to <code>(C, 10, 10, E)</code> which requires the scene features to
be also broadcast from <code>(10, 10, F)</code> to <code>(C, 10, 10, F)</code></p>
<p>note here the output across <code>C</code> remains independent predictions, i.e. there is no softmax or anything</p>
<h2>contrastive embeddings</h2>
<p>as a side note, how do we train the object embeddings? pretty standard...</p>
<ul>
<li>sample C classes</li>
<li>sample N examples of each C as 'anchors' ( with randomly coloured backgrounds )</li>
<li>calculate their embeddings and take the mean</li>
<li>sample another set of N examples as 'positives'</li>
<li>calculate their embeddings and take the mean</li>
<li>train them so the cosine_sim(mean(anc_i), mean(pos_j)) = 1.0 for i==j and = 0.0 otherwise</li>
</ul>
<table class='data'>
<tr>
<td>example class 1 anchors ( top row ) and positives</td>
<td><img src="/blog/imgs/2024/yolz/s017000_train_anchor_positives.png"/></td>
</tr>
<tr>
<td>example class 2 anchors and positives
<td><img src="/blog/imgs/2024/yolz/s006000_train_anchor_positives.png"/></td>
</tr>
</table>

<p>( for more details on contrastive learning
see <a href="https://keras.io/examples/vision/metric_learning/">this tute i did on keras.io</a> )</p>
<p>this trains an embedding space for set of N examples that generalises to new classes. the random
backgrounds were included for robust to the training.</p>
<h2>model definition and fitting</h2>
<p>this all requires a non trivial fit loop since there's no "standard" sense of a batch anywhere and
the contrastive embeddings in this model being a great example of the power of <code>jax.vmap</code> IMHO.</p>
<p>an image model will by default operate with a leading batch dim <code>(B, H, W, 3)</code>. since we want
to embed N objects this fit nicely as <code>(N, H, W, 3) -&gt; (N, E)</code> after which we take the mean <code>(E)</code>
"outside" of the model forward pass.</p>
<p>but now we want to do this for <code>C</code> examples, so our data is <code>(C, N, H, W, 3)</code>. the
classic approach to this, since the CxN examples are independent, would be to reshape to <code>(CN, H, W, 3)</code>,
run through to <code>(CN, E)</code> after which we'd need to reshape back to <code>(C, N, E)</code> and finally run the mean over
the second axis to get <code>(C, E)</code></p>
<p>though it totally doable as a manual thing it's a lot simpler in my mind to compose
this instead as just a <code>vmap</code> over
the operation of <code>(N, H, W, 3) -&gt; (N, E) -&gt; (E)</code>   ( especially since that last mean step happend
<em>outside</em> the model forward pass of batch=N. )</p>
<p>furthermore there's another bigger reason to get <code>vmap</code> to do this for us....</p>
<p>so far we've looked at a model that has two inputs;</p>
<ol>
<li>object references <code>(C, 2, N, 64, 64, 3)</code> and</li>
<li>scene <code>(640, 640, 3)</code></li>
</ol>
<p>but to get the best result in terms of gradient variance we actually want to batch the entire
composite model and actually give inputs</p>
<ol>
<li>object references <code>(B, C, 2, N, 64, 64, 3)</code> and</li>
<li>scene <code>(B, 640, 640, 3)</code></li>
</ol>
<p>and this nesting of vmaps inside vmaps completely handles all the required axis remapping for us.</p>
<h2>basic results</h2>
<p>here we show a basic example of inference on held out data.</p>
<ul>
<li>given an input image ( far left )</li>
<li>top row is reference objects and detection for 3 best performers</li>
<li>bottom row is reference objects and detection for 3 worst performers</li>
</ul>
<iframe width="640" height="480" src="https://www.youtube.com/embed/hLUtVB6r1r4?si=SZiPoxzgIE3ccn5D" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" referrerpolicy="strict-origin-when-cross-origin" allowfullscreen></iframe>

<h2>findings</h2>
<ul>
<li>keras3 and stateless_call with jax is a great combo and i'll be using this more</li>
<li>focal loss gives a non trival bump to performance of the scene classifier, makes me wonder if
an equivalent version can be used to contrastive loss.</li>
<li>having a stop gradient from the scene loss to the object embeddings gives the result we
expect; embeddings model has a lower loss but the overall classifier performs worse.</li>
</ul>
<h2>code</h2>
<p>all the code on <a href="https://github.com/matpalm/yolz">github</a></p>]]></content:encoded>
    </item>
    <item>
      <title>differentiable kalman filters in jax</title>
      <link>http://matpalm.com/blog/differentiable_kalman_filters_in_jax</link>
      <category><![CDATA[jax]]></category>
      <guid>http://matpalm.com/blog/differentiable_kalman_filters_in_jax</guid>
      <description>differentiable kalman filters in jax</description>
      <content:encoded><![CDATA[<h2>kalman filters</h2>
<p>was doing some work with kalman filters recently and, as always, wasn't sure what the best way to
tune the filter configuration.</p>
<p>in the past i've taken a simple grid/random search like approach
but was curious how i could express this tuning as an optimisation using jax instead.</p>
<p>first though, what is a
<a href="https://en.wikipedia.org/wiki/Kalman_filter">kalman filter</a>?
they are pretty broad concept but the main thing i've used them for is dynamic system prediction.</p>
<p>they operate in a two step fashion...</p>
<ol>
<li>a <code>predict</code> step which predicts something about the system based on some internal state and</li>
<li>an <code>update</code> step which integrates new observation information into the filter's state ready for the next <code>predict</code></li>
</ol>
<p>in this post we'll use a simple 2D trajectory as the system. e.g. throwing an object</p>
<p>the task of the kalman filter to be the prediction of the objects' position at the next time step</p>
<h2>a simple dynamics system</h2>
<p>consider then the simple system of throwing an object under a trivial physics model</p>
<pre class="prettyprint"><code class="language-python">def simulate_throw(dx, dy):
  x, y = 0, 0
  for _ in range(10):
    yield x, y
    x += dx
    y += dy
    dy -= 1
</code></pre>

<p>we can use this to simulate a couple of throws and plot the trajectories...</p>
<pre class="prettyprint"><code class="language-python">draw_throw_with_colours(
    [simulate_throw_a(dx=3, dy=5), simulate_throw_a(dx=4, dy=3)],
    ['red', 'green']
)
</code></pre>

<p><img src="/blog/imgs/2024/dkf/throw1.png"/></p>
<h2>a numpy kalman filter implementation</h2>
<p>can we use a kalman filter to predict the next state of these systems?
i do hope so, it's what they were designed for!</p>
<p>implementations of a kalman filter can vary a lot so let's just use this
<a href="https://machinelearningspace.com/2d-object-tracking-using-kalman-filter/">random kalman filter from the internet</a></p>
<p><i>note: using this implementation, on any random snippet of code from the internet, for, say,
controlling a rocket or something might be correctly considered a generally "bad idea". the
correctness, or otherwise, of this filter is irrelevant to the task of jaxifying it :)</i></p>
<pre class="prettyprint"><code class="language-python"># as is from https://machinelearningspace.com/2d-object-tracking-using-kalman-filter/ with some
# minor changes

class KalmanFilter(object):
  def __init__(self, dt, u_x,u_y, std_acc, x_std_meas, y_std_meas):
    """
    :param dt: sampling time (time for 1 cycle)
    :param u_x: acceleration in x-direction
    :param u_y: acceleration in y-direction
    :param std_acc: process noise magnitude
    :param x_std_meas: standard deviation of the measurement in x-direction
    :param y_std_meas: standard deviation of the measurement in y-direction
    """
    # Define sampling time
    self.dt = dt
    # Define the  control input variables
    self.u = np.matrix([[u_x],[u_y]])
    # Intial State
    self.x = np.matrix([[0, 0], [0, 0], [0, 0], [0, 0]])
    # Define the State Transition Matrix A
    self.A = np.matrix([[1, 0, self.dt, 0],
                        [0, 1, 0, self.dt],
                        [0, 0, 1, 0],
                        [0, 0, 0, 1]])
    # Define the Control Input Matrix B
    self.B = np.matrix([[(self.dt**2)/2, 0],
                        [0, (self.dt**2)/2],
                        [self.dt, 0],
                        [0, self.dt]])
    # Define Measurement Mapping Matrix
    self.H = np.matrix([[1, 0, 0, 0],
                        [0, 1, 0, 0]])
    # Initial Process Noise Covariance
    self.Q = np.matrix([[(self.dt**4)/4, 0, (self.dt**3)/2, 0],
                        [0, (self.dt**4)/4, 0, (self.dt**3)/2],
                        [(self.dt**3)/2, 0, self.dt**2, 0],
                        [0, (self.dt**3)/2, 0, self.dt**2]]) * std_acc**2
    # Initial Measurement Noise Covariance
    self.R = np.matrix([[x_std_meas**2,0],
                        [0, y_std_meas**2]])
    # Initial Covariance Matrix
    self.P = np.eye(self.A.shape[1])

  def predict(self):
    # Refer to :Eq.(9) and Eq.(10)
    # Update time state
    #x_k =Ax_(k-1) + Bu_(k-1)     Eq.(9)
    self.x = np.dot(self.A, self.x) + np.dot(self.B, self.u)
    # Calculate error covariance
    # P= A*P*A' + Q               Eq.(10)
    self.P = np.dot(np.dot(self.A, self.P), self.A.T) + self.Q
    return self.x[0]

  def update(self, z):
    # Refer to :Eq.(11), Eq.(12) and Eq.(13)
    # S = H*P*H'+R
    S = np.dot(self.H, np.dot(self.P, self.H.T)) + self.R
    # Calculate the Kalman Gain
    # K = P * H'* inv(H*P*H'+R)
    K = np.dot(np.dot(self.P, self.H.T), np.linalg.inv(S))  #Eq.(11)
    self.x = self.x + np.dot(K, (z - np.dot(self.H, self.x)))   #Eq.(12)
    I = np.eye(self.H.shape[1])
    # Update error covariance matrix
    self.P = (I - (K * self.H)) * self.P   #Eq.(13)
</code></pre>

<p>a couple of things to note about this filter</p>
<ul>
<li>recall the api is two methods; <code>predict</code> which provides an estimate of <code>x</code> and <code>update</code> which updates
the internal state of the filter based on a real observation. the implementations of these are
kinda opaque and i wouldn't be surprised if there are a stack of subtle bugs here that a
dynamic systems expert could spot. i am happy to take it "as is" for the purpose of poking around in some jax</li>
<li>since <em>both</em> <code>predict</code> and <code>update</code> change the internal state of the filter it's expected they
are called in sequence; <code>predict</code>, <code>update</code>, <code>predict</code>, <code>update</code> etc</li>
<li>there are a bunch of cryptically named variables; <code>u</code>, <code>B</code>, <code>H</code> etc, some of which are
to do with the internal state of the filter ( like <code>P</code> ) with others representing a form of config
around how we expect the dynamics of the system to behave ( like <code>A</code> ). these latter
matrices are configured based off scalar values such as <code>dt</code> and <code>x_std_meas</code>.</li>
</ul>
<h2>predicting a trajectory with the numpy kalman filter</h2>
<p>we can use this filter as is to make predictions about the next time step of a throw, and it's
not too bad at it...</p>
<pre class="prettyprint"><code class="language-python"># construct a filter with some config
filter = KalmanFilter(
    dt=1.0, u_x=0, u_y=0,
    std_acc=1.0, x_std_meas=0.1, y_std_meas=0.1)

# simulate a throw
xy_trues = simulate_throw(2.8, 4.8)

# step throw the trajectory
xy_preds = []
for xy_true in xy_trues:
  # make prediction based on filter and record it
  xy_pred = filter.predict()
  xy_preds.append(xy_pred)  
  # update the filter state based on the true value
  filter.update(xy_true)

# plot the pair
# * red denotes true values,
# * green denotes predicted values based on the time step before
xy_preds = np.stack(xy_preds).squeeze()
draw_throw_with_colours(
    [xy_trues, xy_preds],
    ['red',    'green'])
</code></pre>

<p><img src="/blog/imgs/2024/dkf/throw_predict1.png"/></p>
<h2>porting to pure functional jax</h2>
<p>next let's port this kalman filter to jax. there's a few aspects of this....</p>
<h3>params and states</h3>
<p>a key concept in the jax port is making <code>predict</code> and <code>update</code> fully functional,
and that means taking a state and returning it for the two methods. i.e. something like...</p>
<pre class="prettyprint"><code class="language-python">def predict(params, state):
  ...
  return state, xy_pred

def update(params, state, z):
  ...
  return state
</code></pre>

<p>we can then have all the <code>P</code>, <code>A</code>, <code>Q</code>, etc go in either <code>params</code> or <code>state</code>.</p>
<p>note: we are going to be explicit about two different types of variables we are dealing with here...</p>
<ul>
<li><code>state</code> represents the internal state of the filter that changes over time 
based on the sequence of <code>predict</code>, <code>update</code>, <code>predict</code>, ... calls</li>
<li><code>params</code> represents the configuration items, based on <code>dt</code> etc, that we eventually
want to get gradients for ( with respect to a loss function ) </li>
</ul>
<p>the pairing of <code>predict</code> then <code>update</code> can be expressed as the following
that reassigns <code>state</code> each method call</p>
<pre class="prettyprint"><code class="language-python">def predict_then_update(params, state, xy_true):
  state, xy_pred = predict(params, state)
  state = update(params, state, xy_true)
  return state, xy_pred
</code></pre>

<p>what type are <code>params</code> and <code>state</code> then?
we want them to be collections of variables which are well supported in jax
under the idea of
<a href="https://jax.readthedocs.io/en/latest/pytrees.html">pytrees</a>
and we can get a lot way with just thinking of them as dictionaries...</p>
<p>note for this jax port:</p>
<ul>
<li>i switched from <code>jnp.dot</code> to <code>@</code> which, IMHO, reads easier.</li>
<li>given no control inputs i've dropped <code>u_x</code> and <code>u_y</code> which were 0.0s anyways. and no <code>u</code> implies no <code>B</code> either...</li>
<li>the implementation feels a bit clunky with dictionary look ups but oh well...</li>
</ul>
<pre class="prettyprint"><code class="language-python">def predict(params, state):
  state['x'] = params['A'] @ state['x']
  state['P'] = ((params['A'] @ state['P']) @ params['A'].T) + params['Q']
  xy_pred = state['x'][0]
  return state, xy_pred

def update(params, state, z):
  # Define Measurement Mapping Matrix
  H = jnp.array([[1, 0, 0, 0],
                 [0, 1, 0, 0]])
  S = (H @ (state['P'] @ H.T)) + params['R']
  K = (state['P'] @ H.T) @ jnp.linalg.inv(S)
  state['x'] = state['x'] + (K @ (z - (H @ state['x'])))
  I = jnp.eye(4)
  state['P'] = (I - (K @ H)) @ state['P']
  return state
</code></pre>

<p>the final missing piece then is how we define the initial values for <code>params</code> and <code>state</code></p>
<pre class="prettyprint"><code class="language-python">def default_params():
  dt = 1.0
  std_acc = 1.0
  x_std_meas, y_std_meas = 0.1, 0.1
  return {
    # Define the State Transition Matrix A
    'A': jnp.array([[1, 0, dt, 0],
                    [0, 1, 0, dt],
                    [0, 0, 1, 0],
                    [0, 0, 0, 1]]),
    # Initial Measurement Noise Covariance
    'R': jnp.array([[x_std_meas**2, 0],
                    [0, y_std_meas**2]]),
    # Initial Process Noise Covariance
    'Q': jnp.array([[(dt**4)/4, 0, (dt**3)/2, 0],
                    [0, (dt**4)/4, 0, (dt**3)/2],
                    [(dt**3)/2, 0, dt**2, 0],
                    [0, (dt**3)/2, 0, dt**2]]) * std_acc**2
  }

def initial_state():
  return {
    # Initial State
    'x': jnp.zeros((4, 2)),
    # Initial Covariance Matrix
    'P': jnp.eye(4),
  }
</code></pre>

<p>which all comes together like the numpy one as ...</p>
<pre class="prettyprint"><code class="language-python">params = default_params()
state = initial_state()
xy_trues = simulate_throw(2.8, 4.8)
xy_preds = []
for xy_true in xy_trues:
    state, xy_pred = predict_then_update(params, state, xy_true)
    xy_preds.append(xy_pred)
xy_preds = np.stack(xy_preds)

draw_throw_with_colours(
    [xy_trues, xy_preds],
    ['red',    'green'])
</code></pre>

<p><img src="/blog/imgs/2024/dkf/throw_predict2.png"/></p>
<h2>a minor diversion regarding rolling out kalman filters and teacher forcing.</h2>
<p>before we go any deeper into jax land let's talk about one more aspect of using kalman filters.</p>
<p>the main idea touched on already is that we can use them to make a prediction about
the next state of a system <em>before</em> we observe it.</p>
<pre class="prettyprint"><code class="language-python">xy_pred_t0 = predict()
# wait for actual xy_true_t0
update(xy_true_t0)
xy_pred_t1 = predict()
# wait for actual xy_true_t1
update(xy_true_t1)
</code></pre>

<p>but if we really trust the filter we can use it to alos handle <em>missing</em> observations
by passing in the last predicted as the observed one when we don't have it ( for whatever reason )</p>
<pre class="prettyprint"><code class="language-python">xy_pred_t0 = predict()
# oh oh! for whatever reason we don't have _t0
update(xy_pred_t0)  # use predicted instead
xy_pred_t1 = predict()
# wait for actual xy_true_t1
update(xy_true_t1)
</code></pre>

<p>this can be handy for a signal that noisy and dropping out <b>but</b> it does put more pressure
on the filter to be robust to any compounding error it might exhibit.</p>
<p>during training then ( which we still haven't talked about yet ) we can induce this
by occasionally randomly dropping out
<code>xy_true</code> values and using the prior <code>xy_pred</code> value instead.</p>
<p>( those with a background in RNNs might notice this is the same as
<a href="https://en.wikipedia.org/wiki/Teacher_forcing">teacher forcing</a> )</p>
<p>the code to do this based on a 20% dropout can be....</p>
<pre class="prettyprint"><code class="language-python">def predict_then_update(params, state, has_observation, xy_true):
  state, xy_pred = predict(params, state)
  xy_for_update = jnp.where(has_observation, xy_true, xy_pred)
  state = update(params, state, xy_for_update)
  return state, xy_pred

xy_preds = []
has_observations = []
for xy_true in xy_trues:
  has_observation = rng.uniform() > 0.2
  has_observations.append(has_observation)
  state, xy_pred = predict_then_update(params, state, has_observation, xy_true)  
  xy_preds.append(xy_pred)
xy_preds = np.stack(xy_preds)

print("has_observations", has_observations

draw_throw_with_colours(
    [xy_trues, xy_preds],
    ['red', 'green'])
</code></pre>

<pre class="prettyprint"><code class="language-python">has_observations [True, True, True, True, True, False, False, True, False, False]
</code></pre>

<p><img src="/blog/imgs/2024/dkf/throw_predict_tf0.2.png" /></p>
<p>notice how the filter shoots way off after that pair of Falses. the <code>default_params</code> values
might be good enough to predict one step in the future, but they don't look robust to predicting two
or more steps in the future.</p>
<h2>time to jax things up!!!</h2>
<h3>jax.lax.scan</h3>
<p>a key first thing to do is to implement the for loop with 
<a href="https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html">jax.lax.scan</a>
so that jax can more cleanly trace it.</p>
<p><code>jax.lax.scan</code> provides the classic functional programming idea of iterating over a method
with a carried state.</p>
<p>note:</p>
<ul>
<li>we move the random allocation of <code>has_observation</code> into the jax function
which means having to assign a <code>key</code> to the <code>state</code> that will be carried along
between calls to <code>predict_then_update_single</code> </li>
<li>the scanning works cleanly because each <code>xy_trues</code> sequence is the same length. in the
cases where we might have a different roll out length we'd need to use something like zero padding
and loss masking or bucketing by sequence length.</li>
<li>we get the entire set of <code>xy_preds</code> in a single call from <code>xy_trues</code></li>
</ul>
<pre class="prettyprint"><code class="language-python">def initial_state(seed):
  return {
    # Initial State
    'x': jnp.zeros((4, 2)),
    # Initial Covariance Matrix
    'P': jnp.eye(4),
    # rng key for missing observations
    'key': jax.random.key(seed)
  }

def rolled_out_predict_then_update(params, seed, xy_trues, missing_rate):

  def predict_then_update_single(state, xy_true):
    state, xy_pred = predict(params, state)

    state['key'], subkey = jax.random.split(state['key'])
    has_observation = jax.random.uniform(subkey) > missing_rate
    xy_for_update = jnp.where(has_observation, xy_true, xy_pred)

    state = update(params, state, xy_for_update)
    return state, xy_pred

  _final_state, predictions = jax.lax.scan(predict_then_update_single, initial_state(seed), xy_trues)
  return predictions

xy_trues = np.array(list(simulate_throw(2.8, 4.8)))

seed = 1234414
xy_preds = jax.jit(rolled_out_predict_then_update)(default_params(), seed, xy_trues, missing_rate=0.2)

draw_throw_with_colours(
    [xy_trues, xy_preds],
    ['red',    'green'])

</code></pre>

<p><img src="/blog/imgs/2024/dkf/throw_predict.scan.tf0.2.png" /></p>
<p>this example has a bunch of missed observations and the filter struggles quite a bit :/</p>
<h3>a loss function, jax.grad &amp; a simple training loop</h3>
<p>with a function now that takes a list of <code>xy_true</code> values and returns a list of <code>xy_pred</code> values
we can start to think about a loss, the simplest being mean square error.</p>
<pre class="prettyprint"><code class="language-python">def loss_fn(params, seed, xy_trues, missing_rate):
  predictions = rolled_out_predict_then_update(params, seed, xy_trues, missing_rate)
  squared_difference = (predictions - xy_trues) ** 2
  return jnp.mean(squared_difference)
</code></pre>

<p>it's interesting to note how the losses are quite unstable across seeds,
especially as we increase the <code>missing_rate</code>. </p>
<pre class="prettyprint"><code class="language-python">for missing_rate in np.linspace(0.0, 0.9, 10):
  losses = [loss_fn(default_params(), seed, xy_trues, missing_rate)
            for seed in range(100)]  
  print(missing_rate, np.mean(losses), np.std(losses))

0.0 2.359239  2.3841858e-07
0.1 5.6141863 10.308739
0.2 18.179102 53.249737
0.3 88.48642  323.95425
0.4 173.53473 505.12674
0.5 269.37463 636.1734
0.6 497.91122 774.9014
0.7 712.81384 858.53827
0.8 989.5386  972.14777
0.9 661.3348  749.5342
</code></pre>

<p>having a loss function means we can get gradients with respect to the params and use them
in a trivial gradient descent update step</p>
<pre class="prettyprint"><code class="language-python">@jax.jit
def update_step(params, seed, xy_trues):
  gradients = jax.grad(loss_fn)(params, seed, xy_trues, missing_rate=0.1)

  def apply_gradients(p, g):
    learning_rate = 1e-5
    return p - learning_rate * g

  return jax.tree_util.tree_map(apply_gradients, params, gradients)
</code></pre>

<p>this update step allows us to sample a trajectory, run it through the rolled out filter, calculate
a loss and gradient and finally update the params...</p>
<pre class="prettyprint"><code class="language-python">params = default_params()
seed = 0
for _ in range(1000):
  dx = 2 + np.random.uniform() * 5  # (2, 7)
  dy = 4 + np.random.uniform() * 4  # (4, 8)
  xy_trues = simulate_throw(dx, dy)
  params = update_step(params, seed, xy_trues)
  if i % 100 == 0:
    print("loss", next_seed, loss_fn(params, next_seed, xy_trues, missing_rate=0.1))
  seed += 1
</code></pre>

<pre>
loss 0   2.5613098
loss 100 12.690718
loss 200 4.0097485
loss 300 4.4077697
loss 400 4.8039966
loss 500 3.5365129
loss 600 3.0296855
loss 700 2.361609
loss 800 3.0249815
loss 900 1.7365919
</pre>

<p>we see the loss has dropped, hooray! but how does the filter behave?</p>
<p>if we plot some examples of the <code>default_params</code> versus these trained <code>params</code> for
a range of <code>missing_rate</code> we can see the trained filter is much more robust to missing values</p>
<ul>
<li>red represents ground truth</li>
<li>green represents the filter behaviour with the default params</li>
<li>yellow represents the filter behaviour with the trained params</li>
</ul>
<pre class="prettyprint"><code class="language-python">xy_trues = simulate_throw_a(3, 5)
seed = 1234414

def throw_img(xy_trues, seed, missing_rate):
  xy_preds_initial_params = rolled_out_predict_then_update(default_params(), seed, xy_trues, missing_rate)
  xy_preds_trained_params = rolled_out_predict_then_update(params, seed, xy_trues, missing_rate)
  return draw_throw_with_colours(
      [xy_trues, xy_preds_initial_params, xy_preds_trained_params],
      ['red',    'green',                 'yellow'])
</code></pre>

<table>
<tr>
<td><img src="/blog/imgs/2024/dkf/default_vs_trained.mr0.0.png"/></td>
<td><img src="/blog/imgs/2024/dkf/default_vs_trained.mr0.2.png"/></td>
<td><img src="/blog/imgs/2024/dkf/default_vs_trained.mr0.4.png"/></td>
</tr>
<tr>
<td>missing_rate=0.0</td>
<td>missing_rate=0.2</td>
<td>missing_rate=0.4</td>
</tr>
</table>

<h2>how do the optimised params differ?</h2>
<p>let's look at the differences between the original params and the ones that were learnt....</p>
<table class='data'>
<tr><th>param</th><th>original</th><th>learnt</th></tr>
<tr>
<td>A</td>
<td><pre>
[[1. 0. 1. 0.] 
 [0. 1. 0. 1.] 
 [0. 0. 1. 0.] 
 [0. 0. 0. 1.]]
</pre></td>
<td><pre>
[[ 0.958  0.014  0.897  0.038] 
 [ 0.024  1.016  0.021  1.020] 
 [ 0.004 -0.019  0.967 -0.003] 
 [-0.026 -0.020  0.024  1.025]]
</pre></td>
</tr>
<tr>
<td>Q</td>
<td><pre>
[[0.25 0.   0.5  0.  ] 
 [0.   0.25 0.   0.5 ] 
 [0.5  0.   1.   0.  ] 
 [0.   0.5  0.   1.  ]]
</pre></td>
<td><pre>
[[ 0.278  0.064  0.470 -0.035] 
 [-0.001  0.279  0.013  0.513] 
 [ 0.504 -0.027  1.031  0.042] 
 [-0.002  0.465  0.003  1.005]]
</pre></td>
</tr>
<tr>
<td>R</td>
<td><pre>
[[0.01 0.  ] 
 [0.   0.01]]
</pre></td>
<td><pre>
[[0.069 0.029] 
 [0.011 0.137]]
</pre></td>
</tr>
</table>

<p>we can see that <code>A</code> and <code>Q</code> had minor changes but <code>R</code> was changed much more, particularly the value
related to <code>y_std_meas</code></p>
<h2>but what are we actually optimising?</h2>
<p>seeing this result made me realise something. by providing the full <code>A</code> matrix to be optimised
we end up with non-zero and non-one values where as i really only wanted to tune for <code>dt</code></p>
<p>it is interesting that the model has tuned the transistion matrix fully
but maybe in some cases it's better to constrain things and only allow it to change <code>dt</code></p>
<p>to do this we just need to be more explict about what the actual parameter set is..</p>
<pre class="prettyprint"><code class="language-python">def default_params():
  return {
    'dt': 1.0,
    'std_acc': 1.0,
    'x_std_meas': 0.1,
    'y_std_meas': 0.1
  }
</code></pre>

<p>and then materialise <code>A</code>, <code>Q</code> and <code>R</code> from the <code>params</code> as required in <code>predict</code> and <code>update</code></p>
<pre class="prettyprint"><code class="language-python">def predict(params, state):
  dt = params['dt']
  # State Transition Matrix
  A = jnp.array([[1, 0, dt, 0],
                 [0, 1, 0, dt],
                 [0, 0, 1, 0],
                 [0, 0, 0, 1]])
  # Process Noise Covariance
  Q = jnp.array([[(dt**4)/4, 0, (dt**3)/2, 0],
                 [0, (dt**4)/4, 0, (dt**3)/2],
                 [(dt**3)/2, 0, dt**2, 0],
                 [0, (dt**3)/2, 0, dt**2]]) * params['std_acc']**2

  state['x'] = A @ state['x']
  state['P'] = ((A @ state['P']) @ A.T) + Q
  xy_pred = state['x'][0]
  return state, xy_pred

def update(params, state, z):
  # Define Measurement Mapping Matrix
  H = jnp.array([[1, 0, 0, 0],
                 [0, 1, 0, 0]])
  # Measurement Noise Covariance
  R = jnp.array([[params['x_std_meas'] **2, 0],
                 [0, params['y_std_meas'] **2]])

  S = (H @ (state['P'] @ H.T)) + R
  K = (state['P'] @ H.T) @ jnp.linalg.inv(S)
  state['x'] = state['x'] + (K @ (z - (H @ state['x'])))
  I = jnp.eye(4)
  state['P'] = (I - (K @ H)) @ state['P']
  return state
</code></pre>

<p>doing this makes for a minor difference in the loss, and it's not really even visible in the trace
visualisation.</p>
<p>to be honest though given the instability of the loss there are bigger things at play here
in terms of ways to improve...</p>
<table class='data'>
<tr> <th>step</th> <th>full matrix params</th> <th>scalar params</th> </tr>
<tr> <td>0</td>   <td>2.561</td> <td>5.138</td> </tr>
<tr> <td>100</td> <td>12.69</td> <td>24.98</td> </tr>
<tr> <td>200</td> <td>4.009</td> <td>2.928</td> </tr>
<tr> <td>300</td> <td>4.407</td> <td>6.072</td> </tr>
<tr> <td>400</td> <td>4.803</td> <td>2.422</td> </tr>
<tr> <td>500</td> <td>3.536</td> <td>2.577</td> </tr>
<tr> <td>600</td> <td>3.029</td> <td>4.388</td> </tr>
<tr> <td>700</td> <td>2.361</td> <td>3.006</td> </tr>
<tr> <td>800</td> <td>3.024</td> <td>1.631</td> </tr>
<tr> <td>900</td> <td>1.736</td> <td>3.306</td> </tr>
</table>

<h2>some extensions</h2>
<h3>jax.vmap</h3>
<p>running an update step <em>per</em> example is often troublesome. not only is it slow (since we're
not making use of as much vectorisation as possible) it can also suffer a lot from gradient
variance problems.</p>
<p>the general approach is to batch things and calculate gradients with respect to multiple examples
before an update step.
<a href="https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html">jax.vmap</a>
is perfect for this...</p>
<p>we can express the loss with respect to multiple examples by using jax to make a version of
the rollout function that rolls out multiple trajectories at once...</p>
<p>note: <code>in_axes</code> denotes we want to vmap transform to vectorise over the
second and third arg of the loss function ( <code>seed</code> &amp; <code>xy_trues</code> )
while broadcasting the first and last arg ( <code>params</code> &amp; <code>missing_rate</code> )</p>
<pre class="prettyprint"><code class="language-python">def loss_fn(params, seeds, xy_truess, missing_rate):
  v_rolled_out_predict_then_update = jax.vmap(rolled_out_predict_then_update, in_axes=[None, 0, 0, None])
  predictions = v_rolled_out_predict_then_update(params, seeds, xy_truess, missing_rate)
  squared_difference = (predictions - xy_truess) ** 2
  return jnp.mean(squared_difference)
</code></pre>

<p>this loss is called as before but instead with a batch of <code>seeds</code> and <code>xy_trues</code></p>
<pre class="prettyprint"><code class="language-python">batch_size = 8
xy_truess = []
seeds = []
for _ in range(batch_size):
  dx = 2 + np.random.uniform() * 5  # (2, 7)
  dy = 4 + np.random.uniform() * 4  # (4, 8)
  xy_truess.append(simulate_throw_a(dx, dy))
  seeds.append(next_seed)
  next_seed += 1
xy_truess = np.stack(xy_truess)
seeds = np.array(seeds)
...
params = update_step(params, seeds, xy_truess)
</code></pre>

<p>interestingly i found this didn't actually help! usually for a neural network this is a big win, for speed as well as
gradient variance, but in this case it was behaving <em>worse</em> for me :/</p>
<h3>why you no optax?</h3>
<p>rolling your own optimisation step is generally (*) a bad idea, better to use an update step
using an optax optimiser.</p>
<pre class="prettyprint"><code class="language-python">params = default_params()
opt = optax.adam(1e-6)
opt_state = opt.init(params)
...
@jax.jit
def update_step(params, opt_state, seeds, xy_truess):
  gradients = jax.grad(loss_fn)(params, seeds, xy_truess, missing_rate=0.1)  
  updates, opt_state = opt.update(gradients, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state
...
for i in range(1000):
  ...
  params, opt_state = update_step(params, opt_state, seeds, xy_truess)
  ...
</code></pre>

<p>having said that trying a couple of different optimisers gave no better result than the simple
hand rolled update step from above!!! i guess the dynamics are weird enough that they aren't fitting
the expected default space of adam, rmsprop, etc (?) am betting on the noisy behaviour of the rollout
to be at fault...</p>
<h3>cross product of possible data</h3>
<p>sometimes the best way to avoid noisy samples is go nuts on a fixed large dataset size
e.g. if we just materialise a large range of cross product values for <code>dx</code>, <code>dy</code> ...</p>
<pre class="prettyprint"><code class="language-python">N = 40
SEEDS_PER_DX_DY = 50
xy_truess = []
seeds = []
for dx_l in np.linspace(0, 1, N):  
  for dy_l in np.linspace(0, 1, N):    
    xy_trues = simulate_throw_a(dx=2+(dx_l*5), dy=4+(dy_l*4))
    for _ in range(SEEDS_PER_DX_DY):
      xy_truess.append(xy_trues)
xy_truess = np.stack(xy_truess)
seeds = np.array(range(len(xy_truess)))
print("xy_truess", xy_truess.shape, "seeds", seeds.shape)
</code></pre>

<pre>
xy_truess (80000, 10, 2) seeds (80000,)
</pre>

<p>we can then <code>jit</code> this fixed data into the update step</p>
<pre class="prettyprint"><code class="language-python">@jax.jit
def update_step(params, opt_state):
  gradients = jax.grad(loss_fn)(params, seeds, xy_truess, missing_rate=0.1)
  updates, opt_state = opt.update(gradients, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state

for i in range(1000):
  params, opt_state = update_step(params, opt_state)
  if i % 100 == 0:
    print("loss", i, loss_fn(params, seeds, xy_truess, missing_rate=0.1))

</code></pre>

<p>and we get a fast update and the most stable loss, though it flattens out
pretty quick</p>
<pre>
loss 0 9.521738
loss 100 5.65931
loss 200 5.651896
loss 300 5.6428213
loss 400 5.636997
...
</pre>

<table class='data'>
<tr> <th>variable</th>    <th>initial</th> <th>trained</th></tr>
<tr> <td>dt</td>          <td>1.0</td>     <td>1.3539</td> </tr>
<tr> <td>std_acc</td>     <td>1.0</td>     <td>0.1803</td> </tr>
<tr> <td>x_std_meas</td>  <td>0.1</td>     <td>0.1394</td> </tr>
<tr> <td>y_std_meas</td>  <td>0.1</td>     <td>0.1</td> </tr>
</table>

<p>curiously in this formulation we've ended up with no gradients with respect to <code>y_std_meas</code> (?)</p>
<p>so i've either got a bug or it's the subtle way it's being used.. TODO!</p>
<p>anyways, that's my time up. check out
<a href="https://colab.research.google.com/drive/15051eqBadUxQifxR-4nS-qMYbYaoW7Ne?usp=sharing">the prototype colab i made for this post</a></p>]]></content:encoded>
    </item>
    <item>
      <title>a (larger) wavenet neural net running on an FPGA at (almost) 200,000 inferences / sec</title>
      <link>http://matpalm.com/blog/wavenet_on_fpga</link>
      <category><![CDATA[eurorack]]></category>
      <category><![CDATA[wavenet]]></category>
      <category><![CDATA[fpga]]></category>
      <guid>http://matpalm.com/blog/wavenet_on_fpga</guid>
      <description>a (larger) wavenet neural net running on an FPGA at (almost) 200,000 inferences / sec</description>
      <content:encoded><![CDATA[<h2>neural nets on an MCU?</h2>
<p>in my last post, <a href="https://matpalm.com/blog/wavenet_on_mcu/">wavenet on an mcu</a>, we talked about
running a cached optimised version of wavenet on an MCU at 48kHz. this time we're going to get
things running on an FPGA instead!</p>
<p><i>note: this post assumes you've
read that last one; it describes the model architecture, some caching tricks, and describes a
waveshaping task.</i></p>
<h2>neural nets on an FPGA?</h2>
<p>when doing some research on how to make things faster i stumbled on the
<a href="https://github.com/apfelaudio/eurorack-pmod">eurorack-pmod project</a> which is a piece of hardware
that connects an FPGA with a eurorack setup. it includes 192kHz analog-digital-analog conversion
and supports 4 channels in and out. just add an FPGA! perfect!</p>
<p><img src="/blog/imgs/2023/wv_fpga/setup.jpg"/></p>
<p>i'd never done anything beyond blinking LEDs with an FPGA but it can't be that
hard.... right? right?!?!</p>
<p>well. turns out it wasn't that simple... i couldn't find any open source examples of
compiling a neural net to an FPGA. the closest was <a href="https://fastmachinelearning.org/hls4ml/">HLS4ML</a>
but, if i understand it correctly, it works only with (expensive) propertiary FPGA toolchains :(</p>
<p>HLS4ML it did at least though put me onto <a href="https://github.com/google/qkeras">qkeras</a> which is a key component we'll talk about in a second.</p>
<p>( not interested in details? jump straight to the bottom for the demos )</p>
<h2>multiple versions for training and optimised inference</h2>
<p>recall from the last post that an important aspect of the MCU project was having
multiple versions for training vs inference</p>
<ul>
<li>a <a href="https://keras.io/">keras</a> version; the main version for training the weights.
<a href="https://github.com/matpalm/cached_dilated_causal_convolutions/tree/master/keras_version">code</a></li>
<li>a <a href="https://pypi.org/project/cmsisdsp/">cmsisdsp_py</a> version; a python prototype used to develop the caching inference &amp; understand the cmsis api.
<a href="https://github.com/matpalm/cached_dilated_causal_convolutions/tree/master/cmsisdsp_py_version">code</a></li>
<li>an inference firmware; the final c++ firmware for running on the MCU using
<a href="https://www.keil.com/pack/doc/CMSIS/DSP/html/index.html">cmsis-dsp</a> for all the matrix math
<a href="https://github.com/matpalm/cached_dilated_causal_convolutions/tree/master/inference_firmware">code</a></li>
</ul>
<p>the FPGA version ended with a similar collection of variants;</p>
<ul>
<li>a <a href="https://github.com/google/qkeras">qkeras</a> version; main version for training quantised weights.
<a href="https://github.com/matpalm/cached_dilated_causal_convolutions/tree/master/qkeras_version">code</a></li>
<li>a <a href="https://github.com/francof2a/fxpmath">fxpmath</a> version; a python prototype of hand rolled fixed point inference ( reusing the caching inference from the MCU version ).
<a href="https://github.com/matpalm/cached_dilated_causal_convolutions/tree/master/fxpmath_version">code</a></li>
<li>a <a href="https://en.wikipedia.org/wiki/Verilog">verilog</a> version; an FPGA version running first in simulation and then on eurorack pmod hardware.
<a href="https://github.com/matpalm/cached_dilated_causal_convolutions/tree/master/sverilog_version">code</a></li>
</ul>
<p>for both the MCU and FPGA versions we use a single firmware running on the daisy patch
for
<a href="https://github.com/matpalm/cached_dilated_causal_convolutions/tree/master/datalogger_firmware">training data collection</a>.</p>
<p>let's expand on each part of the FPGA version...</p>
<h3>qkeras</h3>
<p>neural networks end up doing a lot of matrix multiplication. <em>a lot</em>.</p>
<p>for training it's common to work with a floating point representation since it gives the best continuous view of the loss space.</p>
<p>but for inference we can often reduce the precision of the weights and use integer arithmetic.
we're motivated to do this since, for a lot of systems, integer math is <em>much</em> faster to run than full floating point math.</p>
<p>( and, actually, in a lot of systems we just simply might not have the ability to do floating point math at all! )</p>
<p>the systems we use for converting from floats during training to something simpler for inference is called quantisation and we'll look at two flavours for this project.</p>
<h4>side note 1; fixed point arithmetic</h4>
<p>for this project i mainly used <a href="https://en.wikipedia.org/wiki/Fixed-point_arithmetic">fixed point numbers</a>.</p>
<p>fixed point numbers are a simpler representation of floating point numbers with
some constraints around range and precision but they allow the multiplication to be done as if it were integer
multiplication. ( this is the first project i've done with fixed point math and i'm hooked, it's perfect for neural nets! )</p>
<p>the high level idea is you specify a total number of bits and then how many of those bits you want to use
for the integer part of the number, and how many you want to use from representing the fractional part.</p>
<p>in this project all inputs, outputs, weights and biases are 16 bits in total with 4 bits for the integer
and 12 bits for the fractional part. ( i.e. FP4.12 )</p>
<p>with "only" 4 bits for the integer part this means the range of values is +/- 2^4 = 8. though this might seem
limiting, it's actually ok for a network, where generally activations etc are centered on zero ( especially if we add
a bit of L2 regularisation along the way )</p>
<p>with 12 bits allocated to the fractional part we are able to describe numbers with a precision of 2^-12 = 0.00024.</p>
<p>here are some examples; we show the 16 bit binary number with a decimal point after the 4th bit to denote the change
from the integer part to the fractional part.</p>
<table class='data'>
<tr><th>bits</th><th>decimal</th></tr>
<tr><td>0010.0000 0000 0000</td><td>2^1 = 2</td></tr>
<tr><td>0101.0000 0000 0000</td><td>2^2 + 2^0 = 4 + 1 = 5</td></tr>
<tr><td>0000.0000 0000 0000</td><td>0</td></tr>
<tr><td>0000.1000 0000 0000</td><td>2^-1 = 0.5</td></tr>
<tr><td>0000.1001 0000 0000</td><td>2^-1 + 2^-4 = 0.5 + 0.0625 = 0.5625</td></tr>
<tr><td>0000.0000 0000 0001</td><td>2^-12 = 0.000244140625</td></tr>
</table>

<p>the purpose of the qkeras model then is to train in full float32 but provide the ability to
export the weights and biases in this configured fixed point configuration. </p>
<p>notes...</p>
<ul>
<li>using twos compliment for negative number means it's not <em>quite</em> +/-8 but it's close :)</li>
<li>though the weights, biases and activations are in FP4.12 the multiplications and accumulations are
calculated with FP8.24 ( since multiplying two FP4.12 numbers produces an FP8.24 result ). so in all the following
we accumulate as FP8.24 and then "slice" the middle bits out to make it a FP4.12 number again.</li>
</ul>
<h4 id="po2">side note 2; power of 2 quantisation</h4>

<p>the options for quantisation can get pretty crazy too!</p>
<p>qkeras provides another scheme called power-of-two quantisation where all values are quantised to be <em>only</em> powers of two.
i.e. depending on the fixed point config, a weight/bias can only only be one of [ +/- 1, 0, +/- 1/2, +/- 1/4, +/- 1/8, ...]</p>
<p>though this seems overly restrictive it has one HUGE important benefit... when a weight is a power of two then the "multiplication"
of a feature by that weight can be simply done with a bit shift operation. and bit shifting is VERY fast.</p>
<p>and there are ways to "recover" representational power too, the best one i found being based around matrix factorisation.</p>
<p>if we have, say, a restricted weight matrix W of shape (8, 8) it can only contain those fixed values.
but if we instead represent W as a <em>a product</em> of matrices, say, (8, 32). (32, 8) then we can see that, even though all the
individual weights are restricted, the product of the matrices, an effective (8, 8) matrix, has many more possible values.</p>
<p>the pro is that the weights can take many more values, all the combos of w*w. the con though is we have to do two mat muls.
depending on how much space we have for allocating the shift operations compared to the number of multiply unit we have, this tradeoff might be ok!</p>
<p>i messed around with this a lot and though it was interesting, and generally worked, it turned on that for the FPGA sizing
( i'm using at least ) the best result was to just use fixed point multiplication instead. :/ am guessing this is fault on my part in terms
of poor verilog design and i still have some ideas to try at least...</p>
<h3>fxpmath version</h3>
<p>anyways, back to the models. the next model after the qkeras one is a
<a href="https://github.com/francof2a/fxpmath">fxpmath</a> one.</p>
<p>the fxpmath version connects qkeras model fixed point export with the caching approach the inference logic that will go into the verilog design</p>
<p>the activation caching has two two elements that are basically the same as the MCU version; the <code>left shift buffer</code>
for handling the first input and an <code>activation cache</code> for handling the activations between each convolutional
layer.</p>
<p>the <em>big</em> difference comes in with the implementation of the convolution which i had to roll from scratch :/
but at least it allows for some very optimised parallelisation.</p>
<h4>convolution 1d</h4>
<p>consider a 1D convolution with kernel size K=4 &amp; input / output feature depth of 16.</p>
<p>since we don't intend to stride this convolution at all ( that's handled implicitly by the activation caching ) we can treat this convolution as the following steps...</p>
<ul>
<li>4 instance of an input by weight matrix multiplication; i.e. 4x a (1, 16) . (16, 16) matrix multiply.</li>
<li>an accumulation of these K=4 results</li>
<li>the adding of a bias</li>
<li>applying an activation ( just relu for this project )</li>
</ul>
<p>each of the x4 matrix multiplications are actually just a row by matrix multiplication, so can be decomposed into k=16
independent dot products ( we'll call this a <code>row_by_matrix_multiply</code> from now on )</p>
<p>and each of <em>those</em> dot products can be decomposed into k=16 independent multiplications followed by an accumulation.</p>
<p>the reason for being so explicit about what is independent, and what isn't, comes into play with the verilog version.</p>
<h3>verilog version</h3>
<p>verilog is a language used to "program" FPGAs and 's not at all like other languages, it's been really interesting to learn.</p>
<p>in the MPU version the two big concerns were</p>
<ol>
<li>is there enough RAM to hold the samples and network? ( was never a problem ) and</li>
<li>does the code run fast enough to process a sample before the next comes in?</li>
</ol>
<p>but the FPGA version is a little different. instead we more flexibility to design things based on what we want to run in parallel, vs what we want to
run sequentially. for neural networks this gives lots of options for design!</p>
<h3>side note 3; mats brutally short intro to verilog for matrix math</h3>
<p>at a 30,000" view of verilog we have two main concerns; 1) executing code in parallel 2) executing blocks of code sequentially.</p>
<h4>introducing verilog with a dot product</h4>
<p>e.g. consider a dot product A.B with |A|=|B|=4</p>
<p>normally we'd think of this as simply something like <code>a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3]</code> provided by a single <code>np.dot</code> call.</p>
<p>but if we're writing verilog we have to think of things in terms of hardware, and that means considering parallel vs sequential.</p>
<p>the simplest sequential way would be like the following psuedo code... </p>
<pre class="prettyprint"><code class="language-verilog">// WARNING PSEUDO CODE!
case(state):
  multiply_0:
    // recall; the following three statements are run _in parallel_ 
    accum <= 0                    // set accumulator to 0
    product <= a[0] * b[0]        // set intermediate product variable to a0.b0
    state <= multiply_1           // set the next state
  multiply_1:
    accum <= accum + product      // update accumulator <b>with the product value from the last state</b>
    product <= a[1] * b[1]        // set product variable to a0.b0
    state <= multiply_2           // set the next state
  multiply_2:
    accum <= accum + product
    product <= a[2] * b[2]
    state <= multiply_1
  multiply_3:
    accum <= accum + product
    product <= a[3] * b[3]
    state <= final_accumulate
  final_accumulate:
    accum <= accum + product
    state <= result_ready
  done:
    // final result available in `accum`
    state <= done
</code></pre>

<p>doing things this way does the dot product in 5 cycles... </p>
<h4>an important aside...</h4>
<p>an important aspect of how verilog works is to note that the statements in one of those case clauses <i>all run in parallel</i>. i.e. all right hand 
sides are evaluated and then assigned to the left hand side</p>
<p>e.g. the following would swap a and b</p>
<pre class="prettyprint">
a <= b;
b <= a;
</pre>

<p>and if functionally equivalent to ...</p>
<pre class="prettyprint">
b <= a;
a <= b;
</pre>

<p>so thinking in terms of sequential vs parallel we have the option to do 
more than one multiplication at any given time. this requires the hardware to be able to support 2 multiplications per 
clock cycle, but saves some clock cycles. not much in this case, but it add ups as |a| and |b| increase...</p>
<pre class="prettyprint"><code class="language-verilog">// WARNING PSEUDO CODE!
case(state):
  multiply_01:
    accum_0 <= 0
    accum_1 <= 0
    product_0 <= a[0] * b[0]
    product_1 <= a[1] * b[1]    
    state <= multiply_23
  multiply_23:
    accum_0 <= accum_0 + product_0
    accum_1 <= accum_1 + product_1  
    product_0 <= a[2] * b[2]
    product_1 <= a[3] * b[3]    
    state <= accumulate_0
  accumulate_0:
    accum_0 <= accum_0 + product_0
    accum_1 <= accum_1 + product_1    
    state <= accumulate_1
  accumulate_1:
    accum_0 <= accum_0 + accum_1
    state <= done
  done:
    // final result available in `accum_0`
    state <= done
</code></pre>

<p>or, if we want to go nuts, and, again, can support it, we can do all the elements multiplications at the same time, 
and then hierarchically accumulate the result into one. </p>
<pre class="prettyprint"><code class="language-verilog">// WARNING PSEUDO CODE!
case(state):
  multiply_all:
    // calculate all 4 elements of dot product in parallel; ( note: requires 4 available multiplication units )
    product_0 <= a[0] * b[0]
    product_1 <= a[1] * b[1]
    product_2 <= a[2] * b[2]
    product_3 <= a[3] * b[3]
    state <= accumulate_0
  accumulate_0:
    // add p0 and p1 at the same time as p2 and p3
    accum_0 <= product_0 + product_1
    accum_1 <= product_2 + product_3
  accumulate_1:
    // final add of the two
    accum_0 <= accum_0 + accum_1
    state <= done
  done:
    // final result available in `accum0`
    state <= done
</code></pre>

<p>this general idea of how much we do in a single clock cycles versus making values available in the next clock gives a lot
of flexibility for a design</p>
<h4>conv1d design in verilog</h4>
<p>specifically for this neural net i've represented the K=4 conv 1d as</p>
<ul>
<li>a <code>dot_product</code> module
<a href="https://github.com/matpalm/cached_dilated_causal_convolutions/blob/master/sverilog_version/src/dot_product.sv">code</a>
where the elements of dot products are calculated sequentially; i.e. <code>x[0]*w[0]</code> in the first step, <code>x[1]*w[1]</code> in the second, etc. 
so a dot product of N values takes N*M clock cycles ( where M is the number of cycles the multiple unit takes ) + some parallel accumulation.</li>
<li>a <code>row_by_matrix_multiply</code> module
<a href="https://github.com/matpalm/cached_dilated_causal_convolutions/blob/master/sverilog_version/src/row_by_matrix_multiply.sv">code</a>
which runs the j=16 dot products required for each <code>row_by_matrix_multiply</code> in parallel</li>
<li>and a <code>conv1d</code> module
<a href="https://github.com/matpalm/cached_dilated_causal_convolutions/blob/master/sverilog_version/src/conv1d.sv">code</a>
that runs the K=4 <code>row_by_matrix_multiply</code>s are also in parallel, as well as handling the state machine for accumulating results with a bias and applying relu. </li>
</ul>
<p>so we end up having all j=16 * K=4 = 64 dot products run in parallel, all together.</p>
<p>having said this, there are a number of ways to restructure this; e.g. if there were too many dot products to run 
in parallel for the 4 <code>row_by_matrix_multiply</code> we could run 2 of them in parallel, and then when they were finished, 
run the other 2. there are loads of trade offs between the number of multiple units available vs the time required to run them.</p>
<h2>a slightly modified data set</h2>
<p>in the MCU version i was only feeding in samples based on the embedding corner points, one of 4 types sampled randomly from...</p>
<table class='data'>
<tr><th>input ( core wave, e0, e1 ) </th><th>output</th></tr>
<tr><td>(triangle, 0, 0)</td><td>sine</td></tr>
<tr><td>(triangle, 0, 1)</td><td>ramp</td></tr>
<tr><td>(triangle, 1, 1)</td><td>zigzag</td></tr>
<tr><td>(triangle, 1, 0)</td><td>square</td></tr>
</table>

<p>for the first FPGA version i did this but the model was large enough that it was quickly overfitting this and basically outputing 
noise for the intermediate points. as such for training of this model i changed things a bit to include interpolated data. </p>
<p>basically we emit corner points, say (e0=0, e1=1, sine wave) as well as interpolated points, where we pick two waves, say sine and ramp, and a 
random point between them and train for that point as an interpolated wave between the two ( using constant power interpolation ) </p>
<p>though i couldn't get the MCU model to converge well with this kind of data, the larger FPGA variant has no problems.</p>
<p>i also messed around with a 3d input embeddings; to translate between any pairing, but it didn't add anything really so i stuck with 2d.</p>
<h2>final model</h2>
<p>where as the final MCU model was ...</p>
<pre class="prettyprint">
---------------------------------------------------------------------------
 Layer (type)         Output Shape      Par#  Conv1D params
---------------------------------------------------------------------------
 input (InputLayer)   [(None, 256, 3)]  0
 c0a (Conv1D)         (None, 64, 4)     52    F=4, K=4, D=1, P=causal
 c0b (Conv1D)         (None, 64, 4)     20    F=4, K=1
 c1a (Conv1D)         (None, 16, 4)     68    F=4, K=4, D=4, P=causal
 c1b (Conv1D)         (None, 16, 4)     20    F=4, K=1
 c2a (Conv1D)         (None, 4, 4)      68    F=4, K=4, D=16, P=causal
 c2b (Conv1D)         (None, 4, 4)      20    F=4, K=1
 c3a (Conv1D)         (None, 1, 8)      136   F=8, K=4, D=64, P=causal
 c3b (Conv1D)         (None, 1, 8)      72    F=8, K=1
 y_pred (Conv1D)      (None, 1, 1)      13    F=1, K=1
---------------------------------------------------------------------------
Trainable params: 465
---------------------------------------------------------------------------
</pre>

<p>... the current FPGA version i'm running is ...</p>
<pre class="prettyprint">
-----------------------------------------------------------------
 Layer (type)                Output Shape             Param #   
-----------------------------------------------------------------
 input_1 (InputLayer)        [(None, 64, 4)]          0                                                                         
 qconv_0 (QConv1D)           (None, 16, 16)           272                                                                        
 qrelu_0 (QActivation)       (None, 16, 16)           0                                                                          
 qconv_1 (QConv1D)           (None, 4, 16)            1040                                                                       
 qrelu_1 (QActivation)       (None, 4, 16)            0                                                                          
 qconv_2 (QConv1D)           (None, 1, 4)             260                                                                        
-----------------------------------------------------------------
Trainable params: 1,572
-----------------------------------------------------------------
</pre>

<p>it has the following differences</p>
<ul>
<li>have lowered the depth for the model from 4 to 3; a receptive field of 64 is enough to handle this waveshaping</li>
<li>have dropped the secondary 1x1 conv and am just running 1 conv1d per layer; primarily because the code is just simpler this way</li>
<li>all the internal filter sizes have been increase to 16d. </li>
</ul>
<p>so compared to the MCU version</p>
<ul>
<li>it has x3 the params</li>
<li>it's running at 192kHz instead of 32kHz ( so x4 faster )</li>
<li>and is running at a utilisation of 30% instead of 88% </li>
</ul>
<p>to be honest utilisation is a bit harder to compare; the trade off between compute and space is quite different with an FPGA design..</p>
<p>for each sample coming in at 192kHz the FPGA is running a simple state machine of 1) accept next sample 2) run the sequence of qconvs and 
activation caches, then 3) output the result and sit in a while-true loop until the next sample. when i say above the FPGA is running at 30% what
i really should say is that it's spending 70% of the time in the post sample processing while-true loop waiting for the next sample.</p>
<p>looking at the device utilisation we have the following..</p>
<pre class="prettyprint">
Info: Device utilisation:
Info:             TRELLIS_IO:    11/  365     3%
Info:                   DCCA:     5/   56     8%
Info:                 DP16KD:    24/  208    11%
Info:             MULT18X18D:   134/  156    85%
Info:                EHXPLLL:     1/    4    25%
Info:             TRELLIS_FF: 21081/83640    25%
Info:           TRELLIS_COMB: 44647/83640    53%
Info:           TRELLIS_RAMW:   192/10455     1%
</pre>

<p>the pieces of interest are...</p>
<ul>
<li><b><code>DP16KD</code></b>: which is the amount of ( one type ) of RAM being used; this looks to be dominated by the activation cache, 
so with only 11% being used there is a lot of room for having more layers.</li>
<li><b><code>MULT18X18D</code></b>: is the big one, it's the max number of multiplication DSP units being used at any one time. in this model that <code>qconv1</code> 
with in_dim = out_dim = 16. since it's already 85% if we wanted to increase the filter size much more we might be forced to <em>not</em> 
run all 16 dot products of the 16x16 <code>row_by_matrix_multiply</code> at once but instead, say, do 8 in parallel, then the other 8. 
this would incur a latency hit, but that's totally fine given we still have a lot of clock time available to do work between samples. 
the trouble is the code as written would end up being tricky to refactor. </li>
</ul>
<p>currently things are setup so that the <em>entire</em> network has to run before the next sample comes in. this is just because it was
the simplest thing to do while i'm learning verilog and it seems like the FPGA is fast enough for it. but with a neural net it 
doesn't have to be like that; we really just need to finish the <em>first</em> layer before next sample comes, not the whole network.
as long as we don't mind a little bit of output latency we can run a layer per sample clock tick. doing this would actually
delay the output by number-of-layer sample clock ticks, but at 192kHz that'd be fine :)</p>
<p>another way to run a bigger network is to continue using the same naive MULT18X18D dsp allocation but just use an intermediate layer 
twice; e.g. if you have a network input -&gt; conv0 -&gt; output you can get extra depth but running input -&gt; conv0 -&gt; conv0 -&gt; output instead.
you lose a bit of representation power, since the same layer needs to model two layers, but sometimes it's worth it. in this model we'd get extra
depth without having to worry about more allocation, and we've got plenty of headroom to do more compute.</p>
<h3>the po2 work in progress model</h3>
<p>the work in progress model i've been tinkering with for the power of two quantisation is the following...</p>
<pre class="prettyprint">
_________________________________________________________________
 Layer (type)                Output Shape              Param #  
_________________________________________________________________
 input_1 (InputLayer)        [(None, 640, 4)]          0                                                                          
 qconv_0_qb (QConv1D)        (None, 640, 8)            136       
 qrelu_0 (QActivation)       (None, 640, 8)            0         
 qconv_1_qb (QConv1D)        (None, 640, 8)            264       
 qrelu_1 (QActivation)       (None, 640, 8)            0         
 qconv_1_1a_po2 (QConv1D)    (None, 640, 16)           144       
 qconv_1_1b_po2 (QConv1D)    (None, 640, 8)            136       
 qrelu_1_1 (QActivation)     (None, 640, 8)            0         
 qconv_1_2a_po2 (QConv1D)    (None, 640, 16)           144       
 qconv_1_2b_po2 (QConv1D)    (None, 640, 8)            136       
 qrelu_1_2 (QActivation)     (None, 640, 8)            0         
 qconv_2_qb (QConv1D)        (None, 640, 4)            132       
_________________________________________________________________
Trainable params: 1,092
_________________________________________________________________
</pre>

<ul>
<li>the layers postfixed with <code>_qb</code> are the normal quantised bits layers which are fixed point weights and use <code>MULT18X18D</code> units for inference.</li>
<li>the layers postfixed with <code>_po2</code> have power-of-two weights and use just shift operators for inference.</li>
</ul>
<p>the output quality is the same and it uses less <b><code>MULT18X18D</code></b> units but doesn't quite fit :)</p>
<pre class="prettyprint">
Info: Device utilisation:
Info:             TRELLIS_IO:    11/  365     3%
Info:                   DCCA:     5/   56     8%
Info:                 DP16KD:    12/  208     5%
Info:             MULT18X18D:    70/  156    44%
Info:                EHXPLLL:     1/    4    25%
Info:             TRELLIS_FF: 37356/83640    44%
Info:           TRELLIS_COMB: 85053/83640   101%   close!!
Info:           TRELLIS_RAMW:    96/10455     0%
</pre>

<p>i spend a bunch of timing trying various combos of reuse of the modules but never had a design that would meet timing for the FPGA :/</p>
<p>i still feel i'm doing something wrong here, and might come back to it. </p>
<h2>waveshaping wave examples</h2>
<table class='data'>
<tr><td><img src="/blog/imgs/2023/wv_fpga/interp_examples.jpg"/></td></tr>
<tr><td>waveforms generated by the model across the embedding space</td></tr>
</table>

<p>the above images show the range of waveforms generated by the model across the two dimensional embedding space. of note...</p>
<ul>
<li>the corners represent the original logged waveshapes of ( from top left clockwise ) sine, zigzag, square and ramp. 50% of the training examples are of this type.</li>
<li>the outer edges represent the examples given to the model during training, interpolated ( by constant power cross fade ) from pairs of corners. these represent the other 50% of training examples.</li>
<li>the inner examples are all interpolations created by the model. note that the closer you get to the middle the further you are from a training example and the noisier the waves get...</li>
</ul>
<h2>waveshaping audio examples</h2>
<p>lets look at some examples!</p>
<ul>
<li>green trace; core triangle wave being wave shaped</li>
<li>blue trace; embedding x value, range -1 to 1</li>
<li>red trace; embedding y value; range -1 to 1</li>
<li>yellow; neural net output ( and the audio we hear )</li>
</ul>
<p>( make sure subtitles are enabled! )</p>
<h3>corners of the embedding space</h3>
<p>an example of a triangle core wave as input and a manual transistion of the embedding values between the corners of the 2d space.</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/J2921Cir6xw?si=208sna6kkd11wuKt" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>

<h3>modulating an embedding point at audio rates</h3>
<p>modulating the embedding x value at audio rates makes for some great timbres! the FPGA and the eurorack pmod have no problems handling this.</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/-14O8rNdFyQ?si=s4oQkxNpCIsavVRf" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>

<h3>discontinuity glitches</h3>
<p>since the model was trained <em>only</em> on a triangle wave if you give it something discontinuous, like a ramp or square, it glitches! :)</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/zpNIqOuacqE?si=q6WKjLVB2AXcXI-l" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>

<h3>neural net techno</h3>
<p>and if you sequence things, add an envelope to the embedding point &amp; stick in some 909 kick &amp; hats.... what do you have? 
neural net techno! doff. doff. some delay on the hats and clap, but no effects on the oscillator. 
( mixed the hats too loud as well, story of my life )</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/GAH_acoRsTQ?si=oZvgJGGPv4PxlAmA" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>

<h3>TODOs</h3>
<p>there is a lot of room to optimise for a larger network. 
see the <a href="https://github.com/matpalm/cached_dilated_causal_convolutions/issues">issues on github</a></p>
<h3>code</h3>
<p>the code for training and verilog simulation is in this 
<a href="https://github.com/matpalm/cached_dilated_causal_convolutions">github repo</a></p>
<p>whereas the code representing this as core on the eurorack pmod is in this 
<a href="https://github.com/matpalm/eurorack-pmod/tree/network">github repo on a branch</a></p>]]></content:encoded>
    </item>
    <item>
      <title>a wavenet neural net running on a microcontroller at (almost) 50,000 inferences / sec</title>
      <link>http://matpalm.com/blog/wavenet_on_mcu</link>
      <category><![CDATA[mcu]]></category>
      <category><![CDATA[wavenet]]></category>
      <category><![CDATA[eurorack]]></category>
      <guid>http://matpalm.com/blog/wavenet_on_mcu</guid>
      <description>a wavenet neural net running on a microcontroller at (almost) 50,000 inferences / sec</description>
      <content:encoded><![CDATA[<h2>neural nets on the daisy patch?</h2>
<p>the <a href="https://www.electro-smith.com/daisy/patch">electro smith daisy patch</a>
is a eurorack module made for prototyping; it provides a powerful audio focussed
microcontroller setup (the daisy 'seed' with arm cortex-m7 running at 480MHz) along with all the
connectivity required to be a eurorack module.</p>
<p>pretty much the first thing i thought when i saw one was; "could it run a neural net?"
microcontrollers are fast, but so too is audio rate.</p>
<p>note: see also the updated version of this project <a href="https://matpalm.com/blog/wavenet_on_fpga/">running on an FPGA</a> though it's worth reading this first to understand the model and waveshaping task</p>
<p><img src="/blog/imgs/2023/wv_mcu/daisy.jpg"/></p>
<p>after a bit of research i found there are a couple of people who have done some real time audio effects
processing on the daisy seed; e.g. <a href="https://github.com/GuitarML/NeuralSeed">guitarML's neural seed</a></p>
<p>all the examples i found though were quite small recurrent models and i was keen to give
<a href="https://www.deepmind.com/blog/wavenet-a-generative-model-for-raw-audio">wavenet</a>
a go instead ( i'm a big fan of causal dilated
convolutions )</p>
<h2>wavenet 101</h2>
<p>the wavenet architecture is a 1D convolutional network designed to operate on timeseries data. it is composed of two key structures;</p>
<ul>
<li>it uses causal padding to ensure that each convolution never depends on anything from the future and</li>
<li>it increases dilation exponentially each layer to give the following structure</li>
</ul>
<p><img src="/blog/imgs/2023/wv_mcu/net1.png"/></p>
<p>usually it's just dilation convolutions stacked but for my variant i also include one extra
1x1 conv between each dilation; a 1x1 doesn't break any of the dilation structure
wavenet needs and including another non linearity is almost always a good thing.</p>
<h2>training</h2>
<p>the first step was to write a firmware that just recorded some audio and/or control voltages into buffers and then
streamed them out over a serial connection. it's kinda clumsy, and i'm sure there's a cleaner way, but it works without
having to mess around too much. the code is
<a href="https://github.com/matpalm/cached_dilated_causal_convolutions/tree/master/datalogger_firmware">datalogger_firmware/</a></p>
<p>from there a wavenet could be trained
( see <a href="https://github.com/matpalm/cached_dilated_causal_convolutions/blob/master/cmsisdsp_py_version/keras_model.py">keras_model.py</a> )
and it was trivial to quantise and export to a c++ lib for the microcontroller using
<a href="https://edgeimpulse.com">edge impulse's</a>
<a href="https://docs.edgeimpulse.com/docs/edge-impulse-studio/bring-your-own-model-byom">bring-your-own-model</a></p>
<p>from there i got a bit stuck though integrating tensorflow with the daisy
optimised <code>arm_math.h</code> stuff ( Make files and such aren't my expertise ) so instead i
thought i'd use this chance to write my own inference. convolutions don't need much,
it's all just a bunch of matrix math after all.</p>
<p>while poking around doing this it suddenly occured to me that
you could do heavy heavy caching of the convolution activations!
i was super excited, thinking i was doing some novel, and it wasn't until i was actually finished that i found out
someone else had thought of the same idea basically, what they call <a href="https://github.com/tomlepaine/fast-wavenet">fast-wavenet</a> :/</p>
<p>oh well, it was fun to discover independently i suppose, and i ended up implementing things a bit differently.</p>
<h2>caching activations for inference</h2>
<p>so let's walk through the caching optimisation...</p>
<p>consider the following wavenet like network; it has 16 inputs being integrated thru 3 layers to a single prediction</p>
<p><img src="/blog/imgs/2023/wv_mcu/net1.png"/></p>
<p>if we consider a sliding window input at time step 4 we see that the output of node 0 ( the node with the red 0 )
will be the processed values from [0, 1, 2, 3]</p>
<p><img src="/blog/imgs/2023/wv_mcu/net2.png"/></p>
<p>a bit later on an interesting things happens at time step 8; node 1 now gets the same inputs that node 0 had 4 steps ago.</p>
<p><img src="/blog/imgs/2023/wv_mcu/net3.png"/></p>
<p>so we don't need to calculate it! when processing a timeseries we can see that the node 1 output just
lags node 0 by 4 steps, and nodes 2 and 3 lag another 4 each. as long as we've got the memory for caching
we can store all these in a circular buffer. </p>
<p><img src="/blog/imgs/2023/wv_mcu/net4.png"/></p>
<p>and this whole thing is stackable! in fact we only ever need to run the right hand side convolutions,
as long as we have the memory to cache.</p>
<p><img src="/blog/imgs/2023/wv_mcu/net5.png"/></p>
<p>one big win for this on the daisy is that it can all run fast and small enough that we can
use full float32 math everywhere. hoorah!</p>
<p>in terms of coding the whole thing becomes an exercise in</p>
<ol>
<li>preparing the data in the right way to pass to some crazy optimised linear algebra math and</li>
<li>doing so with as little memory copying as possible :/</li>
</ol>
<h2>representative architectures</h2>
<p>the following model has an input of 3 values =&gt; a receptive field of 256 steps, and a single output.
it runs on the daisy using the above caching technique at 48KHz @88% CPU ( ~20 micro seconds per inference ). 
we'll describe what the 3 input values are in a bit. it's the model we'll use the examples at the end.</p>
<pre>
cNa - denotes the dilated convolutions
cNb - denotes the 1x1s convolutions that follow cNa
F - number of filters
K - kernel size; either 4 for cNa or 1 for cNb
D - dilation; K^layer# for cNa, or 1 for cNb
P - padding; 'causal' or the layer dft 'valid'
</pre>

<pre>
____________________________________________________________________________________________
 Layer (type)         Output Shape      Par#  Conv1D params   
============================================================================================
 input (InputLayer)   [(None, 256, 3)]  0                                   
 c0a (Conv1D)         (None, 64, 4)     52    F=4, K=4, D=1, P=causal
 c0b (Conv1D)         (None, 64, 4)     20    F=4, K=1
 c1a (Conv1D)         (None, 16, 4)     68    F=4, K=4, D=4, P=causal
 c1b (Conv1D)         (None, 16, 4)     20    F=4, K=1
 c2a (Conv1D)         (None, 4, 4)      68    F=4, K=4, D=16, P=causal
 c2b (Conv1D)         (None, 4, 4)      20    F=4, K=1
 c3a (Conv1D)         (None, 1, 8)      136   F=8, K=4, D=64, P=causal        
 c3b (Conv1D)         (None, 1, 8)      72    F=8, K=1
 y_pred (Conv1D)      (None, 1, 1)      13    F=1, K=1
============================================================================================
Total params: 465
____________________________________________________________________________________________
</pre>

<p>as a side note, the daisy can actually run at 96kHz but at this rate i could only run a
smaller [c0a, c0b, c1a, c2b] model with just 2 filters each. it runs, but i couldn't get it
to train on the examples i show below, so i didn't use it. shame because naming the blog post
"at (almost) 100,000 inferences a second" would have had a nicer ring to it :D</p>
<p>it can also run slower, e.g. 32kHz, which allows either more filters per step, or even more
depth =&gt; large receptive field.</p>
<p>but these combos demonstrate an interesting set of trade offs we have between</p>
<ul>
<li>depth ( which dictates the receptive field / input size )</li>
<li>number of filters we can manage and</li>
<li>what audio rate we can run at.</li>
</ul>
<p>for training i just use <code>keras.layer.Conv1D</code> everywhere but then exported to the device in two passes; a python
prototype and then the c++ code for the device.</p>
<h3>cmsis-dsp</h3>
<p>the final code on the daisy uses the
<a href="https://www.keil.com/pack/doc/CMSIS/DSP/html/index.html">cmsis-dsp</a> library for all the matrix math.
only three pieces end up being used though</p>
<ul>
<li><code>arm_mat_init_f32</code> for making the matrix structures,</li>
<li><code>arm_mat_mult_f32</code> for the actual matrix multiplications and</li>
<li><code>arm_add_f32</code> for the kernel accumulation and adding biases here and there.</li>
</ul>
<p>to be honest none of these were benchmarked and i <i>assume</i> they are faster than if the multiplications
were just rolled out (???)</p>
<h3>python cmsis-dsp prototype</h3>
<p>the first pass was to get the inference working as a prototype in python.
the <a href="https://pypi.org/project/cmsisdsp/">cmsisdsp</a> python lib was used ( a pure python api equivalent ) and the code is
<a href="https://github.com/matpalm/cached_dilated_causal_convolutions/tree/master/cmsisdsp_py_version">cmsisdsp_py_version/</a>.</p>
<p>having <code>cmsisdsp</code> was very handy to prototype, especially as i've not done any <code>cmsis</code> stuff before. </p>
<h3>cmsis on the daisy</h3>
<p>the code for the daisy is under
<a href="https://github.com/matpalm/cached_dilated_causal_convolutions/tree/master/inference_firmware">inference_firmware/</a>
and is a port of the python version. it's where i probably spent the most amount of time, especially <code>block.h</code></p>
<p>it has 4 main parts</p>
<ul>
<li><code>left_shift_buffer.h</code> to handle the shifted <code>InputLayer</code></li>
<li><code>block.h</code> to handle the sequential running of each <code>cNa</code> and <code>cNb</code> pair</li>
<li><code>rolling_cache.h</code> which represents the circular buffer used for activation lag and</li>
<li><code>regression.h</code> which is just a simple <code>y=mx+b</code> using the weights from the final 1x1 Conv1D regression used during training</li>
</ul>
<p>it all ends up being pretty straightforward but i did iterate a bit;
any code involving memcpy and pointer arithmetic needs careful attention.</p>
<p>the "best" bit of the code is where the initial kera model is trained in a python notebook and
then exported to be used in the c++ code by having python code construct a <code>model_defn.h</code> with
a bunch of print statements :/ it's actually written to <code>/tmp/model_defn.h</code> no less! what
a hack, lol. may god have mercy on my soul.</p>
<h2>an example waveshaper</h2>
<p>for a demo project i wanted to make a little waveshaper; something that takes one waveform and outputs another.</p>
<p>so 5 waveforms were collected from another eurorack oscillator; a triangle, sine, saw, square and a weird zigzag thing.
during data collection random voltages were sampled to change the oscillator's frequency. </p>
<p>to convert to actual training dataset the triangle wave is used as input with one of the other waves acting as output.
which output is decided by including two additional selector variables in the inputs.
these variables are only ever {0, 1} during training but can act (very loosely) as an embedding since any float
value can be passed in the range (0, 1) when running on the device.</p>
<table class='data'>
<tr><th>input</th><th>output</th></tr>
<tr><td>(triangle, 0, 0)</td><td>sine</td></tr>
<tr><td>(triangle, 0, 1)</td><td>ramp</td></tr>
<tr><td>(triangle, 1, 0)</td><td>square</td></tr>
<tr><td>(triangle, 1, 1)</td><td>zigzag</td></tr>
</table>

<p>the model was trained for &lt; 1 min ( the model is pretty small, and there's not much variety in the data ).</p>
<p>it does well on held out test data...</p>
<table class='data'>
<tr><td><img src="/blog/imgs/2023/wv_mcu/test_sine.png"/></td></tr>
<tr><td><img src="/blog/imgs/2023/wv_mcu/test_ramp.png"/></td></tr>
<tr><td><img src="/blog/imgs/2023/wv_mcu/test_square.png"/></td></tr>
<tr><td><img src="/blog/imgs/2023/wv_mcu/test_zigzag.png"/></td></tr>
</table>

<p>but the much more interesting thing is what happens when we input values for x2 and x3 that weren't see during training data</p>
<p>e.g. what if we choose a point that is 20% between sine (0, 0) and square (1, 0) ? we end up with some weird non defined
part of the input space.</p>
<table class='data'>
<tr><td><img src="/blog/imgs/2023/wv_mcu/test_inbetween.png"/></td></tr>
</table>

<p>these out of training distribution things are always the most interesting to me; values for x2 and x3
inbetween (0, 1) give a weird sort-of-interpolation. generally these results are never a smooth transistion unless
there's some aspect of the loss that directs it to be so (and in this case there isn't). </p>
<p>we get the classic model hallucination stuff that i've always loved. 
( see my older post on <a href="https://matpalm.com/blog/2015/03/15/hallucinating_softmaxs/">hallucinating softmaxs</a> for more info )</p>
<p>we <em>could</em> encourage the model to make full use of the space by including a GAN like discriminator loss. this would be trained on
random samples of values for (x2, x3). i've seen this kind of training force the model to be much
more consistent for interpolated inputs.</p>
<p>this video of the network actually running give a better idea of the "interpolation".
the green wave shows the input core triangle wave, the blue wave shows the output waveshaped result.
the daisy module display shows the values for x2 and x3; with these values controlled by hand from the module on the right.</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/W2E7iZeZDlM?si=jSJ4IeVOxpv4lFW1" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>

<p>the video transistions across values of (x2, x3), moving between the corner (0, 1) values. at each corner the frequency of
the core triangle input wave was adjusted over a range as well. something i hadn't considered is that when we change
the frequency to be <i>outside</i> the range of what was seen during training the model does some weird extrapolation.</p>
<p>the interpolation extrapolation stuff makes some weird things, but it all sounds cool :)</p>
<h2>appendix</h2>
<p>further ideas could be...</p>
<ul>
<li>use mulaw encoding? the originl deepmind wavenet paper talks about the model output being a distribution across mulaw encoded values. that's interesting; i've used mulaw for compression in loopers and delays, but never as a categorical thing to classify!</li>
<li>there's other things we can add in the block; e.g. is this model too small to use a skip connection around the 1x1 block?</li>
<li>further study on relationship of audio rate vs dilation/receptive field</li>
<li>GAN discriminator across randomly sampled (x2, x3)</li>
<li>this is all running in float32 math but there's a number of quantised operators available too. how much bigger/faster could the network be if we swap toi quantised calculations?</li>
</ul>
<p>note: see also the updated version of this project <a href="https://matpalm.com/blog/wavenet_on_fpga/">running on an FPGA</a></p>
<p>code on <a href="https://github.com/matpalm/cached_dilated_causal_convolutions">github</a></p>]]></content:encoded>
    </item>
    <item>
      <title>high performance ML with JAX</title>
      <link>http://matpalm.com/blog/pycon_jax_talk</link>
      <category><![CDATA[jax]]></category>
      <category><![CDATA[talk]]></category>
      <guid>http://matpalm.com/blog/pycon_jax_talk</guid>
      <description>high performance ML with JAX</description>
      <content:encoded><![CDATA[<p>last friday i did a talk at <a href="https://2021.pycon.org.au/">pycon</a> on jax</p>
<p>here's a recording; check it out!</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/cqbBjM4_yGw" frameborder="0" allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>

<p>and here's <a href="https://drive.google.com/file/d/1_5nYIvXm_TWGYwmCwcDQtn1YI62pxz3P/view?usp=sharing">a pdf of the slides</a></p>]]></content:encoded>
    </item>
    <item>
      <title>evolved channel selection</title>
      <link>http://matpalm.com/blog/evolved_channel_selection</link>
      <category><![CDATA[projects]]></category>
      <category><![CDATA[ga]]></category>
      <category><![CDATA[jax]]></category>
      <guid>http://matpalm.com/blog/evolved_channel_selection</guid>
      <description>evolved channel selection</description>
      <content:encoded><![CDATA[<h1>multi spectral channel data</h1>
<p><a href="https://github.com/phelber/eurosat">eurosat/all</a> is a dataset
of 27,000 64x64 satellite images taken with 13 spectral bands.
each image is labelled one of ten classes.</p>
<p>for the purpose of classification these 13 aren't equally useful, and the
information in them varies across resolutions. if we were designing a sensor
we might choose to use different channels in different resolutions.</p>
<p>how can we explore the trade off between mixed resolutions and whether to use
a channel at all?</p>
<h1>a simple baseline model</h1>
<p>let's start with a simple baseline to see what performance we get. we won't
spend too much time on this model, we just want something that we can
iterate on quickly.</p>
<p>the simple model shown below trained on a 'training' split with adam hits 0.942
top 1 accuracy on a 2nd 'validation' split in 5 epochs. that'll do for a start.</p>
<p><img src="/blog/imgs/2021/ecs/single.svg.png"/></p>
<h1>what is the benefit of each channel?</h1>
<p>let's check the effect of including different combos of input channels. we'll do so by
introducing a channel mask.</p>
<p>a mask of all ones denotes using all channels and gives our baseline performance</p>
<table class='data'>
<tr><td>mask</td><td>validation accuracy</td></tr>
<tr><td>[1,1,1,1,1,1,1,1,1,1,1,1,1]</td><td>0.942</td></tr>
</table>

<p>a mask of all zeros denotes using <em>no</em> channels and acts as a sanity check; it
gives the performance of random chance which is in line with what we expect
give the balanced training set. (note: we standardise the input data so that
it has zero mean per channel (with the mean, standard deviation parameters
fit against training data only) so we can get this effect)</p>
<table class='data'>
<tr><td>mask</td><td>validation accuracy</td></tr>
<tr><td>[1,1,1,1,1,1,1,1,1,1,1,1,1]</td><td>0.942</td></tr>
<tr><td>[0,0,0,0,0,0,0,0,0,0,0,0,0]</td><td>0.113</td></tr>
</table>

<p>but what about if we drop just one channel? i.e. a mask of all ones except for a single zero.</p>
<table class='data'>
<tr><td>channel to drop</td><td>validation accuracy</td></tr>
<tr><td>0</td><td>0.735</td></tr>
<tr><td>1</td><td>0.528</td></tr>
<tr><td>2</td><td>0.661</td></tr>
<tr><td>3</td><td>0.675</td></tr>
<tr><td>4</td><td>0.809</td></tr>
<tr><td>5</td><td>0.724</td></tr>
<tr><td>6</td><td>0.749</td></tr>
<tr><td>7</td><td>0.634</td></tr>
<tr><td>8</td><td>0.874</td></tr>
<tr><td>9</td><td>0.934</td></tr>
<tr><td>10</td><td>0.593</td></tr>
<tr><td>11</td><td>0.339</td></tr>
<tr><td>12</td><td>0.896</td></tr>
</table>

<p>from this we can see that the performance hit we get from losing a single channel
is not always the same. in particular consider channel 11; if we drop that channel we
get a huge hit! does that mean that if we keep <em>only</em> 11 that should give reasonable
performance?</p>
<table class='data'>
<tr><td>mask</td><td>validation accuracy</td></tr>
<tr><td>[1,1,1,1,1,1,1,1,1,1,1,1,1]</td><td>0.942</td></tr>
<tr><td>[0,0,0,0,0,0,0,0,0,0,0,0,0]</td><td>0.113</td></tr>
<tr><td>[0,0,0,0,0,0,0,0,0,0,0,1,0] (keep just 11)</td><td>0.260</td></tr>
</table>

<p>bbbzzzttt (or other appropriate annoying buzzer noise). channel 11
is contributing to the classification but it's not being used independently.
in general this is exactly the behaviour we want from a neural network
but what should we do to explore the effect of not having this dependence?</p>
<h1>dropping out channels</h1>
<p>consider using a dropout idea, just with input channels instead of intermediate nodes.</p>
<p>what behaviour do we get if we drop channels out during
training? i.e. with 50% probability we replace an entire input channel with 0s?</p>
<p>things take longer to train and we get a slight hit in accuracy...</p>
<table class='data'>
<tr><td>dropout?</td><td>validation accuracy</td></tr>
<tr><td>no</td><td>0.942</td></tr>
<tr><td>yes</td><td>0.934</td></tr>
</table>

<p>...but now when we mask out one channel at a time we don't get a big hit for losing any
particular one.</p>
<table class='data'>
<tr><td>channel to drop</td><td colspan=2>validation accuracy</td></tr>
<tr><td></td><td>no dropout</td><td>with dropout</td></tr>
<tr><td>0</td><td>0.735</td><td>0.931</td></tr>
<tr><td>1</td><td>0.528</td><td>0.931</td></tr>
<tr><td>2</td><td>0.661</td><td>0.936</td></tr>
<tr><td>3</td><td>0.675</td><td>0.935</td></tr>
<tr><td>4</td><td>0.809</td><td>0.937</td></tr>
<tr><td>5</td><td>0.724</td><td>0.934</td></tr>
<tr><td>6</td><td>0.749</td><td>0.931</td></tr>
<tr><td>7</td><td>0.634</td><td>0.927</td></tr>
<tr><td>8</td><td>0.874</td><td>0.927</td></tr>
<tr><td>9</td><td>0.934</td><td>0.927</td></tr>
<tr><td>10</td><td>0.593</td><td>0.927</td></tr>
<tr><td>11</td><td>0.339</td><td>0.933</td></tr>
<tr><td>12</td><td>0.896</td><td>0.937</td></tr>
</table>

<h1>evolving the channel selection</h1>
<p>now that we have a model that is robust to any combo of channels what do we see
if we use a simple genetic algorithm (GA) to evolve the channel mask to use
with this pre trained network?
a mask that represents using all channels will be the best right? right?</p>
<p>we'll evolve the GA using the network trained above but based on it's performance
on a 3rd "ga_train" split using the inverse loss as a fitness function.</p>
<p>amusingly the GA finds that a mask of <code>[1,1,0,1,0,0,0,1,1,0,1,0,1]</code>
does better marginally better than all channels, but only uses 1/2 of them!</p>
<table class='data'>
<tr><td>mask</td><td>split</td><td>accuracy</td></tr>
<tr><td>[1,1,1,1,1,1,1,1,1,1,1,1,1] (all)</td><td>ga_validate</td><td>0.934</td></tr>
<tr><td>[1,1,0,1,0,0,0,1,1,0,1,0,1] (ga)</td><td>ga_validate</td><td>0.936</td></tr>
</table>

<p>important note: we can imagine the best performance overall would be to have the GA
evolve not the channels to use from this model, but the channels to use when training <em>from
scratch</em>. this would though require a lot more model training, basically a full training
cycle per fitness evaluation :( in the approach we describe here
we only have to train a single model and then have the GA just run inference.</p>
<h1>what about different resolutions?</h1>
<p>taking the idea of channel selection a step further, what if we got the GA to not
only decide whether to use a channel or not, but also <em>what resolution</em> it should be
in?</p>
<p>consider some example images across resolutions....</p>
<table class='data'>
<tr><td colspan=4>example images (just RGB channels shown)</td></tr>
<tr><td>orig x64</td><td>x32</td><td>x16</td><td>x8</td>
<tr>
<td><img src="/blog/imgs/2021/ecs/i05_x64.png" width='128'/></td>
<td><img src="/blog/imgs/2021/ecs/i05_x32.png" width='128'/></td>
<td><img src="/blog/imgs/2021/ecs/i05_x16.png" width='128'/></td>
<td><img src="/blog/imgs/2021/ecs/i05_x08.png" width='128'/></td>
</tr>
<tr>
<td><img src="/blog/imgs/2021/ecs/i06_x64.png" width='128'/></td>
<td><img src="/blog/imgs/2021/ecs/i06_x32.png" width='128'/></td>
<td><img src="/blog/imgs/2021/ecs/i06_x16.png" width='128'/></td>
<td><img src="/blog/imgs/2021/ecs/i06_x08.png" width='128'/></td>
</tr>
<tr>
<td><img src="/blog/imgs/2021/ecs/i08_x64.png" width='128'/></td>
<td><img src="/blog/imgs/2021/ecs/i08_x32.png" width='128'/></td>
<td><img src="/blog/imgs/2021/ecs/i08_x16.png" width='128'/></td>
<td><img src="/blog/imgs/2021/ecs/i08_x08.png" width='128'/></td>
</tr>
</table>

<p>we could then weight the use of a channel based on resolution; the higher the resolution
the more the channel "costs" to use, with not using the channel at all being "free".</p>
<p>to support this we can change the GA to represent members not as a string of {0, 1}s
but instead a sequence of {0, x8, x16, x32, x64} values per channel where these represent...</p>
<table class='data'>
<tr><td><b>resolution</b></td><td><b>description</b></td><td><b>channel cost<b/></td></tr>
<tr><td>x64</td><td>use original (64, 64) version of input</td><td>0.8</td></tr>
<tr><td>x32</td><td>use a 1/2 res (32, 32) version of input</td><td>0.4</td></tr>
<tr><td>x16</td><td>use a 1/4 res (16, 16) version of input</td><td>0.2</td></tr>
<tr><td>x8</td><td>use a 1/8 res (8, 8) version of input</td><td>0.1</td></tr>
<tr><td>0</td><td>don't use channel</td><td>0</td></tr>
</table>

<p>the change in the encoding of our GA is trivial, just 5 values per channel instead of 2,
but before we look at that; how do we change our network?</p>
<p>we can do it without having to add too many extra parameters by using the magic of fully
convolutional networks :)</p>
<p>notice how the main trunk of our first network was a series of 2d convolutions
with a global spatial mean. this network will simply take as input all the
resolutions we need! we can simply reuse it multiple times!</p>
<p>so we can have our network...</p>
<ol>
<li>take the original x64 input</li>
<li>downsample it multiple times to x32, x16 and x8</li>
<li>mask out the channels so that each channel is only represented in one of the resolutions (or not
represented at all if we want to ignore that channel)</li>
<li>run the main trunk network with shared parameters on each of the masked resolutions</li>
<li>combine the outputs with a simple channel concatenation</li>
<li>do one more non linear mixing (because, why not..)</li>
<li>finish with the logits</li>
</ol>
<p><img src="/blog/imgs/2021/ecs/multi_res.svg.png"/></p>
<p>note: try as i might i can't get steps 2 to 4 to run parallelised in a pmap.
<a href="https://github.com/google/jax/discussions/5895">asked on github about it</a>
and looks to be something you can't do at the moment.</p>
<h1>the channel cost vs loss pareto front</h1>
<p>when we consider channel cost vs loss there is no single best solution, it's a classic
example of a
<a href="https://en.wikipedia.org/wiki/Pareto_efficiency">pareto front</a>
where we see a tradeoff between the channel_cost and loss.</p>
<p>consider this sampling of 1,000 random channel masks...</p>
<p><img src="/blog/imgs/2021/ecs/pareto_front.just_random.png"/></p>
<h1>rerunning the GA</h1>
<p>the GA needs to operate with a fitness that's a single scalar; for now we just use
a simple combo of <code>(1.0 / loss) - channel_cost</code></p>
<p>running with this fitness function we evolve the solution
<code>[x16, x64, x64, x16, x32, ignore, x8, x64, x8, ignore, x8, ignore, x32]</code></p>
<p>it's on the pareto front, as we'd hope, and it's interesting that it includes
a mix of resolutions including ignoring 3 channels completely :)</p>
<p><img src="/blog/imgs/2021/ecs/pareto_front.with_ga.png"/></p>
<p>different mixings of loss and channel_cost would result in different GA solutions along the front</p>
<h1>code</h1>
<p>all the code is <a href="https://github.com/matpalm/evolved_channel_selection">on github</a></p>]]></content:encoded>
    </item>
    <item>
      <title>crazy large batch sizes</title>
      <link>http://matpalm.com/blog/crazy_large_batch_sizes</link>
      <category><![CDATA[quick_hack]]></category>
      <category><![CDATA[tpu]]></category>
      <category><![CDATA[jax]]></category>
      <guid>http://matpalm.com/blog/crazy_large_batch_sizes</guid>
      <description>crazy large batch sizes</description>
      <content:encoded><![CDATA[<h1>"use a bigger batch"</h1>
<p>a classic piece of advice i hear for people using tpus is that
they should "try using a bigger batch size".</p>
<p>this got me thinking;
i wonder how big a batch size i could reasonably use?
how would the optimisation go?
how fast could i get things?</p>
<h1>dataset</h1>
<p>let's train a model on the
<a href="https://github.com/phelber/eurosat">eurosat/rgb dataset</a>.
it's a 10 way classification problem on 64x64 images</p>
<p><img src="/blog/imgs/2020/en/sample_images.png"/></p>
<p>with a training split of 80% we have 21,600 training examples.
we'll use another 10% for validation (2,700 images)
and just not use the final 10% test split #hack</p>
<h1>model</h1>
<p>for the
<a href="https://github.com/matpalm/large_batch/blob/master/model.py">model</a>
we'll use a simple stack of convolutions with
channel sizes 32, 64, 128 and 256, a stride of 2 for spatial reduction
all with gelu activation. after the convolutions we'll do a simple global
spatial pooling, a single 128d dense layer with gelu and then a
10d logit output. a pretty vanilla architecture of ~400K params. nothing fancy.</p>
<h1>splitting up the data</h1>
<p>a v3-32 tpu pod slice is 4 hosts, each with 8 tpu devices.</p>
<p>21,600 training examples total =&gt; 5,400 examples per host =&gt; 675 examples per device.</p>
<p>this number of images easily fits on a device. great.</p>
<h2>augmentation</h2>
<p>now usually augmentation is something we do randomly per batch, but for this hack
we're interested in seeing how big a batch we can run. so why not
fill out the dataset a bit by just running a stack of augmentations before training?</p>
<p>for each image we'll do 90, 180 and 270 deg rotations along with left/right flips
for a total of 8 augmented images for each original image. e.g.....</p>
<p><img src="/blog/imgs/2021/lb/augmentations.png"/></p>
<p>this gives us now 172,800 images total =&gt; 43,200 per host =&gt; 5,400 per tpu device.
which stills fits no problem.</p>
<p>side note: turns out doing this augmentation was one of the most fun parts of
this hack :)</p>
<h2>optimisers?</h2>
<p>one motivation i had for this hack was to compare adam to lamb. i'd
seen lamb referred to in the past, would it perform better for this model/dataset size?
turns out it does! a simple sweep comparing lamb, adam and sgd shows lamb consistently
doing the best. definitely one to add to the tuning mix from now on.</p>
<h2>data / model / optimiser state placement</h2>
<p>not only does the augmented data fit sharded across devices but we can replicate
both the model parameters and the optimiser state as well. this is important
for speed since the main training loop doesn't have to do any host/device
communication. taking a data parallel approach means the only cross device
comms is a gradient psum.</p>
<h1>results</h1>
<p>for training we run an inner loop just pumping the <code>param = update(params)</code> step.</p>
<p>an outer loop runs the inner loop 100 times before doing a validation accuracy check.</p>
<p>the inner loop runs at 1.5s for the 100 iterations and since each iteration is a
forward &amp; backwards pass for all 172,800 images across all hosts that's 11M images
processed per second. 🔥🔥🔥</p>
<p>at this speed the best result of 0.95 on validation takes 13 outer loops;
i.e. all done in under 20s. o_O !!</p>
<p>when reviewing runs i did laugh to see sgd with momentum make a top 10 entry.</p>
<p><em>new t-shirt slogan: "sgd with momentum; always worth a try"</em></p>
<p><img src="/blog/imgs/2021/lb/top_10.png"/></p>
<h1>code</h1>
<p>all the code in hacktastic undocumented form
<a href="https://github.com/matpalm/large_batch">on github</a></p>]]></content:encoded>
    </item>
    <item>
      <title>solving y=mx+b... with jax on a tpu pod slice</title>
      <link>http://matpalm.com/blog/ymxb_pod_slice</link>
      <category><![CDATA[tpu]]></category>
      <category><![CDATA[ensemble_nets]]></category>
      <category><![CDATA[jax]]></category>
      <category><![CDATA[projects]]></category>
      <category><![CDATA[haiku]]></category>
      <guid>http://matpalm.com/blog/ymxb_pod_slice</guid>
      <description>solving y=mx+b... with jax on a tpu pod slice</description>
      <content:encoded><![CDATA[<h1>from jax fundamentals to running on a tpu pod slice</h1>
<p>this 4 (and a bit) part tute series starts with
<a href="https://jax.readthedocs.io/en/latest/">jax</a>
fundamentals, builds up to describing a data parallel approach to training on a
<a href="https://cloud.google.com/tpu">cloud tpu pod slice</a>, and
finishes with a tpu pod slice implementation of
<a href="http://matpalm.com/blog/ensemble_nets">ensemble nets</a>....
all with the goal of solving 1d <code>y=mx+b</code></p>
<p>and though it may seem like a bit of overkill it turns out it's a good example
to work through so that we can focus on the library support without having
to worry about the modelling.</p>
<h2>part 1: some jax basics</h2>
<p>in this first section we introduce some jax fundamentals;
e.g. make_jaxpr, grad, jit, vmap &amp; pmap.</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/W1vfBDFLm7Q" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>

<p>colab: <a href="https://colab.research.google.com/drive/1vmEckPE6o9pDJF1tctPW5yyuAIVZ817B">01 pmap jit vmap oh my.ipynb</a></p>
<h2>part 2: solving y=mx+b</h2>
<p>in part 2 we use the techniques from part 1 to solve <code>y=mx+b</code> in pure jax. we'll also
introduce
<a href="https://jax.readthedocs.io/en/latest/pytrees.html">pytrees</a>
and various
<a href="https://jax.readthedocs.io/en/latest/jax.tree_util.html">tree_utils</a>
for manipulating them.</p>
<p>we run first on a single device and work up to using
<a href="https://jax.readthedocs.io/en/latest/jax.html#jax.pmap">pmap</a> to demonstrate a
simple data parallelism approach. along the way we'll do a small detour to a tpu pod slice
to illustrate the difference in a multi host setup.</p>
<p>( note: the experience as described here for a pod slice isn't publically available yet; but sign up via the
<a href="http://goo.gle/jax-tpu-signup">JAX on Cloud TPU Interest Form</a> to get more info. see also this
<a href="https://docs.google.com/presentation/d/1eBfNKT3D3lEWtcn4mkgvvitKZTU7HSn4f3fGHqYgpeA/">JAX on Cloud TPUs (NeurIPS 2020)</a> talk )</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/XXsSZlHzHcw" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>

<p>colab: <a href="https://colab.research.google.com/drive/1qkjyNtmPzQIYHAY7l1YqvMgTOfbQAfgA?usp=sharing">02 y mx b on a tpu.ipynb</a></p>
<h2>part 3: introducing haiku and optax</h2>
<p>next we introduce <a href="https://github.com/deepmind/dm-haiku">haiku</a> as a way
of defining our model and <a href="https://github.com/deepmind/dm-haiku">optax</a> as a
library to provide standard optimisers. to illustrate there use we'll do a minimal
port of our model and training loop to use them.</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/8-N1-7lPWOs" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>

<p>colab: <a href="https://colab.research.google.com/drive/1_MDXCnwmLTPPm4qJnpHYj-X0E-a-ev6H?usp=sharing">03 y mx b in haiku.ipynb</a></p>
<h2>part 4: ensemble nets</h2>
<p>in part 4 we'll reimplement <a href="http://matpalm.com/blog/ensemble_nets">ensemble nets</a>
for this trivial model, continuing to do things in a way that supports
running on a tpu pod slice.</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/_-ftTbABKuk" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>

<p>colab: <a href="https://colab.research.google.com/drive/1uWZQQfp5T4nKRdg_kgV93tIzkA316f2O?usp=sharing">04 y mx b haiku ensemble.ipynb</a></p>
<h2>part 5: some sanity</h2>
<p>to wrap up we acknowledge that though tpu pod slices and data parallel
approaches <em>are</em> fun we could have just solved this in a single
calculation using the normal equation... :D</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/5hKse-PUo0k" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>

<p>colab: <a href="https://colab.research.google.com/drive/1sX5plJfIesT-5IUU-d9U-mimXtWOXEE3?usp=sharing">05 booooooooooooooooring.ipynb</a></p>
<h1>what a way to solve <code>y=mx+b</code> !!!</h1>
<p><img src="/blog/imgs/2021/ymxb/tenor.gif"/></p>]]></content:encoded>
    </item>
    <item>
      <title>develomentor.com podcast interview</title>
      <link>http://matpalm.com/blog/develomentor_podcast</link>
      <category><![CDATA[talk]]></category>
      <guid>http://matpalm.com/blog/develomentor_podcast</guid>
      <description>develomentor.com podcast interview</description>
      <content:encoded><![CDATA[<p>was <a href="https://develomentor.com/2020/12/07/matthew-kelcey-former-research-engineer-at-google-brain-114/">a guest on the develomentor podcast</a> talking about random parts of my career. i always enjoy chatting to grant, hope you get to have a listen!</p>]]></content:encoded>
    </item>
  </channel>
</rss>
