[tor-commits] [ooni-probe/master] Refactoring tcp flag test

isis at torproject.org isis at torproject.org
Tue Dec 18 05:53:46 UTC 2012


commit 1029140ad660c12f4395df40baff641b534a062a
Author: Isis Lovecruft <isis at torproject.org>
Date:   Thu Dec 6 04:29:55 2012 +0000

    Refactoring tcp flag test
---
 nettests/bridge_reachability/tcpsyn.py |  193 +++++++++++++++-----------------
 ooni/nettest.py                        |   29 -----
 ooni/oonicli.py                        |    6 +-
 ooni/reporter.py                       |    3 +-
 ooni/runner.py                         |   82 ++++++--------
 5 files changed, 127 insertions(+), 186 deletions(-)

diff --git a/nettests/bridge_reachability/tcpsyn.py b/nettests/bridge_reachability/tcpsyn.py
index 548fc0e..6a4b8db 100644
--- a/nettests/bridge_reachability/tcpsyn.py
+++ b/nettests/bridge_reachability/tcpsyn.py
@@ -2,10 +2,10 @@
 # -*- coding: utf-8 -*-
 #
 #  +-----------+
-#  | tcpsyn.py |
+#  | tcpflags.py |
 #  +-----------+
-#     Send a TCP SYN packet to a test server to check that
-#     it is reachable.
+#     Send packets with various TCP flags set to a test server 
+#     to check that it is reachable.
 #
 # @authors: Isis Lovecruft, <isis at torproject.org>
 # @version: 0.0.1-pre-alpha
@@ -36,15 +36,14 @@ class TCPFlagOptions(usage.Options):
     optParameters = [
         ['dst', 'd', None, 'Host IP to ping'],
         ['port', 'p', None, 'Host port'],
+        ['flags', 's', None, 'Comma separated flags to set [S|A|F]'],
         ['count', 'c', 3, 'Number of SYN packets to send', int],
         ['interface', 'i', None, 'Network interface to use'],
         ['hexdump', 'x', False, 'Show hexdump of responses'],
         ['pdf', 'y', False,
-         'Create pdf of visual representation of packet conversations'],
-        ['cerealize', 'z', False,
-         'Cerealize scapy objects for further scripting']]
+         'Create pdf of visual representation of packet conversations']]
 
-class TCPFlagTest(nettest.NetTestCase):
+class TCPFlagsTest(nettest.NetTestCase):
     """
     Sends only a TCP SYN packet to a host IP:PORT, and waits for either a
     SYN/ACK, a RST, or an ICMP error.
@@ -52,21 +51,19 @@ class TCPFlagTest(nettest.NetTestCase):
     TCPSynTest can take an input file containing one IP:Port pair per line, or
     the commandline switches --dst <IP> and --port <PORT> can be used.
     """
-    name         = 'TCP Flag'
+    name         = 'TCP Flags'
     author       = 'Isis Lovecruft <isis at torproject.org>'
     description  = 'A TCP SYN/ACK/FIN test to see if a host is reachable.'
-    version      = '0.0.1'
+    version      = '0.1.1'
     requiresRoot = True
 
     usageOptions = TCPFlagOptions
     inputFile    = ['file', 'f', None, 'File of list of IP:PORTs to ping']
 
-    #destinations = {}
-
-    @log.catch
     def setUp(self, *a, **kw):
         """Configure commandline parameters for TCPSynTest."""
         self.report = {}
+        self.packets = {'results': [], 'unanswered': []}
 
         if self.localOptions:
             for key, value in self.localOptions.items():
@@ -78,7 +75,6 @@ class TCPFlagTest(nettest.NetTestCase):
                 log.warn("Could not find a working network interface!")
                 log.fail(ie)
             else:
-                log.msg("Using system default interface: %s" % iface)
                 self.interface = iface
         if config.advanced.debug:
             defer.setDebugging('on')
