""" Author: Weisen Pan Date: 2023-10-24 """ import argparse import torch from preprocess import preprocess_data_exp23_dag, preprocess_data_exp23 from select_model import select_model_exp2 from models.DAG_Transformer import DAGTransformer from models.CNN import CNNModel from models.LSTM import LSTMModel from models.Vanilla_Transformer import VanillaTransformerModel from train_model_dag import train from train_model_vanilla import train as train_vanilla # Argument parsing parser = argparse.ArgumentParser() parser.add_argument('--model_name', required=True) # Choices: DAGTransformer, CNN, LSTM, VanillaTransformer parser.add_argument('--split', default='Branch060202') # Choices: Branch090505, Branch080101, Branch060202 opt = parser.parse_args() valid_models = ['DAGTransformer', 'CNN', 'LSTM', 'VanillaTransformer'] valid_splits = ['Branch090505', 'Branch080101', 'Branch060202'] if opt.model_name not in valid_models: raise AssertionError('model should be one of: ' + '/'.join(valid_models)) model_name = opt.model_name if opt.split not in valid_splits: raise AssertionError('split should be one of: ' + '/'.join(valid_splits)) split = opt.split config = select_model_exp2(model_name) if model_name == 'DAGTransformer': train_data, val_data, test_data = preprocess_data_exp23_dag(split) else: train_data, val_data, test_data = preprocess_data_exp23(split) # Creating data loaders loader_args = {'batch_size': config.batch_size, 'num_workers': 2, 'shuffle': False} train_loader = torch.utils.data.DataLoader(dataset=train_data, **loader_args) val_loader = torch.utils.data.DataLoader(dataset=val_data, **loader_args) test_loader = torch.utils.data.DataLoader(dataset=test_data, **loader_args) if __name__ == '__main__': if model_name == 'DAGTransformer': model = DAGTransformer(config).to(config.device) train(config, model, train_loader, val_loader, test_loader) else: model_class = { 'LSTM': LSTMModel, 'CNN': CNNModel, 'VanillaTransformer': VanillaTransformerModel }[model_name] model = model_class(config).to(config.device) train_vanilla(config, model, train_loader, val_loader, test_loader)