DQN RL with Axon/Nx

I’m experimenting with DQN algorithms and wanted to use Axon/Nx to test some ideas. However, the Loop API does not seem to be a good fit for such reinforcement learning algorithms, as they need to use the model parameters for prediction after each batch (predictions and training are interspersed).

Is there a better approach than coding my own training API for now?

Can you elaborate on the limitations you’re running into?

Okay, I did some research and I think I see where you are running in to problems, but I think there are a couple of ways you can overcome these limitations. As a caveat, I haven’t done much RL.

As I understand, you maintain some experience replay which contains state transition information. At each step, you select an action from the input environment:

defn select_action(model, params, step, observation) do
  random = Nx.random_uniform({})
  # Epsilon greedy
  eps = (@eps_end + (@eps_start - @eps_end)) * Nx.exp(-1.0 * step / @eps_decay)
  if eps > @threshold do
    # Take model action (exploit)
    |> Axon.predict(params, observation, mode: :inference)
    |> Nx.reduce_max()
    |> Nx.reshape({1, 1})
    # Take random action (explore)
    Nx.random_uniform({1, 1}, 0, @n_actions, type: {:u, 64})

Okay so now we have a function which selects actions, our next step would be to define an optimization step, you can do this more granule with the optimization API itself, but it’s more verbose. This will create a single update function which updates model and optimization parameters behind the scenes. Assuming you have a step state with train_state as fields:

defn optimization_step(step_fn, train_state, sampled_experiences) do
  sampled_observations = sampled_experiences[:observations]
  expected_state_action_values = get_expected_values(sampled_experiences)
  # Train step, assuming you optimize model to predict Q
  step_fn.(step_state[:train_state], {sampled_observations, expected_state_action_values})

Now you need to combine these into an actual step that loop recognizes, notice this doesn’t need to be defn, and in this case it shouldn’t be because we need to do some things off the device:

def step(state, obs, model, train_step) do
  # Do whatever you need to do to get an observation from environment,
  # if environment is a process, you can store it in the train_state, here I assume
  # that `obs` which is essentially batch equates to querying the environment for 
  # something specific
  obs = get_observation(obs, state[:env])
  # Select action based on environment/observation
  action = select_action(model, state[:train_state][:params], state[:train_state][:i], obs)

  # Get experience from action
  {experience, new_env} = act(action, env)
  # If we have enough experiences in the experience relay, then we can sample and optimize!
  new_train_state =
    if Enum.count(state[:experiences]) > @batch_size do
      optimization_step(train_step, state[:train_state], Enum.take_random(state[:experiences], @batch_size)

  # Add experience to replay
  new_experiences = [experience | step_state[:experiences]

  # Return updated step state with new environment, experience replay, and train state  
    train_state: new_train_state,
    experiences: new_experiences,
    env: env

We also need to initialize in some way:

def init(train_init, compiler) do
  fn ->
    train_state = Nx.Defn.jit(train_init, [], compiler: compiler)
    env = init_env()
      train_state: train_state,
      experiences: [],
      env: init_env

And now we can construct the loop:

# Build the Axon model
model = build_model()

# Create the train step, huber loss, rmsprop optimizer
{train_init, train_step} = Axon.Loop.train_step(model, :huber, :rmsprop)

# Build the actual loop from init and step defined above, assuming we're interacting with
# an environment that is a process or something, we just pass dummy data which equates
# to max steps per episode, if you have something more meaningful, that works
|> init(EXLA)
|> Axon.Loop.step(&step(&1, &2, model, train_step)
|> Axon.Loop.run(1..@max_steps, epochs: @max_epochs, compiler: EXLA)

This is a very rough, off the top of my head outline of how I would solve the problem. It’s hard for me to come up with a concrete solution without working through it myself. There are other ways. For example PyTorch Ignite uses event handlers in a more sophisticated way to update the model. I will put RL examples on my backlog of things to add to the repository. I am currently reworking one of my older libraries which can interact with the Arcade Learning Environment from Elixir to add some easier to work with RL examples.

I hope this helps, let me know if you have any questions!


Wow, that’s awesome! Thanks for such a detailed answer @seanmor5 !

I’ll start from there and keep you posted. Thanks again!