Hello everyone

I’m watching the great video “Let’s build GPT: from scratch, in code, spelled out.” by Andrej Karpathy.

All the code is written in python (using numpy and torch) and I’m trying to rewrite it in Nx/Axon.

I’m now at minute 55, when a matrix multiplication between 2 tensors of different dimensions is executed.

I was able to obtain the same result with `Nx.dot/6`

but only when specifying the *contracting* and *batch* axes (see docs). I searched a bit online, but I’m not fully sure I got it correctly.

Let’s take this silly example in Python (commented the result/values):

```
import numpy as np
import torch
from torch.nn import functional as F
# B -> batch, T -> time, C -> channel
B,T,C = 4,8,2
x = np.arange(0, 64, 1).reshape((B, T, C))
# array([[[ 0, 1],
# [ 2, 3],
# [ 4, 5],
# [ 6, 7],
# [ 8, 9],
# [10, 11],
# [12, 13],
# [14, 15]],
# [[16, 17],
# [18, 19],
# [20, 21],
# [22, 23],
# [24, 25],
# [26, 27],
# [28, 29],
# [30, 31]],
# [[32, 33],
# [34, 35],
# [36, 37],
# [38, 39],
# [40, 41],
# [42, 43],
# [44, 45],
# [46, 47]],
# [[48, 49],
# [50, 51],
# [52, 53],
# [54, 55],
# [56, 57],
# [58, 59],
# [60, 61],
# [62, 63]]])
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
# tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
# [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
# [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
# [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
# [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
# [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
# [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
# [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
res = wei @ x
# tensor([[[ 0.0000, 1.0000],
# [ 1.0000, 2.0000],
# [ 2.0000, 3.0000],
# [ 3.0000, 4.0000],
# [ 4.0000, 5.0000],
# [ 5.0000, 6.0000],
# [ 6.0000, 7.0000],
# [ 7.0000, 8.0000]],
# [[16.0000, 17.0000],
# [17.0000, 18.0000],
# [18.0000, 19.0000],
# [19.0000, 20.0000],
# [20.0000, 21.0000],
# [21.0000, 22.0000],
# [22.0000, 23.0000],
# [23.0000, 24.0000]],
# [[32.0000, 33.0000],
# [33.0000, 34.0000],
# [34.0000, 35.0000],
# [35.0000, 36.0000],
# [36.0000, 37.0000],
# [37.0000, 38.0000],
# [38.0000, 39.0000],
# [39.0000, 40.0000]],
# [[48.0000, 49.0000],
# [49.0000, 50.0000],
# [50.0000, 51.0000],
# [51.0000, 52.0000],
# [52.0000, 53.0000],
# [53.0000, 54.0000],
# [54.0000, 55.0000],
# [55.0000, 56.0000]]], dtype=torch.float64)
res.shape
# torch.Size([4, 8, 2])
```

The relevant information are:

`x`

shape is`(B,T,C)`

`wei`

shape is`(T, T)`

`res`

shape is`(B, T, C)`

And apparently the matrix multiplication in torch `wei @ x`

behaves in a way that automatically adds the B (batch) dimension to `wei`

by broadcasting it.

I think that’s what the torch docs mean with:

The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable). For example, if

`input`

is a`(j×1×n×n)`

tensor and`other`

is a`(k×n×n)`

tensor,`out`

will be a`(j×k×n×n)`

tensor.

I could get the same result with elixir Nx after a bit of trial and errors:

```
Mix.install(
[
{:nx, "~> 0.6", override: true},
{:axon, "~> 0.5"},
],
config: [nx: [default_backend: EXLA.Backend]]
)
{b, t, c} = {4, 8, 2}
x = Nx.iota({4, 8, 2})
#Nx.Tensor<
# s64[4][8][2]
# EXLA.Backend<host:0, 0.668616701.3656777744.246793>
# [
# [
# [0, 1],
# [2, 3],
# [4, 5],
# [6, 7],
# [8, 9],
# [10, 11],
# [12, 13],
# [14, 15]
# ],
# [
# [16, 17],
# [18, 19],
# [20, 21],
# [22, 23],
# [24, 25],
# [26, 27],
# [28, 29],
# [30, 31]
# ],
# [
# [32, 33],
# [34, 35],
# [36, 37],
# [38, 39],
# [40, 41],
# [42, 43],
# [44, 45],
# [46, 47]
# ],
# [
# [48, 49],
# ...
# ]
# ]
# >
wei =
Nx.broadcast(Nx.Constants.neg_infinity(), {t, t})
|> Nx.triu(k: 1)
|> Axon.Activations.softmax()
# #Nx.Tensor<
# f32[8][8]
# EXLA.Backend<host:0, 0.668616701.3656777744.246805>
# [
# [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.3333333432674408, 0.3333333432674408, 0.3333333432674408, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0],
# [0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.0, 0.0, 0.0],
# [0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.0, 0.0],
# [0.1428571492433548, 0.1428571492433548, ...],
# ...
# ]
# >
# WRONG, DIFFERENT RESULT! ❌
# res = Nx.dot(wei, x)
# WRONG, DIFFERENT RESULT! ❌
# res = Nx.dot(wei, x) |> Nx.reshape({b, t, c})
# ☝️ Add the batch dimension by broadcasting it
wei_with_batches = Nx.broadcast(wei, {b, t, t})
# #Nx.Tensor<
# f32[4][8][8]
# EXLA.Backend<host:0, 0.668616701.3656777744.246806>
# [
# [
# [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.3333333432674408, 0.3333333432674408, 0.3333333432674408, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0],
# [0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.0, 0.0, 0.0],
# [0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.0, 0.0],
# [0.1428571492433548, 0.1428571492433548, ...],
# ...
# ],
# [
# [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.3333333432674408, 0.3333333432674408, 0.3333333432674408, 0.0, 0.0, 0.0, 0.0, 0.0],
# ...
# ],
# ...
# ]
# >
# WRONG, DIFFERENT RESULT! ❌
# res = Nx.dot(wei_with_batches, x)
# CORRECT! SAME RESULT! ✅
res = Nx.dot(wei_with_batches, [2], [0], x, [1], [0])
# #Nx.Tensor<
# f32[4][8][2]
# EXLA.Backend<host:0, 0.668616701.3656777744.246807>
# [
# [
# [0.0, 1.0],
# [1.0, 2.0],
# [2.0, 3.0],
# [3.0, 4.0],
# [4.0, 5.0],
# [5.0, 6.000000476837158],
# [6.000000476837158, 7.000000476837158],
# [7.0, 8.0]
# ],
# [
# [16.0, 17.0],
# [17.0, 18.0],
# [18.0, 19.000001907348633],
# [19.0, 20.0],
# [20.0, 21.0],
# [21.000001907348633, 22.0],
# [22.000001907348633, 23.0],
# [23.0, 24.0]
# ],
# [
# [32.0, 33.0],
# [33.0, 34.0],
# [34.0, 35.0],
# [35.0, 36.0],
# [36.0, 37.0],
# [37.0, 38.0],
# [38.0, 39.000003814697266],
# [39.0, 40.0]
# ],
# [
# [48.0, 49.0],
# ...
# ]
# ]
# >
```

To recap:

`x`

shape is`(B,T,C)`

`wei`

shape is`(T, T)`

`wei_with_batches`

shape is`(B, T, T)`

`res`

shape is`(B, T, C)`

In particular, I needed to use the `Nx.dot/6`

function specifying the contracting and batch axes:

```
Nx.dot(wei_with_batches, [2], [0], x, [1], [0])
```

where the contracting axes are:

`2nd T`

in`wei_with_batches`

(`(B, T, T)[2]`

)`T`

in`x`

(`(B, T, C)[1]`

)

and the batch axes are:

`B`

in both`wei_with_batches`

and`x`

Quoting the docs:

The dot product is computed by multiplying the values from

`t1`

given by`contract_axes1`

against the values from`t2`

given by`contract_axes2`

, considering batch axes of`batch_axes1`

and`batch_axes2`

. For instance, the first axis in`contract_axes1`

will be matched against the first axis in`contract_axes2`

and so on.

The axes given by

`contract_axes1`

and`contract_axes2`

are effectively removed from the final tensor, which is why they are often called the contraction axes.

Specifying batch axes will compute a vectorized dot product along the given batch dimensions. The length of

`batch_axes1`

and`batch_axes2`

must match.

So I guess the final `res`

tensor is `(B, T, C)`

because, if we keep the batch dimension on the side for a moment:

`(T, T) @ (T, C) = (T, C)`

, then adding back the `B`

batch dimension → `(B, T, C)`

*(kinda makes sense but I think there is a better way for explaining it ).*

To conclude, I’d like to understand if my reasoning above makes sense and if this is the correct way to replicate the Python snippet or if there is a better way for doing this.

Thank youuuu