73 lines
2.2 KiB
Python
73 lines
2.2 KiB
Python
from torch.utils.data import Dataset
|
|
import numpy as np
|
|
import einops
|
|
import csv
|
|
import torch
|
|
from pathlib import Path
|
|
from typing import Tuple
|
|
|
|
|
|
class MnistDataset(Dataset):
|
|
"""
|
|
The MNIST database of handwritten digits.
|
|
Training set is 60k labeled examples, test is 10k examples.
|
|
The b/w images normalized to 20x20, preserving aspect ratio.
|
|
|
|
It's the defacto standard image training set to learn about classification in DL
|
|
"""
|
|
|
|
def __init__(self, path: Path):
|
|
"""
|
|
give a path to a dir that contains the following csv files:
|
|
https://pjreddie.com/projects/mnist-in-csv/
|
|
"""
|
|
self.path = path
|
|
self.features, self.labels = self.load()
|
|
|
|
def __getitem__(self, idx):
|
|
return (self.features[idx], self.labels[idx])
|
|
|
|
def __len__(self):
|
|
return len(self.features)
|
|
|
|
def load(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# opening the CSV file
|
|
with open(self.path, mode="r") as file:
|
|
images = list()
|
|
labels = list()
|
|
# reading the CSV file
|
|
csvFile = csv.reader(file)
|
|
# displaying the contents of the CSV file
|
|
# header = next(csvFile)
|
|
limit = 1000
|
|
for line in csvFile:
|
|
if limit < 1:
|
|
break
|
|
label = int(line[0])
|
|
labels.append(label)
|
|
image = [int(x) for x in line[1:]]
|
|
images.append(image)
|
|
limit -= 1
|
|
labels = torch.tensor(labels, dtype=torch.long)
|
|
images = torch.tensor(images, dtype=torch.float32)
|
|
images = einops.rearrange(images, "n (w h) -> n w h", w=28, h=28)
|
|
images = einops.repeat(
|
|
images, "n w h -> n c (w r_w) (h r_h)", c=1, r_w=8, r_h=8
|
|
)
|
|
return (images, labels)
|
|
|
|
|
|
def main():
|
|
|
|
path = "storage/mnist_train.csv"
|
|
dataset = MnistDataset(path=path)
|
|
print(f"len: {len(dataset)}")
|
|
print(f"first shape: {dataset[0][0].shape}")
|
|
mean = einops.reduce(dataset[:10][0], "n w h -> w h", "mean")
|
|
print(f"mean shape: {mean.shape}")
|
|
print(f"mean image: {mean}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|