fix sigma in new gpr model

This commit is contained in:
bohanjason 2019-11-29 20:05:00 -05:00 committed by Dana Van Aken
parent 7ee615a3f3
commit 5654d23637
1 changed files with 8 additions and 4 deletions

View File

@ -53,7 +53,10 @@ def tf_optimize(model, Xnew_arr, learning_rate=0.01, maxiter=100, ucb_beta=3.,
beta_t = tf.constant(ucb_beta, name='ucb_beta', dtype=settings.float_type)
y_mean_var = model.likelihood.predict_mean_and_var(*model._build_predict(Xin))
loss = tf.subtract(y_mean_var[0], tf.multiply(beta_t, y_mean_var[1]), name='loss_fn')
y_mean = y_mean_var[0]
y_var = y_mean_var[1]
y_std = tf.sqrt(y_var)
loss = tf.subtract(y_mean, tf.multiply(beta_t, y_std), name='loss_fn')
opt = tf.train.AdamOptimizer(learning_rate, epsilon=1e-6)
train_op = opt.minimize(loss)
variables = opt.variables()
@ -64,10 +67,11 @@ def tf_optimize(model, Xnew_arr, learning_rate=0.01, maxiter=100, ucb_beta=3.,
for i in range(maxiter):
session.run(train_op)
Xnew_value = session.run(Xnew_bounded)
y_mean_value, y_var_value = session.run(y_mean_var)
y_mean_value = session.run(y_mean)
y_std_value = session.run(y_std)
loss_value = session.run(loss)
assert_all_finite(Xnew_value)
assert_all_finite(y_mean_value)
assert_all_finite(y_var_value)
assert_all_finite(y_std_value)
assert_all_finite(loss_value)
return GPRGDResult(y_mean_value, y_var_value, loss_value, Xnew_value)
return GPRGDResult(y_mean_value, y_std_value, loss_value, Xnew_value)