diff --git a/server/analysis/simulation.py b/server/analysis/simulation.py index 9e69f8f..dff650f 100644 --- a/server/analysis/simulation.py +++ b/server/analysis/simulation.py @@ -4,10 +4,13 @@ # Copyright (c) 2017-18, Carnegie Mellon University Database Group # + +import imp import random import os import sys try: + imp.find_module('matplotlib.pyplot') import matplotlib.pyplot as plt except (ModuleNotFoundError, ImportError): plt = None @@ -111,7 +114,7 @@ def plotlines(x_axis, data1, data2, label1, label2, title, path): plt.clf() -def main(knob_dim=192, metric_dim=60, lr=0.001, mode=0, n_loops=1000): +def main(knob_dim=8, metric_dim=60, lr=0.0001, mode=2, n_loops=2000): if not plt: LOG.info("Cannot import matplotlib. Will write results to files instead of figures.") random.seed(0) @@ -119,7 +122,7 @@ def main(knob_dim=192, metric_dim=60, lr=0.001, mode=0, n_loops=1000): torch.manual_seed(0) env = Environment(knob_dim, metric_dim, mode=mode) - n_repeats = 5 + n_repeats = 10 for i in range(n_repeats): if i == 0: results1, x_axis = train_ddpg(env, gamma=0, lr=lr, n_loops=n_loops)