# Understanding `Nx.dot/6`: contracting and batch axes

I’m watching the great video “Let’s build GPT: from scratch, in code, spelled out.” by Andrej Karpathy.

All the code is written in python (using numpy and torch) and I’m trying to rewrite it in Nx/Axon.

I’m now at minute 55, when a matrix multiplication between 2 tensors of different dimensions is executed.

I was able to obtain the same result with `Nx.dot/6` but only when specifying the contracting and batch axes (see docs). I searched a bit online, but I’m not fully sure I got it correctly.

Let’s take this silly example in Python (commented the result/values):

``````import numpy as np
import torch
from torch.nn import functional as F

# B -> batch, T -> time, C -> channel
B,T,C = 4,8,2

x = np.arange(0, 64, 1).reshape((B, T, C))
# array([[[ 0,  1],
#         [ 2,  3],
#         [ 4,  5],
#         [ 6,  7],
#         [ 8,  9],
#         [10, 11],
#         [12, 13],
#         [14, 15]],

#        [[16, 17],
#         [18, 19],
#         [20, 21],
#         [22, 23],
#         [24, 25],
#         [26, 27],
#         [28, 29],
#         [30, 31]],

#        [[32, 33],
#         [34, 35],
#         [36, 37],
#         [38, 39],
#         [40, 41],
#         [42, 43],
#         [44, 45],
#         [46, 47]],

#        [[48, 49],
#         [50, 51],
#         [52, 53],
#         [54, 55],
#         [56, 57],
#         [58, 59],
#         [60, 61],
#         [62, 63]]])

tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
# tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
#         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
#         [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
#         [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

res = wei @ x
# tensor([[[ 0.0000,  1.0000],
#          [ 1.0000,  2.0000],
#          [ 2.0000,  3.0000],
#          [ 3.0000,  4.0000],
#          [ 4.0000,  5.0000],
#          [ 5.0000,  6.0000],
#          [ 6.0000,  7.0000],
#          [ 7.0000,  8.0000]],

#         [[16.0000, 17.0000],
#          [17.0000, 18.0000],
#          [18.0000, 19.0000],
#          [19.0000, 20.0000],
#          [20.0000, 21.0000],
#          [21.0000, 22.0000],
#          [22.0000, 23.0000],
#          [23.0000, 24.0000]],

#         [[32.0000, 33.0000],
#          [33.0000, 34.0000],
#          [34.0000, 35.0000],
#          [35.0000, 36.0000],
#          [36.0000, 37.0000],
#          [37.0000, 38.0000],
#          [38.0000, 39.0000],
#          [39.0000, 40.0000]],

#         [[48.0000, 49.0000],
#          [49.0000, 50.0000],
#          [50.0000, 51.0000],
#          [51.0000, 52.0000],
#          [52.0000, 53.0000],
#          [53.0000, 54.0000],
#          [54.0000, 55.0000],
#          [55.0000, 56.0000]]], dtype=torch.float64)

res.shape
# torch.Size([4, 8, 2])
``````

The relevant information are:

• `x` shape is `(B,T,C)`
• `wei` shape is `(T, T)`
• `res` shape is `(B, T, C)`

And apparently the matrix multiplication in torch `wei @ x` behaves in a way that automatically adds the B (batch) dimension to `wei` by broadcasting it.

I think that’s what the torch docs mean with:

The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable). For example, if `input` is a `(j×1×n×n)` tensor and `other` is a `(k×n×n)` tensor, `out` will be a `(j×k×n×n)` tensor.

I could get the same result with elixir Nx after a bit of trial and errors:

``````Mix.install(
[
{:nx, "~> 0.6", override: true},
{:axon, "~> 0.5"},
],
config: [nx: [default_backend: EXLA.Backend]]
)

{b, t, c} = {4, 8, 2}

x = Nx.iota({4, 8, 2})
#Nx.Tensor<
#   s64[4][8][2]
#   EXLA.Backend<host:0, 0.668616701.3656777744.246793>
#   [
#     [
#       [0, 1],
#       [2, 3],
#       [4, 5],
#       [6, 7],
#       [8, 9],
#       [10, 11],
#       [12, 13],
#       [14, 15]
#     ],
#     [
#       [16, 17],
#       [18, 19],
#       [20, 21],
#       [22, 23],
#       [24, 25],
#       [26, 27],
#       [28, 29],
#       [30, 31]
#     ],
#     [
#       [32, 33],
#       [34, 35],
#       [36, 37],
#       [38, 39],
#       [40, 41],
#       [42, 43],
#       [44, 45],
#       [46, 47]
#     ],
#     [
#       [48, 49],
#       ...
#     ]
#   ]
# >

wei =
|> Nx.triu(k: 1)
|> Axon.Activations.softmax()
# #Nx.Tensor<
#   f32[8][8]
#   EXLA.Backend<host:0, 0.668616701.3656777744.246805>
#   [
#     [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#     [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#     [0.3333333432674408, 0.3333333432674408, 0.3333333432674408, 0.0, 0.0, 0.0, 0.0, 0.0],
#     [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0],
#     [0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.0, 0.0, 0.0],
#     [0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.0, 0.0],
#     [0.1428571492433548, 0.1428571492433548, ...],
#     ...
#   ]
# >

# WRONG, DIFFERENT RESULT! ❌
# res = Nx.dot(wei, x)

# WRONG, DIFFERENT RESULT! ❌
# res = Nx.dot(wei, x) |> Nx.reshape({b, t, c})

wei_with_batches = Nx.broadcast(wei, {b, t, t})
# #Nx.Tensor<
#   f32[4][8][8]
#   EXLA.Backend<host:0, 0.668616701.3656777744.246806>
#   [
#     [
#       [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#       [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#       [0.3333333432674408, 0.3333333432674408, 0.3333333432674408, 0.0, 0.0, 0.0, 0.0, 0.0],
#       [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0],
#       [0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.0, 0.0, 0.0],
#       [0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.0, 0.0],
#       [0.1428571492433548, 0.1428571492433548, ...],
#       ...
#     ],
#     [
#       [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#       [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#       [0.3333333432674408, 0.3333333432674408, 0.3333333432674408, 0.0, 0.0, 0.0, 0.0, 0.0],
#       ...
#     ],
#     ...
#   ]
# >

# WRONG, DIFFERENT RESULT! ❌
# res = Nx.dot(wei_with_batches, x)

# CORRECT! SAME RESULT! ✅
res = Nx.dot(wei_with_batches, [2], [0], x, [1], [0])
# #Nx.Tensor<
#   f32[4][8][2]
#   EXLA.Backend<host:0, 0.668616701.3656777744.246807>
#   [
#     [
#       [0.0, 1.0],
#       [1.0, 2.0],
#       [2.0, 3.0],
#       [3.0, 4.0],
#       [4.0, 5.0],
#       [5.0, 6.000000476837158],
#       [6.000000476837158, 7.000000476837158],
#       [7.0, 8.0]
#     ],
#     [
#       [16.0, 17.0],
#       [17.0, 18.0],
#       [18.0, 19.000001907348633],
#       [19.0, 20.0],
#       [20.0, 21.0],
#       [21.000001907348633, 22.0],
#       [22.000001907348633, 23.0],
#       [23.0, 24.0]
#     ],
#     [
#       [32.0, 33.0],
#       [33.0, 34.0],
#       [34.0, 35.0],
#       [35.0, 36.0],
#       [36.0, 37.0],
#       [37.0, 38.0],
#       [38.0, 39.000003814697266],
#       [39.0, 40.0]
#     ],
#     [
#       [48.0, 49.0],
#       ...
#     ]
#   ]
# >
``````

To recap:

• `x` shape is `(B,T,C)`
• `wei` shape is `(T, T)`
• `wei_with_batches` shape is `(B, T, T)`
• `res` shape is `(B, T, C)`

In particular, I needed to use the `Nx.dot/6` function specifying the contracting and batch axes:

``````Nx.dot(wei_with_batches, [2], [0], x, [1], [0])
``````

where the contracting axes are:

• `2nd T` in `wei_with_batches` (`(B, T, T)[2]`)
• `T` in `x` (`(B, T, C)[1]`)

and the batch axes are:

• `B` in both `wei_with_batches` and `x`

Quoting the docs:

The dot product is computed by multiplying the values from `t1` given by `contract_axes1` against the values from `t2` given by `contract_axes2` , considering batch axes of `batch_axes1` and `batch_axes2` . For instance, the first axis in `contract_axes1` will be matched against the first axis in `contract_axes2` and so on.

The axes given by `contract_axes1` and `contract_axes2` are effectively removed from the final tensor, which is why they are often called the contraction axes.

Specifying batch axes will compute a vectorized dot product along the given batch dimensions. The length of `batch_axes1` and `batch_axes2` must match.

So I guess the final `res` tensor is `(B, T, C)` because, if we keep the batch dimension on the side for a moment:

