Release 1.1.0

Change-Id: I7d49059a98ed36825f4be3324b165d196cf66474
This commit is contained in:
Anthony Michon 2015-04-15 11:44:34 +02:00
parent 24406d820a
commit f6176f3fd7
161 changed files with 19527 additions and 46 deletions

202
LICENSE-2.0.txt Normal file
View File

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -1,19 +1,20 @@
# -*- coding: utf-8 -*-
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import pbr.version
__version__ = pbr.version.VersionInfo(
'cerberus').version_string()

27
cerberus/api/__init__.py Normal file
View File

@ -0,0 +1,27 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from oslo.config import cfg
from cerberus.openstack.common._i18n import _ # noqa
keystone_opts = [
cfg.StrOpt('auth_strategy', default='keystone',
help=_('The strategy to use for authentication.'))
]
CONF = cfg.CONF
CONF.register_opts(keystone_opts)

112
cerberus/api/app.py Normal file
View File

@ -0,0 +1,112 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pecan
from wsgiref import simple_server
from oslo.config import cfg
import auth
from cerberus.api import config as api_config
from cerberus.api import hooks
from cerberus.openstack.common import log as logging
LOG = logging.getLogger(__name__)
auth_opts = [
cfg.StrOpt('api_paste_config',
default="api_paste.ini",
help="Configuration file for WSGI definition of API."
),
]
api_opts = [
cfg.StrOpt('host_ip',
default="0.0.0.0",
help="Host serving the API."
),
cfg.IntOpt('port',
default=8300,
help="Host port serving the API."
),
]
CONF = cfg.CONF
CONF.register_opts(auth_opts)
CONF.register_opts(api_opts, group='api')
def get_pecan_config():
# Set up the pecan configuration
filename = api_config.__file__.replace('.pyc', '.py')
return pecan.configuration.conf_from_file(filename)
def setup_app(pecan_config=None, extra_hooks=None):
if not pecan_config:
pecan_config = get_pecan_config()
app_hooks = [hooks.ConfigHook(),
hooks.DBHook(),
hooks.ContextHook(pecan_config.app.acl_public_routes),
hooks.NoExceptionTracebackHook()]
if pecan_config.app.enable_acl:
app_hooks.append(hooks.AuthorizationHook(
pecan_config.app.member_routes))
pecan.configuration.set_config(dict(pecan_config), overwrite=True)
app = pecan.make_app(
pecan_config.app.root,
static_root=pecan_config.app.static_root,
template_path=pecan_config.app.template_path,
debug=CONF.debug,
force_canonical=getattr(pecan_config.app, 'force_canonical', True),
hooks=app_hooks,
guess_content_type_from_ext=False
)
if pecan_config.app.enable_acl:
strategy = auth.strategy(CONF.auth_strategy)
return strategy.install(app,
cfg.CONF,
pecan_config.app.acl_public_routes)
return app
def build_server():
# Create the WSGI server and start it
host = CONF.api.host_ip
port = CONF.api.port
server_cls = simple_server.WSGIServer
handler_cls = simple_server.WSGIRequestHandler
pecan_config = get_pecan_config()
pecan_config.app.enable_acl = (CONF.auth_strategy == 'keystone')
app = setup_app(pecan_config=pecan_config)
srv = simple_server.make_server(
host,
port,
app,
server_cls,
handler_cls)
return srv

62
cerberus/api/auth.py Normal file
View File

@ -0,0 +1,62 @@
#
# Copyright (c) 2015 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from cerberus.api.middleware import auth_token
from cerberus.openstack.common import log
STRATEGIES = {}
LOG = log.getLogger(__name__)
OPT_GROUP_NAME = 'keystone_authtoken'
class KeystoneAuth(object):
@classmethod
def _register_opts(cls, conf):
"""Register keystoneclient middleware options."""
if OPT_GROUP_NAME not in conf:
conf.register_opts(auth_token.opts, group=OPT_GROUP_NAME)
auth_token.CONF = conf
@classmethod
def install(cls, app, conf, public_routes):
"""Install Auth check on application."""
LOG.debug(u'Installing Keystone\'s auth protocol')
cls._register_opts(conf)
conf = dict(conf.get(OPT_GROUP_NAME))
return auth_token.AuthTokenMiddleware(app,
conf=conf,
public_api_routes=public_routes)
STRATEGIES['keystone'] = KeystoneAuth
def strategy(strategy):
"""Returns the Auth Strategy.
:param strategy: String representing
the strategy to use
"""
try:
return STRATEGIES[strategy]
except KeyError:
raise RuntimeError

23
cerberus/api/config.py Normal file
View File

@ -0,0 +1,23 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# Pecan Application Configurations
app = {
'root': 'cerberus.api.root.RootController',
'modules': ['cerberus.api'],
'static_root': '%(confdir)s/public',
'template_path': '%(confdir)s/templates',
'debug': True,
'enable_acl': False,
'acl_public_routes': ['/', '/v1'],
'member_routes': ['/v1/security_reports', ]
}

147
cerberus/api/hooks.py Normal file
View File

@ -0,0 +1,147 @@
#
# Copyright (c) 2015 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from oslo.config import cfg
from pecan import hooks
from webob import exc
from cerberus.common import context
from cerberus.common import policy
from cerberus.db import api as dbapi
class ConfigHook(hooks.PecanHook):
"""Attach the config object to the request so controllers can get to it."""
def before(self, state):
state.request.cfg = cfg.CONF
class DBHook(hooks.PecanHook):
"""Attach the dbapi object to the request so controllers can get to it."""
def before(self, state):
state.request.dbapi = dbapi.get_instance()
class ContextHook(hooks.PecanHook):
"""Configures a request context and attaches it to the request.
The following HTTP request headers are used:
X-User-Id or X-User:
Used for context.user_id.
X-Tenant-Id or X-Tenant:
Used for context.tenant.
X-Auth-Token:
Used for context.auth_token.
X-Roles:
Used for setting context.is_admin flag to either True or False.
The flag is set to True, if X-Roles contains either an administrator
or admin substring. Otherwise it is set to False.
"""
def __init__(self, public_api_routes):
self.public_api_routes = public_api_routes
super(ContextHook, self).__init__()
def before(self, state):
user_id = state.request.headers.get('X-User-Id')
user_id = state.request.headers.get('X-User', user_id)
tenant_id = state.request.headers.get('X-Tenant-Id')
tenant = state.request.headers.get('X-Tenant', tenant_id)
domain_id = state.request.headers.get('X-User-Domain-Id')
domain_name = state.request.headers.get('X-User-Domain-Name')
auth_token = state.request.headers.get('X-Auth-Token')
roles = state.request.headers.get('X-Roles', '').split(',')
creds = {'roles': roles}
is_public_api = state.request.environ.get('is_public_api', False)
is_admin = policy.enforce('context_is_admin',
state.request.headers,
creds)
state.request.context = context.RequestContext(
auth_token=auth_token,
user=user_id,
tenant_id=tenant_id,
tenant=tenant,
domain_id=domain_id,
domain_name=domain_name,
is_admin=is_admin,
is_public_api=is_public_api,
roles=roles)
class AuthorizationHook(hooks.PecanHook):
"""Verify that the user has admin rights.
Checks whether the request context is an admin context and
rejects the request if the api is not public.
"""
def __init__(self, member_routes):
self.member_routes = member_routes
super(AuthorizationHook, self).__init__()
def before(self, state):
ctx = state.request.context
if not ctx.is_admin and not ctx.is_public_api and \
state.request.path not in self.member_routes:
raise exc.HTTPForbidden()
class NoExceptionTracebackHook(hooks.PecanHook):
"""Workaround rpc.common: deserialize_remote_exception.
deserialize_remote_exception builds rpc exception traceback into error
message which is then sent to the client. Such behavior is a security
concern so this hook is aimed to cut-off traceback from the error message.
"""
# NOTE(max_lobur): 'after' hook used instead of 'on_error' because
# 'on_error' never fired for wsme+pecan pair. wsme @wsexpose decorator
# catches and handles all the errors, so 'on_error' dedicated for unhandled
# exceptions never fired.
def after(self, state):
# Omit empty body. Some errors may not have body at this level yet.
if not state.response.body:
return
# Do nothing if there is no error.
if 200 <= state.response.status_int < 400:
return
json_body = state.response.json
# Do not remove traceback when server in debug mode (except 'Server'
# errors when 'debuginfo' will be used for traces).
if cfg.CONF.debug and json_body.get('faultcode') != 'Server':
return
faultsting = json_body.get('faultstring')
traceback_marker = 'Traceback (most recent call last):'
if faultsting and (traceback_marker in faultsting):
# Cut-off traceback.
faultsting = faultsting.split(traceback_marker, 1)[0]
# Remove trailing newlines and spaces if any.
json_body['faultstring'] = faultsting.rstrip()
# Replace the whole json. Cannot change original one beacause it's
# generated on the fly.
state.response.json = json_body

View File

@ -0,0 +1,20 @@
#
# Copyright (c) 2015 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from cerberus.api.middleware import auth_token
AuthTokenMiddleware = auth_token.AuthTokenMiddleware

View File

@ -0,0 +1,61 @@
#
# Copyright (c) 2015 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from keystoneclient.middleware import auth_token
from cerberus.common import exception
from cerberus.common import safe_utils
from cerberus.openstack.common import log
LOG = log.getLogger(__name__)
class AuthTokenMiddleware(auth_token.AuthProtocol):
"""A wrapper on Keystone auth_token middleware.
Does not perform verification of authentication tokens
for public routes in the API.
"""
def __init__(self, app, conf, public_api_routes=[]):
route_pattern_tpl = '%s(\.json|\.xml)?$'
try:
self.public_api_routes = [re.compile(route_pattern_tpl % route_tpl)
for route_tpl in public_api_routes]
except re.error as e:
msg = _('Cannot compile public API routes: %s') % e
LOG.error(msg)
raise exception.ConfigInvalid(error_msg=msg)
super(AuthTokenMiddleware, self).__init__(app, conf)
def __call__(self, env, start_response):
path = safe_utils.safe_rstrip(env.get('PATH_INFO'), '/')
# The information whether the API call is being performed against the
# public API is required for some other components. Saving it to the
# WSGI environment is reasonable thereby.
env['is_public_api'] = any(map(lambda pattern: re.match(pattern, path),
self.public_api_routes))
if env['is_public_api']:
return self.app(env, start_response)
return super(AuthTokenMiddleware, self).__call__(env, start_response)

140
cerberus/api/root.py Normal file
View File

@ -0,0 +1,140 @@
#
# Copyright (c) 2015 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pecan
from pecan import rest
from wsme import types as wtypes
import wsmeext.pecan as wsme_pecan
from cerberus.api.v1 import controllers as v1_api
from cerberus.openstack.common import log as logging
LOG = logging.getLogger(__name__)
VERSION_STATUS = wtypes.Enum(wtypes.text, 'EXPERIMENTAL', 'STABLE')
class APILink(wtypes.Base):
"""API link description.
"""
type = wtypes.text
"""Type of link."""
rel = wtypes.text
"""Relationship with this link."""
href = wtypes.text
"""URL of the link."""
@classmethod
def sample(cls):
version = 'v1'
sample = cls(
rel='self',
type='text/html',
href='http://127.0.0.1:8888/{id}'.format(
id=version))
return sample
class APIMediaType(wtypes.Base):
"""Media type description.
"""
base = wtypes.text
"""Base type of this media type."""
type = wtypes.text
"""Type of this media type."""
@classmethod
def sample(cls):
sample = cls(
base='application/json',
type='application/vnd.openstack.sticks-v1+json')
return sample
class APIVersion(wtypes.Base):
"""API Version description.
"""
id = wtypes.text
"""ID of the version."""
status = VERSION_STATUS
"""Status of the version."""
updated = wtypes.text
"Last update in iso8601 format."
links = [APILink]
"""List of links to API resources."""
media_types = [APIMediaType]
"""Types accepted by this API."""
@classmethod
def sample(cls):
version = 'v1'
updated = '2014-08-11T16:00:00Z'
links = [APILink.sample()]
media_types = [APIMediaType.sample()]
sample = cls(id=version,
status='STABLE',
updated=updated,
links=links,
media_types=media_types)
return sample
class RootController(rest.RestController):
"""Root REST Controller exposing versions of the API.
"""
v1 = v1_api.V1Controller()
@wsme_pecan.wsexpose([APIVersion])
def get(self):
"""Return the version list
"""
# TODO(sheeprine): Maybe we should store all the API version
# informations in every API modules
ver1 = APIVersion(
id='v1',
status='EXPERIMENTAL',
updated='2015-03-09T16:00:00Z',
links=[
APILink(
rel='self',
href='{scheme}://{host}/v1'.format(
scheme=pecan.request.scheme,
host=pecan.request.host,
)
)
],
media_types=[]
)
versions = []
versions.append(ver1)
return versions

View File

View File

@ -0,0 +1,31 @@
#
# Copyright (c) 2015 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pecan import rest
from cerberus.api.v1.controllers import alerts as alerts_api
from cerberus.api.v1.controllers import plugins as plugins_api
from cerberus.api.v1.controllers import security_reports as \
security_reports_api
from cerberus.api.v1.controllers import tasks as tasks_api
class V1Controller(rest.RestController):
"""API version 1 controller. """
alerts = alerts_api.AlertsController()
plugins = plugins_api.PluginsController()
security_reports = security_reports_api.SecurityReportsController()
tasks = tasks_api.TasksController()

View File

@ -0,0 +1,33 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pecan import rest
from oslo.config import cfg
from oslo import messaging
from cerberus.openstack.common import log
LOG = log.getLogger(__name__)
class BaseController(rest.RestController):
def __init__(self):
transport = messaging.get_transport(cfg.CONF)
target = messaging.Target(topic='test_rpc', server='server1')
self.client = messaging.RPCClient(transport, target)

View File

@ -0,0 +1,51 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
JSON_HOME = {
'resources': {
# -----------------------------------------------------------------
# Plugins
# -----------------------------------------------------------------
'rel/plugins': {
'href-template': '/v1/plugins{?id}',
'href-vars': {
'method_name': 'param/method_name',
'task_name': 'param/task_name',
'task_type': 'param/task_type',
'task_period': 'param/task_period',
},
'hints': {
'allow': ['GET', 'POST'],
'formats': {
'application/json': {},
},
},
}
}
}
class Resource(object):
def __init__(self):
document = json.dumps(JSON_HOME, ensure_ascii=False, indent=4)
self.document_utf8 = document.encode('utf-8')
def on_get(self, req, resp):
resp.data = self.document_utf8
resp.content_type = 'application/json-home'
resp.cache_control = ['max-age=86400']

View File

@ -0,0 +1,149 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import pecan
from webob import exc
from oslo import messaging
from cerberus.api.v1.controllers import base
from cerberus.common import errors
from cerberus import db
from cerberus.db.sqlalchemy import models
from cerberus.openstack.common import log
LOG = log.getLogger(__name__)
_ENFORCER = None
class PluginsController(base.BaseController):
def list_plugins(self):
""" List all the plugins installed on system """
# Get information about plugins loaded by Cerberus
try:
plugins = self._plugins()
except messaging.RemoteError as e:
LOG.exception(e)
raise
try:
# Get information about plugins stored in db
db_plugins_info = db.plugins_info_get()
except Exception as e:
LOG.exception(e)
raise
plugins_info = []
for plugin_info in db_plugins_info:
plugins_info.append(models.PluginInfoJsonSerializer().
serialize(plugin_info))
plugins_full_info = []
for plugin in plugins:
for plugin_info in plugins_info:
if (plugin.get('name') == plugin_info.get('name')):
plugins_full_info.append(dict(plugin.items()
+ plugin_info.items()))
return plugins_full_info
def _plugins(self):
""" Get a list of plugins loaded by Cerberus Manager """
ctx = pecan.request.context.to_dict()
try:
plugins = self.client.call(ctx, 'get_plugins')
except messaging.RemoteError as e:
LOG.exception(e)
raise
plugins_ = []
for plugin in plugins:
plugin_ = json.loads(plugin)
plugins_.append(plugin_)
return plugins_
@pecan.expose("json")
def get_all(self):
""" Get a list of plugins loaded by Cerberus manager
:return: a list of plugins loaded by Cerberus manager
:raises:
HTTPServiceUnavailable: an error occurred in Cerberus Manager or
the service is unavailable
HTTPNotFound: any other error
"""
# Get information about plugins loaded by Cerberus
try:
plugins = self.list_plugins()
except messaging.RemoteError:
raise exc.HTTPServiceUnavailable()
except Exception as e:
LOG.exception(e)
raise exc.HTTPNotFound()
return {'plugins': plugins}
def get_plugin(self, uuid):
""" Get information about plugin loaded by Cerberus"""
try:
plugin = self._plugin(uuid)
except messaging.RemoteError:
raise
except errors.PluginNotFound:
raise
try:
# Get information about plugin stored in db
db_plugin_info = db.plugin_info_get_from_uuid(uuid)
plugin_info = models.PluginInfoJsonSerializer().\
serialize(db_plugin_info)
except Exception as e:
LOG.exception(e)
raise
return dict(plugin_info.items() + plugin.items())
def _plugin(self, uuid):
""" Get a specific plugin thanks to its identifier """
ctx = pecan.request.context.to_dict()
try:
plugin = self.client.call(ctx, 'get_plugin_from_uuid', uuid=uuid)
except messaging.RemoteError as e:
LOG.exception(e)
raise
if plugin is None:
LOG.exception('Plugin %s not found.' % uuid)
raise errors.PluginNotFound(uuid)
return json.loads(plugin)
@pecan.expose("json")
def get_one(self, uuid):
""" Get details of a specific plugin whose identifier is uuid
:param uuid: the identifier of the plugin
:return: details of a specific plugin
:raises:
HTTPServiceUnavailable: an error occurred in Cerberus Manager or
the service is unavailable
HTTPNotFound: Plugin is not found. Also any other error
"""
try:
plugin = self.get_plugin(uuid)
except messaging.RemoteError:
raise exc.HTTPServiceUnavailable()
except errors.PluginNotFound:
raise exc.HTTPNotFound()
except Exception as e:
LOG.exception(e)
raise exc.HTTPNotFound()
return {'plugin': plugin}

View File

@ -0,0 +1,88 @@
#
# Copyright (c) 2015 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pecan
from webob import exc
from cerberus.api.v1.controllers import base
from cerberus.common import errors
from cerberus import db
from cerberus.db.sqlalchemy import models
from cerberus.openstack.common import log
LOG = log.getLogger(__name__)
class SecurityReportsController(base.BaseController):
def list_security_reports(self, project_id=None):
""" List all the security reports of all projects or just one. """
try:
security_reports = db.security_report_get_all(
project_id=project_id)
except Exception as e:
LOG.exception(e)
raise errors.DbError(
"Security reports could not be retrieved"
)
return security_reports
@pecan.expose("json")
def get_all(self):
""" Get stored security reports.
:return: list of security reports for one or all projects depending on
context of the token.
"""
ctx = pecan.request.context
try:
if ctx.is_admin:
security_reports = self.list_security_reports()
else:
security_reports = self.list_security_reports(ctx.tenant_id)
except errors.DbError:
raise exc.HTTPNotFound()
json_security_reports = []
for security_report in security_reports:
json_security_reports.append(models.SecurityReportJsonSerializer().
serialize(security_report))
return {'security_reports': json_security_reports}
def get_security_report(self, id):
try:
security_report = db.security_report_get(id)
except Exception as e:
LOG.exception(e)
raise errors.DbError(
"Security report %s could not be retrieved" % id
)
return security_report
@pecan.expose("json")
def get_one(self, id):
"""
Get security reports in db
:param req: the HTTP request
:param resp: the HTTP response
:return:
"""
try:
security_report = self.get_security_report(id)
except errors.DbError:
raise exc.HTTPNotFound()
s_report = models.SecurityReportJsonSerializer().\
serialize(security_report)
return {'security_report': s_report}

View File

@ -0,0 +1,312 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import pecan
from webob import exc
from wsme import types as wtypes
from oslo.messaging import rpc
from cerberus.api.v1.controllers import base
from cerberus.common import errors
from cerberus.openstack.common import log
LOG = log.getLogger(__name__)
action_kind = ["stop", "restart", "force_delete"]
action_kind_enum = wtypes.Enum(str, *action_kind)
class Task(wtypes.Base):
""" Representation of a task.
"""
name = wtypes.text
period = wtypes.IntegerType()
method = wtypes.text
plugin_id = wtypes.text
type = wtypes.text
class TasksController(base.BaseController):
@pecan.expose()
def _lookup(self, task_id, *remainder):
return TaskController(task_id), remainder
def list_tasks(self):
ctx = pecan.request.context.to_dict()
try:
tasks = self.client.call(ctx, 'get_tasks')
except rpc.RemoteError as e:
LOG.exception(e)
raise
tasks_ = []
for task in tasks:
task_ = json.loads(task)
tasks_.append(task_)
return tasks_
@pecan.expose("json")
def get(self):
""" List tasks
:return: list of tasks
:raises:
HTTPBadRequest
"""
try:
tasks = self.list_tasks()
except rpc.RemoteError:
raise exc.HTTPServiceUnavailable()
return {'tasks': tasks}
def create_task(self, body):
ctx = pecan.request.context.to_dict()
task = body.get('task', None)
if task is None:
LOG.exception("Task object not provided in request")
raise errors.TaskObjectNotProvided()
plugin_id = task.get('plugin_id', None)
if plugin_id is None:
LOG.exception("Plugin id not provided in request")
raise errors.PluginIdNotProvided()
method_ = task.get('method', None)
if method_ is None:
LOG.exception("Method not provided in request")
raise errors.MethodNotProvided()
try:
task['id'] = self.client.call(
ctx,
'add_task',
uuid=plugin_id,
method_=method_,
task_period=task.get('period', None),
task_name=task.get('name', "unknown"),
task_type=task.get('type', "unique")
)
except rpc.RemoteError as e:
LOG.exception(e)
raise
return task
@pecan.expose("json")
def post(self):
"""Ask Cerberus Manager to call a function of a plugin whose identifier
is uuid, either once or periodically.
:return:
:raises:
HTTPBadRequest: the request is not correct
"""
body_ = pecan.request.body
try:
body = json.loads(body_.decode('utf-8'))
except (ValueError, UnicodeDecodeError) as e:
LOG.exception(e)
raise exc.HTTPBadRequest()
try:
task = self.create_task(body)
except errors.TaskObjectNotProvided:
raise exc.HTTPBadRequest(
explanation='The task object is required.')
except errors.PluginIdNotProvided:
raise exc.HTTPBadRequest(
explanation='Plugin id must be provided as a string')
except errors.MethodNotProvided:
raise exc.HTTPBadRequest(
explanation='Method must be provided as a string')
except rpc.RemoteError as e:
LOG.exception(e)
raise exc.HTTPBadRequest(explanation=e.value)
except Exception as e:
LOG.exception(e)
raise exc.HTTPBadRequest()
return {'task': task}
class TaskController(base.BaseController):
"""Manages operation on a single task."""
_custom_actions = {
'action': ['POST']
}
def __init__(self, task_id):
super(TaskController, self).__init__()
pecan.request.context['task_id'] = task_id
try:
self._id = int(task_id)
except ValueError:
raise exc.HTTPBadRequest(
explanation='Task id must be an integer')
def get_task(self, id):
ctx = pecan.request.context.to_dict()
try:
task = self.client.call(ctx, 'get_task', id=int(id))
except ValueError as e:
LOG.exception(e)
raise
except rpc.RemoteError as e:
LOG.exception(e)
raise
return json.loads(task)
@pecan.expose("json")
def get(self):
""" Get details of a task whose id is id
:param id: the id of the task
:return:
:raises:
HTTPBadRequest
"""
try:
task = self.get_task(self._id)
except ValueError:
raise exc.HTTPBadRequest(
explanation='Task id must be an integer')
except rpc.RemoteError:
raise exc.HTTPNotFound()
except Exception as e:
LOG.exception(e)
raise
return {'task': task}
@pecan.expose("json")
def post(self):
"""
Enable to perform certain actions on a specific task (e.g; stop it)
:param req: the HTTP request, including the action to perform
:param resp: the HTTP response, including a description and the task id
:param id: the identifier of the task on which an action has to be
performed
:return:
:raises:
HTTPError: Incorrect JSON or not UTF-8 encoded
HTTPBadRequest: id not integer or task does not exist
"""
body_ = pecan.request.body
try:
body = json.loads(body_.decode('utf-8'))
except (ValueError, UnicodeDecodeError) as e:
LOG.exception(e)
raise exc.HTTPBadRequest()
if 'stop' in body:
try:
self.stop_task(self._id)
except ValueError:
raise exc.HTTPBadRequest(
explanation="Task id must be an integer")
except rpc.RemoteError:
raise exc.HTTPBadRequest(
explanation="Task can not be stopped")
elif 'forceDelete' in body:
try:
self.force_delete(self._id)
except ValueError:
raise exc.HTTPBadRequest(
explanation="Task id must be an integer")
except rpc.RemoteError as e:
raise exc.HTTPBadRequest(explanation=e.value)
elif 'restart' in body:
try:
self.restart(self._id)
except ValueError:
raise exc.HTTPBadRequest(
explanation="Task id must be an integer")
except rpc.RemoteError as e:
raise exc.HTTPBadRequest(explanation=e.value)
else:
raise exc.HTTPBadRequest()
def stop_task(self, id):
ctx = pecan.request.context.to_dict()
try:
self.client.call(ctx, 'stop_task', id=int(id))
except ValueError as e:
LOG.exception(e)
raise
except rpc.RemoteError as e:
LOG.exception(e)
raise
def force_delete(self, id):
ctx = pecan.request.context.to_dict()
try:
self.client.call(ctx,
'force_delete_recurrent_task',
id=int(id))
except ValueError as e:
LOG.exception(e)
raise
except rpc.RemoteError as e:
LOG.exception(e)
raise
def restart(self, id):
ctx = pecan.request.context.to_dict()
try:
self.client.call(ctx,
'restart_recurrent_task',
id=int(id))
except ValueError as e:
LOG.exception(e)
raise
except rpc.RemoteError as e:
LOG.exception(e)
raise
def delete_task(self, id):
ctx = pecan.request.context.to_dict()
try:
self.client.call(ctx, 'delete_recurrent_task', id=int(id))
except ValueError as e:
LOG.exception(e)
raise
except rpc.RemoteError as e:
LOG.exception(e)
raise
@pecan.expose("json")
def delete(self):
"""
Delete a task specified by its identifier. If the task is running, it
has to be stopped.
:param req: the HTTP request
:param resp: the HTTP response, including a description and the task id
:param id: the identifier of the task to be deleted
:return:
:raises:
HTTPBadRequest: id not an integer or task can't be deleted
"""
try:
self.delete_task(self._id)
except ValueError:
raise exc.HTTPBadRequest(explanation="Task id must be an integer")
except rpc.RemoteError as e:
raise exc.HTTPBadRequest(explanation=e.value)
except Exception as e:
LOG.exception(e)
raise

View File

@ -0,0 +1,15 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View File

