I got this code, hopefully I read the original one correctly:
iex(3)> key = Nx.Random.key(System.system_time())
#Nx.Tensor<
u32[2]
[392370618, 318517384]
>
iex(4)> {data, key} = Nx.Random.randint(key, 0, 20, shape: {10})
{#Nx.Tensor<
s64[10]
[8, 0, 10, 14, 8, 19, 9, 9, 13, 12]
>,
#Nx.Tensor<
u32[2]
[2734010943, 1393708537]
>}
iex(5)> defmodule Mod do
...(5)> import Nx.Defn
...(5)>
...(5)> defn get_batch(data, key, opts \\ []) do
...(5)> opts = keyword!(opts, batch_size: 4, block_size: 8)
...(5)>
...(5)> block_size = opts[:block_size]
...(5)> batch_size = opts[:batch_size]
...(5)>
...(5)> {n} = Nx.shape(data)
...(5)>
...(5)> {ix, key} = Nx.Random.randint(key, 0, n - block_size, shape: {batch_size, 1})
...(5)>
...(5)> x_indices = ix + Nx.iota({1, block_size})
...(5)>
...(5)> x = Nx.take(data, x_indices)
...(5)> y = Nx.take(data, x_indices + 1)
...(5)>
...(5)> {x, y, key}
...(5)> end
...(5)> end
{:module, Mod,
<<70, 79, 82, 49, 0, 0, 13, 68, 66, 69, 65, 77, 65, 116, 85, 56, 0, 0, 1, 145,
0, 0, 0, 38, 10, 69, 108, 105, 120, 105, 114, 46, 77, 111, 100, 8, 95, 95,
105, 110, 102, 111, 95, 95, 10, 97, 116, ...>>, true}
iex(6)> {x, y, key} = Mod.get_batch(data, key, batch_size: 6, block_size: 3)
{#Nx.Tensor<
s64[6][3]
[
[10, 14, 8],
[8, 19, 9],
[14, 8, 19],
[8, 19, 9],
[14, 8, 19],
[10, 14, 8]
]
>,
#Nx.Tensor<
s64[6][3]
[
[14, 8, 19],
[19, 9, 9],
[8, 19, 9],
[19, 9, 9],
[8, 19, 9],
[14, 8, 19]
]
>,
#Nx.Tensor<
u32[2]
[3992323708, 1530313987]
>}