fix original ddpg

This commit is contained in:
yangdsh 2020-04-14 00:05:46 +00:00 committed by Dana Van Aken
parent b43a911bd7
commit fd6dc1ea54
1 changed files with 29 additions and 20 deletions

View File

@ -340,8 +340,9 @@ def preprocessing(result_id, algorithm):
session_results = Result.objects.filter(session=session) session_results = Result.objects.filter(session=session)
results_cnt = len(session_results) results_cnt = len(session_results)
skip_ddpg = False skip_ddpg = False
ignore = ['range_test']
for i, result in enumerate(session_results): for i, result in enumerate(session_results):
if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name: if any(symbol in result.metric_data.name for symbol in ignore):
results_cnt -= 1 results_cnt -= 1
if i == len(session_results) - 1 and algorithm == AlgorithmType.DDPG: if i == len(session_results) - 1 and algorithm == AlgorithmType.DDPG:
skip_ddpg = True skip_ddpg = True
@ -525,8 +526,9 @@ def train_ddpg(train_ddpg_input):
results_cnt = len(session_results) results_cnt = len(session_results)
first_valid_result = -1 first_valid_result = -1
ignore = ['range_test']
for i, result in enumerate(session_results): for i, result in enumerate(session_results):
if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name: if any(symbol in result.metric_data.name for symbol in ignore):
results_cnt -= 1 results_cnt -= 1
else: else:
last_valid_result = i last_valid_result = i
@ -545,12 +547,26 @@ def train_ddpg(train_ddpg_input):
base_result = Result.objects.filter(pk=base_result_id) base_result = Result.objects.filter(pk=base_result_id)
prev_result = Result.objects.filter(pk=prev_result_id) prev_result = Result.objects.filter(pk=prev_result_id)
agg_data = DataUtil.aggregate_data(result) agg_data = DataUtil.aggregate_data(result, ignore)
base_metric_data = (DataUtil.aggregate_data(base_result))['y_matrix'].flatten() prev_agg_data = DataUtil.aggregate_data(prev_result, ignore)
prev_metric_data = (DataUtil.aggregate_data(prev_result))['y_matrix'].flatten() metric_data = agg_data['y_matrix'].flatten()
prev_metric_data = prev_agg_data['y_matrix'].flatten()
base_metric_data = (DataUtil.aggregate_data(base_result, ignore))['y_matrix'].flatten()
target_objective = session.target_objective target_objective = session.target_objective
prev_obj_idx = [i for i, n in enumerate(agg_data['y_columnlabels']) if n == target_objective] # Filter ys by current target objective metric
target_obj_idx = [i for i, n in enumerate(agg_data['y_columnlabels']) if n == target_objective]
if len(target_obj_idx) == 0:
raise Exception(('[{}] Could not find target objective in metrics '
'(target_obj={})').format(task_name, target_objective))
if len(target_obj_idx) > 1:
raise Exception(('[{}] Found {} instances of target objective in '
'metrics (target_obj={})').format(task_name,
len(target_obj_idx),
target_objective))
objective = metric_data[target_obj_idx]
base_objective = base_metric_data[target_obj_idx]
prev_objective = prev_metric_data[target_obj_idx]
# Clean metric data # Clean metric data
metric_data, metric_labels = clean_metric_data(agg_data['y_matrix'], metric_data, metric_labels = clean_metric_data(agg_data['y_matrix'],
@ -558,6 +574,11 @@ def train_ddpg(train_ddpg_input):
metric_data = metric_data.flatten() metric_data = metric_data.flatten()
metric_scalar = MinMaxScaler().fit(metric_data.reshape(1, -1)) metric_scalar = MinMaxScaler().fit(metric_data.reshape(1, -1))
normalized_metric_data = metric_scalar.transform(metric_data.reshape(1, -1))[0] normalized_metric_data = metric_scalar.transform(metric_data.reshape(1, -1))[0]
prev_metric_data, _ = clean_metric_data(prev_agg_data['y_matrix'],
prev_agg_data['y_columnlabels'], session)
prev_metric_data = prev_metric_data.flatten()
prev_metric_scalar = MinMaxScaler().fit(prev_metric_data.reshape(1, -1))
prev_normalized_metric_data = prev_metric_scalar.transform(prev_metric_data.reshape(1, -1))[0]
# Clean knob data # Clean knob data
cleaned_knob_data = clean_knob_data(agg_data['X_matrix'], agg_data['X_columnlabels'], session) cleaned_knob_data = clean_knob_data(agg_data['X_matrix'], agg_data['X_columnlabels'], session)
@ -569,25 +590,13 @@ def train_ddpg(train_ddpg_input):
metric_num = len(metric_data) metric_num = len(metric_data)
LOG.debug('%s: knob_num: %d, metric_num: %d', task_name, knob_num, metric_num) LOG.debug('%s: knob_num: %d, metric_num: %d', task_name, knob_num, metric_num)
# Filter ys by current target objective metric
target_obj_idx = [i for i, n in enumerate(metric_labels) if n == target_objective]
if len(target_obj_idx) == 0:
raise Exception(('[{}] Could not find target objective in metrics '
'(target_obj={})').format(task_name, target_objective))
elif len(target_obj_idx) > 1:
raise Exception(('[{}] Found {} instances of target objective in '
'metrics (target_obj={})').format(task_name,
len(target_obj_idx),
target_objective))
objective = metric_data[target_obj_idx]
base_objective = base_metric_data[prev_obj_idx]
prev_objective = prev_metric_data[prev_obj_idx]
metric_meta = db.target_objectives.get_metric_metadata(session.dbms.pk, metric_meta = db.target_objectives.get_metric_metadata(session.dbms.pk,
session.target_objective) session.target_objective)
# Calculate the reward # Calculate the reward
if params['DDPG_SIMPLE_REWARD']: if params['DDPG_SIMPLE_REWARD']:
objective = objective / base_objective objective = objective / base_objective
prev_normalized_metric_data = normalized_metric_data
if metric_meta[target_objective].improvement == '(less is better)': if metric_meta[target_objective].improvement == '(less is better)':
reward = -objective reward = -objective
else: else:
@ -618,7 +627,7 @@ def train_ddpg(train_ddpg_input):
ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model) ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model)
if session.ddpg_reply_memory: if session.ddpg_reply_memory:
ddpg.replay_memory.set(session.ddpg_reply_memory) ddpg.replay_memory.set(session.ddpg_reply_memory)
ddpg.add_sample(normalized_metric_data, knob_data, reward, normalized_metric_data) ddpg.add_sample(prev_normalized_metric_data, knob_data, reward, normalized_metric_data)
for _ in range(params['DDPG_UPDATE_EPOCHS']): for _ in range(params['DDPG_UPDATE_EPOCHS']):
ddpg.update() ddpg.update()
session.ddpg_actor_model, session.ddpg_critic_model = ddpg.get_model() session.ddpg_actor_model, session.ddpg_critic_model = ddpg.get_model()