52 lines
1.8 KiB
Python
52 lines
1.8 KiB
Python
from torch import nn
|
|
|
|
|
|
class Runner:
|
|
"""Runner class that is in charge of implementing routine training functions such as running epochs or doing inference time"""
|
|
|
|
def __init__(self, train_set, train_loader, accelerator, model, optimizer):
|
|
|
|
# Initialize class attributes
|
|
self.accelerator = accelerator
|
|
self.train_set = train_set
|
|
|
|
# Prepare opt, model, and train_loader (helps accelerator auto-cast to devices)
|
|
self.optimizer, self.model, self.train_loader = accelerator.prepare(
|
|
optimizer, model, train_loader
|
|
)
|
|
|
|
# Since data is for targets, use Mean Squared Error Loss
|
|
self.criterion = nn.MSELoss()
|
|
|
|
def next(self):
|
|
"""Runs an epoch of training.
|
|
|
|
Includes updating model weights and tracking training loss
|
|
|
|
Returns:
|
|
float: The loss averaged over the entire epoch
|
|
"""
|
|
|
|
# Turn the model to training mode (affects batchnorm and dropout)
|
|
self.model.train()
|
|
|
|
running_loss = 0.0
|
|
|
|
# Make sure there are no leftover gradients before starting training an epoch
|
|
self.optimizer.zero_grad()
|
|
|
|
for sample, target in self.train_loader:
|
|
|
|
prediction = self.model(sample) # Forward pass through model
|
|
loss = self.criterion(prediction, target) # Error calculation
|
|
running_loss += loss # Increment running loss
|
|
self.accelerator.backward(
|
|
loss
|
|
) # Increment gradients within model by sending loss backwards
|
|
self.optimizer.step() # Update model weights
|
|
self.optimizer.zero_grad() # Reset gradients to 0
|
|
|
|
# Take the average of the loss over each sample
|
|
avg_loss = running_loss / len(self.train_loader)
|
|
return avg_loss
|