From fa6e6910e7f7e31252bbf43957d8bb491f2295db Mon Sep 17 00:00:00 2001 From: iElectric Date: Sat, 6 Jun 2009 17:02:39 +0000 Subject: [PATCH] added tests for versioning.version.py, refactored the module --- TODO | 2 + migrate/versioning/repository.py | 6 +- migrate/versioning/util/__init__.py | 2 +- migrate/versioning/version.py | 147 ++++++++++++---------------- setup.cfg | 1 + test/fixture/pathed.py | 8 ++ test/versioning/test_repository.py | 1 + test/versioning/test_shell.py | 1 + test/versioning/test_version.py | 142 +++++++++++++++++++++------ 9 files changed, 189 insertions(+), 121 deletions(-) diff --git a/TODO b/TODO index 75ee438..d29c879 100644 --- a/TODO +++ b/TODO @@ -4,3 +4,5 @@ - document dotted_name parsing changes - document shell parsing - document engine parameters usage/parsing + +- better SQL scripts support (testing, source viewing) diff --git a/migrate/versioning/repository.py b/migrate/versioning/repository.py index 47c056c..4e705a4 100644 --- a/migrate/versioning/repository.py +++ b/migrate/versioning/repository.py @@ -61,9 +61,7 @@ class Changeset(dict): class Repository(pathed.Pathed): """A project's change script repository""" - # Configuration file, inside repository _config = 'migrate.cfg' - # Version information, inside repository _versions = 'versions' def __init__(self, path): @@ -133,10 +131,10 @@ class Repository(pathed.Pathed): return cls(path) def create_script(self, description, **k): - self.versions.createNewVersion(description, **k) + self.versions.create_new_python_version(description, **k) def create_script_sql(self, database, **k): - self.versions.createNewSQLVersion(database, **k) + self.versions.create_new_sql_version(database, **k) latest=property(lambda self: self.versions.latest) version_table=property(lambda self: self.config.get('db_settings', diff --git a/migrate/versioning/util/__init__.py b/migrate/versioning/util/__init__.py index ee51d58..982c756 100644 --- a/migrate/versioning/util/__init__.py +++ b/migrate/versioning/util/__init__.py @@ -17,7 +17,7 @@ def load_model(dotted_name): if isinstance(dotted_name, basestring): if ':' not in dotted_name: # backwards compatibility - warnings.warn('model should be in form of module.model:User' + warnings.warn('model should be in form of module.model:User ' 'and not module.model.User', DeprecationWarning) dotted_name = ':'.join(dotted_name.rsplit('.', 1)) return EntryPoint.parse('x=%s' % dotted_name).load(False) diff --git a/migrate/versioning/version.py b/migrate/versioning/version.py index dac4ac3..9c67a8c 100644 --- a/migrate/versioning/version.py +++ b/migrate/versioning/version.py @@ -36,7 +36,7 @@ class VerNum(object): return int(self) - int(value) def __repr__(self): - return str(self.value) + return "" % self.value def __str__(self): return str(self.value) @@ -44,32 +44,25 @@ class VerNum(object): def __int__(self): return int(self.value) -def str_to_filename(s): - """Replaces spaces, (double and single) quotes - and double underscores to underscores - """ - - s = s.replace(' ', '_').replace('"', '_').replace("'", '_') - while '__' in s: - s = s.replace('__', '_') - return s - class Collection(pathed.Pathed): """A collection of versioning scripts in a repository""" - FILENAME_WITH_VERSION = re.compile(r'^(\d+).*') + FILENAME_WITH_VERSION = re.compile(r'^(\d{3,}).*') def __init__(self, path): + """Collect current version scripts in repository""" super(Collection, self).__init__(path) # Create temporary list of files, allowing skipped version numbers. files = os.listdir(path) - tempVersions = dict() if '1' in files: + # deprecation raise Exception('It looks like you have a repository in the old ' 'format (with directories for each version). ' 'Please convert repository before proceeding.') + + tempVersions = dict() for filename in files: match = self.FILENAME_WITH_VERSION.match(filename) if match: @@ -83,27 +76,14 @@ class Collection(pathed.Pathed): self.versions = dict() for num, files in tempVersions.items(): self.versions[VerNum(num)] = Version(num, path, files) - # calculate latest version - self.latest = max([VerNum(0)] + self.versions.keys()) - def version_path(self, ver): - return os.path.join(self.path, str(ver)) - - def version(self, vernum=None): - if vernum is None: - vernum = self.latest - return self.versions[VerNum(vernum)] + @property + def latest(self): + return max([VerNum(0)] + self.versions.keys()) - def getNewVersion(self): + def create_new_python_version(self, description, **k): + """Create Python files for new version""" ver = self.latest + 1 - # No change scripts exist for 0 (even though it's a valid version) - if ver <= 0: - raise exceptions.InvalidVersionError() - self.latest = ver - return ver - - def createNewVersion(self, description, **k): - ver = self.getNewVersion() extra = str_to_filename(description) if extra: @@ -113,7 +93,7 @@ class Collection(pathed.Pathed): extra = '_%s' % extra filename = '%03d%s.py' % (ver, extra) - filepath = self.version_path(filename) + filepath = self._version_path(filename) if os.path.exists(filepath): raise Exception('Script already exists: %s' % filepath) @@ -122,82 +102,70 @@ class Collection(pathed.Pathed): self.versions[ver] = Version(ver, self.path, [filename]) - def createNewSQLVersion(self, database, **k): - # Determine version number to use. - # fix from Issue 29 - ver = self.getNewVersion() + def create_new_sql_version(self, database, **k): + """Create SQL files for new version""" + ver = self.latest + 1 self.versions[ver] = Version(ver, self.path, []) # Create new files. for op in ('upgrade', 'downgrade'): filename = '%03d_%s_%s.sql' % (ver, database, op) - filepath = self.version_path(filename) + filepath = self._version_path(filename) if os.path.exists(filepath): raise Exception('Script already exists: %s' % filepath) else: open(filepath, "w").close() - self.versions[ver]._add_script(filepath) + self.versions[ver].add_script(filepath) + def version(self, vernum=None): + """Returns latest Version if vernum is not given. \ + Otherwise, returns wanted version""" + if vernum is None: + vernum = self.latest + return self.versions[VerNum(vernum)] + @classmethod def clear(cls): super(Collection, cls).clear() - -class Extensions: - """A namespace for file extensions""" - py = 'py' - sql = 'sql' + def _version_path(self, ver): + """Returns path of file in versions repository""" + return os.path.join(self.path, str(ver)) -class Version(object): # formerly inherit from: (pathed.Pathed): - """A single version in a repository """ +class Version(object): + """A single version in a collection """ def __init__(self, vernum, path, filelist): - # Version must be numeric - try: - self.version = VerNum(vernum) - except: - raise exceptions.InvalidVersionError(vernum) + self.version = VerNum(vernum) # Collect scripts in this folder self.sql = dict() self.python = None for script in filelist: - # skip __init__.py, because we assume that it's - # just there to mark the package - if script == '__init__.py': - continue - self._add_script(os.path.join(path, script)) + self.add_script(os.path.join(path, script)) def script(self, database=None, operation=None): - # Try to return a .sql script first - try: - return self._script_sql(database, operation) - except KeyError: - pass # No .sql script exists - - # Try to return the default .sql script - try: - return self._script_sql('default', operation) - except KeyError: - pass # No .sql script exists + """Returns SQL or Python Script""" + for db in (database, 'default'): + # Try to return a .sql script first + try: + return self.sql[db][operation] + except KeyError: + continue # No .sql script exists - ret = self._script_py() + # TODO: maybe add force Python parameter? + ret = self.python assert ret is not None return ret - def _script_py(self): - return self.python - - def _script_sql(self, database, operation): - return self.sql[database][operation] - + # deprecated? @classmethod def create(cls, path): os.mkdir(path) - # craete the version as a proper Python package + # create the version as a proper Python package initfile = os.path.join(path, "__init__.py") if not os.path.exists(initfile): # just touch the file @@ -209,7 +177,8 @@ class Version(object): # formerly inherit from: (pathed.Pathed): raise return ret - def _add_script(self, path): + def add_script(self, path): + """Add script to Collection/Version""" if path.endswith(Extensions.py): self._add_script_py(path) elif path.endswith(Extensions.sql): @@ -223,14 +192,10 @@ class Version(object): # formerly inherit from: (pathed.Pathed): if match: version, dbms, op = match.group(1), match.group(2), match.group(3) else: - raise exceptions.ScriptError("Invalid sql script name %s" % path) + raise exceptions.ScriptError("Invalid SQL script name %s" % path) # File the script into a dictionary - dbmses = self.sql - if dbms not in dbmses: - dbmses[dbms] = dict() - ops = dbmses[dbms] - ops[op] = script.SqlScript(path) + self.sql.setdefault(dbms, {})[op] = script.SqlScript(path) def _add_script_py(self, path): if self.python is not None: @@ -238,9 +203,17 @@ class Version(object): # formerly inherit from: (pathed.Pathed): ' but you have: %s and %s' % (self.python, path)) self.python = script.PythonScript(path) - def _rm_ignore(self, path): - """Try to remove a path; ignore failure""" - try: - os.remove(path) - except OSError: - pass +class Extensions: + """A namespace for file extensions""" + py = 'py' + sql = 'sql' + +def str_to_filename(s): + """Replaces spaces, (double and single) quotes + and double underscores to underscores + """ + + s = s.replace(' ', '_').replace('"', '_').replace("'", '_') + while '__' in s: + s = s.replace('__', '_') + return s diff --git a/setup.cfg b/setup.cfg index d31d1f7..a62ecf0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,7 @@ tag_build = .dev [nosetests] pdb = true +#pdb-failures = true [aliases] release = egg_info -RDb '' diff --git a/test/fixture/pathed.py b/test/fixture/pathed.py index 2a0e7a6..728a613 100644 --- a/test/fixture/pathed.py +++ b/test/fixture/pathed.py @@ -13,6 +13,14 @@ class Pathed(base.Base): _tmpdir = tempfile.mkdtemp() + def setUp(self): + super(Pathed, self).setUp() + self.temp_usable_dir = tempfile.mkdtemp() + + def tearDown(self): + super(Pathed, self).tearDown() + self.temp_usable_dir = tempfile.mkdtemp() + @classmethod def _tmp(cls, prefix='', suffix=''): """Generate a temporary file name that doesn't exist diff --git a/test/versioning/test_repository.py b/test/versioning/test_repository.py index af573e7..53f78a2 100644 --- a/test/versioning/test_repository.py +++ b/test/versioning/test_repository.py @@ -52,6 +52,7 @@ class TestVersionedRepository(fixture.Pathed): """Tests on an existing repository with a single python script""" script_cls = script.PythonScript def setUp(self): + super(TestVersionedRepository, self).setUp() Repository.clear() self.path_repos=self.tmp_repos() # Create repository, script diff --git a/test/versioning/test_shell.py b/test/versioning/test_shell.py index 08f0e9a..baba8be 100644 --- a/test/versioning/test_shell.py +++ b/test/versioning/test_shell.py @@ -173,6 +173,7 @@ class TestShellRepository(Shell): def setUp(self): """Create repository, python change script""" + super(TestShellRepository, self).setUp() self.path_repos = repos = self.tmp_repos() self.assertSuccess(self.cmd('create', repos, 'repository_name')) diff --git a/test/versioning/test_version.py b/test/versioning/test_version.py index 0d6f2fe..a4855fe 100644 --- a/test/versioning/test_version.py +++ b/test/versioning/test_version.py @@ -8,51 +8,135 @@ from migrate.versioning.version import * class TestVerNum(fixture.Base): def test_invalid(self): """Disallow invalid version numbers""" - versions = ('-1',-1,'Thirteen','') + versions = ('-1', -1, 'Thirteen', '') for version in versions: - self.assertRaises(ValueError,VerNum,version) + self.assertRaises(ValueError, VerNum, version) + def test_is(self): - a=VerNum(1) - b=VerNum(1) + """Two version with the same number should be equal""" + a = VerNum(1) + b = VerNum(1) self.assert_(a is b) + + self.assertEqual(VerNum(VerNum(2)), VerNum(2)) + def test_add(self): - self.assert_(VerNum(1)+VerNum(1)==VerNum(2)) - self.assert_(VerNum(1)+1==2) - self.assert_(VerNum(1)+1=='2') + self.assertEqual(VerNum(1) + VerNum(1), VerNum(2)) + self.assertEqual(VerNum(1) + 1, 2) + self.assertEqual(VerNum(1) + 1, '2') + self.assert_(isinstance(VerNum(1) + 1, VerNum)) + def test_sub(self): - self.assert_(VerNum(1)-1==0) - self.assertRaises(ValueError,lambda:VerNum(0)-1) + self.assertEqual(VerNum(1) - 1, 0) + self.assert_(isinstance(VerNum(1) - 1, VerNum)) + self.assertRaises(ValueError, lambda: VerNum(0) - 1) + def test_eq(self): - self.assert_(VerNum(1)==VerNum('1')) - self.assert_(VerNum(1)==1) - self.assert_(VerNum(1)=='1') - self.assert_(not VerNum(1)==2) + """Two versions are equal""" + self.assertEqual(VerNum(1), VerNum('1')) + self.assertEqual(VerNum(1), 1) + self.assertEqual(VerNum(1), '1') + self.assertNotEqual(VerNum(1), 2) + def test_ne(self): - self.assert_(VerNum(1)!=2) - self.assert_(not VerNum(1)!=1) + self.assert_(VerNum(1) != 2) + self.assertFalse(VerNum(1) != 1) + def test_lt(self): - self.assert_(not VerNum(1)<1) - self.assert_(VerNum(1)<2) - self.assert_(not VerNum(2)<1) + self.assertFalse(VerNum(1) < 1) + self.assert_(VerNum(1) < 2) + self.assertFalse(VerNum(2) < 1) + def test_le(self): - self.assert_(VerNum(1)<=1) - self.assert_(VerNum(1)<=2) - self.assert_(not VerNum(2)<=1) + self.assert_(VerNum(1) <= 1) + self.assert_(VerNum(1) <= 2) + self.assertFalse(VerNum(2) <= 1) + def test_gt(self): - self.assert_(not VerNum(1)>1) - self.assert_(not VerNum(1)>2) - self.assert_(VerNum(2)>1) + self.assertFalse(VerNum(1) > 1) + self.assertFalse(VerNum(1) > 2) + self.assert_(VerNum(2) > 1) + def test_ge(self): - self.assert_(VerNum(1)>=1) - self.assert_(not VerNum(1)>=2) - self.assert_(VerNum(2)>=1) + self.assert_(VerNum(1) >= 1) + self.assert_(VerNum(2) >= 1) + self.assertFalse(VerNum(1) >= 2) -class TestDescriptionNaming(fixture.Base): - def test_names(self): +class TestVersion(fixture.Pathed): + + def setUp(self): + super(TestVersion, self).setUp() + + def test_str_to_filename(self): self.assertEquals(str_to_filename(''), '') + self.assertEquals(str_to_filename('__'), '_') self.assertEquals(str_to_filename('a'), 'a') self.assertEquals(str_to_filename('Abc Def'), 'Abc_Def') self.assertEquals(str_to_filename('Abc "D" Ef'), 'Abc_D_Ef') self.assertEquals(str_to_filename("Abc's Stuff"), 'Abc_s_Stuff') self.assertEquals(str_to_filename("a b"), 'a_b') + + def test_collection(self): + """Let's see how we handle versions collection""" + coll = Collection(self.temp_usable_dir) + coll.create_new_python_version("foo bar") + coll.create_new_sql_version("postgres") + coll.create_new_sql_version("sqlite") + coll.create_new_python_version("") + + self.assertEqual(coll.latest, 4) + self.assertEqual(len(coll.versions), 4) + self.assertEqual(coll.version(4), coll.version(coll.latest)) + + coll2 = Collection(self.temp_usable_dir) + self.assertEqual(coll.versions, coll2.versions) + + #def test_collection_unicode(self): + # pass + + def test_create_new_python_version(self): + coll = Collection(self.temp_usable_dir) + coll.create_new_python_version("foo bar") + + ver = coll.version() + self.assert_(ver.script().source()) + + def test_create_new_sql_version(self): + coll = Collection(self.temp_usable_dir) + coll.create_new_sql_version("sqlite") + + ver = coll.version() + ver_up = ver.script('sqlite', 'upgrade') + ver_down = ver.script('sqlite', 'downgrade') + ver_up.source() + ver_down.source() + + def test_selection(self): + """Verify right sql script is selected""" + # Create empty directory. + path = self.tmp_repos() + os.mkdir(path) + + # Create files -- files must be present or you'll get an exception later. + python_file = '001_initial_.py' + sqlite_upgrade_file = '001_sqlite_upgrade.sql' + default_upgrade_file = '001_default_upgrade.sql' + for file_ in [sqlite_upgrade_file, default_upgrade_file, python_file]: + filepath = '%s/%s' % (path, file_) + open(filepath, 'w').close() + + ver = Version(1, path, [sqlite_upgrade_file]) + self.assertEquals(os.path.basename(ver.script('sqlite', 'upgrade').path), sqlite_upgrade_file) + + ver = Version(1, path, [default_upgrade_file]) + self.assertEquals(os.path.basename(ver.script('default', 'upgrade').path), default_upgrade_file) + + ver = Version(1, path, [sqlite_upgrade_file, default_upgrade_file]) + self.assertEquals(os.path.basename(ver.script('sqlite', 'upgrade').path), sqlite_upgrade_file) + + ver = Version(1, path, [sqlite_upgrade_file, default_upgrade_file, python_file]) + self.assertEquals(os.path.basename(ver.script('postgres', 'upgrade').path), default_upgrade_file) + + ver = Version(1, path, [sqlite_upgrade_file, python_file]) + self.assertEquals(os.path.basename(ver.script('postgres', 'upgrade').path), python_file)