improve ddpg
This commit is contained in:
@@ -25,15 +25,18 @@ class TestDDPG(unittest.TestCase):
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
super(TestDDPG, cls).setUpClass()
|
||||
cls.ddpg = DDPG(n_actions=1, n_states=1, gamma=0)
|
||||
for _ in range(700):
|
||||
knob_data = np.array([random.random()])
|
||||
prev_metric_data = np.array([random.random()])
|
||||
cls.ddpg = DDPG(n_actions=1, n_states=1, gamma=0, alr=0.02)
|
||||
knob_data = np.zeros(1)
|
||||
metric_data = np.array([random.random()])
|
||||
for _ in range(100):
|
||||
prev_metric_data = metric_data
|
||||
metric_data = np.array([random.random()])
|
||||
reward = 1.0 if (prev_metric_data[0] - 0.5) * (knob_data[0] - 0.5) > 0 else 0.0
|
||||
reward = np.array([reward])
|
||||
cls.ddpg.add_sample(prev_metric_data, knob_data, reward, metric_data)
|
||||
cls.ddpg.update()
|
||||
for _ in range(10):
|
||||
cls.ddpg.update()
|
||||
knob_data = cls.ddpg.choose_action(metric_data)
|
||||
|
||||
def test_ddpg_ypreds(self):
|
||||
total_reward = 0.0
|
||||
|
||||
Reference in New Issue
Block a user