fix ddpg
This commit is contained in:
parent
56860d6364
commit
b4e5fb2e66
|
@ -297,15 +297,20 @@ def preprocessing(result_id, algorithm):
|
||||||
# implement a sampling technique to generate new training data).
|
# implement a sampling technique to generate new training data).
|
||||||
has_pipeline_data = PipelineData.objects.filter(workload=newest_result.workload).exists()
|
has_pipeline_data = PipelineData.objects.filter(workload=newest_result.workload).exists()
|
||||||
session_results = Result.objects.filter(session=session)
|
session_results = Result.objects.filter(session=session)
|
||||||
useful_results_cnt = len(session_results)
|
results_cnt = len(session_results)
|
||||||
for result in session_results:
|
skip_ddpg = False
|
||||||
|
for i, result in enumerate(session_results):
|
||||||
if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name:
|
if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name:
|
||||||
useful_results_cnt -= 1
|
results_cnt -= 1
|
||||||
if not has_pipeline_data or useful_results_cnt == 0 or session.tuning_session == 'lhs':
|
if i == len(session_results) - 1 and algorithm == AlgorithmType.DDPG:
|
||||||
|
skip_ddpg = True
|
||||||
|
if not has_pipeline_data or results_cnt == 0 or skip_ddpg or session.tuning_session == 'lhs':
|
||||||
if not has_pipeline_data and session.tuning_session == 'tuning_session':
|
if not has_pipeline_data and session.tuning_session == 'tuning_session':
|
||||||
LOG.debug("Background tasks haven't ran for this workload yet, picking data with lhs.")
|
LOG.debug("Background tasks haven't ran for this workload yet, picking data with lhs.")
|
||||||
if useful_results_cnt == 0 and session.tuning_session == 'tuning_session':
|
if results_cnt == 0 and session.tuning_session == 'tuning_session':
|
||||||
LOG.debug("Not enough data in this session, picking data with lhs.")
|
LOG.debug("Not enough data in this session, picking data with lhs.")
|
||||||
|
if skip_ddpg:
|
||||||
|
LOG.debug("The most recent result cannot be used by DDPG, picking data with lhs.")
|
||||||
|
|
||||||
all_samples = JSONUtil.loads(session.lhs_samples)
|
all_samples = JSONUtil.loads(session.lhs_samples)
|
||||||
if len(all_samples) == 0:
|
if len(all_samples) == 0:
|
||||||
|
@ -457,11 +462,11 @@ def train_ddpg(train_ddpg_input):
|
||||||
params = JSONUtil.loads(session.hyperparameters)
|
params = JSONUtil.loads(session.hyperparameters)
|
||||||
session_results = Result.objects.filter(session=session,
|
session_results = Result.objects.filter(session=session,
|
||||||
creation_time__lt=result.creation_time)
|
creation_time__lt=result.creation_time)
|
||||||
useful_results_cnt = len(session_results)
|
results_cnt = len(session_results)
|
||||||
first_valid_result = -1
|
first_valid_result = -1
|
||||||
for i, result in enumerate(session_results):
|
for i, result in enumerate(session_results):
|
||||||
if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name:
|
if 'range_test' in result.metric_data.name or 'default' in result.metric_data.name:
|
||||||
useful_results_cnt -= 1
|
results_cnt -= 1
|
||||||
else:
|
else:
|
||||||
last_valid_result = i
|
last_valid_result = i
|
||||||
first_valid_result = i if first_valid_result == -1 else first_valid_result
|
first_valid_result = i if first_valid_result == -1 else first_valid_result
|
||||||
|
@ -470,7 +475,7 @@ def train_ddpg(train_ddpg_input):
|
||||||
|
|
||||||
# Extract data from result and previous results
|
# Extract data from result and previous results
|
||||||
result = Result.objects.filter(pk=result_id)
|
result = Result.objects.filter(pk=result_id)
|
||||||
if useful_results_cnt == 0:
|
if results_cnt == 0:
|
||||||
base_result_id = result_id
|
base_result_id = result_id
|
||||||
prev_result_id = result_id
|
prev_result_id = result_id
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue