500 lines
22 KiB
Python

import sys
import traceback
from StringIO import StringIO
import os,shutil
from test import fixture
from migrate.versioning.repository import Repository
from migrate.versioning import genmodel, shell
from sqlalchemy import MetaData,Table
python_version = sys.version[0:3]
class Shell(fixture.Shell):
_cmd=os.path.join('python migrate', 'versioning', 'shell.py')
@classmethod
def cmd(cls,*p):
p = map(lambda s: str(s),p)
ret = ' '.join([cls._cmd]+p)
return ret
def execute(self,shell_cmd,runshell=None):
"""A crude simulation of a shell command, to speed things up"""
# If we get an fd, the command is already done
if isinstance(shell_cmd,file) or isinstance(shell_cmd,StringIO):
return shell_cmd
# Analyze the command; see if we can 'fake' the shell
try:
# Forced to run in shell?
#if runshell or '--runshell' in sys.argv:
if runshell:
raise Exception
# Remove the command prefix
if not shell_cmd.startswith(self._cmd):
raise Exception
cmd = shell_cmd[(len(self._cmd)+1):]
params = cmd.split(' ')
command = params[0]
except:
return super(Shell,self).execute(shell_cmd)
# Redirect stdout to an object; redirect stderr to stdout
fd = StringIO()
orig_stdout = sys.stdout
orig_stderr = sys.stderr
sys.stdout = fd
sys.stderr = fd
# Execute this command
try:
try:
shell.main(params)
except SystemExit,e:
# Simulate the exit status
fd_close=fd.close
def close_():
fd_close()
return e.args[0]
fd.close = close_
except Exception,e:
# Print the exception, but don't re-raise it
traceback.print_exc()
# Simulate a nonzero exit status
fd_close=fd.close
def close_():
fd_close()
return 2
fd.close = close_
finally:
# Clean up
sys.stdout = orig_stdout
sys.stderr = orig_stderr
fd.seek(0)
return fd
def cmd_version(self,repos_path):
fd = self.execute(self.cmd('version',repos_path))
ret = int(fd.read().strip())
self.assertSuccess(fd)
return ret
def cmd_db_version(self,url,repos_path):
fd = self.execute(self.cmd('db_version',url,repos_path))
txt = fd.read()
#print txt
ret = int(txt.strip())
self.assertSuccess(fd)
return ret
class TestShellCommands(Shell):
"""Tests migrate.py commands"""
def test_run(self):
"""Runs; displays help"""
# Force this to run in shell...
self.assertSuccess(self.cmd('-h'),runshell=True)
self.assertSuccess(self.cmd('--help'),runshell=True)
def test_help(self):
"""Display help on a specific command"""
self.assertSuccess(self.cmd('-h'),runshell=True)
self.assertSuccess(self.cmd('--help'),runshell=True)
for cmd in shell.api.__all__:
fd=self.execute(self.cmd('help',cmd))
# Description may change, so best we can do is ensure it shows up
#self.assertNotEquals(fd.read(),'')
output = fd.read()
self.assertNotEquals(output,'')
self.assertSuccess(fd)
def test_create(self):
"""Repositories are created successfully"""
repos=self.tmp_repos()
name='name'
# Creating a file that doesn't exist should succeed
cmd=self.cmd('create',repos,name)
self.assertSuccess(cmd)
# Files should actually be created
self.assert_(os.path.exists(repos))
# The default table should not be None
repos_ = Repository(repos)
self.assertNotEquals(repos_.config.get('db_settings','version_table'),'None')
# Can't create it again: it already exists
self.assertFailure(cmd)
def test_script(self):
"""We can create a migration script via the command line"""
repos=self.tmp_repos()
self.assertSuccess(self.cmd('create',repos,'repository_name'))
self.assertSuccess(self.cmd('script', '--repository=%s' % repos, 'Desc'))
self.assert_(os.path.exists('%s/versions/001_Desc.py' % repos))
# 's' instead of 'script' should work too
self.assertSuccess(self.cmd('script', '--repository=%s' % repos, 'More'))
self.assert_(os.path.exists('%s/versions/002_More.py' % repos))
def test_script_sql(self):
"""We can create a migration sql script via the command line"""
repos=self.tmp_repos()
self.assertSuccess(self.cmd('create',repos,'repository_name'))
self.assertSuccess(self.cmd('script_sql', '--repository=%s' % repos, 'mydb'))
self.assert_(os.path.exists('%s/versions/001_mydb_upgrade.sql' % repos))
self.assert_(os.path.exists('%s/versions/001_mydb_downgrade.sql' % repos))
# Can't create it again: it already exists
self.assertFailure(self.cmd('script_sql', '--repository=%s' % repos, 'mydb'))
def test_manage(self):
"""Create a project management script"""
script=self.tmp_py()
self.assert_(not os.path.exists(script))
# No attempt is made to verify correctness of the repository path here
self.assertSuccess(self.cmd('manage',script,'--repository=/path/to/repository'))
self.assert_(os.path.exists(script))
class TestShellRepository(Shell):
"""Shell commands on an existing repository/python script"""
def setUp(self):
"""Create repository, python change script"""
self.path_repos=repos=self.tmp_repos()
self.assertSuccess(self.cmd('create',repos,'repository_name'))
def test_version(self):
"""Correctly detect repository version"""
# Version: 0 (no scripts yet); successful execution
fd=self.execute(self.cmd('version','--repository=%s'%self.path_repos))
self.assertEquals(fd.read().strip(),"0")
self.assertSuccess(fd)
# Also works as a positional param
fd=self.execute(self.cmd('version',self.path_repos))
self.assertEquals(fd.read().strip(),"0")
self.assertSuccess(fd)
# Create a script and version should increment
self.assertSuccess(self.cmd('script', '--repository=%s' % self.path_repos, 'Desc'))
fd=self.execute(self.cmd('version',self.path_repos))
self.assertEquals(fd.read().strip(),"1")
self.assertSuccess(fd)
def test_source(self):
"""Correctly fetch a script's source"""
self.assertSuccess(self.cmd('script', '--repository=%s' % self.path_repos, 'Desc'))
filename='%s/versions/001_Desc.py' % self.path_repos
source=open(filename).read()
self.assert_(source.find('def upgrade')>=0)
# Version is now 1
fd=self.execute(self.cmd('version',self.path_repos))
self.assert_(fd.read().strip()=="1")
self.assertSuccess(fd)
# Output/verify the source of version 1
fd=self.execute(self.cmd('source',1,'--repository=%s'%self.path_repos))
result=fd.read()
self.assertSuccess(fd)
self.assert_(result.strip()==source.strip())
# We can also send the source to a file... test that too
self.assertSuccess(self.cmd('source',1,filename,'--repository=%s'%self.path_repos))
self.assert_(os.path.exists(filename))
fd=open(filename)
result=fd.read()
self.assert_(result.strip()==source.strip())
class TestShellDatabase(Shell,fixture.DB):
"""Commands associated with a particular database"""
# We'll need to clean up after ourself, since the shell creates its own txn;
# we need to connect to the DB to see if things worked
level=fixture.DB.CONNECT
@fixture.usedb()
def test_version_control(self):
"""Ensure we can set version control on a database"""
path_repos=repos=self.tmp_repos()
self.assertSuccess(self.cmd('create',path_repos,'repository_name'))
self.exitcode(self.cmd('drop_version_control',self.url,path_repos))
self.assertSuccess(self.cmd('version_control',self.url,path_repos))
# Clean up
self.assertSuccess(self.cmd('drop_version_control',self.url,path_repos))
# Attempting to drop vc from a database without it should fail
self.assertFailure(self.cmd('drop_version_control',self.url,path_repos))
@fixture.usedb()
def test_version_control_specified(self):
"""Ensure we can set version control to a particular version"""
path_repos=self.tmp_repos()
self.assertSuccess(self.cmd('create',path_repos,'repository_name'))
self.exitcode(self.cmd('drop_version_control',self.url,path_repos))
# Fill the repository
path_script = self.tmp_py()
version=1
for i in range(version):
self.assertSuccess(self.cmd('script', '--repository=%s' % path_repos, 'Desc'))
# Repository version is correct
fd=self.execute(self.cmd('version',path_repos))
self.assertEquals(fd.read().strip(),str(version))
self.assertSuccess(fd)
# Apply versioning to DB
self.assertSuccess(self.cmd('version_control',self.url,path_repos,version))
# Test version number
fd=self.execute(self.cmd('db_version',self.url,path_repos))
self.assertEquals(fd.read().strip(),str(version))
self.assertSuccess(fd)
# Clean up
self.assertSuccess(self.cmd('drop_version_control',self.url,path_repos))
@fixture.usedb()
def test_upgrade(self):
"""Can upgrade a versioned database"""
# Create a repository
repos_name = 'repos_name'
repos_path = self.tmp()
self.assertSuccess(self.cmd('create',repos_path,repos_name))
self.assertEquals(self.cmd_version(repos_path),0)
# Version the DB
self.exitcode(self.cmd('drop_version_control',self.url,repos_path))
self.assertSuccess(self.cmd('version_control',self.url,repos_path))
# Upgrades with latest version == 0
self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
self.assertSuccess(self.cmd('upgrade',self.url,repos_path))
self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
self.assertSuccess(self.cmd('upgrade',self.url,repos_path,0))
self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
self.assertFailure(self.cmd('upgrade',self.url,repos_path,1))
self.assertFailure(self.cmd('upgrade',self.url,repos_path,-1))
# Add a script to the repository; upgrade the db
self.assertSuccess(self.cmd('script', '--repository=%s' % repos_path, 'Desc'))
self.assertEquals(self.cmd_version(repos_path),1)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
self.assertSuccess(self.cmd('upgrade',self.url,repos_path))
self.assertEquals(self.cmd_db_version(self.url,repos_path),1)
# Downgrade must have a valid version specified
self.assertFailure(self.cmd('downgrade',self.url,repos_path))
self.assertFailure(self.cmd('downgrade',self.url,repos_path,2))
self.assertFailure(self.cmd('downgrade',self.url,repos_path,-1))
self.assertEquals(self.cmd_db_version(self.url,repos_path),1)
self.assertSuccess(self.cmd('downgrade',self.url,repos_path,0))
self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
self.assertFailure(self.cmd('downgrade',self.url,repos_path,1))
self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
self.assertSuccess(self.cmd('drop_version_control',self.url,repos_path))
def _run_test_sqlfile(self,upgrade_script,downgrade_script):
repos_path = self.tmp()
repos_name = 'repos'
self.assertSuccess(self.cmd('create',repos_path,repos_name))
self.exitcode(self.cmd('drop_version_control',self.url,repos_path))
self.assertSuccess(self.cmd('version_control',self.url,repos_path))
self.assertEquals(self.cmd_version(repos_path),0)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
beforeCount = len(os.listdir(os.path.join(repos_path,'versions'))) # hmm, this number changes sometimes based on running from svn
self.assertSuccess(self.cmd('script_sql', '--repository=%s' % repos_path, 'postgres'))
self.assertEquals(self.cmd_version(repos_path),1)
self.assertEquals(len(os.listdir(os.path.join(repos_path,'versions'))), beforeCount + 2)
open('%s/versions/001_postgres_upgrade.sql' % repos_path, 'a').write(upgrade_script)
open('%s/versions/001_postgres_downgrade.sql' % repos_path, 'a').write(downgrade_script)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
self.assertRaises(Exception,self.engine.text('select * from t_table').execute)
self.assertSuccess(self.cmd('upgrade',self.url,repos_path))
self.assertEquals(self.cmd_db_version(self.url,repos_path),1)
self.engine.text('select * from t_table').execute()
self.assertSuccess(self.cmd('downgrade',self.url,repos_path,0))
self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
self.assertRaises(Exception,self.engine.text('select * from t_table').execute)
# The tests below are written with some postgres syntax, but the stuff
# being tested (.sql files) ought to work with any db.
@fixture.usedb(supported='postgres')
def test_sqlfile(self):
upgrade_script = """
create table t_table (
id serial,
primary key(id)
);
"""
downgrade_script = """
drop table t_table;
"""
self._run_test_sqlfile(upgrade_script,downgrade_script)
@fixture.usedb(supported='postgres')
def test_sqlfile_comment(self):
upgrade_script = """
-- Comments in SQL break postgres autocommit
create table t_table (
id serial,
primary key(id)
);
"""
downgrade_script = """
-- Comments in SQL break postgres autocommit
drop table t_table;
"""
self._run_test_sqlfile(upgrade_script,downgrade_script)
@fixture.usedb()
def test_test(self):
repos_name = 'repos_name'
repos_path = self.tmp()
self.assertSuccess(self.cmd('create',repos_path,repos_name))
self.exitcode(self.cmd('drop_version_control',self.url,repos_path))
self.assertSuccess(self.cmd('version_control',self.url,repos_path))
self.assertEquals(self.cmd_version(repos_path),0)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
# Empty script should succeed
self.assertSuccess(self.cmd('script', '--repository=%s' % repos_path, 'Desc'))
self.assertSuccess(self.cmd('test',repos_path,self.url))
self.assertEquals(self.cmd_version(repos_path),1)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
# Error script should fail
script_path = self.tmp_py()
script_text="""
from sqlalchemy import *
from migrate import *
def upgrade():
print 'fgsfds'
raise Exception()
def downgrade():
print 'sdfsgf'
raise Exception()
""".replace("\n ","\n")
file=open(script_path,'w')
file.write(script_text)
file.close()
self.assertFailure(self.cmd('test',repos_path,self.url,'blah blah'))
self.assertEquals(self.cmd_version(repos_path),1)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
# Nonempty script using migrate_engine should succeed
script_path = self.tmp_py()
script_text="""
from sqlalchemy import *
from migrate import *
meta = MetaData(migrate_engine)
account = Table('account',meta,
Column('id',Integer,primary_key=True),
Column('login',String(40)),
Column('passwd',String(40)),
)
def upgrade():
# Upgrade operations go here. Don't create your own engine; use the engine
# named 'migrate_engine' imported from migrate.
meta.create_all()
def downgrade():
# Operations to reverse the above upgrade go here.
meta.drop_all()
""".replace("\n ","\n")
file=open(script_path,'w')
file.write(script_text)
file.close()
self.assertSuccess(self.cmd('test',repos_path,self.url))
self.assertEquals(self.cmd_version(repos_path),1)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
@fixture.usedb()
def test_rundiffs_in_shell(self):
# This is a variant of the test_schemadiff tests but run through the shell level.
# These shell tests are hard to debug (since they keep forking processes), so they shouldn't replace the lower-level tests.
repos_name = 'repos_name'
repos_path = self.tmp()
script_path = self.tmp_py()
old_model_path = self.tmp_named('oldtestmodel.py')
model_path = self.tmp_named('testmodel.py')
# Create empty repository.
self.meta = MetaData(self.engine, reflect=True)
self.meta.drop_all() # in case junk tables are lying around in the test database
self.assertSuccess(self.cmd('create',repos_path,repos_name))
self.exitcode(self.cmd('drop_version_control',self.url,repos_path))
self.assertSuccess(self.cmd('version_control',self.url,repos_path))
self.assertEquals(self.cmd_version(repos_path),0)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0)
# Setup helper script.
model_module = 'testmodel.meta'
self.assertSuccess(self.cmd('manage',script_path,'--repository=%s --url=%s --model=%s' % (repos_path, self.url, model_module)))
self.assert_(os.path.exists(script_path))
# Write old and new model to disk - old model is empty!
script_preamble="""
from sqlalchemy import *
meta = MetaData()
""".replace("\n ","\n")
script_text="""
""".replace("\n ","\n")
open(old_model_path, 'w').write(script_preamble + script_text)
script_text="""
tmp_account_rundiffs = Table('tmp_account_rundiffs',meta,
Column('id',Integer,primary_key=True),
Column('login',String(40)),
Column('passwd',String(40)),
)
""".replace("\n ","\n")
open(model_path, 'w').write(script_preamble + script_text)
# Model is defined but database is empty.
output, exitcode = self.output_and_exitcode('python %s compare_model_to_db' % script_path)
self.assertEquals(output, "Schema diffs:\n tables missing in database: tmp_account_rundiffs")
# Update db to latest model.
output, exitcode = self.output_and_exitcode('python %s update_db_from_model' % script_path)
self.assertEquals(output, "")
self.assertEquals(self.cmd_version(repos_path),0)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0) # version did not get bumped yet because new version not yet created
output, exitcode = self.output_and_exitcode('python %s compare_model_to_db' % script_path)
self.assertEquals(output, "No schema diffs")
output, exitcode = self.output_and_exitcode('python %s create_model' % script_path)
output = output.replace(genmodel.HEADER.strip(), '') # need strip b/c output_and_exitcode called strip
self.assertEqualsIgnoreWhitespace(output, """
tmp_account_rundiffs = Table('tmp_account_rundiffs',meta,
Column('id',Integer(),primary_key=True,nullable=False),
Column('login',String(length=None,convert_unicode=False,assert_unicode=None)),
Column('passwd',String(length=None,convert_unicode=False,assert_unicode=None)),
)
""") # TODO: length shouldn't be None above
# We're happy with db changes, make first db upgrade script to go from version 0 -> 1.
output, exitcode = self.output_and_exitcode('python %s make_update_script_for_model' % script_path) # intentionally omit a parameter
self.assertEquals('Error: Too few arguments' in output, True)
output, exitcode = self.output_and_exitcode('python %s make_update_script_for_model --oldmodel=oldtestmodel.meta' % script_path)
self.assertEqualsIgnoreWhitespace(output, """
from sqlalchemy import *
from migrate import *
meta = MetaData(migrate_engine)
tmp_account_rundiffs = Table('tmp_account_rundiffs', meta,
Column('id', Integer() , primary_key=True, nullable=False),
Column('login', String(length=40, convert_unicode=False, assert_unicode=None) ),
Column('passwd', String(length=40, convert_unicode=False, assert_unicode=None) ),
)
def upgrade():
# Upgrade operations go here. Don't create your own engine; use the engine
# named 'migrate_engine' imported from migrate.
tmp_account_rundiffs.create()
def downgrade():
# Operations to reverse the above upgrade go here.
tmp_account_rundiffs.drop()
""")
# Save the upgrade script.
self.assertSuccess(self.cmd('script', '--repository=%s' % repos_path, 'Desc'))
upgrade_script_path = '%s/versions/001_Desc.py' % repos_path
open(upgrade_script_path, 'w').write(output)
#output, exitcode = self.output_and_exitcode('python %s test %s' % (script_path, upgrade_script_path)) # no, we already upgraded the db above
#self.assertEquals(output, "")
output, exitcode = self.output_and_exitcode('python %s update_db_from_model' % script_path) # bump the db_version
self.assertEquals(output, "")
self.assertEquals(self.cmd_version(repos_path),1)
self.assertEquals(self.cmd_db_version(self.url,repos_path),1)