improve ddpg

This commit is contained in:
yangdsh
2019-11-09 01:58:10 +00:00
committed by Dana Van Aken
parent 67a4a70c09
commit 21fce27291
5 changed files with 79 additions and 59 deletions

View File

@@ -23,21 +23,21 @@ LOG = get_analysis_logger(__name__)
class Actor(nn.Module):
def __init__(self, n_states, n_actions):
def __init__(self, n_states, n_actions, hidden_sizes):
super(Actor, self).__init__()
self.layers = nn.Sequential(
nn.Linear(n_states, 128),
nn.Linear(n_states, hidden_sizes[0]),
nn.LeakyReLU(negative_slope=0.2),
nn.BatchNorm1d(128),
nn.Linear(128, 128),
nn.BatchNorm1d(hidden_sizes[0]),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.Tanh(),
nn.Dropout(0.3),
nn.BatchNorm1d(128),
nn.Linear(128, 64),
nn.BatchNorm1d(hidden_sizes[1]),
nn.Linear(hidden_sizes[1], hidden_sizes[2]),
nn.Tanh(),
nn.BatchNorm1d(64),
nn.Linear(64, n_actions)
nn.Dropout(0.3),
nn.BatchNorm1d(hidden_sizes[2]),
nn.Linear(hidden_sizes[2], n_actions)
)
# This act layer maps the output to (0, 1)
self.act = nn.Sigmoid()
@@ -58,21 +58,21 @@ class Actor(nn.Module):
class Critic(nn.Module):
def __init__(self, n_states, n_actions):
def __init__(self, n_states, n_actions, hidden_sizes):
super(Critic, self).__init__()
self.state_input = nn.Linear(n_states, 128)
self.action_input = nn.Linear(n_actions, 128)
self.state_input = nn.Linear(n_states, hidden_sizes[0])
self.action_input = nn.Linear(n_actions, hidden_sizes[0])
self.act = nn.Tanh()
self.layers = nn.Sequential(
nn.Linear(256, 256),
nn.Linear(hidden_sizes[0] * 2, hidden_sizes[1]),
nn.LeakyReLU(negative_slope=0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 64),
nn.Dropout(0.3),
nn.BatchNorm1d(hidden_sizes[1]),
nn.Linear(hidden_sizes[1], hidden_sizes[2]),
nn.Tanh(),
nn.Dropout(0.3),
nn.BatchNorm1d(64),
nn.Linear(64, 1),
nn.BatchNorm1d(hidden_sizes[2]),
nn.Linear(hidden_sizes[2], 1),
)
self._init_weights()
@@ -100,7 +100,8 @@ class Critic(nn.Module):
class DDPG(object):
def __init__(self, n_states, n_actions, model_name='', alr=0.001, clr=0.001,
gamma=0.9, batch_size=32, tau=0.002, shift=0, memory_size=100000):
gamma=0.9, batch_size=32, tau=0.002, shift=0, memory_size=100000,
a_hidden_sizes=[128, 128, 64], c_hidden_sizes=[128, 256, 64]):
self.n_states = n_states
self.n_actions = n_actions
self.alr = alr
@@ -109,6 +110,8 @@ class DDPG(object):
self.batch_size = batch_size
self.gamma = gamma
self.tau = tau
self.a_hidden_sizes = a_hidden_sizes
self.c_hidden_sizes = c_hidden_sizes
self.shift = shift
self._build_network()
@@ -121,10 +124,10 @@ class DDPG(object):
return Variable(torch.FloatTensor(x))
def _build_network(self):
self.actor = Actor(self.n_states, self.n_actions)
self.target_actor = Actor(self.n_states, self.n_actions)
self.critic = Critic(self.n_states, self.n_actions)
self.target_critic = Critic(self.n_states, self.n_actions)
self.actor = Actor(self.n_states, self.n_actions, self.a_hidden_sizes)
self.target_actor = Actor(self.n_states, self.n_actions, self.a_hidden_sizes)
self.critic = Critic(self.n_states, self.n_actions, self.c_hidden_sizes)
self.target_critic = Critic(self.n_states, self.n_actions, self.c_hidden_sizes)
# Copy actor's parameters
self._update_target(self.target_actor, self.actor, tau=1.0)