From b4e5fb2e662283efcab71cb92f4eafd9f8bee571 Mon Sep 17 00:00:00 2001 From: yangdsh Date: Mon, 23 Mar 2020 03:38:54 +0000 Subject: [PATCH] fix ddpg --- server/website/website/tasks/async_tasks.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/server/website/website/tasks/async_tasks.py b/server/website/website/tasks/async_tasks.py index b4dfc36..a3b2eff 100644 --- a/server/website/website/tasks/async_tasks.py +++ b/server/website/website/tasks/async_tasks.py @@ -297,15 +297,20 @@ def preprocessing(result_id, algorithm): # implement a sampling technique to generate new training data). has_pipeline_data = PipelineData.objects.filter(workload=newest_result.workload).exists() session_results = Result.objects.filter(session=session) - useful_results_cnt = len(session_results) - for result in session_results: + results_cnt = len(session_results) + skip_ddpg = False + for i, result in enumerate(session_results): if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name: - useful_results_cnt -= 1 - if not has_pipeline_data or useful_results_cnt == 0 or session.tuning_session == 'lhs': + results_cnt -= 1 + if i == len(session_results) - 1 and algorithm == AlgorithmType.DDPG: + skip_ddpg = True + if not has_pipeline_data or results_cnt == 0 or skip_ddpg or session.tuning_session == 'lhs': if not has_pipeline_data and session.tuning_session == 'tuning_session': LOG.debug("Background tasks haven't ran for this workload yet, picking data with lhs.") - if useful_results_cnt == 0 and session.tuning_session == 'tuning_session': + if results_cnt == 0 and session.tuning_session == 'tuning_session': LOG.debug("Not enough data in this session, picking data with lhs.") + if skip_ddpg: + LOG.debug("The most recent result cannot be used by DDPG, picking data with lhs.") all_samples = JSONUtil.loads(session.lhs_samples) if len(all_samples) == 0: @@ -457,11 +462,11 @@ def train_ddpg(train_ddpg_input): params = JSONUtil.loads(session.hyperparameters) session_results = Result.objects.filter(session=session, creation_time__lt=result.creation_time) - useful_results_cnt = len(session_results) + results_cnt = len(session_results) first_valid_result = -1 for i, result in enumerate(session_results): if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name: - useful_results_cnt -= 1 + results_cnt -= 1 else: last_valid_result = i first_valid_result = i if first_valid_result == -1 else first_valid_result @@ -470,7 +475,7 @@ def train_ddpg(train_ddpg_input): # Extract data from result and previous results result = Result.objects.filter(pk=result_id) - if useful_results_cnt == 0: + if results_cnt == 0: base_result_id = result_id prev_result_id = result_id else: