diff --git a/tooz/drivers/memcached.py b/tooz/drivers/memcached.py index e7960719..d11b66ce 100644 --- a/tooz/drivers/memcached.py +++ b/tooz/drivers/memcached.py @@ -80,15 +80,18 @@ class MemcachedLock(locking.Lock): class MemcachedDriver(coordination.CoordinationDriver): _GROUP_PREFIX = b'_TOOZ_GROUP_' + _GROUP_LEADER_PREFIX = b'_TOOZ_GROUP_LEADER_' _MEMBER_PREFIX = b'_TOOZ_MEMBER_' _GROUP_LIST_KEY = b'_TOOZ_GROUP_LIST' - def __init__(self, member_id, membership_timeout=30, lock_timeout=30): + def __init__(self, member_id, membership_timeout=30, lock_timeout=30, + leader_timeout=30): super(MemcachedDriver, self).__init__() self._member_id = member_id self._groups = set() self.membership_timeout = membership_timeout self.lock_timeout = lock_timeout + self.leader_timeout = leader_timeout @staticmethod def _msgpack_serializer(key, value): @@ -121,6 +124,11 @@ class MemcachedDriver(coordination.CoordinationDriver): def stop(self): self.client.delete(self._encode_member_id(self._member_id)) map(self.leave_group, list(self._groups)) + + for group_id in six.iterkeys(self._hooks_elected_leader): + if self.get_leader(group_id).get() == self._member_id: + self.client.delete(self._encode_group_leader(group_id)) + self.client.close() def _encode_group_id(self, group_id): @@ -129,6 +137,9 @@ class MemcachedDriver(coordination.CoordinationDriver): def _encode_member_id(self, member_id): return self._MEMBER_PREFIX + member_id + def _encode_group_leader(self, group_id): + return self._GROUP_LEADER_PREFIX + group_id + @retry def _add_group_to_group_list(self, group_id): """Add group to the group list. @@ -241,6 +252,10 @@ class MemcachedDriver(coordination.CoordinationDriver): raise Retry return MemcachedAsyncResult(None) + def get_leader(self, group_id): + return MemcachedAsyncResult( + self.client.get(self._encode_group_leader(group_id))) + def heartbeat(self): self.client.set(self._encode_member_id(self._member_id), "It's alive!", @@ -249,6 +264,11 @@ class MemcachedDriver(coordination.CoordinationDriver): for lock in self._acquired_locks: lock.heartbeat() + for group_id in six.iterkeys(self._hooks_elected_leader): + if self.get_leader(group_id).get() == self._member_id: + self.client.touch(self._encode_group_leader(group_id), + expire=self.leader_timeout) + def _init_watch_group(self, group_id): members = self.client.get(self._encode_group_id(group_id)) if members is None: @@ -275,13 +295,13 @@ class MemcachedDriver(coordination.CoordinationDriver): return super(MemcachedDriver, self).unwatch_leave_group( group_id, callback) - @staticmethod - def watch_elected_as_leader(group_id, callback): - raise NotImplementedError + def watch_elected_as_leader(self, group_id, callback): + return super(MemcachedDriver, self).watch_elected_as_leader( + group_id, callback) - @staticmethod - def unwatch_elected_as_leader(group_id, callback): - raise NotImplementedError + def unwatch_elected_as_leader(self, group_id, callback): + return super(MemcachedDriver, self).unwatch_elected_as_leader( + group_id, callback) def get_lock(self, name): return MemcachedLock(self, name, self.lock_timeout) @@ -307,6 +327,19 @@ class MemcachedDriver(coordination.CoordinationDriver): self._group_members[group_id] = group_members + for group_id in six.iterkeys(self._hooks_elected_leader): + lock_id = self._encode_group_leader(group_id) + # Try to grab the lock, if that fails, that means someone has it + # already. + if self.client.add(lock_id, self._member_id, + expire=self.leader_timeout, + noreply=False): + # We got the lock + self._hooks_elected_leader[group_id].run( + coordination.LeaderElected( + group_id, + self._member_id)) + return result