Hi! I’m trying to generate some gaussian distributed variables, but encountered an error:
** (ArithmeticError) bad argument in arithmetic expression
(stdlib 4.3.1.3) :math.sqrt(-4.303224443447107e-13)
(complex 0.5.0) lib/complex.ex:772: Complex.sqrt/1
(nx 0.7.1) lib/nx/binary_backend.ex:2551: Nx.BinaryBackend.“-binary_to_binary/4-lbc$^6/2-11-”/5
(nx 0.7.1) lib/nx/binary_backend.ex:933: Nx.BinaryBackend.element_wise_unary_op/3
(nx 0.7.1) lib/nx/defn/evaluator.ex:441: Nx.Defn.Evaluator.eval_apply/4
(nx 0.7.1) lib/nx/defn/evaluator.ex:256: Nx.Defn.Evaluator.eval/3
(nx 0.7.1) lib/nx/defn/evaluator.ex:359: Nx.Defn.Evaluator.eval_apply/4
Please see the following code:
defn gaussian_rbf(x1, x2, l \\ 1.0, sigma_f \\ 1.0) do
dist_matrix =
Nx.sum(Nx.pow(x1, 2), axes: [1], keep_axes: true) +
Nx.sum(Nx.pow(x2, 2), axes: [1]) -
2 * Nx.dot(x1, Nx.transpose(x2))
Nx.pow(sigma_f, 2) * Nx.exp(-1 / (2 * Nx.pow(l, 2)) * dist_matrix)
end
def sample_gp(n \\ 20) do
x =
Nx.linspace(0, 1, n: n, type: {:f, 64})
|> Nx.reshape({:auto, 1})
mu = Nx.broadcast(0.0, {n, 1})
cov = gaussian_rbf(x, x)
IO.inspect(cov, limit: :infinity)
{multivariate_normal, _new_key} =
Nx.Random.key(59)
|> Nx.Random.multivariate_normal(Nx.flatten(mu), cov)
{mu, cov, x, multivariate_normal}
end
Thank you!
Edit: the cov is here, I can use it in python (np.random.multivariate_normal(mu.reshape(-1), cov, 10)
) but with a RuntimeWarning: covariance is not symmetric positive-semidefinite. samples = np.random.multivariate_normal(mu.reshape(-1), cov, 10)
The matrix is very close to the one generated by numpy with the same method, but a little bit different
nx.Tensor<
f64[20][20]
[
[1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622919, 0.8457004773919743, 0.8191844691944301, 0.7913048223601589, 0.762259565588209, 0.7322492269229477, 0.7014745656989334, 0.6701343875280227, 0.638423474554494, 0.6065306597126334],
[0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622918, 0.8457004773919744, 0.8191844691944301, 0.7913048223601589, 0.762259565588209, 0.7322492269229478, 0.7014745656989335, 0.6701343875280227, 0.638423474554494],
[0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622918, 0.8457004773919744, 0.8191844691944301, 0.7913048223601588, 0.762259565588209, 0.7322492269229478, 0.7014745656989334, 0.6701343875280227],
[0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622919, 0.8457004773919744, 0.81918446919443, 0.7913048223601589, 0.7622595655882091, 0.7322492269229478, 0.7014745656989334],
[0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622919, 0.8457004773919743, 0.8191844691944301, 0.7913048223601589, 0.762259565588209, 0.7322492269229477],
[0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528131, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622918, 0.8457004773919744, 0.8191844691944301, 0.7913048223601589, 0.762259565588209],
[0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622919, 0.8457004773919744, 0.8191844691944301, 0.7913048223601589],
[0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622919, 0.8457004773919743, 0.81918446919443],
[0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716, 0.8938758661318704, 0.8706596335622918, 0.8457004773919743],
[0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528131, 0.9343847048354291, 0.9151725435669718, 0.8938758661318704, 0.8706596335622919],
[0.8706596335622919, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.9876119969931421, 0.9780830788850546, 0.9659665827628695, 0.9513611828528131, 0.9343847048354292, 0.9151725435669716, 0.8938758661318704],
[0.8457004773919743, 0.8706596335622918, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528131, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291, 0.9151725435669716],
[0.8191844691944301, 0.8457004773919744, 0.8706596335622918, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138855, 0.987611996993142, 0.9780830788850546, 0.9659665827628696, 0.9513611828528132, 0.9343847048354291],
[0.7913048223601589, 0.8191844691944301, 0.8457004773919744, 0.8706596335622919, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.9876119969931421, 0.9944751522138855, 0.998615917176126, 1.0, 0.9986159171761259, 0.9944751522138856, 0.9876119969931421, 0.9780830788850546, 0.9659665827628697, 0.9513611828528132],
[0.762259565588209, 0.7913048223601589, 0.8191844691944301, 0.8457004773919744, 0.8706596335622919, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138855, 0.9986159171761259, 1.0, 0.998615917176126, 0.9944751522138856, 0.987611996993142, 0.9780830788850546, 0.9659665827628696],
[0.7322492269229477, 0.762259565588209, 0.7913048223601588, 0.81918446919443, 0.8457004773919743, 0.8706596335622918, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528131, 0.9659665827628695, 0.9780830788850546, 0.987611996993142, 0.9944751522138856, 0.998615917176126, 1.0, 0.998615917176126, 0.9944751522138856, 0.987611996993142, 0.9780830788850546],
[0.7014745656989334, 0.7322492269229478, 0.762259565588209, 0.7913048223601589, 0.8191844691944301, 0.8457004773919744, 0.8706596335622919, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528131, 0.9659665827628696, 0.9780830788850546, 0.9876119969931421, 0.9944751522138856, 0.998615917176126, 1.0, 0.9986159171761261, 0.9944751522138856, 0.987611996993142],
[0.6701343875280227, 0.7014745656989335, 0.7322492269229478, 0.7622595655882091, 0.7913048223601589, 0.8191844691944301, 0.8457004773919744, 0.8706596335622919, 0.8938758661318704, 0.9151725435669718, 0.9343847048354292, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138856, 0.9986159171761261, 1.0, 0.9986159171761261, 0.9944751522138856],
[0.638423474554494, 0.6701343875280227, 0.7014745656989334, 0.7322492269229478, 0.762259565588209, 0.7913048223601589, 0.8191844691944301, 0.8457004773919743, 0.8706596335622918, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628697, 0.9780830788850546, 0.987611996993142, 0.9944751522138856, 0.9986159171761261, 1.0, 0.998615917176126],
[0.6065306597126334, 0.638423474554494, 0.6701343875280227, 0.7014745656989334, 0.7322492269229477, 0.762259565588209, 0.7913048223601589, 0.81918446919443, 0.8457004773919743, 0.8706596335622919, 0.8938758661318704, 0.9151725435669716, 0.9343847048354291, 0.9513611828528132, 0.9659665827628696, 0.9780830788850546, 0.987611996993142, 0.9944751522138856, 0.998615917176126, 1.0]
]