# -*- coding: utf-8 -*- # @Author: Weisen Pan # Importing necessary modules from distutils.version import LooseVersion # Used for version comparisons from .basic_hooks import * # Importing basic hooks (functions for profiling operations) from .rnn_hooks import * # Importing hooks specific to RNN operations # Uncomment the following for logging purposes # import logging # logger = logging.getLogger(__name__) # Creating a logger instance # logger.setLevel(logging.INFO) # Setting the log level to INFO # Functions to print text in different colors # Useful for visually differentiating output in terminal def prRed(skk): print("\033[91m{}\033[00m".format(skk)) # Print red text def prGreen(skk): print("\033[92m{}\033[00m".format(skk)) # Print green text def prYellow(skk): print("\033[93m{}\033[00m".format(skk)) # Print yellow text # Checking if the installed version of PyTorch is outdated if LooseVersion(torch.__version__) < LooseVersion("1.0.0"): # If the version is below 1.0.0, print a warning logging.warning( f"You are using an old version of PyTorch {torch.__version__}, which THOP may not support in the future." ) # Setting the default data type for tensors default_dtype = torch.float64 # Using 64-bit float as the default precision # Register hooks for different layers in PyTorch # Each layer type is mapped to its respective counting function register_hooks = { nn.ZeroPad2d: zero_ops, nn.Conv1d: count_convNd, nn.Conv2d: count_convNd, nn.Conv3d: count_convNd, nn.ConvTranspose1d: count_convNd, nn.ConvTranspose2d: count_convNd, nn.ConvTranspose3d: count_convNd, nn.BatchNorm1d: count_bn, nn.BatchNorm2d: count_bn, nn.BatchNorm3d: count_bn, nn.SyncBatchNorm: count_bn, nn.ReLU: zero_ops, nn.ReLU6: zero_ops, nn.LeakyReLU: count_relu, nn.MaxPool1d: zero_ops, nn.MaxPool2d: zero_ops, nn.MaxPool3d: zero_ops, nn.AdaptiveMaxPool1d: zero_ops, nn.AdaptiveMaxPool2d: zero_ops, nn.AdaptiveMaxPool3d: zero_ops, nn.AvgPool1d: count_avgpool, nn.AvgPool2d: count_avgpool, nn.AvgPool3d: count_avgpool, nn.AdaptiveAvgPool1d: count_adap_avgpool, nn.AdaptiveAvgPool2d: count_adap_avgpool, nn.AdaptiveAvgPool3d: count_adap_avgpool, nn.Linear: count_linear, nn.Dropout: zero_ops, nn.Upsample: count_upsample, nn.UpsamplingBilinear2d: count_upsample, nn.UpsamplingNearest2d: count_upsample, nn.RNNCell: count_rnn_cell, nn.GRUCell: count_gru_cell, nn.LSTMCell: count_lstm_cell, nn.RNN: count_rnn, nn.GRU: count_gru, nn.LSTM: count_lstm, } # Function for profiling model operations and parameters # This tracks how many operations (ops) and parameters (params) a model uses def profile_origin(model, inputs, custom_ops=None, verbose=True): handler_collection = [] # Collection of hooks types_collection = set() # Keep track of registered layer types custom_ops = custom_ops or {} # Custom operation handling def add_hooks(m): # Ignore compound modules (those that contain other modules) if len(list(m.children())) > 0: return # Check if the module already has the required attributes if hasattr(m, "total_ops") or hasattr(m, "total_params"): logging.warning(f"Either .total_ops or .total_params is already defined in {str(m)}. Be cautious.") # Add buffers to store the total number of operations and parameters m.register_buffer('total_ops', torch.zeros(1, dtype=default_dtype)) m.register_buffer('total_params', torch.zeros(1, dtype=default_dtype)) # Count the number of parameters for this module for p in m.parameters(): m.total_params += torch.DoubleTensor([p.numel()]) # Determine which function to use for counting operations m_type = type(m) fn = custom_ops.get(m_type, register_hooks.get(m_type, None)) if fn: # If the function exists, register the forward hook if m_type not in types_collection and verbose: print(f"[INFO] {'Customize rule' if m_type in custom_ops else 'Register'} {fn.__qualname__} for {m_type}.") handler = m.register_forward_hook(fn) handler_collection.append(handler) else: # Warn if no counting rule is found if m_type not in types_collection and verbose: prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero MACs and zero Params.") types_collection.add(m_type) # Set the model to evaluation mode (no gradients) model.eval() model.apply(add_hooks) # Run a forward pass with no gradients with torch.no_grad(): model(*inputs) # Sum up the total operations and parameters across all layers total_ops = sum(m.total_ops.item() for m in model.modules() if hasattr(m, 'total_ops')) total_params = sum(m.total_params.item() for m in model.modules() if hasattr(m, 'total_params')) # Restore the model to training mode and remove hooks model.train() for handler in handler_collection: handler.remove() for m in model.modules(): if hasattr(m, "total_ops"): del m._buffers['total_ops'] if hasattr(m, "total_params"): del m._buffers['total_params'] return total_ops, total_params # Return the total number of ops and params # Updated profiling function with a different approach for hierarchical modules def profile(model: nn.Module, inputs, custom_ops=None, verbose=True): handler_collection = {} # Dictionary to store handlers types_collection = set() # Store layer types that have been processed custom_ops = custom_ops or {} # Custom operation handling def add_hooks(m: nn.Module): # Add buffers for storing total ops and params m.register_buffer('total_ops', torch.zeros(1, dtype=default_dtype)) m.register_buffer('total_params', torch.zeros(1, dtype=default_dtype)) # Find the appropriate counting function for this layer fn = custom_ops.get(type(m), register_hooks.get(type(m), None)) if fn: # Register hooks for both operations and parameters handler_collection[m] = (m.register_forward_hook(fn), m.register_forward_hook(count_parameters)) if type(m) not in types_collection and verbose: print(f"[INFO] {'Customize rule' if type(m) in custom_ops else 'Register'} {fn.__qualname__} for {type(m)}.") else: # Warn if no rule is found for this layer if type(m) not in types_collection and verbose: prRed(f"[WARN] Cannot find rule for {type(m)}. Treat it as zero MACs and zero Params.") types_collection.add(type(m)) # Set the model to evaluation mode model.eval() model.apply(add_hooks) # Run a forward pass with no gradients with torch.no_grad(): model(*inputs) # Recursive function to count ops and params for hierarchical models def dfs_count(module: nn.Module) -> (int, int): total_ops, total_params = 0, 0 for m in module.children(): if m in handler_collection: m_ops, m_params = m.total_ops.item(), m.total_params.item() else: m_ops, m_params = dfs_count(m) total_ops += m_ops total_params += m_params return total_ops, total_params total_ops, total_params = dfs_count(model) # Perform the depth-first count # Restore the model to training mode and remove hooks model.train() for m, (op_handler, params_handler) in handler_collection.items(): op_handler.remove() params_handler.remove() del m._buffers['total_ops'] del m._buffers['total_params'] return total_ops, total_params # Return the total ops and params