# brain of mat kelcey

## an illustrative einsum example

May 27, 2020

recently i changed what was quite clumsy looking code in something more elegant using a cool function in numpy called einsum and felt it was worth writing up.

this example is just going to be using numpy, but einsum is also provided in tensorflow, jax and pytorch

there's also a terser colab version of this post if you're prefer to walk through that.

and for a video walkthrough of this see this youtube clip

the only user defined function we are going to use is this little `rnd`

one to make random arrays of various sizes.

```
>>> import numpy as np
>>> def rnd(*args):
>>> return np.random.random(args)
```

let's start with a fundamental operation, the matrix multiply.
numpy provides it through `matmul`

function which in this case is taking
an `(i, j)`

sized matrix and a `(j, k)`

sized one and producing an
`(i, k)`

result; the classic 2d matrix multiply.

```
>>> np.matmul(rnd(4, 5), rnd(5, 6)).shape
(4, 6)
```

there's a number of interpretations of what a matrix multiply
represents and a common one i use a lot is calculating,
in batch, a bunch of pairwise dot products. if we have `S1`

and `S2`

,
both of which represent some set of embeddings, in this case
32d embeddings, 5 for `S1`

and 10 for `S2`

, if we want to calculate all pairwise
combos we can use a matrix multiply. but note we need to massage
`S2`

a bit; a matrix multiple wants the inner dimension to match so
we need to transpose `B`

to make it (32, 10)

```
>>> S1 = rnd(5, 32)
>>> S2 = rnd(10, 32)
>>> np.matmul(S1, S2.T).shape
(5, 10)
```

an important thing to note about `matmul`

is that it automatically handles
the idea of batching. for example a leading dimension of 10 makes
it equivalent to 10 independent matrix multiplies. we could do this
in a for loop but expressing it in a single `matmul`

will most likely
be faster as the underlying linear algebra libraries can kick in with
a bunch of optimisations.

```
>>> np.matmul(rnd(10, 4, 5),
>>> rnd(10, 5, 6)).shape
(10, 4, 6)
```

and note that it can be multiple leading dimensions.
all that matters it the last two axis follow the `i,j,k`

rule.

```
>>> np.matmul(rnd(10, 20, 30, 4, 5),
>>> rnd(10, 20, 30, 5, 6)).shape
(10, 20, 30, 4, 6)
```

so let's talk about the real use case that led me to write
this post. we have a `query`

that i want to compare to 10 `keys`

by
using a dot product; basically an unnormalised attention / soft map lookup.

```
>>> E = 32
>>> query_embedding = rnd(E)
>>> keys_embeddings = rnd(10, E)
```

even though `query`

is a vector we can use `matmul`

to do this because
`matmul`

can interpret the vector as either a row vector as below in v1
( i.e. a matrix of shape `(1, E)`

) or a column vector as in v2
( a matrix of shape `(E, 1)`

) where, note, that `keys`

and `query`

are swapped)

```
>>> v1 = np.matmul(query_embedding, keys_embeddings.T)
>>> v2 = np.matmul(keys_embeddings, query_embedding)
>>> v1.shape, v2.shape, np.all(np.isclose(v1, v2))
((10,), (10,), True)
```

my actual use case though is wanting to do the `query`

to 10 `keys`

comparison, but across a "batch" of 100 independant sets. so if we
want to use `matmul`

in the batched form we need to do a bit
of massaging of these two since the assumed behaviour of `matmul`

will
no longer work.

```
>>> N = 100
>>> E = 32
>>> query_embeddings = rnd(N, E)
>>> keys_embeddings = rnd(N, 10, E)
```

firstly we need to be more explicit that the `query`

embeddings
represent N row vectors.

```
>>> query_embeddings.reshape(N, 1, E).shape
(100, 1, 32)
```

secondly we need to do the transpose for `keys`

, but this time we
can't just use `.T`

, we need to be explicit about keeping the
leading axis the same, swapping only the final two.

```
>>> keys_embeddings.transpose(0, 2, 1).shape
(100, 32, 10)
```

after doing these two transforms we get the matrices that will trigger the batch matrix multiply behaviour we want. note that we end up with that inner dimension of 1 still around.

