brain of mat kelcey...
an illustrative einsum example
May 27, 2020 at 12:00 AM | categories: talk, short_tuteintro
recently i changed what was quite clumsy looking code in something more elegant using a cool function in numpy called einsum and felt it was worth writing up.
this example is just going to be using numpy, but einsum is also provided in tensorflow, jax and pytorch
there's also a terser colab version of this post if you're prefer to walk through that.
and for a video walkthrough of this see...
example walkthrough
the only user defined function we are going to use is this little rnd
one to make random arrays of various sizes.
>>> import numpy as np
>>> def rnd(*args):
>>> return np.random.random(args)
let's start with a fundamental operation, the matrix multiply.
numpy provides it through matmul
function which in this case is taking
an (i, j)
sized matrix and a (j, k)
sized one and producing an
(i, k)
result; the classic 2d matrix multiply.
>>> np.matmul(rnd(4, 5), rnd(5, 6)).shape
(4, 6)
there's a number of interpretations of what a matrix multiply
represents and a common one i use a lot is calculating,
in batch, a bunch of pairwise dot products. if we have S1
and S2
,
both of which represent some set of embeddings, in this case
32d embeddings, 5 for S1
and 10 for S2
, if we want to calculate all pairwise
combos we can use a matrix multiply. but note we need to massage
S2
a bit; a matrix multiple wants the inner dimension to match so
we need to transpose B
to make it (32, 10)
>>> S1 = rnd(5, 32)
>>> S2 = rnd(10, 32)
>>> np.matmul(S1, S2.T).shape
(5, 10)
an important thing to note about matmul
is that it automatically handles
the idea of batching. for example a leading dimension of 10 makes
it equivalent to 10 independent matrix multiplies. we could do this
in a for loop but expressing it in a single matmul
will most likely
be faster as the underlying linear algebra libraries can kick in with
a bunch of optimisations.
>>> np.matmul(rnd(10, 4, 5),
>>> rnd(10, 5, 6)).shape
(10, 4, 6)
and note that it can be multiple leading dimensions.
all that matters it the last two axis follow the i,j,k
rule.
>>> np.matmul(rnd(10, 20, 30, 4, 5),
>>> rnd(10, 20, 30, 5, 6)).shape
(10, 20, 30, 4, 6)
so let's talk about the real use case that led me to write
this post. we have a query
that i want to compare to 10 keys
by
using a dot product; basically an unnormalised attention / soft map lookup.
>>> E = 32
>>> query_embedding = rnd(E)
>>> keys_embeddings = rnd(10, E)
even though query
is a vector we can use matmul
to do this because
matmul
can interpret the vector as either a row vector as below in v1
( i.e. a matrix of shape (1, E)
) or a column vector as in v2
( a matrix of shape (E, 1)
) where, note, that keys
and query
are swapped)
>>> v1 = np.matmul(query_embedding, keys_embeddings.T)
>>> v2 = np.matmul(keys_embeddings, query_embedding)
>>> v1.shape, v2.shape, np.all(np.isclose(v1, v2))
((10,), (10,), True)
my actual use case though is wanting to do the query
to 10 keys
comparison, but across a "batch" of 100 independant sets. so if we
want to use matmul
in the batched form we need to do a bit
of massaging of these two since the assumed behaviour of matmul
will
no longer work.
>>> N = 100
>>> E = 32
>>> query_embeddings = rnd(N, E)
>>> keys_embeddings = rnd(N, 10, E)
firstly we need to be more explicit that the query
embeddings
represent N row vectors.
>>> query_embeddings.reshape(N, 1, E).shape
(100, 1, 32)
secondly we need to do the transpose for keys
, but this time we
can't just use .T
, we need to be explicit about keeping the
leading axis the same, swapping only the final two.
>>> keys_embeddings.transpose(0, 2, 1).shape
(100, 32, 10)
after doing these two transforms we get the matrices that will trigger the batch matrix multiply behaviour we want. note that we end up with that inner dimension of 1 still around.
>>> np.matmul(query_embeddings.reshape(N, 1, E),
>>> keys_embeddings.transpose(0, 2, 1)).shape
(100, 1, 10)
and if we don't want that inner axis, we can explicitly squeeze
it away.
>>> np.squeeze(np.matmul(query_embeddings.reshape(N, 1, E),
>>> keys_embeddings.transpose(0, 2, 1))).shape
(100, 10)
this all works but it felt clumsy to me. so let's go back through
these examples but using einsum
which gives us more explicit ways to
define the calculation without the massaging of query
and keys
to
make matmul
work.
we'll redo the matmul versions again first (denoted
by m
) followed by the einsum equivalents (denoted by e
)
let's start with matrix multiply. einsum
take a subscript
arg
which describes the computation we want to do. the first part ij,jk
is a
comma seperated list of the dimensions of the inputs, in this case A
and B
the first input, A
, is 2d with axis we're naming i
and j
.
the second input, B
, is 2d also, with axis we're naming j
and k
.
the output we want ->ik
is 2d and takes the axis i
and k
with a reduction
along j
. this is exactly a matrix multiply.
>>> A, B = rnd(4, 5), rnd(5, 6)
>>>
>>> m = np.matmul(A, B)
>>>
>>> e = np.einsum('ij,jk->ik', A, B)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))
((4, 6), (4, 6), True)
consider again the case of two sets of 32d embeddings S1
and S2
and
wanting to get all the pairwise products. since the j
is now the
second axis we don't need to have the transpose.
>>> S1, S2 = rnd(5, 32), rnd(10, 32)
>>>
>>> m = np.matmul(S1, S2.T)
>>>
>>> e = np.einsum('ij,kj->ik', S1, S2)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))
((5, 10), (5, 10), True)
in fact in einsum
we can use whatever letters we like, so we're
free to try to use the character subscripts as a weak form of documentation.
we've basically got one letter to describe what the axis represents.
>>> S1, S2 = rnd(5, 32), rnd(10, 32)
>>>
>>> m = np.matmul(S1, S2.T)
>>>
>>> e = np.einsum('ae,be->ab', S1, S2)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))
((5, 10), (5, 10), True)
next let's consider the batch form of a matrix multiply. recall: matmul
handles this by default, but with einsum
we have to be
explicit about it. even still it's arguably clearer since we end up not
having to understand the assumed behaviour of matmul
.
>>> A, B = rnd(10, 4, 5), rnd(10, 5, 6)
>>>
>>> m = np.matmul(A, B)
>>>
>>> e = np.einsum('nij,njk->nik', A, B)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))
((10, 4, 6), (10, 4, 6), True)
going back to our query
vs keys
comparison, let's consider the
single instance case. one main thing that is different is we don't need
to have any assumption about query
being interpretable as a row
or column matrix, we can just explictly describe the axis
( in this case it's just the 1d e
).
>>> E = 32
>>>
>>> query_embedding = rnd(E)
>>> keys_embeddings = rnd(10, E)
>>>
>>> m = np.matmul(query_embedding, keys_embeddings.T)
>>>
>>> e = np.einsum('e,ke->k', query_embedding, keys_embeddings)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))
((10,), (10,), True)
and finally consider the batched version where we see the einsum
version ends
up being much simpler. we don't rely on any behaviour of matmul
that forces
us to change things to particular shapes. as before because we need to be
explict about what we want in terms of reduction we finish with a much
simpler expression that is much more direct.
>>> N = 100
>>> E = 32
>>>
>>> query_embeddings = rnd(N, E)
>>> keys_embeddings = rnd(N, 10, E)
>>>
>>> m = np.squeeze(np.matmul(query_embeddings.reshape(N, 1, E),
>>> keys_embeddings.transpose(0, 2, 1)))
>>>
>>> e = np.einsum('ne,nke->nk', query_embeddings, keys_embeddings)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))
((100, 10), (100, 10), True)