This commit is contained in:
yangdsh 2020-03-23 03:38:54 +00:00 committed by Dana Van Aken
parent 56860d6364
commit b4e5fb2e66
1 changed files with 13 additions and 8 deletions

View File

@ -297,15 +297,20 @@ def preprocessing(result_id, algorithm):
# implement a sampling technique to generate new training data). # implement a sampling technique to generate new training data).
has_pipeline_data = PipelineData.objects.filter(workload=newest_result.workload).exists() has_pipeline_data = PipelineData.objects.filter(workload=newest_result.workload).exists()
session_results = Result.objects.filter(session=session) session_results = Result.objects.filter(session=session)
useful_results_cnt = len(session_results) results_cnt = len(session_results)
for result in session_results: skip_ddpg = False
for i, result in enumerate(session_results):
if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name: if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name:
useful_results_cnt -= 1 results_cnt -= 1
if not has_pipeline_data or useful_results_cnt == 0 or session.tuning_session == 'lhs': if i == len(session_results) - 1 and algorithm == AlgorithmType.DDPG:
skip_ddpg = True
if not has_pipeline_data or results_cnt == 0 or skip_ddpg or session.tuning_session == 'lhs':
if not has_pipeline_data and session.tuning_session == 'tuning_session': if not has_pipeline_data and session.tuning_session == 'tuning_session':
LOG.debug("Background tasks haven't ran for this workload yet, picking data with lhs.") LOG.debug("Background tasks haven't ran for this workload yet, picking data with lhs.")
if useful_results_cnt == 0 and session.tuning_session == 'tuning_session': if results_cnt == 0 and session.tuning_session == 'tuning_session':
LOG.debug("Not enough data in this session, picking data with lhs.") LOG.debug("Not enough data in this session, picking data with lhs.")
if skip_ddpg:
LOG.debug("The most recent result cannot be used by DDPG, picking data with lhs.")
all_samples = JSONUtil.loads(session.lhs_samples) all_samples = JSONUtil.loads(session.lhs_samples)
if len(all_samples) == 0: if len(all_samples) == 0:
@ -457,11 +462,11 @@ def train_ddpg(train_ddpg_input):
params = JSONUtil.loads(session.hyperparameters) params = JSONUtil.loads(session.hyperparameters)
session_results = Result.objects.filter(session=session, session_results = Result.objects.filter(session=session,
creation_time__lt=result.creation_time) creation_time__lt=result.creation_time)
useful_results_cnt = len(session_results) results_cnt = len(session_results)
first_valid_result = -1 first_valid_result = -1
for i, result in enumerate(session_results): for i, result in enumerate(session_results):
if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name: if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name:
useful_results_cnt -= 1 results_cnt -= 1
else: else:
last_valid_result = i last_valid_result = i
first_valid_result = i if first_valid_result == -1 else first_valid_result first_valid_result = i if first_valid_result == -1 else first_valid_result
@ -470,7 +475,7 @@ def train_ddpg(train_ddpg_input):
# Extract data from result and previous results # Extract data from result and previous results
result = Result.objects.filter(pk=result_id) result = Result.objects.filter(pk=result_id)
if useful_results_cnt == 0: if results_cnt == 0:
base_result_id = result_id base_result_id = result_id
prev_result_id = result_id prev_result_id = result_id
else: else: