rename batch to runner.

fill out makefile.
add dev pipeline.
gitignore data dir.
add logger.py.
fill out readme.md.
export env.yml.
This commit is contained in:
Matt 2023-01-26 11:00:24 -08:00
parent 1f13224c4f
commit 0f12b26e40
8 changed files with 275 additions and 33 deletions

View File

@ -1,15 +1,25 @@
CONDA_ENV=ml_pipeline CONDA_ENV=ml_pipeline
.PHONY: help
all: run all: help
run: run: ## run the pipeline (train)
./launch.sh python src/pipeline.py \
debug=false
debug: ## run the pipeline (train) with debugging enabled
python src/pipeline.py \
debug=true
data: data: ## download the mnist data
python src/data.py wget https://pjreddie.com/media/files/mnist_train.csv -O data/mnist_train.csv
wget https://pjreddie.com/media/files/mnist_test.csv -O data/mnist_test.csv
batch: env_import: environment.yml ## import any changes to env.yml into conda env
python src/batch.py conda env update -n ${CONDA_ENV} --file $^
env_export: ## export the conda envirnoment without package or name
conda env export | head -n -1 | tail -n +2 > $@
help: ## display this help message
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
install:
conda env updates -n ${CONDA_ENV} --file environment.yml

View File

@ -40,14 +40,54 @@ Especially in a research lab, standardizing on a few components makes switching
In this template, you'll see the following: In this template, you'll see the following:
- `src/model`, `src/config`, `storage`, `test` dirs. ## directory structure
- `if __name__ == "__main__"` tests.
- `src/model`
- `src/config`
- `data/`
- `test/`
- pytest: unit testing.
- good for data shape
- TODO:
- `docs/`
- switching projects is easier with these in place
- organize them
- `**/__init__.py`
- creates modules out of dir.
- `import module` works with these.
- `README.md`
- root level required.
- can exist inside any dir.
- `environment.yml`
- `Makefile`
- to install and run stuff.
- houses common operations and scripts.
- `launch.sh`
- script to dispatch training.
## testing
- `if __name__ == "__main__"`.
- good way to test things
- enables lots breakpoints.
## config
- Hydra config. - Hydra config.
- dataloader, optimizer, criterion, device, state are constructed in main, but passed to an object that runs batches. - quickly experiment with hyperparameters
- tqdm to track progress. - good way to define env. variables
- debug config flag enables lots breakpoints. - lr, workers, batch_size
- python type hints. - debug
- a `launch.sh` script to dispatch training.
- a Makefile to install and run stuff. ## data
- automatic linting with the `black` package.
- collate functions! - collate functions!
## formatting python
- python type hints.
- automatic linting with the `black` package.
## running
- tqdm to track progress.
## architecture
- dataloader, optimizer, criterion, device, state are constructed in main, but passed to an object that runs batches.

1
data/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*.csv

View File

