[tor-commits] [stem/master] Fix unit tests

atagar at torproject.org atagar at torproject.org
Thu Jul 16 01:28:59 UTC 2020


commit 1db0e6b84e870a5f228f3a770daca542bdef5d4e
Author: Illia Volochii <illia.volochii at gmail.com>
Date:   Sun Apr 26 22:31:12 2020 +0300

    Fix unit tests
---
 test/unit/connection/authentication.py |  36 +++--
 test/unit/connection/connect.py        |  19 +--
 test/unit/control/controller.py        | 254 +++++++++++++++++++--------------
 test/unit/response/control_message.py  |  10 +-
 4 files changed, 188 insertions(+), 131 deletions(-)

diff --git a/test/unit/connection/authentication.py b/test/unit/connection/authentication.py
index f6241e0e..596fa50c 100644
--- a/test/unit/connection/authentication.py
+++ b/test/unit/connection/authentication.py
@@ -14,41 +14,52 @@ import unittest
 import stem.connection
 import test
 
-from unittest.mock import Mock, patch
+from unittest.mock import patch
 
 from stem.response import ControlMessage
 from stem.util import log
+from test.unit.util.asynchronous import (
+  async_test,
+  coro_func_raising_exc,
+  coro_func_returning_value,
+)
 
 
 class TestAuthenticate(unittest.TestCase):
   @patch('stem.connection.get_protocolinfo')
-  @patch('stem.connection.authenticate_none', Mock())
-  def test_with_get_protocolinfo(self, get_protocolinfo_mock):
+  @patch('stem.connection.authenticate_none')
+  @async_test
+  async def test_with_get_protocolinfo(self, authenticate_none_mock, get_protocolinfo_mock):
     """
     Tests the authenticate() function when it needs to make a get_protocolinfo.
     """
 
     # tests where get_protocolinfo succeeds
 
+    authenticate_none_mock.side_effect = coro_func_returning_value(None)
+
     protocolinfo_message = ControlMessage.from_str('250-PROTOCOLINFO 1\r\n250 OK\r\n', 'PROTOCOLINFO')
     protocolinfo_message.auth_methods = (stem.connection.AuthMethod.NONE, )
-    get_protocolinfo_mock.return_value = protocolinfo_message
+    get_protocolinfo_mock.side_effect = coro_func_returning_value(protocolinfo_message)
 
-    stem.connection.authenticate(None)
+    await stem.connection.authenticate(None)
 
     # tests where get_protocolinfo raises an exception
 
     get_protocolinfo_mock.side_effect = stem.ProtocolError
-    self.assertRaises(stem.connection.IncorrectSocketType, stem.connection.authenticate, None)
+    with self.assertRaises(stem.connection.IncorrectSocketType):
+      await stem.connection.authenticate(None)
 
     get_protocolinfo_mock.side_effect = stem.SocketError
-    self.assertRaises(stem.connection.AuthenticationFailure, stem.connection.authenticate, None)
+    with self.assertRaises(stem.connection.AuthenticationFailure):
+      await stem.connection.authenticate(None)
 
   @patch('stem.connection.authenticate_none')
   @patch('stem.connection.authenticate_password')
   @patch('stem.connection.authenticate_cookie')
   @patch('stem.connection.authenticate_safecookie')
