[tor-commits] [flashproxy/master] Reimplement matching algorithm using an Endpoints datastruct for both client/server

infinity0 at torproject.org infinity0 at torproject.org
Mon Oct 28 14:47:41 UTC 2013


commit 2658f6e3eac20a4f7c20e2c6b90ddf253731dce1
Author: Ximin Luo <infinity0 at gmx.com>
Date:   Mon Oct 7 16:47:53 2013 +0100

    Reimplement matching algorithm using an Endpoints datastruct for both client/server
    - select prefix pool based on least-well-served rather than arbitrarily
    - don't attempt a match until a proxy is available to service the request
    - don't match ipv6 proxies to ipv4 servers
---
 facilitator/facilitator      |  428 +++++++++++++++++++++++-------------------
 facilitator/facilitator-test |  166 +++++++++++++++-
 2 files changed, 396 insertions(+), 198 deletions(-)

diff --git a/facilitator/facilitator b/facilitator/facilitator
index e44047e..d4088ff 100755
--- a/facilitator/facilitator
+++ b/facilitator/facilitator
@@ -8,6 +8,7 @@ import sys
 import threading
 import time
 import traceback
+from collections import namedtuple
 
 import fac
 
@@ -65,26 +66,6 @@ again. Listen on 127.0.0.1 and port PORT (by default %(port)d).
     "log": DEFAULT_LOG_FILENAME,
 }
 
-def num_relays():
-    return sum(len(x) for x in RELAYS.values())
-
-def parse_transport_chain(spec):
-    """Parse a transport chain string and return a tuple of individual
-    transports, each of which is a string.
-    >>> parse_transport_chain("obfs3|websocket")
-    ('obfs3', 'websocket')
-    """
-    assert(spec)
-    return tuple(spec.split("|"))
-
-def get_outermost_transport(transports):
-    """Given a transport chain tuple, return the last element.
-    >>> get_outermost_transport(("obfs3", "websocket"))
-    'websocket'
-    """
-    assert(transports)
-    return transports[-1]
-
 def safe_str(s):
     """Return "[scrubbed]" if options.safe_logging is true, and s otherwise."""
     if options.safe_logging:
@@ -98,87 +79,191 @@ def log(msg):
         print >> options.log_file, (u"%s %s" % (time.strftime(LOG_DATE_FORMAT), msg)).encode("UTF-8")
         options.log_file.flush()
 
-class TCPReg(object):
-    def __init__(self, host, port, transports):
-        self.host = host
-        self.port = port
-        self.transports = transports
-        # Get a relay for this registration. Throw UnknownTransport if
-        # could not be found.
-        self.relay = self._get_matching_relay()
 
-    def _get_matching_relay(self):
-        """Return a matching relay address for this registration. Raise
-        UnknownTransport if a relay with a matching transport chain could not be
-        found."""
-        if self.transports not in RELAYS:
-            raise UnknownTransport("Can't find relay with transport chain: %s" % self.transports)
-
-        # Maybe this should be a random pick from the set of all the
-        # eligible relays. But let's keep it deterministic for now,
-        # and return the first one.
-
-        # return random.choice(RELAYS[self.transports])
-        return RELAYS[self.transports][0]
+class Transport(namedtuple("Transport", "prefix suffix")):
+    @classmethod
+    def parse(cls, transport):
+        if isinstance(transport, cls):
+            return transport
+        elif type(transport) == str:
+            if "|" in transport:
+                prefix, suffix = transport.rsplit("|", 1)
+            else:
+                prefix, suffix = "", transport
+            return cls(prefix, suffix)
+        else:
+            raise ValueError("could not parse transport: %s" % transport)
 
-    def __unicode__(self):
-        return fac.format_addr((self.host, self.port))
+    def __init__(self, prefix, suffix):
+        if not suffix:
+            raise ValueError("suffix (proxy) part of transport must be non-empty: %s" % str(self))
 
     def __str__(self):
-        return unicode(self).encode("UTF-8")
+        return "%s|%s" % (self.prefix, self.suffix) if self.prefix else self.suffix
 
