[tor-commits] [stem/master] Replace "stem/socket.py" with its asynchronous implementation
atagar at torproject.org
atagar at torproject.org
Thu Jul 16 01:28:58 UTC 2020
commit 08d1c08dc39b9f2535fb185f4339496b8e3ea2de
Author: Illia Volochii <illia.volochii at gmail.com>
Date: Sun Apr 12 18:47:24 2020 +0300
Replace "stem/socket.py" with its asynchronous implementation
---
stem/async_socket.py | 717 ---------------------------------------------------
stem/socket.py | 177 +++++--------
2 files changed, 65 insertions(+), 829 deletions(-)
diff --git a/stem/async_socket.py b/stem/async_socket.py
deleted file mode 100644
index 512e6bde..00000000
--- a/stem/async_socket.py
+++ /dev/null
@@ -1,717 +0,0 @@
-# Copyright 2011-2020, Damian Johnson and The Tor Project
-# See LICENSE for licensing information
-
-"""
-Supports communication with sockets speaking Tor protocols. This
-allows us to send messages as basic strings, and receive responses as
-:class:`~stem.response.ControlMessage` instances.
-
-**This module only consists of low level components, and is not intended for
-users.** See our `tutorials <../tutorials.html>`_ and `Control Module
-<control.html>`_ if you're new to Stem and looking to get started.
-
-With that aside, these can still be used for raw socket communication with
-Tor...
-
-::
-
- import stem
- import stem.connection
- import stem.socket
-
- if __name__ == '__main__':
- try:
- control_socket = stem.socket.ControlPort(port = 9051)
- stem.connection.authenticate(control_socket)
- except stem.SocketError as exc:
- print 'Unable to connect to tor on port 9051: %s' % exc
- sys.exit(1)
- except stem.connection.AuthenticationFailure as exc:
- print 'Unable to authenticate: %s' % exc
- sys.exit(1)
-
- print "Issuing 'GETINFO version' query...\\n"
- control_socket.send('GETINFO version')
- print control_socket.recv()
-
-::
-
- % python example.py
- Issuing 'GETINFO version' query...
-
- version=0.2.4.10-alpha-dev (git-8be6058d8f31e578)
- OK
-
-**Module Overview:**
-
-::
-
- BaseSocket - Thread safe socket.
- |- RelaySocket - Socket for a relay's ORPort.
- | |- send - sends a message to the socket
- | +- recv - receives a response from the socket
- |
- |- ControlSocket - Socket wrapper that speaks the tor control protocol.
- | |- ControlPort - Control connection via a port.
- | |- ControlSocketFile - Control connection via a local file socket.
- | |
- | |- send - sends a message to the socket
- | +- recv - receives a ControlMessage from the socket
- |
- |- is_alive - reports if the socket is known to be closed
- |- is_localhost - returns if the socket is for the local system or not
- |- connection_time - timestamp when socket last connected or disconnected
- |- connect - connects a new socket
- |- close - shuts down the socket
- +- __enter__ / __exit__ - manages socket connection
-
- send_message - Writes a message to a control socket.
- recv_message - Reads a ControlMessage from a control socket.
- send_formatting - Performs the formatting expected from sent messages.
-"""
-
-from __future__ import absolute_import
-
-import asyncio
-import re
-import socket
-import ssl
-import threading
-import time
-
-import stem.response
-import stem.util.str_tools
-
-from stem.util import log
-
-MESSAGE_PREFIX = re.compile(b'^[a-zA-Z0-9]{3}[-+ ]')
-ERROR_MSG = 'Error while receiving a control message (%s): %s'
-
-# lines to limit our trace logging to, you can disable this by setting it to None
-
-TRUNCATE_LOGS = 10
-
-
-class BaseSocket(object):
- """
- Thread safe socket, providing common socket functionality.
- """
-
- def __init__(self):
- self._reader = None
- self._writer = None
- self._is_alive = False
- self._connection_time = 0.0 # time when we last connected or disconnected
-
- # Tracks sending and receiving separately. This should be safe, and doing
- # so prevents deadlock where we block writes because we're waiting to read
- # a message that isn't coming.
-
- self._send_lock = threading.RLock()
- self._recv_lock = threading.RLock()
-
- def is_alive(self):
- """
- Checks if the socket is known to be closed. We won't be aware if it is
- until we either use it or have explicitily shut it down.
-
- In practice a socket derived from a port knows about its disconnection
- after failing to receive data, whereas socket file derived connections
- know after either sending or receiving data.
-
- This means that to have reliable detection for when we're disconnected
- you need to continually pull from the socket (which is part of what the
- :class:`~stem.control.BaseController` does).
-
- :returns: **bool** that's **True** if our socket is connected and **False**
- otherwise
- """
-
- return self._is_alive
-
- def is_localhost(self):
- """
- Returns if the connection is for the local system or not.
-
- :returns: **bool** that's **True** if the connection is for the local host
- and **False** otherwise
- """
-
- return False
-
- def connection_time(self):
- """
- Provides the unix timestamp for when our socket was either connected or
- disconnected. That is to say, the time we connected if we're currently
- connected and the time we disconnected if we're not connected.
-
- .. versionadded:: 1.3.0
-
- :returns: **float** for when we last connected or disconnected, zero if
- we've never connected
- """
-
- return self._connection_time
-
- async def connect(self):
- """
- Connects to a new socket, closing our previous one if we're already
- attached.
-
- :raises: :class:`stem.SocketError` if unable to make a socket
- """
-
- with self._send_lock:
- # Closes the socket if we're currently attached to one. Once we're no
- # longer alive it'll be safe to acquire the recv lock because recv()
- # calls no longer block (raising SocketClosed instead).
-
- if self.is_alive():
- await self.close()
-
- with self._recv_lock:
- self._reader, self._writer = await self._open_connection()
- self._is_alive = True
- self._connection_time = time.time()
-
- # It's possible for this to have a transient failure...
- # SocketError: [Errno 4] Interrupted system call
- #
- # It's safe to retry, so give it another try if it fails.
-
- try:
- await self._connect()
- except stem.SocketError:
- await self._connect() # single retry
-
- async def close(self):
- """
- Shuts down the socket. If it's already closed then this is a no-op.
- """
-
- with self._send_lock:
- # Function is idempotent with one exception: we notify _close() if this
- # is causing our is_alive() state to change.
-
- is_change = self.is_alive()
-
- if self._writer:
- self._writer.close()
- # `StreamWriter.wait_closed` was added in Python 3.7.
- if hasattr(self._writer, 'wait_closed'):
- await self._writer.wait_closed()
-
- self._reader = None
- self._writer = None
- self._is_alive = False
- self._connection_time = time.time()
-
- if is_change:
- await self._close()
-
- async def _send(self, message, handler):
- """
- Send message in a thread safe manner. Handler is expected to be of the form...
-
- ::
-
- my_handler(socket, socket_file, message)
- """
-
- with self._send_lock:
- try:
- if not self.is_alive():
- raise stem.SocketClosed()
-
- await handler(self._writer, message)
- except stem.SocketClosed:
- # if send_message raises a SocketClosed then we should properly shut
- # everything down
-
- if self.is_alive():
- await self.close()
-
- raise
-
- async def _recv(self, handler):
- """
- Receives a message in a thread safe manner. Handler is expected to be of the form...
-
- ::
-
- my_handler(socket, socket_file)
- """
-
- with self._recv_lock:
- try:
- # makes a temporary reference to the _reader because connect()
- # and close() may set or unset it
-
- my_reader = self._reader
-
- if not my_reader:
- raise stem.SocketClosed()
-
- return await handler(my_reader)
- except stem.SocketClosed:
- # If recv_message raises a SocketClosed then we should properly shut
- # everything down. However, there's a couple cases where this will
- # cause deadlock...
- #
- # * This SocketClosed was *caused by* a close() call, which is joining
- # on our thread.
- #
- # * A send() call that's currently in flight is about to call close(),
- # also attempting to join on us.
- #
- # To resolve this we make a non-blocking call to acquire the send lock.
- # If we get it then great, we can close safely. If not then one of the
- # above are in progress and we leave the close to them.
-
- if self.is_alive():
- if self._send_lock.acquire(False):
- await self.close()
- self._send_lock.release()
-
- raise
-
- def _get_send_lock(self):
- """
- The send lock is useful to classes that interact with us at a deep level
- because it's used to lock :func:`stem.socket.ControlSocket.connect` /
- :func:`stem.socket.BaseSocket.close`, and by extension our
- :func:`stem.socket.BaseSocket.is_alive` state changes.
-
- :returns: **threading.RLock** that governs sending messages to our socket
- and state changes
- """
-
- return self._send_lock
-
- async def __aenter__(self):
- return self
-
- async def __aexit__(self, exit_type, value, traceback):
- await self.close()
-
- async def _connect(self):
- """
- Connection callback that can be overwritten by subclasses and wrappers.
- """
-
- pass
-
- async def _close(self):
- """
- Disconnection callback that can be overwritten by subclasses and wrappers.
- """
-
- pass
-
- async def _open_connection(self):
- raise NotImplementedError('Unsupported Operation: this should be implemented by the BaseSocket subclass')
-
-
-class RelaySocket(BaseSocket):
- """
- `Link-level connection
- <https://gitweb.torproject.org/torspec.git/tree/tor-spec.txt>`_ to a Tor
- relay.
-
- .. versionadded:: 1.7.0
-
- :var str address: address our socket connects to
- :var int port: ORPort our socket connects to
- """
-
- def __init__(self, address = '127.0.0.1', port = 9050):
- """
- RelaySocket constructor.
-
- :param str address: ip address of the relay
- :param int port: orport of the relay
- """
-
- super(RelaySocket, self).__init__()
- self.address = address
- self.port = port
-
- async def send(self, message):
- """
- Sends a message to the relay's ORPort.
-
- :param str message: message to be formatted and sent to the socket
-
- :raises:
- * :class:`stem.SocketError` if a problem arises in using the socket
- * :class:`stem.SocketClosed` if the socket is known to be shut down
- """
-
- await self._send(message, _write_to_socket)
-
- async def recv(self, timeout = None):
- """
- Receives a message from the relay.
-
- :param float timeout: maxiumum number of seconds to await a response, this
- blocks indefinitely if **None**
-
- :returns: bytes for the message received
-
- :raises:
- * :class:`stem.ProtocolError` the content from the socket is malformed
- * :class:`stem.SocketClosed` if the socket closes before we receive a complete message
- """
-
- async def wrapped_recv(reader):
- read_coroutine = reader.read(1024)
- if timeout is None:
- return await read_coroutine
- else:
- try:
- return await asyncio.wait_for(read_coroutine, timeout)
- except (asyncio.TimeoutError, ssl.SSLError, ssl.SSLWantReadError):
- return None
-
- return await self._recv(wrapped_recv)
-
- def is_localhost(self):
- return self.address == '127.0.0.1'
-
- async def _open_connection(self):
- try:
- return await asyncio.open_connection(self.address, self.port, ssl=ssl.SSLContext())
- except socket.error as exc:
- raise stem.SocketError(exc)
-
-
-class ControlSocket(BaseSocket):
- """
- Wrapper for a socket connection that speaks the Tor control protocol. To the
- better part this transparently handles the formatting for sending and
- receiving complete messages.
-
- Callers should not instantiate this class directly, but rather use subclasses
- which are expected to implement the **_make_socket()** method.
- """
-
- def __init__(self):
- super(ControlSocket, self).__init__()
-
- async def send(self, message):
- """
- Formats and sends a message to the control socket. For more information see
- the :func:`~stem.socket.send_message` function.
-
- :param str message: message to be formatted and sent to the socket
-
- :raises:
- * :class:`stem.SocketError` if a problem arises in using the socket
- * :class:`stem.SocketClosed` if the socket is known to be shut down
- """
-
- await self._send(message, send_message)
-
- async def recv(self):
- """
- Receives a message from the control socket, blocking until we've received
- one. For more information see the :func:`~stem.socket.recv_message` function.
-
- :returns: :class:`~stem.response.ControlMessage` for the message received
-
- :raises:
- * :class:`stem.ProtocolError` the content from the socket is malformed
- * :class:`stem.SocketClosed` if the socket closes before we receive a complete message
- """
-
- return await self._recv(recv_message)
-
-
-class ControlPort(ControlSocket):
- """
- Control connection to tor. For more information see tor's ControlPort torrc
- option.
-
- :var str address: address our socket connects to
- :var int port: ControlPort our socket connects to
- """
-
- def __init__(self, address = '127.0.0.1', port = 9051):
- """
- ControlPort constructor.
-
- :param str address: ip address of the controller
- :param int port: port number of the controller
- """
-
- super(ControlPort, self).__init__()
- self.address = address
- self.port = port
-
- def is_localhost(self):
- return self.address == '127.0.0.1'
-
- async def _open_connection(self):
- try:
- return await asyncio.open_connection(self.address, self.port)
- except socket.error as exc:
- raise stem.SocketError(exc)
-
-
-class ControlSocketFile(ControlSocket):
- """
- Control connection to tor. For more information see tor's ControlSocket torrc
- option.
-
- :var str path: filesystem path of the socket we connect to
- """
-
- def __init__(self, path = '/var/run/tor/control'):
- """
- ControlSocketFile constructor.
-
- :param str socket_path: path where the control socket is located
- """
-
- super(ControlSocketFile, self).__init__()
- self.path = path
-
- def is_localhost(self):
- return True
-
- async def _open_connection(self):
- try:
- return await asyncio.open_unix_connection(self.path)
- except socket.error as exc:
- raise stem.SocketError(exc)
-
-
-async def send_message(writer, message, raw = False):
- """
- Sends a message to the control socket, adding the expected formatting for
- single verses multi-line messages. Neither message type should contain an
- ending newline (if so it'll be treated as a multi-line message with a blank
- line at the end). If the message doesn't contain a newline then it's sent
- as...
-
- ::
-
- <message>\\r\\n
-
- and if it does contain newlines then it's split on ``\\n`` and sent as...
-
- ::
-
- +<line 1>\\r\\n
- <line 2>\\r\\n
- <line 3>\\r\\n
- .\\r\\n
-
- :param file control_file: file derived from the control socket (see the
- socket's makefile() method for more information)
- :param str message: message to be sent on the control socket
- :param bool raw: leaves the message formatting untouched, passing it to the
- socket as-is
-
- :raises:
- * :class:`stem.SocketError` if a problem arises in using the socket
- * :class:`stem.SocketClosed` if the socket is known to be shut down
- """
-
- if not raw:
- message = send_formatting(message)
-
- await _write_to_socket(writer, message)
-
- if log.is_tracing():
- log_message = message.replace('\r\n', '\n').rstrip()
- msg_div = '\n' if '\n' in log_message else ' '
- log.trace('Sent to tor:%s%s' % (msg_div, log_message))
-
-
-async def _write_to_socket(writer, message):
- try:
- writer.write(stem.util.str_tools._to_bytes(message))
- await writer.drain()
- except socket.error as exc:
- log.info('Failed to send: %s' % exc)
-
- # When sending there doesn't seem to be a reliable method for
- # distinguishing between failures from a disconnect verses other things.
- # Just accounting for known disconnection responses.
-
- if str(exc) == '[Errno 32] Broken pipe':
- raise stem.SocketClosed(exc)
- else:
- raise stem.SocketError(exc)
- except AttributeError:
- # if the control_file has been closed then flush will receive:
- # AttributeError: 'NoneType' object has no attribute 'sendall'
-
- log.info('Failed to send: file has been closed')
- raise stem.SocketClosed('file has been closed')
-
-
-async def recv_message(reader, arrived_at = None):
- """
- Pulls from a control socket until we either have a complete message or
- encounter a problem.
-
- :param file control_file: file derived from the control socket (see the
- socket's makefile() method for more information)
-
- :returns: :class:`~stem.response.ControlMessage` read from the socket
-
- :raises:
- * :class:`stem.ProtocolError` the content from the socket is malformed
- * :class:`stem.SocketClosed` if the socket closes before we receive
- a complete message
- """
-
- parsed_content, raw_content, first_line = None, None, True
-
- while True:
- try:
- line = await reader.readline()
- except AttributeError:
- # if the control_file has been closed then we will receive:
- # AttributeError: 'NoneType' object has no attribute 'recv'
-
- log.info(ERROR_MSG % ('SocketClosed', 'socket file has been closed'))
- raise stem.SocketClosed('socket file has been closed')
- except (OSError, ValueError) as exc:
- # when disconnected this errors with...
- #
- # * ValueError: I/O operation on closed file
- # * OSError: [Errno 107] Transport endpoint is not connected
- # * OSError: [Errno 9] Bad file descriptor
-
- log.info(ERROR_MSG % ('SocketClosed', 'received exception "%s"' % exc))
- raise stem.SocketClosed(exc)
-
- # Parses the tor control lines. These are of the form...
- # <status code><divider><content>\r\n
-
- if not line:
- # if the socket is disconnected then the readline() method will provide
- # empty content
-
- log.info(ERROR_MSG % ('SocketClosed', 'empty socket content'))
- raise stem.SocketClosed('Received empty socket content.')
- elif not MESSAGE_PREFIX.match(line):
- log.info(ERROR_MSG % ('ProtocolError', 'malformed status code/divider, "%s"' % log.escape(line)))
- raise stem.ProtocolError('Badly formatted reply line: beginning is malformed')
- elif not line.endswith(b'\r\n'):
- log.info(ERROR_MSG % ('ProtocolError', 'no CRLF linebreak, "%s"' % log.escape(line)))
- raise stem.ProtocolError('All lines should end with CRLF')
-
- status_code, divider, content = line[:3], line[3:4], line[4:-2] # strip CRLF off content
-
- status_code = stem.util.str_tools._to_unicode(status_code)
- divider = stem.util.str_tools._to_unicode(divider)
-
- # Most controller responses are single lines, in which case we don't need
- # so much overhead.
-
- if first_line:
- if divider == ' ':
- _log_trace(line)
- return stem.response.ControlMessage([(status_code, divider, content)], line, arrived_at = arrived_at)
- else:
- parsed_content, raw_content, first_line = [], bytearray(), False
-
- raw_content += line
-
- if divider == '-':
- # mid-reply line, keep pulling for more content
- parsed_content.append((status_code, divider, content))
- elif divider == ' ':
- # end of the message, return the message
- parsed_content.append((status_code, divider, content))
- _log_trace(bytes(raw_content))
- return stem.response.ControlMessage(parsed_content, bytes(raw_content), arrived_at = arrived_at)
- elif divider == '+':
- # data entry, all of the following lines belong to the content until we
- # get a line with just a period
-
- content_block = bytearray(content)
-
- while True:
- try:
- line = await reader.readline()
- raw_content += line
- except socket.error as exc:
- log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content)))))
- raise stem.SocketClosed(exc)
-
- if not line.endswith(b'\r\n'):
- log.info(ERROR_MSG % ('ProtocolError', 'CRLF linebreaks missing from a data reply, "%s"' % log.escape(bytes(raw_content))))
- raise stem.ProtocolError('All lines should end with CRLF')
- elif line == b'.\r\n':
- break # data block termination
-
- line = line[:-2] # strips off the CRLF
-
- # lines starting with a period are escaped by a second period (as per
- # section 2.4 of the control-spec)
-
- if line.startswith(b'..'):
- line = line[1:]
-
- content_block += b'\n' + line
-
- # joins the content using a newline rather than CRLF separator (more
- # conventional for multi-line string content outside the windows world)
-
- parsed_content.append((status_code, divider, bytes(content_block)))
- else:
- # this should never be reached due to the prefix regex, but might as well
- # be safe...
-
- log.warn(ERROR_MSG % ('ProtocolError', "\"%s\" isn't a recognized divider type" % divider))
- raise stem.ProtocolError("Unrecognized divider type '%s': %s" % (divider, stem.util.str_tools._to_unicode(line)))
-
-
-def send_formatting(message):
- """
- Performs the formatting expected from sent control messages. For more
- information see the :func:`~stem.socket.send_message` function.
-
- :param str message: message to be formatted
-
- :returns: **str** of the message wrapped by the formatting expected from
- controllers
- """
-
- # From control-spec section 2.2...
- # Command = Keyword OptArguments CRLF / "+" Keyword OptArguments CRLF CmdData
- # Keyword = 1*ALPHA
- # OptArguments = [ SP *(SP / VCHAR) ]
- #
- # A command is either a single line containing a Keyword and arguments, or a
- # multiline command whose initial keyword begins with +, and whose data
- # section ends with a single "." on a line of its own.
-
- # if we already have \r\n entries then standardize on \n to start with
- message = message.replace('\r\n', '\n')
-
- if '\n' in message:
- return '+%s\r\n.\r\n' % message.replace('\n', '\r\n')
- else:
- return message + '\r\n'
-
-
-def _log_trace(response):
- if not log.is_tracing():
- return
-
- log_message = stem.util.str_tools._to_unicode(response.replace(b'\r\n', b'\n').rstrip())
- log_message_lines = log_message.split('\n')
-
- if TRUNCATE_LOGS and len(log_message_lines) > TRUNCATE_LOGS:
- log_message = '\n'.join(log_message_lines[:TRUNCATE_LOGS] + ['... %i more lines...' % (len(log_message_lines) - TRUNCATE_LOGS)])
-
- if len(log_message_lines) > 2:
- log.trace('Received from tor:\n%s' % log_message)
- else:
- log.trace('Received from tor: %s' % log_message.replace('\n', '\\n'))
diff --git a/stem/socket.py b/stem/socket.py
index 81019cf2..dd123751 100644
--- a/stem/socket.py
+++ b/stem/socket.py
@@ -69,6 +69,7 @@ Tor...
send_formatting - Performs the formatting expected from sent messages.
"""
+import asyncio
import re
import socket
import ssl
@@ -96,8 +97,8 @@ class BaseSocket(object):
"""
def __init__(self) -> None:
- self._socket = None # type: Optional[Union[socket.socket, ssl.SSLSocket]]
- self._socket_file = None # type: Optional[BinaryIO]
+ self._reader = None
+ self._writer = None
self._is_alive = False
self._connection_time = 0.0 # time when we last connected or disconnected
@@ -151,7 +152,7 @@ class BaseSocket(object):
return self._connection_time
- def connect(self) -> None:
+ async def connect(self) -> None:
"""
Connects to a new socket, closing our previous one if we're already
attached.
@@ -165,11 +166,10 @@ class BaseSocket(object):
# calls no longer block (raising SocketClosed instead).
if self.is_alive():
- self.close()
+ await self.close()
with self._recv_lock:
- self._socket = self._make_socket()
- self._socket_file = self._socket.makefile(mode = 'rwb')
+ self._reader, self._writer = await self._open_connection()
self._is_alive = True
self._connection_time = time.time()
@@ -179,11 +179,11 @@ class BaseSocket(object):
# It's safe to retry, so give it another try if it fails.
try:
- self._connect()
+ await self._connect()
except stem.SocketError:
- self._connect() # single retry
+ await self._connect() # single retry
- def close(self) -> None:
+ async def close(self) -> None:
"""
Shuts down the socket. If it's already closed then this is a no-op.
"""
@@ -194,32 +194,21 @@ class BaseSocket(object):
is_change = self.is_alive()
- if self._socket:
- # if we haven't yet established a connection then this raises an error
- # socket.error: [Errno 107] Transport endpoint is not connected
+ if self._writer:
+ self._writer.close()
+ # `StreamWriter.wait_closed` was added in Python 3.7.
+ if hasattr(self._writer, 'wait_closed'):
+ await self._writer.wait_closed()
- try:
- self._socket.shutdown(socket.SHUT_RDWR)
- except socket.error:
- pass
-
- self._socket.close()
-
- if self._socket_file:
- try:
- self._socket_file.close()
- except BrokenPipeError:
- pass
-
- self._socket = None
- self._socket_file = None
+ self._reader = None
+ self._writer = None
self._is_alive = False
self._connection_time = time.time()
if is_change:
- self._close()
+ await self._close()
- def _send(self, message: Union[bytes, str], handler: Callable[[Union[socket.socket, ssl.SSLSocket], BinaryIO, Union[bytes, str]], None]) -> None:
+ async def _send(self, message: Union[bytes, str], handler: Callable[[Union[socket.socket, ssl.SSLSocket], BinaryIO, Union[bytes, str]], None]) -> None:
"""
Send message in a thread safe manner. Handler is expected to be of the form...
@@ -233,25 +222,25 @@ class BaseSocket(object):
if not self.is_alive():
raise stem.SocketClosed()
- handler(self._socket, self._socket_file, message)
+ await handler(self._writer, message)
except stem.SocketClosed:
# if send_message raises a SocketClosed then we should properly shut
# everything down
if self.is_alive():
- self.close()
+ await self.close()
raise
@overload
- def _recv(self, handler: Callable[[ssl.SSLSocket, BinaryIO], bytes]) -> bytes:
+ async def _recv(self, handler: Callable[[ssl.SSLSocket, BinaryIO], bytes]) -> bytes:
...
@overload
- def _recv(self, handler: Callable[[socket.socket, BinaryIO], stem.response.ControlMessage]) -> stem.response.ControlMessage:
+ async def _recv(self, handler: Callable[[socket.socket, BinaryIO], stem.response.ControlMessage]) -> stem.response.ControlMessage:
...
- def _recv(self, handler):
+ async def _recv(self, handler):
"""
Receives a message in a thread safe manner. Handler is expected to be of the form...
@@ -262,15 +251,15 @@ class BaseSocket(object):
with self._recv_lock:
try:
- # makes a temporary reference to the _socket_file because connect()
+ # makes a temporary reference to the _reader because connect()
# and close() may set or unset it
- my_socket, my_socket_file = self._socket, self._socket_file
+ my_reader = self._reader
- if not my_socket or not my_socket_file:
+ if not my_reader:
raise stem.SocketClosed()
- return handler(my_socket, my_socket_file)
+ return await handler(my_reader)
except stem.SocketClosed:
# If recv_message raises a SocketClosed then we should properly shut
# everything down. However, there's a couple cases where this will
@@ -288,7 +277,7 @@ class BaseSocket(object):
if self.is_alive():
if self._send_lock.acquire(False):
- self.close()
+ await self.close()
self._send_lock.release()
raise
@@ -306,35 +295,31 @@ class BaseSocket(object):
return self._send_lock
- def __enter__(self) -> 'stem.socket.BaseSocket':
+ async def __aenter__(self) -> 'stem.socket.BaseSocket':
return self
- def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]):
+ async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]):
self.close()
- def _connect(self) -> None:
+ async def _connect(self) -> None:
"""
Connection callback that can be overwritten by subclasses and wrappers.
"""
pass
- def _close(self) -> None:
+ async def _close(self) -> None:
"""
Disconnection callback that can be overwritten by subclasses and wrappers.
"""
pass
- def _make_socket(self) -> Union[socket.socket, ssl.SSLSocket]:
+ async def _open_connection(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
"""
Constructs and connects new socket. This is implemented by subclasses.
- :returns: **socket.socket** for our configuration
-
- :raises:
- * :class:`stem.SocketError` if unable to make a socket
- * **NotImplementedError** if not implemented by a subclass
+ :returns: **tuple** with our reader and writer streams
"""
raise NotImplementedError('Unsupported Operation: this should be implemented by the BaseSocket subclass')
@@ -352,26 +337,19 @@ class RelaySocket(BaseSocket):
:var int port: ORPort our socket connects to
"""
- def __init__(self, address: str = '127.0.0.1', port: int = 9050, connect: bool = True) -> None:
+ def __init__(self, address: str = '127.0.0.1', port: int = 9050) -> None:
"""
RelaySocket constructor.
:param address: ip address of the relay
:param port: orport of the relay
- :param connect: connects to the socket if True, leaves it unconnected otherwise
-
- :raises: :class:`stem.SocketError` if connect is **True** and we're
- unable to establish a connection
"""
super(RelaySocket, self).__init__()
self.address = address
self.port = port
- if connect:
- self.connect()
-
- def send(self, message: Union[str, bytes]) -> None:
+ async def send(self, message: Union[str, bytes]) -> None:
"""
Sends a message to the relay's ORPort.
@@ -382,9 +360,9 @@ class RelaySocket(BaseSocket):
* :class:`stem.SocketClosed` if the socket is known to be shut down
"""
- self._send(message, lambda s, sf, msg: _write_to_socket(sf, msg))
+ await self._send(message, _write_to_socket)
- def recv(self, timeout: Optional[float] = None) -> bytes:
+ async def recv(self, timeout: Optional[float] = None) -> bytes:
"""
Receives a message from the relay.
@@ -398,30 +376,24 @@ class RelaySocket(BaseSocket):
* :class:`stem.SocketClosed` if the socket closes before we receive a complete message
"""
- def wrapped_recv(s: ssl.SSLSocket, sf: BinaryIO) -> bytes:
+ async def wrapped_recv(s: ssl.SSLSocket, sf: BinaryIO) -> bytes:
+ read_coroutine = reader.read(1024)
if timeout is None:
- return s.recv(1024)
+ return await read_coroutine
else:
- s.setblocking(False)
- s.settimeout(timeout)
-
try:
- return s.recv(1024)
- except (socket.timeout, ssl.SSLError, ssl.SSLWantReadError):
+ return await asyncio.wait_for(read_coroutine, timeout)
+ except (asyncio.TimeoutError, ssl.SSLError, ssl.SSLWantReadError):
return None
- finally:
- s.setblocking(True)
- return self._recv(wrapped_recv)
+ return await self._recv(wrapped_recv)
def is_localhost(self) -> bool:
return self.address == '127.0.0.1'
- def _make_socket(self) -> ssl.SSLSocket:
+ async def _open_connection(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
try:
- relay_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- relay_socket.connect((self.address, self.port))
- return ssl.wrap_socket(relay_socket)
+ return await asyncio.open_connection(self.address, self.port, ssl=ssl.SSLContext())
except socket.error as exc:
raise stem.SocketError(exc)
@@ -439,7 +411,7 @@ class ControlSocket(BaseSocket):
def __init__(self) -> None:
super(ControlSocket, self).__init__()
- def send(self, message: Union[bytes, str]) -> None:
+ async def send(self, message: Union[bytes, str]) -> None:
"""
Formats and sends a message to the control socket. For more information see
the :func:`~stem.socket.send_message` function.
@@ -451,9 +423,9 @@ class ControlSocket(BaseSocket):
* :class:`stem.SocketClosed` if the socket is known to be shut down
"""
- self._send(message, lambda s, sf, msg: send_message(sf, msg))
+ await self._send(message, send_message)
- def recv(self) -> stem.response.ControlMessage:
+ async def recv(self) -> stem.response.ControlMessage:
"""
Receives a message from the control socket, blocking until we've received
one. For more information see the :func:`~stem.socket.recv_message` function.
@@ -465,7 +437,7 @@ class ControlSocket(BaseSocket):
* :class:`stem.SocketClosed` if the socket closes before we receive a complete message
"""
- return self._recv(lambda s, sf: recv_message(sf))
+ return await self._recv(recv_message)
class ControlPort(ControlSocket):
@@ -477,33 +449,24 @@ class ControlPort(ControlSocket):
:var int port: ControlPort our socket connects to
"""
- def __init__(self, address: str = '127.0.0.1', port: int = 9051, connect: bool = True) -> None:
+ def __init__(self, address: str = '127.0.0.1', port: int = 9051) -> None:
"""
ControlPort constructor.
:param address: ip address of the controller
:param port: port number of the controller
- :param connect: connects to the socket if True, leaves it unconnected otherwise
-
- :raises: :class:`stem.SocketError` if connect is **True** and we're
- unable to establish a connection
"""
super(ControlPort, self).__init__()
self.address = address
self.port = port
- if connect:
- self.connect()
-
def is_localhost(self) -> bool:
return self.address == '127.0.0.1'
- def _make_socket(self) -> socket.socket:
+ async def _open_connection(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
try:
- control_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- control_socket.connect((self.address, self.port))
- return control_socket
+ return await asyncio.open_connection(self.address, self.port)
except socket.error as exc:
raise stem.SocketError(exc)
@@ -516,36 +479,27 @@ class ControlSocketFile(ControlSocket):
:var str path: filesystem path of the socket we connect to
"""
- def __init__(self, path: str = '/var/run/tor/control', connect: bool = True) -> None:
+ def __init__(self, path: str = '/var/run/tor/control') -> None:
"""
ControlSocketFile constructor.
:param socket_path: path where the control socket is located
- :param connect: connects to the socket if True, leaves it unconnected otherwise
-
- :raises: :class:`stem.SocketError` if connect is **True** and we're
- unable to establish a connection
"""
super(ControlSocketFile, self).__init__()
self.path = path
- if connect:
- self.connect()
-
def is_localhost(self) -> bool:
return True
- def _make_socket(self) -> socket.socket:
+ async def _open_connection(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
try:
- control_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
- control_socket.connect(self.path)
- return control_socket
+ return await asyncio.open_unix_connection(self.path)
except socket.error as exc:
raise stem.SocketError(exc)
-def send_message(control_file: BinaryIO, message: Union[bytes, str], raw: bool = False) -> None:
+async def send_message(writer: asyncio.StreamWriter, message: Union[bytes, str], raw: bool = False) -> None:
"""
Sends a message to the control socket, adding the expected formatting for
single verses multi-line messages. Neither message type should contain an
@@ -566,8 +520,7 @@ def send_message(control_file: BinaryIO, message: Union[bytes, str], raw: bool =
<line 3>\\r\\n
.\\r\\n
- :param control_file: file derived from the control socket (see the
- socket's makefile() method for more information)
+ :param writer: stream derived from the control socket
:param message: message to be sent on the control socket
:param raw: leaves the message formatting untouched, passing it to the
socket as-is
@@ -582,7 +535,7 @@ def send_message(control_file: BinaryIO, message: Union[bytes, str], raw: bool =
if not raw:
message = send_formatting(message)
- _write_to_socket(control_file, message)
+ await _write_to_socket(writer, message)
if log.is_tracing():
log_message = message.replace('\r\n', '\n').rstrip()
@@ -590,10 +543,10 @@ def send_message(control_file: BinaryIO, message: Union[bytes, str], raw: bool =
log.trace('Sent to tor:%s%s' % (msg_div, log_message))
-def _write_to_socket(socket_file: BinaryIO, message: Union[str, bytes]) -> None:
+async def _write_to_socket(writer: asyncio.StreamWriter, message: Union[str, bytes]) -> None:
try:
- socket_file.write(stem.util.str_tools._to_bytes(message))
- socket_file.flush()
+ writer.write(stem.util.str_tools._to_bytes(message))
+ await writer.drain()
except socket.error as exc:
log.info('Failed to send: %s' % exc)
@@ -613,7 +566,7 @@ def _write_to_socket(socket_file: BinaryIO, message: Union[str, bytes]) -> None:
raise stem.SocketClosed('file has been closed')
-def recv_message(control_file: BinaryIO, arrived_at: Optional[float] = None) -> stem.response.ControlMessage:
+async def recv_message(reader: asyncio.StreamReader, arrived_at: Optional[float] = None) -> stem.response.ControlMessage:
"""
Pulls from a control socket until we either have a complete message or
encounter a problem.
@@ -635,7 +588,7 @@ def recv_message(control_file: BinaryIO, arrived_at: Optional[float] = None) ->
while True:
try:
- line = control_file.readline()
+ line = await reader.readline()
except AttributeError:
# if the control_file has been closed then we will receive:
# AttributeError: 'NoneType' object has no attribute 'recv'
@@ -701,7 +654,7 @@ def recv_message(control_file: BinaryIO, arrived_at: Optional[float] = None) ->
while True:
try:
- line = control_file.readline()
+ line = await reader.readline()
raw_content += line
except socket.error as exc:
log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content).decode('utf-8')))))
More information about the tor-commits
mailing list