[tor-commits] [flashproxy/master] - fix match() trying to access info of deleted address
infinity0 at torproject.org
infinity0 at torproject.org
Mon Oct 28 14:47:41 UTC 2013
commit 9b155c9b2029ec779666a661e5eecf8226cef6c8
Author: Ximin Luo <infinity0 at gmx.com>
Date: Fri Oct 11 17:12:07 2013 +0100
- fix match() trying to access info of deleted address
---
facilitator/facilitator | 16 +++++++---------
facilitator/facilitator-test | 38 ++++++++++++++++++++++++++------------
2 files changed, 33 insertions(+), 21 deletions(-)
diff --git a/facilitator/facilitator b/facilitator/facilitator
index 9e11826..844836a 100755
--- a/facilitator/facilitator
+++ b/facilitator/facilitator
@@ -161,7 +161,7 @@ class Endpoints(object):
def _serveReg(self, addrpool):
"""
:param list addrpool: List of candidate addresses.
- :returns: An address of an endpoint from the given pool. The serve
+ :returns: An Endpoint whose address is from the given pool. The serve
counter for that address is also incremented, and if it hits
self._maxserve the endpoint is removed from this collection.
:raises: KeyError if any address is not registered with this collection
@@ -170,9 +170,10 @@ class Endpoints(object):
prio_addr = min(addrpool, key=lambda a: self._served[a])
assert self._served[prio_addr] < self._maxserve
self._served[prio_addr] += 1
+ transport = self._endpoints[prio_addr]
if self._served[prio_addr] == self._maxserve:
self.delEndpoint(prio_addr)
- return prio_addr
+ return Endpoint(prio_addr, transport)
EMPTY_MATCH = (None, None)
@staticmethod
@@ -195,16 +196,13 @@ class Endpoints(object):
# find a client to serve
client_pool = [addr for inner in both for addr in client_inner[inner]]
assert len(client_pool)
- client_addr = ptsClient._serveReg(client_pool)
+ client_reg = ptsClient._serveReg(client_pool)
# find a server to serve that has the same inner transport
- inner = ptsClient._endpoints[client_addr].inner
+ inner = client_reg.transport.inner
assert inner in server_inner and len(server_inner[inner])
- server_addr = ptsServer._serveReg(server_inner[inner])
+ server_reg = ptsServer._serveReg(server_inner[inner])
# assume servers never run out
- client_transport = ptsClient._endpoints[client_addr]
- server_transport = ptsServer._endpoints[server_addr]
- return (Endpoint(client_addr, client_transport),
- Endpoint(server_addr, server_transport))
+ return (client_reg, server_reg)
class Handler(SocketServer.StreamRequestHandler):
diff --git a/facilitator/facilitator-test b/facilitator/facilitator-test
index 8143348..3efe34d 100755
--- a/facilitator/facilitator-test
+++ b/facilitator/facilitator-test
@@ -60,9 +60,9 @@ class EndpointsTest(unittest.TestCase):
self.pts.addEndpoint("C", "a|p")
for i in xrange(64): # 64 is infinite ;)
served = set()
- served.add(self.pts._serveReg("ABC"))
- served.add(self.pts._serveReg("ABC"))
- served.add(self.pts._serveReg("ABC"))
+ served.add(self.pts._serveReg("ABC").addr)
+ served.add(self.pts._serveReg("ABC").addr)
+ served.add(self.pts._serveReg("ABC").addr)
self.assertEquals(served, set("ABC"))
def test_serveReg_maxserve_finite_exhaustion(self):
@@ -74,25 +74,25 @@ class EndpointsTest(unittest.TestCase):
# test getNumUnservedEndpoints whilst we're at it
self.assertEquals(self.pts.getNumUnservedEndpoints(), 3)
served = set()
- served.add(self.pts._serveReg("ABC"))
+ served.add(self.pts._serveReg("ABC").addr)
self.assertEquals(self.pts.getNumUnservedEndpoints(), 2)
- served.add(self.pts._serveReg("ABC"))
+ served.add(self.pts._serveReg("ABC").addr)
self.assertEquals(self.pts.getNumUnservedEndpoints(), 1)
- served.add(self.pts._serveReg("ABC"))
+ served.add(self.pts._serveReg("ABC").addr)
self.assertEquals(self.pts.getNumUnservedEndpoints(), 0)
self.assertEquals(served, set("ABC"))
for i in xrange(5-2):
served = set()
- served.add(self.pts._serveReg("ABC"))
- served.add(self.pts._serveReg("ABC"))
- served.add(self.pts._serveReg("ABC"))
+ served.add(self.pts._serveReg("ABC").addr)
+ served.add(self.pts._serveReg("ABC").addr)
+ served.add(self.pts._serveReg("ABC").addr)
self.assertEquals(served, set("ABC"))
remaining = set("ABC")
- remaining.remove(self.pts._serveReg(remaining))
+ remaining.remove(self.pts._serveReg(remaining).addr)
self.assertRaises(KeyError, self.pts._serveReg, "ABC")
- remaining.remove(self.pts._serveReg(remaining))
+ remaining.remove(self.pts._serveReg(remaining).addr)
self.assertRaises(KeyError, self.pts._serveReg, "ABC")
- remaining.remove(self.pts._serveReg(remaining))
+ remaining.remove(self.pts._serveReg(remaining).addr)
self.assertRaises(KeyError, self.pts._serveReg, "ABC")
self.assertEquals(remaining, set())
self.assertEquals(self.pts.getNumUnservedEndpoints(), 0)
@@ -151,6 +151,20 @@ class EndpointsTest(unittest.TestCase):
self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"]))
self.assertEquals(empty, Endpoints.match(self.pts, self.pts2, ["x"]))
+ def test_match_exhaustion(self):
+ self.pts.addEndpoint("A", "p")
+ self.pts2 = Endpoints(af=socket.AF_INET, maxserve=2)
+ self.pts2.addEndpoint("B", "p")
+ print self.pts2._indexes, self.pts2._served
+ Endpoints.match(self.pts2, self.pts, ["p"])
+ print self.pts2._indexes, self.pts2._served
+ Endpoints.match(self.pts2, self.pts, ["p"])
+ empty = Endpoints.EMPTY_MATCH
+ self.assertTrue("B" not in self.pts2._endpoints)
+ self.assertTrue("B" not in self.pts2._indexes["p"][""])
+ self.assertEquals(empty, Endpoints.match(self.pts2, self.pts, ["p"]))
+
+
class FacilitatorTest(unittest.TestCase):
def test_transport_parse(self):
More information about the tor-commits
mailing list