-  def test_all_use_cases(self, authenticate_safecookie_mock, authenticate_cookie_mock, authenticate_password_mock, authenticate_none_mock):
+  @async_test
+  async def test_all_use_cases(self, authenticate_safecookie_mock, authenticate_cookie_mock, authenticate_password_mock, authenticate_none_mock):
     """
     Does basic validation that all valid use cases for the PROTOCOLINFO input
     and dependent functions result in either success or a AuthenticationFailed
@@ -133,15 +144,16 @@ class TestAuthenticate(unittest.TestCase):
                 auth_mock, raised_exc = authenticate_safecookie_mock, auth_cookie_exc
 
               if raised_exc:
-                auth_mock.side_effect = raised_exc
+                auth_mock.side_effect = coro_func_raising_exc(raised_exc)
               else:
-                auth_mock.side_effect = None
+                auth_mock.side_effect = coro_func_returning_value(None)
                 expect_success = True
 
             if expect_success:
-              stem.connection.authenticate(None, 'blah', None, protocolinfo)
+              await stem.connection.authenticate(None, 'blah', None, protocolinfo)
             else:
-              self.assertRaises(stem.connection.AuthenticationFailure, stem.connection.authenticate, None, 'blah', None, protocolinfo)
+              with self.assertRaises(stem.connection.AuthenticationFailure):
+                await stem.connection.authenticate(None, 'blah', None, protocolinfo)
 
     # revert logging back to normal
     stem_logger.setLevel(log.logging_level(log.TRACE))
diff --git a/test/unit/connection/connect.py b/test/unit/connection/connect.py
index 175a1ebd..d2a22f18 100644
--- a/test/unit/connection/connect.py
+++ b/test/unit/connection/connect.py
@@ -11,6 +11,8 @@ import stem.socket
 
 from unittest.mock import Mock, patch
 
+from test.unit.util.asynchronous import coro_func_raising_exc, coro_func_returning_value
+
 
 class TestConnect(unittest.TestCase):
   @patch('sys.stdout', new_callable = io.StringIO)
@@ -85,6 +87,7 @@ class TestConnect(unittest.TestCase):
 
   @patch('stem.connection.authenticate')
   def test_auth_success(self, authenticate_mock):
+    authenticate_mock.side_effect = coro_func_returning_value(None)
     control_socket = Mock()
 
     stem.connection._connect_auth(control_socket, None, False, None, None)
@@ -99,7 +102,7 @@ class TestConnect(unittest.TestCase):
   def test_auth_success_with_password_prompt(self, authenticate_mock, getpass_mock):
     control_socket = Mock()
 
-    def authenticate_mock_func(controller, password, *args):
+    async def authenticate_mock_func(controller, password, *args):
       if password is None:
         raise stem.connection.MissingPassword('no password')
       elif password == 'my_password':
@@ -117,25 +120,25 @@ class TestConnect(unittest.TestCase):
   @patch('sys.stdout', new_callable = io.StringIO)
   @patch('stem.connection.authenticate')
   def test_auth_failure(self, authenticate_mock, stdout_mock):
-    control_socket = stem.socket.ControlPort(connect = False)
+    control_socket = stem.socket.ControlPort()
 
-    authenticate_mock.side_effect = stem.connection.IncorrectSocketType('unable to connect to socket')
+    authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.IncorrectSocketType('unable to connect to socket'))
     self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Please check in your torrc that 9051 is the ControlPort.')
 
-    control_socket = stem.socket.ControlSocketFile(connect = False)
+    control_socket = stem.socket.ControlSocketFile()
 
     self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Are you sure the interface you specified belongs to')
 
-    authenticate_mock.side_effect = stem.connection.UnrecognizedAuthMethods('unable to connect', ['telepathy'])
+    authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.UnrecognizedAuthMethods('unable to connect', ['telepathy']))
     self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Tor is using a type of authentication we do not recognize...\n\n  telepathy')
 
-    authenticate_mock.side_effect = stem.connection.IncorrectPassword('password rejected')
+    authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.IncorrectPassword('password rejected'))
     self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Incorrect password')
 
-    authenticate_mock.side_effect = stem.connection.UnreadableCookieFile('permission denied', '/tmp/my_cookie', False)
+    authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.UnreadableCookieFile('permission denied', '/tmp/my_cookie', False))
     self._assert_authenticate_fails_with(control_socket, stdout_mock, "We were unable to read tor's authentication cookie...\n\n  Path: /tmp/my_cookie\n  Issue: permission denied")
 
-    authenticate_mock.side_effect = stem.connection.OpenAuthRejected('crazy failure')
+    authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.OpenAuthRejected('crazy failure'))
     self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Unable to authenticate: crazy failure')
 
   def _assert_authenticate_fails_with(self, control_socket, stdout_mock, msg):
diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py
index c0a07e2a..d09b5ca8 100644
--- a/test/unit/control/controller.py
+++ b/test/unit/control/controller.py
@@ -3,6 +3,7 @@ Unit tests for the stem.control module. The module's primarily exercised via
 integ tests, but a few bits lend themselves to unit testing.
 """
 
+import asyncio
 import datetime
 import io
 import unittest
@@ -20,6 +21,11 @@ from stem import ControllerError, DescriptorUnavailable, InvalidArguments, Inval
 from stem.control import MALFORMED_EVENTS, _parse_circ_path, Listener, Controller, EventType
 from stem.response import ControlMessage
 from stem.exit_policy import ExitPolicy
+from test.unit.util.asynchronous import (
+  async_test,
+  coro_func_raising_exc,
+  coro_func_returning_value,
+)
 
 NS_DESC = 'r %s %s u5lTXJKGsLKufRLnSyVqT7TdGYw 2012-12-30 22:02:49 77.223.43.54 9001 0\ns Fast Named Running Stable Valid\nw Bandwidth=75'
 TEST_TIMESTAMP = 12345
@@ -36,8 +42,9 @@ class TestControl(unittest.TestCase):
     # When initially constructing a controller we need to suppress msg, so our
     # constructor's SETEVENTS requests pass.
 
-    with patch('stem.control.BaseController.msg', Mock()):
+    with patch('stem.control.BaseController.msg', Mock(side_effect = coro_func_returning_value(None))):
       self.controller = Controller(socket)
+      self.async_controller = self.controller._async_controller
 
       self.circ_listener = Mock()
       self.controller.add_event_listener(self.circ_listener, EventType.CIRC)
@@ -59,18 +66,24 @@ class TestControl(unittest.TestCase):
     for event in stem.control.EventType:
       self.assertTrue(stem.control.event_description(event) is not None)
 
-  @patch('stem.control.Controller.msg')
+  @patch('stem.control.AsyncController.msg')
   def test_get_info(self, msg_mock):
-    msg_mock.return_value = ControlMessage.from_str('250-hello=hi right back!\r\n250 OK\r\n', 'GETINFO')
+    message = ControlMessage.from_str('250-hello=hi right back!\r\n250 OK\r\n', 'GETINFO')
+    msg_mock.side_effect = coro_func_returning_value(message)
     self.assertEqual('hi right back!', self.controller.get_info('hello'))
 
