448 lines
17 KiB
Python
448 lines
17 KiB
Python
# Copyright (c) 2018 Orange. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
# not use this file except in compliance with the License. You may obtain
|
|
# a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
# License for the specific language governing permissions and limitations
|
|
# under the License.
|
|
#
|
|
|
|
# Conflict with flake8
|
|
# pylint: disable = bad-continuation
|
|
|
|
"""A theory that contains rules that must be treated by Z3."""
|
|
|
|
import time
|
|
|
|
import logging
|
|
import six
|
|
|
|
from congress.datalog import base
|
|
from congress.datalog import compile as ast
|
|
from congress.datalog import nonrecursive
|
|
from congress.datalog import ruleset
|
|
from congress.datalog import unify
|
|
from congress import exception
|
|
from congress.z3 import typechecker
|
|
from congress.z3 import z3types
|
|
|
|
# pylint: disable = ungrouped-imports
|
|
MYPY = False
|
|
if MYPY:
|
|
# pylint: disable = unused-import
|
|
from congress.datalog import topdown # noqa
|
|
from typing import Dict, Callable, Union, List, Any, Tuple # noqa
|
|
import z3 # noqa
|
|
Z3_RESULT = Tuple[Union[bool, List[List[z3.ExprRef]]],
|
|
List[ast.Variable],
|
|
List[z3types.Z3Type]]
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
Z3_ENGINE_OPTIONS = {'engine': 'datalog'}
|
|
|
|
Z3OPT = z3types.z3
|
|
|
|
INTER_COMPILE_DELAY = 60.0
|
|
|
|
|
|
def cycle_not_contained_in_z3(theories, cycles):
|
|
# type: (Dict[str, base.Theory], List[List[str]]) -> bool
|
|
"""Check that there is a true cycle not through Z3 theory
|
|
|
|
A cycle is irreducible if it contains at least one element which is not a
|
|
Z3Theory for which recursion is allowed. Cycles are represented by lists of
|
|
qualified table names.
|
|
"""
|
|
acceptables = [
|
|
th.name
|
|
for th in six.itervalues(theories)
|
|
if isinstance(th, Z3Theory)]
|
|
return any(fullname[:fullname.index(':')] not in acceptables
|
|
for cycle in cycles for fullname in cycle)
|
|
|
|
|
|
# TODO(pcregut): Object constants should evolve to use the type system
|
|
# rather than custom types.
|
|
def congress_constant(val):
|
|
"""Creates an object constant from a value using its type"""
|
|
if isinstance(val, six.string_types):
|
|
typ = ast.ObjectConstant.STRING
|
|
elif isinstance(val, int):
|
|
typ = ast.ObjectConstant.INTEGER
|
|
elif isinstance(val, float):
|
|
typ = ast.ObjectConstant.FLOAT
|
|
else:
|
|
val = str(val)
|
|
typ = ast.ObjectConstant.STRING
|
|
return ast.ObjectConstant(val, typ)
|
|
|
|
|
|
def retrieve(theory, tablename):
|
|
# type: (topdown.TopDownTheory, str) -> List[ast.Literal]
|
|
"""Retrieves all the values of an external table.
|
|
|
|
Performs a select on the theory with a query computed from the schema
|
|
of the table.
|
|
"""
|
|
arity = theory.schema.arity(tablename)
|
|
table = ast.Tablename(tablename, theory.name)
|
|
args = [ast.Variable('X' + str(i)) for i in range(arity)]
|
|
query = ast.Literal(table, args)
|
|
return theory.select(query)
|
|
|
|
|
|
class Z3Theory(nonrecursive.RuleHandlingMixin, base.Theory):
|
|
"""Theory for Z3 engine
|
|
|
|
Z3Theory is a datalog theory interpreted by the Z3 engine instead of
|
|
the usual congress internal engine.
|
|
"""
|
|
|
|
def __init__(self, name=None, abbr=None,
|
|
schema=None, theories=None, desc=None, owner=None):
|
|
super(Z3Theory, self).__init__(
|
|
name=name, abbr=abbr, theories=theories,
|
|
schema=ast.Schema() if schema is None else schema,
|
|
desc=desc, owner=owner)
|
|
LOG.info('z3theory: create %s', name)
|
|
self.kind = base.Z3_POLICY_TYPE
|
|
self.rules = ruleset.RuleSet()
|
|
self.dirty = False
|
|
self.z3context = None
|
|
Z3Context.get_context().register(self)
|
|
|
|
def select(self, query, find_all=True):
|
|
"""Performs a query"""
|
|
return self.z3context.select(self, query, find_all)
|
|
|
|
def arity(self, tablename, modal=None):
|
|
"""Arity of a table"""
|
|
return self.schema.arity(tablename)
|
|
|
|
def drop(self):
|
|
"""To call when the theory is forgotten"""
|
|
self.z3context.drop(self)
|
|
|
|
def _top_down_eval(self,
|
|
context, # type: topdown.TopDownTheory.TopDownContext
|
|
caller # type: topdown.TopDownTheory.TopDownCaller
|
|
):
|
|
# type: (...) -> bool
|
|
"""Evaluation entry point for the non recursive engine
|
|
|
|
We must compute unifiers and clear off as soon as we can
|
|
giving back control to the theory context.
|
|
Returns true if we only need one binding and it has been found,
|
|
false otherwise.
|
|
"""
|
|
raw_lit = context.literals[context.literal_index]
|
|
query_lit = raw_lit.plug(context.binding)
|
|
answers, bvars, translators = self.z3context.eval(self, query_lit)
|
|
if isinstance(answers, bool):
|
|
if answers:
|
|
return (context.theory._top_down_finish(context, caller)
|
|
and not caller.find_all)
|
|
return False
|
|
for answer in answers:
|
|
changes = []
|
|
for (val, var, trans) in six.moves.zip(answer, bvars, translators):
|
|
chg = context.binding.add(var, trans.to_os(val), None)
|
|
changes.append(chg)
|
|
context.theory._top_down_finish(context, caller)
|
|
unify.undo_all(changes)
|
|
if not caller.find_all:
|
|
return True
|
|
return False
|
|
|
|
|
|
class Z3Context(object):
|
|
"""An instance of Z3 defined first by its execution context"""
|
|
|
|
_singleton = None
|
|
|
|
def __init__(self):
|
|
self.context = Z3OPT.Fixedpoint()
|
|
self.context.set(**Z3_ENGINE_OPTIONS)
|
|
self.z3theories = {} # type: Dict[str, Z3Theory]
|
|
self.relations = {} # type: Dict[str, z3.Function]
|
|
# back pointer on all theories extracted from registered theory.
|
|
self.theories = None # type: Dict[str, topdown.TopDownTheory]
|
|
self.externals = set() # type: Set[Tuple[str, str]]
|
|
self.type_registry = z3types.TypeRegistry()
|
|
self.last_compiled = 0
|
|
|
|
def register(self, theory):
|
|
# type: (Z3Theory) -> None
|
|
"""Registers a Z3 theory in the context"""
|
|
if self.theories is None:
|
|
self.theories = theory.theories
|
|
theory.z3context = self
|
|
self.z3theories[theory.name] = theory
|
|
|
|
def drop(self, theory):
|
|
# type: (Z3Theory) -> None
|
|
"""Unregister a Z3 theory from the context"""
|
|
del self.z3theories[theory.name]
|
|
|
|
@staticmethod
|
|
def get_context():
|
|
# type: () -> Z3Context
|
|
"""Gives back the unique instance of this class.
|
|
|
|
Users should not use the class constructor but this method.
|
|
"""
|
|
if Z3Context._singleton is None:
|
|
Z3Context._singleton = Z3Context()
|
|
return Z3Context._singleton
|
|
|
|
def eval(self,
|
|
theory, # type: Z3Theory
|
|
query # type: ast.Literal
|
|
):
|
|
# type: (...) -> Z3_RESULT
|
|
"""Solves a query and gives back a raw result
|
|
|
|
Result is in Z3 ast format with a translator
|
|
"""
|
|
theories_changed = any(t.dirty for t in self.z3theories.values())
|
|
# TODO(pcregut): replace either with an option or find something
|
|
# better for the refresh of datasources.
|
|
needs_refresh = time.time() - self.last_compiled > INTER_COMPILE_DELAY
|
|
if theories_changed or needs_refresh:
|
|
# There is no reset on Z3 context. Replace with a new one.
|
|
self.context = Z3OPT.Fixedpoint()
|
|
self.context.set(**Z3_ENGINE_OPTIONS)
|
|
self.typecheck()
|
|
self.compile_all()
|
|
self.synchronize_external()
|
|
z3query = self.compile_query(theory, query)
|
|
self.context.query(z3query)
|
|
z3answer = self.context.get_answer()
|
|
answer = z3types.z3_to_array(z3answer)
|
|
typ_args = theory.schema.types(query.table.table)
|
|
variables = []
|
|
translators = []
|
|
for arg, typ_arg in six.moves.zip(query.arguments, typ_args):
|
|
if isinstance(arg, ast.Variable) and arg not in variables:
|
|
translators.append(
|
|
self.type_registry.get_translator(str(typ_arg.type)))
|
|
variables.append(arg)
|
|
return (answer, variables, translators)
|
|
|
|
def select(self, theory, query, find_all):
|
|
# type: (Z3Theory, ast.Literal, bool) -> List[ast.Literal]
|
|
"""Query a theory"""
|
|
(answer, variables, trans) = self.eval(theory, query)
|
|
pattern = [
|
|
variables.index(arg) if isinstance(arg, ast.Variable) else arg
|
|
for arg in query.arguments]
|
|
|
|
def plug(row):
|
|
"""Plugs in found values in query litteral"""
|
|
args = [
|
|
(congress_constant(trans[arg].to_os(row[arg]))
|
|
if isinstance(arg, int) else arg)
|
|
for arg in pattern]
|
|
return ast.Literal(query.table, args)
|
|
|
|
if isinstance(answer, bool):
|
|
return [query] if answer else []
|
|
if find_all:
|
|
result = [plug(row) for row in answer]
|
|
else:
|
|
result = [plug(answer[0])]
|
|
return result
|
|
|
|
def declare_table(self, theory, tablename):
|
|
"""Declares a new table in Z3 context"""
|
|
fullname = theory.name + ':' + tablename
|
|
if fullname in self.relations:
|
|
return
|
|
typ_args = theory.schema.types(tablename)
|
|
param_types = [
|
|
self.type_registry.get_type(str(tArg.type))
|
|
for tArg in typ_args]
|
|
param_types.append(Z3OPT.BoolSort())
|
|
relation = Z3OPT.Function(fullname, *param_types)
|
|
self.context.register_relation(relation)
|
|
self.relations[fullname] = relation
|
|
|
|
def declare_tables(self):
|
|
"""Declares all tables defined in Z3 context"""
|
|
for theory in six.itervalues(self.z3theories):
|
|
for tablename in theory.schema.map.keys():
|
|
self.declare_table(theory, tablename)
|
|
|
|
def declare_external_tables(self):
|
|
"""Declares tables from other theories used in Z3 context"""
|
|
def declare_for_lit(lit):
|
|
"""Declares the table of a litteral if necessary"""
|
|
service = lit.table.service
|
|
table = lit.table.table
|
|
if service is not None and service not in self.z3theories:
|
|
self.externals.add((service, table))
|
|
for theory in six.itervalues(self.z3theories):
|
|
for rules in six.itervalues(theory.rules.rules):
|
|
for rule in rules:
|
|
for lit in rule.body:
|
|
declare_for_lit(lit)
|
|
for (service, table) in self.externals:
|
|
self.declare_table(self.theories[service], table)
|
|
|
|
def compile_facts(self, theory):
|
|
# type: (Z3Theory) -> None
|
|
"""Compiles the facts of a theory in Z3 context"""
|
|
for tname, facts in six.iteritems(theory.rules.facts):
|
|
translators = [
|
|
self.type_registry.get_translator(str(arg_type.type))
|
|
for arg_type in theory.schema.types(tname)]
|
|
fullname = theory.name + ':' + tname
|
|
z3func = self.relations[fullname]
|
|
for fact in facts:
|
|
z3args = (tr.to_z3(v, strict=True)
|
|
for (v, tr) in six.moves.zip(fact, translators))
|
|
z3fact = z3func(*z3args)
|
|
self.context.fact(z3fact)
|
|
|
|
def compile_atoms(self,
|
|
theory, # type: Z3Theory
|
|
head, # type: ast.Literal
|
|
body # type: List[ast.Literal]
|
|
):
|
|
# type: (...) -> Tuple[z3.Const, z3.ExprRef, List[z3.ExprRef]]
|
|
"""Compile a list of atoms belonging to a single variable scope
|
|
|
|
As it is used mainly for rules, the head is distinguished.
|
|
"""
|
|
variables = {} # type: Dict[str, z3.Const]
|
|
z3vars = []
|
|
|
|
def compile_expr(expr, translator):
|
|
"""Compiles an expression to Z3"""
|
|
if isinstance(expr, ast.Variable):
|
|
name = expr.name
|
|
if name in variables:
|
|
return variables[name]
|
|
var = Z3OPT.Const(name, translator.type())
|
|
variables[name] = var
|
|
z3vars.append(var)
|
|
return var
|
|
elif isinstance(expr, ast.ObjectConstant):
|
|
return translator.to_z3(expr.name)
|
|
else:
|
|
raise exception.PolicyException(
|
|
"Expr {} not handled by Z3".format(expr))
|
|
|
|
def compile_atom(literal):
|
|
"""Compiles an atom in Z3"""
|
|
name = literal.table.table
|
|
svc = literal.table.service
|
|
if svc == 'builtin':
|
|
raise exception.PolicyException(
|
|
"Builtins not handled by Z3 yet")
|
|
lit_theory = theory if svc is None else self.theories[svc]
|
|
translators = [
|
|
self.type_registry.get_translator(str(arg_type.type))
|
|
for arg_type in lit_theory.schema.types(name)]
|
|
fullname = lit_theory.name + ":" + name
|
|
try:
|
|
z3func = self.relations[fullname]
|
|
z3args = (compile_expr(arg, tr)
|
|
for (arg, tr) in six.moves.zip(literal.arguments,
|
|
translators))
|
|
z3lit = z3func(*z3args)
|
|
return (Z3OPT.Not(z3lit) if literal.negated
|
|
else z3lit)
|
|
except KeyError:
|
|
raise exception.PolicyException(
|
|
"Z3: Relation %s not registered" % fullname)
|
|
|
|
z3head = compile_atom(head)
|
|
z3body = [compile_atom(atom) for atom in body]
|
|
# We give back variables explicitely and do not rely on declare_var and
|
|
# abstract. Otherwise rules are cluttered with useless variables.
|
|
return (z3vars, z3head, z3body)
|
|
|
|
def compile_rule(self, theory, rule):
|
|
# type: (Z3Theory, ast.Rule) -> None
|
|
"""compiles a single rule
|
|
|
|
:param theory: the theory containing the rule
|
|
:param rule: the rule to compile.
|
|
"""
|
|
z3vars, z3head, z3body = self.compile_atoms(theory, rule.head,
|
|
rule.body)
|
|
term1 = (z3head if z3body == []
|
|
else Z3OPT.Implies(Z3OPT.And(*z3body), z3head))
|
|
term2 = term1 if z3vars == [] else Z3OPT.ForAll(z3vars, term1)
|
|
self.context.rule(term2)
|
|
|
|
def compile_query(self, theory, literal):
|
|
# type: (Z3Theory, ast.Literal) -> z3.ExprRef
|
|
"""compiles a query litteral
|
|
|
|
:param theory: theory used as the context of the query
|
|
:param litteral: the query
|
|
:returns: an existentially quantified litteral in Z3 format.
|
|
"""
|
|
z3vars, z3head, _ = self.compile_atoms(theory, literal, [])
|
|
return z3head if z3vars == [] else Z3OPT.Exists(z3vars, z3head)
|
|
|
|
def compile_theory(self, theory):
|
|
# type: (Z3Theory) -> None
|
|
"""Compiles all the rules of a theory
|
|
|
|
:param theory: theory to compile. Will be marked clean after.
|
|
"""
|
|
self.compile_facts(theory)
|
|
for rules in six.itervalues(theory.rules.rules):
|
|
for rule in rules:
|
|
self.compile_rule(theory, rule)
|
|
theory.dirty = False
|
|
|
|
def compile_all(self):
|
|
"""Compile all Z3 theories"""
|
|
self.relations = {}
|
|
self.externals.clear()
|
|
self.declare_tables()
|
|
self.declare_external_tables()
|
|
for theory in six.itervalues(self.z3theories):
|
|
self.compile_theory(theory)
|
|
self.last_compiled = time.time()
|
|
|
|
def typecheck(self):
|
|
"""Typechecker for rules defined"""
|
|
typer = typechecker.Typechecker(
|
|
self.z3theories.values(), self.theories)
|
|
typer.type_all()
|
|
|
|
def inject(self, theoryname, tablename):
|
|
# type: (str, str) -> None
|
|
"""Inject the values of an external table in the Z3 Context.
|
|
|
|
Loops over the literal retrieved from a standard query.
|
|
"""
|
|
theory = self.theories[theoryname]
|
|
translators = [
|
|
self.type_registry.get_translator(str(arg_type.type))
|
|
for arg_type in theory.schema.types(tablename)]
|
|
fullname = theory.name + ':' + tablename
|
|
z3func = self.relations[fullname]
|
|
for lit in retrieve(theory, tablename):
|
|
z3args = (tr.to_z3(v.name, strict=True)
|
|
for (v, tr) in six.moves.zip(lit.arguments, translators))
|
|
z3fact = z3func(*z3args)
|
|
self.context.fact(z3fact)
|
|
|
|
def synchronize_external(self):
|
|
"""Synchronize all external tables"""
|
|
for (theoryname, tablename) in self.externals:
|
|
self.inject(theoryname, tablename)
|