From 26e2299908d58df68069ac4e9af018c3c7b3f5d4 Mon Sep 17 00:00:00 2001 From: bharath Date: Sun, 14 Oct 2018 00:22:56 +0530 Subject: [PATCH] Add tensorflow driver implementation Change-Id: I951cea9325d2ea4a843ea55d1731c481df899474 --- devstack/lib/gyan | 1 + gyan/api/controllers/v1/__init__.py | 10 +- gyan/api/controllers/v1/ml_models.py | 121 ++++++++++-------- gyan/api/controllers/v1/schemas/ml_models.py | 9 +- .../controllers/v1/schemas/parameter_types.py | 14 ++ .../controllers/v1/views/ml_models_view.py | 33 +++-- gyan/api/middleware/parsable_error.py | 2 - gyan/api/utils.py | 2 +- gyan/common/consts.py | 5 +- gyan/common/policies/ml_model.py | 15 ++- gyan/common/service.py | 3 +- gyan/common/utils.py | 14 +- gyan/compute/api.py | 18 ++- gyan/compute/manager.py | 27 ++-- gyan/compute/rpcapi.py | 7 +- gyan/conf/scheduler.py | 3 - ...99_add_ml_type_and_ml_data_to_ml_model_.py | 39 ++++++ gyan/db/sqlalchemy/api.py | 19 +-- gyan/db/sqlalchemy/models.py | 12 +- gyan/ml_model/tensorflow/driver.py | 25 +++- gyan/objects/fields.py | 15 +++ gyan/objects/ml_model.py | 42 +++++- gyan/tests/base.py | 5 - requirements.txt | 27 +++- setup.py | 2 - 25 files changed, 332 insertions(+), 138 deletions(-) create mode 100644 gyan/db/sqlalchemy/alembic/versions/f3bf9414f399_add_ml_type_and_ml_data_to_ml_model_.py diff --git a/devstack/lib/gyan b/devstack/lib/gyan index caabb7e..ad5c525 100644 --- a/devstack/lib/gyan +++ b/devstack/lib/gyan @@ -305,6 +305,7 @@ function start_gyan_compute { function start_gyan { # ``run_process`` checks ``is_service_enabled``, it is not needed here + mkdir -p /opt/stack/data/gyan start_gyan_api start_gyan_compute } diff --git a/gyan/api/controllers/v1/__init__.py b/gyan/api/controllers/v1/__init__.py index 02f32ec..c4d908e 100644 --- a/gyan/api/controllers/v1/__init__.py +++ b/gyan/api/controllers/v1/__init__.py @@ -82,10 +82,10 @@ class V1(controllers_base.APIBase): 'hosts', '', bookmark=True)] v1.ml_models = [link.make_link('self', pecan.request.host_url, - 'ml_models', ''), + 'ml-models', ''), link.make_link('bookmark', pecan.request.host_url, - 'ml_models', '', + 'ml-models', '', bookmark=True)] return v1 @@ -147,9 +147,9 @@ class Controller(controllers_base.Controller): {'url': pecan.request.url, 'method': pecan.request.method, 'body': pecan.request.body}) - LOG.debug(msg) - + # LOG.debug(msg) + LOG.debug(args) return super(Controller, self)._route(args) -__all__ = ('Controller',) \ No newline at end of file +__all__ = ('Controller',) diff --git a/gyan/api/controllers/v1/ml_models.py b/gyan/api/controllers/v1/ml_models.py index 9e0bf3e..6a48f0e 100644 --- a/gyan/api/controllers/v1/ml_models.py +++ b/gyan/api/controllers/v1/ml_models.py @@ -10,6 +10,7 @@ # License for the specific language governing permissions and limitations # under the License. +import base64 import shlex from oslo_log import log as logging @@ -74,12 +75,13 @@ class MLModelController(base.Controller): """Controller for MLModels.""" _custom_actions = { - 'train': ['POST'], + 'upload_trained_model': ['POST'], 'deploy': ['GET'], - 'undeploy': ['GET'] + 'undeploy': ['GET'], + 'predict': ['POST'] } - + @pecan.expose('json') @exception.wrap_pecan_controller_exception def get_all(self, **kwargs): @@ -149,33 +151,55 @@ class MLModelController(base.Controller): context.all_projects = True ml_model = utils.get_ml_model(ml_model_ident) check_policy_on_ml_model(ml_model.as_dict(), "ml_model:get_one") - if ml_model.node: - compute_api = pecan.request.compute_api - try: - ml_model = compute_api.ml_model_show(context, ml_model) - except exception.MLModelHostNotUp: - raise exception.ServerNotUsable - return view.format_ml_model(context, pecan.request.host_url, ml_model.as_dict()) + @base.Controller.api_version("1.0") + @pecan.expose('json') + @exception.wrap_pecan_controller_exception + def upload_trained_model(self, ml_model_ident, **kwargs): + context = pecan.request.context + LOG.debug(ml_model_ident) + ml_model = utils.get_ml_model(ml_model_ident) + LOG.debug(ml_model) + ml_model.ml_data = pecan.request.body + ml_model.save(context) + pecan.response.status = 200 + compute_api = pecan.request.compute_api + new_model = view.format_ml_model(context, pecan.request.host_url, + ml_model.as_dict()) + compute_api.ml_model_create(context, new_model) + return new_model + + @base.Controller.api_version("1.0") + @pecan.expose('json') + @exception.wrap_pecan_controller_exception + def predict(self, ml_model_ident, **kwargs): + context = pecan.request.context + LOG.debug(ml_model_ident) + ml_model = utils.get_ml_model(ml_model_ident) + pecan.response.status = 200 + compute_api = pecan.request.compute_api + predict_dict = { + "data": base64.b64encode(pecan.request.POST['file'].file.read()) + } + prediction = compute_api.ml_model_predict(context, ml_model_ident, **predict_dict) + return prediction + @base.Controller.api_version("1.0") @pecan.expose('json') @api_utils.enforce_content_types(['application/json']) @exception.wrap_pecan_controller_exception - @validation.validate_query_param(pecan.request, schema.query_param_create) @validation.validated(schema.ml_model_create) def post(self, **ml_model_dict): return self._do_post(**ml_model_dict) - def _do_post(self, **ml_model_dict): """Create or run a new ml model. :param ml_model_dict: a ml_model within the request body. """ context = pecan.request.context - compute_api = pecan.request.compute_api policy.enforce(context, "ml_model:create", action="ml_model:create") @@ -183,22 +207,24 @@ class MLModelController(base.Controller): ml_model_dict['user_id'] = context.user_id name = ml_model_dict.get('name') ml_model_dict['name'] = name - - ml_model_dict['status'] = consts.CREATING + + ml_model_dict['status'] = consts.CREATED + ml_model_dict['ml_type'] = ml_model_dict['type'] extra_spec = {} extra_spec['hints'] = ml_model_dict.get('hints', None) + #ml_model_dict["model_data"] = open("/home/bharath/model.zip", "rb").read() new_ml_model = objects.ML_Model(context, **ml_model_dict) - new_ml_model.create(context) - - compute_api.ml_model_create(context, new_ml_model, **kwargs) + ml_model = new_ml_model.create(context) + LOG.debug(new_ml_model) + #compute_api.ml_model_create(context, new_ml_model) # Set the HTTP Location Header pecan.response.location = link.build_url('ml_models', - new_ml_model.uuid) - pecan.response.status = 202 - return view.format_ml_model(context, pecan.request.node_url, - new_ml_model.as_dict()) + ml_model.id) + pecan.response.status = 201 + return view.format_ml_model(context, pecan.request.host_url, + ml_model.as_dict()) + - @pecan.expose('json') @exception.wrap_pecan_controller_exception @validation.validated(schema.ml_model_update) @@ -217,11 +243,11 @@ class MLModelController(base.Controller): return view.format_ml_model(context, pecan.request.node_url, ml_model.as_dict()) - + @pecan.expose('json') @exception.wrap_pecan_controller_exception @validation.validate_query_param(pecan.request, schema.query_param_delete) - def delete(self, ml_model_ident, force=False, **kwargs): + def delete(self, ml_model_ident, **kwargs): """Delete a ML Model. :param ml_model_ident: UUID or Name of a ML Model. @@ -230,27 +256,7 @@ class MLModelController(base.Controller): context = pecan.request.context ml_model = utils.get_ml_model(ml_model_ident) check_policy_on_ml_model(ml_model.as_dict(), "ml_model:delete") - try: - force = strutils.bool_from_string(force, strict=True) - except ValueError: - bools = ', '.join(strutils.TRUE_STRINGS + strutils.FALSE_STRINGS) - raise exception.InvalidValue(_('Valid force values are: %s') - % bools) - stop = kwargs.pop('stop', False) - try: - stop = strutils.bool_from_string(stop, strict=True) - except ValueError: - bools = ', '.join(strutils.TRUE_STRINGS + strutils.FALSE_STRINGS) - raise exception.InvalidValue(_('Valid stop values are: %s') - % bools) - compute_api = pecan.request.compute_api - if not force: - utils.validate_ml_model_state(ml_model, 'delete') - ml_model.status = consts.DELETING - if ml_model.node: - compute_api.ml_model_delete(context, ml_model, force) - else: - ml_model.destroy(context) + ml_model.destroy(context) pecan.response.status = 204 @@ -261,15 +267,19 @@ class MLModelController(base.Controller): :param ml_model_ident: UUID or Name of a ML Model. """ + context = pecan.request.context ml_model = utils.get_ml_model(ml_model_ident) check_policy_on_ml_model(ml_model.as_dict(), "ml_model:deploy") utils.validate_ml_model_state(ml_model, 'deploy') LOG.debug('Calling compute.ml_model_deploy with %s', - ml_model.uuid) - context = pecan.request.context - compute_api = pecan.request.compute_api - compute_api.ml_model_deploy(context, ml_model) + ml_model.id) + ml_model.status = consts.DEPLOYED + url = pecan.request.url.replace("deploy", "predict") + ml_model.url = url + ml_model.save(context) pecan.response.status = 202 + return view.format_ml_model(context, pecan.request.host_url, + ml_model.as_dict()) @pecan.expose('json') @exception.wrap_pecan_controller_exception @@ -278,12 +288,15 @@ class MLModelController(base.Controller): :param ml_model_ident: UUID or Name of a ML Model. """ + context = pecan.request.context ml_model = utils.get_ml_model(ml_model_ident) check_policy_on_ml_model(ml_model.as_dict(), "ml_model:deploy") utils.validate_ml_model_state(ml_model, 'undeploy') LOG.debug('Calling compute.ml_model_deploy with %s', - ml_model.uuid) - context = pecan.request.context - compute_api = pecan.request.compute_api - compute_api.ml_model_undeploy(context, ml_model) + ml_model.id) + ml_model.status = consts.SCHEDULED + ml_model.url = None + ml_model.save(context) pecan.response.status = 202 + return view.format_ml_model(context, pecan.request.host_url, + ml_model.as_dict()) diff --git a/gyan/api/controllers/v1/schemas/ml_models.py b/gyan/api/controllers/v1/schemas/ml_models.py index 07c2eb5..5f64a12 100644 --- a/gyan/api/controllers/v1/schemas/ml_models.py +++ b/gyan/api/controllers/v1/schemas/ml_models.py @@ -18,8 +18,11 @@ _ml_model_properties = {} ml_model_create = { 'type': 'object', - 'properties': _ml_model_properties, - 'required': ['name'], + 'properties': { + "name": parameter_types.ml_model_name, + "type": parameter_types.ml_model_type + }, + 'required': ['name', 'type'], 'additionalProperties': False } @@ -46,4 +49,4 @@ query_param_delete = { 'stop': parameter_types.boolean_extended }, 'additionalProperties': False -} \ No newline at end of file +} diff --git a/gyan/api/controllers/v1/schemas/parameter_types.py b/gyan/api/controllers/v1/schemas/parameter_types.py index 4a2ecab..fa1f64c 100644 --- a/gyan/api/controllers/v1/schemas/parameter_types.py +++ b/gyan/api/controllers/v1/schemas/parameter_types.py @@ -95,3 +95,17 @@ hostname = { # real systems. 'pattern': '^[a-zA-Z0-9-._]*$', } + +ml_model_name = { + 'type': 'string', + 'minLength': 1, + 'maxLength': 255, + 'pattern': '^[a-zA-Z0-9-._]*$' +} + +ml_model_type = { + 'type': 'string', + 'minLength': 1, + 'maxLength': 255, + 'pattern': '^[a-zA-Z0-9-._]*$' +} diff --git a/gyan/api/controllers/v1/views/ml_models_view.py b/gyan/api/controllers/v1/views/ml_models_view.py index 3721fec..50af09c 100644 --- a/gyan/api/controllers/v1/views/ml_models_view.py +++ b/gyan/api/controllers/v1/views/ml_models_view.py @@ -13,41 +13,46 @@ import itertools +from oslo_log import log as logging + from gyan.api.controllers import link from gyan.common.policies import ml_model as policies _basic_keys = ( - 'uuid', + 'id', 'user_id', 'project_id', 'name', 'url', 'status', 'status_reason', - 'task_state', - 'labels', - 'host', - 'status_detail' + 'host_id', + 'deployed', + 'ml_type' ) +LOG = logging.getLogger(__name__) + def format_ml_model(context, url, ml_model): def transform(key, value): + LOG.debug(key) + LOG.debug(value) if key not in _basic_keys: return # strip the key if it is not allowed by policy policy_action = policies.ML_MODEL % ('get_one:%s' % key) if not context.can(policy_action, fatal=False, might_not_exist=True): return - if key == 'uuid': - yield ('uuid', value) - if url: - yield ('links', [link.make_link( - 'self', url, 'ml_models', value), - link.make_link( - 'bookmark', url, - 'ml_models', value, - bookmark=True)]) + if key == 'id': + yield ('id', value) + # if url: + # yield ('links', [link.make_link( + # 'self', url, 'ml_models', value), + # link.make_link( + # 'bookmark', url, + # 'ml_models', value, + # bookmark=True)]) else: yield (key, value) diff --git a/gyan/api/middleware/parsable_error.py b/gyan/api/middleware/parsable_error.py index 63e61a4..7511694 100644 --- a/gyan/api/middleware/parsable_error.py +++ b/gyan/api/middleware/parsable_error.py @@ -1,5 +1,3 @@ -# Copyright ? 2012 New Dream Network, LLC (DreamHost) -# # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at diff --git a/gyan/api/utils.py b/gyan/api/utils.py index d48aa8b..d83291c 100644 --- a/gyan/api/utils.py +++ b/gyan/api/utils.py @@ -113,4 +113,4 @@ def version_check(action, version): if req_version < min_version: raise exception.InvalidParamInVersion(param=action, req_version=req_version, - min_version=min_version) \ No newline at end of file + min_version=min_version) diff --git a/gyan/common/consts.py b/gyan/common/consts.py index d45edaa..a086a0c 100644 --- a/gyan/common/consts.py +++ b/gyan/common/consts.py @@ -14,4 +14,7 @@ ALLOCATED = 'allocated' CREATED = 'created' UNDEPLOYED = 'undeployed' -DEPLOYED = 'deployed' \ No newline at end of file +DEPLOYED = 'deployed' +CREATING = 'CREATING' +CREATED = 'CREATED' +SCHEDULED = 'SCHEDULED' \ No newline at end of file diff --git a/gyan/common/policies/ml_model.py b/gyan/common/policies/ml_model.py index f45b167..6368c20 100644 --- a/gyan/common/policies/ml_model.py +++ b/gyan/common/policies/ml_model.py @@ -106,16 +106,27 @@ rules = [ ] ), policy.DocumentedRuleDefault( - name=ML_MODEL % 'upload', + name=ML_MODEL % 'upload_trained_model', check_str=base.RULE_ADMIN_OR_OWNER, description='Upload the trained ML Model', operations=[ { - 'path': '/v1/ml_models/{ml_model_ident}/upload', + 'path': '/v1/ml_models/{ml_model_ident}/upload_trained_model', 'method': 'POST' } ] ), + policy.DocumentedRuleDefault( + name=ML_MODEL % 'deploy', + check_str=base.RULE_ADMIN_OR_OWNER, + description='Upload the trained ML Model', + operations=[ + { + 'path': '/v1/ml_models/{ml_model_ident}/deploy', + 'method': 'GET' + } + ] + ), ] diff --git a/gyan/common/service.py b/gyan/common/service.py index f0b8e64..9ffe131 100644 --- a/gyan/common/service.py +++ b/gyan/common/service.py @@ -27,7 +27,8 @@ CONF = gyan.conf.CONF def prepare_service(argv=None): if argv is None: - argv = [] + argv = ['/usr/local/bin/gyan-api', '--config-file', '/etc/gyan/gyan.conf'] + argv = ['/usr/local/bin/gyan-api', '--config-file', '/etc/gyan/gyan.conf'] log.register_options(CONF) config.parse_args(argv) config.set_config_defaults() diff --git a/gyan/common/utils.py b/gyan/common/utils.py index f3950a7..0dd9ba9 100644 --- a/gyan/common/utils.py +++ b/gyan/common/utils.py @@ -23,6 +23,7 @@ import functools import inspect import json import mimetypes +import os from oslo_concurrency import processutils from oslo_context import context as common_context @@ -44,7 +45,7 @@ CONF = gyan.conf.CONF LOG = logging.getLogger(__name__) VALID_STATES = { - 'deploy': [consts.CREATED, consts.UNDEPLOYED], + 'deploy': [consts.CREATED, consts.UNDEPLOYED, consts.SCHEDULED], 'undeploy': [consts.DEPLOYED] } def safe_rstrip(value, chars=None): @@ -162,7 +163,7 @@ def get_ml_model(ml_model_ident): def validate_ml_model_state(ml_model, action): if ml_model.status not in VALID_STATES[action]: raise exception.InvalidStateException( - id=ml_model.uuid, + id=ml_model.id, action=action, actual_state=ml_model.status) @@ -253,3 +254,12 @@ def decode_file_data(data): return base64.b64decode(data) except (TypeError, binascii.Error): raise exception.Base64Exception() + + +def save_model(path, model): + file_path = os.path.join(path, model.id) + with open(file_path+'.zip', 'wb') as f: + f.write(model.ml_data) + zip_ref = zipfile.ZipFile(file_path+'.zip', 'r') + zip_ref.extractall(file_path) + zip_ref.close() \ No newline at end of file diff --git a/gyan/compute/api.py b/gyan/compute/api.py index cf62ed0..bf687c5 100644 --- a/gyan/compute/api.py +++ b/gyan/compute/api.py @@ -28,7 +28,6 @@ CONF = gyan.conf.CONF LOG = logging.getLogger(__name__) -@profiler.trace_cls("rpc") class API(object): """API for interacting with the compute manager.""" @@ -36,10 +35,11 @@ class API(object): self.rpcapi = rpcapi.API(context=context) super(API, self).__init__() - def ml_model_create(self, context, new_ml_model, extra_spec): + def ml_model_create(self, context, new_ml_model, **extra_spec): try: - host_state = self._schedule_ml_model(context, ml_model, - extra_spec) + host_state = { + "host": "localhost" + } #self._schedule_ml_model(context, ml_model, extra_spec) except exception.NoValidHost: new_ml_model.status = consts.ERROR new_ml_model.status_reason = _( @@ -51,13 +51,17 @@ class API(object): new_ml_model.status_reason = _("Unexpected exception occurred.") new_ml_model.save(context) raise - - self.rpcapi.ml_model_create(context, host_state['host'], + LOG.debug(host_state) + return self.rpcapi.ml_model_create(context, host_state['host'], new_ml_model) + + def ml_model_predict(self, context, ml_model_id, **kwargs): + return self.rpcapi.ml_model_predict(context, ml_model_id, + **kwargs) def ml_model_delete(self, context, ml_model, *args): self._record_action_start(context, ml_model, ml_model_actions.DELETE) return self.rpcapi.ml_model_delete(context, ml_model, *args) def ml_model_show(self, context, ml_model): - return self.rpcapi.ml_model_show(context, ml_model) \ No newline at end of file + return self.rpcapi.ml_model_show(context, ml_model) diff --git a/gyan/compute/manager.py b/gyan/compute/manager.py index dee77ed..146f70c 100644 --- a/gyan/compute/manager.py +++ b/gyan/compute/manager.py @@ -1,5 +1,3 @@ -# Copyright 2016 IBM Corp. -# # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at @@ -12,7 +10,9 @@ # License for the specific language governing permissions and limitations # under the License. +import base64 import itertools +import os import six import time @@ -49,17 +49,18 @@ class Manager(periodic_task.PeriodicTasks): self.host = CONF.compute.host self._resource_tracker = None - def ml_model_create(self, context, limits, requested_networks, - requested_volumes, ml_model, run, pci_requests=None): - @utils.synchronized(ml_model.uuid) - def do_ml_model_create(): - created_ml_model = self._do_ml_model_create( - context, ml_model, requested_networks, requested_volumes, - pci_requests, limits) - if run: - self._do_ml_model_start(context, created_ml_model) + def ml_model_create(self, context, ml_model): + db_ml_model = objects.ML_Model.get_by_uuid_db(context, ml_model["id"]) + utils.save_model(CONF.state_path, db_ml_model) + obj_ml_model = objects.ML_Model.get_by_uuid(context, ml_model["id"]) + obj_ml_model.status = consts.SCHEDULED + obj_ml_model.status_reason = "The ML Model is scheduled and saved to the host %s" % self.host + obj_ml_model.save(context) - utils.spawn_n(do_ml_model_create) + def ml_model_predict(self, context, ml_model_id, kwargs): + #open("/home/bharath/Documents/0.png", "wb").write(base64.b64decode(kwargs["data"])) + model_path = os.path.join(CONF.state_path, ml_model_id) + return self.driver.predict(context, model_path, base64.b64decode(kwargs["data"])) @wrap_ml_model_event(prefix='compute') def _do_ml_model_create(self, context, ml_model, requested_networks, @@ -118,4 +119,4 @@ class Manager(periodic_task.PeriodicTasks): rt = compute_host_tracker.ComputeHostTracker(self.host, self.driver) self._resource_tracker = rt - return self._resource_tracker \ No newline at end of file + return self._resource_tracker diff --git a/gyan/compute/rpcapi.py b/gyan/compute/rpcapi.py index f8824bb..6fb530b 100644 --- a/gyan/compute/rpcapi.py +++ b/gyan/compute/rpcapi.py @@ -1,5 +1,3 @@ -# Copyright 2016 IBM Corp. -# # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at @@ -30,7 +28,6 @@ def check_ml_model_host(func): return wrap -@profiler.trace_cls("rpc") class API(rpc_service.API): """Client side of the ml_model compute rpc API. @@ -51,6 +48,10 @@ class API(rpc_service.API): self._cast(host, 'ml_model_create', ml_model=ml_model) + def ml_model_predict(self, context, ml_model_id, **kwargs): + return self._call("localhost", 'ml_model_predict', + ml_model_id=ml_model_id, kwargs=kwargs) + @check_ml_model_host def ml_model_delete(self, context, ml_model, force): return self._cast(ml_model.host, 'ml_model_delete', diff --git a/gyan/conf/scheduler.py b/gyan/conf/scheduler.py index 554f614..946315c 100644 --- a/gyan/conf/scheduler.py +++ b/gyan/conf/scheduler.py @@ -1,6 +1,3 @@ -# Copyright 2015 OpenStack Foundation -# All Rights Reserved. -# # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at diff --git a/gyan/db/sqlalchemy/alembic/versions/f3bf9414f399_add_ml_type_and_ml_data_to_ml_model_.py b/gyan/db/sqlalchemy/alembic/versions/f3bf9414f399_add_ml_type_and_ml_data_to_ml_model_.py new file mode 100644 index 0000000..111e261 --- /dev/null +++ b/gyan/db/sqlalchemy/alembic/versions/f3bf9414f399_add_ml_type_and_ml_data_to_ml_model_.py @@ -0,0 +1,39 @@ +"""Add ml_type and ml_data to ml_model table + +Revision ID: f3bf9414f399 +Revises: cebd81b206ca +Create Date: 2018-10-13 09:48:36.783322 + +""" + +# revision identifiers, used by Alembic. +revision = 'f3bf9414f399' +down_revision = 'cebd81b206ca' +branch_labels = None +depends_on = None + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('compute_host', schema=None) as batch_op: + batch_op.alter_column('hostname', + existing_type=mysql.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('status', + existing_type=mysql.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('type', + existing_type=mysql.VARCHAR(length=255), + nullable=False) + + with op.batch_alter_table('ml_model', schema=None) as batch_op: + batch_op.add_column(sa.Column('ml_data', sa.LargeBinary(length=(2**32)-1), nullable=True)) + batch_op.add_column(sa.Column('ml_type', sa.String(length=255), nullable=True)) + batch_op.add_column(sa.Column('started_at', sa.DateTime(), nullable=True)) + batch_op.create_unique_constraint('uniq_mlmodel0uuid', ['id']) + batch_op.drop_constraint(u'ml_model_ibfk_1', type_='foreignkey') + + # ### end Alembic commands ### diff --git a/gyan/db/sqlalchemy/api.py b/gyan/db/sqlalchemy/api.py index 125d727..9f1cb38 100644 --- a/gyan/db/sqlalchemy/api.py +++ b/gyan/db/sqlalchemy/api.py @@ -1,5 +1,3 @@ -# Copyright 2013 Hewlett-Packard Development Company, L.P. -# # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at @@ -13,6 +11,7 @@ # under the License. """SQLAlchemy storage backend.""" +from oslo_log import log as logging from oslo_db import exception as db_exc from oslo_db.sqlalchemy import session as db_session @@ -39,6 +38,7 @@ profiler_sqlalchemy = importutils.try_import('osprofiler.sqlalchemy') CONF = gyan.conf.CONF _FACADE = None +LOG = logging.getLogger(__name__) def _create_facade_lazily(): @@ -90,7 +90,7 @@ def add_identity_filter(query, value): if strutils.is_int_like(value): return query.filter_by(id=value) elif uuidutils.is_uuid_like(value): - return query.filter_by(uuid=value) + return query.filter_by(id=value) else: raise exception.InvalidIdentity(identity=value) @@ -230,16 +230,17 @@ class Connection(object): def list_ml_models(self, context, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): - query = model_query(models.Capsule) + query = model_query(models.ML_Model) query = self._add_project_filters(context, query) query = self._add_ml_models_filters(query, filters) - return _paginate_query(models.Capsule, limit, marker, + LOG.debug(filters) + return _paginate_query(models.ML_Model, limit, marker, sort_key, sort_dir, query) def create_ml_model(self, context, values): # ensure defaults are present for new ml_models - if not values.get('uuid'): - values['uuid'] = uuidutils.generate_uuid() + if not values.get('id'): + values['id'] = uuidutils.generate_uuid() ml_model = models.ML_Model() ml_model.update(values) try: @@ -252,7 +253,7 @@ class Connection(object): def get_ml_model_by_uuid(self, context, ml_model_uuid): query = model_query(models.ML_Model) query = self._add_project_filters(context, query) - query = query.filter_by(uuid=ml_model_uuid) + query = query.filter_by(id=ml_model_uuid) try: return query.one() except NoResultFound: @@ -261,7 +262,7 @@ class Connection(object): def get_ml_model_by_name(self, context, ml_model_name): query = model_query(models.ML_Model) query = self._add_project_filters(context, query) - query = query.filter_by(meta_name=ml_model_name) + query = query.filter_by(name=ml_model_name) try: return query.one() except NoResultFound: diff --git a/gyan/db/sqlalchemy/models.py b/gyan/db/sqlalchemy/models.py index b148439..31c1362 100644 --- a/gyan/db/sqlalchemy/models.py +++ b/gyan/db/sqlalchemy/models.py @@ -31,6 +31,7 @@ from sqlalchemy import orm from sqlalchemy import schema from sqlalchemy import sql from sqlalchemy import String +from sqlalchemy import LargeBinary from sqlalchemy import Text from sqlalchemy.types import TypeDecorator, TEXT @@ -120,11 +121,12 @@ class ML_Model(Base): name = Column(String(255)) status = Column(String(20)) status_reason = Column(Text, nullable=True) - task_state = Column(String(20)) - host_id = Column(String(255)) - status_detail = Column(String(50)) - deployed = Column(String(50)) + host_id = Column(String(255), nullable=True) deployed = Column(Text, nullable=True) + url = Column(Text, nullable=True) + hints = Column(Text, nullable=True) + ml_type = Column(String(255), nullable=True) + ml_data = Column(LargeBinary(length=(2**32)-1), nullable=True) started_at = Column(DateTime) @@ -138,4 +140,4 @@ class ComputeHost(Base): id = Column(String(36), primary_key=True, nullable=False) hostname = Column(String(255), nullable=False) status = Column(String(255), nullable=False) - type = Column(String(255), nullable=False) \ No newline at end of file + type = Column(String(255), nullable=False) diff --git a/gyan/ml_model/tensorflow/driver.py b/gyan/ml_model/tensorflow/driver.py index 7b41a5a..114f5c9 100644 --- a/gyan/ml_model/tensorflow/driver.py +++ b/gyan/ml_model/tensorflow/driver.py @@ -15,8 +15,13 @@ import datetime import eventlet import functools import types +import png +import os +import tempfile +import numpy as np + +import tensorflow as tf -from docker import errors from oslo_log import log as logging from oslo_utils import timeutils from oslo_utils import uuidutils @@ -47,6 +52,24 @@ class TensorflowDriver(driver.MLModelDriver): return ml_model pass + def _load(self, session, path): + saver = tf.train.import_meta_graph(path + '/model.meta') + saver.restore(session, tf.train.latest_checkpoint(path)) + return tf.get_default_graph() + + def predict(self, context, ml_model_path, data): + session = tf.Session() + graph = self._load(session, ml_model_path) + img_file, img_path = tempfile.mkstemp() + with os.fdopen(img_file, 'wb') as f: + f.write(data) + png_data = png.Reader(img_path) + img = np.array(list(png_data.read()[2])) + img = img.reshape(1, 784) + tensor = graph.get_tensor_by_name('x:0') + prediction = graph.get_tensor_by_name('classification:0') + return {"data": session.run(prediction, feed_dict={tensor:img})[0]} + def delete(self, context, ml_model, force): pass diff --git a/gyan/objects/fields.py b/gyan/objects/fields.py index 07dc6da..59f454d 100644 --- a/gyan/objects/fields.py +++ b/gyan/objects/fields.py @@ -43,3 +43,18 @@ class Json(fields.FieldType): class JsonField(fields.AutoTypedField): AUTO_TYPE = Json() + + +class ModelFieldType(fields.FieldType): + def coerce(self, obj, attr, value): + return value + + def from_primitive(self, obj, attr, value): + return self.coerce(obj, attr, value) + + def to_primitive(self, obj, attr, value): + return value + + +class ModelField(fields.AutoTypedField): + AUTO_TYPE = ModelFieldType() \ No newline at end of file diff --git a/gyan/objects/ml_model.py b/gyan/objects/ml_model.py index 3420ba8..0f24501 100644 --- a/gyan/objects/ml_model.py +++ b/gyan/objects/ml_model.py @@ -22,6 +22,7 @@ from gyan.objects import fields as z_fields LOG = logging.getLogger(__name__) + @base.GyanObjectRegistry.register class ML_Model(base.GyanPersistentObject, base.GyanObject): VERSION = '1' @@ -35,16 +36,19 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject): 'status_reason': fields.StringField(nullable=True), 'url': fields.StringField(nullable=True), 'deployed': fields.BooleanField(nullable=True), - 'node': fields.UUIDField(nullable=True), 'hints': fields.StringField(nullable=True), 'created_at': fields.DateTimeField(tzinfo_aware=False, nullable=True), - 'updated_at': fields.DateTimeField(tzinfo_aware=False, nullable=True) + 'updated_at': fields.DateTimeField(tzinfo_aware=False, nullable=True), + 'ml_data': z_fields.ModelField(nullable=True), + 'ml_type': fields.StringField(nullable=True) } @staticmethod def _from_db_object(ml_model, db_ml_model): """Converts a database entity to a formal object.""" for field in ml_model.fields: + if 'field' == 'ml_data': + continue setattr(ml_model, field, db_ml_model[field]) ml_model.obj_reset_changes() @@ -67,6 +71,17 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject): db_ml_model = dbapi.get_ml_model_by_uuid(context, uuid) ml_model = ML_Model._from_db_object(cls(context), db_ml_model) return ml_model + + @base.remotable_classmethod + def get_by_uuid_db(cls, context, uuid): + """Find a ml model based on uuid and return a :class:`ML_Model` object. + + :param uuid: the uuid of a ml model. + :param context: Security context + :returns: a :class:`ML_Model` object. + """ + db_ml_model = dbapi.get_ml_model_by_uuid(context, uuid) + return db_ml_model @base.remotable_classmethod def get_by_name(cls, context, name): @@ -125,7 +140,7 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject): """ values = self.obj_get_changes() db_ml_model = dbapi.create_ml_model(context, values) - self._from_db_object(self, db_ml_model) + return self._from_db_object(self, db_ml_model) @base.remotable def destroy(self, context=None): @@ -138,7 +153,26 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject): A context should be set when instantiating the object, e.g.: ML Model(context) """ - dbapi.destroy_ml_model(context, self.uuid) + dbapi.destroy_ml_model(context, self.id) + self.obj_reset_changes() + + @base.remotable + def save(self, context=None): + """Save updates to this ML Model. + + Updates will be made column by column based on the result + of self.what_changed(). + + :param context: Security context. NOTE: This should only + be used internally by the indirection_api. + Unfortunately, RPC requires context as the first + argument, even though we don't use it. + A context should be set when instantiating the + object, e.g.: ML Model(context) + """ + updates = self.obj_get_changes() + dbapi.update_ml_model(context, self.id, updates) + self.obj_reset_changes() def obj_load_attr(self, attrname): diff --git a/gyan/tests/base.py b/gyan/tests/base.py index 1c30cdb..bc2d9c8 100644 --- a/gyan/tests/base.py +++ b/gyan/tests/base.py @@ -1,8 +1,3 @@ -# -*- coding: utf-8 -*- - -# Copyright 2010-2011 OpenStack Foundation -# Copyright (c) 2013 Hewlett-Packard Development Company, L.P. -# # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at diff --git a/requirements.txt b/requirements.txt index 1d18dd3..035de03 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,29 @@ # of appearance. Changing the order has an impact on the overall integration # process, which may cause wedges in the gate later. -pbr>=2.0 # Apache-2.0 +PyYAML>=3.12 # MIT +eventlet!=0.18.3,!=0.20.1,>=0.18.2 # MIT +keystonemiddleware>=4.17.0 # Apache-2.0 +pecan!=1.0.2,!=1.0.3,!=1.0.4,!=1.2,>=1.0.0 # BSD +oslo.i18n>=3.15.3 # Apache-2.0 +oslo.log>=3.36.0 # Apache-2.0 +oslo.concurrency>=3.25.0 # Apache-2.0 +oslo.config>=5.2.0 # Apache-2.0 +oslo.messaging>=5.29.0 # Apache-2.0 +oslo.middleware>=3.31.0 # Apache-2.0 +oslo.policy>=1.30.0 # Apache-2.0 +oslo.privsep>=1.23.0 # Apache-2.0 +oslo.serialization!=2.19.1,>=2.18.0 # Apache-2.0 +oslo.service!=1.28.1,>=1.24.0 # Apache-2.0 +oslo.versionedobjects>=1.31.2 # Apache-2.0 +oslo.context>=2.19.2 # Apache-2.0 +oslo.utils>=3.33.0 # Apache-2.0 +oslo.db>=4.27.0 # Apache-2.0 +os-brick>=2.2.0 # Apache-2.0 +six>=1.10.0 # MIT +SQLAlchemy!=1.1.5,!=1.1.6,!=1.1.7,!=1.1.8,>=1.0.10 # MIT +stevedore>=1.20.0 # Apache-2.0 +pypng +numpy +tensorflow +idx2numpy diff --git a/setup.py b/setup.py index 056c16c..98b93eb 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,3 @@ -# Copyright (c) 2013 Hewlett-Packard Development Company, L.P. -# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at