-    def __cmp__(self, other):
-        if isinstance(other, TCPReg):
-            # XXX is this correct comparison?
-            return cmp((self.host, self.port, self.transports), (other.host, other.port, other.transports))
-        else:
-            return False
 
-class Reg(object):
-    @staticmethod
-    def parse(spec, transports, defhost = None, defport = None):
+class Reg(namedtuple("Reg", "addr transport")):
+    @classmethod
+    def parse(cls, spec, transport, defhost = None, defport = None):
         host, port = fac.parse_addr_spec(spec, defhost, defport)
-        return TCPReg(host, port, transports)
+        return cls((host, port), Transport.parse(transport))
 
-class RegSet(object):
-    def __init__(self):
-        self.tiers = [[] for i in range(MAX_PROXIES_PER_CLIENT)]
-        self.cv = threading.Condition()
 
-    def add(self, reg):
-        self.cv.acquire()
-        try:
-            for tier in self.tiers:
-                if reg in tier:
-                    break
-            else:
-                self.tiers[0].append(reg)
-                self.cv.notify()
-                return True
-            return False
-        finally:
-            self.cv.release()
-
-    def fetch(self):
-        self.cv.acquire()
-        try:
-            for i in range(len(self.tiers)):
-                tier = self.tiers[i]
-                if tier:
-                    reg = tier.pop(0)
-                    if i + 1 < len(self.tiers):
-                        self.tiers[i+1].append(reg)
-                    return reg
+class Endpoints(object):
+    """
+    Tracks endpoints (either client/server) and the transport chains that
+    they support.
+    """
+
+    matchingLock = threading.Condition()
+
+    def __init__(self, af, maxserve=float("inf"), known_suf=("websocket",)):
+        self.af = af
+        self._maxserve = maxserve
+        self._endpoints = {} # address -> transport
+        self._indexes = {} # suffix -> [ addresses ]
+        self._served = {} # address -> num_times_served
+        self._cv = threading.Condition()
+        self.known_suf = set(known_suf)
+        for suf in self.known_suf:
+            self._ensureIndexForSuffix(suf)
+
+    def getNumEndpoints(self):
+        """:returns: the number of endpoints known to us."""
+        with self._cv:
+            return len(self._endpoints)
+
+    def getNumUnservedEndpoints(self):
+        """:returns: the number of unserved endpoints known to us."""
+        with self._cv:
+            return len(filter(lambda t: t == 0, self._served.itervalues()))
+
+    def addEndpoint(self, addr, transport):
+        """Add an endpoint.
+
+        :param addr: Address of endpoint, usage-dependent.
+        :param list transports: List of transports.
+        :returns: False if the address is already known, in which case no
+            update is made to its supported transports, else True.
+        """
+        transport = Transport.parse(transport)
+        with self._cv:
+            if addr in self._endpoints: return False
+            self._endpoints[addr] = transport
+            self._served[addr] = 0
+            self._addAddrIntoIndexes(addr)
+            self._cv.notify()
+            return True
+
+    def delEndpoint(self, addr):
+        """Forget an endpoint.
+
+        :param addr: Address of endpoint, usage-dependent.
+        :returns: False if the address was already forgotten, else True.
+        """
+        with self._cv:
+            if addr not in self._endpoints: return False
+            self._delAddrFromIndexes(addr)
+            del self._served[addr]
+            del self._endpoints[addr]
+            self._cv.notify()
+            return True
+
+    def supports(self, transport):
+        """
+        Estimate whether we support the given transport. May give false
+        positives, but doing a proper match later on will catch these.
+
+        :returns: True if we know, or have met, proxies that might be able
+            to satisfy the requested transport against our known endpoints.
+        """
+        transport = Transport.parse(transport)
+        with self._cv:
+            known_pre = self._findPrefixesForSuffixes(*self.known_suf).keys()
+            pre, suf = transport.prefix, transport.suffix
+            return pre in known_pre and suf in self.known_suf
+
+    def _findPrefixesForSuffixes(self, *supported_suf):
+        """
+        :returns: { prefix: [addr] }, where each address supports some suffix
+            from supported_suf. TODO(infinity0): describe better
+        """
+        self.known_suf.update(supported_suf)
+        prefixes = {}
+        for suf in supported_suf:
+            self._ensureIndexForSuffix(suf)
+            for addr in self._indexes[suf]:
+                pre = self._endpoints[addr].prefix
+                prefixes.setdefault(pre, set()).add(addr)
+        return prefixes
+
+    def _avServed(self, addrpool):
+        return sum(self._served[a] for a in addrpool) / float(len(addrpool))
+
+    def _serveReg(self, addrpool):
+        """
+        :param list addrpool: List of candidate addresses.
+        :returns: An address of an endpoint from the given pool, or None if all
+            endpoints have already been served _maxserve times. The serve
+            counter for that address is also incremented.
+        """
+        if not addrpool: return None
+        prio_addr = min(addrpool, key=lambda a: self._served[a])
+        if self._served[prio_addr] < self._maxserve:
+            self._served[prio_addr] += 1
+            return prio_addr
+        else:
             return None
