change the expected value in ddpg test
This commit is contained in:
parent
f071a0e62c
commit
090387a176
|
@ -17,27 +17,26 @@ class TestDDPG(unittest.TestCase):
|
|||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
super(TestDDPG, cls).setUpClass()
|
||||
boston = datasets.load_boston()
|
||||
data = boston['data']
|
||||
X_train = data[0:500]
|
||||
cls.X_test = data[500:]
|
||||
X_test = data[500:]
|
||||
y_train = boston['target'][0:500].reshape(500, 1)
|
||||
cls.ddpg = DDPG(n_actions=1, n_states=13)
|
||||
ddpg = DDPG(n_actions=1, n_states=13)
|
||||
for i in range(500):
|
||||
knob_data = np.array([random.random()])
|
||||
prev_metric_data = X_train[i - 1]
|
||||
metric_data = X_train[i]
|
||||
reward = y_train[i - 1]
|
||||
cls.ddpg.add_sample(prev_metric_data, knob_data, reward, metric_data, False)
|
||||
if len(cls.ddpg.replay_memory) > 32:
|
||||
cls.ddpg.update()
|
||||
ddpg.add_sample(prev_metric_data, knob_data, reward, metric_data, False)
|
||||
if len(ddpg.replay_memory) > 32:
|
||||
ddpg.update()
|
||||
cls.ypreds_round = ['%.4f' % ddpg.choose_action(x)[0] for x in X_test]
|
||||
|
||||
def test_ddpg_ypreds(self):
|
||||
ypreds_round = [round(self.ddpg.choose_action(x)[0], 4) for x in self.X_test]
|
||||
expected_ypreds = [0.1778, 0.1914, 0.2607, 0.4459, 0.5660, 0.3836]
|
||||
for ypred_round, expected_ypred in zip(ypreds_round, expected_ypreds):
|
||||
self.assertAlmostEqual(ypred_round, expected_ypred, places=6)
|
||||
expected_ypreds = ['0.3169', '0.3240', '0.3934', '0.5787', '0.6988', '0.5163']
|
||||
self.assertEqual(self.ypreds_round, expected_ypreds)
|
||||
|
|
|
@ -299,7 +299,7 @@ def configuration_recommendation_ddpg(result_info): # pylint: disable=invalid-n
|
|||
metric_scalar = MinMaxScaler().fit(metric_data.reshape(1, -1))
|
||||
normalized_metric_data = metric_scalar.transform(metric_data.reshape(1, -1))[0]
|
||||
cleaned_knob_data = clean_knob_data(agg_data['X_matrix'], agg_data['X_columnlabels'],
|
||||
session)
|
||||
session)
|
||||
knob_labels = np.array(cleaned_knob_data[1]).flatten()
|
||||
knob_num = len(knob_labels)
|
||||
metric_num = len(metric_data)
|
||||
|
|
Loading…
Reference in New Issue