save ddpg model in database

This commit is contained in:
Dongsheng Yang 2019-09-26 14:44:28 -04:00 committed by Dana Van Aken
parent c8fbaf6e4b
commit a3fcf59f07
12 changed files with 683 additions and 736 deletions

View File

@ -1,10 +1,10 @@
#
# __init__.py
#
# Copyright
#
from analysis.ddpg.ddpg import DDPG
__all__ = ["DDPG"]
#
# OtterTune - __init__.py
#
# Copyright (c) 2017-18, Carnegie Mellon University Database Group
#
from analysis.ddpg.ddpg import DDPG
__all__ = ["DDPG"]

View File

@ -1,509 +1,428 @@
#
# ddpg.py
#
# Copyright
#
"""
Deep Deterministic Policy Gradient Model
"""
import logging
import os
import sys
import math
import pickle
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init, Parameter
import torch.nn.functional as F
import torch.optim as optimizer
from torch.autograd import Variable
from analysis.ddpg.OUProcess import OUProcess
from analysis.ddpg.prioritized_replay_memory import PrioritizedReplayMemory
LOG = logging.getLogger(__name__)
sys.path.append('../')
# code from https://github.com/Kaixhin/NoisyNet-A3C/blob/master/model.py
class NoisyLinear(nn.Linear):
def __init__(self, in_features, out_features, sigma_init=0.05, bias=True):
super(NoisyLinear, self).__init__(in_features, out_features, bias=True)
# reuse self.weight and self.bias
self.sigma_init = sigma_init
self.sigma_weight = Parameter(torch.Tensor(out_features, in_features))
self.sigma_bias = Parameter(torch.Tensor(out_features))
self.register_buffer('epsilon_weight', torch.zeros(out_features, in_features))
self.register_buffer('epsilon_bias', torch.zeros(out_features))
self.reset_parameters()
def reset_parameters(self):
# Only init after all params added (otherwise super().__init__() fails)
if hasattr(self, 'sigma_weight'):
init.uniform(self.weight, -math.sqrt(3 / self.in_features),
math.sqrt(3 / self.in_features))
init.uniform(self.bias, -math.sqrt(3 / self.in_features),
math.sqrt(3 / self.in_features))
init.constant(self.sigma_weight, self.sigma_init)
init.constant(self.sigma_bias, self.sigma_init)
def forward(self, x):
return F.linear(x, self.weight + self.sigma_weight * Variable(self.epsilon_weight),
self.bias + self.sigma_bias * Variable(self.epsilon_bias))
# pylint: disable=attribute-defined-outside-init
def sample_noise(self):
self.epsilon_weight = torch.randn(self.out_features, self.in_features)
self.epsilon_bias = torch.randn(self.out_features)
def remove_noise(self):
self.epsilon_weight = torch.zeros(self.out_features, self.in_features)
self.epsilon_bias = torch.zeros(self.out_features)
# pylint: enable=attribute-defined-outside-init
class Normalizer(object):
def __init__(self, mean, variance):
if isinstance(mean, list):
mean = np.array(mean)
if isinstance(variance, list):
variance = np.array(variance)
self.mean = mean
self.std = np.sqrt(variance + 0.00001)
def normalize(self, x):
if isinstance(x, list):
x = np.array(x)
x = x - self.mean
x = x / self.std
return Variable(torch.FloatTensor(x))
def __call__(self, x, *args, **kwargs):
return self.normalize(x)
class ActorLow(nn.Module):
def __init__(self, n_states, n_actions, ):
super(ActorLow, self).__init__()
self.layers = nn.Sequential(
nn.BatchNorm1d(n_states),
nn.Linear(n_states, 32),
nn.LeakyReLU(negative_slope=0.2),
nn.BatchNorm1d(32),
nn.Linear(32, n_actions),
nn.LeakyReLU(negative_slope=0.2)
)
self._init_weights()
self.out_func = nn.Tanh()
def _init_weights(self):
for m in self.layers:
if isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, 1e-3)
m.bias.data.uniform_(-0.1, 0.1)
def forward(self, x): # pylint: disable=arguments-differ
out = self.layers(x)
return self.out_func(out)
class CriticLow(nn.Module):
def __init__(self, n_states, n_actions):
super(CriticLow, self).__init__()
self.state_input = nn.Linear(n_states, 32)
self.action_input = nn.Linear(n_actions, 32)
self.act = nn.LeakyReLU(negative_slope=0.2)
self.state_bn = nn.BatchNorm1d(n_states)
self.layers = nn.Sequential(
nn.Linear(64, 1),
nn.LeakyReLU(negative_slope=0.2),
)
self._init_weights()
def _init_weights(self):
self.state_input.weight.data.normal_(0.0, 1e-3)
self.state_input.bias.data.uniform_(-0.1, 0.1)
self.action_input.weight.data.normal_(0.0, 1e-3)
self.action_input.bias.data.uniform_(-0.1, 0.1)
for m in self.layers:
if isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, 1e-3)
m.bias.data.uniform_(-0.1, 0.1)
def forward(self, x, action): # pylint: disable=arguments-differ
x = self.state_bn(x)
x = self.act(self.state_input(x))
action = self.act(self.action_input(action))
_input = torch.cat([x, action], dim=1)
value = self.layers(_input)
return value
class Actor(nn.Module):
def __init__(self, n_states, n_actions, noisy=False):
super(Actor, self).__init__()
self.layers = nn.Sequential(
nn.Linear(n_states, 128),
nn.LeakyReLU(negative_slope=0.2),
nn.BatchNorm1d(128),
nn.Linear(128, 128),
nn.Tanh(),
nn.Dropout(0.3),
nn.Linear(128, 64),
nn.Tanh(),
nn.BatchNorm1d(64),
)
if noisy:
self.out = NoisyLinear(64, n_actions)
else:
self.out = nn.Linear(64, n_actions)
self._init_weights()
self.act = nn.Sigmoid()
def _init_weights(self):
for m in self.layers:
if isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, 1e-2)
m.bias.data.uniform_(-0.1, 0.1)
def sample_noise(self):
self.out.sample_noise()
def forward(self, x): # pylint: disable=arguments-differ
out = self.act(self.out(self.layers(x)))
return out
class Critic(nn.Module):
def __init__(self, n_states, n_actions):
super(Critic, self).__init__()
self.state_input = nn.Linear(n_states, 128)
self.action_input = nn.Linear(n_actions, 128)
self.act = nn.Tanh()
self.layers = nn.Sequential(
nn.Linear(256, 256),
nn.LeakyReLU(negative_slope=0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 64),
nn.Tanh(),
nn.Dropout(0.3),
nn.BatchNorm1d(64),
nn.Linear(64, 1),
)
self._init_weights()
def _init_weights(self):
self.state_input.weight.data.normal_(0.0, 1e-2)
self.state_input.bias.data.uniform_(-0.1, 0.1)
self.action_input.weight.data.normal_(0.0, 1e-2)
self.action_input.bias.data.uniform_(-0.1, 0.1)
for m in self.layers:
if isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, 1e-2)
m.bias.data.uniform_(-0.1, 0.1)
def forward(self, x, action): # pylint: disable=arguments-differ
x = self.act(self.state_input(x))
action = self.act(self.action_input(action))
_input = torch.cat([x, action], dim=1)
value = self.layers(_input)
return value
class DDPG(object):
def __init__(self, n_states, n_actions, opt=None, ouprocess=True, mean_var_path=None,
supervised=False):
""" DDPG Algorithms
Args:
n_states: int, dimension of states
n_actions: int, dimension of actions
opt: dict, params
supervised, bool, pre-train the actor with supervised learning
"""
self.n_states = n_states
self.n_actions = n_actions
if opt is None:
opt = {
'model': '',
'alr': 0.001,
'clr': 0.001,
'gamma': 0.9,
'batch_size': 32,
'tau': 0.002,
'memory_size': 100000
}
# Params
self.alr = opt['alr']
self.clr = opt['clr']
self.model_name = opt['model']
self.batch_size = opt['batch_size']
self.gamma = opt['gamma']
self.tau = opt['tau']
self.ouprocess = ouprocess
if mean_var_path is None:
mean = np.zeros(n_states)
var = np.zeros(n_states)
elif not os.path.exists(mean_var_path):
mean = np.zeros(n_states)
var = np.zeros(n_states)
else:
with open(mean_var_path, 'rb') as f:
mean, var = pickle.load(f)
self.normalizer = Normalizer(mean, var)
if supervised:
self._build_actor()
LOG.info("Supervised Learning Initialized")
else:
# Build Network
self._build_network()
LOG.info('Finish Initializing Networks')
self.replay_memory = PrioritizedReplayMemory(capacity=opt['memory_size'])
self.noise = OUProcess(n_actions)
# LOG.info('DDPG Initialzed!')
@staticmethod
def totensor(x):
return Variable(torch.FloatTensor(x))
def _build_actor(self):
if self.ouprocess:
noisy = False
else:
noisy = True
self.actor = Actor(self.n_states, self.n_actions, noisy=noisy)
self.actor_criterion = nn.MSELoss()
self.actor_optimizer = optimizer.Adam(lr=self.alr, params=self.actor.parameters())
def _build_network(self):
if self.ouprocess:
noisy = False
else:
noisy = True
self.actor = Actor(self.n_states, self.n_actions, noisy=noisy)
self.target_actor = Actor(self.n_states, self.n_actions)
self.critic = Critic(self.n_states, self.n_actions)
self.target_critic = Critic(self.n_states, self.n_actions)
# if model params are provided, load them
if len(self.model_name):
self.load_model(model_name=self.model_name)
LOG.info("Loading model from file: %s", self.model_name)
# Copy actor's parameters
self._update_target(self.target_actor, self.actor, tau=1.0)
# Copy critic's parameters
self._update_target(self.target_critic, self.critic, tau=1.0)
self.loss_criterion = nn.MSELoss()
self.actor_optimizer = optimizer.Adam(lr=self.alr, params=self.actor.parameters(),
weight_decay=1e-5)
self.critic_optimizer = optimizer.Adam(lr=self.clr, params=self.critic.parameters(),
weight_decay=1e-5)
@staticmethod
def _update_target(target, source, tau):
for (target_param, param) in zip(target.parameters(), source.parameters()):
target_param.data.copy_(
target_param.data * (1 - tau) + param.data * tau
)
def reset(self, sigma):
self.noise.reset(sigma)
def _sample_batch(self):
batch, idx = self.replay_memory.sample(self.batch_size)
# batch = self.replay_memory.sample(self.batch_size)
states = list(map(lambda x: x[0].tolist(), batch)) # pylint: disable=W0141
next_states = list(map(lambda x: x[3].tolist(), batch)) # pylint: disable=W0141
actions = list(map(lambda x: x[1].tolist(), batch)) # pylint: disable=W0141
rewards = list(map(lambda x: x[2], batch)) # pylint: disable=W0141
terminates = list(map(lambda x: x[4], batch)) # pylint: disable=W0141
return idx, states, next_states, actions, rewards, terminates
def add_sample(self, state, action, reward, next_state, terminate):
self.critic.eval()
self.actor.eval()
self.target_critic.eval()
self.target_actor.eval()
batch_state = self.normalizer([state.tolist()])
batch_next_state = self.normalizer([next_state.tolist()])
current_value = self.critic(batch_state, self.totensor([action.tolist()]))
target_action = self.target_actor(batch_next_state)
target_value = self.totensor([reward]) \
+ self.totensor([0 if x else 1 for x in [terminate]]) \
* self.target_critic(batch_next_state, target_action) * self.gamma
error = float(torch.abs(current_value - target_value).data.numpy()[0])
self.target_actor.train()
self.actor.train()
self.critic.train()
self.target_critic.train()
self.replay_memory.add(error, (state, action, reward, next_state, terminate))
def update(self):
idxs, states, next_states, actions, rewards, terminates = self._sample_batch()
batch_states = self.normalizer(states)
batch_next_states = self.normalizer(next_states)
batch_actions = self.totensor(actions)
batch_rewards = self.totensor(rewards)
mask = [0 if x else 1 for x in terminates]
mask = self.totensor(mask)
target_next_actions = self.target_actor(batch_next_states).detach()
target_next_value = self.target_critic(batch_next_states, target_next_actions).detach()
current_value = self.critic(batch_states, batch_actions)
# TODO (dongshen): This clause is the original clause, but it has some mistakes
# next_value = batch_rewards + mask * target_next_value * self.gamma
# Since terminate is always false, I remove the mask here.
next_value = batch_rewards + target_next_value * self.gamma
# Update Critic
# update prioritized memory
error = torch.abs(current_value - next_value).data.numpy()
for i in range(self.batch_size):
idx = idxs[i]
self.replay_memory.update(idx, error[i][0])
loss = self.loss_criterion(current_value, next_value)
self.critic_optimizer.zero_grad()
loss.backward()
self.critic_optimizer.step()
# Update Actor
self.critic.eval()
policy_loss = -self.critic(batch_states, self.actor(batch_states))
policy_loss = policy_loss.mean()
self.actor_optimizer.zero_grad()
policy_loss.backward()
self.actor_optimizer.step()
self.critic.train()
self._update_target(self.target_critic, self.critic, tau=self.tau)
self._update_target(self.target_actor, self.actor, tau=self.tau)
return loss.data, policy_loss.data
def choose_action(self, x):
""" Select Action according to the current state
Args:
x: np.array, current state
"""
self.actor.eval()
act = self.actor(self.normalizer([x.tolist()])).squeeze(0)
self.actor.train()
action = act.data.numpy()
if self.ouprocess:
action += self.noise.noise()
return action.clip(0, 1)
def sample_noise(self):
self.actor.sample_noise()
def load_model(self, model_name):
""" Load Torch Model from files
Args:
model_name: str, model path
"""
self.actor.load_state_dict(
torch.load('{}_actor.pth'.format(model_name))
)
self.critic.load_state_dict(
torch.load('{}_critic.pth'.format(model_name))
)
def save_model(self, model_name):
""" Save Torch Model from files
Args:
model_dir: str, model dir
title: str, model name
"""
torch.save(
self.actor.state_dict(),
'{}_actor.pth'.format(model_name)
)
torch.save(
self.critic.state_dict(),
'{}_critic.pth'.format(model_name)
)
def save_actor(self, path):
""" save actor network
Args:
path, str, path to save
"""
torch.save(
self.actor.state_dict(),
path
)
def load_actor(self, path):
""" load actor network
Args:
path, str, path to load
"""
self.actor.load_state_dict(
torch.load(path)
)
def train_actor(self, batch_data, is_train=True):
""" Train the actor separately with data
Args:
batch_data: tuple, (states, actions)
is_train: bool
Return:
_loss: float, training loss
"""
states, action = batch_data
if is_train:
self.actor.train()
pred = self.actor(self.normalizer(states))
action = self.totensor(action)
_loss = self.actor_criterion(pred, action)
self.actor_optimizer.zero_grad()
_loss.backward()
self.actor_optimizer.step()
else:
self.actor.eval()
pred = self.actor(self.normalizer(states))
action = self.totensor(action)
_loss = self.actor_criterion(pred, action)
return _loss.data[0]
#
# OtterTune - ddpg.py
#
# Copyright (c) 2017-18, Carnegie Mellon University Database Group
#
# from: https://github.com/KqSMea8/CDBTune
# Zhang, Ji, et al. "An end-to-end automatic cloud database tuning system using
# deep reinforcement learning." Proceedings of the 2019 International Conference
# on Management of Data. ACM, 2019
import os
import pickle
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init, Parameter
import torch.nn.functional as F
import torch.optim as optimizer
from torch.autograd import Variable
from analysis.ddpg.ou_process import OUProcess
from analysis.ddpg.prioritized_replay_memory import PrioritizedReplayMemory
from analysis.util import get_analysis_logger
LOG = get_analysis_logger(__name__)
# code from https://github.com/Kaixhin/NoisyNet-A3C/blob/master/model.py
class NoisyLinear(nn.Linear):
def __init__(self, in_features, out_features, sigma_init=0.05, bias=True):
super(NoisyLinear, self).__init__(in_features, out_features, bias=True)
# reuse self.weight and self.bias
self.sigma_init = sigma_init
self.sigma_weight = Parameter(torch.Tensor(out_features, in_features))
self.sigma_bias = Parameter(torch.Tensor(out_features))
self.epsilon_weight = None
self.epsilon_bias = None
self.register_buffer('epsilon_weight', torch.zeros(out_features, in_features))
self.register_buffer('epsilon_bias', torch.zeros(out_features))
self.reset_parameters()
def reset_parameters(self):
# Only init after all params added (otherwise super().__init__() fails)
if hasattr(self, 'sigma_weight'):
init.uniform(self.weight, -math.sqrt(3 / self.in_features),
math.sqrt(3 / self.in_features))
init.uniform(self.bias, -math.sqrt(3 / self.in_features),
math.sqrt(3 / self.in_features))
init.constant(self.sigma_weight, self.sigma_init)
init.constant(self.sigma_bias, self.sigma_init)
def forward(self, x):
return F.linear(x, self.weight + self.sigma_weight * Variable(self.epsilon_weight),
self.bias + self.sigma_bias * Variable(self.epsilon_bias))
def sample_noise(self):
self.epsilon_weight = torch.randn(self.out_features, self.in_features)
self.epsilon_bias = torch.randn(self.out_features)
def remove_noise(self):
self.epsilon_weight = torch.zeros(self.out_features, self.in_features)
self.epsilon_bias = torch.zeros(self.out_features)
class Normalizer(object):
def __init__(self, mean, variance):
if isinstance(mean, list):
mean = np.array(mean)
if isinstance(variance, list):
variance = np.array(variance)
self.mean = mean
self.std = np.sqrt(variance + 0.00001)
def normalize(self, x):
if isinstance(x, list):
x = np.array(x)
x = x - self.mean
x = x / self.std
return Variable(torch.FloatTensor(x))
def __call__(self, x, *args, **kwargs):
return self.normalize(x)
class Actor(nn.Module):
def __init__(self, n_states, n_actions, noisy=False):
super(Actor, self).__init__()
self.layers = nn.Sequential(
nn.Linear(n_states, 128),
nn.LeakyReLU(negative_slope=0.2),
nn.BatchNorm1d(128),
nn.Linear(128, 128),
nn.Tanh(),
nn.Dropout(0.3),
nn.Linear(128, 64),
nn.Tanh(),
nn.BatchNorm1d(64),
)
if noisy:
self.out = NoisyLinear(64, n_actions)
else:
self.out = nn.Linear(64, n_actions)
self._init_weights()
self.act = nn.Sigmoid()
def _init_weights(self):
for m in self.layers:
if isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, 1e-2)
m.bias.data.uniform_(-0.1, 0.1)
def sample_noise(self):
self.out.sample_noise()
def forward(self, x): # pylint: disable=arguments-differ
out = self.act(self.out(self.layers(x)))
return out
class Critic(nn.Module):
def __init__(self, n_states, n_actions):
super(Critic, self).__init__()
self.state_input = nn.Linear(n_states, 128)
self.action_input = nn.Linear(n_actions, 128)
self.act = nn.Tanh()
self.layers = nn.Sequential(
nn.Linear(256, 256),
nn.LeakyReLU(negative_slope=0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 64),
nn.Tanh(),
nn.Dropout(0.3),
nn.BatchNorm1d(64),
nn.Linear(64, 1),
)
self._init_weights()
def _init_weights(self):
self.state_input.weight.data.normal_(0.0, 1e-2)
self.state_input.bias.data.uniform_(-0.1, 0.1)
self.action_input.weight.data.normal_(0.0, 1e-2)
self.action_input.bias.data.uniform_(-0.1, 0.1)
for m in self.layers:
if isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, 1e-2)
m.bias.data.uniform_(-0.1, 0.1)
def forward(self, x, action): # pylint: disable=arguments-differ
x = self.act(self.state_input(x))
action = self.act(self.action_input(action))
_input = torch.cat([x, action], dim=1)
value = self.layers(_input)
return value
class DDPG(object):
def __init__(self, n_states, n_actions, model_name='', alr=0.001, clr=0.001,
gamma=0.9, batch_size=32, tau=0.002, memory_size=100000,
ouprocess=True, mean_var_path=None, supervised=False):
self.n_states = n_states
self.n_actions = n_actions
self.alr = alr
self.clr = clr
self.model_name = model_name
self.batch_size = batch_size
self.gamma = gamma
self.tau = tau
self.ouprocess = ouprocess
if mean_var_path is None:
mean = np.zeros(n_states)
var = np.zeros(n_states)
elif not os.path.exists(mean_var_path):
mean = np.zeros(n_states)
var = np.zeros(n_states)
else:
with open(mean_var_path, 'rb') as f:
mean, var = pickle.load(f)
self.normalizer = Normalizer(mean, var)
if supervised:
self._build_actor()
LOG.info("Supervised Learning Initialized")
else:
# Build Network
self._build_network()
LOG.info('Finish Initializing Networks')
self.replay_memory = PrioritizedReplayMemory(capacity=memory_size)
self.noise = OUProcess(n_actions)
@staticmethod
def totensor(x):
return Variable(torch.FloatTensor(x))
def _build_actor(self):
if self.ouprocess:
noisy = False
else:
noisy = True
self.actor = Actor(self.n_states, self.n_actions, noisy=noisy)
self.actor_criterion = nn.MSELoss()
self.actor_optimizer = optimizer.Adam(lr=self.alr, params=self.actor.parameters())
def _build_network(self):
if self.ouprocess:
noisy = False
else:
noisy = True
self.actor = Actor(self.n_states, self.n_actions, noisy=noisy)
self.target_actor = Actor(self.n_states, self.n_actions)
self.critic = Critic(self.n_states, self.n_actions)
self.target_critic = Critic(self.n_states, self.n_actions)
# if model params are provided, load them
if len(self.model_name):
self.load_model(model_name=self.model_name)
LOG.info("Loading model from file: %s", self.model_name)
# Copy actor's parameters
self._update_target(self.target_actor, self.actor, tau=1.0)
# Copy critic's parameters
self._update_target(self.target_critic, self.critic, tau=1.0)
self.loss_criterion = nn.MSELoss()
self.actor_optimizer = optimizer.Adam(lr=self.alr, params=self.actor.parameters(),
weight_decay=1e-5)
self.critic_optimizer = optimizer.Adam(lr=self.clr, params=self.critic.parameters(),
weight_decay=1e-5)
@staticmethod
def _update_target(target, source, tau):
for (target_param, param) in zip(target.parameters(), source.parameters()):
target_param.data.copy_(
target_param.data * (1 - tau) + param.data * tau
)
def reset(self, sigma):
self.noise.reset(sigma)
def _sample_batch(self):
batch, idx = self.replay_memory.sample(self.batch_size)
# batch = self.replay_memory.sample(self.batch_size)
states = list(map(lambda x: x[0].tolist(), batch)) # pylint: disable=W0141
next_states = list(map(lambda x: x[3].tolist(), batch)) # pylint: disable=W0141
actions = list(map(lambda x: x[1].tolist(), batch)) # pylint: disable=W0141
rewards = list(map(lambda x: x[2], batch)) # pylint: disable=W0141
terminates = list(map(lambda x: x[4], batch)) # pylint: disable=W0141
return idx, states, next_states, actions, rewards, terminates
def add_sample(self, state, action, reward, next_state, terminate):
self.critic.eval()
self.actor.eval()
self.target_critic.eval()
self.target_actor.eval()
batch_state = self.normalizer([state.tolist()])
batch_next_state = self.normalizer([next_state.tolist()])
current_value = self.critic(batch_state, self.totensor([action.tolist()]))
target_action = self.target_actor(batch_next_state)
target_value = self.totensor([reward]) \
+ self.totensor([0 if x else 1 for x in [terminate]]) \
* self.target_critic(batch_next_state, target_action) * self.gamma
error = float(torch.abs(current_value - target_value).data.numpy()[0])
self.target_actor.train()
self.actor.train()
self.critic.train()
self.target_critic.train()
self.replay_memory.add(error, (state, action, reward, next_state, terminate))
def update(self):
idxs, states, next_states, actions, rewards, terminates = self._sample_batch()
batch_states = self.normalizer(states)
batch_next_states = self.normalizer(next_states)
batch_actions = self.totensor(actions)
batch_rewards = self.totensor(rewards)
mask = [0 if x else 1 for x in terminates]
mask = self.totensor(mask)
target_next_actions = self.target_actor(batch_next_states).detach()
target_next_value = self.target_critic(batch_next_states, target_next_actions).detach()
current_value = self.critic(batch_states, batch_actions)
# TODO (dongshen): This clause is the original clause, but it has some mistakes
# next_value = batch_rewards + mask * target_next_value * self.gamma
# Since terminate is always false, I remove the mask here.
next_value = batch_rewards + target_next_value * self.gamma
# Update Critic
# update prioritized memory
error = torch.abs(current_value - next_value).data.numpy()
for i in range(self.batch_size):
idx = idxs[i]
self.replay_memory.update(idx, error[i][0])
loss = self.loss_criterion(current_value, next_value)
self.critic_optimizer.zero_grad()
loss.backward()
self.critic_optimizer.step()
# Update Actor
self.critic.eval()
policy_loss = -self.critic(batch_states, self.actor(batch_states))
policy_loss = policy_loss.mean()
self.actor_optimizer.zero_grad()
policy_loss.backward()
self.actor_optimizer.step()
self.critic.train()
self._update_target(self.target_critic, self.critic, tau=self.tau)
self._update_target(self.target_actor, self.actor, tau=self.tau)
return loss.data, policy_loss.data
def choose_action(self, x):
""" Select Action according to the current state
Args:
x: np.array, current state
"""
self.actor.eval()
act = self.actor(self.normalizer([x.tolist()])).squeeze(0)
self.actor.train()
action = act.data.numpy()
if self.ouprocess:
action += self.noise.noise()
return action.clip(0, 1)
def sample_noise(self):
self.actor.sample_noise()
def load_model(self, model_name):
""" Load Torch Model from files
Args:
model_name: str, model path
"""
self.actor.load_state_dict(
torch.load('{}_actor.pth'.format(model_name))
)
self.critic.load_state_dict(
torch.load('{}_critic.pth'.format(model_name))
)
def save_model(self, model_name):
""" Save Torch Model from files
Args:
model_dir: str, model dir
title: str, model name
"""
torch.save(
self.actor.state_dict(),
'{}_actor.pth'.format(model_name)
)
torch.save(
self.critic.state_dict(),
'{}_critic.pth'.format(model_name)
)
def set_model(self, actor_dict, critic_dict):
self.actor.load_state_dict(pickle.loads(actor_dict))
self.critic.load_state_dict(pickle.loads(critic_dict))
def get_model(self):
return pickle.dumps(self.actor.state_dict()), pickle.dumps(self.critic.state_dict())
def save_actor(self, path):
""" save actor network
Args:
path, str, path to save
"""
torch.save(
self.actor.state_dict(),
path
)
def load_actor(self, path):
""" load actor network
Args:
path, str, path to load
"""
self.actor.load_state_dict(
torch.load(path)
)
def train_actor(self, batch_data, is_train=True):
""" Train the actor separately with data
Args:
batch_data: tuple, (states, actions)
is_train: bool
Return:
_loss: float, training loss
"""
states, action = batch_data
if is_train:
self.actor.train()
pred = self.actor(self.normalizer(states))
action = self.totensor(action)
_loss = self.actor_criterion(pred, action)
self.actor_optimizer.zero_grad()
_loss.backward()
self.actor_optimizer.step()
else:
self.actor.eval()
pred = self.actor(self.normalizer(states))
action = self.totensor(action)
_loss = self.actor_criterion(pred, action)
return _loss.data[0]

View File

@ -1,41 +1,33 @@
#
# OUProcess.py
#
# Copyright
#
import numpy as np
# from https://github.com/songrotek/DDPG/blob/master/ou_noise.py
class OUProcess(object):
def __init__(self, n_actions, theta=0.15, mu=0, sigma=0.1, ):
self.n_actions = n_actions
self.theta = theta
self.mu = mu
self.sigma = sigma
self.current_value = np.ones(self.n_actions) * self.mu
def reset(self, sigma=0):
self.current_value = np.ones(self.n_actions) * self.mu
if sigma != 0:
self.sigma = sigma
def noise(self):
x = self.current_value
dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(len(x))
self.current_value = x + dx
return self.current_value
if __name__ == '__main__':
import matplotlib.pyplot as plt # pylint: disable=wrong-import-position
ou = OUProcess(3, theta=0.3) # pylint: disable=invalid-name
states = [] # pylint: disable=invalid-name
for i in range(1000):
states.append(ou.noise())
plt.plot(states)
plt.show()
#
# OtterTune - ou_process.py
#
# Copyright (c) 2017-18, Carnegie Mellon University Database Group
#
# from: https://github.com/KqSMea8/CDBTune
# Zhang, Ji, et al. "An end-to-end automatic cloud database tuning system using
# deep reinforcement learning." Proceedings of the 2019 International Conference
# on Management of Data. ACM, 2019
import numpy as np
class OUProcess(object):
def __init__(self, n_actions, theta=0.15, mu=0, sigma=0.1, ):
self.n_actions = n_actions
self.theta = theta
self.mu = mu
self.sigma = sigma
self.current_value = np.ones(self.n_actions) * self.mu
def reset(self, sigma=0):
self.current_value = np.ones(self.n_actions) * self.mu
if sigma != 0:
self.sigma = sigma
def noise(self):
x = self.current_value
dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(len(x))
self.current_value = x + dx
return self.current_value

