diff --git a/frame/tools/target_language.py b/frame/tools/target_language.py new file mode 100644 index 0000000..8d85f6d --- /dev/null +++ b/frame/tools/target_language.py @@ -0,0 +1,300 @@ +import typing +from abc import ABC, abstractmethod + + +class TargetLanguage(ABC): + """Abstract base class for target language operator translations.""" + + @abstractmethod + def translate_arith_op(self, op: str) -> str: + """Translate arithmetic operators like +, -, *, /, %, ^""" + pass + + @abstractmethod + def translate_logical_op(self, op: str) -> str: + """Translate comparison operators like ==, >, <, >=, <=, !=""" + pass + + @abstractmethod + def translate_boolean_op(self, op: str) -> str: + """Translate boolean operators like And, Or, Implies, Not""" + pass + + @abstractmethod + def translate_quantifier(self, quantifier: str) -> str: + """Translate quantifiers like ForAll, Exists""" + pass + + @abstractmethod + def translate_boolean_literal(self, literal: str) -> str: + """Translate boolean literals like True, False""" + pass + + @abstractmethod + def format_function_call(self, function_name: str, args: typing.List[str]) -> str: + """Format a function call with given name and arguments""" + pass + + @abstractmethod + def format_quantified_expression(self, quantifier: str, bound_vars: typing.List[str], body: str) -> str: + """Format a quantified expression with bound variables and body""" + pass + + @abstractmethod + def get_preamble(self) -> str: + """Get preamble code to add at the beginning of the program""" + pass + + +class SMTLibTargetLanguage(TargetLanguage): + """Default SMT-Lib target language implementation.""" + + def translate_arith_op(self, op: str) -> str: + arith_op_map = { + "+": "+", + "-": "-", + "*": "*", + "/": "div", + "%": "mod", + "^": "^" + } + if op not in arith_op_map: + raise ValueError(f"Unknown arithmetic operator: {op}") + return arith_op_map[op] + + def translate_logical_op(self, op: str) -> str: + logical_op_map = { + "==": "=", + ">": ">", + "<": "<", + ">=": ">=", + "<=": "<=", + "!=": "distinct" + } + if op not in logical_op_map: + raise ValueError(f"Unknown logical operator: {op}") + return logical_op_map[op] + + def translate_boolean_op(self, op: str) -> str: + boolean_op_map = { + "And": "and", + "Or": "or", + "Implies": "=>", + "Not": "not" + } + if op not in boolean_op_map: + raise ValueError(f"Unknown boolean operator: {op}") + return boolean_op_map[op] + + def translate_quantifier(self, quantifier: str) -> str: + quantifier_map = { + "ForAll": "forall", + "Exists": "exists" + } + if quantifier not in quantifier_map: + raise ValueError(f"Unknown quantifier: {quantifier}") + return quantifier_map[quantifier] + + def translate_boolean_literal(self, literal: str) -> str: + literal_map = { + "True": "true", + "False": "false" + } + if literal not in literal_map: + raise ValueError(f"Unknown boolean literal: {literal}") + return literal_map[literal] + + def format_function_call(self, function_name: str, args: typing.List[str]) -> str: + if not args: + return function_name + args_str = " ".join(args) + return f"({function_name} {args_str})" + + def format_quantified_expression(self, quantifier: str, bound_vars: typing.List[str], body: str) -> str: + # SMT-Lib specific logic for quantified expressions + bound_vars_str = " ".join([f"({var} Int)" for var in bound_vars]) + + # Add constraints for variables >= 0 + constraints = [f"(<= 0 {var})" for var in bound_vars] + + if quantifier == "forall": + if constraints: + all_constraints = " ".join(constraints) + body = f"(=> {all_constraints} {body})" + elif quantifier == "exists": + if constraints: + all_constraints = " ".join(constraints) + body = f"(and {all_constraints} {body})" + + return f"({quantifier} ({bound_vars_str}) {body})" + + def get_preamble(self) -> str: + """SMT-Lib doesn't need any preamble""" + return "" + + +class ZnTargetLanguage(TargetLanguage): + """Z_n finite field target language implementation.""" + + def __init__(self, n: int): + """Initialize with modulus n for Z_n finite field""" + if n <= 1: + raise ValueError(f"Modulus n must be greater than 1, got {n}") + self.n = n + self.is_prime = self._is_prime(n) + + def _is_prime(self, n: int) -> bool: + """Check if n is prime""" + if n < 2: + return False + if n == 2: + return True + if n % 2 == 0: + return False + for i in range(3, int(n**0.5) + 1, 2): + if n % i == 0: + return False + return True + + def translate_arith_op(self, op: str) -> str: + arith_op_map = { + "+": "zn_add", + "-": "zn_sub", + "*": "zn_mul", + "/": "zn_div" if self.is_prime else "div", # Only support division for prime fields + "%": "mod", # Regular modulo for other uses + "^": "zn_pow" # Modular exponentiation + } + if op not in arith_op_map: + raise ValueError(f"Unknown arithmetic operator: {op}") + return arith_op_map[op] + + def translate_logical_op(self, op: str) -> str: + # Logical operations remain the same + logical_op_map = { + "==": "=", + ">": ">", + "<": "<", + ">=": ">=", + "<=": "<=", + "!=": "distinct" + } + if op not in logical_op_map: + raise ValueError(f"Unknown logical operator: {op}") + return logical_op_map[op] + + def translate_boolean_op(self, op: str) -> str: + # Boolean operations remain the same + boolean_op_map = { + "And": "and", + "Or": "or", + "Implies": "=>", + "Not": "not" + } + if op not in boolean_op_map: + raise ValueError(f"Unknown boolean operator: {op}") + return boolean_op_map[op] + + def translate_quantifier(self, quantifier: str) -> str: + # Quantifiers remain the same + quantifier_map = { + "ForAll": "forall", + "Exists": "exists" + } + if quantifier not in quantifier_map: + raise ValueError(f"Unknown quantifier: {quantifier}") + return quantifier_map[quantifier] + + def translate_boolean_literal(self, literal: str) -> str: + # Boolean literals remain the same + literal_map = { + "True": "true", + "False": "false" + } + if literal not in literal_map: + raise ValueError(f"Unknown boolean literal: {literal}") + return literal_map[literal] + + def format_function_call(self, function_name: str, args: typing.List[str]) -> str: + if not args: + return function_name + args_str = " ".join(args) + return f"({function_name} {args_str})" + + def format_quantified_expression(self, quantifier: str, bound_vars: typing.List[str], body: str) -> str: + # For Z_n, variables are constrained to [0, n-1] + bound_vars_str = " ".join([f"({var} Int)" for var in bound_vars]) + + # Add constraints for variables in range [0, n-1] + # Create individual constraints but don't wrap in (and ...) yet + constraints = [] + for var in bound_vars: + constraints.append(f"(<= 0 {var})") + constraints.append(f"(< {var} {self.n})") + + # Combine constraint with body based on quantifier + if quantifier == "forall": + # For forall: (=> (and constraints...) body) + if len(constraints) == 1: + constraint_expr = constraints[0] + else: + constraint_expr = f"(and {' '.join(constraints)})" + body = f"(=> {constraint_expr} {body})" + elif quantifier == "exists": + # For exists: (and constraints... body) + if len(constraints) == 0: + # No constraints, just body + pass + elif len(constraints) == 1: + body = f"(and {constraints[0]} {body})" + else: + all_parts = constraints + [body] + body = f"(and {' '.join(all_parts)})" + + return f"({quantifier} ({bound_vars_str}) {body})" + + def _get_basic_operations(self) -> str: + """Generate basic Z_n operations (add, sub, mul, pow)""" + return f"""; Z_{self.n} finite field operations +(define-fun zn_add ((x Int) (y Int)) Int (mod (+ x y) {self.n})) +(define-fun zn_sub ((x Int) (y Int)) Int (mod (- x y) {self.n})) +(define-fun zn_mul ((x Int) (y Int)) Int (mod (* x y) {self.n})) +(define-fun zn_pow ((x Int) (y Int)) Int (mod (^ x y) {self.n})) +""" + + def _get_inverse_function(self) -> str: + """Generate modular inverse function for prime fields""" + if not self.is_prime: + return "" + + inverse_def = "(define-fun zn_inv ((x Int)) Int\n" + inverse_def += " (ite (= x 0) 0\n" # 0 has no inverse + + for i in range(1, self.n): + for j in range(1, self.n): + if (i * j) % self.n == 1: + inverse_def += f" (ite (= x {i}) {j}\n" + break + + inverse_def += " 0" + ")" * self.n + ")\n" + return inverse_def + + def _get_division_function(self) -> str: + """Generate division function using modular inverse""" + if not self.is_prime: + return f"; Note: Division not supported for non-prime modulus {self.n}\n" + + return f"""(define-fun zn_div ((x Int) (y Int)) Int + (ite (= y 0) + 0 ; Division by zero returns 0 (undefined behavior) + (mod (* x (zn_inv y)) {self.n}))) +""" + + def get_preamble(self) -> str: + """Generate Z_n finite field function definitions""" + parts = [ + self._get_basic_operations(), + self._get_inverse_function(), # Must come before division + self._get_division_function() + ] + return "".join(part for part in parts if part.strip()) \ No newline at end of file diff --git a/frame/tools/z3_dsl.py b/frame/tools/z3_dsl.py index 63885db..0398f16 100644 --- a/frame/tools/z3_dsl.py +++ b/frame/tools/z3_dsl.py @@ -4,9 +4,12 @@ from parglare.parser import LRStackNode from frame.tools.grammar import Grammar from frame.tools.z3_runtime import Z3DslRuntime +from frame.tools.target_language import TargetLanguage, SMTLibTargetLanguage, ZnTargetLanguage from collections import namedtuple from abc import ABC, abstractmethod + + Z3DslRunResult = namedtuple( "Z3DslRunResult", [ @@ -61,6 +64,12 @@ def __repr__(self): def parent_program(self) -> typing.Optional['Z3Program']: return self._parent_executable + @property + def target_language(self) -> TargetLanguage: + if self._parent_executable is not None: + return self._parent_executable.target_language + return SMTLibTargetLanguage() + @property def params(self) -> int: return 0 @@ -520,14 +529,12 @@ def parse(context: LRStackNode, nodes) -> 'ArithExprOp': # assert isinstance(nodes[2], ArithExpr), "Compiler Bug: Invalid arith expr op" assert isinstance(nodes[1], str), "Compiler Bug: Invalid arith expr op" ops_available = ["+", "-", "*", "/", "%", "^"] - assert nodes[1] in ops_available, f"Compiler Bug: Invalid arith operator {nodes[1].value}" + assert nodes[1] in ops_available, f"Compiler Bug: Invalid arith operator {nodes[1]}" arith_expr_op = ArithExprOp(nodes[1]) arith_expr_op._program_code = context.input_str[context.start_position:context.end_position] - if nodes[1] == "/": - nodes[1] = "div" - elif nodes[1] == "%": - nodes[1] = "mod" - nodes[1] = LeafNode(nodes[1], True) + # Store original operator, translation will happen in return_expr() + original_op = nodes[1] + nodes[1] = LeafNode(original_op, True) arith_expr_op.add_child(nodes[1]) arith_expr_op.add_child(nodes[0]) arith_expr_op.add_child(nodes[2]) @@ -538,6 +545,20 @@ def __init__(self, op: str): assert op in ["+", "-", "*", "/", "%", "^"], f"Compiler Bug: Invalid arith operator {op}" self.op = op + def return_expr(self): + assert len(self._children) == 3, "Compiler Bug: Invalid arith expr op" + assert isinstance(self._children[0], LeafNode), "Compiler Bug: Invalid arith expr op" + assert isinstance(self._children[1], Z3ComposableDslParseNode), "Compiler Bug: Invalid arith expr op" + assert isinstance(self._children[2], Z3ComposableDslParseNode), "Compiler Bug: Invalid arith expr op" + original_op = self._children[0].value + translated_op = self.target_language.translate_arith_op(original_op) + left_expr = self._children[1].return_expr() + right_expr = self._children[2].return_expr() + return self.target_language.format_function_call(translated_op, [left_expr, right_expr]) + + def return_pred(self): + return self.return_expr() + class ArithExprCompund(Z3ComposableDslParseNode): def parse(context: LRStackNode, nodes) -> 'ArithExprCompund': assert isinstance(nodes, list), "Compiler Bug: Invalid arith expr compound" @@ -591,13 +612,13 @@ def parse(context: LRStackNode, nodes) -> 'LogicalExprOp': # assert isinstance(nodes[0], LogicalExpr), "Compiler Bug: Invalid logical expr op" # assert isinstance(nodes[2], LogicalExpr), "Compiler Bug: Invalid logical expr op" assert isinstance(nodes[1], str), "Compiler Bug: Invalid logical expr op" - ops_available = ["==", ">", "<", ">=", "<=", "==", "!="] - assert nodes[1] in ops_available, f"Compiler Bug: Invalid logical operator {nodes[1].value}" + ops_available = ["==", ">", "<", ">=", "<=", "!="] + assert nodes[1] in ops_available, f"Compiler Bug: Invalid logical operator {nodes[1]}" logical_expr_op = LogicalExprOp(nodes[1]) logical_expr_op._program_code = context.input_str[context.start_position:context.end_position] - if nodes[1] == "==": - nodes[1] = "=" - nodes[1] = LeafNode(nodes[1], True) + # Store original operator, translation will happen in return_pred() + original_op = nodes[1] + nodes[1] = LeafNode(original_op, True) logical_expr_op.add_child(nodes[1]) logical_expr_op.add_child(nodes[0]) logical_expr_op.add_child(nodes[2]) @@ -605,9 +626,20 @@ def parse(context: LRStackNode, nodes) -> 'LogicalExprOp': def __init__(self, op: str): super().__init__(op) - assert op in ["==", ">", "<", ">=", "<=", "==", "!="], f"Compiler Bug: Invalid logical operator {op}" + assert op in ["==", ">", "<", ">=", "<=", "!="], f"Compiler Bug: Invalid logical operator {op}" self.op = op + def return_pred(self): + assert len(self._children) == 3, "Compiler Bug: Invalid logical expr op" + assert isinstance(self._children[0], LeafNode), "Compiler Bug: Invalid logical expr op" + assert isinstance(self._children[1], Z3ComposableDslParseNode), "Compiler Bug: Invalid logical expr op" + assert isinstance(self._children[2], Z3ComposableDslParseNode), "Compiler Bug: Invalid logical expr op" + original_op = self._children[0].value + translated_op = self.target_language.translate_logical_op(original_op) + left_expr = self._children[1].return_pred() + right_expr = self._children[2].return_pred() + return self.target_language.format_function_call(translated_op, [left_expr, right_expr]) + class LogicalExprQuantifier(Z3ComposableDslParseNode): def parse(context: LRStackNode, nodes) -> 'LogicalExprQuantifier': assert isinstance(nodes, list), "Compiler Bug: Invalid logical expr quantifier" @@ -629,13 +661,10 @@ def return_pred(self): assert isinstance(self._children[0], LeafNode), "Compiler Bug: Invalid logical expr quantifier" assert isinstance(self._children[1], BoundedVarSeq), "Compiler Bug: Invalid logical expr quantifier" assert isinstance(self._children[2], Z3ComposableDslParseNode), "Compiler Bug: Invalid logical expr quantifier" - quantifier = self._children[0].return_pred() - if quantifier == "ForAll": - quantifier = "forall" - elif quantifier == "Exists": - quantifier = "exists" - else: - raise Exception(f"Compiler Bug: Invalid logical expr quantifier {quantifier}") + + original_quantifier = self._children[0].return_pred() + translated_quantifier = self.target_language.translate_quantifier(original_quantifier) + bounded_vars = [] bounded_var_seq = self._children[1] assert isinstance(bounded_var_seq, BoundedVarSeq), "Compiler Bug: Invalid logical expr quantifier" @@ -645,20 +674,12 @@ def return_pred(self): assert isinstance(child.value, str), "Compiler Bug: Invalid logical expr quantifier" assert child.value.startswith("b") and child.value[2:].isdigit(), "Compiler Bug: Invalid logical expr quantifier" bounded_vars.append(child.value) + parent_program = self.parent_program assert isinstance(parent_program, Z3Program), f"Compiler Bug: Parent program should be a Z3Program at {self._program_code}" - # assert len(bounded_vars) == parent_program.bounded_params.params, f"User Error: The number of bounded variables {len(bounded_vars)} does not match the specified number of bounded parameters {parent_program.bounded_params.params} at {self._program_code}" - bound_vars = " ".join([f"({var} Int)" for var in bounded_vars]) + logical_expr = self._children[2].return_pred() - if quantifier == "forall": - all_vars_gtr_than_zero = ' '.join([f'(<= 0 {var})' for var in bounded_vars]) - logical_expr = f"(=> {all_vars_gtr_than_zero} {logical_expr})" - elif quantifier == "exists": - all_vars_gtr_than_zero = ' '.join([f'(<= 0 {var})' for var in bounded_vars]) - logical_expr = f"(and {all_vars_gtr_than_zero} {logical_expr})" - else: - raise Exception(f"Compiler Bug: Invalid logical expr quantifier {quantifier}") - return f"({quantifier} ({bound_vars}) {logical_expr})" + return self.target_language.format_quantified_expression(translated_quantifier, bounded_vars, logical_expr) class LogicalExprBoolean(Z3ComposableDslParseNode): def parse(context: LRStackNode, nodes) -> 'LogicalExprBoolean': @@ -682,22 +703,14 @@ def return_pred(self): assert len(self._children) >= 3, "Compiler Bug: Invalid logical expr boolean" assert isinstance(self._children[0], LeafNode), "Compiler Bug: Invalid logical expr boolean" assert all(isinstance(child, LogicalExpr) for child in self._children[1:]), "Compiler Bug: Invalid logical expr boolean" - op_type = self._children[0].return_pred() - if op_type == "And": - op_type = "and" - elif op_type == "Or": - op_type = "or" - elif op_type == "Implies": - op_type = "=>" - else: - raise Exception(f"Compiler Bug: Invalid logical expr boolean {op_type}") + original_op = self._children[0].return_pred() + translated_op = self.target_language.translate_boolean_op(original_op) logical_exprs = [] for child in self._children[1:]: assert isinstance(child, LogicalExpr), "Compiler Bug: Invalid logical expr boolean" logical_expr = child.return_pred() logical_exprs.append(logical_expr) - logical_exprs_str = " ".join(logical_exprs) - return f"({op_type} {logical_exprs_str})" + return self.target_language.format_function_call(translated_op, logical_exprs) class LogicalExprNot(Z3ComposableDslParseNode): def parse(context: LRStackNode, nodes) -> 'LogicalExprNot': @@ -716,13 +729,10 @@ def return_pred(self): assert len(self._children) == 2, "Compiler Bug: Invalid logical expr not" assert isinstance(self._children[0], LeafNode), "Compiler Bug: Invalid logical expr not" assert isinstance(self._children[1], LogicalExpr), "Compiler Bug: Invalid logical expr not" - op_type = self._children[0].return_pred() - if op_type == "Not": - op_type = "not" - else: - raise Exception(f"Compiler Bug: Invalid logical expr not {op_type}") + original_op = self._children[0].return_pred() + translated_op = self.target_language.translate_boolean_op(original_op) logical_expr = self._children[1].return_pred() - return f"({op_type} {logical_expr})" + return self.target_language.format_function_call(translated_op, [logical_expr]) class LogicalExprIsMember(Z3ComposableDslParseNode): def parse(context: LRStackNode, nodes) -> 'LogicalExprIsMember': @@ -765,12 +775,8 @@ def return_pred(self): assert len(self._children) == 1, "Compiler Bug: Invalid logical expr true false" assert isinstance(self._children[0], LeafNode), "Compiler Bug: Invalid logical expr true false" assert self._children[0].value in ["True", "False"], "Compiler Bug: Invalid logical expr true false" - if self._children[0].value == "True": - return "true" - elif self._children[0].value == "False": - return "false" - else: - raise Exception(f"Compiler Bug: Invalid logical expr true false {self._children[0].value}") + original_literal = self._children[0].value + return self.target_language.translate_boolean_literal(original_literal) class LogicalExprPred(Z3ComposableDslParseNode): def parse(context: LRStackNode, nodes) -> 'LogicalExprPred': @@ -1001,11 +1007,16 @@ def parse(context: LRStackNode, nodes) -> 'Z3Program': program.add_child(pred_expr) return program - def __init__(self): + def __init__(self, target_language: typing.Optional[TargetLanguage] = None): super().__init__() self._name = None self.symbol_table : typing.Dict[str, 'Z3Program'] = {} self._populated = False + self._target_language = target_language or SMTLibTargetLanguage() + + @property + def target_language(self) -> TargetLanguage: + return self._target_language @property def name(self): @@ -1260,7 +1271,11 @@ def smt2(self): pred_expr_str = f"(define-fun return_pred () Bool {self.pred_expr.return_pred()})\n" else: pred_expr_str = "" + # Get preamble from target language + preamble = self.target_language.get_preamble() + smt_program_lines = [ + preamble, fun_decls_str, pred_decls_str, expr_str, @@ -1292,11 +1307,11 @@ def params(self): return self.unbounded_params.params @staticmethod - def from_code(code: str) -> 'Z3Program': + def from_code(code: str, target_language: typing.Optional[TargetLanguage] = None) -> 'Z3Program': """ Parse the code and return the Z3Program. """ - parser = Z3ComposableDsl() + parser = Z3ComposableDsl(target_language) result = parser.parse(code) assert isinstance(result, Z3Program), "Compiler Bug: Result should be a Z3Program" return result @@ -1357,7 +1372,8 @@ class Z3ComposableDsl(Grammar): "," ] - def __init__(self): + def __init__(self, target_language: typing.Optional[TargetLanguage] = None): + self._target_language = target_language or SMTLibTargetLanguage() with open(Z3ComposableDsl.grammar_path, "r") as f: grammar = f.read() super(Z3ComposableDsl, self).__init__( @@ -1406,6 +1422,8 @@ def parse(self, code: str) -> Z3ComposableDslParseNode: actions = self.get_action() parser = self._get_parser(actions) result = parser.parse(code) + if isinstance(result, Z3Program): + result._target_language = self._target_language return result def run(self, code: str) -> Z3DslRunResult: diff --git a/tests/test_zn.py b/tests/test_zn.py new file mode 100644 index 0000000..df4f321 --- /dev/null +++ b/tests/test_zn.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 + +import unittest +from frame.tools.z3_dsl import Z3Program +from frame.tools.target_language import ZnTargetLanguage + + + + +class TestZnTargetLanguage(unittest.TestCase): + + def test_z3_field(self): + """Test Z_3 finite field operations""" + # Create Z_3 target language (elements: {0, 1, 2}) + z3_lang = ZnTargetLanguage(3) + + # Test: 1 + 2 == 0 in Z_3 (since (1+2) mod 3 = 0) + dsl_code = """ + params 0; + bounded params 0; + ReturnExpr None; + ReturnPred 1 + 2 == 0; + """ + + program = Z3Program.from_code(dsl_code, z3_lang) + result = program.run() + + self.assertTrue(result.proved) + self.assertIsNone(result.counter_example) + self.assertIn("zn_add", result.smt2) + self.assertIn("define-fun zn_add", result.smt2) + + def test_z5_field(self): + """Test Z_5 finite field operations""" + # Create Z_5 target language (elements: {0, 1, 2, 3, 4}) + z5_lang = ZnTargetLanguage(5) + + # Test: 3 * 2 == 1 in Z_5 (since (3*2) mod 5 = 1) + dsl_code = """ + params 0; + bounded params 0; + ReturnExpr None; + ReturnPred 3 * 2 == 1; + """ + + program = Z3Program.from_code(dsl_code, z5_lang) + result = program.run() + + self.assertTrue(result.proved) + self.assertIsNone(result.counter_example) + self.assertIn("zn_mul", result.smt2) + self.assertIn("define-fun zn_mul", result.smt2) + + def test_z7_prime_field_division(self): + """Test Z_7 prime field with division""" + # Create Z_7 target language (prime field, supports division) + z7_lang = ZnTargetLanguage(7) + + # Test: 6 / 2 == 3 in Z_7 (since 6 * inv(2) mod 7 = 6 * 4 mod 7 = 3) + dsl_code = """ + params 0; + bounded params 0; + ReturnExpr None; + ReturnPred 6 / 2 == 3; + """ + + program = Z3Program.from_code(dsl_code, z7_lang) + result = program.run() + + self.assertTrue(result.proved) + self.assertIsNone(result.counter_example) + self.assertIn("zn_div", result.smt2) + self.assertIn("define-fun zn_div", result.smt2) + self.assertIn("define-fun zn_inv", result.smt2) + + def test_z6_composite_field(self): + """Test Z_6 composite field (no division)""" + # Create Z_6 target language (composite field, no division) + z6_lang = ZnTargetLanguage(6) + + # Test: 2 * 3 == 0 in Z_6 (since (2*3) mod 6 = 0) + dsl_code = """ + params 0; + bounded params 0; + ReturnExpr None; + ReturnPred 2 * 3 == 0; + """ + + program = Z3Program.from_code(dsl_code, z6_lang) + result = program.run() + + self.assertTrue(result.proved) + self.assertIsNone(result.counter_example) + self.assertIn("zn_mul", result.smt2) + self.assertIn("define-fun zn_mul", result.smt2) + # Should not have division for composite fields + self.assertNotIn("zn_div", result.smt2) + self.assertIn("Note: Division not supported for non-prime modulus 6", result.smt2) + + def test_zn_primality_check(self): + """Test primality checking in ZnTargetLanguage""" + z3_lang = ZnTargetLanguage(3) + z5_lang = ZnTargetLanguage(5) + z6_lang = ZnTargetLanguage(6) + z7_lang = ZnTargetLanguage(7) + + self.assertTrue(z3_lang.is_prime) + self.assertTrue(z5_lang.is_prime) + self.assertFalse(z6_lang.is_prime) + self.assertTrue(z7_lang.is_prime) + + def test_z3_multiplicative_inverses(self): + """Test that every non-zero element in Z_3 has a multiplicative inverse""" + z3_lang = ZnTargetLanguage(3) + + # Test: 1 * 1 == 1 (1 is its own inverse) + dsl_code1 = """ + params 0; + bounded params 0; + ReturnExpr None; + ReturnPred 1 * 1 == 1; + """ + program1 = Z3Program.from_code(dsl_code1, z3_lang) + result1 = program1.run() + self.assertTrue(result1.proved) + + # Test: 2 * 2 == 1 (2 is its own inverse in Z_3) + dsl_code2 = """ + params 0; + bounded params 0; + ReturnExpr None; + ReturnPred 2 * 2 == 1; + """ + program2 = Z3Program.from_code(dsl_code2, z3_lang) + result2 = program2.run() + self.assertTrue(result2.proved) + + def test_z3_additive_inverses(self): + """Test additive inverses in Z_3""" + z3_lang = ZnTargetLanguage(3) + + # Test: 1 + 2 == 0 (1 and 2 are additive inverses in Z_3) + dsl_code = """ + params 0; + bounded params 2; + ReturnExpr None; + ReturnPred ForAll([b_0], Exists([b_1], b_0 + b_1 == 0)); + """ + program = Z3Program.from_code(dsl_code, z3_lang) + generated_code = program.smt2() + print("Generated SMT2 Code for Additive Inverses in Z_3:") + print(generated_code) # For debugging purposes + result = program.run() + self.assertTrue(result.proved, "Additive inverses should hold in Z_3") + + def test_z3_multiplicative_inverses_z3(self): + """Test multiplicative inverses in Z_3""" + z3_lang = ZnTargetLanguage(3) + dsl_code = """ + params 0; + bounded params 2; + ReturnExpr None; + ReturnPred ForAll([b_0], Implies(b_0 != 0, Exists([b_1], b_0 * b_1 == 1))); + """ + program = Z3Program.from_code(dsl_code, z3_lang) + generated_code = program.smt2() + print("Generated SMT2 Code for Multiplicative Inverses in Z_3:") + print(generated_code) + result = program.run() + self.assertTrue(result.proved, "Multiplicative inverses should hold in Z_3") + + + def test_z5_multiplicative_inverses(self): + """Test multiplicative inverses in Z_5""" + z5_lang = ZnTargetLanguage(5) + + # Test all non-zero elements have inverses + inverse_pairs = [(1, 1), (2, 3), (3, 2), (4, 4)] + + for a, a_inv in inverse_pairs: + dsl_code = f""" + params 0; + bounded params 0; + ReturnExpr None; + ReturnPred {a} * {a_inv} == 1; + """ + program = Z3Program.from_code(dsl_code, z5_lang) + result = program.run() + self.assertTrue(result.proved, f"{a} * {a_inv} should equal 1 in Z_5") + + def test_invalid_zn(self): + """Test invalid Z_n values""" + with self.assertRaises(ValueError): + ZnTargetLanguage(1) + + with self.assertRaises(ValueError): + ZnTargetLanguage(0) + + with self.assertRaises(ValueError): + ZnTargetLanguage(-1) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file