[tor-commits] [stem/master] Make it possible to use a function to connect to the async controller
atagar at torproject.org
atagar at torproject.org
Thu Jul 16 01:28:59 UTC 2020
commit 459612e63181218d79d2a42ab5b0eebd0cb206bf
Author: Illia Volochii <illia.volochii at gmail.com>
Date: Tue Apr 28 23:54:52 2020 +0300
Make it possible to use a function to connect to the async controller
---
stem/connection.py | 76 ++++++++++++++++++++++++-----------------
stem/control.py | 11 ++++--
test/unit/connection/connect.py | 48 ++++++++++++--------------
3 files changed, 74 insertions(+), 61 deletions(-)
diff --git a/stem/connection.py b/stem/connection.py
index 12330ca3..3d240070 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -133,6 +133,7 @@ import getpass
import hashlib
import hmac
import os
+import threading
import stem.control
import stem.response
@@ -253,6 +254,31 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
# TODO: change this function's API so we can provide a concrete type
+ if controller is None or not issubclass(controller, stem.control.Controller):
+ raise ValueError('Controller should be a stem.control.BaseController subclass.')
+
+ async_controller_thread = stem.control._AsyncControllerThread()
+ async_controller_thread.start()
+
+ connect_coroutine = _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller)
+ try:
+ connection = asyncio.run_coroutine_threadsafe(connect_coroutine, async_controller_thread.loop).result()
+ if connection is None and async_controller_thread.is_alive():
+ async_controller_thread.join()
+ return connection
+ except:
+ if async_controller_thread.is_alive():
+ async_controller_thread.join()
+ raise
+
+
+async def connect_async(control_port = ('127.0.0.1', 'default'), control_socket = '/var/run/tor/control', password = None, password_prompt = False, chroot_path = None, controller = stem.control.AsyncController):
+ if controller and not issubclass(controller, stem.control.BaseController):
+ raise ValueError('The provided controller should be a stem.control.BaseController subclass.')
+ return await _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller)
+
+
+async def _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller):
if control_port is None and control_socket is None:
raise ValueError('Neither a control port nor control socket were provided. Nothing to connect to.')
elif control_port:
@@ -266,17 +292,11 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
control_connection = None # type: Optional[stem.socket.ControlSocket]
error_msg = ''
- async_controller_thread = stem.control._AsyncControllerThread()
- async_controller_thread.start()
-
- def connect_socket(socket):
- asyncio.run_coroutine_threadsafe(socket.connect(), async_controller_thread.loop).result()
-
if control_socket:
if os.path.exists(control_socket):
try:
control_connection = stem.socket.ControlSocketFile(control_socket)
- connect_socket(control_connection)
+ await control_connection.connect()
except stem.SocketError as exc:
error_msg = CONNECT_MESSAGES['unable_to_use_socket'].format(path = control_socket, error = exc)
else:
@@ -290,7 +310,7 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
control_connection = _connection_for_default_port(address)
else:
control_connection = stem.socket.ControlPort(address, int(port))
- connect_socket(control_connection)
+ await control_connection.connect()
except stem.SocketError as exc:
error_msg = CONNECT_MESSAGES['unable_to_use_port'].format(address = address, port = port, error = exc)
@@ -304,14 +324,12 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
error_msg = CONNECT_MESSAGES['no_control_port'] if is_tor_running else CONNECT_MESSAGES['tor_isnt_running']
print(error_msg)
- if async_controller_thread.is_alive():
- async_controller_thread.join()
return None
- return _connect_auth(control_connection, password, password_prompt, chroot_path, controller, async_controller_thread)
+ return await _connect_auth(control_connection, password, password_prompt, chroot_path, controller)
-def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[stem.control.BaseController]], async_controller_thread: 'threading.Thread') -> Any:
+async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[stem.control.BaseController]]) -> Any:
"""
Helper for the connect_* functions that authenticates the socket and
constructs the controller.
@@ -327,61 +345,55 @@ def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, pass
:returns: authenticated control connection, the type based on the controller argument
"""
- def run_coroutine(coroutine):
- asyncio.run_coroutine_threadsafe(coroutine, async_controller_thread.loop).result()
-
- def close_control_socket():
- run_coroutine(control_socket.close())
- if async_controller_thread.is_alive():
- async_controller_thread.join()
-
try:
- run_coroutine(authenticate(control_socket, password, chroot_path))
+ await authenticate(control_socket, password, chroot_path)
if controller is None:
return control_socket
- else:
- return controller(control_socket, is_authenticated = True, started_async_controller_thread = async_controller_thread)
+ elif issubclass(controller, stem.control.BaseController):
+ return controller(control_socket, is_authenticated = True)
+ elif issubclass(controller, stem.control.Controller):
+ return controller(control_socket, is_authenticated = True, started_async_controller_thread = threading.current_thread())
except IncorrectSocketType:
if isinstance(control_socket, stem.socket.ControlPort):
print(CONNECT_MESSAGES['wrong_port_type'].format(port = control_socket.port))
else:
print(CONNECT_MESSAGES['wrong_socket_type'])
- close_control_socket()
+ await control_socket.close()
return None
except UnrecognizedAuthMethods as exc:
print(CONNECT_MESSAGES['uncrcognized_auth_type'].format(auth_methods = ', '.join(exc.unknown_auth_methods)))
- close_control_socket()
+ await control_socket.close()
return None
except IncorrectPassword:
print(CONNECT_MESSAGES['incorrect_password'])
- close_control_socket()
+ await control_socket.close()
return None
except MissingPassword:
if password is not None:
- close_control_socket()
+ await control_socket.close()
raise ValueError(CONNECT_MESSAGES['missing_password_bug'])
if password_prompt:
try:
password = getpass.getpass(CONNECT_MESSAGES['password_prompt'] + ' ')
except KeyboardInterrupt:
- close_control_socket()
+ await control_socket.close()
return None
- return _connect_auth(control_socket, password, password_prompt, chroot_path, controller, async_controller_thread)
+ return await _connect_auth(control_socket, password, password_prompt, chroot_path, controller)
else:
print(CONNECT_MESSAGES['needs_password'])
- close_control_socket()
+ await control_socket.close()
return None
except UnreadableCookieFile as exc:
print(CONNECT_MESSAGES['unreadable_cookie_file'].format(path = exc.cookie_path, issue = str(exc)))
- close_control_socket()
+ await control_socket.close()
return None
except AuthenticationFailure as exc:
print(CONNECT_MESSAGES['general_auth_failure'].format(error = exc))
- close_control_socket()
+ await control_socket.close()
return None
diff --git a/stem/control.py b/stem/control.py
index 21a89a5a..b2d2d9d7 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -3946,10 +3946,15 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
self._socket = self._async_controller._socket
def _init_async_controller(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool) -> 'stem.control.AsyncController':
- async def init_async_controller():
- return AsyncController(control_socket, is_authenticated)
+ # The asynchronous controller should be initialized in the thread where its
+ # methods will be executed.
+ if self._async_controller_thread != threading.current_thread():
+ async def init_async_controller() -> 'stem.control.AsyncController':
+ return AsyncController(control_socket, is_authenticated)
- return asyncio.run_coroutine_threadsafe(init_async_controller(), self._asyncio_loop).result()
+ return asyncio.run_coroutine_threadsafe(init_async_controller(), self._asyncio_loop).result()
+
+ return AsyncController(control_socket, is_authenticated)
def _execute_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
return asyncio.run_coroutine_threadsafe(
diff --git a/test/unit/connection/connect.py b/test/unit/connection/connect.py
index 8ba2770c..2112f678 100644
--- a/test/unit/connection/connect.py
+++ b/test/unit/connection/connect.py
@@ -1,7 +1,7 @@
"""
Unit tests for the stem.connection.connect function.
"""
-import contextlib
+
import io
import unittest
@@ -11,7 +11,11 @@ import stem.socket
from unittest.mock import Mock, patch
-from test.unit.async_util import coro_func_raising_exc, coro_func_returning_value
+from test.unit.async_util import (
+ async_test,
+ coro_func_raising_exc,
+ coro_func_returning_value,
+)
class TestConnect(unittest.TestCase):
@@ -20,7 +24,7 @@ class TestConnect(unittest.TestCase):
@patch('os.path.exists', Mock(return_value = True))
@patch('stem.socket.ControlSocketFile', Mock(side_effect = stem.SocketError('failed')))
@patch('stem.socket.ControlPort', Mock(side_effect = stem.SocketError('failed')))
- @patch('stem.connection._connect_auth', Mock())
+ @patch('stem.connection._connect_auth', Mock(side_effect = coro_func_returning_value(None)))
def test_failue_with_the_default_endpoint(self, is_running_mock, stdout_mock):
is_running_mock.return_value = False
self._assert_connect_fails_with({}, stdout_mock, "Unable to connect to tor. Are you sure it's running?")
@@ -33,7 +37,7 @@ class TestConnect(unittest.TestCase):
@patch('stem.util.system.is_running', Mock(return_value = True))
@patch('stem.socket.ControlSocketFile', Mock(side_effect = stem.SocketError('failed')))
@patch('stem.socket.ControlPort', Mock(side_effect = stem.SocketError('failed')))
- @patch('stem.connection._connect_auth', Mock())
+ @patch('stem.connection._connect_auth', Mock(side_effect = coro_func_returning_value(None)))
def test_failure_with_a_custom_endpoint(self, path_exists_mock, stdout_mock):
path_exists_mock.return_value = True
self._assert_connect_fails_with({'control_port': ('127.0.0.1', 80), 'control_socket': None}, stdout_mock, "Unable to connect to 127.0.0.1:80: failed")
@@ -45,7 +49,7 @@ class TestConnect(unittest.TestCase):
@patch('stem.socket.ControlPort')
@patch('os.path.exists', Mock(return_value = False))
- @patch('stem.connection._connect_auth', Mock())
+ @patch('stem.connection._connect_auth', Mock(side_effect = coro_func_returning_value(None)))
def test_getting_a_control_port(self, port_mock):
port_connect_mock = port_mock.return_value.connect
port_connect_mock.side_effect = coro_func_returning_value(None)
@@ -59,7 +63,7 @@ class TestConnect(unittest.TestCase):
@patch('stem.socket.ControlSocketFile')
@patch('os.path.exists', Mock(return_value = True))
- @patch('stem.connection._connect_auth', Mock())
+ @patch('stem.connection._connect_auth', Mock(side_effect = coro_func_returning_value(None)))
def test_getting_a_control_socket(self, socket_mock):
socket_connect_mock = socket_mock.return_value.connect
socket_connect_mock.side_effect = coro_func_returning_value(None)
@@ -92,21 +96,22 @@ class TestConnect(unittest.TestCase):
self.assertEqual(msg, stdout_output.strip().lstrip('\x00'))
@patch('stem.connection.authenticate')
- def test_auth_success(self, authenticate_mock):
+ @async_test
+ async def test_auth_success(self, authenticate_mock):
authenticate_mock.side_effect = coro_func_returning_value(None)
control_socket = Mock()
- with self._get_thread() as thread:
- stem.connection._connect_auth(control_socket, None, False, None, None, thread)
- authenticate_mock.assert_called_with(control_socket, None, None)
- authenticate_mock.reset_mock()
+ await stem.connection._connect_auth(control_socket, None, False, None, None)
+ authenticate_mock.assert_called_with(control_socket, None, None)
+ authenticate_mock.reset_mock()
- stem.connection._connect_auth(control_socket, 's3krit!!!', False, '/my/chroot', None, thread)
+ await stem.connection._connect_auth(control_socket, 's3krit!!!', False, '/my/chroot', None)
authenticate_mock.assert_called_with(control_socket, 's3krit!!!', '/my/chroot')
@patch('getpass.getpass')
@patch('stem.connection.authenticate')
- def test_auth_success_with_password_prompt(self, authenticate_mock, getpass_mock):
+ @async_test
+ async def test_auth_success_with_password_prompt(self, authenticate_mock, getpass_mock):
control_socket = Mock()
async def authenticate_mock_func(controller, password, *args):
@@ -120,8 +125,7 @@ class TestConnect(unittest.TestCase):
authenticate_mock.side_effect = authenticate_mock_func
getpass_mock.return_value = 'my_password'
- with self._get_thread() as thread:
- stem.connection._connect_auth(control_socket, None, True, None, None, thread)
+ await stem.connection._connect_auth(control_socket, None, True, None, None)
authenticate_mock.assert_any_call(control_socket, None, None)
authenticate_mock.assert_any_call(control_socket, 'my_password', None)
@@ -149,9 +153,9 @@ class TestConnect(unittest.TestCase):
authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.OpenAuthRejected('crazy failure'))
self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Unable to authenticate: crazy failure')
- def _assert_authenticate_fails_with(self, control_socket, stdout_mock, msg):
- with self._get_thread() as thread:
- result = stem.connection._connect_auth(control_socket, None, False, None, None, thread)
+ @async_test
+ async def _assert_authenticate_fails_with(self, control_socket, stdout_mock, msg):
+ result = await stem.connection._connect_auth(control_socket, None, False, None, None)
if result is not None:
self.fail() # _connect_auth() was successful
@@ -161,11 +165,3 @@ class TestConnect(unittest.TestCase):
if msg not in stdout_output:
self.fail("Expected...\n\n%s\n\n... which couldn't be found in...\n\n%s" % (msg, stdout_output))
-
- @contextlib.contextmanager
- def _get_thread(self):
- thread = stem.control._AsyncControllerThread()
- thread.start()
- yield thread
- if thread.is_alive():
- thread.join()
More information about the tor-commits
mailing list