diff --git a/server/analysis/gpr/gpr_models.py b/server/analysis/gpr/gpr_models.py index 82b04bf..3efbfe1 100644 --- a/server/analysis/gpr/gpr_models.py +++ b/server/analysis/gpr/gpr_models.py @@ -118,8 +118,43 @@ class BasicGP(BaseModel): return k +class ExpWhiteGP(BaseModel): + + _KERNEL_HP_KEYS = [ + [ + 'GPRC/kern/kernels/0/variance', + 'GPRC/kern/kernels/0/lengthscales', + ], + [ + 'GPRC/kern/kernels/1/variance', + ], + ] + + def _get_kernel_kwargs(self, **kwargs): + X_dim = kwargs.pop('X_dim') + return [ + { + 'input_dim': X_dim, + 'ARD': False + }, + { + 'input_dim': X_dim, + }, + ] + + def _build_kernel(self, kernel_kwargs, **kwargs): + k0 = gpflow.kernels.Exponential(**kernel_kwargs[0]) + k1 = gpflow.kernels.White(**kernel_kwargs[1]) + if kwargs.pop('optimize_hyperparameters'): + k0.lengthscales.transform = gpflow.transforms.Logistic( + *self._LENGTHSCALE_BOUNDS) + k = k0 + k1 + return k + + _MODEL_MAP = { 'BasicGP': BasicGP, + 'ExpWhiteGP': ExpWhiteGP, }