""" Pika connection pool inspired by: - https://github.com/WeatherDecisionTechnologies/flask-pika and this interface: - http://docs.sqlalchemy.org/en/latest/core/pooling.html#sqlalchemy.pool.Pool Get it like this: .. code:: python $ pip install pika-pool Use it like e.g. this: .. code:: python import json import pika import pika_pool params = pika.URLParameters( 'amqp://guest:guest@localhost:5672/%2F?socket_timeout=5&connection_attempts=2' ) pool = pika_pool.QueuedPool( create=lambda: pika.BlockingConnection(parameters=params) recycle=45, max_size=10, max_overflow=10, timeout=10, ) with pool.acquire() as cxn: cxn.channel.basic_publish( body=json.dumps({'type': 'banana', 'color': 'yellow'}), exchange='exchange', routing_key='banana', properties={ 'content_type': 'application/json', 'content_encoding': 'utf-8', 'delivery_mode': 2 } ) """ from __future__ import unicode_literals from datetime import datetime import logging try: # python 3 import queue except ImportError: # python 2 import Queue as queue import select import socket import threading import time import pika.exceptions __version__ = '0.1.2' __all__ = [ 'Error' 'Timeout' 'Overflow' 'Connection', 'Pool', 'NullPool', 'QueuedPool', ] logger = logging.getLogger(__name__) class Error(Exception): pass class Overflow(Error): """ Raised when a `Pool.acquire` cannot allocate anymore connections. """ pass class Timeout(Error): """ Raised when an attempt to `Pool.acquire` a connection has timedout. """ pass class Connection(object): """ Connection acquired from a `Pool` instance. Get them like this: .. code:: python with pool.acquire() as cxn: print cxn.channel """ #: Exceptions that imply connection has been invalidated. connectivity_errors = ( pika.exceptions.AMQPConnectionError, pika.exceptions.ConnectionClosed, pika.exceptions.ChannelClosed, select.error, # XXX: https://github.com/pika/pika/issues/412 ) @classmethod def is_connection_invalidated(cls, exc): """ Says whether the given exception indicates the connection has been invalidated. :param exc: Exception object. :return: True if connection has been invalidted, otherwise False. """ return any( isinstance(exc, error)for error in cls.connectivity_errors ) def __init__(self, pool, fairy): self.pool = pool self.fairy = fairy @property def channel(self): if self.fairy.channel is None: self.fairy.channel = self.fairy.cxn.channel() return self.fairy.channel def close(self): self.pool.close(self.fairy) self.fairy = None def release(self): self.pool.release(self.fairy) self.fairy = None def __enter__(self): return self def __exit__(self, type, value, traceback): if type is None or not self.is_connection_invalidated(value): self.release() else: self.close() class Pool(object): """ Pool interface similar to: http://docs.sqlalchemy.org/en/latest/core/pooling.html#sqlalchemy.pool.Pool and used like: .. code:: python with pool.acquire(timeout=60) as cxn: cxn.channel.basic_publish( ... ) """ #: Acquired connection type. Connection = Connection def __init__(self, create): """ :param create: Callable creating a new connection. """ self.create = create def acquire(self, timeout=None): """ Retrieve a connection from the pool or create a new one. """ raise NotImplementedError def release(self, fairy): """ Return a connection to the pool. """ raise NotImplementedError def close(self, fairy): """ Forcibly close a connection, suppressing any connection errors. """ fairy.close() class Fairy(object): """ Connection wrapper for tracking its associated state. """ def __init__(self, cxn): self.cxn = cxn self.channel = None def close(self): if self.channel: try: self.channel.close() except Connection.connectivity_errors as ex: if not Connection.is_connection_invalidated(ex): raise self.channel = None try: self.cxn.close() except Connection.connectivity_errors as ex: if not Connection.is_connection_invalidated(ex): raise @property def cxn_params(self): if isinstance(self.cxn, pika.BaseConnection): return self.cxn.params if isinstance(self.cxn, pika.BlockingConnection): return self.cxn._impl.params @property def cxn_str(self): params = self.cxn_params if params: return '{0}:{1}/{2}'.format(params.host, params.port, params.virtual_host) def __str__(self): return ', '.join('{0}={1}'.format(k, v) for k, v in [ ('cxn', self.cxn_str), ('channel', '{0}'.format(int(self.channel) if self.channel is not None else self.channel)), ]) def _create(self): """ All fairy creates go through here. """ return self.Fairy(self.create()) class NullPool(Pool): """ Dummy pool. It opens/closes connections on each acquire/release. """ def acquire(self, timeout=None): return self.Connection(self, self._create()) def release(self, fairy): self.close(fairy) class QueuedPool(Pool): """ Queue backed pool. """ def __init__(self, create, max_size=10, max_overflow=10, timeout=30, recycle=None, stale=None, ): """ :param max_size: Maximum number of connections to keep queued. :param max_overflow: Maximum number of connections to create above `max_size`. :param timeout: Default number of seconds to wait for a connections to available. :param recycle: Lifetime of a connection (since creation) in seconds or None for no recycling. Expired connections are closed on acquire. :param stale: Threshold at which inactive (since release) connections are considered stale in seconds or None for no staleness. Stale connections are closed on acquire. """ self.max_size = max_size self.max_overflow = max_overflow self.timeout = timeout self.recycle = recycle self.stale = stale self._queue = queue.Queue(maxsize=self.max_size) self._avail_lock = threading.Lock() self._avail = self.max_size + self.max_overflow super(QueuedPool, self).__init__(create) def acquire(self, timeout=None): try: fairy = self._queue.get(False) except queue.Empty: try: fairy = self._create() except Overflow: timeout = timeout or self.timeout try: fairy = self._queue.get(timeout=timeout) except queue.Empty: try: fairy = self._create() except Overflow: raise Timeout() if self.is_expired(fairy): logger.info('closing expired connection - %s', fairy) self.close(fairy) return self.acquire(timeout=timeout) if self.is_stale(fairy): logger.info('closing stale connection - %s', fairy) self.close(fairy) return self.acquire(timeout=timeout) return self.Connection(self, fairy) def release(self, fairy): fairy.released_at = time.time() try: self._queue.put_nowait(fairy) except queue.Full: self.close(fairy) def close(self, fairy): # inc with self._avail_lock: self._avail += 1 return super(QueuedPool, self).close(fairy) def _create(self): # dec with self._avail_lock: if self._avail <= 0: raise Overflow() self._avail -= 1 try: return super(QueuedPool, self)._create() except: # inc with self._avail_lock: self._avail += 1 raise class Fairy(Pool.Fairy): def __init__(self, cxn): super(QueuedPool.Fairy, self).__init__(cxn) self.released_at = self.created_at = time.time() def __str__(self): return ', '.join('{0}={1}'.format(k, v) for k, v in [ ('cxn', self.cxn_str), ('channel', '{0}'.format(int(self.channel) if self.channel is not None else self.channel)), ('created_at', '{0}'.format(datetime.fromtimestamp(self.created_at).isoformat())), ('released_at', '{0}'.format(datetime.fromtimestamp(self.released_at).isoformat())), ]) def is_stale(self, fairy): if not self.stale: return False return (time.time() - fairy.released_at) > self.stale def is_expired(self, fairy): if not self.recycle: return False return (time.time() - fairy.created_at) > self.recycle