@ -0,0 +1,65 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
from keystoneclient.v2_0 import client as keystone_client_v2_0
from oslo.config import cfg
from cerberus.openstack.common import log
cfg.CONF.import_group('service_credentials', 'cerberus.service')
LOG = log.getLogger(__name__)
def logged(func):
@functools.wraps(func)
def with_logging(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
LOG.exception(e)
raise
return with_logging
class Client(object):
"""A client which gets information via python-keystoneclient."""
def __init__(self):
"""Initialize a keystone client object."""
conf = cfg.CONF.service_credentials
self.keystone_client_v2_0 = keystone_client_v2_0.Client(
username=conf.os_username,
password=conf.os_password,
tenant_name=conf.os_tenant_name,
auth_url=conf.os_auth_url,
region_name=conf.os_region_name,
)
@logged
def user_detail_get(self, user):
"""Returns details for a user."""
return self.keystone_client_v2_0.users.get(user)
@logged
def roles_for_user(self, user, tenant=None):
"""Returns role for a given id."""
return self.keystone_client_v2_0.roles.roles_for_user(user, tenant)

View File

@ -0,0 +1,109 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
from neutronclient.v2_0 import client as neutron_client
from oslo.config import cfg
from cerberus.openstack.common import log
cfg.CONF.import_group('service_credentials', 'cerberus.service')
LOG = log.getLogger(__name__)
def logged(func):
@functools.wraps(func)
def with_logging(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
LOG.exception(e)
raise
return with_logging
class Client(object):
"""A client which gets information via python-neutronclient."""
def __init__(self):
"""Initialize a neutron client object."""
conf = cfg.CONF.service_credentials
self.neutronClient = neutron_client.Client(
username=conf.os_username,
password=conf.os_password,
tenant_name=conf.os_tenant_name,
auth_url=conf.os_auth_url,
)
@logged
def list_networks(self, tenant_id):
"""Returns the list of networks of a given tenant"""
return self.neutronClient.list_networks(
tenant_id=tenant_id).get("networks", None)
@logged
def list_floatingips(self, tenant_id):
"""Returns the list of networks of a given tenant"""
return self.neutronClient.list_floatingips(
tenant_id=tenant_id).get("floatingips", None)
@logged
def list_associated_floatingips(self, **params):
"""Returns the list of associated floating ips of a given tenant"""
floating_ips = self.neutronClient.list_floatingips(
**params).get("floatingips", None)
# A floating IP is an IP address on an external network, which is
# associated with a specific port, and optionally a specific IP
# address, on a private OpenStack Networking network. Therefore a
# floating IP allows access to an instance on a private network from an
# external network. Floating IPs can only be defined on networks for
# which the attribute router:external (by the external network
# extension) has been set to True.
associated_floating_ips = []
for floating_ip in floating_ips:
if floating_ip.get("port_id") is not None:
associated_floating_ips.append(floating_ip)
return associated_floating_ips
@logged
def net_ips_get(self, network_id):
"""
Return ip pools used in all subnets of a network
:param network_id:
:return: list of pools
"""
subnets = self.neutronClient.show_network(
network_id)["network"]["subnets"]
ips = []
for subnet in subnets:
ips.append(self.subnet_ips_get(subnet))
return ips
@logged
def get_net_of_subnet(self, subnet_id):
return self.neutronClient.show_subnet(
subnet_id)["subnet"]["network_id"]
@logged
def subnet_ips_get(self, network_id):
"""Returns ip pool of a subnet."""
return self.neutronClient.show_subnet(
network_id)["subnet"]["allocation_pools"]

View File

@ -0,0 +1,109 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
from novaclient.v3 import client as nova_client
from oslo.config import cfg
from cerberus.openstack.common import log
OPTS = [
cfg.BoolOpt('nova_http_log_debug',
default=False,
help='Allow novaclient\'s debug log output.'),
]
SERVICE_OPTS = [
cfg.StrOpt('nova',
default='compute',
help='Nova service type.'),
]
cfg.CONF.register_opts(OPTS)
cfg.CONF.register_opts(SERVICE_OPTS, group='service_types')
# cfg.CONF.import_opt('http_timeout', 'cerberus.service')
cfg.CONF.import_group('service_credentials', 'cerberus.service')
LOG = log.getLogger(__name__)
def logged(func):
@functools.wraps(func)
def with_logging(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
LOG.exception(e)
raise
return with_logging
class Client(object):
"""A client which gets information via python-novaclient."""
def __init__(self, bypass_url=None, auth_token=None):
"""Initialize a nova client object."""
conf = cfg.CONF.service_credentials
tenant = conf.os_tenant_id or conf.os_tenant_name
self.nova_client = nova_client.Client(
username=conf.os_username,
project_id=tenant,
auth_url=conf.os_auth_url,
password=conf.os_password,
region_name=conf.os_region_name,
endpoint_type=conf.os_endpoint_type,
service_type=cfg.CONF.service_types.nova,
bypass_url=bypass_url,
cacert=conf.os_cacert,
insecure=conf.insecure,
http_log_debug=cfg.CONF.nova_http_log_debug,
no_cache=True)
@logged
def instance_get_all(self):
"""Returns list of all instances."""
search_opts = {'all_tenants': True}
return self.nova_client.servers.list(
detailed=True,
search_opts=search_opts)
@logged
def get_instance_details_from_floating_ip(self, ip):
"""
Get instance_id which is associated to the floating ip "ip"
:param ip: the floating ip that should belong to an instance
:return instance_id if ip is found, else None
"""
instances = self.instance_get_all()
try:
for instance in instances:
# An instance can belong to many networks. An instance can
# have two ips in a network:
# at least a private ip and potentially a floating ip
addresses_in_networks = instance.addresses.values()
for addresses_in_network in addresses_in_networks:
for address_in_network in addresses_in_network:
if ((address_in_network.get('OS-EXT-IPS:type', None)
== 'floating')
and (address_in_network['addr'] == ip)):
return instance
except Exception as e:
LOG.exception(e)
raise
return None

15
cerberus/cmd/__init__.py Normal file
View File

@ -0,0 +1,15 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

46
cerberus/cmd/agent.py Normal file
View File

@ -0,0 +1,46 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import eventlet
import sys
from oslo.config import cfg
from cerberus.common import config
from cerberus import manager
from cerberus.openstack.common import log
from cerberus.openstack.common import service
eventlet.monkey_patch()
LOG = log.getLogger(__name__)
def main():
log.set_defaults(cfg.CONF.default_log_levels)
argv = sys.argv
config.parse_args(argv)
log.setup(cfg.CONF, 'cerberus')
launcher = service.ProcessLauncher()
c_manager = manager.CerberusManager()
launcher.launch_service(c_manager)
launcher.wait()
if __name__ == '__main__':
main()

45
cerberus/cmd/api.py Normal file
View File

@ -0,0 +1,45 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
from oslo.config import cfg
from cerberus.api import app
from cerberus.common import config
from cerberus.openstack.common import log
CONF = cfg.CONF
CONF.import_opt('auth_strategy', 'cerberus.api')
LOG = log.getLogger(__name__)
def main():
argv = sys.argv
config.parse_args(argv)
log.setup(cfg.CONF, 'cerberus')
server = app.build_server()
log.set_defaults(cfg.CONF.default_log_levels)
try:
server.serve_forever()
except KeyboardInterrupt:
pass
LOG.info("cerberus-api starting...")
if __name__ == '__main__':
main()

43
cerberus/cmd/db_create.py Normal file
View File

@ -0,0 +1,43 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
from oslo.config import cfg
from sqlalchemy import create_engine
from cerberus.common import config
from cerberus.db.sqlalchemy import models
def main():
argv = sys.argv
config.parse_args(argv)
engine = create_engine(cfg.CONF.database.connection)
conn = engine.connect()
try:
conn.execute("CREATE DATABASE cerberus")
except Exception:
pass
models.BASE.metadata.create_all(engine)
conn.close()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,15 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

26
cerberus/common/config.py Normal file
View File

@ -0,0 +1,26 @@
#
# Copyright (c) 2015 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from oslo.config import cfg
from cerberus import version
def parse_args(argv, default_config_files=None):
cfg.CONF(argv[1:],
project='cerberus',
version=version.version_info.release_string(),
default_config_files=default_config_files)

View File

@ -0,0 +1,64 @@
# -*- encoding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from cerberus.openstack.common import context
class RequestContext(context.RequestContext):
"""Extends security contexts from the OpenStack common library."""
def __init__(self, auth_token=None, domain_id=None, domain_name=None,
user=None, tenant_id=None, tenant=None, is_admin=False,
is_public_api=False, read_only=False, show_deleted=False,
request_id=None, roles=None):
"""Stores several additional request parameters:
:param domain_id: The ID of the domain.
:param domain_name: The name of the domain.
:param is_public_api: Specifies whether the request should be processed
without authentication.
"""
self.tenant_id = tenant_id
self.is_public_api = is_public_api
self.domain_id = domain_id
self.domain_name = domain_name
self.roles = roles or []
super(RequestContext, self).__init__(auth_token=auth_token,
user=user, tenant=tenant,
is_admin=is_admin,
read_only=read_only,
show_deleted=show_deleted,
request_id=request_id)
def to_dict(self):
return {'auth_token': self.auth_token,
'user': self.user,
'tenant_id': self.tenant_id,
'tenant': self.tenant,
'is_admin': self.is_admin,
'read_only': self.read_only,
'show_deleted': self.show_deleted,
'request_id': self.request_id,
'domain_id': self.domain_id,
'roles': self.roles,
'domain_name': self.domain_name,
'is_public_api': self.is_public_api}
@classmethod
def from_dict(cls, values):
values.pop('user', None)
values.pop('tenant', None)
return cls(**values)

124
cerberus/common/errors.py Normal file
View File

@ -0,0 +1,124 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from cerberus.openstack.common._i18n import _ # noqa
class InvalidOperation(Exception):
def __init__(self, description):
super(InvalidOperation, self).__init__(description)
class PluginNotFound(InvalidOperation):
def __init__(self, uuid):
super(PluginNotFound, self).__init__("Plugin %s does not exist"
% str(uuid))
class TaskPeriodNotInteger(InvalidOperation):
def __init__(self):
super(TaskPeriodNotInteger, self).__init__(
"The period of the task must be provided as an integer"
)
class TaskNotFound(InvalidOperation):
def __init__(self, _id):
super(TaskNotFound, self).__init__(
_('Task %s does not exist') % _id
)
class TaskDeletionNotAllowed(InvalidOperation):
def __init__(self, _id):
super(TaskDeletionNotAllowed, self).__init__(
_("Deletion of task %s is not allowed because either it "
"does not exist or it is not recurrent") % _id
)
class TaskRestartNotAllowed(InvalidOperation):
def __init__(self, _id):
super(TaskRestartNotAllowed, self).__init__(
_("Restarting task %s is not allowed because either it "
"does not exist or it is not recurrent") % _id
)
class TaskRestartNotPossible(InvalidOperation):
def __init__(self, _id):
super(TaskRestartNotPossible, self).__init__(
_("Restarting task %s is not possible because it is running") % _id
)
class MethodNotString(InvalidOperation):
def __init__(self):
super(MethodNotString, self).__init__(
"Method must be provided as a string"
)
class MethodNotCallable(InvalidOperation):
def __init__(self, method, name):
super(MethodNotCallable, self).__init__(
"Method named %s is not callable by plugin %s"
% (str(method), str(name))
)
class TaskObjectNotProvided(InvalidOperation):
def __init__(self):
super(TaskObjectNotProvided, self).__init__(
"Task object not provided in request"
)
class PluginIdNotProvided(InvalidOperation):
def __init__(self):
super(PluginIdNotProvided, self).__init__(
"Plugin id not provided in request"
)
class MethodNotProvided(InvalidOperation):
def __init__(self):
super(MethodNotProvided, self).__init__(
"Method not provided in request"
)
class PolicyEnforcementError(Exception):
def __init__(self):
super(PolicyEnforcementError, self).__init__(
"Policy enforcement error"
)
class DbError(Exception):
def __init__(self, description):
super(DbError, self).__init__(description)

View File

@ -0,0 +1,154 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Cerberus base exception handling.
Includes decorator for re-raising Nova-type exceptions.
SHOULD include dedicated exception logging.
"""
import functools
import gettext as t
import logging
import sys
import webob.exc
from oslo.config import cfg
from cerberus.common import safe_utils
from cerberus.openstack.common import excutils
_ = t.gettext
LOG = logging.getLogger(__name__)
exc_log_opts = [
cfg.BoolOpt('fatal_exception_format_errors',
default=False,
help='Make exception message format errors fatal'),
]
CONF = cfg.CONF
CONF.register_opts(exc_log_opts)
class ConvertedException(webob.exc.WSGIHTTPException):
def __init__(self, code=0, title="", explanation=""):
self.code = code
self.title = title
self.explanation = explanation
super(ConvertedException, self).__init__()
def _cleanse_dict(original):
"""Strip all admin_password, new_pass, rescue_pass keys from a dict."""
return dict((k, v) for k, v in original.iteritems() if "_pass" not in k)
def wrap_exception(notifier=None, get_notifier=None):
"""This decorator wraps a method to catch any exceptions that may
get thrown. It logs the exception as well as optionally sending
it to the notification system.
"""
def inner(f):
def wrapped(self, context, *args, **kw):
# Don't store self or context in the payload, it now seems to
# contain confidential information.
try:
return f(self, context, *args, **kw)
except Exception as e:
with excutils.save_and_reraise_exception():
if notifier or get_notifier:
payload = dict(exception=e)
call_dict = safe_utils.getcallargs(f, context,
*args, **kw)
cleansed = _cleanse_dict(call_dict)
payload.update({'args': cleansed})
# If f has multiple decorators, they must use
# functools.wraps to ensure the name is
# propagated.
event_type = f.__name__
(notifier or get_notifier()).error(context,
event_type,
payload)
return functools.wraps(f)(wrapped)
return inner
class CerberusException(Exception):
"""Base Cerberus Exception
To correctly use this class, inherit from it and define
a 'msg_fmt' property. That msg_fmt will get printf'd
with the keyword arguments provided to the constructor.
"""
msg_fmt = _("An unknown exception occurred.")
code = 500
headers = {}
safe = False
def __init__(self, message=None, **kwargs):
self.kwargs = kwargs
if 'code' not in self.kwargs:
try:
self.kwargs['code'] = self.code
except AttributeError:
pass
if not message:
try:
message = self.msg_fmt % kwargs
except Exception:
exc_info = sys.exc_info()
# kwargs doesn't match a variable in the message
# log the issue and the kwargs
LOG.exception(_('Exception in string format operation'))
for name, value in kwargs.iteritems():
LOG.error("%s: %s" % (name, value)) # noqa
if CONF.fatal_exception_format_errors:
raise exc_info[0], exc_info[1], exc_info[2]
else:
# at least get the core message out if something happened
message = self.msg_fmt
super(CerberusException, self).__init__(message)
def format_message(self):
# NOTE(mrodden): use the first argument to the python Exception object
# which should be our full NovaException message, (see __init__)
return self.args[0]
class AlertExists(CerberusException):
msg_fmt = _("Alert %(alert_id)s already exists.")
class ReportExists(CerberusException):
msg_fmt = _("Report %(report_id)s already exists.")
class PluginInfoExists(CerberusException):
msg_fmt = _("Plugin info %(id)s already exists.")

View File

@ -0,0 +1,26 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import json
class DateTimeEncoder(json.JSONEncoder):
def default(self, obj):
"""JSON serializer for objects not serializable by default json code"""
if isinstance(obj, datetime.datetime):
serial = obj.isoformat()
return serial

67
cerberus/common/policy.py Normal file
View File

@ -0,0 +1,67 @@
# Copyright (c) 2011 OpenStack Foundation
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Policy Engine For Cerberus."""
from oslo.config import cfg
from cerberus.openstack.common import policy
_ENFORCER = None
CONF = cfg.CONF
def init_enforcer(policy_file=None, rules=None,
default_rule=None, use_conf=True):
"""Synchronously initializes the policy enforcer
:param policy_file: Custom policy file to use, if none is specified,
`CONF.policy_file` will be used.
:param rules: Default dictionary / Rules to use. It will be
considered just in the first instantiation.
:param default_rule: Default rule to use, CONF.default_rule will
be used if none is specified.
:param use_conf: Whether to load rules from config file.
"""
global _ENFORCER
if _ENFORCER:
return
_ENFORCER = policy.Enforcer(policy_file=policy_file,
rules=rules,
default_rule=default_rule,
use_conf=use_conf)
def get_enforcer():
"""Provides access to the single instance of Policy enforcer."""
if not _ENFORCER:
init_enforcer()
return _ENFORCER
def enforce(rule, target, creds, do_raise=False, exc=None, *args, **kwargs):
"""A shortcut for policy.Enforcer.enforce()
Checks authorization of a rule against the target and credentials.
"""
enforcer = get_enforcer()
return enforcer.enforce(rule, target, creds, do_raise=do_raise,
exc=exc, *args, **kwargs)

View File

@ -0,0 +1,70 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Utilities and helper functions that won't produce circular imports."""
import inspect
import six
from cerberus.openstack.common import log
LOG = log.getLogger(__name__)
def getcallargs(function, *args, **kwargs):
"""This is a simplified inspect.getcallargs (2.7+).
It should be replaced when python >= 2.7 is standard.
"""
keyed_args = {}
argnames, varargs, keywords, defaults = inspect.getargspec(function)
keyed_args.update(kwargs)
# NOTE(alaski) the implicit 'self' or 'cls' argument shows up in
# argnames but not in args or kwargs. Uses 'in' rather than '==' because
# some tests use 'self2'.
if 'self' in argnames[0] or 'cls' == argnames[0]:
# The function may not actually be a method or have im_self.
# Typically seen when it's stubbed with mox.
if inspect.ismethod(function) and hasattr(function, 'im_self'):
keyed_args[argnames[0]] = function.im_self
else:
keyed_args[argnames[0]] = None
remaining_argnames = filter(lambda x: x not in keyed_args, argnames)
keyed_args.update(dict(zip(remaining_argnames, args)))
if defaults:
num_defaults = len(defaults)
for argname, value in zip(argnames[-num_defaults:], defaults):
if argname not in keyed_args:
keyed_args[argname] = value
return keyed_args
def safe_rstrip(value, chars=None):
"""Removes trailing characters from a string if that does not make it empty
:param value: A string value that will be stripped.
:param chars: Characters to remove.
:return: Stripped value.
"""
if not isinstance(value, six.string_types):
LOG.warn(("Failed to remove trailing character. Returning original "
"object. Supplied object is not a string: %s,") % value)
return value
return value.rstrip(chars) or value

View File

@ -0,0 +1,110 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import parser
class JsonSerializer(object):
"""A serializer that provides methods to serialize and deserialize JSON
dictionaries.
Note, one of the assumptions this serializer makes is that all objects that
it is used to deserialize have a constructor that can take all of the
attribute arguments. I.e. If you have an object with 3 attributes, the
constructor needs to take those three attributes as keyword arguments.
"""
__attributes__ = None
"""The attributes to be serialized by the seralizer.
The implementor needs to provide these."""
__required__ = None
"""The attributes that are required when deserializing.
The implementor needs to provide these."""
__attribute_serializer__ = None
"""The serializer to use for a specified attribute. If an attribute is not
included here, no special serializer will be user.
The implementor needs to provide these."""
__object_class__ = None
"""The class that the deserializer should generate.
The implementor needs to provide these."""
serializers = dict(
date=dict(
serialize=lambda x: x.isoformat(),
deserialize=lambda x: parser.parse(x)
)
)
def deserialize(self, json, **kwargs):
"""Deserialize a JSON dictionary and return a populated object.
This takes the JSON data, and deserializes it appropriately and then
calls the constructor of the object to be created with all of the
attributes.
Args:
json: The JSON dict with all of the data
**kwargs: Optional values that can be used as defaults if they are
not present in the JSON data
Returns:
The deserialized object.
Raises:
ValueError: If any of the required attributes are not present
"""
d = dict()
for attr in self.__attributes__:
if attr in json:
val = json[attr]
elif attr in self.__required__:
try:
val = kwargs[attr]
except KeyError:
raise ValueError("{} must be set".format(attr))
serializer = self.__attribute_serializer__.get(attr)
if serializer:
d[attr] = self.serializers[serializer]['deserialize'](val)
else:
d[attr] = val
return self.__object_class__(**d)
def serialize(self, obj):
"""Serialize an object to a dictionary.
Take all of the attributes defined in self.__attributes__ and create
a dictionary containing those values.
Args:
obj: The object to serialize
Returns:
A dictionary containing all of the serialized data from the object.
"""
d = dict()
for attr in self.__attributes__:
val = getattr(obj, attr)
if val is None:
continue
serializer = self.__attribute_serializer__.get(attr)
if serializer:
d[attr] = self.serializers[serializer]['serialize'](val)
else:
d[attr] = val
return d

17
cerberus/db/__init__.py Normal file
View File

@ -0,0 +1,17 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from cerberus.db.api import * # noqa

115
cerberus/db/api.py Normal file
View File

@ -0,0 +1,115 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from oslo.config import cfg
from cerberus.openstack.common.db import api as db_api
CONF = cfg.CONF
CONF.import_opt('backend', 'cerberus.openstack.common.db.options',
group='database')
_BACKEND_MAPPING = {'sqlalchemy': 'cerberus.db.sqlalchemy.api'}
IMPL = db_api.DBAPI(CONF.database.backend, backend_mapping=_BACKEND_MAPPING,
lazy=True)
''' JUNO:
IMPL = db_api.DBAPI.from_config(cfg.CONF,
backend_mapping=_BACKEND_MAPPING,
lazy=True)
'''
def get_instance():
"""Return a DB API instance."""
return IMPL
def get_engine():
return IMPL.get_engine()
def get_session():
return IMPL.get_session()
def db_sync(engine, version=None):
"""Migrate the database to `version` or the most recent version."""
return IMPL.db_sync(engine, version=version)
def alert_create(values):
"""Create an instance from the values dictionary."""
return IMPL.alert_create(values)
def alert_get_all():
"""Get all alerts"""
return IMPL.alert_get_all()
def security_report_create(values):
"""Create an instance from the values dictionary."""
return IMPL.security_report_create(values)
def security_report_update_last_report_date(id, date):
"""Create an instance from the values dictionary."""
return IMPL.security_report_update_last_report_date(id, date)
def security_report_get_all(project_id=None):
"""Get all alerts"""
return IMPL.security_report_get_all(project_id=project_id)
def security_report_get(id):
"""Get all alerts"""
return IMPL.security_report_get(id)
def security_report_get_from_report_id(report_id):
"""Get all alerts"""
return IMPL.security_report_get_from_report_id(report_id)
def plugins_info_get():
"""Get information about plugins stored in db"""
return IMPL.plugins_info_get()
def plugin_info_get_from_uuid(id):
"""
Get information about plugin stored in db
:param id: the uuid of the plugin
"""
return IMPL.plugin_info_get_from_uuid(id)
def plugin_version_update(id, version):
return IMPL.plugin_version_update(id, version)
def security_alarm_create(values):
return IMPL.security_alarm_create(values)
def security_alarm_get_all():
return IMPL.security_alarm_get_all()
def security_alarm_get(id):
return IMPL.security_alarm_get(id)

View File

@ -0,0 +1,15 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View File

@ -0,0 +1,271 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sqlalchemy
import sys
import threading
from oslo.config import cfg
from cerberus.common import exception
from cerberus.db.sqlalchemy import migration
from cerberus.db.sqlalchemy import models
from cerberus.openstack.common.db import exception as db_exc
from cerberus.openstack.common.db.sqlalchemy import session as db_session
from cerberus.openstack.common import log
CONF = cfg.CONF
LOG = log.getLogger(__name__)
_ENGINE_FACADE = None
_LOCK = threading.Lock()
_FACADE = None
def _create_facade_lazily():
global _FACADE
if _FACADE is None:
_FACADE = db_session.EngineFacade(
CONF.database.connection,
**dict(CONF.database.iteritems())
)
return _FACADE
def get_engine():
facade = _create_facade_lazily()
return facade.get_engine()
def get_session(**kwargs):
facade = _create_facade_lazily()
return facade.get_session(**kwargs)
def get_backend():
"""The backend is this module itself."""
return sys.modules[__name__]
def model_query(model, *args, **kwargs):
"""Query helper for simpler session usage.
:param session: if present, the session to use
"""
session = kwargs.get('session') or get_session()
query = session.query(model, *args)
return query
def _alert_get_all(session=None):
session = get_session()
return model_query(models.Alert, read_deleted="no",
session=session).all()
def alert_create(values):
alert_ref = models.Alert()
alert_ref.update(values)
try:
alert_ref.save()
except db_exc.DBDuplicateEntry:
raise exception.AlertExists(id=values['id'])
return alert_ref
def alert_get_all():
return _alert_get_all()
def _security_report_get_all(project_id=None):
session = get_session()
try:
if project_id is None:
return model_query(models.SecurityReport, read_deleted="no",
session=session).all()
else:
return model_query(models.SecurityReport, read_deleted="no",
session=session).\
filter(models.SecurityReport.project_id == project_id).all()
except Exception as e:
LOG.exception(e)
raise e
def _security_report_get(id):
session = get_session()
return model_query(models.SecurityReport, read_deleted="no",
session=session).filter(models.SecurityReport.
id == id).first()
def _security_report_get_from_report_id(report_id):
session = get_session()
return model_query(models.SecurityReport, read_deleted="no",
session=session).filter(models.SecurityReport.report_id
== report_id).first()
def security_report_create(values):
security_report_ref = models.SecurityReport()
security_report_ref.update(values)
try:
security_report_ref.save()
except sqlalchemy.exc.OperationalError as e:
LOG.exception(e)
raise db_exc.ColumnError
return security_report_ref
def security_report_update_last_report_date(id, date):
session = get_session()
report = model_query(models.SecurityReport, read_deleted="no",
session=session).filter(models.SecurityReport.id
== id).first()
report.last_report_date = date
try:
report.save(session)
except sqlalchemy.exc.OperationalError as e:
LOG.exception(e)
raise db_exc.ColumnError
def security_report_get_all(project_id=None):
return _security_report_get_all(project_id=project_id)
def security_report_get(id):
return _security_report_get(id)
def security_report_get_from_report_id(report_id):
return _security_report_get_from_report_id(report_id)
def _plugin_info_get(name):
session = get_session()
return model_query(models.PluginInfo,
read_deleted="no",
session=session).filter(models.PluginInfo.name ==
name).first()
def _plugin_info_get_from_uuid(id):
session = get_session()
return model_query(models.PluginInfo,
read_deleted="no",
session=session).filter(models.PluginInfo.uuid ==
id).first()
def _plugins_info_get():
session = get_session()
return model_query(models.PluginInfo,
read_deleted="no",
session=session).all()
def plugin_info_create(values):
plugin_info_ref = models.PluginInfo()
plugin_info_ref.update(values)
try:
plugin_info_ref.save()
except db_exc.DBDuplicateEntry:
raise exception.PluginInfoExists(id=values['id'])
return plugin_info_ref
def plugin_info_get(name):
return _plugin_info_get(name)
def plugin_info_get_from_uuid(id):
return _plugin_info_get_from_uuid(id)
def plugins_info_get():
return _plugins_info_get()
def plugin_version_update(id, version):
session = get_session()
plugin = model_query(models.PluginInfo, read_deleted="no",
session=session).filter(models.PluginInfo.id ==
id).first()
plugin.version = version
try:
plugin.save(session)
except sqlalchemy.exc.OperationalError as e:
LOG.exception(e)
raise db_exc.ColumnError
def db_sync(engine, version=None):
"""Migrate the database to `version` or the most recent version."""
return migration.db_sync(engine, version=version)
def db_version(engine):
"""Display the current database version."""
return migration.db_version(engine)
def _security_alarm_get_all():
session = get_session()
try:
return model_query(models.SecurityAlarm, read_deleted="no",
session=session).all()
except Exception as e:
LOG.exception(e)
raise e
def _security_alarm_get(id):
session = get_session()
return model_query(models.SecurityAlarm, read_deleted="no",
session=session).filter(models.SecurityAlarm.
id == id).first()
def security_alarm_create(values):
security_alarm_ref = models.SecurityAlarm()
security_alarm_ref.update(values)
try:
security_alarm_ref.save()
except sqlalchemy.exc.OperationalError as e:
LOG.exception(e)
raise db_exc.ColumnError
return security_alarm_ref
def security_alarm_get_all():
return _security_alarm_get_all()
def security_alarm_get(id):
return _security_alarm_get(id)

View File

@ -0,0 +1,4 @@
This is a database migration repository.
More information at
http://code.google.com/p/sqlalchemy-migrate/

View File

@ -0,0 +1,5 @@
#!/usr/bin/env python
from migrate.versioning.shell import main
if __name__ == '__main__':
main(debug='False')

View File

@ -0,0 +1,25 @@
[db_settings]
# Used to identify which repository this database is versioned under.
# You can use the name of your project.
repository_id=cerberus
# The name of the database table used to track the schema version.
# This name shouldn't already be used by your project.
# If this is changed once a database is under version control, you'll need to
# change the table name in each database too.
version_table=migrate_version
# When committing a change script, Migrate will attempt to generate the
# sql for all supported databases; normally, if one of them fails - probably
# because you don't have that database installed - it is ignored and the
# commit continues, perhaps ending successfully.
# Databases in this list MUST compile successfully during a commit, or the
# entire commit will fail. List the databases your application will actually
# be using to ensure your updates to that database work properly.
# This must be a list; example: ['postgres','sqlite']
required_dbs=[]
# When creating new change scripts, Migrate will stamp the new script with
# a version number. By default this is latest_version + 1. You can set this
# to 'true' to tell Migrate to use the UTC timestamp instead.
use_timestamp_numbering=False

View File

@ -0,0 +1,102 @@
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import sqlalchemy
def upgrade(migrate_engine):
meta = sqlalchemy.MetaData()
meta.bind = migrate_engine
alert = sqlalchemy.Table(
'alert', meta,
sqlalchemy.Column('id', sqlalchemy.Integer, primary_key=True,
nullable=False),
sqlalchemy.Column('title', sqlalchemy.Text),
sqlalchemy.Column('status', sqlalchemy.Text),
sqlalchemy.Column('severity', sqlalchemy.Integer),
sqlalchemy.Column('acknowledged_at', sqlalchemy.DateTime),
sqlalchemy.Column('plugin_id', sqlalchemy.Text),
sqlalchemy.Column('description', sqlalchemy.Text),
sqlalchemy.Column('resource_id', sqlalchemy.Text),
sqlalchemy.Column('issue_link', sqlalchemy.Text),
sqlalchemy.Column('created_at', sqlalchemy.DateTime),
sqlalchemy.Column('updated_at', sqlalchemy.DateTime),
sqlalchemy.Column('deleted_at', sqlalchemy.DateTime),
sqlalchemy.Column('deleted', sqlalchemy.Integer),
mysql_engine='InnoDB',
mysql_charset='utf8'
)
plugin_info = sqlalchemy.Table(
'plugin_info', meta,
sqlalchemy.Column('id', sqlalchemy.Integer, primary_key=True,
nullable=False),
sqlalchemy.Column('uuid', sqlalchemy.Text),
sqlalchemy.Column('name', sqlalchemy.Text),
sqlalchemy.Column('version', sqlalchemy.Integer),
sqlalchemy.Column('provider', sqlalchemy.DateTime),
sqlalchemy.Column('type', sqlalchemy.Text),
sqlalchemy.Column('description', sqlalchemy.Text),
sqlalchemy.Column('tool_name', sqlalchemy.Text),
sqlalchemy.Column('created_at', sqlalchemy.DateTime),
sqlalchemy.Column('updated_at', sqlalchemy.DateTime),
sqlalchemy.Column('deleted_at', sqlalchemy.DateTime),
sqlalchemy.Column('deleted', sqlalchemy.Integer),
mysql_engine='InnoDB',
mysql_charset='utf8'
)
security_report = sqlalchemy.Table(
'security_report', meta,
sqlalchemy.Column('id', sqlalchemy.Integer, primary_key=True,
nullable=False),
sqlalchemy.Column('plugin_id', sqlalchemy.Text),
sqlalchemy.Column('report_id', sqlalchemy.Text, unique=True),
sqlalchemy.Column('component_id', sqlalchemy.Text),
sqlalchemy.Column('component_type', sqlalchemy.Text),
sqlalchemy.Column('component_name', sqlalchemy.Text),
sqlalchemy.Column('project_id', sqlalchemy.Text),
sqlalchemy.Column('title', sqlalchemy.Text),
sqlalchemy.Column('description', sqlalchemy.Text),
sqlalchemy.Column('security_rating', sqlalchemy.Float),
sqlalchemy.Column('vulnerabilities', sqlalchemy.Text),
sqlalchemy.Column('vulnerabilities_number', sqlalchemy.Integer),
sqlalchemy.Column('last_report_date', sqlalchemy.DateTime),
sqlalchemy.Column('created_at', sqlalchemy.DateTime),
sqlalchemy.Column('updated_at', sqlalchemy.DateTime),
sqlalchemy.Column('deleted_at', sqlalchemy.DateTime),
sqlalchemy.Column('deleted', sqlalchemy.Integer),
mysql_engine='InnoDB',
mysql_charset='utf8'
)
tables = (
security_report,
alert,
plugin_info,
)
for index, table in enumerate(tables):
try:
table.create()
except Exception:
# If an error occurs, drop all tables created so far to return
# to the previously existing state.
meta.drop_all(tables=tables[:index])
raise
def downgrade(migrate_engine):
raise NotImplementedError('Database downgrade not supported - '
'would drop all tables')

View File

@ -0,0 +1,38 @@
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
from cerberus.openstack.common.db.sqlalchemy import migration as oslo_migration
INIT_VERSION = 14
def db_sync(engine, version=None):
path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
'migrate_repo')
return oslo_migration.db_sync(engine, path, version,
init_version=INIT_VERSION)
def db_version(engine):
path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
'migrate_repo')
return oslo_migration.db_version(engine, path, INIT_VERSION)
def db_version_control(engine, version=None):
path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
'migrate_repo')
return oslo_migration.db_version_control(engine, path, version)

View File

@ -0,0 +1,171 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
SQLAlchemy models for cerberus data.
"""
from sqlalchemy import Column, String, Integer, DateTime, Float, Text
from sqlalchemy.ext.declarative import declarative_base
from oslo.config import cfg
from cerberus.common import serialize
from cerberus.openstack.common.db.sqlalchemy import models
CONF = cfg.CONF
BASE = declarative_base()
class CerberusBase(models.SoftDeleteMixin,
models.TimestampMixin,
models.ModelBase):
metadata = None
def save(self, session=None):
from cerberus.db.sqlalchemy import api
if session is None:
session = api.get_session()
super(CerberusBase, self).save(session=session)
class Alert(BASE, CerberusBase):
"""Security alert"""
__tablename__ = 'alert'
__table_args__ = ()
id = Column(Integer, primary_key=True)
title = Column(String(255))
status = Column(String(255))
severity = Column(Integer)
acknowledged_at = Column(DateTime)
plugin_id = Column(String(255))
description = Column(String(255))
resource_id = Column(String(255))
issue_link = Column(String(255))
class AlertJsonSerializer(serialize.JsonSerializer):
"""Alert serializer"""
__attributes__ = ['id', 'title', 'description', 'status', 'severity',
'created_at', 'deleted_at', 'updated_at',
'acknowledged_at', 'plugin_id', 'resource_id',
'issue_link', 'deleted']
__required__ = ['id', 'title']
__attribute_serializer__ = dict(created_at='date', deleted_at='date',
updated_at='date', acknowledged_at='date')
__object_class__ = Alert
class PluginInfo(BASE, CerberusBase):
"""Plugin info"""
__tablename__ = 'plugin_info'
__table_args__ = ()
id = Column(Integer, primary_key=True)
uuid = Column(String(255))
name = Column(String(255))
version = Column(String(255))
provider = Column(String(255))
type = Column(String(255))
description = Column(String(255))
tool_name = Column(String(255))
class PluginInfoJsonSerializer(serialize.JsonSerializer):
"""Plugin info serializer"""
__attributes__ = ['id', 'uuid', 'name', 'version', 'provider',
'type', 'description', 'tool_name']
__required__ = ['id']
__attribute_serializer__ = dict(created_at='date', deleted_at='date',
acknowledged_at='date')
__object_class__ = PluginInfo
class SecurityReport(BASE, CerberusBase):
"""Security Report"""
__tablename__ = 'security_report'
__table_args__ = ()
id = Column(Integer, primary_key=True)
plugin_id = Column(String(255))
report_id = Column(String(255), unique=True)
component_id = Column(String(255))
component_type = Column(String(255))
component_name = Column(String(255))
project_id = Column(String(255))
title = Column(String(255))
description = Column(String(255))
security_rating = Column(Float)
vulnerabilities = Column(Text)
vulnerabilities_number = Column(Integer)
last_report_date = Column(DateTime)
class SecurityReportJsonSerializer(serialize.JsonSerializer):
"""Security report serializer"""
__attributes__ = ['id', 'title', 'description', 'plugin_id', 'report_id',
'component_id', 'component_type', 'component_name',
'project_id', 'security_rating', 'vulnerabilities',
'vulnerabilities_number', 'last_report_date', 'deleted',
'created_at', 'deleted_at', 'updated_at']
__required__ = ['id', 'title', 'component_id']
__attribute_serializer__ = dict(created_at='date', deleted_at='date',
acknowledged_at='date')
__object_class__ = SecurityReport
class SecurityAlarm(BASE, CerberusBase):
"""Security alarm coming from Security Information and Event Manager
for example
"""
__tablename__ = 'security_alarm'
__table_args__ = ()
id = Column(Integer, primary_key=True)
plugin_id = Column(String(255))
alarm_id = Column(String(255), unique=True)
timestamp = Column(DateTime)
status = Column(String(255))
severity = Column(String(255))
component_id = Column(String(255))
summary = Column(String(255))
description = Column(String(255))
class SecurityAlarmJsonSerializer(serialize.JsonSerializer):
"""Security report serializer"""
__attributes__ = ['id', 'plugin_id', 'alarm_id', 'timestamp', 'status',
'severity', 'component_id', 'summary',
'project_id', 'security_rating', 'vulnerabilities',
'description', 'deleted', 'created_at', 'deleted_at',
'updated_at']
__required__ = ['id', 'title']
__attribute_serializer__ = dict(created_at='date', deleted_at='date',
acknowledged_at='date')
__object_class__ = SecurityAlarm

426
cerberus/manager.py Normal file
View File

@ -0,0 +1,426 @@
#
# Copyright (c) 2014 EUROGICIEL
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import uuid
from oslo.config import cfg
from oslo import messaging
from stevedore import extension
from cerberus.common import errors
from cerberus.db.sqlalchemy import api
from cerberus.openstack.common import log
from cerberus.openstack.common import loopingcall
from cerberus.openstack.common import service
from cerberus.openstack.common import threadgroup
from plugins import base
LOG = log.getLogger(__name__)
OPTS = [
cfg.MultiStrOpt('messaging_urls',
default=[],
help="Messaging URLs to listen for notifications. "
"Example: transport://user:pass@host1:port"
"[,hostN:portN]/virtual_host "
"(DEFAULT/transport_url is used if empty)"),
cfg.ListOpt('notification-topics', default=['designate']),
cfg.ListOpt('cerberus_control_exchange', default=['cerberus']),
]
cfg.CONF.register_opts(OPTS)
class CerberusManager(service.Service):
TASK_NAMESPACE = 'cerberus.plugins'
@classmethod
def _get_cerberus_manager(cls):
return extension.ExtensionManager(
namespace=cls.TASK_NAMESPACE,
invoke_on_load=True,
)
def __init__(self):
self.task_id = 0
super(CerberusManager, self).__init__()
def _register_plugin(self, extension):
# Record plugin in database
version = extension.entry_point.dist.version
plugin = extension.obj
db_plugin_info = api.plugin_info_get(plugin._name)
if db_plugin_info is None:
db_plugin_info = api.plugin_info_create({'name': plugin._name,
'uuid': uuid.uuid4(),
'version': version,
'provider':
plugin.PROVIDER,
'type': plugin.TYPE,
'description':
plugin.DESCRIPTION,
'tool_name':
plugin.TOOL_NAME
})
else:
api.plugin_version_update(db_plugin_info.id, version)
plugin._uuid = db_plugin_info.uuid
def start(self):
self.rpc_server = None
self.notification_server = None
super(CerberusManager, self).start()
targets = []
plugins = []
self.cerberus_manager = self._get_cerberus_manager()
if not list(self.cerberus_manager):
LOG.warning('Failed to load any task handlers for %s',
self.TASK_NAMESPACE)
for extension in self.cerberus_manager:
handler = extension.obj
LOG.debug('Plugin loaded: ' + extension.name)
LOG.debug(('Event types from %(name)s: %(type)s')
% {'name': extension.name,
'type': ', '.join(handler._subscribedEvents)})
self._register_plugin(extension)
handler.register_manager(self)
targets.extend(handler.get_targets(cfg.CONF))
plugins.append(handler)
transport = messaging.get_transport(cfg.CONF)
if transport:
rpc_target = messaging.Target(topic='test_rpc', server='server1')
self.rpc_server = messaging.get_rpc_server(transport, rpc_target,
[self],
executor='eventlet')
self.notification_server = messaging.get_notification_listener(
transport, targets, plugins, executor='eventlet')
LOG.info("RPC Server starting...")
self.rpc_server.start()
self.notification_server.start()
def _get_unique_task(self, id):
try:
unique_task = next(
thread for thread in self.tg.threads
if (thread.kw.get('task_id', None) == id))
except StopIteration:
return None
return unique_task
def _get_recurrent_task(self, id):
try:
recurrent_task = next(timer for timer in self.tg.timers if
(timer.kw.get('task_id', None) == id))
except StopIteration:
return None
return recurrent_task
def _add_unique_task(self, callback, *args, **kwargs):
"""
Add a simple task executing only once without delay
:param callback: Callable function to call when it's necessary
:param args: list of positional arguments to call the callback with
:param kwargs: dict of keyword arguments to call the callback with
:return the thread object that is created
"""
self.tg.add_thread(callback, *args, **kwargs)
def _add_recurrent_task(self, callback, period, initial_delay=None, *args,
**kwargs):
"""
Add a recurrent task executing periodically with or without an initial
delay
:param callback: Callable function to call when it's necessary
:param period: the time in seconds during two executions of the task
:param initial_delay: the time after the first execution of the task
occurs
:param args: list of positional arguments to call the callback with
:param kwargs: dict of keyword arguments to call the callback with
"""
self.tg.add_timer(period, callback, initial_delay, *args, **kwargs)
def get_plugins(self, ctx):
'''
This method is designed to be called by an rpc client.
E.g: Cerberus-api
It is used to get information about plugins
'''
json_plugins = []
for extension in self.cerberus_manager:
plugin = extension.obj
res = json.dumps(plugin, cls=base.PluginEncoder)
json_plugins.append(res)
return json_plugins
def _get_plugin_from_uuid(self, uuid):
for extension in self.cerberus_manager:
plugin = extension.obj
if (plugin._uuid == uuid):
return plugin
return None
def get_plugin_from_uuid(self, ctx, uuid):
plugin = self._get_plugin_from_uuid(uuid)
if plugin is not None:
return json.dumps(plugin, cls=base.PluginEncoder)
else:
return None
def add_task(self, ctx, uuid, method_, *args, **kwargs):
'''
This method is designed to be called by an rpc client.
E.g: Cerberus-api
It is used to call a method of a plugin back
:param ctx: a request context dict supplied by client
:param uuid: the uuid of the plugin to call method onto
:param method_: the method to call back
:param task_type: the type of task to create
:param args: some extra arguments
:param kwargs: some extra keyworded arguments
'''
self.task_id += 1
kwargs['task_id'] = self.task_id
kwargs['plugin_id'] = uuid
task_type = kwargs.get('task_type', "unique")
plugin = self._get_plugin_from_uuid(uuid)
if plugin is None:
raise errors.PluginNotFound(uuid)
if (task_type.lower() == 'recurrent'):
try:
task_period = int(kwargs.get('task_period', None))
except (TypeError, ValueError) as e:
LOG.exception(e)
raise errors.TaskPeriodNotInteger()
try:
self._add_recurrent_task(getattr(plugin, method_),
task_period,
*args,
**kwargs)
except TypeError as e:
LOG.exception(e)
raise errors.MethodNotString()
except AttributeError as e:
LOG.exception(e)
raise errors.MethodNotCallable(method_,
plugin.__class__.__name__)
else:
try:
self._add_unique_task(
getattr(plugin, method_),
*args,
**kwargs)
except TypeError as e:
LOG.exception(e)
raise errors.MethodNotString()
except AttributeError as e:
LOG.exception(e)
raise errors.MethodNotCallable(method_,
plugin.__class__.__name__)
return self.task_id
def _stop_recurrent_task(self, id):
"""
Stop the recurrent task but does not remove it from the ThreadGroup.
I.e, the task still exists and could be restarted
Plus, if the task is running, wait for the end of its execution
:param id: the id of the recurrent task to stop
:return:
:raises:
StopIteration: the task is not found
"""
recurrent_task = self._get_recurrent_task(id)
if recurrent_task is None:
raise errors.TaskNotFound(id)
recurrent_task.stop()
def _stop_unique_task(self, id):
unique_task = self._get_unique_task(id)
if unique_task is None:
raise errors.TaskNotFound(id)
unique_task.stop()
def _stop_task(self, id):
task = self._get_task(id)
if isinstance(task, loopingcall.FixedIntervalLoopingCall):
try:
self._stop_recurrent_task(id)
except errors.InvalidOperation:
raise
elif isinstance(task, threadgroup.Thread):
try:
self._stop_unique_task(id)
except errors.InvalidOperation:
raise
def stop_task(self, ctx, id):
try:
self._stop_task(id)
except errors.InvalidOperation:
raise
return id
def _delete_recurrent_task(self, id):
"""
Stop the task and delete the recurrent task from the ThreadGroup.
If the task is running, wait for the end of its execution
:param id: the identifier of the task to delete
:return:
"""
recurrent_task = self._get_recurrent_task(id)
if (recurrent_task is None):
raise errors.TaskDeletionNotAllowed(id)
recurrent_task.stop()
try:
self.tg.timers.remove(recurrent_task)
except ValueError:
raise
def delete_recurrent_task(self, ctx, id):
'''
This method is designed to be called by an rpc client.
E.g: Cerberus-api
Stop the task and delete the recurrent task from the ThreadGroup.
If the task is running, wait for the end of its execution
:param ctx: a request context dict supplied by client
:param id: the identifier of the task to delete
'''
try:
self._delete_recurrent_task(id)
except errors.InvalidOperation:
raise
return id
def _force_delete_recurrent_task(self, id):
"""
Stop the task even if it is running and delete the recurrent task from
the ThreadGroup.
:param id: the identifier of the task to force delete
:return:
"""
recurrent_task = self._get_recurrent_task(id)
if (recurrent_task is None):
raise errors.TaskDeletionNotAllowed(id)
recurrent_task.stop()
recurrent_task.gt.kill()
try:
self.tg.timers.remove(recurrent_task)
except ValueError:
raise
def force_delete_recurrent_task(self, ctx, id):
'''
This method is designed to be called by an rpc client.
E.g: Cerberus-api
Stop the task even if it is running and delete the recurrent task
from the ThreadGroup.
:param ctx: a request context dict supplied by client
:param id: the identifier of the task to force delete
'''
try:
self._force_delete_recurrent_task(id)
except errors.InvalidOperation:
raise
return id
def _get_tasks(self):
tasks = []
for timer in self.tg.timers:
tasks.append(timer)
for thread in self.tg.threads:
tasks.append(thread)
return tasks
def _get_task(self, id):
task = self._get_unique_task(id)
task_ = self._get_recurrent_task(id)
if (task is None and task_ is None):
raise errors.TaskNotFound(id)
return task if task is not None else task_
def get_tasks(self, ctx):
tasks_ = []
tasks = self._get_tasks()
for task in tasks:
if (isinstance(task, loopingcall.FixedIntervalLoopingCall)):
tasks_.append(
json.dumps(task,
cls=base.FixedIntervalLoopingCallEncoder))
elif (isinstance(task, threadgroup.Thread)):
tasks_.append(
json.dumps(task,
cls=base.ThreadEncoder))
return tasks_
def get_task(self, ctx, id):
try:
task = self._get_task(id)
except errors.InvalidOperation:
raise
if isinstance(task, loopingcall.FixedIntervalLoopingCall):
return json.dumps(task,
cls=base.FixedIntervalLoopingCallEncoder)
elif isinstance(task, threadgroup.Thread):
return json.dumps(task,
cls=base.ThreadEncoder)
def _restart_recurrent_task(self, id):
"""
Restart the task
:param id: the identifier of the task to restart
:return:
"""
recurrent_task = self._get_recurrent_task(id)
if (recurrent_task is None):
raise errors.TaskRestartNotAllowed(str(id))
period = recurrent_task.kw.get("task_period", None)
if recurrent_task._running is True:
raise errors.TaskRestartNotPossible(str(id))
else:
try:
recurrent_task.start(int(period))
except ValueError as e:
LOG.exception(e)
def restart_recurrent_task(self, ctx, id):
'''
This method is designed to be called by an rpc client.
E.g: Cerberus-api
Restart a recurrent task after it's being stopped
:param ctx: a request context dict supplied by client
:param id: the identifier of the task to restart
'''
try:
self._restart_recurrent_task(id)
except errors.InvalidOperation:
raise
return id

View File

View File

@ -0,0 +1,17 @@
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import six
six.add_move(six.MovedModule('mox', 'mox', 'mox3.mox'))

View File

@ -0,0 +1,40 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""oslo.i18n integration module.
See http://docs.openstack.org/developer/oslo.i18n/usage.html
"""
import oslo.i18n
# NOTE(dhellmann): This reference to o-s-l-o will be replaced by the
# application name when this module is synced into the separate
# repository. It is OK to have more than one translation function
# using the same domain, since there will still only be one message
# catalog.
_translators = oslo.i18n.TranslatorFactory(domain='oslo')
# The primary translation function using the well-known name "_"
_ = _translators.primary
# Translators for log levels.
#
# The abbreviated names are meant to reflect the usual use of a short
# name like '_'. The "L" is for "log" and the other letter comes from
# the level.
_LI = _translators.log_info
_LW = _translators.log_warning
_LE = _translators.log_error
_LC = _translators.log_critical

View File

@ -0,0 +1,221 @@
# Copyright 2013 OpenStack Foundation
# Copyright 2013 Spanish National Research Council.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# E0202: An attribute inherited from %s hide this method
# pylint: disable=E0202
import abc
import argparse
import os
import six
from stevedore import extension
from cerberus.openstack.common.apiclient import exceptions
_discovered_plugins = {}
def discover_auth_systems():
"""Discover the available auth-systems.
This won't take into account the old style auth-systems.
"""
global _discovered_plugins
_discovered_plugins = {}
def add_plugin(ext):
_discovered_plugins[ext.name] = ext.plugin
ep_namespace = "cerberus.openstack.common.apiclient.auth"
mgr = extension.ExtensionManager(ep_namespace)
mgr.map(add_plugin)
def load_auth_system_opts(parser):
"""Load options needed by the available auth-systems into a parser.
This function will try to populate the parser with options from the
available plugins.
"""
group = parser.add_argument_group("Common auth options")
BaseAuthPlugin.add_common_opts(group)
for name, auth_plugin in six.iteritems(_discovered_plugins):
group = parser.add_argument_group(
"Auth-system '%s' options" % name,
conflict_handler="resolve")
auth_plugin.add_opts(group)
def load_plugin(auth_system):
try:
plugin_class = _discovered_plugins[auth_system]
except KeyError:
raise exceptions.AuthSystemNotFound(auth_system)
return plugin_class(auth_system=auth_system)
def load_plugin_from_args(args):
"""Load required plugin and populate it with options.
Try to guess auth system if it is not specified. Systems are tried in
alphabetical order.
:type args: argparse.Namespace
:raises: AuthPluginOptionsMissing
"""
auth_system = args.os_auth_system
if auth_system:
plugin = load_plugin(auth_system)
plugin.parse_opts(args)
plugin.sufficient_options()
return plugin
for plugin_auth_system in sorted(six.iterkeys(_discovered_plugins)):
plugin_class = _discovered_plugins[plugin_auth_system]
plugin = plugin_class()
plugin.parse_opts(args)
try:
plugin.sufficient_options()
except exceptions.AuthPluginOptionsMissing:
continue
return plugin
raise exceptions.AuthPluginOptionsMissing(["auth_system"])
@six.add_metaclass(abc.ABCMeta)
class BaseAuthPlugin(object):
"""Base class for authentication plugins.
An authentication plugin needs to override at least the authenticate
method to be a valid plugin.
"""
auth_system = None
opt_names = []
common_opt_names = [
"auth_system",
"username",
"password",
"tenant_name",
"token",
"auth_url",
]
def __init__(self, auth_system=None, **kwargs):
self.auth_system = auth_system or self.auth_system
self.opts = dict((name, kwargs.get(name))
for name in self.opt_names)
@staticmethod
def _parser_add_opt(parser, opt):
"""Add an option to parser in two variants.
:param opt: option name (with underscores)
"""
dashed_opt = opt.replace("_", "-")
env_var = "OS_%s" % opt.upper()
arg_default = os.environ.get(env_var, "")
arg_help = "Defaults to env[%s]." % env_var
parser.add_argument(
"--os-%s" % dashed_opt,
metavar="<%s>" % dashed_opt,
default=arg_default,
help=arg_help)
parser.add_argument(
"--os_%s" % opt,
metavar="<%s>" % dashed_opt,
help=argparse.SUPPRESS)
@classmethod
def add_opts(cls, parser):
"""Populate the parser with the options for this plugin.
"""
for opt in cls.opt_names:
# use `BaseAuthPlugin.common_opt_names` since it is never
# changed in child classes
if opt not in BaseAuthPlugin.common_opt_names:
cls._parser_add_opt(parser, opt)
@classmethod
def add_common_opts(cls, parser):
"""Add options that are common for several plugins.
"""
for opt in cls.common_opt_names:
cls._parser_add_opt(parser, opt)
@staticmethod
def get_opt(opt_name, args):
"""Return option name and value.
:param opt_name: name of the option, e.g., "username"
:param args: parsed arguments
"""
return (opt_name, getattr(args, "os_%s" % opt_name, None))
def parse_opts(self, args):
"""Parse the actual auth-system options if any.
This method is expected to populate the attribute `self.opts` with a
dict containing the options and values needed to make authentication.
"""
self.opts.update(dict(self.get_opt(opt_name, args)
for opt_name in self.opt_names))
def authenticate(self, http_client):
"""Authenticate using plugin defined method.
The method usually analyses `self.opts` and performs
a request to authentication server.
:param http_client: client object that needs authentication
:type http_client: HTTPClient
:raises: AuthorizationFailure
"""
self.sufficient_options()
self._do_authenticate(http_client)
@abc.abstractmethod
def _do_authenticate(self, http_client):
"""Protected method for authentication.
"""
def sufficient_options(self):
"""Check if all required options are present.
:raises: AuthPluginOptionsMissing
"""
missing = [opt
for opt in self.opt_names
if not self.opts.get(opt)]
if missing:
raise exceptions.AuthPluginOptionsMissing(missing)
@abc.abstractmethod
def token_and_endpoint(self, endpoint_type, service_type):
"""Return token and endpoint.
:param service_type: Service type of the endpoint
:type service_type: string
:param endpoint_type: Type of endpoint.
Possible values: public or publicURL,
internal or internalURL,
admin or adminURL
:type endpoint_type: string
:returns: tuple of token and endpoint strings
:raises: EndpointException
"""

View File

@ -0,0 +1,500 @@
# Copyright 2010 Jacob Kaplan-Moss
# Copyright 2011 OpenStack Foundation
# Copyright 2012 Grid Dynamics
# Copyright 2013 OpenStack Foundation
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Base utilities to build API operation managers and objects on top of.
"""
# E1102: %s is not callable
# pylint: disable=E1102
import abc
import copy
import six
from six.moves.urllib import parse
from cerberus.openstack.common.apiclient import exceptions
from cerberus.openstack.common import strutils
def getid(obj):
"""Return id if argument is a Resource.
Abstracts the common pattern of allowing both an object or an object's ID
(UUID) as a parameter when dealing with relationships.
"""
try:
if obj.uuid:
return obj.uuid
except AttributeError:
pass
try:
return obj.id
except AttributeError:
return obj
# TODO(aababilov): call run_hooks() in HookableMixin's child classes
class HookableMixin(object):
"""Mixin so classes can register and run hooks."""
_hooks_map = {}
@classmethod
def add_hook(cls, hook_type, hook_func):
"""Add a new hook of specified type.
:param cls: class that registers hooks
:param hook_type: hook type, e.g., '__pre_parse_args__'
:param hook_func: hook function
"""
if hook_type not in cls._hooks_map:
cls._hooks_map[hook_type] = []
cls._hooks_map[hook_type].append(hook_func)
@classmethod
def run_hooks(cls, hook_type, *args, **kwargs):
"""Run all hooks of specified type.
:param cls: class that registers hooks
:param hook_type: hook type, e.g., '__pre_parse_args__'
:param **args: args to be passed to every hook function
:param **kwargs: kwargs to be passed to every hook function
"""
hook_funcs = cls._hooks_map.get(hook_type) or []
for hook_func in hook_funcs:
hook_func(*args, **kwargs)
class BaseManager(HookableMixin):
"""Basic manager type providing common operations.
Managers interact with a particular type of API (servers, flavors, images,
etc.) and provide CRUD operations for them.
"""
resource_class = None
def __init__(self, client):
"""Initializes BaseManager with `client`.
:param client: instance of BaseClient descendant for HTTP requests
"""
super(BaseManager, self).__init__()
self.client = client
def _list(self, url, response_key, obj_class=None, json=None):
"""List the collection.
:param url: a partial URL, e.g., '/servers'
:param response_key: the key to be looked up in response dictionary,
e.g., 'servers'
:param obj_class: class for constructing the returned objects
(self.resource_class will be used by default)
:param json: data that will be encoded as JSON and passed in POST
request (GET will be sent by default)
"""
if json:
body = self.client.post(url, json=json).json()
else:
body = self.client.get(url).json()
if obj_class is None:
obj_class = self.resource_class
data = body[response_key]
# NOTE(ja): keystone returns values as list as {'values': [ ... ]}
# unlike other services which just return the list...
try:
data = data['values']
except (KeyError, TypeError):
pass
return [obj_class(self, res, loaded=True) for res in data if res]
def _get(self, url, response_key):
"""Get an object from collection.
:param url: a partial URL, e.g., '/servers'
:param response_key: the key to be looked up in response dictionary,
e.g., 'server'
"""
body = self.client.get(url).json()
return self.resource_class(self, body[response_key], loaded=True)
def _head(self, url):
"""Retrieve request headers for an object.
:param url: a partial URL, e.g., '/servers'
"""
resp = self.client.head(url)
return resp.status_code == 204
def _post(self, url, json, response_key, return_raw=False):
"""Create an object.
:param url: a partial URL, e.g., '/servers'
:param json: data that will be encoded as JSON and passed in POST
request (GET will be sent by default)
:param response_key: the key to be looked up in response dictionary,
e.g., 'servers'
:param return_raw: flag to force returning raw JSON instead of
Python object of self.resource_class
"""
body = self.client.post(url, json=json).json()
if return_raw:
return body[response_key]
return self.resource_class(self, body[response_key])
def _put(self, url, json=None, response_key=None):
"""Update an object with PUT method.
:param url: a partial URL, e.g., '/servers'
:param json: data that will be encoded as JSON and passed in POST
request (GET will be sent by default)
:param response_key: the key to be looked up in response dictionary,
e.g., 'servers'
"""
resp = self.client.put(url, json=json)
# PUT requests may not return a body
if resp.content:
body = resp.json()
if response_key is not None:
return self.resource_class(self, body[response_key])
else:
return self.resource_class(self, body)
def _patch(self, url, json=None, response_key=None):
"""Update an object with PATCH method.
:param url: a partial URL, e.g., '/servers'
:param json: data that will be encoded as JSON and passed in POST
request (GET will be sent by default)
:param response_key: the key to be looked up in response dictionary,
e.g., 'servers'
"""
body = self.client.patch(url, json=json).json()
if response_key is not None:
return self.resource_class(self, body[response_key])
else:
return self.resource_class(self, body)
def _delete(self, url):
"""Delete an object.
:param url: a partial URL, e.g., '/servers/my-server'
"""
return self.client.delete(url)
@six.add_metaclass(abc.ABCMeta)
class ManagerWithFind(BaseManager):
"""Manager with additional `find()`/`findall()` methods."""
@abc.abstractmethod
def list(self):
pass
def find(self, **kwargs):
"""Find a single item with attributes matching ``**kwargs``.
This isn't very efficient: it loads the entire list then filters on
the Python side.
"""
matches = self.findall(**kwargs)
num_matches = len(matches)
if num_matches == 0:
msg = "No %s matching %s." % (self.resource_class.__name__, kwargs)
raise exceptions.NotFound(msg)
elif num_matches > 1:
raise exceptions.NoUniqueMatch()
else:
return matches[0]
def findall(self, **kwargs):
"""Find all items with attributes matching ``**kwargs``.
This isn't very efficient: it loads the entire list then filters on
the Python side.
"""
found = []
searches = kwargs.items()
for obj in self.list():
try:
if all(getattr(obj, attr) == value
for (attr, value) in searches):
found.append(obj)
except AttributeError:
continue
return found
class CrudManager(BaseManager):
"""Base manager class for manipulating entities.
Children of this class are expected to define a `collection_key` and `key`.
- `collection_key`: Usually a plural noun by convention (e.g. `entities`);
used to refer collections in both URL's (e.g. `/v3/entities`) and JSON
objects containing a list of member resources (e.g. `{'entities': [{},
{}, {}]}`).
- `key`: Usually a singular noun by convention (e.g. `entity`); used to
refer to an individual member of the collection.
"""
collection_key = None
key = None
def build_url(self, base_url=None, **kwargs):
"""Builds a resource URL for the given kwargs.
Given an example collection where `collection_key = 'entities'` and
`key = 'entity'`, the following URL's could be generated.
By default, the URL will represent a collection of entities, e.g.::
/entities
If kwargs contains an `entity_id`, then the URL will represent a
specific member, e.g.::
/entities/{entity_id}
:param base_url: if provided, the generated URL will be appended to it
"""
url = base_url if base_url is not None else ''
url += '/%s' % self.collection_key
# do we have a specific entity?
entity_id = kwargs.get('%s_id' % self.key)
if entity_id is not None:
url += '/%s' % entity_id
return url
def _filter_kwargs(self, kwargs):
"""Drop null values and handle ids."""
for key, ref in six.iteritems(kwargs.copy()):
if ref is None:
kwargs.pop(key)
else:
if isinstance(ref, Resource):
kwargs.pop(key)
kwargs['%s_id' % key] = getid(ref)
return kwargs
def create(self, **kwargs):
kwargs = self._filter_kwargs(kwargs)
return self._post(
self.build_url(**kwargs),
{self.key: kwargs},
self.key)
def get(self, **kwargs):
kwargs = self._filter_kwargs(kwargs)
return self._get(
self.build_url(**kwargs),
self.key)
def head(self, **kwargs):
kwargs = self._filter_kwargs(kwargs)
return self._head(self.build_url(**kwargs))
def list(self, base_url=None, **kwargs):
"""List the collection.
:param base_url: if provided, the generated URL will be appended to it
"""
kwargs = self._filter_kwargs(kwargs)
return self._list(
'%(base_url)s%(query)s' % {
'base_url': self.build_url(base_url=base_url, **kwargs),
'query': '?%s' % parse.urlencode(kwargs) if kwargs else '',
},
self.collection_key)
def put(self, base_url=None, **kwargs):
"""Update an element.
:param base_url: if provided, the generated URL will be appended to it
"""
kwargs = self._filter_kwargs(kwargs)
return self._put(self.build_url(base_url=base_url, **kwargs))
def update(self, **kwargs):
kwargs = self._filter_kwargs(kwargs)
params = kwargs.copy()
params.pop('%s_id' % self.key)
return self._patch(
self.build_url(**kwargs),
{self.key: params},
self.key)
def delete(self, **kwargs):
kwargs = self._filter_kwargs(kwargs)
return self._delete(
self.build_url(**kwargs))
def find(self, base_url=None, **kwargs):
"""Find a single item with attributes matching ``**kwargs``.
:param base_url: if provided, the generated URL will be appended to it
"""
kwargs = self._filter_kwargs(kwargs)
rl = self._list(
'%(base_url)s%(query)s' % {
'base_url': self.build_url(base_url=base_url, **kwargs),
'query': '?%s' % parse.urlencode(kwargs) if kwargs else '',
},
self.collection_key)
num = len(rl)
if num == 0:
msg = "No %s matching %s." % (self.resource_class.__name__, kwargs)
raise exceptions.NotFound(404, msg)
elif num > 1:
raise exceptions.NoUniqueMatch
else:
return rl[0]
class Extension(HookableMixin):
"""Extension descriptor."""
SUPPORTED_HOOKS = ('__pre_parse_args__', '__post_parse_args__')
manager_class = None
def __init__(self, name, module):
super(Extension, self).__init__()
self.name = name
self.module = module
self._parse_extension_module()
def _parse_extension_module(self):
self.manager_class = None
for attr_name, attr_value in self.module.__dict__.items():
if attr_name in self.SUPPORTED_HOOKS:
self.add_hook(attr_name, attr_value)
else:
try:
if issubclass(attr_value, BaseManager):
self.manager_class = attr_value
except TypeError:
pass
def __repr__(self):
return "<Extension '%s'>" % self.name
class Resource(object):
"""Base class for OpenStack resources (tenant, user, etc.).
This is pretty much just a bag for attributes.
"""
HUMAN_ID = False
NAME_ATTR = 'name'
def __init__(self, manager, info, loaded=False):
"""Populate and bind to a manager.
:param manager: BaseManager object
:param info: dictionary representing resource attributes
:param loaded: prevent lazy-loading if set to True
"""
self.manager = manager
self._info = info
self._add_details(info)
self._loaded = loaded
def __repr__(self):
reprkeys = sorted(k
for k in self.__dict__.keys()
if k[0] != '_' and k != 'manager')
info = ", ".join("%s=%s" % (k, getattr(self, k)) for k in reprkeys)
return "<%s %s>" % (self.__class__.__name__, info)
@property
def human_id(self):
"""Human-readable ID which can be used for bash completion.
"""
if self.NAME_ATTR in self.__dict__ and self.HUMAN_ID:
return strutils.to_slug(getattr(self, self.NAME_ATTR))
return None
def _add_details(self, info):
for (k, v) in six.iteritems(info):
try:
setattr(self, k, v)
self._info[k] = v
except AttributeError:
# In this case we already defined the attribute on the class
pass
def __getattr__(self, k):
if k not in self.__dict__:
#NOTE(bcwaldon): disallow lazy-loading if already loaded once
if not self.is_loaded():
self.get()
return self.__getattr__(k)
raise AttributeError(k)
else:
return self.__dict__[k]
def get(self):
"""Support for lazy loading details.
Some clients, such as novaclient have the option to lazy load the
details, details which can be loaded with this function.
"""
# set_loaded() first ... so if we have to bail, we know we tried.
self.set_loaded(True)
if not hasattr(self.manager, 'get'):
return
new = self.manager.get(self.id)
if new:
self._add_details(new._info)
def __eq__(self, other):
if not isinstance(other, Resource):
return NotImplemented
# two resources of different types are not equal
if not isinstance(other, self.__class__):
return False
if hasattr(self, 'id') and hasattr(other, 'id'):
return self.id == other.id
return self._info == other._info
def is_loaded(self):
return self._loaded
def set_loaded(self, val):
self._loaded = val
def to_dict(self):
return copy.deepcopy(self._info)

View File

@ -0,0 +1,358 @@
# Copyright 2010 Jacob Kaplan-Moss
# Copyright 2011 OpenStack Foundation
# Copyright 2011 Piston Cloud Computing, Inc.
# Copyright 2013 Alessio Ababilov
# Copyright 2013 Grid Dynamics
# Copyright 2013 OpenStack Foundation
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
OpenStack Client interface. Handles the REST calls and responses.
"""
# E0202: An attribute inherited from %s hide this method
# pylint: disable=E0202
import logging
import time
try:
import simplejson as json
except ImportError:
import json
import requests
from cerberus.openstack.common.apiclient import exceptions
from cerberus.openstack.common import importutils
_logger = logging.getLogger(__name__)
class HTTPClient(object):
"""This client handles sending HTTP requests to OpenStack servers.
Features:
- share authentication information between several clients to different
services (e.g., for compute and image clients);
- reissue authentication request for expired tokens;
- encode/decode JSON bodies;
- raise exceptions on HTTP errors;
- pluggable authentication;
- store authentication information in a keyring;
- store time spent for requests;
- register clients for particular services, so one can use
`http_client.identity` or `http_client.compute`;
- log requests and responses in a format that is easy to copy-and-paste
into terminal and send the same request with curl.
"""
user_agent = "cerberus.openstack.common.apiclient"
def __init__(self,
auth_plugin,
region_name=None,
endpoint_type="publicURL",
original_ip=None,
verify=True,
cert=None,
timeout=None,
timings=False,
keyring_saver=None,
debug=False,
user_agent=None,
http=None):
self.auth_plugin = auth_plugin
self.endpoint_type = endpoint_type
self.region_name = region_name
self.original_ip = original_ip
self.timeout = timeout
self.verify = verify
self.cert = cert
self.keyring_saver = keyring_saver
self.debug = debug
self.user_agent = user_agent or self.user_agent
self.times = [] # [("item", starttime, endtime), ...]
self.timings = timings
# requests within the same session can reuse TCP connections from pool
self.http = http or requests.Session()
self.cached_token = None
def _http_log_req(self, method, url, kwargs):
if not self.debug:
return
string_parts = [
"curl -i",
"-X '%s'" % method,
"'%s'" % url,
]
for element in kwargs['headers']:
header = "-H '%s: %s'" % (element, kwargs['headers'][element])
string_parts.append(header)
_logger.debug("REQ: %s" % " ".join(string_parts))
if 'data' in kwargs:
_logger.debug("REQ BODY: %s\n" % (kwargs['data']))
def _http_log_resp(self, resp):
if not self.debug:
return
_logger.debug(
"RESP: [%s] %s\n",
resp.status_code,
resp.headers)
if resp._content_consumed:
_logger.debug(
"RESP BODY: %s\n",
resp.text)
def serialize(self, kwargs):
if kwargs.get('json') is not None:
kwargs['headers']['Content-Type'] = 'application/json'
kwargs['data'] = json.dumps(kwargs['json'])
try:
del kwargs['json']
except KeyError:
pass
def get_timings(self):
return self.times
def reset_timings(self):
self.times = []
def request(self, method, url, **kwargs):
"""Send an http request with the specified characteristics.
Wrapper around `requests.Session.request` to handle tasks such as
setting headers, JSON encoding/decoding, and error handling.
:param method: method of HTTP request
:param url: URL of HTTP request
:param kwargs: any other parameter that can be passed to
' requests.Session.request (such as `headers`) or `json`
that will be encoded as JSON and used as `data` argument
"""
kwargs.setdefault("headers", kwargs.get("headers", {}))
kwargs["headers"]["User-Agent"] = self.user_agent
if self.original_ip:
kwargs["headers"]["Forwarded"] = "for=%s;by=%s" % (
self.original_ip, self.user_agent)
if self.timeout is not None:
kwargs.setdefault("timeout", self.timeout)
kwargs.setdefault("verify", self.verify)
if self.cert is not None:
kwargs.setdefault("cert", self.cert)
self.serialize(kwargs)
self._http_log_req(method, url, kwargs)
if self.timings:
start_time = time.time()
resp = self.http.request(method, url, **kwargs)
if self.timings:
self.times.append(("%s %s" % (method, url),
start_time, time.time()))
self._http_log_resp(resp)
if resp.status_code >= 400:
_logger.debug(
"Request returned failure status: %s",
resp.status_code)
raise exceptions.from_response(resp, method, url)
return resp
@staticmethod
def concat_url(endpoint, url):
"""Concatenate endpoint and final URL.
E.g., "http://keystone/v2.0/" and "/tokens" are concatenated to
"http://keystone/v2.0/tokens".
:param endpoint: the base URL
:param url: the final URL
"""
return "%s/%s" % (endpoint.rstrip("/"), url.strip("/"))
def client_request(self, client, method, url, **kwargs):
"""Send an http request using `client`'s endpoint and specified `url`.
If request was rejected as unauthorized (possibly because the token is
expired), issue one authorization attempt and send the request once
again.
:param client: instance of BaseClient descendant
:param method: method of HTTP request
:param url: URL of HTTP request
:param kwargs: any other parameter that can be passed to
' `HTTPClient.request`
"""
filter_args = {
"endpoint_type": client.endpoint_type or self.endpoint_type,
"service_type": client.service_type,
}
token, endpoint = (self.cached_token, client.cached_endpoint)
just_authenticated = False
if not (token and endpoint):
try:
token, endpoint = self.auth_plugin.token_and_endpoint(
**filter_args)
except exceptions.EndpointException:
pass
if not (token and endpoint):
self.authenticate()
just_authenticated = True
token, endpoint = self.auth_plugin.token_and_endpoint(
**filter_args)
if not (token and endpoint):
raise exceptions.AuthorizationFailure(
"Cannot find endpoint or token for request")
old_token_endpoint = (token, endpoint)
kwargs.setdefault("headers", {})["X-Auth-Token"] = token
self.cached_token = token
client.cached_endpoint = endpoint
# Perform the request once. If we get Unauthorized, then it
# might be because the auth token expired, so try to
# re-authenticate and try again. If it still fails, bail.
try:
return self.request(
method, self.concat_url(endpoint, url), **kwargs)
except exceptions.Unauthorized as unauth_ex:
if just_authenticated:
raise
self.cached_token = None
client.cached_endpoint = None
self.authenticate()
try:
token, endpoint = self.auth_plugin.token_and_endpoint(
**filter_args)
except exceptions.EndpointException:
raise unauth_ex
if (not (token and endpoint) or
old_token_endpoint == (token, endpoint)):
raise unauth_ex
self.cached_token = token
client.cached_endpoint = endpoint
kwargs["headers"]["X-Auth-Token"] = token
return self.request(
method, self.concat_url(endpoint, url), **kwargs)
def add_client(self, base_client_instance):
"""Add a new instance of :class:`BaseClient` descendant.
`self` will store a reference to `base_client_instance`.
Example:
>>> def test_clients():
... from keystoneclient.auth import keystone
... from openstack.common.apiclient import client
... auth = keystone.KeystoneAuthPlugin(
... username="user", password="pass", tenant_name="tenant",
... auth_url="http://auth:5000/v2.0")
... openstack_client = client.HTTPClient(auth)
... # create nova client
... from novaclient.v1_1 import client
... client.Client(openstack_client)
... # create keystone client
... from keystoneclient.v2_0 import client
... client.Client(openstack_client)
... # use them
... openstack_client.identity.tenants.list()
... openstack_client.compute.servers.list()
"""
service_type = base_client_instance.service_type
if service_type and not hasattr(self, service_type):
setattr(self, service_type, base_client_instance)
def authenticate(self):
self.auth_plugin.authenticate(self)
# Store the authentication results in the keyring for later requests
if self.keyring_saver:
self.keyring_saver.save(self)
class BaseClient(object):
"""Top-level object to access the OpenStack API.
This client uses :class:`HTTPClient` to send requests. :class:`HTTPClient`
will handle a bunch of issues such as authentication.
"""
service_type = None
endpoint_type = None # "publicURL" will be used
cached_endpoint = None
def __init__(self, http_client, extensions=None):
self.http_client = http_client
http_client.add_client(self)
# Add in any extensions...
if extensions:
for extension in extensions:
if extension.manager_class:
setattr(self, extension.name,
extension.manager_class(self))
def client_request(self, method, url, **kwargs):
return self.http_client.client_request(
self, method, url, **kwargs)
def head(self, url, **kwargs):
return self.client_request("HEAD", url, **kwargs)
def get(self, url, **kwargs):
return self.client_request("GET", url, **kwargs)
def post(self, url, **kwargs):
return self.client_request("POST", url, **kwargs)
def put(self, url, **kwargs):
return self.client_request("PUT", url, **kwargs)
def delete(self, url, **kwargs):
return self.client_request("DELETE", url, **kwargs)
def patch(self, url, **kwargs):
return self.client_request("PATCH", url, **kwargs)
@staticmethod
def get_class(api_name, version, version_map):
"""Returns the client class for the requested API version
:param api_name: the name of the API, e.g. 'compute', 'image', etc
:param version: the requested API version
:param version_map: a dict of client classes keyed by version
:rtype: a client class for the requested API version
"""
try:
client_path = version_map[str(version)]
except (KeyError, ValueError):
msg = "Invalid %s client version '%s'. must be one of: %s" % (
(api_name, version, ', '.join(version_map.keys())))
raise exceptions.UnsupportedVersion(msg)
return importutils.import_class(client_path)

View File

@ -0,0 +1,459 @@
# Copyright 2010 Jacob Kaplan-Moss
# Copyright 2011 Nebula, Inc.
# Copyright 2013 Alessio Ababilov
# Copyright 2013 OpenStack Foundation
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Exception definitions.
"""
import inspect
import sys
import six
class ClientException(Exception):
"""The base exception class for all exceptions this library raises.
"""
pass
class MissingArgs(ClientException):
"""Supplied arguments are not sufficient for calling a function."""
def __init__(self, missing):
self.missing = missing
msg = "Missing argument(s): %s" % ", ".join(missing)
super(MissingArgs, self).__init__(msg)
class ValidationError(ClientException):
"""Error in validation on API client side."""
pass
class UnsupportedVersion(ClientException):
"""User is trying to use an unsupported version of the API."""
pass
class CommandError(ClientException):
"""Error in CLI tool."""
pass
class AuthorizationFailure(ClientException):
"""Cannot authorize API client."""
pass
class ConnectionRefused(ClientException):
"""Cannot connect to API service."""
pass
class AuthPluginOptionsMissing(AuthorizationFailure):
"""Auth plugin misses some options."""
def __init__(self, opt_names):
super(AuthPluginOptionsMissing, self).__init__(
"Authentication failed. Missing options: %s" %
", ".join(opt_names))
self.opt_names = opt_names
class AuthSystemNotFound(AuthorizationFailure):
"""User has specified a AuthSystem that is not installed."""
def __init__(self, auth_system):
super(AuthSystemNotFound, self).__init__(
"AuthSystemNotFound: %s" % repr(auth_system))
self.auth_system = auth_system
class NoUniqueMatch(ClientException):
"""Multiple entities found instead of one."""
pass
class EndpointException(ClientException):
"""Something is rotten in Service Catalog."""
pass
class EndpointNotFound(EndpointException):
"""Could not find requested endpoint in Service Catalog."""
pass
class AmbiguousEndpoints(EndpointException):
"""Found more than one matching endpoint in Service Catalog."""
def __init__(self, endpoints=None):
super(AmbiguousEndpoints, self).__init__(
"AmbiguousEndpoints: %s" % repr(endpoints))
self.endpoints = endpoints
class HttpError(ClientException):
"""The base exception class for all HTTP exceptions.
"""
http_status = 0
message = "HTTP Error"
def __init__(self, message=None, details=None,
response=None, request_id=None,
url=None, method=None, http_status=None):
self.http_status = http_status or self.http_status
self.message = message or self.message
self.details = details
self.request_id = request_id
self.response = response
self.url = url
self.method = method
formatted_string = "%s (HTTP %s)" % (self.message, self.http_status)
if request_id:
formatted_string += " (Request-ID: %s)" % request_id
super(HttpError, self).__init__(formatted_string)
class HTTPRedirection(HttpError):
"""HTTP Redirection."""
message = "HTTP Redirection"
class HTTPClientError(HttpError):
"""Client-side HTTP error.
Exception for cases in which the client seems to have erred.
"""
message = "HTTP Client Error"
class HttpServerError(HttpError):
"""Server-side HTTP error.
Exception for cases in which the server is aware that it has
erred or is incapable of performing the request.
"""
message = "HTTP Server Error"
class MultipleChoices(HTTPRedirection):
"""HTTP 300 - Multiple Choices.
Indicates multiple options for the resource that the client may follow.
"""
http_status = 300
message = "Multiple Choices"
class BadRequest(HTTPClientError):
"""HTTP 400 - Bad Request.
The request cannot be fulfilled due to bad syntax.
"""
http_status = 400
message = "Bad Request"
class Unauthorized(HTTPClientError):
"""HTTP 401 - Unauthorized.
Similar to 403 Forbidden, but specifically for use when authentication
is required and has failed or has not yet been provided.
"""
http_status = 401
message = "Unauthorized"
class PaymentRequired(HTTPClientError):
"""HTTP 402 - Payment Required.
Reserved for future use.
"""
http_status = 402
message = "Payment Required"
class Forbidden(HTTPClientError):
"""HTTP 403 - Forbidden.
The request was a valid request, but the server is refusing to respond
to it.
"""
http_status = 403
message = "Forbidden"
class NotFound(HTTPClientError):
"""HTTP 404 - Not Found.
The requested resource could not be found but may be available again
in the future.
"""
http_status = 404
message = "Not Found"
class MethodNotAllowed(HTTPClientError):
"""HTTP 405 - Method Not Allowed.
A request was made of a resource using a request method not supported
by that resource.
"""
http_status = 405
message = "Method Not Allowed"
class NotAcceptable(HTTPClientError):
"""HTTP 406 - Not Acceptable.
The requested resource is only capable of generating content not
acceptable according to the Accept headers sent in the request.
"""
http_status = 406
message = "Not Acceptable"
class ProxyAuthenticationRequired(HTTPClientError):
"""HTTP 407 - Proxy Authentication Required.
The client must first authenticate itself with the proxy.
"""
http_status = 407
message = "Proxy Authentication Required"
class RequestTimeout(HTTPClientError):
"""HTTP 408 - Request Timeout.
The server timed out waiting for the request.
"""
http_status = 408
message = "Request Timeout"
class Conflict(HTTPClientError):
"""HTTP 409 - Conflict.
Indicates that the request could not be processed because of conflict
in the request, such as an edit conflict.
"""
http_status = 409
message = "Conflict"
class Gone(HTTPClientError):
"""HTTP 410 - Gone.
Indicates that the resource requested is no longer available and will
not be available again.
"""
http_status = 410
message = "Gone"
class LengthRequired(HTTPClientError):
"""HTTP 411 - Length Required.
The request did not specify the length of its content, which is
required by the requested resource.
"""
http_status = 411
message = "Length Required"
class PreconditionFailed(HTTPClientError):
"""HTTP 412 - Precondition Failed.
The server does not meet one of the preconditions that the requester
put on the request.
"""
http_status = 412
message = "Precondition Failed"
class RequestEntityTooLarge(HTTPClientError):
"""HTTP 413 - Request Entity Too Large.
The request is larger than the server is willing or able to process.
"""
http_status = 413
message = "Request Entity Too Large"
def __init__(self, *args, **kwargs):
try:
self.retry_after = int(kwargs.pop('retry_after'))
except (KeyError, ValueError):
self.retry_after = 0
super(RequestEntityTooLarge, self).__init__(*args, **kwargs)
class RequestUriTooLong(HTTPClientError):
"""HTTP 414 - Request-URI Too Long.
The URI provided was too long for the server to process.
"""
http_status = 414
message = "Request-URI Too Long"
class UnsupportedMediaType(HTTPClientError):
"""HTTP 415 - Unsupported Media Type.
The request entity has a media type which the server or resource does
not support.
"""
http_status = 415
message = "Unsupported Media Type"
class RequestedRangeNotSatisfiable(HTTPClientError):
"""HTTP 416 - Requested Range Not Satisfiable.
The client has asked for a portion of the file, but the server cannot
supply that portion.
"""
http_status = 416
message = "Requested Range Not Satisfiable"
class ExpectationFailed(HTTPClientError):
"""HTTP 417 - Expectation Failed.
The server cannot meet the requirements of the Expect request-header field.
"""
http_status = 417
message = "Expectation Failed"
class UnprocessableEntity(HTTPClientError):
"""HTTP 422 - Unprocessable Entity.
The request was well-formed but was unable to be followed due to semantic
errors.
"""
http_status = 422
message = "Unprocessable Entity"
class InternalServerError(HttpServerError):
"""HTTP 500 - Internal Server Error.
A generic error message, given when no more specific message is suitable.
"""
http_status = 500
message = "Internal Server Error"
# NotImplemented is a python keyword.
class HttpNotImplemented(HttpServerError):
"""HTTP 501 - Not Implemented.
The server either does not recognize the request method, or it lacks
the ability to fulfill the request.
"""
http_status = 501
message = "Not Implemented"
class BadGateway(HttpServerError):
"""HTTP 502 - Bad Gateway.
The server was acting as a gateway or proxy and received an invalid
response from the upstream server.
"""
http_status = 502
message = "Bad Gateway"
class ServiceUnavailable(HttpServerError):
"""HTTP 503 - Service Unavailable.
The server is currently unavailable.
"""
http_status = 503
message = "Service Unavailable"
class GatewayTimeout(HttpServerError):
"""HTTP 504 - Gateway Timeout.
The server was acting as a gateway or proxy and did not receive a timely
response from the upstream server.
"""
http_status = 504
message = "Gateway Timeout"
class HttpVersionNotSupported(HttpServerError):
"""HTTP 505 - HttpVersion Not Supported.
The server does not support the HTTP protocol version used in the request.
"""
http_status = 505
message = "HTTP Version Not Supported"
# _code_map contains all the classes that have http_status attribute.
_code_map = dict(
(getattr(obj, 'http_status', None), obj)
for name, obj in six.iteritems(vars(sys.modules[__name__]))
if inspect.isclass(obj) and getattr(obj, 'http_status', False)
)
def from_response(response, method, url):
"""Returns an instance of :class:`HttpError` or subclass based on response.
:param response: instance of `requests.Response` class
:param method: HTTP method used for request
:param url: URL used for request
"""
kwargs = {
"http_status": response.status_code,
"response": response,
"method": method,
"url": url,
"request_id": response.headers.get("x-compute-request-id"),
}
if "retry-after" in response.headers:
kwargs["retry_after"] = response.headers["retry-after"]
content_type = response.headers.get("Content-Type", "")
if content_type.startswith("application/json"):
try:
body = response.json()
except ValueError:
pass
else:
if isinstance(body, dict):
error = list(body.values())[0]
kwargs["message"] = error.get("message")
kwargs["details"] = error.get("details")
elif content_type.startswith("text/"):
kwargs["details"] = response.text
try:
cls = _code_map[response.status_code]
except KeyError:
if 500 <= response.status_code < 600:
cls = HttpServerError
elif 400 <= response.status_code < 500:
cls = HTTPClientError
else:
cls = HttpError
return cls(**kwargs)

View File

@ -0,0 +1,173 @@
# Copyright 2013 OpenStack Foundation
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
A fake server that "responds" to API methods with pre-canned responses.
All of these responses come from the spec, so if for some reason the spec's
wrong the tests might raise AssertionError. I've indicated in comments the
places where actual behavior differs from the spec.
"""
# W0102: Dangerous default value %s as argument
# pylint: disable=W0102
import json
import requests
import six
from six.moves.urllib import parse
from cerberus.openstack.common.apiclient import client
def assert_has_keys(dct, required=[], optional=[]):
for k in required:
try:
assert k in dct
except AssertionError:
extra_keys = set(dct.keys()).difference(set(required + optional))
raise AssertionError("found unexpected keys: %s" %
list(extra_keys))
class TestResponse(requests.Response):
"""Wrap requests.Response and provide a convenient initialization.
"""
def __init__(self, data):
super(TestResponse, self).__init__()
self._content_consumed = True
if isinstance(data, dict):
self.status_code = data.get('status_code', 200)
# Fake the text attribute to streamline Response creation
text = data.get('text', "")
if isinstance(text, (dict, list)):
self._content = json.dumps(text)
default_headers = {
"Content-Type": "application/json",
}
else:
self._content = text
default_headers = {}
if six.PY3 and isinstance(self._content, six.string_types):
self._content = self._content.encode('utf-8', 'strict')
self.headers = data.get('headers') or default_headers
else:
self.status_code = data
def __eq__(self, other):
return (self.status_code == other.status_code and
self.headers == other.headers and
self._content == other._content)
class FakeHTTPClient(client.HTTPClient):
def __init__(self, *args, **kwargs):
self.callstack = []
self.fixtures = kwargs.pop("fixtures", None) or {}
if not args and not "auth_plugin" in kwargs:
args = (None, )
super(FakeHTTPClient, self).__init__(*args, **kwargs)
def assert_called(self, method, url, body=None, pos=-1):
"""Assert than an API method was just called.
"""
expected = (method, url)
called = self.callstack[pos][0:2]
assert self.callstack, \
"Expected %s %s but no calls were made." % expected
assert expected == called, 'Expected %s %s; got %s %s' % \
(expected + called)
if body is not None:
if self.callstack[pos][3] != body:
raise AssertionError('%r != %r' %
(self.callstack[pos][3], body))
def assert_called_anytime(self, method, url, body=None):
"""Assert than an API method was called anytime in the test.
"""
expected = (method, url)
assert self.callstack, \
"Expected %s %s but no calls were made." % expected
found = False
entry = None
for entry in self.callstack:
if expected == entry[0:2]:
found = True
break
assert found, 'Expected %s %s; got %s' % \
(method, url, self.callstack)
if body is not None:
assert entry[3] == body, "%s != %s" % (entry[3], body)
self.callstack = []
def clear_callstack(self):
self.callstack = []
def authenticate(self):
pass
def client_request(self, client, method, url, **kwargs):
# Check that certain things are called correctly
if method in ["GET", "DELETE"]:
assert "json" not in kwargs
# Note the call
self.callstack.append(
(method,
url,
kwargs.get("headers") or {},
kwargs.get("json") or kwargs.get("data")))
try:
fixture = self.fixtures[url][method]
except KeyError:
pass
else:
return TestResponse({"headers": fixture[0],
"text": fixture[1]})
# Call the method
args = parse.parse_qsl(parse.urlparse(url)[4])
kwargs.update(args)
munged_url = url.rsplit('?', 1)[0]
munged_url = munged_url.strip('/').replace('/', '_').replace('.', '_')
munged_url = munged_url.replace('-', '_')
callback = "%s_%s" % (method.lower(), munged_url)
if not hasattr(self, callback):
raise AssertionError('Called unknown API method: %s %s, '
'expected fakes method name: %s' %
(method, url, callback))
resp = getattr(self, callback)(**kwargs)
if len(resp) == 3:
status, headers, body = resp
else:
status, body = resp
headers = {}
return TestResponse({
"status_code": status,
"text": body,
"headers": headers,
})

View File

@ -0,0 +1,309 @@
# Copyright 2012 Red Hat, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# W0603: Using the global statement
# W0621: Redefining name %s from outer scope
# pylint: disable=W0603,W0621
from __future__ import print_function
import getpass
import inspect
import os
import sys
import textwrap
import prettytable
import six
from six import moves
from cerberus.openstack.common.apiclient import exceptions
from cerberus.openstack.common.gettextutils import _
from cerberus.openstack.common import strutils
from cerberus.openstack.common import uuidutils
def validate_args(fn, *args, **kwargs):
"""Check that the supplied args are sufficient for calling a function.
>>> validate_args(lambda a: None)
Traceback (most recent call last):
...
MissingArgs: Missing argument(s): a
>>> validate_args(lambda a, b, c, d: None, 0, c=1)
Traceback (most recent call last):
...
MissingArgs: Missing argument(s): b, d
:param fn: the function to check
:param arg: the positional arguments supplied
:param kwargs: the keyword arguments supplied
"""
argspec = inspect.getargspec(fn)
num_defaults = len(argspec.defaults or [])
required_args = argspec.args[:len(argspec.args) - num_defaults]
def isbound(method):
return getattr(method, 'im_self', None) is not None
if isbound(fn):
required_args.pop(0)
missing = [arg for arg in required_args if arg not in kwargs]
missing = missing[len(args):]
if missing:
raise exceptions.MissingArgs(missing)
def arg(*args, **kwargs):
"""Decorator for CLI args.
Example:
>>> @arg("name", help="Name of the new entity")
... def entity_create(args):
... pass
"""
def _decorator(func):
add_arg(func, *args, **kwargs)
return func
return _decorator
def env(*args, **kwargs):
"""Returns the first environment variable set.
If all are empty, defaults to '' or keyword arg `default`.
"""
for arg in args:
value = os.environ.get(arg)
if value:
return value
return kwargs.get('default', '')
def add_arg(func, *args, **kwargs):
"""Bind CLI arguments to a shell.py `do_foo` function."""
if not hasattr(func, 'arguments'):
func.arguments = []
# NOTE(sirp): avoid dups that can occur when the module is shared across
# tests.
if (args, kwargs) not in func.arguments:
# Because of the semantics of decorator composition if we just append
# to the options list positional options will appear to be backwards.
func.arguments.insert(0, (args, kwargs))
def unauthenticated(func):
"""Adds 'unauthenticated' attribute to decorated function.
Usage:
>>> @unauthenticated
... def mymethod(f):
... pass
"""
func.unauthenticated = True
return func
def isunauthenticated(func):
"""Checks if the function does not require authentication.
Mark such functions with the `@unauthenticated` decorator.
:returns: bool
"""
return getattr(func, 'unauthenticated', False)
def print_list(objs, fields, formatters=None, sortby_index=0,
mixed_case_fields=None):
"""Print a list or objects as a table, one row per object.
:param objs: iterable of :class:`Resource`
:param fields: attributes that correspond to columns, in order
:param formatters: `dict` of callables for field formatting
:param sortby_index: index of the field for sorting table rows
:param mixed_case_fields: fields corresponding to object attributes that
have mixed case names (e.g., 'serverId')
"""
formatters = formatters or {}
mixed_case_fields = mixed_case_fields or []
if sortby_index is None:
kwargs = {}
else:
kwargs = {'sortby': fields[sortby_index]}
pt = prettytable.PrettyTable(fields, caching=False)
pt.align = 'l'
for o in objs:
row = []
for field in fields:
if field in formatters:
row.append(formatters[field](o))
else:
if field in mixed_case_fields:
field_name = field.replace(' ', '_')
else:
field_name = field.lower().replace(' ', '_')
data = getattr(o, field_name, '')
row.append(data)
pt.add_row(row)
print(strutils.safe_encode(pt.get_string(**kwargs)))
def print_dict(dct, dict_property="Property", wrap=0):
"""Print a `dict` as a table of two columns.
:param dct: `dict` to print
:param dict_property: name of the first column
:param wrap: wrapping for the second column
"""
pt = prettytable.PrettyTable([dict_property, 'Value'], caching=False)
pt.align = 'l'
for k, v in six.iteritems(dct):
# convert dict to str to check length
if isinstance(v, dict):
v = six.text_type(v)
if wrap > 0:
v = textwrap.fill(six.text_type(v), wrap)
# if value has a newline, add in multiple rows
# e.g. fault with stacktrace
if v and isinstance(v, six.string_types) and r'\n' in v:
lines = v.strip().split(r'\n')
col1 = k
for line in lines:
pt.add_row([col1, line])
col1 = ''
else:
pt.add_row([k, v])
print(strutils.safe_encode(pt.get_string()))
def get_password(max_password_prompts=3):
"""Read password from TTY."""
verify = strutils.bool_from_string(env("OS_VERIFY_PASSWORD"))
pw = None
if hasattr(sys.stdin, "isatty") and sys.stdin.isatty():
# Check for Ctrl-D
try:
for __ in moves.range(max_password_prompts):
pw1 = getpass.getpass("OS Password: ")
if verify:
pw2 = getpass.getpass("Please verify: ")
else:
pw2 = pw1
if pw1 == pw2 and pw1:
pw = pw1
break
except EOFError:
pass
return pw
def find_resource(manager, name_or_id, **find_args):
"""Look for resource in a given manager.
Used as a helper for the _find_* methods.
Example:
def _find_hypervisor(cs, hypervisor):
#Get a hypervisor by name or ID.
return cliutils.find_resource(cs.hypervisors, hypervisor)
"""
# first try to get entity as integer id
try:
return manager.get(int(name_or_id))
except (TypeError, ValueError, exceptions.NotFound):
pass
# now try to get entity as uuid
try:
tmp_id = strutils.safe_encode(name_or_id)
if uuidutils.is_uuid_like(tmp_id):
return manager.get(tmp_id)
except (TypeError, ValueError, exceptions.NotFound):
pass
# for str id which is not uuid
if getattr(manager, 'is_alphanum_id_allowed', False):
try:
return manager.get(name_or_id)
except exceptions.NotFound:
pass
try:
try:
return manager.find(human_id=name_or_id, **find_args)
except exceptions.NotFound:
pass
# finally try to find entity by name
try:
resource = getattr(manager, 'resource_class', None)
name_attr = resource.NAME_ATTR if resource else 'name'
kwargs = {name_attr: name_or_id}
kwargs.update(find_args)
return manager.find(**kwargs)
except exceptions.NotFound:
msg = _("No %(name)s with a name or "
"ID of '%(name_or_id)s' exists.") % \
{
"name": manager.resource_class.__name__.lower(),
"name_or_id": name_or_id
}
raise exceptions.CommandError(msg)
except exceptions.NoUniqueMatch:
msg = _("Multiple %(name)s matches found for "
"'%(name_or_id)s', use an ID to be more specific.") % \
{
"name": manager.resource_class.__name__.lower(),
"name_or_id": name_or_id
}
raise exceptions.CommandError(msg)
def service_type(stype):
"""Adds 'service_type' attribute to decorated function.
Usage:
@service_type('volume')
def mymethod(f):
...
"""
def inner(f):
f.service_type = stype
return f
return inner
def get_service_type(f):
"""Retrieves service type from function."""
return getattr(f, 'service_type', None)
def pretty_choice_list(l):
return ', '.join("'%s'" % i for i in l)
def exit(msg=''):
if msg:
print (msg, file=sys.stderr)
sys.exit(1)

View File

@ -0,0 +1,307 @@
# Copyright 2012 SINA Corporation
# Copyright 2014 Cisco Systems, Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
"""Extracts OpenStack config option info from module(s)."""
from __future__ import print_function
import argparse
import imp
import os
import re
import socket
import sys
import textwrap
from oslo.config import cfg
import six
import stevedore.named
from cerberus.openstack.common import gettextutils
from cerberus.openstack.common import importutils
gettextutils.install('cerberus')
STROPT = "StrOpt"
BOOLOPT = "BoolOpt"
INTOPT = "IntOpt"
FLOATOPT = "FloatOpt"
LISTOPT = "ListOpt"
DICTOPT = "DictOpt"
MULTISTROPT = "MultiStrOpt"
OPT_TYPES = {
STROPT: 'string value',
BOOLOPT: 'boolean value',
INTOPT: 'integer value',
FLOATOPT: 'floating point value',
LISTOPT: 'list value',
DICTOPT: 'dict value',
MULTISTROPT: 'multi valued',
}
OPTION_REGEX = re.compile(r"(%s)" % "|".join([STROPT, BOOLOPT, INTOPT,
FLOATOPT, LISTOPT, DICTOPT,
MULTISTROPT]))
PY_EXT = ".py"
BASEDIR = os.path.abspath(os.path.join(os.path.dirname(__file__),
"../../../../"))
WORDWRAP_WIDTH = 60
def raise_extension_exception(extmanager, ep, err):
raise
def generate(argv):
parser = argparse.ArgumentParser(
description='generate sample configuration file',
)
parser.add_argument('-m', dest='modules', action='append')
parser.add_argument('-l', dest='libraries', action='append')
parser.add_argument('srcfiles', nargs='*')
parsed_args = parser.parse_args(argv)
mods_by_pkg = dict()
for filepath in parsed_args.srcfiles:
pkg_name = filepath.split(os.sep)[1]
mod_str = '.'.join(['.'.join(filepath.split(os.sep)[:-1]),
os.path.basename(filepath).split('.')[0]])
mods_by_pkg.setdefault(pkg_name, list()).append(mod_str)
# NOTE(lzyeval): place top level modules before packages
pkg_names = sorted(pkg for pkg in mods_by_pkg if pkg.endswith(PY_EXT))
ext_names = sorted(pkg for pkg in mods_by_pkg if pkg not in pkg_names)
pkg_names.extend(ext_names)
# opts_by_group is a mapping of group name to an options list
# The options list is a list of (module, options) tuples
opts_by_group = {'DEFAULT': []}
if parsed_args.modules:
for module_name in parsed_args.modules:
module = _import_module(module_name)
if module:
for group, opts in _list_opts(module):
opts_by_group.setdefault(group, []).append((module_name,
opts))
# Look for entry points defined in libraries (or applications) for
# option discovery, and include their return values in the output.
#
# Each entry point should be a function returning an iterable
# of pairs with the group name (or None for the default group)
# and the list of Opt instances for that group.
if parsed_args.libraries:
loader = stevedore.named.NamedExtensionManager(
'oslo.config.opts',
names=list(set(parsed_args.libraries)),
invoke_on_load=False,
on_load_failure_callback=raise_extension_exception
)
for ext in loader:
for group, opts in ext.plugin():
opt_list = opts_by_group.setdefault(group or 'DEFAULT', [])
opt_list.append((ext.name, opts))
for pkg_name in pkg_names:
mods = mods_by_pkg.get(pkg_name)
mods.sort()
for mod_str in mods:
if mod_str.endswith('.__init__'):
mod_str = mod_str[:mod_str.rfind(".")]
mod_obj = _import_module(mod_str)
if not mod_obj:
raise RuntimeError("Unable to import module %s" % mod_str)
for group, opts in _list_opts(mod_obj):
opts_by_group.setdefault(group, []).append((mod_str, opts))
print_group_opts('DEFAULT', opts_by_group.pop('DEFAULT', []))
for group in sorted(opts_by_group.keys()):
print_group_opts(group, opts_by_group[group])
def _import_module(mod_str):
try:
if mod_str.startswith('bin.'):
imp.load_source(mod_str[4:], os.path.join('bin', mod_str[4:]))
return sys.modules[mod_str[4:]]
else:
return importutils.import_module(mod_str)
except Exception as e:
sys.stderr.write("Error importing module %s: %s\n" % (mod_str, str(e)))
return None
def _is_in_group(opt, group):
"Check if opt is in group."
for value in group._opts.values():
# NOTE(llu): Temporary workaround for bug #1262148, wait until
# newly released oslo.config support '==' operator.
if not(value['opt'] != opt):
return True
return False
def _guess_groups(opt, mod_obj):
# is it in the DEFAULT group?
if _is_in_group(opt, cfg.CONF):
return 'DEFAULT'
# what other groups is it in?
for value in cfg.CONF.values():
if isinstance(value, cfg.CONF.GroupAttr):
if _is_in_group(opt, value._group):
return value._group.name
raise RuntimeError(
"Unable to find group for option %s, "
"maybe it's defined twice in the same group?"
% opt.name
)
def _list_opts(obj):
def is_opt(o):
return (isinstance(o, cfg.Opt) and
not isinstance(o, cfg.SubCommandOpt))
opts = list()
for attr_str in dir(obj):
attr_obj = getattr(obj, attr_str)
if is_opt(attr_obj):
opts.append(attr_obj)
elif (isinstance(attr_obj, list) and
all(map(lambda x: is_opt(x), attr_obj))):
opts.extend(attr_obj)
ret = {}
for opt in opts:
ret.setdefault(_guess_groups(opt, obj), []).append(opt)
return ret.items()
def print_group_opts(group, opts_by_module):
print("[%s]" % group)
print('')
for mod, opts in opts_by_module:
print('#')
print('# Options defined in %s' % mod)
print('#')
print('')
for opt in opts:
_print_opt(opt)
print('')
def _get_my_ip():
try:
csock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
csock.connect(('8.8.8.8', 80))
(addr, port) = csock.getsockname()
csock.close()
return addr
except socket.error:
return None
def _sanitize_default(name, value):
"""Set up a reasonably sensible default for pybasedir, my_ip and host."""
if value.startswith(sys.prefix):
# NOTE(jd) Don't use os.path.join, because it is likely to think the
# second part is an absolute pathname and therefore drop the first
# part.
value = os.path.normpath("/usr/" + value[len(sys.prefix):])
elif value.startswith(BASEDIR):
return value.replace(BASEDIR, '/usr/lib/python/site-packages')
elif BASEDIR in value:
return value.replace(BASEDIR, '')
elif value == _get_my_ip():
return '10.0.0.1'
elif value in (socket.gethostname(), socket.getfqdn()) and 'host' in name:
return 'cerberus'
elif value.strip() != value:
return '"%s"' % value
return value
def _print_opt(opt):
opt_name, opt_default, opt_help = opt.dest, opt.default, opt.help
if not opt_help:
sys.stderr.write('WARNING: "%s" is missing help string.\n' % opt_name)
opt_help = ""
opt_type = None
try:
opt_type = OPTION_REGEX.search(str(type(opt))).group(0)
except (ValueError, AttributeError) as err:
sys.stderr.write("%s\n" % str(err))
sys.exit(1)
opt_help = u'%s (%s)' % (opt_help,
OPT_TYPES[opt_type])
print('#', "\n# ".join(textwrap.wrap(opt_help, WORDWRAP_WIDTH)))
if opt.deprecated_opts:
for deprecated_opt in opt.deprecated_opts:
if deprecated_opt.name:
deprecated_group = (deprecated_opt.group if
deprecated_opt.group else "DEFAULT")
print('# Deprecated group/name - [%s]/%s' %
(deprecated_group,
deprecated_opt.name))
try:
if opt_default is None:
print('#%s=<None>' % opt_name)
elif opt_type == STROPT:
assert(isinstance(opt_default, six.string_types))
print('#%s=%s' % (opt_name, _sanitize_default(opt_name,
opt_default)))
elif opt_type == BOOLOPT:
assert(isinstance(opt_default, bool))
print('#%s=%s' % (opt_name, str(opt_default).lower()))
elif opt_type == INTOPT:
assert(isinstance(opt_default, int) and
not isinstance(opt_default, bool))
print('#%s=%s' % (opt_name, opt_default))
elif opt_type == FLOATOPT:
assert(isinstance(opt_default, float))
print('#%s=%s' % (opt_name, opt_default))
elif opt_type == LISTOPT:
assert(isinstance(opt_default, list))
print('#%s=%s' % (opt_name, ','.join(opt_default)))
elif opt_type == DICTOPT:
assert(isinstance(opt_default, dict))
opt_default_strlist = [str(key) + ':' + str(value)
for (key, value) in opt_default.items()]
print('#%s=%s' % (opt_name, ','.join(opt_default_strlist)))
elif opt_type == MULTISTROPT:
assert(isinstance(opt_default, list))
if not opt_default:
opt_default = ['']
for default in opt_default:
print('#%s=%s' % (opt_name, default))
print('')
except Exception:
sys.stderr.write('Error in option "%s"\n' % opt_name)
sys.exit(1)
def main():
generate(sys.argv[1:])
if __name__ == '__main__':
main()

View File

@ -0,0 +1,111 @@
# Copyright 2011 OpenStack Foundation.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Simple class that stores security context information in the web request.
Projects should subclass this class if they wish to enhance the request
context or provide additional information in their specific WSGI pipeline.
"""
import itertools
import uuid
def generate_request_id():
return 'req-%s' % str(uuid.uuid4())
class RequestContext(object):
"""Helper class to represent useful information about a request context.
Stores information about the security context under which the user
accesses the system, as well as additional request information.
"""
user_idt_format = '{user} {tenant} {domain} {user_domain} {p_domain}'
def __init__(self, auth_token=None, user=None, tenant=None, domain=None,
user_domain=None, project_domain=None, is_admin=False,
read_only=False, show_deleted=False, request_id=None,
instance_uuid=None):
self.auth_token = auth_token
self.user = user
self.tenant = tenant
self.domain = domain
self.user_domain = user_domain
self.project_domain = project_domain
self.is_admin = is_admin
self.read_only = read_only
self.show_deleted = show_deleted
self.instance_uuid = instance_uuid
if not request_id:
request_id = generate_request_id()
self.request_id = request_id
def to_dict(self):
user_idt = (
self.user_idt_format.format(user=self.user or '-',
tenant=self.tenant or '-',
domain=self.domain or '-',
user_domain=self.user_domain or '-',
p_domain=self.project_domain or '-'))
return {'user': self.user,
'tenant': self.tenant,
'domain': self.domain,
'user_domain': self.user_domain,
'project_domain': self.project_domain,
'is_admin': self.is_admin,
'read_only': self.read_only,
'show_deleted': self.show_deleted,
'auth_token': self.auth_token,
'request_id': self.request_id,
'instance_uuid': self.instance_uuid,
'user_identity': user_idt}
def get_admin_context(show_deleted=False):
context = RequestContext(None,
tenant=None,
is_admin=True,
show_deleted=show_deleted)
return context
def get_context_from_function_and_args(function, args, kwargs):
"""Find an arg of type RequestContext and return it.
This is useful in a couple of decorators where we don't
know much about the function we're wrapping.
"""
for arg in itertools.chain(kwargs.values(), args):
if isinstance(arg, RequestContext):
return arg
return None
def is_user_context(context):
"""Indicates if the request context is a normal user."""
if not context:
return False
if context.is_admin:
return False
if not context.user_id or not context.project_id:
return False
return True

View File

View File

@ -0,0 +1,162 @@
# Copyright (c) 2013 Rackspace Hosting
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Multiple DB API backend support.
A DB backend module should implement a method named 'get_backend' which
takes no arguments. The method can return any object that implements DB
API methods.
"""
import functools
import logging
import threading
import time
from cerberus.openstack.common.db import exception
from cerberus.openstack.common.gettextutils import _LE
from cerberus.openstack.common import importutils
LOG = logging.getLogger(__name__)
def safe_for_db_retry(f):
"""Enable db-retry for decorated function, if config option enabled."""
f.__dict__['enable_retry'] = True
return f
class wrap_db_retry(object):
"""Retry db.api methods, if DBConnectionError() raised
Retry decorated db.api methods. If we enabled `use_db_reconnect`
in config, this decorator will be applied to all db.api functions,
marked with @safe_for_db_retry decorator.
Decorator catchs DBConnectionError() and retries function in a
loop until it succeeds, or until maximum retries count will be reached.
"""
def __init__(self, retry_interval, max_retries, inc_retry_interval,
max_retry_interval):
super(wrap_db_retry, self).__init__()
self.retry_interval = retry_interval
self.max_retries = max_retries
self.inc_retry_interval = inc_retry_interval
self.max_retry_interval = max_retry_interval
def __call__(self, f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
next_interval = self.retry_interval
remaining = self.max_retries
while True:
try:
return f(*args, **kwargs)
except exception.DBConnectionError as e:
if remaining == 0:
LOG.exception(_LE('DB exceeded retry limit.'))
raise exception.DBError(e)
if remaining != -1:
remaining -= 1
LOG.exception(_LE('DB connection error.'))
# NOTE(vsergeyev): We are using patched time module, so
# this effectively yields the execution
# context to another green thread.
time.sleep(next_interval)
if self.inc_retry_interval:
next_interval = min(
next_interval * 2,
self.max_retry_interval
)
return wrapper
class DBAPI(object):
def __init__(self, backend_name, backend_mapping=None, lazy=False,
**kwargs):
"""Initialize the chosen DB API backend.
:param backend_name: name of the backend to load
:type backend_name: str
:param backend_mapping: backend name -> module/class to load mapping
:type backend_mapping: dict
:param lazy: load the DB backend lazily on the first DB API method call
:type lazy: bool
Keyword arguments:
:keyword use_db_reconnect: retry DB transactions on disconnect or not
:type use_db_reconnect: bool
:keyword retry_interval: seconds between transaction retries
:type retry_interval: int
:keyword inc_retry_interval: increase retry interval or not
:type inc_retry_interval: bool
:keyword max_retry_interval: max interval value between retries
:type max_retry_interval: int
:keyword max_retries: max number of retries before an error is raised
:type max_retries: int
"""
self._backend = None
self._backend_name = backend_name
self._backend_mapping = backend_mapping or {}
self._lock = threading.Lock()
if not lazy:
self._load_backend()
self.use_db_reconnect = kwargs.get('use_db_reconnect', False)
self.retry_interval = kwargs.get('retry_interval', 1)
self.inc_retry_interval = kwargs.get('inc_retry_interval', True)
self.max_retry_interval = kwargs.get('max_retry_interval', 10)
self.max_retries = kwargs.get('max_retries', 20)
def _load_backend(self):
with self._lock:
if not self._backend:
# Import the untranslated name if we don't have a mapping
backend_path = self._backend_mapping.get(self._backend_name,
self._backend_name)
backend_mod = importutils.import_module(backend_path)
self._backend = backend_mod.get_backend()
def __getattr__(self, key):
if not self._backend:
self._load_backend()
attr = getattr(self._backend, key)
if not hasattr(attr, '__call__'):
return attr
# NOTE(vsergeyev): If `use_db_reconnect` option is set to True, retry
# DB API methods, decorated with @safe_for_db_retry
# on disconnect.
if self.use_db_reconnect and hasattr(attr, 'enable_retry'):
attr = wrap_db_retry(
retry_interval=self.retry_interval,
max_retries=self.max_retries,
inc_retry_interval=self.inc_retry_interval,
max_retry_interval=self.max_retry_interval)(attr)
return attr

View File

@ -0,0 +1,56 @@
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""DB related custom exceptions."""
import six
from cerberus.openstack.common.gettextutils import _
class DBError(Exception):
"""Wraps an implementation specific exception."""
def __init__(self, inner_exception=None):
self.inner_exception = inner_exception
super(DBError, self).__init__(six.text_type(inner_exception))
class DBDuplicateEntry(DBError):
"""Wraps an implementation specific exception."""
def __init__(self, columns=[], inner_exception=None):
self.columns = columns
super(DBDuplicateEntry, self).__init__(inner_exception)
class DBDeadlock(DBError):
def __init__(self, inner_exception=None):
super(DBDeadlock, self).__init__(inner_exception)
class DBInvalidUnicodeParameter(Exception):
message = _("Invalid Parameter: "
"Unicode is not supported by the current database.")
class DbMigrationError(DBError):
"""Wraps migration specific exception."""
def __init__(self, message=None):
super(DbMigrationError, self).__init__(message)
class DBConnectionError(DBError):
"""Wraps connection specific exception."""
pass

View File

@ -0,0 +1,171 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import copy
from oslo.config import cfg
database_opts = [
cfg.StrOpt('sqlite_db',
deprecated_group='DEFAULT',
default='cerberus.sqlite',
help='The file name to use with SQLite'),
cfg.BoolOpt('sqlite_synchronous',
deprecated_group='DEFAULT',
default=True,
help='If True, SQLite uses synchronous mode'),
cfg.StrOpt('backend',
default='sqlalchemy',
deprecated_name='db_backend',
deprecated_group='DEFAULT',
help='The backend to use for db'),
cfg.StrOpt('connection',
help='The SQLAlchemy connection string used to connect to the '
'database',
secret=True,
deprecated_opts=[cfg.DeprecatedOpt('sql_connection',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_connection',
group='DATABASE'),
cfg.DeprecatedOpt('connection',
group='sql'), ]),
cfg.StrOpt('mysql_sql_mode',
default='TRADITIONAL',
help='The SQL mode to be used for MySQL sessions. '
'This option, including the default, overrides any '
'server-set SQL mode. To use whatever SQL mode '
'is set by the server configuration, '
'set this to no value. Example: mysql_sql_mode='),
cfg.IntOpt('idle_timeout',
default=3600,
deprecated_opts=[cfg.DeprecatedOpt('sql_idle_timeout',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_idle_timeout',
group='DATABASE'),
cfg.DeprecatedOpt('idle_timeout',
group='sql')],
help='Timeout before idle sql connections are reaped'),
cfg.IntOpt('min_pool_size',
default=1,
deprecated_opts=[cfg.DeprecatedOpt('sql_min_pool_size',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_min_pool_size',
group='DATABASE')],
help='Minimum number of SQL connections to keep open in a '
'pool'),
cfg.IntOpt('max_pool_size',
default=None,
deprecated_opts=[cfg.DeprecatedOpt('sql_max_pool_size',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_max_pool_size',
group='DATABASE')],
help='Maximum number of SQL connections to keep open in a '
'pool'),
cfg.IntOpt('max_retries',
default=10,
deprecated_opts=[cfg.DeprecatedOpt('sql_max_retries',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_max_retries',
group='DATABASE')],
help='Maximum db connection retries during startup. '
'(setting -1 implies an infinite retry count)'),
cfg.IntOpt('retry_interval',
default=10,
deprecated_opts=[cfg.DeprecatedOpt('sql_retry_interval',
group='DEFAULT'),
cfg.DeprecatedOpt('reconnect_interval',
group='DATABASE')],
help='Interval between retries of opening a sql connection'),
cfg.IntOpt('max_overflow',
default=None,
deprecated_opts=[cfg.DeprecatedOpt('sql_max_overflow',
group='DEFAULT'),
cfg.DeprecatedOpt('sqlalchemy_max_overflow',
group='DATABASE')],
help='If set, use this value for max_overflow with sqlalchemy'),
cfg.IntOpt('connection_debug',
default=0,
deprecated_opts=[cfg.DeprecatedOpt('sql_connection_debug',
group='DEFAULT')],
help='Verbosity of SQL debugging information. 0=None, '
'100=Everything'),
cfg.BoolOpt('connection_trace',
default=False,
deprecated_opts=[cfg.DeprecatedOpt('sql_connection_trace',
group='DEFAULT')],
help='Add python stack traces to SQL as comment strings'),
cfg.IntOpt('pool_timeout',
default=None,
deprecated_opts=[cfg.DeprecatedOpt('sqlalchemy_pool_timeout',
group='DATABASE')],
help='If set, use this value for pool_timeout with sqlalchemy'),
cfg.BoolOpt('use_db_reconnect',
default=False,
help='Enable the experimental use of database reconnect '
'on connection lost'),
cfg.IntOpt('db_retry_interval',
default=1,
help='seconds between db connection retries'),
cfg.BoolOpt('db_inc_retry_interval',
default=True,
help='Whether to increase interval between db connection '
'retries, up to db_max_retry_interval'),
cfg.IntOpt('db_max_retry_interval',
default=10,
help='max seconds between db connection retries, if '
'db_inc_retry_interval is enabled'),
cfg.IntOpt('db_max_retries',
default=20,
help='maximum db connection retries before error is raised. '
'(setting -1 implies an infinite retry count)'),
]
CONF = cfg.CONF
CONF.register_opts(database_opts, 'database')
def set_defaults(sql_connection, sqlite_db, max_pool_size=None,
max_overflow=None, pool_timeout=None):
"""Set defaults for configuration variables."""
cfg.set_defaults(database_opts,
connection=sql_connection,
sqlite_db=sqlite_db)
# Update the QueuePool defaults
if max_pool_size is not None:
cfg.set_defaults(database_opts,
max_pool_size=max_pool_size)
if max_overflow is not None:
cfg.set_defaults(database_opts,
max_overflow=max_overflow)
if pool_timeout is not None:
cfg.set_defaults(database_opts,
pool_timeout=pool_timeout)
def list_opts():
"""Returns a list of oslo.config options available in the library.
The returned list includes all oslo.config options which may be registered
at runtime by the library.
Each element of the list is a tuple. The first element is the name of the
group under which the list of elements in the second element will be
registered. A group name of None corresponds to the [DEFAULT] group in
config files.
The purpose of this is to allow tools like the Oslo sample config file
generator to discover the options exposed to users by this library.
:returns: a list of (group_name, opts) tuples
"""
return [('database', copy.deepcopy(database_opts))]

View File

@ -0,0 +1,278 @@
# coding: utf-8
#
# Copyright (c) 2013 OpenStack Foundation
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# Base on code in migrate/changeset/databases/sqlite.py which is under
# the following license:
#
# The MIT License
#
# Copyright (c) 2009 Evan Rosson, Jan Dittberner, Domen Kožar
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import os
import re
from migrate.changeset import ansisql
from migrate.changeset.databases import sqlite
from migrate import exceptions as versioning_exceptions
from migrate.versioning import api as versioning_api
from migrate.versioning.repository import Repository
import sqlalchemy
from sqlalchemy.schema import UniqueConstraint
from cerberus.openstack.common.db import exception
from cerberus.openstack.common.gettextutils import _
def _get_unique_constraints(self, table):
"""Retrieve information about existing unique constraints of the table
This feature is needed for _recreate_table() to work properly.
Unfortunately, it's not available in sqlalchemy 0.7.x/0.8.x.
"""
data = table.metadata.bind.execute(
"""SELECT sql
FROM sqlite_master
WHERE
type='table' AND
name=:table_name""",
table_name=table.name
).fetchone()[0]
UNIQUE_PATTERN = "CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)"
return [
UniqueConstraint(
*[getattr(table.columns, c.strip(' "')) for c in cols.split(",")],
name=name
)
for name, cols in re.findall(UNIQUE_PATTERN, data)
]
def _recreate_table(self, table, column=None, delta=None, omit_uniques=None):
"""Recreate the table properly
Unlike the corresponding original method of sqlalchemy-migrate this one
doesn't drop existing unique constraints when creating a new one.
"""
table_name = self.preparer.format_table(table)
# we remove all indexes so as not to have
# problems during copy and re-create
for index in table.indexes:
index.drop()
# reflect existing unique constraints
for uc in self._get_unique_constraints(table):
table.append_constraint(uc)
# omit given unique constraints when creating a new table if required
table.constraints = set([
cons for cons in table.constraints
if omit_uniques is None or cons.name not in omit_uniques
])
self.append('ALTER TABLE %s RENAME TO migration_tmp' % table_name)
self.execute()
insertion_string = self._modify_table(table, column, delta)
table.create(bind=self.connection)
self.append(insertion_string % {'table_name': table_name})
self.execute()
self.append('DROP TABLE migration_tmp')
self.execute()
def _visit_migrate_unique_constraint(self, *p, **k):
"""Drop the given unique constraint
The corresponding original method of sqlalchemy-migrate just
raises NotImplemented error
"""
self.recreate_table(p[0].table, omit_uniques=[p[0].name])
def patch_migrate():
"""A workaround for SQLite's inability to alter things
SQLite abilities to alter tables are very limited (please read
http://www.sqlite.org/lang_altertable.html for more details).
E. g. one can't drop a column or a constraint in SQLite. The
workaround for this is to recreate the original table omitting
the corresponding constraint (or column).
sqlalchemy-migrate library has recreate_table() method that
implements this workaround, but it does it wrong:
- information about unique constraints of a table
is not retrieved. So if you have a table with one
unique constraint and a migration adding another one
you will end up with a table that has only the
latter unique constraint, and the former will be lost
- dropping of unique constraints is not supported at all
The proper way to fix this is to provide a pull-request to
sqlalchemy-migrate, but the project seems to be dead. So we
can go on with monkey-patching of the lib at least for now.
"""
# this patch is needed to ensure that recreate_table() doesn't drop
# existing unique constraints of the table when creating a new one
helper_cls = sqlite.SQLiteHelper
helper_cls.recreate_table = _recreate_table
helper_cls._get_unique_constraints = _get_unique_constraints
# this patch is needed to be able to drop existing unique constraints
constraint_cls = sqlite.SQLiteConstraintDropper
constraint_cls.visit_migrate_unique_constraint = \
_visit_migrate_unique_constraint
constraint_cls.__bases__ = (ansisql.ANSIColumnDropper,
sqlite.SQLiteConstraintGenerator)
def db_sync(engine, abs_path, version=None, init_version=0, sanity_check=True):
"""Upgrade or downgrade a database.
Function runs the upgrade() or downgrade() functions in change scripts.
:param engine: SQLAlchemy engine instance for a given database
:param abs_path: Absolute path to migrate repository.
:param version: Database will upgrade/downgrade until this version.
If None - database will update to the latest
available version.
:param init_version: Initial database version
:param sanity_check: Require schema sanity checking for all tables
"""
if version is not None:
try:
version = int(version)
except ValueError:
raise exception.DbMigrationError(
message=_("version should be an integer"))
current_version = db_version(engine, abs_path, init_version)
repository = _find_migrate_repo(abs_path)
if sanity_check:
_db_schema_sanity_check(engine)
if version is None or version > current_version:
return versioning_api.upgrade(engine, repository, version)
else:
return versioning_api.downgrade(engine, repository,
version)
def _db_schema_sanity_check(engine):
"""Ensure all database tables were created with required parameters.
:param engine: SQLAlchemy engine instance for a given database
"""
if engine.name == 'mysql':
onlyutf8_sql = ('SELECT TABLE_NAME,TABLE_COLLATION '
'from information_schema.TABLES '
'where TABLE_SCHEMA=%s and '
'TABLE_COLLATION NOT LIKE "%%utf8%%"')
# NOTE(morganfainberg): exclude the sqlalchemy-migrate and alembic
# versioning tables from the tables we need to verify utf8 status on.
# Non-standard table names are not supported.
EXCLUDED_TABLES = ['migrate_version', 'alembic_version']
table_names = [res[0] for res in
engine.execute(onlyutf8_sql, engine.url.database) if
res[0].lower() not in EXCLUDED_TABLES]
if len(table_names) > 0:
raise ValueError(_('Tables "%s" have non utf8 collation, '
'please make sure all tables are CHARSET=utf8'
) % ','.join(table_names))
def db_version(engine, abs_path, init_version):
"""Show the current version of the repository.
:param engine: SQLAlchemy engine instance for a given database
:param abs_path: Absolute path to migrate repository
:param version: Initial database version
"""
repository = _find_migrate_repo(abs_path)
try:
return versioning_api.db_version(engine, repository)
except versioning_exceptions.DatabaseNotControlledError:
meta = sqlalchemy.MetaData()
meta.reflect(bind=engine)
tables = meta.tables
if len(tables) == 0 or 'alembic_version' in tables:
db_version_control(engine, abs_path, version=init_version)
return versioning_api.db_version(engine, repository)
else:
raise exception.DbMigrationError(
message=_(
"The database is not under version control, but has "
"tables. Please stamp the current version of the schema "
"manually."))
def db_version_control(engine, abs_path, version=None):
"""Mark a database as under this repository's version control.
Once a database is under version control, schema changes should
only be done via change scripts in this repository.
:param engine: SQLAlchemy engine instance for a given database
:param abs_path: Absolute path to migrate repository
:param version: Initial database version
"""
repository = _find_migrate_repo(abs_path)
versioning_api.version_control(engine, repository, version)
return version
def _find_migrate_repo(abs_path):
"""Get the project's change script repository
:param abs_path: Absolute path to migrate repository
"""
if not os.path.exists(abs_path):
raise exception.DbMigrationError("Path %s not found" % abs_path)
return Repository(abs_path)

View File

@ -0,0 +1,78 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
import alembic
from alembic import config as alembic_config
import alembic.migration as alembic_migration
from cerberus.openstack.common.db.sqlalchemy.migration_cli import ext_base
from cerberus.openstack.common.db.sqlalchemy import session as db_session
class AlembicExtension(ext_base.MigrationExtensionBase):
order = 2
@property
def enabled(self):
return os.path.exists(self.alembic_ini_path)
def __init__(self, migration_config):
"""Extension to provide alembic features.
:param migration_config: Stores specific configuration for migrations
:type migration_config: dict
"""
self.alembic_ini_path = migration_config.get('alembic_ini_path', '')
self.config = alembic_config.Config(self.alembic_ini_path)
# option should be used if script is not in default directory
repo_path = migration_config.get('alembic_repo_path')
if repo_path:
self.config.set_main_option('script_location', repo_path)
self.db_url = migration_config['db_url']
def upgrade(self, version):
return alembic.command.upgrade(self.config, version or 'head')
def downgrade(self, version):
if isinstance(version, int) or version is None or version.isdigit():
version = 'base'
return alembic.command.downgrade(self.config, version)
def version(self):
engine = db_session.create_engine(self.db_url)
with engine.connect() as conn:
context = alembic_migration.MigrationContext.configure(conn)
return context.get_current_revision()
def revision(self, message='', autogenerate=False):
"""Creates template for migration.
:param message: Text that will be used for migration title
:type message: string
:param autogenerate: If True - generates diff based on current database
state
:type autogenerate: bool
"""
return alembic.command.revision(self.config, message=message,
autogenerate=autogenerate)
def stamp(self, revision):
"""Stamps database with provided revision.
:param revision: Should match one from repository or head - to stamp
database with most recent revision
:type revision: string
"""
return alembic.command.stamp(self.config, revision=revision)

View File

@ -0,0 +1,79 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import abc
import six
@six.add_metaclass(abc.ABCMeta)
class MigrationExtensionBase(object):
#used to sort migration in logical order
order = 0
@property
def enabled(self):
"""Used for availability verification of a plugin.
:rtype: bool
"""
return False
@abc.abstractmethod
def upgrade(self, version):
"""Used for upgrading database.
:param version: Desired database version
:type version: string
"""
@abc.abstractmethod
def downgrade(self, version):
"""Used for downgrading database.
:param version: Desired database version
:type version: string
"""
@abc.abstractmethod
def version(self):
"""Current database version.
:returns: Databse version
:rtype: string
"""
def revision(self, *args, **kwargs):
"""Used to generate migration script.
In migration engines that support this feature, it should generate
new migration script.
Accept arbitrary set of arguments.
"""
raise NotImplementedError()
def stamp(self, *args, **kwargs):
"""Stamps database based on plugin features.
Accept arbitrary set of arguments.
"""
raise NotImplementedError()
def __cmp__(self, other):
"""Used for definition of plugin order.
:param other: MigrationExtensionBase instance
:rtype: bool
"""
return self.order > other.order

View File

@ -0,0 +1,69 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import logging
import os
from cerberus.openstack.common.db.sqlalchemy import migration
from cerberus.openstack.common.db.sqlalchemy.migration_cli import ext_base
from cerberus.openstack.common.db.sqlalchemy import session as db_session
from cerberus.openstack.common.gettextutils import _LE
LOG = logging.getLogger(__name__)
class MigrateExtension(ext_base.MigrationExtensionBase):
"""Extension to provide sqlalchemy-migrate features.
:param migration_config: Stores specific configuration for migrations
:type migration_config: dict
"""
order = 1
def __init__(self, migration_config):
self.repository = migration_config.get('migration_repo_path', '')
self.init_version = migration_config.get('init_version', 0)
self.db_url = migration_config['db_url']
self.engine = db_session.create_engine(self.db_url)
@property
def enabled(self):
return os.path.exists(self.repository)
def upgrade(self, version):
version = None if version == 'head' else version
return migration.db_sync(
self.engine, self.repository, version,
init_version=self.init_version)
def downgrade(self, version):
try:
#version for migrate should be valid int - else skip
if version in ('base', None):
version = self.init_version
version = int(version)
return migration.db_sync(
self.engine, self.repository, version,
init_version=self.init_version)
except ValueError:
LOG.error(
_LE('Migration number for migrate plugin must be valid '
'integer or empty, if you want to downgrade '
'to initial state')
)
raise
def version(self):
return migration.db_version(
self.engine, self.repository, init_version=self.init_version)

View File

@ -0,0 +1,71 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from stevedore import enabled
MIGRATION_NAMESPACE = 'cerberus.openstack.common.migration'
def check_plugin_enabled(ext):
"""Used for EnabledExtensionManager"""
return ext.obj.enabled
class MigrationManager(object):
def __init__(self, migration_config):
self._manager = enabled.EnabledExtensionManager(
MIGRATION_NAMESPACE,
check_plugin_enabled,
invoke_kwds={'migration_config': migration_config},
invoke_on_load=True
)
if not self._plugins:
raise ValueError('There must be at least one plugin active.')
@property
def _plugins(self):
return sorted(ext.obj for ext in self._manager.extensions)
def upgrade(self, revision):
"""Upgrade database with all available backends."""
results = []
for plugin in self._plugins:
results.append(plugin.upgrade(revision))
return results
def downgrade(self, revision):
"""Downgrade database with available backends."""
#downgrading should be performed in reversed order
results = []
for plugin in reversed(self._plugins):
results.append(plugin.downgrade(revision))
return results
def version(self):
"""Return last version of db."""
last = None
for plugin in self._plugins:
version = plugin.version()
if version:
last = version
return last
def revision(self, message, autogenerate):
"""Generate template or autogenerated revision."""
#revision should be done only by last plugin
return self._plugins[-1].revision(message, autogenerate)
def stamp(self, revision):
"""Create stamp for a given revision."""
return self._plugins[-1].stamp(revision)

View File

@ -0,0 +1,119 @@
# Copyright (c) 2011 X.commerce, a business unit of eBay Inc.
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# Copyright 2011 Piston Cloud Computing, Inc.
# Copyright 2012 Cloudscaling Group, Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
SQLAlchemy models.
"""
import six
from sqlalchemy import Column, Integer
from sqlalchemy import DateTime
from sqlalchemy.orm import object_mapper
from cerberus.openstack.common import timeutils
class ModelBase(six.Iterator):
"""Base class for models."""
__table_initialized__ = False
def save(self, session):
"""Save this object."""
# NOTE(boris-42): This part of code should be look like:
# session.add(self)
# session.flush()
# But there is a bug in sqlalchemy and eventlet that
# raises NoneType exception if there is no running
# transaction and rollback is called. As long as
# sqlalchemy has this bug we have to create transaction
# explicitly.
with session.begin(subtransactions=True):
session.add(self)
session.flush()
def __setitem__(self, key, value):
setattr(self, key, value)
def __getitem__(self, key):
return getattr(self, key)
def get(self, key, default=None):
return getattr(self, key, default)
@property
def _extra_keys(self):
"""Specifies custom fields
Subclasses can override this property to return a list
of custom fields that should be included in their dict
representation.
For reference check tests/db/sqlalchemy/test_models.py
"""
return []
def __iter__(self):
columns = dict(object_mapper(self).columns).keys()
# NOTE(russellb): Allow models to specify other keys that can be looked
# up, beyond the actual db columns. An example would be the 'name'
# property for an Instance.
columns.extend(self._extra_keys)
self._i = iter(columns)
return self
# In Python 3, __next__() has replaced next().
def __next__(self):
n = six.advance_iterator(self._i)
return n, getattr(self, n)
def next(self):
return self.__next__()
def update(self, values):
"""Make the model object behave like a dict."""
for k, v in six.iteritems(values):
setattr(self, k, v)
def iteritems(self):
"""Make the model object behave like a dict.
Includes attributes from joins.
"""
local = dict(self)
joined = dict([(k, v) for k, v in six.iteritems(self.__dict__)
if not k[0] == '_'])
local.update(joined)
return six.iteritems(local)
class TimestampMixin(object):
created_at = Column(DateTime, default=lambda: timeutils.utcnow())
updated_at = Column(DateTime, onupdate=lambda: timeutils.utcnow())
class SoftDeleteMixin(object):
deleted_at = Column(DateTime)
deleted = Column(Integer, default=0)
def soft_delete(self, session):
"""Mark this object as deleted."""
self.deleted = self.id
self.deleted_at = timeutils.utcnow()
self.save(session=session)

View File

@ -0,0 +1,157 @@
# Copyright 2013 Mirantis.inc
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Provision test environment for specific DB backends"""
import argparse
import logging
import os
import random
import string
from six import moves
import sqlalchemy
from cerberus.openstack.common.db import exception as exc
LOG = logging.getLogger(__name__)
def get_engine(uri):
"""Engine creation
Call the function without arguments to get admin connection. Admin
connection required to create temporary user and database for each
particular test. Otherwise use existing connection to recreate connection
to the temporary database.
"""
return sqlalchemy.create_engine(uri, poolclass=sqlalchemy.pool.NullPool)
def _execute_sql(engine, sql, driver):
"""Initialize connection, execute sql query and close it."""
try:
with engine.connect() as conn:
if driver == 'postgresql':
conn.connection.set_isolation_level(0)
for s in sql:
conn.execute(s)
except sqlalchemy.exc.OperationalError:
msg = ('%s does not match database admin '
'credentials or database does not exist.')
LOG.exception(msg % engine.url)
raise exc.DBConnectionError(msg % engine.url)
def create_database(engine):
"""Provide temporary user and database for each particular test."""
driver = engine.name
auth = {
'database': ''.join(random.choice(string.ascii_lowercase)
for i in moves.range(10)),
'user': engine.url.username,
'passwd': engine.url.password,
}
sqls = [
"drop database if exists %(database)s;",
"create database %(database)s;"
]
if driver == 'sqlite':
return 'sqlite:////tmp/%s' % auth['database']
elif driver in ['mysql', 'postgresql']:
sql_query = map(lambda x: x % auth, sqls)
_execute_sql(engine, sql_query, driver)
else:
raise ValueError('Unsupported RDBMS %s' % driver)
params = auth.copy()
params['backend'] = driver
return "%(backend)s://%(user)s:%(passwd)s@localhost/%(database)s" % params
def drop_database(admin_engine, current_uri):
"""Drop temporary database and user after each particular test."""
engine = get_engine(current_uri)
driver = engine.name
auth = {'database': engine.url.database, 'user': engine.url.username}
if driver == 'sqlite':
try:
os.remove(auth['database'])
except OSError:
pass
elif driver in ['mysql', 'postgresql']:
sql = "drop database if exists %(database)s;"
_execute_sql(admin_engine, [sql % auth], driver)
else:
raise ValueError('Unsupported RDBMS %s' % driver)
def main():
"""Controller to handle commands
::create: Create test user and database with random names.
::drop: Drop user and database created by previous command.
"""
parser = argparse.ArgumentParser(
description='Controller to handle database creation and dropping'
' commands.',
epilog='Under normal circumstances is not used directly.'
' Used in .testr.conf to automate test database creation'
' and dropping processes.')
subparsers = parser.add_subparsers(
help='Subcommands to manipulate temporary test databases.')
create = subparsers.add_parser(
'create',
help='Create temporary test '
'databases and users.')
create.set_defaults(which='create')
create.add_argument(
'instances_count',
type=int,
help='Number of databases to create.')
drop = subparsers.add_parser(
'drop',
help='Drop temporary test databases and users.')
drop.set_defaults(which='drop')
drop.add_argument(
'instances',
nargs='+',
help='List of databases uri to be dropped.')
args = parser.parse_args()
connection_string = os.getenv('OS_TEST_DBAPI_ADMIN_CONNECTION',
'sqlite://')
engine = get_engine(connection_string)
which = args.which
if which == "create":
for i in range(int(args.instances_count)):
print(create_database(engine))
elif which == "drop":
for db in args.instances:
drop_database(engine, db)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,933 @@
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Session Handling for SQLAlchemy backend.
Recommended ways to use sessions within this framework:
* Don't use them explicitly; this is like running with ``AUTOCOMMIT=1``.
`model_query()` will implicitly use a session when called without one
supplied. This is the ideal situation because it will allow queries
to be automatically retried if the database connection is interrupted.
.. note:: Automatic retry will be enabled in a future patch.
It is generally fine to issue several queries in a row like this. Even though
they may be run in separate transactions and/or separate sessions, each one
will see the data from the prior calls. If needed, undo- or rollback-like
functionality should be handled at a logical level. For an example, look at
the code around quotas and `reservation_rollback()`.
Examples:
.. code:: python
def get_foo(context, foo):
return (model_query(context, models.Foo).
filter_by(foo=foo).
first())
def update_foo(context, id, newfoo):
(model_query(context, models.Foo).
filter_by(id=id).
update({'foo': newfoo}))
def create_foo(context, values):
foo_ref = models.Foo()
foo_ref.update(values)
foo_ref.save()
return foo_ref
* Within the scope of a single method, keep all the reads and writes within
the context managed by a single session. In this way, the session's
`__exit__` handler will take care of calling `flush()` and `commit()` for
you. If using this approach, you should not explicitly call `flush()` or
`commit()`. Any error within the context of the session will cause the
session to emit a `ROLLBACK`. Database errors like `IntegrityError` will be
raised in `session`'s `__exit__` handler, and any try/except within the
context managed by `session` will not be triggered. And catching other
non-database errors in the session will not trigger the ROLLBACK, so
exception handlers should always be outside the session, unless the
developer wants to do a partial commit on purpose. If the connection is
dropped before this is possible, the database will implicitly roll back the
transaction.
.. note:: Statements in the session scope will not be automatically retried.
If you create models within the session, they need to be added, but you
do not need to call `model.save()`:
.. code:: python
def create_many_foo(context, foos):
session = sessionmaker()
with session.begin():
for foo in foos:
foo_ref = models.Foo()
foo_ref.update(foo)
session.add(foo_ref)
def update_bar(context, foo_id, newbar):
session = sessionmaker()
with session.begin():
foo_ref = (model_query(context, models.Foo, session).
filter_by(id=foo_id).
first())
(model_query(context, models.Bar, session).
filter_by(id=foo_ref['bar_id']).
update({'bar': newbar}))
.. note:: `update_bar` is a trivially simple example of using
``with session.begin``. Whereas `create_many_foo` is a good example of
when a transaction is needed, it is always best to use as few queries as
possible.
The two queries in `update_bar` can be better expressed using a single query
which avoids the need for an explicit transaction. It can be expressed like
so:
.. code:: python
def update_bar(context, foo_id, newbar):
subq = (model_query(context, models.Foo.id).
filter_by(id=foo_id).
limit(1).
subquery())
(model_query(context, models.Bar).
filter_by(id=subq.as_scalar()).
update({'bar': newbar}))
For reference, this emits approximately the following SQL statement:
.. code:: sql
UPDATE bar SET bar = ${newbar}
WHERE id=(SELECT bar_id FROM foo WHERE id = ${foo_id} LIMIT 1);
.. note:: `create_duplicate_foo` is a trivially simple example of catching an
exception while using ``with session.begin``. Here create two duplicate
instances with same primary key, must catch the exception out of context
managed by a single session:
.. code:: python
def create_duplicate_foo(context):
foo1 = models.Foo()
foo2 = models.Foo()
foo1.id = foo2.id = 1
session = sessionmaker()
try:
with session.begin():
session.add(foo1)
session.add(foo2)
except exception.DBDuplicateEntry as e:
handle_error(e)
* Passing an active session between methods. Sessions should only be passed
to private methods. The private method must use a subtransaction; otherwise
SQLAlchemy will throw an error when you call `session.begin()` on an existing
transaction. Public methods should not accept a session parameter and should
not be involved in sessions within the caller's scope.
Note that this incurs more overhead in SQLAlchemy than the above means
due to nesting transactions, and it is not possible to implicitly retry
failed database operations when using this approach.
This also makes code somewhat more difficult to read and debug, because a
single database transaction spans more than one method. Error handling
becomes less clear in this situation. When this is needed for code clarity,
it should be clearly documented.
.. code:: python
def myfunc(foo):
session = sessionmaker()
with session.begin():
# do some database things
bar = _private_func(foo, session)
return bar
def _private_func(foo, session=None):
if not session:
session = sessionmaker()
with session.begin(subtransaction=True):
# do some other database things
return bar
There are some things which it is best to avoid:
* Don't keep a transaction open any longer than necessary.
This means that your ``with session.begin()`` block should be as short
as possible, while still containing all the related calls for that
transaction.
* Avoid ``with_lockmode('UPDATE')`` when possible.
In MySQL/InnoDB, when a ``SELECT ... FOR UPDATE`` query does not match
any rows, it will take a gap-lock. This is a form of write-lock on the
"gap" where no rows exist, and prevents any other writes to that space.
This can effectively prevent any INSERT into a table by locking the gap
at the end of the index. Similar problems will occur if the SELECT FOR UPDATE
has an overly broad WHERE clause, or doesn't properly use an index.
One idea proposed at ODS Fall '12 was to use a normal SELECT to test the
number of rows matching a query, and if only one row is returned,
then issue the SELECT FOR UPDATE.
The better long-term solution is to use
``INSERT .. ON DUPLICATE KEY UPDATE``.
However, this can not be done until the "deleted" columns are removed and
proper UNIQUE constraints are added to the tables.
Enabling soft deletes:
* To use/enable soft-deletes, the `SoftDeleteMixin` must be added
to your model class. For example:
.. code:: python
class NovaBase(models.SoftDeleteMixin, models.ModelBase):
pass
Efficient use of soft deletes:
* There are two possible ways to mark a record as deleted:
`model.soft_delete()` and `query.soft_delete()`.
The `model.soft_delete()` method works with a single already-fetched entry.
`query.soft_delete()` makes only one db request for all entries that
correspond to the query.
* In almost all cases you should use `query.soft_delete()`. Some examples:
.. code:: python
def soft_delete_bar():
count = model_query(BarModel).find(some_condition).soft_delete()
if count == 0:
raise Exception("0 entries were soft deleted")
def complex_soft_delete_with_synchronization_bar(session=None):
if session is None:
session = sessionmaker()
with session.begin(subtransactions=True):
count = (model_query(BarModel).
find(some_condition).
soft_delete(synchronize_session=True))
# Here synchronize_session is required, because we
# don't know what is going on in outer session.
if count == 0:
raise Exception("0 entries were soft deleted")
* There is only one situation where `model.soft_delete()` is appropriate: when
you fetch a single record, work with it, and mark it as deleted in the same
transaction.
.. code:: python
def soft_delete_bar_model():
session = sessionmaker()
with session.begin():
bar_ref = model_query(BarModel).find(some_condition).first()
# Work with bar_ref
bar_ref.soft_delete(session=session)
However, if you need to work with all entries that correspond to query and
then soft delete them you should use the `query.soft_delete()` method:
.. code:: python
def soft_delete_multi_models():
session = sessionmaker()
with session.begin():
query = (model_query(BarModel, session=session).
find(some_condition))
model_refs = query.all()
# Work with model_refs
query.soft_delete(synchronize_session=False)
# synchronize_session=False should be set if there is no outer
# session and these entries are not used after this.
When working with many rows, it is very important to use query.soft_delete,
which issues a single query. Using `model.soft_delete()`, as in the following
example, is very inefficient.
.. code:: python
for bar_ref in bar_refs:
bar_ref.soft_delete(session=session)
# This will produce count(bar_refs) db requests.
"""
import functools
import logging
import re
import time
import six
from sqlalchemy import exc as sqla_exc
from sqlalchemy.interfaces import PoolListener
import sqlalchemy.orm
from sqlalchemy.pool import NullPool, StaticPool
from sqlalchemy.sql.expression import literal_column
from cerberus.openstack.common.db import exception
from cerberus.openstack.common.gettextutils import _LE, _LW
from cerberus.openstack.common import timeutils
LOG = logging.getLogger(__name__)
class SqliteForeignKeysListener(PoolListener):
"""Ensures that the foreign key constraints are enforced in SQLite.
The foreign key constraints are disabled by default in SQLite,
so the foreign key constraints will be enabled here for every
database connection
"""
def connect(self, dbapi_con, con_record):
dbapi_con.execute('pragma foreign_keys=ON')
# note(boris-42): In current versions of DB backends unique constraint
# violation messages follow the structure:
#
# sqlite:
# 1 column - (IntegrityError) column c1 is not unique
# N columns - (IntegrityError) column c1, c2, ..., N are not unique
#
# sqlite since 3.7.16:
# 1 column - (IntegrityError) UNIQUE constraint failed: tbl.k1
#
# N columns - (IntegrityError) UNIQUE constraint failed: tbl.k1, tbl.k2
#
# postgres:
# 1 column - (IntegrityError) duplicate key value violates unique
# constraint "users_c1_key"
# N columns - (IntegrityError) duplicate key value violates unique
# constraint "name_of_our_constraint"
#
# mysql:
# 1 column - (IntegrityError) (1062, "Duplicate entry 'value_of_c1' for key
# 'c1'")
# N columns - (IntegrityError) (1062, "Duplicate entry 'values joined
# with -' for key 'name_of_our_constraint'")
#
# ibm_db_sa:
# N columns - (IntegrityError) SQL0803N One or more values in the INSERT
# statement, UPDATE statement, or foreign key update caused by a
# DELETE statement are not valid because the primary key, unique
# constraint or unique index identified by "2" constrains table
# "NOVA.KEY_PAIRS" from having duplicate values for the index
# key.
_DUP_KEY_RE_DB = {
"sqlite": (re.compile(r"^.*columns?([^)]+)(is|are)\s+not\s+unique$"),
re.compile(r"^.*UNIQUE\s+constraint\s+failed:\s+(.+)$")),
"postgresql": (re.compile(r"^.*duplicate\s+key.*\"([^\"]+)\"\s*\n.*$"),),
"mysql": (re.compile(r"^.*\(1062,.*'([^\']+)'\"\)$"),),
"ibm_db_sa": (re.compile(r"^.*SQL0803N.*$"),),
}
def _raise_if_duplicate_entry_error(integrity_error, engine_name):
"""Raise exception if two entries are duplicated.
In this function will be raised DBDuplicateEntry exception if integrity
error wrap unique constraint violation.
"""
def get_columns_from_uniq_cons_or_name(columns):
# note(vsergeyev): UniqueConstraint name convention: "uniq_t0c10c2"
# where `t` it is table name and columns `c1`, `c2`
# are in UniqueConstraint.
uniqbase = "uniq_"
if not columns.startswith(uniqbase):
if engine_name == "postgresql":
return [columns[columns.index("_") + 1:columns.rindex("_")]]
return [columns]
return columns[len(uniqbase):].split("0")[1:]
if engine_name not in ("ibm_db_sa", "mysql", "sqlite", "postgresql"):
return
# FIXME(johannes): The usage of the .message attribute has been
# deprecated since Python 2.6. However, the exceptions raised by
# SQLAlchemy can differ when using unicode() and accessing .message.
# An audit across all three supported engines will be necessary to
# ensure there are no regressions.
for pattern in _DUP_KEY_RE_DB[engine_name]:
match = pattern.match(integrity_error.message)
if match:
break
else:
return
# NOTE(mriedem): The ibm_db_sa integrity error message doesn't provide the
# columns so we have to omit that from the DBDuplicateEntry error.
columns = ''
if engine_name != 'ibm_db_sa':
columns = match.group(1)
if engine_name == "sqlite":
columns = [c.split('.')[-1] for c in columns.strip().split(", ")]
else:
columns = get_columns_from_uniq_cons_or_name(columns)
raise exception.DBDuplicateEntry(columns, integrity_error)
# NOTE(comstud): In current versions of DB backends, Deadlock violation
# messages follow the structure:
#
# mysql:
# (OperationalError) (1213, 'Deadlock found when trying to get lock; try '
# 'restarting transaction') <query_str> <query_args>
_DEADLOCK_RE_DB = {
"mysql": re.compile(r"^.*\(1213, 'Deadlock.*")
}
def _raise_if_deadlock_error(operational_error, engine_name):
"""Raise exception on deadlock condition.
Raise DBDeadlock exception if OperationalError contains a Deadlock
condition.
"""
re = _DEADLOCK_RE_DB.get(engine_name)
if re is None:
return
# FIXME(johannes): The usage of the .message attribute has been
# deprecated since Python 2.6. However, the exceptions raised by
# SQLAlchemy can differ when using unicode() and accessing .message.
# An audit across all three supported engines will be necessary to
# ensure there are no regressions.
m = re.match(operational_error.message)
if not m:
return
raise exception.DBDeadlock(operational_error)
def _wrap_db_error(f):
@functools.wraps(f)
def _wrap(self, *args, **kwargs):
try:
assert issubclass(
self.__class__, (
sqlalchemy.orm.session.Session, SessionTransactionWrapper)
), ('_wrap_db_error() can only be applied to methods of '
'subclasses of sqlalchemy.orm.session.Session or '
' SessionTransactionWrapper')
return f(self, *args, **kwargs)
except UnicodeEncodeError:
raise exception.DBInvalidUnicodeParameter()
except sqla_exc.OperationalError as e:
_raise_if_db_connection_lost(e, self.bind)
_raise_if_deadlock_error(e, self.bind.dialect.name)
# NOTE(comstud): A lot of code is checking for OperationalError
# so let's not wrap it for now.
raise
# note(boris-42): We should catch unique constraint violation and
# wrap it by our own DBDuplicateEntry exception. Unique constraint
# violation is wrapped by IntegrityError.
except sqla_exc.IntegrityError as e:
# note(boris-42): SqlAlchemy doesn't unify errors from different
# DBs so we must do this. Also in some tables (for example
# instance_types) there are more than one unique constraint. This
# means we should get names of columns, which values violate
# unique constraint, from error message.
_raise_if_duplicate_entry_error(e, self.bind.dialect.name)
raise exception.DBError(e)
except exception.DBError:
# note(zzzeek) - if _wrap_db_error is applied to nested functions,
# ensure an existing DBError is propagated outwards
raise
except Exception as e:
LOG.exception(_LE('DB exception wrapped.'))
raise exception.DBError(e)
return _wrap
def _synchronous_switch_listener(dbapi_conn, connection_rec):
"""Switch sqlite connections to non-synchronous mode."""
dbapi_conn.execute("PRAGMA synchronous = OFF")
def _add_regexp_listener(dbapi_con, con_record):
"""Add REGEXP function to sqlite connections."""
def regexp(expr, item):
reg = re.compile(expr)
return reg.search(six.text_type(item)) is not None
dbapi_con.create_function('regexp', 2, regexp)
def _thread_yield(dbapi_con, con_record):
"""Ensure other greenthreads get a chance to be executed.
If we use eventlet.monkey_patch(), eventlet.greenthread.sleep(0) will
execute instead of time.sleep(0).
Force a context switch. With common database backends (eg MySQLdb and
sqlite), there is no implicit yield caused by network I/O since they are
implemented by C libraries that eventlet cannot monkey patch.
"""
time.sleep(0)
def _ping_listener(engine, dbapi_conn, connection_rec, connection_proxy):
"""Ensures that MySQL, PostgreSQL or DB2 connections are alive.
Borrowed from:
http://groups.google.com/group/sqlalchemy/msg/a4ce563d802c929f
"""
cursor = dbapi_conn.cursor()
try:
ping_sql = 'select 1'
if engine.name == 'ibm_db_sa':
# DB2 requires a table expression
ping_sql = 'select 1 from (values (1)) AS t1'
cursor.execute(ping_sql)
except Exception as ex:
if engine.dialect.is_disconnect(ex, dbapi_conn, cursor):
msg = _LW('Database server has gone away: %s') % ex
LOG.warning(msg)
# if the database server has gone away, all connections in the pool
# have become invalid and we can safely close all of them here,
# rather than waste time on checking of every single connection
engine.dispose()
# this will be handled by SQLAlchemy and will force it to create
# a new connection and retry the original action
raise sqla_exc.DisconnectionError(msg)
else:
raise
def _set_session_sql_mode(dbapi_con, connection_rec, sql_mode=None):
"""Set the sql_mode session variable.
MySQL supports several server modes. The default is None, but sessions
may choose to enable server modes like TRADITIONAL, ANSI,
several STRICT_* modes and others.
Note: passing in '' (empty string) for sql_mode clears
the SQL mode for the session, overriding a potentially set
server default.
"""
cursor = dbapi_con.cursor()
cursor.execute("SET SESSION sql_mode = %s", [sql_mode])
def _mysql_get_effective_sql_mode(engine):
"""Returns the effective SQL mode for connections from the engine pool.
Returns ``None`` if the mode isn't available, otherwise returns the mode.
"""
# Get the real effective SQL mode. Even when unset by
# our own config, the server may still be operating in a specific
# SQL mode as set by the server configuration.
# Also note that the checkout listener will be called on execute to
# set the mode if it's registered.
row = engine.execute("SHOW VARIABLES LIKE 'sql_mode'").fetchone()
if row is None:
return
return row[1]
def _mysql_check_effective_sql_mode(engine):
"""Logs a message based on the effective SQL mode for MySQL connections."""
realmode = _mysql_get_effective_sql_mode(engine)
if realmode is None:
LOG.warning(_LW('Unable to detect effective SQL mode'))
return
LOG.debug('MySQL server mode set to %s', realmode)
# 'TRADITIONAL' mode enables several other modes, so
# we need a substring match here
if not ('TRADITIONAL' in realmode.upper() or
'STRICT_ALL_TABLES' in realmode.upper()):
LOG.warning(_LW("MySQL SQL mode is '%s', "
"consider enabling TRADITIONAL or STRICT_ALL_TABLES"),
realmode)
def _mysql_set_mode_callback(engine, sql_mode):
if sql_mode is not None:
mode_callback = functools.partial(_set_session_sql_mode,
sql_mode=sql_mode)
sqlalchemy.event.listen(engine, 'connect', mode_callback)
_mysql_check_effective_sql_mode(engine)
def _is_db_connection_error(args):
"""Return True if error in connecting to db."""
# NOTE(adam_g): This is currently MySQL specific and needs to be extended
# to support Postgres and others.
# For the db2, the error code is -30081 since the db2 is still not ready
conn_err_codes = ('2002', '2003', '2006', '2013', '-30081')
for err_code in conn_err_codes:
if args.find(err_code) != -1:
return True
return False
def _raise_if_db_connection_lost(error, engine):
# NOTE(vsergeyev): Function is_disconnect(e, connection, cursor)
# requires connection and cursor in incoming parameters,
# but we have no possibility to create connection if DB
# is not available, so in such case reconnect fails.
# But is_disconnect() ignores these parameters, so it
# makes sense to pass to function None as placeholder
# instead of connection and cursor.
if engine.dialect.is_disconnect(error, None, None):
raise exception.DBConnectionError(error)
def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None,
idle_timeout=3600,
connection_debug=0, max_pool_size=None, max_overflow=None,
pool_timeout=None, sqlite_synchronous=True,
connection_trace=False, max_retries=10, retry_interval=10):
"""Return a new SQLAlchemy engine."""
connection_dict = sqlalchemy.engine.url.make_url(sql_connection)
engine_args = {
"pool_recycle": idle_timeout,
'convert_unicode': True,
}
logger = logging.getLogger('sqlalchemy.engine')
# Map SQL debug level to Python log level
if connection_debug >= 100:
logger.setLevel(logging.DEBUG)
elif connection_debug >= 50:
logger.setLevel(logging.INFO)
else:
logger.setLevel(logging.WARNING)
if "sqlite" in connection_dict.drivername:
if sqlite_fk:
engine_args["listeners"] = [SqliteForeignKeysListener()]
engine_args["poolclass"] = NullPool
if sql_connection == "sqlite://":
engine_args["poolclass"] = StaticPool
engine_args["connect_args"] = {'check_same_thread': False}
else:
if max_pool_size is not None:
engine_args['pool_size'] = max_pool_size
if max_overflow is not None:
engine_args['max_overflow'] = max_overflow
if pool_timeout is not None:
engine_args['pool_timeout'] = pool_timeout
engine = sqlalchemy.create_engine(sql_connection, **engine_args)
sqlalchemy.event.listen(engine, 'checkin', _thread_yield)
if engine.name in ('ibm_db_sa', 'mysql', 'postgresql'):
ping_callback = functools.partial(_ping_listener, engine)
sqlalchemy.event.listen(engine, 'checkout', ping_callback)
if engine.name == 'mysql':
if mysql_sql_mode:
_mysql_set_mode_callback(engine, mysql_sql_mode)
elif 'sqlite' in connection_dict.drivername:
if not sqlite_synchronous:
sqlalchemy.event.listen(engine, 'connect',
_synchronous_switch_listener)
sqlalchemy.event.listen(engine, 'connect', _add_regexp_listener)
if connection_trace and engine.dialect.dbapi.__name__ == 'MySQLdb':
_patch_mysqldb_with_stacktrace_comments()
try:
engine.connect()
except sqla_exc.OperationalError as e:
if not _is_db_connection_error(e.args[0]):
raise
remaining = max_retries
if remaining == -1:
remaining = 'infinite'
while True:
msg = _LW('SQL connection failed. %s attempts left.')
LOG.warning(msg % remaining)
if remaining != 'infinite':
remaining -= 1
time.sleep(retry_interval)
try:
engine.connect()
break
except sqla_exc.OperationalError as e:
if (remaining != 'infinite' and remaining == 0) or \
not _is_db_connection_error(e.args[0]):
raise
return engine
class Query(sqlalchemy.orm.query.Query):
"""Subclass of sqlalchemy.query with soft_delete() method."""
def soft_delete(self, synchronize_session='evaluate'):
return self.update({'deleted': literal_column('id'),
'updated_at': literal_column('updated_at'),
'deleted_at': timeutils.utcnow()},
synchronize_session=synchronize_session)
class Session(sqlalchemy.orm.session.Session):
"""Custom Session class to avoid SqlAlchemy Session monkey patching."""
@_wrap_db_error
def query(self, *args, **kwargs):
return super(Session, self).query(*args, **kwargs)
@_wrap_db_error
def flush(self, *args, **kwargs):
return super(Session, self).flush(*args, **kwargs)
@_wrap_db_error
def execute(self, *args, **kwargs):
return super(Session, self).execute(*args, **kwargs)
@_wrap_db_error
def commit(self, *args, **kwargs):
return super(Session, self).commit(*args, **kwargs)
def begin(self, **kw):
trans = super(Session, self).begin(**kw)
trans.__class__ = SessionTransactionWrapper
return trans
class SessionTransactionWrapper(sqlalchemy.orm.session.SessionTransaction):
@property
def bind(self):
return self.session.bind
@_wrap_db_error
def commit(self, *args, **kwargs):
return super(SessionTransactionWrapper, self).commit(*args, **kwargs)
@_wrap_db_error
def rollback(self, *args, **kwargs):
return super(SessionTransactionWrapper, self).rollback(*args, **kwargs)
def get_maker(engine, autocommit=True, expire_on_commit=False):
"""Return a SQLAlchemy sessionmaker using the given engine."""
return sqlalchemy.orm.sessionmaker(bind=engine,
class_=Session,
autocommit=autocommit,
expire_on_commit=expire_on_commit,
query_cls=Query)
def _patch_mysqldb_with_stacktrace_comments():
"""Adds current stack trace as a comment in queries.
Patches MySQLdb.cursors.BaseCursor._do_query.
"""
import MySQLdb.cursors
import traceback
old_mysql_do_query = MySQLdb.cursors.BaseCursor._do_query
def _do_query(self, q):
stack = ''
for filename, line, method, function in traceback.extract_stack():
# exclude various common things from trace
if filename.endswith('session.py') and method == '_do_query':
continue
if filename.endswith('api.py') and method == 'wrapper':
continue
if filename.endswith('utils.py') and method == '_inner':
continue
if filename.endswith('exception.py') and method == '_wrap':
continue
# db/api is just a wrapper around db/sqlalchemy/api
if filename.endswith('db/api.py'):
continue
# only trace inside cerberus
index = filename.rfind('cerberus')
if index == -1:
continue
stack += "File:%s:%s Method:%s() Line:%s | " \
% (filename[index:], line, method, function)
# strip trailing " | " from stack
if stack:
stack = stack[:-3]
qq = "%s /* %s */" % (q, stack)
else:
qq = q
old_mysql_do_query(self, qq)
setattr(MySQLdb.cursors.BaseCursor, '_do_query', _do_query)
class EngineFacade(object):
"""A helper class for removing of global engine instances from cerberus.db.
As a library, cerberus.db can't decide where to store/when to create engine
and sessionmaker instances, so this must be left for a target application.
On the other hand, in order to simplify the adoption of cerberus.db changes,
we'll provide a helper class, which creates engine and sessionmaker
on its instantiation and provides get_engine()/get_session() methods
that are compatible with corresponding utility functions that currently
exist in target projects, e.g. in Nova.
engine/sessionmaker instances will still be global (and they are meant to
be global), but they will be stored in the app context, rather that in the
cerberus.db context.
Note: using of this helper is completely optional and you are encouraged to
integrate engine/sessionmaker instances into your apps any way you like
(e.g. one might want to bind a session to a request context). Two important
things to remember:
1. An Engine instance is effectively a pool of DB connections, so it's
meant to be shared (and it's thread-safe).
2. A Session instance is not meant to be shared and represents a DB
transactional context (i.e. it's not thread-safe). sessionmaker is
a factory of sessions.
"""
def __init__(self, sql_connection,
sqlite_fk=False, autocommit=True,
expire_on_commit=False, **kwargs):
"""Initialize engine and sessionmaker instances.
:param sqlite_fk: enable foreign keys in SQLite
:type sqlite_fk: bool
:param autocommit: use autocommit mode for created Session instances
:type autocommit: bool
:param expire_on_commit: expire session objects on commit
:type expire_on_commit: bool
Keyword arguments:
:keyword mysql_sql_mode: the SQL mode to be used for MySQL sessions.
(defaults to TRADITIONAL)
:keyword idle_timeout: timeout before idle sql connections are reaped
(defaults to 3600)
:keyword connection_debug: verbosity of SQL debugging information.
0=None, 100=Everything (defaults to 0)
:keyword max_pool_size: maximum number of SQL connections to keep open
in a pool (defaults to SQLAlchemy settings)
:keyword max_overflow: if set, use this value for max_overflow with
sqlalchemy (defaults to SQLAlchemy settings)
:keyword pool_timeout: if set, use this value for pool_timeout with
sqlalchemy (defaults to SQLAlchemy settings)
:keyword sqlite_synchronous: if True, SQLite uses synchronous mode
(defaults to True)
:keyword connection_trace: add python stack traces to SQL as comment
strings (defaults to False)
:keyword max_retries: maximum db connection retries during startup.
(setting -1 implies an infinite retry count)
(defaults to 10)
:keyword retry_interval: interval between retries of opening a sql
connection (defaults to 10)
"""
super(EngineFacade, self).__init__()
self._engine = create_engine(
sql_connection=sql_connection,
sqlite_fk=sqlite_fk,
mysql_sql_mode=kwargs.get('mysql_sql_mode', 'TRADITIONAL'),
idle_timeout=kwargs.get('idle_timeout', 3600),
connection_debug=kwargs.get('connection_debug', 0),
max_pool_size=kwargs.get('max_pool_size'),
max_overflow=kwargs.get('max_overflow'),
pool_timeout=kwargs.get('pool_timeout'),
sqlite_synchronous=kwargs.get('sqlite_synchronous', True),
connection_trace=kwargs.get('connection_trace', False),
max_retries=kwargs.get('max_retries', 10),
retry_interval=kwargs.get('retry_interval', 10))
self._session_maker = get_maker(
engine=self._engine,
autocommit=autocommit,
expire_on_commit=expire_on_commit)
def get_engine(self):
"""Get the engine instance (note, that it's shared)."""
return self._engine
def get_session(self, **kwargs):
"""Get a Session instance.
If passed, keyword arguments values override the ones used when the
sessionmaker instance was created.
:keyword autocommit: use autocommit mode for created Session instances
:type autocommit: bool
:keyword expire_on_commit: expire session objects on commit
:type expire_on_commit: bool
"""
for arg in kwargs:
if arg not in ('autocommit', 'expire_on_commit'):
del kwargs[arg]
return self._session_maker(**kwargs)
@classmethod
def from_config(cls, connection_string, conf,
sqlite_fk=False, autocommit=True, expire_on_commit=False):
"""Initialize EngineFacade using oslo.config config instance options.
:param connection_string: SQLAlchemy connection string
:type connection_string: string
:param conf: oslo.config config instance
:type conf: oslo.config.cfg.ConfigOpts
:param sqlite_fk: enable foreign keys in SQLite
:type sqlite_fk: bool
:param autocommit: use autocommit mode for created Session instances
:type autocommit: bool
:param expire_on_commit: expire session objects on commit
:type expire_on_commit: bool
"""
return cls(sql_connection=connection_string,
sqlite_fk=sqlite_fk,
autocommit=autocommit,
expire_on_commit=expire_on_commit,
**dict(conf.database.items()))

View File

@ -0,0 +1,153 @@
# Copyright (c) 2013 OpenStack Foundation
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import abc
import functools
import os
import fixtures
import six
from cerberus.openstack.common.db.sqlalchemy import session
from cerberus.openstack.common.db.sqlalchemy import utils
from cerberus.openstack.common.fixture import lockutils
from cerberus.openstack.common import test
class DbFixture(fixtures.Fixture):
"""Basic database fixture.
Allows to run tests on various db backends, such as SQLite, MySQL and
PostgreSQL. By default use sqlite backend. To override default backend
uri set env variable OS_TEST_DBAPI_CONNECTION with database admin
credentials for specific backend.
"""
def _get_uri(self):
return os.getenv('OS_TEST_DBAPI_CONNECTION', 'sqlite://')
def __init__(self, test):
super(DbFixture, self).__init__()
self.test = test
def setUp(self):
super(DbFixture, self).setUp()
self.test.engine = session.create_engine(self._get_uri())
self.test.sessionmaker = session.get_maker(self.test.engine)
self.addCleanup(self.test.engine.dispose)
class DbTestCase(test.BaseTestCase):
"""Base class for testing of DB code.
Using `DbFixture`. Intended to be the main database test case to use all
the tests on a given backend with user defined uri. Backend specific
tests should be decorated with `backend_specific` decorator.
"""
FIXTURE = DbFixture
def setUp(self):
super(DbTestCase, self).setUp()
self.useFixture(self.FIXTURE(self))
ALLOWED_DIALECTS = ['sqlite', 'mysql', 'postgresql']
def backend_specific(*dialects):
"""Decorator to skip backend specific tests on inappropriate engines.
::dialects: list of dialects names under which the test will be launched.
"""
def wrap(f):
@functools.wraps(f)
def ins_wrap(self):
if not set(dialects).issubset(ALLOWED_DIALECTS):
raise ValueError(
"Please use allowed dialects: %s" % ALLOWED_DIALECTS)
if self.engine.name not in dialects:
msg = ('The test "%s" can be run '
'only on %s. Current engine is %s.')
args = (f.__name__, ' '.join(dialects), self.engine.name)
self.skip(msg % args)
else:
return f(self)
return ins_wrap
return wrap
@six.add_metaclass(abc.ABCMeta)
class OpportunisticFixture(DbFixture):
"""Base fixture to use default CI databases.
The databases exist in OpenStack CI infrastructure. But for the
correct functioning in local environment the databases must be
created manually.
"""
DRIVER = abc.abstractproperty(lambda: None)
DBNAME = PASSWORD = USERNAME = 'openstack_citest'
def _get_uri(self):
return utils.get_connect_string(backend=self.DRIVER,
user=self.USERNAME,
passwd=self.PASSWORD,
database=self.DBNAME)
@six.add_metaclass(abc.ABCMeta)
class OpportunisticTestCase(DbTestCase):
"""Base test case to use default CI databases.
The subclasses of the test case are running only when openstack_citest
database is available otherwise a tests will be skipped.
"""
FIXTURE = abc.abstractproperty(lambda: None)
def setUp(self):
# TODO(bnemec): Remove this once infra is ready for
# https://review.openstack.org/#/c/74963/ to merge.
self.useFixture(lockutils.LockFixture('opportunistic-db'))
credentials = {
'backend': self.FIXTURE.DRIVER,
'user': self.FIXTURE.USERNAME,
'passwd': self.FIXTURE.PASSWORD,
'database': self.FIXTURE.DBNAME}
if self.FIXTURE.DRIVER and not utils.is_backend_avail(**credentials):
msg = '%s backend is not available.' % self.FIXTURE.DRIVER
return self.skip(msg)
super(OpportunisticTestCase, self).setUp()
class MySQLOpportunisticFixture(OpportunisticFixture):
DRIVER = 'mysql'
class PostgreSQLOpportunisticFixture(OpportunisticFixture):
DRIVER = 'postgresql'
class MySQLOpportunisticTestCase(OpportunisticTestCase):
FIXTURE = MySQLOpportunisticFixture
class PostgreSQLOpportunisticTestCase(OpportunisticTestCase):
FIXTURE = PostgreSQLOpportunisticFixture

View File

@ -0,0 +1,269 @@
# Copyright 2010-2011 OpenStack Foundation
# Copyright 2012-2013 IBM Corp.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import functools
import logging
import os
import subprocess
import lockfile
from six import moves
from six.moves.urllib import parse
import sqlalchemy
import sqlalchemy.exc
from cerberus.openstack.common.db.sqlalchemy import utils
from cerberus.openstack.common.gettextutils import _LE
from cerberus.openstack.common import test
LOG = logging.getLogger(__name__)
def _have_mysql(user, passwd, database):
present = os.environ.get('TEST_MYSQL_PRESENT')
if present is None:
return utils.is_backend_avail(backend='mysql',
user=user,
passwd=passwd,
database=database)
return present.lower() in ('', 'true')
def _have_postgresql(user, passwd, database):
present = os.environ.get('TEST_POSTGRESQL_PRESENT')
if present is None:
return utils.is_backend_avail(backend='postgres',
user=user,
passwd=passwd,
database=database)
return present.lower() in ('', 'true')
def _set_db_lock(lock_path=None, lock_prefix=None):
def decorator(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
try:
path = lock_path or os.environ.get("CERBERUS_LOCK_PATH")
lock = lockfile.FileLock(os.path.join(path, lock_prefix))
with lock:
LOG.debug('Got lock "%s"' % f.__name__)
return f(*args, **kwargs)
finally:
LOG.debug('Lock released "%s"' % f.__name__)
return wrapper
return decorator
class BaseMigrationTestCase(test.BaseTestCase):
"""Base class fort testing of migration utils."""
def __init__(self, *args, **kwargs):
super(BaseMigrationTestCase, self).__init__(*args, **kwargs)
self.DEFAULT_CONFIG_FILE = os.path.join(os.path.dirname(__file__),
'test_migrations.conf')
# Test machines can set the TEST_MIGRATIONS_CONF variable
# to override the location of the config file for migration testing
self.CONFIG_FILE_PATH = os.environ.get('TEST_MIGRATIONS_CONF',
self.DEFAULT_CONFIG_FILE)
self.test_databases = {}
self.migration_api = None
def setUp(self):
super(BaseMigrationTestCase, self).setUp()
# Load test databases from the config file. Only do this
# once. No need to re-run this on each test...
LOG.debug('config_path is %s' % self.CONFIG_FILE_PATH)
if os.path.exists(self.CONFIG_FILE_PATH):
cp = moves.configparser.RawConfigParser()
try:
cp.read(self.CONFIG_FILE_PATH)
defaults = cp.defaults()
for key, value in defaults.items():
self.test_databases[key] = value
except moves.configparser.ParsingError as e:
self.fail("Failed to read test_migrations.conf config "
"file. Got error: %s" % e)
else:
self.fail("Failed to find test_migrations.conf config "
"file.")
self.engines = {}
for key, value in self.test_databases.items():
self.engines[key] = sqlalchemy.create_engine(value)
# We start each test case with a completely blank slate.
self._reset_databases()
def tearDown(self):
# We destroy the test data store between each test case,
# and recreate it, which ensures that we have no side-effects
# from the tests
self._reset_databases()
super(BaseMigrationTestCase, self).tearDown()
def execute_cmd(self, cmd=None):
process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
output = process.communicate()[0]
LOG.debug(output)
self.assertEqual(0, process.returncode,
"Failed to run: %s\n%s" % (cmd, output))
def _reset_pg(self, conn_pieces):
(user,
password,
database,
host) = utils.get_db_connection_info(conn_pieces)
os.environ['PGPASSWORD'] = password
os.environ['PGUSER'] = user
# note(boris-42): We must create and drop database, we can't
# drop database which we have connected to, so for such
# operations there is a special database template1.
sqlcmd = ("psql -w -U %(user)s -h %(host)s -c"
" '%(sql)s' -d template1")
sql = ("drop database if exists %s;") % database
droptable = sqlcmd % {'user': user, 'host': host, 'sql': sql}
self.execute_cmd(droptable)
sql = ("create database %s;") % database
createtable = sqlcmd % {'user': user, 'host': host, 'sql': sql}
self.execute_cmd(createtable)
os.unsetenv('PGPASSWORD')
os.unsetenv('PGUSER')
@_set_db_lock(lock_prefix='migration_tests-')
def _reset_databases(self):
for key, engine in self.engines.items():
conn_string = self.test_databases[key]
conn_pieces = parse.urlparse(conn_string)
engine.dispose()
if conn_string.startswith('sqlite'):
# We can just delete the SQLite database, which is
# the easiest and cleanest solution
db_path = conn_pieces.path.strip('/')
if os.path.exists(db_path):
os.unlink(db_path)
# No need to recreate the SQLite DB. SQLite will
# create it for us if it's not there...
elif conn_string.startswith('mysql'):
# We can execute the MySQL client to destroy and re-create
# the MYSQL database, which is easier and less error-prone
# than using SQLAlchemy to do this via MetaData...trust me.
(user, password, database, host) = \
utils.get_db_connection_info(conn_pieces)
sql = ("drop database if exists %(db)s; "
"create database %(db)s;") % {'db': database}
cmd = ("mysql -u \"%(user)s\" -p\"%(password)s\" -h %(host)s "
"-e \"%(sql)s\"") % {'user': user, 'password': password,
'host': host, 'sql': sql}
self.execute_cmd(cmd)
elif conn_string.startswith('postgresql'):
self._reset_pg(conn_pieces)
class WalkVersionsMixin(object):
def _walk_versions(self, engine=None, snake_walk=False, downgrade=True):
# Determine latest version script from the repo, then
# upgrade from 1 through to the latest, with no data
# in the databases. This just checks that the schema itself
# upgrades successfully.
# Place the database under version control
self.migration_api.version_control(engine, self.REPOSITORY,
self.INIT_VERSION)
self.assertEqual(self.INIT_VERSION,
self.migration_api.db_version(engine,
self.REPOSITORY))
LOG.debug('latest version is %s' % self.REPOSITORY.latest)
versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1)
for version in versions:
# upgrade -> downgrade -> upgrade
self._migrate_up(engine, version, with_data=True)
if snake_walk:
downgraded = self._migrate_down(
engine, version - 1, with_data=True)
if downgraded:
self._migrate_up(engine, version)
if downgrade:
# Now walk it back down to 0 from the latest, testing
# the downgrade paths.
for version in reversed(versions):
# downgrade -> upgrade -> downgrade
downgraded = self._migrate_down(engine, version - 1)
if snake_walk and downgraded:
self._migrate_up(engine, version)
self._migrate_down(engine, version - 1)
def _migrate_down(self, engine, version, with_data=False):
try:
self.migration_api.downgrade(engine, self.REPOSITORY, version)
except NotImplementedError:
# NOTE(sirp): some migrations, namely release-level
# migrations, don't support a downgrade.
return False
self.assertEqual(
version, self.migration_api.db_version(engine, self.REPOSITORY))
# NOTE(sirp): `version` is what we're downgrading to (i.e. the 'target'
# version). So if we have any downgrade checks, they need to be run for
# the previous (higher numbered) migration.
if with_data:
post_downgrade = getattr(
self, "_post_downgrade_%03d" % (version + 1), None)
if post_downgrade:
post_downgrade(engine)
return True
def _migrate_up(self, engine, version, with_data=False):
"""migrate up to a new version of the db.
We allow for data insertion and post checks at every
migration version with special _pre_upgrade_### and
_check_### functions in the main test.
"""
# NOTE(sdague): try block is here because it's impossible to debug
# where a failed data migration happens otherwise
try:
if with_data:
data = None
pre_upgrade = getattr(
self, "_pre_upgrade_%03d" % version, None)
if pre_upgrade:
data = pre_upgrade(engine)
self.migration_api.upgrade(engine, self.REPOSITORY, version)
self.assertEqual(version,
self.migration_api.db_version(engine,
self.REPOSITORY))
if with_data:
check = getattr(self, "_check_%03d" % version, None)
if check:
check(engine, data)
except Exception:
LOG.error(_LE("Failed to migrate to version %s on engine %s") %
(version, engine))
raise

View File

@ -0,0 +1,647 @@
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# Copyright 2010-2011 OpenStack Foundation.
# Copyright 2012 Justin Santa Barbara
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import logging
import re
import sqlalchemy
from sqlalchemy import Boolean
from sqlalchemy import CheckConstraint
from sqlalchemy import Column
from sqlalchemy.engine import reflection
from sqlalchemy.ext.compiler import compiles
from sqlalchemy import func
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import or_
from sqlalchemy.sql.expression import literal_column
from sqlalchemy.sql.expression import UpdateBase
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy.types import NullType
from cerberus.openstack.common import context as request_context
from cerberus.openstack.common.db.sqlalchemy import models
from cerberus.openstack.common.gettextutils import _, _LI, _LW
from cerberus.openstack.common import timeutils
LOG = logging.getLogger(__name__)
_DBURL_REGEX = re.compile(r"[^:]+://([^:]+):([^@]+)@.+")
def sanitize_db_url(url):
match = _DBURL_REGEX.match(url)
if match:
return '%s****:****%s' % (url[:match.start(1)], url[match.end(2):])
return url
class InvalidSortKey(Exception):
message = _("Sort key supplied was not valid.")
# copy from glance/db/sqlalchemy/api.py
def paginate_query(query, model, limit, sort_keys, marker=None,
sort_dir=None, sort_dirs=None):
"""Returns a query with sorting / pagination criteria added.
Pagination works by requiring a unique sort_key, specified by sort_keys.
(If sort_keys is not unique, then we risk looping through values.)
We use the last row in the previous page as the 'marker' for pagination.
So we must return values that follow the passed marker in the order.
With a single-valued sort_key, this would be easy: sort_key > X.
With a compound-values sort_key, (k1, k2, k3) we must do this to repeat
the lexicographical ordering:
(k1 > X1) or (k1 == X1 && k2 > X2) or (k1 == X1 && k2 == X2 && k3 > X3)
We also have to cope with different sort_directions.
Typically, the id of the last row is used as the client-facing pagination
marker, then the actual marker object must be fetched from the db and
passed in to us as marker.
:param query: the query object to which we should add paging/sorting
:param model: the ORM model class
:param limit: maximum number of items to return
:param sort_keys: array of attributes by which results should be sorted
:param marker: the last item of the previous page; we returns the next
results after this value.
:param sort_dir: direction in which results should be sorted (asc, desc)
:param sort_dirs: per-column array of sort_dirs, corresponding to sort_keys
:rtype: sqlalchemy.orm.query.Query
:return: The query with sorting/pagination added.
"""
if 'id' not in sort_keys:
# TODO(justinsb): If this ever gives a false-positive, check
# the actual primary key, rather than assuming its id
LOG.warning(_LW('Id not in sort_keys; is sort_keys unique?'))
assert(not (sort_dir and sort_dirs))
# Default the sort direction to ascending
if sort_dirs is None and sort_dir is None:
sort_dir = 'asc'
# Ensure a per-column sort direction
if sort_dirs is None:
sort_dirs = [sort_dir for _sort_key in sort_keys]
assert(len(sort_dirs) == len(sort_keys))
# Add sorting
for current_sort_key, current_sort_dir in zip(sort_keys, sort_dirs):
try:
sort_dir_func = {
'asc': sqlalchemy.asc,
'desc': sqlalchemy.desc,
}[current_sort_dir]
except KeyError:
raise ValueError(_("Unknown sort direction, "
"must be 'desc' or 'asc'"))
try:
sort_key_attr = getattr(model, current_sort_key)
except AttributeError:
raise InvalidSortKey()
query = query.order_by(sort_dir_func(sort_key_attr))
# Add pagination
if marker is not None:
marker_values = []
for sort_key in sort_keys:
v = getattr(marker, sort_key)
marker_values.append(v)
# Build up an array of sort criteria as in the docstring
criteria_list = []
for i in range(len(sort_keys)):
crit_attrs = []
for j in range(i):
model_attr = getattr(model, sort_keys[j])
crit_attrs.append((model_attr == marker_values[j]))
model_attr = getattr(model, sort_keys[i])
if sort_dirs[i] == 'desc':
crit_attrs.append((model_attr < marker_values[i]))
else:
crit_attrs.append((model_attr > marker_values[i]))
criteria = sqlalchemy.sql.and_(*crit_attrs)
criteria_list.append(criteria)
f = sqlalchemy.sql.or_(*criteria_list)
query = query.filter(f)
if limit is not None:
query = query.limit(limit)
return query
def _read_deleted_filter(query, db_model, read_deleted):
if 'deleted' not in db_model.__table__.columns:
raise ValueError(_("There is no `deleted` column in `%s` table. "
"Project doesn't use soft-deleted feature.")
% db_model.__name__)
default_deleted_value = db_model.__table__.c.deleted.default.arg
if read_deleted == 'no':
query = query.filter(db_model.deleted == default_deleted_value)
elif read_deleted == 'yes':
pass # omit the filter to include deleted and active
elif read_deleted == 'only':
query = query.filter(db_model.deleted != default_deleted_value)
else:
raise ValueError(_("Unrecognized read_deleted value '%s'")
% read_deleted)
return query
def _project_filter(query, db_model, context, project_only):
if project_only and 'project_id' not in db_model.__table__.columns:
raise ValueError(_("There is no `project_id` column in `%s` table.")
% db_model.__name__)
if request_context.is_user_context(context) and project_only:
if project_only == 'allow_none':
is_none = None
query = query.filter(or_(db_model.project_id == context.project_id,
db_model.project_id == is_none))
else:
query = query.filter(db_model.project_id == context.project_id)
return query
def model_query(context, model, session, args=None, project_only=False,
read_deleted=None):
"""Query helper that accounts for context's `read_deleted` field.
:param context: context to query under
:param model: Model to query. Must be a subclass of ModelBase.
:type model: models.ModelBase
:param session: The session to use.
:type session: sqlalchemy.orm.session.Session
:param args: Arguments to query. If None - model is used.
:type args: tuple
:param project_only: If present and context is user-type, then restrict
query to match the context's project_id. If set to
'allow_none', restriction includes project_id = None.
:type project_only: bool
:param read_deleted: If present, overrides context's read_deleted field.
:type read_deleted: bool
Usage:
..code:: python
result = (utils.model_query(context, models.Instance, session=session)
.filter_by(uuid=instance_uuid)
.all())
query = utils.model_query(
context, Node,
session=session,
args=(func.count(Node.id), func.sum(Node.ram))
).filter_by(project_id=project_id)
"""
if not read_deleted:
if hasattr(context, 'read_deleted'):
# NOTE(viktors): some projects use `read_deleted` attribute in
# their contexts instead of `show_deleted`.
read_deleted = context.read_deleted
else:
read_deleted = context.show_deleted
if not issubclass(model, models.ModelBase):
raise TypeError(_("model should be a subclass of ModelBase"))
query = session.query(model) if not args else session.query(*args)
query = _read_deleted_filter(query, model, read_deleted)
query = _project_filter(query, model, context, project_only)
return query
def get_table(engine, name):
"""Returns an sqlalchemy table dynamically from db.
Needed because the models don't work for us in migrations
as models will be far out of sync with the current data.
"""
metadata = MetaData()
metadata.bind = engine
return Table(name, metadata, autoload=True)
class InsertFromSelect(UpdateBase):
"""Form the base for `INSERT INTO table (SELECT ... )` statement."""
def __init__(self, table, select):
self.table = table
self.select = select
@compiles(InsertFromSelect)
def visit_insert_from_select(element, compiler, **kw):
"""Form the `INSERT INTO table (SELECT ... )` statement."""
return "INSERT INTO %s %s" % (
compiler.process(element.table, asfrom=True),
compiler.process(element.select))
class ColumnError(Exception):
"""Error raised when no column or an invalid column is found."""
def _get_not_supported_column(col_name_col_instance, column_name):
try:
column = col_name_col_instance[column_name]
except KeyError:
msg = _("Please specify column %s in col_name_col_instance "
"param. It is required because column has unsupported "
"type by sqlite).")
raise ColumnError(msg % column_name)
if not isinstance(column, Column):
msg = _("col_name_col_instance param has wrong type of "
"column instance for column %s It should be instance "
"of sqlalchemy.Column.")
raise ColumnError(msg % column_name)
return column
def drop_unique_constraint(migrate_engine, table_name, uc_name, *columns,
**col_name_col_instance):
"""Drop unique constraint from table.
DEPRECATED: this function is deprecated and will be removed from cerberus.db
in a few releases. Please use UniqueConstraint.drop() method directly for
sqlalchemy-migrate migration scripts.
This method drops UC from table and works for mysql, postgresql and sqlite.
In mysql and postgresql we are able to use "alter table" construction.
Sqlalchemy doesn't support some sqlite column types and replaces their
type with NullType in metadata. We process these columns and replace
NullType with the correct column type.
:param migrate_engine: sqlalchemy engine
:param table_name: name of table that contains uniq constraint.
:param uc_name: name of uniq constraint that will be dropped.
:param columns: columns that are in uniq constraint.
:param col_name_col_instance: contains pair column_name=column_instance.
column_instance is instance of Column. These params
are required only for columns that have unsupported
types by sqlite. For example BigInteger.
"""
from migrate.changeset import UniqueConstraint
meta = MetaData()
meta.bind = migrate_engine
t = Table(table_name, meta, autoload=True)
if migrate_engine.name == "sqlite":
override_cols = [
_get_not_supported_column(col_name_col_instance, col.name)
for col in t.columns
if isinstance(col.type, NullType)
]
for col in override_cols:
t.columns.replace(col)
uc = UniqueConstraint(*columns, table=t, name=uc_name)
uc.drop()
def drop_old_duplicate_entries_from_table(migrate_engine, table_name,
use_soft_delete, *uc_column_names):
"""Drop all old rows having the same values for columns in uc_columns.
This method drop (or mark ad `deleted` if use_soft_delete is True) old
duplicate rows form table with name `table_name`.
:param migrate_engine: Sqlalchemy engine
:param table_name: Table with duplicates
:param use_soft_delete: If True - values will be marked as `deleted`,
if False - values will be removed from table
:param uc_column_names: Unique constraint columns
"""
meta = MetaData()
meta.bind = migrate_engine
table = Table(table_name, meta, autoload=True)
columns_for_group_by = [table.c[name] for name in uc_column_names]
columns_for_select = [func.max(table.c.id)]
columns_for_select.extend(columns_for_group_by)
duplicated_rows_select = sqlalchemy.sql.select(
columns_for_select, group_by=columns_for_group_by,
having=func.count(table.c.id) > 1)
for row in migrate_engine.execute(duplicated_rows_select):
# NOTE(boris-42): Do not remove row that has the biggest ID.
delete_condition = table.c.id != row[0]
is_none = None # workaround for pyflakes
delete_condition &= table.c.deleted_at == is_none
for name in uc_column_names:
delete_condition &= table.c[name] == row[name]
rows_to_delete_select = sqlalchemy.sql.select(
[table.c.id]).where(delete_condition)
for row in migrate_engine.execute(rows_to_delete_select).fetchall():
LOG.info(_LI("Deleting duplicated row with id: %(id)s from table: "
"%(table)s") % dict(id=row[0], table=table_name))
if use_soft_delete:
delete_statement = table.update().\
where(delete_condition).\
values({
'deleted': literal_column('id'),
'updated_at': literal_column('updated_at'),
'deleted_at': timeutils.utcnow()
})
else:
delete_statement = table.delete().where(delete_condition)
migrate_engine.execute(delete_statement)
def _get_default_deleted_value(table):
if isinstance(table.c.id.type, Integer):
return 0
if isinstance(table.c.id.type, String):
return ""
raise ColumnError(_("Unsupported id columns type"))
def _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes):
table = get_table(migrate_engine, table_name)
insp = reflection.Inspector.from_engine(migrate_engine)
real_indexes = insp.get_indexes(table_name)
existing_index_names = dict(
[(index['name'], index['column_names']) for index in real_indexes])
# NOTE(boris-42): Restore indexes on `deleted` column
for index in indexes:
if 'deleted' not in index['column_names']:
continue
name = index['name']
if name in existing_index_names:
column_names = [table.c[c] for c in existing_index_names[name]]
old_index = Index(name, *column_names, unique=index["unique"])
old_index.drop(migrate_engine)
column_names = [table.c[c] for c in index['column_names']]
new_index = Index(index["name"], *column_names, unique=index["unique"])
new_index.create(migrate_engine)
def change_deleted_column_type_to_boolean(migrate_engine, table_name,
**col_name_col_instance):
if migrate_engine.name == "sqlite":
return _change_deleted_column_type_to_boolean_sqlite(
migrate_engine, table_name, **col_name_col_instance)
insp = reflection.Inspector.from_engine(migrate_engine)
indexes = insp.get_indexes(table_name)
table = get_table(migrate_engine, table_name)
old_deleted = Column('old_deleted', Boolean, default=False)
old_deleted.create(table, populate_default=False)
table.update().\
where(table.c.deleted == table.c.id).\
values(old_deleted=True).\
execute()
table.c.deleted.drop()
table.c.old_deleted.alter(name="deleted")
_restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes)
def _change_deleted_column_type_to_boolean_sqlite(migrate_engine, table_name,
**col_name_col_instance):
insp = reflection.Inspector.from_engine(migrate_engine)
table = get_table(migrate_engine, table_name)
columns = []
for column in table.columns:
column_copy = None
if column.name != "deleted":
if isinstance(column.type, NullType):
column_copy = _get_not_supported_column(col_name_col_instance,
column.name)
else:
column_copy = column.copy()
else:
column_copy = Column('deleted', Boolean, default=0)
columns.append(column_copy)
constraints = [constraint.copy() for constraint in table.constraints]
meta = table.metadata
new_table = Table(table_name + "__tmp__", meta,
*(columns + constraints))
new_table.create()
indexes = []
for index in insp.get_indexes(table_name):
column_names = [new_table.c[c] for c in index['column_names']]
indexes.append(Index(index["name"], *column_names,
unique=index["unique"]))
c_select = []
for c in table.c:
if c.name != "deleted":
c_select.append(c)
else:
c_select.append(table.c.deleted == table.c.id)
ins = InsertFromSelect(new_table, sqlalchemy.sql.select(c_select))
migrate_engine.execute(ins)
table.drop()
[index.create(migrate_engine) for index in indexes]
new_table.rename(table_name)
new_table.update().\
where(new_table.c.deleted == new_table.c.id).\
values(deleted=True).\
execute()
def change_deleted_column_type_to_id_type(migrate_engine, table_name,
**col_name_col_instance):
if migrate_engine.name == "sqlite":
return _change_deleted_column_type_to_id_type_sqlite(
migrate_engine, table_name, **col_name_col_instance)
insp = reflection.Inspector.from_engine(migrate_engine)
indexes = insp.get_indexes(table_name)
table = get_table(migrate_engine, table_name)
new_deleted = Column('new_deleted', table.c.id.type,
default=_get_default_deleted_value(table))
new_deleted.create(table, populate_default=True)
deleted = True # workaround for pyflakes
table.update().\
where(table.c.deleted == deleted).\
values(new_deleted=table.c.id).\
execute()
table.c.deleted.drop()
table.c.new_deleted.alter(name="deleted")
_restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes)
def _change_deleted_column_type_to_id_type_sqlite(migrate_engine, table_name,
**col_name_col_instance):
# NOTE(boris-42): sqlaclhemy-migrate can't drop column with check
# constraints in sqlite DB and our `deleted` column has
# 2 check constraints. So there is only one way to remove
# these constraints:
# 1) Create new table with the same columns, constraints
# and indexes. (except deleted column).
# 2) Copy all data from old to new table.
# 3) Drop old table.
# 4) Rename new table to old table name.
insp = reflection.Inspector.from_engine(migrate_engine)
meta = MetaData(bind=migrate_engine)
table = Table(table_name, meta, autoload=True)
default_deleted_value = _get_default_deleted_value(table)
columns = []
for column in table.columns:
column_copy = None
if column.name != "deleted":
if isinstance(column.type, NullType):
column_copy = _get_not_supported_column(col_name_col_instance,
column.name)
else:
column_copy = column.copy()
else:
column_copy = Column('deleted', table.c.id.type,
default=default_deleted_value)
columns.append(column_copy)
def is_deleted_column_constraint(constraint):
# NOTE(boris-42): There is no other way to check is CheckConstraint
# associated with deleted column.
if not isinstance(constraint, CheckConstraint):
return False
sqltext = str(constraint.sqltext)
return (sqltext.endswith("deleted in (0, 1)") or
sqltext.endswith("deleted IN (:deleted_1, :deleted_2)"))
constraints = []
for constraint in table.constraints:
if not is_deleted_column_constraint(constraint):
constraints.append(constraint.copy())
new_table = Table(table_name + "__tmp__", meta,
*(columns + constraints))
new_table.create()
indexes = []
for index in insp.get_indexes(table_name):
column_names = [new_table.c[c] for c in index['column_names']]
indexes.append(Index(index["name"], *column_names,
unique=index["unique"]))
ins = InsertFromSelect(new_table, table.select())
migrate_engine.execute(ins)
table.drop()
[index.create(migrate_engine) for index in indexes]
new_table.rename(table_name)
deleted = True # workaround for pyflakes
new_table.update().\
where(new_table.c.deleted == deleted).\
values(deleted=new_table.c.id).\
execute()
# NOTE(boris-42): Fix value of deleted column: False -> "" or 0.
deleted = False # workaround for pyflakes
new_table.update().\
where(new_table.c.deleted == deleted).\
values(deleted=default_deleted_value).\
execute()
def get_connect_string(backend, database, user=None, passwd=None):
"""Get database connection
Try to get a connection with a very specific set of values, if we get
these then we'll run the tests, otherwise they are skipped
"""
args = {'backend': backend,
'user': user,
'passwd': passwd,
'database': database}
if backend == 'sqlite':
template = '%(backend)s:///%(database)s'
else:
template = "%(backend)s://%(user)s:%(passwd)s@localhost/%(database)s"
return template % args
def is_backend_avail(backend, database, user=None, passwd=None):
try:
connect_uri = get_connect_string(backend=backend,
database=database,
user=user,
passwd=passwd)
engine = sqlalchemy.create_engine(connect_uri)
connection = engine.connect()
except Exception:
# intentionally catch all to handle exceptions even if we don't
# have any backend code loaded.
return False
else:
connection.close()
engine.dispose()
return True
def get_db_connection_info(conn_pieces):
database = conn_pieces.path.strip('/')
loc_pieces = conn_pieces.netloc.split('@')
host = loc_pieces[1]
auth_pieces = loc_pieces[0].split(':')
user = auth_pieces[0]
password = ""
if len(auth_pieces) > 1:
password = auth_pieces[1].strip()
return (user, password, database, host)

