Elixir/Nx/Axon equivalents of PyTorch code

Hello Nx community! :wave:

I’m here as a noob in the field of AI (and python), trying to accomplish a task with my favourite language. I’d like to use an exported ONNX model (exported from a torch model, built in python, which I can load from the filesystem). Someone before me has written some python code to execute the torch model, which I would like to do in elixir. The gist boils down to the following:

import pandas as pd
import numpy as np
import torch

model = torch.load('/path/to/model.torch', map_location=torch.device('cpu'))
pred = {}

for index, data_frame in some_list_of_data_frames:
  feats = torch.from_numpy(np.array(data_frame, dtype=np.float32))
  with torch.no_grad():
    pred[index] = model.forward(feats.unsqueeze(0)).detach().cpu().numpy()[0]

result = pd.DataFrame.from_dict(pred, orient="index", columns=["some", "set", "of", "columns"])

The data frames are built somehow, but that’s for later (with explorer, I suppose). I’d like to know what a comparable elixir implementation would be, if at all possible. I manage to load the ONNX model with the axon_onnx lib, which spits out the axon model and a map of params, which is a start. But then I don’t really know how to mimic the torch.from_numpy(), torch.no_grad(), model.forward(...).detach().cpu().numpy() and pd.DataFrame.from_dict(...). If anyone has tips to share, I’d be very grateful :pray:

Maybe this guide is a good starting point: Converting ONNX file into Axon by meanderingstream · Pull Request #227 · elixir-nx/axon · GitHub


Thanks for the reference! This gives me some insight on how to run the model. The gist is to run Axon.predict/4 on the input tensor.

As a follow-up question: to construct the input tensor I’d like to do a linear interpolation of series data, just like this numpy interp function: numpy.interp — NumPy v1.23 Manual
Is this in scope of Nx, or should I look elsewhere? I’m not finding anything related to linear interpolation right now (but maybe there are better terms to describe this problem, and I’m not looking for the right words).

I have a similar need for creating a histogram where I can specify the bins wherein I can sum values, to get to a tensor with well structured series data. Something like this numpy histogram function: numpy.histogram — NumPy v1.23 Manual

As far as I understand these things are related to building and manipulating a tensor, and should be found around the Nx library. But correct me if I’m wrong, and if there are other libs that I should look into.

After thinking about it some more, I think I’m looking in the wrong space. The things I want to do are not something Nx would be able to help me (interpolating and aggregating in a histogram). I can’t expect all the numpy goodies to be available in Nx or Explorer. Numpy has a way bigger API surface.

In my specific case, I can roll my own transformations (which does involve enumerating all values, which is fine because my tensors are relatively small).