Nx.to_flat_list inside defn


I am working on an implementation for the Resize operator for axon_onnx here: Support for `op_type` "Resize" by jnnks · Pull Request #35 · elixir-nx/axon_onnx · GitHub and need to calculate the output shape of the Tensor inside a defn context.

Consider the following snippet:

defmodule Sample do
  import Nx.Defn

  defn calc_output_size(input, scales) do
    # take inputs shape and multiply with scales
    #   input: Nx.Tensor<f32[1][1][2][4]>
    #   scales: Nx.Tensor<f32[4]>
    # return tuple
    output_size = input
      |> Nx.shape()
      |> Tuple.to_list()
      |> Nx.tensor()
      |> Nx.multiply(scales)
      |> Nx.to_flat_list()
      |> List.to_tuple()

The call to Nx.to_flat_list fails during compile-\runtime.
How can I get the values out of the tensor and into a tuple?

Fails with what error message?

at runtime

     ** (ArgumentError) cannot invoke to_binary/2 on Nx.Defn.Expr.
     This typically means you are invoking an unsupported Nx function
     inside `defn` or inside JIT compiled code

or in the separate example code during compile time

** (CompileError) #cell:24: Nx.to_flat_list/1 is not allowed inside defn

The answer is no because you can’t really traverse lists inside the native CPU/GPU, they are an Elixir construct. The idea is that everything inside the defn happens on the native device.

If that’s the last step of you computation, then you should invoke it outside of defn, once it is over. If it is part of the computation, then you want to write it in a way it doesn’t require to_flat_list.

I was hoping you wouldn’t say that :smiley:

I also realize this could work in defn as long as we return slices of the original tensor. It is something you could do by hand using while, but we should probably add conveniences.