change to algoroithm tpyes

This commit is contained in:
bohanjason 2019-09-29 23:58:35 -04:00 committed by Dana Van Aken
parent 25d0838376
commit e9f503ef3e
2 changed files with 5 additions and 6 deletions

View File

@ -21,7 +21,7 @@ from analysis.constraints import ParamConstraintHelper
from website.models import (PipelineData, PipelineRun, Result, Workload, KnobCatalog,
MetricCatalog, SessionKnob)
from website.parser import Parser
from website.types import PipelineTaskType
from website.types import PipelineTaskType, AlgorithmType
from website.utils import DataUtil, JSONUtil
from website.settings import IMPORTANT_KNOB_NUMBER, NUM_SAMPLES, TOP_NUM_CONFIG # pylint: disable=no-name-in-module
from website.settings import (DEFAULT_LENGTH_SCALE, DEFAULT_MAGNITUDE,
@ -543,9 +543,8 @@ def configuration_recommendation(recommendation_input):
session = newest_result.session
res = None
assert algorithm in ['gpr', 'dnn']
if algorithm == 'dnn':
if algorithm == AlgorithmType.DNN:
# neural network model
model_nn = NeuralNet(n_input=X_samples.shape[1],
batch_size=X_samples.shape[0],
@ -562,7 +561,7 @@ def configuration_recommendation(recommendation_input):
session.dnn_model = model_nn.get_weights_bin()
session.save()
elif algorithm == 'gpr':
elif algorithm == AlgorithmType.OTTERTUNE:
# default gpr model
model = GPRGD(length_scale=DEFAULT_LENGTH_SCALE,
magnitude=DEFAULT_MAGNITUDE,

View File

@ -529,14 +529,14 @@ def handle_result_files(session, files):
result_id = result.pk
response = None
if session.algorithm == AlgorithmType.OTTERTUNE:
response = chain(aggregate_target_results.s(result.pk),
response = chain(aggregate_target_results.s(result.pk, session.algorithm),
map_workload.s(),
configuration_recommendation.s()).apply_async()
elif session.algorithm == AlgorithmType.DDPG:
response = chain(train_ddpg.s(result.pk),
configuration_recommendation_ddpg.s()).apply_async()
elif session.algorithm == AlgorithmType.DNN:
response = chain(aggregate_target_results.s(result.pk, 'dnn'),
response = chain(aggregate_target_results.s(result.pk, session.algorithm),
map_workload.s(),
configuration_recommendation.s()).apply_async()