use GPFlow in workload mapping
This commit is contained in:
committed by
Dana Van Aken
parent
389174302f
commit
25d1950e67
@@ -111,7 +111,13 @@ class BasicGP(BaseModel):
|
||||
]
|
||||
|
||||
def _build_kernel(self, kernel_kwargs, **kwargs):
|
||||
k = gpflow.kernels.Matern12(lengthscales=2, **kernel_kwargs[0])
|
||||
ls = 2
|
||||
var = 1
|
||||
if kwargs.get('lengthscales') is not None:
|
||||
ls = kwargs['lengthscales']
|
||||
if kwargs.get('variance') is not None:
|
||||
var = kwargs['variance']
|
||||
k = gpflow.kernels.Matern12(variance=var, lengthscales=ls, **kernel_kwargs[0])
|
||||
if kwargs.pop('optimize_hyperparameters'):
|
||||
k.lengthscales.transform = gpflow.transforms.Logistic(
|
||||
*self._LENGTHSCALE_BOUNDS)
|
||||
|
||||
@@ -16,6 +16,13 @@ from analysis.util import get_analysis_logger
|
||||
LOG = get_analysis_logger(__name__)
|
||||
|
||||
|
||||
class GPRResult():
|
||||
|
||||
def __init__(self, ypreds=None, sigmas=None):
|
||||
self.ypreds = ypreds
|
||||
self.sigmas = sigmas
|
||||
|
||||
|
||||
class GPRGDResult():
|
||||
|
||||
def __init__(self, ypreds=None, sigmas=None, minl=None, minl_conf=None):
|
||||
@@ -25,6 +32,20 @@ class GPRGDResult():
|
||||
self.minl_conf = minl_conf
|
||||
|
||||
|
||||
def gpflow_predict(model, Xin):
|
||||
fmean, fvar, _, _, _ = model._build_predict(Xin) # pylint: disable=protected-access
|
||||
y_mean_var = model.likelihood.predict_mean_and_var(fmean, fvar)
|
||||
y_mean = y_mean_var[0]
|
||||
y_var = y_mean_var[1]
|
||||
y_std = tf.sqrt(y_var)
|
||||
|
||||
session = model.enquire_session(session=None)
|
||||
with session.as_default():
|
||||
y_mean_value = session.run(y_mean)
|
||||
y_std_value = session.run(y_std)
|
||||
return GPRResult(y_mean_value, y_std_value)
|
||||
|
||||
|
||||
def tf_optimize(model, Xnew_arr, learning_rate=0.01, maxiter=100, ucb_beta=3.,
|
||||
active_dims=None, bounds=None, debug=True):
|
||||
Xnew_arr = check_array(Xnew_arr, copy=False, warn_on_dtype=True, dtype=FLOAT_DTYPES)
|
||||
|
||||
Reference in New Issue
Block a user