simplify ddpg
This commit is contained in:
@@ -325,21 +325,20 @@ def train_ddpg(result_id):
|
||||
# Calculate the reward
|
||||
objective = objective / base_objective
|
||||
if metric_meta[target_objective].improvement == '(less is better)':
|
||||
reward = -objective * objective
|
||||
reward = -objective
|
||||
else:
|
||||
reward = objective * objective
|
||||
reward = objective
|
||||
LOG.info('reward: %f', reward)
|
||||
|
||||
# Update ddpg
|
||||
ddpg = DDPG(n_actions=knob_num, n_states=metric_num, alr=ACTOR_LEARNING_RATE,
|
||||
clr=CRITIC_LEARNING_RATE, gamma=0.0, batch_size=DDPG_BATCH_SIZE, tau=0.0)
|
||||
clr=CRITIC_LEARNING_RATE, gamma=0, batch_size=DDPG_BATCH_SIZE)
|
||||
if session.ddpg_actor_model and session.ddpg_critic_model:
|
||||
ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model)
|
||||
if session.ddpg_reply_memory:
|
||||
ddpg.replay_memory.set(session.ddpg_reply_memory)
|
||||
ddpg.add_sample(normalized_metric_data, knob_data, reward, normalized_metric_data, False)
|
||||
if len(ddpg.replay_memory) > 32:
|
||||
ddpg.update()
|
||||
ddpg.add_sample(normalized_metric_data, knob_data, reward, normalized_metric_data)
|
||||
ddpg.update()
|
||||
session.ddpg_actor_model, session.ddpg_critic_model = ddpg.get_model()
|
||||
session.ddpg_reply_memory = ddpg.replay_memory.get()
|
||||
session.save()
|
||||
@@ -362,8 +361,7 @@ def configuration_recommendation_ddpg(result_info): # pylint: disable=invalid-n
|
||||
knob_num = len(knob_labels)
|
||||
metric_num = len(metric_data)
|
||||
|
||||
ddpg = DDPG(n_actions=knob_num, n_states=metric_num, alr=ACTOR_LEARNING_RATE,
|
||||
clr=CRITIC_LEARNING_RATE, gamma=0.0, batch_size=DDPG_BATCH_SIZE, tau=0.0)
|
||||
ddpg = DDPG(n_actions=knob_num, n_states=metric_num)
|
||||
if session.ddpg_actor_model is not None and session.ddpg_critic_model is not None:
|
||||
ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model)
|
||||
if session.ddpg_reply_memory is not None:
|
||||
|
||||
Reference in New Issue
Block a user