brain of mat kelcey


an illustrative einsum example

May 27, 2020

recently i changed what was quite clumsy looking code in something more elegant using a cool function in numpy called einsum and felt it was worth writing up.

this example is just going to be using numpy, but einsum is also provided in tensorflow, jax and pytorch

there's also a terser colab version of this post if you're prefer to walk through that.

and for a video walkthrough of this see this youtube clip

the only user defined function we are going to use is this little rnd one to make random arrays of various sizes.

>>> import numpy as np

>>> def rnd(*args):
>>>     return np.random.random(args)

let's start with a fundamental operation, the matrix multiply. numpy provides it through matmul function which in this case is taking an (i, j) sized matrix and a (j, k) sized one and producing an (i, k) result; the classic 2d matrix multiply.

>>> np.matmul(rnd(4, 5), rnd(5, 6)).shape

(4, 6)

there's a number of interpretations of what a matrix multiply represents and a common one i use a lot is calculating, in batch, a bunch of pairwise dot products. if we have S1 and S2, both of which represent some set of embeddings, in this case 32d embeddings, 5 for S1 and 10 for S2, if we want to calculate all pairwise combos we can use a matrix multiply. but note we need to massage S2 a bit; a matrix multiple wants the inner dimension to match so we need to transpose B to make it (32, 10)

>>> S1 = rnd(5, 32)
>>> S2 = rnd(10, 32)
>>> np.matmul(S1, S2.T).shape

(5, 10)

an important thing to note about matmul is that it automatically handles the idea of batching. for example a leading dimension of 10 makes it equivalent to 10 independent matrix multiplies. we could do this in a for loop but expressing it in a single matmul will most likely be faster as the underlying linear algebra libraries can kick in with a bunch of optimisations.

>>> np.matmul(rnd(10, 4, 5),
>>>           rnd(10, 5, 6)).shape

(10, 4, 6)

and note that it can be multiple leading dimensions. all that matters it the last two axis follow the i,j,k rule.

>>> np.matmul(rnd(10, 20, 30, 4, 5),
>>>           rnd(10, 20, 30, 5, 6)).shape

(10, 20, 30, 4, 6)

so let's talk about the real use case that led me to write this post. we have a query that i want to compare to 10 keys by using a dot product; basically an unnormalised attention / soft map lookup.

>>> E = 32
>>> query_embedding = rnd(E)
>>> keys_embeddings = rnd(10, E)

even though query is a vector we can use matmul to do this because matmul can interpret the vector as either a row vector as below in v1 ( i.e. a matrix of shape (1, E) ) or a column vector as in v2 ( a matrix of shape (E, 1) ) where, note, that keys and query are swapped)

>>> v1 = np.matmul(query_embedding, keys_embeddings.T)
>>> v2 = np.matmul(keys_embeddings, query_embedding)
>>> v1.shape, v2.shape, np.all(np.isclose(v1, v2))

((10,), (10,), True)

my actual use case though is wanting to do the query to 10 keys comparison, but across a "batch" of 100 independant sets. so if we want to use matmul in the batched form we need to do a bit of massaging of these two since the assumed behaviour of matmul will no longer work.

>>> N = 100
>>> E = 32
>>> query_embeddings = rnd(N, E)
>>> keys_embeddings = rnd(N, 10, E)

firstly we need to be more explicit that the query embeddings represent N row vectors.

>>> query_embeddings.reshape(N, 1, E).shape

(100, 1, 32)

secondly we need to do the transpose for keys, but this time we can't just use .T, we need to be explicit about keeping the leading axis the same, swapping only the final two.

>>> keys_embeddings.transpose(0, 2, 1).shape

(100, 32, 10)

after doing these two transforms we get the matrices that will trigger the batch matrix multiply behaviour we want. note that we end up with that inner dimension of 1 still around.

>>> np.matmul(query_embeddings.reshape(N, 1, E),
>>>           keys_embeddings.transpose(0, 2, 1)).shape

(100, 1, 10)

and if we don't want that inner axis, we can explicitly squeeze it away.

>>> np.squeeze(np.matmul(query_embeddings.reshape(N, 1, E),
>>>                      keys_embeddings.transpose(0, 2, 1))).shape

(100, 10)

this all works but it felt clumsy to me. so let's go back through these examples but using einsum which gives us more explicit ways to define the calculation without the massaging of query and keys to make matmul work.

we'll redo the matmul versions again first (denoted by m) followed by the einsum equivalents (denoted by e)

let's start with matrix multiply. einsum take a subscript arg which describes the computation we want to do. the first part ij,jk is a comma seperated list of the dimensions of the inputs, in this case A and B

the first input, A, is 2d with axis we're naming i and j. the second input, B, is 2d also, with axis we're naming j and k.

the output we want ->ik is 2d and takes the axis i and k with a reduction along j. this is exactly a matrix multiply.

>>> A, B = rnd(4, 5), rnd(5, 6)
>>>
>>> m = np.matmul(A, B)
>>>
>>> e = np.einsum('ij,jk->ik', A, B)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))

((4, 6), (4, 6), True)

consider again the case of two sets of 32d embeddings S1 and S2 and wanting to get all the pairwise products. since the j is now the second axis we don't need to have the transpose.

>>> S1, S2 = rnd(5, 32), rnd(10, 32)
>>>
>>> m = np.matmul(S1, S2.T)
>>>
>>> e = np.einsum('ij,kj->ik', S1, S2)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))

((5, 10), (5, 10), True)

in fact in einsum we can use whatever letters we like, so we're free to try to use the character subscripts as a weak form of documentation. we've basically got one letter to describe what the axis represents.

>>> S1, S2 = rnd(5, 32), rnd(10, 32)
>>>
>>> m = np.matmul(S1, S2.T)
>>>
>>> e = np.einsum('ae,be->ab', S1, S2)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))

((5, 10), (5, 10), True)

next let's consider the batch form of a matrix multiply. recall: matmul handles this by default, but with einsum we have to be explicit about it. even still it's arguably clearer since we end up not having to understand the assumed behaviour of matmul.

>>> A, B = rnd(10, 4, 5), rnd(10, 5, 6)
>>>
>>> m = np.matmul(A, B)
>>>
>>> e = np.einsum('nij,njk->nik', A, B)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))

((10, 4, 6), (10, 4, 6), True)

going back to our query vs keys comparison, let's consider the single instance case. one main thing that is different is we don't need to have any assumption about query being interpretable as a row or column matrix, we can just explictly describe the axis ( in this case it's just the 1d e ).

>>> E = 32
>>>
>>> query_embedding = rnd(E)
>>> keys_embeddings = rnd(10, E)
>>>
>>> m = np.matmul(query_embedding, keys_embeddings.T)
>>>
>>> e = np.einsum('e,ke->k', query_embedding, keys_embeddings)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))

((10,), (10,), True)

and finally consider the batched version where we see the einsum version ends up being much simpler. we don't rely on any behaviour of matmul that forces us to change things to particular shapes. as before because we need to be explict about what we want in terms of reduction we finish with a much simpler expression that is much more direct.

>>> N = 100
>>> E = 32
>>>
>>> query_embeddings = rnd(N, E)
>>> keys_embeddings = rnd(N, 10, E)
>>>
>>> m = np.squeeze(np.matmul(query_embeddings.reshape(N, 1, E),
>>>                          keys_embeddings.transpose(0, 2, 1)))
>>>
>>> e = np.einsum('ne,nke->nk', query_embeddings, keys_embeddings)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))

((100, 10), (100, 10), True)