From 3fe1b4a3be05fe5184feb8bc237ceded5a98e8dc Mon Sep 17 00:00:00 2001 From: Gabriel Date: Tue, 5 Jul 2011 00:44:57 +0200 Subject: [PATCH] Fix and test issue 118. Clarify genmodel transformations. --- migrate/tests/versioning/test_genmodel.py | 64 ++++++++++++++--------- migrate/tests/versioning/test_script.py | 51 +++++++++++++----- migrate/versioning/genmodel.py | 57 ++++++++++++++------ migrate/versioning/schema.py | 6 +-- migrate/versioning/schemadiff.py | 64 +++++++++++------------ migrate/versioning/script/py.py | 16 +++--- 6 files changed, 159 insertions(+), 99 deletions(-) diff --git a/migrate/tests/versioning/test_genmodel.py b/migrate/tests/versioning/test_genmodel.py index cf5a378..aa3ac06 100644 --- a/migrate/tests/versioning/test_genmodel.py +++ b/migrate/tests/versioning/test_genmodel.py @@ -35,7 +35,7 @@ class TestSchemaDiff(fixture.DB): def _applyLatestModel(self): diff = schemadiff.getDiffOfModelAgainstDatabase(self.meta, self.engine, excludeTables=['migrate_version']) - genmodel.ModelGenerator(diff,self.engine).applyModel() + genmodel.ModelGenerator(diff,self.engine).runB2A() @fixture.usedb() def test_functional(self): @@ -57,30 +57,44 @@ class TestSchemaDiff(fixture.DB): # Check Python upgrade and downgrade of database from updated model. diff = schemadiff.getDiffOfModelAgainstDatabase(self.meta, self.engine, excludeTables=['migrate_version']) - decls, upgradeCommands, downgradeCommands = genmodel.ModelGenerator(diff,self.engine).toUpgradeDowngradePython() - self.assertEqualsIgnoreWhitespace(decls, ''' - from migrate.changeset import schema - meta = MetaData() - tmp_schemadiff = Table('tmp_schemadiff', meta, - Column('id', Integer(), primary_key=True, nullable=False), - Column('name', UnicodeText(length=None)), - Column('data', UnicodeText(length=None)), - ) - ''') + decls, upgradeCommands, downgradeCommands = genmodel.ModelGenerator(diff,self.engine).genB2AMigration() + + # Feature test for a recent SQLa feature; + # expect different output in that case. + if repr(String()) == 'String()': + self.assertEqualsIgnoreWhitespace(decls, ''' + from migrate.changeset import schema + meta = MetaData() + tmp_schemadiff = Table('tmp_schemadiff', meta, + Column('id', Integer, primary_key=True, nullable=False), + Column('name', UnicodeText), + Column('data', UnicodeText), + ) + ''') + else: + self.assertEqualsIgnoreWhitespace(decls, ''' + from migrate.changeset import schema + meta = MetaData() + tmp_schemadiff = Table('tmp_schemadiff', meta, + Column('id', Integer(), primary_key=True, nullable=False), + Column('name', UnicodeText(length=None)), + Column('data', UnicodeText(length=None)), + ) + ''') self.assertEqualsIgnoreWhitespace(upgradeCommands, '''meta.bind = migrate_engine tmp_schemadiff.create()''') self.assertEqualsIgnoreWhitespace(downgradeCommands, '''meta.bind = migrate_engine tmp_schemadiff.drop()''') - + # Create table in database, now model should match database. self._applyLatestModel() assertDiff(False, [], [], []) - + # Check Python code gen from database. diff = schemadiff.getDiffOfModelAgainstDatabase(MetaData(), self.engine, excludeTables=['migrate_version']) - src = genmodel.ModelGenerator(diff,self.engine).toPython() + src = genmodel.ModelGenerator(diff,self.engine).genBDefinition() exec src in locals() @@ -105,25 +119,25 @@ class TestSchemaDiff(fixture.DB): Column('data2',Integer(),nullable=True), ) assertDiff(True, [], [], [self.table_name]) - + # Apply latest model changes and find no more diffs. self._applyLatestModel() assertDiff(False, [], [], []) - + if not self.engine.name == 'oracle': # Make sure data is still present. result = self.engine.execute(self.table.select(self.table.c.id==dataId)) rows = result.fetchall() eq_(len(rows), 1) eq_(rows[0].name, 'mydata') - + # Add data, later we'll make sure it's still present. result = self.engine.execute(self.table.insert(), id=2, name=u'mydata2', data2=123) if SQLA_06: dataId2 = result.inserted_primary_key[0] else: dataId2 = result.last_inserted_ids()[0] - + # Change column type in model. self.meta.remove(self.table) self.table = Table(self.table_name,self.meta, @@ -134,13 +148,13 @@ class TestSchemaDiff(fixture.DB): # XXX test type diff return - + assertDiff(True, [], [], [self.table_name]) - + # Apply latest model changes and find no more diffs. self._applyLatestModel() assertDiff(False, [], [], []) - + if not self.engine.name == 'oracle': # Make sure data is still present. result = self.engine.execute(self.table.select(self.table.c.id==dataId2)) @@ -148,11 +162,11 @@ class TestSchemaDiff(fixture.DB): self.assertEquals(len(rows), 1) self.assertEquals(rows[0].name, 'mydata2') self.assertEquals(rows[0].data2, '123') - + # Delete data, since we're about to make a required column. # Not even using sqlalchemy.PassiveDefault helps because we're doing explicit column select. self.engine.execute(self.table.delete(), id=dataId) - + if not self.engine.name == 'firebird': # Change column nullable in model. self.meta.remove(self.table) @@ -162,11 +176,11 @@ class TestSchemaDiff(fixture.DB): Column('data2',String(255),nullable=False), ) assertDiff(True, [], [], [self.table_name]) # TODO test nullable diff - + # Apply latest model changes and find no more diffs. self._applyLatestModel() assertDiff(False, [], [], []) - + # Remove table from model. self.meta.remove(self.table) assertDiff(True, [], [self.table_name], []) diff --git a/migrate/tests/versioning/test_script.py b/migrate/tests/versioning/test_script.py index 021b880..6c9d585 100644 --- a/migrate/tests/versioning/test_script.py +++ b/migrate/tests/versioning/test_script.py @@ -161,23 +161,44 @@ def upgrade(migrate_engine): self.assertTrue('User.create()' in source_script) self.assertTrue('User.drop()' in source_script) - #@fixture.usedb() - #def test_make_update_script_for_model_equals(self): - # """Try to make update script from two identical models""" + @fixture.usedb() + def test_make_update_script_for_equal_models(self): + """Try to make update script from two identical models""" - # self.setup_model_params() - # self.write_file(self.first_model_path, self.base_source + self.model_source) - # self.write_file(self.second_model_path, self.base_source + self.model_source) + self.setup_model_params() + self.write_file(self.first_model_path, self.base_source + self.model_source) + self.write_file(self.second_model_path, self.base_source + self.model_source) - # source_script = self.pyscript.make_update_script_for_model( - # engine=self.engine, - # oldmodel=load_model('testmodel_first:meta'), - # model=load_model('testmodel_second:meta'), - # repository=self.repo_path, - # ) + source_script = self.pyscript.make_update_script_for_model( + engine=self.engine, + oldmodel=load_model('testmodel_first:meta'), + model=load_model('testmodel_second:meta'), + repository=self.repo_path, + ) - # self.assertFalse('User.create()' in source_script) - # self.assertFalse('User.drop()' in source_script) + self.assertFalse('User.create()' in source_script) + self.assertFalse('User.drop()' in source_script) + + @fixture.usedb() + def test_make_update_script_direction(self): + """Check update scripts go in the right direction""" + + self.setup_model_params() + self.write_file(self.first_model_path, self.base_source) + self.write_file(self.second_model_path, self.base_source + self.model_source) + + source_script = self.pyscript.make_update_script_for_model( + engine=self.engine, + oldmodel=load_model('testmodel_first:meta'), + model=load_model('testmodel_second:meta'), + repository=self.repo_path, + ) + + self.assertTrue(0 + < source_script.find('upgrade') + < source_script.find('User.create()') + < source_script.find('downgrade') + < source_script.find('User.drop()')) def setup_model_params(self): self.script_path = self.tmp_py() @@ -195,6 +216,8 @@ User = Table('User', meta, self.repo = repository.Repository.create(self.repo_path, 'repo') self.pyscript = PythonScript.create(self.script_path) + sys.modules.pop('testmodel_first', None) + sys.modules.pop('testmodel_second', None) def write_file(self, path, contents): f = open(path, 'w') diff --git a/migrate/versioning/genmodel.py b/migrate/versioning/genmodel.py index 6cb8e09..cfe9996 100644 --- a/migrate/versioning/genmodel.py +++ b/migrate/versioning/genmodel.py @@ -1,9 +1,9 @@ """ - Code to generate a Python model from a database or differences - between a model and database. +Code to generate a Python model from a database or differences +between a model and database. - Some of this is borrowed heavily from the AutoCode project at: - http://code.google.com/p/sqlautocode/ +Some of this is borrowed heavily from the AutoCode project at: +http://code.google.com/p/sqlautocode/ """ import sys @@ -34,6 +34,13 @@ Base = declarative.declarative_base() class ModelGenerator(object): + """Various transformations from an A, B diff. + + In the implementation, A tends to be called the model and B + the database (although this is not true of all diffs). + The diff is directionless, but transformations apply the diff + in a particular direction, described in the method name. + """ def __init__(self, diff, engine, declarative=False): self.diff = diff @@ -89,7 +96,7 @@ class ModelGenerator(object): else: return """Column(%(name)r, %(commonStuff)s)""" % data - def getTableDefn(self, table): + def _getTableDefn(self, table): out = [] tableName = table.name if self.declarative: @@ -117,9 +124,15 @@ class ModelGenerator(object): if bool_: for name in names: yield metadata.tables.get(name) - - def toPython(self): - """Assume database is current and model is empty.""" + + def genBDefinition(self): + """Generates the source code for a definition of B. + + Assumes a diff where A is empty. + + Was: toPython. Assume database (B) is current and model (A) is empty. + """ + out = [] if self.declarative: out.append(DECLARATIVE_HEADER) @@ -127,17 +140,22 @@ class ModelGenerator(object): out.append(HEADER) out.append("") for table in self._get_tables(missingA=True): - out.extend(self.getTableDefn(table)) + out.extend(self._getTableDefn(table)) return '\n'.join(out) - def toUpgradeDowngradePython(self, indent=' '): - ''' Assume model is most current and database is out-of-date. ''' + def genB2AMigration(self, indent=' '): + '''Generate a migration from B to A. + + Was: toUpgradeDowngradePython + Assume model (A) is most current and database (B) is out-of-date. + ''' + decls = ['from migrate.changeset import schema', 'meta = MetaData()'] for table in self._get_tables( missingA=True,missingB=True,modified=True ): - decls.extend(self.getTableDefn(table)) + decls.extend(self._getTableDefn(table)) upgradeCommands, downgradeCommands = [], [] for tableName in self.diff.tables_missing_from_A: @@ -175,16 +193,21 @@ class ModelGenerator(object): '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in downgradeCommands])) def _db_can_handle_this_change(self,td): + """Check if the database can handle going from B to A.""" + if (td.columns_missing_from_B and not td.columns_missing_from_A and not td.columns_different): - # Even sqlite can handle this. + # Even sqlite can handle column additions. return True else: return not self.engine.url.drivername.startswith('sqlite') - def applyModel(self): - """Apply model to current database.""" + def runB2A(self): + """Goes from B to A. + + Was: applyModel. Apply model (A) to current database (B). + """ meta = sqlalchemy.MetaData(self.engine) @@ -200,9 +223,9 @@ class ModelGenerator(object): dbTable = self.diff.metadataB.tables[tableName] td = self.diff.tables_different[tableName] - + if self._db_can_handle_this_change(td): - + for col in td.columns_missing_from_B: modelTable.columns[col].create() for col in td.columns_missing_from_A: diff --git a/migrate/versioning/schema.py b/migrate/versioning/schema.py index 1085538..e4d9365 100644 --- a/migrate/versioning/schema.py +++ b/migrate/versioning/schema.py @@ -71,7 +71,7 @@ class ControlledSchema(object): def changeset(self, version=None): """API to Changeset creation. - + Uses self.version for start version and engine.name to get database name. """ @@ -117,7 +117,7 @@ class ControlledSchema(object): diff = schemadiff.getDiffOfModelAgainstDatabase( model, self.engine, excludeTables=[self.repository.version_table] ) - genmodel.ModelGenerator(diff,self.engine).applyModel() + genmodel.ModelGenerator(diff,self.engine).runB2A() self.update_repository_table(self.version, int(self.repository.latest)) @@ -217,4 +217,4 @@ class ControlledSchema(object): diff = schemadiff.getDiffOfModelAgainstDatabase( MetaData(), engine, excludeTables=[repository.version_table] ) - return genmodel.ModelGenerator(diff, engine, declarative).toPython() + return genmodel.ModelGenerator(diff, engine, declarative).genBDefinition() diff --git a/migrate/versioning/schemadiff.py b/migrate/versioning/schemadiff.py index 17c2d8e..77661d3 100644 --- a/migrate/versioning/schemadiff.py +++ b/migrate/versioning/schemadiff.py @@ -39,11 +39,11 @@ class ColDiff(object): Container for differences in one :class:`~sqlalchemy.schema.Column` between two :class:`~sqlalchemy.schema.Table` instances, ``A`` and ``B``. - + .. attribute:: col_A The :class:`~sqlalchemy.schema.Column` object for A. - + .. attribute:: col_B The :class:`~sqlalchemy.schema.Column` object for B. @@ -51,15 +51,15 @@ class ColDiff(object): .. attribute:: type_A The most generic type of the :class:`~sqlalchemy.schema.Column` - object in A. - + object in A. + .. attribute:: type_B The most generic type of the :class:`~sqlalchemy.schema.Column` - object in A. - + object in A. + """ - + diff = False def __init__(self,col_A,col_B): @@ -87,10 +87,10 @@ class ColDiff(object): if not (A is None or B is None) and A!=B: self.diff=True return - + def __nonzero__(self): return self.diff - + class TableDiff(object): """ Container for differences in one :class:`~sqlalchemy.schema.Table` @@ -101,12 +101,12 @@ class TableDiff(object): A sequence of column names that were found in B but weren't in A. - + .. attribute:: columns_missing_from_B A sequence of column names that were found in A but weren't in B. - + .. attribute:: columns_different A dictionary containing information about columns that were @@ -126,7 +126,7 @@ class TableDiff(object): self.columns_missing_from_B or self.columns_different ) - + class SchemaDiff(object): """ Compute the difference between two :class:`~sqlalchemy.schema.MetaData` @@ -139,34 +139,34 @@ class SchemaDiff(object): The length of a :class:`SchemaDiff` will give the number of changes found, enabling it to be used much like a boolean in expressions. - + :param metadataA: First :class:`~sqlalchemy.schema.MetaData` to compare. - + :param metadataB: Second :class:`~sqlalchemy.schema.MetaData` to compare. - + :param labelA: The label to use in messages about the first - :class:`~sqlalchemy.schema.MetaData`. - - :param labelB: + :class:`~sqlalchemy.schema.MetaData`. + + :param labelB: The label to use in messages about the second - :class:`~sqlalchemy.schema.MetaData`. - + :class:`~sqlalchemy.schema.MetaData`. + :param excludeTables: A sequence of table names to exclude. - + .. attribute:: tables_missing_from_A A sequence of table names that were found in B but weren't in A. - + .. attribute:: tables_missing_from_B A sequence of table names that were found in A but weren't in B. - + .. attribute:: tables_different A dictionary containing information about tables that were found @@ -195,26 +195,26 @@ class SchemaDiff(object): self.tables_missing_from_B = sorted( A_table_names - B_table_names - excludeTables ) - + self.tables_different = {} for table_name in A_table_names.intersection(B_table_names): td = TableDiff() - + A_table = metadataA.tables[table_name] B_table = metadataB.tables[table_name] - + A_column_names = set(A_table.columns.keys()) B_column_names = set(B_table.columns.keys()) td.columns_missing_from_A = sorted( B_column_names - A_column_names ) - + td.columns_missing_from_B = sorted( A_column_names - B_column_names ) - + td.columns_different = {} for col_name in A_column_names.intersection(B_column_names): @@ -226,7 +226,7 @@ class SchemaDiff(object): if cd: td.columns_different[col_name]=cd - + # XXX - index and constraint differences should # be checked for here @@ -237,7 +237,7 @@ class SchemaDiff(object): ''' Summarize differences. ''' out = [] column_template =' %%%is: %%r' % self.label_width - + for names,label in ( (self.tables_missing_from_A,self.labelA), (self.tables_missing_from_B,self.labelB), @@ -248,7 +248,7 @@ class SchemaDiff(object): label,', '.join(sorted(names)) ) ) - + for name,td in sorted(self.tables_different.items()): out.append( ' table with differences: %s' % name @@ -267,7 +267,7 @@ class SchemaDiff(object): out.append(' column with differences: %s' % name) out.append(column_template % (self.labelA,cd.col_A)) out.append(column_template % (self.labelB,cd.col_B)) - + if out: out.insert(0, 'Schema diffs:') return '\n'.join(out) diff --git a/migrate/versioning/script/py.py b/migrate/versioning/script/py.py index 35fe4aa..3a090d4 100644 --- a/migrate/versioning/script/py.py +++ b/migrate/versioning/script/py.py @@ -25,7 +25,7 @@ class PythonScript(base.BaseScript): @classmethod def create(cls, path, **opts): """Create an empty migration script at specified path - + :returns: :class:`PythonScript instance `""" cls.require_notfound(path) @@ -38,7 +38,7 @@ class PythonScript(base.BaseScript): def make_update_script_for_model(cls, engine, oldmodel, model, repository, **opts): """Create a migration script based on difference between two SA models. - + :param repository: path to migrate repository :param oldmodel: dotted.module.name:SAClass or SAClass object :param model: dotted.module.name:SAClass or SAClass object @@ -50,7 +50,7 @@ class PythonScript(base.BaseScript): :returns: Upgrade / Downgrade script :rtype: string """ - + if isinstance(repository, basestring): # oh dear, an import cycle! from migrate.versioning.repository import Repository @@ -61,12 +61,12 @@ class PythonScript(base.BaseScript): # Compute differences. diff = schemadiff.getDiffOfModelAgainstModel( - oldmodel, model, + oldmodel, excludeTables=[repository.version_table]) # TODO: diff can be False (there is no difference?) decls, upgradeCommands, downgradeCommands = \ - genmodel.ModelGenerator(diff,engine).toUpgradeDowngradePython() + genmodel.ModelGenerator(diff,engine).genB2AMigration() # Store differences into file. src = Template(opts.pop('templates_path', None)).get_script(opts.pop('templates_theme', None)) @@ -86,7 +86,7 @@ class PythonScript(base.BaseScript): @classmethod def verify_module(cls, path): """Ensure path is a valid script - + :param path: Script location :type path: string :raises: :exc:`InvalidScriptError ` @@ -101,7 +101,7 @@ class PythonScript(base.BaseScript): return module def preview_sql(self, url, step, **args): - """Mocks SQLAlchemy Engine to store all executed calls in a string + """Mocks SQLAlchemy Engine to store all executed calls in a string and runs :meth:`PythonScript.run ` :returns: SQL file @@ -119,7 +119,7 @@ class PythonScript(base.BaseScript): return go(url, step, **args) def run(self, engine, step): - """Core method of Script file. + """Core method of Script file. Exectues :func:`update` or :func:`downgrade` functions :param engine: SQLAlchemy Engine