Issue 34; preview_sql now correctly displays SQL on python and SQL scripts. (tests added, docs still missing)

This commit is contained in:
iElectric 2009-06-06 10:34:22 +00:00
parent 4356e8b582
commit 938bbf9bf3
10 changed files with 225 additions and 132 deletions

View File

@ -1,7 +1,9 @@
0.5.4
- fixed preview_sql parameter for downgrade/upgrade. Now it prints SQL if the step is SQL script
and runs step with mocked engine to only print SQL statements if ORM is used. [Domen Kozar]
- use entrypoints terminology to specify dotted model names (module.model.User) [Domen Kozar]
- added engine_dict and engine_arg_* parameters to all api functions [Domen Kozar]
- make --echo parameter a bit more forgivable (better Python API support) [Domen Kozar]
- added engine_dict and engine_arg_* parameters to all api functions (deprecated echo) [Domen Kozar]
- make --echo parameter a bit more forgivable (better Python API support) [Domen Kozar]
- apply patch to refactor cmd line parsing for Issue 54 by Domen Kozar
0.5.3

View File

@ -16,11 +16,10 @@ import sys
import inspect
import warnings
from sqlalchemy import create_engine
from migrate.versioning import (exceptions, repository, schema, version,
script as script_) # command name conflict
from migrate.versioning.util import asbool, catch_known_errors, guess_obj_type
from migrate.versioning.util import catch_known_errors, construct_engine
__all__ = [
'help',
@ -46,6 +45,7 @@ Repository = repository.Repository
ControlledSchema = schema.ControlledSchema
VerNum = version.VerNum
PythonScript = script_.PythonScript
SqlScript = script_.SqlScript
# deprecated
@ -117,7 +117,7 @@ def test(repository, url=None, **opts):
bad state. You should therefore better run the test on a copy of
your database.
"""
engine = _construct_engine(url, **opts)
engine = construct_engine(url, **opts)
repos = Repository(repository)
script = repos.version(None).script()
@ -179,7 +179,7 @@ def version_control(url, repository, version=None, **opts):
identical to what it would be if the database were created from
scratch.
"""
engine = _construct_engine(url, **opts)
engine = construct_engine(url, **opts)
ControlledSchema.create(engine, repository, version)
@ -192,7 +192,7 @@ def db_version(url, repository, **opts):
The url should be any valid SQLAlchemy connection string.
"""
engine = _construct_engine(url, **opts)
engine = construct_engine(url, **opts)
schema = ControlledSchema(engine, repository)
return schema.version
@ -236,7 +236,7 @@ def drop_version_control(url, repository, **opts):
Removes version control from a database.
"""
engine = _construct_engine(url, **opts)
engine = construct_engine(url, **opts)
schema = ControlledSchema(engine, repository)
schema.drop()
@ -268,7 +268,7 @@ def compare_model_to_db(url, model, repository, **opts):
NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label
engine = _construct_engine(url, **opts)
engine = construct_engine(url, **opts)
print ControlledSchema.compare_model_to_db(engine, model, repository)
@ -279,7 +279,7 @@ def create_model(url, repository, **opts):
NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label
engine = _construct_engine(url, **opts)
engine = construct_engine(url, **opts)
declarative = opts.get('declarative', False)
print ControlledSchema.create_model(engine, repository, declarative)
@ -294,7 +294,7 @@ def make_update_script_for_model(url, oldmodel, model, repository, **opts):
NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label
engine = _construct_engine(url, **opts)
engine = construct_engine(url, **opts)
print PythonScript.make_update_script_for_model(
engine, oldmodel, model, repository, **opts)
@ -308,30 +308,37 @@ def update_db_from_model(url, model, repository, **opts):
NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label
engine = _construct_engine(url, **opts)
engine = construct_engine(url, **opts)
schema = ControlledSchema(engine, repository)
schema.update_db_from_model(model)
def _migrate(url, repository, version, upgrade, err, **opts):
engine = _construct_engine(url, **opts)
engine = construct_engine(url, **opts)
schema = ControlledSchema(engine, repository)
version = _migrate_version(schema, version, upgrade, err)
changeset = schema.changeset(version)
for ver, change in changeset:
nextver = ver + changeset.step
print '%s -> %s... ' % (ver, nextver),
print '%s -> %s... ' % (ver, nextver)
if opts.get('preview_sql'):
print
print change.log
if isinstance(change, PythonScript):
print change.preview_sql(url, changeset.step, **opts)
elif isinstance(change, SqlScript):
print change.source()
elif opts.get('preview_py'):
source_ver = max(ver, nextver)
module = schema.repository.version(source_ver).script().module
funcname = upgrade and "upgrade" or "downgrade"
func = getattr(module, funcname)
print
print inspect.getsource(module.upgrade)
if isinstance(change, PythonScript):
print inspect.getsource(func)
else:
raise UsageError("Python source can be only displayed"
" for python migration files")
else:
schema.runchange(ver, change, changeset.step)
print 'done'
@ -352,39 +359,3 @@ def _migrate_version(schema, version, upgrade, err):
if not direction:
raise exceptions.KnownError(err % (cur, version))
return version
def _construct_engine(url, **opts):
"""Constructs and returns SQLAlchemy engine.
Currently, there are 2 ways to pass create_engine options to api functions:
* keyword parameters (starting with `engine_arg_*`)
* python dictionary of options (`engine_dict`)
NOTE: keyword parameters override `engine_dict` values.
.. versionadded:: 0.5.4
"""
# TODO: include docs
# get options for create_engine
if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict):
kwargs = opts['engine_dict']
else:
kwargs = dict()
# DEPRECATED: handle echo the old way
echo = asbool(opts.get('echo', False))
if echo:
warnings.warn('echo=True parameter is deprecated, pass '
'engine_arg_echo=True or engine_dict={"echo": True}',
DeprecationWarning)
kwargs['echo'] = echo
# parse keyword arguments
for key, value in opts.iteritems():
if key.startswith('engine_arg_'):
kwargs[key[11:]] = guess_obj_type(value)
return create_engine(url, **kwargs)

