2008-02-06 18:39:07 +00:00

157 lines
4.8 KiB
Python

from base import Base
from pathed import Pathed
from sqlalchemy import create_engine,Table
from sqlalchemy.orm import create_session
from pkg_resources import resource_stream
import os
def readurls():
filename='test_db.cfg'
fullpath = os.path.join(os.curdir,filename)
ret=[]
tmpfile=Pathed.tmp()
try:
fd=open(fullpath)
except IOError:
print "You must specify the databases to use for testing!"
tmplfile = "%s.tmpl"%filename
print "Copy %s.tmpl to %s and edit your database URLs."%(tmplfile,filename)
raise
#fd = resource_stream('__main__',filename)
for line in fd:
if line.startswith('#'):
continue
line=line.replace('__tmp__',tmpfile).strip()
ret.append(line)
fd.close()
return ret
def is_supported(url,supported,not_supported):
db = url.split(':',1)[0]
if supported is not None:
if isinstance(supported,basestring):
supported = (supported,)
ret = db in supported
elif not_supported is not None:
if isinstance(not_supported,basestring):
not_supported = (not_supported,)
ret = not (db in not_supported)
else:
ret = True
return ret
def usedb(supported=None,not_supported=None):
"""Decorates tests to be run with a database connection
These tests are run once for each available database
@param supported: run tests for ONLY these databases
@param not_supported: run tests for all databases EXCEPT these
If both supported and not_supported are empty, all dbs are assumed
to be supported
"""
if supported is not None and not_supported is not None:
msg = "Can't specify both supported and not_supported in fixture.db()"
assert False, msg
urls = DB.urls
urls = [url for url in urls if is_supported(url,supported,not_supported)]
def entangle(func):
def run(self,*p,**k):
for url in urls:
def run_one():
self._connect(url)
self.setup_method(func)
try:
func(self,*p,**k)
finally:
self.teardown_method(func)
self._disconnect()
yield run_one
return run
return entangle
class DB(Base):
# Constants: connection level
NONE=0 # No connection; just set self.url
CONNECT=1 # Connect; no transaction
TXN=2 # Everything in a transaction
level=TXN
urls=readurls()
# url: engine
engines=dict([(url,create_engine(url)) for url in urls])
def shortDescription(self,*p,**k):
"""List database connection info with description of the test"""
ret = super(DB,self).shortDescription(*p,**k) or str(self)
engine = self._engineInfo()
if engine is not None:
ret = "(%s) %s"%(engine,ret)
return ret
def _engineInfo(self,url=None):
if url is None:
url=self.url
return url
def _connect(self,url):
self.url = url
self.engine = self.engines[url]
if self.level < self.CONNECT:
return
#self.conn = self.engine.connect()
self.session = create_session(bind=self.engine)
if self.level < self.TXN:
return
self.txn = self.session.create_transaction()
#self.txn.add(self.engine)
def _disconnect(self):
if hasattr(self,'txn'):
self.txn.rollback()
if hasattr(self,'session'):
self.session.close()
#if hasattr(self,'conn'):
# self.conn.close()
def run(self,*p,**k):
"""Run one test for each connection string"""
for url in self.urls:
self._run_one(url,*p,**k)
def _supported(self,url):
db = url.split(':',1)[0]
func = getattr(self,self._TestCase__testMethodName)
if hasattr(func,'supported'):
return db in func.supported
if hasattr(func,'not_supported'):
return not (db in func.not_supported)
# Neither list assigned; assume all are supported
return True
def _not_supported(self,url):
return not self._supported(url)
def _run_one(self,url,*p,**k):
if self._not_supported(url):
return
self._connect(url)
try:
super(DB,self).run(*p,**k)
finally:
self._disconnect()
def refresh_table(self,name=None):
"""Reload the table from the database
Assumes we're working with only a single table, self.table, and
metadata self.meta
Working w/ multiple tables is not possible, as tables can only be
reloaded with meta.clear()
"""
if name is None:
name = self.table.name
self.meta.clear()
self.table = Table(name,self.meta,autoload=True)