fix project_slug error.

replace ml_pipeline with project_name or module_name from config
This commit is contained in:
publicmatt
2024-04-06 15:48:29 -07:00
parent 727f16df57
commit 6eed08d1ba
43 changed files with 81 additions and 95 deletions

View File

@@ -0,0 +1,2 @@
TRAIN_PATH=${HOME}/Dev/ml/data/mnist_train.csv
INPUT_FEATURES=40

View File

@@ -0,0 +1,6 @@
from {{cookiecutter.module_name}} import config
from {{cookiecutter.module_name}}.model.cnn import VGG11
def test_in_channels():
assert config.model.name == 'vgg11'

View File

@@ -0,0 +1,28 @@
from {{cookiecutter.module_name}}.data.dataset import MnistDataset
from {{cookiecutter.module_name}} import config
from pathlib import Path
import pytest
@pytest.mark.skip()
def test_init():
pass
def test_getitem():
train_set = MnistDataset(config.data.train_path)
assert train_set[0][1].item() == 5
repeated = 8
length = 28
channels = 1
assert train_set[0][0].shape == (channels, length * repeated, length * repeated)
@pytest.mark.skip()
def test_loader():
from torch.utils.data import DataLoader
train_set = MnistDataset(config.data.train_path)
# train_loader = DataLoader(train_set, batch_size=config.training.batch_size, shuffle=True)
# for sample, target in train_loader:
# assert len(sample) == config.training.batch_size
# len(sample)
# len(target)