Use pytest fixtures to reduce complexity and repetition

Also:

Allow override of database name and user in tests (important for me as I would have to mess with my PSQL and MySQL database users otherwise)
Use dict.items instead of six.iteritems as it sporadically caused RuntimeError: dictionary changed size during iteration in Python 2.6 tests.
Fix typo DNS to DSN
Adds Python 3.5 to tox.ini
Added an .editorconfig
Import babel.dates in sqlalchemy_utils.i18n as an exception would be raised when using the latest versions of babel.
This commit is contained in:
Jacob Magnusson 2016-01-18 16:32:12 +01:00
parent 5bdd4d3efb
commit 815f07d6c1
128 changed files with 5412 additions and 4286 deletions

14
.editorconfig Normal file
View File

@ -0,0 +1,14 @@
# EditorConfig helps developers define and maintain consistent
# coding styles between different editors and IDEs
# editorconfig.org
root = true
[*]
indent_style = space
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
indent_size = 4

5
.gitignore vendored
View File

@ -15,6 +15,8 @@ var
sdist
develop-eggs
.installed.cfg
.cache
.eggs
lib
lib64
docs/_build
@ -42,3 +44,6 @@ nosetests.xml
Session.vim
.netrwhist
*~
# Sublime Text
*.sublime-*

View File

@ -1,5 +1,5 @@
[settings]
known_first_party=sqlalchemy_utils,tests
known_first_party=sqlalchemy_utils
line_length=79
multi_line_output=3
not_skip=__init__.py

View File

@ -1,3 +1,6 @@
sudo: false
language: python
addons:
postgresql: "9.4"
@ -6,22 +9,29 @@ before_script:
- psql -c 'create extension hstore;' -U postgres -d sqlalchemy_utils_test
- mysql -e 'create database sqlalchemy_utils_test;'
language: python
python:
- 2.6
- 2.7
- 3.3
- 3.4
- 3.5
env:
- EXTRAS=test
- EXTRAS=test_all
matrix:
include:
- python: 2.6
env:
- "TOXENV=py26"
- python: 2.7
env:
- "TOXENV=py27"
- python: 3.3
env:
- "TOXENV=py33"
- python: 3.4
env:
- "TOXENV=py34"
- python: 3.5
env:
- "TOXENV=py35"
- python: 3.5
env:
- "TOXENV=lint"
install:
- pip install -e .[$EXTRAS]
- pip install tox
script:
- isort --recursive --diff sqlalchemy_utils tests && isort --recursive --check-only sqlalchemy_utils tests
- flake8 sqlalchemy_utils tests
- py.test
- tox

View File

@ -1,4 +1,4 @@
include CHANGES.rst LICENSE README.rst
include CHANGES.rst LICENSE README.rst conftest.py .isort.cfg
recursive-include tests *
recursive-exclude tests *.pyc
recursive-include docs *

198
conftest.py Normal file
View File

@ -0,0 +1,198 @@
import os
import warnings
import pytest
import sqlalchemy as sa
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base, synonym_for
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import (
aggregates,
coercion_listener,
i18n,
InstrumentedList
)
from sqlalchemy_utils.types.pg_composite import remove_composite_listeners
@sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute')
def count_sql_calls(conn, cursor, statement, parameters, context, executemany):
try:
conn.query_count += 1
except AttributeError:
conn.query_count = 0
warnings.simplefilter('error', sa.exc.SAWarning)
sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener)
def get_locale():
class Locale():
territories = {'FI': 'Finland'}
return Locale()
@pytest.fixture(scope='session')
def db_name():
return os.environ.get('SQLALCHEMY_UTILS_TEST_DB', 'sqlalchemy_utils_test')
@pytest.fixture(scope='session')
def postgresql_db_user():
return os.environ.get('SQLALCHEMY_UTILS_TEST_POSTGRESQL_USER', 'postgres')
@pytest.fixture(scope='session')
def mysql_db_user():
return os.environ.get('SQLALCHEMY_UTILS_TEST_MYSQL_USER', 'root')
@pytest.fixture
def postgresql_dsn(postgresql_db_user, db_name):
return 'postgres://{0}@localhost/{1}'.format(postgresql_db_user, db_name)
@pytest.fixture
def mysql_dsn(mysql_db_user, db_name):
return 'mysql+pymysql://{0}@localhost/{1}'.format(mysql_db_user, db_name)
@pytest.fixture
def sqlite_memory_dsn():
return 'sqlite:///:memory:'
@pytest.fixture
def sqlite_file_dsn():
return 'sqlite:///{0}.db'.format(db_name)
@pytest.fixture
def dsn(request):
if 'postgresql_dsn' in request.fixturenames:
return request.getfuncargvalue('postgresql_dsn')
elif 'mysql_dsn' in request.fixturenames:
return request.getfuncargvalue('mysql_dsn')
elif 'sqlite_file_dsn' in request.fixturenames:
return request.getfuncargvalue('sqlite_file_dsn')
elif 'sqlite_memory_dsn' in request.fixturenames:
pass # Return default
return request.getfuncargvalue('sqlite_memory_dsn')
@pytest.fixture
def engine(dsn):
engine = create_engine(dsn)
# engine.echo = True
return engine
@pytest.fixture
def connection(engine):
return engine.connect()
@pytest.fixture
def Base():
return declarative_base()
@pytest.fixture
def User(Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
return User
@pytest.fixture
def Category(Base):
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
title = sa.Column(sa.Unicode(255))
@hybrid_property
def full_name(self):
return u'%s %s' % (self.title, self.name)
@full_name.expression
def full_name(self):
return sa.func.concat(self.title, ' ', self.name)
@hybrid_property
def articles_count(self):
return len(self.articles)
@articles_count.expression
def articles_count(cls):
Article = Base._decl_class_registry['Article']
return (
sa.select([sa.func.count(Article.id)])
.where(Article.category_id == cls.id)
.correlate(Article.__table__)
.label('article_count')
)
@property
def name_alias(self):
return self.name
@synonym_for('name')
@property
def name_synonym(self):
return self.name
return Category
@pytest.fixture
def Article(Base, Category):
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255), index=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
category = sa.orm.relationship(
Category,
primaryjoin=category_id == Category.id,
backref=sa.orm.backref(
'articles',
collection_class=InstrumentedList
)
)
return Article
@pytest.fixture
def init_models(User, Category, Article):
pass
@pytest.fixture
def session(request, engine, connection, Base, init_models):
sa.orm.configure_mappers()
Base.metadata.create_all(connection)
Session = sessionmaker(bind=connection)
session = Session()
i18n.get_locale = get_locale
def teardown():
aggregates.manager.reset()
session.close_all()
Base.metadata.drop_all(connection)
remove_composite_listeners()
connection.close()
engine.dispose()
request.addfinalizer(teardown)
return session

View File

@ -11,6 +11,8 @@ SQLAlchemy-Utils has been tested against the following Python platforms.
- cPython 2.6
- cPython 2.7
- cPython 3.3
- cPython 3.4
- cPython 3.5
Installing an official release

View File

@ -89,11 +89,11 @@ setup(
'Operating System :: OS Independent',
'Programming Language :: Python',
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.6',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Topic :: Internet :: WWW/HTTP :: Dynamic Content',
'Topic :: Software Development :: Libraries :: Python Modules'
]

View File

@ -365,7 +365,6 @@ TODO
from collections import defaultdict
from weakref import WeakKeyDictionary
import six
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.sql.functions import _FunctionGenerator
@ -519,7 +518,7 @@ class AggregationManager(object):
)
def update_generator_registry(self):
for class_, attrs in six.iteritems(aggregated_attrs):
for class_, attrs in aggregated_attrs.items():
for expr, path, column in attrs:
value = AggregatedValue(
class_=class_,
@ -539,7 +538,7 @@ class AggregationManager(object):
if class_ in self.generator_registry:
object_dict[class_].append(obj)
for class_, objects in six.iteritems(object_dict):
for class_, objects in object_dict.items():
for aggregate_value in self.generator_registry[class_]:
query = aggregate_value.update_query(objects)
if query is not None:

View File

@ -10,7 +10,7 @@ from sqlalchemy.sql.expression import (
)
from sqlalchemy.sql.functions import GenericFunction
from sqlalchemy_utils.functions.orm import quote
from .functions.orm import quote
class explain(Executable, ClauseElement):

View File

@ -7,8 +7,7 @@ import sqlalchemy as sa
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy_utils.expressions import explain_analyze
from ..expressions import explain_analyze
from ..utils import starts_with
from .orm import quote

View File

@ -1,7 +1,6 @@
from collections import defaultdict
from itertools import groupby
import six
import sqlalchemy as sa
from sqlalchemy.exc import NoInspectionAvailable
from sqlalchemy.orm import object_session
@ -167,7 +166,7 @@ def merge_references(from_, to, foreign_keys=None):
new_values = get_foreign_key_values(fk, to)
criteria = (
getattr(fk.constraint.table.c, key) == value
for key, value in six.iteritems(old_values)
for key, value in old_values.items()
)
try:
mapper = get_mapper(fk.constraint.table)

View File

@ -19,7 +19,7 @@ from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.session import object_session
from sqlalchemy.orm.util import AliasedInsp
from sqlalchemy_utils.utils import is_sequence
from ..utils import is_sequence
def get_class_by_table(base, table, data=None):

View File

@ -8,9 +8,8 @@ from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
from sqlalchemy.orm.session import _state_session
from sqlalchemy.util import set_creation_order
from sqlalchemy_utils.functions import identity
from .exceptions import ImproperlyConfigured
from .functions import identity
class GenericAttributeImpl(attributes.ScalarAttributeImpl):

View File

@ -8,6 +8,7 @@ from .exceptions import ImproperlyConfigured
try:
import babel
import babel.dates
except ImportError:
babel = None

View File

@ -154,9 +154,9 @@ from collections import defaultdict, Iterable, namedtuple
import sqlalchemy as sa
from sqlalchemy_utils.functions import getdotattr, has_changes
from sqlalchemy_utils.path import AttrPath
from sqlalchemy_utils.utils import is_sequence
from .functions import getdotattr, has_changes
from .path import AttrPath
from .utils import is_sequence
Callback = namedtuple('Callback', ['func', 'path', 'backref', 'fullpath'])

View File

@ -1,7 +1,7 @@
import six
from sqlalchemy_utils import i18n
from sqlalchemy_utils.utils import str_coercible
from .. import i18n
from ..utils import str_coercible
@str_coercible

View File

@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
import six
from sqlalchemy_utils import i18n, ImproperlyConfigured
from sqlalchemy_utils.utils import str_coercible
from .. import i18n, ImproperlyConfigured
from ..utils import str_coercible
@str_coercible

View File

@ -4,8 +4,8 @@ try:
except ImportError:
# Python 2.6 port
from total_ordering import total_ordering
from sqlalchemy_utils import i18n
from sqlalchemy_utils.utils import str_coercible
from .. import i18n
from ..utils import str_coercible
@str_coercible

View File

@ -1,7 +1,6 @@
import six
from sqlalchemy_utils.utils import str_coercible
from ..utils import str_coercible
from .weekday import WeekDay

View File

@ -6,8 +6,7 @@ from datetime import datetime
import six
from sqlalchemy import types
from sqlalchemy_utils.exceptions import ImproperlyConfigured
from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible
arrow = None

View File

@ -1,8 +1,7 @@
import six
from sqlalchemy import types
from sqlalchemy_utils.exceptions import ImproperlyConfigured
from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible
colour = None

View File

@ -1,8 +1,7 @@
import six
from sqlalchemy import types
from sqlalchemy_utils.primitives import Country
from ..primitives import Country
from .scalar_coercible import ScalarCoercible

View File

@ -1,9 +1,8 @@
import six
from sqlalchemy import types
from sqlalchemy_utils import i18n, ImproperlyConfigured
from sqlalchemy_utils.primitives import Currency
from .. import i18n, ImproperlyConfigured
from ..primitives import Currency
from .scalar_coercible import ScalarCoercible

View File

@ -5,8 +5,7 @@ import datetime
import six
from sqlalchemy.types import Binary, String, TypeDecorator
from sqlalchemy_utils.exceptions import ImproperlyConfigured
from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible
cryptography = None
@ -84,7 +83,7 @@ class AesEngine(EncryptionDecryptionBaseEngine):
value = str(value)
decryptor = self.cipher.decryptor()
decrypted = base64.b64decode(value)
decrypted = decryptor.update(decrypted)+decryptor.finalize()
decrypted = decryptor.update(decrypted) + decryptor.finalize()
decrypted = decrypted.rstrip(self.PADDING)
if not isinstance(decrypted, six.string_types):
decrypted = decrypted.decode('utf-8')

View File

@ -1,8 +1,7 @@
import six
from sqlalchemy import types
from sqlalchemy_utils.exceptions import ImproperlyConfigured
from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible
ip_address = None

View File

@ -5,8 +5,7 @@ from sqlalchemy import types
from sqlalchemy.dialects import oracle, postgresql
from sqlalchemy.ext.mutable import Mutable
from sqlalchemy_utils.exceptions import ImproperlyConfigured
from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible
passlib = None

View File

@ -109,7 +109,7 @@ from sqlalchemy.types import (
UserDefinedType
)
from sqlalchemy_utils import ImproperlyConfigured
from .. import ImproperlyConfigured
psycopg2 = None
CompositeCaster = None

View File

@ -1,8 +1,7 @@
from sqlalchemy import types
from sqlalchemy_utils.exceptions import ImproperlyConfigured
from sqlalchemy_utils.utils import str_coercible
from ..exceptions import ImproperlyConfigured
from ..utils import str_coercible
from .scalar_coercible import ScalarCoercible
try:

View File

@ -1,8 +1,7 @@
import six
from sqlalchemy import types
from sqlalchemy_utils.exceptions import ImproperlyConfigured
from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible

View File

@ -1,10 +1,9 @@
import six
from sqlalchemy import types
from sqlalchemy_utils import i18n
from sqlalchemy_utils.exceptions import ImproperlyConfigured
from sqlalchemy_utils.primitives import WeekDay, WeekDays
from .. import i18n
from ..exceptions import ImproperlyConfigured
from ..primitives import WeekDay, WeekDays
from .bit import BitType
from .scalar_coercible import ScalarCoercible

View File

@ -1,132 +1,3 @@
import warnings
import sqlalchemy as sa
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base, synonym_for
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import (
aggregates,
coercion_listener,
i18n,
InstrumentedList
)
from sqlalchemy_utils.types.pg_composite import remove_composite_listeners
@sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute')
def count_sql_calls(conn, cursor, statement, parameters, context, executemany):
try:
conn.query_count += 1
except AttributeError:
conn.query_count = 0
warnings.simplefilter('error', sa.exc.SAWarning)
sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener)
def get_locale():
class Locale():
territories = {'FI': 'Finland'}
return Locale()
class TestCase(object):
dns = 'sqlite:///:memory:'
create_tables = True
def setup_method(self, method):
self.engine = create_engine(self.dns)
# self.engine.echo = True
self.connection = self.engine.connect()
self.Base = declarative_base()
self.create_models()
sa.orm.configure_mappers()
if self.create_tables:
self.Base.metadata.create_all(self.connection)
Session = sessionmaker(bind=self.connection)
self.session = Session()
i18n.get_locale = get_locale
def teardown_method(self, method):
aggregates.manager.reset()
self.session.close_all()
if self.create_tables:
self.Base.metadata.drop_all(self.connection)
remove_composite_listeners()
self.connection.close()
self.engine.dispose()
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
title = sa.Column(sa.Unicode(255))
@hybrid_property
def full_name(self):
return u'%s %s' % (self.title, self.name)
@full_name.expression
def full_name(self):
return sa.func.concat(self.title, ' ', self.name)
@hybrid_property
def articles_count(self):
return len(self.articles)
@articles_count.expression
def articles_count(cls):
return (
sa.select([sa.func.count(self.Article.id)])
.where(self.Article.category_id == self.Category.id)
.correlate(self.Article.__table__)
.label('article_count')
)
@property
def name_alias(self):
return self.name
@synonym_for('name')
@property
def name_synonym(self):
return self.name
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255), index=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
category = sa.orm.relationship(
Category,
primaryjoin=category_id == Category.id,
backref=sa.orm.backref(
'articles',
collection_class=InstrumentedList
)
)
self.User = User
self.Category = Category
self.Article = Article
def assert_contains(clause, query):
# Test that query executes
query.all()

View File