-  @patch('stem.control.Controller.msg')
-  def test_get_info_address_caching(self, msg_mock):
-    msg_mock.return_value = ControlMessage.from_str('551 Address unknown\r\n')
+  @patch('stem.control.AsyncController.msg')
+  @async_test
+  async def test_get_info_address_caching(self, msg_mock):
+    def set_message(*args):
+      message = ControlMessage.from_str(*args)
+      msg_mock.side_effect = coro_func_returning_value(message)
 
-    self.assertEqual(None, self.controller._last_address_exc)
+    set_message('551 Address unknown\r\n')
+
+    self.assertEqual(None, self.async_controller._last_address_exc)
     self.assertRaisesWith(stem.OperationFailed, 'Address unknown', self.controller.get_info, 'address')
-    self.assertEqual('Address unknown', str(self.controller._last_address_exc))
+    self.assertEqual('Address unknown', str(self.async_controller._last_address_exc))
     self.assertEqual(1, msg_mock.call_count)
 
     # now that we have a cached failure we should provide that back
@@ -80,27 +93,28 @@ class TestControl(unittest.TestCase):
 
     # invalidates the cache, transitioning from no address to having one
 
-    msg_mock.return_value = ControlMessage.from_str('250-address=17.2.89.80\r\n250 OK\r\n', 'GETINFO')
+    set_message('250-address=17.2.89.80\r\n250 OK\r\n', 'GETINFO')
     self.assertRaisesWith(stem.OperationFailed, 'Address unknown', self.controller.get_info, 'address')
-    self.controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=17.2.89.80 METHOD=DIRSERV\r\n'))
+    await self.async_controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=17.2.89.80 METHOD=DIRSERV\r\n'))
     self.assertEqual('17.2.89.80', self.controller.get_info('address'))
 
     # invalidates the cache, transitioning from one address to another
 
-    msg_mock.return_value = ControlMessage.from_str('250-address=80.89.2.17\r\n250 OK\r\n', 'GETINFO')
+    set_message('250-address=80.89.2.17\r\n250 OK\r\n', 'GETINFO')
     self.assertEqual('17.2.89.80', self.controller.get_info('address'))
-    self.controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=80.89.2.17 METHOD=DIRSERV\r\n'))
+    await self.async_controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=80.89.2.17 METHOD=DIRSERV\r\n'))
     self.assertEqual('80.89.2.17', self.controller.get_info('address'))
 
-  @patch('stem.control.Controller.msg')
-  @patch('stem.control.Controller.get_conf')
+  @patch('stem.control.AsyncController.msg')
+  @patch('stem.control.AsyncController.get_conf')
   def test_get_info_without_fingerprint(self, get_conf_mock, msg_mock):
-    msg_mock.return_value = ControlMessage.from_str('551 Not running in server mode\r\n')
+    message = ControlMessage.from_str('551 Not running in server mode\r\n')
+    msg_mock.side_effect = coro_func_returning_value(message)
     get_conf_mock.return_value = None
 
-    self.assertEqual(None, self.controller._last_fingerprint_exc)
+    self.assertEqual(None, self.async_controller._last_fingerprint_exc)
     self.assertRaisesWith(stem.OperationFailed, 'Not running in server mode', self.controller.get_info, 'fingerprint')
-    self.assertEqual('Not running in server mode', str(self.controller._last_fingerprint_exc))
+    self.assertEqual('Not running in server mode', str(self.async_controller._last_fingerprint_exc))
     self.assertEqual(1, msg_mock.call_count)
 
     # now that we have a cached failure we should provide that back
@@ -114,7 +128,7 @@ class TestControl(unittest.TestCase):
     self.assertRaisesWith(stem.OperationFailed, 'Not running in server mode', self.controller.get_info, 'fingerprint')
     self.assertEqual(2, msg_mock.call_count)
 
-  @patch('stem.control.Controller.get_info')
+  @patch('stem.control.AsyncController.get_info')
   def test_get_version(self, get_info_mock):
     """
     Exercises the get_version() method.
@@ -124,7 +138,7 @@ class TestControl(unittest.TestCase):
       # Use one version for first check.
       version_2_1 = '0.2.1.32'
       version_2_1_object = stem.version.Version(version_2_1)
-      get_info_mock.return_value = version_2_1
+      get_info_mock.side_effect = coro_func_returning_value(version_2_1)
 
       # Return a version with a cold cache.
       self.assertEqual(version_2_1_object, self.controller.get_version())
@@ -132,23 +146,23 @@ class TestControl(unittest.TestCase):
       # Use a different version for second check.
       version_2_2 = '0.2.2.39'
       version_2_2_object = stem.version.Version(version_2_2)
-      get_info_mock.return_value = version_2_2
+      get_info_mock.side_effect = coro_func_returning_value(version_2_2)
 
       # Return a version with a hot cache, so it will be the old version.
       self.assertEqual(version_2_1_object, self.controller.get_version())
 
       # Turn off caching.
-      self.controller._is_caching_enabled = False
+      self.async_controller._is_caching_enabled = False
       # Return a version without caching, so it will be the new version.
       self.assertEqual(version_2_2_object, self.controller.get_version())
 
       # Spec says the getinfo response may optionally be prefixed by 'Tor '. In
       # practice it doesn't but we should accept that.
-      get_info_mock.return_value = 'Tor 0.2.1.32'
+      get_info_mock.side_effect = coro_func_returning_value('Tor 0.2.1.32')
       self.assertEqual(version_2_1_object, self.controller.get_version())
 
       # Raise an exception in the get_info() call.
-      get_info_mock.side_effect = InvalidArguments
+      get_info_mock.side_effect = coro_func_raising_exc(InvalidArguments)
 
       # Get a default value when the call fails.
       self.assertEqual(
@@ -161,22 +175,24 @@ class TestControl(unittest.TestCase):
 
       # Give a bad version.  The stem.version.Version ValueError should bubble up.
       version_A_42 = '0.A.42.spam'
-      get_info_mock.return_value = version_A_42
-      get_info_mock.side_effect = None
+      get_info_mock.side_effect = coro_func_returning_value(version_A_42)
       self.assertRaises(ValueError, self.controller.get_version)
     finally:
       # Turn caching back on before we leave.
       self.controller._is_caching_enabled = True
 
-  @patch('stem.control.Controller.get_info')
+  @patch('stem.control.AsyncController.get_info')
   def test_get_exit_policy(self, get_info_mock):
     """
     Exercises the get_exit_policy() method.
     """
 
-    get_info_mock.side_effect = lambda param, default = None: {
-      'exit-policy/full': 'reject *:25,reject *:119,reject *:135-139,reject *:445,reject *:563,reject *:1214,reject *:4661-4666,reject *:6346-6429,reject *:6699,reject *:6881-6999,accept *:*',
-    }[param]
+    async def get_info_mock_side_effect(param, default = None):
+      return {
+        'exit-policy/full': 'reject *:25,reject *:119,reject *:135-139,reject *:445,reject *:563,reject *:1214,reject *:4661-4666,reject *:6346-6429,reject *:6699,reject *:6881-6999,accept *:*',
+      }[param]
+
+    get_info_mock.side_effect = get_info_mock_side_effect
 
     expected = ExitPolicy(
       'reject *:25',
@@ -194,8 +210,8 @@ class TestControl(unittest.TestCase):
 
     self.assertEqual(str(expected), str(self.controller.get_exit_policy()))
 
-  @patch('stem.control.Controller.get_info')
-  @patch('stem.control.Controller.get_conf')
+  @patch('stem.control.AsyncController.get_info')
+  @patch('stem.control.AsyncController.get_conf')
   def test_get_ports(self, get_conf_mock, get_info_mock):
     """
     Exercises the get_ports() and get_listeners() methods.
@@ -204,12 +220,15 @@ class TestControl(unittest.TestCase):
     # Exercise as an old version of tor that doesn't support the 'GETINFO
     # net/listeners/*' options.
 
-    get_info_mock.side_effect = InvalidArguments
+    get_info_mock.side_effect = coro_func_raising_exc(InvalidArguments)
+
+    async def get_conf_mock_side_effect(param, **kwargs):
+      return {
+        'ControlPort': '9050',
+        'ControlListenAddress': ['127.0.0.1'],
+      }[param]
 
-    get_conf_mock.side_effect = lambda param, *args, **kwargs: {
-      'ControlPort': '9050',
-      'ControlListenAddress': ['127.0.0.1'],
-    }[param]
+    get_conf_mock.side_effect = get_conf_mock_side_effect
 
     self.assertEqual([('127.0.0.1', 9050)], self.controller.get_listeners(Listener.CONTROL))
     self.assertEqual([9050], self.controller.get_ports(Listener.CONTROL))
@@ -217,10 +236,13 @@ class TestControl(unittest.TestCase):
 
     # non-local addresss
 
-    get_conf_mock.side_effect = lambda param, *args, **kwargs: {
-      'ControlPort': '9050',
-      'ControlListenAddress': ['27.4.4.1'],
-    }[param]
+    async def get_conf_mock_side_effect(param, **kwargs):
+      return {
+        'ControlPort': '9050',
+        'ControlListenAddress': ['27.4.4.1'],
+      }[param]
+
+    get_conf_mock.side_effect = get_conf_mock_side_effect
 
     self.assertEqual([('27.4.4.1', 9050)], self.controller.get_listeners(Listener.CONTROL))
     self.assertEqual([], self.controller.get_ports(Listener.CONTROL))
@@ -228,8 +250,8 @@ class TestControl(unittest.TestCase):
 
     # exercise via the GETINFO option
 
