I’m working my way through Machine Learning in Elixir and in chapter 8 there’s a section on fine tuning a vision model loaded from ONNX.
In particular, there’s this snippet:
model = model |> Axon.unfreeze(up: 50)
Axon.unfreeze/2 is now deprecated in favor of the new Axon.ModelState.unfreeze/2.
I’m struggling to update this part of the code, though I understand the idea behind it.
I could find a more recent fine tuning example in Elixir Machine Learning: Training Models in Axon is Getting Better:
model_state =
Axon.ModelState.freeze(model_state, fn
["sequence_classification_head.output", _] -> false
_ -> true
end)
However, the model I have loaded from ONNX doesn’t have names as nice as "sequence_classification_head.output"
and there are 130+ nodes, so I’m not quite sure how to target N nodes up.
I suppose I could write a function to traverse the model nodes and figure out all the ones I want to freeze, basically re-implementing Axon.unfreeze(up: 50)
, but it doesn’t seem like that’s the Axon way, otherwise it probably wouldn’t be deprecated.
For the sake of learning, I’m using the deprecated function, but can anyone suggest what the recommended way is in the latest Axon version?