[or-cvs] [bridgedb/master] Implement a sqlite replacement for our current db wrappers.
Nick Mathewson
nickm at seul.org
Fri Sep 25 06:03:22 UTC 2009
Author: Nick Mathewson <nickm at torproject.org>
Date: Fri, 25 Sep 2009 01:18:38 -0400
Subject: Implement a sqlite replacement for our current db wrappers.
Commit: 1d739d1bfc7b544382066ebf9c6df7895c95cd60
Now all we'll need to do is reverse-engineer our current DB usage,
design a schema, write a migration tool, and switch the code to use
sqlite.
Such fun!
---
lib/bridgedb/Storage.py | 92 +++++++++++++++++++++++++++++++++++++++++++++++
lib/bridgedb/Tests.py | 60 +++++++++++++++++++++++++++----
2 files changed, 145 insertions(+), 7 deletions(-)
create mode 100644 lib/bridgedb/Storage.py
diff --git a/lib/bridgedb/Storage.py b/lib/bridgedb/Storage.py
new file mode 100644
index 0000000..a0430c0
--- /dev/null
+++ b/lib/bridgedb/Storage.py
@@ -0,0 +1,92 @@
+# BridgeDB by Nick Mathewson.
+# Copyright (c) 2007-2009, The Tor Project, Inc.
+# See LICENSE for licensing information
+
+def _escapeValue(v):
+ return "'%s'" % v.replace("'", "''")
+
+class SqliteDict:
+ """
+ A SqliteDict wraps a SQLite table and makes it look like a
+ Python dictionary. In addition to the single key and value
+ columns, there can be a number of "fixed" columns, such that
+ the dictionary only contains elements of the table where the
+ fixed columns are set appropriately.
+ """
+ def __init__(self, conn, cursor, table, fixedcolnames, fixedcolvalues,
+ keycol, valcol):
+ assert len(fixedcolnames) == len(fixedcolvalues)
+ self._conn = conn
+ self._cursor = cursor
+ keys = ", ".join(fixedcolnames+(keycol,valcol))
+ vals = "".join("%s, "%_escapeValue(v) for v in fixedcolvalues)
+ constraint = "WHERE %s = ?"%keycol
+ if fixedcolnames:
+ constraint += "".join(
+ " AND %s = %s"%(c,_escapeValue(v))
+ for c,v in zip(fixedcolnames, fixedcolvalues))
+
+ self._getStmt = "SELECT %s FROM %s %s"%(valcol,table,constraint)
+ self._delStmt = "DELETE FROM %s %s"%(table,constraint)
+ self._setStmt = "INSERT OR REPLACE INTO %s (%s) VALUES (%s?, ?)"%(
+ table, keys, vals)
+
+ constraint = " AND ".join("%s = %s"%(c,_escapeValue(v))
+ for c,v in zip(fixedcolnames, fixedcolvalues))
+ if constraint:
+ whereClause = " WHERE %s"%constraint
+ else:
+ whereClause = ""
+
+ self._keysStmt = "SELECT %s FROM %s%s"%(keycol,table,whereClause)
+
+ def __setitem__(self, k, v):
+ self._cursor.execute(self._setStmt, (k,v))
+ def __delitem__(self, k):
+ self._cursor.execute(self._delStmt, (k,))
+ if self._cursor.rowcount == 0:
+ raise KeyError(k)
+ def __getitem__(self, k):
+ self._cursor.execute(self._getStmt, (k,))
+ val = self._cursor.fetchone()
+ if val == None:
+ raise KeyError(k)
+ else:
+ return val[0]
+ def has_key(self):
+ self._cursor.execute(self._getStmt, (k,))
+ return self._cursor.rowcount != 0
+ def get(self, k, v=None):
+ self._cursor.execute(self._getStmt, (k,))
+ val = self._cursor.fetchone()
+ if val == None:
+ return v;
+ else:
+ return val[0]
+ def setdefault(self, k, v):
+ try:
+ r = self[k]
+ except KeyError:
+ r = self[k] = v
+ return r
+ def keys(self):
+ self._cursor.execute(self._keysStmt)
+ return [ key for (key,) in self._cursor.fetchall() ]
+
+ def commit(self):
+ self._conn.commit()
+ def rollback(self):
+ self._conn.rollback()
+
+#
+# The old DB system was just a key->value mapping DB, with special key
+# prefixes to indicate which database they fell into.
+#
+# sp|<HEXID> -- given to bridgesplitter; maps bridgeID to ring name.
+# em|<emailaddr> -- given to emailbaseddistributor; maps email address
+# to concatenated hexID.
+# fs|<HEXID> -- Given to BridgeTracker, maps to time when a router was
+# first seen (YYYY-MM-DD HH:MM)
+# ls|<HEXID> -- given to bridgetracker, maps to time when a router was
+# last seen (YYYY-MM-DD HH:MM)
+#
diff --git a/lib/bridgedb/Tests.py b/lib/bridgedb/Tests.py
index 9b9b1b9..865c91e 100644
--- a/lib/bridgedb/Tests.py
+++ b/lib/bridgedb/Tests.py
@@ -3,22 +3,22 @@
# See LICENSE for licensing information
import doctest
+import os
+import random
+import sqlite3
+import tempfile
import unittest
import warnings
-import random
import bridgedb.Bridges
import bridgedb.Main
import bridgedb.Dist
import bridgedb.Time
+import bridgedb.Storage
def suppressWarnings():
warnings.filterwarnings('ignore', '.*tmpnam.*')
-class TestCase0(unittest.TestCase):
- def testFooIsFooish(self):
- self.assert_(True)
-
def randomIP():
return ".".join([str(random.randrange(1,256)) for _ in xrange(4)])
@@ -82,11 +82,57 @@ class IPBridgeDistTests(unittest.TestCase):
self.assertEquals(len(fps), 5)
self.assertTrue(count >= 1)
+class StorageTests(unittest.TestCase):
+ def setUp(self):
+ self.fd, self.fname = tempfile.mkstemp()
+ self.conn = sqlite3.Connection(self.fname)
+
+ def tearDown(self):
+ self.conn.close()
+ os.close(self.fd)
+ os.unlink(self.fname)
+
+ def testSimpleDict(self):
+ self.conn.execute("CREATE TABLE A ( X PRIMARY KEY, Y )")
+ d = bridgedb.Storage.SqliteDict(self.conn, self.conn.cursor(),
+ "A", (), (), "X", "Y")
+
+ self.basictests(d)
+
+ def testComplexDict(self):
+ self.conn.execute("CREATE TABLE B ( X, Y, Z, "
+ "CONSTRAINT B_PK PRIMARY KEY (X,Y) )")
+ d = bridgedb.Storage.SqliteDict(self.conn, self.conn.cursor(),
+ "B", ("X",), ("x1",), "Y", "Z")
+ d2 = bridgedb.Storage.SqliteDict(self.conn, self.conn.cursor(),
+ "B", ("X",), ("x2",), "Y", "Z")
+ self.basictests(d)
+ self.basictests(d2)
+
+ def basictests(self, d):
+ d["hello"] = "goodbye"
+ d["hola"] = "adios"
+ self.assertEquals(d["hola"], "adios")
+ d["hola"] = "hasta luego"
+ self.assertEquals(d["hola"], "hasta luego")
+ self.assertEquals(sorted(d.keys()), [u"hello", u"hola"])
+ self.assertRaises(KeyError, d.__getitem__, "buongiorno")
+ self.assertEquals(d.get("buongiorno", "ciao"), "ciao")
+ self.conn.commit()
+ d["buongiorno"] = "ciao"
+ del d['hola']
+ self.assertRaises(KeyError, d.__getitem__, "hola")
+ self.conn.rollback()
+ self.assertEquals(d["hola"], "hasta luego")
+ self.assertEquals(d.setdefault("hola","bye"), "hasta luego")
+ self.assertEquals(d.setdefault("yo","bye"), "bye")
+ self.assertEquals(d['yo'], "bye")
+
def testSuite():
suite = unittest.TestSuite()
loader = unittest.TestLoader()
- for klass in [ TestCase0, IPBridgeDistTests ]:
+ for klass in [ IPBridgeDistTests, StorageTests ]:
suite.addTest(loader.loadTestsFromTestCase(klass))
for module in [ bridgedb.Bridges,
@@ -99,7 +145,7 @@ def testSuite():
def main():
suppressWarnings()
-
+
unittest.TextTestRunner(verbosity=1).run(testSuite())
--
1.5.6.5
More information about the tor-commits
mailing list