From 6dfceac5d26c9e757a11562fea1dda525f37b79b Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 15 Nov 2014 14:28:59 -0500 Subject: [PATCH] - add the concept of named branches and @ syntax --- alembic/command.py | 8 ++- alembic/config.py | 7 ++ alembic/revision.py | 113 +++++++++++++++++++++++++++----- alembic/script.py | 31 +++++++-- tests/test_revision.py | 87 ++++++++++++++++++++++++ tests/test_version_traversal.py | 4 +- 6 files changed, 224 insertions(+), 26 deletions(-) diff --git a/alembic/command.py b/alembic/command.py index a38a1ad..708d6cd 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -64,7 +64,7 @@ def init(config, directory, template='generic'): "settings in %r before proceeding." % config_file) -def revision(config, message=None, autogenerate=False, sql=False): +def revision(config, message=None, autogenerate=False, sql=False, head="head"): """Create a new revision file.""" script = ScriptDirectory.from_config(config) @@ -99,8 +99,10 @@ def revision(config, message=None, autogenerate=False, sql=False): template_args=template_args, ): script.run_env() - return script.generate_revision(util.rev_id(), message, refresh=True, - **template_args) + return script.generate_revision( + util.rev_id(), message, refresh=True, + head=head, + **template_args) def upgrade(config, revision, sql=False, tag=None): diff --git a/alembic/config.py b/alembic/config.py index 4f3e42d..7d46a00 100644 --- a/alembic/config.py +++ b/alembic/config.py @@ -212,6 +212,13 @@ class CommandLine(object): type=str, help="Arbitrary 'tag' name - can be used by " "custom env.py scripts.") + if 'head' in kwargs: + parser.add_argument( + "--head", + type=str, + help="Specify head revision or @head " + "to base new revision on." + ) if 'autogenerate' in kwargs: parser.add_argument( "--autogenerate", diff --git a/alembic/revision.py b/alembic/revision.py index 50daaa4..d64e18a 100644 --- a/alembic/revision.py +++ b/alembic/revision.py @@ -66,10 +66,12 @@ class RevisionMap(object): self.bases = () for revision in self._generator(): + if revision.revision in map_: util.warn("Revision %s is present more than once" % revision.revision) map_[revision.revision] = revision + self._add_branches(revision, map_) heads.add(revision.revision) if revision.is_base: self.bases += (revision.revision, ) @@ -85,6 +87,18 @@ class RevisionMap(object): self.heads = tuple(heads) return map_ + def _add_branches(self, revision, map_): + if revision.branch_names: + for branch_name in util.to_tuple(revision.branch_names, ()): + if branch_name in map_: + raise RevisionError( + "Branch name '%s' in revision %s already " + "used by revision %s" % + (branch_name, revision.revision, + map_[branch_name].revision) + ) + map_[branch_name] = revision + def add_revision(self, revision, _replace=False): """add a single revision to an existing map. @@ -100,6 +114,7 @@ class RevisionMap(object): raise Exception("revision %s not in map" % revision.revision) map_[revision.revision] = revision + self._add_branches(revision, map_) if revision.is_base: self.bases += (revision.revision, ) for downrev in revision.down_revision: @@ -116,7 +131,7 @@ class RevisionMap(object): set(revision.down_revision).union([revision.revision]) ) + (revision.revision,) - def get_current_head(self): + def get_current_head(self, branch_name=None): """Return the current head revision. If the script directory has multiple heads @@ -132,6 +147,12 @@ class RevisionMap(object): """ current_heads = self.heads + if branch_name: + current_heads = [ + h for h in current_heads + if self._shares_lineage(h, branch_name) + ] + if len(current_heads) > 1: raise MultipleHeads( "Multiple heads are present; please use current_heads()") @@ -155,8 +176,10 @@ class RevisionMap(object): full revision. """ - resolved_id = self._resolve_revision_number(id_) or () - return tuple(self.get_revision(rev_id) for rev_id in resolved_id) + resolved_id, branch_name = self._resolve_revision_number(id_) + return tuple( + self._revision_for_ident(rev_id, branch_name) + for rev_id in resolved_id) def get_revision(self, id_): """Return the :class:`.Revision` instance with the given rev id. @@ -172,40 +195,93 @@ class RevisionMap(object): """ - resolved_id = self._resolve_revision_number(id_) or () + resolved_id, branch_name = self._resolve_revision_number(id_) if len(resolved_id) > 1: raise MultipleHeads( "Identifier %r corresponds to multiple revisions" % id_) elif resolved_id: resolved_id = resolved_id[0] + return self._revision_for_ident(resolved_id, branch_name) + + def _resolve_branch(self, branch_name): try: - return self._revision_map[resolved_id] + branch_rev = self._revision_map[branch_name] + except KeyError: + try: + nonbranch_rev = self._revision_for_ident(branch_name) + except ResolutionError: + raise ResolutionError("No such branch: '%s'" % branch_name) + else: + return nonbranch_rev + else: + return branch_rev + + def _revision_for_ident(self, resolved_id, check_branch=None): + if check_branch: + branch_rev = self._resolve_branch(check_branch) + else: + branch_rev = None + + try: + revision = self._revision_map[resolved_id] except KeyError: # do a partial lookup revs = [x for x in self._revision_map if x and x.startswith(resolved_id)] + if branch_rev: + revs = [ + x for x in revs if + self._shares_lineage(x, check_branch)] if not revs: - raise ResolutionError("No such revision '%s'" % id_) + raise ResolutionError("No such revision '%s'" % resolved_id) elif len(revs) > 1: raise ResolutionError( "Multiple revisions start " "with '%s': %s..." % ( - id_, + resolved_id, ", ".join("'%s'" % r for r in revs[0:3]) )) else: - return self._revision_map[revs[0]] + revision = self._revision_map[revs[0]] + + if check_branch and revision is not None: + if not self._shares_lineage( + revision.revision, branch_rev.revision): + raise ResolutionError( + "Revision %s is not a member of branch '%s'" % + (revision.revision, check_branch)) + return revision + + def _shares_lineage(self, reva, revb): + if not isinstance(reva, Revision): + reva = self._revision_for_ident(reva) + if not isinstance(revb, Revision): + revb = self._revision_for_ident(revb) + + return revb in set( + self._get_descendant_nodes([reva])).union( + self._get_ancestor_nodes([reva])) def _resolve_revision_number(self, id_): - if id_ == 'heads': - return self.heads - elif id_ == 'head': - return (self.get_current_head(), ) - elif id_ == 'base': - return None + if isinstance(id_, compat.string_types) and "@" in id_: + branch_name, id_ = id_.split('@', 1) else: - return util.to_tuple(id_, default=None) + branch_name = None + + # ensure map is loaded + self._revision_map + if id_ == 'heads': + if branch_name: + raise RevisionError( + "Branch name given with 'heads' makes no sense") + return self.heads, branch_name + elif id_ == 'head': + return (self.get_current_head(branch_name), ), branch_name + elif id_ == 'base' or id_ is None: + return (), branch_name + else: + return util.to_tuple(id_, default=None), branch_name def iterate_revisions(self, upper, lower): """Iterate through script revisions, starting at the given @@ -365,9 +441,14 @@ class Revision(object): down_revision = None """The ``down_revision`` identifier(s) within the migration script.""" - def __init__(self, revision, down_revision): + branch_names = None + """Optional string/tuple of symbolic names to apply to this + revision's branch""" + + def __init__(self, revision, down_revision, branch_names=None): self.revision = revision self.down_revision = down_revision + self.branch_names = branch_names def add_nextrev(self, rev): self.nextrev = self.nextrev.union([rev]) diff --git a/alembic/script.py b/alembic/script.py index 1d0e071..8da419a 100644 --- a/alembic/script.py +++ b/alembic/script.py @@ -243,7 +243,9 @@ class ScriptDirectory(object): shutil.copy, src, dest) - def generate_revision(self, revid, message, head=None, refresh=False, **kw): + def generate_revision( + self, revid, message, head=None, + refresh=False, splice=False, **kw): """Generate a new revision file. This runs the ``script.py.mako`` template, given @@ -266,20 +268,37 @@ class ScriptDirectory(object): If False, the file is created but the state of the :class:`.ScriptDirectory` is unmodified; ``None`` is returned. + :param splice: if True, allow the "head" version to not be an + actual head; otherwise, the selected head must be a head + (e.g. endpoint) revision. """ if head is None: - head = self.get_current_head() + head = "head" - heads = util.to_tuple(head, default=()) + try: + heads = self.revision_map.get_revisions(head) + except revision.MultipleHeads: + raise util.CommandError( + "Multiple heads are present; please specify the head " + "revision on which the new revision should be based, " + "or perform a merge.") create_date = datetime.datetime.now() path = self._rev_path(revid, message, create_date) + + if not splice: + for head in heads: + if head is not None and not head.is_head: + raise util.CommandError( + "Revision %s is not a head revision" % head.revision) + self._generate_template( os.path.join(self.dir, "script.py.mako"), path, up_revision=str(revid), - down_revision=revision.tuple_rev_as_scalar(heads), + down_revision=revision.tuple_rev_as_scalar( + tuple(h.revision if h is not None else None for h in heads)), create_date=create_date, message=message if message is not None else ("empty message"), **kw @@ -324,7 +343,9 @@ class Script(revision.Revision): self.path = path super(Script, self).__init__( rev_id, - util.to_tuple(module.down_revision, default=())) + util.to_tuple(module.down_revision, default=()), + branch_names=util.to_tuple( + getattr(module, 'branch_names', None), default=())) module = None """The Python module representing the actual script itself.""" diff --git a/tests/test_revision.py b/tests/test_revision.py index d852bcc..ba53ad7 100644 --- a/tests/test_revision.py +++ b/tests/test_revision.py @@ -107,6 +107,93 @@ class DiamondTest(DownIterateTest): ) +class NamedBranchTest(TestBase): + def test_dupe_branch_collection(self): + fn = lambda: [ + Revision('a', ()), + Revision('b', ('a',)), + Revision('c', ('b',), branch_names=['xy1']), + Revision('d', ()), + Revision('e', ('d',), branch_names=['xy1']), + Revision('f', ('e',)) + ] + assert_raises_message( + RevisionError, + "Branch name 'xy1' in revision e already used by revision c", + getattr, RevisionMap(fn), "_revision_map" + ) + + def setUp(self): + self.map_ = RevisionMap(lambda: [ + Revision('a', (), branch_names='abranch'), + Revision('b', ('a',)), + Revision('somelongername', ('b',)), + Revision('c', ('somelongername',)), + Revision('d', ()), + Revision('e', ('d',), branch_names=['ebranch']), + Revision('someothername', ('e',)), + Revision('f', ('someothername',)), + ]) + + def test_partial_id_resolve(self): + eq_(self.map_.get_revision("ebranch@some").revision, "someothername") + eq_(self.map_.get_revision("abranch@some").revision, "somelongername") + + def test_branch_at_heads(self): + assert_raises_message( + RevisionError, + "Branch name given with 'heads' makes no sense", + self.map_.get_revision, "abranch@heads" + ) + + def test_branch_at_syntax(self): + eq_(self.map_.get_revision("abranch@head").revision, 'c') + eq_(self.map_.get_revision("abranch@base"), None) + eq_(self.map_.get_revision("ebranch@head").revision, 'f') + eq_(self.map_.get_revision("abranch@base"), None) + eq_(self.map_.get_revision("ebranch@d").revision, 'd') + + def test_branch_at_self(self): + eq_(self.map_.get_revision("ebranch@ebranch").revision, 'e') + + def test_retrieve_branch_revision(self): + eq_(self.map_.get_revision("abranch").revision, 'a') + eq_(self.map_.get_revision("ebranch").revision, 'e') + + def test_rev_not_in_branch(self): + assert_raises_message( + RevisionError, + "Revision b is not a member of branch 'ebranch'", + self.map_.get_revision, "ebranch@b" + ) + + assert_raises_message( + RevisionError, + "Revision d is not a member of branch 'abranch'", + self.map_.get_revision, "abranch@d" + ) + + def test_no_revision_exists(self): + assert_raises_message( + RevisionError, + "No such revision 'q'", + self.map_.get_revision, "abranch@q" + ) + + def test_not_actually_a_branch(self): + eq_(self.map_.get_revision("e@d").revision, "d") + + def test_not_actually_a_branch_partial_resolution(self): + eq_(self.map_.get_revision("someoth@d").revision, "d") + + def test_no_such_branch(self): + assert_raises_message( + RevisionError, + "No such branch: 'x'", + self.map_.get_revision, "x@d" + ) + + class MultipleBranchTest(DownIterateTest): def setUp(self): self.map = RevisionMap( diff --git a/tests/test_version_traversal.py b/tests/test_version_traversal.py index af70a36..7ef710e 100644 --- a/tests/test_version_traversal.py +++ b/tests/test_version_traversal.py @@ -150,7 +150,7 @@ class BranchedPathTest(TestBase): cls.c2 = env.generate_revision( util.rev_id(), 'b->c2', - head=cls.b.revision, refresh=True) + head=cls.b.revision, refresh=True, splice=True) cls.d2 = env.generate_revision( util.rev_id(), 'c2->d2', head=cls.c2.revision, refresh=True) @@ -228,7 +228,7 @@ class MergedPathTest(TestBase): cls.c2 = env.generate_revision( util.rev_id(), 'b->c2', - head=cls.b.revision, refresh=True) + head=cls.b.revision, refresh=True, splice=True) cls.d2 = env.generate_revision( util.rev_id(), 'c2->d2', head=cls.c2.revision, refresh=True)