init
This commit is contained in:
@@ -0,0 +1,27 @@
|
||||
from pathlib import Path
|
||||
import requests
|
||||
import logging
|
||||
from ml_pipeline import config
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def download(data_path: Path, force=False):
|
||||
|
||||
urls = {
|
||||
'train' : 'https://pjreddie.com/media/files/mnist_train.csv',
|
||||
'test' : 'https://pjreddie.com/media/files/mnist_test.csv'
|
||||
}
|
||||
for dataset, url in urls.items():
|
||||
filename = data_path / url.split('/')[-1]
|
||||
if filename.exists() and not force:
|
||||
logger.info(f'file exists {filename} (set force to overwrite)')
|
||||
continue
|
||||
logger.info(f'downloading {dataset} {url}')
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
with open(filename, 'wb') as file:
|
||||
file.write(response.content)
|
||||
logger.info(f'file downloaded {filename}')
|
||||
else:
|
||||
logger.info(f'failed to download file {filename}')
|
||||
@@ -0,0 +1,66 @@
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import einops
|
||||
import csv
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from ml_pipeline import config, logger
|
||||
|
||||
|
||||
class MnistDataset(Dataset):
|
||||
"""
|
||||
The MNIST database of handwritten digits.
|
||||
Training set is 60k labeled examples, test is 10k examples.
|
||||
The b/w images normalized to 20x20, preserving aspect ratio.
|
||||
|
||||
It's the defacto standard image training set to learn about classification in DL
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path):
|
||||
"""
|
||||
give a path to a dir that contains the following csv files:
|
||||
https://pjreddie.com/projects/mnist-in-csv/
|
||||
"""
|
||||
assert path, "dataset path required"
|
||||
self.path = Path(path)
|
||||
assert self.path.exists(), f"could not find dataset path: {path}"
|
||||
self.features, self.labels = self._load()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return (self.features[idx], self.labels[idx])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def _load(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# opening the CSV file
|
||||
with open(self.path, mode="r") as file:
|
||||
images, labels = [], []
|
||||
csvFile = csv.reader(file)
|
||||
examples = config.training.examples
|
||||
for line, content in enumerate(csvFile):
|
||||
if line == examples:
|
||||
break
|
||||
labels.append(int(content[0]))
|
||||
image = [int(x) for x in content[1:]]
|
||||
images.append(image)
|
||||
labels = torch.tensor(labels, dtype=torch.int64)
|
||||
images = torch.tensor(images, dtype=torch.float32)
|
||||
images = einops.rearrange(images, "n (w h) -> n w h", w=28, h=28)
|
||||
images = einops.repeat(
|
||||
images, "n w h -> n c (w r_w) (h r_h)", c=1, r_w=8, r_h=8
|
||||
)
|
||||
return (images, labels)
|
||||
|
||||
|
||||
def debug():
|
||||
path = Path(config.paths.data) / "mnist_train.csv"
|
||||
dataset = MnistDataset(path=path)
|
||||
logger.info(f"len: {len(dataset)}")
|
||||
logger.info(f"first shape: {dataset[0][0].shape}")
|
||||
mean = einops.reduce(dataset[:10][0], "n w h -> w h", "mean")
|
||||
logger.info(f"mean shape: {mean.shape}")
|
||||
logger.info(f"mean image: {mean}")
|
||||
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python3
|
||||
from sys import stdout
|
||||
import csv
|
||||
|
||||
# 'pip install pyspark' for these
|
||||
from pyspark import SparkFiles
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
# make a spark "session". this creates a local hadoop cluster by default (!)
|
||||
spark = SparkSession.builder.getOrCreate()
|
||||
# put the input file in the cluster's filesystem:
|
||||
spark.sparkContext.addFile("https://csvbase.com/meripaterson/stock-exchanges.csv")
|
||||
# the following is much like for pandas
|
||||
df = (
|
||||
spark.read.csv(f"file://{SparkFiles.get('stock-exchanges.csv')}", header=True)
|
||||
.select("MIC")
|
||||
.na.drop()
|
||||
.sort("MIC")
|
||||
)
|
||||
# pyspark has no easy way to write csv to stdout - use python's csv lib
|
||||
csv.writer(stdout).writerows(df.collect())
|
||||
Reference in New Issue
Block a user