# -*- coding: utf-8 -*-
"""
wsproto/handshake
~~~~~~~~~~~~~~~~~~
An implementation of WebSocket handshakes.
"""
from collections import deque
import h11
from .connection import Connection, ConnectionState, ConnectionType
from .events import AcceptConnection, RejectConnection, RejectData, Request
from .utilities import (
generate_accept_token,
generate_nonce,
LocalProtocolError,
normed_header_dict,
RemoteProtocolError,
split_comma_header,
)
# RFC6455, Section 4.2.1/6 - Reading the Client's Opening Handshake
WEBSOCKET_VERSION = b"13"
[docs]class H11Handshake(object):
"""A Handshake implementation for HTTP/1.1 connections."""
def __init__(self, connection_type):
# type: (ConnectionType) -> None
self.client = connection_type is ConnectionType.CLIENT
self._state = ConnectionState.CONNECTING
if self.client:
self._h11_connection = h11.Connection(h11.CLIENT)
else:
self._h11_connection = h11.Connection(h11.SERVER)
self._connection = None
self._events = deque()
self._initiating_request = None
self._nonce = None
@property
def state(self):
# type() -> ConnectionState
return self._state
@property
def connection(self):
# type() -> Optional[Connection]
"""Return the established connection.
This will either return the connection or raise a
LocalProtocolError if the connection has not yet been
established.
"""
return self._connection
[docs] def initiate_upgrade_connection(self, headers, path):
# type: (List[Tuple[bytes, bytes]], str) -> None
"""Initiate an upgrade connection.
This should be used if the request has already be received and
parsed.
"""
if self.client:
raise LocalProtocolError(
"Cannot initiate an upgrade connection when acting as the client"
)
upgrade_request = h11.Request(method=b"GET", target=path, headers=headers)
h11_client = h11.Connection(h11.CLIENT)
self.receive_data(h11_client.send(upgrade_request))
[docs] def send(self, event):
# type(Event) -> bytes
"""Send an event to the remote.
This will return the bytes to send based on the event or raise
a LocalProtocolError if the event is not valid given the
state.
"""
data = b""
if isinstance(event, Request):
data += self._initiate_connection(event)
elif isinstance(event, AcceptConnection):
data += self._accept(event)
elif isinstance(event, RejectConnection):
data += self._reject(event)
elif isinstance(event, RejectData):
data += self._send_reject_data(event)
else:
raise LocalProtocolError(
"Event {} cannot be sent during the handshake".format(event)
)
return data
[docs] def receive_data(self, data):
# type: (bytes) -> None
"""Receive data from the remote.
A list of events that the remote peer triggered by sending
this data can be retrieved with :meth:`events`.
"""
self._h11_connection.receive_data(data)
while True:
try:
event = self._h11_connection.next_event()
except h11.RemoteProtocolError:
raise RemoteProtocolError(
"Bad HTTP message", event_hint=RejectConnection()
)
if (
isinstance(event, h11.ConnectionClosed)
or event is h11.NEED_DATA
or event is h11.PAUSED
):
break
if self.client:
if isinstance(event, h11.InformationalResponse):
if event.status_code == 101:
self._events.append(self._establish_client_connection(event))
else:
self._events.append(
RejectConnection(
headers=event.headers,
status_code=event.status_code,
has_body=False,
)
)
self._state = ConnectionState.CLOSED
elif isinstance(event, h11.Response):
self._state = ConnectionState.REJECTING
self._events.append(
RejectConnection(
headers=event.headers,
status_code=event.status_code,
has_body=True,
)
)
elif isinstance(event, h11.Data):
self._events.append(
RejectData(data=event.data, body_finished=False)
)
elif isinstance(event, h11.EndOfMessage):
self._events.append(RejectData(data=b"", body_finished=True))
self._state = ConnectionState.CLOSED
else:
if isinstance(event, h11.Request):
self._events.append(self._process_connection_request(event))
def events(self):
# type() -> Generator[Event, None, None]
while self._events:
yield self._events.popleft()
############ Server mode methods
def _process_connection_request(self, event):
if event.method != b"GET":
raise RemoteProtocolError(
"Request method must be GET", event_hint=RejectConnection()
)
connection_tokens = None
extensions = []
host = None
key = None
subprotocols = []
upgrade = b""
version = None
headers = []
for name, value in event.headers:
name = name.lower()
if name == b"connection":
connection_tokens = split_comma_header(value)
elif name == b"host":
host = value.decode("ascii")
continue # Skip appending to headers
elif name == b"sec-websocket-extensions":
extensions = split_comma_header(value)
continue # Skip appending to headers
elif name == b"sec-websocket-key":
key = value
elif name == b"sec-websocket-protocol":
subprotocols = split_comma_header(value)
continue # Skip appending to headers
elif name == b"sec-websocket-version":
version = value
elif name == b"upgrade":
upgrade = value
headers.append((name, value))
if connection_tokens is None or not any(
token.lower() == "upgrade" for token in connection_tokens
):
raise RemoteProtocolError(
"Missing header, 'Connection: Upgrade'", event_hint=RejectConnection()
)
if version != WEBSOCKET_VERSION:
raise RemoteProtocolError(
"Missing header, 'Sec-WebSocket-Version'",
event_hint=RejectConnection(
headers=[(b"Sec-WebSocket-Version", WEBSOCKET_VERSION)],
status_code=426,
),
)
if key is None:
raise RemoteProtocolError(
"Missing header, 'Sec-WebSocket-Key'", event_hint=RejectConnection()
)
if upgrade.lower() != b"websocket":
raise RemoteProtocolError(
"Missing header, 'Upgrade: WebSocket'", event_hint=RejectConnection()
)
if version is None:
raise RemoteProtocolError(
"Missing header, 'Sec-WebSocket-Version'", event_hint=RejectConnection()
)
self._initiating_request = Request(
extensions=extensions,
extra_headers=headers,
host=host,
subprotocols=subprotocols,
target=event.target.decode("ascii"),
)
return self._initiating_request
def _accept(self, event):
# type: (AcceptConnection) -> None
request_headers = normed_header_dict(self._initiating_request.extra_headers)
nonce = request_headers[b"sec-websocket-key"]
accept_token = generate_accept_token(nonce)
headers = [
(b"Upgrade", b"WebSocket"),
(b"Connection", b"Upgrade"),
(b"Sec-WebSocket-Accept", accept_token),
]
if event.subprotocol is not None:
if event.subprotocol not in self._initiating_request.subprotocols:
raise LocalProtocolError(
"unexpected subprotocol {}".format(event.subprotocol)
)
headers.append(
(b"Sec-WebSocket-Protocol", event.subprotocol.encode("ascii"))
)
if event.extensions:
accepts = server_extensions_handshake(
self._initiating_request.extensions, event.extensions
)
if accepts:
headers.append((b"Sec-WebSocket-Extensions", accepts))
response = h11.InformationalResponse(
status_code=101, headers=headers + event.extra_headers
)
self._connection = Connection(
ConnectionType.CLIENT if self.client else ConnectionType.SERVER,
event.extensions,
)
self._state = ConnectionState.OPEN
return self._h11_connection.send(response)
def _reject(self, event):
# type: (RejectConnection) -> bytes
if self.state != ConnectionState.CONNECTING:
raise LocalProtocolError(
"Connection cannot be rejected in state %s" % self.state
)
headers = event.headers
if not event.has_body:
headers.append((b"content-length", b"0"))
response = h11.Response(status_code=event.status_code, headers=headers)
data = self._h11_connection.send(response)
self._state = ConnectionState.REJECTING
if not event.has_body:
data += self._h11_connection.send(h11.EndOfMessage())
self._state = ConnectionState.CLOSED
return data
def _send_reject_data(self, event):
# type: (RejectData) -> bytes
if self.state != ConnectionState.REJECTING:
raise LocalProtocolError(
"Cannot send rejection data in state {}".format(self.state)
)
data = self._h11_connection.send(h11.Data(data=event.data))
if event.body_finished:
data += self._h11_connection.send(h11.EndOfMessage())
self._state = ConnectionState.CLOSED
return data
############ Client mode methods
def _initiate_connection(self, request):
# type: (Request) -> bytes
self._initiating_request = request
self._nonce = generate_nonce()
headers = [
(b"Host", request.host.encode("ascii")),
(b"Upgrade", b"WebSocket"),
(b"Connection", b"Upgrade"),
(b"Sec-WebSocket-Key", self._nonce),
(b"Sec-WebSocket-Version", WEBSOCKET_VERSION),
]
if request.subprotocols:
headers.append(
(
b"Sec-WebSocket-Protocol",
(", ".join(request.subprotocols)).encode("ascii"),
)
)
if request.extensions:
offers = {e.name: e.offer() for e in request.extensions}
extensions = []
for name, params in offers.items():
if params is True:
extensions.append(name.encode("ascii"))
elif params:
# py34 annoyance: doesn't support bytestring formatting
extensions.append(("%s; %s" % (name, params)).encode("ascii"))
if extensions:
headers.append((b"Sec-WebSocket-Extensions", b", ".join(extensions)))
upgrade = h11.Request(
method=b"GET",
target=request.target.encode("ascii"),
headers=headers + request.extra_headers,
)
return self._h11_connection.send(upgrade)
def _establish_client_connection(self, event): # noqa: MC0001
accept = None
connection_tokens = None
accepts = []
subprotocol = None
upgrade = b""
headers = []
for name, value in event.headers:
name = name.lower()
if name == b"connection":
connection_tokens = split_comma_header(value)
continue # Skip appending to headers
elif name == b"sec-websocket-extensions":
accepts = split_comma_header(value)
continue # Skip appending to headers
elif name == b"sec-websocket-accept":
accept = value
continue # Skip appending to headers
elif name == b"sec-websocket-protocol":
subprotocol = value
continue # Skip appending to headers
elif name == b"upgrade":
upgrade = value
continue # Skip appending to headers
headers.append((name, value))
if connection_tokens is None or not any(
token.lower() == "upgrade" for token in connection_tokens
):
raise RemoteProtocolError(
"Missing header, 'Connection: Upgrade'", event_hint=RejectConnection()
)
if upgrade.lower() != b"websocket":
raise RemoteProtocolError(
"Missing header, 'Upgrade: WebSocket'", event_hint=RejectConnection()
)
accept_token = generate_accept_token(self._nonce)
if accept != accept_token:
raise RemoteProtocolError("Bad accept token", event_hint=RejectConnection())
if subprotocol is not None:
subprotocol = subprotocol.decode("ascii")
if subprotocol not in self._initiating_request.subprotocols:
raise RemoteProtocolError(
"unrecognized subprotocol {}".format(subprotocol),
event_hint=RejectConnection(),
)
extensions = client_extensions_handshake(
accepts, self._initiating_request.extensions
)
self._connection = Connection(
ConnectionType.CLIENT if self.client else ConnectionType.SERVER,
extensions,
self._h11_connection.trailing_data[0],
)
self._state = ConnectionState.OPEN
return AcceptConnection(
extensions=extensions, extra_headers=headers, subprotocol=subprotocol
)
def __repr__(self):
return "{}(client={}, state={})".format(
self.__class__.__name__, self.client, self.state
)
[docs]def server_extensions_handshake(requested, supported):
# type: (List[str], List[Extension]) -> Optional[bytes]
"""Agree on the extensions to use returning an appropriate header value.
This returns None if there are no agreed extensions
"""
accepts = {}
for offer in requested:
name = offer.split(";", 1)[0].strip()
for extension in supported:
if extension.name == name:
accept = extension.accept(offer)
if accept is True:
accepts[extension.name] = True
elif accept is not False and accept is not None:
accepts[extension.name] = accept.encode("ascii")
if accepts:
extensions = []
for name, params in accepts.items():
if params is True:
extensions.append(name.encode("ascii"))
else:
# py34 annoyance: doesn't support bytestring formatting
params = params.decode("ascii")
if params == "":
extensions.append(("%s" % (name)).encode("ascii"))
else:
extensions.append(("%s; %s" % (name, params)).encode("ascii"))
return b", ".join(extensions)
return None
[docs]def client_extensions_handshake(accepted, supported):
# type: (List[str], List[Extension]) -> List[Extension]
# This raises RemoteProtocolError is the accepted extension is not
# supported.
extensions = []
for accept in accepted:
name = accept.split(";", 1)[0].strip()
for extension in supported:
if extension.name == name:
extension.finalize(accept)
extensions.append(extension)
break
else:
raise RemoteProtocolError(
"unrecognized extension {}".format(name), event_hint=RejectConnection()
)
return extensions