diff --git a/server/website/website/tasks/async_tasks.py b/server/website/website/tasks/async_tasks.py index 8900fd6..d18c9c0 100644 --- a/server/website/website/tasks/async_tasks.py +++ b/server/website/website/tasks/async_tasks.py @@ -21,7 +21,7 @@ from analysis.constraints import ParamConstraintHelper from website.models import (PipelineData, PipelineRun, Result, Workload, KnobCatalog, MetricCatalog, SessionKnob) from website.parser import Parser -from website.types import PipelineTaskType +from website.types import PipelineTaskType, AlgorithmType from website.utils import DataUtil, JSONUtil from website.settings import IMPORTANT_KNOB_NUMBER, NUM_SAMPLES, TOP_NUM_CONFIG # pylint: disable=no-name-in-module from website.settings import (DEFAULT_LENGTH_SCALE, DEFAULT_MAGNITUDE, @@ -543,9 +543,8 @@ def configuration_recommendation(recommendation_input): session = newest_result.session res = None - assert algorithm in ['gpr', 'dnn'] - if algorithm == 'dnn': + if algorithm == AlgorithmType.DNN: # neural network model model_nn = NeuralNet(n_input=X_samples.shape[1], batch_size=X_samples.shape[0], @@ -562,7 +561,7 @@ def configuration_recommendation(recommendation_input): session.dnn_model = model_nn.get_weights_bin() session.save() - elif algorithm == 'gpr': + elif algorithm == AlgorithmType.OTTERTUNE: # default gpr model model = GPRGD(length_scale=DEFAULT_LENGTH_SCALE, magnitude=DEFAULT_MAGNITUDE, diff --git a/server/website/website/views.py b/server/website/website/views.py index e9b618b..a29fb4a 100644 --- a/server/website/website/views.py +++ b/server/website/website/views.py @@ -529,14 +529,14 @@ def handle_result_files(session, files): result_id = result.pk response = None if session.algorithm == AlgorithmType.OTTERTUNE: - response = chain(aggregate_target_results.s(result.pk), + response = chain(aggregate_target_results.s(result.pk, session.algorithm), map_workload.s(), configuration_recommendation.s()).apply_async() elif session.algorithm == AlgorithmType.DDPG: response = chain(train_ddpg.s(result.pk), configuration_recommendation_ddpg.s()).apply_async() elif session.algorithm == AlgorithmType.DNN: - response = chain(aggregate_target_results.s(result.pk, 'dnn'), + response = chain(aggregate_target_results.s(result.pk, session.algorithm), map_workload.s(), configuration_recommendation.s()).apply_async()