196 lines
6.8 KiB
Python
196 lines
6.8 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
def _count_rnn_cell(input_size, hidden_size, bias=True):
|
|
"""Calculate the total operations for a single RNN cell.
|
|
|
|
Args:
|
|
input_size (int): Size of the input.
|
|
hidden_size (int): Size of the hidden state.
|
|
bias (bool, optional): Whether the RNN cell uses bias. Defaults to True.
|
|
|
|
Returns:
|
|
int: Total number of operations for the RNN cell.
|
|
"""
|
|
ops = hidden_size * (input_size + hidden_size) + hidden_size
|
|
if bias:
|
|
ops += hidden_size * 2
|
|
return ops
|
|
|
|
def count_rnn_cell(cell: nn.RNNCell, x: torch.Tensor):
|
|
"""Count operations for the RNNCell over a batch of input.
|
|
|
|
Args:
|
|
cell (nn.RNNCell): The RNNCell to count operations for.
|
|
x (torch.Tensor): Input tensor.
|
|
"""
|
|
ops = _count_rnn_cell(cell.input_size, cell.hidden_size, cell.bias)
|
|
batch_size = x[0].size(0)
|
|
total_ops = ops * batch_size
|
|
cell.total_ops += torch.DoubleTensor([int(total_ops)])
|
|
|
|
def _count_gru_cell(input_size, hidden_size, bias=True):
|
|
"""Calculate the total operations for a single GRU cell.
|
|
|
|
Args:
|
|
input_size (int): Size of the input.
|
|
hidden_size (int): Size of the hidden state.
|
|
bias (bool, optional): Whether the GRU cell uses bias. Defaults to True.
|
|
|
|
Returns:
|
|
int: Total number of operations for the GRU cell.
|
|
"""
|
|
ops = (hidden_size + input_size) * hidden_size + hidden_size
|
|
if bias:
|
|
ops += hidden_size * 2
|
|
ops *= 2 # For reset and update gates
|
|
|
|
ops += (hidden_size + input_size) * hidden_size + hidden_size # Calculate new gate
|
|
if bias:
|
|
ops += hidden_size * 2
|
|
ops += hidden_size # Hadamard product
|
|
ops += hidden_size * 3 # Final output
|
|
|
|
return ops
|
|
|
|
def count_gru_cell(cell: nn.GRUCell, x: torch.Tensor):
|
|
"""Count operations for the GRUCell over a batch of input.
|
|
|
|
Args:
|
|
cell (nn.GRUCell): The GRUCell to count operations for.
|
|
x (torch.Tensor): Input tensor.
|
|
"""
|
|
ops = _count_gru_cell(cell.input_size, cell.hidden_size, cell.bias)
|
|
batch_size = x[0].size(0)
|
|
total_ops = ops * batch_size
|
|
cell.total_ops += torch.DoubleTensor([int(total_ops)])
|
|
|
|
def _count_lstm_cell(input_size, hidden_size, bias=True):
|
|
"""Calculate the total operations for a single LSTM cell.
|
|
|
|
Args:
|
|
input_size (int): Size of the input.
|
|
hidden_size (int): Size of the hidden state.
|
|
bias (bool, optional): Whether the LSTM cell uses bias. Defaults to True.
|
|
|
|
Returns:
|
|
int: Total number of operations for the LSTM cell.
|
|
"""
|
|
ops = (input_size + hidden_size) * hidden_size + hidden_size
|
|
if bias:
|
|
ops += hidden_size * 2
|
|
ops *= 4 # For input, forget, output, and cell gates
|
|
|
|
ops += hidden_size * 3 # Cell state update
|
|
ops += hidden_size # Final output
|
|
|
|
return ops
|
|
|
|
def count_lstm_cell(cell: nn.LSTMCell, x: torch.Tensor):
|
|
"""Count operations for the LSTMCell over a batch of input.
|
|
|
|
Args:
|
|
cell (nn.LSTMCell): The LSTMCell to count operations for.
|
|
x (torch.Tensor): Input tensor.
|
|
"""
|
|
ops = _count_lstm_cell(cell.input_size, cell.hidden_size, cell.bias)
|
|
batch_size = x[0].size(0)
|
|
total_ops = ops * batch_size
|
|
cell.total_ops += torch.DoubleTensor([int(total_ops)])
|
|
|
|
def _count_rnn_layers(model: nn.RNN, num_layers, input_size, hidden_size):
|
|
"""Calculate the total operations for RNN layers.
|
|
|
|
Args:
|
|
model (nn.RNN): The RNN model.
|
|
num_layers (int): Number of layers in the RNN.
|
|
input_size (int): Size of the input.
|
|
hidden_size (int): Size of the hidden state.
|
|
|
|
Returns:
|
|
int: Total number of operations for the RNN layers.
|
|
"""
|
|
ops = _count_rnn_cell(input_size, hidden_size, model.bias)
|
|
for _ in range(num_layers - 1):
|
|
ops += _count_rnn_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias)
|
|
return ops
|
|
|
|
def count_rnn(model: nn.RNN, x: torch.Tensor):
|
|
"""Count operations for the entire RNN over a batch of input.
|
|
|
|
Args:
|
|
model (nn.RNN): The RNN model.
|
|
x (torch.Tensor): Input tensor.
|
|
"""
|
|
batch_size = x[0].size(0) if model.batch_first else x[0].size(1)
|
|
num_steps = x[0].size(1) if model.batch_first else x[0].size(0)
|
|
|
|
ops = _count_rnn_layers(model, model.num_layers, model.input_size, model.hidden_size)
|
|
total_ops = ops * num_steps * batch_size
|
|
model.total_ops += torch.DoubleTensor([int(total_ops)])
|
|
|
|
def _count_gru_layers(model: nn.GRU, num_layers, input_size, hidden_size):
|
|
"""Calculate the total operations for GRU layers.
|
|
|
|
Args:
|
|
model (nn.GRU): The GRU model.
|
|
num_layers (int): Number of layers in the GRU.
|
|
input_size (int): Size of the input.
|
|
hidden_size (int): Size of the hidden state.
|
|
|
|
Returns:
|
|
int: Total number of operations for the GRU layers.
|
|
"""
|
|
ops = _count_gru_cell(input_size, hidden_size, model.bias)
|
|
for _ in range(num_layers - 1):
|
|
ops += _count_gru_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias)
|
|
return ops
|
|
|
|
def count_gru(model: nn.GRU, x: torch.Tensor):
|
|
"""Count operations for the entire GRU over a batch of input.
|
|
|
|
Args:
|
|
model (nn.GRU): The GRU model.
|
|
x (torch.Tensor): Input tensor.
|
|
"""
|
|
batch_size = x[0].size(0) if model.batch_first else x[0].size(1)
|
|
num_steps = x[0].size(1) if model.batch_first else x[0].size(0)
|
|
|
|
ops = _count_gru_layers(model, model.num_layers, model.input_size, model.hidden_size)
|
|
total_ops = ops * num_steps * batch_size
|
|
model.total_ops += torch.DoubleTensor([int(total_ops)])
|
|
|
|
def _count_lstm_layers(model: nn.LSTM, num_layers, input_size, hidden_size):
|
|
"""Calculate the total operations for LSTM layers.
|
|
|
|
Args:
|
|
model (nn.LSTM): The LSTM model.
|
|
num_layers (int): Number of layers in the LSTM.
|
|
input_size (int): Size of the input.
|
|
hidden_size (int): Size of the hidden state.
|
|
|
|
Returns:
|
|
int: Total number of operations for the LSTM layers.
|
|
"""
|
|
ops = _count_lstm_cell(input_size, hidden_size, model.bias)
|
|
for _ in range(num_layers - 1):
|
|
ops += _count_lstm_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias)
|
|
return ops
|
|
|
|
def count_lstm(model: nn.LSTM, x: torch.Tensor):
|
|
"""Count operations for the entire LSTM over a batch of input.
|
|
|
|
Args:
|
|
model (nn.LSTM): The LSTM model.
|
|
x (torch.Tensor): Input tensor.
|
|
"""
|
|
batch_size = x[0].size(0) if model.batch_first else x[0].size(1)
|
|
num_steps = x[0].size(1) if model.batch_first else x[0].size(0)
|
|
|
|
ops = _count_lstm_layers(model, model.num_layers, model.input_size, model.hidden_size)
|
|
total_ops = ops * num_steps * batch_size
|
|
model.total_ops += torch.DoubleTensor([int(total_ops)])
|