diff --git a/server/website/website/tasks/async_tasks.py b/server/website/website/tasks/async_tasks.py index 3e98cd6..ca8e533 100644 --- a/server/website/website/tasks/async_tasks.py +++ b/server/website/website/tasks/async_tasks.py @@ -75,6 +75,16 @@ class MapWorkloadTask(BaseTask): # pylint: disable=abstract-method task_meta.save() +class ConfigurationRecommendation(BaseTask): # pylint: disable=abstract-method + + def on_success(self, retval, task_id, args, kwargs): + super(ConfigurationRecommendation, self).on_success(retval, task_id, args, kwargs) + + task_meta = TaskMeta.objects.get(task_id=task_id) + task_meta.result = retval + task_meta.save() + + def clean_knob_data(knob_matrix, knob_labels, session): # Makes sure that all knobs in the dbms are included in the knob_matrix and knob_labels knob_matrix = np.array(knob_matrix) @@ -420,7 +430,7 @@ def create_and_save_recommendation(recommended_knobs, result, status, **kwargs): return retval -@shared_task(base=BaseTask, name='configuration_recommendation_ddpg') +@shared_task(base=ConfigurationRecommendation, name='configuration_recommendation_ddpg') def configuration_recommendation_ddpg(result_info): # pylint: disable=invalid-name LOG.info('Use ddpg to recommend configuration') result_id = result_info['newest_result_id'] @@ -640,7 +650,7 @@ def combine_workload(target_data): dummy_encoder, constraint_helper -@shared_task(base=BaseTask, name='configuration_recommendation') +@shared_task(base=ConfigurationRecommendation, name='configuration_recommendation') def configuration_recommendation(recommendation_input): target_data, algorithm = recommendation_input LOG.info('configuration_recommendation called') diff --git a/server/website/website/utils.py b/server/website/website/utils.py index 1e49d98..5e0a5b8 100644 --- a/server/website/website/utils.py +++ b/server/website/website/utils.py @@ -3,6 +3,7 @@ # # Copyright (c) 2017-18, Carnegie Mellon University Database Group # +import celery import datetime import json import logging @@ -73,6 +74,16 @@ class MediaUtil(object): class TaskUtil(object): + @staticmethod + def get_task_ids_from_tuple(task_tuple): + task_res = celery.result.result_from_tuple(task_tuple) + task_ids = [] + task = task_res + while task is not None: + task_ids.insert(0, task) + task = task.parent + return task_ids + @staticmethod def get_tasks(task_ids): task_ids = task_ids or [] diff --git a/server/website/website/views.py b/server/website/website/views.py index 00b321e..b7250ac 100644 --- a/server/website/website/views.py +++ b/server/website/website/views.py @@ -403,16 +403,24 @@ def result_view(request, project_id, session_id, result_id): # default_metrics = {mname: metric_data[mname] * metric_meta[mname].scale # for mname in default_metrics} + if session.tuning_session == 'no_tuning_session': + status = None + next_conf = '' + next_conf_available = False + else: + task_tuple = JSONUtil.loads(target.task_ids) + task_ids = TaskUtil.get_task_ids_from_tuple(task_tuple) + tasks = TaskUtil.get_tasks(task_ids) + status, _ = TaskUtil.get_task_status(tasks, len(task_ids)) - task_ids = [t for t in (target.task_ids or '').split(',') if t.strip() != ''] - tasks = TaskUtil.get_tasks(task_ids) - status, _ = TaskUtil.get_task_status(tasks, len(task_ids)) - - next_conf_available = True if status == 'SUCCESS' else False - next_conf = '' - cfg = target.next_configuration - LOG.debug("status: %s, next_conf_available: %s, next_conf: %s, type: %s", - status, next_conf_available, cfg, type(cfg)) + if status == 'SUCCESS': # pylint: disable=simplifiable-if-statement + next_conf_available = True + else: + next_conf_available = False + next_conf = '' + cfg = target.next_configuration + LOG.debug("status: %s, next_conf_available: %s, next_conf: %s, type: %s", + status, next_conf_available, cfg, type(cfg)) if next_conf_available: try: @@ -917,8 +925,8 @@ def download_debug_info(request, project_id, session_id): # pylint: disable=unu @login_required(login_url=reverse_lazy('login')) def tuner_status_view(request, project_id, session_id, result_id): # pylint: disable=unused-argument res = Result.objects.get(pk=result_id) - - task_ids = [t for t in (res.task_ids or '').split(',') if t.strip() != ''] + task_tuple = JSONUtil.loads(res.task_ids) + task_ids = TaskUtil.get_task_ids_from_tuple(task_tuple) tasks = TaskUtil.get_tasks(task_ids) overall_status, num_completed = TaskUtil.get_task_status(tasks, len(task_ids))