# Copyright 2016 Massachusetts Open Cloud # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. import collections import six import requests from urllib3.util import retry import flask from flask import abort from mixmatch import config from mixmatch.config import LOG, CONF, service_providers from mixmatch.session import app from mixmatch.session import chunked_reader from mixmatch.session import request from mixmatch import auth from mixmatch import extend from mixmatch import model from mixmatch import services from mixmatch import utils METHODS_ACCEPTED = ['GET', 'POST', 'PUT', 'DELETE', 'HEAD', 'PATCH'] RESOURCES_AGGREGATE = ['images', 'volumes', 'snapshots'] def stream_response(response): yield response.raw.read() def get_service(a): """Determine service type based on path.""" # NOTE(knikolla): Workaround to fix glance requests that do not # contain image as the first part of the path. if a[0] in ['v1', 'v2']: return 'image' else: service = a.pop(0) if service in ['image', 'volume', 'network']: return service else: abort(404) def get_details(method, orig_path, headers): """Get details for a request.""" path = orig_path.split('/') # NOTE(knikolla): Request usually look like: # ///// # or # //// return {'method': method, 'service': get_service(path), 'version': utils.safe_pop(path), 'project_id': utils.pop_if_uuid(path), 'action': path[:], # NOTE(knikolla): This includes 'resource_type': utils.safe_pop(path), # this 'resource_id': utils.pop_if_uuid(path), # and this 'token': headers.get('X-AUTH-TOKEN', None), 'headers': dict(headers), 'path': orig_path} def is_json_response(response): return response.headers.get('Content-Type') == 'application/json' def is_token_header_key(string): return string.lower() in ['x-auth-token', 'x-service-token'] def strip_tokens_from_headers(headers): return {k: '' if is_token_header_key(k) else headers[k] for k in headers} def format_for_log(title=None, method=None, url=None, headers=None, status_code=None): return ''.join([ '{}:\n'.format(title) if title else '', 'Method: {}\n'.format(method) if method else '', 'URL: {}\n'.format(url) if url else '', 'Headers: {}\n'.format( strip_tokens_from_headers(headers)) if headers else '', 'Status Code: {}\n'.format(status_code) if status_code else '' ]) class RequestHandler(object): def __init__(self, method, path, headers): self.details = get_details(method, path, headers) self.extensions = extend.get_matched_extensions(self.details) self._set_strip_details(self.details) self.enabled_sps = filter( lambda sp: (self.details['service'] in service_providers.get(CONF, sp).enabled_services), CONF.service_providers ) for extension in self.extensions: extension.handle_request(self.details) if not self.details['version']: if CONF.aggregation: # unversioned calls with no action self._forward = self._list_api_versions else: self._forward = self._local_forward return if not self.details['resource_type']: # versioned calls with no action abort(400) mapping = None if self.details['resource_id']: mapping = model.ResourceMapping.find( resource_type=self.details['action'][0], resource_id=self.details['resource_id']) LOG.debug('Local Token: %s ' % self.details['token']) if 'MM-SERVICE-PROVIDER' in self.details['headers']: # The user wants a specific service provider, use that SP. # FIXME(knikolla): This isn't exercised by any unit test (self.service_provider, self.project_id) = ( self.details['headers']['MM-SERVICE-PROVIDER'], self.details['headers'].get('MM-PROJECT-ID', None) ) if self.service_provider not in self.enabled_sps: abort(400) if not self.project_id and self.service_provider != 'default': self.project_id = auth.get_projects_at_sp( self.service_provider, self.details['token'] )[0] self._forward = self._targeted_forward elif mapping: # Which we already know the location of, use that SP. self.service_provider = mapping.resource_sp self.project_id = mapping.project_id self._forward = self._targeted_forward else: self._forward = self._forward LOG.info(format_for_log(title="Request to proxy", method=self.details['method'], url=self.details['path'], headers=dict(self.details['headers']))) def _do_request_on(self, sp, project_id=None): headers = self._prepare_headers(self.details['headers']) if self.details['token']: if sp == 'default': auth_session = auth.get_local_auth(self.details['token']) else: auth_session = auth.get_sp_auth(sp, self.details['token'], project_id) headers['X-AUTH-TOKEN'] = auth_session.get_token() project_id = auth_session.get_project_id() else: project_id = None url = services.construct_url( sp, self.details['service'], self.details['version'], self.details['action'], project_id=project_id ) request_kwargs = { 'method': self.details['method'], 'url': url, 'headers': headers, 'params': self._prepare_args(request.args) } if self.chunked: resp = self.session.request(data=chunked_reader(), **request_kwargs) else: resp = self.session.request(data=request.data, stream=self.stream, **request_kwargs) LOG.info(format_for_log(title='Request from proxy', method=self.details['method'], url=url, headers=headers)) LOG.info(format_for_log(title='Response to proxy', status_code=resp.status_code, headers=resp.headers)) return resp def _finalize(self, response): if self.stream and not is_json_response(response): text = flask.stream_with_context( stream_response(response)) else: text = response.text final_response = flask.Response( text, response.status_code, headers=self._prepare_headers(response.headers) ) LOG.info(format_for_log(title='Response from proxy', status_code=final_response.status_code, headers=dict(final_response.headers))) return final_response def _local_forward(self): return self._finalize(self._do_request_on('default')) def _targeted_forward(self): return self._finalize( self._do_request_on(self.service_provider, self.project_id)) def _forward(self): if self.fallback_to_local: return self._local_forward() responses = {} errors = collections.defaultdict(lambda: []) for sp in self.enabled_sps: if sp == 'default': r = self._do_request_on('default') if 200 <= r.status_code < 300: responses['default'] = r if not self.aggregate: return self._finalize(r) else: errors[r.status_code].append(r) else: for p in auth.get_projects_at_sp(sp, self.details['token']): r = self._do_request_on(sp, p) if 200 <= r.status_code < 300: responses[(sp, p)] = r if not self.aggregate: return self._finalize(r) else: errors[r.status_code].append(r) # NOTE(knikolla): If we haven't returned yet, either we're aggregating # or there are errors. if not errors: # TODO(knikolla): Plug this into _finalize to have a common path # for everything that is returned. return flask.Response( services.aggregate(responses, self.details['action'][0], self.details['service'], version=self.details['version'], params=dict(request.args), path=request.base_url, strip_details=self.strip_details), 200, content_type='application/json' ) if six.viewkeys(errors) == {404}: return self._finalize(errors[404][0]) else: utils.safe_pop(errors, 404) if len(errors.keys()) == 1: return self._finalize(list(errors.values())[0][0]) # TODO(jfreud): log return flask.Response("Something strange happened.\n", 500) def _list_api_versions(self): return services.list_api_versions(self.details['service'], request.base_url) def forward(self): return self._forward() @staticmethod def _prepare_headers(user_headers): headers = dict() headers['Accept'] = user_headers.get('Accept', '') headers['Content-Type'] = user_headers.get('Content-Type', '') accepted_headers = ['openstack-api-version'] for key, value in user_headers.items(): k = key.lower() if ((k.startswith('x-') and not is_token_header_key(key)) or k in accepted_headers): headers[key] = value return headers @staticmethod def _prepare_args(user_args): """Prepare the GET arguments by removing the limit and marker. This is because the id of the marker will only be present in one of the service providers. """ args = dict(user_args) if CONF.aggregation: args.pop('limit', None) args.pop('marker', None) return args @utils.CachedProperty def chunked(self): encoding = self.details['headers'].get('Transfer-Encoding', '') return encoding.lower() == 'chunked' @utils.CachedProperty def stream(self): return self.details['method'] == 'GET' @utils.CachedProperty def fallback_to_local(self): """Return true if can't aggregate or search.""" return ((self.aggregate and not CONF.aggregation) or (not self.aggregate and not CONF.search_by_broadcast)) @utils.CachedProperty def aggregate(self): """Return true if this is a case where we should aggregate.""" return (not self.details['resource_id'] and self.details['method'] == 'GET' and self.details['action'][0] in RESOURCES_AGGREGATE) @utils.CachedProperty def session(self): requests_session = requests.Session() adapter = requests.adapters.HTTPAdapter( max_retries=retry.Retry(total=3, read=3, connect=3, backoff_factor=0.3, status_forcelist=[500, 502, 504]) ) requests_session.mount('http://', adapter=adapter) requests_session.mount('https://', adapter=adapter) return requests_session def _set_strip_details(self, details): # if request is to /volumes, change it # to /volumes/detail for aggregation # and set strip_details to true if (details['service'] == 'volume' and details['method'] == 'GET' and utils.safe_get(details['action'], -1) == 'volumes'): self.strip_details = True details['action'].insert(len(details['action']), 'detail') else: self.strip_details = False @app.route('/', defaults={'path': ''}, methods=METHODS_ACCEPTED) @app.route('/', methods=METHODS_ACCEPTED) def proxy(path): k2k_request = RequestHandler(request.method, path, request.headers) return k2k_request.forward() def main(): config.configure() model.create_tables() extend.load_extensions() if __name__ == "__main__": main() app.run(port=5001, threaded=True)