[tor-commits] [flashproxy/master] Add WebSocketEncoder and tests.
dcf at torproject.org
dcf at torproject.org
Mon Apr 9 04:08:42 UTC 2012
commit 7ea31933aef27276cb443eee49369bccb6661ffa
Author: David Fifield <david at bamsoftware.com>
Date: Wed Mar 28 00:14:51 2012 -0700
Add WebSocketEncoder and tests.
---
connector-test.py | 28 +++++++++++++++++++++++++++-
connector.py | 51 ++++++++++++++++++++++++++++++++++++++++++---------
2 files changed, 69 insertions(+), 10 deletions(-)
diff --git a/connector-test.py b/connector-test.py
index c0479bd..a15c1ad 100755
--- a/connector-test.py
+++ b/connector-test.py
@@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
import unittest
-from connector import WebSocketDecoder
+from connector import WebSocketDecoder, WebSocketEncoder
def read_frames(dec):
frames = []
@@ -147,5 +147,31 @@ class TestWebSocketDecoder(unittest.TestCase):
dec.feed("\x82\x7f\x00\x00\x00\x00\x01\x00\x00\x00")
self.assertRaises(ValueError, dec.read_frame)
+class TestWebSocketEncoder(unittest.TestCase):
+ def test_length(self):
+ """Test that payload lengths are encoded using the smallest number of
+ bytes."""
+ TESTS = [(0, 0), (125, 0), (126, 2), (65535, 2), (65536, 8)]
+ for length, encoded_length in TESTS:
+ enc = WebSocketEncoder(use_mask = False)
+ eframe = enc.encode_frame(2, "\x00" * length)
+ self.assertEqual(len(eframe), 1 + 1 + encoded_length + length)
+ enc = WebSocketEncoder(use_mask = True)
+ eframe = enc.encode_frame(2, "\x00" * length)
+ self.assertEqual(len(eframe), 1 + 1 + encoded_length + 4 + length)
+
+ def test_roundtrip(self):
+ TESTS = [
+ (1, u"Hello world"),
+ (1, u"Hello \N{WHITE SMILING FACE}"),
+ ]
+ for opcode, payload in TESTS:
+ for use_mask in (False, True):
+ enc = WebSocketEncoder(use_mask = use_mask)
+ enc_message = enc.encode_message(opcode, payload)
+ dec = WebSocketDecoder(use_mask = use_mask)
+ dec.feed(enc_message)
+ self.assertEqual(read_messages(dec), [(opcode, payload)])
+
if __name__ == "__main__":
unittest.main()
diff --git a/connector.py b/connector.py
index 39d89e8..2a06523 100755
--- a/connector.py
+++ b/connector.py
@@ -131,6 +131,13 @@ class BufferSocket(object):
return time.time() - self.birthday > timeout
+def apply_mask(payload, mask_key):
+ result = []
+ for i, c in enumerate(payload):
+ mc = chr(ord(payload[i]) ^ ord(mask_key[i%4]))
+ result.append(mc)
+ return "".join(result)
+
class WebSocketFrame(object):
def __init__(self):
self.fin = False
@@ -171,14 +178,6 @@ class WebSocketDecoder(object):
def feed(self, data):
self.buf += data
- @staticmethod
- def mask(payload, mask_key):
- result = []
- for i, c in enumerate(payload):
- mc = chr(ord(payload[i]) ^ ord(mask_key[i%4]))
- result.append(mc)
- return "".join(result)
-
def read_frame(self):
"""Read a frame from the internal buffer, if one is available. Returns a
WebSocketFrame object, or None if there are no complete frames to
@@ -226,7 +225,7 @@ class WebSocketDecoder(object):
if len(self.buf) < offset + payload_len:
return None
- payload = WebSocketDecoder.mask(self.buf[offset:offset+payload_len], mask_key)
+ payload = apply_mask(self.buf[offset:offset+payload_len], mask_key)
self.buf = self.buf[offset+payload_len:]
frame = WebSocketFrame()
@@ -283,6 +282,40 @@ class WebSocketDecoder(object):
message.payload = message.payload.decode("utf-8")
return message
+class WebSocketEncoder(object):
+ def __init__(self, use_mask = False):
+ self.use_mask = use_mask
+
+ def encode_frame(self, opcode, payload):
+ if opcode >= 16:
+ raise ValueError("Opcode of %d is >= 16" % opcode)
+ length = len(payload)
+
+ if self.use_mask:
+ mask_key = os.urandom(4)
+ payload = apply_mask(payload, mask_key)
+ mask_bit = 0x80
+ else:
+ mask_key = ""
+ mask_bit = 0x00
+
+ if length < 126:
+ len_b, len_ext = length, ""
+ elif length < 0x10000:
+ len_b, len_ext = 126, struct.pack(">H", length)
+ elif length < 0x10000000000000000:
+ len_b, len_ext = 127, struct.pack(">Q", length)
+ else:
+ raise ValueError("payload length of %d is too long" % length)
+
+ return chr(0x80 | opcode) + chr(mask_bit | len_b) + len_ext + mask_key + payload
+
+ def encode_message(self, opcode, payload):
+ if opcode == 1:
+ payload = payload.encode("utf-8")
+ return self.encode_frame(opcode, payload)
+
+
def listen_socket(addr):
"""Return a nonblocking socket listening on the given address."""
addrinfo = socket.getaddrinfo(addr[0], addr[1], 0, socket.SOCK_STREAM, socket.IPPROTO_TCP)[0]
More information about the tor-commits
mailing list