Is there a numpy.take equivalent for Nx?
I currently have two tensors:
t = Nx.tensor([4, 3, 5, 7, 6, 8])
indices = Nx.tensor([5, 3, 4])
I am looking for a way to get the values located at the indices without having to transform my tensor into a list.
Something like that :
iex> Nx.take(t, indices)
#Nx.Tensor<
u64[3]
[8, 7, 6]
>
# Or
iex> t[indices]
#Nx.Tensor<
u64[3]
[8, 7, 6]
>
I took a look at the source code, tensors seem to support the syntax tensor[[index1, index2]]
but from what I understand, using this syntax, Nx will return the value located at tensor[index1][index2]
for a 2 dimensional tensor.
Thanks in advance for your help