improve ddpg
This commit is contained in:
parent
67a4a70c09
commit
21fce27291
|
@ -23,21 +23,21 @@ LOG = get_analysis_logger(__name__)
|
|||
|
||||
class Actor(nn.Module):
|
||||
|
||||
def __init__(self, n_states, n_actions):
|
||||
def __init__(self, n_states, n_actions, hidden_sizes):
|
||||
super(Actor, self).__init__()
|
||||
self.layers = nn.Sequential(
|
||||
nn.Linear(n_states, 128),
|
||||
nn.Linear(n_states, hidden_sizes[0]),
|
||||
nn.LeakyReLU(negative_slope=0.2),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.Linear(128, 128),
|
||||
nn.BatchNorm1d(hidden_sizes[0]),
|
||||
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
|
||||
nn.Tanh(),
|
||||
nn.Dropout(0.3),
|
||||
nn.BatchNorm1d(128),
|
||||
|
||||
nn.Linear(128, 64),
|
||||
nn.BatchNorm1d(hidden_sizes[1]),
|
||||
nn.Linear(hidden_sizes[1], hidden_sizes[2]),
|
||||
nn.Tanh(),
|
||||
nn.BatchNorm1d(64),
|
||||
nn.Linear(64, n_actions)
|
||||
nn.Dropout(0.3),
|
||||
nn.BatchNorm1d(hidden_sizes[2]),
|
||||
nn.Linear(hidden_sizes[2], n_actions)
|
||||
)
|
||||
# This act layer maps the output to (0, 1)
|
||||
self.act = nn.Sigmoid()
|
||||
|
@ -58,21 +58,21 @@ class Actor(nn.Module):
|
|||
|
||||
class Critic(nn.Module):
|
||||
|
||||
def __init__(self, n_states, n_actions):
|
||||
def __init__(self, n_states, n_actions, hidden_sizes):
|
||||
super(Critic, self).__init__()
|
||||
self.state_input = nn.Linear(n_states, 128)
|
||||
self.action_input = nn.Linear(n_actions, 128)
|
||||
self.state_input = nn.Linear(n_states, hidden_sizes[0])
|
||||
self.action_input = nn.Linear(n_actions, hidden_sizes[0])
|
||||
self.act = nn.Tanh()
|
||||
self.layers = nn.Sequential(
|
||||
nn.Linear(256, 256),
|
||||
nn.Linear(hidden_sizes[0] * 2, hidden_sizes[1]),
|
||||
nn.LeakyReLU(negative_slope=0.2),
|
||||
nn.BatchNorm1d(256),
|
||||
|
||||
nn.Linear(256, 64),
|
||||
nn.Dropout(0.3),
|
||||
nn.BatchNorm1d(hidden_sizes[1]),
|
||||
nn.Linear(hidden_sizes[1], hidden_sizes[2]),
|
||||
nn.Tanh(),
|
||||
nn.Dropout(0.3),
|
||||
nn.BatchNorm1d(64),
|
||||
nn.Linear(64, 1),
|
||||
nn.BatchNorm1d(hidden_sizes[2]),
|
||||
nn.Linear(hidden_sizes[2], 1),
|
||||
)
|
||||
self._init_weights()
|
||||
|
||||
|
@ -100,7 +100,8 @@ class Critic(nn.Module):
|
|||
class DDPG(object):
|
||||
|
||||
def __init__(self, n_states, n_actions, model_name='', alr=0.001, clr=0.001,
|
||||
gamma=0.9, batch_size=32, tau=0.002, shift=0, memory_size=100000):
|
||||
gamma=0.9, batch_size=32, tau=0.002, shift=0, memory_size=100000,
|
||||
a_hidden_sizes=[128, 128, 64], c_hidden_sizes=[128, 256, 64]):
|
||||
self.n_states = n_states
|
||||
self.n_actions = n_actions
|
||||
self.alr = alr
|
||||
|
@ -109,6 +110,8 @@ class DDPG(object):
|
|||
self.batch_size = batch_size
|
||||
self.gamma = gamma
|
||||
self.tau = tau
|
||||
self.a_hidden_sizes = a_hidden_sizes
|
||||
self.c_hidden_sizes = c_hidden_sizes
|
||||
self.shift = shift
|
||||
|
||||
self._build_network()
|
||||
|
@ -121,10 +124,10 @@ class DDPG(object):
|
|||
return Variable(torch.FloatTensor(x))
|
||||
|
||||
def _build_network(self):
|
||||
self.actor = Actor(self.n_states, self.n_actions)
|
||||
self.target_actor = Actor(self.n_states, self.n_actions)
|
||||
self.critic = Critic(self.n_states, self.n_actions)
|
||||
self.target_critic = Critic(self.n_states, self.n_actions)
|
||||
self.actor = Actor(self.n_states, self.n_actions, self.a_hidden_sizes)
|
||||
self.target_actor = Actor(self.n_states, self.n_actions, self.a_hidden_sizes)
|
||||
self.critic = Critic(self.n_states, self.n_actions, self.c_hidden_sizes)
|
||||
self.target_critic = Critic(self.n_states, self.n_actions, self.c_hidden_sizes)
|
||||
|
||||
# Copy actor's parameters
|
||||
self._update_target(self.target_actor, self.actor, tau=1.0)
|
||||
|
|
|
@ -25,6 +25,7 @@ from analysis.gpr import gpr_models # noqa
|
|||
from analysis.gpr import ucb # noqa
|
||||
from analysis.gpr.optimize import tf_optimize # noqa
|
||||
|
||||
|
||||
LOG = get_analysis_logger(__name__)
|
||||
|
||||
|
||||
|
@ -106,8 +107,10 @@ def ddpg(env, config, n_loops=100):
|
|||
a_lr = config['a_lr']
|
||||
c_lr = config['c_lr']
|
||||
n_epochs = config['n_epochs']
|
||||
ahs = config['a_hidden_sizes']
|
||||
chs = config['c_hidden_sizes']
|
||||
model_ddpg = DDPG(n_actions=env.knob_dim, n_states=env.metric_dim, gamma=gamma,
|
||||
clr=c_lr, alr=a_lr, shift=0.1)
|
||||
clr=c_lr, alr=a_lr, shift=0, a_hidden_sizes=ahs, c_hidden_sizes=chs)
|
||||
knob_data = np.random.rand(env.knob_dim)
|
||||
prev_metric_data = np.zeros(env.metric_dim)
|
||||
|
||||
|
@ -122,7 +125,7 @@ def ddpg(env, config, n_loops=100):
|
|||
|
||||
for i in range(n_loops):
|
||||
reward, metric_data = env.simulate(knob_data)
|
||||
model_ddpg.add_sample(prev_metric_data, prev_knob_data, prev_reward, metric_data)
|
||||
model_ddpg.add_sample(prev_metric_data, prev_knob_data, prev_reward, prev_metric_data)
|
||||
prev_metric_data = metric_data
|
||||
prev_knob_data = knob_data
|
||||
prev_reward = reward
|
||||
|
@ -184,6 +187,7 @@ def dnn(env, config, n_loops=100):
|
|||
actions, rewards = memory.get_all()
|
||||
model_nn.fit(np.array(actions), -np.array(rewards), fit_epochs=50)
|
||||
res = model_nn.recommend(X_samples, Xmin, Xmax, recommend_epochs=10, explore=False)
|
||||
|
||||
best_config_idx = np.argmin(res.minl.ravel())
|
||||
best_config = res.minl_conf[best_config_idx, :]
|
||||
if ou_process:
|
||||
|
@ -313,7 +317,7 @@ def gpr_new(env, config, n_loops=100):
|
|||
model_kwargs['hyperparameters'] = None
|
||||
model_kwargs['optimize_hyperparameters'] = optimize_hyperparams
|
||||
|
||||
X_new, ypred, model_params, hyperparameters = run_optimize(np.array(actions),
|
||||
X_new, ypred, _, hyperparameters = run_optimize(np.array(actions),
|
||||
-np.array(rewards),
|
||||
X_samples,
|
||||
model_name,
|
||||
|
@ -342,8 +346,8 @@ def plotlines(xs, results, labels, title, path):
|
|||
N = 1
|
||||
weights = np.ones(N)
|
||||
for x_axis, result, label in zip(xs, results, labels):
|
||||
result = np.convolve(weights/weights.sum(), result.flatten())[N-1:-N+1]
|
||||
lines.append(plt.plot(x_axis[:-N+1], result, label=label, lw=4)[0])
|
||||
result = np.convolve(weights/weights.sum(), result.flatten())[N-1:-N]
|
||||
lines.append(plt.plot(x_axis[:-N], result, label=label, lw=4)[0])
|
||||
plt.legend(handles=lines, fontsize=30)
|
||||
plt.title(title, fontsize=25)
|
||||
plt.xticks(fontsize=25)
|
||||
|
@ -357,8 +361,8 @@ def plotlines(xs, results, labels, title, path):
|
|||
def run(tuners, configs, labels, title, env, n_loops, n_repeats):
|
||||
if not plt:
|
||||
LOG.info("Cannot import matplotlib. Will write results to files instead of figures.")
|
||||
random.seed(0)
|
||||
np.random.seed(1)
|
||||
random.seed(2)
|
||||
np.random.seed(2)
|
||||
torch.manual_seed(0)
|
||||
results = []
|
||||
xs = []
|
||||
|
@ -385,16 +389,17 @@ def run(tuners, configs, labels, title, env, n_loops, n_repeats):
|
|||
|
||||
|
||||
def main():
|
||||
env = Environment(knob_dim=24, metric_dim=60, modes=[2], reward_variance=0.05)
|
||||
title = 'compare'
|
||||
n_repeats = [1, 1, 1, 1]
|
||||
n_loops = 80
|
||||
configs = [{'gamma': 0., 'c_lr': 0.001, 'a_lr': 0.01, 'num_collections': 50, 'n_epochs': 50},
|
||||
{'num_samples': 30, 'num_collections': 50},
|
||||
{'num_samples': 30, 'num_collections': 50},
|
||||
{'num_samples': 30, 'num_collections': 50}]
|
||||
tuners = [ddpg, gpr_new, dnn, gpr]
|
||||
labels = [tuner.__name__ for tuner in tuners]
|
||||
env = Environment(knob_dim=8, metric_dim=60, modes=[2], reward_variance=0.15)
|
||||
title = 'ddpg_structure_nodrop'
|
||||
n_repeats = [2, 2]
|
||||
n_loops = 100
|
||||
configs = [{'gamma': 0., 'c_lr': 0.001, 'a_lr': 0.02, 'num_collections': 1, 'n_epochs': 30,
|
||||
'a_hidden_sizes': [128, 128, 64], 'c_hidden_sizes': [64, 128, 64]},
|
||||
{'gamma': 0., 'c_lr': 0.001, 'a_lr': 0.02, 'num_collections': 1, 'n_epochs': 30,
|
||||
'a_hidden_sizes': [64, 64, 32], 'c_hidden_sizes': [64, 128, 64]},
|
||||
]
|
||||
tuners = [ddpg, ddpg]
|
||||
labels = ['1', '2']
|
||||
run(tuners, configs, labels, title, env, n_loops, n_repeats)
|
||||
|
||||
|
||||
|
|
|
@ -25,15 +25,18 @@ class TestDDPG(unittest.TestCase):
|
|||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
super(TestDDPG, cls).setUpClass()
|
||||
cls.ddpg = DDPG(n_actions=1, n_states=1, gamma=0)
|
||||
for _ in range(700):
|
||||
knob_data = np.array([random.random()])
|
||||
prev_metric_data = np.array([random.random()])
|
||||
cls.ddpg = DDPG(n_actions=1, n_states=1, gamma=0, alr=0.02)
|
||||
knob_data = np.zeros(1)
|
||||
metric_data = np.array([random.random()])
|
||||
for _ in range(100):
|
||||
prev_metric_data = metric_data
|
||||
metric_data = np.array([random.random()])
|
||||
reward = 1.0 if (prev_metric_data[0] - 0.5) * (knob_data[0] - 0.5) > 0 else 0.0
|
||||
reward = np.array([reward])
|
||||
cls.ddpg.add_sample(prev_metric_data, knob_data, reward, metric_data)
|
||||
for _ in range(10):
|
||||
cls.ddpg.update()
|
||||
knob_data = cls.ddpg.choose_action(metric_data)
|
||||
|
||||
def test_ddpg_ypreds(self):
|
||||
total_reward = 0.0
|
||||
|
|
|
@ -77,7 +77,16 @@ DNN_DEBUG_INTERVAL = 100
|
|||
DDPG_BATCH_SIZE = 32
|
||||
|
||||
# Learning rate of actor network
|
||||
ACTOR_LEARNING_RATE = 0.01
|
||||
ACTOR_LEARNING_RATE = 0.02
|
||||
|
||||
# Learning rate of critic network
|
||||
CRITIC_LEARNING_RATE = 0.001
|
||||
|
||||
# Number of update epochs per iteration
|
||||
UPDATE_EPOCHS = 30
|
||||
|
||||
# The number of hidden units in each layer of the actor MLP
|
||||
ACTOR_HIDDEN_SIZES = [128, 128, 64]
|
||||
|
||||
# The number of hidden units in each layer of the critic MLP
|
||||
CRITIC_HIDDEN_SIZES = [64, 128, 64]
|
||||
|
|
|
@ -31,7 +31,8 @@ from website.settings import (DEFAULT_LENGTH_SCALE, DEFAULT_MAGNITUDE,
|
|||
DEFAULT_EPSILON, MAX_ITER, GPR_EPS,
|
||||
DEFAULT_SIGMA_MULTIPLIER, DEFAULT_MU_MULTIPLIER,
|
||||
DDPG_BATCH_SIZE, ACTOR_LEARNING_RATE,
|
||||
CRITIC_LEARNING_RATE,
|
||||
CRITIC_LEARNING_RATE, UPDATE_EPOCHS,
|
||||
ACTOR_HIDDEN_SIZES, CRITIC_HIDDEN_SIZES,
|
||||
DNN_TRAIN_ITER, DNN_EXPLORE, DNN_EXPLORE_ITER,
|
||||
DNN_NOISE_SCALE_BEGIN, DNN_NOISE_SCALE_END,
|
||||
DNN_DEBUG, DNN_DEBUG_INTERVAL)
|
||||
|
@ -278,12 +279,9 @@ def train_ddpg(result_id):
|
|||
result = Result.objects.get(pk=result_id)
|
||||
session = Result.objects.get(pk=result_id).session
|
||||
session_results = Result.objects.filter(session=session,
|
||||
creation_time__lt=result.creation_time)
|
||||
creation_time__lte=result.creation_time)
|
||||
result_info = {}
|
||||
result_info['newest_result_id'] = result_id
|
||||
if len(session_results) == 0:
|
||||
LOG.info('No previous result. Abort.')
|
||||
return result_info
|
||||
|
||||
# Extract data from result
|
||||
result = Result.objects.filter(pk=result_id)
|
||||
|
@ -332,13 +330,14 @@ def train_ddpg(result_id):
|
|||
|
||||
# Update ddpg
|
||||
ddpg = DDPG(n_actions=knob_num, n_states=metric_num, alr=ACTOR_LEARNING_RATE,
|
||||
clr=CRITIC_LEARNING_RATE, gamma=0, batch_size=DDPG_BATCH_SIZE)
|
||||
clr=CRITIC_LEARNING_RATE, gamma=0, batch_size=DDPG_BATCH_SIZE,
|
||||
a_hidden_sizes=ACTOR_HIDDEN_SIZES, c_hidden_sizes=CRITIC_HIDDEN_SIZES)
|
||||
if session.ddpg_actor_model and session.ddpg_critic_model:
|
||||
ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model)
|
||||
if session.ddpg_reply_memory:
|
||||
ddpg.replay_memory.set(session.ddpg_reply_memory)
|
||||
ddpg.add_sample(normalized_metric_data, knob_data, reward, normalized_metric_data)
|
||||
for _ in range(25):
|
||||
for _ in range(UPDATE_EPOCHS):
|
||||
ddpg.update()
|
||||
session.ddpg_actor_model, session.ddpg_critic_model = ddpg.get_model()
|
||||
session.ddpg_reply_memory = ddpg.replay_memory.get()
|
||||
|
@ -362,7 +361,8 @@ def configuration_recommendation_ddpg(result_info): # pylint: disable=invalid-n
|
|||
knob_num = len(knob_labels)
|
||||
metric_num = len(metric_data)
|
||||
|
||||
ddpg = DDPG(n_actions=knob_num, n_states=metric_num)
|
||||
ddpg = DDPG(n_actions=knob_num, n_states=metric_num, a_hidden_sizes=ACTOR_HIDDEN_SIZES,
|
||||
c_hidden_sizes=CRITIC_HIDDEN_SIZES)
|
||||
if session.ddpg_actor_model is not None and session.ddpg_critic_model is not None:
|
||||
ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model)
|
||||
if session.ddpg_reply_memory is not None:
|
||||
|
|
Loading…
Reference in New Issue