simplify ddpg

This commit is contained in:
yangdsh 2019-10-23 20:00:20 +00:00 committed by Dana Van Aken
parent 336221d886
commit 21f4f40b88
4 changed files with 46 additions and 244 deletions

View File

@ -8,14 +8,9 @@
# deep reinforcement learning." Proceedings of the 2019 International Conference # deep reinforcement learning." Proceedings of the 2019 International Conference
# on Management of Data. ACM, 2019 # on Management of Data. ACM, 2019
import os
import pickle import pickle
import math
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import init, Parameter
import torch.nn.functional as F
import torch.optim as optimizer import torch.optim as optimizer
from torch.autograd import Variable from torch.autograd import Variable
@ -26,68 +21,9 @@ from analysis.util import get_analysis_logger
LOG = get_analysis_logger(__name__) LOG = get_analysis_logger(__name__)
# code from https://github.com/Kaixhin/NoisyNet-A3C/blob/master/model.py
class NoisyLinear(nn.Linear):
def __init__(self, in_features, out_features, sigma_init=0.05, bias=True):
super(NoisyLinear, self).__init__(in_features, out_features, bias=True)
# reuse self.weight and self.bias
self.sigma_init = sigma_init
self.sigma_weight = Parameter(torch.Tensor(out_features, in_features))
self.sigma_bias = Parameter(torch.Tensor(out_features))
self.epsilon_weight = None
self.epsilon_bias = None
self.register_buffer('epsilon_weight', torch.zeros(out_features, in_features))
self.register_buffer('epsilon_bias', torch.zeros(out_features))
self.reset_parameters()
def reset_parameters(self):
# Only init after all params added (otherwise super().__init__() fails)
if hasattr(self, 'sigma_weight'):
init.uniform(self.weight, -math.sqrt(3 / self.in_features),
math.sqrt(3 / self.in_features))
init.uniform(self.bias, -math.sqrt(3 / self.in_features),
math.sqrt(3 / self.in_features))
init.constant(self.sigma_weight, self.sigma_init)
init.constant(self.sigma_bias, self.sigma_init)
def forward(self, x):
return F.linear(x, self.weight + self.sigma_weight * Variable(self.epsilon_weight),
self.bias + self.sigma_bias * Variable(self.epsilon_bias))
def sample_noise(self):
self.epsilon_weight = torch.randn(self.out_features, self.in_features)
self.epsilon_bias = torch.randn(self.out_features)
def remove_noise(self):
self.epsilon_weight = torch.zeros(self.out_features, self.in_features)
self.epsilon_bias = torch.zeros(self.out_features)
class Normalizer(object):
def __init__(self, mean, variance):
if isinstance(mean, list):
mean = np.array(mean)
if isinstance(variance, list):
variance = np.array(variance)
self.mean = mean
self.std = np.sqrt(variance + 0.00001)
def normalize(self, x):
if isinstance(x, list):
x = np.array(x)
x = x - self.mean
x = x / self.std
return Variable(torch.FloatTensor(x))
def __call__(self, x, *args, **kwargs):
return self.normalize(x)
class Actor(nn.Module): class Actor(nn.Module):
def __init__(self, n_states, n_actions, noisy=False): def __init__(self, n_states, n_actions):
super(Actor, self).__init__() super(Actor, self).__init__()
self.layers = nn.Sequential( self.layers = nn.Sequential(
nn.Linear(n_states, 128), nn.Linear(n_states, 128),
@ -100,13 +36,11 @@ class Actor(nn.Module):
nn.Linear(128, 64), nn.Linear(128, 64),
nn.Tanh(), nn.Tanh(),
nn.BatchNorm1d(64), nn.BatchNorm1d(64),
nn.Linear(64, n_actions)
) )
if noisy: # This act layer maps the output to (0, 1)
self.out = NoisyLinear(64, n_actions)
else:
self.out = nn.Linear(64, n_actions)
self._init_weights()
self.act = nn.Sigmoid() self.act = nn.Sigmoid()
self._init_weights()
def _init_weights(self): def _init_weights(self):
@ -115,13 +49,10 @@ class Actor(nn.Module):
m.weight.data.normal_(0.0, 1e-2) m.weight.data.normal_(0.0, 1e-2)
m.bias.data.uniform_(-0.1, 0.1) m.bias.data.uniform_(-0.1, 0.1)
def sample_noise(self): def forward(self, states): # pylint: disable=arguments-differ
self.out.sample_noise()
def forward(self, x): # pylint: disable=arguments-differ actions = self.act(self.layers(states))
return actions
out = self.act(self.out(self.layers(x)))
return out
class Critic(nn.Module): class Critic(nn.Module):
@ -156,11 +87,11 @@ class Critic(nn.Module):
m.weight.data.normal_(0.0, 1e-2) m.weight.data.normal_(0.0, 1e-2)
m.bias.data.uniform_(-0.1, 0.1) m.bias.data.uniform_(-0.1, 0.1)
def forward(self, x, action): # pylint: disable=arguments-differ def forward(self, states, actions): # pylint: disable=arguments-differ
x = self.act(self.state_input(x)) states = self.act(self.state_input(states))
action = self.act(self.action_input(action)) actions = self.act(self.action_input(actions))
_input = torch.cat([x, action], dim=1) _input = torch.cat([states, actions], dim=1)
value = self.layers(_input) value = self.layers(_input)
return value return value
@ -168,8 +99,7 @@ class Critic(nn.Module):
class DDPG(object): class DDPG(object):
def __init__(self, n_states, n_actions, model_name='', alr=0.001, clr=0.001, def __init__(self, n_states, n_actions, model_name='', alr=0.001, clr=0.001,
gamma=0.9, batch_size=32, tau=0.002, memory_size=100000, gamma=0.9, batch_size=32, tau=0.002, memory_size=100000):
ouprocess=True, mean_var_path=None, supervised=False):
self.n_states = n_states self.n_states = n_states
self.n_actions = n_actions self.n_actions = n_actions
self.alr = alr self.alr = alr
@ -178,25 +108,7 @@ class DDPG(object):
self.batch_size = batch_size self.batch_size = batch_size
self.gamma = gamma self.gamma = gamma
self.tau = tau self.tau = tau
self.ouprocess = ouprocess
if mean_var_path is None:
mean = np.zeros(n_states)
var = np.zeros(n_states)
elif not os.path.exists(mean_var_path):
mean = np.zeros(n_states)
var = np.zeros(n_states)
else:
with open(mean_var_path, 'rb') as f:
mean, var = pickle.load(f)
self.normalizer = Normalizer(mean, var)
if supervised:
self._build_actor()
LOG.info("Supervised Learning Initialized")
else:
# Build Network
self._build_network() self._build_network()
self.replay_memory = PrioritizedReplayMemory(capacity=memory_size) self.replay_memory = PrioritizedReplayMemory(capacity=memory_size)
@ -206,30 +118,12 @@ class DDPG(object):
def totensor(x): def totensor(x):
return Variable(torch.FloatTensor(x)) return Variable(torch.FloatTensor(x))
def _build_actor(self):
if self.ouprocess:
noisy = False
else:
noisy = True
self.actor = Actor(self.n_states, self.n_actions, noisy=noisy)
self.actor_criterion = nn.MSELoss()
self.actor_optimizer = optimizer.Adam(lr=self.alr, params=self.actor.parameters())
def _build_network(self): def _build_network(self):
if self.ouprocess: self.actor = Actor(self.n_states, self.n_actions)
noisy = False
else:
noisy = True
self.actor = Actor(self.n_states, self.n_actions, noisy=noisy)
self.target_actor = Actor(self.n_states, self.n_actions) self.target_actor = Actor(self.n_states, self.n_actions)
self.critic = Critic(self.n_states, self.n_actions) self.critic = Critic(self.n_states, self.n_actions)
self.target_critic = Critic(self.n_states, self.n_actions) self.target_critic = Critic(self.n_states, self.n_actions)
# if model params are provided, load them
if len(self.model_name):
self.load_model(model_name=self.model_name)
LOG.info("Loading model from file: %s", self.model_name)
# Copy actor's parameters # Copy actor's parameters
self._update_target(self.target_actor, self.actor, tau=1.0) self._update_target(self.target_actor, self.actor, tau=1.0)
@ -249,64 +143,57 @@ class DDPG(object):
target_param.data * (1 - tau) + param.data * tau target_param.data * (1 - tau) + param.data * tau
) )
def reset(self, sigma): def reset(self, sigma, theta):
self.noise.reset(sigma) self.noise.reset(sigma, theta)
def _sample_batch(self): def _sample_batch(self):
batch, idx = self.replay_memory.sample(self.batch_size) batch, idx = self.replay_memory.sample(self.batch_size)
# batch = self.replay_memory.sample(self.batch_size)
states = list(map(lambda x: x[0].tolist(), batch)) # pylint: disable=W0141 states = list(map(lambda x: x[0].tolist(), batch)) # pylint: disable=W0141
next_states = list(map(lambda x: x[3].tolist(), batch)) # pylint: disable=W0141
actions = list(map(lambda x: x[1].tolist(), batch)) # pylint: disable=W0141 actions = list(map(lambda x: x[1].tolist(), batch)) # pylint: disable=W0141
rewards = list(map(lambda x: x[2], batch)) # pylint: disable=W0141 rewards = list(map(lambda x: x[2], batch)) # pylint: disable=W0141
terminates = list(map(lambda x: x[4], batch)) # pylint: disable=W0141 next_states = list(map(lambda x: x[3].tolist(), batch)) # pylint: disable=W0141
return idx, states, next_states, actions, rewards, terminates return idx, states, next_states, actions, rewards
def add_sample(self, state, action, reward, next_state, terminate): def add_sample(self, state, action, reward, next_state):
self.critic.eval() self.critic.eval()
self.actor.eval() self.actor.eval()
self.target_critic.eval() self.target_critic.eval()
self.target_actor.eval() self.target_actor.eval()
batch_state = self.normalizer([state.tolist()]) batch_state = self.totensor([state.tolist()])
batch_next_state = self.normalizer([next_state.tolist()]) batch_next_state = self.totensor([next_state.tolist()])
current_value = self.critic(batch_state, self.totensor([action.tolist()])) current_value = self.critic(batch_state, self.totensor([action.tolist()]))
target_action = self.target_actor(batch_next_state) target_action = self.target_actor(batch_next_state)
target_value = self.totensor([reward]) \ target_value = self.totensor([reward]) \
+ self.totensor([0 if x else 1 for x in [terminate]]) \ + self.target_critic(batch_next_state, target_action) * self.gamma
* self.target_critic(batch_next_state, target_action) * self.gamma
error = float(torch.abs(current_value - target_value).data.numpy()[0]) error = float(torch.abs(current_value - target_value).data.numpy()[0])
self.target_actor.train() self.target_actor.train()
self.actor.train() self.actor.train()
self.critic.train() self.critic.train()
self.target_critic.train() self.target_critic.train()
self.replay_memory.add(error, (state, action, reward, next_state, terminate)) self.replay_memory.add(error, (state, action, reward, next_state))
def update(self): def update(self):
idxs, states, next_states, actions, rewards, terminates = self._sample_batch() idxs, states, next_states, actions, rewards = self._sample_batch()
batch_states = self.normalizer(states) batch_states = self.totensor(states)
batch_next_states = self.normalizer(next_states) batch_next_states = self.totensor(next_states)
batch_actions = self.totensor(actions) batch_actions = self.totensor(actions)
batch_rewards = self.totensor(rewards) batch_rewards = self.totensor(rewards)
mask = [0 if x else 1 for x in terminates]
mask = self.totensor(mask)
target_next_actions = self.target_actor(batch_next_states).detach() target_next_actions = self.target_actor(batch_next_states).detach()
target_next_value = self.target_critic(batch_next_states, target_next_actions).detach() target_next_value = self.target_critic(batch_next_states, target_next_actions).detach()
current_value = self.critic(batch_states, batch_actions) current_value = self.critic(batch_states, batch_actions)
# TODO (dongshen): This clause is the original clause, but it has some mistakes
# next_value = batch_rewards + mask * target_next_value * self.gamma
# Since terminate is always false, I remove the mask here.
next_value = batch_rewards + target_next_value * self.gamma next_value = batch_rewards + target_next_value * self.gamma
# Update Critic
# update prioritized memory # update prioritized memory
if isinstance(self.replay_memory, PrioritizedReplayMemory):
error = torch.abs(current_value - next_value).data.numpy() error = torch.abs(current_value - next_value).data.numpy()
for i in range(self.batch_size): for i in range(self.batch_size):
idx = idxs[i] idx = idxs[i]
self.replay_memory.update(idx, error[i][0]) self.replay_memory.update(idx, error[i][0])
# Update Critic
loss = self.loss_criterion(current_value, next_value) loss = self.loss_criterion(current_value, next_value)
self.critic_optimizer.zero_grad() self.critic_optimizer.zero_grad()
loss.backward() loss.backward()
@ -327,101 +214,17 @@ class DDPG(object):
return loss.data, policy_loss.data return loss.data, policy_loss.data
def choose_action(self, x): def choose_action(self, states):
""" Select Action according to the current state
Args:
x: np.array, current state
"""
self.actor.eval() self.actor.eval()
act = self.actor(self.normalizer([x.tolist()])).squeeze(0) act = self.actor(self.totensor([states.tolist()])).squeeze(0)
self.actor.train() self.actor.train()
action = act.data.numpy() action = act.data.numpy()
if self.ouprocess:
action += self.noise.noise() action += self.noise.noise()
return action.clip(0, 1) return action.clip(0, 1)
def sample_noise(self):
self.actor.sample_noise()
def load_model(self, model_name):
""" Load Torch Model from files
Args:
model_name: str, model path
"""
self.actor.load_state_dict(
torch.load('{}_actor.pth'.format(model_name))
)
self.critic.load_state_dict(
torch.load('{}_critic.pth'.format(model_name))
)
def save_model(self, model_name):
""" Save Torch Model from files
Args:
model_dir: str, model dir
title: str, model name
"""
torch.save(
self.actor.state_dict(),
'{}_actor.pth'.format(model_name)
)
torch.save(
self.critic.state_dict(),
'{}_critic.pth'.format(model_name)
)
def set_model(self, actor_dict, critic_dict): def set_model(self, actor_dict, critic_dict):
self.actor.load_state_dict(pickle.loads(actor_dict)) self.actor.load_state_dict(pickle.loads(actor_dict))
self.critic.load_state_dict(pickle.loads(critic_dict)) self.critic.load_state_dict(pickle.loads(critic_dict))
def get_model(self): def get_model(self):
return pickle.dumps(self.actor.state_dict()), pickle.dumps(self.critic.state_dict()) return pickle.dumps(self.actor.state_dict()), pickle.dumps(self.critic.state_dict())
def save_actor(self, path):
""" save actor network
Args:
path, str, path to save
"""
torch.save(
self.actor.state_dict(),
path
)
def load_actor(self, path):
""" load actor network
Args:
path, str, path to load
"""
self.actor.load_state_dict(
torch.load(path)
)
def train_actor(self, batch_data, is_train=True):
""" Train the actor separately with data
Args:
batch_data: tuple, (states, actions)
is_train: bool
Return:
_loss: float, training loss
"""
states, action = batch_data
if is_train:
self.actor.train()
pred = self.actor(self.normalizer(states))
action = self.totensor(action)
_loss = self.actor_criterion(pred, action)
self.actor_optimizer.zero_grad()
_loss.backward()
self.actor_optimizer.step()
else:
self.actor.eval()
pred = self.actor(self.normalizer(states))
action = self.totensor(action)
_loss = self.actor_criterion(pred, action)
return _loss.data[0]

View File

@ -21,10 +21,12 @@ class OUProcess(object):
self.sigma = sigma self.sigma = sigma
self.current_value = np.ones(self.n_actions) * self.mu self.current_value = np.ones(self.n_actions) * self.mu
def reset(self, sigma=0): def reset(self, sigma=0, theta=0):
self.current_value = np.ones(self.n_actions) * self.mu self.current_value = np.ones(self.n_actions) * self.mu
if sigma != 0: if sigma != 0:
self.sigma = sigma self.sigma = sigma
if theta != 0:
self.theta = theta
def noise(self): def noise(self):
x = self.current_value x = self.current_value

View File

@ -32,8 +32,7 @@ class TestDDPG(unittest.TestCase):
metric_data = np.array([random.random()]) metric_data = np.array([random.random()])
reward = 1.0 if (prev_metric_data[0] - 0.5) * (knob_data[0] - 0.5) > 0 else 0.0 reward = 1.0 if (prev_metric_data[0] - 0.5) * (knob_data[0] - 0.5) > 0 else 0.0
reward = np.array([reward]) reward = np.array([reward])
cls.ddpg.add_sample(prev_metric_data, knob_data, reward, metric_data, False) cls.ddpg.add_sample(prev_metric_data, knob_data, reward, metric_data)
if len(cls.ddpg.replay_memory) > 32:
cls.ddpg.update() cls.ddpg.update()
def test_ddpg_ypreds(self): def test_ddpg_ypreds(self):

View File

@ -325,20 +325,19 @@ def train_ddpg(result_id):
# Calculate the reward # Calculate the reward
objective = objective / base_objective objective = objective / base_objective
if metric_meta[target_objective].improvement == '(less is better)': if metric_meta[target_objective].improvement == '(less is better)':
reward = -objective * objective reward = -objective
else: else:
reward = objective * objective reward = objective
LOG.info('reward: %f', reward) LOG.info('reward: %f', reward)
# Update ddpg # Update ddpg
ddpg = DDPG(n_actions=knob_num, n_states=metric_num, alr=ACTOR_LEARNING_RATE, ddpg = DDPG(n_actions=knob_num, n_states=metric_num, alr=ACTOR_LEARNING_RATE,
clr=CRITIC_LEARNING_RATE, gamma=0.0, batch_size=DDPG_BATCH_SIZE, tau=0.0) clr=CRITIC_LEARNING_RATE, gamma=0, batch_size=DDPG_BATCH_SIZE)
if session.ddpg_actor_model and session.ddpg_critic_model: if session.ddpg_actor_model and session.ddpg_critic_model:
ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model) ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model)
if session.ddpg_reply_memory: if session.ddpg_reply_memory:
ddpg.replay_memory.set(session.ddpg_reply_memory) ddpg.replay_memory.set(session.ddpg_reply_memory)
ddpg.add_sample(normalized_metric_data, knob_data, reward, normalized_metric_data, False) ddpg.add_sample(normalized_metric_data, knob_data, reward, normalized_metric_data)
if len(ddpg.replay_memory) > 32:
ddpg.update() ddpg.update()
session.ddpg_actor_model, session.ddpg_critic_model = ddpg.get_model() session.ddpg_actor_model, session.ddpg_critic_model = ddpg.get_model()
session.ddpg_reply_memory = ddpg.replay_memory.get() session.ddpg_reply_memory = ddpg.replay_memory.get()
@ -362,8 +361,7 @@ def configuration_recommendation_ddpg(result_info): # pylint: disable=invalid-n
knob_num = len(knob_labels) knob_num = len(knob_labels)
metric_num = len(metric_data) metric_num = len(metric_data)
ddpg = DDPG(n_actions=knob_num, n_states=metric_num, alr=ACTOR_LEARNING_RATE, ddpg = DDPG(n_actions=knob_num, n_states=metric_num)
clr=CRITIC_LEARNING_RATE, gamma=0.0, batch_size=DDPG_BATCH_SIZE, tau=0.0)
if session.ddpg_actor_model is not None and session.ddpg_critic_model is not None: if session.ddpg_actor_model is not None and session.ddpg_critic_model is not None:
ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model) ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model)
if session.ddpg_reply_memory is not None: if session.ddpg_reply_memory is not None: