Add machine learning model ddpg
This commit is contained in:
parent
b9dc726b9c
commit
c83f2649b6
|
@ -12,7 +12,7 @@ import sys
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import os.path
|
import os
|
||||||
import re
|
import re
|
||||||
import glob
|
import glob
|
||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
|
@ -145,6 +145,8 @@ def signal_controller():
|
||||||
|
|
||||||
@task
|
@task
|
||||||
def save_dbms_result():
|
def save_dbms_result():
|
||||||
|
if not os.path.exists(CONF['save_path']):
|
||||||
|
os.makedirs(CONF['save_path'])
|
||||||
t = int(time.time())
|
t = int(time.time())
|
||||||
files = ['knobs.json', 'metrics_after.json', 'metrics_before.json', 'summary.json']
|
files = ['knobs.json', 'metrics_after.json', 'metrics_before.json', 'summary.json']
|
||||||
for f_ in files:
|
for f_ in files:
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
#
|
||||||
|
# OUProcess.py
|
||||||
|
#
|
||||||
|
# Copyright
|
||||||
|
#
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/songrotek/DDPG/blob/master/ou_noise.py
|
||||||
|
class OUProcess(object):
|
||||||
|
|
||||||
|
def __init__(self, n_actions, theta=0.15, mu=0, sigma=0.1, ):
|
||||||
|
|
||||||
|
self.n_actions = n_actions
|
||||||
|
self.theta = theta
|
||||||
|
self.mu = mu
|
||||||
|
self.sigma = sigma
|
||||||
|
self.current_value = np.ones(self.n_actions) * self.mu
|
||||||
|
|
||||||
|
def reset(self, sigma=0):
|
||||||
|
self.current_value = np.ones(self.n_actions) * self.mu
|
||||||
|
if sigma != 0:
|
||||||
|
self.sigma = sigma
|
||||||
|
|
||||||
|
def noise(self):
|
||||||
|
x = self.current_value
|
||||||
|
dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(len(x))
|
||||||
|
self.current_value = x + dx
|
||||||
|
return self.current_value
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import matplotlib.pyplot as plt # pylint: disable=wrong-import-position
|
||||||
|
ou = OUProcess(3, theta=0.3) # pylint: disable=invalid-name
|
||||||
|
states = [] # pylint: disable=invalid-name
|
||||||
|
for i in range(1000):
|
||||||
|
states.append(ou.noise())
|
||||||
|
|
||||||
|
plt.plot(states)
|
||||||
|
plt.show()
|
|
@ -0,0 +1,10 @@
|
||||||
|
#
|
||||||
|
# __init__.py
|
||||||
|
#
|
||||||
|
# Copyright
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
from analysis.ddpg.ddpg import DDPG
|
||||||
|
|
||||||
|
__all__ = ["DDPG"]
|
|
@ -0,0 +1,509 @@
|
||||||
|
#
|
||||||
|
# ddpg.py
|
||||||
|
#
|
||||||
|
# Copyright
|
||||||
|
#
|
||||||
|
"""
|
||||||
|
Deep Deterministic Policy Gradient Model
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import math
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import init, Parameter
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optimizer
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
from analysis.ddpg.OUProcess import OUProcess
|
||||||
|
from analysis.ddpg.prioritized_replay_memory import PrioritizedReplayMemory
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
sys.path.append('../')
|
||||||
|
|
||||||
|
|
||||||
|
# code from https://github.com/Kaixhin/NoisyNet-A3C/blob/master/model.py
|
||||||
|
class NoisyLinear(nn.Linear):
|
||||||
|
def __init__(self, in_features, out_features, sigma_init=0.05, bias=True):
|
||||||
|
super(NoisyLinear, self).__init__(in_features, out_features, bias=True)
|
||||||
|
# reuse self.weight and self.bias
|
||||||
|
self.sigma_init = sigma_init
|
||||||
|
self.sigma_weight = Parameter(torch.Tensor(out_features, in_features))
|
||||||
|
self.sigma_bias = Parameter(torch.Tensor(out_features))
|
||||||
|
self.register_buffer('epsilon_weight', torch.zeros(out_features, in_features))
|
||||||
|
self.register_buffer('epsilon_bias', torch.zeros(out_features))
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
# Only init after all params added (otherwise super().__init__() fails)
|
||||||
|
if hasattr(self, 'sigma_weight'):
|
||||||
|
init.uniform(self.weight, -math.sqrt(3 / self.in_features),
|
||||||
|
math.sqrt(3 / self.in_features))
|
||||||
|
init.uniform(self.bias, -math.sqrt(3 / self.in_features),
|
||||||
|
math.sqrt(3 / self.in_features))
|
||||||
|
init.constant(self.sigma_weight, self.sigma_init)
|
||||||
|
init.constant(self.sigma_bias, self.sigma_init)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.linear(x, self.weight + self.sigma_weight * Variable(self.epsilon_weight),
|
||||||
|
self.bias + self.sigma_bias * Variable(self.epsilon_bias))
|
||||||
|
|
||||||
|
# pylint: disable=attribute-defined-outside-init
|
||||||
|
def sample_noise(self):
|
||||||
|
self.epsilon_weight = torch.randn(self.out_features, self.in_features)
|
||||||
|
self.epsilon_bias = torch.randn(self.out_features)
|
||||||
|
|
||||||
|
def remove_noise(self):
|
||||||
|
self.epsilon_weight = torch.zeros(self.out_features, self.in_features)
|
||||||
|
self.epsilon_bias = torch.zeros(self.out_features)
|
||||||
|
# pylint: enable=attribute-defined-outside-init
|
||||||
|
|
||||||
|
|
||||||
|
class Normalizer(object):
|
||||||
|
|
||||||
|
def __init__(self, mean, variance):
|
||||||
|
if isinstance(mean, list):
|
||||||
|
mean = np.array(mean)
|
||||||
|
if isinstance(variance, list):
|
||||||
|
variance = np.array(variance)
|
||||||
|
self.mean = mean
|
||||||
|
self.std = np.sqrt(variance + 0.00001)
|
||||||
|
|
||||||
|
def normalize(self, x):
|
||||||
|
if isinstance(x, list):
|
||||||
|
x = np.array(x)
|
||||||
|
x = x - self.mean
|
||||||
|
x = x / self.std
|
||||||
|
|
||||||
|
return Variable(torch.FloatTensor(x))
|
||||||
|
|
||||||
|
def __call__(self, x, *args, **kwargs):
|
||||||
|
return self.normalize(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ActorLow(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, n_states, n_actions, ):
|
||||||
|
super(ActorLow, self).__init__()
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
nn.BatchNorm1d(n_states),
|
||||||
|
nn.Linear(n_states, 32),
|
||||||
|
nn.LeakyReLU(negative_slope=0.2),
|
||||||
|
nn.BatchNorm1d(32),
|
||||||
|
nn.Linear(32, n_actions),
|
||||||
|
nn.LeakyReLU(negative_slope=0.2)
|
||||||
|
)
|
||||||
|
self._init_weights()
|
||||||
|
self.out_func = nn.Tanh()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
|
||||||
|
for m in self.layers:
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
m.weight.data.normal_(0.0, 1e-3)
|
||||||
|
m.bias.data.uniform_(-0.1, 0.1)
|
||||||
|
|
||||||
|
def forward(self, x): # pylint: disable=arguments-differ
|
||||||
|
|
||||||
|
out = self.layers(x)
|
||||||
|
|
||||||
|
return self.out_func(out)
|
||||||
|
|
||||||
|
|
||||||
|
class CriticLow(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, n_states, n_actions):
|
||||||
|
super(CriticLow, self).__init__()
|
||||||
|
self.state_input = nn.Linear(n_states, 32)
|
||||||
|
self.action_input = nn.Linear(n_actions, 32)
|
||||||
|
self.act = nn.LeakyReLU(negative_slope=0.2)
|
||||||
|
self.state_bn = nn.BatchNorm1d(n_states)
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
nn.Linear(64, 1),
|
||||||
|
nn.LeakyReLU(negative_slope=0.2),
|
||||||
|
)
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
self.state_input.weight.data.normal_(0.0, 1e-3)
|
||||||
|
self.state_input.bias.data.uniform_(-0.1, 0.1)
|
||||||
|
|
||||||
|
self.action_input.weight.data.normal_(0.0, 1e-3)
|
||||||
|
self.action_input.bias.data.uniform_(-0.1, 0.1)
|
||||||
|
|
||||||
|
for m in self.layers:
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
m.weight.data.normal_(0.0, 1e-3)
|
||||||
|
m.bias.data.uniform_(-0.1, 0.1)
|
||||||
|
|
||||||
|
def forward(self, x, action): # pylint: disable=arguments-differ
|
||||||
|
x = self.state_bn(x)
|
||||||
|
x = self.act(self.state_input(x))
|
||||||
|
action = self.act(self.action_input(action))
|
||||||
|
|
||||||
|
_input = torch.cat([x, action], dim=1)
|
||||||
|
value = self.layers(_input)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class Actor(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, n_states, n_actions, noisy=False):
|
||||||
|
super(Actor, self).__init__()
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
nn.Linear(n_states, 128),
|
||||||
|
nn.LeakyReLU(negative_slope=0.2),
|
||||||
|
nn.BatchNorm1d(128),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
|
||||||
|
nn.Linear(128, 64),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.BatchNorm1d(64),
|
||||||
|
)
|
||||||
|
if noisy:
|
||||||
|
self.out = NoisyLinear(64, n_actions)
|
||||||
|
else:
|
||||||
|
self.out = nn.Linear(64, n_actions)
|
||||||
|
self._init_weights()
|
||||||
|
self.act = nn.Sigmoid()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
|
||||||
|
for m in self.layers:
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
m.weight.data.normal_(0.0, 1e-2)
|
||||||
|
m.bias.data.uniform_(-0.1, 0.1)
|
||||||
|
|
||||||
|
def sample_noise(self):
|
||||||
|
self.out.sample_noise()
|
||||||
|
|
||||||
|
def forward(self, x): # pylint: disable=arguments-differ
|
||||||
|
|
||||||
|
out = self.act(self.out(self.layers(x)))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Critic(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, n_states, n_actions):
|
||||||
|
super(Critic, self).__init__()
|
||||||
|
self.state_input = nn.Linear(n_states, 128)
|
||||||
|
self.action_input = nn.Linear(n_actions, 128)
|
||||||
|
self.act = nn.Tanh()
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
nn.Linear(256, 256),
|
||||||
|
nn.LeakyReLU(negative_slope=0.2),
|
||||||
|
nn.BatchNorm1d(256),
|
||||||
|
|
||||||
|
nn.Linear(256, 64),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.BatchNorm1d(64),
|
||||||
|
nn.Linear(64, 1),
|
||||||
|
)
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
self.state_input.weight.data.normal_(0.0, 1e-2)
|
||||||
|
self.state_input.bias.data.uniform_(-0.1, 0.1)
|
||||||
|
|
||||||
|
self.action_input.weight.data.normal_(0.0, 1e-2)
|
||||||
|
self.action_input.bias.data.uniform_(-0.1, 0.1)
|
||||||
|
|
||||||
|
for m in self.layers:
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
m.weight.data.normal_(0.0, 1e-2)
|
||||||
|
m.bias.data.uniform_(-0.1, 0.1)
|
||||||
|
|
||||||
|
def forward(self, x, action): # pylint: disable=arguments-differ
|
||||||
|
x = self.act(self.state_input(x))
|
||||||
|
action = self.act(self.action_input(action))
|
||||||
|
|
||||||
|
_input = torch.cat([x, action], dim=1)
|
||||||
|
value = self.layers(_input)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class DDPG(object):
|
||||||
|
|
||||||
|
def __init__(self, n_states, n_actions, opt=None, ouprocess=True, mean_var_path=None,
|
||||||
|
supervised=False):
|
||||||
|
""" DDPG Algorithms
|
||||||
|
Args:
|
||||||
|
n_states: int, dimension of states
|
||||||
|
n_actions: int, dimension of actions
|
||||||
|
opt: dict, params
|
||||||
|
supervised, bool, pre-train the actor with supervised learning
|
||||||
|
"""
|
||||||
|
self.n_states = n_states
|
||||||
|
self.n_actions = n_actions
|
||||||
|
|
||||||
|
if opt is None:
|
||||||
|
opt = {
|
||||||
|
'model': '',
|
||||||
|
'alr': 0.001,
|
||||||
|
'clr': 0.001,
|
||||||
|
'gamma': 0.9,
|
||||||
|
'batch_size': 32,
|
||||||
|
'tau': 0.002,
|
||||||
|
'memory_size': 100000
|
||||||
|
}
|
||||||
|
|
||||||
|
# Params
|
||||||
|
self.alr = opt['alr']
|
||||||
|
self.clr = opt['clr']
|
||||||
|
self.model_name = opt['model']
|
||||||
|
self.batch_size = opt['batch_size']
|
||||||
|
self.gamma = opt['gamma']
|
||||||
|
self.tau = opt['tau']
|
||||||
|
self.ouprocess = ouprocess
|
||||||
|
|
||||||
|
if mean_var_path is None:
|
||||||
|
mean = np.zeros(n_states)
|
||||||
|
var = np.zeros(n_states)
|
||||||
|
elif not os.path.exists(mean_var_path):
|
||||||
|
mean = np.zeros(n_states)
|
||||||
|
var = np.zeros(n_states)
|
||||||
|
else:
|
||||||
|
with open(mean_var_path, 'rb') as f:
|
||||||
|
mean, var = pickle.load(f)
|
||||||
|
|
||||||
|
self.normalizer = Normalizer(mean, var)
|
||||||
|
|
||||||
|
if supervised:
|
||||||
|
self._build_actor()
|
||||||
|
LOG.info("Supervised Learning Initialized")
|
||||||
|
else:
|
||||||
|
# Build Network
|
||||||
|
self._build_network()
|
||||||
|
LOG.info('Finish Initializing Networks')
|
||||||
|
|
||||||
|
self.replay_memory = PrioritizedReplayMemory(capacity=opt['memory_size'])
|
||||||
|
self.noise = OUProcess(n_actions)
|
||||||
|
# LOG.info('DDPG Initialzed!')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def totensor(x):
|
||||||
|
return Variable(torch.FloatTensor(x))
|
||||||
|
|
||||||
|
def _build_actor(self):
|
||||||
|
if self.ouprocess:
|
||||||
|
noisy = False
|
||||||
|
else:
|
||||||
|
noisy = True
|
||||||
|
self.actor = Actor(self.n_states, self.n_actions, noisy=noisy)
|
||||||
|
self.actor_criterion = nn.MSELoss()
|
||||||
|
self.actor_optimizer = optimizer.Adam(lr=self.alr, params=self.actor.parameters())
|
||||||
|
|
||||||
|
def _build_network(self):
|
||||||
|
if self.ouprocess:
|
||||||
|
noisy = False
|
||||||
|
else:
|
||||||
|
noisy = True
|
||||||
|
self.actor = Actor(self.n_states, self.n_actions, noisy=noisy)
|
||||||
|
self.target_actor = Actor(self.n_states, self.n_actions)
|
||||||
|
self.critic = Critic(self.n_states, self.n_actions)
|
||||||
|
self.target_critic = Critic(self.n_states, self.n_actions)
|
||||||
|
|
||||||
|
# if model params are provided, load them
|
||||||
|
if len(self.model_name):
|
||||||
|
self.load_model(model_name=self.model_name)
|
||||||
|
LOG.info("Loading model from file: %s", self.model_name)
|
||||||
|
|
||||||
|
# Copy actor's parameters
|
||||||
|
self._update_target(self.target_actor, self.actor, tau=1.0)
|
||||||
|
|
||||||
|
# Copy critic's parameters
|
||||||
|
self._update_target(self.target_critic, self.critic, tau=1.0)
|
||||||
|
|
||||||
|
self.loss_criterion = nn.MSELoss()
|
||||||
|
self.actor_optimizer = optimizer.Adam(lr=self.alr, params=self.actor.parameters(),
|
||||||
|
weight_decay=1e-5)
|
||||||
|
self.critic_optimizer = optimizer.Adam(lr=self.clr, params=self.critic.parameters(),
|
||||||
|
weight_decay=1e-5)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _update_target(target, source, tau):
|
||||||
|
for (target_param, param) in zip(target.parameters(), source.parameters()):
|
||||||
|
target_param.data.copy_(
|
||||||
|
target_param.data * (1 - tau) + param.data * tau
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self, sigma):
|
||||||
|
self.noise.reset(sigma)
|
||||||
|
|
||||||
|
def _sample_batch(self):
|
||||||
|
batch, idx = self.replay_memory.sample(self.batch_size)
|
||||||
|
# batch = self.replay_memory.sample(self.batch_size)
|
||||||
|
states = list(map(lambda x: x[0].tolist(), batch)) # pylint: disable=W0141
|
||||||
|
next_states = list(map(lambda x: x[3].tolist(), batch)) # pylint: disable=W0141
|
||||||
|
actions = list(map(lambda x: x[1].tolist(), batch)) # pylint: disable=W0141
|
||||||
|
rewards = list(map(lambda x: x[2], batch)) # pylint: disable=W0141
|
||||||
|
terminates = list(map(lambda x: x[4], batch)) # pylint: disable=W0141
|
||||||
|
|
||||||
|
return idx, states, next_states, actions, rewards, terminates
|
||||||
|
|
||||||
|
def add_sample(self, state, action, reward, next_state, terminate):
|
||||||
|
self.critic.eval()
|
||||||
|
self.actor.eval()
|
||||||
|
self.target_critic.eval()
|
||||||
|
self.target_actor.eval()
|
||||||
|
batch_state = self.normalizer([state.tolist()])
|
||||||
|
batch_next_state = self.normalizer([next_state.tolist()])
|
||||||
|
current_value = self.critic(batch_state, self.totensor([action.tolist()]))
|
||||||
|
target_action = self.target_actor(batch_next_state)
|
||||||
|
target_value = self.totensor([reward]) \
|
||||||
|
+ self.totensor([0 if x else 1 for x in [terminate]]) \
|
||||||
|
* self.target_critic(batch_next_state, target_action) * self.gamma
|
||||||
|
error = float(torch.abs(current_value - target_value).data.numpy()[0])
|
||||||
|
|
||||||
|
self.target_actor.train()
|
||||||
|
self.actor.train()
|
||||||
|
self.critic.train()
|
||||||
|
self.target_critic.train()
|
||||||
|
self.replay_memory.add(error, (state, action, reward, next_state, terminate))
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
idxs, states, next_states, actions, rewards, terminates = self._sample_batch()
|
||||||
|
batch_states = self.normalizer(states)
|
||||||
|
batch_next_states = self.normalizer(next_states)
|
||||||
|
batch_actions = self.totensor(actions)
|
||||||
|
batch_rewards = self.totensor(rewards)
|
||||||
|
mask = [0 if x else 1 for x in terminates]
|
||||||
|
mask = self.totensor(mask)
|
||||||
|
|
||||||
|
target_next_actions = self.target_actor(batch_next_states).detach()
|
||||||
|
target_next_value = self.target_critic(batch_next_states, target_next_actions).detach()
|
||||||
|
current_value = self.critic(batch_states, batch_actions)
|
||||||
|
# TODO (dongshen): This clause is the original clause, but it has some mistakes
|
||||||
|
# next_value = batch_rewards + mask * target_next_value * self.gamma
|
||||||
|
# Since terminate is always false, I remove the mask here.
|
||||||
|
next_value = batch_rewards + target_next_value * self.gamma
|
||||||
|
# Update Critic
|
||||||
|
|
||||||
|
# update prioritized memory
|
||||||
|
error = torch.abs(current_value - next_value).data.numpy()
|
||||||
|
for i in range(self.batch_size):
|
||||||
|
idx = idxs[i]
|
||||||
|
self.replay_memory.update(idx, error[i][0])
|
||||||
|
|
||||||
|
loss = self.loss_criterion(current_value, next_value)
|
||||||
|
self.critic_optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
self.critic_optimizer.step()
|
||||||
|
|
||||||
|
# Update Actor
|
||||||
|
self.critic.eval()
|
||||||
|
policy_loss = -self.critic(batch_states, self.actor(batch_states))
|
||||||
|
policy_loss = policy_loss.mean()
|
||||||
|
self.actor_optimizer.zero_grad()
|
||||||
|
policy_loss.backward()
|
||||||
|
|
||||||
|
self.actor_optimizer.step()
|
||||||
|
self.critic.train()
|
||||||
|
|
||||||
|
self._update_target(self.target_critic, self.critic, tau=self.tau)
|
||||||
|
self._update_target(self.target_actor, self.actor, tau=self.tau)
|
||||||
|
|
||||||
|
return loss.data, policy_loss.data
|
||||||
|
|
||||||
|
def choose_action(self, x):
|
||||||
|
""" Select Action according to the current state
|
||||||
|
Args:
|
||||||
|
x: np.array, current state
|
||||||
|
"""
|
||||||
|
self.actor.eval()
|
||||||
|
act = self.actor(self.normalizer([x.tolist()])).squeeze(0)
|
||||||
|
self.actor.train()
|
||||||
|
action = act.data.numpy()
|
||||||
|
if self.ouprocess:
|
||||||
|
action += self.noise.noise()
|
||||||
|
return action.clip(0, 1)
|
||||||
|
|
||||||
|
def sample_noise(self):
|
||||||
|
self.actor.sample_noise()
|
||||||
|
|
||||||
|
def load_model(self, model_name):
|
||||||
|
""" Load Torch Model from files
|
||||||
|
Args:
|
||||||
|
model_name: str, model path
|
||||||
|
"""
|
||||||
|
self.actor.load_state_dict(
|
||||||
|
torch.load('{}_actor.pth'.format(model_name))
|
||||||
|
)
|
||||||
|
self.critic.load_state_dict(
|
||||||
|
torch.load('{}_critic.pth'.format(model_name))
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_model(self, model_name):
|
||||||
|
""" Save Torch Model from files
|
||||||
|
Args:
|
||||||
|
model_dir: str, model dir
|
||||||
|
title: str, model name
|
||||||
|
"""
|
||||||
|
torch.save(
|
||||||
|
self.actor.state_dict(),
|
||||||
|
'{}_actor.pth'.format(model_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.save(
|
||||||
|
self.critic.state_dict(),
|
||||||
|
'{}_critic.pth'.format(model_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_actor(self, path):
|
||||||
|
""" save actor network
|
||||||
|
Args:
|
||||||
|
path, str, path to save
|
||||||
|
"""
|
||||||
|
torch.save(
|
||||||
|
self.actor.state_dict(),
|
||||||
|
path
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_actor(self, path):
|
||||||
|
""" load actor network
|
||||||
|
Args:
|
||||||
|
path, str, path to load
|
||||||
|
"""
|
||||||
|
self.actor.load_state_dict(
|
||||||
|
torch.load(path)
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_actor(self, batch_data, is_train=True):
|
||||||
|
""" Train the actor separately with data
|
||||||
|
Args:
|
||||||
|
batch_data: tuple, (states, actions)
|
||||||
|
is_train: bool
|
||||||
|
Return:
|
||||||
|
_loss: float, training loss
|
||||||
|
"""
|
||||||
|
states, action = batch_data
|
||||||
|
|
||||||
|
if is_train:
|
||||||
|
self.actor.train()
|
||||||
|
pred = self.actor(self.normalizer(states))
|
||||||
|
action = self.totensor(action)
|
||||||
|
|
||||||
|
_loss = self.actor_criterion(pred, action)
|
||||||
|
|
||||||
|
self.actor_optimizer.zero_grad()
|
||||||
|
_loss.backward()
|
||||||
|
self.actor_optimizer.step()
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.actor.eval()
|
||||||
|
pred = self.actor(self.normalizer(states))
|
||||||
|
action = self.totensor(action)
|
||||||
|
_loss = self.actor_criterion(pred, action)
|
||||||
|
|
||||||
|
return _loss.data[0]
|
|
@ -0,0 +1,121 @@
|
||||||
|
#
|
||||||
|
# prioritized_replay_memory.py
|
||||||
|
#
|
||||||
|
# Copyright
|
||||||
|
#
|
||||||
|
import random
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class SumTree(object):
|
||||||
|
write = 0
|
||||||
|
|
||||||
|
def __init__(self, capacity):
|
||||||
|
self.capacity = capacity
|
||||||
|
self.tree = np.zeros(2 * capacity - 1)
|
||||||
|
self.data = np.zeros(capacity, dtype=object)
|
||||||
|
self.num_entries = 0
|
||||||
|
|
||||||
|
def _propagate(self, idx, change):
|
||||||
|
parent = (idx - 1) // 2
|
||||||
|
self.tree[parent] += change
|
||||||
|
if parent != 0:
|
||||||
|
self._propagate(parent, change)
|
||||||
|
|
||||||
|
def _retrieve(self, idx, s):
|
||||||
|
left = 2 * idx + 1
|
||||||
|
right = left + 1
|
||||||
|
|
||||||
|
if left >= len(self.tree):
|
||||||
|
return idx
|
||||||
|
|
||||||
|
if s <= self.tree[left]:
|
||||||
|
return self._retrieve(left, s)
|
||||||
|
else:
|
||||||
|
return self._retrieve(right, s - self.tree[left])
|
||||||
|
|
||||||
|
def total(self):
|
||||||
|
return self.tree[0]
|
||||||
|
|
||||||
|
def add(self, p, data):
|
||||||
|
idx = self.write + self.capacity - 1
|
||||||
|
|
||||||
|
self.data[self.write] = data
|
||||||
|
self.update(idx, p)
|
||||||
|
|
||||||
|
self.write += 1
|
||||||
|
if self.write >= self.capacity:
|
||||||
|
self.write = 0
|
||||||
|
if self.num_entries < self.capacity:
|
||||||
|
self.num_entries += 1
|
||||||
|
|
||||||
|
def update(self, idx, p):
|
||||||
|
change = p - self.tree[idx]
|
||||||
|
|
||||||
|
self.tree[idx] = p
|
||||||
|
self._propagate(idx, change)
|
||||||
|
|
||||||
|
def get(self, s):
|
||||||
|
idx = self._retrieve(0, s)
|
||||||
|
data_idx = idx - self.capacity + 1
|
||||||
|
return [idx, self.tree[idx], self.data[data_idx]]
|
||||||
|
|
||||||
|
|
||||||
|
class PrioritizedReplayMemory(object):
|
||||||
|
|
||||||
|
def __init__(self, capacity):
|
||||||
|
self.tree = SumTree(capacity)
|
||||||
|
self.capacity = capacity
|
||||||
|
self.e = 0.01 # pylint: disable=invalid-name
|
||||||
|
self.a = 0.6 # pylint: disable=invalid-name
|
||||||
|
self.beta = 0.4
|
||||||
|
self.beta_increment_per_sampling = 0.001
|
||||||
|
|
||||||
|
def _get_priority(self, error):
|
||||||
|
return (error + self.e) ** self.a
|
||||||
|
|
||||||
|
def add(self, error, sample):
|
||||||
|
# (s, a, r, s, t)
|
||||||
|
p = self._get_priority(error)
|
||||||
|
self.tree.add(p, sample)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.tree.num_entries
|
||||||
|
|
||||||
|
def sample(self, n):
|
||||||
|
batch = []
|
||||||
|
idxs = []
|
||||||
|
segment = self.tree.total() / n
|
||||||
|
priorities = []
|
||||||
|
|
||||||
|
self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])
|
||||||
|
|
||||||
|
for i in range(n):
|
||||||
|
a = segment * i
|
||||||
|
b = segment * (i + 1)
|
||||||
|
|
||||||
|
s = random.uniform(a, b)
|
||||||
|
(idx, p, data) = self.tree.get(s)
|
||||||
|
priorities.append(p)
|
||||||
|
batch.append(data)
|
||||||
|
idxs.append(idx)
|
||||||
|
return batch, idxs
|
||||||
|
|
||||||
|
# sampling_probabilities = priorities / self.tree.total()
|
||||||
|
# is_weight = np.power(self.tree.num_entries * sampling_probabilities, -self.beta)
|
||||||
|
# is_weight /= is_weight.max()
|
||||||
|
|
||||||
|
def update(self, idx, error):
|
||||||
|
p = self._get_priority(error)
|
||||||
|
self.tree.update(idx, p)
|
||||||
|
|
||||||
|
def save(self, path):
|
||||||
|
f = open(path, 'wb')
|
||||||
|
pickle.dump({"tree": self.tree}, f)
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
def load_memory(self, path):
|
||||||
|
with open(path, 'rb') as f:
|
||||||
|
_memory = pickle.load(f)
|
||||||
|
self.tree = _memory['tree']
|
|
@ -5,7 +5,9 @@
|
||||||
#
|
#
|
||||||
from .async_tasks import (aggregate_target_results,
|
from .async_tasks import (aggregate_target_results,
|
||||||
configuration_recommendation,
|
configuration_recommendation,
|
||||||
map_workload)
|
map_workload,
|
||||||
|
train_ddpg,
|
||||||
|
run_ddpg)
|
||||||
|
|
||||||
|
|
||||||
from .periodic_tasks import (run_background_tasks)
|
from .periodic_tasks import (run_background_tasks)
|
||||||
|
|
|
@ -5,6 +5,8 @@
|
||||||
#
|
#
|
||||||
import random
|
import random
|
||||||
import queue
|
import queue
|
||||||
|
from os.path import dirname, abspath, join
|
||||||
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from celery.task import task, Task
|
from celery.task import task, Task
|
||||||
|
@ -12,6 +14,7 @@ from celery.utils.log import get_task_logger
|
||||||
from djcelery.models import TaskMeta
|
from djcelery.models import TaskMeta
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
|
||||||
|
from analysis.ddpg.ddpg import DDPG
|
||||||
from analysis.gp import GPRNP
|
from analysis.gp import GPRNP
|
||||||
from analysis.gp_tf import GPRGD
|
from analysis.gp_tf import GPRGD
|
||||||
from analysis.preprocessing import Bin, DummyEncoder
|
from analysis.preprocessing import Bin, DummyEncoder
|
||||||
|
@ -30,6 +33,7 @@ from website.settings import (DEFAULT_LENGTH_SCALE, DEFAULT_MAGNITUDE,
|
||||||
from website.settings import INIT_FLIP_PROB, FLIP_PROB_DECAY
|
from website.settings import INIT_FLIP_PROB, FLIP_PROB_DECAY
|
||||||
from website.types import VarType
|
from website.types import VarType
|
||||||
|
|
||||||
|
|
||||||
LOG = get_task_logger(__name__)
|
LOG = get_task_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,6 +45,17 @@ class UpdateTask(Task): # pylint: disable=abstract-method
|
||||||
self.default_retry_delay = 60
|
self.default_retry_delay = 60
|
||||||
|
|
||||||
|
|
||||||
|
class TrainDDPG(UpdateTask): # pylint: disable=abstract-method
|
||||||
|
def on_success(self, retval, task_id, args, kwargs):
|
||||||
|
super(TrainDDPG, self).on_success(retval, task_id, args, kwargs)
|
||||||
|
|
||||||
|
# Completely delete this result because it's huge and not
|
||||||
|
# interesting
|
||||||
|
task_meta = TaskMeta.objects.get(task_id=task_id)
|
||||||
|
task_meta.result = None
|
||||||
|
task_meta.save()
|
||||||
|
|
||||||
|
|
||||||
class AggregateTargetResults(UpdateTask): # pylint: disable=abstract-method
|
class AggregateTargetResults(UpdateTask): # pylint: disable=abstract-method
|
||||||
|
|
||||||
def on_success(self, retval, task_id, args, kwargs):
|
def on_success(self, retval, task_id, args, kwargs):
|
||||||
|
@ -194,6 +209,126 @@ def gen_random_data(knobs):
|
||||||
return random_knob_result
|
return random_knob_result
|
||||||
|
|
||||||
|
|
||||||
|
@task(base=TrainDDPG, name='train_ddpg')
|
||||||
|
def train_ddpg(result_id):
|
||||||
|
LOG.info('Add training data to ddpg and train ddpg')
|
||||||
|
result = Result.objects.get(pk=result_id)
|
||||||
|
session = Result.objects.get(pk=result_id).session
|
||||||
|
session_results = Result.objects.filter(session=session,
|
||||||
|
creation_time__lt=result.creation_time)
|
||||||
|
result_info = {}
|
||||||
|
result_info['newest_result_id'] = result_id
|
||||||
|
if len(session_results) == 0:
|
||||||
|
LOG.info('No previous result. Abort.')
|
||||||
|
return result_info
|
||||||
|
prev_result_id = session_results[len(session_results) - 1].pk
|
||||||
|
base_result_id = session_results[0].pk
|
||||||
|
prev_result = Result.objects.filter(pk=prev_result_id)
|
||||||
|
base_result = Result.objects.filter(pk=base_result_id)
|
||||||
|
|
||||||
|
# Extract data from result
|
||||||
|
result = Result.objects.filter(pk=result_id)
|
||||||
|
agg_data = DataUtil.aggregate_data(result)
|
||||||
|
metric_data = agg_data['y_matrix'].flatten()
|
||||||
|
prev_metric_data = (DataUtil.aggregate_data(prev_result))['y_matrix'].flatten()
|
||||||
|
base_metric_data = (DataUtil.aggregate_data(base_result))['y_matrix'].flatten()
|
||||||
|
|
||||||
|
# Clean knob data
|
||||||
|
cleaned_agg_data = clean_knob_data(agg_data['X_matrix'], agg_data['X_columnlabels'], session)
|
||||||
|
agg_data['X_matrix'] = np.array(cleaned_agg_data[0]).flatten()
|
||||||
|
agg_data['X_columnlabels'] = np.array(cleaned_agg_data[1]).flatten()
|
||||||
|
knob_data = DataUtil.normalize_knob_data(agg_data['X_matrix'],
|
||||||
|
agg_data['X_columnlabels'], session)
|
||||||
|
knob_num = len(knob_data)
|
||||||
|
metric_num = len(metric_data)
|
||||||
|
LOG.info('knob_num: %d, metric_num: %d', knob_num, metric_num)
|
||||||
|
|
||||||
|
# Filter ys by current target objective metric
|
||||||
|
result = Result.objects.get(pk=result_id)
|
||||||
|
target_objective = result.session.target_objective
|
||||||
|
target_obj_idx = [i for i, n in enumerate(agg_data['y_columnlabels']) if n == target_objective]
|
||||||
|
if len(target_obj_idx) == 0:
|
||||||
|
raise Exception(('Could not find target objective in metrics '
|
||||||
|
'(target_obj={})').format(target_objective))
|
||||||
|
elif len(target_obj_idx) > 1:
|
||||||
|
raise Exception(('Found {} instances of target objective in '
|
||||||
|
'metrics (target_obj={})').format(len(target_obj_idx),
|
||||||
|
target_objective))
|
||||||
|
objective = metric_data[target_obj_idx]
|
||||||
|
prev_objective = prev_metric_data[target_obj_idx]
|
||||||
|
base_objective = base_metric_data[target_obj_idx]
|
||||||
|
metric_meta = MetricCatalog.objects.get_metric_meta(result.session.dbms,
|
||||||
|
result.session.target_objective)
|
||||||
|
|
||||||
|
# Calculate the reward
|
||||||
|
reward = 0
|
||||||
|
if metric_meta[target_objective].improvement == '(less is better)':
|
||||||
|
if objective - base_objective <= 0:
|
||||||
|
reward = -(np.square(objective / base_objective) - 1) * objective / prev_objective
|
||||||
|
else:
|
||||||
|
reward = (np.square((2 * base_objective - objective) / base_objective) - 1)\
|
||||||
|
* (2 * prev_objective - objective) / prev_objective
|
||||||
|
else:
|
||||||
|
if objective - base_objective > 0:
|
||||||
|
reward = (np.square(objective / base_objective) - 1) * objective / prev_objective
|
||||||
|
else:
|
||||||
|
reward = -(np.square((2 * base_objective - objective) / base_objective) - 1)\
|
||||||
|
* (2 * prev_objective - objective) / prev_objective
|
||||||
|
|
||||||
|
# Update ddpg
|
||||||
|
project_root = dirname(dirname(dirname(abspath(__file__))))
|
||||||
|
saved_memory = join(project_root, 'checkpoint/reply_memory_' + session.project.name)
|
||||||
|
saved_model = join(project_root, 'checkpoint/ddpg_' + session.project.name)
|
||||||
|
ddpg = DDPG(n_actions=knob_num, n_states=metric_num)
|
||||||
|
if os.path.exists(saved_memory):
|
||||||
|
ddpg.replay_memory.load_memory(saved_memory)
|
||||||
|
ddpg.load_model(saved_model)
|
||||||
|
ddpg.add_sample(prev_metric_data, knob_data, reward, metric_data, False)
|
||||||
|
if len(ddpg.replay_memory) > 32:
|
||||||
|
ddpg.update()
|
||||||
|
checkpoint_dir = join(project_root, 'checkpoint')
|
||||||
|
if not os.path.exists(checkpoint_dir):
|
||||||
|
os.makedirs(checkpoint_dir)
|
||||||
|
ddpg.replay_memory.save(saved_memory)
|
||||||
|
ddpg.save_model(saved_model)
|
||||||
|
return result_info
|
||||||
|
|
||||||
|
|
||||||
|
@task(base=ConfigurationRecommendation, name='run_ddpg')
|
||||||
|
def run_ddpg(result_info):
|
||||||
|
LOG.info('Use ddpg to recommend configuration')
|
||||||
|
result_id = result_info['newest_result_id']
|
||||||
|
result = Result.objects.filter(pk=result_id)
|
||||||
|
session = Result.objects.get(pk=result_id).session
|
||||||
|
agg_data = DataUtil.aggregate_data(result)
|
||||||
|
metric_data = agg_data['y_matrix'].flatten()
|
||||||
|
cleaned_agg_data = clean_knob_data(agg_data['X_matrix'], agg_data['X_columnlabels'],
|
||||||
|
session)
|
||||||
|
knob_labels = np.array(cleaned_agg_data[1]).flatten()
|
||||||
|
knob_data = np.array(cleaned_agg_data[0]).flatten()
|
||||||
|
knob_num = len(knob_data)
|
||||||
|
metric_num = len(metric_data)
|
||||||
|
|
||||||
|
project_root = dirname(dirname(dirname(abspath(__file__))))
|
||||||
|
saved_memory = join(project_root, 'checkpoint/reply_memory_' + session.project.name)
|
||||||
|
saved_model = join(project_root, 'checkpoint/ddpg_' + session.project.name)
|
||||||
|
ddpg = DDPG(n_actions=knob_num, n_states=metric_num)
|
||||||
|
if os.path.exists(saved_memory):
|
||||||
|
ddpg.replay_memory.load_memory(saved_memory)
|
||||||
|
ddpg.load_model(saved_model)
|
||||||
|
knob_data = ddpg.choose_action(metric_data)
|
||||||
|
LOG.info('recommended knob: %s', knob_data)
|
||||||
|
knob_data = DataUtil.denormalize_knob_data(knob_data, knob_labels, session)
|
||||||
|
conf_map = {k: knob_data[i] for i, k in enumerate(knob_labels)}
|
||||||
|
conf_map_res = {}
|
||||||
|
conf_map_res['status'] = 'good'
|
||||||
|
conf_map_res['recommendation'] = conf_map
|
||||||
|
conf_map_res['info'] = 'INFO: ddpg'
|
||||||
|
for k in knob_labels:
|
||||||
|
LOG.info('%s: %f', k, conf_map[k])
|
||||||
|
return conf_map_res
|
||||||
|
|
||||||
|
|
||||||
@task(base=ConfigurationRecommendation, name='configuration_recommendation')
|
@task(base=ConfigurationRecommendation, name='configuration_recommendation')
|
||||||
def configuration_recommendation(target_data):
|
def configuration_recommendation(target_data):
|
||||||
LOG.info('configuration_recommendation called')
|
LOG.info('configuration_recommendation called')
|
||||||
|
|
|
@ -65,6 +65,9 @@ urlpatterns = [
|
||||||
|
|
||||||
# Back door
|
# Back door
|
||||||
url(r'^query_and_get/(?P<upload_code>[0-9a-zA-Z]+)$', website_views.give_result, name="backdoor"),
|
url(r'^query_and_get/(?P<upload_code>[0-9a-zA-Z]+)$', website_views.give_result, name="backdoor"),
|
||||||
|
|
||||||
|
# train ddpg with results in the given session
|
||||||
|
url(r'^train_ddpg/sessions/(?P<session_id>[0-9]+)$', website_views.train_ddpg_loops, name='train_ddpg_loops'),
|
||||||
]
|
]
|
||||||
|
|
||||||
if settings.DEBUG:
|
if settings.DEBUG:
|
||||||
|
|
|
@ -20,7 +20,7 @@ from django.utils.text import capfirst
|
||||||
from djcelery.models import TaskMeta
|
from djcelery.models import TaskMeta
|
||||||
|
|
||||||
from .types import LabelStyleType, VarType
|
from .types import LabelStyleType, VarType
|
||||||
from .models import KnobCatalog, DBMSCatalog
|
from .models import KnobCatalog, DBMSCatalog, SessionKnob
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -92,6 +92,34 @@ class TaskUtil(object):
|
||||||
|
|
||||||
class DataUtil(object):
|
class DataUtil(object):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize_knob_data(knob_values, knob_labels, session):
|
||||||
|
for i, knob in enumerate(knob_labels):
|
||||||
|
knob_object = KnobCatalog.objects.get(dbms=session.dbms, name=knob, tunable=True)
|
||||||
|
minval = float(knob_object.minval)
|
||||||
|
maxval = float(knob_object.maxval)
|
||||||
|
knob_new = SessionKnob.objects.filter(knob=knob_object, session=session, tunable=True)
|
||||||
|
if knob_new.exists():
|
||||||
|
minval = float(knob_new[0].minval)
|
||||||
|
maxval = float(knob_new[0].maxval)
|
||||||
|
knob_values[i] = (knob_values[i] - minval) / (maxval - minval)
|
||||||
|
knob_values[i] = max(0, min(knob_values[i], 1))
|
||||||
|
return knob_values
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def denormalize_knob_data(knob_values, knob_labels, session):
|
||||||
|
for i, knob in enumerate(knob_labels):
|
||||||
|
knob_object = KnobCatalog.objects.get(dbms=session.dbms, name=knob, tunable=True)
|
||||||
|
minval = float(knob_object.minval)
|
||||||
|
maxval = float(knob_object.maxval)
|
||||||
|
knob_session_object = SessionKnob.objects.filter(knob=knob_object, session=session,
|
||||||
|
tunable=True)
|
||||||
|
if knob_session_object.exists():
|
||||||
|
minval = float(knob_session_object[0].minval)
|
||||||
|
maxval = float(knob_session_object[0].maxval)
|
||||||
|
knob_values[i] = knob_values[i] * (maxval - minval) + minval
|
||||||
|
return knob_values
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def aggregate_data(results):
|
def aggregate_data(results):
|
||||||
knob_labels = list(JSONUtil.loads(results[0].knob_data.data).keys())
|
knob_labels = list(JSONUtil.loads(results[0].knob_data.data).keys())
|
||||||
|
|
|
@ -30,7 +30,7 @@ from .models import (BackupData, DBMSCatalog, KnobCatalog, KnobData, MetricCatal
|
||||||
MetricData, MetricManager, Project, Result, Session, Workload,
|
MetricData, MetricManager, Project, Result, Session, Workload,
|
||||||
SessionKnob)
|
SessionKnob)
|
||||||
from .parser import Parser
|
from .parser import Parser
|
||||||
from .tasks import (aggregate_target_results, map_workload,
|
from .tasks import (aggregate_target_results, map_workload, train_ddpg, run_ddpg,
|
||||||
configuration_recommendation)
|
configuration_recommendation)
|
||||||
from .types import (DBMSType, KnobUnitType, MetricType,
|
from .types import (DBMSType, KnobUnitType, MetricType,
|
||||||
TaskType, VarType, WorkloadStatusType, AlgorithmType)
|
TaskType, VarType, WorkloadStatusType, AlgorithmType)
|
||||||
|
@ -967,3 +967,11 @@ def give_result(request, upload_code): # pylint: disable=unused-argument
|
||||||
# success
|
# success
|
||||||
res = Result.objects.get(pk=lastest_result.pk)
|
res = Result.objects.get(pk=lastest_result.pk)
|
||||||
return HttpResponse(JSONUtil.dumps(res.next_configuration), content_type='application/json')
|
return HttpResponse(JSONUtil.dumps(res.next_configuration), content_type='application/json')
|
||||||
|
|
||||||
|
|
||||||
|
def train_ddpg_loops(request, session_id): # pylint: disable=unused-argument
|
||||||
|
session = get_object_or_404(Session, pk=session_id, user=request.user) # pylint: disable=unused-variable
|
||||||
|
results = Result.objects.filter(session=session_id)
|
||||||
|
for result in results:
|
||||||
|
train_ddpg(result.pk)
|
||||||
|
return HttpResponse()
|
||||||
|
|
Loading…
Reference in New Issue