refactor files.
add a notebook module. add config package.
This commit is contained in:
@@ -1,20 +0,0 @@
|
||||
# conftest.py
|
||||
import pytest
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def load_env():
|
||||
# Set up your environment variables here
|
||||
env = Path(__file__).parent / ".env.test"
|
||||
if not load_dotenv(env):
|
||||
raise RuntimeError(".env not loaded")
|
||||
# os.environ['MY_ENV_VAR'] = 'some_value'
|
||||
# You can add more setup code here if needed
|
||||
|
||||
yield
|
||||
|
||||
# Optional: Cleanup code after test (if needed)
|
||||
# e.g., unset environment variables if they should not persist after test
|
||||
6
test/test_cnn.py
Normal file
6
test/test_cnn.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from ml_pipeline import config
|
||||
from ml_pipeline.model.cnn import VGG11
|
||||
|
||||
def test_in_channels():
|
||||
assert config.model.name == 'vgg11'
|
||||
|
||||
28
test/test_inputs.py
Normal file
28
test/test_inputs.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from ml_pipeline.data.dataset import MnistDataset
|
||||
from ml_pipeline import config
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
|
||||
@pytest.mark.skip()
|
||||
def test_init():
|
||||
pass
|
||||
|
||||
|
||||
def test_getitem():
|
||||
train_set = MnistDataset(config.data.train_path)
|
||||
|
||||
assert train_set[0][1].item() == 5
|
||||
repeated = 8
|
||||
length = 28
|
||||
channels = 1
|
||||
assert train_set[0][0].shape == (channels, length * repeated, length * repeated)
|
||||
|
||||
@pytest.mark.skip()
|
||||
def test_loader():
|
||||
from torch.utils.data import DataLoader
|
||||
train_set = MnistDataset(config.data.train_path)
|
||||
# train_loader = DataLoader(train_set, batch_size=config.training.batch_size, shuffle=True)
|
||||
# for sample, target in train_loader:
|
||||
# assert len(sample) == config.training.batch_size
|
||||
# len(sample)
|
||||
# len(target)
|
||||
@@ -1,17 +0,0 @@
|
||||
from src.model.linear import DNN
|
||||
from src.data.dataset import MnistDataset
|
||||
import os
|
||||
|
||||
|
||||
def test_size_of_dataset():
|
||||
examples = 500
|
||||
os.environ["TRAINING_EXAMPLES"] = str(examples)
|
||||
channels = 1
|
||||
width, height = 224, 224
|
||||
dataset = MnistDataset(os.getenv("TRAIN_PATH"))
|
||||
# label = dataset[0][1].item()
|
||||
image = dataset[0][0].shape
|
||||
assert channels == image[0]
|
||||
assert width == image[1]
|
||||
assert height == image[2]
|
||||
assert len(dataset) == examples
|
||||
Reference in New Issue
Block a user