Possible bug in Nx.subtract()

The following lines works as expected. Both return -1:

Nx.subtract(2, 3)  # -1
Nx.subtract(Nx.tensor(2), 3) # -1

Output:

#Nx.Tensor<
  s32
  -1
>

However, the following block returns 4294967295:

bs = 3
y_real = Nx.tensor([1, 2, 3])
y_pred = Nx.tensor([0, 2, 3])

# [0, 1, 1]
Nx.equal(y_real, y_pred)
# 2
|> Nx.sum()
# 4294967295
|> Nx.subtract(bs)

Output:

#Nx.Tensor<
  u32
  4294967295
>

I got the same result in EXLA.

One is signed (s32), the other one unsigned (u32). The binary value is identical. I guess you will have to be explicit about your tensor types.

2 Likes