47 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			47 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			Python
		
	
	
	
from torch import nn
 | 
						|
from torch.utils.data import Dataset, DataLoader
 | 
						|
from torch.optim import Optimizer
 | 
						|
 | 
						|
 | 
						|
class Runner:
 | 
						|
    """Runner class that is in charge of implementing routine training functions such as running epochs or doing inference time"""
 | 
						|
 | 
						|
    def __init__(self, dataset: Dataset, dataloader: DataLoader, model: nn.Module, optimizer: Optimizer):
 | 
						|
        # Initialize class attributes
 | 
						|
        self.dataset = dataset
 | 
						|
 | 
						|
        # Prepare opt, model, and dataloader (helps accelerator auto-cast to devices)
 | 
						|
        self.optimizer, self.model, self.dataloader = (
 | 
						|
            optimizer, model, dataloader
 | 
						|
        )
 | 
						|
 | 
						|
        # Since data is for targets, use Mean Squared Error Loss
 | 
						|
        # self.criterion = nn.MSELoss()
 | 
						|
        self.criterion = nn.CrossEntropyLoss()
 | 
						|
 | 
						|
    def step(self):
 | 
						|
        """Runs an epoch of training.
 | 
						|
 | 
						|
        Includes updating model weights and tracking training loss
 | 
						|
 | 
						|
        Returns:
 | 
						|
            float: The loss averaged over the entire epoch
 | 
						|
        """
 | 
						|
 | 
						|
        # turn the model to training mode (affects batchnorm and dropout)
 | 
						|
        self.model.train()
 | 
						|
 | 
						|
        total_loss, total_samples = 0.0, 0.0
 | 
						|
        for sample, target in self.dataloader:
 | 
						|
            self.optimizer.zero_grad()  # reset gradients to 0
 | 
						|
            prediction = self.model(sample)  # forward pass through model
 | 
						|
            loss = self.criterion(prediction, target)  # error calculation
 | 
						|
 | 
						|
            # increment gradients within model by sending loss backwards
 | 
						|
            loss.backward()
 | 
						|
            self.optimizer.step()  # update model weights
 | 
						|
 | 
						|
            total_loss += loss # increment running loss
 | 
						|
            total_samples += len(sample)
 | 
						|
            yield total_loss / total_samples  # take the average of the loss over each sample
 |