67 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			67 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
| from torch.utils.data import Dataset
 | |
| import numpy as np
 | |
| import einops
 | |
| import csv
 | |
| import torch
 | |
| from pathlib import Path
 | |
| from typing import Tuple
 | |
| from {{cookiecutter.module_name}} import config, logger
 | |
| 
 | |
| 
 | |
| class MnistDataset(Dataset):
 | |
|     """
 | |
|     The MNIST database of handwritten digits.
 | |
|     Training set is 60k labeled examples, test is 10k examples.
 | |
|     The b/w images normalized to 20x20, preserving aspect ratio.
 | |
| 
 | |
|     It's the defacto standard image training set to learn about classification in DL
 | |
|     """
 | |
| 
 | |
|     def __init__(self, path: Path):
 | |
|         """
 | |
|         give a path to a dir that contains the following csv files:
 | |
|         https://pjreddie.com/projects/mnist-in-csv/
 | |
|         """
 | |
|         assert path, "dataset path required"
 | |
|         self.path = Path(path)
 | |
|         assert self.path.exists(), f"could not find dataset path: {path}"
 | |
|         self.features, self.labels = self._load()
 | |
| 
 | |
|     def __getitem__(self, idx):
 | |
|         return (self.features[idx], self.labels[idx])
 | |
| 
 | |
|     def __len__(self):
 | |
|         return len(self.features)
 | |
| 
 | |
|     def _load(self) -> Tuple[torch.Tensor, torch.Tensor]:
 | |
|         # opening the CSV file
 | |
|         with open(self.path, mode="r") as file:
 | |
|             images, labels = [], []
 | |
|             csvFile = csv.reader(file)
 | |
|             examples = config.training.examples
 | |
|             for line, content in enumerate(csvFile):
 | |
|                 if line == examples:
 | |
|                     break
 | |
|                 labels.append(int(content[0]))
 | |
|                 image = [int(x) for x in content[1:]]
 | |
|                 images.append(image)
 | |
|             labels = torch.tensor(labels, dtype=torch.int64)
 | |
|             images = torch.tensor(images, dtype=torch.float32)
 | |
|             images = einops.rearrange(images, "n (w h) -> n w h", w=28, h=28)
 | |
|             images = einops.repeat(
 | |
|                 images, "n w h -> n c (w r_w) (h r_h)", c=1, r_w=8, r_h=8
 | |
|             )
 | |
|             return (images, labels)
 | |
| 
 | |
| 
 | |
| def debug():
 | |
|     path = Path(config.paths.data) /  "mnist_train.csv"
 | |
|     dataset = MnistDataset(path=path)
 | |
|     logger.info(f"len: {len(dataset)}")
 | |
|     logger.info(f"first shape: {dataset[0][0].shape}")
 | |
|     mean = einops.reduce(dataset[:10][0], "n w h -> w h", "mean")
 | |
|     logger.info(f"mean shape: {mean.shape}")
 | |
|     logger.info(f"mean image: {mean}")
 | |
| 
 | |
| 
 |