From 89a6fa3c1436a490d3c664ee5d37aa95e2d8a9a9 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 18 Jun 2015 17:11:49 -0400 Subject: [PATCH] - factor schema object creator functions into a separate object --- alembic/operations/base.py | 194 +++++++------------------------- alembic/operations/schemaobj.py | 137 ++++++++++++++++++++++ tests/test_batch.py | 18 +-- 3 files changed, 185 insertions(+), 164 deletions(-) create mode 100644 alembic/operations/schemaobj.py diff --git a/alembic/operations/base.py b/alembic/operations/base.py index cda5aa2..59fa7af 100644 --- a/alembic/operations/base.py +++ b/alembic/operations/base.py @@ -1,11 +1,11 @@ from contextlib import contextmanager -from sqlalchemy.types import NULLTYPE, Integer +from sqlalchemy.types import NULLTYPE from sqlalchemy import schema as sa_schema from .. import util from . import batch -from ..util.compat import string_types +from . import schemaobj from ..ddl import impl __all__ = ('Operations', 'BatchOperations') @@ -56,6 +56,8 @@ class Operations(object): else: self.impl = impl + self.schema_obj = schemaobj.SchemaObjects(migration_context) + @classmethod @contextmanager def context(cls, migration_context): @@ -65,131 +67,6 @@ class Operations(object): yield op _remove_proxy() - def _primary_key_constraint(self, name, table_name, cols, schema=None): - m = self._metadata() - columns = [sa_schema.Column(n, NULLTYPE) for n in cols] - t1 = sa_schema.Table(table_name, m, - *columns, - schema=schema) - p = sa_schema.PrimaryKeyConstraint(*columns, name=name) - t1.append_constraint(p) - return p - - def _foreign_key_constraint(self, name, source, referent, - local_cols, remote_cols, - onupdate=None, ondelete=None, - deferrable=None, source_schema=None, - referent_schema=None, initially=None, - match=None, **dialect_kw): - m = self._metadata() - if source == referent: - t1_cols = local_cols + remote_cols - else: - t1_cols = local_cols - sa_schema.Table( - referent, m, - *[sa_schema.Column(n, NULLTYPE) for n in remote_cols], - schema=referent_schema) - - t1 = sa_schema.Table( - source, m, - *[sa_schema.Column(n, NULLTYPE) for n in t1_cols], - schema=source_schema) - - tname = "%s.%s" % (referent_schema, referent) if referent_schema \ - else referent - - if util.sqla_08: - # "match" kw unsupported in 0.7 - dialect_kw['match'] = match - - f = sa_schema.ForeignKeyConstraint(local_cols, - ["%s.%s" % (tname, n) - for n in remote_cols], - name=name, - onupdate=onupdate, - ondelete=ondelete, - deferrable=deferrable, - initially=initially, - **dialect_kw - ) - t1.append_constraint(f) - - return f - - def _unique_constraint(self, name, source, local_cols, schema=None, **kw): - t = sa_schema.Table( - source, self._metadata(), - *[sa_schema.Column(n, NULLTYPE) for n in local_cols], - schema=schema) - kw['name'] = name - uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw) - # TODO: need event tests to ensure the event - # is fired off here - t.append_constraint(uq) - return uq - - def _check_constraint(self, name, source, condition, schema=None, **kw): - t = sa_schema.Table(source, self._metadata(), - sa_schema.Column('x', Integer), schema=schema) - ck = sa_schema.CheckConstraint(condition, name=name, **kw) - t.append_constraint(ck) - return ck - - def _metadata(self): - kw = {} - if 'target_metadata' in self.migration_context.opts: - mt = self.migration_context.opts['target_metadata'] - if hasattr(mt, 'naming_convention'): - kw['naming_convention'] = mt.naming_convention - return sa_schema.MetaData(**kw) - - def _table(self, name, *columns, **kw): - m = self._metadata() - t = sa_schema.Table(name, m, *columns, **kw) - for f in t.foreign_keys: - self._ensure_table_for_fk(m, f) - return t - - def _column(self, name, type_, **kw): - return sa_schema.Column(name, type_, **kw) - - def _index(self, name, tablename, columns, schema=None, **kw): - t = sa_schema.Table( - tablename or 'no_table', self._metadata(), - schema=schema - ) - idx = sa_schema.Index( - name, - *[impl._textual_index_column(t, n) for n in columns], - **kw) - return idx - - def _parse_table_key(self, table_key): - if '.' in table_key: - tokens = table_key.split('.') - sname = ".".join(tokens[0:-1]) - tname = tokens[-1] - else: - tname = table_key - sname = None - return (sname, tname) - - def _ensure_table_for_fk(self, metadata, fk): - """create a placeholder Table object for the referent of a - ForeignKey. - - """ - if isinstance(fk._colspec, string_types): - table_key, cname = fk._colspec.rsplit('.', 1) - sname, tname = self._parse_table_key(table_key) - if table_key not in metadata.tables: - rel_t = sa_schema.Table(tname, metadata, schema=sname) - else: - rel_t = metadata.tables[table_key] - if cname not in rel_t.c: - rel_t.append_column(sa_schema.Column(cname, NULLTYPE)) - @contextmanager def batch_alter_table( self, table_name, schema=None, recreate="auto", copy_from=None, @@ -458,10 +335,11 @@ class Operations(object): constraint._create_rule(compiler)) if existing_type and type_: - t = self._table(table_name, - sa_schema.Column(column_name, existing_type), - schema=schema - ) + t = self.schema_obj.table( + table_name, + sa_schema.Column(column_name, existing_type), + schema=schema + ) for constraint in t.constraints: if _count_constraint(constraint): self.impl.drop_constraint(constraint) @@ -480,10 +358,11 @@ class Operations(object): ) if type_: - t = self._table(table_name, - sa_schema.Column(column_name, type_), - schema=schema - ) + t = self.schema_obj.table( + table_name, + self.schema_obj.column(column_name, type_), + schema=schema + ) for constraint in t.constraints: if _count_constraint(constraint): self.impl.add_constraint(constraint) @@ -589,7 +468,7 @@ class Operations(object): """ - t = self._table(table_name, column, schema=schema) + t = self.schema_obj.table(table_name, column, schema=schema) self.impl.add_column( table_name, column, @@ -647,7 +526,7 @@ class Operations(object): self.impl.drop_column( table_name, - self._column(column_name, NULLTYPE), + self.schema_obj.column(column_name, NULLTYPE), **kw ) @@ -692,8 +571,8 @@ class Operations(object): """ self.impl.add_constraint( - self._primary_key_constraint(name, table_name, cols, - schema) + self.schema_obj.primary_key_constraint( + name, table_name, cols, schema) ) def create_foreign_key(self, name, source, referent, local_cols, @@ -747,14 +626,15 @@ class Operations(object): """ self.impl.add_constraint( - self._foreign_key_constraint(name, source, referent, - local_cols, remote_cols, - onupdate=onupdate, ondelete=ondelete, - deferrable=deferrable, - source_schema=source_schema, - referent_schema=referent_schema, - initially=initially, match=match, - **dialect_kw) + self.schema_obj.foreign_key_constraint( + name, source, referent, + local_cols, remote_cols, + onupdate=onupdate, ondelete=ondelete, + deferrable=deferrable, + source_schema=source_schema, + referent_schema=referent_schema, + initially=initially, match=match, + **dialect_kw) ) def create_unique_constraint(self, name, source, local_cols, @@ -802,8 +682,9 @@ class Operations(object): """ self.impl.add_constraint( - self._unique_constraint(name, source, local_cols, - schema=schema, **kw) + self.schema_obj.unique_constraint( + name, source, local_cols, + schema=schema, **kw) ) def create_check_constraint(self, name, source, condition, @@ -852,7 +733,7 @@ class Operations(object): """ self.impl.add_constraint( - self._check_constraint( + self.schema_obj.check_constraint( name, source, condition, schema=schema, **kw) ) @@ -941,7 +822,7 @@ class Operations(object): object is returned. """ - table = self._table(name, *columns, **kw) + table = self.schema_obj.table(name, *columns, **kw) self.impl.create_table(table) return table @@ -968,7 +849,7 @@ class Operations(object): """ self.impl.drop_table( - self._table(name, **kw) + self.schema_obj.table(name, **kw) ) def create_index(self, name, table_name, columns, schema=None, @@ -1024,8 +905,9 @@ class Operations(object): """ self.impl.create_index( - self._index(name, table_name, columns, schema=schema, - unique=unique, quote=quote, **kw) + self.schema_obj.index( + name, table_name, columns, schema=schema, + unique=unique, quote=quote, **kw) ) @util._with_legacy_names([('tablename', 'table_name')]) @@ -1052,7 +934,7 @@ class Operations(object): # need a dummy column name here since SQLAlchemy # 0.7.6 and further raises on Index with no columns self.impl.drop_index( - self._index(name, table_name, ['x'], schema=schema) + self.schema_obj.index(name, table_name, ['x'], schema=schema) ) @util._with_legacy_names([("type", "type_")]) @@ -1073,7 +955,7 @@ class Operations(object): """ - t = self._table(table_name, schema=schema) + t = self.schema_obj.table(table_name, schema=schema) types = { 'foreignkey': lambda name: sa_schema.ForeignKeyConstraint( [], [], name=name), diff --git a/alembic/operations/schemaobj.py b/alembic/operations/schemaobj.py new file mode 100644 index 0000000..b5a8e08 --- /dev/null +++ b/alembic/operations/schemaobj.py @@ -0,0 +1,137 @@ +from sqlalchemy import schema as sa_schema +from sqlalchemy.types import NULLTYPE, Integer +from ..util.compat import string_types +from .. import util +from ..ddl import impl + + +class SchemaObjects(object): + + def __init__(self, migration_context): + self.migration_context = migration_context + + def primary_key_constraint(self, name, table_name, cols, schema=None): + m = self.metadata() + columns = [sa_schema.Column(n, NULLTYPE) for n in cols] + t1 = sa_schema.Table(table_name, m, + *columns, + schema=schema) + p = sa_schema.PrimaryKeyConstraint(*columns, name=name) + t1.append_constraint(p) + return p + + def foreign_key_constraint( + self, name, source, referent, + local_cols, remote_cols, + onupdate=None, ondelete=None, + deferrable=None, source_schema=None, + referent_schema=None, initially=None, + match=None, **dialect_kw): + m = self.metadata() + if source == referent: + t1_cols = local_cols + remote_cols + else: + t1_cols = local_cols + sa_schema.Table( + referent, m, + *[sa_schema.Column(n, NULLTYPE) for n in remote_cols], + schema=referent_schema) + + t1 = sa_schema.Table( + source, m, + *[sa_schema.Column(n, NULLTYPE) for n in t1_cols], + schema=source_schema) + + tname = "%s.%s" % (referent_schema, referent) if referent_schema \ + else referent + + if util.sqla_08: + # "match" kw unsupported in 0.7 + dialect_kw['match'] = match + + f = sa_schema.ForeignKeyConstraint(local_cols, + ["%s.%s" % (tname, n) + for n in remote_cols], + name=name, + onupdate=onupdate, + ondelete=ondelete, + deferrable=deferrable, + initially=initially, + **dialect_kw + ) + t1.append_constraint(f) + + return f + + def unique_constraint(self, name, source, local_cols, schema=None, **kw): + t = sa_schema.Table( + source, self.metadata(), + *[sa_schema.Column(n, NULLTYPE) for n in local_cols], + schema=schema) + kw['name'] = name + uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw) + # TODO: need event tests to ensure the event + # is fired off here + t.append_constraint(uq) + return uq + + def check_constraint(self, name, source, condition, schema=None, **kw): + t = sa_schema.Table(source, self.metadata(), + sa_schema.Column('x', Integer), schema=schema) + ck = sa_schema.CheckConstraint(condition, name=name, **kw) + t.append_constraint(ck) + return ck + + def metadata(self): + kw = {} + if 'target_metadata' in self.migration_context.opts: + mt = self.migration_context.opts['target_metadata'] + if hasattr(mt, 'naming_convention'): + kw['naming_convention'] = mt.naming_convention + return sa_schema.MetaData(**kw) + + def table(self, name, *columns, **kw): + m = self.metadata() + t = sa_schema.Table(name, m, *columns, **kw) + for f in t.foreign_keys: + self._ensure_table_for_fk(m, f) + return t + + def column(self, name, type_, **kw): + return sa_schema.Column(name, type_, **kw) + + def index(self, name, tablename, columns, schema=None, **kw): + t = sa_schema.Table( + tablename or 'no_table', self.metadata(), + schema=schema + ) + idx = sa_schema.Index( + name, + *[impl._textual_index_column(t, n) for n in columns], + **kw) + return idx + + def _parse_table_key(self, table_key): + if '.' in table_key: + tokens = table_key.split('.') + sname = ".".join(tokens[0:-1]) + tname = tokens[-1] + else: + tname = table_key + sname = None + return (sname, tname) + + def _ensure_table_for_fk(self, metadata, fk): + """create a placeholder Table object for the referent of a + ForeignKey. + + """ + if isinstance(fk._colspec, string_types): + table_key, cname = fk._colspec.rsplit('.', 1) + sname, tname = self._parse_table_key(table_key) + if table_key not in metadata.tables: + rel_t = sa_schema.Table(tname, metadata, schema=sname) + else: + rel_t = metadata.tables[table_key] + if cname not in rel_t.c: + rel_t.append_column(sa_schema.Column(cname, NULLTYPE)) diff --git a/tests/test_batch.py b/tests/test_batch.py index c827ac4..4226c8e 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -328,7 +328,7 @@ class BatchApplyTest(TestBase): impl = self._simple_fixture() col = Column('g', Integer) # operations.add_column produces a table - t = self.op._table('tname', col) # noqa + t = self.op.schema_obj.table('tname', col) # noqa impl.add_column('tname', col) new_table = self._assert_impl(impl, colnames=['id', 'x', 'y', 'g']) eq_(new_table.c.g.name, 'g') @@ -418,7 +418,7 @@ class BatchApplyTest(TestBase): def test_add_fk(self): impl = self._simple_fixture() impl.add_column('tname', Column('user_id', Integer)) - fk = self.op._foreign_key_constraint( + fk = self.op.schema_obj.foreign_key_constraint( 'fk1', 'tname', 'user', ['user_id'], ['id']) impl.add_constraint(fk) @@ -445,7 +445,7 @@ class BatchApplyTest(TestBase): def test_add_uq(self): impl = self._simple_fixture() - uq = self.op._unique_constraint( + uq = self.op.schema_obj.unique_constraint( 'uq1', 'tname', ['y'] ) @@ -457,7 +457,7 @@ class BatchApplyTest(TestBase): def test_drop_uq(self): impl = self._uq_fixture() - uq = self.op._unique_constraint( + uq = self.op.schema_obj.unique_constraint( 'uq1', 'tname', ['y'] ) impl.drop_constraint(uq) @@ -467,7 +467,7 @@ class BatchApplyTest(TestBase): def test_create_index(self): impl = self._simple_fixture() - ix = self.op._index('ix1', 'tname', ['y']) + ix = self.op.schema_obj.index('ix1', 'tname', ['y']) impl.create_index(ix) self._assert_impl( @@ -477,7 +477,7 @@ class BatchApplyTest(TestBase): def test_drop_index(self): impl = self._ix_fixture() - ix = self.op._index('ix1', 'tname', ['y']) + ix = self.op.schema_obj.index('ix1', 'tname', ['y']) impl.drop_index(ix) self._assert_impl( impl, colnames=['id', 'x', 'y'], @@ -501,8 +501,10 @@ class BatchAPITest(TestBase): batch = op.batch_alter_table( 'tname', recreate='never', schema=schema).__enter__() - with mock.patch("alembic.operations.base.sa_schema") as mock_schema: - yield batch + mock_schema = mock.MagicMock() + with mock.patch("alembic.operations.schemaobj.sa_schema", mock_schema): + with mock.patch("alembic.operations.base.sa_schema", mock_schema): + yield batch batch.impl.flush() self.mock_schema = mock_schema