use-case-and-architecture/ai_computing_force_scheduling/workflow_models/CNN.py
Weisen Pan a877aed45f AI-based CFN Traffic Control and Computer Force Scheduling
Change-Id: I16cd7730c1e0732253ac52f51010f6b813295aa7
2023-11-03 00:09:19 -07:00

45 lines
1.2 KiB
Python

"""
Author: Weisen Pan
Date: 2023-10-24
"""
import torch.nn as nn
import torch.nn.functional as F
class CNNModel(nn.Module):
def __init__(self, config):
"""
Initializes the CNN model with the given configuration.
Args:
- config (object): Configuration object with model parameters.
"""
super(CNNModel, self).__init__()
self.conv = nn.Conv1d(
in_channels=config.n_feat,
out_channels=config.outdim,
kernel_size=1
)
self.maxpool = nn.MaxPool1d(config.pooldim)
self.fc = nn.Linear((config.outdim // config.pooldim) * config.num_task, config.num_classes)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
"""
Forward pass of the CNN model.
Args:
- x (torch.Tensor): Input tensor.
Returns:
- torch.Tensor: Model's output tensor.
"""
out = x.permute(0, 2, 1)
out = F.relu(self.conv(out))
out = self.dropout(out)
out = self.maxpool(out.permute(0, 2, 1))
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out