congress/congress/z3/z3theory.py

458 lines
18 KiB
Python

# Copyright (c) 2018 Orange. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# Conflict with flake8
# pylint: disable = bad-continuation
"""A theory that contains rules that must be treated by Z3."""
import time
import logging
import six
from congress.datalog import base
from congress.datalog import compile as ast
from congress.datalog import nonrecursive
from congress.datalog import ruleset
from congress.datalog import unify
from congress import exception
from congress.z3 import typechecker
from congress.z3 import z3builtins
from congress.z3 import z3types
# pylint: disable = ungrouped-imports
MYPY = False
if MYPY:
# pylint: disable = unused-import
from congress.datalog import topdown # noqa
from mypy_extensions import TypedDict # noqa
from typing import Dict, Callable, Optional, Union, List, Any, Tuple # noqa
import z3 # noqa
Z3_RESULT = Tuple[Union[bool, List[List[z3.ExprRef]]],
List[ast.Variable],
List[z3types.Z3Type]]
LOG = logging.getLogger(__name__)
Z3_ENGINE_OPTIONS = {'engine': 'datalog'}
Z3OPT = z3types.z3
INTER_COMPILE_DELAY = 60.0
def cycle_not_contained_in_z3(theories, cycles):
# type: (Dict[str, base.Theory], List[List[str]]) -> bool
"""Check that there is a true cycle not through Z3 theory
A cycle is irreducible if it contains at least one element which is not a
Z3Theory for which recursion is allowed. Cycles are represented by lists of
qualified table names.
"""
acceptables = [
th.name
for th in six.itervalues(theories)
if isinstance(th, Z3Theory)]
return any(fullname[:fullname.index(':')] not in acceptables
for cycle in cycles for fullname in cycle)
# TODO(pcregut): Object constants should evolve to use the type system
# rather than custom types.
def congress_constant(val):
"""Creates an object constant from a value using its type"""
if isinstance(val, six.string_types):
typ = ast.ObjectConstant.STRING
elif isinstance(val, int):
typ = ast.ObjectConstant.INTEGER
elif isinstance(val, float):
typ = ast.ObjectConstant.FLOAT
else:
val = str(val)
typ = ast.ObjectConstant.STRING
return ast.ObjectConstant(val, typ)
def retrieve(theory, tablename):
# type: (topdown.TopDownTheory, str) -> List[ast.Literal]
"""Retrieves all the values of an external table.
Performs a select on the theory with a query computed from the schema
of the table.
"""
arity = theory.schema.arity(tablename)
table = ast.Tablename(tablename, theory.name)
args = [ast.Variable('X' + str(i)) for i in range(arity)]
query = ast.Literal(table, args)
return theory.select(query)
class Z3Theory(nonrecursive.RuleHandlingMixin, base.Theory):
"""Theory for Z3 engine
Z3Theory is a datalog theory interpreted by the Z3 engine instead of
the usual congress internal engine.
"""
def __init__(self, name=None, abbr=None,
schema=None, theories=None, desc=None, owner=None):
super(Z3Theory, self).__init__(
name=name, abbr=abbr, theories=theories,
schema=ast.Schema() if schema is None else schema,
desc=desc, owner=owner)
LOG.info('z3theory: create %s', name)
self.kind = base.Z3_POLICY_TYPE
self.rules = ruleset.RuleSet()
self.dirty = False
self.z3context = None
Z3Context.get_context().register(self)
def select(self, query, find_all=True):
"""Performs a query"""
return self.z3context.select(self, query, find_all)
def arity(self, tablename, modal=None):
"""Arity of a table"""
return self.schema.arity(tablename)
def drop(self):
"""To call when the theory is forgotten"""
self.z3context.drop(self)
def _top_down_eval(self,
context, # type: topdown.TopDownTheory.TopDownContext
caller # type: topdown.TopDownTheory.TopDownCaller
):
# type: (...) -> bool
"""Evaluation entry point for the non recursive engine
We must compute unifiers and clear off as soon as we can
giving back control to the theory context.
Returns true if we only need one binding and it has been found,
false otherwise.
"""
raw_lit = context.literals[context.literal_index]
query_lit = raw_lit.plug(context.binding)
answers, bvars, translators = self.z3context.eval(self, query_lit)
if isinstance(answers, bool):
if answers:
return (context.theory._top_down_finish(context, caller)
and not caller.find_all)
return False
for answer in answers:
changes = []
for (val, var, trans) in six.moves.zip(answer, bvars, translators):
chg = context.binding.add(var, trans.to_os(val), None)
changes.append(chg)
context.theory._top_down_finish(context, caller)
unify.undo_all(changes)
if not caller.find_all:
return True
return False
class Z3Context(object):
"""An instance of Z3 defined first by its execution context"""
_singleton = None
def __init__(self):
self.context = Z3OPT.Fixedpoint()
self.context.set(**Z3_ENGINE_OPTIONS)
self.z3theories = {} # type: Dict[str, Z3Theory]
self.relations = {} # type: Dict[str, z3.Function]
# back pointer on all theories extracted from registered theory.
self.theories = None # type: Dict[str, topdown.TopDownTheory]
self.externals = set() # type: Set[Tuple[str, str]]
self.type_registry = z3types.TypeRegistry()
self.last_compiled = 0
def register(self, theory):
# type: (Z3Theory) -> None
"""Registers a Z3 theory in the context"""
if self.theories is None:
self.theories = theory.theories
theory.z3context = self
self.z3theories[theory.name] = theory
def drop(self, theory):
# type: (Z3Theory) -> None
"""Unregister a Z3 theory from the context"""
del self.z3theories[theory.name]
@staticmethod
def get_context():
# type: () -> Z3Context
"""Gives back the unique instance of this class.
Users should not use the class constructor but this method.
"""
if Z3Context._singleton is None:
Z3Context._singleton = Z3Context()
return Z3Context._singleton
def eval(self,
theory, # type: Z3Theory
query # type: ast.Literal
):
# type: (...) -> Z3_RESULT
"""Solves a query and gives back a raw result
Result is in Z3 ast format with a translator
"""
theories_changed = any(t.dirty for t in self.z3theories.values())
# TODO(pcregut): replace either with an option or find something
# better for the refresh of datasources.
needs_refresh = time.time() - self.last_compiled > INTER_COMPILE_DELAY
if theories_changed or needs_refresh:
# There is no reset on Z3 context. Replace with a new one.
self.context = Z3OPT.Fixedpoint()
self.context.set(**Z3_ENGINE_OPTIONS)
type_env = self.typecheck()
self.compile_all(type_env)
self.synchronize_external()
z3query = self.compile_query(theory, query)
self.context.query(z3query)
z3answer = self.context.get_answer()
answer = z3types.z3_to_array(z3answer)
typ_args = theory.schema.types(query.table.table)
variables = [] # type: List[ast.Variable]
translators = [] # type: List[z3types.Z3Type]
for arg, typ_arg in six.moves.zip(query.arguments, typ_args):
if isinstance(arg, ast.Variable) and arg not in variables:
translators.append(
self.type_registry.get_translator(str(typ_arg.type)))
variables.append(arg)
return (answer, variables, translators)
def select(self, theory, query, find_all):
# type: (Z3Theory, ast.Literal, bool) -> List[ast.Literal]
"""Query a theory"""
(answer, variables, trans) = self.eval(theory, query)
pattern = [
variables.index(arg) if isinstance(arg, ast.Variable) else arg
for arg in query.arguments]
def plug(row):
"""Plugs in found values in query litteral"""
args = [
(congress_constant(trans[arg].to_os(row[arg]))
if isinstance(arg, int) else arg)
for arg in pattern]
return ast.Literal(query.table, args)
if isinstance(answer, bool):
return [query] if answer else []
if find_all:
result = [plug(row) for row in answer]
else:
result = [plug(answer[0])]
return result
def declare_table(self, theory, tablename):
"""Declares a new table in Z3 context"""
fullname = theory.name + ':' + tablename
if fullname in self.relations:
return
typ_args = theory.schema.types(tablename)
param_types = [
self.type_registry.get_type(str(tArg.type))
for tArg in typ_args]
param_types.append(Z3OPT.BoolSort())
relation = Z3OPT.Function(fullname, *param_types)
self.context.register_relation(relation)
self.relations[fullname] = relation
def declare_tables(self):
"""Declares all tables defined in Z3 context"""
for theory in six.itervalues(self.z3theories):
for tablename in theory.schema.map.keys():
self.declare_table(theory, tablename)
def declare_external_tables(self):
"""Declares tables from other theories used in Z3 context"""
def declare_for_lit(lit):
"""Declares the table of a litteral if necessary"""
service = lit.table.service
table = lit.table.table
if (service is not None and service != 'builtin' and
service not in self.z3theories):
self.externals.add((service, table))
for theory in six.itervalues(self.z3theories):
for rules in six.itervalues(theory.rules.rules):
for rule in rules:
for lit in rule.body:
declare_for_lit(lit)
for (service, table) in self.externals:
self.declare_table(self.theories[service], table)
def compile_facts(self, theory):
# type: (Z3Theory) -> None
"""Compiles the facts of a theory in Z3 context"""
for tname, facts in six.iteritems(theory.rules.facts):
translators = [
self.type_registry.get_translator(str(arg_type.type))
for arg_type in theory.schema.types(tname)]
fullname = theory.name + ':' + tname
z3func = self.relations[fullname]
for fact in facts:
z3args = (tr.to_z3(v, strict=True)
for (v, tr) in six.moves.zip(fact, translators))
z3fact = z3func(*z3args)
self.context.fact(z3fact)
def compile_atoms(self,
type_env,
theory, # type: Z3Theory
head, # type: ast.Literal
body # type: List[ast.Literal]
):
# type: (...) -> Tuple[z3.Const, z3.ExprRef, List[z3.ExprRef]]
"""Compile a list of atoms belonging to a single variable scope
As it is used mainly for rules, the head is distinguished.
"""
variables = {} # type: Dict[str, z3.Const]
z3vars = []
def compile_expr(expr, translator):
"""Compiles an expression to Z3"""
if isinstance(expr, ast.Variable):
name = expr.name
if name in variables:
return variables[name]
var = Z3OPT.Const(name, translator.type())
variables[name] = var
z3vars.append(var)
return var
elif isinstance(expr, ast.ObjectConstant):
return translator.to_z3(expr.name)
else:
raise exception.PolicyException(
"Expr {} not handled by Z3".format(expr))
def compile_atom(literal, pos=-1):
"""Compiles an atom in Z3"""
name = literal.table.table
svc = literal.table.service
if svc == 'builtin':
translators = [
self.type_registry.get_translator(str(arg_type['type']))
for arg_type in type_env[pos]
]
fullname = 'builtin:'+name
else:
lit_theory = theory if svc is None else self.theories[svc]
translators = [
self.type_registry.get_translator(str(arg_type.type))
for arg_type in lit_theory.schema.types(name)]
fullname = lit_theory.name + ":" + name
try:
z3func = (
z3builtins.BUILTINS[name].z3 if svc == 'builtin'
else self.relations[fullname])
z3args = (compile_expr(arg, tr)
for (arg, tr) in six.moves.zip(literal.arguments,
translators))
z3lit = z3func(*z3args)
return (Z3OPT.Not(z3lit) if literal.negated
else z3lit)
except KeyError:
raise exception.PolicyException(
"Z3: Relation %s not registered" % fullname)
z3head = compile_atom(head)
z3body = [compile_atom(atom, pos) for (pos, atom) in enumerate(body)]
# We give back variables explicitely and do not rely on declare_var and
# abstract. Otherwise rules are cluttered with useless variables.
return (z3vars, z3head, z3body)
def compile_rule(self, type_env, theory, rule):
# type: (typechecker.GEN_TYPE_ENV, Z3Theory, ast.Rule) -> None
"""compiles a single rule
:param theory: the theory containing the rule
:param rule: the rule to compile.
"""
z3vars, z3head, z3body = self.compile_atoms(
type_env.get(rule.id, {}), theory, rule.head, rule.body)
term1 = (z3head if z3body == []
else Z3OPT.Implies(Z3OPT.And(*z3body), z3head))
term2 = term1 if z3vars == [] else Z3OPT.ForAll(z3vars, term1)
self.context.rule(term2)
def compile_query(self, theory, literal):
# type: (Z3Theory, ast.Literal) -> z3.ExprRef
"""compiles a query litteral
:param theory: theory used as the context of the query
:param litteral: the query
:returns: an existentially quantified litteral in Z3 format.
"""
z3vars, z3head, _ = self.compile_atoms({}, theory, literal, [])
return z3head if z3vars == [] else Z3OPT.Exists(z3vars, z3head)
def compile_theory(self, type_env, theory):
# type: (typechecker.GEN_TYPE_ENV, Z3Theory) -> None
"""Compiles all the rules of a theory
:param theory: theory to compile. Will be marked clean after.
"""
self.compile_facts(theory)
for rules in six.itervalues(theory.rules.rules):
for rule in rules:
self.compile_rule(type_env, theory, rule)
theory.dirty = False
def compile_all(self, type_env):
"""Compile all Z3 theories"""
self.relations = {}
self.externals.clear()
self.declare_tables()
self.declare_external_tables()
for theory in six.itervalues(self.z3theories):
self.compile_theory(type_env, theory)
self.last_compiled = time.time()
def typecheck(self):
"""Typechecker for rules defined"""
typer = typechecker.Typechecker(
self.z3theories.values(), self.theories)
return typer.type_all()
def inject(self, theoryname, tablename):
# type: (str, str) -> None
"""Inject the values of an external table in the Z3 Context.
Loops over the literal retrieved from a standard query.
"""
theory = self.theories[theoryname]
translators = [
self.type_registry.get_translator(str(arg_type.type))
for arg_type in theory.schema.types(tablename)]
fullname = theory.name + ':' + tablename
z3func = self.relations[fullname]
for lit in retrieve(theory, tablename):
z3args = (tr.to_z3(v.name, strict=True)
for (v, tr) in six.moves.zip(lit.arguments, translators))
z3fact = z3func(*z3args)
self.context.fact(z3fact)
def synchronize_external(self):
"""Synchronize all external tables"""
for (theoryname, tablename) in self.externals:
self.inject(theoryname, tablename)