improve simulator and ddpg

This commit is contained in:
yangdsh
2019-10-28 17:37:08 +00:00
committed by Dana Van Aken
parent 5431154784
commit 9f71d1c8de
4 changed files with 125 additions and 80 deletions

View File

@@ -32,6 +32,7 @@ class Actor(nn.Module):
nn.Linear(128, 128),
nn.Tanh(),
nn.Dropout(0.3),
nn.BatchNorm1d(128),
nn.Linear(128, 64),
nn.Tanh(),
@@ -99,7 +100,7 @@ class Critic(nn.Module):
class DDPG(object):
def __init__(self, n_states, n_actions, model_name='', alr=0.001, clr=0.001,
gamma=0.9, batch_size=32, tau=0.002, memory_size=100000):
gamma=0.9, batch_size=32, tau=0.002, shift=0, memory_size=100000):
self.n_states = n_states
self.n_actions = n_actions
self.alr = alr
@@ -108,6 +109,7 @@ class DDPG(object):
self.batch_size = batch_size
self.gamma = gamma
self.tau = tau
self.shift = shift
self._build_network()
@@ -184,7 +186,7 @@ class DDPG(object):
target_next_actions = self.target_actor(batch_next_states).detach()
target_next_value = self.target_critic(batch_next_states, target_next_actions).detach()
current_value = self.critic(batch_states, batch_actions)
next_value = batch_rewards + target_next_value * self.gamma
next_value = batch_rewards + target_next_value * self.gamma + self.shift
# update prioritized memory
if isinstance(self.replay_memory, PrioritizedReplayMemory):