diff --git a/heat/db/sqlalchemy/models.py b/heat/db/sqlalchemy/models.py index 290e1e8987..bb1ce84368 100644 --- a/heat/db/sqlalchemy/models.py +++ b/heat/db/sqlalchemy/models.py @@ -19,46 +19,18 @@ import uuid import sqlalchemy -from sqlalchemy.dialects import mysql from sqlalchemy.orm import relationship, backref from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy import types -from json import dumps -from json import loads from heat.openstack.common import timeutils from heat.openstack.common.db.sqlalchemy import models from heat.openstack.common.db.sqlalchemy import session from sqlalchemy.orm.session import Session +from heat.db.sqlalchemy.types import Json BASE = declarative_base() get_session = session.get_session -class Json(types.TypeDecorator): - impl = types.Text - - def load_dialect_impl(self, dialect): - if dialect.name == 'mysql': - return dialect.type_descriptor(mysql.LONGTEXT()) - else: - return self.impl - - def process_bind_param(self, value, dialect): - return dumps(value) - - def process_result_value(self, value, dialect): - return loads(value) - -# TODO(leizhang) When we removed sqlalchemy 0.7 dependence -# we can import MutableDict directly and remove ./mutable.py -try: - from sqlalchemy.ext.mutable import MutableDict as sa_MutableDict - sa_MutableDict.associate_with(Json) -except ImportError: - from heat.db.sqlalchemy.mutable import MutableDict - MutableDict.associate_with(Json) - - class HeatBase(models.ModelBase, models.TimestampMixin): """Base class for Heat Models.""" __table_args__ = {'mysql_engine': 'InnoDB'} diff --git a/heat/db/sqlalchemy/types.py b/heat/db/sqlalchemy/types.py new file mode 100644 index 0000000000..86a8cebc73 --- /dev/null +++ b/heat/db/sqlalchemy/types.py @@ -0,0 +1,50 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from json import dumps +from json import loads +from sqlalchemy import types +from sqlalchemy.dialects import mysql + + +class LongText(types.TypeDecorator): + impl = types.Text + + def load_dialect_impl(self, dialect): + if dialect.name == 'mysql': + return dialect.type_descriptor(mysql.LONGTEXT()) + else: + return self.impl + + +class Json(LongText): + + def process_bind_param(self, value, dialect): + return dumps(value) + + def process_result_value(self, value, dialect): + return loads(value) + + +def associate_with(sqltype): + # TODO(leizhang) When we removed sqlalchemy 0.7 dependence + # we can import MutableDict directly and remove ./mutable.py + try: + from sqlalchemy.ext.mutable import MutableDict as sa_MutableDict + sa_MutableDict.associate_with(Json) + except ImportError: + from heat.db.sqlalchemy.mutable import MutableDict + MutableDict.associate_with(Json) + +associate_with(LongText) +associate_with(Json) diff --git a/heat/tests/test_sqlalchemy_types.py b/heat/tests/test_sqlalchemy_types.py new file mode 100644 index 0000000000..52e8ef6fa9 --- /dev/null +++ b/heat/tests/test_sqlalchemy_types.py @@ -0,0 +1,54 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import testtools + +from heat.db.sqlalchemy.types import LongText +from heat.db.sqlalchemy.types import Json +from sqlalchemy import types +from sqlalchemy.dialects.mysql.base import MySQLDialect +from sqlalchemy.dialects.sqlite.base import SQLiteDialect + + +class LongTextTest(testtools.TestCase): + + def setUp(self): + super(LongTextTest, self).setUp() + self.sqltype = LongText() + + def test_load_dialect_impl(self): + dialect = MySQLDialect() + impl = self.sqltype.load_dialect_impl(dialect) + self.assertNotEqual(types.Text, type(impl)) + dialect = SQLiteDialect() + impl = self.sqltype.load_dialect_impl(dialect) + self.assertEqual(types.Text, type(impl)) + + +class JsonTest(testtools.TestCase): + + def setUp(self): + super(JsonTest, self).setUp() + self.sqltype = Json() + + def test_process_bind_param(self): + dialect = None + value = {'foo': 'bar'} + result = self.sqltype.process_bind_param(value, dialect) + self.assertEqual('{"foo": "bar"}', result) + + def test_process_result_value(self): + dialect = None + value = '{"foo": "bar"}' + result = self.sqltype.process_result_value(value, dialect) + self.assertEqual({'foo': 'bar'}, result)