# -*- coding: utf-8 -*- # @Author: Weisen Pan import argparse import logging import torch import torch.nn as nn from torch.nn.modules.conv import _ConvNd multiply_adds = 1 def count_parameters(m, x, y): """Counts the number of parameters in a model.""" total_params = sum(p.numel() for p in m.parameters()) m.total_params[0] = torch.DoubleTensor([total_params]) def zero_ops(m, x, y): """Sets total operations to zero.""" m.total_ops += torch.DoubleTensor([0]) def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): """Counts operations for convolutional layers.""" x = x[0] kernel_ops = m.weight[0][0].numel() # Kw x Kh bias_ops = 1 if m.bias is not None else 0 total_ops = y.nelement() * (m.in_channels // m.groups * kernel_ops + bias_ops) m.total_ops += torch.DoubleTensor([total_ops]) def count_convNd_ver2(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): """Alternative method for counting operations for convolutional layers.""" x = x[0] output_size = torch.zeros((y.size()[:1] + y.size()[2:])).numel() kernel_ops = m.weight.numel() + (m.bias.numel() if m.bias is not None else 0) m.total_ops += torch.DoubleTensor([output_size * kernel_ops]) def count_bn(m, x, y): """Counts operations for batch normalization layers.""" x = x[0] nelements = x.numel() if not m.training: total_ops = 2 * nelements m.total_ops += torch.DoubleTensor([total_ops]) def count_relu(m, x, y): """Counts operations for ReLU activation.""" x = x[0] nelements = x.numel() m.total_ops += torch.DoubleTensor([nelements]) def count_softmax(m, x, y): """Counts operations for softmax.""" x = x[0] batch_size, nfeatures = x.size() total_ops = batch_size * (2 * nfeatures - 1) m.total_ops += torch.DoubleTensor([total_ops]) def count_avgpool(m, x, y): """Counts operations for average pooling layers.""" num_elements = y.numel() m.total_ops += torch.DoubleTensor([num_elements]) def count_adap_avgpool(m, x, y): """Counts operations for adaptive average pooling layers.""" kernel = torch.DoubleTensor([*(x[0].shape[2:])]) // torch.DoubleTensor(list((m.output_size,))).squeeze() kernel_ops = torch.prod(kernel) + 1 num_elements = y.numel() m.total_ops += torch.DoubleTensor([kernel_ops * num_elements]) def count_upsample(m, x, y): """Counts operations for upsample layers.""" if m.mode not in ("nearest", "linear", "bilinear", "bicubic"): logging.warning(f"Mode {m.mode} is not implemented yet, assuming zero ops") return zero_ops(m, x, y) if m.mode == "nearest": return zero_ops(m, x, y) total_ops = { "linear": 5, "bilinear": 11, "bicubic": 259, # 224 muls + 35 adds "trilinear": 31 # 2 * bilinear + 1 * linear }.get(m.mode, 0) * y.nelement() m.total_ops += torch.DoubleTensor([total_ops]) def count_linear(m, x, y): """Counts operations for linear layers.""" total_ops = m.in_features * y.numel() m.total_ops += torch.DoubleTensor([total_ops])