ml_pipeline_cookiecutter/{{cookiecutter.project_name}}/{{cookiecutter.module_name}}/data/dataset.py

67 lines
2.2 KiB
Python
Raw Normal View History

2024-04-06 13:02:31 -07:00
from torch.utils.data import Dataset
import numpy as np
import einops
import csv
import torch
from pathlib import Path
from typing import Tuple
from {{cookiecutter.module_name}} import config, logger
2024-04-06 13:02:31 -07:00
class MnistDataset(Dataset):
"""
The MNIST database of handwritten digits.
Training set is 60k labeled examples, test is 10k examples.
The b/w images normalized to 20x20, preserving aspect ratio.
It's the defacto standard image training set to learn about classification in DL
"""
def __init__(self, path: Path):
"""
give a path to a dir that contains the following csv files:
https://pjreddie.com/projects/mnist-in-csv/
"""
assert path, "dataset path required"
self.path = Path(path)
assert self.path.exists(), f"could not find dataset path: {path}"
self.features, self.labels = self._load()
def __getitem__(self, idx):
return (self.features[idx], self.labels[idx])
def __len__(self):
return len(self.features)
def _load(self) -> Tuple[torch.Tensor, torch.Tensor]:
# opening the CSV file
with open(self.path, mode="r") as file:
images, labels = [], []
csvFile = csv.reader(file)
examples = config.training.examples
for line, content in enumerate(csvFile):
if line == examples:
break
labels.append(int(content[0]))
image = [int(x) for x in content[1:]]
images.append(image)
labels = torch.tensor(labels, dtype=torch.int64)
images = torch.tensor(images, dtype=torch.float32)
images = einops.rearrange(images, "n (w h) -> n w h", w=28, h=28)
images = einops.repeat(
images, "n w h -> n c (w r_w) (h r_h)", c=1, r_w=8, r_h=8
)
return (images, labels)
def debug():
path = Path(config.paths.data) / "mnist_train.csv"
dataset = MnistDataset(path=path)
logger.info(f"len: {len(dataset)}")
logger.info(f"first shape: {dataset[0][0].shape}")
mean = einops.reduce(dataset[:10][0], "n w h -> w h", "mean")
logger.info(f"mean shape: {mean.shape}")
logger.info(f"mean image: {mean}")