save dnn model into database
This commit is contained in:
committed by
Dana Van Aken
parent
c37ef9c072
commit
25d0838376
@@ -8,6 +8,7 @@ Created on Sep 16, 2019
|
||||
@author: Bohan Zhang
|
||||
'''
|
||||
|
||||
import pickle
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
@@ -28,7 +29,6 @@ class NeuralNet(object):
|
||||
|
||||
def __init__(self,
|
||||
n_input,
|
||||
weights_file,
|
||||
learning_rate=0.01,
|
||||
debug=False,
|
||||
debug_interval=100,
|
||||
@@ -36,9 +36,6 @@ class NeuralNet(object):
|
||||
explore_iters=500,
|
||||
noise_scale_begin=0.1,
|
||||
noise_scale_end=0):
|
||||
# absolute path for the model weitghs file
|
||||
# one model for each (project, session)
|
||||
self.weights_file = weights_file
|
||||
|
||||
self.history = None
|
||||
self.recommend_iters = 0
|
||||
@@ -58,7 +55,6 @@ class NeuralNet(object):
|
||||
layers.Dense(64, activation=tf.nn.relu),
|
||||
layers.Dense(1)
|
||||
])
|
||||
self.load_weights()
|
||||
self.model.compile(loss='mean_squared_error',
|
||||
optimizer=self.optimizer,
|
||||
metrics=['mean_squared_error', 'mean_absolute_error'])
|
||||
@@ -66,17 +62,28 @@ class NeuralNet(object):
|
||||
self.ops = {}
|
||||
self.build_graph()
|
||||
|
||||
def save_weights(self):
|
||||
self.model.save_weights(self.weights_file)
|
||||
def save_weights(self, weights_file):
|
||||
self.model.save_weights(weights_file)
|
||||
|
||||
def load_weights(self):
|
||||
def load_weights(self, weights_file):
|
||||
try:
|
||||
self.model.load_weights(self.weights_file)
|
||||
self.model.load_weights(weights_file)
|
||||
if self.debug:
|
||||
LOG.info('Neural Network Model weights file exists, load weights from the file')
|
||||
except Exception: # pylint: disable=broad-except
|
||||
LOG.info('Weights file does not match neural network model, train model from scratch')
|
||||
|
||||
def get_weights_bin(self):
|
||||
return pickle.dumps(self.model.get_weights())
|
||||
|
||||
def set_weights_bin(self, weights):
|
||||
try:
|
||||
self.model.set_weights(pickle.loads(weights))
|
||||
if self.debug:
|
||||
LOG.info('Neural Network Model weights exists, load the existing weights')
|
||||
except Exception: # pylint: disable=broad-except
|
||||
LOG.info('Weights does not match neural network model, train model from scratch')
|
||||
|
||||
# Build same neural network as self.model, But input X is variables,
|
||||
# weights are placedholders. Find optimial X using gradient descent.
|
||||
def build_graph(self):
|
||||
@@ -109,8 +116,6 @@ class NeuralNet(object):
|
||||
def fit(self, X_train, y_train, fit_epochs=500):
|
||||
self.history = self.model.fit(
|
||||
X_train, y_train, epochs=fit_epochs, verbose=0)
|
||||
# save model weights
|
||||
self.save_weights()
|
||||
if self.debug:
|
||||
mse = self.history.history['mean_squared_error']
|
||||
i = 0
|
||||
|
||||
Reference in New Issue
Block a user