-        finally:
-            self.cv.release()
 
-    def __len__(self):
-        self.cv.acquire()
-        try:
-            return sum(len(tier) for tier in self.tiers)
-        finally:
-            self.cv.release()
+    def _ensureIndexForSuffix(self, suf):
+        if suf in self._indexes: return
+        addrs = set(addr for addr, transport in self._endpoints.iteritems()
+                         if transport.suffix == suf)
+        self._indexes[suf] = addrs
+
+    def _addAddrIntoIndexes(self, addr):
+        suf = self._endpoints[addr].suffix
+        if suf in self._indexes: self._indexes[suf].add(addr)
+
+    def _delAddrFromIndexes(self, addr):
+        suf = self._endpoints[addr].suffix
+        if suf in self._indexes: self._indexes[suf].remove(addr)
+
+    def _prefixesForTransport(self, transport, *supported_suf):
+        for suf in supported_suf:
+            if not suf:
+                yield transport
+            elif transport[-len(suf):] == suf:
+                yield transport[:-len(suf)]
+
+    EMPTY_MATCH = (None, None)
+    @staticmethod
+    def match(ptsClient, ptsServer, supported_suf):
+        """
+        :returns: A tuple (client Reg, server Reg) arbitrarily selected from
+            the available endpoints that can satisfy supported_suf.
+        """
+        if ptsClient.af != ptsServer.af:
+            raise ValueError("address family not equal!")
+        # need to operate on both structures
+        # so hold both locks plus a pair-wise lock
+        with Endpoints.matchingLock, ptsClient._cv, ptsServer._cv:
+            server_pre = ptsServer._findPrefixesForSuffixes(*supported_suf)
+            client_pre = ptsClient._findPrefixesForSuffixes(*supported_suf)
+            both = set(server_pre.keys()) & set(client_pre.keys())
+            if not both: return Endpoints.EMPTY_MATCH
+            # pick the prefix whose client address pool is least well-served
+            # TODO: this may be manipulated by clients, needs research
+            assert all(client_pre.itervalues()) # no pool is empty
+            pre = min(both, key=lambda p: ptsClient._avServed(client_pre[p]))
+            client_addr = ptsClient._serveReg(client_pre[pre])
+            if not client_addr: return Endpoints.EMPTY_MATCH
+            server_addr = ptsServer._serveReg(server_pre[pre])
+            # assume servers never run out
+            client_transport = ptsClient._endpoints[client_addr]
+            server_transport = ptsServer._endpoints[server_addr]
+            return Reg(client_addr, client_transport), Reg(server_addr, server_transport)
+
 
 class Handler(SocketServer.StreamRequestHandler):
     def __init__(self, *args, **kwargs):
@@ -277,16 +362,19 @@ class Handler(SocketServer.StreamRequestHandler):
             return self.error(u"TRANSPORT missing FROM param")
 
         try:
-            reg = get_reg_for_proxy(proxy_addr, transport_list)
+            client_reg, relay_reg = get_match_for_proxy(proxy_addr, transport_list)
         except Exception as e:
             return self.error(u"error getting match for proxy address %s: %%(cause)s" % safe_str(repr(proxy_spec)), e)
 
         check_back_in = get_check_back_in_for_proxy(proxy_addr)
 
