hyperparameter debug info for new gpr
This commit is contained in:
parent
f66543d965
commit
70f9f952d5
|
@ -45,4 +45,4 @@ class GPRC(GPR):
|
|||
else:
|
||||
fvar = self.kern.Kdiag(Xnew) - tf.reduce_sum(tf.square(A), 0)
|
||||
fvar = tf.tile(tf.reshape(fvar, (-1, 1)), [1, tf.shape(self.Y)[1]])
|
||||
return fmean, fvar
|
||||
return fmean, fvar, self.kern.variance, self.kern.lengthscales, self.likelihood.variance
|
||||
|
|
|
@ -26,7 +26,7 @@ class GPRGDResult():
|
|||
|
||||
|
||||
def tf_optimize(model, Xnew_arr, learning_rate=0.01, maxiter=100, ucb_beta=3.,
|
||||
active_dims=None, bounds=None):
|
||||
active_dims=None, bounds=None, debug=True):
|
||||
Xnew_arr = check_array(Xnew_arr, copy=False, warn_on_dtype=True, dtype=FLOAT_DTYPES)
|
||||
|
||||
Xnew = tf.Variable(Xnew_arr, name='Xnew', dtype=settings.float_type)
|
||||
|
@ -52,7 +52,8 @@ def tf_optimize(model, Xnew_arr, learning_rate=0.01, maxiter=100, ucb_beta=3.,
|
|||
Xin = Xnew_bounded
|
||||
|
||||
beta_t = tf.constant(ucb_beta, name='ucb_beta', dtype=settings.float_type)
|
||||
y_mean_var = model.likelihood.predict_mean_and_var(*model._build_predict(Xin))
|
||||
fmean, fvar, kvar, kls, lvar = model._build_predict(Xin) # pylint: disable=protected-access
|
||||
y_mean_var = model.likelihood.predict_mean_and_var(fmean, fvar)
|
||||
y_mean = y_mean_var[0]
|
||||
y_var = y_mean_var[1]
|
||||
y_std = tf.sqrt(y_var)
|
||||
|
@ -74,4 +75,8 @@ def tf_optimize(model, Xnew_arr, learning_rate=0.01, maxiter=100, ucb_beta=3.,
|
|||
assert_all_finite(y_mean_value)
|
||||
assert_all_finite(y_std_value)
|
||||
assert_all_finite(loss_value)
|
||||
if debug:
|
||||
LOG.info("kernel variance: %f", session.run(kvar))
|
||||
LOG.info("kernel lengthscale: %f", session.run(kls))
|
||||
LOG.info("likelihood variance: %f", session.run(lvar))
|
||||
return GPRGDResult(y_mean_value, y_std_value, loss_value, Xnew_value)
|
||||
|
|
|
@ -66,7 +66,7 @@ HP_MAX_ITER = 5000
|
|||
HP_LEARNING_RATE = 0.001
|
||||
|
||||
# ---GRADIENT DESCENT FOR DNN---
|
||||
DNN_TRAIN_ITER = 500
|
||||
DNN_TRAIN_ITER = 100
|
||||
|
||||
DNN_EXPLORE = False
|
||||
|
||||
|
|
Loading…
Reference in New Issue