fix project_slug error.
replace ml_pipeline with project_name or module_name from config
This commit is contained in:
2
{{cookiecutter.project_name}}/test/.env.test
Normal file
2
{{cookiecutter.project_name}}/test/.env.test
Normal file
@@ -0,0 +1,2 @@
|
||||
TRAIN_PATH=${HOME}/Dev/ml/data/mnist_train.csv
|
||||
INPUT_FEATURES=40
|
||||
6
{{cookiecutter.project_name}}/test/test_cnn.py
Normal file
6
{{cookiecutter.project_name}}/test/test_cnn.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from {{cookiecutter.module_name}} import config
|
||||
from {{cookiecutter.module_name}}.model.cnn import VGG11
|
||||
|
||||
def test_in_channels():
|
||||
assert config.model.name == 'vgg11'
|
||||
|
||||
28
{{cookiecutter.project_name}}/test/test_inputs.py
Normal file
28
{{cookiecutter.project_name}}/test/test_inputs.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from {{cookiecutter.module_name}}.data.dataset import MnistDataset
|
||||
from {{cookiecutter.module_name}} import config
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
|
||||
@pytest.mark.skip()
|
||||
def test_init():
|
||||
pass
|
||||
|
||||
|
||||
def test_getitem():
|
||||
train_set = MnistDataset(config.data.train_path)
|
||||
|
||||
assert train_set[0][1].item() == 5
|
||||
repeated = 8
|
||||
length = 28
|
||||
channels = 1
|
||||
assert train_set[0][0].shape == (channels, length * repeated, length * repeated)
|
||||
|
||||
@pytest.mark.skip()
|
||||
def test_loader():
|
||||
from torch.utils.data import DataLoader
|
||||
train_set = MnistDataset(config.data.train_path)
|
||||
# train_loader = DataLoader(train_set, batch_size=config.training.batch_size, shuffle=True)
|
||||
# for sample, target in train_loader:
|
||||
# assert len(sample) == config.training.batch_size
|
||||
# len(sample)
|
||||
# len(target)
|
||||
Reference in New Issue
Block a user