Hey all. I’m working through some of my old torch
into Nx
. I ran into a particular problem that I’m hoping someone can shed some light on.
I’m attempting to one-hot encode a Nx.tensor
, but using Scholar
I get an error. Here’s the code, followed by error:
tensor = Nx.tensor(5)
num_classes = 27
Scholar.Preprocessing.one_hot_encode(tensor, num_classes: num_classes)
This gives me this error:
** (ArgumentError) given axis (0) invalid for shape with rank 0
(nx 0.7.3) lib/nx/shape.ex:1121: Nx.Shape.normalize_axis/4
(nx 0.7.3) lib/nx.ex:14975: anonymous fn/3 in Nx.sort/2
(nx 0.7.3) lib/nx.ex:5368: Nx.apply_vectorized/2
(scholar 0.3.1) lib/scholar/preprocessing/ordinal_encoder.ex:53: Scholar.Preprocessing.OrdinalEncoder."__defn:fit_n__"/2
(nx 0.7.3) lib/nx/defn/compiler.ex:218: Nx.Defn.Compiler.__remote__/4
(scholar 0.3.1) lib/scholar/preprocessing/one_hot_encoder.ex:62: Scholar.Preprocessing.OneHotEncoder."__defn:fit_n__"/2
(scholar 0.3.1) lib/scholar/preprocessing/one_hot_encoder.ex:133: Scholar.Preprocessing.OneHotEncoder."__defn:fit_transform__"/2
#cell:bdn3o6ty5cug3rcb:8: (file)
I dug a bit into the original Scholar code that added one hot encoding here: Add ordinal and one-hot encodings by msluszniak · Pull Request #26 · elixir-nx/scholar · GitHub It looks like this is no longer the code that does one-hot encoding.
I tested it out, and this seems to do what I expect. For example:
tensor = Nx.tensor(5)
num_classes = 27
Nx.equal(
Nx.new_axis(tensor, -1),
Nx.iota({1, num_classes})
)
This looks correct to me:
#Nx.Tensor<
u8[1][27]
EXLA.Backend<host:0, 0.1032028734.2104360976.142186>
[
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
]
>
Am I doing something wrong or is there some sort of bug in the one_hot_encode
function?