View File

@ -1,121 +1,132 @@
#
# prioritized_replay_memory.py
#
# Copyright
#
import random
import pickle
import numpy as np
class SumTree(object):
write = 0
def __init__(self, capacity):
self.capacity = capacity
self.tree = np.zeros(2 * capacity - 1)
self.data = np.zeros(capacity, dtype=object)
self.num_entries = 0
def _propagate(self, idx, change):
parent = (idx - 1) // 2
self.tree[parent] += change
if parent != 0:
self._propagate(parent, change)
def _retrieve(self, idx, s):
left = 2 * idx + 1
right = left + 1
if left >= len(self.tree):
return idx
if s <= self.tree[left]:
return self._retrieve(left, s)
else:
return self._retrieve(right, s - self.tree[left])
def total(self):
return self.tree[0]
def add(self, p, data):
idx = self.write + self.capacity - 1
self.data[self.write] = data
self.update(idx, p)
self.write += 1
if self.write >= self.capacity:
self.write = 0
if self.num_entries < self.capacity:
self.num_entries += 1
def update(self, idx, p):
change = p - self.tree[idx]
self.tree[idx] = p
self._propagate(idx, change)
def get(self, s):
idx = self._retrieve(0, s)
data_idx = idx - self.capacity + 1
return [idx, self.tree[idx], self.data[data_idx]]
class PrioritizedReplayMemory(object):
def __init__(self, capacity):
self.tree = SumTree(capacity)
self.capacity = capacity
self.e = 0.01 # pylint: disable=invalid-name
self.a = 0.6 # pylint: disable=invalid-name
self.beta = 0.4
self.beta_increment_per_sampling = 0.001
def _get_priority(self, error):
return (error + self.e) ** self.a
def add(self, error, sample):
# (s, a, r, s, t)
p = self._get_priority(error)
self.tree.add(p, sample)
def __len__(self):
return self.tree.num_entries
def sample(self, n):
batch = []
idxs = []
segment = self.tree.total() / n
priorities = []
self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])
for i in range(n):
a = segment * i
b = segment * (i + 1)
s = random.uniform(a, b)
(idx, p, data) = self.tree.get(s)
priorities.append(p)
batch.append(data)
idxs.append(idx)
return batch, idxs
# sampling_probabilities = priorities / self.tree.total()
# is_weight = np.power(self.tree.num_entries * sampling_probabilities, -self.beta)
# is_weight /= is_weight.max()
def update(self, idx, error):
p = self._get_priority(error)
self.tree.update(idx, p)
def save(self, path):
f = open(path, 'wb')
pickle.dump({"tree": self.tree}, f)
f.close()
def load_memory(self, path):
with open(path, 'rb') as f:
_memory = pickle.load(f)
self.tree = _memory['tree']
#
# OtterTune - prioritized_replay_memory.py
#
# Copyright (c) 2017-18, Carnegie Mellon University Database Group
#
# from: https://github.com/KqSMea8/CDBTune
# Zhang, Ji, et al. "An end-to-end automatic cloud database tuning system using
# deep reinforcement learning." Proceedings of the 2019 International Conference
# on Management of Data. ACM, 2019
import random
import pickle
import numpy as np
class SumTree(object):
write = 0
def __init__(self, capacity):
self.capacity = capacity
self.tree = np.zeros(2 * capacity - 1)
self.data = np.zeros(capacity, dtype=object)
self.num_entries = 0
def _propagate(self, idx, change):
parent = (idx - 1) // 2
self.tree[parent] += change
if parent != 0:
self._propagate(parent, change)
def _retrieve(self, idx, s):
left = 2 * idx + 1
right = left + 1
if left >= len(self.tree):
return idx
if s <= self.tree[left]:
return self._retrieve(left, s)
else:
return self._retrieve(right, s - self.tree[left])
def total(self):
return self.tree[0]
def add(self, p, data):
idx = self.write + self.capacity - 1
self.data[self.write] = data
self.update(idx, p)
self.write += 1
if self.write >= self.capacity:
self.write = 0
if self.num_entries < self.capacity:
self.num_entries += 1
def update(self, idx, p):
change = p - self.tree[idx]
self.tree[idx] = p
self._propagate(idx, change)
def get(self, s):
idx = self._retrieve(0, s)
data_idx = idx - self.capacity + 1
return [idx, self.tree[idx], self.data[data_idx]]
class PrioritizedReplayMemory(object):
def __init__(self, capacity):
self.tree = SumTree(capacity)
self.capacity = capacity
self.e = 0.01 # pylint: disable=invalid-name
self.a = 0.6 # pylint: disable=invalid-name
self.beta = 0.4
self.beta_increment_per_sampling = 0.001
def _get_priority(self, error):
return (error + self.e) ** self.a
def add(self, error, sample):
# (s, a, r, s, t)
p = self._get_priority(error)
self.tree.add(p, sample)
def __len__(self):
return self.tree.num_entries
def sample(self, n):
batch = []
idxs = []
segment = self.tree.total() / n
priorities = []
self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])
for i in range(n):
a = segment * i
b = segment * (i + 1)
s = random.uniform(a, b)
(idx, p, data) = self.tree.get(s)
priorities.append(p)
batch.append(data)
idxs.append(idx)
return batch, idxs
# sampling_probabilities = priorities / self.tree.total()
# is_weight = np.power(self.tree.num_entries * sampling_probabilities, -self.beta)
# is_weight /= is_weight.max()
def update(self, idx, error):
p = self._get_priority(error)
self.tree.update(idx, p)
def save(self, path):
f = open(path, 'wb')
pickle.dump({"tree": self.tree}, f)
f.close()
def load_memory(self, path):
with open(path, 'rb') as f:
_memory = pickle.load(f)
self.tree = _memory['tree']
def get(self):
return pickle.dumps({"tree": self.tree})
def set(self, binary):
self.tree = pickle.loads(binary)['tree']

