Explain Codes LogoExplain Codes Logo

How do I save a trained model in PyTorch?

python
model-saving
pytorch
deep-learning
Nikita BarsukovbyNikita Barsukov·Feb 13, 2025
TLDR

To save a PyTorch model, store the model parameters using torch.save() as follows:

torch.save(model.state_dict(), 'my_model.pth')

To load a saved model, thankfully it's as easy as saving it with model.load_state_dict():

model = MyModel() # don't forget to define your model class first model.load_state_dict(torch.load('my_model.pth'))

This way of storing models involve saving only the trainable parameters. This results in a compact, portable and, dare I say, quite the good-looking model file.

Saving and Loading: A detailed run-through

Taking a pitstop during the training marathon

Long training sessions are just life (or should I say, strife?). To help you out, you can save checkpoints with more than just model parameters:

checkpoint = { 'epoch': epoch, # round and round we go 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, # tears or binary cross-entropy, we all have losses # feel free to add other metrics for your experimental pleasure } torch.save(checkpoint, 'pitstop.pth')

When loading the checkpoint, initialize the model and optimizer first:

checkpoint = torch.load('pitstop.pth') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # If you were interrupted mid-epoch… epoch = checkpoint['epoch'] loss = checkpoint['loss']

If you're loading in the middle of a racetrack (i.e., for inference), don't forget to call model.eval() - typos here could cause batch chaos.

When more is actually more: Saving the whole model

At times, it may be expedient to save the entire model especially when the source code goes into hide-and-seek mode during deployment:

torch.save(model, 'the_whole_shebang.pth')

And bring it back to life with:

model = torch.load('the_whole_shebang.pth')

Ensuring smooth transition between PyTorch versions

Version compatibility of PyTorch is a fickle mistress. So, do remember to jot down the PyTorch version used to save your models. Future-you will appreciate it.

Special considerations for the horde: Distributed training

Models thriving in DataParallel or DistributedDataParallel situations? Save the state with model.module.state_dict():

# One for all and all for one! torch.save(model.module.state_dict(), 'all_for_one.pth')

Shape-shifting: Loading into a different architecture

Different model architecture loading can create hiccups. Ensure a smooth ride by defining the model class before loading the state dict. And remember, pray to the gods of GPU memory.

Extra Scoop: Advanced details and best practices

Don't shoot an arrow in the dark!

Directly saving model.parameters() might lead to unfortunate consequences (say, a generator object). Let's stick with the tried and true state_dict() for the sake of future generations.

It's all about the label

Adorn your file with a .pt or .pth extension. It's the trademark of PyTorch, after all.

Putting your model through a photocopier

A DataParallel model can create havoc while loading if not dealt with properly. Remember to wrap the model around DataParallel before you call load_state_dict(). You won't regret it.

Put your notes in the checkpoint

In addition to the model and optimizer state, include epochs, performance metrics, and other training parameters. It's like a digital time capsule for nerds! Resuming training later becomes easier.

Serializable not equals immortal!

Changes in your code, refactoring or even switching to another brand of coffee can harm your serialized model. Remember to test your serializable models regularly to ensure their immortality.

The balance of power! (Or efficiency and deployment)

Strive for the golden balance between efficient storage (state_dict()) and easy deployment (the full model).