Series, is there a method that returns the indices where the predicate is true

Hello! :wave:

Currently translating a small Numpy+Pandas script into Nx/Explorer combo.
Apologies if this is a very basic question but I’m missing some sort of .where(List, fun) or .find(List, fun) which would return the indices of the elements that match the fun.
A simple analogy would be numpy’s nonzero function.
Do we have something similar in Nx/Explorer?

Example code:

input = [0, 1, 0, 0, -4, 5] 
# Want the indexes of elements that are not zero (or any other predicate)
output = [1,4,5]

Thanks!

Thanks to @polvalente’s reply on Slack I realized that this is a problem to be solved “back in Elixir” land (ie. using Enum ), since the size of the dimensions is undefined/dynamic; hence not a suitable problem for Nx (and Explorer? afaik).
His textual quote:

nonzero-like functions can’t work in pure Nx because they imply dynamic shapes.
You can use something like argsort to get the highest indices, and Nx.sum(Nx.not_equal(arr, 0)) to get the count, and then bring them both to Elixir to slice the list accordingly

Thanks!

1 Like

In Explorer you have more options for this than Nx. E.g. there’s filter:

You can use that to return all non-zero elements:

require Explorer.Series
series = Explorer.Series.from_list([-1, 0, 1])
Explorer.Series.filter(series, _ != 0)
# #Explorer.Series<
#   Polars[2]
#   s64 [-1, 1]
# >

You could also use index-functions like not_equal and mask:

series = Explorer.Series.from_list([-1, 0, 1])
mask = Explorer.Series.not_equal(series, 0)
# #Explorer.Series<
#   Polars[3]
#   boolean [true, false, true]
# >
Explorer.Series.mask(series, mask)
# #Explorer.Series<
#   Polars[2]
#   s64 [-1, 1]
# >
1 Like

This seems very similar to a problem I faced when porting sklearn’s count vectorizer to Nx.

You can see how I handled it while staying in Nx here: mighty/lib/preprocessing/count_vectorizer.ex at 8e4e3a47f233043448c91788368f0ae02716440b · acalejos/mighty · GitHub

I also wrote about this process here if you’re interested: Python NumPy to Elixir-Nx

3 Likes

Thank you @billylanchantin! The masks feature is really cool and surely will come handy later on.
In this case though I’m after the indices of the elements that match the condition rather than the elements themselves; which is what mask would return if I understand correctly, right?

Super nice write-up @acalejos! Thanks for linking it. Will need a bit of time to digest it see if I can port it directly to my scenario’s needs. :smiling_face:

1 Like

@_toni

In that case you may want something like this:

series = Explorer.Series.from_list([-1, 0, 1])
# #Explorer.Series<
#   Polars[3]
#   s64 [-1, 0, 1]
# >
mask = Explorer.Series.not_equal(series, 0)
# #Explorer.Series<
#   Polars[3]
#   boolean [true, false, true]
# >
indices = Explorer.Series.row_index(series)
# #Explorer.Series<
#   Polars[3]
#   u32 [0, 1, 2]
# >
Explorer.Series.mask(indices, mask)
# #Explorer.Series<
#   Polars[2]
#   u32 [0, 2]
# >
2 Likes

Oh, that indeed achieves the desired output, thanks a lot @billylanchantin :pray:
Also one could replace Series.not_equal with Series.map for a more generic condition, great!

1 Like

You got it!

Possibly! But remember: you need the 2nd argument to Series.mask/2 to have dtype boolean. You probably want something like Series.not_equal/2 or similar if you’re working with predicates like you are with np.where.

1 Like