Can anyone help me implement similarity/cosine similarity in Nx?

I’m trying to port some code from python, to elixir. The python code generates tensors from embedding in BERT, then does some form of similarity comparison between them. From what i’ve found online, it looks like cosine similarity is the calculation I’m looking for, but I can’t quite understand it enough to implement it in Nx.

The formula is listed as A ⋅ B / ||A|| ||B||. I have two tensors with with the shape #Nx.Tensor<f32[1][6][119547]...>. So far this is all i’ve come up with:

for i <- 0..5, j <- 0..5 do
  t1 = tensor1[0][i]
  t2 = tensor2[0][j]
  Nx.dot(t1, t2) / ???
end

I found another formula on the wikipedia page that’s numpy code, which says:

np.sum(a*b)/(np.sqrt(np.sum(a**2)) * np.sqrt(np.sum(b**2)))

I think converted to Nx that’s:

defmodule CosSim do
  import Nx.Defn
  defn cosine_similarity(a, b) do
    left = Nx.sqrt(Nx.sum(a**2))
    right = Nx.sqrt(Nx.sum(b**2))
    Nx.sum(a * b) / (left * right)
  end
end

If I do a quick test:

a = Nx.tensor([1,2,3])
b = Nx.tensor([4,5,6])
CosSim.cosine_similarity(a, b)

#Nx.Tensor<
  f32
  EXLA.Backend<host:0, 0.528063503.4042653716.111775>
  0.9746317863464355
>

If I try to validate it in python:

>>> a = np.matrix([1,2,3])
>>> b = np.matrix([4,5,6])
>>> np.sum(a*b)/(np.sqrt(np.sum(a**2)) * np.sqrt(np.sum(b**2)))

ValueError: shapes (1,3) and (1,3) not aligned: 3 (dim 1) != 1 (dim 0)

So something is off. Admittdly i’ve very new to any of this ML/Nx stuff, so maybe I’m way off, or maybe i’m close. Any tips?

1 Like

We have Cosine Distance available in Scholar: scholar/distance.ex at main · elixir-nx/scholar · GitHub

Cosine Similarity would be 1 - CosDistance :slight_smile:

2 Likes

Awesome, thanks. I’ll give that a go. I swear I spent 30 minutes googling anything related to the subject in Elixir/Nx, and Google was useless. It seems to have gotten steadily worse over the past few years, and in the past few months dramatically so. Maybe it’s just me.

FYI, scholar is not available via hex, or on hex.pm. I had to pull it manually, even though the docs say to add it as a dep in mix.exs.

Yes, the library hasn’t been released yet. The README probably still has the default autogenerated instructions.

Pitching in just to say that your function is formally correct, although you could simplify it a bit by using existing Nx functions for the dot product and the norm (possibly benefiting from some optimizations):

defmodule CosSim do
  import Nx.Defn

  defn cosine_similarity(a, b) do
    Nx.dot(a, b) / (Nx.LinAlg.norm(a) * Nx.LinAlg.norm(b))
  end
end

You can verify that it’s correct with some sanity checks:

# Similarity for vectors pointing in the same direction is 1
a = Nx.tensor([1, 2])
b = Nx.tensor([2, 4])
CosSim.cosine_similarity(a, b)
# =>
# #Nx.Tensor<
#   f32
#   1.0
# >

# Similarity for orthogonal vectors is 0
a = Nx.tensor([3, 0])
b = Nx.tensor([0, 3])
CosSim.cosine_similarity(a, b)
# =>
# #Nx.Tensor<
#   f32
#   0.0
# >

# Similarity for vectors pointing in opposite directions is -1:
a = Nx.tensor([5, 0, 2])
b = Nx.tensor([-10, 0, -4])
CosSim.cosine_similarity(a, b)
# =>
# #Nx.Tensor<
#   f32
#   -1.0
# >

# Vectors at 45deg have cosine similarity = cos(Pi/4) ~ 0.7071
a = Nx.tensor([1, 0])
b = Nx.tensor([1, 1])
CosSim.cosine_similarity(a, b)
# =>
# #Nx.Tensor<
#   f32
#   0.7071067690849304
# >

Note though that the definition of the cosine distance in scholar also has additional special cases for when the L2 norm is so small that it would cause numerical stability issues, or when one or both the operands have norm equal to zero. Additionally, it allows to calculate the cosine distance between multiple vectors in batch, by passing rank-2 tensors as arguments.

3 Likes

Thanks @lucaong and @polvalente for all the help. I’m making my way through the porting of the code, and learning a lot, but it’s all very new to me. The cosine_similarity function that scikit uses(which happens to be what the code i’m porting is using too), seems to do something a bit different, and I’m wondering what the name for it might be, as I’m not even sure what to start googling for. Here’s an example of what it returns.

