Currently translating a small Numpy+Pandas script into Nx/Explorer combo.
Apologies if this is a very basic question but I’m missing some sort of .where(List, fun) or .find(List, fun) which would return the indices of the elements that match the fun.
A simple analogy would be numpy’s nonzero function.
Do we have something similar in Nx/Explorer?
Example code:
input = [0, 1, 0, 0, -4, 5]
# Want the indexes of elements that are not zero (or any other predicate)
output = [1,4,5]
Thanks to @polvalente’s reply on Slack I realized that this is a problem to be solved “back in Elixir” land (ie. using Enum ), since the size of the dimensions is undefined/dynamic; hence not a suitable problem for Nx (and Explorer? afaik).
His textual quote:
nonzero-like functions can’t work in pure Nx because they imply dynamic shapes.
You can use something like argsort to get the highest indices, and Nx.sum(Nx.not_equal(arr, 0)) to get the count, and then bring them both to Elixir to slice the list accordingly
Thank you @billylanchantin! The masks feature is really cool and surely will come handy later on.
In this case though I’m after the indices of the elements that match the condition rather than the elements themselves; which is what mask would return if I understand correctly, right?
Oh, that indeed achieves the desired output, thanks a lot @billylanchantin
Also one could replace Series.not_equal with Series.map for a more generic condition, great!
Possibly! But remember: you need the 2nd argument to Series.mask/2 to have dtype boolean. You probably want something like Series.not_equal/2 or similar if you’re working with predicates like you are with np.where.