[tor-commits] [flashproxy/master] drop registrations that hit _maxserve, and fix tests to run properly
infinity0 at torproject.org
infinity0 at torproject.org
Mon Oct 28 14:47:41 UTC 2013
commit d658db1f29d698a7820ef0d4bdf1120e01fb35af
Author: Ximin Luo <infinity0 at gmx.com>
Date: Wed Oct 9 23:49:38 2013 +0100
drop registrations that hit _maxserve, and fix tests to run properly
- the old behaviour was based on an incorrect understanding of the previous iteration of the code
---
facilitator/facilitator | 20 ++++++++++----------
facilitator/facilitator-test | 23 +++++++++++++++++++----
2 files changed, 29 insertions(+), 14 deletions(-)
diff --git a/facilitator/facilitator b/facilitator/facilitator
index f3ae79b..07038d1 100755
--- a/facilitator/facilitator
+++ b/facilitator/facilitator
@@ -173,17 +173,18 @@ class Endpoints(object):
def _serveReg(self, addrpool):
"""
:param list addrpool: List of candidate addresses.
- :returns: An address of an endpoint from the given pool, or None if all
- endpoints have already been served _maxserve times. The serve
- counter for that address is also incremented.
+ :returns: An address of an endpoint from the given pool. The serve
+ counter for that address is also incremented, and if it hits
+ self._maxserve the endpoint is removed from this collection.
+ :raises: KeyError if any address is not registered with this collection
"""
- if not addrpool: return None
+ if not addrpool: raise ValueError("gave empty address pool")
prio_addr = min(addrpool, key=lambda a: self._served[a])
- if self._served[prio_addr] < self._maxserve:
- self._served[prio_addr] += 1
- return prio_addr
- else:
- return None
+ assert self._served[prio_addr] < self._maxserve
+ self._served[prio_addr] += 1
+ if self._served[prio_addr] == self._maxserve:
+ self.delEndpoint(prio_addr)
+ return prio_addr
def _ensureIndexForSuffix(self, suf):
if suf in self._indexes: return
@@ -227,7 +228,6 @@ class Endpoints(object):
assert all(client_pre.itervalues()) # no pool is empty
pre = min(both, key=lambda p: ptsClient._avServed(client_pre[p]))
client_addr = ptsClient._serveReg(client_pre[pre])
- if not client_addr: return Endpoints.EMPTY_MATCH
server_addr = ptsServer._serveReg(server_pre[pre])
# assume servers never run out
client_transport = ptsClient._endpoints[client_addr]
diff --git a/facilitator/facilitator-test b/facilitator/facilitator-test
index 6709221..3f2fbef 100755
--- a/facilitator/facilitator-test
+++ b/facilitator/facilitator-test
@@ -80,7 +80,7 @@ class EndpointsTest(unittest.TestCase):
self.assertTrue(self.pts.supports("obfs3|unknownwhat"))
self.assertFalse(self.pts.supports("obfs2|unknownwhat"))
- def _test_serveReg_maxserve_infinite_roundrobin(self):
+ def test_serveReg_maxserve_infinite_roundrobin(self):
# case for servers, they never exhaust
self.pts.addEndpoint("A", "a|p")
self.pts.addEndpoint("B", "a|p")
@@ -92,7 +92,7 @@ class EndpointsTest(unittest.TestCase):
served.add(self.pts._serveReg("ABC"))
self.assertEquals(served, set("ABC"))
- def _test_serveReg_maxserve_finite_exhaustion(self):
+ def test_serveReg_maxserve_finite_exhaustion(self):
# case for clients, we don't want to keep serving them
self.pts = Endpoints(af=socket.AF_INET, maxserve=5)
self.pts.addEndpoint("A", "a|p")
@@ -100,13 +100,28 @@ class EndpointsTest(unittest.TestCase):
self.pts.addEndpoint("C", "a|p")
# test getNumUnservedEndpoints whilst we're at it
self.assertEquals(self.pts.getNumUnservedEndpoints(), 3)
- for i in xrange(5):
+ served = set()
+ served.add(self.pts._serveReg("ABC"))
+ self.assertEquals(self.pts.getNumUnservedEndpoints(), 2)
+ served.add(self.pts._serveReg("ABC"))
+ self.assertEquals(self.pts.getNumUnservedEndpoints(), 1)
+ served.add(self.pts._serveReg("ABC"))
+ self.assertEquals(self.pts.getNumUnservedEndpoints(), 0)
+ self.assertEquals(served, set("ABC"))
+ for i in xrange(5-2):
served = set()
served.add(self.pts._serveReg("ABC"))
served.add(self.pts._serveReg("ABC"))
served.add(self.pts._serveReg("ABC"))
self.assertEquals(served, set("ABC"))
- self.assertEquals(None, self.pts._serveReg("ABC"))
+ remaining = set("ABC")
+ remaining.remove(self.pts._serveReg(remaining))
+ self.assertRaises(KeyError, self.pts._serveReg, "ABC")
+ remaining.remove(self.pts._serveReg(remaining))
+ self.assertRaises(KeyError, self.pts._serveReg, "ABC")
+ remaining.remove(self.pts._serveReg(remaining))
+ self.assertRaises(KeyError, self.pts._serveReg, "ABC")
+ self.assertEquals(remaining, set())
self.assertEquals(self.pts.getNumUnservedEndpoints(), 0)
def test_match_normal(self):
More information about the tor-commits
mailing list