[tor-commits] [ooni-probe/master] Refactoring of NetTestLoader

art at torproject.org art at torproject.org
Tue Apr 30 13:01:44 UTC 2013


commit c17906c3c54feecdb05bbfb64c4293775f953f4e
Author: Arturo Filastò <art at fuffa.org>
Date:   Wed Feb 27 17:05:39 2013 +0100

    Refactoring of NetTestLoader
    
    * Make it clear that calling one of those methods can be extremely dangerous
    * Kill a bug spotted thanks to unittesting
---
 ooni/nettest.py |   86 +++++++++++++++++++++++++++++-------------------------
 1 files changed, 46 insertions(+), 40 deletions(-)

diff --git a/ooni/nettest.py b/ooni/nettest.py
index bd6bd41..b264797 100644
--- a/ooni/nettest.py
+++ b/ooni/nettest.py
@@ -24,7 +24,10 @@ class NetTestLoader(object):
 
     def __init__(self, options):
         self.options = options
-        self.testCases = self.loadNetTest(options['test'])
+        if 'test_file' in options:
+            self.loadNetTestFile(options['test_file'])
+        elif 'test_string' in options:
+            self.loadNetTestString(options['test_string'])
 
     @property
     def testDetails(self):
@@ -110,7 +113,45 @@ class NetTestLoader(object):
                 assert usage_options == test_class.usageOptions
         return usage_options
 
-    def loadNetTest(self, net_test_file):
+    def loadNetTestString(self, net_test_string):
+        """
+        Load NetTest from a string.
+        WARNING input to this function *MUST* be sanitized and *NEVER* be
+        untrusted.
+        Failure to do so will result in code exec.
+
+        net_test_string:
+
+            a string that contains the net test to be run.
+        """
+        net_test_file_object = StringIO(net_test_string)
+
+        ns = {}
+        test_cases = []
+        exec net_test_file_object.read() in ns
+        for item in ns.itervalues():
+            test_cases.extend(self._get_test_methods(item))
+
+        if not test_cases:
+            raise NoTestCasesFound
+
+        self.setupTestCases(test_cases)
+
+    def loadNetTestFile(self, net_test_file):
+        """
+        Load NetTest from a file.
+        """
+        test_cases = []
+        module = filenameToModule(net_test_file)
+        for __, item in getmembers(module):
+            test_cases.extend(self._get_test_methods(item))
+
+        if not test_cases:
+            raise NoTestCasesFound
+
+        self.setupTestCases(test_cases)
+
+    def setupTestCases(self, test_cases):
         """
         Creates all the necessary test_cases (a list of tuples containing the
         NetTestCase (test_class, test_method))
@@ -131,25 +172,10 @@ class NetTestLoader(object):
             is either a file path or a file like object that will be used to
             generate the test_cases.
         """
-        test_cases = None
-        try:
-            if os.path.isfile(net_test_file):
-                test_cases = self._loadNetTestFile(net_test_file)
-            else:
-                net_test_file = StringIO(net_test_file)
-                raise TypeError("not a file path")
-
-        except TypeError:
-            if hasattr(net_test_file, 'read'):
-                test_cases = self._loadNetTestFromFileObject(net_test_file)
-
-        if not test_cases:
-            raise NoTestCasesFound
-
         test_class, _ = test_cases[0]
         self.testVersion = test_class.version
         self.testName = test_class.name.lower().replace(' ','_')
-        return test_cases
+        self.testCases = test_cases
 
     def checkOptions(self):
         """
@@ -161,7 +187,8 @@ class NetTestLoader(object):
 
         for klass in test_classes:
             options = self.usageOptions()
-            options.parseOptions(self.options['subargs'])
+            options.parseOptions(self.options)
+
             if options:
                 klass.localOptions = options
 
@@ -176,27 +203,6 @@ class NetTestLoader(object):
                 inputs = [None]
             klass.inputs = inputs
 
-    def _loadNetTestFromFileObject(self, net_test_string):
-        """
-        Load NetTest from a string
-        """
-        ns = {}
-        test_cases = []
-        exec net_test_string.read() in ns
-        for item in ns.itervalues():
-            test_cases.extend(self._get_test_methods(item))
-        return test_cases
-
-    def _loadNetTestFile(self, net_test_file):
-        """
-        Load NetTest from a file
-        """
-        test_cases = []
-        module = filenameToModule(net_test_file)
-        for __, item in getmembers(module):
-            test_cases.extend(self._get_test_methods(item))
-        return test_cases
-
     def _get_test_methods(self, item):
         """
         Look for test_ methods in subclasses of NetTestCase





More information about the tor-commits mailing list