diff --git a/nodepool/driver/static/provider.py b/nodepool/driver/static/provider.py index 772818fc3..64d1cdd5c 100644 --- a/nodepool/driver/static/provider.py +++ b/nodepool/driver/static/provider.py @@ -33,7 +33,6 @@ class StaticNodeProvider(Provider): def __init__(self, provider, *args): self.provider = provider - self.static_nodes = {} def checkHost(self, node): '''Check node is reachable''' @@ -217,34 +216,37 @@ class StaticNodeProvider(Provider): finally: self.zk.unlockNode(node) + def syncNodeCount(self, registered, node, pool): + current_count = registered[node["name"]] + + # Register nodes to synchronize with our configuration. + if current_count < node["max-parallel-jobs"]: + register_cnt = node["max-parallel-jobs"] - current_count + self.registerNodeFromConfig( + register_cnt, self.provider.name, pool.name, node) + + # De-register nodes to synchronize with our configuration. + # This case covers an existing node, but with a decreased + # max-parallel-jobs value. + elif current_count > node["max-parallel-jobs"]: + deregister_cnt = current_count - node["max-parallel-jobs"] + try: + self.deregisterNode(deregister_cnt, node["name"]) + except Exception: + self.log.exception("Couldn't deregister static node:") + def _start(self, zk_conn): self.zk = zk_conn registered = self.getRegisteredNodeHostnames() + static_nodes = {} for pool in self.provider.pools.values(): for node in pool.nodes: - current_count = registered[node["name"]] - - # Register nodes to synchronize with our configuration. - if current_count < node["max-parallel-jobs"]: - register_cnt = node["max-parallel-jobs"] - current_count - try: - self.registerNodeFromConfig( - register_cnt, self.provider.name, pool.name, node) - except Exception: - self.log.exception("Couldn't register static node:") - continue - - # De-register nodes to synchronize with our configuration. - # This case covers an existing node, but with a decreased - # max-parallel-jobs value. - elif current_count > node["max-parallel-jobs"]: - deregister_cnt = current_count - node["max-parallel-jobs"] - try: - self.deregisterNode(deregister_cnt, node["name"]) - except Exception: - self.log.exception("Couldn't deregister static node:") - continue + try: + self.syncNodeCount(registered, node, pool) + except Exception: + self.log.exception("Couldn't sync node:") + continue try: self.updateNodeFromConfig(node) @@ -252,13 +254,13 @@ class StaticNodeProvider(Provider): self.log.exception("Couldn't update static node:") continue - self.static_nodes[node["name"]] = node + static_nodes[node["name"]] = node # De-register nodes to synchronize with our configuration. # This case covers any registered nodes that no longer appear in # the config. for hostname in list(registered): - if hostname not in self.static_nodes: + if hostname not in static_nodes: try: self.deregisterNode(registered[hostname], hostname) except Exception: @@ -275,11 +277,20 @@ class StaticNodeProvider(Provider): self.log.debug("Stopping") def listNodes(self): + registered = self.getRegisteredNodeHostnames() servers = [] - for node in self.static_nodes.values(): - servers.append(node) + for pool in self.provider.pools.values(): + for node in pool.nodes: + if node["name"] in registered: + servers.append(node) return servers + def poolNodes(self): + nodes = {} + for pool in self.provider.pools.values(): + nodes.update({n["name"]: n for n in pool.nodes}) + return nodes + def cleanupNode(self, server_id): return True @@ -293,7 +304,14 @@ class StaticNodeProvider(Provider): return True def cleanupLeakedResources(self): - pass + registered = self.getRegisteredNodeHostnames() + for pool in self.provider.pools.values(): + for node in pool.nodes: + try: + self.syncNodeCount(registered, node, pool) + except Exception: + self.log.exception("Couldn't sync node:") + continue def getRequestHandler(self, poolworker, request): return StaticNodeRequestHandler(poolworker, request) @@ -304,10 +322,10 @@ class StaticNodeProvider(Provider): ''' # It's possible a deleted node no longer exists in our config, so # don't bother to reregister. - if node.hostname not in self.static_nodes: + static_node = self.poolNodes().get(node.hostname) + if static_node is None: return - static_node = self.static_nodes[node.hostname] try: registered = self.getRegisteredNodeHostnames() except Exception: diff --git a/nodepool/tests/test_driver_static.py b/nodepool/tests/test_driver_static.py index a696ac77f..0e863ac4e 100644 --- a/nodepool/tests/test_driver_static.py +++ b/nodepool/tests/test_driver_static.py @@ -277,3 +277,19 @@ class TestDriverStatic(tests.DBTestCase): new_nodes = self.waitForNodes('fake-label') self.assertEqual(len(new_nodes), 1) self.assertEqual(nodes[0].hostname, new_nodes[0].hostname) + + def test_missing_static_node(self): + """Test that a missing static node is added""" + configfile = self.setup_config('static-2-nodes.yaml') + pool = self.useNodepool(configfile, watermark_sleep=1) + pool.start() + + self.log.debug("Waiting for initial nodes") + nodes = self.waitForNodes('fake-label', 2) + self.assertEqual(len(nodes), 2) + + self.zk.deleteNode(nodes[0]) + + self.log.debug("Waiting for node to transition to ready again") + nodes = self.waitForNodes('fake-label', 2) + self.assertEqual(len(nodes), 2)