This commit is contained in:
iElectric 2009-08-10 12:24:01 +02:00
commit ccf7be3372
38 changed files with 730 additions and 484 deletions

13
.hgignore Normal file
View File

@ -0,0 +1,13 @@
syntax: glob
*.pyc
*data/*
*build/*
*dist/*
*ez_setup.py
*.egg/*
*egg-info/*
*bin/*
*include/*
*lib/*
sa06/*

View File

@ -1,14 +1,15 @@
0.5.5 0.6.0
----- -----
- added option to define custom templates through option ``--templates_path`` and ``--templates_theme``, read more in :ref:`tutorial section <custom-templates>` - added option to define custom templates through option ``--templates_path`` and ``--templates_theme``, read more in :ref:`tutorial section <custom-templates>`
- url parameter can also be an Engine instance (this usage is discouraged though sometimes necessary) - use Python logging for output, can be shut down by passing ``--disable_logging`` to :func:`migrate.versioning.shell.main`
- `url` parameter can also be an :class:`Engine` instance (this usage is discouraged though sometimes necessary)
- added support for SQLAlchemy 0.6 (missing oracle and firebird) by Michael Bayer - added support for SQLAlchemy 0.6 (missing oracle and firebird) by Michael Bayer
- alter, create, drop column / rename table / rename index constructs now accept `alter_metadata` parameter. If True, it will modify Column/Table objects according to changes. Otherwise, everything will be untouched. - alter, create, drop column / rename table / rename index constructs now accept `alter_metadata` parameter. If True, it will modify Column/Table objects according to changes. Otherwise, everything will be untouched.
- complete refactoring of :class:`~migrate.changeset.schema.ColumnDelta` (fixes issue 23) - complete refactoring of :class:`~migrate.changeset.schema.ColumnDelta` (fixes issue 23)
- added support for :ref:`firebird <firebird-d>` - added support for :ref:`firebird <firebird-d>`
- fixed bug when :meth:`Column.alter <migrate.changeset.schema.ChangesetColumn.alter>`\(server_default='string') was not properly set - fixed bug when :meth:`Column.alter <migrate.changeset.schema.ChangesetColumn.alter>`\(server_default='string') was not properly set
- `server_defaults` passed to :meth:`Column.create <migrate.changeset.schema.ChangesetColumn.create>` are now issued correctly - `server_defaults` passed to :meth:`Column.create <migrate.changeset.schema.ChangesetColumn.create>` are now issued correctly
- added `populate_default` bool argument to :meth:`Column.create <migrate.changeset.schema.ChangesetColumn.create>` which issues corresponding UPDATE statements to set defaults after column creation
- constraints passed to :meth:`Column.create <migrate.changeset.schema.ChangesetColumn.create>` are correctly interpreted (``ALTER TABLE ADD CONSTRAINT`` is issued after ``ATLER TABLE ADD COLUMN``) - constraints passed to :meth:`Column.create <migrate.changeset.schema.ChangesetColumn.create>` are correctly interpreted (``ALTER TABLE ADD CONSTRAINT`` is issued after ``ATLER TABLE ADD COLUMN``)
- :meth:`Column.create <migrate.changeset.schema.ChangesetColumn.create>` accepts `primary_key_name`, `unique_name` and `index_name` as string value which is used as contraint name when adding a column - :meth:`Column.create <migrate.changeset.schema.ChangesetColumn.create>` accepts `primary_key_name`, `unique_name` and `index_name` as string value which is used as contraint name when adding a column
- Constraint classes have `cascade=True` keyword argument to issue ``DROP CASCADE`` where supported - Constraint classes have `cascade=True` keyword argument to issue ``DROP CASCADE`` where supported
@ -19,10 +20,11 @@
- majoy update to documentation - majoy update to documentation
- :ref:`dialect support <dialect-support>` table was added to documentation - :ref:`dialect support <dialect-support>` table was added to documentation
.. _backwards-055: .. _backwards-06:
**Backward incompatible changes**: **Backward incompatible changes**:
- :func:`api.test` and schema comparison functions now all accept `url` as first parameter and `repository` as second.
- python upgrade/downgrade scripts do not import `migrate_engine` magically, but recieve engine as the only parameter to function (eg. ``def upgrade(migrate_engine):``) - python upgrade/downgrade scripts do not import `migrate_engine` magically, but recieve engine as the only parameter to function (eg. ``def upgrade(migrate_engine):``)
- :meth:`Column.alter <migrate.changeset.schema.ChangesetColumn.alter>` does not accept `current_name` anymore, it extracts name from the old column. - :meth:`Column.alter <migrate.changeset.schema.ChangesetColumn.alter>` does not accept `current_name` anymore, it extracts name from the old column.

View File

@ -39,8 +39,8 @@ Given a standard SQLAlchemy table::
:meth:`Create a column <ChangesetColumn.create>`:: :meth:`Create a column <ChangesetColumn.create>`::
col = Column('col1', String) col = Column('col1', String, default='foobar')
col.create(table) col.create(table, populate_default=True)
# Column is added to table based on its name # Column is added to table based on its name
assert col is table.c.col1 assert col is table.c.col1
@ -72,7 +72,7 @@ Given a standard SQLAlchemy table::
.. note:: .. note::
Since version ``0.5.5`` you can pass primary_key_name, index_name and unique_name to column.create method to issue ALTER TABLE ADD CONSTRAINT after changing the column. Note for multi columns constraints and other advanced configuration, check :ref:`constraint tutorial <constraint-tutorial>`. Since version ``0.6.0`` you can pass primary_key_name, index_name and unique_name to column.create method to issue ALTER TABLE ADD CONSTRAINT after changing the column. Note for multi columns constraints and other advanced configuration, check :ref:`constraint tutorial <constraint-tutorial>`.
.. _table-rename: .. _table-rename:

View File

@ -29,7 +29,7 @@
.. warning:: .. warning::
Version **0.5.5** breaks backward compatability, please read :ref:`changelog <backwards-055>` for more info. Version **0.6.0** breaks backward compatability, please read :ref:`changelog <backwards-06>` for more info.
Download and Development Download and Development

View File

@ -173,7 +173,7 @@ class UniqueConstraint(ConstraintChangeset, schema.UniqueConstraint):
:type table: Table instance :type table: Table instance
:type cols: strings or Column instances :type cols: strings or Column instances
.. versionadded:: 0.5.5 .. versionadded:: 0.6.0
""" """
__migrate_visit_name__ = 'migrate_unique_constraint' __migrate_visit_name__ = 'migrate_unique_constraint'

View File

@ -485,12 +485,16 @@ class ChangesetColumn(object):
:param primary_key_name: Creates :class:\ :param primary_key_name: Creates :class:\
`~migrate.changeset.constraint.PrimaryKeyConstraint` on this column. `~migrate.changeset.constraint.PrimaryKeyConstraint` on this column.
:param alter_metadata: If True, column will be added to table object. :param alter_metadata: If True, column will be added to table object.
:param populate_default: If True, created column will be \
populated with defaults
:type table: Table instance :type table: Table instance
:type index_name: string :type index_name: string
:type unique_name: string :type unique_name: string
:type primary_key_name: string :type primary_key_name: string
:type alter_metadata: bool :type alter_metadata: bool
:type populate_default: bool
""" """
self.populate_default = kwargs.pop('populate_default', False)
self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA) self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
self.index_name = index_name self.index_name = index_name
self.unique_name = unique_name self.unique_name = unique_name
@ -503,6 +507,11 @@ class ChangesetColumn(object):
engine = self.table.bind engine = self.table.bind
visitorcallable = get_engine_visitor(engine, 'columngenerator') visitorcallable = get_engine_visitor(engine, 'columngenerator')
engine._run_visitor(visitorcallable, self, *args, **kwargs) engine._run_visitor(visitorcallable, self, *args, **kwargs)
if self.populate_default and self.default is not None:
stmt = table.update().values({self: engine._execute_default(self.default)})
engine.execute(stmt)
return self return self
def drop(self, table=None, *args, **kwargs): def drop(self, table=None, *args, **kwargs):

View File

@ -1,12 +1,17 @@
""" """
This module provides an external API to the versioning system. This module provides an external API to the versioning system.
.. versionchanged:: 0.4.5 .. versionchanged:: 0.6.0
:func:`migrate.versioning.api.test` and schema diff functions
changed order of positional arguments so all accept `url` and `repository`
as first arguments.
.. versionchanged:: 0.5.4
``--preview_sql`` displays source file when using SQL scripts. ``--preview_sql`` displays source file when using SQL scripts.
If Python script is used, it runs the action with mocked engine and If Python script is used, it runs the action with mocked engine and
returns captured SQL statements. returns captured SQL statements.
.. versionchanged:: 0.4.5 .. versionchanged:: 0.5.4
Deprecated ``--echo`` parameter in favour of new Deprecated ``--echo`` parameter in favour of new
:func:`migrate.versioning.util.construct_engine` behavior. :func:`migrate.versioning.util.construct_engine` behavior.
""" """
@ -23,12 +28,14 @@
import sys import sys
import inspect import inspect
import warnings import warnings
import logging
from migrate.versioning import (exceptions, repository, schema, version, from migrate.versioning import (exceptions, repository, schema, version,
script as script_) # command name conflict script as script_) # command name conflict
from migrate.versioning.util import catch_known_errors, construct_engine from migrate.versioning.util import catch_known_errors, construct_engine
log = logging.getLogger(__name__)
command_desc = { command_desc = {
'help': 'displays help on a given command', 'help': 'displays help on a given command',
'create': 'create an empty repository at the specified path', 'create': 'create an empty repository at the specified path',
@ -193,8 +200,8 @@ def downgrade(url, repository, version, **opts):
"Try 'upgrade' instead." "Try 'upgrade' instead."
return _migrate(url, repository, version, upgrade=False, err=err, **opts) return _migrate(url, repository, version, upgrade=False, err=err, **opts)
def test(repository, url, **opts): def test(url, repository, **opts):
"""%prog test REPOSITORY_PATH URL [VERSION] """%prog test URL REPOSITORY_PATH [VERSION]
Performs the upgrade and downgrade option on the given Performs the upgrade and downgrade option on the given
database. This is not a real test and may leave the database in a database. This is not a real test and may leave the database in a
@ -206,14 +213,14 @@ def test(repository, url, **opts):
script = repos.version(None).script() script = repos.version(None).script()
# Upgrade # Upgrade
print "Upgrading...", log.info("Upgrading...")
script.run(engine, 1) script.run(engine, 1)
print "done" log.info("done")
print "Downgrading...", log.info("Downgrading...")
script.run(engine, -1) script.run(engine, -1)
print "done" log.info("done")
print "Success" log.info("Success")
def version_control(url, repository, version=None, **opts): def version_control(url, repository, version=None, **opts):
@ -268,8 +275,8 @@ def manage(file, **opts):
Repository.create_manage_file(file, **opts) Repository.create_manage_file(file, **opts)
def compare_model_to_db(url, model, repository, **opts): def compare_model_to_db(url, repository, model, **opts):
"""%prog compare_model_to_db URL MODEL REPOSITORY_PATH """%prog compare_model_to_db URL REPOSITORY_PATH MODEL
Compare the current model (assumed to be a module level variable Compare the current model (assumed to be a module level variable
of type sqlalchemy.MetaData) against the current database. of type sqlalchemy.MetaData) against the current database.
@ -277,7 +284,7 @@ def compare_model_to_db(url, model, repository, **opts):
NOTE: This is EXPERIMENTAL. NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label """ # TODO: get rid of EXPERIMENTAL label
engine = construct_engine(url, **opts) engine = construct_engine(url, **opts)
print ControlledSchema.compare_model_to_db(engine, model, repository) return ControlledSchema.compare_model_to_db(engine, model, repository)
def create_model(url, repository, **opts): def create_model(url, repository, **opts):
@ -289,12 +296,11 @@ def create_model(url, repository, **opts):
""" # TODO: get rid of EXPERIMENTAL label """ # TODO: get rid of EXPERIMENTAL label
engine = construct_engine(url, **opts) engine = construct_engine(url, **opts)
declarative = opts.get('declarative', False) declarative = opts.get('declarative', False)
print ControlledSchema.create_model(engine, repository, declarative) return ControlledSchema.create_model(engine, repository, declarative)
# TODO: get rid of this? if we don't add back path param
@catch_known_errors @catch_known_errors
def make_update_script_for_model(url, oldmodel, model, repository, **opts): def make_update_script_for_model(url, repository, oldmodel, model, **opts):
"""%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH """%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH
Create a script changing the old Python model to the new (current) Create a script changing the old Python model to the new (current)
@ -303,12 +309,12 @@ def make_update_script_for_model(url, oldmodel, model, repository, **opts):
NOTE: This is EXPERIMENTAL. NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label """ # TODO: get rid of EXPERIMENTAL label
engine = construct_engine(url, **opts) engine = construct_engine(url, **opts)
print PythonScript.make_update_script_for_model( return PythonScript.make_update_script_for_model(
engine, oldmodel, model, repository, **opts) engine, oldmodel, model, repository, **opts)
def update_db_from_model(url, model, repository, **opts): def update_db_from_model(url, repository, model, **opts):
"""%prog update_db_from_model URL MODEL REPOSITORY_PATH """%prog update_db_from_model URL REPOSITORY_PATH MODEL
Modify the database to match the structure of the current Python Modify the database to match the structure of the current Python
model. This also sets the db_version number to the latest in the model. This also sets the db_version number to the latest in the
@ -329,27 +335,26 @@ def _migrate(url, repository, version, upgrade, err, **opts):
changeset = schema.changeset(version) changeset = schema.changeset(version)
for ver, change in changeset: for ver, change in changeset:
nextver = ver + changeset.step nextver = ver + changeset.step
print '%s -> %s... ' % (ver, nextver) log.info('%s -> %s... ', ver, nextver)
if opts.get('preview_sql'): if opts.get('preview_sql'):
if isinstance(change, PythonScript): if isinstance(change, PythonScript):
print change.preview_sql(url, changeset.step, **opts) log.info(change.preview_sql(url, changeset.step, **opts))
elif isinstance(change, SqlScript): elif isinstance(change, SqlScript):
print change.source() log.info(change.source())
elif opts.get('preview_py'): elif opts.get('preview_py'):
if not isinstance(change, PythonScript):
raise exceptions.UsageError("Python source can be only displayed"
" for python migration files")
source_ver = max(ver, nextver) source_ver = max(ver, nextver)
module = schema.repository.version(source_ver).script().module module = schema.repository.version(source_ver).script().module
funcname = upgrade and "upgrade" or "downgrade" funcname = upgrade and "upgrade" or "downgrade"
func = getattr(module, funcname) func = getattr(module, funcname)
if isinstance(change, PythonScript): log.info(inspect.getsource(func))
print inspect.getsource(func)
else:
raise UsageError("Python source can be only displayed"
" for python migration files")
else: else:
schema.runchange(ver, change, changeset.step) schema.runchange(ver, change, changeset.step)
print 'done' log.info('done')
def _migrate_version(schema, version, upgrade, err): def _migrate_version(schema, version, upgrade, err):

View File

@ -7,11 +7,13 @@
""" """
import sys import sys
import logging
import migrate import migrate
import sqlalchemy import sqlalchemy
log = logging.getLogger(__name__)
HEADER = """ HEADER = """
## File autogenerated by genmodel.py ## File autogenerated by genmodel.py
@ -140,7 +142,7 @@ class ModelGenerator(object):
upgradeCommands.append("%(table)s.create()" % {'table': tableName}) upgradeCommands.append("%(table)s.create()" % {'table': tableName})
downgradeCommands.append("%(table)s.drop()" % {'table': tableName}) downgradeCommands.append("%(table)s.drop()" % {'table': tableName})
pre_command = 'meta.bind(migrate_engine)' pre_command = ' meta.bind = migrate_engine'
return ( return (
'\n'.join(decls), '\n'.join(decls),

View File

@ -10,8 +10,8 @@ from migrate.versioning import exceptions
from migrate.versioning.config import * from migrate.versioning.config import *
from migrate.versioning.util import KeyedInstance from migrate.versioning.util import KeyedInstance
log = logging.getLogger(__name__)
log = logging.getLogger(__name__)
class Pathed(KeyedInstance): class Pathed(KeyedInstance):
""" """

View File

@ -5,8 +5,8 @@ import os
import shutil import shutil
import string import string
import logging import logging
from pkg_resources import resource_filename
from pkg_resources import resource_filename
from tempita import Template as TempitaTemplate from tempita import Template as TempitaTemplate
from migrate.versioning import exceptions, script, version, pathed, cfgparse from migrate.versioning import exceptions, script, version, pathed, cfgparse

View File

@ -1,6 +1,9 @@
""" """
Database schema version management. Database schema version management.
""" """
import sys
import logging
from sqlalchemy import (Table, Column, MetaData, String, Text, Integer, from sqlalchemy import (Table, Column, MetaData, String, Text, Integer,
create_engine) create_engine)
from sqlalchemy.sql import and_ from sqlalchemy.sql import and_
@ -13,6 +16,8 @@ from migrate.versioning.util import load_model
from migrate.versioning.version import VerNum from migrate.versioning.version import VerNum
log = logging.getLogger(__name__)
class ControlledSchema(object): class ControlledSchema(object):
"""A database under version control""" """A database under version control"""
@ -32,22 +37,17 @@ class ControlledSchema(object):
def load(self): def load(self):
"""Load controlled schema version info from DB""" """Load controlled schema version info from DB"""
tname = self.repository.version_table tname = self.repository.version_table
if not hasattr(self, 'table') or self.table is None:
try:
self.table = Table(tname, self.meta, autoload=True)
except (sa_exceptions.NoSuchTableError,
AssertionError):
# assertionerror is raised if no table is found in oracle db
raise exceptions.DatabaseNotControlledError(tname)
# TODO?: verify that the table is correct (# cols, etc.)
result = self.engine.execute(self.table.select(
self.table.c.repository_id == str(self.repository.id)))
try: try:
if not hasattr(self, 'table') or self.table is None:
self.table = Table(tname, self.meta, autoload=True)
result = self.engine.execute(self.table.select(
self.table.c.repository_id == str(self.repository.id)))
data = list(result)[0] data = list(result)[0]
except IndexError: except Exception:
raise exceptions.DatabaseNotControlledError(tname) cls, exc, tb = sys.exc_info()
raise exceptions.DatabaseNotControlledError, exc.message, tb
self.version = data['version'] self.version = data['version']
return data return data

View File

@ -1,9 +1,14 @@
""" """
Schema differencing support. Schema differencing support.
""" """
import logging
import sqlalchemy import sqlalchemy
from migrate.changeset import SQLA_06 from migrate.changeset import SQLA_06
log = logging.getLogger(__name__)
def getDiffOfModelAgainstDatabase(model, conn, excludeTables=None): def getDiffOfModelAgainstDatabase(model, conn, excludeTables=None):
""" """
Return differences of model against database. Return differences of model against database.

View File

@ -1,13 +1,12 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging import logging
from migrate.versioning.config import operations from migrate.versioning.config import operations
from migrate.versioning import pathed, exceptions from migrate.versioning import pathed, exceptions
log = logging.getLogger(__name__)
log = logging.getLogger(__name__)
class BaseScript(pathed.Pathed): class BaseScript(pathed.Pathed):
"""Base class for other types of scripts. """Base class for other types of scripts.

View File

@ -2,6 +2,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import shutil import shutil
import warnings
import logging
from StringIO import StringIO from StringIO import StringIO
import migrate import migrate
@ -12,6 +14,9 @@ from migrate.versioning.script import base
from migrate.versioning.util import import_path, load_model, construct_engine from migrate.versioning.util import import_path, load_model, construct_engine
log = logging.getLogger(__name__)
__all__ = ['PythonScript']
class PythonScript(base.BaseScript): class PythonScript(base.BaseScript):
"""Base for Python scripts""" """Base for Python scripts"""
@ -83,16 +88,11 @@ class PythonScript(base.BaseScript):
:param path: Script location :param path: Script location
:type path: string :type path: string
:raises: :exc:`InvalidScriptError <migrate.versioning.exceptions.InvalidScriptError>` :raises: :exc:`InvalidScriptError <migrate.versioning.exceptions.InvalidScriptError>`
:returns: Python module :returns: Python module
""" """
# Try to import and get the upgrade() func # Try to import and get the upgrade() func
try: module = import_path(path)
module = import_path(path)
except:
# If the script itself has errors, that's not our problem
raise
try: try:
assert callable(module.upgrade) assert callable(module.upgrade)
except Exception, e: except Exception, e:
@ -129,13 +129,15 @@ class PythonScript(base.BaseScript):
op = 'downgrade' op = 'downgrade'
else: else:
raise exceptions.ScriptError("%d is not a valid step" % step) raise exceptions.ScriptError("%d is not a valid step" % step)
funcname = base.operations[op] funcname = base.operations[op]
script_func = self._func(funcname)
func = self._func(funcname)
try: try:
func(engine) script_func(engine)
except TypeError: except TypeError:
print "upgrade/downgrade functions must accept engine parameter (since ver 0.5.5)" warnings.warn("upgrade/downgrade functions must accept engine"
" parameter (since version > 0.5.4)")
raise raise
@property @property
@ -148,8 +150,7 @@ class PythonScript(base.BaseScript):
return self._module return self._module
def _func(self, funcname): def _func(self, funcname):
try: if not hasattr(self.module, funcname):
return getattr(self.module, funcname) msg = "Function '%s' is not defined in this script"
except AttributeError:
msg = "The function %s is not defined in this script"
raise exceptions.ScriptError(msg % funcname) raise exceptions.ScriptError(msg % funcname)
return getattr(self.module, funcname)

View File

@ -1,14 +1,30 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging
import shutil
from migrate.versioning.script import base from migrate.versioning.script import base
from migrate.versioning.template import Template
log = logging.getLogger(__name__)
class SqlScript(base.BaseScript): class SqlScript(base.BaseScript):
"""A file containing plain SQL statements.""" """A file containing plain SQL statements."""
@classmethod
def create(cls, path, **opts):
"""Create an empty migration script at specified path
:returns: :class:`SqlScript instance <migrate.versioning.script.sql.SqlScript>`"""
cls.require_notfound(path)
src = Template(opts.pop('templates_path', None)).get_sql_script(theme=opts.pop('templates_theme', None))
shutil.copy(src, path)
return cls(path)
# TODO: why is step parameter even here? # TODO: why is step parameter even here?
def run(self, engine, step=None): def run(self, engine, step=None, executemany=True):
"""Runs SQL script through raw dbapi execute call""" """Runs SQL script through raw dbapi execute call"""
text = self.source() text = self.source()
# Don't rely on SA's autocommit here # Don't rely on SA's autocommit here
@ -21,7 +37,7 @@ class SqlScript(base.BaseScript):
# HACK: SQLite doesn't allow multiple statements through # HACK: SQLite doesn't allow multiple statements through
# its execute() method, but it provides executescript() instead # its execute() method, but it provides executescript() instead
dbapi = conn.engine.raw_connection() dbapi = conn.engine.raw_connection()
if getattr(dbapi, 'executescript', None): if executemany and getattr(dbapi, 'executescript', None):
dbapi.executescript(text) dbapi.executescript(text)
else: else:
conn.execute(text) conn.execute(text)

View File

@ -1,14 +1,15 @@
#!/usr/bin/python #!/usr/bin/python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""The migrate command-line tool.""" """The migrate command-line tool."""
import sys import sys
import inspect import inspect
import logging
from optparse import OptionParser, BadOptionError from optparse import OptionParser, BadOptionError
from migrate.versioning.config import *
from migrate.versioning import api, exceptions from migrate.versioning import api, exceptions
from migrate.versioning.config import *
from migrate.versioning.util import asbool
alias = dict( alias = dict(
@ -53,10 +54,14 @@ class PassiveOptionParser(OptionParser):
del rargs[0] del rargs[0]
def main(argv=None, **kwargs): def main(argv=None, **kwargs):
"""kwargs are default options that can be overriden with passing """Shell interface to :mod:`migrate.versioning.api`.
--some_option to cmdline
"""
kwargs are default options that can be overriden with passing
--some_option as command line option
:param disable_logging: Let migrate configure logging
:type disable_logging: bool
"""
argv = argv or list(sys.argv[1:]) argv = argv or list(sys.argv[1:])
commands = list(api.__all__) commands = list(api.__all__)
commands.sort() commands.sort()
@ -70,9 +75,16 @@ def main(argv=None, **kwargs):
""" % '\n\t'.join([u"%s%s" % (command.ljust(28), api.command_desc.get(command)) for command in commands]) """ % '\n\t'.join([u"%s%s" % (command.ljust(28), api.command_desc.get(command)) for command in commands])
parser = PassiveOptionParser(usage=usage) parser = PassiveOptionParser(usage=usage)
parser.add_option("-v", "--verbose", action="store_true", dest="verbose") parser.add_option("-d", "--debug",
parser.add_option("-d", "--debug", action="store_true", dest="debug") action="store_true",
parser.add_option("-f", "--force", action="store_true", dest="force") dest="debug",
default=False,
help="Shortcut to turn on DEBUG mode for logging")
parser.add_option("-q", "--disable_logging",
action="store_true",
dest="disable_logging",
default=False,
help="Use this option to disable logging configuration")
help_commands = ['help', '-h', '--help'] help_commands = ['help', '-h', '--help']
HELP = False HELP = False
@ -142,6 +154,21 @@ def main(argv=None, **kwargs):
# apply overrides # apply overrides
kwargs.update(override_kwargs) kwargs.update(override_kwargs)
# configure options
for key, value in options.__dict__.iteritems():
kwargs.setdefault(key, value)
# configure logging
if not asbool(kwargs.pop('disable_logging', False)):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(message)s")
ch = logging.StreamHandler(sys.stdout)
ch.setFormatter(formatter)
logger.addHandler(ch)
log = logging.getLogger(__name__)
# check if all args are given # check if all args are given
try: try:
num_defaults = len(f_defaults) num_defaults = len(f_defaults)
@ -157,10 +184,8 @@ def main(argv=None, **kwargs):
try: try:
ret = command_func(**kwargs) ret = command_func(**kwargs)
if ret is not None: if ret is not None:
print ret log.info(ret)
except (exceptions.UsageError, exceptions.KnownError), e: except (exceptions.UsageError, exceptions.KnownError), e:
if e.args[0] is None:
parser.print_help()
parser.error(e.args[0]) parser.error(e.args[0])
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -28,6 +28,8 @@ class ScriptCollection(Collection):
class ManageCollection(Collection): class ManageCollection(Collection):
_mask = '%s.py_tmpl' _mask = '%s.py_tmpl'
class SQLScriptCollection(Collection):
_mask = '%s.py_tmpl'
class Template(pathed.Pathed): class Template(pathed.Pathed):
"""Finds the paths/packages of various Migrate templates. """Finds the paths/packages of various Migrate templates.
@ -50,6 +52,7 @@ class Template(pathed.Pathed):
self.repository = RepositoryCollection(os.path.join(path, 'repository')) self.repository = RepositoryCollection(os.path.join(path, 'repository'))
self.script = ScriptCollection(os.path.join(path, 'script')) self.script = ScriptCollection(os.path.join(path, 'script'))
self.manage = ManageCollection(os.path.join(path, 'manage')) self.manage = ManageCollection(os.path.join(path, 'manage'))
self.sql_script = SQLScriptCollection(os.path.join(path, 'sql_script'))
@classmethod @classmethod
def _find_path(cls, pkg): def _find_path(cls, pkg):
@ -82,6 +85,10 @@ class Template(pathed.Pathed):
"""Calls self._get_item('script', *a, **kw)""" """Calls self._get_item('script', *a, **kw)"""
return self._get_item('script', *a, **kw) return self._get_item('script', *a, **kw)
def get_sql_script(self, *a, **kw):
"""Calls self._get_item('sql_script', *a, **kw)"""
return self._get_item('sql_script', *a, **kw)
def get_manage(self, *a, **kw): def get_manage(self, *a, **kw):
"""Calls self._get_item('manage', *a, **kw)""" """Calls self._get_item('manage', *a, **kw)"""
return self._get_item('manage', *a, **kw) return self._get_item('manage', *a, **kw)

View File

@ -0,0 +1,4 @@
#!/usr/bin/env python
from migrate.versioning.shell import main
main(%(defaults)s)

View File

@ -81,7 +81,7 @@ def catch_known_errors(f, *a, **kw):
""" """
try: try:
f(*a, **kw) return f(*a, **kw)
except exceptions.PathFoundError, e: except exceptions.PathFoundError, e:
raise exceptions.KnownError("The path %s already exists" % e.args[0]) raise exceptions.KnownError("The path %s already exists" % e.args[0])
@ -130,3 +130,19 @@ def construct_engine(engine, **opts):
kwargs[key[11:]] = guess_obj_type(value) kwargs[key[11:]] = guess_obj_type(value)
return create_engine(engine, **kwargs) return create_engine(engine, **kwargs)
class Memoize:
"""Memoize(fn) - an instance which acts like fn but memoizes its arguments
Will only work on functions with non-mutable arguments
ActiveState Code 52201
"""
def __init__(self, fn):
self.fn = fn
self.memo = {}
def __call__(self, *args):
if not self.memo.has_key(args):
self.memo[args] = self.fn(*args)
return self.memo[args]

View File

@ -4,10 +4,13 @@
import os import os
import re import re
import shutil import shutil
import logging
from migrate.versioning import exceptions, pathed, script from migrate.versioning import exceptions, pathed, script
log = logging.getLogger(__name__)
class VerNum(object): class VerNum(object):
"""A version number that behaves like a string and int at the same time""" """A version number that behaves like a string and int at the same time"""
@ -98,11 +101,7 @@ class Collection(pathed.Pathed):
filename = '%03d%s.py' % (ver, extra) filename = '%03d%s.py' % (ver, extra)
filepath = self._version_path(filename) filepath = self._version_path(filename)
if os.path.exists(filepath): script.PythonScript.create(filepath, **k)
raise Exception('Script already exists: %s' % filepath)
else:
script.PythonScript.create(filepath, **k)
self.versions[ver] = Version(ver, self.path, [filename]) self.versions[ver] = Version(ver, self.path, [filename])
def create_new_sql_version(self, database, **k): def create_new_sql_version(self, database, **k):
@ -114,10 +113,7 @@ class Collection(pathed.Pathed):
for op in ('upgrade', 'downgrade'): for op in ('upgrade', 'downgrade'):
filename = '%03d_%s_%s.sql' % (ver, database, op) filename = '%03d_%s_%s.sql' % (ver, database, op)
filepath = self._version_path(filename) filepath = self._version_path(filename)
if os.path.exists(filepath): script.SqlScript.create(filepath, **k)
raise Exception('Script already exists: %s' % filepath)
else:
open(filepath, "w").close()
self.versions[ver].add_script(filepath) self.versions[ver].add_script(filepath)
def version(self, vernum=None): def version(self, vernum=None):
@ -137,7 +133,14 @@ class Collection(pathed.Pathed):
class Version(object): class Version(object):
"""A single version in a collection """ """A single version in a collection
:param vernum: Version Number
:param path: Path to script files
:param filelist: List of scripts
:type vernum: int, VerNum
:type path: string
:type filelist: list
"""
def __init__(self, vernum, path, filelist): def __init__(self, vernum, path, filelist):
self.version = VerNum(vernum) self.version = VerNum(vernum)
@ -165,22 +168,6 @@ class Version(object):
"There is no script for %d version" % self.version "There is no script for %d version" % self.version
return ret return ret
# deprecated?
@classmethod
def create(cls, path):
os.mkdir(path)
# create the version as a proper Python package
initfile = os.path.join(path, "__init__.py")
if not os.path.exists(initfile):
# just touch the file
open(initfile, "w").close()
try:
ret = cls(path)
except:
os.rmdir(path)
raise
return ret
def add_script(self, path): def add_script(self, path):
"""Add script to Collection/Version""" """Add script to Collection/Version"""
if path.endswith(Extensions.py): if path.endswith(Extensions.py):
@ -203,10 +190,11 @@ class Version(object):
def _add_script_py(self, path): def _add_script_py(self, path):
if self.python is not None: if self.python is not None:
raise Exception('You can only have one Python script per version,' raise exceptions.ScriptError('You can only have one Python script '
' but you have: %s and %s' % (self.python, path)) 'per version, but you have: %s and %s' % (self.python, path))
self.python = script.PythonScript(path) self.python = script.PythonScript(path)
class Extensions: class Extensions:
"""A namespace for file extensions""" """A namespace for file extensions"""
py = 'py' py = 'py'

View File

@ -7,8 +7,8 @@ tag_svn_revision = 1
tag_build = .dev tag_build = .dev
[nosetests] [nosetests]
#pdb = true pdb = true
#pdb-failures = true pdb-failures = true
#stop = true #stop = true
[aliases] [aliases]

View File

@ -14,7 +14,7 @@ try:
except ImportError: except ImportError:
pass pass
test_requirements = ['nose >= 0.10'] test_requirements = ['nose >= 0.10', 'ScriptTest']
required_deps = ['sqlalchemy >= 0.5', 'decorator', 'tempita'] required_deps = ['sqlalchemy >= 0.5', 'decorator', 'tempita']
readme_file = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'README')) readme_file = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'README'))

View File

@ -273,7 +273,21 @@ class TestAddDropColumn(fixture.DB):
self.assertEqual(u'foobar', row['data']) self.assertEqual(u'foobar', row['data'])
col.drop() col.drop()
@fixture.usedb()
def test_populate_default(self):
"""Test populate_default=True"""
def default():
return 'foobar'
col = Column('data', String(244), default=default)
col.create(self.table, populate_default=True)
self.table.insert(values={'id': 10}).execute()
row = self.table.select(autocommit=True).execute().fetchone()
self.assertEqual(u'foobar', row['data'])
col.drop()
# TODO: test sequence # TODO: test sequence
# TODO: test quoting # TODO: test quoting
# TODO: test non-autoname constraints # TODO: test non-autoname constraints

View File

@ -8,15 +8,17 @@ from sqlalchemy import create_engine, Table, MetaData
from sqlalchemy.orm import create_session from sqlalchemy.orm import create_session
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
from migrate.versioning.util import Memoize
from test.fixture.base import Base from test.fixture.base import Base
from test.fixture.pathed import Pathed from test.fixture.pathed import Pathed
@Memoize
def readurls(): def readurls():
"""read URLs from config file return a list""" """read URLs from config file return a list"""
# TODO: remove tmpfile since sqlite can store db in memory
filename = 'test_db.cfg' filename = 'test_db.cfg'
ret = list() ret = list()
# TODO: remove tmpfile since sqlite can store db in memory
tmpfile = Pathed.tmp() tmpfile = Pathed.tmp()
fullpath = os.path.join(os.curdir, filename) fullpath = os.path.join(os.curdir, filename)
@ -24,7 +26,7 @@ def readurls():
fd = open(fullpath) fd = open(fullpath)
except IOError: except IOError:
raise IOError("""You must specify the databases to use for testing! raise IOError("""You must specify the databases to use for testing!
Copy %(filename)s.tmpl to %(filename)s and edit your database URLs.""" % locals()) Copy %(filename)s.tmpl to %(filename)s and edit your database URLs.""" % locals())
for line in fd: for line in fd:
if line.startswith('#'): if line.startswith('#'):
@ -49,10 +51,6 @@ def is_supported(url, supported, not_supported):
return not (db in not_supported) return not (db in not_supported)
return True return True
# we make the engines global, which should make the tests run a bit faster
urls = readurls()
engines = dict([(url, create_engine(url, echo=True, poolclass=StaticPool)) for url in urls])
def usedb(supported=None, not_supported=None): def usedb(supported=None, not_supported=None):
"""Decorates tests to be run with a database connection """Decorates tests to be run with a database connection
@ -67,6 +65,7 @@ def usedb(supported=None, not_supported=None):
if supported is not None and not_supported is not None: if supported is not None and not_supported is not None:
raise AssertionError("Can't specify both supported and not_supported in fixture.db()") raise AssertionError("Can't specify both supported and not_supported in fixture.db()")
urls = readurls()
my_urls = [url for url in urls if is_supported(url, supported, not_supported)] my_urls = [url for url in urls if is_supported(url, supported, not_supported)]
@decorator @decorator
@ -99,7 +98,7 @@ class DB(Base):
def _connect(self, url): def _connect(self, url):
self.url = url self.url = url
self.engine = engines[url] self.engine = create_engine(url, echo=True, poolclass=StaticPool)
self.meta = MetaData(bind=self.engine) self.meta = MetaData(bind=self.engine)
if self.level < self.CONNECT: if self.level < self.CONNECT:
return return
@ -128,6 +127,7 @@ class DB(Base):
return not (db in func.not_supported) return not (db in func.not_supported)
# Neither list assigned; assume all are supported # Neither list assigned; assume all are supported
return True return True
def _not_supported(self, url): def _not_supported(self, url):
return not self._supported(url) return not self._supported(url)

14
test/fixture/models.py Normal file
View File

@ -0,0 +1,14 @@
from sqlalchemy import *
# test rundiffs in shell
meta_old_rundiffs = MetaData()
meta_rundiffs = MetaData()
meta = MetaData()
tmp_account_rundiffs = Table('tmp_account_rundiffs', meta_rundiffs,
Column('id', Integer, primary_key=True),
Column('login', String(40)),
Column('passwd', String(40)),
)
tmp_sql_table = Table('tmp_sql_table', meta, Column('id', Integer))

View File

@ -6,51 +6,22 @@ import shutil
import sys import sys
import types import types
from scripttest import TestFileEnvironment
from test.fixture.pathed import * from test.fixture.pathed import *
class Shell(Pathed): class Shell(Pathed):
"""Base class for command line tests""" """Base class for command line tests"""
def execute(self, command, *p, **k):
"""Return the fd of a command; can get output (stdout/err) and exitcode"""
# We might be passed a file descriptor for some reason; if so, just return it
if isinstance(command, types.FileType):
return command
# Redirect stderr to stdout def setUp(self):
# This is a bit of a hack, but I've not found a better way super(Shell, self).setUp()
py_path = os.environ.get('PYTHONPATH', '') self.env = TestFileEnvironment(os.path.join(self.temp_usable_dir, 'env'))
py_path_list = py_path.split(':')
py_path_list.append(os.path.abspath('.'))
os.environ['PYTHONPATH'] = ':'.join(py_path_list)
fd = os.popen(command + ' 2>&1')
if py_path: def run_version(self, repos_path):
py_path = os.environ['PYTHONPATH'] = py_path result = self.env.run('migrate version %s' % repos_path)
else: return int(result.stdout.strip())
del os.environ['PYTHONPATH']
return fd
def output_and_exitcode(self, *p, **k): def run_db_version(self, url, repos_path):
fd=self.execute(*p, **k) result = self.env.run('migrate db_version %s %s' % (url, repos_path))
output = fd.read().strip() return int(result.stdout.strip())
exitcode = fd.close()
if k.pop('emit',False):
print output
return (output, exitcode)
def exitcode(self, *p, **k):
"""Execute a command and return its exit code
...without printing its output/errors
"""
ret = self.output_and_exitcode(*p, **k)
return ret[1]
def assertFailure(self, *p, **k):
output,exitcode = self.output_and_exitcode(*p, **k)
assert (exitcode), output
def assertSuccess(self, *p, **k):
output,exitcode = self.output_and_exitcode(*p, **k)
#self.assert_(not exitcode, output)
assert (not exitcode), output

120
test/versioning/test_api.py Normal file
View File

@ -0,0 +1,120 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
from migrate.versioning import api
from migrate.versioning.exceptions import *
from test.fixture.pathed import *
from test.fixture import models
from test import fixture
class TestAPI(Pathed):
def test_help(self):
self.assertTrue(isinstance(api.help('help'), basestring))
self.assertRaises(UsageError, api.help)
self.assertRaises(UsageError, api.help, 'foobar')
self.assert_(isinstance(api.help('create'), str))
# test that all commands return some text
for cmd in api.__all__:
content = api.help(cmd)
self.assertTrue(content)
def test_create(self):
tmprepo = self.tmp_repos()
api.create(tmprepo, 'temp')
# repository already exists
self.assertRaises(KnownError, api.create, tmprepo, 'temp')
def test_script(self):
repo = self.tmp_repos()
api.create(repo, 'temp')
api.script('first version', repo)
def test_script_sql(self):
repo = self.tmp_repos()
api.create(repo, 'temp')
api.script_sql('postgres', repo)
def test_version(self):
repo = self.tmp_repos()
api.create(repo, 'temp')
api.version(repo)
def test_source(self):
repo = self.tmp_repos()
api.create(repo, 'temp')
api.script('first version', repo)
api.script_sql('default', repo)
# no repository
self.assertRaises(UsageError, api.source, 1)
# stdout
out = api.source(1, dest=None, repository=repo)
self.assertTrue(out)
# file
out = api.source(1, dest=self.tmp_repos(), repository=repo)
self.assertFalse(out)
def test_manage(self):
output = api.manage(os.path.join(self.temp_usable_dir, 'manage.py'))
class TestSchemaAPI(fixture.DB, Pathed):
def _setup(self, url):
super(TestSchemaAPI, self)._setup(url)
self.repo = self.tmp_repos()
api.create(self.repo, 'temp')
self.schema = api.version_control(url, self.repo)
def _teardown(self):
self.schema = api.drop_version_control(self.url, self.repo)
super(TestSchemaAPI, self)._teardown()
@fixture.usedb()
def test_workflow(self):
self.assertEqual(api.db_version(self.url, self.repo), 0)
api.script('First Version', self.repo)
self.assertEqual(api.db_version(self.url, self.repo), 0)
api.upgrade(self.url, self.repo, 1)
self.assertEqual(api.db_version(self.url, self.repo), 1)
api.downgrade(self.url, self.repo, 0)
self.assertEqual(api.db_version(self.url, self.repo), 0)
api.test(self.url, self.repo)
self.assertEqual(api.db_version(self.url, self.repo), 0)
# preview
# TODO: test output
out = api.upgrade(self.url, self.repo, preview_py=True)
out = api.upgrade(self.url, self.repo, preview_sql=True)
api.upgrade(self.url, self.repo, 1)
api.script_sql('default', self.repo)
self.assertRaises(UsageError, api.upgrade, self.url, self.repo, 2, preview_py=True)
out = api.upgrade(self.url, self.repo, 2, preview_sql=True)
# cant upgrade to version 1, already at version 1
self.assertEqual(api.db_version(self.url, self.repo), 1)
self.assertRaises(KnownError, api.upgrade, self.url, self.repo, 0)
@fixture.usedb()
def test_compare_model_to_db(self):
diff = api.compare_model_to_db(self.url, self.repo, models.meta)
@fixture.usedb()
def test_create_model(self):
model = api.create_model(self.url, self.repo)
@fixture.usedb()
def test_make_update_script_for_model(self):
model = api.make_update_script_for_model(self.url, self.repo, models.meta_old_rundiffs, models.meta_rundiffs)
@fixture.usedb()
def test_update_db_from_model(self):
model = api.update_db_from_model(self.url, self.repo, models.meta_rundiffs)

View File

@ -1,3 +1,6 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
from test import fixture from test import fixture
from migrate.versioning.util.keyedinstance import * from migrate.versioning.util.keyedinstance import *
@ -38,3 +41,5 @@ class TestKeydInstance(fixture.Base):
Uniq1.clear() Uniq1.clear()
a12 = Uniq1('a') a12 = Uniq1('a')
self.assert_(a10 is not a12) self.assert_(a10 is not a12)
self.assertRaises(NotImplementedError, KeyedInstance._key)

View File

@ -101,13 +101,15 @@ class TestVersionedRepository(fixture.Pathed):
# Load repository and commit script # Load repository and commit script
repo = Repository(self.path_repos) repo = Repository(self.path_repos)
repo.create_script('') repo.create_script('')
repo.create_script_sql('postgres')
# Get script object
source = repo.version(1).script().source()
# Source is valid: script must have an upgrade function # Source is valid: script must have an upgrade function
# (not a very thorough test, but should be plenty) # (not a very thorough test, but should be plenty)
self.assert_(source.find('def upgrade') >= 0) source = repo.version(1).script().source()
self.assertTrue(source.find('def upgrade') >= 0)
source = repo.version(2).script('postgres', 'upgrade').source()
self.assertEqual(source.strip(), '')
def test_latestversion(self): def test_latestversion(self):
"""Repository.version() (no params) returns the latest version""" """Repository.version() (no params) returns the latest version"""

View File

@ -16,11 +16,10 @@ class TestControlledSchema(fixture.Pathed, fixture.DB):
# Transactions break postgres in this test; we'll clean up after ourselves # Transactions break postgres in this test; we'll clean up after ourselves
level = fixture.DB.CONNECT level = fixture.DB.CONNECT
def setUp(self): def setUp(self):
super(TestControlledSchema, self).setUp() super(TestControlledSchema, self).setUp()
path_repos = self.temp_usable_dir + '/repo/' self.path_repos = self.temp_usable_dir + '/repo/'
self.repos = Repository.create(path_repos, 'repo_name') self.repos = Repository.create(self.path_repos, 'repo_name')
def _setup(self, url): def _setup(self, url):
self.setUp() self.setUp()
@ -116,7 +115,7 @@ class TestControlledSchema(fixture.Pathed, fixture.DB):
#self.assertRaises(ControlledSchema.InvalidVersionError, #self.assertRaises(ControlledSchema.InvalidVersionError,
# Can't have custom errors with assertRaises... # Can't have custom errors with assertRaises...
try: try:
ControlledSchema.create(self.engine,self.repos,version) ControlledSchema.create(self.engine, self.repos, version)
self.assert_(False, repr(version)) self.assert_(False, repr(version))
except exceptions.InvalidVersionError: except exceptions.InvalidVersionError:
pass pass

View File

@ -63,10 +63,10 @@ class TestSchemaDiff(fixture.DB):
) )
''') ''')
self.assertEqualsIgnoreWhitespace(upgradeCommands, self.assertEqualsIgnoreWhitespace(upgradeCommands,
'''meta.bind(migrate_engine) '''meta.bind = migrate_engine
tmp_schemadiff.create()''') tmp_schemadiff.create()''')
self.assertEqualsIgnoreWhitespace(downgradeCommands, self.assertEqualsIgnoreWhitespace(downgradeCommands,
'''meta.bind(migrate_engine) '''meta.bind = migrate_engine
tmp_schemadiff.drop()''') tmp_schemadiff.drop()''')
# Create table in database, now model should match database. # Create table in database, now model should match database.

View File

@ -10,6 +10,7 @@ from migrate.versioning.script import *
from migrate.versioning.util import * from migrate.versioning.util import *
from test import fixture from test import fixture
from test.fixture.models import tmp_sql_table
class TestBaseScript(fixture.Pathed): class TestBaseScript(fixture.Pathed):
@ -48,6 +49,25 @@ class TestPyScript(fixture.Pathed, fixture.DB):
self.assertRaises(exceptions.ScriptError, pyscript.run, self.engine, 0) self.assertRaises(exceptions.ScriptError, pyscript.run, self.engine, 0)
self.assertRaises(exceptions.ScriptError, pyscript._func, 'foobar') self.assertRaises(exceptions.ScriptError, pyscript._func, 'foobar')
# clean pyc file
os.remove(script_path + 'c')
# test deprecated upgrade/downgrade with no arguments
contents = open(script_path, 'r').read()
f = open(script_path, 'w')
f.write(contents.replace("upgrade(migrate_engine)", "upgrade()"))
f.close()
pyscript = PythonScript(script_path)
pyscript._module = None
try:
pyscript.run(self.engine, 1)
pyscript.run(self.engine, -1)
except exceptions.ScriptError:
pass
else:
self.fail()
def test_verify_notfound(self): def test_verify_notfound(self):
"""Correctly verify a python migration script: nonexistant file""" """Correctly verify a python migration script: nonexistant file"""
path = self.tmp_py() path = self.tmp_py()
@ -60,7 +80,7 @@ class TestPyScript(fixture.Pathed, fixture.DB):
"""Correctly verify a python migration script: invalid python file""" """Correctly verify a python migration script: invalid python file"""
path=self.tmp_py() path=self.tmp_py()
# Create empty file # Create empty file
f=open(path,'w') f = open(path,'w')
f.write("def fail") f.write("def fail")
f.close() f.close()
self.assertRaises(Exception,self.cls.verify_module,path) self.assertRaises(Exception,self.cls.verify_module,path)
@ -86,7 +106,7 @@ class TestPyScript(fixture.Pathed, fixture.DB):
path = self.tmp_py() path = self.tmp_py()
f = open(path, 'w') f = open(path, 'w')
content = """ content = '''
from migrate import * from migrate import *
from sqlalchemy import * from sqlalchemy import *
@ -99,7 +119,7 @@ UserGroup = Table('Link', metadata,
def upgrade(migrate_engine): def upgrade(migrate_engine):
metadata.create_all(migrate_engine) metadata.create_all(migrate_engine)
""" '''
f.write(content) f.write(content)
f.close() f.close()
@ -130,7 +150,6 @@ def upgrade(migrate_engine):
self.write_file(self.first_model_path, self.base_source) self.write_file(self.first_model_path, self.base_source)
self.write_file(self.second_model_path, self.base_source + self.model_source) self.write_file(self.second_model_path, self.base_source + self.model_source)
source_script = self.pyscript.make_update_script_for_model( source_script = self.pyscript.make_update_script_for_model(
engine=self.engine, engine=self.engine,
oldmodel=load_model('testmodel_first:meta'), oldmodel=load_model('testmodel_first:meta'),
@ -176,7 +195,6 @@ User = Table('User', meta,
self.repo = repository.Repository.create(self.repo_path, 'repo') self.repo = repository.Repository.create(self.repo_path, 'repo')
self.pyscript = PythonScript.create(self.script_path) self.pyscript = PythonScript.create(self.script_path)
def write_file(self, path, contents): def write_file(self, path, contents):
f = open(path, 'w') f = open(path, 'w')
f.write(contents) f.write(contents)
@ -196,3 +214,31 @@ class TestSqlScript(fixture.Pathed, fixture.DB):
sqls = SqlScript(src) sqls = SqlScript(src)
self.assertRaises(Exception, sqls.run, self.engine) self.assertRaises(Exception, sqls.run, self.engine)
@fixture.usedb()
def test_success(self):
"""Test sucessful SQL execution"""
# cleanup and prepare python script
tmp_sql_table.metadata.drop_all(self.engine, checkfirst=True)
script_path = self.tmp_py()
pyscript = PythonScript.create(script_path)
# populate python script
contents = open(script_path, 'r').read()
contents = contents.replace("pass", "tmp_sql_table.create(migrate_engine)")
contents = 'from test.fixture.models import tmp_sql_table\n' + contents
f = open(script_path, 'w')
f.write(contents)
f.close()
# write SQL script from python script preview
pyscript = PythonScript(script_path)
src = self.tmp()
f = open(src, 'w')
f.write(pyscript.preview_sql(self.url, 1))
f.close()
# run the change
sqls = SqlScript(src)
sqls.run(self.engine, executemany=False)
tmp_sql_table.metadata.drop_all(self.engine, checkfirst=True)

View File

@ -2,128 +2,83 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import sys import tempfile
import shutil from runpy import run_module
import traceback
from types import FileType
from StringIO import StringIO
from sqlalchemy import MetaData,Table from sqlalchemy import MetaData, Table
from migrate.versioning.repository import Repository from migrate.versioning.repository import Repository
from migrate.versioning import genmodel, shell, api from migrate.versioning import genmodel, shell, api
from migrate.versioning.exceptions import * from migrate.versioning.exceptions import *
from test import fixture from test.fixture import Shell, DB, usedb
class Shell(fixture.Shell):
_cmd = os.path.join(sys.executable + ' migrate', 'versioning', 'shell.py')
@classmethod
def cmd(cls, *args):
safe_parameters = map(lambda arg: str(arg), args)
return ' '.join([cls._cmd] + safe_parameters)
def execute(self, shell_cmd, runshell=None, **kwargs):
"""A crude simulation of a shell command, to speed things up"""
# If we get an fd, the command is already done
if isinstance(shell_cmd, (FileType, StringIO)):
return shell_cmd
# Analyze the command; see if we can 'fake' the shell
try:
# Forced to run in shell?
# if runshell or '--runshell' in sys.argv:
if runshell:
raise Exception
# Remove the command prefix
if not shell_cmd.startswith(self._cmd):
raise Exception
cmd = shell_cmd[(len(self._cmd) + 1):]
params = cmd.split(' ')
command = params[0]
except:
return super(Shell, self).execute(shell_cmd)
# Redirect stdout to an object; redirect stderr to stdout
fd = StringIO()
orig_stdout = sys.stdout
orig_stderr = sys.stderr
sys.stdout = fd
sys.stderr = fd
# Execute this command
try:
try:
shell.main(params, **kwargs)
except SystemExit, e:
# Simulate the exit status
fd_close = fd.close
def close_():
fd_close()
return e.args[0]
fd.close = close_
except Exception, e:
# Print the exception, but don't re-raise it
traceback.print_exc()
# Simulate a nonzero exit status
fd_close = fd.close
def close_():
fd_close()
return 2
fd.close = close_
finally:
# Clean up
sys.stdout = orig_stdout
sys.stderr = orig_stderr
fd.seek(0)
return fd
def cmd_version(self, repos_path):
fd = self.execute(self.cmd('version', repos_path))
result = int(fd.read().strip())
self.assertSuccess(fd)
return result
def cmd_db_version(self, url, repos_path):
fd = self.execute(self.cmd('db_version', url, repos_path))
txt = fd.read()
#print txt
ret = int(txt.strip())
self.assertSuccess(fd)
return ret
class TestShellCommands(Shell): class TestShellCommands(Shell):
"""Tests migrate.py commands""" """Tests migrate.py commands"""
def test_help(self): def test_help(self):
"""Displays default help dialog""" """Displays default help dialog"""
self.assertSuccess(self.cmd('-h'), runshell=True) self.assertEqual(self.env.run('migrate -h').returncode, 0)
self.assertSuccess(self.cmd('--help'), runshell=True) self.assertEqual(self.env.run('migrate --help').returncode, 0)
self.assertSuccess(self.cmd('help'), runshell=True) self.assertEqual(self.env.run('migrate help').returncode, 0)
self.assertSuccess(self.cmd('help'))
self.assertRaises(UsageError, api.help)
self.assertRaises(UsageError, api.help, 'foobar')
self.assert_(isinstance(api.help('create'), str))
def test_help_commands(self): def test_help_commands(self):
"""Display help on a specific command""" """Display help on a specific command"""
for cmd in shell.api.__all__: # we can only test that we get some output
fd = self.execute(self.cmd('help', cmd)) for cmd in api.__all__:
# Description may change, so best we can do is ensure it shows up result = self.env.run('migrate help %s' % cmd)
output = fd.read() self.assertTrue(isinstance(result.stdout, basestring))
self.assertNotEquals(output, '') self.assertTrue(result.stdout)
self.assertSuccess(fd) self.assertFalse(result.stderr)
def test_shutdown_logging(self):
"""Try to shutdown logging output"""
repos = self.tmp_repos()
result = self.env.run('migrate create %s repository_name' % repos)
result = self.env.run('migrate version %s --disable_logging' % repos)
self.assertEqual(result.stdout, '')
result = self.env.run('migrate version %s -q' % repos)
self.assertEqual(result.stdout, '')
# TODO: assert logging messages to 0
shell.main(['version', repos], logging=False)
def test_main(self):
"""Test main() function"""
# TODO: test output?
try:
run_module('migrate.versioning.shell', run_name='__main__')
except:
pass
repos = self.tmp_repos()
shell.main(['help'])
shell.main(['help', 'create'])
shell.main(['create', 'repo_name', '--preview_sql'], repository=repos)
shell.main(['version', '--', '--repository=%s' % repos])
shell.main(['version', '-d', '--repository=%s' % repos, '--version=2'])
try:
shell.main(['foobar'])
except SystemExit, e:
pass
try:
shell.main(['create', 'f', 'o', 'o'])
except SystemExit, e:
pass
try:
shell.main(['create'])
except SystemExit, e:
pass
try:
shell.main(['create', 'repo_name'], repository=repos)
except SystemExit, e:
pass
def test_create(self): def test_create(self):
"""Repositories are created successfully""" """Repositories are created successfully"""
repos = self.tmp_repos() repos = self.tmp_repos()
# Creating a file that doesn't exist should succeed # Creating a file that doesn't exist should succeed
cmd = self.cmd('create', repos, 'repository_name') result = self.env.run('migrate create %s repository_name' % repos)
self.assertSuccess(cmd)
# Files should actually be created # Files should actually be created
self.assert_(os.path.exists(repos)) self.assert_(os.path.exists(repos))
@ -133,241 +88,253 @@ class TestShellCommands(Shell):
self.assertNotEquals(repos_.config.get('db_settings', 'version_table'), 'None') self.assertNotEquals(repos_.config.get('db_settings', 'version_table'), 'None')
# Can't create it again: it already exists # Can't create it again: it already exists
self.assertFailure(cmd) result = self.env.run('migrate create %s repository_name' % repos,
expect_error=True)
self.assertEqual(result.returncode, 2)
def test_script(self): def test_script(self):
"""We can create a migration script via the command line""" """We can create a migration script via the command line"""
repos = self.tmp_repos() repos = self.tmp_repos()
self.assertSuccess(self.cmd('create', repos, 'repository_name')) result = self.env.run('migrate create %s repository_name' % repos)
self.assertSuccess(self.cmd('script', '--repository=%s' % repos, 'Desc')) result = self.env.run('migrate script --repository=%s Desc' % repos)
self.assert_(os.path.exists('%s/versions/001_Desc.py' % repos)) self.assert_(os.path.exists('%s/versions/001_Desc.py' % repos))
self.assertSuccess(self.cmd('script', '--repository=%s' % repos, 'More')) result = self.env.run('migrate script More %s' % repos)
self.assert_(os.path.exists('%s/versions/002_More.py' % repos)) self.assert_(os.path.exists('%s/versions/002_More.py' % repos))
self.assertSuccess(self.cmd('script', '--repository=%s' % repos, '"Some Random name"'), runshell=True) result = self.env.run('migrate script "Some Random name" %s' % repos)
self.assert_(os.path.exists('%s/versions/003_Some_Random_name.py' % repos)) self.assert_(os.path.exists('%s/versions/003_Some_Random_name.py' % repos))
def test_script_sql(self): def test_script_sql(self):
"""We can create a migration sql script via the command line""" """We can create a migration sql script via the command line"""
repos = self.tmp_repos() repos = self.tmp_repos()
self.assertSuccess(self.cmd('create', repos, 'repository_name')) result = self.env.run('migrate create %s repository_name' % repos)
self.assertSuccess(self.cmd('script_sql', '--repository=%s' % repos, 'mydb')) result = self.env.run('migrate script_sql mydb %s' % repos)
self.assert_(os.path.exists('%s/versions/001_mydb_upgrade.sql' % repos)) self.assert_(os.path.exists('%s/versions/001_mydb_upgrade.sql' % repos))
self.assert_(os.path.exists('%s/versions/001_mydb_downgrade.sql' % repos)) self.assert_(os.path.exists('%s/versions/001_mydb_downgrade.sql' % repos))
# Test creating a second # Test creating a second
self.assertSuccess(self.cmd('script_sql', '--repository=%s' % repos, 'postgres')) result = self.env.run('migrate script_sql postgres --repository=%s' % repos)
self.assert_(os.path.exists('%s/versions/002_postgres_upgrade.sql' % repos)) self.assert_(os.path.exists('%s/versions/002_postgres_upgrade.sql' % repos))
self.assert_(os.path.exists('%s/versions/002_postgres_downgrade.sql' % repos)) self.assert_(os.path.exists('%s/versions/002_postgres_downgrade.sql' % repos))
# TODO: test --previews
def test_manage(self): def test_manage(self):
"""Create a project management script""" """Create a project management script"""
script = self.tmp_py() script = self.tmp_py()
self.assert_(not os.path.exists(script)) self.assert_(not os.path.exists(script))
# No attempt is made to verify correctness of the repository path here # No attempt is made to verify correctness of the repository path here
self.assertSuccess(self.cmd('manage', script, '--repository=/path/to/repository')) result = self.env.run('migrate manage %s --repository=/bla/' % script)
self.assert_(os.path.exists(script)) self.assert_(os.path.exists(script))
class TestShellRepository(Shell): class TestShellRepository(Shell):
"""Shell commands on an existing repository/python script""" """Shell commands on an existing repository/python script"""
def setUp(self): def setUp(self):
"""Create repository, python change script""" """Create repository, python change script"""
super(TestShellRepository, self).setUp() super(TestShellRepository, self).setUp()
self.path_repos = repos = self.tmp_repos() self.path_repos = self.tmp_repos()
self.assertSuccess(self.cmd('create', repos, 'repository_name')) result = self.env.run('migrate create %s repository_name' % self.path_repos)
def test_version(self): def test_version(self):
"""Correctly detect repository version""" """Correctly detect repository version"""
# Version: 0 (no scripts yet); successful execution # Version: 0 (no scripts yet); successful execution
fd = self.execute(self.cmd('version','--repository=%s' % self.path_repos)) result = self.env.run('migrate version --repository=%s' % self.path_repos)
self.assertEquals(fd.read().strip(), "0") self.assertEqual(result.stdout.strip(), "0")
self.assertSuccess(fd)
# Also works as a positional param # Also works as a positional param
fd = self.execute(self.cmd('version', self.path_repos)) result = self.env.run('migrate version %s' % self.path_repos)
self.assertEquals(fd.read().strip(), "0") self.assertEqual(result.stdout.strip(), "0")
self.assertSuccess(fd)
# Create a script and version should increment # Create a script and version should increment
self.assertSuccess(self.cmd('script', '--repository=%s' % self.path_repos, 'Desc')) result = self.env.run('migrate script Desc %s' % self.path_repos)
fd = self.execute(self.cmd('version',self.path_repos)) result = self.env.run('migrate version %s' % self.path_repos)
self.assertEquals(fd.read().strip(), "1") self.assertEqual(result.stdout.strip(), "1")
self.assertSuccess(fd)
def test_source(self): def test_source(self):
"""Correctly fetch a script's source""" """Correctly fetch a script's source"""
self.assertSuccess(self.cmd('script', '--repository=%s' % self.path_repos, 'Desc')) result = self.env.run('migrate script Desc --repository=%s' % self.path_repos)
filename = '%s/versions/001_Desc.py' % self.path_repos filename = '%s/versions/001_Desc.py' % self.path_repos
source = open(filename).read() source = open(filename).read()
self.assert_(source.find('def upgrade') >= 0) self.assert_(source.find('def upgrade') >= 0)
# Version is now 1 # Version is now 1
fd = self.execute(self.cmd('version', self.path_repos)) result = self.env.run('migrate version %s' % self.path_repos)
self.assert_(fd.read().strip() == "1") self.assertEqual(result.stdout.strip(), "1")
self.assertSuccess(fd)
# Output/verify the source of version 1 # Output/verify the source of version 1
fd = self.execute(self.cmd('source', 1, '--repository=%s' % self.path_repos)) result = self.env.run('migrate source 1 --repository=%s' % self.path_repos)
result = fd.read() self.assertEqual(result.stdout.strip(), source.strip())
self.assertSuccess(fd)
self.assert_(result.strip() == source.strip())
# We can also send the source to a file... test that too # We can also send the source to a file... test that too
self.assertSuccess(self.cmd('source', 1, filename, '--repository=%s'%self.path_repos)) result = self.env.run('migrate source 1 %s --repository=%s' %
(filename, self.path_repos))
self.assert_(os.path.exists(filename)) self.assert_(os.path.exists(filename))
fd = open(filename) fd = open(filename)
result = fd.read() result = fd.read()
self.assert_(result.strip() == source.strip()) self.assert_(result.strip() == source.strip())
class TestShellDatabase(Shell, fixture.DB):
class TestShellDatabase(Shell, DB):
"""Commands associated with a particular database""" """Commands associated with a particular database"""
# We'll need to clean up after ourself, since the shell creates its own txn; # We'll need to clean up after ourself, since the shell creates its own txn;
# we need to connect to the DB to see if things worked # we need to connect to the DB to see if things worked
level = fixture.DB.CONNECT level = DB.CONNECT
@fixture.usedb() @usedb()
def test_version_control(self): def test_version_control(self):
"""Ensure we can set version control on a database""" """Ensure we can set version control on a database"""
path_repos = repos = self.tmp_repos() path_repos = repos = self.tmp_repos()
self.assertSuccess(self.cmd('create', path_repos, 'repository_name')) url = self.url
self.exitcode(self.cmd('drop_version_control', self.url, path_repos)) result = self.env.run('migrate create %s repository_name' % repos)
self.assertSuccess(self.cmd('version_control', self.url, path_repos))
result = self.env.run('migrate drop_version_control %(url)s %(repos)s'\
% locals(), expect_error=True)
self.assertEqual(result.returncode, 1)
result = self.env.run('migrate version_control %(url)s %(repos)s' % locals())
# Clean up # Clean up
self.assertSuccess(self.cmd('drop_version_control',self.url,path_repos)) result = self.env.run('migrate drop_version_control %(url)s %(repos)s' % locals())
# Attempting to drop vc from a database without it should fail # Attempting to drop vc from a database without it should fail
self.assertFailure(self.cmd('drop_version_control',self.url,path_repos)) result = self.env.run('migrate drop_version_control %(url)s %(repos)s'\
% locals(), expect_error=True)
self.assertEqual(result.returncode, 1)
@fixture.usedb() @usedb()
def test_wrapped_kwargs(self): def test_wrapped_kwargs(self):
"""Commands with default arguments set by manage.py""" """Commands with default arguments set by manage.py"""
path_repos = repos = self.tmp_repos() path_repos = repos = self.tmp_repos()
self.assertSuccess(self.cmd('create', '--', '--name=repository_name'), repository=path_repos) url = self.url
self.exitcode(self.cmd('drop_version_control'), url=self.url, repository=path_repos) result = self.env.run('migrate create --name=repository_name %s' % repos)
self.assertSuccess(self.cmd('version_control'), url=self.url, repository=path_repos) result = self.env.run('migrate drop_version_control %(url)s %(repos)s' % locals(), expect_error=True)
self.assertEqual(result.returncode, 1)
result = self.env.run('migrate version_control %(url)s %(repos)s' % locals())
# Clean up result = self.env.run('migrate drop_version_control %(url)s %(repos)s' % locals())
self.assertSuccess(self.cmd('drop_version_control'), url=self.url, repository=path_repos)
# Attempting to drop vc from a database without it should fail
self.assertFailure(self.cmd('drop_version_control'), url=self.url, repository=path_repos)
@fixture.usedb() @usedb()
def test_version_control_specified(self): def test_version_control_specified(self):
"""Ensure we can set version control to a particular version""" """Ensure we can set version control to a particular version"""
path_repos = self.tmp_repos() path_repos = self.tmp_repos()
self.assertSuccess(self.cmd('create', path_repos, 'repository_name')) url = self.url
self.exitcode(self.cmd('drop_version_control', self.url, path_repos)) result = self.env.run('migrate create --name=repository_name %s' % path_repos)
result = self.env.run('migrate drop_version_control %(url)s %(path_repos)s' % locals(), expect_error=True)
self.assertEqual(result.returncode, 1)
# Fill the repository # Fill the repository
path_script = self.tmp_py() path_script = self.tmp_py()
version = 1 version = 2
for i in range(version): for i in range(version):
self.assertSuccess(self.cmd('script', '--repository=%s' % path_repos, 'Desc')) result = self.env.run('migrate script Desc --repository=%s' % path_repos)
# Repository version is correct # Repository version is correct
fd = self.execute(self.cmd('version', path_repos)) result = self.env.run('migrate version %s' % path_repos)
self.assertEquals(fd.read().strip(), str(version)) self.assertEqual(result.stdout.strip(), str(version))
self.assertSuccess(fd)
# Apply versioning to DB # Apply versioning to DB
self.assertSuccess(self.cmd('version_control', self.url, path_repos, version)) result = self.env.run('migrate version_control %(url)s %(path_repos)s %(version)s' % locals())
# Test version number # Test db version number (should start at 2)
fd = self.execute(self.cmd('db_version', self.url, path_repos)) result = self.env.run('migrate db_version %(url)s %(path_repos)s' % locals())
self.assertEquals(fd.read().strip(), str(version)) self.assertEqual(result.stdout.strip(), str(version))
self.assertSuccess(fd)
# Clean up # Clean up
self.assertSuccess(self.cmd('drop_version_control', self.url, path_repos)) result = self.env.run('migrate drop_version_control %(url)s %(path_repos)s' % locals())
@fixture.usedb() @usedb()
def test_upgrade(self): def test_upgrade(self):
"""Can upgrade a versioned database""" """Can upgrade a versioned database"""
# Create a repository # Create a repository
repos_name = 'repos_name' repos_name = 'repos_name'
repos_path = self.tmp() repos_path = self.tmp()
self.assertSuccess(self.cmd('create', repos_path,repos_name)) result = self.env.run('migrate create %(repos_path)s %(repos_name)s' % locals())
self.assertEquals(self.cmd_version(repos_path), 0) self.assertEquals(self.run_version(repos_path), 0)
# Version the DB # Version the DB
self.exitcode(self.cmd('drop_version_control', self.url, repos_path)) result = self.env.run('migrate drop_version_control %s %s' % (self.url, repos_path), expect_error=True)
self.assertSuccess(self.cmd('version_control', self.url, repos_path)) result = self.env.run('migrate version_control %s %s' % (self.url, repos_path))
# Upgrades with latest version == 0 # Upgrades with latest version == 0
self.assertEquals(self.cmd_db_version(self.url, repos_path), 0) self.assertEquals(self.run_db_version(self.url, repos_path), 0)
self.assertSuccess(self.cmd('upgrade', self.url, repos_path)) result = self.env.run('migrate upgrade %s %s' % (self.url, repos_path))
self.assertEquals(self.cmd_db_version(self.url, repos_path), 0) self.assertEquals(self.run_db_version(self.url, repos_path), 0)
self.assertSuccess(self.cmd('upgrade', self.url, repos_path, 0)) result = self.env.run('migrate upgrade %s %s' % (self.url, repos_path))
self.assertEquals(self.cmd_db_version(self.url, repos_path), 0) self.assertEquals(self.run_db_version(self.url, repos_path), 0)
self.assertFailure(self.cmd('upgrade', self.url, repos_path, 1)) result = self.env.run('migrate upgrade %s %s 1' % (self.url, repos_path), expect_error=True)
self.assertFailure(self.cmd('upgrade', self.url, repos_path, -1)) self.assertEquals(result.returncode, 1)
result = self.env.run('migrate upgrade %s %s -1' % (self.url, repos_path), expect_error=True)
self.assertEquals(result.returncode, 2)
# Add a script to the repository; upgrade the db # Add a script to the repository; upgrade the db
self.assertSuccess(self.cmd('script', '--repository=%s' % repos_path, 'Desc')) result = self.env.run('migrate script Desc --repository=%s' % (repos_path))
self.assertEquals(self.cmd_version(repos_path), 1) self.assertEquals(self.run_version(repos_path), 1)
self.assertEquals(self.cmd_db_version(self.url, repos_path), 0) self.assertEquals(self.run_db_version(self.url, repos_path), 0)
# Test preview # Test preview
self.assertSuccess(self.cmd('upgrade', self.url, repos_path, 0, "--preview_sql")) result = self.env.run('migrate upgrade %s %s 0 --preview_sql' % (self.url, repos_path))
self.assertSuccess(self.cmd('upgrade', self.url, repos_path, 0, "--preview_py")) result = self.env.run('migrate upgrade %s %s 0 --preview_py' % (self.url, repos_path))
self.assertSuccess(self.cmd('upgrade', self.url, repos_path)) result = self.env.run('migrate upgrade %s %s' % (self.url, repos_path))
self.assertEquals(self.cmd_db_version(self.url, repos_path), 1) self.assertEquals(self.run_db_version(self.url, repos_path), 1)
# Downgrade must have a valid version specified # Downgrade must have a valid version specified
self.assertFailure(self.cmd('downgrade', self.url, repos_path)) result = self.env.run('migrate downgrade %s %s' % (self.url, repos_path), expect_error=True)
self.assertFailure(self.cmd('downgrade', self.url, repos_path, '-1', 2)) self.assertEquals(result.returncode, 2)
#self.assertFailure(self.cmd('downgrade', self.url, repos_path, '1', 2)) result = self.env.run('migrate downgrade %s %s -1' % (self.url, repos_path), expect_error=True)
self.assertEquals(self.cmd_db_version(self.url, repos_path), 1) self.assertEquals(result.returncode, 2)
result = self.env.run('migrate downgrade %s %s 2' % (self.url, repos_path), expect_error=True)
self.assertEquals(result.returncode, 2)
self.assertEquals(self.run_db_version(self.url, repos_path), 1)
self.assertSuccess(self.cmd('downgrade', self.url, repos_path, 0)) result = self.env.run('migrate downgrade %s %s 0' % (self.url, repos_path))
self.assertEquals(self.cmd_db_version(self.url, repos_path), 0) self.assertEquals(self.run_db_version(self.url, repos_path), 0)
self.assertFailure(self.cmd('downgrade',self.url, repos_path, 1)) result = self.env.run('migrate downgrade %s %s 1' % (self.url, repos_path), expect_error=True)
self.assertEquals(self.cmd_db_version(self.url, repos_path), 0) self.assertEquals(result.returncode, 2)
self.assertEquals(self.run_db_version(self.url, repos_path), 0)
self.assertSuccess(self.cmd('drop_version_control', self.url, repos_path)) result = self.env.run('migrate drop_version_control %s %s' % (self.url, repos_path))
def _run_test_sqlfile(self, upgrade_script, downgrade_script): def _run_test_sqlfile(self, upgrade_script, downgrade_script):
# TODO: add test script that checks if db really changed # TODO: add test script that checks if db really changed
repos_path = self.tmp() repos_path = self.tmp()
repos_name = 'repos' repos_name = 'repos'
self.assertSuccess(self.cmd('create', repos_path, repos_name))
self.exitcode(self.cmd('drop_version_control', self.url, repos_path)) result = self.env.run('migrate create %s %s' % (repos_path, repos_name))
self.assertSuccess(self.cmd('version_control', self.url, repos_path)) result = self.env.run('migrate drop_version_control %s %s' % (self.url, repos_path), expect_error=True)
self.assertEquals(self.cmd_version(repos_path), 0) result = self.env.run('migrate version_control %s %s' % (self.url, repos_path))
self.assertEquals(self.cmd_db_version(self.url,repos_path), 0) self.assertEquals(self.run_version(repos_path), 0)
self.assertEquals(self.run_db_version(self.url, repos_path), 0)
beforeCount = len(os.listdir(os.path.join(repos_path, 'versions'))) # hmm, this number changes sometimes based on running from svn beforeCount = len(os.listdir(os.path.join(repos_path, 'versions'))) # hmm, this number changes sometimes based on running from svn
self.assertSuccess(self.cmd('script_sql', '--repository=%s' % repos_path, 'postgres')) result = self.env.run('migrate script_sql %s --repository=%s' % ('postgres', repos_path))
self.assertEquals(self.cmd_version(repos_path), 1) self.assertEquals(self.run_version(repos_path), 1)
self.assertEquals(len(os.listdir(os.path.join(repos_path,'versions'))), beforeCount + 2) self.assertEquals(len(os.listdir(os.path.join(repos_path, 'versions'))), beforeCount + 2)
open('%s/versions/001_postgres_upgrade.sql' % repos_path, 'a').write(upgrade_script) open('%s/versions/001_postgres_upgrade.sql' % repos_path, 'a').write(upgrade_script)
open('%s/versions/001_postgres_downgrade.sql' % repos_path, 'a').write(downgrade_script) open('%s/versions/001_postgres_downgrade.sql' % repos_path, 'a').write(downgrade_script)
self.assertEquals(self.cmd_db_version(self.url, repos_path), 0) self.assertEquals(self.run_db_version(self.url, repos_path), 0)
self.assertRaises(Exception, self.engine.text('select * from t_table').execute) self.assertRaises(Exception, self.engine.text('select * from t_table').execute)
self.assertSuccess(self.cmd('upgrade', self.url,repos_path)) result = self.env.run('migrate upgrade %s %s' % (self.url, repos_path))
self.assertEquals(self.cmd_db_version(self.url,repos_path), 1) self.assertEquals(self.run_db_version(self.url, repos_path), 1)
self.engine.text('select * from t_table').execute() self.engine.text('select * from t_table').execute()
self.assertSuccess(self.cmd('downgrade', self.url, repos_path, 0)) result = self.env.run('migrate downgrade %s %s 0' % (self.url, repos_path))
self.assertEquals(self.cmd_db_version(self.url, repos_path), 0) self.assertEquals(self.run_db_version(self.url, repos_path), 0)
self.assertRaises(Exception, self.engine.text('select * from t_table').execute) self.assertRaises(Exception, self.engine.text('select * from t_table').execute)
# The tests below are written with some postgres syntax, but the stuff # The tests below are written with some postgres syntax, but the stuff
# being tested (.sql files) ought to work with any db. # being tested (.sql files) ought to work with any db.
@fixture.usedb(supported='postgres') @usedb(supported='postgres')
def test_sqlfile(self): def test_sqlfile(self):
upgrade_script = """ upgrade_script = """
create table t_table ( create table t_table (
@ -381,8 +348,7 @@ class TestShellDatabase(Shell, fixture.DB):
self.meta.drop_all() self.meta.drop_all()
self._run_test_sqlfile(upgrade_script, downgrade_script) self._run_test_sqlfile(upgrade_script, downgrade_script)
@usedb(supported='postgres')
@fixture.usedb(supported='postgres')
def test_sqlfile_comment(self): def test_sqlfile_comment(self):
upgrade_script = """ upgrade_script = """
-- Comments in SQL break postgres autocommit -- Comments in SQL break postgres autocommit
@ -395,28 +361,28 @@ class TestShellDatabase(Shell, fixture.DB):
-- Comments in SQL break postgres autocommit -- Comments in SQL break postgres autocommit
drop table t_table; drop table t_table;
""" """
self._run_test_sqlfile(upgrade_script,downgrade_script) self._run_test_sqlfile(upgrade_script, downgrade_script)
@fixture.usedb() @usedb()
def test_command_test(self): def test_command_test(self):
repos_name = 'repos_name' repos_name = 'repos_name'
repos_path = self.tmp() repos_path = self.tmp()
self.assertSuccess(self.cmd('create', repos_path, repos_name)) result = self.env.run('migrate create repository_name --repository=%s' % repos_path)
self.exitcode(self.cmd('drop_version_control', self.url, repos_path)) result = self.env.run('migrate drop_version_control %s %s' % (self.url, repos_path), expect_error=True)
self.assertSuccess(self.cmd('version_control', self.url, repos_path)) result = self.env.run('migrate version_control %s %s' % (self.url, repos_path))
self.assertEquals(self.cmd_version(repos_path), 0) self.assertEquals(self.run_version(repos_path), 0)
self.assertEquals(self.cmd_db_version(self.url, repos_path), 0) self.assertEquals(self.run_db_version(self.url, repos_path), 0)
# Empty script should succeed # Empty script should succeed
self.assertSuccess(self.cmd('script', '--repository=%s' % repos_path, 'Desc')) result = self.env.run('migrate script Desc %s' % repos_path)
self.assertSuccess(self.cmd('test', repos_path, self.url)) result = self.env.run('migrate test %s %s' % (self.url, repos_path))
self.assertEquals(self.cmd_version(repos_path), 1) self.assertEquals(self.run_version(repos_path), 1)
self.assertEquals(self.cmd_db_version(self.url, repos_path), 0) self.assertEquals(self.run_db_version(self.url, repos_path), 0)
# Error script should fail # Error script should fail
script_path = self.tmp_py() script_path = self.tmp_py()
script_text=""" script_text='''
from sqlalchemy import * from sqlalchemy import *
from migrate import * from migrate import *
@ -427,26 +393,27 @@ class TestShellDatabase(Shell, fixture.DB):
def downgrade(): def downgrade():
print 'sdfsgf' print 'sdfsgf'
raise Exception() raise Exception()
""".replace("\n ","\n") '''.replace("\n ", "\n")
file = open(script_path, 'w') file = open(script_path, 'w')
file.write(script_text) file.write(script_text)
file.close() file.close()
self.assertFailure(self.cmd('test', repos_path, self.url, 'blah blah')) result = self.env.run('migrate test %s %s bla' % (self.url, repos_path), expect_error=True)
self.assertEquals(self.cmd_version(repos_path), 1) self.assertEqual(result.returncode, 2)
self.assertEquals(self.cmd_db_version(self.url, repos_path),0) self.assertEquals(self.run_version(repos_path), 1)
self.assertEquals(self.run_db_version(self.url, repos_path), 0)
# Nonempty script using migrate_engine should succeed # Nonempty script using migrate_engine should succeed
script_path = self.tmp_py() script_path = self.tmp_py()
script_text=""" script_text = '''
from sqlalchemy import * from sqlalchemy import *
from migrate import * from migrate import *
meta = MetaData(migrate_engine) meta = MetaData(migrate_engine)
account = Table('account',meta, account = Table('account', meta,
Column('id',Integer,primary_key=True), Column('id', Integer, primary_key=True),
Column('login',String(40)), Column('login', String(40)),
Column('passwd',String(40)), Column('passwd', String(40)),
) )
def upgrade(): def upgrade():
# Upgrade operations go here. Don't create your own engine; use the engine # Upgrade operations go here. Don't create your own engine; use the engine
@ -456,113 +423,104 @@ class TestShellDatabase(Shell, fixture.DB):
def downgrade(): def downgrade():
# Operations to reverse the above upgrade go here. # Operations to reverse the above upgrade go here.
meta.drop_all() meta.drop_all()
""".replace("\n ","\n") '''.replace("\n ", "\n")
file = open(script_path, 'w') file = open(script_path, 'w')
file.write(script_text) file.write(script_text)
file.close() file.close()
self.assertSuccess(self.cmd('test', repos_path, self.url)) result = self.env.run('migrate test %s %s' % (self.url, repos_path))
self.assertEquals(self.cmd_version(repos_path), 1) self.assertEquals(self.run_version(repos_path), 1)
self.assertEquals(self.cmd_db_version(self.url, repos_path), 0) self.assertEquals(self.run_db_version(self.url, repos_path), 0)
@fixture.usedb() @usedb()
def test_rundiffs_in_shell(self): def test_rundiffs_in_shell(self):
# This is a variant of the test_schemadiff tests but run through the shell level. # This is a variant of the test_schemadiff tests but run through the shell level.
# These shell tests are hard to debug (since they keep forking processes), so they shouldn't replace the lower-level tests. # These shell tests are hard to debug (since they keep forking processes), so they shouldn't replace the lower-level tests.
repos_name = 'repos_name' repos_name = 'repos_name'
repos_path = self.tmp() repos_path = self.tmp()
script_path = self.tmp_py() script_path = self.tmp_py()
old_model_path = self.tmp_named('oldtestmodel.py') model_module = 'test.fixture.models:meta_rundiffs'
model_path = self.tmp_named('testmodel.py') old_model_module = 'test.fixture.models:meta_old_rundiffs'
# Create empty repository. # Create empty repository.
self.meta = MetaData(self.engine, reflect=True) self.meta = MetaData(self.engine, reflect=True)
self.meta.reflect()
self.meta.drop_all() # in case junk tables are lying around in the test database self.meta.drop_all() # in case junk tables are lying around in the test database
self.assertSuccess(self.cmd('create',repos_path,repos_name))
self.exitcode(self.cmd('drop_version_control',self.url,repos_path)) result = self.env.run('migrate create %s %s' % (repos_path, repos_name))
self.assertSuccess(self.cmd('version_control',self.url,repos_path)) result = self.env.run('migrate drop_version_control %s %s' % (self.url, repos_path), expect_error=True)
self.assertEquals(self.cmd_version(repos_path),0) result = self.env.run('migrate version_control %s %s' % (self.url, repos_path))
self.assertEquals(self.cmd_db_version(self.url,repos_path),0) self.assertEquals(self.run_version(repos_path), 0)
self.assertEquals(self.run_db_version(self.url, repos_path), 0)
# Setup helper script. # Setup helper script.
model_module = 'testmodel:meta' result = self.env.run('migrate manage %s --repository=%s --url=%s --model=%s'\
self.assertSuccess(self.cmd('manage',script_path,'--repository=%s --url=%s --model=%s' % (repos_path, self.url, model_module))) % (script_path, repos_path, self.url, model_module))
self.assert_(os.path.exists(script_path)) self.assert_(os.path.exists(script_path))
# Write old and new model to disk - old model is empty!
script_preamble="""
from sqlalchemy import *
meta = MetaData()
""".replace("\n ","\n")
script_text="""
""".replace("\n ","\n")
open(old_model_path, 'w').write(script_preamble + script_text)
script_text="""
tmp_account_rundiffs = Table('tmp_account_rundiffs',meta,
Column('id',Integer,primary_key=True),
Column('login',String(40)),
Column('passwd',String(40)),
)
""".replace("\n ","\n")
open(model_path, 'w').write(script_preamble + script_text)
# Model is defined but database is empty. # Model is defined but database is empty.
output, exitcode = self.output_and_exitcode('%s %s compare_model_to_db' % (sys.executable, script_path)) result = self.env.run('migrate compare_model_to_db %s %s --model=%s' % (self.url, repos_path, model_module))
assert "tables missing in database: tmp_account_rundiffs" in output, output self.assert_("tables missing in database: tmp_account_rundiffs" in result.stdout)
# Test Deprecation # Test Deprecation
output, exitcode = self.output_and_exitcode('%s %s compare_model_to_db --model=testmodel.meta' % (sys.executable, script_path)) result = self.env.run('migrate compare_model_to_db %s %s --model=%s' % (self.url, repos_path, model_module.replace(":", ".")), expect_error=True)
assert "tables missing in database: tmp_account_rundiffs" in output, output self.assertEqual(result.returncode, 0)
self.assertTrue("DeprecationWarning" in result.stderr)
self.assert_("tables missing in database: tmp_account_rundiffs" in result.stdout)
# Update db to latest model. # Update db to latest model.
output, exitcode = self.output_and_exitcode('%s %s update_db_from_model' % (sys.executable, script_path)) result = self.env.run('migrate update_db_from_model %s %s %s'\
self.assertEquals(exitcode, None) % (self.url, repos_path, model_module))
self.assertEquals(self.cmd_version(repos_path),0) self.assertEquals(self.run_version(repos_path), 0)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0) # version did not get bumped yet because new version not yet created self.assertEquals(self.run_db_version(self.url, repos_path), 0) # version did not get bumped yet because new version not yet created
output, exitcode = self.output_and_exitcode('%s %s compare_model_to_db' % (sys.executable, script_path))
assert "No schema diffs" in output, output result = self.env.run('migrate compare_model_to_db %s %s %s'\
output, exitcode = self.output_and_exitcode('%s %s create_model' % (sys.executable, script_path)) % (self.url, repos_path, model_module))
output = output.replace(genmodel.HEADER.strip(), '') # need strip b/c output_and_exitcode called strip self.assert_("No schema diffs" in result.stdout)
assert """tmp_account_rundiffs = Table('tmp_account_rundiffs', meta,
result = self.env.run('migrate drop_version_control %s %s' % (self.url, repos_path), expect_error=True)
result = self.env.run('migrate version_control %s %s' % (self.url, repos_path))
result = self.env.run('migrate create_model %s %s' % (self.url, repos_path))
self.assertTrue("""tmp_account_rundiffs = Table('tmp_account_rundiffs', meta,
Column('id', Integer(), primary_key=True, nullable=False), Column('id', Integer(), primary_key=True, nullable=False),
Column('login', String(length=None, convert_unicode=False, assert_unicode=None)), Column('login', String(length=None, convert_unicode=False, assert_unicode=None)),
Column('passwd', String(length=None, convert_unicode=False, assert_unicode=None)),""" in output.strip(), output Column('passwd', String(length=None, convert_unicode=False, assert_unicode=None))""" in result.stdout)
# We're happy with db changes, make first db upgrade script to go from version 0 -> 1. # We're happy with db changes, make first db upgrade script to go from version 0 -> 1.
output, exitcode = self.output_and_exitcode('%s %s make_update_script_for_model' % (sys.executable, script_path)) # intentionally omit a parameter result = self.env.run('migrate make_update_script_for_model', expect_error=True)
self.assertEquals('Not enough arguments' in output, True) self.assertTrue('Not enough arguments' in result.stderr)
output, exitcode = self.output_and_exitcode('%s %s make_update_script_for_model --oldmodel=oldtestmodel:meta' % (sys.executable, script_path))
self.assertEqualsIgnoreWhitespace(output,
"""from sqlalchemy import *
from migrate import *
meta = MetaData() result_script = self.env.run('migrate make_update_script_for_model %s %s %s %s'\
tmp_account_rundiffs = Table('tmp_account_rundiffs', meta, % (self.url, repos_path, old_model_module, model_module))
Column('id', Integer(), primary_key=True, nullable=False), self.assertEqualsIgnoreWhitespace(result_script.stdout,
Column('login', String(length=40, convert_unicode=False, assert_unicode=None)), '''from sqlalchemy import *
Column('passwd', String(length=40, convert_unicode=False, assert_unicode=None)), from migrate import *
)
def upgrade(migrate_engine): meta = MetaData()
# Upgrade operations go here. Don't create your own engine; bind migrate_engine tmp_account_rundiffs = Table('tmp_account_rundiffs', meta,
# to your metadata Column('id', Integer(), primary_key=True, nullable=False),
meta.bind(migrate_engine) Column('login', String(length=40, convert_unicode=False, assert_unicode=None)),
tmp_account_rundiffs.create() Column('passwd', String(length=40, convert_unicode=False, assert_unicode=None)),
)
def downgrade(migrate_engine): def upgrade(migrate_engine):
# Operations to reverse the above upgrade go here. # Upgrade operations go here. Don't create your own engine; bind migrate_engine
meta.bind(migrate_engine) # to your metadata
tmp_account_rundiffs.drop()""") meta.bind = migrate_engine
tmp_account_rundiffs.create()
def downgrade(migrate_engine):
# Operations to reverse the above upgrade go here.
meta.bind = migrate_engine
tmp_account_rundiffs.drop()''')
# Save the upgrade script. # Save the upgrade script.
self.assertSuccess(self.cmd('script', '--repository=%s' % repos_path, 'Desc')) result = self.env.run('migrate script Desc %s' % repos_path)
upgrade_script_path = '%s/versions/001_Desc.py' % repos_path upgrade_script_path = '%s/versions/001_Desc.py' % repos_path
open(upgrade_script_path, 'w').write(output) open(upgrade_script_path, 'w').write(result_script.stdout)
#output, exitcode = self.output_and_exitcode('%s %s test %s' % (sys.executable, script_path, upgrade_script_path)) # no, we already upgraded the db above
#self.assertEquals(output, "") result = self.env.run('migrate compare_model_to_db %s %s %s'\
output, exitcode = self.output_and_exitcode('%s %s update_db_from_model' % (sys.executable, script_path)) # bump the db_version % (self.url, repos_path, model_module))
self.assertEquals(exitcode, None) self.assert_("No schema diffs" in result.stdout)
self.assertEquals(self.cmd_version(repos_path),1)
self.assertEquals(self.cmd_db_version(self.url,repos_path),1) self.meta.drop_all() # in case junk tables are lying around in the test database

View File

@ -36,10 +36,13 @@ class TestUtil(fixture.Pathed):
engine_arg_assert_unicode=True) engine_arg_assert_unicode=True)
self.assertTrue(engine.dialect.assert_unicode) self.assertTrue(engine.dialect.assert_unicode)
# deprecated echo= parameter # deprecated echo=True parameter
engine = construct_engine(url, echo='True') engine = construct_engine(url, echo='True')
self.assertTrue(engine.echo) self.assertTrue(engine.echo)
# unsupported argument
self.assertRaises(ValueError, construct_engine, 1)
def test_asbool(self): def test_asbool(self):
"""test asbool parsing""" """test asbool parsing"""
result = asbool(True) result = asbool(True)

View File

@ -3,6 +3,7 @@
from test import fixture from test import fixture
from migrate.versioning.version import * from migrate.versioning.version import *
from migrate.versioning.exceptions import *
class TestVerNum(fixture.Base): class TestVerNum(fixture.Base):
@ -12,6 +13,11 @@ class TestVerNum(fixture.Base):
for version in versions: for version in versions:
self.assertRaises(ValueError, VerNum, version) self.assertRaises(ValueError, VerNum, version)
def test_str(self):
"""Test str and repr version numbers"""
self.assertEqual(str(VerNum(2)), '2')
self.assertEqual(repr(VerNum(2)), '<VerNum(2)>')
def test_is(self): def test_is(self):
"""Two version with the same number should be equal""" """Two version with the same number should be equal"""
a = VerNum(1) a = VerNum(1)
@ -62,6 +68,7 @@ class TestVerNum(fixture.Base):
self.assert_(VerNum(2) >= 1) self.assert_(VerNum(2) >= 1)
self.assertFalse(VerNum(1) >= 2) self.assertFalse(VerNum(1) >= 2)
class TestVersion(fixture.Pathed): class TestVersion(fixture.Pathed):
def setUp(self): def setUp(self):
@ -91,12 +98,18 @@ class TestVersion(fixture.Pathed):
coll2 = Collection(self.temp_usable_dir) coll2 = Collection(self.temp_usable_dir)
self.assertEqual(coll.versions, coll2.versions) self.assertEqual(coll.versions, coll2.versions)
#def test_collection_unicode(self): Collection.clear()
def test_old_repository(self):
open(os.path.join(self.temp_usable_dir, '1'), 'w')
self.assertRaises(Exception, Collection, self.temp_usable_dir)
#TODO: def test_collection_unicode(self):
# pass # pass
def test_create_new_python_version(self): def test_create_new_python_version(self):
coll = Collection(self.temp_usable_dir) coll = Collection(self.temp_usable_dir)
coll.create_new_python_version("foo bar") coll.create_new_python_version("'")
ver = coll.version() ver = coll.version()
self.assert_(ver.script().source()) self.assert_(ver.script().source())
@ -140,3 +153,12 @@ class TestVersion(fixture.Pathed):
ver = Version(1, path, [sqlite_upgrade_file, python_file]) ver = Version(1, path, [sqlite_upgrade_file, python_file])
self.assertEquals(os.path.basename(ver.script('postgres', 'upgrade').path), python_file) self.assertEquals(os.path.basename(ver.script('postgres', 'upgrade').path), python_file)
def test_bad_version(self):
ver = Version(1, self.temp_usable_dir, [])
self.assertRaises(ScriptError, ver.add_script, '123.sql')
pyscript = os.path.join(self.temp_usable_dir, 'bla.py')
open(pyscript, 'w')
ver.add_script(pyscript)
self.assertRaises(ScriptError, ver.add_script, 'bla.py')