Ensure domain stored in memcached gets utf8 decoded on py2

Change-Id: I73b5af9645f3f7349144384609bf18a79620e92f
Closes-Bug: #1862115
This commit is contained in:
Charles Hsu 2020-02-06 15:35:11 +08:00 committed by Tim Burke
parent c0b4d644df
commit 61bf5ee1c4
2 changed files with 20 additions and 6 deletions

View File

@ -27,7 +27,7 @@ maximum lookup depth. If a match is found, the environment's Host header is
rewritten and the request is passed further down the WSGI chain. rewritten and the request is passed further down the WSGI chain.
""" """
from six.moves import range import six
from swift import gettext_ as _ from swift import gettext_ as _
@ -41,7 +41,8 @@ else: # executed if the try block finishes with no errors
MODULE_DEPENDENCY_MET = True MODULE_DEPENDENCY_MET = True
from swift.common.middleware import RewriteContext from swift.common.middleware import RewriteContext
from swift.common.swob import Request, HTTPBadRequest from swift.common.swob import Request, HTTPBadRequest, \
str_to_wsgi, wsgi_to_str
from swift.common.utils import cache_from_env, get_logger, is_valid_ip, \ from swift.common.utils import cache_from_env, get_logger, is_valid_ip, \
list_from_csv, parse_socket_string, register_swift_info list_from_csv, parse_socket_string, register_swift_info
@ -130,9 +131,10 @@ class CNAMELookupMiddleware(object):
if not self.storage_domain: if not self.storage_domain:
return self.app(env, start_response) return self.app(env, start_response)
if 'HTTP_HOST' in env: if 'HTTP_HOST' in env:
requested_host = given_domain = env['HTTP_HOST'] requested_host = env['HTTP_HOST']
else: else:
requested_host = given_domain = env['SERVER_NAME'] requested_host = env['SERVER_NAME']
given_domain = wsgi_to_str(requested_host)
port = '' port = ''
if ':' in given_domain: if ':' in given_domain:
given_domain, port = given_domain.rsplit(':', 1) given_domain, port = given_domain.rsplit(':', 1)
@ -148,6 +150,8 @@ class CNAMELookupMiddleware(object):
if self.memcache: if self.memcache:
memcache_key = ''.join(['cname-', a_domain]) memcache_key = ''.join(['cname-', a_domain])
found_domain = self.memcache.get(memcache_key) found_domain = self.memcache.get(memcache_key)
if six.PY2 and found_domain:
found_domain = found_domain.encode('utf-8')
if found_domain is None: if found_domain is None:
ttl, found_domain = lookup_cname(a_domain, self.resolver) ttl, found_domain = lookup_cname(a_domain, self.resolver)
if self.memcache and ttl > 0: if self.memcache and ttl > 0:
@ -166,9 +170,10 @@ class CNAMELookupMiddleware(object):
{'given_domain': given_domain, {'given_domain': given_domain,
'found_domain': found_domain}) 'found_domain': found_domain})
if port: if port:
env['HTTP_HOST'] = ':'.join([found_domain, port]) env['HTTP_HOST'] = ':'.join([
str_to_wsgi(found_domain), port])
else: else:
env['HTTP_HOST'] = found_domain env['HTTP_HOST'] = str_to_wsgi(found_domain)
error = False error = False
break break
else: else:

View File

@ -170,6 +170,10 @@ class TestCNAMELookup(unittest.TestCase):
return self.cache.get(key, None) return self.cache.get(key, None)
def set(self, key, value, *a, **kw): def set(self, key, value, *a, **kw):
# real memcache client will JSON-serialize, so our mock
# should be sure to return unicode
if isinstance(value, bytes):
value = value.decode('utf-8')
self.cache[key] = value self.cache[key] = value
module = 'swift.common.middleware.cname_lookup.lookup_cname' module = 'swift.common.middleware.cname_lookup.lookup_cname'
@ -186,6 +190,9 @@ class TestCNAMELookup(unittest.TestCase):
self.assertEqual(m.call_count, 1) self.assertEqual(m.call_count, 1)
self.assertEqual(memcache.cache.get('cname-mysite2.com'), self.assertEqual(memcache.cache.get('cname-mysite2.com'),
'c.example.com') 'c.example.com')
self.assertIsInstance(req.environ['HTTP_HOST'], str)
self.assertEqual(req.environ['HTTP_HOST'], 'c.example.com')
req = Request.blank('/', environ={'REQUEST_METHOD': 'GET', req = Request.blank('/', environ={'REQUEST_METHOD': 'GET',
'swift.cache': memcache}, 'swift.cache': memcache},
headers={'Host': 'mysite2.com'}) headers={'Host': 'mysite2.com'})
@ -194,6 +201,8 @@ class TestCNAMELookup(unittest.TestCase):
self.assertEqual(m.call_count, 1) self.assertEqual(m.call_count, 1)
self.assertEqual(memcache.cache.get('cname-mysite2.com'), self.assertEqual(memcache.cache.get('cname-mysite2.com'),
'c.example.com') 'c.example.com')
self.assertIsInstance(req.environ['HTTP_HOST'], str)
self.assertEqual(req.environ['HTTP_HOST'], 'c.example.com')
for exc, num in ((dns.resolver.NXDOMAIN(), 3), for exc, num in ((dns.resolver.NXDOMAIN(), 3),
(dns.resolver.NoAnswer(), 4)): (dns.resolver.NoAnswer(), 4)):