View File

@ -1,6 +1,9 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from migrate.versioning.base import log,operations
from migrate.versioning import pathed,exceptions
# import migrate.run
class BaseScript(pathed.Pathed):
"""Base class for other types of scripts
@ -17,10 +20,10 @@ class BaseScript(pathed.Pathed):
"""
def __init__(self,path):
log.info('Loading script %s...'%path)
log.info('Loading script %s...' % path)
self.verify(path)
super(BaseScript,self).__init__(path)
log.info('Script %s loaded successfully'%path)
super(BaseScript, self).__init__(path)
log.info('Script %s loaded successfully' % path)
@classmethod
def verify(cls,path):
@ -33,10 +36,10 @@ class BaseScript(pathed.Pathed):
raise exceptions.InvalidScriptError(path)
def source(self):
fd=open(self.path)
ret=fd.read()
fd = open(self.path)
ret = fd.read()
fd.close()
return ret
def run(self,engine):
def run(self, engine):
raise NotImplementedError()

View File

@ -1,15 +1,20 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import shutil
#import migrate.run
from StringIO import StringIO
import migrate
from migrate.versioning import exceptions, genmodel, schemadiff
from migrate.versioning.base import operations
from migrate.versioning.template import template
from migrate.versioning.script import base
from migrate.versioning.util import import_path, loadModel
import migrate
from migrate.versioning.util import import_path, loadModel, construct_engine
class PythonScript(base.BaseScript):
@classmethod
def create(cls,path,**opts):
def create(cls, path, **opts):
"""Create an empty migration script"""
cls.require_notfound(path)
@ -18,29 +23,38 @@ class PythonScript(base.BaseScript):
# different one later.
template_file = None
src = template.get_script(template_file)
shutil.copy(src,path)
shutil.copy(src, path)
@classmethod
def make_update_script_for_model(cls,engine,oldmodel,model,repository,**opts):
def make_update_script_for_model(cls, engine, oldmodel,
model, repository, **opts):
"""Create a migration script"""
# Compute differences.
if isinstance(repository, basestring):
from migrate.versioning.repository import Repository # oh dear, an import cycle!
repository=Repository(repository)
# oh dear, an import cycle!
from migrate.versioning.repository import Repository
repository = Repository(repository)
oldmodel = loadModel(oldmodel)
model = loadModel(model)
diff = schemadiff.getDiffOfModelAgainstModel(oldmodel, model, engine, excludeTables=[repository.version_table])
decls, upgradeCommands, downgradeCommands = genmodel.ModelGenerator(diff).toUpgradeDowngradePython()
diff = schemadiff.getDiffOfModelAgainstModel(
oldmodel,
model,
engine,
excludeTables=[repository.version_table])
decls, upgradeCommands, downgradeCommands = \
genmodel.ModelGenerator(diff).toUpgradeDowngradePython()
# Store differences into file.
template_file = None
src = template.get_script(template_file)
contents = open(src).read()
search = 'def upgrade():'
contents = contents.replace(search, decls + '\n\n' + search, 1)
if upgradeCommands: contents = contents.replace(' pass', upgradeCommands, 1)
if downgradeCommands: contents = contents.replace(' pass', downgradeCommands, 1)
contents = contents.replace(search, '\n\n'.join((decls, search)), 1)
if upgradeCommands:
contents = contents.replace(' pass', upgradeCommands, 1)
if downgradeCommands:
contents = contents.replace(' pass', downgradeCommands, 1)
return contents
@classmethod
@ -54,31 +68,31 @@ class PythonScript(base.BaseScript):
raise
try:
assert callable(module.upgrade)
except Exception,e:
raise exceptions.InvalidScriptError(path+': %s'%str(e))
except Exception, e:
raise exceptions.InvalidScriptError(path + ': %s' % str(e))
return module
def _get_module(self):
if not hasattr(self,'_module'):
self._module = self.verify_module(self.path)
return self._module
module = property(_get_module)
def preview_sql(self, url, step, **args):
"""Mock engine to store all executable calls in a string \
and execute the step"""
buf = StringIO()
args['engine_arg_strategy'] = 'mock'
args['engine_arg_executor'] = lambda s, p='': buf.write(s + p)
engine = construct_engine(url, **args)
self.run(engine, step)
def _func(self,funcname):
fn = getattr(self.module, funcname, None)
if not fn:
msg = "The function %s is not defined in this script"
raise exceptions.ScriptError(msg%funcname)
return fn
return buf.getvalue()
def run(self,engine,step):
def run(self, engine, step):
"""Core method of Script file. \
Exectues update() or downgrade() function"""
if step > 0:
op = 'upgrade'
elif step < 0:
op = 'downgrade'
else:
raise exceptions.ScriptError("%d is not a valid step"%step)
raise exceptions.ScriptError("%d is not a valid step" % step)
funcname = base.operations[op]
migrate.migrate_engine = engine
@ -87,3 +101,17 @@ class PythonScript(base.BaseScript):
func()
migrate.migrate_engine = None
#migrate.run.migrate_engine = migrate.migrate_engine = None
def _get_module(self):
if not hasattr(self,'_module'):
self._module = self.verify_module(self.path)
return self._module
module = property(_get_module)
def _func(self, funcname):
fn = getattr(self.module, funcname, None)
if not fn:
msg = "The function %s is not defined in this script"
raise exceptions.ScriptError(msg%funcname)
return fn

