Understanding `Nx.dot/6`: contracting and batch axes

Nx.dot doesn’t generally apply broadcasting (although some cases such as dot(scalar, non-scalar) tensor do).

That being said, you can use a combination of Nx.new_axis + Nx.tile on your weights tensor to mimic that.

That is:

w -> {T, T}
x -> {B, T, C}

then

w2 = Nx.new_axis(w) -> {1, T, T}
Nx.tile(w2, [B, 1, 1]) -> {B, T, T}

However, Nx 0.6 brought vectorization, and with that you can more easily deal with the batched case:

x -> {B, T, C}
w -> {K, T}

vec_x = Nx.vectorize(x, :some_name) # vectorized[some_name: B] {T, C}
vec_result = Nx.dot(w, [1], x, [0]) # vectorized[some_name: B] {K, C}
# Note: vec_result = Nx.dot(w, x) returns the same result :)
result = Nx.devectorize(vec_result, keep_names: false) # {B, K, C}

Example with values:

iex(5)> x            
#Nx.Tensor<
  s64[4][2][3]
  [
    [
      [0, 1, 2],
      [3, 4, 5]
    ],
    [
      [6, 7, 8],
      [9, 10, 11]
    ],
    [
      [12, 13, 14],
      [15, 16, 17]
    ],
    [
      [18, 19, 20],
      [21, 22, 23]
    ]
  ]
>
iex(6)> w = Nx.iota({5, 2})
#Nx.Tensor<
  s64[5][2]
  [
    [0, 1],
    [2, 3],
    [4, 5],
    [6, 7],
    [8, 9]
  ]
>
iex(7)> Nx.dot(w, Nx.vectorize(x, :some_name))
#Nx.Tensor<
  vectorized[some_name: 4]
  s64[5][3]
  [
    [
      [3, 4, 5],
      [9, 14, 19],
      [15, 24, 33],
      [21, 34, 47],
      [27, 44, 61]
    ],
    [
      [9, 10, 11],
      [39, 44, 49],
      [69, 78, 87],
      [99, 112, 125],
      [129, 146, 163]
    ],
    [
      [15, 16, 17],
      [69, 74, 79],
      [123, 132, 141],
      [177, 190, 203],
      [231, 248, 265]
    ],
    [
      [21, 22, 23],
      [99, 104, ...],
      ...
    ]
  ]
>
iex(8)> |> Nx.devectorize(keep_names: false)
#Nx.Tensor<
  s64[4][5][3]
  [
    [
      [3, 4, 5],
      [9, 14, 19],
      [15, 24, 33],
      [21, 34, 47],
      [27, 44, 61]
    ],
    [
      [9, 10, 11],
      [39, 44, 49],
      [69, 78, 87],
      [99, 112, 125],
      [129, 146, 163]
    ],
    [
      [15, 16, 17],
      [69, 74, 79],
      [123, 132, 141],
      [177, 190, 203],
      [231, 248, 265]
    ],
    [
      [21, 22, 23],
      [99, 104, ...],
      ...
    ]
  ]
>
2 Likes