fix original ddpg
This commit is contained in:
parent
b43a911bd7
commit
fd6dc1ea54
|
@ -340,8 +340,9 @@ def preprocessing(result_id, algorithm):
|
||||||
session_results = Result.objects.filter(session=session)
|
session_results = Result.objects.filter(session=session)
|
||||||
results_cnt = len(session_results)
|
results_cnt = len(session_results)
|
||||||
skip_ddpg = False
|
skip_ddpg = False
|
||||||
|
ignore = ['range_test']
|
||||||
for i, result in enumerate(session_results):
|
for i, result in enumerate(session_results):
|
||||||
if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name:
|
if any(symbol in result.metric_data.name for symbol in ignore):
|
||||||
results_cnt -= 1
|
results_cnt -= 1
|
||||||
if i == len(session_results) - 1 and algorithm == AlgorithmType.DDPG:
|
if i == len(session_results) - 1 and algorithm == AlgorithmType.DDPG:
|
||||||
skip_ddpg = True
|
skip_ddpg = True
|
||||||
|
@ -525,8 +526,9 @@ def train_ddpg(train_ddpg_input):
|
||||||
|
|
||||||
results_cnt = len(session_results)
|
results_cnt = len(session_results)
|
||||||
first_valid_result = -1
|
first_valid_result = -1
|
||||||
|
ignore = ['range_test']
|
||||||
for i, result in enumerate(session_results):
|
for i, result in enumerate(session_results):
|
||||||
if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name:
|
if any(symbol in result.metric_data.name for symbol in ignore):
|
||||||
results_cnt -= 1
|
results_cnt -= 1
|
||||||
else:
|
else:
|
||||||
last_valid_result = i
|
last_valid_result = i
|
||||||
|
@ -545,12 +547,26 @@ def train_ddpg(train_ddpg_input):
|
||||||
base_result = Result.objects.filter(pk=base_result_id)
|
base_result = Result.objects.filter(pk=base_result_id)
|
||||||
prev_result = Result.objects.filter(pk=prev_result_id)
|
prev_result = Result.objects.filter(pk=prev_result_id)
|
||||||
|
|
||||||
agg_data = DataUtil.aggregate_data(result)
|
agg_data = DataUtil.aggregate_data(result, ignore)
|
||||||
base_metric_data = (DataUtil.aggregate_data(base_result))['y_matrix'].flatten()
|
prev_agg_data = DataUtil.aggregate_data(prev_result, ignore)
|
||||||
prev_metric_data = (DataUtil.aggregate_data(prev_result))['y_matrix'].flatten()
|
metric_data = agg_data['y_matrix'].flatten()
|
||||||
|
prev_metric_data = prev_agg_data['y_matrix'].flatten()
|
||||||
|
base_metric_data = (DataUtil.aggregate_data(base_result, ignore))['y_matrix'].flatten()
|
||||||
|
|
||||||
target_objective = session.target_objective
|
target_objective = session.target_objective
|
||||||
prev_obj_idx = [i for i, n in enumerate(agg_data['y_columnlabels']) if n == target_objective]
|
# Filter ys by current target objective metric
|
||||||
|
target_obj_idx = [i for i, n in enumerate(agg_data['y_columnlabels']) if n == target_objective]
|
||||||
|
if len(target_obj_idx) == 0:
|
||||||
|
raise Exception(('[{}] Could not find target objective in metrics '
|
||||||
|
'(target_obj={})').format(task_name, target_objective))
|
||||||
|
if len(target_obj_idx) > 1:
|
||||||
|
raise Exception(('[{}] Found {} instances of target objective in '
|
||||||
|
'metrics (target_obj={})').format(task_name,
|
||||||
|
len(target_obj_idx),
|
||||||
|
target_objective))
|
||||||
|
objective = metric_data[target_obj_idx]
|
||||||
|
base_objective = base_metric_data[target_obj_idx]
|
||||||
|
prev_objective = prev_metric_data[target_obj_idx]
|
||||||
|
|
||||||
# Clean metric data
|
# Clean metric data
|
||||||
metric_data, metric_labels = clean_metric_data(agg_data['y_matrix'],
|
metric_data, metric_labels = clean_metric_data(agg_data['y_matrix'],
|
||||||
|
@ -558,6 +574,11 @@ def train_ddpg(train_ddpg_input):
|
||||||
metric_data = metric_data.flatten()
|
metric_data = metric_data.flatten()
|
||||||
metric_scalar = MinMaxScaler().fit(metric_data.reshape(1, -1))
|
metric_scalar = MinMaxScaler().fit(metric_data.reshape(1, -1))
|
||||||
normalized_metric_data = metric_scalar.transform(metric_data.reshape(1, -1))[0]
|
normalized_metric_data = metric_scalar.transform(metric_data.reshape(1, -1))[0]
|
||||||
|
prev_metric_data, _ = clean_metric_data(prev_agg_data['y_matrix'],
|
||||||
|
prev_agg_data['y_columnlabels'], session)
|
||||||
|
prev_metric_data = prev_metric_data.flatten()
|
||||||
|
prev_metric_scalar = MinMaxScaler().fit(prev_metric_data.reshape(1, -1))
|
||||||
|
prev_normalized_metric_data = prev_metric_scalar.transform(prev_metric_data.reshape(1, -1))[0]
|
||||||
|
|
||||||
# Clean knob data
|
# Clean knob data
|
||||||
cleaned_knob_data = clean_knob_data(agg_data['X_matrix'], agg_data['X_columnlabels'], session)
|
cleaned_knob_data = clean_knob_data(agg_data['X_matrix'], agg_data['X_columnlabels'], session)
|
||||||
|
@ -569,25 +590,13 @@ def train_ddpg(train_ddpg_input):
|
||||||
metric_num = len(metric_data)
|
metric_num = len(metric_data)
|
||||||
LOG.debug('%s: knob_num: %d, metric_num: %d', task_name, knob_num, metric_num)
|
LOG.debug('%s: knob_num: %d, metric_num: %d', task_name, knob_num, metric_num)
|
||||||
|
|
||||||
# Filter ys by current target objective metric
|
|
||||||
target_obj_idx = [i for i, n in enumerate(metric_labels) if n == target_objective]
|
|
||||||
if len(target_obj_idx) == 0:
|
|
||||||
raise Exception(('[{}] Could not find target objective in metrics '
|
|
||||||
'(target_obj={})').format(task_name, target_objective))
|
|
||||||
elif len(target_obj_idx) > 1:
|
|
||||||
raise Exception(('[{}] Found {} instances of target objective in '
|
|
||||||
'metrics (target_obj={})').format(task_name,
|
|
||||||
len(target_obj_idx),
|
|
||||||
target_objective))
|
|
||||||
objective = metric_data[target_obj_idx]
|
|
||||||
base_objective = base_metric_data[prev_obj_idx]
|
|
||||||
prev_objective = prev_metric_data[prev_obj_idx]
|
|
||||||
metric_meta = db.target_objectives.get_metric_metadata(session.dbms.pk,
|
metric_meta = db.target_objectives.get_metric_metadata(session.dbms.pk,
|
||||||
session.target_objective)
|
session.target_objective)
|
||||||
|
|
||||||
# Calculate the reward
|
# Calculate the reward
|
||||||
if params['DDPG_SIMPLE_REWARD']:
|
if params['DDPG_SIMPLE_REWARD']:
|
||||||
objective = objective / base_objective
|
objective = objective / base_objective
|
||||||
|
prev_normalized_metric_data = normalized_metric_data
|
||||||
if metric_meta[target_objective].improvement == '(less is better)':
|
if metric_meta[target_objective].improvement == '(less is better)':
|
||||||
reward = -objective
|
reward = -objective
|
||||||
else:
|
else:
|
||||||
|
@ -618,7 +627,7 @@ def train_ddpg(train_ddpg_input):
|
||||||
ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model)
|
ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model)
|
||||||
if session.ddpg_reply_memory:
|
if session.ddpg_reply_memory:
|
||||||
ddpg.replay_memory.set(session.ddpg_reply_memory)
|
ddpg.replay_memory.set(session.ddpg_reply_memory)
|
||||||
ddpg.add_sample(normalized_metric_data, knob_data, reward, normalized_metric_data)
|
ddpg.add_sample(prev_normalized_metric_data, knob_data, reward, normalized_metric_data)
|
||||||
for _ in range(params['DDPG_UPDATE_EPOCHS']):
|
for _ in range(params['DDPG_UPDATE_EPOCHS']):
|
||||||
ddpg.update()
|
ddpg.update()
|
||||||
session.ddpg_actor_model, session.ddpg_critic_model = ddpg.get_model()
|
session.ddpg_actor_model, session.ddpg_critic_model = ddpg.get_model()
|
||||||
|
|
Loading…
Reference in New Issue