Use pytest fixtures to reduce complexity and repetition
Also: Allow override of database name and user in tests (important for me as I would have to mess with my PSQL and MySQL database users otherwise) Use dict.items instead of six.iteritems as it sporadically caused RuntimeError: dictionary changed size during iteration in Python 2.6 tests. Fix typo DNS to DSN Adds Python 3.5 to tox.ini Added an .editorconfig Import babel.dates in sqlalchemy_utils.i18n as an exception would be raised when using the latest versions of babel.
This commit is contained in:
parent
5bdd4d3efb
commit
815f07d6c1
|
@ -0,0 +1,14 @@
|
|||
# EditorConfig helps developers define and maintain consistent
|
||||
# coding styles between different editors and IDEs
|
||||
# editorconfig.org
|
||||
|
||||
root = true
|
||||
|
||||
|
||||
[*]
|
||||
indent_style = space
|
||||
end_of_line = lf
|
||||
charset = utf-8
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
indent_size = 4
|
|
@ -15,6 +15,8 @@ var
|
|||
sdist
|
||||
develop-eggs
|
||||
.installed.cfg
|
||||
.cache
|
||||
.eggs
|
||||
lib
|
||||
lib64
|
||||
docs/_build
|
||||
|
@ -42,3 +44,6 @@ nosetests.xml
|
|||
Session.vim
|
||||
.netrwhist
|
||||
*~
|
||||
|
||||
# Sublime Text
|
||||
*.sublime-*
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[settings]
|
||||
known_first_party=sqlalchemy_utils,tests
|
||||
known_first_party=sqlalchemy_utils
|
||||
line_length=79
|
||||
multi_line_output=3
|
||||
not_skip=__init__.py
|
||||
|
|
40
.travis.yml
40
.travis.yml
|
@ -1,3 +1,6 @@
|
|||
sudo: false
|
||||
language: python
|
||||
|
||||
addons:
|
||||
postgresql: "9.4"
|
||||
|
||||
|
@ -6,22 +9,29 @@ before_script:
|
|||
- psql -c 'create extension hstore;' -U postgres -d sqlalchemy_utils_test
|
||||
- mysql -e 'create database sqlalchemy_utils_test;'
|
||||
|
||||
language: python
|
||||
python:
|
||||
- 2.6
|
||||
- 2.7
|
||||
- 3.3
|
||||
- 3.4
|
||||
- 3.5
|
||||
|
||||
env:
|
||||
- EXTRAS=test
|
||||
- EXTRAS=test_all
|
||||
matrix:
|
||||
include:
|
||||
- python: 2.6
|
||||
env:
|
||||
- "TOXENV=py26"
|
||||
- python: 2.7
|
||||
env:
|
||||
- "TOXENV=py27"
|
||||
- python: 3.3
|
||||
env:
|
||||
- "TOXENV=py33"
|
||||
- python: 3.4
|
||||
env:
|
||||
- "TOXENV=py34"
|
||||
- python: 3.5
|
||||
env:
|
||||
- "TOXENV=py35"
|
||||
- python: 3.5
|
||||
env:
|
||||
- "TOXENV=lint"
|
||||
|
||||
install:
|
||||
- pip install -e .[$EXTRAS]
|
||||
- pip install tox
|
||||
|
||||
script:
|
||||
- isort --recursive --diff sqlalchemy_utils tests && isort --recursive --check-only sqlalchemy_utils tests
|
||||
- flake8 sqlalchemy_utils tests
|
||||
- py.test
|
||||
- tox
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
include CHANGES.rst LICENSE README.rst
|
||||
include CHANGES.rst LICENSE README.rst conftest.py .isort.cfg
|
||||
recursive-include tests *
|
||||
recursive-exclude tests *.pyc
|
||||
recursive-include docs *
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
import os
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base, synonym_for
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy_utils import (
|
||||
aggregates,
|
||||
coercion_listener,
|
||||
i18n,
|
||||
InstrumentedList
|
||||
)
|
||||
|
||||
from sqlalchemy_utils.types.pg_composite import remove_composite_listeners
|
||||
|
||||
|
||||
@sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute')
|
||||
def count_sql_calls(conn, cursor, statement, parameters, context, executemany):
|
||||
try:
|
||||
conn.query_count += 1
|
||||
except AttributeError:
|
||||
conn.query_count = 0
|
||||
|
||||
|
||||
warnings.simplefilter('error', sa.exc.SAWarning)
|
||||
|
||||
sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener)
|
||||
|
||||
|
||||
def get_locale():
|
||||
class Locale():
|
||||
territories = {'FI': 'Finland'}
|
||||
|
||||
return Locale()
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def db_name():
|
||||
return os.environ.get('SQLALCHEMY_UTILS_TEST_DB', 'sqlalchemy_utils_test')
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def postgresql_db_user():
|
||||
return os.environ.get('SQLALCHEMY_UTILS_TEST_POSTGRESQL_USER', 'postgres')
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def mysql_db_user():
|
||||
return os.environ.get('SQLALCHEMY_UTILS_TEST_MYSQL_USER', 'root')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def postgresql_dsn(postgresql_db_user, db_name):
|
||||
return 'postgres://{0}@localhost/{1}'.format(postgresql_db_user, db_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mysql_dsn(mysql_db_user, db_name):
|
||||
return 'mysql+pymysql://{0}@localhost/{1}'.format(mysql_db_user, db_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sqlite_memory_dsn():
|
||||
return 'sqlite:///:memory:'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sqlite_file_dsn():
|
||||
return 'sqlite:///{0}.db'.format(db_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dsn(request):
|
||||
if 'postgresql_dsn' in request.fixturenames:
|
||||
return request.getfuncargvalue('postgresql_dsn')
|
||||
elif 'mysql_dsn' in request.fixturenames:
|
||||
return request.getfuncargvalue('mysql_dsn')
|
||||
elif 'sqlite_file_dsn' in request.fixturenames:
|
||||
return request.getfuncargvalue('sqlite_file_dsn')
|
||||
elif 'sqlite_memory_dsn' in request.fixturenames:
|
||||
pass # Return default
|
||||
return request.getfuncargvalue('sqlite_memory_dsn')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine(dsn):
|
||||
engine = create_engine(dsn)
|
||||
# engine.echo = True
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection(engine):
|
||||
return engine.connect()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Base():
|
||||
return declarative_base()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def User(Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
return User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Category(Base):
|
||||
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
title = sa.Column(sa.Unicode(255))
|
||||
|
||||
@hybrid_property
|
||||
def full_name(self):
|
||||
return u'%s %s' % (self.title, self.name)
|
||||
|
||||
@full_name.expression
|
||||
def full_name(self):
|
||||
return sa.func.concat(self.title, ' ', self.name)
|
||||
|
||||
@hybrid_property
|
||||
def articles_count(self):
|
||||
return len(self.articles)
|
||||
|
||||
@articles_count.expression
|
||||
def articles_count(cls):
|
||||
Article = Base._decl_class_registry['Article']
|
||||
return (
|
||||
sa.select([sa.func.count(Article.id)])
|
||||
.where(Article.category_id == cls.id)
|
||||
.correlate(Article.__table__)
|
||||
.label('article_count')
|
||||
)
|
||||
|
||||
@property
|
||||
def name_alias(self):
|
||||
return self.name
|
||||
|
||||
@synonym_for('name')
|
||||
@property
|
||||
def name_synonym(self):
|
||||
return self.name
|
||||
return Category
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Article(Base, Category):
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255), index=True)
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
||||
|
||||
category = sa.orm.relationship(
|
||||
Category,
|
||||
primaryjoin=category_id == Category.id,
|
||||
backref=sa.orm.backref(
|
||||
'articles',
|
||||
collection_class=InstrumentedList
|
||||
)
|
||||
)
|
||||
return Article
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(User, Category, Article):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(request, engine, connection, Base, init_models):
|
||||
sa.orm.configure_mappers()
|
||||
Base.metadata.create_all(connection)
|
||||
Session = sessionmaker(bind=connection)
|
||||
session = Session()
|
||||
i18n.get_locale = get_locale
|
||||
|
||||
def teardown():
|
||||
aggregates.manager.reset()
|
||||
session.close_all()
|
||||
Base.metadata.drop_all(connection)
|
||||
remove_composite_listeners()
|
||||
connection.close()
|
||||
engine.dispose()
|
||||
|
||||
request.addfinalizer(teardown)
|
||||
|
||||
return session
|
|
@ -11,6 +11,8 @@ SQLAlchemy-Utils has been tested against the following Python platforms.
|
|||
- cPython 2.6
|
||||
- cPython 2.7
|
||||
- cPython 3.3
|
||||
- cPython 3.4
|
||||
- cPython 3.5
|
||||
|
||||
|
||||
Installing an official release
|
||||
|
|
2
setup.py
2
setup.py
|
@ -89,11 +89,11 @@ setup(
|
|||
'Operating System :: OS Independent',
|
||||
'Programming Language :: Python',
|
||||
'Programming Language :: Python :: 2',
|
||||
'Programming Language :: Python :: 2.6',
|
||||
'Programming Language :: Python :: 2.7',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.3',
|
||||
'Programming Language :: Python :: 3.4',
|
||||
'Programming Language :: Python :: 3.5',
|
||||
'Topic :: Internet :: WWW/HTTP :: Dynamic Content',
|
||||
'Topic :: Software Development :: Libraries :: Python Modules'
|
||||
]
|
||||
|
|
|
@ -365,7 +365,6 @@ TODO
|
|||
from collections import defaultdict
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
import six
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.declarative import declared_attr
|
||||
from sqlalchemy.sql.functions import _FunctionGenerator
|
||||
|
@ -519,7 +518,7 @@ class AggregationManager(object):
|
|||
)
|
||||
|
||||
def update_generator_registry(self):
|
||||
for class_, attrs in six.iteritems(aggregated_attrs):
|
||||
for class_, attrs in aggregated_attrs.items():
|
||||
for expr, path, column in attrs:
|
||||
value = AggregatedValue(
|
||||
class_=class_,
|
||||
|
@ -539,7 +538,7 @@ class AggregationManager(object):
|
|||
if class_ in self.generator_registry:
|
||||
object_dict[class_].append(obj)
|
||||
|
||||
for class_, objects in six.iteritems(object_dict):
|
||||
for class_, objects in object_dict.items():
|
||||
for aggregate_value in self.generator_registry[class_]:
|
||||
query = aggregate_value.update_query(objects)
|
||||
if query is not None:
|
||||
|
|
|
@ -10,7 +10,7 @@ from sqlalchemy.sql.expression import (
|
|||
)
|
||||
from sqlalchemy.sql.functions import GenericFunction
|
||||
|
||||
from sqlalchemy_utils.functions.orm import quote
|
||||
from .functions.orm import quote
|
||||
|
||||
|
||||
class explain(Executable, ClauseElement):
|
||||
|
|
|
@ -7,8 +7,7 @@ import sqlalchemy as sa
|
|||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.exc import OperationalError, ProgrammingError
|
||||
|
||||
from sqlalchemy_utils.expressions import explain_analyze
|
||||
|
||||
from ..expressions import explain_analyze
|
||||
from ..utils import starts_with
|
||||
from .orm import quote
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from collections import defaultdict
|
||||
from itertools import groupby
|
||||
|
||||
import six
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.exc import NoInspectionAvailable
|
||||
from sqlalchemy.orm import object_session
|
||||
|
@ -167,7 +166,7 @@ def merge_references(from_, to, foreign_keys=None):
|
|||
new_values = get_foreign_key_values(fk, to)
|
||||
criteria = (
|
||||
getattr(fk.constraint.table.c, key) == value
|
||||
for key, value in six.iteritems(old_values)
|
||||
for key, value in old_values.items()
|
||||
)
|
||||
try:
|
||||
mapper = get_mapper(fk.constraint.table)
|
||||
|
|
|
@ -19,7 +19,7 @@ from sqlalchemy.orm.query import _ColumnEntity
|
|||
from sqlalchemy.orm.session import object_session
|
||||
from sqlalchemy.orm.util import AliasedInsp
|
||||
|
||||
from sqlalchemy_utils.utils import is_sequence
|
||||
from ..utils import is_sequence
|
||||
|
||||
|
||||
def get_class_by_table(base, table, data=None):
|
||||
|
|
|
@ -8,9 +8,8 @@ from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
|
|||
from sqlalchemy.orm.session import _state_session
|
||||
from sqlalchemy.util import set_creation_order
|
||||
|
||||
from sqlalchemy_utils.functions import identity
|
||||
|
||||
from .exceptions import ImproperlyConfigured
|
||||
from .functions import identity
|
||||
|
||||
|
||||
class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
||||
|
|
|
@ -8,6 +8,7 @@ from .exceptions import ImproperlyConfigured
|
|||
|
||||
try:
|
||||
import babel
|
||||
import babel.dates
|
||||
except ImportError:
|
||||
babel = None
|
||||
|
||||
|
|
|
@ -154,9 +154,9 @@ from collections import defaultdict, Iterable, namedtuple
|
|||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.functions import getdotattr, has_changes
|
||||
from sqlalchemy_utils.path import AttrPath
|
||||
from sqlalchemy_utils.utils import is_sequence
|
||||
from .functions import getdotattr, has_changes
|
||||
from .path import AttrPath
|
||||
from .utils import is_sequence
|
||||
|
||||
Callback = namedtuple('Callback', ['func', 'path', 'backref', 'fullpath'])
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import six
|
||||
|
||||
from sqlalchemy_utils import i18n
|
||||
from sqlalchemy_utils.utils import str_coercible
|
||||
from .. import i18n
|
||||
from ..utils import str_coercible
|
||||
|
||||
|
||||
@str_coercible
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import six
|
||||
|
||||
from sqlalchemy_utils import i18n, ImproperlyConfigured
|
||||
from sqlalchemy_utils.utils import str_coercible
|
||||
from .. import i18n, ImproperlyConfigured
|
||||
from ..utils import str_coercible
|
||||
|
||||
|
||||
@str_coercible
|
||||
|
|
|
@ -4,8 +4,8 @@ try:
|
|||
except ImportError:
|
||||
# Python 2.6 port
|
||||
from total_ordering import total_ordering
|
||||
from sqlalchemy_utils import i18n
|
||||
from sqlalchemy_utils.utils import str_coercible
|
||||
from .. import i18n
|
||||
from ..utils import str_coercible
|
||||
|
||||
|
||||
@str_coercible
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import six
|
||||
|
||||
from sqlalchemy_utils.utils import str_coercible
|
||||
|
||||
from ..utils import str_coercible
|
||||
from .weekday import WeekDay
|
||||
|
||||
|
||||
|
|
|
@ -6,8 +6,7 @@ from datetime import datetime
|
|||
import six
|
||||
from sqlalchemy import types
|
||||
|
||||
from sqlalchemy_utils.exceptions import ImproperlyConfigured
|
||||
|
||||
from ..exceptions import ImproperlyConfigured
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
arrow = None
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import six
|
||||
from sqlalchemy import types
|
||||
|
||||
from sqlalchemy_utils.exceptions import ImproperlyConfigured
|
||||
|
||||
from ..exceptions import ImproperlyConfigured
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
colour = None
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import six
|
||||
from sqlalchemy import types
|
||||
|
||||
from sqlalchemy_utils.primitives import Country
|
||||
|
||||
from ..primitives import Country
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
import six
|
||||
from sqlalchemy import types
|
||||
|
||||
from sqlalchemy_utils import i18n, ImproperlyConfigured
|
||||
from sqlalchemy_utils.primitives import Currency
|
||||
|
||||
from .. import i18n, ImproperlyConfigured
|
||||
from ..primitives import Currency
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
|
||||
|
|
|
@ -5,8 +5,7 @@ import datetime
|
|||
import six
|
||||
from sqlalchemy.types import Binary, String, TypeDecorator
|
||||
|
||||
from sqlalchemy_utils.exceptions import ImproperlyConfigured
|
||||
|
||||
from ..exceptions import ImproperlyConfigured
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
cryptography = None
|
||||
|
@ -84,7 +83,7 @@ class AesEngine(EncryptionDecryptionBaseEngine):
|
|||
value = str(value)
|
||||
decryptor = self.cipher.decryptor()
|
||||
decrypted = base64.b64decode(value)
|
||||
decrypted = decryptor.update(decrypted)+decryptor.finalize()
|
||||
decrypted = decryptor.update(decrypted) + decryptor.finalize()
|
||||
decrypted = decrypted.rstrip(self.PADDING)
|
||||
if not isinstance(decrypted, six.string_types):
|
||||
decrypted = decrypted.decode('utf-8')
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import six
|
||||
from sqlalchemy import types
|
||||
|
||||
from sqlalchemy_utils.exceptions import ImproperlyConfigured
|
||||
|
||||
from ..exceptions import ImproperlyConfigured
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
ip_address = None
|
||||
|
|
|
@ -5,8 +5,7 @@ from sqlalchemy import types
|
|||
from sqlalchemy.dialects import oracle, postgresql
|
||||
from sqlalchemy.ext.mutable import Mutable
|
||||
|
||||
from sqlalchemy_utils.exceptions import ImproperlyConfigured
|
||||
|
||||
from ..exceptions import ImproperlyConfigured
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
passlib = None
|
||||
|
|
|
@ -109,7 +109,7 @@ from sqlalchemy.types import (
|
|||
UserDefinedType
|
||||
)
|
||||
|
||||
from sqlalchemy_utils import ImproperlyConfigured
|
||||
from .. import ImproperlyConfigured
|
||||
|
||||
psycopg2 = None
|
||||
CompositeCaster = None
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from sqlalchemy import types
|
||||
|
||||
from sqlalchemy_utils.exceptions import ImproperlyConfigured
|
||||
from sqlalchemy_utils.utils import str_coercible
|
||||
|
||||
from ..exceptions import ImproperlyConfigured
|
||||
from ..utils import str_coercible
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
try:
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import six
|
||||
from sqlalchemy import types
|
||||
|
||||
from sqlalchemy_utils.exceptions import ImproperlyConfigured
|
||||
|
||||
from ..exceptions import ImproperlyConfigured
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import six
|
||||
from sqlalchemy import types
|
||||
|
||||
from sqlalchemy_utils import i18n
|
||||
from sqlalchemy_utils.exceptions import ImproperlyConfigured
|
||||
from sqlalchemy_utils.primitives import WeekDay, WeekDays
|
||||
|
||||
from .. import i18n
|
||||
from ..exceptions import ImproperlyConfigured
|
||||
from ..primitives import WeekDay, WeekDays
|
||||
from .bit import BitType
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
|
|
|
@ -1,132 +1,3 @@
|
|||
import warnings
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base, synonym_for
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from sqlalchemy_utils import (
|
||||
aggregates,
|
||||
coercion_listener,
|
||||
i18n,
|
||||
InstrumentedList
|
||||
)
|
||||
from sqlalchemy_utils.types.pg_composite import remove_composite_listeners
|
||||
|
||||
|
||||
@sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute')
|
||||
def count_sql_calls(conn, cursor, statement, parameters, context, executemany):
|
||||
try:
|
||||
conn.query_count += 1
|
||||
except AttributeError:
|
||||
conn.query_count = 0
|
||||
|
||||
|
||||
warnings.simplefilter('error', sa.exc.SAWarning)
|
||||
|
||||
|
||||
sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener)
|
||||
|
||||
|
||||
def get_locale():
|
||||
class Locale():
|
||||
territories = {'FI': 'Finland'}
|
||||
|
||||
return Locale()
|
||||
|
||||
|
||||
class TestCase(object):
|
||||
dns = 'sqlite:///:memory:'
|
||||
create_tables = True
|
||||
|
||||
def setup_method(self, method):
|
||||
self.engine = create_engine(self.dns)
|
||||
# self.engine.echo = True
|
||||
self.connection = self.engine.connect()
|
||||
self.Base = declarative_base()
|
||||
|
||||
self.create_models()
|
||||
sa.orm.configure_mappers()
|
||||
if self.create_tables:
|
||||
self.Base.metadata.create_all(self.connection)
|
||||
|
||||
Session = sessionmaker(bind=self.connection)
|
||||
self.session = Session()
|
||||
|
||||
i18n.get_locale = get_locale
|
||||
|
||||
def teardown_method(self, method):
|
||||
aggregates.manager.reset()
|
||||
self.session.close_all()
|
||||
if self.create_tables:
|
||||
self.Base.metadata.drop_all(self.connection)
|
||||
remove_composite_listeners()
|
||||
self.connection.close()
|
||||
self.engine.dispose()
|
||||
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
title = sa.Column(sa.Unicode(255))
|
||||
|
||||
@hybrid_property
|
||||
def full_name(self):
|
||||
return u'%s %s' % (self.title, self.name)
|
||||
|
||||
@full_name.expression
|
||||
def full_name(self):
|
||||
return sa.func.concat(self.title, ' ', self.name)
|
||||
|
||||
@hybrid_property
|
||||
def articles_count(self):
|
||||
return len(self.articles)
|
||||
|
||||
@articles_count.expression
|
||||
def articles_count(cls):
|
||||
return (
|
||||
sa.select([sa.func.count(self.Article.id)])
|
||||
.where(self.Article.category_id == self.Category.id)
|
||||
.correlate(self.Article.__table__)
|
||||
.label('article_count')
|
||||
)
|
||||
|
||||
@property
|
||||
def name_alias(self):
|
||||
return self.name
|
||||
|
||||
@synonym_for('name')
|
||||
@property
|
||||
def name_synonym(self):
|
||||
return self.name
|
||||
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255), index=True)
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
||||
|
||||
category = sa.orm.relationship(
|
||||
Category,
|
||||
primaryjoin=category_id == Category.id,
|
||||
backref=sa.orm.backref(
|
||||
'articles',
|
||||
collection_class=InstrumentedList
|
||||
)
|
||||
)
|
||||
|
||||
self.User = User
|
||||
self.Category = Category
|
||||
self.Article = Article
|
||||
|
||||
|
||||
def assert_contains(clause, query):
|
||||
# Test that query executes
|
||||
query.all()
|
||||
|
|
|
@ -1,61 +1,76 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregateValueGenerationWithBackrefs(TestCase):
|
||||
def create_models(self):
|
||||
class Thread(self.Base):
|
||||
__tablename__ = 'thread'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
@pytest.fixture
|
||||
def Thread(Base):
|
||||
class Thread(Base):
|
||||
__tablename__ = 'thread'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated('comments', sa.Column(sa.Integer, default=0))
|
||||
def comment_count(self):
|
||||
return sa.func.count('1')
|
||||
@aggregated('comments', sa.Column(sa.Integer, default=0))
|
||||
def comment_count(self):
|
||||
return sa.func.count('1')
|
||||
return Thread
|
||||
|
||||
class Comment(self.Base):
|
||||
__tablename__ = 'comment'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
content = sa.Column(sa.Unicode(255))
|
||||
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
|
||||
|
||||
thread = sa.orm.relationship(Thread, backref='comments')
|
||||
@pytest.fixture
|
||||
def Comment(Base, Thread):
|
||||
class Comment(Base):
|
||||
__tablename__ = 'comment'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
content = sa.Column(sa.Unicode(255))
|
||||
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
|
||||
|
||||
self.Thread = Thread
|
||||
self.Comment = Comment
|
||||
thread = sa.orm.relationship(Thread, backref='comments')
|
||||
return Comment
|
||||
|
||||
def test_assigns_aggregates_on_insert(self):
|
||||
thread = self.Thread()
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(Thread, Comment):
|
||||
pass
|
||||
|
||||
|
||||
class TestAggregateValueGenerationWithBackrefs(object):
|
||||
|
||||
def test_assigns_aggregates_on_insert(self, session, Thread, Comment):
|
||||
thread = Thread()
|
||||
thread.name = u'some article name'
|
||||
self.session.add(thread)
|
||||
comment = self.Comment(content=u'Some content', thread=thread)
|
||||
self.session.add(comment)
|
||||
self.session.commit()
|
||||
self.session.refresh(thread)
|
||||
session.add(thread)
|
||||
comment = Comment(content=u'Some content', thread=thread)
|
||||
session.add(comment)
|
||||
session.commit()
|
||||
session.refresh(thread)
|
||||
assert thread.comment_count == 1
|
||||
|
||||
def test_assigns_aggregates_on_separate_insert(self):
|
||||
thread = self.Thread()
|
||||
def test_assigns_aggregates_on_separate_insert(
|
||||
self,
|
||||
session,
|
||||
Thread,
|
||||
Comment
|
||||
):
|
||||
thread = Thread()
|
||||
thread.name = u'some article name'
|
||||
self.session.add(thread)
|
||||
self.session.commit()
|
||||
comment = self.Comment(content=u'Some content', thread=thread)
|
||||
self.session.add(comment)
|
||||
self.session.commit()
|
||||
self.session.refresh(thread)
|
||||
session.add(thread)
|
||||
session.commit()
|
||||
comment = Comment(content=u'Some content', thread=thread)
|
||||
session.add(comment)
|
||||
session.commit()
|
||||
session.refresh(thread)
|
||||
assert thread.comment_count == 1
|
||||
|
||||
def test_assigns_aggregates_on_delete(self):
|
||||
thread = self.Thread()
|
||||
def test_assigns_aggregates_on_delete(self, session, Thread, Comment):
|
||||
thread = Thread()
|
||||
thread.name = u'some article name'
|
||||
self.session.add(thread)
|
||||
self.session.commit()
|
||||
comment = self.Comment(content=u'Some content', thread=thread)
|
||||
self.session.add(comment)
|
||||
self.session.commit()
|
||||
self.session.delete(comment)
|
||||
self.session.commit()
|
||||
self.session.refresh(thread)
|
||||
session.add(thread)
|
||||
session.commit()
|
||||
comment = Comment(content=u'Some content', thread=thread)
|
||||
session.add(comment)
|
||||
session.commit()
|
||||
session.delete(comment)
|
||||
session.commit()
|
||||
session.refresh(thread)
|
||||
assert thread.comment_count == 0
|
||||
|
|
|
@ -1,67 +1,76 @@
|
|||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestLazyEvaluatedSelectExpressionsForAggregates(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.fixture
|
||||
def Product(Base):
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
return Product
|
||||
|
||||
@aggregated('products', sa.Column(sa.Numeric, default=0))
|
||||
def net_worth(self):
|
||||
return sa.func.sum(Product.price)
|
||||
|
||||
products = sa.orm.relationship('Product', backref='catalog')
|
||||
@pytest.fixture
|
||||
def Catalog(Base, Product):
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
@aggregated('products', sa.Column(sa.Numeric, default=0))
|
||||
def net_worth(self):
|
||||
return sa.func.sum(Product.price)
|
||||
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
products = sa.orm.relationship('Product', backref='catalog')
|
||||
return Catalog
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Product = Product
|
||||
|
||||
def test_assigns_aggregates_on_insert(self):
|
||||
catalog = self.Catalog(
|
||||
@pytest.fixture
|
||||
def init_models(Product, Catalog):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestLazyEvaluatedSelectExpressionsForAggregates(object):
|
||||
|
||||
def test_assigns_aggregates_on_insert(self, session, Product, Catalog):
|
||||
catalog = Catalog(
|
||||
name=u'Some catalog'
|
||||
)
|
||||
self.session.add(catalog)
|
||||
self.session.commit()
|
||||
product = self.Product(
|
||||
session.add(catalog)
|
||||
session.commit()
|
||||
product = Product(
|
||||
name=u'Some product',
|
||||
price=Decimal('1000'),
|
||||
catalog=catalog
|
||||
)
|
||||
self.session.add(product)
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
||||
session.add(product)
|
||||
session.commit()
|
||||
session.refresh(catalog)
|
||||
assert catalog.net_worth == Decimal('1000')
|
||||
|
||||
def test_assigns_aggregates_on_update(self):
|
||||
catalog = self.Catalog(
|
||||
def test_assigns_aggregates_on_update(self, session, Product, Catalog):
|
||||
catalog = Catalog(
|
||||
name=u'Some catalog'
|
||||
)
|
||||
self.session.add(catalog)
|
||||
self.session.commit()
|
||||
product = self.Product(
|
||||
session.add(catalog)
|
||||
session.commit()
|
||||
product = Product(
|
||||
name=u'Some product',
|
||||
price=Decimal('1000'),
|
||||
catalog=catalog
|
||||
)
|
||||
self.session.add(product)
|
||||
self.session.commit()
|
||||
session.add(product)
|
||||
session.commit()
|
||||
product.price = Decimal('500')
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
||||
session.commit()
|
||||
session.refresh(catalog)
|
||||
assert catalog.net_worth == Decimal('500')
|
||||
|
|
|
@ -1,101 +1,121 @@
|
|||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestLazyEvaluatedSelectExpressionsForAggregates(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.fixture
|
||||
def Product(Base):
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
type = sa.Column(sa.Unicode(255))
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
return Product
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': type
|
||||
}
|
||||
|
||||
@aggregated('products', sa.Column(sa.Numeric, default=0))
|
||||
def net_worth(self):
|
||||
return sa.func.sum(Product.price)
|
||||
@pytest.fixture
|
||||
def Catalog(Base, Product):
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
type = sa.Column(sa.Unicode(255))
|
||||
|
||||
products = sa.orm.relationship('Product', backref='catalog')
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': type
|
||||
}
|
||||
|
||||
class CostumeCatalog(Catalog):
|
||||
__tablename__ = 'costume_catalog'
|
||||
id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True
|
||||
)
|
||||
@aggregated('products', sa.Column(sa.Numeric, default=0))
|
||||
def net_worth(self):
|
||||
return sa.func.sum(Product.price)
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'costumes',
|
||||
}
|
||||
products = sa.orm.relationship('Product', backref='catalog')
|
||||
return Catalog
|
||||
|
||||
class CarCatalog(Catalog):
|
||||
__tablename__ = 'car_catalog'
|
||||
id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True
|
||||
)
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'cars',
|
||||
}
|
||||
@pytest.fixture
|
||||
def CostumeCatalog(Catalog):
|
||||
class CostumeCatalog(Catalog):
|
||||
__tablename__ = 'costume_catalog'
|
||||
id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True
|
||||
)
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'costumes',
|
||||
}
|
||||
return CostumeCatalog
|
||||
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.CostumeCatalog = CostumeCatalog
|
||||
self.CarCatalog = CarCatalog
|
||||
self.Product = Product
|
||||
@pytest.fixture
|
||||
def CarCatalog(Catalog):
|
||||
class CarCatalog(Catalog):
|
||||
__tablename__ = 'car_catalog'
|
||||
id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True
|
||||
)
|
||||
|
||||
def test_columns_inherited_from_parent(self):
|
||||
assert self.CarCatalog.net_worth
|
||||
assert self.CostumeCatalog.net_worth
|
||||
assert self.Catalog.net_worth
|
||||
assert not hasattr(self.CarCatalog.__table__.c, 'net_worth')
|
||||
assert not hasattr(self.CostumeCatalog.__table__.c, 'net_worth')
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'cars',
|
||||
}
|
||||
return CarCatalog
|
||||
|
||||
def test_assigns_aggregates_on_insert(self):
|
||||
catalog = self.Catalog(
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(Product, Catalog, CostumeCatalog, CarCatalog):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestLazyEvaluatedSelectExpressionsForAggregates(object):
|
||||
|
||||
def test_columns_inherited_from_parent(
|
||||
self,
|
||||
Catalog,
|
||||
CarCatalog,
|
||||
CostumeCatalog
|
||||
):
|
||||
assert CarCatalog.net_worth
|
||||
assert CostumeCatalog.net_worth
|
||||
assert Catalog.net_worth
|
||||
assert not hasattr(CarCatalog.__table__.c, 'net_worth')
|
||||
assert not hasattr(CostumeCatalog.__table__.c, 'net_worth')
|
||||
|
||||
def test_assigns_aggregates_on_insert(self, session, Product, Catalog):
|
||||
catalog = Catalog(
|
||||
name=u'Some catalog'
|
||||
)
|
||||
self.session.add(catalog)
|
||||
self.session.commit()
|
||||
product = self.Product(
|
||||
session.add(catalog)
|
||||
session.commit()
|
||||
product = Product(
|
||||
name=u'Some product',
|
||||
price=Decimal('1000'),
|
||||
catalog=catalog
|
||||
)
|
||||
self.session.add(product)
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
||||
session.add(product)
|
||||
session.commit()
|
||||
session.refresh(catalog)
|
||||
assert catalog.net_worth == Decimal('1000')
|
||||
|
||||
def test_assigns_aggregates_on_update(self):
|
||||
catalog = self.Catalog(
|
||||
def test_assigns_aggregates_on_update(self, session, Catalog, Product):
|
||||
catalog = Catalog(
|
||||
name=u'Some catalog'
|
||||
)
|
||||
self.session.add(catalog)
|
||||
self.session.commit()
|
||||
product = self.Product(
|
||||
session.add(catalog)
|
||||
session.commit()
|
||||
product = Product(
|
||||
name=u'Some product',
|
||||
price=Decimal('1000'),
|
||||
catalog=catalog
|
||||
)
|
||||
self.session.add(product)
|
||||
self.session.commit()
|
||||
session.add(product)
|
||||
session.commit()
|
||||
product.price = Decimal('500')
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
||||
session.commit()
|
||||
session.refresh(catalog)
|
||||
assert catalog.net_worth == Decimal('500')
|
||||
|
|
|
@ -1,72 +1,81 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregatesWithManyToManyRelationships(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.fixture
|
||||
def User(Base):
|
||||
user_group = sa.Table(
|
||||
'user_group',
|
||||
Base.metadata,
|
||||
sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')),
|
||||
sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id'))
|
||||
)
|
||||
|
||||
def create_models(self):
|
||||
user_group = sa.Table(
|
||||
'user_group',
|
||||
self.Base.metadata,
|
||||
sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')),
|
||||
sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id'))
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated('groups', sa.Column(sa.Integer, default=0))
|
||||
def group_count(self):
|
||||
return sa.func.count('1')
|
||||
|
||||
groups = sa.orm.relationship(
|
||||
'Group',
|
||||
backref='users',
|
||||
secondary=user_group
|
||||
)
|
||||
return User
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated('groups', sa.Column(sa.Integer, default=0))
|
||||
def group_count(self):
|
||||
return sa.func.count('1')
|
||||
@pytest.fixture
|
||||
def Group(Base):
|
||||
class Group(Base):
|
||||
__tablename__ = 'group'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
return Group
|
||||
|
||||
groups = sa.orm.relationship(
|
||||
'Group',
|
||||
backref='users',
|
||||
secondary=user_group
|
||||
)
|
||||
|
||||
class Group(self.Base):
|
||||
__tablename__ = 'group'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
@pytest.fixture
|
||||
def init_models(User, Group):
|
||||
pass
|
||||
|
||||
self.User = User
|
||||
self.Group = Group
|
||||
|
||||
def test_assigns_aggregates_on_insert(self):
|
||||
user = self.User(
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestAggregatesWithManyToManyRelationships(object):
|
||||
|
||||
def test_assigns_aggregates_on_insert(self, session, User, Group):
|
||||
user = User(
|
||||
name=u'John Matrix'
|
||||
)
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
group = self.Group(
|
||||
session.add(user)
|
||||
session.commit()
|
||||
group = Group(
|
||||
name=u'Some group',
|
||||
users=[user]
|
||||
)
|
||||
self.session.add(group)
|
||||
self.session.commit()
|
||||
self.session.refresh(user)
|
||||
session.add(group)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
assert user.group_count == 1
|
||||
|
||||
def test_updates_aggregates_on_delete(self):
|
||||
user = self.User(
|
||||
def test_updates_aggregates_on_delete(self, session, User, Group):
|
||||
user = User(
|
||||
name=u'John Matrix'
|
||||
)
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
group = self.Group(
|
||||
session.add(user)
|
||||
session.commit()
|
||||
group = Group(
|
||||
name=u'Some group',
|
||||
users=[user]
|
||||
)
|
||||
self.session.add(group)
|
||||
self.session.commit()
|
||||
self.session.refresh(user)
|
||||
session.add(group)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.groups = []
|
||||
self.session.commit()
|
||||
self.session.refresh(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
assert user.group_count == 0
|
||||
|
|
|
@ -1,80 +1,92 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregateManyToManyAndManyToMany(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.fixture
|
||||
def Category(Base):
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
return Category
|
||||
|
||||
def create_models(self):
|
||||
catalog_products = sa.Table(
|
||||
'catalog_product',
|
||||
self.Base.metadata,
|
||||
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
|
||||
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||
|
||||
@pytest.fixture
|
||||
def Catalog(Base, Category):
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated(
|
||||
'products.categories',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def category_count(self):
|
||||
return sa.func.count(sa.distinct(Category.id))
|
||||
return Catalog
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Product(Base, Catalog, Category):
|
||||
catalog_products = sa.Table(
|
||||
'catalog_product',
|
||||
Base.metadata,
|
||||
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
|
||||
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||
)
|
||||
|
||||
product_categories = sa.Table(
|
||||
'category_product',
|
||||
Base.metadata,
|
||||
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
|
||||
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||
)
|
||||
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
catalog_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey('catalog.id')
|
||||
)
|
||||
|
||||
product_categories = sa.Table(
|
||||
'category_product',
|
||||
self.Base.metadata,
|
||||
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
|
||||
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||
catalogs = sa.orm.relationship(
|
||||
Catalog,
|
||||
backref='products',
|
||||
secondary=catalog_products
|
||||
)
|
||||
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
categories = sa.orm.relationship(
|
||||
Category,
|
||||
backref='products',
|
||||
secondary=product_categories
|
||||
)
|
||||
return Product
|
||||
|
||||
@aggregated(
|
||||
'products.categories',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def category_count(self):
|
||||
return sa.func.count(sa.distinct(Category.id))
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
@pytest.fixture
|
||||
def init_models(Category, Catalog, Product):
|
||||
pass
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
catalog_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey('catalog.id')
|
||||
)
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestAggregateManyToManyAndManyToMany(object):
|
||||
|
||||
catalogs = sa.orm.relationship(
|
||||
Catalog,
|
||||
backref='products',
|
||||
secondary=catalog_products
|
||||
)
|
||||
|
||||
categories = sa.orm.relationship(
|
||||
Category,
|
||||
backref='products',
|
||||
secondary=product_categories
|
||||
)
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.Product = Product
|
||||
|
||||
def test_insert(self):
|
||||
category = self.Category()
|
||||
def test_insert(self, session, Product, Category, Catalog):
|
||||
category = Category()
|
||||
products = [
|
||||
self.Product(categories=[category]),
|
||||
self.Product(categories=[category])
|
||||
Product(categories=[category]),
|
||||
Product(categories=[category])
|
||||
]
|
||||
catalog = self.Catalog(products=products)
|
||||
self.session.add(catalog)
|
||||
catalog2 = self.Catalog(products=products)
|
||||
self.session.add(catalog)
|
||||
self.session.commit()
|
||||
catalog = Catalog(products=products)
|
||||
session.add(catalog)
|
||||
catalog2 = Catalog(products=products)
|
||||
session.add(catalog)
|
||||
session.commit()
|
||||
assert catalog.category_count == 1
|
||||
assert catalog2.category_count == 1
|
||||
|
|
|
@ -1,81 +1,96 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
|
||||
def create_models(self):
|
||||
class Thread(self.Base):
|
||||
__tablename__ = 'thread'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
@pytest.fixture
|
||||
def Comment(Base):
|
||||
class Comment(Base):
|
||||
__tablename__ = 'comment'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
content = sa.Column(sa.Unicode(255))
|
||||
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
|
||||
return Comment
|
||||
|
||||
@aggregated(
|
||||
'comments',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def comment_count(self):
|
||||
return sa.func.count('1')
|
||||
|
||||
@aggregated('comments', sa.Column(sa.Integer))
|
||||
def last_comment_id(self):
|
||||
return sa.func.max(Comment.id)
|
||||
@pytest.fixture
|
||||
def Thread(Base, Comment):
|
||||
class Thread(Base):
|
||||
__tablename__ = 'thread'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
comments = sa.orm.relationship(
|
||||
'Comment',
|
||||
backref='thread'
|
||||
)
|
||||
@aggregated(
|
||||
'comments',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def comment_count(self):
|
||||
return sa.func.count('1')
|
||||
|
||||
Thread.last_comment = sa.orm.relationship(
|
||||
@aggregated('comments', sa.Column(sa.Integer))
|
||||
def last_comment_id(self):
|
||||
return sa.func.max(Comment.id)
|
||||
|
||||
comments = sa.orm.relationship(
|
||||
'Comment',
|
||||
primaryjoin='Thread.last_comment_id == Comment.id',
|
||||
foreign_keys=[Thread.last_comment_id],
|
||||
viewonly=True
|
||||
backref='thread'
|
||||
)
|
||||
|
||||
class Comment(self.Base):
|
||||
__tablename__ = 'comment'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
content = sa.Column(sa.Unicode(255))
|
||||
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
|
||||
Thread.last_comment = sa.orm.relationship(
|
||||
'Comment',
|
||||
primaryjoin='Thread.last_comment_id == Comment.id',
|
||||
foreign_keys=[Thread.last_comment_id],
|
||||
viewonly=True
|
||||
)
|
||||
return Thread
|
||||
|
||||
self.Thread = Thread
|
||||
self.Comment = Comment
|
||||
|
||||
def test_assigns_aggregates_on_insert(self):
|
||||
thread = self.Thread()
|
||||
@pytest.fixture
|
||||
def init_models(Comment, Thread):
|
||||
pass
|
||||
|
||||
|
||||
class TestAggregateValueGenerationForSimpleModelPaths(object):
|
||||
|
||||
def test_assigns_aggregates_on_insert(self, session, Thread, Comment):
|
||||
thread = Thread()
|
||||
thread.name = u'some article name'
|
||||
self.session.add(thread)
|
||||
comment = self.Comment(content=u'Some content', thread=thread)
|
||||
self.session.add(comment)
|
||||
self.session.commit()
|
||||
self.session.refresh(thread)
|
||||
session.add(thread)
|
||||
comment = Comment(content=u'Some content', thread=thread)
|
||||
session.add(comment)
|
||||
session.commit()
|
||||
session.refresh(thread)
|
||||
assert thread.comment_count == 1
|
||||
assert thread.last_comment_id == comment.id
|
||||
|
||||
def test_assigns_aggregates_on_separate_insert(self):
|
||||
thread = self.Thread()
|
||||
def test_assigns_aggregates_on_separate_insert(
|
||||
self,
|
||||
session,
|
||||
Thread,
|
||||
Comment
|
||||
):
|
||||
thread = Thread()
|
||||
thread.name = u'some article name'
|
||||
self.session.add(thread)
|
||||
self.session.commit()
|
||||
comment = self.Comment(content=u'Some content', thread=thread)
|
||||
self.session.add(comment)
|
||||
self.session.commit()
|
||||
self.session.refresh(thread)
|
||||
session.add(thread)
|
||||
session.commit()
|
||||
comment = Comment(content=u'Some content', thread=thread)
|
||||
session.add(comment)
|
||||
session.commit()
|
||||
session.refresh(thread)
|
||||
assert thread.comment_count == 1
|
||||
assert thread.last_comment_id == 1
|
||||
|
||||
def test_assigns_aggregates_on_delete(self):
|
||||
thread = self.Thread()
|
||||
def test_assigns_aggregates_on_delete(self, session, Thread, Comment):
|
||||
thread = Thread()
|
||||
thread.name = u'some article name'
|
||||
self.session.add(thread)
|
||||
self.session.commit()
|
||||
comment = self.Comment(content=u'Some content', thread=thread)
|
||||
self.session.add(comment)
|
||||
self.session.commit()
|
||||
self.session.delete(comment)
|
||||
self.session.commit()
|
||||
self.session.refresh(thread)
|
||||
session.add(thread)
|
||||
session.commit()
|
||||
comment = Comment(content=u'Some content', thread=thread)
|
||||
session.add(comment)
|
||||
session.commit()
|
||||
session.delete(comment)
|
||||
session.commit()
|
||||
session.refresh(thread)
|
||||
assert thread.comment_count == 0
|
||||
assert thread.last_comment_id is None
|
||||
|
|
|
@ -1,76 +1,88 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregateOneToManyAndManyToMany(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.fixture
|
||||
def Category(Base):
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
return Category
|
||||
|
||||
def create_models(self):
|
||||
product_categories = sa.Table(
|
||||
'category_product',
|
||||
self.Base.metadata,
|
||||
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
|
||||
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||
|
||||
@pytest.fixture
|
||||
def Catalog(Base, Category):
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated(
|
||||
'products.categories',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def category_count(self):
|
||||
return sa.func.count(sa.distinct(Category.id))
|
||||
return Catalog
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Product(Base, Catalog, Category):
|
||||
product_categories = sa.Table(
|
||||
'category_product',
|
||||
Base.metadata,
|
||||
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
|
||||
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||
)
|
||||
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
catalog_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey('catalog.id')
|
||||
)
|
||||
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
catalog = sa.orm.relationship(
|
||||
Catalog,
|
||||
backref='products'
|
||||
)
|
||||
|
||||
@aggregated(
|
||||
'products.categories',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def category_count(self):
|
||||
return sa.func.count(sa.distinct(Category.id))
|
||||
categories = sa.orm.relationship(
|
||||
Category,
|
||||
backref='products',
|
||||
secondary=product_categories
|
||||
)
|
||||
return Product
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
@pytest.fixture
|
||||
def init_models(Category, Catalog, Product):
|
||||
pass
|
||||
|
||||
catalog_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey('catalog.id')
|
||||
)
|
||||
|
||||
catalog = sa.orm.relationship(
|
||||
Catalog,
|
||||
backref='products'
|
||||
)
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestAggregateOneToManyAndManyToMany(object):
|
||||
|
||||
categories = sa.orm.relationship(
|
||||
Category,
|
||||
backref='products',
|
||||
secondary=product_categories
|
||||
)
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.Product = Product
|
||||
|
||||
def test_insert(self):
|
||||
category = self.Category()
|
||||
def test_insert(self, session, Category, Catalog, Product):
|
||||
category = Category()
|
||||
products = [
|
||||
self.Product(categories=[category]),
|
||||
self.Product(categories=[category])
|
||||
Product(categories=[category]),
|
||||
Product(categories=[category])
|
||||
]
|
||||
catalog = self.Catalog(products=products)
|
||||
self.session.add(catalog)
|
||||
catalog = Catalog(products=products)
|
||||
session.add(catalog)
|
||||
products2 = [
|
||||
self.Product(categories=[category]),
|
||||
self.Product(categories=[category])
|
||||
Product(categories=[category]),
|
||||
Product(categories=[category])
|
||||
]
|
||||
catalog2 = self.Catalog(products=products2)
|
||||
self.session.add(catalog)
|
||||
self.session.commit()
|
||||
catalog2 = Catalog(products=products2)
|
||||
session.add(catalog)
|
||||
session.commit()
|
||||
assert catalog.category_count == 1
|
||||
assert catalog2.category_count == 1
|
||||
|
|
|
@ -1,64 +1,76 @@
|
|||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregateOneToManyAndOneToMany(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.fixture
|
||||
def Catalog(Base):
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
@aggregated(
|
||||
'categories.products',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def product_count(self):
|
||||
return sa.func.count('1')
|
||||
|
||||
@aggregated(
|
||||
'categories.products',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def product_count(self):
|
||||
return sa.func.count('1')
|
||||
categories = sa.orm.relationship('Category', backref='catalog')
|
||||
return Catalog
|
||||
|
||||
categories = sa.orm.relationship('Category', backref='catalog')
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
@pytest.fixture
|
||||
def Category(Base):
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
|
||||
products = sa.orm.relationship('Product', backref='category')
|
||||
products = sa.orm.relationship('Product', backref='category')
|
||||
return Category
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
@pytest.fixture
|
||||
def Product(Base):
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.Product = Product
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
return Product
|
||||
|
||||
def test_assigns_aggregates(self):
|
||||
category = self.Category(name=u'Some category')
|
||||
catalog = self.Catalog(
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(Catalog, Category, Product):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestAggregateOneToManyAndOneToMany(object):
|
||||
|
||||
def test_assigns_aggregates(self, session, Category, Catalog, Product):
|
||||
category = Category(name=u'Some category')
|
||||
catalog = Catalog(
|
||||
categories=[category]
|
||||
)
|
||||
catalog.name = u'Some catalog'
|
||||
self.session.add(catalog)
|
||||
self.session.commit()
|
||||
product = self.Product(
|
||||
session.add(catalog)
|
||||
session.commit()
|
||||
product = Product(
|
||||
name=u'Some product',
|
||||
price=Decimal('1000'),
|
||||
category=category
|
||||
)
|
||||
self.session.add(product)
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
||||
session.add(product)
|
||||
session.commit()
|
||||
session.refresh(catalog)
|
||||
assert catalog.product_count == 1
|
||||
|
|
|
@ -1,88 +1,129 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class Test3LevelDeepOneToMany(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.fixture
|
||||
def Catalog(Base):
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
@aggregated(
|
||||
'categories.sub_categories.products',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def product_count(self):
|
||||
return sa.func.count('1')
|
||||
|
||||
@aggregated(
|
||||
'categories.sub_categories.products',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def product_count(self):
|
||||
return sa.func.count('1')
|
||||
categories = sa.orm.relationship('Category', backref='catalog')
|
||||
return Catalog
|
||||
|
||||
categories = sa.orm.relationship('Category', backref='catalog')
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
@pytest.fixture
|
||||
def Category(Base):
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
|
||||
sub_categories = sa.orm.relationship(
|
||||
'SubCategory', backref='category'
|
||||
)
|
||||
sub_categories = sa.orm.relationship(
|
||||
'SubCategory', backref='category'
|
||||
)
|
||||
return Category
|
||||
|
||||
class SubCategory(self.Base):
|
||||
__tablename__ = 'sub_category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
products = sa.orm.relationship('Product', backref='sub_category')
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
price = sa.Column(sa.Numeric)
|
||||
@pytest.fixture
|
||||
def SubCategory(Base):
|
||||
class SubCategory(Base):
|
||||
__tablename__ = 'sub_category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
products = sa.orm.relationship('Product', backref='sub_category')
|
||||
return SubCategory
|
||||
|
||||
sub_category_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey('sub_category.id')
|
||||
)
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.SubCategory = SubCategory
|
||||
self.Product = Product
|
||||
@pytest.fixture
|
||||
def Product(Base):
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
def test_assigns_aggregates(self):
|
||||
catalog = self.catalog_factory()
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
||||
assert catalog.product_count == 1
|
||||
sub_category_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey('sub_category.id')
|
||||
)
|
||||
return Product
|
||||
|
||||
def catalog_factory(self):
|
||||
product = self.Product()
|
||||
sub_category = self.SubCategory(
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(Catalog, Category, SubCategory, Product):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def catalog_factory(Product, SubCategory, Category, Catalog, session):
|
||||
def catalog_factory():
|
||||
product = Product()
|
||||
sub_category = SubCategory(
|
||||
products=[product]
|
||||
)
|
||||
category = self.Category(sub_categories=[sub_category])
|
||||
catalog = self.Catalog(categories=[category])
|
||||
self.session.add(catalog)
|
||||
category = Category(sub_categories=[sub_category])
|
||||
catalog = Catalog(categories=[category])
|
||||
session.add(catalog)
|
||||
return catalog
|
||||
return catalog_factory
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class Test3LevelDeepOneToMany(object):
|
||||
|
||||
def test_assigns_aggregates(self, session, catalog_factory):
|
||||
catalog = catalog_factory()
|
||||
session.commit()
|
||||
session.refresh(catalog)
|
||||
assert catalog.product_count == 1
|
||||
|
||||
def catalog_factory(
|
||||
self,
|
||||
session,
|
||||
Product,
|
||||
SubCategory,
|
||||
Category,
|
||||
Catalog
|
||||
):
|
||||
product = Product()
|
||||
sub_category = SubCategory(
|
||||
products=[product]
|
||||
)
|
||||
category = Category(sub_categories=[sub_category])
|
||||
catalog = Catalog(categories=[category])
|
||||
session.add(catalog)
|
||||
return catalog
|
||||
|
||||
def test_only_updates_affected_aggregates(self):
|
||||
catalog = self.catalog_factory()
|
||||
catalog2 = self.catalog_factory()
|
||||
self.session.commit()
|
||||
def test_only_updates_affected_aggregates(
|
||||
self,
|
||||
session,
|
||||
catalog_factory,
|
||||
Product
|
||||
):
|
||||
catalog = catalog_factory()
|
||||
catalog2 = catalog_factory()
|
||||
session.commit()
|
||||
|
||||
# force set catalog2 product_count to zero in order to check if it gets
|
||||
# updated when the other catalog's product count gets updated
|
||||
self.session.execute(
|
||||
session.execute(
|
||||
'UPDATE catalog SET product_count = 0 WHERE id = %d'
|
||||
% catalog2.id
|
||||
)
|
||||
|
||||
catalog.categories[0].sub_categories[0].products.append(
|
||||
self.Product()
|
||||
Product()
|
||||
)
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
||||
self.session.refresh(catalog2)
|
||||
session.commit()
|
||||
session.refresh(catalog)
|
||||
session.refresh(catalog2)
|
||||
assert catalog.product_count == 2
|
||||
assert catalog2.product_count == 0
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import aggregated, TSVectorType
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
def tsvector_reduce_concat(vectors):
|
||||
|
@ -13,45 +13,54 @@ def tsvector_reduce_concat(vectors):
|
|||
)
|
||||
|
||||
|
||||
class TestSearchVectorAggregates(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.fixture
|
||||
def Product(Base):
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
return Product
|
||||
|
||||
@aggregated('products', sa.Column(TSVectorType))
|
||||
def product_search_vector(self):
|
||||
return tsvector_reduce_concat(
|
||||
sa.func.to_tsvector(Product.name)
|
||||
)
|
||||
|
||||
products = sa.orm.relationship('Product', backref='catalog')
|
||||
@pytest.fixture
|
||||
def Catalog(Base, Product):
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
@aggregated('products', sa.Column(TSVectorType))
|
||||
def product_search_vector(self):
|
||||
return tsvector_reduce_concat(
|
||||
sa.func.to_tsvector(Product.name)
|
||||
)
|
||||
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
products = sa.orm.relationship('Product', backref='catalog')
|
||||
return Catalog
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Product = Product
|
||||
|
||||
def test_assigns_aggregates_on_insert(self):
|
||||
catalog = self.Catalog(
|
||||
@pytest.fixture
|
||||
def init_models(Product, Catalog):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestSearchVectorAggregates(object):
|
||||
|
||||
def test_assigns_aggregates_on_insert(self, session, Product, Catalog):
|
||||
catalog = Catalog(
|
||||
name=u'Some catalog'
|
||||
)
|
||||
self.session.add(catalog)
|
||||
self.session.commit()
|
||||
product = self.Product(
|
||||
session.add(catalog)
|
||||
session.commit()
|
||||
product = Product(
|
||||
name=u'Product XYZ',
|
||||
catalog=catalog
|
||||
)
|
||||
self.session.add(product)
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
||||
session.add(product)
|
||||
session.commit()
|
||||
session.refresh(catalog)
|
||||
assert catalog.product_search_vector == "'product':1 'xyz':2"
|
||||
|
|
|
@ -1,61 +1,76 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
|
||||
def create_models(self):
|
||||
class Thread(self.Base):
|
||||
__tablename__ = 'thread'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
@pytest.fixture
|
||||
def Thread(Base):
|
||||
class Thread(Base):
|
||||
__tablename__ = 'thread'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated('comments', sa.Column(sa.Integer, default=0))
|
||||
def comment_count(self):
|
||||
return sa.func.count('1')
|
||||
@aggregated('comments', sa.Column(sa.Integer, default=0))
|
||||
def comment_count(self):
|
||||
return sa.func.count('1')
|
||||
|
||||
comments = sa.orm.relationship('Comment', backref='thread')
|
||||
comments = sa.orm.relationship('Comment', backref='thread')
|
||||
return Thread
|
||||
|
||||
class Comment(self.Base):
|
||||
__tablename__ = 'comment'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
content = sa.Column(sa.Unicode(255))
|
||||
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
|
||||
|
||||
self.Thread = Thread
|
||||
self.Comment = Comment
|
||||
@pytest.fixture
|
||||
def Comment(Base):
|
||||
class Comment(Base):
|
||||
__tablename__ = 'comment'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
content = sa.Column(sa.Unicode(255))
|
||||
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
|
||||
return Comment
|
||||
|
||||
def test_assigns_aggregates_on_insert(self):
|
||||
thread = self.Thread()
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(Thread, Comment):
|
||||
pass
|
||||
|
||||
|
||||
class TestAggregateValueGenerationForSimpleModelPaths(object):
|
||||
|
||||
def test_assigns_aggregates_on_insert(self, session, Thread, Comment):
|
||||
thread = Thread()
|
||||
thread.name = u'some article name'
|
||||
self.session.add(thread)
|
||||
comment = self.Comment(content=u'Some content', thread=thread)
|
||||
self.session.add(comment)
|
||||
self.session.commit()
|
||||
self.session.refresh(thread)
|
||||
session.add(thread)
|
||||
comment = Comment(content=u'Some content', thread=thread)
|
||||
session.add(comment)
|
||||
session.commit()
|
||||
session.refresh(thread)
|
||||
assert thread.comment_count == 1
|
||||
|
||||
def test_assigns_aggregates_on_separate_insert(self):
|
||||
thread = self.Thread()
|
||||
def test_assigns_aggregates_on_separate_insert(
|
||||
self,
|
||||
session,
|
||||
Thread,
|
||||
Comment
|
||||
):
|
||||
thread = Thread()
|
||||
thread.name = u'some article name'
|
||||
self.session.add(thread)
|
||||
self.session.commit()
|
||||
comment = self.Comment(content=u'Some content', thread=thread)
|
||||
self.session.add(comment)
|
||||
self.session.commit()
|
||||
self.session.refresh(thread)
|
||||
session.add(thread)
|
||||
session.commit()
|
||||
comment = Comment(content=u'Some content', thread=thread)
|
||||
session.add(comment)
|
||||
session.commit()
|
||||
session.refresh(thread)
|
||||
assert thread.comment_count == 1
|
||||
|
||||
def test_assigns_aggregates_on_delete(self):
|
||||
thread = self.Thread()
|
||||
def test_assigns_aggregates_on_delete(self, session, Thread, Comment):
|
||||
thread = Thread()
|
||||
thread.name = u'some article name'
|
||||
self.session.add(thread)
|
||||
self.session.commit()
|
||||
comment = self.Comment(content=u'Some content', thread=thread)
|
||||
self.session.add(comment)
|
||||
self.session.commit()
|
||||
self.session.delete(comment)
|
||||
self.session.commit()
|
||||
self.session.refresh(thread)
|
||||
session.add(thread)
|
||||
session.commit()
|
||||
comment = Comment(content=u'Some content', thread=thread)
|
||||
session.add(comment)
|
||||
session.commit()
|
||||
session.delete(comment)
|
||||
session.commit()
|
||||
session.refresh(thread)
|
||||
assert thread.comment_count == 0
|
||||
|
|
|
@ -1,59 +1,74 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregatedWithColumnAlias(TestCase):
|
||||
def create_models(self):
|
||||
class Thread(self.Base):
|
||||
__tablename__ = 'thread'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
@pytest.fixture
|
||||
def Thread(Base):
|
||||
class Thread(Base):
|
||||
__tablename__ = 'thread'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
@aggregated(
|
||||
'comments',
|
||||
sa.Column('_comment_count', sa.Integer, default=0)
|
||||
)
|
||||
def comment_count(self):
|
||||
return sa.func.count('1')
|
||||
@aggregated(
|
||||
'comments',
|
||||
sa.Column('_comment_count', sa.Integer, default=0)
|
||||
)
|
||||
def comment_count(self):
|
||||
return sa.func.count('1')
|
||||
|
||||
comments = sa.orm.relationship('Comment', backref='thread')
|
||||
comments = sa.orm.relationship('Comment', backref='thread')
|
||||
return Thread
|
||||
|
||||
class Comment(self.Base):
|
||||
__tablename__ = 'comment'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
|
||||
|
||||
self.Thread = Thread
|
||||
self.Comment = Comment
|
||||
@pytest.fixture
|
||||
def Comment(Base):
|
||||
class Comment(Base):
|
||||
__tablename__ = 'comment'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
|
||||
return Comment
|
||||
|
||||
def test_assigns_aggregates_on_insert(self):
|
||||
thread = self.Thread()
|
||||
self.session.add(thread)
|
||||
comment = self.Comment(thread=thread)
|
||||
self.session.add(comment)
|
||||
self.session.commit()
|
||||
self.session.refresh(thread)
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(Thread, Comment):
|
||||
pass
|
||||
|
||||
|
||||
class TestAggregatedWithColumnAlias(object):
|
||||
|
||||
def test_assigns_aggregates_on_insert(self, session, Thread, Comment):
|
||||
thread = Thread()
|
||||
session.add(thread)
|
||||
comment = Comment(thread=thread)
|
||||
session.add(comment)
|
||||
session.commit()
|
||||
session.refresh(thread)
|
||||
assert thread.comment_count == 1
|
||||
|
||||
def test_assigns_aggregates_on_separate_insert(self):
|
||||
thread = self.Thread()
|
||||
self.session.add(thread)
|
||||
self.session.commit()
|
||||
comment = self.Comment(thread=thread)
|
||||
self.session.add(comment)
|
||||
self.session.commit()
|
||||
self.session.refresh(thread)
|
||||
def test_assigns_aggregates_on_separate_insert(
|
||||
self,
|
||||
session,
|
||||
Thread,
|
||||
Comment
|
||||
):
|
||||
thread = Thread()
|
||||
session.add(thread)
|
||||
session.commit()
|
||||
comment = Comment(thread=thread)
|
||||
session.add(comment)
|
||||
session.commit()
|
||||
session.refresh(thread)
|
||||
assert thread.comment_count == 1
|
||||
|
||||
def test_assigns_aggregates_on_delete(self):
|
||||
thread = self.Thread()
|
||||
self.session.add(thread)
|
||||
self.session.commit()
|
||||
comment = self.Comment(thread=thread)
|
||||
self.session.add(comment)
|
||||
self.session.commit()
|
||||
self.session.delete(comment)
|
||||
self.session.commit()
|
||||
self.session.refresh(thread)
|
||||
def test_assigns_aggregates_on_delete(self, session, Thread, Comment):
|
||||
thread = Thread()
|
||||
session.add(thread)
|
||||
session.commit()
|
||||
comment = Comment(thread=thread)
|
||||
session.add(comment)
|
||||
session.commit()
|
||||
session.delete(comment)
|
||||
session.commit()
|
||||
session.refresh(thread)
|
||||
assert thread.comment_count == 0
|
||||
|
|
|
@ -1,47 +1,56 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregateValueGenerationWithCascadeDelete(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.fixture
|
||||
def Thread(Base):
|
||||
class Thread(Base):
|
||||
__tablename__ = 'thread'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
def create_models(self):
|
||||
class Thread(self.Base):
|
||||
__tablename__ = 'thread'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
@aggregated('comments', sa.Column(sa.Integer, default=0))
|
||||
def comment_count(self):
|
||||
return sa.func.count('1')
|
||||
|
||||
@aggregated('comments', sa.Column(sa.Integer, default=0))
|
||||
def comment_count(self):
|
||||
return sa.func.count('1')
|
||||
comments = sa.orm.relationship(
|
||||
'Comment',
|
||||
passive_deletes=True,
|
||||
backref='thread'
|
||||
)
|
||||
return Thread
|
||||
|
||||
comments = sa.orm.relationship(
|
||||
'Comment',
|
||||
passive_deletes=True,
|
||||
backref='thread'
|
||||
)
|
||||
|
||||
class Comment(self.Base):
|
||||
__tablename__ = 'comment'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
content = sa.Column(sa.Unicode(255))
|
||||
thread_id = sa.Column(
|
||||
sa.Integer,
|
||||
sa.ForeignKey('thread.id', ondelete='CASCADE')
|
||||
)
|
||||
@pytest.fixture
|
||||
def Comment(Base):
|
||||
class Comment(Base):
|
||||
__tablename__ = 'comment'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
content = sa.Column(sa.Unicode(255))
|
||||
thread_id = sa.Column(
|
||||
sa.Integer,
|
||||
sa.ForeignKey('thread.id', ondelete='CASCADE')
|
||||
)
|
||||
return Comment
|
||||
|
||||
self.Thread = Thread
|
||||
self.Comment = Comment
|
||||
|
||||
def test_something(self):
|
||||
thread = self.Thread()
|
||||
@pytest.fixture
|
||||
def init_models(Thread, Comment):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestAggregateValueGenerationWithCascadeDelete(object):
|
||||
|
||||
def test_something(self, session, Thread, Comment):
|
||||
thread = Thread()
|
||||
thread.name = u'some article name'
|
||||
self.session.add(thread)
|
||||
comment = self.Comment(content=u'Some content', thread=thread)
|
||||
self.session.add(comment)
|
||||
self.session.commit()
|
||||
self.session.expire_all()
|
||||
self.session.delete(thread)
|
||||
self.session.commit()
|
||||
session.add(thread)
|
||||
comment = Comment(content=u'Some content', thread=thread)
|
||||
session.add(comment)
|
||||
session.commit()
|
||||
session.expire_all()
|
||||
session.delete(thread)
|
||||
session.commit()
|
||||
|
|
|
@ -1,29 +1,35 @@
|
|||
import pytest
|
||||
|
||||
from sqlalchemy_utils import analyze
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAnalyzeWithPostgres(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestAnalyzeWithPostgres(object):
|
||||
|
||||
def test_runtime(self):
|
||||
query = self.session.query(self.Article)
|
||||
assert analyze(self.connection, query).runtime
|
||||
def test_runtime(self, session, connection, Article):
|
||||
query = session.query(Article)
|
||||
assert analyze(connection, query).runtime
|
||||
|
||||
def test_node_types_with_join(self):
|
||||
def test_node_types_with_join(self, session, connection, Article):
|
||||
query = (
|
||||
self.session.query(self.Article)
|
||||
.join(self.Article.category)
|
||||
session.query(Article)
|
||||
.join(Article.category)
|
||||
)
|
||||
analysis = analyze(self.connection, query)
|
||||
analysis = analyze(connection, query)
|
||||
assert analysis.node_types == [
|
||||
u'Hash Join', u'Seq Scan', u'Hash', u'Seq Scan'
|
||||
]
|
||||
|
||||
def test_node_types_with_index_only_scan(self):
|
||||
def test_node_types_with_index_only_scan(
|
||||
self,
|
||||
session,
|
||||
connection,
|
||||
Article
|
||||
):
|
||||
query = (
|
||||
self.session.query(self.Article.name)
|
||||
.order_by(self.Article.name)
|
||||
session.query(Article.name)
|
||||
.order_by(Article.name)
|
||||
.limit(10)
|
||||
)
|
||||
analysis = analyze(self.connection, query)
|
||||
analysis = analyze(connection, query)
|
||||
assert analysis.node_types == [u'Limit', u'Index Only Scan']
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from flexmock import flexmock
|
||||
from pytest import mark
|
||||
|
||||
from sqlalchemy_utils import create_database, database_exists, drop_database
|
||||
from tests import TestCase
|
||||
|
||||
pymysql = None
|
||||
try:
|
||||
|
@ -14,38 +11,73 @@ except ImportError:
|
|||
pass
|
||||
|
||||
|
||||
class DatabaseTest(TestCase):
|
||||
def test_create_and_drop(self):
|
||||
assert not database_exists(self.url)
|
||||
create_database(self.url)
|
||||
assert database_exists(self.url)
|
||||
drop_database(self.url)
|
||||
assert not database_exists(self.url)
|
||||
class DatabaseTest(object):
|
||||
def test_create_and_drop(self, dsn):
|
||||
assert not database_exists(dsn)
|
||||
create_database(dsn)
|
||||
assert database_exists(dsn)
|
||||
drop_database(dsn)
|
||||
assert not database_exists(dsn)
|
||||
|
||||
|
||||
class TestDatabaseSQLite(DatabaseTest):
|
||||
url = 'sqlite:///sqlalchemy_utils.db'
|
||||
@pytest.mark.usefixtures('sqlite_memory_dsn')
|
||||
class TestDatabaseSQLiteMemory(object):
|
||||
|
||||
def setup(self):
|
||||
if os.path.exists('sqlalchemy_utils.db'):
|
||||
os.remove('sqlalchemy_utils.db')
|
||||
|
||||
def test_exists_memory(self):
|
||||
assert database_exists('sqlite:///:memory:')
|
||||
def test_exists_memory(self, dsn):
|
||||
assert database_exists(dsn)
|
||||
|
||||
|
||||
@mark.skipif('pymysql is None')
|
||||
@pytest.mark.usefixtures('sqlite_file_dsn')
|
||||
class TestDatabaseSQLiteFile(DatabaseTest):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skipif('pymysql is None')
|
||||
@pytest.mark.usefixtures('mysql_dsn')
|
||||
class TestDatabaseMySQL(DatabaseTest):
|
||||
url = 'mysql+pymysql://travis@localhost/db_test_sqlalchemy_util'
|
||||
|
||||
@pytest.fixture
|
||||
def db_name(self):
|
||||
return 'db_test_sqlalchemy_util'
|
||||
|
||||
|
||||
@mark.skipif('pymysql is None')
|
||||
@pytest.mark.skipif('pymysql is None')
|
||||
@pytest.mark.usefixtures('mysql_dsn')
|
||||
class TestDatabaseMySQLWithQuotedName(DatabaseTest):
|
||||
url = 'mysql+pymysql://travis@localhost/db_test_sqlalchemy-util'
|
||||
|
||||
@pytest.fixture
|
||||
def db_name(self):
|
||||
return 'db_test_sqlalchemy-util'
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestDatabasePostgres(DatabaseTest):
|
||||
|
||||
@pytest.fixture
|
||||
def db_name(self):
|
||||
return 'db_test_sqlalchemy_util'
|
||||
|
||||
def test_template(self):
|
||||
(
|
||||
flexmock(sa.engine.Engine)
|
||||
.should_receive('execute')
|
||||
.with_args(
|
||||
"CREATE DATABASE db_test_sqlalchemy_util ENCODING 'utf8' "
|
||||
"TEMPLATE my_template"
|
||||
)
|
||||
)
|
||||
create_database(
|
||||
'postgres://postgres@localhost/db_test_sqlalchemy_util',
|
||||
template='my_template'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestDatabasePostgresWithQuotedName(DatabaseTest):
|
||||
url = 'postgres://postgres@localhost/db_test_sqlalchemy-util'
|
||||
|
||||
@pytest.fixture
|
||||
def db_name(self):
|
||||
return 'db_test_sqlalchemy-util'
|
||||
|
||||
def test_template(self):
|
||||
(
|
||||
|
@ -61,21 +93,3 @@ class TestDatabasePostgresWithQuotedName(DatabaseTest):
|
|||
'postgres://postgres@localhost/db_test_sqlalchemy-util',
|
||||
template='my-template'
|
||||
)
|
||||
|
||||
|
||||
class TestDatabasePostgres(DatabaseTest):
|
||||
url = 'postgres://postgres@localhost/db_test_sqlalchemy_util'
|
||||
|
||||
def test_template(self):
|
||||
(
|
||||
flexmock(sa.engine.Engine)
|
||||
.should_receive('execute')
|
||||
.with_args(
|
||||
"CREATE DATABASE db_test_sqlalchemy_util ENCODING 'utf8' "
|
||||
"TEMPLATE my_template"
|
||||
)
|
||||
)
|
||||
create_database(
|
||||
'postgres://postgres@localhost/db_test_sqlalchemy_util',
|
||||
template='my_template'
|
||||
)
|
||||
|
|
|
@ -1,18 +1,23 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import dependent_objects, get_referencing_foreign_keys
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestDependentObjects(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
class TestDependentObjects(object):
|
||||
|
||||
@pytest.fixture
|
||||
def User(self, Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
first_name = sa.Column(sa.Unicode(255))
|
||||
last_name = sa.Column(sa.Unicode(255))
|
||||
return User
|
||||
|
||||
class Article(self.Base):
|
||||
@pytest.fixture
|
||||
def Article(self, Base, User):
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
|
||||
|
@ -22,8 +27,11 @@ class TestDependentObjects(TestCase):
|
|||
|
||||
author = sa.orm.relationship(User, foreign_keys=[author_id])
|
||||
owner = sa.orm.relationship(User, foreign_keys=[owner_id])
|
||||
return Article
|
||||
|
||||
class BlogPost(self.Base):
|
||||
@pytest.fixture
|
||||
def BlogPost(self, Base, User):
|
||||
class BlogPost(Base):
|
||||
__tablename__ = 'blog_post'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
owner_id = sa.Column(
|
||||
|
@ -31,21 +39,22 @@ class TestDependentObjects(TestCase):
|
|||
)
|
||||
|
||||
owner = sa.orm.relationship(User)
|
||||
return BlogPost
|
||||
|
||||
self.User = User
|
||||
self.Article = Article
|
||||
self.BlogPost = BlogPost
|
||||
@pytest.fixture
|
||||
def init_models(self, User, Article, BlogPost):
|
||||
pass
|
||||
|
||||
def test_returns_all_dependent_objects(self):
|
||||
user = self.User(first_name=u'John')
|
||||
def test_returns_all_dependent_objects(self, session, User, Article):
|
||||
user = User(first_name=u'John')
|
||||
articles = [
|
||||
self.Article(author=user),
|
||||
self.Article(),
|
||||
self.Article(owner=user),
|
||||
self.Article(author=user, owner=user)
|
||||
Article(author=user),
|
||||
Article(),
|
||||
Article(owner=user),
|
||||
Article(author=user, owner=user)
|
||||
]
|
||||
self.session.add_all(articles)
|
||||
self.session.commit()
|
||||
session.add_all(articles)
|
||||
session.commit()
|
||||
|
||||
deps = list(dependent_objects(user))
|
||||
assert len(deps) == 3
|
||||
|
@ -53,23 +62,29 @@ class TestDependentObjects(TestCase):
|
|||
assert articles[2] in deps
|
||||
assert articles[3] in deps
|
||||
|
||||
def test_with_foreign_keys_parameter(self):
|
||||
user = self.User(first_name=u'John')
|
||||
def test_with_foreign_keys_parameter(
|
||||
self,
|
||||
session,
|
||||
User,
|
||||
Article,
|
||||
BlogPost
|
||||
):
|
||||
user = User(first_name=u'John')
|
||||
objects = [
|
||||
self.Article(author=user),
|
||||
self.Article(),
|
||||
self.Article(owner=user),
|
||||
self.Article(author=user, owner=user),
|
||||
self.BlogPost(owner=user)
|
||||
Article(author=user),
|
||||
Article(),
|
||||
Article(owner=user),
|
||||
Article(author=user, owner=user),
|
||||
BlogPost(owner=user)
|
||||
]
|
||||
self.session.add_all(objects)
|
||||
self.session.commit()
|
||||
session.add_all(objects)
|
||||
session.commit()
|
||||
|
||||
deps = list(
|
||||
dependent_objects(
|
||||
user,
|
||||
(
|
||||
fk for fk in get_referencing_foreign_keys(self.User)
|
||||
fk for fk in get_referencing_foreign_keys(User)
|
||||
if fk.ondelete == 'RESTRICT' or fk.ondelete is None
|
||||
)
|
||||
).limit(5)
|
||||
|
@ -79,15 +94,20 @@ class TestDependentObjects(TestCase):
|
|||
assert objects[3] in deps
|
||||
|
||||
|
||||
class TestDependentObjectsWithColumnAliases(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
class TestDependentObjectsWithColumnAliases(object):
|
||||
|
||||
@pytest.fixture
|
||||
def User(self, Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
first_name = sa.Column(sa.Unicode(255))
|
||||
last_name = sa.Column(sa.Unicode(255))
|
||||
return User
|
||||
|
||||
class Article(self.Base):
|
||||
@pytest.fixture
|
||||
def Article(self, Base, User):
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_id = sa.Column(
|
||||
|
@ -100,8 +120,11 @@ class TestDependentObjectsWithColumnAliases(TestCase):
|
|||
|
||||
author = sa.orm.relationship(User, foreign_keys=[author_id])
|
||||
owner = sa.orm.relationship(User, foreign_keys=[owner_id])
|
||||
return Article
|
||||
|
||||
class BlogPost(self.Base):
|
||||
@pytest.fixture
|
||||
def BlogPost(self, Base, User):
|
||||
class BlogPost(Base):
|
||||
__tablename__ = 'blog_post'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
owner_id = sa.Column(
|
||||
|
@ -110,21 +133,22 @@ class TestDependentObjectsWithColumnAliases(TestCase):
|
|||
)
|
||||
|
||||
owner = sa.orm.relationship(User)
|
||||
return BlogPost
|
||||
|
||||
self.User = User
|
||||
self.Article = Article
|
||||
self.BlogPost = BlogPost
|
||||
@pytest.fixture
|
||||
def init_models(self, User, Article, BlogPost):
|
||||
pass
|
||||
|
||||
def test_returns_all_dependent_objects(self):
|
||||
user = self.User(first_name=u'John')
|
||||
def test_returns_all_dependent_objects(self, session, User, Article):
|
||||
user = User(first_name=u'John')
|
||||
articles = [
|
||||
self.Article(author=user),
|
||||
self.Article(),
|
||||
self.Article(owner=user),
|
||||
self.Article(author=user, owner=user)
|
||||
Article(author=user),
|
||||
Article(),
|
||||
Article(owner=user),
|
||||
Article(author=user, owner=user)
|
||||
]
|
||||
self.session.add_all(articles)
|
||||
self.session.commit()
|
||||
session.add_all(articles)
|
||||
session.commit()
|
||||
|
||||
deps = list(dependent_objects(user))
|
||||
assert len(deps) == 3
|
||||
|
@ -132,23 +156,29 @@ class TestDependentObjectsWithColumnAliases(TestCase):
|
|||
assert articles[2] in deps
|
||||
assert articles[3] in deps
|
||||
|
||||
def test_with_foreign_keys_parameter(self):
|
||||
user = self.User(first_name=u'John')
|
||||
def test_with_foreign_keys_parameter(
|
||||
self,
|
||||
session,
|
||||
User,
|
||||
Article,
|
||||
BlogPost
|
||||
):
|
||||
user = User(first_name=u'John')
|
||||
objects = [
|
||||
self.Article(author=user),
|
||||
self.Article(),
|
||||
self.Article(owner=user),
|
||||
self.Article(author=user, owner=user),
|
||||
self.BlogPost(owner=user)
|
||||
Article(author=user),
|
||||
Article(),
|
||||
Article(owner=user),
|
||||
Article(author=user, owner=user),
|
||||
BlogPost(owner=user)
|
||||
]
|
||||
self.session.add_all(objects)
|
||||
self.session.commit()
|
||||
session.add_all(objects)
|
||||
session.commit()
|
||||
|
||||
deps = list(
|
||||
dependent_objects(
|
||||
user,
|
||||
(
|
||||
fk for fk in get_referencing_foreign_keys(self.User)
|
||||
fk for fk in get_referencing_foreign_keys(User)
|
||||
if fk.ondelete == 'RESTRICT' or fk.ondelete is None
|
||||
)
|
||||
).limit(5)
|
||||
|
@ -158,50 +188,64 @@ class TestDependentObjectsWithColumnAliases(TestCase):
|
|||
assert objects[3] in deps
|
||||
|
||||
|
||||
class TestDependentObjectsWithManyReferences(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
class TestDependentObjectsWithManyReferences(object):
|
||||
|
||||
@pytest.fixture
|
||||
def User(self, Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
first_name = sa.Column(sa.Unicode(255))
|
||||
last_name = sa.Column(sa.Unicode(255))
|
||||
return User
|
||||
|
||||
class BlogPost(self.Base):
|
||||
@pytest.fixture
|
||||
def BlogPost(self, Base, User):
|
||||
class BlogPost(Base):
|
||||
__tablename__ = 'blog_post'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
|
||||
author = sa.orm.relationship(User)
|
||||
return BlogPost
|
||||
|
||||
class Article(self.Base):
|
||||
@pytest.fixture
|
||||
def Article(self, Base, User):
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
|
||||
author = sa.orm.relationship(User)
|
||||
return Article
|
||||
|
||||
self.User = User
|
||||
self.Article = Article
|
||||
self.BlogPost = BlogPost
|
||||
@pytest.fixture
|
||||
def init_models(self, User, BlogPost, Article):
|
||||
pass
|
||||
|
||||
def test_with_many_dependencies(self):
|
||||
user = self.User(first_name=u'John')
|
||||
def test_with_many_dependencies(self, session, User, Article, BlogPost):
|
||||
user = User(first_name=u'John')
|
||||
objects = [
|
||||
self.Article(author=user),
|
||||
self.BlogPost(author=user)
|
||||
Article(author=user),
|
||||
BlogPost(author=user)
|
||||
]
|
||||
self.session.add_all(objects)
|
||||
self.session.commit()
|
||||
session.add_all(objects)
|
||||
session.commit()
|
||||
deps = list(dependent_objects(user))
|
||||
assert len(deps) == 2
|
||||
|
||||
|
||||
class TestDependentObjectsWithCompositeKeys(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
class TestDependentObjectsWithCompositeKeys(object):
|
||||
|
||||
@pytest.fixture
|
||||
def User(self, Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
first_name = sa.Column(sa.Unicode(255), primary_key=True)
|
||||
last_name = sa.Column(sa.Unicode(255), primary_key=True)
|
||||
return User
|
||||
|
||||
class Article(self.Base):
|
||||
@pytest.fixture
|
||||
def Article(self, Base, User):
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_first_name = sa.Column(sa.Unicode(255))
|
||||
|
@ -214,20 +258,22 @@ class TestDependentObjectsWithCompositeKeys(TestCase):
|
|||
)
|
||||
|
||||
author = sa.orm.relationship(User)
|
||||
return Article
|
||||
|
||||
self.User = User
|
||||
self.Article = Article
|
||||
@pytest.fixture
|
||||
def init_models(self, User, Article):
|
||||
pass
|
||||
|
||||
def test_returns_all_dependent_objects(self):
|
||||
user = self.User(first_name=u'John', last_name=u'Smith')
|
||||
def test_returns_all_dependent_objects(self, session, User, Article):
|
||||
user = User(first_name=u'John', last_name=u'Smith')
|
||||
articles = [
|
||||
self.Article(author=user),
|
||||
self.Article(),
|
||||
self.Article(),
|
||||
self.Article(author=user)
|
||||
Article(author=user),
|
||||
Article(),
|
||||
Article(),
|
||||
Article(author=user)
|
||||
]
|
||||
self.session.add_all(articles)
|
||||
self.session.commit()
|
||||
session.add_all(articles)
|
||||
session.commit()
|
||||
|
||||
deps = list(dependent_objects(user))
|
||||
assert len(deps) == 2
|
||||
|
@ -235,14 +281,19 @@ class TestDependentObjectsWithCompositeKeys(TestCase):
|
|||
assert articles[3] in deps
|
||||
|
||||
|
||||
class TestDependentObjectsWithSingleTableInheritance(TestCase):
|
||||
def create_models(self):
|
||||
class Category(self.Base):
|
||||
class TestDependentObjectsWithSingleTableInheritance(object):
|
||||
|
||||
@pytest.fixture
|
||||
def Category(self, Base):
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
return Category
|
||||
|
||||
class TextItem(self.Base):
|
||||
@pytest.fixture
|
||||
def TextItem(self, Base, Category):
|
||||
class TextItem(Base):
|
||||
__tablename__ = 'text_item'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
@ -261,33 +312,39 @@ class TestDependentObjectsWithSingleTableInheritance(TestCase):
|
|||
__mapper_args__ = {
|
||||
'polymorphic_on': type,
|
||||
}
|
||||
return TextItem
|
||||
|
||||
@pytest.fixture
|
||||
def Article(self, TextItem):
|
||||
class Article(TextItem):
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': u'article'
|
||||
}
|
||||
return Article
|
||||
|
||||
@pytest.fixture
|
||||
def BlogPost(self, TextItem):
|
||||
class BlogPost(TextItem):
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': u'blog_post'
|
||||
}
|
||||
return BlogPost
|
||||
|
||||
self.Category = Category
|
||||
self.TextItem = TextItem
|
||||
self.Article = Article
|
||||
self.BlogPost = BlogPost
|
||||
@pytest.fixture
|
||||
def init_models(self, Category, TextItem, Article, BlogPost):
|
||||
pass
|
||||
|
||||
def test_returns_all_dependent_objects(self):
|
||||
category1 = self.Category(name=u'Category #1')
|
||||
category2 = self.Category(name=u'Category #2')
|
||||
def test_returns_all_dependent_objects(self, session, Category, Article):
|
||||
category1 = Category(name=u'Category #1')
|
||||
category2 = Category(name=u'Category #2')
|
||||
articles = [
|
||||
self.Article(category=category1),
|
||||
self.Article(category=category1),
|
||||
self.Article(category=category2),
|
||||
self.Article(category=category2),
|
||||
Article(category=category1),
|
||||
Article(category=category1),
|
||||
Article(category=category2),
|
||||
Article(category=category2),
|
||||
]
|
||||
self.session.add_all(articles)
|
||||
self.session.commit()
|
||||
session.add_all(articles)
|
||||
session.commit()
|
||||
|
||||
deps = list(dependent_objects(category1))
|
||||
assert len(deps) == 2
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from sqlalchemy_utils import escape_like
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestEscapeLike(TestCase):
|
||||
class TestEscapeLike(object):
|
||||
def test_escapes_wildcards(self):
|
||||
assert escape_like('_*%') == '*_***%'
|
||||
|
|
|
@ -1,21 +1,20 @@
|
|||
from pytest import raises
|
||||
import pytest
|
||||
|
||||
from sqlalchemy_utils import get_bind
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestGetBind(TestCase):
|
||||
def test_with_session(self):
|
||||
assert get_bind(self.session) == self.connection
|
||||
class TestGetBind(object):
|
||||
def test_with_session(self, session, connection):
|
||||
assert get_bind(session) == connection
|
||||
|
||||
def test_with_connection(self):
|
||||
assert get_bind(self.connection) == self.connection
|
||||
def test_with_connection(self, session, connection):
|
||||
assert get_bind(connection) == connection
|
||||
|
||||
def test_with_model_object(self):
|
||||
article = self.Article()
|
||||
self.session.add(article)
|
||||
assert get_bind(article) == self.connection
|
||||
def test_with_model_object(self, session, connection, Article):
|
||||
article = Article()
|
||||
session.add(article)
|
||||
assert get_bind(article) == connection
|
||||
|
||||
def test_with_unknown_type(self):
|
||||
with raises(TypeError):
|
||||
with pytest.raises(TypeError):
|
||||
get_bind(None)
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from pytest import raises
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import get_class_by_table
|
||||
|
||||
|
||||
class TestGetClassByTableWithJoinedTableInheritance(object):
|
||||
def setup_method(self, method):
|
||||
self.Base = declarative_base()
|
||||
|
||||
class Entity(self.Base):
|
||||
@pytest.fixture
|
||||
def Entity(self, Base):
|
||||
class Entity(Base):
|
||||
__tablename__ = 'entity'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
|
@ -18,7 +17,10 @@ class TestGetClassByTableWithJoinedTableInheritance(object):
|
|||
'polymorphic_on': type,
|
||||
'polymorphic_identity': 'entity'
|
||||
}
|
||||
return Entity
|
||||
|
||||
@pytest.fixture
|
||||
def User(self, Entity):
|
||||
class User(Entity):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(
|
||||
|
@ -29,31 +31,29 @@ class TestGetClassByTableWithJoinedTableInheritance(object):
|
|||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'user'
|
||||
}
|
||||
return User
|
||||
|
||||
self.Entity = Entity
|
||||
self.User = User
|
||||
|
||||
def test_returns_class(self):
|
||||
assert get_class_by_table(self.Base, self.User.__table__) == self.User
|
||||
def test_returns_class(self, Base, User, Entity):
|
||||
assert get_class_by_table(Base, User.__table__) == User
|
||||
assert get_class_by_table(
|
||||
self.Base,
|
||||
self.Entity.__table__
|
||||
) == self.Entity
|
||||
Base,
|
||||
Entity.__table__
|
||||
) == Entity
|
||||
|
||||
def test_table_with_no_associated_class(self):
|
||||
def test_table_with_no_associated_class(self, Base):
|
||||
table = sa.Table(
|
||||
'some_table',
|
||||
self.Base.metadata,
|
||||
Base.metadata,
|
||||
sa.Column('id', sa.Integer)
|
||||
)
|
||||
assert get_class_by_table(self.Base, table) is None
|
||||
assert get_class_by_table(Base, table) is None
|
||||
|
||||
|
||||
class TestGetClassByTableWithSingleTableInheritance(object):
|
||||
def setup_method(self, method):
|
||||
self.Base = declarative_base()
|
||||
|
||||
class Entity(self.Base):
|
||||
@pytest.fixture
|
||||
def Entity(self, Base):
|
||||
class Entity(Base):
|
||||
__tablename__ = 'entity'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
|
@ -62,38 +62,39 @@ class TestGetClassByTableWithSingleTableInheritance(object):
|
|||
'polymorphic_on': type,
|
||||
'polymorphic_identity': 'entity'
|
||||
}
|
||||
return Entity
|
||||
|
||||
@pytest.fixture
|
||||
def User(self, Entity):
|
||||
class User(Entity):
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'user'
|
||||
}
|
||||
return User
|
||||
|
||||
self.Entity = Entity
|
||||
self.User = User
|
||||
|
||||
def test_multiple_classes_without_data_parameter(self):
|
||||
with raises(ValueError):
|
||||
def test_multiple_classes_without_data_parameter(self, Base, Entity, User):
|
||||
with pytest.raises(ValueError):
|
||||
assert get_class_by_table(
|
||||
self.Base,
|
||||
self.Entity.__table__
|
||||
Base,
|
||||
Entity.__table__
|
||||
)
|
||||
|
||||
def test_multiple_classes_with_data_parameter(self):
|
||||
def test_multiple_classes_with_data_parameter(self, Base, Entity, User):
|
||||
assert get_class_by_table(
|
||||
self.Base,
|
||||
self.Entity.__table__,
|
||||
Base,
|
||||
Entity.__table__,
|
||||
{'type': 'entity'}
|
||||
) == self.Entity
|
||||
) == Entity
|
||||
assert get_class_by_table(
|
||||
self.Base,
|
||||
self.Entity.__table__,
|
||||
Base,
|
||||
Entity.__table__,
|
||||
{'type': 'user'}
|
||||
) == self.User
|
||||
) == User
|
||||
|
||||
def test_multiple_classes_with_bogus_data(self):
|
||||
with raises(ValueError):
|
||||
def test_multiple_classes_with_bogus_data(self, Base, Entity, User):
|
||||
with pytest.raises(ValueError):
|
||||
assert get_class_by_table(
|
||||
self.Base,
|
||||
self.Entity.__table__,
|
||||
Base,
|
||||
Entity.__table__,
|
||||
{'type': 'unknown'}
|
||||
)
|
||||
|
|
|
@ -1,42 +1,44 @@
|
|||
from copy import copy
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from pytest import raises
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import get_column_key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Building(Base):
|
||||
class Building(Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column('_name', sa.Unicode(255))
|
||||
return Building
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Movie(Base):
|
||||
class Movie(Base):
|
||||
__tablename__ = 'movie'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
return Movie
|
||||
|
||||
|
||||
class TestGetColumnKey(object):
|
||||
def setup_method(self, method):
|
||||
Base = declarative_base()
|
||||
|
||||
class Building(Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column('_name', sa.Unicode(255))
|
||||
|
||||
class Movie(Base):
|
||||
__tablename__ = 'movie'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
self.Building = Building
|
||||
self.Movie = Movie
|
||||
|
||||
def test_supports_aliases(self):
|
||||
def test_supports_aliases(self, Building):
|
||||
assert (
|
||||
get_column_key(self.Building, self.Building.__table__.c.id) ==
|
||||
get_column_key(Building, Building.__table__.c.id) ==
|
||||
'id'
|
||||
)
|
||||
assert (
|
||||
get_column_key(self.Building, self.Building.__table__.c._name) ==
|
||||
get_column_key(Building, Building.__table__.c._name) ==
|
||||
'name'
|
||||
)
|
||||
|
||||
def test_supports_vague_matching_of_column_objects(self):
|
||||
column = copy(self.Building.__table__.c._name)
|
||||
assert get_column_key(self.Building, column) == 'name'
|
||||
def test_supports_vague_matching_of_column_objects(self, Building):
|
||||
column = copy(Building.__table__.c._name)
|
||||
assert get_column_key(Building, column) == 'name'
|
||||
|
||||
def test_throws_value_error_for_unknown_column(self):
|
||||
with raises(sa.orm.exc.UnmappedColumnError):
|
||||
get_column_key(self.Building, self.Movie.__table__.c.id)
|
||||
def test_throws_value_error_for_unknown_column(self, Building, Movie):
|
||||
with pytest.raises(sa.orm.exc.UnmappedColumnError):
|
||||
get_column_key(Building, Movie.__table__.c.id)
|
||||
|
|
|
@ -1,65 +1,65 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import get_columns
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Building(Base):
|
||||
class Building(Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
name = sa.Column('_name', sa.Unicode(255))
|
||||
return Building
|
||||
|
||||
|
||||
class TestGetColumns(object):
|
||||
def setup_method(self, method):
|
||||
Base = declarative_base()
|
||||
|
||||
class Building(Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
name = sa.Column('_name', sa.Unicode(255))
|
||||
|
||||
self.Building = Building
|
||||
|
||||
def test_table(self):
|
||||
def test_table(self, Building):
|
||||
assert isinstance(
|
||||
get_columns(self.Building.__table__),
|
||||
get_columns(Building.__table__),
|
||||
sa.sql.base.ImmutableColumnCollection
|
||||
)
|
||||
|
||||
def test_instrumented_attribute(self):
|
||||
assert get_columns(self.Building.id) == [self.Building.__table__.c._id]
|
||||
def test_instrumented_attribute(self, Building):
|
||||
assert get_columns(Building.id) == [Building.__table__.c._id]
|
||||
|
||||
def test_column_property(self):
|
||||
assert get_columns(self.Building.id.property) == [
|
||||
self.Building.__table__.c._id
|
||||
def test_column_property(self, Building):
|
||||
assert get_columns(Building.id.property) == [
|
||||
Building.__table__.c._id
|
||||
]
|
||||
|
||||
def test_column(self):
|
||||
assert get_columns(self.Building.__table__.c._id) == [
|
||||
self.Building.__table__.c._id
|
||||
def test_column(self, Building):
|
||||
assert get_columns(Building.__table__.c._id) == [
|
||||
Building.__table__.c._id
|
||||
]
|
||||
|
||||
def test_declarative_class(self):
|
||||
def test_declarative_class(self, Building):
|
||||
assert isinstance(
|
||||
get_columns(self.Building),
|
||||
get_columns(Building),
|
||||
sa.util._collections.OrderedProperties
|
||||
)
|
||||
|
||||
def test_declarative_object(self):
|
||||
def test_declarative_object(self, Building):
|
||||
assert isinstance(
|
||||
get_columns(self.Building()),
|
||||
get_columns(Building()),
|
||||
sa.util._collections.OrderedProperties
|
||||
)
|
||||
|
||||
def test_mapper(self):
|
||||
def test_mapper(self, Building):
|
||||
assert isinstance(
|
||||
get_columns(self.Building.__mapper__),
|
||||
get_columns(Building.__mapper__),
|
||||
sa.util._collections.OrderedProperties
|
||||
)
|
||||
|
||||
def test_class_alias(self):
|
||||
def test_class_alias(self, Building):
|
||||
assert isinstance(
|
||||
get_columns(sa.orm.aliased(self.Building)),
|
||||
get_columns(sa.orm.aliased(Building)),
|
||||
sa.util._collections.OrderedProperties
|
||||
)
|
||||
|
||||
def test_table_alias(self):
|
||||
alias = sa.orm.aliased(self.Building.__table__)
|
||||
def test_table_alias(self, Building):
|
||||
alias = sa.orm.aliased(Building.__table__)
|
||||
assert isinstance(
|
||||
get_columns(alias),
|
||||
sa.sql.base.ImmutableColumnCollection
|
||||
|
|
|
@ -1,41 +1,41 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
|
||||
from sqlalchemy_utils import get_hybrid_properties
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Category(Base):
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@hybrid_property
|
||||
def lowercase_name(self):
|
||||
return self.name.lower()
|
||||
|
||||
@lowercase_name.expression
|
||||
def lowercase_name(cls):
|
||||
return sa.func.lower(cls.name)
|
||||
return Category
|
||||
|
||||
|
||||
class TestGetHybridProperties(object):
|
||||
def setup_method(self, method):
|
||||
Base = declarative_base()
|
||||
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@hybrid_property
|
||||
def lowercase_name(self):
|
||||
return self.name.lower()
|
||||
|
||||
@lowercase_name.expression
|
||||
def lowercase_name(cls):
|
||||
return sa.func.lower(cls.name)
|
||||
|
||||
self.Category = Category
|
||||
|
||||
def test_declarative_model(self):
|
||||
def test_declarative_model(self, Category):
|
||||
assert (
|
||||
list(get_hybrid_properties(self.Category).keys()) ==
|
||||
list(get_hybrid_properties(Category).keys()) ==
|
||||
['lowercase_name']
|
||||
)
|
||||
|
||||
def test_mapper(self):
|
||||
def test_mapper(self, Category):
|
||||
assert (
|
||||
list(get_hybrid_properties(sa.inspect(self.Category)).keys()) ==
|
||||
list(get_hybrid_properties(sa.inspect(Category)).keys()) ==
|
||||
['lowercase_name']
|
||||
)
|
||||
|
||||
def test_aliased_class(self):
|
||||
props = get_hybrid_properties(sa.orm.aliased(self.Category))
|
||||
def test_aliased_class(self, Category):
|
||||
props = get_hybrid_properties(sa.orm.aliased(Category))
|
||||
assert list(props.keys()) == ['lowercase_name']
|
||||
|
|
|
@ -1,104 +1,106 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from pytest import raises
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import get_mapper
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestGetMapper(object):
|
||||
def setup_method(self, method):
|
||||
self.Base = declarative_base()
|
||||
|
||||
class Building(self.Base):
|
||||
@pytest.fixture
|
||||
def Building(self, Base):
|
||||
class Building(Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
return Building
|
||||
|
||||
self.Building = Building
|
||||
def test_table(self, Building):
|
||||
assert get_mapper(Building.__table__) == sa.inspect(Building)
|
||||
|
||||
def test_table(self):
|
||||
assert get_mapper(self.Building.__table__) == sa.inspect(self.Building)
|
||||
|
||||
def test_declarative_class(self):
|
||||
def test_declarative_class(self, Building):
|
||||
assert (
|
||||
get_mapper(self.Building) ==
|
||||
sa.inspect(self.Building)
|
||||
get_mapper(Building) ==
|
||||
sa.inspect(Building)
|
||||
)
|
||||
|
||||
def test_declarative_object(self):
|
||||
def test_declarative_object(self, Building):
|
||||
assert (
|
||||
get_mapper(self.Building()) ==
|
||||
sa.inspect(self.Building)
|
||||
get_mapper(Building()) ==
|
||||
sa.inspect(Building)
|
||||
)
|
||||
|
||||
def test_mapper(self):
|
||||
def test_mapper(self, Building):
|
||||
assert (
|
||||
get_mapper(self.Building.__mapper__) ==
|
||||
sa.inspect(self.Building)
|
||||
get_mapper(Building.__mapper__) ==
|
||||
sa.inspect(Building)
|
||||
)
|
||||
|
||||
def test_class_alias(self):
|
||||
def test_class_alias(self, Building):
|
||||
assert (
|
||||
get_mapper(sa.orm.aliased(self.Building)) ==
|
||||
sa.inspect(self.Building)
|
||||
get_mapper(sa.orm.aliased(Building)) ==
|
||||
sa.inspect(Building)
|
||||
)
|
||||
|
||||
def test_instrumented_attribute(self):
|
||||
def test_instrumented_attribute(self, Building):
|
||||
assert (
|
||||
get_mapper(self.Building.id) == sa.inspect(self.Building)
|
||||
get_mapper(Building.id) == sa.inspect(Building)
|
||||
)
|
||||
|
||||
def test_table_alias(self):
|
||||
alias = sa.orm.aliased(self.Building.__table__)
|
||||
def test_table_alias(self, Building):
|
||||
alias = sa.orm.aliased(Building.__table__)
|
||||
assert (
|
||||
get_mapper(alias) ==
|
||||
sa.inspect(self.Building)
|
||||
sa.inspect(Building)
|
||||
)
|
||||
|
||||
def test_column(self):
|
||||
def test_column(self, Building):
|
||||
assert (
|
||||
get_mapper(self.Building.__table__.c.id) ==
|
||||
sa.inspect(self.Building)
|
||||
get_mapper(Building.__table__.c.id) ==
|
||||
sa.inspect(Building)
|
||||
)
|
||||
|
||||
def test_column_of_an_alias(self):
|
||||
def test_column_of_an_alias(self, Building):
|
||||
assert (
|
||||
get_mapper(sa.orm.aliased(self.Building.__table__).c.id) ==
|
||||
sa.inspect(self.Building)
|
||||
get_mapper(sa.orm.aliased(Building.__table__).c.id) ==
|
||||
sa.inspect(Building)
|
||||
)
|
||||
|
||||
|
||||
class TestGetMapperWithQueryEntities(TestCase):
|
||||
def create_models(self):
|
||||
class Building(self.Base):
|
||||
class TestGetMapperWithQueryEntities(object):
|
||||
|
||||
@pytest.fixture
|
||||
def Building(self, Base):
|
||||
class Building(Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
return Building
|
||||
|
||||
self.Building = Building
|
||||
@pytest.fixture
|
||||
def init_models(self, Building):
|
||||
pass
|
||||
|
||||
def test_mapper_entity_with_mapper(self):
|
||||
entity = self.session.query(self.Building.__mapper__)._entities[0]
|
||||
def test_mapper_entity_with_mapper(self, session, Building):
|
||||
entity = session.query(Building.__mapper__)._entities[0]
|
||||
assert (
|
||||
get_mapper(entity) ==
|
||||
sa.inspect(self.Building)
|
||||
sa.inspect(Building)
|
||||
)
|
||||
|
||||
def test_mapper_entity_with_class(self):
|
||||
entity = self.session.query(self.Building)._entities[0]
|
||||
def test_mapper_entity_with_class(self, session, Building):
|
||||
entity = session.query(Building)._entities[0]
|
||||
assert (
|
||||
get_mapper(entity) ==
|
||||
sa.inspect(self.Building)
|
||||
sa.inspect(Building)
|
||||
)
|
||||
|
||||
def test_column_entity(self):
|
||||
query = self.session.query(self.Building.id)
|
||||
assert get_mapper(query._entities[0]) == sa.inspect(self.Building)
|
||||
def test_column_entity(self, session, Building):
|
||||
query = session.query(Building.id)
|
||||
assert get_mapper(query._entities[0]) == sa.inspect(Building)
|
||||
|
||||
|
||||
class TestGetMapperWithMultipleMappersFound(object):
|
||||
def setup_method(self, method):
|
||||
Base = declarative_base()
|
||||
|
||||
@pytest.fixture
|
||||
def Building(self, Base):
|
||||
class Building(Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
@ -106,29 +108,30 @@ class TestGetMapperWithMultipleMappersFound(object):
|
|||
class BigBuilding(Building):
|
||||
pass
|
||||
|
||||
self.Building = Building
|
||||
self.BigBuilding = BigBuilding
|
||||
return Building
|
||||
|
||||
def test_table(self):
|
||||
with raises(ValueError):
|
||||
get_mapper(self.Building.__table__)
|
||||
def test_table(self, Building):
|
||||
with pytest.raises(ValueError):
|
||||
get_mapper(Building.__table__)
|
||||
|
||||
def test_table_alias(self):
|
||||
alias = sa.orm.aliased(self.Building.__table__)
|
||||
with raises(ValueError):
|
||||
def test_table_alias(self, Building):
|
||||
alias = sa.orm.aliased(Building.__table__)
|
||||
with pytest.raises(ValueError):
|
||||
get_mapper(alias)
|
||||
|
||||
|
||||
class TestGetMapperForTableWithoutMapper(object):
|
||||
def setup_method(self, method):
|
||||
|
||||
@pytest.fixture
|
||||
def building(self):
|
||||
metadata = sa.MetaData()
|
||||
self.building = sa.Table('building', metadata)
|
||||
return sa.Table('building', metadata)
|
||||
|
||||
def test_table(self):
|
||||
with raises(ValueError):
|
||||
get_mapper(self.building)
|
||||
def test_table(self, building):
|
||||
with pytest.raises(ValueError):
|
||||
get_mapper(building)
|
||||
|
||||
def test_table_alias(self):
|
||||
alias = sa.orm.aliased(self.building)
|
||||
with raises(ValueError):
|
||||
def test_table_alias(self, building):
|
||||
alias = sa.orm.aliased(building)
|
||||
with pytest.raises(ValueError):
|
||||
get_mapper(alias)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import get_primary_keys
|
||||
|
||||
|
@ -9,40 +9,40 @@ except ImportError:
|
|||
from ordereddict import OrderedDict
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Building(Base):
|
||||
class Building(Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
name = sa.Column('_name', sa.Unicode(255))
|
||||
return Building
|
||||
|
||||
|
||||
class TestGetPrimaryKeys(object):
|
||||
def setup_method(self, method):
|
||||
Base = declarative_base()
|
||||
|
||||
class Building(Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
name = sa.Column('_name', sa.Unicode(255))
|
||||
|
||||
self.Building = Building
|
||||
|
||||
def test_table(self):
|
||||
assert get_primary_keys(self.Building.__table__) == OrderedDict({
|
||||
'_id': self.Building.__table__.c._id
|
||||
def test_table(self, Building):
|
||||
assert get_primary_keys(Building.__table__) == OrderedDict({
|
||||
'_id': Building.__table__.c._id
|
||||
})
|
||||
|
||||
def test_declarative_class(self):
|
||||
assert get_primary_keys(self.Building) == OrderedDict({
|
||||
'id': self.Building.__table__.c._id
|
||||
def test_declarative_class(self, Building):
|
||||
assert get_primary_keys(Building) == OrderedDict({
|
||||
'id': Building.__table__.c._id
|
||||
})
|
||||
|
||||
def test_declarative_object(self):
|
||||
assert get_primary_keys(self.Building()) == OrderedDict({
|
||||
'id': self.Building.__table__.c._id
|
||||
def test_declarative_object(self, Building):
|
||||
assert get_primary_keys(Building()) == OrderedDict({
|
||||
'id': Building.__table__.c._id
|
||||
})
|
||||
|
||||
def test_class_alias(self):
|
||||
alias = sa.orm.aliased(self.Building)
|
||||
def test_class_alias(self, Building):
|
||||
alias = sa.orm.aliased(Building)
|
||||
assert get_primary_keys(alias) == OrderedDict({
|
||||
'id': self.Building.__table__.c._id
|
||||
'id': Building.__table__.c._id
|
||||
})
|
||||
|
||||
def test_table_alias(self):
|
||||
alias = sa.orm.aliased(self.Building.__table__)
|
||||
def test_table_alias(self, Building):
|
||||
alias = sa.orm.aliased(Building.__table__)
|
||||
assert get_primary_keys(alias) == OrderedDict({
|
||||
'_id': alias.c._id
|
||||
})
|
||||
|
|
|
@ -1,102 +1,115 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import get_query_entities
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestGetQueryEntities(TestCase):
|
||||
def create_models(self):
|
||||
class TextItem(self.Base):
|
||||
__tablename__ = 'text_item'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
@pytest.fixture
|
||||
def TextItem(Base):
|
||||
class TextItem(Base):
|
||||
__tablename__ = 'text_item'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
type = sa.Column(sa.Unicode(255))
|
||||
type = sa.Column(sa.Unicode(255))
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': type,
|
||||
}
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': type,
|
||||
}
|
||||
return TextItem
|
||||
|
||||
class Article(TextItem):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
|
||||
)
|
||||
category = sa.Column(sa.Unicode(255))
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': u'article'
|
||||
}
|
||||
|
||||
class BlogPost(TextItem):
|
||||
__tablename__ = 'blog_post'
|
||||
id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
|
||||
)
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': u'blog_post'
|
||||
}
|
||||
@pytest.fixture
|
||||
def Article(TextItem):
|
||||
class Article(TextItem):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
|
||||
)
|
||||
category = sa.Column(sa.Unicode(255))
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': u'article'
|
||||
}
|
||||
return Article
|
||||
|
||||
self.TextItem = TextItem
|
||||
self.Article = Article
|
||||
self.BlogPost = BlogPost
|
||||
|
||||
def test_mapper(self):
|
||||
query = self.session.query(sa.inspect(self.TextItem))
|
||||
assert get_query_entities(query) == [self.TextItem]
|
||||
@pytest.fixture
|
||||
def BlogPost(TextItem):
|
||||
class BlogPost(TextItem):
|
||||
__tablename__ = 'blog_post'
|
||||
id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
|
||||
)
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': u'blog_post'
|
||||
}
|
||||
return BlogPost
|
||||
|
||||
def test_entity(self):
|
||||
query = self.session.query(self.TextItem)
|
||||
assert get_query_entities(query) == [self.TextItem]
|
||||
|
||||
def test_instrumented_attribute(self):
|
||||
query = self.session.query(self.TextItem.id)
|
||||
assert get_query_entities(query) == [self.TextItem]
|
||||
@pytest.fixture
|
||||
def init_models(TextItem, Article, BlogPost):
|
||||
pass
|
||||
|
||||
def test_column(self):
|
||||
query = self.session.query(self.TextItem.__table__.c.id)
|
||||
assert get_query_entities(query) == [self.TextItem.__table__]
|
||||
|
||||
def test_aliased_selectable(self):
|
||||
selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
|
||||
query = self.session.query(selectable)
|
||||
class TestGetQueryEntities(object):
|
||||
|
||||
def test_mapper(self, session, TextItem):
|
||||
query = session.query(sa.inspect(TextItem))
|
||||
assert get_query_entities(query) == [TextItem]
|
||||
|
||||
def test_entity(self, session, TextItem):
|
||||
query = session.query(TextItem)
|
||||
assert get_query_entities(query) == [TextItem]
|
||||
|
||||
def test_instrumented_attribute(self, session, TextItem):
|
||||
query = session.query(TextItem.id)
|
||||
assert get_query_entities(query) == [TextItem]
|
||||
|
||||
def test_column(self, session, TextItem):
|
||||
query = session.query(TextItem.__table__.c.id)
|
||||
assert get_query_entities(query) == [TextItem.__table__]
|
||||
|
||||
def test_aliased_selectable(self, session, TextItem, BlogPost):
|
||||
selectable = sa.orm.with_polymorphic(TextItem, [BlogPost])
|
||||
query = session.query(selectable)
|
||||
assert get_query_entities(query) == [selectable]
|
||||
|
||||
def test_joined_entity(self):
|
||||
query = self.session.query(self.TextItem).join(
|
||||
self.BlogPost, self.BlogPost.id == self.TextItem.id
|
||||
def test_joined_entity(self, session, TextItem, BlogPost):
|
||||
query = session.query(TextItem).join(
|
||||
BlogPost, BlogPost.id == TextItem.id
|
||||
)
|
||||
assert get_query_entities(query) == [
|
||||
self.TextItem, sa.inspect(self.BlogPost)
|
||||
TextItem, sa.inspect(BlogPost)
|
||||
]
|
||||
|
||||
def test_joined_aliased_entity(self):
|
||||
alias = sa.orm.aliased(self.BlogPost)
|
||||
def test_joined_aliased_entity(self, session, TextItem, BlogPost):
|
||||
alias = sa.orm.aliased(BlogPost)
|
||||
|
||||
query = self.session.query(self.TextItem).join(
|
||||
alias, alias.id == self.TextItem.id
|
||||
query = session.query(TextItem).join(
|
||||
alias, alias.id == TextItem.id
|
||||
)
|
||||
assert get_query_entities(query) == [self.TextItem, alias]
|
||||
assert get_query_entities(query) == [TextItem, alias]
|
||||
|
||||
def test_column_entity_with_label(self):
|
||||
query = self.session.query(self.Article.id.label('id'))
|
||||
assert get_query_entities(query) == [self.Article]
|
||||
def test_column_entity_with_label(self, session, Article):
|
||||
query = session.query(Article.id.label('id'))
|
||||
assert get_query_entities(query) == [Article]
|
||||
|
||||
def test_with_subquery(self):
|
||||
def test_with_subquery(self, session, Article):
|
||||
number_of_articles = (
|
||||
sa.select(
|
||||
[sa.func.count(self.Article.id)],
|
||||
[sa.func.count(Article.id)],
|
||||
)
|
||||
.select_from(
|
||||
self.Article.__table__
|
||||
Article.__table__
|
||||
)
|
||||
).label('number_of_articles')
|
||||
|
||||
query = self.session.query(self.Article, number_of_articles)
|
||||
query = session.query(Article, number_of_articles)
|
||||
assert get_query_entities(query) == [
|
||||
self.Article,
|
||||
Article,
|
||||
number_of_articles
|
||||
]
|
||||
|
||||
def test_aliased_entity(self):
|
||||
alias = sa.orm.aliased(self.Article)
|
||||
query = self.session.query(alias)
|
||||
def test_aliased_entity(self, session, Article):
|
||||
alias = sa.orm.aliased(Article)
|
||||
query = session.query(alias)
|
||||
assert get_query_entities(query) == [alias]
|
||||
|
|
|
@ -1,17 +1,22 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import get_referencing_foreign_keys
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestGetReferencingFksWithCompositeKeys(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
class TestGetReferencingFksWithCompositeKeys(object):
|
||||
|
||||
@pytest.fixture
|
||||
def User(self, Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
first_name = sa.Column(sa.Unicode(255), primary_key=True)
|
||||
last_name = sa.Column(sa.Unicode(255), primary_key=True)
|
||||
return User
|
||||
|
||||
class Article(self.Base):
|
||||
@pytest.fixture
|
||||
def Article(self, Base, User):
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_first_name = sa.Column(sa.Unicode(255))
|
||||
|
@ -22,22 +27,26 @@ class TestGetReferencingFksWithCompositeKeys(TestCase):
|
|||
[User.first_name, User.last_name]
|
||||
),
|
||||
)
|
||||
return Article
|
||||
|
||||
self.User = User
|
||||
self.Article = Article
|
||||
@pytest.fixture
|
||||
def init_models(self, User, Article):
|
||||
pass
|
||||
|
||||
def test_with_declarative_class(self):
|
||||
fks = get_referencing_foreign_keys(self.User)
|
||||
assert self.Article.__table__.foreign_keys == fks
|
||||
def test_with_declarative_class(self, User, Article):
|
||||
fks = get_referencing_foreign_keys(User)
|
||||
assert Article.__table__.foreign_keys == fks
|
||||
|
||||
def test_with_table(self):
|
||||
fks = get_referencing_foreign_keys(self.User.__table__)
|
||||
assert self.Article.__table__.foreign_keys == fks
|
||||
def test_with_table(self, User, Article):
|
||||
fks = get_referencing_foreign_keys(User.__table__)
|
||||
assert Article.__table__.foreign_keys == fks
|
||||
|
||||
|
||||
class TestGetReferencingFksWithInheritance(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
class TestGetReferencingFksWithInheritance(object):
|
||||
|
||||
@pytest.fixture
|
||||
def User(self, Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
type = sa.Column(sa.Unicode)
|
||||
|
@ -47,14 +56,20 @@ class TestGetReferencingFksWithInheritance(TestCase):
|
|||
__mapper_args__ = {
|
||||
'polymorphic_on': 'type'
|
||||
}
|
||||
return User
|
||||
|
||||
@pytest.fixture
|
||||
def Admin(self, User):
|
||||
class Admin(User):
|
||||
__tablename__ = 'admin'
|
||||
id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(User.id), primary_key=True
|
||||
)
|
||||
return Admin
|
||||
|
||||
class TextItem(self.Base):
|
||||
@pytest.fixture
|
||||
def TextItem(self, Base, User):
|
||||
class TextItem(Base):
|
||||
__tablename__ = 'textitem'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
type = sa.Column(sa.Unicode)
|
||||
|
@ -62,7 +77,10 @@ class TestGetReferencingFksWithInheritance(TestCase):
|
|||
__mapper_args__ = {
|
||||
'polymorphic_on': 'type'
|
||||
}
|
||||
return TextItem
|
||||
|
||||
@pytest.fixture
|
||||
def Article(self, TextItem):
|
||||
class Article(TextItem):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(
|
||||
|
@ -71,16 +89,16 @@ class TestGetReferencingFksWithInheritance(TestCase):
|
|||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'article'
|
||||
}
|
||||
return Article
|
||||
|
||||
self.Admin = Admin
|
||||
self.User = User
|
||||
self.Article = Article
|
||||
self.TextItem = TextItem
|
||||
@pytest.fixture
|
||||
def init_models(self, User, Admin, TextItem, Article):
|
||||
pass
|
||||
|
||||
def test_with_declarative_class(self):
|
||||
fks = get_referencing_foreign_keys(self.Admin)
|
||||
assert self.TextItem.__table__.foreign_keys == fks
|
||||
def test_with_declarative_class(self, Admin, TextItem):
|
||||
fks = get_referencing_foreign_keys(Admin)
|
||||
assert TextItem.__table__.foreign_keys == fks
|
||||
|
||||
def test_with_table(self):
|
||||
fks = get_referencing_foreign_keys(self.Admin.__table__)
|
||||
def test_with_table(self, Admin):
|
||||
fks = get_referencing_foreign_keys(Admin.__table__)
|
||||
assert fks == set([])
|
||||
|
|
|
@ -1,76 +1,86 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import get_tables
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestGetTables(TestCase):
|
||||
def create_models(self):
|
||||
class TextItem(self.Base):
|
||||
__tablename__ = 'text_item'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
type = sa.Column(sa.Unicode(255))
|
||||
@pytest.fixture
|
||||
def TextItem(Base):
|
||||
class TextItem(Base):
|
||||
__tablename__ = 'text_item'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
type = sa.Column(sa.Unicode(255))
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': type,
|
||||
'with_polymorphic': '*'
|
||||
}
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': type,
|
||||
'with_polymorphic': '*'
|
||||
}
|
||||
return TextItem
|
||||
|
||||
class Article(TextItem):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
|
||||
)
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': u'article'
|
||||
}
|
||||
|
||||
self.TextItem = TextItem
|
||||
self.Article = Article
|
||||
@pytest.fixture
|
||||
def Article(TextItem):
|
||||
class Article(TextItem):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
|
||||
)
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': u'article'
|
||||
}
|
||||
return Article
|
||||
|
||||
def test_child_class_using_join_table_inheritance(self):
|
||||
assert get_tables(self.Article) == [
|
||||
self.TextItem.__table__,
|
||||
self.Article.__table__
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(TextItem, Article):
|
||||
pass
|
||||
|
||||
|
||||
class TestGetTables(object):
|
||||
|
||||
def test_child_class_using_join_table_inheritance(self, TextItem, Article):
|
||||
assert get_tables(Article) == [
|
||||
TextItem.__table__,
|
||||
Article.__table__
|
||||
]
|
||||
|
||||
def test_entity_using_with_polymorphic(self):
|
||||
assert get_tables(self.TextItem) == [
|
||||
self.TextItem.__table__,
|
||||
self.Article.__table__
|
||||
def test_entity_using_with_polymorphic(self, TextItem, Article):
|
||||
assert get_tables(TextItem) == [
|
||||
TextItem.__table__,
|
||||
Article.__table__
|
||||
]
|
||||
|
||||
def test_instrumented_attribute(self):
|
||||
assert get_tables(self.TextItem.name) == [
|
||||
self.TextItem.__table__,
|
||||
def test_instrumented_attribute(self, TextItem):
|
||||
assert get_tables(TextItem.name) == [
|
||||
TextItem.__table__,
|
||||
]
|
||||
|
||||
def test_polymorphic_instrumented_attribute(self):
|
||||
assert get_tables(self.Article.id) == [
|
||||
self.TextItem.__table__,
|
||||
self.Article.__table__
|
||||
def test_polymorphic_instrumented_attribute(self, TextItem, Article):
|
||||
assert get_tables(Article.id) == [
|
||||
TextItem.__table__,
|
||||
Article.__table__
|
||||
]
|
||||
|
||||
def test_column(self):
|
||||
assert get_tables(self.Article.__table__.c.id) == [
|
||||
self.Article.__table__
|
||||
def test_column(self, Article):
|
||||
assert get_tables(Article.__table__.c.id) == [
|
||||
Article.__table__
|
||||
]
|
||||
|
||||
def test_mapper_entity_with_class(self):
|
||||
query = self.session.query(self.Article)
|
||||
def test_mapper_entity_with_class(self, session, TextItem, Article):
|
||||
query = session.query(Article)
|
||||
assert get_tables(query._entities[0]) == [
|
||||
self.TextItem.__table__, self.Article.__table__
|
||||
TextItem.__table__, Article.__table__
|
||||
]
|
||||
|
||||
def test_mapper_entity_with_mapper(self):
|
||||
query = self.session.query(sa.inspect(self.Article))
|
||||
def test_mapper_entity_with_mapper(self, session, TextItem, Article):
|
||||
query = session.query(sa.inspect(Article))
|
||||
assert get_tables(query._entities[0]) == [
|
||||
self.TextItem.__table__, self.Article.__table__
|
||||
TextItem.__table__, Article.__table__
|
||||
]
|
||||
|
||||
def test_column_entity(self):
|
||||
query = self.session.query(self.Article.id)
|
||||
def test_column_entity(self, session, TextItem, Article):
|
||||
query = session.query(Article.id)
|
||||
assert get_tables(query._entities[0]) == [
|
||||
self.TextItem.__table__, self.Article.__table__
|
||||
TextItem.__table__, Article.__table__
|
||||
]
|
||||
|
|
|
@ -1,46 +1,49 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import get_type
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def User(Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
return User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Article(Base, User):
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
|
||||
author = sa.orm.relationship(User)
|
||||
|
||||
some_property = sa.orm.column_property(
|
||||
sa.func.coalesce(id, 1)
|
||||
)
|
||||
return Article
|
||||
|
||||
|
||||
class TestGetType(object):
|
||||
def setup_method(self, method):
|
||||
Base = declarative_base()
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
def test_instrumented_attribute(self, Article):
|
||||
assert isinstance(get_type(Article.id), sa.Integer)
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
def test_column_property(self, Article):
|
||||
assert isinstance(get_type(Article.id.property), sa.Integer)
|
||||
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
|
||||
author = sa.orm.relationship(User)
|
||||
def test_column(self, Article):
|
||||
assert isinstance(get_type(Article.__table__.c.id), sa.Integer)
|
||||
|
||||
some_property = sa.orm.column_property(
|
||||
sa.func.coalesce(id, 1)
|
||||
)
|
||||
def test_calculated_column_property(self, Article):
|
||||
assert isinstance(get_type(Article.some_property), sa.Integer)
|
||||
|
||||
self.Article = Article
|
||||
self.User = User
|
||||
def test_relationship_property(self, Article, User):
|
||||
assert get_type(Article.author) == User
|
||||
|
||||
def test_instrumented_attribute(self):
|
||||
assert isinstance(get_type(self.Article.id), sa.Integer)
|
||||
|
||||
def test_column_property(self):
|
||||
assert isinstance(get_type(self.Article.id.property), sa.Integer)
|
||||
|
||||
def test_column(self):
|
||||
assert isinstance(get_type(self.Article.__table__.c.id), sa.Integer)
|
||||
|
||||
def test_calculated_column_property(self):
|
||||
assert isinstance(get_type(self.Article.some_property), sa.Integer)
|
||||
|
||||
def test_relationship_property(self):
|
||||
assert get_type(self.Article.author) == self.User
|
||||
|
||||
def test_scalar_select(self):
|
||||
query = sa.select([self.Article.id]).as_scalar()
|
||||
def test_scalar_select(self, Article):
|
||||
query = sa.select([Article.id]).as_scalar()
|
||||
assert isinstance(get_type(query), sa.Integer)
|
||||
|
|
|
@ -1,72 +1,94 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.functions import getdotattr
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestGetDotAttr(TestCase):
|
||||
def create_models(self):
|
||||
class Document(self.Base):
|
||||
__tablename__ = 'document'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
@pytest.fixture
|
||||
def Document(Base):
|
||||
class Document(Base):
|
||||
__tablename__ = 'document'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
return Document
|
||||
|
||||
class Section(self.Base):
|
||||
__tablename__ = 'section'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
document_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Document.id)
|
||||
)
|
||||
@pytest.fixture
|
||||
def Section(Base, Document):
|
||||
class Section(Base):
|
||||
__tablename__ = 'section'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
document = sa.orm.relationship(Document, backref='sections')
|
||||
document_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Document.id)
|
||||
)
|
||||
|
||||
class SubSection(self.Base):
|
||||
__tablename__ = 'subsection'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
document = sa.orm.relationship(Document, backref='sections')
|
||||
return Section
|
||||
|
||||
section_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Section.id)
|
||||
)
|
||||
|
||||
section = sa.orm.relationship(Section, backref='subsections')
|
||||
@pytest.fixture
|
||||
def SubSection(Base, Section):
|
||||
class SubSection(Base):
|
||||
__tablename__ = 'subsection'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class SubSubSection(self.Base):
|
||||
__tablename__ = 'subsubsection'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
section_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Section.id)
|
||||
)
|
||||
|
||||
subsection_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(SubSection.id)
|
||||
)
|
||||
section = sa.orm.relationship(Section, backref='subsections')
|
||||
return SubSection
|
||||
|
||||
subsection = sa.orm.relationship(
|
||||
SubSection, backref='subsubsections'
|
||||
)
|
||||
|
||||
self.Document = Document
|
||||
self.Section = Section
|
||||
self.SubSection = SubSection
|
||||
self.SubSubSection = SubSubSection
|
||||
@pytest.fixture
|
||||
def SubSubSection(Base, SubSection):
|
||||
class SubSubSection(Base):
|
||||
__tablename__ = 'subsubsection'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
def test_simple_objects(self):
|
||||
document = self.Document(name=u'some document')
|
||||
section = self.Section(document=document)
|
||||
subsection = self.SubSection(section=section)
|
||||
subsection_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(SubSection.id)
|
||||
)
|
||||
|
||||
subsection = sa.orm.relationship(
|
||||
SubSection, backref='subsubsections'
|
||||
)
|
||||
return SubSubSection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(Document, Section, SubSection, SubSubSection):
|
||||
pass
|
||||
|
||||
|
||||
class TestGetDotAttr(object):
|
||||
|
||||
def test_simple_objects(self, Document, Section, SubSection):
|
||||
document = Document(name=u'some document')
|
||||
section = Section(document=document)
|
||||
subsection = SubSection(section=section)
|
||||
|
||||
assert getdotattr(
|
||||
subsection,
|
||||
'section.document.name'
|
||||
) == u'some document'
|
||||
|
||||
def test_with_instrumented_lists(self):
|
||||
document = self.Document(name=u'some document')
|
||||
section = self.Section(document=document)
|
||||
subsection = self.SubSection(section=section)
|
||||
subsubsection = self.SubSubSection(subsection=subsection)
|
||||
def test_with_instrumented_lists(
|
||||
self,
|
||||
Document,
|
||||
Section,
|
||||
SubSection,
|
||||
SubSubSection
|
||||
):
|
||||
document = Document(name=u'some document')
|
||||
section = Section(document=document)
|
||||
subsection = SubSection(section=section)
|
||||
subsubsection = SubSubSection(subsection=subsection)
|
||||
|
||||
assert getdotattr(document, 'sections') == [section]
|
||||
assert getdotattr(document, 'sections.subsections') == [
|
||||
|
@ -76,10 +98,10 @@ class TestGetDotAttr(TestCase):
|
|||
subsubsection
|
||||
]
|
||||
|
||||
def test_class_paths(self):
|
||||
assert getdotattr(self.Section, 'document') is self.Section.document
|
||||
def test_class_paths(self, Document, Section, SubSection):
|
||||
assert getdotattr(Section, 'document') is Section.document
|
||||
assert (
|
||||
getdotattr(self.SubSection, 'section.document') is
|
||||
self.Section.document
|
||||
getdotattr(SubSection, 'section.document') is
|
||||
Section.document
|
||||
)
|
||||
assert getdotattr(self.Section, 'document.name') is self.Document.name
|
||||
assert getdotattr(Section, 'document.name') is Document.name
|
||||
|
|
|
@ -1,47 +1,44 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import has_changes
|
||||
|
||||
|
||||
class HasChangesTestCase(object):
|
||||
def setup_method(self, method):
|
||||
Base = declarative_base()
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article_translation'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
title = sa.Column(sa.String(100))
|
||||
|
||||
self.Article = Article
|
||||
@pytest.fixture
|
||||
def Article(Base):
|
||||
class Article(Base):
|
||||
__tablename__ = 'article_translation'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
title = sa.Column(sa.String(100))
|
||||
return Article
|
||||
|
||||
|
||||
class TestHasChangesWithStringAttr(HasChangesTestCase):
|
||||
def test_without_changed_attr(self):
|
||||
article = self.Article()
|
||||
class TestHasChangesWithStringAttr(object):
|
||||
def test_without_changed_attr(self, Article):
|
||||
article = Article()
|
||||
assert not has_changes(article, 'title')
|
||||
|
||||
def test_with_changed_attr(self):
|
||||
article = self.Article(title='Some title')
|
||||
def test_with_changed_attr(self, Article):
|
||||
article = Article(title='Some title')
|
||||
assert has_changes(article, 'title')
|
||||
|
||||
|
||||
class TestHasChangesWithMultipleAttrs(HasChangesTestCase):
|
||||
def test_without_changed_attr(self):
|
||||
article = self.Article()
|
||||
class TestHasChangesWithMultipleAttrs(object):
|
||||
def test_without_changed_attr(self, Article):
|
||||
article = Article()
|
||||
assert not has_changes(article, ['title'])
|
||||
|
||||
def test_with_changed_attr(self):
|
||||
article = self.Article(title='Some title')
|
||||
def test_with_changed_attr(self, Article):
|
||||
article = Article(title='Some title')
|
||||
assert has_changes(article, ['title', 'id'])
|
||||
|
||||
|
||||
class TestHasChangesWithExclude(HasChangesTestCase):
|
||||
def test_without_changed_attr(self):
|
||||
article = self.Article()
|
||||
class TestHasChangesWithExclude(object):
|
||||
def test_without_changed_attr(self, Article):
|
||||
article = Article()
|
||||
assert not has_changes(article, exclude=['id'])
|
||||
|
||||
def test_with_changed_attr(self):
|
||||
article = self.Article(title='Some title')
|
||||
def test_with_changed_attr(self, Article):
|
||||
article = Article(title='Some title')
|
||||
assert has_changes(article, exclude=['id'])
|
||||
assert not has_changes(article, exclude=['title'])
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from pytest import raises
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import get_fk_constraint_for_columns, has_index
|
||||
|
||||
|
||||
class TestHasIndex(object):
|
||||
def setup_method(self, method):
|
||||
Base = declarative_base()
|
||||
|
||||
@pytest.fixture
|
||||
def table(self, Base):
|
||||
class ArticleTranslation(Base):
|
||||
__tablename__ = 'article_translation'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
@ -21,24 +20,23 @@ class TestHasIndex(object):
|
|||
__table_args__ = (
|
||||
sa.Index('my_index', is_deleted, is_archived),
|
||||
)
|
||||
return ArticleTranslation.__table__
|
||||
|
||||
self.table = ArticleTranslation.__table__
|
||||
|
||||
def test_column_that_belongs_to_an_alias(self):
|
||||
alias = sa.orm.aliased(self.table)
|
||||
with raises(TypeError):
|
||||
def test_column_that_belongs_to_an_alias(self, table):
|
||||
alias = sa.orm.aliased(table)
|
||||
with pytest.raises(TypeError):
|
||||
assert has_index(alias.c.id)
|
||||
|
||||
def test_compound_primary_key(self):
|
||||
assert has_index(self.table.c.id)
|
||||
assert not has_index(self.table.c.locale)
|
||||
def test_compound_primary_key(self, table):
|
||||
assert has_index(table.c.id)
|
||||
assert not has_index(table.c.locale)
|
||||
|
||||
def test_single_column_index(self):
|
||||
assert has_index(self.table.c.is_published)
|
||||
def test_single_column_index(self, table):
|
||||
assert has_index(table.c.is_published)
|
||||
|
||||
def test_compound_column_index(self):
|
||||
assert has_index(self.table.c.is_deleted)
|
||||
assert not has_index(self.table.c.is_archived)
|
||||
def test_compound_column_index(self, table):
|
||||
assert has_index(table.c.is_deleted)
|
||||
assert not has_index(table.c.is_archived)
|
||||
|
||||
def test_table_without_primary_key(self):
|
||||
article = sa.Table(
|
||||
|
@ -50,8 +48,7 @@ class TestHasIndex(object):
|
|||
|
||||
|
||||
class TestHasIndexWithFKConstraint(object):
|
||||
def test_composite_fk_without_index(self):
|
||||
Base = declarative_base()
|
||||
def test_composite_fk_without_index(self, Base):
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
|
@ -78,8 +75,7 @@ class TestHasIndexWithFKConstraint(object):
|
|||
)
|
||||
assert not has_index(constraint)
|
||||
|
||||
def test_composite_fk_with_index(self):
|
||||
Base = declarative_base()
|
||||
def test_composite_fk_with_index(self, Base):
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
|
@ -109,8 +105,7 @@ class TestHasIndexWithFKConstraint(object):
|
|||
)
|
||||
assert has_index(constraint)
|
||||
|
||||
def test_composite_fk_with_partial_index_match(self):
|
||||
Base = declarative_base()
|
||||
def test_composite_fk_with_partial_index_match(self, Base):
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
|
|
|
@ -1,18 +1,20 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from pytest import raises
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import get_fk_constraint_for_columns, has_unique_index
|
||||
|
||||
|
||||
class TestHasUniqueIndex(object):
|
||||
def setup_method(self, method):
|
||||
Base = declarative_base()
|
||||
|
||||
@pytest.fixture
|
||||
def articles(self, Base):
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
return Article.__table__
|
||||
|
||||
@pytest.fixture
|
||||
def article_translations(self, Base):
|
||||
class ArticleTranslation(Base):
|
||||
__tablename__ = 'article_translation'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
@ -26,35 +28,33 @@ class TestHasUniqueIndex(object):
|
|||
sa.Index('my_index', is_archived, is_published, unique=True),
|
||||
)
|
||||
|
||||
self.articles = Article.__table__
|
||||
self.article_translations = ArticleTranslation.__table__
|
||||
return ArticleTranslation.__table__
|
||||
|
||||
def test_primary_key(self):
|
||||
assert has_unique_index(self.articles.c.id)
|
||||
def test_primary_key(self, articles):
|
||||
assert has_unique_index(articles.c.id)
|
||||
|
||||
def test_column_of_aliased_table(self):
|
||||
alias = sa.orm.aliased(self.articles)
|
||||
with raises(TypeError):
|
||||
def test_column_of_aliased_table(self, articles):
|
||||
alias = sa.orm.aliased(articles)
|
||||
with pytest.raises(TypeError):
|
||||
assert has_unique_index(alias.c.id)
|
||||
|
||||
def test_unique_index(self):
|
||||
assert has_unique_index(self.article_translations.c.is_deleted)
|
||||
def test_unique_index(self, article_translations):
|
||||
assert has_unique_index(article_translations.c.is_deleted)
|
||||
|
||||
def test_compound_primary_key(self):
|
||||
assert not has_unique_index(self.article_translations.c.id)
|
||||
assert not has_unique_index(self.article_translations.c.locale)
|
||||
def test_compound_primary_key(self, article_translations):
|
||||
assert not has_unique_index(article_translations.c.id)
|
||||
assert not has_unique_index(article_translations.c.locale)
|
||||
|
||||
def test_single_column_index(self):
|
||||
assert not has_unique_index(self.article_translations.c.is_published)
|
||||
def test_single_column_index(self, article_translations):
|
||||
assert not has_unique_index(article_translations.c.is_published)
|
||||
|
||||
def test_compound_column_unique_index(self):
|
||||
assert not has_unique_index(self.article_translations.c.is_published)
|
||||
assert not has_unique_index(self.article_translations.c.is_archived)
|
||||
def test_compound_column_unique_index(self, article_translations):
|
||||
assert not has_unique_index(article_translations.c.is_published)
|
||||
assert not has_unique_index(article_translations.c.is_archived)
|
||||
|
||||
|
||||
class TestHasUniqueIndexWithFKConstraint(object):
|
||||
def test_composite_fk_without_index(self):
|
||||
Base = declarative_base()
|
||||
def test_composite_fk_without_index(self, Base):
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
|
@ -81,8 +81,7 @@ class TestHasUniqueIndexWithFKConstraint(object):
|
|||
)
|
||||
assert not has_unique_index(constraint)
|
||||
|
||||
def test_composite_fk_with_index(self):
|
||||
Base = declarative_base()
|
||||
def test_composite_fk_with_index(self, Base):
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
|
@ -115,8 +114,7 @@ class TestHasUniqueIndexWithFKConstraint(object):
|
|||
)
|
||||
assert has_unique_index(constraint)
|
||||
|
||||
def test_composite_fk_with_partial_index_match(self):
|
||||
Base = declarative_base()
|
||||
def test_composite_fk_with_partial_index_match(self, Base):
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
|
|
|
@ -1,39 +1,46 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.functions import identity
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class IdentityTestCase(TestCase):
|
||||
def test_for_transient_class_without_id(self):
|
||||
assert identity(self.Building()) == (None, )
|
||||
class IdentityTestCase(object):
|
||||
|
||||
def test_for_transient_class_with_id(self):
|
||||
building = self.Building(name=u'Some building')
|
||||
self.session.add(building)
|
||||
self.session.flush()
|
||||
@pytest.fixture
|
||||
def init_models(self, Building):
|
||||
pass
|
||||
|
||||
def test_for_transient_class_without_id(self, Building):
|
||||
assert identity(Building()) == (None, )
|
||||
|
||||
def test_for_transient_class_with_id(self, session, Building):
|
||||
building = Building(name=u'Some building')
|
||||
session.add(building)
|
||||
session.flush()
|
||||
|
||||
assert identity(building) == (building.id, )
|
||||
|
||||
def test_identity_for_class(self):
|
||||
assert identity(self.Building) == (self.Building.id, )
|
||||
def test_identity_for_class(self, Building):
|
||||
assert identity(Building) == (Building.id, )
|
||||
|
||||
|
||||
class TestIdentity(IdentityTestCase):
|
||||
def create_models(self):
|
||||
class Building(self.Base):
|
||||
|
||||
@pytest.fixture
|
||||
def Building(self, Base):
|
||||
class Building(Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
self.Building = Building
|
||||
return Building
|
||||
|
||||
|
||||
class TestIdentityWithColumnAlias(IdentityTestCase):
|
||||
def create_models(self):
|
||||
class Building(self.Base):
|
||||
|
||||
@pytest.fixture
|
||||
def Building(self, Base):
|
||||
class Building(Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
self.Building = Building
|
||||
return Building
|
||||
|
|
|
@ -1,24 +1,24 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import is_loaded
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Article(Base):
|
||||
class Article(Base):
|
||||
__tablename__ = 'article_translation'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
title = sa.orm.deferred(sa.Column(sa.String(100)))
|
||||
return Article
|
||||
|
||||
|
||||
class TestIsLoaded(object):
|
||||
def setup_method(self, method):
|
||||
Base = declarative_base()
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article_translation'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
title = sa.orm.deferred(sa.Column(sa.String(100)))
|
||||
|
||||
self.Article = Article
|
||||
|
||||
def test_loaded_property(self):
|
||||
article = self.Article(id=1)
|
||||
def test_loaded_property(self, Article):
|
||||
article = Article(id=1)
|
||||
assert is_loaded(article, 'id')
|
||||
|
||||
def test_unloaded_property(self):
|
||||
article = self.Article(id=4)
|
||||
def test_unloaded_property(self, Article):
|
||||
article = Article(id=4)
|
||||
assert not is_loaded(article, 'title')
|
||||
|
|
|
@ -2,11 +2,10 @@ import pytest
|
|||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import json_sql
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestJSONSQL(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestJSONSQL(object):
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('value', 'result'),
|
||||
|
@ -27,7 +26,7 @@ class TestJSONSQL(TestCase):
|
|||
)
|
||||
)
|
||||
)
|
||||
def test_compiled_scalars(self, value, result):
|
||||
def test_compiled_scalars(self, connection, value, result):
|
||||
assert result == (
|
||||
self.connection.execute(sa.select([json_sql(value)])).fetchone()[0]
|
||||
connection.execute(sa.select([json_sql(value)])).fetchone()[0]
|
||||
)
|
||||
|
|
|
@ -1,90 +1,102 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.functions.sort_query import make_order_by_deterministic
|
||||
from tests import assert_contains, TestCase
|
||||
|
||||
from .. import assert_contains
|
||||
|
||||
|
||||
class TestMakeOrderByDeterministic(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode)
|
||||
email = sa.Column(sa.Unicode, unique=True)
|
||||
@pytest.fixture
|
||||
def Article(Base):
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
|
||||
author = sa.orm.relationship('User')
|
||||
return Article
|
||||
|
||||
email_lower = sa.orm.column_property(
|
||||
sa.func.lower(name)
|
||||
)
|
||||
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
|
||||
author = sa.orm.relationship(User)
|
||||
@pytest.fixture
|
||||
def User(Base, Article):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode)
|
||||
email = sa.Column(sa.Unicode, unique=True)
|
||||
|
||||
User.article_count = sa.orm.column_property(
|
||||
sa.select([sa.func.count()], from_obj=Article)
|
||||
.where(Article.author_id == User.id)
|
||||
.label('article_count')
|
||||
email_lower = sa.orm.column_property(
|
||||
sa.func.lower(name)
|
||||
)
|
||||
|
||||
self.User = User
|
||||
self.Article = Article
|
||||
User.article_count = sa.orm.column_property(
|
||||
sa.select([sa.func.count()], from_obj=Article)
|
||||
.where(Article.author_id == User.id)
|
||||
.label('article_count')
|
||||
)
|
||||
return User
|
||||
|
||||
def test_column_property(self):
|
||||
query = self.session.query(self.User).order_by(self.User.email_lower)
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(Article, User):
|
||||
pass
|
||||
|
||||
|
||||
class TestMakeOrderByDeterministic(object):
|
||||
|
||||
def test_column_property(self, session, User):
|
||||
query = session.query(User).order_by(User.email_lower)
|
||||
query = make_order_by_deterministic(query)
|
||||
assert_contains('lower("user".name), "user".id ASC', query)
|
||||
|
||||
def test_unique_column(self):
|
||||
query = self.session.query(self.User).order_by(self.User.email)
|
||||
def test_unique_column(self, session, User):
|
||||
query = session.query(User).order_by(User.email)
|
||||
query = make_order_by_deterministic(query)
|
||||
|
||||
assert str(query).endswith('ORDER BY "user".email')
|
||||
|
||||
def test_non_unique_column(self):
|
||||
query = self.session.query(self.User).order_by(self.User.name)
|
||||
def test_non_unique_column(self, session, User):
|
||||
query = session.query(User).order_by(User.name)
|
||||
query = make_order_by_deterministic(query)
|
||||
assert_contains('ORDER BY "user".name, "user".id ASC', query)
|
||||
|
||||
def test_descending_order_by(self):
|
||||
query = self.session.query(self.User).order_by(
|
||||
sa.desc(self.User.name)
|
||||
def test_descending_order_by(self, session, User):
|
||||
query = session.query(User).order_by(
|
||||
sa.desc(User.name)
|
||||
)
|
||||
query = make_order_by_deterministic(query)
|
||||
assert_contains('ORDER BY "user".name DESC, "user".id DESC', query)
|
||||
|
||||
def test_ascending_order_by(self):
|
||||
query = self.session.query(self.User).order_by(
|
||||
sa.asc(self.User.name)
|
||||
def test_ascending_order_by(self, session, User):
|
||||
query = session.query(User).order_by(
|
||||
sa.asc(User.name)
|
||||
)
|
||||
query = make_order_by_deterministic(query)
|
||||
assert_contains('ORDER BY "user".name ASC, "user".id ASC', query)
|
||||
|
||||
def test_string_order_by(self):
|
||||
query = self.session.query(self.User).order_by('name')
|
||||
def test_string_order_by(self, session, User):
|
||||
query = session.query(User).order_by('name')
|
||||
query = make_order_by_deterministic(query)
|
||||
assert_contains('ORDER BY "user".name, "user".id ASC', query)
|
||||
|
||||
def test_annotated_label(self):
|
||||
query = self.session.query(self.User).order_by(self.User.article_count)
|
||||
def test_annotated_label(self, session, User):
|
||||
query = session.query(User).order_by(User.article_count)
|
||||
query = make_order_by_deterministic(query)
|
||||
assert_contains('article_count, "user".id ASC', query)
|
||||
|
||||
def test_annotated_label_with_descending_order(self):
|
||||
query = self.session.query(self.User).order_by(
|
||||
sa.desc(self.User.article_count)
|
||||
def test_annotated_label_with_descending_order(self, session, User):
|
||||
query = session.query(User).order_by(
|
||||
sa.desc(User.article_count)
|
||||
)
|
||||
query = make_order_by_deterministic(query)
|
||||
assert_contains('ORDER BY article_count DESC, "user".id DESC', query)
|
||||
|
||||
def test_query_without_order_by(self):
|
||||
query = self.session.query(self.User)
|
||||
def test_query_without_order_by(self, session, User):
|
||||
query = session.query(User)
|
||||
query = make_order_by_deterministic(query)
|
||||
assert 'ORDER BY "user".id' in str(query)
|
||||
|
||||
def test_alias(self):
|
||||
alias = sa.orm.aliased(self.User.__table__)
|
||||
query = self.session.query(alias).order_by(alias.c.name)
|
||||
def test_alias(self, session, User):
|
||||
alias = sa.orm.aliased(User.__table__)
|
||||
query = session.query(alias).order_by(alias.c.name)
|
||||
query = make_order_by_deterministic(query)
|
||||
assert str(query).endswith('ORDER BY user_1.name, "user".id ASC')
|
||||
|
|
|
@ -1,20 +1,25 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import merge_references
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestMergeReferences(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
class TestMergeReferences(object):
|
||||
|
||||
@pytest.fixture
|
||||
def User(self, Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
def __repr__(self):
|
||||
return 'User(%r)' % self.name
|
||||
return User
|
||||
|
||||
class BlogPost(self.Base):
|
||||
@pytest.fixture
|
||||
def BlogPost(self, Base, User):
|
||||
class BlogPost(Base):
|
||||
__tablename__ = 'blog_post'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
title = sa.Column(sa.Unicode(255))
|
||||
|
@ -22,35 +27,37 @@ class TestMergeReferences(TestCase):
|
|||
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
|
||||
|
||||
author = sa.orm.relationship(User)
|
||||
return BlogPost
|
||||
|
||||
self.User = User
|
||||
self.BlogPost = BlogPost
|
||||
@pytest.fixture
|
||||
def init_models(self, User, BlogPost):
|
||||
pass
|
||||
|
||||
def test_updates_foreign_keys(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
post = self.BlogPost(title=u'Some title', author=john)
|
||||
post2 = self.BlogPost(title=u'Other title', author=jack)
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.add(post)
|
||||
self.session.add(post2)
|
||||
self.session.commit()
|
||||
def test_updates_foreign_keys(self, session, User, BlogPost):
|
||||
john = User(name=u'John')
|
||||
jack = User(name=u'Jack')
|
||||
post = BlogPost(title=u'Some title', author=john)
|
||||
post2 = BlogPost(title=u'Other title', author=jack)
|
||||
session.add(john)
|
||||
session.add(jack)
|
||||
session.add(post)
|
||||
session.add(post2)
|
||||
session.commit()
|
||||
merge_references(john, jack)
|
||||
self.session.commit()
|
||||
session.commit()
|
||||
assert post.author == jack
|
||||
assert post2.author == jack
|
||||
|
||||
def test_object_merging_whenever_possible(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
post = self.BlogPost(title=u'Some title', author=john)
|
||||
post2 = self.BlogPost(title=u'Other title', author=jack)
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.add(post)
|
||||
self.session.add(post2)
|
||||
self.session.commit()
|
||||
def test_object_merging_whenever_possible(self, session, User, BlogPost):
|
||||
john = User(name=u'John')
|
||||
jack = User(name=u'Jack')
|
||||
post = BlogPost(title=u'Some title', author=john)
|
||||
post2 = BlogPost(title=u'Other title', author=jack)
|
||||
session.add(john)
|
||||
session.add(jack)
|
||||
session.add(post)
|
||||
session.add(post2)
|
||||
session.commit()
|
||||
# Load the author for post
|
||||
assert post.author_id == john.id
|
||||
merge_references(john, jack)
|
||||
|
@ -58,18 +65,23 @@ class TestMergeReferences(TestCase):
|
|||
assert post2.author_id == jack.id
|
||||
|
||||
|
||||
class TestMergeReferencesWithManyToManyAssociations(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
class TestMergeReferencesWithManyToManyAssociations(object):
|
||||
|
||||
@pytest.fixture
|
||||
def User(self, Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
def __repr__(self):
|
||||
return 'User(%r)' % self.name
|
||||
return User
|
||||
|
||||
@pytest.fixture
|
||||
def Team(self, Base):
|
||||
team_member = sa.Table(
|
||||
'team_member', self.Base.metadata,
|
||||
'team_member', Base.metadata,
|
||||
sa.Column(
|
||||
'user_id', sa.Integer,
|
||||
sa.ForeignKey('user.id', ondelete='CASCADE'),
|
||||
|
@ -82,46 +94,56 @@ class TestMergeReferencesWithManyToManyAssociations(TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
class Team(self.Base):
|
||||
class Team(Base):
|
||||
__tablename__ = 'team'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
members = sa.orm.relationship(
|
||||
User,
|
||||
'User',
|
||||
secondary=team_member,
|
||||
backref='teams'
|
||||
)
|
||||
return Team
|
||||
|
||||
self.User = User
|
||||
self.Team = Team
|
||||
@pytest.fixture
|
||||
def init_models(self, User, Team):
|
||||
pass
|
||||
|
||||
def test_supports_associations(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
team = self.Team(name=u'Team')
|
||||
def test_supports_associations(self, session, User, Team):
|
||||
john = User(name=u'John')
|
||||
jack = User(name=u'Jack')
|
||||
team = Team(name=u'Team')
|
||||
team.members.append(john)
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.commit()
|
||||
session.add(john)
|
||||
session.add(jack)
|
||||
session.commit()
|
||||
merge_references(john, jack)
|
||||
assert john not in team.members
|
||||
assert jack in team.members
|
||||
|
||||
|
||||
class TestMergeReferencesWithManyToManyAssociationObjects(TestCase):
|
||||
def create_models(self):
|
||||
class Team(self.Base):
|
||||
class TestMergeReferencesWithManyToManyAssociationObjects(object):
|
||||
|
||||
@pytest.fixture
|
||||
def Team(self, Base):
|
||||
class Team(Base):
|
||||
__tablename__ = 'team'
|
||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
return Team
|
||||
|
||||
class User(self.Base):
|
||||
@pytest.fixture
|
||||
def User(self, Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
return User
|
||||
|
||||
class TeamMember(self.Base):
|
||||
@pytest.fixture
|
||||
def TeamMember(self, Base, User, Team):
|
||||
class TeamMember(Base):
|
||||
__tablename__ = 'team_member'
|
||||
user_id = sa.Column(
|
||||
sa.Integer,
|
||||
|
@ -150,22 +172,23 @@ class TestMergeReferencesWithManyToManyAssociationObjects(TestCase):
|
|||
),
|
||||
primaryjoin=user_id == User.id,
|
||||
)
|
||||
return TeamMember
|
||||
|
||||
self.User = User
|
||||
self.TeamMember = TeamMember
|
||||
self.Team = Team
|
||||
@pytest.fixture
|
||||
def init_models(self, User, Team, TeamMember):
|
||||
pass
|
||||
|
||||
def test_supports_associations(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
team = self.Team(name=u'Team')
|
||||
team.members.append(self.TeamMember(user=john))
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.add(team)
|
||||
self.session.commit()
|
||||
def test_supports_associations(self, session, User, Team, TeamMember):
|
||||
john = User(name=u'John')
|
||||
jack = User(name=u'Jack')
|
||||
team = Team(name=u'Team')
|
||||
team.members.append(TeamMember(user=john))
|
||||
session.add(john)
|
||||
session.add(jack)
|
||||
session.add(team)
|
||||
session.commit()
|
||||
merge_references(john, jack)
|
||||
self.session.commit()
|
||||
session.commit()
|
||||
users = [member.user for member in team.members]
|
||||
assert john not in users
|
||||
assert jack in users
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
from sqlalchemy_utils.functions import naturally_equivalent
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestNaturallyEquivalent(TestCase):
|
||||
def test_returns_true_when_properties_match(self):
|
||||
class TestNaturallyEquivalent(object):
|
||||
def test_returns_true_when_properties_match(self, User):
|
||||
assert naturally_equivalent(
|
||||
self.User(name=u'someone'), self.User(name=u'someone')
|
||||
User(name=u'someone'), User(name=u'someone')
|
||||
)
|
||||
|
||||
def test_skips_primary_keys(self):
|
||||
def test_skips_primary_keys(self, User):
|
||||
assert naturally_equivalent(
|
||||
self.User(id=1, name=u'someone'), self.User(id=2, name=u'someone')
|
||||
User(id=1, name=u'someone'), User(id=2, name=u'someone')
|
||||
)
|
||||
|
|
|
@ -1,24 +1,32 @@
|
|||
from itertools import chain
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.functions import non_indexed_foreign_keys
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestFindNonIndexedForeignKeys(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
class TestFindNonIndexedForeignKeys(object):
|
||||
|
||||
@pytest.fixture
|
||||
def User(self, Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
return User
|
||||
|
||||
class Category(self.Base):
|
||||
@pytest.fixture
|
||||
def Category(self, Base):
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
return Category
|
||||
|
||||
class Article(self.Base):
|
||||
@pytest.fixture
|
||||
def Article(self, Base, User, Category):
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
@ -34,13 +42,14 @@ class TestFindNonIndexedForeignKeys(TestCase):
|
|||
'articles',
|
||||
)
|
||||
)
|
||||
return Article
|
||||
|
||||
self.User = User
|
||||
self.Category = Category
|
||||
self.Article = Article
|
||||
@pytest.fixture
|
||||
def init_models(self, User, Category, Article):
|
||||
pass
|
||||
|
||||
def test_finds_all_non_indexed_fks(self):
|
||||
fks = non_indexed_foreign_keys(self.Base.metadata, self.engine)
|
||||
def test_finds_all_non_indexed_fks(self, session, Base, engine):
|
||||
fks = non_indexed_foreign_keys(Base.metadata, engine)
|
||||
assert (
|
||||
'article' in
|
||||
fks
|
||||
|
|
|
@ -1,18 +1,22 @@
|
|||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from sqlalchemy_utils.functions import quote
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestQuote(TestCase):
|
||||
def test_quote_with_preserved_keyword(self):
|
||||
assert quote(self.connection, 'order') == '"order"'
|
||||
assert quote(self.session, 'order') == '"order"'
|
||||
assert quote(self.engine, 'order') == '"order"'
|
||||
class TestQuote(object):
|
||||
def test_quote_with_preserved_keyword(self, engine, connection, session):
|
||||
assert quote(connection, 'order') == '"order"'
|
||||
assert quote(session, 'order') == '"order"'
|
||||
assert quote(engine, 'order') == '"order"'
|
||||
assert quote(postgresql.dialect(), 'order') == '"order"'
|
||||
|
||||
def test_quote_with_non_preserved_keyword(self):
|
||||
assert quote(self.connection, 'some_order') == 'some_order'
|
||||
assert quote(self.session, 'some_order') == 'some_order'
|
||||
assert quote(self.engine, 'some_order') == 'some_order'
|
||||
def test_quote_with_non_preserved_keyword(
|
||||
self,
|
||||
engine,
|
||||
connection,
|
||||
session
|
||||
):
|
||||
assert quote(connection, 'some_order') == 'some_order'
|
||||
assert quote(session, 'some_order') == 'some_order'
|
||||
assert quote(engine, 'some_order') == 'some_order'
|
||||
assert quote(postgresql.dialect(), 'some_order') == 'some_order'
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.functions import (
|
||||
|
@ -5,52 +6,58 @@ from sqlalchemy_utils.functions import (
|
|||
render_expression,
|
||||
render_statement
|
||||
)
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestRender(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
class TestRender(object):
|
||||
|
||||
@pytest.fixture
|
||||
def User(self, Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
return User
|
||||
|
||||
self.User = User
|
||||
@pytest.fixture
|
||||
def init_models(self, User):
|
||||
pass
|
||||
|
||||
def test_render_orm_query(self):
|
||||
query = self.session.query(self.User).filter_by(id=3)
|
||||
def test_render_orm_query(self, session, User):
|
||||
query = session.query(User).filter_by(id=3)
|
||||
text = render_statement(query)
|
||||
|
||||
assert 'SELECT user.id, user.name' in text
|
||||
assert 'FROM user' in text
|
||||
assert 'WHERE user.id = 3' in text
|
||||
|
||||
def test_render_statement(self):
|
||||
statement = self.User.__table__.select().where(self.User.id == 3)
|
||||
text = render_statement(statement, bind=self.session.bind)
|
||||
def test_render_statement(self, session, User):
|
||||
statement = User.__table__.select().where(User.id == 3)
|
||||
text = render_statement(statement, bind=session.bind)
|
||||
|
||||
assert 'SELECT user.id, user.name' in text
|
||||
assert 'FROM user' in text
|
||||
assert 'WHERE user.id = 3' in text
|
||||
|
||||
def test_render_statement_without_mapper(self):
|
||||
def test_render_statement_without_mapper(self, session):
|
||||
statement = sa.select([sa.text('1')])
|
||||
text = render_statement(statement, bind=self.session.bind)
|
||||
text = render_statement(statement, bind=session.bind)
|
||||
|
||||
assert 'SELECT 1' in text
|
||||
|
||||
def test_render_ddl(self):
|
||||
expression = 'self.User.__table__.create(engine)'
|
||||
stream = render_expression(expression, self.engine)
|
||||
def test_render_ddl(self, engine, User):
|
||||
expression = 'User.__table__.create(engine)'
|
||||
stream = render_expression(expression, engine)
|
||||
|
||||
text = stream.getvalue()
|
||||
|
||||
assert 'CREATE TABLE user' in text
|
||||
assert 'PRIMARY KEY' in text
|
||||
|
||||
def test_render_mock_ddl(self):
|
||||
def test_render_mock_ddl(self, engine, User):
|
||||
# TODO: mock_engine doesn't seem to work with locally scoped variables.
|
||||
self.engine = engine
|
||||
with mock_engine('self.engine') as stream:
|
||||
self.User.__table__.create(self.engine)
|
||||
User.__table__.create(self.engine)
|
||||
|
||||
text = stream.getvalue()
|
||||
|
||||
|
|
|
@ -1,26 +1,33 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import table_name
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestTableName(TestCase):
|
||||
def create_models(self):
|
||||
class Building(self.Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
@pytest.fixture
|
||||
def Building(Base):
|
||||
class Building(Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
return Building
|
||||
|
||||
self.Building = Building
|
||||
|
||||
def test_class(self):
|
||||
assert table_name(self.Building) == 'building'
|
||||
del self.Building.__tablename__
|
||||
assert table_name(self.Building) == 'building'
|
||||
@pytest.fixture
|
||||
def init_models(Base):
|
||||
pass
|
||||
|
||||
def test_attribute(self):
|
||||
assert table_name(self.Building.id) == 'building'
|
||||
assert table_name(self.Building.name) == 'building'
|
||||
|
||||
def test_target(self):
|
||||
assert table_name(self.Building()) == 'building'
|
||||
class TestTableName(object):
|
||||
|
||||
def test_class(self, Building):
|
||||
assert table_name(Building) == 'building'
|
||||
del Building.__tablename__
|
||||
assert table_name(Building) == 'building'
|
||||
|
||||
def test_attribute(self, Building):
|
||||
assert table_name(Building.id) == 'building'
|
||||
assert table_name(Building.name) == 'building'
|
||||
|
||||
def test_target(self, Building):
|
||||
assert table_name(Building()) == 'building'
|
||||
|
|
|
@ -1,109 +1,105 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import six
|
||||
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class GenericRelationshipTestCase(TestCase):
|
||||
def test_set_as_none(self):
|
||||
event = self.Event()
|
||||
class GenericRelationshipTestCase(object):
|
||||
def test_set_as_none(self, Event):
|
||||
event = Event()
|
||||
event.object = None
|
||||
assert event.object is None
|
||||
|
||||
def test_set_manual_and_get(self):
|
||||
user = self.User()
|
||||
def test_set_manual_and_get(self, session, User, Event):
|
||||
user = User()
|
||||
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
event = self.Event()
|
||||
event = Event()
|
||||
event.object_id = user.id
|
||||
event.object_type = six.text_type(type(user).__name__)
|
||||
|
||||
assert event.object is None
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
session.add(event)
|
||||
session.commit()
|
||||
|
||||
assert event.object == user
|
||||
|
||||
def test_set_and_get(self):
|
||||
user = self.User()
|
||||
def test_set_and_get(self, session, User, Event):
|
||||
user = User()
|
||||
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
event = self.Event(object=user)
|
||||
event = Event(object=user)
|
||||
|
||||
assert event.object_id == user.id
|
||||
assert event.object_type == type(user).__name__
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
session.add(event)
|
||||
session.commit()
|
||||
|
||||
assert event.object == user
|
||||
|
||||
def test_compare_instance(self):
|
||||
user1 = self.User()
|
||||
user2 = self.User()
|
||||
def test_compare_instance(self, session, User, Event):
|
||||
user1 = User()
|
||||
user2 = User()
|
||||
|
||||
self.session.add_all([user1, user2])
|
||||
self.session.commit()
|
||||
session.add_all([user1, user2])
|
||||
session.commit()
|
||||
|
||||
event = self.Event(object=user1)
|
||||
event = Event(object=user1)
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
session.add(event)
|
||||
session.commit()
|
||||
|
||||
assert event.object == user1
|
||||
assert event.object != user2
|
||||
|
||||
def test_compare_query(self):
|
||||
user1 = self.User()
|
||||
user2 = self.User()
|
||||
def test_compare_query(self, session, User, Event):
|
||||
user1 = User()
|
||||
user2 = User()
|
||||
|
||||
self.session.add_all([user1, user2])
|
||||
self.session.commit()
|
||||
session.add_all([user1, user2])
|
||||
session.commit()
|
||||
|
||||
event = self.Event(object=user1)
|
||||
event = Event(object=user1)
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
session.add(event)
|
||||
session.commit()
|
||||
|
||||
q = self.session.query(self.Event)
|
||||
q = session.query(Event)
|
||||
assert q.filter_by(object=user1).first() is not None
|
||||
assert q.filter_by(object=user2).first() is None
|
||||
assert q.filter(self.Event.object == user2).first() is None
|
||||
assert q.filter(Event.object == user2).first() is None
|
||||
|
||||
def test_compare_not_query(self):
|
||||
user1 = self.User()
|
||||
user2 = self.User()
|
||||
def test_compare_not_query(self, session, User, Event):
|
||||
user1 = User()
|
||||
user2 = User()
|
||||
|
||||
self.session.add_all([user1, user2])
|
||||
self.session.commit()
|
||||
session.add_all([user1, user2])
|
||||
session.commit()
|
||||
|
||||
event = self.Event(object=user1)
|
||||
event = Event(object=user1)
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
session.add(event)
|
||||
session.commit()
|
||||
|
||||
q = self.session.query(self.Event)
|
||||
assert q.filter(self.Event.object != user2).first() is not None
|
||||
q = session.query(Event)
|
||||
assert q.filter(Event.object != user2).first() is not None
|
||||
|
||||
def test_compare_type(self):
|
||||
user1 = self.User()
|
||||
user2 = self.User()
|
||||
def test_compare_type(self, session, User, Event):
|
||||
user1 = User()
|
||||
user2 = User()
|
||||
|
||||
self.session.add_all([user1, user2])
|
||||
self.session.commit()
|
||||
session.add_all([user1, user2])
|
||||
session.commit()
|
||||
|
||||
event1 = self.Event(object=user1)
|
||||
event2 = self.Event(object=user2)
|
||||
event1 = Event(object=user1)
|
||||
event2 = Event(object=user2)
|
||||
|
||||
self.session.add_all([event1, event2])
|
||||
self.session.commit()
|
||||
session.add_all([event1, event2])
|
||||
session.commit()
|
||||
|
||||
statement = self.Event.object.is_type(self.User)
|
||||
q = self.session.query(self.Event).filter(statement)
|
||||
statement = Event.object.is_type(User)
|
||||
q = session.query(Event).filter(statement)
|
||||
assert q.first() is not None
|
||||
|
|
|
@ -1,36 +1,54 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.declarative import declared_attr
|
||||
|
||||
from sqlalchemy_utils import generic_relationship
|
||||
from tests.generic_relationship import GenericRelationshipTestCase
|
||||
|
||||
from . import GenericRelationshipTestCase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Building(Base):
|
||||
class Building(Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
return Building
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def User(Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
return User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def EventBase(Base):
|
||||
class EventBase(Base):
|
||||
__abstract__ = True
|
||||
|
||||
object_type = sa.Column(sa.Unicode(255))
|
||||
object_id = sa.Column(sa.Integer, nullable=False)
|
||||
|
||||
@declared_attr
|
||||
def object(cls):
|
||||
return generic_relationship('object_type', 'object_id')
|
||||
return EventBase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Event(EventBase):
|
||||
class Event(EventBase):
|
||||
__tablename__ = 'event'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
return Event
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(Building, User, Event):
|
||||
pass
|
||||
|
||||
|
||||
class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase):
|
||||
def create_models(self):
|
||||
class Building(self.Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
class EventBase(self.Base):
|
||||
__abstract__ = True
|
||||
|
||||
object_type = sa.Column(sa.Unicode(255))
|
||||
object_id = sa.Column(sa.Integer, nullable=False)
|
||||
|
||||
@declared_attr
|
||||
def object(cls):
|
||||
return generic_relationship('object_type', 'object_id')
|
||||
|
||||
class Event(EventBase):
|
||||
__tablename__ = 'event'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
self.Building = Building
|
||||
self.User = User
|
||||
self.Event = Event
|
||||
pass
|
||||
|
|
|
@ -1,30 +1,44 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import generic_relationship
|
||||
from tests.generic_relationship import GenericRelationshipTestCase
|
||||
|
||||
from . import GenericRelationshipTestCase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Building(Base):
|
||||
class Building(Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
return Building
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def User(Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
return User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Event(Base):
|
||||
class Event(Base):
|
||||
__tablename__ = 'event'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
object_type = sa.Column(sa.Unicode(255), name="objectType")
|
||||
object_id = sa.Column(sa.Integer, nullable=False)
|
||||
|
||||
object = generic_relationship(object_type, object_id)
|
||||
return Event
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(Building, User, Event):
|
||||
pass
|
||||
|
||||
|
||||
class TestGenericRelationship(GenericRelationshipTestCase):
|
||||
def create_models(self):
|
||||
class Building(self.Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
class Event(self.Base):
|
||||
__tablename__ = 'event'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
object_type = sa.Column(sa.Unicode(255), name="objectType")
|
||||
object_id = sa.Column(sa.Integer, nullable=False)
|
||||
|
||||
object = generic_relationship(object_type, object_id)
|
||||
|
||||
self.Building = Building
|
||||
self.User = User
|
||||
self.Event = Event
|
||||
pass
|
||||
|
|
|
@ -1,66 +1,84 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
import six
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import generic_relationship
|
||||
from tests.generic_relationship import GenericRelationshipTestCase
|
||||
|
||||
from ..generic_relationship import GenericRelationshipTestCase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def incrementor():
|
||||
class Incrementor(object):
|
||||
value = 1
|
||||
return Incrementor()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Building(Base, incrementor):
|
||||
class Building(Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
code = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
def __init__(self):
|
||||
incrementor.value += 1
|
||||
self.id = incrementor.value
|
||||
self.code = incrementor.value
|
||||
return Building
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def User(Base, incrementor):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
code = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
def __init__(self):
|
||||
incrementor.value += 1
|
||||
self.id = incrementor.value
|
||||
self.code = incrementor.value
|
||||
return User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Event(Base):
|
||||
class Event(Base):
|
||||
__tablename__ = 'event'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
object_type = sa.Column(sa.Unicode(255))
|
||||
object_id = sa.Column(sa.Integer, nullable=False)
|
||||
object_code = sa.Column(sa.Integer, nullable=False)
|
||||
|
||||
object = generic_relationship(
|
||||
object_type, (object_id, object_code)
|
||||
)
|
||||
return Event
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(Building, User, Event):
|
||||
pass
|
||||
|
||||
|
||||
class TestGenericRelationship(GenericRelationshipTestCase):
|
||||
index = 1
|
||||
|
||||
def create_models(self):
|
||||
class Building(self.Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
code = sa.Column(sa.Integer, primary_key=True)
|
||||
def test_set_manual_and_get(self, session, Event, User):
|
||||
user = User()
|
||||
|
||||
def __init__(obj_self):
|
||||
self.index += 1
|
||||
obj_self.id = self.index
|
||||
obj_self.code = self.index
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
code = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
def __init__(obj_self):
|
||||
self.index += 1
|
||||
obj_self.id = self.index
|
||||
obj_self.code = self.index
|
||||
|
||||
class Event(self.Base):
|
||||
__tablename__ = 'event'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
object_type = sa.Column(sa.Unicode(255))
|
||||
object_id = sa.Column(sa.Integer, nullable=False)
|
||||
object_code = sa.Column(sa.Integer, nullable=False)
|
||||
|
||||
object = generic_relationship(
|
||||
object_type, (object_id, object_code)
|
||||
)
|
||||
|
||||
self.Building = Building
|
||||
self.User = User
|
||||
self.Event = Event
|
||||
|
||||
def test_set_manual_and_get(self):
|
||||
user = self.User()
|
||||
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
|
||||
event = self.Event()
|
||||
event = Event()
|
||||
event.object_id = user.id
|
||||
event.object_type = six.text_type(type(user).__name__)
|
||||
event.object_code = user.code
|
||||
|
||||
assert event.object is None
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
session.add(event)
|
||||
session.commit()
|
||||
|
||||
assert event.object == user
|
||||
|
|
|
@ -1,68 +1,79 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
import six
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
|
||||
from sqlalchemy_utils import generic_relationship
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestGenericRelationship(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
@pytest.fixture
|
||||
def User(Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
return User
|
||||
|
||||
class UserHistory(self.Base):
|
||||
__tablename__ = 'user_history'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
transaction_id = sa.Column(sa.Integer, primary_key=True)
|
||||
@pytest.fixture
|
||||
def UserHistory(Base):
|
||||
class UserHistory(Base):
|
||||
__tablename__ = 'user_history'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
class Event(self.Base):
|
||||
__tablename__ = 'event'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
transaction_id = sa.Column(sa.Integer, primary_key=True)
|
||||
return UserHistory
|
||||
|
||||
transaction_id = sa.Column(sa.Integer)
|
||||
|
||||
object_type = sa.Column(sa.Unicode(255))
|
||||
object_id = sa.Column(sa.Integer, nullable=False)
|
||||
@pytest.fixture
|
||||
def Event(Base):
|
||||
class Event(Base):
|
||||
__tablename__ = 'event'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
object = generic_relationship(
|
||||
object_type, object_id
|
||||
)
|
||||
transaction_id = sa.Column(sa.Integer)
|
||||
|
||||
@hybrid_property
|
||||
def object_version_type(self):
|
||||
return self.object_type + 'History'
|
||||
object_type = sa.Column(sa.Unicode(255))
|
||||
object_id = sa.Column(sa.Integer, nullable=False)
|
||||
|
||||
@object_version_type.expression
|
||||
def object_version_type(cls):
|
||||
return sa.func.concat(cls.object_type, 'History')
|
||||
object = generic_relationship(
|
||||
object_type, object_id
|
||||
)
|
||||
|
||||
object_version = generic_relationship(
|
||||
object_version_type, (object_id, transaction_id)
|
||||
)
|
||||
@hybrid_property
|
||||
def object_version_type(self):
|
||||
return self.object_type + 'History'
|
||||
|
||||
self.User = User
|
||||
self.UserHistory = UserHistory
|
||||
self.Event = Event
|
||||
@object_version_type.expression
|
||||
def object_version_type(cls):
|
||||
return sa.func.concat(cls.object_type, 'History')
|
||||
|
||||
def test_set_manual_and_get(self):
|
||||
user = self.User(id=1)
|
||||
history = self.UserHistory(id=1, transaction_id=1)
|
||||
self.session.add(user)
|
||||
self.session.add(history)
|
||||
self.session.commit()
|
||||
object_version = generic_relationship(
|
||||
object_version_type, (object_id, transaction_id)
|
||||
)
|
||||
return Event
|
||||
|
||||
event = self.Event(transaction_id=1)
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(User, UserHistory, Event):
|
||||
pass
|
||||
|
||||
|
||||
class TestGenericRelationship(object):
|
||||
|
||||
def test_set_manual_and_get(self, session, User, UserHistory, Event):
|
||||
user = User(id=1)
|
||||
history = UserHistory(id=1, transaction_id=1)
|
||||
session.add(user)
|
||||
session.add(history)
|
||||
session.commit()
|
||||
|
||||
event = Event(transaction_id=1)
|
||||
event.object_id = user.id
|
||||
event.object_type = six.text_type(type(user).__name__)
|
||||
assert event.object is None
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
session.add(event)
|
||||
session.commit()
|
||||
|
||||
assert event.object == user
|
||||
assert event.object_version == history
|
||||
|
|
|
@ -1,164 +1,178 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
import six
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import generic_relationship
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestGenericRelationship(TestCase):
|
||||
def create_models(self):
|
||||
class Employee(self.Base):
|
||||
__tablename__ = 'employee'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(50))
|
||||
type = sa.Column(sa.String(20))
|
||||
@pytest.fixture
|
||||
def Employee(Base):
|
||||
class Employee(Base):
|
||||
__tablename__ = 'employee'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(50))
|
||||
type = sa.Column(sa.String(20))
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': type,
|
||||
'polymorphic_identity': 'employee'
|
||||
}
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': type,
|
||||
'polymorphic_identity': 'employee'
|
||||
}
|
||||
return Employee
|
||||
|
||||
class Manager(Employee):
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'manager'
|
||||
}
|
||||
|
||||
class Engineer(Employee):
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'engineer'
|
||||
}
|
||||
@pytest.fixture
|
||||
def Manager(Employee):
|
||||
class Manager(Employee):
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'manager'
|
||||
}
|
||||
return Manager
|
||||
|
||||
class Event(self.Base):
|
||||
__tablename__ = 'event'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
object_type = sa.Column(sa.Unicode(255))
|
||||
object_id = sa.Column(sa.Integer, nullable=False)
|
||||
@pytest.fixture
|
||||
def Engineer(Employee):
|
||||
class Engineer(Employee):
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'engineer'
|
||||
}
|
||||
return Engineer
|
||||
|
||||
object = generic_relationship(object_type, object_id)
|
||||
|
||||
self.Employee = Employee
|
||||
self.Manager = Manager
|
||||
self.Engineer = Engineer
|
||||
self.Event = Event
|
||||
@pytest.fixture
|
||||
def Event(Base):
|
||||
class Event(Base):
|
||||
__tablename__ = 'event'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
def test_set_as_none(self):
|
||||
event = self.Event()
|
||||
object_type = sa.Column(sa.Unicode(255))
|
||||
object_id = sa.Column(sa.Integer, nullable=False)
|
||||
|
||||
object = generic_relationship(object_type, object_id)
|
||||
return Event
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(Employee, Manager, Engineer, Event):
|
||||
pass
|
||||
|
||||
|
||||
class TestGenericRelationship(object):
|
||||
|
||||
def test_set_as_none(self, Event):
|
||||
event = Event()
|
||||
event.object = None
|
||||
assert event.object is None
|
||||
|
||||
def test_set_manual_and_get(self):
|
||||
manager = self.Manager()
|
||||
def test_set_manual_and_get(self, session, Manager, Event):
|
||||
manager = Manager()
|
||||
|
||||
self.session.add(manager)
|
||||
self.session.commit()
|
||||
session.add(manager)
|
||||
session.commit()
|
||||
|
||||
event = self.Event()
|
||||
event = Event()
|
||||
event.object_id = manager.id
|
||||
event.object_type = six.text_type(type(manager).__name__)
|
||||
|
||||
assert event.object is None
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
session.add(event)
|
||||
session.commit()
|
||||
|
||||
assert event.object == manager
|
||||
|
||||
def test_set_and_get(self):
|
||||
manager = self.Manager()
|
||||
def test_set_and_get(self, session, Manager, Event):
|
||||
manager = Manager()
|
||||
|
||||
self.session.add(manager)
|
||||
self.session.commit()
|
||||
session.add(manager)
|
||||
session.commit()
|
||||
|
||||
event = self.Event(object=manager)
|
||||
event = Event(object=manager)
|
||||
|
||||
assert event.object_id == manager.id
|
||||
assert event.object_type == type(manager).__name__
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
session.add(event)
|
||||
session.commit()
|
||||
|
||||
assert event.object == manager
|
||||
|
||||
def test_compare_instance(self):
|
||||
manager1 = self.Manager()
|
||||
manager2 = self.Manager()
|
||||
def test_compare_instance(self, session, Manager, Event):
|
||||
manager1 = Manager()
|
||||
manager2 = Manager()
|
||||
|
||||
self.session.add_all([manager1, manager2])
|
||||
self.session.commit()
|
||||
session.add_all([manager1, manager2])
|
||||
session.commit()
|
||||
|
||||
event = self.Event(object=manager1)
|
||||
event = Event(object=manager1)
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
session.add(event)
|
||||
session.commit()
|
||||
|
||||
assert event.object == manager1
|
||||
assert event.object != manager2
|
||||
|
||||
def test_compare_query(self):
|
||||
manager1 = self.Manager()
|
||||
manager2 = self.Manager()
|
||||
def test_compare_query(self, session, Manager, Event):
|
||||
manager1 = Manager()
|
||||
manager2 = Manager()
|
||||
|
||||
self.session.add_all([manager1, manager2])
|
||||
self.session.commit()
|
||||
session.add_all([manager1, manager2])
|
||||
session.commit()
|
||||
|
||||
event = self.Event(object=manager1)
|
||||
event = Event(object=manager1)
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
session.add(event)
|
||||
session.commit()
|
||||
|
||||
q = self.session.query(self.Event)
|
||||
q = session.query(Event)
|
||||
assert q.filter_by(object=manager1).first() is not None
|
||||
assert q.filter_by(object=manager2).first() is None
|
||||
assert q.filter(self.Event.object == manager2).first() is None
|
||||
assert q.filter(Event.object == manager2).first() is None
|
||||
|
||||
def test_compare_not_query(self):
|
||||
manager1 = self.Manager()
|
||||
manager2 = self.Manager()
|
||||
def test_compare_not_query(self, session, Manager, Event):
|
||||
manager1 = Manager()
|
||||
manager2 = Manager()
|
||||
|
||||
self.session.add_all([manager1, manager2])
|
||||
self.session.commit()
|
||||
session.add_all([manager1, manager2])
|
||||
session.commit()
|
||||
|
||||
event = self.Event(object=manager1)
|
||||
event = Event(object=manager1)
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
session.add(event)
|
||||
session.commit()
|
||||
|
||||
q = self.session.query(self.Event)
|
||||
assert q.filter(self.Event.object != manager2).first() is not None
|
||||
q = session.query(Event)
|
||||
assert q.filter(Event.object != manager2).first() is not None
|
||||
|
||||
def test_compare_type(self):
|
||||
manager1 = self.Manager()
|
||||
manager2 = self.Manager()
|
||||
def test_compare_type(self, session, Manager, Event):
|
||||
manager1 = Manager()
|
||||
manager2 = Manager()
|
||||
|
||||
self.session.add_all([manager1, manager2])
|
||||
self.session.commit()
|
||||
session.add_all([manager1, manager2])
|
||||
session.commit()
|
||||
|
||||
event1 = self.Event(object=manager1)
|
||||
event2 = self.Event(object=manager2)
|
||||
event1 = Event(object=manager1)
|
||||
event2 = Event(object=manager2)
|
||||
|
||||
self.session.add_all([event1, event2])
|
||||
self.session.commit()
|
||||
session.add_all([event1, event2])
|
||||
session.commit()
|
||||
|
||||
statement = self.Event.object.is_type(self.Manager)
|
||||
q = self.session.query(self.Event).filter(statement)
|
||||
statement = Event.object.is_type(Manager)
|
||||
q = session.query(Event).filter(statement)
|
||||
assert q.first() is not None
|
||||
|
||||
def test_compare_super_type(self):
|
||||
manager1 = self.Manager()
|
||||
manager2 = self.Manager()
|
||||
def test_compare_super_type(self, session, Manager, Event, Employee):
|
||||
manager1 = Manager()
|
||||
manager2 = Manager()
|
||||
|
||||
self.session.add_all([manager1, manager2])
|
||||
self.session.commit()
|
||||
session.add_all([manager1, manager2])
|
||||
session.commit()
|
||||
|
||||
event1 = self.Event(object=manager1)
|
||||
event2 = self.Event(object=manager2)
|
||||
event1 = Event(object=manager1)
|
||||
event2 = Event(object=manager2)
|
||||
|
||||
self.session.add_all([event1, event2])
|
||||
self.session.commit()
|
||||
session.add_all([event1, event2])
|
||||
session.commit()
|
||||
|
||||
statement = self.Event.object.is_type(self.Employee)
|
||||
q = self.session.query(self.Event).filter(statement)
|
||||
statement = Event.object.is_type(Employee)
|
||||
q = session.query(Event).filter(statement)
|
||||
assert q.first() is not None
|
||||
|
|
147
tests/mixins.py
147
tests/mixins.py
|
@ -1,18 +1,24 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
class ThreeLevelDeepOneToOne(object):
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
|
||||
@pytest.fixture
|
||||
def Catalog(self, Base, Category):
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
category = sa.orm.relationship(
|
||||
'Category',
|
||||
Category,
|
||||
uselist=False,
|
||||
backref='catalog'
|
||||
)
|
||||
return Catalog
|
||||
|
||||
class Category(self.Base):
|
||||
@pytest.fixture
|
||||
def Category(self, Base, SubCategory):
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
catalog_id = sa.Column(
|
||||
|
@ -22,12 +28,15 @@ class ThreeLevelDeepOneToOne(object):
|
|||
)
|
||||
|
||||
sub_category = sa.orm.relationship(
|
||||
'SubCategory',
|
||||
SubCategory,
|
||||
uselist=False,
|
||||
backref='category'
|
||||
)
|
||||
return Category
|
||||
|
||||
class SubCategory(self.Base):
|
||||
@pytest.fixture
|
||||
def SubCategory(self, Base, Product):
|
||||
class SubCategory(Base):
|
||||
__tablename__ = 'sub_category'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
category_id = sa.Column(
|
||||
|
@ -36,12 +45,15 @@ class ThreeLevelDeepOneToOne(object):
|
|||
sa.ForeignKey('category._id')
|
||||
)
|
||||
product = sa.orm.relationship(
|
||||
'Product',
|
||||
Product,
|
||||
uselist=False,
|
||||
backref='sub_category'
|
||||
)
|
||||
return SubCategory
|
||||
|
||||
class Product(self.Base):
|
||||
@pytest.fixture
|
||||
def Product(self, Base):
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
price = sa.Column(sa.Integer)
|
||||
|
@ -51,22 +63,27 @@ class ThreeLevelDeepOneToOne(object):
|
|||
sa.Integer,
|
||||
sa.ForeignKey('sub_category._id')
|
||||
)
|
||||
return Product
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.SubCategory = SubCategory
|
||||
self.Product = Product
|
||||
@pytest.fixture
|
||||
def init_models(self, Catalog, Category, SubCategory, Product):
|
||||
pass
|
||||
|
||||
|
||||
class ThreeLevelDeepOneToMany(object):
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
|
||||
@pytest.fixture
|
||||
def Catalog(self, Base, Category):
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
|
||||
categories = sa.orm.relationship('Category', backref='catalog')
|
||||
categories = sa.orm.relationship(Category, backref='catalog')
|
||||
return Catalog
|
||||
|
||||
class Category(self.Base):
|
||||
@pytest.fixture
|
||||
def Category(self, Base, SubCategory):
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
catalog_id = sa.Column(
|
||||
|
@ -76,10 +93,13 @@ class ThreeLevelDeepOneToMany(object):
|
|||
)
|
||||
|
||||
sub_categories = sa.orm.relationship(
|
||||
'SubCategory', backref='category'
|
||||
SubCategory, backref='category'
|
||||
)
|
||||
return Category
|
||||
|
||||
class SubCategory(self.Base):
|
||||
@pytest.fixture
|
||||
def SubCategory(self, Base, Product):
|
||||
class SubCategory(Base):
|
||||
__tablename__ = 'sub_category'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
category_id = sa.Column(
|
||||
|
@ -88,11 +108,14 @@ class ThreeLevelDeepOneToMany(object):
|
|||
sa.ForeignKey('category._id')
|
||||
)
|
||||
products = sa.orm.relationship(
|
||||
'Product',
|
||||
Product,
|
||||
backref='sub_category'
|
||||
)
|
||||
return SubCategory
|
||||
|
||||
class Product(self.Base):
|
||||
@pytest.fixture
|
||||
def Product(self, Base):
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
@ -105,25 +128,42 @@ class ThreeLevelDeepOneToMany(object):
|
|||
|
||||
def __repr__(self):
|
||||
return '<Product id=%r>' % self.id
|
||||
return Product
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.SubCategory = SubCategory
|
||||
self.Product = Product
|
||||
@pytest.fixture
|
||||
def init_models(self, Catalog, Category, SubCategory, Product):
|
||||
pass
|
||||
|
||||
|
||||
class ThreeLevelDeepManyToMany(object):
|
||||
def create_models(self):
|
||||
|
||||
@pytest.fixture
|
||||
def Catalog(self, Base, Category):
|
||||
|
||||
catalog_category = sa.Table(
|
||||
'catalog_category',
|
||||
self.Base.metadata,
|
||||
Base.metadata,
|
||||
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog._id')),
|
||||
sa.Column('category_id', sa.Integer, sa.ForeignKey('category._id'))
|
||||
)
|
||||
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
|
||||
categories = sa.orm.relationship(
|
||||
Category,
|
||||
backref='catalogs',
|
||||
secondary=catalog_category
|
||||
)
|
||||
return Catalog
|
||||
|
||||
@pytest.fixture
|
||||
def Category(self, Base, SubCategory):
|
||||
|
||||
category_subcategory = sa.Table(
|
||||
'category_subcategory',
|
||||
self.Base.metadata,
|
||||
Base.metadata,
|
||||
sa.Column(
|
||||
'category_id',
|
||||
sa.Integer,
|
||||
|
@ -136,9 +176,23 @@ class ThreeLevelDeepManyToMany(object):
|
|||
)
|
||||
)
|
||||
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
|
||||
sub_categories = sa.orm.relationship(
|
||||
SubCategory,
|
||||
backref='categories',
|
||||
secondary=category_subcategory
|
||||
)
|
||||
return Category
|
||||
|
||||
@pytest.fixture
|
||||
def SubCategory(self, Base, Product):
|
||||
|
||||
subcategory_product = sa.Table(
|
||||
'subcategory_product',
|
||||
self.Base.metadata,
|
||||
Base.metadata,
|
||||
sa.Column(
|
||||
'subcategory_id',
|
||||
sa.Integer,
|
||||
|
@ -151,41 +205,24 @@ class ThreeLevelDeepManyToMany(object):
|
|||
)
|
||||
)
|
||||
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
|
||||
categories = sa.orm.relationship(
|
||||
'Category',
|
||||
backref='catalogs',
|
||||
secondary=catalog_category
|
||||
)
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
|
||||
sub_categories = sa.orm.relationship(
|
||||
'SubCategory',
|
||||
backref='categories',
|
||||
secondary=category_subcategory
|
||||
)
|
||||
|
||||
class SubCategory(self.Base):
|
||||
class SubCategory(Base):
|
||||
__tablename__ = 'sub_category'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
products = sa.orm.relationship(
|
||||
'Product',
|
||||
Product,
|
||||
backref='sub_categories',
|
||||
secondary=subcategory_product
|
||||
)
|
||||
return SubCategory
|
||||
|
||||
class Product(self.Base):
|
||||
@pytest.fixture
|
||||
def Product(self, Base):
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
price = sa.Column(sa.Numeric)
|
||||
return Product
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.SubCategory = SubCategory
|
||||
self.Product = Product
|
||||
@pytest.fixture
|
||||
def init_models(self, Catalog, Category, SubCategory, Product):
|
||||
pass
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from pytest import raises
|
||||
|
||||
from sqlalchemy_utils.observer import observes
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestObservesForColumn(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestObservesForColumn(object):
|
||||
|
||||
def create_models(self):
|
||||
class Product(self.Base):
|
||||
@pytest.fixture
|
||||
def Product(self, Base):
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
price = sa.Column(sa.Integer)
|
||||
|
@ -17,21 +17,25 @@ class TestObservesForColumn(TestCase):
|
|||
@observes('price')
|
||||
def product_price_observer(self, price):
|
||||
self.price = price * 2
|
||||
return Product
|
||||
|
||||
self.Product = Product
|
||||
@pytest.fixture
|
||||
def init_models(self, Product):
|
||||
pass
|
||||
|
||||
def test_simple_insert(self):
|
||||
product = self.Product(price=100)
|
||||
self.session.add(product)
|
||||
self.session.flush()
|
||||
def test_simple_insert(self, session, Product):
|
||||
product = Product(price=100)
|
||||
session.add(product)
|
||||
session.flush()
|
||||
assert product.price == 200
|
||||
|
||||
|
||||
class TestObservesForColumnWithoutActualChanges(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestObservesForColumnWithoutActualChanges(object):
|
||||
|
||||
def create_models(self):
|
||||
class Product(self.Base):
|
||||
@pytest.fixture
|
||||
def Product(self, Base):
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
price = sa.Column(sa.Integer)
|
||||
|
@ -39,15 +43,18 @@ class TestObservesForColumnWithoutActualChanges(TestCase):
|
|||
@observes('price')
|
||||
def product_price_observer(self, price):
|
||||
raise Exception('Trying to change price')
|
||||
return Product
|
||||
|
||||
self.Product = Product
|
||||
@pytest.fixture
|
||||
def init_models(self, Product):
|
||||
pass
|
||||
|
||||
def test_only_notifies_observer_on_actual_changes(self):
|
||||
product = self.Product()
|
||||
self.session.add(product)
|
||||
self.session.flush()
|
||||
def test_only_notifies_observer_on_actual_changes(self, session, Product):
|
||||
product = Product()
|
||||
session.add(product)
|
||||
session.flush()
|
||||
|
||||
with raises(Exception) as e:
|
||||
with pytest.raises(Exception) as e:
|
||||
product.price = 500
|
||||
self.session.commit()
|
||||
session.commit()
|
||||
assert str(e.value) == 'Trying to change price'
|
||||
|
|
|
@ -1,137 +1,158 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.observer import observes
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestObservesForManyToManyToManyToMany(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.fixture
|
||||
def Catalog(Base):
|
||||
catalog_category = sa.Table(
|
||||
'catalog_category',
|
||||
Base.metadata,
|
||||
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
|
||||
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id'))
|
||||
)
|
||||
|
||||
def create_models(self):
|
||||
catalog_category = sa.Table(
|
||||
'catalog_category',
|
||||
self.Base.metadata,
|
||||
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
|
||||
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id'))
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
product_count = sa.Column(sa.Integer, default=0)
|
||||
|
||||
@observes('categories.sub_categories.products')
|
||||
def product_observer(self, products):
|
||||
self.product_count = len(products)
|
||||
|
||||
categories = sa.orm.relationship(
|
||||
'Category',
|
||||
backref='catalogs',
|
||||
secondary=catalog_category
|
||||
)
|
||||
return Catalog
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Category(Base):
|
||||
category_subcategory = sa.Table(
|
||||
'category_subcategory',
|
||||
Base.metadata,
|
||||
sa.Column(
|
||||
'category_id',
|
||||
sa.Integer,
|
||||
sa.ForeignKey('category.id')
|
||||
),
|
||||
sa.Column(
|
||||
'subcategory_id',
|
||||
sa.Integer,
|
||||
sa.ForeignKey('sub_category.id')
|
||||
)
|
||||
)
|
||||
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
sub_categories = sa.orm.relationship(
|
||||
'SubCategory',
|
||||
backref='categories',
|
||||
secondary=category_subcategory
|
||||
)
|
||||
return Category
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def SubCategory(Base):
|
||||
subcategory_product = sa.Table(
|
||||
'subcategory_product',
|
||||
Base.metadata,
|
||||
sa.Column(
|
||||
'subcategory_id',
|
||||
sa.Integer,
|
||||
sa.ForeignKey('sub_category.id')
|
||||
),
|
||||
sa.Column(
|
||||
'product_id',
|
||||
sa.Integer,
|
||||
sa.ForeignKey('product.id')
|
||||
)
|
||||
)
|
||||
|
||||
class SubCategory(Base):
|
||||
__tablename__ = 'sub_category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
products = sa.orm.relationship(
|
||||
'Product',
|
||||
backref='sub_categories',
|
||||
secondary=subcategory_product
|
||||
)
|
||||
|
||||
category_subcategory = sa.Table(
|
||||
'category_subcategory',
|
||||
self.Base.metadata,
|
||||
sa.Column(
|
||||
'category_id',
|
||||
sa.Integer,
|
||||
sa.ForeignKey('category.id')
|
||||
),
|
||||
sa.Column(
|
||||
'subcategory_id',
|
||||
sa.Integer,
|
||||
sa.ForeignKey('sub_category.id')
|
||||
)
|
||||
)
|
||||
return SubCategory
|
||||
|
||||
subcategory_product = sa.Table(
|
||||
'subcategory_product',
|
||||
self.Base.metadata,
|
||||
sa.Column(
|
||||
'subcategory_id',
|
||||
sa.Integer,
|
||||
sa.ForeignKey('sub_category.id')
|
||||
),
|
||||
sa.Column(
|
||||
'product_id',
|
||||
sa.Integer,
|
||||
sa.ForeignKey('product.id')
|
||||
)
|
||||
)
|
||||
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
product_count = sa.Column(sa.Integer, default=0)
|
||||
@pytest.fixture
|
||||
def Product(Base):
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
price = sa.Column(sa.Numeric)
|
||||
return Product
|
||||
|
||||
@observes('categories.sub_categories.products')
|
||||
def product_observer(self, products):
|
||||
self.product_count = len(products)
|
||||
|
||||
categories = sa.orm.relationship(
|
||||
'Category',
|
||||
backref='catalogs',
|
||||
secondary=catalog_category
|
||||
)
|
||||
@pytest.fixture
|
||||
def init_models(Catalog, Category, SubCategory, Product):
|
||||
pass
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
sub_categories = sa.orm.relationship(
|
||||
'SubCategory',
|
||||
backref='categories',
|
||||
secondary=category_subcategory
|
||||
)
|
||||
@pytest.fixture
|
||||
def catalog(session, Catalog, Category, SubCategory, Product):
|
||||
sub_category = SubCategory(products=[Product()])
|
||||
category = Category(sub_categories=[sub_category])
|
||||
catalog = Catalog(categories=[category])
|
||||
session.add(catalog)
|
||||
session.flush()
|
||||
return catalog
|
||||
|
||||
class SubCategory(self.Base):
|
||||
__tablename__ = 'sub_category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
products = sa.orm.relationship(
|
||||
'Product',
|
||||
backref='sub_categories',
|
||||
secondary=subcategory_product
|
||||
)
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
price = sa.Column(sa.Numeric)
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestObservesForManyToManyToManyToMany(object):
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.SubCategory = SubCategory
|
||||
self.Product = Product
|
||||
|
||||
def create_catalog(self):
|
||||
sub_category = self.SubCategory(products=[self.Product()])
|
||||
category = self.Category(sub_categories=[sub_category])
|
||||
catalog = self.Catalog(categories=[category])
|
||||
self.session.add(catalog)
|
||||
self.session.flush()
|
||||
return catalog
|
||||
|
||||
def test_simple_insert(self):
|
||||
catalog = self.create_catalog()
|
||||
def test_simple_insert(self, catalog):
|
||||
assert catalog.product_count == 1
|
||||
|
||||
def test_add_leaf_object(self):
|
||||
catalog = self.create_catalog()
|
||||
product = self.Product()
|
||||
def test_add_leaf_object(self, catalog, session, Product):
|
||||
product = Product()
|
||||
catalog.categories[0].sub_categories[0].products.append(product)
|
||||
self.session.flush()
|
||||
session.flush()
|
||||
assert catalog.product_count == 2
|
||||
|
||||
def test_remove_leaf_object(self):
|
||||
catalog = self.create_catalog()
|
||||
product = self.Product()
|
||||
def test_remove_leaf_object(self, catalog, session, Product):
|
||||
product = Product()
|
||||
catalog.categories[0].sub_categories[0].products.append(product)
|
||||
self.session.flush()
|
||||
self.session.delete(product)
|
||||
self.session.flush()
|
||||
session.flush()
|
||||
session.delete(product)
|
||||
session.flush()
|
||||
assert catalog.product_count == 1
|
||||
|
||||
def test_delete_intermediate_object(self):
|
||||
catalog = self.create_catalog()
|
||||
self.session.delete(catalog.categories[0].sub_categories[0])
|
||||
self.session.commit()
|
||||
def test_delete_intermediate_object(self, catalog, session):
|
||||
session.delete(catalog.categories[0].sub_categories[0])
|
||||
session.commit()
|
||||
assert catalog.product_count == 0
|
||||
|
||||
def test_gathered_objects_are_distinct(self):
|
||||
catalog = self.Catalog()
|
||||
category = self.Category(catalogs=[catalog])
|
||||
product = self.Product()
|
||||
def test_gathered_objects_are_distinct(
|
||||
self,
|
||||
session,
|
||||
Catalog,
|
||||
Category,
|
||||
SubCategory,
|
||||
Product
|
||||
):
|
||||
catalog = Catalog()
|
||||
category = Category(catalogs=[catalog])
|
||||
product = Product()
|
||||
category.sub_categories.append(
|
||||
self.SubCategory(products=[product])
|
||||
SubCategory(products=[product])
|
||||
)
|
||||
self.session.add(
|
||||
self.SubCategory(categories=[category], products=[product])
|
||||
session.add(
|
||||
SubCategory(categories=[category], products=[product])
|
||||
)
|
||||
self.session.commit()
|
||||
session.commit()
|
||||
assert catalog.product_count == 1
|
||||
|
|
|
@ -1,107 +1,127 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.observer import observes
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestObservesFor3LevelDeepOneToMany(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.fixture
|
||||
def Catalog(Base):
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
product_count = sa.Column(sa.Integer, default=0)
|
||||
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
product_count = sa.Column(sa.Integer, default=0)
|
||||
@observes('categories.sub_categories.products')
|
||||
def product_observer(self, products):
|
||||
self.product_count = len(products)
|
||||
|
||||
@observes('categories.sub_categories.products')
|
||||
def product_observer(self, products):
|
||||
self.product_count = len(products)
|
||||
categories = sa.orm.relationship('Category', backref='catalog')
|
||||
return Catalog
|
||||
|
||||
categories = sa.orm.relationship('Category', backref='catalog')
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
@pytest.fixture
|
||||
def Category(Base):
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
|
||||
sub_categories = sa.orm.relationship(
|
||||
'SubCategory', backref='category'
|
||||
)
|
||||
sub_categories = sa.orm.relationship(
|
||||
'SubCategory', backref='category'
|
||||
)
|
||||
return Category
|
||||
|
||||
class SubCategory(self.Base):
|
||||
__tablename__ = 'sub_category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
products = sa.orm.relationship(
|
||||
'Product',
|
||||
backref='sub_category'
|
||||
)
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
price = sa.Column(sa.Numeric)
|
||||
@pytest.fixture
|
||||
def SubCategory(Base):
|
||||
class SubCategory(Base):
|
||||
__tablename__ = 'sub_category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
products = sa.orm.relationship(
|
||||
'Product',
|
||||
backref='sub_category'
|
||||
)
|
||||
return SubCategory
|
||||
|
||||
sub_category_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey('sub_category.id')
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return '<Product id=%r>' % self.id
|
||||
@pytest.fixture
|
||||
def Product(Base):
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.SubCategory = SubCategory
|
||||
self.Product = Product
|
||||
sub_category_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey('sub_category.id')
|
||||
)
|
||||
|
||||
def create_catalog(self):
|
||||
sub_category = self.SubCategory(products=[self.Product()])
|
||||
category = self.Category(sub_categories=[sub_category])
|
||||
catalog = self.Catalog(categories=[category])
|
||||
self.session.add(catalog)
|
||||
self.session.commit()
|
||||
return catalog
|
||||
def __repr__(self):
|
||||
return '<Product id=%r>' % self.id
|
||||
return Product
|
||||
|
||||
def test_simple_insert(self):
|
||||
catalog = self.create_catalog()
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(Catalog, Category, SubCategory, Product):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def catalog(session, Catalog, Category, SubCategory, Product):
|
||||
sub_category = SubCategory(products=[Product()])
|
||||
category = Category(sub_categories=[sub_category])
|
||||
catalog = Catalog(categories=[category])
|
||||
session.add(catalog)
|
||||
session.commit()
|
||||
return catalog
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestObservesFor3LevelDeepOneToMany(object):
|
||||
|
||||
def test_simple_insert(self, catalog):
|
||||
assert catalog.product_count == 1
|
||||
|
||||
def test_add_leaf_object(self):
|
||||
catalog = self.create_catalog()
|
||||
product = self.Product()
|
||||
def test_add_leaf_object(self, catalog, session, Product):
|
||||
product = Product()
|
||||
catalog.categories[0].sub_categories[0].products.append(product)
|
||||
self.session.flush()
|
||||
session.flush()
|
||||
assert catalog.product_count == 2
|
||||
|
||||
def test_remove_leaf_object(self):
|
||||
catalog = self.create_catalog()
|
||||
product = self.Product()
|
||||
def test_remove_leaf_object(self, catalog, session, Product):
|
||||
product = Product()
|
||||
catalog.categories[0].sub_categories[0].products.append(product)
|
||||
self.session.flush()
|
||||
self.session.delete(product)
|
||||
self.session.commit()
|
||||
session.flush()
|
||||
session.delete(product)
|
||||
session.commit()
|
||||
assert catalog.product_count == 1
|
||||
self.session.delete(
|
||||
session.delete(
|
||||
catalog.categories[0].sub_categories[0].products[0]
|
||||
)
|
||||
self.session.commit()
|
||||
session.commit()
|
||||
assert catalog.product_count == 0
|
||||
|
||||
def test_delete_intermediate_object(self):
|
||||
catalog = self.create_catalog()
|
||||
self.session.delete(catalog.categories[0].sub_categories[0])
|
||||
self.session.commit()
|
||||
def test_delete_intermediate_object(self, catalog, session):
|
||||
session.delete(catalog.categories[0].sub_categories[0])
|
||||
session.commit()
|
||||
assert catalog.product_count == 0
|
||||
|
||||
def test_gathered_objects_are_distinct(self):
|
||||
catalog = self.Catalog()
|
||||
category = self.Category(catalog=catalog)
|
||||
product = self.Product()
|
||||
def test_gathered_objects_are_distinct(
|
||||
self,
|
||||
session,
|
||||
Catalog,
|
||||
Category,
|
||||
SubCategory,
|
||||
Product
|
||||
):
|
||||
catalog = Catalog()
|
||||
category = Category(catalog=catalog)
|
||||
product = Product()
|
||||
category.sub_categories.append(
|
||||
self.SubCategory(products=[product])
|
||||
SubCategory(products=[product])
|
||||
)
|
||||
self.session.add(
|
||||
self.SubCategory(category=category, products=[product])
|
||||
session.add(
|
||||
SubCategory(category=category, products=[product])
|
||||
)
|
||||
self.session.commit()
|
||||
session.commit()
|
||||
assert catalog.product_count == 1
|
||||
|
|
|
@ -1,96 +1,116 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.observer import observes
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestObservesForOneToManyToOneToMany(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.fixture
|
||||
def Catalog(Base):
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
product_count = sa.Column(sa.Integer, default=0)
|
||||
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
product_count = sa.Column(sa.Integer, default=0)
|
||||
@observes('categories.sub_category.products')
|
||||
def product_observer(self, products):
|
||||
self.product_count = len(products)
|
||||
|
||||
@observes('categories.sub_category.products')
|
||||
def product_observer(self, products):
|
||||
self.product_count = len(products)
|
||||
categories = sa.orm.relationship('Category', backref='catalog')
|
||||
return Catalog
|
||||
|
||||
categories = sa.orm.relationship('Category', backref='catalog')
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
@pytest.fixture
|
||||
def Category(Base):
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
|
||||
sub_category = sa.orm.relationship(
|
||||
'SubCategory',
|
||||
uselist=False,
|
||||
backref='category'
|
||||
)
|
||||
sub_category = sa.orm.relationship(
|
||||
'SubCategory',
|
||||
uselist=False,
|
||||
backref='category'
|
||||
)
|
||||
return Category
|
||||
|
||||
class SubCategory(self.Base):
|
||||
__tablename__ = 'sub_category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
products = sa.orm.relationship('Product', backref='sub_category')
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
price = sa.Column(sa.Numeric)
|
||||
@pytest.fixture
|
||||
def SubCategory(Base):
|
||||
class SubCategory(Base):
|
||||
__tablename__ = 'sub_category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
products = sa.orm.relationship('Product', backref='sub_category')
|
||||
return SubCategory
|
||||
|
||||
sub_category_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey('sub_category.id')
|
||||
)
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.SubCategory = SubCategory
|
||||
self.Product = Product
|
||||
@pytest.fixture
|
||||
def Product(Base):
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
def create_catalog(self):
|
||||
sub_category = self.SubCategory(products=[self.Product()])
|
||||
category = self.Category(sub_category=sub_category)
|
||||
catalog = self.Catalog(categories=[category])
|
||||
self.session.add(catalog)
|
||||
self.session.flush()
|
||||
return catalog
|
||||
sub_category_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey('sub_category.id')
|
||||
)
|
||||
return Product
|
||||
|
||||
def test_simple_insert(self):
|
||||
catalog = self.create_catalog()
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(Catalog, Category, SubCategory, Product):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def catalog(session, Catalog, Category, SubCategory, Product):
|
||||
sub_category = SubCategory(products=[Product()])
|
||||
category = Category(sub_category=sub_category)
|
||||
catalog = Catalog(categories=[category])
|
||||
session.add(catalog)
|
||||
session.flush()
|
||||
return catalog
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestObservesForOneToManyToOneToMany(object):
|
||||
|
||||
def test_simple_insert(self, catalog):
|
||||
assert catalog.product_count == 1
|
||||
|
||||
def test_add_leaf_object(self):
|
||||
catalog = self.create_catalog()
|
||||
product = self.Product()
|
||||
def test_add_leaf_object(self, catalog, session, Product):
|
||||
product = Product()
|
||||
catalog.categories[0].sub_category.products.append(product)
|
||||
self.session.flush()
|
||||
session.flush()
|
||||
assert catalog.product_count == 2
|
||||
|
||||
def test_remove_leaf_object(self):
|
||||
catalog = self.create_catalog()
|
||||
product = self.Product()
|
||||
def test_remove_leaf_object(self, catalog, session, Product):
|
||||
product = Product()
|
||||
catalog.categories[0].sub_category.products.append(product)
|
||||
self.session.flush()
|
||||
self.session.delete(product)
|
||||
self.session.flush()
|
||||
session.flush()
|
||||
session.delete(product)
|
||||
session.flush()
|
||||
assert catalog.product_count == 1
|
||||
|
||||
def test_delete_intermediate_object(self):
|
||||
catalog = self.create_catalog()
|
||||
self.session.delete(catalog.categories[0].sub_category)
|
||||
self.session.commit()
|
||||
def test_delete_intermediate_object(self, catalog, session):
|
||||
session.delete(catalog.categories[0].sub_category)
|
||||
session.commit()
|
||||
assert catalog.product_count == 0
|
||||
|
||||
def test_gathered_objects_are_distinct(self):
|
||||
catalog = self.Catalog()
|
||||
category = self.Category(catalog=catalog)
|
||||
product = self.Product()
|
||||
category.sub_category = self.SubCategory(products=[product])
|
||||
self.session.add(
|
||||
self.Category(catalog=catalog, sub_category=category.sub_category)
|
||||
def test_gathered_objects_are_distinct(
|
||||
self,
|
||||
session,
|
||||
Catalog,
|
||||
Category,
|
||||
SubCategory,
|
||||
Product
|
||||
):
|
||||
catalog = Catalog()
|
||||
category = Category(catalog=catalog)
|
||||
product = Product()
|
||||
category.sub_category = SubCategory(products=[product])
|
||||
session.add(
|
||||
Category(catalog=catalog, sub_category=category.sub_category)
|
||||
)
|
||||
self.session.commit()
|
||||
session.commit()
|
||||
assert catalog.product_count == 1
|
||||
|
|
|
@ -1,53 +1,66 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.observer import observes
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestObservesForOneToManyToOneToMany(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.fixture
|
||||
def Device(Base):
|
||||
class Device(Base):
|
||||
__tablename__ = 'device'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
return Device
|
||||
|
||||
def create_models(self):
|
||||
class Device(self.Base):
|
||||
__tablename__ = 'device'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
|
||||
class Order(self.Base):
|
||||
__tablename__ = 'order'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
@pytest.fixture
|
||||
def Order(Base):
|
||||
class Order(Base):
|
||||
__tablename__ = 'order'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
device_id = sa.Column(
|
||||
'device', sa.ForeignKey('device.id'), nullable=False
|
||||
device_id = sa.Column(
|
||||
'device', sa.ForeignKey('device.id'), nullable=False
|
||||
)
|
||||
device = sa.orm.relationship('Device', backref='orders')
|
||||
return Order
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def SalesInvoice(Base):
|
||||
class SalesInvoice(Base):
|
||||
__tablename__ = 'sales_invoice'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
order_id = sa.Column(
|
||||
'order',
|
||||
sa.ForeignKey('order.id'),
|
||||
nullable=False
|
||||
)
|
||||
order = sa.orm.relationship(
|
||||
'Order',
|
||||
backref=sa.orm.backref(
|
||||
'invoice',
|
||||
uselist=False
|
||||
)
|
||||
device = sa.orm.relationship('Device', backref='orders')
|
||||
)
|
||||
device_name = sa.Column(sa.String)
|
||||
|
||||
class SalesInvoice(self.Base):
|
||||
__tablename__ = 'sales_invoice'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
order_id = sa.Column(
|
||||
'order',
|
||||
sa.ForeignKey('order.id'),
|
||||
nullable=False
|
||||
)
|
||||
order = sa.orm.relationship(
|
||||
'Order',
|
||||
backref=sa.orm.backref(
|
||||
'invoice',
|
||||
uselist=False
|
||||
)
|
||||
)
|
||||
device_name = sa.Column(sa.String)
|
||||
@observes('order.device')
|
||||
def process_device(self, device):
|
||||
self.device_name = device.name
|
||||
|
||||
@observes('order.device')
|
||||
def process_device(self, device):
|
||||
self.device_name = device.name
|
||||
return SalesInvoice
|
||||
|
||||
self.Device = Device
|
||||
self.Order = Order
|
||||
self.SalesInvoice = SalesInvoice
|
||||
|
||||
def test_observable_root_obj_is_none(self):
|
||||
order = self.Order(device=self.Device(name='Something'))
|
||||
self.session.add(order)
|
||||
self.session.flush()
|
||||
@pytest.fixture
|
||||
def init_models(Device, Order, SalesInvoice):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestObservesForOneToManyToOneToMany(object):
|
||||
|
||||
def test_observable_root_obj_is_none(self, session, Device, Order):
|
||||
order = Order(device=Device(name='Something'))
|
||||
session.add(order)
|
||||
session.flush()
|
||||
|
|
|
@ -1,84 +1,98 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.observer import observes
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestObservesForOneToOneToOneToOne(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.fixture
|
||||
def Catalog(Base):
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
product_price = sa.Column(sa.Integer)
|
||||
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
product_price = sa.Column(sa.Integer)
|
||||
@observes('category.sub_category.product')
|
||||
def product_observer(self, product):
|
||||
self.product_price = product.price if product else None
|
||||
|
||||
@observes('category.sub_category.product')
|
||||
def product_observer(self, product):
|
||||
self.product_price = product.price if product else None
|
||||
category = sa.orm.relationship(
|
||||
'Category',
|
||||
uselist=False,
|
||||
backref='catalog'
|
||||
)
|
||||
return Catalog
|
||||
|
||||
category = sa.orm.relationship(
|
||||
'Category',
|
||||
uselist=False,
|
||||
backref='catalog'
|
||||
)
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
@pytest.fixture
|
||||
def Category(Base):
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
|
||||
sub_category = sa.orm.relationship(
|
||||
'SubCategory',
|
||||
uselist=False,
|
||||
backref='category'
|
||||
)
|
||||
sub_category = sa.orm.relationship(
|
||||
'SubCategory',
|
||||
uselist=False,
|
||||
backref='category'
|
||||
)
|
||||
return Category
|
||||
|
||||
class SubCategory(self.Base):
|
||||
__tablename__ = 'sub_category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
product = sa.orm.relationship(
|
||||
'Product',
|
||||
uselist=False,
|
||||
backref='sub_category'
|
||||
)
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
price = sa.Column(sa.Integer)
|
||||
@pytest.fixture
|
||||
def SubCategory(Base):
|
||||
class SubCategory(Base):
|
||||
__tablename__ = 'sub_category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
product = sa.orm.relationship(
|
||||
'Product',
|
||||
uselist=False,
|
||||
backref='sub_category'
|
||||
)
|
||||
return SubCategory
|
||||
|
||||
sub_category_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey('sub_category.id')
|
||||
)
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.SubCategory = SubCategory
|
||||
self.Product = Product
|
||||
@pytest.fixture
|
||||
def Product(Base):
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
price = sa.Column(sa.Integer)
|
||||
|
||||
def create_catalog(self):
|
||||
sub_category = self.SubCategory(product=self.Product(price=123))
|
||||
category = self.Category(sub_category=sub_category)
|
||||
catalog = self.Catalog(category=category)
|
||||
self.session.add(catalog)
|
||||
self.session.flush()
|
||||
return catalog
|
||||
sub_category_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey('sub_category.id')
|
||||
)
|
||||
return Product
|
||||
|
||||
def test_simple_insert(self):
|
||||
catalog = self.create_catalog()
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(Catalog, Category, SubCategory, Product):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def catalog(session, Catalog, Category, SubCategory, Product):
|
||||
sub_category = SubCategory(product=Product(price=123))
|
||||
category = Category(sub_category=sub_category)
|
||||
catalog = Catalog(category=category)
|
||||
session.add(catalog)
|
||||
session.flush()
|
||||
return catalog
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestObservesForOneToOneToOneToOne(object):
|
||||
|
||||
def test_simple_insert(self, catalog):
|
||||
assert catalog.product_price == 123
|
||||
|
||||
def test_replace_leaf_object(self):
|
||||
catalog = self.create_catalog()
|
||||
product = self.Product(price=44)
|
||||
def test_replace_leaf_object(self, catalog, session, Product):
|
||||
product = Product(price=44)
|
||||
catalog.category.sub_category.product = product
|
||||
self.session.flush()
|
||||
session.flush()
|
||||
assert catalog.product_price == 44
|
||||
|
||||
def test_delete_leaf_object(self):
|
||||
catalog = self.create_catalog()
|
||||
self.session.delete(catalog.category.sub_category.product)
|
||||
self.session.flush()
|
||||
def test_delete_leaf_object(self, catalog, session):
|
||||
session.delete(catalog.category.sub_category.product)
|
||||
session.flush()
|
||||
assert catalog.product_price is None
|
||||
|
|
|
@ -1,32 +1,36 @@
|
|||
import pytest
|
||||
import six
|
||||
from pytest import mark, raises
|
||||
|
||||
from sqlalchemy_utils import Country, i18n
|
||||
|
||||
|
||||
@mark.skipif('i18n.babel is None')
|
||||
@pytest.fixture
|
||||
def set_get_locale():
|
||||
i18n.get_locale = lambda: i18n.babel.Locale('en')
|
||||
|
||||
|
||||
@pytest.mark.skipif('i18n.babel is None')
|
||||
@pytest.mark.usefixtures('set_get_locale')
|
||||
class TestCountry(object):
|
||||
def setup_method(self, method):
|
||||
i18n.get_locale = lambda: i18n.babel.Locale('en')
|
||||
|
||||
def test_init(self):
|
||||
assert Country(u'FI') == Country(Country(u'FI'))
|
||||
|
||||
def test_constructor_with_wrong_type(self):
|
||||
with raises(TypeError) as e:
|
||||
with pytest.raises(TypeError) as e:
|
||||
Country(None)
|
||||
assert str(e.value) == (
|
||||
"Country() argument must be a string or a country, not 'NoneType'"
|
||||
)
|
||||
|
||||
def test_constructor_with_invalid_code(self):
|
||||
with raises(ValueError) as e:
|
||||
with pytest.raises(ValueError) as e:
|
||||
Country('SomeUnknownCode')
|
||||
assert str(e.value) == (
|
||||
'Could not convert string to country code: SomeUnknownCode'
|
||||
)
|
||||
|
||||
@mark.parametrize(
|
||||
@pytest.mark.parametrize(
|
||||
'code',
|
||||
(
|
||||
'FI',
|
||||
|
@ -37,7 +41,7 @@ class TestCountry(object):
|
|||
Country.validate(code)
|
||||
|
||||
def test_validate_with_invalid_code(self):
|
||||
with raises(ValueError) as e:
|
||||
with pytest.raises(ValueError) as e:
|
||||
Country.validate('SomeUnknownCode')
|
||||
assert str(e.value) == (
|
||||
'Could not convert string to country code: SomeUnknownCode'
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import pytest
|
||||
import six
|
||||
from pytest import mark, raises
|
||||
|
||||
from sqlalchemy_utils import Currency, i18n
|
||||
|
||||
|
||||
@mark.skipif('i18n.babel is None')
|
||||
@pytest.fixture
|
||||
def set_get_locale():
|
||||
i18n.get_locale = lambda: i18n.babel.Locale('en')
|
||||
|
||||
|
||||
@pytest.mark.skipif('i18n.babel is None')
|
||||
@pytest.mark.usefixtures('set_get_locale')
|
||||
class TestCurrency(object):
|
||||
def setup_method(self, method):
|
||||
i18n.get_locale = lambda: i18n.babel.Locale('en')
|
||||
|
||||
def test_init(self):
|
||||
assert Currency('USD') == Currency(Currency('USD'))
|
||||
|
@ -17,14 +21,14 @@ class TestCurrency(object):
|
|||
assert len(set([Currency('USD'), Currency('USD')])) == 1
|
||||
|
||||
def test_invalid_currency_code(self):
|
||||
with raises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
Currency('Unknown code')
|
||||
|
||||
def test_invalid_currency_code_type(self):
|
||||
with raises(TypeError):
|
||||
with pytest.raises(TypeError):
|
||||
Currency(None)
|
||||
|
||||
@mark.parametrize(
|
||||
@pytest.mark.parametrize(
|
||||
('code', 'name'),
|
||||
(
|
||||
('USD', 'US Dollar'),
|
||||
|
@ -34,7 +38,7 @@ class TestCurrency(object):
|
|||
def test_name_property(self, code, name):
|
||||
assert Currency(code).name == name
|
||||
|
||||
@mark.parametrize(
|
||||
@pytest.mark.parametrize(
|
||||
('code', 'symbol'),
|
||||
(
|
||||
('USD', u'$'),
|
||||
|
|
|
@ -6,10 +6,14 @@ from sqlalchemy_utils import i18n
|
|||
from sqlalchemy_utils.primitives import WeekDay, WeekDays
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_get_locale():
|
||||
i18n.get_locale = lambda: i18n.babel.Locale('fi')
|
||||
|
||||
|
||||
@pytest.mark.skipif('i18n.babel is None')
|
||||
@pytest.mark.usefixtures('set_get_locale')
|
||||
class TestWeekDay(object):
|
||||
def setup_method(self, method):
|
||||
i18n.get_locale = lambda: i18n.babel.Locale('fi')
|
||||
|
||||
def test_constructor_with_valid_index(self):
|
||||
day = WeekDay(1)
|
||||
|
|
|
@ -1,26 +1,27 @@
|
|||
import pytest
|
||||
|
||||
from sqlalchemy_utils.relationships import chained_join
|
||||
from tests import TestCase
|
||||
from tests.mixins import (
|
||||
|
||||
from ..mixins import (
|
||||
ThreeLevelDeepManyToMany,
|
||||
ThreeLevelDeepOneToMany,
|
||||
ThreeLevelDeepOneToOne
|
||||
)
|
||||
|
||||
|
||||
class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
create_tables = False
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany):
|
||||
|
||||
def test_simple_join(self):
|
||||
assert str(chained_join(self.Catalog.categories)) == (
|
||||
def test_simple_join(self, Catalog):
|
||||
assert str(chained_join(Catalog.categories)) == (
|
||||
'catalog_category JOIN category ON '
|
||||
'category._id = catalog_category.category_id'
|
||||
)
|
||||
|
||||
def test_two_relations(self):
|
||||
def test_two_relations(self, Catalog, Category):
|
||||
sql = chained_join(
|
||||
self.Catalog.categories,
|
||||
self.Category.sub_categories
|
||||
Catalog.categories,
|
||||
Category.sub_categories
|
||||
)
|
||||
assert str(sql) == (
|
||||
'catalog_category JOIN category ON category._id = '
|
||||
|
@ -30,11 +31,11 @@ class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase):
|
|||
'category_subcategory.subcategory_id'
|
||||
)
|
||||
|
||||
def test_three_relations(self):
|
||||
def test_three_relations(self, Catalog, Category, SubCategory):
|
||||
sql = chained_join(
|
||||
self.Catalog.categories,
|
||||
self.Category.sub_categories,
|
||||
self.SubCategory.products
|
||||
Catalog.categories,
|
||||
Category.sub_categories,
|
||||
SubCategory.products
|
||||
)
|
||||
assert str(sql) == (
|
||||
'catalog_category JOIN category ON category._id = '
|
||||
|
@ -47,28 +48,27 @@ class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase):
|
|||
)
|
||||
|
||||
|
||||
class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany, TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
create_tables = False
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany):
|
||||
|
||||
def test_simple_join(self):
|
||||
assert str(chained_join(self.Catalog.categories)) == 'category'
|
||||
def test_simple_join(self, Catalog):
|
||||
assert str(chained_join(Catalog.categories)) == 'category'
|
||||
|
||||
def test_two_relations(self):
|
||||
def test_two_relations(self, Catalog, Category):
|
||||
sql = chained_join(
|
||||
self.Catalog.categories,
|
||||
self.Category.sub_categories
|
||||
Catalog.categories,
|
||||
Category.sub_categories
|
||||
)
|
||||
assert str(sql) == (
|
||||
'category JOIN sub_category ON category._id = '
|
||||
'sub_category._category_id'
|
||||
)
|
||||
|
||||
def test_three_relations(self):
|
||||
def test_three_relations(self, Catalog, Category, SubCategory):
|
||||
sql = chained_join(
|
||||
self.Catalog.categories,
|
||||
self.Category.sub_categories,
|
||||
self.SubCategory.products
|
||||
Catalog.categories,
|
||||
Category.sub_categories,
|
||||
SubCategory.products
|
||||
)
|
||||
assert str(sql) == (
|
||||
'category JOIN sub_category ON category._id = '
|
||||
|
@ -77,28 +77,27 @@ class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany, TestCase):
|
|||
)
|
||||
|
||||
|
||||
class TestChainedJoinForDeepOneToOne(ThreeLevelDeepOneToOne, TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
create_tables = False
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestChainedJoinForDeepOneToOne(ThreeLevelDeepOneToOne):
|
||||
|
||||
def test_simple_join(self):
|
||||
assert str(chained_join(self.Catalog.category)) == 'category'
|
||||
def test_simple_join(self, Catalog):
|
||||
assert str(chained_join(Catalog.category)) == 'category'
|
||||
|
||||
def test_two_relations(self):
|
||||
def test_two_relations(self, Catalog, Category):
|
||||
sql = chained_join(
|
||||
self.Catalog.category,
|
||||
self.Category.sub_category
|
||||
Catalog.category,
|
||||
Category.sub_category
|
||||
)
|
||||
assert str(sql) == (
|
||||
'category JOIN sub_category ON category._id = '
|
||||
'sub_category._category_id'
|
||||
)
|
||||
|
||||
def test_three_relations(self):
|
||||
def test_three_relations(self, Catalog, Category, SubCategory):
|
||||
sql = chained_join(
|
||||
self.Catalog.category,
|
||||
self.Category.sub_category,
|
||||
self.SubCategory.product
|
||||
Catalog.category,
|
||||
Category.sub_category,
|
||||
SubCategory.product
|
||||
)
|
||||
assert str(sql) == (
|
||||
'category JOIN sub_category ON category._id = '
|
||||
|
|
|
@ -1,31 +1,23 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from sqlalchemy_utils.relationships import select_correlated_expression
|
||||
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def base():
|
||||
return declarative_base()
|
||||
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def group_user_cls(base):
|
||||
@pytest.fixture
|
||||
def group_user_tbl(Base):
|
||||
return sa.Table(
|
||||
'group_user',
|
||||
base.metadata,
|
||||
Base.metadata,
|
||||
sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')),
|
||||
sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id'))
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def group_cls(base):
|
||||
class Group(base):
|
||||
@pytest.fixture
|
||||
def group_tbl(Base):
|
||||
class Group(Base):
|
||||
__tablename__ = 'group'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
|
@ -33,11 +25,11 @@ def group_cls(base):
|
|||
return Group
|
||||
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def friendship_cls(base):
|
||||
@pytest.fixture
|
||||
def friendship_tbl(Base):
|
||||
return sa.Table(
|
||||
'friendships',
|
||||
base.metadata,
|
||||
Base.metadata,
|
||||
sa.Column(
|
||||
'friend_a_id',
|
||||
sa.Integer,
|
||||
|
@ -53,35 +45,37 @@ def friendship_cls(base):
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def user_cls(base, group_user_cls, friendship_cls):
|
||||
class User(base):
|
||||
@pytest.fixture
|
||||
def User(Base, group_user_tbl, friendship_tbl):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
groups = sa.orm.relationship(
|
||||
'Group',
|
||||
secondary=group_user_cls,
|
||||
secondary=group_user_tbl,
|
||||
backref='users'
|
||||
)
|
||||
|
||||
# this relationship is used for persistence
|
||||
friends = sa.orm.relationship(
|
||||
'User',
|
||||
secondary=friendship_cls,
|
||||
primaryjoin=id == friendship_cls.c.friend_a_id,
|
||||
secondaryjoin=id == friendship_cls.c.friend_b_id,
|
||||
secondary=friendship_tbl,
|
||||
primaryjoin=id == friendship_tbl.c.friend_a_id,
|
||||
secondaryjoin=id == friendship_tbl.c.friend_b_id,
|
||||
)
|
||||
|
||||
friendship_union = sa.select([
|
||||
friendship_cls.c.friend_a_id,
|
||||
friendship_cls.c.friend_b_id
|
||||
friendship_union = (
|
||||
sa.select([
|
||||
friendship_tbl.c.friend_a_id,
|
||||
friendship_tbl.c.friend_b_id
|
||||
]).union(
|
||||
sa.select([
|
||||
friendship_cls.c.friend_b_id,
|
||||
friendship_cls.c.friend_a_id]
|
||||
friendship_tbl.c.friend_b_id,
|
||||
friendship_tbl.c.friend_a_id]
|
||||
)
|
||||
).alias()
|
||||
).alias()
|
||||
)
|
||||
|
||||
User.all_friends = sa.orm.relationship(
|
||||
'User',
|
||||
|
@ -94,9 +88,9 @@ def user_cls(base, group_user_cls, friendship_cls):
|
|||
return User
|
||||
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def category_cls(base, group_user_cls, friendship_cls):
|
||||
class Category(base):
|
||||
@pytest.fixture
|
||||
def Category(Base, group_user_tbl, friendship_tbl):
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
|
@ -111,9 +105,9 @@ def category_cls(base, group_user_cls, friendship_cls):
|
|||
return Category
|
||||
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def article_cls(base, category_cls, user_cls):
|
||||
class Article(base):
|
||||
@pytest.fixture
|
||||
def Article(Base, Category, User):
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
|
@ -129,144 +123,104 @@ def article_cls(base, category_cls, user_cls):
|
|||
|
||||
content = sa.Column(sa.String)
|
||||
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey(category_cls.id))
|
||||
category = sa.orm.relationship(category_cls, backref='articles')
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
||||
category = sa.orm.relationship(Category, backref='articles')
|
||||
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey(user_cls.id))
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
|
||||
author = sa.orm.relationship(
|
||||
user_cls,
|
||||
primaryjoin=author_id == user_cls.id,
|
||||
User,
|
||||
primaryjoin=author_id == User.id,
|
||||
backref='authored_articles'
|
||||
)
|
||||
|
||||
owner_id = sa.Column(sa.Integer, sa.ForeignKey(user_cls.id))
|
||||
owner_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
|
||||
owner = sa.orm.relationship(
|
||||
user_cls,
|
||||
primaryjoin=owner_id == user_cls.id,
|
||||
User,
|
||||
primaryjoin=owner_id == User.id,
|
||||
backref='owned_articles'
|
||||
)
|
||||
return Article
|
||||
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def comment_cls(base, article_cls, user_cls):
|
||||
class Comment(base):
|
||||
@pytest.fixture
|
||||
def Comment(Base, Article, User):
|
||||
class Comment(Base):
|
||||
__tablename__ = 'comment'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
content = sa.Column(sa.String)
|
||||
article_id = sa.Column(sa.Integer, sa.ForeignKey(article_cls.id))
|
||||
article = sa.orm.relationship(article_cls, backref='comments')
|
||||
article_id = sa.Column(sa.Integer, sa.ForeignKey(Article.id))
|
||||
article = sa.orm.relationship(Article, backref='comments')
|
||||
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey(user_cls.id))
|
||||
author = sa.orm.relationship(user_cls, backref='comments')
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
|
||||
author = sa.orm.relationship(User, backref='comments')
|
||||
|
||||
article_cls.comment_count = sa.orm.column_property(
|
||||
Article.comment_count = sa.orm.column_property(
|
||||
sa.select([sa.func.count(Comment.id)])
|
||||
.where(Comment.article_id == article_cls.id)
|
||||
.correlate_except(article_cls)
|
||||
.where(Comment.article_id == Article.id)
|
||||
.correlate_except(Article)
|
||||
)
|
||||
|
||||
return Comment
|
||||
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def composite_pk_cls(base):
|
||||
class CompositePKModel(base):
|
||||
__tablename__ = 'composite_pk_model'
|
||||
a = sa.Column(sa.Integer, primary_key=True)
|
||||
b = sa.Column(sa.Integer, primary_key=True)
|
||||
return CompositePKModel
|
||||
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def dns():
|
||||
return 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
|
||||
@pytest.yield_fixture(scope='class')
|
||||
def engine(dns):
|
||||
engine = create_engine(dns)
|
||||
engine.echo = True
|
||||
yield engine
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@pytest.yield_fixture(scope='class')
|
||||
def connection(engine):
|
||||
conn = engine.connect()
|
||||
yield conn
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def model_mapping(article_cls, category_cls, comment_cls, group_cls, user_cls):
|
||||
@pytest.fixture
|
||||
def model_mapping(Article, Category, Comment, group_tbl, User):
|
||||
return {
|
||||
'articles': article_cls,
|
||||
'categories': category_cls,
|
||||
'comments': comment_cls,
|
||||
'groups': group_cls,
|
||||
'users': user_cls
|
||||
'articles': Article,
|
||||
'categories': Category,
|
||||
'comments': Comment,
|
||||
'groups': group_tbl,
|
||||
'users': User
|
||||
}
|
||||
|
||||
|
||||
@pytest.yield_fixture(scope='class')
|
||||
def table_creator(base, connection, model_mapping):
|
||||
sa.orm.configure_mappers()
|
||||
base.metadata.create_all(connection)
|
||||
yield
|
||||
base.metadata.drop_all(connection)
|
||||
@pytest.fixture
|
||||
def init_models(Article, Category, Comment, group_tbl, User):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.yield_fixture(scope='class')
|
||||
def session(connection):
|
||||
Session = sessionmaker(bind=connection)
|
||||
session = Session()
|
||||
yield session
|
||||
session.close_all()
|
||||
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
@pytest.fixture
|
||||
def dataset(
|
||||
session,
|
||||
user_cls,
|
||||
group_cls,
|
||||
article_cls,
|
||||
category_cls,
|
||||
comment_cls
|
||||
User,
|
||||
group_tbl,
|
||||
Article,
|
||||
Category,
|
||||
Comment
|
||||
):
|
||||
group = group_cls(name='Group 1')
|
||||
group2 = group_cls(name='Group 2')
|
||||
user = user_cls(id=1, name='User 1', groups=[group, group2])
|
||||
user2 = user_cls(id=2, name='User 2')
|
||||
user3 = user_cls(id=3, name='User 3', groups=[group])
|
||||
user4 = user_cls(id=4, name='User 4', groups=[group2])
|
||||
user5 = user_cls(id=5, name='User 5')
|
||||
group = group_tbl(name='Group 1')
|
||||
group2 = group_tbl(name='Group 2')
|
||||
user = User(id=1, name='User 1', groups=[group, group2])
|
||||
user2 = User(id=2, name='User 2')
|
||||
user3 = User(id=3, name='User 3', groups=[group])
|
||||
user4 = User(id=4, name='User 4', groups=[group2])
|
||||
user5 = User(id=5, name='User 5')
|
||||
|
||||
user.friends = [user2]
|
||||
user2.friends = [user3, user4]
|
||||
user3.friends = [user5]
|
||||
|
||||
article = article_cls(
|
||||
article = Article(
|
||||
name='Some article',
|
||||
author=user,
|
||||
owner=user2,
|
||||
category=category_cls(
|
||||
category=Category(
|
||||
id=1,
|
||||
name='Some category',
|
||||
subcategories=[
|
||||
category_cls(
|
||||
Category(
|
||||
id=2,
|
||||
name='Subcategory 1',
|
||||
subcategories=[
|
||||
category_cls(
|
||||
Category(
|
||||
id=3,
|
||||
name='Subsubcategory 1',
|
||||
subcategories=[
|
||||
category_cls(
|
||||
Category(
|
||||
id=5,
|
||||
name='Subsubsubcategory 1',
|
||||
),
|
||||
category_cls(
|
||||
Category(
|
||||
id=6,
|
||||
name='Subsubsubcategory 2',
|
||||
)
|
||||
|
@ -274,11 +228,11 @@ def dataset(
|
|||
)
|
||||
]
|
||||
),
|
||||
category_cls(id=4, name='Subcategory 2'),
|
||||
Category(id=4, name='Subcategory 2'),
|
||||
]
|
||||
),
|
||||
comments=[
|
||||
comment_cls(
|
||||
Comment(
|
||||
content='Some comment',
|
||||
author=user
|
||||
)
|
||||
|
@ -290,7 +244,7 @@ def dataset(
|
|||
session.commit()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('table_creator', 'dataset')
|
||||
@pytest.mark.usefixtures('dataset', 'postgresql_dsn')
|
||||
class TestSelectCorrelatedExpression(object):
|
||||
@pytest.mark.parametrize(
|
||||
('model_key', 'related_model_key', 'path', 'result'),
|
||||
|
@ -428,20 +382,20 @@ class TestSelectCorrelatedExpression(object):
|
|||
def test_with_non_aggregate_function(
|
||||
self,
|
||||
session,
|
||||
user_cls,
|
||||
article_cls
|
||||
User,
|
||||
Article
|
||||
):
|
||||
aggregate = select_correlated_expression(
|
||||
article_cls,
|
||||
sa.func.json_build_object('name', user_cls.name),
|
||||
Article,
|
||||
sa.func.json_build_object('name', User.name),
|
||||
'comments.author',
|
||||
user_cls
|
||||
User
|
||||
)
|
||||
|
||||
query = session.query(
|
||||
article_cls.id,
|
||||
Article.id,
|
||||
aggregate.label('author_json')
|
||||
).order_by(article_cls.id)
|
||||
).order_by(Article.id)
|
||||
result = query.all()
|
||||
assert result == [
|
||||
(1, {'name': 'User 1'})
|
||||
|
|
|
@ -9,143 +9,152 @@ from sqlalchemy_utils import (
|
|||
assert_non_nullable,
|
||||
assert_nullable
|
||||
)
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class AssertionTestCase(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
@pytest.fixture()
|
||||
def User(Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
name = sa.Column('_name', sa.String(20))
|
||||
age = sa.Column('_age', sa.Integer, nullable=False)
|
||||
email = sa.Column(
|
||||
'_email', sa.String(200), nullable=False, unique=True
|
||||
)
|
||||
fav_numbers = sa.Column('_fav_numbers', ARRAY(sa.Integer))
|
||||
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||
name = sa.Column('_name', sa.String(20))
|
||||
age = sa.Column('_age', sa.Integer, nullable=False)
|
||||
email = sa.Column(
|
||||
'_email', sa.String(200), nullable=False, unique=True
|
||||
)
|
||||
fav_numbers = sa.Column('_fav_numbers', ARRAY(sa.Integer))
|
||||
|
||||
__table_args__ = (
|
||||
sa.CheckConstraint(sa.and_(age >= 0, age <= 150)),
|
||||
sa.CheckConstraint(
|
||||
sa.and_(
|
||||
sa.func.array_length(fav_numbers, 1) <= 8
|
||||
)
|
||||
__table_args__ = (
|
||||
sa.CheckConstraint(sa.and_(age >= 0, age <= 150)),
|
||||
sa.CheckConstraint(
|
||||
sa.and_(
|
||||
sa.func.array_length(fav_numbers, 1) <= 8
|
||||
)
|
||||
)
|
||||
|
||||
self.User = User
|
||||
|
||||
def setup_method(self, method):
|
||||
TestCase.setup_method(self, method)
|
||||
user = self.User(
|
||||
name='Someone',
|
||||
email='someone@example.com',
|
||||
age=15,
|
||||
fav_numbers=[1, 2, 3]
|
||||
)
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
self.user = user
|
||||
return User
|
||||
|
||||
|
||||
class TestAssertMaxLengthWithArray(AssertionTestCase):
|
||||
def test_with_max_length(self):
|
||||
assert_max_length(self.user, 'fav_numbers', 8)
|
||||
assert_max_length(self.user, 'fav_numbers', 8)
|
||||
@pytest.fixture()
|
||||
def user(User, session):
|
||||
user = User(
|
||||
name='Someone',
|
||||
email='someone@example.com',
|
||||
age=15,
|
||||
fav_numbers=[1, 2, 3]
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
return user
|
||||
|
||||
def test_smaller_than_max_length(self):
|
||||
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestAssertMaxLengthWithArray(object):
|
||||
|
||||
def test_with_max_length(self, user):
|
||||
assert_max_length(user, 'fav_numbers', 8)
|
||||
assert_max_length(user, 'fav_numbers', 8)
|
||||
|
||||
def test_smaller_than_max_length(self, user):
|
||||
with pytest.raises(AssertionError):
|
||||
assert_max_length(self.user, 'fav_numbers', 7)
|
||||
assert_max_length(user, 'fav_numbers', 7)
|
||||
with pytest.raises(AssertionError):
|
||||
assert_max_length(self.user, 'fav_numbers', 7)
|
||||
assert_max_length(user, 'fav_numbers', 7)
|
||||
|
||||
def test_bigger_than_max_length(self):
|
||||
def test_bigger_than_max_length(self, user):
|
||||
with pytest.raises(AssertionError):
|
||||
assert_max_length(self.user, 'fav_numbers', 9)
|
||||
assert_max_length(user, 'fav_numbers', 9)
|
||||
with pytest.raises(AssertionError):
|
||||
assert_max_length(self.user, 'fav_numbers', 9)
|
||||
assert_max_length(user, 'fav_numbers', 9)
|
||||
|
||||
|
||||
class TestAssertNonNullable(AssertionTestCase):
|
||||
def test_non_nullable_column(self):
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestAssertNonNullable(object):
|
||||
|
||||
def test_non_nullable_column(self, user):
|
||||
# Test everything twice so that session gets rolled back properly
|
||||
assert_non_nullable(self.user, 'age')
|
||||
assert_non_nullable(self.user, 'age')
|
||||
assert_non_nullable(user, 'age')
|
||||
assert_non_nullable(user, 'age')
|
||||
|
||||
def test_nullable_column(self):
|
||||
def test_nullable_column(self, user):
|
||||
with pytest.raises(AssertionError):
|
||||
assert_non_nullable(self.user, 'name')
|
||||
assert_non_nullable(user, 'name')
|
||||
with pytest.raises(AssertionError):
|
||||
assert_non_nullable(self.user, 'name')
|
||||
assert_non_nullable(user, 'name')
|
||||
|
||||
|
||||
class TestAssertNullable(AssertionTestCase):
|
||||
def test_nullable_column(self):
|
||||
assert_nullable(self.user, 'name')
|
||||
assert_nullable(self.user, 'name')
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestAssertNullable(object):
|
||||
|
||||
def test_non_nullable_column(self):
|
||||
def test_nullable_column(self, user):
|
||||
assert_nullable(user, 'name')
|
||||
assert_nullable(user, 'name')
|
||||
|
||||
def test_non_nullable_column(self, user):
|
||||
with pytest.raises(AssertionError):
|
||||
assert_nullable(self.user, 'age')
|
||||
assert_nullable(user, 'age')
|
||||
with pytest.raises(AssertionError):
|
||||
assert_nullable(self.user, 'age')
|
||||
assert_nullable(user, 'age')
|
||||
|
||||
|
||||
class TestAssertMaxLength(AssertionTestCase):
|
||||
def test_with_max_length(self):
|
||||
assert_max_length(self.user, 'name', 20)
|
||||
assert_max_length(self.user, 'name', 20)
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestAssertMaxLength(object):
|
||||
|
||||
def test_with_non_nullable_column(self):
|
||||
assert_max_length(self.user, 'email', 200)
|
||||
assert_max_length(self.user, 'email', 200)
|
||||
def test_with_max_length(self, user):
|
||||
assert_max_length(user, 'name', 20)
|
||||
assert_max_length(user, 'name', 20)
|
||||
|
||||
def test_smaller_than_max_length(self):
|
||||
with pytest.raises(AssertionError):
|
||||
assert_max_length(self.user, 'name', 19)
|
||||
with pytest.raises(AssertionError):
|
||||
assert_max_length(self.user, 'name', 19)
|
||||
def test_with_non_nullable_column(self, user):
|
||||
assert_max_length(user, 'email', 200)
|
||||
assert_max_length(user, 'email', 200)
|
||||
|
||||
def test_bigger_than_max_length(self):
|
||||
def test_smaller_than_max_length(self, user):
|
||||
with pytest.raises(AssertionError):
|
||||
assert_max_length(self.user, 'name', 21)
|
||||
assert_max_length(user, 'name', 19)
|
||||
with pytest.raises(AssertionError):
|
||||
assert_max_length(self.user, 'name', 21)
|
||||
assert_max_length(user, 'name', 19)
|
||||
|
||||
def test_bigger_than_max_length(self, user):
|
||||
with pytest.raises(AssertionError):
|
||||
assert_max_length(user, 'name', 21)
|
||||
with pytest.raises(AssertionError):
|
||||
assert_max_length(user, 'name', 21)
|
||||
|
||||
|
||||
class TestAssertMinValue(AssertionTestCase):
|
||||
def test_with_min_value(self):
|
||||
assert_min_value(self.user, 'age', 0)
|
||||
assert_min_value(self.user, 'age', 0)
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestAssertMinValue(object):
|
||||
|
||||
def test_smaller_than_min_value(self):
|
||||
with pytest.raises(AssertionError):
|
||||
assert_min_value(self.user, 'age', -1)
|
||||
with pytest.raises(AssertionError):
|
||||
assert_min_value(self.user, 'age', -1)
|
||||
def test_with_min_value(self, user):
|
||||
assert_min_value(user, 'age', 0)
|
||||
assert_min_value(user, 'age', 0)
|
||||
|
||||
def test_bigger_than_min_value(self):
|
||||
def test_smaller_than_min_value(self, user):
|
||||
with pytest.raises(AssertionError):
|
||||
assert_min_value(self.user, 'age', 1)
|
||||
assert_min_value(user, 'age', -1)
|
||||
with pytest.raises(AssertionError):
|
||||
assert_min_value(self.user, 'age', 1)
|
||||
assert_min_value(user, 'age', -1)
|
||||
|
||||
def test_bigger_than_min_value(self, user):
|
||||
with pytest.raises(AssertionError):
|
||||
assert_min_value(user, 'age', 1)
|
||||
with pytest.raises(AssertionError):
|
||||
assert_min_value(user, 'age', 1)
|
||||
|
||||
|
||||
class TestAssertMaxValue(AssertionTestCase):
|
||||
def test_with_min_value(self):
|
||||
assert_max_value(self.user, 'age', 150)
|
||||
assert_max_value(self.user, 'age', 150)
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestAssertMaxValue(object):
|
||||
|
||||
def test_smaller_than_max_value(self):
|
||||
with pytest.raises(AssertionError):
|
||||
assert_max_value(self.user, 'age', 149)
|
||||
with pytest.raises(AssertionError):
|
||||
assert_max_value(self.user, 'age', 149)
|
||||
def test_with_min_value(self, user):
|
||||
assert_max_value(user, 'age', 150)
|
||||
assert_max_value(user, 'age', 150)
|
||||
|
||||
def test_bigger_than_max_value(self):
|
||||
def test_smaller_than_max_value(self, user):
|
||||
with pytest.raises(AssertionError):
|
||||
assert_max_value(self.user, 'age', 151)
|
||||
assert_max_value(user, 'age', 149)
|
||||
with pytest.raises(AssertionError):
|
||||
assert_max_value(self.user, 'age', 151)
|
||||
assert_max_value(user, 'age', 149)
|
||||
|
||||
def test_bigger_than_max_value(self, user):
|
||||
with pytest.raises(AssertionError):
|
||||
assert_max_value(user, 'age', 151)
|
||||
with pytest.raises(AssertionError):
|
||||
assert_max_value(user, 'age', 151)
|
||||
|
|
|
@ -1,117 +1,108 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from pytest import raises
|
||||
|
||||
from sqlalchemy_utils import auto_delete_orphans, ImproperlyConfigured
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAutoDeleteOrphans(TestCase):
|
||||
def create_models(self):
|
||||
tagging = sa.Table(
|
||||
'tagging',
|
||||
self.Base.metadata,
|
||||
sa.Column(
|
||||
'tag_id',
|
||||
sa.Integer,
|
||||
sa.ForeignKey('tag.id', ondelete='cascade'),
|
||||
primary_key=True
|
||||
),
|
||||
sa.Column(
|
||||
'entry_id',
|
||||
sa.Integer,
|
||||
sa.ForeignKey('entry.id', ondelete='cascade'),
|
||||
primary_key=True
|
||||
)
|
||||
@pytest.fixture
|
||||
def tagging_tbl(Base):
|
||||
return sa.Table(
|
||||
'tagging',
|
||||
Base.metadata,
|
||||
sa.Column(
|
||||
'tag_id',
|
||||
sa.Integer,
|
||||
sa.ForeignKey('tag.id', ondelete='cascade'),
|
||||
primary_key=True
|
||||
),
|
||||
sa.Column(
|
||||
'entry_id',
|
||||
sa.Integer,
|
||||
sa.ForeignKey('entry.id', ondelete='cascade'),
|
||||
primary_key=True
|
||||
)
|
||||
)
|
||||
|
||||
class Tag(self.Base):
|
||||
__tablename__ = 'tag'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(100), unique=True, nullable=False)
|
||||
|
||||
def __init__(self, name=None):
|
||||
self.name = name
|
||||
@pytest.fixture
|
||||
def Tag(Base):
|
||||
class Tag(Base):
|
||||
__tablename__ = 'tag'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(100), unique=True, nullable=False)
|
||||
|
||||
class Entry(self.Base):
|
||||
__tablename__ = 'entry'
|
||||
def __init__(self, name=None):
|
||||
self.name = name
|
||||
return Tag
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
tags = sa.orm.relationship(
|
||||
'Tag',
|
||||
secondary=tagging,
|
||||
backref='entries'
|
||||
)
|
||||
@pytest.fixture
|
||||
def Entry(Base, Tag, tagging_tbl):
|
||||
class Entry(Base):
|
||||
__tablename__ = 'entry'
|
||||
|
||||
auto_delete_orphans(Entry.tags)
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
self.Tag = Tag
|
||||
self.Entry = Entry
|
||||
tags = sa.orm.relationship(
|
||||
'Tag',
|
||||
secondary=tagging_tbl,
|
||||
backref='entries'
|
||||
)
|
||||
auto_delete_orphans(Entry.tags)
|
||||
return Entry
|
||||
|
||||
def test_orphan_deletion(self):
|
||||
r1 = self.Entry()
|
||||
r2 = self.Entry()
|
||||
r3 = self.Entry()
|
||||
|
||||
@pytest.fixture
|
||||
def EntryWithoutTagsBackref(Base, Tag, tagging_tbl):
|
||||
class EntryWithoutTagsBackref(Base):
|
||||
__tablename__ = 'entry'
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
tags = sa.orm.relationship(
|
||||
'Tag',
|
||||
secondary=tagging_tbl
|
||||
)
|
||||
return EntryWithoutTagsBackref
|
||||
|
||||
|
||||
class TestAutoDeleteOrphans(object):
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(self, Entry, Tag):
|
||||
pass
|
||||
|
||||
def test_orphan_deletion(self, session, Entry, Tag):
|
||||
r1 = Entry()
|
||||
r2 = Entry()
|
||||
r3 = Entry()
|
||||
t1, t2, t3, t4 = (
|
||||
self.Tag('t1'),
|
||||
self.Tag('t2'),
|
||||
self.Tag('t3'),
|
||||
self.Tag('t4')
|
||||
Tag('t1'),
|
||||
Tag('t2'),
|
||||
Tag('t3'),
|
||||
Tag('t4')
|
||||
)
|
||||
|
||||
r1.tags.extend([t1, t2])
|
||||
r2.tags.extend([t2, t3])
|
||||
r3.tags.extend([t4])
|
||||
self.session.add_all([r1, r2, r3])
|
||||
session.add_all([r1, r2, r3])
|
||||
|
||||
assert self.session.query(self.Tag).count() == 4
|
||||
assert session.query(Tag).count() == 4
|
||||
r2.tags.remove(t2)
|
||||
assert self.session.query(self.Tag).count() == 4
|
||||
assert session.query(Tag).count() == 4
|
||||
r1.tags.remove(t2)
|
||||
assert self.session.query(self.Tag).count() == 3
|
||||
assert session.query(Tag).count() == 3
|
||||
r1.tags.remove(t1)
|
||||
assert self.session.query(self.Tag).count() == 2
|
||||
assert session.query(Tag).count() == 2
|
||||
|
||||
|
||||
class TestAutoDeleteOrphansWithoutBackref(TestCase):
|
||||
def create_models(self):
|
||||
tagging = sa.Table(
|
||||
'tagging',
|
||||
self.Base.metadata,
|
||||
sa.Column(
|
||||
'tag_id',
|
||||
sa.Integer,
|
||||
sa.ForeignKey('tag.id', ondelete='cascade'),
|
||||
primary_key=True
|
||||
),
|
||||
sa.Column(
|
||||
'entry_id',
|
||||
sa.Integer,
|
||||
sa.ForeignKey('entry.id', ondelete='cascade'),
|
||||
primary_key=True
|
||||
)
|
||||
)
|
||||
class TestAutoDeleteOrphansWithoutBackref(object):
|
||||
|
||||
class Tag(self.Base):
|
||||
__tablename__ = 'tag'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(100), unique=True, nullable=False)
|
||||
@pytest.fixture
|
||||
def init_models(self, EntryWithoutTagsBackref, Tag):
|
||||
pass
|
||||
|
||||
def __init__(self, name=None):
|
||||
self.name = name
|
||||
|
||||
class Entry(self.Base):
|
||||
__tablename__ = 'entry'
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
tags = sa.orm.relationship(
|
||||
'Tag',
|
||||
secondary=tagging
|
||||
)
|
||||
|
||||
self.Entry = Entry
|
||||
|
||||
def test_orphan_deletion(self):
|
||||
with raises(ImproperlyConfigured):
|
||||
auto_delete_orphans(self.Entry.tags)
|
||||
def test_orphan_deletion(self, EntryWithoutTagsBackref):
|
||||
with pytest.raises(ImproperlyConfigured):
|
||||
auto_delete_orphans(EntryWithoutTagsBackref.tags)
|
||||
|
|
|
@ -1,50 +1,60 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import EmailType
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestCaseInsensitiveComparator(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
email = sa.Column(EmailType)
|
||||
@pytest.fixture
|
||||
def User(Base):
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
email = sa.Column(EmailType)
|
||||
|
||||
def __repr__(self):
|
||||
return 'Building(%r)' % self.id
|
||||
def __repr__(self):
|
||||
return 'Building(%r)' % self.id
|
||||
return User
|
||||
|
||||
self.User = User
|
||||
|
||||
def test_supports_equals(self):
|
||||
@pytest.fixture
|
||||
def init_models(User):
|
||||
pass
|
||||
|
||||
|
||||
class TestCaseInsensitiveComparator(object):
|
||||
|
||||
def test_supports_equals(self, session, User):
|
||||
query = (
|
||||
self.session.query(self.User)
|
||||
.filter(self.User.email == u'email@example.com')
|
||||
session.query(User)
|
||||
.filter(User.email == u'email@example.com')
|
||||
)
|
||||
|
||||
assert '"user".email = lower(:lower_1)' in str(query)
|
||||
|
||||
def test_supports_in_(self):
|
||||
def test_supports_in_(self, session, User):
|
||||
query = (
|
||||
self.session.query(self.User)
|
||||
.filter(self.User.email.in_([u'email@example.com', u'a']))
|
||||
session.query(User)
|
||||
.filter(User.email.in_([u'email@example.com', u'a']))
|
||||
)
|
||||
assert (
|
||||
'"user".email IN (lower(:lower_1), lower(:lower_2))'
|
||||
in str(query)
|
||||
)
|
||||
|
||||
def test_supports_notin_(self):
|
||||
def test_supports_notin_(self, session, User):
|
||||
query = (
|
||||
self.session.query(self.User)
|
||||
.filter(self.User.email.notin_([u'email@example.com', u'a']))
|
||||
session.query(User)
|
||||
.filter(User.email.notin_([u'email@example.com', u'a']))
|
||||
)
|
||||
assert (
|
||||
'"user".email NOT IN (lower(:lower_1), lower(:lower_2))'
|
||||
in str(query)
|
||||
)
|
||||
|
||||
def test_does_not_apply_lower_to_types_that_are_already_lowercased(self):
|
||||
assert str(self.User.email == self.User.email) == (
|
||||
def test_does_not_apply_lower_to_types_that_are_already_lowercased(
|
||||
self,
|
||||
User
|
||||
):
|
||||
assert str(User.email == User.email) == (
|
||||
'"user".email = "user".email'
|
||||
)
|
||||
|
|
|
@ -1,86 +1,93 @@
|
|||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from pytest import raises
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from sqlalchemy_utils import Asterisk, row_to_json
|
||||
from sqlalchemy_utils.expressions import explain, explain_analyze
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class ExpressionTestCase(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
def create_models(self):
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
content = sa.Column(sa.UnicodeText)
|
||||
|
||||
self.Article = Article
|
||||
|
||||
def assert_startswith(self, query, query_part):
|
||||
@pytest.fixture
|
||||
def assert_startswith(session):
|
||||
def assert_startswith(query, query_part):
|
||||
assert str(
|
||||
query.compile(dialect=postgresql.dialect())
|
||||
).startswith(query_part)
|
||||
# Check that query executes properly
|
||||
self.session.execute(query)
|
||||
session.execute(query)
|
||||
return assert_startswith
|
||||
|
||||
|
||||
class TestExplain(ExpressionTestCase):
|
||||
def test_render_explain(self):
|
||||
self.assert_startswith(
|
||||
explain(self.session.query(self.Article)),
|
||||
@pytest.fixture
|
||||
def Article(Base):
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
content = sa.Column(sa.UnicodeText)
|
||||
return Article
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestExplain(object):
|
||||
|
||||
def test_render_explain(self, session, assert_startswith, Article):
|
||||
assert_startswith(
|
||||
explain(session.query(Article)),
|
||||
'EXPLAIN SELECT'
|
||||
)
|
||||
|
||||
def test_render_explain_with_analyze(self):
|
||||
self.assert_startswith(
|
||||
explain(self.session.query(self.Article), analyze=True),
|
||||
def test_render_explain_with_analyze(
|
||||
self,
|
||||
session,
|
||||
assert_startswith,
|
||||
Article
|
||||
):
|
||||
assert_startswith(
|
||||
explain(session.query(Article), analyze=True),
|
||||
'EXPLAIN (ANALYZE true) SELECT'
|
||||
)
|
||||
|
||||
def test_with_string_as_stmt_param(self):
|
||||
self.assert_startswith(
|
||||
def test_with_string_as_stmt_param(self, assert_startswith):
|
||||
assert_startswith(
|
||||
explain('SELECT 1 FROM article'),
|
||||
'EXPLAIN SELECT'
|
||||
)
|
||||
|
||||
def test_format(self):
|
||||
self.assert_startswith(
|
||||
def test_format(self, assert_startswith):
|
||||
assert_startswith(
|
||||
explain('SELECT 1 FROM article', format='json'),
|
||||
'EXPLAIN (FORMAT json) SELECT'
|
||||
)
|
||||
|
||||
def test_timing(self):
|
||||
self.assert_startswith(
|
||||
def test_timing(self, assert_startswith):
|
||||
assert_startswith(
|
||||
explain('SELECT 1 FROM article', analyze=True, timing=False),
|
||||
'EXPLAIN (ANALYZE true, TIMING false) SELECT'
|
||||
)
|
||||
|
||||
def test_verbose(self):
|
||||
self.assert_startswith(
|
||||
def test_verbose(self, assert_startswith):
|
||||
assert_startswith(
|
||||
explain('SELECT 1 FROM article', verbose=True),
|
||||
'EXPLAIN (VERBOSE true) SELECT'
|
||||
)
|
||||
|
||||
def test_buffers(self):
|
||||
self.assert_startswith(
|
||||
def test_buffers(self, assert_startswith):
|
||||
assert_startswith(
|
||||
explain('SELECT 1 FROM article', analyze=True, buffers=True),
|
||||
'EXPLAIN (ANALYZE true, BUFFERS true) SELECT'
|
||||
)
|
||||
|
||||
def test_costs(self):
|
||||
self.assert_startswith(
|
||||
def test_costs(self, assert_startswith):
|
||||
assert_startswith(
|
||||
explain('SELECT 1 FROM article', costs=False),
|
||||
'EXPLAIN (COSTS false) SELECT'
|
||||
)
|
||||
|
||||
|
||||
class TestExplainAnalyze(ExpressionTestCase):
|
||||
def test_render_explain_analyze(self):
|
||||
class TestExplainAnalyze(object):
|
||||
def test_render_explain_analyze(self, session, Article):
|
||||
assert str(
|
||||
explain_analyze(self.session.query(self.Article))
|
||||
explain_analyze(session.query(Article))
|
||||
.compile(
|
||||
dialect=postgresql.dialect()
|
||||
)
|
||||
|
@ -111,7 +118,7 @@ class TestAsterisk(object):
|
|||
|
||||
class TestRowToJson(object):
|
||||
def test_compiler_with_default_dialect(self):
|
||||
with raises(sa.exc.CompileError):
|
||||
with pytest.raises(sa.exc.CompileError):
|
||||
str(row_to_json(sa.text('article.*')))
|
||||
|
||||
def test_compiler_with_postgresql(self):
|
||||
|
@ -128,7 +135,7 @@ class TestRowToJson(object):
|
|||
|
||||
class TestArrayAgg(object):
|
||||
def test_compiler_with_default_dialect(self):
|
||||
with raises(sa.exc.CompileError):
|
||||
with pytest.raises(sa.exc.CompileError):
|
||||
str(sa.func.array_agg(sa.text('u.name')))
|
||||
|
||||
def test_compiler_with_postgresql(self):
|
||||
|
|
|
@ -1,27 +1,29 @@
|
|||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils.listeners import force_instant_defaults
|
||||
from tests import TestCase
|
||||
|
||||
force_instant_defaults()
|
||||
|
||||
|
||||
class TestInstantDefaultListener(TestCase):
|
||||
def create_models(self):
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255), default=u'Some article')
|
||||
created_at = sa.Column(sa.DateTime, default=datetime.now)
|
||||
@pytest.fixture
|
||||
def Article(Base):
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255), default=u'Some article')
|
||||
created_at = sa.Column(sa.DateTime, default=datetime.now)
|
||||
return Article
|
||||
|
||||
self.Article = Article
|
||||
|
||||
def test_assigns_defaults_on_object_construction(self):
|
||||
article = self.Article()
|
||||
class TestInstantDefaultListener(object):
|
||||
|
||||
def test_assigns_defaults_on_object_construction(self, Article):
|
||||
article = Article()
|
||||
assert article.name == u'Some article'
|
||||
|
||||
def test_callables_as_defaults(self):
|
||||
article = self.Article()
|
||||
def test_callables_as_defaults(self, Article):
|
||||
article = Article()
|
||||
assert isinstance(article.created_at, datetime)
|
||||
|
|
|
@ -1,14 +1,19 @@
|
|||
from tests import TestCase
|
||||
|
||||
|
||||
class TestInstrumentedList(TestCase):
|
||||
def test_any_returns_true_if_member_has_attr_defined(self):
|
||||
category = self.Category()
|
||||
category.articles.append(self.Article())
|
||||
category.articles.append(self.Article(name=u'some name'))
|
||||
class TestInstrumentedList(object):
|
||||
def test_any_returns_true_if_member_has_attr_defined(
|
||||
self,
|
||||
Category,
|
||||
Article
|
||||
):
|
||||
category = Category()
|
||||
category.articles.append(Article())
|
||||
category.articles.append(Article(name=u'some name'))
|
||||
assert category.articles.any('name')
|
||||
|
||||
def test_any_returns_false_if_no_member_has_attr_defined(self):
|
||||
category = self.Category()
|
||||
category.articles.append(self.Article())
|
||||
def test_any_returns_false_if_no_member_has_attr_defined(
|
||||
self,
|
||||
Category,
|
||||
Article
|
||||
):
|
||||
category = Category()
|
||||
category.articles.append(Article())
|
||||
assert not category.articles.any('name')
|
||||
|
|
|
@ -1,39 +1,40 @@
|
|||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import Timestamp
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestTimestamp(TestCase):
|
||||
@pytest.fixture
|
||||
def Article(Base):
|
||||
class Article(Base, Timestamp):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255), default=u'Some article')
|
||||
return Article
|
||||
|
||||
def create_models(self):
|
||||
class Article(self.Base, Timestamp):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255), default=u'Some article')
|
||||
|
||||
self.Article = Article
|
||||
class TestTimestamp(object):
|
||||
|
||||
def test_created(self):
|
||||
def test_created(self, session, Article):
|
||||
then = datetime.utcnow()
|
||||
article = self.Article()
|
||||
article = Article()
|
||||
|
||||
self.session.add(article)
|
||||
self.session.commit()
|
||||
session.add(article)
|
||||
session.commit()
|
||||
|
||||
assert article.created >= then and article.created <= datetime.utcnow()
|
||||
|
||||
def test_updated(self):
|
||||
article = self.Article()
|
||||
def test_updated(self, session, Article):
|
||||
article = Article()
|
||||
|
||||
self.session.add(article)
|
||||
self.session.commit()
|
||||
session.add(article)
|
||||
session.commit()
|
||||
|
||||
then = datetime.utcnow()
|
||||
article.name = u"Something"
|
||||
|
||||
self.session.commit()
|
||||
session.commit()
|
||||
|
||||
assert article.updated >= then and article.updated <= datetime.utcnow()
|
||||
|
|
|
@ -1,122 +1,127 @@
|
|||
import pytest
|
||||
import six
|
||||
import sqlalchemy as sa
|
||||
from pytest import mark
|
||||
from sqlalchemy.util.langhelpers import symbol
|
||||
|
||||
from sqlalchemy_utils.path import AttrPath, Path
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAttrPath(TestCase):
|
||||
def create_models(self):
|
||||
class Document(self.Base):
|
||||
__tablename__ = 'document'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
@pytest.fixture
|
||||
def Document(Base):
|
||||
class Document(Base):
|
||||
__tablename__ = 'document'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
return Document
|
||||
|
||||
class Section(self.Base):
|
||||
__tablename__ = 'section'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
document_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Document.id)
|
||||
)
|
||||
@pytest.fixture
|
||||
def Section(Base, Document):
|
||||
class Section(Base):
|
||||
__tablename__ = 'section'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
document = sa.orm.relationship(Document, backref='sections')
|
||||
|
||||
class SubSection(self.Base):
|
||||
__tablename__ = 'subsection'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
section_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Section.id)
|
||||
)
|
||||
|
||||
section = sa.orm.relationship(Section, backref='subsections')
|
||||
|
||||
self.Document = Document
|
||||
self.Section = Section
|
||||
self.SubSection = SubSection
|
||||
|
||||
@mark.parametrize(
|
||||
('class_', 'path', 'direction'),
|
||||
(
|
||||
('SubSection', 'section', symbol('MANYTOONE')),
|
||||
document_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Document.id)
|
||||
)
|
||||
)
|
||||
def test_direction(self, class_, path, direction):
|
||||
|
||||
document = sa.orm.relationship(Document, backref='sections')
|
||||
return Section
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def SubSection(Base, Section):
|
||||
class SubSection(Base):
|
||||
__tablename__ = 'subsection'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
section_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Section.id)
|
||||
)
|
||||
|
||||
section = sa.orm.relationship(Section, backref='subsections')
|
||||
return SubSection
|
||||
|
||||
|
||||
class TestAttrPath(object):
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(self, Document, Section, SubSection):
|
||||
pass
|
||||
|
||||
def test_direction(self, SubSection):
|
||||
assert (
|
||||
AttrPath(getattr(self, class_), path).direction == direction
|
||||
AttrPath(SubSection, 'section').direction == symbol('MANYTOONE')
|
||||
)
|
||||
|
||||
def test_invert(self):
|
||||
path = ~ AttrPath(self.SubSection, 'section.document')
|
||||
def test_invert(self, Document, Section, SubSection):
|
||||
path = ~ AttrPath(SubSection, 'section.document')
|
||||
assert path.parts == [
|
||||
self.Document.sections,
|
||||
self.Section.subsections
|
||||
Document.sections,
|
||||
Section.subsections
|
||||
]
|
||||
assert str(path.path) == 'sections.subsections'
|
||||
|
||||
def test_len(self):
|
||||
len(AttrPath(self.SubSection, 'section.document')) == 2
|
||||
def test_len(self, SubSection):
|
||||
len(AttrPath(SubSection, 'section.document')) == 2
|
||||
|
||||
def test_init(self):
|
||||
path = AttrPath(self.SubSection, 'section.document')
|
||||
assert path.class_ == self.SubSection
|
||||
def test_init(self, SubSection):
|
||||
path = AttrPath(SubSection, 'section.document')
|
||||
assert path.class_ == SubSection
|
||||
assert path.path == Path('section.document')
|
||||
|
||||
def test_iter(self):
|
||||
path = AttrPath(self.SubSection, 'section.document')
|
||||
def test_iter(self, Section, SubSection):
|
||||
path = AttrPath(SubSection, 'section.document')
|
||||
assert list(path) == [
|
||||
self.SubSection.section,
|
||||
self.Section.document
|
||||
SubSection.section,
|
||||
Section.document
|
||||
]
|
||||
|
||||
def test_repr(self):
|
||||
path = AttrPath(self.SubSection, 'section.document')
|
||||
def test_repr(self, SubSection):
|
||||
path = AttrPath(SubSection, 'section.document')
|
||||
assert repr(path) == (
|
||||
"AttrPath(SubSection, 'section.document')"
|
||||
)
|
||||
|
||||
def test_index(self):
|
||||
path = AttrPath(self.SubSection, 'section.document')
|
||||
assert path.index(self.Section.document) == 1
|
||||
assert path.index(self.SubSection.section) == 0
|
||||
def test_index(self, Section, SubSection):
|
||||
path = AttrPath(SubSection, 'section.document')
|
||||
assert path.index(Section.document) == 1
|
||||
assert path.index(SubSection.section) == 0
|
||||
|
||||
def test_getitem(self):
|
||||
path = AttrPath(self.SubSection, 'section.document')
|
||||
assert path[0] is self.SubSection.section
|
||||
assert path[1] is self.Section.document
|
||||
def test_getitem(self, Section, SubSection):
|
||||
path = AttrPath(SubSection, 'section.document')
|
||||
assert path[0] is SubSection.section
|
||||
assert path[1] is Section.document
|
||||
|
||||
def test_getitem_with_slice(self):
|
||||
path = AttrPath(self.SubSection, 'section.document')
|
||||
assert path[:] == AttrPath(self.SubSection, 'section.document')
|
||||
assert path[:-1] == AttrPath(self.SubSection, 'section')
|
||||
assert path[1:] == AttrPath(self.Section, 'document')
|
||||
def test_getitem_with_slice(self, Section, SubSection):
|
||||
path = AttrPath(SubSection, 'section.document')
|
||||
assert path[:] == AttrPath(SubSection, 'section.document')
|
||||
assert path[:-1] == AttrPath(SubSection, 'section')
|
||||
assert path[1:] == AttrPath(Section, 'document')
|
||||
|
||||
def test_eq(self):
|
||||
def test_eq(self, SubSection):
|
||||
assert (
|
||||
AttrPath(self.SubSection, 'section.document') ==
|
||||
AttrPath(self.SubSection, 'section.document')
|
||||
AttrPath(SubSection, 'section.document') ==
|
||||
AttrPath(SubSection, 'section.document')
|
||||
)
|
||||
assert not (
|
||||
AttrPath(self.SubSection, 'section') ==
|
||||
AttrPath(self.SubSection, 'section.document')
|
||||
AttrPath(SubSection, 'section') ==
|
||||
AttrPath(SubSection, 'section.document')
|
||||
)
|
||||
|
||||
def test_ne(self):
|
||||
def test_ne(self, SubSection):
|
||||
assert not (
|
||||
AttrPath(self.SubSection, 'section.document') !=
|
||||
AttrPath(self.SubSection, 'section.document')
|
||||
AttrPath(SubSection, 'section.document') !=
|
||||
AttrPath(SubSection, 'section.document')
|
||||
)
|
||||
assert (
|
||||
AttrPath(self.SubSection, 'section') !=
|
||||
AttrPath(self.SubSection, 'section.document')
|
||||
AttrPath(SubSection, 'section') !=
|
||||
AttrPath(SubSection, 'section.document')
|
||||
)
|
||||
|
||||
|
||||
|
@ -133,7 +138,7 @@ class TestPath(object):
|
|||
path = Path('s.s2.s3')
|
||||
assert list(path) == ['s', 's2', 's3']
|
||||
|
||||
@mark.parametrize(('path', 'length'), (
|
||||
@pytest.mark.parametrize(('path', 'length'), (
|
||||
(Path('s.s2.s3'), 3),
|
||||
(Path('s.s2'), 2),
|
||||
(Path(''), 0)
|
||||
|
@ -167,14 +172,14 @@ class TestPath(object):
|
|||
path = Path('s.s2.s3')
|
||||
assert path[1:] == Path('s2.s3')
|
||||
|
||||
@mark.parametrize(('test', 'result'), (
|
||||
@pytest.mark.parametrize(('test', 'result'), (
|
||||
(Path('s.s2') == Path('s.s2'), True),
|
||||
(Path('s.s2') == Path('s.s3'), False)
|
||||
))
|
||||
def test_eq(self, test, result):
|
||||
assert test is result
|
||||
|
||||
@mark.parametrize(('test', 'result'), (
|
||||
@pytest.mark.parametrize(('test', 'result'), (
|
||||
(Path('s.s2') != Path('s.s2'), False),
|
||||
(Path('s.s2') != Path('s.s3'), True)
|
||||
))
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue