use-case-and-architecture/ai_computing_force_scheduling/run_CFN_schedule.py
Weisen Pan a877aed45f AI-based CFN Traffic Control and Computer Force Scheduling
Change-Id: I16cd7730c1e0732253ac52f51010f6b813295aa7
2023-11-03 00:09:19 -07:00

58 lines
2.2 KiB
Python

"""
Author: Weisen Pan
Date: 2023-10-24
"""
import argparse
import torch
from preprocess import preprocess_data_exp23_dag, preprocess_data_exp23
from select_model import select_model_exp2
from models.DAG_Transformer import DAGTransformer
from models.CNN import CNNModel
from models.LSTM import LSTMModel
from models.Vanilla_Transformer import VanillaTransformerModel
from train_model_dag import train
from train_model_vanilla import train as train_vanilla
# Argument parsing
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', required=True) # Choices: DAGTransformer, CNN, LSTM, VanillaTransformer
parser.add_argument('--split', default='Branch060202') # Choices: Branch090505, Branch080101, Branch060202
opt = parser.parse_args()
valid_models = ['DAGTransformer', 'CNN', 'LSTM', 'VanillaTransformer']
valid_splits = ['Branch090505', 'Branch080101', 'Branch060202']
if opt.model_name not in valid_models:
raise AssertionError('model should be one of: ' + '/'.join(valid_models))
model_name = opt.model_name
if opt.split not in valid_splits:
raise AssertionError('split should be one of: ' + '/'.join(valid_splits))
split = opt.split
config = select_model_exp2(model_name)
if model_name == 'DAGTransformer':
train_data, val_data, test_data = preprocess_data_exp23_dag(split)
else:
train_data, val_data, test_data = preprocess_data_exp23(split)
# Creating data loaders
loader_args = {'batch_size': config.batch_size, 'num_workers': 2, 'shuffle': False}
train_loader = torch.utils.data.DataLoader(dataset=train_data, **loader_args)
val_loader = torch.utils.data.DataLoader(dataset=val_data, **loader_args)
test_loader = torch.utils.data.DataLoader(dataset=test_data, **loader_args)
if __name__ == '__main__':
if model_name == 'DAGTransformer':
model = DAGTransformer(config).to(config.device)
train(config, model, train_loader, val_loader, test_loader)
else:
model_class = {
'LSTM': LSTMModel,
'CNN': CNNModel,
'VanillaTransformer': VanillaTransformerModel
}[model_name]
model = model_class(config).to(config.device)
train_vanilla(config, model, train_loader, val_loader, test_loader)