Make constants editable
This commit is contained in:
parent
d1a9933808
commit
42e1a4add3
|
@ -45,6 +45,12 @@ BATCH_SIZE = 3000
|
|||
# Threads for TensorFlow config
|
||||
NUM_THREADS = 4
|
||||
|
||||
# Value of beta for UCB
|
||||
UCB_BETA = 'get_beta_td'
|
||||
|
||||
# Name of the GPR model to use (GPFLOW only)
|
||||
GPR_MODEL_NAME = 'BasicGP'
|
||||
|
||||
# ---GRADIENT DESCENT CONSTANTS---
|
||||
# the maximum iterations of gradient descent
|
||||
MAX_ITER = 500
|
||||
|
|
|
@ -42,7 +42,8 @@ from website.settings import (USE_GPFLOW, DEFAULT_LENGTH_SCALE, DEFAULT_MAGNITUD
|
|||
ACTOR_HIDDEN_SIZES, CRITIC_HIDDEN_SIZES,
|
||||
DNN_TRAIN_ITER, DNN_EXPLORE, DNN_EXPLORE_ITER,
|
||||
DNN_NOISE_SCALE_BEGIN, DNN_NOISE_SCALE_END,
|
||||
DNN_DEBUG, DNN_DEBUG_INTERVAL, GPR_DEBUG)
|
||||
DNN_DEBUG, DNN_DEBUG_INTERVAL, GPR_DEBUG, UCB_BETA,
|
||||
GPR_MODEL_NAME)
|
||||
|
||||
from website.settings import INIT_FLIP_PROB, FLIP_PROB_DECAY
|
||||
from website.types import VarType
|
||||
|
@ -660,13 +661,12 @@ def configuration_recommendation(recommendation_input):
|
|||
opt_kwargs['maxiter'] = MAX_ITER
|
||||
opt_kwargs['bounds'] = [X_min, X_max]
|
||||
opt_kwargs['debug'] = GPR_DEBUG
|
||||
ucb_beta = 'get_beta_td'
|
||||
opt_kwargs['ucb_beta'] = ucb.get_ucb_beta(ucb_beta, scale=DEFAULT_UCB_SCALE,
|
||||
opt_kwargs['ucb_beta'] = ucb.get_ucb_beta(UCB_BETA, scale=DEFAULT_UCB_SCALE,
|
||||
t=i + 1., ndim=X_scaled.shape[1])
|
||||
tf.reset_default_graph()
|
||||
graph = tf.get_default_graph()
|
||||
gpflow.reset_default_session(graph=graph)
|
||||
m = gpr_models.create_model('BasicGP', X=X_scaled, y=y_scaled, **model_kwargs)
|
||||
m = gpr_models.create_model(GPR_MODEL_NAME, X=X_scaled, y=y_scaled, **model_kwargs)
|
||||
res = tf_optimize(m.model, X_samples, **opt_kwargs)
|
||||
else:
|
||||
model = GPRGD(length_scale=DEFAULT_LENGTH_SCALE,
|
||||
|
|
|
@ -73,6 +73,8 @@ urlpatterns = [
|
|||
url(r'^edit/session/', website_views.alt_create_or_edit_session, name='backdoor_edit_session'),
|
||||
url(r'^create/user/', website_views.alt_create_user, name='backdoor_create_user'),
|
||||
url(r'^delete/user/', website_views.alt_delete_user, name='backdoor_delete_user'),
|
||||
url(r'^info/(?P<name>[0-9a-zA-Z]+)', website_views.alt_get_info, name="backdoor_info"),
|
||||
url(r'^set_constant/(?P<name>[0-9a-zA-Z_]+)', website_views.alt_set_constant, name="backdoor_set_constant"),
|
||||
|
||||
# train ddpg with results in the given session
|
||||
url(r'^train_ddpg/sessions/(?P<session_id>[0-9]+)$', website_views.train_ddpg_loops, name='train_ddpg_loops'),
|
||||
|
|
|
@ -342,6 +342,19 @@ class LabelUtil(object):
|
|||
return style_labels
|
||||
|
||||
|
||||
def set_constant(name, value):
|
||||
getattr(constants, name) # Throw exception if not a valid option
|
||||
setattr(constants, name, value)
|
||||
|
||||
|
||||
def get_constants():
|
||||
constants_dict = OrderedDict()
|
||||
for name, value in sorted(constants.__dict__.items()):
|
||||
if not name.startswith('_') and name == name.upper():
|
||||
constants_dict[name] = value
|
||||
return constants_dict
|
||||
|
||||
|
||||
def dump_debug_info(session, pretty_print=False):
|
||||
files = {}
|
||||
|
||||
|
@ -418,11 +431,7 @@ def dump_debug_info(session, pretty_print=False):
|
|||
files['logs/{}.log'.format(logger_name)] = log_values
|
||||
|
||||
# Save settings
|
||||
constants_dict = OrderedDict()
|
||||
for name, value in sorted(constants.__dict__.items()):
|
||||
if not name.startswith('_') and name == name.upper():
|
||||
constants_dict[name] = value
|
||||
files['constants.json'] = constants_dict
|
||||
files['constants.json'] = get_constants()
|
||||
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
root = 'debug_{}'.format(timestamp)
|
||||
|
|
|
@ -1129,6 +1129,33 @@ def train_ddpg_loops(request, session_id): # pylint: disable=unused-argument
|
|||
return HttpResponse()
|
||||
|
||||
|
||||
@csrf_exempt
|
||||
def alt_get_info(request, name):
|
||||
# Backdoor method for getting basic info
|
||||
if name == 'constants':
|
||||
info = utils.get_constants()
|
||||
return HttpResponse(JSONUtil.dumps(info))
|
||||
else:
|
||||
LOG.warning("Invalid name for info request: %s", name)
|
||||
return HttpResponse("Invalid name for info request: {}".format(name), status=400)
|
||||
|
||||
|
||||
@csrf_exempt
|
||||
def alt_set_constant(request, name):
|
||||
# Sets a constant defined in settings/constants.py
|
||||
LOG.info('POST: %s', request.POST)
|
||||
LOG.info('POST.lists(): %s', request.POST.lists())
|
||||
value = request.POST['value']
|
||||
LOG.info('name: %s, value: %s, type: %s', name, value, type(value))
|
||||
#data = {k: v[0] for k, v in request.POST.lists()}
|
||||
try:
|
||||
utils.set_constant(name, value)
|
||||
except AttributeError as e:
|
||||
LOG.warning(e)
|
||||
return HttpResponse(e, status=400)
|
||||
return HttpResponse("Successfully updated {} to '{}'".format(name, value))
|
||||
|
||||
|
||||
@csrf_exempt
|
||||
def alt_create_user(request):
|
||||
response = dict(created=False, error=None, user=None)
|
||||
|
|
Loading…
Reference in New Issue