Merge "[db] Move serialize wrapper in the proper place"
This commit is contained in:
commit
c7fb777811
|
@ -40,14 +40,9 @@ these objects be simple dictionaries.
|
|||
|
||||
"""
|
||||
|
||||
import datetime as dt
|
||||
|
||||
from oslo_config import cfg
|
||||
from oslo_db import api as db_api
|
||||
from oslo_db import options as db_options
|
||||
import six
|
||||
|
||||
from rally.common.i18n import _
|
||||
|
||||
|
||||
CONF = cfg.CONF
|
||||
|
@ -59,37 +54,6 @@ db_options.set_defaults(CONF, connection="sqlite:////tmp/rally.sqlite")
|
|||
IMPL = None
|
||||
|
||||
|
||||
def serialize_data(data):
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, (six.integer_types,
|
||||
six.string_types,
|
||||
six.text_type,
|
||||
dt.date,
|
||||
dt.time,
|
||||
float,
|
||||
)):
|
||||
return data
|
||||
if isinstance(data, dict):
|
||||
return {k: serialize_data(v) for k, v in data.items()}
|
||||
if isinstance(data, (list, tuple)):
|
||||
return [serialize_data(i) for i in data]
|
||||
if hasattr(data, "_as_dict"):
|
||||
result = data._as_dict()
|
||||
for k, v in result.items():
|
||||
result[k] = serialize_data(v)
|
||||
return result
|
||||
|
||||
raise ValueError(_("Can not serialize %s") % data)
|
||||
|
||||
|
||||
def serialize(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
result = fn(*args, **kwargs)
|
||||
return serialize_data(result)
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_impl():
|
||||
global IMPL
|
||||
|
||||
|
@ -143,7 +107,6 @@ def task_get(uuid, detailed=False):
|
|||
for subtask in task["subtasks"]:
|
||||
for workload in subtask["workloads"]:
|
||||
del workload["context_execution"]
|
||||
del workload["_profiling_data"]
|
||||
return task
|
||||
|
||||
|
||||
|
|
|
@ -30,11 +30,11 @@ from oslo_config import cfg
|
|||
from oslo_db import exception as db_exc
|
||||
from oslo_db.sqlalchemy import session as db_session
|
||||
from oslo_utils import timeutils
|
||||
import six
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
from sqlalchemy.orm import load_only as sa_loadonly
|
||||
|
||||
from rally.common.db import api as db_api
|
||||
from rally.common.db.sqlalchemy import models
|
||||
from rally.common.i18n import _
|
||||
from rally import consts
|
||||
|
@ -50,6 +50,44 @@ _FACADE = None
|
|||
INITIAL_REVISION_UUID = "ca3626f62937"
|
||||
|
||||
|
||||
def serialize_data(data):
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, (six.integer_types,
|
||||
six.string_types,
|
||||
six.text_type,
|
||||
dt.date,
|
||||
dt.time,
|
||||
float,
|
||||
)):
|
||||
return data
|
||||
if isinstance(data, dict):
|
||||
return {k: serialize_data(v) for k, v in data.items()}
|
||||
if isinstance(data, (list, tuple)):
|
||||
return [serialize_data(i) for i in data]
|
||||
if hasattr(data, "_as_dict"):
|
||||
# NOTE(andreykurilin): it is an instance of the Model. It support a
|
||||
# method `_as_dict`, which should transform an object into dict
|
||||
# (quite logical as from the method name), BUT it does some extra
|
||||
# work - tries to load properties which were marked to not be loaded
|
||||
# in particular request and fails since the session object is not
|
||||
# present. That is why the code bellow makes a custom transformation.
|
||||
result = {}
|
||||
for key in data.__dict__:
|
||||
if not key.startswith("_"):
|
||||
result[key] = serialize_data(getattr(data, key))
|
||||
return result
|
||||
|
||||
raise ValueError(_("Can not serialize %s") % data)
|
||||
|
||||
|
||||
def serialize(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
result = fn(*args, **kwargs)
|
||||
return serialize_data(result)
|
||||
return wrapper
|
||||
|
||||
|
||||
def _create_facade_lazily():
|
||||
global _FACADE
|
||||
|
||||
|
@ -239,10 +277,10 @@ class Connection(object):
|
|||
for raw in workload_data.chunk_data["raw"]],
|
||||
key=lambda x: x["timestamp"])
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def task_get(self, uuid=None, detailed=False):
|
||||
session = get_session()
|
||||
task = db_api.serialize_data(self._task_get(uuid, session=session))
|
||||
task = serialize_data(self._task_get(uuid, session=session))
|
||||
|
||||
if detailed:
|
||||
task["subtasks"] = self._subtasks_get_all_by_task_uuid(
|
||||
|
@ -250,11 +288,11 @@ class Connection(object):
|
|||
|
||||
return task
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def task_get_status(self, uuid):
|
||||
return self._task_get(uuid, load_only="status").status
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def task_create(self, values):
|
||||
tags = values.pop("tags", None)
|
||||
# TODO(ikhudoshyn): currently 'input_task'
|
||||
|
@ -279,7 +317,7 @@ class Connection(object):
|
|||
task.tags = sorted(self._tags_get(task.uuid, consts.TagType.TASK))
|
||||
return task
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def task_update(self, uuid, values):
|
||||
session = get_session()
|
||||
values.pop("uuid", None)
|
||||
|
@ -315,7 +353,7 @@ class Connection(object):
|
|||
raise exceptions.RallyException(msg)
|
||||
return result
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def task_list(self, status=None, deployment=None, tags=None):
|
||||
session = get_session()
|
||||
tasks = []
|
||||
|
@ -379,18 +417,18 @@ class Connection(object):
|
|||
task_uuid=task_uuid).all())
|
||||
subtasks = []
|
||||
for subtask in result:
|
||||
subtask = db_api.serialize_data(subtask)
|
||||
subtask = serialize_data(subtask)
|
||||
subtask["workloads"] = []
|
||||
workloads = (self.model_query(models.Workload, session=session).
|
||||
filter_by(subtask_uuid=subtask["uuid"]).all())
|
||||
for workload in workloads:
|
||||
workload.data = self._task_workload_data_get_all(
|
||||
workload.uuid)
|
||||
subtask["workloads"].append(db_api.serialize_data(workload))
|
||||
subtask["workloads"].append(serialize_data(workload))
|
||||
subtasks.append(subtask)
|
||||
return subtasks
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def subtask_create(self, task_uuid, title, description=None, context=None):
|
||||
subtask = models.Subtask(task_uuid=task_uuid)
|
||||
subtask.update({
|
||||
|
@ -401,7 +439,7 @@ class Connection(object):
|
|||
subtask.save()
|
||||
return subtask
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def subtask_update(self, subtask_uuid, values):
|
||||
subtask = self.model_query(models.Subtask).filter_by(
|
||||
uuid=subtask_uuid).first()
|
||||
|
@ -409,12 +447,12 @@ class Connection(object):
|
|||
subtask.save()
|
||||
return subtask
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def workload_get(self, workload_uuid):
|
||||
return self.model_query(models.Workload).filter_by(
|
||||
uuid=workload_uuid).first()
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def workload_create(self, task_uuid, subtask_uuid, name, description,
|
||||
position, runner, runner_type, hooks, context, sla,
|
||||
args, context_execution, statistics):
|
||||
|
@ -434,7 +472,7 @@ class Connection(object):
|
|||
workload.save()
|
||||
return workload
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def workload_data_create(self, task_uuid, workload_uuid, chunk_order,
|
||||
data):
|
||||
workload_data = models.WorkloadData(task_uuid=task_uuid,
|
||||
|
@ -483,7 +521,7 @@ class Connection(object):
|
|||
workload_data.save()
|
||||
return workload_data
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def workload_set_results(self, workload_uuid, subtask_uuid, task_uuid,
|
||||
load_duration, full_duration, start_time,
|
||||
sla_results, hooks_results):
|
||||
|
@ -584,7 +622,7 @@ class Connection(object):
|
|||
raise exceptions.DeploymentNotFound(deployment=deployment)
|
||||
return stored_deployment
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def deployment_create(self, values):
|
||||
deployment = models.Deployment()
|
||||
try:
|
||||
|
@ -607,11 +645,11 @@ class Connection(object):
|
|||
if not count:
|
||||
raise exceptions.DeploymentNotFound(deployment=uuid)
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def deployment_get(self, deployment):
|
||||
return self._deployment_get(deployment)
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def deployment_update(self, deployment, values):
|
||||
session = get_session()
|
||||
values.pop("uuid", None)
|
||||
|
@ -620,7 +658,7 @@ class Connection(object):
|
|||
dpl.update(values)
|
||||
return dpl
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def deployment_list(self, status=None, parent_uuid=None, name=None):
|
||||
query = (self.model_query(models.Deployment).
|
||||
filter_by(parent_uuid=parent_uuid))
|
||||
|
@ -631,14 +669,14 @@ class Connection(object):
|
|||
query = query.filter_by(status=status)
|
||||
return query.all()
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def resource_create(self, values):
|
||||
resource = models.Resource()
|
||||
resource.update(values)
|
||||
resource.save()
|
||||
return resource
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def resource_get_all(self, deployment_uuid, provider_name=None, type=None):
|
||||
query = (self.model_query(models.Resource).
|
||||
filter_by(deployment_uuid=deployment_uuid))
|
||||
|
@ -654,7 +692,7 @@ class Connection(object):
|
|||
if not count:
|
||||
raise exceptions.ResourceNotFound(id=id)
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def verifier_create(self, name, vtype, namespace, source, version,
|
||||
system_wide, extra_settings=None):
|
||||
verifier = models.Verifier()
|
||||
|
@ -665,7 +703,7 @@ class Connection(object):
|
|||
verifier.save()
|
||||
return verifier
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def verifier_get(self, verifier_id):
|
||||
return self._verifier_get(verifier_id)
|
||||
|
||||
|
@ -678,7 +716,7 @@ class Connection(object):
|
|||
raise exceptions.ResourceNotFound(id=verifier_id)
|
||||
return verifier
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def verifier_list(self, status=None):
|
||||
query = self.model_query(models.Verifier)
|
||||
if status:
|
||||
|
@ -696,7 +734,7 @@ class Connection(object):
|
|||
if not count:
|
||||
raise exceptions.ResourceNotFound(id=verifier_id)
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def verifier_update(self, verifier_id, properties):
|
||||
session = get_session()
|
||||
with session.begin():
|
||||
|
@ -705,7 +743,7 @@ class Connection(object):
|
|||
verifier.save()
|
||||
return verifier
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def verification_create(self, verifier_id, deployment_id, tags=None,
|
||||
run_args=None):
|
||||
verifier = self._verifier_get(verifier_id)
|
||||
|
@ -726,7 +764,7 @@ class Connection(object):
|
|||
|
||||
return verification
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def verification_get(self, verification_uuid):
|
||||
verification = self._verification_get(verification_uuid)
|
||||
verification.tags = sorted(self._tags_get(verification.uuid,
|
||||
|
@ -741,7 +779,7 @@ class Connection(object):
|
|||
raise exceptions.ResourceNotFound(id=verification_uuid)
|
||||
return verification
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def verification_list(self, verifier_id=None, deployment_id=None,
|
||||
tags=None, status=None):
|
||||
session = get_session()
|
||||
|
@ -783,7 +821,7 @@ class Connection(object):
|
|||
if not count:
|
||||
raise exceptions.ResourceNotFound(id=verification_uuid)
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def verification_update(self, verification_uuid, properties):
|
||||
session = get_session()
|
||||
with session.begin():
|
||||
|
@ -792,7 +830,7 @@ class Connection(object):
|
|||
verification.save()
|
||||
return verification
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def register_worker(self, values):
|
||||
try:
|
||||
worker = models.Worker()
|
||||
|
@ -804,7 +842,7 @@ class Connection(object):
|
|||
raise exceptions.WorkerAlreadyRegistered(
|
||||
worker=values["hostname"])
|
||||
|
||||
@db_api.serialize
|
||||
@serialize
|
||||
def get_worker(self, hostname):
|
||||
try:
|
||||
return (self.model_query(models.Worker).
|
||||
|
|
|
@ -19,12 +19,10 @@ import collections
|
|||
import copy
|
||||
import datetime as dt
|
||||
|
||||
import ddt
|
||||
import mock
|
||||
from six import moves
|
||||
|
||||
from rally.common import db
|
||||
from rally.common.db import api as db_api
|
||||
from rally import consts
|
||||
from rally import exceptions
|
||||
from tests.unit import test
|
||||
|
@ -32,53 +30,6 @@ from tests.unit import test
|
|||
NOW = dt.datetime.now()
|
||||
|
||||
|
||||
class FakeSerializable(object):
|
||||
def __init__(self, **kwargs):
|
||||
self.dict = {}
|
||||
self.dict.update(kwargs)
|
||||
|
||||
def _as_dict(self):
|
||||
return self.dict
|
||||
|
||||
|
||||
@ddt.ddt
|
||||
class SerializeTestCase(test.DBTestCase):
|
||||
def setUp(self):
|
||||
super(SerializeTestCase, self).setUp()
|
||||
|
||||
@ddt.data(
|
||||
{"data": 1, "serialized": 1},
|
||||
{"data": 1.1, "serialized": 1.1},
|
||||
{"data": "a string", "serialized": "a string"},
|
||||
{"data": NOW, "serialized": NOW},
|
||||
{"data": {"k1": 1, "k2": 2}, "serialized": {"k1": 1, "k2": 2}},
|
||||
{"data": [1, "foo"], "serialized": [1, "foo"]},
|
||||
{"data": ["foo", 1, {"a": "b"}], "serialized": ["foo", 1, {"a": "b"}]},
|
||||
{"data": FakeSerializable(a=1), "serialized": {"a": 1}},
|
||||
{"data": [FakeSerializable(a=1),
|
||||
FakeSerializable(b=FakeSerializable(c=1))],
|
||||
"serialized": [{"a": 1}, {"b": {"c": 1}}]},
|
||||
)
|
||||
@ddt.unpack
|
||||
def test_serialize(self, data, serialized):
|
||||
@db_api.serialize
|
||||
def fake_method():
|
||||
return data
|
||||
|
||||
results = fake_method()
|
||||
self.assertEqual(results, serialized)
|
||||
|
||||
def test_serialize_value_error(self):
|
||||
@db_api.serialize
|
||||
def fake_method():
|
||||
class Fake(object):
|
||||
pass
|
||||
|
||||
return Fake()
|
||||
|
||||
self.assertRaises(ValueError, fake_method)
|
||||
|
||||
|
||||
class ConnectionTestCase(test.DBTestCase):
|
||||
def test_schema_revision(self):
|
||||
rev = db.schema_revision()
|
||||
|
@ -500,12 +451,11 @@ class WorkloadTestCase(test.DBTestCase):
|
|||
runner_type=w_runner_type)
|
||||
|
||||
workload.pop("uuid")
|
||||
workload.pop("start_time")
|
||||
workload.pop("created_at")
|
||||
workload.pop("updated_at")
|
||||
|
||||
self.assertEqual(
|
||||
{"_profiling_data": "", "context_execution": {},
|
||||
{"context_execution": {},
|
||||
"statistics": {},
|
||||
"subtask_uuid": self.subtask_uuid,
|
||||
"task_uuid": self.task_uuid,
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
# All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# NOTE(andreykurilin): most tests for sqlalchemy api is merged with db_api
|
||||
# tests. Hope, it will be fixed someday.
|
||||
|
||||
import datetime as dt
|
||||
|
||||
import ddt
|
||||
|
||||
from rally.common.db.sqlalchemy import api as db_api
|
||||
from tests.unit import test
|
||||
|
||||
|
||||
NOW = dt.datetime.now()
|
||||
|
||||
|
||||
class FakeSerializable(object):
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def _as_dict(self):
|
||||
return self.__dict__
|
||||
|
||||
|
||||
@ddt.ddt
|
||||
class SerializeTestCase(test.DBTestCase):
|
||||
def setUp(self):
|
||||
super(SerializeTestCase, self).setUp()
|
||||
|
||||
@ddt.data(
|
||||
{"data": 1, "serialized": 1},
|
||||
{"data": 1.1, "serialized": 1.1},
|
||||
{"data": "a string", "serialized": "a string"},
|
||||
{"data": NOW, "serialized": NOW},
|
||||
{"data": {"k1": 1, "k2": 2}, "serialized": {"k1": 1, "k2": 2}},
|
||||
{"data": [1, "foo"], "serialized": [1, "foo"]},
|
||||
{"data": ["foo", 1, {"a": "b"}], "serialized": ["foo", 1, {"a": "b"}]},
|
||||
{"data": FakeSerializable(a=1), "serialized": {"a": 1}},
|
||||
{"data": [FakeSerializable(a=1),
|
||||
FakeSerializable(b=FakeSerializable(c=1))],
|
||||
"serialized": [{"a": 1}, {"b": {"c": 1}}]},
|
||||
)
|
||||
@ddt.unpack
|
||||
def test_serialize(self, data, serialized):
|
||||
@db_api.serialize
|
||||
def fake_method():
|
||||
return data
|
||||
|
||||
results = fake_method()
|
||||
self.assertEqual(serialized, results)
|
||||
|
||||
def test_serialize_value_error(self):
|
||||
@db_api.serialize
|
||||
def fake_method():
|
||||
class Fake(object):
|
||||
pass
|
||||
|
||||
return Fake()
|
||||
|
||||
self.assertRaises(ValueError, fake_method)
|
Loading…
Reference in New Issue