Hello everyone
I’m watching the great video “Let’s build GPT: from scratch, in code, spelled out.” by Andrej Karpathy.
All the code is written in python (using numpy and torch) and I’m trying to rewrite it in Nx/Axon.
I’m now at minute 55, when a matrix multiplication between 2 tensors of different dimensions is executed.
I was able to obtain the same result with Nx.dot/6
but only when specifying the contracting and batch axes (see docs). I searched a bit online, but I’m not fully sure I got it correctly.
Let’s take this silly example in Python (commented the result/values):
import numpy as np
import torch
from torch.nn import functional as F
# B -> batch, T -> time, C -> channel
B,T,C = 4,8,2
x = np.arange(0, 64, 1).reshape((B, T, C))
# array([[[ 0, 1],
# [ 2, 3],
# [ 4, 5],
# [ 6, 7],
# [ 8, 9],
# [10, 11],
# [12, 13],
# [14, 15]],
# [[16, 17],
# [18, 19],
# [20, 21],
# [22, 23],
# [24, 25],
# [26, 27],
# [28, 29],
# [30, 31]],
# [[32, 33],
# [34, 35],
# [36, 37],
# [38, 39],
# [40, 41],
# [42, 43],
# [44, 45],
# [46, 47]],
# [[48, 49],
# [50, 51],
# [52, 53],
# [54, 55],
# [56, 57],
# [58, 59],
# [60, 61],
# [62, 63]]])
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
# tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
# [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
# [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
# [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
# [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
# [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
# [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
# [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
res = wei @ x
# tensor([[[ 0.0000, 1.0000],
# [ 1.0000, 2.0000],
# [ 2.0000, 3.0000],
# [ 3.0000, 4.0000],
# [ 4.0000, 5.0000],
# [ 5.0000, 6.0000],
# [ 6.0000, 7.0000],
# [ 7.0000, 8.0000]],
# [[16.0000, 17.0000],
# [17.0000, 18.0000],
# [18.0000, 19.0000],
# [19.0000, 20.0000],
# [20.0000, 21.0000],
# [21.0000, 22.0000],
# [22.0000, 23.0000],
# [23.0000, 24.0000]],
# [[32.0000, 33.0000],
# [33.0000, 34.0000],
# [34.0000, 35.0000],
# [35.0000, 36.0000],
# [36.0000, 37.0000],
# [37.0000, 38.0000],
# [38.0000, 39.0000],
# [39.0000, 40.0000]],
# [[48.0000, 49.0000],
# [49.0000, 50.0000],
# [50.0000, 51.0000],
# [51.0000, 52.0000],
# [52.0000, 53.0000],
# [53.0000, 54.0000],
# [54.0000, 55.0000],
# [55.0000, 56.0000]]], dtype=torch.float64)
res.shape
# torch.Size([4, 8, 2])
The relevant information are:
x
shape is(B,T,C)
wei
shape is(T, T)
res
shape is(B, T, C)
And apparently the matrix multiplication in torch wei @ x
behaves in a way that automatically adds the B (batch) dimension to wei
by broadcasting it.
I think that’s what the torch docs mean with:
The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable). For example, if
input
is a(j×1×n×n)
tensor andother
is a(k×n×n)
tensor,out
will be a(j×k×n×n)
tensor.
I could get the same result with elixir Nx after a bit of trial and errors:
Mix.install(
[
{:nx, "~> 0.6", override: true},
{:axon, "~> 0.5"},
],
config: [nx: [default_backend: EXLA.Backend]]
)
{b, t, c} = {4, 8, 2}
x = Nx.iota({4, 8, 2})
#Nx.Tensor<
# s64[4][8][2]
# EXLA.Backend<host:0, 0.668616701.3656777744.246793>
# [
# [
# [0, 1],
# [2, 3],
# [4, 5],
# [6, 7],
# [8, 9],
# [10, 11],
# [12, 13],
# [14, 15]
# ],
# [
# [16, 17],
# [18, 19],
# [20, 21],
# [22, 23],
# [24, 25],
# [26, 27],
# [28, 29],
# [30, 31]
# ],
# [
# [32, 33],
# [34, 35],
# [36, 37],
# [38, 39],
# [40, 41],
# [42, 43],
# [44, 45],
# [46, 47]
# ],
# [
# [48, 49],
# ...
# ]
# ]
# >
wei =
Nx.broadcast(Nx.Constants.neg_infinity(), {t, t})
|> Nx.triu(k: 1)
|> Axon.Activations.softmax()
# #Nx.Tensor<
# f32[8][8]
# EXLA.Backend<host:0, 0.668616701.3656777744.246805>
# [
# [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.3333333432674408, 0.3333333432674408, 0.3333333432674408, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0],
# [0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.0, 0.0, 0.0],
# [0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.0, 0.0],
# [0.1428571492433548, 0.1428571492433548, ...],
# ...
# ]
# >
# WRONG, DIFFERENT RESULT! ❌
# res = Nx.dot(wei, x)
# WRONG, DIFFERENT RESULT! ❌
# res = Nx.dot(wei, x) |> Nx.reshape({b, t, c})
# ☝️ Add the batch dimension by broadcasting it
wei_with_batches = Nx.broadcast(wei, {b, t, t})
# #Nx.Tensor<
# f32[4][8][8]
# EXLA.Backend<host:0, 0.668616701.3656777744.246806>
# [
# [
# [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.3333333432674408, 0.3333333432674408, 0.3333333432674408, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0],
# [0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.0, 0.0, 0.0],
# [0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.1666666716337204, 0.0, 0.0],
# [0.1428571492433548, 0.1428571492433548, ...],
# ...
# ],
# [
# [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
# [0.3333333432674408, 0.3333333432674408, 0.3333333432674408, 0.0, 0.0, 0.0, 0.0, 0.0],
# ...
# ],
# ...
# ]
# >
# WRONG, DIFFERENT RESULT! ❌
# res = Nx.dot(wei_with_batches, x)
# CORRECT! SAME RESULT! ✅
res = Nx.dot(wei_with_batches, [2], [0], x, [1], [0])
# #Nx.Tensor<
# f32[4][8][2]
# EXLA.Backend<host:0, 0.668616701.3656777744.246807>
# [
# [
# [0.0, 1.0],
# [1.0, 2.0],
# [2.0, 3.0],
# [3.0, 4.0],
# [4.0, 5.0],
# [5.0, 6.000000476837158],
# [6.000000476837158, 7.000000476837158],
# [7.0, 8.0]
# ],
# [
# [16.0, 17.0],
# [17.0, 18.0],
# [18.0, 19.000001907348633],
# [19.0, 20.0],
# [20.0, 21.0],
# [21.000001907348633, 22.0],
# [22.000001907348633, 23.0],
# [23.0, 24.0]
# ],
# [
# [32.0, 33.0],
# [33.0, 34.0],
# [34.0, 35.0],
# [35.0, 36.0],
# [36.0, 37.0],
# [37.0, 38.0],
# [38.0, 39.000003814697266],
# [39.0, 40.0]
# ],
# [
# [48.0, 49.0],
# ...
# ]
# ]
# >
To recap:
x
shape is(B,T,C)
wei
shape is(T, T)
wei_with_batches
shape is(B, T, T)
res
shape is(B, T, C)
In particular, I needed to use the Nx.dot/6
function specifying the contracting and batch axes:
Nx.dot(wei_with_batches, [2], [0], x, [1], [0])
where the contracting axes are:
2nd T
inwei_with_batches
((B, T, T)[2]
)T
inx
((B, T, C)[1]
)
and the batch axes are:
B
in bothwei_with_batches
andx
Quoting the docs:
The dot product is computed by multiplying the values from
t1
given bycontract_axes1
against the values fromt2
given bycontract_axes2
, considering batch axes ofbatch_axes1
andbatch_axes2
. For instance, the first axis incontract_axes1
will be matched against the first axis incontract_axes2
and so on.
The axes given by
contract_axes1
andcontract_axes2
are effectively removed from the final tensor, which is why they are often called the contraction axes.
Specifying batch axes will compute a vectorized dot product along the given batch dimensions. The length of
batch_axes1
andbatch_axes2
must match.
So I guess the final res
tensor is (B, T, C)
because, if we keep the batch dimension on the side for a moment:
(T, T) @ (T, C) = (T, C)
, then adding back the B
batch dimension → (B, T, C)
(kinda makes sense but I think there is a better way for explaining it ).
To conclude, I’d like to understand if my reasoning above makes sense and if this is the correct way to replicate the Python snippet or if there is a better way for doing this.
Thank youuuu