View File

@ -5,6 +5,8 @@ import warnings
from decorator import decorator
from pkg_resources import EntryPoint
from sqlalchemy import create_engine
from migrate.versioning import exceptions
from migrate.versioning.util.keyedinstance import KeyedInstance
from migrate.versioning.util.importpath import import_path
@ -33,7 +35,10 @@ def asbool(obj):
return False
else:
raise ValueError("String is not true/false: %r" % obj)
return bool(obj)
if obj in (True, False):
return bool(obj)
else:
raise ValueError("String is not true/false: %r" % obj)
def guess_obj_type(obj):
"""Do everything to guess object type from string"""
@ -63,3 +68,38 @@ def catch_known_errors(f, *a, **kw):
f(*a, **kw)
except exceptions.PathFoundError, e:
raise exceptions.KnownError("The path %s already exists" % e.args[0])
def construct_engine(url, **opts):
"""Constructs and returns SQLAlchemy engine.
Currently, there are 2 ways to pass create_engine options to api functions:
* keyword parameters (starting with `engine_arg_*`)
* python dictionary of options (`engine_dict`)
NOTE: keyword parameters override `engine_dict` values.
.. versionadded:: 0.5.4
"""
# TODO: include docs
# get options for create_engine
if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict):
kwargs = opts['engine_dict']
else:
kwargs = dict()
# DEPRECATED: handle echo the old way
echo = asbool(opts.get('echo', False))
if echo:
warnings.warn('echo=True parameter is deprecated, pass '
'engine_arg_echo=True or engine_dict={"echo": True}',
DeprecationWarning)
kwargs['echo'] = echo
# parse keyword arguments
for key, value in opts.iteritems():
if key.startswith('engine_arg_'):
kwargs[key[11:]] = guess_obj_type(value)
return create_engine(url, **kwargs)

