119 lines
4.3 KiB
Python
119 lines
4.3 KiB
Python
import click
|
|
from tqdm import tqdm
|
|
from enum import Enum, auto
|
|
from dotenv import load_dotenv
|
|
import os
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from accelerate import Accelerator
|
|
|
|
from train.dataset import NewsDataset
|
|
from train.model import Classifier
|
|
from data.main import paths, connect, ticklabels
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
class Stage(Enum):
|
|
TRAIN = auto()
|
|
DEV = auto()
|
|
|
|
@click.command('main')
|
|
@click.option('--epochs', default=10, type=int)
|
|
def main(epochs):
|
|
dev_after = 5
|
|
visible_devices = None
|
|
lr = 1e-4
|
|
debug = False
|
|
torch.manual_seed(0)
|
|
num_workers = int(os.getenv('NUMBER_OF_WORKERS', 0))
|
|
embedding_length = int(os.getenv('EMBEDDING_LENGTH', 384))
|
|
dataset = NewsDataset()
|
|
trainset, devset = torch.utils.data.random_split(dataset, [0.8, 0.2])
|
|
batch_size = int(os.getenv('BATCH_SIZE', 512))
|
|
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
|
|
devloader = DataLoader(devset, shuffle=False, num_workers=num_workers)
|
|
accelerator = Accelerator()
|
|
model = Classifier(embedding_length=embedding_length, classes=5)
|
|
|
|
# it's possible to control which GPUs the process can see using an environmental variable
|
|
if visible_devices:
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = visible_devices
|
|
if debug:
|
|
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
|
#accelerator.log({"message" :"debug enabled"})
|
|
|
|
criterion = torch.nn.CrossEntropyLoss()
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
|
|
|
# wrap objects with accelerate
|
|
model, optimizer, trainloader, devloader = accelerator.prepare(model, optimizer, trainloader, devloader)
|
|
def run():
|
|
"""runner for training and valdating"""
|
|
running_loss = 0.0
|
|
# set the model to train model
|
|
model.train() if stage == Stage.TRAIN else model.eval()
|
|
dataloader = trainloader if stage == Stage.TRAIN else devloader
|
|
desc = 'train epoch' if stage == Stage.TRAIN else 'dev epoch'
|
|
if debug:
|
|
...
|
|
|
|
# Make sure there are no leftover gradients before starting training an epoch
|
|
optimizer.zero_grad()
|
|
|
|
for batch, (x, y) in enumerate(tqdm(dataloader, desc=desc)):
|
|
pred_y = model(x) # Forward pass through model
|
|
loss = criterion(pred_y, y)
|
|
running_loss += loss # Increment running loss
|
|
# Only update model weights on training
|
|
if stage == Stage.TRAIN:
|
|
accelerator.backward(loss) # Increment gradients within model by sending loss backwards
|
|
optimizer.step() # Update model weights
|
|
optimizer.zero_grad() # Reset gradients to 0
|
|
return running_loss / len(dataloader)
|
|
|
|
|
|
for epoch in range(epochs):
|
|
if (epoch + 1) % dev_after == 0:
|
|
stage = Stage.DEV
|
|
log = run()
|
|
print(f"dev loss: {log:.3f}")
|
|
stage = Stage.TRAIN
|
|
log = run()
|
|
print(f"train loss: {log:.3f}")
|
|
torch.save(model.state_dict(), paths('model') / 'torch_clf.pth')
|
|
|
|
@click.command('validate')
|
|
def validate():
|
|
from sklearn.metrics import ConfusionMatrixDisplay
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
|
|
embeddings = np.load(paths('data') / 'embeddings.npy')
|
|
embedding_ids = pd.DataFrame(np.load(paths('data') / 'embedding_ids.npy'), columns=['id']).reset_index()
|
|
|
|
embedding_length = int(os.getenv('EMBEDDING_LENGTH', 384))
|
|
model = Classifier(embedding_length=embedding_length, classes=5)
|
|
model.load_state_dict(torch.load(paths('model') / 'torch_clf.pth'))
|
|
model.eval()
|
|
|
|
dataset = NewsDataset()
|
|
|
|
y = dataset[:][1]
|
|
with torch.no_grad():
|
|
out = model(torch.tensor(dataset[:][0]))
|
|
|
|
sns.histplot(pd.DataFrame(out).melt(), x='value', hue='variable', palette='rainbow')
|
|
out_path = (paths('data') / 'runs')
|
|
out_path.mkdir(exist_ok=True)
|
|
plt.savefig(out_path / 'label_hist.png')
|
|
plt.close()
|
|
|
|
y_pred = out.argmax(axis=1)
|
|
fig, ax = plt.subplots(figsize=(10, 5))
|
|
ConfusionMatrixDisplay.from_predictions(y, y_pred, ax=ax)
|
|
ax.set(title="confusion matrix for kNN classifier on test data.", xticklabels=ticklabels(), yticklabels=ticklabels())
|
|
plt.savefig(out_path / 'confusion_matrix.png')
|
|
plt.close()
|
|
breakpoint()
|