42 lines
1.0 KiB
Python
42 lines
1.0 KiB
Python
|
import click
|
||
|
|
||
|
@click.group()
|
||
|
@click.version_option()
|
||
|
def cli():
|
||
|
"""
|
||
|
ml_pipeline: a template for building, training and running pytorch models.
|
||
|
"""
|
||
|
|
||
|
|
||
|
@cli.command("pipeline:train")
|
||
|
def pipeline_train():
|
||
|
"""run the training pipeline with train data"""
|
||
|
from ml_pipeline.training import pipeline
|
||
|
pipeline.run(evaluate=False)
|
||
|
|
||
|
@cli.command("pipeline:evaluate")
|
||
|
def pipeline_evaluate():
|
||
|
"""run the training pipeline with test data"""
|
||
|
from ml_pipeline.training import pipeline
|
||
|
pipeline.run(evaluate=True)
|
||
|
|
||
|
@cli.command("app:serve")
|
||
|
def app_serve():
|
||
|
"""run the api server pipeline with pretrained model"""
|
||
|
from ml_pipeline import app
|
||
|
app.run()
|
||
|
|
||
|
@cli.command("data:download")
|
||
|
def data_download():
|
||
|
"""download the train and test data"""
|
||
|
from ml_pipeline import data
|
||
|
from ml_pipeline import config
|
||
|
from pathlib import Path
|
||
|
data.download(Path(config.paths.data))
|
||
|
|
||
|
@cli.command("data:debug")
|
||
|
def data_debug():
|
||
|
"""debug the dataset class"""
|
||
|
from ml_pipeline.data import dataset
|
||
|
dataset.debug()
|