from sklearn.metrics.pairwise import cosine_similarity
>>> a = np.matrix([[1,2,3], [4,5,6]])
>>> b = np.matrix([[9,8,7], [6,5,4]])
>>> cosine_similarity(a, b)
array([[0.88265899, 0.85280287],
       [0.96546332, 0.94805195]])

Now, if I use the axes option from the Scholar formula. axes: [1] gives me:

#Nx.Tensor<
  f32[2]
  EXLA.Backend<host:0, 0.2926165521.2327445524.101050>
  [0.1173410415649414, 0.05194807052612305]
>

Which if I subtract from 1 as Paulo mentioned before, I do find two of those numbers, the top left to bottom right diagonal.

iex(52)> 1 - 0.1173410
0.882659
iex(53)> 1 - 0.051948
0.948052

I’m not sure where the other two come from, or how this would be calculated, or what it’s called. Any insight? I’m guessing maybe pairwise was the first clue I should I have followed.

I’d guess that when you pass matrices to cosine_similarity in scikit, it considers each matrix a collection of vectors, and returns the pairwise similarities between each vector in the first matrix and each vector in the second, therefore you get four elements as a result.

1 Like

Yup, that was it!

>>> [cosine_similarity(a[i], b[j]) for i in range(2) for j in range(2)]
[array([[0.88265899]]), array([[0.85280287]]), array([[0.96546332]]), array([[0.94805195]])]
1 Like

You could implement the same behavior as scikit (pairwise cosine similarities between vectors in the given matrices) without iteration, by normalizing along the last axis, transposing the second matrix, and performing a matrix multiplication:

defmodule Pairwise do
  import Nx.Defn

  defnp l2_norm(x, opts \\ []) do
    (x * x)
    |> Nx.sum(opts)
    |> Nx.sqrt()
  end

  defnp normalize(x, opts \\ []) do
    x / l2_norm(x, axes: opts[:axes], keep_axes: true)
  end

  defn cosine_similarity(a, b) do
    normalized_a = normalize(a, axes: [-1])
    normalized_b = normalize(b, axes: [-1])

    Nx.dot(normalized_a, Nx.transpose(normalized_b))
  end
end

You can now use it passing two vectors:

a = Nx.tensor([1, 2, 3])
b = Nx.tensor([9, 8, 7])

Pairwise.cosine_similarity(a, b)
# =>
#  #Nx.Tensor<
#    f32
#    0.8826588988304138
#  >

Or passing two matrices, and getting pairwise similarities, like with scikit:

a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
b = Nx.tensor([[9, 8, 7], [6, 5, 4]])

Pairwise.cosine_similarity(a, b)
# =>
# #Nx.Tensor<
#   f32[2][2]
#   [
#     [0.8826588988304138, 0.8528028130531311],
#     [0.9654632806777954, 0.948051929473877]
#   ]
>

4 Likes

I wish I could give you more than one heart. I’ve been working on a solution since I last commented 2 hours ago. I got it working without defn, but stupidly I checked the performance against the sklearn one(no idea why, I knew mine would be slower, and doesn’t really matter), but of course I wanted to get it working with defn, and I was trying to get something similar to for i <- 0..w1, j <- 0..w2, do: cosine_similarity(t1[i], t2[j]) working with a nested while in defn, very ugly funneling all those parameters through, only to finally learn I can’t update a tensor in a while loop in defn. At least that’s what I think this is telling me:

** (CompileError) iex:242: the do-block in while must return tensors with the same shape, type, and names as the initial arguments.

Body matches template:

{#Nx.Tensor<
   f32[2]
 >, #Nx.Tensor<
   s64
 >, #Nx.Tensor<
   s64
 >, #Nx.Tensor<
   s64
 >, #Nx.Tensor<
   s64[2][5]
 >, #Nx.Tensor<
   s64[3][5]
 >}

and initial argument has template:

{#Nx.Tensor<
   f32[1]
 >, #Nx.Tensor<
   s64
 >, #Nx.Tensor<
   s64
 >, #Nx.Tensor<
   s64
 >, #Nx.Tensor<
   s64[2][5]
 >, #Nx.Tensor<
   s64[3][5]
 >}

I’ve been using elixir for about 5-6 years, so it feels weird like i’m coding with my hands tied behind my back. I was just about ready to throw in the towel, but you came and saved the day for me. Thank you!

1 Like

This error has to do with the first argument of your tuple. You passed in a f32[1] and are returning a f32[2].

I don’t have your implementation, so best I can do is suggest you pad the input to make it f32[2] or drop one of the numbers from the f32[2] output