add hydra config.
remove click. add launch script. add test dir. switch from fashion mnist to generic.
This commit is contained in:
parent
404e39206b
commit
1f13224c4f
|
@ -1,2 +1,4 @@
|
||||||
storage/
|
storage/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
outputs/
|
||||||
|
.env
|
||||||
|
|
2
Makefile
2
Makefile
|
@ -3,7 +3,7 @@ CONDA_ENV=ml_pipeline
|
||||||
all: run
|
all: run
|
||||||
|
|
||||||
run:
|
run:
|
||||||
python src/pipeline.py train
|
./launch.sh
|
||||||
|
|
||||||
data:
|
data:
|
||||||
python src/data.py
|
python src/data.py
|
||||||
|
|
26
README.md
26
README.md
|
@ -7,9 +7,9 @@ Instead of remembering where to put everything and making a different choice for
|
||||||
Think of it like a mini-pytorch lightening, with all the fory internals exposed for extension and modification.
|
Think of it like a mini-pytorch lightening, with all the fory internals exposed for extension and modification.
|
||||||
|
|
||||||
|
|
||||||
## Usage
|
# Usage
|
||||||
|
|
||||||
### Install:
|
## Install:
|
||||||
|
|
||||||
Install the conda requirements:
|
Install the conda requirements:
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ Which is a proxy for calling:
|
||||||
conda env updates -n ml_pipeline --file environment.yml
|
conda env updates -n ml_pipeline --file environment.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run:
|
## Run:
|
||||||
|
|
||||||
Run the code on MNIST with the following command:
|
Run the code on MNIST with the following command:
|
||||||
|
|
||||||
|
@ -31,3 +31,23 @@ Run the code on MNIST with the following command:
|
||||||
make run
|
make run
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
# Tutorial
|
||||||
|
|
||||||
|
The motivation for building a template for deep learning pipelines is this: deep learning is hard enough without every code baase being a little different.
|
||||||
|
|
||||||
|
Especially in a research lab, standardizing on a few components makes switching between projects easier.
|
||||||
|
|
||||||
|
In this template, you'll see the following:
|
||||||
|
|
||||||
|
- `src/model`, `src/config`, `storage`, `test` dirs.
|
||||||
|
- `if __name__ == "__main__"` tests.
|
||||||
|
- Hydra config.
|
||||||
|
- dataloader, optimizer, criterion, device, state are constructed in main, but passed to an object that runs batches.
|
||||||
|
- tqdm to track progress.
|
||||||
|
- debug config flag enables lots breakpoints.
|
||||||
|
- python type hints.
|
||||||
|
- a `launch.sh` script to dispatch training.
|
||||||
|
- a Makefile to install and run stuff.
|
||||||
|
- automatic linting with the `black` package.
|
||||||
|
- collate functions!
|
||||||
|
|
20
src/batch.py
20
src/batch.py
|
@ -2,21 +2,24 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from data import FashionDataset
|
from data import MnistDataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from utils import Stage
|
from utils import Stage
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
|
||||||
class Batch:
|
class Batch:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
stage: Stage,
|
stage: Stage,
|
||||||
model: nn.Module, device,
|
model: nn.Module,
|
||||||
|
device,
|
||||||
loader: DataLoader,
|
loader: DataLoader,
|
||||||
optimizer: optim.Optimizer,
|
optimizer: optim.Optimizer,
|
||||||
criterion: nn.Module,
|
criterion: nn.Module,
|
||||||
|
config: DictConfig = None,
|
||||||
):
|
):
|
||||||
"""todo"""
|
self.config = config
|
||||||
self.stage = stage
|
self.stage = stage
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = model.to(device)
|
self.model = model.to(device)
|
||||||
|
@ -26,7 +29,11 @@ class Batch:
|
||||||
self.loss = 0
|
self.loss = 0
|
||||||
|
|
||||||
def run(self, desc):
|
def run(self, desc):
|
||||||
self.model.train()
|
# set the model to train model
|
||||||
|
if self.stage == Stage.TRAIN:
|
||||||
|
self.model.train()
|
||||||
|
if self.config.debug:
|
||||||
|
breakpoint()
|
||||||
epoch = 0
|
epoch = 0
|
||||||
for epoch, (x, y) in enumerate(tqdm(self.loader, desc=desc)):
|
for epoch, (x, y) in enumerate(tqdm(self.loader, desc=desc)):
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
@ -34,6 +41,7 @@ class Batch:
|
||||||
loss.backward() # Send loss backwards to accumulate gradients
|
loss.backward() # Send loss backwards to accumulate gradients
|
||||||
self.optimizer.step() # Perform a gradient update on the weights of the mode
|
self.optimizer.step() # Perform a gradient update on the weights of the mode
|
||||||
self.loss += loss.item()
|
self.loss += loss.item()
|
||||||
|
return self.loss
|
||||||
|
|
||||||
def _run_batch(self, sample):
|
def _run_batch(self, sample):
|
||||||
true_x, true_y = sample
|
true_x, true_y = sample
|
||||||
|
@ -47,8 +55,8 @@ def main():
|
||||||
model = nn.Conv2d(1, 64, 3)
|
model = nn.Conv2d(1, 64, 3)
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
|
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
|
||||||
path = "fashion-mnist_train.csv"
|
path = "mnist_train.csv"
|
||||||
dataset = FashionDataset(path)
|
dataset = MnistDataset(path)
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
num_workers = 1
|
num_workers = 1
|
||||||
loader = torch.utils.data.DataLoader(
|
loader = torch.utils.data.DataLoader(
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def channel_to_batch(batch):
|
||||||
|
"""TODO"""
|
||||||
|
return batch
|
|
@ -0,0 +1,6 @@
|
||||||
|
app_dir: ${hydra:runtime.cwd}
|
||||||
|
debug: true
|
||||||
|
lr: 2e-4
|
||||||
|
batch_size: 16
|
||||||
|
num_workers: 0
|
||||||
|
device: "cpu"
|
48
src/data.py
48
src/data.py
|
@ -3,51 +3,69 @@ import numpy as np
|
||||||
import einops
|
import einops
|
||||||
import csv
|
import csv
|
||||||
import torch
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
class FashionDataset(Dataset):
|
class MnistDataset(Dataset):
|
||||||
def __init__(self, path: str):
|
"""
|
||||||
|
The MNIST database of handwritten digits.
|
||||||
|
Training set is 60k labeled examples, test is 10k examples.
|
||||||
|
The b/w images normalized to 20x20, preserving aspect ratio.
|
||||||
|
|
||||||
|
It's the defacto standard image training set to learn about classification in DL
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
"""
|
||||||
|
give a path to a dir that contains the following csv files:
|
||||||
|
https://pjreddie.com/projects/mnist-in-csv/
|
||||||
|
"""
|
||||||
self.path = path
|
self.path = path
|
||||||
self.x, self.y = self.load()
|
self.features, self.labels = self.load()
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return (self.x[idx], self.y[idx])
|
return (self.features[idx], self.labels[idx])
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.x)
|
return len(self.features)
|
||||||
|
|
||||||
def load(self):
|
def load(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# opening the CSV file
|
# opening the CSV file
|
||||||
with open(self.path, mode="r") as file:
|
with open(self.path, mode="r") as file:
|
||||||
images = list()
|
images = list()
|
||||||
classes = list()
|
labels = list()
|
||||||
# reading the CSV file
|
# reading the CSV file
|
||||||
csvFile = csv.reader(file)
|
csvFile = csv.reader(file)
|
||||||
# displaying the contents of the CSV file
|
# displaying the contents of the CSV file
|
||||||
header = next(csvFile)
|
# header = next(csvFile)
|
||||||
limit = 1000
|
limit = 1000
|
||||||
for line in csvFile:
|
for line in csvFile:
|
||||||
if limit < 1:
|
if limit < 1:
|
||||||
break
|
break
|
||||||
classes.append(int(line[:1][0]))
|
label = int(line[0])
|
||||||
images.append([int(x) for x in line[1:]])
|
labels.append(label)
|
||||||
|
image = [int(x) for x in line[1:]]
|
||||||
|
images.append(image)
|
||||||
limit -= 1
|
limit -= 1
|
||||||
classes = torch.tensor(classes, dtype=torch.long)
|
labels = torch.tensor(labels, dtype=torch.long)
|
||||||
images = torch.tensor(images, dtype=torch.float32)
|
images = torch.tensor(images, dtype=torch.float32)
|
||||||
images = einops.rearrange(images, "n (w h) -> n w h", w=28, h=28)
|
images = einops.rearrange(images, "n (w h) -> n w h", w=28, h=28)
|
||||||
images = einops.repeat(
|
images = einops.repeat(
|
||||||
images, "n w h -> n c (w r_w) (h r_h)", c=1, r_w=8, r_h=8
|
images, "n w h -> n c (w r_w) (h r_h)", c=1, r_w=8, r_h=8
|
||||||
)
|
)
|
||||||
return (images, classes)
|
return (images, labels)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
path = "fashion-mnist_train.csv"
|
|
||||||
dataset = FashionDataset(path=path)
|
path = "storage/mnist_train.csv"
|
||||||
|
dataset = MnistDataset(path=path)
|
||||||
print(f"len: {len(dataset)}")
|
print(f"len: {len(dataset)}")
|
||||||
print(f"first shape: {dataset[0][0].shape}")
|
print(f"first shape: {dataset[0][0].shape}")
|
||||||
mean = einops.reduce(dataset[:10], "n w h -> w h", "mean")
|
mean = einops.reduce(dataset[:10][0], "n w h -> w h", "mean")
|
||||||
print(f"mean shape: {mean.shape}")
|
print(f"mean shape: {mean.shape}")
|
||||||
|
print(f"mean image: {mean}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -3,46 +3,54 @@ main class for building a DL pipeline.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import click
|
|
||||||
from batch import Batch
|
from batch import Batch
|
||||||
from model.linear import DNN
|
from model.linear import DNN
|
||||||
from model.cnn import VGG16, VGG11
|
from model.cnn import VGG16, VGG11
|
||||||
from data import FashionDataset
|
from data import MnistDataset
|
||||||
from utils import Stage
|
from utils import Stage
|
||||||
import torch
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
from collate import channel_to_batch
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@hydra.main(config_path="config", config_name="main")
|
||||||
def cli():
|
def train(config: DictConfig):
|
||||||
pass
|
if config.debug:
|
||||||
|
breakpoint()
|
||||||
|
lr = config.lr
|
||||||
|
batch_size = config.batch_size
|
||||||
|
num_workers = config.num_workers
|
||||||
|
device = config.device
|
||||||
|
|
||||||
|
path = Path(config.app_dir) / "storage/mnist_train.csv"
|
||||||
@cli.command()
|
trainset = MnistDataset(path=path)
|
||||||
def train():
|
|
||||||
batch_size = 16
|
|
||||||
num_workers = 8
|
|
||||||
|
|
||||||
path = "fashion-mnist_train.csv"
|
|
||||||
trainset = FashionDataset(path=path)
|
|
||||||
|
|
||||||
trainloader = torch.utils.data.DataLoader(
|
trainloader = torch.utils.data.DataLoader(
|
||||||
trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers
|
trainset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
# collate_fn=channel_to_batch,
|
||||||
)
|
)
|
||||||
model = VGG11(in_channels=1, num_classes=10)
|
model = VGG11(in_channels=1, num_classes=10)
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
stage=Stage.TRAIN,
|
stage=Stage.TRAIN,
|
||||||
model=model,
|
model=model,
|
||||||
device=torch.device("cpu"),
|
device=torch.device(device),
|
||||||
loader=trainloader,
|
loader=trainloader,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
batch.run(
|
log = batch.run(
|
||||||
"Run run run run. Run run run away. Oh Oh oH OHHHHHHH yayayayayayayayaya! - David Byrne"
|
"Run run run run. Run run run away. Oh Oh oH OHHHHHHH yayayayayayayayaya! - David Byrne"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
train()
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
from src.model.linear import DNN
|
||||||
|
from src.data import GenericDataset
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def test_size_of_dataset():
|
||||||
|
features = 40
|
||||||
|
os.environ["INPUT_FEATURES"] = str(features)
|
||||||
|
dataset = GenericDataset()
|
||||||
|
assert len(dataset[0][0]) == features
|
Loading…
Reference in New Issue