67 lines
2.2 KiB
Python
67 lines
2.2 KiB
Python
from torch.utils.data import Dataset
|
|
import numpy as np
|
|
import einops
|
|
import csv
|
|
import torch
|
|
from pathlib import Path
|
|
from typing import Tuple
|
|
from {{cookiecutter.module_name}} import config, logger
|
|
|
|
|
|
class MnistDataset(Dataset):
|
|
"""
|
|
The MNIST database of handwritten digits.
|
|
Training set is 60k labeled examples, test is 10k examples.
|
|
The b/w images normalized to 20x20, preserving aspect ratio.
|
|
|
|
It's the defacto standard image training set to learn about classification in DL
|
|
"""
|
|
|
|
def __init__(self, path: Path):
|
|
"""
|
|
give a path to a dir that contains the following csv files:
|
|
https://pjreddie.com/projects/mnist-in-csv/
|
|
"""
|
|
assert path, "dataset path required"
|
|
self.path = Path(path)
|
|
assert self.path.exists(), f"could not find dataset path: {path}"
|
|
self.features, self.labels = self._load()
|
|
|
|
def __getitem__(self, idx):
|
|
return (self.features[idx], self.labels[idx])
|
|
|
|
def __len__(self):
|
|
return len(self.features)
|
|
|
|
def _load(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# opening the CSV file
|
|
with open(self.path, mode="r") as file:
|
|
images, labels = [], []
|
|
csvFile = csv.reader(file)
|
|
examples = config.training.examples
|
|
for line, content in enumerate(csvFile):
|
|
if line == examples:
|
|
break
|
|
labels.append(int(content[0]))
|
|
image = [int(x) for x in content[1:]]
|
|
images.append(image)
|
|
labels = torch.tensor(labels, dtype=torch.int64)
|
|
images = torch.tensor(images, dtype=torch.float32)
|
|
images = einops.rearrange(images, "n (w h) -> n w h", w=28, h=28)
|
|
images = einops.repeat(
|
|
images, "n w h -> n c (w r_w) (h r_h)", c=1, r_w=8, r_h=8
|
|
)
|
|
return (images, labels)
|
|
|
|
|
|
def debug():
|
|
path = Path(config.paths.data) / "mnist_train.csv"
|
|
dataset = MnistDataset(path=path)
|
|
logger.info(f"len: {len(dataset)}")
|
|
logger.info(f"first shape: {dataset[0][0].shape}")
|
|
mean = einops.reduce(dataset[:10][0], "n w h -> w h", "mean")
|
|
logger.info(f"mean shape: {mean.shape}")
|
|
logger.info(f"mean image: {mean}")
|
|
|
|
|