Axon question: LSTM with Dropout

I would like to do something like the following:

input = Axon.input("input", shape: {prediction_days, 1})

model =
  input
  |> Axon.lstm(50)
  |> Axon.dropout(rate: 0.2)
  |> Axon.lstm(50)
  |> Axon.dropout(rate: 0.2)
  |> Axon.lstm(50)
  |> Axon.dropout(rate: 0.2)
  |> Axon.dense(1)

But when I try this out in LiveBook, I get the following error:

** (FunctionClauseError) no function clause matching in Axon.dropout/2    
    
    The following arguments were given to Axon.dropout/2:
    
        # 1
        {#Axon<
           inputs: %{"input" => {60, 1}}
           outputs: "lstm_1_output_sequence"
           nodes: 6
         >,
         {#Axon<
            inputs: %{"input" => {60, 1}}
            outputs: "lstm_1_c_hidden_state"
            nodes: 6
          >,
          #Axon<
            inputs: %{"input" => {60, 1}}
            outputs: "lstm_1_h_hidden_state"
            nodes: 6
          >}}
    
        # 2
        [rate: 0.2]
    
    Attempted function clauses (showing 1 out of 1):
    
        def dropout(%Axon{} = x, opts)

The Python code I am trying to replicate looks like this:

model = Sequential()

model.add(LSTM(units=50, return_sequences=True, input_shape=[x_train.shape[1],  1]))
model.add(Dropout(0.2))
model.add(LSTM(units=50, return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(units=50))
model.add(Dropout(0.2))
model.add(Dense(units=1))

model.compile(optimizer='adam', loss='mean_squared_error')
model.fit(x_train, y_train, epochs=25, batch_size=32)

Any help? Thanks in advance!

2 Likes

Axon’s LSTM (and all other RNNs) return tuples of {seq, state}. If you want to apply dropout you need to extract the seq element and apply dropout to that from the tuple.

I would advise against this though, it’s usually discouraged to apply normal dropout to RNN outputs because it can hinder learning. You’d need to use recurrent dropout to achieve the regularization effect you’re looking for, but Axon does not support this.

2 Likes

Thanks for the info, @seanmor5. I am somewhat new to Neural Networks, so I was basically just copying my model from a tutorial I found on YouTube using Tensorflow and scikit-learn. Next time, I will try to do more research before asking a question on a topic I am not familiar with. Thanks again!

No problem, and don’t feel bad for having questions! It’s how we all learn.

There are some RNN examples in the Axon repo if you want to check those out for some additional help!

1 Like

Sorry, but I have another quick question. I am trying to reshape a tensor in Nx, but the EXLA backend won’t let me for some reason. Here is the code:

shape_1 = elem(Nx.shape(x_tensor), 0)
shape_2 = elem(Nx.shape(x_tensor), 1)
shape_3 = 1

shape = {shape_1, shape_2, shape_3}

x_tensor = Nx.reshape(x_tensor, shape)

And this is the error I see:

** (MatchError) no match of right hand side value: #Function<0.58063470/0 in Nx.LazyContainer.Nx.Tensor.traverse/3>
    (exla 0.3.0) lib/exla/defn/buffers.ex:58: anonymous fn/2 in EXLA.Defn.Buffers.from_nx!/1
    (elixir 1.14.0) lib/enum.ex:2468: Enum."-reduce/3-lists^foldl/2-0-"/3
    (exla 0.3.0) lib/exla/defn/buffers.ex:57: EXLA.Defn.Buffers.from_nx!/1
    (exla 0.3.0) lib/exla/defn.ex:303: EXLA.Defn.maybe_outfeed/7
    (stdlib 4.0.1) timer.erl:235: :timer.tc/1
    (exla 0.3.0) lib/exla/defn.ex:264: anonymous fn/7 in EXLA.Defn.__compile__/4
    (nx 0.3.0) lib/nx/defn.ex:432: Nx.Defn.do_jit_apply/3

Nevermind, I updated EXLA to use GitHub (instead of hex.pm) and the error went away on its own. Cheers!