fix bugs
This commit is contained in:
		
							parent
							
								
									24194293bc
								
							
						
					
					
						commit
						d3c7bb661d
					
				|  | @ -25,8 +25,7 @@ from analysis.gpr import ucb | |||
| from analysis.gpr.optimize import tf_optimize | ||||
| from analysis.preprocessing import Bin, DummyEncoder | ||||
| from analysis.constraints import ParamConstraintHelper | ||||
| from website.models import (PipelineData, PipelineRun, Result, Workload, KnobCatalog, SessionKnob, | ||||
|                             MetricCatalog) | ||||
| from website.models import PipelineData, PipelineRun, Result, Workload, SessionKnob, MetricCatalog | ||||
| from website import db | ||||
| from website.types import PipelineTaskType, AlgorithmType, VarType | ||||
| from website.utils import DataUtil, JSONUtil | ||||
|  | @ -336,7 +335,7 @@ def train_ddpg(result_id): | |||
|         prev_result_id = result_id | ||||
|     else: | ||||
|         base_result_id = session_results[0].pk | ||||
|         prev_result_id = session_results[len(session_results)-1].pk | ||||
|         prev_result_id = session_results[len(session_results) - 1].pk | ||||
|     base_result = Result.objects.filter(pk=base_result_id) | ||||
|     prev_result = Result.objects.filter(pk=prev_result_id) | ||||
| 
 | ||||
|  | @ -381,7 +380,7 @@ def train_ddpg(result_id): | |||
|         result.session.dbms.pk, result.session.target_objective) | ||||
| 
 | ||||
|     # Calculate the reward | ||||
|     if DDPG_SIMPLE_REWARD: | ||||
|     if params['DDPG_SIMPLE_REWARD']: | ||||
|         objective = objective / base_objective | ||||
|         if metric_meta[target_objective].improvement == '(less is better)': | ||||
|             reward = -objective | ||||
|  | @ -652,7 +651,8 @@ def combine_workload(target_data): | |||
|         X_min[i] = col_min | ||||
|         X_max[i] = col_max | ||||
| 
 | ||||
|     return X_columnlabels, X_scaler, X_scaled, y_scaled, X_max, X_min | ||||
|     return X_columnlabels, X_scaler, X_scaled, y_scaled, X_max, X_min,\ | ||||
|         dummy_encoder, constraint_helper | ||||
| 
 | ||||
| 
 | ||||
| @task(base=ConfigurationRecommendation, name='configuration_recommendation') | ||||
|  | @ -661,7 +661,7 @@ def configuration_recommendation(recommendation_input): | |||
|     LOG.info('configuration_recommendation called') | ||||
|     newest_result = Result.objects.get(pk=target_data['newest_result_id']) | ||||
|     session = newest_result.session | ||||
|     params = session.hyper_parameters | ||||
|     params = JSONUtil.loads(session.hyper_parameters) | ||||
| 
 | ||||
|     if target_data['bad'] is True: | ||||
|         target_data_res = create_and_save_recommendation( | ||||
|  | @ -672,9 +672,8 @@ def configuration_recommendation(recommendation_input): | |||
|                   AlgorithmType.name(algorithm), JSONUtil.dumps(target_data, pprint=True)) | ||||
|         return target_data_res | ||||
| 
 | ||||
|     latest_pipeline_run = PipelineRun.objects.get(pk=target_data['pipeline_run']) | ||||
| 
 | ||||
|     X_columnlabels, X_scaler, X_scaled, y_scaled, X_max, X_min = combine_workload(target_data) | ||||
|     X_columnlabels, X_scaler, X_scaled, y_scaled, X_max, X_min,\ | ||||
|         dummy_encoder, constraint_helper = combine_workload(target_data) | ||||
| 
 | ||||
|     # FIXME: we should generate more samples and use a smarter sampling | ||||
|     # technique | ||||
|  | @ -698,9 +697,9 @@ def configuration_recommendation(recommendation_input): | |||
|             # make sure it is within the range. | ||||
|             dist = sum(np.square(X_max - X_scaled[item[1]])) | ||||
|             if dist < 0.001: | ||||
|                 X_samples = np.vstack((X_samples, X_scaled[item[1]] - abs(GPR_EPS))) | ||||
|                 X_samples = np.vstack((X_samples, X_scaled[item[1]] - abs(params['GPR_EPS']))) | ||||
|             else: | ||||
|                 X_samples = np.vstack((X_samples, X_scaled[item[1]] + abs(GPR_EPS))) | ||||
|                 X_samples = np.vstack((X_samples, X_scaled[item[1]] + abs(params['GPR_EPS']))) | ||||
|             i = i + 1 | ||||
|         except queue.Empty: | ||||
|             break | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue