How to create a list of slices in Nx?

Let’s say I have a 1D tensor as follows:

x = Nx.tensor([1, 2, 3, 4, 5])

Along with a 1D tensor of indexes:

ix = Nx.tensor([0, 2, 3])

From this, how could I create a new 2D tensor, where each row is a 2-element slice of x, starting at each index in ix?

Like this:

Nx.tensor([
  [1, 2],
  [3, 4],
  [4, 5]
])

I’ve taken a look at Nx.slice(), Nx.take(), and Nx.gather(), but I can’t quite get the exact shape I want with any of them.

Would love some feedback from the community!

1 Like

Hey @megaboy101 :wave:

I have a working solution, but I think it is not the most idiomatic one and probably not the one you are looking for :sweat:

iex(1)> Mix.install([:nx])
:ok
iex(2)> x = Nx.tensor([1, 2, 3, 4, 5])
#Nx.Tensor<
  s64[5]
  [1, 2, 3, 4, 5]
>
iex(3)> ix = Nx.tensor([0, 2, 3])
#Nx.Tensor<
  s64[3]
  [0, 2, 3]
>
iex(4)> ix |> Nx.to_flat_list() |> Enum.map(fn start_index -> x[start_index..start_index + 1//1] end) |> Nx.stack()
#Nx.Tensor<
  s64[3][2]
  [
    [1, 2],
    [3, 4],
    [4, 5]
  ]
>

Note that x[start_index..start_index + 1//1] is the same of doing Nx.slice(x, [start_index], [2]) end).

But I guess there are a better ways to achieve that :thinking:

Out of curiosity, are you trying to replicate a snippet from numpy or something else? If so, maybe you can share it, it can be helpful.

Cheers :slight_smile:

1 Like

You can create a n X 2 tensor with Nx.reshape:

x = Nx.tensor([1, 2, 3, 4, 5])
Nx.reshape(x, {:auto, 2})

I don’t quite understand what you need the indices for. Could you expand on its requirement?

EDIT: I just realised that your tensor has got an odd number of elements. Please ignore this comment for now, it is incorrect.

1 Like

Hello again :slight_smile:

I might have found a better approach that does not need you to go outside the Nx domain.

x = Nx.tensor([1, 2, 3, 4, 5])
ix = Nx.tensor([0, 2, 3])

# You compose a tensor with all the indexes and
# then you can use Nx.take

ix_2 = Nx.add(ix, 1)
#Nx.Tensor<
  s64[3]
  [1, 3, 4]
>

all_indexes = Nx.stack([ix, ix_2], axis: 1)
#Nx.Tensor<
  s64[3][2]
  [
    [0, 1],
    [2, 3],
    [3, 4]
  ]
>

Nx.take(x, all_indexes)
#Nx.Tensor<
  s64[3][2]
  [
    [1, 2],
    [3, 4],
    [4, 5]
  ]
>

Bye :slight_smile:

3 Likes

NxSignal.as_windowed might also help, although you might need to play around with padding.

However I I suspect reshape is the way to go and the example was just not worked properly.

Edit: nevermind, I think you need a combination of Nx.tile and Nx.slice_along_axis to accomplish this

2 Likes

Try using NxSignal.as_windowed to get the slices and then Nx.take along the first axis to get the slices according to you index tensor

1 Like