import click from tqdm import tqdm from enum import Enum, auto from dotenv import load_dotenv import os import torch from torch.utils.data import DataLoader from accelerate import Accelerator from train.dataset import NewsDataset from train.model import Classifier from data.main import paths, connect, ticklabels import numpy as np import pandas as pd class Stage(Enum): TRAIN = auto() DEV = auto() @click.command('main') @click.option('--epochs', default=10, type=int) def main(epochs): dev_after = 5 visible_devices = None lr = 1e-4 debug = False torch.manual_seed(0) num_workers = int(os.getenv('NUMBER_OF_WORKERS', 0)) embedding_length = int(os.getenv('EMBEDDING_LENGTH', 384)) dataset = NewsDataset() trainset, devset = torch.utils.data.random_split(dataset, [0.8, 0.2]) batch_size = int(os.getenv('BATCH_SIZE', 512)) trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) devloader = DataLoader(devset, shuffle=False, num_workers=num_workers) accelerator = Accelerator() model = Classifier(embedding_length=embedding_length, classes=5) # it's possible to control which GPUs the process can see using an environmental variable if visible_devices: os.environ['CUDA_VISIBLE_DEVICES'] = visible_devices if debug: os.environ['CUDA_LAUNCH_BLOCKING'] = '1' #accelerator.log({"message" :"debug enabled"}) criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.AdamW(model.parameters(), lr=lr) # wrap objects with accelerate model, optimizer, trainloader, devloader = accelerator.prepare(model, optimizer, trainloader, devloader) def run(): """runner for training and valdating""" running_loss = 0.0 # set the model to train model model.train() if stage == Stage.TRAIN else model.eval() dataloader = trainloader if stage == Stage.TRAIN else devloader desc = 'train epoch' if stage == Stage.TRAIN else 'dev epoch' if debug: ... # Make sure there are no leftover gradients before starting training an epoch optimizer.zero_grad() for batch, (x, y) in enumerate(tqdm(dataloader, desc=desc)): pred_y = model(x) # Forward pass through model loss = criterion(pred_y, y) running_loss += loss # Increment running loss # Only update model weights on training if stage == Stage.TRAIN: accelerator.backward(loss) # Increment gradients within model by sending loss backwards optimizer.step() # Update model weights optimizer.zero_grad() # Reset gradients to 0 return running_loss / len(dataloader) for epoch in range(epochs): if (epoch + 1) % dev_after == 0: stage = Stage.DEV log = run() print(f"dev loss: {log:.3f}") stage = Stage.TRAIN log = run() print(f"train loss: {log:.3f}") torch.save(model.state_dict(), paths('model') / 'torch_clf.pth') @click.command('validate') def validate(): from sklearn.metrics import ConfusionMatrixDisplay import matplotlib.pyplot as plt import seaborn as sns embeddings = np.load(paths('data') / 'embeddings.npy') embedding_ids = pd.DataFrame(np.load(paths('data') / 'embedding_ids.npy'), columns=['id']).reset_index() embedding_length = int(os.getenv('EMBEDDING_LENGTH', 384)) model = Classifier(embedding_length=embedding_length, classes=5) model.load_state_dict(torch.load(paths('model') / 'torch_clf.pth')) model.eval() dataset = NewsDataset() y = dataset[:][1] with torch.no_grad(): out = model(torch.tensor(dataset[:][0])) sns.histplot(pd.DataFrame(out).melt(), x='value', hue='variable', palette='rainbow') out_path = (paths('data') / 'runs') out_path.mkdir(exist_ok=True) plt.savefig(out_path / 'label_hist.png') plt.close() y_pred = out.argmax(axis=1) fig, ax = plt.subplots(figsize=(10, 5)) ConfusionMatrixDisplay.from_predictions(y, y_pred, ax=ax) ax.set(title="confusion matrix for kNN classifier on test data.", xticklabels=ticklabels(), yticklabels=ticklabels()) plt.savefig(out_path / 'confusion_matrix.png') plt.close() breakpoint()