improve ddpg
This commit is contained in:
@@ -77,7 +77,16 @@ DNN_DEBUG_INTERVAL = 100
|
||||
DDPG_BATCH_SIZE = 32
|
||||
|
||||
# Learning rate of actor network
|
||||
ACTOR_LEARNING_RATE = 0.01
|
||||
ACTOR_LEARNING_RATE = 0.02
|
||||
|
||||
# Learning rate of critic network
|
||||
CRITIC_LEARNING_RATE = 0.001
|
||||
|
||||
# Number of update epochs per iteration
|
||||
UPDATE_EPOCHS = 30
|
||||
|
||||
# The number of hidden units in each layer of the actor MLP
|
||||
ACTOR_HIDDEN_SIZES = [128, 128, 64]
|
||||
|
||||
# The number of hidden units in each layer of the critic MLP
|
||||
CRITIC_HIDDEN_SIZES = [64, 128, 64]
|
||||
|
||||
@@ -31,7 +31,8 @@ from website.settings import (DEFAULT_LENGTH_SCALE, DEFAULT_MAGNITUDE,
|
||||
DEFAULT_EPSILON, MAX_ITER, GPR_EPS,
|
||||
DEFAULT_SIGMA_MULTIPLIER, DEFAULT_MU_MULTIPLIER,
|
||||
DDPG_BATCH_SIZE, ACTOR_LEARNING_RATE,
|
||||
CRITIC_LEARNING_RATE,
|
||||
CRITIC_LEARNING_RATE, UPDATE_EPOCHS,
|
||||
ACTOR_HIDDEN_SIZES, CRITIC_HIDDEN_SIZES,
|
||||
DNN_TRAIN_ITER, DNN_EXPLORE, DNN_EXPLORE_ITER,
|
||||
DNN_NOISE_SCALE_BEGIN, DNN_NOISE_SCALE_END,
|
||||
DNN_DEBUG, DNN_DEBUG_INTERVAL)
|
||||
@@ -278,12 +279,9 @@ def train_ddpg(result_id):
|
||||
result = Result.objects.get(pk=result_id)
|
||||
session = Result.objects.get(pk=result_id).session
|
||||
session_results = Result.objects.filter(session=session,
|
||||
creation_time__lt=result.creation_time)
|
||||
creation_time__lte=result.creation_time)
|
||||
result_info = {}
|
||||
result_info['newest_result_id'] = result_id
|
||||
if len(session_results) == 0:
|
||||
LOG.info('No previous result. Abort.')
|
||||
return result_info
|
||||
|
||||
# Extract data from result
|
||||
result = Result.objects.filter(pk=result_id)
|
||||
@@ -332,13 +330,14 @@ def train_ddpg(result_id):
|
||||
|
||||
# Update ddpg
|
||||
ddpg = DDPG(n_actions=knob_num, n_states=metric_num, alr=ACTOR_LEARNING_RATE,
|
||||
clr=CRITIC_LEARNING_RATE, gamma=0, batch_size=DDPG_BATCH_SIZE)
|
||||
clr=CRITIC_LEARNING_RATE, gamma=0, batch_size=DDPG_BATCH_SIZE,
|
||||
a_hidden_sizes=ACTOR_HIDDEN_SIZES, c_hidden_sizes=CRITIC_HIDDEN_SIZES)
|
||||
if session.ddpg_actor_model and session.ddpg_critic_model:
|
||||
ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model)
|
||||
if session.ddpg_reply_memory:
|
||||
ddpg.replay_memory.set(session.ddpg_reply_memory)
|
||||
ddpg.add_sample(normalized_metric_data, knob_data, reward, normalized_metric_data)
|
||||
for _ in range(25):
|
||||
for _ in range(UPDATE_EPOCHS):
|
||||
ddpg.update()
|
||||
session.ddpg_actor_model, session.ddpg_critic_model = ddpg.get_model()
|
||||
session.ddpg_reply_memory = ddpg.replay_memory.get()
|
||||
@@ -362,7 +361,8 @@ def configuration_recommendation_ddpg(result_info): # pylint: disable=invalid-n
|
||||
knob_num = len(knob_labels)
|
||||
metric_num = len(metric_data)
|
||||
|
||||
ddpg = DDPG(n_actions=knob_num, n_states=metric_num)
|
||||
ddpg = DDPG(n_actions=knob_num, n_states=metric_num, a_hidden_sizes=ACTOR_HIDDEN_SIZES,
|
||||
c_hidden_sizes=CRITIC_HIDDEN_SIZES)
|
||||
if session.ddpg_actor_model is not None and session.ddpg_critic_model is not None:
|
||||
ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model)
|
||||
if session.ddpg_reply_memory is not None:
|
||||
|
||||
Reference in New Issue
Block a user