Merge pull request #231 from quantus/feature/multi-column-observer

Feature/multi column observer
This commit is contained in:
Konsta Vesterinen 2016-07-17 22:58:55 +03:00 committed by GitHub
commit 8fc4fe5550
2 changed files with 151 additions and 47 deletions

View File

@ -148,6 +148,29 @@ Category has many Products.
session.commit()
catalog.product_count # 1
Observing multiple columns
-----------------------
You can also observe multiple columns by spesifying all the observable columns
in the decorator.
::
class Order(Base):
__tablename__ = 'order'
id = sa.Column(sa.Integer, primary_key=True)
unit_price = sa.Column(sa.Integer)
amount = sa.Column(sa.Integer)
total_price = sa.Column(sa.Integer)
@observes('amount', 'unit_price')
def total_price_observer(self, amount, unit_price):
self.total_price = amount * unit_price
"""
import itertools
from collections import defaultdict, Iterable, namedtuple
@ -158,7 +181,7 @@ from .functions import getdotattr, has_changes
from .path import AttrPath
from .utils import is_sequence
Callback = namedtuple('Callback', ['func', 'path', 'backref', 'fullpath'])
Callback = namedtuple('Callback', ['func', 'backref', 'fullpath'])
class PropertyObserver(object):
@ -208,32 +231,33 @@ class PropertyObserver(object):
)
def gather_paths(self):
for class_, callbacks in self.generator_registry.items():
for callback in callbacks:
path = AttrPath(class_, callback.__observes__)
for class_, generators in self.generator_registry.items():
for callback in generators:
full_paths = []
for call_path in callback.__observes__:
full_paths.append(AttrPath(class_, call_path))
self.callback_map[class_].append(
Callback(
func=callback,
path=path,
backref=None,
fullpath=path
)
)
for index in range(len(path)):
i = index + 1
prop = path[index].property
if isinstance(prop, sa.orm.RelationshipProperty):
prop_class = path[index].property.mapper.class_
self.callback_map[prop_class].append(
Callback(
func=callback,
path=path[i:],
backref=~ (path[:i]),
fullpath=path
)
for path in full_paths:
self.callback_map[class_].append(
Callback(
func=callback,
backref=None,
fullpath=full_paths
)
)
for index in range(len(path)):
i = index + 1
prop = path[index].property
if isinstance(prop, sa.orm.RelationshipProperty):
prop_class = path[index].property.mapper.class_
self.callback_map[prop_class].append(
Callback(
func=callback,
backref=~ (path[:i]),
fullpath=full_paths
)
)
def gather_callback_args(self, obj, callbacks):
for callback in callbacks:
@ -252,18 +276,19 @@ class PropertyObserver(object):
def get_callback_args(self, root_obj, callback):
session = sa.orm.object_session(root_obj)
objects = getdotattr(
objects = [getdotattr(
root_obj,
callback.fullpath,
path,
lambda obj: obj not in session.deleted
)
path = str(callback.fullpath)
if '.' in path or has_changes(root_obj, path):
return (
root_obj,
callback.func,
objects
)
) for path in callback.fullpath]
paths = [str(path) for path in callback.fullpath]
for path in paths:
if '.' in path or has_changes(root_obj, path):
return (
root_obj,
callback.func,
objects
)
def iterate_objects_and_callbacks(self, session):
objs = itertools.chain(session.new, session.dirty, session.deleted)
@ -277,21 +302,25 @@ class PropertyObserver(object):
for obj, callbacks in self.iterate_objects_and_callbacks(session):
args = self.gather_callback_args(obj, callbacks)
for (root_obj, func, objects) in args:
if is_sequence(objects):
callback_args[root_obj][func] = (
callback_args[root_obj][func] | set(objects)
)
else:
callback_args[root_obj][func] = objects
if not callback_args[root_obj][func]:
callback_args[root_obj][func] = {}
for i, object_ in enumerate(objects):
if is_sequence(object_):
callback_args[root_obj][func][i] = (
callback_args[root_obj][func].get(i, set()) |
set(object_)
)
else:
callback_args[root_obj][func][i] = object_
for root_obj, callback_objs in callback_args.items():
for callback, objs in callback_objs.items():
callback(root_obj, objs)
callback(root_obj, *[objs[i] for i in range(len(objs))])
observer = PropertyObserver()
def observes(path, observer=observer):
def observes(*paths, **observer_kw):
"""
Mark method as property observer for the given property path. Inside
transaction observer gathers all changes made in given property path and
@ -327,14 +356,17 @@ def observes(path, observer=observer):
.. versionadded: 0.28.0
:param path: Dot-notated property path, eg. 'categories.products.price'
:param observer: :meth:`PropertyObserver` object
:param *paths: One or more dot-notated property paths, eg.
'categories.products.price'
:param **observer: A dictionary where value for key 'observer' contains
:meth:`PropertyObserver` object
"""
observer.register_listeners()
observer_ = observer_kw.pop('observer', observer)
observer_.register_listeners()
def wraps(func):
def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
wrapper.__observes__ = path
wrapper.__observes__ = paths
return wrapper
return wraps

View File

@ -58,3 +58,75 @@ class TestObservesForColumnWithoutActualChanges(object):
product.price = 500
session.commit()
assert str(e.value) == 'Trying to change price'
@pytest.mark.usefixtures('postgresql_dsn')
class TestObservesForMultipleColumns(object):
@pytest.fixture
def Order(self, Base):
class Order(Base):
__tablename__ = 'order'
id = sa.Column(sa.Integer, primary_key=True)
unit_price = sa.Column(sa.Integer)
amount = sa.Column(sa.Integer)
total_price = sa.Column(sa.Integer)
@observes('amount', 'unit_price')
def total_price_observer(self, amount, unit_price):
self.total_price = amount * unit_price
return Order
@pytest.fixture
def init_models(self, Order):
pass
def test_only_notifies_observer_on_actual_changes(self, session, Order):
order = Order()
order.amount = 2
order.unit_price = 10
session.add(order)
session.flush()
order.amount = 1
session.flush()
assert order.total_price == 10
order.unit_price = 100
session.flush()
assert order.total_price == 100
@pytest.mark.usefixtures('postgresql_dsn')
class TestObservesForMultipleColumnsFiresOnlyOnce(object):
@pytest.fixture
def Order(self, Base):
class Order(Base):
__tablename__ = 'order'
id = sa.Column(sa.Integer, primary_key=True)
unit_price = sa.Column(sa.Integer)
amount = sa.Column(sa.Integer)
@observes('amount', 'unit_price')
def total_price_observer(self, amount, unit_price):
self.call_count = self.call_count + 1
return Order
@pytest.fixture
def init_models(self, Order):
pass
def test_only_notifies_observer_on_actual_changes(self, session, Order):
order = Order()
order.amount = 2
order.unit_price = 10
order.call_count = 0
session.add(order)
session.flush()
assert order.call_count == 1
order.amount = 1
order.unit_price = 100
session.flush()
assert order.call_count == 2