ottertune/server/analysis/tests/test_ddpg.py

53 lines
1.8 KiB
Python
Raw Normal View History

2019-10-07 16:43:10 -07:00
#
# OtterTune - test_ddpg.py
#
# Copyright (c) 2017-18, Carnegie Mellon University Database Group
#
import random
import unittest
import numpy as np
import torch
from analysis.ddpg.ddpg import DDPG
# test ddpg model:
# The enviroment has 1-dim state and 1-dim action, the reward is calculated as follows:
# if state < 0.5, taking action < 0.5 gets reward 1, taking action >= 0.5 gets reward 0
# if state >= 0.5, taking action >= 0.5 gets reward 1, taking action < 0.5 gets reward 0
# Train 500 iterations and test for 500 iterations
# If the average reward during test is larger than 0.9, this test passes
2019-10-07 16:43:10 -07:00
class TestDDPG(unittest.TestCase):
@classmethod
def setUpClass(cls):
random.seed(0)
np.random.seed(0)
2019-10-13 20:52:06 -07:00
torch.manual_seed(0)
2019-10-07 16:43:10 -07:00
super(TestDDPG, cls).setUpClass()
2019-11-08 17:58:10 -08:00
cls.ddpg = DDPG(n_actions=1, n_states=1, gamma=0, alr=0.02)
knob_data = np.zeros(1)
metric_data = np.array([random.random()])
for _ in range(100):
prev_metric_data = metric_data
metric_data = np.array([random.random()])
reward = 1.0 if (prev_metric_data[0] - 0.5) * (knob_data[0] - 0.5) > 0 else 0.0
reward = np.array([reward])
2019-10-23 13:00:20 -07:00
cls.ddpg.add_sample(prev_metric_data, knob_data, reward, metric_data)
2019-11-08 17:58:10 -08:00
for _ in range(10):
cls.ddpg.update()
knob_data = cls.ddpg.choose_action(metric_data)
2019-10-07 16:43:10 -07:00
def test_ddpg_ypreds(self):
total_reward = 0.0
for _ in range(500):
prev_metric_data = np.array([random.random()])
knob_data = self.ddpg.choose_action(prev_metric_data)
reward = 1.0 if (prev_metric_data[0] - 0.5) * (knob_data[0] - 0.5) > 0 else 0.0
total_reward += reward
self.assertGreater(total_reward / 500, 0.5)
if __name__ == '__main__':
unittest.main()