@@ -94,13 +90,10 @@ class TCPFlagTest(nettest.NetTestCase):
         @returns: A 2-tuple containing the address and port.
         """
         dst, dport = net.checkIPandPort(addr, port)
-        #if not dst in self.destinations.keys():
         if not dst in self.report.keys():
-            #self.destinations[dst] = {'dst': dst, 'dport': [dport]}
             self.report[dst] = {'dst': dst, 'dport': [dport]}
         else:
             log.debug("Got additional port for destination.")
-            #self.destinations[dst]['dport'].append(dport)
             self.report[dst]['dport'].append(dport)
         return (dst, dport)
 
@@ -112,87 +105,20 @@ class TCPFlagTest(nettest.NetTestCase):
         """
         if self.localOptions['dst'] is not None \
                 and self.localOptions['port'] is not None:
-            log.debug("processing commandline destination input")
+            log.debug("Processing commandline destination")
             yield self.addToDestinations(self.localOptions['dst'],
                                          self.localOptions['port'])
         if input_file and os.path.isfile(input_file):
-            log.debug("processing input file %s" % input_file)
+            log.debug("Processing input file %s" % input_file)
             with open(input_file) as f:
                 for line in f.readlines():
                     if line.startswith('#'):
                         continue
                     one = line.strip()
-                    raw_ip, raw_port = one.rsplit(':', 1) ## XXX not ipv6 safe!
+                    raw_ip, raw_port = one.rsplit(':', 1)
                     yield self.addToDestinations(raw_ip, raw_port)
 
-    @log.catch
-    def createPDF(self, results):
-        pdfname = self.name + '_' + timestamp()
-        results.pdfdump(pdfname)
-        log.msg("Visual packet conversation saved to %s.pdf" % pdfname)
-
-    @staticmethod
-    def build_packets(addr, port, flags=None, count=3):
-        """Construct a list of packets to send out."""
-        packets = []
-        for x in xrange(count):
-            packets.append( IP(dst=addr)/TCP(dport=port, flags=flags) )
-        return packets
-
-    @staticmethod
-    def process_packets(packet_list):
-        """
-        If the source address of packet in :param:packet_list matches one of our input
-        destinations, then extract some of the information from it to the test report.
-
-        @param packet_list:
-            A :class:scapy.plist.PacketList
-        """
-        results, unanswered = packet_list
-
-        if self.pdf:
-            self.createPDF(results)
-
-        for (q, r) in results:
-            request_data = {'dst': q.dst,
-                            'dport': q.dport,
-                            'summary': q.summary(),
-                            'command': q.command(),
-                            'sent_time': q.time}
-            response_data = {'src': r['IP'].src,
-                             'flags': r['IP'].flags,
-                             'summary': r.summary(),
-                             'command': r.command(),
-                             'recv_time': r.time,
-                             'delay': r.time - q.time}
-            if self.hexdump:
-                request_data.update('hexdump', q.hexdump())
-                response_data.update('hexdump', r.hexdump())
-            for dest, data in self.destinations.items():
-                if data['dst'] == response_data['src']:
-                    if not 'reachable' in data:
-                        if self.hexdump:
-                            log.msg("%s\n%s" % (q.hexdump(), r.hexdump()))
-                        else:
-                            log.msg(" Received response:\n%s ==> %s"
-                                    % (q.mysummary(), r.mysummary()))
-                        data.update( {'reachable': True,
-                                      'request': request_data,
-                                      'response': response_data} )
-        return unanswered
-
-    @staticmethod
-    def process_unanswered(unanswered):
-        """Callback function to process unanswered packets."""
-        if unanswered is not None and len(unanswered) > 0:
-            log.msg("Waiting on responses from\n%s" %
-                    '\n'.join( [unans.summary() for unans in unanswered] ))
-        log.msg("Writing response packet information to report...")
-        self.report = (self.destinations)
-        return self.destinations
-
-    @log.catch
-    def tcp_flags(self, flags="S"):
+    def tcp_flags(self, flags=None):
         """
         Generate, send, and listen for responses to, a list of TCP/IP packets
         to an address and port pair taken from the current input, and a string
@@ -202,25 +128,88 @@ class TCPFlagTest(nettest.NetTestCase):
             A string representing the TCP flags to be set, i.e. "SA" or "F".
             Defaults to "S".
         """
