Hi
I have some troubles getting started with Nx. What is the problem with the following code?
defmodule MyNxModule do
import Nx.Defn
defn vector_with_all_zeros(length) do
Nx.broadcast(0.0, {length})
end
end
When I run it from IEx
iex> MyNxModule.vector_with_all_zeros(10)
the following error returns
** (ArgumentError) invalid dimension in axis 0 in shape. Each dimension must be a positive integer, got #Nx.Tensor<
s64
Nx.Defn.Expr
parameter a:0 s64
> in shape {#Nx.Tensor<
s64
Nx.Defn.Expr
parameter a:0 s64
>}
(nx 0.1.0) lib/nx/shape.ex:35: Nx.Shape.validate!/3
(nx 0.1.0) lib/nx.ex:2285: Nx.broadcast/3
It works when i replace the ‘length’ function param with a fixed integer value. Is there a problem with the function definition?
Thank you for your help.