-    get_info_mock.side_effect = None
-    get_info_mock.return_value = '"127.0.0.1:1112" "127.0.0.1:1114"'
+    listeners = '"127.0.0.1:1112" "127.0.0.1:1114"'
+    get_info_mock.side_effect = coro_func_returning_value(listeners)
 
     self.assertEqual(
       [('127.0.0.1', 1112), ('127.0.0.1', 1114)],
@@ -241,15 +263,16 @@ class TestControl(unittest.TestCase):
 
     # with all localhost addresses, including a couple that aren't
 
-    get_info_mock.side_effect = None
-    get_info_mock.return_value = '"27.4.4.1:1113" "127.0.0.5:1114" "0.0.0.0:1115" "[::]:1116" "[::1]:1117" "[10::]:1118"'
+    listeners = '"27.4.4.1:1113" "127.0.0.5:1114" "0.0.0.0:1115" "[::]:1116" "[::1]:1117" "[10::]:1118"'
+    get_info_mock.side_effect = coro_func_returning_value(listeners)
 
     self.assertEqual([1114, 1115, 1116, 1117], self.controller.get_ports(Listener.OR))
     self.controller.clear_cache()
 
     # IPv6 address
 
-    get_info_mock.return_value = '"0.0.0.0:9001" "[fe80:0000:0000:0000:0202:b3ff:fe1e:8329]:9001"'
+    listeners = '"0.0.0.0:9001" "[fe80:0000:0000:0000:0202:b3ff:fe1e:8329]:9001"'
+    get_info_mock.side_effect = coro_func_returning_value(listeners)
 
     self.assertEqual(
       [('0.0.0.0', 9001), ('fe80:0000:0000:0000:0202:b3ff:fe1e:8329', 9001)],
@@ -259,25 +282,28 @@ class TestControl(unittest.TestCase):
     # unix socket file
 
     self.controller.clear_cache()
-    get_info_mock.return_value = '"unix:/tmp/tor/socket"'
+    get_info_mock.side_effect = coro_func_returning_value('"unix:/tmp/tor/socket"')
 
     self.assertEqual([], self.controller.get_listeners(Listener.CONTROL))
     self.assertEqual([], self.controller.get_ports(Listener.CONTROL))
 
-  @patch('stem.control.Controller.get_info')
+  @patch('stem.control.AsyncController.get_info')
   @patch('time.time', Mock(return_value = 1410723598.276578))
   def test_get_accounting_stats(self, get_info_mock):
     """
     Exercises the get_accounting_stats() method.
     """
 
-    get_info_mock.side_effect = lambda param, **kwargs: {
-      'accounting/enabled': '1',
-      'accounting/hibernating': 'awake',
-      'accounting/interval-end': '2014-09-14 19:41:00',
-      'accounting/bytes': '4837 2050',
-      'accounting/bytes-left': '102944 7440',
-    }[param]
+    async def get_info_mock_side_effect(param, **kwargs):
+      return {
+        'accounting/enabled': '1',
+        'accounting/hibernating': 'awake',
+        'accounting/interval-end': '2014-09-14 19:41:00',
+        'accounting/bytes': '4837 2050',
+        'accounting/bytes-left': '102944 7440',
+      }[param]
+
+    get_info_mock.side_effect = get_info_mock_side_effect
 
     expected = stem.control.AccountingStats(
       1410723598.276578,
@@ -290,7 +316,7 @@ class TestControl(unittest.TestCase):
 
     self.assertEqual(expected, self.controller.get_accounting_stats())
 
-    get_info_mock.side_effect = ControllerError('nope, too bad')
+    get_info_mock.side_effect = coro_func_raising_exc(ControllerError('nope, too bad'))
     self.assertRaises(ControllerError, self.controller.get_accounting_stats)
     self.assertEqual('my default', self.controller.get_accounting_stats('my default'))
 
@@ -303,7 +329,7 @@ class TestControl(unittest.TestCase):
     # use the handy mocked protocolinfo response
 
     protocolinfo_msg = ControlMessage.from_str('250-PROTOCOLINFO 1\r\n250 OK\r\n', 'PROTOCOLINFO')
-    get_protocolinfo_mock.return_value = protocolinfo_msg
+    get_protocolinfo_mock.side_effect = coro_func_returning_value(protocolinfo_msg)
 
     # compare the str representation of these object, because the class
     # does not have, nor need, a direct comparison operator
@@ -315,7 +341,7 @@ class TestControl(unittest.TestCase):
 
     # raise an exception in the stem.connection.get_protocolinfo() call
 
-    get_protocolinfo_mock.side_effect = ProtocolError
+    get_protocolinfo_mock.side_effect = coro_func_raising_exc(ProtocolError)
 
     # get a default value when the call fails
 
@@ -338,7 +364,7 @@ class TestControl(unittest.TestCase):
     self.assertEqual(123, self.controller.get_user(123))
 
   @patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
-  @patch('stem.control.Controller.get_info', Mock(return_value = 'atagar'))
+  @patch('stem.control.AsyncController.get_info', Mock(side_effect = coro_func_returning_value('atagar')))
   def test_get_user_by_getinfo(self):
     """
     Exercise the get_user() resolution via its getinfo option.
@@ -366,7 +392,7 @@ class TestControl(unittest.TestCase):
     self.assertEqual(123, self.controller.get_pid(123))
 
   @patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
-  @patch('stem.control.Controller.get_info', Mock(return_value = '321'))
+  @patch('stem.control.AsyncController.get_info', Mock(side_effect = coro_func_returning_value('321')))
   def test_get_pid_by_getinfo(self):
     """
     Exercise the get_pid() resolution via its getinfo option.
@@ -375,14 +401,14 @@ class TestControl(unittest.TestCase):
     self.assertEqual(321, self.controller.get_pid())
 
   @patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
-  @patch('stem.control.Controller.get_conf')
+  @patch('stem.control.AsyncController.get_conf')
   @patch('stem.control.open', create = True)
   def test_get_pid_by_pid_file(self, open_mock, get_conf_mock):
     """
     Exercise the get_pid() resolution via a PidFile.
     """
 
-    get_conf_mock.return_value = '/tmp/pid_file'
+    get_conf_mock.side_effect = coro_func_returning_value('/tmp/pid_file')
     open_mock.return_value = io.BytesIO(b'432')
 
     self.assertEqual(432, self.controller.get_pid())
@@ -397,25 +423,25 @@ class TestControl(unittest.TestCase):
 
     self.assertEqual(432, self.controller.get_pid())
 
-  @patch('stem.control.Controller.get_version', Mock(return_value = stem.version.Version('0.5.0.14')))
+  @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14'))))
   @patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = False))
-  @patch('stem.control.Controller.get_info')
+  @patch('stem.control.AsyncController.get_info')
   @patch('time.time', Mock(return_value = 1000.0))
   def test_get_uptime_by_getinfo(self, getinfo_mock):
     """
     Exercise the get_uptime() resolution via a GETINFO query.
     """
 
-    getinfo_mock.return_value = '321'
+    getinfo_mock.side_effect = coro_func_returning_value('321')
     self.assertEqual(321.0, self.controller.get_uptime())
     self.controller.clear_cache()
 
-    getinfo_mock.return_value = 'abc'
+    getinfo_mock.side_effect = coro_func_returning_value('abc')
     self.assertRaisesWith(ValueError, "'GETINFO uptime' did not provide a valid numeric response: abc", self.controller.get_uptime)
 
   @patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
-  @patch('stem.control.Controller.get_version', Mock(return_value = stem.version.Version('0.1.0.14')))
-  @patch('stem.control.Controller.get_pid', Mock(return_value = '12'))
+  @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.1.0.14'))))
+  @patch('stem.control.AsyncController.get_pid', Mock(side_effect = coro_func_returning_value('12')))
   @patch('stem.util.system.start_time', Mock(return_value = 5000.0))
   @patch('time.time', Mock(return_value = 5200.0))
   def test_get_uptime_by_process(self):
@@ -425,7 +451,7 @@ class TestControl(unittest.TestCase):
 
     self.assertEqual(200.0, self.controller.get_uptime())
 
-  @patch('stem.control.Controller.get_info')
+  @patch('stem.control.AsyncController.get_info')
   def test_get_network_status_for_ourselves(self, get_info_mock):
     """
     Exercises the get_network_status() method for getting our own relay.
@@ -433,7 +459,7 @@ class TestControl(unittest.TestCase):
 
     # when there's an issue getting our fingerprint
 
-    get_info_mock.side_effect = ControllerError('nope, too bad')
+    get_info_mock.side_effect = coro_func_raising_exc(ControllerError('nope, too bad'))
 
     exc_msg = 'Unable to determine our own fingerprint: nope, too bad'
     self.assertRaisesWith(ControllerError, exc_msg, self.controller.get_network_status)
@@ -443,25 +469,29 @@ class TestControl(unittest.TestCase):
 
     desc = NS_DESC % ('moria1', '/96bKo4soysolMgKn5Hex2nyFSY')
 
-    get_info_mock.side_effect = lambda param, **kwargs: {
-      'fingerprint': '9695DFC35FFEB861329B9F1AB04C46397020CE31',
-      'ns/id/9695DFC35FFEB861329B9F1AB04C46397020CE31': desc,
-    }[param]
+    async def get_info_mock_side_effect(param, **kwargs):
+      return {
+        'fingerprint': '9695DFC35FFEB861329B9F1AB04C46397020CE31',
+        'ns/id/9695DFC35FFEB861329B9F1AB04C46397020CE31': desc,
+      }[param]
+
+    get_info_mock.side_effect = get_info_mock_side_effect
 
     self.assertEqual(stem.descriptor.router_status_entry.RouterStatusEntryV3(desc), self.controller.get_network_status())
 
-  @patch('stem.control.Controller.get_info')
+  @patch('stem.control.AsyncController.get_info')
   def test_get_network_status_when_unavailable(self, get_info_mock):
     """
     Exercises the get_network_status() method.
     """
 
-    get_info_mock.side_effect = InvalidArguments(None, 'GETINFO request contained unrecognized keywords: ns/id/5AC9C5AA75BA1F18D8459B326B4B8111A856D290')
+    exc = InvalidArguments(None, 'GETINFO request contained unrecognized keywords: ns/id/5AC9C5AA75BA1F18D8459B326B4B8111A856D290')
+    get_info_mock.side_effect = coro_func_raising_exc(exc)
 
     exc_msg = "Tor was unable to provide the descriptor for '5AC9C5AA75BA1F18D8459B326B4B8111A856D290'"
     self.assertRaisesWith(DescriptorUnavailable, exc_msg, self.controller.get_network_status, '5AC9C5AA75BA1F18D8459B326B4B8111A856D290')
 
-  @patch('stem.control.Controller.get_info')
+  @patch('stem.control.AsyncController.get_info')
   def test_get_network_status(self, get_info_mock):
     """
     Exercises the get_network_status() method.
@@ -476,7 +506,7 @@ class TestControl(unittest.TestCase):
 
     # always return the same router status entry
 
-    get_info_mock.return_value = desc
+    get_info_mock.side_effect = coro_func_returning_value(desc)
 
     # pretend to get the router status entry with its name
 
@@ -494,7 +524,7 @@ class TestControl(unittest.TestCase):
 
     # raise an exception in the get_info() call
 
-    get_info_mock.side_effect = InvalidArguments
+    get_info_mock.side_effect = coro_func_raising_exc(InvalidArguments)
 
     # get a default value when the call fails
 
@@ -507,22 +537,28 @@ class TestControl(unittest.TestCase):
 
     self.assertRaises(InvalidArguments, self.controller.get_network_status, nickname)
 
-  @patch('stem.control.Controller.is_authenticated', Mock(return_value = True))
-  @patch('stem.control.Controller._attach_listeners', Mock(return_value = ([], [])))
-  @patch('stem.control.Controller.get_version')
-  def test_add_event_listener(self, get_version_mock):
+  @patch('stem.control.AsyncController.is_authenticated', Mock(return_value = True))
+  @patch('stem.control.AsyncController._attach_listeners')
+  @patch('stem.control.AsyncController.get_version')
+  def test_add_event_listener(self, get_version_mock, attach_listeners_mock):
     """
     Exercises the add_event_listener and remove_event_listener methods.
     """
 
+    attach_listeners_mock.side_effect = coro_func_returning_value(([], []))
+
+    def set_version(version_str):
+      version = stem.version.Version(version_str)
+      get_version_mock.side_effect = coro_func_returning_value(version)
+
     # set up for failure to create any events
 
-    get_version_mock.return_value = stem.version.Version('0.1.0.14')
+    set_version('0.1.0.14')
     self.assertRaises(InvalidRequest, self.controller.add_event_listener, Mock(), EventType.BW)
 
     # set up to only fail newer events
 
-    get_version_mock.return_value = stem.version.Version('0.2.0.35')
+    set_version('0.2.0.35')
 
     # EventType.BW is one of the earliest events
 
@@ -551,7 +587,7 @@ class TestControl(unittest.TestCase):
     event thread.
     """
 
-    self.circ_listener.side_effect = ValueError('boom')
+    self.circ_listener.side_effect = coro_func_raising_exc(ValueError('boom'))
 
     self._emit_event(CIRC_EVENT)
     self.circ_listener.assert_called_once_with(CIRC_EVENT)
@@ -582,10 +618,10 @@ class TestControl(unittest.TestCase):
     self._emit_event(BW_EVENT)
     self.bw_listener.assert_called_once_with(BW_EVENT)
 
-  @patch('stem.control.Controller.get_version', Mock(return_value = stem.version.Version('0.5.0.14')))
-  @patch('stem.control.Controller.msg', Mock(return_value = ControlMessage.from_str('250 OK\r\n')))
-  @patch('stem.control.Controller.add_event_listener', Mock())
-  @patch('stem.control.Controller.remove_event_listener', Mock())
+  @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14'))))
+  @patch('stem.control.AsyncController.msg', Mock(side_effect = coro_func_returning_value(ControlMessage.from_str('250 OK\r\n'))))
+  @patch('stem.control.AsyncController.add_event_listener', Mock(side_effect = coro_func_returning_value(None)))
+  @patch('stem.control.AsyncController.remove_event_listener', Mock(side_effect = coro_func_returning_value(None)))
   def test_timeout(self):
     """
     Methods that have an 'await' argument also have an optional timeout. Check
@@ -607,8 +643,9 @@ class TestControl(unittest.TestCase):
     )
 
     response = ''.join(['%s\r\n' % ' '.join(entry) for entry in valid_streams])
