improve simulator and ddpg
This commit is contained in:
@@ -32,6 +32,7 @@ class Actor(nn.Module):
|
||||
nn.Linear(128, 128),
|
||||
nn.Tanh(),
|
||||
nn.Dropout(0.3),
|
||||
nn.BatchNorm1d(128),
|
||||
|
||||
nn.Linear(128, 64),
|
||||
nn.Tanh(),
|
||||
@@ -99,7 +100,7 @@ class Critic(nn.Module):
|
||||
class DDPG(object):
|
||||
|
||||
def __init__(self, n_states, n_actions, model_name='', alr=0.001, clr=0.001,
|
||||
gamma=0.9, batch_size=32, tau=0.002, memory_size=100000):
|
||||
gamma=0.9, batch_size=32, tau=0.002, shift=0, memory_size=100000):
|
||||
self.n_states = n_states
|
||||
self.n_actions = n_actions
|
||||
self.alr = alr
|
||||
@@ -108,6 +109,7 @@ class DDPG(object):
|
||||
self.batch_size = batch_size
|
||||
self.gamma = gamma
|
||||
self.tau = tau
|
||||
self.shift = shift
|
||||
|
||||
self._build_network()
|
||||
|
||||
@@ -184,7 +186,7 @@ class DDPG(object):
|
||||
target_next_actions = self.target_actor(batch_next_states).detach()
|
||||
target_next_value = self.target_critic(batch_next_states, target_next_actions).detach()
|
||||
current_value = self.critic(batch_states, batch_actions)
|
||||
next_value = batch_rewards + target_next_value * self.gamma
|
||||
next_value = batch_rewards + target_next_value * self.gamma + self.shift
|
||||
|
||||
# update prioritized memory
|
||||
if isinstance(self.replay_memory, PrioritizedReplayMemory):
|
||||
|
||||
Reference in New Issue
Block a user