Merge "Implement Manila REST API microversions"

This commit is contained in:
Jenkins 2015-08-15 19:02:45 +00:00 committed by Gerrit Code Review
commit 17063cdd5e
23 changed files with 1497 additions and 346 deletions

View File

@ -0,0 +1,325 @@
API Microversions
=================
Background
----------
Manila uses a framework we called 'API Microversions' for allowing changes
to the API while preserving backward compatibility. The basic idea is
that a user has to explicitly ask for their request to be treated with
a particular version of the API. So breaking changes can be added to
the API without breaking users who don't specifically ask for it. This
is done with an HTTP header ``X-OpenStack-Manila-API-Version`` which
is a monotonically increasing semantic version number starting from
``1.0``.
If a user makes a request without specifying a version, they will get
the ``DEFAULT_API_VERSION`` as defined in
``manila/api/openstack/wsgi.py``. This value is currently ``1.0`` and
is expected to remain so for quite a long time.
There is a special value ``latest`` which can be specified, which will
allow a client to always receive the most recent version of API
responses from the server.
The Nova project was the first to implement microversions. For full
details please read Nova's `Kilo spec for microversions
<http://git.openstack.org/cgit/openstack/nova-specs/tree/specs/kilo/implemented/api-microversions.rst>`_
When do I need a new Microversion?
----------------------------------
A microversion is needed when the contract to the user is
changed. The user contract covers many kinds of information such as:
- the Request
- the list of resource urls which exist on the server
Example: adding a new shares/{ID}/foo which didn't exist in a
previous version of the code
- the list of query parameters that are valid on urls
Example: adding a new parameter ``is_yellow`` servers/{ID}?is_yellow=True
- the list of query parameter values for non free form fields
Example: parameter filter_by takes a small set of constants/enums "A",
"B", "C". Adding support for new enum "D".
- new headers accepted on a request
- the Response
- the list of attributes and data structures returned
Example: adding a new attribute 'locked': True/False to the output
of shares/{ID}
- the allowed values of non free form fields
Example: adding a new allowed ``status`` to shares/{ID}
- the list of status codes allowed for a particular request
Example: an API previously could return 200, 400, 403, 404 and the
change would make the API now also be allowed to return 409.
- changing a status code on a particular response
Example: changing the return code of an API from 501 to 400.
- new headers returned on a response
The following flow chart attempts to walk through the process of "do
we need a microversion".
.. graphviz::
digraph states {
label="Do I need a microversion?"
silent_fail[shape="diamond", style="", label="Did we silently
fail to do what is asked?"];
ret_500[shape="diamond", style="", label="Did we return a 500
before?"];
new_error[shape="diamond", style="", label="Are we changing what
status code is returned?"];
new_attr[shape="diamond", style="", label="Did we add or remove an
attribute to a payload?"];
new_param[shape="diamond", style="", label="Did we add or remove
an accepted query string parameter or value?"];
new_resource[shape="diamond", style="", label="Did we add or remove a
resource url?"];
no[shape="box", style=rounded, label="No microversion needed"];
yes[shape="box", style=rounded, label="Yes, you need a microversion"];
no2[shape="box", style=rounded, label="No microversion needed, it's
a bug"];
silent_fail -> ret_500[label="no"];
silent_fail -> no2[label="yes"];
ret_500 -> no2[label="yes [1]"];
ret_500 -> new_error[label="no"];
new_error -> new_attr[label="no"];
new_error -> yes[label="yes"];
new_attr -> new_param[label="no"];
new_attr -> yes[label="yes"];
new_param -> new_resource[label="no"];
new_param -> yes[label="yes"];
new_resource -> no[label="no"];
new_resource -> yes[label="yes"];
{rank=same; yes new_attr}
{rank=same; no2 ret_500}
{rank=min; silent_fail}
}
**Footnotes**
[1] - When fixing 500 errors that previously caused stack traces, try
to map the new error into the existing set of errors that API call
could previously return (400 if nothing else is appropriate). Changing
the set of allowed status codes from a request is changing the
contract, and should be part of a microversion.
The reason why we are so strict on contract is that we'd like
application writers to be able to know, for sure, what the contract is
at every microversion in Manila. If they do not, they will need to write
conditional code in their application to handle ambiguities.
When in doubt, consider application authors. If it would work with no
client side changes on both Manila versions, you probably don't need a
microversion. If, on the other hand, there is any ambiguity, a
microversion is probably needed.
In Code
-------
In ``manila/api/openstack/wsgi.py`` we define an ``@api_version`` decorator
which is intended to be used on top-level Controller methods. It is
not appropriate for lower-level methods. Some examples:
Adding a new API method
~~~~~~~~~~~~~~~~~~~~~~~
In the controller class::
@wsgi.Controller.api_version("2.4")
def my_api_method(self, req, id):
....
This method would only be available if the caller had specified an
``X-OpenStack-Manila-API-Version`` of >= ``2.4``. If they had specified a
lower version (or not specified it and received the default of ``2.1``)
the server would respond with ``HTTP/404``.
Removing an API method
~~~~~~~~~~~~~~~~~~~~~~
In the controller class::
@wsgi.Controller.api_version("2.1", "2.4")
def my_api_method(self, req, id):
....
This method would only be available if the caller had specified an
``X-OpenStack-Manila-API-Version`` of <= ``2.4``. If ``2.5`` or later
is specified the server will respond with ``HTTP/404``.
Changing a method's behaviour
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In the controller class::
@wsgi.Controller.api_version("2.1", "2.3")
def my_api_method(self, req, id):
.... method_1 ...
@wsgi.Controller.api_version("2.4") # noqa
def my_api_method(self, req, id):
.... method_2 ...
If a caller specified ``2.1``, ``2.2`` or ``2.3`` (or received the
default of ``2.1``) they would see the result from ``method_1``,
``2.4`` or later ``method_2``.
It is vital that the two methods have the same name, so the second of
them will need ``# noqa`` to avoid failing flake8's ``F811`` rule. The
two methods may be different in any kind of semantics (schema
validation, return values, response codes, etc)
A method with only small changes between versions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
A method may have only small changes between microversions, in which
case you can decorate a private method::
@api_version("2.1", "2.4")
def _version_specific_func(self, req, arg1):
pass
@api_version(min_version="2.5") # noqa
def _version_specific_func(self, req, arg1):
pass
def show(self, req, id):
.... common stuff ....
self._version_specific_func(req, "foo")
.... common stuff ....
A change in schema only
~~~~~~~~~~~~~~~~~~~~~~~
If there is no change to the method, only to the schema that is used for
validation, you can add a version range to the ``validation.schema``
decorator::
@wsgi.Controller.api_version("2.1")
@validation.schema(dummy_schema.dummy, "2.3", "2.8")
@validation.schema(dummy_schema.dummy2, "2.9")
def update(self, req, id, body):
....
This method will be available from version ``2.1``, validated according to
``dummy_schema.dummy`` from ``2.3`` to ``2.8``, and validated according to
``dummy_schema.dummy2`` from ``2.9`` onward.
When not using decorators
~~~~~~~~~~~~~~~~~~~~~~~~~
When you don't want to use the ``@api_version`` decorator on a method
or you want to change behaviour within a method (say it leads to
simpler or simply a lot less code) you can directly test for the
requested version with a method as long as you have access to the api
request object (commonly called ``req``). Every API method has an
api_version_request object attached to the req object and that can be
used to modify behaviour based on its value::
def index(self, req):
<common code>
req_version = req.api_version_request
if req_version.matches("2.1", "2.5"):
....stuff....
elif req_version.matches("2.6", "2.10"):
....other stuff....
elif req_version > api_version_request.APIVersionRequest("2.10"):
....more stuff.....
<common code>
The first argument to the matches method is the minimum acceptable version
and the second is maximum acceptable version. A specified version can be null::
null_version = APIVersionRequest()
If the minimum version specified is null then there is no restriction on
the minimum version, and likewise if the maximum version is null there
is no restriction the maximum version. Alternatively a one sided comparison
can be used as in the example above.
Other necessary changes
-----------------------
If you are adding a patch which adds a new microversion, it is
necessary to add changes to other places which describe your change:
* Update ``REST_API_VERSION_HISTORY`` in
``manila/api/openstack/api_version_request.py``
* Update ``_MAX_API_VERSION`` in
``manila/api/openstack/api_version_request.py``
* Add a verbose description to
``manila/api/openstack/rest_api_version_history.rst``. There should
be enough information that it could be used by the docs team for
release notes.
* Update the expected versions in affected tests.
Allocating a microversion
-------------------------
If you are adding a patch which adds a new microversion, it is
necessary to allocate the next microversion number. Except under
extremely unusual circumstances and this would have been mentioned in
the blueprint for the change, the minor number of ``_MAX_API_VERSION``
will be incremented. This will also be the new microversion number for
the API change.
It is possible that multiple microversion patches would be proposed in
parallel and the microversions would conflict between patches. This
will cause a merge conflict. We don't reserve a microversion for each
patch in advance as we don't know the final merge order. Developers
may need over time to rebase their patch calculating a new version
number as above based on the updated value of ``_MAX_API_VERSION``.
Testing Microversioned API Methods
----------------------------------
Testing a microversioned API method is very similar to a normal controller
method test, you just need to add the ``X-OpenStack-Manila-API-Version``
header, for example::
req = fakes.HTTPRequest.blank('/testable/url/endpoint')
req.headers = {'X-OpenStack-Manila-API-Version': '2.2'}
req.api_version_request = api_version.APIVersionRequest('2.6')
controller = controller.TestableController()
res = controller.index(req)
... assertions about the response ...

View File

@ -0,0 +1 @@
.. include:: ../../../manila/api/openstack/rest_api_version_history.rst

View File

@ -57,6 +57,8 @@ API Reference
:maxdepth: 3
api
api_microversion_dev
api_microversion_history
Module Reference
----------------

View File

@ -29,7 +29,7 @@ paste.app_factory = manila.api.v1.router:APIRouter.factory
pipeline = faultwrap osshareversionapp
[app:osshareversionapp]
paste.app_factory = manila.api.versions:Versions.factory
paste.app_factory = manila.api.versions:VersionsRouter.factory
##########
# Shared #

View File

@ -15,15 +15,21 @@
# under the License.
from oslo_config import cfg
from oslo_log import log
import paste.urlmap
from manila.i18n import _LW
LOG = log.getLogger(__name__)
CONF = cfg.CONF
def root_app_factory(loader, global_conf, **local_conf):
if not CONF.enable_v1_api:
del local_conf['/v1']
if not CONF.enable_v2_api:
del local_conf['/v2']
if CONF.enable_v1_api:
LOG.warning(_LW('The config option enable_v1_api is deprecated, is '
'not used, and will be removed in a future release.'))
if CONF.enable_v2_api:
LOG.warning(_LW('The config option enable_v2_api is deprecated, is '
'not used, and will be removed in a future release.'))
return paste.urlmap.urlmap_factory(loader, global_conf, **local_conf)

View File

@ -0,0 +1,140 @@
# Copyright 2014 IBM Corp.
# Copyright 2015 Clinton Knight
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import re
from manila import exception
from manila import utils
# Define the minimum and maximum version of the API across all of the
# REST API. The format of the version is:
# X.Y where:
#
# - X will only be changed if a significant backwards incompatible API
# change is made which affects the API as whole. That is, something
# that is only very very rarely incremented.
#
# - Y when you make any change to the API. Note that this includes
# semantic changes which may not affect the input or output formats or
# even originate in the API code layer. We are not distinguishing
# between backwards compatible and backwards incompatible changes in
# the versioning system. It must be made clear in the documentation as
# to what is a backwards compatible change and what is a backwards
# incompatible one.
#
# You must update the API version history string below with a one or
# two line description as well as update rest_api_version_history.rst
REST_API_VERSION_HISTORY = """
REST API Version History:
* 1.0 - Initial version. Includes all V1 APIs and extensions in Kilo.
* 1.1 - Versions API updated to reflect beginning of microversions epoch.
"""
# The minimum and maximum versions of the API supported
# The default api version request is defined to be the
# the minimum version of the API supported.
_MIN_API_VERSION = "1.0"
_MAX_API_VERSION = "1.1"
DEFAULT_API_VERSION = _MIN_API_VERSION
# NOTE(cyeoh): min and max versions declared as functions so we can
# mock them for unittests. Do not use the constants directly anywhere
# else.
def min_api_version():
return APIVersionRequest(_MIN_API_VERSION)
def max_api_version():
return APIVersionRequest(_MAX_API_VERSION)
class APIVersionRequest(utils.ComparableMixin):
"""This class represents an API Version Request.
This class includes convenience methods for manipulation
and comparison of version numbers as needed to implement
API microversions.
"""
def __init__(self, version_string=None):
"""Create an API version request object."""
self.ver_major = None
self.ver_minor = None
if version_string is not None:
match = re.match(r"^([1-9]\d*)\.([1-9]\d*|0)$",
version_string)
if match:
self.ver_major = int(match.group(1))
self.ver_minor = int(match.group(2))
else:
raise exception.InvalidAPIVersionString(version=version_string)
def __str__(self):
"""Debug/Logging representation of object."""
return ("API Version Request Major: %(major)s, Minor: %(minor)s"
% {'major': self.ver_major, 'minor': self.ver_minor})
def is_null(self):
return self.ver_major is None and self.ver_minor is None
def _cmpkey(self):
"""Return the value used by ComparableMixin for rich comparisons."""
return self.ver_major, self.ver_minor
def matches(self, min_version, max_version):
"""Compares this version to the specified min/max range.
Returns whether the version object represents a version
greater than or equal to the minimum version and less than
or equal to the maximum version.
If min_version is null then there is no minimum limit.
If max_version is null then there is no maximum limit.
If self is null then raise ValueError.
:param min_version: Minimum acceptable version.
:param max_version: Maximum acceptable version.
:returns: boolean
"""
if self.is_null():
raise ValueError
if max_version.is_null() and min_version.is_null():
return True
elif max_version.is_null():
return min_version <= self
elif min_version.is_null():
return self <= max_version
else:
return min_version <= self <= max_version
def get_string(self):
"""Returns a string representation of this object.
If this method is used to create an APIVersionRequest,
the resulting object will be an equivalent request.
"""
if self.is_null():
raise ValueError
return ("%(major)s.%(minor)s" %
{'major': self.ver_major, 'minor': self.ver_minor})

View File

@ -0,0 +1,32 @@
REST API Version History
========================
This documents the changes made to the REST API with every
microversion change. The description for each version should be a
verbose one which has enough information to be suitable for use in
user documentation.
1.0
---
The 1.0 Manila API includes all v1 core APIs existing prior to
the introduction of microversions.
1.1
---
This is the initial version of the Manila API which supports
microversions.
A user can specify a header in the API request::
X-OpenStack-Manila-API-Version: <version>
where ``<version>`` is any valid api version for this API.
If no version is specified then the API will behave as version 1.0
was requested.
The only API change in version 1.1 is versions, i.e.
GET http://localhost:8786/, which now returns the minimum and
current microversion values.

View File

@ -0,0 +1,47 @@
# Copyright 2014 IBM Corp.
# Copyright 2015 Clinton Knight
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from manila import utils
class VersionedMethod(utils.ComparableMixin):
def __init__(self, name, start_version, end_version, func):
"""Versioning information for a single method.
Minimum and maximums are inclusive.
:param name: Name of the method
:param start_version: Minimum acceptable version
:param end_version: Maximum acceptable_version
:param func: Method to call
"""
self.name = name
self.start_version = start_version
self.end_version = end_version
self.func = func
def __str__(self):
args = {
'name': self.name,
'start': self.start_version,
'end': self.end_version
}
return ("Version Method %(name)s: min: %(start)s, max: %(end)s" % args)
def _cmpkey(self):
"""Return the value used by ComparableMixin for rich comparisons."""
return self.start_version

View File

@ -13,19 +13,24 @@
# License for the specific language governing permissions and limitations
# under the License.
import functools
import inspect
import math
import time
from oslo_log import log
from oslo_serialization import jsonutils
from oslo_utils import strutils
import six
import time
import webob
import webob.exc
from manila.api.openstack import api_version_request as api_version
from manila.api.openstack import versioned_method
from manila import exception
from manila.i18n import _
from manila.i18n import _LE
from manila.i18n import _LI
from manila import utils
from manila import wsgi
LOG = log.getLogger(__name__)
@ -38,6 +43,13 @@ _MEDIA_TYPE_MAP = {
'application/json': 'json',
}
# name of attribute to keep version method information
VER_METHOD_ATTR = 'versioned_methods'
# Name of header used by clients to request a specific version
# of the REST API
API_VERSION_REQUEST_HEADER = 'X-OpenStack-Manila-API-Version'
class Request(webob.Request):
"""Add some OpenStack API-specific logic to the base webob.Request."""
@ -45,6 +57,8 @@ class Request(webob.Request):
def __init__(self, *args, **kwargs):
super(Request, self).__init__(*args, **kwargs)
self._resource_cache = {}
if not hasattr(self, 'api_version_request'):
self.api_version_request = api_version.APIVersionRequest()
def cache_resource(self, resource_to_cache, id_attribute='id', name=None):
"""Cache the given resource.
@ -178,8 +192,7 @@ class Request(webob.Request):
def get_content_type(self):
"""Determine content type of the request body.
Does not do any body introspection, only checks header
Does not do any body introspection, only checks header.
"""
if "Content-Type" not in self.headers:
return None
@ -192,6 +205,32 @@ class Request(webob.Request):
return content_type
def set_api_version_request(self):
"""Set API version request based on the request header information."""
if API_VERSION_REQUEST_HEADER in self.headers:
hdr_string = self.headers[API_VERSION_REQUEST_HEADER]
# 'latest' is a special keyword which is equivalent to requesting
# the maximum version of the API supported
if hdr_string == 'latest':
self.api_version_request = api_version.max_api_version()
else:
self.api_version_request = api_version.APIVersionRequest(
hdr_string)
# Check that the version requested is within the global
# minimum/maximum of supported API versions
if not self.api_version_request.matches(
api_version.min_api_version(),
api_version.max_api_version()):
raise exception.InvalidGlobalAPIVersion(
req_ver=self.api_version_request.get_string(),
min_ver=api_version.min_api_version().get_string(),
max_ver=api_version.max_api_version().get_string())
else:
self.api_version_request = api_version.APIVersionRequest(
api_version.DEFAULT_API_VERSION)
class ActionDispatcher(object):
"""Maps method name to local methods through action name."""
@ -199,7 +238,7 @@ class ActionDispatcher(object):
def dispatch(self, *args, **kwargs):
"""Find and call local method."""
action = kwargs.pop('action', 'default')
action_method = getattr(self, str(action), self.default)
action_method = getattr(self, six.text_type(action), self.default)
return action_method(*args, **kwargs)
def default(self, data):
@ -300,7 +339,7 @@ class ResponseObject(object):
optional.
"""
def __init__(self, obj, code=None, **serializers):
def __init__(self, obj, code=None, headers=None, **serializers):
"""Binds serializers with an object.
Takes keyword arguments akin to the @serializer() decorator
@ -313,7 +352,7 @@ class ResponseObject(object):
self.serializers = serializers
self._default_code = 200
self._code = code
self._headers = {}
self._headers = headers or {}
self.serializer = None
self.media_type = None
@ -406,8 +445,8 @@ class ResponseObject(object):
response = webob.Response()
response.status_int = self.code
for hdr, value in self._headers.items():
response.headers[hdr] = value
response.headers['Content-Type'] = content_type
response.headers[hdr] = six.text_type(value)
response.headers['Content-Type'] = six.text_type(content_type)
if self.obj is not None:
response.body = serializer.serialize(self.obj)
@ -462,6 +501,8 @@ class ResourceExceptionHandler(object):
if isinstance(ex_value, exception.NotAuthorized):
msg = six.text_type(ex_value)
raise Fault(webob.exc.HTTPForbidden(explanation=msg))
elif isinstance(ex_value, exception.VersionNotFoundForAPIMethod):
raise
elif isinstance(ex_value, exception.Invalid):
raise Fault(exception.ConvertedException(
code=ex_value.code, explanation=six.text_type(ex_value)))
@ -494,8 +535,8 @@ class Resource(wsgi.Application):
Exceptions derived from webob.exc.HTTPException will be automatically
wrapped in Fault() to provide API friendly error responses.
"""
support_api_request_version = True
def __init__(self, controller, action_peek=None, **deserializers):
"""
@ -656,6 +697,11 @@ class Resource(wsgi.Application):
with ResourceExceptionHandler():
response = ext(req=request, resp_obj=resp_obj,
**action_args)
except exception.VersionNotFoundForAPIMethod:
# If an attached extension (@wsgi.extends) for the
# method has no version match its not an error. We
# just don't run the extends code
continue
except Fault as ex:
response = ex
@ -671,6 +717,16 @@ class Resource(wsgi.Application):
LOG.info("%(method)s %(url)s" % {"method": request.method,
"url": request.url})
if self.support_api_request_version:
# Set the version of the API requested based on the header
try:
request.set_api_version_request()
except exception.InvalidAPIVersionString as e:
return Fault(webob.exc.HTTPBadRequest(
explanation=six.text_type(e)))
except exception.InvalidGlobalAPIVersion as e:
return Fault(webob.exc.HTTPNotAcceptable(
explanation=six.text_type(e)))
# Identify the action, its arguments, and the requested
# content type
@ -704,6 +760,16 @@ class Resource(wsgi.Application):
msg = _("Malformed request body")
return Fault(webob.exc.HTTPBadRequest(explanation=msg))
if body:
msg = ("Action: '%(action)s', calling method: %(meth)s, body: "
"%(body)s") % {'action': action,
'body': six.text_type(body),
'meth': six.text_type(meth)}
LOG.debug(strutils.mask_password(msg))
else:
LOG.debug("Calling method '%(meth)s'",
{'meth': six.text_type(meth)})
# Now, deserialize the request body...
try:
if content_type:
@ -775,6 +841,16 @@ class Resource(wsgi.Application):
LOG.info(msg)
if hasattr(response, 'headers'):
for hdr, val in response.headers.items():
# Headers must be utf-8 strings
response.headers[hdr] = six.text_type(val)
if not request.api_version_request.is_null():
response.headers[API_VERSION_REQUEST_HEADER] = (
request.api_version_request.get_string())
response.headers['Vary'] = API_VERSION_REQUEST_HEADER
return response
def get_method(self, request, action, content_type, body):
@ -809,7 +885,13 @@ class Resource(wsgi.Application):
def dispatch(self, method, request, action_args):
"""Dispatch a call to the action-specific method."""
return method(req=request, **action_args)
try:
return method(req=request, **action_args)
except exception.VersionNotFoundForAPIMethod:
# We deliberately don't return any message information
# about the exception to the user so it looks as if
# the method is simply not implemented.
return Fault(webob.exc.HTTPNotFound())
def action(name):
@ -869,9 +951,22 @@ class ControllerMetaclass(type):
# Find all actions
actions = {}
extensions = []
versioned_methods = None
# start with wsgi actions from base classes
for base in bases:
actions.update(getattr(base, 'wsgi_actions', {}))
if base.__name__ == "Controller":
# NOTE(cyeoh): This resets the VER_METHOD_ATTR attribute
# between API controller class creations. This allows us
# to use a class decorator on the API methods that doesn't
# require naming explicitly what method is being versioned as
# it can be implicit based on the method decorated. It is a bit
# ugly.
if VER_METHOD_ATTR in base.__dict__:
versioned_methods = getattr(base, VER_METHOD_ATTR)
delattr(base, VER_METHOD_ATTR)
for key, value in cls_dict.items():
if not callable(value):
continue
@ -883,6 +978,8 @@ class ControllerMetaclass(type):
# Add the actions and extensions to the class dict
cls_dict['wsgi_actions'] = actions
cls_dict['wsgi_extensions'] = extensions
if versioned_methods:
cls_dict[VER_METHOD_ATTR] = versioned_methods
return super(ControllerMetaclass, mcs).__new__(mcs, name, bases,
cls_dict)
@ -903,6 +1000,97 @@ class Controller(object):
else:
self._view_builder = None
def __getattribute__(self, key):
def version_select(*args, **kwargs):
"""Select and call the matching version of the specified method.
Look for the method which matches the name supplied and version
constraints and calls it with the supplied arguments.
:returns: Returns the result of the method called
:raises: VersionNotFoundForAPIMethod if there is no method which
matches the name and version constraints
"""
# The first arg to all versioned methods is always the request
# object. The version for the request is attached to the
# request object
if len(args) == 0:
ver = kwargs['req'].api_version_request
else:
ver = args[0].api_version_request
func_list = self.versioned_methods[key]
for func in func_list:
if ver.matches(func.start_version, func.end_version):
# Update the version_select wrapper function so
# other decorator attributes like wsgi.response
# are still respected.
functools.update_wrapper(version_select, func.func)
return func.func(self, *args, **kwargs)
# No version match
raise exception.VersionNotFoundForAPIMethod(version=ver)
try:
version_meth_dict = object.__getattribute__(self, VER_METHOD_ATTR)
except AttributeError:
# No versioning on this class
return object.__getattribute__(self, key)
if (version_meth_dict and
key in object.__getattribute__(self, VER_METHOD_ATTR)):
return version_select
return object.__getattribute__(self, key)
# NOTE(cyeoh): This decorator MUST appear first (the outermost
# decorator) on an API method for it to work correctly
@classmethod
def api_version(cls, min_ver, max_ver=None):
"""Decorator for versioning API methods.
Add the decorator to any method which takes a request object
as the first parameter and belongs to a class which inherits from
wsgi.Controller.
:param min_ver: string representing minimum version
:param max_ver: optional string representing maximum version
"""
def decorator(f):
obj_min_ver = api_version.APIVersionRequest(min_ver)
if max_ver:
obj_max_ver = api_version.APIVersionRequest(max_ver)
else:
obj_max_ver = api_version.APIVersionRequest()
# Add to list of versioned methods registered
func_name = f.__name__
new_func = versioned_method.VersionedMethod(
func_name, obj_min_ver, obj_max_ver, f)
func_dict = getattr(cls, VER_METHOD_ATTR, {})
if not func_dict:
setattr(cls, VER_METHOD_ATTR, func_dict)
func_list = func_dict.get(func_name, [])
if not func_list:
func_dict[func_name] = func_list
func_list.append(new_func)
# Ensure the list is sorted by minimum version (reversed)
# so later when we work through the list in order we find
# the method which has the latest version which supports
# the version requested.
# TODO(cyeoh): Add check to ensure that there are no overlapping
# ranges of valid versions as that is ambiguous
func_list.sort(reverse=True)
return f
return decorator
@staticmethod
def is_valid_body(body, entity_name):
if not (body and entity_name in body):
@ -954,8 +1142,10 @@ class Fault(webob.exc.HTTPException):
retry = self.wrapped_exc.headers['Retry-After']
fault_data[fault_name]['retryAfter'] = retry
# 'code' is an attribute on the fault tag itself
metadata = {'attributes': {fault_name: 'code'}}
if not req.api_version_request.is_null():
self.wrapped_exc.headers[API_VERSION_REQUEST_HEADER] = (
req.api_version_request.get_string())
self.wrapped_exc.headers['Vary'] = API_VERSION_REQUEST_HEADER
content_type = req.best_match_content_type()
serializer = {
@ -979,14 +1169,10 @@ def _set_request_id_header(req, headers):
class OverLimitFault(webob.exc.HTTPException):
"""
Rate-limited request response.
"""
"""Rate-limited request response."""
def __init__(self, message, details, retry_time):
"""
Initialize new `OverLimitFault` with relevant information.
"""
"""Initialize new `OverLimitFault` with relevant information."""
hdrs = OverLimitFault._retry_after(retry_time)
self.wrapped_exc = webob.exc.HTTPRequestEntityTooLarge(headers=hdrs)
self.content = {
@ -1006,8 +1192,9 @@ class OverLimitFault(webob.exc.HTTPException):
@webob.dec.wsgify(RequestClass=Request)
def __call__(self, request):
"""
Return the wrapped exception with a serialized body conforming to our
"""Wrap the exception.
Wrap the exception with a serialized body conforming to our
error format.
"""
content_type = request.best_match_content_type()

View File

@ -49,7 +49,7 @@ class APIRouter(manila.api.openstack.APIRouter):
self.resources['versions'] = versions.create_resource()
mapper.connect("versions", "/",
controller=self.resources['versions'],
action='show')
action='index')
mapper.redirect("", "/")

View File

@ -1,4 +1,5 @@
# Copyright 2010 OpenStack LLC.
# Copyright 2015 Clinton Knight
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
@ -13,8 +14,13 @@
# License for the specific language governing permissions and limitations
# under the License.
import copy
from oslo_config import cfg
from manila.api import extensions
from manila.api import openstack
from manila.api.openstack import api_version_request
from manila.api.openstack import wsgi
from manila.api.views import versions as views_versions
@ -22,103 +28,63 @@ CONF = cfg.CONF
_KNOWN_VERSIONS = {
"v2.0": {
"id": "v2.0",
"status": "CURRENT",
"updated": "2012-11-21T11:33:21Z",
"links": [
'v1.0': {
'id': 'v1.0',
'status': 'CURRENT',
'version': api_version_request._MAX_API_VERSION,
'min_version': api_version_request._MIN_API_VERSION,
'updated': '2015-07-30T11:33:21Z',
'links': [
{
"rel": "describedby",
"type": "application/pdf",
"href": "http://jorgew.github.com/block-storage-api/"
"content/os-block-storage-1.0.pdf",
},
{
"rel": "describedby",
"type": "application/vnd.sun.wadl+xml",
# (anthony) FIXME
"href": "http://docs.rackspacecloud.com/"
"servers/api/v1.1/application.wadl",
'rel': 'describedby',
'type': 'text/html',
'href': 'http://docs.openstack.org/',
},
],
"media-types": [
'media-types': [
{
"base": "application/json",
'base': 'application/json',
'type': 'application/vnd.openstack.share+json;version=1',
}
],
},
"v1.0": {
"id": "v1.0",
"status": "CURRENT",
"updated": "2012-01-04T11:33:21Z",
"links": [
{
"rel": "describedby",
"type": "application/pdf",
"href": "http://jorgew.github.com/block-storage-api/"
"content/os-block-storage-1.0.pdf",
},
{
"rel": "describedby",
"type": "application/vnd.sun.wadl+xml",
# (anthony) FIXME
"href": "http://docs.rackspacecloud.com/"
"servers/api/v1.1/application.wadl",
},
],
"media-types": [
{
"base": "application/json",
}
],
}
}
def get_supported_versions():
versions = {}
class VersionsRouter(openstack.APIRouter):
"""Route versions requests."""
if CONF.enable_v1_api:
versions['v1.0'] = _KNOWN_VERSIONS['v1.0']
if CONF.enable_v2_api:
versions['v2.0'] = _KNOWN_VERSIONS['v2.0']
ExtensionManager = extensions.ExtensionManager
return versions
def _setup_routes(self, mapper, ext_mgr):
self.resources['versions'] = create_resource()
mapper.connect('versions', '/',
controller=self.resources['versions'],
action='index')
mapper.redirect('', '/')
class Versions(wsgi.Resource):
class VersionsController(wsgi.Controller):
def __init__(self):
super(Versions, self).__init__(None)
super(VersionsController, self).__init__(None)
@wsgi.Controller.api_version('1.0', '1.0')
def index(self, req):
"""Return all versions."""
builder = views_versions.get_view_builder(req)
return builder.build_versions(get_supported_versions())
known_versions = copy.deepcopy(_KNOWN_VERSIONS)
known_versions['v1.0'].pop('min_version')
known_versions['v1.0'].pop('version')
return builder.build_versions(known_versions)
@wsgi.response(300)
def multi(self, req):
"""Return multiple choices."""
@wsgi.Controller.api_version('1.1') # noqa
def index(self, req): # pylint: disable=E0102
"""Return all versions."""
builder = views_versions.get_view_builder(req)
return builder.build_choices(get_supported_versions(), req)
def get_action_args(self, request_environment):
"""Parse dictionary created by routes library."""
args = {}
if request_environment['PATH_INFO'] == '/':
args['action'] = 'index'
else:
args['action'] = 'multi'
return args
class ShareVersionV1(object):
def show(self, req):
builder = views_versions.get_view_builder(req)
return builder.build_version(_KNOWN_VERSIONS['v1.0'])
known_versions = copy.deepcopy(_KNOWN_VERSIONS)
return builder.build_versions(known_versions)
def create_resource():
return wsgi.Resource(ShareVersionV1())
return wsgi.Resource(VersionsController())

View File

@ -1,4 +1,5 @@
# Copyright 2010-2011 OpenStack LLC.
# Copyright 2015 Clinton Knight
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
@ -14,12 +15,13 @@
# under the License.
import copy
import os
import re
from six.moves import urllib
def get_view_builder(req):
base_url = req.application_url
return ViewBuilder(base_url)
return ViewBuilder(req.application_url)
class ViewBuilder(object):
@ -30,52 +32,30 @@ class ViewBuilder(object):
"""
self.base_url = base_url
def build_choices(self, VERSIONS, req):
version_objs = []
for version in VERSIONS:
version = VERSIONS[version]
version_objs.append({
"id": version['id'],
"status": version['status'],
"links": [{"rel": "self",
"href": self.generate_href(req.path), }, ],
"media-types": version['media-types'], })
return dict(choices=version_objs)
def build_versions(self, versions):
version_objs = []
for version in sorted(versions.keys()):
version = versions[version]
version_objs.append({
"id": version['id'],
"status": version['status'],
"updated": version['updated'],
"links": self._build_links(version), })
views = [self._build_version(versions[key])
for key in sorted(list(versions.keys()))]
return dict(versions=views)
return dict(versions=version_objs)
def build_version(self, version):
reval = copy.deepcopy(version)
reval['links'].insert(0, {
"rel": "self",
"href": self.base_url.rstrip('/') + '/', })
return dict(version=reval)
def _build_version(self, version):
view = copy.deepcopy(version)
view['links'] = self._build_links(version)
return view
def _build_links(self, version_data):
"""Generate a container of links that refer to the provided version."""
href = self.generate_href()
links = [{'rel': 'self',
'href': href, }, ]
links = copy.deepcopy(version_data.get('links', {}))
links.append({'rel': 'self', 'href': self._generate_href()})
return links
def generate_href(self, path=None):
"""Create an url that refers to a specific version_number."""
version_number = 'v1'
def _generate_href(self, version='v1', path=None):
"""Create a URL that refers to a specific version_number."""
base_url = self._get_base_url_without_version()
href = urllib.parse.urljoin(base_url, version).rstrip('/') + '/'
if path:
path = path.strip('/')
return os.path.join(self.base_url, version_number, path)
else:
return os.path.join(self.base_url, version_number) + '/'
href += path.lstrip('/')
return href
def _get_base_url_without_version(self):
"""Get the base URL with out the /v1 suffix."""
return re.sub('v[1-9]+/?$', '', self.base_url)

View File

@ -64,11 +64,15 @@ global_opts = [
default='manila-share',
help='The topic share nodes listen on.'),
cfg.BoolOpt('enable_v1_api',
default=True,
help=_("Deploy v1 of the Manila API.")),
default=False,
help=_('Deploy v1 of the Manila API. This option is '
'deprecated, is not used, and will be removed '
'in a future release.')),
cfg.BoolOpt('enable_v2_api',
default=True,
help=_("Deploy v2 of the Manila API.")),
default=False,
help=_('Deploy v2 of the Manila API. This option is '
'deprecated, is not used, and will be removed '
'in a future release.')),
cfg.BoolOpt('api_rate_limit',
default=True,
help='Whether to rate limit the API.'),

View File

@ -170,6 +170,20 @@ class InvalidDriverMode(Invalid):
message = _("Invalid driver mode: %(driver_mode)s.")
class InvalidAPIVersionString(Invalid):
msg_fmt = _("API Version String %(version)s is of invalid format. Must "
"be of format MajorNum.MinorNum.")
class VersionNotFoundForAPIMethod(Invalid):
msg_fmt = _("API version %(version)s is not supported on this method.")
class InvalidGlobalAPIVersion(Invalid):
msg_fmt = _("Version %(req_ver)s is not supported by the API. Minimum "
"is %(min_ver)s and maximum is %(max_ver)s.")
class NotFound(ManilaException):
message = _("Resource could not be found.")
code = 404

View File

@ -24,6 +24,7 @@ import webob.request
from manila.api.middleware import auth
from manila.api.middleware import fault
from manila.api.openstack import api_version_request as api_version
from manila.api.openstack import wsgi as os_wsgi
from manila.api import urlmap
from manila.api.v1 import limits
@ -106,13 +107,16 @@ class HTTPRequest(os_wsgi.Request):
@classmethod
def blank(cls, *args, **kwargs):
kwargs['base_url'] = 'http://localhost/v1'
if not kwargs.get('base_url'):
kwargs['base_url'] = 'http://localhost/v1'
use_admin_context = kwargs.pop('use_admin_context', False)
version = kwargs.pop('version', api_version.DEFAULT_API_VERSION)
out = os_wsgi.Request.blank(*args, **kwargs)
out.environ['manila.context'] = FakeRequestContext(
'fake_user',
'fake',
is_admin=use_admin_context)
out.api_version_request = api_version.APIVersionRequest(version)
return out

View File

@ -0,0 +1,120 @@
# Copyright 2014 IBM Corp.
# Copyright 2015 Clinton Knight
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import ddt
import six
from manila.api.openstack import api_version_request
from manila import exception
from manila import test
@ddt.ddt
class APIVersionRequestTests(test.TestCase):
@ddt.data(
('1.1', 1, 1),
('2.10', 2, 10),
('5.234', 5, 234),
('12.5', 12, 5),
('2.0', 2, 0),
('2.200', 2, 200)
)
@ddt.unpack
def test_valid_version_strings(self, version_string, major, minor):
request = api_version_request.APIVersionRequest(version_string)
self.assertEqual(major, request.ver_major)
self.assertEqual(minor, request.ver_minor)
def test_null_version(self):
v = api_version_request.APIVersionRequest()
self.assertTrue(v.is_null())
@ddt.data('2', '200', '2.1.4', '200.23.66.3', '5 .3', '5. 3',
'5.03', '02.1', '2.001', '', ' 2.1', '2.1 ')
def test_invalid_version_strings(self, version_string):
self.assertRaises(exception.InvalidAPIVersionString,
api_version_request.APIVersionRequest,
version_string)
def test_cmpkey(self):
request = api_version_request.APIVersionRequest('1.2')
self.assertEqual((1, 2), request._cmpkey())
def test_version_comparisons(self):
v1 = api_version_request.APIVersionRequest('2.0')
v2 = api_version_request.APIVersionRequest('2.5')
v3 = api_version_request.APIVersionRequest('5.23')
v4 = api_version_request.APIVersionRequest('2.0')
v_null = api_version_request.APIVersionRequest()
self.assertTrue(v1 < v2)
self.assertTrue(v1 <= v2)
self.assertTrue(v3 > v2)
self.assertTrue(v3 >= v2)
self.assertTrue(v1 != v2)
self.assertTrue(v1 == v4)
self.assertTrue(v1 != v_null)
self.assertTrue(v_null == v_null)
self.assertFalse(v1 == '2.0')
def test_version_matches(self):
v1 = api_version_request.APIVersionRequest('2.0')
v2 = api_version_request.APIVersionRequest('2.5')
v3 = api_version_request.APIVersionRequest('2.45')
v4 = api_version_request.APIVersionRequest('3.3')
v5 = api_version_request.APIVersionRequest('3.23')
v6 = api_version_request.APIVersionRequest('2.0')
v7 = api_version_request.APIVersionRequest('3.3')
v8 = api_version_request.APIVersionRequest('4.0')
v_null = api_version_request.APIVersionRequest()
self.assertTrue(v2.matches(v1, v3))
self.assertTrue(v2.matches(v1, v_null))
self.assertTrue(v1.matches(v6, v2))
self.assertTrue(v4.matches(v2, v7))
self.assertTrue(v4.matches(v_null, v7))
self.assertTrue(v4.matches(v_null, v8))
self.assertFalse(v1.matches(v2, v3))
self.assertFalse(v5.matches(v2, v4))
self.assertFalse(v2.matches(v3, v1))
self.assertTrue(v1.matches(v_null, v_null))
self.assertRaises(ValueError, v_null.matches, v1, v3)
def test_get_string(self):
v1_string = '3.23'
v1 = api_version_request.APIVersionRequest(v1_string)
self.assertEqual(v1_string, v1.get_string())
self.assertRaises(ValueError,
api_version_request.APIVersionRequest().get_string)
@ddt.data(('1', '0'), ('1', '1'))
@ddt.unpack
def test_str(self, major, minor):
request_input = '%s.%s' % (major, minor)
request = api_version_request.APIVersionRequest(request_input)
request_string = six.text_type(request)
self.assertEqual('API Version Request '
'Major: %s, Minor: %s' % (major, minor),
request_string)

View File

@ -0,0 +1,35 @@
# Copyright 2015 Clinton Knight
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import six
from manila.api.openstack import versioned_method
from manila import test
class VersionedMethodTestCase(test.TestCase):
def test_str(self):
args = ('fake_name', 'fake_min', 'fake_max')
method = versioned_method.VersionedMethod(*(args + (None,)))
method_string = six.text_type(method)
self.assertEqual('Version Method %s: min: %s, max: %s' % args,
method_string)
def test_cmpkey(self):
method = versioned_method.VersionedMethod(
'fake_name', 'fake_start_version', 'fake_end_version', 'fake_func')
self.assertEqual('fake_start_version', method._cmpkey())

View File

@ -27,7 +27,7 @@ class RequestTest(test.TestCase):
request = wsgi.Request.blank('/tests/123')
request.headers["Content-Type"] = "application/json; charset=UTF-8"
result = request.get_content_type()
self.assertEqual(result, "application/json")
self.assertEqual("application/json", result)
def test_content_type_from_accept(self):
content_type = 'application/json'
@ -36,34 +36,34 @@ class RequestTest(test.TestCase):
result = request.best_match_content_type()
self.assertEqual(result, content_type)
self.assertEqual(content_type, result)
def test_content_type_from_accept_best(self):
request = wsgi.Request.blank('/tests/123')
request.headers["Accept"] = "application/xml, application/json"
result = request.best_match_content_type()
self.assertEqual(result, "application/json")
self.assertEqual("application/json", result)
request = wsgi.Request.blank('/tests/123')
request.headers["Accept"] = ("application/json; q=0.3, "
"application/xml; q=0.9")
result = request.best_match_content_type()
self.assertEqual(result, "application/json")
self.assertEqual("application/json", result)
def test_content_type_from_query_extension(self):
request = wsgi.Request.blank('/tests/123.json')
result = request.best_match_content_type()
self.assertEqual(result, "application/json")
self.assertEqual("application/json", result)
request = wsgi.Request.blank('/tests/123.invalid')
result = request.best_match_content_type()
self.assertEqual(result, "application/json")
self.assertEqual("application/json", result)
def test_content_type_accept_default(self):
request = wsgi.Request.blank('/tests/123.unsupported')
request.headers["Accept"] = "application/unsupported1"
result = request.best_match_content_type()
self.assertEqual(result, "application/json")
self.assertEqual("application/json", result)
def test_cache_and_retrieve_resources(self):
request = wsgi.Request.blank('/foo')
@ -131,25 +131,25 @@ class ActionDispatcherTest(test.TestCase):
def test_dispatch(self):
serializer = wsgi.ActionDispatcher()
serializer.create = lambda x: 'pants'
self.assertEqual(serializer.dispatch({}, action='create'), 'pants')
self.assertEqual('pants', serializer.dispatch({}, action='create'))
def test_dispatch_action_None(self):
serializer = wsgi.ActionDispatcher()
serializer.create = lambda x: 'pants'
serializer.default = lambda x: 'trousers'
self.assertEqual(serializer.dispatch({}, action=None), 'trousers')
self.assertEqual('trousers', serializer.dispatch({}, action=None))
def test_dispatch_default(self):
serializer = wsgi.ActionDispatcher()
serializer.create = lambda x: 'pants'
serializer.default = lambda x: 'trousers'
self.assertEqual(serializer.dispatch({}, action='update'), 'trousers')
self.assertEqual('trousers', serializer.dispatch({}, action='update'))
class DictSerializerTest(test.TestCase):
def test_dispatch_default(self):
serializer = wsgi.DictSerializer()
self.assertEqual(serializer.serialize({}, 'update'), '')
self.assertEqual('', serializer.serialize({}, 'update'))
class JSONDictSerializerTest(test.TestCase):
@ -158,14 +158,14 @@ class JSONDictSerializerTest(test.TestCase):
expected_json = six.b('{"servers":{"a":[2,3]}}')
serializer = wsgi.JSONDictSerializer()
result = serializer.serialize(input_dict)
result = result.replace(six.b('\n'),six.b('')).replace(six.b(' '),six.b(''))
self.assertEqual(result, expected_json)
result = result.replace(six.b('\n'), six.b('')).replace(six.b(' '), six.b(''))
self.assertEqual(expected_json, result)
class TextDeserializerTest(test.TestCase):
def test_dispatch_default(self):
deserializer = wsgi.TextDeserializer()
self.assertEqual(deserializer.deserialize({}, 'update'), {})
self.assertEqual({}, deserializer.deserialize({}, 'update'))
class JSONDeserializerTest(test.TestCase):
@ -188,7 +188,7 @@ class JSONDeserializerTest(test.TestCase):
},
}
deserializer = wsgi.JSONDeserializer()
self.assertEqual(deserializer.deserialize(data), as_dict)
self.assertEqual(as_dict, deserializer.deserialize(data))
class ResourceTest(test.TestCase):
@ -200,8 +200,8 @@ class ResourceTest(test.TestCase):
req = webob.Request.blank('/tests')
app = fakes.TestRouter(Controller())
response = req.get_response(app)
self.assertEqual(response.body, six.b('off'))
self.assertEqual(response.status_int, 200)
self.assertEqual(six.b('off'), response.body)
self.assertEqual(200, response.status_int)
def test_resource_not_authorized(self):
class Controller(object):
@ -211,7 +211,7 @@ class ResourceTest(test.TestCase):
req = webob.Request.blank('/tests')
app = fakes.TestRouter(Controller())
response = req.get_response(app)
self.assertEqual(response.status_int, 403)
self.assertEqual(403, response.status_int)
def test_dispatch(self):
class Controller(object):
@ -223,7 +223,7 @@ class ResourceTest(test.TestCase):
method, extensions = resource.get_method(None, 'index', None, '')
actual = resource.dispatch(method, None, {'pants': 'off'})
expected = 'off'
self.assertEqual(actual, expected)
self.assertEqual(expected, actual)
def test_get_method_undefined_controller_action(self):
class Controller(object):
@ -302,7 +302,7 @@ class ResourceTest(test.TestCase):
expected = {'action': 'update', 'id': 12}
self.assertEqual(resource.get_action_args(env), expected)
self.assertEqual(expected, resource.get_action_args(env))
def test_get_body_bad_content(self):
class Controller(object):
@ -317,8 +317,8 @@ class ResourceTest(test.TestCase):
request.body = six.b('foo')
content_type, body = resource.get_body(request)
self.assertEqual(content_type, None)
self.assertEqual(body, '')
self.assertIsNone(content_type)
self.assertEqual('', body)
def test_get_body_no_content_type(self):
class Controller(object):
@ -332,8 +332,8 @@ class ResourceTest(test.TestCase):
request.body = six.b('foo')
content_type, body = resource.get_body(request)
self.assertEqual(content_type, None)
self.assertEqual(body, '')
self.assertIsNone(content_type)
self.assertEqual('', body)
def test_get_body_no_content_body(self):
class Controller(object):
@ -348,8 +348,8 @@ class ResourceTest(test.TestCase):
request.body = six.b('')
content_type, body = resource.get_body(request)
self.assertEqual(content_type, None)
self.assertEqual(body, '')
self.assertIsNone(content_type)
self.assertEqual('', body)
def test_get_body(self):
class Controller(object):
@ -364,8 +364,8 @@ class ResourceTest(test.TestCase):
request.body = six.b('foo')
content_type, body = resource.get_body(request)
self.assertEqual(content_type, 'application/json')
self.assertEqual(body, six.b('foo'))
self.assertEqual('application/json', content_type)
self.assertEqual(six.b('foo'), body)
def test_deserialize_badtype(self):
class Controller(object):
@ -396,7 +396,7 @@ class ResourceTest(test.TestCase):
resource = wsgi.Resource(controller, json=JSONDeserializer)
obj = resource.deserialize(controller.index, 'application/json', 'foo')
self.assertEqual(obj, 'json')
self.assertEqual('json', obj)
def test_deserialize_decorator(self):
class JSONDeserializer(object):
@ -411,7 +411,7 @@ class ResourceTest(test.TestCase):
resource = wsgi.Resource(controller, json=JSONDeserializer)
obj = resource.deserialize(controller.index, 'application/json', 'foo')
self.assertEqual(obj, 'json')
self.assertEqual('json', obj)
def test_register_actions(self):
class Controller(object):
@ -476,8 +476,8 @@ class ResourceTest(test.TestCase):
resource = wsgi.Resource(controller)
resource.register_extensions(extended)
method, extensions = resource.get_method(None, 'index', None, '')
self.assertEqual(method, controller.index)
self.assertEqual(extensions, [extended.index])
self.assertEqual(controller.index, method)
self.assertEqual([extended.index], extensions)
def test_get_method_action_extensions(self):
class Controller(wsgi.Controller):
@ -500,8 +500,8 @@ class ResourceTest(test.TestCase):
method, extensions = resource.get_method(None, 'action',
'application/json',
'{"fooAction": true}')
self.assertEqual(method, controller._action_foo)
self.assertEqual(extensions, [extended._action_foo])
self.assertEqual(controller._action_foo, method)
self.assertEqual([extended._action_foo], extensions)
def test_get_method_action_whitelist_extensions(self):
class Controller(wsgi.Controller):
@ -525,12 +525,12 @@ class ResourceTest(test.TestCase):
method, extensions = resource.get_method(None, 'create',
'application/json',
'{"create": true}')
self.assertEqual(method, extended._create)
self.assertEqual(extensions, [])
self.assertEqual(extended._create, method)
self.assertEqual([], extensions)
method, extensions = resource.get_method(None, 'delete', None, None)
self.assertEqual(method, extended._delete)
self.assertEqual(extensions, [])
self.assertEqual(extended._delete, method)
self.assertEqual([], extensions)
def test_pre_process_extensions_regular(self):
class Controller(object):
@ -552,9 +552,9 @@ class ResourceTest(test.TestCase):
extensions = [extension1, extension2]
response, post = resource.pre_process_extensions(extensions, None, {})
self.assertEqual(called, [])
self.assertEqual(response, None)
self.assertEqual(list(post), [extension2, extension1])
self.assertEqual([], called)
self.assertIsNone(response)
self.assertEqual([extension2, extension1], list(post))
def test_pre_process_extensions_generator(self):
class Controller(object):
@ -579,9 +579,9 @@ class ResourceTest(test.TestCase):
extensions = [extension1, extension2]
response, post = resource.pre_process_extensions(extensions, None, {})
post = list(post)
self.assertEqual(called, ['pre1', 'pre2'])
self.assertEqual(response, None)
self.assertEqual(len(post), 2)
self.assertEqual(['pre1', 'pre2'], called)
self.assertIsNone(response)
self.assertEqual(2, len(post))
self.assertTrue(inspect.isgenerator(post[0]))
self.assertTrue(inspect.isgenerator(post[1]))
@ -591,7 +591,7 @@ class ResourceTest(test.TestCase):
except StopIteration:
continue
self.assertEqual(called, ['pre1', 'pre2', 'post2', 'post1'])
self.assertEqual(['pre1', 'pre2', 'post2', 'post1'], called)
def test_pre_process_extensions_generator_response(self):
class Controller(object):
@ -612,9 +612,9 @@ class ResourceTest(test.TestCase):
extensions = [extension1, extension2]
response, post = resource.pre_process_extensions(extensions, None, {})
self.assertEqual(called, ['pre1'])
self.assertEqual(response, 'foo')
self.assertEqual(post, [])
self.assertEqual(['pre1'], called)
self.assertEqual('foo', response)
self.assertEqual([], post)
def test_post_process_extensions_regular(self):
class Controller(object):
@ -636,8 +636,8 @@ class ResourceTest(test.TestCase):
response = resource.post_process_extensions([extension2, extension1],
None, None, {})
self.assertEqual(called, [2, 1])
self.assertEqual(response, None)
self.assertEqual([2, 1], called)
self.assertIsNone(response)
def test_post_process_extensions_regular_response(self):
class Controller(object):
@ -659,8 +659,30 @@ class ResourceTest(test.TestCase):
response = resource.post_process_extensions([extension2, extension1],
None, None, {})
self.assertEqual(called, [2])
self.assertEqual(response, 'foo')
self.assertEqual([2], called)
self.assertEqual('foo', response)
def test_post_process_extensions_version_not_found(self):
class Controller(object):
def index(self, req, pants=None):
return pants
controller = Controller()
resource = wsgi.Resource(controller)
called = []
def extension1(req, resp_obj):
called.append(1)
return 'bar'
def extension2(req, resp_obj):
raise exception.VersionNotFoundForAPIMethod(version='fake_version')
response = resource.post_process_extensions([extension2, extension1],
None, None, {})
self.assertEqual([1], called)
self.assertEqual('bar', response)
def test_post_process_extensions_generator(self):
class Controller(object):
@ -688,8 +710,8 @@ class ResourceTest(test.TestCase):
response = resource.post_process_extensions([ext2, ext1],
None, None, {})
self.assertEqual(called, [2, 1])
self.assertEqual(response, None)
self.assertEqual([2, 1], called)
self.assertIsNone(response)
def test_post_process_extensions_generator_response(self):
class Controller(object):
@ -718,38 +740,38 @@ class ResourceTest(test.TestCase):
response = resource.post_process_extensions([ext2, ext1],
None, None, {})
self.assertEqual(called, [2])
self.assertEqual(response, 'foo')
self.assertEqual([2], called)
self.assertEqual('foo', response)
class ResponseObjectTest(test.TestCase):
def test_default_code(self):
robj = wsgi.ResponseObject({})
self.assertEqual(robj.code, 200)
self.assertEqual(200, robj.code)
def test_modified_code(self):
robj = wsgi.ResponseObject({})
robj._default_code = 202
self.assertEqual(robj.code, 202)
self.assertEqual(202, robj.code)
def test_override_default_code(self):
robj = wsgi.ResponseObject({}, code=404)
self.assertEqual(robj.code, 404)
self.assertEqual(404, robj.code)
def test_override_modified_code(self):
robj = wsgi.ResponseObject({}, code=404)
robj._default_code = 202
self.assertEqual(robj.code, 404)
self.assertEqual(404, robj.code)
def test_set_header(self):
robj = wsgi.ResponseObject({})
robj['Header'] = 'foo'
self.assertEqual(robj.headers, {'header': 'foo'})
self.assertEqual({'header': 'foo'}, robj.headers)
def test_get_header(self):
robj = wsgi.ResponseObject({})
robj['Header'] = 'foo'
self.assertEqual(robj['hEADER'], 'foo')
self.assertEqual('foo', robj['hEADER'])
def test_del_header(self):
robj = wsgi.ResponseObject({})
@ -762,22 +784,22 @@ class ResponseObjectTest(test.TestCase):
robj['Header'] = 'foo'
hdrs = robj.headers
hdrs['hEADER'] = 'bar'
self.assertEqual(robj['hEADER'], 'foo')
self.assertEqual('foo', robj['hEADER'])
def test_default_serializers(self):
robj = wsgi.ResponseObject({})
self.assertEqual(robj.serializers, {})
self.assertEqual({}, robj.serializers)
def test_bind_serializers(self):
robj = wsgi.ResponseObject({}, json='foo')
robj._bind_method_serializers(dict(xml='bar', json='baz'))
self.assertEqual(robj.serializers, dict(xml='bar', json='foo'))
self.assertEqual(dict(xml='bar', json='foo'), robj.serializers)
def test_get_serializer(self):
robj = wsgi.ResponseObject({}, json='json', xml='xml', atom='atom')
for content_type, mtype in wsgi._MEDIA_TYPE_MAP.items():
_mtype, serializer = robj.get_serializer(content_type)
self.assertEqual(serializer, mtype)
self.assertEqual(mtype, serializer)
def test_get_serializer_defaults(self):
robj = wsgi.ResponseObject({})
@ -787,7 +809,7 @@ class ResponseObjectTest(test.TestCase):
robj.get_serializer, content_type)
_mtype, serializer = robj.get_serializer(content_type,
default_serializers)
self.assertEqual(serializer, mtype)
self.assertEqual(mtype, serializer)
def test_serialize(self):
class JSONSerializer(object):
@ -813,11 +835,11 @@ class ResponseObjectTest(test.TestCase):
request = wsgi.Request.blank('/tests/123')
response = robj.serialize(request, content_type)
self.assertEqual(response.headers['Content-Type'], content_type)
self.assertEqual(response.headers['X-header1'], 'header1')
self.assertEqual(response.headers['X-header2'], 'header2')
self.assertEqual(response.status_int, 202)
self.assertEqual(response.body, six.b(mtype))
self.assertEqual(content_type, response.headers['Content-Type'])
self.assertEqual('header1', response.headers['X-header1'])
self.assertEqual('header2', response.headers['X-header2'])
self.assertEqual(202, response.status_int)
self.assertEqual(six.b(mtype), response.body)
class ValidBodyTest(test.TestCase):
@ -831,19 +853,15 @@ class ValidBodyTest(test.TestCase):
self.assertTrue(self.controller.is_valid_body(body, 'foo'))
def test_is_valid_body_none(self):
resource = wsgi.Resource(controller=None)
self.assertFalse(self.controller.is_valid_body(None, 'foo'))
def test_is_valid_body_empty(self):
resource = wsgi.Resource(controller=None)
self.assertFalse(self.controller.is_valid_body({}, 'foo'))
def test_is_valid_body_no_entity(self):
resource = wsgi.Resource(controller=None)
body = {'bar': {}}
self.assertFalse(self.controller.is_valid_body(body, 'foo'))
def test_is_valid_body_malformed_entity(self):
resource = wsgi.Resource(controller=None)
body = {'foo': 'bar'}
self.assertFalse(self.controller.is_valid_body(body, 'foo'))

View File

@ -1,119 +0,0 @@
# Copyright 2011 Denali Systems, Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo_config import cfg
from oslo_log import log
from manila.api.openstack import wsgi
from manila.api.v1 import router
from manila.api import versions
from manila import test
from manila.tests.api import fakes
CONF = cfg.CONF
LOG = log.getLogger(__name__)
class FakeController(object):
def __init__(self, ext_mgr=None):
self.ext_mgr = ext_mgr
def index(self, req):
return {}
def detail(self, req):
return {}
def create_resource(ext_mgr):
return wsgi.Resource(FakeController(ext_mgr))
class VolumeRouterTestCase(test.TestCase):
def setUp(self):
super(VolumeRouterTestCase, self).setUp()
# NOTE(vish): versions is just returning text so, no need to stub.
self.app = router.APIRouter()
def test_versions(self):
req = fakes.HTTPRequest.blank('')
req.method = 'GET'
req.content_type = 'application/json'
response = req.get_response(self.app)
self.assertEqual(302, response.status_int)
req = fakes.HTTPRequest.blank('/')
req.method = 'GET'
req.content_type = 'application/json'
response = req.get_response(self.app)
self.assertEqual(200, response.status_int)
def test_versions_multi(self):
req = fakes.HTTPRequest.blank('/')
req.method = 'GET'
req.content_type = 'application/json'
resource = versions.Versions()
result = resource.dispatch(resource.multi, req, {})
ids = [v['id'] for v in result['choices']]
self.assertEqual(set(ids), set(['v1.0', 'v2.0']))
def test_versions_multi_disable_v1(self):
self.flags(enable_v1_api=False)
req = fakes.HTTPRequest.blank('/')
req.method = 'GET'
req.content_type = 'application/json'
resource = versions.Versions()
result = resource.dispatch(resource.multi, req, {})
ids = [v['id'] for v in result['choices']]
self.assertEqual(set(ids), set(['v2.0']))
def test_versions_multi_disable_v2(self):
self.flags(enable_v2_api=False)
req = fakes.HTTPRequest.blank('/')
req.method = 'GET'
req.content_type = 'application/json'
resource = versions.Versions()
result = resource.dispatch(resource.multi, req, {})
ids = [v['id'] for v in result['choices']]
self.assertEqual(set(ids), set(['v1.0']))
def test_versions_index(self):
req = fakes.HTTPRequest.blank('/')
req.method = 'GET'
req.content_type = 'application/json'
resource = versions.Versions()
result = resource.dispatch(resource.index, req, {})
ids = [v['id'] for v in result['versions']]
self.assertEqual(set(ids), set(['v1.0', 'v2.0']))
def test_versions_index_disable_v1(self):
self.flags(enable_v1_api=False)
req = fakes.HTTPRequest.blank('/')
req.method = 'GET'
req.content_type = 'application/json'
resource = versions.Versions()
result = resource.dispatch(resource.index, req, {})
ids = [v['id'] for v in result['versions']]
self.assertEqual(set(ids), set(['v2.0']))
def test_versions_index_disable_v2(self):
self.flags(enable_v2_api=False)
req = fakes.HTTPRequest.blank('/')
req.method = 'GET'
req.content_type = 'application/json'
resource = versions.Versions()
result = resource.dispatch(resource.index, req, {})
ids = [v['id'] for v in result['versions']]
self.assertEqual(set(ids), set(['v1.0']))

View File

@ -0,0 +1,156 @@
# Copyright 2015 Clinton Knight
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import ddt
import mock
from oslo_serialization import jsonutils
from manila.api.openstack import api_version_request
from manila.api.openstack import wsgi
from manila.api.v1 import router
from manila.api import versions
from manila import test
from manila.tests.api import fakes
@ddt.ddt
class VersionsControllerTestCase(test.TestCase):
version_header_name = 'X-OpenStack-Manila-API-Version'
def setUp(self):
super(VersionsControllerTestCase, self).setUp()
self.wsgi_apps = (versions.VersionsRouter(), router.APIRouter())
@ddt.data(('', 302), ('/', 200))
@ddt.unpack
def test_versions_return_codes(self, request_path, return_code):
req = fakes.HTTPRequest.blank(request_path)
req.method = 'GET'
req.content_type = 'application/json'
for app in self.wsgi_apps:
response = req.get_response(app)
self.assertEqual(return_code, response.status_int)
@ddt.data(
('http://localhost/', True),
(None, True),
('http://localhost/', False),
(None, False),
)
@ddt.unpack
def test_versions_index_v10(self, base_url, include_header):
req = fakes.HTTPRequest.blank('/', base_url=base_url)
req.method = 'GET'
req.content_type = 'application/json'
if include_header:
req.headers = {self.version_header_name: '1.0'}
for app in self.wsgi_apps:
response = req.get_response(app)
body = jsonutils.loads(response.body)
version_list = body['versions']
ids = [v['id'] for v in version_list]
self.assertEqual({'v1.0'}, set(ids))
self.assertEqual('1.0', response.headers[self.version_header_name])
self.assertEqual(self.version_header_name,
response.headers['Vary'])
self.assertIsNone(version_list[0].get('min_version'))
self.assertIsNone(version_list[0].get('version'))
@ddt.data(
('http://localhost/', '1.1'),
(None, '1.1'),
('http://localhost/', 'latest'),
(None, 'latest')
)
@ddt.unpack
def test_versions_index_v11(self, base_url, req_version):
req = fakes.HTTPRequest.blank('/', base_url=base_url)
req.method = 'GET'
req.content_type = 'application/json'
req.headers = {self.version_header_name: req_version}
for app in self.wsgi_apps:
response = req.get_response(app)
body = jsonutils.loads(response.body)
version_list = body['versions']
ids = [v['id'] for v in version_list]
self.assertEqual({'v1.0'}, set(ids))
if req_version == 'latest':
self.assertEqual(api_version_request._MAX_API_VERSION,
response.headers[self.version_header_name])
else:
self.assertEqual(req_version,
response.headers[self.version_header_name])
self.assertEqual(self.version_header_name,
response.headers['Vary'])
self.assertEqual(api_version_request._MIN_API_VERSION,
version_list[0].get('min_version'))
self.assertEqual(api_version_request._MAX_API_VERSION,
version_list[0].get('version'))
@ddt.data('http://localhost/', None)
def test_versions_index_v2(self, base_url):
req = fakes.HTTPRequest.blank('/', base_url=base_url)
req.method = 'GET'
req.content_type = 'application/json'
req.headers = {self.version_header_name: '2.0'}
for app in self.wsgi_apps:
response = req.get_response(app)
self.assertEqual(406, response.status_int)
self.assertEqual('2.0', response.headers[self.version_header_name])
self.assertEqual(self.version_header_name,
response.headers['Vary'])
@ddt.data('http://localhost/', None)
def test_versions_index_invalid_version_request(self, base_url):
req = fakes.HTTPRequest.blank('/', base_url=base_url)
req.method = 'GET'
req.content_type = 'application/json'
req.headers = {self.version_header_name: '2.0.1'}
for app in self.wsgi_apps:
response = req.get_response(app)
self.assertEqual(400, response.status_int)
self.assertEqual('1.0', response.headers[self.version_header_name])
self.assertEqual(self.version_header_name,
response.headers['Vary'])
def test_versions_version_not_found(self):
api_version_request_3_0 = api_version_request.APIVersionRequest('3.0')
self.mock_object(api_version_request,
'max_api_version',
mock.Mock(return_value=api_version_request_3_0))
class Controller(wsgi.Controller):
@wsgi.Controller.api_version('1.0', '1.0')
def index(self, req):
return 'off'
req = fakes.HTTPRequest.blank('/tests')
req.headers = {self.version_header_name: '2.0'}
app = fakes.TestRouter(Controller())
response = req.get_response(app)
self.assertEqual(404, response.status_int)

View File

@ -0,0 +1,155 @@
# Copyright 2015 Clinton Knight
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import copy
import ddt
import mock
from manila.api.views import versions
from manila import test
class FakeRequest(object):
def __init__(self, application_url):
self.application_url = application_url
URL_BASE = 'http://localhost/'
FAKE_HREF = URL_BASE + 'v1/'
FAKE_VERSIONS = {
"v1.0": {
"id": "v1.0",
"status": "CURRENT",
"version": "1.1",
"min_version": "1.0",
"updated": "2015-07-30T11:33:21Z",
"links": [
{
"rel": "describedby",
"type": "text/html",
"href": 'http://docs.openstack.org/',
},
],
"media-types": [
{
"base": "application/json",
"type": "application/vnd.openstack.share+json;version=1",
}
],
},
}
FAKE_LINKS = [
{
"rel": "describedby",
"type": "text/html",
"href": 'http://docs.openstack.org/',
},
{
'rel': 'self',
'href': FAKE_HREF
},
]
@ddt.ddt
class ViewBuilderTestCase(test.TestCase):
def _get_builder(self):
request = FakeRequest('fake')
return versions.get_view_builder(request)
def test_build_versions(self):
self.mock_object(versions.ViewBuilder,
'_build_links',
mock.Mock(return_value=FAKE_LINKS))
result = self._get_builder().build_versions(FAKE_VERSIONS)
expected = {'versions': list(FAKE_VERSIONS.values())}
expected['versions'][0]['links'] = FAKE_LINKS
self.assertEqual(expected, result)
def test_build_version(self):
self.mock_object(versions.ViewBuilder,
'_build_links',
mock.Mock(return_value=FAKE_LINKS))
result = self._get_builder()._build_version(FAKE_VERSIONS['v1.0'])
expected = copy.deepcopy(FAKE_VERSIONS['v1.0'])
expected['links'] = FAKE_LINKS
self.assertEqual(expected, result)
def test_build_links(self):
self.mock_object(versions.ViewBuilder,
'_generate_href',
mock.Mock(return_value=FAKE_HREF))
result = self._get_builder()._build_links(FAKE_VERSIONS['v1.0'])
self.assertEqual(FAKE_LINKS, result)
def test_generate_href_defaults(self):
self.mock_object(versions.ViewBuilder,
'_get_base_url_without_version',
mock.Mock(return_value=URL_BASE))
result = self._get_builder()._generate_href()
self.assertEqual('http://localhost/v1/', result)
@ddt.data(
('v2', None, URL_BASE + 'v2/'),
('/v2/', None, URL_BASE + 'v2/'),
('/v2/', 'fake_path', URL_BASE + 'v2/fake_path'),
('/v2/', '/fake_path/', URL_BASE + 'v2/fake_path/'),
)
@ddt.unpack
def test_generate_href_no_path(self, version, path, expected):
self.mock_object(versions.ViewBuilder,
'_get_base_url_without_version',
mock.Mock(return_value=URL_BASE))
result = self._get_builder()._generate_href(version=version,
path=path)
self.assertEqual(expected, result)
@ddt.data(
('http://1.1.1.1/', 'http://1.1.1.1/'),
('http://localhost/', 'http://localhost/'),
('http://1.1.1.1/v1/', 'http://1.1.1.1/'),
('http://1.1.1.1/v1', 'http://1.1.1.1/'),
('http://1.1.1.1/v11', 'http://1.1.1.1/'),
)
@ddt.unpack
def test_get_base_url_without_version(self, base_url, base_url_no_version):
request = FakeRequest(base_url)
builder = versions.get_view_builder(request)
result = builder._get_base_url_without_version()
self.assertEqual(base_url_no_version, result)

View File

@ -602,6 +602,56 @@ class IsValidIPVersion(test.TestCase):
self.assertFalse(utils.is_valid_ip_address(addr, vers))
class Comparable(utils.ComparableMixin):
def __init__(self, value):
self.value = value
def _cmpkey(self):
return self.value
class TestComparableMixin(test.TestCase):
def setUp(self):
super(TestComparableMixin, self).setUp()
self.one = Comparable(1)
self.two = Comparable(2)
def test_lt(self):
self.assertTrue(self.one < self.two)
self.assertFalse(self.two < self.one)
self.assertFalse(self.one < self.one)
def test_le(self):
self.assertTrue(self.one <= self.two)
self.assertFalse(self.two <= self.one)
self.assertTrue(self.one <= self.one)
def test_eq(self):
self.assertFalse(self.one == self.two)
self.assertFalse(self.two == self.one)
self.assertTrue(self.one == self.one)
def test_ge(self):
self.assertFalse(self.one >= self.two)
self.assertTrue(self.two >= self.one)
self.assertTrue(self.one >= self.one)
def test_gt(self):
self.assertFalse(self.one > self.two)
self.assertTrue(self.two > self.one)
self.assertFalse(self.one > self.one)
def test_ne(self):
self.assertTrue(self.one != self.two)
self.assertTrue(self.two != self.one)
self.assertFalse(self.one != self.one)
def test_compare(self):
self.assertEqual(NotImplemented,
self.one._compare(1, self.one._cmpkey))
class TestRetryDecorator(test.TestCase):
def setUp(self):
super(TestRetryDecorator, self).setUp()

View File

@ -476,6 +476,34 @@ class IsAMatcher(object):
return isinstance(actual_value, self.expected_value)
class ComparableMixin(object):
def _compare(self, other, method):
try:
return method(self._cmpkey(), other._cmpkey())
except (AttributeError, TypeError):
# _cmpkey not implemented, or return different type,
# so I can't compare with "other".
return NotImplemented
def __lt__(self, other):
return self._compare(other, lambda s, o: s < o)
def __le__(self, other):
return self._compare(other, lambda s, o: s <= o)
def __eq__(self, other):
return self._compare(other, lambda s, o: s == o)
def __ge__(self, other):
return self._compare(other, lambda s, o: s >= o)
def __gt__(self, other):
return self._compare(other, lambda s, o: s > o)
def __ne__(self, other):
return self._compare(other, lambda s, o: s != o)
def retry(exception, interval=1, retries=10, backoff_rate=2):
"""A wrapper around retrying library.