brain of mat kelcey...


yolz; you only look o̶n̶c̶e̶ zero times

October 26, 2024 at 06:45 PM | categories: keras3, jax

what does zero shot mean?

transfer learning

transfer learning is arguably the greatest success story of classifier neural nets. it works based on the idea that an N layer model can be thought of as an N-1 feature extractor followed by a 'simple' logisitic regression. transfer learning is based on the idea that the large N-1 feature extractor is reusable across related domains and you can get a long way just by retraining the classifier to get a second model.

note that this works best when the first model was trained on a set of classes that is, somehow, related to the set of the classes in the second model.

( for more details see this talk i did on self supervised learning )

what is few shot learning?

the fact that transfer learning even works is kind of an accident ( in the sense that it was never the objective of training the first model ). however there's no reason you can't make it an objective. techniques such as meta learning focus on this and aim to train a model that can adapted with transfer learning with only a few examples, sometimes as little as a single one per class ( e.g. reptile )

( for more details see this talk i did on learning to learn )

what is zero shot?

with transfer learning you have to provide a bunch of new examples during retraining. with few shot learning you only need to provide a few examples during retraining. but you can train a model that generalises to new classes without needing to retrain at all?

a zero shot learner does this with the main technique being that the new class to generalise to is somehow provided as part of the input.

note how LLMs are the ultimate in zero shot learning since the use an input that is one of the richest representations we know, language.

what is zero shot detection

to do zero shot detection we need a model that has takes as input not just the 'standard' image we want to find object in, but also examples of the thing we're actually trying to detect.

we'll call the image the 'scene' and the inputs related to the object of interest the 'object reference'

so that the model is robust to unseen items we'll learn an object embedding using contrastive learning. since we also want the model to be robust to different angles of the object so we'll take as input multiple views for the object reference...

IMAGE OF object reference and scene on left, simple network, and example of detections on right

note: that this model has a single "class", output, P(object), not any form of softmax or anything.

contrastive embeddings

firstly how do we train the object embeddings?

  • sample C classes
  • sample N examples of each C as 'anchors'
  • calculate their embeddings and take the mean
  • sample another set of N examples as 'positives'
  • calculate their embeddings and take the mean
  • train them so the cosine_sim(mean(anc_i), mean(pos_j)) = 1.0 for i==j and = 0.0 otherwise
example class 1 anchors ( top row ) and positives
example class 2 anchors and positives

( for more details on contrastive learning see this tute i did on keras.io )

this trains an embedding space for set of N examples that generalises to new classes

detection

we jointly train a second sub model that extracts scene features that are to be combined with the object embeddings. the features from the scene and the embeddings of the object of interest are combined to do a grid of detections

this requires a non trivial fit since there's not "standard" sense of a batch. keras3 stateless call with the jax backend provides a lot of flexibility!

for inference we can 1) completely drop the leading C axis and just run for a single novel object. 2) just run one set of N reference objects ( i.e. we don't need anchors and positives )

note: the next step of model development would be to not just do detection but regression to an actual bounding box ( as for the more general yolo )

first round results

this was a time boxed project and the first results were good enough that i think the project proves a point,

general example; model picks up novel objects ok object confusion; gets confused by objects of similar shape / colour
object references
object references

findings

  • keras3 and stateless_call with jax is a great combo and i'll be using this more
  • focal loss gives a non trival bump to performance of the scene classifier, makes me wonder if an equivalent version can be used to contrastive loss.
  • having a stop gradient from the scene loss to the object embeddings gives the result we expect; embeddings model has a lower loss but the overall classifier performs worse.

code

all the code on github