Okay, I did some research and I think I see where you are running in to problems, but I think there are a couple of ways you can overcome these limitations. As a caveat, I haven’t done much RL.
As I understand, you maintain some experience replay which contains state transition information. At each step, you select an action from the input environment:
defn select_action(model, params, step, observation) do
random = Nx.random_uniform({})
# Epsilon greedy
eps = (@eps_end + (@eps_start - @eps_end)) * Nx.exp(-1.0 * step / @eps_decay)
if eps > @threshold do
# Take model action (exploit)
model
|> Axon.predict(params, observation, mode: :inference)
|> Nx.reduce_max()
|> Nx.reshape({1, 1})
else
# Take random action (explore)
Nx.random_uniform({1, 1}, 0, @n_actions, type: {:u, 64})
end
end
Okay so now we have a function which selects actions, our next step would be to define an optimization step, you can do this more granule with the optimization API itself, but it’s more verbose. This will create a single update function which updates model and optimization parameters behind the scenes. Assuming you have a step state with train_state
as fields:
defn optimization_step(step_fn, train_state, sampled_experiences) do
sampled_observations = sampled_experiences[:observations]
expected_state_action_values = get_expected_values(sampled_experiences)
# Train step, assuming you optimize model to predict Q
step_fn.(step_state[:train_state], {sampled_observations, expected_state_action_values})
end
Now you need to combine these into an actual step that loop recognizes, notice this doesn’t need to be defn
, and in this case it shouldn’t be because we need to do some things off the device:
def step(state, obs, model, train_step) do
# Do whatever you need to do to get an observation from environment,
# if environment is a process, you can store it in the train_state, here I assume
# that `obs` which is essentially batch equates to querying the environment for
# something specific
obs = get_observation(obs, state[:env])
# Select action based on environment/observation
action = select_action(model, state[:train_state][:params], state[:train_state][:i], obs)
# Get experience from action
{experience, new_env} = act(action, env)
# If we have enough experiences in the experience relay, then we can sample and optimize!
new_train_state =
if Enum.count(state[:experiences]) > @batch_size do
optimization_step(train_step, state[:train_state], Enum.take_random(state[:experiences], @batch_size)
else
state[:train_state]
end
# Add experience to replay
new_experiences = [experience | step_state[:experiences]
# Return updated step state with new environment, experience replay, and train state
%{
train_state: new_train_state,
experiences: new_experiences,
env: env
}
end
We also need to initialize in some way:
def init(train_init, compiler) do
fn ->
train_state = Nx.Defn.jit(train_init, [], compiler: compiler)
env = init_env()
%{
train_state: train_state,
experiences: [],
env: init_env
}
end
end
And now we can construct the loop:
# Build the Axon model
model = build_model()
# Create the train step, huber loss, rmsprop optimizer
{train_init, train_step} = Axon.Loop.train_step(model, :huber, :rmsprop)
# Build the actual loop from init and step defined above, assuming we're interacting with
# an environment that is a process or something, we just pass dummy data which equates
# to max steps per episode, if you have something more meaningful, that works
train_init
|> init(EXLA)
|> Axon.Loop.step(&step(&1, &2, model, train_step)
|> Axon.Loop.run(1..@max_steps, epochs: @max_epochs, compiler: EXLA)
This is a very rough, off the top of my head outline of how I would solve the problem. It’s hard for me to come up with a concrete solution without working through it myself. There are other ways. For example PyTorch Ignite uses event handlers in a more sophisticated way to update the model. I will put RL examples on my backlog of things to add to the repository. I am currently reworking one of my older libraries which can interact with the Arcade Learning Environment from Elixir to add some easier to work with RL examples.
I hope this helps, let me know if you have any questions!