View File

@ -185,6 +185,9 @@ class Migration(migrations.Migration):
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(max_length=64, verbose_name=b'session name')),
('description', models.TextField(blank=True, null=True)),
('ddpg_actor_model', models.BinaryField(null=True, blank=True)),
('ddpg_critic_model', models.BinaryField(null=True, blank=True)),
('ddpg_reply_memory', models.BinaryField(null=True, blank=True)),
('creation_time', models.DateTimeField()),
('last_update', models.DateTimeField()),
('upload_code', models.CharField(max_length=30, unique=True)),

View File

@ -187,6 +187,9 @@ class Session(BaseModel):
hardware = models.ForeignKey(Hardware)
algorithm = models.IntegerField(choices=AlgorithmType.choices(),
default=AlgorithmType.OTTERTUNE)
ddpg_actor_model = models.BinaryField(null=True, blank=True)
ddpg_critic_model = models.BinaryField(null=True, blank=True)
ddpg_reply_memory = models.BinaryField(null=True, blank=True)
project = models.ForeignKey(Project)
creation_time = models.DateTimeField()

View File

@ -353,7 +353,7 @@ class BaseParser(object, metaclass=ABCMeta):
def format_enum(self, enum_value, metadata):
enumvals = metadata.enumvals.split(',')
return enumvals[enum_value]
return enumvals[int(round(enum_value))]
def format_integer(self, int_value, metadata):
return int(round(int_value))