-        if reg:
-            log(u"proxy (%s) gets client '%s' (transports: %s) (num relays: %s) (remaining regs: %d/%d)" %
-                (safe_str(repr(proxy_spec)), safe_str(unicode(reg)), reg.transports, num_relays(), num_unhandled_regs(), num_regs()))
-            print >> self.wfile, fac.render_transaction("OK", ("CLIENT", str(reg)), ("RELAY", fac.format_addr(reg.relay)), ("CHECK-BACK-IN", str(check_back_in)))
+        if client_reg:
+            log(u"proxy (%s) gets client '%s' (supported transports: %s) (num relays: %s) (remaining regs: %d/%d)" %
+                (safe_str(repr(proxy_spec)), safe_str(repr(client_reg.addr)), transport_list, num_relays(), num_unhandled_regs(), num_regs()))
+            print >> self.wfile, fac.render_transaction("OK",
+                ("CLIENT", fac.format_addr(client_reg.addr)),
+                ("RELAY", fac.format_addr(relay_reg.addr)),
+                ("CHECK-BACK-IN", str(check_back_in)))
         else:
             log(u"proxy (%s) gets none" % safe_str(repr(proxy_spec)))
             print >> self.wfile, fac.render_transaction("NONE", ("CHECK-BACK-IN", str(check_back_in)))
@@ -297,27 +385,21 @@ class Handler(SocketServer.StreamRequestHandler):
     # Example: PUT CLIENT="1.1.1.1:5555" TRANSPORT_CHAIN="obfs3|websocket"
     def do_PUT(self, params):
         # Check out if we recognize the transport chain in this registration request
-        transports_spec = fac.param_first("TRANSPORT_CHAIN", params)
-        if transports_spec is None:
+        transport = fac.param_first("TRANSPORT_CHAIN", params)
+        if transport is None:
             return self.error(u"PUT missing TRANSPORT_CHAIN param")
 
-        transports = parse_transport_chain(transports_spec)
-
+        transport = Transport.parse(transport)
         # See if we have relays that support this transport chain
-        if transports not in RELAYS:
-            log(u"Unrecognized transport chain: %s" % transports)
-            self.send_error() # XXX can we tell the flashproxy client of this error?
-            return False
-        # if we have relays that support this transport chain, we
-        # certainly have a regset for its outermost transport too.
-        assert(get_outermost_transport(transports) in REGSETS_IPV4)
+        if all(not pts.supports(transport) for pts in SERVERS.itervalues()):
+            return self.error(u"Unrecognized transport: %s" % transport)
 
         client_spec = fac.param_first("CLIENT", params)
         if client_spec is None:
             return self.error(u"PUT missing CLIENT param")
 
         try:
-            reg = Reg.parse(client_spec, transports)
+            reg = Reg.parse(client_spec, transport)
         except (UnknownTransport, ValueError) as e:
             # XXX should we throw a better error message to the client? Is it possible?
             return self.error(u"syntax error in %s: %%(cause)s" % safe_str(repr(client_spec)), e)
@@ -328,9 +410,9 @@ class Handler(SocketServer.StreamRequestHandler):
             return self.error(u"error putting reg %s: %%(cause)s" % safe_str(repr(client_spec)), e)
 
         if ok:
-            log(u"client %s (transports: %s) (remaining regs: %d/%d)" % (safe_str(unicode(reg)), reg.transports, num_unhandled_regs(), num_regs()))
+            log(u"client %s (transports: %s) (remaining regs: %d/%d)" % (safe_str(unicode(reg)), reg.transport, num_unhandled_regs(), num_regs()))
         else:
-            log(u"client %s (already present) (transports: %s) (remaining regs: %d/%d)" % (safe_str(unicode(reg)), reg.transports, num_unhandled_regs(), num_regs()))
+            log(u"client %s (already present) (transports: %s) (remaining regs: %d/%d)" % (safe_str(unicode(reg)), reg.transport, num_unhandled_regs(), num_regs()))
 
         self.send_ok()
         return True
@@ -341,48 +423,29 @@ class Server(SocketServer.ThreadingMixIn, SocketServer.TCPServer):
     allow_reuse_address = True
 
 # Registration sets per-outermost-transport
