204 lines
4.9 KiB
Python
204 lines
4.9 KiB
Python
import os
|
|
import warnings
|
|
|
|
import pytest
|
|
import sqlalchemy as sa
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.ext.declarative import declarative_base, synonym_for
|
|
from sqlalchemy.ext.hybrid import hybrid_property
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy_utils import (
|
|
aggregates,
|
|
coercion_listener,
|
|
i18n,
|
|
InstrumentedList
|
|
)
|
|
|
|
from sqlalchemy_utils.types.pg_composite import remove_composite_listeners
|
|
|
|
|
|
@sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute')
|
|
def count_sql_calls(conn, cursor, statement, parameters, context, executemany):
|
|
try:
|
|
conn.query_count += 1
|
|
except AttributeError:
|
|
conn.query_count = 0
|
|
|
|
|
|
warnings.simplefilter('error', sa.exc.SAWarning)
|
|
|
|
sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener)
|
|
|
|
|
|
def get_locale():
|
|
class Locale():
|
|
territories = {'FI': 'Finland'}
|
|
|
|
return Locale()
|
|
|
|
|
|
@pytest.fixture(scope='session')
|
|
def db_name():
|
|
return os.environ.get('SQLALCHEMY_UTILS_TEST_DB', 'sqlalchemy_utils_test')
|
|
|
|
|
|
@pytest.fixture(scope='session')
|
|
def postgresql_db_user():
|
|
return os.environ.get('SQLALCHEMY_UTILS_TEST_POSTGRESQL_USER', 'postgres')
|
|
|
|
|
|
@pytest.fixture(scope='session')
|
|
def mysql_db_user():
|
|
return os.environ.get('SQLALCHEMY_UTILS_TEST_MYSQL_USER', 'root')
|
|
|
|
|
|
@pytest.fixture
|
|
def postgresql_dsn(postgresql_db_user, db_name):
|
|
return 'postgres://{0}@localhost/{1}'.format(postgresql_db_user, db_name)
|
|
|
|
|
|
@pytest.fixture
|
|
def mysql_dsn(mysql_db_user, db_name):
|
|
return 'mysql+pymysql://{0}@localhost/{1}'.format(mysql_db_user, db_name)
|
|
|
|
|
|
@pytest.fixture
|
|
def sqlite_memory_dsn():
|
|
return 'sqlite:///:memory:'
|
|
|
|
|
|
@pytest.fixture
|
|
def sqlite_none_database_dsn():
|
|
return 'sqlite://'
|
|
|
|
|
|
@pytest.fixture
|
|
def sqlite_file_dsn():
|
|
return 'sqlite:///{0}.db'.format(db_name)
|
|
|
|
|
|
@pytest.fixture
|
|
def dsn(request):
|
|
if 'postgresql_dsn' in request.fixturenames:
|
|
return request.getfuncargvalue('postgresql_dsn')
|
|
elif 'mysql_dsn' in request.fixturenames:
|
|
return request.getfuncargvalue('mysql_dsn')
|
|
elif 'sqlite_file_dsn' in request.fixturenames:
|
|
return request.getfuncargvalue('sqlite_file_dsn')
|
|
elif 'sqlite_memory_dsn' in request.fixturenames:
|
|
pass # Return default
|
|
return request.getfuncargvalue('sqlite_memory_dsn')
|
|
|
|
|
|
@pytest.fixture
|
|
def engine(dsn):
|
|
engine = create_engine(dsn)
|
|
# engine.echo = True
|
|
return engine
|
|
|
|
|
|
@pytest.fixture
|
|
def connection(engine):
|
|
return engine.connect()
|
|
|
|
|
|
@pytest.fixture
|
|
def Base():
|
|
return declarative_base()
|
|
|
|
|
|
@pytest.fixture
|
|
def User(Base):
|
|
class User(Base):
|
|
__tablename__ = 'user'
|
|
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
|
name = sa.Column(sa.Unicode(255))
|
|
return User
|
|
|
|
|
|
@pytest.fixture
|
|
def Category(Base):
|
|
|
|
class Category(Base):
|
|
__tablename__ = 'category'
|
|
id = sa.Column(sa.Integer, primary_key=True)
|
|
name = sa.Column(sa.Unicode(255))
|
|
title = sa.Column(sa.Unicode(255))
|
|
|
|
@hybrid_property
|
|
def full_name(self):
|
|
return u'%s %s' % (self.title, self.name)
|
|
|
|
@full_name.expression
|
|
def full_name(self):
|
|
return sa.func.concat(self.title, ' ', self.name)
|
|
|
|
@hybrid_property
|
|
def articles_count(self):
|
|
return len(self.articles)
|
|
|
|
@articles_count.expression
|
|
def articles_count(cls):
|
|
Article = Base._decl_class_registry['Article']
|
|
return (
|
|
sa.select([sa.func.count(Article.id)])
|
|
.where(Article.category_id == cls.id)
|
|
.correlate(Article.__table__)
|
|
.label('article_count')
|
|
)
|
|
|
|
@property
|
|
def name_alias(self):
|
|
return self.name
|
|
|
|
@synonym_for('name')
|
|
@property
|
|
def name_synonym(self):
|
|
return self.name
|
|
return Category
|
|
|
|
|
|
@pytest.fixture
|
|
def Article(Base, Category):
|
|
class Article(Base):
|
|
__tablename__ = 'article'
|
|
id = sa.Column(sa.Integer, primary_key=True)
|
|
name = sa.Column(sa.Unicode(255), index=True)
|
|
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
|
|
|
category = sa.orm.relationship(
|
|
Category,
|
|
primaryjoin=category_id == Category.id,
|
|
backref=sa.orm.backref(
|
|
'articles',
|
|
collection_class=InstrumentedList
|
|
)
|
|
)
|
|
return Article
|
|
|
|
|
|
@pytest.fixture
|
|
def init_models(User, Category, Article):
|
|
pass
|
|
|
|
|
|
@pytest.fixture
|
|
def session(request, engine, connection, Base, init_models):
|
|
sa.orm.configure_mappers()
|
|
Base.metadata.create_all(connection)
|
|
Session = sessionmaker(bind=connection)
|
|
session = Session()
|
|
i18n.get_locale = get_locale
|
|
|
|
def teardown():
|
|
aggregates.manager.reset()
|
|
session.close_all()
|
|
Base.metadata.drop_all(connection)
|
|
remove_composite_listeners()
|
|
connection.close()
|
|
engine.dispose()
|
|
|
|
request.addfinalizer(teardown)
|
|
|
|
return session
|