From fd6dc1ea54a693379683cad2f538c6dc876f74f7 Mon Sep 17 00:00:00 2001 From: yangdsh Date: Tue, 14 Apr 2020 00:05:46 +0000 Subject: [PATCH] fix original ddpg --- server/website/website/tasks/async_tasks.py | 49 ++++++++++++--------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/server/website/website/tasks/async_tasks.py b/server/website/website/tasks/async_tasks.py index f6db926..9d2a11e 100644 --- a/server/website/website/tasks/async_tasks.py +++ b/server/website/website/tasks/async_tasks.py @@ -340,8 +340,9 @@ def preprocessing(result_id, algorithm): session_results = Result.objects.filter(session=session) results_cnt = len(session_results) skip_ddpg = False + ignore = ['range_test'] for i, result in enumerate(session_results): - if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name: + if any(symbol in result.metric_data.name for symbol in ignore): results_cnt -= 1 if i == len(session_results) - 1 and algorithm == AlgorithmType.DDPG: skip_ddpg = True @@ -525,8 +526,9 @@ def train_ddpg(train_ddpg_input): results_cnt = len(session_results) first_valid_result = -1 + ignore = ['range_test'] for i, result in enumerate(session_results): - if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name: + if any(symbol in result.metric_data.name for symbol in ignore): results_cnt -= 1 else: last_valid_result = i @@ -545,12 +547,26 @@ def train_ddpg(train_ddpg_input): base_result = Result.objects.filter(pk=base_result_id) prev_result = Result.objects.filter(pk=prev_result_id) - agg_data = DataUtil.aggregate_data(result) - base_metric_data = (DataUtil.aggregate_data(base_result))['y_matrix'].flatten() - prev_metric_data = (DataUtil.aggregate_data(prev_result))['y_matrix'].flatten() + agg_data = DataUtil.aggregate_data(result, ignore) + prev_agg_data = DataUtil.aggregate_data(prev_result, ignore) + metric_data = agg_data['y_matrix'].flatten() + prev_metric_data = prev_agg_data['y_matrix'].flatten() + base_metric_data = (DataUtil.aggregate_data(base_result, ignore))['y_matrix'].flatten() target_objective = session.target_objective - prev_obj_idx = [i for i, n in enumerate(agg_data['y_columnlabels']) if n == target_objective] + # Filter ys by current target objective metric + target_obj_idx = [i for i, n in enumerate(agg_data['y_columnlabels']) if n == target_objective] + if len(target_obj_idx) == 0: + raise Exception(('[{}] Could not find target objective in metrics ' + '(target_obj={})').format(task_name, target_objective)) + if len(target_obj_idx) > 1: + raise Exception(('[{}] Found {} instances of target objective in ' + 'metrics (target_obj={})').format(task_name, + len(target_obj_idx), + target_objective)) + objective = metric_data[target_obj_idx] + base_objective = base_metric_data[target_obj_idx] + prev_objective = prev_metric_data[target_obj_idx] # Clean metric data metric_data, metric_labels = clean_metric_data(agg_data['y_matrix'], @@ -558,6 +574,11 @@ def train_ddpg(train_ddpg_input): metric_data = metric_data.flatten() metric_scalar = MinMaxScaler().fit(metric_data.reshape(1, -1)) normalized_metric_data = metric_scalar.transform(metric_data.reshape(1, -1))[0] + prev_metric_data, _ = clean_metric_data(prev_agg_data['y_matrix'], + prev_agg_data['y_columnlabels'], session) + prev_metric_data = prev_metric_data.flatten() + prev_metric_scalar = MinMaxScaler().fit(prev_metric_data.reshape(1, -1)) + prev_normalized_metric_data = prev_metric_scalar.transform(prev_metric_data.reshape(1, -1))[0] # Clean knob data cleaned_knob_data = clean_knob_data(agg_data['X_matrix'], agg_data['X_columnlabels'], session) @@ -569,25 +590,13 @@ def train_ddpg(train_ddpg_input): metric_num = len(metric_data) LOG.debug('%s: knob_num: %d, metric_num: %d', task_name, knob_num, metric_num) - # Filter ys by current target objective metric - target_obj_idx = [i for i, n in enumerate(metric_labels) if n == target_objective] - if len(target_obj_idx) == 0: - raise Exception(('[{}] Could not find target objective in metrics ' - '(target_obj={})').format(task_name, target_objective)) - elif len(target_obj_idx) > 1: - raise Exception(('[{}] Found {} instances of target objective in ' - 'metrics (target_obj={})').format(task_name, - len(target_obj_idx), - target_objective)) - objective = metric_data[target_obj_idx] - base_objective = base_metric_data[prev_obj_idx] - prev_objective = prev_metric_data[prev_obj_idx] metric_meta = db.target_objectives.get_metric_metadata(session.dbms.pk, session.target_objective) # Calculate the reward if params['DDPG_SIMPLE_REWARD']: objective = objective / base_objective + prev_normalized_metric_data = normalized_metric_data if metric_meta[target_objective].improvement == '(less is better)': reward = -objective else: @@ -618,7 +627,7 @@ def train_ddpg(train_ddpg_input): ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model) if session.ddpg_reply_memory: ddpg.replay_memory.set(session.ddpg_reply_memory) - ddpg.add_sample(normalized_metric_data, knob_data, reward, normalized_metric_data) + ddpg.add_sample(prev_normalized_metric_data, knob_data, reward, normalized_metric_data) for _ in range(params['DDPG_UPDATE_EPOCHS']): ddpg.update() session.ddpg_actor_model, session.ddpg_critic_model = ddpg.get_model()