View File

@ -21,7 +21,7 @@ class Base(unittest.TestCase):
def createLines(s):
s = s.replace(' ', '')
lines = s.split('\n')
return [line for line in lines if line]
return filter(None, lines)
lines1 = createLines(v1)
lines2 = createLines(v2)
self.assertEquals(len(lines1), len(lines2))

View File

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
import os
from decorator import decorator
from sqlalchemy import create_engine, Table, MetaData
from sqlalchemy.orm import create_session
@ -74,6 +75,7 @@ def usedb(supported=None, not_supported=None):
yield func, self
self._teardown()
entangle.__name__ = func.__name__
entangle.__doc__ = func.__doc__
return entangle
return dec

View File

@ -1,49 +1,53 @@
import os,shutil,tempfile
import base
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import shutil
import tempfile
from test.fixture import base
class Pathed(base.Base):
# Temporary files
#repos='/tmp/test_repos_091x10'
#config=repos+'/migrate.cfg'
#script='/tmp/test_migration_script.py'
_tmpdir=tempfile.mkdtemp()
_tmpdir = tempfile.mkdtemp()
@classmethod
def _tmp(cls,prefix='',suffix=''):
def _tmp(cls, prefix='', suffix=''):
"""Generate a temporary file name that doesn't exist
All filenames are generated inside a temporary directory created by
tempfile.mkdtemp(); only the creating user has access to this directory.
It should be secure to return a nonexistant temp filename in this
directory, unless the user is messing with their own files.
"""
file,ret = tempfile.mkstemp(suffix,prefix,cls._tmpdir)
file, ret = tempfile.mkstemp(suffix,prefix,cls._tmpdir)
os.close(file)
os.remove(ret)
return ret
@classmethod
def tmp(cls,*p,**k):
return cls._tmp(*p,**k)
def tmp(cls, *p, **k):
return cls._tmp(*p, **k)
@classmethod
def tmp_py(cls,*p,**k):
return cls._tmp(suffix='.py',*p,**k)
def tmp_py(cls, *p, **k):
return cls._tmp(suffix='.py', *p, **k)
@classmethod
def tmp_sql(cls,*p,**k):
return cls._tmp(suffix='.sql',*p,**k)
def tmp_sql(cls, *p, **k):
return cls._tmp(suffix='.sql', *p, **k)
@classmethod
def tmp_named(cls,name):
return os.path.join(cls._tmpdir,name)
def tmp_named(cls, name):
return os.path.join(cls._tmpdir, name)
@classmethod
def tmp_repos(cls,*p,**k):
return cls._tmp(*p,**k)
def tmp_repos(cls, *p, **k):
return cls._tmp(*p, **k)
@classmethod
def purge(cls,path):
def purge(cls, path):
"""Removes this path if it exists, in preparation for tests
Careful - all tests should take place in /tmp.
We don't want to accidentally wipe stuff out...
@ -54,6 +58,6 @@ class Pathed(base.Base):
else:
os.remove(path)
if path.endswith('.py'):
pyc = path+'c'
pyc = path + 'c'
if os.path.exists(pyc):
os.remove(pyc)

View File

@ -1,13 +1,19 @@
from test import fixture
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import shutil
from migrate.versioning.script import *
from migrate.versioning import exceptions, version
import os,shutil
from test import fixture
class TestPyScript(fixture.Pathed):
cls = PythonScript
def test_create(self):
"""We can create a migration script"""
path=self.tmp_py()
path = self.tmp_py()
# Creating a file that doesn't exist should succeed
self.cls.create(path)
self.assert_(os.path.exists(path))
@ -18,8 +24,8 @@ class TestPyScript(fixture.Pathed):
def test_verify_notfound(self):
"""Correctly verify a python migration script: nonexistant file"""
path=self.tmp_py()
self.assert_(not os.path.exists(path))
path = self.tmp_py()
self.assertFalse(os.path.exists(path))
# Fails on empty path
self.assertRaises(exceptions.InvalidScriptError,self.cls.verify,path)
self.assertRaises(exceptions.InvalidScriptError,self.cls,path)
@ -38,19 +44,52 @@ class TestPyScript(fixture.Pathed):
def test_verify_nofuncs(self):
"""Correctly verify a python migration script: valid python file; no upgrade func"""
path=self.tmp_py()
path = self.tmp_py()
# Create empty file
f=open(path,'w')
f = open(path, 'w')
f.write("def zergling():\n\tprint 'rush'")
f.close()
self.assertRaises(exceptions.InvalidScriptError,self.cls.verify_module,path)
self.assertRaises(exceptions.InvalidScriptError, self.cls.verify_module, path)
# script isn't verified on creation, but on module reference
py = self.cls(path)
self.assertRaises(exceptions.InvalidScriptError,(lambda x: x.module),py)
@fixture.usedb(supported='sqlite')
def test_preview_sql(self):
"""Preview SQL abstract from ORM layer (sqlite)"""
path = self.tmp_py()
f = open(path, 'w')
content = """
from migrate import *
from sqlalchemy import *
metadata = MetaData(migrate_engine)
UserGroup = Table('Link', metadata,
Column('link1ID', Integer),
Column('link2ID', Integer),
UniqueConstraint('link1ID', 'link2ID'))
def upgrade():
metadata.create_all()
"""
f.write(content)
f.close()
pyscript = self.cls(path)
SQL = pyscript.preview_sql(self.url, 1)
self.assertEqualsIgnoreWhitespace("""
CREATE TABLE "Link"
("link1ID" INTEGER,
"link2ID" INTEGER,
UNIQUE ("link1ID", "link2ID"))
""", SQL)
# TODO: test: No SQL should be executed!
def test_verify_success(self):
"""Correctly verify a python migration script: success"""
path=self.tmp_py()
path = self.tmp_py()
# Succeeds after creating
self.cls.create(path)
self.cls.verify(path)
@ -66,8 +105,8 @@ class TestSqlScript(fixture.Pathed):
# Create files -- files must be present or you'll get an exception later.
sqlite_upgrade_file = '001_sqlite_upgrade.sql'
default_upgrade_file = '001_default_upgrade.sql'
for file in [sqlite_upgrade_file, default_upgrade_file]:
filepath = '%s/%s' % (path, file)
for file_ in [sqlite_upgrade_file, default_upgrade_file]:
filepath = '%s/%s' % (path, file_)
open(filepath, 'w').close()
ver = version.Version(1, path, [sqlite_upgrade_file])

View File

@ -163,21 +163,21 @@ class TestShellCommands(Shell):
"""Construct engine the smart way"""
url = 'sqlite://'
engine = api._construct_engine(url)
engine = api.construct_engine(url)
self.assert_(engine.name == 'sqlite')
# keyword arg
engine = api._construct_engine(url, engine_arg_assert_unicode=True)
self.assert_(engine.dialect.assert_unicode)
engine = api.construct_engine(url, engine_arg_assert_unicode=True)
self.assertTrue(engine.dialect.assert_unicode)
# dict
engine = api._construct_engine(url, engine_dict={'assert_unicode': True})
self.assert_(engine.dialect.assert_unicode)
engine = api.construct_engine(url, engine_dict={'assert_unicode': True})
self.assertTrue(engine.dialect.assert_unicode)
# test precedance
engine = api._construct_engine(url, engine_dict={'assert_unicode': False},
engine = api.construct_engine(url, engine_dict={'assert_unicode': False},
engine_arg_assert_unicode=True)
self.assert_(engine.dialect.assert_unicode)
self.assertTrue(engine.dialect.assert_unicode)
def test_manage(self):
"""Create a project management script"""
@ -327,8 +327,12 @@ class TestShellDatabase(Shell, fixture.DB):
# Add a script to the repository; upgrade the db
self.assertSuccess(self.cmd('script', '--repository=%s' % repos_path, 'Desc'))
self.assertEquals(self.cmd_version(repos_path), 1)
self.assertEquals(self.cmd_db_version(self.url, repos_path), 0)
# Test preview
self.assertSuccess(self.cmd('upgrade', self.url, repos_path, 0, "--preview_sql"))
self.assertSuccess(self.cmd('upgrade', self.url, repos_path, 0, "--preview_py"))
self.assertSuccess(self.cmd('upgrade', self.url, repos_path))
self.assertEquals(self.cmd_db_version(self.url, repos_path), 1)
@ -345,7 +349,7 @@ class TestShellDatabase(Shell, fixture.DB):
self.assertEquals(self.cmd_db_version(self.url, repos_path), 0)
self.assertSuccess(self.cmd('drop_version_control', self.url, repos_path))
def _run_test_sqlfile(self, upgrade_script, downgrade_script):
# TODO: add test script that checks if db really changed