Understanding `Nx.dot/6`: contracting and batch axes

Hello everyone :wave:

I’m watching the great video “Let’s build GPT: from scratch, in code, spelled out.” by Andrej Karpathy.

All the code is written in python (using numpy and torch) and I’m trying to rewrite it in Nx/Axon.

I’m now at minute 55, when a matrix multiplication between 2 tensors of different dimensions is executed.

I was able to obtain the same result with Nx.dot/6 but only when specifying the contracting and batch axes (see docs). I searched a bit online, but I’m not fully sure I got it correctly.

Let’s take this silly example in Python (commented the result/values):

import numpy as np
import torch
from torch.nn import functional as F

# B -> batch, T -> time, C -> channel
B,T,C = 4,8,2

x = np.arange(0, 64, 1).reshape((B, T, C))
# array([[[ 0,  1],
#         [ 2,  3],
#         [ 4,  5],
#         [ 6,  7],
#         [ 8,  9],
#         [10, 11],
#         [12, 13],
#         [14, 15]],

#        [[16, 17],
#         [18, 19],
#         [20, 21],
#         [22, 23],
#         [24, 25],
#         [26, 27],
#         [28, 29],
#         [30, 31]],

#        [[32, 33],
#         [34, 35],
#         [36, 37],
#         [38, 39],
#         [40, 41],
#         [42, 43],
#         [44, 45],
#         [46, 47]],

#        [[48, 49],
#         [50, 51],
#         [52, 53],
#         [54, 55],
#         [56, 57],
#         [58, 59],
#         [60, 61],
#         [62, 63]]])


tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
# tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
#         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
#         [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
#         [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


res = wei @ x
# tensor([[[ 0.0000,  1.0000],
#          [ 1.0000,  2.0000],
#          [ 2.0000,  3.0000],
#          [ 3.0000,  4.0000],
#          [ 4.0000,  5.0000],
#          [ 5.0000,  6.0000],
#          [ 6.0000,  7.0000],
#          [ 7.0000,  8.0000]],

#         [[16.0000, 17.0000],
#          [17.0000, 18.0000],
#          [18.0000, 19.0000],
#          [19.0000, 20.0000],
#          [20.0000, 21.0000],
#          [21.0000, 22.0000],
#          [22.0000, 23.0000],
#          [23.0000, 24.0000]],

#         [[32.0000, 33.0000],
#          [33.0000, 34.0000],
#          [34.0000, 35.0000],
#          [35.0000, 36.0000],
#          [36.0000, 37.0000],
#          [37.0000, 38.0000],
#          [38.0000, 39.0000],
#          [39.0000, 40.0000]],

#         [[48.0000, 49.0000],
#          [49.0000, 50.0000],
#          [50.0000, 51.0000],
#          [51.0000, 52.0000],
#          [52.0000, 53.0000],
#          [53.0000, 54.0000],
#          [54.0000, 55.0000],
#          [55.0000, 56.0000]]], dtype=torch.float64)

res.shape
# torch.Size([4, 8, 2])

The relevant information are:

  • x shape is (B,T,C)
  • wei shape is (T, T)
  • res shape is (B, T, C)

And apparently the matrix multiplication in torch wei @ x behaves in a way that automatically adds the B (batch) dimension to wei by broadcasting it.

I think that’s what the torch docs mean with:

The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable). For example, if input is a (j×1×n×n) tensor and other is a (k×n×n) tensor, out will be a (j×k×n×n) tensor.

I could get the same result with elixir Nx after a bit of trial and errors:

Mix.install(
  [
    {:nx, "~> 0.6", override: true},
    {:axon, "~> 0.5"},
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

{b, t, c} = {4, 8, 2}

x = Nx.iota({4, 8, 2})
#Nx.Tensor<
#   s64[4][8][2]
#   EXLA.Backend<host:0, 0.668616701.3656777744.246793>
#   [
#     [
#       [0, 1],
#       [2, 3],
#       [4, 5],
#       [6, 7],
#       [8, 9],
#       [10, 11],
#       [12, 13],
#       [14, 15]
#     ],
#     [
#       [16, 17],
#       [18, 19],
#       [20, 21],
#       [22, 23],
#       [24, 25],
#       [26, 27],
#       [28, 29],
#       [30, 31]
#     ],
#     [
#       [32, 33],
#       [34, 35],
#       [36, 37],
#       [38, 39],
#       [40, 41],
#       [42, 43],
#       [44, 45],
#       [46, 47]
#     ],
#     [
#       [48, 49],
#       ...
#     ]
#   ]
# >

wei =
  Nx.broadcast(Nx.Constants.neg_infinity(), {t, t})
  |> Nx.triu(k: 1)
  |> Axon.Activations.softmax()
# #Nx.Tensor<
#   f32[8][8]
#   EXLA.Backend<host:0, 0.668616701.3656777744.246805>
#   [
#     [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#     [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#     [0.3333333432674408, 0.3333333432674408, 0.3333333432674408, 0.0, 0.0, 0.0, 0.0, 0.0],
#     [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0],
#     [0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.0, 0.0, 0.0],
#     [0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.0, 0.0],
#     [0.1428571492433548, 0.1428571492433548, ...],
#     ...
#   ]
# >


# WRONG, DIFFERENT RESULT! ❌
# res = Nx.dot(wei, x)

# WRONG, DIFFERENT RESULT! ❌
# res = Nx.dot(wei, x) |> Nx.reshape({b, t, c})


# ☝️ Add the batch dimension by broadcasting it
wei_with_batches = Nx.broadcast(wei, {b, t, t})
# #Nx.Tensor<
#   f32[4][8][8]
#   EXLA.Backend<host:0, 0.668616701.3656777744.246806>
#   [
#     [
#       [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#       [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#       [0.3333333432674408, 0.3333333432674408, 0.3333333432674408, 0.0, 0.0, 0.0, 0.0, 0.0],
#       [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0],
#       [0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.0, 0.0, 0.0],
#       [0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.0, 0.0],
#       [0.1428571492433548, 0.1428571492433548, ...],
#       ...
#     ],
#     [
#       [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#       [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#       [0.3333333432674408, 0.3333333432674408, 0.3333333432674408, 0.0, 0.0, 0.0, 0.0, 0.0],
#       ...
#     ],
#     ...
#   ]
# >

# WRONG, DIFFERENT RESULT! ❌
# res = Nx.dot(wei_with_batches, x)

# CORRECT! SAME RESULT! ✅
res = Nx.dot(wei_with_batches, [2], [0], x, [1], [0])
# #Nx.Tensor<
#   f32[4][8][2]
#   EXLA.Backend<host:0, 0.668616701.3656777744.246807>
#   [
#     [
#       [0.0, 1.0],
#       [1.0, 2.0],
#       [2.0, 3.0],
#       [3.0, 4.0],
#       [4.0, 5.0],
#       [5.0, 6.000000476837158],
#       [6.000000476837158, 7.000000476837158],
#       [7.0, 8.0]
#     ],
#     [
#       [16.0, 17.0],
#       [17.0, 18.0],
#       [18.0, 19.000001907348633],
#       [19.0, 20.0],
#       [20.0, 21.0],
#       [21.000001907348633, 22.0],
#       [22.000001907348633, 23.0],
#       [23.0, 24.0]
#     ],
#     [
#       [32.0, 33.0],
#       [33.0, 34.0],
#       [34.0, 35.0],
#       [35.0, 36.0],
#       [36.0, 37.0],
#       [37.0, 38.0],
#       [38.0, 39.000003814697266],
#       [39.0, 40.0]
#     ],
#     [
#       [48.0, 49.0],
#       ...
#     ]
#   ]
# >

To recap:

  • x shape is (B,T,C)
  • wei shape is (T, T)
  • wei_with_batches shape is (B, T, T)
  • res shape is (B, T, C)

In particular, I needed to use the Nx.dot/6 function specifying the contracting and batch axes:

Nx.dot(wei_with_batches, [2], [0], x, [1], [0])

where the contracting axes are:

  • 2nd T in wei_with_batches ((B, T, T)[2])
  • T in x ((B, T, C)[1])

and the batch axes are:

  • B in both wei_with_batches and x

Quoting the docs:

The dot product is computed by multiplying the values from t1 given by contract_axes1 against the values from t2 given by contract_axes2 , considering batch axes of batch_axes1 and batch_axes2 . For instance, the first axis in contract_axes1 will be matched against the first axis in contract_axes2 and so on.

The axes given by contract_axes1 and contract_axes2 are effectively removed from the final tensor, which is why they are often called the contraction axes.

Specifying batch axes will compute a vectorized dot product along the given batch dimensions. The length of batch_axes1 and batch_axes2 must match.

So I guess the final res tensor is (B, T, C) because, if we keep the batch dimension on the side for a moment:

(T, T) @ (T, C) = (T, C), then adding back the B batch dimension → (B, T, C)

(kinda makes sense but I think there is a better way for explaining it :see_no_evil:).

To conclude, I’d like to understand if my reasoning above makes sense and if this is the correct way to replicate the Python snippet or if there is a better way for doing this.

Thank youuuu :bowing_man:

1 Like

Nx.dot doesn’t generally apply broadcasting (although some cases such as dot(scalar, non-scalar) tensor do).

That being said, you can use a combination of Nx.new_axis + Nx.tile on your weights tensor to mimic that.

That is:

w -> {T, T}
x -> {B, T, C}

then

w2 = Nx.new_axis(w) -> {1, T, T}
Nx.tile(w2, [B, 1, 1]) -> {B, T, T}

However, Nx 0.6 brought vectorization, and with that you can more easily deal with the batched case:

x -> {B, T, C}
w -> {K, T}

vec_x = Nx.vectorize(x, :some_name) # vectorized[some_name: B] {T, C}
vec_result = Nx.dot(w, [1], x, [0]) # vectorized[some_name: B] {K, C}
# Note: vec_result = Nx.dot(w, x) returns the same result :)
result = Nx.devectorize(vec_result, keep_names: false) # {B, K, C}

Example with values:

iex(5)> x            
#Nx.Tensor<
  s64[4][2][3]
  [
    [
      [0, 1, 2],
      [3, 4, 5]
    ],
    [
      [6, 7, 8],
      [9, 10, 11]
    ],
    [
      [12, 13, 14],
      [15, 16, 17]
    ],
    [
      [18, 19, 20],
      [21, 22, 23]
    ]
  ]
>
iex(6)> w = Nx.iota({5, 2})
#Nx.Tensor<
  s64[5][2]
  [
    [0, 1],
    [2, 3],
    [4, 5],
    [6, 7],
    [8, 9]
  ]
>
iex(7)> Nx.dot(w, Nx.vectorize(x, :some_name))
#Nx.Tensor<
  vectorized[some_name: 4]
  s64[5][3]
  [
    [
      [3, 4, 5],
      [9, 14, 19],
      [15, 24, 33],
      [21, 34, 47],
      [27, 44, 61]
    ],
    [
      [9, 10, 11],
      [39, 44, 49],
      [69, 78, 87],
      [99, 112, 125],
      [129, 146, 163]
    ],
    [
      [15, 16, 17],
      [69, 74, 79],
      [123, 132, 141],
      [177, 190, 203],
      [231, 248, 265]
    ],
    [
      [21, 22, 23],
      [99, 104, ...],
      ...
    ]
  ]
>
iex(8)> |> Nx.devectorize(keep_names: false)
#Nx.Tensor<
  s64[4][5][3]
  [
    [
      [3, 4, 5],
      [9, 14, 19],
      [15, 24, 33],
      [21, 34, 47],
      [27, 44, 61]
    ],
    [
      [9, 10, 11],
      [39, 44, 49],
      [69, 78, 87],
      [99, 112, 125],
      [129, 146, 163]
    ],
    [
      [15, 16, 17],
      [69, 74, 79],
      [123, 132, 141],
      [177, 190, 203],
      [231, 248, 265]
    ],
    [
      [21, 22, 23],
      [99, 104, ...],
      ...
    ]
  ]
>
2 Likes

Hey @polvalente thank you so much for the fast reply and to provide me some examples too, really helpful!

What you suggested:

w2 = Nx.new_axis(w) # -> {1, T, T}
Nx.tile(w2, [B, 1, 1]) # -> {B, T, T}

is basically the same of what I did with:

wei = Nx.broadcast(wei, {b, t, t})

Why would you prefer the combination of Nx.new_axis and Nx.tile over Nx.broadcasting?

However, Nx 0.6 brought vectorization, and with that you can more easily deal with the batched case:

Vectorization for the WIN!
I remember I saw the section in the Nx guide but I didn’t fully get it at the time (still now, but better :grimacing:), and oh wow, it really fits this case perfectly, thanks for bringing it up. :bowing_man:

Thanks again for chiming in and for the great work you and the nx team are doing :heart: :guitar:

1 Like

You’re right! I just brainfarted for a bit because Nx.broadcast(w, x) doesn’t work directly.

You can do something like:

iex(4)> t1.shape
{2, 3}
iex(5)> t2.shape
{5, 3, 4}
iex(6)> Nx.broadcast(t1, Tuple.insert_at(t1.shape, 0, Nx.axis_size(t2, 0)))
#Nx.Tensor<
  s64[5][2][3]
  ...
>

Note that inside defn you’ll probably need to define a separate deftransformp for applying this broadcast.
Vectorization will come with its own separate set of difficulties if you want to support a function that receives a maybe vectorized tensor, though, so each approach has its own trade-offs.

1 Like