Explain Codes LogoExplain Codes Logo

How do I print the model summary in PyTorch?

python
model-summary
pytorch
deep-dive
Nikita BarsukovbyNikita Barsukov·Dec 7, 2024
TLDR

To quickly achieve a model summary in PyTorch, use torchinfo. This replaces torchsummary, and you can install it using pip install torchinfo. Next, call the summary function using your model and the correct input tensor size.

from torchinfo import summary # model: Replace this with your actual PyTorch model # input_size: Use correct input sizes here summary(model, input_size=(3, 224, 224))

Boom! You instantly get a detailed report of your model's layers, parameter counts, and output shapes—essential for validating the model architecture.

Dive deeper with pytorch-summary

torchinfo is good for a quick look, but pytorch-summary offers a deep dive into the model. It details computable parameters and memory usage, helping to diagnose memory-related bottlenecks. Change 'cuda' to 'cpu' if using CPU, and don't forget to buckle your seatbelt before diving.

from torchsummary import summary # Take the deep dive...switch 'cuda' to 'cpu' if you're CPU bound model.to('cuda') summary(model, (3, 224, 224))

Unmasking trainable and non-trainable parameters

When setting up your model, it’s important to know which layers are trainable and which are non-trainable. A little function can help you demystify this:

print("Trainable layers:") for name, param in model.named_parameters(): if param.requires_grad: print(name) # Hands off the wheel, these layers are driving themselves

Handling dynamic inputs with torchinfo

Due to PyTorch's dynamic graph computation, model architecture might change based on input size. However, torchinfo is versatile, allowing you to specify multiple input sizes:

summary(model, input_data=(input1, input2, ...)) # Variable input sizes: Because size does matter (some times)

Workaround when external dependencies are out of bounds

There are some scenarios where external packages can't be used. In those cases, you can use state_dict() or a simple print(model) to get a basic overview of model layers and parameters:

print("Layers and parameters:") for param_tensor in model.state_dict(): print(param_tensor, "\t", model.state_dict()[param_tensor].size()) # Old school method when you're not allowed to bring your toys

Roll your own model summary

Creating your own tools allows you to customize based on your specific needs! Let's write a simple function to pull apart the model's layers:

def model_summary(model): print("Layer \t\t Type \t\t Parameter Count") for name, layer in model.named_modules(): # Who needs packages anyway? if len(list(layer.children())) == 0: params = sum(p.numel() for p in layer.parameters()) print(f"{name} \t {type(layer).__name__} \t {params}") model_summary(your_model)