4ec0a23e73
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
169 lines
7.6 KiB
Python
169 lines
7.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
# Importing necessary modules
|
|
from distutils.version import LooseVersion # Used for version comparisons
|
|
from .basic_hooks import * # Importing basic hooks (functions for profiling operations)
|
|
from .rnn_hooks import * # Importing hooks specific to RNN operations
|
|
|
|
# Uncomment the following for logging purposes
|
|
# import logging
|
|
# logger = logging.getLogger(__name__) # Creating a logger instance
|
|
# logger.setLevel(logging.INFO) # Setting the log level to INFO
|
|
|
|
# Functions to print text in different colors
|
|
# Useful for visually differentiating output in terminal
|
|
def prRed(skk):
|
|
print("\033[91m{}\033[00m".format(skk)) # Print red text
|
|
def prGreen(skk):
|
|
print("\033[92m{}\033[00m".format(skk)) # Print green text
|
|
def prYellow(skk):
|
|
print("\033[93m{}\033[00m".format(skk)) # Print yellow text
|
|
|
|
# Checking if the installed version of PyTorch is outdated
|
|
if LooseVersion(torch.__version__) < LooseVersion("1.0.0"):
|
|
# If the version is below 1.0.0, print a warning
|
|
logging.warning(
|
|
f"You are using an old version of PyTorch {torch.__version__}, which THOP may not support in the future."
|
|
)
|
|
|
|
# Setting the default data type for tensors
|
|
default_dtype = torch.float64 # Using 64-bit float as the default precision
|
|
|
|
# Register hooks for different layers in PyTorch
|
|
# Each layer type is mapped to its respective counting function
|
|
register_hooks = {
|
|
nn.ZeroPad2d: zero_ops,
|
|
nn.Conv1d: count_convNd, nn.Conv2d: count_convNd, nn.Conv3d: count_convNd,
|
|
nn.ConvTranspose1d: count_convNd, nn.ConvTranspose2d: count_convNd, nn.ConvTranspose3d: count_convNd,
|
|
nn.BatchNorm1d: count_bn, nn.BatchNorm2d: count_bn, nn.BatchNorm3d: count_bn, nn.SyncBatchNorm: count_bn,
|
|
nn.ReLU: zero_ops, nn.ReLU6: zero_ops, nn.LeakyReLU: count_relu,
|
|
nn.MaxPool1d: zero_ops, nn.MaxPool2d: zero_ops, nn.MaxPool3d: zero_ops,
|
|
nn.AdaptiveMaxPool1d: zero_ops, nn.AdaptiveMaxPool2d: zero_ops, nn.AdaptiveMaxPool3d: zero_ops,
|
|
nn.AvgPool1d: count_avgpool, nn.AvgPool2d: count_avgpool, nn.AvgPool3d: count_avgpool,
|
|
nn.AdaptiveAvgPool1d: count_adap_avgpool, nn.AdaptiveAvgPool2d: count_adap_avgpool, nn.AdaptiveAvgPool3d: count_adap_avgpool,
|
|
nn.Linear: count_linear, nn.Dropout: zero_ops,
|
|
nn.Upsample: count_upsample, nn.UpsamplingBilinear2d: count_upsample, nn.UpsamplingNearest2d: count_upsample,
|
|
nn.RNNCell: count_rnn_cell, nn.GRUCell: count_gru_cell, nn.LSTMCell: count_lstm_cell,
|
|
nn.RNN: count_rnn, nn.GRU: count_gru, nn.LSTM: count_lstm,
|
|
}
|
|
|
|
# Function for profiling model operations and parameters
|
|
# This tracks how many operations (ops) and parameters (params) a model uses
|
|
def profile_origin(model, inputs, custom_ops=None, verbose=True):
|
|
handler_collection = [] # Collection of hooks
|
|
types_collection = set() # Keep track of registered layer types
|
|
custom_ops = custom_ops or {} # Custom operation handling
|
|
|
|
def add_hooks(m):
|
|
# Ignore compound modules (those that contain other modules)
|
|
if len(list(m.children())) > 0:
|
|
return
|
|
|
|
# Check if the module already has the required attributes
|
|
if hasattr(m, "total_ops") or hasattr(m, "total_params"):
|
|
logging.warning(f"Either .total_ops or .total_params is already defined in {str(m)}. Be cautious.")
|
|
|
|
# Add buffers to store the total number of operations and parameters
|
|
m.register_buffer('total_ops', torch.zeros(1, dtype=default_dtype))
|
|
m.register_buffer('total_params', torch.zeros(1, dtype=default_dtype))
|
|
|
|
# Count the number of parameters for this module
|
|
for p in m.parameters():
|
|
m.total_params += torch.DoubleTensor([p.numel()])
|
|
|
|
# Determine which function to use for counting operations
|
|
m_type = type(m)
|
|
fn = custom_ops.get(m_type, register_hooks.get(m_type, None))
|
|
|
|
if fn:
|
|
# If the function exists, register the forward hook
|
|
if m_type not in types_collection and verbose:
|
|
print(f"[INFO] {'Customize rule' if m_type in custom_ops else 'Register'} {fn.__qualname__} for {m_type}.")
|
|
handler = m.register_forward_hook(fn)
|
|
handler_collection.append(handler)
|
|
else:
|
|
# Warn if no counting rule is found
|
|
if m_type not in types_collection and verbose:
|
|
prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero MACs and zero Params.")
|
|
|
|
types_collection.add(m_type)
|
|
|
|
# Set the model to evaluation mode (no gradients)
|
|
model.eval()
|
|
model.apply(add_hooks)
|
|
|
|
# Run a forward pass with no gradients
|
|
with torch.no_grad():
|
|
model(*inputs)
|
|
|
|
# Sum up the total operations and parameters across all layers
|
|
total_ops = sum(m.total_ops.item() for m in model.modules() if hasattr(m, 'total_ops'))
|
|
total_params = sum(m.total_params.item() for m in model.modules() if hasattr(m, 'total_params'))
|
|
|
|
# Restore the model to training mode and remove hooks
|
|
model.train()
|
|
for handler in handler_collection:
|
|
handler.remove()
|
|
for m in model.modules():
|
|
if hasattr(m, "total_ops"): del m._buffers['total_ops']
|
|
if hasattr(m, "total_params"): del m._buffers['total_params']
|
|
|
|
return total_ops, total_params # Return the total number of ops and params
|
|
|
|
# Updated profiling function with a different approach for hierarchical modules
|
|
def profile(model: nn.Module, inputs, custom_ops=None, verbose=True):
|
|
handler_collection = {} # Dictionary to store handlers
|
|
types_collection = set() # Store layer types that have been processed
|
|
custom_ops = custom_ops or {} # Custom operation handling
|
|
|
|
def add_hooks(m: nn.Module):
|
|
# Add buffers for storing total ops and params
|
|
m.register_buffer('total_ops', torch.zeros(1, dtype=default_dtype))
|
|
m.register_buffer('total_params', torch.zeros(1, dtype=default_dtype))
|
|
|
|
# Find the appropriate counting function for this layer
|
|
fn = custom_ops.get(type(m), register_hooks.get(type(m), None))
|
|
if fn:
|
|
# Register hooks for both operations and parameters
|
|
handler_collection[m] = (m.register_forward_hook(fn), m.register_forward_hook(count_parameters))
|
|
if type(m) not in types_collection and verbose:
|
|
print(f"[INFO] {'Customize rule' if type(m) in custom_ops else 'Register'} {fn.__qualname__} for {type(m)}.")
|
|
else:
|
|
# Warn if no rule is found for this layer
|
|
if type(m) not in types_collection and verbose:
|
|
prRed(f"[WARN] Cannot find rule for {type(m)}. Treat it as zero MACs and zero Params.")
|
|
types_collection.add(type(m))
|
|
|
|
# Set the model to evaluation mode
|
|
model.eval()
|
|
model.apply(add_hooks)
|
|
|
|
# Run a forward pass with no gradients
|
|
with torch.no_grad():
|
|
model(*inputs)
|
|
|
|
# Recursive function to count ops and params for hierarchical models
|
|
def dfs_count(module: nn.Module) -> (int, int):
|
|
total_ops, total_params = 0, 0
|
|
for m in module.children():
|
|
if m in handler_collection:
|
|
m_ops, m_params = m.total_ops.item(), m.total_params.item()
|
|
else:
|
|
m_ops, m_params = dfs_count(m)
|
|
total_ops += m_ops
|
|
total_params += m_params
|
|
return total_ops, total_params
|
|
|
|
total_ops, total_params = dfs_count(model) # Perform the depth-first count
|
|
|
|
# Restore the model to training mode and remove hooks
|
|
model.train()
|
|
for m, (op_handler, params_handler) in handler_collection.items():
|
|
op_handler.remove()
|
|
params_handler.remove()
|
|
del m._buffers['total_ops']
|
|
del m._buffers['total_params']
|
|
|
|
return total_ops, total_params # Return the total ops and params
|