brain of mat kelcey...


keras3 to jax to tflite/liteRT via orbax

November 25, 2024 at 02:10 PM | categories: tflite, litert, keras3, jax, orbax

keras 3, jax, tflite, oh my.

i've always seen keras as providing two key things; 1) a way to compose layers into models and 2) an opinionated fit loop

for the last few years i've been focussed much more on jax than tensorflow which means i've not really been using keras for any hobby stuff. am super excited that keras3 supports a jax backend because i can go back to using keras for it's clear model definition while having the flexibility to mess around with weird non standard model optimisation.

my day job has some focus around the tflite ecosystem ( recently renamed liteRT ) and this set of tools has had great(†) support for tensorflow conversion, though not-so-great (†) support for jax conversion. † for some definition of great

if you have a "standard" keras model you can do conversion to tflite via the various keras save formats. ( in which case whether the backend is jax or not is not relevant ) however after using keras3/jax for a little while i'm finding myself building composite models that don't fit into this "standard" keras model.

consider my yolz project; it's a model that i train with jax that's actually made up of two keras models, and given the multiple input nature of this model using different nested batching it's not something that fits nicely with the standard keras model.

the crux of it all is; you can compose two keras models into one, but not if one of them has been vmapped ( or has had some other jax function transform on it ) i can define the parts with keras, train them with jax/optax, but i can't put them "back" into a keras model to save for tf lite export.

orbax tflite export

jax to tflite conversion has been through a LOT of churn; i've used libs listed under .experimental. more than not it feels like...

the one i started poking with today is orbax

consider the following keras3 model

from keras.layers import Input, Conv2D, GlobalMaxPooling2D
from keras.models import Model

input = Input((64, 64, 3))
y = Conv2D(filters=8, strides=1, kernel_size=3, activation='relu', padding='same')(input)
y = GlobalMaxPooling2D()(y)
model = Model(input, y)

we can convert this to a jax function for inference with a closure over the non trainable params

params = model.trainable_variables
nt_params = model.non_trainable_variables

def predict(params, x):
  y_pred, _nt_params = model.stateless_call(params, nt_params, x, training=False)
  return y_pred

eg_x = jnp.ones((1, 64, 64, 3))
predict(params, eg_x).shape
# (1, 8)

and then pass this function to the tf lite export with no surprises.

from orbax.export import constants
jax_module = JaxModule(params, k3_predict, input_polymorphic_shape='b, ...')
converter = tf.lite.TFLiteConverter.from_concrete_functions(
   [
       jax_module.methods[constants.DEFAULT_METHOD_KEY].get_concrete_function(
           tf.TensorSpec(shape=(1, 64, 64, 3), dtype=tf.float32, name="input")
       )
   ]
)
tflite_model = converter.convert()
with open('keras.tflite', 'wb') as f:
  f.write(tflite_model)

a vmapped model

more interestingly we can make a vectorised version of the inference function that is called on a set of inputs and then does a reduction ( such as the embedding model does in yolz )

input = Input((64, 64, 3))
y = Conv2D(filters=8, strides=1, kernel_size=3, activation='relu', padding='same')(input)
y = GlobalMaxPooling2D()(y)
model = Model(input, y)

params = model.trainable_variables
nt_params = model.non_trainable_variables

def predict(params, x):
  """ the vanilla version of inference """
  y_pred, _nt_params = model.stateless_call(params, nt_params, x, training=False)
  return y_pred

def v_predict(params, x):
  """ a variant that supports multiple inputs and does a mean over the result """
  y_preds = jax.vmap(predict, [None, 0])(params, x)
  return jnp.mean(y_preds, axis=-2)

eg_x = jnp.ones((1, 4, 64, 64, 3))
v_predict(params, eg_x).shape
# (1, 8)

this time we call with the additional shape axis representing the set of inputs; in this example 4

converter = tf.lite.TFLiteConverter.from_concrete_functions(
   [
       jax_module.methods[constants.DEFAULT_METHOD_KEY].get_concrete_function(
           tf.TensorSpec(shape=(1, 4, 64, 64, 3), dtype=tf.float32, name="input")
       )
   ]
)

and we see the tflite result is correctly compiled as a reshape to collapse the leading axis ( from (B, 4, ...) to (4B, ...) for the convolution and then back to (B, 4, ...) for the reduction )

we also see how the additional plumbing is included for the mean reduction mixed with the pooling

note: all of this is additional the batching we use during training for not just hardware optimisation but also for gradient variance reduction

as i mentioned before in the yolz blog post this is all reshaping, axis plumbing etc we can do manually but it's much easier, especially as we start to nest vmaps, when the compiler does it for us

code

here's the colab