fix multithread issue in DNN
This commit is contained in:
parent
2a7bc6145f
commit
c76c8e7bfb
|
@ -760,9 +760,21 @@ def integration_tests():
|
||||||
response = get_result(upload_code='ottertuneTestTuningDNN')
|
response = get_result(upload_code='ottertuneTestTuningDNN')
|
||||||
assert response['status'] == 'good'
|
assert response['status'] == 'good'
|
||||||
|
|
||||||
|
# 2rd iteration Test DNN
|
||||||
|
upload_result(result_dir='./integrationTests/data/', prefix='1__',
|
||||||
|
upload_code='ottertuneTestTuningDNN')
|
||||||
|
response = get_result(upload_code='ottertuneTestTuningDNN')
|
||||||
|
assert response['status'] == 'good'
|
||||||
|
|
||||||
# Test GPR
|
# Test GPR
|
||||||
LOG.info('Test GPR (gaussian process regression)')
|
LOG.info('Test GPR (gaussian process regression)')
|
||||||
upload_result(result_dir='./integrationTests/data/', prefix='0__',
|
upload_result(result_dir='./integrationTests/data/', prefix='0__',
|
||||||
upload_code='ottertuneTestTuningGPR')
|
upload_code='ottertuneTestTuningGPR')
|
||||||
response = get_result(upload_code='ottertuneTestTuningGPR')
|
response = get_result(upload_code='ottertuneTestTuningGPR')
|
||||||
assert response['status'] == 'good'
|
assert response['status'] == 'good'
|
||||||
|
|
||||||
|
# 2rd iteration Test GPR
|
||||||
|
upload_result(result_dir='./integrationTests/data/', prefix='1__',
|
||||||
|
upload_code='ottertuneTestTuningGPR')
|
||||||
|
response = get_result(upload_code='ottertuneTestTuningGPR')
|
||||||
|
assert response['status'] == 'good'
|
||||||
|
|
|
@ -46,38 +46,52 @@ class NeuralNet(object):
|
||||||
self.explore_iters = explore_iters
|
self.explore_iters = explore_iters
|
||||||
self.noise_scale_begin = noise_scale_begin
|
self.noise_scale_begin = noise_scale_begin
|
||||||
self.noise_scale_end = noise_scale_end
|
self.noise_scale_end = noise_scale_end
|
||||||
self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
|
|
||||||
# input X is placeholder, weights are variables.
|
|
||||||
self.model = keras.Sequential([
|
|
||||||
keras.layers.Dense(64, activation=tf.nn.relu, input_shape=[n_input]),
|
|
||||||
keras.layers.Dropout(0.5),
|
|
||||||
keras.layers.Dense(64, activation=tf.nn.relu),
|
|
||||||
keras.layers.Dense(1)
|
|
||||||
])
|
|
||||||
self.model.compile(loss='mean_squared_error',
|
|
||||||
optimizer=self.optimizer,
|
|
||||||
metrics=['mean_squared_error', 'mean_absolute_error'])
|
|
||||||
self.vars = {}
|
self.vars = {}
|
||||||
self.ops = {}
|
self.ops = {}
|
||||||
self.build_graph()
|
|
||||||
|
|
||||||
def save_weights(self, weights_file):
|
self.session = tf.Session()
|
||||||
self.model.save_weights(weights_file)
|
self.graph = tf.get_default_graph()
|
||||||
|
with self.graph.as_default():
|
||||||
|
with self.session.as_default(): # pylint: disable=not-context-manager
|
||||||
|
self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
|
||||||
|
# input X is placeholder, weights are variables.
|
||||||
|
self.model = keras.Sequential([
|
||||||
|
keras.layers.Dense(64, activation=tf.nn.relu, input_shape=[n_input]),
|
||||||
|
keras.layers.Dropout(0.5),
|
||||||
|
keras.layers.Dense(64, activation=tf.nn.relu),
|
||||||
|
keras.layers.Dense(1)
|
||||||
|
])
|
||||||
|
self.model.compile(loss='mean_squared_error',
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
metrics=['mean_squared_error', 'mean_absolute_error'])
|
||||||
|
self._build_graph()
|
||||||
|
|
||||||
def load_weights(self, weights_file):
|
def save_weights_file(self, weights_file):
|
||||||
|
with self.graph.as_default():
|
||||||
|
with self.session.as_default(): # pylint: disable=not-context-manager
|
||||||
|
self.model.save_weights(weights_file)
|
||||||
|
|
||||||
|
def load_weights_file(self, weights_file):
|
||||||
try:
|
try:
|
||||||
self.model.load_weights(weights_file)
|
with self.graph.as_default():
|
||||||
|
with self.session.as_default(): # pylint: disable=not-context-manager
|
||||||
|
self.model.load_weights(weights_file)
|
||||||
if self.debug:
|
if self.debug:
|
||||||
LOG.info('Neural Network Model weights file exists, load weights from the file')
|
LOG.info('Neural Network Model weights file exists, load weights from the file')
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
LOG.info('Weights file does not match neural network model, train model from scratch')
|
LOG.info('Weights file does not match neural network model, train model from scratch')
|
||||||
|
|
||||||
def get_weights_bin(self):
|
def get_weights_bin(self):
|
||||||
return pickle.dumps(self.model.get_weights())
|
with self.graph.as_default():
|
||||||
|
with self.session.as_default(): # pylint: disable=not-context-manager
|
||||||
|
weights = self.model.get_weights()
|
||||||
|
return pickle.dumps(weights)
|
||||||
|
|
||||||
def set_weights_bin(self, weights):
|
def set_weights_bin(self, weights):
|
||||||
try:
|
try:
|
||||||
self.model.set_weights(pickle.loads(weights))
|
with self.graph.as_default():
|
||||||
|
with self.session.as_default(): # pylint: disable=not-context-manager
|
||||||
|
self.model.set_weights(pickle.loads(weights))
|
||||||
if self.debug:
|
if self.debug:
|
||||||
LOG.info('Neural Network Model weights exists, load the existing weights')
|
LOG.info('Neural Network Model weights exists, load the existing weights')
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
|
@ -85,58 +99,62 @@ class NeuralNet(object):
|
||||||
|
|
||||||
# Build same neural network as self.model, But input X is variables,
|
# Build same neural network as self.model, But input X is variables,
|
||||||
# weights are placedholders. Find optimial X using gradient descent.
|
# weights are placedholders. Find optimial X using gradient descent.
|
||||||
def build_graph(self):
|
def _build_graph(self):
|
||||||
batch_size = self.batch_size
|
batch_size = self.batch_size
|
||||||
self.graph = tf.Graph()
|
|
||||||
with self.graph.as_default():
|
with self.graph.as_default():
|
||||||
x_ = tf.Variable(tf.ones([batch_size, self.n_input]))
|
with self.session.as_default(): # pylint: disable=not-context-manager
|
||||||
w1_ = tf.placeholder(tf.float32, [self.n_input, 64])
|
x_ = tf.Variable(tf.ones([batch_size, self.n_input]))
|
||||||
b1_ = tf.placeholder(tf.float32, [64])
|
w1_ = tf.placeholder(tf.float32, [self.n_input, 64])
|
||||||
w2_ = tf.placeholder(tf.float32, [64, 64])
|
b1_ = tf.placeholder(tf.float32, [64])
|
||||||
b2_ = tf.placeholder(tf.float32, [64])
|
w2_ = tf.placeholder(tf.float32, [64, 64])
|
||||||
w3_ = tf.placeholder(tf.float32, [64, 1])
|
b2_ = tf.placeholder(tf.float32, [64])
|
||||||
b3_ = tf.placeholder(tf.float32, [1])
|
w3_ = tf.placeholder(tf.float32, [64, 1])
|
||||||
l1_ = tf.nn.relu(tf.add(tf.matmul(x_, w1_), b1_))
|
b3_ = tf.placeholder(tf.float32, [1])
|
||||||
l2_ = tf.nn.relu(tf.add(tf.matmul(l1_, w2_), b2_))
|
l1_ = tf.nn.relu(tf.add(tf.matmul(x_, w1_), b1_))
|
||||||
y_ = tf.add(tf.matmul(l2_, w3_), b3_)
|
l2_ = tf.nn.relu(tf.add(tf.matmul(l1_, w2_), b2_))
|
||||||
optimizer_ = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
|
y_ = tf.add(tf.matmul(l2_, w3_), b3_)
|
||||||
train_ = optimizer_.minimize(y_)
|
optimizer_ = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
|
||||||
|
train_ = optimizer_.minimize(y_)
|
||||||
|
|
||||||
self.vars['x_'] = x_
|
self.vars['x_'] = x_
|
||||||
self.vars['y_'] = y_
|
self.vars['y_'] = y_
|
||||||
self.vars['w1_'] = w1_
|
self.vars['w1_'] = w1_
|
||||||
self.vars['w2_'] = w2_
|
self.vars['w2_'] = w2_
|
||||||
self.vars['w3_'] = w3_
|
self.vars['w3_'] = w3_
|
||||||
self.vars['b1_'] = b1_
|
self.vars['b1_'] = b1_
|
||||||
self.vars['b2_'] = b2_
|
self.vars['b2_'] = b2_
|
||||||
self.vars['b3_'] = b3_
|
self.vars['b3_'] = b3_
|
||||||
self.ops['train_'] = train_
|
self.ops['train_'] = train_
|
||||||
|
|
||||||
def fit(self, X_train, y_train, fit_epochs=500):
|
def fit(self, X_train, y_train, fit_epochs=500):
|
||||||
self.history = self.model.fit(
|
with self.graph.as_default():
|
||||||
X_train, y_train, epochs=fit_epochs, verbose=0)
|
with self.session.as_default(): # pylint: disable=not-context-manager
|
||||||
if self.debug:
|
self.history = self.model.fit(
|
||||||
mse = self.history.history['mean_squared_error']
|
X_train, y_train, epochs=fit_epochs, verbose=0)
|
||||||
i = 0
|
if self.debug:
|
||||||
size = len(mse)
|
mse = self.history.history['mean_squared_error']
|
||||||
while(i < size):
|
i = 0
|
||||||
LOG.info("Neural network training phase, epoch %d: mean_squared_error %f",
|
size = len(mse)
|
||||||
i, mse[i])
|
while(i < size):
|
||||||
i += self.debug_interval
|
LOG.info("Neural network training phase, epoch %d: mean_squared_error %f",
|
||||||
LOG.info("Neural network training phase, epoch %d: mean_squared_error %f",
|
i, mse[i])
|
||||||
size - 1, mse[size - 1])
|
i += self.debug_interval
|
||||||
|
LOG.info("Neural network training phase, epoch %d: mean_squared_error %f",
|
||||||
|
size - 1, mse[size - 1])
|
||||||
|
|
||||||
def predict(self, X_pred):
|
def predict(self, X_pred):
|
||||||
return self.model.predict(X_pred)
|
with self.graph.as_default():
|
||||||
|
with self.session.as_default(): # pylint: disable=not-context-manager
|
||||||
|
return self.model.predict(X_pred)
|
||||||
|
|
||||||
# Reference: Parameter Space Noise for Exploration.ICLR 2018, https://arxiv.org/abs/1706.01905
|
# Reference: Parameter Space Noise for Exploration.ICLR 2018, https://arxiv.org/abs/1706.01905
|
||||||
def add_noise(self, weights):
|
def _add_noise(self, weights):
|
||||||
scale = self.adaptive_noise_scale()
|
scale = self._adaptive_noise_scale()
|
||||||
size = weights.shape[-1]
|
size = weights.shape[-1]
|
||||||
noise = scale * np.random.normal(size=size)
|
noise = scale * np.random.normal(size=size)
|
||||||
return weights + noise
|
return weights + noise
|
||||||
|
|
||||||
def adaptive_noise_scale(self):
|
def _adaptive_noise_scale(self):
|
||||||
if self.recommend_iters > self.explore_iters:
|
if self.recommend_iters > self.explore_iters:
|
||||||
scale = self.noise_scale_end
|
scale = self.noise_scale_end
|
||||||
else:
|
else:
|
||||||
|
@ -147,69 +165,71 @@ class NeuralNet(object):
|
||||||
def recommend(self, X_start, X_min=None, X_max=None, recommend_epochs=500, explore=False):
|
def recommend(self, X_start, X_min=None, X_max=None, recommend_epochs=500, explore=False):
|
||||||
batch_size = len(X_start)
|
batch_size = len(X_start)
|
||||||
assert(batch_size == self.batch_size)
|
assert(batch_size == self.batch_size)
|
||||||
w1, b1 = self.model.get_layer(index=0).get_weights()
|
|
||||||
w2, b2 = self.model.get_layer(index=2).get_weights()
|
|
||||||
w3, b3 = self.model.get_layer(index=3).get_weights()
|
|
||||||
|
|
||||||
if explore is True:
|
with self.graph.as_default():
|
||||||
w1 = self.add_noise(w1)
|
with self.session.as_default() as sess: # pylint: disable=not-context-manager
|
||||||
b1 = self.add_noise(b1)
|
w1, b1 = self.model.get_layer(index=0).get_weights()
|
||||||
w2 = self.add_noise(w2)
|
w2, b2 = self.model.get_layer(index=2).get_weights()
|
||||||
b2 = self.add_noise(b2)
|
w3, b3 = self.model.get_layer(index=3).get_weights()
|
||||||
w3 = self.add_noise(w3)
|
|
||||||
b3 = self.add_noise(b3)
|
|
||||||
|
|
||||||
y_predict = self.predict(X_start)
|
if explore is True:
|
||||||
if self.debug:
|
w1 = self._add_noise(w1)
|
||||||
LOG.info("Recommend phase, y prediction: min %f, max %f, mean %f",
|
b1 = self._add_noise(b1)
|
||||||
np.min(y_predict), np.max(y_predict), np.mean(y_predict))
|
w2 = self._add_noise(w2)
|
||||||
|
b2 = self._add_noise(b2)
|
||||||
|
w3 = self._add_noise(w3)
|
||||||
|
b3 = self._add_noise(b3)
|
||||||
|
|
||||||
with tf.Session(graph=self.graph) as sess:
|
y_predict = self.predict(X_start)
|
||||||
init = tf.global_variables_initializer()
|
if self.debug:
|
||||||
sess.run(init)
|
LOG.info("Recommend phase, y prediction: min %f, max %f, mean %f",
|
||||||
assign_x_op = self.vars['x_'].assign(X_start)
|
np.min(y_predict), np.max(y_predict), np.mean(y_predict))
|
||||||
sess.run(assign_x_op)
|
|
||||||
y_before = sess.run(self.vars['y_'],
|
|
||||||
feed_dict={self.vars['w1_']: w1, self.vars['w2_']: w2,
|
|
||||||
self.vars['w3_']: w3, self.vars['b1_']: b1,
|
|
||||||
self.vars['b2_']: b2, self.vars['b3_']: b3})
|
|
||||||
if self.debug:
|
|
||||||
LOG.info("Recommend phase, y before gradient descent: min %f, max %f, mean %f",
|
|
||||||
np.min(y_before), np.max(y_before), np.mean(y_before))
|
|
||||||
|
|
||||||
for i in range(recommend_epochs):
|
init = tf.global_variables_initializer()
|
||||||
sess.run(self.ops['train_'],
|
sess.run(init)
|
||||||
feed_dict={self.vars['w1_']: w1, self.vars['w2_']: w2,
|
assign_x_op = self.vars['x_'].assign(X_start)
|
||||||
self.vars['w3_']: w3, self.vars['b1_']: b1,
|
sess.run(assign_x_op)
|
||||||
self.vars['b2_']: b2, self.vars['b3_']: b3})
|
y_before = sess.run(self.vars['y_'],
|
||||||
|
feed_dict={self.vars['w1_']: w1, self.vars['w2_']: w2,
|
||||||
|
self.vars['w3_']: w3, self.vars['b1_']: b1,
|
||||||
|
self.vars['b2_']: b2, self.vars['b3_']: b3})
|
||||||
|
if self.debug:
|
||||||
|
LOG.info("Recommend phase, y before gradient descent: min %f, max %f, mean %f",
|
||||||
|
np.min(y_before), np.max(y_before), np.mean(y_before))
|
||||||
|
|
||||||
# constrain by X_min and X_max
|
for i in range(recommend_epochs):
|
||||||
if X_min is not None and X_max is not None:
|
sess.run(self.ops['train_'],
|
||||||
X_train = sess.run(self.vars['x_'])
|
feed_dict={self.vars['w1_']: w1, self.vars['w2_']: w2,
|
||||||
X_train = np.minimum(X_train, X_max)
|
self.vars['w3_']: w3, self.vars['b1_']: b1,
|
||||||
X_train = np.maximum(X_train, X_min)
|
self.vars['b2_']: b2, self.vars['b3_']: b3})
|
||||||
constraint_x_op = self.vars['x_'].assign(X_train)
|
|
||||||
sess.run(constraint_x_op)
|
|
||||||
|
|
||||||
if self.debug and i % self.debug_interval == 0:
|
# constrain by X_min and X_max
|
||||||
y_train = sess.run(self.vars['y_'],
|
if X_min is not None and X_max is not None:
|
||||||
|
X_train = sess.run(self.vars['x_'])
|
||||||
|
X_train = np.minimum(X_train, X_max)
|
||||||
|
X_train = np.maximum(X_train, X_min)
|
||||||
|
constraint_x_op = self.vars['x_'].assign(X_train)
|
||||||
|
sess.run(constraint_x_op)
|
||||||
|
|
||||||
|
if self.debug and i % self.debug_interval == 0:
|
||||||
|
y_train = sess.run(self.vars['y_'],
|
||||||
|
feed_dict={self.vars['w1_']: w1, self.vars['w2_']: w2,
|
||||||
|
self.vars['w3_']: w3, self.vars['b1_']: b1,
|
||||||
|
self.vars['b2_']: b2, self.vars['b3_']: b3})
|
||||||
|
LOG.info("Recommend phase, epoch %d, y: min %f, max %f, mean %f",
|
||||||
|
i, np.min(y_train), np.max(y_train), np.mean(y_train))
|
||||||
|
|
||||||
|
y_recommend = sess.run(self.vars['y_'],
|
||||||
feed_dict={self.vars['w1_']: w1, self.vars['w2_']: w2,
|
feed_dict={self.vars['w1_']: w1, self.vars['w2_']: w2,
|
||||||
self.vars['w3_']: w3, self.vars['b1_']: b1,
|
self.vars['w3_']: w3, self.vars['b1_']: b1,
|
||||||
self.vars['b2_']: b2, self.vars['b3_']: b3})
|
self.vars['b2_']: b2, self.vars['b3_']: b3})
|
||||||
LOG.info("Recommend phase, epoch %d, y: min %f, max %f, mean %f",
|
X_recommend = sess.run(self.vars['x_'])
|
||||||
i, np.min(y_train), np.max(y_train), np.mean(y_train))
|
res = NeuralNetResult(minl=y_recommend, minl_conf=X_recommend)
|
||||||
|
|
||||||
y_recommend = sess.run(self.vars['y_'],
|
if self.debug:
|
||||||
feed_dict={self.vars['w1_']: w1, self.vars['w2_']: w2,
|
LOG.info("Recommend phase, epoch %d, y after gradient descent: \
|
||||||
self.vars['w3_']: w3, self.vars['b1_']: b1,
|
min %f, max %f, mean %f", recommend_epochs, np.min(y_recommend),
|
||||||
self.vars['b2_']: b2, self.vars['b3_']: b3})
|
np.max(y_recommend), np.mean(y_recommend))
|
||||||
X_recommend = sess.run(self.vars['x_'])
|
|
||||||
res = NeuralNetResult(minl=y_recommend, minl_conf=X_recommend)
|
|
||||||
|
|
||||||
if self.debug:
|
self.recommend_iters += 1
|
||||||
LOG.info("Recommend phase, epoch %d, y after gradient descent: \
|
return res
|
||||||
min %f, max %f, mean %f", recommend_epochs, np.min(y_recommend),
|
|
||||||
np.max(y_recommend), np.mean(y_recommend))
|
|
||||||
|
|
||||||
self.recommend_iters += 1
|
|
||||||
return res
|
|
||||||
|
|
Loading…
Reference in New Issue