How can I speed up this Prim's algorithm implementation?

My implementation seems “correct” but times out on LeetCode for the longer test cases. I tried using Task.async_stream to concurrently process some of the steps but this made it worse. Any suggestions?

defmodule Solution do
  @spec min_cost_connect_points(points :: [[integer]]) :: integer
  def min_cost_connect_points(points) when length(points) < 2, do: 0
  def min_cost_connect_points(points) do
    vertices = points |> Enum.with_index() |> Enum.map(fn {pt, i} -> {i, pt} end) |> Enum.into(%{})
    len = length(points)
    
    set = MapSet.new(1..len - 1)
    visited = MapSet.new([0])
    cost = 0
    do_prims(vertices, set, visited, cost)
  end

  defp do_prims(vertices, set, visited, cost) do 
      if MapSet.size(set) == 0 do
        cost
      else
        {next_vertex, min_dist} = smallest_distance(vertices, set, visited)
        visited = MapSet.put(visited, next_vertex)
        cost = cost + min_dist
        set = MapSet.difference(set, visited)
        do_prims(vertices, set, visited, cost)
      end
  end
      
  defp gen_weights(vertices, current, candidates) do 
      m = candidates
      |> Enum.reduce(%{}, fn i, map -> 
                    Map.put(map, i, distance(vertices[i], vertices[current]))
                    end)
      %{current => m}
  end
  
  defp smallest_distance(vertices, candidates, visited) do 
      
      visited
      |> Enum.map(fn i -> smallest_helper(vertices, i, candidates) end)
      |> Enum.min_by(&elem(&1, 1))
  end 
      
  defp smallest_helper(vertices, i, candidates) do 
      weights = gen_weights(vertices, i, candidates)
      
      candidates
      |> Enum.map(fn j -> {j, weights[i][j]} end)
      |> Enum.min_by(&elem(&1, 1))
  end
      
  defp distance([a,b], [c, d]), do: abs(a - c) + abs(b - d)
end

The usual culprit in these kinds of algorithms is doing the same work more than once in a loop; annotations alongside the source below

So this implementation takes time proportional to the cube of the input size.

Task.async_stream is not going to save you from cubic performance for long. (as you’ve noted)

One piece of work that jumps out right away: in smallest_distance, smallest_helper is called for every element of visited every iteration but visited only grows by one element each iteration.

2 Likes

Thanks for the detailed reply. I knew that was the bottleneck. I must not understand the logic behind the algorithm though, because when I did not iterate over all visited nodes in smallest_distance I got the wrong answer.

I’ve tried a new approach using :gb_trees as a queue and it seems more efficient but also prone to error. I suspect the issue is using distances as keys in the queue leads to omission of some minimum distances for some nodes as they get overwritten. Using node indices as the keys eliminates the efficiency boost of using balanced trees in the first place. So forced it to work by making the tree store the indexes for each minimum distance in a list and then handle the empty list. It’s not fast but it completes in the allowed time.

  def min_cost_connect_points(points) do
    vertices = points |> Enum.with_index() |> Enum.map(fn {pt, i} -> {i, pt} end) |> Enum.into(%{})
    len = length(points)
    
    candidates = MapSet.new(0..len - 1)
    queue = :gb_trees.enter(0, [0], :gb_trees.empty()) 
    cost = 0
    do_prims(vertices, candidates, queue, cost)
  end

  defp do_prims(vertices, candidates, queue, cost) do 
      if MapSet.size(candidates) == 0 do
        cost
      else
        {min_dist, [next_vertex | rest], new_queue} = :gb_trees.take_smallest(queue) 
        new_queue = if rest == [] do
                      new_queue
                    else 
                      :gb_trees.enter(min_dist, rest, new_queue)
                    end
        if not MapSet.member?(candidates, next_vertex) do # skip if next_vertex removed from candidates b/c it has already been visited
            do_prims(vertices, candidates, new_queue, cost)
        else
            candidates = MapSet.delete(candidates, next_vertex)
            updated_queue = candidates
                            |> Enum.reduce(new_queue, fn i, q -> 
                                      d = distance(vertices[i], vertices[next_vertex])
                                      is = if :gb_trees.is_defined(d, q) do
                                                is = :gb_trees.get(d, q)
                                                [i | is]
                                            else
                                                [i]
                                            end
                                            :gb_trees.enter(d, is, q)
                                      end)
            do_prims(vertices, candidates, updated_queue, cost + min_dist)
        end
      end
  end
      
  defp distance([a, b], [c, d]), do: abs(a - c) + abs(b - d)