Correct query loss when using parse_qsl to dict
This commit is contained in:
parent
4c7b3be5a1
commit
ebe9ed0bbb
|
@ -179,6 +179,54 @@ def string_to_scopes(scopes):
|
|||
return scopes
|
||||
|
||||
|
||||
def parse_unique_urlencoded(content):
|
||||
"""Parses unique key-value parameters from urlencoded content.
|
||||
|
||||
Args:
|
||||
content: string, URL-encoded key-value pairs.
|
||||
|
||||
Returns:
|
||||
dict, The key-value pairs from ``content``.
|
||||
|
||||
Raises:
|
||||
ValueError: if one of the keys is repeated.
|
||||
"""
|
||||
urlencoded_params = urllib.parse.parse_qs(content)
|
||||
params = {}
|
||||
for key, value in six.iteritems(urlencoded_params):
|
||||
if len(value) != 1:
|
||||
msg = ('URL-encoded content contains a repeated value:'
|
||||
'%s -> %s' % (key, ', '.join(value)))
|
||||
raise ValueError(msg)
|
||||
params[key] = value[0]
|
||||
return params
|
||||
|
||||
|
||||
def update_query_params(uri, params):
|
||||
"""Updates a URI with new query parameters.
|
||||
|
||||
If a given key from ``params`` is repeated in the ``uri``, then
|
||||
the URI will be considered invalid and an error will occur.
|
||||
|
||||
If the URI is valid, then each value from ``params`` will
|
||||
replace the corresponding value in the query parameters (if
|
||||
it exists).
|
||||
|
||||
Args:
|
||||
uri: string, A valid URI, with potential existing query parameters.
|
||||
params: dict, A dictionary of query parameters.
|
||||
|
||||
Returns:
|
||||
The same URI but with the new query parameters added.
|
||||
"""
|
||||
parts = urllib.parse.urlparse(uri)
|
||||
query_params = parse_unique_urlencoded(parts.query)
|
||||
query_params.update(params)
|
||||
new_query = urllib.parse.urlencode(query_params)
|
||||
new_parts = parts._replace(query=new_query)
|
||||
return urllib.parse.urlunparse(new_parts)
|
||||
|
||||
|
||||
def _add_query_parameter(url, name, value):
|
||||
"""Adds a query parameter to a url.
|
||||
|
||||
|
@ -195,11 +243,7 @@ def _add_query_parameter(url, name, value):
|
|||
if value is None:
|
||||
return url
|
||||
else:
|
||||
parsed = list(urllib.parse.urlparse(url))
|
||||
query = dict(urllib.parse.parse_qsl(parsed[4]))
|
||||
query[name] = value
|
||||
parsed[4] = urllib.parse.urlencode(query)
|
||||
return urllib.parse.urlunparse(parsed)
|
||||
return update_query_params(url, {name: value})
|
||||
|
||||
|
||||
def validate_file(filename):
|
||||
|
|
|
@ -438,23 +438,6 @@ class Storage(object):
|
|||
self.release_lock()
|
||||
|
||||
|
||||
def _update_query_params(uri, params):
|
||||
"""Updates a URI with new query parameters.
|
||||
|
||||
Args:
|
||||
uri: string, A valid URI, with potential existing query parameters.
|
||||
params: dict, A dictionary of query parameters.
|
||||
|
||||
Returns:
|
||||
The same URI but with the new query parameters added.
|
||||
"""
|
||||
parts = urllib.parse.urlparse(uri)
|
||||
query_params = dict(urllib.parse.parse_qsl(parts.query))
|
||||
query_params.update(params)
|
||||
new_parts = parts._replace(query=urllib.parse.urlencode(query_params))
|
||||
return urllib.parse.urlunparse(new_parts)
|
||||
|
||||
|
||||
class OAuth2Credentials(Credentials):
|
||||
"""Credentials object for OAuth 2.0.
|
||||
|
||||
|
@ -850,7 +833,8 @@ class OAuth2Credentials(Credentials):
|
|||
"""
|
||||
logger.info('Revoking token')
|
||||
query_params = {'token': token}
|
||||
token_revoke_uri = _update_query_params(self.revoke_uri, query_params)
|
||||
token_revoke_uri = _helpers.update_query_params(
|
||||
self.revoke_uri, query_params)
|
||||
resp, content = transport.request(http, token_revoke_uri)
|
||||
if resp.status == http_client.OK:
|
||||
self.invalid = True
|
||||
|
@ -889,8 +873,8 @@ class OAuth2Credentials(Credentials):
|
|||
"""
|
||||
logger.info('Refreshing scopes')
|
||||
query_params = {'access_token': token, 'fields': 'scope'}
|
||||
token_info_uri = _update_query_params(self.token_info_uri,
|
||||
query_params)
|
||||
token_info_uri = _helpers.update_query_params(
|
||||
self.token_info_uri, query_params)
|
||||
resp, content = transport.request(http, token_info_uri)
|
||||
content = _helpers._from_bytes(content)
|
||||
if resp.status == http_client.OK:
|
||||
|
@ -1610,7 +1594,7 @@ def _parse_exchange_token_response(content):
|
|||
except Exception:
|
||||
# different JSON libs raise different exceptions,
|
||||
# so we just do a catch-all here
|
||||
resp = dict(urllib.parse.parse_qsl(content))
|
||||
resp = _helpers.parse_unique_urlencoded(content)
|
||||
|
||||
# some providers respond with 'expires', others with 'expires_in'
|
||||
if resp and 'expires' in resp:
|
||||
|
@ -1943,7 +1927,7 @@ class OAuth2WebServerFlow(Flow):
|
|||
query_params['code_challenge_method'] = 'S256'
|
||||
|
||||
query_params.update(self.params)
|
||||
return _update_query_params(self.auth_uri, query_params)
|
||||
return _helpers.update_query_params(self.auth_uri, query_params)
|
||||
|
||||
@_helpers.positional(1)
|
||||
def step1_get_device_and_user_codes(self, http=None):
|
||||
|
|
|
@ -122,16 +122,16 @@ class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler):
|
|||
if an error occurred.
|
||||
"""
|
||||
self.send_response(http_client.OK)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.send_header('Content-type', 'text/html')
|
||||
self.end_headers()
|
||||
query = self.path.split('?', 1)[-1]
|
||||
query = dict(urllib.parse.parse_qsl(query))
|
||||
parts = urllib.parse.urlparse(self.path)
|
||||
query = _helpers.parse_unique_urlencoded(parts.query)
|
||||
self.server.query_params = query
|
||||
self.wfile.write(
|
||||
b"<html><head><title>Authentication Status</title></head>")
|
||||
b'<html><head><title>Authentication Status</title></head>')
|
||||
self.wfile.write(
|
||||
b"<body><p>The authentication flow has completed.</p>")
|
||||
self.wfile.write(b"</body></html>")
|
||||
b'<body><p>The authentication flow has completed.</p>')
|
||||
self.wfile.write(b'</body></html>')
|
||||
|
||||
def log_message(self, format, *args):
|
||||
"""Do not log messages to stdout while running as cmd. line program."""
|
||||
|
|
|
@ -19,6 +19,7 @@ import unittest
|
|||
import mock
|
||||
|
||||
from oauth2client import _helpers
|
||||
from tests import test_client
|
||||
|
||||
|
||||
class PositionalTests(unittest.TestCase):
|
||||
|
@ -242,3 +243,42 @@ class Test__urlsafe_b64decode(unittest.TestCase):
|
|||
bad_string = b'+'
|
||||
with self.assertRaises((TypeError, binascii.Error)):
|
||||
_helpers._urlsafe_b64decode(bad_string)
|
||||
|
||||
|
||||
class Test_update_query_params(unittest.TestCase):
|
||||
|
||||
def test_update_query_params_no_params(self):
|
||||
uri = 'http://www.google.com'
|
||||
updated = _helpers.update_query_params(uri, {'a': 'b'})
|
||||
self.assertEqual(updated, uri + '?a=b')
|
||||
|
||||
def test_update_query_params_existing_params(self):
|
||||
uri = 'http://www.google.com?x=y'
|
||||
updated = _helpers.update_query_params(uri, {'a': 'b', 'c': 'd&'})
|
||||
hardcoded_update = uri + '&a=b&c=d%26'
|
||||
test_client.assertUrisEqual(self, updated, hardcoded_update)
|
||||
|
||||
def test_update_query_params_replace_param(self):
|
||||
base_uri = 'http://www.google.com'
|
||||
uri = base_uri + '?x=a'
|
||||
updated = _helpers.update_query_params(uri, {'x': 'b', 'y': 'c'})
|
||||
hardcoded_update = base_uri + '?x=b&y=c'
|
||||
test_client.assertUrisEqual(self, updated, hardcoded_update)
|
||||
|
||||
def test_update_query_params_repeated_params(self):
|
||||
uri = 'http://www.google.com?x=a&x=b'
|
||||
with self.assertRaises(ValueError):
|
||||
_helpers.update_query_params(uri, {'a': 'c'})
|
||||
|
||||
|
||||
class Test_parse_unique_urlencoded(unittest.TestCase):
|
||||
|
||||
def test_without_repeats(self):
|
||||
content = 'a=b&c=d'
|
||||
result = _helpers.parse_unique_urlencoded(content)
|
||||
self.assertEqual(result, {'a': 'b', 'c': 'd'})
|
||||
|
||||
def test_with_repeats(self):
|
||||
content = 'a=b&a=d'
|
||||
with self.assertRaises(ValueError):
|
||||
_helpers.parse_unique_urlencoded(content)
|
||||
|
|
|
@ -1364,7 +1364,7 @@ class BasicCredentialsTests(unittest.TestCase):
|
|||
self.assertEqual(credentials.scopes, set())
|
||||
self.assertEqual(exc_manager.exception.args, (error_msg,))
|
||||
|
||||
token_uri = client._update_query_params(
|
||||
token_uri = _helpers.update_query_params(
|
||||
oauth2client.GOOGLE_TOKEN_INFO_URI,
|
||||
{'fields': 'scope', 'access_token': token})
|
||||
|
||||
|
@ -1558,19 +1558,6 @@ class TestAssertionCredentials(unittest.TestCase):
|
|||
credentials.sign_blob(b'blob')
|
||||
|
||||
|
||||
class UpdateQueryParamsTest(unittest.TestCase):
|
||||
def test_update_query_params_no_params(self):
|
||||
uri = 'http://www.google.com'
|
||||
updated = client._update_query_params(uri, {'a': 'b'})
|
||||
self.assertEqual(updated, uri + '?a=b')
|
||||
|
||||
def test_update_query_params_existing_params(self):
|
||||
uri = 'http://www.google.com?x=y'
|
||||
updated = client._update_query_params(uri, {'a': 'b', 'c': 'd&'})
|
||||
hardcoded_update = uri + '&a=b&c=d%26'
|
||||
assertUrisEqual(self, updated, hardcoded_update)
|
||||
|
||||
|
||||
class ExtractIdTokenTest(unittest.TestCase):
|
||||
"""Tests client._extract_id_token()."""
|
||||
|
||||
|
@ -1670,7 +1657,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
|
|||
'access_type': 'offline',
|
||||
'response_type': 'code',
|
||||
}
|
||||
expected = client._update_query_params(flow.auth_uri, query_params)
|
||||
expected = _helpers.update_query_params(flow.auth_uri, query_params)
|
||||
assertUrisEqual(self, expected, result)
|
||||
# Check stubs.
|
||||
self.assertEqual(logger.warning.call_count, 1)
|
||||
|
@ -1735,7 +1722,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
|
|||
'access_type': 'offline',
|
||||
'response_type': 'code',
|
||||
}
|
||||
expected = client._update_query_params(flow.auth_uri, query_params)
|
||||
expected = _helpers.update_query_params(flow.auth_uri, query_params)
|
||||
assertUrisEqual(self, expected, result)
|
||||
|
||||
def test_step1_get_device_and_user_codes_wo_device_uri(self):
|
||||
|
|
Loading…
Reference in New Issue