Merge pull request #778 from datastax/python-host-filter-policy_PYTHON-761

Python host filter policy PYTHON-761
This commit is contained in:
Jim Witschey 2017-07-12 16:15:43 -04:00 committed by GitHub
commit 852d39e6c0
5 changed files with 388 additions and 14 deletions

View File

@ -4,6 +4,7 @@
Features Features
-------- --------
* Add idle_heartbeat_timeout cluster option to tune how long to wait for heartbeat responses. (PYTHON-762) * Add idle_heartbeat_timeout cluster option to tune how long to wait for heartbeat responses. (PYTHON-762)
* Add HostFilterPolicy (PYTHON-761)
Bug Fixes Bug Fixes
--------- ---------
@ -20,6 +21,10 @@ Other
* Bump Cython dependency version to 0.25.2 (PYTHON-754) * Bump Cython dependency version to 0.25.2 (PYTHON-754)
* Fix DeprecationWarning when using lz4 (PYTHON-769) * Fix DeprecationWarning when using lz4 (PYTHON-769)
Other
-----
* Deprecate WhiteListRoundRobinPolicy (PYTHON-759)
3.10.0 3.10.0
====== ======
May 24, 2017 May 24, 2017

View File

@ -17,6 +17,7 @@ import logging
from random import randint, shuffle from random import randint, shuffle
from threading import Lock from threading import Lock
import socket import socket
from warnings import warn
from cassandra import ConsistencyLevel, OperationTimedOut from cassandra import ConsistencyLevel, OperationTimedOut
@ -396,6 +397,10 @@ class TokenAwarePolicy(LoadBalancingPolicy):
class WhiteListRoundRobinPolicy(RoundRobinPolicy): class WhiteListRoundRobinPolicy(RoundRobinPolicy):
""" """
|wlrrp| **is deprecated. It will be removed in 4.0.** It can effectively be
reimplemented using :class:`.HostFilterPolicy`. For more information, see
PYTHON-758_.
A subclass of :class:`.RoundRobinPolicy` which evenly A subclass of :class:`.RoundRobinPolicy` which evenly
distributes queries across all nodes in the cluster, distributes queries across all nodes in the cluster,
regardless of what datacenter the nodes may be in, but regardless of what datacenter the nodes may be in, but
@ -405,12 +410,25 @@ class WhiteListRoundRobinPolicy(RoundRobinPolicy):
https://datastax-oss.atlassian.net/browse/JAVA-145 https://datastax-oss.atlassian.net/browse/JAVA-145
Where connection errors occur when connection Where connection errors occur when connection
attempts are made to private IP addresses remotely attempts are made to private IP addresses remotely
.. |wlrrp| raw:: html
<b><code>WhiteListRoundRobinPolicy</code></b>
.. _PYTHON-758: https://datastax-oss.atlassian.net/browse/PYTHON-758
""" """
def __init__(self, hosts): def __init__(self, hosts):
""" """
The `hosts` parameter should be a sequence of hosts to permit The `hosts` parameter should be a sequence of hosts to permit
connections to. connections to.
""" """
msg = ('WhiteListRoundRobinPolicy is deprecated. '
'It will be removed in 4.0. '
'It can effectively be reimplemented using HostFilterPolicy.')
warn(msg, DeprecationWarning)
# DeprecationWarnings are silent by default so we also log the message
log.warning(msg)
self._allowed_hosts = hosts self._allowed_hosts = hosts
self._allowed_hosts_resolved = [endpoint[4][0] for a in self._allowed_hosts self._allowed_hosts_resolved = [endpoint[4][0] for a in self._allowed_hosts
@ -441,6 +459,116 @@ class WhiteListRoundRobinPolicy(RoundRobinPolicy):
RoundRobinPolicy.on_add(self, host) RoundRobinPolicy.on_add(self, host)
class HostFilterPolicy(LoadBalancingPolicy):
"""
A :class:`.LoadBalancingPolicy` subclass configured with a child policy,
and a single-argument predicate. This policy defers to the child policy for
hosts where ``predicate(host)`` is truthy. Hosts for which
``predicate(host)`` is falsey will be considered :attr:`.IGNORED`, and will
not be used in a query plan.
This can be used in the cases where you need a whitelist or blacklist
policy, e.g. to prepare for decommissioning nodes or for testing:
.. code-block:: python
def address_is_ignored(host):
return host.address in [ignored_address0, ignored_address1]
blacklist_filter_policy = HostFilterPolicy(
child_policy=RoundRobinPolicy(),
predicate=address_is_ignored
)
cluster = Cluster(
primary_host,
load_balancing_policy=blacklist_filter_policy,
)
Please note that whitelist and blacklist policies are not recommended for
general, day-to-day use. You probably want something like
:class:`.DCAwareRoundRobinPolicy`, which prefers a local DC but has
fallbacks, over a brute-force method like whitelisting or blacklisting.
"""
def __init__(self, child_policy, predicate):
"""
:param child_policy: an instantiated :class:`.LoadBalancingPolicy`
that this one will defer to.
:param predicate: a one-parameter function that takes a :class:`.Host`.
If it returns a falsey value, the :class:`.Host` will
be :attr:`.IGNORED` and not returned in query plans.
"""
super(HostFilterPolicy, self).__init__()
self._child_policy = child_policy
self._predicate = predicate
def on_up(self, host, *args, **kwargs):
if self.predicate(host):
return self._child_policy.on_up(host, *args, **kwargs)
def on_down(self, host, *args, **kwargs):
if self.predicate(host):
return self._child_policy.on_down(host, *args, **kwargs)
def on_add(self, host, *args, **kwargs):
if self.predicate(host):
return self._child_policy.on_add(host, *args, **kwargs)
def on_remove(self, host, *args, **kwargs):
if self.predicate(host):
return self._child_policy.on_remove(host, *args, **kwargs)
@property
def predicate(self):
"""
A predicate, set on object initialization, that takes a :class:`.Host`
and returns a value. If the value is falsy, the :class:`.Host` is
:class:`~HostDistance.IGNORED`. If the value is truthy,
:class:`.HostFilterPolicy` defers to the child policy to determine the
host's distance.
This is a read-only value set in ``__init__``, implemented as a
``property``.
"""
return self._predicate
def distance(self, host):
"""
Checks if ``predicate(host)``, then returns
:attr:`~HostDistance.IGNORED` if falsey, and defers to the child policy
otherwise.
"""
if self.predicate(host):
return self._child_policy.distance(host)
else:
return HostDistance.IGNORED
def populate(self, cluster, hosts):
self._child_policy.populate(
cluster=cluster,
hosts=[h for h in hosts if self.predicate(h)]
)
def make_query_plan(self, working_keyspace=None, query=None):
"""
Defers to the child policy's
:meth:`.LoadBalancingPolicy.make_query_plan`. Since host changes (up,
down, addition, and removal) have not been propagated to the child
policy, the child policy will only ever return policies for which
:meth:`.predicate(host)` was truthy when that change occurred.
"""
child_qp = self._child_policy.make_query_plan(
working_keyspace=working_keyspace, query=query
)
for host in child_qp:
if self.predicate(host):
yield host
def check_supported(self):
return self._child_policy.check_supported()
class ConvictionPolicy(object): class ConvictionPolicy(object):
""" """
A policy which decides when hosts should be considered down A policy which decides when hosts should be considered down
@ -619,6 +747,7 @@ class WriteType(object):
A lighweight-transaction write, such as "DELETE ... IF EXISTS". A lighweight-transaction write, such as "DELETE ... IF EXISTS".
""" """
WriteType.name_to_value = { WriteType.name_to_value = {
'SIMPLE': WriteType.SIMPLE, 'SIMPLE': WriteType.SIMPLE,
'BATCH': WriteType.BATCH, 'BATCH': WriteType.BATCH,

View File

@ -24,6 +24,14 @@ Load Balancing
.. autoclass:: TokenAwarePolicy .. autoclass:: TokenAwarePolicy
:members: :members:
.. autoclass:: HostFilterPolicy
# we document these methods manually so we can specify a param to predicate
.. automethod:: predicate(host)
.. automethod:: distance
.. automethod:: make_query_plan
Translating Server Node Addresses Translating Server Node Addresses
--------------------------------- ---------------------------------

View File

@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import struct, time, logging, sys, traceback import logging
import struct
import sys
import traceback
from cassandra import ConsistencyLevel, Unavailable, OperationTimedOut, ReadTimeout, ReadFailure, \ from cassandra import ConsistencyLevel, Unavailable, OperationTimedOut, ReadTimeout, ReadFailure, \
WriteTimeout, WriteFailure WriteTimeout, WriteFailure
from cassandra.cluster import Cluster, NoHostAvailable, ExecutionProfile from cassandra.cluster import Cluster, NoHostAvailable
from cassandra.concurrent import execute_concurrent_with_args from cassandra.concurrent import execute_concurrent_with_args
from cassandra.metadata import murmur3 from cassandra.metadata import murmur3
from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy, from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy,
TokenAwarePolicy, WhiteListRoundRobinPolicy) TokenAwarePolicy, WhiteListRoundRobinPolicy,
HostFilterPolicy)
from cassandra.query import SimpleStatement from cassandra.query import SimpleStatement
from tests.integration import use_singledc, use_multidc, remove_cluster, PROTOCOL_VERSION from tests.integration import use_singledc, use_multidc, remove_cluster, PROTOCOL_VERSION
@ -40,7 +44,7 @@ log = logging.getLogger(__name__)
class LoadBalancingPolicyTests(unittest.TestCase): class LoadBalancingPolicyTests(unittest.TestCase):
def setUp(self): def setUp(self):
remove_cluster() # clear ahead of test so it doesn't use one left in unknown state remove_cluster() # clear ahead of test so it doesn't use one left in unknown state
self.coordinator_stats = CoordinatorStats() self.coordinator_stats = CoordinatorStats()
self.prepared = None self.prepared = None
self.probe_cluster = None self.probe_cluster = None
@ -105,7 +109,7 @@ class LoadBalancingPolicyTests(unittest.TestCase):
query_string = 'SELECT * FROM %s.cf WHERE k = ?' % keyspace query_string = 'SELECT * FROM %s.cf WHERE k = ?' % keyspace
if not self.prepared or self.prepared.query_string != query_string: if not self.prepared or self.prepared.query_string != query_string:
self.prepared = session.prepare(query_string) self.prepared = session.prepare(query_string)
self.prepared.consistency_level=consistency_level self.prepared.consistency_level = consistency_level
for i in range(count): for i in range(count):
tries = 0 tries = 0
while True: while True:
@ -508,7 +512,7 @@ class LoadBalancingPolicyTests(unittest.TestCase):
self.coordinator_stats.reset_counts() self.coordinator_stats.reset_counts()
stop(2) stop(2)
self._wait_for_nodes_down([2],cluster) self._wait_for_nodes_down([2], cluster)
self._query(session, keyspace) self._query(session, keyspace)
@ -662,3 +666,37 @@ class LoadBalancingPolicyTests(unittest.TestCase):
pass pass
finally: finally:
cluster.shutdown() cluster.shutdown()
def test_black_list_with_host_filter_policy(self):
use_singledc()
keyspace = 'test_black_list_with_hfp'
ignored_address = (IP_FORMAT % 2)
hfp = HostFilterPolicy(
child_policy=RoundRobinPolicy(),
predicate=lambda host: host.address != ignored_address
)
cluster = Cluster(
(IP_FORMAT % 1,),
load_balancing_policy=hfp,
protocol_version=PROTOCOL_VERSION,
topology_event_refresh_window=0,
status_event_refresh_window=0
)
self.addCleanup(cluster.shutdown)
session = cluster.connect()
self._wait_for_nodes_up([1, 2, 3])
self.assertNotIn(ignored_address, [h.address for h in hfp.make_query_plan()])
create_schema(cluster, session, keyspace)
self._insert(session, keyspace)
self._query(session, keyspace)
self.coordinator_stats.assert_query_count_equals(self, 1, 6)
self.coordinator_stats.assert_query_count_equals(self, 2, 0)
self.coordinator_stats.assert_query_count_equals(self, 3, 6)
# policy should not allow reconnecting to ignored host
force_stop(2)
self._wait_for_nodes_down([2])
self.assertFalse(cluster.metadata._hosts[ignored_address].is_currently_reconnecting())

View File

@ -18,9 +18,10 @@ except ImportError:
import unittest # noqa import unittest # noqa
from itertools import islice, cycle from itertools import islice, cycle
from mock import Mock, patch from mock import Mock, patch, call
from random import randint from random import randint
import six import six
from six.moves._thread import LockType
import sys import sys
import struct import struct
from threading import Thread from threading import Thread
@ -34,7 +35,7 @@ from cassandra.policies import (RoundRobinPolicy, WhiteListRoundRobinPolicy, DCA
RetryPolicy, WriteType, RetryPolicy, WriteType,
DowngradingConsistencyRetryPolicy, ConstantReconnectionPolicy, DowngradingConsistencyRetryPolicy, ConstantReconnectionPolicy,
LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy, FallthroughRetryPolicy, LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy, FallthroughRetryPolicy,
IdentityTranslator, EC2MultiRegionTranslator) IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy)
from cassandra.pool import Host from cassandra.pool import Host
from cassandra.query import Statement from cassandra.query import Statement
@ -421,7 +422,6 @@ class DCAwareRoundRobinPolicyTest(unittest.TestCase):
policy.on_up(hosts[2]) policy.on_up(hosts[2])
policy.on_up(hosts[3]) policy.on_up(hosts[3])
another_host = Host(5, SimpleConvictionPolicy) another_host = Host(5, SimpleConvictionPolicy)
another_host.set_location_info("dc3", "rack1") another_host.set_location_info("dc3", "rack1")
new_host.set_location_info("dc3", "rack1") new_host.set_location_info("dc3", "rack1")
@ -755,7 +755,7 @@ class TokenAwarePolicyTest(unittest.TestCase):
@test_category policy @test_category policy
""" """
self._assert_shuffle(keyspace=None, routing_key='routing_key') self._assert_shuffle(keyspace=None, routing_key='routing_key')
def test_no_shuffle_if_given_no_routing_key(self): def test_no_shuffle_if_given_no_routing_key(self):
""" """
Test to validate the hosts are not shuffled when no routing_key is provided Test to validate the hosts are not shuffled when no routing_key is provided
@ -766,7 +766,7 @@ class TokenAwarePolicyTest(unittest.TestCase):
@test_category policy @test_category policy
""" """
self._assert_shuffle(keyspace='keyspace', routing_key=None) self._assert_shuffle(keyspace='keyspace', routing_key=None)
@patch('cassandra.policies.shuffle') @patch('cassandra.policies.shuffle')
def _assert_shuffle(self, patched_shuffle, keyspace, routing_key): def _assert_shuffle(self, patched_shuffle, keyspace, routing_key):
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)] hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
@ -884,7 +884,7 @@ class ExponentialReconnectionPolicyTest(unittest.TestCase):
self.assertRaises(ValueError, ExponentialReconnectionPolicy, -1, 0) self.assertRaises(ValueError, ExponentialReconnectionPolicy, -1, 0)
self.assertRaises(ValueError, ExponentialReconnectionPolicy, 0, -1) self.assertRaises(ValueError, ExponentialReconnectionPolicy, 0, -1)
self.assertRaises(ValueError, ExponentialReconnectionPolicy, 9000, 1) self.assertRaises(ValueError, ExponentialReconnectionPolicy, 9000, 1)
self.assertRaises(ValueError, ExponentialReconnectionPolicy, 1, 2,-1) self.assertRaises(ValueError, ExponentialReconnectionPolicy, 1, 2, -1)
def test_schedule_no_max(self): def test_schedule_no_max(self):
base_delay = 2.0 base_delay = 2.0
@ -1232,11 +1232,27 @@ class WhiteListRoundRobinPolicyTest(unittest.TestCase):
self.assertEqual(policy.distance(host), HostDistance.LOCAL) self.assertEqual(policy.distance(host), HostDistance.LOCAL)
def test_deprecated(self):
import warnings
warnings.resetwarnings() # in case we've instantiated one before
# set up warning filters to allow all, set up restore when this test is done
filters_backup, warnings.filters = warnings.filters, []
self.addCleanup(setattr, warnings, 'filters', filters_backup)
with warnings.catch_warnings(record=True) as caught_warnings:
WhiteListRoundRobinPolicy([])
self.assertEqual(len(caught_warnings), 1)
warning_message = caught_warnings[-1]
self.assertEqual(warning_message.category, DeprecationWarning)
self.assertIn('4.0', warning_message.message.args[0])
class AddressTranslatorTest(unittest.TestCase): class AddressTranslatorTest(unittest.TestCase):
def test_identity_translator(self): def test_identity_translator(self):
it = IdentityTranslator() IdentityTranslator()
addr = '127.0.0.1'
@patch('socket.getfqdn', return_value='localhost') @patch('socket.getfqdn', return_value='localhost')
def test_ec2_multi_region_translator(self, *_): def test_ec2_multi_region_translator(self, *_):
@ -1245,3 +1261,181 @@ class AddressTranslatorTest(unittest.TestCase):
translated = ec2t.translate(addr) translated = ec2t.translate(addr)
self.assertIsNot(translated, addr) # verifies that the resolver path is followed self.assertIsNot(translated, addr) # verifies that the resolver path is followed
self.assertEqual(translated, addr) # and that it resolves to the same address self.assertEqual(translated, addr) # and that it resolves to the same address
class HostFilterPolicyInitTest(unittest.TestCase):
def setUp(self):
self.child_policy, self.predicate = (Mock(name='child_policy'),
Mock(name='predicate'))
def _check_init(self, hfp):
self.assertIs(hfp._child_policy, self.child_policy)
self.assertIsInstance(hfp._hosts_lock, LockType)
# we can't use a simple assertIs because we wrap the function
arg0, arg1 = Mock(name='arg0'), Mock(name='arg1')
hfp.predicate(arg0)
hfp.predicate(arg1)
self.predicate.assert_has_calls([call(arg0), call(arg1)])
def test_init_arg_order(self):
self._check_init(HostFilterPolicy(self.child_policy, self.predicate))
def test_init_kwargs(self):
self._check_init(HostFilterPolicy(
predicate=self.predicate, child_policy=self.child_policy
))
def test_immutable_predicate(self):
expected_message_regex = "can't set attribute"
hfp = HostFilterPolicy(child_policy=Mock(name='child_policy'),
predicate=Mock(name='predicate'))
with self.assertRaisesRegexp(AttributeError, expected_message_regex):
hfp.predicate = object()
class HostFilterPolicyDeferralTest(unittest.TestCase):
def setUp(self):
self.passthrough_hfp = HostFilterPolicy(
child_policy=Mock(name='child_policy'),
predicate=Mock(name='passthrough_predicate',
return_value=True)
)
self.filterall_hfp = HostFilterPolicy(
child_policy=Mock(name='child_policy'),
predicate=Mock(name='filterall_predicate',
return_value=False)
)
def _check_host_triggered_method(self, policy, name):
arg, kwarg = Mock(name='arg'), Mock(name='kwarg')
expect_deferral = policy is self.passthrough_hfp
method, child_policy_method = (getattr(policy, name),
getattr(policy._child_policy, name))
result = method(arg, kw=kwarg)
if expect_deferral:
# method calls the child policy's method...
child_policy_method.assert_called_once_with(arg, kw=kwarg)
# and returns its return value
self.assertIs(result, child_policy_method.return_value)
else:
child_policy_method.assert_not_called()
def test_defer_on_up_to_child_policy(self):
self._check_host_triggered_method(self.passthrough_hfp, 'on_up')
def test_defer_on_down_to_child_policy(self):
self._check_host_triggered_method(self.passthrough_hfp, 'on_down')
def test_defer_on_add_to_child_policy(self):
self._check_host_triggered_method(self.passthrough_hfp, 'on_add')
def test_defer_on_remove_to_child_policy(self):
self._check_host_triggered_method(self.passthrough_hfp, 'on_remove')
def test_filtered_host_on_up_doesnt_call_child_policy(self):
self._check_host_triggered_method(self.filterall_hfp, 'on_up')
def test_filtered_host_on_down_doesnt_call_child_policy(self):
self._check_host_triggered_method(self.filterall_hfp, 'on_down')
def test_filtered_host_on_add_doesnt_call_child_policy(self):
self._check_host_triggered_method(self.filterall_hfp, 'on_add')
def test_filtered_host_on_remove_doesnt_call_child_policy(self):
self._check_host_triggered_method(self.filterall_hfp, 'on_remove')
def _check_check_supported_deferral(self, policy):
policy.check_supported()
policy._child_policy.check_supported.assert_called_once()
def test_check_supported_defers_to_child(self):
self._check_check_supported_deferral(self.passthrough_hfp)
def test_check_supported_defers_to_child_when_predicate_filtered(self):
self._check_check_supported_deferral(self.filterall_hfp)
class HostFilterPolicyDistanceTest(unittest.TestCase):
def setUp(self):
self.hfp = HostFilterPolicy(
child_policy=Mock(name='child_policy', distance=Mock(name='distance')),
predicate=lambda host: host.address == 'acceptme'
)
self.ignored_host = Host(inet_address='ignoreme', conviction_policy_factory=Mock())
self.accepted_host = Host(inet_address='acceptme', conviction_policy_factory=Mock())
def test_ignored_with_filter(self):
self.assertEqual(self.hfp.distance(self.ignored_host),
HostDistance.IGNORED)
self.assertNotEqual(self.hfp.distance(self.accepted_host),
HostDistance.IGNORED)
def test_accepted_filter_defers_to_child_policy(self):
self.hfp._child_policy.distance.side_effect = distances = Mock(), Mock()
# getting the distance for an ignored host shouldn't affect subsequent results
self.hfp.distance(self.ignored_host)
# first call of _child_policy with count() side effect
self.assertEqual(self.hfp.distance(self.accepted_host), distances[0])
# second call of _child_policy with count() side effect
self.assertEqual(self.hfp.distance(self.accepted_host), distances[1])
class HostFilterPolicyPopulateTest(unittest.TestCase):
def test_populate_deferred_to_child(self):
hfp = HostFilterPolicy(
child_policy=Mock(name='child_policy'),
predicate=lambda host: True
)
mock_cluster, hosts = (Mock(name='cluster'),
['host1', 'host2', 'host3'])
hfp.populate(mock_cluster, hosts)
hfp._child_policy.populate.assert_called_once_with(
cluster=mock_cluster,
hosts=hosts
)
def test_child_not_populated_with_filtered_hosts(self):
hfp = HostFilterPolicy(
child_policy=Mock(name='child_policy'),
predicate=lambda host: 'acceptme' in host
)
mock_cluster, hosts = (Mock(name='cluster'),
['acceptme0', 'ignoreme0', 'ignoreme1', 'acceptme1'])
hfp.populate(mock_cluster, hosts)
hfp._child_policy.populate.assert_called_once()
self.assertEqual(
hfp._child_policy.populate.call_args[1]['hosts'],
['acceptme0', 'acceptme1']
)
class HostFilterPolicyQueryPlanTest(unittest.TestCase):
def test_query_plan_deferred_to_child(self):
child_policy = Mock(
name='child_policy',
make_query_plan=Mock(
return_value=[object(), object(), object()]
)
)
hfp = HostFilterPolicy(
child_policy=child_policy,
predicate=lambda host: True
)
working_keyspace, query = (Mock(name='working_keyspace'),
Mock(name='query'))
qp = list(hfp.make_query_plan(working_keyspace=working_keyspace,
query=query))
hfp._child_policy.make_query_plan.assert_called_once_with(
working_keyspace=working_keyspace,
query=query
)
self.assertEqual(qp, hfp._child_policy.make_query_plan.return_value)