fix ddpg
This commit is contained in:
parent
56860d6364
commit
b4e5fb2e66
|
@ -297,15 +297,20 @@ def preprocessing(result_id, algorithm):
|
|||
# implement a sampling technique to generate new training data).
|
||||
has_pipeline_data = PipelineData.objects.filter(workload=newest_result.workload).exists()
|
||||
session_results = Result.objects.filter(session=session)
|
||||
useful_results_cnt = len(session_results)
|
||||
for result in session_results:
|
||||
results_cnt = len(session_results)
|
||||
skip_ddpg = False
|
||||
for i, result in enumerate(session_results):
|
||||
if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name:
|
||||
useful_results_cnt -= 1
|
||||
if not has_pipeline_data or useful_results_cnt == 0 or session.tuning_session == 'lhs':
|
||||
results_cnt -= 1
|
||||
if i == len(session_results) - 1 and algorithm == AlgorithmType.DDPG:
|
||||
skip_ddpg = True
|
||||
if not has_pipeline_data or results_cnt == 0 or skip_ddpg or session.tuning_session == 'lhs':
|
||||
if not has_pipeline_data and session.tuning_session == 'tuning_session':
|
||||
LOG.debug("Background tasks haven't ran for this workload yet, picking data with lhs.")
|
||||
if useful_results_cnt == 0 and session.tuning_session == 'tuning_session':
|
||||
if results_cnt == 0 and session.tuning_session == 'tuning_session':
|
||||
LOG.debug("Not enough data in this session, picking data with lhs.")
|
||||
if skip_ddpg:
|
||||
LOG.debug("The most recent result cannot be used by DDPG, picking data with lhs.")
|
||||
|
||||
all_samples = JSONUtil.loads(session.lhs_samples)
|
||||
if len(all_samples) == 0:
|
||||
|
@ -457,11 +462,11 @@ def train_ddpg(train_ddpg_input):
|
|||
params = JSONUtil.loads(session.hyperparameters)
|
||||
session_results = Result.objects.filter(session=session,
|
||||
creation_time__lt=result.creation_time)
|
||||
useful_results_cnt = len(session_results)
|
||||
results_cnt = len(session_results)
|
||||
first_valid_result = -1
|
||||
for i, result in enumerate(session_results):
|
||||
if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name:
|
||||
useful_results_cnt -= 1
|
||||
results_cnt -= 1
|
||||
else:
|
||||
last_valid_result = i
|
||||
first_valid_result = i if first_valid_result == -1 else first_valid_result
|
||||
|
@ -470,7 +475,7 @@ def train_ddpg(train_ddpg_input):
|
|||
|
||||
# Extract data from result and previous results
|
||||
result = Result.objects.filter(pk=result_id)
|
||||
if useful_results_cnt == 0:
|
||||
if results_cnt == 0:
|
||||
base_result_id = result_id
|
||||
prev_result_id = result_id
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue