diff --git a/migrate/versioning/version.py b/migrate/versioning/version.py index 75a7a4f..dfaabf2 100644 --- a/migrate/versioning/version.py +++ b/migrate/versioning/version.py @@ -149,12 +149,21 @@ class Version(object): # formerly inherit from: (pathed.Pathed): #if database is None and operation is None: # return self._script_py() #print database,operation,self.sql + try: # Try to return a .sql script first - ret = self._script_sql(database,operation) + return self._script_sql(database,operation) except KeyError: - # No .sql script exists; return a python script - ret = self._script_py() + pass # No .sql script exists + + try: + # Try to return the default .sql script + return self._script_sql('default',operation) + except KeyError: + pass # No .sql script exists + + ret = self._script_py() + assert ret is not None return ret def _script_py(self): diff --git a/test/versioning/test_script.py b/test/versioning/test_script.py index 47c6d7d..eace2fa 100644 --- a/test/versioning/test_script.py +++ b/test/versioning/test_script.py @@ -1,6 +1,6 @@ from test import fixture from migrate.versioning.script import * -from migrate.versioning import exceptions +from migrate.versioning import exceptions, version import os,shutil class TestPyScript(fixture.Pathed): @@ -55,3 +55,30 @@ class TestPyScript(fixture.Pathed): self.cls.create(path) self.cls.verify(path) +class TestSqlScript(fixture.Pathed): + def test_selection(self): + """Verify right sql script is selected""" + + # Create empty directory. + path=self.tmp_repos() + os.mkdir(path) + + # Create files -- files must be present or you'll get an exception later. + sqlite_upgrade_file = '001_sqlite_upgrade.sql' + default_upgrade_file = '001_default_upgrade.sql' + for file in [sqlite_upgrade_file, default_upgrade_file]: + filepath = '%s/%s' % (path, file) + open(filepath, 'w').close() + + ver = version.Version(1, path, [sqlite_upgrade_file]) + self.assertEquals(os.path.basename(ver.script('sqlite', 'upgrade').path), sqlite_upgrade_file) + + ver = version.Version(1, path, [default_upgrade_file]) + self.assertEquals(os.path.basename(ver.script('default', 'upgrade').path), default_upgrade_file) + + ver = version.Version(1, path, [sqlite_upgrade_file, default_upgrade_file]) + self.assertEquals(os.path.basename(ver.script('sqlite', 'upgrade').path), sqlite_upgrade_file) + + ver = version.Version(1, path, [sqlite_upgrade_file, default_upgrade_file]) + self.assertEquals(os.path.basename(ver.script('postgres', 'upgrade').path), default_upgrade_file) +