""" Author: Weisen Pan Date: 2023-10-24 """ import time import torch import numpy as np import torch.nn.functional as F from datetime import timedelta from sklearn import metrics from tqdm import tqdm from scheduler import WarmUpLR, downLR def get_time_dif(start_time): """Get the time difference between now and the start time.""" elapsed_time = time.time() - start_time return timedelta(seconds=int(round(elapsed_time))) def train(config, model, train_iter, dev_iter, test_iter): start_time = time.time() model.train() optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) warmup_steps = config.num_epochs / 2 * len(train_iter) scheduler = downLR(optimizer, (config.num_epochs - warmup_steps / len(train_iter)) * len(train_iter)) warmup_scheduler = WarmUpLR(optimizer, warmup_steps) dev_best_loss = float('inf') dev_best_acc = 0 test_best_acc = 0 for epoch in range(config.num_epochs): epoch_loss = 0 predictions, labels = [], [] for trains, label_batch, poss, masks in tqdm(train_iter): trains, label_batch, poss, masks = [tensor.to(config.device) for tensor in [trains, label_batch, poss, masks]] outputs = model(trains, poss, masks) model.zero_grad() loss = F.cross_entropy(outputs, label_batch) loss.backward() optimizer.step() if epoch < warmup_steps / len(train_iter): warmup_scheduler.step() else: scheduler.step() epoch_loss += loss.item() predictions.extend(torch.max(outputs, 1)[1].tolist()) labels.extend(label_batch.tolist()) train_acc = metrics.accuracy_score(labels, predictions) dev_acc, dev_loss = evaluate(config, model, dev_iter) if dev_loss < dev_best_loss: dev_best_loss = dev_loss if dev_acc > dev_best_acc: dev_best_acc = dev_acc test_best_acc = evaluate(config, model, test_iter)[0] time_dif = get_time_dif(start_time) print(f'Epoch: {epoch + 1}/{config.num_epochs}, Train Loss: {epoch_loss / len(train_iter):.2f}, Train Acc: {train_acc:.2%}, Dev Loss: {dev_loss:.2f}, Dev Acc: {dev_acc:.2%}, Test Best Acc: {test_best_acc:.2%}, Time: {time_dif}') test(config, model, test_iter) def test(config, model, test_iter): model.eval() test_acc, test_loss, test_confusion = evaluate(config, model, test_iter, test=True) print(f'Test Loss: {test_loss:.2f}, Test Acc: {test_acc:.2%}') print("Confusion Matrix:", test_confusion) print("Time usage:", get_time_dif(time.time())) def evaluate(config, model, data_iter, test=False): model.eval() total_loss = 0 predictions, labels = [], [] with torch.no_grad(): for texts, labels_batch, poss, masks in data_iter: texts, poss, masks, labels_batch = [tensor.to(config.device) for tensor in [texts, poss, masks, labels_batch]] outputs = model(texts, poss, masks) loss = F.cross_entropy(outputs, labels_batch) total_loss += loss.item() predictions.extend(torch.max(outputs, 1)[1].tolist()) labels.extend(labels_batch.tolist()) accuracy = metrics.accuracy_score(labels, predictions) if test: confusion = metrics.confusion_matrix(labels, predictions) return accuracy, total_loss / len(data_iter), confusion return accuracy, total_loss / len(data_iter)