Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a09926e9ca | ||
|
|
996f4bc97c | ||
|
|
ecc8939517 | ||
|
|
0f12b26e40 | ||
|
|
1f13224c4f |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,2 +1,4 @@
|
|||||||
storage/
|
storage/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
outputs/
|
||||||
|
.env
|
||||||
|
|||||||
31
Makefile
31
Makefile
@@ -1,15 +1,28 @@
|
|||||||
CONDA_ENV=ml_pipeline
|
CONDA_ENV=ml_pipeline
|
||||||
|
.PHONY: help
|
||||||
|
|
||||||
all: run
|
all: help
|
||||||
|
|
||||||
run:
|
run: ## run the pipeline (train)
|
||||||
python src/pipeline.py train
|
python src/train.py \
|
||||||
|
debug=false
|
||||||
|
debug: ## run the pipeline (train) with debugging enabled
|
||||||
|
python src/train.py \
|
||||||
|
debug=true
|
||||||
|
|
||||||
data:
|
data: ## download the mnist data
|
||||||
python src/data.py
|
wget https://pjreddie.com/media/files/mnist_train.csv -O data/mnist_train.csv
|
||||||
|
wget https://pjreddie.com/media/files/mnist_test.csv -O data/mnist_test.csv
|
||||||
|
|
||||||
batch:
|
install: conda-lock.yml ## import any changes to env.yml into conda env
|
||||||
python src/batch.py
|
conda-lock install --name ${CONDA_ENV} $^
|
||||||
|
|
||||||
|
lock: environment.yml ## lock the current conda env
|
||||||
|
conda-lock
|
||||||
|
|
||||||
|
env_export: ## export the conda envirnoment without package or name
|
||||||
|
conda env export | head -n -1 | tail -n +2 > $@
|
||||||
|
|
||||||
|
help: ## display this help message
|
||||||
|
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||||
|
|
||||||
install:
|
|
||||||
conda env updates -n ${CONDA_ENV} --file environment.yml
|
|
||||||
|
|||||||
138
README.md
138
README.md
@@ -6,10 +6,16 @@ Instead of remembering where to put everything and making a different choice for
|
|||||||
|
|
||||||
Think of it like a mini-pytorch lightening, with all the fory internals exposed for extension and modification.
|
Think of it like a mini-pytorch lightening, with all the fory internals exposed for extension and modification.
|
||||||
|
|
||||||
|
This project lives here: [https://github.com/publicmatt.com/ml_pipeline](https://github.com/publicmatt.com/ml_pipeline).
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
### Install:
|
# Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make help # lists available options.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Install:
|
||||||
|
|
||||||
Install the conda requirements:
|
Install the conda requirements:
|
||||||
|
|
||||||
@@ -17,13 +23,15 @@ Install the conda requirements:
|
|||||||
make install
|
make install
|
||||||
```
|
```
|
||||||
|
|
||||||
Which is a proxy for calling:
|
## Data:
|
||||||
|
|
||||||
|
Download mnist data from PJReadie's website:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
conda env updates -n ml_pipeline --file environment.yml
|
make data
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run:
|
## Run:
|
||||||
|
|
||||||
Run the code on MNIST with the following command:
|
Run the code on MNIST with the following command:
|
||||||
|
|
||||||
@@ -31,3 +39,123 @@ Run the code on MNIST with the following command:
|
|||||||
make run
|
make run
|
||||||
```
|
```
|
||||||
|
|
||||||
|
# Tutorial
|
||||||
|
|
||||||
|
The motivation for building a template for deep learning pipelines is this: deep learning is hard enough without every code baase being a little different.
|
||||||
|
|
||||||
|
Especially in a research lab, standardizing on a few components makes switching between projects easier.
|
||||||
|
|
||||||
|
In this template, you'll see the following:
|
||||||
|
|
||||||
|
## directory structure
|
||||||
|
```
|
||||||
|
.
|
||||||
|
├── README.md
|
||||||
|
├── environment.yml
|
||||||
|
├── launch.sh
|
||||||
|
├── Makefile
|
||||||
|
├── data
|
||||||
|
│ ├── mnist_test.csv
|
||||||
|
│ └── mnist_train.csv
|
||||||
|
├── docs
|
||||||
|
│ └── 2023-01-26.md
|
||||||
|
├── src
|
||||||
|
│ ├── config
|
||||||
|
│ │ └── main.yaml
|
||||||
|
│ ├── data
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── README.md
|
||||||
|
│ │ ├── collate.py
|
||||||
|
│ │ └── dataset.py
|
||||||
|
│ ├── eval.py
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── model
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── README.md
|
||||||
|
│ │ ├── cnn.py
|
||||||
|
│ │ └── linear.py
|
||||||
|
│ ├── pipeline
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── README.md
|
||||||
|
│ │ ├── logger.py
|
||||||
|
│ │ ├── runner.py
|
||||||
|
│ │ └── utils.py
|
||||||
|
│ ├── sample.py
|
||||||
|
│ └── train.py
|
||||||
|
└── test
|
||||||
|
├── __init__.py
|
||||||
|
└── test_pipeline.py
|
||||||
|
|
||||||
|
8 directories, 25 files
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## what and why?
|
||||||
|
|
||||||
|
- `environment.yml`
|
||||||
|
- hutch research has standardized on conda
|
||||||
|
- here's a good tutorial on getting that setup: [seth email](emailto:bassetis@wwu.edu)
|
||||||
|
- `launch.sh` or `Makefile`
|
||||||
|
- to install and run stuff.
|
||||||
|
- houses common operations and scripts.
|
||||||
|
- `launch.sh` to dispatch training.
|
||||||
|
- `README.md`
|
||||||
|
- explain the project and how to run it.
|
||||||
|
- list authors.
|
||||||
|
- list resources that new collaborators might need.
|
||||||
|
- root level dir.
|
||||||
|
- can exist inside any dir.
|
||||||
|
- reads nicely on github.com.
|
||||||
|
- `docs/`
|
||||||
|
- switching projects is easier with these in place.
|
||||||
|
- organize them by meeting, or weekly agenda.
|
||||||
|
- generally collection of markdown files.
|
||||||
|
- `test/`
|
||||||
|
- TODO
|
||||||
|
- pytest: unit testing.
|
||||||
|
- good for data shape. not sure what else.
|
||||||
|
- `data/`
|
||||||
|
- raw data
|
||||||
|
- do not commit these to repo generally.
|
||||||
|
- `echo "*.csv" >> data/.gitignore`
|
||||||
|
- `__init__.py`
|
||||||
|
- creates modules out of dir.
|
||||||
|
- `import module` works b/c of these.
|
||||||
|
- `src/model/`
|
||||||
|
- if you have a large project, you might have multiple architectures/models.
|
||||||
|
- small projects might just have `model/VGG.py` or `model/3d_unet.py`.
|
||||||
|
- `src/config`
|
||||||
|
- based on hydra python package.
|
||||||
|
- quickly change run variables and hyperparameters.
|
||||||
|
- `src/pipeline`
|
||||||
|
- where the magic happens.
|
||||||
|
- `train.py` creates all the objects, hands them off to runner for batching, monitors each epoch.
|
||||||
|
|
||||||
|
## testing
|
||||||
|
- `if __name__ == "__main__"`.
|
||||||
|
- good way to test things
|
||||||
|
- enables lots breakpoints.
|
||||||
|
|
||||||
|
## config
|
||||||
|
- Hydra config.
|
||||||
|
- quickly experiment with hyperparameters
|
||||||
|
- good way to define env. variables
|
||||||
|
- lr, workers, batch_size
|
||||||
|
- debug
|
||||||
|
|
||||||
|
## data
|
||||||
|
- collate functions!
|
||||||
|
- datasets.
|
||||||
|
- dataloader.
|
||||||
|
|
||||||
|
## formatting python
|
||||||
|
- python type hints.
|
||||||
|
- automatic linting with the `black` package.
|
||||||
|
|
||||||
|
## running
|
||||||
|
- tqdm to track progress.
|
||||||
|
- wandb for logging.
|
||||||
|
|
||||||
|
## architecture
|
||||||
|
- dataloader, optimizer, criterion, device, state are constructed in main, but passed to an object that runs batches.
|
||||||
|
|
||||||
|
|||||||
73
bin/install_conda.sh
Normal file
73
bin/install_conda.sh
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
PYTHON_VERSION=3.10
|
||||||
|
ENV_NAME=ml_pipeline
|
||||||
|
INSTALL_DIR=$HOME/Dev
|
||||||
|
# for wwu research:
|
||||||
|
# INSTALL_DIR=/research/hutchinson/workspace/$USERNAME
|
||||||
|
|
||||||
|
####################
|
||||||
|
#
|
||||||
|
# download miniconda
|
||||||
|
#
|
||||||
|
####################
|
||||||
|
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O $HOME/Downloads/Miniconda3-latest-Linux-x86_64.sh
|
||||||
|
|
||||||
|
####################
|
||||||
|
#
|
||||||
|
# run install script
|
||||||
|
# headless
|
||||||
|
#
|
||||||
|
####################
|
||||||
|
rm -rf $INSTALL_DIR/miniconda3
|
||||||
|
bash $HOME/Downloads/Miniconda3-latest-Linux-x86_64.sh -b -p $INSTALL_DIR/miniconda3
|
||||||
|
|
||||||
|
####################
|
||||||
|
#
|
||||||
|
# create first conda environment
|
||||||
|
#
|
||||||
|
####################
|
||||||
|
conda create --name $ENV_NAME python=$PYTHON_VERSION -y
|
||||||
|
|
||||||
|
################
|
||||||
|
#
|
||||||
|
# place the following in $HOME/.bashrc
|
||||||
|
#
|
||||||
|
# then use `hutchconda` to activate base env
|
||||||
|
#
|
||||||
|
################
|
||||||
|
|
||||||
|
# WORKSPACE_DIR=/research/hutchinson/workspace/$USERNAME
|
||||||
|
# hutchconda() {
|
||||||
|
# __conda_setup="$('$WORKSPACE_DIR/miniconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
|
||||||
|
# if [ $? -eq 0 ]; then
|
||||||
|
# eval "$__conda_setup"
|
||||||
|
# else
|
||||||
|
# if [ -f "$WORKSPACE_DIR/miniconda3/etc/profile.d/conda.sh" ]; then
|
||||||
|
# . "$WORKSPACE_DIR/miniconda3/etc/profile.d/conda.sh"
|
||||||
|
# else
|
||||||
|
# export PATH="$WORKSPACE_DIR/miniconda3/bin:$PATH"
|
||||||
|
# fi
|
||||||
|
# fi
|
||||||
|
# unset __conda_setup
|
||||||
|
# }
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
#
|
||||||
|
# activate conda environment
|
||||||
|
#
|
||||||
|
####################
|
||||||
|
conda activate $ENV_NAME
|
||||||
|
|
||||||
|
####################
|
||||||
|
#
|
||||||
|
# install pytorch
|
||||||
|
#
|
||||||
|
####################
|
||||||
|
conda install -c pytorch pytorch -y
|
||||||
|
|
||||||
|
####################
|
||||||
|
#
|
||||||
|
# or install from envirnoment.yml
|
||||||
|
#
|
||||||
|
####################
|
||||||
|
conda env update -n $ENV_NAME --file environment.yml
|
||||||
3233
conda-lock.yml
Normal file
3233
conda-lock.yml
Normal file
File diff suppressed because it is too large
Load Diff
1
data/.gitignore
vendored
Normal file
1
data/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
*.csv
|
||||||
0
docs/2023-01-26.md
Normal file
0
docs/2023-01-26.md
Normal file
124
environment.yml
124
environment.yml
@@ -1,111 +1,23 @@
|
|||||||
name: ml
|
|
||||||
channels:
|
channels:
|
||||||
- pytorch
|
- pytorch
|
||||||
- conda-forge
|
- conda-forge
|
||||||
- defaults
|
- defaults
|
||||||
dependencies:
|
dependencies:
|
||||||
- _libgcc_mutex=0.1=conda_forge
|
- conda-lock
|
||||||
- _openmp_mutex=4.5=2_gnu
|
- black
|
||||||
- black=22.6.0=py310h06a4308_0
|
- click
|
||||||
- blas=1.0=mkl
|
- einops
|
||||||
- brotli=1.0.9=h5eee18b_7
|
- hydra-core
|
||||||
- brotli-bin=1.0.9=h5eee18b_7
|
- matplotlib
|
||||||
- bzip2=1.0.8=h7f98852_4
|
- numpy
|
||||||
- ca-certificates=2022.10.11=h06a4308_0
|
- pip
|
||||||
- click=8.0.3=pyhd3eb1b0_0
|
- wandb
|
||||||
- colorama=0.4.6=pyhd8ed1ab_0
|
- pytest
|
||||||
- cycler=0.11.0=pyhd3eb1b0_0
|
- python=3.10
|
||||||
- dbus=1.13.18=hb2f20db_0
|
- python-dotenv
|
||||||
- einops=0.4.1=pyhd8ed1ab_0
|
- pytorch=1.13
|
||||||
- expat=2.4.9=h6a678d5_0
|
- requests
|
||||||
- fontconfig=2.13.1=h6c09931_0
|
- sqlite
|
||||||
- fonttools=4.25.0=pyhd3eb1b0_0
|
- tqdm
|
||||||
- freetype=2.12.1=h4a9f257_0
|
platforms:
|
||||||
- giflib=5.2.1=h7b6447c_0
|
- linux-64
|
||||||
- glib=2.69.1=h4ff587b_1
|
|
||||||
- gst-plugins-base=1.14.0=h8213a91_2
|
|
||||||
- gstreamer=1.14.0=h28cd5cc_2
|
|
||||||
- icu=58.2=he6710b0_3
|
|
||||||
- intel-openmp=2021.4.0=h06a4308_3561
|
|
||||||
- jpeg=9e=h7f8727e_0
|
|
||||||
- kiwisolver=1.4.2=py310h295c915_0
|
|
||||||
- krb5=1.19.2=hac12032_0
|
|
||||||
- lcms2=2.12=h3be6417_0
|
|
||||||
- ld_impl_linux-64=2.39=hc81fddc_0
|
|
||||||
- lerc=3.0=h295c915_0
|
|
||||||
- libbrotlicommon=1.0.9=h5eee18b_7
|
|
||||||
- libbrotlidec=1.0.9=h5eee18b_7
|
|
||||||
- libbrotlienc=1.0.9=h5eee18b_7
|
|
||||||
- libclang=10.0.1=default_hb85057a_2
|
|
||||||
- libdeflate=1.8=h7f8727e_5
|
|
||||||
- libedit=3.1.20210910=h7f8727e_0
|
|
||||||
- libevent=2.1.12=h8f2d780_0
|
|
||||||
- libffi=3.3=he6710b0_2
|
|
||||||
- libgcc-ng=12.2.0=h65d4601_19
|
|
||||||
- libgfortran-ng=12.2.0=h69a702a_19
|
|
||||||
- libgfortran5=12.2.0=h337968e_19
|
|
||||||
- libgomp=12.2.0=h65d4601_19
|
|
||||||
- libllvm10=10.0.1=hbcb73fb_5
|
|
||||||
- libnsl=2.0.0=h7f98852_0
|
|
||||||
- libopenblas=0.3.21=pthreads_h78a6416_3
|
|
||||||
- libpng=1.6.37=hbc83047_0
|
|
||||||
- libpq=12.9=h16c4e8d_3
|
|
||||||
- libstdcxx-ng=12.2.0=h46fd767_19
|
|
||||||
- libtiff=4.4.0=hecacb30_0
|
|
||||||
- libuuid=1.0.3=h7f8727e_2
|
|
||||||
- libwebp=1.2.4=h11a3e52_0
|
|
||||||
- libwebp-base=1.2.4=h5eee18b_0
|
|
||||||
- libxcb=1.15=h7f8727e_0
|
|
||||||
- libxkbcommon=1.0.1=hfa300c1_0
|
|
||||||
- libxml2=2.9.14=h74e7548_0
|
|
||||||
- libxslt=1.1.35=h4e12654_0
|
|
||||||
- lz4-c=1.9.3=h295c915_1
|
|
||||||
- matplotlib=3.5.2=py310h06a4308_0
|
|
||||||
- matplotlib-base=3.5.2=py310hf590b9c_0
|
|
||||||
- mkl=2021.4.0=h06a4308_640
|
|
||||||
- mkl-service=2.4.0=py310h7f8727e_0
|
|
||||||
- mkl_fft=1.3.1=py310hd6ae3a3_0
|
|
||||||
- mkl_random=1.2.2=py310h00e6091_0
|
|
||||||
- munkres=1.1.4=py_0
|
|
||||||
- mypy_extensions=0.4.3=py310h06a4308_0
|
|
||||||
- ncurses=6.3=h27087fc_1
|
|
||||||
- nspr=4.33=h295c915_0
|
|
||||||
- nss=3.74=h0370c37_0
|
|
||||||
- numpy=1.23.3=py310hd5efca6_0
|
|
||||||
- numpy-base=1.23.3=py310h8e6c178_0
|
|
||||||
- openssl=1.1.1q=h7f8727e_0
|
|
||||||
- packaging=21.3=pyhd3eb1b0_0
|
|
||||||
- pathspec=0.10.1=pyhd8ed1ab_0
|
|
||||||
- pcre=8.45=h295c915_0
|
|
||||||
- pillow=9.2.0=py310hace64e9_1
|
|
||||||
- pip=22.3=pyhd8ed1ab_0
|
|
||||||
- platformdirs=2.5.2=pyhd8ed1ab_1
|
|
||||||
- ply=3.11=py310h06a4308_0
|
|
||||||
- pyparsing=3.0.9=py310h06a4308_0
|
|
||||||
- pyqt=5.15.7=py310h6a678d5_1
|
|
||||||
- python=3.10.6=haa1d7c7_1
|
|
||||||
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
|
||||||
- pytorch=1.13.0=py3.10_cpu_0
|
|
||||||
- pytorch-mutex=1.0=cpu
|
|
||||||
- qt-main=5.15.2=h327a75a_7
|
|
||||||
- qt-webengine=5.15.9=hd2b0992_4
|
|
||||||
- qtwebkit=5.212=h4eab89a_4
|
|
||||||
- readline=8.1.2=h0f457ee_0
|
|
||||||
- setuptools=65.5.0=pyhd8ed1ab_0
|
|
||||||
- sip=6.6.2=py310h6a678d5_0
|
|
||||||
- six=1.16.0=pyhd3eb1b0_1
|
|
||||||
- sqlite=3.39.3=h5082296_0
|
|
||||||
- tk=8.6.12=h1ccaba5_0
|
|
||||||
- toml=0.10.2=pyhd3eb1b0_0
|
|
||||||
- tomli=2.0.1=py310h06a4308_0
|
|
||||||
- tornado=6.2=py310h5eee18b_0
|
|
||||||
- tqdm=4.64.1=pyhd8ed1ab_0
|
|
||||||
- typing_extensions=4.3.0=py310h06a4308_0
|
|
||||||
- tzdata=2022e=h191b570_0
|
|
||||||
- wheel=0.37.1=pyhd8ed1ab_0
|
|
||||||
- xz=5.2.6=h166bdaf_0
|
|
||||||
- zlib=1.2.13=h5eee18b_0
|
|
||||||
- zstd=1.5.2=ha4553b6_0
|
|
||||||
- pip:
|
|
||||||
- pyqt5-sip==12.11.0
|
|
||||||
prefix: /home/personal/Dev/conda/envs/ml
|
|
||||||
|
|||||||
8
src/config/main.yaml
Normal file
8
src/config/main.yaml
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
app_dir: ${hydra:runtime.cwd}
|
||||||
|
debug: true
|
||||||
|
lr: 2e-4
|
||||||
|
batch_size: 16
|
||||||
|
num_workers: 0
|
||||||
|
device: "cpu"
|
||||||
|
epochs: 4
|
||||||
|
dev_after: 20
|
||||||
54
src/data.py
54
src/data.py
@@ -1,54 +0,0 @@
|
|||||||
from torch.utils.data import Dataset
|
|
||||||
import numpy as np
|
|
||||||
import einops
|
|
||||||
import csv
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class FashionDataset(Dataset):
|
|
||||||
def __init__(self, path: str):
|
|
||||||
self.path = path
|
|
||||||
self.x, self.y = self.load()
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return (self.x[idx], self.y[idx])
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.x)
|
|
||||||
|
|
||||||
def load(self):
|
|
||||||
# opening the CSV file
|
|
||||||
with open(self.path, mode="r") as file:
|
|
||||||
images = list()
|
|
||||||
classes = list()
|
|
||||||
# reading the CSV file
|
|
||||||
csvFile = csv.reader(file)
|
|
||||||
# displaying the contents of the CSV file
|
|
||||||
header = next(csvFile)
|
|
||||||
limit = 1000
|
|
||||||
for line in csvFile:
|
|
||||||
if limit < 1:
|
|
||||||
break
|
|
||||||
classes.append(int(line[:1][0]))
|
|
||||||
images.append([int(x) for x in line[1:]])
|
|
||||||
limit -= 1
|
|
||||||
classes = torch.tensor(classes, dtype=torch.long)
|
|
||||||
images = torch.tensor(images, dtype=torch.float32)
|
|
||||||
images = einops.rearrange(images, "n (w h) -> n w h", w=28, h=28)
|
|
||||||
images = einops.repeat(
|
|
||||||
images, "n w h -> n c (w r_w) (h r_h)", c=1, r_w=8, r_h=8
|
|
||||||
)
|
|
||||||
return (images, classes)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
path = "fashion-mnist_train.csv"
|
|
||||||
dataset = FashionDataset(path=path)
|
|
||||||
print(f"len: {len(dataset)}")
|
|
||||||
print(f"first shape: {dataset[0][0].shape}")
|
|
||||||
mean = einops.reduce(dataset[:10], "n w h -> w h", "mean")
|
|
||||||
print(f"mean shape: {mean.shape}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
0
src/data/README.md
Normal file
0
src/data/README.md
Normal file
0
src/data/__init__.py
Normal file
0
src/data/__init__.py
Normal file
6
src/data/collate.py
Normal file
6
src/data/collate.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def channel_to_batch(batch):
|
||||||
|
"""TODO"""
|
||||||
|
return batch
|
||||||
72
src/data/dataset.py
Normal file
72
src/data/dataset.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
from torch.utils.data import Dataset
|
||||||
|
import numpy as np
|
||||||
|
import einops
|
||||||
|
import csv
|
||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class MnistDataset(Dataset):
|
||||||
|
"""
|
||||||
|
The MNIST database of handwritten digits.
|
||||||
|
Training set is 60k labeled examples, test is 10k examples.
|
||||||
|
The b/w images normalized to 20x20, preserving aspect ratio.
|
||||||
|
|
||||||
|
It's the defacto standard image training set to learn about classification in DL
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
"""
|
||||||
|
give a path to a dir that contains the following csv files:
|
||||||
|
https://pjreddie.com/projects/mnist-in-csv/
|
||||||
|
"""
|
||||||
|
self.path = path
|
||||||
|
self.features, self.labels = self.load()
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return (self.features[idx], self.labels[idx])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.features)
|
||||||
|
|
||||||
|
def load(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# opening the CSV file
|
||||||
|
with open(self.path, mode="r") as file:
|
||||||
|
images = list()
|
||||||
|
labels = list()
|
||||||
|
# reading the CSV file
|
||||||
|
csvFile = csv.reader(file)
|
||||||
|
# displaying the contents of the CSV file
|
||||||
|
# header = next(csvFile)
|
||||||
|
limit = 1000
|
||||||
|
for line in csvFile:
|
||||||
|
if limit < 1:
|
||||||
|
break
|
||||||
|
label = int(line[0])
|
||||||
|
labels.append(label)
|
||||||
|
image = [int(x) for x in line[1:]]
|
||||||
|
images.append(image)
|
||||||
|
limit -= 1
|
||||||
|
labels = torch.tensor(labels, dtype=torch.long)
|
||||||
|
images = torch.tensor(images, dtype=torch.float32)
|
||||||
|
images = einops.rearrange(images, "n (w h) -> n w h", w=28, h=28)
|
||||||
|
images = einops.repeat(
|
||||||
|
images, "n w h -> n c (w r_w) (h r_h)", c=1, r_w=8, r_h=8
|
||||||
|
)
|
||||||
|
return (images, labels)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
path = "storage/mnist_train.csv"
|
||||||
|
dataset = MnistDataset(path=path)
|
||||||
|
print(f"len: {len(dataset)}")
|
||||||
|
print(f"first shape: {dataset[0][0].shape}")
|
||||||
|
mean = einops.reduce(dataset[:10][0], "n w h -> w h", "mean")
|
||||||
|
print(f"mean shape: {mean.shape}")
|
||||||
|
print(f"mean image: {mean}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
src/eval.py
Normal file
0
src/eval.py
Normal file
158
src/mpv.py
158
src/mpv.py
@@ -1,158 +0,0 @@
|
|||||||
# pytorch mlp for multiclass classification
|
|
||||||
from numpy import vstack
|
|
||||||
from numpy import argmax
|
|
||||||
from pandas import read_csv
|
|
||||||
from sklearn.preprocessing import LabelEncoder
|
|
||||||
from sklearn.metrics import accuracy_score
|
|
||||||
from torch import Tensor
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torch.utils.data import random_split
|
|
||||||
from torch.nn import Linear
|
|
||||||
from torch.nn import ReLU
|
|
||||||
from torch.nn import Softmax
|
|
||||||
from torch.nn import Module
|
|
||||||
from torch.optim import SGD
|
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
from torch.nn.init import kaiming_uniform_
|
|
||||||
from torch.nn.init import xavier_uniform_
|
|
||||||
|
|
||||||
# dataset definition
|
|
||||||
class CSVDataset(Dataset):
|
|
||||||
# load the dataset
|
|
||||||
def __init__(self, path):
|
|
||||||
# load the csv file as a dataframe
|
|
||||||
df = read_csv(path, header=None)
|
|
||||||
# store the inputs and outputs
|
|
||||||
self.X = df.values[:, :-1]
|
|
||||||
self.y = df.values[:, -1]
|
|
||||||
# ensure input data is floats
|
|
||||||
self.X = self.X.astype('float32')
|
|
||||||
# label encode target and ensure the values are floats
|
|
||||||
self.y = LabelEncoder().fit_transform(self.y)
|
|
||||||
|
|
||||||
# number of rows in the dataset
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.X)
|
|
||||||
|
|
||||||
# get a row at an index
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return [self.X[idx], self.y[idx]]
|
|
||||||
|
|
||||||
# get indexes for train and test rows
|
|
||||||
def get_splits(self, n_test=0.33):
|
|
||||||
# determine sizes
|
|
||||||
test_size = round(n_test * len(self.X))
|
|
||||||
train_size = len(self.X) - test_size
|
|
||||||
# calculate the split
|
|
||||||
return random_split(self, [train_size, test_size])
|
|
||||||
|
|
||||||
# model definition
|
|
||||||
class MLP(Module):
|
|
||||||
# define model elements
|
|
||||||
def __init__(self, n_inputs):
|
|
||||||
super(MLP, self).__init__()
|
|
||||||
# input to first hidden layer
|
|
||||||
self.hidden1 = Linear(n_inputs, 10)
|
|
||||||
kaiming_uniform_(self.hidden1.weight, nonlinearity='relu')
|
|
||||||
self.act1 = ReLU()
|
|
||||||
# second hidden layer
|
|
||||||
self.hidden2 = Linear(10, 8)
|
|
||||||
kaiming_uniform_(self.hidden2.weight, nonlinearity='relu')
|
|
||||||
self.act2 = ReLU()
|
|
||||||
# third hidden layer and output
|
|
||||||
self.hidden3 = Linear(8, 3)
|
|
||||||
xavier_uniform_(self.hidden3.weight)
|
|
||||||
self.act3 = Softmax(dim=1)
|
|
||||||
|
|
||||||
# forward propagate input
|
|
||||||
def forward(self, X):
|
|
||||||
# input to first hidden layer
|
|
||||||
X = self.hidden1(X)
|
|
||||||
X = self.act1(X)
|
|
||||||
# second hidden layer
|
|
||||||
X = self.hidden2(X)
|
|
||||||
X = self.act2(X)
|
|
||||||
# output layer
|
|
||||||
X = self.hidden3(X)
|
|
||||||
X = self.act3(X)
|
|
||||||
return X
|
|
||||||
|
|
||||||
# prepare the dataset
|
|
||||||
def prepare_data(path):
|
|
||||||
# load the dataset
|
|
||||||
dataset = CSVDataset(path)
|
|
||||||
# calculate split
|
|
||||||
train, test = dataset.get_splits()
|
|
||||||
# prepare data loaders
|
|
||||||
train_dl = DataLoader(train, batch_size=32, shuffle=True)
|
|
||||||
test_dl = DataLoader(test, batch_size=1024, shuffle=False)
|
|
||||||
return train_dl, test_dl
|
|
||||||
|
|
||||||
# train the model
|
|
||||||
def train_model(train_dl, model):
|
|
||||||
# define the optimization
|
|
||||||
criterion = CrossEntropyLoss()
|
|
||||||
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
|
|
||||||
# enumerate epochs
|
|
||||||
for epoch in range(500):
|
|
||||||
# enumerate mini batches
|
|
||||||
for i, (inputs, targets) in enumerate(train_dl):
|
|
||||||
# clear the gradients
|
|
||||||
optimizer.zero_grad()
|
|
||||||
# compute the model output
|
|
||||||
yhat = model(inputs)
|
|
||||||
# calculate loss
|
|
||||||
loss = criterion(yhat, targets)
|
|
||||||
# credit assignment
|
|
||||||
loss.backward()
|
|
||||||
# update model weights
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# evaluate the model
|
|
||||||
def evaluate_model(test_dl, model):
|
|
||||||
predictions, actuals = list(), list()
|
|
||||||
for i, (inputs, targets) in enumerate(test_dl):
|
|
||||||
# evaluate the model on the test set
|
|
||||||
yhat = model(inputs)
|
|
||||||
# retrieve numpy array
|
|
||||||
yhat = yhat.detach().numpy()
|
|
||||||
actual = targets.numpy()
|
|
||||||
# convert to class labels
|
|
||||||
yhat = argmax(yhat, axis=1)
|
|
||||||
# reshape for stacking
|
|
||||||
actual = actual.reshape((len(actual), 1))
|
|
||||||
yhat = yhat.reshape((len(yhat), 1))
|
|
||||||
# store
|
|
||||||
predictions.append(yhat)
|
|
||||||
actuals.append(actual)
|
|
||||||
predictions, actuals = vstack(predictions), vstack(actuals)
|
|
||||||
# calculate accuracy
|
|
||||||
acc = accuracy_score(actuals, predictions)
|
|
||||||
return acc
|
|
||||||
|
|
||||||
# make a class prediction for one row of data
|
|
||||||
def predict(row, model):
|
|
||||||
# convert row to data
|
|
||||||
row = Tensor([row])
|
|
||||||
# make prediction
|
|
||||||
yhat = model(row)
|
|
||||||
# retrieve numpy array
|
|
||||||
yhat = yhat.detach().numpy()
|
|
||||||
return yhat
|
|
||||||
|
|
||||||
# prepare the data
|
|
||||||
path = 'https://raw.githubusercontent.com/jbrownlee/Datasets/master/iris.csv'
|
|
||||||
train_dl, test_dl = prepare_data(path)
|
|
||||||
print(len(train_dl.dataset), len(test_dl.dataset))
|
|
||||||
# define the network
|
|
||||||
model = MLP(4)
|
|
||||||
# train the model
|
|
||||||
train_model(train_dl, model)
|
|
||||||
# evaluate the model
|
|
||||||
acc = evaluate_model(test_dl, model)
|
|
||||||
print('Accuracy: %.3f' % acc)
|
|
||||||
# make a single prediction
|
|
||||||
row = [5.1,3.5,1.4,0.2]
|
|
||||||
yhat = predict(row, model)
|
|
||||||
print('Predicted: %s (class=%d)' % (yhat, argmax(yhat)))
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
"""
|
|
||||||
main class for building a DL pipeline.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import click
|
|
||||||
from batch import Batch
|
|
||||||
from model.linear import DNN
|
|
||||||
from model.cnn import VGG16, VGG11
|
|
||||||
from data import FashionDataset
|
|
||||||
from utils import Stage
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
|
||||||
def cli():
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
|
||||||
def train():
|
|
||||||
batch_size = 16
|
|
||||||
num_workers = 8
|
|
||||||
|
|
||||||
path = "fashion-mnist_train.csv"
|
|
||||||
trainset = FashionDataset(path=path)
|
|
||||||
|
|
||||||
trainloader = torch.utils.data.DataLoader(
|
|
||||||
trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers
|
|
||||||
)
|
|
||||||
model = VGG11(in_channels=1, num_classes=10)
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
|
|
||||||
batch = Batch(
|
|
||||||
stage=Stage.TRAIN,
|
|
||||||
model=model,
|
|
||||||
device=torch.device("cpu"),
|
|
||||||
loader=trainloader,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
)
|
|
||||||
batch.run(
|
|
||||||
"Run run run run. Run run run away. Oh Oh oH OHHHHHHH yayayayayayayayaya! - David Byrne"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
cli()
|
|
||||||
0
src/pipeline/README.md
Normal file
0
src/pipeline/README.md
Normal file
111
src/pipeline/logger.py
Normal file
111
src/pipeline/logger.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
from tkinter import W
|
||||||
|
import torch
|
||||||
|
import wandb
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from einops import rearrange
|
||||||
|
from typing import Protocol, Tuple, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class Logger(Protocol):
|
||||||
|
def metrics(self, metrics: dict, epoch: int):
|
||||||
|
"""loss etc."""
|
||||||
|
|
||||||
|
def hyperparameters(self, hyperparameters: dict):
|
||||||
|
"""model states"""
|
||||||
|
|
||||||
|
def predictions(self, predictions: dict):
|
||||||
|
"""inference time stuff"""
|
||||||
|
|
||||||
|
def images(self, images: np.ndarray):
|
||||||
|
"""log images"""
|
||||||
|
|
||||||
|
|
||||||
|
class WandbLogger:
|
||||||
|
def __init__(self, project: str, entity: str, name: Optional[str], notes: str):
|
||||||
|
self.project = project
|
||||||
|
self.entity = entity
|
||||||
|
self.notes = notes
|
||||||
|
self.experiment = wandb.init(project=project, entity=entity, notes=notes)
|
||||||
|
self.experiment.name = name
|
||||||
|
|
||||||
|
self.data_dict = {}
|
||||||
|
|
||||||
|
def metrics(self, metrics: dict):
|
||||||
|
"""loss etc."""
|
||||||
|
|
||||||
|
self.data_dict.update(metrics)
|
||||||
|
|
||||||
|
def hyperparameters(self, hyperparameters: dict):
|
||||||
|
"""model states"""
|
||||||
|
self.experiment.config.update(hyperparameters, allow_val_change=True)
|
||||||
|
|
||||||
|
def predictions(self, predictions: dict):
|
||||||
|
"""inference time stuff"""
|
||||||
|
|
||||||
|
def image(self, image: dict):
|
||||||
|
"""log images to wandb"""
|
||||||
|
self.data_dict.update({'Generate Image' : image})
|
||||||
|
|
||||||
|
def video(self, images: str, title: str):
|
||||||
|
"""log images to wandb"""
|
||||||
|
|
||||||
|
images = np.uint8(rearrange(images, 't b c h w -> b t c h w'))
|
||||||
|
self.data_dict.update({f"{title}": wandb.Video(images, fps=20)})
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
self.experiment.log(self.data_dict)
|
||||||
|
self.data_dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
class DebugLogger:
|
||||||
|
def __init__(self, project: str, entity: str, name: str, notes: str):
|
||||||
|
self.project = project
|
||||||
|
self.entity = entity
|
||||||
|
self.name = name
|
||||||
|
self.notes = notes
|
||||||
|
|
||||||
|
def metrics(self, metrics: dict, epoch: int = None):
|
||||||
|
"""
|
||||||
|
loss etc.
|
||||||
|
"""
|
||||||
|
print(f"metrics: {metrics}")
|
||||||
|
|
||||||
|
def hyperparameters(self, hyperparameters: dict):
|
||||||
|
"""
|
||||||
|
model states
|
||||||
|
"""
|
||||||
|
print(f"hyperparameters: {hyperparameters}")
|
||||||
|
|
||||||
|
def predictions(self, predictions: dict):
|
||||||
|
"""
|
||||||
|
inference time stuff
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Checkpoint:
|
||||||
|
def __init__(self, checkpoint_path):
|
||||||
|
self.checkpoint_path = checkpoint_path
|
||||||
|
|
||||||
|
def load(self) -> Tuple:
|
||||||
|
checkpoint = torch.load(self.checkpoint_path)
|
||||||
|
model = checkpoint["model"]
|
||||||
|
optimizer = checkpoint["optimizer"]
|
||||||
|
epoch = checkpoint["epoch"]
|
||||||
|
loss = checkpoint["loss"]
|
||||||
|
return (model, optimizer, epoch, loss)
|
||||||
|
|
||||||
|
def save(self, model: torch.nn.Module, optimizer, epoch, loss):
|
||||||
|
checkpoint = {
|
||||||
|
"model": model,
|
||||||
|
"optimizer": optimizer,
|
||||||
|
"epoch": epoch,
|
||||||
|
"loss": loss,
|
||||||
|
}
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
|
||||||
|
name = "".join(random.choices(string.ascii_letters, k=10)) + ".tar"
|
||||||
|
torch.save(checkpoint, f"{name}")
|
||||||
|
|
||||||
|
|
||||||
@@ -1,22 +1,27 @@
|
|||||||
|
"""
|
||||||
|
runner for training and valdating
|
||||||
|
"""
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from data import FashionDataset
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from utils import Stage
|
from pipeline.utils import Stage
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
|
||||||
class Batch:
|
class Runner:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
stage: Stage,
|
stage: Stage,
|
||||||
model: nn.Module, device,
|
model: nn.Module,
|
||||||
|
device,
|
||||||
loader: DataLoader,
|
loader: DataLoader,
|
||||||
optimizer: optim.Optimizer,
|
optimizer: optim.Optimizer,
|
||||||
criterion: nn.Module,
|
criterion: nn.Module,
|
||||||
|
config: DictConfig = None,
|
||||||
):
|
):
|
||||||
"""todo"""
|
self.config = config
|
||||||
self.stage = stage
|
self.stage = stage
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = model.to(device)
|
self.model = model.to(device)
|
||||||
@@ -26,14 +31,18 @@ class Batch:
|
|||||||
self.loss = 0
|
self.loss = 0
|
||||||
|
|
||||||
def run(self, desc):
|
def run(self, desc):
|
||||||
self.model.train()
|
# set the model to train model
|
||||||
epoch = 0
|
if self.stage == Stage.TRAIN:
|
||||||
for epoch, (x, y) in enumerate(tqdm(self.loader, desc=desc)):
|
self.model.train()
|
||||||
|
if self.config.debug:
|
||||||
|
breakpoint()
|
||||||
|
for batch, (x, y) in enumerate(tqdm(self.loader, desc=desc)):
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
loss = self._run_batch((x, y))
|
loss = self._run_batch((x, y))
|
||||||
loss.backward() # Send loss backwards to accumulate gradients
|
loss.backward() # Send loss backwards to accumulate gradients
|
||||||
self.optimizer.step() # Perform a gradient update on the weights of the mode
|
self.optimizer.step() # Perform a gradient update on the weights of the mode
|
||||||
self.loss += loss.item()
|
self.loss += loss.item()
|
||||||
|
return self.loss
|
||||||
|
|
||||||
def _run_batch(self, sample):
|
def _run_batch(self, sample):
|
||||||
true_x, true_y = sample
|
true_x, true_y = sample
|
||||||
@@ -41,29 +50,3 @@ class Batch:
|
|||||||
pred_y = self.model(true_x)
|
pred_y = self.model(true_x)
|
||||||
loss = self.criterion(pred_y, true_y)
|
loss = self.criterion(pred_y, true_y)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
model = nn.Conv2d(1, 64, 3)
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
|
|
||||||
path = "fashion-mnist_train.csv"
|
|
||||||
dataset = FashionDataset(path)
|
|
||||||
batch_size = 16
|
|
||||||
num_workers = 1
|
|
||||||
loader = torch.utils.data.DataLoader(
|
|
||||||
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
|
|
||||||
)
|
|
||||||
batch = Batch(
|
|
||||||
Stage.TRAIN,
|
|
||||||
device=torch.device("cpu"),
|
|
||||||
model=model,
|
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
|
||||||
loader=loader,
|
|
||||||
)
|
|
||||||
batch.run("test")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
0
src/sample.py
Normal file
0
src/sample.py
Normal file
88
src/train.py
Normal file
88
src/train.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""
|
||||||
|
main class for building a DL pipeline.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
the main entry point for training a model
|
||||||
|
|
||||||
|
coordinates:
|
||||||
|
|
||||||
|
- datasets
|
||||||
|
- dataloaders
|
||||||
|
- runner
|
||||||
|
|
||||||
|
"""
|
||||||
|
from pipeline.runner import Runner
|
||||||
|
from model.linear import DNN
|
||||||
|
from model.cnn import VGG16, VGG11
|
||||||
|
from data.dataset import MnistDataset
|
||||||
|
from pipeline.utils import Stage
|
||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
from data.collate import channel_to_batch
|
||||||
|
import hydra
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(config_path="config", config_name="main")
|
||||||
|
def train(config: DictConfig):
|
||||||
|
if config.debug:
|
||||||
|
breakpoint()
|
||||||
|
lr = config.lr
|
||||||
|
batch_size = config.batch_size
|
||||||
|
num_workers = config.num_workers
|
||||||
|
device = config.device
|
||||||
|
epochs = config.epochs
|
||||||
|
|
||||||
|
train_path = Path(config.app_dir) / "data/mnist_train.csv"
|
||||||
|
trainset = MnistDataset(path=train_path)
|
||||||
|
|
||||||
|
dev_path = Path(config.app_dir) / "data/mnist_test.csv"
|
||||||
|
devset = MnistDataset(path=dev_path)
|
||||||
|
|
||||||
|
trainloader = torch.utils.data.DataLoader(
|
||||||
|
trainset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=num_workers,
|
||||||
|
# collate_fn=channel_to_batch,
|
||||||
|
)
|
||||||
|
devloader = torch.utils.data.DataLoader(
|
||||||
|
devset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
# collate_fn=channel_to_batch,
|
||||||
|
)
|
||||||
|
model = VGG11(in_channels=1, num_classes=10)
|
||||||
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||||
|
train_runner = Runner(
|
||||||
|
stage=Stage.TRAIN,
|
||||||
|
model=model,
|
||||||
|
device=torch.device(device),
|
||||||
|
loader=trainloader,
|
||||||
|
criterion=criterion,
|
||||||
|
optimizer=optimizer,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
dev_runner = Runner(
|
||||||
|
stage=Stage.DEV,
|
||||||
|
model=model,
|
||||||
|
device=torch.device(device),
|
||||||
|
loader=devloader,
|
||||||
|
criterion=criterion,
|
||||||
|
optimizer=optimizer,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
if epoch % config.dev_after == 0:
|
||||||
|
dev_log = dev_runner.run("dev epoch")
|
||||||
|
else:
|
||||||
|
train_log = train_runner.run("train epoch")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train()
|
||||||
0
test/__init__.py
Normal file
0
test/__init__.py
Normal file
10
test/test_pipeline.py
Normal file
10
test/test_pipeline.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from src.model.linear import DNN
|
||||||
|
from src.data import GenericDataset
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def test_size_of_dataset():
|
||||||
|
features = 40
|
||||||
|
os.environ["INPUT_FEATURES"] = str(features)
|
||||||
|
dataset = GenericDataset()
|
||||||
|
assert len(dataset[0][0]) == features
|
||||||
Reference in New Issue
Block a user