How would you leverage or translate an existing, clear Enum.reduce/3
-alike struct method into a valid defimpl
for Enumerable.reduce/3
?
When implementing a “simple” reducer, it’s clear enough (at least how to map the call back into Enum.reduce
), but I just cannot wrap my head around how to appropriately interface with all those bizarre status atoms… like, what’s the difference between “halted”, “suspended”, and just plain “done”?
Example: lazy Multiset class with an `Enum.reduce/3`-alike method but no `Enumerable.reduce/3` implementation
Ctrl-F for “defimpl” to skip to the Enumerable implementation block
defmodule ISHYGDDT.Multiset do
@moduledoc """
An unordered multiplicitous container type.
"""
defstruct [counts: %{}]
defmodule Multiplicities do
@type t(element) :: %{optional(element) => pos_integer()}
@type t() :: t(term())
@moduledoc """
Operations to work on simple, non-struct multiplicity maps.
Elements with 0 multiplicity SHOULD be absent; the `from_counts/1` function
may be used to sanitize an existing "sloppy" map that has some 0 values.
"""
@typep counts_lax(element) :: Enumerable.t({element, non_neg_integer()})
defguardp is_pos_integer(n) when is_integer(n) and n > 0
defguardp is_non_neg_integer(n) when is_integer(n) and n >= 0
defp s_dupe({x, n}), do: Stream.duplicate(x, n)
defp s_dupe_flat(s), do: Stream.flat_map(s, &s_dupe/1)
@doc """
Returns `true` iff the argument is a well-formed multiset.
"""
@spec ok?(maybe_counts :: term()) :: boolean()
def ok?(%{} = map) do
Enum.all?(Map.values(map), &is_pos_integer/1)
end
def ok?(_), do: false
@doc """
Converts an enumerable of elements into a multiset of elements.
Behaves like Python's [`collections.Counter` constructor](https://docs.python.org/3/library/collections.html#collections.Counter)
does when called on a non-mapping iterable.
See also `Multiplicities.from_counts/1`.
"""
@spec from_elements(Enumerable.t(e)) :: t(e) when e: term()
def from_elements(enumerable) do
Enum.reduce(enumerable, %{}, fn element, acc ->
Map.update(acc, element, 1, &(&1 + 1)) end
)
end
@doc """
Converts an enumerable of `{element, count}` tuples into a multiset.
For non-simple maps, duplicate entries are allowed, and will all be folded in additively.
(This may have minor performance implications. Consider constructing multisets
directly if you don't like this.)
Trivial entries (an element with zero count) are allowed, and will be ignored.
Entries with negative counts are forbidden; passing them in is undefined behavior.
See also `Multiplicities.from_elements/1`.
"""
@spec from_counts(counts_lax(e)) :: t(e) when e: term()
def from_counts(map) when is_non_struct_map(map) do
:maps.filtermap(fn
_element, count when is_non_neg_integer(count) ->
count > 0
key, value ->
raise ArgumentError, "expected `{element :: term(), count :: non_neg_integer()}`, got: #{inspect {key, value}}"
end, map)
end
def from_counts(enumerable) do
# safe path when keys aren't guaranteed to be non-duplicate
Enum.reduce(enumerable, %{}, fn
{element, count}, acc when is_non_neg_integer(count) ->
if count > 0 do
Map.update(acc, element, count, &(&1 + count))
else
acc
end
other, _ ->
raise ArgumentError, "expected `{element :: term(), count :: non_neg_integer()}`, got: #{inspect other}"
end)
end
@doc """
Calculates the cardinality of the support of the multiset.
"""
@spec support_count(t()) :: non_neg_integer()
def support_count(multiplicities) do
map_size(multiplicities)
end
@spec count_element(t(e), e) :: non_neg_integer() when e: term()
def count_element(multiplicities, element) do
Map.get(multiplicities, element, 0)
end
@doc """
Returns the support of the multiset, as a `MapSet`.
"""
@spec support(t(e)) :: support :: MapSet.t(e) when e: term()
def support(multiplicities) do
MapSet.new(Map.keys(multiplicities))
end
@spec reduce_by(t(e), :elements, t_acc, (e, t_acc -> t_acc)) :: t_acc when e: term(), t_acc: term()
@spec reduce_by(t(e), :support, t_acc, (e, t_acc -> t_acc)) :: t_acc when e: term(), t_acc: term()
@spec reduce_by(t(e), :multiplicities, t_acc, ({e, pos_integer()}, t_acc -> t_acc)) :: t_acc when e: term(), t_acc: term()
def reduce_by(multiplicities, :elements, acc, fun) do
multiplicities
|> s_dupe_flat()
|> Enum.reduce(acc, fun)
end
def reduce_by(multiplicities, :support, acc, fun) do
multiplicities
|> Stream.map(fn {element, _} -> element end)
|> Enum.reduce(acc, fun)
end
def reduce_by(multiplicities, :multiplicities, acc, fun) do
Enum.reduce(multiplicities, acc, fun)
end
@spec to_list(t(e)) :: [e] when e: term()
def to_list(multiplicities) do
# https://github.com/elixir-lang/elixir/blob/v1.18.2/lib/elixir/lib/enum.ex#L3835
reduce_by(multiplicities, :elements, [], &[&1 | &2]) |> :lists.reverse()
end
@spec from_set(MapSet.t(e) | Enumerable.t(e)) :: t(e) when e: term()
def from_set(set) do
Map.from_keys(set, 1)
end
# Enumerable type methods
@doc """
Calculates the cardinality of the multiset.
"""
@spec count(t()) :: non_neg_integer()
def count(multiplicities) do
Enum.sum(Map.values(multiplicities))
end
@spec member?(t(), term()) :: boolean()
def member?(multiplicities, element) do
is_map_key(multiplicities, element)
end
# https://github.com/erlang/otp/blob/OTP-27.2.1/lib/stdlib/src/lists.erl#L512
defp prepend_duplicate(count, element, list)
defp prepend_duplicate(0, _, l), do: l
defp prepend_duplicate(n, x, l), do: prepend_duplicate(n - 1, x, [x | l])
defp gbt_stream(iter) do
Stream.unfold(iter, fn acc ->
case :gb_trees.next(acc) do
{key, value, acc} -> {{key, value}, acc}
:none -> nil
end
end)
end
defp gbt_stream_from(tree, inclusive_lower_bound, mode \\ :ordered) do
:gb_trees.iterator_from(inclusive_lower_bound, tree, mode)
|> gbt_stream()
end
@spec enumerable_slice(t(e)) :: {size :: non_neg_integer(), slicing_fun :: (start :: non_neg_integer(), length :: pos_integer(), step :: pos_integer() -> [e])} when e: term()
def enumerable_slice(multiplicities) do
# we cannot avoid an o(support.n) runtime
# but we can do better than an o(n) runtime
# when those metrics diverge
# (TODO: add a "slow path" that'll be faster for very small multisets)
{size, tree} = :maps.fold(
fn
element, count, {running_count, result} -> {
running_count + count,
:gb_trees.insert(running_count, {element, count}, result)}
end,
{0, :gb_trees.empty()},
multiplicities
)
{size, fn
start, length, step ->
case (
tree
|> gbt_stream_from(start + length*step - 1, :reversed)
|> Enum.reduce_while({{:start, start + length*step - 1}, length, []}, fn
{pos, {element, element_count}}, {offset, remaining, result} ->
element_count_available = case offset do
offset when is_integer(offset) ->
# how much did the large step size in the previous iteration
# consume out of the current iteration's bin?
# a negative result simply means this bin is overstepped
# and indicates how far into the next bin we overstepped.
element_count - offset
{:start, start_pos} ->
start_pos - pos
end
n =
div(element_count_available + (step - 1), step)
|> min(0) # if this bin is overstepped, simply don't draw any from it
|> max(remaining) # don't add more elements than asked for
result = prepend_duplicate(n, element, result)
remaining = remaining - n
consumed = n*step
overstep = consumed - element_count_available
if remaining > 0 do
{:cont, {overstep, remaining, result}}
else
{:halt, result}
end
end)
) do
result when is_list(result) -> result
acc -> raise RuntimeError, "early halt:\n\tacc = #{inspect acc}"
end
end}
end
@doc """
Determines whether one multiset is a (non-strict) subset of another.
See also `Multiplicities.difference/2` and `Multiplicities.difference!/2`.
"""
@spec subset?(t(), t()) :: boolean()
def subset?(lhs, rhs) do
Enum.all?(lhs, fn {element, count_1} ->
case rhs do
%{^element => count_2} when count_1 <= count_2 ->
true
%{} ->
false
other ->
:erlang.error({:badmap, other})
end
end)
end
@spec union(t(e1), t(e2)) :: t(e1 | e2) when e1: term(), e2: term()
def union(lhs, rhs) do
Enum.reduce(rhs, lhs, fn {element, count_2}, acc ->
case acc do
%{^element => count_1} when count_1 >= count_2 ->
acc
%{} ->
Map.put(acc, element, count_2)
other ->
:erlang.error({:badmap, other})
end
end)
end
@spec intersection(t(e1 | e3), t(e2 | e3)) :: t(e3) when e1: term(), e2: term(), e3: term()
def intersection(lhs, rhs) do
{lhs, rhs} = if map_size(rhs) > map_size(lhs), do: {rhs, lhs}, else: {lhs, rhs}
:maps.filtermap(fn element, count_1 ->
case rhs do
%{^element => count_2} when count_2 < count_1 ->
if count_2 < count_1 do
{true, count_2}
else
true
end
%{} ->
false
other ->
:erlang.error({:badmap, other})
end
end, lhs)
end
@spec sum(t(e1), t(e2)) :: t(e1 | e2) when e1: term(), e2: term()
def sum(lhs, rhs) do
Map.merge(lhs, rhs, fn
_, count_1, count_2 ->
count_1 + count_2
end)
end
@doc """
Subtracts the second argument from the first, in a soft (clamping) way.
See also `Multiplicities.subset?/2` and `Multiplicities.difference!/2`.
"""
@spec difference(t(e), t()) :: t(e) when e: term()
def difference(lhs, rhs) do
:maps.filtermap(fn element, count_1 ->
case rhs do
%{^element => count_2} ->
count_3 = count_1 - count_2
if count_3 > 0 do
{true, count_3}
else
false
end
%{} ->
true
other ->
:erlang.error({:badmap, other})
end
end, lhs)
end
@doc """
Subtracts the second argument from the first. Raises iff the first is not a subset of the second.
See also `Multiplicities.subset?/2` and `Multiplicities.difference/2`.
"""
@spec difference!(t(e), t()) :: t(e) when e: term()
def difference!(lhs, rhs) do
:maps.fold(fn element, count_2, acc ->
case acc do
%{^element => count_1} when count_1 >= count_2 ->
count_3 = count_1 - count_2
if count_3 > 0 do
%{acc | element => count_3}
else
Map.delete(acc, element)
end
%{} ->
raise KeyError, term: lhs, key: {__MODULE__, {element, count_2}}
other ->
:erlang.error({:badmap, other})
end
end, lhs, rhs)
end
@spec symmetric_difference(t(e1), t(e2)) :: t(e1 | e2) when e1: term(), e2: term()
def symmetric_difference(lhs, rhs) do
{lhs, rhs} = if map_size(rhs) > map_size(lhs), do: {rhs, lhs}, else: {lhs, rhs}
:maps.fold(fn element, count_2, acc ->
case acc do
%{^element => count_1} ->
count_3 = abs(count_1 - count_2)
if count_3 > 0 do
%{acc | element => count_3}
else
Map.delete(acc, element)
end
%{} ->
Map.put(acc, element, count_2)
other ->
:erlang.error({:badmap, other})
end
end, lhs, rhs)
end
end
@type t(_element) :: %__MODULE__{}
@type t() :: %__MODULE__{}
# create structs
def new(arg \\ %{})
def new(map) when is_non_struct_map(map), do: %__MODULE__{counts: Multiplicities.from_counts(map)}
def new(list) when is_list(list), do: %__MODULE__{counts: Multiplicities.from_elements(list)}
def new(%MapSet{} = set), do: %__MODULE__{counts: Multiplicities.from_set(set)}
def new(_ambiguous_enumerable), do: (raise ArgumentError, "Bad initializer. Explicitly call from_elements/1 or from_counts/1 instead")
def from_elements(enumerable), do: %__MODULE__{counts: Multiplicities.from_elements(enumerable)}
def from_counts(enumerable), do: %__MODULE__{counts: Multiplicities.from_counts(enumerable)}
# query structs
def member?(%__MODULE__{counts: multiplicities}, element), do: Multiplicities.member?(multiplicities, element)
def count(%__MODULE__{counts: multiplicities}), do: Multiplicities.count(multiplicities)
def support(%__MODULE__{counts: multiplicities}), do: Multiplicities.support(multiplicities)
def support_count(%__MODULE__{counts: multiplicities}), do: Multiplicities.support_count(multiplicities)
def count_element(%__MODULE__{counts: multiplicities}, element), do: Multiplicities.count_element(multiplicities, element)
# manipulate structs
def subset?(%__MODULE__{counts: lhs}, %__MODULE__{counts: rhs}), do: Multiplicities.subset?(lhs, rhs)
def to_counts(%__MODULE__{counts: multiplicities}), do: multiplicities
def to_list(%__MODULE__{counts: multiplicities}), do: Multiplicities.to_list(multiplicities)
def reduce_by(%__MODULE__{counts: multiplicities}, mode, acc, fun), do: Multiplicities.reduce_by(multiplicities, mode, acc, fun)
# combine structs
def union(%__MODULE__{counts: lhs}, %__MODULE__{counts: rhs}), do: %__MODULE__{counts: Multiplicities.union(lhs, rhs)}
def intersection(%__MODULE__{counts: lhs}, %__MODULE__{counts: rhs}), do: %__MODULE__{counts: Multiplicities.intersection(lhs, rhs)}
def sum(%__MODULE__{counts: lhs}, %__MODULE__{counts: rhs}), do: %__MODULE__{counts: Multiplicities.sum(lhs, rhs)}
def difference(%__MODULE__{counts: lhs}, %__MODULE__{counts: rhs}), do: %__MODULE__{counts: Multiplicities.difference(lhs, rhs)}
def difference!(%__MODULE__{counts: lhs}, %__MODULE__{counts: rhs}), do: %__MODULE__{counts: Multiplicities.difference!(lhs, rhs)}
def symmetric_difference(%__MODULE__{counts: lhs}, %__MODULE__{counts: rhs}), do: %__MODULE__{counts: Multiplicities.symmetric_difference(lhs, rhs)}
# protocol helpers
def inspect(multiset, opts) do
size = count(multiset)
cond do
size == 0 ->
"Multiset.new()"
:math.floor(:math.sqrt(size)) > support_count(multiset) ->
# `Multiset.new(%{42 => 10**100})`
Inspect.Algebra.concat([
"Multiset.from_counts(",
Inspect.Map.inspect(to_counts(multiset), opts),
")"
])
true ->
# https://github.com/elixir-lang/elixir/blob/v1.18.2/lib/elixir/lib/map_set.ex#L444
opts = %Inspect.Opts{opts | charlists: :as_lists}
Inspect.Algebra.concat([
"Multiset.new(",
Inspect.List.inspect(to_list(multiset), opts),
")"
])
end
end
def reduce(multiset, acc, fun) do
reduce_by(multiset, :elements, acc, fun)
end
def slice(%__MODULE__{counts: multiplicities}) do
{size, slicing_fun} = Multiplicities.enumerable_slice(multiplicities)
{:ok, size, slicing_fun}
end
defimpl Inspect, for: ISHYGDDT.Multiset do
defdelegate inspect(multiset, opts), to: ISHYGDDT.Multiset
end
defimpl Enumerable, for: ISHYGDDT.Multiset do
defdelegate count(_), to: ISHYGDDT.Multiset
defdelegate member?(_, _), to: ISHYGDDT.Multiset
def reduce(enumerable, {action, acc}, fun), do: throw :notimpl
defdelegate slice(_), to: ISHYGDDT.Multiset
end
end