# Nx.Random.multivariate_normal failed

Hi! I’m trying to generate some gaussian distributed variables, but encountered an error:

** (ArithmeticError) bad argument in arithmetic expression
(stdlib 4.3.1.3) :math.sqrt(-4.303224443447107e-13)
(complex 0.5.0) lib/complex.ex:772: Complex.sqrt/1
(nx 0.7.1) lib/nx/binary_backend.ex:2551: Nx.BinaryBackend.“-binary_to_binary/4-lbc\$^6/2-11-”/5
(nx 0.7.1) lib/nx/binary_backend.ex:933: Nx.BinaryBackend.element_wise_unary_op/3
(nx 0.7.1) lib/nx/defn/evaluator.ex:441: Nx.Defn.Evaluator.eval_apply/4
(nx 0.7.1) lib/nx/defn/evaluator.ex:256: Nx.Defn.Evaluator.eval/3
(nx 0.7.1) lib/nx/defn/evaluator.ex:359: Nx.Defn.Evaluator.eval_apply/4

Please see the following code:

``````  defn gaussian_rbf(x1, x2, l \\ 1.0, sigma_f \\ 1.0) do
dist_matrix =
Nx.sum(Nx.pow(x1, 2), axes: [1], keep_axes: true) +
Nx.sum(Nx.pow(x2, 2), axes: [1]) -
2 * Nx.dot(x1, Nx.transpose(x2))

Nx.pow(sigma_f, 2) * Nx.exp(-1 / (2 * Nx.pow(l, 2)) * dist_matrix)
end

def sample_gp(n \\ 20) do
x =
Nx.linspace(0, 1, n: n, type: {:f, 64})
|> Nx.reshape({:auto, 1})

mu = Nx.broadcast(0.0, {n, 1})
cov = gaussian_rbf(x, x)

IO.inspect(cov, limit: :infinity)

{multivariate_normal, _new_key} =
Nx.Random.key(59)
|> Nx.Random.multivariate_normal(Nx.flatten(mu), cov)

{mu, cov, x, multivariate_normal}
end
``````

Thank you!

Edit: the cov is here, I can use it in python (`np.random.multivariate_normal(mu.reshape(-1), cov, 10)`) but with a `RuntimeWarning: covariance is not symmetric positive-semidefinite. samples = np.random.multivariate_normal(mu.reshape(-1), cov, 10)`

The matrix is very close to the one generated by numpy with the same method, but a little bit different

nx.Tensor<
f64[20][20]
[
[1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622919, 0.8457004773919743, 0.8191844691944301, 0.7913048223601589, 0.762259565588209, 0.7322492269229477, 0.7014745656989334, 0.6701343875280227, 0.638423474554494, 0.6065306597126334],
[0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622918, 0.8457004773919744, 0.8191844691944301, 0.7913048223601589, 0.762259565588209, 0.7322492269229478, 0.7014745656989335, 0.6701343875280227, 0.638423474554494],
[0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622918, 0.8457004773919744, 0.8191844691944301, 0.7913048223601588, 0.762259565588209, 0.7322492269229478, 0.7014745656989334, 0.6701343875280227],
[0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622919, 0.8457004773919744, 0.81918446919443, 0.7913048223601589, 0.7622595655882091, 0.7322492269229478, 0.7014745656989334],
[0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622919, 0.8457004773919743, 0.8191844691944301, 0.7913048223601589, 0.762259565588209, 0.7322492269229477],
[0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528131, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622918, 0.8457004773919744, 0.8191844691944301, 0.7913048223601589, 0.762259565588209],
[0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622919, 0.8457004773919744, 0.8191844691944301, 0.7913048223601589],
[0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622919, 0.8457004773919743, 0.81918446919443],
[0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622918, 0.8457004773919743],
[0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528131, 0.9343847048354291, 0.9151725435669718, 0.8938758661318704, 0.8706596335622919],
[0.8706596335622919, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.9876119969931421, 0.9780830788850546, 0.9659665827628695, 0.9513611828528131, 0.9343847048354292, 0.9151725435669716, 0.8938758661318704],
[0.8457004773919743, 0.8706596335622918, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528131, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716],
[0.8191844691944301, 0.8457004773919744, 0.8706596335622918, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291],
[0.7913048223601589, 0.8191844691944301, 0.8457004773919744, 0.8706596335622919, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.9876119969931421, 0.9944751522138855, 0.998615917176126, 1.0, 0.9986159171761259, 0.9944751522138856, 0.9876119969931421, 0.9780830788850546, 0.9659665827628697, 0.9513611828528132],
[0.762259565588209, 0.7913048223601589, 0.8191844691944301, 0.8457004773919744, 0.8706596335622919, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.9986159171761259, 1.0, 0.998615917176126, 0.9944751522138856, 0.987611996993142, 0.9780830788850546, 0.9659665827628696],
[0.7322492269229477, 0.762259565588209, 0.7913048223601588, 0.81918446919443, 0.8457004773919743, 0.8706596335622918, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528131, 0.9659665827628695, 0.9780830788850546, 0.987611996993142, 0.9944751522138856, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138856, 0.987611996993142, 0.9780830788850546],
[0.7014745656989334, 0.7322492269229478, 0.762259565588209, 0.7913048223601589, 0.8191844691944301, 0.8457004773919744, 0.8706596335622919, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528131, 0.9659665827628696, 0.9780830788850546, 0.9876119969931421, 0.9944751522138856, 0.998615917176126, 1.0, 0.9986159171761261, 0.9944751522138856, 0.987611996993142],
[0.6701343875280227, 0.7014745656989335, 0.7322492269229478, 0.7622595655882091, 0.7913048223601589, 0.8191844691944301, 0.8457004773919744, 0.8706596335622919, 0.8938758661318704, 0.9151725435669718, 0.9343847048354292, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138856, 0.9986159171761261, 1.0, 0.9986159171761261, 0.9944751522138856],
[0.638423474554494, 0.6701343875280227, 0.7014745656989334, 0.7322492269229478, 0.762259565588209, 0.7913048223601589, 0.8191844691944301, 0.8457004773919743, 0.8706596335622918, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628697, 0.9780830788850546, 0.987611996993142, 0.9944751522138856, 0.9986159171761261, 1.0, 0.998615917176126],
[0.6065306597126334, 0.638423474554494, 0.6701343875280227, 0.7014745656989334, 0.7322492269229477, 0.762259565588209, 0.7913048223601589, 0.81918446919443, 0.8457004773919743, 0.8706596335622919, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138856, 0.998615917176126, 1.0]
]

