129 lines
5.1 KiB
Python
129 lines
5.1 KiB
Python
import re
|
|
|
|
from .. import compat
|
|
from .. import util
|
|
from .base import compiles, alter_table, format_table_name, RenameTable
|
|
from .impl import DefaultImpl
|
|
from sqlalchemy.dialects.postgresql import INTEGER, BIGINT
|
|
from sqlalchemy import text, Numeric, Column
|
|
|
|
if compat.sqla_08:
|
|
from sqlalchemy.sql.expression import UnaryExpression
|
|
else:
|
|
from sqlalchemy.sql.expression import _UnaryExpression as UnaryExpression
|
|
|
|
import logging
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class PostgresqlImpl(DefaultImpl):
|
|
__dialect__ = 'postgresql'
|
|
transactional_ddl = True
|
|
|
|
def prep_table_for_batch(self, table):
|
|
for constraint in table.constraints:
|
|
self.drop_constraint(constraint)
|
|
|
|
def compare_server_default(self, inspector_column,
|
|
metadata_column,
|
|
rendered_metadata_default,
|
|
rendered_inspector_default):
|
|
# don't do defaults for SERIAL columns
|
|
if metadata_column.primary_key and \
|
|
metadata_column is metadata_column.table._autoincrement_column:
|
|
return False
|
|
|
|
conn_col_default = rendered_inspector_default
|
|
|
|
if None in (conn_col_default, rendered_metadata_default):
|
|
return conn_col_default != rendered_metadata_default
|
|
|
|
if metadata_column.server_default is not None and \
|
|
isinstance(metadata_column.server_default.arg,
|
|
compat.string_types) and \
|
|
not re.match(r"^'.+'$", rendered_metadata_default) and \
|
|
not isinstance(inspector_column.type, Numeric):
|
|
# don't single quote if the column type is float/numeric,
|
|
# otherwise a comparison such as SELECT 5 = '5.0' will fail
|
|
rendered_metadata_default = "'%s'" % rendered_metadata_default
|
|
|
|
return not self.connection.scalar(
|
|
"SELECT %s = %s" % (
|
|
conn_col_default,
|
|
rendered_metadata_default
|
|
)
|
|
)
|
|
|
|
def autogen_column_reflect(self, inspector, table, column_info):
|
|
if column_info.get('default') and \
|
|
isinstance(column_info['type'], (INTEGER, BIGINT)):
|
|
seq_match = re.match(
|
|
r"nextval\('(.+?)'::regclass\)",
|
|
column_info['default'])
|
|
if seq_match:
|
|
info = inspector.bind.execute(text(
|
|
"select c.relname, a.attname "
|
|
"from pg_class as c join pg_depend d on d.objid=c.oid and "
|
|
"d.classid='pg_class'::regclass and "
|
|
"d.refclassid='pg_class'::regclass "
|
|
"join pg_class t on t.oid=d.refobjid "
|
|
"join pg_attribute a on a.attrelid=t.oid and "
|
|
"a.attnum=d.refobjsubid "
|
|
"where c.relkind='S' and c.relname=:seqname"
|
|
), seqname=seq_match.group(1)).first()
|
|
if info:
|
|
seqname, colname = info
|
|
if colname == column_info['name']:
|
|
log.info(
|
|
"Detected sequence named '%s' as "
|
|
"owned by integer column '%s(%s)', "
|
|
"assuming SERIAL and omitting" % (
|
|
seqname, table.name, colname
|
|
))
|
|
# sequence, and the owner is this column,
|
|
# its a SERIAL - whack it!
|
|
del column_info['default']
|
|
|
|
def correct_for_autogen_constraints(self, conn_unique_constraints,
|
|
conn_indexes,
|
|
metadata_unique_constraints,
|
|
metadata_indexes):
|
|
conn_uniques_by_name = dict(
|
|
(c.name, c) for c in conn_unique_constraints)
|
|
conn_indexes_by_name = dict(
|
|
(c.name, c) for c in conn_indexes)
|
|
|
|
# TODO: if SQLA 1.0, make use of "duplicates_constraint"
|
|
# metadata
|
|
doubled_constraints = dict(
|
|
(name, (conn_uniques_by_name[name], conn_indexes_by_name[name]))
|
|
for name in set(conn_uniques_by_name).intersection(
|
|
conn_indexes_by_name)
|
|
)
|
|
for name, (uq, ix) in doubled_constraints.items():
|
|
conn_indexes.remove(ix)
|
|
|
|
for idx in list(metadata_indexes):
|
|
if idx.name in conn_indexes_by_name:
|
|
continue
|
|
if compat.sqla_08:
|
|
exprs = idx.expressions
|
|
else:
|
|
exprs = idx.columns
|
|
for expr in exprs:
|
|
if not isinstance(expr, (Column, UnaryExpression)):
|
|
util.warn(
|
|
"autogenerate skipping functional index %s; "
|
|
"not supported by SQLAlchemy reflection" % idx.name
|
|
)
|
|
metadata_indexes.discard(idx)
|
|
|
|
|
|
@compiles(RenameTable, "postgresql")
|
|
def visit_rename_table(element, compiler, **kw):
|
|
return "%s RENAME TO %s" % (
|
|
alter_table(compiler, element.table_name, element.schema),
|
|
format_table_name(compiler, element.new_table_name, None)
|
|
)
|