normalize reward; square reward
This commit is contained in:
parent
090387a176
commit
a3833d83b1
|
@ -235,8 +235,12 @@ def train_ddpg(result_id):
|
|||
|
||||
# Extract data from result
|
||||
result = Result.objects.filter(pk=result_id)
|
||||
base_result_id = session_results[0].pk
|
||||
base_result = Result.objects.filter(pk=base_result_id)
|
||||
|
||||
agg_data = DataUtil.aggregate_data(result)
|
||||
metric_data = agg_data['y_matrix'].flatten()
|
||||
base_metric_data = (DataUtil.aggregate_data(base_result))['y_matrix'].flatten()
|
||||
metric_scalar = MinMaxScaler().fit(metric_data.reshape(1, -1))
|
||||
normalized_metric_data = metric_scalar.transform(metric_data.reshape(1, -1))[0]
|
||||
|
||||
|
@ -262,14 +266,16 @@ def train_ddpg(result_id):
|
|||
'metrics (target_obj={})').format(len(target_obj_idx),
|
||||
target_objective))
|
||||
objective = metric_data[target_obj_idx]
|
||||
base_objective = base_metric_data[target_obj_idx]
|
||||
metric_meta = MetricCatalog.objects.get_metric_meta(result.session.dbms,
|
||||
result.session.target_objective)
|
||||
|
||||
# Calculate the reward
|
||||
objective = objective / base_objective
|
||||
if metric_meta[target_objective].improvement == '(less is better)':
|
||||
reward = -objective
|
||||
reward = -objective * objective
|
||||
else:
|
||||
reward = objective
|
||||
reward = objective * objective
|
||||
LOG.info('reward: %f', reward)
|
||||
|
||||
# Update ddpg
|
||||
|
|
Loading…
Reference in New Issue