Let's build GPT with Nx - how to translate this code in Elixir?

I got this code, hopefully I read the original one correctly:

iex(3)> key = Nx.Random.key(System.system_time())
#Nx.Tensor<
  u32[2]
  [392370618, 318517384]
>
iex(4)> {data, key} = Nx.Random.randint(key, 0, 20, shape: {10})
{#Nx.Tensor<
   s64[10]
   [8, 0, 10, 14, 8, 19, 9, 9, 13, 12]
 >,
 #Nx.Tensor<
   u32[2]
   [2734010943, 1393708537]
 >}
iex(5)> defmodule Mod do                                        
...(5)>   import Nx.Defn
...(5)> 
...(5)>   defn get_batch(data, key, opts \\ []) do
...(5)>     opts = keyword!(opts, batch_size: 4, block_size: 8)
...(5)> 
...(5)>     block_size = opts[:block_size]
...(5)>     batch_size = opts[:batch_size]
...(5)>     
...(5)>     {n} = Nx.shape(data)
...(5)> 
...(5)>     {ix, key} = Nx.Random.randint(key, 0, n - block_size, shape: {batch_size, 1})
...(5)> 
...(5)>     x_indices = ix + Nx.iota({1, block_size})
...(5)>     
...(5)>     x = Nx.take(data, x_indices)
...(5)>     y = Nx.take(data, x_indices + 1)
...(5)>     
...(5)>     {x, y, key}
...(5)>   end
...(5)> end
{:module, Mod,
 <<70, 79, 82, 49, 0, 0, 13, 68, 66, 69, 65, 77, 65, 116, 85, 56, 0, 0, 1, 145,
   0, 0, 0, 38, 10, 69, 108, 105, 120, 105, 114, 46, 77, 111, 100, 8, 95, 95,
   105, 110, 102, 111, 95, 95, 10, 97, 116, ...>>, true}
iex(6)> {x, y, key} = Mod.get_batch(data, key, batch_size: 6, block_size: 3)
{#Nx.Tensor<
   s64[6][3]
   [
     [10, 14, 8],
     [8, 19, 9],
     [14, 8, 19],
     [8, 19, 9],
     [14, 8, 19],
     [10, 14, 8]
   ]
 >,
 #Nx.Tensor<
   s64[6][3]
   [
     [14, 8, 19],
     [19, 9, 9],
     [8, 19, 9],
     [19, 9, 9],
     [8, 19, 9],
     [14, 8, 19]
   ]
 >,
 #Nx.Tensor<
   u32[2]
   [3992323708, 1530313987]
 >}
2 Likes