-# {"websocket" : <RegSet for websocket>, "webrtc" : <RegSet for webrtc>}
-REGSETS_IPV4 = {}
-REGSETS_IPV6 = {}
+# Addresses are plain tuples (str(host), int(port))
 
-def num_regs():
-    """Return the total number of registrations."""
-    num_regs = 0
+CLIENTS = {
+    socket.AF_INET: Endpoints(af=socket.AF_INET, maxserve=MAX_PROXIES_PER_CLIENT),
+    socket.AF_INET6: Endpoints(af=socket.AF_INET6, maxserve=MAX_PROXIES_PER_CLIENT)
+}
+
+SERVERS = {
+    socket.AF_INET: Endpoints(af=socket.AF_INET),
+    socket.AF_INET6: Endpoints(af=socket.AF_INET6)
+}
 
-    # Iterate the regsets of each regset-dictionary, and count their
-    # registrations.
-    for regset in REGSETS_IPV4.values():
-        num_regs += len(regset)
-    for regset in REGSETS_IPV6.values():
-        num_regs += len(regset)
+def num_relays():
+    """Return the total number of relays."""
+    return sum(pts.getNumEndpoints() for pts in SERVERS.itervalues())
 
-    return num_regs
+def num_regs():
+    """Return the total number of registrations."""
+    return sum(pts.getNumEndpoints() for pts in CLIENTS.itervalues())
 
 def num_unhandled_regs():
     """Return the total number of unhandled registrations."""
-    num_regs = 0
-
-    # Iterate the regsets of each regset-dictionary, and count their
-    # unhandled registrations. The first tier of each regset contains
-    # the registrations with no assigned proxy.
-    for regset in REGSETS_IPV4.values():
-        num_regs += len(regset.tiers[0])
-    for regset in REGSETS_IPV6.values():
-        num_regs += len(regset.tiers[0])
-
-    return num_regs
-
-def get_regs(af, transport):
-    """Return the correct regs pool for the given address family and transport."""
-    if transport not in REGSETS_IPV4:
-        raise UnknownTransport("unknown transport '%s'" % transport)
-
-    if af == socket.AF_INET:
-        return REGSETS_IPV4[transport]
-    elif af == socket.AF_INET6:
-        return REGSETS_IPV6[transport]
-    else:
-        raise ValueError("unknown address family %d" % af)
+    return sum(pts.getNumUnservedEndpoints() for pts in CLIENTS.itervalues())
 
 def addr_af(addr_str):
     """Return the address family for an address string. This is a plain string,
@@ -390,26 +453,12 @@ def addr_af(addr_str):
     addrs = socket.getaddrinfo(addr_str, 0, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_NUMERICHOST)
     return addrs[0][0]
 
-def get_reg_for_proxy(proxy_addr, transport_list):
-    """Get a client registration appropriate for the given proxy (one
-    of a matching address family). If 'transports' is set, try to find
-    a client registration that supports the outermost transport of a
-    transport chain."""
-    # XXX How should we prioritize transport matching? We currently
-    # just iterate the transport list that was provided by the flashproxy
-    for transport in transport_list:
-        addr_str = proxy_addr[0]
-        af = addr_af(addr_str)
-
-        try:
-            REGS = get_regs(af, transport)
-        except UnknownTransport as e:
-            log(u"%s" % e)
-            continue # move to the next transport
-
-        return REGS.fetch()
-
-    raise UnknownTransport("Could not find registration for transport list: %s" % str(transport_list))
+def get_match_for_proxy(proxy_addr, transport_list):
+    af = addr_af(proxy_addr[0])
+    try:
+        return Endpoints.match(CLIENTS[af], SERVERS[af], transport_list)
+    except ValueError as e:
+        raise UnknownTransport("Could not find registration for transport list: %s: %s" % (transport_list, e))
 
 def get_check_back_in_for_proxy(proxy_addr):
     """Get a CHECK-BACK-IN interval suitable for this proxy."""
@@ -417,29 +466,24 @@ def get_check_back_in_for_proxy(proxy_addr):
 
 def put_reg(reg):
     """Add a registration."""
-    addr_str = reg.host
-    af = addr_af(addr_str)
-    REGS = get_regs(af, get_outermost_transport(reg.transports))
-    return REGS.add(reg)
+    af = addr_af(reg.addr[0])
+    return CLIENTS[af].addEndpoint(reg.addr, reg.transport)
 
-def parse_relay_file(filename):
+def parse_relay_file(servers, fp):
     """Parse a file containing Tor relays that we can point proxies to.
     Throws ValueError on a parsing error. Each line contains a transport chain
     and an address, for example
         obfs2|websocket 1.4.6.1:4123
     """
-    relays = {}
-    with open(filename) as f:
-        for line in f:
-            try:
-                transport_spec, addr_spec = line.strip().split()
-            except ValueError, e:
-                raise ValueError("Wrong line format: %s." % repr(line))
-            addr = fac.parse_addr_spec(addr_spec, defport=DEFAULT_RELAY_PORT, resolve=True)
-            transports = parse_transport_chain(transport_spec)
-            relays.setdefault(transports, [])
-            relays[transports].append(addr)
-    return relays
+    for line in fp.readlines():
+        try:
+            transport_spec, addr_spec = line.strip().split()
+        except ValueError, e:
+            raise ValueError("Wrong line format: %s." % repr(line))
+        addr = fac.parse_addr_spec(addr_spec, defport=DEFAULT_RELAY_PORT, resolve=True)
+        transport = Transport.parse(transport_spec)
+        af = addr_af(addr[0])
+        servers[af].addEndpoint(addr, transport)
 
 def main():
     opts, args = getopt.gnu_getopt(sys.argv[1:], "dhl:p:r:",
@@ -474,17 +518,12 @@ obfs2|websocket 1.4.6.1:4123\
 """
         sys.exit(1)
 
