ml_pipeline_cookiecutter/{{cookiecutter.project_name}}/{{cookiecutter.module_name}}/training/runner.py

47 lines
1.7 KiB
Python
Raw Normal View History

2024-04-06 13:02:31 -07:00
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Optimizer
class Runner:
"""Runner class that is in charge of implementing routine training functions such as running epochs or doing inference time"""
def __init__(self, dataset: Dataset, dataloader: DataLoader, model: nn.Module, optimizer: Optimizer):
# Initialize class attributes
self.dataset = dataset
# Prepare opt, model, and dataloader (helps accelerator auto-cast to devices)
self.optimizer, self.model, self.dataloader = (
optimizer, model, dataloader
)
# Since data is for targets, use Mean Squared Error Loss
# self.criterion = nn.MSELoss()
self.criterion = nn.CrossEntropyLoss()
def step(self):
"""Runs an epoch of training.
Includes updating model weights and tracking training loss
Returns:
float: The loss averaged over the entire epoch
"""
# turn the model to training mode (affects batchnorm and dropout)
self.model.train()
total_loss, total_samples = 0.0, 0.0
for sample, target in self.dataloader:
self.optimizer.zero_grad() # reset gradients to 0
prediction = self.model(sample) # forward pass through model
loss = self.criterion(prediction, target) # error calculation
# increment gradients within model by sending loss backwards
loss.backward()
self.optimizer.step() # update model weights
total_loss += loss # increment running loss
total_samples += len(sample)
yield total_loss / total_samples # take the average of the loss over each sample