reorganize pipeline dir and location of files.
add readmes to all dir.
This commit is contained in:
		
							parent
							
								
									0f12b26e40
								
							
						
					
					
						commit
						ecc8939517
					
				
							
								
								
									
										6
									
								
								Makefile
								
								
								
								
							
							
						
						
									
										6
									
								
								Makefile
								
								
								
								
							| 
						 | 
					@ -4,17 +4,17 @@ CONDA_ENV=ml_pipeline
 | 
				
			||||||
all: help
 | 
					all: help
 | 
				
			||||||
 | 
					
 | 
				
			||||||
run: ## run the pipeline (train)
 | 
					run: ## run the pipeline (train)
 | 
				
			||||||
	python src/pipeline.py \
 | 
						python src/train.py \
 | 
				
			||||||
		debug=false
 | 
							debug=false
 | 
				
			||||||
debug: ## run the pipeline (train) with debugging enabled
 | 
					debug: ## run the pipeline (train) with debugging enabled
 | 
				
			||||||
	python src/pipeline.py \
 | 
						python src/train.py \
 | 
				
			||||||
		debug=true
 | 
							debug=true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
data: ## download the mnist data
 | 
					data: ## download the mnist data
 | 
				
			||||||
	wget https://pjreddie.com/media/files/mnist_train.csv -O data/mnist_train.csv
 | 
						wget https://pjreddie.com/media/files/mnist_train.csv -O data/mnist_train.csv
 | 
				
			||||||
	wget https://pjreddie.com/media/files/mnist_test.csv -O data/mnist_test.csv
 | 
						wget https://pjreddie.com/media/files/mnist_test.csv -O data/mnist_test.csv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
env_import: environment.yml ## import any changes to env.yml into conda env
 | 
					install: environment.yml ## import any changes to env.yml into conda env
 | 
				
			||||||
	conda env update -n ${CONDA_ENV} --file $^
 | 
						conda env update -n ${CONDA_ENV} --file $^
 | 
				
			||||||
 | 
					
 | 
				
			||||||
env_export: ## export the conda envirnoment without package or name
 | 
					env_export: ## export the conda envirnoment without package or name
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										114
									
								
								README.md
								
								
								
								
							
							
						
						
									
										114
									
								
								README.md
								
								
								
								
							| 
						 | 
					@ -6,9 +6,15 @@ Instead of remembering where to put everything and making a different choice for
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Think of it like a mini-pytorch lightening, with all the fory internals exposed for extension and modification.
 | 
					Think of it like a mini-pytorch lightening, with all the fory internals exposed for extension and modification.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This project lives here: [https://github.com/publicmatt.com/ml_pipeline](https://github.com/publicmatt.com/ml_pipeline).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Usage
 | 
					# Usage
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					make help # lists available options.
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Install:
 | 
					## Install:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Install the conda requirements:
 | 
					Install the conda requirements:
 | 
				
			||||||
| 
						 | 
					@ -17,10 +23,12 @@ Install the conda requirements:
 | 
				
			||||||
make install
 | 
					make install
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Which is a proxy for calling:
 | 
					## Data:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Download mnist data from PJReadie's website:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```bash
 | 
					```bash
 | 
				
			||||||
conda env updates -n ml_pipeline --file environment.yml
 | 
					make data
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Run:
 | 
					## Run:
 | 
				
			||||||
| 
						 | 
					@ -31,7 +39,6 @@ Run the code on MNIST with the following command:
 | 
				
			||||||
make run
 | 
					make run
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
# Tutorial
 | 
					# Tutorial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
The motivation for building a template for deep learning pipelines is this: deep learning is hard enough without every code baase being a little different.
 | 
					The motivation for building a template for deep learning pipelines is this: deep learning is hard enough without every code baase being a little different.
 | 
				
			||||||
| 
						 | 
					@ -41,33 +48,90 @@ Especially in a research lab, standardizing on a few components makes switching
 | 
				
			||||||
In this template, you'll see the following:
 | 
					In this template, you'll see the following:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## directory structure
 | 
					## directory structure
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					.
 | 
				
			||||||
 | 
					├── README.md
 | 
				
			||||||
 | 
					├── environment.yml
 | 
				
			||||||
 | 
					├── launch.sh
 | 
				
			||||||
 | 
					├── Makefile
 | 
				
			||||||
 | 
					├── data
 | 
				
			||||||
 | 
					│   ├── mnist_test.csv
 | 
				
			||||||
 | 
					│   └── mnist_train.csv
 | 
				
			||||||
 | 
					├── docs
 | 
				
			||||||
 | 
					│   └── 2023-01-26.md
 | 
				
			||||||
 | 
					├── src
 | 
				
			||||||
 | 
					│   ├── config
 | 
				
			||||||
 | 
					│   │   └── main.yaml
 | 
				
			||||||
 | 
					│   ├── data
 | 
				
			||||||
 | 
					│   │   ├── __init__.py
 | 
				
			||||||
 | 
					│   │   ├── README.md
 | 
				
			||||||
 | 
					│   │   ├── collate.py
 | 
				
			||||||
 | 
					│   │   └── dataset.py
 | 
				
			||||||
 | 
					│   ├── eval.py
 | 
				
			||||||
 | 
					│   ├── __init__.py
 | 
				
			||||||
 | 
					│   ├── model
 | 
				
			||||||
 | 
					│   │   ├── __init__.py
 | 
				
			||||||
 | 
					│   │   ├── README.md
 | 
				
			||||||
 | 
					│   │   ├── cnn.py
 | 
				
			||||||
 | 
					│   │   └── linear.py
 | 
				
			||||||
 | 
					│   ├── pipeline
 | 
				
			||||||
 | 
					│   │   ├── __init__.py
 | 
				
			||||||
 | 
					│   │   ├── README.md
 | 
				
			||||||
 | 
					│   │   ├── logger.py
 | 
				
			||||||
 | 
					│   │   ├── runner.py
 | 
				
			||||||
 | 
					│   │   └── utils.py
 | 
				
			||||||
 | 
					│   ├── sample.py
 | 
				
			||||||
 | 
					│   └── train.py
 | 
				
			||||||
 | 
					└── test
 | 
				
			||||||
 | 
					    ├── __init__.py
 | 
				
			||||||
 | 
					    └── test_pipeline.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- `src/model`
 | 
					8 directories, 25 files
 | 
				
			||||||
- `src/config`
 | 
					
 | 
				
			||||||
- `data/`
 | 
					```
 | 
				
			||||||
- `test/`
 | 
					
 | 
				
			||||||
    - pytest: unit testing.
 | 
					## what and why?
 | 
				
			||||||
        - good for data shape
 | 
					 | 
				
			||||||
        - TODO:
 | 
					 | 
				
			||||||
- `docs/`
 | 
					 | 
				
			||||||
    - switching projects is easier with these in place
 | 
					 | 
				
			||||||
    - organize them
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
- `**/__init__.py`
 | 
					 | 
				
			||||||
    - creates modules out of dir.
 | 
					 | 
				
			||||||
    - `import module` works with these.
 | 
					 | 
				
			||||||
- `README.md`
 | 
					 | 
				
			||||||
    - root level required.
 | 
					 | 
				
			||||||
    - can exist inside any dir.
 | 
					 | 
				
			||||||
- `environment.yml`
 | 
					- `environment.yml`
 | 
				
			||||||
- `Makefile` 
 | 
					    - hutch research has standardized on conda
 | 
				
			||||||
 | 
					    - here's a good tutorial on getting that setup: [seth email](emailto:bassetis@wwu.edu)
 | 
				
			||||||
 | 
					- `launch.sh` or `Makefile`
 | 
				
			||||||
    - to install and run stuff.
 | 
					    - to install and run stuff.
 | 
				
			||||||
    - houses common operations and scripts.
 | 
					    - houses common operations and scripts.
 | 
				
			||||||
- `launch.sh` 
 | 
					    - `launch.sh` to dispatch training.
 | 
				
			||||||
    - script to dispatch training.
 | 
					- `README.md`
 | 
				
			||||||
 | 
					    - explain the project and how to run it.
 | 
				
			||||||
 | 
					    - list authors.
 | 
				
			||||||
 | 
					    - list resources that new collaborators might need.
 | 
				
			||||||
 | 
					    - root level dir.
 | 
				
			||||||
 | 
					    - can exist inside any dir.
 | 
				
			||||||
 | 
					    - reads nicely on github.com.
 | 
				
			||||||
 | 
					- `docs/`
 | 
				
			||||||
 | 
					    - switching projects is easier with these in place.
 | 
				
			||||||
 | 
					    - organize them by meeting, or weekly agenda.
 | 
				
			||||||
 | 
					    - generally collection of markdown files.
 | 
				
			||||||
 | 
					- `test/`
 | 
				
			||||||
 | 
					    - TODO
 | 
				
			||||||
 | 
					    - pytest: unit testing.
 | 
				
			||||||
 | 
					    - good for data shape. not sure what else.
 | 
				
			||||||
 | 
					- `data/`
 | 
				
			||||||
 | 
					    - raw data
 | 
				
			||||||
 | 
					    - do not commit these to repo generally.
 | 
				
			||||||
 | 
					        - `echo "*.csv" >> data/.gitignore`
 | 
				
			||||||
 | 
					- `__init__.py`
 | 
				
			||||||
 | 
					    - creates modules out of dir.
 | 
				
			||||||
 | 
					    - `import module` works b/c of these.
 | 
				
			||||||
 | 
					- `src/model/`
 | 
				
			||||||
 | 
					    - if you have a large project, you might have multiple architectures/models.
 | 
				
			||||||
 | 
					    - small projects might just have `model/VGG.py` or `model/3d_unet.py`.
 | 
				
			||||||
 | 
					- `src/config`
 | 
				
			||||||
 | 
					    - based on hydra python package.
 | 
				
			||||||
 | 
					    - quickly change run variables and hyperparameters.
 | 
				
			||||||
 | 
					- `src/pipeline`
 | 
				
			||||||
 | 
					    - where the magic happens.
 | 
				
			||||||
 | 
					    - `train.py` creates all the objects, hands them off to runner for batching, monitors each epoch.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## testing
 | 
					## testing
 | 
				
			||||||
    
 | 
					 | 
				
			||||||
- `if __name__ == "__main__"`.
 | 
					- `if __name__ == "__main__"`.
 | 
				
			||||||
    - good way to test things
 | 
					    - good way to test things
 | 
				
			||||||
- enables lots breakpoints.
 | 
					- enables lots breakpoints.
 | 
				
			||||||
| 
						 | 
					@ -81,6 +145,8 @@ In this template, you'll see the following:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## data
 | 
					## data
 | 
				
			||||||
- collate functions!
 | 
					- collate functions!
 | 
				
			||||||
 | 
					- datasets.
 | 
				
			||||||
 | 
					- dataloader.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## formatting python
 | 
					## formatting python
 | 
				
			||||||
- python type hints.
 | 
					- python type hints.
 | 
				
			||||||
| 
						 | 
					@ -88,6 +154,8 @@ In this template, you'll see the following:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## running
 | 
					## running
 | 
				
			||||||
- tqdm to track progress.
 | 
					- tqdm to track progress.
 | 
				
			||||||
 | 
					- wandb for logging.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## architecture
 | 
					## architecture
 | 
				
			||||||
- dataloader, optimizer, criterion, device, state are constructed in main, but passed to an object that runs batches.
 | 
					- dataloader, optimizer, criterion, device, state are constructed in main, but passed to an object that runs batches.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,2 +1,2 @@
 | 
				
			||||||
python src/pipeline.py \
 | 
					python src/train.py \
 | 
				
			||||||
    debug=false
 | 
					    debug=false
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										158
									
								
								src/mpv.py
								
								
								
								
							
							
						
						
									
										158
									
								
								src/mpv.py
								
								
								
								
							| 
						 | 
					@ -1,158 +0,0 @@
 | 
				
			||||||
# pytorch mlp for multiclass classification
 | 
					 | 
				
			||||||
from numpy import vstack
 | 
					 | 
				
			||||||
from numpy import argmax
 | 
					 | 
				
			||||||
from pandas import read_csv
 | 
					 | 
				
			||||||
from sklearn.preprocessing import LabelEncoder
 | 
					 | 
				
			||||||
from sklearn.metrics import accuracy_score
 | 
					 | 
				
			||||||
from torch import Tensor
 | 
					 | 
				
			||||||
from torch.utils.data import Dataset
 | 
					 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					 | 
				
			||||||
from torch.utils.data import random_split
 | 
					 | 
				
			||||||
from torch.nn import Linear
 | 
					 | 
				
			||||||
from torch.nn import ReLU
 | 
					 | 
				
			||||||
from torch.nn import Softmax
 | 
					 | 
				
			||||||
from torch.nn import Module
 | 
					 | 
				
			||||||
from torch.optim import SGD
 | 
					 | 
				
			||||||
from torch.nn import CrossEntropyLoss
 | 
					 | 
				
			||||||
from torch.nn.init import kaiming_uniform_
 | 
					 | 
				
			||||||
from torch.nn.init import xavier_uniform_
 | 
					 | 
				
			||||||
 
 | 
					 | 
				
			||||||
# dataset definition
 | 
					 | 
				
			||||||
class CSVDataset(Dataset):
 | 
					 | 
				
			||||||
    # load the dataset
 | 
					 | 
				
			||||||
    def __init__(self, path):
 | 
					 | 
				
			||||||
        # load the csv file as a dataframe
 | 
					 | 
				
			||||||
        df = read_csv(path, header=None)
 | 
					 | 
				
			||||||
        # store the inputs and outputs
 | 
					 | 
				
			||||||
        self.X = df.values[:, :-1]
 | 
					 | 
				
			||||||
        self.y = df.values[:, -1]
 | 
					 | 
				
			||||||
        # ensure input data is floats
 | 
					 | 
				
			||||||
        self.X = self.X.astype('float32')
 | 
					 | 
				
			||||||
        # label encode target and ensure the values are floats
 | 
					 | 
				
			||||||
        self.y = LabelEncoder().fit_transform(self.y)
 | 
					 | 
				
			||||||
 
 | 
					 | 
				
			||||||
    # number of rows in the dataset
 | 
					 | 
				
			||||||
    def __len__(self):
 | 
					 | 
				
			||||||
        return len(self.X)
 | 
					 | 
				
			||||||
 
 | 
					 | 
				
			||||||
    # get a row at an index
 | 
					 | 
				
			||||||
    def __getitem__(self, idx):
 | 
					 | 
				
			||||||
        return [self.X[idx], self.y[idx]]
 | 
					 | 
				
			||||||
 
 | 
					 | 
				
			||||||
    # get indexes for train and test rows
 | 
					 | 
				
			||||||
    def get_splits(self, n_test=0.33):
 | 
					 | 
				
			||||||
        # determine sizes
 | 
					 | 
				
			||||||
        test_size = round(n_test * len(self.X))
 | 
					 | 
				
			||||||
        train_size = len(self.X) - test_size
 | 
					 | 
				
			||||||
        # calculate the split
 | 
					 | 
				
			||||||
        return random_split(self, [train_size, test_size])
 | 
					 | 
				
			||||||
 
 | 
					 | 
				
			||||||
# model definition
 | 
					 | 
				
			||||||
class MLP(Module):
 | 
					 | 
				
			||||||
    # define model elements
 | 
					 | 
				
			||||||
    def __init__(self, n_inputs):
 | 
					 | 
				
			||||||
        super(MLP, self).__init__()
 | 
					 | 
				
			||||||
        # input to first hidden layer
 | 
					 | 
				
			||||||
        self.hidden1 = Linear(n_inputs, 10)
 | 
					 | 
				
			||||||
        kaiming_uniform_(self.hidden1.weight, nonlinearity='relu')
 | 
					 | 
				
			||||||
        self.act1 = ReLU()
 | 
					 | 
				
			||||||
        # second hidden layer
 | 
					 | 
				
			||||||
        self.hidden2 = Linear(10, 8)
 | 
					 | 
				
			||||||
        kaiming_uniform_(self.hidden2.weight, nonlinearity='relu')
 | 
					 | 
				
			||||||
        self.act2 = ReLU()
 | 
					 | 
				
			||||||
        # third hidden layer and output
 | 
					 | 
				
			||||||
        self.hidden3 = Linear(8, 3)
 | 
					 | 
				
			||||||
        xavier_uniform_(self.hidden3.weight)
 | 
					 | 
				
			||||||
        self.act3 = Softmax(dim=1)
 | 
					 | 
				
			||||||
 
 | 
					 | 
				
			||||||
    # forward propagate input
 | 
					 | 
				
			||||||
    def forward(self, X):
 | 
					 | 
				
			||||||
        # input to first hidden layer
 | 
					 | 
				
			||||||
        X = self.hidden1(X)
 | 
					 | 
				
			||||||
        X = self.act1(X)
 | 
					 | 
				
			||||||
        # second hidden layer
 | 
					 | 
				
			||||||
        X = self.hidden2(X)
 | 
					 | 
				
			||||||
        X = self.act2(X)
 | 
					 | 
				
			||||||
        # output layer
 | 
					 | 
				
			||||||
        X = self.hidden3(X)
 | 
					 | 
				
			||||||
        X = self.act3(X)
 | 
					 | 
				
			||||||
        return X
 | 
					 | 
				
			||||||
 
 | 
					 | 
				
			||||||
# prepare the dataset
 | 
					 | 
				
			||||||
def prepare_data(path):
 | 
					 | 
				
			||||||
    # load the dataset
 | 
					 | 
				
			||||||
    dataset = CSVDataset(path)
 | 
					 | 
				
			||||||
    # calculate split
 | 
					 | 
				
			||||||
    train, test = dataset.get_splits()
 | 
					 | 
				
			||||||
    # prepare data loaders
 | 
					 | 
				
			||||||
    train_dl = DataLoader(train, batch_size=32, shuffle=True)
 | 
					 | 
				
			||||||
    test_dl = DataLoader(test, batch_size=1024, shuffle=False)
 | 
					 | 
				
			||||||
    return train_dl, test_dl
 | 
					 | 
				
			||||||
 
 | 
					 | 
				
			||||||
# train the model
 | 
					 | 
				
			||||||
def train_model(train_dl, model):
 | 
					 | 
				
			||||||
    # define the optimization
 | 
					 | 
				
			||||||
    criterion = CrossEntropyLoss()
 | 
					 | 
				
			||||||
    optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
 | 
					 | 
				
			||||||
    # enumerate epochs
 | 
					 | 
				
			||||||
    for epoch in range(500):
 | 
					 | 
				
			||||||
        # enumerate mini batches
 | 
					 | 
				
			||||||
        for i, (inputs, targets) in enumerate(train_dl):
 | 
					 | 
				
			||||||
            # clear the gradients
 | 
					 | 
				
			||||||
            optimizer.zero_grad()
 | 
					 | 
				
			||||||
            # compute the model output
 | 
					 | 
				
			||||||
            yhat = model(inputs)
 | 
					 | 
				
			||||||
            # calculate loss
 | 
					 | 
				
			||||||
            loss = criterion(yhat, targets)
 | 
					 | 
				
			||||||
            # credit assignment
 | 
					 | 
				
			||||||
            loss.backward()
 | 
					 | 
				
			||||||
            # update model weights
 | 
					 | 
				
			||||||
            optimizer.step()
 | 
					 | 
				
			||||||
 
 | 
					 | 
				
			||||||
# evaluate the model
 | 
					 | 
				
			||||||
def evaluate_model(test_dl, model):
 | 
					 | 
				
			||||||
    predictions, actuals = list(), list()
 | 
					 | 
				
			||||||
    for i, (inputs, targets) in enumerate(test_dl):
 | 
					 | 
				
			||||||
        # evaluate the model on the test set
 | 
					 | 
				
			||||||
        yhat = model(inputs)
 | 
					 | 
				
			||||||
        # retrieve numpy array
 | 
					 | 
				
			||||||
        yhat = yhat.detach().numpy()
 | 
					 | 
				
			||||||
        actual = targets.numpy()
 | 
					 | 
				
			||||||
        # convert to class labels
 | 
					 | 
				
			||||||
        yhat = argmax(yhat, axis=1)
 | 
					 | 
				
			||||||
        # reshape for stacking
 | 
					 | 
				
			||||||
        actual = actual.reshape((len(actual), 1))
 | 
					 | 
				
			||||||
        yhat = yhat.reshape((len(yhat), 1))
 | 
					 | 
				
			||||||
        # store
 | 
					 | 
				
			||||||
        predictions.append(yhat)
 | 
					 | 
				
			||||||
        actuals.append(actual)
 | 
					 | 
				
			||||||
    predictions, actuals = vstack(predictions), vstack(actuals)
 | 
					 | 
				
			||||||
    # calculate accuracy
 | 
					 | 
				
			||||||
    acc = accuracy_score(actuals, predictions)
 | 
					 | 
				
			||||||
    return acc
 | 
					 | 
				
			||||||
 
 | 
					 | 
				
			||||||
# make a class prediction for one row of data
 | 
					 | 
				
			||||||
def predict(row, model):
 | 
					 | 
				
			||||||
    # convert row to data
 | 
					 | 
				
			||||||
    row = Tensor([row])
 | 
					 | 
				
			||||||
    # make prediction
 | 
					 | 
				
			||||||
    yhat = model(row)
 | 
					 | 
				
			||||||
    # retrieve numpy array
 | 
					 | 
				
			||||||
    yhat = yhat.detach().numpy()
 | 
					 | 
				
			||||||
    return yhat
 | 
					 | 
				
			||||||
 
 | 
					 | 
				
			||||||
# prepare the data
 | 
					 | 
				
			||||||
path = 'https://raw.githubusercontent.com/jbrownlee/Datasets/master/iris.csv'
 | 
					 | 
				
			||||||
train_dl, test_dl = prepare_data(path)
 | 
					 | 
				
			||||||
print(len(train_dl.dataset), len(test_dl.dataset))
 | 
					 | 
				
			||||||
# define the network
 | 
					 | 
				
			||||||
model = MLP(4)
 | 
					 | 
				
			||||||
# train the model
 | 
					 | 
				
			||||||
train_model(train_dl, model)
 | 
					 | 
				
			||||||
# evaluate the model
 | 
					 | 
				
			||||||
acc = evaluate_model(test_dl, model)
 | 
					 | 
				
			||||||
print('Accuracy: %.3f' % acc)
 | 
					 | 
				
			||||||
# make a single prediction
 | 
					 | 
				
			||||||
row = [5.1,3.5,1.4,0.2]
 | 
					 | 
				
			||||||
yhat = predict(row, model)
 | 
					 | 
				
			||||||
print('Predicted: %s (class=%d)' % (yhat, argmax(yhat)))
 | 
					 | 
				
			||||||
| 
						 | 
					@ -5,9 +5,8 @@ import torch
 | 
				
			||||||
from torch import nn
 | 
					from torch import nn
 | 
				
			||||||
from torch import optim
 | 
					from torch import optim
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
from data import MnistDataset
 | 
					 | 
				
			||||||
from tqdm import tqdm
 | 
					from tqdm import tqdm
 | 
				
			||||||
from utils import Stage
 | 
					from pipeline.utils import Stage
 | 
				
			||||||
from omegaconf import DictConfig
 | 
					from omegaconf import DictConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -51,29 +50,3 @@ class Runner:
 | 
				
			||||||
        pred_y = self.model(true_x)
 | 
					        pred_y = self.model(true_x)
 | 
				
			||||||
        loss = self.criterion(pred_y, true_y)
 | 
					        loss = self.criterion(pred_y, true_y)
 | 
				
			||||||
        return loss
 | 
					        return loss
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def main():
 | 
					 | 
				
			||||||
    model = nn.Conv2d(1, 64, 3)
 | 
					 | 
				
			||||||
    criterion = torch.nn.CrossEntropyLoss()
 | 
					 | 
				
			||||||
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
 | 
					 | 
				
			||||||
    path = "mnist_train.csv"
 | 
					 | 
				
			||||||
    dataset = MnistDataset(path)
 | 
					 | 
				
			||||||
    batch_size = 16
 | 
					 | 
				
			||||||
    num_workers = 1
 | 
					 | 
				
			||||||
    loader = torch.utils.data.DataLoader(
 | 
					 | 
				
			||||||
        dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    batch = Batch(
 | 
					 | 
				
			||||||
        Stage.TRAIN,
 | 
					 | 
				
			||||||
        device=torch.device("cpu"),
 | 
					 | 
				
			||||||
        model=model,
 | 
					 | 
				
			||||||
        criterion=criterion,
 | 
					 | 
				
			||||||
        optimizer=optimizer,
 | 
					 | 
				
			||||||
        loader=loader,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    batch.run("test")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					 | 
				
			||||||
    main()
 | 
					 | 
				
			||||||
| 
						 | 
					@ -13,14 +13,14 @@ coordinates:
 | 
				
			||||||
- runner
 | 
					- runner
 | 
				
			||||||
 | 
					
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
from runner import Runner
 | 
					from pipeline.runner import Runner
 | 
				
			||||||
from model.linear import DNN
 | 
					from model.linear import DNN
 | 
				
			||||||
from model.cnn import VGG16, VGG11
 | 
					from model.cnn import VGG16, VGG11
 | 
				
			||||||
from data import MnistDataset
 | 
					from data.dataset import MnistDataset
 | 
				
			||||||
from utils import Stage
 | 
					from pipeline.utils import Stage
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
from collate import channel_to_batch
 | 
					from data.collate import channel_to_batch
 | 
				
			||||||
import hydra
 | 
					import hydra
 | 
				
			||||||
from omegaconf import DictConfig
 | 
					from omegaconf import DictConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue