upload and download ddpg models

This commit is contained in:
yangdsh 2020-04-17 01:41:48 +00:00 committed by Dana Van Aken
parent ad96c03902
commit 288805aae3
2 changed files with 17 additions and 0 deletions

View File

@ -247,6 +247,9 @@ CELERYD_MAX_TASKS_PER_CHILD = 20
# the task has been executed, not just before
CELERY_ACKS_LATE = False
# Set the upload max size to a large value for uploading DDPG models
DATA_UPLOAD_MAX_MEMORY_SIZE = 102410241024
djcelery.setup_loader()
# ==============================================

View File

@ -4,6 +4,7 @@
# Copyright (c) 2017-18, Carnegie Mellon University Database Group
#
# pylint: disable=too-many-lines
import base64
import csv
import logging
import os
@ -1564,6 +1565,15 @@ def alt_create_or_edit_session(request):
defaults['hardware'] = hardware
defaults['upload_code'] = data.pop('upload_code', None) or MediaUtil.upload_code_generator()
defaults.update(creation_time=ts, last_update=ts, **data)
if 'ddpg_actor_model' in defaults:
defaults['ddpg_actor_model'] =\
base64.decodebytes(defaults['ddpg_actor_model'].encode('utf8'))
defaults['ddpg_critic_model'] =\
base64.decodebytes(defaults['ddpg_critic_model'].encode('utf8'))
defaults['ddpg_reply_memory'] =\
base64.decodebytes(defaults['ddpg_replay_memory'].encode('utf8'))
# There is a typo in the object name. After correcting that typo, remove the next line.
defaults.pop('ddpg_replay_memory')
session, created = Session.objects.get_or_create(user=user, project=project,
name=session_name, defaults=defaults)
@ -1624,6 +1634,10 @@ def alt_create_or_edit_session(request):
res['hardware_id'] = res['hardware']
res['hardware'] = model_to_dict(session.hardware)
res['algorithm'] = AlgorithmType.name(res['algorithm'])
if session.ddpg_actor_model is not None:
res['ddpg_actor_model'] = base64.encodebytes(session.ddpg_actor_model).decode('utf8')
res['ddpg_critic_model'] = base64.encodebytes(session.ddpg_critic_model).decode('utf8')
res['ddpg_replay_memory'] = base64.encodebytes(session.ddpg_reply_memory).decode('utf8')
sk = SessionKnob.objects.get_knobs_for_session(session)
sess_knobs = {}
for knob in sk: