fix original ddpg
This commit is contained in:
parent
b43a911bd7
commit
fd6dc1ea54
|
@ -340,8 +340,9 @@ def preprocessing(result_id, algorithm):
|
|||
session_results = Result.objects.filter(session=session)
|
||||
results_cnt = len(session_results)
|
||||
skip_ddpg = False
|
||||
ignore = ['range_test']
|
||||
for i, result in enumerate(session_results):
|
||||
if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name:
|
||||
if any(symbol in result.metric_data.name for symbol in ignore):
|
||||
results_cnt -= 1
|
||||
if i == len(session_results) - 1 and algorithm == AlgorithmType.DDPG:
|
||||
skip_ddpg = True
|
||||
|
@ -525,8 +526,9 @@ def train_ddpg(train_ddpg_input):
|
|||
|
||||
results_cnt = len(session_results)
|
||||
first_valid_result = -1
|
||||
ignore = ['range_test']
|
||||
for i, result in enumerate(session_results):
|
||||
if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name:
|
||||
if any(symbol in result.metric_data.name for symbol in ignore):
|
||||
results_cnt -= 1
|
||||
else:
|
||||
last_valid_result = i
|
||||
|
@ -545,12 +547,26 @@ def train_ddpg(train_ddpg_input):
|
|||
base_result = Result.objects.filter(pk=base_result_id)
|
||||
prev_result = Result.objects.filter(pk=prev_result_id)
|
||||
|
||||
agg_data = DataUtil.aggregate_data(result)
|
||||
base_metric_data = (DataUtil.aggregate_data(base_result))['y_matrix'].flatten()
|
||||
prev_metric_data = (DataUtil.aggregate_data(prev_result))['y_matrix'].flatten()
|
||||
agg_data = DataUtil.aggregate_data(result, ignore)
|
||||
prev_agg_data = DataUtil.aggregate_data(prev_result, ignore)
|
||||
metric_data = agg_data['y_matrix'].flatten()
|
||||
prev_metric_data = prev_agg_data['y_matrix'].flatten()
|
||||
base_metric_data = (DataUtil.aggregate_data(base_result, ignore))['y_matrix'].flatten()
|
||||
|
||||
target_objective = session.target_objective
|
||||
prev_obj_idx = [i for i, n in enumerate(agg_data['y_columnlabels']) if n == target_objective]
|
||||
# Filter ys by current target objective metric
|
||||
target_obj_idx = [i for i, n in enumerate(agg_data['y_columnlabels']) if n == target_objective]
|
||||
if len(target_obj_idx) == 0:
|
||||
raise Exception(('[{}] Could not find target objective in metrics '
|
||||
'(target_obj={})').format(task_name, target_objective))
|
||||
if len(target_obj_idx) > 1:
|
||||
raise Exception(('[{}] Found {} instances of target objective in '
|
||||
'metrics (target_obj={})').format(task_name,
|
||||
len(target_obj_idx),
|
||||
target_objective))
|
||||
objective = metric_data[target_obj_idx]
|
||||
base_objective = base_metric_data[target_obj_idx]
|
||||
prev_objective = prev_metric_data[target_obj_idx]
|
||||
|
||||
# Clean metric data
|
||||
metric_data, metric_labels = clean_metric_data(agg_data['y_matrix'],
|
||||
|
@ -558,6 +574,11 @@ def train_ddpg(train_ddpg_input):
|
|||
metric_data = metric_data.flatten()
|
||||
metric_scalar = MinMaxScaler().fit(metric_data.reshape(1, -1))
|
||||
normalized_metric_data = metric_scalar.transform(metric_data.reshape(1, -1))[0]
|
||||
prev_metric_data, _ = clean_metric_data(prev_agg_data['y_matrix'],
|
||||
prev_agg_data['y_columnlabels'], session)
|
||||
prev_metric_data = prev_metric_data.flatten()
|
||||
prev_metric_scalar = MinMaxScaler().fit(prev_metric_data.reshape(1, -1))
|
||||
prev_normalized_metric_data = prev_metric_scalar.transform(prev_metric_data.reshape(1, -1))[0]
|
||||
|
||||
# Clean knob data
|
||||
cleaned_knob_data = clean_knob_data(agg_data['X_matrix'], agg_data['X_columnlabels'], session)
|
||||
|
@ -569,25 +590,13 @@ def train_ddpg(train_ddpg_input):
|
|||
metric_num = len(metric_data)
|
||||
LOG.debug('%s: knob_num: %d, metric_num: %d', task_name, knob_num, metric_num)
|
||||
|
||||
# Filter ys by current target objective metric
|
||||
target_obj_idx = [i for i, n in enumerate(metric_labels) if n == target_objective]
|
||||
if len(target_obj_idx) == 0:
|
||||
raise Exception(('[{}] Could not find target objective in metrics '
|
||||
'(target_obj={})').format(task_name, target_objective))
|
||||
elif len(target_obj_idx) > 1:
|
||||
raise Exception(('[{}] Found {} instances of target objective in '
|
||||
'metrics (target_obj={})').format(task_name,
|
||||
len(target_obj_idx),
|
||||
target_objective))
|
||||
objective = metric_data[target_obj_idx]
|
||||
base_objective = base_metric_data[prev_obj_idx]
|
||||
prev_objective = prev_metric_data[prev_obj_idx]
|
||||
metric_meta = db.target_objectives.get_metric_metadata(session.dbms.pk,
|
||||
session.target_objective)
|
||||
|
||||
# Calculate the reward
|
||||
if params['DDPG_SIMPLE_REWARD']:
|
||||
objective = objective / base_objective
|
||||
prev_normalized_metric_data = normalized_metric_data
|
||||
if metric_meta[target_objective].improvement == '(less is better)':
|
||||
reward = -objective
|
||||
else:
|
||||
|
@ -618,7 +627,7 @@ def train_ddpg(train_ddpg_input):
|
|||
ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model)
|
||||
if session.ddpg_reply_memory:
|
||||
ddpg.replay_memory.set(session.ddpg_reply_memory)
|
||||
ddpg.add_sample(normalized_metric_data, knob_data, reward, normalized_metric_data)
|
||||
ddpg.add_sample(prev_normalized_metric_data, knob_data, reward, normalized_metric_data)
|
||||
for _ in range(params['DDPG_UPDATE_EPOCHS']):
|
||||
ddpg.update()
|
||||
session.ddpg_actor_model, session.ddpg_critic_model = ddpg.get_model()
|
||||
|
|
Loading…
Reference in New Issue