How are you calling this code?

edit: also, what is the Nx version you’re using?

How are you calling this code?

just `GaussianProcess.sample_gp()` (both functions are in GaussianProcess module)

what is the Nx version you’re using?

{:nx, “~> 0.7.1”}

Unfortunately I wasn’t able to reproduce the problem locally. However, given that you get that warning in Numpy, it might be indicative of numerical stability issues in your gaussian_rbf function.

Thank you! This is my full implementation, please try to run it

I found that while using `Nx.linspace(0, 1, n: n, type: {:f, 64})`, I can run `sample_gp(n)` with n at most 16, giving 17 will cause the error. Increasing the size to 0…2 the maximum n will come to 22 (sry for my bad English, hope you can understand…). I’m on apple M3 Max, may be the issue is platform related?

I forgot what was modified but numpy doesn’t warn anymore, maybe that was due to printing precision.

If you still cannot reproduce the problem, I’ll try to raise an issue on GitHub

``````defmodule GaussianProcess do
import Nx.Defn

defn gaussian_rbf(x1, x2, l \\ 1.0, sigma_f \\ 1.0) do
dist_matrix =
Nx.sum(Nx.pow(x1, 2), axes: [1], keep_axes: true) +
Nx.sum(Nx.pow(x2, 2), axes: [1]) -
2 * Nx.dot(x1, Nx.transpose(x2))

Nx.pow(sigma_f, 2) * Nx.exp(-1.0 / (2 * Nx.pow(l, 2)) * dist_matrix)
end

def sample_gp(n \\ 20) do
x =
Nx.linspace(0, 1, n: n, type: {:f, 64})
|> Nx.reshape({:auto, 1})

mu = Nx.broadcast(0.0, {n, 1})
cov = gaussian_rbf(x, x)

IO.inspect(cov, limit: :infinity)

{multivariate_normal, _new_key} =
Nx.Random.key(59)
|> Nx.Random.multivariate_normal(Nx.flatten(mu), cov, shape: {50})

{mu, cov, x, multivariate_normal}
end

def plot({mu, cov, x, samples}) do
x = Nx.to_flat_list(x)
mu = Nx.to_flat_list(mu)
samples = Nx.to_list(samples)

uncertainty =
Nx.multiply(1.96, Nx.sqrt(Nx.take_diagonal(cov)))
|> Nx.to_flat_list()

# Prepare data for plotting
data =
Enum.with_index(x)
|> Enum.map(fn {x, i} ->
%{
x: x,
mu: Enum.at(mu, i),
upper: Enum.at(mu, i) + Enum.at(uncertainty, i),
lower: Enum.at(mu, i) - Enum.at(uncertainty, i),
samples: Enum.map(samples, fn sample -> Enum.at(sample, i) end)
}
end)

# Plot using VegaLite
VegaLite.new(width: 400, height: 300)
|> VegaLite.data_from_values(data)
|> VegaLite.layers(
[
VegaLite.new()
|> VegaLite.mark(:line, %{point: true})
|> VegaLite.encode_field(:x, "x", type: :quantitative)
|> VegaLite.encode_field(:y, "mu", type: :quantitative, axis: %{title: "Mean"}),
VegaLite.new()
|> VegaLite.mark(:area, %{opacity: 0.3})
|> VegaLite.encode_field(:x, "x", type: :quantitative)
|> VegaLite.encode_field(:y, "upper", type: :quantitative)
|> VegaLite.encode_field(:y2, "lower", type: :quantitative)
] ++
Enum.map(Enum.with_index(samples), fn {_sample, i} ->
VegaLite.new()
|> VegaLite.mark(:line, %{strokeDash: [5, 5]})
|> VegaLite.encode_field(:x, "x", type: :quantitative)
|> VegaLite.encode_field(:y, "samples[#{i}]",
type: :quantitative
# axis: %{title: "Samples"}
)
end)
)

# |> VegaLite.to_spec()
end
end

GaussianProcess.sample_gp(30)
|> GaussianProcess.plot()
``````

Nx defaults to Nx.LinAlg.cholesky method, and numpy is using svd. Everything works after changing the method.

2 Likes