a877aed45f
Change-Id: I16cd7730c1e0732253ac52f51010f6b813295aa7
45 lines
1.2 KiB
Python
45 lines
1.2 KiB
Python
"""
|
|
Author: Weisen Pan
|
|
Date: 2023-10-24
|
|
"""
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class CNNModel(nn.Module):
|
|
def __init__(self, config):
|
|
"""
|
|
Initializes the CNN model with the given configuration.
|
|
|
|
Args:
|
|
- config (object): Configuration object with model parameters.
|
|
"""
|
|
super(CNNModel, self).__init__()
|
|
|
|
self.conv = nn.Conv1d(
|
|
in_channels=config.n_feat,
|
|
out_channels=config.outdim,
|
|
kernel_size=1
|
|
)
|
|
self.maxpool = nn.MaxPool1d(config.pooldim)
|
|
self.fc = nn.Linear((config.outdim // config.pooldim) * config.num_task, config.num_classes)
|
|
self.dropout = nn.Dropout(config.dropout)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Forward pass of the CNN model.
|
|
|
|
Args:
|
|
- x (torch.Tensor): Input tensor.
|
|
|
|
Returns:
|
|
- torch.Tensor: Model's output tensor.
|
|
"""
|
|
out = x.permute(0, 2, 1)
|
|
out = F.relu(self.conv(out))
|
|
out = self.dropout(out)
|
|
out = self.maxpool(out.permute(0, 2, 1))
|
|
out = out.reshape(out.size(0), -1)
|
|
out = self.fc(out)
|
|
|
|
return out
|