support default for sql scripts

This commit is contained in:
christian.simms 2008-05-31 11:16:23 +00:00
parent f4d358e3fb
commit b37afa2b6e
2 changed files with 40 additions and 4 deletions

View File

@ -149,12 +149,21 @@ class Version(object): # formerly inherit from: (pathed.Pathed):
#if database is None and operation is None:
# return self._script_py()
#print database,operation,self.sql
try:
# Try to return a .sql script first
ret = self._script_sql(database,operation)
return self._script_sql(database,operation)
except KeyError:
# No .sql script exists; return a python script
ret = self._script_py()
pass # No .sql script exists
try:
# Try to return the default .sql script
return self._script_sql('default',operation)
except KeyError:
pass # No .sql script exists
ret = self._script_py()
assert ret is not None
return ret
def _script_py(self):

View File

@ -1,6 +1,6 @@
from test import fixture
from migrate.versioning.script import *
from migrate.versioning import exceptions
from migrate.versioning import exceptions, version
import os,shutil
class TestPyScript(fixture.Pathed):
@ -55,3 +55,30 @@ class TestPyScript(fixture.Pathed):
self.cls.create(path)
self.cls.verify(path)
class TestSqlScript(fixture.Pathed):
def test_selection(self):
"""Verify right sql script is selected"""
# Create empty directory.
path=self.tmp_repos()
os.mkdir(path)
# Create files -- files must be present or you'll get an exception later.
sqlite_upgrade_file = '001_sqlite_upgrade.sql'
default_upgrade_file = '001_default_upgrade.sql'
for file in [sqlite_upgrade_file, default_upgrade_file]:
filepath = '%s/%s' % (path, file)
open(filepath, 'w').close()
ver = version.Version(1, path, [sqlite_upgrade_file])
self.assertEquals(os.path.basename(ver.script('sqlite', 'upgrade').path), sqlite_upgrade_file)
ver = version.Version(1, path, [default_upgrade_file])
self.assertEquals(os.path.basename(ver.script('default', 'upgrade').path), default_upgrade_file)
ver = version.Version(1, path, [sqlite_upgrade_file, default_upgrade_file])
self.assertEquals(os.path.basename(ver.script('sqlite', 'upgrade').path), sqlite_upgrade_file)
ver = version.Version(1, path, [sqlite_upgrade_file, default_upgrade_file])
self.assertEquals(os.path.basename(ver.script('postgres', 'upgrade').path), default_upgrade_file)