change to algoroithm tpyes
This commit is contained in:
parent
25d0838376
commit
e9f503ef3e
|
@ -21,7 +21,7 @@ from analysis.constraints import ParamConstraintHelper
|
|||
from website.models import (PipelineData, PipelineRun, Result, Workload, KnobCatalog,
|
||||
MetricCatalog, SessionKnob)
|
||||
from website.parser import Parser
|
||||
from website.types import PipelineTaskType
|
||||
from website.types import PipelineTaskType, AlgorithmType
|
||||
from website.utils import DataUtil, JSONUtil
|
||||
from website.settings import IMPORTANT_KNOB_NUMBER, NUM_SAMPLES, TOP_NUM_CONFIG # pylint: disable=no-name-in-module
|
||||
from website.settings import (DEFAULT_LENGTH_SCALE, DEFAULT_MAGNITUDE,
|
||||
|
@ -543,9 +543,8 @@ def configuration_recommendation(recommendation_input):
|
|||
|
||||
session = newest_result.session
|
||||
res = None
|
||||
assert algorithm in ['gpr', 'dnn']
|
||||
|
||||
if algorithm == 'dnn':
|
||||
if algorithm == AlgorithmType.DNN:
|
||||
# neural network model
|
||||
model_nn = NeuralNet(n_input=X_samples.shape[1],
|
||||
batch_size=X_samples.shape[0],
|
||||
|
@ -562,7 +561,7 @@ def configuration_recommendation(recommendation_input):
|
|||
session.dnn_model = model_nn.get_weights_bin()
|
||||
session.save()
|
||||
|
||||
elif algorithm == 'gpr':
|
||||
elif algorithm == AlgorithmType.OTTERTUNE:
|
||||
# default gpr model
|
||||
model = GPRGD(length_scale=DEFAULT_LENGTH_SCALE,
|
||||
magnitude=DEFAULT_MAGNITUDE,
|
||||
|
|
|
@ -529,14 +529,14 @@ def handle_result_files(session, files):
|
|||
result_id = result.pk
|
||||
response = None
|
||||
if session.algorithm == AlgorithmType.OTTERTUNE:
|
||||
response = chain(aggregate_target_results.s(result.pk),
|
||||
response = chain(aggregate_target_results.s(result.pk, session.algorithm),
|
||||
map_workload.s(),
|
||||
configuration_recommendation.s()).apply_async()
|
||||
elif session.algorithm == AlgorithmType.DDPG:
|
||||
response = chain(train_ddpg.s(result.pk),
|
||||
configuration_recommendation_ddpg.s()).apply_async()
|
||||
elif session.algorithm == AlgorithmType.DNN:
|
||||
response = chain(aggregate_target_results.s(result.pk, 'dnn'),
|
||||
response = chain(aggregate_target_results.s(result.pk, session.algorithm),
|
||||
map_workload.s(),
|
||||
configuration_recommendation.s()).apply_async()
|
||||
|
||||
|
|
Loading…
Reference in New Issue