View File

@ -0,0 +1,146 @@
# Copyright (c) 2012 OpenStack Foundation.
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import print_function
import errno
import gc
import os
import pprint
import socket
import sys
import traceback
import eventlet
import eventlet.backdoor
import greenlet
from oslo.config import cfg
from cerberus.openstack.common.gettextutils import _LI
from cerberus.openstack.common import log as logging
help_for_backdoor_port = (
"Acceptable values are 0, <port>, and <start>:<end>, where 0 results "
"in listening on a random tcp port number; <port> results in listening "
"on the specified port number (and not enabling backdoor if that port "
"is in use); and <start>:<end> results in listening on the smallest "
"unused port number within the specified range of port numbers. The "
"chosen port is displayed in the service's log file.")
eventlet_backdoor_opts = [
cfg.StrOpt('backdoor_port',
default=None,
help="Enable eventlet backdoor. %s" % help_for_backdoor_port)
]
CONF = cfg.CONF
CONF.register_opts(eventlet_backdoor_opts)
LOG = logging.getLogger(__name__)
class EventletBackdoorConfigValueError(Exception):
def __init__(self, port_range, help_msg, ex):
msg = ('Invalid backdoor_port configuration %(range)s: %(ex)s. '
'%(help)s' %
{'range': port_range, 'ex': ex, 'help': help_msg})
super(EventletBackdoorConfigValueError, self).__init__(msg)
self.port_range = port_range
def _dont_use_this():
print("Don't use this, just disconnect instead")
def _find_objects(t):
return [o for o in gc.get_objects() if isinstance(o, t)]
def _print_greenthreads():
for i, gt in enumerate(_find_objects(greenlet.greenlet)):
print(i, gt)
traceback.print_stack(gt.gr_frame)
print()
def _print_nativethreads():
for threadId, stack in sys._current_frames().items():
print(threadId)
traceback.print_stack(stack)
print()
def _parse_port_range(port_range):
if ':' not in port_range:
start, end = port_range, port_range
else:
start, end = port_range.split(':', 1)
try:
start, end = int(start), int(end)
if end < start:
raise ValueError
return start, end
except ValueError as ex:
raise EventletBackdoorConfigValueError(port_range, ex,
help_for_backdoor_port)
def _listen(host, start_port, end_port, listen_func):
try_port = start_port
while True:
try:
return listen_func((host, try_port))
except socket.error as exc:
if (exc.errno != errno.EADDRINUSE or
try_port >= end_port):
raise
try_port += 1
def initialize_if_enabled():
backdoor_locals = {
'exit': _dont_use_this, # So we don't exit the entire process
'quit': _dont_use_this, # So we don't exit the entire process
'fo': _find_objects,
'pgt': _print_greenthreads,
'pnt': _print_nativethreads,
}
if CONF.backdoor_port is None:
return None
start_port, end_port = _parse_port_range(str(CONF.backdoor_port))
# NOTE(johannes): The standard sys.displayhook will print the value of
# the last expression and set it to __builtin__._, which overwrites
# the __builtin__._ that gettext sets. Let's switch to using pprint
# since it won't interact poorly with gettext, and it's easier to
# read the output too.
def displayhook(val):
if val is not None:
pprint.pprint(val)
sys.displayhook = displayhook
sock = _listen('localhost', start_port, end_port, eventlet.listen)
# In the case of backdoor port being zero, a port number is assigned by
# listen(). In any case, pull the port number out here.
port = sock.getsockname()[1]
LOG.info(
_LI('Eventlet backdoor listening on %(port)s for process %(pid)d') %
{'port': port, 'pid': os.getpid()}
)
eventlet.spawn_n(eventlet.backdoor.backdoor_server, sock,
locals=backdoor_locals)
return port

