Nx.dot doesn’t generally apply broadcasting (although some cases such as dot(scalar, non-scalar) tensor do).
That being said, you can use a combination of Nx.new_axis + Nx.tile on your weights tensor to mimic that.
That is:
w -> {T, T}
x -> {B, T, C}
then
w2 = Nx.new_axis(w) -> {1, T, T}
Nx.tile(w2, [B, 1, 1]) -> {B, T, T}
However, Nx 0.6 brought vectorization, and with that you can more easily deal with the batched case:
x -> {B, T, C}
w -> {K, T}
vec_x = Nx.vectorize(x, :some_name) # vectorized[some_name: B] {T, C}
vec_result = Nx.dot(w, [1], x, [0]) # vectorized[some_name: B] {K, C}
# Note: vec_result = Nx.dot(w, x) returns the same result :)
result = Nx.devectorize(vec_result, keep_names: false) # {B, K, C}
Example with values:
iex(5)> x
#Nx.Tensor<
s64[4][2][3]
[
[
[0, 1, 2],
[3, 4, 5]
],
[
[6, 7, 8],
[9, 10, 11]
],
[
[12, 13, 14],
[15, 16, 17]
],
[
[18, 19, 20],
[21, 22, 23]
]
]
>
iex(6)> w = Nx.iota({5, 2})
#Nx.Tensor<
s64[5][2]
[
[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]
]
>
iex(7)> Nx.dot(w, Nx.vectorize(x, :some_name))
#Nx.Tensor<
vectorized[some_name: 4]
s64[5][3]
[
[
[3, 4, 5],
[9, 14, 19],
[15, 24, 33],
[21, 34, 47],
[27, 44, 61]
],
[
[9, 10, 11],
[39, 44, 49],
[69, 78, 87],
[99, 112, 125],
[129, 146, 163]
],
[
[15, 16, 17],
[69, 74, 79],
[123, 132, 141],
[177, 190, 203],
[231, 248, 265]
],
[
[21, 22, 23],
[99, 104, ...],
...
]
]
>
iex(8)> |> Nx.devectorize(keep_names: false)
#Nx.Tensor<
s64[4][5][3]
[
[
[3, 4, 5],
[9, 14, 19],
[15, 24, 33],
[21, 34, 47],
[27, 44, 61]
],
[
[9, 10, 11],
[39, 44, 49],
[69, 78, 87],
[99, 112, 125],
[129, 146, 163]
],
[
[15, 16, 17],
[69, 74, 79],
[123, 132, 141],
[177, 190, 203],
[231, 248, 265]
],
[
[21, 22, 23],
[99, 104, ...],
...
]
]
>