remove conda. add requirements.txt. add tests
This commit is contained in:
parent
355e83843f
commit
44250fc618
|
@ -1,4 +1,163 @@
|
|||
storage/
|
||||
data/
|
||||
|
||||
outputs
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.swp
|
||||
*.tmp
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
|
29
Makefile
29
Makefile
|
@ -1,15 +1,30 @@
|
|||
CONDA_ENV=ml_pipeline
|
||||
PYTHON=.venv/bin/python3
|
||||
.PHONY: help test
|
||||
|
||||
all: run
|
||||
|
||||
run:
|
||||
python src/pipeline.py train
|
||||
init:
|
||||
python3.9 -m virtualenv .venv
|
||||
|
||||
data:
|
||||
python src/data.py
|
||||
run: ## run the pipeline (train)
|
||||
$(PYTHON) src/train.py \
|
||||
debug=false
|
||||
|
||||
batch:
|
||||
python src/batch.py
|
||||
debug: ## run the pipeline (train) with debugging enabled
|
||||
$(PYTHON) src/train.py \
|
||||
debug=true
|
||||
|
||||
data: ## download the mnist data
|
||||
wget https://pjreddie.com/media/files/mnist_train.csv -O data/mnist_train.csv
|
||||
wget https://pjreddie.com/media/files/mnist_test.csv -O data/mnist_test.csv
|
||||
test:
|
||||
find . -iname "*.py" | entr -c pytest
|
||||
|
||||
install:
|
||||
$(PYTHON) -m pip install -r requirements.txt
|
||||
|
||||
help: ## display this help message
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
install:
|
||||
conda env updates -n ${CONDA_ENV} --file environment.yml
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
black
|
||||
click
|
||||
einops
|
||||
hydra-core
|
||||
matplotlib
|
||||
numpy
|
||||
wandb
|
||||
pytest
|
||||
python-dotenv
|
||||
torch
|
||||
requests
|
||||
tqdm
|
|
@ -0,0 +1,66 @@
|
|||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import einops
|
||||
import csv
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
import os
|
||||
|
||||
|
||||
class MnistDataset(Dataset):
|
||||
"""
|
||||
The MNIST database of handwritten digits.
|
||||
Training set is 60k labeled examples, test is 10k examples.
|
||||
The b/w images normalized to 20x20, preserving aspect ratio.
|
||||
|
||||
It's the defacto standard image training set to learn about classification in DL
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path):
|
||||
"""
|
||||
give a path to a dir that contains the following csv files:
|
||||
https://pjreddie.com/projects/mnist-in-csv/
|
||||
"""
|
||||
self.path = path
|
||||
self.features, self.labels = self._load()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return (self.features[idx], self.labels[idx])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def _load(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# opening the CSV file
|
||||
with open(self.path, mode="r") as file:
|
||||
images, labels = [], []
|
||||
csvFile = csv.reader(file)
|
||||
examples = int(os.getenv("TRAINING_EXAMPLES", 1000))
|
||||
for line, content in enumerate(csvFile):
|
||||
if line == examples:
|
||||
break
|
||||
labels.append(int(content[0]))
|
||||
image = [int(x) for x in content[1:]]
|
||||
images.append(image)
|
||||
labels = torch.tensor(labels, dtype=torch.long)
|
||||
images = torch.tensor(images, dtype=torch.float32)
|
||||
images = einops.rearrange(images, "n (w h) -> n w h", w=28, h=28)
|
||||
images = einops.repeat(
|
||||
images, "n w h -> n c (w r_w) (h r_h)", c=1, r_w=8, r_h=8
|
||||
)
|
||||
return (images, labels)
|
||||
|
||||
|
||||
def main():
|
||||
path = Path("storage/mnist_train.csv")
|
||||
dataset = MnistDataset(path=path)
|
||||
print(f"len: {len(dataset)}")
|
||||
print(f"first shape: {dataset[0][0].shape}")
|
||||
mean = einops.reduce(dataset[:10][0], "n w h -> w h", "mean")
|
||||
print(f"mean shape: {mean.shape}")
|
||||
print(f"mean image: {mean}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,21 @@
|
|||
#!/usr/bin/env python3
|
||||
from sys import stdout
|
||||
import csv
|
||||
|
||||
# 'pip install pyspark' for these
|
||||
from pyspark import SparkFiles
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
# make a spark "session". this creates a local hadoop cluster by default (!)
|
||||
spark = SparkSession.builder.getOrCreate()
|
||||
# put the input file in the cluster's filesystem:
|
||||
spark.sparkContext.addFile("https://csvbase.com/meripaterson/stock-exchanges.csv")
|
||||
# the following is much like for pandas
|
||||
df = (
|
||||
spark.read.csv(f"file://{SparkFiles.get('stock-exchanges.csv')}", header=True)
|
||||
.select("MIC")
|
||||
.na.drop()
|
||||
.sort("MIC")
|
||||
)
|
||||
# pyspark has no easy way to write csv to stdout - use python's csv lib
|
||||
csv.writer(stdout).writerows(df.collect())
|
|
@ -37,10 +37,10 @@ class VGG11(nn.Module):
|
|||
self.linear_layers = nn.Sequential(
|
||||
nn.Linear(in_features=512 * 7 * 7, out_features=4096),
|
||||
nn.ReLU(),
|
||||
nn.Dropout2d(0.5),
|
||||
nn.Dropout(0.5),
|
||||
nn.Linear(in_features=4096, out_features=4096),
|
||||
nn.ReLU(),
|
||||
nn.Dropout2d(0.5),
|
||||
nn.Dropout(0.5),
|
||||
nn.Linear(in_features=4096, out_features=self.num_classes),
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
TRAIN_PATH=${HOME}/Dev/ml/data/mnist_train.csv
|
||||
INPUT_FEATURES=40
|
|
@ -0,0 +1,20 @@
|
|||
# conftest.py
|
||||
import pytest
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def load_env():
|
||||
# Set up your environment variables here
|
||||
env = Path(__file__).parent / ".env.test"
|
||||
if not load_dotenv(env):
|
||||
raise RuntimeError(".env not loaded")
|
||||
# os.environ['MY_ENV_VAR'] = 'some_value'
|
||||
# You can add more setup code here if needed
|
||||
|
||||
yield
|
||||
|
||||
# Optional: Cleanup code after test (if needed)
|
||||
# e.g., unset environment variables if they should not persist after test
|
|
@ -0,0 +1,17 @@
|
|||
from src.model.linear import DNN
|
||||
from src.data.dataset import MnistDataset
|
||||
import os
|
||||
|
||||
|
||||
def test_size_of_dataset():
|
||||
examples = 500
|
||||
os.environ["TRAINING_EXAMPLES"] = str(examples)
|
||||
channels = 1
|
||||
width, height = 224, 224
|
||||
dataset = MnistDataset(os.getenv("TRAIN_PATH"))
|
||||
# label = dataset[0][1].item()
|
||||
image = dataset[0][0].shape
|
||||
assert channels == image[0]
|
||||
assert width == image[1]
|
||||
assert height == image[2]
|
||||
assert len(dataset) == examples
|
Loading…
Reference in New Issue