""" Author: Weisen Pan Date: 2023-10-24 """ import numpy as np import torch import torch.nn.functional as F from datetime import timedelta from sklearn import metrics from scheduler import WarmUpLR, downLR def get_time_difference(start_time): """Compute the time elapsed since the start_time.""" end_time = time.time() elapsed_time = end_time - start_time return timedelta(seconds=int(round(elapsed_time))) def train(config, model, data): start_time = time.time() model.train() optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) warmup_epoch = config.num_epochs / 2 scheduler = downLR(optimizer, (config.num_epochs - warmup_epoch)) warmup_scheduler = WarmUpLR(optimizer, warmup_epoch) dev_best_loss, dev_best_acc, test_best_acc = float('inf'), 0.0, 0.0 learning_rates = np.zeros((config.num_epochs, 2)) for epoch in range(config.num_epochs): print(f'Epoch [{epoch + 1}/{config.num_epochs}]') learning_rates[epoch][0] = epoch if epoch >= warmup_epoch: current_learning_rate = scheduler.get_lr()[0] learning_rates[epoch][1] = current_learning_rate else: current_learning_rate = warmup_scheduler.get_lr()[0] learning_rates[epoch][0] = current_learning_rate print(f"Learning Rate: {current_learning_rate}") data = data.to(config.device) outputs = model(data) model.zero_grad() loss = F.cross_entropy(outputs[data.train_mask], data.labels[data.train_mask]) loss.backward() optimizer.step() if epoch < warmup_epoch: warmup_scheduler.step() else: scheduler.step() predictions = torch.max(outputs[data.train_mask], 1)[1] train_acc = get_accuracy(predictions, data.labels[data.train_mask]) dev_acc, dev_loss = evaluate(config, model, data) test_acc, test_loss = test(config, model, data) if dev_loss < dev_best_loss: dev_best_loss = dev_loss improve_marker = '*' else: improve_marker = '' if dev_acc > dev_best_acc: dev_best_acc = dev_acc test_best_acc = test_acc elapsed_time = get_time_difference(start_time) status = (f'Iter: {epoch + 1:>6}, Train Loss: {loss.item():>5.2f}, Train Acc: {train_acc:>6.2%}, ' f'Val Loss: {dev_loss:>5.2f}, Val Acc: {dev_acc:>6.2%}, ' f'Test Loss: {test_loss:>5.2f}, Test Acc: {test_acc:>6.2%}, Time: {elapsed_time} {improve_marker}') print(status) print(f'Best Val Acc: {dev_best_acc}, Best Test Acc: {test_best_acc}') test(config, model, data, final=True) def test(config, model, data, final=False): model.eval() with torch.no_grad(): outputs = model(data) test_loss = F.cross_entropy(outputs[data.test_mask], data.labels[data.test_mask]) predictions = torch.max(outputs[data.test_mask], 1)[1] test_acc = get_accuracy(predictions, data.labels[data.test_mask]) if final: print(f'Test Loss: {test_loss:>5.2f}, Test Acc: {test_acc:>6.2%}') confusion = metrics.confusion_matrix(predictions.cpu().numpy(), data.labels[data.test_mask].cpu().numpy()) print('Confusion Matrix:\n', confusion) return test_acc, test_loss, confusion return test_acc, test_loss def evaluate(config, model, data): model.eval() with torch.no_grad(): outputs = model(data) eval_loss = F.cross_entropy(outputs[data.val_mask], data.labels[data.val_mask]) predictions = torch.max(outputs[data.val_mask], 1)[1] eval_acc = get_accuracy(predictions, data.labels[data.val_mask]) return eval_acc, eval_loss def get_accuracy(predictions, true_labels): return metrics.accuracy_score(predictions.cpu().numpy(), true_labels.cpu().numpy())