`(T, T) @ (T, C) = (T, C)`, then adding back the `B` batch dimension → `(B, T, C)`

(kinda makes sense but I think there is a better way for explaining it ).

To conclude, I’d like to understand if my reasoning above makes sense and if this is the correct way to replicate the Python snippet or if there is a better way for doing this.

Nx.dot doesn’t generally apply broadcasting (although some cases such as dot(scalar, non-scalar) tensor do).

That being said, you can use a combination of Nx.new_axis + Nx.tile on your weights tensor to mimic that.

That is:

``````w -> {T, T}
x -> {B, T, C}

then

w2 = Nx.new_axis(w) -> {1, T, T}
Nx.tile(w2, [B, 1, 1]) -> {B, T, T}
``````

However, Nx 0.6 brought vectorization, and with that you can more easily deal with the batched case:

``````x -> {B, T, C}
w -> {K, T}

vec_x = Nx.vectorize(x, :some_name) # vectorized[some_name: B] {T, C}
vec_result = Nx.dot(w, [1], x, [0]) # vectorized[some_name: B] {K, C}
# Note: vec_result = Nx.dot(w, x) returns the same result :)
result = Nx.devectorize(vec_result, keep_names: false) # {B, K, C}
``````

Example with values:

``````iex(5)> x
#Nx.Tensor<
s64[4][2][3]
[
[
[0, 1, 2],
[3, 4, 5]
],
[
[6, 7, 8],
[9, 10, 11]
],
[
[12, 13, 14],
[15, 16, 17]
],
[
[18, 19, 20],
[21, 22, 23]
]
]
>
iex(6)> w = Nx.iota({5, 2})
#Nx.Tensor<
s64[5][2]
[
[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]
]
>
iex(7)> Nx.dot(w, Nx.vectorize(x, :some_name))
#Nx.Tensor<
vectorized[some_name: 4]
s64[5][3]
[
[
[3, 4, 5],
[9, 14, 19],
[15, 24, 33],
[21, 34, 47],
[27, 44, 61]
],
[
[9, 10, 11],
[39, 44, 49],
[69, 78, 87],
[99, 112, 125],
[129, 146, 163]
],
[
[15, 16, 17],
[69, 74, 79],
[123, 132, 141],
[177, 190, 203],
[231, 248, 265]
],
[
[21, 22, 23],
[99, 104, ...],
...
]
]
>
iex(8)> |> Nx.devectorize(keep_names: false)
#Nx.Tensor<
s64[4][5][3]
[
[
[3, 4, 5],
[9, 14, 19],
[15, 24, 33],
[21, 34, 47],
[27, 44, 61]
],
[
[9, 10, 11],
[39, 44, 49],
[69, 78, 87],
[99, 112, 125],
[129, 146, 163]
],
[
[15, 16, 17],
[69, 74, 79],
[123, 132, 141],
[177, 190, 203],
[231, 248, 265]
],
[
[21, 22, 23],
[99, 104, ...],
...
]
]
>
``````
Hey @polvalente thank you so much for the fast reply and to provide me some examples too, really helpful!

What you suggested:

``````w2 = Nx.new_axis(w) # -> {1, T, T}
Nx.tile(w2, [B, 1, 1]) # -> {B, T, T}
``````

is basically the same of what I did with:

``````wei = Nx.broadcast(wei, {b, t, t})
``````

Why would you prefer the combination of `Nx.new_axis` and `Nx.tile` over `Nx.broadcasting`?

However, Nx 0.6 brought vectorization, and with that you can more easily deal with the batched case:

Vectorization for the WIN!
I remember I saw the section in the Nx guide but I didn’t fully get it at the time (still now, but better ), and oh wow, it really fits this case perfectly, thanks for bringing it up.

Thanks again for chiming in and for the great work you and the nx team are doing

You’re right! I just brainfarted for a bit because `Nx.broadcast(w, x)` doesn’t work directly.

You can do something like:

``````iex(4)> t1.shape
{2, 3}
iex(5)> t2.shape
{5, 3, 4}
iex(6)> Nx.broadcast(t1, Tuple.insert_at(t1.shape, 0, Nx.axis_size(t2, 0)))
#Nx.Tensor<
s64[5][2][3]
...
>
``````

Note that inside defn you’ll probably need to define a separate deftransformp for applying this broadcast.
Vectorization will come with its own separate set of difficulties if you want to support a function that receives a maybe vectorized tensor, though, so each approach has its own trade-offs.

