From 938bbf9bf38d30c5f7c69ac88896333486386d49 Mon Sep 17 00:00:00 2001 From: iElectric Date: Sat, 6 Jun 2009 10:34:22 +0000 Subject: [PATCH] Issue 34; preview_sql now correctly displays SQL on python and SQL scripts. (tests added, docs still missing) --- CHANGELOG | 6 ++- migrate/versioning/api.py | 77 +++++++++----------------- migrate/versioning/script/base.py | 17 +++--- migrate/versioning/script/py.py | 84 +++++++++++++++++++---------- migrate/versioning/util/__init__.py | 42 ++++++++++++++- test/fixture/base.py | 2 +- test/fixture/database.py | 2 + test/fixture/pathed.py | 44 ++++++++------- test/versioning/test_script.py | 61 +++++++++++++++++---- test/versioning/test_shell.py | 22 ++++---- 10 files changed, 225 insertions(+), 132 deletions(-) diff --git a/CHANGELOG b/CHANGELOG index 6958671..8968def 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,7 +1,9 @@ 0.5.4 +- fixed preview_sql parameter for downgrade/upgrade. Now it prints SQL if the step is SQL script +and runs step with mocked engine to only print SQL statements if ORM is used. [Domen Kozar] - use entrypoints terminology to specify dotted model names (module.model.User) [Domen Kozar] -- added engine_dict and engine_arg_* parameters to all api functions [Domen Kozar] -- make --echo parameter a bit more forgivable (better Python API support) [Domen Kozar] +- added engine_dict and engine_arg_* parameters to all api functions (deprecated echo) [Domen Kozar] +- make --echo parameter a bit more forgivable (better Python API support) [Domen Kozar] - apply patch to refactor cmd line parsing for Issue 54 by Domen Kozar 0.5.3 diff --git a/migrate/versioning/api.py b/migrate/versioning/api.py index c996b89..2651177 100644 --- a/migrate/versioning/api.py +++ b/migrate/versioning/api.py @@ -16,11 +16,10 @@ import sys import inspect import warnings -from sqlalchemy import create_engine - from migrate.versioning import (exceptions, repository, schema, version, script as script_) # command name conflict -from migrate.versioning.util import asbool, catch_known_errors, guess_obj_type +from migrate.versioning.util import catch_known_errors, construct_engine + __all__ = [ 'help', @@ -46,6 +45,7 @@ Repository = repository.Repository ControlledSchema = schema.ControlledSchema VerNum = version.VerNum PythonScript = script_.PythonScript +SqlScript = script_.SqlScript # deprecated @@ -117,7 +117,7 @@ def test(repository, url=None, **opts): bad state. You should therefore better run the test on a copy of your database. """ - engine = _construct_engine(url, **opts) + engine = construct_engine(url, **opts) repos = Repository(repository) script = repos.version(None).script() @@ -179,7 +179,7 @@ def version_control(url, repository, version=None, **opts): identical to what it would be if the database were created from scratch. """ - engine = _construct_engine(url, **opts) + engine = construct_engine(url, **opts) ControlledSchema.create(engine, repository, version) @@ -192,7 +192,7 @@ def db_version(url, repository, **opts): The url should be any valid SQLAlchemy connection string. """ - engine = _construct_engine(url, **opts) + engine = construct_engine(url, **opts) schema = ControlledSchema(engine, repository) return schema.version @@ -236,7 +236,7 @@ def drop_version_control(url, repository, **opts): Removes version control from a database. """ - engine = _construct_engine(url, **opts) + engine = construct_engine(url, **opts) schema = ControlledSchema(engine, repository) schema.drop() @@ -268,7 +268,7 @@ def compare_model_to_db(url, model, repository, **opts): NOTE: This is EXPERIMENTAL. """ # TODO: get rid of EXPERIMENTAL label - engine = _construct_engine(url, **opts) + engine = construct_engine(url, **opts) print ControlledSchema.compare_model_to_db(engine, model, repository) @@ -279,7 +279,7 @@ def create_model(url, repository, **opts): NOTE: This is EXPERIMENTAL. """ # TODO: get rid of EXPERIMENTAL label - engine = _construct_engine(url, **opts) + engine = construct_engine(url, **opts) declarative = opts.get('declarative', False) print ControlledSchema.create_model(engine, repository, declarative) @@ -294,7 +294,7 @@ def make_update_script_for_model(url, oldmodel, model, repository, **opts): NOTE: This is EXPERIMENTAL. """ # TODO: get rid of EXPERIMENTAL label - engine = _construct_engine(url, **opts) + engine = construct_engine(url, **opts) print PythonScript.make_update_script_for_model( engine, oldmodel, model, repository, **opts) @@ -308,30 +308,37 @@ def update_db_from_model(url, model, repository, **opts): NOTE: This is EXPERIMENTAL. """ # TODO: get rid of EXPERIMENTAL label - engine = _construct_engine(url, **opts) + engine = construct_engine(url, **opts) schema = ControlledSchema(engine, repository) schema.update_db_from_model(model) def _migrate(url, repository, version, upgrade, err, **opts): - engine = _construct_engine(url, **opts) + engine = construct_engine(url, **opts) schema = ControlledSchema(engine, repository) version = _migrate_version(schema, version, upgrade, err) changeset = schema.changeset(version) for ver, change in changeset: nextver = ver + changeset.step - print '%s -> %s... ' % (ver, nextver), + print '%s -> %s... ' % (ver, nextver) + if opts.get('preview_sql'): - print - print change.log + if isinstance(change, PythonScript): + print change.preview_sql(url, changeset.step, **opts) + elif isinstance(change, SqlScript): + print change.source() + elif opts.get('preview_py'): source_ver = max(ver, nextver) module = schema.repository.version(source_ver).script().module funcname = upgrade and "upgrade" or "downgrade" func = getattr(module, funcname) - print - print inspect.getsource(module.upgrade) + if isinstance(change, PythonScript): + print inspect.getsource(func) + else: + raise UsageError("Python source can be only displayed" + " for python migration files") else: schema.runchange(ver, change, changeset.step) print 'done' @@ -352,39 +359,3 @@ def _migrate_version(schema, version, upgrade, err): if not direction: raise exceptions.KnownError(err % (cur, version)) return version - - -def _construct_engine(url, **opts): - """Constructs and returns SQLAlchemy engine. - - Currently, there are 2 ways to pass create_engine options to api functions: - - * keyword parameters (starting with `engine_arg_*`) - * python dictionary of options (`engine_dict`) - - NOTE: keyword parameters override `engine_dict` values. - - .. versionadded:: 0.5.4 - """ - # TODO: include docs - - # get options for create_engine - if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict): - kwargs = opts['engine_dict'] - else: - kwargs = dict() - - # DEPRECATED: handle echo the old way - echo = asbool(opts.get('echo', False)) - if echo: - warnings.warn('echo=True parameter is deprecated, pass ' - 'engine_arg_echo=True or engine_dict={"echo": True}', - DeprecationWarning) - kwargs['echo'] = echo - - # parse keyword arguments - for key, value in opts.iteritems(): - if key.startswith('engine_arg_'): - kwargs[key[11:]] = guess_obj_type(value) - - return create_engine(url, **kwargs) diff --git a/migrate/versioning/script/base.py b/migrate/versioning/script/base.py index b34af55..55aadd3 100644 --- a/migrate/versioning/script/base.py +++ b/migrate/versioning/script/base.py @@ -1,6 +1,9 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + from migrate.versioning.base import log,operations from migrate.versioning import pathed,exceptions -# import migrate.run + class BaseScript(pathed.Pathed): """Base class for other types of scripts @@ -17,10 +20,10 @@ class BaseScript(pathed.Pathed): """ def __init__(self,path): - log.info('Loading script %s...'%path) + log.info('Loading script %s...' % path) self.verify(path) - super(BaseScript,self).__init__(path) - log.info('Script %s loaded successfully'%path) + super(BaseScript, self).__init__(path) + log.info('Script %s loaded successfully' % path) @classmethod def verify(cls,path): @@ -33,10 +36,10 @@ class BaseScript(pathed.Pathed): raise exceptions.InvalidScriptError(path) def source(self): - fd=open(self.path) - ret=fd.read() + fd = open(self.path) + ret = fd.read() fd.close() return ret - def run(self,engine): + def run(self, engine): raise NotImplementedError() diff --git a/migrate/versioning/script/py.py b/migrate/versioning/script/py.py index c2e9452..2deeb8d 100644 --- a/migrate/versioning/script/py.py +++ b/migrate/versioning/script/py.py @@ -1,15 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + import shutil -#import migrate.run +from StringIO import StringIO + +import migrate from migrate.versioning import exceptions, genmodel, schemadiff from migrate.versioning.base import operations from migrate.versioning.template import template from migrate.versioning.script import base -from migrate.versioning.util import import_path, loadModel -import migrate +from migrate.versioning.util import import_path, loadModel, construct_engine class PythonScript(base.BaseScript): + @classmethod - def create(cls,path,**opts): + def create(cls, path, **opts): """Create an empty migration script""" cls.require_notfound(path) @@ -18,29 +23,38 @@ class PythonScript(base.BaseScript): # different one later. template_file = None src = template.get_script(template_file) - shutil.copy(src,path) + shutil.copy(src, path) @classmethod - def make_update_script_for_model(cls,engine,oldmodel,model,repository,**opts): + def make_update_script_for_model(cls, engine, oldmodel, + model, repository, **opts): """Create a migration script""" # Compute differences. if isinstance(repository, basestring): - from migrate.versioning.repository import Repository # oh dear, an import cycle! - repository=Repository(repository) + # oh dear, an import cycle! + from migrate.versioning.repository import Repository + repository = Repository(repository) oldmodel = loadModel(oldmodel) model = loadModel(model) - diff = schemadiff.getDiffOfModelAgainstModel(oldmodel, model, engine, excludeTables=[repository.version_table]) - decls, upgradeCommands, downgradeCommands = genmodel.ModelGenerator(diff).toUpgradeDowngradePython() + diff = schemadiff.getDiffOfModelAgainstModel( + oldmodel, + model, + engine, + excludeTables=[repository.version_table]) + decls, upgradeCommands, downgradeCommands = \ + genmodel.ModelGenerator(diff).toUpgradeDowngradePython() # Store differences into file. template_file = None src = template.get_script(template_file) contents = open(src).read() search = 'def upgrade():' - contents = contents.replace(search, decls + '\n\n' + search, 1) - if upgradeCommands: contents = contents.replace(' pass', upgradeCommands, 1) - if downgradeCommands: contents = contents.replace(' pass', downgradeCommands, 1) + contents = contents.replace(search, '\n\n'.join((decls, search)), 1) + if upgradeCommands: + contents = contents.replace(' pass', upgradeCommands, 1) + if downgradeCommands: + contents = contents.replace(' pass', downgradeCommands, 1) return contents @classmethod @@ -54,31 +68,31 @@ class PythonScript(base.BaseScript): raise try: assert callable(module.upgrade) - except Exception,e: - raise exceptions.InvalidScriptError(path+': %s'%str(e)) + except Exception, e: + raise exceptions.InvalidScriptError(path + ': %s' % str(e)) return module - def _get_module(self): - if not hasattr(self,'_module'): - self._module = self.verify_module(self.path) - return self._module - module = property(_get_module) + def preview_sql(self, url, step, **args): + """Mock engine to store all executable calls in a string \ + and execute the step""" + buf = StringIO() + args['engine_arg_strategy'] = 'mock' + args['engine_arg_executor'] = lambda s, p='': buf.write(s + p) + engine = construct_engine(url, **args) + self.run(engine, step) - def _func(self,funcname): - fn = getattr(self.module, funcname, None) - if not fn: - msg = "The function %s is not defined in this script" - raise exceptions.ScriptError(msg%funcname) - return fn + return buf.getvalue() - def run(self,engine,step): + def run(self, engine, step): + """Core method of Script file. \ + Exectues update() or downgrade() function""" if step > 0: op = 'upgrade' elif step < 0: op = 'downgrade' else: - raise exceptions.ScriptError("%d is not a valid step"%step) + raise exceptions.ScriptError("%d is not a valid step" % step) funcname = base.operations[op] migrate.migrate_engine = engine @@ -87,3 +101,17 @@ class PythonScript(base.BaseScript): func() migrate.migrate_engine = None #migrate.run.migrate_engine = migrate.migrate_engine = None + + def _get_module(self): + if not hasattr(self,'_module'): + self._module = self.verify_module(self.path) + return self._module + module = property(_get_module) + + + def _func(self, funcname): + fn = getattr(self.module, funcname, None) + if not fn: + msg = "The function %s is not defined in this script" + raise exceptions.ScriptError(msg%funcname) + return fn diff --git a/migrate/versioning/util/__init__.py b/migrate/versioning/util/__init__.py index 530bfd2..c8bb5b3 100644 --- a/migrate/versioning/util/__init__.py +++ b/migrate/versioning/util/__init__.py @@ -5,6 +5,8 @@ import warnings from decorator import decorator from pkg_resources import EntryPoint +from sqlalchemy import create_engine + from migrate.versioning import exceptions from migrate.versioning.util.keyedinstance import KeyedInstance from migrate.versioning.util.importpath import import_path @@ -33,7 +35,10 @@ def asbool(obj): return False else: raise ValueError("String is not true/false: %r" % obj) - return bool(obj) + if obj in (True, False): + return bool(obj) + else: + raise ValueError("String is not true/false: %r" % obj) def guess_obj_type(obj): """Do everything to guess object type from string""" @@ -63,3 +68,38 @@ def catch_known_errors(f, *a, **kw): f(*a, **kw) except exceptions.PathFoundError, e: raise exceptions.KnownError("The path %s already exists" % e.args[0]) + +def construct_engine(url, **opts): + """Constructs and returns SQLAlchemy engine. + + Currently, there are 2 ways to pass create_engine options to api functions: + + * keyword parameters (starting with `engine_arg_*`) + * python dictionary of options (`engine_dict`) + + NOTE: keyword parameters override `engine_dict` values. + + .. versionadded:: 0.5.4 + """ + # TODO: include docs + + # get options for create_engine + if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict): + kwargs = opts['engine_dict'] + else: + kwargs = dict() + + # DEPRECATED: handle echo the old way + echo = asbool(opts.get('echo', False)) + if echo: + warnings.warn('echo=True parameter is deprecated, pass ' + 'engine_arg_echo=True or engine_dict={"echo": True}', + DeprecationWarning) + kwargs['echo'] = echo + + # parse keyword arguments + for key, value in opts.iteritems(): + if key.startswith('engine_arg_'): + kwargs[key[11:]] = guess_obj_type(value) + + return create_engine(url, **kwargs) diff --git a/test/fixture/base.py b/test/fixture/base.py index d82c692..dd81033 100644 --- a/test/fixture/base.py +++ b/test/fixture/base.py @@ -21,7 +21,7 @@ class Base(unittest.TestCase): def createLines(s): s = s.replace(' ', '') lines = s.split('\n') - return [line for line in lines if line] + return filter(None, lines) lines1 = createLines(v1) lines2 = createLines(v2) self.assertEquals(len(lines1), len(lines2)) diff --git a/test/fixture/database.py b/test/fixture/database.py index cc405a3..5a04e94 100644 --- a/test/fixture/database.py +++ b/test/fixture/database.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- import os +from decorator import decorator from sqlalchemy import create_engine, Table, MetaData from sqlalchemy.orm import create_session @@ -74,6 +75,7 @@ def usedb(supported=None, not_supported=None): yield func, self self._teardown() entangle.__name__ = func.__name__ + entangle.__doc__ = func.__doc__ return entangle return dec diff --git a/test/fixture/pathed.py b/test/fixture/pathed.py index c5e74cd..2a0e7a6 100644 --- a/test/fixture/pathed.py +++ b/test/fixture/pathed.py @@ -1,49 +1,53 @@ -import os,shutil,tempfile -import base +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import shutil +import tempfile + +from test.fixture import base + class Pathed(base.Base): # Temporary files - #repos='/tmp/test_repos_091x10' - #config=repos+'/migrate.cfg' - #script='/tmp/test_migration_script.py' - _tmpdir=tempfile.mkdtemp() + _tmpdir = tempfile.mkdtemp() @classmethod - def _tmp(cls,prefix='',suffix=''): + def _tmp(cls, prefix='', suffix=''): """Generate a temporary file name that doesn't exist All filenames are generated inside a temporary directory created by tempfile.mkdtemp(); only the creating user has access to this directory. It should be secure to return a nonexistant temp filename in this directory, unless the user is messing with their own files. """ - file,ret = tempfile.mkstemp(suffix,prefix,cls._tmpdir) + file, ret = tempfile.mkstemp(suffix,prefix,cls._tmpdir) os.close(file) os.remove(ret) return ret @classmethod - def tmp(cls,*p,**k): - return cls._tmp(*p,**k) + def tmp(cls, *p, **k): + return cls._tmp(*p, **k) @classmethod - def tmp_py(cls,*p,**k): - return cls._tmp(suffix='.py',*p,**k) + def tmp_py(cls, *p, **k): + return cls._tmp(suffix='.py', *p, **k) @classmethod - def tmp_sql(cls,*p,**k): - return cls._tmp(suffix='.sql',*p,**k) + def tmp_sql(cls, *p, **k): + return cls._tmp(suffix='.sql', *p, **k) @classmethod - def tmp_named(cls,name): - return os.path.join(cls._tmpdir,name) + def tmp_named(cls, name): + return os.path.join(cls._tmpdir, name) @classmethod - def tmp_repos(cls,*p,**k): - return cls._tmp(*p,**k) + def tmp_repos(cls, *p, **k): + return cls._tmp(*p, **k) @classmethod - def purge(cls,path): + def purge(cls, path): """Removes this path if it exists, in preparation for tests Careful - all tests should take place in /tmp. We don't want to accidentally wipe stuff out... @@ -54,6 +58,6 @@ class Pathed(base.Base): else: os.remove(path) if path.endswith('.py'): - pyc = path+'c' + pyc = path + 'c' if os.path.exists(pyc): os.remove(pyc) diff --git a/test/versioning/test_script.py b/test/versioning/test_script.py index eace2fa..ab48a47 100644 --- a/test/versioning/test_script.py +++ b/test/versioning/test_script.py @@ -1,13 +1,19 @@ -from test import fixture +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import shutil + from migrate.versioning.script import * from migrate.versioning import exceptions, version -import os,shutil +from test import fixture + class TestPyScript(fixture.Pathed): cls = PythonScript def test_create(self): """We can create a migration script""" - path=self.tmp_py() + path = self.tmp_py() # Creating a file that doesn't exist should succeed self.cls.create(path) self.assert_(os.path.exists(path)) @@ -18,8 +24,8 @@ class TestPyScript(fixture.Pathed): def test_verify_notfound(self): """Correctly verify a python migration script: nonexistant file""" - path=self.tmp_py() - self.assert_(not os.path.exists(path)) + path = self.tmp_py() + self.assertFalse(os.path.exists(path)) # Fails on empty path self.assertRaises(exceptions.InvalidScriptError,self.cls.verify,path) self.assertRaises(exceptions.InvalidScriptError,self.cls,path) @@ -38,19 +44,52 @@ class TestPyScript(fixture.Pathed): def test_verify_nofuncs(self): """Correctly verify a python migration script: valid python file; no upgrade func""" - path=self.tmp_py() + path = self.tmp_py() # Create empty file - f=open(path,'w') + f = open(path, 'w') f.write("def zergling():\n\tprint 'rush'") f.close() - self.assertRaises(exceptions.InvalidScriptError,self.cls.verify_module,path) + self.assertRaises(exceptions.InvalidScriptError, self.cls.verify_module, path) # script isn't verified on creation, but on module reference py = self.cls(path) self.assertRaises(exceptions.InvalidScriptError,(lambda x: x.module),py) + @fixture.usedb(supported='sqlite') + def test_preview_sql(self): + """Preview SQL abstract from ORM layer (sqlite)""" + path = self.tmp_py() + + f = open(path, 'w') + content = """ +from migrate import * +from sqlalchemy import * + +metadata = MetaData(migrate_engine) + +UserGroup = Table('Link', metadata, + Column('link1ID', Integer), + Column('link2ID', Integer), + UniqueConstraint('link1ID', 'link2ID')) + +def upgrade(): + metadata.create_all() + """ + f.write(content) + f.close() + + pyscript = self.cls(path) + SQL = pyscript.preview_sql(self.url, 1) + self.assertEqualsIgnoreWhitespace(""" + CREATE TABLE "Link" + ("link1ID" INTEGER, + "link2ID" INTEGER, + UNIQUE ("link1ID", "link2ID")) + """, SQL) + # TODO: test: No SQL should be executed! + def test_verify_success(self): """Correctly verify a python migration script: success""" - path=self.tmp_py() + path = self.tmp_py() # Succeeds after creating self.cls.create(path) self.cls.verify(path) @@ -66,8 +105,8 @@ class TestSqlScript(fixture.Pathed): # Create files -- files must be present or you'll get an exception later. sqlite_upgrade_file = '001_sqlite_upgrade.sql' default_upgrade_file = '001_default_upgrade.sql' - for file in [sqlite_upgrade_file, default_upgrade_file]: - filepath = '%s/%s' % (path, file) + for file_ in [sqlite_upgrade_file, default_upgrade_file]: + filepath = '%s/%s' % (path, file_) open(filepath, 'w').close() ver = version.Version(1, path, [sqlite_upgrade_file]) diff --git a/test/versioning/test_shell.py b/test/versioning/test_shell.py index b16dccb..a0941b5 100644 --- a/test/versioning/test_shell.py +++ b/test/versioning/test_shell.py @@ -163,21 +163,21 @@ class TestShellCommands(Shell): """Construct engine the smart way""" url = 'sqlite://' - engine = api._construct_engine(url) + engine = api.construct_engine(url) self.assert_(engine.name == 'sqlite') # keyword arg - engine = api._construct_engine(url, engine_arg_assert_unicode=True) - self.assert_(engine.dialect.assert_unicode) + engine = api.construct_engine(url, engine_arg_assert_unicode=True) + self.assertTrue(engine.dialect.assert_unicode) # dict - engine = api._construct_engine(url, engine_dict={'assert_unicode': True}) - self.assert_(engine.dialect.assert_unicode) + engine = api.construct_engine(url, engine_dict={'assert_unicode': True}) + self.assertTrue(engine.dialect.assert_unicode) # test precedance - engine = api._construct_engine(url, engine_dict={'assert_unicode': False}, + engine = api.construct_engine(url, engine_dict={'assert_unicode': False}, engine_arg_assert_unicode=True) - self.assert_(engine.dialect.assert_unicode) + self.assertTrue(engine.dialect.assert_unicode) def test_manage(self): """Create a project management script""" @@ -327,8 +327,12 @@ class TestShellDatabase(Shell, fixture.DB): # Add a script to the repository; upgrade the db self.assertSuccess(self.cmd('script', '--repository=%s' % repos_path, 'Desc')) self.assertEquals(self.cmd_version(repos_path), 1) - self.assertEquals(self.cmd_db_version(self.url, repos_path), 0) + + # Test preview + self.assertSuccess(self.cmd('upgrade', self.url, repos_path, 0, "--preview_sql")) + self.assertSuccess(self.cmd('upgrade', self.url, repos_path, 0, "--preview_py")) + self.assertSuccess(self.cmd('upgrade', self.url, repos_path)) self.assertEquals(self.cmd_db_version(self.url, repos_path), 1) @@ -345,7 +349,7 @@ class TestShellDatabase(Shell, fixture.DB): self.assertEquals(self.cmd_db_version(self.url, repos_path), 0) self.assertSuccess(self.cmd('drop_version_control', self.url, repos_path)) - + def _run_test_sqlfile(self, upgrade_script, downgrade_script): # TODO: add test script that checks if db really changed