41 lines
1.1 KiB
Python
41 lines
1.1 KiB
Python
|
import numpy as np
|
||
|
|
||
|
|
||
|
def get_beta_t(t, **kwargs):
|
||
|
assert t > 0.
|
||
|
return 2. * np.log(t / np.sqrt(np.log(2. * t)))
|
||
|
|
||
|
|
||
|
def get_beta_td(t, ndim, bound=1.0, **kwargs):
|
||
|
assert t > 0.
|
||
|
assert ndim > 0.
|
||
|
assert bound > 0.
|
||
|
bt = 2. * np.log(float(ndim) * t**2 * np.pi**2 / (6. * bound))
|
||
|
return np.sqrt(bt) if bt > 0. else 0.
|
||
|
|
||
|
|
||
|
_UCB_MAP = {
|
||
|
'get_beta_t': get_beta_t,
|
||
|
'get_beta_td': get_beta_td,
|
||
|
}
|
||
|
|
||
|
|
||
|
def get_ucb_beta(ucb_beta, **kwargs):
|
||
|
check_valid(ucb_beta)
|
||
|
if not isinstance(ucb_beta, float):
|
||
|
ucb_beta = _UCB_MAP[ucb_beta](**kwargs)
|
||
|
assert isinstance(ucb_beta, float), type(ucb_beta)
|
||
|
assert ucb_beta >= 0.0
|
||
|
return ucb_beta
|
||
|
|
||
|
|
||
|
def check_valid(ucb_beta):
|
||
|
if isinstance(ucb_beta, float):
|
||
|
if ucb_beta < 0.0:
|
||
|
raise ValueError(("Invalid value for 'ucb_beta': {} "
|
||
|
"(expected >= 0.0)").format(ucb_beta))
|
||
|
else:
|
||
|
if ucb_beta not in _UCB_MAP:
|
||
|
raise ValueError(("Invalid value for 'ucb_beta': {} "
|
||
|
"(expected 'get_beta_t' or 'get_beta_td')").format(ucb_beta))
|