fix ddpg metrics
This commit is contained in:
parent
a9f1556184
commit
21971e673f
|
@ -152,28 +152,22 @@ def clean_metric_data(metric_matrix, metric_labels, session):
|
|||
metric_cat = [session.target_objective]
|
||||
for metric_obj in metric_objs:
|
||||
metric_cat.append(metric_obj.name)
|
||||
matrix = np.array(metric_matrix)
|
||||
missing_columns = sorted(set(metric_cat) - set(metric_labels))
|
||||
unused_columns = set(metric_labels) - set(metric_cat)
|
||||
LOG.debug("clean_metric_data added %d metrics and removed %d metric.", len(missing_columns),
|
||||
len(unused_columns))
|
||||
# If columns are missing from the matrix
|
||||
if missing_columns:
|
||||
for metric in missing_columns:
|
||||
# append a missing column after the last column
|
||||
index = matrix.shape[1] # pylint: disable=unsubscriptable-object
|
||||
default_val = 0
|
||||
matrix = np.insert(matrix, index, default_val, axis=1)
|
||||
metric_labels.append(metric)
|
||||
default_val = 0
|
||||
metric_cat_size = len(metric_cat)
|
||||
matrix = np.ones((len(metric_matrix), metric_cat_size)) * default_val
|
||||
metric_labels_dict = {n: i for i, n in enumerate(metric_labels)}
|
||||
# column labels in matrix has the same order as ones in metric catalog
|
||||
# missing values are filled with default_val
|
||||
for i, metric_name in enumerate(metric_cat):
|
||||
if metric_name in metric_labels_dict:
|
||||
index = metric_labels_dict[metric_name]
|
||||
matrix[:, i] = metric_matrix[:, index]
|
||||
LOG.debug(matrix.shape)
|
||||
# If they are useless columns in the matrix
|
||||
if unused_columns:
|
||||
indexes = [i for i, n in enumerate(metric_labels) if n in unused_columns]
|
||||
# Delete unused columns
|
||||
matrix = np.delete(matrix, indexes, 1)
|
||||
for i in sorted(indexes, reverse=True):
|
||||
del metric_labels[i]
|
||||
return matrix, metric_labels
|
||||
return matrix, metric_cat
|
||||
|
||||
|
||||
def save_execution_time(start_ts, fn, result):
|
||||
|
|
Loading…
Reference in New Issue