use-case-and-architecture/EdgeFLite/thop/hooks_basic.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

92 lines
3.0 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import argparse
import logging
import torch
import torch.nn as nn
from torch.nn.modules.conv import _ConvNd
multiply_adds = 1
def count_parameters(m, x, y):
"""Counts the number of parameters in a model."""
total_params = sum(p.numel() for p in m.parameters())
m.total_params[0] = torch.DoubleTensor([total_params])
def zero_ops(m, x, y):
"""Sets total operations to zero."""
m.total_ops += torch.DoubleTensor([0])
def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor):
"""Counts operations for convolutional layers."""
x = x[0]
kernel_ops = m.weight[0][0].numel() # Kw x Kh
bias_ops = 1 if m.bias is not None else 0
total_ops = y.nelement() * (m.in_channels // m.groups * kernel_ops + bias_ops)
m.total_ops += torch.DoubleTensor([total_ops])
def count_convNd_ver2(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor):
"""Alternative method for counting operations for convolutional layers."""
x = x[0]
output_size = torch.zeros((y.size()[:1] + y.size()[2:])).numel()
kernel_ops = m.weight.numel() + (m.bias.numel() if m.bias is not None else 0)
m.total_ops += torch.DoubleTensor([output_size * kernel_ops])
def count_bn(m, x, y):
"""Counts operations for batch normalization layers."""
x = x[0]
nelements = x.numel()
if not m.training:
total_ops = 2 * nelements
m.total_ops += torch.DoubleTensor([total_ops])
def count_relu(m, x, y):
"""Counts operations for ReLU activation."""
x = x[0]
nelements = x.numel()
m.total_ops += torch.DoubleTensor([nelements])
def count_softmax(m, x, y):
"""Counts operations for softmax."""
x = x[0]
batch_size, nfeatures = x.size()
total_ops = batch_size * (2 * nfeatures - 1)
m.total_ops += torch.DoubleTensor([total_ops])
def count_avgpool(m, x, y):
"""Counts operations for average pooling layers."""
num_elements = y.numel()
m.total_ops += torch.DoubleTensor([num_elements])
def count_adap_avgpool(m, x, y):
"""Counts operations for adaptive average pooling layers."""
kernel = torch.DoubleTensor([*(x[0].shape[2:])]) // torch.DoubleTensor(list((m.output_size,))).squeeze()
kernel_ops = torch.prod(kernel) + 1
num_elements = y.numel()
m.total_ops += torch.DoubleTensor([kernel_ops * num_elements])
def count_upsample(m, x, y):
"""Counts operations for upsample layers."""
if m.mode not in ("nearest", "linear", "bilinear", "bicubic"):
logging.warning(f"Mode {m.mode} is not implemented yet, assuming zero ops")
return zero_ops(m, x, y)
if m.mode == "nearest":
return zero_ops(m, x, y)
total_ops = {
"linear": 5,
"bilinear": 11,
"bicubic": 259, # 224 muls + 35 adds
"trilinear": 31 # 2 * bilinear + 1 * linear
}.get(m.mode, 0) * y.nelement()
m.total_ops += torch.DoubleTensor([total_ops])
def count_linear(m, x, y):
"""Counts operations for linear layers."""
total_ops = m.in_features * y.numel()
m.total_ops += torch.DoubleTensor([total_ops])