Add testing for ddpg model
This commit is contained in:
parent
876d975713
commit
f08204b9d6
|
@ -198,7 +198,6 @@ class DDPG(object):
|
||||||
else:
|
else:
|
||||||
# Build Network
|
# Build Network
|
||||||
self._build_network()
|
self._build_network()
|
||||||
LOG.info('Finish Initializing Networks')
|
|
||||||
|
|
||||||
self.replay_memory = PrioritizedReplayMemory(capacity=memory_size)
|
self.replay_memory = PrioritizedReplayMemory(capacity=memory_size)
|
||||||
self.noise = OUProcess(n_actions)
|
self.noise = OUProcess(n_actions)
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
#
|
||||||
|
# OtterTune - test_ddpg.py
|
||||||
|
#
|
||||||
|
# Copyright (c) 2017-18, Carnegie Mellon University Database Group
|
||||||
|
#
|
||||||
|
|
||||||
|
import random
|
||||||
|
import unittest
|
||||||
|
from sklearn import datasets
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from analysis.ddpg.ddpg import DDPG
|
||||||
|
|
||||||
|
|
||||||
|
# test ddpg model
|
||||||
|
class TestDDPG(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
random.seed(0)
|
||||||
|
np.random.seed(0)
|
||||||
|
super(TestDDPG, cls).setUpClass()
|
||||||
|
boston = datasets.load_boston()
|
||||||
|
data = boston['data']
|
||||||
|
X_train = data[0:500]
|
||||||
|
cls.X_test = data[500:]
|
||||||
|
y_train = boston['target'][0:500].reshape(500, 1)
|
||||||
|
cls.ddpg = DDPG(n_actions=1, n_states=13)
|
||||||
|
for i in range(500):
|
||||||
|
knob_data = np.array([random.random()])
|
||||||
|
prev_metric_data = X_train[i - 1]
|
||||||
|
metric_data = X_train[i]
|
||||||
|
reward = y_train[i - 1]
|
||||||
|
cls.ddpg.add_sample(prev_metric_data, knob_data, reward, metric_data, False)
|
||||||
|
if len(cls.ddpg.replay_memory) > 32:
|
||||||
|
cls.ddpg.update()
|
||||||
|
|
||||||
|
def test_ddpg_ypreds(self):
|
||||||
|
ypreds_round = [round(self.ddpg.choose_action(x)[0], 4) for x in self.X_test]
|
||||||
|
expected_ypreds = [0.1778, 0.1914, 0.2607, 0.4459, 0.5660, 0.3836]
|
||||||
|
self.assertEqual(ypreds_round, expected_ypreds)
|
||||||
|
for ypred_round, expected_ypred in zip(ypreds_round, expected_ypreds):
|
||||||
|
self.assertAlmostEqual(ypred_round, expected_ypred, places=6)
|
|
@ -83,7 +83,7 @@ ACTOR_LEARNING_RATE = 0.001
|
||||||
CRITIC_LEARNING_RATE = 0.001
|
CRITIC_LEARNING_RATE = 0.001
|
||||||
|
|
||||||
# The impact of future reward on the decision
|
# The impact of future reward on the decision
|
||||||
GAMMA = 0.1
|
GAMMA = 0.9
|
||||||
|
|
||||||
# The changing rate of the target network
|
# The changing rate of the target network
|
||||||
TAU = 0.002
|
TAU = 0.002
|
||||||
|
|
Loading…
Reference in New Issue