use-case-and-architecture/EdgeFLite/thop/profiling.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

169 lines
7.6 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
# Importing necessary modules
from distutils.version import LooseVersion # Used for version comparisons
from .basic_hooks import * # Importing basic hooks (functions for profiling operations)
from .rnn_hooks import * # Importing hooks specific to RNN operations
# Uncomment the following for logging purposes
# import logging
# logger = logging.getLogger(__name__) # Creating a logger instance
# logger.setLevel(logging.INFO) # Setting the log level to INFO
# Functions to print text in different colors
# Useful for visually differentiating output in terminal
def prRed(skk):
print("\033[91m{}\033[00m".format(skk)) # Print red text
def prGreen(skk):
print("\033[92m{}\033[00m".format(skk)) # Print green text
def prYellow(skk):
print("\033[93m{}\033[00m".format(skk)) # Print yellow text
# Checking if the installed version of PyTorch is outdated
if LooseVersion(torch.__version__) < LooseVersion("1.0.0"):
# If the version is below 1.0.0, print a warning
logging.warning(
f"You are using an old version of PyTorch {torch.__version__}, which THOP may not support in the future."
)
# Setting the default data type for tensors
default_dtype = torch.float64 # Using 64-bit float as the default precision
# Register hooks for different layers in PyTorch
# Each layer type is mapped to its respective counting function
register_hooks = {
nn.ZeroPad2d: zero_ops,
nn.Conv1d: count_convNd, nn.Conv2d: count_convNd, nn.Conv3d: count_convNd,
nn.ConvTranspose1d: count_convNd, nn.ConvTranspose2d: count_convNd, nn.ConvTranspose3d: count_convNd,
nn.BatchNorm1d: count_bn, nn.BatchNorm2d: count_bn, nn.BatchNorm3d: count_bn, nn.SyncBatchNorm: count_bn,
nn.ReLU: zero_ops, nn.ReLU6: zero_ops, nn.LeakyReLU: count_relu,
nn.MaxPool1d: zero_ops, nn.MaxPool2d: zero_ops, nn.MaxPool3d: zero_ops,
nn.AdaptiveMaxPool1d: zero_ops, nn.AdaptiveMaxPool2d: zero_ops, nn.AdaptiveMaxPool3d: zero_ops,
nn.AvgPool1d: count_avgpool, nn.AvgPool2d: count_avgpool, nn.AvgPool3d: count_avgpool,
nn.AdaptiveAvgPool1d: count_adap_avgpool, nn.AdaptiveAvgPool2d: count_adap_avgpool, nn.AdaptiveAvgPool3d: count_adap_avgpool,
nn.Linear: count_linear, nn.Dropout: zero_ops,
nn.Upsample: count_upsample, nn.UpsamplingBilinear2d: count_upsample, nn.UpsamplingNearest2d: count_upsample,
nn.RNNCell: count_rnn_cell, nn.GRUCell: count_gru_cell, nn.LSTMCell: count_lstm_cell,
nn.RNN: count_rnn, nn.GRU: count_gru, nn.LSTM: count_lstm,
}
# Function for profiling model operations and parameters
# This tracks how many operations (ops) and parameters (params) a model uses
def profile_origin(model, inputs, custom_ops=None, verbose=True):
handler_collection = [] # Collection of hooks
types_collection = set() # Keep track of registered layer types
custom_ops = custom_ops or {} # Custom operation handling
def add_hooks(m):
# Ignore compound modules (those that contain other modules)
if len(list(m.children())) > 0:
return
# Check if the module already has the required attributes
if hasattr(m, "total_ops") or hasattr(m, "total_params"):
logging.warning(f"Either .total_ops or .total_params is already defined in {str(m)}. Be cautious.")
# Add buffers to store the total number of operations and parameters
m.register_buffer('total_ops', torch.zeros(1, dtype=default_dtype))
m.register_buffer('total_params', torch.zeros(1, dtype=default_dtype))
# Count the number of parameters for this module
for p in m.parameters():
m.total_params += torch.DoubleTensor([p.numel()])
# Determine which function to use for counting operations
m_type = type(m)
fn = custom_ops.get(m_type, register_hooks.get(m_type, None))
if fn:
# If the function exists, register the forward hook
if m_type not in types_collection and verbose:
print(f"[INFO] {'Customize rule' if m_type in custom_ops else 'Register'} {fn.__qualname__} for {m_type}.")
handler = m.register_forward_hook(fn)
handler_collection.append(handler)
else:
# Warn if no counting rule is found
if m_type not in types_collection and verbose:
prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero MACs and zero Params.")
types_collection.add(m_type)
# Set the model to evaluation mode (no gradients)
model.eval()
model.apply(add_hooks)
# Run a forward pass with no gradients
with torch.no_grad():
model(*inputs)
# Sum up the total operations and parameters across all layers
total_ops = sum(m.total_ops.item() for m in model.modules() if hasattr(m, 'total_ops'))
total_params = sum(m.total_params.item() for m in model.modules() if hasattr(m, 'total_params'))
# Restore the model to training mode and remove hooks
model.train()
for handler in handler_collection:
handler.remove()
for m in model.modules():
if hasattr(m, "total_ops"): del m._buffers['total_ops']
if hasattr(m, "total_params"): del m._buffers['total_params']
return total_ops, total_params # Return the total number of ops and params
# Updated profiling function with a different approach for hierarchical modules
def profile(model: nn.Module, inputs, custom_ops=None, verbose=True):
handler_collection = {} # Dictionary to store handlers
types_collection = set() # Store layer types that have been processed
custom_ops = custom_ops or {} # Custom operation handling
def add_hooks(m: nn.Module):
# Add buffers for storing total ops and params
m.register_buffer('total_ops', torch.zeros(1, dtype=default_dtype))
m.register_buffer('total_params', torch.zeros(1, dtype=default_dtype))
# Find the appropriate counting function for this layer
fn = custom_ops.get(type(m), register_hooks.get(type(m), None))
if fn:
# Register hooks for both operations and parameters
handler_collection[m] = (m.register_forward_hook(fn), m.register_forward_hook(count_parameters))
if type(m) not in types_collection and verbose:
print(f"[INFO] {'Customize rule' if type(m) in custom_ops else 'Register'} {fn.__qualname__} for {type(m)}.")
else:
# Warn if no rule is found for this layer
if type(m) not in types_collection and verbose:
prRed(f"[WARN] Cannot find rule for {type(m)}. Treat it as zero MACs and zero Params.")
types_collection.add(type(m))
# Set the model to evaluation mode
model.eval()
model.apply(add_hooks)
# Run a forward pass with no gradients
with torch.no_grad():
model(*inputs)
# Recursive function to count ops and params for hierarchical models
def dfs_count(module: nn.Module) -> (int, int):
total_ops, total_params = 0, 0
for m in module.children():
if m in handler_collection:
m_ops, m_params = m.total_ops.item(), m.total_params.item()
else:
m_ops, m_params = dfs_count(m)
total_ops += m_ops
total_params += m_params
return total_ops, total_params
total_ops, total_params = dfs_count(model) # Perform the depth-first count
# Restore the model to training mode and remove hooks
model.train()
for m, (op_handler, params_handler) in handler_collection.items():
op_handler.remove()
params_handler.remove()
del m._buffers['total_ops']
del m._buffers['total_params']
return total_ops, total_params # Return the total ops and params