rename batch to runner.
fill out makefile. add dev pipeline. gitignore data dir. add logger.py. fill out readme.md. export env.yml.
This commit is contained in:
parent
1f13224c4f
commit
0f12b26e40
28
Makefile
28
Makefile
|
@ -1,15 +1,25 @@
|
||||||
CONDA_ENV=ml_pipeline
|
CONDA_ENV=ml_pipeline
|
||||||
|
.PHONY: help
|
||||||
|
|
||||||
all: run
|
all: help
|
||||||
|
|
||||||
run:
|
run: ## run the pipeline (train)
|
||||||
./launch.sh
|
python src/pipeline.py \
|
||||||
|
debug=false
|
||||||
|
debug: ## run the pipeline (train) with debugging enabled
|
||||||
|
python src/pipeline.py \
|
||||||
|
debug=true
|
||||||
|
|
||||||
data:
|
data: ## download the mnist data
|
||||||
python src/data.py
|
wget https://pjreddie.com/media/files/mnist_train.csv -O data/mnist_train.csv
|
||||||
|
wget https://pjreddie.com/media/files/mnist_test.csv -O data/mnist_test.csv
|
||||||
|
|
||||||
batch:
|
env_import: environment.yml ## import any changes to env.yml into conda env
|
||||||
python src/batch.py
|
conda env update -n ${CONDA_ENV} --file $^
|
||||||
|
|
||||||
|
env_export: ## export the conda envirnoment without package or name
|
||||||
|
conda env export | head -n -1 | tail -n +2 > $@
|
||||||
|
|
||||||
|
help: ## display this help message
|
||||||
|
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||||
|
|
||||||
install:
|
|
||||||
conda env updates -n ${CONDA_ENV} --file environment.yml
|
|
||||||
|
|
58
README.md
58
README.md
|
@ -40,14 +40,54 @@ Especially in a research lab, standardizing on a few components makes switching
|
||||||
|
|
||||||
In this template, you'll see the following:
|
In this template, you'll see the following:
|
||||||
|
|
||||||
- `src/model`, `src/config`, `storage`, `test` dirs.
|
## directory structure
|
||||||
- `if __name__ == "__main__"` tests.
|
|
||||||
|
- `src/model`
|
||||||
|
- `src/config`
|
||||||
|
- `data/`
|
||||||
|
- `test/`
|
||||||
|
- pytest: unit testing.
|
||||||
|
- good for data shape
|
||||||
|
- TODO:
|
||||||
|
- `docs/`
|
||||||
|
- switching projects is easier with these in place
|
||||||
|
- organize them
|
||||||
|
|
||||||
|
- `**/__init__.py`
|
||||||
|
- creates modules out of dir.
|
||||||
|
- `import module` works with these.
|
||||||
|
- `README.md`
|
||||||
|
- root level required.
|
||||||
|
- can exist inside any dir.
|
||||||
|
- `environment.yml`
|
||||||
|
- `Makefile`
|
||||||
|
- to install and run stuff.
|
||||||
|
- houses common operations and scripts.
|
||||||
|
- `launch.sh`
|
||||||
|
- script to dispatch training.
|
||||||
|
|
||||||
|
## testing
|
||||||
|
|
||||||
|
- `if __name__ == "__main__"`.
|
||||||
|
- good way to test things
|
||||||
|
- enables lots breakpoints.
|
||||||
|
|
||||||
|
## config
|
||||||
- Hydra config.
|
- Hydra config.
|
||||||
- dataloader, optimizer, criterion, device, state are constructed in main, but passed to an object that runs batches.
|
- quickly experiment with hyperparameters
|
||||||
- tqdm to track progress.
|
- good way to define env. variables
|
||||||
- debug config flag enables lots breakpoints.
|
- lr, workers, batch_size
|
||||||
- python type hints.
|
- debug
|
||||||
- a `launch.sh` script to dispatch training.
|
|
||||||
- a Makefile to install and run stuff.
|
## data
|
||||||
- automatic linting with the `black` package.
|
|
||||||
- collate functions!
|
- collate functions!
|
||||||
|
|
||||||
|
## formatting python
|
||||||
|
- python type hints.
|
||||||
|
- automatic linting with the `black` package.
|
||||||
|
|
||||||
|
## running
|
||||||
|
- tqdm to track progress.
|
||||||
|
|
||||||
|
## architecture
|
||||||
|
- dataloader, optimizer, criterion, device, state are constructed in main, but passed to an object that runs batches.
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
*.csv
|
|
@ -1,4 +1,3 @@
|
||||||
name: ml
|
|
||||||
channels:
|
channels:
|
||||||
- pytorch
|
- pytorch
|
||||||
- conda-forge
|
- conda-forge
|
||||||
|
@ -6,30 +5,50 @@ channels:
|
||||||
dependencies:
|
dependencies:
|
||||||
- _libgcc_mutex=0.1=conda_forge
|
- _libgcc_mutex=0.1=conda_forge
|
||||||
- _openmp_mutex=4.5=2_gnu
|
- _openmp_mutex=4.5=2_gnu
|
||||||
|
- accelerate=0.15.0=pyhd8ed1ab_0
|
||||||
|
- antlr-python-runtime=4.9.3=pyhd8ed1ab_1
|
||||||
|
- appdirs=1.4.4=pyh9f0ad1d_0
|
||||||
|
- astroid=2.11.7=py310h06a4308_0
|
||||||
|
- attrs=22.1.0=py310h06a4308_0
|
||||||
- black=22.6.0=py310h06a4308_0
|
- black=22.6.0=py310h06a4308_0
|
||||||
- blas=1.0=mkl
|
- blas=1.0=mkl
|
||||||
- brotli=1.0.9=h5eee18b_7
|
- brotli=1.0.9=h5eee18b_7
|
||||||
- brotli-bin=1.0.9=h5eee18b_7
|
- brotli-bin=1.0.9=h5eee18b_7
|
||||||
|
- brotlipy=0.7.0=py310h5764c6d_1005
|
||||||
- bzip2=1.0.8=h7f98852_4
|
- bzip2=1.0.8=h7f98852_4
|
||||||
- ca-certificates=2022.10.11=h06a4308_0
|
- ca-certificates=2022.12.7=ha878542_0
|
||||||
|
- certifi=2022.12.7=pyhd8ed1ab_0
|
||||||
|
- cffi=1.15.1=py310h74dc2b5_0
|
||||||
|
- charset-normalizer=2.1.1=pyhd8ed1ab_0
|
||||||
- click=8.0.3=pyhd3eb1b0_0
|
- click=8.0.3=pyhd3eb1b0_0
|
||||||
- colorama=0.4.6=pyhd8ed1ab_0
|
- colorama=0.4.6=pyhd8ed1ab_0
|
||||||
|
- cryptography=39.0.0=py310h65dfdc0_0
|
||||||
- cycler=0.11.0=pyhd3eb1b0_0
|
- cycler=0.11.0=pyhd3eb1b0_0
|
||||||
- dbus=1.13.18=hb2f20db_0
|
- dbus=1.13.18=hb2f20db_0
|
||||||
|
- dill=0.3.6=pyhd8ed1ab_1
|
||||||
|
- docker-pycreds=0.4.0=py_0
|
||||||
- einops=0.4.1=pyhd8ed1ab_0
|
- einops=0.4.1=pyhd8ed1ab_0
|
||||||
- expat=2.4.9=h6a678d5_0
|
- expat=2.4.9=h6a678d5_0
|
||||||
- fontconfig=2.13.1=h6c09931_0
|
- fontconfig=2.13.1=h6c09931_0
|
||||||
- fonttools=4.25.0=pyhd3eb1b0_0
|
- fonttools=4.25.0=pyhd3eb1b0_0
|
||||||
- freetype=2.12.1=h4a9f257_0
|
- freetype=2.12.1=h4a9f257_0
|
||||||
- giflib=5.2.1=h7b6447c_0
|
- giflib=5.2.1=h7b6447c_0
|
||||||
|
- gitdb=4.0.10=pyhd8ed1ab_0
|
||||||
|
- gitpython=3.1.30=pyhd8ed1ab_0
|
||||||
- glib=2.69.1=h4ff587b_1
|
- glib=2.69.1=h4ff587b_1
|
||||||
- gst-plugins-base=1.14.0=h8213a91_2
|
- gst-plugins-base=1.14.0=h8213a91_2
|
||||||
- gstreamer=1.14.0=h28cd5cc_2
|
- gstreamer=1.14.0=h28cd5cc_2
|
||||||
|
- hydra-core=1.3.1=pyhd8ed1ab_0
|
||||||
- icu=58.2=he6710b0_3
|
- icu=58.2=he6710b0_3
|
||||||
|
- idna=3.4=pyhd8ed1ab_0
|
||||||
|
- importlib_resources=5.10.2=pyhd8ed1ab_0
|
||||||
|
- iniconfig=1.1.1=pyhd3eb1b0_0
|
||||||
- intel-openmp=2021.4.0=h06a4308_3561
|
- intel-openmp=2021.4.0=h06a4308_3561
|
||||||
|
- isort=5.9.3=pyhd3eb1b0_0
|
||||||
- jpeg=9e=h7f8727e_0
|
- jpeg=9e=h7f8727e_0
|
||||||
- kiwisolver=1.4.2=py310h295c915_0
|
- kiwisolver=1.4.2=py310h295c915_0
|
||||||
- krb5=1.19.2=hac12032_0
|
- krb5=1.19.2=hac12032_0
|
||||||
|
- lazy-object-proxy=1.6.0=py310h7f8727e_0
|
||||||
- lcms2=2.12=h3be6417_0
|
- lcms2=2.12=h3be6417_0
|
||||||
- ld_impl_linux-64=2.39=hc81fddc_0
|
- ld_impl_linux-64=2.39=hc81fddc_0
|
||||||
- lerc=3.0=h295c915_0
|
- lerc=3.0=h295c915_0
|
||||||
|
@ -50,6 +69,7 @@ dependencies:
|
||||||
- libopenblas=0.3.21=pthreads_h78a6416_3
|
- libopenblas=0.3.21=pthreads_h78a6416_3
|
||||||
- libpng=1.6.37=hbc83047_0
|
- libpng=1.6.37=hbc83047_0
|
||||||
- libpq=12.9=h16c4e8d_3
|
- libpq=12.9=h16c4e8d_3
|
||||||
|
- libprotobuf=3.20.1=h4ff587b_0
|
||||||
- libstdcxx-ng=12.2.0=h46fd767_19
|
- libstdcxx-ng=12.2.0=h46fd767_19
|
||||||
- libtiff=4.4.0=hecacb30_0
|
- libtiff=4.4.0=hecacb30_0
|
||||||
- libuuid=1.0.3=h7f8727e_2
|
- libuuid=1.0.3=h7f8727e_2
|
||||||
|
@ -62,6 +82,7 @@ dependencies:
|
||||||
- lz4-c=1.9.3=h295c915_1
|
- lz4-c=1.9.3=h295c915_1
|
||||||
- matplotlib=3.5.2=py310h06a4308_0
|
- matplotlib=3.5.2=py310h06a4308_0
|
||||||
- matplotlib-base=3.5.2=py310hf590b9c_0
|
- matplotlib-base=3.5.2=py310hf590b9c_0
|
||||||
|
- mccabe=0.7.0=pyhd3eb1b0_0
|
||||||
- mkl=2021.4.0=h06a4308_640
|
- mkl=2021.4.0=h06a4308_640
|
||||||
- mkl-service=2.4.0=py310h7f8727e_0
|
- mkl-service=2.4.0=py310h7f8727e_0
|
||||||
- mkl_fft=1.3.1=py310hd6ae3a3_0
|
- mkl_fft=1.3.1=py310hd6ae3a3_0
|
||||||
|
@ -73,39 +94,62 @@ dependencies:
|
||||||
- nss=3.74=h0370c37_0
|
- nss=3.74=h0370c37_0
|
||||||
- numpy=1.23.3=py310hd5efca6_0
|
- numpy=1.23.3=py310hd5efca6_0
|
||||||
- numpy-base=1.23.3=py310h8e6c178_0
|
- numpy-base=1.23.3=py310h8e6c178_0
|
||||||
- openssl=1.1.1q=h7f8727e_0
|
- omegaconf=2.3.0=pyhd8ed1ab_0
|
||||||
|
- openssl=1.1.1s=h0b41bf4_1
|
||||||
- packaging=21.3=pyhd3eb1b0_0
|
- packaging=21.3=pyhd3eb1b0_0
|
||||||
- pathspec=0.10.1=pyhd8ed1ab_0
|
- pathspec=0.10.1=pyhd8ed1ab_0
|
||||||
|
- pathtools=0.1.2=py_1
|
||||||
- pcre=8.45=h295c915_0
|
- pcre=8.45=h295c915_0
|
||||||
- pillow=9.2.0=py310hace64e9_1
|
- pillow=9.2.0=py310hace64e9_1
|
||||||
- pip=22.3=pyhd8ed1ab_0
|
- pip=22.3=pyhd8ed1ab_0
|
||||||
- platformdirs=2.5.2=pyhd8ed1ab_1
|
- platformdirs=2.5.2=pyhd8ed1ab_1
|
||||||
|
- pluggy=1.0.0=py310h06a4308_1
|
||||||
- ply=3.11=py310h06a4308_0
|
- ply=3.11=py310h06a4308_0
|
||||||
|
- protobuf=3.20.1=py310h295c915_0
|
||||||
|
- psutil=5.9.4=py310h5764c6d_0
|
||||||
|
- py=1.11.0=pyhd3eb1b0_0
|
||||||
|
- pycparser=2.21=pyhd8ed1ab_0
|
||||||
|
- pylint=2.14.5=py310h06a4308_0
|
||||||
|
- pyopenssl=23.0.0=pyhd8ed1ab_0
|
||||||
- pyparsing=3.0.9=py310h06a4308_0
|
- pyparsing=3.0.9=py310h06a4308_0
|
||||||
- pyqt=5.15.7=py310h6a678d5_1
|
- pyqt=5.15.7=py310h6a678d5_1
|
||||||
|
- pysocks=1.7.1=pyha2e5f31_6
|
||||||
|
- pytest=7.1.2=py310h06a4308_0
|
||||||
- python=3.10.6=haa1d7c7_1
|
- python=3.10.6=haa1d7c7_1
|
||||||
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
||||||
|
- python-dotenv=0.21.0=py310h06a4308_0
|
||||||
|
- python_abi=3.10=2_cp310
|
||||||
- pytorch=1.13.0=py3.10_cpu_0
|
- pytorch=1.13.0=py3.10_cpu_0
|
||||||
- pytorch-mutex=1.0=cpu
|
- pytorch-mutex=1.0=cpu
|
||||||
|
- pyyaml=6.0=py310h5764c6d_5
|
||||||
- qt-main=5.15.2=h327a75a_7
|
- qt-main=5.15.2=h327a75a_7
|
||||||
- qt-webengine=5.15.9=hd2b0992_4
|
- qt-webengine=5.15.9=hd2b0992_4
|
||||||
- qtwebkit=5.212=h4eab89a_4
|
- qtwebkit=5.212=h4eab89a_4
|
||||||
- readline=8.1.2=h0f457ee_0
|
- readline=8.1.2=h0f457ee_0
|
||||||
|
- requests=2.28.2=pyhd8ed1ab_0
|
||||||
|
- sentry-sdk=1.14.0=pyhd8ed1ab_0
|
||||||
|
- setproctitle=1.3.2=py310h5764c6d_1
|
||||||
- setuptools=65.5.0=pyhd8ed1ab_0
|
- setuptools=65.5.0=pyhd8ed1ab_0
|
||||||
- sip=6.6.2=py310h6a678d5_0
|
- sip=6.6.2=py310h6a678d5_0
|
||||||
- six=1.16.0=pyhd3eb1b0_1
|
- six=1.16.0=pyhd3eb1b0_1
|
||||||
|
- smmap=3.0.5=pyh44b312d_0
|
||||||
- sqlite=3.39.3=h5082296_0
|
- sqlite=3.39.3=h5082296_0
|
||||||
- tk=8.6.12=h1ccaba5_0
|
- tk=8.6.12=h1ccaba5_0
|
||||||
- toml=0.10.2=pyhd3eb1b0_0
|
- toml=0.10.2=pyhd3eb1b0_0
|
||||||
- tomli=2.0.1=py310h06a4308_0
|
- tomli=2.0.1=py310h06a4308_0
|
||||||
|
- tomlkit=0.11.1=py310h06a4308_0
|
||||||
- tornado=6.2=py310h5eee18b_0
|
- tornado=6.2=py310h5eee18b_0
|
||||||
- tqdm=4.64.1=pyhd8ed1ab_0
|
- tqdm=4.64.1=pyhd8ed1ab_0
|
||||||
- typing_extensions=4.3.0=py310h06a4308_0
|
- typing_extensions=4.3.0=py310h06a4308_0
|
||||||
- tzdata=2022e=h191b570_0
|
- tzdata=2022e=h191b570_0
|
||||||
|
- urllib3=1.26.14=pyhd8ed1ab_0
|
||||||
|
- wandb=0.13.9=pyhd8ed1ab_0
|
||||||
- wheel=0.37.1=pyhd8ed1ab_0
|
- wheel=0.37.1=pyhd8ed1ab_0
|
||||||
|
- wrapt=1.14.1=py310h5eee18b_0
|
||||||
- xz=5.2.6=h166bdaf_0
|
- xz=5.2.6=h166bdaf_0
|
||||||
|
- yaml=0.2.5=h7f98852_2
|
||||||
|
- zipp=3.11.0=pyhd8ed1ab_0
|
||||||
- zlib=1.2.13=h5eee18b_0
|
- zlib=1.2.13=h5eee18b_0
|
||||||
- zstd=1.5.2=ha4553b6_0
|
- zstd=1.5.2=ha4553b6_0
|
||||||
- pip:
|
- pip:
|
||||||
- pyqt5-sip==12.11.0
|
- pyqt5-sip==12.11.0
|
||||||
prefix: /home/personal/Dev/conda/envs/ml
|
|
||||||
|
|
|
@ -4,3 +4,5 @@ lr: 2e-4
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
num_workers: 0
|
num_workers: 0
|
||||||
device: "cpu"
|
device: "cpu"
|
||||||
|
epochs: 4
|
||||||
|
dev_after: 20
|
||||||
|
|
|
@ -0,0 +1,111 @@
|
||||||
|
from tkinter import W
|
||||||
|
import torch
|
||||||
|
import wandb
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from einops import rearrange
|
||||||
|
from typing import Protocol, Tuple, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class Logger(Protocol):
|
||||||
|
def metrics(self, metrics: dict, epoch: int):
|
||||||
|
"""loss etc."""
|
||||||
|
|
||||||
|
def hyperparameters(self, hyperparameters: dict):
|
||||||
|
"""model states"""
|
||||||
|
|
||||||
|
def predictions(self, predictions: dict):
|
||||||
|
"""inference time stuff"""
|
||||||
|
|
||||||
|
def images(self, images: np.ndarray):
|
||||||
|
"""log images"""
|
||||||
|
|
||||||
|
|
||||||
|
class WandbLogger:
|
||||||
|
def __init__(self, project: str, entity: str, name: Optional[str], notes: str):
|
||||||
|
self.project = project
|
||||||
|
self.entity = entity
|
||||||
|
self.notes = notes
|
||||||
|
self.experiment = wandb.init(project=project, entity=entity, notes=notes)
|
||||||
|
self.experiment.name = name
|
||||||
|
|
||||||
|
self.data_dict = {}
|
||||||
|
|
||||||
|
def metrics(self, metrics: dict):
|
||||||
|
"""loss etc."""
|
||||||
|
|
||||||
|
self.data_dict.update(metrics)
|
||||||
|
|
||||||
|
def hyperparameters(self, hyperparameters: dict):
|
||||||
|
"""model states"""
|
||||||
|
self.experiment.config.update(hyperparameters, allow_val_change=True)
|
||||||
|
|
||||||
|
def predictions(self, predictions: dict):
|
||||||
|
"""inference time stuff"""
|
||||||
|
|
||||||
|
def image(self, image: dict):
|
||||||
|
"""log images to wandb"""
|
||||||
|
self.data_dict.update({'Generate Image' : image})
|
||||||
|
|
||||||
|
def video(self, images: str, title: str):
|
||||||
|
"""log images to wandb"""
|
||||||
|
|
||||||
|
images = np.uint8(rearrange(images, 't b c h w -> b t c h w'))
|
||||||
|
self.data_dict.update({f"{title}": wandb.Video(images, fps=20)})
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
self.experiment.log(self.data_dict)
|
||||||
|
self.data_dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
class DebugLogger:
|
||||||
|
def __init__(self, project: str, entity: str, name: str, notes: str):
|
||||||
|
self.project = project
|
||||||
|
self.entity = entity
|
||||||
|
self.name = name
|
||||||
|
self.notes = notes
|
||||||
|
|
||||||
|
def metrics(self, metrics: dict, epoch: int = None):
|
||||||
|
"""
|
||||||
|
loss etc.
|
||||||
|
"""
|
||||||
|
print(f"metrics: {metrics}")
|
||||||
|
|
||||||
|
def hyperparameters(self, hyperparameters: dict):
|
||||||
|
"""
|
||||||
|
model states
|
||||||
|
"""
|
||||||
|
print(f"hyperparameters: {hyperparameters}")
|
||||||
|
|
||||||
|
def predictions(self, predictions: dict):
|
||||||
|
"""
|
||||||
|
inference time stuff
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Checkpoint:
|
||||||
|
def __init__(self, checkpoint_path):
|
||||||
|
self.checkpoint_path = checkpoint_path
|
||||||
|
|
||||||
|
def load(self) -> Tuple:
|
||||||
|
checkpoint = torch.load(self.checkpoint_path)
|
||||||
|
model = checkpoint["model"]
|
||||||
|
optimizer = checkpoint["optimizer"]
|
||||||
|
epoch = checkpoint["epoch"]
|
||||||
|
loss = checkpoint["loss"]
|
||||||
|
return (model, optimizer, epoch, loss)
|
||||||
|
|
||||||
|
def save(self, model: torch.nn.Module, optimizer, epoch, loss):
|
||||||
|
checkpoint = {
|
||||||
|
"model": model,
|
||||||
|
"optimizer": optimizer,
|
||||||
|
"epoch": epoch,
|
||||||
|
"loss": loss,
|
||||||
|
}
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
|
||||||
|
name = "".join(random.choices(string.ascii_letters, k=10)) + ".tar"
|
||||||
|
torch.save(checkpoint, f"{name}")
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,17 @@ main class for building a DL pipeline.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from batch import Batch
|
"""
|
||||||
|
the main entry point for training a model
|
||||||
|
|
||||||
|
coordinates:
|
||||||
|
|
||||||
|
- datasets
|
||||||
|
- dataloaders
|
||||||
|
- runner
|
||||||
|
|
||||||
|
"""
|
||||||
|
from runner import Runner
|
||||||
from model.linear import DNN
|
from model.linear import DNN
|
||||||
from model.cnn import VGG16, VGG11
|
from model.cnn import VGG16, VGG11
|
||||||
from data import MnistDataset
|
from data import MnistDataset
|
||||||
|
@ -11,7 +21,6 @@ from utils import Stage
|
||||||
import torch
|
import torch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from collate import channel_to_batch
|
from collate import channel_to_batch
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
@ -24,13 +33,24 @@ def train(config: DictConfig):
|
||||||
batch_size = config.batch_size
|
batch_size = config.batch_size
|
||||||
num_workers = config.num_workers
|
num_workers = config.num_workers
|
||||||
device = config.device
|
device = config.device
|
||||||
|
epochs = config.epochs
|
||||||
|
|
||||||
path = Path(config.app_dir) / "storage/mnist_train.csv"
|
train_path = Path(config.app_dir) / "data/mnist_train.csv"
|
||||||
trainset = MnistDataset(path=path)
|
trainset = MnistDataset(path=train_path)
|
||||||
|
|
||||||
|
dev_path = Path(config.app_dir) / "data/mnist_test.csv"
|
||||||
|
devset = MnistDataset(path=dev_path)
|
||||||
|
|
||||||
trainloader = torch.utils.data.DataLoader(
|
trainloader = torch.utils.data.DataLoader(
|
||||||
trainset,
|
trainset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=num_workers,
|
||||||
|
# collate_fn=channel_to_batch,
|
||||||
|
)
|
||||||
|
devloader = torch.utils.data.DataLoader(
|
||||||
|
devset,
|
||||||
|
batch_size=batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
# collate_fn=channel_to_batch,
|
# collate_fn=channel_to_batch,
|
||||||
|
@ -38,7 +58,7 @@ def train(config: DictConfig):
|
||||||
model = VGG11(in_channels=1, num_classes=10)
|
model = VGG11(in_channels=1, num_classes=10)
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||||
batch = Batch(
|
train_runner = Runner(
|
||||||
stage=Stage.TRAIN,
|
stage=Stage.TRAIN,
|
||||||
model=model,
|
model=model,
|
||||||
device=torch.device(device),
|
device=torch.device(device),
|
||||||
|
@ -47,10 +67,22 @@ def train(config: DictConfig):
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
log = batch.run(
|
dev_runner = Runner(
|
||||||
"Run run run run. Run run run away. Oh Oh oH OHHHHHHH yayayayayayayayaya! - David Byrne"
|
stage=Stage.DEV,
|
||||||
|
model=model,
|
||||||
|
device=torch.device(device),
|
||||||
|
loader=devloader,
|
||||||
|
criterion=criterion,
|
||||||
|
optimizer=optimizer,
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
if epoch % config.dev_after == 0:
|
||||||
|
dev_log = dev_runner.run("dev epoch")
|
||||||
|
else:
|
||||||
|
train_log = train_runner.run("train epoch")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train()
|
train()
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
"""
|
||||||
|
runner for training and valdating
|
||||||
|
"""
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch import optim
|
from torch import optim
|
||||||
|
@ -8,7 +11,7 @@ from utils import Stage
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
|
||||||
class Batch:
|
class Runner:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
stage: Stage,
|
stage: Stage,
|
||||||
|
@ -34,8 +37,7 @@ class Batch:
|
||||||
self.model.train()
|
self.model.train()
|
||||||
if self.config.debug:
|
if self.config.debug:
|
||||||
breakpoint()
|
breakpoint()
|
||||||
epoch = 0
|
for batch, (x, y) in enumerate(tqdm(self.loader, desc=desc)):
|
||||||
for epoch, (x, y) in enumerate(tqdm(self.loader, desc=desc)):
|
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
loss = self._run_batch((x, y))
|
loss = self._run_batch((x, y))
|
||||||
loss.backward() # Send loss backwards to accumulate gradients
|
loss.backward() # Send loss backwards to accumulate gradients
|
Loading…
Reference in New Issue