Merge "[db] Move serialize wrapper in the proper place"

This commit is contained in:
Jenkins 2017-07-15 05:34:11 +00:00 committed by Gerrit Code Review
commit c7fb777811
4 changed files with 143 additions and 119 deletions

View File

@ -40,14 +40,9 @@ these objects be simple dictionaries.
"""
import datetime as dt
from oslo_config import cfg
from oslo_db import api as db_api
from oslo_db import options as db_options
import six
from rally.common.i18n import _
CONF = cfg.CONF
@ -59,37 +54,6 @@ db_options.set_defaults(CONF, connection="sqlite:////tmp/rally.sqlite")
IMPL = None
def serialize_data(data):
if data is None:
return None
if isinstance(data, (six.integer_types,
six.string_types,
six.text_type,
dt.date,
dt.time,
float,
)):
return data
if isinstance(data, dict):
return {k: serialize_data(v) for k, v in data.items()}
if isinstance(data, (list, tuple)):
return [serialize_data(i) for i in data]
if hasattr(data, "_as_dict"):
result = data._as_dict()
for k, v in result.items():
result[k] = serialize_data(v)
return result
raise ValueError(_("Can not serialize %s") % data)
def serialize(fn):
def wrapper(*args, **kwargs):
result = fn(*args, **kwargs)
return serialize_data(result)
return wrapper
def get_impl():
global IMPL
@ -143,7 +107,6 @@ def task_get(uuid, detailed=False):
for subtask in task["subtasks"]:
for workload in subtask["workloads"]:
del workload["context_execution"]
del workload["_profiling_data"]
return task

View File

@ -30,11 +30,11 @@ from oslo_config import cfg
from oslo_db import exception as db_exc
from oslo_db.sqlalchemy import session as db_session
from oslo_utils import timeutils
import six
from sqlalchemy import or_
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm import load_only as sa_loadonly
from rally.common.db import api as db_api
from rally.common.db.sqlalchemy import models
from rally.common.i18n import _
from rally import consts
@ -50,6 +50,44 @@ _FACADE = None
INITIAL_REVISION_UUID = "ca3626f62937"
def serialize_data(data):
if data is None:
return None
if isinstance(data, (six.integer_types,
six.string_types,
six.text_type,
dt.date,
dt.time,
float,
)):
return data
if isinstance(data, dict):
return {k: serialize_data(v) for k, v in data.items()}
if isinstance(data, (list, tuple)):
return [serialize_data(i) for i in data]
if hasattr(data, "_as_dict"):
# NOTE(andreykurilin): it is an instance of the Model. It support a
# method `_as_dict`, which should transform an object into dict
# (quite logical as from the method name), BUT it does some extra
# work - tries to load properties which were marked to not be loaded
# in particular request and fails since the session object is not
# present. That is why the code bellow makes a custom transformation.
result = {}
for key in data.__dict__:
if not key.startswith("_"):
result[key] = serialize_data(getattr(data, key))
return result
raise ValueError(_("Can not serialize %s") % data)
def serialize(fn):
def wrapper(*args, **kwargs):
result = fn(*args, **kwargs)
return serialize_data(result)
return wrapper
def _create_facade_lazily():
global _FACADE
@ -239,10 +277,10 @@ class Connection(object):
for raw in workload_data.chunk_data["raw"]],
key=lambda x: x["timestamp"])
@db_api.serialize
@serialize
def task_get(self, uuid=None, detailed=False):
session = get_session()
task = db_api.serialize_data(self._task_get(uuid, session=session))
task = serialize_data(self._task_get(uuid, session=session))
if detailed:
task["subtasks"] = self._subtasks_get_all_by_task_uuid(
@ -250,11 +288,11 @@ class Connection(object):
return task
@db_api.serialize
@serialize
def task_get_status(self, uuid):
return self._task_get(uuid, load_only="status").status
@db_api.serialize
@serialize
def task_create(self, values):
tags = values.pop("tags", None)
# TODO(ikhudoshyn): currently 'input_task'
@ -279,7 +317,7 @@ class Connection(object):
task.tags = sorted(self._tags_get(task.uuid, consts.TagType.TASK))
return task
@db_api.serialize
@serialize
def task_update(self, uuid, values):
session = get_session()
values.pop("uuid", None)
@ -315,7 +353,7 @@ class Connection(object):
raise exceptions.RallyException(msg)
return result
@db_api.serialize
@serialize
def task_list(self, status=None, deployment=None, tags=None):
session = get_session()
tasks = []
@ -379,18 +417,18 @@ class Connection(object):
task_uuid=task_uuid).all())
subtasks = []
for subtask in result:
subtask = db_api.serialize_data(subtask)
subtask = serialize_data(subtask)
subtask["workloads"] = []
workloads = (self.model_query(models.Workload, session=session).
filter_by(subtask_uuid=subtask["uuid"]).all())
for workload in workloads:
workload.data = self._task_workload_data_get_all(
workload.uuid)
subtask["workloads"].append(db_api.serialize_data(workload))
subtask["workloads"].append(serialize_data(workload))
subtasks.append(subtask)
return subtasks
@db_api.serialize
@serialize
def subtask_create(self, task_uuid, title, description=None, context=None):
subtask = models.Subtask(task_uuid=task_uuid)
subtask.update({
@ -401,7 +439,7 @@ class Connection(object):
subtask.save()
return subtask
@db_api.serialize
@serialize
def subtask_update(self, subtask_uuid, values):
subtask = self.model_query(models.Subtask).filter_by(
uuid=subtask_uuid).first()
@ -409,12 +447,12 @@ class Connection(object):
subtask.save()
return subtask
@db_api.serialize
@serialize
def workload_get(self, workload_uuid):
return self.model_query(models.Workload).filter_by(
uuid=workload_uuid).first()
@db_api.serialize
@serialize
def workload_create(self, task_uuid, subtask_uuid, name, description,
position, runner, runner_type, hooks, context, sla,
args, context_execution, statistics):
@ -434,7 +472,7 @@ class Connection(object):
workload.save()
return workload
@db_api.serialize
@serialize
def workload_data_create(self, task_uuid, workload_uuid, chunk_order,
data):
workload_data = models.WorkloadData(task_uuid=task_uuid,
@ -483,7 +521,7 @@ class Connection(object):
workload_data.save()
return workload_data
@db_api.serialize
@serialize
def workload_set_results(self, workload_uuid, subtask_uuid, task_uuid,
load_duration, full_duration, start_time,
sla_results, hooks_results):
@ -584,7 +622,7 @@ class Connection(object):
raise exceptions.DeploymentNotFound(deployment=deployment)
return stored_deployment
@db_api.serialize
@serialize
def deployment_create(self, values):
deployment = models.Deployment()
try:
@ -607,11 +645,11 @@ class Connection(object):
if not count:
raise exceptions.DeploymentNotFound(deployment=uuid)
@db_api.serialize
@serialize
def deployment_get(self, deployment):
return self._deployment_get(deployment)
@db_api.serialize
@serialize
def deployment_update(self, deployment, values):
session = get_session()
values.pop("uuid", None)
@ -620,7 +658,7 @@ class Connection(object):
dpl.update(values)
return dpl
@db_api.serialize
@serialize
def deployment_list(self, status=None, parent_uuid=None, name=None):
query = (self.model_query(models.Deployment).
filter_by(parent_uuid=parent_uuid))
@ -631,14 +669,14 @@ class Connection(object):
query = query.filter_by(status=status)
return query.all()
@db_api.serialize
@serialize
def resource_create(self, values):
resource = models.Resource()
resource.update(values)
resource.save()
return resource
@db_api.serialize
@serialize
def resource_get_all(self, deployment_uuid, provider_name=None, type=None):
query = (self.model_query(models.Resource).
filter_by(deployment_uuid=deployment_uuid))
@ -654,7 +692,7 @@ class Connection(object):
if not count:
raise exceptions.ResourceNotFound(id=id)
@db_api.serialize
@serialize
def verifier_create(self, name, vtype, namespace, source, version,
system_wide, extra_settings=None):
verifier = models.Verifier()
@ -665,7 +703,7 @@ class Connection(object):
verifier.save()
return verifier
@db_api.serialize
@serialize
def verifier_get(self, verifier_id):
return self._verifier_get(verifier_id)
@ -678,7 +716,7 @@ class Connection(object):
raise exceptions.ResourceNotFound(id=verifier_id)
return verifier
@db_api.serialize
@serialize
def verifier_list(self, status=None):
query = self.model_query(models.Verifier)
if status:
@ -696,7 +734,7 @@ class Connection(object):
if not count:
raise exceptions.ResourceNotFound(id=verifier_id)
@db_api.serialize
@serialize
def verifier_update(self, verifier_id, properties):
session = get_session()
with session.begin():
@ -705,7 +743,7 @@ class Connection(object):
verifier.save()
return verifier
@db_api.serialize
@serialize
def verification_create(self, verifier_id, deployment_id, tags=None,
run_args=None):
verifier = self._verifier_get(verifier_id)
@ -726,7 +764,7 @@ class Connection(object):
return verification
@db_api.serialize
@serialize
def verification_get(self, verification_uuid):
verification = self._verification_get(verification_uuid)
verification.tags = sorted(self._tags_get(verification.uuid,
@ -741,7 +779,7 @@ class Connection(object):
raise exceptions.ResourceNotFound(id=verification_uuid)
return verification
@db_api.serialize
@serialize
def verification_list(self, verifier_id=None, deployment_id=None,
tags=None, status=None):
session = get_session()
@ -783,7 +821,7 @@ class Connection(object):
if not count:
raise exceptions.ResourceNotFound(id=verification_uuid)
@db_api.serialize
@serialize
def verification_update(self, verification_uuid, properties):
session = get_session()
with session.begin():
@ -792,7 +830,7 @@ class Connection(object):
verification.save()
return verification
@db_api.serialize
@serialize
def register_worker(self, values):
try:
worker = models.Worker()
@ -804,7 +842,7 @@ class Connection(object):
raise exceptions.WorkerAlreadyRegistered(
worker=values["hostname"])
@db_api.serialize
@serialize
def get_worker(self, hostname):
try:
return (self.model_query(models.Worker).

View File

@ -19,12 +19,10 @@ import collections
import copy
import datetime as dt
import ddt
import mock
from six import moves
from rally.common import db
from rally.common.db import api as db_api
from rally import consts
from rally import exceptions
from tests.unit import test
@ -32,53 +30,6 @@ from tests.unit import test
NOW = dt.datetime.now()
class FakeSerializable(object):
def __init__(self, **kwargs):
self.dict = {}
self.dict.update(kwargs)
def _as_dict(self):
return self.dict
@ddt.ddt
class SerializeTestCase(test.DBTestCase):
def setUp(self):
super(SerializeTestCase, self).setUp()
@ddt.data(
{"data": 1, "serialized": 1},
{"data": 1.1, "serialized": 1.1},
{"data": "a string", "serialized": "a string"},
{"data": NOW, "serialized": NOW},
{"data": {"k1": 1, "k2": 2}, "serialized": {"k1": 1, "k2": 2}},
{"data": [1, "foo"], "serialized": [1, "foo"]},
{"data": ["foo", 1, {"a": "b"}], "serialized": ["foo", 1, {"a": "b"}]},
{"data": FakeSerializable(a=1), "serialized": {"a": 1}},
{"data": [FakeSerializable(a=1),
FakeSerializable(b=FakeSerializable(c=1))],
"serialized": [{"a": 1}, {"b": {"c": 1}}]},
)
@ddt.unpack
def test_serialize(self, data, serialized):
@db_api.serialize
def fake_method():
return data
results = fake_method()
self.assertEqual(results, serialized)
def test_serialize_value_error(self):
@db_api.serialize
def fake_method():
class Fake(object):
pass
return Fake()
self.assertRaises(ValueError, fake_method)
class ConnectionTestCase(test.DBTestCase):
def test_schema_revision(self):
rev = db.schema_revision()
@ -500,12 +451,11 @@ class WorkloadTestCase(test.DBTestCase):
runner_type=w_runner_type)
workload.pop("uuid")
workload.pop("start_time")
workload.pop("created_at")
workload.pop("updated_at")
self.assertEqual(
{"_profiling_data": "", "context_execution": {},
{"context_execution": {},
"statistics": {},
"subtask_uuid": self.subtask_uuid,
"task_uuid": self.task_uuid,

View File

@ -0,0 +1,73 @@
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# NOTE(andreykurilin): most tests for sqlalchemy api is merged with db_api
# tests. Hope, it will be fixed someday.
import datetime as dt
import ddt
from rally.common.db.sqlalchemy import api as db_api
from tests.unit import test
NOW = dt.datetime.now()
class FakeSerializable(object):
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def _as_dict(self):
return self.__dict__
@ddt.ddt
class SerializeTestCase(test.DBTestCase):
def setUp(self):
super(SerializeTestCase, self).setUp()
@ddt.data(
{"data": 1, "serialized": 1},
{"data": 1.1, "serialized": 1.1},
{"data": "a string", "serialized": "a string"},
{"data": NOW, "serialized": NOW},
{"data": {"k1": 1, "k2": 2}, "serialized": {"k1": 1, "k2": 2}},
{"data": [1, "foo"], "serialized": [1, "foo"]},
{"data": ["foo", 1, {"a": "b"}], "serialized": ["foo", 1, {"a": "b"}]},
{"data": FakeSerializable(a=1), "serialized": {"a": 1}},
{"data": [FakeSerializable(a=1),
FakeSerializable(b=FakeSerializable(c=1))],
"serialized": [{"a": 1}, {"b": {"c": 1}}]},
)
@ddt.unpack
def test_serialize(self, data, serialized):
@db_api.serialize
def fake_method():
return data
results = fake_method()
self.assertEqual(serialized, results)
def test_serialize_value_error(self):
@db_api.serialize
def fake_method():
class Fake(object):
pass
return Fake()
self.assertRaises(ValueError, fake_method)