wwu-577/src/train/main.py

119 lines
4.3 KiB
Python

import click
from tqdm import tqdm
from enum import Enum, auto
from dotenv import load_dotenv
import os
import torch
from torch.utils.data import DataLoader
from accelerate import Accelerator
from train.dataset import NewsDataset
from train.model import Classifier
from data.main import paths, connect, ticklabels
import numpy as np
import pandas as pd
class Stage(Enum):
TRAIN = auto()
DEV = auto()
@click.command('main')
@click.option('--epochs', default=10, type=int)
def main(epochs):
dev_after = 5
visible_devices = None
lr = 1e-4
debug = False
torch.manual_seed(0)
num_workers = int(os.getenv('NUMBER_OF_WORKERS', 0))
embedding_length = int(os.getenv('EMBEDDING_LENGTH', 384))
dataset = NewsDataset()
trainset, devset = torch.utils.data.random_split(dataset, [0.8, 0.2])
batch_size = int(os.getenv('BATCH_SIZE', 512))
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
devloader = DataLoader(devset, shuffle=False, num_workers=num_workers)
accelerator = Accelerator()
model = Classifier(embedding_length=embedding_length, classes=5)
# it's possible to control which GPUs the process can see using an environmental variable
if visible_devices:
os.environ['CUDA_VISIBLE_DEVICES'] = visible_devices
if debug:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
#accelerator.log({"message" :"debug enabled"})
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
# wrap objects with accelerate
model, optimizer, trainloader, devloader = accelerator.prepare(model, optimizer, trainloader, devloader)
def run():
"""runner for training and valdating"""
running_loss = 0.0
# set the model to train model
model.train() if stage == Stage.TRAIN else model.eval()
dataloader = trainloader if stage == Stage.TRAIN else devloader
desc = 'train epoch' if stage == Stage.TRAIN else 'dev epoch'
if debug:
...
# Make sure there are no leftover gradients before starting training an epoch
optimizer.zero_grad()
for batch, (x, y) in enumerate(tqdm(dataloader, desc=desc)):
pred_y = model(x) # Forward pass through model
loss = criterion(pred_y, y)
running_loss += loss # Increment running loss
# Only update model weights on training
if stage == Stage.TRAIN:
accelerator.backward(loss) # Increment gradients within model by sending loss backwards
optimizer.step() # Update model weights
optimizer.zero_grad() # Reset gradients to 0
return running_loss / len(dataloader)
for epoch in range(epochs):
if (epoch + 1) % dev_after == 0:
stage = Stage.DEV
log = run()
print(f"dev loss: {log:.3f}")
stage = Stage.TRAIN
log = run()
print(f"train loss: {log:.3f}")
torch.save(model.state_dict(), paths('model') / 'torch_clf.pth')
@click.command('validate')
def validate():
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import seaborn as sns
embeddings = np.load(paths('data') / 'embeddings.npy')
embedding_ids = pd.DataFrame(np.load(paths('data') / 'embedding_ids.npy'), columns=['id']).reset_index()
embedding_length = int(os.getenv('EMBEDDING_LENGTH', 384))
model = Classifier(embedding_length=embedding_length, classes=5)
model.load_state_dict(torch.load(paths('model') / 'torch_clf.pth'))
model.eval()
dataset = NewsDataset()
y = dataset[:][1]
with torch.no_grad():
out = model(torch.tensor(dataset[:][0]))
sns.histplot(pd.DataFrame(out).melt(), x='value', hue='variable', palette='rainbow')
out_path = (paths('data') / 'runs')
out_path.mkdir(exist_ok=True)
plt.savefig(out_path / 'label_hist.png')
plt.close()
y_pred = out.argmax(axis=1)
fig, ax = plt.subplots(figsize=(10, 5))
ConfusionMatrixDisplay.from_predictions(y, y_pred, ax=ax)
ax.set(title="confusion matrix for kNN classifier on test data.", xticklabels=ticklabels(), yticklabels=ticklabels())
plt.savefig(out_path / 'confusion_matrix.png')
plt.close()
breakpoint()