+    get_info_mock = Mock(side_effect = coro_func_returning_value(response))
 
-    with patch('stem.control.Controller.get_info', Mock(return_value = response)):
+    with patch('stem.control.AsyncController.get_info', get_info_mock):
       streams = self.controller.get_streams()
       self.assertEqual(len(valid_streams), len(streams))
 
@@ -627,8 +664,9 @@ class TestControl(unittest.TestCase):
     # instance, it's already open).
 
     response = stem.response.ControlMessage.from_str('555 Connection is not managed by controller.\r\n')
+    msg_mock = Mock(side_effect = coro_func_returning_value(response))
 
-    with patch('stem.control.Controller.msg', Mock(return_value = response)):
+    with patch('stem.control.AsyncController.msg', msg_mock):
       self.assertRaises(UnsatisfiableRequest, self.controller.attach_stream, 'stream_id', 'circ_id')
 
   def test_parse_circ_path(self):
@@ -671,7 +709,7 @@ class TestControl(unittest.TestCase):
     for test_input in malformed_inputs:
       self.assertRaises(ProtocolError, _parse_circ_path, test_input)
 
-  @patch('stem.control.Controller.get_conf')
+  @patch('stem.control.AsyncController.get_conf')
   def test_get_effective_rate(self, get_conf_mock):
     """
     Exercise the get_effective_rate() method.
@@ -679,18 +717,21 @@ class TestControl(unittest.TestCase):
 
     # check default if nothing was set
 
-    get_conf_mock.side_effect = lambda param, *args, **kwargs: {
-      'BandwidthRate': '1073741824',
-      'BandwidthBurst': '1073741824',
-      'RelayBandwidthRate': '0',
-      'RelayBandwidthBurst': '0',
-      'MaxAdvertisedBandwidth': '1073741824',
-    }[param]
+    async def get_conf_mock_side_effect(param, **kwargs):
+      return {
+        'BandwidthRate': '1073741824',
+        'BandwidthBurst': '1073741824',
+        'RelayBandwidthRate': '0',
+        'RelayBandwidthBurst': '0',
+        'MaxAdvertisedBandwidth': '1073741824',
+      }[param]
+
+    get_conf_mock.side_effect = get_conf_mock_side_effect
 
     self.assertEqual(1073741824, self.controller.get_effective_rate())
     self.assertEqual(1073741824, self.controller.get_effective_rate(burst = True))
 
-    get_conf_mock.side_effect = ControllerError('nope, too bad')
+    get_conf_mock.side_effect = coro_func_raising_exc(ControllerError('nope, too bad'))
     self.assertRaises(ControllerError, self.controller.get_effective_rate)
     self.assertEqual('my_default', self.controller.get_effective_rate('my_default'))
 
@@ -705,18 +746,19 @@ class TestControl(unittest.TestCase):
     #      with its work is to join on the thread.
 
     with patch('time.time', Mock(return_value = TEST_TIMESTAMP)):
-      with patch('stem.control.Controller.is_alive') as is_alive_mock:
+      with patch('stem.control.AsyncController.is_alive') as is_alive_mock:
         is_alive_mock.return_value = True
-        self.controller._create_loop_tasks()
+        loop = self.controller._asyncio_loop
+        asyncio.run_coroutine_threadsafe(self.async_controller._event_loop(), loop)
 
         try:
           # Converting an event back into an uncast ControlMessage, then feeding it
           # into our controller's event queue.
 
           uncast_event = ControlMessage.from_str(event.raw_content())
-          self.controller._event_queue.put(uncast_event)
-          self.controller._event_notice.set()
-          self.controller._event_queue.join()  # block until the event is consumed
+          event_queue = self.async_controller._event_queue
+          asyncio.run_coroutine_threadsafe(event_queue.put(uncast_event), loop).result()
+          asyncio.run_coroutine_threadsafe(event_queue.join(), loop).result()  # block until the event is consumed
         finally:
           is_alive_mock.return_value = False
-          self.controller._close()
+          asyncio.run_coroutine_threadsafe(self.async_controller._close(), loop).result()
diff --git a/test/unit/response/control_message.py b/test/unit/response/control_message.py
index abf5debf..414dcf63 100644
--- a/test/unit/response/control_message.py
+++ b/test/unit/response/control_message.py
@@ -126,7 +126,7 @@ class TestControlMessage(unittest.TestCase):
       # replace the CRLF for the line
       infonames_lines[index] = line.rstrip('\r\n') + '\n'
       test_socket_file = io.BytesIO(stem.util.str_tools._to_bytes(''.join(infonames_lines)))
-      self.assertRaises(stem.ProtocolError, stem.socket.recv_message, test_socket_file)
+      self.assertRaises(stem.ProtocolError, stem.socket.recv_message_from_bytes_io, test_socket_file)
 
       # puts the CRLF back
       infonames_lines[index] = infonames_lines[index].rstrip('\n') + '\r\n'
@@ -151,8 +151,8 @@ class TestControlMessage(unittest.TestCase):
         # - this is part of the message prefix
         # - this is disrupting the line ending
 
-        self.assertRaises(stem.ProtocolError, stem.socket.recv_message, io.BytesIO(stem.util.str_tools._to_bytes(removal_test_input)))
-        self.assertRaises(stem.ProtocolError, stem.socket.recv_message, io.BytesIO(stem.util.str_tools._to_bytes(replacement_test_input)))
+        self.assertRaises(stem.ProtocolError, stem.socket.recv_message_from_bytes_io, io.BytesIO(stem.util.str_tools._to_bytes(removal_test_input)))
+        self.assertRaises(stem.ProtocolError, stem.socket.recv_message_from_bytes_io, io.BytesIO(stem.util.str_tools._to_bytes(replacement_test_input)))
       else:
         # otherwise the data will be malformed, but this goes undetected
         self._assert_message_parses(removal_test_input)
@@ -166,7 +166,7 @@ class TestControlMessage(unittest.TestCase):
 
     control_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     control_socket_file = control_socket.makefile()
-    self.assertRaises(stem.SocketClosed, stem.socket.recv_message, control_socket_file)
+    self.assertRaises(stem.SocketClosed, stem.socket.recv_message_from_bytes_io, control_socket_file)
 
   def test_equality(self):
     msg = stem.response.ControlMessage.from_str(EVENT_BW)
@@ -200,7 +200,7 @@ class TestControlMessage(unittest.TestCase):
       stem.response.ControlMessage for the given input
     """
 
-    message = stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes(controller_reply)))
+    message = stem.socket.recv_message_from_bytes_io(io.BytesIO(stem.util.str_tools._to_bytes(controller_reply)))
 
     # checks that the raw_content equals the input value
     self.assertEqual(controller_reply, message.raw_content())





More information about the tor-commits mailing list