deb-alembic/alembic/migration.py

693 lines
24 KiB
Python

import logging
import sys
from contextlib import contextmanager
from collections import namedtuple
from sqlalchemy import MetaData, Table, Column, String, literal_column
from sqlalchemy import create_engine
from sqlalchemy.engine import url as sqla_url
from .compat import callable, EncodedIO, string_types
from . import ddl, util
from .revision import tuple_rev_as_scalar
log = logging.getLogger(__name__)
class MigrationContext(object):
"""Represent the database state made available to a migration
script.
:class:`.MigrationContext` is the front end to an actual
database connection, or alternatively a string output
stream given a particular database dialect,
from an Alembic perspective.
When inside the ``env.py`` script, the :class:`.MigrationContext`
is available via the
:meth:`.EnvironmentContext.get_context` method,
which is available at ``alembic.context``::
# from within env.py script
from alembic import context
migration_context = context.get_context()
For usage outside of an ``env.py`` script, such as for
utility routines that want to check the current version
in the database, the :meth:`.MigrationContext.configure`
method to create new :class:`.MigrationContext` objects.
For example, to get at the current revision in the
database using :meth:`.MigrationContext.get_current_revision`::
# in any application, outside of an env.py script
from alembic.migration import MigrationContext
from sqlalchemy import create_engine
engine = create_engine("postgresql://mydatabase")
conn = engine.connect()
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
The above context can also be used to produce
Alembic migration operations with an :class:`.Operations`
instance::
# in any application, outside of the normal Alembic environment
from alembic.operations import Operations
op = Operations(context)
op.alter_column("mytable", "somecolumn", nullable=True)
"""
def __init__(self, dialect, connection, opts, environment_context=None):
self.environment_context = environment_context
self.opts = opts
self.dialect = dialect
self.script = opts.get('script')
as_sql = opts.get('as_sql', False)
transactional_ddl = opts.get("transactional_ddl")
self._transaction_per_migration = opts.get(
"transaction_per_migration", False)
if as_sql:
self.connection = self._stdout_connection(connection)
assert self.connection is not None
else:
self.connection = connection
self._migrations_fn = opts.get('fn')
self.as_sql = as_sql
if "output_encoding" in opts:
self.output_buffer = EncodedIO(
opts.get("output_buffer") or sys.stdout,
opts['output_encoding']
)
else:
self.output_buffer = opts.get("output_buffer", sys.stdout)
self._user_compare_type = opts.get('compare_type', False)
self._user_compare_server_default = opts.get(
'compare_server_default',
False)
self.version_table = version_table = opts.get(
'version_table', 'alembic_version')
self.version_table_schema = version_table_schema = \
opts.get('version_table_schema', None)
self._version = Table(
version_table, MetaData(),
Column('version_num', String(32), nullable=False),
schema=version_table_schema)
self._start_from_rev = opts.get("starting_rev")
self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
dialect, self.connection, self.as_sql,
transactional_ddl,
self.output_buffer,
opts
)
log.info("Context impl %s.", self.impl.__class__.__name__)
if self.as_sql:
log.info("Generating static SQL")
log.info("Will assume %s DDL.",
"transactional" if self.impl.transactional_ddl
else "non-transactional")
@classmethod
def configure(cls,
connection=None,
url=None,
dialect_name=None,
environment_context=None,
opts=None,
):
"""Create a new :class:`.MigrationContext`.
This is a factory method usually called
by :meth:`.EnvironmentContext.configure`.
:param connection: a :class:`~sqlalchemy.engine.Connection`
to use for SQL execution in "online" mode. When present,
is also used to determine the type of dialect in use.
:param url: a string database url, or a
:class:`sqlalchemy.engine.url.URL` object.
The type of dialect to be used will be derived from this if
``connection`` is not passed.
:param dialect_name: string name of a dialect, such as
"postgresql", "mssql", etc. The type of dialect to be used will be
derived from this if ``connection`` and ``url`` are not passed.
:param opts: dictionary of options. Most other options
accepted by :meth:`.EnvironmentContext.configure` are passed via
this dictionary.
"""
if opts is None:
opts = {}
if connection:
dialect = connection.dialect
elif url:
url = sqla_url.make_url(url)
dialect = url.get_dialect()()
elif dialect_name:
url = sqla_url.make_url("%s://" % dialect_name)
dialect = url.get_dialect()()
else:
raise Exception("Connection, url, or dialect_name is required.")
return MigrationContext(dialect, connection, opts, environment_context)
def begin_transaction(self, _per_migration=False):
transaction_now = _per_migration == self._transaction_per_migration
if not transaction_now:
@contextmanager
def do_nothing():
yield
return do_nothing()
elif not self.impl.transactional_ddl:
@contextmanager
def do_nothing():
yield
return do_nothing()
elif self.as_sql:
@contextmanager
def begin_commit():
self.impl.emit_begin()
yield
self.impl.emit_commit()
return begin_commit()
else:
return self.bind.begin()
def get_current_revision(self):
"""Return the current revision, usually that which is present
in the ``alembic_version`` table in the database.
This method intends to be used only for a migration stream that
does not contain unmerged branches in the target database;
if there are multiple branches present, an exception is raised.
The :meth:`.MigrationContext.get_current_heads` should be preferred
over this method going forward in order to be compatible with
branch migration support.
If this :class:`.MigrationContext` was configured in "offline"
mode, that is with ``as_sql=True``, the ``starting_rev``
parameter is returned instead, if any.
"""
heads = self.get_current_heads()
if len(heads) == 0:
return None
elif len(heads) > 1:
raise util.CommandError(
"Version table '%s' has more than one head present; "
"please use get_current_heads()" % self.version_table)
else:
return heads[0]
def get_current_heads(self):
"""Return a tuple of the current 'head versions' that are represented
in the target database.
For a migration stream without branches, this will be a single
value, synonymous with that of
:meth:`.MigrationContext.get_current_revision`. However when multiple
unmerged branches exist within the target database, the returned tuple
will contain a value for each head.
If this :class:`.MigrationContext` was configured in "offline"
mode, that is with ``as_sql=True``, the ``starting_rev``
parameter is returned in a one-length tuple.
If no version table is present, or if there are no revisions
present, an empty tuple is returned.
.. versionadded:: 0.7.0
"""
if self.as_sql:
return util.to_tuple(self._start_from_rev, default=())
else:
if self._start_from_rev:
raise util.CommandError(
"Can't specify current_rev to context "
"when using a database connection")
if not self._has_version_table():
return ()
return tuple(
row[0] for row in self.connection.execute(self._version.select())
)
def _ensure_version_table(self):
self._version.create(self.connection, checkfirst=True)
def _has_version_table(self):
return self.connection.dialect.has_table(
self.connection, self.version_table, self.version_table_schema)
def stamp(self, script_directory, revision):
"""Stamp the version table with a specific revision.
This method calculates those branches to which the given revision
can apply, and updates those branches as though they were migrated
towards that revision (either up or down). If no current branches
include the revision, it is added as a new branch head.
.. versionadded:: 0.7.0
"""
heads = self.get_current_heads()
head_maintainer = HeadMaintainer(self, heads)
for step in script_directory._steps_revs(revision, heads):
head_maintainer.update_to_step(step)
def run_migrations(self, **kw):
"""Run the migration scripts established for this
:class:`.MigrationContext`, if any.
The commands in :mod:`alembic.command` will set up a function
that is ultimately passed to the :class:`.MigrationContext`
as the ``fn`` argument. This function represents the "work"
that will be done when :meth:`.MigrationContext.run_migrations`
is called, typically from within the ``env.py`` script of the
migration environment. The "work function" then provides an iterable
of version callables and other version information which
in the case of the ``upgrade`` or ``downgrade`` commands are the
list of version scripts to invoke. Other commands yield nothing,
in the case that a command wants to run some other operation
against the database such as the ``current`` or ``stamp`` commands.
:param \**kw: keyword arguments here will be passed to each
migration callable, that is the ``upgrade()`` or ``downgrade()``
method within revision scripts.
"""
self.impl.start_migrations()
heads = self.get_current_heads()
if not self.as_sql and not heads:
self._ensure_version_table()
head_maintainer = HeadMaintainer(self, heads)
for step in self._migrations_fn(heads, self):
with self.begin_transaction(_per_migration=True):
if self.as_sql and not head_maintainer.heads:
# for offline mode, include a CREATE TABLE from
# the base
self._version.create(self.connection)
log.info("Running %s", step)
if self.as_sql:
self.impl.static_output("-- Running %s" % (step.short_log,))
step.migration_fn(**kw)
# previously, we wouldn't stamp per migration
# if we were in a transaction, however given the more
# complex model that involves any number of inserts
# and row-targeted updates and deletes, it's simpler for now
# just to run the operations on every version
head_maintainer.update_to_step(step)
if self.as_sql and not head_maintainer.heads:
self._version.drop(self.connection)
def execute(self, sql, execution_options=None):
"""Execute a SQL construct or string statement.
The underlying execution mechanics are used, that is
if this is "offline mode" the SQL is written to the
output buffer, otherwise the SQL is emitted on
the current SQLAlchemy connection.
"""
self.impl._exec(sql, execution_options)
def _stdout_connection(self, connection):
def dump(construct, *multiparams, **params):
self.impl._exec(construct)
return create_engine("%s://" % self.dialect.name,
strategy="mock", executor=dump)
@property
def bind(self):
"""Return the current "bind".
In online mode, this is an instance of
:class:`sqlalchemy.engine.Connection`, and is suitable
for ad-hoc execution of any kind of usage described
in :ref:`sqlexpression_toplevel` as well as
for usage with the :meth:`sqlalchemy.schema.Table.create`
and :meth:`sqlalchemy.schema.MetaData.create_all` methods
of :class:`~sqlalchemy.schema.Table`,
:class:`~sqlalchemy.schema.MetaData`.
Note that when "standard output" mode is enabled,
this bind will be a "mock" connection handler that cannot
return results and is only appropriate for a very limited
subset of commands.
"""
return self.connection
@property
def config(self):
"""Return the :class:`.Config` used by the current environment, if any.
.. versionadded:: 0.6.6
"""
if self.environment_context:
return self.environment_context.config
else:
return None
def _compare_type(self, inspector_column, metadata_column):
if self._user_compare_type is False:
return False
if callable(self._user_compare_type):
user_value = self._user_compare_type(
self,
inspector_column,
metadata_column,
inspector_column.type,
metadata_column.type
)
if user_value is not None:
return user_value
return self.impl.compare_type(
inspector_column,
metadata_column)
def _compare_server_default(self, inspector_column,
metadata_column,
rendered_metadata_default,
rendered_column_default):
if self._user_compare_server_default is False:
return False
if callable(self._user_compare_server_default):
user_value = self._user_compare_server_default(
self,
inspector_column,
metadata_column,
rendered_column_default,
metadata_column.server_default,
rendered_metadata_default
)
if user_value is not None:
return user_value
return self.impl.compare_server_default(
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_column_default)
class HeadMaintainer(object):
def __init__(self, context, heads):
self.context = context
self.heads = set(heads)
def _insert_version(self, version):
assert version not in self.heads
self.heads.add(version)
self.context.impl._exec(
self.context._version.insert().
values(
version_num=literal_column("'%s'" % version)
)
)
def _delete_version(self, version):
self.heads.remove(version)
ret = self.context.impl._exec(
self.context._version.delete().where(
self.context._version.c.version_num ==
literal_column("'%s'" % version)))
if not self.context.as_sql and ret.rowcount != 1:
raise util.CommandError(
"Online migration expected to match one "
"row when deleting '%s' in '%s'; "
"%d found"
% (version,
self.context.version_table, ret.rowcount))
def _update_version(self, from_, to_):
assert to_ not in self.heads
self.heads.remove(from_)
self.heads.add(to_)
ret = self.context.impl._exec(
self.context._version.update().
values(version_num=literal_column("'%s'" % to_)).where(
self.context._version.c.version_num
== literal_column("'%s'" % from_))
)
if not self.context.as_sql and ret.rowcount != 1:
raise util.CommandError(
"Online migration expected to match one "
"row when updating '%s' to '%s' in '%s'; "
"%d found"
% (from_, to_, self.context.version_table, ret.rowcount))
def update_to_step(self, step):
if step.should_delete_branch(self.heads):
vers = step.delete_version_num
log.debug("branch delete %s", vers)
self._delete_version(vers)
elif step.should_create_branch(self.heads):
vers = step.insert_version_num
log.debug("new branch insert %s", vers)
self._insert_version(vers)
elif step.should_merge_branches(self.heads):
# delete revs, update from rev, update to rev
(delete_revs, update_from_rev,
update_to_rev) = step.merge_branch_idents
log.debug(
"merge, delete %s, update %s to %s",
delete_revs, update_from_rev, update_to_rev)
for delrev in delete_revs:
self._delete_version(delrev)
self._update_version(update_from_rev, update_to_rev)
elif step.should_unmerge_branches(self.heads):
(update_from_rev, update_to_rev,
insert_revs) = step.unmerge_branch_idents
log.debug(
"unmerge, insert %s, update %s to %s",
insert_revs, update_from_rev, update_to_rev)
for insrev in insert_revs:
self._insert_version(insrev)
self._update_version(update_from_rev, update_to_rev)
else:
from_, to_ = step.update_version_num
log.debug("update %s to %s", from_, to_)
self._update_version(from_, to_)
class MigrationStep(object):
@property
def name(self):
return self.migration_fn.__name__
@classmethod
def upgrade_from_script(cls, revision_map, script):
return RevisionStep(revision_map, script, True)
@classmethod
def downgrade_from_script(cls, revision_map, script):
return RevisionStep(revision_map, script, False)
@property
def short_log(self):
return "%s %s -> %s" % (
self.name, tuple_rev_as_scalar(self.from_revisions),
tuple_rev_as_scalar(self.to_revisions)
)
def __str__(self):
if self.doc:
return "%s %s -> %s, %s" % (
self.name, tuple_rev_as_scalar(self.from_revisions),
tuple_rev_as_scalar(self.to_revisions), self.doc
)
else:
return self.short_log
class RevisionStep(MigrationStep):
def __init__(self, revision_map, revision, is_upgrade):
self.revision_map = revision_map
self.revision = revision
self.is_upgrade = is_upgrade
if is_upgrade:
self.migration_fn = revision.module.upgrade
else:
self.migration_fn = revision.module.downgrade
def __eq__(self, other):
return isinstance(other, RevisionStep) and \
other.revision == self.revision and \
self.is_upgrade == other.is_upgrade
@property
def doc(self):
return self.revision.doc
@property
def from_revisions(self):
if self.is_upgrade:
return self.revision._down_revision_tuple
else:
return (self.revision.revision, )
@property
def to_revisions(self):
if self.is_upgrade:
return (self.revision.revision, )
else:
return self.revision._down_revision_tuple
@property
def is_downgrade(self):
return not self.is_upgrade
@property
def _has_scalar_down_revision(self):
return len(self.revision._down_revision_tuple) == 1
def should_delete_branch(self, heads):
if not self.is_downgrade:
return False
if self.revision.revision not in heads:
return False
downrevs = self.revision._down_revision_tuple
if not downrevs:
# is a base
return True
elif len(downrevs) == 1:
downrev = self.revision_map.get_revision(downrevs[0])
if not downrev.is_branch_point:
return False
descendants = set(
r.revision for r in self.revision_map._get_descendant_nodes(
self.revision_map.get_revisions(downrev.nextrev),
check=False
)
)
# the downrev is a branchpoint, and other members or descendants
# of the branch are still in heads; so delete this branch.
# the reason this occurs is because traversal tries to stay
# fully on one branch down to the branchpoint before starting
# the other; so if we have a->b->(c1->d1->e1, c2->d2->e2),
# on a downgrade from the top we may go e1, d1, c1, now heads
# are at c1 and e2, with the current method, we don't know that
# "e2" is important unless we get all descendants of c1/c2
if len(descendants.intersection(heads).difference(
[self.revision.revision])):
# TODO: this doesn't work; make sure tests are here to ensure
# this fails
#if len(downrev.nextrev.intersection(heads).difference(
# [self.revision.revision])):
return True
else:
return False
else:
# is a merge point
return False
def should_create_branch(self, heads):
if not self.is_upgrade:
return False
downrevs = self.revision._down_revision_tuple
if not downrevs:
# is a base
return True
elif len(downrevs) == 1:
if downrevs[0] in heads:
return False
else:
return True
else:
# is a merge point
return False
def should_merge_branches(self, heads):
if not self.is_upgrade:
return False
downrevs = self.revision._down_revision_tuple
if len(downrevs) > 1 and \
len(heads.intersection(downrevs)) > 1:
return True
return False
def should_unmerge_branches(self, heads):
if not self.is_downgrade:
return False
downrevs = self.revision._down_revision_tuple
if self.revision.revision in heads and len(downrevs) > 1:
return True
return False
@property
def update_version_num(self):
assert self._has_scalar_down_revision
if self.is_upgrade:
return self.revision.down_revision, self.revision.revision
else:
return self.revision.revision, self.revision.down_revision
@property
def delete_version_num(self):
return self.revision.revision
@property
def insert_version_num(self):
return self.revision.revision
@property
def merge_branch_idents(self):
return (
# delete revs, update from rev, update to rev
self.from_revisions[0:-1], self.from_revisions[-1],
self.to_revisions[0]
)
@property
def unmerge_branch_idents(self):
return (
# update from rev, update to rev, insert revs
self.from_revisions[0], self.to_revisions[-1],
self.to_revisions[0:-1]
)
class StampStep(MigrationStep):
pass