# How can I speed up this Prim's algorithm implementation?

My implementation seems “correct” but times out on LeetCode for the longer test cases. I tried using Task.async_stream to concurrently process some of the steps but this made it worse. Any suggestions?

``````defmodule Solution do
@spec min_cost_connect_points(points :: [[integer]]) :: integer
def min_cost_connect_points(points) when length(points) < 2, do: 0
def min_cost_connect_points(points) do
vertices = points |> Enum.with_index() |> Enum.map(fn {pt, i} -> {i, pt} end) |> Enum.into(%{})
len = length(points)

set = MapSet.new(1..len - 1)
visited = MapSet.new([0])
cost = 0
do_prims(vertices, set, visited, cost)
end

defp do_prims(vertices, set, visited, cost) do
if MapSet.size(set) == 0 do
cost
else
{next_vertex, min_dist} = smallest_distance(vertices, set, visited)
visited = MapSet.put(visited, next_vertex)
cost = cost + min_dist
set = MapSet.difference(set, visited)
do_prims(vertices, set, visited, cost)
end
end

defp gen_weights(vertices, current, candidates) do
m = candidates
|> Enum.reduce(%{}, fn i, map ->
Map.put(map, i, distance(vertices[i], vertices[current]))
end)
%{current => m}
end

defp smallest_distance(vertices, candidates, visited) do

visited
|> Enum.map(fn i -> smallest_helper(vertices, i, candidates) end)
|> Enum.min_by(&elem(&1, 1))
end

defp smallest_helper(vertices, i, candidates) do
weights = gen_weights(vertices, i, candidates)

candidates
|> Enum.map(fn j -> {j, weights[i][j]} end)
|> Enum.min_by(&elem(&1, 1))
end

defp distance([a,b], [c, d]), do: abs(a - c) + abs(b - d)
end
``````

The usual culprit in these kinds of algorithms is doing the same work more than once in a loop; annotations alongside the source below

So this implementation takes time proportional to the cube of the input size.

`Task.async_stream` is not going to save you from cubic performance for long. (as you’ve noted)

One piece of work that jumps out right away: in `smallest_distance`, `smallest_helper` is called for every element of `visited` every iteration but `visited` only grows by one element each iteration.

2 Likes

Thanks for the detailed reply. I knew that was the bottleneck. I must not understand the logic behind the algorithm though, because when I did not iterate over all visited nodes in smallest_distance I got the wrong answer.

I’ve tried a new approach using `:gb_trees` as a queue and it seems more efficient but also prone to error. I suspect the issue is using distances as keys in the queue leads to omission of some minimum distances for some nodes as they get overwritten. Using node indices as the keys eliminates the efficiency boost of using balanced trees in the first place. So forced it to work by making the tree store the indexes for each minimum distance in a list and then handle the empty list. It’s not fast but it completes in the allowed time.

``````  def min_cost_connect_points(points) do
vertices = points |> Enum.with_index() |> Enum.map(fn {pt, i} -> {i, pt} end) |> Enum.into(%{})
len = length(points)

candidates = MapSet.new(0..len - 1)
queue = :gb_trees.enter(0, [0], :gb_trees.empty())
cost = 0
do_prims(vertices, candidates, queue, cost)
end

defp do_prims(vertices, candidates, queue, cost) do
if MapSet.size(candidates) == 0 do
cost
else
{min_dist, [next_vertex | rest], new_queue} = :gb_trees.take_smallest(queue)
new_queue = if rest == [] do
new_queue
else
:gb_trees.enter(min_dist, rest, new_queue)
end
if not MapSet.member?(candidates, next_vertex) do # skip if next_vertex removed from candidates b/c it has already been visited
do_prims(vertices, candidates, new_queue, cost)
else
candidates = MapSet.delete(candidates, next_vertex)
updated_queue = candidates
|> Enum.reduce(new_queue, fn i, q ->
d = distance(vertices[i], vertices[next_vertex])
is = if :gb_trees.is_defined(d, q) do
is = :gb_trees.get(d, q)
[i | is]
else
[i]
end
:gb_trees.enter(d, is, q)
end)
do_prims(vertices, candidates, updated_queue, cost + min_dist)
end
end
end

defp distance([a, b], [c, d]), do: abs(a - c) + abs(b - d)
``````