`numpy.take` equivalent for Nx

Is there a numpy.take equivalent for Nx?

I currently have two tensors:

t = Nx.tensor([4, 3, 5, 7, 6, 8])
indices = Nx.tensor([5, 3, 4])

I am looking for a way to get the values located at the indices without having to transform my tensor into a list.

Something like that :

iex> Nx.take(t, indices)
#Nx.Tensor<
  u64[3]
  [8, 7, 6]
>

# Or
iex> t[indices]
#Nx.Tensor<
  u64[3]
  [8, 7, 6]
>

I took a look at the source code, tensors seem to support the syntax tensor[[index1, index2]] but from what I understand, using this syntax, Nx will return the value located at tensor[index1][index2] for a 2 dimensional tensor.

Thanks in advance for your help

Coming soon: Implement Nx.take/3 backed by xla::Gather by jonatanklosko · Pull Request #433 · elixir-nx/nx · GitHub :slight_smile:

5 Likes

Perfect, thanks!