288 lines
10 KiB
Python
288 lines
10 KiB
Python
from sqlalchemy import Table, MetaData, Index, select, Column, \
|
|
ForeignKeyConstraint, cast
|
|
from sqlalchemy import types as sqltypes
|
|
from sqlalchemy import schema as sql_schema
|
|
from sqlalchemy.util import OrderedDict
|
|
from .. import util
|
|
from ..util.sqla_compat import _columns_for_constraint, _is_type_bound
|
|
|
|
|
|
class BatchOperationsImpl(object):
|
|
def __init__(self, operations, table_name, schema, recreate,
|
|
copy_from, table_args, table_kwargs,
|
|
reflect_args, reflect_kwargs, naming_convention):
|
|
if not util.sqla_08:
|
|
raise NotImplementedError(
|
|
"batch mode requires SQLAlchemy 0.8 or greater.")
|
|
self.operations = operations
|
|
self.table_name = table_name
|
|
self.schema = schema
|
|
if recreate not in ('auto', 'always', 'never'):
|
|
raise ValueError(
|
|
"recreate may be one of 'auto', 'always', or 'never'.")
|
|
self.recreate = recreate
|
|
self.copy_from = copy_from
|
|
self.table_args = table_args
|
|
self.table_kwargs = table_kwargs
|
|
self.reflect_args = reflect_args
|
|
self.reflect_kwargs = reflect_kwargs
|
|
self.naming_convention = naming_convention
|
|
self.batch = []
|
|
|
|
@property
|
|
def dialect(self):
|
|
return self.operations.impl.dialect
|
|
|
|
@property
|
|
def impl(self):
|
|
return self.operations.impl
|
|
|
|
def _should_recreate(self):
|
|
if self.recreate == 'auto':
|
|
return self.operations.impl.requires_recreate_in_batch(self)
|
|
elif self.recreate == 'always':
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def flush(self):
|
|
should_recreate = self._should_recreate()
|
|
|
|
if not should_recreate:
|
|
for opname, arg, kw in self.batch:
|
|
fn = getattr(self.operations.impl, opname)
|
|
fn(*arg, **kw)
|
|
else:
|
|
if self.naming_convention:
|
|
m1 = MetaData(naming_convention=self.naming_convention)
|
|
else:
|
|
m1 = MetaData()
|
|
|
|
if self.copy_from is not None:
|
|
existing_table = self.copy_from
|
|
else:
|
|
existing_table = Table(
|
|
self.table_name, m1,
|
|
schema=self.schema,
|
|
autoload=True,
|
|
autoload_with=self.operations.get_bind(),
|
|
*self.reflect_args, **self.reflect_kwargs)
|
|
|
|
batch_impl = ApplyBatchImpl(
|
|
existing_table, self.table_args, self.table_kwargs)
|
|
for opname, arg, kw in self.batch:
|
|
fn = getattr(batch_impl, opname)
|
|
fn(*arg, **kw)
|
|
|
|
batch_impl._create(self.impl)
|
|
|
|
def alter_column(self, *arg, **kw):
|
|
self.batch.append(("alter_column", arg, kw))
|
|
|
|
def add_column(self, *arg, **kw):
|
|
self.batch.append(("add_column", arg, kw))
|
|
|
|
def drop_column(self, *arg, **kw):
|
|
self.batch.append(("drop_column", arg, kw))
|
|
|
|
def add_constraint(self, const):
|
|
self.batch.append(("add_constraint", (const,), {}))
|
|
|
|
def drop_constraint(self, const):
|
|
self.batch.append(("drop_constraint", (const, ), {}))
|
|
|
|
def rename_table(self, *arg, **kw):
|
|
self.batch.append(("rename_table", arg, kw))
|
|
|
|
def create_index(self, idx):
|
|
self.batch.append(("create_index", (idx,), {}))
|
|
|
|
def drop_index(self, idx):
|
|
self.batch.append(("drop_index", (idx,), {}))
|
|
|
|
def create_table(self, table):
|
|
raise NotImplementedError("Can't create table in batch mode")
|
|
|
|
def drop_table(self, table):
|
|
raise NotImplementedError("Can't drop table in batch mode")
|
|
|
|
|
|
class ApplyBatchImpl(object):
|
|
def __init__(self, table, table_args, table_kwargs):
|
|
self.table = table # this is a Table object
|
|
self.table_args = table_args
|
|
self.table_kwargs = table_kwargs
|
|
self.new_table = None
|
|
self.column_transfers = OrderedDict(
|
|
(c.name, {'expr': c}) for c in self.table.c
|
|
)
|
|
self._grab_table_elements()
|
|
|
|
def _grab_table_elements(self):
|
|
schema = self.table.schema
|
|
self.columns = OrderedDict()
|
|
for c in self.table.c:
|
|
c_copy = c.copy(schema=schema)
|
|
c_copy.unique = c_copy.index = False
|
|
self.columns[c.name] = c_copy
|
|
self.named_constraints = {}
|
|
self.unnamed_constraints = []
|
|
self.indexes = {}
|
|
for const in self.table.constraints:
|
|
if _is_type_bound(const):
|
|
continue
|
|
if const.name:
|
|
self.named_constraints[const.name] = const
|
|
else:
|
|
self.unnamed_constraints.append(const)
|
|
|
|
for idx in self.table.indexes:
|
|
self.indexes[idx.name] = idx
|
|
|
|
def _transfer_elements_to_new_table(self):
|
|
assert self.new_table is None, "Can only create new table once"
|
|
|
|
m = MetaData()
|
|
schema = self.table.schema
|
|
self.new_table = new_table = Table(
|
|
'_alembic_batch_temp', m,
|
|
*(list(self.columns.values()) + list(self.table_args)),
|
|
schema=schema,
|
|
**self.table_kwargs)
|
|
|
|
for const in list(self.named_constraints.values()) + \
|
|
self.unnamed_constraints:
|
|
|
|
const_columns = set([
|
|
c.key for c in _columns_for_constraint(const)])
|
|
|
|
if not const_columns.issubset(self.column_transfers):
|
|
continue
|
|
const_copy = const.copy(schema=schema, target_table=new_table)
|
|
if isinstance(const, ForeignKeyConstraint):
|
|
self._setup_referent(m, const)
|
|
new_table.append_constraint(const_copy)
|
|
|
|
for index in self.indexes.values():
|
|
Index(index.name,
|
|
unique=index.unique,
|
|
*[new_table.c[col] for col in index.columns.keys()],
|
|
**index.kwargs)
|
|
|
|
def _setup_referent(self, metadata, constraint):
|
|
spec = constraint.elements[0]._get_colspec()
|
|
parts = spec.split(".")
|
|
tname = parts[-2]
|
|
if len(parts) == 3:
|
|
referent_schema = parts[0]
|
|
else:
|
|
referent_schema = None
|
|
if tname != '_alembic_batch_temp':
|
|
key = sql_schema._get_table_key(tname, referent_schema)
|
|
if key in metadata.tables:
|
|
t = metadata.tables[key]
|
|
for elem in constraint.elements:
|
|
colname = elem._get_colspec().split(".")[-1]
|
|
if not t.c.contains_column(colname):
|
|
t.append_column(
|
|
Column(colname, sqltypes.NULLTYPE)
|
|
)
|
|
else:
|
|
Table(
|
|
tname, metadata,
|
|
*[Column(n, sqltypes.NULLTYPE) for n in
|
|
[elem._get_colspec().split(".")[-1]
|
|
for elem in constraint.elements]],
|
|
schema=referent_schema)
|
|
|
|
def _create(self, op_impl):
|
|
self._transfer_elements_to_new_table()
|
|
|
|
op_impl.prep_table_for_batch(self.table)
|
|
op_impl.create_table(self.new_table)
|
|
|
|
try:
|
|
op_impl._exec(
|
|
self.new_table.insert(inline=True).from_select(
|
|
list(k for k, transfer in
|
|
self.column_transfers.items() if 'expr' in transfer),
|
|
select([
|
|
transfer['expr']
|
|
for transfer in self.column_transfers.values()
|
|
if 'expr' in transfer
|
|
])
|
|
)
|
|
)
|
|
op_impl.drop_table(self.table)
|
|
except:
|
|
op_impl.drop_table(self.new_table)
|
|
raise
|
|
else:
|
|
op_impl.rename_table(
|
|
"_alembic_batch_temp",
|
|
self.table.name,
|
|
schema=self.table.schema
|
|
)
|
|
|
|
def alter_column(self, table_name, column_name,
|
|
nullable=None,
|
|
server_default=False,
|
|
name=None,
|
|
type_=None,
|
|
autoincrement=None,
|
|
**kw
|
|
):
|
|
existing = self.columns[column_name]
|
|
existing_transfer = self.column_transfers[column_name]
|
|
if name is not None and name != column_name:
|
|
# note that we don't change '.key' - we keep referring
|
|
# to the renamed column by its old key in _create(). neat!
|
|
existing.name = name
|
|
existing_transfer["name"] = name
|
|
|
|
if type_ is not None:
|
|
type_ = sqltypes.to_instance(type_)
|
|
existing.type = type_
|
|
existing_transfer["expr"] = cast(existing_transfer["expr"], type_)
|
|
if nullable is not None:
|
|
existing.nullable = nullable
|
|
if server_default is not False:
|
|
existing.server_default = server_default
|
|
if autoincrement is not None:
|
|
existing.autoincrement = bool(autoincrement)
|
|
|
|
def add_column(self, table_name, column, **kw):
|
|
# we copy the column because operations.add_column()
|
|
# gives us a Column that is part of a Table already.
|
|
self.columns[column.name] = column.copy(schema=self.table.schema)
|
|
self.column_transfers[column.name] = {}
|
|
|
|
def drop_column(self, table_name, column, **kw):
|
|
del self.columns[column.name]
|
|
del self.column_transfers[column.name]
|
|
|
|
def add_constraint(self, const):
|
|
if not const.name:
|
|
raise ValueError("Constraint must have a name")
|
|
self.named_constraints[const.name] = const
|
|
|
|
def drop_constraint(self, const):
|
|
if not const.name:
|
|
raise ValueError("Constraint must have a name")
|
|
try:
|
|
del self.named_constraints[const.name]
|
|
except KeyError:
|
|
raise ValueError("No such constraint: '%s'" % const.name)
|
|
|
|
def create_index(self, idx):
|
|
self.indexes[idx.name] = idx
|
|
|
|
def drop_index(self, idx):
|
|
try:
|
|
del self.indexes[idx.name]
|
|
except KeyError:
|
|
raise ValueError("No such index: '%s'" % idx.name)
|
|
|
|
def rename_table(self, *arg, **kw):
|
|
raise NotImplementedError("TODO")
|