Hello everyone ![]()
I’m watching the great video “Let’s build GPT: from scratch, in code, spelled out.” by Andrej Karpathy.
All the code is written in python (using numpy and torch) and I’m trying to rewrite it in Nx/Axon.
I’m now at minute 55, when a matrix multiplication between 2 tensors of different dimensions is executed.
I was able to obtain the same result with Nx.dot/6 but only when specifying the contracting and batch axes (see docs). I searched a bit online, but I’m not fully sure I got it correctly.
Let’s take this silly example in Python (commented the result/values):
import numpy as np
import torch
from torch.nn import functional as F
# B -> batch, T -> time, C -> channel
B,T,C = 4,8,2
x = np.arange(0, 64, 1).reshape((B, T, C))
# array([[[ 0, 1],
# [ 2, 3],
# [ 4, 5],
# [ 6, 7],
# [ 8, 9],
# [10, 11],
# [12, 13],
# [14, 15]],
# [[16, 17],
# [18, 19],
# [20, 21],
# [22, 23],
# [24, 25],
# [26, 27],
# [28, 29],
# [30, 31]],
# [[32, 33],
# [34, 35],
# [36, 37],
# [38, 39],
# [40, 41],
# [42, 43],
# [44, 45],
# [46, 47]],
# [[48, 49],
# [50, 51],
# [52, 53],
# [54, 55],
# [56, 57],
# [58, 59],
# [60, 61],
# [62, 63]]])
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
# tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
# [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
# [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
# [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
# [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
# [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
# [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
# [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
res = wei @ x
# tensor([[[ 0.0000, 1.0000],
# [ 1.0000, 2.0000],
# [ 2.0000, 3.0000],
# [ 3.0000, 4.0000],
# [ 4.0000, 5.0000],
# [ 5.0000, 6.0000],
# [ 6.0000, 7.0000],
# [ 7.0000, 8.0000]],
# [[16.0000, 17.0000],
# [17.0000, 18.0000],
# [18.0000, 19.0000],
# [19.0000, 20.0000],
# [20.0000, 21.0000],
# [21.0000, 22.0000],
# [22.0000, 23.0000],
# [23.0000, 24.0000]],
# [[32.0000, 33.0000],
# [33.0000, 34.0000],
# [34.0000, 35.0000],
# [35.0000, 36.0000],
# [36.0000, 37.0000],
# [37.0000, 38.0000],
# [38.0000, 39.0000],
# [39.0000, 40.0000]],
# [[48.0000, 49.0000],
# [49.0000, 50.0000],
# [50.0000, 51.0000],
# [51.0000, 52.0000],
# [52.0000, 53.0000],
# [53.0000, 54.0000],
# [54.0000, 55.0000],
# [55.0000, 56.0000]]], dtype=torch.float64)
res.shape
# torch.Size([4, 8, 2])
The relevant information are:
xshape is(B,T,C)weishape is(T, T)resshape is(B, T, C)
And apparently the matrix multiplication in torch wei @ x behaves in a way that automatically adds the B (batch) dimension to wei by broadcasting it.
I think that’s what the torch docs mean with:
The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable). For example, if
inputis a(j×1×n×n)tensor andotheris a(k×n×n)tensor,outwill be a(j×k×n×n)tensor.
I could get the same result with elixir Nx after a bit of trial and errors:
Mix.install(
[
{:nx, "~> 0.6", override: true},
{:axon, "~> 0.5"},
],
config: [nx: [default_backend: EXLA.Backend]]
)
{b, t, c} = {4, 8, 2}
x = Nx.iota({4, 8, 2})
#Nx.Tensor<
# s64[4][8][2]
# EXLA.Backend<host:0, 0.668616701.3656777744.246793>
# [
# [
# [0, 1],
# [2, 3],
# [4, 5],
# [6, 7],
# [8, 9],
# [10, 11],
# [12, 13],
# [14, 15]
# ],
# [
# [16, 17],
# [18, 19],
# [20, 21],
# [22, 23],
# [24, 25],
# [26, 27],
# [28, 29],
# [30, 31]
# ],
# [
# [32, 33],
# [34, 35],
# [36, 37],
# [38, 39],
# [40, 41],
# [42, 43],
# [44, 45],
# [46, 47]
# ],
# [
# [48, 49],
# ...
# ]
# ]
# >
wei =
Nx.broadcast(Nx.Constants.neg_infinity(), {t, t})
|> Nx.triu(k: 1)
|> Axon.Activations.softmax()
# #Nx.Tensor<
# f32[8][8]
# EXLA.Backend<host:0, 0.668616701.3656777744.246805>
# [
# [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.3333333432674408, 0.3333333432674408, 0.3333333432674408, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0],
# [0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.0, 0.0, 0.0],
# [0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.0, 0.0],
# [0.1428571492433548, 0.1428571492433548, ...],
# ...
# ]
# >
# WRONG, DIFFERENT RESULT! ❌
# res = Nx.dot(wei, x)
# WRONG, DIFFERENT RESULT! ❌
# res = Nx.dot(wei, x) |> Nx.reshape({b, t, c})
# ☝️ Add the batch dimension by broadcasting it
wei_with_batches = Nx.broadcast(wei, {b, t, t})
# #Nx.Tensor<
# f32[4][8][8]
# EXLA.Backend<host:0, 0.668616701.3656777744.246806>
# [
# [
# [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.3333333432674408, 0.3333333432674408, 0.3333333432674408, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0],
# [0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.0, 0.0, 0.0],
# [0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.0, 0.0],
# [0.1428571492433548, 0.1428571492433548, ...],
# ...
# ],
# [
# [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.3333333432674408, 0.3333333432674408, 0.3333333432674408, 0.0, 0.0, 0.0, 0.0, 0.0],
# ...
# ],
# ...
# ]
# >
# WRONG, DIFFERENT RESULT! ❌
# res = Nx.dot(wei_with_batches, x)
# CORRECT! SAME RESULT! ✅
res = Nx.dot(wei_with_batches, [2], [0], x, [1], [0])
# #Nx.Tensor<
# f32[4][8][2]
# EXLA.Backend<host:0, 0.668616701.3656777744.246807>
# [
# [
# [0.0, 1.0],
# [1.0, 2.0],
# [2.0, 3.0],
# [3.0, 4.0],
# [4.0, 5.0],
# [5.0, 6.000000476837158],
# [6.000000476837158, 7.000000476837158],
# [7.0, 8.0]
# ],
# [
# [16.0, 17.0],
# [17.0, 18.0],
# [18.0, 19.000001907348633],
# [19.0, 20.0],
# [20.0, 21.0],
# [21.000001907348633, 22.0],
# [22.000001907348633, 23.0],
# [23.0, 24.0]
# ],
# [
# [32.0, 33.0],
# [33.0, 34.0],
# [34.0, 35.0],
# [35.0, 36.0],
# [36.0, 37.0],
# [37.0, 38.0],
# [38.0, 39.000003814697266],
# [39.0, 40.0]
# ],
# [
# [48.0, 49.0],
# ...
# ]
# ]
# >
To recap:
xshape is(B,T,C)weishape is(T, T)wei_with_batchesshape is(B, T, T)resshape is(B, T, C)
In particular, I needed to use the Nx.dot/6 function specifying the contracting and batch axes:
Nx.dot(wei_with_batches, [2], [0], x, [1], [0])
where the contracting axes are:
2nd Tinwei_with_batches((B, T, T)[2])Tinx((B, T, C)[1])
and the batch axes are:
Bin bothwei_with_batchesandx
Quoting the docs:
The dot product is computed by multiplying the values from
t1given bycontract_axes1against the values fromt2given bycontract_axes2, considering batch axes ofbatch_axes1andbatch_axes2. For instance, the first axis incontract_axes1will be matched against the first axis incontract_axes2and so on.
The axes given by
contract_axes1andcontract_axes2are effectively removed from the final tensor, which is why they are often called the contraction axes.
Specifying batch axes will compute a vectorized dot product along the given batch dimensions. The length of
batch_axes1andbatch_axes2must match.
So I guess the final res tensor is (B, T, C) because, if we keep the batch dimension on the side for a moment:
(T, T) @ (T, C) = (T, C), then adding back the B batch dimension → (B, T, C)
(kinda makes sense but I think there is a better way for explaining it
).
To conclude, I’d like to understand if my reasoning above makes sense and if this is the correct way to replicate the Python snippet or if there is a better way for doing this.
Thank youuuu ![]()






















