[tor-commits] [flashproxy/master] simplify Endpoints a bit

infinity0 at torproject.org infinity0 at torproject.org
Mon Oct 28 14:47:41 UTC 2013


commit e027db3d3348944c7f75cda52c41b020f02b2aba
Author: Ximin Luo <infinity0 at gmx.com>
Date:   Fri Oct 11 15:21:14 2013 +0100

    simplify Endpoints a bit
    - remove Endpoints.supports() and related code - the important thing actually is whether a proxy supports a transport, not whether we have relays that support it
    - use defaultdict to get rid of some boilerplate, and populate _indexes unconditionally
---
 facilitator/facilitator      |   89 +++++++++++++++++-------------------------
 facilitator/facilitator-test |   61 ++++++++---------------------
 2 files changed, 53 insertions(+), 97 deletions(-)

diff --git a/facilitator/facilitator b/facilitator/facilitator
index d011013..9e11826 100755
--- a/facilitator/facilitator
+++ b/facilitator/facilitator
@@ -8,6 +8,7 @@ import sys
 import threading
 import time
 import traceback
+from collections import defaultdict
 
 import fac
 from fac import Transport, Endpoint
@@ -26,6 +27,7 @@ CLIENT_TIMEOUT = 1.0
 READLINE_MAX_LENGTH = 10240
 
 MAX_PROXIES_PER_CLIENT = 5
+DEFAULT_OUTER_TRANSPORTS = ["websocket"]
 
 LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
 
@@ -45,6 +47,7 @@ class options(object):
     pid_filename = None
     privdrop_username = None
     safe_logging = True
+    outer_transports = DEFAULT_OUTER_TRANSPORTS
 
 def usage(f = sys.stdout):
     print >> f, """\
@@ -59,11 +62,15 @@ again. Listen on 127.0.0.1 and port PORT (by default %(port)d).
       --pidfile FILENAME    write PID to FILENAME after daemonizing.
       --privdrop-user USER  switch UID and GID to those of USER.
   -r, --relay-file RELAY    learn relays from FILE.
+      --outer-transports TRANSPORTS
+                            comma-sep list of outer transports to accept proxies
+                            for (by default %(outer-transports)s)
       --unsafe-logging      don't scrub IP addresses from logs.\
 """ % {
     "progname": sys.argv[0],
     "port": DEFAULT_LISTEN_PORT,
     "log": DEFAULT_LOG_FILENAME,
+    "outer-transports": ",".join(DEFAULT_OUTER_TRANSPORTS)
 }
 
 def safe_str(s):
@@ -87,16 +94,13 @@ class Endpoints(object):
 
     matchingLock = threading.Condition()
 
-    def __init__(self, af, maxserve=float("inf"), known_outer=("websocket",)):
+    def __init__(self, af, maxserve=float("inf")):
         self.af = af
         self._maxserve = maxserve
         self._endpoints = {} # address -> transport
-        self._indexes = {} # outer -> [ addresses ]
+        self._indexes = defaultdict(lambda: defaultdict(set)) # outer -> inner -> [ addresses ]
         self._served = {} # address -> num_times_served
         self._cv = threading.Condition()
-        self.known_outer = set(known_outer)
-        for outer in self.known_outer:
-            self._ensureIndexForOuter(outer)
 
     def getNumEndpoints(self):
         """:returns: the number of endpoints known to us."""
@@ -119,9 +123,10 @@ class Endpoints(object):
         transport = Transport.parse(transport)
         with self._cv:
             if addr in self._endpoints: return False
+            inner, outer = transport
             self._endpoints[addr] = transport
             self._served[addr] = 0
-            self._addAddrIntoIndexes(addr)
+            self._indexes[outer][inner].add(addr)
             self._cv.notify()
             return True
 