@ -1,4 +1,3 @@
name: ml
channels: channels:
- pytorch - pytorch
- conda-forge - conda-forge
@ -6,30 +5,50 @@ channels:
dependencies: dependencies:
- _libgcc_mutex=0.1=conda_forge - _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu - _openmp_mutex=4.5=2_gnu
- accelerate=0.15.0=pyhd8ed1ab_0
- antlr-python-runtime=4.9.3=pyhd8ed1ab_1
- appdirs=1.4.4=pyh9f0ad1d_0
- astroid=2.11.7=py310h06a4308_0
- attrs=22.1.0=py310h06a4308_0
- black=22.6.0=py310h06a4308_0 - black=22.6.0=py310h06a4308_0
- blas=1.0=mkl - blas=1.0=mkl
- brotli=1.0.9=h5eee18b_7 - brotli=1.0.9=h5eee18b_7
- brotli-bin=1.0.9=h5eee18b_7 - brotli-bin=1.0.9=h5eee18b_7
- brotlipy=0.7.0=py310h5764c6d_1005
- bzip2=1.0.8=h7f98852_4 - bzip2=1.0.8=h7f98852_4
- ca-certificates=2022.10.11=h06a4308_0 - ca-certificates=2022.12.7=ha878542_0
- certifi=2022.12.7=pyhd8ed1ab_0
- cffi=1.15.1=py310h74dc2b5_0
- charset-normalizer=2.1.1=pyhd8ed1ab_0
- click=8.0.3=pyhd3eb1b0_0 - click=8.0.3=pyhd3eb1b0_0
- colorama=0.4.6=pyhd8ed1ab_0 - colorama=0.4.6=pyhd8ed1ab_0
- cryptography=39.0.0=py310h65dfdc0_0
- cycler=0.11.0=pyhd3eb1b0_0 - cycler=0.11.0=pyhd3eb1b0_0
- dbus=1.13.18=hb2f20db_0 - dbus=1.13.18=hb2f20db_0
- dill=0.3.6=pyhd8ed1ab_1
- docker-pycreds=0.4.0=py_0
- einops=0.4.1=pyhd8ed1ab_0 - einops=0.4.1=pyhd8ed1ab_0
- expat=2.4.9=h6a678d5_0 - expat=2.4.9=h6a678d5_0
- fontconfig=2.13.1=h6c09931_0 - fontconfig=2.13.1=h6c09931_0
- fonttools=4.25.0=pyhd3eb1b0_0 - fonttools=4.25.0=pyhd3eb1b0_0
- freetype=2.12.1=h4a9f257_0 - freetype=2.12.1=h4a9f257_0
- giflib=5.2.1=h7b6447c_0 - giflib=5.2.1=h7b6447c_0
- gitdb=4.0.10=pyhd8ed1ab_0
- gitpython=3.1.30=pyhd8ed1ab_0
- glib=2.69.1=h4ff587b_1 - glib=2.69.1=h4ff587b_1
- gst-plugins-base=1.14.0=h8213a91_2 - gst-plugins-base=1.14.0=h8213a91_2
- gstreamer=1.14.0=h28cd5cc_2 - gstreamer=1.14.0=h28cd5cc_2
- hydra-core=1.3.1=pyhd8ed1ab_0
- icu=58.2=he6710b0_3 - icu=58.2=he6710b0_3
- idna=3.4=pyhd8ed1ab_0
- importlib_resources=5.10.2=pyhd8ed1ab_0
- iniconfig=1.1.1=pyhd3eb1b0_0
- intel-openmp=2021.4.0=h06a4308_3561 - intel-openmp=2021.4.0=h06a4308_3561
- isort=5.9.3=pyhd3eb1b0_0
- jpeg=9e=h7f8727e_0 - jpeg=9e=h7f8727e_0
- kiwisolver=1.4.2=py310h295c915_0 - kiwisolver=1.4.2=py310h295c915_0
- krb5=1.19.2=hac12032_0 - krb5=1.19.2=hac12032_0
- lazy-object-proxy=1.6.0=py310h7f8727e_0
- lcms2=2.12=h3be6417_0 - lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.39=hc81fddc_0 - ld_impl_linux-64=2.39=hc81fddc_0
- lerc=3.0=h295c915_0 - lerc=3.0=h295c915_0
@ -50,6 +69,7 @@ dependencies:
- libopenblas=0.3.21=pthreads_h78a6416_3 - libopenblas=0.3.21=pthreads_h78a6416_3
- libpng=1.6.37=hbc83047_0 - libpng=1.6.37=hbc83047_0
- libpq=12.9=h16c4e8d_3 - libpq=12.9=h16c4e8d_3
- libprotobuf=3.20.1=h4ff587b_0
- libstdcxx-ng=12.2.0=h46fd767_19 - libstdcxx-ng=12.2.0=h46fd767_19
- libtiff=4.4.0=hecacb30_0 - libtiff=4.4.0=hecacb30_0
- libuuid=1.0.3=h7f8727e_2 - libuuid=1.0.3=h7f8727e_2
@ -62,6 +82,7 @@ dependencies:
- lz4-c=1.9.3=h295c915_1 - lz4-c=1.9.3=h295c915_1
- matplotlib=3.5.2=py310h06a4308_0 - matplotlib=3.5.2=py310h06a4308_0
- matplotlib-base=3.5.2=py310hf590b9c_0 - matplotlib-base=3.5.2=py310hf590b9c_0
- mccabe=0.7.0=pyhd3eb1b0_0
- mkl=2021.4.0=h06a4308_640 - mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py310h7f8727e_0 - mkl-service=2.4.0=py310h7f8727e_0
- mkl_fft=1.3.1=py310hd6ae3a3_0 - mkl_fft=1.3.1=py310hd6ae3a3_0
@ -73,39 +94,62 @@ dependencies:
- nss=3.74=h0370c37_0 - nss=3.74=h0370c37_0
- numpy=1.23.3=py310hd5efca6_0 - numpy=1.23.3=py310hd5efca6_0
- numpy-base=1.23.3=py310h8e6c178_0 - numpy-base=1.23.3=py310h8e6c178_0
- openssl=1.1.1q=h7f8727e_0 - omegaconf=2.3.0=pyhd8ed1ab_0
- openssl=1.1.1s=h0b41bf4_1
- packaging=21.3=pyhd3eb1b0_0 - packaging=21.3=pyhd3eb1b0_0
- pathspec=0.10.1=pyhd8ed1ab_0 - pathspec=0.10.1=pyhd8ed1ab_0
- pathtools=0.1.2=py_1
- pcre=8.45=h295c915_0 - pcre=8.45=h295c915_0
- pillow=9.2.0=py310hace64e9_1 - pillow=9.2.0=py310hace64e9_1
- pip=22.3=pyhd8ed1ab_0 - pip=22.3=pyhd8ed1ab_0
- platformdirs=2.5.2=pyhd8ed1ab_1 - platformdirs=2.5.2=pyhd8ed1ab_1
- pluggy=1.0.0=py310h06a4308_1
- ply=3.11=py310h06a4308_0 - ply=3.11=py310h06a4308_0
- protobuf=3.20.1=py310h295c915_0
- psutil=5.9.4=py310h5764c6d_0
- py=1.11.0=pyhd3eb1b0_0
- pycparser=2.21=pyhd8ed1ab_0
- pylint=2.14.5=py310h06a4308_0
- pyopenssl=23.0.0=pyhd8ed1ab_0
- pyparsing=3.0.9=py310h06a4308_0 - pyparsing=3.0.9=py310h06a4308_0
- pyqt=5.15.7=py310h6a678d5_1 - pyqt=5.15.7=py310h6a678d5_1
- pysocks=1.7.1=pyha2e5f31_6
- pytest=7.1.2=py310h06a4308_0
- python=3.10.6=haa1d7c7_1 - python=3.10.6=haa1d7c7_1
- python-dateutil=2.8.2=pyhd3eb1b0_0 - python-dateutil=2.8.2=pyhd3eb1b0_0
- python-dotenv=0.21.0=py310h06a4308_0
- python_abi=3.10=2_cp310
- pytorch=1.13.0=py3.10_cpu_0 - pytorch=1.13.0=py3.10_cpu_0
- pytorch-mutex=1.0=cpu - pytorch-mutex=1.0=cpu
- pyyaml=6.0=py310h5764c6d_5
- qt-main=5.15.2=h327a75a_7 - qt-main=5.15.2=h327a75a_7
- qt-webengine=5.15.9=hd2b0992_4 - qt-webengine=5.15.9=hd2b0992_4
- qtwebkit=5.212=h4eab89a_4 - qtwebkit=5.212=h4eab89a_4
- readline=8.1.2=h0f457ee_0 - readline=8.1.2=h0f457ee_0
- requests=2.28.2=pyhd8ed1ab_0
- sentry-sdk=1.14.0=pyhd8ed1ab_0
- setproctitle=1.3.2=py310h5764c6d_1
- setuptools=65.5.0=pyhd8ed1ab_0 - setuptools=65.5.0=pyhd8ed1ab_0
- sip=6.6.2=py310h6a678d5_0 - sip=6.6.2=py310h6a678d5_0
- six=1.16.0=pyhd3eb1b0_1 - six=1.16.0=pyhd3eb1b0_1
- smmap=3.0.5=pyh44b312d_0
- sqlite=3.39.3=h5082296_0 - sqlite=3.39.3=h5082296_0
- tk=8.6.12=h1ccaba5_0 - tk=8.6.12=h1ccaba5_0
- toml=0.10.2=pyhd3eb1b0_0 - toml=0.10.2=pyhd3eb1b0_0
- tomli=2.0.1=py310h06a4308_0 - tomli=2.0.1=py310h06a4308_0
- tomlkit=0.11.1=py310h06a4308_0
- tornado=6.2=py310h5eee18b_0 - tornado=6.2=py310h5eee18b_0
- tqdm=4.64.1=pyhd8ed1ab_0 - tqdm=4.64.1=pyhd8ed1ab_0
- typing_extensions=4.3.0=py310h06a4308_0 - typing_extensions=4.3.0=py310h06a4308_0
- tzdata=2022e=h191b570_0 - tzdata=2022e=h191b570_0
- urllib3=1.26.14=pyhd8ed1ab_0
- wandb=0.13.9=pyhd8ed1ab_0
- wheel=0.37.1=pyhd8ed1ab_0 - wheel=0.37.1=pyhd8ed1ab_0
- wrapt=1.14.1=py310h5eee18b_0
- xz=5.2.6=h166bdaf_0 - xz=5.2.6=h166bdaf_0
- yaml=0.2.5=h7f98852_2
- zipp=3.11.0=pyhd8ed1ab_0
- zlib=1.2.13=h5eee18b_0 - zlib=1.2.13=h5eee18b_0
- zstd=1.5.2=ha4553b6_0 - zstd=1.5.2=ha4553b6_0
- pip: - pip:
- pyqt5-sip==12.11.0 - pyqt5-sip==12.11.0
prefix: /home/personal/Dev/conda/envs/ml

