220 lines
8.5 KiB
Python
220 lines
8.5 KiB
Python
"""
|
|
Schema differencing support.
|
|
"""
|
|
import logging
|
|
|
|
import sqlalchemy
|
|
from migrate.changeset import SQLA_06
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
def getDiffOfModelAgainstDatabase(model, conn, excludeTables=None):
|
|
"""
|
|
Return differences of model against database.
|
|
|
|
:return: object which will evaluate to :keyword:`True` if there \
|
|
are differences else :keyword:`False`.
|
|
"""
|
|
return SchemaDiff(model, conn, excludeTables)
|
|
|
|
|
|
def getDiffOfModelAgainstModel(oldmodel, model, conn, excludeTables=None):
|
|
"""
|
|
Return differences of model against another model.
|
|
|
|
:return: object which will evaluate to :keyword:`True` if there \
|
|
are differences else :keyword:`False`.
|
|
"""
|
|
return SchemaDiff(model, conn, excludeTables, oldmodel=oldmodel)
|
|
|
|
|
|
class SchemaDiff(object):
|
|
"""
|
|
Differences of model against database.
|
|
"""
|
|
|
|
def __init__(self, model, conn, excludeTables=None, oldmodel=None):
|
|
"""
|
|
:param model: Python model's metadata
|
|
:param conn: active database connection.
|
|
"""
|
|
self.model = model
|
|
self.conn = conn
|
|
if not excludeTables:
|
|
# [] can't be default value in Python parameter
|
|
excludeTables = []
|
|
self.excludeTables = excludeTables
|
|
if oldmodel:
|
|
self.reflected_model = oldmodel
|
|
else:
|
|
self.reflected_model = sqlalchemy.MetaData(conn, reflect=True)
|
|
self.tablesMissingInDatabase, self.tablesMissingInModel, \
|
|
self.tablesWithDiff = [], [], []
|
|
self.colDiffs = {}
|
|
self.compareModelToDatabase()
|
|
|
|
def compareModelToDatabase(self):
|
|
"""
|
|
Do actual comparison.
|
|
"""
|
|
# Setup common variables.
|
|
cc = self.conn.contextual_connect()
|
|
if SQLA_06:
|
|
from sqlalchemy.ext import compiler
|
|
from sqlalchemy.schema import DDLElement
|
|
class DefineColumn(DDLElement):
|
|
def __init__(self, col):
|
|
self.col = col
|
|
|
|
@compiler.compiles(DefineColumn)
|
|
def compile(elem, compiler, **kw):
|
|
return compiler.get_column_specification(elem.col)
|
|
|
|
def get_column_specification(col):
|
|
return str(DefineColumn(col).compile(dialect=self.conn.dialect))
|
|
else:
|
|
schemagenerator = self.conn.dialect.schemagenerator(
|
|
self.conn.dialect, cc)
|
|
def get_column_specification(col):
|
|
return schemagenerator.get_column_specification(col)
|
|
|
|
# For each in model, find missing in database.
|
|
for modelName, modelTable in self.model.tables.items():
|
|
if modelName in self.excludeTables:
|
|
continue
|
|
reflectedTable = self.reflected_model.tables.get(modelName, None)
|
|
if reflectedTable is not None:
|
|
# Table exists.
|
|
pass
|
|
else:
|
|
self.tablesMissingInDatabase.append(modelTable)
|
|
|
|
# For each in database, find missing in model.
|
|
for reflectedName, reflectedTable in \
|
|
self.reflected_model.tables.items():
|
|
if reflectedName in self.excludeTables:
|
|
continue
|
|
modelTable = self.model.tables.get(reflectedName, None)
|
|
if modelTable is not None:
|
|
# Table exists.
|
|
|
|
# Find missing columns in database.
|
|
for modelCol in modelTable.columns:
|
|
databaseCol = reflectedTable.columns.get(modelCol.name,
|
|
None)
|
|
if databaseCol is not None:
|
|
pass
|
|
else:
|
|
self.storeColumnMissingInDatabase(modelTable, modelCol)
|
|
|
|
# Find missing columns in model.
|
|
for databaseCol in reflectedTable.columns:
|
|
|
|
# TODO: no test coverage here? (mrb)
|
|
|
|
modelCol = modelTable.columns.get(databaseCol.name, None)
|
|
if modelCol is not None:
|
|
# Compare attributes of column.
|
|
modelDecl = \
|
|
get_column_specification(modelCol)
|
|
databaseDecl = \
|
|
get_column_specification(databaseCol)
|
|
if modelDecl != databaseDecl:
|
|
# Unfortunately, sometimes the database
|
|
# decl won't quite match the model, even
|
|
# though they're the same.
|
|
mc, dc = modelCol.type.__class__, \
|
|
databaseCol.type.__class__
|
|
if (issubclass(mc, dc) \
|
|
or issubclass(dc, mc)) \
|
|
and modelCol.nullable == \
|
|
databaseCol.nullable:
|
|
# Types and nullable are the same.
|
|
pass
|
|
else:
|
|
self.storeColumnDiff(
|
|
modelTable, modelCol, databaseCol,
|
|
modelDecl, databaseDecl)
|
|
else:
|
|
self.storeColumnMissingInModel(modelTable, databaseCol)
|
|
else:
|
|
self.tablesMissingInModel.append(reflectedTable)
|
|
|
|
def __str__(self):
|
|
''' Summarize differences. '''
|
|
|
|
def colDiffDetails():
|
|
colout = []
|
|
for table in self.tablesWithDiff:
|
|
tableName = table.name
|
|
missingInDatabase, missingInModel, diffDecl = \
|
|
self.colDiffs[tableName]
|
|
if missingInDatabase:
|
|
colout.append(
|
|
' %s missing columns in database: %s' % \
|
|
(tableName, ', '.join(
|
|
[col.name for col in missingInDatabase])))
|
|
if missingInModel:
|
|
colout.append(
|
|
' %s missing columns in model: %s' % \
|
|
(tableName, ', '.join(
|
|
[col.name for col in missingInModel])))
|
|
if diffDecl:
|
|
colout.append(
|
|
' %s with different declaration of columns\
|
|
in database: %s' % (tableName, str(diffDecl)))
|
|
return colout
|
|
|
|
out = []
|
|
if self.tablesMissingInDatabase:
|
|
out.append(
|
|
' tables missing in database: %s' % \
|
|
', '.join(
|
|
[table.name for table in self.tablesMissingInDatabase]))
|
|
if self.tablesMissingInModel:
|
|
out.append(
|
|
' tables missing in model: %s' % \
|
|
', '.join(
|
|
[table.name for table in self.tablesMissingInModel]))
|
|
if self.tablesWithDiff:
|
|
out.append(
|
|
' tables with differences: %s' % \
|
|
', '.join([table.name for table in self.tablesWithDiff]))
|
|
|
|
if out:
|
|
out.insert(0, 'Schema diffs:')
|
|
out.extend(colDiffDetails())
|
|
return '\n'.join(out)
|
|
else:
|
|
return 'No schema diffs'
|
|
|
|
def __len__(self):
|
|
"""
|
|
Used in bool evaluation, return of 0 means no diffs.
|
|
"""
|
|
return len(self.tablesMissingInDatabase) + \
|
|
len(self.tablesMissingInModel) + len(self.tablesWithDiff)
|
|
|
|
def storeColumnMissingInDatabase(self, table, col):
|
|
if table not in self.tablesWithDiff:
|
|
self.tablesWithDiff.append(table)
|
|
missingInDatabase, missingInModel, diffDecl = \
|
|
self.colDiffs.setdefault(table.name, ([], [], []))
|
|
missingInDatabase.append(col)
|
|
|
|
def storeColumnMissingInModel(self, table, col):
|
|
if table not in self.tablesWithDiff:
|
|
self.tablesWithDiff.append(table)
|
|
missingInDatabase, missingInModel, diffDecl = \
|
|
self.colDiffs.setdefault(table.name, ([], [], []))
|
|
missingInModel.append(col)
|
|
|
|
def storeColumnDiff(self, table, modelCol, databaseCol, modelDecl,
|
|
databaseDecl):
|
|
if table not in self.tablesWithDiff:
|
|
self.tablesWithDiff.append(table)
|
|
missingInDatabase, missingInModel, diffDecl = \
|
|
self.colDiffs.setdefault(table.name, ([], [], []))
|
|
diffDecl.append((modelCol, databaseCol, modelDecl, databaseDecl))
|