-    RELAYS.update(parse_relay_file(options.relay_filename))
-
-    if not RELAYS:
-        print >> sys.stderr, u"Warning: no relays configured."
-
-    # Create RegSets for our supported transports
-    for transport in RELAYS.keys():
-        outermost_transport = get_outermost_transport(transport)
-        if outermost_transport not in REGSETS_IPV4:
-            REGSETS_IPV4[outermost_transport] = RegSet()
-            REGSETS_IPV6[outermost_transport] = RegSet()
+    try:
+        with open(options.relay_filename) as fp:
+            parse_relay_file(SERVERS, fp)
+    except ValueError as e:
+        print >> sys.stderr, u"Could not parse file '%s': %s" % (repr(a), str(e))
+        sys.exit(1)
 
     # Setup log file
     if options.log_filename:
@@ -499,7 +538,8 @@ obfs2|websocket 1.4.6.1:4123\
     server = Server(addrinfo[4], Handler)
 
     log(u"start on %s" % fac.format_addr(addrinfo[4]))
-    log(u"using relays %s" % str(RELAYS))
+    log(u"using IPv4 relays %s" % str(SERVERS[socket.AF_INET]._endpoints))
+    log(u"using IPv6 relays %s" % str(SERVERS[socket.AF_INET6]._endpoints))
 
     if options.daemonize:
         log(u"daemonizing")
diff --git a/facilitator/facilitator-test b/facilitator/facilitator-test
index e39cecd..b81b84e 100755
--- a/facilitator/facilitator-test
+++ b/facilitator/facilitator-test
@@ -7,7 +7,7 @@ import tempfile
 import time
 import unittest
 
-from facilitator import parse_transport_chain
+from facilitator import Transport, Reg, Endpoints, parse_relay_file
 import fac
 
 FACILITATOR_HOST = "127.0.0.1"
@@ -23,11 +23,169 @@ def gimme_socket(host, port):
     s.connect(addrinfo[4])
     return s
 
+class EndpointsTest(unittest.TestCase):
+
+    def setUp(self):
+        self.pts = Endpoints(af=socket.AF_INET)
+
+    def _observeProxySupporting(self, *supported_suf):
+        # semantically observe the existence of a proxy, to make our intent
+        # a bit clearer than simply calling findPrefixesForSuffixes
+        self.pts._findPrefixesForSuffixes(*supported_suf)
+
+    def test_addEndpoints_twice(self):
+        self.pts.addEndpoint("A", "a|b|p")
+        self.assertFalse(self.pts.addEndpoint("A", "zzz"))
+        self.assertEquals(self.pts._endpoints["A"], Transport("a|b", "p"))
+
+    def test_addEndpoints_lazy_indexing(self):
+        self.pts.addEndpoint("A", "a|b|p")
+        default_index = {"websocket": set()} # we always index known_suffixes
+
+        # no index until we've asked for it
+        self.assertEquals(self.pts._indexes, default_index)
+        self._observeProxySupporting("p")
+        self.assertEquals(self.pts._indexes["p"], set("A"))
+
+        # indexes are updated correctly after observing new addresses
+        self.pts.addEndpoint("B", "c|p")
+        self.assertEquals(self.pts._indexes["p"], set("AB"))
+
+        # indexes are updated correctly after observing new proxies
+        self.pts.addEndpoint("C", "a|q")
+        self._observeProxySupporting("q")
+        self.assertEquals(self.pts._indexes["q"], set("C"))
+
+    def test_supports_default(self):
+        # we know there are websocket-capable proxies out there;
+        # support them implicitly without needing to see a proxy.
+        self.pts.addEndpoint("A", "obfs3|websocket")
+        self.assertTrue(self.pts.supports("obfs3|websocket"))
+        self.assertFalse(self.pts.supports("xxx|websocket"))
+        self.assertFalse(self.pts.supports("websocket"))
+        self.assertFalse(self.pts.supports("unknownwhat"))
+        # doesn't matter what the first part is
+        self.pts.addEndpoint("B", "xxx|websocket")
+        self.assertTrue(self.pts.supports("xxx|websocket"))
+
+    def test_supports_seen_proxy(self):
+        # OTOH if some 3rd-party proxy decides to implement its own transport
+        # we are fully capable of supporting them too, but only if we have
+        # an endpoint that also speaks it.
+        self.assertFalse(self.pts.supports("obfs3|unknownwhat"))
+        suf = self._observeProxySupporting("unknownwhat")
+        self.assertFalse(self.pts.supports("obfs3|unknownwhat"))
+        self.pts.addEndpoint("A", "obfs3|unknownwhat")
+        self.assertTrue(self.pts.supports("obfs3|unknownwhat"))
+        self.assertFalse(self.pts.supports("obfs2|unknownwhat"))
+
+    def _test_serveReg_maxserve_infinite_roundrobin(self):
+        # case for servers, they never exhaust
+        self.pts.addEndpoint("A", "a|p")
+        self.pts.addEndpoint("B", "a|p")
+        self.pts.addEndpoint("C", "a|p")
+        for i in xrange(64): # 64 is infinite ;)
+            served = set()
+            served.add(self.pts._serveReg("ABC"))
+            served.add(self.pts._serveReg("ABC"))
+            served.add(self.pts._serveReg("ABC"))
+            self.assertEquals(served, set("ABC"))
+
+    def _test_serveReg_maxserve_finite_exhaustion(self):
+        # case for clients, we don't want to keep serving them
+        self.pts = Endpoints(af=socket.AF_INET, maxserve=5)
+        self.pts.addEndpoint("A", "a|p")
+        self.pts.addEndpoint("B", "a|p")
+        self.pts.addEndpoint("C", "a|p")
+        # test getNumUnservedEndpoints whilst we're at it
+        self.assertEquals(self.pts.getNumUnservedEndpoints(), 3)
+        for i in xrange(5):
+            served = set()
+            served.add(self.pts._serveReg("ABC"))
+            served.add(self.pts._serveReg("ABC"))
+            served.add(self.pts._serveReg("ABC"))
+            self.assertEquals(served, set("ABC"))
+        self.assertEquals(None, self.pts._serveReg("ABC"))
+        self.assertEquals(self.pts.getNumUnservedEndpoints(), 0)
+
+    def test_match_normal(self):
+        self.pts.addEndpoint("A", "a|p")
+        self.pts2 = Endpoints(af=socket.AF_INET)
+        self.pts2.addEndpoint("B", "a|p")
+        self.pts2.addEndpoint("C", "b|p")
+        self.pts2.addEndpoint("D", "a|q")
+        expected = (Reg("A", Transport("a","p")), Reg("B", Transport("a","p")))
+        empty = Endpoints.EMPTY_MATCH
+        self.assertEquals(expected, Endpoints.match(self.pts, self.pts2, ["p"]))
+        self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"]))
+
+    def test_match_unequal_client_server(self):
+        self.pts.addEndpoint("A", "a|p")
+        self.pts2 = Endpoints(af=socket.AF_INET)
+        self.pts2.addEndpoint("B", "a|q")
+        expected = (Reg("A", Transport("a","p")), Reg("B", Transport("a","q")))
+        empty = Endpoints.EMPTY_MATCH
+        self.assertEquals(expected, Endpoints.match(self.pts, self.pts2, ["p", "q"]))
+        self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["p"]))
+        self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["q"]))
+        self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"]))
+
+    def test_match_raw_server(self):
+        self.pts.addEndpoint("A", "p")
+        self.pts2 = Endpoints(af=socket.AF_INET)
+        self.pts2.addEndpoint("B", "p")
+        expected = (Reg("A", Transport("","p")), Reg("B", Transport("","p")))
+        empty = Endpoints.EMPTY_MATCH
+        self.assertEquals(expected, Endpoints.match(self.pts, self.pts2, ["p"]))
+        self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"]))
+
+    def test_match_many_prefixes(self):
+        self.pts.addEndpoint("A", "a|p")
+        self.pts.addEndpoint("B", "b|p")
+        self.pts.addEndpoint("C", "p")
+        self.pts2 = Endpoints(af=socket.AF_INET)
+        self.pts2.addEndpoint("D", "a|p")
+        self.pts2.addEndpoint("E", "b|p")
+        self.pts2.addEndpoint("F", "p")
+        # this test ensures we have a sane policy for selecting between prefix pools
+        expected = set()
+        expected.add((Reg("A", Transport("a","p")), Reg("D", Transport("a","p"))))
+        expected.add((Reg("B", Transport("b","p")), Reg("E", Transport("b","p"))))
+        expected.add((Reg("C", Transport("","p")), Reg("F", Transport("","p"))))
+        result = set()
+        result.add(Endpoints.match(self.pts, self.pts2, ["p"]))
+        result.add(Endpoints.match(self.pts, self.pts2, ["p"]))
+        result.add(Endpoints.match(self.pts, self.pts2, ["p"]))
+        empty = Endpoints.EMPTY_MATCH
+        self.assertEquals(expected, result)
+        self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"]))
+        self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"]))
+        self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"]))
+
 class FacilitatorTest(unittest.TestCase):
 
