How to use all values of an axis as parameters of a function?

Suppose I’ve a tensor like this:

f32[3][3][2]
[
    [
        [0.0, 0.0],
        [0.0, 1.0],
        [0.0, 2.0]
    ],
    [
        [1.0, 0.0],
        [1.0, 1.0],
        [1.0, 2.0]
    ],
    [
        [2.0, 0.0],
        [2.0, 1.0],
        [2.0, 2.0]
    ]
]

Is there a way to use the two values of axis 2 as parameters of a function and use the returned value to create a tensor of shape f32[3][3]?

I tried reduce, but it works just iterating along the axis and not just taking the values right?

This will depend a lot on what you’re trying to achieve.
Some functions, such as Nx.sum, Nx.reduce_max and Nx.product accept :axes as options. You can check examples here: https://github.com/elixir-nx/nx/blob/main/nx/guides/tensor-aggregation-101.livemd

However, if you have a more complex function that works in {2}-shaped tensors, and you want to treat your {3, 3, 2} tensor as a collection of those, you could approach it via vectorization: https://github.com/elixir-nx/nx/blob/main/nx/guides/vectorization.livemd

It seems good, but from this point I don’t understand how can I iterate throught my collection of {2}-shaped tensors and call a function for each of them, Nx.map still take alone elements instead the {2}-shaped tensors. I searched in Nx docs but no success.

The idea would for you to implement a function that takes in a {2} shaped tensor, and then vectorize a collection to pass down to it.

For example:

iex(2)> fun = fn x -> Nx.add(x[0], x[1]) end 
#Function<42.125776118/1 in :erl_eval.expr/6>
iex(3)> Nx.Defn.jit_apply(fun, [Nx.tensor([10, 20])])
#Nx.Tensor<
  s64
  30
>
iex(4)> t2 = Nx.tensor([[10, 20], [30, 40], [-10, -20], [-30, -40]])
#Nx.Tensor<
  s64[4][2]
  [
    [10, 20],
    [30, 40],
    [-10, -20],
    [-30, -40]
  ]
>
iex(5)> Nx.Defn.jit_apply(fun, [t2])
#Nx.Tensor<
  s64[2]
  [40, 60]
>
iex(6)> Nx.Defn.jit_apply(fun, [Nx.vectorize(t2, :rows)])
#Nx.Tensor<
  vectorized[rows: 4]
  s64
  [30, 70, -30, -70]
>

Notice how vectorizing the input changes how the function reacts to the tensor and yields the sum of each row

Sorry for taking so long, this was a great help, I finally solved my problem. Thank you very much Paulo.

1 Like

Glad it was!