Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

196 lines
6.8 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import torch
import torch.nn as nn
def _count_rnn_cell(input_size, hidden_size, bias=True):
"""Calculate the total operations for a single RNN cell.
Args:
input_size (int): Size of the input.
hidden_size (int): Size of the hidden state.
bias (bool, optional): Whether the RNN cell uses bias. Defaults to True.
Returns:
int: Total number of operations for the RNN cell.
"""
ops = hidden_size * (input_size + hidden_size) + hidden_size
if bias:
ops += hidden_size * 2
return ops
def count_rnn_cell(cell: nn.RNNCell, x: torch.Tensor):
"""Count operations for the RNNCell over a batch of input.
Args:
cell (nn.RNNCell): The RNNCell to count operations for.
x (torch.Tensor): Input tensor.
"""
ops = _count_rnn_cell(cell.input_size, cell.hidden_size, cell.bias)
batch_size = x[0].size(0)
total_ops = ops * batch_size
cell.total_ops += torch.DoubleTensor([int(total_ops)])
def _count_gru_cell(input_size, hidden_size, bias=True):
"""Calculate the total operations for a single GRU cell.
Args:
input_size (int): Size of the input.
hidden_size (int): Size of the hidden state.
bias (bool, optional): Whether the GRU cell uses bias. Defaults to True.
Returns:
int: Total number of operations for the GRU cell.
"""
ops = (hidden_size + input_size) * hidden_size + hidden_size
if bias:
ops += hidden_size * 2
ops *= 2 # For reset and update gates
ops += (hidden_size + input_size) * hidden_size + hidden_size # Calculate new gate
if bias:
ops += hidden_size * 2
ops += hidden_size # Hadamard product
ops += hidden_size * 3 # Final output
return ops
def count_gru_cell(cell: nn.GRUCell, x: torch.Tensor):
"""Count operations for the GRUCell over a batch of input.
Args:
cell (nn.GRUCell): The GRUCell to count operations for.
x (torch.Tensor): Input tensor.
"""
ops = _count_gru_cell(cell.input_size, cell.hidden_size, cell.bias)
batch_size = x[0].size(0)
total_ops = ops * batch_size
cell.total_ops += torch.DoubleTensor([int(total_ops)])
def _count_lstm_cell(input_size, hidden_size, bias=True):
"""Calculate the total operations for a single LSTM cell.
Args:
input_size (int): Size of the input.
hidden_size (int): Size of the hidden state.
bias (bool, optional): Whether the LSTM cell uses bias. Defaults to True.
Returns:
int: Total number of operations for the LSTM cell.
"""
ops = (input_size + hidden_size) * hidden_size + hidden_size
if bias:
ops += hidden_size * 2
ops *= 4 # For input, forget, output, and cell gates
ops += hidden_size * 3 # Cell state update
ops += hidden_size # Final output
return ops
def count_lstm_cell(cell: nn.LSTMCell, x: torch.Tensor):
"""Count operations for the LSTMCell over a batch of input.
Args:
cell (nn.LSTMCell): The LSTMCell to count operations for.
x (torch.Tensor): Input tensor.
"""
ops = _count_lstm_cell(cell.input_size, cell.hidden_size, cell.bias)
batch_size = x[0].size(0)
total_ops = ops * batch_size
cell.total_ops += torch.DoubleTensor([int(total_ops)])
def _count_rnn_layers(model: nn.RNN, num_layers, input_size, hidden_size):
"""Calculate the total operations for RNN layers.
Args:
model (nn.RNN): The RNN model.
num_layers (int): Number of layers in the RNN.
input_size (int): Size of the input.
hidden_size (int): Size of the hidden state.
Returns:
int: Total number of operations for the RNN layers.
"""
ops = _count_rnn_cell(input_size, hidden_size, model.bias)
for _ in range(num_layers - 1):
ops += _count_rnn_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias)
return ops
def count_rnn(model: nn.RNN, x: torch.Tensor):
"""Count operations for the entire RNN over a batch of input.
Args:
model (nn.RNN): The RNN model.
x (torch.Tensor): Input tensor.
"""
batch_size = x[0].size(0) if model.batch_first else x[0].size(1)
num_steps = x[0].size(1) if model.batch_first else x[0].size(0)
ops = _count_rnn_layers(model, model.num_layers, model.input_size, model.hidden_size)
total_ops = ops * num_steps * batch_size
model.total_ops += torch.DoubleTensor([int(total_ops)])
def _count_gru_layers(model: nn.GRU, num_layers, input_size, hidden_size):
"""Calculate the total operations for GRU layers.
Args:
model (nn.GRU): The GRU model.
num_layers (int): Number of layers in the GRU.
input_size (int): Size of the input.
hidden_size (int): Size of the hidden state.
Returns:
int: Total number of operations for the GRU layers.
"""
ops = _count_gru_cell(input_size, hidden_size, model.bias)
for _ in range(num_layers - 1):
ops += _count_gru_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias)
return ops
def count_gru(model: nn.GRU, x: torch.Tensor):
"""Count operations for the entire GRU over a batch of input.
Args:
model (nn.GRU): The GRU model.
x (torch.Tensor): Input tensor.
"""
batch_size = x[0].size(0) if model.batch_first else x[0].size(1)
num_steps = x[0].size(1) if model.batch_first else x[0].size(0)
ops = _count_gru_layers(model, model.num_layers, model.input_size, model.hidden_size)
total_ops = ops * num_steps * batch_size
model.total_ops += torch.DoubleTensor([int(total_ops)])
def _count_lstm_layers(model: nn.LSTM, num_layers, input_size, hidden_size):
"""Calculate the total operations for LSTM layers.
Args:
model (nn.LSTM): The LSTM model.
num_layers (int): Number of layers in the LSTM.
input_size (int): Size of the input.
hidden_size (int): Size of the hidden state.
Returns:
int: Total number of operations for the LSTM layers.
"""
ops = _count_lstm_cell(input_size, hidden_size, model.bias)
for _ in range(num_layers - 1):
ops += _count_lstm_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias)
return ops
def count_lstm(model: nn.LSTM, x: torch.Tensor):
"""Count operations for the entire LSTM over a batch of input.
Args:
model (nn.LSTM): The LSTM model.
x (torch.Tensor): Input tensor.
"""
batch_size = x[0].size(0) if model.batch_first else x[0].size(1)
num_steps = x[0].size(1) if model.batch_first else x[0].size(0)
ops = _count_lstm_layers(model, model.num_layers, model.input_size, model.hidden_size)
total_ops = ops * num_steps * batch_size
model.total_ops += torch.DoubleTensor([int(total_ops)])