change to algoroithm tpyes

This commit is contained in:
bohanjason 2019-09-29 23:58:35 -04:00 committed by Dana Van Aken
parent 25d0838376
commit e9f503ef3e
2 changed files with 5 additions and 6 deletions

View File

@ -21,7 +21,7 @@ from analysis.constraints import ParamConstraintHelper
from website.models import (PipelineData, PipelineRun, Result, Workload, KnobCatalog, from website.models import (PipelineData, PipelineRun, Result, Workload, KnobCatalog,
MetricCatalog, SessionKnob) MetricCatalog, SessionKnob)
from website.parser import Parser from website.parser import Parser
from website.types import PipelineTaskType from website.types import PipelineTaskType, AlgorithmType
from website.utils import DataUtil, JSONUtil from website.utils import DataUtil, JSONUtil
from website.settings import IMPORTANT_KNOB_NUMBER, NUM_SAMPLES, TOP_NUM_CONFIG # pylint: disable=no-name-in-module from website.settings import IMPORTANT_KNOB_NUMBER, NUM_SAMPLES, TOP_NUM_CONFIG # pylint: disable=no-name-in-module
from website.settings import (DEFAULT_LENGTH_SCALE, DEFAULT_MAGNITUDE, from website.settings import (DEFAULT_LENGTH_SCALE, DEFAULT_MAGNITUDE,
@ -543,9 +543,8 @@ def configuration_recommendation(recommendation_input):
session = newest_result.session session = newest_result.session
res = None res = None
assert algorithm in ['gpr', 'dnn']
if algorithm == 'dnn': if algorithm == AlgorithmType.DNN:
# neural network model # neural network model
model_nn = NeuralNet(n_input=X_samples.shape[1], model_nn = NeuralNet(n_input=X_samples.shape[1],
batch_size=X_samples.shape[0], batch_size=X_samples.shape[0],
@ -562,7 +561,7 @@ def configuration_recommendation(recommendation_input):
session.dnn_model = model_nn.get_weights_bin() session.dnn_model = model_nn.get_weights_bin()
session.save() session.save()
elif algorithm == 'gpr': elif algorithm == AlgorithmType.OTTERTUNE:
# default gpr model # default gpr model
model = GPRGD(length_scale=DEFAULT_LENGTH_SCALE, model = GPRGD(length_scale=DEFAULT_LENGTH_SCALE,
magnitude=DEFAULT_MAGNITUDE, magnitude=DEFAULT_MAGNITUDE,

View File

@ -529,14 +529,14 @@ def handle_result_files(session, files):
result_id = result.pk result_id = result.pk
response = None response = None
if session.algorithm == AlgorithmType.OTTERTUNE: if session.algorithm == AlgorithmType.OTTERTUNE:
response = chain(aggregate_target_results.s(result.pk), response = chain(aggregate_target_results.s(result.pk, session.algorithm),
map_workload.s(), map_workload.s(),
configuration_recommendation.s()).apply_async() configuration_recommendation.s()).apply_async()
elif session.algorithm == AlgorithmType.DDPG: elif session.algorithm == AlgorithmType.DDPG:
response = chain(train_ddpg.s(result.pk), response = chain(train_ddpg.s(result.pk),
configuration_recommendation_ddpg.s()).apply_async() configuration_recommendation_ddpg.s()).apply_async()
elif session.algorithm == AlgorithmType.DNN: elif session.algorithm == AlgorithmType.DNN:
response = chain(aggregate_target_results.s(result.pk, 'dnn'), response = chain(aggregate_target_results.s(result.pk, session.algorithm),
map_workload.s(), map_workload.s(),
configuration_recommendation.s()).apply_async() configuration_recommendation.s()).apply_async()