Source code for wsproto.extensions

# -*- coding: utf-8 -*-
"""
wsproto/extensions
~~~~~~~~~~~~~~~~~~

WebSocket extensions.
"""

import zlib

from .frame_protocol import CloseReason, Opcode, RsvBits


[docs]class Extension(object): name = None def enabled(self): return False def offer(self, connection): pass def accept(self, connection, offer): pass def finalize(self, connection, offer): pass def frame_inbound_header(self, proto, opcode, rsv, payload_length): return RsvBits(False, False, False) def frame_inbound_payload_data(self, proto, data): return data def frame_inbound_complete(self, proto, fin): pass def frame_outbound(self, proto, opcode, rsv, data, fin): return (rsv, data)
class PerMessageDeflate(Extension): name = 'permessage-deflate' DEFAULT_CLIENT_MAX_WINDOW_BITS = 15 DEFAULT_SERVER_MAX_WINDOW_BITS = 15 def __init__(self, client_no_context_takeover=False, client_max_window_bits=None, server_no_context_takeover=False, server_max_window_bits=None): self.client_no_context_takeover = client_no_context_takeover if client_max_window_bits is None: client_max_window_bits = self.DEFAULT_CLIENT_MAX_WINDOW_BITS self.client_max_window_bits = client_max_window_bits self.server_no_context_takeover = server_no_context_takeover if server_max_window_bits is None: server_max_window_bits = self.DEFAULT_SERVER_MAX_WINDOW_BITS self.server_max_window_bits = server_max_window_bits self._compressor = None self._decompressor = None # This refers to the current frame self._inbound_is_compressible = None # This refers to the ongoing message (which might span multiple # frames). Only the first frame in a fragmented message is flagged for # compression, so this carries that bit forward. self._inbound_compressed = None self._enabled = False def _compressible_opcode(self, opcode): return opcode in (Opcode.TEXT, Opcode.BINARY, Opcode.CONTINUATION) def enabled(self): return self._enabled def offer(self, connection): parameters = [ 'client_max_window_bits=%d' % self.client_max_window_bits, 'server_max_window_bits=%d' % self.server_max_window_bits, ] if self.client_no_context_takeover: parameters.append('client_no_context_takeover') if self.server_no_context_takeover: parameters.append('server_no_context_takeover') return '; '.join(parameters) def finalize(self, connection, offer): bits = [b.strip() for b in offer.split(';')] for bit in bits[1:]: if bit.startswith('client_no_context_takeover'): self.client_no_context_takeover = True elif bit.startswith('server_no_context_takeover'): self.server_no_context_takeover = True elif bit.startswith('client_max_window_bits'): self.client_max_window_bits = int(bit.split('=', 1)[1].strip()) elif bit.startswith('server_max_window_bits'): self.server_max_window_bits = int(bit.split('=', 1)[1].strip()) self._enabled = True def _parse_params(self, params): client_max_window_bits = None server_max_window_bits = None bits = [b.strip() for b in params.split(';')] for bit in bits[1:]: if bit.startswith('client_no_context_takeover'): self.client_no_context_takeover = True elif bit.startswith('server_no_context_takeover'): self.server_no_context_takeover = True elif bit.startswith('client_max_window_bits'): if '=' in bit: client_max_window_bits = int(bit.split('=', 1)[1].strip()) else: client_max_window_bits = self.client_max_window_bits elif bit.startswith('server_max_window_bits'): if '=' in bit: server_max_window_bits = int(bit.split('=', 1)[1].strip()) else: server_max_window_bits = self.server_max_window_bits return client_max_window_bits, server_max_window_bits def accept(self, connection, offer): client_max_window_bits, server_max_window_bits = \ self._parse_params(offer) self._enabled = True parameters = [] if self.client_no_context_takeover: parameters.append('client_no_context_takeover') if client_max_window_bits is not None: parameters.append('client_max_window_bits=%d' % client_max_window_bits) self.client_max_window_bits = client_max_window_bits if self.server_no_context_takeover: parameters.append('server_no_context_takeover') if server_max_window_bits is not None: parameters.append('server_max_window_bits=%d' % server_max_window_bits) self.server_max_window_bits = server_max_window_bits return '; '.join(parameters) def frame_inbound_header(self, proto, opcode, rsv, payload_length): if rsv.rsv1 and opcode.iscontrol(): return CloseReason.PROTOCOL_ERROR if rsv.rsv1 and opcode is Opcode.CONTINUATION: return CloseReason.PROTOCOL_ERROR self._inbound_is_compressible = self._compressible_opcode(opcode) if self._inbound_compressed is None: self._inbound_compressed = rsv.rsv1 if self._inbound_compressed: assert self._inbound_is_compressible if proto.client: bits = self.server_max_window_bits else: bits = self.client_max_window_bits if self._decompressor is None: self._decompressor = zlib.decompressobj(-int(bits)) return RsvBits(True, False, False) def frame_inbound_payload_data(self, proto, data): if not self._inbound_compressed or not self._inbound_is_compressible: return data try: return self._decompressor.decompress(bytes(data)) except zlib.error: return CloseReason.INVALID_FRAME_PAYLOAD_DATA def frame_inbound_complete(self, proto, fin): if not fin: return None if not self._inbound_is_compressible: self._inbound_compressed = None return None if not self._inbound_compressed: self._inbound_compressed = None return None try: data = self._decompressor.decompress(b'\x00\x00\xff\xff') data += self._decompressor.flush() except zlib.error: return CloseReason.INVALID_FRAME_PAYLOAD_DATA if proto.client: no_context_takeover = self.server_no_context_takeover else: no_context_takeover = self.client_no_context_takeover if no_context_takeover: self._decompressor = None self._inbound_compressed = None return data def frame_outbound(self, proto, opcode, rsv, data, fin): if not self._compressible_opcode(opcode): return (rsv, data) if opcode is not Opcode.CONTINUATION: rsv = RsvBits(True, *rsv[1:]) if self._compressor is None: assert opcode is not Opcode.CONTINUATION if proto.client: bits = self.client_max_window_bits else: bits = self.server_max_window_bits self._compressor = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -int(bits)) data = self._compressor.compress(bytes(data)) if fin: data += self._compressor.flush(zlib.Z_SYNC_FLUSH) data = data[:-4] if proto.client: no_context_takeover = self.client_no_context_takeover else: no_context_takeover = self.server_no_context_takeover if no_context_takeover: self._compressor = None return (rsv, data) def __repr__(self): descr = ['client_max_window_bits=%d' % self.client_max_window_bits] if self.client_no_context_takeover: descr.append('client_no_context_takeover') descr.append('server_max_window_bits=%d' % self.server_max_window_bits) if self.server_no_context_takeover: descr.append('server_no_context_takeover') descr = '; '.join(descr) return '<%s %s>' % (self.__class__.__name__, descr) #: SUPPORTED_EXTENSIONS maps all supported extension names to their class. #: This can be used to iterate all supported extensions of wsproto, instantiate #: new extensions based on their name, or check if a given extension is #: supported or not. SUPPORTED_EXTENSIONS = { PerMessageDeflate.name: PerMessageDeflate }