Merge pull request #231 from quantus/feature/multi-column-observer
Feature/multi column observer
This commit is contained in:
commit
8fc4fe5550
|
@ -148,6 +148,29 @@ Category has many Products.
|
|||
session.commit()
|
||||
catalog.product_count # 1
|
||||
|
||||
|
||||
Observing multiple columns
|
||||
-----------------------
|
||||
|
||||
You can also observe multiple columns by spesifying all the observable columns
|
||||
in the decorator.
|
||||
|
||||
|
||||
::
|
||||
|
||||
class Order(Base):
|
||||
__tablename__ = 'order'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
unit_price = sa.Column(sa.Integer)
|
||||
amount = sa.Column(sa.Integer)
|
||||
total_price = sa.Column(sa.Integer)
|
||||
|
||||
@observes('amount', 'unit_price')
|
||||
def total_price_observer(self, amount, unit_price):
|
||||
self.total_price = amount * unit_price
|
||||
|
||||
|
||||
|
||||
"""
|
||||
import itertools
|
||||
from collections import defaultdict, Iterable, namedtuple
|
||||
|
@ -158,7 +181,7 @@ from .functions import getdotattr, has_changes
|
|||
from .path import AttrPath
|
||||
from .utils import is_sequence
|
||||
|
||||
Callback = namedtuple('Callback', ['func', 'path', 'backref', 'fullpath'])
|
||||
Callback = namedtuple('Callback', ['func', 'backref', 'fullpath'])
|
||||
|
||||
|
||||
class PropertyObserver(object):
|
||||
|
@ -208,32 +231,33 @@ class PropertyObserver(object):
|
|||
)
|
||||
|
||||
def gather_paths(self):
|
||||
for class_, callbacks in self.generator_registry.items():
|
||||
for callback in callbacks:
|
||||
path = AttrPath(class_, callback.__observes__)
|
||||
for class_, generators in self.generator_registry.items():
|
||||
for callback in generators:
|
||||
full_paths = []
|
||||
for call_path in callback.__observes__:
|
||||
full_paths.append(AttrPath(class_, call_path))
|
||||
|
||||
self.callback_map[class_].append(
|
||||
Callback(
|
||||
func=callback,
|
||||
path=path,
|
||||
backref=None,
|
||||
fullpath=path
|
||||
)
|
||||
)
|
||||
|
||||
for index in range(len(path)):
|
||||
i = index + 1
|
||||
prop = path[index].property
|
||||
if isinstance(prop, sa.orm.RelationshipProperty):
|
||||
prop_class = path[index].property.mapper.class_
|
||||
self.callback_map[prop_class].append(
|
||||
Callback(
|
||||
func=callback,
|
||||
path=path[i:],
|
||||
backref=~ (path[:i]),
|
||||
fullpath=path
|
||||
)
|
||||
for path in full_paths:
|
||||
self.callback_map[class_].append(
|
||||
Callback(
|
||||
func=callback,
|
||||
backref=None,
|
||||
fullpath=full_paths
|
||||
)
|
||||
)
|
||||
|
||||
for index in range(len(path)):
|
||||
i = index + 1
|
||||
prop = path[index].property
|
||||
if isinstance(prop, sa.orm.RelationshipProperty):
|
||||
prop_class = path[index].property.mapper.class_
|
||||
self.callback_map[prop_class].append(
|
||||
Callback(
|
||||
func=callback,
|
||||
backref=~ (path[:i]),
|
||||
fullpath=full_paths
|
||||
)
|
||||
)
|
||||
|
||||
def gather_callback_args(self, obj, callbacks):
|
||||
for callback in callbacks:
|
||||
|
@ -252,18 +276,19 @@ class PropertyObserver(object):
|
|||
|
||||
def get_callback_args(self, root_obj, callback):
|
||||
session = sa.orm.object_session(root_obj)
|
||||
objects = getdotattr(
|
||||
objects = [getdotattr(
|
||||
root_obj,
|
||||
callback.fullpath,
|
||||
path,
|
||||
lambda obj: obj not in session.deleted
|
||||
)
|
||||
path = str(callback.fullpath)
|
||||
if '.' in path or has_changes(root_obj, path):
|
||||
return (
|
||||
root_obj,
|
||||
callback.func,
|
||||
objects
|
||||
)
|
||||
) for path in callback.fullpath]
|
||||
paths = [str(path) for path in callback.fullpath]
|
||||
for path in paths:
|
||||
if '.' in path or has_changes(root_obj, path):
|
||||
return (
|
||||
root_obj,
|
||||
callback.func,
|
||||
objects
|
||||
)
|
||||
|
||||
def iterate_objects_and_callbacks(self, session):
|
||||
objs = itertools.chain(session.new, session.dirty, session.deleted)
|
||||
|
@ -277,21 +302,25 @@ class PropertyObserver(object):
|
|||
for obj, callbacks in self.iterate_objects_and_callbacks(session):
|
||||
args = self.gather_callback_args(obj, callbacks)
|
||||
for (root_obj, func, objects) in args:
|
||||
if is_sequence(objects):
|
||||
callback_args[root_obj][func] = (
|
||||
callback_args[root_obj][func] | set(objects)
|
||||
)
|
||||
else:
|
||||
callback_args[root_obj][func] = objects
|
||||
if not callback_args[root_obj][func]:
|
||||
callback_args[root_obj][func] = {}
|
||||
for i, object_ in enumerate(objects):
|
||||
if is_sequence(object_):
|
||||
callback_args[root_obj][func][i] = (
|
||||
callback_args[root_obj][func].get(i, set()) |
|
||||
set(object_)
|
||||
)
|
||||
else:
|
||||
callback_args[root_obj][func][i] = object_
|
||||
|
||||
for root_obj, callback_objs in callback_args.items():
|
||||
for callback, objs in callback_objs.items():
|
||||
callback(root_obj, objs)
|
||||
callback(root_obj, *[objs[i] for i in range(len(objs))])
|
||||
|
||||
observer = PropertyObserver()
|
||||
|
||||
|
||||
def observes(path, observer=observer):
|
||||
def observes(*paths, **observer_kw):
|
||||
"""
|
||||
Mark method as property observer for the given property path. Inside
|
||||
transaction observer gathers all changes made in given property path and
|
||||
|
@ -327,14 +356,17 @@ def observes(path, observer=observer):
|
|||
|
||||
.. versionadded: 0.28.0
|
||||
|
||||
:param path: Dot-notated property path, eg. 'categories.products.price'
|
||||
:param observer: :meth:`PropertyObserver` object
|
||||
:param *paths: One or more dot-notated property paths, eg.
|
||||
'categories.products.price'
|
||||
:param **observer: A dictionary where value for key 'observer' contains
|
||||
:meth:`PropertyObserver` object
|
||||
"""
|
||||
observer.register_listeners()
|
||||
observer_ = observer_kw.pop('observer', observer)
|
||||
observer_.register_listeners()
|
||||
|
||||
def wraps(func):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
return func(self, *args, **kwargs)
|
||||
wrapper.__observes__ = path
|
||||
wrapper.__observes__ = paths
|
||||
return wrapper
|
||||
return wraps
|
||||
|
|
|
@ -58,3 +58,75 @@ class TestObservesForColumnWithoutActualChanges(object):
|
|||
product.price = 500
|
||||
session.commit()
|
||||
assert str(e.value) == 'Trying to change price'
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestObservesForMultipleColumns(object):
|
||||
|
||||
@pytest.fixture
|
||||
def Order(self, Base):
|
||||
class Order(Base):
|
||||
__tablename__ = 'order'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
unit_price = sa.Column(sa.Integer)
|
||||
amount = sa.Column(sa.Integer)
|
||||
total_price = sa.Column(sa.Integer)
|
||||
|
||||
@observes('amount', 'unit_price')
|
||||
def total_price_observer(self, amount, unit_price):
|
||||
self.total_price = amount * unit_price
|
||||
return Order
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(self, Order):
|
||||
pass
|
||||
|
||||
def test_only_notifies_observer_on_actual_changes(self, session, Order):
|
||||
order = Order()
|
||||
order.amount = 2
|
||||
order.unit_price = 10
|
||||
session.add(order)
|
||||
session.flush()
|
||||
|
||||
order.amount = 1
|
||||
session.flush()
|
||||
assert order.total_price == 10
|
||||
|
||||
order.unit_price = 100
|
||||
session.flush()
|
||||
assert order.total_price == 100
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestObservesForMultipleColumnsFiresOnlyOnce(object):
|
||||
|
||||
@pytest.fixture
|
||||
def Order(self, Base):
|
||||
class Order(Base):
|
||||
__tablename__ = 'order'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
unit_price = sa.Column(sa.Integer)
|
||||
amount = sa.Column(sa.Integer)
|
||||
|
||||
@observes('amount', 'unit_price')
|
||||
def total_price_observer(self, amount, unit_price):
|
||||
self.call_count = self.call_count + 1
|
||||
return Order
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(self, Order):
|
||||
pass
|
||||
|
||||
def test_only_notifies_observer_on_actual_changes(self, session, Order):
|
||||
order = Order()
|
||||
order.amount = 2
|
||||
order.unit_price = 10
|
||||
order.call_count = 0
|
||||
session.add(order)
|
||||
session.flush()
|
||||
assert order.call_count == 1
|
||||
|
||||
order.amount = 1
|
||||
order.unit_price = 100
|
||||
session.flush()
|
||||
assert order.call_count == 2
|
||||
|
|
Loading…
Reference in New Issue