""" Author: Weisen Pan Date: 2023-10-24 """ import torch.nn as nn import torch.nn.functional as F class CNNModel(nn.Module): def __init__(self, config): """ Initializes the CNN model with the given configuration. Args: - config (object): Configuration object with model parameters. """ super(CNNModel, self).__init__() self.conv = nn.Conv1d( in_channels=config.n_feat, out_channels=config.outdim, kernel_size=1 ) self.maxpool = nn.MaxPool1d(config.pooldim) self.fc = nn.Linear((config.outdim // config.pooldim) * config.num_task, config.num_classes) self.dropout = nn.Dropout(config.dropout) def forward(self, x): """ Forward pass of the CNN model. Args: - x (torch.Tensor): Input tensor. Returns: - torch.Tensor: Model's output tensor. """ out = x.permute(0, 2, 1) out = F.relu(self.conv(out)) out = self.dropout(out) out = self.maxpool(out.permute(0, 2, 1)) out = out.reshape(out.size(0), -1) out = self.fc(out) return out