ml_pipeline_cookiecutter/{{cookiecutter.project_name}}/{{cookiecutter.module_name}}/training/pipeline.py

56 lines
2.0 KiB
Python
Raw Permalink Normal View History

2024-04-06 13:02:31 -07:00
from torch.utils.data import DataLoader
from torch.optim import AdamW
from {{cookiecutter.module_name}}.training.runner import Runner
from {{cookiecutter.module_name}} import config, logger
2024-04-06 13:02:31 -07:00
def run(evaluate=False):
# Initialize the training set and a dataloader to iterate over the dataset
# train_set = GenericDataset()
dataset = get_dataset(evaluate)
dataloader = DataLoader(dataset, batch_size=config.training.batch_size, shuffle=True)
model = get_model(name=config.model.name)
optimizer = AdamW(model.parameters(), lr=config.training.learning_rate)
# Create a runner that will handle
runner = Runner(
dataset=dataset,
dataloader=dataloader,
model=model,
optimizer=optimizer,
)
# Train the model
for _ in range(config.training.epochs):
# Run one loop of training and record the average loss
for step in runner.step():
logger.info(f"{step}")
def get_model(name='vgg11'):
from {{cookiecutter.module_name}}.model.linear import DNN
from {{cookiecutter.module_name}}.model.cnn import VGG11
2024-04-06 13:02:31 -07:00
if name == 'vgg11':
return VGG11(config.data.in_channels, config.data.num_classes)
else:
# Create the model and optimizer and cast model to the appropriate GPU
in_features, out_features = dataset.in_out_features()
model = DNN(in_features, config.model.hidden_size, out_features)
return model.to(config.training.device)
def get_dataset(evaluate=False):
# Usage
from {{cookiecutter.module_name}}.data.dataset import MnistDataset
2024-04-06 13:02:31 -07:00
from torchvision import transforms
csv_file_path = config.data.train_path if not evaluate else config.data.test_path
transform = transforms.Compose([
transforms.ToTensor(), # Converts a PIL Image or numpy.ndarray to a FloatTensor and scales the image's pixel intensity values to the [0., 1.] range
transforms.Normalize((0.1307,), (0.3081,)) # Normalize using the mean and std specific to MNIST
])
dataset = MnistDataset(csv_file_path)
return dataset