a877aed45f
Change-Id: I16cd7730c1e0732253ac52f51010f6b813295aa7
58 lines
2.2 KiB
Python
58 lines
2.2 KiB
Python
"""
|
|
Author: Weisen Pan
|
|
Date: 2023-10-24
|
|
"""
|
|
import argparse
|
|
import torch
|
|
|
|
from preprocess import preprocess_data_exp23_dag, preprocess_data_exp23
|
|
from select_model import select_model_exp2
|
|
from models.DAG_Transformer import DAGTransformer
|
|
from models.CNN import CNNModel
|
|
from models.LSTM import LSTMModel
|
|
from models.Vanilla_Transformer import VanillaTransformerModel
|
|
from train_model_dag import train
|
|
from train_model_vanilla import train as train_vanilla
|
|
|
|
# Argument parsing
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model_name', required=True) # Choices: DAGTransformer, CNN, LSTM, VanillaTransformer
|
|
parser.add_argument('--split', default='Branch060202') # Choices: Branch090505, Branch080101, Branch060202
|
|
opt = parser.parse_args()
|
|
|
|
valid_models = ['DAGTransformer', 'CNN', 'LSTM', 'VanillaTransformer']
|
|
valid_splits = ['Branch090505', 'Branch080101', 'Branch060202']
|
|
|
|
if opt.model_name not in valid_models:
|
|
raise AssertionError('model should be one of: ' + '/'.join(valid_models))
|
|
model_name = opt.model_name
|
|
|
|
if opt.split not in valid_splits:
|
|
raise AssertionError('split should be one of: ' + '/'.join(valid_splits))
|
|
split = opt.split
|
|
|
|
config = select_model_exp2(model_name)
|
|
if model_name == 'DAGTransformer':
|
|
train_data, val_data, test_data = preprocess_data_exp23_dag(split)
|
|
else:
|
|
train_data, val_data, test_data = preprocess_data_exp23(split)
|
|
|
|
# Creating data loaders
|
|
loader_args = {'batch_size': config.batch_size, 'num_workers': 2, 'shuffle': False}
|
|
train_loader = torch.utils.data.DataLoader(dataset=train_data, **loader_args)
|
|
val_loader = torch.utils.data.DataLoader(dataset=val_data, **loader_args)
|
|
test_loader = torch.utils.data.DataLoader(dataset=test_data, **loader_args)
|
|
|
|
if __name__ == '__main__':
|
|
if model_name == 'DAGTransformer':
|
|
model = DAGTransformer(config).to(config.device)
|
|
train(config, model, train_loader, val_loader, test_loader)
|
|
else:
|
|
model_class = {
|
|
'LSTM': LSTMModel,
|
|
'CNN': CNNModel,
|
|
'VanillaTransformer': VanillaTransformerModel
|
|
}[model_name]
|
|
model = model_class(config).to(config.device)
|
|
train_vanilla(config, model, train_loader, val_loader, test_loader)
|