# -*- coding: utf-8 -*- # @Author: Weisen Pan import torch import torch.nn as nn def _count_rnn_cell(input_size, hidden_size, bias=True): """Calculate the total operations for a single RNN cell. Args: input_size (int): Size of the input. hidden_size (int): Size of the hidden state. bias (bool, optional): Whether the RNN cell uses bias. Defaults to True. Returns: int: Total number of operations for the RNN cell. """ ops = hidden_size * (input_size + hidden_size) + hidden_size if bias: ops += hidden_size * 2 return ops def count_rnn_cell(cell: nn.RNNCell, x: torch.Tensor): """Count operations for the RNNCell over a batch of input. Args: cell (nn.RNNCell): The RNNCell to count operations for. x (torch.Tensor): Input tensor. """ ops = _count_rnn_cell(cell.input_size, cell.hidden_size, cell.bias) batch_size = x[0].size(0) total_ops = ops * batch_size cell.total_ops += torch.DoubleTensor([int(total_ops)]) def _count_gru_cell(input_size, hidden_size, bias=True): """Calculate the total operations for a single GRU cell. Args: input_size (int): Size of the input. hidden_size (int): Size of the hidden state. bias (bool, optional): Whether the GRU cell uses bias. Defaults to True. Returns: int: Total number of operations for the GRU cell. """ ops = (hidden_size + input_size) * hidden_size + hidden_size if bias: ops += hidden_size * 2 ops *= 2 # For reset and update gates ops += (hidden_size + input_size) * hidden_size + hidden_size # Calculate new gate if bias: ops += hidden_size * 2 ops += hidden_size # Hadamard product ops += hidden_size * 3 # Final output return ops def count_gru_cell(cell: nn.GRUCell, x: torch.Tensor): """Count operations for the GRUCell over a batch of input. Args: cell (nn.GRUCell): The GRUCell to count operations for. x (torch.Tensor): Input tensor. """ ops = _count_gru_cell(cell.input_size, cell.hidden_size, cell.bias) batch_size = x[0].size(0) total_ops = ops * batch_size cell.total_ops += torch.DoubleTensor([int(total_ops)]) def _count_lstm_cell(input_size, hidden_size, bias=True): """Calculate the total operations for a single LSTM cell. Args: input_size (int): Size of the input. hidden_size (int): Size of the hidden state. bias (bool, optional): Whether the LSTM cell uses bias. Defaults to True. Returns: int: Total number of operations for the LSTM cell. """ ops = (input_size + hidden_size) * hidden_size + hidden_size if bias: ops += hidden_size * 2 ops *= 4 # For input, forget, output, and cell gates ops += hidden_size * 3 # Cell state update ops += hidden_size # Final output return ops def count_lstm_cell(cell: nn.LSTMCell, x: torch.Tensor): """Count operations for the LSTMCell over a batch of input. Args: cell (nn.LSTMCell): The LSTMCell to count operations for. x (torch.Tensor): Input tensor. """ ops = _count_lstm_cell(cell.input_size, cell.hidden_size, cell.bias) batch_size = x[0].size(0) total_ops = ops * batch_size cell.total_ops += torch.DoubleTensor([int(total_ops)]) def _count_rnn_layers(model: nn.RNN, num_layers, input_size, hidden_size): """Calculate the total operations for RNN layers. Args: model (nn.RNN): The RNN model. num_layers (int): Number of layers in the RNN. input_size (int): Size of the input. hidden_size (int): Size of the hidden state. Returns: int: Total number of operations for the RNN layers. """ ops = _count_rnn_cell(input_size, hidden_size, model.bias) for _ in range(num_layers - 1): ops += _count_rnn_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias) return ops def count_rnn(model: nn.RNN, x: torch.Tensor): """Count operations for the entire RNN over a batch of input. Args: model (nn.RNN): The RNN model. x (torch.Tensor): Input tensor. """ batch_size = x[0].size(0) if model.batch_first else x[0].size(1) num_steps = x[0].size(1) if model.batch_first else x[0].size(0) ops = _count_rnn_layers(model, model.num_layers, model.input_size, model.hidden_size) total_ops = ops * num_steps * batch_size model.total_ops += torch.DoubleTensor([int(total_ops)]) def _count_gru_layers(model: nn.GRU, num_layers, input_size, hidden_size): """Calculate the total operations for GRU layers. Args: model (nn.GRU): The GRU model. num_layers (int): Number of layers in the GRU. input_size (int): Size of the input. hidden_size (int): Size of the hidden state. Returns: int: Total number of operations for the GRU layers. """ ops = _count_gru_cell(input_size, hidden_size, model.bias) for _ in range(num_layers - 1): ops += _count_gru_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias) return ops def count_gru(model: nn.GRU, x: torch.Tensor): """Count operations for the entire GRU over a batch of input. Args: model (nn.GRU): The GRU model. x (torch.Tensor): Input tensor. """ batch_size = x[0].size(0) if model.batch_first else x[0].size(1) num_steps = x[0].size(1) if model.batch_first else x[0].size(0) ops = _count_gru_layers(model, model.num_layers, model.input_size, model.hidden_size) total_ops = ops * num_steps * batch_size model.total_ops += torch.DoubleTensor([int(total_ops)]) def _count_lstm_layers(model: nn.LSTM, num_layers, input_size, hidden_size): """Calculate the total operations for LSTM layers. Args: model (nn.LSTM): The LSTM model. num_layers (int): Number of layers in the LSTM. input_size (int): Size of the input. hidden_size (int): Size of the hidden state. Returns: int: Total number of operations for the LSTM layers. """ ops = _count_lstm_cell(input_size, hidden_size, model.bias) for _ in range(num_layers - 1): ops += _count_lstm_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias) return ops def count_lstm(model: nn.LSTM, x: torch.Tensor): """Count operations for the entire LSTM over a batch of input. Args: model (nn.LSTM): The LSTM model. x (torch.Tensor): Input tensor. """ batch_size = x[0].size(0) if model.batch_first else x[0].size(1) num_steps = x[0].size(1) if model.batch_first else x[0].size(0) ops = _count_lstm_layers(model, model.num_layers, model.input_size, model.hidden_size) total_ops = ops * num_steps * batch_size model.total_ops += torch.DoubleTensor([int(total_ops)])