ml_pipeline_cookiecutter/{{cookiecutter.project_slug}}/{{cookiecutter.module_name}}/cli.py

42 lines
1.0 KiB
Python
Raw Normal View History

2024-04-06 13:02:31 -07:00
import click
@click.group()
@click.version_option()
def cli():
"""
ml_pipeline: a template for building, training and running pytorch models.
"""
@cli.command("pipeline:train")
def pipeline_train():
"""run the training pipeline with train data"""
from ml_pipeline.training import pipeline
pipeline.run(evaluate=False)
@cli.command("pipeline:evaluate")
def pipeline_evaluate():
"""run the training pipeline with test data"""
from ml_pipeline.training import pipeline
pipeline.run(evaluate=True)
@cli.command("app:serve")
def app_serve():
"""run the api server pipeline with pretrained model"""
from ml_pipeline import app
app.run()
@cli.command("data:download")
def data_download():
"""download the train and test data"""
from ml_pipeline import data
from ml_pipeline import config
from pathlib import Path
data.download(Path(config.paths.data))
@cli.command("data:debug")
def data_debug():
"""debug the dataset class"""
from ml_pipeline.data import dataset
dataset.debug()