reset graph in DNN
This commit is contained in:
parent
c76c8e7bfb
commit
f0deb63cdf
|
@ -34,7 +34,8 @@ class NeuralNet(object):
|
|||
batch_size=1,
|
||||
explore_iters=500,
|
||||
noise_scale_begin=0.1,
|
||||
noise_scale_end=0):
|
||||
noise_scale_end=0,
|
||||
reset_seed=False):
|
||||
|
||||
self.history = None
|
||||
self.recommend_iters = 0
|
||||
|
@ -49,6 +50,9 @@ class NeuralNet(object):
|
|||
self.vars = {}
|
||||
self.ops = {}
|
||||
|
||||
tf.reset_default_graph()
|
||||
if reset_seed:
|
||||
tf.set_random_seed(0)
|
||||
self.session = tf.Session()
|
||||
self.graph = tf.get_default_graph()
|
||||
with self.graph.as_default():
|
||||
|
|
|
@ -26,17 +26,18 @@ class TestNN(unittest.TestCase):
|
|||
np.random.seed(0)
|
||||
set_random_seed(0)
|
||||
cls.model = NeuralNet(n_input=X_test.shape[1],
|
||||
batch_size=X_test.shape[0])
|
||||
batch_size=X_test.shape[0],
|
||||
reset_seed=True)
|
||||
cls.model.fit(X_train, y_train)
|
||||
cls.nn_result = cls.model.predict(X_test)
|
||||
cls.nn_recommend = cls.model.recommend(X_test)
|
||||
|
||||
def test_nn_ypreds(self):
|
||||
ypreds_round = ['%.3f' % x[0] for x in self.nn_result]
|
||||
expected_ypreds = ['20.021', '22.578', '22.722', '26.889', '24.362', '23.258']
|
||||
expected_ypreds = ['21.279', '22.668', '23.115', '27.228', '25.892', '23.967']
|
||||
self.assertEqual(ypreds_round, expected_ypreds)
|
||||
|
||||
def test_nn_yrecommend(self):
|
||||
recommends_round = ['%.3f' % x[0] for x in self.nn_recommend.minl]
|
||||
expected_recommends = ['13.321', '15.482', '15.621', '18.648', '16.982', '15.986']
|
||||
expected_recommends = ['21.279', '21.279', '21.279', '21.279', '21.279', '21.279']
|
||||
self.assertEqual(recommends_round, expected_recommends)
|
||||
|
|
Loading…
Reference in New Issue