diff --git a/requires.py b/requires.py index ce0bf7a..a0a67b1 100644 --- a/requires.py +++ b/requires.py @@ -18,6 +18,7 @@ from charms.reactive import hook from charms.reactive import RelationBase from charms.reactive import scopes from charms.reactive.helpers import data_changed +from charmhelpers.core import hookenv class HAClusterRequires(RelationBase): @@ -32,13 +33,38 @@ class HAClusterRequires(RelationBase): @hook('{requires:hacluster}-relation-changed') def changed(self): - self.set_state('{relation_name}.available') + if self.is_clustered(): + self.set_state('{relation_name}.available') + else: + self.remove_state('{relation_name}.available') @hook('{requires:hacluster}-relation-{broken,departed}') def departed(self): self.remove_state('{relation_name}.available') self.remove_state('{relation_name}.connected') + def is_clustered(self): + """Has the hacluster charm set clustered? + + The hacluster charm sets cluster=True when it determines it is ready. + Check the relation data for clustered and force a boolean return. + + :returns: boolean + """ + clustered_values = self.get_remote_all('clustered') + if clustered_values: + # There is only ever one subordinate hacluster unit + clustered = clustered_values[0] + # Future versions of hacluster will return a bool + # Current versions return a string + if type(clustered) is bool: + return clustered + elif (clustered is not None and + (clustered.lower() == 'true' or + clustered.lower() == 'yes')): + return True + return False + def bind_on(self, iface=None, mcastport=None): relation_data = {} if iface: @@ -167,3 +193,16 @@ class HAClusterRequires(RelationBase): resources.group(group, *dns_res_group_members) self.set_local(resources=resources) + + def get_remote_all(self, key, default=None): + """Return a list of all values presented by remote units for key""" + values = [] + for conversation in self.conversations(): + for relation_id in conversation.relation_ids: + for unit in hookenv.related_units(relation_id): + value = hookenv.relation_get(key, + unit, + relation_id) or default + if value: + values.append(value) + return list(set(values))