```
>>> np.matmul(query_embeddings.reshape(N, 1, E),
>>> keys_embeddings.transpose(0, 2, 1)).shape
(100, 1, 10)
```

and if we don't want that inner axis, we can explicitly `squeeze`

it away.

```
>>> np.squeeze(np.matmul(query_embeddings.reshape(N, 1, E),
>>> keys_embeddings.transpose(0, 2, 1))).shape
(100, 10)
```

this all works but it felt clumsy to me. so let's go back through
these examples but using `einsum`

which gives us more explicit ways to
define the calculation without the massaging of `query`

and `keys`

to
make `matmul`

work.

we'll redo the matmul versions again first (denoted
by `m`

) followed by the einsum equivalents (denoted by `e`

)

let's start with matrix multiply. `einsum`

take a `subscript`

arg
which describes the computation we want to do. the first part `ij,jk`

is a
comma seperated list of the dimensions of the inputs, in this case `A`

and `B`

the first input, `A`

, is 2d with axis we're naming `i`

and `j`

.
the second input, `B`

, is 2d also, with axis we're naming `j`

and `k`

.

the output we want `->ik`

is 2d and takes the axis `i`

and `k`

with a reduction
along `j`

. this is exactly a matrix multiply.

```
>>> A, B = rnd(4, 5), rnd(5, 6)
>>>
>>> m = np.matmul(A, B)
>>>
>>> e = np.einsum('ij,jk->ik', A, B)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))
((4, 6), (4, 6), True)
```

consider again the case of two sets of 32d embeddings `S1`

and `S2`

and
wanting to get all the pairwise products. since the `j`

is now the
second axis we don't need to have the transpose.

```
>>> S1, S2 = rnd(5, 32), rnd(10, 32)
>>>
>>> m = np.matmul(S1, S2.T)
>>>
>>> e = np.einsum('ij,kj->ik', S1, S2)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))
((5, 10), (5, 10), True)
```

in fact in `einsum`

we can use whatever letters we like, so we're
free to try to use the character subscripts as a weak form of documentation.
we've basically got one letter to describe what the axis represents.

```
>>> S1, S2 = rnd(5, 32), rnd(10, 32)
>>>
>>> m = np.matmul(S1, S2.T)
>>>
>>> e = np.einsum('ae,be->ab', S1, S2)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))
((5, 10), (5, 10), True)
```

next let's consider the batch form of a matrix multiply. recall: `matmul`

handles this by default, but with `einsum`

we have to be
explicit about it. even still it's arguably clearer since we end up not
having to understand the assumed behaviour of `matmul`

.

```
>>> A, B = rnd(10, 4, 5), rnd(10, 5, 6)
>>>
>>> m = np.matmul(A, B)
>>>
>>> e = np.einsum('nij,njk->nik', A, B)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))
((10, 4, 6), (10, 4, 6), True)
```

going back to our `query`

vs `keys`

comparison, let's consider the
single instance case. one main thing that is different is we don't need
to have any assumption about `query`

being interpretable as a row
or column matrix, we can just explictly describe the axis
( in this case it's just the 1d `e`

).

```
>>> E = 32
>>>
>>> query_embedding = rnd(E)
>>> keys_embeddings = rnd(10, E)
>>>
>>> m = np.matmul(query_embedding, keys_embeddings.T)
>>>
>>> e = np.einsum('e,ke->k', query_embedding, keys_embeddings)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))
((10,), (10,), True)
```

and finally consider the batched version where we see the `einsum`

version ends
up being much simpler. we don't rely on any behaviour of `matmul`

that forces
us to change things to particular shapes. as before because we need to be
explict about what we want in terms of reduction we finish with a much
simpler expression that is much more direct.

```
>>> N = 100
>>> E = 32
>>>
>>> query_embeddings = rnd(N, E)
>>> keys_embeddings = rnd(N, 10, E)
>>>
>>> m = np.squeeze(np.matmul(query_embeddings.reshape(N, 1, E),
>>> keys_embeddings.transpose(0, 2, 1)))
>>>
>>> e = np.einsum('ne,nke->nk', query_embeddings, keys_embeddings)
>>>
>>> m.shape, e.shape, np.all(np.isclose(m, e))
((100, 10), (100, 10), True)
```