init
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
main class for building a DL pipeline.
|
||||
|
||||
"""
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
class Stage(Enum):
|
||||
TRAIN = auto()
|
||||
DEV = auto()
|
||||
TEST = auto()
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import AdamW
|
||||
from ml_pipeline.training.runner import Runner
|
||||
from ml_pipeline import config, logger
|
||||
|
||||
|
||||
def run(evaluate=False):
|
||||
# Initialize the training set and a dataloader to iterate over the dataset
|
||||
# train_set = GenericDataset()
|
||||
dataset = get_dataset(evaluate)
|
||||
dataloader = DataLoader(dataset, batch_size=config.training.batch_size, shuffle=True)
|
||||
|
||||
model = get_model(name=config.model.name)
|
||||
|
||||
optimizer = AdamW(model.parameters(), lr=config.training.learning_rate)
|
||||
|
||||
# Create a runner that will handle
|
||||
runner = Runner(
|
||||
dataset=dataset,
|
||||
dataloader=dataloader,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
|
||||
# Train the model
|
||||
for _ in range(config.training.epochs):
|
||||
# Run one loop of training and record the average loss
|
||||
for step in runner.step():
|
||||
logger.info(f"{step}")
|
||||
|
||||
def get_model(name='vgg11'):
|
||||
from ml_pipeline.model.linear import DNN
|
||||
from ml_pipeline.model.cnn import VGG11
|
||||
if name == 'vgg11':
|
||||
return VGG11(config.data.in_channels, config.data.num_classes)
|
||||
else:
|
||||
# Create the model and optimizer and cast model to the appropriate GPU
|
||||
in_features, out_features = dataset.in_out_features()
|
||||
model = DNN(in_features, config.model.hidden_size, out_features)
|
||||
return model.to(config.training.device)
|
||||
|
||||
|
||||
def get_dataset(evaluate=False):
|
||||
# Usage
|
||||
from ml_pipeline.data.dataset import MnistDataset
|
||||
from torchvision import transforms
|
||||
csv_file_path = config.data.train_path if not evaluate else config.data.test_path
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(), # Converts a PIL Image or numpy.ndarray to a FloatTensor and scales the image's pixel intensity values to the [0., 1.] range
|
||||
transforms.Normalize((0.1307,), (0.3081,)) # Normalize using the mean and std specific to MNIST
|
||||
])
|
||||
|
||||
dataset = MnistDataset(csv_file_path)
|
||||
return dataset
|
||||
@@ -0,0 +1,46 @@
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class Runner:
|
||||
"""Runner class that is in charge of implementing routine training functions such as running epochs or doing inference time"""
|
||||
|
||||
def __init__(self, dataset: Dataset, dataloader: DataLoader, model: nn.Module, optimizer: Optimizer):
|
||||
# Initialize class attributes
|
||||
self.dataset = dataset
|
||||
|
||||
# Prepare opt, model, and dataloader (helps accelerator auto-cast to devices)
|
||||
self.optimizer, self.model, self.dataloader = (
|
||||
optimizer, model, dataloader
|
||||
)
|
||||
|
||||
# Since data is for targets, use Mean Squared Error Loss
|
||||
# self.criterion = nn.MSELoss()
|
||||
self.criterion = nn.CrossEntropyLoss()
|
||||
|
||||
def step(self):
|
||||
"""Runs an epoch of training.
|
||||
|
||||
Includes updating model weights and tracking training loss
|
||||
|
||||
Returns:
|
||||
float: The loss averaged over the entire epoch
|
||||
"""
|
||||
|
||||
# turn the model to training mode (affects batchnorm and dropout)
|
||||
self.model.train()
|
||||
|
||||
total_loss, total_samples = 0.0, 0.0
|
||||
for sample, target in self.dataloader:
|
||||
self.optimizer.zero_grad() # reset gradients to 0
|
||||
prediction = self.model(sample) # forward pass through model
|
||||
loss = self.criterion(prediction, target) # error calculation
|
||||
|
||||
# increment gradients within model by sending loss backwards
|
||||
loss.backward()
|
||||
self.optimizer.step() # update model weights
|
||||
|
||||
total_loss += loss # increment running loss
|
||||
total_samples += len(sample)
|
||||
yield total_loss / total_samples # take the average of the loss over each sample
|
||||
Reference in New Issue
Block a user