Nx.Random.multivariate_normal failed

Hi! I’m trying to generate some gaussian distributed variables, but encountered an error:

** (ArithmeticError) bad argument in arithmetic expression
(stdlib 4.3.1.3) :math.sqrt(-4.303224443447107e-13)
(complex 0.5.0) lib/complex.ex:772: Complex.sqrt/1
(nx 0.7.1) lib/nx/binary_backend.ex:2551: Nx.BinaryBackend.“-binary_to_binary/4-lbc$^6/2-11-”/5
(nx 0.7.1) lib/nx/binary_backend.ex:933: Nx.BinaryBackend.element_wise_unary_op/3
(nx 0.7.1) lib/nx/defn/evaluator.ex:441: Nx.Defn.Evaluator.eval_apply/4
(nx 0.7.1) lib/nx/defn/evaluator.ex:256: Nx.Defn.Evaluator.eval/3
(nx 0.7.1) lib/nx/defn/evaluator.ex:359: Nx.Defn.Evaluator.eval_apply/4

Please see the following code:

  defn gaussian_rbf(x1, x2, l \\ 1.0, sigma_f \\ 1.0) do
    dist_matrix =
      Nx.sum(Nx.pow(x1, 2), axes: [1], keep_axes: true) +
        Nx.sum(Nx.pow(x2, 2), axes: [1]) -
        2 * Nx.dot(x1, Nx.transpose(x2))

    Nx.pow(sigma_f, 2) * Nx.exp(-1 / (2 * Nx.pow(l, 2)) * dist_matrix)
  end

  def sample_gp(n \\ 20) do
    x =
      Nx.linspace(0, 1, n: n, type: {:f, 64})
      |> Nx.reshape({:auto, 1})

    mu = Nx.broadcast(0.0, {n, 1})
    cov = gaussian_rbf(x, x)

    IO.inspect(cov, limit: :infinity)

    {multivariate_normal, _new_key} =
      Nx.Random.key(59)
      |> Nx.Random.multivariate_normal(Nx.flatten(mu), cov)

    {mu, cov, x, multivariate_normal}
  end

Thank you! :smiley:

Edit: the cov is here, I can use it in python (np.random.multivariate_normal(mu.reshape(-1), cov, 10)) but with a RuntimeWarning: covariance is not symmetric positive-semidefinite. samples = np.random.multivariate_normal(mu.reshape(-1), cov, 10)

The matrix is very close to the one generated by numpy with the same method, but a little bit different

nx.Tensor<
f64[20][20]
[
[1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622919, 0.8457004773919743, 0.8191844691944301, 0.7913048223601589, 0.762259565588209, 0.7322492269229477, 0.7014745656989334, 0.6701343875280227, 0.638423474554494, 0.6065306597126334],
[0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622918, 0.8457004773919744, 0.8191844691944301, 0.7913048223601589, 0.762259565588209, 0.7322492269229478, 0.7014745656989335, 0.6701343875280227, 0.638423474554494],
[0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622918, 0.8457004773919744, 0.8191844691944301, 0.7913048223601588, 0.762259565588209, 0.7322492269229478, 0.7014745656989334, 0.6701343875280227],
[0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622919, 0.8457004773919744, 0.81918446919443, 0.7913048223601589, 0.7622595655882091, 0.7322492269229478, 0.7014745656989334],
[0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622919, 0.8457004773919743, 0.8191844691944301, 0.7913048223601589, 0.762259565588209, 0.7322492269229477],
[0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528131, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622918, 0.8457004773919744, 0.8191844691944301, 0.7913048223601589, 0.762259565588209],
[0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622919, 0.8457004773919744, 0.8191844691944301, 0.7913048223601589],
[0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622919, 0.8457004773919743, 0.81918446919443],
[0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622918, 0.8457004773919743],
[0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528131, 0.9343847048354291, 0.9151725435669718, 0.8938758661318704, 0.8706596335622919],
[0.8706596335622919, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.9876119969931421, 0.9780830788850546, 0.9659665827628695, 0.9513611828528131, 0.9343847048354292, 0.9151725435669716, 0.8938758661318704],
[0.8457004773919743, 0.8706596335622918, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528131, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716],
[0.8191844691944301, 0.8457004773919744, 0.8706596335622918, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291],
[0.7913048223601589, 0.8191844691944301, 0.8457004773919744, 0.8706596335622919, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.9876119969931421, 0.9944751522138855, 0.998615917176126, 1.0, 0.9986159171761259, 0.9944751522138856, 0.9876119969931421, 0.9780830788850546, 0.9659665827628697, 0.9513611828528132],
[0.762259565588209, 0.7913048223601589, 0.8191844691944301, 0.8457004773919744, 0.8706596335622919, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.9986159171761259, 1.0, 0.998615917176126, 0.9944751522138856, 0.987611996993142, 0.9780830788850546, 0.9659665827628696],
[0.7322492269229477, 0.762259565588209, 0.7913048223601588, 0.81918446919443, 0.8457004773919743, 0.8706596335622918, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528131, 0.9659665827628695, 0.9780830788850546, 0.987611996993142, 0.9944751522138856, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138856, 0.987611996993142, 0.9780830788850546],
[0.7014745656989334, 0.7322492269229478, 0.762259565588209, 0.7913048223601589, 0.8191844691944301, 0.8457004773919744, 0.8706596335622919, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528131, 0.9659665827628696, 0.9780830788850546, 0.9876119969931421, 0.9944751522138856, 0.998615917176126, 1.0, 0.9986159171761261, 0.9944751522138856, 0.987611996993142],
[0.6701343875280227, 0.7014745656989335, 0.7322492269229478, 0.7622595655882091, 0.7913048223601589, 0.8191844691944301, 0.8457004773919744, 0.8706596335622919, 0.8938758661318704, 0.9151725435669718, 0.9343847048354292, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138856, 0.9986159171761261, 1.0, 0.9986159171761261, 0.9944751522138856],
[0.638423474554494, 0.6701343875280227, 0.7014745656989334, 0.7322492269229478, 0.762259565588209, 0.7913048223601589, 0.8191844691944301, 0.8457004773919743, 0.8706596335622918, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628697, 0.9780830788850546, 0.987611996993142, 0.9944751522138856, 0.9986159171761261, 1.0, 0.998615917176126],
[0.6065306597126334, 0.638423474554494, 0.6701343875280227, 0.7014745656989334, 0.7322492269229477, 0.762259565588209, 0.7913048223601589, 0.81918446919443, 0.8457004773919743, 0.8706596335622919, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138856, 0.998615917176126, 1.0]
]

How are you calling this code?

edit: also, what is the Nx version you’re using?

How are you calling this code?

just GaussianProcess.sample_gp() (both functions are in GaussianProcess module)

what is the Nx version you’re using?

{:nx, “~> 0.7.1”}

Unfortunately I wasn’t able to reproduce the problem locally. However, given that you get that warning in Numpy, it might be indicative of numerical stability issues in your gaussian_rbf function.

Thank you! This is my full implementation, please try to run it :slight_smile:

I found that while using Nx.linspace(0, 1, n: n, type: {:f, 64}), I can run sample_gp(n) with n at most 16, giving 17 will cause the error. Increasing the size to 0…2 the maximum n will come to 22 (sry for my bad English, hope you can understand…). I’m on apple M3 Max, may be the issue is platform related?

I forgot what was modified but numpy doesn’t warn anymore, maybe that was due to printing precision.

If you still cannot reproduce the problem, I’ll try to raise an issue on GitHub :smile:

defmodule GaussianProcess do
  import Nx.Defn

  defn gaussian_rbf(x1, x2, l \\ 1.0, sigma_f \\ 1.0) do
    dist_matrix =
      Nx.sum(Nx.pow(x1, 2), axes: [1], keep_axes: true) +
        Nx.sum(Nx.pow(x2, 2), axes: [1]) -
        2 * Nx.dot(x1, Nx.transpose(x2))

    Nx.pow(sigma_f, 2) * Nx.exp(-1.0 / (2 * Nx.pow(l, 2)) * dist_matrix)
  end

  def sample_gp(n \\ 20) do
    x =
      Nx.linspace(0, 1, n: n, type: {:f, 64})
      |> Nx.reshape({:auto, 1})

    mu = Nx.broadcast(0.0, {n, 1})
    cov = gaussian_rbf(x, x)

    IO.inspect(cov, limit: :infinity)

    {multivariate_normal, _new_key} =
      Nx.Random.key(59)
      |> Nx.Random.multivariate_normal(Nx.flatten(mu), cov, shape: {50})

    {mu, cov, x, multivariate_normal}
  end

  def plot({mu, cov, x, samples}) do
    x = Nx.to_flat_list(x)
    mu = Nx.to_flat_list(mu)
    samples = Nx.to_list(samples)

    uncertainty =
      Nx.multiply(1.96, Nx.sqrt(Nx.take_diagonal(cov)))
      |> Nx.to_flat_list()

    # Prepare data for plotting
    data =
      Enum.with_index(x)
      |> Enum.map(fn {x, i} ->
        %{
          x: x,
          mu: Enum.at(mu, i),
          upper: Enum.at(mu, i) + Enum.at(uncertainty, i),
          lower: Enum.at(mu, i) - Enum.at(uncertainty, i),
          samples: Enum.map(samples, fn sample -> Enum.at(sample, i) end)
        }
      end)

    # Plot using VegaLite
    VegaLite.new(width: 400, height: 300)
    |> VegaLite.data_from_values(data)
    |> VegaLite.layers(
      [
        VegaLite.new()
        |> VegaLite.mark(:line, %{point: true})
        |> VegaLite.encode_field(:x, "x", type: :quantitative)
        |> VegaLite.encode_field(:y, "mu", type: :quantitative, axis: %{title: "Mean"}),
        VegaLite.new()
        |> VegaLite.mark(:area, %{opacity: 0.3})
        |> VegaLite.encode_field(:x, "x", type: :quantitative)
        |> VegaLite.encode_field(:y, "upper", type: :quantitative)
        |> VegaLite.encode_field(:y2, "lower", type: :quantitative)
      ] ++
        Enum.map(Enum.with_index(samples), fn {_sample, i} ->
          VegaLite.new()
          |> VegaLite.mark(:line, %{strokeDash: [5, 5]})
          |> VegaLite.encode_field(:x, "x", type: :quantitative)
          |> VegaLite.encode_field(:y, "samples[#{i}]",
            type: :quantitative
            # axis: %{title: "Samples"}
          )
        end)
    )

    # |> VegaLite.to_spec()
  end
end

GaussianProcess.sample_gp(30)
|> GaussianProcess.plot()

Nx defaults to Nx.LinAlg.cholesky method, and numpy is using svd. Everything works after changing the method.

2 Likes