brain of mat kelcey...


yolz; you only look o̶n̶c̶e̶ zero times

October 26, 2024 at 06:45 PM | categories: keras3, jax

what does zero shot mean?

transfer learning

transfer learning is arguably the greatest success story of classifier neural nets. it works based on the idea that an N layer model can be thought of as an N-1 feature extractor followed by a 'simple' logisitic regression. transfer learning is based on the idea that the large N-1 feature extractor is reusable across related domains and you can get a long way just by retraining the classifier to get a second model.

note that this works best when the first model was trained on a set of classes that is, somehow, related to the set of the classes in the second model.

( for more details see this talk i did on self supervised learning )

what is few shot learning?

the fact that transfer learning even works is kind of an accident ( in the sense that it was never the objective of training the first model ). however there's no reason you can't make it an objective. techniques such as meta learning focus on this and aim to train a model that can adapted with transfer learning with only a few examples, sometimes as little as a single one per class ( e.g. reptile )

( for more details see this talk i did on learning to learn )

what is zero shot learning?

with transfer learning you have to provide a lot of new examples during retraining. with few shot learning you only need to provide onlt a few examples during retraining. but you can train a model that generalises to new classes without needing to retrain at all?

a zero shot learner aims to do this with a key technique being that the new class to generalise to is somehow provided as part of the model input itself.

( note how LLMs are the ultimate in zero shot learning since the use an input that is one of the richest representations we know, language )

what is zero shot detection

to do zero shot detection we need a model that has takes two inputs

  1. examples of the thing we're actually trying to detect ( the 'reference objects' )
  2. the image we want to find those object in ( the 'scene' )

and returns as output a binary mask of the location of the object of interest...

since we want the model to be robust to different angles of the object so we'll take as input multiple views for the object reference...

inference model

first let's consider the inference version of this model..

the N object references are run through an embedding network with the "mean" embedding being the result

the scene input is run through a standard u-net like architecture with the object embeddings combined with the scene features in the middle of the u-net ( i.e. after the downsampling, but before the upsampling ).

note we have one embedding to mix into a (10, 10) feature map so we first broadcast the embedding to match the feature map size.

the best performing approach to combining was just simple elementwise addition ( after a non linear projection of both scene features and object embeddings to a joint feature depth ). some simple experiments were run trying to use the object embeddings as a query for self attention against the scene features but it wasn't any better and was more computationally expensive. less is more sometimes i guess, and maybe it'd provide a better result for a more complex version of the problem?

training model

now consider the additional parts required for the training model...

the main change is around the object reference branch.

  1. since a scene has many objects of interest we want to train against the masks of all the objects present, not just one. this means the inclusion of an additional C axis.
  2. since we want a set of embeddings that will generalise well to novel examples we jointly train the embedding model with a constrastive loss. this requires having to run the embeddings for pairs of examples that will go to the constrastive loss. note: we only pass the anchors down to the scene branch.

the joint training means that the embeddings are forced to both

  1. generalise well based on the constrastive loss but also
  2. be useful for features to include in the scene branch.

though the main change is around the contrastive loss, we also need to do extra broadcasting in the scene branch. in the inference model we needed to broadcast the embeddings from (E) to (10, 10, E) but now the broadcasting will be (C, E) to (C, 10, 10, E) which requires the scene features to be also broadcast from (10, 10, F) to (C, 10, 10, F)

note here the output across C remains independent predictions, i.e. there is no softmax or anything

contrastive embeddings

as a side note, how do we train the object embeddings? pretty standard...

  • sample C classes
  • sample N examples of each C as 'anchors' ( with randomly coloured backgrounds )
  • calculate their embeddings and take the mean
  • sample another set of N examples as 'positives'
  • calculate their embeddings and take the mean
  • train them so the cosine_sim(mean(anc_i), mean(pos_j)) = 1.0 for i==j and = 0.0 otherwise
example class 1 anchors ( top row ) and positives
example class 2 anchors and positives

( for more details on contrastive learning see this tute i did on keras.io )

this trains an embedding space for set of N examples that generalises to new classes. the random backgrounds were included for robust to the training.

model definition and fitting

this all requires a non trivial fit loop since there's no "standard" sense of a batch anywhere and the contrastive embeddings in this model being a great example of the power of jax.vmap IMHO.

an image model will by default operate with a leading batch dim (B, H, W, 3). since we want to embed N objects this fit nicely as (N, H, W, 3) -> (N, E) after which we take the mean (E) "outside" of the model forward pass.

but now we want to do this for C examples, so our data is (C, N, H, W, 3). the classic approach to this, since the CxN examples are independent, would be to reshape to (CN, H, W, 3), run through to (CN, E) after which we'd need to reshape back to (C, N, E) and finally run the mean over the second axis to get (C, E)

though it totally doable as a manual thing it's a lot simpler in my mind to compose this instead as just a vmap over the operation of (N, H, W, 3) -> (N, E) -> (E) ( especially since that last mean step happend outside the model forward pass of batch=N. )

furthermore there's another bigger reason to get vmap to do this for us....

so far we've looked at a model that has two inputs;

  1. object references (C, 2, N, 64, 64, 3) and
  2. scene (640, 640, 3)

but to get the best result in terms of gradient variance we actually want to batch the entire composite model and actually give inputs

  1. object references (B, C, 2, N, 64, 64, 3) and
  2. scene (B, 640, 640, 3)

and this nesting of vmaps inside vmaps completely handles all the required axis remapping for us.

basic results

here we show a basic example of inference on held out data.

  • given an input image ( far left )
  • top row is reference objects and detection for 3 best performers
  • bottom row is reference objects and detection for 3 worst performers

findings

  • keras3 and stateless_call with jax is a great combo and i'll be using this more
  • focal loss gives a non trival bump to performance of the scene classifier, makes me wonder if an equivalent version can be used to contrastive loss.
  • having a stop gradient from the scene loss to the object embeddings gives the result we expect; embeddings model has a lower loss but the overall classifier performs worse.

code

all the code on github