diff --git a/LICENSE-2.0.txt b/LICENSE-2.0.txt new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/LICENSE-2.0.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/cerberus/__init__.py b/cerberus/__init__.py index 9028bce..6081a8d 100644 --- a/cerberus/__init__.py +++ b/cerberus/__init__.py @@ -1,19 +1,20 @@ -# -*- coding: utf-8 -*- - -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. import pbr.version - __version__ = pbr.version.VersionInfo( 'cerberus').version_string() diff --git a/cerberus/api/__init__.py b/cerberus/api/__init__.py new file mode 100644 index 0000000..08d938c --- /dev/null +++ b/cerberus/api/__init__.py @@ -0,0 +1,27 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from oslo.config import cfg + +from cerberus.openstack.common._i18n import _ # noqa + +keystone_opts = [ + cfg.StrOpt('auth_strategy', default='keystone', + help=_('The strategy to use for authentication.')) +] + +CONF = cfg.CONF +CONF.register_opts(keystone_opts) diff --git a/cerberus/api/app.py b/cerberus/api/app.py new file mode 100644 index 0000000..e709c20 --- /dev/null +++ b/cerberus/api/app.py @@ -0,0 +1,112 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pecan +from wsgiref import simple_server + +from oslo.config import cfg + +import auth +from cerberus.api import config as api_config +from cerberus.api import hooks +from cerberus.openstack.common import log as logging + +LOG = logging.getLogger(__name__) + +auth_opts = [ + cfg.StrOpt('api_paste_config', + default="api_paste.ini", + help="Configuration file for WSGI definition of API." + ), +] + +api_opts = [ + cfg.StrOpt('host_ip', + default="0.0.0.0", + help="Host serving the API." + ), + cfg.IntOpt('port', + default=8300, + help="Host port serving the API." + ), +] + +CONF = cfg.CONF +CONF.register_opts(auth_opts) +CONF.register_opts(api_opts, group='api') + + +def get_pecan_config(): + # Set up the pecan configuration + filename = api_config.__file__.replace('.pyc', '.py') + return pecan.configuration.conf_from_file(filename) + + +def setup_app(pecan_config=None, extra_hooks=None): + + if not pecan_config: + pecan_config = get_pecan_config() + + app_hooks = [hooks.ConfigHook(), + hooks.DBHook(), + hooks.ContextHook(pecan_config.app.acl_public_routes), + hooks.NoExceptionTracebackHook()] + + if pecan_config.app.enable_acl: + app_hooks.append(hooks.AuthorizationHook( + pecan_config.app.member_routes)) + + pecan.configuration.set_config(dict(pecan_config), overwrite=True) + + app = pecan.make_app( + pecan_config.app.root, + static_root=pecan_config.app.static_root, + template_path=pecan_config.app.template_path, + debug=CONF.debug, + force_canonical=getattr(pecan_config.app, 'force_canonical', True), + hooks=app_hooks, + guess_content_type_from_ext=False + ) + + if pecan_config.app.enable_acl: + strategy = auth.strategy(CONF.auth_strategy) + return strategy.install(app, + cfg.CONF, + pecan_config.app.acl_public_routes) + + return app + + +def build_server(): + # Create the WSGI server and start it + host = CONF.api.host_ip + port = CONF.api.port + + server_cls = simple_server.WSGIServer + handler_cls = simple_server.WSGIRequestHandler + + pecan_config = get_pecan_config() + pecan_config.app.enable_acl = (CONF.auth_strategy == 'keystone') + + app = setup_app(pecan_config=pecan_config) + + srv = simple_server.make_server( + host, + port, + app, + server_cls, + handler_cls) + + return srv diff --git a/cerberus/api/auth.py b/cerberus/api/auth.py new file mode 100644 index 0000000..47d5a74 --- /dev/null +++ b/cerberus/api/auth.py @@ -0,0 +1,62 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from cerberus.api.middleware import auth_token + +from cerberus.openstack.common import log + +STRATEGIES = {} + +LOG = log.getLogger(__name__) + + +OPT_GROUP_NAME = 'keystone_authtoken' + + +class KeystoneAuth(object): + + @classmethod + def _register_opts(cls, conf): + """Register keystoneclient middleware options.""" + + if OPT_GROUP_NAME not in conf: + conf.register_opts(auth_token.opts, group=OPT_GROUP_NAME) + auth_token.CONF = conf + + @classmethod + def install(cls, app, conf, public_routes): + """Install Auth check on application.""" + LOG.debug(u'Installing Keystone\'s auth protocol') + cls._register_opts(conf) + conf = dict(conf.get(OPT_GROUP_NAME)) + return auth_token.AuthTokenMiddleware(app, + conf=conf, + public_api_routes=public_routes) + + +STRATEGIES['keystone'] = KeystoneAuth + + +def strategy(strategy): + """Returns the Auth Strategy. + + :param strategy: String representing + the strategy to use + """ + try: + return STRATEGIES[strategy] + except KeyError: + raise RuntimeError diff --git a/cerberus/api/config.py b/cerberus/api/config.py new file mode 100644 index 0000000..0e2a213 --- /dev/null +++ b/cerberus/api/config.py @@ -0,0 +1,23 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +# Pecan Application Configurations +app = { + 'root': 'cerberus.api.root.RootController', + 'modules': ['cerberus.api'], + 'static_root': '%(confdir)s/public', + 'template_path': '%(confdir)s/templates', + 'debug': True, + 'enable_acl': False, + 'acl_public_routes': ['/', '/v1'], + 'member_routes': ['/v1/security_reports', ] +} diff --git a/cerberus/api/hooks.py b/cerberus/api/hooks.py new file mode 100644 index 0000000..7c7d9af --- /dev/null +++ b/cerberus/api/hooks.py @@ -0,0 +1,147 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from oslo.config import cfg +from pecan import hooks +from webob import exc + +from cerberus.common import context +from cerberus.common import policy +from cerberus.db import api as dbapi + + +class ConfigHook(hooks.PecanHook): + """Attach the config object to the request so controllers can get to it.""" + + def before(self, state): + state.request.cfg = cfg.CONF + + +class DBHook(hooks.PecanHook): + """Attach the dbapi object to the request so controllers can get to it.""" + + def before(self, state): + state.request.dbapi = dbapi.get_instance() + + +class ContextHook(hooks.PecanHook): + """Configures a request context and attaches it to the request. + + The following HTTP request headers are used: + + X-User-Id or X-User: + Used for context.user_id. + + X-Tenant-Id or X-Tenant: + Used for context.tenant. + + X-Auth-Token: + Used for context.auth_token. + + X-Roles: + Used for setting context.is_admin flag to either True or False. + The flag is set to True, if X-Roles contains either an administrator + or admin substring. Otherwise it is set to False. + + """ + def __init__(self, public_api_routes): + self.public_api_routes = public_api_routes + super(ContextHook, self).__init__() + + def before(self, state): + user_id = state.request.headers.get('X-User-Id') + user_id = state.request.headers.get('X-User', user_id) + tenant_id = state.request.headers.get('X-Tenant-Id') + tenant = state.request.headers.get('X-Tenant', tenant_id) + domain_id = state.request.headers.get('X-User-Domain-Id') + domain_name = state.request.headers.get('X-User-Domain-Name') + auth_token = state.request.headers.get('X-Auth-Token') + roles = state.request.headers.get('X-Roles', '').split(',') + creds = {'roles': roles} + + is_public_api = state.request.environ.get('is_public_api', False) + is_admin = policy.enforce('context_is_admin', + state.request.headers, + creds) + + state.request.context = context.RequestContext( + auth_token=auth_token, + user=user_id, + tenant_id=tenant_id, + tenant=tenant, + domain_id=domain_id, + domain_name=domain_name, + is_admin=is_admin, + is_public_api=is_public_api, + roles=roles) + + +class AuthorizationHook(hooks.PecanHook): + """Verify that the user has admin rights. + + Checks whether the request context is an admin context and + rejects the request if the api is not public. + + """ + def __init__(self, member_routes): + self.member_routes = member_routes + super(AuthorizationHook, self).__init__() + + def before(self, state): + ctx = state.request.context + + if not ctx.is_admin and not ctx.is_public_api and \ + state.request.path not in self.member_routes: + raise exc.HTTPForbidden() + + +class NoExceptionTracebackHook(hooks.PecanHook): + """Workaround rpc.common: deserialize_remote_exception. + + deserialize_remote_exception builds rpc exception traceback into error + message which is then sent to the client. Such behavior is a security + concern so this hook is aimed to cut-off traceback from the error message. + + """ + # NOTE(max_lobur): 'after' hook used instead of 'on_error' because + # 'on_error' never fired for wsme+pecan pair. wsme @wsexpose decorator + # catches and handles all the errors, so 'on_error' dedicated for unhandled + # exceptions never fired. + def after(self, state): + # Omit empty body. Some errors may not have body at this level yet. + if not state.response.body: + return + + # Do nothing if there is no error. + if 200 <= state.response.status_int < 400: + return + + json_body = state.response.json + # Do not remove traceback when server in debug mode (except 'Server' + # errors when 'debuginfo' will be used for traces). + if cfg.CONF.debug and json_body.get('faultcode') != 'Server': + return + + faultsting = json_body.get('faultstring') + traceback_marker = 'Traceback (most recent call last):' + if faultsting and (traceback_marker in faultsting): + # Cut-off traceback. + faultsting = faultsting.split(traceback_marker, 1)[0] + # Remove trailing newlines and spaces if any. + json_body['faultstring'] = faultsting.rstrip() + # Replace the whole json. Cannot change original one beacause it's + # generated on the fly. + state.response.json = json_body diff --git a/cerberus/api/middleware/__init__.py b/cerberus/api/middleware/__init__.py new file mode 100644 index 0000000..d15c2ea --- /dev/null +++ b/cerberus/api/middleware/__init__.py @@ -0,0 +1,20 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from cerberus.api.middleware import auth_token + +AuthTokenMiddleware = auth_token.AuthTokenMiddleware diff --git a/cerberus/api/middleware/auth_token.py b/cerberus/api/middleware/auth_token.py new file mode 100644 index 0000000..989e4f7 --- /dev/null +++ b/cerberus/api/middleware/auth_token.py @@ -0,0 +1,61 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re + +from keystoneclient.middleware import auth_token + +from cerberus.common import exception +from cerberus.common import safe_utils +from cerberus.openstack.common import log + +LOG = log.getLogger(__name__) + + +class AuthTokenMiddleware(auth_token.AuthProtocol): + """A wrapper on Keystone auth_token middleware. + + Does not perform verification of authentication tokens + for public routes in the API. + + """ + def __init__(self, app, conf, public_api_routes=[]): + route_pattern_tpl = '%s(\.json|\.xml)?$' + + try: + self.public_api_routes = [re.compile(route_pattern_tpl % route_tpl) + for route_tpl in public_api_routes] + except re.error as e: + msg = _('Cannot compile public API routes: %s') % e + + LOG.error(msg) + raise exception.ConfigInvalid(error_msg=msg) + + super(AuthTokenMiddleware, self).__init__(app, conf) + + def __call__(self, env, start_response): + path = safe_utils.safe_rstrip(env.get('PATH_INFO'), '/') + + # The information whether the API call is being performed against the + # public API is required for some other components. Saving it to the + # WSGI environment is reasonable thereby. + env['is_public_api'] = any(map(lambda pattern: re.match(pattern, path), + self.public_api_routes)) + + if env['is_public_api']: + return self.app(env, start_response) + + return super(AuthTokenMiddleware, self).__call__(env, start_response) diff --git a/cerberus/api/root.py b/cerberus/api/root.py new file mode 100644 index 0000000..ffb362c --- /dev/null +++ b/cerberus/api/root.py @@ -0,0 +1,140 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pecan +from pecan import rest +from wsme import types as wtypes +import wsmeext.pecan as wsme_pecan + +from cerberus.api.v1 import controllers as v1_api +from cerberus.openstack.common import log as logging + +LOG = logging.getLogger(__name__) + +VERSION_STATUS = wtypes.Enum(wtypes.text, 'EXPERIMENTAL', 'STABLE') + + +class APILink(wtypes.Base): + """API link description. + + """ + + type = wtypes.text + """Type of link.""" + + rel = wtypes.text + """Relationship with this link.""" + + href = wtypes.text + """URL of the link.""" + + @classmethod + def sample(cls): + version = 'v1' + sample = cls( + rel='self', + type='text/html', + href='http://127.0.0.1:8888/{id}'.format( + id=version)) + return sample + + +class APIMediaType(wtypes.Base): + """Media type description. + + """ + + base = wtypes.text + """Base type of this media type.""" + + type = wtypes.text + """Type of this media type.""" + + @classmethod + def sample(cls): + sample = cls( + base='application/json', + type='application/vnd.openstack.sticks-v1+json') + return sample + + +class APIVersion(wtypes.Base): + """API Version description. + + """ + + id = wtypes.text + """ID of the version.""" + + status = VERSION_STATUS + """Status of the version.""" + + updated = wtypes.text + "Last update in iso8601 format." + + links = [APILink] + """List of links to API resources.""" + + media_types = [APIMediaType] + """Types accepted by this API.""" + + @classmethod + def sample(cls): + version = 'v1' + updated = '2014-08-11T16:00:00Z' + links = [APILink.sample()] + media_types = [APIMediaType.sample()] + sample = cls(id=version, + status='STABLE', + updated=updated, + links=links, + media_types=media_types) + return sample + + +class RootController(rest.RestController): + """Root REST Controller exposing versions of the API. + + """ + + v1 = v1_api.V1Controller() + + @wsme_pecan.wsexpose([APIVersion]) + def get(self): + """Return the version list + + """ + # TODO(sheeprine): Maybe we should store all the API version + # informations in every API modules + ver1 = APIVersion( + id='v1', + status='EXPERIMENTAL', + updated='2015-03-09T16:00:00Z', + links=[ + APILink( + rel='self', + href='{scheme}://{host}/v1'.format( + scheme=pecan.request.scheme, + host=pecan.request.host, + ) + ) + ], + media_types=[] + ) + + versions = [] + versions.append(ver1) + + return versions diff --git a/cerberus/api/v1/__init__.py b/cerberus/api/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cerberus/api/v1/controllers/__init__.py b/cerberus/api/v1/controllers/__init__.py new file mode 100644 index 0000000..060d34a --- /dev/null +++ b/cerberus/api/v1/controllers/__init__.py @@ -0,0 +1,31 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pecan import rest + +from cerberus.api.v1.controllers import alerts as alerts_api +from cerberus.api.v1.controllers import plugins as plugins_api +from cerberus.api.v1.controllers import security_reports as \ + security_reports_api +from cerberus.api.v1.controllers import tasks as tasks_api + + +class V1Controller(rest.RestController): + """API version 1 controller. """ + alerts = alerts_api.AlertsController() + plugins = plugins_api.PluginsController() + security_reports = security_reports_api.SecurityReportsController() + tasks = tasks_api.TasksController() diff --git a/cerberus/api/v1/controllers/base.py b/cerberus/api/v1/controllers/base.py new file mode 100644 index 0000000..de782ac --- /dev/null +++ b/cerberus/api/v1/controllers/base.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pecan import rest + +from oslo.config import cfg +from oslo import messaging + +from cerberus.openstack.common import log + + +LOG = log.getLogger(__name__) + + +class BaseController(rest.RestController): + + def __init__(self): + transport = messaging.get_transport(cfg.CONF) + target = messaging.Target(topic='test_rpc', server='server1') + self.client = messaging.RPCClient(transport, target) diff --git a/cerberus/api/v1/controllers/other_controllers_to_change/home.py b/cerberus/api/v1/controllers/other_controllers_to_change/home.py new file mode 100644 index 0000000..ff62c5f --- /dev/null +++ b/cerberus/api/v1/controllers/other_controllers_to_change/home.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json + +JSON_HOME = { + 'resources': { + # ----------------------------------------------------------------- + # Plugins + # ----------------------------------------------------------------- + 'rel/plugins': { + 'href-template': '/v1/plugins{?id}', + 'href-vars': { + 'method_name': 'param/method_name', + 'task_name': 'param/task_name', + 'task_type': 'param/task_type', + 'task_period': 'param/task_period', + }, + 'hints': { + 'allow': ['GET', 'POST'], + 'formats': { + 'application/json': {}, + }, + }, + } + } +} + + +class Resource(object): + def __init__(self): + document = json.dumps(JSON_HOME, ensure_ascii=False, indent=4) + self.document_utf8 = document.encode('utf-8') + + def on_get(self, req, resp): + resp.data = self.document_utf8 + resp.content_type = 'application/json-home' + resp.cache_control = ['max-age=86400'] diff --git a/cerberus/api/v1/controllers/plugins.py b/cerberus/api/v1/controllers/plugins.py new file mode 100644 index 0000000..3b9a3e6 --- /dev/null +++ b/cerberus/api/v1/controllers/plugins.py @@ -0,0 +1,149 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import pecan +from webob import exc + +from oslo import messaging + +from cerberus.api.v1.controllers import base +from cerberus.common import errors +from cerberus import db +from cerberus.db.sqlalchemy import models +from cerberus.openstack.common import log + + +LOG = log.getLogger(__name__) + +_ENFORCER = None + + +class PluginsController(base.BaseController): + + def list_plugins(self): + """ List all the plugins installed on system """ + + # Get information about plugins loaded by Cerberus + try: + plugins = self._plugins() + except messaging.RemoteError as e: + LOG.exception(e) + raise + try: + # Get information about plugins stored in db + db_plugins_info = db.plugins_info_get() + except Exception as e: + LOG.exception(e) + raise + plugins_info = [] + for plugin_info in db_plugins_info: + plugins_info.append(models.PluginInfoJsonSerializer(). + serialize(plugin_info)) + plugins_full_info = [] + for plugin in plugins: + for plugin_info in plugins_info: + if (plugin.get('name') == plugin_info.get('name')): + plugins_full_info.append(dict(plugin.items() + + plugin_info.items())) + return plugins_full_info + + def _plugins(self): + """ Get a list of plugins loaded by Cerberus Manager """ + ctx = pecan.request.context.to_dict() + try: + plugins = self.client.call(ctx, 'get_plugins') + except messaging.RemoteError as e: + LOG.exception(e) + raise + plugins_ = [] + for plugin in plugins: + plugin_ = json.loads(plugin) + plugins_.append(plugin_) + return plugins_ + + @pecan.expose("json") + def get_all(self): + """ Get a list of plugins loaded by Cerberus manager + :return: a list of plugins loaded by Cerberus manager + :raises: + HTTPServiceUnavailable: an error occurred in Cerberus Manager or + the service is unavailable + HTTPNotFound: any other error + """ + + # Get information about plugins loaded by Cerberus + try: + plugins = self.list_plugins() + except messaging.RemoteError: + raise exc.HTTPServiceUnavailable() + except Exception as e: + LOG.exception(e) + raise exc.HTTPNotFound() + return {'plugins': plugins} + + def get_plugin(self, uuid): + """ Get information about plugin loaded by Cerberus""" + try: + plugin = self._plugin(uuid) + except messaging.RemoteError: + raise + except errors.PluginNotFound: + raise + try: + # Get information about plugin stored in db + db_plugin_info = db.plugin_info_get_from_uuid(uuid) + plugin_info = models.PluginInfoJsonSerializer().\ + serialize(db_plugin_info) + except Exception as e: + LOG.exception(e) + raise + return dict(plugin_info.items() + plugin.items()) + + def _plugin(self, uuid): + """ Get a specific plugin thanks to its identifier """ + ctx = pecan.request.context.to_dict() + try: + plugin = self.client.call(ctx, 'get_plugin_from_uuid', uuid=uuid) + except messaging.RemoteError as e: + LOG.exception(e) + raise + + if plugin is None: + LOG.exception('Plugin %s not found.' % uuid) + raise errors.PluginNotFound(uuid) + return json.loads(plugin) + + @pecan.expose("json") + def get_one(self, uuid): + """ Get details of a specific plugin whose identifier is uuid + :param uuid: the identifier of the plugin + :return: details of a specific plugin + :raises: + HTTPServiceUnavailable: an error occurred in Cerberus Manager or + the service is unavailable + HTTPNotFound: Plugin is not found. Also any other error + """ + try: + plugin = self.get_plugin(uuid) + except messaging.RemoteError: + raise exc.HTTPServiceUnavailable() + except errors.PluginNotFound: + raise exc.HTTPNotFound() + except Exception as e: + LOG.exception(e) + raise exc.HTTPNotFound() + return {'plugin': plugin} diff --git a/cerberus/api/v1/controllers/security_reports.py b/cerberus/api/v1/controllers/security_reports.py new file mode 100644 index 0000000..a394136 --- /dev/null +++ b/cerberus/api/v1/controllers/security_reports.py @@ -0,0 +1,88 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pecan +from webob import exc + +from cerberus.api.v1.controllers import base +from cerberus.common import errors +from cerberus import db +from cerberus.db.sqlalchemy import models +from cerberus.openstack.common import log + +LOG = log.getLogger(__name__) + + +class SecurityReportsController(base.BaseController): + + def list_security_reports(self, project_id=None): + """ List all the security reports of all projects or just one. """ + try: + security_reports = db.security_report_get_all( + project_id=project_id) + except Exception as e: + LOG.exception(e) + raise errors.DbError( + "Security reports could not be retrieved" + ) + return security_reports + + @pecan.expose("json") + def get_all(self): + """ Get stored security reports. + :return: list of security reports for one or all projects depending on + context of the token. + """ + ctx = pecan.request.context + try: + if ctx.is_admin: + security_reports = self.list_security_reports() + else: + security_reports = self.list_security_reports(ctx.tenant_id) + except errors.DbError: + raise exc.HTTPNotFound() + json_security_reports = [] + for security_report in security_reports: + json_security_reports.append(models.SecurityReportJsonSerializer(). + serialize(security_report)) + return {'security_reports': json_security_reports} + + def get_security_report(self, id): + try: + security_report = db.security_report_get(id) + except Exception as e: + LOG.exception(e) + raise errors.DbError( + "Security report %s could not be retrieved" % id + ) + return security_report + + @pecan.expose("json") + def get_one(self, id): + """ + Get security reports in db + :param req: the HTTP request + :param resp: the HTTP response + :return: + """ + try: + security_report = self.get_security_report(id) + except errors.DbError: + raise exc.HTTPNotFound() + s_report = models.SecurityReportJsonSerializer().\ + serialize(security_report) + + return {'security_report': s_report} diff --git a/cerberus/api/v1/controllers/tasks.py b/cerberus/api/v1/controllers/tasks.py new file mode 100644 index 0000000..fde3542 --- /dev/null +++ b/cerberus/api/v1/controllers/tasks.py @@ -0,0 +1,312 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import pecan +from webob import exc +from wsme import types as wtypes + +from oslo.messaging import rpc + +from cerberus.api.v1.controllers import base +from cerberus.common import errors +from cerberus.openstack.common import log + + +LOG = log.getLogger(__name__) + + +action_kind = ["stop", "restart", "force_delete"] +action_kind_enum = wtypes.Enum(str, *action_kind) + + +class Task(wtypes.Base): + """ Representation of a task. + """ + name = wtypes.text + period = wtypes.IntegerType() + method = wtypes.text + plugin_id = wtypes.text + type = wtypes.text + + +class TasksController(base.BaseController): + + @pecan.expose() + def _lookup(self, task_id, *remainder): + return TaskController(task_id), remainder + + def list_tasks(self): + ctx = pecan.request.context.to_dict() + try: + tasks = self.client.call(ctx, 'get_tasks') + except rpc.RemoteError as e: + LOG.exception(e) + raise + tasks_ = [] + for task in tasks: + task_ = json.loads(task) + tasks_.append(task_) + return tasks_ + + @pecan.expose("json") + def get(self): + """ List tasks + :return: list of tasks + :raises: + HTTPBadRequest + """ + try: + tasks = self.list_tasks() + except rpc.RemoteError: + raise exc.HTTPServiceUnavailable() + return {'tasks': tasks} + + def create_task(self, body): + + ctx = pecan.request.context.to_dict() + + task = body.get('task', None) + if task is None: + LOG.exception("Task object not provided in request") + raise errors.TaskObjectNotProvided() + + plugin_id = task.get('plugin_id', None) + if plugin_id is None: + LOG.exception("Plugin id not provided in request") + raise errors.PluginIdNotProvided() + + method_ = task.get('method', None) + if method_ is None: + LOG.exception("Method not provided in request") + raise errors.MethodNotProvided() + + try: + task['id'] = self.client.call( + ctx, + 'add_task', + uuid=plugin_id, + method_=method_, + task_period=task.get('period', None), + task_name=task.get('name', "unknown"), + task_type=task.get('type', "unique") + ) + except rpc.RemoteError as e: + LOG.exception(e) + raise + + return task + + @pecan.expose("json") + def post(self): + """Ask Cerberus Manager to call a function of a plugin whose identifier + is uuid, either once or periodically. + :return: + :raises: + HTTPBadRequest: the request is not correct + """ + body_ = pecan.request.body + try: + body = json.loads(body_.decode('utf-8')) + except (ValueError, UnicodeDecodeError) as e: + LOG.exception(e) + raise exc.HTTPBadRequest() + try: + task = self.create_task(body) + except errors.TaskObjectNotProvided: + raise exc.HTTPBadRequest( + explanation='The task object is required.') + except errors.PluginIdNotProvided: + raise exc.HTTPBadRequest( + explanation='Plugin id must be provided as a string') + except errors.MethodNotProvided: + raise exc.HTTPBadRequest( + explanation='Method must be provided as a string') + except rpc.RemoteError as e: + LOG.exception(e) + raise exc.HTTPBadRequest(explanation=e.value) + except Exception as e: + LOG.exception(e) + raise exc.HTTPBadRequest() + return {'task': task} + + +class TaskController(base.BaseController): + """Manages operation on a single task.""" + + _custom_actions = { + 'action': ['POST'] + } + + def __init__(self, task_id): + super(TaskController, self).__init__() + pecan.request.context['task_id'] = task_id + try: + self._id = int(task_id) + except ValueError: + raise exc.HTTPBadRequest( + explanation='Task id must be an integer') + + def get_task(self, id): + ctx = pecan.request.context.to_dict() + try: + task = self.client.call(ctx, 'get_task', id=int(id)) + except ValueError as e: + LOG.exception(e) + raise + except rpc.RemoteError as e: + LOG.exception(e) + raise + return json.loads(task) + + @pecan.expose("json") + def get(self): + """ Get details of a task whose id is id + :param id: the id of the task + :return: + :raises: + HTTPBadRequest + """ + try: + task = self.get_task(self._id) + except ValueError: + raise exc.HTTPBadRequest( + explanation='Task id must be an integer') + except rpc.RemoteError: + raise exc.HTTPNotFound() + except Exception as e: + LOG.exception(e) + raise + return {'task': task} + + @pecan.expose("json") + def post(self): + """ + Enable to perform certain actions on a specific task (e.g; stop it) + :param req: the HTTP request, including the action to perform + :param resp: the HTTP response, including a description and the task id + :param id: the identifier of the task on which an action has to be + performed + :return: + :raises: + HTTPError: Incorrect JSON or not UTF-8 encoded + HTTPBadRequest: id not integer or task does not exist + """ + body_ = pecan.request.body + try: + body = json.loads(body_.decode('utf-8')) + except (ValueError, UnicodeDecodeError) as e: + LOG.exception(e) + raise exc.HTTPBadRequest() + + if 'stop' in body: + try: + self.stop_task(self._id) + except ValueError: + raise exc.HTTPBadRequest( + explanation="Task id must be an integer") + except rpc.RemoteError: + raise exc.HTTPBadRequest( + explanation="Task can not be stopped") + elif 'forceDelete' in body: + try: + self.force_delete(self._id) + except ValueError: + raise exc.HTTPBadRequest( + explanation="Task id must be an integer") + except rpc.RemoteError as e: + raise exc.HTTPBadRequest(explanation=e.value) + + elif 'restart' in body: + try: + self.restart(self._id) + except ValueError: + raise exc.HTTPBadRequest( + explanation="Task id must be an integer") + except rpc.RemoteError as e: + raise exc.HTTPBadRequest(explanation=e.value) + else: + raise exc.HTTPBadRequest() + + def stop_task(self, id): + ctx = pecan.request.context.to_dict() + try: + self.client.call(ctx, 'stop_task', id=int(id)) + except ValueError as e: + LOG.exception(e) + raise + except rpc.RemoteError as e: + LOG.exception(e) + raise + + def force_delete(self, id): + ctx = pecan.request.context.to_dict() + try: + self.client.call(ctx, + 'force_delete_recurrent_task', + id=int(id)) + except ValueError as e: + LOG.exception(e) + raise + except rpc.RemoteError as e: + LOG.exception(e) + raise + + def restart(self, id): + ctx = pecan.request.context.to_dict() + try: + self.client.call(ctx, + 'restart_recurrent_task', + id=int(id)) + except ValueError as e: + LOG.exception(e) + raise + except rpc.RemoteError as e: + LOG.exception(e) + raise + + def delete_task(self, id): + ctx = pecan.request.context.to_dict() + try: + self.client.call(ctx, 'delete_recurrent_task', id=int(id)) + except ValueError as e: + LOG.exception(e) + raise + except rpc.RemoteError as e: + LOG.exception(e) + raise + + @pecan.expose("json") + def delete(self): + """ + Delete a task specified by its identifier. If the task is running, it + has to be stopped. + :param req: the HTTP request + :param resp: the HTTP response, including a description and the task id + :param id: the identifier of the task to be deleted + :return: + :raises: + HTTPBadRequest: id not an integer or task can't be deleted + """ + try: + self.delete_task(self._id) + except ValueError: + raise exc.HTTPBadRequest(explanation="Task id must be an integer") + except rpc.RemoteError as e: + raise exc.HTTPBadRequest(explanation=e.value) + except Exception as e: + LOG.exception(e) + raise diff --git a/cerberus/client/__init__.py b/cerberus/client/__init__.py new file mode 100644 index 0000000..73ca62b --- /dev/null +++ b/cerberus/client/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/cerberus/client/keystone_client.py b/cerberus/client/keystone_client.py new file mode 100644 index 0000000..013e84c --- /dev/null +++ b/cerberus/client/keystone_client.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools + +from keystoneclient.v2_0 import client as keystone_client_v2_0 +from oslo.config import cfg + +from cerberus.openstack.common import log + + +cfg.CONF.import_group('service_credentials', 'cerberus.service') + +LOG = log.getLogger(__name__) + + +def logged(func): + + @functools.wraps(func) + def with_logging(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + LOG.exception(e) + raise + + return with_logging + + +class Client(object): + """A client which gets information via python-keystoneclient.""" + + def __init__(self): + """Initialize a keystone client object.""" + conf = cfg.CONF.service_credentials + self.keystone_client_v2_0 = keystone_client_v2_0.Client( + username=conf.os_username, + password=conf.os_password, + tenant_name=conf.os_tenant_name, + auth_url=conf.os_auth_url, + region_name=conf.os_region_name, + ) + + @logged + def user_detail_get(self, user): + """Returns details for a user.""" + return self.keystone_client_v2_0.users.get(user) + + @logged + def roles_for_user(self, user, tenant=None): + """Returns role for a given id.""" + return self.keystone_client_v2_0.roles.roles_for_user(user, tenant) diff --git a/cerberus/client/neutron_client.py b/cerberus/client/neutron_client.py new file mode 100644 index 0000000..cf3cf22 --- /dev/null +++ b/cerberus/client/neutron_client.py @@ -0,0 +1,109 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools + +from neutronclient.v2_0 import client as neutron_client +from oslo.config import cfg + +from cerberus.openstack.common import log + + +cfg.CONF.import_group('service_credentials', 'cerberus.service') + +LOG = log.getLogger(__name__) + + +def logged(func): + + @functools.wraps(func) + def with_logging(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + LOG.exception(e) + raise + + return with_logging + + +class Client(object): + """A client which gets information via python-neutronclient.""" + + def __init__(self): + """Initialize a neutron client object.""" + conf = cfg.CONF.service_credentials + self.neutronClient = neutron_client.Client( + username=conf.os_username, + password=conf.os_password, + tenant_name=conf.os_tenant_name, + auth_url=conf.os_auth_url, + ) + + @logged + def list_networks(self, tenant_id): + """Returns the list of networks of a given tenant""" + return self.neutronClient.list_networks( + tenant_id=tenant_id).get("networks", None) + + @logged + def list_floatingips(self, tenant_id): + """Returns the list of networks of a given tenant""" + return self.neutronClient.list_floatingips( + tenant_id=tenant_id).get("floatingips", None) + + @logged + def list_associated_floatingips(self, **params): + """Returns the list of associated floating ips of a given tenant""" + floating_ips = self.neutronClient.list_floatingips( + **params).get("floatingips", None) + # A floating IP is an IP address on an external network, which is + # associated with a specific port, and optionally a specific IP + # address, on a private OpenStack Networking network. Therefore a + # floating IP allows access to an instance on a private network from an + # external network. Floating IPs can only be defined on networks for + # which the attribute router:external (by the external network + # extension) has been set to True. + associated_floating_ips = [] + for floating_ip in floating_ips: + if floating_ip.get("port_id") is not None: + associated_floating_ips.append(floating_ip) + return associated_floating_ips + + @logged + def net_ips_get(self, network_id): + """ + Return ip pools used in all subnets of a network + :param network_id: + :return: list of pools + """ + subnets = self.neutronClient.show_network( + network_id)["network"]["subnets"] + ips = [] + for subnet in subnets: + ips.append(self.subnet_ips_get(subnet)) + return ips + + @logged + def get_net_of_subnet(self, subnet_id): + return self.neutronClient.show_subnet( + subnet_id)["subnet"]["network_id"] + + @logged + def subnet_ips_get(self, network_id): + """Returns ip pool of a subnet.""" + return self.neutronClient.show_subnet( + network_id)["subnet"]["allocation_pools"] diff --git a/cerberus/client/nova_client.py b/cerberus/client/nova_client.py new file mode 100644 index 0000000..d7a5750 --- /dev/null +++ b/cerberus/client/nova_client.py @@ -0,0 +1,109 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools + +from novaclient.v3 import client as nova_client +from oslo.config import cfg + +from cerberus.openstack.common import log + + +OPTS = [ + cfg.BoolOpt('nova_http_log_debug', + default=False, + help='Allow novaclient\'s debug log output.'), +] + +SERVICE_OPTS = [ + cfg.StrOpt('nova', + default='compute', + help='Nova service type.'), +] + +cfg.CONF.register_opts(OPTS) +cfg.CONF.register_opts(SERVICE_OPTS, group='service_types') +# cfg.CONF.import_opt('http_timeout', 'cerberus.service') +cfg.CONF.import_group('service_credentials', 'cerberus.service') +LOG = log.getLogger(__name__) + + +def logged(func): + + @functools.wraps(func) + def with_logging(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + LOG.exception(e) + raise + + return with_logging + + +class Client(object): + """A client which gets information via python-novaclient.""" + def __init__(self, bypass_url=None, auth_token=None): + """Initialize a nova client object.""" + conf = cfg.CONF.service_credentials + tenant = conf.os_tenant_id or conf.os_tenant_name + self.nova_client = nova_client.Client( + username=conf.os_username, + project_id=tenant, + auth_url=conf.os_auth_url, + password=conf.os_password, + region_name=conf.os_region_name, + endpoint_type=conf.os_endpoint_type, + service_type=cfg.CONF.service_types.nova, + bypass_url=bypass_url, + cacert=conf.os_cacert, + insecure=conf.insecure, + http_log_debug=cfg.CONF.nova_http_log_debug, + no_cache=True) + + @logged + def instance_get_all(self): + """Returns list of all instances.""" + search_opts = {'all_tenants': True} + return self.nova_client.servers.list( + detailed=True, + search_opts=search_opts) + + @logged + def get_instance_details_from_floating_ip(self, ip): + """ + Get instance_id which is associated to the floating ip "ip" + :param ip: the floating ip that should belong to an instance + :return instance_id if ip is found, else None + """ + instances = self.instance_get_all() + + try: + for instance in instances: + # An instance can belong to many networks. An instance can + # have two ips in a network: + # at least a private ip and potentially a floating ip + addresses_in_networks = instance.addresses.values() + for addresses_in_network in addresses_in_networks: + for address_in_network in addresses_in_network: + if ((address_in_network.get('OS-EXT-IPS:type', None) + == 'floating') + and (address_in_network['addr'] == ip)): + return instance + except Exception as e: + LOG.exception(e) + raise + return None diff --git a/cerberus/cmd/__init__.py b/cerberus/cmd/__init__.py new file mode 100644 index 0000000..73ca62b --- /dev/null +++ b/cerberus/cmd/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/cerberus/cmd/agent.py b/cerberus/cmd/agent.py new file mode 100644 index 0000000..42084f4 --- /dev/null +++ b/cerberus/cmd/agent.py @@ -0,0 +1,46 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import eventlet +import sys + +from oslo.config import cfg + +from cerberus.common import config +from cerberus import manager +from cerberus.openstack.common import log +from cerberus.openstack.common import service + + +eventlet.monkey_patch() + +LOG = log.getLogger(__name__) + + +def main(): + + log.set_defaults(cfg.CONF.default_log_levels) + argv = sys.argv + config.parse_args(argv) + log.setup(cfg.CONF, 'cerberus') + launcher = service.ProcessLauncher() + c_manager = manager.CerberusManager() + launcher.launch_service(c_manager) + launcher.wait() + + +if __name__ == '__main__': + main() diff --git a/cerberus/cmd/api.py b/cerberus/cmd/api.py new file mode 100644 index 0000000..7b0dc45 --- /dev/null +++ b/cerberus/cmd/api.py @@ -0,0 +1,45 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from oslo.config import cfg + +from cerberus.api import app +from cerberus.common import config +from cerberus.openstack.common import log + + +CONF = cfg.CONF +CONF.import_opt('auth_strategy', 'cerberus.api') +LOG = log.getLogger(__name__) + + +def main(): + argv = sys.argv + config.parse_args(argv) + log.setup(cfg.CONF, 'cerberus') + server = app.build_server() + log.set_defaults(cfg.CONF.default_log_levels) + + try: + server.serve_forever() + except KeyboardInterrupt: + pass + LOG.info("cerberus-api starting...") + +if __name__ == '__main__': + main() diff --git a/cerberus/cmd/db_create.py b/cerberus/cmd/db_create.py new file mode 100644 index 0000000..36b4b01 --- /dev/null +++ b/cerberus/cmd/db_create.py @@ -0,0 +1,43 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from oslo.config import cfg +from sqlalchemy import create_engine + +from cerberus.common import config +from cerberus.db.sqlalchemy import models + + +def main(): + argv = sys.argv + config.parse_args(argv) + + engine = create_engine(cfg.CONF.database.connection) + + conn = engine.connect() + try: + conn.execute("CREATE DATABASE cerberus") + except Exception: + pass + + models.BASE.metadata.create_all(engine) + + conn.close() + +if __name__ == '__main__': + main() diff --git a/cerberus/common/__init__.py b/cerberus/common/__init__.py new file mode 100644 index 0000000..73ca62b --- /dev/null +++ b/cerberus/common/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/cerberus/common/config.py b/cerberus/common/config.py new file mode 100644 index 0000000..d0e790d --- /dev/null +++ b/cerberus/common/config.py @@ -0,0 +1,26 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from oslo.config import cfg + +from cerberus import version + + +def parse_args(argv, default_config_files=None): + cfg.CONF(argv[1:], + project='cerberus', + version=version.version_info.release_string(), + default_config_files=default_config_files) diff --git a/cerberus/common/context.py b/cerberus/common/context.py new file mode 100644 index 0000000..c14c422 --- /dev/null +++ b/cerberus/common/context.py @@ -0,0 +1,64 @@ +# -*- encoding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from cerberus.openstack.common import context + + +class RequestContext(context.RequestContext): + """Extends security contexts from the OpenStack common library.""" + + def __init__(self, auth_token=None, domain_id=None, domain_name=None, + user=None, tenant_id=None, tenant=None, is_admin=False, + is_public_api=False, read_only=False, show_deleted=False, + request_id=None, roles=None): + """Stores several additional request parameters: + + :param domain_id: The ID of the domain. + :param domain_name: The name of the domain. + :param is_public_api: Specifies whether the request should be processed + without authentication. + + """ + self.tenant_id = tenant_id + self.is_public_api = is_public_api + self.domain_id = domain_id + self.domain_name = domain_name + self.roles = roles or [] + + super(RequestContext, self).__init__(auth_token=auth_token, + user=user, tenant=tenant, + is_admin=is_admin, + read_only=read_only, + show_deleted=show_deleted, + request_id=request_id) + + def to_dict(self): + return {'auth_token': self.auth_token, + 'user': self.user, + 'tenant_id': self.tenant_id, + 'tenant': self.tenant, + 'is_admin': self.is_admin, + 'read_only': self.read_only, + 'show_deleted': self.show_deleted, + 'request_id': self.request_id, + 'domain_id': self.domain_id, + 'roles': self.roles, + 'domain_name': self.domain_name, + 'is_public_api': self.is_public_api} + + @classmethod + def from_dict(cls, values): + values.pop('user', None) + values.pop('tenant', None) + return cls(**values) diff --git a/cerberus/common/errors.py b/cerberus/common/errors.py new file mode 100644 index 0000000..7462edf --- /dev/null +++ b/cerberus/common/errors.py @@ -0,0 +1,124 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from cerberus.openstack.common._i18n import _ # noqa + + +class InvalidOperation(Exception): + + def __init__(self, description): + super(InvalidOperation, self).__init__(description) + + +class PluginNotFound(InvalidOperation): + + def __init__(self, uuid): + super(PluginNotFound, self).__init__("Plugin %s does not exist" + % str(uuid)) + + +class TaskPeriodNotInteger(InvalidOperation): + + def __init__(self): + super(TaskPeriodNotInteger, self).__init__( + "The period of the task must be provided as an integer" + ) + + +class TaskNotFound(InvalidOperation): + + def __init__(self, _id): + super(TaskNotFound, self).__init__( + _('Task %s does not exist') % _id + ) + + +class TaskDeletionNotAllowed(InvalidOperation): + def __init__(self, _id): + super(TaskDeletionNotAllowed, self).__init__( + _("Deletion of task %s is not allowed because either it " + "does not exist or it is not recurrent") % _id + ) + + +class TaskRestartNotAllowed(InvalidOperation): + def __init__(self, _id): + super(TaskRestartNotAllowed, self).__init__( + _("Restarting task %s is not allowed because either it " + "does not exist or it is not recurrent") % _id + ) + + +class TaskRestartNotPossible(InvalidOperation): + def __init__(self, _id): + super(TaskRestartNotPossible, self).__init__( + _("Restarting task %s is not possible because it is running") % _id + ) + + +class MethodNotString(InvalidOperation): + + def __init__(self): + super(MethodNotString, self).__init__( + "Method must be provided as a string" + ) + + +class MethodNotCallable(InvalidOperation): + + def __init__(self, method, name): + super(MethodNotCallable, self).__init__( + "Method named %s is not callable by plugin %s" + % (str(method), str(name)) + ) + + +class TaskObjectNotProvided(InvalidOperation): + + def __init__(self): + super(TaskObjectNotProvided, self).__init__( + "Task object not provided in request" + ) + + +class PluginIdNotProvided(InvalidOperation): + + def __init__(self): + super(PluginIdNotProvided, self).__init__( + "Plugin id not provided in request" + ) + + +class MethodNotProvided(InvalidOperation): + + def __init__(self): + super(MethodNotProvided, self).__init__( + "Method not provided in request" + ) + + +class PolicyEnforcementError(Exception): + + def __init__(self): + super(PolicyEnforcementError, self).__init__( + "Policy enforcement error" + ) + + +class DbError(Exception): + + def __init__(self, description): + super(DbError, self).__init__(description) diff --git a/cerberus/common/exception.py b/cerberus/common/exception.py new file mode 100644 index 0000000..1b4a8f7 --- /dev/null +++ b/cerberus/common/exception.py @@ -0,0 +1,154 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Cerberus base exception handling. + +Includes decorator for re-raising Nova-type exceptions. + +SHOULD include dedicated exception logging. + +""" + +import functools +import gettext as t +import logging +import sys +import webob.exc + +from oslo.config import cfg + +from cerberus.common import safe_utils +from cerberus.openstack.common import excutils + + +_ = t.gettext + +LOG = logging.getLogger(__name__) + +exc_log_opts = [ + cfg.BoolOpt('fatal_exception_format_errors', + default=False, + help='Make exception message format errors fatal'), +] + +CONF = cfg.CONF +CONF.register_opts(exc_log_opts) + + +class ConvertedException(webob.exc.WSGIHTTPException): + def __init__(self, code=0, title="", explanation=""): + self.code = code + self.title = title + self.explanation = explanation + super(ConvertedException, self).__init__() + + +def _cleanse_dict(original): + """Strip all admin_password, new_pass, rescue_pass keys from a dict.""" + return dict((k, v) for k, v in original.iteritems() if "_pass" not in k) + + +def wrap_exception(notifier=None, get_notifier=None): + """This decorator wraps a method to catch any exceptions that may + get thrown. It logs the exception as well as optionally sending + it to the notification system. + """ + def inner(f): + def wrapped(self, context, *args, **kw): + # Don't store self or context in the payload, it now seems to + # contain confidential information. + try: + return f(self, context, *args, **kw) + except Exception as e: + with excutils.save_and_reraise_exception(): + if notifier or get_notifier: + payload = dict(exception=e) + call_dict = safe_utils.getcallargs(f, context, + *args, **kw) + cleansed = _cleanse_dict(call_dict) + payload.update({'args': cleansed}) + + # If f has multiple decorators, they must use + # functools.wraps to ensure the name is + # propagated. + event_type = f.__name__ + + (notifier or get_notifier()).error(context, + event_type, + payload) + + return functools.wraps(f)(wrapped) + return inner + + +class CerberusException(Exception): + """Base Cerberus Exception + + To correctly use this class, inherit from it and define + a 'msg_fmt' property. That msg_fmt will get printf'd + with the keyword arguments provided to the constructor. + + """ + msg_fmt = _("An unknown exception occurred.") + code = 500 + headers = {} + safe = False + + def __init__(self, message=None, **kwargs): + self.kwargs = kwargs + + if 'code' not in self.kwargs: + try: + self.kwargs['code'] = self.code + except AttributeError: + pass + + if not message: + try: + message = self.msg_fmt % kwargs + + except Exception: + exc_info = sys.exc_info() + # kwargs doesn't match a variable in the message + # log the issue and the kwargs + LOG.exception(_('Exception in string format operation')) + for name, value in kwargs.iteritems(): + LOG.error("%s: %s" % (name, value)) # noqa + + if CONF.fatal_exception_format_errors: + raise exc_info[0], exc_info[1], exc_info[2] + else: + # at least get the core message out if something happened + message = self.msg_fmt + + super(CerberusException, self).__init__(message) + + def format_message(self): + # NOTE(mrodden): use the first argument to the python Exception object + # which should be our full NovaException message, (see __init__) + return self.args[0] + + +class AlertExists(CerberusException): + msg_fmt = _("Alert %(alert_id)s already exists.") + + +class ReportExists(CerberusException): + msg_fmt = _("Report %(report_id)s already exists.") + + +class PluginInfoExists(CerberusException): + msg_fmt = _("Plugin info %(id)s already exists.") diff --git a/cerberus/common/json_encoders.py b/cerberus/common/json_encoders.py new file mode 100644 index 0000000..51826bf --- /dev/null +++ b/cerberus/common/json_encoders.py @@ -0,0 +1,26 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import json + + +class DateTimeEncoder(json.JSONEncoder): + def default(self, obj): + """JSON serializer for objects not serializable by default json code""" + if isinstance(obj, datetime.datetime): + serial = obj.isoformat() + return serial diff --git a/cerberus/common/policy.py b/cerberus/common/policy.py new file mode 100644 index 0000000..61a66d7 --- /dev/null +++ b/cerberus/common/policy.py @@ -0,0 +1,67 @@ +# Copyright (c) 2011 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Policy Engine For Cerberus.""" + +from oslo.config import cfg + +from cerberus.openstack.common import policy + +_ENFORCER = None +CONF = cfg.CONF + + +def init_enforcer(policy_file=None, rules=None, + default_rule=None, use_conf=True): + """Synchronously initializes the policy enforcer + + :param policy_file: Custom policy file to use, if none is specified, + `CONF.policy_file` will be used. + :param rules: Default dictionary / Rules to use. It will be + considered just in the first instantiation. + :param default_rule: Default rule to use, CONF.default_rule will + be used if none is specified. + :param use_conf: Whether to load rules from config file. + + """ + global _ENFORCER + + if _ENFORCER: + return + + _ENFORCER = policy.Enforcer(policy_file=policy_file, + rules=rules, + default_rule=default_rule, + use_conf=use_conf) + + +def get_enforcer(): + """Provides access to the single instance of Policy enforcer.""" + + if not _ENFORCER: + init_enforcer() + + return _ENFORCER + + +def enforce(rule, target, creds, do_raise=False, exc=None, *args, **kwargs): + """A shortcut for policy.Enforcer.enforce() + + Checks authorization of a rule against the target and credentials. + + """ + enforcer = get_enforcer() + return enforcer.enforce(rule, target, creds, do_raise=do_raise, + exc=exc, *args, **kwargs) diff --git a/cerberus/common/safe_utils.py b/cerberus/common/safe_utils.py new file mode 100644 index 0000000..1efb582 --- /dev/null +++ b/cerberus/common/safe_utils.py @@ -0,0 +1,70 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Utilities and helper functions that won't produce circular imports.""" + +import inspect +import six + +from cerberus.openstack.common import log + +LOG = log.getLogger(__name__) + + +def getcallargs(function, *args, **kwargs): + """This is a simplified inspect.getcallargs (2.7+). + + It should be replaced when python >= 2.7 is standard. + """ + keyed_args = {} + argnames, varargs, keywords, defaults = inspect.getargspec(function) + + keyed_args.update(kwargs) + + # NOTE(alaski) the implicit 'self' or 'cls' argument shows up in + # argnames but not in args or kwargs. Uses 'in' rather than '==' because + # some tests use 'self2'. + if 'self' in argnames[0] or 'cls' == argnames[0]: + # The function may not actually be a method or have im_self. + # Typically seen when it's stubbed with mox. + if inspect.ismethod(function) and hasattr(function, 'im_self'): + keyed_args[argnames[0]] = function.im_self + else: + keyed_args[argnames[0]] = None + + remaining_argnames = filter(lambda x: x not in keyed_args, argnames) + keyed_args.update(dict(zip(remaining_argnames, args))) + + if defaults: + num_defaults = len(defaults) + for argname, value in zip(argnames[-num_defaults:], defaults): + if argname not in keyed_args: + keyed_args[argname] = value + + return keyed_args + + +def safe_rstrip(value, chars=None): + """Removes trailing characters from a string if that does not make it empty + :param value: A string value that will be stripped. + :param chars: Characters to remove. + :return: Stripped value. + """ + if not isinstance(value, six.string_types): + LOG.warn(("Failed to remove trailing character. Returning original " + "object. Supplied object is not a string: %s,") % value) + return value + return value.rstrip(chars) or value diff --git a/cerberus/common/serialize.py b/cerberus/common/serialize.py new file mode 100644 index 0000000..ef8075a --- /dev/null +++ b/cerberus/common/serialize.py @@ -0,0 +1,110 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import parser + + +class JsonSerializer(object): + """A serializer that provides methods to serialize and deserialize JSON + dictionaries. + + Note, one of the assumptions this serializer makes is that all objects that + it is used to deserialize have a constructor that can take all of the + attribute arguments. I.e. If you have an object with 3 attributes, the + constructor needs to take those three attributes as keyword arguments. + """ + + __attributes__ = None + """The attributes to be serialized by the seralizer. + The implementor needs to provide these.""" + + __required__ = None + """The attributes that are required when deserializing. + The implementor needs to provide these.""" + + __attribute_serializer__ = None + """The serializer to use for a specified attribute. If an attribute is not + included here, no special serializer will be user. + The implementor needs to provide these.""" + + __object_class__ = None + """The class that the deserializer should generate. + The implementor needs to provide these.""" + + serializers = dict( + date=dict( + serialize=lambda x: x.isoformat(), + deserialize=lambda x: parser.parse(x) + ) + ) + + def deserialize(self, json, **kwargs): + """Deserialize a JSON dictionary and return a populated object. + + This takes the JSON data, and deserializes it appropriately and then + calls the constructor of the object to be created with all of the + attributes. + + Args: + json: The JSON dict with all of the data + **kwargs: Optional values that can be used as defaults if they are + not present in the JSON data + Returns: + The deserialized object. + Raises: + ValueError: If any of the required attributes are not present + """ + d = dict() + for attr in self.__attributes__: + if attr in json: + val = json[attr] + elif attr in self.__required__: + try: + val = kwargs[attr] + except KeyError: + raise ValueError("{} must be set".format(attr)) + + serializer = self.__attribute_serializer__.get(attr) + if serializer: + d[attr] = self.serializers[serializer]['deserialize'](val) + else: + d[attr] = val + + return self.__object_class__(**d) + + def serialize(self, obj): + """Serialize an object to a dictionary. + + Take all of the attributes defined in self.__attributes__ and create + a dictionary containing those values. + + Args: + obj: The object to serialize + Returns: + A dictionary containing all of the serialized data from the object. + """ + d = dict() + for attr in self.__attributes__: + val = getattr(obj, attr) + if val is None: + continue + serializer = self.__attribute_serializer__.get(attr) + if serializer: + d[attr] = self.serializers[serializer]['serialize'](val) + else: + d[attr] = val + + return d diff --git a/cerberus/db/__init__.py b/cerberus/db/__init__.py new file mode 100644 index 0000000..2261e4c --- /dev/null +++ b/cerberus/db/__init__.py @@ -0,0 +1,17 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from cerberus.db.api import * # noqa diff --git a/cerberus/db/api.py b/cerberus/db/api.py new file mode 100644 index 0000000..f0fb92a --- /dev/null +++ b/cerberus/db/api.py @@ -0,0 +1,115 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from oslo.config import cfg + +from cerberus.openstack.common.db import api as db_api + + +CONF = cfg.CONF +CONF.import_opt('backend', 'cerberus.openstack.common.db.options', + group='database') +_BACKEND_MAPPING = {'sqlalchemy': 'cerberus.db.sqlalchemy.api'} + +IMPL = db_api.DBAPI(CONF.database.backend, backend_mapping=_BACKEND_MAPPING, + lazy=True) +''' JUNO: +IMPL = db_api.DBAPI.from_config(cfg.CONF, + backend_mapping=_BACKEND_MAPPING, + lazy=True) +''' + + +def get_instance(): + """Return a DB API instance.""" + return IMPL + + +def get_engine(): + return IMPL.get_engine() + + +def get_session(): + return IMPL.get_session() + + +def db_sync(engine, version=None): + """Migrate the database to `version` or the most recent version.""" + return IMPL.db_sync(engine, version=version) + + +def alert_create(values): + """Create an instance from the values dictionary.""" + return IMPL.alert_create(values) + + +def alert_get_all(): + """Get all alerts""" + return IMPL.alert_get_all() + + +def security_report_create(values): + """Create an instance from the values dictionary.""" + return IMPL.security_report_create(values) + + +def security_report_update_last_report_date(id, date): + """Create an instance from the values dictionary.""" + return IMPL.security_report_update_last_report_date(id, date) + + +def security_report_get_all(project_id=None): + """Get all alerts""" + return IMPL.security_report_get_all(project_id=project_id) + + +def security_report_get(id): + """Get all alerts""" + return IMPL.security_report_get(id) + + +def security_report_get_from_report_id(report_id): + """Get all alerts""" + return IMPL.security_report_get_from_report_id(report_id) + + +def plugins_info_get(): + """Get information about plugins stored in db""" + return IMPL.plugins_info_get() + + +def plugin_info_get_from_uuid(id): + """ + Get information about plugin stored in db + :param id: the uuid of the plugin + """ + return IMPL.plugin_info_get_from_uuid(id) + + +def plugin_version_update(id, version): + return IMPL.plugin_version_update(id, version) + + +def security_alarm_create(values): + return IMPL.security_alarm_create(values) + + +def security_alarm_get_all(): + return IMPL.security_alarm_get_all() + + +def security_alarm_get(id): + return IMPL.security_alarm_get(id) diff --git a/cerberus/db/sqlalchemy/__init__.py b/cerberus/db/sqlalchemy/__init__.py new file mode 100644 index 0000000..73ca62b --- /dev/null +++ b/cerberus/db/sqlalchemy/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/cerberus/db/sqlalchemy/api.py b/cerberus/db/sqlalchemy/api.py new file mode 100644 index 0000000..82626c1 --- /dev/null +++ b/cerberus/db/sqlalchemy/api.py @@ -0,0 +1,271 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sqlalchemy +import sys +import threading + +from oslo.config import cfg + +from cerberus.common import exception +from cerberus.db.sqlalchemy import migration +from cerberus.db.sqlalchemy import models +from cerberus.openstack.common.db import exception as db_exc +from cerberus.openstack.common.db.sqlalchemy import session as db_session +from cerberus.openstack.common import log + + +CONF = cfg.CONF + +LOG = log.getLogger(__name__) + + +_ENGINE_FACADE = None +_LOCK = threading.Lock() + + +_FACADE = None + + +def _create_facade_lazily(): + global _FACADE + if _FACADE is None: + _FACADE = db_session.EngineFacade( + CONF.database.connection, + **dict(CONF.database.iteritems()) + ) + return _FACADE + + +def get_engine(): + facade = _create_facade_lazily() + return facade.get_engine() + + +def get_session(**kwargs): + facade = _create_facade_lazily() + return facade.get_session(**kwargs) + + +def get_backend(): + """The backend is this module itself.""" + return sys.modules[__name__] + + +def model_query(model, *args, **kwargs): + """Query helper for simpler session usage. + :param session: if present, the session to use + """ + session = kwargs.get('session') or get_session() + query = session.query(model, *args) + return query + + +def _alert_get_all(session=None): + session = get_session() + + return model_query(models.Alert, read_deleted="no", + session=session).all() + + +def alert_create(values): + alert_ref = models.Alert() + alert_ref.update(values) + try: + alert_ref.save() + except db_exc.DBDuplicateEntry: + raise exception.AlertExists(id=values['id']) + return alert_ref + + +def alert_get_all(): + return _alert_get_all() + + +def _security_report_get_all(project_id=None): + session = get_session() + + try: + if project_id is None: + return model_query(models.SecurityReport, read_deleted="no", + session=session).all() + else: + return model_query(models.SecurityReport, read_deleted="no", + session=session).\ + filter(models.SecurityReport.project_id == project_id).all() + except Exception as e: + LOG.exception(e) + raise e + + +def _security_report_get(id): + session = get_session() + + return model_query(models.SecurityReport, read_deleted="no", + session=session).filter(models.SecurityReport. + id == id).first() + + +def _security_report_get_from_report_id(report_id): + session = get_session() + return model_query(models.SecurityReport, read_deleted="no", + session=session).filter(models.SecurityReport.report_id + == report_id).first() + + +def security_report_create(values): + security_report_ref = models.SecurityReport() + security_report_ref.update(values) + try: + security_report_ref.save() + except sqlalchemy.exc.OperationalError as e: + LOG.exception(e) + raise db_exc.ColumnError + return security_report_ref + + +def security_report_update_last_report_date(id, date): + session = get_session() + report = model_query(models.SecurityReport, read_deleted="no", + session=session).filter(models.SecurityReport.id + == id).first() + report.last_report_date = date + try: + report.save(session) + except sqlalchemy.exc.OperationalError as e: + LOG.exception(e) + raise db_exc.ColumnError + + +def security_report_get_all(project_id=None): + return _security_report_get_all(project_id=project_id) + + +def security_report_get(id): + return _security_report_get(id) + + +def security_report_get_from_report_id(report_id): + return _security_report_get_from_report_id(report_id) + + +def _plugin_info_get(name): + session = get_session() + + return model_query(models.PluginInfo, + read_deleted="no", + session=session).filter(models.PluginInfo.name == + name).first() + + +def _plugin_info_get_from_uuid(id): + session = get_session() + + return model_query(models.PluginInfo, + read_deleted="no", + session=session).filter(models.PluginInfo.uuid == + id).first() + + +def _plugins_info_get(): + session = get_session() + + return model_query(models.PluginInfo, + read_deleted="no", + session=session).all() + + +def plugin_info_create(values): + plugin_info_ref = models.PluginInfo() + plugin_info_ref.update(values) + try: + plugin_info_ref.save() + except db_exc.DBDuplicateEntry: + raise exception.PluginInfoExists(id=values['id']) + return plugin_info_ref + + +def plugin_info_get(name): + return _plugin_info_get(name) + + +def plugin_info_get_from_uuid(id): + return _plugin_info_get_from_uuid(id) + + +def plugins_info_get(): + return _plugins_info_get() + + +def plugin_version_update(id, version): + session = get_session() + plugin = model_query(models.PluginInfo, read_deleted="no", + session=session).filter(models.PluginInfo.id == + id).first() + plugin.version = version + try: + plugin.save(session) + except sqlalchemy.exc.OperationalError as e: + LOG.exception(e) + raise db_exc.ColumnError + + +def db_sync(engine, version=None): + """Migrate the database to `version` or the most recent version.""" + return migration.db_sync(engine, version=version) + + +def db_version(engine): + """Display the current database version.""" + return migration.db_version(engine) + + +def _security_alarm_get_all(): + + session = get_session() + try: + return model_query(models.SecurityAlarm, read_deleted="no", + session=session).all() + except Exception as e: + LOG.exception(e) + raise e + + +def _security_alarm_get(id): + + session = get_session() + return model_query(models.SecurityAlarm, read_deleted="no", + session=session).filter(models.SecurityAlarm. + id == id).first() + + +def security_alarm_create(values): + security_alarm_ref = models.SecurityAlarm() + security_alarm_ref.update(values) + try: + security_alarm_ref.save() + except sqlalchemy.exc.OperationalError as e: + LOG.exception(e) + raise db_exc.ColumnError + return security_alarm_ref + + +def security_alarm_get_all(): + return _security_alarm_get_all() + + +def security_alarm_get(id): + return _security_alarm_get(id) diff --git a/cerberus/db/sqlalchemy/migrate_repo/README b/cerberus/db/sqlalchemy/migrate_repo/README new file mode 100644 index 0000000..6218f8c --- /dev/null +++ b/cerberus/db/sqlalchemy/migrate_repo/README @@ -0,0 +1,4 @@ +This is a database migration repository. + +More information at +http://code.google.com/p/sqlalchemy-migrate/ diff --git a/cerberus/db/sqlalchemy/migrate_repo/__init__.py b/cerberus/db/sqlalchemy/migrate_repo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cerberus/db/sqlalchemy/migrate_repo/manage.py b/cerberus/db/sqlalchemy/migrate_repo/manage.py new file mode 100755 index 0000000..39fa389 --- /dev/null +++ b/cerberus/db/sqlalchemy/migrate_repo/manage.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python +from migrate.versioning.shell import main + +if __name__ == '__main__': + main(debug='False') diff --git a/cerberus/db/sqlalchemy/migrate_repo/migrate.cfg b/cerberus/db/sqlalchemy/migrate_repo/migrate.cfg new file mode 100644 index 0000000..b5a1683 --- /dev/null +++ b/cerberus/db/sqlalchemy/migrate_repo/migrate.cfg @@ -0,0 +1,25 @@ +[db_settings] +# Used to identify which repository this database is versioned under. +# You can use the name of your project. +repository_id=cerberus + +# The name of the database table used to track the schema version. +# This name shouldn't already be used by your project. +# If this is changed once a database is under version control, you'll need to +# change the table name in each database too. +version_table=migrate_version + +# When committing a change script, Migrate will attempt to generate the +# sql for all supported databases; normally, if one of them fails - probably +# because you don't have that database installed - it is ignored and the +# commit continues, perhaps ending successfully. +# Databases in this list MUST compile successfully during a commit, or the +# entire commit will fail. List the databases your application will actually +# be using to ensure your updates to that database work properly. +# This must be a list; example: ['postgres','sqlite'] +required_dbs=[] + +# When creating new change scripts, Migrate will stamp the new script with +# a version number. By default this is latest_version + 1. You can set this +# to 'true' to tell Migrate to use the UTC timestamp instead. +use_timestamp_numbering=False diff --git a/cerberus/db/sqlalchemy/migrate_repo/versions/015_initial.py b/cerberus/db/sqlalchemy/migrate_repo/versions/015_initial.py new file mode 100644 index 0000000..6b71db3 --- /dev/null +++ b/cerberus/db/sqlalchemy/migrate_repo/versions/015_initial.py @@ -0,0 +1,102 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import sqlalchemy + + +def upgrade(migrate_engine): + meta = sqlalchemy.MetaData() + meta.bind = migrate_engine + + alert = sqlalchemy.Table( + 'alert', meta, + sqlalchemy.Column('id', sqlalchemy.Integer, primary_key=True, + nullable=False), + sqlalchemy.Column('title', sqlalchemy.Text), + sqlalchemy.Column('status', sqlalchemy.Text), + sqlalchemy.Column('severity', sqlalchemy.Integer), + sqlalchemy.Column('acknowledged_at', sqlalchemy.DateTime), + sqlalchemy.Column('plugin_id', sqlalchemy.Text), + sqlalchemy.Column('description', sqlalchemy.Text), + sqlalchemy.Column('resource_id', sqlalchemy.Text), + sqlalchemy.Column('issue_link', sqlalchemy.Text), + sqlalchemy.Column('created_at', sqlalchemy.DateTime), + sqlalchemy.Column('updated_at', sqlalchemy.DateTime), + sqlalchemy.Column('deleted_at', sqlalchemy.DateTime), + sqlalchemy.Column('deleted', sqlalchemy.Integer), + mysql_engine='InnoDB', + mysql_charset='utf8' + ) + + plugin_info = sqlalchemy.Table( + 'plugin_info', meta, + sqlalchemy.Column('id', sqlalchemy.Integer, primary_key=True, + nullable=False), + sqlalchemy.Column('uuid', sqlalchemy.Text), + sqlalchemy.Column('name', sqlalchemy.Text), + sqlalchemy.Column('version', sqlalchemy.Integer), + sqlalchemy.Column('provider', sqlalchemy.DateTime), + sqlalchemy.Column('type', sqlalchemy.Text), + sqlalchemy.Column('description', sqlalchemy.Text), + sqlalchemy.Column('tool_name', sqlalchemy.Text), + sqlalchemy.Column('created_at', sqlalchemy.DateTime), + sqlalchemy.Column('updated_at', sqlalchemy.DateTime), + sqlalchemy.Column('deleted_at', sqlalchemy.DateTime), + sqlalchemy.Column('deleted', sqlalchemy.Integer), + mysql_engine='InnoDB', + mysql_charset='utf8' + ) + + security_report = sqlalchemy.Table( + 'security_report', meta, + sqlalchemy.Column('id', sqlalchemy.Integer, primary_key=True, + nullable=False), + sqlalchemy.Column('plugin_id', sqlalchemy.Text), + sqlalchemy.Column('report_id', sqlalchemy.Text, unique=True), + sqlalchemy.Column('component_id', sqlalchemy.Text), + sqlalchemy.Column('component_type', sqlalchemy.Text), + sqlalchemy.Column('component_name', sqlalchemy.Text), + sqlalchemy.Column('project_id', sqlalchemy.Text), + sqlalchemy.Column('title', sqlalchemy.Text), + sqlalchemy.Column('description', sqlalchemy.Text), + sqlalchemy.Column('security_rating', sqlalchemy.Float), + sqlalchemy.Column('vulnerabilities', sqlalchemy.Text), + sqlalchemy.Column('vulnerabilities_number', sqlalchemy.Integer), + sqlalchemy.Column('last_report_date', sqlalchemy.DateTime), + sqlalchemy.Column('created_at', sqlalchemy.DateTime), + sqlalchemy.Column('updated_at', sqlalchemy.DateTime), + sqlalchemy.Column('deleted_at', sqlalchemy.DateTime), + sqlalchemy.Column('deleted', sqlalchemy.Integer), + mysql_engine='InnoDB', + mysql_charset='utf8' + ) + + tables = ( + security_report, + alert, + plugin_info, + ) + + for index, table in enumerate(tables): + try: + table.create() + except Exception: + # If an error occurs, drop all tables created so far to return + # to the previously existing state. + meta.drop_all(tables=tables[:index]) + raise + + +def downgrade(migrate_engine): + raise NotImplementedError('Database downgrade not supported - ' + 'would drop all tables') diff --git a/cerberus/db/sqlalchemy/migrate_repo/versions/__init__.py b/cerberus/db/sqlalchemy/migrate_repo/versions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cerberus/db/sqlalchemy/migration.py b/cerberus/db/sqlalchemy/migration.py new file mode 100644 index 0000000..4a53f35 --- /dev/null +++ b/cerberus/db/sqlalchemy/migration.py @@ -0,0 +1,38 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os + +from cerberus.openstack.common.db.sqlalchemy import migration as oslo_migration + + +INIT_VERSION = 14 + + +def db_sync(engine, version=None): + path = os.path.join(os.path.abspath(os.path.dirname(__file__)), + 'migrate_repo') + return oslo_migration.db_sync(engine, path, version, + init_version=INIT_VERSION) + + +def db_version(engine): + path = os.path.join(os.path.abspath(os.path.dirname(__file__)), + 'migrate_repo') + return oslo_migration.db_version(engine, path, INIT_VERSION) + + +def db_version_control(engine, version=None): + path = os.path.join(os.path.abspath(os.path.dirname(__file__)), + 'migrate_repo') + return oslo_migration.db_version_control(engine, path, version) diff --git a/cerberus/db/sqlalchemy/models.py b/cerberus/db/sqlalchemy/models.py new file mode 100644 index 0000000..255d3c6 --- /dev/null +++ b/cerberus/db/sqlalchemy/models.py @@ -0,0 +1,171 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +SQLAlchemy models for cerberus data. +""" + +from sqlalchemy import Column, String, Integer, DateTime, Float, Text +from sqlalchemy.ext.declarative import declarative_base + +from oslo.config import cfg + +from cerberus.common import serialize +from cerberus.openstack.common.db.sqlalchemy import models + + +CONF = cfg.CONF +BASE = declarative_base() + + +class CerberusBase(models.SoftDeleteMixin, + models.TimestampMixin, + models.ModelBase): + + metadata = None + + def save(self, session=None): + from cerberus.db.sqlalchemy import api + + if session is None: + session = api.get_session() + + super(CerberusBase, self).save(session=session) + + +class Alert(BASE, CerberusBase): + """Security alert""" + + __tablename__ = 'alert' + __table_args__ = () + + id = Column(Integer, primary_key=True) + title = Column(String(255)) + status = Column(String(255)) + severity = Column(Integer) + acknowledged_at = Column(DateTime) + plugin_id = Column(String(255)) + description = Column(String(255)) + resource_id = Column(String(255)) + issue_link = Column(String(255)) + + +class AlertJsonSerializer(serialize.JsonSerializer): + """Alert serializer""" + + __attributes__ = ['id', 'title', 'description', 'status', 'severity', + 'created_at', 'deleted_at', 'updated_at', + 'acknowledged_at', 'plugin_id', 'resource_id', + 'issue_link', 'deleted'] + __required__ = ['id', 'title'] + __attribute_serializer__ = dict(created_at='date', deleted_at='date', + updated_at='date', acknowledged_at='date') + __object_class__ = Alert + + +class PluginInfo(BASE, CerberusBase): + """Plugin info""" + + __tablename__ = 'plugin_info' + __table_args__ = () + + id = Column(Integer, primary_key=True) + uuid = Column(String(255)) + name = Column(String(255)) + version = Column(String(255)) + provider = Column(String(255)) + type = Column(String(255)) + description = Column(String(255)) + tool_name = Column(String(255)) + + +class PluginInfoJsonSerializer(serialize.JsonSerializer): + """Plugin info serializer""" + + __attributes__ = ['id', 'uuid', 'name', 'version', 'provider', + 'type', 'description', 'tool_name'] + __required__ = ['id'] + __attribute_serializer__ = dict(created_at='date', deleted_at='date', + acknowledged_at='date') + __object_class__ = PluginInfo + + +class SecurityReport(BASE, CerberusBase): + """Security Report""" + + __tablename__ = 'security_report' + __table_args__ = () + + id = Column(Integer, primary_key=True) + plugin_id = Column(String(255)) + report_id = Column(String(255), unique=True) + component_id = Column(String(255)) + component_type = Column(String(255)) + component_name = Column(String(255)) + project_id = Column(String(255)) + title = Column(String(255)) + description = Column(String(255)) + security_rating = Column(Float) + vulnerabilities = Column(Text) + vulnerabilities_number = Column(Integer) + last_report_date = Column(DateTime) + + +class SecurityReportJsonSerializer(serialize.JsonSerializer): + """Security report serializer""" + + __attributes__ = ['id', 'title', 'description', 'plugin_id', 'report_id', + 'component_id', 'component_type', 'component_name', + 'project_id', 'security_rating', 'vulnerabilities', + 'vulnerabilities_number', 'last_report_date', 'deleted', + 'created_at', 'deleted_at', 'updated_at'] + __required__ = ['id', 'title', 'component_id'] + __attribute_serializer__ = dict(created_at='date', deleted_at='date', + acknowledged_at='date') + __object_class__ = SecurityReport + + +class SecurityAlarm(BASE, CerberusBase): + """Security alarm coming from Security Information and Event Manager + for example + """ + + __tablename__ = 'security_alarm' + __table_args__ = () + + id = Column(Integer, primary_key=True) + plugin_id = Column(String(255)) + alarm_id = Column(String(255), unique=True) + timestamp = Column(DateTime) + status = Column(String(255)) + severity = Column(String(255)) + component_id = Column(String(255)) + summary = Column(String(255)) + description = Column(String(255)) + + +class SecurityAlarmJsonSerializer(serialize.JsonSerializer): + """Security report serializer""" + + __attributes__ = ['id', 'plugin_id', 'alarm_id', 'timestamp', 'status', + 'severity', 'component_id', 'summary', + 'project_id', 'security_rating', 'vulnerabilities', + 'description', 'deleted', 'created_at', 'deleted_at', + 'updated_at'] + __required__ = ['id', 'title'] + __attribute_serializer__ = dict(created_at='date', deleted_at='date', + acknowledged_at='date') + __object_class__ = SecurityAlarm diff --git a/cerberus/manager.py b/cerberus/manager.py new file mode 100644 index 0000000..897b052 --- /dev/null +++ b/cerberus/manager.py @@ -0,0 +1,426 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import uuid + +from oslo.config import cfg +from oslo import messaging +from stevedore import extension + +from cerberus.common import errors +from cerberus.db.sqlalchemy import api +from cerberus.openstack.common import log +from cerberus.openstack.common import loopingcall +from cerberus.openstack.common import service +from cerberus.openstack.common import threadgroup +from plugins import base + + +LOG = log.getLogger(__name__) + +OPTS = [ + + cfg.MultiStrOpt('messaging_urls', + default=[], + help="Messaging URLs to listen for notifications. " + "Example: transport://user:pass@host1:port" + "[,hostN:portN]/virtual_host " + "(DEFAULT/transport_url is used if empty)"), + cfg.ListOpt('notification-topics', default=['designate']), + cfg.ListOpt('cerberus_control_exchange', default=['cerberus']), +] + +cfg.CONF.register_opts(OPTS) + + +class CerberusManager(service.Service): + + TASK_NAMESPACE = 'cerberus.plugins' + + @classmethod + def _get_cerberus_manager(cls): + return extension.ExtensionManager( + namespace=cls.TASK_NAMESPACE, + invoke_on_load=True, + ) + + def __init__(self): + self.task_id = 0 + super(CerberusManager, self).__init__() + + def _register_plugin(self, extension): + # Record plugin in database + version = extension.entry_point.dist.version + plugin = extension.obj + db_plugin_info = api.plugin_info_get(plugin._name) + if db_plugin_info is None: + db_plugin_info = api.plugin_info_create({'name': plugin._name, + 'uuid': uuid.uuid4(), + 'version': version, + 'provider': + plugin.PROVIDER, + 'type': plugin.TYPE, + 'description': + plugin.DESCRIPTION, + 'tool_name': + plugin.TOOL_NAME + }) + else: + api.plugin_version_update(db_plugin_info.id, version) + + plugin._uuid = db_plugin_info.uuid + + def start(self): + + self.rpc_server = None + self.notification_server = None + super(CerberusManager, self).start() + + targets = [] + plugins = [] + self.cerberus_manager = self._get_cerberus_manager() + if not list(self.cerberus_manager): + LOG.warning('Failed to load any task handlers for %s', + self.TASK_NAMESPACE) + + for extension in self.cerberus_manager: + handler = extension.obj + LOG.debug('Plugin loaded: ' + extension.name) + LOG.debug(('Event types from %(name)s: %(type)s') + % {'name': extension.name, + 'type': ', '.join(handler._subscribedEvents)}) + + self._register_plugin(extension) + handler.register_manager(self) + targets.extend(handler.get_targets(cfg.CONF)) + plugins.append(handler) + + transport = messaging.get_transport(cfg.CONF) + + if transport: + rpc_target = messaging.Target(topic='test_rpc', server='server1') + self.rpc_server = messaging.get_rpc_server(transport, rpc_target, + [self], + executor='eventlet') + + self.notification_server = messaging.get_notification_listener( + transport, targets, plugins, executor='eventlet') + + LOG.info("RPC Server starting...") + self.rpc_server.start() + self.notification_server.start() + + def _get_unique_task(self, id): + + try: + unique_task = next( + thread for thread in self.tg.threads + if (thread.kw.get('task_id', None) == id)) + except StopIteration: + return None + return unique_task + + def _get_recurrent_task(self, id): + try: + recurrent_task = next(timer for timer in self.tg.timers if + (timer.kw.get('task_id', None) == id)) + except StopIteration: + return None + return recurrent_task + + def _add_unique_task(self, callback, *args, **kwargs): + """ + Add a simple task executing only once without delay + :param callback: Callable function to call when it's necessary + :param args: list of positional arguments to call the callback with + :param kwargs: dict of keyword arguments to call the callback with + :return the thread object that is created + """ + self.tg.add_thread(callback, *args, **kwargs) + + def _add_recurrent_task(self, callback, period, initial_delay=None, *args, + **kwargs): + """ + Add a recurrent task executing periodically with or without an initial + delay + :param callback: Callable function to call when it's necessary + :param period: the time in seconds during two executions of the task + :param initial_delay: the time after the first execution of the task + occurs + :param args: list of positional arguments to call the callback with + :param kwargs: dict of keyword arguments to call the callback with + """ + self.tg.add_timer(period, callback, initial_delay, *args, **kwargs) + + def get_plugins(self, ctx): + ''' + This method is designed to be called by an rpc client. + E.g: Cerberus-api + It is used to get information about plugins + ''' + json_plugins = [] + for extension in self.cerberus_manager: + plugin = extension.obj + res = json.dumps(plugin, cls=base.PluginEncoder) + json_plugins.append(res) + return json_plugins + + def _get_plugin_from_uuid(self, uuid): + for extension in self.cerberus_manager: + plugin = extension.obj + if (plugin._uuid == uuid): + return plugin + return None + + def get_plugin_from_uuid(self, ctx, uuid): + plugin = self._get_plugin_from_uuid(uuid) + if plugin is not None: + return json.dumps(plugin, cls=base.PluginEncoder) + else: + return None + + def add_task(self, ctx, uuid, method_, *args, **kwargs): + ''' + This method is designed to be called by an rpc client. + E.g: Cerberus-api + It is used to call a method of a plugin back + :param ctx: a request context dict supplied by client + :param uuid: the uuid of the plugin to call method onto + :param method_: the method to call back + :param task_type: the type of task to create + :param args: some extra arguments + :param kwargs: some extra keyworded arguments + ''' + self.task_id += 1 + kwargs['task_id'] = self.task_id + kwargs['plugin_id'] = uuid + task_type = kwargs.get('task_type', "unique") + plugin = self._get_plugin_from_uuid(uuid) + + if plugin is None: + raise errors.PluginNotFound(uuid) + + if (task_type.lower() == 'recurrent'): + try: + task_period = int(kwargs.get('task_period', None)) + except (TypeError, ValueError) as e: + LOG.exception(e) + raise errors.TaskPeriodNotInteger() + try: + self._add_recurrent_task(getattr(plugin, method_), + task_period, + *args, + **kwargs) + except TypeError as e: + LOG.exception(e) + raise errors.MethodNotString() + + except AttributeError as e: + LOG.exception(e) + raise errors.MethodNotCallable(method_, + plugin.__class__.__name__) + else: + try: + self._add_unique_task( + getattr(plugin, method_), + *args, + **kwargs) + except TypeError as e: + LOG.exception(e) + raise errors.MethodNotString() + except AttributeError as e: + LOG.exception(e) + raise errors.MethodNotCallable(method_, + plugin.__class__.__name__) + return self.task_id + + def _stop_recurrent_task(self, id): + """ + Stop the recurrent task but does not remove it from the ThreadGroup. + I.e, the task still exists and could be restarted + Plus, if the task is running, wait for the end of its execution + :param id: the id of the recurrent task to stop + :return: + :raises: + StopIteration: the task is not found + """ + recurrent_task = self._get_recurrent_task(id) + if recurrent_task is None: + raise errors.TaskNotFound(id) + recurrent_task.stop() + + def _stop_unique_task(self, id): + unique_task = self._get_unique_task(id) + if unique_task is None: + raise errors.TaskNotFound(id) + unique_task.stop() + + def _stop_task(self, id): + task = self._get_task(id) + if isinstance(task, loopingcall.FixedIntervalLoopingCall): + try: + self._stop_recurrent_task(id) + except errors.InvalidOperation: + raise + elif isinstance(task, threadgroup.Thread): + try: + self._stop_unique_task(id) + except errors.InvalidOperation: + raise + + def stop_task(self, ctx, id): + try: + self._stop_task(id) + except errors.InvalidOperation: + raise + return id + + def _delete_recurrent_task(self, id): + """ + Stop the task and delete the recurrent task from the ThreadGroup. + If the task is running, wait for the end of its execution + :param id: the identifier of the task to delete + :return: + """ + recurrent_task = self._get_recurrent_task(id) + if (recurrent_task is None): + raise errors.TaskDeletionNotAllowed(id) + recurrent_task.stop() + try: + self.tg.timers.remove(recurrent_task) + except ValueError: + raise + + def delete_recurrent_task(self, ctx, id): + ''' + This method is designed to be called by an rpc client. + E.g: Cerberus-api + Stop the task and delete the recurrent task from the ThreadGroup. + If the task is running, wait for the end of its execution + :param ctx: a request context dict supplied by client + :param id: the identifier of the task to delete + ''' + try: + self._delete_recurrent_task(id) + except errors.InvalidOperation: + raise + return id + + def _force_delete_recurrent_task(self, id): + """ + Stop the task even if it is running and delete the recurrent task from + the ThreadGroup. + :param id: the identifier of the task to force delete + :return: + """ + recurrent_task = self._get_recurrent_task(id) + if (recurrent_task is None): + raise errors.TaskDeletionNotAllowed(id) + recurrent_task.stop() + recurrent_task.gt.kill() + try: + self.tg.timers.remove(recurrent_task) + except ValueError: + raise + + def force_delete_recurrent_task(self, ctx, id): + ''' + This method is designed to be called by an rpc client. + E.g: Cerberus-api + Stop the task even if it is running and delete the recurrent task + from the ThreadGroup. + :param ctx: a request context dict supplied by client + :param id: the identifier of the task to force delete + ''' + try: + self._force_delete_recurrent_task(id) + except errors.InvalidOperation: + raise + return id + + def _get_tasks(self): + tasks = [] + for timer in self.tg.timers: + tasks.append(timer) + for thread in self.tg.threads: + tasks.append(thread) + return tasks + + def _get_task(self, id): + task = self._get_unique_task(id) + task_ = self._get_recurrent_task(id) + if (task is None and task_ is None): + raise errors.TaskNotFound(id) + return task if task is not None else task_ + + def get_tasks(self, ctx): + tasks_ = [] + tasks = self._get_tasks() + for task in tasks: + if (isinstance(task, loopingcall.FixedIntervalLoopingCall)): + tasks_.append( + json.dumps(task, + cls=base.FixedIntervalLoopingCallEncoder)) + elif (isinstance(task, threadgroup.Thread)): + tasks_.append( + json.dumps(task, + cls=base.ThreadEncoder)) + return tasks_ + + def get_task(self, ctx, id): + try: + task = self._get_task(id) + except errors.InvalidOperation: + raise + if isinstance(task, loopingcall.FixedIntervalLoopingCall): + return json.dumps(task, + cls=base.FixedIntervalLoopingCallEncoder) + elif isinstance(task, threadgroup.Thread): + return json.dumps(task, + cls=base.ThreadEncoder) + + def _restart_recurrent_task(self, id): + """ + Restart the task + :param id: the identifier of the task to restart + :return: + """ + recurrent_task = self._get_recurrent_task(id) + if (recurrent_task is None): + raise errors.TaskRestartNotAllowed(str(id)) + period = recurrent_task.kw.get("task_period", None) + if recurrent_task._running is True: + raise errors.TaskRestartNotPossible(str(id)) + else: + try: + recurrent_task.start(int(period)) + except ValueError as e: + LOG.exception(e) + + def restart_recurrent_task(self, ctx, id): + ''' + This method is designed to be called by an rpc client. + E.g: Cerberus-api + Restart a recurrent task after it's being stopped + :param ctx: a request context dict supplied by client + :param id: the identifier of the task to restart + ''' + try: + self._restart_recurrent_task(id) + except errors.InvalidOperation: + raise + return id diff --git a/cerberus/openstack/__init__.py b/cerberus/openstack/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cerberus/openstack/common/__init__.py b/cerberus/openstack/common/__init__.py new file mode 100644 index 0000000..d1223ea --- /dev/null +++ b/cerberus/openstack/common/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import six + + +six.add_move(six.MovedModule('mox', 'mox', 'mox3.mox')) diff --git a/cerberus/openstack/common/_i18n.py b/cerberus/openstack/common/_i18n.py new file mode 100644 index 0000000..4a6691f --- /dev/null +++ b/cerberus/openstack/common/_i18n.py @@ -0,0 +1,40 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""oslo.i18n integration module. + +See http://docs.openstack.org/developer/oslo.i18n/usage.html + +""" + +import oslo.i18n + + +# NOTE(dhellmann): This reference to o-s-l-o will be replaced by the +# application name when this module is synced into the separate +# repository. It is OK to have more than one translation function +# using the same domain, since there will still only be one message +# catalog. +_translators = oslo.i18n.TranslatorFactory(domain='oslo') + +# The primary translation function using the well-known name "_" +_ = _translators.primary + +# Translators for log levels. +# +# The abbreviated names are meant to reflect the usual use of a short +# name like '_'. The "L" is for "log" and the other letter comes from +# the level. +_LI = _translators.log_info +_LW = _translators.log_warning +_LE = _translators.log_error +_LC = _translators.log_critical diff --git a/cerberus/openstack/common/apiclient/__init__.py b/cerberus/openstack/common/apiclient/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cerberus/openstack/common/apiclient/auth.py b/cerberus/openstack/common/apiclient/auth.py new file mode 100644 index 0000000..1763818 --- /dev/null +++ b/cerberus/openstack/common/apiclient/auth.py @@ -0,0 +1,221 @@ +# Copyright 2013 OpenStack Foundation +# Copyright 2013 Spanish National Research Council. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +# E0202: An attribute inherited from %s hide this method +# pylint: disable=E0202 + +import abc +import argparse +import os + +import six +from stevedore import extension + +from cerberus.openstack.common.apiclient import exceptions + + +_discovered_plugins = {} + + +def discover_auth_systems(): + """Discover the available auth-systems. + + This won't take into account the old style auth-systems. + """ + global _discovered_plugins + _discovered_plugins = {} + + def add_plugin(ext): + _discovered_plugins[ext.name] = ext.plugin + + ep_namespace = "cerberus.openstack.common.apiclient.auth" + mgr = extension.ExtensionManager(ep_namespace) + mgr.map(add_plugin) + + +def load_auth_system_opts(parser): + """Load options needed by the available auth-systems into a parser. + + This function will try to populate the parser with options from the + available plugins. + """ + group = parser.add_argument_group("Common auth options") + BaseAuthPlugin.add_common_opts(group) + for name, auth_plugin in six.iteritems(_discovered_plugins): + group = parser.add_argument_group( + "Auth-system '%s' options" % name, + conflict_handler="resolve") + auth_plugin.add_opts(group) + + +def load_plugin(auth_system): + try: + plugin_class = _discovered_plugins[auth_system] + except KeyError: + raise exceptions.AuthSystemNotFound(auth_system) + return plugin_class(auth_system=auth_system) + + +def load_plugin_from_args(args): + """Load required plugin and populate it with options. + + Try to guess auth system if it is not specified. Systems are tried in + alphabetical order. + + :type args: argparse.Namespace + :raises: AuthPluginOptionsMissing + """ + auth_system = args.os_auth_system + if auth_system: + plugin = load_plugin(auth_system) + plugin.parse_opts(args) + plugin.sufficient_options() + return plugin + + for plugin_auth_system in sorted(six.iterkeys(_discovered_plugins)): + plugin_class = _discovered_plugins[plugin_auth_system] + plugin = plugin_class() + plugin.parse_opts(args) + try: + plugin.sufficient_options() + except exceptions.AuthPluginOptionsMissing: + continue + return plugin + raise exceptions.AuthPluginOptionsMissing(["auth_system"]) + + +@six.add_metaclass(abc.ABCMeta) +class BaseAuthPlugin(object): + """Base class for authentication plugins. + + An authentication plugin needs to override at least the authenticate + method to be a valid plugin. + """ + + auth_system = None + opt_names = [] + common_opt_names = [ + "auth_system", + "username", + "password", + "tenant_name", + "token", + "auth_url", + ] + + def __init__(self, auth_system=None, **kwargs): + self.auth_system = auth_system or self.auth_system + self.opts = dict((name, kwargs.get(name)) + for name in self.opt_names) + + @staticmethod + def _parser_add_opt(parser, opt): + """Add an option to parser in two variants. + + :param opt: option name (with underscores) + """ + dashed_opt = opt.replace("_", "-") + env_var = "OS_%s" % opt.upper() + arg_default = os.environ.get(env_var, "") + arg_help = "Defaults to env[%s]." % env_var + parser.add_argument( + "--os-%s" % dashed_opt, + metavar="<%s>" % dashed_opt, + default=arg_default, + help=arg_help) + parser.add_argument( + "--os_%s" % opt, + metavar="<%s>" % dashed_opt, + help=argparse.SUPPRESS) + + @classmethod + def add_opts(cls, parser): + """Populate the parser with the options for this plugin. + """ + for opt in cls.opt_names: + # use `BaseAuthPlugin.common_opt_names` since it is never + # changed in child classes + if opt not in BaseAuthPlugin.common_opt_names: + cls._parser_add_opt(parser, opt) + + @classmethod + def add_common_opts(cls, parser): + """Add options that are common for several plugins. + """ + for opt in cls.common_opt_names: + cls._parser_add_opt(parser, opt) + + @staticmethod + def get_opt(opt_name, args): + """Return option name and value. + + :param opt_name: name of the option, e.g., "username" + :param args: parsed arguments + """ + return (opt_name, getattr(args, "os_%s" % opt_name, None)) + + def parse_opts(self, args): + """Parse the actual auth-system options if any. + + This method is expected to populate the attribute `self.opts` with a + dict containing the options and values needed to make authentication. + """ + self.opts.update(dict(self.get_opt(opt_name, args) + for opt_name in self.opt_names)) + + def authenticate(self, http_client): + """Authenticate using plugin defined method. + + The method usually analyses `self.opts` and performs + a request to authentication server. + + :param http_client: client object that needs authentication + :type http_client: HTTPClient + :raises: AuthorizationFailure + """ + self.sufficient_options() + self._do_authenticate(http_client) + + @abc.abstractmethod + def _do_authenticate(self, http_client): + """Protected method for authentication. + """ + + def sufficient_options(self): + """Check if all required options are present. + + :raises: AuthPluginOptionsMissing + """ + missing = [opt + for opt in self.opt_names + if not self.opts.get(opt)] + if missing: + raise exceptions.AuthPluginOptionsMissing(missing) + + @abc.abstractmethod + def token_and_endpoint(self, endpoint_type, service_type): + """Return token and endpoint. + + :param service_type: Service type of the endpoint + :type service_type: string + :param endpoint_type: Type of endpoint. + Possible values: public or publicURL, + internal or internalURL, + admin or adminURL + :type endpoint_type: string + :returns: tuple of token and endpoint strings + :raises: EndpointException + """ diff --git a/cerberus/openstack/common/apiclient/base.py b/cerberus/openstack/common/apiclient/base.py new file mode 100644 index 0000000..bd2a48e --- /dev/null +++ b/cerberus/openstack/common/apiclient/base.py @@ -0,0 +1,500 @@ +# Copyright 2010 Jacob Kaplan-Moss +# Copyright 2011 OpenStack Foundation +# Copyright 2012 Grid Dynamics +# Copyright 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Base utilities to build API operation managers and objects on top of. +""" + +# E1102: %s is not callable +# pylint: disable=E1102 + +import abc +import copy + +import six +from six.moves.urllib import parse + +from cerberus.openstack.common.apiclient import exceptions +from cerberus.openstack.common import strutils + + +def getid(obj): + """Return id if argument is a Resource. + + Abstracts the common pattern of allowing both an object or an object's ID + (UUID) as a parameter when dealing with relationships. + """ + try: + if obj.uuid: + return obj.uuid + except AttributeError: + pass + try: + return obj.id + except AttributeError: + return obj + + +# TODO(aababilov): call run_hooks() in HookableMixin's child classes +class HookableMixin(object): + """Mixin so classes can register and run hooks.""" + _hooks_map = {} + + @classmethod + def add_hook(cls, hook_type, hook_func): + """Add a new hook of specified type. + + :param cls: class that registers hooks + :param hook_type: hook type, e.g., '__pre_parse_args__' + :param hook_func: hook function + """ + if hook_type not in cls._hooks_map: + cls._hooks_map[hook_type] = [] + + cls._hooks_map[hook_type].append(hook_func) + + @classmethod + def run_hooks(cls, hook_type, *args, **kwargs): + """Run all hooks of specified type. + + :param cls: class that registers hooks + :param hook_type: hook type, e.g., '__pre_parse_args__' + :param **args: args to be passed to every hook function + :param **kwargs: kwargs to be passed to every hook function + """ + hook_funcs = cls._hooks_map.get(hook_type) or [] + for hook_func in hook_funcs: + hook_func(*args, **kwargs) + + +class BaseManager(HookableMixin): + """Basic manager type providing common operations. + + Managers interact with a particular type of API (servers, flavors, images, + etc.) and provide CRUD operations for them. + """ + resource_class = None + + def __init__(self, client): + """Initializes BaseManager with `client`. + + :param client: instance of BaseClient descendant for HTTP requests + """ + super(BaseManager, self).__init__() + self.client = client + + def _list(self, url, response_key, obj_class=None, json=None): + """List the collection. + + :param url: a partial URL, e.g., '/servers' + :param response_key: the key to be looked up in response dictionary, + e.g., 'servers' + :param obj_class: class for constructing the returned objects + (self.resource_class will be used by default) + :param json: data that will be encoded as JSON and passed in POST + request (GET will be sent by default) + """ + if json: + body = self.client.post(url, json=json).json() + else: + body = self.client.get(url).json() + + if obj_class is None: + obj_class = self.resource_class + + data = body[response_key] + # NOTE(ja): keystone returns values as list as {'values': [ ... ]} + # unlike other services which just return the list... + try: + data = data['values'] + except (KeyError, TypeError): + pass + + return [obj_class(self, res, loaded=True) for res in data if res] + + def _get(self, url, response_key): + """Get an object from collection. + + :param url: a partial URL, e.g., '/servers' + :param response_key: the key to be looked up in response dictionary, + e.g., 'server' + """ + body = self.client.get(url).json() + return self.resource_class(self, body[response_key], loaded=True) + + def _head(self, url): + """Retrieve request headers for an object. + + :param url: a partial URL, e.g., '/servers' + """ + resp = self.client.head(url) + return resp.status_code == 204 + + def _post(self, url, json, response_key, return_raw=False): + """Create an object. + + :param url: a partial URL, e.g., '/servers' + :param json: data that will be encoded as JSON and passed in POST + request (GET will be sent by default) + :param response_key: the key to be looked up in response dictionary, + e.g., 'servers' + :param return_raw: flag to force returning raw JSON instead of + Python object of self.resource_class + """ + body = self.client.post(url, json=json).json() + if return_raw: + return body[response_key] + return self.resource_class(self, body[response_key]) + + def _put(self, url, json=None, response_key=None): + """Update an object with PUT method. + + :param url: a partial URL, e.g., '/servers' + :param json: data that will be encoded as JSON and passed in POST + request (GET will be sent by default) + :param response_key: the key to be looked up in response dictionary, + e.g., 'servers' + """ + resp = self.client.put(url, json=json) + # PUT requests may not return a body + if resp.content: + body = resp.json() + if response_key is not None: + return self.resource_class(self, body[response_key]) + else: + return self.resource_class(self, body) + + def _patch(self, url, json=None, response_key=None): + """Update an object with PATCH method. + + :param url: a partial URL, e.g., '/servers' + :param json: data that will be encoded as JSON and passed in POST + request (GET will be sent by default) + :param response_key: the key to be looked up in response dictionary, + e.g., 'servers' + """ + body = self.client.patch(url, json=json).json() + if response_key is not None: + return self.resource_class(self, body[response_key]) + else: + return self.resource_class(self, body) + + def _delete(self, url): + """Delete an object. + + :param url: a partial URL, e.g., '/servers/my-server' + """ + return self.client.delete(url) + + +@six.add_metaclass(abc.ABCMeta) +class ManagerWithFind(BaseManager): + """Manager with additional `find()`/`findall()` methods.""" + + @abc.abstractmethod + def list(self): + pass + + def find(self, **kwargs): + """Find a single item with attributes matching ``**kwargs``. + + This isn't very efficient: it loads the entire list then filters on + the Python side. + """ + matches = self.findall(**kwargs) + num_matches = len(matches) + if num_matches == 0: + msg = "No %s matching %s." % (self.resource_class.__name__, kwargs) + raise exceptions.NotFound(msg) + elif num_matches > 1: + raise exceptions.NoUniqueMatch() + else: + return matches[0] + + def findall(self, **kwargs): + """Find all items with attributes matching ``**kwargs``. + + This isn't very efficient: it loads the entire list then filters on + the Python side. + """ + found = [] + searches = kwargs.items() + + for obj in self.list(): + try: + if all(getattr(obj, attr) == value + for (attr, value) in searches): + found.append(obj) + except AttributeError: + continue + + return found + + +class CrudManager(BaseManager): + """Base manager class for manipulating entities. + + Children of this class are expected to define a `collection_key` and `key`. + + - `collection_key`: Usually a plural noun by convention (e.g. `entities`); + used to refer collections in both URL's (e.g. `/v3/entities`) and JSON + objects containing a list of member resources (e.g. `{'entities': [{}, + {}, {}]}`). + - `key`: Usually a singular noun by convention (e.g. `entity`); used to + refer to an individual member of the collection. + + """ + collection_key = None + key = None + + def build_url(self, base_url=None, **kwargs): + """Builds a resource URL for the given kwargs. + + Given an example collection where `collection_key = 'entities'` and + `key = 'entity'`, the following URL's could be generated. + + By default, the URL will represent a collection of entities, e.g.:: + + /entities + + If kwargs contains an `entity_id`, then the URL will represent a + specific member, e.g.:: + + /entities/{entity_id} + + :param base_url: if provided, the generated URL will be appended to it + """ + url = base_url if base_url is not None else '' + + url += '/%s' % self.collection_key + + # do we have a specific entity? + entity_id = kwargs.get('%s_id' % self.key) + if entity_id is not None: + url += '/%s' % entity_id + + return url + + def _filter_kwargs(self, kwargs): + """Drop null values and handle ids.""" + for key, ref in six.iteritems(kwargs.copy()): + if ref is None: + kwargs.pop(key) + else: + if isinstance(ref, Resource): + kwargs.pop(key) + kwargs['%s_id' % key] = getid(ref) + return kwargs + + def create(self, **kwargs): + kwargs = self._filter_kwargs(kwargs) + return self._post( + self.build_url(**kwargs), + {self.key: kwargs}, + self.key) + + def get(self, **kwargs): + kwargs = self._filter_kwargs(kwargs) + return self._get( + self.build_url(**kwargs), + self.key) + + def head(self, **kwargs): + kwargs = self._filter_kwargs(kwargs) + return self._head(self.build_url(**kwargs)) + + def list(self, base_url=None, **kwargs): + """List the collection. + + :param base_url: if provided, the generated URL will be appended to it + """ + kwargs = self._filter_kwargs(kwargs) + + return self._list( + '%(base_url)s%(query)s' % { + 'base_url': self.build_url(base_url=base_url, **kwargs), + 'query': '?%s' % parse.urlencode(kwargs) if kwargs else '', + }, + self.collection_key) + + def put(self, base_url=None, **kwargs): + """Update an element. + + :param base_url: if provided, the generated URL will be appended to it + """ + kwargs = self._filter_kwargs(kwargs) + + return self._put(self.build_url(base_url=base_url, **kwargs)) + + def update(self, **kwargs): + kwargs = self._filter_kwargs(kwargs) + params = kwargs.copy() + params.pop('%s_id' % self.key) + + return self._patch( + self.build_url(**kwargs), + {self.key: params}, + self.key) + + def delete(self, **kwargs): + kwargs = self._filter_kwargs(kwargs) + + return self._delete( + self.build_url(**kwargs)) + + def find(self, base_url=None, **kwargs): + """Find a single item with attributes matching ``**kwargs``. + + :param base_url: if provided, the generated URL will be appended to it + """ + kwargs = self._filter_kwargs(kwargs) + + rl = self._list( + '%(base_url)s%(query)s' % { + 'base_url': self.build_url(base_url=base_url, **kwargs), + 'query': '?%s' % parse.urlencode(kwargs) if kwargs else '', + }, + self.collection_key) + num = len(rl) + + if num == 0: + msg = "No %s matching %s." % (self.resource_class.__name__, kwargs) + raise exceptions.NotFound(404, msg) + elif num > 1: + raise exceptions.NoUniqueMatch + else: + return rl[0] + + +class Extension(HookableMixin): + """Extension descriptor.""" + + SUPPORTED_HOOKS = ('__pre_parse_args__', '__post_parse_args__') + manager_class = None + + def __init__(self, name, module): + super(Extension, self).__init__() + self.name = name + self.module = module + self._parse_extension_module() + + def _parse_extension_module(self): + self.manager_class = None + for attr_name, attr_value in self.module.__dict__.items(): + if attr_name in self.SUPPORTED_HOOKS: + self.add_hook(attr_name, attr_value) + else: + try: + if issubclass(attr_value, BaseManager): + self.manager_class = attr_value + except TypeError: + pass + + def __repr__(self): + return "" % self.name + + +class Resource(object): + """Base class for OpenStack resources (tenant, user, etc.). + + This is pretty much just a bag for attributes. + """ + + HUMAN_ID = False + NAME_ATTR = 'name' + + def __init__(self, manager, info, loaded=False): + """Populate and bind to a manager. + + :param manager: BaseManager object + :param info: dictionary representing resource attributes + :param loaded: prevent lazy-loading if set to True + """ + self.manager = manager + self._info = info + self._add_details(info) + self._loaded = loaded + + def __repr__(self): + reprkeys = sorted(k + for k in self.__dict__.keys() + if k[0] != '_' and k != 'manager') + info = ", ".join("%s=%s" % (k, getattr(self, k)) for k in reprkeys) + return "<%s %s>" % (self.__class__.__name__, info) + + @property + def human_id(self): + """Human-readable ID which can be used for bash completion. + """ + if self.NAME_ATTR in self.__dict__ and self.HUMAN_ID: + return strutils.to_slug(getattr(self, self.NAME_ATTR)) + return None + + def _add_details(self, info): + for (k, v) in six.iteritems(info): + try: + setattr(self, k, v) + self._info[k] = v + except AttributeError: + # In this case we already defined the attribute on the class + pass + + def __getattr__(self, k): + if k not in self.__dict__: + #NOTE(bcwaldon): disallow lazy-loading if already loaded once + if not self.is_loaded(): + self.get() + return self.__getattr__(k) + + raise AttributeError(k) + else: + return self.__dict__[k] + + def get(self): + """Support for lazy loading details. + + Some clients, such as novaclient have the option to lazy load the + details, details which can be loaded with this function. + """ + # set_loaded() first ... so if we have to bail, we know we tried. + self.set_loaded(True) + if not hasattr(self.manager, 'get'): + return + + new = self.manager.get(self.id) + if new: + self._add_details(new._info) + + def __eq__(self, other): + if not isinstance(other, Resource): + return NotImplemented + # two resources of different types are not equal + if not isinstance(other, self.__class__): + return False + if hasattr(self, 'id') and hasattr(other, 'id'): + return self.id == other.id + return self._info == other._info + + def is_loaded(self): + return self._loaded + + def set_loaded(self, val): + self._loaded = val + + def to_dict(self): + return copy.deepcopy(self._info) diff --git a/cerberus/openstack/common/apiclient/client.py b/cerberus/openstack/common/apiclient/client.py new file mode 100644 index 0000000..5bc0c7d --- /dev/null +++ b/cerberus/openstack/common/apiclient/client.py @@ -0,0 +1,358 @@ +# Copyright 2010 Jacob Kaplan-Moss +# Copyright 2011 OpenStack Foundation +# Copyright 2011 Piston Cloud Computing, Inc. +# Copyright 2013 Alessio Ababilov +# Copyright 2013 Grid Dynamics +# Copyright 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +OpenStack Client interface. Handles the REST calls and responses. +""" + +# E0202: An attribute inherited from %s hide this method +# pylint: disable=E0202 + +import logging +import time + +try: + import simplejson as json +except ImportError: + import json + +import requests + +from cerberus.openstack.common.apiclient import exceptions +from cerberus.openstack.common import importutils + + +_logger = logging.getLogger(__name__) + + +class HTTPClient(object): + """This client handles sending HTTP requests to OpenStack servers. + + Features: + - share authentication information between several clients to different + services (e.g., for compute and image clients); + - reissue authentication request for expired tokens; + - encode/decode JSON bodies; + - raise exceptions on HTTP errors; + - pluggable authentication; + - store authentication information in a keyring; + - store time spent for requests; + - register clients for particular services, so one can use + `http_client.identity` or `http_client.compute`; + - log requests and responses in a format that is easy to copy-and-paste + into terminal and send the same request with curl. + """ + + user_agent = "cerberus.openstack.common.apiclient" + + def __init__(self, + auth_plugin, + region_name=None, + endpoint_type="publicURL", + original_ip=None, + verify=True, + cert=None, + timeout=None, + timings=False, + keyring_saver=None, + debug=False, + user_agent=None, + http=None): + self.auth_plugin = auth_plugin + + self.endpoint_type = endpoint_type + self.region_name = region_name + + self.original_ip = original_ip + self.timeout = timeout + self.verify = verify + self.cert = cert + + self.keyring_saver = keyring_saver + self.debug = debug + self.user_agent = user_agent or self.user_agent + + self.times = [] # [("item", starttime, endtime), ...] + self.timings = timings + + # requests within the same session can reuse TCP connections from pool + self.http = http or requests.Session() + + self.cached_token = None + + def _http_log_req(self, method, url, kwargs): + if not self.debug: + return + + string_parts = [ + "curl -i", + "-X '%s'" % method, + "'%s'" % url, + ] + + for element in kwargs['headers']: + header = "-H '%s: %s'" % (element, kwargs['headers'][element]) + string_parts.append(header) + + _logger.debug("REQ: %s" % " ".join(string_parts)) + if 'data' in kwargs: + _logger.debug("REQ BODY: %s\n" % (kwargs['data'])) + + def _http_log_resp(self, resp): + if not self.debug: + return + _logger.debug( + "RESP: [%s] %s\n", + resp.status_code, + resp.headers) + if resp._content_consumed: + _logger.debug( + "RESP BODY: %s\n", + resp.text) + + def serialize(self, kwargs): + if kwargs.get('json') is not None: + kwargs['headers']['Content-Type'] = 'application/json' + kwargs['data'] = json.dumps(kwargs['json']) + try: + del kwargs['json'] + except KeyError: + pass + + def get_timings(self): + return self.times + + def reset_timings(self): + self.times = [] + + def request(self, method, url, **kwargs): + """Send an http request with the specified characteristics. + + Wrapper around `requests.Session.request` to handle tasks such as + setting headers, JSON encoding/decoding, and error handling. + + :param method: method of HTTP request + :param url: URL of HTTP request + :param kwargs: any other parameter that can be passed to +' requests.Session.request (such as `headers`) or `json` + that will be encoded as JSON and used as `data` argument + """ + kwargs.setdefault("headers", kwargs.get("headers", {})) + kwargs["headers"]["User-Agent"] = self.user_agent + if self.original_ip: + kwargs["headers"]["Forwarded"] = "for=%s;by=%s" % ( + self.original_ip, self.user_agent) + if self.timeout is not None: + kwargs.setdefault("timeout", self.timeout) + kwargs.setdefault("verify", self.verify) + if self.cert is not None: + kwargs.setdefault("cert", self.cert) + self.serialize(kwargs) + + self._http_log_req(method, url, kwargs) + if self.timings: + start_time = time.time() + resp = self.http.request(method, url, **kwargs) + if self.timings: + self.times.append(("%s %s" % (method, url), + start_time, time.time())) + self._http_log_resp(resp) + + if resp.status_code >= 400: + _logger.debug( + "Request returned failure status: %s", + resp.status_code) + raise exceptions.from_response(resp, method, url) + + return resp + + @staticmethod + def concat_url(endpoint, url): + """Concatenate endpoint and final URL. + + E.g., "http://keystone/v2.0/" and "/tokens" are concatenated to + "http://keystone/v2.0/tokens". + + :param endpoint: the base URL + :param url: the final URL + """ + return "%s/%s" % (endpoint.rstrip("/"), url.strip("/")) + + def client_request(self, client, method, url, **kwargs): + """Send an http request using `client`'s endpoint and specified `url`. + + If request was rejected as unauthorized (possibly because the token is + expired), issue one authorization attempt and send the request once + again. + + :param client: instance of BaseClient descendant + :param method: method of HTTP request + :param url: URL of HTTP request + :param kwargs: any other parameter that can be passed to +' `HTTPClient.request` + """ + + filter_args = { + "endpoint_type": client.endpoint_type or self.endpoint_type, + "service_type": client.service_type, + } + token, endpoint = (self.cached_token, client.cached_endpoint) + just_authenticated = False + if not (token and endpoint): + try: + token, endpoint = self.auth_plugin.token_and_endpoint( + **filter_args) + except exceptions.EndpointException: + pass + if not (token and endpoint): + self.authenticate() + just_authenticated = True + token, endpoint = self.auth_plugin.token_and_endpoint( + **filter_args) + if not (token and endpoint): + raise exceptions.AuthorizationFailure( + "Cannot find endpoint or token for request") + + old_token_endpoint = (token, endpoint) + kwargs.setdefault("headers", {})["X-Auth-Token"] = token + self.cached_token = token + client.cached_endpoint = endpoint + # Perform the request once. If we get Unauthorized, then it + # might be because the auth token expired, so try to + # re-authenticate and try again. If it still fails, bail. + try: + return self.request( + method, self.concat_url(endpoint, url), **kwargs) + except exceptions.Unauthorized as unauth_ex: + if just_authenticated: + raise + self.cached_token = None + client.cached_endpoint = None + self.authenticate() + try: + token, endpoint = self.auth_plugin.token_and_endpoint( + **filter_args) + except exceptions.EndpointException: + raise unauth_ex + if (not (token and endpoint) or + old_token_endpoint == (token, endpoint)): + raise unauth_ex + self.cached_token = token + client.cached_endpoint = endpoint + kwargs["headers"]["X-Auth-Token"] = token + return self.request( + method, self.concat_url(endpoint, url), **kwargs) + + def add_client(self, base_client_instance): + """Add a new instance of :class:`BaseClient` descendant. + + `self` will store a reference to `base_client_instance`. + + Example: + + >>> def test_clients(): + ... from keystoneclient.auth import keystone + ... from openstack.common.apiclient import client + ... auth = keystone.KeystoneAuthPlugin( + ... username="user", password="pass", tenant_name="tenant", + ... auth_url="http://auth:5000/v2.0") + ... openstack_client = client.HTTPClient(auth) + ... # create nova client + ... from novaclient.v1_1 import client + ... client.Client(openstack_client) + ... # create keystone client + ... from keystoneclient.v2_0 import client + ... client.Client(openstack_client) + ... # use them + ... openstack_client.identity.tenants.list() + ... openstack_client.compute.servers.list() + """ + service_type = base_client_instance.service_type + if service_type and not hasattr(self, service_type): + setattr(self, service_type, base_client_instance) + + def authenticate(self): + self.auth_plugin.authenticate(self) + # Store the authentication results in the keyring for later requests + if self.keyring_saver: + self.keyring_saver.save(self) + + +class BaseClient(object): + """Top-level object to access the OpenStack API. + + This client uses :class:`HTTPClient` to send requests. :class:`HTTPClient` + will handle a bunch of issues such as authentication. + """ + + service_type = None + endpoint_type = None # "publicURL" will be used + cached_endpoint = None + + def __init__(self, http_client, extensions=None): + self.http_client = http_client + http_client.add_client(self) + + # Add in any extensions... + if extensions: + for extension in extensions: + if extension.manager_class: + setattr(self, extension.name, + extension.manager_class(self)) + + def client_request(self, method, url, **kwargs): + return self.http_client.client_request( + self, method, url, **kwargs) + + def head(self, url, **kwargs): + return self.client_request("HEAD", url, **kwargs) + + def get(self, url, **kwargs): + return self.client_request("GET", url, **kwargs) + + def post(self, url, **kwargs): + return self.client_request("POST", url, **kwargs) + + def put(self, url, **kwargs): + return self.client_request("PUT", url, **kwargs) + + def delete(self, url, **kwargs): + return self.client_request("DELETE", url, **kwargs) + + def patch(self, url, **kwargs): + return self.client_request("PATCH", url, **kwargs) + + @staticmethod + def get_class(api_name, version, version_map): + """Returns the client class for the requested API version + + :param api_name: the name of the API, e.g. 'compute', 'image', etc + :param version: the requested API version + :param version_map: a dict of client classes keyed by version + :rtype: a client class for the requested API version + """ + try: + client_path = version_map[str(version)] + except (KeyError, ValueError): + msg = "Invalid %s client version '%s'. must be one of: %s" % ( + (api_name, version, ', '.join(version_map.keys()))) + raise exceptions.UnsupportedVersion(msg) + + return importutils.import_class(client_path) diff --git a/cerberus/openstack/common/apiclient/exceptions.py b/cerberus/openstack/common/apiclient/exceptions.py new file mode 100644 index 0000000..ada1344 --- /dev/null +++ b/cerberus/openstack/common/apiclient/exceptions.py @@ -0,0 +1,459 @@ +# Copyright 2010 Jacob Kaplan-Moss +# Copyright 2011 Nebula, Inc. +# Copyright 2013 Alessio Ababilov +# Copyright 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Exception definitions. +""" + +import inspect +import sys + +import six + + +class ClientException(Exception): + """The base exception class for all exceptions this library raises. + """ + pass + + +class MissingArgs(ClientException): + """Supplied arguments are not sufficient for calling a function.""" + def __init__(self, missing): + self.missing = missing + msg = "Missing argument(s): %s" % ", ".join(missing) + super(MissingArgs, self).__init__(msg) + + +class ValidationError(ClientException): + """Error in validation on API client side.""" + pass + + +class UnsupportedVersion(ClientException): + """User is trying to use an unsupported version of the API.""" + pass + + +class CommandError(ClientException): + """Error in CLI tool.""" + pass + + +class AuthorizationFailure(ClientException): + """Cannot authorize API client.""" + pass + + +class ConnectionRefused(ClientException): + """Cannot connect to API service.""" + pass + + +class AuthPluginOptionsMissing(AuthorizationFailure): + """Auth plugin misses some options.""" + def __init__(self, opt_names): + super(AuthPluginOptionsMissing, self).__init__( + "Authentication failed. Missing options: %s" % + ", ".join(opt_names)) + self.opt_names = opt_names + + +class AuthSystemNotFound(AuthorizationFailure): + """User has specified a AuthSystem that is not installed.""" + def __init__(self, auth_system): + super(AuthSystemNotFound, self).__init__( + "AuthSystemNotFound: %s" % repr(auth_system)) + self.auth_system = auth_system + + +class NoUniqueMatch(ClientException): + """Multiple entities found instead of one.""" + pass + + +class EndpointException(ClientException): + """Something is rotten in Service Catalog.""" + pass + + +class EndpointNotFound(EndpointException): + """Could not find requested endpoint in Service Catalog.""" + pass + + +class AmbiguousEndpoints(EndpointException): + """Found more than one matching endpoint in Service Catalog.""" + def __init__(self, endpoints=None): + super(AmbiguousEndpoints, self).__init__( + "AmbiguousEndpoints: %s" % repr(endpoints)) + self.endpoints = endpoints + + +class HttpError(ClientException): + """The base exception class for all HTTP exceptions. + """ + http_status = 0 + message = "HTTP Error" + + def __init__(self, message=None, details=None, + response=None, request_id=None, + url=None, method=None, http_status=None): + self.http_status = http_status or self.http_status + self.message = message or self.message + self.details = details + self.request_id = request_id + self.response = response + self.url = url + self.method = method + formatted_string = "%s (HTTP %s)" % (self.message, self.http_status) + if request_id: + formatted_string += " (Request-ID: %s)" % request_id + super(HttpError, self).__init__(formatted_string) + + +class HTTPRedirection(HttpError): + """HTTP Redirection.""" + message = "HTTP Redirection" + + +class HTTPClientError(HttpError): + """Client-side HTTP error. + + Exception for cases in which the client seems to have erred. + """ + message = "HTTP Client Error" + + +class HttpServerError(HttpError): + """Server-side HTTP error. + + Exception for cases in which the server is aware that it has + erred or is incapable of performing the request. + """ + message = "HTTP Server Error" + + +class MultipleChoices(HTTPRedirection): + """HTTP 300 - Multiple Choices. + + Indicates multiple options for the resource that the client may follow. + """ + + http_status = 300 + message = "Multiple Choices" + + +class BadRequest(HTTPClientError): + """HTTP 400 - Bad Request. + + The request cannot be fulfilled due to bad syntax. + """ + http_status = 400 + message = "Bad Request" + + +class Unauthorized(HTTPClientError): + """HTTP 401 - Unauthorized. + + Similar to 403 Forbidden, but specifically for use when authentication + is required and has failed or has not yet been provided. + """ + http_status = 401 + message = "Unauthorized" + + +class PaymentRequired(HTTPClientError): + """HTTP 402 - Payment Required. + + Reserved for future use. + """ + http_status = 402 + message = "Payment Required" + + +class Forbidden(HTTPClientError): + """HTTP 403 - Forbidden. + + The request was a valid request, but the server is refusing to respond + to it. + """ + http_status = 403 + message = "Forbidden" + + +class NotFound(HTTPClientError): + """HTTP 404 - Not Found. + + The requested resource could not be found but may be available again + in the future. + """ + http_status = 404 + message = "Not Found" + + +class MethodNotAllowed(HTTPClientError): + """HTTP 405 - Method Not Allowed. + + A request was made of a resource using a request method not supported + by that resource. + """ + http_status = 405 + message = "Method Not Allowed" + + +class NotAcceptable(HTTPClientError): + """HTTP 406 - Not Acceptable. + + The requested resource is only capable of generating content not + acceptable according to the Accept headers sent in the request. + """ + http_status = 406 + message = "Not Acceptable" + + +class ProxyAuthenticationRequired(HTTPClientError): + """HTTP 407 - Proxy Authentication Required. + + The client must first authenticate itself with the proxy. + """ + http_status = 407 + message = "Proxy Authentication Required" + + +class RequestTimeout(HTTPClientError): + """HTTP 408 - Request Timeout. + + The server timed out waiting for the request. + """ + http_status = 408 + message = "Request Timeout" + + +class Conflict(HTTPClientError): + """HTTP 409 - Conflict. + + Indicates that the request could not be processed because of conflict + in the request, such as an edit conflict. + """ + http_status = 409 + message = "Conflict" + + +class Gone(HTTPClientError): + """HTTP 410 - Gone. + + Indicates that the resource requested is no longer available and will + not be available again. + """ + http_status = 410 + message = "Gone" + + +class LengthRequired(HTTPClientError): + """HTTP 411 - Length Required. + + The request did not specify the length of its content, which is + required by the requested resource. + """ + http_status = 411 + message = "Length Required" + + +class PreconditionFailed(HTTPClientError): + """HTTP 412 - Precondition Failed. + + The server does not meet one of the preconditions that the requester + put on the request. + """ + http_status = 412 + message = "Precondition Failed" + + +class RequestEntityTooLarge(HTTPClientError): + """HTTP 413 - Request Entity Too Large. + + The request is larger than the server is willing or able to process. + """ + http_status = 413 + message = "Request Entity Too Large" + + def __init__(self, *args, **kwargs): + try: + self.retry_after = int(kwargs.pop('retry_after')) + except (KeyError, ValueError): + self.retry_after = 0 + + super(RequestEntityTooLarge, self).__init__(*args, **kwargs) + + +class RequestUriTooLong(HTTPClientError): + """HTTP 414 - Request-URI Too Long. + + The URI provided was too long for the server to process. + """ + http_status = 414 + message = "Request-URI Too Long" + + +class UnsupportedMediaType(HTTPClientError): + """HTTP 415 - Unsupported Media Type. + + The request entity has a media type which the server or resource does + not support. + """ + http_status = 415 + message = "Unsupported Media Type" + + +class RequestedRangeNotSatisfiable(HTTPClientError): + """HTTP 416 - Requested Range Not Satisfiable. + + The client has asked for a portion of the file, but the server cannot + supply that portion. + """ + http_status = 416 + message = "Requested Range Not Satisfiable" + + +class ExpectationFailed(HTTPClientError): + """HTTP 417 - Expectation Failed. + + The server cannot meet the requirements of the Expect request-header field. + """ + http_status = 417 + message = "Expectation Failed" + + +class UnprocessableEntity(HTTPClientError): + """HTTP 422 - Unprocessable Entity. + + The request was well-formed but was unable to be followed due to semantic + errors. + """ + http_status = 422 + message = "Unprocessable Entity" + + +class InternalServerError(HttpServerError): + """HTTP 500 - Internal Server Error. + + A generic error message, given when no more specific message is suitable. + """ + http_status = 500 + message = "Internal Server Error" + + +# NotImplemented is a python keyword. +class HttpNotImplemented(HttpServerError): + """HTTP 501 - Not Implemented. + + The server either does not recognize the request method, or it lacks + the ability to fulfill the request. + """ + http_status = 501 + message = "Not Implemented" + + +class BadGateway(HttpServerError): + """HTTP 502 - Bad Gateway. + + The server was acting as a gateway or proxy and received an invalid + response from the upstream server. + """ + http_status = 502 + message = "Bad Gateway" + + +class ServiceUnavailable(HttpServerError): + """HTTP 503 - Service Unavailable. + + The server is currently unavailable. + """ + http_status = 503 + message = "Service Unavailable" + + +class GatewayTimeout(HttpServerError): + """HTTP 504 - Gateway Timeout. + + The server was acting as a gateway or proxy and did not receive a timely + response from the upstream server. + """ + http_status = 504 + message = "Gateway Timeout" + + +class HttpVersionNotSupported(HttpServerError): + """HTTP 505 - HttpVersion Not Supported. + + The server does not support the HTTP protocol version used in the request. + """ + http_status = 505 + message = "HTTP Version Not Supported" + + +# _code_map contains all the classes that have http_status attribute. +_code_map = dict( + (getattr(obj, 'http_status', None), obj) + for name, obj in six.iteritems(vars(sys.modules[__name__])) + if inspect.isclass(obj) and getattr(obj, 'http_status', False) +) + + +def from_response(response, method, url): + """Returns an instance of :class:`HttpError` or subclass based on response. + + :param response: instance of `requests.Response` class + :param method: HTTP method used for request + :param url: URL used for request + """ + kwargs = { + "http_status": response.status_code, + "response": response, + "method": method, + "url": url, + "request_id": response.headers.get("x-compute-request-id"), + } + if "retry-after" in response.headers: + kwargs["retry_after"] = response.headers["retry-after"] + + content_type = response.headers.get("Content-Type", "") + if content_type.startswith("application/json"): + try: + body = response.json() + except ValueError: + pass + else: + if isinstance(body, dict): + error = list(body.values())[0] + kwargs["message"] = error.get("message") + kwargs["details"] = error.get("details") + elif content_type.startswith("text/"): + kwargs["details"] = response.text + + try: + cls = _code_map[response.status_code] + except KeyError: + if 500 <= response.status_code < 600: + cls = HttpServerError + elif 400 <= response.status_code < 500: + cls = HTTPClientError + else: + cls = HttpError + return cls(**kwargs) diff --git a/cerberus/openstack/common/apiclient/fake_client.py b/cerberus/openstack/common/apiclient/fake_client.py new file mode 100644 index 0000000..c1dfdbe --- /dev/null +++ b/cerberus/openstack/common/apiclient/fake_client.py @@ -0,0 +1,173 @@ +# Copyright 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +A fake server that "responds" to API methods with pre-canned responses. + +All of these responses come from the spec, so if for some reason the spec's +wrong the tests might raise AssertionError. I've indicated in comments the +places where actual behavior differs from the spec. +""" + +# W0102: Dangerous default value %s as argument +# pylint: disable=W0102 + +import json + +import requests +import six +from six.moves.urllib import parse + +from cerberus.openstack.common.apiclient import client + + +def assert_has_keys(dct, required=[], optional=[]): + for k in required: + try: + assert k in dct + except AssertionError: + extra_keys = set(dct.keys()).difference(set(required + optional)) + raise AssertionError("found unexpected keys: %s" % + list(extra_keys)) + + +class TestResponse(requests.Response): + """Wrap requests.Response and provide a convenient initialization. + """ + + def __init__(self, data): + super(TestResponse, self).__init__() + self._content_consumed = True + if isinstance(data, dict): + self.status_code = data.get('status_code', 200) + # Fake the text attribute to streamline Response creation + text = data.get('text', "") + if isinstance(text, (dict, list)): + self._content = json.dumps(text) + default_headers = { + "Content-Type": "application/json", + } + else: + self._content = text + default_headers = {} + if six.PY3 and isinstance(self._content, six.string_types): + self._content = self._content.encode('utf-8', 'strict') + self.headers = data.get('headers') or default_headers + else: + self.status_code = data + + def __eq__(self, other): + return (self.status_code == other.status_code and + self.headers == other.headers and + self._content == other._content) + + +class FakeHTTPClient(client.HTTPClient): + + def __init__(self, *args, **kwargs): + self.callstack = [] + self.fixtures = kwargs.pop("fixtures", None) or {} + if not args and not "auth_plugin" in kwargs: + args = (None, ) + super(FakeHTTPClient, self).__init__(*args, **kwargs) + + def assert_called(self, method, url, body=None, pos=-1): + """Assert than an API method was just called. + """ + expected = (method, url) + called = self.callstack[pos][0:2] + assert self.callstack, \ + "Expected %s %s but no calls were made." % expected + + assert expected == called, 'Expected %s %s; got %s %s' % \ + (expected + called) + + if body is not None: + if self.callstack[pos][3] != body: + raise AssertionError('%r != %r' % + (self.callstack[pos][3], body)) + + def assert_called_anytime(self, method, url, body=None): + """Assert than an API method was called anytime in the test. + """ + expected = (method, url) + + assert self.callstack, \ + "Expected %s %s but no calls were made." % expected + + found = False + entry = None + for entry in self.callstack: + if expected == entry[0:2]: + found = True + break + + assert found, 'Expected %s %s; got %s' % \ + (method, url, self.callstack) + if body is not None: + assert entry[3] == body, "%s != %s" % (entry[3], body) + + self.callstack = [] + + def clear_callstack(self): + self.callstack = [] + + def authenticate(self): + pass + + def client_request(self, client, method, url, **kwargs): + # Check that certain things are called correctly + if method in ["GET", "DELETE"]: + assert "json" not in kwargs + + # Note the call + self.callstack.append( + (method, + url, + kwargs.get("headers") or {}, + kwargs.get("json") or kwargs.get("data"))) + try: + fixture = self.fixtures[url][method] + except KeyError: + pass + else: + return TestResponse({"headers": fixture[0], + "text": fixture[1]}) + + # Call the method + args = parse.parse_qsl(parse.urlparse(url)[4]) + kwargs.update(args) + munged_url = url.rsplit('?', 1)[0] + munged_url = munged_url.strip('/').replace('/', '_').replace('.', '_') + munged_url = munged_url.replace('-', '_') + + callback = "%s_%s" % (method.lower(), munged_url) + + if not hasattr(self, callback): + raise AssertionError('Called unknown API method: %s %s, ' + 'expected fakes method name: %s' % + (method, url, callback)) + + resp = getattr(self, callback)(**kwargs) + if len(resp) == 3: + status, headers, body = resp + else: + status, body = resp + headers = {} + return TestResponse({ + "status_code": status, + "text": body, + "headers": headers, + }) diff --git a/cerberus/openstack/common/cliutils.py b/cerberus/openstack/common/cliutils.py new file mode 100644 index 0000000..a99ea4d --- /dev/null +++ b/cerberus/openstack/common/cliutils.py @@ -0,0 +1,309 @@ +# Copyright 2012 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +# W0603: Using the global statement +# W0621: Redefining name %s from outer scope +# pylint: disable=W0603,W0621 + +from __future__ import print_function + +import getpass +import inspect +import os +import sys +import textwrap + +import prettytable +import six +from six import moves + +from cerberus.openstack.common.apiclient import exceptions +from cerberus.openstack.common.gettextutils import _ +from cerberus.openstack.common import strutils +from cerberus.openstack.common import uuidutils + + +def validate_args(fn, *args, **kwargs): + """Check that the supplied args are sufficient for calling a function. + + >>> validate_args(lambda a: None) + Traceback (most recent call last): + ... + MissingArgs: Missing argument(s): a + >>> validate_args(lambda a, b, c, d: None, 0, c=1) + Traceback (most recent call last): + ... + MissingArgs: Missing argument(s): b, d + + :param fn: the function to check + :param arg: the positional arguments supplied + :param kwargs: the keyword arguments supplied + """ + argspec = inspect.getargspec(fn) + + num_defaults = len(argspec.defaults or []) + required_args = argspec.args[:len(argspec.args) - num_defaults] + + def isbound(method): + return getattr(method, 'im_self', None) is not None + + if isbound(fn): + required_args.pop(0) + + missing = [arg for arg in required_args if arg not in kwargs] + missing = missing[len(args):] + if missing: + raise exceptions.MissingArgs(missing) + + +def arg(*args, **kwargs): + """Decorator for CLI args. + + Example: + + >>> @arg("name", help="Name of the new entity") + ... def entity_create(args): + ... pass + """ + def _decorator(func): + add_arg(func, *args, **kwargs) + return func + return _decorator + + +def env(*args, **kwargs): + """Returns the first environment variable set. + + If all are empty, defaults to '' or keyword arg `default`. + """ + for arg in args: + value = os.environ.get(arg) + if value: + return value + return kwargs.get('default', '') + + +def add_arg(func, *args, **kwargs): + """Bind CLI arguments to a shell.py `do_foo` function.""" + + if not hasattr(func, 'arguments'): + func.arguments = [] + + # NOTE(sirp): avoid dups that can occur when the module is shared across + # tests. + if (args, kwargs) not in func.arguments: + # Because of the semantics of decorator composition if we just append + # to the options list positional options will appear to be backwards. + func.arguments.insert(0, (args, kwargs)) + + +def unauthenticated(func): + """Adds 'unauthenticated' attribute to decorated function. + + Usage: + + >>> @unauthenticated + ... def mymethod(f): + ... pass + """ + func.unauthenticated = True + return func + + +def isunauthenticated(func): + """Checks if the function does not require authentication. + + Mark such functions with the `@unauthenticated` decorator. + + :returns: bool + """ + return getattr(func, 'unauthenticated', False) + + +def print_list(objs, fields, formatters=None, sortby_index=0, + mixed_case_fields=None): + """Print a list or objects as a table, one row per object. + + :param objs: iterable of :class:`Resource` + :param fields: attributes that correspond to columns, in order + :param formatters: `dict` of callables for field formatting + :param sortby_index: index of the field for sorting table rows + :param mixed_case_fields: fields corresponding to object attributes that + have mixed case names (e.g., 'serverId') + """ + formatters = formatters or {} + mixed_case_fields = mixed_case_fields or [] + if sortby_index is None: + kwargs = {} + else: + kwargs = {'sortby': fields[sortby_index]} + pt = prettytable.PrettyTable(fields, caching=False) + pt.align = 'l' + + for o in objs: + row = [] + for field in fields: + if field in formatters: + row.append(formatters[field](o)) + else: + if field in mixed_case_fields: + field_name = field.replace(' ', '_') + else: + field_name = field.lower().replace(' ', '_') + data = getattr(o, field_name, '') + row.append(data) + pt.add_row(row) + + print(strutils.safe_encode(pt.get_string(**kwargs))) + + +def print_dict(dct, dict_property="Property", wrap=0): + """Print a `dict` as a table of two columns. + + :param dct: `dict` to print + :param dict_property: name of the first column + :param wrap: wrapping for the second column + """ + pt = prettytable.PrettyTable([dict_property, 'Value'], caching=False) + pt.align = 'l' + for k, v in six.iteritems(dct): + # convert dict to str to check length + if isinstance(v, dict): + v = six.text_type(v) + if wrap > 0: + v = textwrap.fill(six.text_type(v), wrap) + # if value has a newline, add in multiple rows + # e.g. fault with stacktrace + if v and isinstance(v, six.string_types) and r'\n' in v: + lines = v.strip().split(r'\n') + col1 = k + for line in lines: + pt.add_row([col1, line]) + col1 = '' + else: + pt.add_row([k, v]) + print(strutils.safe_encode(pt.get_string())) + + +def get_password(max_password_prompts=3): + """Read password from TTY.""" + verify = strutils.bool_from_string(env("OS_VERIFY_PASSWORD")) + pw = None + if hasattr(sys.stdin, "isatty") and sys.stdin.isatty(): + # Check for Ctrl-D + try: + for __ in moves.range(max_password_prompts): + pw1 = getpass.getpass("OS Password: ") + if verify: + pw2 = getpass.getpass("Please verify: ") + else: + pw2 = pw1 + if pw1 == pw2 and pw1: + pw = pw1 + break + except EOFError: + pass + return pw + + +def find_resource(manager, name_or_id, **find_args): + """Look for resource in a given manager. + + Used as a helper for the _find_* methods. + Example: + + def _find_hypervisor(cs, hypervisor): + #Get a hypervisor by name or ID. + return cliutils.find_resource(cs.hypervisors, hypervisor) + """ + # first try to get entity as integer id + try: + return manager.get(int(name_or_id)) + except (TypeError, ValueError, exceptions.NotFound): + pass + + # now try to get entity as uuid + try: + tmp_id = strutils.safe_encode(name_or_id) + + if uuidutils.is_uuid_like(tmp_id): + return manager.get(tmp_id) + except (TypeError, ValueError, exceptions.NotFound): + pass + + # for str id which is not uuid + if getattr(manager, 'is_alphanum_id_allowed', False): + try: + return manager.get(name_or_id) + except exceptions.NotFound: + pass + + try: + try: + return manager.find(human_id=name_or_id, **find_args) + except exceptions.NotFound: + pass + + # finally try to find entity by name + try: + resource = getattr(manager, 'resource_class', None) + name_attr = resource.NAME_ATTR if resource else 'name' + kwargs = {name_attr: name_or_id} + kwargs.update(find_args) + return manager.find(**kwargs) + except exceptions.NotFound: + msg = _("No %(name)s with a name or " + "ID of '%(name_or_id)s' exists.") % \ + { + "name": manager.resource_class.__name__.lower(), + "name_or_id": name_or_id + } + raise exceptions.CommandError(msg) + except exceptions.NoUniqueMatch: + msg = _("Multiple %(name)s matches found for " + "'%(name_or_id)s', use an ID to be more specific.") % \ + { + "name": manager.resource_class.__name__.lower(), + "name_or_id": name_or_id + } + raise exceptions.CommandError(msg) + + +def service_type(stype): + """Adds 'service_type' attribute to decorated function. + + Usage: + @service_type('volume') + def mymethod(f): + ... + """ + def inner(f): + f.service_type = stype + return f + return inner + + +def get_service_type(f): + """Retrieves service type from function.""" + return getattr(f, 'service_type', None) + + +def pretty_choice_list(l): + return ', '.join("'%s'" % i for i in l) + + +def exit(msg=''): + if msg: + print (msg, file=sys.stderr) + sys.exit(1) diff --git a/cerberus/openstack/common/config/__init__.py b/cerberus/openstack/common/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cerberus/openstack/common/config/generator.py b/cerberus/openstack/common/config/generator.py new file mode 100644 index 0000000..8808dcf --- /dev/null +++ b/cerberus/openstack/common/config/generator.py @@ -0,0 +1,307 @@ +# Copyright 2012 SINA Corporation +# Copyright 2014 Cisco Systems, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# + +"""Extracts OpenStack config option info from module(s).""" + +from __future__ import print_function + +import argparse +import imp +import os +import re +import socket +import sys +import textwrap + +from oslo.config import cfg +import six +import stevedore.named + +from cerberus.openstack.common import gettextutils +from cerberus.openstack.common import importutils + +gettextutils.install('cerberus') + +STROPT = "StrOpt" +BOOLOPT = "BoolOpt" +INTOPT = "IntOpt" +FLOATOPT = "FloatOpt" +LISTOPT = "ListOpt" +DICTOPT = "DictOpt" +MULTISTROPT = "MultiStrOpt" + +OPT_TYPES = { + STROPT: 'string value', + BOOLOPT: 'boolean value', + INTOPT: 'integer value', + FLOATOPT: 'floating point value', + LISTOPT: 'list value', + DICTOPT: 'dict value', + MULTISTROPT: 'multi valued', +} + +OPTION_REGEX = re.compile(r"(%s)" % "|".join([STROPT, BOOLOPT, INTOPT, + FLOATOPT, LISTOPT, DICTOPT, + MULTISTROPT])) + +PY_EXT = ".py" +BASEDIR = os.path.abspath(os.path.join(os.path.dirname(__file__), + "../../../../")) +WORDWRAP_WIDTH = 60 + + +def raise_extension_exception(extmanager, ep, err): + raise + + +def generate(argv): + parser = argparse.ArgumentParser( + description='generate sample configuration file', + ) + parser.add_argument('-m', dest='modules', action='append') + parser.add_argument('-l', dest='libraries', action='append') + parser.add_argument('srcfiles', nargs='*') + parsed_args = parser.parse_args(argv) + + mods_by_pkg = dict() + for filepath in parsed_args.srcfiles: + pkg_name = filepath.split(os.sep)[1] + mod_str = '.'.join(['.'.join(filepath.split(os.sep)[:-1]), + os.path.basename(filepath).split('.')[0]]) + mods_by_pkg.setdefault(pkg_name, list()).append(mod_str) + # NOTE(lzyeval): place top level modules before packages + pkg_names = sorted(pkg for pkg in mods_by_pkg if pkg.endswith(PY_EXT)) + ext_names = sorted(pkg for pkg in mods_by_pkg if pkg not in pkg_names) + pkg_names.extend(ext_names) + + # opts_by_group is a mapping of group name to an options list + # The options list is a list of (module, options) tuples + opts_by_group = {'DEFAULT': []} + + if parsed_args.modules: + for module_name in parsed_args.modules: + module = _import_module(module_name) + if module: + for group, opts in _list_opts(module): + opts_by_group.setdefault(group, []).append((module_name, + opts)) + + # Look for entry points defined in libraries (or applications) for + # option discovery, and include their return values in the output. + # + # Each entry point should be a function returning an iterable + # of pairs with the group name (or None for the default group) + # and the list of Opt instances for that group. + if parsed_args.libraries: + loader = stevedore.named.NamedExtensionManager( + 'oslo.config.opts', + names=list(set(parsed_args.libraries)), + invoke_on_load=False, + on_load_failure_callback=raise_extension_exception + ) + for ext in loader: + for group, opts in ext.plugin(): + opt_list = opts_by_group.setdefault(group or 'DEFAULT', []) + opt_list.append((ext.name, opts)) + + for pkg_name in pkg_names: + mods = mods_by_pkg.get(pkg_name) + mods.sort() + for mod_str in mods: + if mod_str.endswith('.__init__'): + mod_str = mod_str[:mod_str.rfind(".")] + + mod_obj = _import_module(mod_str) + if not mod_obj: + raise RuntimeError("Unable to import module %s" % mod_str) + + for group, opts in _list_opts(mod_obj): + opts_by_group.setdefault(group, []).append((mod_str, opts)) + + print_group_opts('DEFAULT', opts_by_group.pop('DEFAULT', [])) + for group in sorted(opts_by_group.keys()): + print_group_opts(group, opts_by_group[group]) + + +def _import_module(mod_str): + try: + if mod_str.startswith('bin.'): + imp.load_source(mod_str[4:], os.path.join('bin', mod_str[4:])) + return sys.modules[mod_str[4:]] + else: + return importutils.import_module(mod_str) + except Exception as e: + sys.stderr.write("Error importing module %s: %s\n" % (mod_str, str(e))) + return None + + +def _is_in_group(opt, group): + "Check if opt is in group." + for value in group._opts.values(): + # NOTE(llu): Temporary workaround for bug #1262148, wait until + # newly released oslo.config support '==' operator. + if not(value['opt'] != opt): + return True + return False + + +def _guess_groups(opt, mod_obj): + # is it in the DEFAULT group? + if _is_in_group(opt, cfg.CONF): + return 'DEFAULT' + + # what other groups is it in? + for value in cfg.CONF.values(): + if isinstance(value, cfg.CONF.GroupAttr): + if _is_in_group(opt, value._group): + return value._group.name + + raise RuntimeError( + "Unable to find group for option %s, " + "maybe it's defined twice in the same group?" + % opt.name + ) + + +def _list_opts(obj): + def is_opt(o): + return (isinstance(o, cfg.Opt) and + not isinstance(o, cfg.SubCommandOpt)) + + opts = list() + for attr_str in dir(obj): + attr_obj = getattr(obj, attr_str) + if is_opt(attr_obj): + opts.append(attr_obj) + elif (isinstance(attr_obj, list) and + all(map(lambda x: is_opt(x), attr_obj))): + opts.extend(attr_obj) + + ret = {} + for opt in opts: + ret.setdefault(_guess_groups(opt, obj), []).append(opt) + return ret.items() + + +def print_group_opts(group, opts_by_module): + print("[%s]" % group) + print('') + for mod, opts in opts_by_module: + print('#') + print('# Options defined in %s' % mod) + print('#') + print('') + for opt in opts: + _print_opt(opt) + print('') + + +def _get_my_ip(): + try: + csock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + csock.connect(('8.8.8.8', 80)) + (addr, port) = csock.getsockname() + csock.close() + return addr + except socket.error: + return None + + +def _sanitize_default(name, value): + """Set up a reasonably sensible default for pybasedir, my_ip and host.""" + if value.startswith(sys.prefix): + # NOTE(jd) Don't use os.path.join, because it is likely to think the + # second part is an absolute pathname and therefore drop the first + # part. + value = os.path.normpath("/usr/" + value[len(sys.prefix):]) + elif value.startswith(BASEDIR): + return value.replace(BASEDIR, '/usr/lib/python/site-packages') + elif BASEDIR in value: + return value.replace(BASEDIR, '') + elif value == _get_my_ip(): + return '10.0.0.1' + elif value in (socket.gethostname(), socket.getfqdn()) and 'host' in name: + return 'cerberus' + elif value.strip() != value: + return '"%s"' % value + return value + + +def _print_opt(opt): + opt_name, opt_default, opt_help = opt.dest, opt.default, opt.help + if not opt_help: + sys.stderr.write('WARNING: "%s" is missing help string.\n' % opt_name) + opt_help = "" + opt_type = None + try: + opt_type = OPTION_REGEX.search(str(type(opt))).group(0) + except (ValueError, AttributeError) as err: + sys.stderr.write("%s\n" % str(err)) + sys.exit(1) + opt_help = u'%s (%s)' % (opt_help, + OPT_TYPES[opt_type]) + print('#', "\n# ".join(textwrap.wrap(opt_help, WORDWRAP_WIDTH))) + if opt.deprecated_opts: + for deprecated_opt in opt.deprecated_opts: + if deprecated_opt.name: + deprecated_group = (deprecated_opt.group if + deprecated_opt.group else "DEFAULT") + print('# Deprecated group/name - [%s]/%s' % + (deprecated_group, + deprecated_opt.name)) + try: + if opt_default is None: + print('#%s=' % opt_name) + elif opt_type == STROPT: + assert(isinstance(opt_default, six.string_types)) + print('#%s=%s' % (opt_name, _sanitize_default(opt_name, + opt_default))) + elif opt_type == BOOLOPT: + assert(isinstance(opt_default, bool)) + print('#%s=%s' % (opt_name, str(opt_default).lower())) + elif opt_type == INTOPT: + assert(isinstance(opt_default, int) and + not isinstance(opt_default, bool)) + print('#%s=%s' % (opt_name, opt_default)) + elif opt_type == FLOATOPT: + assert(isinstance(opt_default, float)) + print('#%s=%s' % (opt_name, opt_default)) + elif opt_type == LISTOPT: + assert(isinstance(opt_default, list)) + print('#%s=%s' % (opt_name, ','.join(opt_default))) + elif opt_type == DICTOPT: + assert(isinstance(opt_default, dict)) + opt_default_strlist = [str(key) + ':' + str(value) + for (key, value) in opt_default.items()] + print('#%s=%s' % (opt_name, ','.join(opt_default_strlist))) + elif opt_type == MULTISTROPT: + assert(isinstance(opt_default, list)) + if not opt_default: + opt_default = [''] + for default in opt_default: + print('#%s=%s' % (opt_name, default)) + print('') + except Exception: + sys.stderr.write('Error in option "%s"\n' % opt_name) + sys.exit(1) + + +def main(): + generate(sys.argv[1:]) + +if __name__ == '__main__': + main() diff --git a/cerberus/openstack/common/context.py b/cerberus/openstack/common/context.py new file mode 100644 index 0000000..09019ee --- /dev/null +++ b/cerberus/openstack/common/context.py @@ -0,0 +1,111 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Simple class that stores security context information in the web request. + +Projects should subclass this class if they wish to enhance the request +context or provide additional information in their specific WSGI pipeline. +""" + +import itertools +import uuid + + +def generate_request_id(): + return 'req-%s' % str(uuid.uuid4()) + + +class RequestContext(object): + + """Helper class to represent useful information about a request context. + + Stores information about the security context under which the user + accesses the system, as well as additional request information. + """ + + user_idt_format = '{user} {tenant} {domain} {user_domain} {p_domain}' + + def __init__(self, auth_token=None, user=None, tenant=None, domain=None, + user_domain=None, project_domain=None, is_admin=False, + read_only=False, show_deleted=False, request_id=None, + instance_uuid=None): + self.auth_token = auth_token + self.user = user + self.tenant = tenant + self.domain = domain + self.user_domain = user_domain + self.project_domain = project_domain + self.is_admin = is_admin + self.read_only = read_only + self.show_deleted = show_deleted + self.instance_uuid = instance_uuid + if not request_id: + request_id = generate_request_id() + self.request_id = request_id + + def to_dict(self): + user_idt = ( + self.user_idt_format.format(user=self.user or '-', + tenant=self.tenant or '-', + domain=self.domain or '-', + user_domain=self.user_domain or '-', + p_domain=self.project_domain or '-')) + + return {'user': self.user, + 'tenant': self.tenant, + 'domain': self.domain, + 'user_domain': self.user_domain, + 'project_domain': self.project_domain, + 'is_admin': self.is_admin, + 'read_only': self.read_only, + 'show_deleted': self.show_deleted, + 'auth_token': self.auth_token, + 'request_id': self.request_id, + 'instance_uuid': self.instance_uuid, + 'user_identity': user_idt} + + +def get_admin_context(show_deleted=False): + context = RequestContext(None, + tenant=None, + is_admin=True, + show_deleted=show_deleted) + return context + + +def get_context_from_function_and_args(function, args, kwargs): + """Find an arg of type RequestContext and return it. + + This is useful in a couple of decorators where we don't + know much about the function we're wrapping. + """ + + for arg in itertools.chain(kwargs.values(), args): + if isinstance(arg, RequestContext): + return arg + + return None + + +def is_user_context(context): + """Indicates if the request context is a normal user.""" + if not context: + return False + if context.is_admin: + return False + if not context.user_id or not context.project_id: + return False + return True diff --git a/cerberus/openstack/common/db/__init__.py b/cerberus/openstack/common/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cerberus/openstack/common/db/api.py b/cerberus/openstack/common/db/api.py new file mode 100644 index 0000000..0025c34 --- /dev/null +++ b/cerberus/openstack/common/db/api.py @@ -0,0 +1,162 @@ +# Copyright (c) 2013 Rackspace Hosting +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Multiple DB API backend support. + +A DB backend module should implement a method named 'get_backend' which +takes no arguments. The method can return any object that implements DB +API methods. +""" + +import functools +import logging +import threading +import time + +from cerberus.openstack.common.db import exception +from cerberus.openstack.common.gettextutils import _LE +from cerberus.openstack.common import importutils + + +LOG = logging.getLogger(__name__) + + +def safe_for_db_retry(f): + """Enable db-retry for decorated function, if config option enabled.""" + f.__dict__['enable_retry'] = True + return f + + +class wrap_db_retry(object): + """Retry db.api methods, if DBConnectionError() raised + + Retry decorated db.api methods. If we enabled `use_db_reconnect` + in config, this decorator will be applied to all db.api functions, + marked with @safe_for_db_retry decorator. + Decorator catchs DBConnectionError() and retries function in a + loop until it succeeds, or until maximum retries count will be reached. + """ + + def __init__(self, retry_interval, max_retries, inc_retry_interval, + max_retry_interval): + super(wrap_db_retry, self).__init__() + + self.retry_interval = retry_interval + self.max_retries = max_retries + self.inc_retry_interval = inc_retry_interval + self.max_retry_interval = max_retry_interval + + def __call__(self, f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + next_interval = self.retry_interval + remaining = self.max_retries + + while True: + try: + return f(*args, **kwargs) + except exception.DBConnectionError as e: + if remaining == 0: + LOG.exception(_LE('DB exceeded retry limit.')) + raise exception.DBError(e) + if remaining != -1: + remaining -= 1 + LOG.exception(_LE('DB connection error.')) + # NOTE(vsergeyev): We are using patched time module, so + # this effectively yields the execution + # context to another green thread. + time.sleep(next_interval) + if self.inc_retry_interval: + next_interval = min( + next_interval * 2, + self.max_retry_interval + ) + return wrapper + + +class DBAPI(object): + def __init__(self, backend_name, backend_mapping=None, lazy=False, + **kwargs): + """Initialize the chosen DB API backend. + + :param backend_name: name of the backend to load + :type backend_name: str + + :param backend_mapping: backend name -> module/class to load mapping + :type backend_mapping: dict + + :param lazy: load the DB backend lazily on the first DB API method call + :type lazy: bool + + Keyword arguments: + + :keyword use_db_reconnect: retry DB transactions on disconnect or not + :type use_db_reconnect: bool + + :keyword retry_interval: seconds between transaction retries + :type retry_interval: int + + :keyword inc_retry_interval: increase retry interval or not + :type inc_retry_interval: bool + + :keyword max_retry_interval: max interval value between retries + :type max_retry_interval: int + + :keyword max_retries: max number of retries before an error is raised + :type max_retries: int + + """ + + self._backend = None + self._backend_name = backend_name + self._backend_mapping = backend_mapping or {} + self._lock = threading.Lock() + + if not lazy: + self._load_backend() + + self.use_db_reconnect = kwargs.get('use_db_reconnect', False) + self.retry_interval = kwargs.get('retry_interval', 1) + self.inc_retry_interval = kwargs.get('inc_retry_interval', True) + self.max_retry_interval = kwargs.get('max_retry_interval', 10) + self.max_retries = kwargs.get('max_retries', 20) + + def _load_backend(self): + with self._lock: + if not self._backend: + # Import the untranslated name if we don't have a mapping + backend_path = self._backend_mapping.get(self._backend_name, + self._backend_name) + backend_mod = importutils.import_module(backend_path) + self._backend = backend_mod.get_backend() + + def __getattr__(self, key): + if not self._backend: + self._load_backend() + + attr = getattr(self._backend, key) + if not hasattr(attr, '__call__'): + return attr + # NOTE(vsergeyev): If `use_db_reconnect` option is set to True, retry + # DB API methods, decorated with @safe_for_db_retry + # on disconnect. + if self.use_db_reconnect and hasattr(attr, 'enable_retry'): + attr = wrap_db_retry( + retry_interval=self.retry_interval, + max_retries=self.max_retries, + inc_retry_interval=self.inc_retry_interval, + max_retry_interval=self.max_retry_interval)(attr) + + return attr diff --git a/cerberus/openstack/common/db/exception.py b/cerberus/openstack/common/db/exception.py new file mode 100644 index 0000000..1be2db5 --- /dev/null +++ b/cerberus/openstack/common/db/exception.py @@ -0,0 +1,56 @@ +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""DB related custom exceptions.""" + +import six + +from cerberus.openstack.common.gettextutils import _ + + +class DBError(Exception): + """Wraps an implementation specific exception.""" + def __init__(self, inner_exception=None): + self.inner_exception = inner_exception + super(DBError, self).__init__(six.text_type(inner_exception)) + + +class DBDuplicateEntry(DBError): + """Wraps an implementation specific exception.""" + def __init__(self, columns=[], inner_exception=None): + self.columns = columns + super(DBDuplicateEntry, self).__init__(inner_exception) + + +class DBDeadlock(DBError): + def __init__(self, inner_exception=None): + super(DBDeadlock, self).__init__(inner_exception) + + +class DBInvalidUnicodeParameter(Exception): + message = _("Invalid Parameter: " + "Unicode is not supported by the current database.") + + +class DbMigrationError(DBError): + """Wraps migration specific exception.""" + def __init__(self, message=None): + super(DbMigrationError, self).__init__(message) + + +class DBConnectionError(DBError): + """Wraps connection specific exception.""" + pass diff --git a/cerberus/openstack/common/db/options.py b/cerberus/openstack/common/db/options.py new file mode 100644 index 0000000..61e4ce1 --- /dev/null +++ b/cerberus/openstack/common/db/options.py @@ -0,0 +1,171 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import copy + +from oslo.config import cfg + + +database_opts = [ + cfg.StrOpt('sqlite_db', + deprecated_group='DEFAULT', + default='cerberus.sqlite', + help='The file name to use with SQLite'), + cfg.BoolOpt('sqlite_synchronous', + deprecated_group='DEFAULT', + default=True, + help='If True, SQLite uses synchronous mode'), + cfg.StrOpt('backend', + default='sqlalchemy', + deprecated_name='db_backend', + deprecated_group='DEFAULT', + help='The backend to use for db'), + cfg.StrOpt('connection', + help='The SQLAlchemy connection string used to connect to the ' + 'database', + secret=True, + deprecated_opts=[cfg.DeprecatedOpt('sql_connection', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_connection', + group='DATABASE'), + cfg.DeprecatedOpt('connection', + group='sql'), ]), + cfg.StrOpt('mysql_sql_mode', + default='TRADITIONAL', + help='The SQL mode to be used for MySQL sessions. ' + 'This option, including the default, overrides any ' + 'server-set SQL mode. To use whatever SQL mode ' + 'is set by the server configuration, ' + 'set this to no value. Example: mysql_sql_mode='), + cfg.IntOpt('idle_timeout', + default=3600, + deprecated_opts=[cfg.DeprecatedOpt('sql_idle_timeout', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_idle_timeout', + group='DATABASE'), + cfg.DeprecatedOpt('idle_timeout', + group='sql')], + help='Timeout before idle sql connections are reaped'), + cfg.IntOpt('min_pool_size', + default=1, + deprecated_opts=[cfg.DeprecatedOpt('sql_min_pool_size', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_min_pool_size', + group='DATABASE')], + help='Minimum number of SQL connections to keep open in a ' + 'pool'), + cfg.IntOpt('max_pool_size', + default=None, + deprecated_opts=[cfg.DeprecatedOpt('sql_max_pool_size', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_max_pool_size', + group='DATABASE')], + help='Maximum number of SQL connections to keep open in a ' + 'pool'), + cfg.IntOpt('max_retries', + default=10, + deprecated_opts=[cfg.DeprecatedOpt('sql_max_retries', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_max_retries', + group='DATABASE')], + help='Maximum db connection retries during startup. ' + '(setting -1 implies an infinite retry count)'), + cfg.IntOpt('retry_interval', + default=10, + deprecated_opts=[cfg.DeprecatedOpt('sql_retry_interval', + group='DEFAULT'), + cfg.DeprecatedOpt('reconnect_interval', + group='DATABASE')], + help='Interval between retries of opening a sql connection'), + cfg.IntOpt('max_overflow', + default=None, + deprecated_opts=[cfg.DeprecatedOpt('sql_max_overflow', + group='DEFAULT'), + cfg.DeprecatedOpt('sqlalchemy_max_overflow', + group='DATABASE')], + help='If set, use this value for max_overflow with sqlalchemy'), + cfg.IntOpt('connection_debug', + default=0, + deprecated_opts=[cfg.DeprecatedOpt('sql_connection_debug', + group='DEFAULT')], + help='Verbosity of SQL debugging information. 0=None, ' + '100=Everything'), + cfg.BoolOpt('connection_trace', + default=False, + deprecated_opts=[cfg.DeprecatedOpt('sql_connection_trace', + group='DEFAULT')], + help='Add python stack traces to SQL as comment strings'), + cfg.IntOpt('pool_timeout', + default=None, + deprecated_opts=[cfg.DeprecatedOpt('sqlalchemy_pool_timeout', + group='DATABASE')], + help='If set, use this value for pool_timeout with sqlalchemy'), + cfg.BoolOpt('use_db_reconnect', + default=False, + help='Enable the experimental use of database reconnect ' + 'on connection lost'), + cfg.IntOpt('db_retry_interval', + default=1, + help='seconds between db connection retries'), + cfg.BoolOpt('db_inc_retry_interval', + default=True, + help='Whether to increase interval between db connection ' + 'retries, up to db_max_retry_interval'), + cfg.IntOpt('db_max_retry_interval', + default=10, + help='max seconds between db connection retries, if ' + 'db_inc_retry_interval is enabled'), + cfg.IntOpt('db_max_retries', + default=20, + help='maximum db connection retries before error is raised. ' + '(setting -1 implies an infinite retry count)'), +] + +CONF = cfg.CONF +CONF.register_opts(database_opts, 'database') + + +def set_defaults(sql_connection, sqlite_db, max_pool_size=None, + max_overflow=None, pool_timeout=None): + """Set defaults for configuration variables.""" + cfg.set_defaults(database_opts, + connection=sql_connection, + sqlite_db=sqlite_db) + # Update the QueuePool defaults + if max_pool_size is not None: + cfg.set_defaults(database_opts, + max_pool_size=max_pool_size) + if max_overflow is not None: + cfg.set_defaults(database_opts, + max_overflow=max_overflow) + if pool_timeout is not None: + cfg.set_defaults(database_opts, + pool_timeout=pool_timeout) + + +def list_opts(): + """Returns a list of oslo.config options available in the library. + + The returned list includes all oslo.config options which may be registered + at runtime by the library. + + Each element of the list is a tuple. The first element is the name of the + group under which the list of elements in the second element will be + registered. A group name of None corresponds to the [DEFAULT] group in + config files. + + The purpose of this is to allow tools like the Oslo sample config file + generator to discover the options exposed to users by this library. + + :returns: a list of (group_name, opts) tuples + """ + return [('database', copy.deepcopy(database_opts))] diff --git a/cerberus/openstack/common/db/sqlalchemy/__init__.py b/cerberus/openstack/common/db/sqlalchemy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cerberus/openstack/common/db/sqlalchemy/migration.py b/cerberus/openstack/common/db/sqlalchemy/migration.py new file mode 100644 index 0000000..f728ae2 --- /dev/null +++ b/cerberus/openstack/common/db/sqlalchemy/migration.py @@ -0,0 +1,278 @@ +# coding: utf-8 +# +# Copyright (c) 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# Base on code in migrate/changeset/databases/sqlite.py which is under +# the following license: +# +# The MIT License +# +# Copyright (c) 2009 Evan Rosson, Jan Dittberner, Domen Kožar +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os +import re + +from migrate.changeset import ansisql +from migrate.changeset.databases import sqlite +from migrate import exceptions as versioning_exceptions +from migrate.versioning import api as versioning_api +from migrate.versioning.repository import Repository +import sqlalchemy +from sqlalchemy.schema import UniqueConstraint + +from cerberus.openstack.common.db import exception +from cerberus.openstack.common.gettextutils import _ + + +def _get_unique_constraints(self, table): + """Retrieve information about existing unique constraints of the table + + This feature is needed for _recreate_table() to work properly. + Unfortunately, it's not available in sqlalchemy 0.7.x/0.8.x. + + """ + + data = table.metadata.bind.execute( + """SELECT sql + FROM sqlite_master + WHERE + type='table' AND + name=:table_name""", + table_name=table.name + ).fetchone()[0] + + UNIQUE_PATTERN = "CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)" + return [ + UniqueConstraint( + *[getattr(table.columns, c.strip(' "')) for c in cols.split(",")], + name=name + ) + for name, cols in re.findall(UNIQUE_PATTERN, data) + ] + + +def _recreate_table(self, table, column=None, delta=None, omit_uniques=None): + """Recreate the table properly + + Unlike the corresponding original method of sqlalchemy-migrate this one + doesn't drop existing unique constraints when creating a new one. + + """ + + table_name = self.preparer.format_table(table) + + # we remove all indexes so as not to have + # problems during copy and re-create + for index in table.indexes: + index.drop() + + # reflect existing unique constraints + for uc in self._get_unique_constraints(table): + table.append_constraint(uc) + # omit given unique constraints when creating a new table if required + table.constraints = set([ + cons for cons in table.constraints + if omit_uniques is None or cons.name not in omit_uniques + ]) + + self.append('ALTER TABLE %s RENAME TO migration_tmp' % table_name) + self.execute() + + insertion_string = self._modify_table(table, column, delta) + + table.create(bind=self.connection) + self.append(insertion_string % {'table_name': table_name}) + self.execute() + self.append('DROP TABLE migration_tmp') + self.execute() + + +def _visit_migrate_unique_constraint(self, *p, **k): + """Drop the given unique constraint + + The corresponding original method of sqlalchemy-migrate just + raises NotImplemented error + + """ + + self.recreate_table(p[0].table, omit_uniques=[p[0].name]) + + +def patch_migrate(): + """A workaround for SQLite's inability to alter things + + SQLite abilities to alter tables are very limited (please read + http://www.sqlite.org/lang_altertable.html for more details). + E. g. one can't drop a column or a constraint in SQLite. The + workaround for this is to recreate the original table omitting + the corresponding constraint (or column). + + sqlalchemy-migrate library has recreate_table() method that + implements this workaround, but it does it wrong: + + - information about unique constraints of a table + is not retrieved. So if you have a table with one + unique constraint and a migration adding another one + you will end up with a table that has only the + latter unique constraint, and the former will be lost + + - dropping of unique constraints is not supported at all + + The proper way to fix this is to provide a pull-request to + sqlalchemy-migrate, but the project seems to be dead. So we + can go on with monkey-patching of the lib at least for now. + + """ + + # this patch is needed to ensure that recreate_table() doesn't drop + # existing unique constraints of the table when creating a new one + helper_cls = sqlite.SQLiteHelper + helper_cls.recreate_table = _recreate_table + helper_cls._get_unique_constraints = _get_unique_constraints + + # this patch is needed to be able to drop existing unique constraints + constraint_cls = sqlite.SQLiteConstraintDropper + constraint_cls.visit_migrate_unique_constraint = \ + _visit_migrate_unique_constraint + constraint_cls.__bases__ = (ansisql.ANSIColumnDropper, + sqlite.SQLiteConstraintGenerator) + + +def db_sync(engine, abs_path, version=None, init_version=0, sanity_check=True): + """Upgrade or downgrade a database. + + Function runs the upgrade() or downgrade() functions in change scripts. + + :param engine: SQLAlchemy engine instance for a given database + :param abs_path: Absolute path to migrate repository. + :param version: Database will upgrade/downgrade until this version. + If None - database will update to the latest + available version. + :param init_version: Initial database version + :param sanity_check: Require schema sanity checking for all tables + """ + + if version is not None: + try: + version = int(version) + except ValueError: + raise exception.DbMigrationError( + message=_("version should be an integer")) + + current_version = db_version(engine, abs_path, init_version) + repository = _find_migrate_repo(abs_path) + if sanity_check: + _db_schema_sanity_check(engine) + if version is None or version > current_version: + return versioning_api.upgrade(engine, repository, version) + else: + return versioning_api.downgrade(engine, repository, + version) + + +def _db_schema_sanity_check(engine): + """Ensure all database tables were created with required parameters. + + :param engine: SQLAlchemy engine instance for a given database + + """ + + if engine.name == 'mysql': + onlyutf8_sql = ('SELECT TABLE_NAME,TABLE_COLLATION ' + 'from information_schema.TABLES ' + 'where TABLE_SCHEMA=%s and ' + 'TABLE_COLLATION NOT LIKE "%%utf8%%"') + + # NOTE(morganfainberg): exclude the sqlalchemy-migrate and alembic + # versioning tables from the tables we need to verify utf8 status on. + # Non-standard table names are not supported. + EXCLUDED_TABLES = ['migrate_version', 'alembic_version'] + + table_names = [res[0] for res in + engine.execute(onlyutf8_sql, engine.url.database) if + res[0].lower() not in EXCLUDED_TABLES] + + if len(table_names) > 0: + raise ValueError(_('Tables "%s" have non utf8 collation, ' + 'please make sure all tables are CHARSET=utf8' + ) % ','.join(table_names)) + + +def db_version(engine, abs_path, init_version): + """Show the current version of the repository. + + :param engine: SQLAlchemy engine instance for a given database + :param abs_path: Absolute path to migrate repository + :param version: Initial database version + """ + repository = _find_migrate_repo(abs_path) + try: + return versioning_api.db_version(engine, repository) + except versioning_exceptions.DatabaseNotControlledError: + meta = sqlalchemy.MetaData() + meta.reflect(bind=engine) + tables = meta.tables + if len(tables) == 0 or 'alembic_version' in tables: + db_version_control(engine, abs_path, version=init_version) + return versioning_api.db_version(engine, repository) + else: + raise exception.DbMigrationError( + message=_( + "The database is not under version control, but has " + "tables. Please stamp the current version of the schema " + "manually.")) + + +def db_version_control(engine, abs_path, version=None): + """Mark a database as under this repository's version control. + + Once a database is under version control, schema changes should + only be done via change scripts in this repository. + + :param engine: SQLAlchemy engine instance for a given database + :param abs_path: Absolute path to migrate repository + :param version: Initial database version + """ + repository = _find_migrate_repo(abs_path) + versioning_api.version_control(engine, repository, version) + return version + + +def _find_migrate_repo(abs_path): + """Get the project's change script repository + + :param abs_path: Absolute path to migrate repository + """ + if not os.path.exists(abs_path): + raise exception.DbMigrationError("Path %s not found" % abs_path) + return Repository(abs_path) diff --git a/cerberus/openstack/common/db/sqlalchemy/migration_cli/__init__.py b/cerberus/openstack/common/db/sqlalchemy/migration_cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cerberus/openstack/common/db/sqlalchemy/migration_cli/ext_alembic.py b/cerberus/openstack/common/db/sqlalchemy/migration_cli/ext_alembic.py new file mode 100644 index 0000000..039ed47 --- /dev/null +++ b/cerberus/openstack/common/db/sqlalchemy/migration_cli/ext_alembic.py @@ -0,0 +1,78 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os + +import alembic +from alembic import config as alembic_config +import alembic.migration as alembic_migration + +from cerberus.openstack.common.db.sqlalchemy.migration_cli import ext_base +from cerberus.openstack.common.db.sqlalchemy import session as db_session + + +class AlembicExtension(ext_base.MigrationExtensionBase): + + order = 2 + + @property + def enabled(self): + return os.path.exists(self.alembic_ini_path) + + def __init__(self, migration_config): + """Extension to provide alembic features. + + :param migration_config: Stores specific configuration for migrations + :type migration_config: dict + """ + self.alembic_ini_path = migration_config.get('alembic_ini_path', '') + self.config = alembic_config.Config(self.alembic_ini_path) + # option should be used if script is not in default directory + repo_path = migration_config.get('alembic_repo_path') + if repo_path: + self.config.set_main_option('script_location', repo_path) + self.db_url = migration_config['db_url'] + + def upgrade(self, version): + return alembic.command.upgrade(self.config, version or 'head') + + def downgrade(self, version): + if isinstance(version, int) or version is None or version.isdigit(): + version = 'base' + return alembic.command.downgrade(self.config, version) + + def version(self): + engine = db_session.create_engine(self.db_url) + with engine.connect() as conn: + context = alembic_migration.MigrationContext.configure(conn) + return context.get_current_revision() + + def revision(self, message='', autogenerate=False): + """Creates template for migration. + + :param message: Text that will be used for migration title + :type message: string + :param autogenerate: If True - generates diff based on current database + state + :type autogenerate: bool + """ + return alembic.command.revision(self.config, message=message, + autogenerate=autogenerate) + + def stamp(self, revision): + """Stamps database with provided revision. + + :param revision: Should match one from repository or head - to stamp + database with most recent revision + :type revision: string + """ + return alembic.command.stamp(self.config, revision=revision) diff --git a/cerberus/openstack/common/db/sqlalchemy/migration_cli/ext_base.py b/cerberus/openstack/common/db/sqlalchemy/migration_cli/ext_base.py new file mode 100644 index 0000000..271cd0a --- /dev/null +++ b/cerberus/openstack/common/db/sqlalchemy/migration_cli/ext_base.py @@ -0,0 +1,79 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import abc + +import six + + +@six.add_metaclass(abc.ABCMeta) +class MigrationExtensionBase(object): + + #used to sort migration in logical order + order = 0 + + @property + def enabled(self): + """Used for availability verification of a plugin. + + :rtype: bool + """ + return False + + @abc.abstractmethod + def upgrade(self, version): + """Used for upgrading database. + + :param version: Desired database version + :type version: string + """ + + @abc.abstractmethod + def downgrade(self, version): + """Used for downgrading database. + + :param version: Desired database version + :type version: string + """ + + @abc.abstractmethod + def version(self): + """Current database version. + + :returns: Databse version + :rtype: string + """ + + def revision(self, *args, **kwargs): + """Used to generate migration script. + + In migration engines that support this feature, it should generate + new migration script. + + Accept arbitrary set of arguments. + """ + raise NotImplementedError() + + def stamp(self, *args, **kwargs): + """Stamps database based on plugin features. + + Accept arbitrary set of arguments. + """ + raise NotImplementedError() + + def __cmp__(self, other): + """Used for definition of plugin order. + + :param other: MigrationExtensionBase instance + :rtype: bool + """ + return self.order > other.order diff --git a/cerberus/openstack/common/db/sqlalchemy/migration_cli/ext_migrate.py b/cerberus/openstack/common/db/sqlalchemy/migration_cli/ext_migrate.py new file mode 100644 index 0000000..4758c4f --- /dev/null +++ b/cerberus/openstack/common/db/sqlalchemy/migration_cli/ext_migrate.py @@ -0,0 +1,69 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import logging +import os + +from cerberus.openstack.common.db.sqlalchemy import migration +from cerberus.openstack.common.db.sqlalchemy.migration_cli import ext_base +from cerberus.openstack.common.db.sqlalchemy import session as db_session +from cerberus.openstack.common.gettextutils import _LE + + +LOG = logging.getLogger(__name__) + + +class MigrateExtension(ext_base.MigrationExtensionBase): + """Extension to provide sqlalchemy-migrate features. + + :param migration_config: Stores specific configuration for migrations + :type migration_config: dict + """ + + order = 1 + + def __init__(self, migration_config): + self.repository = migration_config.get('migration_repo_path', '') + self.init_version = migration_config.get('init_version', 0) + self.db_url = migration_config['db_url'] + self.engine = db_session.create_engine(self.db_url) + + @property + def enabled(self): + return os.path.exists(self.repository) + + def upgrade(self, version): + version = None if version == 'head' else version + return migration.db_sync( + self.engine, self.repository, version, + init_version=self.init_version) + + def downgrade(self, version): + try: + #version for migrate should be valid int - else skip + if version in ('base', None): + version = self.init_version + version = int(version) + return migration.db_sync( + self.engine, self.repository, version, + init_version=self.init_version) + except ValueError: + LOG.error( + _LE('Migration number for migrate plugin must be valid ' + 'integer or empty, if you want to downgrade ' + 'to initial state') + ) + raise + + def version(self): + return migration.db_version( + self.engine, self.repository, init_version=self.init_version) diff --git a/cerberus/openstack/common/db/sqlalchemy/migration_cli/manager.py b/cerberus/openstack/common/db/sqlalchemy/migration_cli/manager.py new file mode 100644 index 0000000..1184293 --- /dev/null +++ b/cerberus/openstack/common/db/sqlalchemy/migration_cli/manager.py @@ -0,0 +1,71 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from stevedore import enabled + + +MIGRATION_NAMESPACE = 'cerberus.openstack.common.migration' + + +def check_plugin_enabled(ext): + """Used for EnabledExtensionManager""" + return ext.obj.enabled + + +class MigrationManager(object): + + def __init__(self, migration_config): + self._manager = enabled.EnabledExtensionManager( + MIGRATION_NAMESPACE, + check_plugin_enabled, + invoke_kwds={'migration_config': migration_config}, + invoke_on_load=True + ) + if not self._plugins: + raise ValueError('There must be at least one plugin active.') + + @property + def _plugins(self): + return sorted(ext.obj for ext in self._manager.extensions) + + def upgrade(self, revision): + """Upgrade database with all available backends.""" + results = [] + for plugin in self._plugins: + results.append(plugin.upgrade(revision)) + return results + + def downgrade(self, revision): + """Downgrade database with available backends.""" + #downgrading should be performed in reversed order + results = [] + for plugin in reversed(self._plugins): + results.append(plugin.downgrade(revision)) + return results + + def version(self): + """Return last version of db.""" + last = None + for plugin in self._plugins: + version = plugin.version() + if version: + last = version + return last + + def revision(self, message, autogenerate): + """Generate template or autogenerated revision.""" + #revision should be done only by last plugin + return self._plugins[-1].revision(message, autogenerate) + + def stamp(self, revision): + """Create stamp for a given revision.""" + return self._plugins[-1].stamp(revision) diff --git a/cerberus/openstack/common/db/sqlalchemy/models.py b/cerberus/openstack/common/db/sqlalchemy/models.py new file mode 100644 index 0000000..ccc77f6 --- /dev/null +++ b/cerberus/openstack/common/db/sqlalchemy/models.py @@ -0,0 +1,119 @@ +# Copyright (c) 2011 X.commerce, a business unit of eBay Inc. +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2011 Piston Cloud Computing, Inc. +# Copyright 2012 Cloudscaling Group, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +SQLAlchemy models. +""" + +import six + +from sqlalchemy import Column, Integer +from sqlalchemy import DateTime +from sqlalchemy.orm import object_mapper + +from cerberus.openstack.common import timeutils + + +class ModelBase(six.Iterator): + """Base class for models.""" + __table_initialized__ = False + + def save(self, session): + """Save this object.""" + + # NOTE(boris-42): This part of code should be look like: + # session.add(self) + # session.flush() + # But there is a bug in sqlalchemy and eventlet that + # raises NoneType exception if there is no running + # transaction and rollback is called. As long as + # sqlalchemy has this bug we have to create transaction + # explicitly. + with session.begin(subtransactions=True): + session.add(self) + session.flush() + + def __setitem__(self, key, value): + setattr(self, key, value) + + def __getitem__(self, key): + return getattr(self, key) + + def get(self, key, default=None): + return getattr(self, key, default) + + @property + def _extra_keys(self): + """Specifies custom fields + + Subclasses can override this property to return a list + of custom fields that should be included in their dict + representation. + + For reference check tests/db/sqlalchemy/test_models.py + """ + return [] + + def __iter__(self): + columns = dict(object_mapper(self).columns).keys() + # NOTE(russellb): Allow models to specify other keys that can be looked + # up, beyond the actual db columns. An example would be the 'name' + # property for an Instance. + columns.extend(self._extra_keys) + self._i = iter(columns) + return self + + # In Python 3, __next__() has replaced next(). + def __next__(self): + n = six.advance_iterator(self._i) + return n, getattr(self, n) + + def next(self): + return self.__next__() + + def update(self, values): + """Make the model object behave like a dict.""" + for k, v in six.iteritems(values): + setattr(self, k, v) + + def iteritems(self): + """Make the model object behave like a dict. + + Includes attributes from joins. + """ + local = dict(self) + joined = dict([(k, v) for k, v in six.iteritems(self.__dict__) + if not k[0] == '_']) + local.update(joined) + return six.iteritems(local) + + +class TimestampMixin(object): + created_at = Column(DateTime, default=lambda: timeutils.utcnow()) + updated_at = Column(DateTime, onupdate=lambda: timeutils.utcnow()) + + +class SoftDeleteMixin(object): + deleted_at = Column(DateTime) + deleted = Column(Integer, default=0) + + def soft_delete(self, session): + """Mark this object as deleted.""" + self.deleted = self.id + self.deleted_at = timeutils.utcnow() + self.save(session=session) diff --git a/cerberus/openstack/common/db/sqlalchemy/provision.py b/cerberus/openstack/common/db/sqlalchemy/provision.py new file mode 100644 index 0000000..4a29e70 --- /dev/null +++ b/cerberus/openstack/common/db/sqlalchemy/provision.py @@ -0,0 +1,157 @@ +# Copyright 2013 Mirantis.inc +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Provision test environment for specific DB backends""" + +import argparse +import logging +import os +import random +import string + +from six import moves +import sqlalchemy + +from cerberus.openstack.common.db import exception as exc + + +LOG = logging.getLogger(__name__) + + +def get_engine(uri): + """Engine creation + + Call the function without arguments to get admin connection. Admin + connection required to create temporary user and database for each + particular test. Otherwise use existing connection to recreate connection + to the temporary database. + """ + return sqlalchemy.create_engine(uri, poolclass=sqlalchemy.pool.NullPool) + + +def _execute_sql(engine, sql, driver): + """Initialize connection, execute sql query and close it.""" + try: + with engine.connect() as conn: + if driver == 'postgresql': + conn.connection.set_isolation_level(0) + for s in sql: + conn.execute(s) + except sqlalchemy.exc.OperationalError: + msg = ('%s does not match database admin ' + 'credentials or database does not exist.') + LOG.exception(msg % engine.url) + raise exc.DBConnectionError(msg % engine.url) + + +def create_database(engine): + """Provide temporary user and database for each particular test.""" + driver = engine.name + + auth = { + 'database': ''.join(random.choice(string.ascii_lowercase) + for i in moves.range(10)), + 'user': engine.url.username, + 'passwd': engine.url.password, + } + + sqls = [ + "drop database if exists %(database)s;", + "create database %(database)s;" + ] + + if driver == 'sqlite': + return 'sqlite:////tmp/%s' % auth['database'] + elif driver in ['mysql', 'postgresql']: + sql_query = map(lambda x: x % auth, sqls) + _execute_sql(engine, sql_query, driver) + else: + raise ValueError('Unsupported RDBMS %s' % driver) + + params = auth.copy() + params['backend'] = driver + return "%(backend)s://%(user)s:%(passwd)s@localhost/%(database)s" % params + + +def drop_database(admin_engine, current_uri): + """Drop temporary database and user after each particular test.""" + + engine = get_engine(current_uri) + driver = engine.name + auth = {'database': engine.url.database, 'user': engine.url.username} + + if driver == 'sqlite': + try: + os.remove(auth['database']) + except OSError: + pass + elif driver in ['mysql', 'postgresql']: + sql = "drop database if exists %(database)s;" + _execute_sql(admin_engine, [sql % auth], driver) + else: + raise ValueError('Unsupported RDBMS %s' % driver) + + +def main(): + """Controller to handle commands + + ::create: Create test user and database with random names. + ::drop: Drop user and database created by previous command. + """ + parser = argparse.ArgumentParser( + description='Controller to handle database creation and dropping' + ' commands.', + epilog='Under normal circumstances is not used directly.' + ' Used in .testr.conf to automate test database creation' + ' and dropping processes.') + subparsers = parser.add_subparsers( + help='Subcommands to manipulate temporary test databases.') + + create = subparsers.add_parser( + 'create', + help='Create temporary test ' + 'databases and users.') + create.set_defaults(which='create') + create.add_argument( + 'instances_count', + type=int, + help='Number of databases to create.') + + drop = subparsers.add_parser( + 'drop', + help='Drop temporary test databases and users.') + drop.set_defaults(which='drop') + drop.add_argument( + 'instances', + nargs='+', + help='List of databases uri to be dropped.') + + args = parser.parse_args() + + connection_string = os.getenv('OS_TEST_DBAPI_ADMIN_CONNECTION', + 'sqlite://') + engine = get_engine(connection_string) + which = args.which + + if which == "create": + for i in range(int(args.instances_count)): + print(create_database(engine)) + elif which == "drop": + for db in args.instances: + drop_database(engine, db) + + +if __name__ == "__main__": + main() diff --git a/cerberus/openstack/common/db/sqlalchemy/session.py b/cerberus/openstack/common/db/sqlalchemy/session.py new file mode 100644 index 0000000..7a0324a --- /dev/null +++ b/cerberus/openstack/common/db/sqlalchemy/session.py @@ -0,0 +1,933 @@ +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Session Handling for SQLAlchemy backend. + +Recommended ways to use sessions within this framework: + +* Don't use them explicitly; this is like running with ``AUTOCOMMIT=1``. + `model_query()` will implicitly use a session when called without one + supplied. This is the ideal situation because it will allow queries + to be automatically retried if the database connection is interrupted. + + .. note:: Automatic retry will be enabled in a future patch. + + It is generally fine to issue several queries in a row like this. Even though + they may be run in separate transactions and/or separate sessions, each one + will see the data from the prior calls. If needed, undo- or rollback-like + functionality should be handled at a logical level. For an example, look at + the code around quotas and `reservation_rollback()`. + + Examples: + + .. code:: python + + def get_foo(context, foo): + return (model_query(context, models.Foo). + filter_by(foo=foo). + first()) + + def update_foo(context, id, newfoo): + (model_query(context, models.Foo). + filter_by(id=id). + update({'foo': newfoo})) + + def create_foo(context, values): + foo_ref = models.Foo() + foo_ref.update(values) + foo_ref.save() + return foo_ref + + +* Within the scope of a single method, keep all the reads and writes within + the context managed by a single session. In this way, the session's + `__exit__` handler will take care of calling `flush()` and `commit()` for + you. If using this approach, you should not explicitly call `flush()` or + `commit()`. Any error within the context of the session will cause the + session to emit a `ROLLBACK`. Database errors like `IntegrityError` will be + raised in `session`'s `__exit__` handler, and any try/except within the + context managed by `session` will not be triggered. And catching other + non-database errors in the session will not trigger the ROLLBACK, so + exception handlers should always be outside the session, unless the + developer wants to do a partial commit on purpose. If the connection is + dropped before this is possible, the database will implicitly roll back the + transaction. + + .. note:: Statements in the session scope will not be automatically retried. + + If you create models within the session, they need to be added, but you + do not need to call `model.save()`: + + .. code:: python + + def create_many_foo(context, foos): + session = sessionmaker() + with session.begin(): + for foo in foos: + foo_ref = models.Foo() + foo_ref.update(foo) + session.add(foo_ref) + + def update_bar(context, foo_id, newbar): + session = sessionmaker() + with session.begin(): + foo_ref = (model_query(context, models.Foo, session). + filter_by(id=foo_id). + first()) + (model_query(context, models.Bar, session). + filter_by(id=foo_ref['bar_id']). + update({'bar': newbar})) + + .. note:: `update_bar` is a trivially simple example of using + ``with session.begin``. Whereas `create_many_foo` is a good example of + when a transaction is needed, it is always best to use as few queries as + possible. + + The two queries in `update_bar` can be better expressed using a single query + which avoids the need for an explicit transaction. It can be expressed like + so: + + .. code:: python + + def update_bar(context, foo_id, newbar): + subq = (model_query(context, models.Foo.id). + filter_by(id=foo_id). + limit(1). + subquery()) + (model_query(context, models.Bar). + filter_by(id=subq.as_scalar()). + update({'bar': newbar})) + + For reference, this emits approximately the following SQL statement: + + .. code:: sql + + UPDATE bar SET bar = ${newbar} + WHERE id=(SELECT bar_id FROM foo WHERE id = ${foo_id} LIMIT 1); + + .. note:: `create_duplicate_foo` is a trivially simple example of catching an + exception while using ``with session.begin``. Here create two duplicate + instances with same primary key, must catch the exception out of context + managed by a single session: + + .. code:: python + + def create_duplicate_foo(context): + foo1 = models.Foo() + foo2 = models.Foo() + foo1.id = foo2.id = 1 + session = sessionmaker() + try: + with session.begin(): + session.add(foo1) + session.add(foo2) + except exception.DBDuplicateEntry as e: + handle_error(e) + +* Passing an active session between methods. Sessions should only be passed + to private methods. The private method must use a subtransaction; otherwise + SQLAlchemy will throw an error when you call `session.begin()` on an existing + transaction. Public methods should not accept a session parameter and should + not be involved in sessions within the caller's scope. + + Note that this incurs more overhead in SQLAlchemy than the above means + due to nesting transactions, and it is not possible to implicitly retry + failed database operations when using this approach. + + This also makes code somewhat more difficult to read and debug, because a + single database transaction spans more than one method. Error handling + becomes less clear in this situation. When this is needed for code clarity, + it should be clearly documented. + + .. code:: python + + def myfunc(foo): + session = sessionmaker() + with session.begin(): + # do some database things + bar = _private_func(foo, session) + return bar + + def _private_func(foo, session=None): + if not session: + session = sessionmaker() + with session.begin(subtransaction=True): + # do some other database things + return bar + + +There are some things which it is best to avoid: + +* Don't keep a transaction open any longer than necessary. + + This means that your ``with session.begin()`` block should be as short + as possible, while still containing all the related calls for that + transaction. + +* Avoid ``with_lockmode('UPDATE')`` when possible. + + In MySQL/InnoDB, when a ``SELECT ... FOR UPDATE`` query does not match + any rows, it will take a gap-lock. This is a form of write-lock on the + "gap" where no rows exist, and prevents any other writes to that space. + This can effectively prevent any INSERT into a table by locking the gap + at the end of the index. Similar problems will occur if the SELECT FOR UPDATE + has an overly broad WHERE clause, or doesn't properly use an index. + + One idea proposed at ODS Fall '12 was to use a normal SELECT to test the + number of rows matching a query, and if only one row is returned, + then issue the SELECT FOR UPDATE. + + The better long-term solution is to use + ``INSERT .. ON DUPLICATE KEY UPDATE``. + However, this can not be done until the "deleted" columns are removed and + proper UNIQUE constraints are added to the tables. + + +Enabling soft deletes: + +* To use/enable soft-deletes, the `SoftDeleteMixin` must be added + to your model class. For example: + + .. code:: python + + class NovaBase(models.SoftDeleteMixin, models.ModelBase): + pass + + +Efficient use of soft deletes: + +* There are two possible ways to mark a record as deleted: + `model.soft_delete()` and `query.soft_delete()`. + + The `model.soft_delete()` method works with a single already-fetched entry. + `query.soft_delete()` makes only one db request for all entries that + correspond to the query. + +* In almost all cases you should use `query.soft_delete()`. Some examples: + + .. code:: python + + def soft_delete_bar(): + count = model_query(BarModel).find(some_condition).soft_delete() + if count == 0: + raise Exception("0 entries were soft deleted") + + def complex_soft_delete_with_synchronization_bar(session=None): + if session is None: + session = sessionmaker() + with session.begin(subtransactions=True): + count = (model_query(BarModel). + find(some_condition). + soft_delete(synchronize_session=True)) + # Here synchronize_session is required, because we + # don't know what is going on in outer session. + if count == 0: + raise Exception("0 entries were soft deleted") + +* There is only one situation where `model.soft_delete()` is appropriate: when + you fetch a single record, work with it, and mark it as deleted in the same + transaction. + + .. code:: python + + def soft_delete_bar_model(): + session = sessionmaker() + with session.begin(): + bar_ref = model_query(BarModel).find(some_condition).first() + # Work with bar_ref + bar_ref.soft_delete(session=session) + + However, if you need to work with all entries that correspond to query and + then soft delete them you should use the `query.soft_delete()` method: + + .. code:: python + + def soft_delete_multi_models(): + session = sessionmaker() + with session.begin(): + query = (model_query(BarModel, session=session). + find(some_condition)) + model_refs = query.all() + # Work with model_refs + query.soft_delete(synchronize_session=False) + # synchronize_session=False should be set if there is no outer + # session and these entries are not used after this. + + When working with many rows, it is very important to use query.soft_delete, + which issues a single query. Using `model.soft_delete()`, as in the following + example, is very inefficient. + + .. code:: python + + for bar_ref in bar_refs: + bar_ref.soft_delete(session=session) + # This will produce count(bar_refs) db requests. + +""" + +import functools +import logging +import re +import time + +import six +from sqlalchemy import exc as sqla_exc +from sqlalchemy.interfaces import PoolListener +import sqlalchemy.orm +from sqlalchemy.pool import NullPool, StaticPool +from sqlalchemy.sql.expression import literal_column + +from cerberus.openstack.common.db import exception +from cerberus.openstack.common.gettextutils import _LE, _LW +from cerberus.openstack.common import timeutils + + +LOG = logging.getLogger(__name__) + + +class SqliteForeignKeysListener(PoolListener): + """Ensures that the foreign key constraints are enforced in SQLite. + + The foreign key constraints are disabled by default in SQLite, + so the foreign key constraints will be enabled here for every + database connection + """ + def connect(self, dbapi_con, con_record): + dbapi_con.execute('pragma foreign_keys=ON') + + +# note(boris-42): In current versions of DB backends unique constraint +# violation messages follow the structure: +# +# sqlite: +# 1 column - (IntegrityError) column c1 is not unique +# N columns - (IntegrityError) column c1, c2, ..., N are not unique +# +# sqlite since 3.7.16: +# 1 column - (IntegrityError) UNIQUE constraint failed: tbl.k1 +# +# N columns - (IntegrityError) UNIQUE constraint failed: tbl.k1, tbl.k2 +# +# postgres: +# 1 column - (IntegrityError) duplicate key value violates unique +# constraint "users_c1_key" +# N columns - (IntegrityError) duplicate key value violates unique +# constraint "name_of_our_constraint" +# +# mysql: +# 1 column - (IntegrityError) (1062, "Duplicate entry 'value_of_c1' for key +# 'c1'") +# N columns - (IntegrityError) (1062, "Duplicate entry 'values joined +# with -' for key 'name_of_our_constraint'") +# +# ibm_db_sa: +# N columns - (IntegrityError) SQL0803N One or more values in the INSERT +# statement, UPDATE statement, or foreign key update caused by a +# DELETE statement are not valid because the primary key, unique +# constraint or unique index identified by "2" constrains table +# "NOVA.KEY_PAIRS" from having duplicate values for the index +# key. +_DUP_KEY_RE_DB = { + "sqlite": (re.compile(r"^.*columns?([^)]+)(is|are)\s+not\s+unique$"), + re.compile(r"^.*UNIQUE\s+constraint\s+failed:\s+(.+)$")), + "postgresql": (re.compile(r"^.*duplicate\s+key.*\"([^\"]+)\"\s*\n.*$"),), + "mysql": (re.compile(r"^.*\(1062,.*'([^\']+)'\"\)$"),), + "ibm_db_sa": (re.compile(r"^.*SQL0803N.*$"),), +} + + +def _raise_if_duplicate_entry_error(integrity_error, engine_name): + """Raise exception if two entries are duplicated. + + In this function will be raised DBDuplicateEntry exception if integrity + error wrap unique constraint violation. + """ + + def get_columns_from_uniq_cons_or_name(columns): + # note(vsergeyev): UniqueConstraint name convention: "uniq_t0c10c2" + # where `t` it is table name and columns `c1`, `c2` + # are in UniqueConstraint. + uniqbase = "uniq_" + if not columns.startswith(uniqbase): + if engine_name == "postgresql": + return [columns[columns.index("_") + 1:columns.rindex("_")]] + return [columns] + return columns[len(uniqbase):].split("0")[1:] + + if engine_name not in ("ibm_db_sa", "mysql", "sqlite", "postgresql"): + return + + # FIXME(johannes): The usage of the .message attribute has been + # deprecated since Python 2.6. However, the exceptions raised by + # SQLAlchemy can differ when using unicode() and accessing .message. + # An audit across all three supported engines will be necessary to + # ensure there are no regressions. + for pattern in _DUP_KEY_RE_DB[engine_name]: + match = pattern.match(integrity_error.message) + if match: + break + else: + return + + # NOTE(mriedem): The ibm_db_sa integrity error message doesn't provide the + # columns so we have to omit that from the DBDuplicateEntry error. + columns = '' + + if engine_name != 'ibm_db_sa': + columns = match.group(1) + + if engine_name == "sqlite": + columns = [c.split('.')[-1] for c in columns.strip().split(", ")] + else: + columns = get_columns_from_uniq_cons_or_name(columns) + raise exception.DBDuplicateEntry(columns, integrity_error) + + +# NOTE(comstud): In current versions of DB backends, Deadlock violation +# messages follow the structure: +# +# mysql: +# (OperationalError) (1213, 'Deadlock found when trying to get lock; try ' +# 'restarting transaction') +_DEADLOCK_RE_DB = { + "mysql": re.compile(r"^.*\(1213, 'Deadlock.*") +} + + +def _raise_if_deadlock_error(operational_error, engine_name): + """Raise exception on deadlock condition. + + Raise DBDeadlock exception if OperationalError contains a Deadlock + condition. + """ + re = _DEADLOCK_RE_DB.get(engine_name) + if re is None: + return + # FIXME(johannes): The usage of the .message attribute has been + # deprecated since Python 2.6. However, the exceptions raised by + # SQLAlchemy can differ when using unicode() and accessing .message. + # An audit across all three supported engines will be necessary to + # ensure there are no regressions. + m = re.match(operational_error.message) + if not m: + return + raise exception.DBDeadlock(operational_error) + + +def _wrap_db_error(f): + @functools.wraps(f) + def _wrap(self, *args, **kwargs): + try: + assert issubclass( + self.__class__, ( + sqlalchemy.orm.session.Session, SessionTransactionWrapper) + ), ('_wrap_db_error() can only be applied to methods of ' + 'subclasses of sqlalchemy.orm.session.Session or ' + ' SessionTransactionWrapper') + + return f(self, *args, **kwargs) + except UnicodeEncodeError: + raise exception.DBInvalidUnicodeParameter() + except sqla_exc.OperationalError as e: + _raise_if_db_connection_lost(e, self.bind) + _raise_if_deadlock_error(e, self.bind.dialect.name) + # NOTE(comstud): A lot of code is checking for OperationalError + # so let's not wrap it for now. + raise + # note(boris-42): We should catch unique constraint violation and + # wrap it by our own DBDuplicateEntry exception. Unique constraint + # violation is wrapped by IntegrityError. + except sqla_exc.IntegrityError as e: + # note(boris-42): SqlAlchemy doesn't unify errors from different + # DBs so we must do this. Also in some tables (for example + # instance_types) there are more than one unique constraint. This + # means we should get names of columns, which values violate + # unique constraint, from error message. + _raise_if_duplicate_entry_error(e, self.bind.dialect.name) + raise exception.DBError(e) + except exception.DBError: + # note(zzzeek) - if _wrap_db_error is applied to nested functions, + # ensure an existing DBError is propagated outwards + raise + except Exception as e: + LOG.exception(_LE('DB exception wrapped.')) + raise exception.DBError(e) + return _wrap + + +def _synchronous_switch_listener(dbapi_conn, connection_rec): + """Switch sqlite connections to non-synchronous mode.""" + dbapi_conn.execute("PRAGMA synchronous = OFF") + + +def _add_regexp_listener(dbapi_con, con_record): + """Add REGEXP function to sqlite connections.""" + + def regexp(expr, item): + reg = re.compile(expr) + return reg.search(six.text_type(item)) is not None + dbapi_con.create_function('regexp', 2, regexp) + + +def _thread_yield(dbapi_con, con_record): + """Ensure other greenthreads get a chance to be executed. + + If we use eventlet.monkey_patch(), eventlet.greenthread.sleep(0) will + execute instead of time.sleep(0). + Force a context switch. With common database backends (eg MySQLdb and + sqlite), there is no implicit yield caused by network I/O since they are + implemented by C libraries that eventlet cannot monkey patch. + """ + time.sleep(0) + + +def _ping_listener(engine, dbapi_conn, connection_rec, connection_proxy): + """Ensures that MySQL, PostgreSQL or DB2 connections are alive. + + Borrowed from: + http://groups.google.com/group/sqlalchemy/msg/a4ce563d802c929f + """ + cursor = dbapi_conn.cursor() + try: + ping_sql = 'select 1' + if engine.name == 'ibm_db_sa': + # DB2 requires a table expression + ping_sql = 'select 1 from (values (1)) AS t1' + cursor.execute(ping_sql) + except Exception as ex: + if engine.dialect.is_disconnect(ex, dbapi_conn, cursor): + msg = _LW('Database server has gone away: %s') % ex + LOG.warning(msg) + + # if the database server has gone away, all connections in the pool + # have become invalid and we can safely close all of them here, + # rather than waste time on checking of every single connection + engine.dispose() + + # this will be handled by SQLAlchemy and will force it to create + # a new connection and retry the original action + raise sqla_exc.DisconnectionError(msg) + else: + raise + + +def _set_session_sql_mode(dbapi_con, connection_rec, sql_mode=None): + """Set the sql_mode session variable. + + MySQL supports several server modes. The default is None, but sessions + may choose to enable server modes like TRADITIONAL, ANSI, + several STRICT_* modes and others. + + Note: passing in '' (empty string) for sql_mode clears + the SQL mode for the session, overriding a potentially set + server default. + """ + + cursor = dbapi_con.cursor() + cursor.execute("SET SESSION sql_mode = %s", [sql_mode]) + + +def _mysql_get_effective_sql_mode(engine): + """Returns the effective SQL mode for connections from the engine pool. + + Returns ``None`` if the mode isn't available, otherwise returns the mode. + + """ + # Get the real effective SQL mode. Even when unset by + # our own config, the server may still be operating in a specific + # SQL mode as set by the server configuration. + # Also note that the checkout listener will be called on execute to + # set the mode if it's registered. + row = engine.execute("SHOW VARIABLES LIKE 'sql_mode'").fetchone() + if row is None: + return + return row[1] + + +def _mysql_check_effective_sql_mode(engine): + """Logs a message based on the effective SQL mode for MySQL connections.""" + realmode = _mysql_get_effective_sql_mode(engine) + + if realmode is None: + LOG.warning(_LW('Unable to detect effective SQL mode')) + return + + LOG.debug('MySQL server mode set to %s', realmode) + # 'TRADITIONAL' mode enables several other modes, so + # we need a substring match here + if not ('TRADITIONAL' in realmode.upper() or + 'STRICT_ALL_TABLES' in realmode.upper()): + LOG.warning(_LW("MySQL SQL mode is '%s', " + "consider enabling TRADITIONAL or STRICT_ALL_TABLES"), + realmode) + + +def _mysql_set_mode_callback(engine, sql_mode): + if sql_mode is not None: + mode_callback = functools.partial(_set_session_sql_mode, + sql_mode=sql_mode) + sqlalchemy.event.listen(engine, 'connect', mode_callback) + _mysql_check_effective_sql_mode(engine) + + +def _is_db_connection_error(args): + """Return True if error in connecting to db.""" + # NOTE(adam_g): This is currently MySQL specific and needs to be extended + # to support Postgres and others. + # For the db2, the error code is -30081 since the db2 is still not ready + conn_err_codes = ('2002', '2003', '2006', '2013', '-30081') + for err_code in conn_err_codes: + if args.find(err_code) != -1: + return True + return False + + +def _raise_if_db_connection_lost(error, engine): + # NOTE(vsergeyev): Function is_disconnect(e, connection, cursor) + # requires connection and cursor in incoming parameters, + # but we have no possibility to create connection if DB + # is not available, so in such case reconnect fails. + # But is_disconnect() ignores these parameters, so it + # makes sense to pass to function None as placeholder + # instead of connection and cursor. + if engine.dialect.is_disconnect(error, None, None): + raise exception.DBConnectionError(error) + + +def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None, + idle_timeout=3600, + connection_debug=0, max_pool_size=None, max_overflow=None, + pool_timeout=None, sqlite_synchronous=True, + connection_trace=False, max_retries=10, retry_interval=10): + """Return a new SQLAlchemy engine.""" + + connection_dict = sqlalchemy.engine.url.make_url(sql_connection) + + engine_args = { + "pool_recycle": idle_timeout, + 'convert_unicode': True, + } + + logger = logging.getLogger('sqlalchemy.engine') + + # Map SQL debug level to Python log level + if connection_debug >= 100: + logger.setLevel(logging.DEBUG) + elif connection_debug >= 50: + logger.setLevel(logging.INFO) + else: + logger.setLevel(logging.WARNING) + + if "sqlite" in connection_dict.drivername: + if sqlite_fk: + engine_args["listeners"] = [SqliteForeignKeysListener()] + engine_args["poolclass"] = NullPool + + if sql_connection == "sqlite://": + engine_args["poolclass"] = StaticPool + engine_args["connect_args"] = {'check_same_thread': False} + else: + if max_pool_size is not None: + engine_args['pool_size'] = max_pool_size + if max_overflow is not None: + engine_args['max_overflow'] = max_overflow + if pool_timeout is not None: + engine_args['pool_timeout'] = pool_timeout + + engine = sqlalchemy.create_engine(sql_connection, **engine_args) + + sqlalchemy.event.listen(engine, 'checkin', _thread_yield) + + if engine.name in ('ibm_db_sa', 'mysql', 'postgresql'): + ping_callback = functools.partial(_ping_listener, engine) + sqlalchemy.event.listen(engine, 'checkout', ping_callback) + if engine.name == 'mysql': + if mysql_sql_mode: + _mysql_set_mode_callback(engine, mysql_sql_mode) + elif 'sqlite' in connection_dict.drivername: + if not sqlite_synchronous: + sqlalchemy.event.listen(engine, 'connect', + _synchronous_switch_listener) + sqlalchemy.event.listen(engine, 'connect', _add_regexp_listener) + + if connection_trace and engine.dialect.dbapi.__name__ == 'MySQLdb': + _patch_mysqldb_with_stacktrace_comments() + + try: + engine.connect() + except sqla_exc.OperationalError as e: + if not _is_db_connection_error(e.args[0]): + raise + + remaining = max_retries + if remaining == -1: + remaining = 'infinite' + while True: + msg = _LW('SQL connection failed. %s attempts left.') + LOG.warning(msg % remaining) + if remaining != 'infinite': + remaining -= 1 + time.sleep(retry_interval) + try: + engine.connect() + break + except sqla_exc.OperationalError as e: + if (remaining != 'infinite' and remaining == 0) or \ + not _is_db_connection_error(e.args[0]): + raise + return engine + + +class Query(sqlalchemy.orm.query.Query): + """Subclass of sqlalchemy.query with soft_delete() method.""" + def soft_delete(self, synchronize_session='evaluate'): + return self.update({'deleted': literal_column('id'), + 'updated_at': literal_column('updated_at'), + 'deleted_at': timeutils.utcnow()}, + synchronize_session=synchronize_session) + + +class Session(sqlalchemy.orm.session.Session): + """Custom Session class to avoid SqlAlchemy Session monkey patching.""" + @_wrap_db_error + def query(self, *args, **kwargs): + return super(Session, self).query(*args, **kwargs) + + @_wrap_db_error + def flush(self, *args, **kwargs): + return super(Session, self).flush(*args, **kwargs) + + @_wrap_db_error + def execute(self, *args, **kwargs): + return super(Session, self).execute(*args, **kwargs) + + @_wrap_db_error + def commit(self, *args, **kwargs): + return super(Session, self).commit(*args, **kwargs) + + def begin(self, **kw): + trans = super(Session, self).begin(**kw) + trans.__class__ = SessionTransactionWrapper + return trans + + +class SessionTransactionWrapper(sqlalchemy.orm.session.SessionTransaction): + @property + def bind(self): + return self.session.bind + + @_wrap_db_error + def commit(self, *args, **kwargs): + return super(SessionTransactionWrapper, self).commit(*args, **kwargs) + + @_wrap_db_error + def rollback(self, *args, **kwargs): + return super(SessionTransactionWrapper, self).rollback(*args, **kwargs) + + +def get_maker(engine, autocommit=True, expire_on_commit=False): + """Return a SQLAlchemy sessionmaker using the given engine.""" + return sqlalchemy.orm.sessionmaker(bind=engine, + class_=Session, + autocommit=autocommit, + expire_on_commit=expire_on_commit, + query_cls=Query) + + +def _patch_mysqldb_with_stacktrace_comments(): + """Adds current stack trace as a comment in queries. + + Patches MySQLdb.cursors.BaseCursor._do_query. + """ + import MySQLdb.cursors + import traceback + + old_mysql_do_query = MySQLdb.cursors.BaseCursor._do_query + + def _do_query(self, q): + stack = '' + for filename, line, method, function in traceback.extract_stack(): + # exclude various common things from trace + if filename.endswith('session.py') and method == '_do_query': + continue + if filename.endswith('api.py') and method == 'wrapper': + continue + if filename.endswith('utils.py') and method == '_inner': + continue + if filename.endswith('exception.py') and method == '_wrap': + continue + # db/api is just a wrapper around db/sqlalchemy/api + if filename.endswith('db/api.py'): + continue + # only trace inside cerberus + index = filename.rfind('cerberus') + if index == -1: + continue + stack += "File:%s:%s Method:%s() Line:%s | " \ + % (filename[index:], line, method, function) + + # strip trailing " | " from stack + if stack: + stack = stack[:-3] + qq = "%s /* %s */" % (q, stack) + else: + qq = q + old_mysql_do_query(self, qq) + + setattr(MySQLdb.cursors.BaseCursor, '_do_query', _do_query) + + +class EngineFacade(object): + """A helper class for removing of global engine instances from cerberus.db. + + As a library, cerberus.db can't decide where to store/when to create engine + and sessionmaker instances, so this must be left for a target application. + + On the other hand, in order to simplify the adoption of cerberus.db changes, + we'll provide a helper class, which creates engine and sessionmaker + on its instantiation and provides get_engine()/get_session() methods + that are compatible with corresponding utility functions that currently + exist in target projects, e.g. in Nova. + + engine/sessionmaker instances will still be global (and they are meant to + be global), but they will be stored in the app context, rather that in the + cerberus.db context. + + Note: using of this helper is completely optional and you are encouraged to + integrate engine/sessionmaker instances into your apps any way you like + (e.g. one might want to bind a session to a request context). Two important + things to remember: + + 1. An Engine instance is effectively a pool of DB connections, so it's + meant to be shared (and it's thread-safe). + 2. A Session instance is not meant to be shared and represents a DB + transactional context (i.e. it's not thread-safe). sessionmaker is + a factory of sessions. + + """ + + def __init__(self, sql_connection, + sqlite_fk=False, autocommit=True, + expire_on_commit=False, **kwargs): + """Initialize engine and sessionmaker instances. + + :param sqlite_fk: enable foreign keys in SQLite + :type sqlite_fk: bool + + :param autocommit: use autocommit mode for created Session instances + :type autocommit: bool + + :param expire_on_commit: expire session objects on commit + :type expire_on_commit: bool + + Keyword arguments: + + :keyword mysql_sql_mode: the SQL mode to be used for MySQL sessions. + (defaults to TRADITIONAL) + :keyword idle_timeout: timeout before idle sql connections are reaped + (defaults to 3600) + :keyword connection_debug: verbosity of SQL debugging information. + 0=None, 100=Everything (defaults to 0) + :keyword max_pool_size: maximum number of SQL connections to keep open + in a pool (defaults to SQLAlchemy settings) + :keyword max_overflow: if set, use this value for max_overflow with + sqlalchemy (defaults to SQLAlchemy settings) + :keyword pool_timeout: if set, use this value for pool_timeout with + sqlalchemy (defaults to SQLAlchemy settings) + :keyword sqlite_synchronous: if True, SQLite uses synchronous mode + (defaults to True) + :keyword connection_trace: add python stack traces to SQL as comment + strings (defaults to False) + :keyword max_retries: maximum db connection retries during startup. + (setting -1 implies an infinite retry count) + (defaults to 10) + :keyword retry_interval: interval between retries of opening a sql + connection (defaults to 10) + + """ + + super(EngineFacade, self).__init__() + + self._engine = create_engine( + sql_connection=sql_connection, + sqlite_fk=sqlite_fk, + mysql_sql_mode=kwargs.get('mysql_sql_mode', 'TRADITIONAL'), + idle_timeout=kwargs.get('idle_timeout', 3600), + connection_debug=kwargs.get('connection_debug', 0), + max_pool_size=kwargs.get('max_pool_size'), + max_overflow=kwargs.get('max_overflow'), + pool_timeout=kwargs.get('pool_timeout'), + sqlite_synchronous=kwargs.get('sqlite_synchronous', True), + connection_trace=kwargs.get('connection_trace', False), + max_retries=kwargs.get('max_retries', 10), + retry_interval=kwargs.get('retry_interval', 10)) + self._session_maker = get_maker( + engine=self._engine, + autocommit=autocommit, + expire_on_commit=expire_on_commit) + + def get_engine(self): + """Get the engine instance (note, that it's shared).""" + + return self._engine + + def get_session(self, **kwargs): + """Get a Session instance. + + If passed, keyword arguments values override the ones used when the + sessionmaker instance was created. + + :keyword autocommit: use autocommit mode for created Session instances + :type autocommit: bool + + :keyword expire_on_commit: expire session objects on commit + :type expire_on_commit: bool + + """ + + for arg in kwargs: + if arg not in ('autocommit', 'expire_on_commit'): + del kwargs[arg] + + return self._session_maker(**kwargs) + + @classmethod + def from_config(cls, connection_string, conf, + sqlite_fk=False, autocommit=True, expire_on_commit=False): + """Initialize EngineFacade using oslo.config config instance options. + + :param connection_string: SQLAlchemy connection string + :type connection_string: string + + :param conf: oslo.config config instance + :type conf: oslo.config.cfg.ConfigOpts + + :param sqlite_fk: enable foreign keys in SQLite + :type sqlite_fk: bool + + :param autocommit: use autocommit mode for created Session instances + :type autocommit: bool + + :param expire_on_commit: expire session objects on commit + :type expire_on_commit: bool + + """ + + return cls(sql_connection=connection_string, + sqlite_fk=sqlite_fk, + autocommit=autocommit, + expire_on_commit=expire_on_commit, + **dict(conf.database.items())) diff --git a/cerberus/openstack/common/db/sqlalchemy/test_base.py b/cerberus/openstack/common/db/sqlalchemy/test_base.py new file mode 100644 index 0000000..199326a --- /dev/null +++ b/cerberus/openstack/common/db/sqlalchemy/test_base.py @@ -0,0 +1,153 @@ +# Copyright (c) 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import abc +import functools +import os + +import fixtures +import six + +from cerberus.openstack.common.db.sqlalchemy import session +from cerberus.openstack.common.db.sqlalchemy import utils +from cerberus.openstack.common.fixture import lockutils +from cerberus.openstack.common import test + + +class DbFixture(fixtures.Fixture): + """Basic database fixture. + + Allows to run tests on various db backends, such as SQLite, MySQL and + PostgreSQL. By default use sqlite backend. To override default backend + uri set env variable OS_TEST_DBAPI_CONNECTION with database admin + credentials for specific backend. + """ + + def _get_uri(self): + return os.getenv('OS_TEST_DBAPI_CONNECTION', 'sqlite://') + + def __init__(self, test): + super(DbFixture, self).__init__() + + self.test = test + + def setUp(self): + super(DbFixture, self).setUp() + + self.test.engine = session.create_engine(self._get_uri()) + self.test.sessionmaker = session.get_maker(self.test.engine) + self.addCleanup(self.test.engine.dispose) + + +class DbTestCase(test.BaseTestCase): + """Base class for testing of DB code. + + Using `DbFixture`. Intended to be the main database test case to use all + the tests on a given backend with user defined uri. Backend specific + tests should be decorated with `backend_specific` decorator. + """ + + FIXTURE = DbFixture + + def setUp(self): + super(DbTestCase, self).setUp() + self.useFixture(self.FIXTURE(self)) + + +ALLOWED_DIALECTS = ['sqlite', 'mysql', 'postgresql'] + + +def backend_specific(*dialects): + """Decorator to skip backend specific tests on inappropriate engines. + + ::dialects: list of dialects names under which the test will be launched. + """ + def wrap(f): + @functools.wraps(f) + def ins_wrap(self): + if not set(dialects).issubset(ALLOWED_DIALECTS): + raise ValueError( + "Please use allowed dialects: %s" % ALLOWED_DIALECTS) + if self.engine.name not in dialects: + msg = ('The test "%s" can be run ' + 'only on %s. Current engine is %s.') + args = (f.__name__, ' '.join(dialects), self.engine.name) + self.skip(msg % args) + else: + return f(self) + return ins_wrap + return wrap + + +@six.add_metaclass(abc.ABCMeta) +class OpportunisticFixture(DbFixture): + """Base fixture to use default CI databases. + + The databases exist in OpenStack CI infrastructure. But for the + correct functioning in local environment the databases must be + created manually. + """ + + DRIVER = abc.abstractproperty(lambda: None) + DBNAME = PASSWORD = USERNAME = 'openstack_citest' + + def _get_uri(self): + return utils.get_connect_string(backend=self.DRIVER, + user=self.USERNAME, + passwd=self.PASSWORD, + database=self.DBNAME) + + +@six.add_metaclass(abc.ABCMeta) +class OpportunisticTestCase(DbTestCase): + """Base test case to use default CI databases. + + The subclasses of the test case are running only when openstack_citest + database is available otherwise a tests will be skipped. + """ + + FIXTURE = abc.abstractproperty(lambda: None) + + def setUp(self): + # TODO(bnemec): Remove this once infra is ready for + # https://review.openstack.org/#/c/74963/ to merge. + self.useFixture(lockutils.LockFixture('opportunistic-db')) + credentials = { + 'backend': self.FIXTURE.DRIVER, + 'user': self.FIXTURE.USERNAME, + 'passwd': self.FIXTURE.PASSWORD, + 'database': self.FIXTURE.DBNAME} + + if self.FIXTURE.DRIVER and not utils.is_backend_avail(**credentials): + msg = '%s backend is not available.' % self.FIXTURE.DRIVER + return self.skip(msg) + + super(OpportunisticTestCase, self).setUp() + + +class MySQLOpportunisticFixture(OpportunisticFixture): + DRIVER = 'mysql' + + +class PostgreSQLOpportunisticFixture(OpportunisticFixture): + DRIVER = 'postgresql' + + +class MySQLOpportunisticTestCase(OpportunisticTestCase): + FIXTURE = MySQLOpportunisticFixture + + +class PostgreSQLOpportunisticTestCase(OpportunisticTestCase): + FIXTURE = PostgreSQLOpportunisticFixture diff --git a/cerberus/openstack/common/db/sqlalchemy/test_migrations.py b/cerberus/openstack/common/db/sqlalchemy/test_migrations.py new file mode 100644 index 0000000..0fe0f12 --- /dev/null +++ b/cerberus/openstack/common/db/sqlalchemy/test_migrations.py @@ -0,0 +1,269 @@ +# Copyright 2010-2011 OpenStack Foundation +# Copyright 2012-2013 IBM Corp. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import functools +import logging +import os +import subprocess + +import lockfile +from six import moves +from six.moves.urllib import parse +import sqlalchemy +import sqlalchemy.exc + +from cerberus.openstack.common.db.sqlalchemy import utils +from cerberus.openstack.common.gettextutils import _LE +from cerberus.openstack.common import test + +LOG = logging.getLogger(__name__) + + +def _have_mysql(user, passwd, database): + present = os.environ.get('TEST_MYSQL_PRESENT') + if present is None: + return utils.is_backend_avail(backend='mysql', + user=user, + passwd=passwd, + database=database) + return present.lower() in ('', 'true') + + +def _have_postgresql(user, passwd, database): + present = os.environ.get('TEST_POSTGRESQL_PRESENT') + if present is None: + return utils.is_backend_avail(backend='postgres', + user=user, + passwd=passwd, + database=database) + return present.lower() in ('', 'true') + + +def _set_db_lock(lock_path=None, lock_prefix=None): + def decorator(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + try: + path = lock_path or os.environ.get("CERBERUS_LOCK_PATH") + lock = lockfile.FileLock(os.path.join(path, lock_prefix)) + with lock: + LOG.debug('Got lock "%s"' % f.__name__) + return f(*args, **kwargs) + finally: + LOG.debug('Lock released "%s"' % f.__name__) + return wrapper + return decorator + + +class BaseMigrationTestCase(test.BaseTestCase): + """Base class fort testing of migration utils.""" + + def __init__(self, *args, **kwargs): + super(BaseMigrationTestCase, self).__init__(*args, **kwargs) + + self.DEFAULT_CONFIG_FILE = os.path.join(os.path.dirname(__file__), + 'test_migrations.conf') + # Test machines can set the TEST_MIGRATIONS_CONF variable + # to override the location of the config file for migration testing + self.CONFIG_FILE_PATH = os.environ.get('TEST_MIGRATIONS_CONF', + self.DEFAULT_CONFIG_FILE) + self.test_databases = {} + self.migration_api = None + + def setUp(self): + super(BaseMigrationTestCase, self).setUp() + + # Load test databases from the config file. Only do this + # once. No need to re-run this on each test... + LOG.debug('config_path is %s' % self.CONFIG_FILE_PATH) + if os.path.exists(self.CONFIG_FILE_PATH): + cp = moves.configparser.RawConfigParser() + try: + cp.read(self.CONFIG_FILE_PATH) + defaults = cp.defaults() + for key, value in defaults.items(): + self.test_databases[key] = value + except moves.configparser.ParsingError as e: + self.fail("Failed to read test_migrations.conf config " + "file. Got error: %s" % e) + else: + self.fail("Failed to find test_migrations.conf config " + "file.") + + self.engines = {} + for key, value in self.test_databases.items(): + self.engines[key] = sqlalchemy.create_engine(value) + + # We start each test case with a completely blank slate. + self._reset_databases() + + def tearDown(self): + # We destroy the test data store between each test case, + # and recreate it, which ensures that we have no side-effects + # from the tests + self._reset_databases() + super(BaseMigrationTestCase, self).tearDown() + + def execute_cmd(self, cmd=None): + process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + output = process.communicate()[0] + LOG.debug(output) + self.assertEqual(0, process.returncode, + "Failed to run: %s\n%s" % (cmd, output)) + + def _reset_pg(self, conn_pieces): + (user, + password, + database, + host) = utils.get_db_connection_info(conn_pieces) + os.environ['PGPASSWORD'] = password + os.environ['PGUSER'] = user + # note(boris-42): We must create and drop database, we can't + # drop database which we have connected to, so for such + # operations there is a special database template1. + sqlcmd = ("psql -w -U %(user)s -h %(host)s -c" + " '%(sql)s' -d template1") + + sql = ("drop database if exists %s;") % database + droptable = sqlcmd % {'user': user, 'host': host, 'sql': sql} + self.execute_cmd(droptable) + + sql = ("create database %s;") % database + createtable = sqlcmd % {'user': user, 'host': host, 'sql': sql} + self.execute_cmd(createtable) + + os.unsetenv('PGPASSWORD') + os.unsetenv('PGUSER') + + @_set_db_lock(lock_prefix='migration_tests-') + def _reset_databases(self): + for key, engine in self.engines.items(): + conn_string = self.test_databases[key] + conn_pieces = parse.urlparse(conn_string) + engine.dispose() + if conn_string.startswith('sqlite'): + # We can just delete the SQLite database, which is + # the easiest and cleanest solution + db_path = conn_pieces.path.strip('/') + if os.path.exists(db_path): + os.unlink(db_path) + # No need to recreate the SQLite DB. SQLite will + # create it for us if it's not there... + elif conn_string.startswith('mysql'): + # We can execute the MySQL client to destroy and re-create + # the MYSQL database, which is easier and less error-prone + # than using SQLAlchemy to do this via MetaData...trust me. + (user, password, database, host) = \ + utils.get_db_connection_info(conn_pieces) + sql = ("drop database if exists %(db)s; " + "create database %(db)s;") % {'db': database} + cmd = ("mysql -u \"%(user)s\" -p\"%(password)s\" -h %(host)s " + "-e \"%(sql)s\"") % {'user': user, 'password': password, + 'host': host, 'sql': sql} + self.execute_cmd(cmd) + elif conn_string.startswith('postgresql'): + self._reset_pg(conn_pieces) + + +class WalkVersionsMixin(object): + def _walk_versions(self, engine=None, snake_walk=False, downgrade=True): + # Determine latest version script from the repo, then + # upgrade from 1 through to the latest, with no data + # in the databases. This just checks that the schema itself + # upgrades successfully. + + # Place the database under version control + self.migration_api.version_control(engine, self.REPOSITORY, + self.INIT_VERSION) + self.assertEqual(self.INIT_VERSION, + self.migration_api.db_version(engine, + self.REPOSITORY)) + + LOG.debug('latest version is %s' % self.REPOSITORY.latest) + versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1) + + for version in versions: + # upgrade -> downgrade -> upgrade + self._migrate_up(engine, version, with_data=True) + if snake_walk: + downgraded = self._migrate_down( + engine, version - 1, with_data=True) + if downgraded: + self._migrate_up(engine, version) + + if downgrade: + # Now walk it back down to 0 from the latest, testing + # the downgrade paths. + for version in reversed(versions): + # downgrade -> upgrade -> downgrade + downgraded = self._migrate_down(engine, version - 1) + + if snake_walk and downgraded: + self._migrate_up(engine, version) + self._migrate_down(engine, version - 1) + + def _migrate_down(self, engine, version, with_data=False): + try: + self.migration_api.downgrade(engine, self.REPOSITORY, version) + except NotImplementedError: + # NOTE(sirp): some migrations, namely release-level + # migrations, don't support a downgrade. + return False + + self.assertEqual( + version, self.migration_api.db_version(engine, self.REPOSITORY)) + + # NOTE(sirp): `version` is what we're downgrading to (i.e. the 'target' + # version). So if we have any downgrade checks, they need to be run for + # the previous (higher numbered) migration. + if with_data: + post_downgrade = getattr( + self, "_post_downgrade_%03d" % (version + 1), None) + if post_downgrade: + post_downgrade(engine) + + return True + + def _migrate_up(self, engine, version, with_data=False): + """migrate up to a new version of the db. + + We allow for data insertion and post checks at every + migration version with special _pre_upgrade_### and + _check_### functions in the main test. + """ + # NOTE(sdague): try block is here because it's impossible to debug + # where a failed data migration happens otherwise + try: + if with_data: + data = None + pre_upgrade = getattr( + self, "_pre_upgrade_%03d" % version, None) + if pre_upgrade: + data = pre_upgrade(engine) + + self.migration_api.upgrade(engine, self.REPOSITORY, version) + self.assertEqual(version, + self.migration_api.db_version(engine, + self.REPOSITORY)) + if with_data: + check = getattr(self, "_check_%03d" % version, None) + if check: + check(engine, data) + except Exception: + LOG.error(_LE("Failed to migrate to version %s on engine %s") % + (version, engine)) + raise diff --git a/cerberus/openstack/common/db/sqlalchemy/utils.py b/cerberus/openstack/common/db/sqlalchemy/utils.py new file mode 100644 index 0000000..6da194e --- /dev/null +++ b/cerberus/openstack/common/db/sqlalchemy/utils.py @@ -0,0 +1,647 @@ +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2010-2011 OpenStack Foundation. +# Copyright 2012 Justin Santa Barbara +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import logging +import re + +import sqlalchemy +from sqlalchemy import Boolean +from sqlalchemy import CheckConstraint +from sqlalchemy import Column +from sqlalchemy.engine import reflection +from sqlalchemy.ext.compiler import compiles +from sqlalchemy import func +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import or_ +from sqlalchemy.sql.expression import literal_column +from sqlalchemy.sql.expression import UpdateBase +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy.types import NullType + +from cerberus.openstack.common import context as request_context +from cerberus.openstack.common.db.sqlalchemy import models +from cerberus.openstack.common.gettextutils import _, _LI, _LW +from cerberus.openstack.common import timeutils + + +LOG = logging.getLogger(__name__) + +_DBURL_REGEX = re.compile(r"[^:]+://([^:]+):([^@]+)@.+") + + +def sanitize_db_url(url): + match = _DBURL_REGEX.match(url) + if match: + return '%s****:****%s' % (url[:match.start(1)], url[match.end(2):]) + return url + + +class InvalidSortKey(Exception): + message = _("Sort key supplied was not valid.") + + +# copy from glance/db/sqlalchemy/api.py +def paginate_query(query, model, limit, sort_keys, marker=None, + sort_dir=None, sort_dirs=None): + """Returns a query with sorting / pagination criteria added. + + Pagination works by requiring a unique sort_key, specified by sort_keys. + (If sort_keys is not unique, then we risk looping through values.) + We use the last row in the previous page as the 'marker' for pagination. + So we must return values that follow the passed marker in the order. + With a single-valued sort_key, this would be easy: sort_key > X. + With a compound-values sort_key, (k1, k2, k3) we must do this to repeat + the lexicographical ordering: + (k1 > X1) or (k1 == X1 && k2 > X2) or (k1 == X1 && k2 == X2 && k3 > X3) + + We also have to cope with different sort_directions. + + Typically, the id of the last row is used as the client-facing pagination + marker, then the actual marker object must be fetched from the db and + passed in to us as marker. + + :param query: the query object to which we should add paging/sorting + :param model: the ORM model class + :param limit: maximum number of items to return + :param sort_keys: array of attributes by which results should be sorted + :param marker: the last item of the previous page; we returns the next + results after this value. + :param sort_dir: direction in which results should be sorted (asc, desc) + :param sort_dirs: per-column array of sort_dirs, corresponding to sort_keys + + :rtype: sqlalchemy.orm.query.Query + :return: The query with sorting/pagination added. + """ + + if 'id' not in sort_keys: + # TODO(justinsb): If this ever gives a false-positive, check + # the actual primary key, rather than assuming its id + LOG.warning(_LW('Id not in sort_keys; is sort_keys unique?')) + + assert(not (sort_dir and sort_dirs)) + + # Default the sort direction to ascending + if sort_dirs is None and sort_dir is None: + sort_dir = 'asc' + + # Ensure a per-column sort direction + if sort_dirs is None: + sort_dirs = [sort_dir for _sort_key in sort_keys] + + assert(len(sort_dirs) == len(sort_keys)) + + # Add sorting + for current_sort_key, current_sort_dir in zip(sort_keys, sort_dirs): + try: + sort_dir_func = { + 'asc': sqlalchemy.asc, + 'desc': sqlalchemy.desc, + }[current_sort_dir] + except KeyError: + raise ValueError(_("Unknown sort direction, " + "must be 'desc' or 'asc'")) + try: + sort_key_attr = getattr(model, current_sort_key) + except AttributeError: + raise InvalidSortKey() + query = query.order_by(sort_dir_func(sort_key_attr)) + + # Add pagination + if marker is not None: + marker_values = [] + for sort_key in sort_keys: + v = getattr(marker, sort_key) + marker_values.append(v) + + # Build up an array of sort criteria as in the docstring + criteria_list = [] + for i in range(len(sort_keys)): + crit_attrs = [] + for j in range(i): + model_attr = getattr(model, sort_keys[j]) + crit_attrs.append((model_attr == marker_values[j])) + + model_attr = getattr(model, sort_keys[i]) + if sort_dirs[i] == 'desc': + crit_attrs.append((model_attr < marker_values[i])) + else: + crit_attrs.append((model_attr > marker_values[i])) + + criteria = sqlalchemy.sql.and_(*crit_attrs) + criteria_list.append(criteria) + + f = sqlalchemy.sql.or_(*criteria_list) + query = query.filter(f) + + if limit is not None: + query = query.limit(limit) + + return query + + +def _read_deleted_filter(query, db_model, read_deleted): + if 'deleted' not in db_model.__table__.columns: + raise ValueError(_("There is no `deleted` column in `%s` table. " + "Project doesn't use soft-deleted feature.") + % db_model.__name__) + + default_deleted_value = db_model.__table__.c.deleted.default.arg + if read_deleted == 'no': + query = query.filter(db_model.deleted == default_deleted_value) + elif read_deleted == 'yes': + pass # omit the filter to include deleted and active + elif read_deleted == 'only': + query = query.filter(db_model.deleted != default_deleted_value) + else: + raise ValueError(_("Unrecognized read_deleted value '%s'") + % read_deleted) + return query + + +def _project_filter(query, db_model, context, project_only): + if project_only and 'project_id' not in db_model.__table__.columns: + raise ValueError(_("There is no `project_id` column in `%s` table.") + % db_model.__name__) + + if request_context.is_user_context(context) and project_only: + if project_only == 'allow_none': + is_none = None + query = query.filter(or_(db_model.project_id == context.project_id, + db_model.project_id == is_none)) + else: + query = query.filter(db_model.project_id == context.project_id) + + return query + + +def model_query(context, model, session, args=None, project_only=False, + read_deleted=None): + """Query helper that accounts for context's `read_deleted` field. + + :param context: context to query under + + :param model: Model to query. Must be a subclass of ModelBase. + :type model: models.ModelBase + + :param session: The session to use. + :type session: sqlalchemy.orm.session.Session + + :param args: Arguments to query. If None - model is used. + :type args: tuple + + :param project_only: If present and context is user-type, then restrict + query to match the context's project_id. If set to + 'allow_none', restriction includes project_id = None. + :type project_only: bool + + :param read_deleted: If present, overrides context's read_deleted field. + :type read_deleted: bool + + Usage: + + ..code:: python + + result = (utils.model_query(context, models.Instance, session=session) + .filter_by(uuid=instance_uuid) + .all()) + + query = utils.model_query( + context, Node, + session=session, + args=(func.count(Node.id), func.sum(Node.ram)) + ).filter_by(project_id=project_id) + + """ + + if not read_deleted: + if hasattr(context, 'read_deleted'): + # NOTE(viktors): some projects use `read_deleted` attribute in + # their contexts instead of `show_deleted`. + read_deleted = context.read_deleted + else: + read_deleted = context.show_deleted + + if not issubclass(model, models.ModelBase): + raise TypeError(_("model should be a subclass of ModelBase")) + + query = session.query(model) if not args else session.query(*args) + query = _read_deleted_filter(query, model, read_deleted) + query = _project_filter(query, model, context, project_only) + + return query + + +def get_table(engine, name): + """Returns an sqlalchemy table dynamically from db. + + Needed because the models don't work for us in migrations + as models will be far out of sync with the current data. + """ + metadata = MetaData() + metadata.bind = engine + return Table(name, metadata, autoload=True) + + +class InsertFromSelect(UpdateBase): + """Form the base for `INSERT INTO table (SELECT ... )` statement.""" + def __init__(self, table, select): + self.table = table + self.select = select + + +@compiles(InsertFromSelect) +def visit_insert_from_select(element, compiler, **kw): + """Form the `INSERT INTO table (SELECT ... )` statement.""" + return "INSERT INTO %s %s" % ( + compiler.process(element.table, asfrom=True), + compiler.process(element.select)) + + +class ColumnError(Exception): + """Error raised when no column or an invalid column is found.""" + + +def _get_not_supported_column(col_name_col_instance, column_name): + try: + column = col_name_col_instance[column_name] + except KeyError: + msg = _("Please specify column %s in col_name_col_instance " + "param. It is required because column has unsupported " + "type by sqlite).") + raise ColumnError(msg % column_name) + + if not isinstance(column, Column): + msg = _("col_name_col_instance param has wrong type of " + "column instance for column %s It should be instance " + "of sqlalchemy.Column.") + raise ColumnError(msg % column_name) + return column + + +def drop_unique_constraint(migrate_engine, table_name, uc_name, *columns, + **col_name_col_instance): + """Drop unique constraint from table. + + DEPRECATED: this function is deprecated and will be removed from cerberus.db + in a few releases. Please use UniqueConstraint.drop() method directly for + sqlalchemy-migrate migration scripts. + + This method drops UC from table and works for mysql, postgresql and sqlite. + In mysql and postgresql we are able to use "alter table" construction. + Sqlalchemy doesn't support some sqlite column types and replaces their + type with NullType in metadata. We process these columns and replace + NullType with the correct column type. + + :param migrate_engine: sqlalchemy engine + :param table_name: name of table that contains uniq constraint. + :param uc_name: name of uniq constraint that will be dropped. + :param columns: columns that are in uniq constraint. + :param col_name_col_instance: contains pair column_name=column_instance. + column_instance is instance of Column. These params + are required only for columns that have unsupported + types by sqlite. For example BigInteger. + """ + + from migrate.changeset import UniqueConstraint + + meta = MetaData() + meta.bind = migrate_engine + t = Table(table_name, meta, autoload=True) + + if migrate_engine.name == "sqlite": + override_cols = [ + _get_not_supported_column(col_name_col_instance, col.name) + for col in t.columns + if isinstance(col.type, NullType) + ] + for col in override_cols: + t.columns.replace(col) + + uc = UniqueConstraint(*columns, table=t, name=uc_name) + uc.drop() + + +def drop_old_duplicate_entries_from_table(migrate_engine, table_name, + use_soft_delete, *uc_column_names): + """Drop all old rows having the same values for columns in uc_columns. + + This method drop (or mark ad `deleted` if use_soft_delete is True) old + duplicate rows form table with name `table_name`. + + :param migrate_engine: Sqlalchemy engine + :param table_name: Table with duplicates + :param use_soft_delete: If True - values will be marked as `deleted`, + if False - values will be removed from table + :param uc_column_names: Unique constraint columns + """ + meta = MetaData() + meta.bind = migrate_engine + + table = Table(table_name, meta, autoload=True) + columns_for_group_by = [table.c[name] for name in uc_column_names] + + columns_for_select = [func.max(table.c.id)] + columns_for_select.extend(columns_for_group_by) + + duplicated_rows_select = sqlalchemy.sql.select( + columns_for_select, group_by=columns_for_group_by, + having=func.count(table.c.id) > 1) + + for row in migrate_engine.execute(duplicated_rows_select): + # NOTE(boris-42): Do not remove row that has the biggest ID. + delete_condition = table.c.id != row[0] + is_none = None # workaround for pyflakes + delete_condition &= table.c.deleted_at == is_none + for name in uc_column_names: + delete_condition &= table.c[name] == row[name] + + rows_to_delete_select = sqlalchemy.sql.select( + [table.c.id]).where(delete_condition) + for row in migrate_engine.execute(rows_to_delete_select).fetchall(): + LOG.info(_LI("Deleting duplicated row with id: %(id)s from table: " + "%(table)s") % dict(id=row[0], table=table_name)) + + if use_soft_delete: + delete_statement = table.update().\ + where(delete_condition).\ + values({ + 'deleted': literal_column('id'), + 'updated_at': literal_column('updated_at'), + 'deleted_at': timeutils.utcnow() + }) + else: + delete_statement = table.delete().where(delete_condition) + migrate_engine.execute(delete_statement) + + +def _get_default_deleted_value(table): + if isinstance(table.c.id.type, Integer): + return 0 + if isinstance(table.c.id.type, String): + return "" + raise ColumnError(_("Unsupported id columns type")) + + +def _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes): + table = get_table(migrate_engine, table_name) + + insp = reflection.Inspector.from_engine(migrate_engine) + real_indexes = insp.get_indexes(table_name) + existing_index_names = dict( + [(index['name'], index['column_names']) for index in real_indexes]) + + # NOTE(boris-42): Restore indexes on `deleted` column + for index in indexes: + if 'deleted' not in index['column_names']: + continue + name = index['name'] + if name in existing_index_names: + column_names = [table.c[c] for c in existing_index_names[name]] + old_index = Index(name, *column_names, unique=index["unique"]) + old_index.drop(migrate_engine) + + column_names = [table.c[c] for c in index['column_names']] + new_index = Index(index["name"], *column_names, unique=index["unique"]) + new_index.create(migrate_engine) + + +def change_deleted_column_type_to_boolean(migrate_engine, table_name, + **col_name_col_instance): + if migrate_engine.name == "sqlite": + return _change_deleted_column_type_to_boolean_sqlite( + migrate_engine, table_name, **col_name_col_instance) + insp = reflection.Inspector.from_engine(migrate_engine) + indexes = insp.get_indexes(table_name) + + table = get_table(migrate_engine, table_name) + + old_deleted = Column('old_deleted', Boolean, default=False) + old_deleted.create(table, populate_default=False) + + table.update().\ + where(table.c.deleted == table.c.id).\ + values(old_deleted=True).\ + execute() + + table.c.deleted.drop() + table.c.old_deleted.alter(name="deleted") + + _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes) + + +def _change_deleted_column_type_to_boolean_sqlite(migrate_engine, table_name, + **col_name_col_instance): + insp = reflection.Inspector.from_engine(migrate_engine) + table = get_table(migrate_engine, table_name) + + columns = [] + for column in table.columns: + column_copy = None + if column.name != "deleted": + if isinstance(column.type, NullType): + column_copy = _get_not_supported_column(col_name_col_instance, + column.name) + else: + column_copy = column.copy() + else: + column_copy = Column('deleted', Boolean, default=0) + columns.append(column_copy) + + constraints = [constraint.copy() for constraint in table.constraints] + + meta = table.metadata + new_table = Table(table_name + "__tmp__", meta, + *(columns + constraints)) + new_table.create() + + indexes = [] + for index in insp.get_indexes(table_name): + column_names = [new_table.c[c] for c in index['column_names']] + indexes.append(Index(index["name"], *column_names, + unique=index["unique"])) + + c_select = [] + for c in table.c: + if c.name != "deleted": + c_select.append(c) + else: + c_select.append(table.c.deleted == table.c.id) + + ins = InsertFromSelect(new_table, sqlalchemy.sql.select(c_select)) + migrate_engine.execute(ins) + + table.drop() + [index.create(migrate_engine) for index in indexes] + + new_table.rename(table_name) + new_table.update().\ + where(new_table.c.deleted == new_table.c.id).\ + values(deleted=True).\ + execute() + + +def change_deleted_column_type_to_id_type(migrate_engine, table_name, + **col_name_col_instance): + if migrate_engine.name == "sqlite": + return _change_deleted_column_type_to_id_type_sqlite( + migrate_engine, table_name, **col_name_col_instance) + insp = reflection.Inspector.from_engine(migrate_engine) + indexes = insp.get_indexes(table_name) + + table = get_table(migrate_engine, table_name) + + new_deleted = Column('new_deleted', table.c.id.type, + default=_get_default_deleted_value(table)) + new_deleted.create(table, populate_default=True) + + deleted = True # workaround for pyflakes + table.update().\ + where(table.c.deleted == deleted).\ + values(new_deleted=table.c.id).\ + execute() + table.c.deleted.drop() + table.c.new_deleted.alter(name="deleted") + + _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes) + + +def _change_deleted_column_type_to_id_type_sqlite(migrate_engine, table_name, + **col_name_col_instance): + # NOTE(boris-42): sqlaclhemy-migrate can't drop column with check + # constraints in sqlite DB and our `deleted` column has + # 2 check constraints. So there is only one way to remove + # these constraints: + # 1) Create new table with the same columns, constraints + # and indexes. (except deleted column). + # 2) Copy all data from old to new table. + # 3) Drop old table. + # 4) Rename new table to old table name. + insp = reflection.Inspector.from_engine(migrate_engine) + meta = MetaData(bind=migrate_engine) + table = Table(table_name, meta, autoload=True) + default_deleted_value = _get_default_deleted_value(table) + + columns = [] + for column in table.columns: + column_copy = None + if column.name != "deleted": + if isinstance(column.type, NullType): + column_copy = _get_not_supported_column(col_name_col_instance, + column.name) + else: + column_copy = column.copy() + else: + column_copy = Column('deleted', table.c.id.type, + default=default_deleted_value) + columns.append(column_copy) + + def is_deleted_column_constraint(constraint): + # NOTE(boris-42): There is no other way to check is CheckConstraint + # associated with deleted column. + if not isinstance(constraint, CheckConstraint): + return False + sqltext = str(constraint.sqltext) + return (sqltext.endswith("deleted in (0, 1)") or + sqltext.endswith("deleted IN (:deleted_1, :deleted_2)")) + + constraints = [] + for constraint in table.constraints: + if not is_deleted_column_constraint(constraint): + constraints.append(constraint.copy()) + + new_table = Table(table_name + "__tmp__", meta, + *(columns + constraints)) + new_table.create() + + indexes = [] + for index in insp.get_indexes(table_name): + column_names = [new_table.c[c] for c in index['column_names']] + indexes.append(Index(index["name"], *column_names, + unique=index["unique"])) + + ins = InsertFromSelect(new_table, table.select()) + migrate_engine.execute(ins) + + table.drop() + [index.create(migrate_engine) for index in indexes] + + new_table.rename(table_name) + deleted = True # workaround for pyflakes + new_table.update().\ + where(new_table.c.deleted == deleted).\ + values(deleted=new_table.c.id).\ + execute() + + # NOTE(boris-42): Fix value of deleted column: False -> "" or 0. + deleted = False # workaround for pyflakes + new_table.update().\ + where(new_table.c.deleted == deleted).\ + values(deleted=default_deleted_value).\ + execute() + + +def get_connect_string(backend, database, user=None, passwd=None): + """Get database connection + + Try to get a connection with a very specific set of values, if we get + these then we'll run the tests, otherwise they are skipped + """ + args = {'backend': backend, + 'user': user, + 'passwd': passwd, + 'database': database} + if backend == 'sqlite': + template = '%(backend)s:///%(database)s' + else: + template = "%(backend)s://%(user)s:%(passwd)s@localhost/%(database)s" + return template % args + + +def is_backend_avail(backend, database, user=None, passwd=None): + try: + connect_uri = get_connect_string(backend=backend, + database=database, + user=user, + passwd=passwd) + engine = sqlalchemy.create_engine(connect_uri) + connection = engine.connect() + except Exception: + # intentionally catch all to handle exceptions even if we don't + # have any backend code loaded. + return False + else: + connection.close() + engine.dispose() + return True + + +def get_db_connection_info(conn_pieces): + database = conn_pieces.path.strip('/') + loc_pieces = conn_pieces.netloc.split('@') + host = loc_pieces[1] + + auth_pieces = loc_pieces[0].split(':') + user = auth_pieces[0] + password = "" + if len(auth_pieces) > 1: + password = auth_pieces[1].strip() + + return (user, password, database, host) diff --git a/cerberus/openstack/common/eventlet_backdoor.py b/cerberus/openstack/common/eventlet_backdoor.py new file mode 100644 index 0000000..95b443b --- /dev/null +++ b/cerberus/openstack/common/eventlet_backdoor.py @@ -0,0 +1,146 @@ +# Copyright (c) 2012 OpenStack Foundation. +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import print_function + +import errno +import gc +import os +import pprint +import socket +import sys +import traceback + +import eventlet +import eventlet.backdoor +import greenlet +from oslo.config import cfg + +from cerberus.openstack.common.gettextutils import _LI +from cerberus.openstack.common import log as logging + +help_for_backdoor_port = ( + "Acceptable values are 0, , and :, where 0 results " + "in listening on a random tcp port number; results in listening " + "on the specified port number (and not enabling backdoor if that port " + "is in use); and : results in listening on the smallest " + "unused port number within the specified range of port numbers. The " + "chosen port is displayed in the service's log file.") +eventlet_backdoor_opts = [ + cfg.StrOpt('backdoor_port', + default=None, + help="Enable eventlet backdoor. %s" % help_for_backdoor_port) +] + +CONF = cfg.CONF +CONF.register_opts(eventlet_backdoor_opts) +LOG = logging.getLogger(__name__) + + +class EventletBackdoorConfigValueError(Exception): + def __init__(self, port_range, help_msg, ex): + msg = ('Invalid backdoor_port configuration %(range)s: %(ex)s. ' + '%(help)s' % + {'range': port_range, 'ex': ex, 'help': help_msg}) + super(EventletBackdoorConfigValueError, self).__init__(msg) + self.port_range = port_range + + +def _dont_use_this(): + print("Don't use this, just disconnect instead") + + +def _find_objects(t): + return [o for o in gc.get_objects() if isinstance(o, t)] + + +def _print_greenthreads(): + for i, gt in enumerate(_find_objects(greenlet.greenlet)): + print(i, gt) + traceback.print_stack(gt.gr_frame) + print() + + +def _print_nativethreads(): + for threadId, stack in sys._current_frames().items(): + print(threadId) + traceback.print_stack(stack) + print() + + +def _parse_port_range(port_range): + if ':' not in port_range: + start, end = port_range, port_range + else: + start, end = port_range.split(':', 1) + try: + start, end = int(start), int(end) + if end < start: + raise ValueError + return start, end + except ValueError as ex: + raise EventletBackdoorConfigValueError(port_range, ex, + help_for_backdoor_port) + + +def _listen(host, start_port, end_port, listen_func): + try_port = start_port + while True: + try: + return listen_func((host, try_port)) + except socket.error as exc: + if (exc.errno != errno.EADDRINUSE or + try_port >= end_port): + raise + try_port += 1 + + +def initialize_if_enabled(): + backdoor_locals = { + 'exit': _dont_use_this, # So we don't exit the entire process + 'quit': _dont_use_this, # So we don't exit the entire process + 'fo': _find_objects, + 'pgt': _print_greenthreads, + 'pnt': _print_nativethreads, + } + + if CONF.backdoor_port is None: + return None + + start_port, end_port = _parse_port_range(str(CONF.backdoor_port)) + + # NOTE(johannes): The standard sys.displayhook will print the value of + # the last expression and set it to __builtin__._, which overwrites + # the __builtin__._ that gettext sets. Let's switch to using pprint + # since it won't interact poorly with gettext, and it's easier to + # read the output too. + def displayhook(val): + if val is not None: + pprint.pprint(val) + sys.displayhook = displayhook + + sock = _listen('localhost', start_port, end_port, eventlet.listen) + + # In the case of backdoor port being zero, a port number is assigned by + # listen(). In any case, pull the port number out here. + port = sock.getsockname()[1] + LOG.info( + _LI('Eventlet backdoor listening on %(port)s for process %(pid)d') % + {'port': port, 'pid': os.getpid()} + ) + eventlet.spawn_n(eventlet.backdoor.backdoor_server, sock, + locals=backdoor_locals) + return port diff --git a/cerberus/openstack/common/excutils.py b/cerberus/openstack/common/excutils.py new file mode 100644 index 0000000..01f8b8e --- /dev/null +++ b/cerberus/openstack/common/excutils.py @@ -0,0 +1,113 @@ +# Copyright 2011 OpenStack Foundation. +# Copyright 2012, Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Exception related utilities. +""" + +import logging +import sys +import time +import traceback + +import six + +from cerberus.openstack.common.gettextutils import _LE + + +class save_and_reraise_exception(object): + """Save current exception, run some code and then re-raise. + + In some cases the exception context can be cleared, resulting in None + being attempted to be re-raised after an exception handler is run. This + can happen when eventlet switches greenthreads or when running an + exception handler, code raises and catches an exception. In both + cases the exception context will be cleared. + + To work around this, we save the exception state, run handler code, and + then re-raise the original exception. If another exception occurs, the + saved exception is logged and the new exception is re-raised. + + In some cases the caller may not want to re-raise the exception, and + for those circumstances this context provides a reraise flag that + can be used to suppress the exception. For example:: + + except Exception: + with save_and_reraise_exception() as ctxt: + decide_if_need_reraise() + if not should_be_reraised: + ctxt.reraise = False + + If another exception occurs and reraise flag is False, + the saved exception will not be logged. + + If the caller wants to raise new exception during exception handling + he/she sets reraise to False initially with an ability to set it back to + True if needed:: + + except Exception: + with save_and_reraise_exception(reraise=False) as ctxt: + [if statements to determine whether to raise a new exception] + # Not raising a new exception, so reraise + ctxt.reraise = True + """ + def __init__(self, reraise=True): + self.reraise = reraise + + def __enter__(self): + self.type_, self.value, self.tb, = sys.exc_info() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + if self.reraise: + logging.error(_LE('Original exception being dropped: %s'), + traceback.format_exception(self.type_, + self.value, + self.tb)) + return False + if self.reraise: + six.reraise(self.type_, self.value, self.tb) + + +def forever_retry_uncaught_exceptions(infunc): + def inner_func(*args, **kwargs): + last_log_time = 0 + last_exc_message = None + exc_count = 0 + while True: + try: + return infunc(*args, **kwargs) + except Exception as exc: + this_exc_message = six.u(str(exc)) + if this_exc_message == last_exc_message: + exc_count += 1 + else: + exc_count = 1 + # Do not log any more frequently than once a minute unless + # the exception message changes + cur_time = int(time.time()) + if (cur_time - last_log_time > 60 or + this_exc_message != last_exc_message): + logging.exception( + _LE('Unexpected exception occurred %d time(s)... ' + 'retrying.') % exc_count) + last_log_time = cur_time + last_exc_message = this_exc_message + exc_count = 0 + # This should be a very rare event. In case it isn't, do + # a sleep. + time.sleep(1) + return inner_func diff --git a/cerberus/openstack/common/fileutils.py b/cerberus/openstack/common/fileutils.py new file mode 100644 index 0000000..2b804a2 --- /dev/null +++ b/cerberus/openstack/common/fileutils.py @@ -0,0 +1,135 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import contextlib +import errno +import os +import tempfile + +from cerberus.openstack.common import excutils +from cerberus.openstack.common import log as logging + +LOG = logging.getLogger(__name__) + +_FILE_CACHE = {} + + +def ensure_tree(path): + """Create a directory (and any ancestor directories required) + + :param path: Directory to create + """ + try: + os.makedirs(path) + except OSError as exc: + if exc.errno == errno.EEXIST: + if not os.path.isdir(path): + raise + else: + raise + + +def read_cached_file(filename, force_reload=False): + """Read from a file if it has been modified. + + :param force_reload: Whether to reload the file. + :returns: A tuple with a boolean specifying if the data is fresh + or not. + """ + global _FILE_CACHE + + if force_reload and filename in _FILE_CACHE: + del _FILE_CACHE[filename] + + reloaded = False + mtime = os.path.getmtime(filename) + cache_info = _FILE_CACHE.setdefault(filename, {}) + + if not cache_info or mtime > cache_info.get('mtime', 0): + LOG.debug("Reloading cached file %s" % filename) + with open(filename) as fap: + cache_info['data'] = fap.read() + cache_info['mtime'] = mtime + reloaded = True + return (reloaded, cache_info['data']) + + +def delete_if_exists(path, remove=os.unlink): + """Delete a file, but ignore file not found error. + + :param path: File to delete + :param remove: Optional function to remove passed path + """ + + try: + remove(path) + except OSError as e: + if e.errno != errno.ENOENT: + raise + + +@contextlib.contextmanager +def remove_path_on_error(path, remove=delete_if_exists): + """Protect code that wants to operate on PATH atomically. + Any exception will cause PATH to be removed. + + :param path: File to work with + :param remove: Optional function to remove passed path + """ + + try: + yield + except Exception: + with excutils.save_and_reraise_exception(): + remove(path) + + +def file_open(*args, **kwargs): + """Open file + + see built-in file() documentation for more details + + Note: The reason this is kept in a separate module is to easily + be able to provide a stub module that doesn't alter system + state at all (for unit tests) + """ + return file(*args, **kwargs) + + +def write_to_tempfile(content, path=None, suffix='', prefix='tmp'): + """Create temporary file or use existing file. + + This util is needed for creating temporary file with + specified content, suffix and prefix. If path is not None, + it will be used for writing content. If the path doesn't + exist it'll be created. + + :param content: content for temporary file. + :param path: same as parameter 'dir' for mkstemp + :param suffix: same as parameter 'suffix' for mkstemp + :param prefix: same as parameter 'prefix' for mkstemp + + For example: it can be used in database tests for creating + configuration files. + """ + if path: + ensure_tree(path) + + (fd, path) = tempfile.mkstemp(suffix=suffix, dir=path, prefix=prefix) + try: + os.write(fd, content) + finally: + os.close(fd) + return path diff --git a/cerberus/openstack/common/fixture/__init__.py b/cerberus/openstack/common/fixture/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cerberus/openstack/common/fixture/config.py b/cerberus/openstack/common/fixture/config.py new file mode 100644 index 0000000..9489b85 --- /dev/null +++ b/cerberus/openstack/common/fixture/config.py @@ -0,0 +1,85 @@ +# +# Copyright 2013 Mirantis, Inc. +# Copyright 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import fixtures +from oslo.config import cfg +import six + + +class Config(fixtures.Fixture): + """Allows overriding configuration settings for the test. + + `conf` will be reset on cleanup. + + """ + + def __init__(self, conf=cfg.CONF): + self.conf = conf + + def setUp(self): + super(Config, self).setUp() + # NOTE(morganfainberg): unregister must be added to cleanup before + # reset is because cleanup works in reverse order of registered items, + # and a reset must occur before unregistering options can occur. + self.addCleanup(self._unregister_config_opts) + self.addCleanup(self.conf.reset) + self._registered_config_opts = {} + + def config(self, **kw): + """Override configuration values. + + The keyword arguments are the names of configuration options to + override and their values. + + If a `group` argument is supplied, the overrides are applied to + the specified configuration option group, otherwise the overrides + are applied to the ``default`` group. + + """ + + group = kw.pop('group', None) + for k, v in six.iteritems(kw): + self.conf.set_override(k, v, group) + + def _unregister_config_opts(self): + for group in self._registered_config_opts: + self.conf.unregister_opts(self._registered_config_opts[group], + group=group) + + def register_opt(self, opt, group=None): + """Register a single option for the test run. + + Options registered in this manner will automatically be unregistered + during cleanup. + + If a `group` argument is supplied, it will register the new option + to that group, otherwise the option is registered to the ``default`` + group. + """ + self.conf.register_opt(opt, group=group) + self._registered_config_opts.setdefault(group, set()).add(opt) + + def register_opts(self, opts, group=None): + """Register multiple options for the test run. + + This works in the same manner as register_opt() but takes a list of + options as the first argument. All arguments will be registered to the + same group if the ``group`` argument is supplied, otherwise all options + will be registered to the ``default`` group. + """ + for opt in opts: + self.register_opt(opt, group=group) diff --git a/cerberus/openstack/common/fixture/lockutils.py b/cerberus/openstack/common/fixture/lockutils.py new file mode 100644 index 0000000..2ecd0dc --- /dev/null +++ b/cerberus/openstack/common/fixture/lockutils.py @@ -0,0 +1,51 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import fixtures + +from cerberus.openstack.common import lockutils + + +class LockFixture(fixtures.Fixture): + """External locking fixture. + + This fixture is basically an alternative to the synchronized decorator with + the external flag so that tearDowns and addCleanups will be included in + the lock context for locking between tests. The fixture is recommended to + be the first line in a test method, like so:: + + def test_method(self): + self.useFixture(LockFixture) + ... + + or the first line in setUp if all the test methods in the class are + required to be serialized. Something like:: + + class TestCase(testtools.testcase): + def setUp(self): + self.useFixture(LockFixture) + super(TestCase, self).setUp() + ... + + This is because addCleanups are put on a LIFO queue that gets run after the + test method exits. (either by completing or raising an exception) + """ + def __init__(self, name, lock_file_prefix=None): + self.mgr = lockutils.lock(name, lock_file_prefix, True) + + def setUp(self): + super(LockFixture, self).setUp() + self.addCleanup(self.mgr.__exit__, None, None, None) + self.lock = self.mgr.__enter__() diff --git a/cerberus/openstack/common/fixture/logging.py b/cerberus/openstack/common/fixture/logging.py new file mode 100644 index 0000000..3823a03 --- /dev/null +++ b/cerberus/openstack/common/fixture/logging.py @@ -0,0 +1,34 @@ +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import fixtures + + +def get_logging_handle_error_fixture(): + """returns a fixture to make logging raise formatting exceptions. + + Usage: + self.useFixture(logging.get_logging_handle_error_fixture()) + """ + return fixtures.MonkeyPatch('logging.Handler.handleError', + _handleError) + + +def _handleError(self, record): + """Monkey patch for logging.Handler.handleError. + + The default handleError just logs the error to stderr but we want + the option of actually raising an exception. + """ + raise diff --git a/cerberus/openstack/common/fixture/mockpatch.py b/cerberus/openstack/common/fixture/mockpatch.py new file mode 100644 index 0000000..f6316ef --- /dev/null +++ b/cerberus/openstack/common/fixture/mockpatch.py @@ -0,0 +1,62 @@ +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2013 Hewlett-Packard Development Company, L.P. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +############################################################################## +############################################################################## +## +## DO NOT MODIFY THIS FILE +## +## This file is being graduated to the cerberustest library. Please make all +## changes there, and only backport critical fixes here. - dhellmann +## +############################################################################## +############################################################################## + +import fixtures +import mock + + +class PatchObject(fixtures.Fixture): + """Deal with code around mock.""" + + def __init__(self, obj, attr, new=mock.DEFAULT, **kwargs): + self.obj = obj + self.attr = attr + self.kwargs = kwargs + self.new = new + + def setUp(self): + super(PatchObject, self).setUp() + _p = mock.patch.object(self.obj, self.attr, self.new, **self.kwargs) + self.mock = _p.start() + self.addCleanup(_p.stop) + + +class Patch(fixtures.Fixture): + + """Deal with code around mock.patch.""" + + def __init__(self, obj, new=mock.DEFAULT, **kwargs): + self.obj = obj + self.kwargs = kwargs + self.new = new + + def setUp(self): + super(Patch, self).setUp() + _p = mock.patch(self.obj, self.new, **self.kwargs) + self.mock = _p.start() + self.addCleanup(_p.stop) diff --git a/cerberus/openstack/common/fixture/moxstubout.py b/cerberus/openstack/common/fixture/moxstubout.py new file mode 100644 index 0000000..15b35bd --- /dev/null +++ b/cerberus/openstack/common/fixture/moxstubout.py @@ -0,0 +1,43 @@ +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2013 Hewlett-Packard Development Company, L.P. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +############################################################################## +############################################################################## +## +## DO NOT MODIFY THIS FILE +## +## This file is being graduated to the cerberustest library. Please make all +## changes there, and only backport critical fixes here. - dhellmann +## +############################################################################## +############################################################################## + +import fixtures +from six.moves import mox + + +class MoxStubout(fixtures.Fixture): + """Deal with code around mox and stubout as a fixture.""" + + def setUp(self): + super(MoxStubout, self).setUp() + # emulate some of the mox stuff, we can't use the metaclass + # because it screws with our generators + self.mox = mox.Mox() + self.stubs = self.mox.stubs + self.addCleanup(self.mox.UnsetStubs) + self.addCleanup(self.mox.VerifyAll) diff --git a/cerberus/openstack/common/gettextutils.py b/cerberus/openstack/common/gettextutils.py new file mode 100644 index 0000000..a69ed04 --- /dev/null +++ b/cerberus/openstack/common/gettextutils.py @@ -0,0 +1,448 @@ +# Copyright 2012 Red Hat, Inc. +# Copyright 2013 IBM Corp. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +gettext for openstack-common modules. + +Usual usage in an openstack.common module: + + from cerberus.openstack.common.gettextutils import _ +""" + +import copy +import functools +import gettext +import locale +from logging import handlers +import os + +from babel import localedata +import six + +_localedir = os.environ.get('cerberus'.upper() + '_LOCALEDIR') +_t = gettext.translation('cerberus', localedir=_localedir, fallback=True) + +# We use separate translation catalogs for each log level, so set up a +# mapping between the log level name and the translator. The domain +# for the log level is project_name + "-log-" + log_level so messages +# for each level end up in their own catalog. +_t_log_levels = dict( + (level, gettext.translation('cerberus' + '-log-' + level, + localedir=_localedir, + fallback=True)) + for level in ['info', 'warning', 'error', 'critical'] +) + +_AVAILABLE_LANGUAGES = {} +USE_LAZY = False + + +def enable_lazy(): + """Convenience function for configuring _() to use lazy gettext + + Call this at the start of execution to enable the gettextutils._ + function to use lazy gettext functionality. This is useful if + your project is importing _ directly instead of using the + gettextutils.install() way of importing the _ function. + """ + global USE_LAZY + USE_LAZY = True + + +def _(msg): + if USE_LAZY: + return Message(msg, domain='cerberus') + else: + if six.PY3: + return _t.gettext(msg) + return _t.ugettext(msg) + + +def _log_translation(msg, level): + """Build a single translation of a log message + """ + if USE_LAZY: + return Message(msg, domain='cerberus' + '-log-' + level) + else: + translator = _t_log_levels[level] + if six.PY3: + return translator.gettext(msg) + return translator.ugettext(msg) + +# Translators for log levels. +# +# The abbreviated names are meant to reflect the usual use of a short +# name like '_'. The "L" is for "log" and the other letter comes from +# the level. +_LI = functools.partial(_log_translation, level='info') +_LW = functools.partial(_log_translation, level='warning') +_LE = functools.partial(_log_translation, level='error') +_LC = functools.partial(_log_translation, level='critical') + + +def install(domain, lazy=False): + """Install a _() function using the given translation domain. + + Given a translation domain, install a _() function using gettext's + install() function. + + The main difference from gettext.install() is that we allow + overriding the default localedir (e.g. /usr/share/locale) using + a translation-domain-specific environment variable (e.g. + NOVA_LOCALEDIR). + + :param domain: the translation domain + :param lazy: indicates whether or not to install the lazy _() function. + The lazy _() introduces a way to do deferred translation + of messages by installing a _ that builds Message objects, + instead of strings, which can then be lazily translated into + any available locale. + """ + if lazy: + # NOTE(mrodden): Lazy gettext functionality. + # + # The following introduces a deferred way to do translations on + # messages in OpenStack. We override the standard _() function + # and % (format string) operation to build Message objects that can + # later be translated when we have more information. + def _lazy_gettext(msg): + """Create and return a Message object. + + Lazy gettext function for a given domain, it is a factory method + for a project/module to get a lazy gettext function for its own + translation domain (i.e. nova, glance, cinder, etc.) + + Message encapsulates a string so that we can translate + it later when needed. + """ + return Message(msg, domain=domain) + + from six import moves + moves.builtins.__dict__['_'] = _lazy_gettext + else: + localedir = '%s_LOCALEDIR' % domain.upper() + if six.PY3: + gettext.install(domain, + localedir=os.environ.get(localedir)) + else: + gettext.install(domain, + localedir=os.environ.get(localedir), + unicode=True) + + +class Message(six.text_type): + """A Message object is a unicode object that can be translated. + + Translation of Message is done explicitly using the translate() method. + For all non-translation intents and purposes, a Message is simply unicode, + and can be treated as such. + """ + + def __new__(cls, msgid, msgtext=None, params=None, + domain='cerberus', *args): + """Create a new Message object. + + In order for translation to work gettext requires a message ID, this + msgid will be used as the base unicode text. It is also possible + for the msgid and the base unicode text to be different by passing + the msgtext parameter. + """ + # If the base msgtext is not given, we use the default translation + # of the msgid (which is in English) just in case the system locale is + # not English, so that the base text will be in that locale by default. + if not msgtext: + msgtext = Message._translate_msgid(msgid, domain) + # We want to initialize the parent unicode with the actual object that + # would have been plain unicode if 'Message' was not enabled. + msg = super(Message, cls).__new__(cls, msgtext) + msg.msgid = msgid + msg.domain = domain + msg.params = params + return msg + + def translate(self, desired_locale=None): + """Translate this message to the desired locale. + + :param desired_locale: The desired locale to translate the message to, + if no locale is provided the message will be + translated to the system's default locale. + + :returns: the translated message in unicode + """ + + translated_message = Message._translate_msgid(self.msgid, + self.domain, + desired_locale) + if self.params is None: + # No need for more translation + return translated_message + + # This Message object may have been formatted with one or more + # Message objects as substitution arguments, given either as a single + # argument, part of a tuple, or as one or more values in a dictionary. + # When translating this Message we need to translate those Messages too + translated_params = _translate_args(self.params, desired_locale) + + translated_message = translated_message % translated_params + + return translated_message + + @staticmethod + def _translate_msgid(msgid, domain, desired_locale=None): + if not desired_locale: + system_locale = locale.getdefaultlocale() + # If the system locale is not available to the runtime use English + if not system_locale[0]: + desired_locale = 'en_US' + else: + desired_locale = system_locale[0] + + locale_dir = os.environ.get(domain.upper() + '_LOCALEDIR') + lang = gettext.translation(domain, + localedir=locale_dir, + languages=[desired_locale], + fallback=True) + if six.PY3: + translator = lang.gettext + else: + translator = lang.ugettext + + translated_message = translator(msgid) + return translated_message + + def __mod__(self, other): + # When we mod a Message we want the actual operation to be performed + # by the parent class (i.e. unicode()), the only thing we do here is + # save the original msgid and the parameters in case of a translation + params = self._sanitize_mod_params(other) + unicode_mod = super(Message, self).__mod__(params) + modded = Message(self.msgid, + msgtext=unicode_mod, + params=params, + domain=self.domain) + return modded + + def _sanitize_mod_params(self, other): + """Sanitize the object being modded with this Message. + + - Add support for modding 'None' so translation supports it + - Trim the modded object, which can be a large dictionary, to only + those keys that would actually be used in a translation + - Snapshot the object being modded, in case the message is + translated, it will be used as it was when the Message was created + """ + if other is None: + params = (other,) + elif isinstance(other, dict): + # Merge the dictionaries + # Copy each item in case one does not support deep copy. + params = {} + if isinstance(self.params, dict): + for key, val in self.params.items(): + params[key] = self._copy_param(val) + for key, val in other.items(): + params[key] = self._copy_param(val) + else: + params = self._copy_param(other) + return params + + def _copy_param(self, param): + try: + return copy.deepcopy(param) + except Exception: + # Fallback to casting to unicode this will handle the + # python code-like objects that can't be deep-copied + return six.text_type(param) + + def __add__(self, other): + msg = _('Message objects do not support addition.') + raise TypeError(msg) + + def __radd__(self, other): + return self.__add__(other) + + def __str__(self): + # NOTE(luisg): Logging in python 2.6 tries to str() log records, + # and it expects specifically a UnicodeError in order to proceed. + msg = _('Message objects do not support str() because they may ' + 'contain non-ascii characters. ' + 'Please use unicode() or translate() instead.') + raise UnicodeError(msg) + + +def get_available_languages(domain): + """Lists the available languages for the given translation domain. + + :param domain: the domain to get languages for + """ + if domain in _AVAILABLE_LANGUAGES: + return copy.copy(_AVAILABLE_LANGUAGES[domain]) + + localedir = '%s_LOCALEDIR' % domain.upper() + find = lambda x: gettext.find(domain, + localedir=os.environ.get(localedir), + languages=[x]) + + # NOTE(mrodden): en_US should always be available (and first in case + # order matters) since our in-line message strings are en_US + language_list = ['en_US'] + # NOTE(luisg): Babel <1.0 used a function called list(), which was + # renamed to locale_identifiers() in >=1.0, the requirements master list + # requires >=0.9.6, uncapped, so defensively work with both. We can remove + # this check when the master list updates to >=1.0, and update all projects + list_identifiers = (getattr(localedata, 'list', None) or + getattr(localedata, 'locale_identifiers')) + locale_identifiers = list_identifiers() + + for i in locale_identifiers: + if find(i) is not None: + language_list.append(i) + + # NOTE(luisg): Babel>=1.0,<1.3 has a bug where some OpenStack supported + # locales (e.g. 'zh_CN', and 'zh_TW') aren't supported even though they + # are perfectly legitimate locales: + # https://github.com/mitsuhiko/babel/issues/37 + # In Babel 1.3 they fixed the bug and they support these locales, but + # they are still not explicitly "listed" by locale_identifiers(). + # That is why we add the locales here explicitly if necessary so that + # they are listed as supported. + aliases = {'zh': 'zh_CN', + 'zh_Hant_HK': 'zh_HK', + 'zh_Hant': 'zh_TW', + 'fil': 'tl_PH'} + for (locale, alias) in six.iteritems(aliases): + if locale in language_list and alias not in language_list: + language_list.append(alias) + + _AVAILABLE_LANGUAGES[domain] = language_list + return copy.copy(language_list) + + +def translate(obj, desired_locale=None): + """Gets the translated unicode representation of the given object. + + If the object is not translatable it is returned as-is. + If the locale is None the object is translated to the system locale. + + :param obj: the object to translate + :param desired_locale: the locale to translate the message to, if None the + default system locale will be used + :returns: the translated object in unicode, or the original object if + it could not be translated + """ + message = obj + if not isinstance(message, Message): + # If the object to translate is not already translatable, + # let's first get its unicode representation + message = six.text_type(obj) + if isinstance(message, Message): + # Even after unicoding() we still need to check if we are + # running with translatable unicode before translating + return message.translate(desired_locale) + return obj + + +def _translate_args(args, desired_locale=None): + """Translates all the translatable elements of the given arguments object. + + This method is used for translating the translatable values in method + arguments which include values of tuples or dictionaries. + If the object is not a tuple or a dictionary the object itself is + translated if it is translatable. + + If the locale is None the object is translated to the system locale. + + :param args: the args to translate + :param desired_locale: the locale to translate the args to, if None the + default system locale will be used + :returns: a new args object with the translated contents of the original + """ + if isinstance(args, tuple): + return tuple(translate(v, desired_locale) for v in args) + if isinstance(args, dict): + translated_dict = {} + for (k, v) in six.iteritems(args): + translated_v = translate(v, desired_locale) + translated_dict[k] = translated_v + return translated_dict + return translate(args, desired_locale) + + +class TranslationHandler(handlers.MemoryHandler): + """Handler that translates records before logging them. + + The TranslationHandler takes a locale and a target logging.Handler object + to forward LogRecord objects to after translating them. This handler + depends on Message objects being logged, instead of regular strings. + + The handler can be configured declaratively in the logging.conf as follows: + + [handlers] + keys = translatedlog, translator + + [handler_translatedlog] + class = handlers.WatchedFileHandler + args = ('/var/log/api-localized.log',) + formatter = context + + [handler_translator] + class = openstack.common.log.TranslationHandler + target = translatedlog + args = ('zh_CN',) + + If the specified locale is not available in the system, the handler will + log in the default locale. + """ + + def __init__(self, locale=None, target=None): + """Initialize a TranslationHandler + + :param locale: locale to use for translating messages + :param target: logging.Handler object to forward + LogRecord objects to after translation + """ + # NOTE(luisg): In order to allow this handler to be a wrapper for + # other handlers, such as a FileHandler, and still be able to + # configure it using logging.conf, this handler has to extend + # MemoryHandler because only the MemoryHandlers' logging.conf + # parsing is implemented such that it accepts a target handler. + handlers.MemoryHandler.__init__(self, capacity=0, target=target) + self.locale = locale + + def setFormatter(self, fmt): + self.target.setFormatter(fmt) + + def emit(self, record): + # We save the message from the original record to restore it + # after translation, so other handlers are not affected by this + original_msg = record.msg + original_args = record.args + + try: + self._translate_and_log_record(record) + finally: + record.msg = original_msg + record.args = original_args + + def _translate_and_log_record(self, record): + record.msg = translate(record.msg, self.locale) + + # In addition to translating the message, we also need to translate + # arguments that were passed to the log method that were not part + # of the main message e.g., log.info(_('Some message %s'), this_one)) + record.args = _translate_args(record.args, self.locale) + + self.target.emit(record) diff --git a/cerberus/openstack/common/importutils.py b/cerberus/openstack/common/importutils.py new file mode 100644 index 0000000..af78d95 --- /dev/null +++ b/cerberus/openstack/common/importutils.py @@ -0,0 +1,73 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Import related utilities and helper functions. +""" + +import sys +import traceback + + +def import_class(import_str): + """Returns a class from a string including module and class.""" + mod_str, _sep, class_str = import_str.rpartition('.') + try: + __import__(mod_str) + return getattr(sys.modules[mod_str], class_str) + except (ValueError, AttributeError): + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +def import_object(import_str, *args, **kwargs): + """Import a class and return an instance of it.""" + return import_class(import_str)(*args, **kwargs) + + +def import_object_ns(name_space, import_str, *args, **kwargs): + """Tries to import object from default namespace. + + Imports a class and return an instance of it, first by trying + to find the class in a default namespace, then failing back to + a full path if not found in the default namespace. + """ + import_value = "%s.%s" % (name_space, import_str) + try: + return import_class(import_value)(*args, **kwargs) + except ImportError: + return import_class(import_str)(*args, **kwargs) + + +def import_module(import_str): + """Import a module.""" + __import__(import_str) + return sys.modules[import_str] + + +def import_versioned_module(version, submodule=None): + module = 'cerberus.v%s' % version + if submodule: + module = '.'.join((module, submodule)) + return import_module(module) + + +def try_import(import_str, default=None): + """Try to import a module and if it fails return default.""" + try: + return import_module(import_str) + except ImportError: + return default diff --git a/cerberus/openstack/common/jsonutils.py b/cerberus/openstack/common/jsonutils.py new file mode 100644 index 0000000..fa7073b --- /dev/null +++ b/cerberus/openstack/common/jsonutils.py @@ -0,0 +1,186 @@ +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2011 Justin Santa Barbara +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +''' +JSON related utilities. + +This module provides a few things: + + 1) A handy function for getting an object down to something that can be + JSON serialized. See to_primitive(). + + 2) Wrappers around loads() and dumps(). The dumps() wrapper will + automatically use to_primitive() for you if needed. + + 3) This sets up anyjson to use the loads() and dumps() wrappers if anyjson + is available. +''' + + +import codecs +import datetime +import functools +import inspect +import itertools +import sys + +if sys.version_info < (2, 7): + # On Python <= 2.6, json module is not C boosted, so try to use + # simplejson module if available + try: + import simplejson as json + except ImportError: + import json +else: + import json + +import six +import six.moves.xmlrpc_client as xmlrpclib + +from cerberus.openstack.common import gettextutils +from cerberus.openstack.common import importutils +from cerberus.openstack.common import strutils +from cerberus.openstack.common import timeutils + +netaddr = importutils.try_import("netaddr") + +_nasty_type_tests = [inspect.ismodule, inspect.isclass, inspect.ismethod, + inspect.isfunction, inspect.isgeneratorfunction, + inspect.isgenerator, inspect.istraceback, inspect.isframe, + inspect.iscode, inspect.isbuiltin, inspect.isroutine, + inspect.isabstract] + +_simple_types = (six.string_types + six.integer_types + + (type(None), bool, float)) + + +def to_primitive(value, convert_instances=False, convert_datetime=True, + level=0, max_depth=3): + """Convert a complex object into primitives. + + Handy for JSON serialization. We can optionally handle instances, + but since this is a recursive function, we could have cyclical + data structures. + + To handle cyclical data structures we could track the actual objects + visited in a set, but not all objects are hashable. Instead we just + track the depth of the object inspections and don't go too deep. + + Therefore, convert_instances=True is lossy ... be aware. + + """ + # handle obvious types first - order of basic types determined by running + # full tests on nova project, resulting in the following counts: + # 572754 + # 460353 + # 379632 + # 274610 + # 199918 + # 114200 + # 51817 + # 26164 + # 6491 + # 283 + # 19 + if isinstance(value, _simple_types): + return value + + if isinstance(value, datetime.datetime): + if convert_datetime: + return timeutils.strtime(value) + else: + return value + + # value of itertools.count doesn't get caught by nasty_type_tests + # and results in infinite loop when list(value) is called. + if type(value) == itertools.count: + return six.text_type(value) + + # FIXME(vish): Workaround for LP bug 852095. Without this workaround, + # tests that raise an exception in a mocked method that + # has a @wrap_exception with a notifier will fail. If + # we up the dependency to 0.5.4 (when it is released) we + # can remove this workaround. + if getattr(value, '__module__', None) == 'mox': + return 'mock' + + if level > max_depth: + return '?' + + # The try block may not be necessary after the class check above, + # but just in case ... + try: + recursive = functools.partial(to_primitive, + convert_instances=convert_instances, + convert_datetime=convert_datetime, + level=level, + max_depth=max_depth) + if isinstance(value, dict): + return dict((k, recursive(v)) for k, v in six.iteritems(value)) + elif isinstance(value, (list, tuple)): + return [recursive(lv) for lv in value] + + # It's not clear why xmlrpclib created their own DateTime type, but + # for our purposes, make it a datetime type which is explicitly + # handled + if isinstance(value, xmlrpclib.DateTime): + value = datetime.datetime(*tuple(value.timetuple())[:6]) + + if convert_datetime and isinstance(value, datetime.datetime): + return timeutils.strtime(value) + elif isinstance(value, gettextutils.Message): + return value.data + elif hasattr(value, 'iteritems'): + return recursive(dict(value.iteritems()), level=level + 1) + elif hasattr(value, '__iter__'): + return recursive(list(value)) + elif convert_instances and hasattr(value, '__dict__'): + # Likely an instance of something. Watch for cycles. + # Ignore class member vars. + return recursive(value.__dict__, level=level + 1) + elif netaddr and isinstance(value, netaddr.IPAddress): + return six.text_type(value) + else: + if any(test(value) for test in _nasty_type_tests): + return six.text_type(value) + return value + except TypeError: + # Class objects are tricky since they may define something like + # __iter__ defined but it isn't callable as list(). + return six.text_type(value) + + +def dumps(value, default=to_primitive, **kwargs): + return json.dumps(value, default=default, **kwargs) + + +def loads(s, encoding='utf-8'): + return json.loads(strutils.safe_decode(s, encoding)) + + +def load(fp, encoding='utf-8'): + return json.load(codecs.getreader(encoding)(fp)) + + +try: + import anyjson +except ImportError: + pass +else: + anyjson._modules.append((__name__, 'dumps', TypeError, + 'loads', ValueError, 'load')) + anyjson.force_implementation(__name__) diff --git a/cerberus/openstack/common/local.py b/cerberus/openstack/common/local.py new file mode 100644 index 0000000..0819d5b --- /dev/null +++ b/cerberus/openstack/common/local.py @@ -0,0 +1,45 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Local storage of variables using weak references""" + +import threading +import weakref + + +class WeakLocal(threading.local): + def __getattribute__(self, attr): + rval = super(WeakLocal, self).__getattribute__(attr) + if rval: + # NOTE(mikal): this bit is confusing. What is stored is a weak + # reference, not the value itself. We therefore need to lookup + # the weak reference and return the inner value here. + rval = rval() + return rval + + def __setattr__(self, attr, value): + value = weakref.ref(value) + return super(WeakLocal, self).__setattr__(attr, value) + + +# NOTE(mikal): the name "store" should be deprecated in the future +store = WeakLocal() + +# A "weak" store uses weak references and allows an object to fall out of scope +# when it falls out of scope in the code that uses the thread local storage. A +# "strong" store will hold a reference to the object so that it never falls out +# of scope. +weak_store = WeakLocal() +strong_store = threading.local() diff --git a/cerberus/openstack/common/lockutils.py b/cerberus/openstack/common/lockutils.py new file mode 100644 index 0000000..3a54542 --- /dev/null +++ b/cerberus/openstack/common/lockutils.py @@ -0,0 +1,377 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import contextlib +import errno +import fcntl +import functools +import os +import shutil +import subprocess +import sys +import tempfile +import threading +import time +import weakref + +from oslo.config import cfg + +from cerberus.openstack.common import fileutils +from cerberus.openstack.common.gettextutils import _, _LE, _LI +from cerberus.openstack.common import log as logging + + +LOG = logging.getLogger(__name__) + + +util_opts = [ + cfg.BoolOpt('disable_process_locking', default=False, + help='Whether to disable inter-process locks'), + cfg.StrOpt('lock_path', + default=os.environ.get("CERBERUS_LOCK_PATH"), + help=('Directory to use for lock files.')) +] + + +CONF = cfg.CONF +CONF.register_opts(util_opts) + + +def set_defaults(lock_path): + cfg.set_defaults(util_opts, lock_path=lock_path) + + +class _FileLock(object): + """Lock implementation which allows multiple locks, working around + issues like bugs.debian.org/cgi-bin/bugreport.cgi?bug=632857 and does + not require any cleanup. Since the lock is always held on a file + descriptor rather than outside of the process, the lock gets dropped + automatically if the process crashes, even if __exit__ is not executed. + + There are no guarantees regarding usage by multiple green threads in a + single process here. This lock works only between processes. Exclusive + access between local threads should be achieved using the semaphores + in the @synchronized decorator. + + Note these locks are released when the descriptor is closed, so it's not + safe to close the file descriptor while another green thread holds the + lock. Just opening and closing the lock file can break synchronisation, + so lock files must be accessed only using this abstraction. + """ + + def __init__(self, name): + self.lockfile = None + self.fname = name + + def acquire(self): + basedir = os.path.dirname(self.fname) + + if not os.path.exists(basedir): + fileutils.ensure_tree(basedir) + LOG.info(_LI('Created lock path: %s'), basedir) + + self.lockfile = open(self.fname, 'w') + + while True: + try: + # Using non-blocking locks since green threads are not + # patched to deal with blocking locking calls. + # Also upon reading the MSDN docs for locking(), it seems + # to have a laughable 10 attempts "blocking" mechanism. + self.trylock() + LOG.debug('Got file lock "%s"', self.fname) + return True + except IOError as e: + if e.errno in (errno.EACCES, errno.EAGAIN): + # external locks synchronise things like iptables + # updates - give it some time to prevent busy spinning + time.sleep(0.01) + else: + raise threading.ThreadError(_("Unable to acquire lock on" + " `%(filename)s` due to" + " %(exception)s") % + { + 'filename': self.fname, + 'exception': e, + }) + + def __enter__(self): + self.acquire() + return self + + def release(self): + try: + self.unlock() + self.lockfile.close() + LOG.debug('Released file lock "%s"', self.fname) + except IOError: + LOG.exception(_LE("Could not release the acquired lock `%s`"), + self.fname) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.release() + + def exists(self): + return os.path.exists(self.fname) + + def trylock(self): + raise NotImplementedError() + + def unlock(self): + raise NotImplementedError() + + +class _WindowsLock(_FileLock): + def trylock(self): + msvcrt.locking(self.lockfile.fileno(), msvcrt.LK_NBLCK, 1) + + def unlock(self): + msvcrt.locking(self.lockfile.fileno(), msvcrt.LK_UNLCK, 1) + + +class _FcntlLock(_FileLock): + def trylock(self): + fcntl.lockf(self.lockfile, fcntl.LOCK_EX | fcntl.LOCK_NB) + + def unlock(self): + fcntl.lockf(self.lockfile, fcntl.LOCK_UN) + + +class _PosixLock(object): + def __init__(self, name): + # Hash the name because it's not valid to have POSIX semaphore + # names with things like / in them. Then use base64 to encode + # the digest() instead taking the hexdigest() because the + # result is shorter and most systems can't have shm sempahore + # names longer than 31 characters. + h = hashlib.sha1() + h.update(name.encode('ascii')) + self.name = str((b'/' + base64.urlsafe_b64encode( + h.digest())).decode('ascii')) + + def acquire(self, timeout=None): + self.semaphore = posix_ipc.Semaphore(self.name, + flags=posix_ipc.O_CREAT, + initial_value=1) + self.semaphore.acquire(timeout) + return self + + def __enter__(self): + self.acquire() + return self + + def release(self): + self.semaphore.release() + self.semaphore.close() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.release() + + def exists(self): + try: + semaphore = posix_ipc.Semaphore(self.name) + except posix_ipc.ExistentialError: + return False + else: + semaphore.close() + return True + + +if os.name == 'nt': + import msvcrt + InterProcessLock = _WindowsLock + FileLock = _WindowsLock +else: + import base64 + import hashlib + import posix_ipc + InterProcessLock = _PosixLock + FileLock = _FcntlLock + +_semaphores = weakref.WeakValueDictionary() +_semaphores_lock = threading.Lock() + + +def _get_lock_path(name, lock_file_prefix, lock_path=None): + # NOTE(mikal): the lock name cannot contain directory + # separators + name = name.replace(os.sep, '_') + if lock_file_prefix: + sep = '' if lock_file_prefix.endswith('-') else '-' + name = '%s%s%s' % (lock_file_prefix, sep, name) + + local_lock_path = lock_path or CONF.lock_path + + if not local_lock_path: + # NOTE(bnemec): Create a fake lock path for posix locks so we don't + # unnecessarily raise the RequiredOptError below. + if InterProcessLock is not _PosixLock: + raise cfg.RequiredOptError('lock_path') + local_lock_path = 'posixlock:/' + + return os.path.join(local_lock_path, name) + + +def external_lock(name, lock_file_prefix=None, lock_path=None): + LOG.debug('Attempting to grab external lock "%(lock)s"', + {'lock': name}) + + lock_file_path = _get_lock_path(name, lock_file_prefix, lock_path) + + # NOTE(bnemec): If an explicit lock_path was passed to us then it + # means the caller is relying on file-based locking behavior, so + # we can't use posix locks for those calls. + if lock_path: + return FileLock(lock_file_path) + return InterProcessLock(lock_file_path) + + +def remove_external_lock_file(name, lock_file_prefix=None): + """Remove a external lock file when it's not used anymore + This will be helpful when we have a lot of lock files + """ + with internal_lock(name): + lock_file_path = _get_lock_path(name, lock_file_prefix) + try: + os.remove(lock_file_path) + except OSError: + LOG.info(_LI('Failed to remove file %(file)s'), + {'file': lock_file_path}) + + +def internal_lock(name): + with _semaphores_lock: + try: + sem = _semaphores[name] + except KeyError: + sem = threading.Semaphore() + _semaphores[name] = sem + + LOG.debug('Got semaphore "%(lock)s"', {'lock': name}) + return sem + + +@contextlib.contextmanager +def lock(name, lock_file_prefix=None, external=False, lock_path=None): + """Context based lock + + This function yields a `threading.Semaphore` instance (if we don't use + eventlet.monkey_patch(), else `semaphore.Semaphore`) unless external is + True, in which case, it'll yield an InterProcessLock instance. + + :param lock_file_prefix: The lock_file_prefix argument is used to provide + lock files on disk with a meaningful prefix. + + :param external: The external keyword argument denotes whether this lock + should work across multiple processes. This means that if two different + workers both run a a method decorated with @synchronized('mylock', + external=True), only one of them will execute at a time. + """ + int_lock = internal_lock(name) + with int_lock: + if external and not CONF.disable_process_locking: + ext_lock = external_lock(name, lock_file_prefix, lock_path) + with ext_lock: + yield ext_lock + else: + yield int_lock + + +def synchronized(name, lock_file_prefix=None, external=False, lock_path=None): + """Synchronization decorator. + + Decorating a method like so:: + + @synchronized('mylock') + def foo(self, *args): + ... + + ensures that only one thread will execute the foo method at a time. + + Different methods can share the same lock:: + + @synchronized('mylock') + def foo(self, *args): + ... + + @synchronized('mylock') + def bar(self, *args): + ... + + This way only one of either foo or bar can be executing at a time. + """ + + def wrap(f): + @functools.wraps(f) + def inner(*args, **kwargs): + try: + with lock(name, lock_file_prefix, external, lock_path): + LOG.debug('Got semaphore / lock "%(function)s"', + {'function': f.__name__}) + return f(*args, **kwargs) + finally: + LOG.debug('Semaphore / lock released "%(function)s"', + {'function': f.__name__}) + return inner + return wrap + + +def synchronized_with_prefix(lock_file_prefix): + """Partial object generator for the synchronization decorator. + + Redefine @synchronized in each project like so:: + + (in nova/utils.py) + from nova.openstack.common import lockutils + + synchronized = lockutils.synchronized_with_prefix('nova-') + + + (in nova/foo.py) + from nova import utils + + @utils.synchronized('mylock') + def bar(self, *args): + ... + + The lock_file_prefix argument is used to provide lock files on disk with a + meaningful prefix. + """ + + return functools.partial(synchronized, lock_file_prefix=lock_file_prefix) + + +def main(argv): + """Create a dir for locks and pass it to command from arguments + + If you run this: + python -m openstack.common.lockutils python setup.py testr + + a temporary directory will be created for all your locks and passed to all + your tests in an environment variable. The temporary dir will be deleted + afterwards and the return value will be preserved. + """ + + lock_dir = tempfile.mkdtemp() + os.environ["CERBERUS_LOCK_PATH"] = lock_dir + try: + ret_val = subprocess.call(argv[1:]) + finally: + shutil.rmtree(lock_dir, ignore_errors=True) + return ret_val + + +if __name__ == '__main__': + sys.exit(main(sys.argv)) diff --git a/cerberus/openstack/common/log.py b/cerberus/openstack/common/log.py new file mode 100644 index 0000000..8cef7af --- /dev/null +++ b/cerberus/openstack/common/log.py @@ -0,0 +1,713 @@ +# Copyright 2011 OpenStack Foundation. +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""OpenStack logging handler. + +This module adds to logging functionality by adding the option to specify +a context object when calling the various log methods. If the context object +is not specified, default formatting is used. Additionally, an instance uuid +may be passed as part of the log message, which is intended to make it easier +for admins to find messages related to a specific instance. + +It also allows setting of formatting information through conf. + +""" + +import inspect +import itertools +import logging +import logging.config +import logging.handlers +import os +import re +import sys +import traceback + +from oslo.config import cfg +import six +from six import moves + +from cerberus.openstack.common.gettextutils import _ +from cerberus.openstack.common import importutils +from cerberus.openstack.common import jsonutils +from cerberus.openstack.common import local + + +_DEFAULT_LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + +_SANITIZE_KEYS = ['adminPass', 'admin_pass', 'password', 'admin_password'] + +# NOTE(ldbragst): Let's build a list of regex objects using the list of +# _SANITIZE_KEYS we already have. This way, we only have to add the new key +# to the list of _SANITIZE_KEYS and we can generate regular expressions +# for XML and JSON automatically. +_SANITIZE_PATTERNS = [] +_FORMAT_PATTERNS = [r'(%(key)s\s*[=]\s*[\"\']).*?([\"\'])', + r'(<%(key)s>).*?()', + r'([\"\']%(key)s[\"\']\s*:\s*[\"\']).*?([\"\'])', + r'([\'"].*?%(key)s[\'"]\s*:\s*u?[\'"]).*?([\'"])'] + +for key in _SANITIZE_KEYS: + for pattern in _FORMAT_PATTERNS: + reg_ex = re.compile(pattern % {'key': key}, re.DOTALL) + _SANITIZE_PATTERNS.append(reg_ex) + + +common_cli_opts = [ + cfg.BoolOpt('debug', + short='d', + default=False, + help='Print debugging output (set logging level to ' + 'DEBUG instead of default WARNING level).'), + cfg.BoolOpt('verbose', + short='v', + default=False, + help='Print more verbose output (set logging level to ' + 'INFO instead of default WARNING level).'), +] + +logging_cli_opts = [ + cfg.StrOpt('log-config-append', + metavar='PATH', + deprecated_name='log-config', + help='The name of logging configuration file. It does not ' + 'disable existing loggers, but just appends specified ' + 'logging configuration to any other existing logging ' + 'options. Please see the Python logging module ' + 'documentation for details on logging configuration ' + 'files.'), + cfg.StrOpt('log-format', + default=None, + metavar='FORMAT', + help='DEPRECATED. ' + 'A logging.Formatter log message format string which may ' + 'use any of the available logging.LogRecord attributes. ' + 'This option is deprecated. Please use ' + 'logging_context_format_string and ' + 'logging_default_format_string instead.'), + cfg.StrOpt('log-date-format', + default=_DEFAULT_LOG_DATE_FORMAT, + metavar='DATE_FORMAT', + help='Format string for %%(asctime)s in log records. ' + 'Default: %(default)s'), + cfg.StrOpt('log-file', + metavar='PATH', + deprecated_name='logfile', + help='(Optional) Name of log file to output to. ' + 'If no default is set, logging will go to stdout.'), + cfg.StrOpt('log-dir', + deprecated_name='logdir', + help='(Optional) The base directory used for relative ' + '--log-file paths'), + cfg.BoolOpt('use-syslog', + default=False, + help='Use syslog for logging. ' + 'Existing syslog format is DEPRECATED during I, ' + 'and then will be changed in J to honor RFC5424'), + cfg.BoolOpt('use-syslog-rfc-format', + # TODO(bogdando) remove or use True after existing + # syslog format deprecation in J + default=False, + help='(Optional) Use syslog rfc5424 format for logging. ' + 'If enabled, will add APP-NAME (RFC5424) before the ' + 'MSG part of the syslog message. The old format ' + 'without APP-NAME is deprecated in I, ' + 'and will be removed in J.'), + cfg.StrOpt('syslog-log-facility', + default='LOG_USER', + help='Syslog facility to receive log lines') +] + +generic_log_opts = [ + cfg.BoolOpt('use_stderr', + default=True, + help='Log output to standard error') +] + +log_opts = [ + cfg.StrOpt('logging_context_format_string', + default='%(asctime)s.%(msecs)03d %(process)d %(levelname)s ' + '%(name)s [%(request_id)s %(user_identity)s] ' + '%(instance)s%(message)s', + help='Format string to use for log messages with context'), + cfg.StrOpt('logging_default_format_string', + default='%(asctime)s.%(msecs)03d %(process)d %(levelname)s ' + '%(name)s [-] %(instance)s%(message)s', + help='Format string to use for log messages without context'), + cfg.StrOpt('logging_debug_format_suffix', + default='%(funcName)s %(pathname)s:%(lineno)d', + help='Data to append to log format when level is DEBUG'), + cfg.StrOpt('logging_exception_prefix', + default='%(asctime)s.%(msecs)03d %(process)d TRACE %(name)s ' + '%(instance)s', + help='Prefix each line of exception output with this format'), + cfg.ListOpt('default_log_levels', + default=[ + 'amqp=WARN', + 'amqplib=WARN', + 'boto=WARN', + 'qpid=WARN', + 'sqlalchemy=WARN', + 'suds=INFO', + 'oslo.messaging=INFO', + 'iso8601=WARN', + 'requests.packages.urllib3.connectionpool=WARN' + ], + help='List of logger=LEVEL pairs'), + cfg.BoolOpt('publish_errors', + default=False, + help='Publish error events'), + cfg.BoolOpt('fatal_deprecations', + default=False, + help='Make deprecations fatal'), + + # NOTE(mikal): there are two options here because sometimes we are handed + # a full instance (and could include more information), and other times we + # are just handed a UUID for the instance. + cfg.StrOpt('instance_format', + default='[instance: %(uuid)s] ', + help='If an instance is passed with the log message, format ' + 'it like this'), + cfg.StrOpt('instance_uuid_format', + default='[instance: %(uuid)s] ', + help='If an instance UUID is passed with the log message, ' + 'format it like this'), +] + +CONF = cfg.CONF +CONF.register_cli_opts(common_cli_opts) +CONF.register_cli_opts(logging_cli_opts) +CONF.register_opts(generic_log_opts) +CONF.register_opts(log_opts) + +# our new audit level +# NOTE(jkoelker) Since we synthesized an audit level, make the logging +# module aware of it so it acts like other levels. +logging.AUDIT = logging.INFO + 1 +logging.addLevelName(logging.AUDIT, 'AUDIT') + + +try: + NullHandler = logging.NullHandler +except AttributeError: # NOTE(jkoelker) NullHandler added in Python 2.7 + class NullHandler(logging.Handler): + def handle(self, record): + pass + + def emit(self, record): + pass + + def createLock(self): + self.lock = None + + +def _dictify_context(context): + if context is None: + return None + if not isinstance(context, dict) and getattr(context, 'to_dict', None): + context = context.to_dict() + return context + + +def _get_binary_name(): + return os.path.basename(inspect.stack()[-1][1]) + + +def _get_log_file_path(binary=None): + logfile = CONF.log_file + logdir = CONF.log_dir + + if logfile and not logdir: + return logfile + + if logfile and logdir: + return os.path.join(logdir, logfile) + + if logdir: + binary = binary or _get_binary_name() + return '%s.log' % (os.path.join(logdir, binary),) + + return None + + +def mask_password(message, secret="***"): + """Replace password with 'secret' in message. + + :param message: The string which includes security information. + :param secret: value with which to replace passwords. + :returns: The unicode value of message with the password fields masked. + + For example: + + >>> mask_password("'adminPass' : 'aaaaa'") + "'adminPass' : '***'" + >>> mask_password("'admin_pass' : 'aaaaa'") + "'admin_pass' : '***'" + >>> mask_password('"password" : "aaaaa"') + '"password" : "***"' + >>> mask_password("'original_password' : 'aaaaa'") + "'original_password' : '***'" + >>> mask_password("u'original_password' : u'aaaaa'") + "u'original_password' : u'***'" + """ + message = six.text_type(message) + + # NOTE(ldbragst): Check to see if anything in message contains any key + # specified in _SANITIZE_KEYS, if not then just return the message since + # we don't have to mask any passwords. + if not any(key in message for key in _SANITIZE_KEYS): + return message + + secret = r'\g<1>' + secret + r'\g<2>' + for pattern in _SANITIZE_PATTERNS: + message = re.sub(pattern, secret, message) + return message + + +class BaseLoggerAdapter(logging.LoggerAdapter): + + def audit(self, msg, *args, **kwargs): + self.log(logging.AUDIT, msg, *args, **kwargs) + + +class LazyAdapter(BaseLoggerAdapter): + def __init__(self, name='unknown', version='unknown'): + self._logger = None + self.extra = {} + self.name = name + self.version = version + + @property + def logger(self): + if not self._logger: + self._logger = getLogger(self.name, self.version) + return self._logger + + +class ContextAdapter(BaseLoggerAdapter): + warn = logging.LoggerAdapter.warning + + def __init__(self, logger, project_name, version_string): + self.logger = logger + self.project = project_name + self.version = version_string + self._deprecated_messages_sent = dict() + + @property + def handlers(self): + return self.logger.handlers + + def deprecated(self, msg, *args, **kwargs): + """Call this method when a deprecated feature is used. + + If the system is configured for fatal deprecations then the message + is logged at the 'critical' level and :class:`DeprecatedConfig` will + be raised. + + Otherwise, the message will be logged (once) at the 'warn' level. + + :raises: :class:`DeprecatedConfig` if the system is configured for + fatal deprecations. + + """ + stdmsg = _("Deprecated: %s") % msg + if CONF.fatal_deprecations: + self.critical(stdmsg, *args, **kwargs) + raise DeprecatedConfig(msg=stdmsg) + + # Using a list because a tuple with dict can't be stored in a set. + sent_args = self._deprecated_messages_sent.setdefault(msg, list()) + + if args in sent_args: + # Already logged this message, so don't log it again. + return + + sent_args.append(args) + self.warn(stdmsg, *args, **kwargs) + + def process(self, msg, kwargs): + # NOTE(mrodden): catch any Message/other object and + # coerce to unicode before they can get + # to the python logging and possibly + # cause string encoding trouble + if not isinstance(msg, six.string_types): + msg = six.text_type(msg) + + if 'extra' not in kwargs: + kwargs['extra'] = {} + extra = kwargs['extra'] + + context = kwargs.pop('context', None) + if not context: + context = getattr(local.store, 'context', None) + if context: + extra.update(_dictify_context(context)) + + instance = kwargs.pop('instance', None) + instance_uuid = (extra.get('instance_uuid') or + kwargs.pop('instance_uuid', None)) + instance_extra = '' + if instance: + instance_extra = CONF.instance_format % instance + elif instance_uuid: + instance_extra = (CONF.instance_uuid_format + % {'uuid': instance_uuid}) + extra['instance'] = instance_extra + + extra.setdefault('user_identity', kwargs.pop('user_identity', None)) + + extra['project'] = self.project + extra['version'] = self.version + extra['extra'] = extra.copy() + return msg, kwargs + + +class JSONFormatter(logging.Formatter): + def __init__(self, fmt=None, datefmt=None): + # NOTE(jkoelker) we ignore the fmt argument, but its still there + # since logging.config.fileConfig passes it. + self.datefmt = datefmt + + def formatException(self, ei, strip_newlines=True): + lines = traceback.format_exception(*ei) + if strip_newlines: + lines = [moves.filter( + lambda x: x, + line.rstrip().splitlines()) for line in lines] + lines = list(itertools.chain(*lines)) + return lines + + def format(self, record): + message = {'message': record.getMessage(), + 'asctime': self.formatTime(record, self.datefmt), + 'name': record.name, + 'msg': record.msg, + 'args': record.args, + 'levelname': record.levelname, + 'levelno': record.levelno, + 'pathname': record.pathname, + 'filename': record.filename, + 'module': record.module, + 'lineno': record.lineno, + 'funcname': record.funcName, + 'created': record.created, + 'msecs': record.msecs, + 'relative_created': record.relativeCreated, + 'thread': record.thread, + 'thread_name': record.threadName, + 'process_name': record.processName, + 'process': record.process, + 'traceback': None} + + if hasattr(record, 'extra'): + message['extra'] = record.extra + + if record.exc_info: + message['traceback'] = self.formatException(record.exc_info) + + return jsonutils.dumps(message) + + +def _create_logging_excepthook(product_name): + def logging_excepthook(exc_type, value, tb): + extra = {} + if CONF.verbose or CONF.debug: + extra['exc_info'] = (exc_type, value, tb) + getLogger(product_name).critical( + "".join(traceback.format_exception_only(exc_type, value)), + **extra) + return logging_excepthook + + +class LogConfigError(Exception): + + message = _('Error loading logging config %(log_config)s: %(err_msg)s') + + def __init__(self, log_config, err_msg): + self.log_config = log_config + self.err_msg = err_msg + + def __str__(self): + return self.message % dict(log_config=self.log_config, + err_msg=self.err_msg) + + +def _load_log_config(log_config_append): + try: + logging.config.fileConfig(log_config_append, + disable_existing_loggers=False) + except moves.configparser.Error as exc: + raise LogConfigError(log_config_append, str(exc)) + + +def setup(product_name, version='unknown'): + """Setup logging.""" + if CONF.log_config_append: + _load_log_config(CONF.log_config_append) + else: + _setup_logging_from_conf(product_name, version) + sys.excepthook = _create_logging_excepthook(product_name) + + +def set_defaults(logging_context_format_string): + cfg.set_defaults(log_opts, + logging_context_format_string= + logging_context_format_string) + + +def _find_facility_from_conf(): + facility_names = logging.handlers.SysLogHandler.facility_names + facility = getattr(logging.handlers.SysLogHandler, + CONF.syslog_log_facility, + None) + + if facility is None and CONF.syslog_log_facility in facility_names: + facility = facility_names.get(CONF.syslog_log_facility) + + if facility is None: + valid_facilities = facility_names.keys() + consts = ['LOG_AUTH', 'LOG_AUTHPRIV', 'LOG_CRON', 'LOG_DAEMON', + 'LOG_FTP', 'LOG_KERN', 'LOG_LPR', 'LOG_MAIL', 'LOG_NEWS', + 'LOG_AUTH', 'LOG_SYSLOG', 'LOG_USER', 'LOG_UUCP', + 'LOG_LOCAL0', 'LOG_LOCAL1', 'LOG_LOCAL2', 'LOG_LOCAL3', + 'LOG_LOCAL4', 'LOG_LOCAL5', 'LOG_LOCAL6', 'LOG_LOCAL7'] + valid_facilities.extend(consts) + raise TypeError(_('syslog facility must be one of: %s') % + ', '.join("'%s'" % fac + for fac in valid_facilities)) + + return facility + + +class RFCSysLogHandler(logging.handlers.SysLogHandler): + def __init__(self, *args, **kwargs): + self.binary_name = _get_binary_name() + super(RFCSysLogHandler, self).__init__(*args, **kwargs) + + def format(self, record): + msg = super(RFCSysLogHandler, self).format(record) + msg = self.binary_name + ' ' + msg + return msg + + +def _setup_logging_from_conf(project, version): + log_root = getLogger(None).logger + for handler in log_root.handlers: + log_root.removeHandler(handler) + + if CONF.use_syslog: + facility = _find_facility_from_conf() + # TODO(bogdando) use the format provided by RFCSysLogHandler + # after existing syslog format deprecation in J + if CONF.use_syslog_rfc_format: + syslog = RFCSysLogHandler(address='/dev/log', + facility=facility) + else: + syslog = logging.handlers.SysLogHandler(address='/dev/log', + facility=facility) + log_root.addHandler(syslog) + + logpath = _get_log_file_path() + if logpath: + filelog = logging.handlers.WatchedFileHandler(logpath) + log_root.addHandler(filelog) + + if CONF.use_stderr: + streamlog = ColorHandler() + log_root.addHandler(streamlog) + + elif not logpath: + # pass sys.stdout as a positional argument + # python2.6 calls the argument strm, in 2.7 it's stream + streamlog = logging.StreamHandler(sys.stdout) + log_root.addHandler(streamlog) + + if CONF.publish_errors: + handler = importutils.import_object( + "cerberus.openstack.common.log_handler.PublishErrorsHandler", + logging.ERROR) + log_root.addHandler(handler) + + datefmt = CONF.log_date_format + for handler in log_root.handlers: + # NOTE(alaski): CONF.log_format overrides everything currently. This + # should be deprecated in favor of context aware formatting. + if CONF.log_format: + handler.setFormatter(logging.Formatter(fmt=CONF.log_format, + datefmt=datefmt)) + log_root.info('Deprecated: log_format is now deprecated and will ' + 'be removed in the next release') + else: + handler.setFormatter(ContextFormatter(project=project, + version=version, + datefmt=datefmt)) + + if CONF.debug: + log_root.setLevel(logging.DEBUG) + elif CONF.verbose: + log_root.setLevel(logging.INFO) + else: + log_root.setLevel(logging.WARNING) + + for pair in CONF.default_log_levels: + mod, _sep, level_name = pair.partition('=') + level = logging.getLevelName(level_name) + logger = logging.getLogger(mod) + logger.setLevel(level) + +_loggers = {} + + +def getLogger(name='unknown', version='unknown'): + if name not in _loggers: + _loggers[name] = ContextAdapter(logging.getLogger(name), + name, + version) + return _loggers[name] + + +def getLazyLogger(name='unknown', version='unknown'): + """Returns lazy logger. + + Creates a pass-through logger that does not create the real logger + until it is really needed and delegates all calls to the real logger + once it is created. + """ + return LazyAdapter(name, version) + + +class WritableLogger(object): + """A thin wrapper that responds to `write` and logs.""" + + def __init__(self, logger, level=logging.INFO): + self.logger = logger + self.level = level + + def write(self, msg): + self.logger.log(self.level, msg.rstrip()) + + +class ContextFormatter(logging.Formatter): + """A context.RequestContext aware formatter configured through flags. + + The flags used to set format strings are: logging_context_format_string + and logging_default_format_string. You can also specify + logging_debug_format_suffix to append extra formatting if the log level is + debug. + + For information about what variables are available for the formatter see: + http://docs.python.org/library/logging.html#formatter + + If available, uses the context value stored in TLS - local.store.context + + """ + + def __init__(self, *args, **kwargs): + """Initialize ContextFormatter instance + + Takes additional keyword arguments which can be used in the message + format string. + + :keyword project: project name + :type project: string + :keyword version: project version + :type version: string + + """ + + self.project = kwargs.pop('project', 'unknown') + self.version = kwargs.pop('version', 'unknown') + + logging.Formatter.__init__(self, *args, **kwargs) + + def format(self, record): + """Uses contextstring if request_id is set, otherwise default.""" + + # store project info + record.project = self.project + record.version = self.version + + # store request info + context = getattr(local.store, 'context', None) + if context: + d = _dictify_context(context) + for k, v in d.items(): + setattr(record, k, v) + + # NOTE(sdague): default the fancier formatting params + # to an empty string so we don't throw an exception if + # they get used + for key in ('instance', 'color', 'user_identity'): + if key not in record.__dict__: + record.__dict__[key] = '' + + if record.__dict__.get('request_id'): + self._fmt = CONF.logging_context_format_string + else: + self._fmt = CONF.logging_default_format_string + + if (record.levelno == logging.DEBUG and + CONF.logging_debug_format_suffix): + self._fmt += " " + CONF.logging_debug_format_suffix + + # Cache this on the record, Logger will respect our formatted copy + if record.exc_info: + record.exc_text = self.formatException(record.exc_info, record) + return logging.Formatter.format(self, record) + + def formatException(self, exc_info, record=None): + """Format exception output with CONF.logging_exception_prefix.""" + if not record: + return logging.Formatter.formatException(self, exc_info) + + stringbuffer = moves.StringIO() + traceback.print_exception(exc_info[0], exc_info[1], exc_info[2], + None, stringbuffer) + lines = stringbuffer.getvalue().split('\n') + stringbuffer.close() + + if CONF.logging_exception_prefix.find('%(asctime)') != -1: + record.asctime = self.formatTime(record, self.datefmt) + + formatted_lines = [] + for line in lines: + pl = CONF.logging_exception_prefix % record.__dict__ + fl = '%s%s' % (pl, line) + formatted_lines.append(fl) + return '\n'.join(formatted_lines) + + +class ColorHandler(logging.StreamHandler): + LEVEL_COLORS = { + logging.DEBUG: '\033[00;32m', # GREEN + logging.INFO: '\033[00;36m', # CYAN + logging.AUDIT: '\033[01;36m', # BOLD CYAN + logging.WARN: '\033[01;33m', # BOLD YELLOW + logging.ERROR: '\033[01;31m', # BOLD RED + logging.CRITICAL: '\033[01;31m', # BOLD RED + } + + def format(self, record): + record.color = self.LEVEL_COLORS[record.levelno] + return logging.StreamHandler.format(self, record) + + +class DeprecatedConfig(Exception): + message = _("Fatal call to deprecated config: %(msg)s") + + def __init__(self, msg): + super(Exception, self).__init__(self.message % dict(msg=msg)) diff --git a/cerberus/openstack/common/log_handler.py b/cerberus/openstack/common/log_handler.py new file mode 100644 index 0000000..836eab3 --- /dev/null +++ b/cerberus/openstack/common/log_handler.py @@ -0,0 +1,30 @@ +# Copyright 2013 IBM Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import logging + +from oslo.config import cfg + +from cerberus.openstack.common import notifier + + +class PublishErrorsHandler(logging.Handler): + def emit(self, record): + if ('cerberus.openstack.common.notifier.log_notifier' in + cfg.CONF.notification_driver): + return + notifier.api.notify(None, 'error.publisher', + 'error_notification', + notifier.api.ERROR, + dict(error=record.getMessage())) diff --git a/cerberus/openstack/common/loopingcall.py b/cerberus/openstack/common/loopingcall.py new file mode 100644 index 0000000..3b4496c --- /dev/null +++ b/cerberus/openstack/common/loopingcall.py @@ -0,0 +1,145 @@ +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2011 Justin Santa Barbara +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import sys + +from eventlet import event +from eventlet import greenthread + +from cerberus.openstack.common.gettextutils import _LE, _LW +from cerberus.openstack.common import log as logging +from cerberus.openstack.common import timeutils + +LOG = logging.getLogger(__name__) + + +class LoopingCallDone(Exception): + """Exception to break out and stop a LoopingCall. + + The poll-function passed to LoopingCall can raise this exception to + break out of the loop normally. This is somewhat analogous to + StopIteration. + + An optional return-value can be included as the argument to the exception; + this return-value will be returned by LoopingCall.wait() + + """ + + def __init__(self, retvalue=True): + """:param retvalue: Value that LoopingCall.wait() should return.""" + self.retvalue = retvalue + + +class LoopingCallBase(object): + def __init__(self, f=None, *args, **kw): + self.args = args + self.kw = kw + self.f = f + self._running = False + self.done = None + + def stop(self): + self._running = False + + def wait(self): + return self.done.wait() + + +class FixedIntervalLoopingCall(LoopingCallBase): + """A fixed interval looping call.""" + + def start(self, interval, initial_delay=None): + self._running = True + done = event.Event() + + def _inner(): + if initial_delay: + greenthread.sleep(initial_delay) + + try: + while self._running: + start = timeutils.utcnow() + self.f(*self.args, **self.kw) + end = timeutils.utcnow() + if not self._running: + break + delay = interval - timeutils.delta_seconds(start, end) + if delay <= 0: + LOG.warn(_LW('task run outlasted interval by %s sec') % + -delay) + greenthread.sleep(delay if delay > 0 else 0) + except LoopingCallDone as e: + self.stop() + done.send(e.retvalue) + except Exception: + LOG.exception(_LE('in fixed duration looping call')) + done.send_exception(*sys.exc_info()) + return + else: + done.send(True) + + self.done = done + + self.gt = greenthread.spawn(_inner) + return self.done + + +# TODO(mikal): this class name is deprecated in Havana and should be removed +# in the I release +LoopingCall = FixedIntervalLoopingCall + + +class DynamicLoopingCall(LoopingCallBase): + """A looping call which sleeps until the next known event. + + The function called should return how long to sleep for before being + called again. + """ + + def start(self, initial_delay=None, periodic_interval_max=None): + self._running = True + done = event.Event() + + def _inner(): + if initial_delay: + greenthread.sleep(initial_delay) + + try: + while self._running: + idle = self.f(*self.args, **self.kw) + if not self._running: + break + + if periodic_interval_max is not None: + idle = min(idle, periodic_interval_max) + LOG.debug('Dynamic looping call sleeping for %.02f ' + 'seconds', idle) + greenthread.sleep(idle) + except LoopingCallDone as e: + self.stop() + done.send(e.retvalue) + except Exception: + LOG.exception(_LE('in dynamic looping call')) + done.send_exception(*sys.exc_info()) + return + else: + done.send(True) + + self.done = done + + greenthread.spawn(_inner) + return self.done diff --git a/cerberus/openstack/common/network_utils.py b/cerberus/openstack/common/network_utils.py new file mode 100644 index 0000000..fa812b2 --- /dev/null +++ b/cerberus/openstack/common/network_utils.py @@ -0,0 +1,108 @@ +# Copyright 2012 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Network-related utilities and helper functions. +""" + +# TODO(jd) Use six.moves once +# https://bitbucket.org/gutworth/six/pull-request/28 +# is merged +try: + import urllib.parse + SplitResult = urllib.parse.SplitResult +except ImportError: + import urlparse + SplitResult = urlparse.SplitResult + +from six.moves.urllib import parse + + +def parse_host_port(address, default_port=None): + """Interpret a string as a host:port pair. + + An IPv6 address MUST be escaped if accompanied by a port, + because otherwise ambiguity ensues: 2001:db8:85a3::8a2e:370:7334 + means both [2001:db8:85a3::8a2e:370:7334] and + [2001:db8:85a3::8a2e:370]:7334. + + >>> parse_host_port('server01:80') + ('server01', 80) + >>> parse_host_port('server01') + ('server01', None) + >>> parse_host_port('server01', default_port=1234) + ('server01', 1234) + >>> parse_host_port('[::1]:80') + ('::1', 80) + >>> parse_host_port('[::1]') + ('::1', None) + >>> parse_host_port('[::1]', default_port=1234) + ('::1', 1234) + >>> parse_host_port('2001:db8:85a3::8a2e:370:7334', default_port=1234) + ('2001:db8:85a3::8a2e:370:7334', 1234) + + """ + if address[0] == '[': + # Escaped ipv6 + _host, _port = address[1:].split(']') + host = _host + if ':' in _port: + port = _port.split(':')[1] + else: + port = default_port + else: + if address.count(':') == 1: + host, port = address.split(':') + else: + # 0 means ipv4, >1 means ipv6. + # We prohibit unescaped ipv6 addresses with port. + host = address + port = default_port + + return (host, None if port is None else int(port)) + + +class ModifiedSplitResult(SplitResult): + """Split results class for urlsplit.""" + + # NOTE(dims): The functions below are needed for Python 2.6.x. + # We can remove these when we drop support for 2.6.x. + @property + def hostname(self): + netloc = self.netloc.split('@', 1)[-1] + host, port = parse_host_port(netloc) + return host + + @property + def port(self): + netloc = self.netloc.split('@', 1)[-1] + host, port = parse_host_port(netloc) + return port + + +def urlsplit(url, scheme='', allow_fragments=True): + """Parse a URL using urlparse.urlsplit(), splitting query and fragments. + This function papers over Python issue9374 when needed. + + The parameters are the same as urlparse.urlsplit. + """ + scheme, netloc, path, query, fragment = parse.urlsplit( + url, scheme, allow_fragments) + if allow_fragments and '#' in path: + path, fragment = path.split('#', 1) + if '?' in path: + path, query = path.split('?', 1) + return ModifiedSplitResult(scheme, netloc, + path, query, fragment) diff --git a/cerberus/openstack/common/notifier/__init__.py b/cerberus/openstack/common/notifier/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cerberus/openstack/common/notifier/api.py b/cerberus/openstack/common/notifier/api.py new file mode 100644 index 0000000..3603fd9 --- /dev/null +++ b/cerberus/openstack/common/notifier/api.py @@ -0,0 +1,173 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import socket +import uuid + +from oslo.config import cfg + +from cerberus.openstack.common import context +from cerberus.openstack.common.gettextutils import _, _LE +from cerberus.openstack.common import importutils +from cerberus.openstack.common import jsonutils +from cerberus.openstack.common import log as logging +from cerberus.openstack.common import timeutils + + +LOG = logging.getLogger(__name__) + +notifier_opts = [ + cfg.MultiStrOpt('notification_driver', + default=[], + help='Driver or drivers to handle sending notifications'), + cfg.StrOpt('default_notification_level', + default='INFO', + help='Default notification level for outgoing notifications'), + cfg.StrOpt('default_publisher_id', + default=None, + help='Default publisher_id for outgoing notifications'), +] + +CONF = cfg.CONF +CONF.register_opts(notifier_opts) + +WARN = 'WARN' +INFO = 'INFO' +ERROR = 'ERROR' +CRITICAL = 'CRITICAL' +DEBUG = 'DEBUG' + +log_levels = (DEBUG, WARN, INFO, ERROR, CRITICAL) + + +class BadPriorityException(Exception): + pass + + +def notify_decorator(name, fn): + """Decorator for notify which is used from utils.monkey_patch(). + + :param name: name of the function + :param function: - object of the function + :returns: function -- decorated function + + """ + def wrapped_func(*args, **kwarg): + body = {} + body['args'] = [] + body['kwarg'] = {} + for arg in args: + body['args'].append(arg) + for key in kwarg: + body['kwarg'][key] = kwarg[key] + + ctxt = context.get_context_from_function_and_args(fn, args, kwarg) + notify(ctxt, + CONF.default_publisher_id or socket.gethostname(), + name, + CONF.default_notification_level, + body) + return fn(*args, **kwarg) + return wrapped_func + + +def publisher_id(service, host=None): + if not host: + try: + host = CONF.host + except AttributeError: + host = CONF.default_publisher_id or socket.gethostname() + return "%s.%s" % (service, host) + + +def notify(context, publisher_id, event_type, priority, payload): + """Sends a notification using the specified driver + + :param publisher_id: the source worker_type.host of the message + :param event_type: the literal type of event (ex. Instance Creation) + :param priority: patterned after the enumeration of Python logging + levels in the set (DEBUG, WARN, INFO, ERROR, CRITICAL) + :param payload: A python dictionary of attributes + + Outgoing message format includes the above parameters, and appends the + following: + + message_id + a UUID representing the id for this notification + + timestamp + the GMT timestamp the notification was sent at + + The composite message will be constructed as a dictionary of the above + attributes, which will then be sent via the transport mechanism defined + by the driver. + + Message example:: + + {'message_id': str(uuid.uuid4()), + 'publisher_id': 'compute.host1', + 'timestamp': timeutils.utcnow(), + 'priority': 'WARN', + 'event_type': 'compute.create_instance', + 'payload': {'instance_id': 12, ... }} + + """ + if priority not in log_levels: + raise BadPriorityException( + _('%s not in valid priorities') % priority) + + # Ensure everything is JSON serializable. + payload = jsonutils.to_primitive(payload, convert_instances=True) + + msg = dict(message_id=str(uuid.uuid4()), + publisher_id=publisher_id, + event_type=event_type, + priority=priority, + payload=payload, + timestamp=str(timeutils.utcnow())) + + for driver in _get_drivers(): + try: + driver.notify(context, msg) + except Exception as e: + LOG.exception(_LE("Problem '%(e)s' attempting to " + "send to notification system. " + "Payload=%(payload)s") + % dict(e=e, payload=payload)) + + +_drivers = None + + +def _get_drivers(): + """Instantiate, cache, and return drivers based on the CONF.""" + global _drivers + if _drivers is None: + _drivers = {} + for notification_driver in CONF.notification_driver: + try: + driver = importutils.import_module(notification_driver) + _drivers[notification_driver] = driver + except ImportError: + LOG.exception(_LE("Failed to load notifier %s. " + "These notifications will not be sent.") % + notification_driver) + return _drivers.values() + + +def _reset_drivers(): + """Used by unit tests to reset the drivers.""" + global _drivers + _drivers = None diff --git a/cerberus/openstack/common/notifier/log_notifier.py b/cerberus/openstack/common/notifier/log_notifier.py new file mode 100644 index 0000000..62e9497 --- /dev/null +++ b/cerberus/openstack/common/notifier/log_notifier.py @@ -0,0 +1,37 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo.config import cfg + +from cerberus.openstack.common import jsonutils +from cerberus.openstack.common import log as logging + + +CONF = cfg.CONF + + +def notify(_context, message): + """Notifies the recipient of the desired event given the model. + + Log notifications using OpenStack's default logging system. + """ + + priority = message.get('priority', + CONF.default_notification_level) + priority = priority.lower() + logger = logging.getLogger( + 'cerberus.openstack.common.notification.%s' % + message['event_type']) + getattr(logger, priority)(jsonutils.dumps(message)) diff --git a/cerberus/openstack/common/notifier/no_op_notifier.py b/cerberus/openstack/common/notifier/no_op_notifier.py new file mode 100644 index 0000000..13d946e --- /dev/null +++ b/cerberus/openstack/common/notifier/no_op_notifier.py @@ -0,0 +1,19 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +def notify(_context, message): + """Notifies the recipient of the desired event given the model.""" + pass diff --git a/cerberus/openstack/common/notifier/proxy.py b/cerberus/openstack/common/notifier/proxy.py new file mode 100644 index 0000000..948c300 --- /dev/null +++ b/cerberus/openstack/common/notifier/proxy.py @@ -0,0 +1,77 @@ +# Copyright 2013 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +A temporary helper which emulates oslo.messaging.Notifier. + +This helper method allows us to do the tedious porting to the new Notifier API +as a standalone commit so that the commit which switches us to oslo.messaging +is smaller and easier to review. This file will be removed as part of that +commit. +""" + +from oslo.config import cfg + +from cerberus.openstack.common.notifier import api as notifier_api + +CONF = cfg.CONF + + +class Notifier(object): + + def __init__(self, publisher_id): + super(Notifier, self).__init__() + self.publisher_id = publisher_id + + _marker = object() + + def prepare(self, publisher_id=_marker): + ret = self.__class__(self.publisher_id) + if publisher_id is not self._marker: + ret.publisher_id = publisher_id + return ret + + def _notify(self, ctxt, event_type, payload, priority): + notifier_api.notify(ctxt, + self.publisher_id, + event_type, + priority, + payload) + + def audit(self, ctxt, event_type, payload): + # No audit in old notifier. + self._notify(ctxt, event_type, payload, 'INFO') + + def debug(self, ctxt, event_type, payload): + self._notify(ctxt, event_type, payload, 'DEBUG') + + def info(self, ctxt, event_type, payload): + self._notify(ctxt, event_type, payload, 'INFO') + + def warn(self, ctxt, event_type, payload): + self._notify(ctxt, event_type, payload, 'WARN') + + warning = warn + + def error(self, ctxt, event_type, payload): + self._notify(ctxt, event_type, payload, 'ERROR') + + def critical(self, ctxt, event_type, payload): + self._notify(ctxt, event_type, payload, 'CRITICAL') + + +def get_notifier(service=None, host=None, publisher_id=None): + if not publisher_id: + publisher_id = "%s.%s" % (service, host or CONF.host) + return Notifier(publisher_id) diff --git a/cerberus/openstack/common/notifier/rpc_notifier.py b/cerberus/openstack/common/notifier/rpc_notifier.py new file mode 100644 index 0000000..4a713fd --- /dev/null +++ b/cerberus/openstack/common/notifier/rpc_notifier.py @@ -0,0 +1,47 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo.config import cfg + +from cerberus.openstack.common import context as req_context +from cerberus.openstack.common.gettextutils import _LE +from cerberus.openstack.common import log as logging +from cerberus.openstack.common import rpc + +LOG = logging.getLogger(__name__) + +notification_topic_opt = cfg.ListOpt( + 'notification_topics', default=['notifications', ], + help='AMQP topic used for OpenStack notifications') + +CONF = cfg.CONF +CONF.register_opt(notification_topic_opt) + + +def notify(context, message): + """Sends a notification via RPC.""" + if not context: + context = req_context.get_admin_context() + priority = message.get('priority', + CONF.default_notification_level) + priority = priority.lower() + for topic in CONF.notification_topics: + topic = '%s.%s' % (topic, priority) + try: + rpc.notify(context, topic, message) + except Exception: + LOG.exception(_LE("Could not send notification to %(topic)s. " + "Payload=%(message)s"), + {"topic": topic, "message": message}) diff --git a/cerberus/openstack/common/notifier/rpc_notifier2.py b/cerberus/openstack/common/notifier/rpc_notifier2.py new file mode 100644 index 0000000..5ec99be --- /dev/null +++ b/cerberus/openstack/common/notifier/rpc_notifier2.py @@ -0,0 +1,53 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +'''messaging based notification driver, with message envelopes''' + +from oslo.config import cfg + +from cerberus.openstack.common import context as req_context +from cerberus.openstack.common.gettextutils import _LE +from cerberus.openstack.common import log as logging +from cerberus.openstack.common import rpc + +LOG = logging.getLogger(__name__) + +notification_topic_opt = cfg.ListOpt( + 'topics', default=['notifications', ], + help='AMQP topic(s) used for OpenStack notifications') + +opt_group = cfg.OptGroup(name='rpc_notifier2', + title='Options for rpc_notifier2') + +CONF = cfg.CONF +CONF.register_group(opt_group) +CONF.register_opt(notification_topic_opt, opt_group) + + +def notify(context, message): + """Sends a notification via RPC.""" + if not context: + context = req_context.get_admin_context() + priority = message.get('priority', + CONF.default_notification_level) + priority = priority.lower() + for topic in CONF.rpc_notifier2.topics: + topic = '%s.%s' % (topic, priority) + try: + rpc.notify(context, topic, message, envelope=True) + except Exception: + LOG.exception(_LE("Could not send notification to %(topic)s. " + "Payload=%(message)s"), + {"topic": topic, "message": message}) diff --git a/cerberus/openstack/common/notifier/test_notifier.py b/cerberus/openstack/common/notifier/test_notifier.py new file mode 100644 index 0000000..11fc21f --- /dev/null +++ b/cerberus/openstack/common/notifier/test_notifier.py @@ -0,0 +1,21 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +NOTIFICATIONS = [] + + +def notify(_context, message): + """Test notifier, stores notifications in memory for unittests.""" + NOTIFICATIONS.append(message) diff --git a/cerberus/openstack/common/periodic_task.py b/cerberus/openstack/common/periodic_task.py new file mode 100644 index 0000000..5311a40 --- /dev/null +++ b/cerberus/openstack/common/periodic_task.py @@ -0,0 +1,183 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import time + +from oslo.config import cfg +import six + +from cerberus.openstack.common.gettextutils import _, _LE, _LI +from cerberus.openstack.common import log as logging + + +periodic_opts = [ + cfg.BoolOpt('run_external_periodic_tasks', + default=True, + help=('Some periodic tasks can be run in a separate process. ' + 'Should we run them here?')), +] + +CONF = cfg.CONF +CONF.register_opts(periodic_opts) + +LOG = logging.getLogger(__name__) + +DEFAULT_INTERVAL = 60.0 + + +class InvalidPeriodicTaskArg(Exception): + message = _("Unexpected argument for periodic task creation: %(arg)s.") + + +def periodic_task(*args, **kwargs): + """Decorator to indicate that a method is a periodic task. + + This decorator can be used in two ways: + + 1. Without arguments '@periodic_task', this will be run on every cycle + of the periodic scheduler. + + 2. With arguments: + @periodic_task(spacing=N [, run_immediately=[True|False]]) + this will be run on approximately every N seconds. If this number is + negative the periodic task will be disabled. If the run_immediately + argument is provided and has a value of 'True', the first run of the + task will be shortly after task scheduler starts. If + run_immediately is omitted or set to 'False', the first time the + task runs will be approximately N seconds after the task scheduler + starts. + """ + def decorator(f): + # Test for old style invocation + if 'ticks_between_runs' in kwargs: + raise InvalidPeriodicTaskArg(arg='ticks_between_runs') + + # Control if run at all + f._periodic_task = True + f._periodic_external_ok = kwargs.pop('external_process_ok', False) + if f._periodic_external_ok and not CONF.run_external_periodic_tasks: + f._periodic_enabled = False + else: + f._periodic_enabled = kwargs.pop('enabled', True) + + # Control frequency + f._periodic_spacing = kwargs.pop('spacing', 0) + f._periodic_immediate = kwargs.pop('run_immediately', False) + if f._periodic_immediate: + f._periodic_last_run = None + else: + f._periodic_last_run = time.time() + return f + + # NOTE(sirp): The `if` is necessary to allow the decorator to be used with + # and without parents. + # + # In the 'with-parents' case (with kwargs present), this function needs to + # return a decorator function since the interpreter will invoke it like: + # + # periodic_task(*args, **kwargs)(f) + # + # In the 'without-parents' case, the original function will be passed + # in as the first argument, like: + # + # periodic_task(f) + if kwargs: + return decorator + else: + return decorator(args[0]) + + +class _PeriodicTasksMeta(type): + def __init__(cls, names, bases, dict_): + """Metaclass that allows us to collect decorated periodic tasks.""" + super(_PeriodicTasksMeta, cls).__init__(names, bases, dict_) + + # NOTE(sirp): if the attribute is not present then we must be the base + # class, so, go ahead an initialize it. If the attribute is present, + # then we're a subclass so make a copy of it so we don't step on our + # parent's toes. + try: + cls._periodic_tasks = cls._periodic_tasks[:] + except AttributeError: + cls._periodic_tasks = [] + + try: + cls._periodic_spacing = cls._periodic_spacing.copy() + except AttributeError: + cls._periodic_spacing = {} + + for value in cls.__dict__.values(): + if getattr(value, '_periodic_task', False): + task = value + name = task.__name__ + + if task._periodic_spacing < 0: + LOG.info(_LI('Skipping periodic task %(task)s because ' + 'its interval is negative'), + {'task': name}) + continue + if not task._periodic_enabled: + LOG.info(_LI('Skipping periodic task %(task)s because ' + 'it is disabled'), + {'task': name}) + continue + + # A periodic spacing of zero indicates that this task should + # be run every pass + if task._periodic_spacing == 0: + task._periodic_spacing = None + + cls._periodic_tasks.append((name, task)) + cls._periodic_spacing[name] = task._periodic_spacing + + +@six.add_metaclass(_PeriodicTasksMeta) +class PeriodicTasks(object): + def __init__(self): + super(PeriodicTasks, self).__init__() + self._periodic_last_run = {} + for name, task in self._periodic_tasks: + self._periodic_last_run[name] = task._periodic_last_run + + def run_periodic_tasks(self, context, raise_on_error=False): + """Tasks to be run at a periodic interval.""" + idle_for = DEFAULT_INTERVAL + for task_name, task in self._periodic_tasks: + full_task_name = '.'.join([self.__class__.__name__, task_name]) + + spacing = self._periodic_spacing[task_name] + last_run = self._periodic_last_run[task_name] + + # If a periodic task is _nearly_ due, then we'll run it early + if spacing is not None: + idle_for = min(idle_for, spacing) + if last_run is not None: + delta = last_run + spacing - time.time() + if delta > 0.2: + idle_for = min(idle_for, delta) + continue + + LOG.debug("Running periodic task %(full_task_name)s", + {"full_task_name": full_task_name}) + self._periodic_last_run[task_name] = time.time() + + try: + task(self, context) + except Exception as e: + if raise_on_error: + raise + LOG.exception(_LE("Error during %(full_task_name)s: %(e)s"), + {"full_task_name": full_task_name, "e": e}) + time.sleep(0) + + return idle_for diff --git a/cerberus/openstack/common/policy.py b/cerberus/openstack/common/policy.py new file mode 100644 index 0000000..d6d1747 --- /dev/null +++ b/cerberus/openstack/common/policy.py @@ -0,0 +1,897 @@ +# Copyright (c) 2012 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Common Policy Engine Implementation + +Policies can be expressed in one of two forms: A list of lists, or a +string written in the new policy language. + +In the list-of-lists representation, each check inside the innermost +list is combined as with an "and" conjunction--for that check to pass, +all the specified checks must pass. These innermost lists are then +combined as with an "or" conjunction. This is the original way of +expressing policies, but there now exists a new way: the policy +language. + +In the policy language, each check is specified the same way as in the +list-of-lists representation: a simple "a:b" pair that is matched to +the correct code to perform that check. However, conjunction +operators are available, allowing for more expressiveness in crafting +policies. + +As an example, take the following rule, expressed in the list-of-lists +representation:: + + [["role:admin"], ["project_id:%(project_id)s", "role:projectadmin"]] + +In the policy language, this becomes:: + + role:admin or (project_id:%(project_id)s and role:projectadmin) + +The policy language also has the "not" operator, allowing a richer +policy rule:: + + project_id:%(project_id)s and not role:dunce + +It is possible to perform policy checks on the following user +attributes (obtained through the token): user_id, domain_id or +project_id:: + + domain_id: + +Attributes sent along with API calls can be used by the policy engine +(on the right side of the expression), by using the following syntax:: + + :user.id + +Contextual attributes of objects identified by their IDs are loaded +from the database. They are also available to the policy engine and +can be checked through the `target` keyword:: + + :target.role.name + +All these attributes (related to users, API calls, and context) can be +checked against each other or against constants, be it literals (True, +) or strings. + +Finally, two special policy checks should be mentioned; the policy +check "@" will always accept an access, and the policy check "!" will +always reject an access. (Note that if a rule is either the empty +list ("[]") or the empty string, this is equivalent to the "@" policy +check.) Of these, the "!" policy check is probably the most useful, +as it allows particular rules to be explicitly disabled. +""" + +import abc +import ast +import re + +from oslo.config import cfg +import six +import six.moves.urllib.parse as urlparse +import six.moves.urllib.request as urlrequest + +from cerberus.openstack.common import fileutils +from cerberus.openstack.common.gettextutils import _, _LE +from cerberus.openstack.common import jsonutils +from cerberus.openstack.common import log as logging + + +policy_opts = [ + cfg.StrOpt('policy_file', + default='policy.json', + help=_('JSON file containing policy')), + cfg.StrOpt('policy_default_rule', + default='default', + help=_('Rule enforced when requested rule is not found')), +] + +CONF = cfg.CONF +CONF.register_opts(policy_opts) + +LOG = logging.getLogger(__name__) + +_checks = {} + + +class PolicyNotAuthorized(Exception): + + def __init__(self, rule): + msg = _("Policy doesn't allow %s to be performed.") % rule + super(PolicyNotAuthorized, self).__init__(msg) + + +class Rules(dict): + """A store for rules. Handles the default_rule setting directly.""" + + @classmethod + def load_json(cls, data, default_rule=None): + """Allow loading of JSON rule data.""" + + # Suck in the JSON data and parse the rules + rules = dict((k, parse_rule(v)) for k, v in + jsonutils.loads(data).items()) + + return cls(rules, default_rule) + + def __init__(self, rules=None, default_rule=None): + """Initialize the Rules store.""" + + super(Rules, self).__init__(rules or {}) + self.default_rule = default_rule + + def __missing__(self, key): + """Implements the default rule handling.""" + + if isinstance(self.default_rule, dict): + raise KeyError(key) + + # If the default rule isn't actually defined, do something + # reasonably intelligent + if not self.default_rule: + raise KeyError(key) + + if isinstance(self.default_rule, BaseCheck): + return self.default_rule + + # We need to check this or we can get infinite recursion + if self.default_rule not in self: + raise KeyError(key) + + elif isinstance(self.default_rule, six.string_types): + return self[self.default_rule] + + def __str__(self): + """Dumps a string representation of the rules.""" + + # Start by building the canonical strings for the rules + out_rules = {} + for key, value in self.items(): + # Use empty string for singleton TrueCheck instances + if isinstance(value, TrueCheck): + out_rules[key] = '' + else: + out_rules[key] = str(value) + + # Dump a pretty-printed JSON representation + return jsonutils.dumps(out_rules, indent=4) + + +class Enforcer(object): + """Responsible for loading and enforcing rules. + + :param policy_file: Custom policy file to use, if none is + specified, `CONF.policy_file` will be + used. + :param rules: Default dictionary / Rules to use. It will be + considered just in the first instantiation. If + `load_rules(True)`, `clear()` or `set_rules(True)` + is called this will be overwritten. + :param default_rule: Default rule to use, CONF.default_rule will + be used if none is specified. + :param use_conf: Whether to load rules from cache or config file. + """ + + def __init__(self, policy_file=None, rules=None, + default_rule=None, use_conf=True): + self.rules = Rules(rules, default_rule) + self.default_rule = default_rule or CONF.policy_default_rule + + self.policy_path = None + self.policy_file = policy_file or CONF.policy_file + self.use_conf = use_conf + + def set_rules(self, rules, overwrite=True, use_conf=False): + """Create a new Rules object based on the provided dict of rules. + + :param rules: New rules to use. It should be an instance of dict. + :param overwrite: Whether to overwrite current rules or update them + with the new rules. + :param use_conf: Whether to reload rules from cache or config file. + """ + + if not isinstance(rules, dict): + raise TypeError(_("Rules must be an instance of dict or Rules, " + "got %s instead") % type(rules)) + self.use_conf = use_conf + if overwrite: + self.rules = Rules(rules, self.default_rule) + else: + self.rules.update(rules) + + def clear(self): + """Clears Enforcer rules, policy's cache and policy's path.""" + self.set_rules({}) + self.default_rule = None + self.policy_path = None + + def load_rules(self, force_reload=False): + """Loads policy_path's rules. + + Policy file is cached and will be reloaded if modified. + + :param force_reload: Whether to overwrite current rules. + """ + + if force_reload: + self.use_conf = force_reload + + if self.use_conf: + if not self.policy_path: + self.policy_path = self._get_policy_path() + + reloaded, data = fileutils.read_cached_file( + self.policy_path, force_reload=force_reload) + if reloaded or not self.rules: + rules = Rules.load_json(data, self.default_rule) + self.set_rules(rules) + LOG.debug("Rules successfully reloaded") + + def _get_policy_path(self): + """Locate the policy json data file. + + :param policy_file: Custom policy file to locate. + + :returns: The policy path + + :raises: ConfigFilesNotFoundError if the file couldn't + be located. + """ + policy_file = CONF.find_file(self.policy_file) + + if policy_file: + return policy_file + + raise cfg.ConfigFilesNotFoundError((self.policy_file,)) + + def enforce(self, rule, target, creds, do_raise=False, + exc=None, *args, **kwargs): + """Checks authorization of a rule against the target and credentials. + + :param rule: A string or BaseCheck instance specifying the rule + to evaluate. + :param target: As much information about the object being operated + on as possible, as a dictionary. + :param creds: As much information about the user performing the + action as possible, as a dictionary. + :param do_raise: Whether to raise an exception or not if check + fails. + :param exc: Class of the exception to raise if the check fails. + Any remaining arguments passed to check() (both + positional and keyword arguments) will be passed to + the exception class. If not specified, PolicyNotAuthorized + will be used. + + :return: Returns False if the policy does not allow the action and + exc is not provided; otherwise, returns a value that + evaluates to True. Note: for rules using the "case" + expression, this True value will be the specified string + from the expression. + """ + + # NOTE(flaper87): Not logging target or creds to avoid + # potential security issues. + LOG.debug("Rule %s will be now enforced" % rule) + + self.load_rules() + + # Allow the rule to be a Check tree + if isinstance(rule, BaseCheck): + result = rule(target, creds, self) + elif not self.rules: + # No rules to reference means we're going to fail closed + result = False + else: + try: + # Evaluate the rule + result = self.rules[rule](target, creds, self) + except KeyError: + LOG.debug("Rule [%s] doesn't exist" % rule) + # If the rule doesn't exist, fail closed + result = False + + # If it is False, raise the exception if requested + if do_raise and not result: + if exc: + raise exc(*args, **kwargs) + + raise PolicyNotAuthorized(rule) + + return result + + +@six.add_metaclass(abc.ABCMeta) +class BaseCheck(object): + """Abstract base class for Check classes.""" + + @abc.abstractmethod + def __str__(self): + """String representation of the Check tree rooted at this node.""" + + pass + + @abc.abstractmethod + def __call__(self, target, cred, enforcer): + """Triggers if instance of the class is called. + + Performs the check. Returns False to reject the access or a + true value (not necessary True) to accept the access. + """ + + pass + + +class FalseCheck(BaseCheck): + """A policy check that always returns False (disallow).""" + + def __str__(self): + """Return a string representation of this check.""" + + return "!" + + def __call__(self, target, cred, enforcer): + """Check the policy.""" + + return False + + +class TrueCheck(BaseCheck): + """A policy check that always returns True (allow).""" + + def __str__(self): + """Return a string representation of this check.""" + + return "@" + + def __call__(self, target, cred, enforcer): + """Check the policy.""" + + return True + + +class Check(BaseCheck): + """A base class to allow for user-defined policy checks.""" + + def __init__(self, kind, match): + """Initiates Check instance. + + :param kind: The kind of the check, i.e., the field before the + ':'. + :param match: The match of the check, i.e., the field after + the ':'. + """ + + self.kind = kind + self.match = match + + def __str__(self): + """Return a string representation of this check.""" + + return "%s:%s" % (self.kind, self.match) + + +class NotCheck(BaseCheck): + """Implements the "not" logical operator. + + A policy check that inverts the result of another policy check. + """ + + def __init__(self, rule): + """Initialize the 'not' check. + + :param rule: The rule to negate. Must be a Check. + """ + + self.rule = rule + + def __str__(self): + """Return a string representation of this check.""" + + return "not %s" % self.rule + + def __call__(self, target, cred, enforcer): + """Check the policy. + + Returns the logical inverse of the wrapped check. + """ + + return not self.rule(target, cred, enforcer) + + +class AndCheck(BaseCheck): + """Implements the "and" logical operator. + + A policy check that requires that a list of other checks all return True. + """ + + def __init__(self, rules): + """Initialize the 'and' check. + + :param rules: A list of rules that will be tested. + """ + + self.rules = rules + + def __str__(self): + """Return a string representation of this check.""" + + return "(%s)" % ' and '.join(str(r) for r in self.rules) + + def __call__(self, target, cred, enforcer): + """Check the policy. + + Requires that all rules accept in order to return True. + """ + + for rule in self.rules: + if not rule(target, cred, enforcer): + return False + + return True + + def add_check(self, rule): + """Adds rule to be tested. + + Allows addition of another rule to the list of rules that will + be tested. Returns the AndCheck object for convenience. + """ + + self.rules.append(rule) + return self + + +class OrCheck(BaseCheck): + """Implements the "or" operator. + + A policy check that requires that at least one of a list of other + checks returns True. + """ + + def __init__(self, rules): + """Initialize the 'or' check. + + :param rules: A list of rules that will be tested. + """ + + self.rules = rules + + def __str__(self): + """Return a string representation of this check.""" + + return "(%s)" % ' or '.join(str(r) for r in self.rules) + + def __call__(self, target, cred, enforcer): + """Check the policy. + + Requires that at least one rule accept in order to return True. + """ + + for rule in self.rules: + if rule(target, cred, enforcer): + return True + return False + + def add_check(self, rule): + """Adds rule to be tested. + + Allows addition of another rule to the list of rules that will + be tested. Returns the OrCheck object for convenience. + """ + + self.rules.append(rule) + return self + + +def _parse_check(rule): + """Parse a single base check rule into an appropriate Check object.""" + + # Handle the special checks + if rule == '!': + return FalseCheck() + elif rule == '@': + return TrueCheck() + + try: + kind, match = rule.split(':', 1) + except Exception: + LOG.exception(_LE("Failed to understand rule %s") % rule) + # If the rule is invalid, we'll fail closed + return FalseCheck() + + # Find what implements the check + if kind in _checks: + return _checks[kind](kind, match) + elif None in _checks: + return _checks[None](kind, match) + else: + LOG.error(_LE("No handler for matches of kind %s") % kind) + return FalseCheck() + + +def _parse_list_rule(rule): + """Translates the old list-of-lists syntax into a tree of Check objects. + + Provided for backwards compatibility. + """ + + # Empty rule defaults to True + if not rule: + return TrueCheck() + + # Outer list is joined by "or"; inner list by "and" + or_list = [] + for inner_rule in rule: + # Elide empty inner lists + if not inner_rule: + continue + + # Handle bare strings + if isinstance(inner_rule, six.string_types): + inner_rule = [inner_rule] + + # Parse the inner rules into Check objects + and_list = [_parse_check(r) for r in inner_rule] + + # Append the appropriate check to the or_list + if len(and_list) == 1: + or_list.append(and_list[0]) + else: + or_list.append(AndCheck(and_list)) + + # If we have only one check, omit the "or" + if not or_list: + return FalseCheck() + elif len(or_list) == 1: + return or_list[0] + + return OrCheck(or_list) + + +# Used for tokenizing the policy language +_tokenize_re = re.compile(r'\s+') + + +def _parse_tokenize(rule): + """Tokenizer for the policy language. + + Most of the single-character tokens are specified in the + _tokenize_re; however, parentheses need to be handled specially, + because they can appear inside a check string. Thankfully, those + parentheses that appear inside a check string can never occur at + the very beginning or end ("%(variable)s" is the correct syntax). + """ + + for tok in _tokenize_re.split(rule): + # Skip empty tokens + if not tok or tok.isspace(): + continue + + # Handle leading parens on the token + clean = tok.lstrip('(') + for i in range(len(tok) - len(clean)): + yield '(', '(' + + # If it was only parentheses, continue + if not clean: + continue + else: + tok = clean + + # Handle trailing parens on the token + clean = tok.rstrip(')') + trail = len(tok) - len(clean) + + # Yield the cleaned token + lowered = clean.lower() + if lowered in ('and', 'or', 'not'): + # Special tokens + yield lowered, clean + elif clean: + # Not a special token, but not composed solely of ')' + if len(tok) >= 2 and ((tok[0], tok[-1]) in + [('"', '"'), ("'", "'")]): + # It's a quoted string + yield 'string', tok[1:-1] + else: + yield 'check', _parse_check(clean) + + # Yield the trailing parens + for i in range(trail): + yield ')', ')' + + +class ParseStateMeta(type): + """Metaclass for the ParseState class. + + Facilitates identifying reduction methods. + """ + + def __new__(mcs, name, bases, cls_dict): + """Create the class. + + Injects the 'reducers' list, a list of tuples matching token sequences + to the names of the corresponding reduction methods. + """ + + reducers = [] + + for key, value in cls_dict.items(): + if not hasattr(value, 'reducers'): + continue + for reduction in value.reducers: + reducers.append((reduction, key)) + + cls_dict['reducers'] = reducers + + return super(ParseStateMeta, mcs).__new__(mcs, name, bases, cls_dict) + + +def reducer(*tokens): + """Decorator for reduction methods. + + Arguments are a sequence of tokens, in order, which should trigger running + this reduction method. + """ + + def decorator(func): + # Make sure we have a list of reducer sequences + if not hasattr(func, 'reducers'): + func.reducers = [] + + # Add the tokens to the list of reducer sequences + func.reducers.append(list(tokens)) + + return func + + return decorator + + +@six.add_metaclass(ParseStateMeta) +class ParseState(object): + """Implement the core of parsing the policy language. + + Uses a greedy reduction algorithm to reduce a sequence of tokens into + a single terminal, the value of which will be the root of the Check tree. + + Note: error reporting is rather lacking. The best we can get with + this parser formulation is an overall "parse failed" error. + Fortunately, the policy language is simple enough that this + shouldn't be that big a problem. + """ + + def __init__(self): + """Initialize the ParseState.""" + + self.tokens = [] + self.values = [] + + def reduce(self): + """Perform a greedy reduction of the token stream. + + If a reducer method matches, it will be executed, then the + reduce() method will be called recursively to search for any more + possible reductions. + """ + + for reduction, methname in self.reducers: + if (len(self.tokens) >= len(reduction) and + self.tokens[-len(reduction):] == reduction): + # Get the reduction method + meth = getattr(self, methname) + + # Reduce the token stream + results = meth(*self.values[-len(reduction):]) + + # Update the tokens and values + self.tokens[-len(reduction):] = [r[0] for r in results] + self.values[-len(reduction):] = [r[1] for r in results] + + # Check for any more reductions + return self.reduce() + + def shift(self, tok, value): + """Adds one more token to the state. Calls reduce().""" + + self.tokens.append(tok) + self.values.append(value) + + # Do a greedy reduce... + self.reduce() + + @property + def result(self): + """Obtain the final result of the parse. + + Raises ValueError if the parse failed to reduce to a single result. + """ + + if len(self.values) != 1: + raise ValueError("Could not parse rule") + return self.values[0] + + @reducer('(', 'check', ')') + @reducer('(', 'and_expr', ')') + @reducer('(', 'or_expr', ')') + def _wrap_check(self, _p1, check, _p2): + """Turn parenthesized expressions into a 'check' token.""" + + return [('check', check)] + + @reducer('check', 'and', 'check') + def _make_and_expr(self, check1, _and, check2): + """Create an 'and_expr'. + + Join two checks by the 'and' operator. + """ + + return [('and_expr', AndCheck([check1, check2]))] + + @reducer('and_expr', 'and', 'check') + def _extend_and_expr(self, and_expr, _and, check): + """Extend an 'and_expr' by adding one more check.""" + + return [('and_expr', and_expr.add_check(check))] + + @reducer('check', 'or', 'check') + def _make_or_expr(self, check1, _or, check2): + """Create an 'or_expr'. + + Join two checks by the 'or' operator. + """ + + return [('or_expr', OrCheck([check1, check2]))] + + @reducer('or_expr', 'or', 'check') + def _extend_or_expr(self, or_expr, _or, check): + """Extend an 'or_expr' by adding one more check.""" + + return [('or_expr', or_expr.add_check(check))] + + @reducer('not', 'check') + def _make_not_expr(self, _not, check): + """Invert the result of another check.""" + + return [('check', NotCheck(check))] + + +def _parse_text_rule(rule): + """Parses policy to the tree. + + Translates a policy written in the policy language into a tree of + Check objects. + """ + + # Empty rule means always accept + if not rule: + return TrueCheck() + + # Parse the token stream + state = ParseState() + for tok, value in _parse_tokenize(rule): + state.shift(tok, value) + + try: + return state.result + except ValueError: + # Couldn't parse the rule + LOG.exception(_LE("Failed to understand rule %r") % rule) + + # Fail closed + return FalseCheck() + + +def parse_rule(rule): + """Parses a policy rule into a tree of Check objects.""" + + # If the rule is a string, it's in the policy language + if isinstance(rule, six.string_types): + return _parse_text_rule(rule) + return _parse_list_rule(rule) + + +def register(name, func=None): + """Register a function or Check class as a policy check. + + :param name: Gives the name of the check type, e.g., 'rule', + 'role', etc. If name is None, a default check type + will be registered. + :param func: If given, provides the function or class to register. + If not given, returns a function taking one argument + to specify the function or class to register, + allowing use as a decorator. + """ + + # Perform the actual decoration by registering the function or + # class. Returns the function or class for compliance with the + # decorator interface. + def decorator(func): + _checks[name] = func + return func + + # If the function or class is given, do the registration + if func: + return decorator(func) + + return decorator + + +@register("rule") +class RuleCheck(Check): + def __call__(self, target, creds, enforcer): + """Recursively checks credentials based on the defined rules.""" + + try: + return enforcer.rules[self.match](target, creds, enforcer) + except KeyError: + # We don't have any matching rule; fail closed + return False + + +@register("role") +class RoleCheck(Check): + def __call__(self, target, creds, enforcer): + """Check that there is a matching role in the cred dict.""" + + return self.match.lower() in [x.lower() for x in creds['roles']] + + +@register('http') +class HttpCheck(Check): + def __call__(self, target, creds, enforcer): + """Check http: rules by calling to a remote server. + + This example implementation simply verifies that the response + is exactly 'True'. + """ + + url = ('http:' + self.match) % target + data = {'target': jsonutils.dumps(target), + 'credentials': jsonutils.dumps(creds)} + post_data = urlparse.urlencode(data) + f = urlrequest.urlopen(url, post_data) + return f.read() == "True" + + +@register(None) +class GenericCheck(Check): + def __call__(self, target, creds, enforcer): + """Check an individual match. + + Matches look like: + + tenant:%(tenant_id)s + role:compute:admin + True:%(user.enabled)s + 'Member':%(role.name)s + """ + + # TODO(termie): do dict inspection via dot syntax + try: + match = self.match % target + except KeyError: + # While doing GenericCheck if key not + # present in Target return false + return False + + try: + # Try to interpret self.kind as a literal + leftval = ast.literal_eval(self.kind) + except ValueError: + try: + leftval = creds[self.kind] + except KeyError: + return False + return match == six.text_type(leftval) diff --git a/cerberus/openstack/common/processutils.py b/cerberus/openstack/common/processutils.py new file mode 100644 index 0000000..de617bc --- /dev/null +++ b/cerberus/openstack/common/processutils.py @@ -0,0 +1,272 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +System-level utilities and helper functions. +""" + +import errno +import logging +import os +import random +import shlex +import signal + +from eventlet.green import subprocess +from eventlet import greenthread +import six + +from cerberus.openstack.common.gettextutils import _ # noqa +from cerberus.openstack.common import strutils + + +LOG = logging.getLogger(__name__) + + +class InvalidArgumentError(Exception): + def __init__(self, message=None): + super(InvalidArgumentError, self).__init__(message) + + +class UnknownArgumentError(Exception): + def __init__(self, message=None): + super(UnknownArgumentError, self).__init__(message) + + +class ProcessExecutionError(Exception): + def __init__(self, stdout=None, stderr=None, exit_code=None, cmd=None, + description=None): + self.exit_code = exit_code + self.stderr = stderr + self.stdout = stdout + self.cmd = cmd + self.description = description + + if description is None: + description = _("Unexpected error while running command.") + if exit_code is None: + exit_code = '-' + message = _('%(description)s\n' + 'Command: %(cmd)s\n' + 'Exit code: %(exit_code)s\n' + 'Stdout: %(stdout)r\n' + 'Stderr: %(stderr)r') % {'description': description, + 'cmd': cmd, + 'exit_code': exit_code, + 'stdout': stdout, + 'stderr': stderr} + super(ProcessExecutionError, self).__init__(message) + + +class NoRootWrapSpecified(Exception): + def __init__(self, message=None): + super(NoRootWrapSpecified, self).__init__(message) + + +def _subprocess_setup(): + # Python installs a SIGPIPE handler by default. This is usually not what + # non-Python subprocesses expect. + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + +def execute(*cmd, **kwargs): + """Helper method to shell out and execute a command through subprocess. + + Allows optional retry. + + :param cmd: Passed to subprocess.Popen. + :type cmd: string + :param process_input: Send to opened process. + :type process_input: string + :param check_exit_code: Single bool, int, or list of allowed exit + codes. Defaults to [0]. Raise + :class:`ProcessExecutionError` unless + program exits with one of these code. + :type check_exit_code: boolean, int, or [int] + :param delay_on_retry: True | False. Defaults to True. If set to True, + wait a short amount of time before retrying. + :type delay_on_retry: boolean + :param attempts: How many times to retry cmd. + :type attempts: int + :param run_as_root: True | False. Defaults to False. If set to True, + the command is prefixed by the command specified + in the root_helper kwarg. + :type run_as_root: boolean + :param root_helper: command to prefix to commands called with + run_as_root=True + :type root_helper: string + :param shell: whether or not there should be a shell used to + execute this command. Defaults to false. + :type shell: boolean + :param loglevel: log level for execute commands. + :type loglevel: int. (Should be logging.DEBUG or logging.INFO) + :returns: (stdout, stderr) from process execution + :raises: :class:`UnknownArgumentError` on + receiving unknown arguments + :raises: :class:`ProcessExecutionError` + """ + + process_input = kwargs.pop('process_input', None) + check_exit_code = kwargs.pop('check_exit_code', [0]) + ignore_exit_code = False + delay_on_retry = kwargs.pop('delay_on_retry', True) + attempts = kwargs.pop('attempts', 1) + run_as_root = kwargs.pop('run_as_root', False) + root_helper = kwargs.pop('root_helper', '') + shell = kwargs.pop('shell', False) + loglevel = kwargs.pop('loglevel', logging.DEBUG) + + if isinstance(check_exit_code, bool): + ignore_exit_code = not check_exit_code + check_exit_code = [0] + elif isinstance(check_exit_code, int): + check_exit_code = [check_exit_code] + + if kwargs: + raise UnknownArgumentError(_('Got unknown keyword args ' + 'to utils.execute: %r') % kwargs) + + if run_as_root and hasattr(os, 'geteuid') and os.geteuid() != 0: + if not root_helper: + raise NoRootWrapSpecified( + message=_('Command requested root, but did not ' + 'specify a root helper.')) + cmd = shlex.split(root_helper) + list(cmd) + + cmd = map(str, cmd) + sanitized_cmd = strutils.mask_password(' '.join(cmd)) + + while attempts > 0: + attempts -= 1 + try: + LOG.log(loglevel, _('Running cmd (subprocess): %s'), sanitized_cmd) + _PIPE = subprocess.PIPE # pylint: disable=E1101 + + if os.name == 'nt': + preexec_fn = None + close_fds = False + else: + preexec_fn = _subprocess_setup + close_fds = True + + obj = subprocess.Popen(cmd, + stdin=_PIPE, + stdout=_PIPE, + stderr=_PIPE, + close_fds=close_fds, + preexec_fn=preexec_fn, + shell=shell) + result = None + for _i in six.moves.range(20): + # NOTE(russellb) 20 is an arbitrary number of retries to + # prevent any chance of looping forever here. + try: + if process_input is not None: + result = obj.communicate(process_input) + else: + result = obj.communicate() + except OSError as e: + if e.errno in (errno.EAGAIN, errno.EINTR): + continue + raise + break + obj.stdin.close() # pylint: disable=E1101 + _returncode = obj.returncode # pylint: disable=E1101 + LOG.log(loglevel, _('Result was %s') % _returncode) + if not ignore_exit_code and _returncode not in check_exit_code: + (stdout, stderr) = result + sanitized_stdout = strutils.mask_password(stdout) + sanitized_stderr = strutils.mask_password(stderr) + raise ProcessExecutionError(exit_code=_returncode, + stdout=sanitized_stdout, + stderr=sanitized_stderr, + cmd=sanitized_cmd) + return result + except ProcessExecutionError: + if not attempts: + raise + else: + LOG.log(loglevel, _('%r failed. Retrying.'), sanitized_cmd) + if delay_on_retry: + greenthread.sleep(random.randint(20, 200) / 100.0) + finally: + # NOTE(termie): this appears to be necessary to let the subprocess + # call clean something up in between calls, without + # it two execute calls in a row hangs the second one + greenthread.sleep(0) + + +def trycmd(*args, **kwargs): + """A wrapper around execute() to more easily handle warnings and errors. + + Returns an (out, err) tuple of strings containing the output of + the command's stdout and stderr. If 'err' is not empty then the + command can be considered to have failed. + + :discard_warnings True | False. Defaults to False. If set to True, + then for succeeding commands, stderr is cleared + + """ + discard_warnings = kwargs.pop('discard_warnings', False) + + try: + out, err = execute(*args, **kwargs) + failed = False + except ProcessExecutionError as exn: + out, err = '', str(exn) + failed = True + + if not failed and discard_warnings and err: + # Handle commands that output to stderr but otherwise succeed + err = '' + + return out, err + + +def ssh_execute(ssh, cmd, process_input=None, + addl_env=None, check_exit_code=True): + sanitized_cmd = strutils.mask_password(cmd) + LOG.debug('Running cmd (SSH): %s', sanitized_cmd) + if addl_env: + raise InvalidArgumentError(_('Environment not supported over SSH')) + + if process_input: + # This is (probably) fixable if we need it... + raise InvalidArgumentError(_('process_input not supported over SSH')) + + stdin_stream, stdout_stream, stderr_stream = ssh.exec_command(cmd) + channel = stdout_stream.channel + + # NOTE(justinsb): This seems suspicious... + # ...other SSH clients have buffering issues with this approach + stdout = stdout_stream.read() + sanitized_stdout = strutils.mask_password(stdout) + stderr = stderr_stream.read() + sanitized_stderr = strutils.mask_password(stderr) + + stdin_stream.close() + + exit_status = channel.recv_exit_status() + + # exit_status == -1 if no exit code was returned + if exit_status != -1: + LOG.debug('Result was %s' % exit_status) + if check_exit_code and exit_status != 0: + raise ProcessExecutionError(exit_code=exit_status, + stdout=sanitized_stdout, + stderr=sanitized_stderr, + cmd=sanitized_cmd) + + return (sanitized_stdout, sanitized_stderr) diff --git a/cerberus/openstack/common/service.py b/cerberus/openstack/common/service.py new file mode 100644 index 0000000..6b5aefc --- /dev/null +++ b/cerberus/openstack/common/service.py @@ -0,0 +1,504 @@ +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2011 Justin Santa Barbara +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Generic Node base class for all workers that run on hosts.""" + +import errno +import logging as std_logging +import os +import random +import signal +import sys +import time + +try: + # Importing just the symbol here because the io module does not + # exist in Python 2.6. + from io import UnsupportedOperation # noqa +except ImportError: + # Python 2.6 + UnsupportedOperation = None + +import eventlet +from eventlet import event +from oslo.config import cfg + +from cerberus.openstack.common import eventlet_backdoor +from cerberus.openstack.common.gettextutils import _LE, _LI, _LW +from cerberus.openstack.common import importutils +from cerberus.openstack.common import log as logging +from cerberus.openstack.common import systemd +from cerberus.openstack.common import threadgroup + + +rpc = importutils.try_import('cerberus.openstack.common.rpc') +CONF = cfg.CONF +LOG = logging.getLogger(__name__) + + +def _sighup_supported(): + return hasattr(signal, 'SIGHUP') + + +def _is_daemon(): + # The process group for a foreground process will match the + # process group of the controlling terminal. If those values do + # not match, or ioctl() fails on the stdout file handle, we assume + # the process is running in the background as a daemon. + # http://www.gnu.org/software/bash/manual/bashref.html#Job-Control-Basics + try: + is_daemon = os.getpgrp() != os.tcgetpgrp(sys.stdout.fileno()) + except OSError as err: + if err.errno == errno.ENOTTY: + # Assume we are a daemon because there is no terminal. + is_daemon = True + else: + raise + except UnsupportedOperation: + # Could not get the fileno for stdout, so we must be a daemon. + is_daemon = True + return is_daemon + + +def _is_sighup_and_daemon(signo): + if not (_sighup_supported() and signo == signal.SIGHUP): + # Avoid checking if we are a daemon, because the signal isn't + # SIGHUP. + return False + return _is_daemon() + + +def _signo_to_signame(signo): + signals = {signal.SIGTERM: 'SIGTERM', + signal.SIGINT: 'SIGINT'} + if _sighup_supported(): + signals[signal.SIGHUP] = 'SIGHUP' + return signals[signo] + + +def _set_signals_handler(handler): + signal.signal(signal.SIGTERM, handler) + signal.signal(signal.SIGINT, handler) + if _sighup_supported(): + signal.signal(signal.SIGHUP, handler) + + +class Launcher(object): + """Launch one or more services and wait for them to complete.""" + + def __init__(self): + """Initialize the service launcher. + + :returns: None + + """ + self.services = Services() + self.backdoor_port = eventlet_backdoor.initialize_if_enabled() + + def launch_service(self, service): + """Load and start the given service. + + :param service: The service you would like to start. + :returns: None + + """ + service.backdoor_port = self.backdoor_port + self.services.add(service) + + def stop(self): + """Stop all services which are currently running. + + :returns: None + + """ + self.services.stop() + + def wait(self): + """Waits until all services have been stopped, and then returns. + + :returns: None + + """ + self.services.wait() + + def restart(self): + """Reload config files and restart service. + + :returns: None + + """ + cfg.CONF.reload_config_files() + self.services.restart() + + +class SignalExit(SystemExit): + def __init__(self, signo, exccode=1): + super(SignalExit, self).__init__(exccode) + self.signo = signo + + +class ServiceLauncher(Launcher): + def _handle_signal(self, signo, frame): + # Allow the process to be killed again and die from natural causes + _set_signals_handler(signal.SIG_DFL) + raise SignalExit(signo) + + def handle_signal(self): + _set_signals_handler(self._handle_signal) + + def _wait_for_exit_or_signal(self, ready_callback=None): + status = None + signo = 0 + + LOG.debug('Full set of CONF:') + CONF.log_opt_values(LOG, std_logging.DEBUG) + + try: + if ready_callback: + ready_callback() + super(ServiceLauncher, self).wait() + except SignalExit as exc: + signame = _signo_to_signame(exc.signo) + LOG.info(_LI('Caught %s, exiting'), signame) + status = exc.code + signo = exc.signo + except SystemExit as exc: + status = exc.code + finally: + self.stop() + if rpc: + try: + rpc.cleanup() + except Exception: + # We're shutting down, so it doesn't matter at this point. + LOG.exception(_LE('Exception during rpc cleanup.')) + + return status, signo + + def wait(self, ready_callback=None): + systemd.notify_once() + while True: + self.handle_signal() + status, signo = self._wait_for_exit_or_signal(ready_callback) + if not _is_sighup_and_daemon(signo): + return status + self.restart() + + +class ServiceWrapper(object): + def __init__(self, service, workers): + self.service = service + self.workers = workers + self.children = set() + self.forktimes = [] + + +class ProcessLauncher(object): + def __init__(self, wait_interval=0.01): + """Constructor. + + :param wait_interval: The interval to sleep for between checks + of child process exit. + """ + self.children = {} + self.sigcaught = None + self.running = True + self.wait_interval = wait_interval + rfd, self.writepipe = os.pipe() + self.readpipe = eventlet.greenio.GreenPipe(rfd, 'r') + self.handle_signal() + + def handle_signal(self): + _set_signals_handler(self._handle_signal) + + def _handle_signal(self, signo, frame): + self.sigcaught = signo + self.running = False + + # Allow the process to be killed again and die from natural causes + _set_signals_handler(signal.SIG_DFL) + + def _pipe_watcher(self): + # This will block until the write end is closed when the parent + # dies unexpectedly + self.readpipe.read() + + LOG.info(_LI('Parent process has died unexpectedly, exiting')) + + sys.exit(1) + + def _child_process_handle_signal(self): + # Setup child signal handlers differently + def _sigterm(*args): + signal.signal(signal.SIGTERM, signal.SIG_DFL) + raise SignalExit(signal.SIGTERM) + + def _sighup(*args): + signal.signal(signal.SIGHUP, signal.SIG_DFL) + raise SignalExit(signal.SIGHUP) + + signal.signal(signal.SIGTERM, _sigterm) + if _sighup_supported(): + signal.signal(signal.SIGHUP, _sighup) + # Block SIGINT and let the parent send us a SIGTERM + signal.signal(signal.SIGINT, signal.SIG_IGN) + + def _child_wait_for_exit_or_signal(self, launcher): + status = 0 + signo = 0 + + # NOTE(johannes): All exceptions are caught to ensure this + # doesn't fallback into the loop spawning children. It would + # be bad for a child to spawn more children. + try: + launcher.wait() + except SignalExit as exc: + signame = _signo_to_signame(exc.signo) + LOG.info(_LI('Caught %s, exiting'), signame) + status = exc.code + signo = exc.signo + except SystemExit as exc: + status = exc.code + except BaseException: + LOG.exception(_LE('Unhandled exception')) + status = 2 + finally: + launcher.stop() + + return status, signo + + def _child_process(self, service): + self._child_process_handle_signal() + + # Reopen the eventlet hub to make sure we don't share an epoll + # fd with parent and/or siblings, which would be bad + eventlet.hubs.use_hub() + + # Close write to ensure only parent has it open + os.close(self.writepipe) + # Create greenthread to watch for parent to close pipe + eventlet.spawn_n(self._pipe_watcher) + + # Reseed random number generator + random.seed() + + launcher = Launcher() + launcher.launch_service(service) + return launcher + + def _start_child(self, wrap): + if len(wrap.forktimes) > wrap.workers: + # Limit ourselves to one process a second (over the period of + # number of workers * 1 second). This will allow workers to + # start up quickly but ensure we don't fork off children that + # die instantly too quickly. + if time.time() - wrap.forktimes[0] < wrap.workers: + LOG.info(_LI('Forking too fast, sleeping')) + time.sleep(1) + + wrap.forktimes.pop(0) + + wrap.forktimes.append(time.time()) + + pid = os.fork() + if pid == 0: + launcher = self._child_process(wrap.service) + while True: + self._child_process_handle_signal() + status, signo = self._child_wait_for_exit_or_signal(launcher) + if not _is_sighup_and_daemon(signo): + break + launcher.restart() + + os._exit(status) + + LOG.info(_LI('Started child %d'), pid) + + wrap.children.add(pid) + self.children[pid] = wrap + + return pid + + def launch_service(self, service, workers=1): + wrap = ServiceWrapper(service, workers) + + LOG.info(_LI('Starting %d workers'), wrap.workers) + while self.running and len(wrap.children) < wrap.workers: + self._start_child(wrap) + + def _wait_child(self): + try: + # Don't block if no child processes have exited + pid, status = os.waitpid(0, os.WNOHANG) + if not pid: + return None + except OSError as exc: + if exc.errno not in (errno.EINTR, errno.ECHILD): + raise + return None + + if os.WIFSIGNALED(status): + sig = os.WTERMSIG(status) + LOG.info(_LI('Child %(pid)d killed by signal %(sig)d'), + dict(pid=pid, sig=sig)) + else: + code = os.WEXITSTATUS(status) + LOG.info(_LI('Child %(pid)s exited with status %(code)d'), + dict(pid=pid, code=code)) + + if pid not in self.children: + LOG.warning(_LW('pid %d not in child list'), pid) + return None + + wrap = self.children.pop(pid) + wrap.children.remove(pid) + return wrap + + def _respawn_children(self): + while self.running: + wrap = self._wait_child() + if not wrap: + # Yield to other threads if no children have exited + # Sleep for a short time to avoid excessive CPU usage + # (see bug #1095346) + eventlet.greenthread.sleep(self.wait_interval) + continue + while self.running and len(wrap.children) < wrap.workers: + self._start_child(wrap) + + def wait(self): + """Loop waiting on children to die and respawning as necessary.""" + + systemd.notify_once() + LOG.debug('Full set of CONF:') + CONF.log_opt_values(LOG, std_logging.DEBUG) + + try: + while True: + self.handle_signal() + self._respawn_children() + if self.sigcaught: + signame = _signo_to_signame(self.sigcaught) + LOG.info(_LI('Caught %s, stopping children'), signame) + if not _is_sighup_and_daemon(self.sigcaught): + break + + for pid in self.children: + os.kill(pid, signal.SIGHUP) + self.running = True + self.sigcaught = None + except eventlet.greenlet.GreenletExit: + LOG.info(_LI("Wait called after thread killed. Cleaning up.")) + + for pid in self.children: + try: + os.kill(pid, signal.SIGTERM) + except OSError as exc: + if exc.errno != errno.ESRCH: + raise + + # Wait for children to die + if self.children: + LOG.info(_LI('Waiting on %d children to exit'), len(self.children)) + while self.children: + self._wait_child() + + +class Service(object): + """Service object for binaries running on hosts.""" + + def __init__(self, threads=1000): + self.tg = threadgroup.ThreadGroup(threads) + + # signal that the service is done shutting itself down: + self._done = event.Event() + + def reset(self): + # NOTE(Fengqian): docs for Event.reset() recommend against using it + self._done = event.Event() + + def start(self): + pass + + def stop(self): + self.tg.stop() + self.tg.wait() + # Signal that service cleanup is done: + if not self._done.ready(): + self._done.send() + + def wait(self): + self._done.wait() + + +class Services(object): + + def __init__(self): + self.services = [] + self.tg = threadgroup.ThreadGroup() + self.done = event.Event() + + def add(self, service): + self.services.append(service) + self.tg.add_thread(self.run_service, service, self.done) + + def stop(self): + # wait for graceful shutdown of services: + for service in self.services: + service.stop() + service.wait() + + # Each service has performed cleanup, now signal that the run_service + # wrapper threads can now die: + if not self.done.ready(): + self.done.send() + + # reap threads: + self.tg.stop() + + def wait(self): + self.tg.wait() + + def restart(self): + self.stop() + self.done = event.Event() + for restart_service in self.services: + restart_service.reset() + self.tg.add_thread(self.run_service, restart_service, self.done) + + @staticmethod + def run_service(service, done): + """Service start wrapper. + + :param service: service to run + :param done: event to wait on until a shutdown is triggered + :returns: None + + """ + service.start() + done.wait() + + +def launch(service, workers=1): + if workers is None or workers == 1: + launcher = ServiceLauncher() + launcher.launch_service(service) + else: + launcher = ProcessLauncher() + launcher.launch_service(service, workers=workers) + + return launcher diff --git a/cerberus/openstack/common/sslutils.py b/cerberus/openstack/common/sslutils.py new file mode 100644 index 0000000..5ad2766 --- /dev/null +++ b/cerberus/openstack/common/sslutils.py @@ -0,0 +1,98 @@ +# Copyright 2013 IBM Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import ssl + +from oslo.config import cfg + +from cerberus.openstack.common.gettextutils import _ + + +ssl_opts = [ + cfg.StrOpt('ca_file', + default=None, + help="CA certificate file to use to verify " + "connecting clients."), + cfg.StrOpt('cert_file', + default=None, + help="Certificate file to use when starting " + "the server securely."), + cfg.StrOpt('key_file', + default=None, + help="Private key file to use when starting " + "the server securely."), +] + + +CONF = cfg.CONF +CONF.register_opts(ssl_opts, "ssl") + + +def is_enabled(): + cert_file = CONF.ssl.cert_file + key_file = CONF.ssl.key_file + ca_file = CONF.ssl.ca_file + use_ssl = cert_file or key_file + + if cert_file and not os.path.exists(cert_file): + raise RuntimeError(_("Unable to find cert_file : %s") % cert_file) + + if ca_file and not os.path.exists(ca_file): + raise RuntimeError(_("Unable to find ca_file : %s") % ca_file) + + if key_file and not os.path.exists(key_file): + raise RuntimeError(_("Unable to find key_file : %s") % key_file) + + if use_ssl and (not cert_file or not key_file): + raise RuntimeError(_("When running server in SSL mode, you must " + "specify both a cert_file and key_file " + "option value in your configuration file")) + + return use_ssl + + +def wrap(sock): + ssl_kwargs = { + 'server_side': True, + 'certfile': CONF.ssl.cert_file, + 'keyfile': CONF.ssl.key_file, + 'cert_reqs': ssl.CERT_NONE, + } + + if CONF.ssl.ca_file: + ssl_kwargs['ca_certs'] = CONF.ssl.ca_file + ssl_kwargs['cert_reqs'] = ssl.CERT_REQUIRED + + return ssl.wrap_socket(sock, **ssl_kwargs) + + +_SSL_PROTOCOLS = { + "tlsv1": ssl.PROTOCOL_TLSv1, + "sslv23": ssl.PROTOCOL_SSLv23, + "sslv3": ssl.PROTOCOL_SSLv3 +} + +try: + _SSL_PROTOCOLS["sslv2"] = ssl.PROTOCOL_SSLv2 +except AttributeError: + pass + + +def validate_ssl_version(version): + key = version.lower() + try: + return _SSL_PROTOCOLS[key] + except KeyError: + raise RuntimeError(_("Invalid SSL version : %s") % version) diff --git a/cerberus/openstack/common/strutils.py b/cerberus/openstack/common/strutils.py new file mode 100644 index 0000000..e50c9b7 --- /dev/null +++ b/cerberus/openstack/common/strutils.py @@ -0,0 +1,322 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +System-level utilities and helper functions. +""" + +import math +import re +import sys +import unicodedata + +import six + +from cerberus.openstack.common.gettextutils import _ + + +UNIT_PREFIX_EXPONENT = { + 'k': 1, + 'K': 1, + 'Ki': 1, + 'M': 2, + 'Mi': 2, + 'G': 3, + 'Gi': 3, + 'T': 4, + 'Ti': 4, +} +UNIT_SYSTEM_INFO = { + 'IEC': (1024, re.compile(r'(^[-+]?\d*\.?\d+)([KMGT]i?)?(b|bit|B)$')), + 'SI': (1000, re.compile(r'(^[-+]?\d*\.?\d+)([kMGT])?(b|bit|B)$')), +} + +TRUE_STRINGS = ('1', 't', 'true', 'on', 'y', 'yes') +FALSE_STRINGS = ('0', 'f', 'false', 'off', 'n', 'no') + +SLUGIFY_STRIP_RE = re.compile(r"[^\w\s-]") +SLUGIFY_HYPHENATE_RE = re.compile(r"[-\s]+") + + +# NOTE(flaper87): The following globals are used by `mask_password` +_SANITIZE_KEYS = ['adminPass', 'admin_pass', 'password', 'admin_password'] + +# NOTE(ldbragst): Let's build a list of regex objects using the list of +# _SANITIZE_KEYS we already have. This way, we only have to add the new key +# to the list of _SANITIZE_KEYS and we can generate regular expressions +# for XML and JSON automatically. +_SANITIZE_PATTERNS_2 = [] +_SANITIZE_PATTERNS_1 = [] + +# NOTE(amrith): Some regular expressions have only one parameter, some +# have two parameters. Use different lists of patterns here. +_FORMAT_PATTERNS_1 = [r'(%(key)s\s*[=]\s*)[^\s^\'^\"]+'] +_FORMAT_PATTERNS_2 = [r'(%(key)s\s*[=]\s*[\"\']).*?([\"\'])', + r'(%(key)s\s+[\"\']).*?([\"\'])', + r'([-]{2}%(key)s\s+)[^\'^\"^=^\s]+([\s]*)', + r'(<%(key)s>).*?()', + r'([\"\']%(key)s[\"\']\s*:\s*[\"\']).*?([\"\'])', + r'([\'"].*?%(key)s[\'"]\s*:\s*u?[\'"]).*?([\'"])', + r'([\'"].*?%(key)s[\'"]\s*,\s*\'--?[A-z]+\'\s*,\s*u?' + '[\'"]).*?([\'"])', + r'(%(key)s\s*--?[A-z]+\s*)\S+(\s*)'] + +for key in _SANITIZE_KEYS: + for pattern in _FORMAT_PATTERNS_2: + reg_ex = re.compile(pattern % {'key': key}, re.DOTALL) + _SANITIZE_PATTERNS_2.append(reg_ex) + + for pattern in _FORMAT_PATTERNS_1: + reg_ex = re.compile(pattern % {'key': key}, re.DOTALL) + _SANITIZE_PATTERNS_1.append(reg_ex) + + +def int_from_bool_as_string(subject): + """Interpret a string as a boolean and return either 1 or 0. + + Any string value in: + + ('True', 'true', 'On', 'on', '1') + + is interpreted as a boolean True. + + Useful for JSON-decoded stuff and config file parsing + """ + return bool_from_string(subject) and 1 or 0 + + +def bool_from_string(subject, strict=False, default=False): + """Interpret a string as a boolean. + + A case-insensitive match is performed such that strings matching 't', + 'true', 'on', 'y', 'yes', or '1' are considered True and, when + `strict=False`, anything else returns the value specified by 'default'. + + Useful for JSON-decoded stuff and config file parsing. + + If `strict=True`, unrecognized values, including None, will raise a + ValueError which is useful when parsing values passed in from an API call. + Strings yielding False are 'f', 'false', 'off', 'n', 'no', or '0'. + """ + if not isinstance(subject, six.string_types): + subject = str(subject) + + lowered = subject.strip().lower() + + if lowered in TRUE_STRINGS: + return True + elif lowered in FALSE_STRINGS: + return False + elif strict: + acceptable = ', '.join( + "'%s'" % s for s in sorted(TRUE_STRINGS + FALSE_STRINGS)) + msg = _("Unrecognized value '%(val)s', acceptable values are:" + " %(acceptable)s") % {'val': subject, + 'acceptable': acceptable} + raise ValueError(msg) + else: + return default + + +def safe_decode(text, incoming=None, errors='strict'): + """Decodes incoming text/bytes string using `incoming` if they're not + already unicode. + + :param incoming: Text's current encoding + :param errors: Errors handling policy. See here for valid + values http://docs.python.org/2/library/codecs.html + :returns: text or a unicode `incoming` encoded + representation of it. + :raises TypeError: If text is not an instance of str + """ + if not isinstance(text, (six.string_types, six.binary_type)): + raise TypeError("%s can't be decoded" % type(text)) + + if isinstance(text, six.text_type): + return text + + if not incoming: + incoming = (sys.stdin.encoding or + sys.getdefaultencoding()) + + try: + return text.decode(incoming, errors) + except UnicodeDecodeError: + # Note(flaper87) If we get here, it means that + # sys.stdin.encoding / sys.getdefaultencoding + # didn't return a suitable encoding to decode + # text. This happens mostly when global LANG + # var is not set correctly and there's no + # default encoding. In this case, most likely + # python will use ASCII or ANSI encoders as + # default encodings but they won't be capable + # of decoding non-ASCII characters. + # + # Also, UTF-8 is being used since it's an ASCII + # extension. + return text.decode('utf-8', errors) + + +def safe_encode(text, incoming=None, + encoding='utf-8', errors='strict'): + """Encodes incoming text/bytes string using `encoding`. + + If incoming is not specified, text is expected to be encoded with + current python's default encoding. (`sys.getdefaultencoding`) + + :param incoming: Text's current encoding + :param encoding: Expected encoding for text (Default UTF-8) + :param errors: Errors handling policy. See here for valid + values http://docs.python.org/2/library/codecs.html + :returns: text or a bytestring `encoding` encoded + representation of it. + :raises TypeError: If text is not an instance of str + """ + if not isinstance(text, (six.string_types, six.binary_type)): + raise TypeError("%s can't be encoded" % type(text)) + + if not incoming: + incoming = (sys.stdin.encoding or + sys.getdefaultencoding()) + + if isinstance(text, six.text_type): + if six.PY3: + return text.encode(encoding, errors).decode(incoming) + else: + return text.encode(encoding, errors) + elif text and encoding != incoming: + # Decode text before encoding it with `encoding` + text = safe_decode(text, incoming, errors) + if six.PY3: + return text.encode(encoding, errors).decode(incoming) + else: + return text.encode(encoding, errors) + + return text + + +def string_to_bytes(text, unit_system='IEC', return_int=False): + """Converts a string into an float representation of bytes. + + The units supported for IEC :: + + Kb(it), Kib(it), Mb(it), Mib(it), Gb(it), Gib(it), Tb(it), Tib(it) + KB, KiB, MB, MiB, GB, GiB, TB, TiB + + The units supported for SI :: + + kb(it), Mb(it), Gb(it), Tb(it) + kB, MB, GB, TB + + Note that the SI unit system does not support capital letter 'K' + + :param text: String input for bytes size conversion. + :param unit_system: Unit system for byte size conversion. + :param return_int: If True, returns integer representation of text + in bytes. (default: decimal) + :returns: Numerical representation of text in bytes. + :raises ValueError: If text has an invalid value. + + """ + try: + base, reg_ex = UNIT_SYSTEM_INFO[unit_system] + except KeyError: + msg = _('Invalid unit system: "%s"') % unit_system + raise ValueError(msg) + match = reg_ex.match(text) + if match: + magnitude = float(match.group(1)) + unit_prefix = match.group(2) + if match.group(3) in ['b', 'bit']: + magnitude /= 8 + else: + msg = _('Invalid string format: %s') % text + raise ValueError(msg) + if not unit_prefix: + res = magnitude + else: + res = magnitude * pow(base, UNIT_PREFIX_EXPONENT[unit_prefix]) + if return_int: + return int(math.ceil(res)) + return res + + +def to_slug(value, incoming=None, errors="strict"): + """Normalize string. + + Convert to lowercase, remove non-word characters, and convert spaces + to hyphens. + + Inspired by Django's `slugify` filter. + + :param value: Text to slugify + :param incoming: Text's current encoding + :param errors: Errors handling policy. See here for valid + values http://docs.python.org/2/library/codecs.html + :returns: slugified unicode representation of `value` + :raises TypeError: If text is not an instance of str + """ + value = safe_decode(value, incoming, errors) + # NOTE(aababilov): no need to use safe_(encode|decode) here: + # encodings are always "ascii", error handling is always "ignore" + # and types are always known (first: unicode; second: str) + value = unicodedata.normalize("NFKD", value).encode( + "ascii", "ignore").decode("ascii") + value = SLUGIFY_STRIP_RE.sub("", value).strip().lower() + return SLUGIFY_HYPHENATE_RE.sub("-", value) + + +def mask_password(message, secret="***"): + """Replace password with 'secret' in message. + + :param message: The string which includes security information. + :param secret: value with which to replace passwords. + :returns: The unicode value of message with the password fields masked. + + For example: + + >>> mask_password("'adminPass' : 'aaaaa'") + "'adminPass' : '***'" + >>> mask_password("'admin_pass' : 'aaaaa'") + "'admin_pass' : '***'" + >>> mask_password('"password" : "aaaaa"') + '"password" : "***"' + >>> mask_password("'original_password' : 'aaaaa'") + "'original_password' : '***'" + >>> mask_password("u'original_password' : u'aaaaa'") + "u'original_password' : u'***'" + """ + try: + message = six.text_type(message) + except UnicodeDecodeError: + # NOTE(jecarey): Temporary fix to handle cases where message is a + # byte string. A better solution will be provided in Kilo. + pass + + # NOTE(ldbragst): Check to see if anything in message contains any key + # specified in _SANITIZE_KEYS, if not then just return the message since + # we don't have to mask any passwords. + if not any(key in message for key in _SANITIZE_KEYS): + return message + + substitute = r'\g<1>' + secret + r'\g<2>' + for pattern in _SANITIZE_PATTERNS_2: + message = re.sub(pattern, substitute, message) + + substitute = r'\g<1>' + secret + for pattern in _SANITIZE_PATTERNS_1: + message = re.sub(pattern, substitute, message) + + return message diff --git a/cerberus/openstack/common/systemd.py b/cerberus/openstack/common/systemd.py new file mode 100644 index 0000000..a5707cb --- /dev/null +++ b/cerberus/openstack/common/systemd.py @@ -0,0 +1,104 @@ +# Copyright 2012-2014 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Helper module for systemd service readiness notification. +""" + +import os +import socket +import sys + +from cerberus.openstack.common import log as logging + + +LOG = logging.getLogger(__name__) + + +def _abstractify(socket_name): + if socket_name.startswith('@'): + # abstract namespace socket + socket_name = '\0%s' % socket_name[1:] + return socket_name + + +def _sd_notify(unset_env, msg): + notify_socket = os.getenv('NOTIFY_SOCKET') + if notify_socket: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + try: + sock.connect(_abstractify(notify_socket)) + sock.sendall(msg) + if unset_env: + del os.environ['NOTIFY_SOCKET'] + except EnvironmentError: + LOG.debug("Systemd notification failed", exc_info=True) + finally: + sock.close() + + +def notify(): + """Send notification to Systemd that service is ready. + For details see + http://www.freedesktop.org/software/systemd/man/sd_notify.html + """ + _sd_notify(False, 'READY=1') + + +def notify_once(): + """Send notification once to Systemd that service is ready. + Systemd sets NOTIFY_SOCKET environment variable with the name of the + socket listening for notifications from services. + This method removes the NOTIFY_SOCKET environment variable to ensure + notification is sent only once. + """ + _sd_notify(True, 'READY=1') + + +def onready(notify_socket, timeout): + """Wait for systemd style notification on the socket. + + :param notify_socket: local socket address + :type notify_socket: string + :param timeout: socket timeout + :type timeout: float + :returns: 0 service ready + 1 service not ready + 2 timeout occured + """ + sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + sock.settimeout(timeout) + sock.bind(_abstractify(notify_socket)) + try: + msg = sock.recv(512) + except socket.timeout: + return 2 + finally: + sock.close() + if 'READY=1' in msg: + return 0 + else: + return 1 + + +if __name__ == '__main__': + # simple CLI for testing + if len(sys.argv) == 1: + notify() + elif len(sys.argv) >= 2: + timeout = float(sys.argv[1]) + notify_socket = os.getenv('NOTIFY_SOCKET') + if notify_socket: + retval = onready(notify_socket, timeout) + sys.exit(retval) diff --git a/cerberus/openstack/common/test.py b/cerberus/openstack/common/test.py new file mode 100644 index 0000000..a391f54 --- /dev/null +++ b/cerberus/openstack/common/test.py @@ -0,0 +1,99 @@ +# Copyright (c) 2013 Hewlett-Packard Development Company, L.P. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +############################################################################## +############################################################################## +## +## DO NOT MODIFY THIS FILE +## +## This file is being graduated to the cerberustest library. Please make all +## changes there, and only backport critical fixes here. - dhellmann +## +############################################################################## +############################################################################## + +"""Common utilities used in testing""" + +import logging +import os +import tempfile + +import fixtures +import testtools + +_TRUE_VALUES = ('True', 'true', '1', 'yes') +_LOG_FORMAT = "%(levelname)8s [%(name)s] %(message)s" + + +class BaseTestCase(testtools.TestCase): + + def setUp(self): + super(BaseTestCase, self).setUp() + self._set_timeout() + self._fake_output() + self._fake_logs() + self.useFixture(fixtures.NestedTempfile()) + self.useFixture(fixtures.TempHomeDir()) + self.tempdirs = [] + + def _set_timeout(self): + test_timeout = os.environ.get('OS_TEST_TIMEOUT', 0) + try: + test_timeout = int(test_timeout) + except ValueError: + # If timeout value is invalid do not set a timeout. + test_timeout = 0 + if test_timeout > 0: + self.useFixture(fixtures.Timeout(test_timeout, gentle=True)) + + def _fake_output(self): + if os.environ.get('OS_STDOUT_CAPTURE') in _TRUE_VALUES: + stdout = self.useFixture(fixtures.StringStream('stdout')).stream + self.useFixture(fixtures.MonkeyPatch('sys.stdout', stdout)) + if os.environ.get('OS_STDERR_CAPTURE') in _TRUE_VALUES: + stderr = self.useFixture(fixtures.StringStream('stderr')).stream + self.useFixture(fixtures.MonkeyPatch('sys.stderr', stderr)) + + def _fake_logs(self): + if os.environ.get('OS_DEBUG') in _TRUE_VALUES: + level = logging.DEBUG + else: + level = logging.INFO + capture_logs = os.environ.get('OS_LOG_CAPTURE') in _TRUE_VALUES + if capture_logs: + self.useFixture( + fixtures.FakeLogger( + format=_LOG_FORMAT, + level=level, + nuke_handlers=capture_logs, + ) + ) + else: + logging.basicConfig(format=_LOG_FORMAT, level=level) + + def create_tempfiles(self, files, ext='.conf'): + tempfiles = [] + for (basename, contents) in files: + if not os.path.isabs(basename): + (fd, path) = tempfile.mkstemp(prefix=basename, suffix=ext) + else: + path = basename + ext + fd = os.open(path, os.O_CREAT | os.O_WRONLY) + tempfiles.append(path) + try: + os.write(fd, contents) + finally: + os.close(fd) + return tempfiles diff --git a/cerberus/openstack/common/threadgroup.py b/cerberus/openstack/common/threadgroup.py new file mode 100644 index 0000000..b068095 --- /dev/null +++ b/cerberus/openstack/common/threadgroup.py @@ -0,0 +1,149 @@ +# Copyright 2012 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import threading + +import eventlet +from eventlet import greenpool + +from cerberus.openstack.common import log as logging +from cerberus.openstack.common import loopingcall + + +LOG = logging.getLogger(__name__) + + +def _thread_done(gt, *args, **kwargs): + """Callback function to be passed to GreenThread.link() when we spawn() + Calls the :class:`ThreadGroup` to notify if. + + """ + kwargs['group'].thread_done(kwargs['thread']) + + +class Thread(object): + """Wrapper around a greenthread, that holds a reference to the + :class:`ThreadGroup`. The Thread will notify the :class:`ThreadGroup` when + it has done so it can be removed from the threads list. + """ + def __init__(self, thread, group, *args, **kwargs): + self.args = args + self.kw = kwargs + self.thread = thread + self.thread.link(_thread_done, group=group, thread=self) + + def stop(self): + self.thread.kill() + + def wait(self): + return self.thread.wait() + + def link(self, func, *args, **kwargs): + self.thread.link(func, *args, **kwargs) + + +class ThreadGroup(object): + """The point of the ThreadGroup class is to: + + * keep track of timers and greenthreads (making it easier to stop them + when need be). + * provide an easy API to add timers. + """ + def __init__(self, thread_pool_size=10): + self.pool = greenpool.GreenPool(thread_pool_size) + self.threads = [] + self.timers = [] + + def add_dynamic_timer(self, callback, initial_delay=None, + periodic_interval_max=None, *args, **kwargs): + timer = loopingcall.DynamicLoopingCall(callback, *args, **kwargs) + timer.start(initial_delay=initial_delay, + periodic_interval_max=periodic_interval_max) + self.timers.append(timer) + + def add_timer(self, interval, callback, initial_delay=None, + *args, **kwargs): + pulse = loopingcall.FixedIntervalLoopingCall(callback, *args, **kwargs) + pulse.start(interval=interval, + initial_delay=initial_delay) + self.timers.append(pulse) + + def add_thread(self, callback, *args, **kwargs): + gt = self.pool.spawn(callback, *args, **kwargs) + th = Thread(gt, self, *args, **kwargs) + self.threads.append(th) + return th + + def thread_done(self, thread): + self.threads.remove(thread) + + def _stop_threads(self): + current = threading.current_thread() + + # Iterate over a copy of self.threads so thread_done doesn't + # modify the list while we're iterating + for x in self.threads[:]: + if x is current: + # don't kill the current thread. + continue + try: + x.stop() + except Exception as ex: + LOG.exception(ex) + + def stop_timers(self): + for x in self.timers: + try: + x.stop() + except Exception as ex: + LOG.exception(ex) + self.timers = [] + + def stop(self, graceful=False): + """stop function has the option of graceful=True/False. + + * In case of graceful=True, wait for all threads to be finished. + Never kill threads. + * In case of graceful=False, kill threads immediately. + """ + self.stop_timers() + if graceful: + # In case of graceful=True, wait for all threads to be + # finished, never kill threads + self.wait() + else: + # In case of graceful=False(Default), kill threads + # immediately + self._stop_threads() + + def wait(self): + for x in self.timers: + try: + x.wait() + except eventlet.greenlet.GreenletExit: + pass + except Exception as ex: + LOG.exception(ex) + current = threading.current_thread() + + # Iterate over a copy of self.threads so thread_done doesn't + # modify the list while we're iterating + for x in self.threads[:]: + if x is current: + continue + try: + x.wait() + except eventlet.greenlet.GreenletExit: + pass + except Exception as ex: + LOG.exception(ex) diff --git a/cerberus/openstack/common/timeutils.py b/cerberus/openstack/common/timeutils.py new file mode 100644 index 0000000..52688a0 --- /dev/null +++ b/cerberus/openstack/common/timeutils.py @@ -0,0 +1,210 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Time related utilities and helper functions. +""" + +import calendar +import datetime +import time + +import iso8601 +import six + + +# ISO 8601 extended time format with microseconds +_ISO8601_TIME_FORMAT_SUBSECOND = '%Y-%m-%dT%H:%M:%S.%f' +_ISO8601_TIME_FORMAT = '%Y-%m-%dT%H:%M:%S' +PERFECT_TIME_FORMAT = _ISO8601_TIME_FORMAT_SUBSECOND + + +def isotime(at=None, subsecond=False): + """Stringify time in ISO 8601 format.""" + if not at: + at = utcnow() + st = at.strftime(_ISO8601_TIME_FORMAT + if not subsecond + else _ISO8601_TIME_FORMAT_SUBSECOND) + tz = at.tzinfo.tzname(None) if at.tzinfo else 'UTC' + st += ('Z' if tz == 'UTC' else tz) + return st + + +def parse_isotime(timestr): + """Parse time from ISO 8601 format.""" + try: + return iso8601.parse_date(timestr) + except iso8601.ParseError as e: + raise ValueError(six.text_type(e)) + except TypeError as e: + raise ValueError(six.text_type(e)) + + +def strtime(at=None, fmt=PERFECT_TIME_FORMAT): + """Returns formatted utcnow.""" + if not at: + at = utcnow() + return at.strftime(fmt) + + +def parse_strtime(timestr, fmt=PERFECT_TIME_FORMAT): + """Turn a formatted time back into a datetime.""" + return datetime.datetime.strptime(timestr, fmt) + + +def normalize_time(timestamp): + """Normalize time in arbitrary timezone to UTC naive object.""" + offset = timestamp.utcoffset() + if offset is None: + return timestamp + return timestamp.replace(tzinfo=None) - offset + + +def is_older_than(before, seconds): + """Return True if before is older than seconds.""" + if isinstance(before, six.string_types): + before = parse_strtime(before).replace(tzinfo=None) + else: + before = before.replace(tzinfo=None) + + return utcnow() - before > datetime.timedelta(seconds=seconds) + + +def is_newer_than(after, seconds): + """Return True if after is newer than seconds.""" + if isinstance(after, six.string_types): + after = parse_strtime(after).replace(tzinfo=None) + else: + after = after.replace(tzinfo=None) + + return after - utcnow() > datetime.timedelta(seconds=seconds) + + +def utcnow_ts(): + """Timestamp version of our utcnow function.""" + if utcnow.override_time is None: + # NOTE(kgriffs): This is several times faster + # than going through calendar.timegm(...) + return int(time.time()) + + return calendar.timegm(utcnow().timetuple()) + + +def utcnow(): + """Overridable version of utils.utcnow.""" + if utcnow.override_time: + try: + return utcnow.override_time.pop(0) + except AttributeError: + return utcnow.override_time + return datetime.datetime.utcnow() + + +def iso8601_from_timestamp(timestamp): + """Returns a iso8601 formatted date from timestamp.""" + return isotime(datetime.datetime.utcfromtimestamp(timestamp)) + + +utcnow.override_time = None + + +def set_time_override(override_time=None): + """Overrides utils.utcnow. + + Make it return a constant time or a list thereof, one at a time. + + :param override_time: datetime instance or list thereof. If not + given, defaults to the current UTC time. + """ + utcnow.override_time = override_time or datetime.datetime.utcnow() + + +def advance_time_delta(timedelta): + """Advance overridden time using a datetime.timedelta.""" + assert(not utcnow.override_time is None) + try: + for dt in utcnow.override_time: + dt += timedelta + except TypeError: + utcnow.override_time += timedelta + + +def advance_time_seconds(seconds): + """Advance overridden time by seconds.""" + advance_time_delta(datetime.timedelta(0, seconds)) + + +def clear_time_override(): + """Remove the overridden time.""" + utcnow.override_time = None + + +def marshall_now(now=None): + """Make an rpc-safe datetime with microseconds. + + Note: tzinfo is stripped, but not required for relative times. + """ + if not now: + now = utcnow() + return dict(day=now.day, month=now.month, year=now.year, hour=now.hour, + minute=now.minute, second=now.second, + microsecond=now.microsecond) + + +def unmarshall_time(tyme): + """Unmarshall a datetime dict.""" + return datetime.datetime(day=tyme['day'], + month=tyme['month'], + year=tyme['year'], + hour=tyme['hour'], + minute=tyme['minute'], + second=tyme['second'], + microsecond=tyme['microsecond']) + + +def delta_seconds(before, after): + """Return the difference between two timing objects. + + Compute the difference in seconds between two date, time, or + datetime objects (as a float, to microsecond resolution). + """ + delta = after - before + return total_seconds(delta) + + +def total_seconds(delta): + """Return the total seconds of datetime.timedelta object. + + Compute total seconds of datetime.timedelta, datetime.timedelta + doesn't have method total_seconds in Python2.6, calculate it manually. + """ + try: + return delta.total_seconds() + except AttributeError: + return ((delta.days * 24 * 3600) + delta.seconds + + float(delta.microseconds) / (10 ** 6)) + + +def is_soon(dt, window): + """Determines if time is going to happen in the next window seconds. + + :param dt: the time + :param window: minimum seconds to remain to consider the time not soon + + :return: True if expiration is within the given duration + """ + soon = (utcnow() + datetime.timedelta(seconds=window)) + return normalize_time(dt) <= soon diff --git a/cerberus/openstack/common/uuidutils.py b/cerberus/openstack/common/uuidutils.py new file mode 100644 index 0000000..234b880 --- /dev/null +++ b/cerberus/openstack/common/uuidutils.py @@ -0,0 +1,37 @@ +# Copyright (c) 2012 Intel Corporation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +UUID related utilities and helper functions. +""" + +import uuid + + +def generate_uuid(): + return str(uuid.uuid4()) + + +def is_uuid_like(val): + """Returns validation of a value as a UUID. + + For our purposes, a UUID is a canonical form string: + aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa + + """ + try: + return str(uuid.UUID(val)) == val + except (TypeError, ValueError, AttributeError): + return False diff --git a/cerberus/openstack/common/versionutils.py b/cerberus/openstack/common/versionutils.py new file mode 100644 index 0000000..0ed0452 --- /dev/null +++ b/cerberus/openstack/common/versionutils.py @@ -0,0 +1,148 @@ +# Copyright (c) 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Helpers for comparing version strings. +""" + +import functools +import pkg_resources + +from cerberus.openstack.common.gettextutils import _ +from cerberus.openstack.common import log as logging + + +LOG = logging.getLogger(__name__) + + +class deprecated(object): + """A decorator to mark callables as deprecated. + + This decorator logs a deprecation message when the callable it decorates is + used. The message will include the release where the callable was + deprecated, the release where it may be removed and possibly an optional + replacement. + + Examples: + + 1. Specifying the required deprecated release + + >>> @deprecated(as_of=deprecated.ICEHOUSE) + ... def a(): pass + + 2. Specifying a replacement: + + >>> @deprecated(as_of=deprecated.ICEHOUSE, in_favor_of='f()') + ... def b(): pass + + 3. Specifying the release where the functionality may be removed: + + >>> @deprecated(as_of=deprecated.ICEHOUSE, remove_in=+1) + ... def c(): pass + + """ + + FOLSOM = 'F' + GRIZZLY = 'G' + HAVANA = 'H' + ICEHOUSE = 'I' + + _RELEASES = { + 'F': 'Folsom', + 'G': 'Grizzly', + 'H': 'Havana', + 'I': 'Icehouse', + } + + _deprecated_msg_with_alternative = _( + '%(what)s is deprecated as of %(as_of)s in favor of ' + '%(in_favor_of)s and may be removed in %(remove_in)s.') + + _deprecated_msg_no_alternative = _( + '%(what)s is deprecated as of %(as_of)s and may be ' + 'removed in %(remove_in)s. It will not be superseded.') + + def __init__(self, as_of, in_favor_of=None, remove_in=2, what=None): + """Initialize decorator + + :param as_of: the release deprecating the callable. Constants + are define in this class for convenience. + :param in_favor_of: the replacement for the callable (optional) + :param remove_in: an integer specifying how many releases to wait + before removing (default: 2) + :param what: name of the thing being deprecated (default: the + callable's name) + + """ + self.as_of = as_of + self.in_favor_of = in_favor_of + self.remove_in = remove_in + self.what = what + + def __call__(self, func): + if not self.what: + self.what = func.__name__ + '()' + + @functools.wraps(func) + def wrapped(*args, **kwargs): + msg, details = self._build_message() + LOG.deprecated(msg, details) + return func(*args, **kwargs) + return wrapped + + def _get_safe_to_remove_release(self, release): + # TODO(dstanek): this method will have to be reimplemented once + # when we get to the X release because once we get to the Y + # release, what is Y+2? + new_release = chr(ord(release) + self.remove_in) + if new_release in self._RELEASES: + return self._RELEASES[new_release] + else: + return new_release + + def _build_message(self): + details = dict(what=self.what, + as_of=self._RELEASES[self.as_of], + remove_in=self._get_safe_to_remove_release(self.as_of)) + + if self.in_favor_of: + details['in_favor_of'] = self.in_favor_of + msg = self._deprecated_msg_with_alternative + else: + msg = self._deprecated_msg_no_alternative + return msg, details + + +def is_compatible(requested_version, current_version, same_major=True): + """Determine whether `requested_version` is satisfied by + `current_version`; in other words, `current_version` is >= + `requested_version`. + + :param requested_version: version to check for compatibility + :param current_version: version to check against + :param same_major: if True, the major version must be identical between + `requested_version` and `current_version`. This is used when a + major-version difference indicates incompatibility between the two + versions. Since this is the common-case in practice, the default is + True. + :returns: True if compatible, False if not + """ + requested_parts = pkg_resources.parse_version(requested_version) + current_parts = pkg_resources.parse_version(current_version) + + if same_major and (requested_parts[0] != current_parts[0]): + return False + + return current_parts >= requested_parts diff --git a/cerberus/plugins/__init__.py b/cerberus/plugins/__init__.py new file mode 100644 index 0000000..73ca62b --- /dev/null +++ b/cerberus/plugins/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/cerberus/plugins/base.py b/cerberus/plugins/base.py new file mode 100644 index 0000000..1098d2f --- /dev/null +++ b/cerberus/plugins/base.py @@ -0,0 +1,154 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import fnmatch +import json +import six + +import oslo.messaging + +from cerberus.openstack.common import log +from cerberus.openstack.common import loopingcall +from cerberus.openstack.common import threadgroup + + +LOG = log.getLogger(__name__) + + +@six.add_metaclass(abc.ABCMeta) +class PluginBase(object): + """ + Base class for all plugins + """ + + TOOL_NAME = "" + TYPE = "" + PROVIDER = "" + DESCRIPTION = "" + + _name = None + + _uuid = None + + _event_groups = { + 'INSTANCE': [ + 'compute.instance.created', + 'compute.instance.deleted' + 'compute.instance.updated' + ], + 'NETWORK': [ + 'network.created', + ], + 'PROJECT': [ + 'project.created' + ] + } + + def __init__(self, description=None, provider=None, type=None, + tool_name=None): + self._subscribedEvents = [] + self._name = "{0}.{1}".format(self.__class__.__module__, + self.__class__.__name__) + + def subscribe_event(self, event): + if not (event in self._subscribedEvents): + self._subscribedEvents.append(event) + + def register_manager(self, manager): + """ + Enables the plugin to add tasks to the manager + :param manager: the task manager to add tasks to + """ + self.manager = manager + + @staticmethod + def _handle_event_type(subscribed_events, event_type): + """Check whether event_type should be handled. + + It is according to event_type_to_handle.l + """ + return any(map(lambda e: fnmatch.fnmatch(event_type, e), + subscribed_events)) + + @staticmethod + def get_targets(conf): + """Return a sequence of oslo.messaging.Target + + Sequence defining the exchange and topics to be connected for this + plugin. + """ + return [oslo.messaging.Target(topic=topic) + for topic in conf.notification_topics] + + @abc.abstractmethod + def process_notification(self, ctxt, publisher_id, event_type, payload, + metadata): + pass + + def info(self, ctxt, publisher_id, event_type, payload, metadata): + # Check if event is registered for plugin + if self._handle_event_type(self._subscribedEvents, event_type): + self.process_notification(ctxt, publisher_id, event_type, payload, + metadata) + ''' + http://stackoverflow.com/questions/3378949/ + python-decorators-and-class-inheritance + http://stackoverflow.com/questions/338101/ + python-function-attributes-uses-and-abuses + ''' + @staticmethod + def webmethod(func): + func.is_webmethod = True + return func + + +class PluginEncoder(json.JSONEncoder): + def default(self, obj): + if not isinstance(obj, PluginBase): + return super(PluginEncoder, self).default(obj) + methods = [method for method in dir(obj) + if hasattr(getattr(obj, method), 'is_webmethod')] + return {'name': obj._name, + 'subscribed_events': obj._subscribedEvents, + 'methods': methods} + + +class FixedIntervalLoopingCallEncoder(json.JSONEncoder): + def default(self, obj): + if not isinstance(obj, loopingcall.FixedIntervalLoopingCall): + return super(FixedIntervalLoopingCallEncoder, self).default(obj) + if obj._running is True: + state = 'running' + else: + state = 'stopped' + return {'id': obj.kw.get('task_id', None), + 'name': obj.kw.get('task_name', None), + 'period': obj.kw.get('task_period', None), + 'type': obj.kw.get('task_type', None), + 'plugin_id': obj.kw.get('plugin_id', None), + 'state': state} + + +class ThreadEncoder(json.JSONEncoder): + def default(self, obj): + if not isinstance(obj, threadgroup.Thread): + return super(ThreadEncoder, self).default(obj) + return {'id': obj.kw.get('task_id', None), + 'name': obj.kw.get('task_name', None), + 'type': obj.kw.get('task_type', None), + 'plugin_id': obj.kw.get('plugin_id', None), + 'state': 'running'} diff --git a/cerberus/plugins/extension.py b/cerberus/plugins/extension.py new file mode 100644 index 0000000..c4f9305 --- /dev/null +++ b/cerberus/plugins/extension.py @@ -0,0 +1,55 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import argparse + +from stevedore import extension + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--width', + default=60, + type=int, + help='maximum output width for text', + ) + parsed_args = parser.parse_args() + + data = { + 'a': 'A', + 'b': 'B', + 'long': 'word ' * 80, + } + + mgr = extension.ExtensionManager( + namespace='stevedore.example.formatter', + invoke_on_load=True, + invoke_args=(parsed_args.width,), + ) + + def format_data(ext, data): + return (ext.name, ext.obj.format(data)) + + results = mgr.map(format_data, data) + + for name, result in results: + print('Formatter: {0}'.format(name)) + for chunk in result: + print(chunk, end='') + print('') diff --git a/cerberus/plugins/openvas.py b/cerberus/plugins/openvas.py new file mode 100644 index 0000000..26979d3 --- /dev/null +++ b/cerberus/plugins/openvas.py @@ -0,0 +1,94 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from oslo.config import cfg + +from cerberus.client import keystone_client +from cerberus.client import neutron_client +from cerberus.client import nova_client +from cerberus.openstack.common import log +from cerberus.plugins import base +import openvas_lib + + +LOG = log.getLogger(__name__) + + +# Register options for the service +OPENVAS_OPTS = [ + cfg.StrOpt('openvas_admin', + default='admin', + help='The admin user for rcp server', + ), + cfg.StrOpt('openvas_passwd', + default='admin', + help='The password for rcp server', + ), + cfg.StrOpt('openvas_url', + default='https://', + help='Url of rcp server', + ), +] + +opt_group = cfg.OptGroup(name='openvas', + title='Options for the OpenVas client') + +cfg.CONF.register_group(opt_group) +cfg.CONF.register_opts(OPENVAS_OPTS, opt_group) +cfg.CONF.import_group('openvas', 'cerberus.service') + +_FLOATINGIP_UPDATED = 'floatingip.update.end' +_ROLE_ASSIGNMENT_CREATED = 'identity.created.role_assignment' +_ROLE_ASSIGNMENT_DELETED = 'identity.deleted.role_assignment' +_PROJECT_DELETED = 'identity.project.deleted' + + +class OpenVasPlugin(base.PluginBase): + + def __init__(self): + self.task_id = None + super(OpenVasPlugin, self).__init__() + self.subscribe_event(_ROLE_ASSIGNMENT_CREATED) + self.subscribe_event(_ROLE_ASSIGNMENT_DELETED) + self.subscribe_event(_FLOATINGIP_UPDATED) + self.subscribe_event(_PROJECT_DELETED) + self.kc = keystone_client.Client() + self.nc = neutron_client.Client() + self.nova_client = nova_client.Client() + self.conf = cfg.CONF.openvas + + @base.PluginBase.webmethod + def get_security_reports(self, **kwargs): + security_reports = [] + try: + scanner = openvas_lib.VulnscanManager(self.conf.openvas_url, + self.conf.openvas_admin, + self.conf.openvas_passwd) + finished_scans = scanner.get_finished_scans + for scan_key, scan_id in finished_scans.iteritems(): + report_id = scanner.get_report_id(scan_id) + report = scanner.get_report_html(report_id) + + security_reports.append(report) + + except Exception as e: + LOG.exception(e) + pass + return security_reports + + def process_notification(self, ctxt, publisher_id, event_type, payload, + metadata): + pass diff --git a/cerberus/plugins/task_plugin.py b/cerberus/plugins/task_plugin.py new file mode 100644 index 0000000..adfd0dd --- /dev/null +++ b/cerberus/plugins/task_plugin.py @@ -0,0 +1,61 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import eventlet + + +from cerberus.openstack.common import log +from cerberus.plugins import base + + +LOG = log.getLogger(__name__) + +_IMAGE_UPDATE = 'image.update' + + +class TaskPlugin(base.PluginBase): + + def __init__(self): + super(TaskPlugin, self).__init__() + + @base.PluginBase.webmethod + def act_long(self, *args, **kwargs): + ''' + Each second, log the date during 40 seconds. + :param args: + :param kwargs: + :return: + ''' + LOG.info(str(kwargs.get('task_name', 'unknown')) + " :" + + str(datetime.datetime.time(datetime.datetime.now()))) + i = 0 + while(i < 3600): + LOG.info(str(kwargs.get('task_name', 'unknown')) + " :" + + str(datetime.datetime.time(datetime.datetime.now()))) + i += 1 + eventlet.sleep(1) + LOG.info(str(kwargs.get('task_name', 'unknown')) + " :" + + str(datetime.datetime.time(datetime.datetime.now()))) + + @base.PluginBase.webmethod + def act_short(self, *args, **kwargs): + LOG.info(str(kwargs.get('task_name', 'unknown')) + " :" + + str(datetime.datetime.time(datetime.datetime.now()))) + + def process_notification(self, ctxt, publisher_id, event_type, payload, + metadata): + pass diff --git a/cerberus/plugins/test_plugin.py b/cerberus/plugins/test_plugin.py new file mode 100644 index 0000000..bf57ead --- /dev/null +++ b/cerberus/plugins/test_plugin.py @@ -0,0 +1,59 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime + +from cerberus.openstack.common import log +from cerberus.plugins import base + + +LOG = log.getLogger(__name__) + + +class TestPlugin(base.PluginBase): + + def __init__(self): + self.task_id = None + super(TestPlugin, self).__init__() + super(TestPlugin, self).subscribe_event('image.update') + + def act_short(self, *args, **kwargs): + LOG.info(str(kwargs.get('task_name', 'unknown')) + " :" + + str(datetime.datetime.time(datetime.datetime.now()))) + + def process_notification(self, ctxt, publisher_id, event_type, payload, + metadata): + + LOG.info('--> Plugin %(plugin)s managed event %(event)s' + 'payload %(payload)s' + % {'plugin': self._name, + 'event': event_type, + 'payload': payload}) + if ('START' in payload['name']and self.task_id is None): + self.task_id = self.manager.\ + _add_recurrent_task(self.act_short, + 1, + task_name='TEST_PLUGIN_START_PAYLOAD') + LOG.info('Start cycling task id %s', self.task_id) + if ('STOP' in payload['name']): + try: + self.manager._force_delete_recurrent_task(self.task_id) + LOG.info('Stop cycling task id %s', self.task_id) + self.task_id = None + except StopIteration as e: + LOG.debug('Error when stopping task') + LOG.exception(e) + return self._name diff --git a/cerberus/service.py b/cerberus/service.py new file mode 100644 index 0000000..8f01030 --- /dev/null +++ b/cerberus/service.py @@ -0,0 +1,143 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import socket +import sys + +from oslo.config import cfg +from oslo.messaging import rpc +from stevedore import named + +from cerberus.openstack.common import log +from cerberus import utils + + +OPTS = [ + cfg.StrOpt('host', + default=socket.gethostname(), + help='Name of this node, which must be valid in an AMQP ' + 'key. Can be an opaque identifier. For ZeroMQ only, must ' + 'be a valid host name, FQDN, or IP address.'), + cfg.MultiStrOpt('dispatcher', + deprecated_group="collector", + default=['database'], + help='Dispatcher to process data.'), + cfg.IntOpt('collector_workers', + default=1, + help='Number of workers for collector service. A single ' + 'collector is enabled by default.'), + cfg.IntOpt('notification_workers', + default=1, + help='Number of workers for notification service. A single ' + 'notification agent is enabled by default.'), +] +cfg.CONF.register_opts(OPTS) + +CLI_OPTIONS = [ + cfg.StrOpt('os-username', + deprecated_group="DEFAULT", + default=os.environ.get('OS_USERNAME', 'cerberus'), + help='User name to use for OpenStack service access.'), + cfg.StrOpt('os-password', + deprecated_group="DEFAULT", + secret=True, + default=os.environ.get('OS_PASSWORD', 'admin'), + help='Password to use for OpenStack service access.'), + cfg.StrOpt('os-tenant-id', + deprecated_group="DEFAULT", + default=os.environ.get('OS_TENANT_ID', ''), + help='Tenant ID to use for OpenStack service access.'), + cfg.StrOpt('os-tenant-name', + deprecated_group="DEFAULT", + default=os.environ.get('OS_TENANT_NAME', 'admin'), + help='Tenant name to use for OpenStack service access.'), + cfg.StrOpt('os-cacert', + default=os.environ.get('OS_CACERT'), + help='Certificate chain for SSL validation.'), + cfg.StrOpt('os-auth-url', + deprecated_group="DEFAULT", + default=os.environ.get('OS_AUTH_URL', + 'http://localhost:5000/v2.0'), + help='Auth URL to use for OpenStack service access.'), + cfg.StrOpt('os-region-name', + deprecated_group="DEFAULT", + default=os.environ.get('OS_REGION_NAME'), + help='Region name to use for OpenStack service endpoints.'), + cfg.StrOpt('os-endpoint-type', + default=os.environ.get('OS_ENDPOINT_TYPE', 'publicURL'), + help='Type of endpoint in Identity service catalog to use for ' + 'communication with OpenStack services.'), + cfg.BoolOpt('insecure', + default=False, + help='Disables X.509 certificate validation when an ' + 'SSL connection to Identity Service is established.'), +] +cfg.CONF.register_opts(CLI_OPTIONS, group="service_credentials") + + +LOG = log.getLogger(__name__) + + +class WorkerException(Exception): + """Exception for errors relating to service workers + """ + + +class DispatchedService(object): + + DISPATCHER_NAMESPACE = 'cerberus.dispatcher' + + def start(self): + super(DispatchedService, self).start() + LOG.debug(_('loading dispatchers from %s'), + self.DISPATCHER_NAMESPACE) + self.dispatcher_manager = named.NamedExtensionManager( + namespace=self.DISPATCHER_NAMESPACE, + names=cfg.CONF.dispatcher, + invoke_on_load=True, + invoke_args=[cfg.CONF]) + if not list(self.dispatcher_manager): + LOG.warning(_('Failed to load any dispatchers for %s'), + self.DISPATCHER_NAMESPACE) + + +def get_workers(name): + workers = (cfg.CONF.get('%s_workers' % name) or + utils.cpu_count()) + if workers and workers < 1: + msg = (_("%(worker_name)s value of %(workers)s is invalid, " + "must be greater than 0") % + {'worker_name': '%s_workers' % name, 'workers': str(workers)}) + raise WorkerException(msg) + return workers + + +def prepare_service(argv=None): + rpc.set_defaults(control_exchange='cerberus') + cfg.set_defaults(log.log_opts, + default_log_levels=['amqplib=WARN', + 'qpid.messaging=INFO', + 'sqlalchemy=WARN', + 'keystoneclient=INFO', + 'stevedore=INFO', + 'eventlet.wsgi.server=WARN', + 'iso8601=WARN' + ]) + if argv is None: + argv = sys.argv + cfg.CONF(argv[1:], project='cerberus') + log.setup('cerberus') diff --git a/cerberus/tests/__init__.py b/cerberus/tests/__init__.py index e69de29..73ca62b 100644 --- a/cerberus/tests/__init__.py +++ b/cerberus/tests/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/cerberus/tests/api/__init__.py b/cerberus/tests/api/__init__.py new file mode 100644 index 0000000..73ca62b --- /dev/null +++ b/cerberus/tests/api/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/cerberus/tests/api/base.py b/cerberus/tests/api/base.py new file mode 100644 index 0000000..85b4b5a --- /dev/null +++ b/cerberus/tests/api/base.py @@ -0,0 +1,208 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from oslo.config import cfg +import pecan.testing + +from cerberus.api import auth +from cerberus.db import api as dbapi +from cerberus.tests import base + + +PATH_PREFIX = '/v1' + + +class TestApiBase(base.TestBase): + + def setUp(self): + super(TestApiBase, self).setUp() + self.app = self._make_app() + self.dbapi = dbapi.get_instance() + cfg.CONF.set_override("auth_version", + "v2.0", + group=auth.OPT_GROUP_NAME) + + def _make_app(self, enable_acl=False): + + root_dir = self.path_get() + + self.config = { + 'app': { + 'root': 'cerberus.api.root.RootController', + 'modules': ['cerberus.api'], + 'static_root': '%s/public' % root_dir, + 'template_path': '%s/api/templates' % root_dir, + 'enable_acl': enable_acl, + 'acl_public_routes': ['/', '/v1', '/security_reports'] + }, + } + return pecan.testing.load_test_app(self.config) + + def _request_json(self, path, params, expect_errors=False, headers=None, + method="post", extra_environ=None, status=None, + path_prefix=PATH_PREFIX): + """Sends simulated HTTP request to Pecan test app. + + :param path: url path of target service + :param params: content for wsgi.input of request + :param expect_errors: Boolean value; whether an error is expected based + on request + :param headers: a dictionary of headers to send along with the request + :param method: Request method type. Appropriate method function call + should be used rather than passing attribute in. + :param extra_environ: a dictionary of environ variables to send along + with the request + :param status: expected status code of response + :param path_prefix: prefix of the url path + """ + full_path = path_prefix + path + print('%s: %s %s' % (method.upper(), full_path, params)) + response = getattr(self.app, "%s_json" % method)( + str(full_path), + params=params, + headers=headers, + status=status, + extra_environ=extra_environ, + expect_errors=expect_errors + ) + print('GOT:%s' % response) + return response + + def put_json(self, path, params, expect_errors=False, headers=None, + extra_environ=None, status=None): + """Sends simulated HTTP PUT request to Pecan test app. + + :param path: url path of target service + :param params: content for wsgi.input of request + :param expect_errors: Boolean value; whether an error is expected based + on request + :param headers: a dictionary of headers to send along with the request + :param extra_environ: a dictionary of environ variables to send along + with the request + :param status: expected status code of response + """ + return self._request_json(path=path, params=params, + expect_errors=expect_errors, + headers=headers, extra_environ=extra_environ, + status=status, method="put") + + def post_json(self, path, params, expect_errors=False, headers=None, + extra_environ=None, status=None): + """Sends simulated HTTP POST request to Pecan test app. + + :param path: url path of target service + :param params: content for wsgi.input of request + :param expect_errors: Boolean value; whether an error is expected based + on request + :param headers: a dictionary of headers to send along with the request + :param extra_environ: a dictionary of environ variables to send along + with the request + :param status: expected status code of response + """ + return self._request_json(path=path, params=params, + expect_errors=expect_errors, + headers=headers, extra_environ=extra_environ, + status=status, method="post") + + def patch_json(self, path, params, expect_errors=False, headers=None, + extra_environ=None, status=None): + """Sends simulated HTTP PATCH request to Pecan test app. + + :param path: url path of target service + :param params: content for wsgi.input of request + :param expect_errors: Boolean value; whether an error is expected based + on request + :param headers: a dictionary of headers to send along with the request + :param extra_environ: a dictionary of environ variables to send along + with the request + :param status: expected status code of response + """ + return self._request_json(path=path, params=params, + expect_errors=expect_errors, + headers=headers, extra_environ=extra_environ, + status=status, method="patch") + + def delete(self, path, expect_errors=False, headers=None, + extra_environ=None, status=None, path_prefix=PATH_PREFIX): + """Sends simulated HTTP DELETE request to Pecan test app. + + :param path: url path of target service + :param expect_errors: Boolean value; whether an error is expected based + on request + :param headers: a dictionary of headers to send along with the request + :param extra_environ: a dictionary of environ variables to send along + with the request + :param status: expected status code of response + :param path_prefix: prefix of the url path + """ + full_path = path_prefix + path + print('DELETE: %s' % (full_path)) + response = self.app.delete(str(full_path), + headers=headers, + status=status, + extra_environ=extra_environ, + expect_errors=expect_errors) + print('GOT:%s' % response) + return response + + def get_json(self, path, expect_errors=False, headers=None, + extra_environ=None, q=[], path_prefix=PATH_PREFIX, **params): + """Sends simulated HTTP GET request to Pecan test app. + + :param path: url path of target service + :param expect_errors: Boolean value;whether an error is expected based + on request + :param headers: a dictionary of headers to send along with the request + :param extra_environ: a dictionary of environ variables to send along + with the request + :param q: list of queries consisting of: field, value, op, and type + keys + :param path_prefix: prefix of the url path + :param params: content for wsgi.input of request + """ + full_path = path_prefix + path + query_params = {'q.field': [], + 'q.value': [], + 'q.op': [], + } + for query in q: + for name in ['field', 'op', 'value']: + query_params['q.%s' % name].append(query.get(name, '')) + all_params = {} + all_params.update(params) + if q: + all_params.update(query_params) + print('GET: %s %r' % (full_path, all_params)) + response = self.app.get(full_path, + params=all_params, + headers=headers, + extra_environ=extra_environ, + expect_errors=expect_errors) + if not expect_errors: + response = response.json + print('GOT:%s' % response) + return response + + def validate_link(self, link): + """Checks if the given link can get correct data.""" + + # removes 'http://loicalhost' part + full_path = link.split('localhost', 1)[1] + try: + self.get_json(full_path, path_prefix='') + return True + except Exception: + return False diff --git a/cerberus/tests/api/utils.py b/cerberus/tests/api/utils.py new file mode 100644 index 0000000..574b56c --- /dev/null +++ b/cerberus/tests/api/utils.py @@ -0,0 +1,66 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Utils for testing the API service. +""" + +import datetime +import json + +ADMIN_TOKEN = '4562138218392831' +MEMBER_TOKEN = '4562138218392832' + + +class FakeMemcache(object): + """Fake cache that is used for keystone tokens lookup.""" + + _cache = { + 'tokens/%s' % ADMIN_TOKEN: { + 'access': { + 'token': {'id': ADMIN_TOKEN, + 'expires': '2100-09-11T00:00:00'}, + 'user': {'id': 'user_id1', + 'name': 'user_name1', + 'tenantId': '123i2910', + 'tenantName': 'mytenant', + 'roles': [{'name': 'admin'}]}, + } + }, + 'tokens/%s' % MEMBER_TOKEN: { + 'access': { + 'token': {'id': MEMBER_TOKEN, + 'expires': '2100-09-11T00:00:00'}, + 'user': {'id': 'user_id2', + 'name': 'user-good', + 'tenantId': 'project-good', + 'tenantName': 'goodies', + 'roles': [{'name': 'Member'}]} + } + } + } + + def __init__(self): + self.set_key = None + self.set_value = None + self.token_expiration = None + + def get(self, key): + dt = datetime.datetime.utcnow() + datetime.timedelta(minutes=5) + return json.dumps((self._cache.get(key), dt.isoformat())) + + def set(self, key, value, time=0, min_compress_len=0): + self.set_value = value + self.set_key = key diff --git a/cerberus/tests/api/v1/__init__.py b/cerberus/tests/api/v1/__init__.py new file mode 100644 index 0000000..b06b406 --- /dev/null +++ b/cerberus/tests/api/v1/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/cerberus/tests/api/v1/test_plugins.py b/cerberus/tests/api/v1/test_plugins.py new file mode 100644 index 0000000..f33bbe4 --- /dev/null +++ b/cerberus/tests/api/v1/test_plugins.py @@ -0,0 +1,106 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import mock +from sqlalchemy import exc + +from oslo import messaging + +from cerberus import db +from cerberus.tests.api import base +from cerberus.tests.db import utils as db_utils + + +PLUGIN_ID_1 = 1 +PLUGIN_ID_2 = 2 +PLUGIN_NAME_2 = 'toolyx' + + +class TestPlugins(base.TestApiBase): + + def setUp(self): + super(TestPlugins, self).setUp() + self.fake_plugin = db_utils.get_test_plugin( + id=PLUGIN_ID_1 + ) + self.fake_plugins = [] + self.fake_plugins.append(self.fake_plugin) + self.fake_plugins.append(db_utils.get_test_plugin( + id=PLUGIN_ID_2, + name=PLUGIN_NAME_2 + )) + self.fake_plugin_model = db_utils.get_plugin_model( + id=PLUGIN_ID_1 + ) + self.fake_plugins_model = [] + self.fake_plugins_model.append( + self.fake_plugin_model) + self.fake_plugins_model.append( + db_utils.get_plugin_model( + id=PLUGIN_ID_2, + name=PLUGIN_NAME_2 + ) + ) + self.plugins_path = '/plugins' + self.plugin_path = '/plugins/%s' % self.fake_plugin['uuid'] + + def test_list(self): + + rpc_plugins = [] + for plugin in self.fake_plugins: + rpc_plugins.append(json.dumps(plugin)) + + messaging.RPCClient.call = mock.MagicMock( + return_value=rpc_plugins) + db.plugins_info_get = mock.MagicMock( + return_value=self.fake_plugins_model) + + plugins = self.get_json(self.plugins_path) + self.assertEqual({'plugins': self.fake_plugins}, + plugins) + + def test_get(self): + rpc_plugin = json.dumps(self.fake_plugin) + messaging.RPCClient.call = mock.MagicMock(return_value=rpc_plugin) + db.plugin_info_get_from_uuid = mock.MagicMock( + return_value=self.fake_plugin_model) + plugin = self.get_json(self.plugin_path) + self.assertEqual({'plugin': self.fake_plugin}, plugin) + + def test_list_plugins_remote_error(self): + messaging.RPCClient.call = mock.MagicMock( + side_effect=messaging.RemoteError) + res = self.get_json(self.plugins_path, expect_errors=True) + self.assertEqual(503, res.status_code) + + def test_get_plugin_not_existing(self): + messaging.RPCClient.call = mock.MagicMock( + side_effect=messaging.RemoteError) + res = self.get_json(self.plugin_path, expect_errors=True) + self.assertEqual(503, res.status_code) + + def test_list_plugins_db_error(self): + messaging.RPCClient.call = mock.MagicMock(return_value=None) + db.plugins_info_get = mock.MagicMock(side_effect=exc.OperationalError) + res = self.get_json(self.plugins_path, expect_errors=True) + self.assertEqual(404, res.status_code) + + def test_get_plugin_remote_error(self): + messaging.RPCClient.call = mock.MagicMock( + side_effect=messaging.RemoteError) + res = self.get_json(self.plugin_path, expect_errors=True) + self.assertEqual(503, res.status_code) diff --git a/cerberus/tests/api/v1/test_security_reports.py b/cerberus/tests/api/v1/test_security_reports.py new file mode 100644 index 0000000..e2652aa --- /dev/null +++ b/cerberus/tests/api/v1/test_security_reports.py @@ -0,0 +1,95 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import mock +from sqlalchemy import exc as sql_exc + +from cerberus import db +from cerberus.tests.api import base +from cerberus.tests.db import utils as db_utils + + +def get_tasks(): + tasks = [] + return tasks + + +def get_task(): + task = {} + return task + + +SECURITY_REPORT_ID = 'abc123' +SECURITY_REPORT_ID_2 = 'xyz789' + + +class TestSecurityReports(base.TestApiBase): + + def setUp(self): + super(TestSecurityReports, self).setUp() + self.fake_security_report = db_utils.get_test_security_report( + id=SECURITY_REPORT_ID + ) + self.fake_security_reports = [] + self.fake_security_reports.append(self.fake_security_report) + self.fake_security_reports.append(db_utils.get_test_security_report( + id=SECURITY_REPORT_ID_2 + )) + self.fake_security_report_model = db_utils.get_security_report_model( + id=SECURITY_REPORT_ID + ) + self.fake_security_reports_model = [] + self.fake_security_reports_model.append( + self.fake_security_report_model) + self.fake_security_reports_model.append( + db_utils.get_security_report_model( + id=SECURITY_REPORT_ID_2 + ) + ) + self.security_reports_path = '/security_reports' + self.security_report_path = '/security_reports/%s' \ + % self.fake_security_report['report_id'] + + def test_get(self): + + db.security_report_get = mock.MagicMock( + return_value=self.fake_security_report_model) + security_report = self.get_json(self.security_report_path) + self.assertEqual({'security_report': self.fake_security_report}, + security_report) + + def test_list(self): + + db.security_report_get_all = mock.MagicMock( + return_value=self.fake_security_reports_model) + + security_reports = self.get_json(self.security_reports_path) + + self.assertEqual({'security_reports': self.fake_security_reports}, + security_reports) + + def test_get_sreports_db_error(self): + db.security_report_get_all = mock.MagicMock( + side_effect=sql_exc.NoSuchTableError) + + res = self.get_json(self.security_reports_path, expect_errors=True) + self.assertEqual(404, res.status_code) + + def test_get_sreport_db_error(self): + db.security_report_get = mock.MagicMock( + side_effect=sql_exc.OperationalError) + res = self.get_json(self.security_report_path, expect_errors=True) + self.assertEqual(404, res.status_code) diff --git a/cerberus/tests/api/v1/test_tasks.py b/cerberus/tests/api/v1/test_tasks.py new file mode 100644 index 0000000..95dfc1d --- /dev/null +++ b/cerberus/tests/api/v1/test_tasks.py @@ -0,0 +1,221 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import mock + +from oslo import messaging + +from cerberus.tests.api import base +from cerberus.tests.db import utils as db_utils + + +class TestTasks(base.TestApiBase): + + def setUp(self): + super(TestTasks, self).setUp() + self.fake_task = db_utils.get_test_task() + self.fake_tasks = [] + self.fake_tasks.append(self.fake_task) + self.fake_tasks.append(db_utils.get_test_task( + task_id=2, + task_type='reccurent', + task_name='recurrent_task', + task_period=20 + )) + self.tasks_path = '/tasks' + self.task_path = '/tasks/%s' % self.fake_task['task_id'] + + def test_list(self): + rpc_tasks = [] + for task in self.fake_tasks: + rpc_tasks.append(json.dumps(task)) + + messaging.RPCClient.call = mock.MagicMock(return_value=rpc_tasks) + tasks = self.get_json(self.tasks_path) + self.assertEqual({'tasks': self.fake_tasks}, tasks) + + def test_create(self): + task_id = 1 + task = { + "method": "act_long", + "name": "task1", + "type": "recurrent", + "period": 60, + "plugin_id": "test" + } + expected_task = task + expected_task['id'] = task_id + messaging.RPCClient.call = mock.MagicMock(return_value=task_id) + task = self.post_json(self.tasks_path, {'task': task}) + self.assertEqual({'task': expected_task}, task.json_body) + + def test_get(self): + rpc_task = json.dumps(self.fake_task) + messaging.RPCClient.call = mock.MagicMock( + return_value=rpc_task) + task = self.get_json(self.task_path) + self.assertEqual({'task': self.fake_task}, task) + + def test_stop(self): + request_body = {'stop': 'null'} + messaging.RPCClient.call = mock.MagicMock(return_value=1) + response = self.post_json(self.task_path, request_body) + self.assertEqual(200, response.status_code) + + def test_delete(self): + messaging.RPCClient.call = mock.MagicMock(return_value=1) + response = self.delete(self.task_path) + self.assertEqual(200, response.status_code) + + def test_list_tasks_remote_error(self): + messaging.RPCClient.call = mock.MagicMock( + side_effect=messaging.RemoteError) + response = self.get_json(self.task_path, expect_errors=True) + self.assertEqual(404, response.status_code) + + def test_create_task_incorrect_json(self): + request_body = "INCORRECT JSON" + response = self.post_json(self.tasks_path, + request_body, + expect_errors=True) + self.assertEqual(400, response.status_code) + + def test_create_recurrent_task_without_task_object(self): + task_id = 1 + request_body = { + "method": "act_long", + "name": "task1", + "type": "recurrent", + } + messaging.RPCClient.call = mock.MagicMock(return_value=task_id) + response = self.post_json(self.tasks_path, + request_body, + expect_errors=True) + self.assertEqual(400, response.status_code) + + def test_create_recurrent_task_without_plugin_id(self): + task_id = 1 + task = { + "method": "act_long", + "name": "task1", + "type": "recurrent", + "period": 60, + } + request_body = {'task': task} + messaging.RPCClient.call = mock.MagicMock(return_value=task_id) + response = self.post_json(self.tasks_path, + request_body, + expect_errors=True) + self.assertEqual(400, response.status_code) + + def test_create_recurrent_task_without_method(self): + task_id = 1 + task = { + "name": "task1", + "type": "recurrent", + "period": 60, + "plugin_id": "plugin-test" + } + request_body = {'task': task} + messaging.RPCClient.call = mock.MagicMock(return_value=task_id) + response = self.post_json(self.tasks_path, + request_body, + expect_errors=True) + self.assertEqual(400, response.status_code) + + def test_create_recurrent_task_remote_error(self): + task = { + "method": "act_long", + "name": "task1", + "type": "recurrent", + "period": 60, + "plugin_id": "plugin-test" + } + request_body = {'task': task} + messaging.RPCClient.call = mock.MagicMock( + side_effect=messaging.RemoteError) + response = self.post_json(self.tasks_path, + request_body, + expect_errors=True) + self.assertEqual(400, response.status_code) + + def test_get_task_bad_id(self): + response = self.get_json('/tasks/toto', expect_errors=True) + self.assertEqual(400, response.status_code) + + def test_get_task_remote_error(self): + messaging.RPCClient.call = mock.MagicMock( + side_effect=messaging.RemoteError) + response = self.get_json(self.task_path, expect_errors=True) + self.assertEqual(404, response.status_code) + + def test_stop_task_wrong_json(self): + request_body = "INCORRECT JSON" + response = self.post_json(self.task_path, + request_body, + expect_errors=True) + self.assertEqual(400, response.status_code) + + def test_stop_task_wrong_id(self): + request_body = json.dumps({ + "stop": "null" + }) + messaging.RPCClient.call = mock.MagicMock( + side_effect=messaging.RemoteError) + response = self.post_json(self.task_path, + request_body, + expect_errors=True) + self.assertEqual(400, response.status_code) + + def test_stop_task_id_not_integer(self): + request_body = json.dumps({ + "stop": "null" + }) + response = self.post_json('/tasks/toto', + request_body, + expect_errors=True) + self.assertEqual(400, response.status_code) + + def test_force_delete_task_wrong_id(self): + request_body = json.dumps({ + "forceDelete": "null" + }) + messaging.RPCClient.call = mock.MagicMock( + side_effect=messaging.RemoteError) + response = self.post_json(self.task_path, + request_body, + expect_errors=True) + self.assertEqual(400, response.status_code) + + def test_force_delete_task_id_not_integer(self): + request_body = json.dumps({ + "forceDelete": "null" + }) + response = self.post_json('/tasks/toto', + request_body, + expect_errors=True) + self.assertEqual(400, response.status_code) + + def test_delete_task_id_not_integer(self): + response = self.delete('/tasks/toto', expect_errors=True) + self.assertEqual(400, response.status_code) + + def test_delete_task_not_existing(self): + messaging.RPCClient.call = mock.MagicMock( + side_effect=messaging.RemoteError) + response = self.delete(self.task_path, expect_errors=True) + self.assertEqual(400, response.status_code) diff --git a/cerberus/tests/base.py b/cerberus/tests/base.py index 1c30cdb..57bd8ef 100644 --- a/cerberus/tests/base.py +++ b/cerberus/tests/base.py @@ -1,23 +1,56 @@ -# -*- coding: utf-8 -*- - -# Copyright 2010-2011 OpenStack Foundation -# Copyright (c) 2013 Hewlett-Packard Development Company, L.P. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from oslo.config import cfg from oslotest import base +from cerberus.tests import config_fixture +from cerberus.tests import policy_fixture +from cerberus.tests import utils -class TestCase(base.BaseTestCase): + +CONF = cfg.CONF + + +class TestBase(base.BaseTestCase): """Test case base class for all unit tests.""" + def setUp(self): + super(TestBase, self).setUp() + utils.setup_dummy_db() + self.addCleanup(utils.reset_dummy_db) + self.useFixture(config_fixture.ConfigFixture(CONF)) + self.policy = self.useFixture(policy_fixture.PolicyFixture()) + + def path_get(self, project_file=None): + """Get the absolute path to a file. Used for testing the API. + :param project_file: File whose path to return. Default: None. + :returns: path to the specified file, or path to project root. + """ + root = os.path.abspath(os.path.join(os.path.dirname(__file__), + '..', + '..', + ) + ) + if project_file: + return os.path.join(root, project_file) + else: + return root + + +class TestBaseFaulty(TestBase): + """This test ensures we aren't letting any exceptions go unhandled.""" diff --git a/cerberus/tests/client/__init__.py b/cerberus/tests/client/__init__.py new file mode 100644 index 0000000..4aa294d --- /dev/null +++ b/cerberus/tests/client/__init__.py @@ -0,0 +1 @@ +__author__ = 'svcdev' diff --git a/cerberus/tests/client/test_keystone_client.py b/cerberus/tests/client/test_keystone_client.py new file mode 100644 index 0000000..0e077c1 --- /dev/null +++ b/cerberus/tests/client/test_keystone_client.py @@ -0,0 +1,59 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import mock + +from oslo.config import cfg + +from cerberus.client import keystone_client +from cerberus.tests import base + +cfg.CONF.import_group('service_credentials', 'cerberus.service') + + +class TestKeystoneClient(base.TestBase): + + def setUp(self): + super(TestKeystoneClient, self).setUp() + + @staticmethod + def fake_get_user(): + return { + 'user': { + "id": "u1000", + "name": "jqsmith", + "email": "john.smith@example.org", + "enabled": True + } + } + + @mock.patch('keystoneclient.v2_0.client.Client') + def test_get_user(self, mock_client): + kc = keystone_client.Client() + user = self.fake_get_user() + kc.keystone_client_v2_0.users.get = mock.MagicMock( + return_value=user) + user = kc.user_detail_get("user") + self.assertEqual("u1000", user['user'].get('id')) + + @mock.patch('keystoneclient.v2_0.client.Client') + def test_roles_for_user(self, mock_client): + kc = keystone_client.Client() + kc.keystone_client_v2_0.roles.roles_for_user = mock.MagicMock( + return_value="role" + ) + role = kc.roles_for_user("user", "tenant") + self.assertEqual("role", role) diff --git a/cerberus/tests/client/test_neutron_client.py b/cerberus/tests/client/test_neutron_client.py new file mode 100644 index 0000000..865c952 --- /dev/null +++ b/cerberus/tests/client/test_neutron_client.py @@ -0,0 +1,212 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import mock + +from oslo.config import cfg + +from cerberus.client import neutron_client +from cerberus.tests import base + +cfg.CONF.import_group('service_credentials', 'cerberus.service') + + +class TestNeutronClient(base.TestBase): + + def setUp(self): + super(TestNeutronClient, self).setUp() + + @staticmethod + def fake_networks_list(): + return {'networks': + [{'admin_state_up': True, + 'id': '298a3088-a446-4d5a-bad8-f92ecacd786b', + 'name': 'public', + 'provider:network_type': 'gre', + 'provider:physical_network': None, + 'provider:segmentation_id': 2, + 'router:external': True, + 'shared': False, + 'status': 'ACTIVE', + 'subnets': [u'c4b6f5b8-3508-4896-b238-a441f25fb492'], + 'tenant_id': '62d6f08bbd3a44f6ad6f00ca15cce4e5'}, + ]} + + @staticmethod + def fake_network_get(): + return {"network": { + "status": "ACTIVE", + "subnets": [ + "54d6f61d-db07-451c-9ab3-b9609b6b6f0b"], + "name": "private-network", + "provider:physical_network": None, + "admin_state_up": True, + "tenant_id": "4fd44f30292945e481c7b8a0c8908869", + "provider:network_type": "local", + "router:external": True, + "shared": True, + "id": "d32019d3-bc6e-4319-9c1d-6722fc136a22", + "provider:segmentation_id": None + } + } + + @staticmethod + def fake_subnets_list(): + return {"subnets": [ + { + "name": "private-subnet", + "enable_dhcp": True, + "network_id": "db193ab3-96e3-4cb3-8fc5-05f4296d0324", + "tenant_id": "26a7980765d0414dbc1fc1f88cdb7e6e", + "dns_nameservers": [], + "allocation_pools": [ + { + "start": "10.0.0.2", + "end": "10.0.0.254" + } + ], + "host_routes": [], + "ip_version": 4, + "gateway_ip": "10.0.0.1", + "cidr": "10.0.0.0/24", + "id": "08eae331-0402-425a-923c-34f7cfe39c1b"}, + { + "name": "my_subnet", + "enable_dhcp": True, + "network_id": "d32019d3-bc6e-4319-9c1d-6722fc136a22", + "tenant_id": "4fd44f30292945e481c7b8a0c8908869", + "dns_nameservers": [], + "allocation_pools": [ + { + "start": "192.0.0.2", + "end": "192.255.255.254" + } + ], + "host_routes": [], + "ip_version": 4, + "gateway_ip": "192.0.0.1", + "cidr": "192.0.0.0/8", + "id": "54d6f61d-db07-451c-9ab3-b9609b6b6f0b" + } + ] + } + + @staticmethod + def fake_subnet_get(): + return {"subnet": { + "name": "my_subnet", + "enable_dhcp": True, + "network_id": "d32019d3-bc6e-4319-9c1d-6722fc136a22", + "tenant_id": "4fd44f30292945e481c7b8a0c8908869", + "dns_nameservers": [], + "allocation_pools": [ + { + "start": "192.0.0.2", + "end": "192.255.255.254" + }], + "host_routes": [], + "ip_version": 4, + "gateway_ip": "192.0.0.1", + "cidr": "192.0.0.0/8", + "id": "54d6f61d-db07-451c-9ab3-b9609b6b6f0b" + } + } + + @staticmethod + def fake_floating_ips_list(): + return {'floatingips': [ + { + 'router_id': 'd23abc8d-2991-4a55-ba98-2aaea84cc72f', + 'tenant_id': '4969c491a3c74ee4af974e6d800c62de', + 'floating_network_id': '376da547-b977-4cfe-9cba-275c80debf57', + 'fixed_ip_address': '10.0.0.3', + 'floating_ip_address': '172.24.4.228', + 'port_id': 'ce705c24-c1ef-408a-bda3-7bbd946164ab', + 'id': '2f245a7b-796b-4f26-9cf9-9e82d248fda7'}, + { + 'router_id': None, + 'tenant_id': '4969c491a3c74ee4af974e6d800c62de', + 'floating_network_id': '376da547-b977-4cfe-9cba-275c80debf57', + 'fixed_ip_address': None, + 'floating_ip_address': '172.24.4.227', + 'port_id': None, + 'id': '61cea855-49cb-4846-997d-801b70c71bdd' + } + ]} + + @mock.patch('neutronclient.v2_0.client.Client') + def test_list_networks(self, mock_client): + nc = neutron_client.Client() + nc.neutronClient.list_networks = mock.MagicMock( + return_value=self.fake_networks_list()) + networks = nc.list_networks('tenant') + self.assertTrue(len(networks) == 1) + self.assertEqual('298a3088-a446-4d5a-bad8-f92ecacd786b', + networks[0].get('id')) + + @mock.patch('neutronclient.v2_0.client.Client') + def test_list_floatingips(self, mock_client): + nc = neutron_client.Client() + nc.neutronClient.list_floatingips = mock.MagicMock( + return_value=self.fake_floating_ips_list()) + floating_ips = nc.list_floatingips('tenant') + self.assertTrue(len(floating_ips) == 2) + self.assertEqual('2f245a7b-796b-4f26-9cf9-9e82d248fda7', + floating_ips[0].get('id')) + self.assertEqual('61cea855-49cb-4846-997d-801b70c71bdd', + floating_ips[1].get('id')) + + @mock.patch('neutronclient.v2_0.client.Client') + def test_list_associated_floatingips(self, mock_client): + nc = neutron_client.Client() + nc.neutronClient.list_floatingips = mock.MagicMock( + return_value=self.fake_floating_ips_list()) + floating_ips = nc.list_associated_floatingips() + self.assertTrue(len(floating_ips) == 1) + self.assertEqual('2f245a7b-796b-4f26-9cf9-9e82d248fda7', + floating_ips[0].get('id')) + + @mock.patch('neutronclient.v2_0.client.Client') + def test_subnet_ips_get(self, mock_client): + nc = neutron_client.Client() + nc.neutronClient.show_subnet = mock.MagicMock( + return_value=self.fake_subnet_get()) + subnet_ips = nc.subnet_ips_get("d32019d3-bc6e-4319-9c1d-6722fc136a22") + self.assertTrue(len(subnet_ips) == 1) + self.assertEqual("192.0.0.2", subnet_ips[0].get("start", None)) + self.assertEqual("192.255.255.254", subnet_ips[0].get("end", None)) + + @mock.patch('neutronclient.v2_0.client.Client') + def test_net_ips_get(self, mock_client): + nc = neutron_client.Client() + nc.neutronClient.show_network = mock.MagicMock( + return_value=self.fake_network_get()) + nc.neutronClient.show_subnet = mock.MagicMock( + return_value=self.fake_subnet_get()) + ips = nc.net_ips_get("d32019d3-bc6e-4319-9c1d-6722fc136a22") + self.assertTrue(len(ips) == 1) + self.assertTrue(len(ips[0]) == 1) + self.assertEqual("192.0.0.2", ips[0][0].get("start", None)) + self.assertEqual("192.255.255.254", ips[0][0].get("end", None)) + + @mock.patch('neutronclient.v2_0.client.Client') + def test_get_net_of_subnet(self, mock_client): + nc = neutron_client.Client() + nc.neutronClient.show_subnet = mock.MagicMock( + return_value=self.fake_subnet_get()) + network_id = nc.get_net_of_subnet( + "54d6f61d-db07-451c-9ab3-b9609b6b6f0b") + self.assertEqual("d32019d3-bc6e-4319-9c1d-6722fc136a22", network_id) diff --git a/cerberus/tests/client/test_nova_client.py b/cerberus/tests/client/test_nova_client.py new file mode 100644 index 0000000..e2d6f51 --- /dev/null +++ b/cerberus/tests/client/test_nova_client.py @@ -0,0 +1,130 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import mock + +from oslo.config import cfg + +from cerberus.client import nova_client +from cerberus.tests import base + +cfg.CONF.import_group('service_credentials', 'cerberus.service') + + +class TestNovaClient(base.TestBase): + + @staticmethod + def fake_servers_list(*args, **kwargs): + a = mock.MagicMock() + a.id = 42 + a.flavor = {'id': 1} + a.image = {'id': 1} + a_addresses = [] + a_addresses.append({"addr": "10.0.0.1", "version": 4, + 'OS-EXT-IPS:type': 'floating'}) + a.addresses = {'private': a_addresses} + b = mock.MagicMock() + b.id = 43 + b.flavor = {'id': 2} + b.image = {'id': 2} + return [a, b] + + @staticmethod + def fake_detailed_servers_list(): + return \ + {"servers": [ + { + "accessIPv4": "", + "accessIPv6": "", + "addresses": { + "private": [ + { + "addr": "192.168.0.3", + "version": 4 + } + ] + }, + "created": "2012-09-07T16:56:37Z", + "flavor": { + "id": "1", + "links": [ + { + "href": "http://openstack.example.com/" + "openstack/flavors/1", + "rel": "bookmark" + } + ] + }, + "hostId": "16d193736a5cfdb60c697ca27ad071d6126fa13baeb670f" + "c9d10645e", + "id": "05184ba3-00ba-4fbc-b7a2-03b62b884931", + "image": { + "id": "70a599e0-31e7-49b7-b260-868f441e862b", + "links": [ + { + "href": "http://openstack.example.com/" + "openstack/images/70a599e0-31e7-49b7-" + "b260-868f441e862b", + "rel": "bookmark" + } + ] + }, + "links": [ + { + "href": "http://openstack.example.com/v2/" + "openstack/servers/05184ba3-00ba-4fbc-" + "b7a2-03b62b884931", + "rel": "self" + }, + { + "href": "http://openstack.example.com/openstack/" + "servers/05184ba3-00ba-4fbc-b7a2-" + "03b62b884931", + "rel": "bookmark" + } + ], + "metadata": { + "My Server Name": "Apache1" + }, + "name": "new-server-test", + "progress": 0, + "status": "ACTIVE", + "tenant_id": "openstack", + "updated": "2012-09-07T16:56:37Z", + "user_id": "fake" + } + ] + } + + def setUp(self): + super(TestNovaClient, self).setUp() + self.nova_client = nova_client.Client() + + def test_instance_get_all(self): + self.nova_client.nova_client.servers.list = mock.MagicMock( + return_value=self.fake_servers_list()) + instances = self.nova_client.instance_get_all() + self.assertTrue(instances is not None) + + def test_get_instance_details_from_floating_ip(self): + self.nova_client.nova_client.servers.list = mock.MagicMock( + return_value=self.fake_servers_list()) + instance_1 = self.nova_client.get_instance_details_from_floating_ip( + "10.0.0.1") + instance_2 = self.nova_client.get_instance_details_from_floating_ip( + "10.0.0.2") + self.assertTrue(instance_1 is not None) + self.assertTrue(instance_2 is None) diff --git a/cerberus/tests/config_fixture.py b/cerberus/tests/config_fixture.py new file mode 100644 index 0000000..46da1ab --- /dev/null +++ b/cerberus/tests/config_fixture.py @@ -0,0 +1,35 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import fixtures +from oslo.config import cfg + +from cerberus.common import config + +CONF = cfg.CONF + + +class ConfigFixture(fixtures.Fixture): + """Fixture to manage global conf settings.""" + + def __init__(self, conf): + self.conf = conf + + def setUp(self): + super(ConfigFixture, self).setUp() + self.conf.set_default('verbose', True) + config.parse_args([], default_config_files=[]) + self.addCleanup(self.conf.reset) diff --git a/cerberus/tests/db/__init__.py b/cerberus/tests/db/__init__.py new file mode 100644 index 0000000..b06b406 --- /dev/null +++ b/cerberus/tests/db/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/cerberus/tests/db/utils.py b/cerberus/tests/db/utils.py new file mode 100644 index 0000000..54d1f18 --- /dev/null +++ b/cerberus/tests/db/utils.py @@ -0,0 +1,113 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime + +from cerberus.db.sqlalchemy import models + + +def get_test_security_report(**kwargs): + return { + 'id': kwargs.get('id', 1), + 'plugin_id': kwargs.get('plugin_id', + '228df8e8-d5f4-4eb9-a547-dfc649dd1017'), + 'report_id': kwargs.get('report_id', '1234'), + 'component_id': kwargs.get('component_id', + '422zb9d5-c5g3-8wy9-a547-hhc885dd8548'), + 'component_type': kwargs.get('component_type', 'instance'), + 'component_name': kwargs.get('component_name', 'instance-test'), + 'project_id': kwargs.get('project_id', + '28c6f9e6add24c29a589a9967432fede'), + 'title': kwargs.get('title', 'test-security-report'), + 'description': kwargs.get('description', + 'no fear, this is just a test'), + 'security_rating': kwargs.get('security_rating', 5), + 'vulnerabilities': kwargs.get('vulnerabilities', 'vulns'), + 'vulnerabilities_number': kwargs.get('vulnerabilities_number', 1), + 'last_report_date': kwargs.get('last_report_date', + '2015-01-01 00:00:00') + } + + +def get_security_report_model(**kwargs): + security_report = models.SecurityReport() + security_report.id = kwargs.get('id', 1) + security_report.plugin_id = kwargs.get( + 'plugin_id', + '228df8e8-d5f4-4eb9-a547-dfc649dd1017' + ) + security_report.report_id = kwargs.get('report_id', '1234') + security_report.component_id = kwargs.get( + 'component_id', + '422zb9d5-c5g3-8wy9-a547-hhc885dd8548') + security_report.component_type = kwargs.get('component_type', 'instance') + security_report.component_name = kwargs.get('component_name', + 'instance-test') + security_report.project_id = kwargs.get('project_id', + '28c6f9e6add24c29a589a9967432fede') + security_report.title = kwargs.get('title', 'test-security-report') + security_report.description = kwargs.get('description', + 'no fear, this is just a test') + security_report.security_rating = kwargs.get('security_rating', 5) + security_report.vulnerabilities = kwargs.get('vulnerabilities', 'vulns') + security_report.vulnerabilities_number = kwargs.get( + 'vulnerabilities_number', 1) + security_report.last_report_date = kwargs.get( + 'last_report_date', + datetime.datetime(2015, 1, 1) + ) + return security_report + + +def get_test_plugin(**kwargs): + return { + 'id': kwargs.get('id', 1), + 'provider': kwargs.get('provider', 'provider'), + 'tool_name': kwargs.get('tool_name', 'toolbox'), + 'type': kwargs.get('type', 'tool_whatever'), + 'description': kwargs.get('description', 'This is a tool'), + 'uuid': kwargs.get('uuid', '490cc562-9e60-46a7-9b5f-c7619aca2e07'), + 'version': kwargs.get('version', '0.1a'), + 'name': kwargs.get('name', 'tooly'), + 'subscribed_events': kwargs.get('subscribed_events', + ["compute.instance.updated"]), + 'methods': kwargs.get('methods', []) + } + + +def get_plugin_model(**kwargs): + plugin = models.PluginInfo() + plugin.id = kwargs.get('id', 1) + plugin.provider = kwargs.get('provider', 'provider') + plugin.tool_name = kwargs.get('tool_name', 'toolbox') + plugin.type = kwargs.get('type', 'tool_whatever') + plugin.description = kwargs.get('description', 'This is a tool') + plugin.uuid = kwargs.get('uuid', '490cc562-9e60-46a7-9b5f-c7619aca2e07') + plugin.version = kwargs.get('version', '0.1a') + plugin.name = kwargs.get('name', 'tooly') + plugin.subscribed_events = kwargs.get('subscribed_events', + ["compute.instance.updated"]) + plugin.methods = kwargs.get('methods', []) + return plugin + + +def get_test_task(**kwargs): + return { + 'task_id': kwargs.get('task_id', 1), + 'task_type': kwargs.get('task_type', 'unique'), + 'task_name': kwargs.get('task_name', 'No Name'), + 'task_period': kwargs.get('task_period', ''), + } diff --git a/cerberus/tests/fake_policy.py b/cerberus/tests/fake_policy.py new file mode 100644 index 0000000..64bcecb --- /dev/null +++ b/cerberus/tests/fake_policy.py @@ -0,0 +1,22 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +policy_data = """ +{ + "context_is_admin": "role:admin", + "default": "" +} +""" diff --git a/cerberus/tests/policy_fixture.py b/cerberus/tests/policy_fixture.py new file mode 100644 index 0000000..815f39b --- /dev/null +++ b/cerberus/tests/policy_fixture.py @@ -0,0 +1,45 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import fixtures +import os + +from oslo.config import cfg + +from cerberus.common import policy as cerberus_policy +from cerberus.openstack.common import policy as common_policy +from cerberus.tests import fake_policy + +CONF = cfg.CONF + + +class PolicyFixture(fixtures.Fixture): + + def setUp(self): + super(PolicyFixture, self).setUp() + self.policy_dir = self.useFixture(fixtures.TempDir()) + self.policy_file_name = os.path.join(self.policy_dir.path, + 'policy.json') + with open(self.policy_file_name, 'w') as policy_file: + policy_file.write(fake_policy.policy_data) + CONF.set_override('policy_file', self.policy_file_name) + cerberus_policy._ENFORCER = None + self.addCleanup(cerberus_policy.get_enforcer().clear) + + def set_rules(self, rules): + common_policy.set_rules(common_policy.Rules( + dict((k, common_policy.parse_rule(v)) + for k, v in rules.items()))) diff --git a/cerberus/tests/test_cerberus_manager.py b/cerberus/tests/test_cerberus_manager.py new file mode 100644 index 0000000..e54fd36 --- /dev/null +++ b/cerberus/tests/test_cerberus_manager.py @@ -0,0 +1,581 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +test_cerberus manager +---------------------------------- + +Tests for `cerberus` module. +""" +import json + +from eventlet import greenpool +import mock +from oslo import messaging +import pkg_resources +from stevedore import extension + +from cerberus.common import errors +from cerberus.db.sqlalchemy import api +from cerberus import manager +from cerberus.openstack.common import loopingcall +from cerberus.openstack.common import threadgroup +from cerberus.plugins import base as base_plugin +from cerberus.tests import base + + +PLUGIN_UUID = 'UUID' + + +class FakePlugin(base_plugin.PluginBase): + + def __init__(self): + super(FakePlugin, self).__init__() + self._uuid = PLUGIN_UUID + + def fake_function(self, *args, **kwargs): + return(args, kwargs) + + @base_plugin.PluginBase.webmethod + def another_fake_but_web_method(self): + pass + + def process_notification(self, ctxt, publisher_id, event_type, payload, + metadata): + pass + + +class DbPluginInfo(object): + def __init__(self, id, uuid): + self.id = id + self.uuid = uuid + + +class EntryPoint(object): + def __init__(self): + self.dist = pkg_resources.Distribution.from_filename( + "FooPkg-1.2-py2.4.egg") + + +class TestCerberusManager(base.TestBase): + + def setUp(self): + super(TestCerberusManager, self).setUp() + self.plugin = FakePlugin() + self.extension_mgr = extension.ExtensionManager.make_test_instance( + [ + extension.Extension( + 'plugin', + EntryPoint(), + None, + self.plugin, ), + ] + ) + self.db_plugin_info = DbPluginInfo(1, PLUGIN_UUID) + self.manager = manager.CerberusManager() + self.manager.cerberus_manager = self.extension_mgr + + def test_register_plugin(self): + with mock.patch('cerberus.db.sqlalchemy.api.plugin_info_create') \ + as MockClass: + MockClass.return_value = DbPluginInfo(1, PLUGIN_UUID) + self.manager._register_plugin(self.manager. + cerberus_manager['plugin']) + self.assertEqual(self.db_plugin_info.uuid, + self.manager.cerberus_manager['plugin'].obj._uuid) + + def test_register_plugin_new_version(self): + with mock.patch('cerberus.db.sqlalchemy.api.plugin_info_get') \ + as MockClass: + MockClass.return_value = DbPluginInfo(1, PLUGIN_UUID) + api.plugin_version_update = mock.MagicMock() + self.manager._register_plugin( + self.manager.cerberus_manager['plugin']) + self.assertEqual(self.db_plugin_info.uuid, + self.manager.cerberus_manager['plugin'].obj._uuid) + + @mock.patch.object(messaging.MessageHandlingServer, 'start') + def test_start(self, rpc_start): + manager.CerberusManager._register_plugin = mock.MagicMock() + mgr = manager.CerberusManager() + mgr.start() + rpc_start.assert_called_with() + assert(rpc_start.call_count == 2) + + @mock.patch.object(greenpool.GreenPool, "spawn") + def test_add_task_without_args(self, mock): + self.manager._add_unique_task( + self.manager.cerberus_manager['plugin'].obj.fake_function) + assert(len(self.manager.tg.threads) == 1) + mock.assert_called_with( + self.manager.cerberus_manager['plugin'].obj.fake_function) + + @mock.patch.object(greenpool.GreenPool, "spawn") + def test_add_task_with_args(self, mock): + self.manager._add_unique_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + name="fake") + assert(len(self.manager.tg.threads) == 1) + mock.assert_called_with( + self.manager.cerberus_manager['plugin'].obj.fake_function, + name="fake") + + @mock.patch.object(loopingcall.FixedIntervalLoopingCall, "start") + def test_add_recurrent_task_without_delay(self, mock): + self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + 15) + assert(len(self.manager.tg.timers) == 1) + mock.assert_called_with(initial_delay=None, interval=15) + + @mock.patch.object(loopingcall.FixedIntervalLoopingCall, "start") + def test_add_recurrent_task_with_delay(self, mock): + self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + 15, + 200) + assert(len(self.manager.tg.timers) == 1) + mock.assert_called_with(initial_delay=200, interval=15) + + @mock.patch.object(greenpool.GreenPool, "spawn") + def test_add_task(self, mock): + ctx = {"some": "context"} + self.manager.add_task(ctx, PLUGIN_UUID, 'fake_function') + assert(len(self.manager.tg.threads) == 1) + mock.assert_called_with(self.manager.cerberus_manager['plugin'].obj. + fake_function, + plugin_id=PLUGIN_UUID, + task_id=1) + + @mock.patch.object(greenpool.GreenPool, "spawn") + def test_add_task_incorrect_task_type(self, mock): + ctx = {"some": "context"} + self.manager.add_task(ctx, PLUGIN_UUID, 'fake_function', + task_type='INCORRECT') + assert(len(self.manager.tg.threads) == 1) + mock.assert_called_with(self.manager.cerberus_manager[ + 'plugin'].obj.fake_function, + plugin_id=PLUGIN_UUID, + task_type='INCORRECT', + task_id=1) + + @mock.patch.object(loopingcall.FixedIntervalLoopingCall, "start") + def test_add_recurrent_task_with_interval(self, mock): + ctx = {"some": "context"} + self.manager.add_task(ctx, PLUGIN_UUID, 'fake_function', + task_type='recurrent', task_period=5) + assert(len(self.manager.tg.timers) == 1) + mock.assert_called_with(initial_delay=None, interval=5) + + def test_get_recurrent_task(self): + task_id = self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + 15) + recurrent_task = self.manager._get_recurrent_task(task_id) + assert(isinstance(recurrent_task, + loopingcall.FixedIntervalLoopingCall)) + + def test_get_recurrent_task_wrong_id(self): + task_id = 1 + self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + 15, + task_id=task_id) + self.assertTrue(self.manager._get_recurrent_task(task_id + 1) is None) + + def test_get_plugins(self): + ctx = {"some": "context"} + json_plugin1 = { + "name": "cerberus.tests.test_cerberus_manager.FakePlugin", + "subscribed_events": + [ + ], + "methods": + [ + "another_fake_but_web_method" + ] + } + expected_json_plugins = [] + jplugin1 = json.dumps(json_plugin1) + expected_json_plugins.append(jplugin1) + json_plugins = self.manager.get_plugins(ctx) + self.assertEqual(json_plugins, expected_json_plugins) + + def test_get_plugin(self): + ctx = {"some": "context"} + c_manager = manager.CerberusManager() + c_manager.cerberus_manager = self.extension_mgr + + json_plugin1 = { + "name": "cerberus.tests.test_cerberus_manager.FakePlugin", + "subscribed_events": + [ + ], + "methods": + [ + "another_fake_but_web_method" + ] + } + jplugin1 = json.dumps(json_plugin1) + json_plugin = c_manager.get_plugin_from_uuid(ctx, PLUGIN_UUID) + self.assertEqual(json_plugin, jplugin1) + + def test_get_plugin_wrong_id(self): + ctx = {"some": "context"} + self.assertEqual(self.manager.get_plugin_from_uuid(ctx, 'wrong_test'), + None) + + def test_get_tasks(self): + recurrent_task_id = 1 + unique_task_id = 2 + task_period = 5 + self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + task_period, + task_id=recurrent_task_id) + self.manager._add_unique_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + task_id=unique_task_id) + tasks = self.manager._get_tasks() + self.assertTrue(len(tasks) == 2) + self.assertTrue(isinstance(tasks[0], + loopingcall.FixedIntervalLoopingCall)) + self.assertTrue(isinstance(tasks[1], threadgroup.Thread)) + + def test_get_tasks_(self): + recurrent_task_id = 1 + unique_task_id = 2 + task_period = 5 + self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + task_period, + task_id=recurrent_task_id) + self.manager._add_unique_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + task_id=unique_task_id) + tasks = self.manager.get_tasks({'some': 'context'}) + self.assertTrue(len(tasks) == 2) + + def test_get_task_reccurent(self): + task_id = 1 + task_period = 5 + self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + task_period, + task_id=task_id) + task = self.manager._get_task(task_id) + self.assertTrue(isinstance(task, loopingcall.FixedIntervalLoopingCall)) + + def test_get_task_unique(self): + task_id = 1 + self.manager._add_unique_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + task_id=task_id) + task = self.manager._get_task(task_id) + self.assertTrue(isinstance(task, threadgroup.Thread)) + + def test_get_task(self): + recurrent_task_id = 1 + recurrent_task_name = "recurrent_task" + unique_task_id = 2 + unique_task_name = "unique_task" + task_period = 5 + self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + task_period, + task_name=recurrent_task_name, + task_period=task_period, + task_id=recurrent_task_id) + self.manager._add_unique_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + task_id=unique_task_id, + task_name=unique_task_name) + task = self.manager.get_task({'some': 'context'}, 1) + self.assertTrue(json.loads(task).get('name') == recurrent_task_name) + self.assertTrue(json.loads(task).get('id') == recurrent_task_id) + task_2 = self.manager.get_task({'some': 'context'}, 2) + self.assertTrue(json.loads(task_2).get('name') == unique_task_name) + self.assertTrue(json.loads(task_2).get('id') == unique_task_id) + + def test_stop_unique_task(self): + task_id = 1 + self.manager._add_unique_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + task_id=task_id) + assert(len(self.manager.tg.threads) == 1) + self.manager._stop_unique_task(task_id) + assert(len(self.manager.tg.threads) == 0) + + def test_stop_recurrent_task(self): + task_id = 1 + self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + 5, + task_id=task_id) + assert(self.manager.tg.timers[0]._running is True) + self.manager._stop_recurrent_task(task_id) + assert(self.manager.tg.timers[0]._running is False) + + def test_stop_task_recurrent(self): + recurrent_task_id = 1 + unique_task_id = 2 + task_period = 5 + self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + task_period, + task_id=recurrent_task_id) + self.manager._add_unique_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + task_id=unique_task_id) + self.assertTrue(len(self.manager.tg.timers) == 1) + assert(self.manager.tg.timers[0]._running is True) + self.assertTrue(len(self.manager.tg.threads) == 1) + self.manager._stop_task(recurrent_task_id) + self.assertTrue(len(self.manager.tg.timers) == 1) + assert(self.manager.tg.timers[0]._running is False) + self.assertTrue(len(self.manager.tg.threads) == 1) + self.manager._stop_task(unique_task_id) + self.assertTrue(len(self.manager.tg.timers) == 1) + assert(self.manager.tg.timers[0]._running is False) + self.assertTrue(len(self.manager.tg.threads) == 0) + + @mock.patch.object(manager.CerberusManager, "_stop_task") + def test_stop_task(self, mock): + self.manager.stop_task({'some': 'context'}, 1) + mock.assert_called_with(1) + + def test_delete_recurrent_task(self): + ctx = {"some": "context"} + task_id = 1 + self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + 5, + task_id=task_id) + recurrent_task = self.manager._get_recurrent_task(task_id) + assert(self.manager.tg.timers[0]._running is True) + assert(recurrent_task.gt.dead is False) + self.manager.delete_recurrent_task(ctx, task_id) + assert(recurrent_task.gt.dead is False) + assert(len(self.manager.tg.timers) == 0) + + def test_force_delete_recurrent_task(self): + task_id = 1 + ctx = {"some": "ctx"} + self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + 5, + task_id=task_id) + recurrent_task = self.manager._get_recurrent_task(task_id) + assert(self.manager.tg.timers[0]._running is True) + assert(recurrent_task.gt.dead is False) + self.manager.force_delete_recurrent_task(ctx, task_id) + assert(recurrent_task.gt.dead is True) + assert(len(self.manager.tg.timers) == 0) + + def test_restart_recurrent_task(self): + ctxt = {'some': 'context'} + task_id = 1 + task_period = 5 + self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + task_period, + task_id=task_id, + task_period=task_period) + assert(self.manager.tg.timers[0]._running is True) + self.manager._stop_recurrent_task(task_id) + assert(self.manager.tg.timers[0]._running is False) + self.manager.restart_recurrent_task(ctxt, task_id) + assert(self.manager.tg.timers[0]._running is True) + + +class FaultyTestCerberusManager(base.TestBaseFaulty): + + def setUp(self): + super(FaultyTestCerberusManager, self).setUp() + self.plugin = FakePlugin() + self.extension_mgr = extension.ExtensionManager.make_test_instance( + [ + extension.Extension( + 'plugin', + EntryPoint(), + None, + self.plugin, ), + ] + ) + self.db_plugin_info = DbPluginInfo(1, PLUGIN_UUID) + self.manager = manager.CerberusManager() + self.manager.cerberus_manager = self.extension_mgr + + def test_add_task_wrong_plugin_id(self): + ctx = {"some": "context"} + self.assertRaises(errors.PluginNotFound, self.manager.add_task, + ctx, 'WRONG_UUID', 'fake_function') + assert(len(self.manager.tg.threads) == 0) + + def test_add_task_incorrect_period(self): + ctx = {"some": "context"} + self.assertRaises(errors.TaskPeriodNotInteger, + self.manager.add_task, + ctx, + PLUGIN_UUID, + 'fake_function', + task_type='recurrent', + task_period='NOT_INTEGER') + assert(len(self.manager.tg.threads) == 0) + + def test_add_task_wrong_plugin_method(self): + ctx = {"some": "context"} + self.assertRaises(errors.MethodNotCallable, + self.manager.add_task, ctx, PLUGIN_UUID, 'fake') + assert(len(self.manager.tg.threads) == 0) + + def test_add_task_method_not_as_string(self): + ctx = {"some": "context"} + self.assertRaises(errors.MethodNotString, + self.manager.add_task, + ctx, + PLUGIN_UUID, + self.manager.cerberus_manager[ + 'plugin'].obj.fake_function) + assert(len(self.manager.tg.threads) == 0) + + def test_add_recurrent_task_without_period(self): + ctx = {"some": "context"} + self.assertRaises(errors.TaskPeriodNotInteger, + self.manager.add_task, + ctx, + PLUGIN_UUID, + 'fake_function', + task_type='recurrent') + assert(len(self.manager.tg.timers) == 0) + + def test_add_recurrent_task_wrong_plugin_method(self): + ctx = {"some": "context"} + self.assertRaises(errors.MethodNotCallable, + self.manager.add_task, ctx, PLUGIN_UUID, 'fake', + task_type='recurrent', task_period=5) + assert(len(self.manager.tg.timers) == 0) + + def test_add_recurrent_task_method_not_as_string(self): + ctx = {"some": "context"} + self.assertRaises(errors.MethodNotString, + self.manager.add_task, + ctx, + PLUGIN_UUID, + self.manager.cerberus_manager[ + 'plugin'].obj.fake_function, + task_type='recurrent', + task_period=5) + assert(len(self.manager.tg.timers) == 0) + + def test_get_task_unique_wrong_id(self): + task_id = 1 + ctx = {"some": "context"} + self.manager._add_unique_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + 5, + task_id=task_id) + self.assertRaises(errors.TaskNotFound, + self.manager.get_task, + ctx, + task_id + 1) + + def test_stop_unique_task_wrong_id(self): + task_id = 1 + self.manager._add_unique_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + task_id=task_id) + assert(len(self.manager.tg.threads) == 1) + self.assertRaises(errors.TaskNotFound, + self.manager._stop_unique_task, + task_id + 1) + assert(len(self.manager.tg.threads) == 1) + + def test_stop_recurrent_task_wrong_id(self): + task_id = 1 + self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + 5, + task_id=task_id) + assert(self.manager.tg.timers[0]._running is True) + self.assertRaises(errors.TaskNotFound, + self.manager._stop_recurrent_task, + task_id + 1) + assert(self.manager.tg.timers[0]._running is True) + + def test_delete_recurrent_task_wrong_id(self): + ctx = {"some": "context"} + task_id = 1 + self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + 5, + task_id=task_id) + recurrent_task = self.manager._get_recurrent_task(task_id) + assert(self.manager.tg.timers[0]._running is True) + assert(recurrent_task.gt.dead is False) + self.assertRaises(errors.TaskDeletionNotAllowed, + self.manager.delete_recurrent_task, + ctx, + task_id + 1) + assert(self.manager.tg.timers[0]._running is True) + assert(recurrent_task.gt.dead is False) + + def test_force_delete_recurrent_task_wrong_id(self): + ctx = {"some": "ctx"} + task_id = 1 + self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + 5, + task_id=task_id) + recurrent_task = self.manager._get_recurrent_task(task_id) + assert(self.manager.tg.timers[0]._running is True) + assert(recurrent_task.gt.dead is False) + self.assertRaises(errors.TaskDeletionNotAllowed, + self.manager.force_delete_recurrent_task, + ctx, + task_id + 1) + assert(recurrent_task.gt.dead is False) + assert(len(self.manager.tg.timers) == 1) + + def test_restart_recurrent_task_wrong_id(self): + ctxt = {"some": "ctx"} + task_id = 1 + self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + 5, + task_id=task_id) + assert(self.manager.tg.timers[0]._running is True) + self.manager._stop_recurrent_task(task_id) + assert(self.manager.tg.timers[0]._running is False) + self.assertRaises(errors.TaskRestartNotAllowed, + self.manager.restart_recurrent_task, + ctxt, + task_id + 1) + assert(self.manager.tg.timers[0]._running is False) + + def test_restart_recurrent_task_running(self): + ctxt = {"some": "ctx"} + task_id = 1 + self.manager._add_recurrent_task( + self.manager.cerberus_manager['plugin'].obj.fake_function, + 5, + task_id=task_id) + assert(self.manager.tg.timers[0]._running is True) + self.assertRaises(errors.TaskRestartNotPossible, + self.manager.restart_recurrent_task, + ctxt, + task_id) + assert(self.manager.tg.timers[0]._running is True) diff --git a/cerberus/tests/test_db_api.py b/cerberus/tests/test_db_api.py new file mode 100644 index 0000000..ebaa74c --- /dev/null +++ b/cerberus/tests/test_db_api.py @@ -0,0 +1,78 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Tests for `db api` module. +""" + +import mock + +from oslo.config import fixture as fixture_config + +from cerberus.db.sqlalchemy import api +from cerberus.db.sqlalchemy import models +from cerberus.openstack.common.db.sqlalchemy import models as db_models +from cerberus.tests import base + + +class DbApiTestCase(base.TestBase): + + def test_alert_create(self): + self.CONF = self.useFixture(fixture_config.Config()).conf + self.CONF([], project='cerberus') + al = api.alert_create({'title': 'TitleAlert'}) + self.assertTrue(al.id >= 0) + + def test_alert_get_all(self): + self.CONF = self.useFixture(fixture_config.Config()).conf + self.CONF([], project='cerberus') + self.test_alert_create() + al = api.alert_get_all() + for a in al: + dec = models.AlertJsonSerializer().serialize(a) + self.assertTrue(dec['id'], 1) + self.assertTrue(dec['title'], 'TitleAlert') + + def test_security_report_create(self): + self.CONF = self.useFixture(fixture_config.Config()).conf + self.CONF([], project='cerberus') + db_models.ModelBase.save = mock.MagicMock() + report = api.security_report_create({'title': 'TitleSecurityReport', + 'plugin_id': '123456789', + 'description': 'The first', + 'component_id': '1234'}) + + self.assertEqual('TitleSecurityReport', report.title) + self.assertEqual('123456789', report.plugin_id) + self.assertEqual('The first', report.description) + self.assertEqual('1234', report.component_id) + + def test_plugin_info_create(self): + self.CONF = self.useFixture(fixture_config.Config()).conf + self.CONF([], project='cerberus') + pi = api.plugin_info_create({'name': 'NameOfPlugin', + 'uuid': '0000-aaaa-1111-bbbb'}) + self.assertTrue(pi.id >= 0) + + def test_alert_get(self): + self.CONF = self.useFixture(fixture_config.Config()).conf + self.CONF([], project='cerberus') + + pi = api.plugin_info_create({'name': 'NameOfPluginToGet', + 'uuid': '3333-aaaa-1111-bbbb'}) + + pi = api.plugin_info_get('NameOfPluginToGet') + self.assertEqual('NameOfPluginToGet', pi.name) diff --git a/cerberus/tests/test_utils.py b/cerberus/tests/test_utils.py new file mode 100644 index 0000000..b10cfe1 --- /dev/null +++ b/cerberus/tests/test_utils.py @@ -0,0 +1,98 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Utilities and helper functions.""" +"""Tests for cerberus/utils.py +""" +import datetime +import decimal + +from oslotest import base + +from cerberus import utils + + +class TestUtils(base.BaseTestCase): + + def test_datetime_to_decimal(self): + expected = 1356093296.12 + utc_datetime = datetime.datetime.utcfromtimestamp(expected) + actual = utils.dt_to_decimal(utc_datetime) + self.assertAlmostEqual(expected, float(actual), places=5) + + def test_decimal_to_datetime(self): + expected = 1356093296.12 + dexpected = decimal.Decimal(str(expected)) # Python 2.6 wants str() + expected_datetime = datetime.datetime.utcfromtimestamp(expected) + actual_datetime = utils.decimal_to_dt(dexpected) + # Python 3 have rounding issue on this, so use float + self.assertAlmostEqual(utils.dt_to_decimal(expected_datetime), + utils.dt_to_decimal(actual_datetime), + places=5) + + def test_restore_nesting_unested(self): + metadata = {'a': 'A', 'b': 'B'} + unwound = utils.restore_nesting(metadata) + self.assertIs(metadata, unwound) + + def test_restore_nesting(self): + metadata = {'a': 'A', 'b': 'B', + 'nested:a': 'A', + 'nested:b': 'B', + 'nested:twice:c': 'C', + 'nested:twice:d': 'D', + 'embedded:e': 'E'} + unwound = utils.restore_nesting(metadata) + expected = {'a': 'A', 'b': 'B', + 'nested': {'a': 'A', 'b': 'B', + 'twice': {'c': 'C', 'd': 'D'}}, + 'embedded': {'e': 'E'}} + self.assertEqual(expected, unwound) + self.assertIsNot(metadata, unwound) + + def test_restore_nesting_with_separator(self): + metadata = {'a': 'A', 'b': 'B', + 'nested.a': 'A', + 'nested.b': 'B', + 'nested.twice.c': 'C', + 'nested.twice.d': 'D', + 'embedded.e': 'E'} + unwound = utils.restore_nesting(metadata, separator='.') + expected = {'a': 'A', 'b': 'B', + 'nested': {'a': 'A', 'b': 'B', + 'twice': {'c': 'C', 'd': 'D'}}, + 'embedded': {'e': 'E'}} + self.assertEqual(expected, unwound) + self.assertIsNot(metadata, unwound) + + def test_decimal_to_dt_with_none_parameter(self): + self.assertIsNone(utils.decimal_to_dt(None)) + + def test_dict_to_kv(self): + data = {'a': 'A', + 'b': 'B', + 'nested': {'a': 'A', + 'b': 'B', + }, + 'nested2': [{'c': 'A'}, {'c': 'B'}] + } + pairs = list(utils.dict_to_keyval(data)) + self.assertEqual([('a', 'A'), + ('b', 'B'), + ('nested.a', 'A'), + ('nested.b', 'B'), + ('nested2[0].c', 'A'), + ('nested2[1].c', 'B')], + sorted(pairs, key=lambda x: x[0])) diff --git a/cerberus/tests/utils.py b/cerberus/tests/utils.py new file mode 100644 index 0000000..c85a208 --- /dev/null +++ b/cerberus/tests/utils.py @@ -0,0 +1,42 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sqlalchemy + +from cerberus.db import api as db_api +from cerberus.openstack.common.db import options + + +get_engine = db_api.get_engine + + +def setup_dummy_db(): + options.cfg.set_defaults(options.database_opts, sqlite_synchronous=False) + options.set_defaults("sqlite://", sqlite_db='heat.db') + engine = get_engine() + db_api.db_sync(engine) + engine.connect() + + +def reset_dummy_db(): + engine = get_engine() + meta = sqlalchemy.MetaData() + meta.reflect(bind=engine) + + for table in reversed(meta.sorted_tables): + if table.name == 'migrate_version': + continue + engine.execute(table.delete()) diff --git a/cerberus/utils.py b/cerberus/utils.py new file mode 100644 index 0000000..97b4092 --- /dev/null +++ b/cerberus/utils.py @@ -0,0 +1,163 @@ +# +# Copyright (c) 2014 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Utilities and helper functions.""" + +import calendar +import copy +import datetime +import decimal +import multiprocessing + +from oslo.utils import timeutils +from oslo.utils import units + + +def restore_nesting(d, separator=':'): + """Unwinds a flattened dict to restore nesting. + """ + d = copy.copy(d) if any([separator in k for k in d.keys()]) else d + for k, v in d.items(): + if separator in k: + top, rem = k.split(separator, 1) + nest = d[top] if isinstance(d.get(top), dict) else {} + nest[rem] = v + d[top] = restore_nesting(nest, separator) + del d[k] + return d + + +def dt_to_decimal(utc): + """Datetime to Decimal. + + Some databases don't store microseconds in datetime + so we always store as Decimal unixtime. + """ + if utc is None: + return None + + decimal.getcontext().prec = 30 + return decimal.Decimal(str(calendar.timegm(utc.utctimetuple()))) + \ + (decimal.Decimal(str(utc.microsecond)) / + decimal.Decimal("1000000.0")) + + +def decimal_to_dt(dec): + """Return a datetime from Decimal unixtime format. + """ + if dec is None: + return None + + integer = int(dec) + micro = (dec - decimal.Decimal(integer)) * decimal.Decimal(units.M) + daittyme = datetime.datetime.utcfromtimestamp(integer) + return daittyme.replace(microsecond=int(round(micro))) + + +def sanitize_timestamp(timestamp): + """Return a naive utc datetime object.""" + if not timestamp: + return timestamp + if not isinstance(timestamp, datetime.datetime): + timestamp = timeutils.parse_isotime(timestamp) + return timeutils.normalize_time(timestamp) + + +def stringify_timestamps(data): + """Stringify any datetimes in given dict.""" + isa_timestamp = lambda v: isinstance(v, datetime.datetime) + return dict((k, v.isoformat() if isa_timestamp(v) else v) + for (k, v) in data.iteritems()) + + +def dict_to_keyval(value, key_base=None): + """Expand a given dict to its corresponding key-value pairs. + + Generated keys are fully qualified, delimited using dot notation. + ie. key = 'key.child_key.grandchild_key[0]' + """ + val_iter, key_func = None, None + if isinstance(value, dict): + val_iter = value.iteritems() + key_func = lambda k: key_base + '.' + k if key_base else k + elif isinstance(value, (tuple, list)): + val_iter = enumerate(value) + key_func = lambda k: key_base + '[%d]' % k + + if val_iter: + for k, v in val_iter: + key_gen = key_func(k) + if isinstance(v, dict) or isinstance(v, (tuple, list)): + for key_gen, v in dict_to_keyval(v, key_gen): + yield key_gen, v + else: + yield key_gen, v + + +def lowercase_keys(mapping): + """Converts the values of the keys in mapping to lowercase.""" + items = mapping.items() + for key, value in items: + del mapping[key] + mapping[key.lower()] = value + + +def lowercase_values(mapping): + """Converts the values in the mapping dict to lowercase.""" + items = mapping.items() + for key, value in items: + mapping[key] = value.lower() + + +def update_nested(original_dict, updates): + """Updates the leaf nodes in a nest dict, without replacing + entire sub-dicts. + """ + dict_to_update = copy.deepcopy(original_dict) + for key, value in updates.iteritems(): + if isinstance(value, dict): + sub_dict = update_nested(dict_to_update.get(key, {}), value) + dict_to_update[key] = sub_dict + else: + dict_to_update[key] = updates[key] + return dict_to_update + + +def cpu_count(): + try: + return multiprocessing.cpu_count() or 1 + except NotImplementedError: + return 1 + + +def uniq(dupes, attrs): + """Exclude elements of dupes with a duplicated set of attribute values.""" + key = lambda d: '/'.join([getattr(d, a) or '' for a in attrs]) + keys = [] + deduped = [] + for d in dupes: + if key(d) not in keys: + deduped.append(d) + keys.append(key(d)) + return deduped + + +def create_datetime_obj(date): + """ + '20150109T10:53:50' + :param date: The date to build a datetime object. Format: 20150109T10:53:50 + :return: a datetime object + """ + return datetime.datetime.strptime(date, '%Y%m%dT%H:%M:%S') diff --git a/cerberus/version.py b/cerberus/version.py new file mode 100644 index 0000000..b5322f9 --- /dev/null +++ b/cerberus/version.py @@ -0,0 +1,19 @@ +# +# Copyright (c) 2015 EUROGICIEL +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pbr.version + +version_info = pbr.version.VersionInfo('cerberus') diff --git a/contrib/devstack/README.rst b/contrib/devstack/README.rst new file mode 100644 index 0000000..e69de29 diff --git a/contrib/devstack/extras.d/50-cerberus.sh b/contrib/devstack/extras.d/50-cerberus.sh new file mode 100644 index 0000000..be9b8fc --- /dev/null +++ b/contrib/devstack/extras.d/50-cerberus.sh @@ -0,0 +1,39 @@ +# cerberus.sh - Devstack extras script to install Cerberus + +if is_service_enabled cerberus-api cerberus-agent; then + if [[ "$1" == "source" ]]; then + # Initial source + source $TOP_DIR/lib/cerberus + elif [[ "$1" == "stack" && "$2" == "install" ]]; then + echo_summary "Installing Cerberus" + install_cerberus + install_cerberusclient + + if is_service_enabled cerberus-dashboard; then + install_cerberusdashboard + fi + cleanup_cerberus + elif [[ "$1" == "stack" && "$2" == "post-config" ]]; then + echo_summary "Configuring Cerberus" + configure_cerberus + if is_service_enabled cerberus-dashboard; then + configure_cerberusdashboard + fi + if is_service_enabled key; then + create_cerberus_accounts + fi + + elif [[ "$1" == "stack" && "$2" == "extra" ]]; then + # Initialize cerberus + echo_summary "Initializing Cerberus" + init_cerberus + + # Start the Cerberus API and Cerberus agent components + echo_summary "Starting Cerberus" + start_cerberus + fi + + if [[ "$1" == "unstack" ]]; then + stop_cerberus + fi +fi diff --git a/doc/source/conf.py b/doc/source/conf.py index ba7de3e..fb8ffe2 100755 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -72,4 +72,4 @@ latex_documents = [ ] # Example configuration for intersphinx: refer to the Python standard library. -#intersphinx_mapping = {'http://docs.python.org/': None} +#intersphinx_mapping = {'http://docs.python.org/': None} \ No newline at end of file diff --git a/etc/cerberus/cerberus.conf.sample b/etc/cerberus/cerberus.conf.sample new file mode 100644 index 0000000..e9a404a --- /dev/null +++ b/etc/cerberus/cerberus.conf.sample @@ -0,0 +1,32 @@ + +[DEFAULT] +policy_file = /etc/cerberus/policy.json +debug = True +verbose = True +notification_topics = svc_notifications +rabbit_password = guest +rabbit_hosts = localhost +# rpc_backend = cerberus.openstack.common.rpc.impl_kombu + +[service_credentials] +os_tenant_name = service +os_password = svc +os_username = cerberus + +[keystone_authtoken] +signing_dir = /var/cache/cerberus +admin_tenant_name = service +admin_password = svc +admin_user = cerberus +auth_protocol = http +auth_port = 5000 +auth_host = localhost + +[database] +connection = mysql://root:svc@localhost/cerberus?charset=utf8 + +[ikare] +ikare_admin=NONE +ikare_password=NONE +ikare_url=HOST +ikare_role_name=ikare diff --git a/etc/cerberus/policy.json b/etc/cerberus/policy.json new file mode 100644 index 0000000..4ac0d25 --- /dev/null +++ b/etc/cerberus/policy.json @@ -0,0 +1,4 @@ +{ + "context_is_admin": "role:admin", + "default": "" +} \ No newline at end of file diff --git a/openstack-common.conf b/openstack-common.conf index 1079159..65befd4 100644 --- a/openstack-common.conf +++ b/openstack-common.conf @@ -1,6 +1,36 @@ [DEFAULT] # The list of modules to copy from oslo-incubator.git +module=cliutils +module=config +module=config.generator +module=context +module=db +module=db.sqlalchemy +module=db.sqlalchemy.migration_cli +module=eventlet_backdoor +module=excutils +module=fileutils +module=flakes +module=gettextutils +module=importutils +module=install_venv_common +module=jsonutils +module=local +module=lockutils +module=log +module=log_handler +module=network_utils +module=notifier +module=periodic_task +module=policy +module=processutils +module=py3kcompat +module=service +module=setup +module=strutils +module=timeutils +module=test # The base module to hold the copy of openstack.common base=cerberus \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index bc7131e..7beb3ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,21 @@ # The order of packages is significant, because pip processes them in the order # of appearance. Changing the order has an impact on the overall integration # process, which may cause wedges in the gate later. - pbr>=0.6,!=0.7,<1.0 -Babel>=1.3 \ No newline at end of file +Babel>=1.3 +eventlet>=0.15.1 +greenlet>=0.3.2 +lockfile>=0.8 +mysql-python +oslo.config>=1.2.0,<1.5 +oslo.messaging>=1.3.0,<=1.4.1 +pecan>=0.4.5 +posix_ipc +python-keystoneclient>=0.4.2,<0.12 +python-neutronclient>=2.3 +python-novaclient==2.20 +six>=1.6.0 +SQLAlchemy>=0.7.8,!=0.9.5,<=0.9.99 +sqlalchemy-migrate>=0.8.2,!=0.8.4,<=0.9.1 +webob>=1.2.3 +WSME>=0.6 diff --git a/setup.cfg b/setup.cfg index 61e2cd0..0e9baa5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,7 @@ [metadata] name = cerberus summary = Cerberus security component +version = 2014.1 description-file = README.rst author = OpenStack @@ -44,4 +45,15 @@ input_file = cerberus/locale/cerberus.pot [extract_messages] keywords = _ gettext ngettext l_ lazy_gettext mapping_file = babel.cfg -output_file = cerberus/locale/cerberus.pot \ No newline at end of file +output_file = cerberus/locale/cerberus.pot + +[entry_points] +console_scripts = + cerberus-api = cerberus.cmd.api:main + cerberus-agent = cerberus.cmd.agent:main + dbcreate = cerberus.cmd.db_create:main + +cerberus.plugins = + testplugin = cerberus.plugins.test_plugin:TestPlugin + openvasplugin = cerberus.plugins.openvas:OpenVasPlugin + taskplugin = cerberus.plugins.task_plugin:TaskPlugin diff --git a/test-requirements.txt b/test-requirements.txt index 7b79352..ec7d4f3 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -3,13 +3,17 @@ # process, which may cause wedges in the gate later. hacking>=0.9.2,<0.10 - +# mock object framework +mock>=1.0 coverage>=3.6 discover +# fixture stubbing +fixtures>=0.3.14 +oslotest>=1.1.0 # Apache-2.0 python-subunit +nose +nose-exclude +nosexcover sphinx>=1.1.2 oslosphinx -oslotest>=1.1.0.0a1 -testrepository>=0.0.18 -testscenarios>=0.4 -testtools>=0.9.34 \ No newline at end of file + diff --git a/tools/config/check_uptodate.sh b/tools/config/check_uptodate.sh new file mode 100755 index 0000000..1885e70 --- /dev/null +++ b/tools/config/check_uptodate.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +PROJECT_NAME=${PROJECT_NAME:-cerberus} +CFGFILE_NAME=${PROJECT_NAME}.conf.sample + +if [ -e etc/${PROJECT_NAME}/${CFGFILE_NAME} ]; then + CFGFILE=etc/${PROJECT_NAME}/${CFGFILE_NAME} +elif [ -e etc/${CFGFILE_NAME} ]; then + CFGFILE=etc/${CFGFILE_NAME} +else + echo "${0##*/}: can not find config file" + exit 1 +fi + +TEMPDIR=`mktemp -d /tmp/${PROJECT_NAME}.XXXXXX` +trap "rm -rf $TEMPDIR" EXIT + +tools/config/generate_sample.sh -b ./ -p ${PROJECT_NAME} -o ${TEMPDIR} + +if ! diff -u ${TEMPDIR}/${CFGFILE_NAME} ${CFGFILE} +then + echo "${0##*/}: ${PROJECT_NAME}.conf.sample is not up to date." + echo "${0##*/}: Please run ${0%%${0##*/}}generate_sample.sh." + exit 1 +fi diff --git a/tools/config/generate_sample.sh b/tools/config/generate_sample.sh new file mode 100755 index 0000000..ba63071 --- /dev/null +++ b/tools/config/generate_sample.sh @@ -0,0 +1,119 @@ +#!/usr/bin/env bash + +print_hint() { + echo "Try \`${0##*/} --help' for more information." >&2 +} + +PARSED_OPTIONS=$(getopt -n "${0##*/}" -o hb:p:m:l:o: \ + --long help,base-dir:,package-name:,output-dir:,module:,library: -- "$@") + +if [ $? != 0 ] ; then print_hint ; exit 1 ; fi + +eval set -- "$PARSED_OPTIONS" + +while true; do + case "$1" in + -h|--help) + echo "${0##*/} [options]" + echo "" + echo "options:" + echo "-h, --help show brief help" + echo "-b, --base-dir=DIR project base directory" + echo "-p, --package-name=NAME project package name" + echo "-o, --output-dir=DIR file output directory" + echo "-m, --module=MOD extra python module to interrogate for options" + echo "-l, --library=LIB extra library that registers options for discovery" + exit 0 + ;; + -b|--base-dir) + shift + BASEDIR=`echo $1 | sed -e 's/\/*$//g'` + shift + ;; + -p|--package-name) + shift + PACKAGENAME=`echo $1` + shift + ;; + -o|--output-dir) + shift + OUTPUTDIR=`echo $1 | sed -e 's/\/*$//g'` + shift + ;; + -m|--module) + shift + MODULES="$MODULES -m $1" + shift + ;; + -l|--library) + shift + LIBRARIES="$LIBRARIES -l $1" + shift + ;; + --) + break + ;; + esac +done + +BASEDIR=${BASEDIR:-`pwd`} +if ! [ -d $BASEDIR ] +then + echo "${0##*/}: missing project base directory" >&2 ; print_hint ; exit 1 +elif [[ $BASEDIR != /* ]] +then + BASEDIR=$(cd "$BASEDIR" && pwd) +fi + +PACKAGENAME=${PACKAGENAME:-$(python setup.py --name)} +TARGETDIR=$BASEDIR/$PACKAGENAME +if ! [ -d $TARGETDIR ] +then + echo "${0##*/}: invalid project package name" >&2 ; print_hint ; exit 1 +fi + +OUTPUTDIR=${OUTPUTDIR:-$BASEDIR/etc} +# NOTE(bnemec): Some projects put their sample config in etc/, +# some in etc/$PACKAGENAME/ +if [ -d $OUTPUTDIR/$PACKAGENAME ] +then + OUTPUTDIR=$OUTPUTDIR/$PACKAGENAME +elif ! [ -d $OUTPUTDIR ] +then + echo "${0##*/}: cannot access \`$OUTPUTDIR': No such file or directory" >&2 + exit 1 +fi + +BASEDIRESC=`echo $BASEDIR | sed -e 's/\//\\\\\//g'` +find $TARGETDIR -type f -name "*.pyc" -delete +FILES=$(find $TARGETDIR -type f -name "*.py" ! -path "*/tests/*" \ + -exec grep -l "Opt(" {} + | sed -e "s/^$BASEDIRESC\///g" | sort -u) + +RC_FILE="`dirname $0`/oslo.config.generator.rc" +if test -r "$RC_FILE" +then + source "$RC_FILE" +fi + +for mod in ${CERBERUS_CONFIG_GENERATOR_EXTRA_MODULES}; do + MODULES="$MODULES -m $mod" +done + +for lib in ${CERBERUS_CONFIG_GENERATOR_EXTRA_LIBRARIES}; do + LIBRARIES="$LIBRARIES -l $lib" +done + +export EVENTLET_NO_GREENDNS=yes + +OS_VARS=$(set | sed -n '/^OS_/s/=[^=]*$//gp' | xargs) +[ "$OS_VARS" ] && eval "unset \$OS_VARS" +DEFAULT_MODULEPATH=cerberus.openstack.common.config.generator +MODULEPATH=${MODULEPATH:-$DEFAULT_MODULEPATH} +OUTPUTFILE=$OUTPUTDIR/$PACKAGENAME.conf.sample +python -m $MODULEPATH $MODULES $LIBRARIES $FILES > $OUTPUTFILE + +# Hook to allow projects to append custom config file snippets +CONCAT_FILES=$(ls $BASEDIR/tools/config/*.conf.sample 2>/dev/null) +for CONCAT_FILE in $CONCAT_FILES; do + cat $CONCAT_FILE >> $OUTPUTFILE +done diff --git a/tools/install_venv_common.py b/tools/install_venv_common.py new file mode 100644 index 0000000..46822e3 --- /dev/null +++ b/tools/install_venv_common.py @@ -0,0 +1,172 @@ +# Copyright 2013 OpenStack Foundation +# Copyright 2013 IBM Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Provides methods needed by installation script for OpenStack development +virtual environments. + +Since this script is used to bootstrap a virtualenv from the system's Python +environment, it should be kept strictly compatible with Python 2.6. + +Synced in from openstack-common +""" + +from __future__ import print_function + +import optparse +import os +import subprocess +import sys + + +class InstallVenv(object): + + def __init__(self, root, venv, requirements, + test_requirements, py_version, + project): + self.root = root + self.venv = venv + self.requirements = requirements + self.test_requirements = test_requirements + self.py_version = py_version + self.project = project + + def die(self, message, *args): + print(message % args, file=sys.stderr) + sys.exit(1) + + def check_python_version(self): + if sys.version_info < (2, 6): + self.die("Need Python Version >= 2.6") + + def run_command_with_code(self, cmd, redirect_output=True, + check_exit_code=True): + """Runs a command in an out-of-process shell. + + Returns the output of that command. Working directory is self.root. + """ + if redirect_output: + stdout = subprocess.PIPE + else: + stdout = None + + proc = subprocess.Popen(cmd, cwd=self.root, stdout=stdout) + output = proc.communicate()[0] + if check_exit_code and proc.returncode != 0: + self.die('Command "%s" failed.\n%s', ' '.join(cmd), output) + return (output, proc.returncode) + + def run_command(self, cmd, redirect_output=True, check_exit_code=True): + return self.run_command_with_code(cmd, redirect_output, + check_exit_code)[0] + + def get_distro(self): + if (os.path.exists('/etc/fedora-release') or + os.path.exists('/etc/redhat-release')): + return Fedora( + self.root, self.venv, self.requirements, + self.test_requirements, self.py_version, self.project) + else: + return Distro( + self.root, self.venv, self.requirements, + self.test_requirements, self.py_version, self.project) + + def check_dependencies(self): + self.get_distro().install_virtualenv() + + def create_virtualenv(self, no_site_packages=True): + """Creates the virtual environment and installs PIP. + + Creates the virtual environment and installs PIP only into the + virtual environment. + """ + if not os.path.isdir(self.venv): + print('Creating venv...', end=' ') + if no_site_packages: + self.run_command(['virtualenv', '-q', '--no-site-packages', + self.venv]) + else: + self.run_command(['virtualenv', '-q', self.venv]) + print('done.') + else: + print("venv already exists...") + pass + + def pip_install(self, *args): + self.run_command(['tools/with_venv.sh', + 'pip', 'install', '--upgrade'] + list(args), + redirect_output=False) + + def install_dependencies(self): + print('Installing dependencies with pip (this can take a while)...') + + # First things first, make sure our venv has the latest pip and + # setuptools and pbr + self.pip_install('pip>=1.4') + self.pip_install('setuptools') + self.pip_install('pbr') + + self.pip_install('-r', self.requirements, '-r', self.test_requirements) + + def parse_args(self, argv): + """Parses command-line arguments.""" + parser = optparse.OptionParser() + parser.add_option('-n', '--no-site-packages', + action='store_true', + help="Do not inherit packages from global Python " + "install") + return parser.parse_args(argv[1:])[0] + + +class Distro(InstallVenv): + + def check_cmd(self, cmd): + return bool(self.run_command(['which', cmd], + check_exit_code=False).strip()) + + def install_virtualenv(self): + if self.check_cmd('virtualenv'): + return + + if self.check_cmd('easy_install'): + print('Installing virtualenv via easy_install...', end=' ') + if self.run_command(['easy_install', 'virtualenv']): + print('Succeeded') + return + else: + print('Failed') + + self.die('ERROR: virtualenv not found.\n\n%s development' + ' requires virtualenv, please install it using your' + ' favorite package management tool' % self.project) + + +class Fedora(Distro): + """This covers all Fedora-based distributions. + + Includes: Fedora, RHEL, CentOS, Scientific Linux + """ + + def check_pkg(self, pkg): + return self.run_command_with_code(['rpm', '-q', pkg], + check_exit_code=False)[1] == 0 + + def install_virtualenv(self): + if self.check_cmd('virtualenv'): + return + + if not self.check_pkg('python-virtualenv'): + self.die("Please install 'python-virtualenv'.") + + super(Fedora, self).install_virtualenv() diff --git a/tools/pretty_tox.sh b/tools/pretty_tox.sh new file mode 100644 index 0000000..6c4759b --- /dev/null +++ b/tools/pretty_tox.sh @@ -0,0 +1,6 @@ +#! /bin/sh + +TESTRARGS=$1 + +exec 3>&1 +status=$(exec 4>&1 >&3; ( python setup.py testr --slowest --testr-args="--subunit $TESTRARGS"; echo $? >&4 ) | subunit2junitxml --output-to=junitxml-result.xml) && exit $status diff --git a/tools/subunit-trace.py b/tools/subunit-trace.py new file mode 100755 index 0000000..73f2f10 --- /dev/null +++ b/tools/subunit-trace.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python + +# Copyright 2014 Hewlett-Packard Development Company, L.P. +# Copyright 2014 Samsung Electronics +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Trace a subunit stream in reasonable detail and high accuracy.""" + +import argparse +import functools +import os +import re +import sys + +import mimeparse +import subunit +import testtools + +DAY_SECONDS = 60 * 60 * 24 +FAILS = [] +RESULTS = {} + + +class Starts(testtools.StreamResult): + + def __init__(self, output): + super(Starts, self).__init__() + self._output = output + + def startTestRun(self): + self._neednewline = False + self._emitted = set() + + def status(self, test_id=None, test_status=None, test_tags=None, + runnable=True, file_name=None, file_bytes=None, eof=False, + mime_type=None, route_code=None, timestamp=None): + super(Starts, self).status( + test_id, test_status, + test_tags=test_tags, runnable=runnable, file_name=file_name, + file_bytes=file_bytes, eof=eof, mime_type=mime_type, + route_code=route_code, timestamp=timestamp) + if not test_id: + if not file_bytes: + return + if not mime_type or mime_type == 'test/plain;charset=utf8': + mime_type = 'text/plain; charset=utf-8' + primary, sub, parameters = mimeparse.parse_mime_type(mime_type) + content_type = testtools.content_type.ContentType( + primary, sub, parameters) + content = testtools.content.Content( + content_type, lambda: [file_bytes]) + text = content.as_text() + if text and text[-1] not in '\r\n': + self._neednewline = True + self._output.write(text) + elif test_status == 'inprogress' and test_id not in self._emitted: + if self._neednewline: + self._neednewline = False + self._output.write('\n') + worker = '' + for tag in test_tags or (): + if tag.startswith('worker-'): + worker = '(' + tag[7:] + ') ' + if timestamp: + timestr = timestamp.isoformat() + else: + timestr = '' + self._output.write('%s: %s%s [start]\n' % + (timestr, worker, test_id)) + self._emitted.add(test_id) + + +def cleanup_test_name(name, strip_tags=True, strip_scenarios=False): + """Clean up the test name for display. + + By default we strip out the tags in the test because they don't help us + in identifying the test that is run to it's result. + + Make it possible to strip out the testscenarios information (not to + be confused with tempest scenarios) however that's often needed to + indentify generated negative tests. + """ + if strip_tags: + tags_start = name.find('[') + tags_end = name.find(']') + if tags_start > 0 and tags_end > tags_start: + newname = name[:tags_start] + newname += name[tags_end + 1:] + name = newname + + if strip_scenarios: + tags_start = name.find('(') + tags_end = name.find(')') + if tags_start > 0 and tags_end > tags_start: + newname = name[:tags_start] + newname += name[tags_end + 1:] + name = newname + + return name + + +def get_duration(timestamps): + start, end = timestamps + if not start or not end: + duration = '' + else: + delta = end - start + duration = '%d.%06ds' % ( + delta.days * DAY_SECONDS + delta.seconds, delta.microseconds) + return duration + + +def find_worker(test): + for tag in test['tags']: + if tag.startswith('worker-'): + return int(tag[7:]) + return 'NaN' + + +# Print out stdout/stderr if it exists, always +def print_attachments(stream, test, all_channels=False): + """Print out subunit attachments. + + Print out subunit attachments that contain content. This + runs in 2 modes, one for successes where we print out just stdout + and stderr, and an override that dumps all the attachments. + """ + channels = ('stdout', 'stderr') + for name, detail in test['details'].items(): + # NOTE(sdague): the subunit names are a little crazy, and actually + # are in the form pythonlogging:'' (with the colon and quotes) + name = name.split(':')[0] + if detail.content_type.type == 'test': + detail.content_type.type = 'text' + if (all_channels or name in channels) and detail.as_text(): + title = "Captured %s:" % name + stream.write("\n%s\n%s\n" % (title, ('~' * len(title)))) + # indent attachment lines 4 spaces to make them visually + # offset + for line in detail.as_text().split('\n'): + stream.write(" %s\n" % line) + + +def show_outcome(stream, test, print_failures=False, failonly=False): + global RESULTS + status = test['status'] + # TODO(sdague): ask lifeless why on this? + if status == 'exists': + return + + worker = find_worker(test) + name = cleanup_test_name(test['id']) + duration = get_duration(test['timestamps']) + + if worker not in RESULTS: + RESULTS[worker] = [] + RESULTS[worker].append(test) + + # don't count the end of the return code as a fail + if name == 'process-returncode': + return + + if status == 'fail': + FAILS.append(test) + stream.write('{%s} %s [%s] ... FAILED\n' % ( + worker, name, duration)) + if not print_failures: + print_attachments(stream, test, all_channels=True) + elif not failonly: + if status == 'success': + stream.write('{%s} %s [%s] ... ok\n' % ( + worker, name, duration)) + print_attachments(stream, test) + elif status == 'skip': + stream.write('{%s} %s ... SKIPPED: %s\n' % ( + worker, name, test['details']['reason'].as_text())) + else: + stream.write('{%s} %s [%s] ... %s\n' % ( + worker, name, duration, test['status'])) + if not print_failures: + print_attachments(stream, test, all_channels=True) + + stream.flush() + + +def print_fails(stream): + """Print summary failure report. + + Currently unused, however there remains debate on inline vs. at end + reporting, so leave the utility function for later use. + """ + if not FAILS: + return + stream.write("\n==============================\n") + stream.write("Failed %s tests - output below:" % len(FAILS)) + stream.write("\n==============================\n") + for f in FAILS: + stream.write("\n%s\n" % f['id']) + stream.write("%s\n" % ('-' * len(f['id']))) + print_attachments(stream, f, all_channels=True) + stream.write('\n') + + +def count_tests(key, value): + count = 0 + for k, v in RESULTS.items(): + for item in v: + if key in item: + if re.search(value, item[key]): + count += 1 + return count + + +def run_time(): + runtime = 0.0 + for k, v in RESULTS.items(): + for test in v: + runtime += float(get_duration(test['timestamps']).strip('s')) + return runtime + + +def worker_stats(worker): + tests = RESULTS[worker] + num_tests = len(tests) + delta = tests[-1]['timestamps'][1] - tests[0]['timestamps'][0] + return num_tests, delta + + +def print_summary(stream): + stream.write("\n======\nTotals\n======\n") + stream.write("Run: %s in %s sec.\n" % (count_tests('status', '.*'), + run_time())) + stream.write(" - Passed: %s\n" % count_tests('status', 'success')) + stream.write(" - Skipped: %s\n" % count_tests('status', 'skip')) + stream.write(" - Failed: %s\n" % count_tests('status', 'fail')) + + # we could have no results, especially as we filter out the process-codes + if RESULTS: + stream.write("\n==============\nWorker Balance\n==============\n") + + for w in range(max(RESULTS.keys()) + 1): + if w not in RESULTS: + stream.write( + " - WARNING: missing Worker %s! " + "Race in testr accounting.\n" % w) + else: + num, time = worker_stats(w) + stream.write(" - Worker %s (%s tests) => %ss\n" % + (w, num, time)) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--no-failure-debug', '-n', action='store_true', + dest='print_failures', help='Disable printing failure ' + 'debug information in realtime') + parser.add_argument('--fails', '-f', action='store_true', + dest='post_fails', help='Print failure debug ' + 'information after the stream is proccesed') + parser.add_argument('--failonly', action='store_true', + dest='failonly', help="Don't print success items", + default=( + os.environ.get('TRACE_FAILONLY', False) + is not False)) + return parser.parse_args() + + +def main(): + args = parse_args() + stream = subunit.ByteStreamToStreamResult( + sys.stdin, non_subunit_name='stdout') + starts = Starts(sys.stdout) + outcomes = testtools.StreamToDict( + functools.partial(show_outcome, sys.stdout, + print_failures=args.print_failures, + failonly=args.failonly + )) + summary = testtools.StreamSummary() + result = testtools.CopyStreamResult([starts, outcomes, summary]) + result.startTestRun() + try: + stream.run(result) + finally: + result.stopTestRun() + if count_tests('status', '.*') == 0: + print("The test run didn't actually run any tests") + return 1 + if args.post_fails: + print_fails(sys.stdout) + print_summary(sys.stdout) + return (0 if summary.wasSuccessful() else 1) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tox.ini b/tox.ini index 9be310a..38d3390 100644 --- a/tox.ini +++ b/tox.ini @@ -10,7 +10,6 @@ setenv = VIRTUAL_ENV={envdir} deps = -r{toxinidir}/requirements.txt -r{toxinidir}/test-requirements.txt -commands = python setup.py testr --slowest --testr-args='{posargs}' [testenv:pep8] commands = flake8 @@ -19,16 +18,33 @@ commands = flake8 commands = {posargs} [testenv:cover] -commands = python setup.py testr --coverage --testr-args='{posargs}' +commands = + nosetests --with-xunit --with-xcoverage --cover-package=cerberus --nocapture --cover-tests --cover-branches --cover-min-percentage=50 [testenv:docs] commands = python setup.py build_sphinx [flake8] -# H803 skipped on purpose per list discussion. -# E123, E125 skipped as they are invalid PEP-8. - -show-source = True -ignore = E123,E125,H803 +# E125 continuation line does not distinguish itself from next logical line +# E126 continuation line over-indented for hanging indent +# E128 continuation line under-indented for visual indent +# E129 visually indented line with same indent as next logical line +# E265 block comment should start with ‘# ‘ +# E713 test for membership should be ‘not in’ +# F402 import module shadowed by loop variable +# F811 redefinition of unused variable +# F812 list comprehension redefines name from line +# H104 file contains nothing but comments +# H237 module is removed in Python 3 +# H305 imports not grouped correctly +# H307 like imports should be grouped together +# H401 docstring should not start with a space +# H402 one line docstring needs punctuation +# H405 multi line docstring summary not separated with an empty line +# H904 Wrap long lines in parentheses instead of a backslash +# TODO(marun) H404 multi line docstring should start with a summary +ignore = E125,E126,E128,E129,E265,E713,F402,F811,F812,H104,H237,H305,H307,H401,H402,H404,H405,H904 +show-source = true builtins = _ -exclude=.venv,.git,.tox,dist,doc,*openstack/common*,*lib/python*,*egg,build \ No newline at end of file +exclude = .venv,.git,.tox,dist,doc,*openstack/common*,*lib/python*,*egg,build,tools +