ml_pipeline/batch.py

71 lines
1.9 KiB
Python

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from data import FashionDataset
from tqdm import tqdm
from utils import Stage
class Batch:
def __init__(
self,
stage: Stage,
model: nn.Module,
device,
loader: DataLoader,
optimizer: optim.Optimizer,
criterion: nn.Module,
):
"""todo"""
self.stage = stage
self.device = device
self.model = model.to(device)
self.loader = loader
self.criterion = criterion
self.optimizer = optimizer
self.loss = 0
def run(self, desc):
self.model.train()
epoch = 0
for epoch, (x, y) in enumerate(tqdm(self.loader, desc=desc)):
self.optimizer.zero_grad()
loss = self._run_batch((x, y))
loss.backward() # Send loss backwards to accumulate gradients
self.optimizer.step() # Perform a gradient update on the weights of the mode
self.loss += loss.item()
def _run_batch(self, sample):
true_x, true_y = sample
true_x, true_y = true_x.to(self.device), true_y.to(self.device)
pred_y = self.model(true_x)
loss = self.criterion(pred_y, true_y)
return loss
def main():
model = nn.Conv2d(1, 64, 3)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
path = "fashion-mnist_train.csv"
dataset = FashionDataset(path)
batch_size = 16
num_workers = 1
loader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
batch = Batch(
Stage.TRAIN,
device=torch.device("cpu"),
model=model,
criterion=criterion,
optimizer=optimizer,
loader=loader,
)
batch.run("test")
if __name__ == "__main__":
main()