Clustering with DBSCAN is not working with high dimension vectors, and HNSW provides ANN, but has no clustering options… By combining both, it is possible to achieve high speed clustering of thousands of vectors (dim=128)
Thanks to Nx and this package hnswlib | Hex
For future reference… here is the implementation
defmodule Koko.Clustering do
require Logger
# SAMPLE PARAMS
#
# max_elements = 10000
# ef_construction = 200
# M = 16
# ef = 50 # ef should be set based on your accuracy/speed tradeoff needs
@eps 0.3 # Distance threshold for DBSCAN
@min_samples 12 # Minimum number of points to form a dense region (cluster)
# DBSCAN
def create_clusters(index, opts \\ []) do
count = index |> instance().get_current_count() |> unwrap!()
if count > 0 do
eps = Keyword.get(opts, :eps, @eps)
min_samples = Keyword.get(opts, :min_samples, @min_samples)
labels = Nx.tensor(Enum.map(0..count, fn _ -> -1 end), type: {:s, 16})
{labels, _cluster_id} = 0..count
|> Enum.reduce({labels, 0}, fn i, {labels, cluster_id} = acc ->
Logger.debug("#{__MODULE__} LOOP #{i} for cluster #{cluster_id}")
if Nx.to_number(labels[i]) != -1 do
acc
else
neighbors = hnsw_neighbors(index, get_item(index, i), eps: eps)
# |> IO.inspect(label: "NEIGHBORS", limit: :infinity)
if length(neighbors) < min_samples do
# Mark labels[i] = -1 as noise
labels = mark_labels(labels, i, -1)
{labels, cluster_id}
else
# Expand cluster
# Mark labels[i] as cluster_id
labels = labels
|> mark_labels(i, cluster_id)
|> do_process_neighbors(neighbors, cluster_id)
{labels, cluster_id + 1}
end
end
end)
Logger.debug("#{__MODULE__} labels #{inspect labels}")
labels
else
Logger.warning("#{__MODULE__} index is empty")
[]
end
end
defp do_process_neighbors(labels, [], _cluster_id), do: labels
defp do_process_neighbors(labels, [current | rest], cluster_id) do
if Nx.to_number(labels[current]) == -1 do
mark_labels(labels, current, cluster_id)
else
labels
end |> do_process_neighbors(rest, cluster_id)
end
defp mark_labels(labels, i, value) do
labels |> Nx.put_slice([i], Nx.tensor([value], type: :s16))
# |> IO.inspect(label: "LABELS")
end
def hnsw_neighbors(index, point, opts \\ [])
def hnsw_neighbors(_index, nil, _opts) do
[]
end
def hnsw_neighbors(index, point, opts) do
eps = Keyword.get(opts, :eps, @eps)
count = Keyword.get(opts, :count, index |> instance().get_current_count() |> unwrap!())
{:ok, ids, distances} = HNSWLib.Index.knn_query(index, point, k: count)
# Do not take the head, as it is the point itself
[_ | list_ids] = ids |> Nx.to_flat_list()
[_ | list_distances] = distances |> Nx.to_flat_list()
list_ids
|> Enum.zip(list_distances)
|> Enum.filter(& elem(&1, 1) <= eps)
|> Enum.map(& elem(&1, 0))
end
def get_item(index, i) do
Logger.debug("#{__MODULE__} get_item #{i}")
case instance().get_items(index, [i]) |> unwrap!() do
[item] ->
item |> Nx.from_binary(:f32)
any ->
Logger.error("#{__MODULE__} get item #{i} failure #{inspect any}")
nil
end
end
# HNSW
def new_index(opts \\ []) do
space = Keyword.get(opts, :space, :l2)
dim = Keyword.get(opts, :din, 128)
max_elements = Keyword.get(opts, :max_elements, 100)
space
|> instance().new(dim, max_elements)
|> unwrap!()
end
# Cannot be defdelegate because instance() is dynamic
def knn_query(index, query, opts \\ []) do
instance().knn_query(index, query, opts)
end
def unwrap!({:ok, value}), do: value
def unwrap!({:error, value}), do: value
defp instance do
HNSWLib.Index
end
end