rename batch to runner.

fill out makefile.
add dev pipeline.
gitignore data dir.
add logger.py.
fill out readme.md.
export env.yml.
This commit is contained in:
Matt
2023-01-26 11:00:24 -08:00
parent 1f13224c4f
commit 0f12b26e40
8 changed files with 275 additions and 33 deletions

View File

@@ -4,3 +4,5 @@ lr: 2e-4
batch_size: 16
num_workers: 0
device: "cpu"
epochs: 4
dev_after: 20

111
src/logger.py Normal file
View File

@@ -0,0 +1,111 @@
from tkinter import W
import torch
import wandb
import numpy as np
from PIL import Image
from einops import rearrange
from typing import Protocol, Tuple, Optional
class Logger(Protocol):
def metrics(self, metrics: dict, epoch: int):
"""loss etc."""
def hyperparameters(self, hyperparameters: dict):
"""model states"""
def predictions(self, predictions: dict):
"""inference time stuff"""
def images(self, images: np.ndarray):
"""log images"""
class WandbLogger:
def __init__(self, project: str, entity: str, name: Optional[str], notes: str):
self.project = project
self.entity = entity
self.notes = notes
self.experiment = wandb.init(project=project, entity=entity, notes=notes)
self.experiment.name = name
self.data_dict = {}
def metrics(self, metrics: dict):
"""loss etc."""
self.data_dict.update(metrics)
def hyperparameters(self, hyperparameters: dict):
"""model states"""
self.experiment.config.update(hyperparameters, allow_val_change=True)
def predictions(self, predictions: dict):
"""inference time stuff"""
def image(self, image: dict):
"""log images to wandb"""
self.data_dict.update({'Generate Image' : image})
def video(self, images: str, title: str):
"""log images to wandb"""
images = np.uint8(rearrange(images, 't b c h w -> b t c h w'))
self.data_dict.update({f"{title}": wandb.Video(images, fps=20)})
def flush(self):
self.experiment.log(self.data_dict)
self.data_dict = {}
class DebugLogger:
def __init__(self, project: str, entity: str, name: str, notes: str):
self.project = project
self.entity = entity
self.name = name
self.notes = notes
def metrics(self, metrics: dict, epoch: int = None):
"""
loss etc.
"""
print(f"metrics: {metrics}")
def hyperparameters(self, hyperparameters: dict):
"""
model states
"""
print(f"hyperparameters: {hyperparameters}")
def predictions(self, predictions: dict):
"""
inference time stuff
"""
class Checkpoint:
def __init__(self, checkpoint_path):
self.checkpoint_path = checkpoint_path
def load(self) -> Tuple:
checkpoint = torch.load(self.checkpoint_path)
model = checkpoint["model"]
optimizer = checkpoint["optimizer"]
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]
return (model, optimizer, epoch, loss)
def save(self, model: torch.nn.Module, optimizer, epoch, loss):
checkpoint = {
"model": model,
"optimizer": optimizer,
"epoch": epoch,
"loss": loss,
}
import random
import string
name = "".join(random.choices(string.ascii_letters, k=10)) + ".tar"
torch.save(checkpoint, f"{name}")

View File

@@ -3,7 +3,17 @@ main class for building a DL pipeline.
"""
from batch import Batch
"""
the main entry point for training a model
coordinates:
- datasets
- dataloaders
- runner
"""
from runner import Runner
from model.linear import DNN
from model.cnn import VGG16, VGG11
from data import MnistDataset
@@ -11,7 +21,6 @@ from utils import Stage
import torch
from pathlib import Path
from collate import channel_to_batch
import hydra
from omegaconf import DictConfig
@@ -24,13 +33,24 @@ def train(config: DictConfig):
batch_size = config.batch_size
num_workers = config.num_workers
device = config.device
epochs = config.epochs
path = Path(config.app_dir) / "storage/mnist_train.csv"
trainset = MnistDataset(path=path)
train_path = Path(config.app_dir) / "data/mnist_train.csv"
trainset = MnistDataset(path=train_path)
dev_path = Path(config.app_dir) / "data/mnist_test.csv"
devset = MnistDataset(path=dev_path)
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
# collate_fn=channel_to_batch,
)
devloader = torch.utils.data.DataLoader(
devset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
# collate_fn=channel_to_batch,
@@ -38,7 +58,7 @@ def train(config: DictConfig):
model = VGG11(in_channels=1, num_classes=10)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
batch = Batch(
train_runner = Runner(
stage=Stage.TRAIN,
model=model,
device=torch.device(device),
@@ -47,10 +67,22 @@ def train(config: DictConfig):
optimizer=optimizer,
config=config,
)
log = batch.run(
"Run run run run. Run run run away. Oh Oh oH OHHHHHHH yayayayayayayayaya! - David Byrne"
dev_runner = Runner(
stage=Stage.DEV,
model=model,
device=torch.device(device),
loader=devloader,
criterion=criterion,
optimizer=optimizer,
config=config,
)
for epoch in range(epochs):
if epoch % config.dev_after == 0:
dev_log = dev_runner.run("dev epoch")
else:
train_log = train_runner.run("train epoch")
if __name__ == "__main__":
train()

View File

@@ -1,3 +1,6 @@
"""
runner for training and valdating
"""
import torch
from torch import nn
from torch import optim
@@ -8,7 +11,7 @@ from utils import Stage
from omegaconf import DictConfig
class Batch:
class Runner:
def __init__(
self,
stage: Stage,
@@ -34,8 +37,7 @@ class Batch:
self.model.train()
if self.config.debug:
breakpoint()
epoch = 0
for epoch, (x, y) in enumerate(tqdm(self.loader, desc=desc)):
for batch, (x, y) in enumerate(tqdm(self.loader, desc=desc)):
self.optimizer.zero_grad()
loss = self._run_batch((x, y))
loss.backward() # Send loss backwards to accumulate gradients