+        def build_packets(addr, port, flags=None, count=3):
+            """Construct a list of packets to send out."""
+            packets = []
+            for x in xrange(count):
+                packets.append( IP(dst=addr)/TCP(dport=port, flags=flags) )
+            return packets
+
+        def process_packets(packet_list):
+            """
+            If the source address of packet in :param:packet_list matches one of
+            our input destinations, then extract some of the information from it
+            to the test report.
+
+            @param packet_list:
+                A :class:scapy.plist.PacketList
+            """
+            results, unanswered = packet_list
+            self.packets['results'].append([r for r in results])
+            self.packets['unanswered'].append([u for u in unanswered])
+    
+            for (q, r) in results:
+                request_data = {'dst': q.dst,
+                                'dport': q.dport,
+                                'summary': q.summary(),
+                                'command': q.command(),
+                                'hexdump': None,
+                                'sent_time': q.time}
+                response_data = {'src': r['IP'].src,
+                                 'flags': r['IP'].flags,
+                                 'summary': r.summary(),
+                                 'command': r.command(),
+                                 'hexdump': None,
+                                 'recv_time': r.time,
+                                 'delay': r.time - q.time}
+                if self.hexdump:
+                    request_data.update('hexdump', q.hexdump())
+                    response_data.update('hexdump', r.hexdump())
+
+                for dest, data in self.report.items():
+                    if data['dst'] == response_data['src']:
+                        if not 'reachable' in data:
+                            if self.hexdump:
+                                log.msg("%s\n%s" % (q.hexdump(), r.hexdump()))
+                            else:
+                                log.msg(" Received response:\n%s ==> %s"
+                                        % (q.mysummary(), r.mysummary()))
+                            data.update( {'reachable': True,
+                                          'request': request_data,
+                                          'response': response_data} )
+                            self.report[response_data['src']['data'].update(data)
+
+            if unanswered is not None and len(unanswered) > 0:
+                log.msg("Waiting on responses from\n%s" %
+                        '\n'.join( [unans.summary() for unans in unanswered] ))
+            log.msg("Writing response packet information to report...")
+ 
         (addr, port) = self.input
-        packets = self.build_packets(addr, port, str(flags), self.count)
+        packets = build_packets(addr, port, str(flags), self.count)
         d = txscapy.sr(packets, iface=self.interface)
-        d.addCallbacks(self.process_packets, log.exception)
-        d.addCallbacks(self.process_unanswered, log.exception)
+        #d.addCallbacks(process_packets, log.exception)
+        #d.addCallbacks(process_unanswered, log.exception)
+        d.addCallback(process_packets)
+        d.addErrback(process_unanswered)
+
         return d
 
-    def test_tcp_fin(self):
-        """Send a list of FIN packets to an address and port pair from inputs."""
-        return self.tcp_flags("F")
+    @log.catch
+    def createPDF(self):
+        pdfname = self.name + '_' + timestamp()
+        self.packets['results'].pdfdump(pdfname)
+        log.msg("Visual packet conversation saved to %s.pdf" % pdfname)
+
+    def test_tcp_flags(self):
+        """Send packets with given TCP flags to an address:port pair."""
+        flag_list = self.flags.split(',')
 
-    def test_tcp_syn(self):
-        """Send a list of SYN packets to an address and port pair from inputs."""
-        return self.tcp_flags("S")
+        dl = []
+        for flag in flag_list:
+            dl.append(self.tcp_flags(flag))
+        d = defer.DeferredList(dl)
 
-    def test_tcp_synack(self):
-        """Send a list of SYN/ACK packets to an address and port pair from inputs."""
-        return self.tcp_flags("SA")
+        if self.pdf:
+            d.addCallback(self.createPDF)
 
-    def test_tcp_ack(self):
-        """Send a list of SYN packets to an address and port pair from inputs."""
-        return self.tcp_flags("A")
+        return d
diff --git a/ooni/nettest.py b/ooni/nettest.py
index 1d1477d..bd6ef9b 100644
--- a/ooni/nettest.py
+++ b/ooni/nettest.py
@@ -171,35 +171,6 @@ class NetTestCase(object):
     def __repr__(self):
         return "<%s inputs=%s>" % (self.__class__, self.inputs)
 
-    def _getSkip(self):
-        return txtrutil.acquireAttribute(self._parents, 'skip', None)
-    
-    def _getSkipReason(self, method, skip):
-        return super(TestCase, self)._getSkipReason(self, method, skip)
-    
-    def _getTimeout(self):
-        """
-        Returns the timeout value set on this test. Check on the instance
-        first, the the class, then the module, then package. As soon as it
-        finds something with a timeout attribute, returns that. Returns
-        twisted.trial.util.DEFAULT_TIMEOUT_DURATION if it cannot find
-        anything. See TestCase docstring for more details.
-        """
-        try:
-            testMethod = getattr(self, methodName)
-        except:
-            testMethod = self.setUp
-        self._parents = [testMethod, self]
-        self._parents.extend(txtrutil.getPythonContainers(testMethod))
-        timeout = txtrutil.acquireAttribute(self._parents, 'timeout', 
-                                            txtrutil.DEFAULT_TIMEOUT_DURATION)
-        try:
-            return float(timeout)
-        except (ValueError, TypeError):
-            warnings.warn("'timeout' attribute needs to be a number.",
-                          category=DeprecationWarning)
-            return txtrutil.DEFAULT_TIMEOUT_DURATION
-    
     def _abort(self, reason):
         """
 
diff --git a/ooni/oonicli.py b/ooni/oonicli.py
index 3362d06..c64e445 100644
--- a/ooni/oonicli.py
+++ b/ooni/oonicli.py
@@ -81,7 +81,7 @@ class Options(usage.Options, app.ReactorSelectionMixin):
 
 def testsEnded(*arg, **kw):
     """You can place here all the post shutdown tasks."""
-    log.debug("testsEnded: Finished running all tests")
+    log.debug("Finished running all tests")
 
 def run():
     """Call me to begin testing from a file."""
@@ -133,7 +133,3 @@ def run():
     tests_d = runner.runTestCases(test_cases, options,
                                   cmd_line_options, yamloo_filename)
     tests_d.addBoth(testsEnded)
-
-    ## it appears that tests run without this?
-    #reactor.run()
-
diff --git a/ooni/reporter.py b/ooni/reporter.py
index 193d056..6fdc142 100644
--- a/ooni/reporter.py
+++ b/ooni/reporter.py
@@ -111,8 +111,7 @@ class OReporter(object):
         pass
 
     def testDone(self, test, test_name):
-        log.debug("Finished running %s" % test_name)
-        log.debug("Writing report")
+        log.debug("Calling reporter to record results")
         test_report = dict(test.report)
 
         if isinstance(test.input, packet.Packet):
diff --git a/ooni/runner.py b/ooni/runner.py
index 2b41d59..4214360 100644
--- a/ooni/runner.py
+++ b/ooni/runner.py
@@ -18,8 +18,8 @@ import itertools
 from twisted.python import reflect, usage, failure
 from twisted.internet import defer
 from twisted.trial.runner import filenameToModule
-from twisted.trial import util as txtrutil
 from twisted.trial import reporter as txreporter
+from twisted.trial import util as txtrutil
 from twisted.trial.unittest import utils as txtrutils
 from twisted.trial.unittest import SkipTest
 from twisted.internet import reactor, threads
@@ -144,37 +144,31 @@ def loadTestsAndOptions(classes, cmd_line_options):
 
     return test_cases, options
 
-def abortTestRun(test_class, warn_err_fail, test_input, oreporter):
-    """
-    Abort the entire test, and record the error, failure, or warning for why
-    it could not be completed.
+def getTimeout(test_instance, test_method):
     """
-    log.warn("Aborting remaining tests for %s" % test_name)
+    Returns the timeout value set on this test. Check on the instance first,
+    the the class, then the module, then package. As soon as it finds
+    something with a timeout attribute, returns that. Returns
+    twisted.trial.util.DEFAULT_TIMEOUT_DURATION if it cannot find anything.
 
-def abortTestWasCalled(abort_reason, abort_what, test_class, test_instance, 
-                       test_method, test_input, oreporter):
+    See twisted.trial.unittest.TestCase docstring for more details.
     """
-    XXX
-    """
-    if not abort_what in ['class', 'method', 'input']:
-        log.warn("__test_abort__() must specify 'class', 'method', or 'input'")
-        abort_what = 'input'    
-
-    if not isinstance(abort_reason, Exception):
-        abort_reason = Exception(str(abort_reason))
-    if abort_what == 'input':
-        log.msg("%s test requested to abort for input: %s"
-                % (test_instance.name, test_input))
-        d = defer.maybeDeferred(lambda x: object)
-
-    if hasattr(test_instance, "abort_all"):
-        log.msg("%s test requested to abort all remaining inputs"
-                % test_instance.name)
-    #else:
-    #    d = defer.Deferred()
-    #    d.cancel()
-    #    d = abortTestRun(test_class, reason, test_input, oreporter)
-    
+    try:
+        testMethod = getattr(test_instance, test_method)
+    except:
+        log.debug("_getTimeout couldn't find self.methodName!")
+        return txtrutil.DEFAULT_TIMEOUT_DURATION
+    else:
+        test_instance._parents = [testMethod, test_instance]
+        test_instance._parents.extend(txtrutil.getPythonContainers(testMethod))
+        timeout = txtrutil.acquireAttribute(test_instance._parents, 'timeout', 
+                                            txtrutil.DEFAULT_TIMEOUT_DURATION)
+        try:
+            return float(timeout)
+        except (ValueError, TypeError):
+            warnings.warn("'timeout' attribute needs to be a number.",
+                          category=DeprecationWarning)
+            return txtrutil.DEFAULT_TIMEOUT_DURATION
 
 def runTestWithInput(test_class, test_method, test_input, oreporter):
     """
@@ -205,6 +199,9 @@ def runTestWithInput(test_class, test_method, test_input, oreporter):
 
     def test_error(error, test_instance, test_name):
         if isinstance(error, SkipTest):
+            if len(error.args) > 0:
+                skip_what = error.args[1]
+                # XXX we'll need to handle methods and classes
             log.info("%s" % error.message)
         else:
             log.exception(error)
@@ -212,32 +209,23 @@ def runTestWithInput(test_class, test_method, test_input, oreporter):
     test_instance = test_class()
     test_instance.input = test_input
     test_instance.report = {}
-    # XXX TODO
-    # the twisted.trial.reporter.TestResult is expected by test_timeout(),
-    # but we should eventually replace it with a stub class
+    # XXX TODO the twisted.trial.reporter.TestResult is expected by
+    # test_timeout(), but we should eventually replace it with a stub class
     test_instance._test_result = txreporter.TestResult()
     # use this to keep track of the test runtime
     test_instance._start_time = time.time()
-    test_instance.timeout = test_instance._getTimeout()
+    test_instance.timeout = getTimeout(test_instance, test_method)
     # call setups on the test
     test_instance._setUp()
     test_instance.setUp()
 
-    # check that we haven't inherited a skip
-    test_ignored = txtrutil.acquireAttribute(
+    test_skip = txtrutil.acquireAttribute(
         test_instance._parents, 'skip', None)
-    if test_ignored is not None:
+    if test_skip is not None:
         # XXX we'll need to do something more than warn
-        log.warn("test_skip is %s" % test_ignored)
-
-    # now check our instance for test_methods set to be skipped:
-    skip_list = test_instance._getSkip()
-    if skip_list is not None:
-        log.debug("%s marked these tests to be skipped: %s"
-                  % (test_instance.name, skip_list))
-    else:
-        log.debug("No tests marked as skip")
-    skip_list = [skip_list]
+        log.warn("%s marked these tests to be skipped: %s"
+                  % (test_instance.name, test_skip))
+    skip_list = [test_skip]
 
     if not test_method in skip_list:
         test = getattr(test_instance, test_method)
@@ -249,10 +237,8 @@ def runTestWithInput(test_class, test_method, test_input, oreporter):
     
         d.addCallback(test_done, test_instance, test_method)
         d.addErrback(test_error, test_instance, test_method)
-        log.debug("returning %s input" % test_method)
     else:
         d = defer.Deferred()
-
     return d
 
 def runTestWithInputUnit(test_class, test_method, input_unit, oreporter):





More information about the tor-commits mailing list