View File

@ -0,0 +1,113 @@
# Copyright 2011 OpenStack Foundation.
# Copyright 2012, Red Hat, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Exception related utilities.
"""
import logging
import sys
import time
import traceback
import six
from cerberus.openstack.common.gettextutils import _LE
class save_and_reraise_exception(object):
"""Save current exception, run some code and then re-raise.
In some cases the exception context can be cleared, resulting in None
being attempted to be re-raised after an exception handler is run. This
can happen when eventlet switches greenthreads or when running an
exception handler, code raises and catches an exception. In both
cases the exception context will be cleared.
To work around this, we save the exception state, run handler code, and
then re-raise the original exception. If another exception occurs, the
saved exception is logged and the new exception is re-raised.
In some cases the caller may not want to re-raise the exception, and
for those circumstances this context provides a reraise flag that
can be used to suppress the exception. For example::
except Exception:
with save_and_reraise_exception() as ctxt:
decide_if_need_reraise()
if not should_be_reraised:
ctxt.reraise = False
If another exception occurs and reraise flag is False,
the saved exception will not be logged.
If the caller wants to raise new exception during exception handling
he/she sets reraise to False initially with an ability to set it back to
True if needed::
except Exception:
with save_and_reraise_exception(reraise=False) as ctxt:
[if statements to determine whether to raise a new exception]
# Not raising a new exception, so reraise
ctxt.reraise = True
"""
def __init__(self, reraise=True):
self.reraise = reraise
def __enter__(self):
self.type_, self.value, self.tb, = sys.exc_info()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
if self.reraise:
logging.error(_LE('Original exception being dropped: %s'),
traceback.format_exception(self.type_,
self.value,
self.tb))
return False
if self.reraise:
six.reraise(self.type_, self.value, self.tb)
def forever_retry_uncaught_exceptions(infunc):
def inner_func(*args, **kwargs):
last_log_time = 0
last_exc_message = None
exc_count = 0
while True:
try:
return infunc(*args, **kwargs)
except Exception as exc:
this_exc_message = six.u(str(exc))
if this_exc_message == last_exc_message:
exc_count += 1
else:
exc_count = 1
# Do not log any more frequently than once a minute unless
# the exception message changes
cur_time = int(time.time())
if (cur_time - last_log_time > 60 or
this_exc_message != last_exc_message):
logging.exception(
_LE('Unexpected exception occurred %d time(s)... '
'retrying.') % exc_count)
last_log_time = cur_time
last_exc_message = this_exc_message
exc_count = 0
# This should be a very rare event. In case it isn't, do
# a sleep.
time.sleep(1)
return inner_func

View File

@ -0,0 +1,135 @@
# Copyright 2011 OpenStack Foundation.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import contextlib
import errno
import os
import tempfile
from cerberus.openstack.common import excutils
from cerberus.openstack.common import log as logging
LOG = logging.getLogger(__name__)
_FILE_CACHE = {}
def ensure_tree(path):
"""Create a directory (and any ancestor directories required)
:param path: Directory to create
"""
try:
os.makedirs(path)
except OSError as exc:
if exc.errno == errno.EEXIST:
if not os.path.isdir(path):
raise
else:
raise
def read_cached_file(filename, force_reload=False):
"""Read from a file if it has been modified.
:param force_reload: Whether to reload the file.
:returns: A tuple with a boolean specifying if the data is fresh
or not.
"""
global _FILE_CACHE
if force_reload and filename in _FILE_CACHE:
del _FILE_CACHE[filename]
reloaded = False
mtime = os.path.getmtime(filename)
cache_info = _FILE_CACHE.setdefault(filename, {})
if not cache_info or mtime > cache_info.get('mtime', 0):
LOG.debug("Reloading cached file %s" % filename)
with open(filename) as fap:
cache_info['data'] = fap.read()
cache_info['mtime'] = mtime
reloaded = True
return (reloaded, cache_info['data'])
def delete_if_exists(path, remove=os.unlink):
"""Delete a file, but ignore file not found error.
:param path: File to delete
:param remove: Optional function to remove passed path
"""
try:
remove(path)
except OSError as e:
if e.errno != errno.ENOENT:
raise
@contextlib.contextmanager
def remove_path_on_error(path, remove=delete_if_exists):
"""Protect code that wants to operate on PATH atomically.
Any exception will cause PATH to be removed.
:param path: File to work with
:param remove: Optional function to remove passed path
"""
try:
yield
except Exception:
with excutils.save_and_reraise_exception():
remove(path)
def file_open(*args, **kwargs):
"""Open file
see built-in file() documentation for more details
Note: The reason this is kept in a separate module is to easily
be able to provide a stub module that doesn't alter system
state at all (for unit tests)
"""
return file(*args, **kwargs)
def write_to_tempfile(content, path=None, suffix='', prefix='tmp'):
"""Create temporary file or use existing file.
This util is needed for creating temporary file with
specified content, suffix and prefix. If path is not None,
it will be used for writing content. If the path doesn't
exist it'll be created.
:param content: content for temporary file.
:param path: same as parameter 'dir' for mkstemp
:param suffix: same as parameter 'suffix' for mkstemp
:param prefix: same as parameter 'prefix' for mkstemp
For example: it can be used in database tests for creating
configuration files.
"""
if path:
ensure_tree(path)
(fd, path) = tempfile.mkstemp(suffix=suffix, dir=path, prefix=prefix)
try:
os.write(fd, content)
finally:
os.close(fd)
return path

View File

@ -0,0 +1,85 @@
#
# Copyright 2013 Mirantis, Inc.
# Copyright 2013 OpenStack Foundation
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import fixtures
from oslo.config import cfg
import six
class Config(fixtures.Fixture):
"""Allows overriding configuration settings for the test.
`conf` will be reset on cleanup.
"""
def __init__(self, conf=cfg.CONF):
self.conf = conf
def setUp(self):
super(Config, self).setUp()
# NOTE(morganfainberg): unregister must be added to cleanup before
# reset is because cleanup works in reverse order of registered items,
# and a reset must occur before unregistering options can occur.
self.addCleanup(self._unregister_config_opts)
self.addCleanup(self.conf.reset)
self._registered_config_opts = {}
def config(self, **kw):
"""Override configuration values.
The keyword arguments are the names of configuration options to
override and their values.
If a `group` argument is supplied, the overrides are applied to
the specified configuration option group, otherwise the overrides
are applied to the ``default`` group.
"""
group = kw.pop('group', None)
for k, v in six.iteritems(kw):
self.conf.set_override(k, v, group)
def _unregister_config_opts(self):
for group in self._registered_config_opts:
self.conf.unregister_opts(self._registered_config_opts[group],
group=group)
def register_opt(self, opt, group=None):
"""Register a single option for the test run.
Options registered in this manner will automatically be unregistered
during cleanup.
If a `group` argument is supplied, it will register the new option
to that group, otherwise the option is registered to the ``default``
group.
"""
self.conf.register_opt(opt, group=group)
self._registered_config_opts.setdefault(group, set()).add(opt)
def register_opts(self, opts, group=None):
"""Register multiple options for the test run.
This works in the same manner as register_opt() but takes a list of
options as the first argument. All arguments will be registered to the
same group if the ``group`` argument is supplied, otherwise all options
will be registered to the ``default`` group.
"""
for opt in opts:
self.register_opt(opt, group=group)

View File

@ -0,0 +1,51 @@
# Copyright 2011 OpenStack Foundation.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import fixtures
from cerberus.openstack.common import lockutils
class LockFixture(fixtures.Fixture):
"""External locking fixture.
This fixture is basically an alternative to the synchronized decorator with
the external flag so that tearDowns and addCleanups will be included in
the lock context for locking between tests. The fixture is recommended to
be the first line in a test method, like so::
def test_method(self):
self.useFixture(LockFixture)
...
or the first line in setUp if all the test methods in the class are
required to be serialized. Something like::
class TestCase(testtools.testcase):
def setUp(self):
self.useFixture(LockFixture)
super(TestCase, self).setUp()
...
This is because addCleanups are put on a LIFO queue that gets run after the
test method exits. (either by completing or raising an exception)
"""
def __init__(self, name, lock_file_prefix=None):
self.mgr = lockutils.lock(name, lock_file_prefix, True)
def setUp(self):
super(LockFixture, self).setUp()
self.addCleanup(self.mgr.__exit__, None, None, None)
self.lock = self.mgr.__enter__()

View File

@ -0,0 +1,34 @@
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import fixtures
def get_logging_handle_error_fixture():
"""returns a fixture to make logging raise formatting exceptions.
Usage:
self.useFixture(logging.get_logging_handle_error_fixture())
"""
return fixtures.MonkeyPatch('logging.Handler.handleError',
_handleError)
def _handleError(self, record):
"""Monkey patch for logging.Handler.handleError.
The default handleError just logs the error to stderr but we want
the option of actually raising an exception.
"""
raise

View File

@ -0,0 +1,62 @@
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# Copyright 2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
##############################################################################
##############################################################################
##
## DO NOT MODIFY THIS FILE
##
## This file is being graduated to the cerberustest library. Please make all
## changes there, and only backport critical fixes here. - dhellmann
##
##############################################################################
##############################################################################
import fixtures
import mock
class PatchObject(fixtures.Fixture):
"""Deal with code around mock."""
def __init__(self, obj, attr, new=mock.DEFAULT, **kwargs):
self.obj = obj
self.attr = attr
self.kwargs = kwargs
self.new = new
def setUp(self):
super(PatchObject, self).setUp()
_p = mock.patch.object(self.obj, self.attr, self.new, **self.kwargs)
self.mock = _p.start()
self.addCleanup(_p.stop)
class Patch(fixtures.Fixture):
"""Deal with code around mock.patch."""
def __init__(self, obj, new=mock.DEFAULT, **kwargs):
self.obj = obj
self.kwargs = kwargs
self.new = new
def setUp(self):
super(Patch, self).setUp()
_p = mock.patch(self.obj, self.new, **self.kwargs)
self.mock = _p.start()
self.addCleanup(_p.stop)

View File

@ -0,0 +1,43 @@
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# Copyright 2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
##############################################################################
##############################################################################
##
## DO NOT MODIFY THIS FILE
##
## This file is being graduated to the cerberustest library. Please make all
## changes there, and only backport critical fixes here. - dhellmann
##
##############################################################################
##############################################################################
import fixtures
from six.moves import mox
class MoxStubout(fixtures.Fixture):
"""Deal with code around mox and stubout as a fixture."""
def setUp(self):
super(MoxStubout, self).setUp()
# emulate some of the mox stuff, we can't use the metaclass
# because it screws with our generators
self.mox = mox.Mox()
self.stubs = self.mox.stubs
self.addCleanup(self.mox.UnsetStubs)
self.addCleanup(self.mox.VerifyAll)

View File

@ -0,0 +1,448 @@
# Copyright 2012 Red Hat, Inc.
# Copyright 2013 IBM Corp.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
gettext for openstack-common modules.
Usual usage in an openstack.common module:
from cerberus.openstack.common.gettextutils import _
"""
import copy
import functools
import gettext
import locale
from logging import handlers
import os
from babel import localedata
import six
_localedir = os.environ.get('cerberus'.upper() + '_LOCALEDIR')
_t = gettext.translation('cerberus', localedir=_localedir, fallback=True)
# We use separate translation catalogs for each log level, so set up a
# mapping between the log level name and the translator. The domain
# for the log level is project_name + "-log-" + log_level so messages
# for each level end up in their own catalog.
_t_log_levels = dict(
(level, gettext.translation('cerberus' + '-log-' + level,
localedir=_localedir,
fallback=True))
for level in ['info', 'warning', 'error', 'critical']
)
_AVAILABLE_LANGUAGES = {}
USE_LAZY = False
def enable_lazy():
"""Convenience function for configuring _() to use lazy gettext
Call this at the start of execution to enable the gettextutils._
function to use lazy gettext functionality. This is useful if
your project is importing _ directly instead of using the
gettextutils.install() way of importing the _ function.
"""
global USE_LAZY
USE_LAZY = True
def _(msg):
if USE_LAZY:
return Message(msg, domain='cerberus')
else:
if six.PY3:
return _t.gettext(msg)
return _t.ugettext(msg)
def _log_translation(msg, level):
"""Build a single translation of a log message
"""
if USE_LAZY:
return Message(msg, domain='cerberus' + '-log-' + level)
else:
translator = _t_log_levels[level]
if six.PY3:
return translator.gettext(msg)
return translator.ugettext(msg)
# Translators for log levels.
#
# The abbreviated names are meant to reflect the usual use of a short
# name like '_'. The "L" is for "log" and the other letter comes from
# the level.
_LI = functools.partial(_log_translation, level='info')
_LW = functools.partial(_log_translation, level='warning')
_LE = functools.partial(_log_translation, level='error')
_LC = functools.partial(_log_translation, level='critical')
def install(domain, lazy=False):
"""Install a _() function using the given translation domain.
Given a translation domain, install a _() function using gettext's
install() function.
The main difference from gettext.install() is that we allow
overriding the default localedir (e.g. /usr/share/locale) using
a translation-domain-specific environment variable (e.g.
NOVA_LOCALEDIR).
:param domain: the translation domain
:param lazy: indicates whether or not to install the lazy _() function.
The lazy _() introduces a way to do deferred translation
of messages by installing a _ that builds Message objects,
instead of strings, which can then be lazily translated into
any available locale.
"""
if lazy:
# NOTE(mrodden): Lazy gettext functionality.
#
# The following introduces a deferred way to do translations on
# messages in OpenStack. We override the standard _() function
# and % (format string) operation to build Message objects that can
# later be translated when we have more information.
def _lazy_gettext(msg):
"""Create and return a Message object.
Lazy gettext function for a given domain, it is a factory method
for a project/module to get a lazy gettext function for its own
translation domain (i.e. nova, glance, cinder, etc.)
Message encapsulates a string so that we can translate
it later when needed.
"""
return Message(msg, domain=domain)
from six import moves
moves.builtins.__dict__['_'] = _lazy_gettext
else:
localedir = '%s_LOCALEDIR' % domain.upper()
if six.PY3:
gettext.install(domain,
localedir=os.environ.get(localedir))
else:
gettext.install(domain,
localedir=os.environ.get(localedir),
unicode=True)
class Message(six.text_type):
"""A Message object is a unicode object that can be translated.
Translation of Message is done explicitly using the translate() method.
For all non-translation intents and purposes, a Message is simply unicode,
and can be treated as such.
"""
def __new__(cls, msgid, msgtext=None, params=None,
domain='cerberus', *args):
"""Create a new Message object.
In order for translation to work gettext requires a message ID, this
msgid will be used as the base unicode text. It is also possible
for the msgid and the base unicode text to be different by passing
the msgtext parameter.
"""
# If the base msgtext is not given, we use the default translation
# of the msgid (which is in English) just in case the system locale is
# not English, so that the base text will be in that locale by default.
if not msgtext:
msgtext = Message._translate_msgid(msgid, domain)
# We want to initialize the parent unicode with the actual object that
# would have been plain unicode if 'Message' was not enabled.
msg = super(Message, cls).__new__(cls, msgtext)
msg.msgid = msgid
msg.domain = domain
msg.params = params
return msg
def translate(self, desired_locale=None):
"""Translate this message to the desired locale.
:param desired_locale: The desired locale to translate the message to,
if no locale is provided the message will be
translated to the system's default locale.
:returns: the translated message in unicode
"""
translated_message = Message._translate_msgid(self.msgid,
self.domain,
desired_locale)
if self.params is None:
# No need for more translation
return translated_message
# This Message object may have been formatted with one or more
# Message objects as substitution arguments, given either as a single
# argument, part of a tuple, or as one or more values in a dictionary.
# When translating this Message we need to translate those Messages too
translated_params = _translate_args(self.params, desired_locale)
translated_message = translated_message % translated_params
return translated_message
@staticmethod
def _translate_msgid(msgid, domain, desired_locale=None):
if not desired_locale:
system_locale = locale.getdefaultlocale()
# If the system locale is not available to the runtime use English
if not system_locale[0]:
desired_locale = 'en_US'
else:
desired_locale = system_locale[0]
locale_dir = os.environ.get(domain.upper() + '_LOCALEDIR')
lang = gettext.translation(domain,
localedir=locale_dir,
languages=[desired_locale],
fallback=True)
if six.PY3:
translator = lang.gettext
else:
translator = lang.ugettext
translated_message = translator(msgid)
return translated_message
def __mod__(self, other):
# When we mod a Message we want the actual operation to be performed
# by the parent class (i.e. unicode()), the only thing we do here is
# save the original msgid and the parameters in case of a translation
params = self._sanitize_mod_params(other)
unicode_mod = super(Message, self).__mod__(params)
modded = Message(self.msgid,
msgtext=unicode_mod,
params=params,
domain=self.domain)
return modded
def _sanitize_mod_params(self, other):
"""Sanitize the object being modded with this Message.
- Add support for modding 'None' so translation supports it
- Trim the modded object, which can be a large dictionary, to only
those keys that would actually be used in a translation
- Snapshot the object being modded, in case the message is
translated, it will be used as it was when the Message was created
"""
if other is None:
params = (other,)
elif isinstance(other, dict):
# Merge the dictionaries
# Copy each item in case one does not support deep copy.
params = {}
if isinstance(self.params, dict):
for key, val in self.params.items():
params[key] = self._copy_param(val)
for key, val in other.items():
params[key] = self._copy_param(val)
else:
params = self._copy_param(other)
return params
def _copy_param(self, param):
try:
return copy.deepcopy(param)
except Exception:
# Fallback to casting to unicode this will handle the
# python code-like objects that can't be deep-copied
return six.text_type(param)
def __add__(self, other):
msg = _('Message objects do not support addition.')
raise TypeError(msg)
def __radd__(self, other):
return self.__add__(other)
def __str__(self):
# NOTE(luisg): Logging in python 2.6 tries to str() log records,
# and it expects specifically a UnicodeError in order to proceed.
msg = _('Message objects do not support str() because they may '
'contain non-ascii characters. '
'Please use unicode() or translate() instead.')
raise UnicodeError(msg)
def get_available_languages(domain):
"""Lists the available languages for the given translation domain.
:param domain: the domain to get languages for
"""
if domain in _AVAILABLE_LANGUAGES:
return copy.copy(_AVAILABLE_LANGUAGES[domain])
localedir = '%s_LOCALEDIR' % domain.upper()
find = lambda x: gettext.find(domain,
localedir=os.environ.get(localedir),
languages=[x])
# NOTE(mrodden): en_US should always be available (and first in case
# order matters) since our in-line message strings are en_US
language_list = ['en_US']
# NOTE(luisg): Babel <1.0 used a function called list(), which was
# renamed to locale_identifiers() in >=1.0, the requirements master list
# requires >=0.9.6, uncapped, so defensively work with both. We can remove
# this check when the master list updates to >=1.0, and update all projects
list_identifiers = (getattr(localedata, 'list', None) or
getattr(localedata, 'locale_identifiers'))
locale_identifiers = list_identifiers()
for i in locale_identifiers:
if find(i) is not None:
language_list.append(i)
# NOTE(luisg): Babel>=1.0,<1.3 has a bug where some OpenStack supported
# locales (e.g. 'zh_CN', and 'zh_TW') aren't supported even though they
# are perfectly legitimate locales:
# https://github.com/mitsuhiko/babel/issues/37
# In Babel 1.3 they fixed the bug and they support these locales, but
# they are still not explicitly "listed" by locale_identifiers().
# That is why we add the locales here explicitly if necessary so that
# they are listed as supported.
aliases = {'zh': 'zh_CN',
'zh_Hant_HK': 'zh_HK',
'zh_Hant': 'zh_TW',
'fil': 'tl_PH'}
for (locale, alias) in six.iteritems(aliases):
if locale in language_list and alias not in language_list:
language_list.append(alias)
_AVAILABLE_LANGUAGES[domain] = language_list
return copy.copy(language_list)
def translate(obj, desired_locale=None):
"""Gets the translated unicode representation of the given object.
If the object is not translatable it is returned as-is.
If the locale is None the object is translated to the system locale.
:param obj: the object to translate
:param desired_locale: the locale to translate the message to, if None the
default system locale will be used
:returns: the translated object in unicode, or the original object if
it could not be translated
"""
message = obj
if not isinstance(message, Message):
# If the object to translate is not already translatable,
# let's first get its unicode representation
message = six.text_type(obj)
if isinstance(message, Message):
# Even after unicoding() we still need to check if we are
# running with translatable unicode before translating
return message.translate(desired_locale)
return obj
def _translate_args(args, desired_locale=None):
"""Translates all the translatable elements of the given arguments object.
This method is used for translating the translatable values in method
arguments which include values of tuples or dictionaries.
If the object is not a tuple or a dictionary the object itself is
translated if it is translatable.
If the locale is None the object is translated to the system locale.
:param args: the args to translate
:param desired_locale: the locale to translate the args to, if None the
default system locale will be used
:returns: a new args object with the translated contents of the original
"""
if isinstance(args, tuple):
return tuple(translate(v, desired_locale) for v in args)
if isinstance(args, dict):
translated_dict = {}
for (k, v) in six.iteritems(args):
translated_v = translate(v, desired_locale)
translated_dict[k] = translated_v
return translated_dict
return translate(args, desired_locale)
class TranslationHandler(handlers.MemoryHandler):
"""Handler that translates records before logging them.
The TranslationHandler takes a locale and a target logging.Handler object
to forward LogRecord objects to after translating them. This handler
depends on Message objects being logged, instead of regular strings.
The handler can be configured declaratively in the logging.conf as follows:
[handlers]
keys = translatedlog, translator
[handler_translatedlog]
class = handlers.WatchedFileHandler
args = ('/var/log/api-localized.log',)
formatter = context
[handler_translator]
class = openstack.common.log.TranslationHandler
target = translatedlog
args = ('zh_CN',)
If the specified locale is not available in the system, the handler will
log in the default locale.
"""
def __init__(self, locale=None, target=None):
"""Initialize a TranslationHandler
:param locale: locale to use for translating messages
:param target: logging.Handler object to forward
LogRecord objects to after translation
"""
# NOTE(luisg): In order to allow this handler to be a wrapper for
# other handlers, such as a FileHandler, and still be able to
# configure it using logging.conf, this handler has to extend
# MemoryHandler because only the MemoryHandlers' logging.conf
# parsing is implemented such that it accepts a target handler.
handlers.MemoryHandler.__init__(self, capacity=0, target=target)
self.locale = locale
def setFormatter(self, fmt):
self.target.setFormatter(fmt)
def emit(self, record):
# We save the message from the original record to restore it
# after translation, so other handlers are not affected by this
original_msg = record.msg
original_args = record.args
try:
self._translate_and_log_record(record)
finally:
record.msg = original_msg
record.args = original_args
def _translate_and_log_record(self, record):
record.msg = translate(record.msg, self.locale)
# In addition to translating the message, we also need to translate
# arguments that were passed to the log method that were not part
# of the main message e.g., log.info(_('Some message %s'), this_one))
record.args = _translate_args(record.args, self.locale)
self.target.emit(record)

View File

@ -0,0 +1,73 @@
# Copyright 2011 OpenStack Foundation.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Import related utilities and helper functions.
"""
import sys
import traceback
def import_class(import_str):
"""Returns a class from a string including module and class."""
mod_str, _sep, class_str = import_str.rpartition('.')
try:
__import__(mod_str)
return getattr(sys.modules[mod_str], class_str)
except (ValueError, AttributeError):
raise ImportError('Class %s cannot be found (%s)' %
(class_str,
traceback.format_exception(*sys.exc_info())))
def import_object(import_str, *args, **kwargs):
"""Import a class and return an instance of it."""
return import_class(import_str)(*args, **kwargs)
def import_object_ns(name_space, import_str, *args, **kwargs):
"""Tries to import object from default namespace.
Imports a class and return an instance of it, first by trying
to find the class in a default namespace, then failing back to
a full path if not found in the default namespace.
"""
import_value = "%s.%s" % (name_space, import_str)
try:
return import_class(import_value)(*args, **kwargs)
except ImportError:
return import_class(import_str)(*args, **kwargs)
def import_module(import_str):
"""Import a module."""
__import__(import_str)
return sys.modules[import_str]
def import_versioned_module(version, submodule=None):
module = 'cerberus.v%s' % version
if submodule:
module = '.'.join((module, submodule))
return import_module(module)
def try_import(import_str, default=None):
"""Try to import a module and if it fails return default."""
try:
return import_module(import_str)
except ImportError:
return default

View File

@ -0,0 +1,186 @@
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# Copyright 2011 Justin Santa Barbara
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
'''
JSON related utilities.
This module provides a few things:
1) A handy function for getting an object down to something that can be
JSON serialized. See to_primitive().
2) Wrappers around loads() and dumps(). The dumps() wrapper will
automatically use to_primitive() for you if needed.
3) This sets up anyjson to use the loads() and dumps() wrappers if anyjson
is available.
'''
import codecs
import datetime
import functools
import inspect
import itertools
import sys
if sys.version_info < (2, 7):
# On Python <= 2.6, json module is not C boosted, so try to use
# simplejson module if available
try:
import simplejson as json
except ImportError:
import json
else:
import json
import six
import six.moves.xmlrpc_client as xmlrpclib
from cerberus.openstack.common import gettextutils
from cerberus.openstack.common import importutils
from cerberus.openstack.common import strutils
from cerberus.openstack.common import timeutils
netaddr = importutils.try_import("netaddr")
_nasty_type_tests = [inspect.ismodule, inspect.isclass, inspect.ismethod,
inspect.isfunction, inspect.isgeneratorfunction,
inspect.isgenerator, inspect.istraceback, inspect.isframe,
inspect.iscode, inspect.isbuiltin, inspect.isroutine,
inspect.isabstract]
_simple_types = (six.string_types + six.integer_types
+ (type(None), bool, float))
def to_primitive(value, convert_instances=False, convert_datetime=True,
level=0, max_depth=3):
"""Convert a complex object into primitives.
Handy for JSON serialization. We can optionally handle instances,
but since this is a recursive function, we could have cyclical
data structures.
To handle cyclical data structures we could track the actual objects
visited in a set, but not all objects are hashable. Instead we just
track the depth of the object inspections and don't go too deep.
Therefore, convert_instances=True is lossy ... be aware.
"""
# handle obvious types first - order of basic types determined by running
# full tests on nova project, resulting in the following counts:
# 572754 <type 'NoneType'>
# 460353 <type 'int'>
# 379632 <type 'unicode'>
# 274610 <type 'str'>
# 199918 <type 'dict'>
# 114200 <type 'datetime.datetime'>
# 51817 <type 'bool'>
# 26164 <type 'list'>
# 6491 <type 'float'>
# 283 <type 'tuple'>
# 19 <type 'long'>
if isinstance(value, _simple_types):
return value
if isinstance(value, datetime.datetime):
if convert_datetime:
return timeutils.strtime(value)
else:
return value
# value of itertools.count doesn't get caught by nasty_type_tests
# and results in infinite loop when list(value) is called.
if type(value) == itertools.count:
return six.text_type(value)
# FIXME(vish): Workaround for LP bug 852095. Without this workaround,
# tests that raise an exception in a mocked method that
# has a @wrap_exception with a notifier will fail. If
# we up the dependency to 0.5.4 (when it is released) we
# can remove this workaround.
if getattr(value, '__module__', None) == 'mox':
return 'mock'
if level > max_depth:
return '?'
# The try block may not be necessary after the class check above,
# but just in case ...
try:
recursive = functools.partial(to_primitive,
convert_instances=convert_instances,
convert_datetime=convert_datetime,
level=level,
max_depth=max_depth)
if isinstance(value, dict):
return dict((k, recursive(v)) for k, v in six.iteritems(value))
elif isinstance(value, (list, tuple)):
return [recursive(lv) for lv in value]
# It's not clear why xmlrpclib created their own DateTime type, but
# for our purposes, make it a datetime type which is explicitly
# handled
if isinstance(value, xmlrpclib.DateTime):
value = datetime.datetime(*tuple(value.timetuple())[:6])
if convert_datetime and isinstance(value, datetime.datetime):
return timeutils.strtime(value)
elif isinstance(value, gettextutils.Message):
return value.data
elif hasattr(value, 'iteritems'):
return recursive(dict(value.iteritems()), level=level + 1)
elif hasattr(value, '__iter__'):
return recursive(list(value))
elif convert_instances and hasattr(value, '__dict__'):
# Likely an instance of something. Watch for cycles.
# Ignore class member vars.
return recursive(value.__dict__, level=level + 1)
elif netaddr and isinstance(value, netaddr.IPAddress):
return six.text_type(value)
else:
if any(test(value) for test in _nasty_type_tests):
return six.text_type(value)
return value
except TypeError:
# Class objects are tricky since they may define something like
# __iter__ defined but it isn't callable as list().
return six.text_type(value)
def dumps(value, default=to_primitive, **kwargs):
return json.dumps(value, default=default, **kwargs)
def loads(s, encoding='utf-8'):
return json.loads(strutils.safe_decode(s, encoding))
def load(fp, encoding='utf-8'):
return json.load(codecs.getreader(encoding)(fp))
try:
import anyjson
except ImportError:
pass
else:
anyjson._modules.append((__name__, 'dumps', TypeError,
'loads', ValueError, 'load'))
anyjson.force_implementation(__name__)

View File

@ -0,0 +1,45 @@
# Copyright 2011 OpenStack Foundation.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Local storage of variables using weak references"""
import threading
import weakref
class WeakLocal(threading.local):
def __getattribute__(self, attr):
rval = super(WeakLocal, self).__getattribute__(attr)
if rval:
# NOTE(mikal): this bit is confusing. What is stored is a weak
# reference, not the value itself. We therefore need to lookup
# the weak reference and return the inner value here.
rval = rval()
return rval
def __setattr__(self, attr, value):
value = weakref.ref(value)
return super(WeakLocal, self).__setattr__(attr, value)
# NOTE(mikal): the name "store" should be deprecated in the future
store = WeakLocal()
# A "weak" store uses weak references and allows an object to fall out of scope
# when it falls out of scope in the code that uses the thread local storage. A
# "strong" store will hold a reference to the object so that it never falls out
# of scope.
weak_store = WeakLocal()
strong_store = threading.local()

View File

@ -0,0 +1,377 @@
# Copyright 2011 OpenStack Foundation.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import contextlib
import errno
import fcntl
import functools
import os
import shutil
import subprocess
import sys
import tempfile
import threading
import time
import weakref
from oslo.config import cfg
from cerberus.openstack.common import fileutils
from cerberus.openstack.common.gettextutils import _, _LE, _LI
from cerberus.openstack.common import log as logging
LOG = logging.getLogger(__name__)
util_opts = [
cfg.BoolOpt('disable_process_locking', default=False,
help='Whether to disable inter-process locks'),
cfg.StrOpt('lock_path',
default=os.environ.get("CERBERUS_LOCK_PATH"),
help=('Directory to use for lock files.'))
]
CONF = cfg.CONF
CONF.register_opts(util_opts)
def set_defaults(lock_path):
cfg.set_defaults(util_opts, lock_path=lock_path)
class _FileLock(object):
"""Lock implementation which allows multiple locks, working around
issues like bugs.debian.org/cgi-bin/bugreport.cgi?bug=632857 and does
not require any cleanup. Since the lock is always held on a file
descriptor rather than outside of the process, the lock gets dropped
automatically if the process crashes, even if __exit__ is not executed.
There are no guarantees regarding usage by multiple green threads in a
single process here. This lock works only between processes. Exclusive
access between local threads should be achieved using the semaphores
in the @synchronized decorator.
Note these locks are released when the descriptor is closed, so it's not
safe to close the file descriptor while another green thread holds the
lock. Just opening and closing the lock file can break synchronisation,
so lock files must be accessed only using this abstraction.
"""
def __init__(self, name):
self.lockfile = None
self.fname = name
def acquire(self):
basedir = os.path.dirname(self.fname)
if not os.path.exists(basedir):
fileutils.ensure_tree(basedir)
LOG.info(_LI('Created lock path: %s'), basedir)
self.lockfile = open(self.fname, 'w')
while True:
try:
# Using non-blocking locks since green threads are not
# patched to deal with blocking locking calls.
# Also upon reading the MSDN docs for locking(), it seems
# to have a laughable 10 attempts "blocking" mechanism.
self.trylock()
LOG.debug('Got file lock "%s"', self.fname)
return True
except IOError as e:
if e.errno in (errno.EACCES, errno.EAGAIN):
# external locks synchronise things like iptables
# updates - give it some time to prevent busy spinning
time.sleep(0.01)
else:
raise threading.ThreadError(_("Unable to acquire lock on"
" `%(filename)s` due to"
" %(exception)s") %
{
'filename': self.fname,
'exception': e,
})
def __enter__(self):
self.acquire()
return self
def release(self):
try:
self.unlock()
self.lockfile.close()
LOG.debug('Released file lock "%s"', self.fname)
except IOError:
LOG.exception(_LE("Could not release the acquired lock `%s`"),
self.fname)
def __exit__(self, exc_type, exc_val, exc_tb):
self.release()
def exists(self):
return os.path.exists(self.fname)
def trylock(self):
raise NotImplementedError()
def unlock(self):
raise NotImplementedError()
class _WindowsLock(_FileLock):
def trylock(self):
msvcrt.locking(self.lockfile.fileno(), msvcrt.LK_NBLCK, 1)
def unlock(self):
msvcrt.locking(self.lockfile.fileno(), msvcrt.LK_UNLCK, 1)
class _FcntlLock(_FileLock):
def trylock(self):
fcntl.lockf(self.lockfile, fcntl.LOCK_EX | fcntl.LOCK_NB)
def unlock(self):
fcntl.lockf(self.lockfile, fcntl.LOCK_UN)
class _PosixLock(object):
def __init__(self, name):
# Hash the name because it's not valid to have POSIX semaphore
# names with things like / in them. Then use base64 to encode
# the digest() instead taking the hexdigest() because the
# result is shorter and most systems can't have shm sempahore
# names longer than 31 characters.
h = hashlib.sha1()
h.update(name.encode('ascii'))
self.name = str((b'/' + base64.urlsafe_b64encode(
h.digest())).decode('ascii'))
def acquire(self, timeout=None):
self.semaphore = posix_ipc.Semaphore(self.name,
flags=posix_ipc.O_CREAT,
initial_value=1)
self.semaphore.acquire(timeout)
return self
def __enter__(self):
self.acquire()
return self
def release(self):
self.semaphore.release()
self.semaphore.close()
def __exit__(self, exc_type, exc_val, exc_tb):
self.release()
def exists(self):
try:
semaphore = posix_ipc.Semaphore(self.name)
except posix_ipc.ExistentialError:
return False
else:
semaphore.close()
return True
if os.name == 'nt':
import msvcrt
InterProcessLock = _WindowsLock
FileLock = _WindowsLock
else:
import base64
import hashlib
import posix_ipc
InterProcessLock = _PosixLock
FileLock = _FcntlLock
_semaphores = weakref.WeakValueDictionary()
_semaphores_lock = threading.Lock()
def _get_lock_path(name, lock_file_prefix, lock_path=None):
# NOTE(mikal): the lock name cannot contain directory
# separators
name = name.replace(os.sep, '_')
if lock_file_prefix:
sep = '' if lock_file_prefix.endswith('-') else '-'
name = '%s%s%s' % (lock_file_prefix, sep, name)
local_lock_path = lock_path or CONF.lock_path
if not local_lock_path:
# NOTE(bnemec): Create a fake lock path for posix locks so we don't
# unnecessarily raise the RequiredOptError below.
if InterProcessLock is not _PosixLock:
raise cfg.RequiredOptError('lock_path')
local_lock_path = 'posixlock:/'
return os.path.join(local_lock_path, name)
def external_lock(name, lock_file_prefix=None, lock_path=None):
LOG.debug('Attempting to grab external lock "%(lock)s"',
{'lock': name})
lock_file_path = _get_lock_path(name, lock_file_prefix, lock_path)
# NOTE(bnemec): If an explicit lock_path was passed to us then it
# means the caller is relying on file-based locking behavior, so
# we can't use posix locks for those calls.
if lock_path:
return FileLock(lock_file_path)
return InterProcessLock(lock_file_path)
def remove_external_lock_file(name, lock_file_prefix=None):
"""Remove a external lock file when it's not used anymore
This will be helpful when we have a lot of lock files
"""
with internal_lock(name):
lock_file_path = _get_lock_path(name, lock_file_prefix)
try:
os.remove(lock_file_path)
except OSError:
LOG.info(_LI('Failed to remove file %(file)s'),
{'file': lock_file_path})
def internal_lock(name):
with _semaphores_lock:
try:
sem = _semaphores[name]
except KeyError:
sem = threading.Semaphore()
_semaphores[name] = sem
LOG.debug('Got semaphore "%(lock)s"', {'lock': name})
return sem
@contextlib.contextmanager
def lock(name, lock_file_prefix=None, external=False, lock_path=None):
"""Context based lock
This function yields a `threading.Semaphore` instance (if we don't use
eventlet.monkey_patch(), else `semaphore.Semaphore`) unless external is
True, in which case, it'll yield an InterProcessLock instance.
:param lock_file_prefix: The lock_file_prefix argument is used to provide
lock files on disk with a meaningful prefix.
:param external: The external keyword argument denotes whether this lock
should work across multiple processes. This means that if two different
workers both run a a method decorated with @synchronized('mylock',
external=True), only one of them will execute at a time.
"""
int_lock = internal_lock(name)
with int_lock:
if external and not CONF.disable_process_locking:
ext_lock = external_lock(name, lock_file_prefix, lock_path)
with ext_lock:
yield ext_lock
else:
yield int_lock
def synchronized(name, lock_file_prefix=None, external=False, lock_path=None):
"""Synchronization decorator.
Decorating a method like so::
@synchronized('mylock')
def foo(self, *args):
...
ensures that only one thread will execute the foo method at a time.
Different methods can share the same lock::
@synchronized('mylock')
def foo(self, *args):
...
@synchronized('mylock')
def bar(self, *args):
...
This way only one of either foo or bar can be executing at a time.
"""
def wrap(f):
@functools.wraps(f)
def inner(*args, **kwargs):
try:
with lock(name, lock_file_prefix, external, lock_path):
LOG.debug('Got semaphore / lock "%(function)s"',
{'function': f.__name__})
return f(*args, **kwargs)
finally:
LOG.debug('Semaphore / lock released "%(function)s"',
{'function': f.__name__})
return inner
return wrap
def synchronized_with_prefix(lock_file_prefix):
"""Partial object generator for the synchronization decorator.
Redefine @synchronized in each project like so::
(in nova/utils.py)
from nova.openstack.common import lockutils
synchronized = lockutils.synchronized_with_prefix('nova-')
(in nova/foo.py)
from nova import utils
@utils.synchronized('mylock')
def bar(self, *args):
...
The lock_file_prefix argument is used to provide lock files on disk with a
meaningful prefix.
"""
return functools.partial(synchronized, lock_file_prefix=lock_file_prefix)
def main(argv):
"""Create a dir for locks and pass it to command from arguments
If you run this:
python -m openstack.common.lockutils python setup.py testr <etc>
a temporary directory will be created for all your locks and passed to all
your tests in an environment variable. The temporary dir will be deleted
afterwards and the return value will be preserved.
"""
lock_dir = tempfile.mkdtemp()
os.environ["CERBERUS_LOCK_PATH"] = lock_dir
try:
ret_val = subprocess.call(argv[1:])
finally:
shutil.rmtree(lock_dir, ignore_errors=True)
return ret_val
if __name__ == '__main__':
sys.exit(main(sys.argv))

