remove conda. add requirements.txt. add tests

This commit is contained in:
publicmatt 2024-03-14 13:47:37 -07:00
parent 355e83843f
commit 44250fc618
9 changed files with 324 additions and 12 deletions

165
.gitignore vendored
View File

@ -1,4 +1,163 @@
storage/
data/
outputs
# Byte-compiled / optimized / DLL files
__pycache__/
*.swp
*.tmp
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

View File

@ -1,15 +1,30 @@
CONDA_ENV=ml_pipeline
PYTHON=.venv/bin/python3
.PHONY: help test
all: run
run:
python src/pipeline.py train
init:
python3.9 -m virtualenv .venv
data:
python src/data.py
run: ## run the pipeline (train)
$(PYTHON) src/train.py \
debug=false
batch:
python src/batch.py
debug: ## run the pipeline (train) with debugging enabled
$(PYTHON) src/train.py \
debug=true
data: ## download the mnist data
wget https://pjreddie.com/media/files/mnist_train.csv -O data/mnist_train.csv
wget https://pjreddie.com/media/files/mnist_test.csv -O data/mnist_test.csv
test:
find . -iname "*.py" | entr -c pytest
install:
$(PYTHON) -m pip install -r requirements.txt
help: ## display this help message
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
install:
conda env updates -n ${CONDA_ENV} --file environment.yml

12
requirements.txt Normal file
View File

@ -0,0 +1,12 @@
black
click
einops
hydra-core
matplotlib
numpy
wandb
pytest
python-dotenv
torch
requests
tqdm

66
src/data/dataset.py Normal file
View File

@ -0,0 +1,66 @@
from torch.utils.data import Dataset
import numpy as np
import einops
import csv
import torch
from pathlib import Path
from typing import Tuple
import os
class MnistDataset(Dataset):
"""
The MNIST database of handwritten digits.
Training set is 60k labeled examples, test is 10k examples.
The b/w images normalized to 20x20, preserving aspect ratio.
It's the defacto standard image training set to learn about classification in DL
"""
def __init__(self, path: Path):
"""
give a path to a dir that contains the following csv files:
https://pjreddie.com/projects/mnist-in-csv/
"""
self.path = path
self.features, self.labels = self._load()
def __getitem__(self, idx):
return (self.features[idx], self.labels[idx])
def __len__(self):
return len(self.features)
def _load(self) -> Tuple[torch.Tensor, torch.Tensor]:
# opening the CSV file
with open(self.path, mode="r") as file:
images, labels = [], []
csvFile = csv.reader(file)
examples = int(os.getenv("TRAINING_EXAMPLES", 1000))
for line, content in enumerate(csvFile):
if line == examples:
break
labels.append(int(content[0]))
image = [int(x) for x in content[1:]]
images.append(image)
labels = torch.tensor(labels, dtype=torch.long)
images = torch.tensor(images, dtype=torch.float32)
images = einops.rearrange(images, "n (w h) -> n w h", w=28, h=28)
images = einops.repeat(
images, "n w h -> n c (w r_w) (h r_h)", c=1, r_w=8, r_h=8
)
return (images, labels)
def main():
path = Path("storage/mnist_train.csv")
dataset = MnistDataset(path=path)
print(f"len: {len(dataset)}")
print(f"first shape: {dataset[0][0].shape}")
mean = einops.reduce(dataset[:10][0], "n w h -> w h", "mean")
print(f"mean shape: {mean.shape}")
print(f"mean image: {mean}")
if __name__ == "__main__":
main()

21
src/data/spark.py Normal file
View File

@ -0,0 +1,21 @@
#!/usr/bin/env python3
from sys import stdout
import csv
# 'pip install pyspark' for these
from pyspark import SparkFiles
from pyspark.sql import SparkSession
# make a spark "session". this creates a local hadoop cluster by default (!)
spark = SparkSession.builder.getOrCreate()
# put the input file in the cluster's filesystem:
spark.sparkContext.addFile("https://csvbase.com/meripaterson/stock-exchanges.csv")
# the following is much like for pandas
df = (
spark.read.csv(f"file://{SparkFiles.get('stock-exchanges.csv')}", header=True)
.select("MIC")
.na.drop()
.sort("MIC")
)
# pyspark has no easy way to write csv to stdout - use python's csv lib
csv.writer(stdout).writerows(df.collect())

View File

@ -37,10 +37,10 @@ class VGG11(nn.Module):
self.linear_layers = nn.Sequential(
nn.Linear(in_features=512 * 7 * 7, out_features=4096),
nn.ReLU(),
nn.Dropout2d(0.5),
nn.Dropout(0.5),
nn.Linear(in_features=4096, out_features=4096),
nn.ReLU(),
nn.Dropout2d(0.5),
nn.Dropout(0.5),
nn.Linear(in_features=4096, out_features=self.num_classes),
)

2
test/.env.test Normal file
View File

@ -0,0 +1,2 @@
TRAIN_PATH=${HOME}/Dev/ml/data/mnist_train.csv
INPUT_FEATURES=40

20
test/conftest.py Normal file
View File

@ -0,0 +1,20 @@
# conftest.py
import pytest
import os
from dotenv import load_dotenv
from pathlib import Path
@pytest.fixture(autouse=True)
def load_env():
# Set up your environment variables here
env = Path(__file__).parent / ".env.test"
if not load_dotenv(env):
raise RuntimeError(".env not loaded")
# os.environ['MY_ENV_VAR'] = 'some_value'
# You can add more setup code here if needed
yield
# Optional: Cleanup code after test (if needed)
# e.g., unset environment variables if they should not persist after test

17
test/test_pipeline.py Normal file
View File

@ -0,0 +1,17 @@
from src.model.linear import DNN
from src.data.dataset import MnistDataset
import os
def test_size_of_dataset():
examples = 500
os.environ["TRAINING_EXAMPLES"] = str(examples)
channels = 1
width, height = 224, 224
dataset = MnistDataset(os.getenv("TRAIN_PATH"))
# label = dataset[0][1].item()
image = dataset[0][0].shape
assert channels == image[0]
assert width == image[1]
assert height == image[2]
assert len(dataset) == examples