ml_pipeline/pipeline.py

44 lines
997 B
Python

import click
from batch import Batch
from model.linear import DNN
from model.cnn import VGG16, VGG11
from data import FashionDataset
from utils import Stage
import torch
@click.group()
def cli():
pass
@cli.command()
def train():
batch_size = 16
num_workers = 8
path = "fashion-mnist_train.csv"
trainset = FashionDataset(path=path)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
model = VGG11(in_channels=1, num_classes=10)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
batch = Batch(
stage=Stage.TRAIN,
model=model,
device=torch.device("cpu"),
loader=trainloader,
criterion=criterion,
optimizer=optimizer,
)
batch.run(
"Run run run run. Run run run away. Oh Oh oH OHHHHHHH yayayayayayayayaya! - David Byrne"
)
if __name__ == "__main__":
cli()