Hey all, newbie here with a newbie question
I’m wondering what’s the easier way to get the first n
rows from a multi-dimensional tensor.
For instance, If I have a 3 dimension tensor like this:
t = Nx.tensor([
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]
])
#Nx.Tensor<
s64[3][3]
[
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]
]
>
How can I get the first 2 rows? I mean [[0, 3, 6], [1, 4, 7]]
.
A simple solution is to transpose the tensor and then get the first 2 elements, like:
Nx.transpose(t)[0..1]
#Nx.Tensor<
s64[2][3]
[
[0, 3, 6],
[1, 4, 7]
]
>
But I was wondering if there is a more straightforward way to do that?
I’m asking because I have seen that with python (numpy) you can do that with array[0:2]
.
This is just a curiosity, is not based on any particular needs.
❯ python3
>>> import numpy as np
>>> x1 = [1, 2, 3]
>>> x2 = [4, 5, 6]
>>> x3 = [7, 8, 9]
>>> X = np.column_stack([x1, x2, x3])
>>> X
array([[1, 4, 7],
[2, 5, 8],
[3, 6, 9]])
>>> X[0:2]
array([[1, 4, 7],
[2, 5, 8]])
Thank you all in advance
Cheers