# -*- coding: utf-8 -*- # @Author: Weisen Pan __all__ = ['model_summary'] import torch import torch.nn as nn import numpy as np import os import json from collections import OrderedDict # Format FLOPs value with appropriate unit (T, G, M, K) def format_flops(flops): units = [(1e12, 'T'), (1e9, 'G'), (1e6, 'M'), (1e3, 'K')] for scale, suffix in units: if flops >= scale: return f"{flops / scale:.1f}{suffix}" return f"{flops:.1f}" # Calculate the number of trainable or non-trainable parameters def calculate_grad_params(param_count, param): if param.requires_grad: return param_count, 0 else: return 0, param_count # Compute FLOPs and parameters for a convolutional layer def compute_conv_flops(layer, input, output): oh, ow = output.shape[-2:] # Output height and width kh, kw = layer.kernel_size # Kernel height and width ic, oc = layer.in_channels, layer.out_channels # Input/output channels groups = layer.groups # Number of groups for grouped convolution total_trainable = 0 total_non_trainable = 0 flops = 0 # Compute parameters and FLOPs for the weight if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'): param_count = np.prod(layer.weight.shape) trainable, non_trainable = calculate_grad_params(param_count, layer.weight) total_trainable += trainable total_non_trainable += non_trainable flops += (2 * ic * kh * kw - 1) * oh * ow * (oc // groups) # Compute parameters and FLOPs for the bias if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'): param_count = np.prod(layer.bias.shape) trainable, non_trainable = calculate_grad_params(param_count, layer.bias) total_trainable += trainable total_non_trainable += non_trainable flops += oh * ow * (oc // groups) return total_trainable, total_non_trainable, flops # Compute FLOPs and parameters for normalization layers (BatchNorm, GroupNorm) def compute_norm_flops(layer, input, output): total_trainable = 0 total_non_trainable = 0 if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'): param_count = np.prod(layer.weight.shape) trainable, non_trainable = calculate_grad_params(param_count, layer.weight) total_trainable += trainable total_non_trainable += non_trainable if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'): param_count = np.prod(layer.bias.shape) trainable, non_trainable = calculate_grad_params(param_count, layer.bias) total_trainable += trainable total_non_trainable += non_trainable if hasattr(layer, 'running_mean'): total_non_trainable += np.prod(layer.running_mean.shape) if hasattr(layer, 'running_var'): total_non_trainable += np.prod(layer.running_var.shape) # FLOPs for normalization operations flops = np.prod(input[0].shape) if layer.affine: flops *= 2 return total_trainable, total_non_trainable, flops # Compute FLOPs and parameters for linear (fully connected) layers def compute_linear_flops(layer, input, output): ic, oc = layer.in_features, layer.out_features # Input/output features total_trainable = 0 total_non_trainable = 0 flops = 0 # Compute parameters and FLOPs for the weight if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'): param_count = np.prod(layer.weight.shape) trainable, non_trainable = calculate_grad_params(param_count, layer.weight) total_trainable += trainable total_non_trainable += non_trainable flops += (2 * ic - 1) * oc # Compute parameters and FLOPs for the bias if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'): param_count = np.prod(layer.bias.shape) trainable, non_trainable = calculate_grad_params(param_count, layer.bias) total_trainable += trainable total_non_trainable += non_trainable flops += oc return total_trainable, total_non_trainable, flops # Model summary function: calculates the total parameters and FLOPs for a model @torch.no_grad() def model_summary(model, input_data, target_data=None, is_coremodel=True, return_data=False): model.eval() summary_info = OrderedDict() hooks = [] # Hook function to register layer and compute its parameters/FLOPs def register_layer_hook(layer): def hook(layer, input, output): layer_name = f"{layer.__class__.__name__}-{len(summary_info) + 1}" summary_info[layer_name] = OrderedDict() summary_info[layer_name]['input_shape'] = list(input[0].shape) summary_info[layer_name]['output_shape'] = list(output.shape) if not isinstance(output, (list, tuple)) else [list(o.shape) for o in output] if isinstance(layer, nn.Conv2d): trainable, non_trainable, flops = compute_conv_flops(layer, input, output) elif isinstance(layer, (nn.BatchNorm2d, nn.GroupNorm)): trainable, non_trainable, flops = compute_norm_flops(layer, input, output) elif isinstance(layer, nn.Linear): trainable, non_trainable, flops = compute_linear_flops(layer, input, output) else: trainable, non_trainable, flops = 0, 0, 0 summary_info[layer_name]['trainable_params'] = trainable summary_info[layer_name]['non_trainable_params'] = non_trainable summary_info[layer_name]['total_params'] = trainable + non_trainable summary_info[layer_name]['flops'] = flops if not isinstance(layer, (nn.Sequential, nn.ModuleList, nn.Identity)): hooks.append(layer.register_forward_hook(hook)) model.apply(register_layer_hook) if is_coremodel: model(input_data, target=target_data, mode='summary') else: model(input_data) for hook in hooks: hook.remove() total_params, trainable_params, total_flops = 0, 0, 0 for layer_name, layer_info in summary_info.items(): total_params += layer_info['total_params'] trainable_params += layer_info['trainable_params'] total_flops += layer_info['flops'] param_size_mb = total_params * 4 / (1024 ** 2) print(f"Total parameters: {total_params:,} ({format_flops(total_params)})") print(f"Trainable parameters: {trainable_params:,}") print(f"Non-trainable parameters: {total_params - trainable_params:,}") print(f"Total FLOPs: {total_flops:,} ({format_flops(total_flops)})") print(f"Model size: {param_size_mb:.2f} MB") if return_data: return total_params, total_flops # Example usage with a convolutional layer if __name__ == '__main__': conv_layer = nn.Conv2d(50, 10, 3, padding=1, groups=5, bias=True) model_summary(conv_layer, torch.rand((1, 50, 10, 10)), target_data=torch.ones(1, dtype=torch.long), is_coremodel=False) for name, param in conv_layer.named_parameters(): print(f"{name}: {param.size()}") # Save the model's summary details as a JSON file def save_model_as_json(args, model_content): """Save the model's details to a JSON file.""" os.makedirs(args.model_dir, exist_ok=True) filename = os.path.join(args.model_dir, f"model_{args.split_factor}.txt") with open(filename, 'w') as f: f.write(str(model_content))