2008-08-28 01:17:44 +00:00

359 lines
13 KiB
Python

import re
import sqlalchemy
from migrate.changeset.databases.visitor import get_engine_visitor
__all__ = [
'create_column',
'drop_column',
'alter_column',
'rename_table',
'rename_index',
]
def create_column(column,table=None,*p,**k):
if table is not None:
return table.create_column(column,*p,**k)
return column.create(*p,**k)
def drop_column(column,table=None,*p,**k):
if table is not None:
return table.drop_column(column,*p,**k)
return column.drop(*p,**k)
def _to_table(table,engine=None):
if isinstance(table,sqlalchemy.Table):
return table
# Given: table name, maybe an engine
meta = sqlalchemy.MetaData()
if engine is not None:
meta.bind = engine #meta.connect(engine)
return sqlalchemy.Table(table,meta)
def _to_index(index,table=None,engine=None):
if isinstance(index,sqlalchemy.Index):
return index
# Given: index name; table name required
table = _to_table(table,engine)
ret = sqlalchemy.Index(index)
ret.table = table
return ret
def rename_table(table,name,engine=None):
"""Rename a table, given the table's current name and the new name."""
table = _to_table(table,engine)
table.rename(name)
def rename_index(index,name,table=None,engine=None):
"""Rename an index
Takes an index name/object, a table name/object, and an engine. Engine and
table aren't required if an index object is given.
"""
index = _to_index(index,table,engine)
index.rename(name)
def _engine_run_visitor(engine,visitorcallable,element,**kwargs):
conn = engine.connect()
try:
element.accept_schema_visitor(visitorcallable(engine.dialect,connection=conn))
finally:
conn.close()
def alter_column(*p,**k):
"""Alter a column
Parameters: column name, table name, an engine, and the
properties of that column to change
"""
if len(p) and isinstance(p[0],sqlalchemy.Column):
col = p[0]
else:
col = None
if 'table' not in k:
k['table'] = col.table
if 'engine' not in k:
k['engine'] = k['table'].bind
engine = k['engine']
delta = _ColumnDelta(*p,**k)
visitorcallable = get_engine_visitor(engine,'schemachanger')
_engine_run_visitor(engine,visitorcallable,delta)
# Update column
if col is not None:
# Special case: change column key on rename, if key not explicit
# Used by SA : table.c.[key]
#
# This fails if the key was explit AND equal to the column name.
# (It changes the key name when it shouldn't.)
# Not much we can do about it.
if 'name' in delta.keys():
if (col.name == col.key):
newname = delta['name']
del col.table.c[col.key]
setattr(col,'key',newname)
col.table.c[col.key] = col
# Change all other attrs
for key,val in delta.iteritems():
setattr(col,key,val)
def _normalize_table(column,table):
if table is not None:
if table is not column.table:
# This is a bit of a hack: we end up with dupe PK columns here
pk_names = map(lambda c: c.name, table.primary_key)
if column.primary_key and pk_names.count(column.name):
index = pk_names.index(column_name)
del table.primary_key[index]
table.append_column(column)
return column.table
class _WrapRename(object):
def __init__(self,item,name):
self.item = item
self.name = name
def accept_schema_visitor(self,visitor):
if isinstance(self.item,sqlalchemy.Table):
suffix = 'table'
elif isinstance(self.item,sqlalchemy.Column):
suffix = 'column'
elif isinstance(self.item,sqlalchemy.Index):
suffix = 'index'
funcname = 'visit_%s'%suffix
func = getattr(visitor,funcname)
param = self.item,self.name
return func(param)
class _ColumnDelta(dict):
"""Extracts the differences between two columns/column-parameters"""
def __init__(self,*p,**k):
"""Extract ALTER-able differences from two columns
May receive parameters arranged in several different ways:
* old_column_object,new_column_object,*parameters
Identifies attributes that differ between the two columns.
Parameters specified outside of either column are always executed
and override column differences.
* column_object,[current_name,]*parameters
Parameters specified are changed; table name is extracted from
column object.
Name is changed to column_object.name from current_name, if
current_name is specified. If not specified, name is unchanged.
* current_name,table,*parameters
'table' may be either an object or a name
"""
# Things are initialized differently depending on how many column
# parameters are given. Figure out how many and call the appropriate
# method.
if len(p) >= 1 and isinstance(p[0],sqlalchemy.Column):
# At least one column specified
if len(p) >= 2 and isinstance(p[1],sqlalchemy.Column):
# Two columns specified
func = self._init_2col
else:
# Exactly one column specified
func = self._init_1col
else:
# Zero columns specified
func = self._init_0col
diffs = func(*p,**k)
self._set_diffs(diffs)
# Column attributes that can be altered
diff_keys = ('name','type','nullable','default','server_default','primary_key','foreign_key')
def _get_table_name(self):
if isinstance(self._table,basestring):
ret = self._table
else:
ret = self._table.name
return ret
table_name = property(_get_table_name)
def _get_table(self):
if isinstance(self._table,basestring):
ret = None
else:
ret = self._table
return ret
table = property(_get_table)
def _init_0col(self,current_name,*p,**k):
p,k = self._init_normalize_params(p,k)
table = k.pop('table')
self.current_name = current_name
self._table = table
return k
def _init_1col(self,col,*p,**k):
p,k = self._init_normalize_params(p,k)
self._table = k.pop('table',None) or col.table
self.result_column = col.copy()
if 'current_name' in k:
# Renamed
self.current_name = k.pop('current_name')
k.setdefault('name',col.name)
else:
self.current_name = col.name
return k
def _init_2col(self,start_col,end_col,*p,**k):
p,k = self._init_normalize_params(p,k)
self.result_column = start_col.copy()
self._table = k.pop('table',None) or start_col.table or end_col.table
self.current_name = start_col.name
for key in ('name','nullable','default','server_default','primary_key','foreign_key'):
val = getattr(end_col,key,None)
if getattr(start_col,key,None) != val:
k.setdefault(key,val)
if not self.column_types_eq(start_col.type,end_col.type):
k.setdefault('type',end_col.type)
return k
def _init_normalize_params(self,p,k):
p = list(p)
if len(p):
k.setdefault('name',p.pop(0))
if len(p):
k.setdefault('type',p.pop(0))
# TODO: sequences? FKs?
return p,k
def _set_diffs(self,diffs):
for key in self.diff_keys:
if key in diffs:
self[key] = diffs[key]
if getattr(self,'result_column',None) is not None:
setattr(self.result_column,key,diffs[key])
def column_types_eq(self,this,that):
ret = isinstance(this,that.__class__)
ret = ret or isinstance(that,this.__class__)
# String length is a special case
if ret and isinstance(that,sqlalchemy.types.String):
ret = (getattr(this,'length',None) == getattr(that,'length',None))
return ret
def accept_schema_visitor(self,visitor):
return visitor.visit_column(self)
class ChangesetTable(object):
"""Changeset extensions to SQLAlchemy tables."""
def create_column(self,column):
"""Creates a column
The column parameter may be a column definition or the name of a column
in this table.
"""
if not isinstance(column,sqlalchemy.Column):
# It's a column name
column = getattr(self.c,str(column))
column.create(table=self)
def drop_column(self,column):
"""Drop a column, given its name or definition."""
if not isinstance(column,sqlalchemy.Column):
# It's a column name
try:
column = getattr(self.c,str(column),None)
except AttributeError:
# That column isn't part of the table. We don't need its entire
# definition to drop the column, just its name, so create a dummy
# column with the same name.
column = sqlalchemy.Column(str(column))
column.drop(table=self)
def _meta_key(self):
return sqlalchemy.schema._get_table_key(self.name,self.schema)
def deregister(self):
"""Remove this table from its metadata"""
key = self._meta_key()
meta = self.metadata
if key in meta.tables:
del meta.tables[key]
def rename(self,name,*args,**kwargs):
"""Rename this table
This changes both the database name and the name of this Python object
"""
engine = self.bind
visitorcallable = get_engine_visitor(engine,'schemachanger')
param = _WrapRename(self,name)
#engine._run_visitor(visitorcallable,param,*args,**kwargs)
_engine_run_visitor(engine,visitorcallable,param,*args,**kwargs)
# Fix metadata registration
meta = self.metadata
self.deregister()
self.name = name
self._set_parent(meta)
def _get_fullname(self):
"""Fullname should always be up to date"""
# Copied from Table constructor
if self.schema is not None:
ret = "%s.%s"%(self.schema,self.name)
else:
ret = self.name
return ret
fullname = property(_get_fullname,(lambda self,val: None))
class ChangesetColumn(object):
"""Changeset extensions to SQLAlchemy columns"""
def alter(self,*p,**k):
"""Alter a column's definition: ALTER TABLE ALTER COLUMN
May supply a new column object, or a list of properties to change.
For example; the following are equivalent:
col.alter(Column('myint',Integer,nullable=False))
col.alter('myint',Integer,nullable=False)
col.alter(name='myint',type=Integer,nullable=False)
Column name, type, default, and nullable may be changed here. Note that
for column defaults, only PassiveDefaults are managed by the database -
changing others doesn't make sense.
"""
if 'table' not in k:
k['table'] = self.table
if 'engine' not in k:
k['engine'] = k['table'].bind
return alter_column(self,*p,**k)
def create(self,table=None,*args,**kwargs):
"""Create this column in the database. Assumes the given table exists.
ALTER TABLE ADD COLUMN, for most databases.
"""
table = _normalize_table(self,table)
engine = table.bind
visitorcallable = get_engine_visitor(engine,'columngenerator')
engine._run_visitor(visitorcallable,self,*args,**kwargs)
return self
def drop(self,table=None,*args,**kwargs):
"""Drop this column from the database, leaving its table intact.
ALTER TABLE DROP COLUMN, for most databases.
"""
table = _normalize_table(self,table)
engine = table.bind
visitorcallable = get_engine_visitor(engine,'columndropper')
#engine._run_visitor(visitorcallable,self,*args,**kwargs)
engine._run_visitor(lambda dialect, conn: visitorcallable(conn), self, *args, **kwargs)
## Remove col from table object, too
#del table._columns[self.key]
#if self in table.primary_key:
# table.primary_key.remove(self)
return self
class ChangesetIndex(object):
"""Changeset extensions to SQLAlchemy Indexes"""
def rename(self,name,*args,**kwargs):
"""Change the name of an index.
This changes both the Python object name and the database name.
"""
engine = self.table.bind
visitorcallable = get_engine_visitor(engine,'schemachanger')
param = _WrapRename(self,name)
#engine._run_visitor(visitorcallable,param,*args,**kwargs)
_engine_run_visitor(engine,visitorcallable,param,*args,**kwargs)
self.name = name
def _patch():
"""All the 'ugly' operations that patch SQLAlchemy's internals."""
sqlalchemy.schema.Table.__bases__ += (ChangesetTable,)
sqlalchemy.schema.Column.__bases__ += (ChangesetColumn,)
sqlalchemy.schema.Index.__bases__ += (ChangesetIndex,)
_patch()