How do I save a trained model in PyTorch?
To save a PyTorch model, store the model parameters using torch.save()
as follows:
To load a saved model, thankfully it's as easy as saving it with model.load_state_dict()
:
This way of storing models involve saving only the trainable parameters. This results in a compact, portable and, dare I say, quite the good-looking model file.
Saving and Loading: A detailed run-through
Taking a pitstop during the training marathon
Long training sessions are just life (or should I say, strife?). To help you out, you can save checkpoints with more than just model parameters:
When loading the checkpoint, initialize the model and optimizer first:
If you're loading in the middle of a racetrack (i.e., for inference), don't forget to call model.eval()
- typos here could cause batch chaos.
When more is actually more: Saving the whole model
At times, it may be expedient to save the entire model especially when the source code goes into hide-and-seek mode during deployment:
And bring it back to life with:
Ensuring smooth transition between PyTorch versions
Version compatibility of PyTorch is a fickle mistress. So, do remember to jot down the PyTorch version used to save your models. Future-you will appreciate it.
Special considerations for the horde: Distributed training
Models thriving in DataParallel
or DistributedDataParallel
situations? Save the state with model.module.state_dict()
:
Shape-shifting: Loading into a different architecture
Different model architecture loading can create hiccups. Ensure a smooth ride by defining the model class before loading the state dict. And remember, pray to the gods of GPU memory.
Extra Scoop: Advanced details and best practices
Don't shoot an arrow in the dark!
Directly saving model.parameters()
might lead to unfortunate consequences (say, a generator object). Let's stick with the tried and true state_dict()
for the sake of future generations.
It's all about the label
Adorn your file with a .pt
or .pth
extension. It's the trademark of PyTorch, after all.
Putting your model through a photocopier
A DataParallel
model can create havoc while loading if not dealt with properly. Remember to wrap the model around DataParallel
before you call load_state_dict()
. You won't regret it.
Put your notes in the checkpoint
In addition to the model and optimizer state, include epochs, performance metrics, and other training parameters. It's like a digital time capsule for nerds! Resuming training later becomes easier.
Serializable not equals immortal!
Changes in your code, refactoring or even switching to another brand of coffee can harm your serialized model. Remember to test your serializable models regularly to ensure their immortality.
The balance of power! (Or efficiency and deployment)
Strive for the golden balance between efficient storage (state_dict()
) and easy deployment (the full model).
Was this article helpful?