Rework Auth Plugins to Support HTTP Auth

This commit reworks auth plugins slightly to enable
support for HTTP authentication.  By raising an
AuthenticationError, auth plugins can now return
HTTP responses to the upgrade request (such as 401).

Related to kanaka/noVNC#522
This commit is contained in:
Solly Ross 2015-08-25 16:44:24 -04:00
parent 1e894f0d29
commit 997e2151b3
4 changed files with 88 additions and 21 deletions

View File

@ -106,11 +106,11 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
def lookup(self, token):
return (self.source + token).split(',')
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy',
lambda *args, **kwargs: None)
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'send_auth_error',
staticmethod(lambda *args, **kwargs: None))
self.handler.server.token_plugin = TestPlugin("somehost,")
self.handler.new_websocket_client()
self.handler.validate_connection()
self.assertEqual(self.handler.server.target_host, "somehost")
self.assertEqual(self.handler.server.target_port, "blah")
@ -119,9 +119,9 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
class TestPlugin(auth_plugins.BasePlugin):
def authenticate(self, headers, target_host, target_port):
if target_host == self.source:
raise auth_plugins.AuthenticationError("some error")
raise auth_plugins.AuthenticationError(response_msg="some_error")
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy',
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'send_auth_error',
staticmethod(lambda *args, **kwargs: None))
self.handler.server.auth_plugin = TestPlugin("somehost")
@ -129,8 +129,8 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
self.handler.server.target_port = "someport"
self.assertRaises(auth_plugins.AuthenticationError,
self.handler.new_websocket_client)
self.handler.validate_connection)
self.handler.server.target_host = "someotherhost"
self.handler.new_websocket_client()
self.handler.validate_connection()

View File

@ -7,7 +7,15 @@ class BasePlugin(object):
class AuthenticationError(Exception):
pass
def __init__(self, log_msg=None, response_code=403, response_headers={}, response_msg=None):
self.code = response_code
self.headers = response_headers
self.msg = response_msg
if log_msg is None:
log_msg = response_msg
super(AuthenticationError, self).__init__('%s %s' % (self.code, log_msg))
class InvalidOriginError(AuthenticationError):
@ -16,8 +24,44 @@ class InvalidOriginError(AuthenticationError):
self.actual_origin = actual
super(InvalidOriginError, self).__init__(
"Invalid Origin Header: Expected one of "
"%s, got '%s'" % (expected, actual))
response_msg='Invalid Origin',
log_msg="Invalid Origin Header: Expected one of "
"%s, got '%s'" % (expected, actual))
class BasicHTTPAuth(object):
def __init__(self, src=None):
self.src = src
def authenticate(self, headers, target_host, target_port):
import base64
auth_header = headers.get('Authorization')
if auth_header:
if not auth_header.startswith('Basic '):
raise AuthenticationError(response_code=403)
try:
user_pass_raw = base64.b64decode(auth_header[6:])
except TypeError:
raise AuthenticationError(response_code=403)
user_pass = user_pass_raw.split(':', 1)
if len(user_pass) != 2:
raise AuthenticationError(response_code=403)
if not self.validate_creds:
raise AuthenticationError(response_code=403)
else:
raise AuthenticationError(response_code=401,
response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'})
def validate_creds(username, password):
if '%s:%s' % (username, password) == self.src:
return True
else:
return False
class ExpectOrigin(object):
def __init__(self, src=None):

View File

@ -474,9 +474,13 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
"""Upgrade a connection to Websocket, if requested. If this succeeds,
new_websocket_client() will be called. Otherwise, False is returned.
"""
if (self.headers.get('upgrade') and
self.headers.get('upgrade').lower() == 'websocket'):
# ensure connection is authorized, and determine the target
self.validate_connection()
if not self.do_websocket_handshake():
return False
@ -549,6 +553,10 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
""" Do something with a WebSockets client connection. """
raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded")
def validate_connection(self):
""" Ensure that the connection is a valid connection, and set the target. """
pass
def do_HEAD(self):
if self.only_upgrade:
self.send_error(405, "Method Not Allowed")
@ -789,7 +797,7 @@ class WebSocketServer(object):
"""
ready = select.select([sock], [], [], 3)[0]
if not ready:
raise self.EClose("ignoring socket not ready")
# Peek, but do not read the data so that we have a opportunity
@ -903,7 +911,7 @@ class WebSocketServer(object):
def top_new_client(self, startsock, address):
""" Do something with a WebSockets client connection. """
# handler process
# handler process
client = None
try:
try:

View File

@ -18,6 +18,7 @@ try: from http.server import HTTPServer
except: from BaseHTTPServer import HTTPServer
import select
from websockify import websocket
from websockify import auth_plugins as auth
try:
from urllib.parse import parse_qs, urlparse
except:
@ -37,20 +38,34 @@ Traffic Legend:
< - Client send
<. - Client send partial
"""
def send_auth_error(self, ex):
self.send_response(ex.code, ex.msg)
self.send_header('Content-Type', 'text/html')
for name, val in ex.headers.items():
self.send_header(name, val)
self.end_headers()
def validate_connection(self):
if self.server.token_plugin:
(self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path)
if self.server.auth_plugin:
try:
self.server.auth_plugin.authenticate(
headers=self.headers, target_host=self.server.target_host,
target_port=self.server.target_port)
except auth.AuthenticationError:
ex = sys.exc_info()[1]
self.send_auth_error(ex)
raise
def new_websocket_client(self):
"""
Called after a new WebSocket connection has been established.
"""
# Checks if we receive a token, and look
# for a valid target for it then
if self.server.token_plugin:
(self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path)
if self.server.auth_plugin:
self.server.auth_plugin.authenticate(
headers=self.headers, target_host=self.server.target_host,
target_port=self.server.target_port)
# Checking for a token is done in validate_connection()
# Connect to the target
if self.server.wrap_cmd: