Merge pull request #778 from datastax/python-host-filter-policy_PYTHON-761

Python host filter policy PYTHON-761
This commit is contained in:
Jim Witschey 2017-07-12 16:15:43 -04:00 committed by GitHub
commit 852d39e6c0
5 changed files with 388 additions and 14 deletions

View File

@ -4,6 +4,7 @@
Features
--------
* Add idle_heartbeat_timeout cluster option to tune how long to wait for heartbeat responses. (PYTHON-762)
* Add HostFilterPolicy (PYTHON-761)
Bug Fixes
---------
@ -20,6 +21,10 @@ Other
* Bump Cython dependency version to 0.25.2 (PYTHON-754)
* Fix DeprecationWarning when using lz4 (PYTHON-769)
Other
-----
* Deprecate WhiteListRoundRobinPolicy (PYTHON-759)
3.10.0
======
May 24, 2017

View File

@ -17,6 +17,7 @@ import logging
from random import randint, shuffle
from threading import Lock
import socket
from warnings import warn
from cassandra import ConsistencyLevel, OperationTimedOut
@ -396,6 +397,10 @@ class TokenAwarePolicy(LoadBalancingPolicy):
class WhiteListRoundRobinPolicy(RoundRobinPolicy):
"""
|wlrrp| **is deprecated. It will be removed in 4.0.** It can effectively be
reimplemented using :class:`.HostFilterPolicy`. For more information, see
PYTHON-758_.
A subclass of :class:`.RoundRobinPolicy` which evenly
distributes queries across all nodes in the cluster,
regardless of what datacenter the nodes may be in, but
@ -405,12 +410,25 @@ class WhiteListRoundRobinPolicy(RoundRobinPolicy):
https://datastax-oss.atlassian.net/browse/JAVA-145
Where connection errors occur when connection
attempts are made to private IP addresses remotely
.. |wlrrp| raw:: html
<b><code>WhiteListRoundRobinPolicy</code></b>
.. _PYTHON-758: https://datastax-oss.atlassian.net/browse/PYTHON-758
"""
def __init__(self, hosts):
"""
The `hosts` parameter should be a sequence of hosts to permit
connections to.
"""
msg = ('WhiteListRoundRobinPolicy is deprecated. '
'It will be removed in 4.0. '
'It can effectively be reimplemented using HostFilterPolicy.')
warn(msg, DeprecationWarning)
# DeprecationWarnings are silent by default so we also log the message
log.warning(msg)
self._allowed_hosts = hosts
self._allowed_hosts_resolved = [endpoint[4][0] for a in self._allowed_hosts
@ -441,6 +459,116 @@ class WhiteListRoundRobinPolicy(RoundRobinPolicy):
RoundRobinPolicy.on_add(self, host)
class HostFilterPolicy(LoadBalancingPolicy):
"""
A :class:`.LoadBalancingPolicy` subclass configured with a child policy,
and a single-argument predicate. This policy defers to the child policy for
hosts where ``predicate(host)`` is truthy. Hosts for which
``predicate(host)`` is falsey will be considered :attr:`.IGNORED`, and will
not be used in a query plan.
This can be used in the cases where you need a whitelist or blacklist
policy, e.g. to prepare for decommissioning nodes or for testing:
.. code-block:: python
def address_is_ignored(host):
return host.address in [ignored_address0, ignored_address1]
blacklist_filter_policy = HostFilterPolicy(
child_policy=RoundRobinPolicy(),
predicate=address_is_ignored
)
cluster = Cluster(
primary_host,
load_balancing_policy=blacklist_filter_policy,
)
Please note that whitelist and blacklist policies are not recommended for
general, day-to-day use. You probably want something like
:class:`.DCAwareRoundRobinPolicy`, which prefers a local DC but has
fallbacks, over a brute-force method like whitelisting or blacklisting.
"""
def __init__(self, child_policy, predicate):
"""
:param child_policy: an instantiated :class:`.LoadBalancingPolicy`
that this one will defer to.
:param predicate: a one-parameter function that takes a :class:`.Host`.
If it returns a falsey value, the :class:`.Host` will
be :attr:`.IGNORED` and not returned in query plans.
"""
super(HostFilterPolicy, self).__init__()
self._child_policy = child_policy
self._predicate = predicate
def on_up(self, host, *args, **kwargs):
if self.predicate(host):
return self._child_policy.on_up(host, *args, **kwargs)
def on_down(self, host, *args, **kwargs):
if self.predicate(host):
return self._child_policy.on_down(host, *args, **kwargs)
def on_add(self, host, *args, **kwargs):
if self.predicate(host):
return self._child_policy.on_add(host, *args, **kwargs)
def on_remove(self, host, *args, **kwargs):
if self.predicate(host):
return self._child_policy.on_remove(host, *args, **kwargs)
@property
def predicate(self):
"""
A predicate, set on object initialization, that takes a :class:`.Host`
and returns a value. If the value is falsy, the :class:`.Host` is
:class:`~HostDistance.IGNORED`. If the value is truthy,
:class:`.HostFilterPolicy` defers to the child policy to determine the
host's distance.
This is a read-only value set in ``__init__``, implemented as a
``property``.
"""
return self._predicate
def distance(self, host):
"""
Checks if ``predicate(host)``, then returns
:attr:`~HostDistance.IGNORED` if falsey, and defers to the child policy
otherwise.
"""
if self.predicate(host):
return self._child_policy.distance(host)
else:
return HostDistance.IGNORED
def populate(self, cluster, hosts):
self._child_policy.populate(
cluster=cluster,
hosts=[h for h in hosts if self.predicate(h)]
)
def make_query_plan(self, working_keyspace=None, query=None):
"""
Defers to the child policy's
:meth:`.LoadBalancingPolicy.make_query_plan`. Since host changes (up,
down, addition, and removal) have not been propagated to the child
policy, the child policy will only ever return policies for which
:meth:`.predicate(host)` was truthy when that change occurred.
"""
child_qp = self._child_policy.make_query_plan(
working_keyspace=working_keyspace, query=query
)
for host in child_qp:
if self.predicate(host):
yield host
def check_supported(self):
return self._child_policy.check_supported()
class ConvictionPolicy(object):
"""
A policy which decides when hosts should be considered down
@ -619,6 +747,7 @@ class WriteType(object):
A lighweight-transaction write, such as "DELETE ... IF EXISTS".
"""
WriteType.name_to_value = {
'SIMPLE': WriteType.SIMPLE,
'BATCH': WriteType.BATCH,

View File

@ -24,6 +24,14 @@ Load Balancing
.. autoclass:: TokenAwarePolicy
:members:
.. autoclass:: HostFilterPolicy
# we document these methods manually so we can specify a param to predicate
.. automethod:: predicate(host)
.. automethod:: distance
.. automethod:: make_query_plan
Translating Server Node Addresses
---------------------------------

View File

@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import struct, time, logging, sys, traceback
import logging
import struct
import sys
import traceback
from cassandra import ConsistencyLevel, Unavailable, OperationTimedOut, ReadTimeout, ReadFailure, \
WriteTimeout, WriteFailure
from cassandra.cluster import Cluster, NoHostAvailable, ExecutionProfile
from cassandra.cluster import Cluster, NoHostAvailable
from cassandra.concurrent import execute_concurrent_with_args
from cassandra.metadata import murmur3
from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy,
TokenAwarePolicy, WhiteListRoundRobinPolicy)
TokenAwarePolicy, WhiteListRoundRobinPolicy,
HostFilterPolicy)
from cassandra.query import SimpleStatement
from tests.integration import use_singledc, use_multidc, remove_cluster, PROTOCOL_VERSION
@ -40,7 +44,7 @@ log = logging.getLogger(__name__)
class LoadBalancingPolicyTests(unittest.TestCase):
def setUp(self):
remove_cluster() # clear ahead of test so it doesn't use one left in unknown state
remove_cluster() # clear ahead of test so it doesn't use one left in unknown state
self.coordinator_stats = CoordinatorStats()
self.prepared = None
self.probe_cluster = None
@ -105,7 +109,7 @@ class LoadBalancingPolicyTests(unittest.TestCase):
query_string = 'SELECT * FROM %s.cf WHERE k = ?' % keyspace
if not self.prepared or self.prepared.query_string != query_string:
self.prepared = session.prepare(query_string)
self.prepared.consistency_level=consistency_level
self.prepared.consistency_level = consistency_level
for i in range(count):
tries = 0
while True:
@ -508,7 +512,7 @@ class LoadBalancingPolicyTests(unittest.TestCase):
self.coordinator_stats.reset_counts()
stop(2)
self._wait_for_nodes_down([2],cluster)
self._wait_for_nodes_down([2], cluster)
self._query(session, keyspace)
@ -662,3 +666,37 @@ class LoadBalancingPolicyTests(unittest.TestCase):
pass
finally:
cluster.shutdown()
def test_black_list_with_host_filter_policy(self):
use_singledc()
keyspace = 'test_black_list_with_hfp'
ignored_address = (IP_FORMAT % 2)
hfp = HostFilterPolicy(
child_policy=RoundRobinPolicy(),
predicate=lambda host: host.address != ignored_address
)
cluster = Cluster(
(IP_FORMAT % 1,),
load_balancing_policy=hfp,
protocol_version=PROTOCOL_VERSION,
topology_event_refresh_window=0,
status_event_refresh_window=0
)
self.addCleanup(cluster.shutdown)
session = cluster.connect()
self._wait_for_nodes_up([1, 2, 3])
self.assertNotIn(ignored_address, [h.address for h in hfp.make_query_plan()])
create_schema(cluster, session, keyspace)
self._insert(session, keyspace)
self._query(session, keyspace)
self.coordinator_stats.assert_query_count_equals(self, 1, 6)
self.coordinator_stats.assert_query_count_equals(self, 2, 0)
self.coordinator_stats.assert_query_count_equals(self, 3, 6)
# policy should not allow reconnecting to ignored host
force_stop(2)
self._wait_for_nodes_down([2])
self.assertFalse(cluster.metadata._hosts[ignored_address].is_currently_reconnecting())

View File

@ -18,9 +18,10 @@ except ImportError:
import unittest # noqa
from itertools import islice, cycle
from mock import Mock, patch
from mock import Mock, patch, call
from random import randint
import six
from six.moves._thread import LockType
import sys
import struct
from threading import Thread
@ -34,7 +35,7 @@ from cassandra.policies import (RoundRobinPolicy, WhiteListRoundRobinPolicy, DCA
RetryPolicy, WriteType,
DowngradingConsistencyRetryPolicy, ConstantReconnectionPolicy,
LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy, FallthroughRetryPolicy,
IdentityTranslator, EC2MultiRegionTranslator)
IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy)
from cassandra.pool import Host
from cassandra.query import Statement
@ -421,7 +422,6 @@ class DCAwareRoundRobinPolicyTest(unittest.TestCase):
policy.on_up(hosts[2])
policy.on_up(hosts[3])
another_host = Host(5, SimpleConvictionPolicy)
another_host.set_location_info("dc3", "rack1")
new_host.set_location_info("dc3", "rack1")
@ -755,7 +755,7 @@ class TokenAwarePolicyTest(unittest.TestCase):
@test_category policy
"""
self._assert_shuffle(keyspace=None, routing_key='routing_key')
def test_no_shuffle_if_given_no_routing_key(self):
"""
Test to validate the hosts are not shuffled when no routing_key is provided
@ -766,7 +766,7 @@ class TokenAwarePolicyTest(unittest.TestCase):
@test_category policy
"""
self._assert_shuffle(keyspace='keyspace', routing_key=None)
@patch('cassandra.policies.shuffle')
def _assert_shuffle(self, patched_shuffle, keyspace, routing_key):
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
@ -884,7 +884,7 @@ class ExponentialReconnectionPolicyTest(unittest.TestCase):
self.assertRaises(ValueError, ExponentialReconnectionPolicy, -1, 0)
self.assertRaises(ValueError, ExponentialReconnectionPolicy, 0, -1)
self.assertRaises(ValueError, ExponentialReconnectionPolicy, 9000, 1)
self.assertRaises(ValueError, ExponentialReconnectionPolicy, 1, 2,-1)
self.assertRaises(ValueError, ExponentialReconnectionPolicy, 1, 2, -1)
def test_schedule_no_max(self):
base_delay = 2.0
@ -1232,11 +1232,27 @@ class WhiteListRoundRobinPolicyTest(unittest.TestCase):
self.assertEqual(policy.distance(host), HostDistance.LOCAL)
def test_deprecated(self):
import warnings
warnings.resetwarnings() # in case we've instantiated one before
# set up warning filters to allow all, set up restore when this test is done
filters_backup, warnings.filters = warnings.filters, []
self.addCleanup(setattr, warnings, 'filters', filters_backup)
with warnings.catch_warnings(record=True) as caught_warnings:
WhiteListRoundRobinPolicy([])
self.assertEqual(len(caught_warnings), 1)
warning_message = caught_warnings[-1]
self.assertEqual(warning_message.category, DeprecationWarning)
self.assertIn('4.0', warning_message.message.args[0])
class AddressTranslatorTest(unittest.TestCase):
def test_identity_translator(self):
it = IdentityTranslator()
addr = '127.0.0.1'
IdentityTranslator()
@patch('socket.getfqdn', return_value='localhost')
def test_ec2_multi_region_translator(self, *_):
@ -1245,3 +1261,181 @@ class AddressTranslatorTest(unittest.TestCase):
translated = ec2t.translate(addr)
self.assertIsNot(translated, addr) # verifies that the resolver path is followed
self.assertEqual(translated, addr) # and that it resolves to the same address
class HostFilterPolicyInitTest(unittest.TestCase):
def setUp(self):
self.child_policy, self.predicate = (Mock(name='child_policy'),
Mock(name='predicate'))
def _check_init(self, hfp):
self.assertIs(hfp._child_policy, self.child_policy)
self.assertIsInstance(hfp._hosts_lock, LockType)
# we can't use a simple assertIs because we wrap the function
arg0, arg1 = Mock(name='arg0'), Mock(name='arg1')
hfp.predicate(arg0)
hfp.predicate(arg1)
self.predicate.assert_has_calls([call(arg0), call(arg1)])
def test_init_arg_order(self):
self._check_init(HostFilterPolicy(self.child_policy, self.predicate))
def test_init_kwargs(self):
self._check_init(HostFilterPolicy(
predicate=self.predicate, child_policy=self.child_policy
))
def test_immutable_predicate(self):
expected_message_regex = "can't set attribute"
hfp = HostFilterPolicy(child_policy=Mock(name='child_policy'),
predicate=Mock(name='predicate'))
with self.assertRaisesRegexp(AttributeError, expected_message_regex):
hfp.predicate = object()
class HostFilterPolicyDeferralTest(unittest.TestCase):
def setUp(self):
self.passthrough_hfp = HostFilterPolicy(
child_policy=Mock(name='child_policy'),
predicate=Mock(name='passthrough_predicate',
return_value=True)
)
self.filterall_hfp = HostFilterPolicy(
child_policy=Mock(name='child_policy'),
predicate=Mock(name='filterall_predicate',
return_value=False)
)
def _check_host_triggered_method(self, policy, name):
arg, kwarg = Mock(name='arg'), Mock(name='kwarg')
expect_deferral = policy is self.passthrough_hfp
method, child_policy_method = (getattr(policy, name),
getattr(policy._child_policy, name))
result = method(arg, kw=kwarg)
if expect_deferral:
# method calls the child policy's method...
child_policy_method.assert_called_once_with(arg, kw=kwarg)
# and returns its return value
self.assertIs(result, child_policy_method.return_value)
else:
child_policy_method.assert_not_called()
def test_defer_on_up_to_child_policy(self):
self._check_host_triggered_method(self.passthrough_hfp, 'on_up')
def test_defer_on_down_to_child_policy(self):
self._check_host_triggered_method(self.passthrough_hfp, 'on_down')
def test_defer_on_add_to_child_policy(self):
self._check_host_triggered_method(self.passthrough_hfp, 'on_add')
def test_defer_on_remove_to_child_policy(self):
self._check_host_triggered_method(self.passthrough_hfp, 'on_remove')
def test_filtered_host_on_up_doesnt_call_child_policy(self):
self._check_host_triggered_method(self.filterall_hfp, 'on_up')
def test_filtered_host_on_down_doesnt_call_child_policy(self):
self._check_host_triggered_method(self.filterall_hfp, 'on_down')
def test_filtered_host_on_add_doesnt_call_child_policy(self):
self._check_host_triggered_method(self.filterall_hfp, 'on_add')
def test_filtered_host_on_remove_doesnt_call_child_policy(self):
self._check_host_triggered_method(self.filterall_hfp, 'on_remove')
def _check_check_supported_deferral(self, policy):
policy.check_supported()
policy._child_policy.check_supported.assert_called_once()
def test_check_supported_defers_to_child(self):
self._check_check_supported_deferral(self.passthrough_hfp)
def test_check_supported_defers_to_child_when_predicate_filtered(self):
self._check_check_supported_deferral(self.filterall_hfp)
class HostFilterPolicyDistanceTest(unittest.TestCase):
def setUp(self):
self.hfp = HostFilterPolicy(
child_policy=Mock(name='child_policy', distance=Mock(name='distance')),
predicate=lambda host: host.address == 'acceptme'
)
self.ignored_host = Host(inet_address='ignoreme', conviction_policy_factory=Mock())
self.accepted_host = Host(inet_address='acceptme', conviction_policy_factory=Mock())
def test_ignored_with_filter(self):
self.assertEqual(self.hfp.distance(self.ignored_host),
HostDistance.IGNORED)
self.assertNotEqual(self.hfp.distance(self.accepted_host),
HostDistance.IGNORED)
def test_accepted_filter_defers_to_child_policy(self):
self.hfp._child_policy.distance.side_effect = distances = Mock(), Mock()
# getting the distance for an ignored host shouldn't affect subsequent results
self.hfp.distance(self.ignored_host)
# first call of _child_policy with count() side effect
self.assertEqual(self.hfp.distance(self.accepted_host), distances[0])
# second call of _child_policy with count() side effect
self.assertEqual(self.hfp.distance(self.accepted_host), distances[1])
class HostFilterPolicyPopulateTest(unittest.TestCase):
def test_populate_deferred_to_child(self):
hfp = HostFilterPolicy(
child_policy=Mock(name='child_policy'),
predicate=lambda host: True
)
mock_cluster, hosts = (Mock(name='cluster'),
['host1', 'host2', 'host3'])
hfp.populate(mock_cluster, hosts)
hfp._child_policy.populate.assert_called_once_with(
cluster=mock_cluster,
hosts=hosts
)
def test_child_not_populated_with_filtered_hosts(self):
hfp = HostFilterPolicy(
child_policy=Mock(name='child_policy'),
predicate=lambda host: 'acceptme' in host
)
mock_cluster, hosts = (Mock(name='cluster'),
['acceptme0', 'ignoreme0', 'ignoreme1', 'acceptme1'])
hfp.populate(mock_cluster, hosts)
hfp._child_policy.populate.assert_called_once()
self.assertEqual(
hfp._child_policy.populate.call_args[1]['hosts'],
['acceptme0', 'acceptme1']
)
class HostFilterPolicyQueryPlanTest(unittest.TestCase):
def test_query_plan_deferred_to_child(self):
child_policy = Mock(
name='child_policy',
make_query_plan=Mock(
return_value=[object(), object(), object()]
)
)
hfp = HostFilterPolicy(
child_policy=child_policy,
predicate=lambda host: True
)
working_keyspace, query = (Mock(name='working_keyspace'),
Mock(name='query'))
qp = list(hfp.make_query_plan(working_keyspace=working_keyspace,
query=query))
hfp._child_policy.make_query_plan.assert_called_once_with(
working_keyspace=working_keyspace,
query=query
)
self.assertEqual(qp, hfp._child_policy.make_query_plan.return_value)