Let's build GPT with Nx - how to translate this code in Elixir?

I’m studying Let’s build GPT: from scratch, in code, spelled out. in Elixir.

I have no idea how to translate the code below in Elixir.

batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix]) <-- especially this
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

code: Google Colab

What is the best translation for that?

2 Likes

Assuming your data is an Elixir list/enumerable, you need to work with it outside of defn. Then you stack it and then you call defn. In other words, you should write this code in “pure Elixir”. :slight_smile:

2 Likes

Thanks for reply!

I finally did it like below.

I tried to use Nx.Random.choice/3, but couldn’t find the way.

key = Nx.Random.key(1337)

# how many independent sequences will we process in parallel?
batch_size = 4
# what is the maximum context length for predictions?
block_size = 8

defmodule Trainer do
  @key key

  def get_batch(%Nx.Tensor{} = tensor, batch_size, block_size) do
    {ix, _} = Nx.Random.randint(@key, 0, 10, shape: {batch_size})

    {x, y} =
      ix
      |> Nx.to_list()
      |> Enum.reduce({[], []}, fn i, {x, y} ->
        {
          [tensor[i..(i + block_size - 1)] | x],
          [tensor[(i + 1)..(i + block_size)] | y]
        }
      end)

    {Nx.Batch.stack(x), Nx.Batch.stack(y)}
  end
end

{xb, yb} = Trainer.get_batch(train_tensor, batch_size, block_size)

Nx.Defn.jit_apply(&Function.identity/1, [xb]) |> IO.inspect(label: "inputs")
Nx.Defn.jit_apply(&Function.identity/1, [yb]) |> IO.inspect(label: "targets")
1 Like

You should avoid setting module attributes as tensors because that will only work when using the BinaryBackend as your default backend

2 Likes

Have you tried looking into Nx.take and Nx.slice_along_axis to build your batches directly using Nx?

1 Like

If you are not inside defn, I would use regular Enum.random for the randomness. And you can use Nx.stack (the batch will just call Nx.stack anyway).

Yeah, if the input is a tensor, then using defn is better but I am not sure how to do it.

1 Like

I got this code, hopefully I read the original one correctly:

iex(3)> key = Nx.Random.key(System.system_time())
#Nx.Tensor<
  u32[2]
  [392370618, 318517384]
>
iex(4)> {data, key} = Nx.Random.randint(key, 0, 20, shape: {10})
{#Nx.Tensor<
   s64[10]
   [8, 0, 10, 14, 8, 19, 9, 9, 13, 12]
 >,
 #Nx.Tensor<
   u32[2]
   [2734010943, 1393708537]
 >}
iex(5)> defmodule Mod do                                        
...(5)>   import Nx.Defn
...(5)> 
...(5)>   defn get_batch(data, key, opts \\ []) do
...(5)>     opts = keyword!(opts, batch_size: 4, block_size: 8)
...(5)> 
...(5)>     block_size = opts[:block_size]
...(5)>     batch_size = opts[:batch_size]
...(5)>     
...(5)>     {n} = Nx.shape(data)
...(5)> 
...(5)>     {ix, key} = Nx.Random.randint(key, 0, n - block_size, shape: {batch_size, 1})
...(5)> 
...(5)>     x_indices = ix + Nx.iota({1, block_size})
...(5)>     
...(5)>     x = Nx.take(data, x_indices)
...(5)>     y = Nx.take(data, x_indices + 1)
...(5)>     
...(5)>     {x, y, key}
...(5)>   end
...(5)> end
{:module, Mod,
 <<70, 79, 82, 49, 0, 0, 13, 68, 66, 69, 65, 77, 65, 116, 85, 56, 0, 0, 1, 145,
   0, 0, 0, 38, 10, 69, 108, 105, 120, 105, 114, 46, 77, 111, 100, 8, 95, 95,
   105, 110, 102, 111, 95, 95, 10, 97, 116, ...>>, true}
iex(6)> {x, y, key} = Mod.get_batch(data, key, batch_size: 6, block_size: 3)
{#Nx.Tensor<
   s64[6][3]
   [
     [10, 14, 8],
     [8, 19, 9],
     [14, 8, 19],
     [8, 19, 9],
     [14, 8, 19],
     [10, 14, 8]
   ]
 >,
 #Nx.Tensor<
   s64[6][3]
   [
     [14, 8, 19],
     [19, 9, 9],
     [8, 19, 9],
     [19, 9, 9],
     [8, 19, 9],
     [14, 8, 19]
   ]
 >,
 #Nx.Tensor<
   u32[2]
   [3992323708, 1530313987]
 >}
2 Likes

Nice! It may be something we want to include in Nx but most of the times batches are not tensors, so I am not sure!

2 Likes

This, together with some things I tried to do in the last few days, got me thinking that maybe if we have Nx.iota with strides (both Jax and PyTorch seem to accept in arange), we could get a more efficient version of NxSignal.as_windowed… something to investigate on the next few days!

This snippet seems like something we could have on a cheatsheet. I think we need a Python->Nx translation cheatsheet (or one for each major lib)

3 Likes

We can support ranges in iota for sure, in addition to the shape? I think that would be a great feature.

3 Likes

Exactly, it would be the equivalent to the stepped range in Elixir. 0…n//s would be equivalent to Nx.iota({n + 1}, strides: [s]), but we would support n-dimensions of course

4 Likes

Ok, we’ve added support for stepped ranges, and the step and also be negative! This mimics the same behavior as Python’s array slicing

7 Likes