[tor-commits] [stem/master] Remove Query's Synchronous usage
atagar at torproject.org
atagar at torproject.org
Sun Nov 8 01:24:38 UTC 2020
commit 7ce8a5e090fc95bfb874299d61c824638d5242f4
Author: Damian Johnson <atagar at torproject.org>
Date: Sat Nov 7 17:18:53 2020 -0800
Remove Query's Synchronous usage
First step to remove our asyncio metaprogramming...
https://github.com/torproject/stem/issues/77
Our Query class now provides a run method for synchronous users, and run_async
for asyncio. This also adds a stop method that can cancel our download.
---
stem/descriptor/remote.py | 114 ++++++++++++++++++++++++++++++++---------
test/unit/descriptor/remote.py | 53 ++++++++++++++-----
2 files changed, 130 insertions(+), 37 deletions(-)
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index 136b9d15..50b3065c 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -100,8 +100,7 @@ import stem.util.tor_tools
from stem.descriptor import Compression
from stem.util import log, str_tools
-from stem.util.asyncio import Synchronous
-from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Union
+from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence, Tuple, Union
# Tor has a limited number of descriptors we can fetch explicitly by their
# fingerprint or hashes due to a limit on the url length by squid proxies.
@@ -227,7 +226,7 @@ def get_detached_signatures(**query_args: Any) -> 'stem.descriptor.remote.Query'
return get_instance().get_detached_signatures(**query_args)
-class Query(Synchronous):
+class Query(object):
"""
Asynchronous request for descriptor content from a directory authority or
mirror. These can either be made through the
@@ -369,7 +368,6 @@ class Query(Synchronous):
super(Query, self).__init__()
if not resource.startswith('/'):
- self.stop()
raise ValueError("Resources should start with a '/': %s" % resource)
if resource.endswith('.z'):
@@ -380,7 +378,6 @@ class Query(Synchronous):
elif isinstance(compression, stem.descriptor._Compression):
compression = [compression] # caller provided only a single option
else:
- self.stop()
raise ValueError('Compression should be a list of stem.descriptor.Compression, was %s (%s)' % (compression, type(compression).__name__))
if Compression.ZSTD in compression and not Compression.ZSTD.available:
@@ -404,7 +401,6 @@ class Query(Synchronous):
if isinstance(endpoint, (stem.ORPort, stem.DirPort)):
self.endpoints.append(endpoint)
else:
- self.stop()
raise ValueError("Endpoints must be an stem.ORPort or stem.DirPort. '%s' is a %s." % (endpoint, type(endpoint).__name__))
self.resource = resource
@@ -428,6 +424,12 @@ class Query(Synchronous):
self._downloader_task = None # type: Optional[asyncio.Task]
self._downloader_lock = threading.RLock()
+ # background thread if outside an asyncio context
+
+ self._loop = None # type: Optional[asyncio.AbstractEventLoop]
+ self._loop_thread = None # type: Optional[threading.Thread]
+ self._loop_lock = threading.RLock()
+
if start:
self.start()
@@ -441,9 +443,38 @@ class Query(Synchronous):
with self._downloader_lock:
if self._downloader_task is None:
- self._downloader_task = self._loop.create_task(Query._download_descriptors(self, self.retries, self.timeout))
+ with self._loop_lock:
+ if self._loop is None:
+ try:
+ self._loop = asyncio.get_running_loop()
+ except RuntimeError:
+ self._loop = asyncio.new_event_loop()
+ self._loop_thread = threading.Thread(
+ name = 'stem.descriptor.remote query',
+ target = self._loop.run_forever,
+ daemon = True,
+ )
+
+ self._loop_thread.start()
+
+ self._downloader_task = self._loop.create_task(self._download_descriptors(self.retries, self.timeout))
+
+ def stop(self) -> None:
+ """
+ Aborts our download if it's in progress, and cleans up underlying
+ resources.
+ """
- async def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
+ with self._downloader_lock:
+ if self._downloader_task and not self._downloader_task.done():
+ self._downloader_task.cancel()
+
+ with self._loop_lock:
+ if self._loop_thread and self._loop_thread.is_alive():
+ self._loop.call_soon_threadsafe(self._loop.stop)
+ self._loop_thread.join()
+
+ def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
"""
Blocks until our request is complete then provides the descriptors. If we
haven't yet started our request then this does so.
@@ -461,12 +492,43 @@ class Query(Synchronous):
* :class:`~stem.DownloadFailed` if our request fails
"""
- try:
- return [desc async for desc in self._run(suppress)]
- finally:
- self.stop()
+ if not self.downloaded and not self.error:
+ with self._loop_lock:
+ if self._loop is None:
+ self.start()
+
+ async def run_wrapper():
+ return [desc async for desc in self.run_async(suppress = True)]
+
+ asyncio.run_coroutine_threadsafe(run_wrapper(), self._loop).result()
+
+ self.stop()
+
+ if self.error:
+ if suppress:
+ return []
+
+ raise self.error
+ else:
+ return list(self.downloaded)
+
+ async def run_async(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]:
+ """
+ Asynchronous counterpart of :func:`stem.descriptor.remote.Query.run`
+
+ :param suppress: avoids raising exceptions if **True**
+
+ :returns: iterator for the requested :class:`~stem.descriptor.__init__.Descriptor` instances
+
+ :raises:
+ Using the iterator can fail with the following if **suppress** is
+ **False**...
+
+ * **ValueError** if the descriptor contents is malformed
+ * :class:`~stem.DownloadTimeout` if our request timed out
+ * :class:`~stem.DownloadFailed` if our request fails
+ """
- async def _run(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]:
with self._downloader_lock:
if not self.downloaded and not self.error:
if not self._downloader_task:
@@ -477,17 +539,21 @@ class Query(Synchronous):
except Exception as exc:
self.error = exc
- if self.error:
- if suppress:
- return
+ if self.error:
+ if suppress:
+ return
- raise self.error
- else:
- for desc in self.downloaded:
- yield desc
+ raise self.error
+ else:
+ for desc in self.downloaded:
+ yield desc
+
+ def __iter__(self) -> Iterator[stem.descriptor.Descriptor]:
+ for desc in self.run(True):
+ yield desc
async def __aiter__(self) -> AsyncIterator[stem.descriptor.Descriptor]:
- async for desc in self._run(True):
+ async for desc in self.run_async(True):
yield desc
def _pick_endpoint(self, use_authority: bool = False) -> stem.Endpoint:
@@ -620,7 +686,7 @@ class DescriptorDownloader(object):
directories = [auth for auth in stem.directory.Authority.from_cache().values() if auth.nickname not in DIR_PORT_BLACKLIST]
new_endpoints = set([stem.DirPort(directory.address, directory.dir_port) for directory in directories])
- consensus = list(self.get_consensus(document_handler = stem.descriptor.DocumentHandler.DOCUMENT).run())[0] # type: ignore
+ consensus = list(self.get_consensus(document_handler = stem.descriptor.DocumentHandler.DOCUMENT).run())[0]
for desc in consensus.routers.values():
if stem.Flag.V2DIR in desc.flags and desc.dir_port:
@@ -630,7 +696,7 @@ class DescriptorDownloader(object):
self._endpoints = list(new_endpoints)
- return consensus
+ return consensus # type: ignore
def their_server_descriptor(self, **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
@@ -785,7 +851,7 @@ class DescriptorDownloader(object):
# authority key certificates
if consensus_query.validate and consensus_query.document_handler == stem.descriptor.DocumentHandler.DOCUMENT:
- consensus = list(consensus_query.run())[0] # type: ignore
+ consensus = list(consensus_query.run())[0]
key_certs = self.get_key_certificates(**query_args).run()
try:
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index 58c7276a..8635d6bd 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -2,6 +2,7 @@
Unit tests for stem.descriptor.remote.
"""
+import time
import unittest
import stem
@@ -87,12 +88,50 @@ class TestDescriptorDownloader(unittest.TestCase):
query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False)
self.assertTrue(query._downloader_task is None)
- query.stop()
query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = True)
self.assertTrue(query._downloader_task is not None)
query.stop()
+ def test_stop(self):
+ """
+ Stop a complete, in-process, and unstarted query.
+ """
+
+ # stop a completed query
+
+ with mock_download(TEST_DESCRIPTOR):
+ query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31')
+ self.assertTrue(query._loop_thread.is_alive())
+
+ query.run() # complete the query
+ self.assertFalse(query._loop_thread.is_alive())
+ self.assertFalse(query._downloader_task.cancelled())
+
+ query.stop() # nothing to do
+ self.assertFalse(query._loop_thread.is_alive())
+ self.assertFalse(query._downloader_task.cancelled())
+
+ # stop an in-process query
+
+ def pause(*args):
+ time.sleep(5)
+
+ with patch('stem.descriptor.remote.Query._download_from', Mock(side_effect = pause)):
+ query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31')
+
+ query.stop() # terminates in-process query
+ self.assertFalse(query._loop_thread.is_alive())
+ self.assertTrue(query._downloader_task.cancelled())
+
+ # stop an unstarted query
+
+ query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False)
+
+ query.stop() # nothing to do
+ self.assertTrue(query._loop_thread is None)
+ self.assertTrue(query._downloader_task is None)
+
@mock_download(TEST_DESCRIPTOR)
def test_download(self):
"""
@@ -115,8 +154,6 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint)
self.assertEqual(TEST_DESCRIPTOR.rstrip(), desc.get_bytes())
- reply.stop()
-
def test_response_header_code(self):
"""
When successful Tor provides a '200 OK' status, but we should accept other 2xx
@@ -165,13 +202,11 @@ class TestDescriptorDownloader(unittest.TestCase):
descriptors = list(query)
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
- query.stop()
def test_gzip_url_override(self):
query = stem.descriptor.remote.Query(TEST_RESOURCE + '.z', compression = Compression.PLAINTEXT, start = False)
self.assertEqual([stem.descriptor.Compression.GZIP], query.compression)
self.assertEqual(TEST_RESOURCE, query.resource)
- query.stop()
@mock_download(read_resource('compressed_identity'), encoding = 'identity')
def test_compression_plaintext(self):
@@ -187,7 +222,6 @@ class TestDescriptorDownloader(unittest.TestCase):
)
descriptors = list(query)
- query.stop()
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
@@ -206,7 +240,6 @@ class TestDescriptorDownloader(unittest.TestCase):
)
descriptors = list(query)
- query.stop()
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
@@ -227,7 +260,6 @@ class TestDescriptorDownloader(unittest.TestCase):
)
descriptors = list(query)
- query.stop()
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
@@ -248,7 +280,6 @@ class TestDescriptorDownloader(unittest.TestCase):
)
descriptors = list(query)
- query.stop()
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
@@ -300,8 +331,6 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertRaises(ValueError, query.run)
- query.stop()
-
def test_query_with_invalid_endpoints(self):
invalid_endpoints = {
'hello': "'h' is a str.",
@@ -330,5 +359,3 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertEqual(1, len(list(query)))
self.assertEqual(1, len(list(query)))
self.assertEqual(1, len(list(query)))
-
- query.stop()
More information about the tor-commits
mailing list