work around a pylint & astroid bug

This commit is contained in:
yangdsh 2019-10-22 22:22:18 +00:00 committed by Dana Van Aken
parent 794418d29f
commit d71c131e5b
1 changed files with 5 additions and 2 deletions

View File

@ -4,10 +4,13 @@
# Copyright (c) 2017-18, Carnegie Mellon University Database Group # Copyright (c) 2017-18, Carnegie Mellon University Database Group
# #
import imp
import random import random
import os import os
import sys import sys
try: try:
imp.find_module('matplotlib.pyplot')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
except (ModuleNotFoundError, ImportError): except (ModuleNotFoundError, ImportError):
plt = None plt = None
@ -111,7 +114,7 @@ def plotlines(x_axis, data1, data2, label1, label2, title, path):
plt.clf() plt.clf()
def main(knob_dim=192, metric_dim=60, lr=0.001, mode=0, n_loops=1000): def main(knob_dim=8, metric_dim=60, lr=0.0001, mode=2, n_loops=2000):
if not plt: if not plt:
LOG.info("Cannot import matplotlib. Will write results to files instead of figures.") LOG.info("Cannot import matplotlib. Will write results to files instead of figures.")
random.seed(0) random.seed(0)
@ -119,7 +122,7 @@ def main(knob_dim=192, metric_dim=60, lr=0.001, mode=0, n_loops=1000):
torch.manual_seed(0) torch.manual_seed(0)
env = Environment(knob_dim, metric_dim, mode=mode) env = Environment(knob_dim, metric_dim, mode=mode)
n_repeats = 5 n_repeats = 10
for i in range(n_repeats): for i in range(n_repeats):
if i == 0: if i == 0:
results1, x_axis = train_ddpg(env, gamma=0, lr=lr, n_loops=n_loops) results1, x_axis = train_ddpg(env, gamma=0, lr=lr, n_loops=n_loops)