@@ -133,43 +138,26 @@ class Endpoints(object):
         """
         with self._cv:
             if addr not in self._endpoints: return False
-            self._delAddrFromIndexes(addr)
+            inner, outer = self._endpoints[addr]
+            self._indexes[outer][inner].remove(addr) # TODO(infinity0): maybe delete empty bins
             del self._served[addr]
             del self._endpoints[addr]
             self._cv.notify()
             return True
 
-    def supports(self, transport):
-        """
-        Estimate whether we support the given transport. May give false
-        positives, but doing a proper match later on will catch these.
-
-        :returns: True if we know, or have met, proxies that might be able
-            to satisfy the requested transport against our known endpoints.
-        """
-        transport = Transport.parse(transport)
-        with self._cv:
-            known_inner = self._findInnerForOuter(*self.known_outer).keys()
-            inner, outer = transport.inner, transport.outer
-            return inner in known_inner and outer in self.known_outer
-
     def _findInnerForOuter(self, *supported_outer):
         """
         :returns: { inner: [addr] }, where each address supports some outer
             from supported_outer. TODO(infinity0): describe better
         """
-        self.known_outer.update(supported_outer)
-        inners = {}
-        for outer in supported_outer:
-            self._ensureIndexForOuter(outer)
-            for addr in self._indexes[outer]:
-                inner = self._endpoints[addr].inner
-                inners.setdefault(inner, set()).add(addr)
+        inners = defaultdict(set)
+        for outer in set(supported_outer) & set(self._indexes.iterkeys()):
+            for inner, addrs in self._indexes[outer].iteritems():
+                if addrs:
+                    # don't add empty bins, to avoid false-positive key checks
+                    inners[inner].update(addrs)
         return inners
 
-    def _avServed(self, addrpool):
-        return sum(self._served[a] for a in addrpool) / float(len(addrpool))
-
     def _serveReg(self, addrpool):
         """
         :param list addrpool: List of candidate addresses.
@@ -186,20 +174,6 @@ class Endpoints(object):
             self.delEndpoint(prio_addr)
         return prio_addr
 
-    def _ensureIndexForOuter(self, outer):
-        if outer in self._indexes: return
-        addrs = set(addr for addr, transport in self._endpoints.iteritems()
-                         if transport.outer == outer)
-        self._indexes[outer] = addrs
-
-    def _addAddrIntoIndexes(self, addr):
-        outer = self._endpoints[addr].outer
-        if outer in self._indexes: self._indexes[outer].add(addr)
-
-    def _delAddrFromIndexes(self, addr):
-        outer = self._endpoints[addr].outer
-        if outer in self._indexes: self._indexes[outer].remove(addr)
-
     EMPTY_MATCH = (None, None)
     @staticmethod
     def match(ptsClient, ptsServer, supported_outer):
