brain of mat kelcey...
keras3 to jax to tflite/liteRT via orbax
November 25, 2024 at 02:10 PM | categories: tflite, litert, keras3, jax, orbaxkeras 3, jax, tflite, oh my.
i've always seen keras as providing two key things;
1) a way to compose layers into models and
2) an opinionated fit
loop
for the last few years i've been focussed much more on jax than tensorflow which means i've not really been using keras for any hobby stuff. am super excited that keras3 supports a jax backend because i can go back to using keras for it's clear model definition while having the flexibility to mess around with weird non standard model optimisation.
my day job has some focus around the tflite ecosystem ( recently renamed liteRT ) and this set of tools has had great(†) support for tensorflow conversion, though not-so-great (†) support for jax conversion. † for some definition of great
if you have a "standard" keras model you can do conversion to tflite via the various keras save formats. ( in which case whether the backend is jax or not is not relevant ) however after using keras3/jax for a little while i'm finding myself building composite models that don't fit into this "standard" keras model.
consider my yolz project; it's a model that i train with jax that's actually made up of two keras models, and given the multiple input nature of this model using different nested batching it's not something that fits nicely with the standard keras model.
the crux of it all is; you can compose two keras models into one, but not if one of them has been vmapped ( or has had some other jax function transform on it ) i can define the parts with keras, train them with jax/optax, but i can't put them "back" into a keras model to save for tf lite export.
orbax tflite export
jax to tflite conversion has been through a LOT of churn; i've used libs listed under
.experimental.
more than not it feels like...
the one i started poking with today is orbax
consider the following keras3 model
from keras.layers import Input, Conv2D, GlobalMaxPooling2D
from keras.models import Model
input = Input((64, 64, 3))
y = Conv2D(filters=8, strides=1, kernel_size=3, activation='relu', padding='same')(input)
y = GlobalMaxPooling2D()(y)
model = Model(input, y)
we can convert this to a jax function for inference with a closure over the non trainable params
params = model.trainable_variables
nt_params = model.non_trainable_variables
def predict(params, x):
y_pred, _nt_params = model.stateless_call(params, nt_params, x, training=False)
return y_pred
eg_x = jnp.ones((1, 64, 64, 3))
predict(params, eg_x).shape
# (1, 8)
and then pass this function to the tf lite export with no surprises.
from orbax.export import constants
jax_module = JaxModule(params, k3_predict, input_polymorphic_shape='b, ...')
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[
jax_module.methods[constants.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(1, 64, 64, 3), dtype=tf.float32, name="input")
)
]
)
tflite_model = converter.convert()
with open('keras.tflite', 'wb') as f:
f.write(tflite_model)
a vmapped model
more interestingly we can make a vectorised version of the inference function that is called on a set of inputs and then does a reduction ( such as the embedding model does in yolz )
input = Input((64, 64, 3))
y = Conv2D(filters=8, strides=1, kernel_size=3, activation='relu', padding='same')(input)
y = GlobalMaxPooling2D()(y)
model = Model(input, y)
params = model.trainable_variables
nt_params = model.non_trainable_variables
def predict(params, x):
""" the vanilla version of inference """
y_pred, _nt_params = model.stateless_call(params, nt_params, x, training=False)
return y_pred
def v_predict(params, x):
""" a variant that supports multiple inputs and does a mean over the result """
y_preds = jax.vmap(predict, [None, 0])(params, x)
return jnp.mean(y_preds, axis=-2)
eg_x = jnp.ones((1, 4, 64, 64, 3))
v_predict(params, eg_x).shape
# (1, 8)
this time we call with the additional shape axis representing the set of inputs; in this example 4
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[
jax_module.methods[constants.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(1, 4, 64, 64, 3), dtype=tf.float32, name="input")
)
]
)
and we see the tflite result is correctly compiled as a reshape to collapse the leading
axis ( from (B, 4, ...)
to (4B, ...)
for the convolution and then back to (B, 4, ...)
for the reduction )
we also see how the additional plumbing is included for the mean
reduction mixed with the pooling
note: all of this is additional the batching we use during training for not just hardware optimisation but also for gradient variance reduction
as i mentioned before in the yolz blog post this is all reshaping, axis plumbing etc we can do manually but it's much easier, especially as we start to nest vmaps, when the compiler does it for us
code
here's the colab