@ -1,61 +1,76 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregateValueGenerationWithBackrefs(TestCase):
def create_models(self):
class Thread(self.Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@pytest.fixture
def Thread(Base):
class Thread(Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated('comments', sa.Column(sa.Integer, default=0))
def comment_count(self):
return sa.func.count('1')
@aggregated('comments', sa.Column(sa.Integer, default=0))
def comment_count(self):
return sa.func.count('1')
return Thread
class Comment(self.Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.Unicode(255))
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
thread = sa.orm.relationship(Thread, backref='comments')
@pytest.fixture
def Comment(Base, Thread):
class Comment(Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.Unicode(255))
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
self.Thread = Thread
self.Comment = Comment
thread = sa.orm.relationship(Thread, backref='comments')
return Comment
def test_assigns_aggregates_on_insert(self):
thread = self.Thread()
@pytest.fixture
def init_models(Thread, Comment):
pass
class TestAggregateValueGenerationWithBackrefs(object):
def test_assigns_aggregates_on_insert(self, session, Thread, Comment):
thread = Thread()
thread.name = u'some article name'
self.session.add(thread)
comment = self.Comment(content=u'Some content', thread=thread)
self.session.add(comment)
self.session.commit()
self.session.refresh(thread)
session.add(thread)
comment = Comment(content=u'Some content', thread=thread)
session.add(comment)
session.commit()
session.refresh(thread)
assert thread.comment_count == 1
def test_assigns_aggregates_on_separate_insert(self):
thread = self.Thread()
def test_assigns_aggregates_on_separate_insert(
self,
session,
Thread,
Comment
):
thread = Thread()
thread.name = u'some article name'
self.session.add(thread)
self.session.commit()
comment = self.Comment(content=u'Some content', thread=thread)
self.session.add(comment)
self.session.commit()
self.session.refresh(thread)
session.add(thread)
session.commit()
comment = Comment(content=u'Some content', thread=thread)
session.add(comment)
session.commit()
session.refresh(thread)
assert thread.comment_count == 1
def test_assigns_aggregates_on_delete(self):
thread = self.Thread()
def test_assigns_aggregates_on_delete(self, session, Thread, Comment):
thread = Thread()
thread.name = u'some article name'
self.session.add(thread)
self.session.commit()
comment = self.Comment(content=u'Some content', thread=thread)
self.session.add(comment)
self.session.commit()
self.session.delete(comment)
self.session.commit()
self.session.refresh(thread)
session.add(thread)
session.commit()
comment = Comment(content=u'Some content', thread=thread)
session.add(comment)
session.commit()
session.delete(comment)
session.commit()
session.refresh(thread)
assert thread.comment_count == 0

View File

@ -1,67 +1,76 @@
from decimal import Decimal
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestLazyEvaluatedSelectExpressionsForAggregates(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.fixture
def Product(Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
return Product
@aggregated('products', sa.Column(sa.Numeric, default=0))
def net_worth(self):
return sa.func.sum(Product.price)
products = sa.orm.relationship('Product', backref='catalog')
@pytest.fixture
def Catalog(Base, Product):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
@aggregated('products', sa.Column(sa.Numeric, default=0))
def net_worth(self):
return sa.func.sum(Product.price)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
products = sa.orm.relationship('Product', backref='catalog')
return Catalog
self.Catalog = Catalog
self.Product = Product
def test_assigns_aggregates_on_insert(self):
catalog = self.Catalog(
@pytest.fixture
def init_models(Product, Catalog):
pass
@pytest.mark.usefixtures('postgresql_dsn')
class TestLazyEvaluatedSelectExpressionsForAggregates(object):
def test_assigns_aggregates_on_insert(self, session, Product, Catalog):
catalog = Catalog(
name=u'Some catalog'
)
self.session.add(catalog)
self.session.commit()
product = self.Product(
session.add(catalog)
session.commit()
product = Product(
name=u'Some product',
price=Decimal('1000'),
catalog=catalog
)
self.session.add(product)
self.session.commit()
self.session.refresh(catalog)
session.add(product)
session.commit()
session.refresh(catalog)
assert catalog.net_worth == Decimal('1000')
def test_assigns_aggregates_on_update(self):
catalog = self.Catalog(
def test_assigns_aggregates_on_update(self, session, Product, Catalog):
catalog = Catalog(
name=u'Some catalog'
)
self.session.add(catalog)
self.session.commit()
product = self.Product(
session.add(catalog)
session.commit()
product = Product(
name=u'Some product',
price=Decimal('1000'),
catalog=catalog
)
self.session.add(product)
self.session.commit()
session.add(product)
session.commit()
product.price = Decimal('500')
self.session.commit()
self.session.refresh(catalog)
session.commit()
session.refresh(catalog)
assert catalog.net_worth == Decimal('500')

View File

@ -1,101 +1,121 @@
from decimal import Decimal
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestLazyEvaluatedSelectExpressionsForAggregates(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.fixture
def Product(Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
type = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
return Product
__mapper_args__ = {
'polymorphic_on': type
}
@aggregated('products', sa.Column(sa.Numeric, default=0))
def net_worth(self):
return sa.func.sum(Product.price)
@pytest.fixture
def Catalog(Base, Product):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
type = sa.Column(sa.Unicode(255))
products = sa.orm.relationship('Product', backref='catalog')
__mapper_args__ = {
'polymorphic_on': type
}
class CostumeCatalog(Catalog):
__tablename__ = 'costume_catalog'
id = sa.Column(
sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True
)
@aggregated('products', sa.Column(sa.Numeric, default=0))
def net_worth(self):
return sa.func.sum(Product.price)
__mapper_args__ = {
'polymorphic_identity': 'costumes',
}
products = sa.orm.relationship('Product', backref='catalog')
return Catalog
class CarCatalog(Catalog):
__tablename__ = 'car_catalog'
id = sa.Column(
sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True
)
__mapper_args__ = {
'polymorphic_identity': 'cars',
}
@pytest.fixture
def CostumeCatalog(Catalog):
class CostumeCatalog(Catalog):
__tablename__ = 'costume_catalog'
id = sa.Column(
sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True
)
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
__mapper_args__ = {
'polymorphic_identity': 'costumes',
}
return CostumeCatalog
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
self.Catalog = Catalog
self.CostumeCatalog = CostumeCatalog
self.CarCatalog = CarCatalog
self.Product = Product
@pytest.fixture
def CarCatalog(Catalog):
class CarCatalog(Catalog):
__tablename__ = 'car_catalog'
id = sa.Column(
sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True
)
def test_columns_inherited_from_parent(self):
assert self.CarCatalog.net_worth
assert self.CostumeCatalog.net_worth
assert self.Catalog.net_worth
assert not hasattr(self.CarCatalog.__table__.c, 'net_worth')
assert not hasattr(self.CostumeCatalog.__table__.c, 'net_worth')
__mapper_args__ = {
'polymorphic_identity': 'cars',
}
return CarCatalog
def test_assigns_aggregates_on_insert(self):
catalog = self.Catalog(
@pytest.fixture
def init_models(Product, Catalog, CostumeCatalog, CarCatalog):
pass
@pytest.mark.usefixtures('postgresql_dsn')
class TestLazyEvaluatedSelectExpressionsForAggregates(object):
def test_columns_inherited_from_parent(
self,
Catalog,
CarCatalog,
CostumeCatalog
):
assert CarCatalog.net_worth
assert CostumeCatalog.net_worth
assert Catalog.net_worth
assert not hasattr(CarCatalog.__table__.c, 'net_worth')
assert not hasattr(CostumeCatalog.__table__.c, 'net_worth')
def test_assigns_aggregates_on_insert(self, session, Product, Catalog):
catalog = Catalog(
name=u'Some catalog'
)
self.session.add(catalog)
self.session.commit()
product = self.Product(
session.add(catalog)
session.commit()
product = Product(
name=u'Some product',
price=Decimal('1000'),
catalog=catalog
)
self.session.add(product)
self.session.commit()
self.session.refresh(catalog)
session.add(product)
session.commit()
session.refresh(catalog)
assert catalog.net_worth == Decimal('1000')
def test_assigns_aggregates_on_update(self):
catalog = self.Catalog(
def test_assigns_aggregates_on_update(self, session, Catalog, Product):
catalog = Catalog(
name=u'Some catalog'
)
self.session.add(catalog)
self.session.commit()
product = self.Product(
session.add(catalog)
session.commit()
product = Product(
name=u'Some product',
price=Decimal('1000'),
catalog=catalog
)
self.session.add(product)
self.session.commit()
session.add(product)
session.commit()
product.price = Decimal('500')
self.session.commit()
self.session.refresh(catalog)
session.commit()
session.refresh(catalog)
assert catalog.net_worth == Decimal('500')

View File

@ -1,72 +1,81 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregatesWithManyToManyRelationships(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.fixture
def User(Base):
user_group = sa.Table(
'user_group',
Base.metadata,
sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')),
sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id'))
)
def create_models(self):
user_group = sa.Table(
'user_group',
self.Base.metadata,
sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')),
sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id'))
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated('groups', sa.Column(sa.Integer, default=0))
def group_count(self):
return sa.func.count('1')
groups = sa.orm.relationship(
'Group',
backref='users',
secondary=user_group
)
return User
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated('groups', sa.Column(sa.Integer, default=0))
def group_count(self):
return sa.func.count('1')
@pytest.fixture
def Group(Base):
class Group(Base):
__tablename__ = 'group'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
return Group
groups = sa.orm.relationship(
'Group',
backref='users',
secondary=user_group
)
class Group(self.Base):
__tablename__ = 'group'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@pytest.fixture
def init_models(User, Group):
pass
self.User = User
self.Group = Group
def test_assigns_aggregates_on_insert(self):
user = self.User(
@pytest.mark.usefixtures('postgresql_dsn')
class TestAggregatesWithManyToManyRelationships(object):
def test_assigns_aggregates_on_insert(self, session, User, Group):
user = User(
name=u'John Matrix'
)
self.session.add(user)
self.session.commit()
group = self.Group(
session.add(user)
session.commit()
group = Group(
name=u'Some group',
users=[user]
)
self.session.add(group)
self.session.commit()
self.session.refresh(user)
session.add(group)
session.commit()
session.refresh(user)
assert user.group_count == 1
def test_updates_aggregates_on_delete(self):
user = self.User(
def test_updates_aggregates_on_delete(self, session, User, Group):
user = User(
name=u'John Matrix'
)
self.session.add(user)
self.session.commit()
group = self.Group(
session.add(user)
session.commit()
group = Group(
name=u'Some group',
users=[user]
)
self.session.add(group)
self.session.commit()
self.session.refresh(user)
session.add(group)
session.commit()
session.refresh(user)
user.groups = []
self.session.commit()
self.session.refresh(user)
session.commit()
session.refresh(user)
assert user.group_count == 0

View File

@ -1,80 +1,92 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils import aggregated
from tests import TestCase
class TestAggregateManyToManyAndManyToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.fixture
def Category(Base):
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
return Category
def create_models(self):
catalog_products = sa.Table(
'catalog_product',
self.Base.metadata,
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
@pytest.fixture
def Catalog(Base, Category):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'products.categories',
sa.Column(sa.Integer, default=0)
)
def category_count(self):
return sa.func.count(sa.distinct(Category.id))
return Catalog
@pytest.fixture
def Product(Base, Catalog, Category):
catalog_products = sa.Table(
'catalog_product',
Base.metadata,
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
)
product_categories = sa.Table(
'category_product',
Base.metadata,
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
)
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(
sa.Integer, sa.ForeignKey('catalog.id')
)
product_categories = sa.Table(
'category_product',
self.Base.metadata,
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
catalogs = sa.orm.relationship(
Catalog,
backref='products',
secondary=catalog_products
)
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
categories = sa.orm.relationship(
Category,
backref='products',
secondary=product_categories
)
return Product
@aggregated(
'products.categories',
sa.Column(sa.Integer, default=0)
)
def category_count(self):
return sa.func.count(sa.distinct(Category.id))
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@pytest.fixture
def init_models(Category, Catalog, Product):
pass
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(
sa.Integer, sa.ForeignKey('catalog.id')
)
@pytest.mark.usefixtures('postgresql_dsn')
class TestAggregateManyToManyAndManyToMany(object):
catalogs = sa.orm.relationship(
Catalog,
backref='products',
secondary=catalog_products
)
categories = sa.orm.relationship(
Category,
backref='products',
secondary=product_categories
)
self.Catalog = Catalog
self.Category = Category
self.Product = Product
def test_insert(self):
category = self.Category()
def test_insert(self, session, Product, Category, Catalog):
category = Category()
products = [
self.Product(categories=[category]),
self.Product(categories=[category])
Product(categories=[category]),
Product(categories=[category])
]
catalog = self.Catalog(products=products)
self.session.add(catalog)
catalog2 = self.Catalog(products=products)
self.session.add(catalog)
self.session.commit()
catalog = Catalog(products=products)
session.add(catalog)
catalog2 = Catalog(products=products)
session.add(catalog)
session.commit()
assert catalog.category_count == 1
assert catalog2.category_count == 1

View File

@ -1,81 +1,96 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
def create_models(self):
class Thread(self.Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@pytest.fixture
def Comment(Base):
class Comment(Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.Unicode(255))
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
return Comment
@aggregated(
'comments',
sa.Column(sa.Integer, default=0)
)
def comment_count(self):
return sa.func.count('1')
@aggregated('comments', sa.Column(sa.Integer))
def last_comment_id(self):
return sa.func.max(Comment.id)
@pytest.fixture
def Thread(Base, Comment):
class Thread(Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
comments = sa.orm.relationship(
'Comment',
backref='thread'
)
@aggregated(
'comments',
sa.Column(sa.Integer, default=0)
)
def comment_count(self):
return sa.func.count('1')
Thread.last_comment = sa.orm.relationship(
@aggregated('comments', sa.Column(sa.Integer))
def last_comment_id(self):
return sa.func.max(Comment.id)
comments = sa.orm.relationship(
'Comment',
primaryjoin='Thread.last_comment_id == Comment.id',
foreign_keys=[Thread.last_comment_id],
viewonly=True
backref='thread'
)
class Comment(self.Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.Unicode(255))
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
Thread.last_comment = sa.orm.relationship(
'Comment',
primaryjoin='Thread.last_comment_id == Comment.id',
foreign_keys=[Thread.last_comment_id],
viewonly=True
)
return Thread
self.Thread = Thread
self.Comment = Comment
def test_assigns_aggregates_on_insert(self):
thread = self.Thread()
@pytest.fixture
def init_models(Comment, Thread):
pass
class TestAggregateValueGenerationForSimpleModelPaths(object):
def test_assigns_aggregates_on_insert(self, session, Thread, Comment):
thread = Thread()
thread.name = u'some article name'
self.session.add(thread)
comment = self.Comment(content=u'Some content', thread=thread)
self.session.add(comment)
self.session.commit()
self.session.refresh(thread)
session.add(thread)
comment = Comment(content=u'Some content', thread=thread)
session.add(comment)
session.commit()
session.refresh(thread)
assert thread.comment_count == 1
assert thread.last_comment_id == comment.id
def test_assigns_aggregates_on_separate_insert(self):
thread = self.Thread()
def test_assigns_aggregates_on_separate_insert(
self,
session,
Thread,
Comment
):
thread = Thread()
thread.name = u'some article name'
self.session.add(thread)
self.session.commit()
comment = self.Comment(content=u'Some content', thread=thread)
self.session.add(comment)
self.session.commit()
self.session.refresh(thread)
session.add(thread)
session.commit()
comment = Comment(content=u'Some content', thread=thread)
session.add(comment)
session.commit()
session.refresh(thread)
assert thread.comment_count == 1
assert thread.last_comment_id == 1
def test_assigns_aggregates_on_delete(self):
thread = self.Thread()
def test_assigns_aggregates_on_delete(self, session, Thread, Comment):
thread = Thread()
thread.name = u'some article name'
self.session.add(thread)
self.session.commit()
comment = self.Comment(content=u'Some content', thread=thread)
self.session.add(comment)
self.session.commit()
self.session.delete(comment)
self.session.commit()
self.session.refresh(thread)
session.add(thread)
session.commit()
comment = Comment(content=u'Some content', thread=thread)
session.add(comment)
session.commit()
session.delete(comment)
session.commit()
session.refresh(thread)
assert thread.comment_count == 0
assert thread.last_comment_id is None

View File

@ -1,76 +1,88 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils import aggregated
from tests import TestCase
class TestAggregateOneToManyAndManyToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.fixture
def Category(Base):
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
return Category
def create_models(self):
product_categories = sa.Table(
'category_product',
self.Base.metadata,
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
@pytest.fixture
def Catalog(Base, Category):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'products.categories',
sa.Column(sa.Integer, default=0)
)
def category_count(self):
return sa.func.count(sa.distinct(Category.id))
return Catalog
@pytest.fixture
def Product(Base, Catalog, Category):
product_categories = sa.Table(
'category_product',
Base.metadata,
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
)
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(
sa.Integer, sa.ForeignKey('catalog.id')
)
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog = sa.orm.relationship(
Catalog,
backref='products'
)
@aggregated(
'products.categories',
sa.Column(sa.Integer, default=0)
)
def category_count(self):
return sa.func.count(sa.distinct(Category.id))
categories = sa.orm.relationship(
Category,
backref='products',
secondary=product_categories
)
return Product
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
@pytest.fixture
def init_models(Category, Catalog, Product):
pass
catalog_id = sa.Column(
sa.Integer, sa.ForeignKey('catalog.id')
)
catalog = sa.orm.relationship(
Catalog,
backref='products'
)
@pytest.mark.usefixtures('postgresql_dsn')
class TestAggregateOneToManyAndManyToMany(object):
categories = sa.orm.relationship(
Category,
backref='products',
secondary=product_categories
)
self.Catalog = Catalog
self.Category = Category
self.Product = Product
def test_insert(self):
category = self.Category()
def test_insert(self, session, Category, Catalog, Product):
category = Category()
products = [
self.Product(categories=[category]),
self.Product(categories=[category])
Product(categories=[category]),
Product(categories=[category])
]
catalog = self.Catalog(products=products)
self.session.add(catalog)
catalog = Catalog(products=products)
session.add(catalog)
products2 = [
self.Product(categories=[category]),
self.Product(categories=[category])
Product(categories=[category]),
Product(categories=[category])
]
catalog2 = self.Catalog(products=products2)
self.session.add(catalog)
self.session.commit()
catalog2 = Catalog(products=products2)
session.add(catalog)
session.commit()
assert catalog.category_count == 1
assert catalog2.category_count == 1

View File

@ -1,64 +1,76 @@
from decimal import Decimal
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregateOneToManyAndOneToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.fixture
def Catalog(Base):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'categories.products',
sa.Column(sa.Integer, default=0)
)
def product_count(self):
return sa.func.count('1')
@aggregated(
'categories.products',
sa.Column(sa.Integer, default=0)
)
def product_count(self):
return sa.func.count('1')
categories = sa.orm.relationship('Category', backref='catalog')
return Catalog
categories = sa.orm.relationship('Category', backref='catalog')
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@pytest.fixture
def Category(Base):
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
products = sa.orm.relationship('Product', backref='category')
products = sa.orm.relationship('Product', backref='category')
return Category
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
@pytest.fixture
def Product(Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
self.Catalog = Catalog
self.Category = Category
self.Product = Product
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
return Product
def test_assigns_aggregates(self):
category = self.Category(name=u'Some category')
catalog = self.Catalog(
@pytest.fixture
def init_models(Catalog, Category, Product):
pass
@pytest.mark.usefixtures('postgresql_dsn')
class TestAggregateOneToManyAndOneToMany(object):
def test_assigns_aggregates(self, session, Category, Catalog, Product):
category = Category(name=u'Some category')
catalog = Catalog(
categories=[category]
)
catalog.name = u'Some catalog'
self.session.add(catalog)
self.session.commit()
product = self.Product(
session.add(catalog)
session.commit()
product = Product(
name=u'Some product',
price=Decimal('1000'),
category=category
)
self.session.add(product)
self.session.commit()
self.session.refresh(catalog)
session.add(product)
session.commit()
session.refresh(catalog)
assert catalog.product_count == 1

View File

@ -1,88 +1,129 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils import aggregated
from tests import TestCase
class Test3LevelDeepOneToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.fixture
def Catalog(Base):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
@aggregated(
'categories.sub_categories.products',
sa.Column(sa.Integer, default=0)
)
def product_count(self):
return sa.func.count('1')
@aggregated(
'categories.sub_categories.products',
sa.Column(sa.Integer, default=0)
)
def product_count(self):
return sa.func.count('1')
categories = sa.orm.relationship('Category', backref='catalog')
return Catalog
categories = sa.orm.relationship('Category', backref='catalog')
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
@pytest.fixture
def Category(Base):
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
sub_categories = sa.orm.relationship(
'SubCategory', backref='category'
)
sub_categories = sa.orm.relationship(
'SubCategory', backref='category'
)
return Category
class SubCategory(self.Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
products = sa.orm.relationship('Product', backref='sub_category')
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
@pytest.fixture
def SubCategory(Base):
class SubCategory(Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
products = sa.orm.relationship('Product', backref='sub_category')
return SubCategory
sub_category_id = sa.Column(
sa.Integer, sa.ForeignKey('sub_category.id')
)
self.Catalog = Catalog
self.Category = Category
self.SubCategory = SubCategory
self.Product = Product
@pytest.fixture
def Product(Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
def test_assigns_aggregates(self):
catalog = self.catalog_factory()
self.session.commit()
self.session.refresh(catalog)
assert catalog.product_count == 1
sub_category_id = sa.Column(
sa.Integer, sa.ForeignKey('sub_category.id')
)
return Product
def catalog_factory(self):
product = self.Product()
sub_category = self.SubCategory(
@pytest.fixture
def init_models(Catalog, Category, SubCategory, Product):
pass
@pytest.fixture
def catalog_factory(Product, SubCategory, Category, Catalog, session):
def catalog_factory():
product = Product()
sub_category = SubCategory(
products=[product]
)
category = self.Category(sub_categories=[sub_category])
catalog = self.Catalog(categories=[category])
self.session.add(catalog)
category = Category(sub_categories=[sub_category])
catalog = Catalog(categories=[category])
session.add(catalog)
return catalog
return catalog_factory
@pytest.mark.usefixtures('postgresql_dsn')
class Test3LevelDeepOneToMany(object):
def test_assigns_aggregates(self, session, catalog_factory):
catalog = catalog_factory()
session.commit()
session.refresh(catalog)
assert catalog.product_count == 1
def catalog_factory(
self,
session,
Product,
SubCategory,
Category,
Catalog
):
product = Product()
sub_category = SubCategory(
products=[product]
)
category = Category(sub_categories=[sub_category])
catalog = Catalog(categories=[category])
session.add(catalog)
return catalog
def test_only_updates_affected_aggregates(self):
catalog = self.catalog_factory()
catalog2 = self.catalog_factory()
self.session.commit()
def test_only_updates_affected_aggregates(
self,
session,
catalog_factory,
Product
):
catalog = catalog_factory()
catalog2 = catalog_factory()
session.commit()
# force set catalog2 product_count to zero in order to check if it gets
# updated when the other catalog's product count gets updated
self.session.execute(
session.execute(
'UPDATE catalog SET product_count = 0 WHERE id = %d'
% catalog2.id
)
catalog.categories[0].sub_categories[0].products.append(
self.Product()
Product()
)
self.session.commit()
self.session.refresh(catalog)
self.session.refresh(catalog2)
session.commit()
session.refresh(catalog)
session.refresh(catalog2)
assert catalog.product_count == 2
assert catalog2.product_count == 0

View File

@ -1,7 +1,7 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils import aggregated, TSVectorType
from tests import TestCase
def tsvector_reduce_concat(vectors):
@ -13,45 +13,54 @@ def tsvector_reduce_concat(vectors):
)
class TestSearchVectorAggregates(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.fixture
def Product(Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
return Product
@aggregated('products', sa.Column(TSVectorType))
def product_search_vector(self):
return tsvector_reduce_concat(
sa.func.to_tsvector(Product.name)
)
products = sa.orm.relationship('Product', backref='catalog')
@pytest.fixture
def Catalog(Base, Product):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
@aggregated('products', sa.Column(TSVectorType))
def product_search_vector(self):
return tsvector_reduce_concat(
sa.func.to_tsvector(Product.name)
)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
products = sa.orm.relationship('Product', backref='catalog')
return Catalog
self.Catalog = Catalog
self.Product = Product
def test_assigns_aggregates_on_insert(self):
catalog = self.Catalog(
@pytest.fixture
def init_models(Product, Catalog):
pass
@pytest.mark.usefixtures('postgresql_dsn')
class TestSearchVectorAggregates(object):
def test_assigns_aggregates_on_insert(self, session, Product, Catalog):
catalog = Catalog(
name=u'Some catalog'
)
self.session.add(catalog)
self.session.commit()
product = self.Product(
session.add(catalog)
session.commit()
product = Product(
name=u'Product XYZ',
catalog=catalog
)
self.session.add(product)
self.session.commit()
self.session.refresh(catalog)
session.add(product)
session.commit()
session.refresh(catalog)
assert catalog.product_search_vector == "'product':1 'xyz':2"

View File

@ -1,61 +1,76 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
def create_models(self):
class Thread(self.Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@pytest.fixture
def Thread(Base):
class Thread(Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated('comments', sa.Column(sa.Integer, default=0))
def comment_count(self):
return sa.func.count('1')
@aggregated('comments', sa.Column(sa.Integer, default=0))
def comment_count(self):
return sa.func.count('1')
comments = sa.orm.relationship('Comment', backref='thread')
comments = sa.orm.relationship('Comment', backref='thread')
return Thread
class Comment(self.Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.Unicode(255))
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
self.Thread = Thread
self.Comment = Comment
@pytest.fixture
def Comment(Base):
class Comment(Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.Unicode(255))
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
return Comment
def test_assigns_aggregates_on_insert(self):
thread = self.Thread()
@pytest.fixture
def init_models(Thread, Comment):
pass
class TestAggregateValueGenerationForSimpleModelPaths(object):
def test_assigns_aggregates_on_insert(self, session, Thread, Comment):
thread = Thread()
thread.name = u'some article name'
self.session.add(thread)
comment = self.Comment(content=u'Some content', thread=thread)
self.session.add(comment)
self.session.commit()
self.session.refresh(thread)
session.add(thread)
comment = Comment(content=u'Some content', thread=thread)
session.add(comment)
session.commit()
session.refresh(thread)
assert thread.comment_count == 1
def test_assigns_aggregates_on_separate_insert(self):
thread = self.Thread()
def test_assigns_aggregates_on_separate_insert(
self,
session,
Thread,
Comment
):
thread = Thread()
thread.name = u'some article name'
self.session.add(thread)
self.session.commit()
comment = self.Comment(content=u'Some content', thread=thread)
self.session.add(comment)
self.session.commit()
self.session.refresh(thread)
session.add(thread)
session.commit()
comment = Comment(content=u'Some content', thread=thread)
session.add(comment)
session.commit()
session.refresh(thread)
assert thread.comment_count == 1
def test_assigns_aggregates_on_delete(self):
thread = self.Thread()
def test_assigns_aggregates_on_delete(self, session, Thread, Comment):
thread = Thread()
thread.name = u'some article name'
self.session.add(thread)
self.session.commit()
comment = self.Comment(content=u'Some content', thread=thread)
self.session.add(comment)
self.session.commit()
self.session.delete(comment)
self.session.commit()
self.session.refresh(thread)
session.add(thread)
session.commit()
comment = Comment(content=u'Some content', thread=thread)
session.add(comment)
session.commit()
session.delete(comment)
session.commit()
session.refresh(thread)
assert thread.comment_count == 0

View File

@ -1,59 +1,74 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregatedWithColumnAlias(TestCase):
def create_models(self):
class Thread(self.Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
@pytest.fixture
def Thread(Base):
class Thread(Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
@aggregated(
'comments',
sa.Column('_comment_count', sa.Integer, default=0)
)
def comment_count(self):
return sa.func.count('1')
@aggregated(
'comments',
sa.Column('_comment_count', sa.Integer, default=0)
)
def comment_count(self):
return sa.func.count('1')
comments = sa.orm.relationship('Comment', backref='thread')
comments = sa.orm.relationship('Comment', backref='thread')
return Thread
class Comment(self.Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
self.Thread = Thread
self.Comment = Comment
@pytest.fixture
def Comment(Base):
class Comment(Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
return Comment
def test_assigns_aggregates_on_insert(self):
thread = self.Thread()
self.session.add(thread)
comment = self.Comment(thread=thread)
self.session.add(comment)
self.session.commit()
self.session.refresh(thread)
@pytest.fixture
def init_models(Thread, Comment):
pass
class TestAggregatedWithColumnAlias(object):
def test_assigns_aggregates_on_insert(self, session, Thread, Comment):
thread = Thread()
session.add(thread)
comment = Comment(thread=thread)
session.add(comment)
session.commit()
session.refresh(thread)
assert thread.comment_count == 1
def test_assigns_aggregates_on_separate_insert(self):
thread = self.Thread()
self.session.add(thread)
self.session.commit()
comment = self.Comment(thread=thread)
self.session.add(comment)
self.session.commit()
self.session.refresh(thread)
def test_assigns_aggregates_on_separate_insert(
self,
session,
Thread,
Comment
):
thread = Thread()
session.add(thread)
session.commit()
comment = Comment(thread=thread)
session.add(comment)
session.commit()
session.refresh(thread)
assert thread.comment_count == 1
def test_assigns_aggregates_on_delete(self):
thread = self.Thread()
self.session.add(thread)
self.session.commit()
comment = self.Comment(thread=thread)
self.session.add(comment)
self.session.commit()
self.session.delete(comment)
self.session.commit()
self.session.refresh(thread)
def test_assigns_aggregates_on_delete(self, session, Thread, Comment):
thread = Thread()
session.add(thread)
session.commit()
comment = Comment(thread=thread)
session.add(comment)
session.commit()
session.delete(comment)
session.commit()
session.refresh(thread)
assert thread.comment_count == 0

View File

@ -1,47 +1,56 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregateValueGenerationWithCascadeDelete(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.fixture
def Thread(Base):
class Thread(Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
def create_models(self):
class Thread(self.Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated('comments', sa.Column(sa.Integer, default=0))
def comment_count(self):
return sa.func.count('1')
@aggregated('comments', sa.Column(sa.Integer, default=0))
def comment_count(self):
return sa.func.count('1')
comments = sa.orm.relationship(
'Comment',
passive_deletes=True,
backref='thread'
)
return Thread
comments = sa.orm.relationship(
'Comment',
passive_deletes=True,
backref='thread'
)
class Comment(self.Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.Unicode(255))
thread_id = sa.Column(
sa.Integer,
sa.ForeignKey('thread.id', ondelete='CASCADE')
)
@pytest.fixture
def Comment(Base):
class Comment(Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.Unicode(255))
thread_id = sa.Column(
sa.Integer,
sa.ForeignKey('thread.id', ondelete='CASCADE')
)
return Comment
self.Thread = Thread
self.Comment = Comment
def test_something(self):
thread = self.Thread()
@pytest.fixture
def init_models(Thread, Comment):
pass
@pytest.mark.usefixtures('postgresql_dsn')
class TestAggregateValueGenerationWithCascadeDelete(object):
def test_something(self, session, Thread, Comment):
thread = Thread()
thread.name = u'some article name'
self.session.add(thread)
comment = self.Comment(content=u'Some content', thread=thread)
self.session.add(comment)
self.session.commit()
self.session.expire_all()
self.session.delete(thread)
self.session.commit()
session.add(thread)
comment = Comment(content=u'Some content', thread=thread)
session.add(comment)
session.commit()
session.expire_all()
session.delete(thread)
session.commit()

View File

@ -1,29 +1,35 @@
import pytest
from sqlalchemy_utils import analyze
from tests import TestCase
class TestAnalyzeWithPostgres(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.mark.usefixtures('postgresql_dsn')
class TestAnalyzeWithPostgres(object):
def test_runtime(self):
query = self.session.query(self.Article)
assert analyze(self.connection, query).runtime
def test_runtime(self, session, connection, Article):
query = session.query(Article)
assert analyze(connection, query).runtime
def test_node_types_with_join(self):
def test_node_types_with_join(self, session, connection, Article):
query = (
self.session.query(self.Article)
.join(self.Article.category)
session.query(Article)
.join(Article.category)
)
analysis = analyze(self.connection, query)
analysis = analyze(connection, query)
assert analysis.node_types == [
u'Hash Join', u'Seq Scan', u'Hash', u'Seq Scan'
]
def test_node_types_with_index_only_scan(self):
def test_node_types_with_index_only_scan(
self,
session,
connection,
Article
):
query = (
self.session.query(self.Article.name)
.order_by(self.Article.name)
session.query(Article.name)
.order_by(Article.name)
.limit(10)
)
analysis = analyze(self.connection, query)
analysis = analyze(connection, query)
assert analysis.node_types == [u'Limit', u'Index Only Scan']

View File

@ -1,11 +1,8 @@
import os
import pytest
import sqlalchemy as sa
from flexmock import flexmock
from pytest import mark
from sqlalchemy_utils import create_database, database_exists, drop_database
from tests import TestCase
pymysql = None
try:
@ -14,38 +11,73 @@ except ImportError:
pass
class DatabaseTest(TestCase):
def test_create_and_drop(self):
assert not database_exists(self.url)
create_database(self.url)
assert database_exists(self.url)
drop_database(self.url)
assert not database_exists(self.url)
class DatabaseTest(object):
def test_create_and_drop(self, dsn):
assert not database_exists(dsn)
create_database(dsn)
assert database_exists(dsn)
drop_database(dsn)
assert not database_exists(dsn)
class TestDatabaseSQLite(DatabaseTest):
url = 'sqlite:///sqlalchemy_utils.db'
@pytest.mark.usefixtures('sqlite_memory_dsn')
class TestDatabaseSQLiteMemory(object):
def setup(self):
if os.path.exists('sqlalchemy_utils.db'):
os.remove('sqlalchemy_utils.db')
def test_exists_memory(self):
assert database_exists('sqlite:///:memory:')
def test_exists_memory(self, dsn):
assert database_exists(dsn)
@mark.skipif('pymysql is None')
@pytest.mark.usefixtures('sqlite_file_dsn')
class TestDatabaseSQLiteFile(DatabaseTest):
pass
@pytest.mark.skipif('pymysql is None')
@pytest.mark.usefixtures('mysql_dsn')
class TestDatabaseMySQL(DatabaseTest):
url = 'mysql+pymysql://travis@localhost/db_test_sqlalchemy_util'
@pytest.fixture
def db_name(self):
return 'db_test_sqlalchemy_util'
@mark.skipif('pymysql is None')
@pytest.mark.skipif('pymysql is None')
@pytest.mark.usefixtures('mysql_dsn')
class TestDatabaseMySQLWithQuotedName(DatabaseTest):
url = 'mysql+pymysql://travis@localhost/db_test_sqlalchemy-util'
@pytest.fixture
def db_name(self):
return 'db_test_sqlalchemy-util'
@pytest.mark.usefixtures('postgresql_dsn')
class TestDatabasePostgres(DatabaseTest):
@pytest.fixture
def db_name(self):
return 'db_test_sqlalchemy_util'
def test_template(self):
(
flexmock(sa.engine.Engine)
.should_receive('execute')
.with_args(
"CREATE DATABASE db_test_sqlalchemy_util ENCODING 'utf8' "
"TEMPLATE my_template"
)
)
create_database(
'postgres://postgres@localhost/db_test_sqlalchemy_util',
template='my_template'
)
@pytest.mark.usefixtures('postgresql_dsn')
class TestDatabasePostgresWithQuotedName(DatabaseTest):
url = 'postgres://postgres@localhost/db_test_sqlalchemy-util'
@pytest.fixture
def db_name(self):
return 'db_test_sqlalchemy-util'
def test_template(self):
(
@ -61,21 +93,3 @@ class TestDatabasePostgresWithQuotedName(DatabaseTest):
'postgres://postgres@localhost/db_test_sqlalchemy-util',
template='my-template'
)
class TestDatabasePostgres(DatabaseTest):
url = 'postgres://postgres@localhost/db_test_sqlalchemy_util'
def test_template(self):
(
flexmock(sa.engine.Engine)
.should_receive('execute')
.with_args(
"CREATE DATABASE db_test_sqlalchemy_util ENCODING 'utf8' "
"TEMPLATE my_template"
)
)
create_database(
'postgres://postgres@localhost/db_test_sqlalchemy_util',
template='my_template'
)

View File

@ -1,18 +1,23 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils import dependent_objects, get_referencing_foreign_keys
from tests import TestCase
class TestDependentObjects(TestCase):
def create_models(self):
class User(self.Base):
class TestDependentObjects(object):
@pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
first_name = sa.Column(sa.Unicode(255))
last_name = sa.Column(sa.Unicode(255))
return User
class Article(self.Base):
@pytest.fixture
def Article(self, Base, User):
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
@ -22,8 +27,11 @@ class TestDependentObjects(TestCase):
author = sa.orm.relationship(User, foreign_keys=[author_id])
owner = sa.orm.relationship(User, foreign_keys=[owner_id])
return Article
class BlogPost(self.Base):
@pytest.fixture
def BlogPost(self, Base, User):
class BlogPost(Base):
__tablename__ = 'blog_post'
id = sa.Column(sa.Integer, primary_key=True)
owner_id = sa.Column(
@ -31,21 +39,22 @@ class TestDependentObjects(TestCase):
)
owner = sa.orm.relationship(User)
return BlogPost
self.User = User
self.Article = Article
self.BlogPost = BlogPost
@pytest.fixture
def init_models(self, User, Article, BlogPost):
pass
def test_returns_all_dependent_objects(self):
user = self.User(first_name=u'John')
def test_returns_all_dependent_objects(self, session, User, Article):
user = User(first_name=u'John')
articles = [
self.Article(author=user),
self.Article(),
self.Article(owner=user),
self.Article(author=user, owner=user)
Article(author=user),
Article(),
Article(owner=user),
Article(author=user, owner=user)
]
self.session.add_all(articles)
self.session.commit()
session.add_all(articles)
session.commit()
deps = list(dependent_objects(user))
assert len(deps) == 3
@ -53,23 +62,29 @@ class TestDependentObjects(TestCase):
assert articles[2] in deps
assert articles[3] in deps
def test_with_foreign_keys_parameter(self):
user = self.User(first_name=u'John')
def test_with_foreign_keys_parameter(
self,
session,
User,
Article,
BlogPost
):
user = User(first_name=u'John')
objects = [
self.Article(author=user),
self.Article(),
self.Article(owner=user),
self.Article(author=user, owner=user),
self.BlogPost(owner=user)
Article(author=user),
Article(),
Article(owner=user),
Article(author=user, owner=user),
BlogPost(owner=user)
]
self.session.add_all(objects)
self.session.commit()
session.add_all(objects)
session.commit()
deps = list(
dependent_objects(
user,
(
fk for fk in get_referencing_foreign_keys(self.User)
fk for fk in get_referencing_foreign_keys(User)
if fk.ondelete == 'RESTRICT' or fk.ondelete is None
)
).limit(5)
@ -79,15 +94,20 @@ class TestDependentObjects(TestCase):
assert objects[3] in deps
class TestDependentObjectsWithColumnAliases(TestCase):
def create_models(self):
class User(self.Base):
class TestDependentObjectsWithColumnAliases(object):
@pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
first_name = sa.Column(sa.Unicode(255))
last_name = sa.Column(sa.Unicode(255))
return User
class Article(self.Base):
@pytest.fixture
def Article(self, Base, User):
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column(
@ -100,8 +120,11 @@ class TestDependentObjectsWithColumnAliases(TestCase):
author = sa.orm.relationship(User, foreign_keys=[author_id])
owner = sa.orm.relationship(User, foreign_keys=[owner_id])
return Article
class BlogPost(self.Base):
@pytest.fixture
def BlogPost(self, Base, User):
class BlogPost(Base):
__tablename__ = 'blog_post'
id = sa.Column(sa.Integer, primary_key=True)
owner_id = sa.Column(
@ -110,21 +133,22 @@ class TestDependentObjectsWithColumnAliases(TestCase):
)
owner = sa.orm.relationship(User)
return BlogPost
self.User = User
self.Article = Article
self.BlogPost = BlogPost
@pytest.fixture
def init_models(self, User, Article, BlogPost):
pass
def test_returns_all_dependent_objects(self):
user = self.User(first_name=u'John')
def test_returns_all_dependent_objects(self, session, User, Article):
user = User(first_name=u'John')
articles = [
self.Article(author=user),
self.Article(),
self.Article(owner=user),
self.Article(author=user, owner=user)
Article(author=user),
Article(),
Article(owner=user),
Article(author=user, owner=user)
]
self.session.add_all(articles)
self.session.commit()
session.add_all(articles)
session.commit()
deps = list(dependent_objects(user))
assert len(deps) == 3
@ -132,23 +156,29 @@ class TestDependentObjectsWithColumnAliases(TestCase):
assert articles[2] in deps
assert articles[3] in deps
def test_with_foreign_keys_parameter(self):
user = self.User(first_name=u'John')
def test_with_foreign_keys_parameter(
self,
session,
User,
Article,
BlogPost
):
user = User(first_name=u'John')
objects = [
self.Article(author=user),
self.Article(),
self.Article(owner=user),
self.Article(author=user, owner=user),
self.BlogPost(owner=user)
Article(author=user),
Article(),
Article(owner=user),
Article(author=user, owner=user),
BlogPost(owner=user)
]
self.session.add_all(objects)
self.session.commit()
session.add_all(objects)
session.commit()
deps = list(
dependent_objects(
user,
(
fk for fk in get_referencing_foreign_keys(self.User)
fk for fk in get_referencing_foreign_keys(User)
if fk.ondelete == 'RESTRICT' or fk.ondelete is None
)
).limit(5)
@ -158,50 +188,64 @@ class TestDependentObjectsWithColumnAliases(TestCase):
assert objects[3] in deps
class TestDependentObjectsWithManyReferences(TestCase):
def create_models(self):
class User(self.Base):
class TestDependentObjectsWithManyReferences(object):
@pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
first_name = sa.Column(sa.Unicode(255))
last_name = sa.Column(sa.Unicode(255))
return User
class BlogPost(self.Base):
@pytest.fixture
def BlogPost(self, Base, User):
class BlogPost(Base):
__tablename__ = 'blog_post'
id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
author = sa.orm.relationship(User)
return BlogPost
class Article(self.Base):
@pytest.fixture
def Article(self, Base, User):
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
author = sa.orm.relationship(User)
return Article
self.User = User
self.Article = Article
self.BlogPost = BlogPost
@pytest.fixture
def init_models(self, User, BlogPost, Article):
pass
def test_with_many_dependencies(self):
user = self.User(first_name=u'John')
def test_with_many_dependencies(self, session, User, Article, BlogPost):
user = User(first_name=u'John')
objects = [
self.Article(author=user),
self.BlogPost(author=user)
Article(author=user),
BlogPost(author=user)
]
self.session.add_all(objects)
self.session.commit()
session.add_all(objects)
session.commit()
deps = list(dependent_objects(user))
assert len(deps) == 2
class TestDependentObjectsWithCompositeKeys(TestCase):
def create_models(self):
class User(self.Base):
class TestDependentObjectsWithCompositeKeys(object):
@pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user'
first_name = sa.Column(sa.Unicode(255), primary_key=True)
last_name = sa.Column(sa.Unicode(255), primary_key=True)
return User
class Article(self.Base):
@pytest.fixture
def Article(self, Base, User):
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_first_name = sa.Column(sa.Unicode(255))
@ -214,20 +258,22 @@ class TestDependentObjectsWithCompositeKeys(TestCase):
)
author = sa.orm.relationship(User)
return Article
self.User = User
self.Article = Article
@pytest.fixture
def init_models(self, User, Article):
pass
def test_returns_all_dependent_objects(self):
user = self.User(first_name=u'John', last_name=u'Smith')
def test_returns_all_dependent_objects(self, session, User, Article):
user = User(first_name=u'John', last_name=u'Smith')
articles = [
self.Article(author=user),
self.Article(),
self.Article(),
self.Article(author=user)
Article(author=user),
Article(),
Article(),
Article(author=user)
]
self.session.add_all(articles)
self.session.commit()
session.add_all(articles)
session.commit()
deps = list(dependent_objects(user))
assert len(deps) == 2
@ -235,14 +281,19 @@ class TestDependentObjectsWithCompositeKeys(TestCase):
assert articles[3] in deps
class TestDependentObjectsWithSingleTableInheritance(TestCase):
def create_models(self):
class Category(self.Base):
class TestDependentObjectsWithSingleTableInheritance(object):
@pytest.fixture
def Category(self, Base):
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
return Category
class TextItem(self.Base):
@pytest.fixture
def TextItem(self, Base, Category):
class TextItem(Base):
__tablename__ = 'text_item'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@ -261,33 +312,39 @@ class TestDependentObjectsWithSingleTableInheritance(TestCase):
__mapper_args__ = {
'polymorphic_on': type,
}
return TextItem
@pytest.fixture
def Article(self, TextItem):
class Article(TextItem):
__mapper_args__ = {
'polymorphic_identity': u'article'
}
return Article
@pytest.fixture
def BlogPost(self, TextItem):
class BlogPost(TextItem):
__mapper_args__ = {
'polymorphic_identity': u'blog_post'
}
return BlogPost
self.Category = Category
self.TextItem = TextItem
self.Article = Article
self.BlogPost = BlogPost
@pytest.fixture
def init_models(self, Category, TextItem, Article, BlogPost):
pass
def test_returns_all_dependent_objects(self):
category1 = self.Category(name=u'Category #1')
category2 = self.Category(name=u'Category #2')
def test_returns_all_dependent_objects(self, session, Category, Article):
category1 = Category(name=u'Category #1')
category2 = Category(name=u'Category #2')
articles = [
self.Article(category=category1),
self.Article(category=category1),
self.Article(category=category2),
self.Article(category=category2),
Article(category=category1),
Article(category=category1),
Article(category=category2),
Article(category=category2),
]
self.session.add_all(articles)
self.session.commit()
session.add_all(articles)
session.commit()
deps = list(dependent_objects(category1))
assert len(deps) == 2

View File

@ -1,7 +1,6 @@
from sqlalchemy_utils import escape_like
from tests import TestCase
class TestEscapeLike(TestCase):
class TestEscapeLike(object):
def test_escapes_wildcards(self):
assert escape_like('_*%') == '*_***%'

View File

@ -1,21 +1,20 @@
from pytest import raises
import pytest
from sqlalchemy_utils import get_bind
from tests import TestCase
class TestGetBind(TestCase):
def test_with_session(self):
assert get_bind(self.session) == self.connection
class TestGetBind(object):
def test_with_session(self, session, connection):
assert get_bind(session) == connection
def test_with_connection(self):
assert get_bind(self.connection) == self.connection
def test_with_connection(self, session, connection):
assert get_bind(connection) == connection
def test_with_model_object(self):
article = self.Article()
self.session.add(article)
assert get_bind(article) == self.connection
def test_with_model_object(self, session, connection, Article):
article = Article()
session.add(article)
assert get_bind(article) == connection
def test_with_unknown_type(self):
with raises(TypeError):
with pytest.raises(TypeError):
get_bind(None)

View File

@ -1,15 +1,14 @@
import pytest
import sqlalchemy as sa
from pytest import raises
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_class_by_table
class TestGetClassByTableWithJoinedTableInheritance(object):
def setup_method(self, method):
self.Base = declarative_base()
class Entity(self.Base):
@pytest.fixture
def Entity(self, Base):
class Entity(Base):
__tablename__ = 'entity'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
@ -18,7 +17,10 @@ class TestGetClassByTableWithJoinedTableInheritance(object):
'polymorphic_on': type,
'polymorphic_identity': 'entity'
}
return Entity
@pytest.fixture
def User(self, Entity):
class User(Entity):
__tablename__ = 'user'
id = sa.Column(
@ -29,31 +31,29 @@ class TestGetClassByTableWithJoinedTableInheritance(object):
__mapper_args__ = {
'polymorphic_identity': 'user'
}
return User
self.Entity = Entity
self.User = User
def test_returns_class(self):
assert get_class_by_table(self.Base, self.User.__table__) == self.User
def test_returns_class(self, Base, User, Entity):
assert get_class_by_table(Base, User.__table__) == User
assert get_class_by_table(
self.Base,
self.Entity.__table__
) == self.Entity
Base,
Entity.__table__
) == Entity
def test_table_with_no_associated_class(self):
def test_table_with_no_associated_class(self, Base):
table = sa.Table(
'some_table',
self.Base.metadata,
Base.metadata,
sa.Column('id', sa.Integer)
)
assert get_class_by_table(self.Base, table) is None
assert get_class_by_table(Base, table) is None
class TestGetClassByTableWithSingleTableInheritance(object):
def setup_method(self, method):
self.Base = declarative_base()
class Entity(self.Base):
@pytest.fixture
def Entity(self, Base):
class Entity(Base):
__tablename__ = 'entity'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
@ -62,38 +62,39 @@ class TestGetClassByTableWithSingleTableInheritance(object):
'polymorphic_on': type,
'polymorphic_identity': 'entity'
}
return Entity
@pytest.fixture
def User(self, Entity):
class User(Entity):
__mapper_args__ = {
'polymorphic_identity': 'user'
}
return User
self.Entity = Entity
self.User = User
def test_multiple_classes_without_data_parameter(self):
with raises(ValueError):
def test_multiple_classes_without_data_parameter(self, Base, Entity, User):
with pytest.raises(ValueError):
assert get_class_by_table(
self.Base,
self.Entity.__table__
Base,
Entity.__table__
)
def test_multiple_classes_with_data_parameter(self):
def test_multiple_classes_with_data_parameter(self, Base, Entity, User):
assert get_class_by_table(
self.Base,
self.Entity.__table__,
Base,
Entity.__table__,
{'type': 'entity'}
) == self.Entity
) == Entity
assert get_class_by_table(
self.Base,
self.Entity.__table__,
Base,
Entity.__table__,
{'type': 'user'}
) == self.User
) == User
def test_multiple_classes_with_bogus_data(self):
with raises(ValueError):
def test_multiple_classes_with_bogus_data(self, Base, Entity, User):
with pytest.raises(ValueError):
assert get_class_by_table(
self.Base,
self.Entity.__table__,
Base,
Entity.__table__,
{'type': 'unknown'}
)

View File

@ -1,42 +1,44 @@
from copy import copy
import pytest
import sqlalchemy as sa
from pytest import raises
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_column_key
@pytest.fixture
def Building(Base):
class Building(Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column('_name', sa.Unicode(255))
return Building
@pytest.fixture
def Movie(Base):
class Movie(Base):
__tablename__ = 'movie'
id = sa.Column(sa.Integer, primary_key=True)
return Movie
class TestGetColumnKey(object):
def setup_method(self, method):
Base = declarative_base()
class Building(Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column('_name', sa.Unicode(255))
class Movie(Base):
__tablename__ = 'movie'
id = sa.Column(sa.Integer, primary_key=True)
self.Building = Building
self.Movie = Movie
def test_supports_aliases(self):
def test_supports_aliases(self, Building):
assert (
get_column_key(self.Building, self.Building.__table__.c.id) ==
get_column_key(Building, Building.__table__.c.id) ==
'id'
)
assert (
get_column_key(self.Building, self.Building.__table__.c._name) ==
get_column_key(Building, Building.__table__.c._name) ==
'name'
)
def test_supports_vague_matching_of_column_objects(self):
column = copy(self.Building.__table__.c._name)
assert get_column_key(self.Building, column) == 'name'
def test_supports_vague_matching_of_column_objects(self, Building):
column = copy(Building.__table__.c._name)
assert get_column_key(Building, column) == 'name'
def test_throws_value_error_for_unknown_column(self):
with raises(sa.orm.exc.UnmappedColumnError):
get_column_key(self.Building, self.Movie.__table__.c.id)
def test_throws_value_error_for_unknown_column(self, Building, Movie):
with pytest.raises(sa.orm.exc.UnmappedColumnError):
get_column_key(Building, Movie.__table__.c.id)

View File

@ -1,65 +1,65 @@
import pytest
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_columns
@pytest.fixture
def Building(Base):
class Building(Base):
__tablename__ = 'building'
id = sa.Column('_id', sa.Integer, primary_key=True)
name = sa.Column('_name', sa.Unicode(255))
return Building
class TestGetColumns(object):
def setup_method(self, method):
Base = declarative_base()
class Building(Base):
__tablename__ = 'building'
id = sa.Column('_id', sa.Integer, primary_key=True)
name = sa.Column('_name', sa.Unicode(255))
self.Building = Building
def test_table(self):
def test_table(self, Building):
assert isinstance(
get_columns(self.Building.__table__),
get_columns(Building.__table__),
sa.sql.base.ImmutableColumnCollection
)
def test_instrumented_attribute(self):
assert get_columns(self.Building.id) == [self.Building.__table__.c._id]
def test_instrumented_attribute(self, Building):
assert get_columns(Building.id) == [Building.__table__.c._id]
def test_column_property(self):
assert get_columns(self.Building.id.property) == [
self.Building.__table__.c._id
def test_column_property(self, Building):
assert get_columns(Building.id.property) == [
Building.__table__.c._id
]
def test_column(self):
assert get_columns(self.Building.__table__.c._id) == [
self.Building.__table__.c._id
def test_column(self, Building):
assert get_columns(Building.__table__.c._id) == [
Building.__table__.c._id
]
def test_declarative_class(self):
def test_declarative_class(self, Building):
assert isinstance(
get_columns(self.Building),
get_columns(Building),
sa.util._collections.OrderedProperties
)
def test_declarative_object(self):
def test_declarative_object(self, Building):
assert isinstance(
get_columns(self.Building()),
get_columns(Building()),
sa.util._collections.OrderedProperties
)
def test_mapper(self):
def test_mapper(self, Building):
assert isinstance(
get_columns(self.Building.__mapper__),
get_columns(Building.__mapper__),
sa.util._collections.OrderedProperties
)
def test_class_alias(self):
def test_class_alias(self, Building):
assert isinstance(
get_columns(sa.orm.aliased(self.Building)),
get_columns(sa.orm.aliased(Building)),
sa.util._collections.OrderedProperties
)
def test_table_alias(self):
alias = sa.orm.aliased(self.Building.__table__)
def test_table_alias(self, Building):
alias = sa.orm.aliased(Building.__table__)
assert isinstance(
get_columns(alias),
sa.sql.base.ImmutableColumnCollection

View File

@ -1,41 +1,41 @@
import pytest
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy_utils import get_hybrid_properties
@pytest.fixture
def Category(Base):
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@hybrid_property
def lowercase_name(self):
return self.name.lower()
@lowercase_name.expression
def lowercase_name(cls):
return sa.func.lower(cls.name)
return Category
class TestGetHybridProperties(object):
def setup_method(self, method):
Base = declarative_base()
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@hybrid_property
def lowercase_name(self):
return self.name.lower()
@lowercase_name.expression
def lowercase_name(cls):
return sa.func.lower(cls.name)
self.Category = Category
def test_declarative_model(self):
def test_declarative_model(self, Category):
assert (
list(get_hybrid_properties(self.Category).keys()) ==
list(get_hybrid_properties(Category).keys()) ==
['lowercase_name']
)
def test_mapper(self):
def test_mapper(self, Category):
assert (
list(get_hybrid_properties(sa.inspect(self.Category)).keys()) ==
list(get_hybrid_properties(sa.inspect(Category)).keys()) ==
['lowercase_name']
)
def test_aliased_class(self):
props = get_hybrid_properties(sa.orm.aliased(self.Category))
def test_aliased_class(self, Category):
props = get_hybrid_properties(sa.orm.aliased(Category))
assert list(props.keys()) == ['lowercase_name']

View File

@ -1,104 +1,106 @@
import pytest
import sqlalchemy as sa
from pytest import raises
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_mapper
from tests import TestCase
class TestGetMapper(object):
def setup_method(self, method):
self.Base = declarative_base()
class Building(self.Base):
@pytest.fixture
def Building(self, Base):
class Building(Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
return Building
self.Building = Building
def test_table(self, Building):
assert get_mapper(Building.__table__) == sa.inspect(Building)
def test_table(self):
assert get_mapper(self.Building.__table__) == sa.inspect(self.Building)
def test_declarative_class(self):
def test_declarative_class(self, Building):
assert (
get_mapper(self.Building) ==
sa.inspect(self.Building)
get_mapper(Building) ==
sa.inspect(Building)
)
def test_declarative_object(self):
def test_declarative_object(self, Building):
assert (
get_mapper(self.Building()) ==
sa.inspect(self.Building)
get_mapper(Building()) ==
sa.inspect(Building)
)
def test_mapper(self):
def test_mapper(self, Building):
assert (
get_mapper(self.Building.__mapper__) ==
sa.inspect(self.Building)
get_mapper(Building.__mapper__) ==
sa.inspect(Building)
)
def test_class_alias(self):
def test_class_alias(self, Building):
assert (
get_mapper(sa.orm.aliased(self.Building)) ==
sa.inspect(self.Building)
get_mapper(sa.orm.aliased(Building)) ==
sa.inspect(Building)
)
def test_instrumented_attribute(self):
def test_instrumented_attribute(self, Building):
assert (
get_mapper(self.Building.id) == sa.inspect(self.Building)
get_mapper(Building.id) == sa.inspect(Building)
)
def test_table_alias(self):
alias = sa.orm.aliased(self.Building.__table__)
def test_table_alias(self, Building):
alias = sa.orm.aliased(Building.__table__)
assert (
get_mapper(alias) ==
sa.inspect(self.Building)
sa.inspect(Building)
)
def test_column(self):
def test_column(self, Building):
assert (
get_mapper(self.Building.__table__.c.id) ==
sa.inspect(self.Building)
get_mapper(Building.__table__.c.id) ==
sa.inspect(Building)
)
def test_column_of_an_alias(self):
def test_column_of_an_alias(self, Building):
assert (
get_mapper(sa.orm.aliased(self.Building.__table__).c.id) ==
sa.inspect(self.Building)
get_mapper(sa.orm.aliased(Building.__table__).c.id) ==
sa.inspect(Building)
)
class TestGetMapperWithQueryEntities(TestCase):
def create_models(self):
class Building(self.Base):
class TestGetMapperWithQueryEntities(object):
@pytest.fixture
def Building(self, Base):
class Building(Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
return Building
self.Building = Building
@pytest.fixture
def init_models(self, Building):
pass
def test_mapper_entity_with_mapper(self):
entity = self.session.query(self.Building.__mapper__)._entities[0]
def test_mapper_entity_with_mapper(self, session, Building):
entity = session.query(Building.__mapper__)._entities[0]
assert (
get_mapper(entity) ==
sa.inspect(self.Building)
sa.inspect(Building)
)
def test_mapper_entity_with_class(self):
entity = self.session.query(self.Building)._entities[0]
def test_mapper_entity_with_class(self, session, Building):
entity = session.query(Building)._entities[0]
assert (
get_mapper(entity) ==
sa.inspect(self.Building)
sa.inspect(Building)
)
def test_column_entity(self):
query = self.session.query(self.Building.id)
assert get_mapper(query._entities[0]) == sa.inspect(self.Building)
def test_column_entity(self, session, Building):
query = session.query(Building.id)
assert get_mapper(query._entities[0]) == sa.inspect(Building)
class TestGetMapperWithMultipleMappersFound(object):
def setup_method(self, method):
Base = declarative_base()
@pytest.fixture
def Building(self, Base):
class Building(Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
@ -106,29 +108,30 @@ class TestGetMapperWithMultipleMappersFound(object):
class BigBuilding(Building):
pass
self.Building = Building
self.BigBuilding = BigBuilding
return Building
def test_table(self):
with raises(ValueError):
get_mapper(self.Building.__table__)
def test_table(self, Building):
with pytest.raises(ValueError):
get_mapper(Building.__table__)
def test_table_alias(self):
alias = sa.orm.aliased(self.Building.__table__)
with raises(ValueError):
def test_table_alias(self, Building):
alias = sa.orm.aliased(Building.__table__)
with pytest.raises(ValueError):
get_mapper(alias)
class TestGetMapperForTableWithoutMapper(object):
def setup_method(self, method):
@pytest.fixture
def building(self):
metadata = sa.MetaData()
self.building = sa.Table('building', metadata)
return sa.Table('building', metadata)
def test_table(self):
with raises(ValueError):
get_mapper(self.building)
def test_table(self, building):
with pytest.raises(ValueError):
get_mapper(building)
def test_table_alias(self):
alias = sa.orm.aliased(self.building)
with raises(ValueError):
def test_table_alias(self, building):
alias = sa.orm.aliased(building)
with pytest.raises(ValueError):
get_mapper(alias)

View File

@ -1,5 +1,5 @@
import pytest
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_primary_keys
@ -9,40 +9,40 @@ except ImportError:
from ordereddict import OrderedDict
@pytest.fixture
def Building(Base):
class Building(Base):
__tablename__ = 'building'
id = sa.Column('_id', sa.Integer, primary_key=True)
name = sa.Column('_name', sa.Unicode(255))
return Building
class TestGetPrimaryKeys(object):
def setup_method(self, method):
Base = declarative_base()
class Building(Base):
__tablename__ = 'building'
id = sa.Column('_id', sa.Integer, primary_key=True)
name = sa.Column('_name', sa.Unicode(255))
self.Building = Building
def test_table(self):
assert get_primary_keys(self.Building.__table__) == OrderedDict({
'_id': self.Building.__table__.c._id
def test_table(self, Building):
assert get_primary_keys(Building.__table__) == OrderedDict({
'_id': Building.__table__.c._id
})
def test_declarative_class(self):
assert get_primary_keys(self.Building) == OrderedDict({
'id': self.Building.__table__.c._id
def test_declarative_class(self, Building):
assert get_primary_keys(Building) == OrderedDict({
'id': Building.__table__.c._id
})
def test_declarative_object(self):
assert get_primary_keys(self.Building()) == OrderedDict({
'id': self.Building.__table__.c._id
def test_declarative_object(self, Building):
assert get_primary_keys(Building()) == OrderedDict({
'id': Building.__table__.c._id
})
def test_class_alias(self):
alias = sa.orm.aliased(self.Building)
def test_class_alias(self, Building):
alias = sa.orm.aliased(Building)
assert get_primary_keys(alias) == OrderedDict({
'id': self.Building.__table__.c._id
'id': Building.__table__.c._id
})
def test_table_alias(self):
alias = sa.orm.aliased(self.Building.__table__)
def test_table_alias(self, Building):
alias = sa.orm.aliased(Building.__table__)
assert get_primary_keys(alias) == OrderedDict({
'_id': alias.c._id
})

View File

@ -1,102 +1,115 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils import get_query_entities
from tests import TestCase
class TestGetQueryEntities(TestCase):
def create_models(self):
class TextItem(self.Base):
__tablename__ = 'text_item'
id = sa.Column(sa.Integer, primary_key=True)
@pytest.fixture
def TextItem(Base):
class TextItem(Base):
__tablename__ = 'text_item'
id = sa.Column(sa.Integer, primary_key=True)
type = sa.Column(sa.Unicode(255))
type = sa.Column(sa.Unicode(255))
__mapper_args__ = {
'polymorphic_on': type,
}
__mapper_args__ = {
'polymorphic_on': type,
}
return TextItem
class Article(TextItem):
__tablename__ = 'article'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
category = sa.Column(sa.Unicode(255))
__mapper_args__ = {
'polymorphic_identity': u'article'
}
class BlogPost(TextItem):
__tablename__ = 'blog_post'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
__mapper_args__ = {
'polymorphic_identity': u'blog_post'
}
@pytest.fixture
def Article(TextItem):
class Article(TextItem):
__tablename__ = 'article'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
category = sa.Column(sa.Unicode(255))
__mapper_args__ = {
'polymorphic_identity': u'article'
}
return Article
self.TextItem = TextItem
self.Article = Article
self.BlogPost = BlogPost
def test_mapper(self):
query = self.session.query(sa.inspect(self.TextItem))
assert get_query_entities(query) == [self.TextItem]
@pytest.fixture
def BlogPost(TextItem):
class BlogPost(TextItem):
__tablename__ = 'blog_post'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
__mapper_args__ = {
'polymorphic_identity': u'blog_post'
}
return BlogPost
def test_entity(self):
query = self.session.query(self.TextItem)
assert get_query_entities(query) == [self.TextItem]
def test_instrumented_attribute(self):
query = self.session.query(self.TextItem.id)
assert get_query_entities(query) == [self.TextItem]
@pytest.fixture
def init_models(TextItem, Article, BlogPost):
pass
def test_column(self):
query = self.session.query(self.TextItem.__table__.c.id)
assert get_query_entities(query) == [self.TextItem.__table__]
def test_aliased_selectable(self):
selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
query = self.session.query(selectable)
class TestGetQueryEntities(object):
def test_mapper(self, session, TextItem):
query = session.query(sa.inspect(TextItem))
assert get_query_entities(query) == [TextItem]
def test_entity(self, session, TextItem):
query = session.query(TextItem)
assert get_query_entities(query) == [TextItem]
def test_instrumented_attribute(self, session, TextItem):
query = session.query(TextItem.id)
assert get_query_entities(query) == [TextItem]
def test_column(self, session, TextItem):
query = session.query(TextItem.__table__.c.id)
assert get_query_entities(query) == [TextItem.__table__]
def test_aliased_selectable(self, session, TextItem, BlogPost):
selectable = sa.orm.with_polymorphic(TextItem, [BlogPost])
query = session.query(selectable)
assert get_query_entities(query) == [selectable]
def test_joined_entity(self):
query = self.session.query(self.TextItem).join(
self.BlogPost, self.BlogPost.id == self.TextItem.id
def test_joined_entity(self, session, TextItem, BlogPost):
query = session.query(TextItem).join(
BlogPost, BlogPost.id == TextItem.id
)
assert get_query_entities(query) == [
self.TextItem, sa.inspect(self.BlogPost)
TextItem, sa.inspect(BlogPost)
]
def test_joined_aliased_entity(self):
alias = sa.orm.aliased(self.BlogPost)
def test_joined_aliased_entity(self, session, TextItem, BlogPost):
alias = sa.orm.aliased(BlogPost)
query = self.session.query(self.TextItem).join(
alias, alias.id == self.TextItem.id
query = session.query(TextItem).join(
alias, alias.id == TextItem.id
)
assert get_query_entities(query) == [self.TextItem, alias]
assert get_query_entities(query) == [TextItem, alias]
def test_column_entity_with_label(self):
query = self.session.query(self.Article.id.label('id'))
assert get_query_entities(query) == [self.Article]
def test_column_entity_with_label(self, session, Article):
query = session.query(Article.id.label('id'))
assert get_query_entities(query) == [Article]
def test_with_subquery(self):
def test_with_subquery(self, session, Article):
number_of_articles = (
sa.select(
[sa.func.count(self.Article.id)],
[sa.func.count(Article.id)],
)
.select_from(
self.Article.__table__
Article.__table__
)
).label('number_of_articles')
query = self.session.query(self.Article, number_of_articles)
query = session.query(Article, number_of_articles)
assert get_query_entities(query) == [
self.Article,
Article,
number_of_articles
]
def test_aliased_entity(self):
alias = sa.orm.aliased(self.Article)
query = self.session.query(alias)
def test_aliased_entity(self, session, Article):
alias = sa.orm.aliased(Article)
query = session.query(alias)
assert get_query_entities(query) == [alias]

View File

@ -1,17 +1,22 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils import get_referencing_foreign_keys
from tests import TestCase
class TestGetReferencingFksWithCompositeKeys(TestCase):
def create_models(self):
class User(self.Base):
class TestGetReferencingFksWithCompositeKeys(object):
@pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user'
first_name = sa.Column(sa.Unicode(255), primary_key=True)
last_name = sa.Column(sa.Unicode(255), primary_key=True)
return User
class Article(self.Base):
@pytest.fixture
def Article(self, Base, User):
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_first_name = sa.Column(sa.Unicode(255))
@ -22,22 +27,26 @@ class TestGetReferencingFksWithCompositeKeys(TestCase):
[User.first_name, User.last_name]
),
)
return Article
self.User = User
self.Article = Article
@pytest.fixture
def init_models(self, User, Article):
pass
def test_with_declarative_class(self):
fks = get_referencing_foreign_keys(self.User)
assert self.Article.__table__.foreign_keys == fks
def test_with_declarative_class(self, User, Article):
fks = get_referencing_foreign_keys(User)
assert Article.__table__.foreign_keys == fks
def test_with_table(self):
fks = get_referencing_foreign_keys(self.User.__table__)
assert self.Article.__table__.foreign_keys == fks
def test_with_table(self, User, Article):
fks = get_referencing_foreign_keys(User.__table__)
assert Article.__table__.foreign_keys == fks
class TestGetReferencingFksWithInheritance(TestCase):
def create_models(self):
class User(self.Base):
class TestGetReferencingFksWithInheritance(object):
@pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
type = sa.Column(sa.Unicode)
@ -47,14 +56,20 @@ class TestGetReferencingFksWithInheritance(TestCase):
__mapper_args__ = {
'polymorphic_on': 'type'
}
return User
@pytest.fixture
def Admin(self, User):
class Admin(User):
__tablename__ = 'admin'
id = sa.Column(
sa.Integer, sa.ForeignKey(User.id), primary_key=True
)
return Admin
class TextItem(self.Base):
@pytest.fixture
def TextItem(self, Base, User):
class TextItem(Base):
__tablename__ = 'textitem'
id = sa.Column(sa.Integer, primary_key=True)
type = sa.Column(sa.Unicode)
@ -62,7 +77,10 @@ class TestGetReferencingFksWithInheritance(TestCase):
__mapper_args__ = {
'polymorphic_on': 'type'
}
return TextItem
@pytest.fixture
def Article(self, TextItem):
class Article(TextItem):
__tablename__ = 'article'
id = sa.Column(
@ -71,16 +89,16 @@ class TestGetReferencingFksWithInheritance(TestCase):
__mapper_args__ = {
'polymorphic_identity': 'article'
}
return Article
self.Admin = Admin
self.User = User
self.Article = Article
self.TextItem = TextItem
@pytest.fixture
def init_models(self, User, Admin, TextItem, Article):
pass
def test_with_declarative_class(self):
fks = get_referencing_foreign_keys(self.Admin)
assert self.TextItem.__table__.foreign_keys == fks
def test_with_declarative_class(self, Admin, TextItem):
fks = get_referencing_foreign_keys(Admin)
assert TextItem.__table__.foreign_keys == fks
def test_with_table(self):
fks = get_referencing_foreign_keys(self.Admin.__table__)
def test_with_table(self, Admin):
fks = get_referencing_foreign_keys(Admin.__table__)
assert fks == set([])

View File

@ -1,76 +1,86 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils import get_tables
from tests import TestCase
class TestGetTables(TestCase):
def create_models(self):
class TextItem(self.Base):
__tablename__ = 'text_item'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
type = sa.Column(sa.Unicode(255))
@pytest.fixture
def TextItem(Base):
class TextItem(Base):
__tablename__ = 'text_item'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
type = sa.Column(sa.Unicode(255))
__mapper_args__ = {
'polymorphic_on': type,
'with_polymorphic': '*'
}
__mapper_args__ = {
'polymorphic_on': type,
'with_polymorphic': '*'
}
return TextItem
class Article(TextItem):
__tablename__ = 'article'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
__mapper_args__ = {
'polymorphic_identity': u'article'
}
self.TextItem = TextItem
self.Article = Article
@pytest.fixture
def Article(TextItem):
class Article(TextItem):
__tablename__ = 'article'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
__mapper_args__ = {
'polymorphic_identity': u'article'
}
return Article
def test_child_class_using_join_table_inheritance(self):
assert get_tables(self.Article) == [
self.TextItem.__table__,
self.Article.__table__
@pytest.fixture
def init_models(TextItem, Article):
pass
class TestGetTables(object):
def test_child_class_using_join_table_inheritance(self, TextItem, Article):
assert get_tables(Article) == [
TextItem.__table__,
Article.__table__
]
def test_entity_using_with_polymorphic(self):
assert get_tables(self.TextItem) == [
self.TextItem.__table__,
self.Article.__table__
def test_entity_using_with_polymorphic(self, TextItem, Article):
assert get_tables(TextItem) == [
TextItem.__table__,
Article.__table__
]
def test_instrumented_attribute(self):
assert get_tables(self.TextItem.name) == [
self.TextItem.__table__,
def test_instrumented_attribute(self, TextItem):
assert get_tables(TextItem.name) == [
TextItem.__table__,
]
def test_polymorphic_instrumented_attribute(self):
assert get_tables(self.Article.id) == [
self.TextItem.__table__,
self.Article.__table__
def test_polymorphic_instrumented_attribute(self, TextItem, Article):
assert get_tables(Article.id) == [
TextItem.__table__,
Article.__table__
]
def test_column(self):
assert get_tables(self.Article.__table__.c.id) == [
self.Article.__table__
def test_column(self, Article):
assert get_tables(Article.__table__.c.id) == [
Article.__table__
]
def test_mapper_entity_with_class(self):
query = self.session.query(self.Article)
def test_mapper_entity_with_class(self, session, TextItem, Article):
query = session.query(Article)
assert get_tables(query._entities[0]) == [
self.TextItem.__table__, self.Article.__table__
TextItem.__table__, Article.__table__
]
def test_mapper_entity_with_mapper(self):
query = self.session.query(sa.inspect(self.Article))
def test_mapper_entity_with_mapper(self, session, TextItem, Article):
query = session.query(sa.inspect(Article))
assert get_tables(query._entities[0]) == [
self.TextItem.__table__, self.Article.__table__
TextItem.__table__, Article.__table__
]
def test_column_entity(self):
query = self.session.query(self.Article.id)
def test_column_entity(self, session, TextItem, Article):
query = session.query(Article.id)
assert get_tables(query._entities[0]) == [
self.TextItem.__table__, self.Article.__table__
TextItem.__table__, Article.__table__
]

View File

@ -1,46 +1,49 @@
import pytest
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_type
@pytest.fixture
def User(Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
return User
@pytest.fixture
def Article(Base, User):
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
author = sa.orm.relationship(User)
some_property = sa.orm.column_property(
sa.func.coalesce(id, 1)
)
return Article
class TestGetType(object):
def setup_method(self, method):
Base = declarative_base()
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
def test_instrumented_attribute(self, Article):
assert isinstance(get_type(Article.id), sa.Integer)
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
def test_column_property(self, Article):
assert isinstance(get_type(Article.id.property), sa.Integer)
author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
author = sa.orm.relationship(User)
def test_column(self, Article):
assert isinstance(get_type(Article.__table__.c.id), sa.Integer)
some_property = sa.orm.column_property(
sa.func.coalesce(id, 1)
)
def test_calculated_column_property(self, Article):
assert isinstance(get_type(Article.some_property), sa.Integer)
self.Article = Article
self.User = User
def test_relationship_property(self, Article, User):
assert get_type(Article.author) == User
def test_instrumented_attribute(self):
assert isinstance(get_type(self.Article.id), sa.Integer)
def test_column_property(self):
assert isinstance(get_type(self.Article.id.property), sa.Integer)
def test_column(self):
assert isinstance(get_type(self.Article.__table__.c.id), sa.Integer)
def test_calculated_column_property(self):
assert isinstance(get_type(self.Article.some_property), sa.Integer)
def test_relationship_property(self):
assert get_type(self.Article.author) == self.User
def test_scalar_select(self):
query = sa.select([self.Article.id]).as_scalar()
def test_scalar_select(self, Article):
query = sa.select([Article.id]).as_scalar()
assert isinstance(get_type(query), sa.Integer)

View File

@ -1,72 +1,94 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.functions import getdotattr
from tests import TestCase
class TestGetDotAttr(TestCase):
def create_models(self):
class Document(self.Base):
__tablename__ = 'document'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@pytest.fixture
def Document(Base):
class Document(Base):
__tablename__ = 'document'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
return Document
class Section(self.Base):
__tablename__ = 'section'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
document_id = sa.Column(
sa.Integer, sa.ForeignKey(Document.id)
)
@pytest.fixture
def Section(Base, Document):
class Section(Base):
__tablename__ = 'section'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
document = sa.orm.relationship(Document, backref='sections')
document_id = sa.Column(
sa.Integer, sa.ForeignKey(Document.id)
)
class SubSection(self.Base):
__tablename__ = 'subsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
document = sa.orm.relationship(Document, backref='sections')
return Section
section_id = sa.Column(
sa.Integer, sa.ForeignKey(Section.id)
)
section = sa.orm.relationship(Section, backref='subsections')
@pytest.fixture
def SubSection(Base, Section):
class SubSection(Base):
__tablename__ = 'subsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class SubSubSection(self.Base):
__tablename__ = 'subsubsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
section_id = sa.Column(
sa.Integer, sa.ForeignKey(Section.id)
)
subsection_id = sa.Column(
sa.Integer, sa.ForeignKey(SubSection.id)
)
section = sa.orm.relationship(Section, backref='subsections')
return SubSection
subsection = sa.orm.relationship(
SubSection, backref='subsubsections'
)
self.Document = Document
self.Section = Section
self.SubSection = SubSection
self.SubSubSection = SubSubSection
@pytest.fixture
def SubSubSection(Base, SubSection):
class SubSubSection(Base):
__tablename__ = 'subsubsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
def test_simple_objects(self):
document = self.Document(name=u'some document')
section = self.Section(document=document)
subsection = self.SubSection(section=section)
subsection_id = sa.Column(
sa.Integer, sa.ForeignKey(SubSection.id)
)
subsection = sa.orm.relationship(
SubSection, backref='subsubsections'
)
return SubSubSection
@pytest.fixture
def init_models(Document, Section, SubSection, SubSubSection):
pass
class TestGetDotAttr(object):
def test_simple_objects(self, Document, Section, SubSection):
document = Document(name=u'some document')
section = Section(document=document)
subsection = SubSection(section=section)
assert getdotattr(
subsection,
'section.document.name'
) == u'some document'
def test_with_instrumented_lists(self):
document = self.Document(name=u'some document')
section = self.Section(document=document)
subsection = self.SubSection(section=section)
subsubsection = self.SubSubSection(subsection=subsection)
def test_with_instrumented_lists(
self,
Document,
Section,
SubSection,
SubSubSection
):
document = Document(name=u'some document')
section = Section(document=document)
subsection = SubSection(section=section)
subsubsection = SubSubSection(subsection=subsection)
assert getdotattr(document, 'sections') == [section]
assert getdotattr(document, 'sections.subsections') == [
@ -76,10 +98,10 @@ class TestGetDotAttr(TestCase):
subsubsection
]
def test_class_paths(self):
assert getdotattr(self.Section, 'document') is self.Section.document
def test_class_paths(self, Document, Section, SubSection):
assert getdotattr(Section, 'document') is Section.document
assert (
getdotattr(self.SubSection, 'section.document') is
self.Section.document
getdotattr(SubSection, 'section.document') is
Section.document
)
assert getdotattr(self.Section, 'document.name') is self.Document.name
assert getdotattr(Section, 'document.name') is Document.name

View File

@ -1,47 +1,44 @@
import pytest
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import has_changes
class HasChangesTestCase(object):
def setup_method(self, method):
Base = declarative_base()
class Article(Base):
__tablename__ = 'article_translation'
id = sa.Column(sa.Integer, primary_key=True)
title = sa.Column(sa.String(100))
self.Article = Article
@pytest.fixture
def Article(Base):
class Article(Base):
__tablename__ = 'article_translation'
id = sa.Column(sa.Integer, primary_key=True)
title = sa.Column(sa.String(100))
return Article
class TestHasChangesWithStringAttr(HasChangesTestCase):
def test_without_changed_attr(self):
article = self.Article()
class TestHasChangesWithStringAttr(object):
def test_without_changed_attr(self, Article):
article = Article()
assert not has_changes(article, 'title')
def test_with_changed_attr(self):
article = self.Article(title='Some title')
def test_with_changed_attr(self, Article):
article = Article(title='Some title')
assert has_changes(article, 'title')
class TestHasChangesWithMultipleAttrs(HasChangesTestCase):
def test_without_changed_attr(self):
article = self.Article()
class TestHasChangesWithMultipleAttrs(object):
def test_without_changed_attr(self, Article):
article = Article()
assert not has_changes(article, ['title'])
def test_with_changed_attr(self):
article = self.Article(title='Some title')
def test_with_changed_attr(self, Article):
article = Article(title='Some title')
assert has_changes(article, ['title', 'id'])
class TestHasChangesWithExclude(HasChangesTestCase):
def test_without_changed_attr(self):
article = self.Article()
class TestHasChangesWithExclude(object):
def test_without_changed_attr(self, Article):
article = Article()
assert not has_changes(article, exclude=['id'])
def test_with_changed_attr(self):
article = self.Article(title='Some title')
def test_with_changed_attr(self, Article):
article = Article(title='Some title')
assert has_changes(article, exclude=['id'])
assert not has_changes(article, exclude=['title'])

View File

@ -1,14 +1,13 @@
import pytest
import sqlalchemy as sa
from pytest import raises
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_fk_constraint_for_columns, has_index
class TestHasIndex(object):
def setup_method(self, method):
Base = declarative_base()
@pytest.fixture
def table(self, Base):
class ArticleTranslation(Base):
__tablename__ = 'article_translation'
id = sa.Column(sa.Integer, primary_key=True)
@ -21,24 +20,23 @@ class TestHasIndex(object):
__table_args__ = (
sa.Index('my_index', is_deleted, is_archived),
)
return ArticleTranslation.__table__
self.table = ArticleTranslation.__table__
def test_column_that_belongs_to_an_alias(self):
alias = sa.orm.aliased(self.table)
with raises(TypeError):
def test_column_that_belongs_to_an_alias(self, table):
alias = sa.orm.aliased(table)
with pytest.raises(TypeError):
assert has_index(alias.c.id)
def test_compound_primary_key(self):
assert has_index(self.table.c.id)
assert not has_index(self.table.c.locale)
def test_compound_primary_key(self, table):
assert has_index(table.c.id)
assert not has_index(table.c.locale)
def test_single_column_index(self):
assert has_index(self.table.c.is_published)
def test_single_column_index(self, table):
assert has_index(table.c.is_published)
def test_compound_column_index(self):
assert has_index(self.table.c.is_deleted)
assert not has_index(self.table.c.is_archived)
def test_compound_column_index(self, table):
assert has_index(table.c.is_deleted)
assert not has_index(table.c.is_archived)
def test_table_without_primary_key(self):
article = sa.Table(
@ -50,8 +48,7 @@ class TestHasIndex(object):
class TestHasIndexWithFKConstraint(object):
def test_composite_fk_without_index(self):
Base = declarative_base()
def test_composite_fk_without_index(self, Base):
class User(Base):
__tablename__ = 'user'
@ -78,8 +75,7 @@ class TestHasIndexWithFKConstraint(object):
)
assert not has_index(constraint)
def test_composite_fk_with_index(self):
Base = declarative_base()
def test_composite_fk_with_index(self, Base):
class User(Base):
__tablename__ = 'user'
@ -109,8 +105,7 @@ class TestHasIndexWithFKConstraint(object):
)
assert has_index(constraint)
def test_composite_fk_with_partial_index_match(self):
Base = declarative_base()
def test_composite_fk_with_partial_index_match(self, Base):
class User(Base):
__tablename__ = 'user'

View File

@ -1,18 +1,20 @@
import pytest
import sqlalchemy as sa
from pytest import raises
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_fk_constraint_for_columns, has_unique_index
class TestHasUniqueIndex(object):
def setup_method(self, method):
Base = declarative_base()
@pytest.fixture
def articles(self, Base):
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
return Article.__table__
@pytest.fixture
def article_translations(self, Base):
class ArticleTranslation(Base):
__tablename__ = 'article_translation'
id = sa.Column(sa.Integer, primary_key=True)
@ -26,35 +28,33 @@ class TestHasUniqueIndex(object):
sa.Index('my_index', is_archived, is_published, unique=True),
)
self.articles = Article.__table__
self.article_translations = ArticleTranslation.__table__
return ArticleTranslation.__table__
def test_primary_key(self):
assert has_unique_index(self.articles.c.id)
def test_primary_key(self, articles):
assert has_unique_index(articles.c.id)
def test_column_of_aliased_table(self):
alias = sa.orm.aliased(self.articles)
with raises(TypeError):
def test_column_of_aliased_table(self, articles):
alias = sa.orm.aliased(articles)
with pytest.raises(TypeError):
assert has_unique_index(alias.c.id)
def test_unique_index(self):
assert has_unique_index(self.article_translations.c.is_deleted)
def test_unique_index(self, article_translations):
assert has_unique_index(article_translations.c.is_deleted)
def test_compound_primary_key(self):
assert not has_unique_index(self.article_translations.c.id)
assert not has_unique_index(self.article_translations.c.locale)
def test_compound_primary_key(self, article_translations):
assert not has_unique_index(article_translations.c.id)
assert not has_unique_index(article_translations.c.locale)
def test_single_column_index(self):
assert not has_unique_index(self.article_translations.c.is_published)
def test_single_column_index(self, article_translations):
assert not has_unique_index(article_translations.c.is_published)
def test_compound_column_unique_index(self):
assert not has_unique_index(self.article_translations.c.is_published)
assert not has_unique_index(self.article_translations.c.is_archived)
def test_compound_column_unique_index(self, article_translations):
assert not has_unique_index(article_translations.c.is_published)
assert not has_unique_index(article_translations.c.is_archived)
class TestHasUniqueIndexWithFKConstraint(object):
def test_composite_fk_without_index(self):
Base = declarative_base()
def test_composite_fk_without_index(self, Base):
class User(Base):
__tablename__ = 'user'
@ -81,8 +81,7 @@ class TestHasUniqueIndexWithFKConstraint(object):
)
assert not has_unique_index(constraint)
def test_composite_fk_with_index(self):
Base = declarative_base()
def test_composite_fk_with_index(self, Base):
class User(Base):
__tablename__ = 'user'
@ -115,8 +114,7 @@ class TestHasUniqueIndexWithFKConstraint(object):
)
assert has_unique_index(constraint)
def test_composite_fk_with_partial_index_match(self):
Base = declarative_base()
def test_composite_fk_with_partial_index_match(self, Base):
class User(Base):
__tablename__ = 'user'

View File

@ -1,39 +1,46 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.functions import identity
from tests import TestCase
class IdentityTestCase(TestCase):
def test_for_transient_class_without_id(self):
assert identity(self.Building()) == (None, )
class IdentityTestCase(object):
def test_for_transient_class_with_id(self):
building = self.Building(name=u'Some building')
self.session.add(building)
self.session.flush()
@pytest.fixture
def init_models(self, Building):
pass
def test_for_transient_class_without_id(self, Building):
assert identity(Building()) == (None, )
def test_for_transient_class_with_id(self, session, Building):
building = Building(name=u'Some building')
session.add(building)
session.flush()
assert identity(building) == (building.id, )
def test_identity_for_class(self):
assert identity(self.Building) == (self.Building.id, )
def test_identity_for_class(self, Building):
assert identity(Building) == (Building.id, )
class TestIdentity(IdentityTestCase):
def create_models(self):
class Building(self.Base):
@pytest.fixture
def Building(self, Base):
class Building(Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
self.Building = Building
return Building
class TestIdentityWithColumnAlias(IdentityTestCase):
def create_models(self):
class Building(self.Base):
@pytest.fixture
def Building(self, Base):
class Building(Base):
__tablename__ = 'building'
id = sa.Column('_id', sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
self.Building = Building
return Building

View File

@ -1,24 +1,24 @@
import pytest
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import is_loaded
@pytest.fixture
def Article(Base):
class Article(Base):
__tablename__ = 'article_translation'
id = sa.Column(sa.Integer, primary_key=True)
title = sa.orm.deferred(sa.Column(sa.String(100)))
return Article
class TestIsLoaded(object):
def setup_method(self, method):
Base = declarative_base()
class Article(Base):
__tablename__ = 'article_translation'
id = sa.Column(sa.Integer, primary_key=True)
title = sa.orm.deferred(sa.Column(sa.String(100)))
self.Article = Article
def test_loaded_property(self):
article = self.Article(id=1)
def test_loaded_property(self, Article):
article = Article(id=1)
assert is_loaded(article, 'id')
def test_unloaded_property(self):
article = self.Article(id=4)
def test_unloaded_property(self, Article):
article = Article(id=4)
assert not is_loaded(article, 'title')

View File

@ -2,11 +2,10 @@ import pytest
import sqlalchemy as sa
from sqlalchemy_utils import json_sql
from tests import TestCase
class TestJSONSQL(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.mark.usefixtures('postgresql_dsn')
class TestJSONSQL(object):
@pytest.mark.parametrize(
('value', 'result'),
@ -27,7 +26,7 @@ class TestJSONSQL(TestCase):
)
)
)
def test_compiled_scalars(self, value, result):
def test_compiled_scalars(self, connection, value, result):
assert result == (
self.connection.execute(sa.select([json_sql(value)])).fetchone()[0]
connection.execute(sa.select([json_sql(value)])).fetchone()[0]
)

View File

@ -1,90 +1,102 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.functions.sort_query import make_order_by_deterministic
from tests import assert_contains, TestCase
from .. import assert_contains
class TestMakeOrderByDeterministic(TestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode)
email = sa.Column(sa.Unicode, unique=True)
@pytest.fixture
def Article(Base):
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
author = sa.orm.relationship('User')
return Article
email_lower = sa.orm.column_property(
sa.func.lower(name)
)
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
author = sa.orm.relationship(User)
@pytest.fixture
def User(Base, Article):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode)
email = sa.Column(sa.Unicode, unique=True)
User.article_count = sa.orm.column_property(
sa.select([sa.func.count()], from_obj=Article)
.where(Article.author_id == User.id)
.label('article_count')
email_lower = sa.orm.column_property(
sa.func.lower(name)
)
self.User = User
self.Article = Article
User.article_count = sa.orm.column_property(
sa.select([sa.func.count()], from_obj=Article)
.where(Article.author_id == User.id)
.label('article_count')
)
return User
def test_column_property(self):
query = self.session.query(self.User).order_by(self.User.email_lower)
@pytest.fixture
def init_models(Article, User):
pass
class TestMakeOrderByDeterministic(object):
def test_column_property(self, session, User):
query = session.query(User).order_by(User.email_lower)
query = make_order_by_deterministic(query)
assert_contains('lower("user".name), "user".id ASC', query)
def test_unique_column(self):
query = self.session.query(self.User).order_by(self.User.email)
def test_unique_column(self, session, User):
query = session.query(User).order_by(User.email)
query = make_order_by_deterministic(query)
assert str(query).endswith('ORDER BY "user".email')
def test_non_unique_column(self):
query = self.session.query(self.User).order_by(self.User.name)
def test_non_unique_column(self, session, User):
query = session.query(User).order_by(User.name)
query = make_order_by_deterministic(query)
assert_contains('ORDER BY "user".name, "user".id ASC', query)
def test_descending_order_by(self):
query = self.session.query(self.User).order_by(
sa.desc(self.User.name)
def test_descending_order_by(self, session, User):
query = session.query(User).order_by(
sa.desc(User.name)
)
query = make_order_by_deterministic(query)
assert_contains('ORDER BY "user".name DESC, "user".id DESC', query)
def test_ascending_order_by(self):
query = self.session.query(self.User).order_by(
sa.asc(self.User.name)
def test_ascending_order_by(self, session, User):
query = session.query(User).order_by(
sa.asc(User.name)
)
query = make_order_by_deterministic(query)
assert_contains('ORDER BY "user".name ASC, "user".id ASC', query)
def test_string_order_by(self):
query = self.session.query(self.User).order_by('name')
def test_string_order_by(self, session, User):
query = session.query(User).order_by('name')
query = make_order_by_deterministic(query)
assert_contains('ORDER BY "user".name, "user".id ASC', query)
def test_annotated_label(self):
query = self.session.query(self.User).order_by(self.User.article_count)
def test_annotated_label(self, session, User):
query = session.query(User).order_by(User.article_count)
query = make_order_by_deterministic(query)
assert_contains('article_count, "user".id ASC', query)
def test_annotated_label_with_descending_order(self):
query = self.session.query(self.User).order_by(
sa.desc(self.User.article_count)
def test_annotated_label_with_descending_order(self, session, User):
query = session.query(User).order_by(
sa.desc(User.article_count)
)
query = make_order_by_deterministic(query)
assert_contains('ORDER BY article_count DESC, "user".id DESC', query)
def test_query_without_order_by(self):
query = self.session.query(self.User)
def test_query_without_order_by(self, session, User):
query = session.query(User)
query = make_order_by_deterministic(query)
assert 'ORDER BY "user".id' in str(query)
def test_alias(self):
alias = sa.orm.aliased(self.User.__table__)
query = self.session.query(alias).order_by(alias.c.name)
def test_alias(self, session, User):
alias = sa.orm.aliased(User.__table__)
query = session.query(alias).order_by(alias.c.name)
query = make_order_by_deterministic(query)
assert str(query).endswith('ORDER BY user_1.name, "user".id ASC')

View File

@ -1,20 +1,25 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils import merge_references
from tests import TestCase
class TestMergeReferences(TestCase):
def create_models(self):
class User(self.Base):
class TestMergeReferences(object):
@pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
def __repr__(self):
return 'User(%r)' % self.name
return User
class BlogPost(self.Base):
@pytest.fixture
def BlogPost(self, Base, User):
class BlogPost(Base):
__tablename__ = 'blog_post'
id = sa.Column(sa.Integer, primary_key=True)
title = sa.Column(sa.Unicode(255))
@ -22,35 +27,37 @@ class TestMergeReferences(TestCase):
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
author = sa.orm.relationship(User)
return BlogPost
self.User = User
self.BlogPost = BlogPost
@pytest.fixture
def init_models(self, User, BlogPost):
pass
def test_updates_foreign_keys(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
post = self.BlogPost(title=u'Some title', author=john)
post2 = self.BlogPost(title=u'Other title', author=jack)
self.session.add(john)
self.session.add(jack)
self.session.add(post)
self.session.add(post2)
self.session.commit()
def test_updates_foreign_keys(self, session, User, BlogPost):
john = User(name=u'John')
jack = User(name=u'Jack')
post = BlogPost(title=u'Some title', author=john)
post2 = BlogPost(title=u'Other title', author=jack)
session.add(john)
session.add(jack)
session.add(post)
session.add(post2)
session.commit()
merge_references(john, jack)
self.session.commit()
session.commit()
assert post.author == jack
assert post2.author == jack
def test_object_merging_whenever_possible(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
post = self.BlogPost(title=u'Some title', author=john)
post2 = self.BlogPost(title=u'Other title', author=jack)
self.session.add(john)
self.session.add(jack)
self.session.add(post)
self.session.add(post2)
self.session.commit()
def test_object_merging_whenever_possible(self, session, User, BlogPost):
john = User(name=u'John')
jack = User(name=u'Jack')
post = BlogPost(title=u'Some title', author=john)
post2 = BlogPost(title=u'Other title', author=jack)
session.add(john)
session.add(jack)
session.add(post)
session.add(post2)
session.commit()
# Load the author for post
assert post.author_id == john.id
merge_references(john, jack)
@ -58,18 +65,23 @@ class TestMergeReferences(TestCase):
assert post2.author_id == jack.id
class TestMergeReferencesWithManyToManyAssociations(TestCase):
def create_models(self):
class User(self.Base):
class TestMergeReferencesWithManyToManyAssociations(object):
@pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
def __repr__(self):
return 'User(%r)' % self.name
return User
@pytest.fixture
def Team(self, Base):
team_member = sa.Table(
'team_member', self.Base.metadata,
'team_member', Base.metadata,
sa.Column(
'user_id', sa.Integer,
sa.ForeignKey('user.id', ondelete='CASCADE'),
@ -82,46 +94,56 @@ class TestMergeReferencesWithManyToManyAssociations(TestCase):
)
)
class Team(self.Base):
class Team(Base):
__tablename__ = 'team'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
members = sa.orm.relationship(
User,
'User',
secondary=team_member,
backref='teams'
)
return Team
self.User = User
self.Team = Team
@pytest.fixture
def init_models(self, User, Team):
pass
def test_supports_associations(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
team = self.Team(name=u'Team')
def test_supports_associations(self, session, User, Team):
john = User(name=u'John')
jack = User(name=u'Jack')
team = Team(name=u'Team')
team.members.append(john)
self.session.add(john)
self.session.add(jack)
self.session.commit()
session.add(john)
session.add(jack)
session.commit()
merge_references(john, jack)
assert john not in team.members
assert jack in team.members
class TestMergeReferencesWithManyToManyAssociationObjects(TestCase):
def create_models(self):
class Team(self.Base):
class TestMergeReferencesWithManyToManyAssociationObjects(object):
@pytest.fixture
def Team(self, Base):
class Team(Base):
__tablename__ = 'team'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
return Team
class User(self.Base):
@pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
return User
class TeamMember(self.Base):
@pytest.fixture
def TeamMember(self, Base, User, Team):
class TeamMember(Base):
__tablename__ = 'team_member'
user_id = sa.Column(
sa.Integer,
@ -150,22 +172,23 @@ class TestMergeReferencesWithManyToManyAssociationObjects(TestCase):
),
primaryjoin=user_id == User.id,
)
return TeamMember
self.User = User
self.TeamMember = TeamMember
self.Team = Team
@pytest.fixture
def init_models(self, User, Team, TeamMember):
pass
def test_supports_associations(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
team = self.Team(name=u'Team')
team.members.append(self.TeamMember(user=john))
self.session.add(john)
self.session.add(jack)
self.session.add(team)
self.session.commit()
def test_supports_associations(self, session, User, Team, TeamMember):
john = User(name=u'John')
jack = User(name=u'Jack')
team = Team(name=u'Team')
team.members.append(TeamMember(user=john))
session.add(john)
session.add(jack)
session.add(team)
session.commit()
merge_references(john, jack)
self.session.commit()
session.commit()
users = [member.user for member in team.members]
assert john not in users
assert jack in users

View File

@ -1,14 +1,13 @@
from sqlalchemy_utils.functions import naturally_equivalent
from tests import TestCase
class TestNaturallyEquivalent(TestCase):
def test_returns_true_when_properties_match(self):
class TestNaturallyEquivalent(object):
def test_returns_true_when_properties_match(self, User):
assert naturally_equivalent(
self.User(name=u'someone'), self.User(name=u'someone')
User(name=u'someone'), User(name=u'someone')
)
def test_skips_primary_keys(self):
def test_skips_primary_keys(self, User):
assert naturally_equivalent(
self.User(id=1, name=u'someone'), self.User(id=2, name=u'someone')
User(id=1, name=u'someone'), User(id=2, name=u'someone')
)

View File

@ -1,24 +1,32 @@
from itertools import chain
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.functions import non_indexed_foreign_keys
from tests import TestCase
class TestFindNonIndexedForeignKeys(TestCase):
def create_models(self):
class User(self.Base):
class TestFindNonIndexedForeignKeys(object):
@pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
return User
class Category(self.Base):
@pytest.fixture
def Category(self, Base):
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
return Category
class Article(self.Base):
@pytest.fixture
def Article(self, Base, User, Category):
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@ -34,13 +42,14 @@ class TestFindNonIndexedForeignKeys(TestCase):
'articles',
)
)
return Article
self.User = User
self.Category = Category
self.Article = Article
@pytest.fixture
def init_models(self, User, Category, Article):
pass
def test_finds_all_non_indexed_fks(self):
fks = non_indexed_foreign_keys(self.Base.metadata, self.engine)
def test_finds_all_non_indexed_fks(self, session, Base, engine):
fks = non_indexed_foreign_keys(Base.metadata, engine)
assert (
'article' in
fks

View File

@ -1,18 +1,22 @@
from sqlalchemy.dialects import postgresql
from sqlalchemy_utils.functions import quote
from tests import TestCase
class TestQuote(TestCase):
def test_quote_with_preserved_keyword(self):
assert quote(self.connection, 'order') == '"order"'
assert quote(self.session, 'order') == '"order"'
assert quote(self.engine, 'order') == '"order"'
class TestQuote(object):
def test_quote_with_preserved_keyword(self, engine, connection, session):
assert quote(connection, 'order') == '"order"'
assert quote(session, 'order') == '"order"'
assert quote(engine, 'order') == '"order"'
assert quote(postgresql.dialect(), 'order') == '"order"'
def test_quote_with_non_preserved_keyword(self):
assert quote(self.connection, 'some_order') == 'some_order'
assert quote(self.session, 'some_order') == 'some_order'
assert quote(self.engine, 'some_order') == 'some_order'
def test_quote_with_non_preserved_keyword(
self,
engine,
connection,
session
):
assert quote(connection, 'some_order') == 'some_order'
assert quote(session, 'some_order') == 'some_order'
assert quote(engine, 'some_order') == 'some_order'
assert quote(postgresql.dialect(), 'some_order') == 'some_order'

View File

@ -1,3 +1,4 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.functions import (
@ -5,52 +6,58 @@ from sqlalchemy_utils.functions import (
render_expression,
render_statement
)
from tests import TestCase
class TestRender(TestCase):
def create_models(self):
class User(self.Base):
class TestRender(object):
@pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
return User
self.User = User
@pytest.fixture
def init_models(self, User):
pass
def test_render_orm_query(self):
query = self.session.query(self.User).filter_by(id=3)
def test_render_orm_query(self, session, User):
query = session.query(User).filter_by(id=3)
text = render_statement(query)
assert 'SELECT user.id, user.name' in text
assert 'FROM user' in text
assert 'WHERE user.id = 3' in text
def test_render_statement(self):
statement = self.User.__table__.select().where(self.User.id == 3)
text = render_statement(statement, bind=self.session.bind)
def test_render_statement(self, session, User):
statement = User.__table__.select().where(User.id == 3)
text = render_statement(statement, bind=session.bind)
assert 'SELECT user.id, user.name' in text
assert 'FROM user' in text
assert 'WHERE user.id = 3' in text
def test_render_statement_without_mapper(self):
def test_render_statement_without_mapper(self, session):
statement = sa.select([sa.text('1')])
text = render_statement(statement, bind=self.session.bind)
text = render_statement(statement, bind=session.bind)
assert 'SELECT 1' in text
def test_render_ddl(self):
expression = 'self.User.__table__.create(engine)'
stream = render_expression(expression, self.engine)
def test_render_ddl(self, engine, User):
expression = 'User.__table__.create(engine)'
stream = render_expression(expression, engine)
text = stream.getvalue()
assert 'CREATE TABLE user' in text
assert 'PRIMARY KEY' in text
def test_render_mock_ddl(self):
def test_render_mock_ddl(self, engine, User):
# TODO: mock_engine doesn't seem to work with locally scoped variables.
self.engine = engine
with mock_engine('self.engine') as stream:
self.User.__table__.create(self.engine)
User.__table__.create(self.engine)
text = stream.getvalue()

View File

@ -1,26 +1,33 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils import table_name
from tests import TestCase
class TestTableName(TestCase):
def create_models(self):
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@pytest.fixture
def Building(Base):
class Building(Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
return Building
self.Building = Building
def test_class(self):
assert table_name(self.Building) == 'building'
del self.Building.__tablename__
assert table_name(self.Building) == 'building'
@pytest.fixture
def init_models(Base):
pass
def test_attribute(self):
assert table_name(self.Building.id) == 'building'
assert table_name(self.Building.name) == 'building'
def test_target(self):
assert table_name(self.Building()) == 'building'
class TestTableName(object):
def test_class(self, Building):
assert table_name(Building) == 'building'
del Building.__tablename__
assert table_name(Building) == 'building'
def test_attribute(self, Building):
assert table_name(Building.id) == 'building'
assert table_name(Building.name) == 'building'
def test_target(self, Building):
assert table_name(Building()) == 'building'

View File

@ -1,109 +1,105 @@
from __future__ import unicode_literals
import six
from tests import TestCase
class GenericRelationshipTestCase(TestCase):
def test_set_as_none(self):
event = self.Event()
class GenericRelationshipTestCase(object):
def test_set_as_none(self, Event):
event = Event()
event.object = None
assert event.object is None
def test_set_manual_and_get(self):
user = self.User()
def test_set_manual_and_get(self, session, User, Event):
user = User()
self.session.add(user)
self.session.commit()
session.add(user)
session.commit()
event = self.Event()
event = Event()
event.object_id = user.id
event.object_type = six.text_type(type(user).__name__)
assert event.object is None
self.session.add(event)
self.session.commit()
session.add(event)
session.commit()
assert event.object == user
def test_set_and_get(self):
user = self.User()
def test_set_and_get(self, session, User, Event):
user = User()
self.session.add(user)
self.session.commit()
session.add(user)
session.commit()
event = self.Event(object=user)
event = Event(object=user)
assert event.object_id == user.id
assert event.object_type == type(user).__name__
self.session.add(event)
self.session.commit()
session.add(event)
session.commit()
assert event.object == user
def test_compare_instance(self):
user1 = self.User()
user2 = self.User()
def test_compare_instance(self, session, User, Event):
user1 = User()
user2 = User()
self.session.add_all([user1, user2])
self.session.commit()
session.add_all([user1, user2])
session.commit()
event = self.Event(object=user1)
event = Event(object=user1)
self.session.add(event)
self.session.commit()
session.add(event)
session.commit()
assert event.object == user1
assert event.object != user2
def test_compare_query(self):
user1 = self.User()
user2 = self.User()
def test_compare_query(self, session, User, Event):
user1 = User()
user2 = User()
self.session.add_all([user1, user2])
self.session.commit()
session.add_all([user1, user2])
session.commit()
event = self.Event(object=user1)
event = Event(object=user1)
self.session.add(event)
self.session.commit()
session.add(event)
session.commit()
q = self.session.query(self.Event)
q = session.query(Event)
assert q.filter_by(object=user1).first() is not None
assert q.filter_by(object=user2).first() is None
assert q.filter(self.Event.object == user2).first() is None
assert q.filter(Event.object == user2).first() is None
def test_compare_not_query(self):
user1 = self.User()
user2 = self.User()
def test_compare_not_query(self, session, User, Event):
user1 = User()
user2 = User()
self.session.add_all([user1, user2])
self.session.commit()
session.add_all([user1, user2])
session.commit()
event = self.Event(object=user1)
event = Event(object=user1)
self.session.add(event)
self.session.commit()
session.add(event)
session.commit()
q = self.session.query(self.Event)
assert q.filter(self.Event.object != user2).first() is not None
q = session.query(Event)
assert q.filter(Event.object != user2).first() is not None
def test_compare_type(self):
user1 = self.User()
user2 = self.User()
def test_compare_type(self, session, User, Event):
user1 = User()
user2 = User()
self.session.add_all([user1, user2])
self.session.commit()
session.add_all([user1, user2])
session.commit()
event1 = self.Event(object=user1)
event2 = self.Event(object=user2)
event1 = Event(object=user1)
event2 = Event(object=user2)
self.session.add_all([event1, event2])
self.session.commit()
session.add_all([event1, event2])
session.commit()
statement = self.Event.object.is_type(self.User)
q = self.session.query(self.Event).filter(statement)
statement = Event.object.is_type(User)
q = session.query(Event).filter(statement)
assert q.first() is not None

View File

@ -1,36 +1,54 @@
from __future__ import unicode_literals
import pytest
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy_utils import generic_relationship
from tests.generic_relationship import GenericRelationshipTestCase
from . import GenericRelationshipTestCase
@pytest.fixture
def Building(Base):
class Building(Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
return Building
@pytest.fixture
def User(Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
return User
@pytest.fixture
def EventBase(Base):
class EventBase(Base):
__abstract__ = True
object_type = sa.Column(sa.Unicode(255))
object_id = sa.Column(sa.Integer, nullable=False)
@declared_attr
def object(cls):
return generic_relationship('object_type', 'object_id')
return EventBase
@pytest.fixture
def Event(EventBase):
class Event(EventBase):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
return Event
@pytest.fixture
def init_models(Building, User, Event):
pass
class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase):
def create_models(self):
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
class EventBase(self.Base):
__abstract__ = True
object_type = sa.Column(sa.Unicode(255))
object_id = sa.Column(sa.Integer, nullable=False)
@declared_attr
def object(cls):
return generic_relationship('object_type', 'object_id')
class Event(EventBase):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
self.Building = Building
self.User = User
self.Event = Event
pass

View File

@ -1,30 +1,44 @@
from __future__ import unicode_literals
import pytest
import sqlalchemy as sa
from sqlalchemy_utils import generic_relationship
from tests.generic_relationship import GenericRelationshipTestCase
from . import GenericRelationshipTestCase
@pytest.fixture
def Building(Base):
class Building(Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
return Building
@pytest.fixture
def User(Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
return User
@pytest.fixture
def Event(Base):
class Event(Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
object_type = sa.Column(sa.Unicode(255), name="objectType")
object_id = sa.Column(sa.Integer, nullable=False)
object = generic_relationship(object_type, object_id)
return Event
@pytest.fixture
def init_models(Building, User, Event):
pass
class TestGenericRelationship(GenericRelationshipTestCase):
def create_models(self):
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
class Event(self.Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
object_type = sa.Column(sa.Unicode(255), name="objectType")
object_id = sa.Column(sa.Integer, nullable=False)
object = generic_relationship(object_type, object_id)
self.Building = Building
self.User = User
self.Event = Event
pass

View File

@ -1,66 +1,84 @@
from __future__ import unicode_literals
import pytest
import six
import sqlalchemy as sa
from sqlalchemy_utils import generic_relationship
from tests.generic_relationship import GenericRelationshipTestCase
from ..generic_relationship import GenericRelationshipTestCase
@pytest.fixture
def incrementor():
class Incrementor(object):
value = 1
return Incrementor()
@pytest.fixture
def Building(Base, incrementor):
class Building(Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
code = sa.Column(sa.Integer, primary_key=True)
def __init__(self):
incrementor.value += 1
self.id = incrementor.value
self.code = incrementor.value
return Building
@pytest.fixture
def User(Base, incrementor):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
code = sa.Column(sa.Integer, primary_key=True)
def __init__(self):
incrementor.value += 1
self.id = incrementor.value
self.code = incrementor.value
return User
@pytest.fixture
def Event(Base):
class Event(Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
object_type = sa.Column(sa.Unicode(255))
object_id = sa.Column(sa.Integer, nullable=False)
object_code = sa.Column(sa.Integer, nullable=False)
object = generic_relationship(
object_type, (object_id, object_code)
)
return Event
@pytest.fixture
def init_models(Building, User, Event):
pass
class TestGenericRelationship(GenericRelationshipTestCase):
index = 1
def create_models(self):
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
code = sa.Column(sa.Integer, primary_key=True)
def test_set_manual_and_get(self, session, Event, User):
user = User()
def __init__(obj_self):
self.index += 1
obj_self.id = self.index
obj_self.code = self.index
session.add(user)
session.commit()
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
code = sa.Column(sa.Integer, primary_key=True)
def __init__(obj_self):
self.index += 1
obj_self.id = self.index
obj_self.code = self.index
class Event(self.Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
object_type = sa.Column(sa.Unicode(255))
object_id = sa.Column(sa.Integer, nullable=False)
object_code = sa.Column(sa.Integer, nullable=False)
object = generic_relationship(
object_type, (object_id, object_code)
)
self.Building = Building
self.User = User
self.Event = Event
def test_set_manual_and_get(self):
user = self.User()
self.session.add(user)
self.session.commit()
event = self.Event()
event = Event()
event.object_id = user.id
event.object_type = six.text_type(type(user).__name__)
event.object_code = user.code
assert event.object is None
self.session.add(event)
self.session.commit()
session.add(event)
session.commit()
assert event.object == user

View File

@ -1,68 +1,79 @@
from __future__ import unicode_literals
import pytest
import six
import sqlalchemy as sa
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy_utils import generic_relationship
from tests import TestCase
class TestGenericRelationship(TestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
@pytest.fixture
def User(Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
return User
class UserHistory(self.Base):
__tablename__ = 'user_history'
id = sa.Column(sa.Integer, primary_key=True)
transaction_id = sa.Column(sa.Integer, primary_key=True)
@pytest.fixture
def UserHistory(Base):
class UserHistory(Base):
__tablename__ = 'user_history'
id = sa.Column(sa.Integer, primary_key=True)
class Event(self.Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
transaction_id = sa.Column(sa.Integer, primary_key=True)
return UserHistory
transaction_id = sa.Column(sa.Integer)
object_type = sa.Column(sa.Unicode(255))
object_id = sa.Column(sa.Integer, nullable=False)
@pytest.fixture
def Event(Base):
class Event(Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
object = generic_relationship(
object_type, object_id
)
transaction_id = sa.Column(sa.Integer)
@hybrid_property
def object_version_type(self):
return self.object_type + 'History'
object_type = sa.Column(sa.Unicode(255))
object_id = sa.Column(sa.Integer, nullable=False)
@object_version_type.expression
def object_version_type(cls):
return sa.func.concat(cls.object_type, 'History')
object = generic_relationship(
object_type, object_id
)
object_version = generic_relationship(
object_version_type, (object_id, transaction_id)
)
@hybrid_property
def object_version_type(self):
return self.object_type + 'History'
self.User = User
self.UserHistory = UserHistory
self.Event = Event
@object_version_type.expression
def object_version_type(cls):
return sa.func.concat(cls.object_type, 'History')
def test_set_manual_and_get(self):
user = self.User(id=1)
history = self.UserHistory(id=1, transaction_id=1)
self.session.add(user)
self.session.add(history)
self.session.commit()
object_version = generic_relationship(
object_version_type, (object_id, transaction_id)
)
return Event
event = self.Event(transaction_id=1)
@pytest.fixture
def init_models(User, UserHistory, Event):
pass
class TestGenericRelationship(object):
def test_set_manual_and_get(self, session, User, UserHistory, Event):
user = User(id=1)
history = UserHistory(id=1, transaction_id=1)
session.add(user)
session.add(history)
session.commit()
event = Event(transaction_id=1)
event.object_id = user.id
event.object_type = six.text_type(type(user).__name__)
assert event.object is None
self.session.add(event)
self.session.commit()
session.add(event)
session.commit()
assert event.object == user
assert event.object_version == history

View File

@ -1,164 +1,178 @@
from __future__ import unicode_literals
import pytest
import six
import sqlalchemy as sa
from sqlalchemy_utils import generic_relationship
from tests import TestCase
class TestGenericRelationship(TestCase):
def create_models(self):
class Employee(self.Base):
__tablename__ = 'employee'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(50))
type = sa.Column(sa.String(20))
@pytest.fixture
def Employee(Base):
class Employee(Base):
__tablename__ = 'employee'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(50))
type = sa.Column(sa.String(20))
__mapper_args__ = {
'polymorphic_on': type,
'polymorphic_identity': 'employee'
}
__mapper_args__ = {
'polymorphic_on': type,
'polymorphic_identity': 'employee'
}
return Employee
class Manager(Employee):
__mapper_args__ = {
'polymorphic_identity': 'manager'
}
class Engineer(Employee):
__mapper_args__ = {
'polymorphic_identity': 'engineer'
}
@pytest.fixture
def Manager(Employee):
class Manager(Employee):
__mapper_args__ = {
'polymorphic_identity': 'manager'
}
return Manager
class Event(self.Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
object_type = sa.Column(sa.Unicode(255))
object_id = sa.Column(sa.Integer, nullable=False)
@pytest.fixture
def Engineer(Employee):
class Engineer(Employee):
__mapper_args__ = {
'polymorphic_identity': 'engineer'
}
return Engineer
object = generic_relationship(object_type, object_id)
self.Employee = Employee
self.Manager = Manager
self.Engineer = Engineer
self.Event = Event
@pytest.fixture
def Event(Base):
class Event(Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
def test_set_as_none(self):
event = self.Event()
object_type = sa.Column(sa.Unicode(255))
object_id = sa.Column(sa.Integer, nullable=False)
object = generic_relationship(object_type, object_id)
return Event
@pytest.fixture
def init_models(Employee, Manager, Engineer, Event):
pass
class TestGenericRelationship(object):
def test_set_as_none(self, Event):
event = Event()
event.object = None
assert event.object is None
def test_set_manual_and_get(self):
manager = self.Manager()
def test_set_manual_and_get(self, session, Manager, Event):
manager = Manager()
self.session.add(manager)
self.session.commit()
session.add(manager)
session.commit()
event = self.Event()
event = Event()
event.object_id = manager.id
event.object_type = six.text_type(type(manager).__name__)
assert event.object is None
self.session.add(event)
self.session.commit()
session.add(event)
session.commit()
assert event.object == manager
def test_set_and_get(self):
manager = self.Manager()
def test_set_and_get(self, session, Manager, Event):
manager = Manager()
self.session.add(manager)
self.session.commit()
session.add(manager)
session.commit()
event = self.Event(object=manager)
event = Event(object=manager)
assert event.object_id == manager.id
assert event.object_type == type(manager).__name__
self.session.add(event)
self.session.commit()
session.add(event)
session.commit()
assert event.object == manager
def test_compare_instance(self):
manager1 = self.Manager()
manager2 = self.Manager()
def test_compare_instance(self, session, Manager, Event):
manager1 = Manager()
manager2 = Manager()
self.session.add_all([manager1, manager2])
self.session.commit()
session.add_all([manager1, manager2])
session.commit()
event = self.Event(object=manager1)
event = Event(object=manager1)
self.session.add(event)
self.session.commit()
session.add(event)
session.commit()
assert event.object == manager1
assert event.object != manager2
def test_compare_query(self):
manager1 = self.Manager()
manager2 = self.Manager()
def test_compare_query(self, session, Manager, Event):
manager1 = Manager()
manager2 = Manager()
self.session.add_all([manager1, manager2])
self.session.commit()
session.add_all([manager1, manager2])
session.commit()
event = self.Event(object=manager1)
event = Event(object=manager1)
self.session.add(event)
self.session.commit()
session.add(event)
session.commit()
q = self.session.query(self.Event)
q = session.query(Event)
assert q.filter_by(object=manager1).first() is not None
assert q.filter_by(object=manager2).first() is None
assert q.filter(self.Event.object == manager2).first() is None
assert q.filter(Event.object == manager2).first() is None
def test_compare_not_query(self):
manager1 = self.Manager()
manager2 = self.Manager()
def test_compare_not_query(self, session, Manager, Event):
manager1 = Manager()
manager2 = Manager()
self.session.add_all([manager1, manager2])
self.session.commit()
session.add_all([manager1, manager2])
session.commit()
event = self.Event(object=manager1)
event = Event(object=manager1)
self.session.add(event)
self.session.commit()
session.add(event)
session.commit()
q = self.session.query(self.Event)
assert q.filter(self.Event.object != manager2).first() is not None
q = session.query(Event)
assert q.filter(Event.object != manager2).first() is not None
def test_compare_type(self):
manager1 = self.Manager()
manager2 = self.Manager()
def test_compare_type(self, session, Manager, Event):
manager1 = Manager()
manager2 = Manager()
self.session.add_all([manager1, manager2])
self.session.commit()
session.add_all([manager1, manager2])
session.commit()
event1 = self.Event(object=manager1)
event2 = self.Event(object=manager2)
event1 = Event(object=manager1)
event2 = Event(object=manager2)
self.session.add_all([event1, event2])
self.session.commit()
session.add_all([event1, event2])
session.commit()
statement = self.Event.object.is_type(self.Manager)
q = self.session.query(self.Event).filter(statement)
statement = Event.object.is_type(Manager)
q = session.query(Event).filter(statement)
assert q.first() is not None
def test_compare_super_type(self):
manager1 = self.Manager()
manager2 = self.Manager()
def test_compare_super_type(self, session, Manager, Event, Employee):
manager1 = Manager()
manager2 = Manager()
self.session.add_all([manager1, manager2])
self.session.commit()
session.add_all([manager1, manager2])
session.commit()
event1 = self.Event(object=manager1)
event2 = self.Event(object=manager2)
event1 = Event(object=manager1)
event2 = Event(object=manager2)
self.session.add_all([event1, event2])
self.session.commit()
session.add_all([event1, event2])
session.commit()
statement = self.Event.object.is_type(self.Employee)
q = self.session.query(self.Event).filter(statement)
statement = Event.object.is_type(Employee)
q = session.query(Event).filter(statement)
assert q.first() is not None

View File

@ -1,18 +1,24 @@
import pytest
import sqlalchemy as sa
class ThreeLevelDeepOneToOne(object):
def create_models(self):
class Catalog(self.Base):
@pytest.fixture
def Catalog(self, Base, Category):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column('_id', sa.Integer, primary_key=True)
category = sa.orm.relationship(
'Category',
Category,
uselist=False,
backref='catalog'
)
return Catalog
class Category(self.Base):
@pytest.fixture
def Category(self, Base, SubCategory):
class Category(Base):
__tablename__ = 'category'
id = sa.Column('_id', sa.Integer, primary_key=True)
catalog_id = sa.Column(
@ -22,12 +28,15 @@ class ThreeLevelDeepOneToOne(object):
)
sub_category = sa.orm.relationship(
'SubCategory',
SubCategory,
uselist=False,
backref='category'
)
return Category
class SubCategory(self.Base):
@pytest.fixture
def SubCategory(self, Base, Product):
class SubCategory(Base):
__tablename__ = 'sub_category'
id = sa.Column('_id', sa.Integer, primary_key=True)
category_id = sa.Column(
@ -36,12 +45,15 @@ class ThreeLevelDeepOneToOne(object):
sa.ForeignKey('category._id')
)
product = sa.orm.relationship(
'Product',
Product,
uselist=False,
backref='sub_category'
)
return SubCategory
class Product(self.Base):
@pytest.fixture
def Product(self, Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column('_id', sa.Integer, primary_key=True)
price = sa.Column(sa.Integer)
@ -51,22 +63,27 @@ class ThreeLevelDeepOneToOne(object):
sa.Integer,
sa.ForeignKey('sub_category._id')
)
return Product
self.Catalog = Catalog
self.Category = Category
self.SubCategory = SubCategory
self.Product = Product
@pytest.fixture
def init_models(self, Catalog, Category, SubCategory, Product):
pass
class ThreeLevelDeepOneToMany(object):
def create_models(self):
class Catalog(self.Base):
@pytest.fixture
def Catalog(self, Base, Category):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column('_id', sa.Integer, primary_key=True)
categories = sa.orm.relationship('Category', backref='catalog')
categories = sa.orm.relationship(Category, backref='catalog')
return Catalog
class Category(self.Base):
@pytest.fixture
def Category(self, Base, SubCategory):
class Category(Base):
__tablename__ = 'category'
id = sa.Column('_id', sa.Integer, primary_key=True)
catalog_id = sa.Column(
@ -76,10 +93,13 @@ class ThreeLevelDeepOneToMany(object):
)
sub_categories = sa.orm.relationship(
'SubCategory', backref='category'
SubCategory, backref='category'
)
return Category
class SubCategory(self.Base):
@pytest.fixture
def SubCategory(self, Base, Product):
class SubCategory(Base):
__tablename__ = 'sub_category'
id = sa.Column('_id', sa.Integer, primary_key=True)
category_id = sa.Column(
@ -88,11 +108,14 @@ class ThreeLevelDeepOneToMany(object):
sa.ForeignKey('category._id')
)
products = sa.orm.relationship(
'Product',
Product,
backref='sub_category'
)
return SubCategory
class Product(self.Base):
@pytest.fixture
def Product(self, Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column('_id', sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
@ -105,25 +128,42 @@ class ThreeLevelDeepOneToMany(object):
def __repr__(self):
return '<Product id=%r>' % self.id
return Product
self.Catalog = Catalog
self.Category = Category
self.SubCategory = SubCategory
self.Product = Product
@pytest.fixture
def init_models(self, Catalog, Category, SubCategory, Product):
pass
class ThreeLevelDeepManyToMany(object):
def create_models(self):
@pytest.fixture
def Catalog(self, Base, Category):
catalog_category = sa.Table(
'catalog_category',
self.Base.metadata,
Base.metadata,
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog._id')),
sa.Column('category_id', sa.Integer, sa.ForeignKey('category._id'))
)
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column('_id', sa.Integer, primary_key=True)
categories = sa.orm.relationship(
Category,
backref='catalogs',
secondary=catalog_category
)
return Catalog
@pytest.fixture
def Category(self, Base, SubCategory):
category_subcategory = sa.Table(
'category_subcategory',
self.Base.metadata,
Base.metadata,
sa.Column(
'category_id',
sa.Integer,
@ -136,9 +176,23 @@ class ThreeLevelDeepManyToMany(object):
)
)
class Category(Base):
__tablename__ = 'category'
id = sa.Column('_id', sa.Integer, primary_key=True)
sub_categories = sa.orm.relationship(
SubCategory,
backref='categories',
secondary=category_subcategory
)
return Category
@pytest.fixture
def SubCategory(self, Base, Product):
subcategory_product = sa.Table(
'subcategory_product',
self.Base.metadata,
Base.metadata,
sa.Column(
'subcategory_id',
sa.Integer,
@ -151,41 +205,24 @@ class ThreeLevelDeepManyToMany(object):
)
)
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column('_id', sa.Integer, primary_key=True)
categories = sa.orm.relationship(
'Category',
backref='catalogs',
secondary=catalog_category
)
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column('_id', sa.Integer, primary_key=True)
sub_categories = sa.orm.relationship(
'SubCategory',
backref='categories',
secondary=category_subcategory
)
class SubCategory(self.Base):
class SubCategory(Base):
__tablename__ = 'sub_category'
id = sa.Column('_id', sa.Integer, primary_key=True)
products = sa.orm.relationship(
'Product',
Product,
backref='sub_categories',
secondary=subcategory_product
)
return SubCategory
class Product(self.Base):
@pytest.fixture
def Product(self, Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column('_id', sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
return Product
self.Catalog = Catalog
self.Category = Category
self.SubCategory = SubCategory
self.Product = Product
@pytest.fixture
def init_models(self, Catalog, Category, SubCategory, Product):
pass

View File

@ -1,15 +1,15 @@
import pytest
import sqlalchemy as sa
from pytest import raises
from sqlalchemy_utils.observer import observes
from tests import TestCase
class TestObservesForColumn(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.mark.usefixtures('postgresql_dsn')
class TestObservesForColumn(object):
def create_models(self):
class Product(self.Base):
@pytest.fixture
def Product(self, Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Integer)
@ -17,21 +17,25 @@ class TestObservesForColumn(TestCase):
@observes('price')
def product_price_observer(self, price):
self.price = price * 2
return Product
self.Product = Product
@pytest.fixture
def init_models(self, Product):
pass
def test_simple_insert(self):
product = self.Product(price=100)
self.session.add(product)
self.session.flush()
def test_simple_insert(self, session, Product):
product = Product(price=100)
session.add(product)
session.flush()
assert product.price == 200
class TestObservesForColumnWithoutActualChanges(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.mark.usefixtures('postgresql_dsn')
class TestObservesForColumnWithoutActualChanges(object):
def create_models(self):
class Product(self.Base):
@pytest.fixture
def Product(self, Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Integer)
@ -39,15 +43,18 @@ class TestObservesForColumnWithoutActualChanges(TestCase):
@observes('price')
def product_price_observer(self, price):
raise Exception('Trying to change price')
return Product
self.Product = Product
@pytest.fixture
def init_models(self, Product):
pass
def test_only_notifies_observer_on_actual_changes(self):
product = self.Product()
self.session.add(product)
self.session.flush()
def test_only_notifies_observer_on_actual_changes(self, session, Product):
product = Product()
session.add(product)
session.flush()
with raises(Exception) as e:
with pytest.raises(Exception) as e:
product.price = 500
self.session.commit()
session.commit()
assert str(e.value) == 'Trying to change price'

View File

@ -1,137 +1,158 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.observer import observes
from tests import TestCase
class TestObservesForManyToManyToManyToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.fixture
def Catalog(Base):
catalog_category = sa.Table(
'catalog_category',
Base.metadata,
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id'))
)
def create_models(self):
catalog_category = sa.Table(
'catalog_category',
self.Base.metadata,
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id'))
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
product_count = sa.Column(sa.Integer, default=0)
@observes('categories.sub_categories.products')
def product_observer(self, products):
self.product_count = len(products)
categories = sa.orm.relationship(
'Category',
backref='catalogs',
secondary=catalog_category
)
return Catalog
@pytest.fixture
def Category(Base):
category_subcategory = sa.Table(
'category_subcategory',
Base.metadata,
sa.Column(
'category_id',
sa.Integer,
sa.ForeignKey('category.id')
),
sa.Column(
'subcategory_id',
sa.Integer,
sa.ForeignKey('sub_category.id')
)
)
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
sub_categories = sa.orm.relationship(
'SubCategory',
backref='categories',
secondary=category_subcategory
)
return Category
@pytest.fixture
def SubCategory(Base):
subcategory_product = sa.Table(
'subcategory_product',
Base.metadata,
sa.Column(
'subcategory_id',
sa.Integer,
sa.ForeignKey('sub_category.id')
),
sa.Column(
'product_id',
sa.Integer,
sa.ForeignKey('product.id')
)
)
class SubCategory(Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
products = sa.orm.relationship(
'Product',
backref='sub_categories',
secondary=subcategory_product
)
category_subcategory = sa.Table(
'category_subcategory',
self.Base.metadata,
sa.Column(
'category_id',
sa.Integer,
sa.ForeignKey('category.id')
),
sa.Column(
'subcategory_id',
sa.Integer,
sa.ForeignKey('sub_category.id')
)
)
return SubCategory
subcategory_product = sa.Table(
'subcategory_product',
self.Base.metadata,
sa.Column(
'subcategory_id',
sa.Integer,
sa.ForeignKey('sub_category.id')
),
sa.Column(
'product_id',
sa.Integer,
sa.ForeignKey('product.id')
)
)
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
product_count = sa.Column(sa.Integer, default=0)
@pytest.fixture
def Product(Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
return Product
@observes('categories.sub_categories.products')
def product_observer(self, products):
self.product_count = len(products)
categories = sa.orm.relationship(
'Category',
backref='catalogs',
secondary=catalog_category
)
@pytest.fixture
def init_models(Catalog, Category, SubCategory, Product):
pass
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
sub_categories = sa.orm.relationship(
'SubCategory',
backref='categories',
secondary=category_subcategory
)
@pytest.fixture
def catalog(session, Catalog, Category, SubCategory, Product):
sub_category = SubCategory(products=[Product()])
category = Category(sub_categories=[sub_category])
catalog = Catalog(categories=[category])
session.add(catalog)
session.flush()
return catalog
class SubCategory(self.Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
products = sa.orm.relationship(
'Product',
backref='sub_categories',
secondary=subcategory_product
)
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
@pytest.mark.usefixtures('postgresql_dsn')
class TestObservesForManyToManyToManyToMany(object):
self.Catalog = Catalog
self.Category = Category
self.SubCategory = SubCategory
self.Product = Product
def create_catalog(self):
sub_category = self.SubCategory(products=[self.Product()])
category = self.Category(sub_categories=[sub_category])
catalog = self.Catalog(categories=[category])
self.session.add(catalog)
self.session.flush()
return catalog
def test_simple_insert(self):
catalog = self.create_catalog()
def test_simple_insert(self, catalog):
assert catalog.product_count == 1
def test_add_leaf_object(self):
catalog = self.create_catalog()
product = self.Product()
def test_add_leaf_object(self, catalog, session, Product):
product = Product()
catalog.categories[0].sub_categories[0].products.append(product)
self.session.flush()
session.flush()
assert catalog.product_count == 2
def test_remove_leaf_object(self):
catalog = self.create_catalog()
product = self.Product()
def test_remove_leaf_object(self, catalog, session, Product):
product = Product()
catalog.categories[0].sub_categories[0].products.append(product)
self.session.flush()
self.session.delete(product)
self.session.flush()
session.flush()
session.delete(product)
session.flush()
assert catalog.product_count == 1
def test_delete_intermediate_object(self):
catalog = self.create_catalog()
self.session.delete(catalog.categories[0].sub_categories[0])
self.session.commit()
def test_delete_intermediate_object(self, catalog, session):
session.delete(catalog.categories[0].sub_categories[0])
session.commit()
assert catalog.product_count == 0
def test_gathered_objects_are_distinct(self):
catalog = self.Catalog()
category = self.Category(catalogs=[catalog])
product = self.Product()
def test_gathered_objects_are_distinct(
self,
session,
Catalog,
Category,
SubCategory,
Product
):
catalog = Catalog()
category = Category(catalogs=[catalog])
product = Product()
category.sub_categories.append(
self.SubCategory(products=[product])
SubCategory(products=[product])
)
self.session.add(
self.SubCategory(categories=[category], products=[product])
session.add(
SubCategory(categories=[category], products=[product])
)
self.session.commit()
session.commit()
assert catalog.product_count == 1

View File

@ -1,107 +1,127 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.observer import observes
from tests import TestCase
class TestObservesFor3LevelDeepOneToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.fixture
def Catalog(Base):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
product_count = sa.Column(sa.Integer, default=0)
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
product_count = sa.Column(sa.Integer, default=0)
@observes('categories.sub_categories.products')
def product_observer(self, products):
self.product_count = len(products)
@observes('categories.sub_categories.products')
def product_observer(self, products):
self.product_count = len(products)
categories = sa.orm.relationship('Category', backref='catalog')
return Catalog
categories = sa.orm.relationship('Category', backref='catalog')
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
@pytest.fixture
def Category(Base):
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
sub_categories = sa.orm.relationship(
'SubCategory', backref='category'
)
sub_categories = sa.orm.relationship(
'SubCategory', backref='category'
)
return Category
class SubCategory(self.Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
products = sa.orm.relationship(
'Product',
backref='sub_category'
)
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
@pytest.fixture
def SubCategory(Base):
class SubCategory(Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
products = sa.orm.relationship(
'Product',
backref='sub_category'
)
return SubCategory
sub_category_id = sa.Column(
sa.Integer, sa.ForeignKey('sub_category.id')
)
def __repr__(self):
return '<Product id=%r>' % self.id
@pytest.fixture
def Product(Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
self.Catalog = Catalog
self.Category = Category
self.SubCategory = SubCategory
self.Product = Product
sub_category_id = sa.Column(
sa.Integer, sa.ForeignKey('sub_category.id')
)
def create_catalog(self):
sub_category = self.SubCategory(products=[self.Product()])
category = self.Category(sub_categories=[sub_category])
catalog = self.Catalog(categories=[category])
self.session.add(catalog)
self.session.commit()
return catalog
def __repr__(self):
return '<Product id=%r>' % self.id
return Product
def test_simple_insert(self):
catalog = self.create_catalog()
@pytest.fixture
def init_models(Catalog, Category, SubCategory, Product):
pass
@pytest.fixture
def catalog(session, Catalog, Category, SubCategory, Product):
sub_category = SubCategory(products=[Product()])
category = Category(sub_categories=[sub_category])
catalog = Catalog(categories=[category])
session.add(catalog)
session.commit()
return catalog
@pytest.mark.usefixtures('postgresql_dsn')
class TestObservesFor3LevelDeepOneToMany(object):
def test_simple_insert(self, catalog):
assert catalog.product_count == 1
def test_add_leaf_object(self):
catalog = self.create_catalog()
product = self.Product()
def test_add_leaf_object(self, catalog, session, Product):
product = Product()
catalog.categories[0].sub_categories[0].products.append(product)
self.session.flush()
session.flush()
assert catalog.product_count == 2
def test_remove_leaf_object(self):
catalog = self.create_catalog()
product = self.Product()
def test_remove_leaf_object(self, catalog, session, Product):
product = Product()
catalog.categories[0].sub_categories[0].products.append(product)
self.session.flush()
self.session.delete(product)
self.session.commit()
session.flush()
session.delete(product)
session.commit()
assert catalog.product_count == 1
self.session.delete(
session.delete(
catalog.categories[0].sub_categories[0].products[0]
)
self.session.commit()
session.commit()
assert catalog.product_count == 0
def test_delete_intermediate_object(self):
catalog = self.create_catalog()
self.session.delete(catalog.categories[0].sub_categories[0])
self.session.commit()
def test_delete_intermediate_object(self, catalog, session):
session.delete(catalog.categories[0].sub_categories[0])
session.commit()
assert catalog.product_count == 0
def test_gathered_objects_are_distinct(self):
catalog = self.Catalog()
category = self.Category(catalog=catalog)
product = self.Product()
def test_gathered_objects_are_distinct(
self,
session,
Catalog,
Category,
SubCategory,
Product
):
catalog = Catalog()
category = Category(catalog=catalog)
product = Product()
category.sub_categories.append(
self.SubCategory(products=[product])
SubCategory(products=[product])
)
self.session.add(
self.SubCategory(category=category, products=[product])
session.add(
SubCategory(category=category, products=[product])
)
self.session.commit()
session.commit()
assert catalog.product_count == 1

View File

@ -1,96 +1,116 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.observer import observes
from tests import TestCase
class TestObservesForOneToManyToOneToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.fixture
def Catalog(Base):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
product_count = sa.Column(sa.Integer, default=0)
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
product_count = sa.Column(sa.Integer, default=0)
@observes('categories.sub_category.products')
def product_observer(self, products):
self.product_count = len(products)
@observes('categories.sub_category.products')
def product_observer(self, products):
self.product_count = len(products)
categories = sa.orm.relationship('Category', backref='catalog')
return Catalog
categories = sa.orm.relationship('Category', backref='catalog')
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
@pytest.fixture
def Category(Base):
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
sub_category = sa.orm.relationship(
'SubCategory',
uselist=False,
backref='category'
)
sub_category = sa.orm.relationship(
'SubCategory',
uselist=False,
backref='category'
)
return Category
class SubCategory(self.Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
products = sa.orm.relationship('Product', backref='sub_category')
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
@pytest.fixture
def SubCategory(Base):
class SubCategory(Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
products = sa.orm.relationship('Product', backref='sub_category')
return SubCategory
sub_category_id = sa.Column(
sa.Integer, sa.ForeignKey('sub_category.id')
)
self.Catalog = Catalog
self.Category = Category
self.SubCategory = SubCategory
self.Product = Product
@pytest.fixture
def Product(Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
def create_catalog(self):
sub_category = self.SubCategory(products=[self.Product()])
category = self.Category(sub_category=sub_category)
catalog = self.Catalog(categories=[category])
self.session.add(catalog)
self.session.flush()
return catalog
sub_category_id = sa.Column(
sa.Integer, sa.ForeignKey('sub_category.id')
)
return Product
def test_simple_insert(self):
catalog = self.create_catalog()
@pytest.fixture
def init_models(Catalog, Category, SubCategory, Product):
pass
@pytest.fixture
def catalog(session, Catalog, Category, SubCategory, Product):
sub_category = SubCategory(products=[Product()])
category = Category(sub_category=sub_category)
catalog = Catalog(categories=[category])
session.add(catalog)
session.flush()
return catalog
@pytest.mark.usefixtures('postgresql_dsn')
class TestObservesForOneToManyToOneToMany(object):
def test_simple_insert(self, catalog):
assert catalog.product_count == 1
def test_add_leaf_object(self):
catalog = self.create_catalog()
product = self.Product()
def test_add_leaf_object(self, catalog, session, Product):
product = Product()
catalog.categories[0].sub_category.products.append(product)
self.session.flush()
session.flush()
assert catalog.product_count == 2
def test_remove_leaf_object(self):
catalog = self.create_catalog()
product = self.Product()
def test_remove_leaf_object(self, catalog, session, Product):
product = Product()
catalog.categories[0].sub_category.products.append(product)
self.session.flush()
self.session.delete(product)
self.session.flush()
session.flush()
session.delete(product)
session.flush()
assert catalog.product_count == 1
def test_delete_intermediate_object(self):
catalog = self.create_catalog()
self.session.delete(catalog.categories[0].sub_category)
self.session.commit()
def test_delete_intermediate_object(self, catalog, session):
session.delete(catalog.categories[0].sub_category)
session.commit()
assert catalog.product_count == 0
def test_gathered_objects_are_distinct(self):
catalog = self.Catalog()
category = self.Category(catalog=catalog)
product = self.Product()
category.sub_category = self.SubCategory(products=[product])
self.session.add(
self.Category(catalog=catalog, sub_category=category.sub_category)
def test_gathered_objects_are_distinct(
self,
session,
Catalog,
Category,
SubCategory,
Product
):
catalog = Catalog()
category = Category(catalog=catalog)
product = Product()
category.sub_category = SubCategory(products=[product])
session.add(
Category(catalog=catalog, sub_category=category.sub_category)
)
self.session.commit()
session.commit()
assert catalog.product_count == 1

View File

@ -1,53 +1,66 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.observer import observes
from tests import TestCase
class TestObservesForOneToManyToOneToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.fixture
def Device(Base):
class Device(Base):
__tablename__ = 'device'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
return Device
def create_models(self):
class Device(self.Base):
__tablename__ = 'device'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
class Order(self.Base):
__tablename__ = 'order'
id = sa.Column(sa.Integer, primary_key=True)
@pytest.fixture
def Order(Base):
class Order(Base):
__tablename__ = 'order'
id = sa.Column(sa.Integer, primary_key=True)
device_id = sa.Column(
'device', sa.ForeignKey('device.id'), nullable=False
device_id = sa.Column(
'device', sa.ForeignKey('device.id'), nullable=False
)
device = sa.orm.relationship('Device', backref='orders')
return Order
@pytest.fixture
def SalesInvoice(Base):
class SalesInvoice(Base):
__tablename__ = 'sales_invoice'
id = sa.Column(sa.Integer, primary_key=True)
order_id = sa.Column(
'order',
sa.ForeignKey('order.id'),
nullable=False
)
order = sa.orm.relationship(
'Order',
backref=sa.orm.backref(
'invoice',
uselist=False
)
device = sa.orm.relationship('Device', backref='orders')
)
device_name = sa.Column(sa.String)
class SalesInvoice(self.Base):
__tablename__ = 'sales_invoice'
id = sa.Column(sa.Integer, primary_key=True)
order_id = sa.Column(
'order',
sa.ForeignKey('order.id'),
nullable=False
)
order = sa.orm.relationship(
'Order',
backref=sa.orm.backref(
'invoice',
uselist=False
)
)
device_name = sa.Column(sa.String)
@observes('order.device')
def process_device(self, device):
self.device_name = device.name
@observes('order.device')
def process_device(self, device):
self.device_name = device.name
return SalesInvoice
self.Device = Device
self.Order = Order
self.SalesInvoice = SalesInvoice
def test_observable_root_obj_is_none(self):
order = self.Order(device=self.Device(name='Something'))
self.session.add(order)
self.session.flush()
@pytest.fixture
def init_models(Device, Order, SalesInvoice):
pass
@pytest.mark.usefixtures('postgresql_dsn')
class TestObservesForOneToManyToOneToMany(object):
def test_observable_root_obj_is_none(self, session, Device, Order):
order = Order(device=Device(name='Something'))
session.add(order)
session.flush()

View File

@ -1,84 +1,98 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.observer import observes
from tests import TestCase
class TestObservesForOneToOneToOneToOne(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.fixture
def Catalog(Base):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
product_price = sa.Column(sa.Integer)
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
product_price = sa.Column(sa.Integer)
@observes('category.sub_category.product')
def product_observer(self, product):
self.product_price = product.price if product else None
@observes('category.sub_category.product')
def product_observer(self, product):
self.product_price = product.price if product else None
category = sa.orm.relationship(
'Category',
uselist=False,
backref='catalog'
)
return Catalog
category = sa.orm.relationship(
'Category',
uselist=False,
backref='catalog'
)
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
@pytest.fixture
def Category(Base):
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
sub_category = sa.orm.relationship(
'SubCategory',
uselist=False,
backref='category'
)
sub_category = sa.orm.relationship(
'SubCategory',
uselist=False,
backref='category'
)
return Category
class SubCategory(self.Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
product = sa.orm.relationship(
'Product',
uselist=False,
backref='sub_category'
)
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Integer)
@pytest.fixture
def SubCategory(Base):
class SubCategory(Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
product = sa.orm.relationship(
'Product',
uselist=False,
backref='sub_category'
)
return SubCategory
sub_category_id = sa.Column(
sa.Integer, sa.ForeignKey('sub_category.id')
)
self.Catalog = Catalog
self.Category = Category
self.SubCategory = SubCategory
self.Product = Product
@pytest.fixture
def Product(Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Integer)
def create_catalog(self):
sub_category = self.SubCategory(product=self.Product(price=123))
category = self.Category(sub_category=sub_category)
catalog = self.Catalog(category=category)
self.session.add(catalog)
self.session.flush()
return catalog
sub_category_id = sa.Column(
sa.Integer, sa.ForeignKey('sub_category.id')
)
return Product
def test_simple_insert(self):
catalog = self.create_catalog()
@pytest.fixture
def init_models(Catalog, Category, SubCategory, Product):
pass
@pytest.fixture
def catalog(session, Catalog, Category, SubCategory, Product):
sub_category = SubCategory(product=Product(price=123))
category = Category(sub_category=sub_category)
catalog = Catalog(category=category)
session.add(catalog)
session.flush()
return catalog
@pytest.mark.usefixtures('postgresql_dsn')
class TestObservesForOneToOneToOneToOne(object):
def test_simple_insert(self, catalog):
assert catalog.product_price == 123
def test_replace_leaf_object(self):
catalog = self.create_catalog()
product = self.Product(price=44)
def test_replace_leaf_object(self, catalog, session, Product):
product = Product(price=44)
catalog.category.sub_category.product = product
self.session.flush()
session.flush()
assert catalog.product_price == 44
def test_delete_leaf_object(self):
catalog = self.create_catalog()
self.session.delete(catalog.category.sub_category.product)
self.session.flush()
def test_delete_leaf_object(self, catalog, session):
session.delete(catalog.category.sub_category.product)
session.flush()
assert catalog.product_price is None

View File

@ -1,32 +1,36 @@
import pytest
import six
from pytest import mark, raises
from sqlalchemy_utils import Country, i18n
@mark.skipif('i18n.babel is None')
@pytest.fixture
def set_get_locale():
i18n.get_locale = lambda: i18n.babel.Locale('en')
@pytest.mark.skipif('i18n.babel is None')
@pytest.mark.usefixtures('set_get_locale')
class TestCountry(object):
def setup_method(self, method):
i18n.get_locale = lambda: i18n.babel.Locale('en')
def test_init(self):
assert Country(u'FI') == Country(Country(u'FI'))
def test_constructor_with_wrong_type(self):
with raises(TypeError) as e:
with pytest.raises(TypeError) as e:
Country(None)
assert str(e.value) == (
"Country() argument must be a string or a country, not 'NoneType'"
)
def test_constructor_with_invalid_code(self):
with raises(ValueError) as e:
with pytest.raises(ValueError) as e:
Country('SomeUnknownCode')
assert str(e.value) == (
'Could not convert string to country code: SomeUnknownCode'
)
@mark.parametrize(
@pytest.mark.parametrize(
'code',
(
'FI',
@ -37,7 +41,7 @@ class TestCountry(object):
Country.validate(code)
def test_validate_with_invalid_code(self):
with raises(ValueError) as e:
with pytest.raises(ValueError) as e:
Country.validate('SomeUnknownCode')
assert str(e.value) == (
'Could not convert string to country code: SomeUnknownCode'

View File

@ -1,14 +1,18 @@
# -*- coding: utf-8 -*-
import pytest
import six
from pytest import mark, raises
from sqlalchemy_utils import Currency, i18n
@mark.skipif('i18n.babel is None')
@pytest.fixture
def set_get_locale():
i18n.get_locale = lambda: i18n.babel.Locale('en')
@pytest.mark.skipif('i18n.babel is None')
@pytest.mark.usefixtures('set_get_locale')
class TestCurrency(object):
def setup_method(self, method):
i18n.get_locale = lambda: i18n.babel.Locale('en')
def test_init(self):
assert Currency('USD') == Currency(Currency('USD'))
@ -17,14 +21,14 @@ class TestCurrency(object):
assert len(set([Currency('USD'), Currency('USD')])) == 1
def test_invalid_currency_code(self):
with raises(ValueError):
with pytest.raises(ValueError):
Currency('Unknown code')
def test_invalid_currency_code_type(self):
with raises(TypeError):
with pytest.raises(TypeError):
Currency(None)
@mark.parametrize(
@pytest.mark.parametrize(
('code', 'name'),
(
('USD', 'US Dollar'),
@ -34,7 +38,7 @@ class TestCurrency(object):
def test_name_property(self, code, name):
assert Currency(code).name == name
@mark.parametrize(
@pytest.mark.parametrize(
('code', 'symbol'),
(
('USD', u'$'),

View File

@ -6,10 +6,14 @@ from sqlalchemy_utils import i18n
from sqlalchemy_utils.primitives import WeekDay, WeekDays
@pytest.fixture
def set_get_locale():
i18n.get_locale = lambda: i18n.babel.Locale('fi')
@pytest.mark.skipif('i18n.babel is None')
@pytest.mark.usefixtures('set_get_locale')
class TestWeekDay(object):
def setup_method(self, method):
i18n.get_locale = lambda: i18n.babel.Locale('fi')
def test_constructor_with_valid_index(self):
day = WeekDay(1)

View File

@ -1,26 +1,27 @@
import pytest
from sqlalchemy_utils.relationships import chained_join
from tests import TestCase
from tests.mixins import (
from ..mixins import (
ThreeLevelDeepManyToMany,
ThreeLevelDeepOneToMany,
ThreeLevelDeepOneToOne
)
class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
create_tables = False
@pytest.mark.usefixtures('postgresql_dsn')
class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany):
def test_simple_join(self):
assert str(chained_join(self.Catalog.categories)) == (
def test_simple_join(self, Catalog):
assert str(chained_join(Catalog.categories)) == (
'catalog_category JOIN category ON '
'category._id = catalog_category.category_id'
)
def test_two_relations(self):
def test_two_relations(self, Catalog, Category):
sql = chained_join(
self.Catalog.categories,
self.Category.sub_categories
Catalog.categories,
Category.sub_categories
)
assert str(sql) == (
'catalog_category JOIN category ON category._id = '
@ -30,11 +31,11 @@ class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase):
'category_subcategory.subcategory_id'
)
def test_three_relations(self):
def test_three_relations(self, Catalog, Category, SubCategory):
sql = chained_join(
self.Catalog.categories,
self.Category.sub_categories,
self.SubCategory.products
Catalog.categories,
Category.sub_categories,
SubCategory.products
)
assert str(sql) == (
'catalog_category JOIN category ON category._id = '
@ -47,28 +48,27 @@ class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase):
)
class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany, TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
create_tables = False
@pytest.mark.usefixtures('postgresql_dsn')
class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany):
def test_simple_join(self):
assert str(chained_join(self.Catalog.categories)) == 'category'
def test_simple_join(self, Catalog):
assert str(chained_join(Catalog.categories)) == 'category'
def test_two_relations(self):
def test_two_relations(self, Catalog, Category):
sql = chained_join(
self.Catalog.categories,
self.Category.sub_categories
Catalog.categories,
Category.sub_categories
)
assert str(sql) == (
'category JOIN sub_category ON category._id = '
'sub_category._category_id'
)
def test_three_relations(self):
def test_three_relations(self, Catalog, Category, SubCategory):
sql = chained_join(
self.Catalog.categories,
self.Category.sub_categories,
self.SubCategory.products
Catalog.categories,
Category.sub_categories,
SubCategory.products
)
assert str(sql) == (
'category JOIN sub_category ON category._id = '
@ -77,28 +77,27 @@ class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany, TestCase):
)
class TestChainedJoinForDeepOneToOne(ThreeLevelDeepOneToOne, TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
create_tables = False
@pytest.mark.usefixtures('postgresql_dsn')
class TestChainedJoinForDeepOneToOne(ThreeLevelDeepOneToOne):
def test_simple_join(self):
assert str(chained_join(self.Catalog.category)) == 'category'
def test_simple_join(self, Catalog):
assert str(chained_join(Catalog.category)) == 'category'
def test_two_relations(self):
def test_two_relations(self, Catalog, Category):
sql = chained_join(
self.Catalog.category,
self.Category.sub_category
Catalog.category,
Category.sub_category
)
assert str(sql) == (
'category JOIN sub_category ON category._id = '
'sub_category._category_id'
)
def test_three_relations(self):
def test_three_relations(self, Catalog, Category, SubCategory):
sql = chained_join(
self.Catalog.category,
self.Category.sub_category,
self.SubCategory.product
Catalog.category,
Category.sub_category,
SubCategory.product
)
assert str(sql) == (
'category JOIN sub_category ON category._id = '

View File

@ -1,31 +1,23 @@
import pytest
import sqlalchemy as sa
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils.relationships import select_correlated_expression
@pytest.fixture(scope='class')
def base():
return declarative_base()
@pytest.fixture(scope='class')
def group_user_cls(base):
@pytest.fixture
def group_user_tbl(Base):
return sa.Table(
'group_user',
base.metadata,
Base.metadata,
sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')),
sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id'))
)
@pytest.fixture(scope='class')
def group_cls(base):
class Group(base):
@pytest.fixture
def group_tbl(Base):
class Group(Base):
__tablename__ = 'group'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
@ -33,11 +25,11 @@ def group_cls(base):
return Group
@pytest.fixture(scope='class')
def friendship_cls(base):
@pytest.fixture
def friendship_tbl(Base):
return sa.Table(
'friendships',
base.metadata,
Base.metadata,
sa.Column(
'friend_a_id',
sa.Integer,
@ -53,35 +45,37 @@ def friendship_cls(base):
)
@pytest.fixture(scope='class')
def user_cls(base, group_user_cls, friendship_cls):
class User(base):
@pytest.fixture
def User(Base, group_user_tbl, friendship_tbl):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
groups = sa.orm.relationship(
'Group',
secondary=group_user_cls,
secondary=group_user_tbl,
backref='users'
)
# this relationship is used for persistence
friends = sa.orm.relationship(
'User',
secondary=friendship_cls,
primaryjoin=id == friendship_cls.c.friend_a_id,
secondaryjoin=id == friendship_cls.c.friend_b_id,
secondary=friendship_tbl,
primaryjoin=id == friendship_tbl.c.friend_a_id,
secondaryjoin=id == friendship_tbl.c.friend_b_id,
)
friendship_union = sa.select([
friendship_cls.c.friend_a_id,
friendship_cls.c.friend_b_id
friendship_union = (
sa.select([
friendship_tbl.c.friend_a_id,
friendship_tbl.c.friend_b_id
]).union(
sa.select([
friendship_cls.c.friend_b_id,
friendship_cls.c.friend_a_id]
friendship_tbl.c.friend_b_id,
friendship_tbl.c.friend_a_id]
)
).alias()
).alias()
)
User.all_friends = sa.orm.relationship(
'User',
@ -94,9 +88,9 @@ def user_cls(base, group_user_cls, friendship_cls):
return User
@pytest.fixture(scope='class')
def category_cls(base, group_user_cls, friendship_cls):
class Category(base):
@pytest.fixture
def Category(Base, group_user_tbl, friendship_tbl):
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
@ -111,9 +105,9 @@ def category_cls(base, group_user_cls, friendship_cls):
return Category
@pytest.fixture(scope='class')
def article_cls(base, category_cls, user_cls):
class Article(base):
@pytest.fixture
def Article(Base, Category, User):
class Article(Base):
__tablename__ = 'article'
id = sa.Column('_id', sa.Integer, primary_key=True)
name = sa.Column(sa.String)
@ -129,144 +123,104 @@ def article_cls(base, category_cls, user_cls):
content = sa.Column(sa.String)
category_id = sa.Column(sa.Integer, sa.ForeignKey(category_cls.id))
category = sa.orm.relationship(category_cls, backref='articles')
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
category = sa.orm.relationship(Category, backref='articles')
author_id = sa.Column(sa.Integer, sa.ForeignKey(user_cls.id))
author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
author = sa.orm.relationship(
user_cls,
primaryjoin=author_id == user_cls.id,
User,
primaryjoin=author_id == User.id,
backref='authored_articles'
)
owner_id = sa.Column(sa.Integer, sa.ForeignKey(user_cls.id))
owner_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
owner = sa.orm.relationship(
user_cls,
primaryjoin=owner_id == user_cls.id,
User,
primaryjoin=owner_id == User.id,
backref='owned_articles'
)
return Article
@pytest.fixture(scope='class')
def comment_cls(base, article_cls, user_cls):
class Comment(base):
@pytest.fixture
def Comment(Base, Article, User):
class Comment(Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.String)
article_id = sa.Column(sa.Integer, sa.ForeignKey(article_cls.id))
article = sa.orm.relationship(article_cls, backref='comments')
article_id = sa.Column(sa.Integer, sa.ForeignKey(Article.id))
article = sa.orm.relationship(Article, backref='comments')
author_id = sa.Column(sa.Integer, sa.ForeignKey(user_cls.id))
author = sa.orm.relationship(user_cls, backref='comments')
author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
author = sa.orm.relationship(User, backref='comments')
article_cls.comment_count = sa.orm.column_property(
Article.comment_count = sa.orm.column_property(
sa.select([sa.func.count(Comment.id)])
.where(Comment.article_id == article_cls.id)
.correlate_except(article_cls)
.where(Comment.article_id == Article.id)
.correlate_except(Article)
)
return Comment
@pytest.fixture(scope='class')
def composite_pk_cls(base):
class CompositePKModel(base):
__tablename__ = 'composite_pk_model'
a = sa.Column(sa.Integer, primary_key=True)
b = sa.Column(sa.Integer, primary_key=True)
return CompositePKModel
@pytest.fixture(scope='class')
def dns():
return 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.yield_fixture(scope='class')
def engine(dns):
engine = create_engine(dns)
engine.echo = True
yield engine
engine.dispose()
@pytest.yield_fixture(scope='class')
def connection(engine):
conn = engine.connect()
yield conn
conn.close()
@pytest.fixture(scope='class')
def model_mapping(article_cls, category_cls, comment_cls, group_cls, user_cls):
@pytest.fixture
def model_mapping(Article, Category, Comment, group_tbl, User):
return {
'articles': article_cls,
'categories': category_cls,
'comments': comment_cls,
'groups': group_cls,
'users': user_cls
'articles': Article,
'categories': Category,
'comments': Comment,
'groups': group_tbl,
'users': User
}
@pytest.yield_fixture(scope='class')
def table_creator(base, connection, model_mapping):
sa.orm.configure_mappers()
base.metadata.create_all(connection)
yield
base.metadata.drop_all(connection)
@pytest.fixture
def init_models(Article, Category, Comment, group_tbl, User):
pass
@pytest.yield_fixture(scope='class')
def session(connection):
Session = sessionmaker(bind=connection)
session = Session()
yield session
session.close_all()
@pytest.fixture(scope='class')
@pytest.fixture
def dataset(
session,
user_cls,
group_cls,
article_cls,
category_cls,
comment_cls
User,
group_tbl,
Article,
Category,
Comment
):
group = group_cls(name='Group 1')
group2 = group_cls(name='Group 2')
user = user_cls(id=1, name='User 1', groups=[group, group2])
user2 = user_cls(id=2, name='User 2')
user3 = user_cls(id=3, name='User 3', groups=[group])
user4 = user_cls(id=4, name='User 4', groups=[group2])
user5 = user_cls(id=5, name='User 5')
group = group_tbl(name='Group 1')
group2 = group_tbl(name='Group 2')
user = User(id=1, name='User 1', groups=[group, group2])
user2 = User(id=2, name='User 2')
user3 = User(id=3, name='User 3', groups=[group])
user4 = User(id=4, name='User 4', groups=[group2])
user5 = User(id=5, name='User 5')
user.friends = [user2]
user2.friends = [user3, user4]
user3.friends = [user5]
article = article_cls(
article = Article(
name='Some article',
author=user,
owner=user2,
category=category_cls(
category=Category(
id=1,
name='Some category',
subcategories=[
category_cls(
Category(
id=2,
name='Subcategory 1',
subcategories=[
category_cls(
Category(
id=3,
name='Subsubcategory 1',
subcategories=[
category_cls(
Category(
id=5,
name='Subsubsubcategory 1',
),
category_cls(
Category(
id=6,
name='Subsubsubcategory 2',
)
@ -274,11 +228,11 @@ def dataset(
)
]
),
category_cls(id=4, name='Subcategory 2'),
Category(id=4, name='Subcategory 2'),
]
),
comments=[
comment_cls(
Comment(
content='Some comment',
author=user
)
@ -290,7 +244,7 @@ def dataset(
session.commit()
@pytest.mark.usefixtures('table_creator', 'dataset')
@pytest.mark.usefixtures('dataset', 'postgresql_dsn')
class TestSelectCorrelatedExpression(object):
@pytest.mark.parametrize(
('model_key', 'related_model_key', 'path', 'result'),
@ -428,20 +382,20 @@ class TestSelectCorrelatedExpression(object):
def test_with_non_aggregate_function(
self,
session,
user_cls,
article_cls
User,
Article
):
aggregate = select_correlated_expression(
article_cls,
sa.func.json_build_object('name', user_cls.name),
Article,
sa.func.json_build_object('name', User.name),
'comments.author',
user_cls
User
)
query = session.query(
article_cls.id,
Article.id,
aggregate.label('author_json')
).order_by(article_cls.id)
).order_by(Article.id)
result = query.all()
assert result == [
(1, {'name': 'User 1'})

View File

@ -9,143 +9,152 @@ from sqlalchemy_utils import (
assert_non_nullable,
assert_nullable
)
from tests import TestCase
class AssertionTestCase(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.fixture()
def User(Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column('_id', sa.Integer, primary_key=True)
name = sa.Column('_name', sa.String(20))
age = sa.Column('_age', sa.Integer, nullable=False)
email = sa.Column(
'_email', sa.String(200), nullable=False, unique=True
)
fav_numbers = sa.Column('_fav_numbers', ARRAY(sa.Integer))
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column('_id', sa.Integer, primary_key=True)
name = sa.Column('_name', sa.String(20))
age = sa.Column('_age', sa.Integer, nullable=False)
email = sa.Column(
'_email', sa.String(200), nullable=False, unique=True
)
fav_numbers = sa.Column('_fav_numbers', ARRAY(sa.Integer))
__table_args__ = (
sa.CheckConstraint(sa.and_(age >= 0, age <= 150)),
sa.CheckConstraint(
sa.and_(
sa.func.array_length(fav_numbers, 1) <= 8
)
__table_args__ = (
sa.CheckConstraint(sa.and_(age >= 0, age <= 150)),
sa.CheckConstraint(
sa.and_(
sa.func.array_length(fav_numbers, 1) <= 8
)
)
self.User = User
def setup_method(self, method):
TestCase.setup_method(self, method)
user = self.User(
name='Someone',
email='someone@example.com',
age=15,
fav_numbers=[1, 2, 3]
)
self.session.add(user)
self.session.commit()
self.user = user
return User
class TestAssertMaxLengthWithArray(AssertionTestCase):
def test_with_max_length(self):
assert_max_length(self.user, 'fav_numbers', 8)
assert_max_length(self.user, 'fav_numbers', 8)
@pytest.fixture()
def user(User, session):
user = User(
name='Someone',
email='someone@example.com',
age=15,
fav_numbers=[1, 2, 3]
)
session.add(user)
session.commit()
return user
def test_smaller_than_max_length(self):
@pytest.mark.usefixtures('postgresql_dsn')
class TestAssertMaxLengthWithArray(object):
def test_with_max_length(self, user):
assert_max_length(user, 'fav_numbers', 8)
assert_max_length(user, 'fav_numbers', 8)
def test_smaller_than_max_length(self, user):
with pytest.raises(AssertionError):
assert_max_length(self.user, 'fav_numbers', 7)
assert_max_length(user, 'fav_numbers', 7)
with pytest.raises(AssertionError):
assert_max_length(self.user, 'fav_numbers', 7)
assert_max_length(user, 'fav_numbers', 7)
def test_bigger_than_max_length(self):
def test_bigger_than_max_length(self, user):
with pytest.raises(AssertionError):
assert_max_length(self.user, 'fav_numbers', 9)
assert_max_length(user, 'fav_numbers', 9)
with pytest.raises(AssertionError):
assert_max_length(self.user, 'fav_numbers', 9)
assert_max_length(user, 'fav_numbers', 9)
class TestAssertNonNullable(AssertionTestCase):
def test_non_nullable_column(self):
@pytest.mark.usefixtures('postgresql_dsn')
class TestAssertNonNullable(object):
def test_non_nullable_column(self, user):
# Test everything twice so that session gets rolled back properly
assert_non_nullable(self.user, 'age')
assert_non_nullable(self.user, 'age')
assert_non_nullable(user, 'age')
assert_non_nullable(user, 'age')
def test_nullable_column(self):
def test_nullable_column(self, user):
with pytest.raises(AssertionError):
assert_non_nullable(self.user, 'name')
assert_non_nullable(user, 'name')
with pytest.raises(AssertionError):
assert_non_nullable(self.user, 'name')
assert_non_nullable(user, 'name')
class TestAssertNullable(AssertionTestCase):
def test_nullable_column(self):
assert_nullable(self.user, 'name')
assert_nullable(self.user, 'name')
@pytest.mark.usefixtures('postgresql_dsn')
class TestAssertNullable(object):
def test_non_nullable_column(self):
def test_nullable_column(self, user):
assert_nullable(user, 'name')
assert_nullable(user, 'name')
def test_non_nullable_column(self, user):
with pytest.raises(AssertionError):
assert_nullable(self.user, 'age')
assert_nullable(user, 'age')
with pytest.raises(AssertionError):
assert_nullable(self.user, 'age')
assert_nullable(user, 'age')
class TestAssertMaxLength(AssertionTestCase):
def test_with_max_length(self):
assert_max_length(self.user, 'name', 20)
assert_max_length(self.user, 'name', 20)
@pytest.mark.usefixtures('postgresql_dsn')
class TestAssertMaxLength(object):
def test_with_non_nullable_column(self):
assert_max_length(self.user, 'email', 200)
assert_max_length(self.user, 'email', 200)
def test_with_max_length(self, user):
assert_max_length(user, 'name', 20)
assert_max_length(user, 'name', 20)
def test_smaller_than_max_length(self):
with pytest.raises(AssertionError):
assert_max_length(self.user, 'name', 19)
with pytest.raises(AssertionError):
assert_max_length(self.user, 'name', 19)
def test_with_non_nullable_column(self, user):
assert_max_length(user, 'email', 200)
assert_max_length(user, 'email', 200)
def test_bigger_than_max_length(self):
def test_smaller_than_max_length(self, user):
with pytest.raises(AssertionError):
assert_max_length(self.user, 'name', 21)
assert_max_length(user, 'name', 19)
with pytest.raises(AssertionError):
assert_max_length(self.user, 'name', 21)
assert_max_length(user, 'name', 19)
def test_bigger_than_max_length(self, user):
with pytest.raises(AssertionError):
assert_max_length(user, 'name', 21)
with pytest.raises(AssertionError):
assert_max_length(user, 'name', 21)
class TestAssertMinValue(AssertionTestCase):
def test_with_min_value(self):
assert_min_value(self.user, 'age', 0)
assert_min_value(self.user, 'age', 0)
@pytest.mark.usefixtures('postgresql_dsn')
class TestAssertMinValue(object):
def test_smaller_than_min_value(self):
with pytest.raises(AssertionError):
assert_min_value(self.user, 'age', -1)
with pytest.raises(AssertionError):
assert_min_value(self.user, 'age', -1)
def test_with_min_value(self, user):
assert_min_value(user, 'age', 0)
assert_min_value(user, 'age', 0)
def test_bigger_than_min_value(self):
def test_smaller_than_min_value(self, user):
with pytest.raises(AssertionError):
assert_min_value(self.user, 'age', 1)
assert_min_value(user, 'age', -1)
with pytest.raises(AssertionError):
assert_min_value(self.user, 'age', 1)
assert_min_value(user, 'age', -1)
def test_bigger_than_min_value(self, user):
with pytest.raises(AssertionError):
assert_min_value(user, 'age', 1)
with pytest.raises(AssertionError):
assert_min_value(user, 'age', 1)
class TestAssertMaxValue(AssertionTestCase):
def test_with_min_value(self):
assert_max_value(self.user, 'age', 150)
assert_max_value(self.user, 'age', 150)
@pytest.mark.usefixtures('postgresql_dsn')
class TestAssertMaxValue(object):
def test_smaller_than_max_value(self):
with pytest.raises(AssertionError):
assert_max_value(self.user, 'age', 149)
with pytest.raises(AssertionError):
assert_max_value(self.user, 'age', 149)
def test_with_min_value(self, user):
assert_max_value(user, 'age', 150)
assert_max_value(user, 'age', 150)
def test_bigger_than_max_value(self):
def test_smaller_than_max_value(self, user):
with pytest.raises(AssertionError):
assert_max_value(self.user, 'age', 151)
assert_max_value(user, 'age', 149)
with pytest.raises(AssertionError):
assert_max_value(self.user, 'age', 151)
assert_max_value(user, 'age', 149)
def test_bigger_than_max_value(self, user):
with pytest.raises(AssertionError):
assert_max_value(user, 'age', 151)
with pytest.raises(AssertionError):
assert_max_value(user, 'age', 151)

View File

@ -1,117 +1,108 @@
import pytest
import sqlalchemy as sa
from pytest import raises
from sqlalchemy_utils import auto_delete_orphans, ImproperlyConfigured
from tests import TestCase
class TestAutoDeleteOrphans(TestCase):
def create_models(self):
tagging = sa.Table(
'tagging',
self.Base.metadata,
sa.Column(
'tag_id',
sa.Integer,
sa.ForeignKey('tag.id', ondelete='cascade'),
primary_key=True
),
sa.Column(
'entry_id',
sa.Integer,
sa.ForeignKey('entry.id', ondelete='cascade'),
primary_key=True
)
@pytest.fixture
def tagging_tbl(Base):
return sa.Table(
'tagging',
Base.metadata,
sa.Column(
'tag_id',
sa.Integer,
sa.ForeignKey('tag.id', ondelete='cascade'),
primary_key=True
),
sa.Column(
'entry_id',
sa.Integer,
sa.ForeignKey('entry.id', ondelete='cascade'),
primary_key=True
)
)
class Tag(self.Base):
__tablename__ = 'tag'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(100), unique=True, nullable=False)
def __init__(self, name=None):
self.name = name
@pytest.fixture
def Tag(Base):
class Tag(Base):
__tablename__ = 'tag'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(100), unique=True, nullable=False)
class Entry(self.Base):
__tablename__ = 'entry'
def __init__(self, name=None):
self.name = name
return Tag
id = sa.Column(sa.Integer, primary_key=True)
tags = sa.orm.relationship(
'Tag',
secondary=tagging,
backref='entries'
)
@pytest.fixture
def Entry(Base, Tag, tagging_tbl):
class Entry(Base):
__tablename__ = 'entry'
auto_delete_orphans(Entry.tags)
id = sa.Column(sa.Integer, primary_key=True)
self.Tag = Tag
self.Entry = Entry
tags = sa.orm.relationship(
'Tag',
secondary=tagging_tbl,
backref='entries'
)
auto_delete_orphans(Entry.tags)
return Entry
def test_orphan_deletion(self):
r1 = self.Entry()
r2 = self.Entry()
r3 = self.Entry()
@pytest.fixture
def EntryWithoutTagsBackref(Base, Tag, tagging_tbl):
class EntryWithoutTagsBackref(Base):
__tablename__ = 'entry'
id = sa.Column(sa.Integer, primary_key=True)
tags = sa.orm.relationship(
'Tag',
secondary=tagging_tbl
)
return EntryWithoutTagsBackref
class TestAutoDeleteOrphans(object):
@pytest.fixture
def init_models(self, Entry, Tag):
pass
def test_orphan_deletion(self, session, Entry, Tag):
r1 = Entry()
r2 = Entry()
r3 = Entry()
t1, t2, t3, t4 = (
self.Tag('t1'),
self.Tag('t2'),
self.Tag('t3'),
self.Tag('t4')
Tag('t1'),
Tag('t2'),
Tag('t3'),
Tag('t4')
)
r1.tags.extend([t1, t2])
r2.tags.extend([t2, t3])
r3.tags.extend([t4])
self.session.add_all([r1, r2, r3])
session.add_all([r1, r2, r3])
assert self.session.query(self.Tag).count() == 4
assert session.query(Tag).count() == 4
r2.tags.remove(t2)
assert self.session.query(self.Tag).count() == 4
assert session.query(Tag).count() == 4
r1.tags.remove(t2)
assert self.session.query(self.Tag).count() == 3
assert session.query(Tag).count() == 3
r1.tags.remove(t1)
assert self.session.query(self.Tag).count() == 2
assert session.query(Tag).count() == 2
class TestAutoDeleteOrphansWithoutBackref(TestCase):
def create_models(self):
tagging = sa.Table(
'tagging',
self.Base.metadata,
sa.Column(
'tag_id',
sa.Integer,
sa.ForeignKey('tag.id', ondelete='cascade'),
primary_key=True
),
sa.Column(
'entry_id',
sa.Integer,
sa.ForeignKey('entry.id', ondelete='cascade'),
primary_key=True
)
)
class TestAutoDeleteOrphansWithoutBackref(object):
class Tag(self.Base):
__tablename__ = 'tag'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(100), unique=True, nullable=False)
@pytest.fixture
def init_models(self, EntryWithoutTagsBackref, Tag):
pass
def __init__(self, name=None):
self.name = name
class Entry(self.Base):
__tablename__ = 'entry'
id = sa.Column(sa.Integer, primary_key=True)
tags = sa.orm.relationship(
'Tag',
secondary=tagging
)
self.Entry = Entry
def test_orphan_deletion(self):
with raises(ImproperlyConfigured):
auto_delete_orphans(self.Entry.tags)
def test_orphan_deletion(self, EntryWithoutTagsBackref):
with pytest.raises(ImproperlyConfigured):
auto_delete_orphans(EntryWithoutTagsBackref.tags)

View File

@ -1,50 +1,60 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils import EmailType
from tests import TestCase
class TestCaseInsensitiveComparator(TestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
email = sa.Column(EmailType)
@pytest.fixture
def User(Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
email = sa.Column(EmailType)
def __repr__(self):
return 'Building(%r)' % self.id
def __repr__(self):
return 'Building(%r)' % self.id
return User
self.User = User
def test_supports_equals(self):
@pytest.fixture
def init_models(User):
pass
class TestCaseInsensitiveComparator(object):
def test_supports_equals(self, session, User):
query = (
self.session.query(self.User)
.filter(self.User.email == u'email@example.com')
session.query(User)
.filter(User.email == u'email@example.com')
)
assert '"user".email = lower(:lower_1)' in str(query)
def test_supports_in_(self):
def test_supports_in_(self, session, User):
query = (
self.session.query(self.User)
.filter(self.User.email.in_([u'email@example.com', u'a']))
session.query(User)
.filter(User.email.in_([u'email@example.com', u'a']))
)
assert (
'"user".email IN (lower(:lower_1), lower(:lower_2))'
in str(query)
)
def test_supports_notin_(self):
def test_supports_notin_(self, session, User):
query = (
self.session.query(self.User)
.filter(self.User.email.notin_([u'email@example.com', u'a']))
session.query(User)
.filter(User.email.notin_([u'email@example.com', u'a']))
)
assert (
'"user".email NOT IN (lower(:lower_1), lower(:lower_2))'
in str(query)
)
def test_does_not_apply_lower_to_types_that_are_already_lowercased(self):
assert str(self.User.email == self.User.email) == (
def test_does_not_apply_lower_to_types_that_are_already_lowercased(
self,
User
):
assert str(User.email == User.email) == (
'"user".email = "user".email'
)

View File

@ -1,86 +1,93 @@
import pytest
import sqlalchemy as sa
from pytest import raises
from sqlalchemy.dialects import postgresql
from sqlalchemy_utils import Asterisk, row_to_json
from sqlalchemy_utils.expressions import explain, explain_analyze
from tests import TestCase
class ExpressionTestCase(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
content = sa.Column(sa.UnicodeText)
self.Article = Article
def assert_startswith(self, query, query_part):
@pytest.fixture
def assert_startswith(session):
def assert_startswith(query, query_part):
assert str(
query.compile(dialect=postgresql.dialect())
).startswith(query_part)
# Check that query executes properly
self.session.execute(query)
session.execute(query)
return assert_startswith
class TestExplain(ExpressionTestCase):
def test_render_explain(self):
self.assert_startswith(
explain(self.session.query(self.Article)),
@pytest.fixture
def Article(Base):
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
content = sa.Column(sa.UnicodeText)
return Article
@pytest.mark.usefixtures('postgresql_dsn')
class TestExplain(object):
def test_render_explain(self, session, assert_startswith, Article):
assert_startswith(
explain(session.query(Article)),
'EXPLAIN SELECT'
)
def test_render_explain_with_analyze(self):
self.assert_startswith(
explain(self.session.query(self.Article), analyze=True),
def test_render_explain_with_analyze(
self,
session,
assert_startswith,
Article
):
assert_startswith(
explain(session.query(Article), analyze=True),
'EXPLAIN (ANALYZE true) SELECT'
)
def test_with_string_as_stmt_param(self):
self.assert_startswith(
def test_with_string_as_stmt_param(self, assert_startswith):
assert_startswith(
explain('SELECT 1 FROM article'),
'EXPLAIN SELECT'
)
def test_format(self):
self.assert_startswith(
def test_format(self, assert_startswith):
assert_startswith(
explain('SELECT 1 FROM article', format='json'),
'EXPLAIN (FORMAT json) SELECT'
)
def test_timing(self):
self.assert_startswith(
def test_timing(self, assert_startswith):
assert_startswith(
explain('SELECT 1 FROM article', analyze=True, timing=False),
'EXPLAIN (ANALYZE true, TIMING false) SELECT'
)
def test_verbose(self):
self.assert_startswith(
def test_verbose(self, assert_startswith):
assert_startswith(
explain('SELECT 1 FROM article', verbose=True),
'EXPLAIN (VERBOSE true) SELECT'
)
def test_buffers(self):
self.assert_startswith(
def test_buffers(self, assert_startswith):
assert_startswith(
explain('SELECT 1 FROM article', analyze=True, buffers=True),
'EXPLAIN (ANALYZE true, BUFFERS true) SELECT'
)
def test_costs(self):
self.assert_startswith(
def test_costs(self, assert_startswith):
assert_startswith(
explain('SELECT 1 FROM article', costs=False),
'EXPLAIN (COSTS false) SELECT'
)
class TestExplainAnalyze(ExpressionTestCase):
def test_render_explain_analyze(self):
class TestExplainAnalyze(object):
def test_render_explain_analyze(self, session, Article):
assert str(
explain_analyze(self.session.query(self.Article))
explain_analyze(session.query(Article))
.compile(
dialect=postgresql.dialect()
)
@ -111,7 +118,7 @@ class TestAsterisk(object):
class TestRowToJson(object):
def test_compiler_with_default_dialect(self):
with raises(sa.exc.CompileError):
with pytest.raises(sa.exc.CompileError):
str(row_to_json(sa.text('article.*')))
def test_compiler_with_postgresql(self):
@ -128,7 +135,7 @@ class TestRowToJson(object):
class TestArrayAgg(object):
def test_compiler_with_default_dialect(self):
with raises(sa.exc.CompileError):
with pytest.raises(sa.exc.CompileError):
str(sa.func.array_agg(sa.text('u.name')))
def test_compiler_with_postgresql(self):

View File

@ -1,27 +1,29 @@
from datetime import datetime
import pytest
import sqlalchemy as sa
from sqlalchemy_utils.listeners import force_instant_defaults
from tests import TestCase
force_instant_defaults()
class TestInstantDefaultListener(TestCase):
def create_models(self):
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255), default=u'Some article')
created_at = sa.Column(sa.DateTime, default=datetime.now)
@pytest.fixture
def Article(Base):
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255), default=u'Some article')
created_at = sa.Column(sa.DateTime, default=datetime.now)
return Article
self.Article = Article
def test_assigns_defaults_on_object_construction(self):
article = self.Article()
class TestInstantDefaultListener(object):
def test_assigns_defaults_on_object_construction(self, Article):
article = Article()
assert article.name == u'Some article'
def test_callables_as_defaults(self):
article = self.Article()
def test_callables_as_defaults(self, Article):
article = Article()
assert isinstance(article.created_at, datetime)

View File

@ -1,14 +1,19 @@
from tests import TestCase
class TestInstrumentedList(TestCase):
def test_any_returns_true_if_member_has_attr_defined(self):
category = self.Category()
category.articles.append(self.Article())
category.articles.append(self.Article(name=u'some name'))
class TestInstrumentedList(object):
def test_any_returns_true_if_member_has_attr_defined(
self,
Category,
Article
):
category = Category()
category.articles.append(Article())
category.articles.append(Article(name=u'some name'))
assert category.articles.any('name')
def test_any_returns_false_if_no_member_has_attr_defined(self):
category = self.Category()
category.articles.append(self.Article())
def test_any_returns_false_if_no_member_has_attr_defined(
self,
Category,
Article
):
category = Category()
category.articles.append(Article())
assert not category.articles.any('name')

View File

@ -1,39 +1,40 @@
from datetime import datetime
import pytest
import sqlalchemy as sa
from sqlalchemy_utils import Timestamp
from tests import TestCase
class TestTimestamp(TestCase):
@pytest.fixture
def Article(Base):
class Article(Base, Timestamp):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255), default=u'Some article')
return Article
def create_models(self):
class Article(self.Base, Timestamp):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255), default=u'Some article')
self.Article = Article
class TestTimestamp(object):
def test_created(self):
def test_created(self, session, Article):
then = datetime.utcnow()
article = self.Article()
article = Article()
self.session.add(article)
self.session.commit()
session.add(article)
session.commit()
assert article.created >= then and article.created <= datetime.utcnow()
def test_updated(self):
article = self.Article()
def test_updated(self, session, Article):
article = Article()
self.session.add(article)
self.session.commit()
session.add(article)
session.commit()
then = datetime.utcnow()
article.name = u"Something"
self.session.commit()
session.commit()
assert article.updated >= then and article.updated <= datetime.utcnow()

View File

@ -1,122 +1,127 @@
import pytest
import six
import sqlalchemy as sa
from pytest import mark
from sqlalchemy.util.langhelpers import symbol
from sqlalchemy_utils.path import AttrPath, Path
from tests import TestCase
class TestAttrPath(TestCase):
def create_models(self):
class Document(self.Base):
__tablename__ = 'document'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
@pytest.fixture
def Document(Base):
class Document(Base):
__tablename__ = 'document'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
return Document
class Section(self.Base):
__tablename__ = 'section'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
document_id = sa.Column(
sa.Integer, sa.ForeignKey(Document.id)
)
@pytest.fixture
def Section(Base, Document):
class Section(Base):
__tablename__ = 'section'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
document = sa.orm.relationship(Document, backref='sections')
class SubSection(self.Base):
__tablename__ = 'subsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
section_id = sa.Column(
sa.Integer, sa.ForeignKey(Section.id)
)
section = sa.orm.relationship(Section, backref='subsections')
self.Document = Document
self.Section = Section
self.SubSection = SubSection
@mark.parametrize(
('class_', 'path', 'direction'),
(
('SubSection', 'section', symbol('MANYTOONE')),
document_id = sa.Column(
sa.Integer, sa.ForeignKey(Document.id)
)
)
def test_direction(self, class_, path, direction):
document = sa.orm.relationship(Document, backref='sections')
return Section
@pytest.fixture
def SubSection(Base, Section):
class SubSection(Base):
__tablename__ = 'subsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
section_id = sa.Column(
sa.Integer, sa.ForeignKey(Section.id)
)
section = sa.orm.relationship(Section, backref='subsections')
return SubSection
class TestAttrPath(object):
@pytest.fixture
def init_models(self, Document, Section, SubSection):
pass
def test_direction(self, SubSection):
assert (
AttrPath(getattr(self, class_), path).direction == direction
AttrPath(SubSection, 'section').direction == symbol('MANYTOONE')
)
def test_invert(self):
path = ~ AttrPath(self.SubSection, 'section.document')
def test_invert(self, Document, Section, SubSection):
path = ~ AttrPath(SubSection, 'section.document')
assert path.parts == [
self.Document.sections,
self.Section.subsections
Document.sections,
Section.subsections
]
assert str(path.path) == 'sections.subsections'
def test_len(self):
len(AttrPath(self.SubSection, 'section.document')) == 2
def test_len(self, SubSection):
len(AttrPath(SubSection, 'section.document')) == 2
def test_init(self):
path = AttrPath(self.SubSection, 'section.document')
assert path.class_ == self.SubSection
def test_init(self, SubSection):
path = AttrPath(SubSection, 'section.document')
assert path.class_ == SubSection
assert path.path == Path('section.document')
def test_iter(self):
path = AttrPath(self.SubSection, 'section.document')
def test_iter(self, Section, SubSection):
path = AttrPath(SubSection, 'section.document')
assert list(path) == [
self.SubSection.section,
self.Section.document
SubSection.section,
Section.document
]
def test_repr(self):
path = AttrPath(self.SubSection, 'section.document')
def test_repr(self, SubSection):
path = AttrPath(SubSection, 'section.document')
assert repr(path) == (
"AttrPath(SubSection, 'section.document')"
)
def test_index(self):
path = AttrPath(self.SubSection, 'section.document')
assert path.index(self.Section.document) == 1
assert path.index(self.SubSection.section) == 0
def test_index(self, Section, SubSection):
path = AttrPath(SubSection, 'section.document')
assert path.index(Section.document) == 1
assert path.index(SubSection.section) == 0
def test_getitem(self):
path = AttrPath(self.SubSection, 'section.document')
assert path[0] is self.SubSection.section
assert path[1] is self.Section.document
def test_getitem(self, Section, SubSection):
path = AttrPath(SubSection, 'section.document')
assert path[0] is SubSection.section
assert path[1] is Section.document
def test_getitem_with_slice(self):
path = AttrPath(self.SubSection, 'section.document')
assert path[:] == AttrPath(self.SubSection, 'section.document')
assert path[:-1] == AttrPath(self.SubSection, 'section')
assert path[1:] == AttrPath(self.Section, 'document')
def test_getitem_with_slice(self, Section, SubSection):
path = AttrPath(SubSection, 'section.document')
assert path[:] == AttrPath(SubSection, 'section.document')
assert path[:-1] == AttrPath(SubSection, 'section')
assert path[1:] == AttrPath(Section, 'document')
def test_eq(self):
def test_eq(self, SubSection):
assert (
AttrPath(self.SubSection, 'section.document') ==
AttrPath(self.SubSection, 'section.document')
AttrPath(SubSection, 'section.document') ==
AttrPath(SubSection, 'section.document')
)
assert not (
AttrPath(self.SubSection, 'section') ==
AttrPath(self.SubSection, 'section.document')
AttrPath(SubSection, 'section') ==
AttrPath(SubSection, 'section.document')
)
def test_ne(self):
def test_ne(self, SubSection):
assert not (
AttrPath(self.SubSection, 'section.document') !=
AttrPath(self.SubSection, 'section.document')
AttrPath(SubSection, 'section.document') !=
AttrPath(SubSection, 'section.document')
)
assert (
AttrPath(self.SubSection, 'section') !=
AttrPath(self.SubSection, 'section.document')
AttrPath(SubSection, 'section') !=
AttrPath(SubSection, 'section.document')
)
@ -133,7 +138,7 @@ class TestPath(object):
path = Path('s.s2.s3')
assert list(path) == ['s', 's2', 's3']
@mark.parametrize(('path', 'length'), (
@pytest.mark.parametrize(('path', 'length'), (
(Path('s.s2.s3'), 3),
(Path('s.s2'), 2),
(Path(''), 0)
@ -167,14 +172,14 @@ class TestPath(object):
path = Path('s.s2.s3')
assert path[1:] == Path('s2.s3')
@mark.parametrize(('test', 'result'), (
@pytest.mark.parametrize(('test', 'result'), (
(Path('s.s2') == Path('s.s2'), True),
(Path('s.s2') == Path('s.s3'), False)
))
def test_eq(self, test, result):
assert test is result
@mark.parametrize(('test', 'result'), (
@pytest.mark.parametrize(('test', 'result'), (
(Path('s.s2') != Path('s.s2'), False),
(Path('s.s2') != Path('s.s3'), True)
))

Some files were not shown because too many files have changed in this diff Show More