Merge "Make msgpack registries copyable (and add __contains__)"

This commit is contained in:
Jenkins 2016-03-24 02:51:19 +00:00 committed by Gerrit Code Review
commit 8cb124e149
2 changed files with 346 additions and 26 deletions

View File

@ -41,6 +41,36 @@ import six.moves.xmlrpc_client as xmlrpclib
netaddr = importutils.try_import("netaddr")
class Interval(object):
"""Small and/or simple immutable integer/float interval class.
Interval checking is **inclusive** of the min/max boundaries.
"""
def __init__(self, min_value, max_value):
if min_value > max_value:
raise ValueError("Minimum value %s must be less than"
" or equal to maximum value %s" % (min_value,
max_value))
self._min_value = min_value
self._max_value = max_value
@property
def min_value(self):
return self._min_value
@property
def max_value(self):
return self._max_value
def __contains__(self, value):
return value >= self.min_value and value <= self.max_value
def __repr__(self):
return 'Interval(%s, %s)' % (self._min_value, self._max_value)
# Expose these so that users don't have to import msgpack to gain these.
PackException = msgpack.PackException
@ -61,48 +91,115 @@ class HandlerRegistry(object):
.. versionadded:: 1.5
"""
# Applications can assign 0 to 127 to store
# application-specific type information...
reserved_extension_range = Interval(0, 32)
"""
These ranges are **always** reserved for use by ``oslo.serialization`` and
its own add-ons extensions (these extensions are meant to be generally
applicable to all of python).
"""
non_reserved_extension_range = Interval(33, 127)
"""
These ranges are **always** reserved for use by applications building
their own type specific handlers (the meaning of extensions in this range
will typically vary depending on application).
"""
min_value = 0
"""
Applications can assign 0 to 127 to store application (or library)
specific type handlers; see above ranges for what is reserved by this
library and what is not.
"""
max_value = 127
"""
Applications can assign 0 to 127 to store application (or library)
specific type handlers; see above ranges for what is reserved by this
library and what is not.
"""
def __init__(self):
self._handlers = {}
self._num_handlers = 0
self.frozen = False
def __iter__(self):
return six.itervalues(self._handlers)
"""Iterates over **all** registered handlers."""
for handlers in six.itervalues(self._handlers):
for h in handlers:
yield h
def register(self, handler):
def register(self, handler, reserved=False, override=False):
"""Register a extension handler to handle its associated type."""
if self.frozen:
raise ValueError("Frozen handler registry can't be modified")
if reserved:
ok_interval = self.reserved_extension_range
else:
ok_interval = self.non_reserved_extension_range
ident = handler.identity
if ident < self.min_value:
if ident < ok_interval.min_value:
raise ValueError("Handler '%s' identity must be greater"
" or equal to %s" % (handler, self.min_value))
if ident > self.max_value:
" or equal to %s" % (handler,
ok_interval.min_value))
if ident > ok_interval.max_value:
raise ValueError("Handler '%s' identity must be less than"
" or equal to %s" % (handler, self.max_value))
if ident in self._handlers:
raise ValueError("Already registered handler with"
" or equal to %s" % (handler,
ok_interval.max_value))
if ident in self._handlers and override:
existing_handlers = self._handlers[ident]
# Insert at the front so that overrides get selected before
# whatever existed before the override...
existing_handlers.insert(0, handler)
self._num_handlers += 1
elif ident in self._handlers and not override:
raise ValueError("Already registered handler(s) with"
" identity %s: %s" % (ident,
self._handlers[ident]))
else:
self._handlers[ident] = handler
self._handlers[ident] = [handler]
self._num_handlers += 1
def __len__(self):
return len(self._handlers)
"""Return how many extension handlers are registered."""
return self._num_handlers
def __contains__(self, identity):
"""Return if any handler exists for the given identity (number)."""
return identity in self._handlers
def copy(self, unfreeze=False):
"""Deep copy the given registry (and its handlers)."""
c = type(self)()
for ident, handlers in six.iteritems(self._handlers):
cloned_handlers = []
for h in handlers:
if hasattr(h, 'copy'):
h = h.copy(c)
cloned_handlers.append(h)
c._handlers[ident] = cloned_handlers
c._num_handlers += len(cloned_handlers)
if not unfreeze and self.frozen:
c.frozen = True
return c
def get(self, identity):
"""Get the handle for the given numeric identity (or none)."""
return self._handlers.get(identity, None)
"""Get the handler for the given numeric identity (or none)."""
maybe_handlers = self._handlers.get(identity)
if maybe_handlers:
# Prefer the first (if there are many) as this is how we
# override built-in extensions (for those that wish to do this).
return maybe_handlers[0]
else:
return None
def match(self, obj):
"""Match the registries handlers to the given object (or none)."""
for handler in six.itervalues(self._handlers):
if isinstance(obj, handler.handles):
return handler
for possible_handlers in six.itervalues(self._handlers):
for h in possible_handlers:
if isinstance(obj, h.handles):
return h
return None
@ -126,6 +223,9 @@ class DateTimeHandler(object):
def __init__(self, registry):
self._registry = registry
def copy(self, registry):
return type(self)(registry)
def serialize(self, dt):
dct = {
u'day': dt.day,
@ -222,6 +322,9 @@ class SetHandler(object):
def __init__(self, registry):
self._registry = registry
def copy(self, registry):
return type(self)(registry)
def serialize(self, obj):
return dumps(list(obj), registry=self._registry)
@ -241,6 +344,9 @@ class XMLRPCDateTimeHandler(object):
def __init__(self, registry):
self._handler = DateTimeHandler(registry)
def copy(self, registry):
return type(self)(registry)
def serialize(self, obj):
dt = datetime.datetime(*tuple(obj.timetuple())[:6])
return self._handler.serialize(dt)
@ -257,6 +363,9 @@ class DateHandler(object):
def __init__(self, registry):
self._registry = registry
def copy(self, registry):
return type(self)(registry)
def serialize(self, d):
dct = {
u'year': d.year,
@ -286,7 +395,7 @@ def _serializer(registry, obj):
def _unserializer(registry, code, data):
handler = registry.get(code)
if handler is None:
if not handler:
return msgpack.ExtType(code, data)
else:
return handler.deserialize(data)
@ -294,15 +403,15 @@ def _unserializer(registry, code, data):
def _create_default_registry():
registry = HandlerRegistry()
registry.register(DateTimeHandler(registry))
registry.register(DateHandler(registry))
registry.register(UUIDHandler())
registry.register(CountHandler())
registry.register(SetHandler(registry))
registry.register(FrozenSetHandler(registry))
registry.register(DateTimeHandler(registry), reserved=True)
registry.register(DateHandler(registry), reserved=True)
registry.register(UUIDHandler(), reserved=True)
registry.register(CountHandler(), reserved=True)
registry.register(SetHandler(registry), reserved=True)
registry.register(FrozenSetHandler(registry), reserved=True)
if netaddr is not None:
registry.register(NetAddrIPHandler())
registry.register(XMLRPCDateTimeHandler(registry))
registry.register(NetAddrIPHandler(), reserved=True)
registry.register(XMLRPCDateTimeHandler(registry), reserved=True)
registry.frozen = True
return registry

View File

@ -0,0 +1,211 @@
# Copyright (C) 2015 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime
import itertools
import uuid
import netaddr
from oslotest import base as test_base
from pytz import timezone
import six
import six.moves.xmlrpc_client as xmlrpclib
from oslo_serialization import msgpackutils
_TZ_FMT = '%Y-%m-%d %H:%M:%S %Z%z'
class Color(object):
def __init__(self, r, g, b):
self.r = r
self.g = g
self.b = b
class ColorHandler(object):
handles = (Color,)
identity = (
msgpackutils.HandlerRegistry.non_reserved_extension_range.min_value + 1
)
@staticmethod
def serialize(obj):
blob = '%s, %s, %s' % (obj.r, obj.g, obj.b)
if six.PY3:
blob = blob.encode("ascii")
return blob
@staticmethod
def deserialize(data):
chunks = [int(c.strip()) for c in data.split(b",")]
return Color(chunks[0], chunks[1], chunks[2])
class MySpecialSetHandler(object):
handles = (set,)
identity = msgpackutils.SetHandler.identity
def _dumps_loads(obj):
obj = msgpackutils.dumps(obj)
return msgpackutils.loads(obj)
class MsgPackUtilsTest(test_base.BaseTestCase):
def test_list(self):
self.assertEqual(_dumps_loads([1, 2, 3]), [1, 2, 3])
def test_empty_list(self):
self.assertEqual(_dumps_loads([]), [])
def test_tuple(self):
# Seems like we do lose whether it was a tuple or not...
#
# Maybe fixed someday:
#
# https://github.com/msgpack/msgpack-python/issues/98
self.assertEqual(_dumps_loads((1, 2, 3)), [1, 2, 3])
def test_dict(self):
self.assertEqual(_dumps_loads(dict(a=1, b=2, c=3)),
dict(a=1, b=2, c=3))
def test_empty_dict(self):
self.assertEqual(_dumps_loads({}), {})
def test_complex_dict(self):
src = {
'now': datetime.datetime(1920, 2, 3, 4, 5, 6, 7),
'later': datetime.datetime(1921, 2, 3, 4, 5, 6, 9),
'a': 1,
'b': 2.0,
'c': [],
'd': set([1, 2, 3]),
'zzz': uuid.uuid4(),
'yyy': 'yyy',
'ddd': b'bbb',
'today': datetime.date.today(),
}
self.assertEqual(_dumps_loads(src), src)
def test_itercount(self):
it = itertools.count(1)
six.next(it)
six.next(it)
it2 = _dumps_loads(it)
self.assertEqual(six.next(it), six.next(it2))
it = itertools.count(0)
it2 = _dumps_loads(it)
self.assertEqual(six.next(it), six.next(it2))
def test_itercount_step(self):
it = itertools.count(1, 3)
it2 = _dumps_loads(it)
self.assertEqual(six.next(it), six.next(it2))
def test_set(self):
self.assertEqual(_dumps_loads(set([1, 2])), set([1, 2]))
def test_empty_set(self):
self.assertEqual(_dumps_loads(set([])), set([]))
def test_frozenset(self):
self.assertEqual(_dumps_loads(frozenset([1, 2])), frozenset([1, 2]))
def test_empty_frozenset(self):
self.assertEqual(_dumps_loads(frozenset([])), frozenset([]))
def test_datetime_preserve(self):
x = datetime.datetime(1920, 2, 3, 4, 5, 6, 7)
self.assertEqual(_dumps_loads(x), x)
def test_datetime(self):
x = xmlrpclib.DateTime()
x.decode("19710203T04:05:06")
self.assertEqual(_dumps_loads(x), x)
def test_ipaddr(self):
thing = {'ip_addr': netaddr.IPAddress('1.2.3.4')}
self.assertEqual(_dumps_loads(thing), thing)
def test_today(self):
today = datetime.date.today()
self.assertEqual(today, _dumps_loads(today))
def test_datetime_tz_clone(self):
eastern = timezone('US/Eastern')
now = datetime.datetime.now()
e_dt = eastern.localize(now)
e_dt2 = _dumps_loads(e_dt)
self.assertEqual(e_dt, e_dt2)
self.assertEqual(e_dt.strftime(_TZ_FMT), e_dt2.strftime(_TZ_FMT))
def test_datetime_tz_different(self):
eastern = timezone('US/Eastern')
pacific = timezone('US/Pacific')
now = datetime.datetime.now()
e_dt = eastern.localize(now)
p_dt = pacific.localize(now)
self.assertNotEqual(e_dt, p_dt)
self.assertNotEqual(e_dt.strftime(_TZ_FMT), p_dt.strftime(_TZ_FMT))
e_dt2 = _dumps_loads(e_dt)
p_dt2 = _dumps_loads(p_dt)
self.assertNotEqual(e_dt2, p_dt2)
self.assertNotEqual(e_dt2.strftime(_TZ_FMT), p_dt2.strftime(_TZ_FMT))
self.assertEqual(e_dt, e_dt2)
self.assertEqual(p_dt, p_dt2)
def test_copy_then_register(self):
registry = msgpackutils.default_registry
self.assertRaises(ValueError,
registry.register, MySpecialSetHandler(),
reserved=True, override=True)
registry = registry.copy(unfreeze=True)
registry.register(MySpecialSetHandler(),
reserved=True, override=True)
h = registry.match(set())
self.assertIsInstance(h, MySpecialSetHandler)
def test_bad_register(self):
registry = msgpackutils.default_registry
self.assertRaises(ValueError,
registry.register, MySpecialSetHandler(),
reserved=True, override=True)
self.assertRaises(ValueError,
registry.register, MySpecialSetHandler())
registry = registry.copy(unfreeze=True)
registry.register(ColorHandler())
self.assertRaises(ValueError,
registry.register, ColorHandler())
def test_custom_register(self):
registry = msgpackutils.default_registry.copy(unfreeze=True)
registry.register(ColorHandler())
c = Color(255, 254, 253)
c_b = msgpackutils.dumps(c, registry=registry)
c = msgpackutils.loads(c_b, registry=registry)
self.assertEqual(255, c.r)
self.assertEqual(254, c.g)
self.assertEqual(253, c.b)