Is there a numpy.take equivalent for Nx?
I currently have two tensors:
t = Nx.tensor([4, 3, 5, 7, 6, 8]) indices = Nx.tensor([5, 3, 4])
I am looking for a way to get the values located at the indices without having to transform my tensor into a list.
Something like that :
iex> Nx.take(t, indices) #Nx.Tensor< u64 [8, 7, 6] > # Or iex> t[indices] #Nx.Tensor< u64 [8, 7, 6] >
I took a look at the source code, tensors seem to support the syntax
tensor[[index1, index2]] but from what I understand, using this syntax, Nx will return the value located at
tensor[index1][index2] for a 2 dimensional tensor.
Thanks in advance for your help