improve ddpg

This commit is contained in:
yangdsh
2019-11-09 01:58:10 +00:00
committed by Dana Van Aken
parent 67a4a70c09
commit 21fce27291
5 changed files with 79 additions and 59 deletions

View File

@@ -25,15 +25,18 @@ class TestDDPG(unittest.TestCase):
np.random.seed(0)
torch.manual_seed(0)
super(TestDDPG, cls).setUpClass()
cls.ddpg = DDPG(n_actions=1, n_states=1, gamma=0)
for _ in range(700):
knob_data = np.array([random.random()])
prev_metric_data = np.array([random.random()])
cls.ddpg = DDPG(n_actions=1, n_states=1, gamma=0, alr=0.02)
knob_data = np.zeros(1)
metric_data = np.array([random.random()])
for _ in range(100):
prev_metric_data = metric_data
metric_data = np.array([random.random()])
reward = 1.0 if (prev_metric_data[0] - 0.5) * (knob_data[0] - 0.5) > 0 else 0.0
reward = np.array([reward])
cls.ddpg.add_sample(prev_metric_data, knob_data, reward, metric_data)
cls.ddpg.update()
for _ in range(10):
cls.ddpg.update()
knob_data = cls.ddpg.choose_action(metric_data)
def test_ddpg_ypreds(self):
total_reward = 0.0