From c85aa57866c5277570b622cd907c2b160805ba93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pekka=20P=C3=B6yry?= Date: Thu, 14 Jul 2016 13:54:43 +0300 Subject: [PATCH 1/2] Remove unused internal variable from observer --- sqlalchemy_utils/observer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sqlalchemy_utils/observer.py b/sqlalchemy_utils/observer.py index 4f08722..b4cb37f 100644 --- a/sqlalchemy_utils/observer.py +++ b/sqlalchemy_utils/observer.py @@ -158,7 +158,7 @@ from .functions import getdotattr, has_changes from .path import AttrPath from .utils import is_sequence -Callback = namedtuple('Callback', ['func', 'path', 'backref', 'fullpath']) +Callback = namedtuple('Callback', ['func', 'backref', 'fullpath']) class PropertyObserver(object): @@ -215,7 +215,6 @@ class PropertyObserver(object): self.callback_map[class_].append( Callback( func=callback, - path=path, backref=None, fullpath=path ) @@ -229,7 +228,6 @@ class PropertyObserver(object): self.callback_map[prop_class].append( Callback( func=callback, - path=path[i:], backref=~ (path[:i]), fullpath=path ) From 1f8dccfd11f08c4d63b8fe56ca61fdaa90432b6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pekka=20P=C3=B6yry?= Date: Thu, 14 Jul 2016 13:52:21 +0300 Subject: [PATCH 2/2] Add support for multi column observers --- sqlalchemy_utils/observer.py | 122 ++++++++++++++++--------- tests/observes/test_column_property.py | 72 +++++++++++++++ 2 files changed, 150 insertions(+), 44 deletions(-) diff --git a/sqlalchemy_utils/observer.py b/sqlalchemy_utils/observer.py index b4cb37f..0e8fcf5 100644 --- a/sqlalchemy_utils/observer.py +++ b/sqlalchemy_utils/observer.py @@ -148,6 +148,29 @@ Category has many Products. session.commit() catalog.product_count # 1 + +Observing multiple columns +----------------------- + +You can also observe multiple columns by spesifying all the observable columns +in the decorator. + + +:: + + class Order(Base): + __tablename__ = 'order' + id = sa.Column(sa.Integer, primary_key=True) + unit_price = sa.Column(sa.Integer) + amount = sa.Column(sa.Integer) + total_price = sa.Column(sa.Integer) + + @observes('amount', 'unit_price') + def total_price_observer(self, amount, unit_price): + self.total_price = amount * unit_price + + + """ import itertools from collections import defaultdict, Iterable, namedtuple @@ -208,30 +231,33 @@ class PropertyObserver(object): ) def gather_paths(self): - for class_, callbacks in self.generator_registry.items(): - for callback in callbacks: - path = AttrPath(class_, callback.__observes__) + for class_, generators in self.generator_registry.items(): + for callback in generators: + full_paths = [] + for call_path in callback.__observes__: + full_paths.append(AttrPath(class_, call_path)) - self.callback_map[class_].append( - Callback( - func=callback, - backref=None, - fullpath=path - ) - ) - - for index in range(len(path)): - i = index + 1 - prop = path[index].property - if isinstance(prop, sa.orm.RelationshipProperty): - prop_class = path[index].property.mapper.class_ - self.callback_map[prop_class].append( - Callback( - func=callback, - backref=~ (path[:i]), - fullpath=path - ) + for path in full_paths: + self.callback_map[class_].append( + Callback( + func=callback, + backref=None, + fullpath=full_paths ) + ) + + for index in range(len(path)): + i = index + 1 + prop = path[index].property + if isinstance(prop, sa.orm.RelationshipProperty): + prop_class = path[index].property.mapper.class_ + self.callback_map[prop_class].append( + Callback( + func=callback, + backref=~ (path[:i]), + fullpath=full_paths + ) + ) def gather_callback_args(self, obj, callbacks): for callback in callbacks: @@ -250,18 +276,19 @@ class PropertyObserver(object): def get_callback_args(self, root_obj, callback): session = sa.orm.object_session(root_obj) - objects = getdotattr( + objects = [getdotattr( root_obj, - callback.fullpath, + path, lambda obj: obj not in session.deleted - ) - path = str(callback.fullpath) - if '.' in path or has_changes(root_obj, path): - return ( - root_obj, - callback.func, - objects - ) + ) for path in callback.fullpath] + paths = [str(path) for path in callback.fullpath] + for path in paths: + if '.' in path or has_changes(root_obj, path): + return ( + root_obj, + callback.func, + objects + ) def iterate_objects_and_callbacks(self, session): objs = itertools.chain(session.new, session.dirty, session.deleted) @@ -275,21 +302,25 @@ class PropertyObserver(object): for obj, callbacks in self.iterate_objects_and_callbacks(session): args = self.gather_callback_args(obj, callbacks) for (root_obj, func, objects) in args: - if is_sequence(objects): - callback_args[root_obj][func] = ( - callback_args[root_obj][func] | set(objects) - ) - else: - callback_args[root_obj][func] = objects + if not callback_args[root_obj][func]: + callback_args[root_obj][func] = {} + for i, object_ in enumerate(objects): + if is_sequence(object_): + callback_args[root_obj][func][i] = ( + callback_args[root_obj][func].get(i, set()) | + set(object_) + ) + else: + callback_args[root_obj][func][i] = object_ for root_obj, callback_objs in callback_args.items(): for callback, objs in callback_objs.items(): - callback(root_obj, objs) + callback(root_obj, *[objs[i] for i in range(len(objs))]) observer = PropertyObserver() -def observes(path, observer=observer): +def observes(*paths, **observer_kw): """ Mark method as property observer for the given property path. Inside transaction observer gathers all changes made in given property path and @@ -325,14 +356,17 @@ def observes(path, observer=observer): .. versionadded: 0.28.0 - :param path: Dot-notated property path, eg. 'categories.products.price' - :param observer: :meth:`PropertyObserver` object + :param *paths: One or more dot-notated property paths, eg. + 'categories.products.price' + :param **observer: A dictionary where value for key 'observer' contains + :meth:`PropertyObserver` object """ - observer.register_listeners() + observer_ = observer_kw.pop('observer', observer) + observer_.register_listeners() def wraps(func): def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) - wrapper.__observes__ = path + wrapper.__observes__ = paths return wrapper return wraps diff --git a/tests/observes/test_column_property.py b/tests/observes/test_column_property.py index f388f26..fde0db5 100644 --- a/tests/observes/test_column_property.py +++ b/tests/observes/test_column_property.py @@ -58,3 +58,75 @@ class TestObservesForColumnWithoutActualChanges(object): product.price = 500 session.commit() assert str(e.value) == 'Trying to change price' + + +@pytest.mark.usefixtures('postgresql_dsn') +class TestObservesForMultipleColumns(object): + + @pytest.fixture + def Order(self, Base): + class Order(Base): + __tablename__ = 'order' + id = sa.Column(sa.Integer, primary_key=True) + unit_price = sa.Column(sa.Integer) + amount = sa.Column(sa.Integer) + total_price = sa.Column(sa.Integer) + + @observes('amount', 'unit_price') + def total_price_observer(self, amount, unit_price): + self.total_price = amount * unit_price + return Order + + @pytest.fixture + def init_models(self, Order): + pass + + def test_only_notifies_observer_on_actual_changes(self, session, Order): + order = Order() + order.amount = 2 + order.unit_price = 10 + session.add(order) + session.flush() + + order.amount = 1 + session.flush() + assert order.total_price == 10 + + order.unit_price = 100 + session.flush() + assert order.total_price == 100 + + +@pytest.mark.usefixtures('postgresql_dsn') +class TestObservesForMultipleColumnsFiresOnlyOnce(object): + + @pytest.fixture + def Order(self, Base): + class Order(Base): + __tablename__ = 'order' + id = sa.Column(sa.Integer, primary_key=True) + unit_price = sa.Column(sa.Integer) + amount = sa.Column(sa.Integer) + + @observes('amount', 'unit_price') + def total_price_observer(self, amount, unit_price): + self.call_count = self.call_count + 1 + return Order + + @pytest.fixture + def init_models(self, Order): + pass + + def test_only_notifies_observer_on_actual_changes(self, session, Order): + order = Order() + order.amount = 2 + order.unit_price = 10 + order.call_count = 0 + session.add(order) + session.flush() + assert order.call_count == 1 + + order.amount = 1 + order.unit_price = 100 + session.flush() + assert order.call_count == 2