View File

@ -0,0 +1,713 @@
# Copyright 2011 OpenStack Foundation.
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""OpenStack logging handler.
This module adds to logging functionality by adding the option to specify
a context object when calling the various log methods. If the context object
is not specified, default formatting is used. Additionally, an instance uuid
may be passed as part of the log message, which is intended to make it easier
for admins to find messages related to a specific instance.
It also allows setting of formatting information through conf.
"""
import inspect
import itertools
import logging
import logging.config
import logging.handlers
import os
import re
import sys
import traceback
from oslo.config import cfg
import six
from six import moves
from cerberus.openstack.common.gettextutils import _
from cerberus.openstack.common import importutils
from cerberus.openstack.common import jsonutils
from cerberus.openstack.common import local
_DEFAULT_LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
_SANITIZE_KEYS = ['adminPass', 'admin_pass', 'password', 'admin_password']
# NOTE(ldbragst): Let's build a list of regex objects using the list of
# _SANITIZE_KEYS we already have. This way, we only have to add the new key
# to the list of _SANITIZE_KEYS and we can generate regular expressions
# for XML and JSON automatically.
_SANITIZE_PATTERNS = []
_FORMAT_PATTERNS = [r'(%(key)s\s*[=]\s*[\"\']).*?([\"\'])',
r'(<%(key)s>).*?(</%(key)s>)',
r'([\"\']%(key)s[\"\']\s*:\s*[\"\']).*?([\"\'])',
r'([\'"].*?%(key)s[\'"]\s*:\s*u?[\'"]).*?([\'"])']
for key in _SANITIZE_KEYS:
for pattern in _FORMAT_PATTERNS:
reg_ex = re.compile(pattern % {'key': key}, re.DOTALL)
_SANITIZE_PATTERNS.append(reg_ex)
common_cli_opts = [
cfg.BoolOpt('debug',
short='d',
default=False,
help='Print debugging output (set logging level to '
'DEBUG instead of default WARNING level).'),
cfg.BoolOpt('verbose',
short='v',
default=False,
help='Print more verbose output (set logging level to '
'INFO instead of default WARNING level).'),
]
logging_cli_opts = [
cfg.StrOpt('log-config-append',
metavar='PATH',
deprecated_name='log-config',
help='The name of logging configuration file. It does not '
'disable existing loggers, but just appends specified '
'logging configuration to any other existing logging '
'options. Please see the Python logging module '
'documentation for details on logging configuration '
'files.'),
cfg.StrOpt('log-format',
default=None,
metavar='FORMAT',
help='DEPRECATED. '
'A logging.Formatter log message format string which may '
'use any of the available logging.LogRecord attributes. '
'This option is deprecated. Please use '
'logging_context_format_string and '
'logging_default_format_string instead.'),
cfg.StrOpt('log-date-format',
default=_DEFAULT_LOG_DATE_FORMAT,
metavar='DATE_FORMAT',
help='Format string for %%(asctime)s in log records. '
'Default: %(default)s'),
cfg.StrOpt('log-file',
metavar='PATH',
deprecated_name='logfile',
help='(Optional) Name of log file to output to. '
'If no default is set, logging will go to stdout.'),
cfg.StrOpt('log-dir',
deprecated_name='logdir',
help='(Optional) The base directory used for relative '
'--log-file paths'),
cfg.BoolOpt('use-syslog',
default=False,
help='Use syslog for logging. '
'Existing syslog format is DEPRECATED during I, '
'and then will be changed in J to honor RFC5424'),
cfg.BoolOpt('use-syslog-rfc-format',
# TODO(bogdando) remove or use True after existing
# syslog format deprecation in J
default=False,
help='(Optional) Use syslog rfc5424 format for logging. '
'If enabled, will add APP-NAME (RFC5424) before the '
'MSG part of the syslog message. The old format '
'without APP-NAME is deprecated in I, '
'and will be removed in J.'),
cfg.StrOpt('syslog-log-facility',
default='LOG_USER',
help='Syslog facility to receive log lines')
]
generic_log_opts = [
cfg.BoolOpt('use_stderr',
default=True,
help='Log output to standard error')
]
log_opts = [
cfg.StrOpt('logging_context_format_string',
default='%(asctime)s.%(msecs)03d %(process)d %(levelname)s '
'%(name)s [%(request_id)s %(user_identity)s] '
'%(instance)s%(message)s',
help='Format string to use for log messages with context'),
cfg.StrOpt('logging_default_format_string',
default='%(asctime)s.%(msecs)03d %(process)d %(levelname)s '
'%(name)s [-] %(instance)s%(message)s',
help='Format string to use for log messages without context'),
cfg.StrOpt('logging_debug_format_suffix',
default='%(funcName)s %(pathname)s:%(lineno)d',
help='Data to append to log format when level is DEBUG'),
cfg.StrOpt('logging_exception_prefix',
default='%(asctime)s.%(msecs)03d %(process)d TRACE %(name)s '
'%(instance)s',
help='Prefix each line of exception output with this format'),
cfg.ListOpt('default_log_levels',
default=[
'amqp=WARN',
'amqplib=WARN',
'boto=WARN',
'qpid=WARN',
'sqlalchemy=WARN',
'suds=INFO',
'oslo.messaging=INFO',
'iso8601=WARN',
'requests.packages.urllib3.connectionpool=WARN'
],
help='List of logger=LEVEL pairs'),
cfg.BoolOpt('publish_errors',
default=False,
help='Publish error events'),
cfg.BoolOpt('fatal_deprecations',
default=False,
help='Make deprecations fatal'),
# NOTE(mikal): there are two options here because sometimes we are handed
# a full instance (and could include more information), and other times we
# are just handed a UUID for the instance.
cfg.StrOpt('instance_format',
default='[instance: %(uuid)s] ',
help='If an instance is passed with the log message, format '
'it like this'),
cfg.StrOpt('instance_uuid_format',
default='[instance: %(uuid)s] ',
help='If an instance UUID is passed with the log message, '
'format it like this'),
]
CONF = cfg.CONF
CONF.register_cli_opts(common_cli_opts)
CONF.register_cli_opts(logging_cli_opts)
CONF.register_opts(generic_log_opts)
CONF.register_opts(log_opts)
# our new audit level
# NOTE(jkoelker) Since we synthesized an audit level, make the logging
# module aware of it so it acts like other levels.
logging.AUDIT = logging.INFO + 1
logging.addLevelName(logging.AUDIT, 'AUDIT')
try:
NullHandler = logging.NullHandler
except AttributeError: # NOTE(jkoelker) NullHandler added in Python 2.7
class NullHandler(logging.Handler):
def handle(self, record):
pass
def emit(self, record):
pass
def createLock(self):
self.lock = None
def _dictify_context(context):
if context is None:
return None
if not isinstance(context, dict) and getattr(context, 'to_dict', None):
context = context.to_dict()
return context
def _get_binary_name():
return os.path.basename(inspect.stack()[-1][1])
def _get_log_file_path(binary=None):
logfile = CONF.log_file
logdir = CONF.log_dir
if logfile and not logdir:
return logfile
if logfile and logdir:
return os.path.join(logdir, logfile)
if logdir:
binary = binary or _get_binary_name()
return '%s.log' % (os.path.join(logdir, binary),)
return None
def mask_password(message, secret="***"):
"""Replace password with 'secret' in message.
:param message: The string which includes security information.
:param secret: value with which to replace passwords.
:returns: The unicode value of message with the password fields masked.
For example:
>>> mask_password("'adminPass' : 'aaaaa'")
"'adminPass' : '***'"
>>> mask_password("'admin_pass' : 'aaaaa'")
"'admin_pass' : '***'"
>>> mask_password('"password" : "aaaaa"')
'"password" : "***"'
>>> mask_password("'original_password' : 'aaaaa'")
"'original_password' : '***'"
>>> mask_password("u'original_password' : u'aaaaa'")
"u'original_password' : u'***'"
"""
message = six.text_type(message)
# NOTE(ldbragst): Check to see if anything in message contains any key
# specified in _SANITIZE_KEYS, if not then just return the message since
# we don't have to mask any passwords.
if not any(key in message for key in _SANITIZE_KEYS):
return message
secret = r'\g<1>' + secret + r'\g<2>'
for pattern in _SANITIZE_PATTERNS:
message = re.sub(pattern, secret, message)
return message
class BaseLoggerAdapter(logging.LoggerAdapter):
def audit(self, msg, *args, **kwargs):
self.log(logging.AUDIT, msg, *args, **kwargs)
class LazyAdapter(BaseLoggerAdapter):
def __init__(self, name='unknown', version='unknown'):
self._logger = None
self.extra = {}
self.name = name
self.version = version
@property
def logger(self):
if not self._logger:
self._logger = getLogger(self.name, self.version)
return self._logger
class ContextAdapter(BaseLoggerAdapter):
warn = logging.LoggerAdapter.warning
def __init__(self, logger, project_name, version_string):
self.logger = logger
self.project = project_name
self.version = version_string
self._deprecated_messages_sent = dict()
@property
def handlers(self):
return self.logger.handlers
def deprecated(self, msg, *args, **kwargs):
"""Call this method when a deprecated feature is used.
If the system is configured for fatal deprecations then the message
is logged at the 'critical' level and :class:`DeprecatedConfig` will
be raised.
Otherwise, the message will be logged (once) at the 'warn' level.
:raises: :class:`DeprecatedConfig` if the system is configured for
fatal deprecations.
"""
stdmsg = _("Deprecated: %s") % msg
if CONF.fatal_deprecations:
self.critical(stdmsg, *args, **kwargs)
raise DeprecatedConfig(msg=stdmsg)
# Using a list because a tuple with dict can't be stored in a set.
sent_args = self._deprecated_messages_sent.setdefault(msg, list())
if args in sent_args:
# Already logged this message, so don't log it again.
return
sent_args.append(args)
self.warn(stdmsg, *args, **kwargs)
def process(self, msg, kwargs):
# NOTE(mrodden): catch any Message/other object and
# coerce to unicode before they can get
# to the python logging and possibly
# cause string encoding trouble
if not isinstance(msg, six.string_types):
msg = six.text_type(msg)
if 'extra' not in kwargs:
kwargs['extra'] = {}
extra = kwargs['extra']
context = kwargs.pop('context', None)
if not context:
context = getattr(local.store, 'context', None)
if context:
extra.update(_dictify_context(context))
instance = kwargs.pop('instance', None)
instance_uuid = (extra.get('instance_uuid') or
kwargs.pop('instance_uuid', None))
instance_extra = ''
if instance:
instance_extra = CONF.instance_format % instance
elif instance_uuid:
instance_extra = (CONF.instance_uuid_format
% {'uuid': instance_uuid})
extra['instance'] = instance_extra
extra.setdefault('user_identity', kwargs.pop('user_identity', None))
extra['project'] = self.project
extra['version'] = self.version
extra['extra'] = extra.copy()
return msg, kwargs
class JSONFormatter(logging.Formatter):
def __init__(self, fmt=None, datefmt=None):
# NOTE(jkoelker) we ignore the fmt argument, but its still there
# since logging.config.fileConfig passes it.
self.datefmt = datefmt
def formatException(self, ei, strip_newlines=True):
lines = traceback.format_exception(*ei)
if strip_newlines:
lines = [moves.filter(
lambda x: x,
line.rstrip().splitlines()) for line in lines]
lines = list(itertools.chain(*lines))
return lines
def format(self, record):
message = {'message': record.getMessage(),
'asctime': self.formatTime(record, self.datefmt),
'name': record.name,
'msg': record.msg,
'args': record.args,
'levelname': record.levelname,
'levelno': record.levelno,
'pathname': record.pathname,
'filename': record.filename,
'module': record.module,
'lineno': record.lineno,
'funcname': record.funcName,
'created': record.created,
'msecs': record.msecs,
'relative_created': record.relativeCreated,
'thread': record.thread,
'thread_name': record.threadName,
'process_name': record.processName,
'process': record.process,
'traceback': None}
if hasattr(record, 'extra'):
message['extra'] = record.extra
if record.exc_info:
message['traceback'] = self.formatException(record.exc_info)
return jsonutils.dumps(message)
def _create_logging_excepthook(product_name):
def logging_excepthook(exc_type, value, tb):
extra = {}
if CONF.verbose or CONF.debug:
extra['exc_info'] = (exc_type, value, tb)
getLogger(product_name).critical(
"".join(traceback.format_exception_only(exc_type, value)),
**extra)
return logging_excepthook
class LogConfigError(Exception):
message = _('Error loading logging config %(log_config)s: %(err_msg)s')
def __init__(self, log_config, err_msg):
self.log_config = log_config
self.err_msg = err_msg
def __str__(self):
return self.message % dict(log_config=self.log_config,
err_msg=self.err_msg)
def _load_log_config(log_config_append):
try:
logging.config.fileConfig(log_config_append,
disable_existing_loggers=False)
except moves.configparser.Error as exc:
raise LogConfigError(log_config_append, str(exc))
def setup(product_name, version='unknown'):
"""Setup logging."""
if CONF.log_config_append:
_load_log_config(CONF.log_config_append)
else:
_setup_logging_from_conf(product_name, version)
sys.excepthook = _create_logging_excepthook(product_name)
def set_defaults(logging_context_format_string):
cfg.set_defaults(log_opts,
logging_context_format_string=
logging_context_format_string)
def _find_facility_from_conf():
facility_names = logging.handlers.SysLogHandler.facility_names
facility = getattr(logging.handlers.SysLogHandler,
CONF.syslog_log_facility,
None)
if facility is None and CONF.syslog_log_facility in facility_names:
facility = facility_names.get(CONF.syslog_log_facility)
if facility is None:
valid_facilities = facility_names.keys()
consts = ['LOG_AUTH', 'LOG_AUTHPRIV', 'LOG_CRON', 'LOG_DAEMON',
'LOG_FTP', 'LOG_KERN', 'LOG_LPR', 'LOG_MAIL', 'LOG_NEWS',
'LOG_AUTH', 'LOG_SYSLOG', 'LOG_USER', 'LOG_UUCP',
'LOG_LOCAL0', 'LOG_LOCAL1', 'LOG_LOCAL2', 'LOG_LOCAL3',
'LOG_LOCAL4', 'LOG_LOCAL5', 'LOG_LOCAL6', 'LOG_LOCAL7']
valid_facilities.extend(consts)
raise TypeError(_('syslog facility must be one of: %s') %
', '.join("'%s'" % fac
for fac in valid_facilities))
return facility
class RFCSysLogHandler(logging.handlers.SysLogHandler):
def __init__(self, *args, **kwargs):
self.binary_name = _get_binary_name()
super(RFCSysLogHandler, self).__init__(*args, **kwargs)
def format(self, record):
msg = super(RFCSysLogHandler, self).format(record)
msg = self.binary_name + ' ' + msg
return msg
def _setup_logging_from_conf(project, version):
log_root = getLogger(None).logger
for handler in log_root.handlers:
log_root.removeHandler(handler)
if CONF.use_syslog:
facility = _find_facility_from_conf()
# TODO(bogdando) use the format provided by RFCSysLogHandler
# after existing syslog format deprecation in J
if CONF.use_syslog_rfc_format:
syslog = RFCSysLogHandler(address='/dev/log',
facility=facility)
else:
syslog = logging.handlers.SysLogHandler(address='/dev/log',
facility=facility)
log_root.addHandler(syslog)
logpath = _get_log_file_path()
if logpath:
filelog = logging.handlers.WatchedFileHandler(logpath)
log_root.addHandler(filelog)
if CONF.use_stderr:
streamlog = ColorHandler()
log_root.addHandler(streamlog)
elif not logpath:
# pass sys.stdout as a positional argument
# python2.6 calls the argument strm, in 2.7 it's stream
streamlog = logging.StreamHandler(sys.stdout)
log_root.addHandler(streamlog)
if CONF.publish_errors:
handler = importutils.import_object(
"cerberus.openstack.common.log_handler.PublishErrorsHandler",
logging.ERROR)
log_root.addHandler(handler)
datefmt = CONF.log_date_format
for handler in log_root.handlers:
# NOTE(alaski): CONF.log_format overrides everything currently. This
# should be deprecated in favor of context aware formatting.
if CONF.log_format:
handler.setFormatter(logging.Formatter(fmt=CONF.log_format,
datefmt=datefmt))
log_root.info('Deprecated: log_format is now deprecated and will '
'be removed in the next release')
else:
handler.setFormatter(ContextFormatter(project=project,
version=version,
datefmt=datefmt))
if CONF.debug:
log_root.setLevel(logging.DEBUG)
elif CONF.verbose:
log_root.setLevel(logging.INFO)
else:
log_root.setLevel(logging.WARNING)
for pair in CONF.default_log_levels:
mod, _sep, level_name = pair.partition('=')
level = logging.getLevelName(level_name)
logger = logging.getLogger(mod)
logger.setLevel(level)
_loggers = {}
def getLogger(name='unknown', version='unknown'):
if name not in _loggers:
_loggers[name] = ContextAdapter(logging.getLogger(name),
name,
version)
return _loggers[name]
def getLazyLogger(name='unknown', version='unknown'):
"""Returns lazy logger.
Creates a pass-through logger that does not create the real logger
until it is really needed and delegates all calls to the real logger
once it is created.
"""
return LazyAdapter(name, version)
class WritableLogger(object):
"""A thin wrapper that responds to `write` and logs."""
def __init__(self, logger, level=logging.INFO):
self.logger = logger
self.level = level
def write(self, msg):
self.logger.log(self.level, msg.rstrip())
class ContextFormatter(logging.Formatter):
"""A context.RequestContext aware formatter configured through flags.
The flags used to set format strings are: logging_context_format_string
and logging_default_format_string. You can also specify
logging_debug_format_suffix to append extra formatting if the log level is
debug.
For information about what variables are available for the formatter see:
http://docs.python.org/library/logging.html#formatter
If available, uses the context value stored in TLS - local.store.context
"""
def __init__(self, *args, **kwargs):
"""Initialize ContextFormatter instance
Takes additional keyword arguments which can be used in the message
format string.
:keyword project: project name
:type project: string
:keyword version: project version
:type version: string
"""
self.project = kwargs.pop('project', 'unknown')
self.version = kwargs.pop('version', 'unknown')
logging.Formatter.__init__(self, *args, **kwargs)
def format(self, record):
"""Uses contextstring if request_id is set, otherwise default."""
# store project info
record.project = self.project
record.version = self.version
# store request info
context = getattr(local.store, 'context', None)
if context:
d = _dictify_context(context)
for k, v in d.items():
setattr(record, k, v)
# NOTE(sdague): default the fancier formatting params
# to an empty string so we don't throw an exception if
# they get used
for key in ('instance', 'color', 'user_identity'):
if key not in record.__dict__:
record.__dict__[key] = ''
if record.__dict__.get('request_id'):
self._fmt = CONF.logging_context_format_string
else:
self._fmt = CONF.logging_default_format_string
if (record.levelno == logging.DEBUG and
CONF.logging_debug_format_suffix):
self._fmt += " " + CONF.logging_debug_format_suffix
# Cache this on the record, Logger will respect our formatted copy
if record.exc_info:
record.exc_text = self.formatException(record.exc_info, record)
return logging.Formatter.format(self, record)
def formatException(self, exc_info, record=None):
"""Format exception output with CONF.logging_exception_prefix."""
if not record:
return logging.Formatter.formatException(self, exc_info)
stringbuffer = moves.StringIO()
traceback.print_exception(exc_info[0], exc_info[1], exc_info[2],
None, stringbuffer)
lines = stringbuffer.getvalue().split('\n')
stringbuffer.close()
if CONF.logging_exception_prefix.find('%(asctime)') != -1:
record.asctime = self.formatTime(record, self.datefmt)
formatted_lines = []
for line in lines:
pl = CONF.logging_exception_prefix % record.__dict__
fl = '%s%s' % (pl, line)
formatted_lines.append(fl)
return '\n'.join(formatted_lines)
class ColorHandler(logging.StreamHandler):
LEVEL_COLORS = {
logging.DEBUG: '\033[00;32m', # GREEN
logging.INFO: '\033[00;36m', # CYAN
logging.AUDIT: '\033[01;36m', # BOLD CYAN
logging.WARN: '\033[01;33m', # BOLD YELLOW
logging.ERROR: '\033[01;31m', # BOLD RED
logging.CRITICAL: '\033[01;31m', # BOLD RED
}
def format(self, record):
record.color = self.LEVEL_COLORS[record.levelno]
return logging.StreamHandler.format(self, record)
class DeprecatedConfig(Exception):
message = _("Fatal call to deprecated config: %(msg)s")
def __init__(self, msg):
super(Exception, self).__init__(self.message % dict(msg=msg))

View File

@ -0,0 +1,30 @@
# Copyright 2013 IBM Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import logging
from oslo.config import cfg
from cerberus.openstack.common import notifier
class PublishErrorsHandler(logging.Handler):
def emit(self, record):
if ('cerberus.openstack.common.notifier.log_notifier' in
cfg.CONF.notification_driver):
return
notifier.api.notify(None, 'error.publisher',
'error_notification',
notifier.api.ERROR,
dict(error=record.getMessage()))

View File

@ -0,0 +1,145 @@
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# Copyright 2011 Justin Santa Barbara
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import sys
from eventlet import event
from eventlet import greenthread
from cerberus.openstack.common.gettextutils import _LE, _LW
from cerberus.openstack.common import log as logging
from cerberus.openstack.common import timeutils
LOG = logging.getLogger(__name__)
class LoopingCallDone(Exception):
"""Exception to break out and stop a LoopingCall.
The poll-function passed to LoopingCall can raise this exception to
break out of the loop normally. This is somewhat analogous to
StopIteration.
An optional return-value can be included as the argument to the exception;
this return-value will be returned by LoopingCall.wait()
"""
def __init__(self, retvalue=True):
""":param retvalue: Value that LoopingCall.wait() should return."""
self.retvalue = retvalue
class LoopingCallBase(object):
def __init__(self, f=None, *args, **kw):
self.args = args
self.kw = kw
self.f = f
self._running = False
self.done = None
def stop(self):
self._running = False
def wait(self):
return self.done.wait()
class FixedIntervalLoopingCall(LoopingCallBase):
"""A fixed interval looping call."""
def start(self, interval, initial_delay=None):
self._running = True
done = event.Event()
def _inner():
if initial_delay:
greenthread.sleep(initial_delay)
try:
while self._running:
start = timeutils.utcnow()
self.f(*self.args, **self.kw)
end = timeutils.utcnow()
if not self._running:
break
delay = interval - timeutils.delta_seconds(start, end)
if delay <= 0:
LOG.warn(_LW('task run outlasted interval by %s sec') %
-delay)
greenthread.sleep(delay if delay > 0 else 0)
except LoopingCallDone as e:
self.stop()
done.send(e.retvalue)
except Exception:
LOG.exception(_LE('in fixed duration looping call'))
done.send_exception(*sys.exc_info())
return
else:
done.send(True)
self.done = done
self.gt = greenthread.spawn(_inner)
return self.done
# TODO(mikal): this class name is deprecated in Havana and should be removed
# in the I release
LoopingCall = FixedIntervalLoopingCall
class DynamicLoopingCall(LoopingCallBase):
"""A looping call which sleeps until the next known event.
The function called should return how long to sleep for before being
called again.
"""
def start(self, initial_delay=None, periodic_interval_max=None):
self._running = True
done = event.Event()
def _inner():
if initial_delay:
greenthread.sleep(initial_delay)
try:
while self._running:
idle = self.f(*self.args, **self.kw)
if not self._running:
break
if periodic_interval_max is not None:
idle = min(idle, periodic_interval_max)
LOG.debug('Dynamic looping call sleeping for %.02f '
'seconds', idle)
greenthread.sleep(idle)
except LoopingCallDone as e:
self.stop()
done.send(e.retvalue)
except Exception:
LOG.exception(_LE('in dynamic looping call'))
done.send_exception(*sys.exc_info())
return
else:
done.send(True)
self.done = done
greenthread.spawn(_inner)
return self.done

View File

@ -0,0 +1,108 @@
# Copyright 2012 OpenStack Foundation.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Network-related utilities and helper functions.
"""
# TODO(jd) Use six.moves once
# https://bitbucket.org/gutworth/six/pull-request/28
# is merged
try:
import urllib.parse
SplitResult = urllib.parse.SplitResult
except ImportError:
import urlparse
SplitResult = urlparse.SplitResult
from six.moves.urllib import parse
def parse_host_port(address, default_port=None):
"""Interpret a string as a host:port pair.
An IPv6 address MUST be escaped if accompanied by a port,
because otherwise ambiguity ensues: 2001:db8:85a3::8a2e:370:7334
means both [2001:db8:85a3::8a2e:370:7334] and
[2001:db8:85a3::8a2e:370]:7334.
>>> parse_host_port('server01:80')
('server01', 80)
>>> parse_host_port('server01')
('server01', None)
>>> parse_host_port('server01', default_port=1234)
('server01', 1234)
>>> parse_host_port('[::1]:80')
('::1', 80)
>>> parse_host_port('[::1]')
('::1', None)
>>> parse_host_port('[::1]', default_port=1234)
('::1', 1234)
>>> parse_host_port('2001:db8:85a3::8a2e:370:7334', default_port=1234)
('2001:db8:85a3::8a2e:370:7334', 1234)
"""
if address[0] == '[':
# Escaped ipv6
_host, _port = address[1:].split(']')
host = _host
if ':' in _port:
port = _port.split(':')[1]
else:
port = default_port
else:
if address.count(':') == 1:
host, port = address.split(':')
else:
# 0 means ipv4, >1 means ipv6.
# We prohibit unescaped ipv6 addresses with port.
host = address
port = default_port
return (host, None if port is None else int(port))
class ModifiedSplitResult(SplitResult):
"""Split results class for urlsplit."""
# NOTE(dims): The functions below are needed for Python 2.6.x.
# We can remove these when we drop support for 2.6.x.
@property
def hostname(self):
netloc = self.netloc.split('@', 1)[-1]
host, port = parse_host_port(netloc)
return host
@property
def port(self):
netloc = self.netloc.split('@', 1)[-1]
host, port = parse_host_port(netloc)
return port
def urlsplit(url, scheme='', allow_fragments=True):
"""Parse a URL using urlparse.urlsplit(), splitting query and fragments.
This function papers over Python issue9374 when needed.
The parameters are the same as urlparse.urlsplit.
"""
scheme, netloc, path, query, fragment = parse.urlsplit(
url, scheme, allow_fragments)
if allow_fragments and '#' in path:
path, fragment = path.split('#', 1)
if '?' in path:
path, query = path.split('?', 1)
return ModifiedSplitResult(scheme, netloc,
path, query, fragment)

View File

@ -0,0 +1,173 @@
# Copyright 2011 OpenStack Foundation.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import socket
import uuid
from oslo.config import cfg
from cerberus.openstack.common import context
from cerberus.openstack.common.gettextutils import _, _LE
from cerberus.openstack.common import importutils
from cerberus.openstack.common import jsonutils
from cerberus.openstack.common import log as logging
from cerberus.openstack.common import timeutils
LOG = logging.getLogger(__name__)
notifier_opts = [
cfg.MultiStrOpt('notification_driver',
default=[],
help='Driver or drivers to handle sending notifications'),
cfg.StrOpt('default_notification_level',
default='INFO',
help='Default notification level for outgoing notifications'),
cfg.StrOpt('default_publisher_id',
default=None,
help='Default publisher_id for outgoing notifications'),
]
CONF = cfg.CONF
CONF.register_opts(notifier_opts)
WARN = 'WARN'
INFO = 'INFO'
ERROR = 'ERROR'
CRITICAL = 'CRITICAL'
DEBUG = 'DEBUG'
log_levels = (DEBUG, WARN, INFO, ERROR, CRITICAL)
class BadPriorityException(Exception):
pass
def notify_decorator(name, fn):
"""Decorator for notify which is used from utils.monkey_patch().
:param name: name of the function
:param function: - object of the function
:returns: function -- decorated function
"""
def wrapped_func(*args, **kwarg):
body = {}
body['args'] = []
body['kwarg'] = {}
for arg in args:
body['args'].append(arg)
for key in kwarg:
body['kwarg'][key] = kwarg[key]
ctxt = context.get_context_from_function_and_args(fn, args, kwarg)
notify(ctxt,
CONF.default_publisher_id or socket.gethostname(),
name,
CONF.default_notification_level,
body)
return fn(*args, **kwarg)
return wrapped_func
def publisher_id(service, host=None):
if not host:
try:
host = CONF.host
except AttributeError:
host = CONF.default_publisher_id or socket.gethostname()
return "%s.%s" % (service, host)
def notify(context, publisher_id, event_type, priority, payload):
"""Sends a notification using the specified driver
:param publisher_id: the source worker_type.host of the message
:param event_type: the literal type of event (ex. Instance Creation)
:param priority: patterned after the enumeration of Python logging
levels in the set (DEBUG, WARN, INFO, ERROR, CRITICAL)
:param payload: A python dictionary of attributes
Outgoing message format includes the above parameters, and appends the
following:
message_id
a UUID representing the id for this notification
timestamp
the GMT timestamp the notification was sent at
The composite message will be constructed as a dictionary of the above
attributes, which will then be sent via the transport mechanism defined
by the driver.
Message example::
{'message_id': str(uuid.uuid4()),
'publisher_id': 'compute.host1',
'timestamp': timeutils.utcnow(),
'priority': 'WARN',
'event_type': 'compute.create_instance',
'payload': {'instance_id': 12, ... }}
"""
if priority not in log_levels:
raise BadPriorityException(
_('%s not in valid priorities') % priority)
# Ensure everything is JSON serializable.
payload = jsonutils.to_primitive(payload, convert_instances=True)
msg = dict(message_id=str(uuid.uuid4()),
publisher_id=publisher_id,
event_type=event_type,
priority=priority,
payload=payload,
timestamp=str(timeutils.utcnow()))
for driver in _get_drivers():
try:
driver.notify(context, msg)
except Exception as e:
LOG.exception(_LE("Problem '%(e)s' attempting to "
"send to notification system. "
"Payload=%(payload)s")
% dict(e=e, payload=payload))
_drivers = None
def _get_drivers():
"""Instantiate, cache, and return drivers based on the CONF."""
global _drivers
if _drivers is None:
_drivers = {}
for notification_driver in CONF.notification_driver:
try:
driver = importutils.import_module(notification_driver)
_drivers[notification_driver] = driver
except ImportError:
LOG.exception(_LE("Failed to load notifier %s. "
"These notifications will not be sent.") %
notification_driver)
return _drivers.values()
def _reset_drivers():
"""Used by unit tests to reset the drivers."""
global _drivers
_drivers = None

View File

@ -0,0 +1,37 @@
# Copyright 2011 OpenStack Foundation.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo.config import cfg
from cerberus.openstack.common import jsonutils
from cerberus.openstack.common import log as logging
CONF = cfg.CONF
def notify(_context, message):
"""Notifies the recipient of the desired event given the model.
Log notifications using OpenStack's default logging system.
"""
priority = message.get('priority',
CONF.default_notification_level)
priority = priority.lower()
logger = logging.getLogger(
'cerberus.openstack.common.notification.%s' %
message['event_type'])
getattr(logger, priority)(jsonutils.dumps(message))

View File

@ -0,0 +1,19 @@
# Copyright 2011 OpenStack Foundation.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
def notify(_context, message):
"""Notifies the recipient of the desired event given the model."""
pass

View File

@ -0,0 +1,77 @@
# Copyright 2013 Red Hat, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
A temporary helper which emulates oslo.messaging.Notifier.
This helper method allows us to do the tedious porting to the new Notifier API
as a standalone commit so that the commit which switches us to oslo.messaging
is smaller and easier to review. This file will be removed as part of that
commit.
"""
from oslo.config import cfg
from cerberus.openstack.common.notifier import api as notifier_api
CONF = cfg.CONF
class Notifier(object):
def __init__(self, publisher_id):
super(Notifier, self).__init__()
self.publisher_id = publisher_id
_marker = object()
def prepare(self, publisher_id=_marker):
ret = self.__class__(self.publisher_id)
if publisher_id is not self._marker:
ret.publisher_id = publisher_id
return ret
def _notify(self, ctxt, event_type, payload, priority):
notifier_api.notify(ctxt,
self.publisher_id,
event_type,
priority,
payload)
def audit(self, ctxt, event_type, payload):
# No audit in old notifier.
self._notify(ctxt, event_type, payload, 'INFO')
def debug(self, ctxt, event_type, payload):
self._notify(ctxt, event_type, payload, 'DEBUG')
def info(self, ctxt, event_type, payload):
self._notify(ctxt, event_type, payload, 'INFO')
def warn(self, ctxt, event_type, payload):
self._notify(ctxt, event_type, payload, 'WARN')
warning = warn
def error(self, ctxt, event_type, payload):
self._notify(ctxt, event_type, payload, 'ERROR')
def critical(self, ctxt, event_type, payload):
self._notify(ctxt, event_type, payload, 'CRITICAL')
def get_notifier(service=None, host=None, publisher_id=None):
if not publisher_id:
publisher_id = "%s.%s" % (service, host or CONF.host)
return Notifier(publisher_id)

Some files were not shown because too many files have changed in this diff Show More