How do I print the model summary in PyTorch?
To quickly achieve a model summary in PyTorch, use torchinfo. This replaces torchsummary, and you can install it using pip install torchinfo. Next, call the summary function using your model and the correct input tensor size.
Boom! You instantly get a detailed report of your model's layers, parameter counts, and output shapes—essential for validating the model architecture.
Dive deeper with pytorch-summary
torchinfo is good for a quick look, but pytorch-summary offers a deep dive into the model. It details computable parameters and memory usage, helping to diagnose memory-related bottlenecks. Change 'cuda' to 'cpu' if using CPU, and don't forget to buckle your seatbelt before diving.
Unmasking trainable and non-trainable parameters
When setting up your model, it’s important to know which layers are trainable and which are non-trainable. A little function can help you demystify this:
Handling dynamic inputs with torchinfo
Due to PyTorch's dynamic graph computation, model architecture might change based on input size. However, torchinfo is versatile, allowing you to specify multiple input sizes:
Workaround when external dependencies are out of bounds
There are some scenarios where external packages can't be used. In those cases, you can use state_dict() or a simple print(model) to get a basic overview of model layers and parameters:
Roll your own model summary
Creating your own tools allows you to customize based on your specific needs! Let's write a simple function to pull apart the model's layers:
Was this article helpful?