reset graph in DNN

This commit is contained in:
bohanjason 2019-12-16 23:54:29 -05:00 committed by Dana Van Aken
parent c76c8e7bfb
commit f0deb63cdf
2 changed files with 9 additions and 4 deletions

View File

@ -34,7 +34,8 @@ class NeuralNet(object):
batch_size=1, batch_size=1,
explore_iters=500, explore_iters=500,
noise_scale_begin=0.1, noise_scale_begin=0.1,
noise_scale_end=0): noise_scale_end=0,
reset_seed=False):
self.history = None self.history = None
self.recommend_iters = 0 self.recommend_iters = 0
@ -49,6 +50,9 @@ class NeuralNet(object):
self.vars = {} self.vars = {}
self.ops = {} self.ops = {}
tf.reset_default_graph()
if reset_seed:
tf.set_random_seed(0)
self.session = tf.Session() self.session = tf.Session()
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
with self.graph.as_default(): with self.graph.as_default():

View File

@ -26,17 +26,18 @@ class TestNN(unittest.TestCase):
np.random.seed(0) np.random.seed(0)
set_random_seed(0) set_random_seed(0)
cls.model = NeuralNet(n_input=X_test.shape[1], cls.model = NeuralNet(n_input=X_test.shape[1],
batch_size=X_test.shape[0]) batch_size=X_test.shape[0],
reset_seed=True)
cls.model.fit(X_train, y_train) cls.model.fit(X_train, y_train)
cls.nn_result = cls.model.predict(X_test) cls.nn_result = cls.model.predict(X_test)
cls.nn_recommend = cls.model.recommend(X_test) cls.nn_recommend = cls.model.recommend(X_test)
def test_nn_ypreds(self): def test_nn_ypreds(self):
ypreds_round = ['%.3f' % x[0] for x in self.nn_result] ypreds_round = ['%.3f' % x[0] for x in self.nn_result]
expected_ypreds = ['20.021', '22.578', '22.722', '26.889', '24.362', '23.258'] expected_ypreds = ['21.279', '22.668', '23.115', '27.228', '25.892', '23.967']
self.assertEqual(ypreds_round, expected_ypreds) self.assertEqual(ypreds_round, expected_ypreds)
def test_nn_yrecommend(self): def test_nn_yrecommend(self):
recommends_round = ['%.3f' % x[0] for x in self.nn_recommend.minl] recommends_round = ['%.3f' % x[0] for x in self.nn_recommend.minl]
expected_recommends = ['13.321', '15.482', '15.621', '18.648', '16.982', '15.986'] expected_recommends = ['21.279', '21.279', '21.279', '21.279', '21.279', '21.279']
self.assertEqual(recommends_round, expected_recommends) self.assertEqual(recommends_round, expected_recommends)