View File

@ -35,7 +35,7 @@ MAX_TRAIN_SIZE = 7000
# Batch size in GPR model
BATCH_SIZE = 3000
# Threads for TensorFlow config
# Threads for TensorFlow config
NUM_THREADS = 4
# ---GRADIENT DESCENT CONSTANTS---
@ -54,3 +54,19 @@ DEFAULT_EPSILON = 1e-6
DEFAULT_SIGMA_MULTIPLIER = 3.0
DEFAULT_MU_MULTIPLIER = 1.0
# ---CONSTRAINTS CONSTANTS---
# Batch size in DDPG model
DDPG_BATCH_SIZE = 32
# Learning rate of actor network
ACTOR_LEARNING_RATE = 0.001
# Learning rate of critic network
CRITIC_LEARNING_RATE = 0.001
# The impact of future reward on the decision
GAMMA = 0.1
# The changing rate of the target network
TAU = 0.002

View File

@ -7,7 +7,7 @@ from .async_tasks import (aggregate_target_results,
configuration_recommendation,
map_workload,
train_ddpg,
run_ddpg)
configuration_recommendation_ddpg)
from .periodic_tasks import (run_background_tasks)

View File

@ -5,14 +5,12 @@
#
import random
import queue
from os.path import dirname, abspath, join
import os
import numpy as np
from celery.task import task, Task
from celery.utils.log import get_task_logger
from djcelery.models import TaskMeta
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from analysis.ddpg.ddpg import DDPG
from analysis.gp import GPRNP
@ -29,7 +27,10 @@ from website.settings import (DEFAULT_LENGTH_SCALE, DEFAULT_MAGNITUDE,
MAX_TRAIN_SIZE, BATCH_SIZE, NUM_THREADS,
DEFAULT_RIDGE, DEFAULT_LEARNING_RATE,
DEFAULT_EPSILON, MAX_ITER, GPR_EPS,
DEFAULT_SIGMA_MULTIPLIER, DEFAULT_MU_MULTIPLIER)
DEFAULT_SIGMA_MULTIPLIER, DEFAULT_MU_MULTIPLIER,
DDPG_BATCH_SIZE, ACTOR_LEARNING_RATE,
CRITIC_LEARNING_RATE, GAMMA, TAU)
from website.settings import INIT_FLIP_PROB, FLIP_PROB_DECAY
from website.types import VarType
@ -235,10 +236,10 @@ def train_ddpg(result_id):
# Clean knob data
cleaned_agg_data = clean_knob_data(agg_data['X_matrix'], agg_data['X_columnlabels'], session)
agg_data['X_matrix'] = np.array(cleaned_agg_data[0]).flatten()
agg_data['X_columnlabels'] = np.array(cleaned_agg_data[1]).flatten()
knob_data = DataUtil.normalize_knob_data(agg_data['X_matrix'],
agg_data['X_columnlabels'], session)
knob_data = np.array(cleaned_agg_data[0])
knob_labels = np.array(cleaned_agg_data[1])
knob_bounds = np.vstack(DataUtil.get_knob_bounds(knob_labels.flatten(), session))
knob_data = MinMaxScaler().fit(knob_bounds).transform(knob_data)[0]
knob_num = len(knob_data)
metric_num = len(metric_data)
LOG.info('knob_num: %d, metric_num: %d', knob_num, metric_num)
@ -276,26 +277,23 @@ def train_ddpg(result_id):
* (2 * prev_objective - objective) / prev_objective
# Update ddpg
project_root = dirname(dirname(dirname(abspath(__file__))))
saved_memory = join(project_root, 'checkpoint/reply_memory_' + session.project.name)
saved_model = join(project_root, 'checkpoint/ddpg_' + session.project.name)
ddpg = DDPG(n_actions=knob_num, n_states=metric_num)
if os.path.exists(saved_memory):
ddpg.replay_memory.load_memory(saved_memory)
ddpg.load_model(saved_model)
ddpg = DDPG(n_actions=knob_num, n_states=metric_num, alr=ACTOR_LEARNING_RATE,
clr=CRITIC_LEARNING_RATE, gamma=GAMMA, batch_size=DDPG_BATCH_SIZE, tau=TAU)
if session.ddpg_actor_model and session.ddpg_critic_model:
ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model)
if session.ddpg_reply_memory:
ddpg.replay_memory.set(session.ddpg_reply_memory)
ddpg.add_sample(prev_metric_data, knob_data, reward, metric_data, False)
if len(ddpg.replay_memory) > 32:
ddpg.update()
checkpoint_dir = join(project_root, 'checkpoint')
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
ddpg.replay_memory.save(saved_memory)
ddpg.save_model(saved_model)
session.ddpg_actor_model, session.ddpg_critic_model = ddpg.get_model()
session.ddpg_reply_memory = ddpg.replay_memory.get()
session.save()
return result_info
@task(base=ConfigurationRecommendation, name='run_ddpg')
def run_ddpg(result_info):
@task(base=ConfigurationRecommendation, name='configuration_recommendation_ddpg')
def configuration_recommendation_ddpg(result_info): # pylint: disable=invalid-name
LOG.info('Use ddpg to recommend configuration')
result_id = result_info['newest_result_id']
result = Result.objects.filter(pk=result_id)
@ -305,20 +303,20 @@ def run_ddpg(result_info):
cleaned_agg_data = clean_knob_data(agg_data['X_matrix'], agg_data['X_columnlabels'],
session)
knob_labels = np.array(cleaned_agg_data[1]).flatten()
knob_data = np.array(cleaned_agg_data[0]).flatten()
knob_num = len(knob_data)
knob_num = len(knob_labels)
metric_num = len(metric_data)
project_root = dirname(dirname(dirname(abspath(__file__))))
saved_memory = join(project_root, 'checkpoint/reply_memory_' + session.project.name)
saved_model = join(project_root, 'checkpoint/ddpg_' + session.project.name)
ddpg = DDPG(n_actions=knob_num, n_states=metric_num)
if os.path.exists(saved_memory):
ddpg.replay_memory.load_memory(saved_memory)
ddpg.load_model(saved_model)
ddpg = DDPG(n_actions=knob_num, n_states=metric_num, alr=ACTOR_LEARNING_RATE,
clr=CRITIC_LEARNING_RATE, gamma=GAMMA, batch_size=DDPG_BATCH_SIZE, tau=TAU)
if session.ddpg_actor_model is not None and session.ddpg_critic_model is not None:
ddpg.set_model(session.ddpg_actor_model, session.ddpg_critic_model)
if session.ddpg_reply_memory is not None:
ddpg.replay_memory.set(session.ddpg_reply_memory)
knob_data = ddpg.choose_action(metric_data)
LOG.info('recommended knob: %s', knob_data)
knob_data = DataUtil.denormalize_knob_data(knob_data, knob_labels, session)
knob_bounds = np.vstack(DataUtil.get_knob_bounds(knob_labels, session))
knob_data = MinMaxScaler().fit(knob_bounds).inverse_transform(knob_data.reshape(1, -1))[0]
conf_map = {k: knob_data[i] for i, k in enumerate(knob_labels)}
conf_map_res = {}
conf_map_res['status'] = 'good'

