From 44250fc618b256e3bce18c5668a544c87c9094bb Mon Sep 17 00:00:00 2001 From: publicmatt Date: Thu, 14 Mar 2024 13:47:37 -0700 Subject: [PATCH] remove conda. add requirements.txt. add tests --- .gitignore | 165 +++++++++++++++++++++++++++++++++++++++++- Makefile | 29 ++++++-- requirements.txt | 12 +++ src/data/dataset.py | 66 +++++++++++++++++ src/data/spark.py | 21 ++++++ src/model/cnn.py | 4 +- test/.env.test | 2 + test/conftest.py | 20 +++++ test/test_pipeline.py | 17 +++++ 9 files changed, 324 insertions(+), 12 deletions(-) create mode 100644 requirements.txt create mode 100644 src/data/dataset.py create mode 100644 src/data/spark.py create mode 100644 test/.env.test create mode 100644 test/conftest.py create mode 100644 test/test_pipeline.py diff --git a/.gitignore b/.gitignore index cffc7d1..4550c23 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,163 @@ -storage/ +data/ + +outputs +# Byte-compiled / optimized / DLL files __pycache__/ -*.swp -*.tmp +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/Makefile b/Makefile index 334c298..c706e0e 100644 --- a/Makefile +++ b/Makefile @@ -1,15 +1,30 @@ -CONDA_ENV=ml_pipeline +PYTHON=.venv/bin/python3 +.PHONY: help test all: run -run: - python src/pipeline.py train +init: + python3.9 -m virtualenv .venv -data: - python src/data.py +run: ## run the pipeline (train) + $(PYTHON) src/train.py \ + debug=false -batch: - python src/batch.py +debug: ## run the pipeline (train) with debugging enabled + $(PYTHON) src/train.py \ + debug=true + +data: ## download the mnist data + wget https://pjreddie.com/media/files/mnist_train.csv -O data/mnist_train.csv + wget https://pjreddie.com/media/files/mnist_test.csv -O data/mnist_test.csv +test: + find . -iname "*.py" | entr -c pytest + +install: + $(PYTHON) -m pip install -r requirements.txt + +help: ## display this help message + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' install: conda env updates -n ${CONDA_ENV} --file environment.yml diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c94b179 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +black +click +einops +hydra-core +matplotlib +numpy +wandb +pytest +python-dotenv +torch +requests +tqdm diff --git a/src/data/dataset.py b/src/data/dataset.py new file mode 100644 index 0000000..6be33b0 --- /dev/null +++ b/src/data/dataset.py @@ -0,0 +1,66 @@ +from torch.utils.data import Dataset +import numpy as np +import einops +import csv +import torch +from pathlib import Path +from typing import Tuple +import os + + +class MnistDataset(Dataset): + """ + The MNIST database of handwritten digits. + Training set is 60k labeled examples, test is 10k examples. + The b/w images normalized to 20x20, preserving aspect ratio. + + It's the defacto standard image training set to learn about classification in DL + """ + + def __init__(self, path: Path): + """ + give a path to a dir that contains the following csv files: + https://pjreddie.com/projects/mnist-in-csv/ + """ + self.path = path + self.features, self.labels = self._load() + + def __getitem__(self, idx): + return (self.features[idx], self.labels[idx]) + + def __len__(self): + return len(self.features) + + def _load(self) -> Tuple[torch.Tensor, torch.Tensor]: + # opening the CSV file + with open(self.path, mode="r") as file: + images, labels = [], [] + csvFile = csv.reader(file) + examples = int(os.getenv("TRAINING_EXAMPLES", 1000)) + for line, content in enumerate(csvFile): + if line == examples: + break + labels.append(int(content[0])) + image = [int(x) for x in content[1:]] + images.append(image) + labels = torch.tensor(labels, dtype=torch.long) + images = torch.tensor(images, dtype=torch.float32) + images = einops.rearrange(images, "n (w h) -> n w h", w=28, h=28) + images = einops.repeat( + images, "n w h -> n c (w r_w) (h r_h)", c=1, r_w=8, r_h=8 + ) + return (images, labels) + + +def main(): + path = Path("storage/mnist_train.csv") + dataset = MnistDataset(path=path) + print(f"len: {len(dataset)}") + print(f"first shape: {dataset[0][0].shape}") + mean = einops.reduce(dataset[:10][0], "n w h -> w h", "mean") + print(f"mean shape: {mean.shape}") + print(f"mean image: {mean}") + + +if __name__ == "__main__": + main() diff --git a/src/data/spark.py b/src/data/spark.py new file mode 100644 index 0000000..7e082cc --- /dev/null +++ b/src/data/spark.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +from sys import stdout +import csv + +# 'pip install pyspark' for these +from pyspark import SparkFiles +from pyspark.sql import SparkSession + +# make a spark "session". this creates a local hadoop cluster by default (!) +spark = SparkSession.builder.getOrCreate() +# put the input file in the cluster's filesystem: +spark.sparkContext.addFile("https://csvbase.com/meripaterson/stock-exchanges.csv") +# the following is much like for pandas +df = ( + spark.read.csv(f"file://{SparkFiles.get('stock-exchanges.csv')}", header=True) + .select("MIC") + .na.drop() + .sort("MIC") +) +# pyspark has no easy way to write csv to stdout - use python's csv lib +csv.writer(stdout).writerows(df.collect()) diff --git a/src/model/cnn.py b/src/model/cnn.py index 51f983f..04e0d12 100644 --- a/src/model/cnn.py +++ b/src/model/cnn.py @@ -37,10 +37,10 @@ class VGG11(nn.Module): self.linear_layers = nn.Sequential( nn.Linear(in_features=512 * 7 * 7, out_features=4096), nn.ReLU(), - nn.Dropout2d(0.5), + nn.Dropout(0.5), nn.Linear(in_features=4096, out_features=4096), nn.ReLU(), - nn.Dropout2d(0.5), + nn.Dropout(0.5), nn.Linear(in_features=4096, out_features=self.num_classes), ) diff --git a/test/.env.test b/test/.env.test new file mode 100644 index 0000000..bc63937 --- /dev/null +++ b/test/.env.test @@ -0,0 +1,2 @@ +TRAIN_PATH=${HOME}/Dev/ml/data/mnist_train.csv +INPUT_FEATURES=40 diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..fc788d0 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,20 @@ +# conftest.py +import pytest +import os +from dotenv import load_dotenv +from pathlib import Path + + +@pytest.fixture(autouse=True) +def load_env(): + # Set up your environment variables here + env = Path(__file__).parent / ".env.test" + if not load_dotenv(env): + raise RuntimeError(".env not loaded") + # os.environ['MY_ENV_VAR'] = 'some_value' + # You can add more setup code here if needed + + yield + + # Optional: Cleanup code after test (if needed) + # e.g., unset environment variables if they should not persist after test diff --git a/test/test_pipeline.py b/test/test_pipeline.py new file mode 100644 index 0000000..9ffb306 --- /dev/null +++ b/test/test_pipeline.py @@ -0,0 +1,17 @@ +from src.model.linear import DNN +from src.data.dataset import MnistDataset +import os + + +def test_size_of_dataset(): + examples = 500 + os.environ["TRAINING_EXAMPLES"] = str(examples) + channels = 1 + width, height = 224, 224 + dataset = MnistDataset(os.getenv("TRAIN_PATH")) + # label = dataset[0][1].item() + image = dataset[0][0].shape + assert channels == image[0] + assert width == image[1] + assert height == image[2] + assert len(dataset) == examples