View File

@ -4,3 +4,5 @@ lr: 2e-4
batch_size: 16 batch_size: 16
num_workers: 0 num_workers: 0
device: "cpu" device: "cpu"
epochs: 4
dev_after: 20

111
src/logger.py Normal file
View File

@ -0,0 +1,111 @@
from tkinter import W
import torch
import wandb
import numpy as np
from PIL import Image
from einops import rearrange
from typing import Protocol, Tuple, Optional
class Logger(Protocol):
def metrics(self, metrics: dict, epoch: int):
"""loss etc."""
def hyperparameters(self, hyperparameters: dict):
"""model states"""
def predictions(self, predictions: dict):
"""inference time stuff"""
def images(self, images: np.ndarray):
"""log images"""
class WandbLogger:
def __init__(self, project: str, entity: str, name: Optional[str], notes: str):
self.project = project
self.entity = entity
self.notes = notes
self.experiment = wandb.init(project=project, entity=entity, notes=notes)
self.experiment.name = name
self.data_dict = {}
def metrics(self, metrics: dict):
"""loss etc."""
self.data_dict.update(metrics)
def hyperparameters(self, hyperparameters: dict):
"""model states"""
self.experiment.config.update(hyperparameters, allow_val_change=True)
def predictions(self, predictions: dict):
"""inference time stuff"""
def image(self, image: dict):
"""log images to wandb"""
self.data_dict.update({'Generate Image' : image})
def video(self, images: str, title: str):
"""log images to wandb"""
images = np.uint8(rearrange(images, 't b c h w -> b t c h w'))
self.data_dict.update({f"{title}": wandb.Video(images, fps=20)})
def flush(self):
self.experiment.log(self.data_dict)
self.data_dict = {}
class DebugLogger:
def __init__(self, project: str, entity: str, name: str, notes: str):
self.project = project
self.entity = entity
self.name = name
self.notes = notes
def metrics(self, metrics: dict, epoch: int = None):
"""
loss etc.
"""
print(f"metrics: {metrics}")
def hyperparameters(self, hyperparameters: dict):
"""
model states
"""
print(f"hyperparameters: {hyperparameters}")
def predictions(self, predictions: dict):
"""
inference time stuff
"""
class Checkpoint:
def __init__(self, checkpoint_path):
self.checkpoint_path = checkpoint_path
def load(self) -> Tuple:
checkpoint = torch.load(self.checkpoint_path)
model = checkpoint["model"]
optimizer = checkpoint["optimizer"]
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]
return (model, optimizer, epoch, loss)
def save(self, model: torch.nn.Module, optimizer, epoch, loss):
checkpoint = {
"model": model,
"optimizer": optimizer,
"epoch": epoch,
"loss": loss,
}
import random
import string
name = "".join(random.choices(string.ascii_letters, k=10)) + ".tar"
torch.save(checkpoint, f"{name}")

View File

@ -3,7 +3,17 @@ main class for building a DL pipeline.
""" """
from batch import Batch """
the main entry point for training a model
coordinates:
- datasets
- dataloaders
- runner
"""
from runner import Runner
from model.linear import DNN from model.linear import DNN
from model.cnn import VGG16, VGG11 from model.cnn import VGG16, VGG11
from data import MnistDataset from data import MnistDataset
@ -11,7 +21,6 @@ from utils import Stage
import torch import torch
from pathlib import Path from pathlib import Path
from collate import channel_to_batch from collate import channel_to_batch
import hydra import hydra
from omegaconf import DictConfig from omegaconf import DictConfig
@ -24,13 +33,24 @@ def train(config: DictConfig):
batch_size = config.batch_size batch_size = config.batch_size
num_workers = config.num_workers num_workers = config.num_workers
device = config.device device = config.device
epochs = config.epochs
path = Path(config.app_dir) / "storage/mnist_train.csv" train_path = Path(config.app_dir) / "data/mnist_train.csv"
trainset = MnistDataset(path=path) trainset = MnistDataset(path=train_path)
dev_path = Path(config.app_dir) / "data/mnist_test.csv"
devset = MnistDataset(path=dev_path)
trainloader = torch.utils.data.DataLoader( trainloader = torch.utils.data.DataLoader(
trainset, trainset,
batch_size=batch_size, batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
# collate_fn=channel_to_batch,
)
devloader = torch.utils.data.DataLoader(
devset,
batch_size=batch_size,
shuffle=False, shuffle=False,
num_workers=num_workers, num_workers=num_workers,
# collate_fn=channel_to_batch, # collate_fn=channel_to_batch,
@ -38,7 +58,7 @@ def train(config: DictConfig):
model = VGG11(in_channels=1, num_classes=10) model = VGG11(in_channels=1, num_classes=10)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr) optimizer = torch.optim.Adam(model.parameters(), lr=lr)
batch = Batch( train_runner = Runner(
stage=Stage.TRAIN, stage=Stage.TRAIN,
model=model, model=model,
device=torch.device(device), device=torch.device(device),
@ -47,10 +67,22 @@ def train(config: DictConfig):
optimizer=optimizer, optimizer=optimizer,
config=config, config=config,
) )
log = batch.run( dev_runner = Runner(
"Run run run run. Run run run away. Oh Oh oH OHHHHHHH yayayayayayayayaya! - David Byrne" stage=Stage.DEV,
model=model,
device=torch.device(device),
loader=devloader,
criterion=criterion,
optimizer=optimizer,
config=config,
) )
for epoch in range(epochs):
if epoch % config.dev_after == 0:
dev_log = dev_runner.run("dev epoch")
else:
train_log = train_runner.run("train epoch")
if __name__ == "__main__": if __name__ == "__main__":
train() train()

View File

@ -1,3 +1,6 @@
"""
runner for training and valdating
"""
import torch import torch
from torch import nn from torch import nn
from torch import optim from torch import optim
@ -8,7 +11,7 @@ from utils import Stage
from omegaconf import DictConfig from omegaconf import DictConfig
class Batch: class Runner:
def __init__( def __init__(
self, self,
stage: Stage, stage: Stage,
@ -34,8 +37,7 @@ class Batch:
self.model.train() self.model.train()
if self.config.debug: if self.config.debug:
breakpoint() breakpoint()
epoch = 0 for batch, (x, y) in enumerate(tqdm(self.loader, desc=desc)):
for epoch, (x, y) in enumerate(tqdm(self.loader, desc=desc)):
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss = self._run_batch((x, y)) loss = self._run_batch((x, y))
loss.backward() # Send loss backwards to accumulate gradients loss.backward() # Send loss backwards to accumulate gradients