refactor files.

add a notebook module.
add config package.
This commit is contained in:
publicmatt
2024-04-05 18:37:24 -07:00
parent e508efefee
commit bae586612e
25 changed files with 431 additions and 234 deletions

View File

@@ -1,20 +0,0 @@
# conftest.py
import pytest
import os
from dotenv import load_dotenv
from pathlib import Path
@pytest.fixture(autouse=True)
def load_env():
# Set up your environment variables here
env = Path(__file__).parent / ".env.test"
if not load_dotenv(env):
raise RuntimeError(".env not loaded")
# os.environ['MY_ENV_VAR'] = 'some_value'
# You can add more setup code here if needed
yield
# Optional: Cleanup code after test (if needed)
# e.g., unset environment variables if they should not persist after test

6
test/test_cnn.py Normal file
View File

@@ -0,0 +1,6 @@
from ml_pipeline import config
from ml_pipeline.model.cnn import VGG11
def test_in_channels():
assert config.model.name == 'vgg11'

28
test/test_inputs.py Normal file
View File

@@ -0,0 +1,28 @@
from ml_pipeline.data.dataset import MnistDataset
from ml_pipeline import config
from pathlib import Path
import pytest
@pytest.mark.skip()
def test_init():
pass
def test_getitem():
train_set = MnistDataset(config.data.train_path)
assert train_set[0][1].item() == 5
repeated = 8
length = 28
channels = 1
assert train_set[0][0].shape == (channels, length * repeated, length * repeated)
@pytest.mark.skip()
def test_loader():
from torch.utils.data import DataLoader
train_set = MnistDataset(config.data.train_path)
# train_loader = DataLoader(train_set, batch_size=config.training.batch_size, shuffle=True)
# for sample, target in train_loader:
# assert len(sample) == config.training.batch_size
# len(sample)
# len(target)

View File

@@ -1,17 +0,0 @@
from src.model.linear import DNN
from src.data.dataset import MnistDataset
import os
def test_size_of_dataset():
examples = 500
os.environ["TRAINING_EXAMPLES"] = str(examples)
channels = 1
width, height = 224, 224
dataset = MnistDataset(os.getenv("TRAIN_PATH"))
# label = dataset[0][1].item()
image = dataset[0][0].shape
assert channels == image[0]
assert width == image[1]
assert height == image[2]
assert len(dataset) == examples