@@ -208,7 +182,9 @@ class Endpoints(object):
             the available endpoints that can satisfy supported_outer.
         """
         if ptsClient.af != ptsServer.af:
-            raise ValueError("address family not equal!")
+            raise ValueError("address family not equal")
+        if ptsServer._maxserve < float("inf"):
+            raise ValueError("servers mustn't run out")
         # need to operate on both structures
         # so hold both locks plus a pair-wise lock
         with Endpoints.matchingLock, ptsClient._cv, ptsServer._cv:
@@ -216,16 +192,19 @@ class Endpoints(object):
             client_inner = ptsClient._findInnerForOuter(*supported_outer)
             both = set(server_inner.keys()) & set(client_inner.keys())
             if not both: return Endpoints.EMPTY_MATCH
-            # pick the inner whose client address pool is least well-served
-            # TODO: this may be manipulated by clients, needs research
-            assert all(client_inner.itervalues()) # no pool is empty
-            inner = min(both, key=lambda p: ptsClient._avServed(client_inner[p]))
-            client_addr = ptsClient._serveReg(client_inner[inner])
+            # find a client to serve
+            client_pool = [addr for inner in both for addr in client_inner[inner]]
+            assert len(client_pool)
+            client_addr = ptsClient._serveReg(client_pool)
+            # find a server to serve that has the same inner transport
+            inner = ptsClient._endpoints[client_addr].inner
+            assert inner in server_inner and len(server_inner[inner])
             server_addr = ptsServer._serveReg(server_inner[inner])
             # assume servers never run out
             client_transport = ptsClient._endpoints[client_addr]
             server_transport = ptsServer._endpoints[server_addr]
-            return Endpoint(client_addr, client_transport), Endpoint(server_addr, server_transport)
+            return (Endpoint(client_addr, client_transport),
+                    Endpoint(server_addr, server_transport))
 
 
 class Handler(SocketServer.StreamRequestHandler):
@@ -354,7 +333,7 @@ class Handler(SocketServer.StreamRequestHandler):
 
         transport = Transport.parse(transport)
         # See if we have relays that support this transport
-        if all(not pts.supports(transport) for pts in SERVERS.itervalues()):
+        if transport.outer not in options.outer_transports:
             return self.error(u"Unrecognized transport: %s" % transport)
 
         client_spec = fac.param_first("CLIENT", params)
@@ -445,6 +424,8 @@ def parse_relay_file(servers, fp):
             raise ValueError("Wrong line format: %s." % repr(line))
         addr = fac.parse_addr_spec(addr_spec, defport=DEFAULT_RELAY_PORT, resolve=True)
         transport = Transport.parse(transport_spec)
+        if transport.outer not in options.outer_transports:
+            raise ValueError(u"Unrecognized transport: %s" % transport)
         af = addr_af(addr[0])
         servers[af].addEndpoint(addr, transport)
 
@@ -468,6 +449,8 @@ def main():
             options.privdrop_username = a
         elif o == "-r" or o == "--relay-file":
             options.relay_filename = a
+        elif o == "--outer-transports":
+            options.outer_transports = a.split(",")
         elif o == "--unsafe-logging":
             options.safe_logging = False
 
diff --git a/facilitator/facilitator-test b/facilitator/facilitator-test
index 8e06053..8143348 100755
--- a/facilitator/facilitator-test
+++ b/facilitator/facilitator-test
@@ -29,56 +29,29 @@ class EndpointsTest(unittest.TestCase):
     def setUp(self):
         self.pts = Endpoints(af=socket.AF_INET)
 
-    def _observeProxySupporting(self, *supported_outer):
-        # semantically observe the existence of a proxy, to make our intent
-        # a bit clearer than simply calling _findInnerForOuter
-        self.pts._findInnerForOuter(*supported_outer)
-
     def test_addEndpoints_twice(self):
         self.pts.addEndpoint("A", "a|b|p")
         self.assertFalse(self.pts.addEndpoint("A", "zzz"))
         self.assertEquals(self.pts._endpoints["A"], Transport("a|b", "p"))
 
-    def test_addEndpoints_lazy_indexing(self):
+    def test_delEndpoints_twice(self):
+        self.pts.addEndpoint("A", "a|b|p")
+        self.assertTrue(self.pts.delEndpoint("A"))
+        self.assertFalse(self.pts.delEndpoint("A"))
+        self.assertEquals(self.pts._endpoints.get("A"), None)
+
+    def test_Endpoints_indexing(self):
+        self.assertEquals(self.pts._indexes.get("p"), None)
+        # test defaultdict works as expected
+        self.assertEquals(self.pts._indexes["p"]["a|b"], set(""))
         self.pts.addEndpoint("A", "a|b|p")
-        default_index = {"websocket": set()} # we always index known_outer
-
-        # no index until we've asked for it
-        self.assertEquals(self.pts._indexes, default_index)
-        self._observeProxySupporting("p")
-        self.assertEquals(self.pts._indexes["p"], set("A"))
-
-        # indexes are updated correctly after observing new addresses
-        self.pts.addEndpoint("B", "c|p")
-        self.assertEquals(self.pts._indexes["p"], set("AB"))
-
-        # indexes are updated correctly after observing new proxies
-        self.pts.addEndpoint("C", "a|q")
-        self._observeProxySupporting("q")
-        self.assertEquals(self.pts._indexes["q"], set("C"))
-
-    def test_supports_default(self):
-        # we know there are websocket-capable proxies out there;
-        # support them implicitly without needing to see a proxy.
-        self.pts.addEndpoint("A", "obfs3|websocket")
-        self.assertTrue(self.pts.supports("obfs3|websocket"))
-        self.assertFalse(self.pts.supports("xxx|websocket"))
-        self.assertFalse(self.pts.supports("websocket"))
-        self.assertFalse(self.pts.supports("unknownwhat"))
-        # doesn't matter what the first part is
-        self.pts.addEndpoint("B", "xxx|websocket")
-        self.assertTrue(self.pts.supports("xxx|websocket"))
-
-    def test_supports_seen_proxy(self):
-        # OTOH if some 3rd-party proxy decides to implement its own transport
-        # we are fully capable of supporting them too, but only if we have
-        # an endpoint that also speaks it.
-        self.assertFalse(self.pts.supports("obfs3|unknownwhat"))
-        self._observeProxySupporting("unknownwhat")
-        self.assertFalse(self.pts.supports("obfs3|unknownwhat"))
-        self.pts.addEndpoint("A", "obfs3|unknownwhat")
-        self.assertTrue(self.pts.supports("obfs3|unknownwhat"))
-        self.assertFalse(self.pts.supports("obfs2|unknownwhat"))
+        self.assertEquals(self.pts._indexes["p"]["a|b"], set("A"))
+        self.pts.addEndpoint("B", "a|b|p")
+        self.assertEquals(self.pts._indexes["p"]["a|b"], set("AB"))
+        self.pts.delEndpoint("A")
+        self.assertEquals(self.pts._indexes["p"]["a|b"], set("B"))
+        self.pts.delEndpoint("B")
+        self.assertEquals(self.pts._indexes["p"]["a|b"], set(""))
 
     def test_serveReg_maxserve_infinite_roundrobin(self):
         # case for servers, they never exhaust





More information about the tor-commits mailing list