693 lines
24 KiB
Python
693 lines
24 KiB
Python
import logging
|
|
import sys
|
|
from contextlib import contextmanager
|
|
from collections import namedtuple
|
|
|
|
from sqlalchemy import MetaData, Table, Column, String, literal_column
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.engine import url as sqla_url
|
|
|
|
from .compat import callable, EncodedIO, string_types
|
|
from . import ddl, util
|
|
from .revision import tuple_rev_as_scalar
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class MigrationContext(object):
|
|
|
|
"""Represent the database state made available to a migration
|
|
script.
|
|
|
|
:class:`.MigrationContext` is the front end to an actual
|
|
database connection, or alternatively a string output
|
|
stream given a particular database dialect,
|
|
from an Alembic perspective.
|
|
|
|
When inside the ``env.py`` script, the :class:`.MigrationContext`
|
|
is available via the
|
|
:meth:`.EnvironmentContext.get_context` method,
|
|
which is available at ``alembic.context``::
|
|
|
|
# from within env.py script
|
|
from alembic import context
|
|
migration_context = context.get_context()
|
|
|
|
For usage outside of an ``env.py`` script, such as for
|
|
utility routines that want to check the current version
|
|
in the database, the :meth:`.MigrationContext.configure`
|
|
method to create new :class:`.MigrationContext` objects.
|
|
For example, to get at the current revision in the
|
|
database using :meth:`.MigrationContext.get_current_revision`::
|
|
|
|
# in any application, outside of an env.py script
|
|
from alembic.migration import MigrationContext
|
|
from sqlalchemy import create_engine
|
|
|
|
engine = create_engine("postgresql://mydatabase")
|
|
conn = engine.connect()
|
|
|
|
context = MigrationContext.configure(conn)
|
|
current_rev = context.get_current_revision()
|
|
|
|
The above context can also be used to produce
|
|
Alembic migration operations with an :class:`.Operations`
|
|
instance::
|
|
|
|
# in any application, outside of the normal Alembic environment
|
|
from alembic.operations import Operations
|
|
op = Operations(context)
|
|
op.alter_column("mytable", "somecolumn", nullable=True)
|
|
|
|
"""
|
|
|
|
def __init__(self, dialect, connection, opts, environment_context=None):
|
|
self.environment_context = environment_context
|
|
self.opts = opts
|
|
self.dialect = dialect
|
|
self.script = opts.get('script')
|
|
|
|
as_sql = opts.get('as_sql', False)
|
|
transactional_ddl = opts.get("transactional_ddl")
|
|
|
|
self._transaction_per_migration = opts.get(
|
|
"transaction_per_migration", False)
|
|
|
|
if as_sql:
|
|
self.connection = self._stdout_connection(connection)
|
|
assert self.connection is not None
|
|
else:
|
|
self.connection = connection
|
|
self._migrations_fn = opts.get('fn')
|
|
self.as_sql = as_sql
|
|
|
|
if "output_encoding" in opts:
|
|
self.output_buffer = EncodedIO(
|
|
opts.get("output_buffer") or sys.stdout,
|
|
opts['output_encoding']
|
|
)
|
|
else:
|
|
self.output_buffer = opts.get("output_buffer", sys.stdout)
|
|
|
|
self._user_compare_type = opts.get('compare_type', False)
|
|
self._user_compare_server_default = opts.get(
|
|
'compare_server_default',
|
|
False)
|
|
self.version_table = version_table = opts.get(
|
|
'version_table', 'alembic_version')
|
|
self.version_table_schema = version_table_schema = \
|
|
opts.get('version_table_schema', None)
|
|
self._version = Table(
|
|
version_table, MetaData(),
|
|
Column('version_num', String(32), nullable=False),
|
|
schema=version_table_schema)
|
|
|
|
self._start_from_rev = opts.get("starting_rev")
|
|
self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
|
|
dialect, self.connection, self.as_sql,
|
|
transactional_ddl,
|
|
self.output_buffer,
|
|
opts
|
|
)
|
|
log.info("Context impl %s.", self.impl.__class__.__name__)
|
|
if self.as_sql:
|
|
log.info("Generating static SQL")
|
|
log.info("Will assume %s DDL.",
|
|
"transactional" if self.impl.transactional_ddl
|
|
else "non-transactional")
|
|
|
|
@classmethod
|
|
def configure(cls,
|
|
connection=None,
|
|
url=None,
|
|
dialect_name=None,
|
|
environment_context=None,
|
|
opts=None,
|
|
):
|
|
"""Create a new :class:`.MigrationContext`.
|
|
|
|
This is a factory method usually called
|
|
by :meth:`.EnvironmentContext.configure`.
|
|
|
|
:param connection: a :class:`~sqlalchemy.engine.Connection`
|
|
to use for SQL execution in "online" mode. When present,
|
|
is also used to determine the type of dialect in use.
|
|
:param url: a string database url, or a
|
|
:class:`sqlalchemy.engine.url.URL` object.
|
|
The type of dialect to be used will be derived from this if
|
|
``connection`` is not passed.
|
|
:param dialect_name: string name of a dialect, such as
|
|
"postgresql", "mssql", etc. The type of dialect to be used will be
|
|
derived from this if ``connection`` and ``url`` are not passed.
|
|
:param opts: dictionary of options. Most other options
|
|
accepted by :meth:`.EnvironmentContext.configure` are passed via
|
|
this dictionary.
|
|
|
|
"""
|
|
if opts is None:
|
|
opts = {}
|
|
|
|
if connection:
|
|
dialect = connection.dialect
|
|
elif url:
|
|
url = sqla_url.make_url(url)
|
|
dialect = url.get_dialect()()
|
|
elif dialect_name:
|
|
url = sqla_url.make_url("%s://" % dialect_name)
|
|
dialect = url.get_dialect()()
|
|
else:
|
|
raise Exception("Connection, url, or dialect_name is required.")
|
|
|
|
return MigrationContext(dialect, connection, opts, environment_context)
|
|
|
|
def begin_transaction(self, _per_migration=False):
|
|
transaction_now = _per_migration == self._transaction_per_migration
|
|
|
|
if not transaction_now:
|
|
@contextmanager
|
|
def do_nothing():
|
|
yield
|
|
return do_nothing()
|
|
|
|
elif not self.impl.transactional_ddl:
|
|
@contextmanager
|
|
def do_nothing():
|
|
yield
|
|
return do_nothing()
|
|
elif self.as_sql:
|
|
@contextmanager
|
|
def begin_commit():
|
|
self.impl.emit_begin()
|
|
yield
|
|
self.impl.emit_commit()
|
|
return begin_commit()
|
|
else:
|
|
return self.bind.begin()
|
|
|
|
def get_current_revision(self):
|
|
"""Return the current revision, usually that which is present
|
|
in the ``alembic_version`` table in the database.
|
|
|
|
This method intends to be used only for a migration stream that
|
|
does not contain unmerged branches in the target database;
|
|
if there are multiple branches present, an exception is raised.
|
|
The :meth:`.MigrationContext.get_current_heads` should be preferred
|
|
over this method going forward in order to be compatible with
|
|
branch migration support.
|
|
|
|
If this :class:`.MigrationContext` was configured in "offline"
|
|
mode, that is with ``as_sql=True``, the ``starting_rev``
|
|
parameter is returned instead, if any.
|
|
|
|
"""
|
|
heads = self.get_current_heads()
|
|
if len(heads) == 0:
|
|
return None
|
|
elif len(heads) > 1:
|
|
raise util.CommandError(
|
|
"Version table '%s' has more than one head present; "
|
|
"please use get_current_heads()" % self.version_table)
|
|
else:
|
|
return heads[0]
|
|
|
|
def get_current_heads(self):
|
|
"""Return a tuple of the current 'head versions' that are represented
|
|
in the target database.
|
|
|
|
For a migration stream without branches, this will be a single
|
|
value, synonymous with that of
|
|
:meth:`.MigrationContext.get_current_revision`. However when multiple
|
|
unmerged branches exist within the target database, the returned tuple
|
|
will contain a value for each head.
|
|
|
|
If this :class:`.MigrationContext` was configured in "offline"
|
|
mode, that is with ``as_sql=True``, the ``starting_rev``
|
|
parameter is returned in a one-length tuple.
|
|
|
|
If no version table is present, or if there are no revisions
|
|
present, an empty tuple is returned.
|
|
|
|
.. versionadded:: 0.7.0
|
|
|
|
"""
|
|
if self.as_sql:
|
|
return util.to_tuple(self._start_from_rev, default=())
|
|
else:
|
|
if self._start_from_rev:
|
|
raise util.CommandError(
|
|
"Can't specify current_rev to context "
|
|
"when using a database connection")
|
|
if not self._has_version_table():
|
|
return ()
|
|
return tuple(
|
|
row[0] for row in self.connection.execute(self._version.select())
|
|
)
|
|
|
|
def _ensure_version_table(self):
|
|
self._version.create(self.connection, checkfirst=True)
|
|
|
|
def _has_version_table(self):
|
|
return self.connection.dialect.has_table(
|
|
self.connection, self.version_table, self.version_table_schema)
|
|
|
|
def stamp(self, script_directory, revision):
|
|
"""Stamp the version table with a specific revision.
|
|
|
|
This method calculates those branches to which the given revision
|
|
can apply, and updates those branches as though they were migrated
|
|
towards that revision (either up or down). If no current branches
|
|
include the revision, it is added as a new branch head.
|
|
|
|
.. versionadded:: 0.7.0
|
|
|
|
"""
|
|
heads = self.get_current_heads()
|
|
head_maintainer = HeadMaintainer(self, heads)
|
|
for step in script_directory._steps_revs(revision, heads):
|
|
head_maintainer.update_to_step(step)
|
|
|
|
def run_migrations(self, **kw):
|
|
"""Run the migration scripts established for this
|
|
:class:`.MigrationContext`, if any.
|
|
|
|
The commands in :mod:`alembic.command` will set up a function
|
|
that is ultimately passed to the :class:`.MigrationContext`
|
|
as the ``fn`` argument. This function represents the "work"
|
|
that will be done when :meth:`.MigrationContext.run_migrations`
|
|
is called, typically from within the ``env.py`` script of the
|
|
migration environment. The "work function" then provides an iterable
|
|
of version callables and other version information which
|
|
in the case of the ``upgrade`` or ``downgrade`` commands are the
|
|
list of version scripts to invoke. Other commands yield nothing,
|
|
in the case that a command wants to run some other operation
|
|
against the database such as the ``current`` or ``stamp`` commands.
|
|
|
|
:param \**kw: keyword arguments here will be passed to each
|
|
migration callable, that is the ``upgrade()`` or ``downgrade()``
|
|
method within revision scripts.
|
|
|
|
"""
|
|
self.impl.start_migrations()
|
|
|
|
heads = self.get_current_heads()
|
|
if not self.as_sql and not heads:
|
|
self._ensure_version_table()
|
|
|
|
head_maintainer = HeadMaintainer(self, heads)
|
|
|
|
for step in self._migrations_fn(heads, self):
|
|
with self.begin_transaction(_per_migration=True):
|
|
if self.as_sql and not head_maintainer.heads:
|
|
# for offline mode, include a CREATE TABLE from
|
|
# the base
|
|
self._version.create(self.connection)
|
|
log.info("Running %s", step)
|
|
if self.as_sql:
|
|
self.impl.static_output("-- Running %s" % (step.short_log,))
|
|
step.migration_fn(**kw)
|
|
|
|
# previously, we wouldn't stamp per migration
|
|
# if we were in a transaction, however given the more
|
|
# complex model that involves any number of inserts
|
|
# and row-targeted updates and deletes, it's simpler for now
|
|
# just to run the operations on every version
|
|
head_maintainer.update_to_step(step)
|
|
|
|
if self.as_sql and not head_maintainer.heads:
|
|
self._version.drop(self.connection)
|
|
|
|
def execute(self, sql, execution_options=None):
|
|
"""Execute a SQL construct or string statement.
|
|
|
|
The underlying execution mechanics are used, that is
|
|
if this is "offline mode" the SQL is written to the
|
|
output buffer, otherwise the SQL is emitted on
|
|
the current SQLAlchemy connection.
|
|
|
|
"""
|
|
self.impl._exec(sql, execution_options)
|
|
|
|
def _stdout_connection(self, connection):
|
|
def dump(construct, *multiparams, **params):
|
|
self.impl._exec(construct)
|
|
|
|
return create_engine("%s://" % self.dialect.name,
|
|
strategy="mock", executor=dump)
|
|
|
|
@property
|
|
def bind(self):
|
|
"""Return the current "bind".
|
|
|
|
In online mode, this is an instance of
|
|
:class:`sqlalchemy.engine.Connection`, and is suitable
|
|
for ad-hoc execution of any kind of usage described
|
|
in :ref:`sqlexpression_toplevel` as well as
|
|
for usage with the :meth:`sqlalchemy.schema.Table.create`
|
|
and :meth:`sqlalchemy.schema.MetaData.create_all` methods
|
|
of :class:`~sqlalchemy.schema.Table`,
|
|
:class:`~sqlalchemy.schema.MetaData`.
|
|
|
|
Note that when "standard output" mode is enabled,
|
|
this bind will be a "mock" connection handler that cannot
|
|
return results and is only appropriate for a very limited
|
|
subset of commands.
|
|
|
|
"""
|
|
return self.connection
|
|
|
|
@property
|
|
def config(self):
|
|
"""Return the :class:`.Config` used by the current environment, if any.
|
|
|
|
.. versionadded:: 0.6.6
|
|
|
|
"""
|
|
if self.environment_context:
|
|
return self.environment_context.config
|
|
else:
|
|
return None
|
|
|
|
def _compare_type(self, inspector_column, metadata_column):
|
|
if self._user_compare_type is False:
|
|
return False
|
|
|
|
if callable(self._user_compare_type):
|
|
user_value = self._user_compare_type(
|
|
self,
|
|
inspector_column,
|
|
metadata_column,
|
|
inspector_column.type,
|
|
metadata_column.type
|
|
)
|
|
if user_value is not None:
|
|
return user_value
|
|
|
|
return self.impl.compare_type(
|
|
inspector_column,
|
|
metadata_column)
|
|
|
|
def _compare_server_default(self, inspector_column,
|
|
metadata_column,
|
|
rendered_metadata_default,
|
|
rendered_column_default):
|
|
|
|
if self._user_compare_server_default is False:
|
|
return False
|
|
|
|
if callable(self._user_compare_server_default):
|
|
user_value = self._user_compare_server_default(
|
|
self,
|
|
inspector_column,
|
|
metadata_column,
|
|
rendered_column_default,
|
|
metadata_column.server_default,
|
|
rendered_metadata_default
|
|
)
|
|
if user_value is not None:
|
|
return user_value
|
|
|
|
return self.impl.compare_server_default(
|
|
inspector_column,
|
|
metadata_column,
|
|
rendered_metadata_default,
|
|
rendered_column_default)
|
|
|
|
|
|
class HeadMaintainer(object):
|
|
def __init__(self, context, heads):
|
|
self.context = context
|
|
self.heads = set(heads)
|
|
|
|
def _insert_version(self, version):
|
|
assert version not in self.heads
|
|
self.heads.add(version)
|
|
|
|
self.context.impl._exec(
|
|
self.context._version.insert().
|
|
values(
|
|
version_num=literal_column("'%s'" % version)
|
|
)
|
|
)
|
|
|
|
def _delete_version(self, version):
|
|
self.heads.remove(version)
|
|
|
|
ret = self.context.impl._exec(
|
|
self.context._version.delete().where(
|
|
self.context._version.c.version_num ==
|
|
literal_column("'%s'" % version)))
|
|
if not self.context.as_sql and ret.rowcount != 1:
|
|
raise util.CommandError(
|
|
"Online migration expected to match one "
|
|
"row when deleting '%s' in '%s'; "
|
|
"%d found"
|
|
% (version,
|
|
self.context.version_table, ret.rowcount))
|
|
|
|
def _update_version(self, from_, to_):
|
|
assert to_ not in self.heads
|
|
self.heads.remove(from_)
|
|
self.heads.add(to_)
|
|
|
|
ret = self.context.impl._exec(
|
|
self.context._version.update().
|
|
values(version_num=literal_column("'%s'" % to_)).where(
|
|
self.context._version.c.version_num
|
|
== literal_column("'%s'" % from_))
|
|
)
|
|
if not self.context.as_sql and ret.rowcount != 1:
|
|
raise util.CommandError(
|
|
"Online migration expected to match one "
|
|
"row when updating '%s' to '%s' in '%s'; "
|
|
"%d found"
|
|
% (from_, to_, self.context.version_table, ret.rowcount))
|
|
|
|
def update_to_step(self, step):
|
|
if step.should_delete_branch(self.heads):
|
|
vers = step.delete_version_num
|
|
log.debug("branch delete %s", vers)
|
|
self._delete_version(vers)
|
|
elif step.should_create_branch(self.heads):
|
|
vers = step.insert_version_num
|
|
log.debug("new branch insert %s", vers)
|
|
self._insert_version(vers)
|
|
elif step.should_merge_branches(self.heads):
|
|
# delete revs, update from rev, update to rev
|
|
(delete_revs, update_from_rev,
|
|
update_to_rev) = step.merge_branch_idents
|
|
log.debug(
|
|
"merge, delete %s, update %s to %s",
|
|
delete_revs, update_from_rev, update_to_rev)
|
|
for delrev in delete_revs:
|
|
self._delete_version(delrev)
|
|
self._update_version(update_from_rev, update_to_rev)
|
|
elif step.should_unmerge_branches(self.heads):
|
|
(update_from_rev, update_to_rev,
|
|
insert_revs) = step.unmerge_branch_idents
|
|
log.debug(
|
|
"unmerge, insert %s, update %s to %s",
|
|
insert_revs, update_from_rev, update_to_rev)
|
|
for insrev in insert_revs:
|
|
self._insert_version(insrev)
|
|
self._update_version(update_from_rev, update_to_rev)
|
|
else:
|
|
from_, to_ = step.update_version_num
|
|
log.debug("update %s to %s", from_, to_)
|
|
self._update_version(from_, to_)
|
|
|
|
|
|
class MigrationStep(object):
|
|
@property
|
|
def name(self):
|
|
return self.migration_fn.__name__
|
|
|
|
@classmethod
|
|
def upgrade_from_script(cls, revision_map, script):
|
|
return RevisionStep(revision_map, script, True)
|
|
|
|
@classmethod
|
|
def downgrade_from_script(cls, revision_map, script):
|
|
return RevisionStep(revision_map, script, False)
|
|
|
|
@property
|
|
def short_log(self):
|
|
return "%s %s -> %s" % (
|
|
self.name, tuple_rev_as_scalar(self.from_revisions),
|
|
tuple_rev_as_scalar(self.to_revisions)
|
|
)
|
|
|
|
def __str__(self):
|
|
if self.doc:
|
|
return "%s %s -> %s, %s" % (
|
|
self.name, tuple_rev_as_scalar(self.from_revisions),
|
|
tuple_rev_as_scalar(self.to_revisions), self.doc
|
|
)
|
|
else:
|
|
return self.short_log
|
|
|
|
|
|
class RevisionStep(MigrationStep):
|
|
def __init__(self, revision_map, revision, is_upgrade):
|
|
self.revision_map = revision_map
|
|
self.revision = revision
|
|
self.is_upgrade = is_upgrade
|
|
if is_upgrade:
|
|
self.migration_fn = revision.module.upgrade
|
|
else:
|
|
self.migration_fn = revision.module.downgrade
|
|
|
|
def __eq__(self, other):
|
|
return isinstance(other, RevisionStep) and \
|
|
other.revision == self.revision and \
|
|
self.is_upgrade == other.is_upgrade
|
|
|
|
@property
|
|
def doc(self):
|
|
return self.revision.doc
|
|
|
|
@property
|
|
def from_revisions(self):
|
|
if self.is_upgrade:
|
|
return self.revision._down_revision_tuple
|
|
else:
|
|
return (self.revision.revision, )
|
|
|
|
@property
|
|
def to_revisions(self):
|
|
if self.is_upgrade:
|
|
return (self.revision.revision, )
|
|
else:
|
|
return self.revision._down_revision_tuple
|
|
|
|
@property
|
|
def is_downgrade(self):
|
|
return not self.is_upgrade
|
|
|
|
@property
|
|
def _has_scalar_down_revision(self):
|
|
return len(self.revision._down_revision_tuple) == 1
|
|
|
|
def should_delete_branch(self, heads):
|
|
if not self.is_downgrade:
|
|
return False
|
|
|
|
if self.revision.revision not in heads:
|
|
return False
|
|
|
|
downrevs = self.revision._down_revision_tuple
|
|
if not downrevs:
|
|
# is a base
|
|
return True
|
|
elif len(downrevs) == 1:
|
|
downrev = self.revision_map.get_revision(downrevs[0])
|
|
|
|
if not downrev.is_branch_point:
|
|
return False
|
|
|
|
descendants = set(
|
|
r.revision for r in self.revision_map._get_descendant_nodes(
|
|
self.revision_map.get_revisions(downrev.nextrev),
|
|
check=False
|
|
)
|
|
)
|
|
|
|
# the downrev is a branchpoint, and other members or descendants
|
|
# of the branch are still in heads; so delete this branch.
|
|
# the reason this occurs is because traversal tries to stay
|
|
# fully on one branch down to the branchpoint before starting
|
|
# the other; so if we have a->b->(c1->d1->e1, c2->d2->e2),
|
|
# on a downgrade from the top we may go e1, d1, c1, now heads
|
|
# are at c1 and e2, with the current method, we don't know that
|
|
# "e2" is important unless we get all descendants of c1/c2
|
|
|
|
if len(descendants.intersection(heads).difference(
|
|
[self.revision.revision])):
|
|
|
|
# TODO: this doesn't work; make sure tests are here to ensure
|
|
# this fails
|
|
#if len(downrev.nextrev.intersection(heads).difference(
|
|
# [self.revision.revision])):
|
|
|
|
return True
|
|
else:
|
|
return False
|
|
else:
|
|
# is a merge point
|
|
return False
|
|
|
|
def should_create_branch(self, heads):
|
|
if not self.is_upgrade:
|
|
return False
|
|
|
|
downrevs = self.revision._down_revision_tuple
|
|
|
|
if not downrevs:
|
|
# is a base
|
|
return True
|
|
elif len(downrevs) == 1:
|
|
if downrevs[0] in heads:
|
|
return False
|
|
else:
|
|
return True
|
|
else:
|
|
# is a merge point
|
|
return False
|
|
|
|
def should_merge_branches(self, heads):
|
|
if not self.is_upgrade:
|
|
return False
|
|
|
|
downrevs = self.revision._down_revision_tuple
|
|
|
|
if len(downrevs) > 1 and \
|
|
len(heads.intersection(downrevs)) > 1:
|
|
return True
|
|
|
|
return False
|
|
|
|
def should_unmerge_branches(self, heads):
|
|
if not self.is_downgrade:
|
|
return False
|
|
|
|
downrevs = self.revision._down_revision_tuple
|
|
|
|
if self.revision.revision in heads and len(downrevs) > 1:
|
|
return True
|
|
|
|
return False
|
|
|
|
@property
|
|
def update_version_num(self):
|
|
assert self._has_scalar_down_revision
|
|
if self.is_upgrade:
|
|
return self.revision.down_revision, self.revision.revision
|
|
else:
|
|
return self.revision.revision, self.revision.down_revision
|
|
|
|
@property
|
|
def delete_version_num(self):
|
|
return self.revision.revision
|
|
|
|
@property
|
|
def insert_version_num(self):
|
|
return self.revision.revision
|
|
|
|
@property
|
|
def merge_branch_idents(self):
|
|
return (
|
|
# delete revs, update from rev, update to rev
|
|
self.from_revisions[0:-1], self.from_revisions[-1],
|
|
self.to_revisions[0]
|
|
)
|
|
|
|
@property
|
|
def unmerge_branch_idents(self):
|
|
return (
|
|
# update from rev, update to rev, insert revs
|
|
self.from_revisions[0], self.to_revisions[-1],
|
|
self.to_revisions[0:-1]
|
|
)
|
|
|
|
|
|
class StampStep(MigrationStep):
|
|
pass |