How do I print the model summary in PyTorch?
To quickly achieve a model summary in PyTorch, use torchinfo
. This replaces torchsummary
, and you can install it using pip install torchinfo
. Next, call the summary
function using your model and the correct input tensor size.
Boom! You instantly get a detailed report of your model's layers, parameter counts, and output shapes—essential for validating the model architecture.
Dive deeper with pytorch-summary
torchinfo
is good for a quick look, but pytorch-summary
offers a deep dive into the model. It details computable parameters and memory usage, helping to diagnose memory-related bottlenecks. Change 'cuda'
to 'cpu'
if using CPU, and don't forget to buckle your seatbelt before diving.
Unmasking trainable and non-trainable parameters
When setting up your model, it’s important to know which layers are trainable and which are non-trainable. A little function can help you demystify this:
Handling dynamic inputs with torchinfo
Due to PyTorch's dynamic graph computation, model architecture might change based on input size. However, torchinfo
is versatile, allowing you to specify multiple input sizes:
Workaround when external dependencies are out of bounds
There are some scenarios where external packages can't be used. In those cases, you can use state_dict()
or a simple print(model)
to get a basic overview of model layers and parameters:
Roll your own model summary
Creating your own tools allows you to customize based on your specific needs! Let's write a simple function to pull apart the model's layers:
Was this article helpful?