View File

@ -93,30 +93,35 @@ class TaskUtil(object):
class DataUtil(object):
@staticmethod
def normalize_knob_data(knob_values, knob_labels, session):
for i, knob in enumerate(knob_labels):
def get_knob_bounds(knob_labels, session):
minvals = []
maxvals = []
for _, knob in enumerate(knob_labels):
knob_object = KnobCatalog.objects.get(dbms=session.dbms, name=knob, tunable=True)
minval = float(knob_object.minval)
maxval = float(knob_object.maxval)
knob_new = SessionKnob.objects.filter(knob=knob_object, session=session, tunable=True)
if knob_new.exists():
minval = float(knob_new[0].minval)
maxval = float(knob_new[0].maxval)
knob_values[i] = (knob_values[i] - minval) / (maxval - minval)
knob_values[i] = max(0, min(knob_values[i], 1))
return knob_values
@staticmethod
def denormalize_knob_data(knob_values, knob_labels, session):
for i, knob in enumerate(knob_labels):
knob_object = KnobCatalog.objects.get(dbms=session.dbms, name=knob, tunable=True)
minval = float(knob_object.minval)
maxval = float(knob_object.maxval)
knob_session_object = SessionKnob.objects.filter(knob=knob_object, session=session,
tunable=True)
if knob_session_object.exists():
minval = float(knob_session_object[0].minval)
maxval = float(knob_session_object[0].maxval)
else:
minval = float(knob_object.minval)
maxval = float(knob_object.maxval)
minvals.append(minval)
maxvals.append(maxval)
return np.array(minvals), np.array(maxvals)
@staticmethod
def denormalize_knob_data(knob_values, knob_labels, session):
for i, knob in enumerate(knob_labels):
knob_object = KnobCatalog.objects.get(dbms=session.dbms, name=knob, tunable=True)
knob_session_object = SessionKnob.objects.filter(knob=knob_object, session=session,
tunable=True)
if knob_session_object.exists():
minval = float(knob_session_object[0].minval)
maxval = float(knob_session_object[0].maxval)
else:
minval = float(knob_object.minval)
maxval = float(knob_object.maxval)
knob_values[i] = knob_values[i] * (maxval - minval) + minval
return knob_values

View File

@ -30,8 +30,8 @@ from .models import (BackupData, DBMSCatalog, KnobCatalog, KnobData, MetricCatal
MetricData, MetricManager, Project, Result, Session, Workload,
SessionKnob)
from .parser import Parser
from .tasks import (aggregate_target_results, map_workload, train_ddpg, run_ddpg,
configuration_recommendation)
from .tasks import (aggregate_target_results, map_workload, train_ddpg,
configuration_recommendation, configuration_recommendation_ddpg)
from .types import (DBMSType, KnobUnitType, MetricType,
TaskType, VarType, WorkloadStatusType, AlgorithmType)
from .utils import JSONUtil, LabelUtil, MediaUtil, TaskUtil