-    def test_parse_transport_chain(self):
-        self.assertEquals(parse_transport_chain("a"), ("a",))
-        self.assertEquals(parse_transport_chain("a|b|c"), ("a","b","c"))
+    def test_transport_parse(self):
+        self.assertEquals(Transport.parse("a"), Transport("", "a"))
+        self.assertEquals(Transport.parse("|a"), Transport("", "a"))
+        self.assertEquals(Transport.parse("a|b|c"), Transport("a|b","c"))
+        self.assertEquals(Transport.parse(Transport("a|b","c")), Transport("a|b","c"))
+        self.assertRaises(ValueError, Transport, "", "")
+        self.assertRaises(ValueError, Transport, "a", "")
+        self.assertRaises(ValueError, Transport.parse, "")
+        self.assertRaises(ValueError, Transport.parse, "|")
+        self.assertRaises(ValueError, Transport.parse, "a|")
+        self.assertRaises(ValueError, Transport.parse, ["a"])
+        self.assertRaises(ValueError, Transport.parse, [Transport("a", "b")])
+
+    def test_parse_relay_file(self):
+        fp = StringIO()
+        fp.write("websocket 0.0.1.0:1\n")
+        fp.flush()
+        fp.seek(0)
+        af = socket.AF_INET
+        servers = { af: Endpoints(af=af) }
+        parse_relay_file(servers, fp)
+        self.assertEquals(servers[af]._endpoints, {('0.0.1.0', 1): Transport('', 'websocket')})
 
 class FacilitatorProcTest(unittest.TestCase):
     IPV4_CLIENT_ADDR = ("1.1.1.1", 9000)





More information about the tor-commits mailing list