improve dnn

This commit is contained in:
Bohan Zhang
2019-12-18 21:37:40 -08:00
committed by Dana Van Aken
parent 3a722df5e2
commit 9633b0e99c
3 changed files with 27 additions and 15 deletions

View File

@@ -84,6 +84,9 @@ HP_LEARNING_RATE = 0.001
# ---GRADIENT DESCENT FOR DNN---
DNN_TRAIN_ITER = 100
# Gradient Descent iteration for recommendation
DNN_GD_ITER = 100
DNN_EXPLORE = False
DNN_EXPLORE_ITER = 500

View File

@@ -44,7 +44,7 @@ from website.settings import (USE_GPFLOW, DEFAULT_LENGTH_SCALE, DEFAULT_MAGNITUD
DNN_TRAIN_ITER, DNN_EXPLORE, DNN_EXPLORE_ITER,
DNN_NOISE_SCALE_BEGIN, DNN_NOISE_SCALE_END,
DNN_DEBUG, DNN_DEBUG_INTERVAL, GPR_DEBUG, UCB_BETA,
GPR_MODEL_NAME, ENABLE_DUMMY_ENCODER)
GPR_MODEL_NAME, ENABLE_DUMMY_ENCODER, DNN_GD_ITER)
from website.settings import INIT_FLIP_PROB, FLIP_PROB_DECAY
from website.types import VarType
@@ -700,7 +700,7 @@ def configuration_recommendation(recommendation_input):
model_nn.set_weights_bin(session.dnn_model)
model_nn.fit(X_scaled, y_scaled, fit_epochs=DNN_TRAIN_ITER)
res = model_nn.recommend(X_samples, X_min, X_max,
explore=DNN_EXPLORE, recommend_epochs=MAX_ITER)
explore=DNN_EXPLORE, recommend_epochs=DNN_GD_ITER)
session.dnn_model = model_nn.get_weights_bin()
session.save()