[tor-commits] [stem/master] Correct rebase discrepancies
atagar at torproject.org
atagar at torproject.org
Thu Jul 16 01:29:00 UTC 2020
commit 18a3280d9cb27a81bb01d8e964449adda3dc734e
Author: Damian Johnson <atagar at torproject.org>
Date: Mon May 18 14:43:18 2020 -0700
Correct rebase discrepancies
To make Illia's branch cleanly mergable I rebased onto our present master.
Manually resolving the conflicts resulted in a slightly different result than
he had. This delta makes us perfectly match his commit c788fd8.
---
stem/client/__init__.py | 11 +++---
stem/connection.py | 25 ++++++-------
stem/control.py | 63 ++++++++++++++++++---------------
stem/descriptor/remote.py | 16 ++++-----
stem/interpreter/__init__.py | 4 +--
stem/interpreter/autocomplete.py | 8 ++---
stem/interpreter/commands.py | 12 +++----
stem/interpreter/help.py | 9 ++---
stem/response/__init__.py | 2 +-
stem/socket.py | 45 +++++++++++------------
test/integ/connection/authentication.py | 6 +++-
test/unit/control/controller.py | 6 ++--
12 files changed, 105 insertions(+), 102 deletions(-)
diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index 941f0ee7..8ea7b3c1 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -33,7 +33,7 @@ import stem.socket
import stem.util.connection
from types import TracebackType
-from typing import Dict, Iterator, List, Optional, Sequence, Type, Union
+from typing import AsyncIterator, Dict, List, Optional, Sequence, Type, Union
from stem.client.cell import (
CELL_TYPE_SIZE,
@@ -70,7 +70,8 @@ class Relay(object):
self.link_protocol = LinkProtocol(link_protocol)
self._orport = orport
self._orport_buffer = b'' # unread bytes
- self._circuits = {}
+ self._orport_lock = stem.util.CombinedReentrantAndAsyncioLock()
+ self._circuits = {} # type: Dict[int, stem.client.Circuit]
@staticmethod
async def connect(address: str, port: int, link_protocols: Sequence['stem.client.datatype.LinkProtocol'] = DEFAULT_LINK_PROTOCOLS) -> 'stem.client.Relay': # type: ignore
@@ -191,7 +192,7 @@ class Relay(object):
cell, self._orport_buffer = Cell.pop(self._orport_buffer, self.link_protocol)
return cell
- async def _msg(self, cell: 'stem.client.cell.Cell') -> Iterator['stem.client.cell.Cell']:
+ async def _msg(self, cell: 'stem.client.cell.Cell') -> AsyncIterator['stem.client.cell.Cell']:
"""
Sends a cell on the ORPort and provides the response we receive in reply.
@@ -283,7 +284,7 @@ class Relay(object):
return circ
- async def __aiter__(self) -> Iterator['stem.client.Circuit']:
+ async def __aiter__(self) -> AsyncIterator['stem.client.Circuit']:
async with self._orport_lock:
for circ in self._circuits.values():
yield circ
@@ -381,7 +382,7 @@ class Circuit(object):
self.forward_digest = forward_digest
self.forward_key = forward_key
- async def close(self)- > None:
+ async def close(self) -> None:
async with self.relay._orport_lock:
await self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol))
del self.relay._circuits[self.id]
diff --git a/stem/connection.py b/stem/connection.py
index de76a345..213ba010 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -159,7 +159,7 @@ import stem.util.str_tools
import stem.util.system
import stem.version
-from typing import Any, List, Optional, Sequence, Tuple, Type, Union
+from typing import Any, cast, List, Optional, Sequence, Tuple, Type, Union
from stem.util import log
AuthMethod = stem.util.enum.Enum('NONE', 'PASSWORD', 'COOKIE', 'SAFECOOKIE', 'UNKNOWN')
@@ -227,7 +227,7 @@ COMMON_TOR_COMMANDS = (
)
-def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: Type = stem.control.Controller) -> Any:
+def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: Type[stem.control.Controller] = stem.control.Controller) -> Any:
"""
Convenience function for quickly getting a control connection for synchronous
usage. This is very handy for debugging or CLI setup, handling setup and
@@ -269,7 +269,7 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
# TODO: change this function's API so we can provide a concrete type
if controller is None or not issubclass(controller, stem.control.Controller):
- raise ValueError('Controller should be a stem.control.BaseController subclass.')
+ raise ValueError('Controller should be a stem.control.Controller subclass.')
async_controller_thread = stem.util.ThreadForWrappedAsyncClass()
async_controller_thread.start()
@@ -326,7 +326,7 @@ async def connect_async(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1'
return await _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller)
-async def _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller):
+async def _connect_async(control_port: Tuple[str, Union[str, int]], control_socket: str, password: Optional[str], password_prompt: bool, chroot_path: Optional[str], controller: Type[Union[stem.control.BaseController, stem.control.Controller]]) -> Any:
if control_port is None and control_socket is None:
raise ValueError('Neither a control port nor control socket were provided. Nothing to connect to.')
elif control_port:
@@ -377,7 +377,7 @@ async def _connect_async(control_port, control_socket, password, password_prompt
return await _connect_auth(control_connection, password, password_prompt, chroot_path, controller)
-async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[stem.control.BaseController]]) -> Any:
+async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[Union[stem.control.BaseController, stem.control.Controller]]]) -> Any:
"""
Helper for the connect_* functions that authenticates the socket and
constructs the controller.
@@ -402,7 +402,7 @@ async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str
elif issubclass(controller, stem.control.BaseController):
return controller(control_socket, is_authenticated = True)
elif issubclass(controller, stem.control.Controller):
- return controller(control_socket, is_authenticated = True, started_async_controller_thread = threading.current_thread())
+ return controller(control_socket, is_authenticated = True, started_async_controller_thread = cast(stem.util.ThreadForWrappedAsyncClass, threading.current_thread()))
except IncorrectSocketType:
if isinstance(control_socket, stem.socket.ControlPort):
print(CONNECT_MESSAGES['wrong_port_type'].format(port = control_socket.port))
@@ -995,7 +995,7 @@ async def authenticate_safecookie(controller: Union[stem.control.BaseController,
auth_response = await _msg(controller, 'AUTHENTICATE %s' % stem.util.str_tools._to_unicode(binascii.b2a_hex(client_hash)))
except stem.ControllerError as exc:
try:
- controller.connect()
+ await controller.connect()
except:
pass
@@ -1007,7 +1007,7 @@ async def authenticate_safecookie(controller: Union[stem.control.BaseController,
# if we got anything but an OK response then err
if not auth_response.is_ok():
try:
- controller.connect()
+ await controller.connect()
except:
pass
@@ -1051,9 +1051,7 @@ async def get_protocolinfo(controller: Union[stem.control.BaseController, stem.s
# next followed by authentication. Transparently reconnect if that happens.
if not protocolinfo_response or str(protocolinfo_response) == 'Authentication required.':
- potential_coroutine = controller.connect()
- if asyncio.iscoroutine(potential_coroutine):
- await potential_coroutine
+ await controller.connect()
try:
protocolinfo_response = await _msg(controller, 'PROTOCOLINFO 1')
@@ -1074,10 +1072,7 @@ async def _msg(controller: Union[stem.control.BaseController, stem.socket.Contro
await controller.send(message)
return await controller.recv()
else:
- message = controller.msg(message)
- if asyncio.iscoroutine(message):
- message = await message
- return message
+ return await controller.msg(message)
def _connection_for_default_port(address: str) -> stem.socket.ControlPort:
diff --git a/stem/control.py b/stem/control.py
index 6e15c16c..84f8f39b 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -271,7 +271,7 @@ import stem.version
from stem import UNDEFINED, CircStatus, Signal
from stem.util import log
from types import TracebackType
-from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
# When closing the controller we attempt to finish processing enqueued events,
# but if it takes longer than this we terminate.
@@ -554,6 +554,8 @@ def event_description(event: str) -> str:
class _BaseControllerSocketMixin:
+ _socket: stem.socket.ControlSocket
+
def is_alive(self) -> bool:
"""
Checks if our socket is currently connected. This is a pass-through for our
@@ -589,7 +591,7 @@ class _BaseControllerSocketMixin:
return self._socket.connection_time()
- def get_socket(self):
+ def get_socket(self) -> stem.socket.ControlSocket:
"""
Provides the socket used to speak with the tor process. Communicating with
the socket directly isn't advised since it may confuse this controller.
@@ -839,7 +841,7 @@ class BaseController(_BaseControllerSocketMixin):
async def __aenter__(self) -> 'stem.control.BaseController':
return self
- await def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
+ async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
await self.close()
async def _handle_event(self, event_message: stem.response.ControlMessage) -> None:
@@ -997,7 +999,7 @@ class AsyncController(BaseController):
"""
@classmethod
- def from_port(cls: Type, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.AsyncController':
+ def from_port(cls, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'AsyncController':
"""
Constructs a :class:`~stem.socket.ControlPort` based AsyncController.
@@ -1017,7 +1019,7 @@ class AsyncController(BaseController):
return cls(control_socket)
@classmethod
- def from_socket_file(cls: Type, path: str = '/var/run/tor/control') -> 'stem.control.AsyncController':
+ def from_socket_file(cls, path: str = '/var/run/tor/control') -> 'AsyncController':
"""
Constructs a :class:`~stem.socket.ControlSocketFile` based AsyncController.
@@ -1189,8 +1191,8 @@ class AsyncController(BaseController):
return list(reply.values())[0]
try:
- response = stem.response._convert_to_getinfo(await self.msg('GETINFO %s' % ' '.join(params)))
- response._assert_matches(params)
+ response = stem.response._convert_to_getinfo(await self.msg('GETINFO %s' % ' '.join(param_set)))
+ response._assert_matches(param_set)
# usually we want unicode values under python 3.x
@@ -1765,7 +1767,7 @@ class AsyncController(BaseController):
return stem.descriptor.microdescriptor.Microdescriptor(desc_content)
@with_default(yields = True)
- async def get_microdescriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.microdescriptor.Microdescriptor]:
+ async def get_microdescriptors(self, default: Any = UNDEFINED) -> AsyncIterator[stem.descriptor.microdescriptor.Microdescriptor]:
"""
get_microdescriptors(default = UNDEFINED)
@@ -1859,7 +1861,7 @@ class AsyncController(BaseController):
return stem.descriptor.server_descriptor.RelayDescriptor(desc_content)
@with_default(yields = True)
- async def get_server_descriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.server_descriptor.RelayDescriptor]:
+ async def get_server_descriptors(self, default: Any = UNDEFINED) -> AsyncIterator[stem.descriptor.server_descriptor.RelayDescriptor]:
"""
get_server_descriptors(default = UNDEFINED)
@@ -1954,7 +1956,7 @@ class AsyncController(BaseController):
return stem.descriptor.router_status_entry.RouterStatusEntryV3(desc_content)
@with_default(yields = True)
- async def get_network_statuses(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.router_status_entry.RouterStatusEntryV3]:
+ async def get_network_statuses(self, default: Any = UNDEFINED) -> AsyncIterator[stem.descriptor.router_status_entry.RouterStatusEntryV3]:
"""
get_network_statuses(default = UNDEFINED)
@@ -2061,6 +2063,7 @@ class AsyncController(BaseController):
request += ' ' + ' '.join(['SERVER=%s' % s for s in servers])
response = stem.response._convert_to_single_line(await self.msg(request))
+ stem.response.convert('SINGLELINE', response)
if not response.is_ok():
raise stem.ProtocolError('HSFETCH returned unexpected response code: %s' % response.code)
@@ -2153,7 +2156,7 @@ class AsyncController(BaseController):
async def _get_conf_multiple(self, param: str, default: Any = UNDEFINED) -> List[str]:
return await self.get_conf(param, default, multiple = True) # type: ignore
- await def get_conf_map(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, multiple: bool = True) -> Dict[str, Union[str, Sequence[str]]]:
+ async def get_conf_map(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, multiple: bool = True) -> Dict[str, Union[str, Sequence[str]]]:
"""
get_conf_map(params, default = UNDEFINED, multiple = True)
@@ -2971,7 +2974,7 @@ class AsyncController(BaseController):
else:
request += ' ClientAuth=%s' % client_name
- response = stem.response._convert_to_add_onion(await self.msg(request))
+ response = stem.response._convert_to_add_onion(stem.response._convert_to_add_onion(await self.msg(request)))
if await_publication:
# We should receive five UPLOAD events, followed by up to another five
@@ -3024,7 +3027,7 @@ class AsyncController(BaseController):
else:
raise stem.ProtocolError('DEL_ONION returned unexpected response code: %s' % response.code)
- async def add_event_listener(self, listener: Callable[[stem.response.events.Event], None], *events: 'stem.control.EventType') -> None:
+ async def add_event_listener(self, listener: Callable[[stem.response.events.Event], Union[None, Awaitable[None]]], *events: 'stem.control.EventType') -> None:
"""
Directs further tor controller events to a given function. The function is
expected to take a single argument, which is a
@@ -3082,7 +3085,7 @@ class AsyncController(BaseController):
if failed_events:
raise stem.ProtocolError('SETEVENTS rejected %s' % ', '.join(failed_events))
- async def remove_event_listener(self, listener: Callable[[stem.response.events.Event], None]) -> None:
+ async def remove_event_listener(self, listener: Callable[[stem.response.events.Event], Union[None, Awaitable[None]]]) -> None:
"""
Stops a listener from being notified of further tor events.
@@ -3253,7 +3256,7 @@ class AsyncController(BaseController):
:raises: :class:`stem.ControllerError` if the call fails
"""
- response = stem.response._convert_to_single_line(async self.msg('LOADCONF\n%s' % configtext))
+ response = stem.response._convert_to_single_line(await self.msg('LOADCONF\n%s' % configtext))
if response.code in ('552', '553'):
if response.code == '552' and response.message.startswith('Invalid config file: Failed to parse/validate config: Unknown option'):
@@ -3379,7 +3382,7 @@ class AsyncController(BaseController):
response = await self.get_info('circuit-status')
for circ in response.splitlines():
- circ_message = stem.response._convert_to_event(await stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 CIRC %s\r\n' % circ))))
+ circ_message = stem.response._convert_to_event(stem.socket.recv_message_from_bytes_io(io.BytesIO(stem.util.str_tools._to_bytes('650 CIRC %s\r\n' % circ))))
circuits.append(circ_message) # type: ignore
return circuits
@@ -3563,7 +3566,7 @@ class AsyncController(BaseController):
response = await self.get_info('stream-status')
for stream in response.splitlines():
- message = stem.response._convert_to_event(await stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 STREAM %s\r\n' % stream))))
+ message = stem.response._convert_to_event(stem.socket.recv_message_from_bytes_io(io.BytesIO(stem.util.str_tools._to_bytes('650 STREAM %s\r\n' % stream))))
streams.append(message) # type: ignore
return streams
@@ -3744,7 +3747,7 @@ class AsyncController(BaseController):
response = await self.msg('MAPADDRESS %s' % mapaddress_arg)
return stem.response._convert_to_mapaddress(response).entries
- await def drop_guards(self) -> None:
+ async def drop_guards(self) -> None:
"""
Drops our present guard nodes and picks a new set.
@@ -3812,7 +3815,7 @@ class AsyncController(BaseController):
if listener_type == event_type:
for listener in event_listeners:
try:
- potential_coroutine = listener(event_message)
+ potential_coroutine = listener(event)
if asyncio.iscoroutine(potential_coroutine):
await potential_coroutine
except Exception as exc:
@@ -3874,7 +3877,7 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
"""
@classmethod
- def from_port(cls: Type, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.Controller':
+ def from_port(cls, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'Controller':
"""
Constructs a :class:`~stem.socket.ControlPort` based Controller.
@@ -3885,8 +3888,8 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
.. versionchanged:: 1.5.0
Use both port 9051 and 9151 by default.
- :param str address: ip address of the controller
- :param int port: port number of the controller
+ :param address: ip address of the controller
+ :param port: port number of the controller
:returns: :class:`~stem.control.Controller` attached to the given port
@@ -3899,7 +3902,7 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
return controller
@classmethod
- def from_socket_file(cls: Type, path: str = '/var/run/tor/control') -> 'stem.control.Controller':
+ def from_socket_file(cls, path: str = '/var/run/tor/control') -> 'Controller':
"""
Constructs a :class:`~stem.socket.ControlSocketFile` based Controller.
@@ -3915,15 +3918,19 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
controller.connect()
return controller
- def __init__(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool = False, started_async_controller_thread: Optional['threading.Thread'] = None) -> None:
- def __init__(self, control_socket, is_authenticated = False, started_async_controller_thread = None):
+ def __init__(
+ self,
+ control_socket: stem.socket.ControlSocket,
+ is_authenticated: bool = False,
+ started_async_controller_thread: stem.util.ThreadForWrappedAsyncClass = None,
+ ) -> None:
if started_async_controller_thread:
self._thread_for_wrapped_class = started_async_controller_thread
else:
self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
self._thread_for_wrapped_class.start()
- self._wrapped_instance: AsyncController = self._init_async_class(AsyncController, control_socket, is_authenticated)
+ self._wrapped_instance: AsyncController = self._init_async_class(AsyncController, control_socket, is_authenticated) # type: ignore
self._socket = self._wrapped_instance._socket
@_set_doc_from_async_controller
@@ -3956,7 +3963,7 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
@_set_doc_from_async_controller
def remove_status_listener(self, callback: Callable[['stem.control.Controller', 'stem.control.State', float], None]) -> bool:
- self._wrapped_instance.remove_status_listener(callback)
+ return self._wrapped_instance.remove_status_listener(callback)
@_set_doc_from_async_controller
def authenticate(self, *args: Any, **kwargs: Any) -> None:
@@ -4306,7 +4313,7 @@ def _case_insensitive_lookup(entries: Union[Sequence[str], Mapping[str, Any]], k
raise ValueError("key '%s' doesn't exist in dict: %s" % (key, entries))
-async def _get_with_timeout(event_queue: queue.Queue, timeout: float, start_time: float) -> Any:
+async def _get_with_timeout(event_queue: asyncio.Queue, timeout: Optional[float], start_time: float) -> Any:
"""
Pulls an item from a queue with a given timeout.
"""
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index d7a833a4..e7ccaa24 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -104,7 +104,7 @@ import stem.util.tor_tools
from stem.descriptor import Compression
from stem.util import log, str_tools
-from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
+from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence, Tuple, Union
# Tor has a limited number of descriptors we can fetch explicitly by their
# fingerprint or hashes due to a limit on the url length by squid proxies.
@@ -392,7 +392,7 @@ class AsyncQuery(object):
self.reply_headers = None # type: Optional[Dict[str, str]]
self.kwargs = kwargs
- self._downloader_task = None
+ self._downloader_task = None # type: Optional[asyncio.Task]
self._downloader_lock = threading.RLock()
self._asyncio_loop = asyncio.get_event_loop()
@@ -401,7 +401,7 @@ class AsyncQuery(object):
self.start()
if block:
- self._asyncio_loop.create_task(self.run(True))
+ self.run(True)
def start(self) -> None:
"""
@@ -432,7 +432,7 @@ class AsyncQuery(object):
return [desc async for desc in self._run(suppress)]
- async def _run(self, suppress: bool) -> Iterator[stem.descriptor.Descriptor]:
+ async def _run(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]:
with self._downloader_lock:
self.start()
await self._downloader_task
@@ -468,7 +468,7 @@ class AsyncQuery(object):
raise self.error
- async def __aiter__(self) -> Iterator[stem.descriptor.Descriptor]:
+ async def __aiter__(self) -> AsyncIterator[stem.descriptor.Descriptor]:
async for desc in self._run(True):
yield desc
@@ -665,7 +665,7 @@ class Query(stem.util.AsyncClassWrapper):
def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
self._thread_for_wrapped_class.start()
- self._wrapped_instance: AsyncQuery = self._init_async_class(
+ self._wrapped_instance: AsyncQuery = self._init_async_class( # type: ignore
AsyncQuery,
resource,
descriptor_type,
@@ -688,7 +688,7 @@ class Query(stem.util.AsyncClassWrapper):
self._call_async_method_soon('start')
- def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
+ def run(self, suppress = False) -> List['stem.descriptor.Descriptor']:
"""
Blocks until our request is complete then provides the descriptors. If we
haven't yet started our request then this does so.
@@ -708,7 +708,7 @@ class Query(stem.util.AsyncClassWrapper):
return self._execute_async_method('run', suppress)
- def __iter__(self):
+ def __iter__(self) -> Iterator[stem.descriptor.Descriptor]:
for desc in self._execute_async_generator_method('__aiter__'):
yield desc
diff --git a/stem/interpreter/__init__.py b/stem/interpreter/__init__.py
index ae064a0a..370b9aa6 100644
--- a/stem/interpreter/__init__.py
+++ b/stem/interpreter/__init__.py
@@ -124,10 +124,10 @@ def main() -> None:
if args.run_cmd:
if args.run_cmd.upper().startswith('SETEVENTS '):
- async def handle_event(event_message):
+ async def handle_event(event_message: stem.response.ControlMessage) -> None:
print(format(str(event_message), *STANDARD_OUTPUT))
- controller._wrapped_instance._handle_event = handle_event
+ controller._wrapped_instance._handle_event = handle_event # type: ignore
if sys.stdout.isatty():
events = args.run_cmd.upper().split(' ', 1)[1]
diff --git a/stem/interpreter/autocomplete.py b/stem/interpreter/autocomplete.py
index 54642472..ed51fd3d 100644
--- a/stem/interpreter/autocomplete.py
+++ b/stem/interpreter/autocomplete.py
@@ -11,7 +11,7 @@ import stem.control
import stem.util.conf
from stem.interpreter import uses_settings
-from typing import List, Optional
+from typing import cast, List, Optional
@uses_settings
@@ -28,7 +28,7 @@ def _get_commands(controller: stem.control.Controller, config: stem.util.conf.Co
# GETINFO commands. Lines are of the form '[option] -- [description]'. This
# strips '*' from options that accept values.
- results = controller.get_info('info/names', None)
+ results = cast(str, controller.get_info('info/names', None))
if results:
for line in results.splitlines():
@@ -40,7 +40,7 @@ def _get_commands(controller: stem.control.Controller, config: stem.util.conf.Co
# GETCONF, SETCONF, and RESETCONF commands. Lines are of the form
# '[option] [type]'.
- results = controller.get_info('config/names', None)
+ results = cast(str, controller.get_info('config/names', None))
if results:
for line in results.splitlines():
@@ -62,7 +62,7 @@ def _get_commands(controller: stem.control.Controller, config: stem.util.conf.Co
)
for prefix, getinfo_cmd in options:
- results = controller.get_info(getinfo_cmd, None)
+ results = cast(str, controller.get_info(getinfo_cmd, None))
if results:
commands += [prefix + value for value in results.split()]
diff --git a/stem/interpreter/commands.py b/stem/interpreter/commands.py
index edbcca70..99f1219d 100644
--- a/stem/interpreter/commands.py
+++ b/stem/interpreter/commands.py
@@ -21,7 +21,7 @@ import stem.util.tor_tools
from stem.interpreter import STANDARD_OUTPUT, BOLD_OUTPUT, ERROR_OUTPUT, uses_settings, msg
from stem.util.term import format
-from typing import Iterator, List, TextIO
+from typing import cast, Iterator, List, TextIO
MAX_EVENTS = 100
@@ -45,7 +45,7 @@ def _get_fingerprint(arg: str, controller: stem.control.Controller) -> str:
if not arg:
try:
- return controller.get_info('fingerprint')
+ return cast(str, controller.get_info('fingerprint'))
except:
raise ValueError("We aren't a relay, no information to provide")
elif stem.util.tor_tools.is_valid_fingerprint(arg):
@@ -132,14 +132,14 @@ class ControlInterpreter(code.InteractiveConsole):
async def handle_event_wrapper(event_message: stem.response.ControlMessage) -> None:
await handle_event_real(event_message)
- self._received_events.insert(0, event_message)
+ self._received_events.insert(0, event_message) # type: ignore
if len(self._received_events) > MAX_EVENTS:
self._received_events.pop()
# type check disabled due to https://github.com/python/mypy/issues/708
- self._controller._wrapped_instance._handle_event = handle_event_wrapper
+ self._controller._wrapped_instance._handle_event = handle_event_wrapper # type: ignore
def get_events(self, *event_types: stem.control.EventType) -> List[stem.response.events.Event]:
events = list(self._received_events)
@@ -207,7 +207,7 @@ class ControlInterpreter(code.InteractiveConsole):
extrainfo_desc_query = downloader.get_extrainfo_descriptors(fingerprint)
for desc in server_desc_query:
- server_desc = desc
+ server_desc = cast(stem.descriptor.server_descriptor.RelayDescriptor, desc)
for desc in extrainfo_desc_query:
extrainfo_desc = desc
@@ -220,7 +220,7 @@ class ControlInterpreter(code.InteractiveConsole):
pass
try:
- address_extrainfo.append(self._controller.get_info('ip-to-country/%s' % ns_desc.address))
+ address_extrainfo.append(cast(str, self._controller.get_info('ip-to-country/%s' % ns_desc.address)))
except:
pass
diff --git a/stem/interpreter/help.py b/stem/interpreter/help.py
index 14b46e35..f2bbbafd 100644
--- a/stem/interpreter/help.py
+++ b/stem/interpreter/help.py
@@ -6,6 +6,7 @@ Provides our /help responses.
"""
import functools
+from typing import cast
import stem.control
import stem.util.conf
@@ -74,7 +75,7 @@ def _response(controller: stem.control.Controller, arg: str, config: stem.util.c
output += '\n'
if arg == 'GETINFO':
- results = controller.get_info('info/names', None)
+ results = cast(str, controller.get_info('info/names', None))
if results:
for line in results.splitlines():
@@ -84,7 +85,7 @@ def _response(controller: stem.control.Controller, arg: str, config: stem.util.c
output += format('%-33s' % opt, *BOLD_OUTPUT)
output += format(' - %s' % summary, *STANDARD_OUTPUT) + '\n'
elif arg == 'GETCONF':
- results = controller.get_info('config/names', None)
+ results = cast(str, controller.get_info('config/names', None))
if results:
options = [opt.split(' ', 1)[0] for opt in results.splitlines()]
@@ -103,7 +104,7 @@ def _response(controller: stem.control.Controller, arg: str, config: stem.util.c
output += format('%-15s' % signal, *BOLD_OUTPUT)
output += format(' - %s' % summary, *STANDARD_OUTPUT) + '\n'
elif arg == 'SETEVENTS':
- results = controller.get_info('events/names', None)
+ results = cast(str, controller.get_info('events/names', None))
if results:
entries = results.split()
@@ -118,7 +119,7 @@ def _response(controller: stem.control.Controller, arg: str, config: stem.util.c
output += format(line.rstrip(), *STANDARD_OUTPUT) + '\n'
elif arg == 'USEFEATURE':
- results = controller.get_info('features/names', None)
+ results = cast(str, controller.get_info('features/names', None))
if results:
output += format(results, *STANDARD_OUTPUT) + '\n'
diff --git a/stem/response/__init__.py b/stem/response/__init__.py
index 52dc74e4..2e251144 100644
--- a/stem/response/__init__.py
+++ b/stem/response/__init__.py
@@ -202,7 +202,7 @@ class ControlMessage(object):
content = re.sub(b'([\r]?)\n', b'\r\n', content)
- msg = stem.socket.recv_message_from_bytes_io(io.BytesIO(content), arrived_at = kwargs.pop('arrived_at', None))
+ msg = stem.socket.recv_message_from_bytes_io(io.BytesIO(stem.util.str_tools._to_bytes(content)), arrived_at = kwargs.pop('arrived_at', None))
if msg_type is not None:
convert(msg_type, msg, **kwargs)
diff --git a/stem/socket.py b/stem/socket.py
index 0feae831..ff99c5b1 100644
--- a/stem/socket.py
+++ b/stem/socket.py
@@ -85,6 +85,7 @@ import asyncio
import re
import socket
import ssl
+import sys
import threading
import time
@@ -93,7 +94,7 @@ import stem.util.str_tools
from stem.util import log
from types import TracebackType
-from typing import BinaryIO, Callable, List, Optional, Tuple, Type, Union, overload
+from typing import Awaitable, BinaryIO, Callable, List, Optional, Tuple, Type, Union, overload
MESSAGE_PREFIX = re.compile(b'^[a-zA-Z0-9]{3}[-+ ]')
ERROR_MSG = 'Error while receiving a control message (%s): %s'
@@ -109,8 +110,8 @@ class BaseSocket(object):
"""
def __init__(self) -> None:
- self._reader = None
- self._writer = None
+ self._reader = None # type: Optional[asyncio.StreamReader]
+ self._writer = None # type: Optional[asyncio.StreamWriter]
self._is_alive = False
self._connection_time = 0.0 # time when we last connected or disconnected
@@ -209,7 +210,7 @@ class BaseSocket(object):
if self._writer:
self._writer.close()
# `StreamWriter.wait_closed` was added in Python 3.7.
- if hasattr(self._writer, 'wait_closed'):
+ if sys.version_info >= (3, 7):
await self._writer.wait_closed()
self._reader = None
@@ -220,7 +221,7 @@ class BaseSocket(object):
if is_change:
await self._close()
- async def _send(self, message: Union[bytes, str], handler: Callable[[Union[socket.socket, ssl.SSLSocket], BinaryIO, Union[bytes, str]], None]) -> None:
+ async def _send(self, message: Union[bytes, str], handler: Callable[[asyncio.StreamWriter, Union[bytes, str]], Awaitable[None]]) -> None:
"""
Send message in a thread safe manner.
"""
@@ -241,11 +242,11 @@ class BaseSocket(object):
raise
@overload
- async def _recv(self, handler: Callable[[ssl.SSLSocket, BinaryIO], bytes]) -> bytes:
+ async def _recv(self, handler: Callable[[asyncio.StreamReader], Awaitable[bytes]]) -> bytes:
...
@overload
- async def _recv(self, handler: Callable[[socket.socket, BinaryIO], stem.response.ControlMessage]) -> stem.response.ControlMessage:
+ async def _recv(self, handler: Callable[[asyncio.StreamReader], Awaitable[stem.response.ControlMessage]]) -> stem.response.ControlMessage:
...
async def _recv(self, handler):
@@ -303,7 +304,7 @@ class BaseSocket(object):
return self
async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]):
- self.close()
+ await self.close()
async def _connect(self) -> None:
"""
@@ -320,12 +321,6 @@ class BaseSocket(object):
pass
async def _open_connection(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
- """
- Constructs and connects new socket. This is implemented by subclasses.
-
- :returns: **tuple** with our reader and writer streams
- """
-
raise NotImplementedError('Unsupported Operation: this should be implemented by the BaseSocket subclass')
@@ -380,7 +375,7 @@ class RelaySocket(BaseSocket):
* :class:`stem.SocketClosed` if the socket closes before we receive a complete message
"""
- async def wrapped_recv(s: ssl.SSLSocket, sf: BinaryIO) -> bytes:
+ async def wrapped_recv(reader: asyncio.StreamReader) -> Optional[bytes]:
read_coroutine = reader.read(1024)
if timeout is None:
return await read_coroutine
@@ -524,7 +519,7 @@ async def send_message(writer: asyncio.StreamWriter, message: Union[bytes, str],
<line 3>\\r\\n
.\\r\\n
- :param writer: stream derived from the control socket
+ :param writer: writer object
:param message: message to be sent on the control socket
:param raw: leaves the message formatting untouched, passing it to the
socket as-is
@@ -591,9 +586,7 @@ async def recv_message(reader: asyncio.StreamReader, arrived_at: Optional[float]
while True:
try:
- line = reader.readline()
- if asyncio.iscoroutine(line):
- line = await line
+ line = await reader.readline()
except AttributeError:
# if the control_file has been closed then we will receive:
# AttributeError: 'NoneType' object has no attribute 'recv'
@@ -693,7 +686,7 @@ async def recv_message(reader: asyncio.StreamReader, arrived_at: Optional[float]
raise stem.ProtocolError("Unrecognized divider type '%s': %s" % (divider, stem.util.str_tools._to_unicode(line)))
-def recv_message_from_bytes_io(reader: asyncio.StreamReader, arrived_at: Optional[float] = None) -> 'stem.response.ControlMessage':
+def recv_message_from_bytes_io(reader: BinaryIO, arrived_at: Optional[float] = None) -> stem.response.ControlMessage:
"""
Pulls from an I/O stream until we either have a complete message or
encounter a problem.
@@ -708,7 +701,9 @@ def recv_message_from_bytes_io(reader: asyncio.StreamReader, arrived_at: Optiona
a complete message
"""
- parsed_content, raw_content, first_line = None, None, True
+ parsed_content = [] # type: List[Tuple[str, str, bytes]]
+ raw_content = bytearray()
+ first_line = True
while True:
try:
@@ -739,10 +734,10 @@ def recv_message_from_bytes_io(reader: asyncio.StreamReader, arrived_at: Optiona
log.info(ERROR_MSG % ('SocketClosed', 'empty socket content'))
raise stem.SocketClosed('Received empty socket content.')
elif not MESSAGE_PREFIX.match(line):
- log.info(ERROR_MSG % ('ProtocolError', 'malformed status code/divider, "%s"' % log.escape(line)))
+ log.info(ERROR_MSG % ('ProtocolError', 'malformed status code/divider, "%s"' % log.escape(line.decode('utf-8'))))
raise stem.ProtocolError('Badly formatted reply line: beginning is malformed')
elif not line.endswith(b'\r\n'):
- log.info(ERROR_MSG % ('ProtocolError', 'no CRLF linebreak, "%s"' % log.escape(line)))
+ log.info(ERROR_MSG % ('ProtocolError', 'no CRLF linebreak, "%s"' % log.escape(line.decode('utf-8'))))
raise stem.ProtocolError('All lines should end with CRLF')
status_code, divider, content = line[:3], line[3:4], line[4:-2] # strip CRLF off content
@@ -781,11 +776,11 @@ def recv_message_from_bytes_io(reader: asyncio.StreamReader, arrived_at: Optiona
line = reader.readline()
raw_content += line
except socket.error as exc:
- log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content)))))
+ log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content).decode('utf-8')))))
raise stem.SocketClosed(exc)
if not line.endswith(b'\r\n'):
- log.info(ERROR_MSG % ('ProtocolError', 'CRLF linebreaks missing from a data reply, "%s"' % log.escape(bytes(raw_content))))
+ log.info(ERROR_MSG % ('ProtocolError', 'CRLF linebreaks missing from a data reply, "%s"' % log.escape(bytes(raw_content).decode('utf-8'))))
raise stem.ProtocolError('All lines should end with CRLF')
elif line == b'.\r\n':
break # data block termination
diff --git a/test/integ/connection/authentication.py b/test/integ/connection/authentication.py
index 683e555f..3eaae8d9 100644
--- a/test/integ/connection/authentication.py
+++ b/test/integ/connection/authentication.py
@@ -3,6 +3,7 @@ Integration tests for authenticating to the control socket via
stem.connection.authenticate* functions.
"""
+import asyncio
import os
import unittest
@@ -121,7 +122,10 @@ class TestAuthenticate(unittest.TestCase):
runner = test.runner.get_runner()
with await runner.get_tor_controller(False) as controller:
- await stem.connection.authenticate(controller, test.runner.CONTROL_PASSWORD, runner.get_chroot())
+ asyncio.run_coroutine_threadsafe(
+ stem.connection.authenticate(controller._wrapped_instance, test.runner.CONTROL_PASSWORD, runner.get_chroot()),
+ controller._thread_for_wrapped_class.loop,
+ ).result()
await test.runner.exercise_controller(self, controller)
@test.require.controller
diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py
index a11aba45..e4b11788 100644
--- a/test/unit/control/controller.py
+++ b/test/unit/control/controller.py
@@ -222,7 +222,7 @@ class TestControl(unittest.TestCase):
get_info_mock.side_effect = coro_func_raising_exc(InvalidArguments)
- async def get_conf_mock_side_effect(param, **kwargs):
+ async def get_conf_mock_side_effect(param, *args, **kwargs):
return {
'ControlPort': '9050',
'ControlListenAddress': ['127.0.0.1'],
@@ -236,7 +236,7 @@ class TestControl(unittest.TestCase):
# non-local addresss
- async def get_conf_mock_side_effect(param, **kwargs):
+ async def get_conf_mock_side_effect(param, *args, **kwargs):
return {
'ControlPort': '9050',
'ControlListenAddress': ['27.4.4.1'],
@@ -717,7 +717,7 @@ class TestControl(unittest.TestCase):
# check default if nothing was set
- async def get_conf_mock_side_effect(param, **kwargs):
+ async def get_conf_mock_side_effect(param, *args, **kwargs):
return {
'BandwidthRate': '1073741824',
'BandwidthBurst': '1073741824',
More information about the tor-commits
mailing list