From a3833d83b1e4ec33f79302e1f5daad90702921b5 Mon Sep 17 00:00:00 2001 From: yangdsh Date: Mon, 14 Oct 2019 17:24:17 +0000 Subject: [PATCH] normalize reward; square reward --- server/website/website/tasks/async_tasks.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/server/website/website/tasks/async_tasks.py b/server/website/website/tasks/async_tasks.py index f2861eb..b3d8c3d 100644 --- a/server/website/website/tasks/async_tasks.py +++ b/server/website/website/tasks/async_tasks.py @@ -235,8 +235,12 @@ def train_ddpg(result_id): # Extract data from result result = Result.objects.filter(pk=result_id) + base_result_id = session_results[0].pk + base_result = Result.objects.filter(pk=base_result_id) + agg_data = DataUtil.aggregate_data(result) metric_data = agg_data['y_matrix'].flatten() + base_metric_data = (DataUtil.aggregate_data(base_result))['y_matrix'].flatten() metric_scalar = MinMaxScaler().fit(metric_data.reshape(1, -1)) normalized_metric_data = metric_scalar.transform(metric_data.reshape(1, -1))[0] @@ -262,14 +266,16 @@ def train_ddpg(result_id): 'metrics (target_obj={})').format(len(target_obj_idx), target_objective)) objective = metric_data[target_obj_idx] + base_objective = base_metric_data[target_obj_idx] metric_meta = MetricCatalog.objects.get_metric_meta(result.session.dbms, result.session.target_objective) # Calculate the reward + objective = objective / base_objective if metric_meta[target_objective].improvement == '(less is better)': - reward = -objective + reward = -objective * objective else: - reward = objective + reward = objective * objective LOG.info('reward: %f', reward) # Update ddpg