rename batch to runner.
fill out makefile. add dev pipeline. gitignore data dir. add logger.py. fill out readme.md. export env.yml.
This commit is contained in:
@@ -4,3 +4,5 @@ lr: 2e-4
|
||||
batch_size: 16
|
||||
num_workers: 0
|
||||
device: "cpu"
|
||||
epochs: 4
|
||||
dev_after: 20
|
||||
|
||||
111
src/logger.py
Normal file
111
src/logger.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from tkinter import W
|
||||
import torch
|
||||
import wandb
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from einops import rearrange
|
||||
from typing import Protocol, Tuple, Optional
|
||||
|
||||
|
||||
class Logger(Protocol):
|
||||
def metrics(self, metrics: dict, epoch: int):
|
||||
"""loss etc."""
|
||||
|
||||
def hyperparameters(self, hyperparameters: dict):
|
||||
"""model states"""
|
||||
|
||||
def predictions(self, predictions: dict):
|
||||
"""inference time stuff"""
|
||||
|
||||
def images(self, images: np.ndarray):
|
||||
"""log images"""
|
||||
|
||||
|
||||
class WandbLogger:
|
||||
def __init__(self, project: str, entity: str, name: Optional[str], notes: str):
|
||||
self.project = project
|
||||
self.entity = entity
|
||||
self.notes = notes
|
||||
self.experiment = wandb.init(project=project, entity=entity, notes=notes)
|
||||
self.experiment.name = name
|
||||
|
||||
self.data_dict = {}
|
||||
|
||||
def metrics(self, metrics: dict):
|
||||
"""loss etc."""
|
||||
|
||||
self.data_dict.update(metrics)
|
||||
|
||||
def hyperparameters(self, hyperparameters: dict):
|
||||
"""model states"""
|
||||
self.experiment.config.update(hyperparameters, allow_val_change=True)
|
||||
|
||||
def predictions(self, predictions: dict):
|
||||
"""inference time stuff"""
|
||||
|
||||
def image(self, image: dict):
|
||||
"""log images to wandb"""
|
||||
self.data_dict.update({'Generate Image' : image})
|
||||
|
||||
def video(self, images: str, title: str):
|
||||
"""log images to wandb"""
|
||||
|
||||
images = np.uint8(rearrange(images, 't b c h w -> b t c h w'))
|
||||
self.data_dict.update({f"{title}": wandb.Video(images, fps=20)})
|
||||
|
||||
def flush(self):
|
||||
self.experiment.log(self.data_dict)
|
||||
self.data_dict = {}
|
||||
|
||||
|
||||
class DebugLogger:
|
||||
def __init__(self, project: str, entity: str, name: str, notes: str):
|
||||
self.project = project
|
||||
self.entity = entity
|
||||
self.name = name
|
||||
self.notes = notes
|
||||
|
||||
def metrics(self, metrics: dict, epoch: int = None):
|
||||
"""
|
||||
loss etc.
|
||||
"""
|
||||
print(f"metrics: {metrics}")
|
||||
|
||||
def hyperparameters(self, hyperparameters: dict):
|
||||
"""
|
||||
model states
|
||||
"""
|
||||
print(f"hyperparameters: {hyperparameters}")
|
||||
|
||||
def predictions(self, predictions: dict):
|
||||
"""
|
||||
inference time stuff
|
||||
"""
|
||||
|
||||
|
||||
class Checkpoint:
|
||||
def __init__(self, checkpoint_path):
|
||||
self.checkpoint_path = checkpoint_path
|
||||
|
||||
def load(self) -> Tuple:
|
||||
checkpoint = torch.load(self.checkpoint_path)
|
||||
model = checkpoint["model"]
|
||||
optimizer = checkpoint["optimizer"]
|
||||
epoch = checkpoint["epoch"]
|
||||
loss = checkpoint["loss"]
|
||||
return (model, optimizer, epoch, loss)
|
||||
|
||||
def save(self, model: torch.nn.Module, optimizer, epoch, loss):
|
||||
checkpoint = {
|
||||
"model": model,
|
||||
"optimizer": optimizer,
|
||||
"epoch": epoch,
|
||||
"loss": loss,
|
||||
}
|
||||
import random
|
||||
import string
|
||||
|
||||
name = "".join(random.choices(string.ascii_letters, k=10)) + ".tar"
|
||||
torch.save(checkpoint, f"{name}")
|
||||
|
||||
|
||||
@@ -3,7 +3,17 @@ main class for building a DL pipeline.
|
||||
|
||||
"""
|
||||
|
||||
from batch import Batch
|
||||
"""
|
||||
the main entry point for training a model
|
||||
|
||||
coordinates:
|
||||
|
||||
- datasets
|
||||
- dataloaders
|
||||
- runner
|
||||
|
||||
"""
|
||||
from runner import Runner
|
||||
from model.linear import DNN
|
||||
from model.cnn import VGG16, VGG11
|
||||
from data import MnistDataset
|
||||
@@ -11,7 +21,6 @@ from utils import Stage
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from collate import channel_to_batch
|
||||
|
||||
import hydra
|
||||
from omegaconf import DictConfig
|
||||
|
||||
@@ -24,13 +33,24 @@ def train(config: DictConfig):
|
||||
batch_size = config.batch_size
|
||||
num_workers = config.num_workers
|
||||
device = config.device
|
||||
epochs = config.epochs
|
||||
|
||||
path = Path(config.app_dir) / "storage/mnist_train.csv"
|
||||
trainset = MnistDataset(path=path)
|
||||
train_path = Path(config.app_dir) / "data/mnist_train.csv"
|
||||
trainset = MnistDataset(path=train_path)
|
||||
|
||||
dev_path = Path(config.app_dir) / "data/mnist_test.csv"
|
||||
devset = MnistDataset(path=dev_path)
|
||||
|
||||
trainloader = torch.utils.data.DataLoader(
|
||||
trainset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
# collate_fn=channel_to_batch,
|
||||
)
|
||||
devloader = torch.utils.data.DataLoader(
|
||||
devset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
# collate_fn=channel_to_batch,
|
||||
@@ -38,7 +58,7 @@ def train(config: DictConfig):
|
||||
model = VGG11(in_channels=1, num_classes=10)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||
batch = Batch(
|
||||
train_runner = Runner(
|
||||
stage=Stage.TRAIN,
|
||||
model=model,
|
||||
device=torch.device(device),
|
||||
@@ -47,10 +67,22 @@ def train(config: DictConfig):
|
||||
optimizer=optimizer,
|
||||
config=config,
|
||||
)
|
||||
log = batch.run(
|
||||
"Run run run run. Run run run away. Oh Oh oH OHHHHHHH yayayayayayayayaya! - David Byrne"
|
||||
dev_runner = Runner(
|
||||
stage=Stage.DEV,
|
||||
model=model,
|
||||
device=torch.device(device),
|
||||
loader=devloader,
|
||||
criterion=criterion,
|
||||
optimizer=optimizer,
|
||||
config=config,
|
||||
)
|
||||
|
||||
for epoch in range(epochs):
|
||||
if epoch % config.dev_after == 0:
|
||||
dev_log = dev_runner.run("dev epoch")
|
||||
else:
|
||||
train_log = train_runner.run("train epoch")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
"""
|
||||
runner for training and valdating
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import optim
|
||||
@@ -8,7 +11,7 @@ from utils import Stage
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
class Batch:
|
||||
class Runner:
|
||||
def __init__(
|
||||
self,
|
||||
stage: Stage,
|
||||
@@ -34,8 +37,7 @@ class Batch:
|
||||
self.model.train()
|
||||
if self.config.debug:
|
||||
breakpoint()
|
||||
epoch = 0
|
||||
for epoch, (x, y) in enumerate(tqdm(self.loader, desc=desc)):
|
||||
for batch, (x, y) in enumerate(tqdm(self.loader, desc=desc)):
|
||||
self.optimizer.zero_grad()
|
||||
loss = self._run_batch((x, y))
|
||||
loss.backward() # Send loss backwards to accumulate gradients
|
||||
Reference in New Issue
Block a user