The following lines works as expected. Both return -1:
Nx.subtract(2, 3) # -1
Nx.subtract(Nx.tensor(2), 3) # -1
Output:
#Nx.Tensor<
s32
-1
>
However, the following block returns 4294967295:
bs = 3
y_real = Nx.tensor([1, 2, 3])
y_pred = Nx.tensor([0, 2, 3])
# [0, 1, 1]
Nx.equal(y_real, y_pred)
# 2
|> Nx.sum()
# 4294967295
|> Nx.subtract(bs)
Output:
#Nx.Tensor<
u32
4294967295
>
I got the same result in EXLA.