From 5afe09442a954c4b8f15a7728737cd794ec31649 Mon Sep 17 00:00:00 2001 From: User Date: Wed, 18 Dec 2024 15:02:57 +0000 Subject: [PATCH] Initial commit --- .env | 29 + .gitignore | 20 + .gitmodules | 4 + Dockerfile | 32 + GroundSLASH | 1 + README.md | 62 + demos/2-MNIST_addition.ipynb | 359 ++ demos/4-queens.ipynb | 132 + demos/draw.ipynb | 411 +++ experiments/mnist_addition.py | 448 +++ experiments/n_queens.py | 283 ++ experiments/promis_paris.py | 285 ++ experiments/scripts/mnist_add.sh | 21 + .../scripts/mnist_batch_size_tradeoff.sh | 68 + experiments/scripts/mnist_dppl_comparison.sh | 55 + experiments/scripts/n_queens.sh | 11 + experiments/scripts/promis_paris.sh | 22 + imgs/logo_dark.png | Bin 0 -> 17655 bytes imgs/logo_light.png | Bin 0 -> 19438 bytes pyproject.toml | 33 + src/asn/__init__.py | 4 + src/asn/asn.py | 348 ++ src/asn/data/__init__.py | 1 + src/asn/data/datasets/download_datasets.py | 89 + src/asn/data/datasets/family_relations.py | 361 +++ src/asn/data/datasets/mnist_addition.py | 85 + src/asn/data/datasets/shapeworld4.py | 215 ++ src/asn/data/expression.py | 65 + src/asn/data/reasoning_graph.py | 1227 +++++++ src/asn/data/utils.py | 215 ++ src/asn/models/alexnet.py | 36 + src/asn/models/einsum_wrapper.py | 171 + src/asn/models/llm_wrapper.py | 49 + src/asn/models/promis_mock_model.py | 12 + src/asn/models/slot_attention.py | 258 ++ src/asn/solver/__init__.py | 6 + src/asn/solver/gnn/__init__.py | 13 + src/asn/solver/gnn/constr.py | 31 + src/asn/solver/gnn/gnn.py | 97 + src/asn/solver/gnn/message_passing.py | 136 + src/asn/solver/graph_block.py | 47 + src/asn/solver/npp_context.py | 18 + src/asn/solver/solver.py | 40 + src/asn/solver/solving_context.py | 345 ++ src/asn/solver/stable_model_context.py | 5 + src/asn/utils/__init__.py | 2 + src/asn/utils/collections.py | 23 + src/asn/utils/load_dotenv.py | 19 + src/asn/utils/relop.py | 26 + src/format.sh | 3 + src/tests/__init__.py | 0 src/tests/data/__init__.py | 0 src/tests/data/test_reasoning_graph.py | 2882 +++++++++++++++++ 53 files changed, 9105 insertions(+) create mode 100644 .env create mode 100644 .gitignore create mode 100644 .gitmodules create mode 100644 Dockerfile create mode 160000 GroundSLASH create mode 100644 README.md create mode 100644 demos/2-MNIST_addition.ipynb create mode 100644 demos/4-queens.ipynb create mode 100644 demos/draw.ipynb create mode 100644 experiments/mnist_addition.py create mode 100644 experiments/n_queens.py create mode 100644 experiments/promis_paris.py create mode 100644 experiments/scripts/mnist_add.sh create mode 100644 experiments/scripts/mnist_batch_size_tradeoff.sh create mode 100644 experiments/scripts/mnist_dppl_comparison.sh create mode 100644 experiments/scripts/n_queens.sh create mode 100644 experiments/scripts/promis_paris.sh create mode 100644 imgs/logo_dark.png create mode 100644 imgs/logo_light.png create mode 100644 pyproject.toml create mode 100644 src/asn/__init__.py create mode 100644 src/asn/asn.py create mode 100644 src/asn/data/__init__.py create mode 100644 src/asn/data/datasets/download_datasets.py create mode 100644 src/asn/data/datasets/family_relations.py create mode 100644 src/asn/data/datasets/mnist_addition.py create mode 100644 src/asn/data/datasets/shapeworld4.py create mode 100644 src/asn/data/expression.py create mode 100644 src/asn/data/reasoning_graph.py create mode 100644 src/asn/data/utils.py create mode 100644 src/asn/models/alexnet.py create mode 100644 src/asn/models/einsum_wrapper.py create mode 100644 src/asn/models/llm_wrapper.py create mode 100644 src/asn/models/promis_mock_model.py create mode 100644 src/asn/models/slot_attention.py create mode 100644 src/asn/solver/__init__.py create mode 100644 src/asn/solver/gnn/__init__.py create mode 100644 src/asn/solver/gnn/constr.py create mode 100644 src/asn/solver/gnn/gnn.py create mode 100644 src/asn/solver/gnn/message_passing.py create mode 100644 src/asn/solver/graph_block.py create mode 100644 src/asn/solver/npp_context.py create mode 100644 src/asn/solver/solver.py create mode 100644 src/asn/solver/solving_context.py create mode 100644 src/asn/solver/stable_model_context.py create mode 100644 src/asn/utils/__init__.py create mode 100644 src/asn/utils/collections.py create mode 100644 src/asn/utils/load_dotenv.py create mode 100644 src/asn/utils/relop.py create mode 100755 src/format.sh create mode 100644 src/tests/__init__.py create mode 100644 src/tests/data/__init__.py create mode 100644 src/tests/data/test_reasoning_graph.py diff --git a/.env b/.env new file mode 100644 index 0000000..c706cbb --- /dev/null +++ b/.env @@ -0,0 +1,29 @@ +#Environment file to specify workspace dir and API Keys +WORKSPACE_DIR="/workspaces/ASN_dev/AnswerSetNetworks" + + +#AIML cluster root +HF_HOME=/workspaces/ASN_dev/AnswerSetNetworks +HUGGINGFACE_HUB_CACHE=/workspaces/ASN_dev/AnswerSetNetworks + +#42 cluster root +#HF_HOME=/pfss/mlde/workspaces/mlde_wsp_Multimodal_on_42/multimodal_changes/LLaVA-changes/hfcache +#HUGGINGFACE_HUB_CACHE=/pfss/mlde/workspaces/mlde_wsp_Multimodal_on_42/multimodal_changes/LLaVA-changes/hfcache + +#wandb +WANDB_PROJECT=answer-set-networks +WANDB_API_KEY=cb7eae87a4af67e7866c061f0b05d574e98e9e82 +HF_TOKEN=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 + +#OpenAI +#Manuel +OPENAI_API_KEY=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 + +#AIML +#OPENAI_API_KEY=sk-TI4p-AGQ3kOtU1wdq92ZQDADlIsA9yZ7bwDbjy9gCWT3BlbkFJhfk3wC4Uqbara_JoFw0_TW3vhNUg_fT3T66ZRC9VEA + +#My own +# OPENAI_API_KEY=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 + +#perplexity +# OPENAI_API_KEY=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 \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7904e24 --- /dev/null +++ b/.gitignore @@ -0,0 +1,20 @@ +__pycache__/ +.ipynb_checkpoints/ +*.egg-info/ +experiments/logs/* +experiments/logs*/* +experiments/*logs/* +*/data/* +data/* +*/plots/* +build/ +.vscode/ +.devcontainer/ +.hypothesis/ +spin_up_containers.sh +hf/* +hfcache/* +**/checkpoints/** +**/wandb/** +install.sh +install_transformers.sh \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..e0d4b03 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "GroundSLASH"] + path = GroundSLASH + url = git@github.com:pdeibert/GroundSLASH.git + ignore = untracked diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..2dea91e --- /dev/null +++ b/Dockerfile @@ -0,0 +1,32 @@ +FROM nvcr.io/nvidia/pytorch:24.06-py3 +WORKDIR "/asn" +COPY . . +ARG DEBIAN_FRONTEND=noninteractive + +RUN apt-get update + +RUN echo "Upgrade pip" +RUN python -m pip install -U pip +RUN echo "Install torch==2.3.0 & dependencies..." +RUN python -m pip install torch==2.3.0 \ + torchvision==0.18.0 \ + torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121 +RUN echo "Install torch_geometric==2.5.3 & dependencies..." +RUN python -m pip install pyg_lib \ + torch_scatter \ + torch_sparse \ + torch_cluster \ + torch_spline_conv -f https://data.pyg.org/whl/torch-2.3.0+cu121.html +RUN python -m pip install torch_geometric==2.5.3 +RUN echo "Install ASN grounder" +WORKDIR "/asn/GroundSLASH/" +RUN python -m pip install -e . +WORKDIR "/asn" +RUN echo "Install Graphviz development libraries" +RUN apt-get update && apt-get install -y \ + graphviz \ + graphviz-dev \ + pkg-config +RUN echo "Install ASN" +RUN python -m pip install -e . +RUN echo "All installations completed successfully!" \ No newline at end of file diff --git a/GroundSLASH b/GroundSLASH new file mode 160000 index 0000000..a0b75df --- /dev/null +++ b/GroundSLASH @@ -0,0 +1 @@ +Subproject commit a0b75df16b010b16ed7a740df33bba9af0d00d8b diff --git a/README.md b/README.md new file mode 100644 index 0000000..b020ad5 --- /dev/null +++ b/README.md @@ -0,0 +1,62 @@ +# Answer Set Networks: Casting Answer Set Programming into Deep Learning +Arseny Skryagin, Daniel Ochs, Philipp Deibert, Simon Kohaut, Devendra Singh Dhami , Kristian Kersting +
+ +![Fancy logo](./imgs/logo_dark.png#gh-dark-mode-only) +![Fancy logo](./imgs/logo_light.png#gh-light-mode-only) + +
+ +[![MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) + + +# Abstract +Although Answer Set Programming (ASP) allows constraining neural-symbolic (NeSy) systems, its employment is hindered by the prohibitive costs of computing stable models and the CPU-bound nature of state-of-the-art solvers. +To this end, we propose Answer Set Networks (ASN), a NeSy solver. +Based on Graph Neural Networks (GNN), ASNs are a scalable approach to ASP-based Deep Probabilistic Logic Programming (DPPL). +Specifically, we show how to translate ASPs into ASNs and demonstrate how ASNs can efficiently solve the encoded problem by leveraging GPU's batching and parallelization capabilities. +Our experimental evaluations demonstrate that ASNs outperform state-of-the-art CPU-bound NeSy systems on multiple tasks. +Simultaneously, we make the following two contributions based on the strengths of ASNs. +Namely, we are the first to show the finetuning of Large Language Models (LLM) with DPPLs, employing ASNs to guide the training with logic. +Further, we show the "constitutional navigation" of drones, i.e., encoding public aviation laws in an ASN for routing Unmanned Aerial Vehicles in uncertain environments. + + + + +# Installation + +### Environment +First you need a PyTorch environment. You can either use our prebuilt docker container available on the docker hub (see [hansiwusti/asn:1.0](https://hub.docker.com/r/hansiwusti/asn)) or create an environment yourself. For this we provided a `Dockerfile` and the `pyproject.toml`. Note that we found the environment with PyTorch==2.3.0 and PyTorch Geometric==2.5.3 to work well, but you may need to select a PyTorch version which fits to your own GPU/Cuda environment. + +### Cloning the repo and ASP Grrounder +After setting up your environment you need to install ASN and a grounder. We use the GroundSlash grounder from https://github.com/pdeibert/GroundSLASH for ASN. + +Start by cloning the ASN and Grounder repositories +``` +git clone git@github.com:pdeibert/AnswerSetNetworks.git +cd AnswerSetNetworks/ +git clone git@github.com:pdeibert/GroundSLASH.git +``` + +Then install all python modules using: +``` +python -m pip install --upgrade pip +python -m pip install -e . +python -m pip install ./GroundSLASH +``` +This will also install PyTorch and other requirements if not installed yet. + +### LLMs in ASN +if you want to use ASN to train LLMs you have to install additional packages (Huggingface transformers, wandb, ...). In your project root run: +``` +python -m pip install .[transformer_libs] +``` + + +# Run ASN +We put together a folder containing all experiment scripts in '/experiments/scripts'. +To start ASN for MNIST addition with two images you can run: +``` +. experiments/scripts/mnist_add.sh +``` +This script will provide you with a good starting point to explore all python args. Also you can check out the other scripts in the folder to get an idea of how to start ASN for other experiments. The script will call the mnist_addition.py which exemplifies how to connect your PyTorch models to NPP objects, create a dataloader with Constraint as your labels and calls the ASN forward and backward pass. diff --git a/demos/2-MNIST_addition.ipynb b/demos/2-MNIST_addition.ipynb new file mode 100644 index 0000000..8336d5e --- /dev/null +++ b/demos/2-MNIST_addition.ipynb @@ -0,0 +1,359 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7cf02ea4", + "metadata": {}, + "source": [ + "# Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7686afd3", + "metadata": {}, + "outputs": [], + "source": [ + "# GroundSLASH\n", + "from ground_slash.program import Program, Choice\n", + "from ground_slash.grounding import Grounder\n", + "\n", + "# PyTorch\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader\n", + "\n", + "# PyTorch Geometric\n", + "import torch_geometric\n", + "from torch_geometric.data import HeteroData, Data, Batch" + ] + }, + { + "cell_type": "markdown", + "id": "b37e1d04", + "metadata": {}, + "source": [ + "### Initialize CUDA" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "291f18db", + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "print(device)" + ] + }, + { + "cell_type": "markdown", + "id": "1d4c0674", + "metadata": {}, + "source": [ + "# Program" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a955fda", + "metadata": {}, + "outputs": [], + "source": [ + "digits = list(range(10))\n", + "n_out = len(digits)\n", + "\n", + "prog_str = fr'''\n", + "img(i1). img(i2).\n", + "\n", + "#npp(digit(X), {digits}) :- img(X).\n", + "\n", + "addition(A,B,N1+N2):- digit(A,N1), digit(B,N2), A0.\n", + ":- q(X1,Y1), q(X2,Y2), n(N), X2=X1+N, Y1=Y2+N, N>0.\n", + "'''" + ] + }, + { + "cell_type": "markdown", + "id": "a1db5ca0", + "metadata": {}, + "source": [ + "# Solve" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4d0b3c2", + "metadata": {}, + "outputs": [], + "source": [ + "asn = ASN.from_string(prog_str)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c628e23", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Solutions:\")\n", + "\n", + "for answer_set in asn.get_answer_sets():\n", + " print(\"\\t\", *answer_set)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc131f65", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/demos/draw.ipynb b/demos/draw.ipynb new file mode 100644 index 0000000..8330097 --- /dev/null +++ b/demos/draw.ipynb @@ -0,0 +1,411 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "30f9a7ba", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "from asn.data.reasoning_graph import ReasoningGraph\n", + "from ground_slash.program import Program\n", + "from ground_slash.grounding import Grounder" + ] + }, + { + "cell_type": "markdown", + "id": "553f85a8", + "metadata": {}, + "source": [ + "### Normal rules" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06badaec", + "metadata": {}, + "outputs": [], + "source": [ + "# normal fact\n", + "prog = Program.from_string(r\"\"\"\n", + "a.\n", + "\"\"\")\n", + "\n", + "ReasoningGraph(prog).draw()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "352cb72b", + "metadata": {}, + "outputs": [], + "source": [ + "# normal facts\n", + "prog = Program.from_string(r\"\"\"\n", + "a.\n", + "b.\n", + "\"\"\")\n", + "\n", + "ReasoningGraph(prog).draw()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68d6000e", + "metadata": {}, + "outputs": [], + "source": [ + "# normal rule\n", + "prog = Program.from_string(r\"\"\"\n", + "a :- b, not c.\n", + "\"\"\")\n", + "\n", + "ReasoningGraph(prog).draw()" + ] + }, + { + "cell_type": "markdown", + "id": "61af29ea", + "metadata": {}, + "source": [ + "### Disjunctive rules" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cfa4ebbb", + "metadata": {}, + "outputs": [], + "source": [ + "# disjunctive fact\n", + "prog = Program.from_string(r\"\"\"\n", + "a | b.\n", + "\"\"\")\n", + "\n", + "ReasoningGraph(prog).draw()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "adaff1ec", + "metadata": {}, + "outputs": [], + "source": [ + "# disjunctive rule\n", + "prog = Program.from_string(r\"\"\"\n", + "a | b :- c, not d.\n", + "\"\"\")\n", + "\n", + "ReasoningGraph(prog).draw()" + ] + }, + { + "cell_type": "markdown", + "id": "48fdc432", + "metadata": {}, + "source": [ + "### Constraint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61328fdc", + "metadata": {}, + "outputs": [], + "source": [ + "# constraint\n", + "prog = Program.from_string(r\"\"\"\n", + ":- a, not b.\n", + "\"\"\")\n", + "\n", + "ReasoningGraph(prog).draw()" + ] + }, + { + "cell_type": "markdown", + "id": "87041400", + "metadata": {}, + "source": [ + "### Aggregates" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3be9100b", + "metadata": {}, + "outputs": [], + "source": [ + "# count aggregate\n", + "prog = Program.from_string(r\"\"\"\n", + "a :- #count{1;2:b;2:c,not d;2:not d;3}.\n", + "\"\"\")\n", + "\n", + "ReasoningGraph(prog).draw()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc681e67", + "metadata": {}, + "outputs": [], + "source": [ + "# sum aggregate\n", + "prog = Program.from_string(r\"\"\"\n", + "a :- #sum{1;2:b;2:c,not d;2:not d;3}.\n", + "\"\"\")\n", + "\n", + "ReasoningGraph(prog).draw()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f0bfac1", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# min aggregate\n", + "prog = Program.from_string(r\"\"\"\n", + "a :- #min{1;2:b;2:c,not d;2:not d;3}.\n", + "\"\"\")\n", + "\n", + "ReasoningGraph(prog).draw()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b472dbbc", + "metadata": {}, + "outputs": [], + "source": [ + "# max aggregate\n", + "prog = Program.from_string(r\"\"\"\n", + "a :- #max{1;2:b;2:c,not d;2:not d;3}.\n", + "\"\"\")\n", + "\n", + "ReasoningGraph(prog).draw()" + ] + }, + { + "cell_type": "markdown", + "id": "0f536e40", + "metadata": {}, + "source": [ + "### Choice rules" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10c41604", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# choice fact\n", + "prog = Program.from_string(r\"\"\"\n", + "{a;b:d;b:e,not f;b:not f;c}.\n", + "\"\"\")\n", + "\n", + "ReasoningGraph(prog).draw()" + ] + }, + { + "cell_type": "markdown", + "id": "5dd45a4a", + "metadata": {}, + "source": [ + "# Strong/classical negation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e30059c7", + "metadata": {}, + "outputs": [], + "source": [ + "prog = Program.from_string(r\"\"\"\n", + "a | b :- c.\n", + "-a.\n", + "\"\"\")\n", + "\n", + "ReasoningGraph(prog).draw()" + ] + }, + { + "cell_type": "markdown", + "id": "e228d3c7", + "metadata": {}, + "source": [ + "# NPP rules" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d1bf6f9", + "metadata": {}, + "outputs": [], + "source": [ + "# normal fact\n", + "prog = Program.from_string(r\"\"\"\n", + "\n", + "img(i1).\n", + "\n", + "#npp(digit(i1), [0,1,2]) :- img(i1).\n", + "\"\"\")\n", + "\n", + "ReasoningGraph(prog).draw()" + ] + }, + { + "cell_type": "markdown", + "id": "03551153", + "metadata": {}, + "source": [ + "# MNIST-Addition" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d67fd0e", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "from asn.data.reasoning_graph import ReasoningGraph\n", + "from ground_slash.program import Program\n", + "from ground_slash.grounding import Grounder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b57a8a60", + "metadata": {}, + "outputs": [], + "source": [ + "prog = Program.from_string(r'''\n", + "img(i1). img(i2).\n", + "\n", + "#npp(digit(i1), [0,1,2]) :- img(i1).\n", + "#npp(digit(i2), [0,1,2]) :- img(i2).\n", + "\n", + "addition(i1,i2,0):- digit(i1,0), digit(i2,0), i1= 2 +), f"Number of digits must be greater of equal to two, but was {args.num_digits}." +# check classes +assert len(set(args.classes)) == len(args.classes), "Duplicate classes." +assert all(c >= 0 and c < 10 for c in args.classes), "Invalid classes." +# check learning rate +assert ( + args.learning_rate > 0.0 +), f"Learning rate must be positive, but was {args.learning_rate}." +if args.eval_batch_size is None: + args.eval_batch_size = args.batch_size +# check batch size +assert ( + args.batch_size > 0 +), f"Batch size must greater than zero, but was {args.batch_size}." +# check eval batch size +assert ( + args.eval_batch_size > 0 +), f"Evaluation batch size must greater than zero, but was {args.eval_batch_size}." +# check number of phases +assert ( + args.num_phases > 0 +), f"Number of phases must be greater than zero, but was {args.num_phases}." +# check number of epochs +assert ( + args.num_epochs >= 0 +), f"Number of epochs must be greater or equal to zero, but was {args.num_epochs}." +# check device +assert torch.device(args.device), f"{args.device} is no valid device." +# check number of runs +assert len(args.num_runs) in ( + 1, + 2, +), "Number of runs is expected to be one or two integers." +assert ( + args.num_runs[-1] > 0 +), f"Number of runs must be greater than zero, but was {args.num_runs[-1]}." +if len(args.num_runs) > 1: + assert ( + args.num_runs[0] >= 0 + ), f"Number of warmup runs must be greater or equal to zero, but was {args.num_runs[0]}." +# create log path if it does not exist yet +if not os.path.exists(args.log_path): + os.makedirs(args.log_path) + +print("----- mnist-addition experiment -----") +print(f"title: {args.title}") +print(f"num-digits: {args.num_digits}") +print(f"classes: {args.classes}") +print(f"learning-rate: {args.learning_rate}") +print(f"batch-size: {args.batch_size}") +print(f"eval-batch-size: {args.eval_batch_size}") +print(f"num-phases: {args.num_phases}") +print(f"num-epochs: {args.num_epochs}") +print(f"seed: {args.seed}") +print(f"device: {args.device}") +print(f"num_runs: {args.num_runs}") +print(f"log-path: {args.log_path}") +print(f"data-path: {args.data_path}") +print("-----") + +# ----- set up experiment ----- + +print("initializing experiment...") + +# create program string +prog_str = "" + +# initialize images +for n in range(args.num_digits): + prog_str += f"img(i{n+1}).\n" + +# NPPs +prog_str += f"#npp(digit(X), {args.classes}) :- img(X).\n" + +# addition +prog_str += ( + "addition(" + # images + + ",".join([f"i{n+1}" for n in range(args.num_digits)]) + + "," + # sum of digits + + "+".join([f"N{n+1}" for n in range(args.num_digits)]) + + ") :- " + # individual digits + + ", ".join([f"digit(i{n+1},N{n+1})" for n in range(args.num_digits)]) + # + ", " + ## order of images + # + ", ".join([f"X{n}= sum(args.num_runs) - args.num_runs[-1]: + exp_log["runs"].append(run_log) + + # export statistics + with log_path.open("w") as f: + json.dump(exp_log, f, indent=4, cls=SetEncoder) + +exp_log["complete"] = True + +with log_path.open("w") as f: + json.dump(exp_log, f, indent=4, cls=SetEncoder) diff --git a/experiments/n_queens.py b/experiments/n_queens.py new file mode 100644 index 0000000..5e4bfa4 --- /dev/null +++ b/experiments/n_queens.py @@ -0,0 +1,283 @@ +# stlib +import argparse +import json +import os +import warnings +from pathlib import Path +from time import perf_counter, strftime + +warnings.filterwarnings("ignore") + +# clingo +import clingo + +# PyTorch +import torch + +# GroundSLASH +from ground_slash.grounding import Grounder +from ground_slash.program import Program + +# ASN +from asn.asn import ASN +from asn.solver import SolvingContext + +# ----- parse arguments ----- + +parser = argparse.ArgumentParser() + +parser.add_argument("--title", "--t", type=str, default="n_queens") +parser.add_argument("--num-queens", "--n", type=int, default=4) +parser.add_argument("--num-phases", "--p", type=int, default=1) +parser.add_argument("--device", "--d", type=str, default="cpu") +parser.add_argument("--num_runs", "--r", nargs="+", type=int, default=[2, 10]) +parser.add_argument("--log-path", "--lpath", type=str, default="./logs/") + +args = parser.parse_args() + +# check number of queens +assert ( + args.num_queens >= 0 +), f"Number of queens must be greater of equal to zero, but was {args.num_queens}." +# check number of phases +assert ( + args.num_phases > 0 +), f"Number of phases must be greater than zero, but was {args.num_phases}." +# check device +assert torch.device(args.device), f"{args.device} is no valid device." +# check number of runs +assert len(args.num_runs) in ( + 1, + 2, +), "Number of runs is expected to be one or two integers." +assert ( + args.num_runs[-1] > 0 +), f"Number of runs must be greater than zero, but was {args.num_runs[-1]}." +if len(args.num_runs) > 1: + assert ( + args.num_runs[0] >= 0 + ), f"Number of warmup runs must be greater or equal to zero, but was {args.num_runs[0]}." +# create log path if it does not exist yet +if not os.path.exists(args.log_path): + os.makedirs(args.log_path) + +print("----- n-queens experiment -----") +print(f"title: {args.title}") +print(f"num-queens: {args.num_queens}") +print(f"num-phases: {args.num_phases}") +print(f"device: {args.device}") +print(f"num_runs: {args.num_runs}") +print(f"path: {args.log_path}") +print("-----") + +# ----- set up experiment ----- + +print("initializing experiment...") + +# create program string +prog_str = "" + +# initialize rows +for n in range(args.num_queens): + prog_str += f"n({n}).\n" + +# choose a column for each row +prog_str += ( + "1={" + ";".join([f"q(X,{n})" for n in range(args.num_queens)]) + "} :- n(X).\n" +) + +# no column overlap +prog_str += ":- q(X1,Y), q(X2,Y), X1= sum(args.num_runs) - args.num_runs[-1]: + t_init_cum += t_init - t_start + t_solving_readout_cum += t_end - t_init + t_total_cum += t_end - t_start + +exp_log["clingo"]["t_init"] = t_init_cum / args.num_runs[-1] +exp_log["clingo"]["t_solving_readout"] = t_solving_readout_cum / args.num_runs[-1] +exp_log["clingo"]["t_total"] = t_total_cum / args.num_runs[-1] + +print("average time:", exp_log["clingo"]["t_total"]) + +with log_path.open("w") as f: + json.dump(exp_log, f, indent=4, cls=SetEncoder) + +# ----- ASN ----- + +print("asn...") + +t_init_cum = 0.0 +t_batching_cum = 0.0 +t_solving_cum = 0.0 +t_readout_cum = 0.0 +t_total_cum = 0.0 + +for r in range(sum(args.num_runs)): + print(f"run {r+1}...") + + t_batching_phase_cum = 0.0 + t_solving_phase_cum = 0.0 + + t_start = perf_counter() + + # initialize solver + asn = ASN(grnd_prog, False, grounder=grounder, num_phases=args.num_phases) + + # initialize solving context + solving_ctx = SolvingContext() + + t_init = perf_counter() + + for phase in range(args.num_phases): + t_phase_start = perf_counter() + + # prepare graph block + graph_block = asn.prepare_block( + rg=asn.rg, # pass pre-computed reasoning graph (avoids copying) + phase=phase, + device=args.device, + ) + + t_batching = perf_counter() + + # solve graph block + graph_block = asn.solve(graph_block) + + t_solving = perf_counter() + + solving_ctx.update_SMs(graph_block) + + t_batching_phase_cum += t_batching - t_phase_start + t_solving_phase_cum += t_solving - t_batching + + # get labels for SMs + exp_log["asn"]["solutions"] = set( + frozenset( + label + for label, atom in zip(asn.rg.node_dict["atom"]["label"], atoms) + if torch.isclose(atom, torch.ones_like(atom)) + ) + for atoms in solving_ctx.sm_ctx.atoms[solving_ctx.sm_ctx.is_SM[0].squeeze(-1)] + ) + + t_end = perf_counter() + + # accumulate timings if warmup has passed + if r >= sum(args.num_runs) - args.num_runs[-1]: + t_init_cum += t_init - t_start + t_batching_cum += t_batching_phase_cum + t_solving_cum += t_solving_phase_cum + t_readout_cum += t_end - t_solving + t_total_cum += t_end - t_start + +exp_log["asn"]["t_init"] = t_init_cum / args.num_runs[-1] +exp_log["asn"]["t_batching"] = t_batching_cum / args.num_runs[-1] +exp_log["asn"]["t_solving"] = t_solving_cum / args.num_runs[-1] +exp_log["asn"]["t_readout"] = t_readout_cum / args.num_runs[-1] +exp_log["asn"]["t_total"] = t_total_cum / args.num_runs[-1] + +# TODO: across all graph blocks +#exp_log["asn"]["num_nodes"] = { +# node_type: solving_ctx.node_dict[node_type]["num_nodes"] +# for node_type in ("atom", "disj", "conj", "count", "sum", "min", "max") +#} +#exp_log["asn"]["num_edges"] = { +# "\t".join(edge_type): edge_attrs["edge_index"].shape[1] +# for edge_type, edge_attrs in solving_ctx.edge_dict.items() +#} +print("average time:", exp_log["asn"]["t_total"]) + +with log_path.open("w") as f: + json.dump(exp_log, f, indent=4, cls=SetEncoder) + +# ----- compare ----- + +print("comparing stable models...", end="") + +exp_log["valid"] = exp_log["clingo"]["solutions"] == exp_log["asn"]["solutions"] +exp_log["complete"] = True + +print(exp_log["valid"]) + +with log_path.open("w") as f: + json.dump(exp_log, f, indent=4, cls=SetEncoder) diff --git a/experiments/promis_paris.py b/experiments/promis_paris.py new file mode 100644 index 0000000..12edf9d --- /dev/null +++ b/experiments/promis_paris.py @@ -0,0 +1,285 @@ +# PyTorch +import torch +import torch.optim as optim + + + +#ASN +from asn.asn import ASN +from typing import Iterable, List +from ground_slash.program import Constraint, Naf, PredLiteral, SymbolicConstant +from asn.models.promis_mock_model import PromisMockNet +from asn.solver import SolvingContext +from ground_slash.grounding import Grounder +from ground_slash.program import Program + +#Python libraries +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path +import argparse +import json +from time import time +from tqdm import tqdm + +#rtpt +from rtpt import RTPT + + +parser = argparse.ArgumentParser() + +parser.add_argument("--steps", type=int, default=500, help="the number of pixels to solve in each step") +parser.add_argument("--size", type=int, default=6500, help="the image width/height") +parser.add_argument("--title", "--t", type=str, default="promis") +parser.add_argument("--log-path", "--lpath", type=str, default="./logs/") +parser.add_argument("--device", "--d", type=str, default="cuda") +parser.add_argument("--credentials", type=str, default="DO") +parser.add_argument("--data-path", type=str, default="/workspaces/ASN/data/promis_paris/") + +args = parser.parse_args() + +rtpt = RTPT(name_initials=args.credentials, experiment_name='ProMis PARIS in ASN', max_iterations=1) +rtpt.start() + + +# # Program +prog_str = ''' +% Batch of probabilistic facts +point(x0,y0). + +#npp(over(X,Y,park), [1,0]):- point(X,Y). +#npp(over(X,Y,primary), [1,0]):- point(X,Y). +#npp(over(X,Y,eiffel_tower_stadium), [1,0]):- point(X,Y). +#npp(over(X,Y,bercy_arena), [1,0]):- point(X,Y). +#npp(over(X,Y,secondary), [1,0]):- point(X,Y). +#npp(embassy(X,Y), [1,0]):- point(X,Y). +#npp(government(X,Y), [1,0]):- point(X,Y). + +% Its okay to go over park areas +permission(X,Y) :- over(X,Y, park, 1). + +% Its okay to fly over major roads +permission(X,Y) :- over(X, Y, primary,1). +permission(X,Y) :- over(X, Y, secondary,1). + +% define sport sites +sport_sites(X,Y) :- over(X,Y, eiffel_tower_stadium,0). +sport_sites(X,Y) :- over(X,Y, bercy_arena,0). + +% define public buildings +public_building(X,Y) :- embassy(X,Y,0). +public_building(X,Y) :- government(X,Y,0). + +%it is not allowed to fly over sport sites and public buildings +permitted(X,Y) :- sport_sites(X,Y). +permitted(X,Y) :- public_building(X,Y). + +%we are only permitted to fly over the permitted areas which are not restricted otherwise +airspace(X,Y) :- permission(X,Y), not permitted(X,Y). + +''' +# % Query +# :- not landscape(x0), light_drone. +# :- not landscape(x0) +# :- not light_drone. + + +# ground program +grounder = Grounder(Program.from_string(prog_str)) +grnd_prog = grounder.ground() +print("Grounded program:\n",grnd_prog,"\n") + +#print("ASN program:\n",prog_str,"\n") +asn = ASN.from_string(prog_str) + +#mock network which is not trainable but acts as probabilsitc fact +model = PromisMockNet() +model.to(args.device) + +park_data = np.nan_to_num(np.load(args.data_path+"/park.npy"),0) +primary_data = np.load(args.data_path+"/primary.npy") +secondary_data = np.load(args.data_path+"/secondary.npy") +eifeltower_data = np.load(args.data_path+"/eifeltower.npy") +becyarena_data = np.load(args.data_path+"/bercyarena.npy") +government_data = np.load(args.data_path+"/government.npy") +embassy_data = np.load(args.data_path+"/embassy.npy") + + +print("Data loaded") +print("park_data shape:", park_data.shape, + "primary_data shape:", primary_data.shape, + "secondary_data shape:", secondary_data.shape, + "eifeltower_data shape:", eifeltower_data.shape, + "embassy_data shape:", embassy_data.shape, + "becyarena_data shape:", becyarena_data.shape) + +def process_data(data, path): + """ + Load and return the data from the numpy file and store it in the visualization folder + """ + + data = torch.tensor(data.flatten()).T + data = data.clamp_(0, 1) + + print('data min:', data.min()) + print('data max:', data.max()) + + if data.isnan().sum() > 0: + print("path", path) + return torch.stack((data, 1-data), dim=1).to(device='cuda') + + +park_data = process_data(park_data, "park") +primary_data = process_data(primary_data, "primary") +secondary_data = process_data(secondary_data, "secondary") +eifeltower_data = process_data(eifeltower_data, "eifeltower") +government_data = process_data(government_data, "government") +embassy_data = process_data(embassy_data, "embassy") +becyarena_data = process_data(becyarena_data, "bercyarena") +print("Data processed into tensors") +npp_rule_dict={} + + +total_steps = args.size**2 # 6500 = 262144 + +epochs = int(total_steps / args.steps) # 262144 / 512 = 512 or 262144 / 256 = 1024 +assert total_steps % args.steps == 0, "total steps must be divisible by steps" + + +print("total steps:", total_steps) +print("epochs:", epochs) +print("steps:", args.steps) +print("device:",args.device) + + +asn.configure_NPPs( + { + npp_rule: { + "model": model, + "optimizer": optim.Adam(model.parameters(), lr=0.1) + if not i + else None, + } + for i, npp_rule in enumerate(asn.rg.npp_edges) + } +) + + +def get_npp_data_dict(i, size): + """ + Returns a chunk of pixels from map to be processed. + """ + npp_data_dict_str={} + npp_data_dict_str['#npp(over(x0,y0,park),[1,0]) :- point(x0,y0).']= [park_data[i*size:(i+1)*size,:]] + npp_data_dict_str['#npp(over(x0,y0,primary),[1,0]) :- point(x0,y0).']= [primary_data[i*size:(i+1)*size,:]] + npp_data_dict_str['#npp(over(x0,y0,secondary),[1,0]) :- point(x0,y0).']= [secondary_data[i*size:(i+1)*size,:]] + npp_data_dict_str['#npp(over(x0,y0,eiffel_tower_stadium),[1,0]) :- point(x0,y0).']= [eifeltower_data[i*size:(i+1)*size,:]] + npp_data_dict_str['#npp(government(x0,y0),[1,0]) :- point(x0,y0).']= [government_data[i*size:(i+1)*size,:]] + npp_data_dict_str['#npp(embassy(x0,y0),[1,0]) :- point(x0,y0).']= [embassy_data[i*size:(i+1)*size,:]] + npp_data_dict_str['#npp(over(x0,y0,bercy_arena),[1,0]) :- point(x0,y0).']= [becyarena_data[i*size:(i+1)*size,:]] + + npp_data_dict={} + for e in asn.rg.npp_edges: + npp_data_dict[e] = npp_data_dict_str[str(e)] + return npp_data_dict + + +def to_queries(y: Iterable[int]) -> List[Constraint]: + return [ + Constraint( + Naf( + PredLiteral( + "airspace", + *tuple((SymbolicConstant(f"x0"), SymbolicConstant(f"y0"))) + ), + ) + ) + for y_i in y + ] + + +start = time() + +#store the solved pixels for later stacking +solved_pixels = torch.zeros([total_steps]) +solving_steps = 0 + +#epochs are not epochs as used in the DL sense but rather the chunks of pixels to solve +for i in tqdm(range(epochs)): + + npp_data_dict = get_npp_data_dict(i, args.steps) + queries = to_queries(y=torch.ones(args.steps)) + rg = asn.encode_queries(queries) + + # NPP forward pass + npp_ctx_dict = asn.npp_forward( + npp_data={ + npp_rule: (npp_data_dict[npp_rule][0],) + for i,npp_rule in enumerate(asn.rg.npp_edges) + }, + ) + + # initialize solving context + solving_ctx = SolvingContext( + len(queries), + npp_ctx_dict, + ) + + for phase in range(1): + + # prepare graph block + graph_block = asn.prepare_block( + queries=queries, + rg=rg, # pass pre-computed reasoning graph + phase=phase, + device=args.device + ) + + # solve graph block + graph_block = asn.solve(graph_block) + + + # update stable models + solving_ctx.update_SMs(graph_block) + + Q = solving_ctx.p_Q + solved_pixels[solving_steps*args.steps:(solving_steps+1)*args.steps]= Q.squeeze() + solving_steps += 1 + + +end = time() +total_time = end - start + +log_path = Path(args.log_path) +log_path.mkdir(parents=True, exist_ok=True) + +#create log file +exp_log = {"title": args.title, + "time": total_time, + "steps": args.steps, + "epochs": epochs, + "image_size": args.size, + "total_steps": total_steps} + +with Path(log_path,f"{args.title}.json").open("w") as f: + json.dump(exp_log, f, indent=4) + + +#stack all pixels +#pixels_stacked = torch.stack(solved_pixels) +pixel_image = solved_pixels.view(args.size,args.size).cpu().numpy() +plt.imshow(pixel_image) + +path_folder = Path(log_path,"npy") +path_folder.mkdir(parents=True, exist_ok=True) +path_img = Path(log_path,"img") +path_img.mkdir(parents=True, exist_ok=True) + +#save generated map +path_img = Path(log_path,"img","landscape_{}.png".format(args.title)) +path_npy =Path(log_path,"npy","landscape_{}.npy".format(args.title)) +plt.savefig(path_img) + +np.save(path_npy, pixel_image) +print("saved image to", path_img) + diff --git a/experiments/scripts/mnist_add.sh b/experiments/scripts/mnist_add.sh new file mode 100644 index 0000000..4eb8a79 --- /dev/null +++ b/experiments/scripts/mnist_add.sh @@ -0,0 +1,21 @@ +# run mnist addition + +source .env +echo "workspace dir ${WORKSPACE_DIR}" + +# T1 +for batch_size in 512 +do + for seed in 42 + do + python ${WORKSPACE_DIR}/experiments/mnist_addition.py --title "mnist_addition_comparison_T1_seed_${seed}_batch_size_${batch_size}" \ + --num-digits 2 \ + --learning-rate 0.005 \ + --batch-size $batch_size \ + --num-epochs 10 \ + --num-runs 1 \ + --seed $seed \ + --device "cuda" \ + --log-path "${WORKSPACE_DIR}/experiments/logs/compare/" + done +done diff --git a/experiments/scripts/mnist_batch_size_tradeoff.sh b/experiments/scripts/mnist_batch_size_tradeoff.sh new file mode 100644 index 0000000..1f2a19c --- /dev/null +++ b/experiments/scripts/mnist_batch_size_tradeoff.sh @@ -0,0 +1,68 @@ +# TRADE-OFF EXPERIMENT: batchsize vs time +# run ASN with different batchsizes 64,128,256,...30k for 50 epochs +# collect time for each epoch and performance after each epoch +# do this for T1, T2, T3 + + +#for batch_size in 64 128 256 512 1024 2048 4096 8192 15000 30000 +# only run for 10, 20, 50 epochs for varying batchsizes as we do not need the full 50 epochs to reach 98% accuracy + +source .env +echo "workspace dir ${WORKSPACE_DIR}" + +for batch_size in 64 100 128 256 512 +do + for i in 2 # 3 4 + do + for seed in 0 1 2 3 4 + do + python ${WORKSPACE_DIR}/experiments/mnist_addition.py --title "mnist_addition_trade_off_T${i-1}_seed_${seed}_batch_size_${batch_size}" \ + --num-digits $i \ + --learning-rate 0.005 \ + --batch-size $batch_size \ + --num-epochs 10 \ + --num-runs 1 \ + --seed $seed \ + --device "cuda" \ + --log-path "${WORKSPACE_DIR}/experiments/logs/trade_off/" + done + done +done + +for batch_size in 1024 2048 +do + for i in 2 # 3 4 + do + for seed in 0 1 2 3 4 + do + python ${WORKSPACE_DIR}/experiments/mnist_addition.py --title "mnist_addition_trade_off_T${i-1}_seed_${seed}_batch_size_${batch_size}" \ + --num-digits $i \ + --learning-rate 0.005 \ + --batch-size $batch_size \ + --num-epochs 20 \ + --num-runs 1 \ + --seed $seed \ + --device "cuda" \ + --log-path "${WORKSPACE_DIR}/experiments/logs/trade_off/" + done + done +done + +for batch_size in 4096 8192 15000 30000 +do + for i in 2 # 3 4 + do + for seed in 0 1 2 3 4 + do + python ${WORKSPACE_DIR}/experiments/mnist_addition.py --title "mnist_addition_trade_off_T${i-1}_seed_${seed}_batch_size_${batch_size}" \ + --num-digits $i \ + --learning-rate 0.005 \ + --batch-size $batch_size \ + --num-epochs 50 \ + --num-runs 1 \ + --seed $seed \ + --device "cuda" \ + --log-path "${WORKSPACE_DIR}/experiments/logs/trade_off/" + done + done +done \ No newline at end of file diff --git a/experiments/scripts/mnist_dppl_comparison.sh b/experiments/scripts/mnist_dppl_comparison.sh new file mode 100644 index 0000000..38bee09 --- /dev/null +++ b/experiments/scripts/mnist_dppl_comparison.sh @@ -0,0 +1,55 @@ +# run mnist with 5 seeds with BS=100 and best performing batchsize for 20 epochs and collect time for each batch and performance after each batch + +source .env +echo "workspace dir ${WORKSPACE_DIR}" + +# T1 +for batch_size in 100 512 +do + for seed in 1 2 3 4 5 + do + python ${WORKSPACE_DIR}/experiments/mnist_addition.py --title "mnist_addition_comparison_T1_seed_${seed}_batch_size_${batch_size}" \ + --num-digits 2 \ + --learning-rate 0.005 \ + --batch-size $batch_size \ + --num-epochs 10 \ + --num-runs 1 \ + --seed $seed \ + --device "cuda" \ + --log-path "${WORKSPACE_DIR}/experiments/logs/compare/" + done +done + +# T2 +for batch_size in 100 1024 +do + for seed in 1 2 3 4 5 + do + python ${WORKSPACE_DIR}/experiments/mnist_addition.py --title "mnist_addition_comparison_T2_seed_${seed}_batch_size_${batch_size}" \ + --num-digits 3 \ + --learning-rate 0.005 \ + --batch-size $batch_size \ + --num-epochs 10 \ + --num-runs 1 \ + --seed $seed \ + --device "cuda" \ + --log-path "${WORKSPACE_DIR}/experiments/logs/compare/" + done +done + +# T3 +for batch_size in 100 4096 +do + for seed in 1 2 3 4 5 + do + python ${WORKSPACE_DIR}/experiments/mnist_addition.py --title "mnist_addition_comparison_T3_seed_${seed}_batch_size_${batch_size}" \ + --num-digits 4 \ + --learning-rate 0.005 \ + --batch-size $batch_size \ + --num-epochs 10 \ + --num-runs 1 \ + --seed $seed \ + --device "cuda" \ + --log-path "${WORKSPACE_DIR}/experiments/logs/compare/" + done +done \ No newline at end of file diff --git a/experiments/scripts/n_queens.sh b/experiments/scripts/n_queens.sh new file mode 100644 index 0000000..ac1d7e4 --- /dev/null +++ b/experiments/scripts/n_queens.sh @@ -0,0 +1,11 @@ +source .env +echo "workspace dir ${WORKSPACE_DIR}" + +for nq in 4 5 6 +do + python ${WORKSPACE_DIR}/experiments/n_queens.py --title "n_queens_comparison_nq_${nq}" \ + --num-queens ${nq} \ + --device "cuda" \ + --log-path "${WORKSPACE_DIR}/experiments/queen_logs/" +done + diff --git a/experiments/scripts/promis_paris.sh b/experiments/scripts/promis_paris.sh new file mode 100644 index 0000000..5a594e1 --- /dev/null +++ b/experiments/scripts/promis_paris.sh @@ -0,0 +1,22 @@ +# Script to run the PROMIS experiment with different step sizes +# 64 128 256 512 1024 2048 4096 8192 16384 32768  + +# For promis paris you need to download the npy files containing the cartography data. +# We put all data on Huggingface +# https://huggingface.co/datasets/DanielOchs/AnswerSetNetworks +# download the bercyarena.npy,eifeltower.npy, embassy.npy, government.npy, park.npy, primary.npy and secondary.npy files and put them in the data/promis_paris folder +#TODO automatically download the data from Huggingface + + +source .env +echo "workspace dir ${WORKSPACE_DIR}" + +size=6500 +for steps in 250000 +do + python ${WORKSPACE_DIR}/experiments/promis_paris.py --title "promis_paris_${size}x${size}_stepsize_${steps}" \ + --steps ${steps} \ + --size ${size} \ + --log-path "/${WORKSPACE_DIR}/experiments/logs/promis/paris/" \ + --data-path "${WORKSPACE_DIR}/data/promis_paris/" +done \ No newline at end of file diff --git a/imgs/logo_dark.png b/imgs/logo_dark.png new file mode 100644 index 0000000000000000000000000000000000000000..b16e589946a797e1633622de8cf63528adaa3f0d GIT binary patch literal 17655 zcmXtAWmH>T)5VIry9A0CcZwD*UZl7~ad&rjD73|0io3f@ahKo(3Bg^yJnvfHUAakC zew=mAoO5UP?AeoORb^QWR8mwJ7#IwBIjL_jFtGj5?-j_1(09>Mm0{=yij$mB)kM@u(P6Bi2@Pft%aYX@6bGZQBZHb<8qndhHLVPL3X8k*ntY}~1B(?|2+=3DpVqvrqor+vGR zqAGB4{=1T!(y$a=uD=5XpR7b&T3yjM!gU$%*mvBeRw z(9yskqZC7I1V)K{nX`u<`mQh5)b+(gvJCy`G=RC5Mzt!QK2dDP=%b+@mi8_aJGNG2u>KNq_Xw&!V0b6T*-PC=H#Qe16qR-3$oPD9SptlL zGl{yPN1WH%fl|*hgIDX{c*)B_ zbVJeS!;4r02YhS|{5J!6xsGkLn8+kh9=P$5Gbc=NU3xj5@3~v<_*~q%0eI`{qAo?W zk$phC=~W6f!AJg509mP{U2uoU3_d#>nQOVTdW`>AHn{J?s+JB@Q9#0kx#vh42ogbl z?=}p+7yrSulV{X&(=WSgcoBs!Xkubi zWM6&{4a!DHFvG|Ma~OXeju-x>f6w%y`#!X-QpFa+;c0F@0oBoH&nRnJOzWk2#RC;k{-{ICH9raa? z=48)9-A;?Mj>Sr1ZNxCpT3Tn7Vs68OlW()-ML z_Pu=ccf?`=k>d7JJU_uHt$`jhLwU={ZexvKD3AYw{4t+`Q-x;V;<{uK;vk<*Bx@m# zR3LuKN#k!=v=Vdgk$m&4z%Y`PuCbu40^PnkFEWr8Y{q;3;QS8MjsK~8`Y@aOV#B^IEc*GVr znnASfxP*Yj&6D<RD$S}69d4MBVOBK7+P>9*O~Ok4ABh+)G*lD z(#kWB?aHLdMLHWBA?co=d_31}z>F9img-4iwmPRfQ`E@b7s(alq9|}a^uvI{hOH!@alwl&(gsh#CPbMro^7q zI0oc57D*HGJt$89E>`p$op%g88$4cpjPBHvUsIc3d!Xm=u^ecdwkDaAJxCl7=+Z@tN4sd2l0Os{4V-}T>ahzFHAJJ zmlz?qNRy-Jfia^25bDQR19(ta=CQR)m2DaD*a-ZeQR6o z@@Vp%2QQ!%S(USE^|HQ_?A19zrV{ADMxrn^x+$8P$Tk57YV@k zL?%6XOWC2s7Mxix<_)P)qr)X?0`!G@0+pO6(#aP(_#^KiBl)jSI1PRZC5#s|pxBUGQ}itW?^EQBu=49Fr2`vJRJOL(GXaWS-e^d#W*{KiB)>*KjFrS~32kAGR> zuNZwLn-!Zpy6tsE4YMelhn{!A$XkCxcZeRvCJthd_L6KqnI;C%GGWeap^OjeZgRwL z1<_?p{p(CWSOmV35pDXbw14>MivJiIZ)Zlf&x6G{rWQHqfv?v2>It`yQdL;^VG(`E9*SvOIQFQ^msY#T@oP+ zDwSX8|GIPe`L;IrQX{R8Ry#j%Cyfmw+1j_x$@sux z*+1-_c?bVsJ_RTvKZ*KZ^w9y>5$#La41_%PbI_*W-GA}9rWElq%(K0OP5x(N4MW!{ zPr#M#FFutOPvCujV@B)^Gs`Kc{k$Q_@c@jL!G2WPGiiRX=2L3Yh2runp4@VPC1CX? z`hd?Rg?c2g-EK98{d?Ftj*`Zg+57SnGn19oUX#}+SgwdT-jshhY`$KiG!|b-r4NP} zu@MJnMDCF7UEJ^gW&d{TN7)>x&xk$abus@;f2FD(BQs}lT%oD|a&0rw6r*oW`@DvI z5>01ddI2gEhCHbwDnckDaMqnEW25Q%Yd<(G1R22UGC0?uPgpt>d_4*HJYPA|oRbJ= z!qM1cIw&6rp2`2hILt=uhaZB8Ii#I%WtPd8k?-TA`rLCrM6~g^8w=B5Uu5)oQBi;) zCWElXFgu#inPLFc%XJh%mQFQOIxQ@*4lB@TB=+PXZ{W}D;sb{mX2Smi=X|UYt(P_~ z@ZuS|ByrsPD=*fi9V}S~zP4e{CrB%8S#vK#1N-*t1ih<1p%q zMH#*ueJ#DFjL@RruHSi$=ZzVXmg}`U*IX4sfr{I}DueO3+EDhE0T>dBSnl$j){mbXd%fN%6IPM zbYSlnMXwu*D7buxHi*ZD!6z2^`-{!suv>JEhg9o1e`)PCpS!PbhwP&R<47_Nt!Zyf ze|;%{0gIV^n#sSq-GSZ`X=zfge*byc=nzeqYO6<#Hc{Yn(o`h4jiHB(r)&#iZ6hSF z+(jjMkeJvR(2c3_rdY-NCuNvFY#mCU|I?K+e%yCcPC@ZP0dvb;S7u z%2(?0JUiP?FScjQZ9U8Pjza%HvN|bn5?g>1| z$uN8aH5^QI)Q8(lpX)hJ<%#+WSObIYfZhIu&uBjQqmv~aTa!3rgRL3=NNSINA_eV9 z*55Aj*{?@@`&8E2v9LVSWLv7N@p7ke$f`S7=-4$iaeY{4sx*KRs7pAA7kPI|Mb$iV zeg<4OPs>SzZGn%)AUs0K_{U!DjIWxL^UYyzP7ygBNrP;|iCTlG^XUL2O z_a1#A6xlN!d%iu^nNOrY>jAp(g?8INXfIX7Mqmo=UIe^io~|*0bxLgQifEIPzxLV{ zs_3kCgwbtP4?JJ9F!whfqcE|ze66V1qK^w?Np3h)`R(#!+an0!+~|W~Q{8e~1?$5J zi>P6Hn&VQX$kNbD0(IwEBroDoR*(p1CGdCkG}-^asS12R zqSz*g1?N>HJzf{5={pl5jjiP=cfrvTaxD${`pOWRB$ zXi#~BC&#W(TYR6)72-WAVtpNsBRSJunhz$$Dc@)>vszwFr4OHrU6wE!u=-X1G8^$X zusg$-HY!6WrW)WoyyK7MyN_{BmS{we2JDd(eeZRPEv8OAUfI$5_q%NnWP9Fd{f_vCn-}e#esXN`p-;mu)&7KB4ljXc?vB+nnAAvP1sA;LDOf?vXxM_6A?jvk zCV^cbEU7}bW4_(XlTi*RwDu!jw%N8f&b<*y?dvT*uCWbZ^rOZ;;elzY#2RQkVyLR9 zka<R=h30Z_pw6P`yr;I zf!gY*8e~%12Ys#m9E}+oJvpd-LsBI=lWc4iaa+LzH_`&GRMAg!k5^)p!wwr=W>4n~ zia@W>yeYbduWElCV`a(7zMG*@-p&(#>-Xo`Q9Q+|XlKoGsMC7>5T{!E zeqM0C^m~NWw5Mq2-I`{0Y;2;`tq)c6^FLgY&yFVxGc3QWgh=~&bo%H+V^38vc6rz5 z_>-9tqF1=Oi<`)3QNt|HRP`{k>4iwFu5N6|42PVe{*rCty@icZr`@JB@1)=YQhj)E>?N1dHhLS9yo05g-Jc9 z#%|uy5lvZgKUaJ)-H^Q&X7TjXGMKxjz;u6OEHxv1-$tanRVG+6x)Q+|C}xvfZ${DubP))d>U;`@eCT#FxPKilS04bA zw0$7IA(&&gguihi``d}~EQXMk^!7%k^Mkfao>n1^oojzH`Z8d5H87h&+&9xI7O26C1JjF??a|!J2)J%s8`k%$4 zAF?ZUC{h=T7Ibk*h`vO^hSPflrK4=H>}vlYu6As;inS!zyvOw-m3@n+)W_S^hr1EcKg5(V$NXo5PC*4$T~?Tu(VW)3KDg;df*-36W#A z50CbMmw&Q4iPpM$q$;+{3i!yG%FgygkBzD~5@c_uL$%?5v=YxT+%8Nz1ZyeVVJX1c zaU?xrsXM;lgQGjFMrJn3x4X$<=?^?& zdx?oL>JhP?OAL4**`Z2ri*Vw-Fd(fl`rs0uUxh)Q5>_a3KD!b1t#DO__lDz?vy8tn zh`n%x!mN})8$_2{oIpw|0AVrUC7H= z)FfV-TY;!ut28<@!6|O^S)dF{9?I(i-^ho;1uG5SH3ae-=Wp{z^E+f@$HDoSwUks_ zCI80gG16B@J$Sa}s#a5a{l$|IS!8;28X~?QPVEhOUwgmDvTJt)yq2&(QhI4>P^mBb z`@P*92P<#_I+F*p%SQ!|tJQBH&XwwOEh9Dfw;HU)qu!UszFmh1yfZ+__%-}BIWp4Z zhXq-DWU{{C0gl!}_5DhL2R@tOzOE&Wj-t`?$BoM^iTOX`buM7trD~(;0*IEO|15>) z#!X)u6A%GpCy5nd!)v^e>jb~6ro8uQa}AT#iQ_C0Pr|lNrv~med0h3X^ZxL?6OGHrCn6) zL3jL}@elJ{0P8gKLy4F$Is^&I6Y0^x9>KrP88KN9C#qKDu8q%AO={U1w&h{_oKJQQA_$T}bX=5WceTy%aI&R<1IFfJM$Ftca7D7CKED*n2womI#VU zr8m<22v^BUn=;pM%0vsqwXw7+#@yTg zMXZAnd#Nv&^N-M<(&)63bR-p4cNME5&6Lk+YWP0wf0G%p4=I~T$FVwdw)s9HV&N*) ztDvDRU(gJ%TZ^KGiPe1qeMCI9T)l{cZ1GIG^p?G1J8_hY20RWjPg-F6nlW}ojiN7> zyEUr`&Hi;O2rwt#lDf+#-?$9$P6hueY3kC5lf*~nPwDMSzfk~F@Qb%e4U}2?G)i+3 zpkK$=pf5L#jf8~j5A)R!zv2u?HhwZ0_}lws%VUfxxyxF&qkJHZ41gLY4tIyY-*Q~D zpl8b>_+|JdNA+4=#)K4E?Z#Y;A{v=+xpQTg8_q@%;zhLDB3I-#HAS)4*et^!IUSY9 z4~0BK6}nxjpp|k=?#y^OD^jk(?>}5%B)UFVqEC$2tm_XoiV(fYGw>CD@Wb@^n<0WA z$P4Ss^ESTmG1VH}i4`XQfKlt%eu~dDl4-SscFK|K18i6;qh`#gQ`DP>9jsai{U)WH zP|J(H;`GEVdZ5Hbvs0yRLrmKM`))6Ou>*v(PFAv$FdUQTS*l@PX2D;|f{5$&ip^-hnh zl5Afi#7cCWF&&o5cU62=tS0hY^brpSofAD!CP0DVgcg&#Q;VP3M`X$^rR$dhLxDIb zn5}71`}oQfJ?COGlP7urDd@Peb0A2PYI6+sMlF%7u|eHe$|Z)?Ld31_!k;3}&6L|l zT#-)Kok_lTwc9Jzomhe4b7|8xAGT0=N%OJINeUa65ON^8s*=(UUw+e`eTv7wtTJQ=lH^nrj7b=~xXA{;rW{X@)IE?J@`S3_E3`HHvGwW!^ z$}nssrfh|yE!Ozqu8%y6(c^zdyW5oN?^Bi% z;PPQPzi9>?-{rlCM({ErM=D19wlYkrkc4j>&TBC&ArAPR&ZRDz7eY_ji+~eN0qdU#Cq&USS`pW!Dt(uFTDkE5(zNn9+?vd}!33yDCUszv2nWhN7 zVuo&U25dIe&F)>?-4GIo97xL~fA^6LBHqB#cxts^kcN29lmfh^WpL2N6bF)&7b4@` z32NsC2-EWj0v=Q46^rp6ZW<;s#kDayg7l4RZvr{J20bJF3ba@qNQ1&&&ePwz6c@S@{(sSb@%SWHpT(yqS$P=H(@{@dQfS!AlFwPs9xt=(8 z0upKQDKGzBq{+WD>w>k3x&QXz_&Q}c?B{o5vi9<*1Y5u3d>EW@npC%qmVB7|{$q|B zchzxFL#S>BYtI`8Ih6x}_|LhJ1SEc-1GaUa`w+Zzr6f}aOJ)nGrOpGdhbX$uu&1+tQ@4NwG(VB7885XNMg<-3yK#IIrfJ!A~~CvXUE<4PyoltZT+?idZg zZ4({~5Jnf~dtAgRTSbyYVngsL?fF(c5Lv&$5j7y(+Y&Kxk0A66QvRokoE!3z;)I$X zKm$@lGC0qh{IenQGE?8+*ZZ94|B-b!gHWT zOk{CEuKjF2?`&L~=P3ZY-}}fA>NQcNNP^> zJ$-YS@;X6>5{imY7~fDsNS=`ZsM6eKxMY?wQ-4@ru||gkd5ku>E3&b94?5l_N8^sb)k;G3h&<&E7sUS!y5`?-VmOy;BXmZfS>dQgo4z2Os?4K<(B`x*CrR>||YX zYKol7x6K%>A8VyPAoOoG`PURTJx!&9mbbo!SX0cr?XbD@t3R4XMfHT>cr%TX5YX}W z`4JU>34KQXo{X-K0yVz8QbrQbG80pqRC95AT;0y6Nfl;cDU5>o z-AI8AOWKYQbihx(sb*%WAj77DS1bxEt2xC)P`6oJ@_aS^17V}J(3u(rJnnPLh>1}l(J`jO^^ziVncz`=1cKBntwpWD51WJ*>3^t(2&BYDW;$g zd`kzM=0r&YAG+|ge_Qw1SpNmR%5Mk<6aMxX#vkIvH+x>rVL{M(<`{d+qa#0!(N9^)j?rDh3IacV z^c-W1$vVz!7@JL2g<{fGR%iw&un{w&URjVjRuIB8alvD>jx)MEYb{8DY*t=PWi2*k z^OY_65-Z=B)3F87!A*sjT%If^(Jb4Nb{*LTP!7mc-1?5;HuJ8P32d>zoxjr;BtQ-h zgQB&2#a=n1&P&+q`Of#>omm)RnB=JBjmi_ROX=ji$VLx0074l!zvXY=+=!iHI9=Nyl+>CZB1FHO&4{|u zrU-qu+b1SlOs&Dlbm>Hu2<&u}dGL z-I<`6X8GYDkw!0Bnj+YR?LPnkPcd&)-hzLaBg8EuO^5_9QXj*D-Ub<$@ z6ns0Pxj#wwB2@PJ+S{mNL8y{#h1&%1Rl#Bo$;t}qdc9&Yz&*s`)J=jg^lQla3r-s! zkL#-07@7>&{`x-B?DFkTf7F>-bUg2pOnd4d;Uin7M@0OtpBn7bl~UJILTXfeOT5r+ z&1@{b`JjG`aKQ`v8D=fO5rp(QK7qDTQ4uQb-Wqd;a)#V!W1y|Q;QhC1cT?Ei)Hu;V zpw-c+d9&gY$0tT#<_`Jy03B>69Qt2KW;OFQA^YJ_+x4#oD;$OH=Y*!yW%ydn)1}6! z5yxk%)%x{H^Q%km`I0HLSx8h#UH|YORrq^tkm5FV?!DvDc0!YSd28htlOL%nm*i>6 z3nMnt)s+~QC1;soW@Z=cz)l;0O+XKs94&&{?YE$Pb|~>`of|jAhO^jUZ5j=MlIVhg z7i#Edv)k~)e!{%ki*CmD%BewU7O6a6oAxde-(M4-4FkdnZ2$NfzsPCS`~j9sy&o6J zPnF1fb4%hnr580e1*DI>ThWTh-yTlV>k5H%(d%_%&Ilqs;Vzi$b}@ypuICAQNBHIv zo&KiEZ*QsVkz^BwY*7pd%{vCx1F&PF($dls;OI?$#jN&xXMgkRzl5^zuO%H=bvwYE z>`677yq0h+hllG}*CVdSx)Ggro1e$&kl4|->k>sy3K()6>eKwp_&1CdV2^=&MZD%A zFkq`ZxSP;ML*AAaJo-;3#X2=?58r=qf-***I8yKBs+>sx6v*k5I4(uz$9^;BMSaF; zqhZond%>rna^1eHRi&%uyyo~g#70(wnbS<`EP&#MQ(Z( z3L9@sXaI&?DiksXP{1c%B~Oi6)lC{698~Ro_`v>%g8-}b{y=-Q6@WDD|3DM8cbKhT z0IPrB7X9uOM3(dCKo+W6(pc62u%CBkT5eQNRTH1Rk_B)H7?U6llxEGW^ z(A8sIQ???9QX_98-ds;;Q{&@2)>WJ+8|URd3ysm?y}w}@{t<>G!$qNf4r*}h=mMug z5S%^&TrA{uR>$5P$#;;PTQzZvd=!{{BB-7p!in$77dHL~&Hm@!-!GbD<8-k${hZ5i zpe#ldk_KyPo%l*=;|(s{98;AF9kD;8>lxxR_EAUgiGuL;t3qVY<8djyh)!lm)Cq3X#q5ut?BISCHO6qq)U^mrLDz>`5IR?G)#&$dDp7M!Vn5~ zri8Osj@MR0Q<+G3AzdAH)}SN=OEQO}o8bBZPzBdg=rfzfkmKdDz&XN-xu(TeZ)UPX zVx*z@yZiW0T)!H$ZjtcING+M&U*8pjAvz0BK~sm&XXOZcJMgV_KN^nr?B{WRDQ8dD z7_MEVEyayuoRA|d8N+HU&}Si+tZ!2qBWH@>m8a~-(+vDhb%y>NCXP{z8R-`qw)ER? zlk*}j=il}s81L~&l-_uFrUg-V)ERDRj%XALd@agw#-__695^{*en@O`TA=(TXfsvc zboxX7D{Rm7qRi3IUXrh*>5EFdQ1l>&WLC<{OO~L??{Q{P-S)I%mM?ZCg+AJl%e~kE zZsKG%z9o0B%=VcfcB4+?Od`F?VLiS2N(Rh8Skb%s$bw~@vmcEn)aYzg6h71%|4C7# zfdLaoKtH^au}b5+&KWdIG_wRerRHge3IzX|fhP*T>p}QxZ)!V_xo96bWUj3}N)bHc zx%oDxM!05f9pZhCgBkezvNwQT<~89FkYHy+dgnb`)4XOTUZyDAbo1E6o1s3>g(pSX z?)TbOevQ}amzfE<_^($#Nmo#=O~!*?UnDWYLdPfUr6dC@^l>d?I20bnJH!8Q{_#nb zHRt5(On8*SMFX7`3P+d&Db$;js{qg1=}>cEgvBmZ*-|ts1yL>iQ2J3e^I+0NB%e%D zutx7zh;v-f>r1)xqg$6A4f3l#UZArW1ndF>h%_@5C{BjH$mWR$dgLP(r)z4Cae(c6 z6L^3=)@vEF#CS z1UK<|^O2~wu6lSKs3*>H!St~@LQHH3wQ+{Y?=kEo)?_VKhtT&tEs&SIXXWEX1XpH0 zfIZBC2Vh^Rp~^9&P%*_#x#oJXZtVoG8KZyToamo=p+M!wg#QI40bv7~=0=>jfi;!H zboi|4`d_2<@4)ieFHd(E&pA$yAq53i+KvvYX07K)lPt>q5>84gVUMK9(F2OxwCui& zJ|dLTcf2zc0}0|1rA=S3!m?vx3~2`Ohj3n>!!4NqS{Z2!Qf_ucT3?xD^Dxba7$EYxBywB&e1Aj+Fp)6rLq4wdBhgsN`t=nCZyzP4tCAm{TM${gZ!i?~{5mY3!l5@US>mOib#?3yNq%x%b~ zj-g@e;N(TO)jJq4!i&(p^jCJI_GSe5rqHj$%MiVJT<1W&@9sqo7T{U!RNd3x+Kf4y zI&8^{1SL#Q2iugbat=_pVMNrOeOxa?Tdgx2 zp{SsVi&;Q6a*iZr^UH|KtpFT{src2%wT*~i%1Nd|n(bVVrFd%xo$OPSHI^Gx85Ikz zNl(!^J!b)wUbno4IbyksVh)@?!Elb$SH_xLg07d1*4@RgSHb^Dpf;+@tiDOG-mMqM z$JwoUXW9!?8i@tCZo@X(Zu;yk6SN6n@*{w4!vg7M*b&-GdcM^()%j6Pq+((|QP%1~ zIb6XOEe*;`JZGfr(8x%gU==etqdym_*+FQ7m5^gUVGCMSNwhG`IHa!X@*0(PHY&fz zkC4E}gSFq&pM8N(QvF&^AG)2TMPs5n1w60z?Phoncu;O)6T6=(PP3$AwO=;BE-fWE zu2s0(dYsMVhSQf2eoJ0!p3FsVAJy~;=vZnsZsDQA=N!q$Aj zy|By(`w%QMPkZA^ABH+)L|DFC|eHKA-UJN^vwjEI-M#<=0CXuRB;pCQU}G)` z?RlMu&O|du&HOl9`KtLEsjluiz`(zmHcR>2g!X?Da&{ub{I&Unk#z{LVT4oqlkQyc z7bFjGJFPFI9xdEsj|w(K6f;GQDFx|RiqTWCEe2o_(jDsEB`1t@k4t82M_V&-k{}+O zi;P#>l2>DPgY`*v^`a7@EWJ=18X6FK$sV$0Wfp(8Sm*;1aoBiEKh=VP$4C!E({FNj zzKr{?2rk-@u6e`$3A`H>M<9@s2Gb>>OF&p3SYw#waUp>YTDxeLlQ=? zm>k;AY4PaS>-kYr&NKX+?0>OdLQ9G&&&Y#Gyrzvm{}|cpdAeA$ZPBMl3%w!ctx)RB zGX`*=$|=Os0mzqXs0tG)#}hJvNP%Xw24lVA;C6x5??L*OwLEFUCFDPlkY15Fat^CN z@XnP+%lBQkGn+C*f6g+cJlGU(NiovI5AdMn3rmR^=j6+PXS`YLr>w}KuD=-1Vtf&4 z8G2CscAlm&luKp(LYpPEd%I;e5+RgMYxR^!*5V3PNb?$#TO&%y75~i5<&MOecibLo zbLD?M&mD8ZMqoGfi6<51`G@=M$-xYPWT*G-)_;0{!=eB&?nU38##zX^HO$N`!q3FfdBYw>|AX#~^*}}BbCX|JN83;x`b5~R+oL$4yySk%agTVIno$14SMXb64Pza>&U zj+qy}7c{^F8ated;*HT~1EFL2p5=z^nO!(-vX?|Kku}YtgMTdV3-|lBW7(+TIl6bo zcN@C~YzlR4{oFUO0(n26xC1XJ7y zr-sWH$(rzll$WnSbNjt}4S#?y(kqdZ-IMR}C(cSVC4c!y8TCT`dtzNLy=fY=zi7Gv z`33jz16KWGCQ}797Uv>~SD-qgs~gDMC_hA#?bjhWXVl*M%3 z;Ji@kpwT7~nRsUO4P_X7p?`yUA=JH3QX#%ZKF5H|Hrj3!lxyJQuZlw+0RRs zjAto9Vs=*+g}J$8{R`0;cAUII4M(#`1jwFcQP#NvFzMY-lnO?^xj78vW4#K~^jq!u z$);R{xyl4M$N3lf>SC8E)e{>-v5#)6Tw0Yn15bYtjB0eO ze>T!%_L{FGA=WTKcDxad10>xb{sxWsPtG2pA4~l z_o&U=BMHSkUhFQ508c$6;+H4r9t_W%bC-f>;4Vz zNrT3+s=q!hNhfmw_h%V=-T!d&9W&U-mwKrFQ(R=GW4u=|>@>j0w!FtpL6D`*EvISt z;ern=(bfrF7r+N@q_uUxzkf%!E;S;bqtKHFyVR4XDtL(4jQ&+ki;gr<`5D)PueFyR z9#a|kkAJK~`w^&mYZTE*zJZ5LvkCPB{WYWb;>*F+4RXC!Hjlv~|GH~~z5&`R=(-e71-_uVHkN4cSBIJ(Ae4q4O z4)a3=4uasVaDaA_>T@!>c$4CnQ!A{yz$LR$cLZVt*ovN--QNuibx)}p&tejg*brkr zWBU>YKhW*6p!1sBJTZ(|HEHcuFee%rb0Phh=n?2#whc=Rj7x7YhDj|uc@ zFVIe7D0S~2_tVy$a-A%n`6<7qtD~x~yv0=LACdd%8Tx;BMZBRt$Z8MAdTyi}iTdp& zR;}UR!lXx7@E1r334*wKzqYe|7kgr-^qE=?hQDvFs1SMv|K>hjM7$M#e*v^uOw6cQ z7WkZZma8{NvM&>Vw6G|KBQRO}oagE2AaZ(7_SPDsQJNT@helwV-`h{e${Hr^)tY+x z-Ow=4c4?yuQ2MaV2UFoQAHrwPwj`g*wZ_6rqq(b1j?~mhoDu$|D5= ztnzzZYpX=pPH}0@iQa*=P%8k;WwMOGS(C?OIokqUdNmXNRyvR>xA}1j!JjQ6;4lvb z$>2E8Q5Y18dG3xPFcLgq+ks_&IX+_0a7ZWy^-583CkC4C#kszRXib1;N50pm?=0V4 z`V2gu38C~+wZVjU3$U;MU;;KNNf^!^J`+ z2V!&^N24sK>MJ4KCd$mU*1iJP#t`pfO4Zs*sL&3=AR=(ocv6;!5;_i zCh}gRbiYEnJV)FA^cH7*5NY1L6Oo1#D<@9wrhl?tXC~JaH>egqNNcF*OV#3|Yu@bn zUSlx8r`s8h>nYaiyfiN`W}e3vWIEE!fM{PtQmY?ujmU2llEmPoO~e?* zYq1!nJKtex?nKxb<>(X9CbV5m$Hp}?97GFP1Z+~q@;4_YzQih|YcKVgkl2P*If9+t zjJxn?NoDo<7kqbDuMZvXnxM48N6T=GfAl*dv|UP9TJm+~9cnMF*NcJa?xnicUnaiiBr&fZReG*xTNpOS zvCt-$yqaQWip^t8^x*+~2x8PnV^1VD_nCKS@ z3msTLg|CsTGatWB*sO3OLc2rpH{Pgy;W`xl)ECcViAI%f`7ck~{6iFUouLWd)b_ia zvczrTh+?Ry@PL!FC&@Tv3=81LxIv#t!Z4pLa+m@bU7>|1{I2crNyVcnQl>c~4wfm-M-X(!uUk4dys$BlxX`f8<abWd6GWXUsPR0xup1U9h0y;=VvNxPcDiMtNgf#Y7*TbL`~yVwBw zqPui=psmgCzC+MpxkpmkFr`Jpj(R(|WWUb!8aK~W8|s|^23D;bJm0gCpy;xJ?0c_M zPB>5WqJNws2HXV~&K@bQl|2E!mN+kkS&YJk*y1J816KDGx|8uAx$rzZt$^Kqkfx{+ z+c&POv&a_v<#Ty4JNM(2tXMY)K}8lC9E@*Z9E=3 zg^JHfHGV?gSKq^4R5dp0Exwq7Arg-D$p$fJA$Uf=1TKlC{&C&1Tk2%hPK&E>d;!55j0q`syWLlgj3$4+BHuqG-DI zoM07Voz7{9f}v zt!g-={0yVT$Yq(KoZTVjJNB`oe3UwPm4|ijK@jk>clBp1Ay_DIJnabnmS?C!pc5~Y zVb^ILis5T8`wnGT7a6fh|Dy^CVIh%LR|AP9N_xcUoX?%(48Y&3r=SmczriO3FGez8 z*;;qCS?-DD>MAK-r=h!M@Y?a)$FBEtz%o&9qT0_E;>g4eA}Mx$`*QYbwAs6a$z z6xuUwt*!58@~K(FMqn($AdMjhBuR*SBhGAQA@wf`Xixj??An5EQs>>m(=9h#vl^A{ zo>OEzx?8d+b!w0x$px*!mg+u$Y_(4Yab#Qmk32R)+5ZL811$VpwCHYB-wo7=$gC&` zdP5(4FbRwy*!Q06!ssxAAk+rJbG~}~M^z1SFy@w6lC;-fpCQ{3o1m)d78INWd=`+2 zuM@G@2%R)3qaWl2P-(*zHnM9H>8Z#Cz~ewRXv}^9o>JtZB(jT*42;7~WRcX#PxQeD zldf$aLlcXj4EQ{$dU`;6=N0oSDxOo3KL%z%^-q-*70(2=e*?(IDrph%M@KC z@TFoc7WmvwRpPD4#9t~KseJ%@41Yd@K6YcV;R~+6KGG-qjEG1{S=l&Luc0=x=mAA_ za%Gz~mu4zcB%21`+e?n_>ALzKQhgGGG=0y=k7u13=2Xsz^x6I5i?^zd#clZ7MbHa6RL&S7;po_@gT+6y(!S&Z?L#n&kQc_xaf{5G% zY-hN|*P?pSys18D@mvRNPmnAzu;9wgmoK>cm>_!BoT`*@$xlQgsUL=jD=_n0wz zS(bGza2QPuh$1g5Tv9n@%4>l-7tt)~S5|g_s$Qoe`!{ik3Op;8HDT`L$zf{V@|wNM z%Ui9hsv4~+sZvw>=B);%#zm%oFl9={b-wimLVy`1hj*~u!h1oF_RktWDqMK(f?0R@ zwr51B^e-*#l7!)i9F1yUM0OyR8c|@0q&k$$avkUK%IVX;$gH;!&4FHH#_SY}#fIa$ zLsX=X!1kmb##~a}hvo`A>A3C_A5EL~Rc5`-U-lX~GX7(mHV32nOJF}z`qd4+|9${I z0bW(hdMwtu^-FUnP0E^k@CE~6>boL;={fXNK_>&-c(nDTi*?a?A3Ygbw}6On=vQ7| zkWFvuh~`i9p34nn#E2MdlwM>lTMmS^Oz$$HWviBJ&q6pC=+o$epQuK8!d0f8_sOi+ zjlGPBh=_=Y2#6*H!n(}tHFyAO9f8QcwQbq91==F4N2vlnMVX_bFKc1!ne#q;a#awH hM?^$KL`3+({{zPR&NnLUzZU=i002ovPDHLkV1lfR${+v$ literal 0 HcmV?d00001 diff --git a/imgs/logo_light.png b/imgs/logo_light.png new file mode 100644 index 0000000000000000000000000000000000000000..62097e3b987ed9b7ea502388de343d2b2c543ec1 GIT binary patch literal 19438 zcmXtg1yGx9({=FR0g6j;N^y60Efi>RcL-A41BK%5?rw$R?(P(Kch{m{p7)6{s zGiYC83Vo5#SxVDc&EDME&FF_2z|GB#*~-qw$;8OPjM@H&Mf$lQ5dc6AkdqWsch5M@ z^l;aitH0lO?C2TpA}cITUnX$LJeA(L6kiW%mo9@6f;<7 z5f5MP?zs^>3FU)D;h2KZe!!+?fY_HOP8XAQlq`Lone=5RK3Y?6)w2sTIvCfEs|tmRD_QqmZHhHDf@PMId?=PQcPA=}jt zm;kWm-0bl}ga84CzBG8@uO0{SUw*~yS41iyayV}FsD_p4k9N%&W})O!0_f~Af?;ea z_B(&)wbg=$<1*!(xdaY@x*{+!_KO*TJl$8-J6Yc`+@5Pmft5jcEOxZ64d?wBeWO>DX_+SoJ2X*f}}r zudAW00+)6u@_0be8*OJbgBV07dP%kJHISQL1DIPEwKGr@!24~udkP~EkF7b+7{g|n zMYO=L`OB?zYo0fFfhwAla5F8kUuBxl13~*#5W<3&6Q1f$TOm?y0U(BHqX1S1TbIu} z;;o{E;WOU~>|A{p<(!2aVE~m{?MAQnPS8>qT0IV*XaE|TreGH18g@p{yTwrZMgJ*S z+T>)vY=63`_p{^xHaZmjvX5v1}3lKzIBx3G|GH#L%h={gzuW+`$+#+1pIL;mgN}Hj#UT9rV@b z>p4*wKpkA}Z~~*|rN>O2NcSSR3M0j!^xm_&u~AM9CvOR(k;R)7jtAR;9^yPCY0COW z3;YsvT(&-xpkaKq+N>}2=_5+;k2K`jkN9b>^GSQ7BHTvjW9Gy7f*}HBFpm{H6B~U( zEbm0W_i&2XOWVBBtd`|du;5`Jhp05UQ)@+IV5-Oke1h#eq;uEjo>rI&>P_nc84x@G zc8eX|eNCv$`Ok{MJ$l~Kew!at-_eiP`2unv;3p@AKE|qN`E0>Ij^w5reQFvYEP$zd z+vWyF0F;J04F^C`YH+l_>Iw2y125i&2TXIwc)a*$v*!Nt4+eaoS)tPd(c(NV*!ew5 zG*+L^#2fs#+Rbh2Is3ya7@)R4-bzA>NoCGnzjVUIu)@aSv3u+f@Z$+3LksS8itN{4`L`P{6MuM2H?`fLNBu$DUgd}4$ALr3igus2%715c zm_A3(fGovY{JExsLOEA*&QQ6%y-B#HB}%7hO-}y5Zvbu7X_CUO#%EVpUx1-u?JFAs zVMSj{*I(l?};Rg8apxt_suEY5yOkq(CJ@W9lDx5Ug*@)SKE2 zXJ!nXW?`c{40`YyE28u_V%%A6u|HW~;!+*ZuMblPhPz(AHG19s6*wobORz^azB|~C zJHo|9z*|uiG1eyJCrkMX5FW$`+*h$~mSUI?tsGxK+ITs*r!3@QRQ6@$ii0q3& zEegmOSS?OD)ScUU{0`ly(#H=^1st&l&J1AKA&pczEnALPaz(`kR6X8aj6hn^vHZ%0 z;n!*%YbyVR@2-{N%fWyNIF33H99+K-?n&~NbMovz8HQGVUJaKBjfR$*`4~1oP1GfA3PV&4gRLB$JeQC3Z^6J{lJVRoIjLMwnT*L6t)&GZ z4}O#WP!pM^u0ZXF9>I_G4>7C<|58FHT$*$hV8*!myM{p?ahH7X0KMm7Eq@3xeSGmS z88j9EHrf#petGApU>lid+I0AocK4bcOWqeL^NuK?ECv`ffLo}0AHk0|Q_C9OWsJZ1 z_4fqYCsF=fy~6!yh8&`Pbe8s3YzI$6TU=`mVRG=FnI{|@@v`8i>H!9*hPnnOZSjMm zyD8}l;$3cZlj)7dC)ayAN?h<=SkZ%=@d2J67cujFx=H@j=4pSFJQKBqg(o?A{+u8A zwO!AeiH%l6T^DwL4^VVpkNYVDV>g;G>V>Tr3$s4Q>V~ZDudKq}GJ(4qsIyEevziF& zQTdnKxh>M~mYa>ttZ6ef*`0UOOy_Fq{d zG}D1`DLZHY zmaIqLtc{sKj(?iPRn~84H53`NpLDX69al@Rx*PS2>8AD6IRvcKA=^=J#IUIFSjq>(BD}{ zim)Uxs`*H&(AP!#76jT5WJt3D;cr?f%oT2Z#$Lj3U6rS_oj#zAjk54_$ORS527c_Q zLnoCp8V)|X_%BoK3k5Ts=#paAZ@SMl)O=e~kb!MnKUNRX-}R-Aqz=m`$CsGMTy?Qw zgY7vRchSg3UHGQeLMfM^qBQq0*nss+Mvhge7lkUy{RP*OR6eG-ZVAG6IcgYk<9A~QT(5J%cXFB&8+uFA};lTN26T_(>PFGiX7ojh> zv}+y*vaFLFD^V8LDTg5p6U{YuUB7(HSJ*$8yRnZAY&+U-d_O8SdX0k9{<}-%g_y?e zw|~P7RJW>{<~%r0enZ$akh>GJ9EEB9n7idJJ&^(5;|9p_*f6R)6){d@iYn#paD5JC zxk1g_LRI8Qn)o!*4p21k0mY8)xiE9CL`?%JM;kD{)UMbCkyZ-v??1MXva0EBG+E&r zD~;3nBs0^0`urz+>*3eH@M4lkcX-oDv<8F}7%3C7cQ_OK_PBf&!W|gd*c)MTVK2e> z_qx=o$83OD-rjP81fi=_$1FOugWv0FyKd&7wumO*&CuBf3|8AnJ_;r*Zk#8Rn*HX4 zcRc=#L8YrVdu9{hK{$O{=P9B&f~o%|0PL(^XAR!$%B1V4h0FDK_bFW1EHLaV`(^UJ z7;xyg6F^}@6ji+z)c|*w{VHrI;*-_a$7V7y82ZC%p7qgz*E=fc-dUpSzN4BwBeU&C zp>+dRH47_zCIY;;&OK#a`!j?_L1A2#v*JXfM`!ygxWlYuT2Ug@52;HgY1z~pA636h z=&Fzuh#pX+=*Z^cN1M~ours7XJl*!iui^JRS?}~*`czxvO0)d6^thsY?v}a~t4@c! za${d1=*A16yuQGZ-l=6ARJUDa!0q!J_3pwRuKejHrDyJ0^<96%iETC-2~&6_?q;ga zw41@G#90WkleB0e2=Ov_8~gkhBM@=RYN(mCD=%B_D23l5^o;35&jxSXFh3IOd9aOxx z$#*j}*Lj0!4!f0+mM;!&ev_oZQa9UsUk7S~K6W(WN`CLTHI$9sT2KDMmkfn_Fz(V& zkRkbSf3VyjfzJDqOn$szCnLC83+Iyn2@BE&Rjo(=tgi0Ay89>wKFQCr$NnZtkFLtT z+>WLYpw#}dF0SK3g@Vd7j7J@EwdNy2W+(O9*|2?Ch8}Xtzj$yEW^$!@)6vU&e@{!| z>qNC2P3zsJm35JB=Bd?Ow||bmAii^BIKgd&79rp`8Gh5^a?guQdSK5sky1j@Uk;Dkr6^0(tf z`XD+y>ff8*$>(?bI~m>q7*=~WWH3((wvX}!{`)oN5m5IG^OC&#c)E$i^4(+@A2viF zDL-Gd4U6e>nV2`uxkvM$wSg=7U%Bp*nXeR|ws?3r*81x&GDCj&oxMCDXXVvKcY8bb zV6e5F_imE_E~o{$IclCxQ4P5%looZ}rzycBZ5|Jn;x9|;5|VShYv%^3*7>Mt=dM!j z$3YQ?R}RSq+)u9^A}{;hU-S)N-M#&@19rXWu^sYPL{36w&ss1|4~9_i+OQEX&swWr z8efwmCfeV(r)T!wF#vV&5{M$ZjJ9-`RH>}oC#gKVn0({kS z8B1#ydZMrQF)(I?CrD4T#L`J(c}In0r&;D{u=>rG%^A|Yq6X6SE)UZADL6`A{OsUh z6gvbttlW;zri(vrpU7M`Z+z^F{}_1bXtH6S?x5b=inNUR(*r}Wa>rpZ{L3vaJdBrA z2`Ut3mrL@BbQ>`O-B=%dN7jE^nB=c~1Y$dU%k#x@;DdN!lO~6?8-gx z`^^2r=h^VBv^zi9zNB`_gg|e9Z%SY4EytVnh&auM9zc8W60glkcpW z;*^Gfs%2Xr-qHiQ+n+R%#ft69*UHC#^FXo{0&z19oUEk}J)<9aOm1pEsiPb2uZ1o{ zGbFX07%Ct_*+J7LquF4o5CF(+Pu+Z>BM;lTi8J8{s~kSWS$kocC6v`m^QNf zh_B+Sg)2oO>|PV;yXBW2P-%$25LuZV3q1)4Nvivc@VwT}|5fdg4ER+^cbK{Ho?eng z(>!eHnz)0oLci`)2Y9KrRn(b2g-2Eqi)xdPqA!6-A(^%q7kBFex{(Ody(iU~;Kl;| z`{pw^)@Krpn7bccRQ{?ngPijHieEMbJAawuJR}+qi zZ;j9N-?(OJnjeQ4f1-{jSk7jtv+t=A)Rsne;fI}{-B*>Yeue#-OL1kiMqn%dm zZj4k_7OVs_2>W;Wf=6m9f!nah9dd7;#}!JZlPTl)gy7O#8{ROJqPtDP)6mjyKICfq zc|OJER#}yeo*ZJT`V+tS5}_}uo6m-VsfWA(P~e|*>i`?G0du1K>XxmC^EfsqMK1B% zIj}lS&Z;?jlx-a#CP~s~|7W70H5Rh$Id`|9DOaEXFy%${%CDQY=Ce11v&Y%Cp0%lA;df z#J!0W&5Sg*Wzw)bMiq@S>mhvbkQ3+}=`uTRhXW6GK!JH0dmRtqv~oJ|G6LlV|K|uc z$v|I|UW1ND33C%EWEF=hR?qEL&7AYwndHhABFDZjg9u5-pNmo33gfZ zF}&()e3Ex&a}GO2s9~xVK{n<)LcP18*f;&}BbJ?3eBsGe!TRYEktagxNLI#)yDC9i zt?JdU&UEJv30aRK|3QmV>7C%#yZUV-JBs}@IuQm=+tH!teHiYKCKi|8(g1JvlG<0rS}WkyE0#Nb_~WI>}5Y4etd?*Az=IU#`@xUG3uwy{2RyQsTxG5li|=3X{`8P#1GCAKvVqIc`~$hdoaOPJjwfW&44 z>^tmvf$?7K5$BCsf9G9E4BoLYIvNnsPuz0uHXR)*R<5g*0B?$8@8}wL_Bt)j%6BHw zXI!%;#6G`w32{)@Q!i681y;H%zL-v??A{02&vVKAQXzNXr*5pjS-vSGlRc%25q5aC z&{qU4Q1eHnoI|@1Dwy`60Gd?i(kwh0DuPLsdURPLE@V zJ~uPR%)<0%KIB~FS$*+saz4URXscrb4PRTY{UtcVjD;iq*5XsRf#gt?u!ez$N#=e3 zPT;X2D+BK?INIO+Qzp8DVeVhQKGr?#v@MdDAT!%-@1=iDC+9qVk`P@WInzp`49>E- z>PlWnu+Eh4(Ot*O1xeg#V-tp+56GetjgV8Feki5CfCU|yk^IfBDE z_pK7?ITK1qt z8tkI?Ez54bG7vJPpNvh=iDGTL%_W3e!)AO%Aig$OtrRv=_v*SCy>TncX>slgp*V~F zzXd+*$x&yidMP@VV5V(NAX9R6R{}7GB4|u4nz_uA>zo5MVr>doxaUPBcc_Eqt3S0E zCk(Bhu(U4&{brEe{?ra-D?JMibqa3p_&>LB|Do)|3I!lp=rUKGpj>gBX&!A1X!(l1~KYXHWZRs zr}ug|BmDo)nIiHX8a#ltLHby5GJSJ~-jd(+1Eb`8C5k{6DfOT==aQ^}Eu;9g`R5cg zEi72lVK88$+$k;eTw?l*KT!OEZ;0Cihs56Yy4J}FOkp;{>acd@VK zl4DwK%CeNSe%~Zq$rX9{M-$|RF*S#ZH{4{J8mwuY-D+jNxzy6djS}FJ>y>*5fLy;i zXd6<)xzX`QL3k*i%ipYjeOa4MookS@X5H$U;P=&%*nuXQm_#v9vg$P7*BR9G-OamT zWw_nA40TELFV0kWs5xIuML5%nnK?(Y6J%U<(ob=7DLy$0qLELlKl3v5T+4$bXO{At z{{=8}{!P8jQMUfjd$MmiPs3Qi;Rsxcm`<8s#w^61_M}p8JQo769H$5Ukuhayt{jg! z;dg`;0KtAD(9j;~;=v{xa_4TdU;BomdyK0TpR5|fxJnjeOAZeh_>O+tJ75}Efr!4? zKj!((PAIPrx1I7?7fCD+-`X?RLPtl|HS^K5Evq{ffE5ecrl)>uwIu%<_Z22jtLY6n zN9o5SuJltJ9%xvk0#r3IV!PrTL?e)XvZ{R&IUr&40D9CCa$>!_gEiPpa5YX1%mRbI z-I*Pl`;aqUPQGA5M3N7AX&IFZ{EK59#=R;5qpt1QtcGKZTA3Q2c~-TBM}V%-C?uM6 zPDo)m_(Y-M5H+{Y7j2Ewi;-g_TV2vqCmn?Han(UR`m6#YkV;ySg`74u%U>WZSYiPu zOFyKrbzUOU;yVwDEB(b`#+Gq(_3cLF)xdxvuom58Tn7sl8lMG?n*&n%239>xTQdk0 zONv?zF7p<)Cz<4#R<67k0g`+ljjJ{dSvdEL9eD= zA%6ylc48w5LF1|b+1E&d>%HbDg&WAcMUL9VfMm3}EuB}K@J2mcF5|xrEh)gh9NMc_ zN!2k|{;OQVe#_lQ*qGzQ!JKdbojFtVvB&%gdc-n}yXb}tN*6*_V+$-m=|g=Pp_)Y< z55MGns6w4`Uf&t2-wM$0EBX=`3~{lyu>g~+%j#f|t14xV1PY8@(y6ATin%M8Gr}O8 zf2+MEFpCaWd|qZ_Tr`7vSxfSF!LU*Cl>|yA40Zw=mY> zibD&OwzaV^BfouuJ%D+P>zg{D;_4okex^dT;vcD0b#=)H&Q-kvHy>QSNH$HJp}to+ z5l_$tkY3i-mG3L0-&fCXZ+Hvm@ zbh{oek9_keHcMgR%%LncWVneG?0!J3Pu9;d@UuA$+g3~!@bA|@S2_C6TDU-@s7yA1 zVJZrGYwy2YI&`-opk~KN1e|wiY*hjO7=s4f*Tq)$HHtn@H`PrW#0e{^@Pd>qr6gr% zncyL^$-~s2V(+S~Wp#f)Wyc|7jmpuXl>rkp{8b|9Mc`RExBMkh-4B$7zStM#%1C$j zX%*7JL-t#0IuTUBn+X!3S31=qmvlP%v4SBXiy%efmuZvE|+soh%vBl75QE2(;#D!;B#`o6P`h zh)+CXM0ei82#{+Bi;KJ3FwLv)ELoIrAR9&uWz8CxRu9Y5_1RI4SxAvHe)`-R;Lm_V zI+$a?wSd5rvK1c-x*hsdTRt*|OgZ)dyJ$Ca8w6NWn0VBW?z9cEa9;Wlf4b-_UPj4~ zxoDxKc4l`7TAJ14FHSuq!|mvYero_0vdT94G|bLtSTk?(0?{t|lw<|8$lpINxK=4} zRDvi;fyhg4%0PfFzAaYYf~Wz-ZpRS9Bc5g3(ivvO6<*?cT&Vb=x@bJ(q*O)*c)CjR zIDkd0dMNAnCq03sgMSl52%=@i@v6A~?7|>ClAZ?`MMZCFp_w+G5&@3BJa9{d9Hl<|B3J@JUaEr)qeA2YxErxCmBCs+x_+U(LHvj~)vLY%c} z`r7xOS$x1IDpG-^Q9NCszZCM~^LM*dWuUrw%3xh^|)jmEU ztw#L+?*#w|%f_LiK%7FannLmib}%dMA4fXAvwH5AQ2&ePcD7?v$ZY?Cl7W?#qh|VE zm&Hf$Qtynj6o2}!%D?c6{eE8t6rI1b9|vR%xllA~mi$sjw)h+Kdaw7_UUUXuudLrp z?10aStIT!U1mX_;;e~oG5<-_0xv#us9`wW7PVSS}W5G06PpEt!mMXcp1b@3k=T00- zews$gZH)I>4n39os#r~U4~mcVpW0!0W`~4B>+ECpMElcipscT#dPv0>nGoT6^B*By z{0xOn8yL-*v|>znh!gT{iukf{w8&-0%EuJ@1>V;W=s#wV5v}>X;ET{ApSihM#fs2O zzt2dUR+m%8_#d`#1Fw~sp_)v)dFFRiK@4fB@M@SP|*SzhY10s z&>)ftg8V@)aR$~5p0Ma0+$Ay%J!GGKX)qa<^tFjpGuNqO7!LxA_)kQkD877x@@qeJ z*&|=r3aJ21W&@uQoilmG1`x#E7}>KSv)E*^CAOtH-%Z+N=C^FszQM)dB4w|sCamtm z|0h~VTFLnN|_nt(W@0`FBNdb4dcmvwW{yf zWYWnahm*WV;@gwDmsjLB;Uptc{-pswQX-mS;9L~l&0gNp5};`tx`m;o=<}Ke0xVVP z$q4Q@^~w6N2VMG?9Yf^{MB0qdzo;RA9M(e=*3Lt_vC69_xWp{gLYiUMk;2@(9d!NnDyUM2C+-G_J%0mPjeU4 z4e2)o7pUsdDaWvVj1l)cT0(n^mK?-V``ui1%wztTfEL0DjiLY_z_BM z+iW6O8z_cOwLZuGPTengE8knjp&6H+bRq)VtSiL|4hAiWa5M@^CX z$0-lgCI2w+cE(N#!p17ip___Ti~*uuW@aX8u?INlZFvq=tG~GVAHF#qP6dzWrj;G^ z)OTBvr$t7AXA?%mPj}I1eY<^t>8tq)9M*FmO@dDeab!4yDKb&VBtyenbHA8opODAa zQhF^((s8Ts!KAYQY3|B9k_DW`w}!LrDh*&^OFaZ zplbL=26qy7rpBX49r?C>mD~|WpC(DM;gZKpS|-CqrrjD0eA51f)$xth)4Yp387)Z5 zRcB>!#@*^8Fc>|3oA$C7TVe&*zb;AP7C1cm>65|=JE2sGE7+&0>Nr&ve935>0I_7l z`6^$BK2?ummA)qB4;6@T)YTg6(Gt1G3=qlUN6m5rm7sR%dnTwfc{$5p^{ zkyuUAki*eK*;=`MkGUlW-VG+k&xu0z3S^As6pCjN)A3t|tcb?#Uqg`3e>#^_U_gue zGY$9e#e3u#Zjk3o;XC5x9kzrB+CtQG&=y7^$+pr`ELz8m>=)aXV4HOt^o0sPdR^i+ z3h}D$K5r&|&lwu3n>tMsPp|)ehB!l=Fr(h0CNTnLjwCIma#0{)E3GV9?6J5?E<{ZR ziU-1uBJGPZlKf*VqhO~@*iBMPdrHisN+Ws+{cs?EC*>k)teW{{O`7)G=T-FPY_^s+ zNrKf#CvxoBFZnAfosQKsz%JG&6T*f4i6Ug73Pk{b$Jsx6mdry8n9sq&v9B?fI@Z+4 z>~eNC=?)#~N&Zu3D?X?YD`ZUSHtbTJYN}XjAIrr*HL?CVr|#&^L9uU(x7ojhi2uhF zl9r0k{FRpHRdwu&xl$!+_F=BmQf_gMOUB@Aeb2>8=x<5LD&`g5*Qz}>{COPo=L%!V4V z(Q~RnY4Wx2$bCkg7*RB4RXhJyK7o&r=DZ2%N$MVxa4#Ay6X9*|Gi<5~755ozE|J?B zlA&zg4$!@&)ZAp*G8j#4C|QTf(3QV<&Vq9Md7%<>0im<1hb;=2h?es0V{%Hz-c_Pc zf5pE@2u-VU!4GKIOu1X0kMroKWMc_|!PFZqup@XV@RNu61|s5u&m(JnFW}ohx5~GgIB5;1QY*Q?&RNK9!~c2~aa1RPHPh$a!(ca% zb~si(6#$#UeJ&Ll1L9ULgsTKX{`rITFWxWnqzKgt1l7nXuOBVe#}pMPv>M8gwAI7D zSeE|Fndstkk&rsrZRR8L@2g5Ga?JTBs*#9nr-cHo*$))92;{9Tu#jjG+S z6U>)5e|u;*B-R=7x9UA<`mLaE{u9@>MIlVP|AYEc%oUx!Zz{aVs3r^ewr7l_tkiQo zx6DG^USeS~?!quw@s3u_G5nENcD+-VPkTAx`w2P*qKRai$^oYw*3#%Vcx={&lfL2` z+SKi4Cnx!lQ9tBON5IShRx=f%QD0%<5b4{|2lzLqf}&%Vr*=gHahL8hfHj~?jhzT; z9zc`vuOdm|Nd1>NjEnnj7V?w(*NlV1&bsVc6DkZlBSVb8{Z~iFDAgN6n0}ffIsCj!ES;~xmGrgFv0DvkM_1o{ow<;2Va|g zX67c0;iD7>3hIEYM&2=L-pAyGWWTw`P;dmfF1~jX5E1NkN+@gq0(HM(x|h*wJ6kfzH`+eC$DToA(A@=jx?$(1mRmkXWPFAboo zT8GQ|NPjC7M~ojvgUt~@o4)FF`&eZdpF80kmmtdfpV@TG_PBuv>!0zJyL^$4PJcpw zxE|5;57~K>Q+&yPAj-Wo-(_$4rinkGF6+$MT`!{RINGjr#Ac2nPMglZ21FQktD}8Z z6t&4`{4hA*Z!rGQo}*($Z(!zm|2c~(TGMyl%UI~ud6-WDGvE@sdMx?z3wqSrxA^vU znZ=KcvaEg6O2u{2okicDyXwiC(cvs&0u;pSr#N6mjLRneC-*nUGH5WU4FlMGzVMPR z_g3NpBa%JXdH#R`_WT`|*W57SV=0E&e8hL*x4;_X%1gw^ zsNvF}H5k%*CGc2bY0wBsbqp`s7qHOx_%aVJ7mHYo`k)0}6&h`5AD_u}UE41;~h zt{_$h4*Xk(@}A;c8h)Hp^^>AwJ7(OaW2^HG-Zq?M;dFTgY*AP^a!J83tqvy?Zh4Nx6$-{~J<>y>Z^yb7qf=If z^{XaIQdgFfki-$I>SwZg<1vF5ZuJ%us1_{#!HGs;N3Ud_siox0Y@l-(6`s1AKStil z;885O1Snb_rdd0%5>Wl}jeRvLO`Gr$?+_J+sIr@kL%*XS2>XVy9K4S<))6E>L4M+A zPAeZ;$&sFgRq?7gc$rHd%U1?%<4kj$^8}y)`VKcM;;hv3Xk<}NE#b((i4iyKWo~Dq zLj<&7(LV*m?tMuNqnosjmy>bSdFS9@o%oT@jf_>b(j+Y-gB`nYNM>FlfgNMIfNbiJ zH#Te;<_qPB@-K38V?HlQ%@u zAzM^IekE(Ni5!YtZiE9vr`r71JXw8nAg zMG(KOCW4jRV`?a4rJ+3)5)5$xQn=+;wck(`5}+JU$e`GfT151j4-eLInk`7dEcBaP zdsiHhIIR;Zg(eZOQTW-G<5oi?Mm*Bo;0Bp1Eg{8OkWd^?u~{-aMQKGVpAwyPDQ7%?>u22#X5 zjjHBX{%+BM@FTB0HB-Oj39GQX@KZA{S<0E(>f;#Um39-4%IUdqyncnxpUKG0eS(`_ zMZw|HI`<^Y;-**2pQM2H)s6>f#16j@ik1U=uSyVHZ6hFZRi~itzv-%-Q>^;ZQA{bd z;))wMxhWGi%Bmz4N&~TO)GO`JrrNJa}bK6f&k^b891tQO70yv4P7M%EoNlOE4ib=O5Oo;k0b)wX_ zY9LHbv)!vSw5@hRir}@W_-m-dY?3hi-|EPLXm@Y-z3A&ZCrTp?bY*sK{GC?WS}=3b zLJi-kS};Ug|DUmYEy*G)c0#ZC67LQxtUchK4sAW1j+KUPW(U`c zO5IETAq7Z}YMP*ahSln%R6j&^=1WJjUkT2CUqp){{OWTyd4@5_2emrKw*Z!yEPjD$x>t=y_6KkGn*f*f!M9u5`iG z@LuMlaWo%GBn#055vZ$Zc2dHMl)N0==Mej-Dt)V~$HF4iiQ%xOsuDLZNoD=(@Z& z10BaldkxEVC!n@hHl;gy*E4wwN1hc<1GCP()3f4ldsp>uWaU*A8RbAWIdJ$J?fRm(?io|9-4;;L=r)wo%-YEio)_5C(s4@4gb77STP5Sl?%QH_DiWD8!M3!T8hks98UlGhxD!eB}1;tOlV9K7aX!m z0+toLiVtO>G$7AD2}oMV7JRL0MRC%&V`rT-G(^K{CYb8-4h(IOn0@=!ZkwtH--arA ztYk=@gyG6#$4H(@fmCpSR=*Jf817zDjOKxa_ZW0-2~nhfzAM$~>wU zCOFWbh~fYll`{~c!60t&g2jq$zbwCqS9SNc3LkB&EZT;WlTIUZfCuaVbQwfAm(qoV z&`xCgk;1%Z!s!m_Enhhcbx@i^fbpL{m!(Gwsm?y+0kHVDE5?)tiSrfMsfy58cg>cNT^~h$>`FE z5uqg>3Ry?b$!iQZwdt#sT3I%||L68!lP5kqFS?CC9x+Pk(%qT5nWRLg|E~3u@EB{z z50nL5K)hg`DLzkWGVE>^pZyGz!Nm4L073fYB{=z?6(~PlO*vDf^OLf=Sl1oEyIpV8 z5w{X4nyM`yZ`wrYAO&j{|39{vYk5J^;Tdj&dNUbYq7ljx(MUhX%elXrbsAkRxF}9n z^(DqZ)xp-e9;|WI)s8i$m_w@J0G8;}>Zfvyji;c}VRlJL@!6!&MWf9xt{~2Am9yoi z2|95tW{$B`vQP9Zol=&|D>0}609N>}o#5Uro}x9M2hklt1ybnL)t&NZ7C&{Grx3_R zel&EFVOJ%_fb&ByOR7J`fzq$u_>DC)nI2}h8IfRWFJViIhjw@l5Zmk=~1Y@B&n`GT|L*Io>t^r=TtAB-cs*mV**7Hu(d-{JiX-) zar1Dud`=pOPUDomfQ~%3ziXV`NH@nEJ&xt=9p!X zNs@roMDTCJC+&MOj~WmA%N!%86NuOko%|k*z}MUkU!|cP)K#P&4*nv(g)>ZD%+128 z4Qj)-vET^Cg!Wqj(A$sG)727=cTd{4W8kh{AMAP5)FkQn8vy`V08<7yAO1L|v#koh zQ^}iiCnDnSkJRJRZ4&xQ!0q`T8*S&eNg4-U=zHWzWAE?RX^?dh7D**^Tjzs7EfKPJ zD90bsbk*4H*GT}KJfpu-Gkh2Hf||$d9R~juZ=o>?z*UK@^NghksG@{V5W)`|R+~Cn zJV}OK9QlqMkI=Kd%z`ls%4Q9MI~jPM?pgjP`XR zIB-4H{2-7-(c2vBZEF*2qH!Z12=Pr!eQxp)O+P~S0xJgrJsV_%&IeZQCXq}n@wb@) zKO@tsWkSf9vjlhsodsxPaH2OxJ?%7o$VpIEEgtPJ?8-+yjcY9wYg8)Ts$;WE*~iLv z`U>3e&Ar7;Cfr1BNR?VRv&{T|FW_#4A%EA~q-^VwtaaH20B}Cgh20<;`&6J;_w`s_ z81aM=;7V71w#Am*gAOCK9bjz|=>jc`U=jrJ11ZU`oi4Hk{odZG@+u{Svbbv5gZ|rm ztihI5W17eV{23F~;>=7URvB#eSzjnx&}j!h!H#@2nyUyHL=|&Drw^QsN1|atDgauR zA=Vd?))bjbMh8zT%`UW$m2+q_n{Y~poRWI*Jt#HpHAwfT^JoGj5HtABBXAPS@&?6K zR-W9#zc04Tc}M#@^!G8~2V9EUFHpn6r0k9LYxaj4H1KR)(ytW0C_}~e_lVzS@w{7b z8gpcRIf$aFUlw-JNFf0=y|D+xu(3`fDDQ&jkI+ow@LCWSCK6lp>UmOYW>2#n%s|$O z9XI~nj69EpVjMP?TpN&QSC8>a;#Pv14U`0(t^2o?oW$(V;zXeA6Qt2=VPT{46pTn( zHnCiH0cqDA=n^pCWGdtS#hHel96^5!ra1aK)}5dR>C*+_e>9dv`4!E~DSrpDm!eur zZY}zp<&%p3{06Mk7gKavmWY`A{)`1cO`=td|JS7#H(?);n|{fnuN zUdh5G^OZU|gf<8z;k2Y;)hNxaB8blw!%3 zgRp9Mcv?g;zz|)}crzSpzy|;TuCw8%;WfpXx{y)#Q9kR-HN+M=3agM~bpo9ix_vm^ zg4L3U+vVj{YLLm^Q_K<(F*Jj|zqs}5a zX6euQLmSDt+-OTNaSL}lJ)|={+C8OUZm6;p82FPM08GY#D~j&DRIm`If2?>(^d~A0 zAa|GvbLc1h{w4r66?JH#QLk*9`)$!9UD(zDmTNd2d%dTtngL12s&UVwQ ze=xAAwRJ)Fry#+PZuL^YBB`~`hr#u=<-)%o+i{_cl>CMEpU2&oi&VMqk&XZsO)~-j zVtj!80q(iC{NC=j>~0@tyzJCj=FFcvhSc&XUiY_d#1%NS??E#b+k^rT3;Xs6ko3 z{KY_k5XMcl5cs^LD|`U=zz;^@!=eMJAf=8iK`c!5lAio9DxOd+$aAtEvB!7$MIszR?*d z(7CX=De^^^^Z!bu1AmuorZemzTHqwEfiQGfyMl8Sfl{gVLwe*ElyJ5q&@5A6n&puk zrCW8?30*O8&0jHSNtM-+>YCH$%=$1+#PP7|Nij#xUN7wMz>S_hjBH3En#_7gOb1&y zoBl_>$?1Z9Y@`$fCz9$pNZPAZ7aJ=Hb+y5j2w&YntadgTIWI~Ux5j0*>S{oE_~IZy zR{&m5-)FMd&C>&gDq*_NdoVwtx7{%AE~}6KZWvjlE%jw}SN8nSUt>p0)e|#=!Hw#U z=HDfMubBoV>3Nx0YtIZ{Im^#})D5mFN)WL6kT6c2d2M3^w+p3_lTv!=0|D}4%qS5e z%zQxsvoc0qe`}JGgg=O+y55OGpj4Bk8|zKv@IA#!s-`MiQs5zn*)s4t{MWAYgXxmK zz2mzs;&K$H@%y%K z-6&X4)4{_u;<4yjS`C0{L1pEh=v~J>r6yuMFX+hF@NROkxfcDii*ZQ+0IJD<{|n$l zA(`&UN3GLHfXv}H9X2QqC8)ZRjZ|QQgek=JdEz4GwQWYn8a{#p=gHt8f~+joqgE?> zQ2CpH*L47(KzfAupZ^eQODT5XeSmp{0l0awpgWIKYTN5mR%2*Ip)n_y>01AMEHPsu*(X#aD_68l@jyy>E+}L7+Ttl zXpU0gE1TiG2P)N#6AANRwe|&s1{5gvw~f$(6Q#drycD{Cr3r+a{{?);p+V;7&F|G^tMR#e$(c5cq}U^+ZAUl93BV2(vorM z4g;2-O#_%BN{0089*M0BH=jZ#O*tKttzsVRUS1IspuH17HH^ME_1edJ8z#)Qsl6hfpwfgxm!E zL~U-9{T=m-?*8^x;7rncJ^?Xf1ZEeTVceh~J?CGYYw6`A*|Ge!`{aTP11;$7~8wb!h0Apd$rmPk>SD)A; zY~N=RhB=Z1tCnf08~_d>MFJratjUtFIR%wyUoxK0Z$jAOnqk|lK|^XDG|VD7DqHrhs{w#X6%{vu`2iXKSYx#8 zIf%v&c0=o5a!lILy!SBTE$tmO8ug#KF%jLrb?MUkP}zNhi_X?KyF34#0*#%WTc|aU z>LE~$N0*20c&G&&S7AXVqfcSpzhl`lYsIP=00C7DeA8mHRaO~t0Aek`si3Wi$5aH9 z?L@Z1%sgrS{MXDXeB2$a2kODvXBHhFGvR4>EcU{oLmxJGePP?ur9Yc;+ikyB^YZ2q za4UdOea+9R%B=^EQ)u&^j9CSCNBfaR9C>Twfgu-8Kk6vRnJziaHW1yub?MSK_4vYj zR=1&-KsjK{w1KghbDTmL1y)DK!maaM1KpNAh%?bER@urV$l-W}CW| zAR?MPe||tCx(3WM7$_2ehl$z&9ARKPnEw{pxMOY8j)Qk7L`8NRk+A4d*e96{j%igle zUM4ZAynF+I85Wt_w5_sI_Mt9e9N(#ay=ui7S8WG4xm8^l|C3bvHW==9|MhSrQVr%3 zvu(wj+qmuE@!v7=qmi1bO?{8}l+LrO^_~CvqTDlL`Vw}!%8B{E0i;|lQAxHF@nd1+ z{Dy1Z`{batdg;=)0lZy2ckWPEPR`ZLd?A2h07D3HFwt>l-b+O5SkqpquB`khSw8>( zgf;C&j}Q+6C^Fp`QS~s7=N;4SWGKY^G!y;I3R_I{v=w$vK+KF46R)}*%#T}X_#~Pr z9vgGZu2t#N4ktq?DKEdCfS1g+6Awk^e)QZ=pEBD{SB4Fn;%>{y{ffr8nSisbv~3V7 z7;h5s)8V=eZ!mL4bX}iRUVbxx=S?@neATvP%ib{EPKH8r<|rkjM}Gidim7&h?ZZ5t zsjC+)vg|-#004!R=K&nH%IGN&KV=8O)L{FvWq%}KwaI4Q-{~9qjL9}m;mDCKkx)(W zg-C7iY-l_Qj70!8TkS`pFt8Sc+u14n;kw}Eq1vj~GbIQDK=uCpzX0+vX&gnKhT|1^~NEGI21b zD06txqNLu)!w0IW!*#*?!?nR_g479M{vN=uqN5^AQemK(fj@z8yCN_?QX4!wToZgU zvS(dJ%8PjAN>y_>t^shsBolj7W~(zEGm>qUm8+QO2c{Wd{!w*h!@*w+{bkUlxN;>>53&;K1`yg4b6O0Ac)k7!?P z`aLX{?>Usq6{bGbtA!-Bq*6N#o{n-|gZGQs>{FBVuf|R1$K27}d9_jc3b=ljqQBf- z{OF}wiRDdOYchG03&Y(}ygVcL(WFPe&t&pZci0?Z`^Gx*yY!}h|2aUQg z>?{rr`u^mINm3%6eh#w%symC9!7v*tG&a6iubm=YSbX13bgXgCKBY1zG0xEMbdQc~ zpbpjFtKDr|wyZcChHnB-FtZ*aGx(5(hMvN}z(pVb^SleSZQZ(R!nqwtGwpA7Nbk8I zc(yn=Sof{&g)EUuJuc}rYUt~|gkv(>^j$XlL6mt-*F}X8V0+6WH&b?d!QvYK|N8OK zb({OFzR!ug+O16{+k+r@1-Oq|3!=2&;Dtga^F^%5c?(i&I(@g9y=u~3HU7i|`xFGv z=ZA(4#+saWPJ4g<(!a*XHyfm>eaQR?AiGG)eUr(Yn)2+JOyI>vB zMdP%=#-8HH&dApEX4S2!R8tu`kvcHD1!*;vtq}u$Rb@iSh%1$L7jn6u>zl2&5Hu%~ zH-=%@iE|H`)FQZ+%FUPq#(=zFxa8cP!`rs~R^M#@zuLRI7oBcwyce^_fb~@RSFfxd zJ_Q^BKDQv)6)s!$alWr_&O8UdWFf2^mEKzO@RP!ZxcnKE!;WB6@8RJ|M=yEvS8IQN z!(6(ky@k(TbMu3`x None: + """TODO""" + # initialize program (optionally ground) + self.grounder = Grounder(prog) if grounder is None else grounder + self.prog = self.grounder.ground() if not prog.ground or always_ground else prog + self.certain_literals = self.grounder.certain_literals + + # initialize reasoning graph + self.rg = ReasoningGraph(self.prog, self.certain_literals) + + # initialize solver + self.solver = Solver() + self.npp_cfg_dict = dict() + + self.num_phases = num_phases + + if rank < 0 or world_size < 1 or rank >= world_size: + raise ValueError(f"Invalid rank '{rank}' or world size '{world_size}'.") + self.rank = rank + self.world_size = world_size + + # TODO: verify group + if world_size > 1 and group is None: + # create group with all processes + group = dist.new_group(list(range(self.world_size))) + self.group = group + + self.device = device + + @classmethod + def from_string( + cls, + prog_str: str, + ground: bool = True, + num_phases: int = 1, + rank: int = 0, + world_size: int = 1, + group: Optional["ProcessGroup"] = None, + device: Optional[torch.device] = None, + ) -> "ASN": + """TODO""" + # initialize program from string + return ASN( + Program.from_string(prog_str), + ground, + num_phases=num_phases, + rank=rank, + world_size=world_size, + group=group, + device=device, + ) + + def configure_NPPs(self, npp_cfg_dict: Dict[NPPRule, NPPContext]) -> None: + """TODO""" + self.npp_cfg_dict.update(npp_cfg_dict) + + def npp_forward( + self, npp_data: Dict[str, Tuple[Any]] + ) -> Dict["NPPRule", NPPContext]: + """TODO""" + return { + rule: NPPContext(npp_cfg["model"](*npp_data[rule])) + for rule, npp_cfg in self.npp_cfg_dict.items() + } + + def encode_queries(self, queries: Optional[List[Constraint]]) -> ReasoningGraph: + if queries is None: + queries = [] + + # copy original reasoning graph to encode queries + rg = deepcopy(self.rg) if queries else self.rg + + # encode all queries + for query in queries: + rg.encode_query(query) + + return rg + + def prepare_block( + self, + queries: Optional[List[Constraint]] = None, + rg: Optional[ReasoningGraph] = None, + device: Optional[torch.device] = None, + phase: int = 0, + ) -> GraphBlock: + """TODO""" + if queries is None: + queries = [] + + if rg is None: + # copy original reasoning graph to encode queries + rg = deepcopy(self.rg) if queries else self.rg + + # encode all queries + for query in queries: + rg.encode_query(query) + + if device is None: + device = self.device + + # get node ids for all query sinks (indicating SAT) + # NOTE: uses global sink in case no queries are specified + query_sinks = ( + torch.tensor(rg.query_sinks, device=device) + if queries + else torch.tensor([0], device=device) + ) + + # get powerset of choices for all non-deterministic rules + powerset_dict = { + rule: rule.powerset() + for rule in chain( + rg.npp_edges.keys(), + rg.choice_edges.keys(), + ) + } + + # total number of non-deterministic choice/outcome combinations + total_combinations = math.prod( + [len(powerset) for powerset in powerset_dict.values()] + ) + + # divide total number of combinations by world size + n_per_block, remainder = divmod( + total_combinations, self.world_size * self.num_phases + ) + + # block id + block_id = self.rank * self.num_phases + phase + + # start index for combinations + start_index = n_per_block * block_id + + # spread out remainder across chunks + if block_id < remainder: + start_index += self.rank + else: + start_index += remainder + + # end index for combinations (excluding) + end_index = start_index + n_per_block + + # spread out remainder across chunks + if block_id < remainder: + end_index += 1 + + combination_bounds = (start_index, end_index) + + # number of combinations in chunk + n_combinations = combination_bounds[1] - combination_bounds[0] + + if total_combinations == 0: + # deterministic program + total_combinations = 1 + + # create pyg "batch" from abstract reasoning graph + batch = rg.to_pyg(device=device, hard=True, copies=n_combinations) + + # set all choice edges to zero initially + # TODO: do not do this (no way to recover original sign for activation) + for edges in chain(rg.npp_edges.values(), rg.choice_edges.values()): + for _, edge_type, edge_id in edges: + # NOTE: zeros entries for all copies of the same edge! + # TODO: initialize directly to zero in RG to avoid doing it here + batch[edge_type].edge_weight[edge_id] = 0 + + # initialize NPP contexts + npp_choices_dict = {rule: [] for rule in rg.npp_edges} + + # set choices + for i, powerset_choices in enumerate( + torch.cartesian_prod( + *[torch.arange(len(powerset)) for powerset in powerset_dict.values()] + )[combination_bounds[0] : combination_bounds[1]] + ): + # set choices + for (rule, edges), powerset_choice in zip( + chain(rg.npp_edges.items(), rg.choice_edges.items()), + powerset_choices, + ): + # get selected powerset + choices = powerset_dict[rule][powerset_choice] + + # enable edges + for c in choices: + _, edge_type, edge_id = edges[c] + batch[edge_type].edge_weight[edge_id][i] = 1 + + if isinstance(rule, NPPRule): + npp_choices_dict[rule].append(choices) + + # initialize batch + batch = condense_edges_pyg(batch, device=device) + + # broadcast ids for certain atoms across each graph + certain_atom_ids = torch.tensor(rg.certain_atom_ids, device=device) + query_sinks_batch = query_sinks + + graph_block = GraphBlock( + dict(batch.node_items()), + dict(batch.edge_items()), + torch.cat( + [ + torch.tensor(choices, device=device) + for choices in npp_choices_dict.values() + ], + dim=-1, + ) + if npp_choices_dict + else torch.empty(n_combinations, 0, device=device), + certain_atom_ids=certain_atom_ids, + sink_ids=query_sinks_batch, + ) + + return graph_block + + def solve(self, graph_block: GraphBlock, max_iter: int = -1) -> GraphBlock: + return self.solver.solve(graph_block, max_iter) + + def zero_grad(self) -> None: + """TODO""" + for npp_cfg in self.npp_cfg_dict.values(): + if "optimizer" in npp_cfg and npp_cfg["optimizer"] is not None: + npp_cfg["optimizer"].zero_grad() + + def step(self) -> None: + """TODO""" + for npp_cfg in self.npp_cfg_dict.values(): + if "optimizer" in npp_cfg and npp_cfg["optimizer"] is not None: + npp_cfg["optimizer"].step() + + def get_answer_sets( + self, + queries: Optional[List[Constraint]] = None, + npp_data: Optional[Dict[str, Tuple[Any]]] = None, + device: Optional[torch.device] = None, + num_phases: int = 1, + ) -> Union[Dict[Constraint, List[Set[PredLiteral]]], List[Set[PredLiteral]]]: + if device is None: + device = self.device + + # NPP forward + npp_ctx_dict = self.npp_forward(npp_data) + + # initialize solving context + solving_ctx = SolvingContext( + len(queries) if queries else 1, + npp_ctx_dict, + ) + + for phase in range(num_phases): + # prepare batch (includes NPP forward) + graph_block = self.prepare_block( + queries=queries, + device=device, + phase=phase, + ) + + # solve graph block + graph_block = self.solve(graph_block) + + # update stable models + solving_ctx.update_SMs(graph_block) + + # synchronize SMs across processes + solving_ctx.synchronize_SMs() + + SM_dict = {} + + if queries is None: + queries = [None] + + for query, query_is_SM in zip(queries, solving_ctx.sm_ctx.is_SM): + # get labels for SMs + SM_dict[query] = [ + { + label + for label, atom in zip(self.rg.node_dict["atom"]["label"], atoms) + if torch.isclose(atom, torch.ones_like(atom)) + } + for atoms in solving_ctx.sm_ctx.atoms[query_is_SM.squeeze(-1)] + ] + + # gather all SMs in main process + if self.world_size > 1: + # main process + if self.rank == 0: + SM_dict_list = [None] * self.world_size + + # receive local SMs from secondary processes + dist.gather_object( + SM_dict, + SM_dict_list, + dst=0, + group=self.group, + ) + # merge local SM dictionaries into single global one + for query in SM_dict.keys(): + SM_dict[query] = sum([d[query] for d in SM_dict_list], []) + + # secondary process + else: + # send local SMs to main process + dist.gather_object( + SM_dict, + None, + dst=0, + group=self.group, + ) + + if queries == [None]: + return SM_dict[None] + + return SM_dict + + +# TODO: warning / error if not all NPPs configured diff --git a/src/asn/data/__init__.py b/src/asn/data/__init__.py new file mode 100644 index 0000000..13ab0d1 --- /dev/null +++ b/src/asn/data/__init__.py @@ -0,0 +1 @@ +from .reasoning_graph import ReasoningGraph diff --git a/src/asn/data/datasets/download_datasets.py b/src/asn/data/datasets/download_datasets.py new file mode 100644 index 0000000..35118f1 --- /dev/null +++ b/src/asn/data/datasets/download_datasets.py @@ -0,0 +1,89 @@ +import os +import tempfile +import urllib.request +import shutil + +from zipfile import ZipFile +import gzip +import errno + + +def mkdir_p(path): + """Linux mkdir -p""" + try: + os.makedirs(path) + except OSError as exc: # Python >2.5 + if exc.errno == errno.EEXIST and os.path.isdir(path): + pass + else: + raise + +def maybe_download(directory, url_base, filename, suffix='.zip'): + ''' + Downloads the specified dataset and extracts it + + @param directory: + @param url_base: URL where to find the file + @param filename: name of the file to be downloaded + @param suffix: suffix of the file + + :returns: true if nothing went wrong downloading + ''' + + filepath = os.path.join(directory, filename) + if os.path.isfile(filepath): + return False + + if not os.path.isdir(directory): + mkdir_p(directory) + + url = url_base +filename + + _, zipped_filepath = tempfile.mkstemp(suffix=suffix) + + print('Downloading {} to {}'.format(url, zipped_filepath)) + + urllib.request.urlretrieve(url, zipped_filepath) + print('{} Bytes'.format(os.path.getsize(zipped_filepath))) + + print('Move to {}'.format(filepath)) + shutil.move(zipped_filepath, filepath) + return True + + +def extract_dataset(directory, filepath, filepath_extracted): + if not os.path.isdir(filepath_extracted): + print('unzip ',filepath, " to", filepath_extracted) + with ZipFile(filepath, 'r') as zipObj: + # Extract all the contents of zip file in current directory + zipObj.extractall(directory) + + +def maybe_download_shapeworld4(): + ''' + Downloads the shapeworld4 dataset if it is not downloaded yet + ''' + + directory = "../../data/" + file_name= "shapeworld4.zip" + maybe_download(directory, "https://hessenbox.tu-darmstadt.de/dl/fiEE3hftM4n1gBGn4HJLKUkU/", file_name) + + filepath = os.path.join(directory, file_name) + filepath_extracted = os.path.join(directory,"shapeworld4") + + extract_dataset(directory, filepath, filepath_extracted) + + +def maybe_download_shapeworld_cogent(): + ''' + Downloads the shapeworld4 cogent dataset if it is not downloaded yet + ''' + + directory = "../../data/" + file_name= "shapeworld_cogent.zip" + maybe_download(directory, "https://hessenbox.tu-darmstadt.de/dl/fi3CDjPRsYgAvotHcC8GPaWj/", file_name) + + filepath = os.path.join(directory, file_name) + filepath_extracted = os.path.join(directory,"shapeworld_cogent") + + extract_dataset(directory, filepath, filepath_extracted) diff --git a/src/asn/data/datasets/family_relations.py b/src/asn/data/datasets/family_relations.py new file mode 100644 index 0000000..87b677d --- /dev/null +++ b/src/asn/data/datasets/family_relations.py @@ -0,0 +1,361 @@ + +import torch +from torch.utils.data import Dataset +from tqdm import tqdm +import pandas as pd +import numpy as np + +import csv + +sex_lookup = {} +with open('/workspaces/ASN/data/parent_relation/sex.csv') as f: + r = csv.reader(f, delimiter=',') + header = next(r) + for row in r: + child,sex = row + sex_lookup[child] = sex + + +def create_data_splits(for_asn_training): + + # read in child_dict from csv + child_df = pd.read_csv('/workspaces/ASN/data/parent_relation/celebtrity_dataset_split.csv') + + #read in the data + df = pd.read_csv('/workspaces/ASN/data/parent_relation/celebrity_relations_parent_child_pairs.csv') + + + train_data = [] + val_data = [] + + #includes the pairs from which we have seen the other side during training + parent_child_pairs_seen = [] + child_parent_pairs_seen = [] + + #includes the pairs from which we have not seen the other side during training + parent_child_pairs_unseen = [] + child_parent_pairs_unseen = [] + + #go through all 1513 entries in the dataset + for _, row in df.iterrows(): + #get the parent and child and the relation + child = row['child'] + parent = row['parent'] + relation = row['parent_type'] + + # obtain the reverse relation + if sex_lookup[child] == 'Male': + reverse_relation = 'son' + elif sex_lookup[child] == 'Female': + reverse_relation = 'daughter' + + + if relation == 'father': + sex_parent = 'Male' + elif relation == 'mother': + sex_parent = 'Female' + + + + # Add the (person1, person2, relation) tuple + if child_df.loc[child_df['child']==child]['dataset'].item() == 'train': #for some add both to train + + #flip a coin to decide if we add the reverse relation to the train data. We add the other way to the test data + if child_df.loc[child_df['child']==child]['direction'].item(): + train_data.append((parent, child, relation, sex_lookup[child])) + child_parent_pairs_seen.append((child, parent, reverse_relation, sex_parent)) + + if for_asn_training: #if we use SLASH training we obtain the other way around as well + train_data.append((child, parent, reverse_relation, sex_parent)) + else: + train_data.append((child, parent, reverse_relation, sex_parent)) + parent_child_pairs_seen.append((parent, child, relation, sex_lookup[child])) + if for_asn_training: #if we use SLASH training we obtain the other way around as well + train_data.append((parent, child, relation, sex_lookup[child])) + + elif child_df.loc[child_df['child']==child]['dataset'].item() == 'val': #for some add both to train + val_data.append((child, parent, reverse_relation, sex_parent)) + val_data.append((parent, child, relation, sex_lookup[child])) + + elif child_df.loc[child_df['child']==child]['dataset'].item() == 'test':#for some add only child parent to train and test the other way around + parent_child_pairs_unseen.append((parent, child, relation, sex_lookup[child])) + child_parent_pairs_unseen.append((child, parent, reverse_relation, sex_parent)) + + train_data = np.array(train_data) + val_data = np.array(val_data) + child_parent_data_seen= np.array(child_parent_pairs_seen) + parent_child_data_seen = np.array(parent_child_pairs_seen) + child_parent_data_unseen = np.array(child_parent_pairs_unseen) + parent_child_data_unseen = np.array(parent_child_pairs_unseen) + + return train_data, val_data, child_parent_data_seen, parent_child_data_seen, child_parent_data_unseen, parent_child_data_unseen + + +# Custom dataset +class FamilyRelationDataset(Dataset): + """ + Custom dataset for the family relation classification task. It returns a dictionary with the input IDs, attention mask, and labels. + """ + def __init__(self, data, tokenizer,split, max_length=128): + self.data = data + self.tokenizer = tokenizer + self.max_length = max_length + self.split = split + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + person1, person2, relation,_ = self.data[idx] + input_text = f"What is the relation between {person1} and {person2}? The answer is:" + + if self.split == 'train': + # Append the relation to the input text for training + full_text = f"{input_text} {relation}" + + # Tokenize the full text including the relation + inputs = self.tokenizer(full_text, return_tensors="pt", max_length=self.max_length, padding="max_length", truncation=True) + input_ids = inputs["input_ids"].squeeze() + attention_mask = inputs["attention_mask"].squeeze() + + # Create labels by shifting the input_ids + labels = input_ids.clone() + + # Set the labels of the input text (not including the relation) to -100 + input_length = len(self.tokenizer(input_text, add_special_tokens=False)['input_ids']) + #print(tokenizer.decode(labels[:input_length+1])) + labels[:input_length+1] = -100 + + + elif self.split == 'test': + # For testing, we don't append the relation to the input + inputs = self.tokenizer(input_text, return_tensors="pt", max_length=self.max_length, padding="max_length", truncation=True) + input_ids = inputs["input_ids"].squeeze() + attention_mask = inputs["attention_mask"].squeeze() + labels = relation # Keep the original label for evaluation + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels + } + +# Custom dataset +class FamilyRelationASNDataset(Dataset): + """ + Custom dataset for the family relation classification task. It returns a dictionary with the input IDs, attention mask, and labels. + """ + def __init__(self, data, tokenizer,split, max_length=128): + self.data = data + self.tokenizer = tokenizer + self.max_length = max_length + self.split = split + self.class_token_ids_relation = torch.tensor([tokenizer.encode(c, add_special_tokens=False)[0] for c in ['mother','father','daughter','son']]) + self.class_token_ids_sex = torch.tensor([tokenizer.encode(c, add_special_tokens=False)[0] for c in ['male','female']]) + + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + person1, person2, relation, sex_p2 = self.data[idx] + input_text_original = f"What is the relation between {person1} and {person2}? The answer is:" + input_text_derived = f"What is the relation between {person2} and {person1}? The answer is:" + input_text_sex_person1 = f"What is the sex of {person1}? The answer is:" + input_text_sex_person2 = f"What is the sex of {person2}? The answer is:" + + # # one hot encode the labels using the following encoding: mother, father, daughter, son + # if relation == 'father': + # labels = torch.tensor([1,0,0,0]) + # elif relation == 'mother': + # labels = torch.tensor([0,1,0,0]) + # elif relation == 'daughter': + # labels = torch.tensor([0,0,1,0]) + # elif relation == 'son': + # labels = torch.tensor([0,0,0,1]) + + # Tokenize the full text including the relation + inputs_original = self.tokenizer(input_text_original, return_tensors="pt", max_length=self.max_length, padding="max_length", truncation=True) + input_ids_original = inputs_original["input_ids"].squeeze() + attention_mask_original = inputs_original["attention_mask"].squeeze() + + inputs_derived = self.tokenizer(input_text_derived, return_tensors="pt", max_length=self.max_length, padding="max_length", truncation=True) + input_ids_derived = inputs_derived["input_ids"].squeeze() + attention_mask_derived = inputs_derived["attention_mask"].squeeze() + + inputs_sex_person1 = self.tokenizer(input_text_sex_person1, return_tensors="pt", max_length=self.max_length, padding="max_length", truncation=True) + input_ids_sex_person1 = inputs_sex_person1["input_ids"].squeeze() + attention_inputs_sex_person1 = inputs_sex_person1["attention_mask"].squeeze() + + inputs_sex_person2 = self.tokenizer(input_text_sex_person2, return_tensors="pt", max_length=self.max_length, padding="max_length", truncation=True) + input_ids_sex_person2 = inputs_sex_person2["input_ids"].squeeze() + attention_inputs_sex_person2 = inputs_sex_person2["attention_mask"].squeeze() + + # Create labels by shifting the input_ids + #labels = input_ids.clone() + + # # Set the labels of the input text (not including the relation) to -100 + # input_length = len(self.tokenizer(input_text_child_parent, add_special_tokens=False)['input_ids']) + # labels[:input_length+1] = -100 + + + return { + "relation": relation, + "sex_p2":sex_p2.lower(), + "#npp(relation(p1,p2),[mother,father,daughter,son]) :- person(p1),person(p2),p1!=p2.": { + "input_ids": input_ids_original, + "attention_mask": attention_mask_original, + "classes": self.class_token_ids_relation + }, + '#npp(relation(p2,p1),[mother,father,daughter,son]) :- person(p2),person(p1),p2!=p1.': { + "input_ids":input_ids_derived, + "attention_mask":attention_mask_derived, + "classes": self.class_token_ids_relation + }, + '#npp(sex(p1),[male,female]) :- person(p1).': { + 'input_ids':input_ids_sex_person1, + 'attention_mask':attention_inputs_sex_person1, + 'classes': self.class_token_ids_sex + }, + '#npp(sex(p2),[male,female]) :- person(p2).': { + 'input_ids':input_ids_sex_person2, + 'attention_mask':attention_inputs_sex_person2, + 'classes': self.class_token_ids_sex + } + } + +# Custom dataset +class FamilyRelationASNDataset2(Dataset): + """ + Custom dataset for the family relation classification task. It returns a dictionary with the input IDs, attention mask, and labels. + """ + def __init__(self, data, tokenizer,split, max_length=128): + self.data = data + self.tokenizer = tokenizer + self.max_length = max_length + self.split = split + self.class_token_ids_relation = torch.tensor([tokenizer.encode(c, add_special_tokens=False)[0] for c in ['mother','father','daughter','son']]) + self.class_token_ids_sex = torch.tensor([tokenizer.encode(c, add_special_tokens=False)[0] for c in ['male','female']]) + + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + person1, person2, relation, sex_p2 = self.data[idx] + input_text_original = f"What is the relation between {person1} and {person2}? The answer is:" + input_text_derived = f"What is the relation between {person2} and {person1}? The answer is:" + input_text_sex_person1 = f"What is the sex of {person1}? The answer is:" + input_text_sex_person2 = f"What is the sex of {person2}? The answer is:" + + # # one hot encode the labels using the following encoding: mother, father, daughter, son + # if relation == 'father': + # labels = torch.tensor([1,0,0,0]) + # elif relation == 'mother': + # labels = torch.tensor([0,1,0,0]) + # elif relation == 'daughter': + # labels = torch.tensor([0,0,1,0]) + # elif relation == 'son': + # labels = torch.tensor([0,0,0,1]) + + # Tokenize the full text including the relation + inputs_original = self.tokenizer(input_text_original, return_tensors="pt", max_length=self.max_length, padding="max_length", truncation=True) + input_ids_original = inputs_original["input_ids"].squeeze() + attention_mask_original = inputs_original["attention_mask"].squeeze() + + inputs_derived = self.tokenizer(input_text_derived, return_tensors="pt", max_length=self.max_length, padding="max_length", truncation=True) + input_ids_derived = inputs_derived["input_ids"].squeeze() + attention_mask_derived = inputs_derived["attention_mask"].squeeze() + + inputs_sex_person1 = self.tokenizer(input_text_sex_person1, return_tensors="pt", max_length=self.max_length, padding="max_length", truncation=True) + input_ids_sex_person1 = inputs_sex_person1["input_ids"].squeeze() + attention_inputs_sex_person1 = inputs_sex_person1["attention_mask"].squeeze() + + inputs_sex_person2 = self.tokenizer(input_text_sex_person2, return_tensors="pt", max_length=self.max_length, padding="max_length", truncation=True) + input_ids_sex_person2 = inputs_sex_person2["input_ids"].squeeze() + attention_inputs_sex_person2 = inputs_sex_person2["attention_mask"].squeeze() + + # Create labels by shifting the input_ids + #labels = input_ids.clone() + + # # Set the labels of the input text (not including the relation) to -100 + # input_length = len(self.tokenizer(input_text_child_parent, add_special_tokens=False)['input_ids']) + # labels[:input_length+1] = -100 + + + return { + "relation": relation, + "sex_p2":sex_p2.lower(), + "#npp(relation(p1,p2),[mother,father,daughter,son]) :- person(p1),person(p2),p1!=p2.": { + "input_ids": input_ids_original, + "attention_mask": attention_mask_original, + "classes": self.class_token_ids_relation + }, + '#npp(relation(p2,p1),[mother,father,daughter,son]) :- person(p2),person(p1),p2!=p1.': { + "input_ids":input_ids_derived, + "attention_mask":attention_mask_derived, + "classes": self.class_token_ids_relation + }, + '#npp(sex(p1),[male,female]) :- person(p1).': { + 'input_ids':input_ids_sex_person1, + 'attention_mask':attention_inputs_sex_person1, + 'classes': self.class_token_ids_sex + }, + '#npp(sex(p2),[male,female]) :- person(p2).': { + 'input_ids':input_ids_sex_person2, + 'attention_mask':attention_inputs_sex_person2, + 'classes': self.class_token_ids_sex + } + } + +from sklearn.metrics import accuracy_score, precision_recall_fscore_support + +def evaluate_model(model, test_loader, tokenizer, device, batch_size=8, print_preds=False): + """ + Evaluate the model on the test set and return the accuracy, precision, recall, and F1 score. + """ + model.eval() + + relations = ["mother", "father", "daughter", "son"] + # Define the relations we want to classify + relation_token_ids = [tokenizer.encode(rel, add_special_tokens=False)[0] for rel in relations] + + all_preds = [] + all_labels = [] + + with torch.no_grad(): + for batch in tqdm(test_loader, desc="Evaluating"): + input_ids = batch["input_ids"].to(device) + attention_mask = batch["attention_mask"].to(device) + + outputs = model(input_ids=input_ids, attention_mask=attention_mask) + next_token_pos = ((input_ids == 2).nonzero(as_tuple=True))[1][0] + logits = outputs.logits[:, next_token_pos, :] # Get logits for the last token -> get the logits at the end of the input text + + # Only consider logits for the specified relation tokens + relation_logits = logits[:, relation_token_ids] + predicted_relation_indices = torch.argmax(relation_logits, dim=1) + predictions = [relations[i] for i in predicted_relation_indices.cpu().numpy()] + all_preds.extend(predictions) + all_labels.extend(batch["labels"]) + + # Calculate metrics + accuracy = accuracy_score(all_labels, all_preds) + precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted', zero_division=0) + + if print_preds: + print("ALL LABELS", all_labels) + print("ALL PREDS ", all_preds) + # Per-class metrics + class_precision, class_recall, class_f1, _ = precision_recall_fscore_support(all_labels, all_preds, zero_division=0, average=None) + + results = { + "accuracy": accuracy, + "precision": precision, + "recall": recall, + "f1": f1, + "per_class": {rel: {"precision": p, "recall": r, "f1": f} + for rel, p, r, f in zip(relations, class_precision, class_recall, class_f1)} + } + + return results \ No newline at end of file diff --git a/src/asn/data/datasets/mnist_addition.py b/src/asn/data/datasets/mnist_addition.py new file mode 100644 index 0000000..6234203 --- /dev/null +++ b/src/asn/data/datasets/mnist_addition.py @@ -0,0 +1,85 @@ +from functools import reduce +from typing import Callable, Iterable, List, Optional, Tuple + +import torch +from ground_slash.program import Constraint, Naf, Number, PredLiteral, SymbolicConstant +from torch.utils.data import Dataset, Subset +from torchvision.datasets import MNIST + + +class MNISTAddition(Dataset): + """TODO""" + + def __init__( + self, + n: int, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + download: bool = False, + digits: Optional[Iterable[int]] = None, + seed: int = None, + ) -> None: + """TODO""" + self.n = n + self.transform = transform + self.train = train + self.root = root + + # get regular MNIST dataset + self.mnist = MNIST( + root=root, train=train, transform=transform, download=download + ) + + if digits is not None: + self.digits = set(digits) if not isinstance(digits, set) else digits + self.mnist = Subset( + self.mnist, + torch.where( + reduce( + torch.logical_or, + [self.mnist.targets == digit for digit in digits], + ) + )[0], + ) + else: + self.digits = set(range(10)) + + if seed is not None: + torch.manual_seed(seed) + + self.data = [] + + for ids in torch.split(torch.randperm(len(self.mnist)), self.n): + # chunk is not complete + if len(ids) != self.n: + continue + + x, y = tuple(zip(*tuple(self.mnist[i] for i in ids))) + + self.data.append( + ( + x, + sum(y), + ) + ) + + def __getitem__(self, index: int) -> Tuple[Tuple[torch.Tensor, ...], int]: + return self.data[index] + + def __len__(self) -> int: + return len(self.data) + + def to_queries(self, y: Iterable[int]) -> List[Constraint]: + return [ + Constraint( + Naf( + PredLiteral( + "addition", + *tuple(SymbolicConstant(f"i{i+1}") for i in range(self.n)), + Number(y_i.item()), + ) + ) + ) + for y_i in y + ] diff --git a/src/asn/data/datasets/shapeworld4.py b/src/asn/data/datasets/shapeworld4.py new file mode 100644 index 0000000..e638603 --- /dev/null +++ b/src/asn/data/datasets/shapeworld4.py @@ -0,0 +1,215 @@ +import torch +import torchvision +from torch.utils.data import Dataset +from torchvision.transforms import transforms + +from torch.utils.data import Dataset +from torchvision import transforms +from skimage import io +import os +import numpy as np +import torch +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True +import json +from asn.data.datasets.download_datasets import maybe_download_shapeworld4 + +from functools import reduce +from typing import Callable, Iterable, List, Optional, Tuple +from ground_slash.program import Constraint, Naf, Number, PredLiteral, SymbolicConstant +from asn.data.expression import ComplexQuery + +def get_encoding(color, shape, shade, size): + + if color == 'red': + col_enc = [1,0,0,0,0,0,0,0,0] + elif color == 'blue': + col_enc = [0,1,0,0,0,0,0,0,0] + elif color == 'green': + col_enc = [0,0,1,0,0,0,0,0,0] + elif color == 'gray': + col_enc = [0,0,0,1,0,0,0,0,0] + elif color == 'brown': + col_enc = [0,0,0,0,1,0,0,0,0] + elif color == 'magenta': + col_enc = [0,0,0,0,0,1,0,0,0] + elif color == 'cyan': + col_enc = [0,0,0,0,0,0,1,0,0] + elif color == 'yellow': + col_enc = [0,0,0,0,0,0,0,1,0] + elif color == 'black': + col_enc = [0,0,0,0,0,0,0,0,1] + + + if shape == 'circle': + shape_enc = [1,0,0,0] + elif shape == 'triangle': + shape_enc = [0,1,0,0] + elif shape == 'square': + shape_enc = [0,0,1,0] + elif shape == 'bg': + shape_enc = [0,0,0,1] + + if shade == 'bright': + shade_enc = [1,0,0] + elif shade =='dark': + shade_enc = [0,1,0] + elif shade == 'bg': + shade_enc = [0,0,1] + + + if size == 'small': + size_enc = [1,0,0] + elif size == 'big': + size_enc = [0,1,0] + elif size == 'bg': + size_enc = [0,0,1] + + return col_enc + shape_enc + shade_enc + size_enc + [1] + + +class ShapeWorld4(Dataset): + def __init__(self, root, mode, ret_obj_encoding=False): + + maybe_download_shapeworld4() + + self.ret_obj_encoding = ret_obj_encoding + self.root = root + self.mode = mode + assert os.path.exists(root), 'Path {} does not exist'.format(root) + + #dictionary of the form {'image_idx':'img_path'} + self.img_paths = {} + + + for file in os.scandir(os.path.join(root, 'images', mode)): + img_path = file.path + + img_path_idx = img_path.split("/") + img_path_idx = img_path_idx[-1] + img_path_idx = img_path_idx[:-4][6:] + try: + img_path_idx = int(img_path_idx) + self.img_paths[img_path_idx] = img_path + except: + print("path:",img_path_idx) + + + count = 0 + + #target maps of the form {'target:idx': query string} or {'target:idx': obj encoding} + self.query_map = {} + self.obj_map = {} + + with open(os.path.join(root, 'labels', mode,"world_model.json")) as f: + worlds = json.load(f) + + #iterate over all objects + for world in worlds: + num_objects = 0 + target_query = [] + obj_enc = [] + for entity in world['entities']: + + color = entity['color']['name'] + shape = entity['shape']['name'] + + shade_val = entity['color']['shade'] + if shade_val == 0.0: + shade = 'bright' + else: + shade = 'dark' + + size_val = entity['shape']['size']['x'] + if size_val == 0.075: + size = 'small' + elif size_val == 0.15: + size = 'big' + + name = 'o' + str(num_objects+1) + + q = Constraint( + Naf( + PredLiteral( + "object", + *tuple([SymbolicConstant(name), + SymbolicConstant(color), + SymbolicConstant(shape), + SymbolicConstant(shade), + SymbolicConstant(size)]), + ) + ) + ) + + #target_query = target_query+ ":- not object({},{},{},{},{}). ".format(name, color, shape, shade, size) + target_query.append(q) + obj_enc.append(get_encoding(color, shape, shade, size)) + num_objects += 1 + + #bg encodings + for i in range(num_objects, 4): + name = 'o' + str(num_objects+1) + + q = Constraint( + Naf( + PredLiteral( + "object", + *tuple([SymbolicConstant(name), + SymbolicConstant(f'black'), + SymbolicConstant(f'bg'), + SymbolicConstant(f'bg'), + SymbolicConstant(f'bg')]), + ) + ) + ) + target_query.append(q) + # target_query = target_query+ ":- not object({},black,bg, bg, bg). ".format(name) + obj_enc.append(get_encoding("black","bg","bg","bg")) + num_objects += 1 + + target_query = ComplexQuery(*target_query) + + self.query_map[count] = target_query + self.obj_map[count] = np.array(obj_enc) + count+=1 + + + + def __getitem__(self, index): + + #get the image + img_path = self.img_paths[index] + img = io.imread(img_path)[:, :, :3] + + transform = transforms.Compose([ + transforms.ToPILImage(), + transforms.ToTensor(), + ]) + img = transform(img) + img = (img - 0.5) * 2.0 # Rescale to [-1, 1]. + + if self.ret_obj_encoding: + return {'im':img}, self.query_map[index] ,self.obj_map[index] + else: + return {'im':img}, self.query_map[index] + + def __len__(self): + return len(self.img_paths) + + + # def to_queries(self, y: Iterable[int]) -> List[Constraint]: + + + + # return [ + # Constraint( + # Naf( + # PredLiteral( + # "object", + # *tuple(SymbolicConstant(f"i{i+1}") for i in range(self.n)), + # Number(y_i.item()), + # ) + # ) + # ) + # for y_i in y + # ] diff --git a/src/asn/data/expression.py b/src/asn/data/expression.py new file mode 100644 index 0000000..2276f52 --- /dev/null +++ b/src/asn/data/expression.py @@ -0,0 +1,65 @@ +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Union + +from ground_slash.program import Expr, LiteralCollection, Naf + +if TYPE_CHECKING: + from ground_slash.program import Constraint, Query, Statement, Term, Variable + from ground_slash.progrm.safety_characterization import SafetyTriplet + + +class Disjunction(LiteralCollection): + # TODO: generalize to Literals as well as Disjunctions/Disjunctions + def __hash__(self) -> int: + return hash(("disjunction", frozenset(self.literals))) + + def __eq__(self, other: "Any") -> bool: + return ( + isinstance(other, Disjunction) + and len(self) == len(other) + and frozenset(self.literals) == frozenset(other.literals) + ) + + def as_conjunction(self) -> "Conjunction": + return Conjunction(Naf(literal, ~literal.naf) for literal in self) + + +class Conjunction(LiteralCollection): + # TODO: generalize to Literals as well as Disjunctions/Disjunctions + def __hash__(self) -> int: + return hash(("conjunction", frozenset(self.literals))) + + def __eq__(self, other: "Any") -> bool: + return ( + isinstance(other, Conjunction) + and len(self) == len(other) + and frozenset(self.literals) == frozenset(other.literals) + ) + + +class ComplexQuery(Expr): + def __init__(self, *constraints: Iterable["Constraint"]) -> None: + if len(constraints) < 2: + raise ValueError(f"Complex query must containt at least 2 constraints, but got {len(constraints)}.") + + self.constraints = tuple(constraints) + + def __str__(self) -> str: + return "\n".join([str(constr) for constr in self.constraints]) + + def __eq__(self, other: "ComplexQuery") -> bool: + return isinstance(other, ComplexQuery) and (frozenset(self.constraints) == frozenset(other.constraints)) + + def __hash__(self) -> int: + return hash(frozenset(self.constraints)) + + def vars(self) -> Set["Variable"]: + raise NotImplementedError() + + def global_vars(self) -> Set["Variable"]: + raise NotImplementedError() + + def safety(self, statement: Optional[Union["Statement", "Query"]]=None) -> "SafetyTriplet": + raise NotImplementedError() + + def substitute(self, subst: Dict[str, "Term"]) -> "Expr": + raise NotImplementedError() \ No newline at end of file diff --git a/src/asn/data/reasoning_graph.py b/src/asn/data/reasoning_graph.py new file mode 100644 index 0000000..d4af1bd --- /dev/null +++ b/src/asn/data/reasoning_graph.py @@ -0,0 +1,1227 @@ +import io +import itertools +from collections import defaultdict +from copy import deepcopy +from math import isfinite +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union + +import pygraphviz as pgv +import torch +from ground_slash.program import ( + NPP, + AggrCount, + AggrElement, + AggrLiteral, + BuiltinLiteral, + Choice, + ChoiceRule, + Constraint, + DisjunctiveRule, + Expr, + FalseConstant, + Guard, + Infimum, + Literal, + LiteralCollection, + Naf, + Neg, + NormalRule, + NPPRule, + Number, + PredLiteral, + Program, + Statement, + TermTuple, + TrueConstant, +) +from PIL import Image +from torch_geometric.data import HeteroData + +# TODO: refactor +from asn.utils import relop_dict +from asn.utils.collections import get_minimal_collections + +from .expression import ComplexQuery, Conjunction, Disjunction + + +class ReasoningGraph: + def __init__( + self, + prog: Program, + certain_atoms: Optional[Set[PredLiteral]] = None, + ) -> None: + """ + Args: + prog: SLASH `Program` instance. + certain_atoms: optional set of atoms (`PredLiteral` instances) whose nodes + are initialized to `True`. Can be used to reduce the number of + iterations during solving. + + Raises: + TODO: + """ + # ---------- init graph ---------- + self.node_types = ("atom", "disj", "conj", "count", "sum", "min", "max") + + if certain_atoms is None: + certain_atoms = set() + + # list of ids for query-specific SAT nodes (i.e. 'False') + self.query_sinks = [] + + # map for some specific unicode symbols + self.__unicode_symbols = { + "true": "\u22a5", + "false": "\u22a4", + "disj": "\u2228", + "conj": "\u2227", + "neq": "\u2260", + "leq": "\u2264", + "geq": "\u2265", + } + + # node & edge dictionaries + self.node_dict: Dict[str, Dict[str, List]] = defaultdict( + lambda: defaultdict(list) + ) + self.edge_dict: Dict[Tuple[str], Dict[str, List]] = defaultdict( + lambda: defaultdict(list) + ) + + # dictionaries mapping ASP constructs to node ids + self.node_id_dict: Dict[Expr, Tuple[str, int]] = dict() + self.edge_id_dict: Dict[Tuple[Expr, Expr], int] = dict() + + # dictionaries mapping conj. ids to list of edges for choices/disjs. & NPPs + # TODO: typing + self.choice_edges: Dict[ + Union[Choice, LiteralCollection], List[Tuple[Literal, int, int]] + ] = dict() + # TODO: typing + self.npp_edges: Dict[NPP, List[Tuple[Literal, int, int]]] = dict() + + # set of `Choice` & `NPP` instances that are already incorporated in the graph + # (to avoid duplicate encodings) + self.choices: Set[Choice, LiteralCollection] = set() + self.npps: Set[NPP] = set() + + # create literals for constants 'True' and 'False' + self.true_const = TrueConstant() + self.false_const = FalseConstant() + + # initialize constant nodes for 'True' and 'False' + # 'True' represented as conj. node (no inputs result in True) + self.add_node( + self.true_const, + "conj", + self.__unicode_symbols["true"], + x=1.0, + ) + # 'False' represented as disj. node (no inputs result in False) + self.add_node( + self.false_const, + "disj", + self.__unicode_symbols["false"], + ) + + # ---------- process program ---------- + for stmt in prog.statements: + self.encode_statement(stmt, certain_atoms) + + # map certain atoms to their node ids + self.certain_atom_ids = [self.node_id_dict[atom][1] for atom in certain_atoms] + + def encode_statement( + self, + statement: Statement, + certain_atoms: Optional[Set[PredLiteral]] = None, + ) -> None: + if certain_atoms is None: + certain_atoms = set() + + # check if statement is ground + if not statement.ground: + raise ValueError(f"Statement {str(statement)} is not ground.") + + # --------------- process body --------------- + + if any( + isinstance(literal, BuiltinLiteral) and not literal.eval() + for literal in statement.body + ): + # false built-in literal (i.e., body never satisfied) + # no need to process rule + return + + body_literals = [] + body_literal_signs = [] + + # pre-process body literals + for literal in statement.body: + # encode literal (if not exists) + sign = self.encode_literal(literal, certain_atoms) + + # predicate or aggregate literal + if sign != 0: + body_literals.append(literal) + body_literal_signs.append(sign) + # removes built-in literals from body + # we already know that these evaluate to 'True' + + # fact + if not body_literals: + body_literals.append(self.true_const) + body_literal_signs.append(1) + + # TODO: better way? + body_literals = Conjunction(*body_literals) + + # single body literal + if len(body_literals) == 1: + # use literal directly + body_key = abs(body_literals[0]) + body_sign = body_literal_signs[0] + # conjunction of body literals + else: + body_key = body_literals + + # connect body literals to a conjunction node (if not exists) + if body_key not in self.node_id_dict: + self.add_node( + body_key, + "conj", + f"{self.__unicode_symbols['conj']}_{{{len(self.node_dict['conj']['x'])-1}}}", + ) + + for literal, sign in zip(body_literals, body_literal_signs): + self.add_edge( + abs(literal), + body_key, + edge_weight=float(sign), + ) + + body_sign = 1 + + # --------------- process head --------------- + + consequents = defaultdict(list) + + if isinstance(statement, ChoiceRule): + # choice rule + choice = statement.head + + for element in choice.elements: + consequents[element.atom].append(Conjunction(*element.literals)) + elif isinstance(statement, NPPRule): + # NPP rule + choice = statement.npp.as_choice() + + for element in choice.elements: + consequents[element.atom].append(Conjunction(*element.literals)) + else: + # normal/disjunctive rules + for atom in statement.head: + consequents[atom].append(Conjunction()) + + # normal/disjunctive rules + if not consequents: + consequents[self.false_const].append(Conjunction()) + + # dictionary to store the edges corresponding to a choice/disjunction + # NOTE: used later if statement actually non-deterministic + choice_edges = list() + # TODO + cond_map = dict() + + # iterate over all consequent literals + for i, (cond_literal, conditions) in enumerate(consequents.items()): + # if 'cond_literal' is not a query sink (i.e., Constraint expression) + if not isinstance(cond_literal, (Constraint, ComplexQuery)): + # encode or update literal node + # NOTE: is always positive due to language specifications + self.encode_literal(cond_literal, certain_atoms) + + # ----- process conditions ----- + + literal_conditions = [] + + # pre-process conditions + # remove builtin-literals and check their satisfiability + for cond in conditions: + literals = [] + + for literal in cond: + if isinstance(literal, BuiltinLiteral): + if not literal.eval(): + # condition can never be satisfied (remove) + break + + literals.append(literal) + else: + # condition can be satisfied (keep) + literal_conditions.append(Conjunction(*literals)) + + # get minimal conditions + # (supersets irrelevant if a subset already satisfies condition) + minimal_cond_candidates = get_minimal_collections(*literal_conditions) + + # list of tuples containing the keys to different conditions and their sign + cond_keys = [] + # list of aggregate elements for a choice aggregate + # NOTE: used later if statement is a choice rule + count_elements = [] + + # process final conditions + for cond in minimal_cond_candidates: + # encode literals (if not already) + # NOTE: do NOT need to check for existence of literals from here on out + for literal in cond: + self.encode_literal(literal, certain_atoms) + + # save aggregate element for choice rule/NPP + count_elements.append( + AggrElement( + TermTuple(Number(i)), + LiteralCollection(cond_literal, *cond), + ) + ) + + # empty condition (unconditional) + if len(cond) == 0: + continue + # single condition literal (use directly) + elif len(cond) == 1: + # sign depends on literal since we its node directly + cond_keys.append((abs(cond[0]), -1 if cond[0].naf else 1)) + # multiple condition literals (combine in conjunction) + else: + if cond not in self.node_id_dict: + # create new conj. node + self.add_node( + cond, + "conj", + f"{self.__unicode_symbols['conj']}_{{{len(self.node_dict['conj']['x'])-1}}}", + ) + + # add edges from literals to conj. + for literal in cond: + self.add_edge( + abs(literal), cond, edge_weight=-1 if literal.naf else 1 + ) + + # sign is always positive + cond_keys.append((cond, 1)) + + # no condition (unconditional) + if len(cond_keys) == 0: + # NOTE: even if body is just 'True', we need the edge here + choice_edges.append( + ( + cond_literal, + *self.add_edge(body_key, cond_literal, edge_weight=body_sign), + ) + ) + + cond_map[cond_literal] = None + + # single condition (use directly) + elif len(cond_keys) == 1: + # NOTE: condition already encoded as a conjunction (no need to check) + cond_key, sign = cond_keys[0] + conj_key = ( + Conjunction(*body_key, cond_key) + if isinstance(body_key, LiteralCollection) + else Conjunction(body_key, cond_key) + ) + + if conj_key not in self.node_id_dict: + self.add_node( + conj_key, + "conj", + f"{self.__unicode_symbols['conj']}_{{{len(self.node_dict['conj']['x'])-1}}}", + ) + + if body_key != self.true_const: + self.add_edge( + body_key, + conj_key, + edge_weight=body_sign, + ) + + self.add_edge(cond_key, conj_key, edge_weight=float(sign)) + + choice_edges.append( + ( + cond_literal, + *self.add_edge( + conj_key, + cond_literal, + ), + ) + ) + + cond_map[cond_literal] = conj_key + + # multiple conditions (combine in disjunction) + else: + disj_key = Disjunction(cond for cond, _ in cond_keys) + conj_key = ( + Conjunction(*body_key, disj_key) + if isinstance(body_key, LiteralCollection) + else Conjunction(body_key, disj_key) + ) + + if conj_key not in self.node_id_dict: + self.add_node( + conj_key, + "conj", + f"{self.__unicode_symbols['conj']}_{{{len(self.node_dict['conj']['x'])-1}}}", + ) + + if disj_key not in self.node_id_dict: + self.add_node( + disj_key, + "disj", + f"{self.__unicode_symbols['disj']}_{{{len(self.node_dict['disj']['x'])-1}}}", + ) + + for cond, sign in cond_keys: + self.add_edge( + cond, + disj_key, + edge_weight=sign, + ) + + self.add_edge( + disj_key, + conj_key, + ) + + if body_key != self.true_const: + self.add_edge( + body_key, + conj_key, + edge_weight=body_sign, + ) + + choice_edges.append( + ( + cond_literal, + *self.add_edge( + conj_key, + cond_literal, + ), + ) + ) + + cond_map[cond_literal] = conj_key + + if isinstance(statement, (NormalRule, Constraint)): + # done + return + + if isinstance(statement, DisjunctiveRule): + head_literals = statement.head + + # constraint that is active ONLY if the rules body is satisfied + # AND NONE of the head literals is satisfied + constr_key = Conjunction( + *[Naf(deepcopy(atom), True) for atom in head_literals], *body_literals + ) + + if constr_key not in self.node_id_dict: + self.add_node( + constr_key, + "conj", + rf"{self.__unicode_symbols['conj']}_{{{len(self.node_dict['conj']['x'])-1}}}", + ) + + self.add_edge(body_key, constr_key, edge_weight=body_sign) + + for literal in head_literals: + self.add_edge( + literal, + constr_key, + edge_weight=-1.0, + ) + + self.add_edge( + constr_key, + self.false_const, + ) + + if statement not in self.choices: + # TODO: necessary ??? + self.choices.add(statement) + # TODO: best way to store choices? + self.choice_edges[statement] = choice_edges + + elif isinstance(statement, (ChoiceRule, NPPRule)): + choice_aggr = AggrLiteral( + AggrCount(), + tuple(count_elements), + choice.guards, + ) + + # TODO: guard encoding + + # encode choice aggregate + if choice_aggr not in self.node_id_dict: + self.add_node( + choice_aggr, + "count", + f"\#count_{{{len(self.node_dict['count']['x'])}}}", + guards=tuple(self.encode_guards(choice_aggr.guards)), + ) + + for literal, cond_key in cond_map.items(): + if cond_key is None: + self.add_edge( + literal, + choice_aggr, + ) + else: + conj_key = Conjunction(literal, cond_key) + + if conj_key not in self.node_id_dict: + self.add_node( + conj_key, + "conj", + f"{self.__unicode_symbols['conj']}_{{{len(self.node_dict['conj']['x'])-1}}}", + ) + + self.add_edge( + cond_key, + conj_key, + ) + self.add_edge( + literal, + conj_key, + ) + + self.add_edge( + conj_key, + choice_aggr, + ) + + # add choice constraint + constr_key = Conjunction(*body_literals, Naf(choice_aggr)) + + if constr_key not in self.node_id_dict: + self.add_node( + constr_key, + "conj", + f"{self.__unicode_symbols['conj']}_{{{len(self.node_dict['conj']['x'])-1}}}", + ) + + self.add_edge( + body_key, + constr_key, + edge_weight=body_sign, + ) + self.add_edge( + choice_aggr, + constr_key, + edge_weight=-1, + ) + + self.add_edge(constr_key, self.false_const) + + if isinstance(statement, ChoiceRule) and statement not in self.choices: + # TODO: necessary ??? + self.choices.add(statement) + # TODO: best way to store choices? + self.choice_edges[statement] = choice_edges + + if isinstance(statement, NPPRule) and statement not in self.choices: + # TODO: necessary ??? + self.npps.add(statement) + # TODO: best way to store choices? + self.npp_edges[statement] = choice_edges + + def encode_literal( + self, + literal: Literal, + certain_atoms: Optional[Set[PredLiteral]] = None, + ) -> int: + if certain_atoms is None: + certain_atoms = set() + + if isinstance(literal, BuiltinLiteral): + # nothing to do here + return 0 + elif isinstance(literal, AggrLiteral): + aggr = abs(literal) + + # positive or negative aggregate + self.encode_aggregate(aggr) + else: + atom = abs(literal) + + # initialize probability with 1.0 if atom is certain (i.e., fact) + p = float(atom in certain_atoms) + + # register literal if not exits + try: + # update existing node + # use maximum possible value (a fact is not invalidated by a rule) + node_type, literal_id = self.node_id_dict[atom] + self.node_dict[node_type]["x"][literal_id] = max( + self.node_dict[node_type]["x"][literal_id], p + ) + # update value if it does + except KeyError: + # create new atom node + # since 'False' already registed, we can safely assume that all new head + # literals are atoms + self.add_node( + atom, + "atom", + label=str(atom), + x=p, + ) + + # check if strong negation is also encoded in the graph + neg_atom = Neg(atom, not atom.neg) + + if neg_atom in self.node_id_dict: + # add constraint that both cannot be true at the same time + # NOTE: since 'atom' is just encoded, we know there is no conj. yet + + conj_key = Conjunction(atom, neg_atom) + + self.add_node( + conj_key, + "conj", + f"{self.__unicode_symbols['conj']}_{{{len(self.node_dict['conj']['x'])-1}}}", + ) + + self.add_edge( + atom, + conj_key, + ) + self.add_edge( + neg_atom, + conj_key, + ) + self.add_edge(conj_key, self.false_const) + + return -1 if literal.naf else 1 + + def encode_aggregate( + self, + aggr: AggrLiteral, + certain_atoms: Optional[Set[PredLiteral]] = None, + ) -> int: + """Encodes an aggregate in the graph. + + Args: + aggr: `AggrLiteral` instance. + + Raises: + TODO + """ + # TODO: optimize & update to use new 'encode_literal' function! + + if certain_atoms is None: + certain_atoms = set() + + aggr_type = str(aggr.func)[1:] + + if aggr not in self.node_id_dict: + # create new aggregate node + self.add_node( + aggr, + aggr_type, + label=f"{str(aggr.func)}_{{{len(self.node_dict[aggr_type]['x'])-1}}}", + guards=tuple(self.encode_guards(aggr.guards)), + ) + + # ---------- sort elements ---------- + # dictionary mapping a tuple to possible conditions satisfying it + # (multiple possible; only one needs to hold) + cond_dict = defaultdict(lambda: defaultdict(lambda: None)) + + for elem in aggr.elements: + predicate_literals = [] + for literal in elem.literals: + if isinstance(literal, BuiltinLiteral): + if not literal.eval(): + # false built-in literal (i.e., body never satisfied) + break + else: + # keep classical literals + predicate_literals.append(literal) + # run if loop did not break early + else: + # keep track of condition for tuple + literals = Conjunction(*predicate_literals) + cond_dict[elem.terms][literals] + + # register atoms in aggregate element (if not already) + for literal in literals: + self.encode_literal(abs(literal), certain_atoms) + + # unconditional tuples (i.e., always satisfied) + uncond_tuples = [] + + # ---------- process tuples and conditions ---------- + for ( + tup, + cond_candidates, + ) in cond_dict.items(): + # condition always satisfied + # (check here to avoid construction of minimal collections) + if Conjunction() in cond_candidates: + uncond_tuples.append(tup) + continue + + # get minimal conditions + # (supersets irrelevant if a subset already satisfies condition) + minimal_cond_candidates = get_minimal_collections(*cond_candidates) + + # keep track of condition conjunctions + tuple_signature = Disjunction( + Conjunction(*condition) for condition in minimal_cond_candidates + ) + + if tuple_signature not in self.node_id_dict: + # create auxiliary atom representing satisfied tuple + self.add_node( + tuple_signature, + "disj", + label=f"{self.__unicode_symbols['disj']}_{{{len(self.node_dict['disj']['x'])-1}}}", + ) + + # connect auxiliary node to aggregate node + # TODO: better way to compute tuple weight? + self.add_edge( + tuple_signature, aggr, edge_weight=float(aggr.func.eval({tup}).eval()) + ) + + # process tuple conditions + for condition in minimal_cond_candidates: + if len(condition) == 1: + literal = condition[0] + pos_literal = abs(literal) + + if pos_literal not in self.node_id_dict: + self.add_node( + pos_literal, + "atom", + str(pos_literal), + ) + + self.add_edge( + pos_literal, + tuple_signature, + edge_weight=1.0 if not literal.naf else -1.0, + ) + else: + # NOTE: already handled non-conditional tuples earlier + conj_signature = Conjunction(*condition) + + # check if equivalent conjunction exists + if conj_signature not in self.node_id_dict: + # create new conjunction node + self.add_node( + conj_signature, + "conj", + f"{self.__unicode_symbols['conj']}_{{{len(self.node_dict['conj']['x'])-1}}}", + ) + + # connect literals to conjunction node + for literal in condition: + self.add_edge( + abs(literal), + conj_signature, + edge_weight=1.0 if not literal.naf else -1.0, + ) + + self.add_edge( + conj_signature, + tuple_signature, + ) + + # edge from 'True' to aggregate auxiliary atom + # (weight based on aggregate of all certain tuples) + if uncond_tuples: + self.add_edge( + self.true_const, + aggr, + edge_weight=float(aggr.func.eval(set(uncond_tuples)).eval()), + ) + + def encode_query( + self, query: Union[Constraint, ComplexQuery], certain_atoms: Optional[Set[PredLiteral]] = None + ) -> int: + """Adds a query to the reasoning graph. + + Args: + query: `Constraint` instance. + certain_atoms: optional set of atoms (`PredLiteral` instances) whose nodes + are initialized to `True`. Can be used to reduce number of iterations. + + Raises: + TODO + """ + try: + _, sink_id = self.node_id_dict[query] + self.query_sinks.append(sink_id) + + return sink_id + except KeyError: + # add new query-specific sink + sink_id = self.add_node( + query, + "disj", + label=str(query), + ) + + # keep track of global sink + global_sink = self.false_const + + # set new sink to query + self.false_const = query + self.query_sinks.append(sink_id) + + # simple query (single constraint) + if isinstance(query, Constraint): + # process query as a regular constraint (new sink is used instead) + self.encode_statement(query, certain_atoms) + # complex query (multiple constraints) + else: + for q in query.constraints: + # process query as a regular constraint (new sink is used instead) + self.encode_statement(q, certain_atoms) + + # reset sink to global sink + self.false_const = global_sink + + # connect global sink to query sink + self.add_edge( + self.false_const, + query, + ) + + return sink_id + + def add_node( + self, + expr: Expr, + node_type: str, + label: str, + **attrs: Dict[str, Any], + ) -> None: + """TODO""" + + # if node_type == "conj" and len(self.node_dict['conj']['x']) == 2: + # raise Exception(label) + + # get node ID + node_id = len(self.node_dict[node_type]["x"]) + + if expr in self.node_id_dict: + raise ValueError(f"Node representing '{str(expr)}' already exists.") + + # add node attributes + self.node_dict[node_type]["label"].append(label) + + if "x" not in attrs: + attrs["x"] = 0.0 + + for attr, val in attrs.items(): + self.node_dict[node_type][attr].append(val) + + # track expression encoded by node + self.node_id_dict[expr] = (node_type, node_id) + + return node_id + + def get_node( + self, + expr: Expr, + ) -> Optional[Tuple[str, int]]: + """TODO""" + + try: + return self.node_id_dict[expr] + except KeyError: + return None + + def add_edge( + self, + src_expr: Expr, + dst_expr: Expr, + **attrs: Dict[str, Any], + ) -> Tuple[Tuple[str, str, str], int]: + """TODO""" + try: + src_type, src_id = self.node_id_dict[src_expr] + except KeyError: + raise ValueError(f"No node representing expression {str(src_expr)}") + + try: + dst_type, dst_id = self.node_id_dict[dst_expr] + except KeyError: + raise ValueError(f"No node representing expression {str(dst_expr)}") + + edge_type = (src_type, "to", dst_type) + + # get edge ID + edge_id = len(self.edge_dict[edge_type]["edge_index"]) + + # add edge attributes + if (src_id, dst_id) in self.edge_dict[edge_type]["edge_index"]: + # TODO + raise ValueError( + f"Edge from {str(src_expr)} to {str(dst_expr)} already exists." + ) + + self.edge_dict[edge_type]["edge_index"].append((src_id, dst_id)) + + if "edge_weight" not in attrs: + attrs["edge_weight"] = 1.0 + + for attr, val in attrs.items(): + self.edge_dict[edge_type][attr].append(val) + + return edge_type, edge_id + + def get_edge( + self, + edge_type: Tuple[str, str, str], + src_expr: Expr, + dst_expr: Expr, + ) -> Optional[int]: + """TODO""" + + try: + src_id = self.node_id_dict[src_expr] + dst_id = self.node_id_dict[dst_expr] + + return self.edge_dict[edge_type]["edge_index"].index((src_id, dst_id)) + except KeyError: + return None + except ValueError: + return None + + def encode_guards(self, guards: Tuple[Guard, Guard]) -> Tuple[int, int, int, int]: + guard_encoding = [] + + # parse guards + for guard in guards: + if guard is None: + guard_encoding += [-1, -1] + else: + if isinstance(guard.bound, Number): + bound = guard.bound.eval() + else: + # infimum is the only object that precedes numbers in + # the total ordering for terms + # use +-infinity respectively + bound = ( + -float("inf") + if isinstance(guard.bound, Infimum) + else float("inf") + ) + + guard_encoding += [ + relop_dict[guard.op], + bound, + ] + + return guard_encoding + + def to_pyg( + self, + device: Optional[torch.device] = None, + hard: bool = True, + copies: int = 1, + ) -> HeteroData: + """TODO""" + + if copies < 1: + raise ValueError( + f"Number of copies for reasoning graph must be larger than zero, but was: {copies}." + ) + + # TODO: use quantized int8 for 'soft' values? + + # number of nodes per type + node_types = ("atom", "disj", "conj", "count", "sum", "min", "max") + num_nodes = tuple( + [len(self.node_dict[node_type]["label"]) for node_type in node_types] + ) + num_nodes_dict = dict(zip(node_types, num_nodes)) + + # initialize heterogeneous PyG graph + data = HeteroData() + data.hard = hard + data.device = device + data.copies = copies + + # ----- node features ----- + + for node_type in node_types: + # keep track of number of nodes + data[node_type].num_nodes = num_nodes_dict[node_type] + + if num_nodes_dict[node_type]: + # NOTE: we repeat the tensor to represent different copies of the same graph + data[node_type].x = ( + torch.tensor( + self.node_dict[node_type]["x"], + device=device + ).type(dtype=torch.int8 if hard else torch.get_default_dtype()) + .unsqueeze(1) + .repeat(1, copies) + ) + + if node_type in self.node_types[3:]: + data[node_type].guards = torch.tensor( + self.node_dict[node_type]["guards"], + device=device, + ) + else: + # empty data + data[node_type].x = torch.empty( + 0, + copies, + dtype=torch.int8 if data.hard else torch.get_default_dtype(), + device=device, + ) + + if node_type in self.node_types[3:]: + data[node_type].guards = torch.empty( + 0, + 4, + device=device, + ) + + # ----- edge indices and weights ----- + + # atom / disj. / conj. -> * + for src_type in self.node_types[:3]: + for dst_type in self.node_types: + edge_type = (src_type, "to", dst_type) + + # existing edges + if len(self.edge_dict[edge_type]["edge_weight"]): + data[edge_type].edge_index = torch.tensor( + self.edge_dict[edge_type]["edge_index"], + dtype=torch.long, + device=device, + ).T.contiguous() + # NOTE: we repeat the tensor to represent different copies of the same graph + data[edge_type].edge_weight = ( + torch.tensor( + self.edge_dict[edge_type]["edge_weight"], + device=device, + ).type(dtype=torch.int8 + if hard and dst_type not in self.node_types[3:] + else torch.get_default_dtype(),) + .unsqueeze(1) + .repeat(1, copies) + ) + else: + # empty data + data[edge_type].edge_index = torch.empty( + 2, + 0, + dtype=torch.long, + device=device, + ) + data[edge_type].edge_weight = torch.empty( + 0, + copies, + dtype=torch.int8 + if data.hard and dst_type not in node_types[3:] + else torch.get_default_dtype(), + device=device, + ) + + # count / sum / min / max -> * + for src_type in self.node_types[3:]: + for dst_type in self.node_types[:3]: + edge_type = (src_type, "to", dst_type) + + # existing edges + if len(self.edge_dict[edge_type]["edge_weight"]): + data[edge_type].edge_index = torch.tensor( + self.edge_dict[edge_type]["edge_index"], + dtype=torch.long, + device=device, + ).T.contiguous() + # NOTE: we repeat the tensor to represent different copies of the same graph + data[edge_type].edge_weight = ( + torch.tensor( + self.edge_dict[edge_type]["edge_weight"], + dtype=torch.int8 if hard else torch.get_default_dtype(), + device=device, + ) + .unsqueeze(1) + .repeat(1, copies) + ) + else: + # empty data + data[edge_type].edge_index = torch.empty( + 2, + 0, + dtype=torch.long, + device=device, + ) + data[edge_type].edge_weight = torch.empty( + 0, + copies, + dtype=torch.int8 + if data.hard and edge_type[2] not in node_types[3:] + else torch.get_default_dtype(), + device=device, + ) + + return data + + def draw( + self, + save_as: Optional[str] = None, + direction: str = "TB", + ) -> None: + pgv_graph = self.to_graphviz(direction=direction) + + if save_as is not None: + pgv_graph.draw(path=save_as, prog="dot") + + # draw without specifying a path (returns bytes object of image) + img = Image.open( + io.BytesIO(pgv_graph.draw(prog="dot", format="png")), formats=("PNG",) + ) + + try: + # check if __IPYTHON__ is defined (a bit of a hack) + # see https://discourse.jupyter.org/t/find-out-if-my-code-runs-inside-a-notebook-or-jupyter-lab/6935/7 + __IPYTHON__ + + # display using IPython + from IPython.display import display + + display(img) + except NameError: + # show in external window + img.show() + + def to_graphviz(self, direction: str = "TB") -> pgv.AGraph: + # TODO: automatically test for self-loops and choose strictness + + # initialize directed graph + graph = pgv.AGraph(directed=True, rankdir=direction) + + # add nodes + graph.add_node( + self.node_dict["disj"]["label"][0], + style="filled", + fillcolor="lightgoldenrod", + shape="circle", + label=self.__unicode_symbols["true"], + ) + graph.add_nodes_from( + self.node_dict["disj"]["label"][1:], + style="filled", + fillcolor="gray40", + shape="circle", + label=self.__unicode_symbols["disj"], + fontcolor="white", + ) + graph.add_nodes_from( + self.node_dict["atom"]["label"], + style="filled", + fillcolor="darkslategray3", + shape="oval", + ) + graph.add_node( + self.node_dict["conj"]["label"][0], + shape="circle", + label=self.__unicode_symbols["false"], + ) + graph.add_nodes_from( + self.node_dict["conj"]["label"][1:], + shape="circle", + label=self.__unicode_symbols["conj"], + ) + + # map encoded relation operators to a symbol + symbol_dict = { + 0: "=", + 1: self.__unicode_symbols["neq"], + 2: "<", + 3: ">", + 4: self.__unicode_symbols["leq"], + 5: self.__unicode_symbols["geq"], + } + + for node_type in ("count", "sum", "min", "max"): + for node_key, guards in zip( + self.node_dict[node_type]["label"], self.node_dict[node_type]["guards"] + ): + # TODO: clearner way? + # NOTE: not a high priority as plotting does not need to be performant + label = f"\#{node_type}" + + if guards[0] != -1: + bound = int(guards[1]) if isfinite(guards[1]) else guards[1] + label = f"{bound}{symbol_dict[guards[0]]}" + label + if guards[2] != -1: + bound = int(guards[3]) if isfinite(guards[3]) else guards[3] + label = label + f"{symbol_dict[guards[2]]}{bound}" + + graph.add_node( + node_key, + shape="rectangle", + label=label, + ) + + choice_edges = [] + + for _, edges in itertools.chain( + self.choice_edges.items(), self.npp_edges.items() + ): + for _, edge_type, edge_id in edges: + src_id, dst_id = self.edge_dict[edge_type]["edge_index"][edge_id] + src_key = self.node_dict[edge_type[0]]["label"][src_id] + dst_key = self.node_dict[edge_type[-1]]["label"][dst_id] + + choice_edges.append((src_key, dst_key)) + + # add edges + for src_type, dst_type in itertools.product( + ("atom", "disj", "conj", "count", "sum", "min", "max"), + ("atom", "disj", "conj", "count", "sum", "min", "max"), + ): + edge_type = (src_type, "to", dst_type) + + for (src, dst), w in zip( + self.edge_dict[edge_type]["edge_index"], + self.edge_dict[edge_type]["edge_weight"], + ): + if dst_type in ("count", "sum", "min", "max") or w == 1: + color = "black" + elif w == 0: + color = "gray65" + else: + color = "orangered" + + if isfinite(w): + w = int(w) + + src_key = self.node_dict[src_type]["label"][src] + dst_key = self.node_dict[dst_type]["label"][dst] + + graph.add_edge( + src_key, + dst_key, + color=color, + style="dashed" if (src_key, dst_key) in choice_edges else "", + label=str(w) if dst_type in ("sum", "min", "max") else "", + ) + + return graph diff --git a/src/asn/data/utils.py b/src/asn/data/utils.py new file mode 100644 index 0000000..3fa7d91 --- /dev/null +++ b/src/asn/data/utils.py @@ -0,0 +1,215 @@ +import itertools +from typing import Optional + +import torch +from torch_geometric.data import HeteroData + +__node_types = ("atom", "disj", "conj", "count", "sum", "min", "max") + + +def condense_edges_pyg( + data: HeteroData, device: Optional[torch.device] = None +) -> HeteroData: + """TODO""" + + # initialize new heterogeneous PyG graph + data_condensed = HeteroData() + data_condensed.hard = data.hard + data_condensed.copies = data.copies + + # number of nodes per type + num_nodes_dict = { + node_type: data[node_type].num_nodes for node_type in __node_types + } + # accumulate numbers of nodes + num_nodes_dict["_"] = sum([n for n in num_nodes_dict.values()]) + num_nodes_dict["atom/disj"] = num_nodes_dict["atom"] + num_nodes_dict["disj"] + num_nodes_dict["count/sum"] = num_nodes_dict["count"] + num_nodes_dict["sum"] + num_nodes_dict["atom/disj/conj"] = ( + num_nodes_dict["atom/disj"] + num_nodes_dict["conj"] + ) + + if device is None: + # infer device of 'data' + # NOTE: assumes all tensors reside on the same device! + for node_type, num_nodes in num_nodes_dict.items(): + if num_nodes > 0: + device = data[node_type].x.device + break + else: + device = None + + # infer number of copies in graph + copies = data["disj"].x.shape[-1] + + # offsets of node ids for each node type + offset_dict = { + node_type: cum_nodes + for node_type, cum_nodes in zip( + __node_types, + [0] + + list( + itertools.accumulate( + [num_nodes_dict[node_type] for node_type in __node_types] + ) + ), + ) + } + + # ----- edge indices ----- + + # -> atom/disj. + data_condensed[("_", "to", "atom/disj")].edge_index = torch.cat( + [ + torch.cat( + [ + torch.tensor([[offset_dict[src_type]], [0]], device=device) + + data[(src_type, "to", dst_type)].edge_index + if (src_type, "to", dst_type) in data.edge_types + else torch.empty(2, 0, device=device) + for src_type in __node_types + ], + dim=1, + ) + + torch.tensor([[0], [offset_dict[dst_type]]], device=device) + for dst_type in ("atom", "disj") + ], + dim=1, + ) + # -> conj. + data_condensed[("_", "to", "conj")].edge_index = torch.cat( + [ + torch.tensor([[offset_dict[src_type]], [0]], device=device) + + data[(src_type, "to", "conj")].edge_index + if (src_type, "to", "conj") in data.edge_types + else torch.empty(2, 0, device=device) + for src_type in __node_types + ], + dim=1, + ) + # -> count/sum + data_condensed[("atom/disj/conj", "to", "count/sum")].edge_index = torch.cat( + [ + torch.cat( + [ + torch.tensor([[offset_dict[src_type]], [0]], device=device) + + data[(src_type, "to", dst_type)].edge_index + if (src_type, "to", dst_type) in data.edge_types + else torch.empty(2, 0, device=device) + for src_type in ("atom", "disj", "conj") + ], + dim=1, + ) + + torch.tensor( + [[0], [offset_dict[dst_type] - offset_dict["count"]]], + device=device, + ) + for dst_type in ("count", "sum") + ], + dim=1, + ) + # -> min + data_condensed[("atom/disj/conj", "to", "min")].edge_index = torch.cat( + [ + torch.tensor([[offset_dict[src_type]], [0]], device=device) + + data[(src_type, "to", "min")].edge_index + if (src_type, "to", "min") in data.edge_types + else torch.empty(2, 0, device=device) + for src_type in ("atom", "disj", "conj") + ], + dim=1, + ) + # -> max + data_condensed[("atom/disj/conj", "to", "max")].edge_index = torch.cat( + [ + torch.tensor([[offset_dict[src_type]], [0]], device=device) + + data[(src_type, "to", "max")].edge_index + if (src_type, "to", "max") in data.edge_types + else torch.empty(2, 0, device=device) + for src_type in ("atom", "disj", "conj") + ], + dim=1, + ) + + # ----- node features ----- + + for node_type, num_nodes in num_nodes_dict.items(): + data_condensed[node_type].num_nodes = num_nodes_dict[node_type] + + for node_type in __node_types: + data_condensed[node_type].x = data[node_type].x + + if node_type in __node_types[3:]: + data_condensed[node_type].guards = data[node_type].guards + + # ----- edge features ----- + + # -> atom/disj. + if ("_", "to", "atom/disj") in data_condensed.edge_types: + data_condensed[("_", "to", "atom/disj")].edge_weight = torch.cat( + [ + data[(src_type, "to", dst_type)].edge_weight + if (src_type, "to", dst_type) in data.edge_types + else torch.empty( + 0, + data.copies, + dtype=torch.int8 if data.hard else torch.get_default_dtype(), + device=device, + ) + for dst_type in ("atom", "disj") + for src_type in __node_types + ], + dim=0, + ) + # -> conj. + if ("_", "to", "conj") in data_condensed.edge_types: + data_condensed[("_", "to", "conj")].edge_weight = torch.cat( + [ + data[(src_type, "to", "conj")].edge_weight + if (src_type, "to", "conj") in data.edge_types + else torch.empty( + 0, + data.copies, + dtype=torch.int8 if data.hard else torch.get_default_dtype(), + device=device, + ) + for src_type in __node_types + ], + dim=0, + ) + # -> count/sum + if ("atom/disj/conj", "to", "count/sum") in data_condensed.edge_types: + data_condensed[("atom/disj/conj", "to", "count/sum")].edge_weight = torch.cat( + [ + data[(src_type, "to", dst_type)].edge_weight + if (src_type, "to", dst_type) in data.edge_types + else torch.empty(0, data.copies, device=device) + for dst_type in ("count", "sum") + for src_type in __node_types[:3] + ], + dim=0, + ) + # -> min + if ("atom/disj/conj", "to", "min") in data_condensed.edge_types: + data_condensed[("atom/disj/conj", "to", "min")].edge_weight = torch.cat( + [ + data[(src_type, "to", "min")].edge_weight + if (src_type, "to", "min") in data.edge_types + else torch.empty(0, data.copies, device=device) + for src_type in __node_types[:3] + ], + dim=0, + ) + # -> max + if ("atom/disj/conj", "to", "max") in data_condensed.edge_types: + data_condensed[("atom/disj/conj", "to", "max")].edge_weight = torch.cat( + [ + data[(src_type, "to", "max")].edge_weight + if (src_type, "to", "max") in data.edge_types + else torch.empty(0, data.copies, device=device) + for src_type in __node_types[:3] + ], + dim=0, + ) + + return data_condensed diff --git a/src/asn/models/alexnet.py b/src/asn/models/alexnet.py new file mode 100644 index 0000000..e7c1e55 --- /dev/null +++ b/src/asn/models/alexnet.py @@ -0,0 +1,36 @@ +import torch.nn as nn + + +class AlexNet(nn.Module): + """TODO""" + + def __init__(self, n_out: int = 10): + """TODO""" + super(AlexNet, self).__init__() + self.encoder = nn.Sequential( + nn.Conv2d( + 1, 6, 5 + ), # 6 is the output chanel size; 5 is the kernal size; 1 (chanel) 28 28 -> 6 24 24 + nn.MaxPool2d(2, 2), # kernal size 2; stride size 2; 6 24 24 -> 6 12 12 + nn.ReLU( + True + ), # inplace=True means that it will modify the input directly thus save memory + nn.Conv2d(6, 16, 5), # 6 12 12 -> 16 8 8 + nn.MaxPool2d(2, 2), # 16 8 8 -> 16 4 4 + nn.ReLU(True), + ) + self.classifier = nn.Sequential( + nn.Linear(16 * 4 * 4, 120), + nn.ReLU(), + nn.Linear(120, 84), + nn.ReLU(), + nn.Linear(84, n_out), + nn.Softmax(-1), + ) + + def forward(self, x): + x = self.encoder(x) + x = x.view(-1, 16 * 4 * 4) + x = self.classifier(x) + + return x diff --git a/src/asn/models/einsum_wrapper.py b/src/asn/models/einsum_wrapper.py new file mode 100644 index 0000000..eb6e1bb --- /dev/null +++ b/src/asn/models/einsum_wrapper.py @@ -0,0 +1,171 @@ +import torch +import torch.nn as nn +import numpy as np +# import EinsumNetworks +# import EinsumNetworks.src +# from EinsumNetworks.src import EinsumNetwork +from EinsumNetwork import EinsumNetwork +from EinsumNetwork import Graph +device = torch.device('cuda:0') + +#wrapper class to create an Einsum Network given a specific structure and parameters +class EiNet(EinsumNetwork.EinsumNetwork): + def __init__(self , + use_em, + structure = 'poon-domingos', + pd_num_pieces = [4], + depth = 8, + num_repetitions = 20, + num_var = 784, + class_count = 3, + K = 10, + num_sums = 10, + pd_height = 28, + pd_width = 28, + learn_prior = True + ): + + # super(EinsumNetwork, self).__init__() + # Structure + self.structure = structure + self.class_count = class_count + classes = np.arange(class_count) # [0,1,2,..,n-1] + + # Define the prior, i.e. P(C) and make it learnable. + self.learnable_prior = learn_prior + # P(C) is needed to apply the Bayes' theorem and to retrive + # P(C|X) = P(X|C)*(P(C) / P(X) + if self.class_count == 4: + self.prior = torch.tensor([(1/3)*(2/3), (1/3)*(2/3), (1/3)*(2/3), (1/3)], dtype=torch.float, requires_grad=True, device=device).log() + else: + self.prior = torch.ones(class_count, device=device, dtype=torch.float) + self.prior.fill_(1 / class_count) + self.prior.log_() + if self.learnable_prior: + print("P(C) is learnable.") + self.prior.requires_grad_() + #print(f"P(C) is {self.prior}") + + self.K = K + self.num_sums = num_sums + + # 'poon-domingos' + self.pd_num_pieces = pd_num_pieces # [10, 28],[4],[7] + self.pd_height = pd_height + self.pd_width = pd_width + + + # 'binary-trees' + self.depth = depth + self.num_repetitions = num_repetitions + self.num_var = num_var + + # drop-out rate + # self.drop_out = drop_out + # print("The drop-out rate:", self.drop_out) + + # EM-settings + self.use_em = use_em + online_em_frequency = 1 + online_em_stepsize = 0.05 # 0.05 + print("train SPN with EM:",self.use_em) + + + # exponential_family = EinsumNetwork.BinomialArray + # exponential_family = EinsumNetwork.CategoricalArray + exponential_family = EinsumNetwork.NormalArray + + exponential_family_args = None + if exponential_family == EinsumNetwork.BinomialArray: + exponential_family_args = {'N': 255} + if exponential_family == EinsumNetwork.CategoricalArray: + exponential_family_args = {'K': 1366120} + if exponential_family == EinsumNetwork.NormalArray: + exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1} + + # Make EinsumNetwork + if self.structure == 'poon-domingos': + pd_delta = [[self.pd_height / d, self.pd_width / d] for d in self.pd_num_pieces] + graph = Graph.poon_domingos_structure(shape=(self.pd_height, self.pd_width), delta=pd_delta) + elif self.structure == 'binary-trees': + graph = Graph.random_binary_trees(num_var=self.num_var, depth=self.depth, num_repetitions=self.num_repetitions) + else: + raise AssertionError("Unknown Structure") + + + args = EinsumNetwork.Args( + num_var=self.num_var, + num_dims=1, + num_classes=self.class_count, + num_sums=self.num_sums, + num_input_distributions=self.K, + exponential_family=exponential_family, + exponential_family_args=exponential_family_args, + use_em=self.use_em, + online_em_frequency=online_em_frequency, + online_em_stepsize=online_em_stepsize) + + super().__init__(graph, args) + super().initialize() + + def get_log_likelihoods(self, x): + log_likelihood = super().forward(x) + return log_likelihood + + def forward(self, x, marg_idx=None, type=1): + + # PRIOR + if type == 4: + expanded_prior = self.prior.expand(x.shape[0], self.prior.shape[0]) + return expanded_prior + + else: + # Obtain P(X|C) in log domain + if marg_idx: # If marginalisation mask is passed + self.set_marginalization_idx(marg_idx) + log_likelihood = super().forward(x) + self.set_marginalization_idx(None) + likelihood = torch.nn.functional.softmax(log_likelihood, dim=1) + else: + log_likelihood = super().forward(x) + #print("P(X|C):", likelihood.shape, likelihood[0], torch.sum(likelihood[0])) + + #LIKELIHOOD + if type == 2: + likelihood = torch.nn.functional.softmax(log_likelihood, dim=1) + # Sanity check for NaN-values + if torch.isnan(log_likelihood).sum() > 0: + print("likelihood nan") + + return likelihood + else: + # Apply Bayes' Theorem to obtain P(C|X) instead of P(X|C) + # as it is provided by the EiNet + # 1. Computation of the prior, i.e. P(C), is already being + # dealt with at the initialisation of the EiNet. + # 2. Compute the normalization constant P(X) + z = torch.logsumexp(log_likelihood + self.prior, dim=1) + # 3. Compute the posterior, i.e. P(C|X) = (P(X|C) * P(C)) / P(X) + posterior_log = (log_likelihood + self.prior - z[:, None]) # log domain + #posterior = posterior_log.exp() # decimal domain + #print("P(C|X)", posterior.shape, posterior[0]) + + + #POSTERIOR + if type == 1: + posterior = torch.nn.functional.softmax(posterior_log, dim=1) + + # Sanity check for NaN-values + if torch.isnan(z).sum() > 0: + print("z nan") + if torch.isnan(posterior).sum() > 0: + print("posterior nan") + return posterior + + #JOINT + elif type == 3: + #compute the joint P(X|C) * P(C) + joint = torch.nn.functional.softmax(log_likelihood + self.prior, dim=1) + #print("P(X,C)",(joint.shape), joint[0], torch.sum(joint[0])) + + return joint diff --git a/src/asn/models/llm_wrapper.py b/src/asn/models/llm_wrapper.py new file mode 100644 index 0000000..2ff87ac --- /dev/null +++ b/src/asn/models/llm_wrapper.py @@ -0,0 +1,49 @@ + + +from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup +import torch +from transformers import LlamaForCausalLM +import torch +from peft import get_peft_model, LoraConfig, TaskType + + + +#wrapper class for AutoModelForSequenceClassification +class LLMWrapper(): + + def __init__(self, model): + self.model = model + + + def forward(self, kwargs): + #forward pass to obtain logits for the specified tokens + + return self.model(kwargs).logits.softmax(dim=1) + + + + + +# LLM wrapper class that inherits from nn.Module +class LLMWithClassProbs(torch.nn.Module): + def __init__(self, model, tokenizer): + super().__init__() + self.model = model + self.tokenizer = tokenizer + + def forward(self, *args, **kwargs): + input_ids = args[0] + attention_mask = args[1] + class_token_ids = args[2] + + seq_lengths = attention_mask.sum(dim=1) + + output = self.model(input_ids) + # Get the logits for the last token + logits = output.logits[0, seq_lengths, :] + + + # Get the probabilities for the class tokens + probs = torch.nn.functional.softmax(logits[:,class_token_ids[0]], dim=1) + return probs + diff --git a/src/asn/models/promis_mock_model.py b/src/asn/models/promis_mock_model.py new file mode 100644 index 0000000..9e32fcc --- /dev/null +++ b/src/asn/models/promis_mock_model.py @@ -0,0 +1,12 @@ +import torch.nn as nn +import torch +class PromisMockNet(nn.Module): + """TODO""" + + def __init__(self, type='house'): + """TODO""" + super(PromisMockNet, self).__init__() + self.param = nn.Parameter(torch.tensor([0.0])) + + def forward(self, x): + return x diff --git a/src/asn/models/slot_attention.py b/src/asn/models/slot_attention.py new file mode 100644 index 0000000..5351dae --- /dev/null +++ b/src/asn/models/slot_attention.py @@ -0,0 +1,258 @@ +""" +Slot attention model based on code of tkipf and the corresponding paper Locatello et al. 2020 +""" +from torch import nn +import torch +import torch.nn.functional as F +import torchvision.models as models +import numpy as np +#from torchsummary import summary + + +def build_grid(resolution): + ranges = [np.linspace(0., 1., num=res) for res in resolution] + grid = np.meshgrid(*ranges, sparse=False, indexing="ij") + grid = np.stack(grid, axis=-1) + grid = np.reshape(grid, [resolution[0], resolution[1], -1]) + grid = np.expand_dims(grid, axis=0) + grid = grid.astype(np.float32) + return np.concatenate([grid, 1.0 - grid], axis=-1) + + +def spatial_broadcast(slots, resolution): + """Broadcast slot features to a 2D grid and collapse slot dimension.""" + # `slots` has shape: [batch_size, num_slots, slot_size]. + slots = torch.reshape(slots, [slots.shape[0] * slots.shape[1], 1, 1, slots.shape[2]]) + + grid = slots.repeat(1, resolution[0], resolution[1], 1) #repeat expands the data along differnt dimensions + # `grid` has shape: [batch_size*num_slots, width, height, slot_size]. + return grid + + +def unstack_and_split(x, batch_size, n_slots, num_channels=3): + """Unstack batch dimension and split into channels and alpha mask.""" + # unstacked = torch.reshape(x, [batch_size, -1] + list(x.shape[1:])) + # channels, masks = torch.split(unstacked, [num_channels, 1], dim=-1) + unstacked = torch.reshape(x, [batch_size, n_slots] + list(x.shape[1:])) + channels, masks = torch.split(unstacked, [num_channels, 1], dim=2) + return channels, masks + + +class SlotAttention(nn.Module): + def __init__(self, num_slots, dim, iters=3, eps=1e-8, hidden_dim=128): + super().__init__() + self.num_slots = num_slots + self.iters = iters + self.eps = eps + self.scale = dim ** -0.5 #named D in the paper + + self.slots_mu = nn.Parameter(torch.randn(1, 1, dim)) #randomly initialize sigma and mu + self.slots_log_sigma = nn.Parameter(torch.randn(1, 1, dim)).abs().to(device='cuda') + #self.slots_mu = nn.Parameter(torch.nn.init.xavier_uniform_(torch.empty(1,1,dim), gain=1.0)) #randomly initialize sigma and mu + #self.slots_log_sigma = nn.Parameter(torch.nn.init.xavier_uniform_(torch.empty(1,1,dim), gain=1.0)) + + self.project_q = nn.Linear(dim, dim, bias=True) #query projection + self.project_k = nn.Linear(dim, dim, bias=True) # + self.project_v = nn.Linear(dim, dim, bias=True) #feature key projection + + self.gru = nn.GRUCell(dim, dim) + + hidden_dim = max(dim, hidden_dim) + + self.mlp = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, dim) + ) + + self.norm_inputs = nn.LayerNorm(dim, eps=1e-05) + self.norm_slots = nn.LayerNorm(dim, eps=1e-05) + self.norm_mlp = nn.LayerNorm(dim, eps=1e-05) + + self.attn = 0 + + def forward(self, inputs, num_slots=None): + b, n, d = inputs.shape #b is the batchsize, n is the dimensionsize of the features, d is the amount of features([15, 1024, 32]) + n_s = num_slots if num_slots is not None else self.num_slots + + mu = self.slots_mu.expand(b, n_s, -1) #mu and sigma are shared by all slots + sigma = self.slots_log_sigma.expand(b, n_s, -1) + slots = torch.normal(mu, sigma) #sample slots from mu and sigma + #slots = torch.normal(mu, sigma.exp()) #sample slots from mu and sigma + + + inputs = self.norm_inputs(inputs) #layer normalization of inputs + k, v = self.project_k(inputs), self.project_v(inputs) #*self.scale + + + for _ in range(self.iters): + slots_prev = slots #store old slots + + slots = self.norm_slots(slots) #layer norm of slots + q = self.project_q(slots) #emit a query for all slots + + dots = torch.einsum('bid,bjd->bij', q, k) * self.scale #is M in the paper, has shape 1024(feature map)| 7(slot amount) + attn = dots.softmax(dim=1) + self.eps #calcualte the softmax for each slot which is also 1024 * 7 + attn = attn / attn.sum(dim=-1, keepdim=True) #weighted mean + + updates = torch.einsum('bjd,bij->bid', v, attn) + + #recurrently update the slots with the slot updates and the previous slots + slots = self.gru( + updates.reshape(-1, d), + slots_prev.reshape(-1, d) + ) + + #apply 2 layer relu mlp to GRU output + slots = slots.reshape(b, -1, d) + slots = slots + self.mlp(self.norm_mlp(slots)) + + self.attn = attn + + return slots + + +class SlotAttention_encoder(nn.Module): + def __init__(self, in_channels, hidden_channels, clevr_encoding): + super(SlotAttention_encoder, self).__init__() + + if clevr_encoding: + self.network = nn.Sequential( + nn.Conv2d(in_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(2, 2), padding=2), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(2, 2), padding=2), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), + nn.ReLU(inplace=True)) + else: + self.network = nn.Sequential( + nn.Conv2d(in_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, hidden_channels, (5, 5), stride=(1, 1), padding=2), + nn.ReLU(inplace=True)) + + + + + def forward(self, x): + return self.network(x) + + +class MLP(nn.Module): + def __init__(self, hidden_channels): + super(MLP, self).__init__() + self.network = nn.Sequential( + nn.Linear(hidden_channels, hidden_channels), + nn.ReLU(inplace=True), + nn.Linear(hidden_channels, hidden_channels), + ) + + def forward(self, x): + return self.network(x) + + +class SoftPositionEmbed(nn.Module): + """Adds soft positional embedding with learnable projection.""" + + def __init__(self, hidden_size, resolution, device="cuda:0"): + """Builds the soft position embedding layer. + Args: + hidden_size: Size of input feature dimension. + resolution: Tuple of integers specifying width and height of grid. + """ + super().__init__() + self.dense = nn.Linear(4, hidden_size) + # self.grid = torch.FloatTensor(build_grid(resolution)) + # self.grid = self.grid.to(device) + # for nn.DataParallel + self.register_buffer("grid", torch.FloatTensor(build_grid(resolution))) + self.resolution = resolution[0] + self.hidden_size = hidden_size + + def forward(self, inputs): + #print("positional embedding",inputs.shape, self.dense(self.grid).view((-1, self.hidden_size, self.resolution, self.resolution)).shape ) + return inputs + self.dense(self.grid).view((-1, self.hidden_size, self.resolution, self.resolution)) + + +class SlotAttention_classifier(nn.Module): + def __init__(self, in_channels, out_channels): + super(SlotAttention_classifier, self).__init__() + self.network = nn.Sequential( + nn.Linear(in_channels, in_channels), # nn.Conv1d(in_channels, in_channels, 1, stride=1, groups=in_channels) + nn.ReLU(inplace=True), + nn.Linear(in_channels, out_channels), + nn.Sigmoid() + ) + + def forward(self, x): + return self.network(x) + + +class SlotAttention_model(nn.Module): + def __init__(self, n_slots, n_iters, n_attr, + in_channels=3, + encoder_hidden_channels=64, + attention_hidden_channels=128, + mlp_prediction = False, + device="cuda", + clevr_encoding=False): + super(SlotAttention_model, self).__init__() + self.n_slots = n_slots + self.n_iters = n_iters + self.n_attr = n_attr + self.n_attr = n_attr + 1 # additional slot to indicate if it is a object or empty slot + self.device = device + + self.encoder_cnn = SlotAttention_encoder(in_channels=in_channels, hidden_channels=encoder_hidden_channels , clevr_encoding=clevr_encoding) + self.encoder_pos = SoftPositionEmbed(encoder_hidden_channels, (32, 32), device=device)# changed from 128* 128 + self.layer_norm = nn.LayerNorm(encoder_hidden_channels, eps=1e-05) + self.mlp = MLP(hidden_channels=encoder_hidden_channels) + self.slot_attention = SlotAttention(num_slots=n_slots, dim=encoder_hidden_channels, iters=n_iters, eps=1e-8, + hidden_dim=attention_hidden_channels) + + #for set prediction baseline + self.mlp_prediction = mlp_prediction + self.mlp_classifier = SlotAttention_classifier(in_channels=encoder_hidden_channels, out_channels=self.n_attr) + + self.softmax = nn.Softmax(dim=1) + + def forward(self, img): + # `x` has shape: [batch_size, width, height, num_channels]. + #print("input img shape", img.shape) + + # SLOT ATTENTION ENCODER + x = self.encoder_cnn(img) + x = self.encoder_pos(x) + x = torch.flatten(x, start_dim=2) + + # permute channel dimensions + x = x.permute(0, 2, 1) + x = self.layer_norm(x) + x = self.mlp(x) + + #print("shape after mlp", x.shape) + + slots = self.slot_attention(x) + # slots has shape: [batch_size, num_slots, slot_size]. + if self.mlp_prediction: + x = self.mlp_classifier(slots) + return x + else: + return slots + + +if __name__ == "__main__": + x = torch.rand(15, 3, 32, 32).cuda() + net = SlotAttention_model(n_slots=11, n_iters=3, n_attr=18, + encoder_hidden_channels=32, attention_hidden_channels=64, + decoder_hidden_channels=32, decoder_initial_size=(8, 8)) + net = net.cuda() + output = net(x) + #summary(net, (3, 32, 32)) + diff --git a/src/asn/solver/__init__.py b/src/asn/solver/__init__.py new file mode 100644 index 0000000..3b6a03c --- /dev/null +++ b/src/asn/solver/__init__.py @@ -0,0 +1,6 @@ +from .gnn import ASNGNN +from .graph_block import GraphBlock +from .npp_context import NPPContext +from .solver import Solver +from .solving_context import SolvingContext +from .stable_model_context import StableModelContext diff --git a/src/asn/solver/gnn/__init__.py b/src/asn/solver/gnn/__init__.py new file mode 100644 index 0000000..3c15b18 --- /dev/null +++ b/src/asn/solver/gnn/__init__.py @@ -0,0 +1,13 @@ +from .gnn import ASNGNN # noqa +from .message_passing import AggrMaxUpdater # noqa +from .message_passing import AggrMinUpdater # noqa +from .message_passing import AggrSumUpdater # noqa +from .message_passing import AggrUpdater # noqa +from .message_passing import ConjUpdater # noqa +from .message_passing import DisjUpdater # noqa +from .message_passing import ASNUpdater # noqa +from .message_passing import SoftConjUpdater # noqa +from .message_passing import SoftDisjUpdater # noqa +from .message_passing import scatter_boltzmann # noqa + +# TODO: diff --git a/src/asn/solver/gnn/constr.py b/src/asn/solver/gnn/constr.py new file mode 100644 index 0000000..fad9a90 --- /dev/null +++ b/src/asn/solver/gnn/constr.py @@ -0,0 +1,31 @@ +import torch + +from asn.utils import eval_dict + + +def constr_eval( + vals: torch.Tensor, + lops: torch.Tensor, + lbounds: torch.Tensor, + rops: torch.Tensor, + rbounds: torch.Tensor, +) -> torch.Tensor: + res = torch.ones_like(vals, dtype=torch.bool) + + for lop in torch.unique(lops): + if lop == -1: + # no left guard specified + continue + + op_mask = (lops == lop).squeeze(-1) + res[op_mask] = eval_dict[lop.item()](lbounds[op_mask], vals[op_mask]) + + for rop in torch.unique(rops): + if rop == -1: + # no left guard specified + continue + + op_mask = (rops == rop).squeeze(-1) + res[op_mask] &= eval_dict[rop.item()](vals[op_mask], rbounds[op_mask]) + + return res diff --git a/src/asn/solver/gnn/gnn.py b/src/asn/solver/gnn/gnn.py new file mode 100644 index 0000000..d85dc22 --- /dev/null +++ b/src/asn/solver/gnn/gnn.py @@ -0,0 +1,97 @@ +import itertools +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn + +from .message_passing import ( + AggrMaxUpdater, + AggrMinUpdater, + AggrSumUpdater, + ConjUpdater, + DisjUpdater, +) + +if TYPE_CHECKING: + from torch import Tensor + from torch_geometric.nn import MessagePassing + + +_node_types = ("atom", "disj", "conj", "count", "sum", "min", "max") + + +class ASNGNN(nn.Module): + def __init__( + self, + updaters: Optional[Dict[str, "MessagePassing"]] = None, + ) -> None: + super().__init__() + + if updaters is None: + updaters = { + **dict.fromkeys(["atom", "disj"], DisjUpdater()), + "conj": ConjUpdater(), + **dict.fromkeys(["count", "sum"], AggrSumUpdater()), + "min": AggrMinUpdater(), + "max": AggrMaxUpdater(), + } + + # message passing node updaters + self.updaters = updaters + + def forward( + self, + node_dict: Dict[str, Dict[str, "Tensor"]], + edge_dict: Dict[str, Dict[str, "Tensor"]], + certain_atom_ids: Optional["Tensor"] = None, + ) -> Tuple["Tensor", bool]: + # flag signifying whether or not any node values changed + converged = True + + # for all different edge types + for edge_type in edge_dict.keys(): + # parse source and destination node types + src_types = edge_type[0].split("/") if edge_type[0] != "_" else _node_types + dst_types = edge_type[2].split("/") + + # update nodes + for dst_type, x_prime in zip( + dst_types, + torch.split( + # assumes that all target nodes share same update function + self.updaters[dst_types[0]]( + x=( + torch.cat( + [node_dict[src_type].x for src_type in src_types], + dim=0, + ), + torch.cat( + [node_dict[dst_type].x for dst_type in dst_types], + dim=0, + ), + ), + edge_index=edge_dict[edge_type].edge_index, + edge_weight=edge_dict[edge_type].edge_weight, + guards=( + torch.cat( + [node_dict[dst_type].guards for dst_type in dst_types], + dim=0, + ) + if dst_types[0] in _node_types[3:] + else None + ), + ), + tuple(node_dict[dst_type].num_nodes for dst_type in dst_types), + dim=0, + ), + ): + if converged and not torch.allclose(node_dict[dst_type].x, x_prime): + converged = False + + node_dict[dst_type].x = x_prime + + # manually set value for certain atoms to 1 (True) + if certain_atom_ids is not None: + node_dict["atom"].x[certain_atom_ids] = 1.0 + + return node_dict, converged diff --git a/src/asn/solver/gnn/message_passing.py b/src/asn/solver/gnn/message_passing.py new file mode 100644 index 0000000..930a13e --- /dev/null +++ b/src/asn/solver/gnn/message_passing.py @@ -0,0 +1,136 @@ +from abc import ABC, abstractmethod +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch_geometric.nn import MessagePassing +from torch_scatter import scatter_add, scatter_max, scatter_min, scatter_softmax + +from .constr import constr_eval + + +def scatter_boltzmann( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + temp: float = 1.0, + out: torch.Tensor = None, + dim_size: int = None, +) -> torch.Tensor: + softmax = scatter_softmax(temp * src, index, dim=dim) + return scatter_add(src * softmax, index, dim=dim, out=out, dim_size=dim_size) + + +class ASNUpdater(MessagePassing, ABC): + def forward( + self, + x: Tuple[Tensor, Tensor], + edge_index: Tensor, + edge_weight: Tensor, + guards: Optional[Tensor] = None, + ) -> Tensor: + return self.propagate( + edge_index, + x=x, + edge_weight=edge_weight, + num_nodes=x[1].shape[0], + guards=guards, + ) + + def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: + return (edge_weight < 0) + x_j * edge_weight + + @abstractmethod + def aggregate(self, *args, **kwargs) -> Tensor: + pass + + +class ConjUpdater(ASNUpdater): + def aggregate(self, inputs: Tensor, index: Tensor, num_nodes: int) -> Tensor: + out = torch.ones( + num_nodes, inputs.shape[-1], dtype=inputs.dtype, device=inputs.device + ) + # compute logical AND for boolean values + return scatter_min(inputs, index, dim=0, out=out)[0] + + +class SoftConjUpdater(ASNUpdater): + def aggregate(self, inputs: Tensor, index: Tensor, num_nodes: int) -> Tensor: + out = torch.ones( + num_nodes, inputs.shape[-1], dtype=inputs.dtype, device=inputs.device + ) + # compute soft (i.e., differentiable) logical AND + return scatter_boltzmann(inputs, index, temp=8.0, dim=0, out=out) + + +class DisjUpdater(ASNUpdater): + def aggregate(self, inputs: Tensor, index: Tensor, num_nodes: int) -> Tensor: + out = torch.zeros( + num_nodes, inputs.shape[-1], dtype=inputs.dtype, device=inputs.device + ) + # compute logical OR for boolean values + return scatter_max(inputs, index, dim=0, out=out)[0] + + +class SoftDisjUpdater(ASNUpdater): + def aggregate(self, inputs: Tensor, index: Tensor, num_nodes: int) -> Tensor: + out = torch.zeros( + num_nodes, inputs.shape[-1], dtype=inputs.dtype, device=inputs.device + ) + # compute soft (i.e., differentiable) logical OR + return scatter_boltzmann(inputs, index, temp=-8.0, dim=0, out=out) + + +class AggrUpdater(ASNUpdater, ABC): + @abstractmethod + def message(self, *args, **kwargs) -> Tensor: + pass + + @abstractmethod + def aggregate(self, *args, **kwargs) -> Tensor: + pass + + def update(self, aggr_out: Tensor, guards: Tensor, x_i: Tensor) -> Tensor: + # check bounds for aggregated value + return constr_eval( + aggr_out, + torch.index_select(guards, -1, torch.tensor(0, device=aggr_out.device)), + torch.index_select(guards, -1, torch.tensor(1, device=aggr_out.device)), + torch.index_select(guards, -1, torch.tensor(2, device=aggr_out.device)), + torch.index_select(guards, -1, torch.tensor(3, device=aggr_out.device)), + ).type(x_i.dtype) + + +class AggrSumUpdater(AggrUpdater): + def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: + return edge_weight * x_j + + def aggregate(self, inputs: Tensor, index: Tensor, num_nodes: int) -> Tensor: + # sum up active incoming edge weights + return scatter_add(inputs, index, dim=0, dim_size=num_nodes) + + +class AggrMinUpdater(AggrUpdater): + def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: + return torch.where( + torch.isclose(x_j.type(torch.get_default_dtype()), torch.tensor(1.0)), + edge_weight, + float("inf"), + ) + + def aggregate(self, inputs: Tensor, index: Tensor, num_nodes: int) -> Tensor: + # compute minimum accross all values + return scatter_min(inputs, index, dim=0, dim_size=num_nodes)[0] + + +class AggrMaxUpdater(AggrUpdater): + def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: + return torch.where( + torch.isclose(x_j.type(torch.get_default_dtype()), torch.tensor(1.0)), + edge_weight, + -float("inf"), + ) + + def aggregate(self, inputs: Tensor, index: Tensor, num_nodes: int) -> Tensor: + # compute maximum accross all values + return scatter_max(inputs, index, dim=0, dim_size=num_nodes)[0] diff --git a/src/asn/solver/graph_block.py b/src/asn/solver/graph_block.py new file mode 100644 index 0000000..2691db6 --- /dev/null +++ b/src/asn/solver/graph_block.py @@ -0,0 +1,47 @@ +from functools import cached_property +from typing import TYPE_CHECKING, Optional + +import torch + +if TYPE_CHECKING: + from torch_geometric.data.storage import EdgeStorage, NodeStorage + + +class GraphBlock: + """TODO""" + + def __init__( + self, + node_dict: "NodeStorage", + edge_dict: "EdgeStorage", + npp_choices: torch.Tensor, + certain_atom_ids: Optional[torch.Tensor] = None, + sink_ids: Optional[torch.Tensor] = None, + ) -> None: + """TODO""" + self.node_dict = node_dict + self.edge_dict = edge_dict + self.npp_choices = npp_choices + self.sink_ids = ( + sink_ids if sink_ids is not None else torch.tensor([0], device=self.device) + ) + # get unique sink ids and the inverse indicices + self.unique_sink_ids, self.inverse_unique_sink_ids = self.sink_ids.unique( + sorted=False, return_inverse=True + ) + self.certain_atom_ids = certain_atom_ids + self.device = self.node_dict["conj"]["x"].device + + @cached_property + def block_size(self) -> int: + # infer number of combinations in graph block + return self.node_dict["conj"]["x"].shape[1] + + @property + def atoms(self) -> torch.Tensor: + """TODO""" + return self.node_dict["atom"]["x"].T + + @property + def is_model(self) -> torch.Tensor: + return ~self.node_dict["disj"].x[self.unique_sink_ids].unsqueeze(-1).bool() diff --git a/src/asn/solver/npp_context.py b/src/asn/solver/npp_context.py new file mode 100644 index 0000000..a2ca752 --- /dev/null +++ b/src/asn/solver/npp_context.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class NPPContext: + """TODO""" + + p: Optional[torch.Tensor] = None # batch_size x n_out + + def to(self, *args, **kwargs) -> "NPPContext": + """TODO""" + if self.p is not None: + self.p.to(*args, **kwargs) + + return self diff --git a/src/asn/solver/solver.py b/src/asn/solver/solver.py new file mode 100644 index 0000000..ca104d9 --- /dev/null +++ b/src/asn/solver/solver.py @@ -0,0 +1,40 @@ +from typing import Optional + +from .gnn import ASNGNN +from .graph_block import GraphBlock + +_node_types = ("atom", "disj", "conj", "count", "sum", "min", "max") + + +# TODO: repackage in GNN class ??? +class Solver: + """TODO""" + + def __init__(self): + """TODO""" + # initialize GNN + self.gnn = ASNGNN() + + def solve(self, graph_block: GraphBlock, max_iter: int = -1) -> GraphBlock: + """TODO""" + if max_iter == -1: + max_iter = float("inf") + + converged = False + + while not converged and max_iter > 0: + # forward step + _, converged = self.gnn( + graph_block.node_dict, + graph_block.edge_dict, + graph_block.certain_atom_ids, + ) + + # decrease iteration counter + max_iter -= 1 + + # check for convergence + if converged: + break + + return graph_block diff --git a/src/asn/solver/solving_context.py b/src/asn/solver/solving_context.py new file mode 100644 index 0000000..4f964b7 --- /dev/null +++ b/src/asn/solver/solving_context.py @@ -0,0 +1,345 @@ +from functools import cached_property +from typing import TYPE_CHECKING, Dict, Iterable, Optional + +import torch +import torch.distributed as dist + +from .stable_model_context import StableModelContext + +if TYPE_CHECKING: + from ground_slash.program import NPPRule + from torch.distributed.distributed_c10d import ProcessGroup + + from asn.solver.graph_block import GraphBlock + + +class SolvingContext: + """TODO""" + + def __init__( + self, + batch_size: int = 1, + npp_ctx_dict: Optional[Dict["NPPRule", torch.Tensor]] = None, + sm_ctx: Optional[StableModelContext] = None, + rank: int = 0, + world_size: int = 1, + group: Optional["ProcessGroup"] = None, + ) -> None: + """TODO""" + self.batch_size = batch_size + self.npp_ctx_dict = npp_ctx_dict + + if world_size < 1: + raise ValueError(f"World size should be positive, but was {world_size}.") + self.world_size = world_size + if rank >= world_size: + raise ValueError(f"Rank should be smaller than world_size, but was {rank}.") + self.rank = rank + + # create group with all processes + if group is None and world_size > 1: + group = dist.new_group(list(range(self.world_size))) + self.group = group + + # stable model context + self.sm_ctx = sm_ctx + + # flag indicating whether or not the solving context is synchronized across processes + self.synchronized = False + + def clear_cache(self, attrs: Optional[Iterable[str]] = None) -> None: + """TODO""" + if attrs is None: + # per default clear all + attrs = ("p_I", "p_Q") + + for attr in attrs: + # clear cached properties + self.__dict__.pop(attr, None) + + def filter_SMs(self, atoms: torch.Tensor, is_SM: torch.Tensor) -> None: + """TODO""" + # NOTE: 'is_SM' should hold the preliminary indicators for stable models + # at the very least it should be initialized to 'is_model' + + # filter for stable models + for k in range(is_SM.shape[1]): + # an interpretation is SM for a given query iff it is a model for the given query + # AND it is not a superset of any other (subsequent) model for the given query + # mult. assignment so that we only consider interpretations that are models themselves + is_SM[:, k, :] *= ( + # test if interpretation is NOT a superset of any subsequent model + ~torch.any( + # check if element-wise OR is identical to itself (i.e., is superset) + # I_k LOR I_j = I_k => I_k SUPSETEQ I_j + torch.all( + torch.eq( + # element-wise OR with all subsequent interpretations + # I_k LOR I_j for j > k + torch.logical_or(atoms[[k], :], atoms[k + 1 :, :]), + # -> n_combinations-(k+1) x n_atoms + atoms[[k], :], + ), + # -> n_combinations-(k+1) x n_atoms + dim=1, + keepdims=True, + ).T + # -> 1 x n_combinations-(k+1) + # multiply (LAND) to only consider other models + # NOTE: 'is_SM' is initialized to 'is_model' + * is_SM[:, k + 1 :, :].squeeze(-1), + # -> n_unique_queries x n_combinations-(k+1) + dim=1, + keepdims=True, + ) + # -> n_unique_queries x 1 + ) + + def update_SMs(self, graph_block: "GraphBlock") -> None: + atoms = graph_block.atoms + # -> n_combinations x n_atoms + + # initialize mask to indicate whether interpretation is a model or not + is_SM = graph_block.is_model + # -> n_unique_queries x n_combinations x 1 + + # NPP contexts + npp_choices = graph_block.npp_choices + # m x n_NPPs + + # filter out any interpretations which are not a model for SOME query + is_SM_for_some = is_SM.any(dim=0).squeeze(-1) + atoms = atoms[is_SM_for_some, :] + is_SM = is_SM[:, is_SM_for_some, :] + npp_choices = npp_choices[is_SM_for_some, :] + + if self.sm_ctx is not None: + # append current stable model (candidates) + atoms = torch.cat((atoms, self.sm_ctx.atoms), dim=0) + is_SM = torch.cat((is_SM, self.sm_ctx.is_SM), dim=1) + npp_choices = torch.cat((npp_choices, self.sm_ctx.npp_choices), dim=0) + + # filter for stable models + self.filter_SMs(atoms, is_SM) + + # mask indicating whether interpretation is a stable model for ANY query + is_SM_for_some = is_SM.any(dim=0).squeeze(-1) + + # store all relevant information to keep track of updated stable models + self.sm_ctx = StableModelContext( + atoms[is_SM_for_some, :], + is_SM[:, is_SM_for_some, :], + npp_choices[is_SM_for_some, :], + graph_block.inverse_unique_sink_ids, + ) + + # indicate that solving context may not be synchronized anymore + self.synchronized = False + + def synchronize_SMs(self) -> None: + if self.world_size > 1: + atoms = self.sm_ctx.atoms + # -> m x n_atoms + # (where m is the number of combinations where the interpretation is considered a SM for some query) + is_SM = self.sm_ctx.is_SM + # -> n_unique_queries x m x 1 + + # NPP contexts + npp_choices = self.sm_ctx.npp_choices + # m x n_NPPs + + # main process + if self.rank == 0: + # receive 'filtered_atoms' from all processes + atoms_list = [None] * self.world_size + dist.gather_object( + atoms, + atoms_list, + dst=0, + group=self.group, + ) + # receive 'filtered_is_model' from all processes + is_SM_list = [None] * self.world_size + dist.gather_object( + is_SM, + is_SM_list, + dst=0, + group=self.group, + ) + # infer number of interpretations received from all processes (for splitting later) + m_list = [t.shape[0] for t in atoms_list] + device = atoms.device + # concatenate received tensors together + atoms_concat = torch.concat([t.to(device) for t in atoms_list], dim=0) + is_SM_concat = torch.concat([t.to(device) for t in is_SM_list], dim=1) + + # filter for stable models + self.filter_SMs(atoms_concat, is_SM_concat) + + # scatter final 'is_SM' entries + scattered_list = [None] + dist.scatter_object_list( + scattered_list, + list(torch.split(is_SM_concat, m_list, dim=1)), + src=0, + group=self.group, + ) + is_SM = scattered_list[0].to(device) + + # mask indicating whether interpretation is a stable model for ANY query + is_SM_for_some = is_SM.any(dim=0).squeeze(-1) + + # store all relevant information to keep track of updated stable models + self.sm_ctx = StableModelContext( + atoms[is_SM_for_some, :], + is_SM[:, is_SM_for_some, :], + npp_choices[is_SM_for_some, :], + self.sm_ctx.inverse_sink_ids, + ) + # secondary process + else: + # send 'filtered_atoms' to main process + dist.gather_object( + atoms, + None, + dst=0, + group=self.group, + ) + # send 'filtered_is_model' to main process + dist.gather_object( + is_SM, + None, + dst=0, + group=self.group, + ) + # receive updated 'is_SM' entries + scattered_list = [None] + dist.scatter_object_list( + scattered_list, + [None] * self.world_size, + src=0, + group=self.group, + ) + # update SM mask + is_SM = scattered_list[0].to(atoms.device) + + # mask indicating whether interpretation is a stable model for ANY query + is_SM_for_some = is_SM.any(dim=0).squeeze(-1) + + # store all relevant information to keep track of updated stable models + self.sm_ctx = StableModelContext( + atoms[is_SM_for_some, :], + is_SM[:, is_SM_for_some, :], + npp_choices[is_SM_for_some, :], + self.sm_ctx.inverse_sink_ids, + ) + + # indicate that solving context is synchronized + self.synchronized = True + + @cached_property + def p_I(self) -> torch.Tensor: + """TODO""" + is_SM = self.sm_ctx.is_SM + + if self.npp_ctx_dict: + return ( + is_SM[self.sm_ctx.inverse_sink_ids] + * torch.cat( + [ + torch.gather( + npp_ctx.p, + # -> batch_size x n_out + -1, + npp_choices.repeat(npp_ctx.p.shape[0], 1) + # -> batch_size x n_combinations + ).unsqueeze(-1) + # -> batch_size x n_combinations x 1 + for npp_ctx, npp_choices in zip( + self.npp_ctx_dict.values(), self.sm_ctx.npp_choices.T + ) + ], + dim=-1, + ).prod(dim=-1, keepdims=True) + / torch.tensor(len(self.npp_ctx_dict)) + ) + # -> batch_size x n_combinations x 1 + else: + return is_SM[self.sm_ctx.inverse_sink_ids] + # -> batch_size x n_combinations x 1 + + @cached_property + def p_Q(self) -> torch.Tensor: + """TODO""" + p_Q = self.p_I.sum(dim=-2) + + if self.world_size > 1: + if self.synchronized: + # reduce incomplete probabilities across all processes + dist.all_reduce( + p_Q, + group=self.group, + ) + else: + raise Exception( + "Computing query probabilities from unsynchronized solving context." + ) + + return p_Q + + @property + def npp_grads(self) -> Dict["NPPRule", torch.Tensor]: + """TODO""" + p_I = self.p_I + p_Q = self.p_Q + + npp_grads = {} + + for (npp, npp_ctx), npp_choices in zip( + self.npp_ctx_dict.items(), self.sm_ctx.npp_choices.T + ): + with torch.no_grad(): + p_c_eq_vi = npp_ctx.p.unsqueeze(1) + # -> batch_size x 1 x n_out + + # TODO: wir haben 'npp_choices' gespeichert + choices_mask = torch.zeros( + npp_choices.shape[0], + npp_ctx.p.shape[1], + device=npp_ctx.p.device, + dtype=torch.bool, + ).scatter_(index=npp_choices.unsqueeze(-1), dim=1, value=1.0) + # -> n_combinations x n_out + + # TODO: umwandeln von 'npp_choices' zu 'choices_mask' + c_eq_vi = p_c_eq_vi * choices_mask + # -> batch_size x n_combinations x n_out + + p_interp_div_c_eq_vi = p_I / c_eq_vi + p_interp_div_c_eq_vi[c_eq_vi == 0.0] = 0.0 + # -> batch_size x n_combinations x n_out + pos_grads = (p_interp_div_c_eq_vi).sum(dim=-2) + # -> batch_size x n_out + neg_grads = pos_grads - pos_grads.sum(dim=-1, keepdims=True) + # -> batch_size x n_out + # TODO: if p(Q) is zero -> division by zero -> NaNs + npp_grads[npp] = (pos_grads + neg_grads) / p_Q + # -> batch_size x n_out + + return npp_grads + + @property + def npp_loss(self) -> torch.Tensor: + """TODO""" + npp_grads = self.npp_grads + loss = torch.tensor(0.0, device=self.p_I.device) + + for npp, grads in npp_grads.items(): + # multiply manual gradients by NPP outputs + # this way the gradients during backward are exactly our gradients + preped_grads = self.npp_ctx_dict[npp].p * grads + # update loss + loss += preped_grads[preped_grads.abs() != 0.0].sum() + + # return normalized loss + return loss / self.batch_size diff --git a/src/asn/solver/stable_model_context.py b/src/asn/solver/stable_model_context.py new file mode 100644 index 0000000..986c2b5 --- /dev/null +++ b/src/asn/solver/stable_model_context.py @@ -0,0 +1,5 @@ +from collections import namedtuple + +StableModelContext = namedtuple( + "StableModelContext", ["atoms", "is_SM", "npp_choices", "inverse_sink_ids"] +) diff --git a/src/asn/utils/__init__.py b/src/asn/utils/__init__.py new file mode 100644 index 0000000..81f903f --- /dev/null +++ b/src/asn/utils/__init__.py @@ -0,0 +1,2 @@ +from .collections import get_minimal_collections +from .relop import eval_dict, relop_dict diff --git a/src/asn/utils/collections.py b/src/asn/utils/collections.py new file mode 100644 index 0000000..71c36d1 --- /dev/null +++ b/src/asn/utils/collections.py @@ -0,0 +1,23 @@ +from collections import defaultdict +from typing import Hashable, Iterable, Tuple + + +def get_minimal_collections( + *collections: Iterable[Hashable], +) -> Tuple[Iterable[Hashable], ...]: + minimal_collections = defaultdict(lambda: None) # order-preserving + + for i, collection in enumerate(collections): + for j, collection_other in enumerate(collections): + if i == j: + continue + if collection >= collection_other: + break + else: + minimal_collections[collection] + + return tuple(minimal_collections) + + +# TODO: rename file (no sets are even used) +# collections? diff --git a/src/asn/utils/load_dotenv.py b/src/asn/utils/load_dotenv.py new file mode 100644 index 0000000..847c293 --- /dev/null +++ b/src/asn/utils/load_dotenv.py @@ -0,0 +1,19 @@ +from dotenv import load_dotenv +import os + +def load_env(): + load_dotenv() + + # Access the environment variables + hf_home = os.getenv('HF_HOME') + huggingface_hub_cache = os.getenv('HUGGINGFACE_HUB_CACHE') + wandb_api_key = os.getenv('WANDB_API_KEY') + wandb_project = os.getenv('WANDB_PROJECT') + + print("-----------------------------") + print("Environment variables loaded:") + print("HF_HOME",hf_home) + print("HUGGINGFACE_HUB_CACHE",huggingface_hub_cache) + #print("WANDB_API_KEY",wandb_api_key + print("WANDB_PROJECT",wandb_project) + print("-----------------------------") \ No newline at end of file diff --git a/src/asn/utils/relop.py b/src/asn/utils/relop.py new file mode 100644 index 0000000..e57366f --- /dev/null +++ b/src/asn/utils/relop.py @@ -0,0 +1,26 @@ +import torch +from ground_slash.program.operators import RelOp + +relop_dict = { + op: i + for i, op in enumerate( + ( + RelOp.EQUAL, + RelOp.UNEQUAL, + RelOp.LESS, + RelOp.GREATER, + RelOp.LESS_OR_EQ, + RelOp.GREATER_OR_EQ, + ) + ) +} + + +eval_dict = { + 0: torch.eq, + 1: torch.ne, + 2: torch.lt, + 3: torch.gt, + 4: torch.le, + 5: torch.ge, +} diff --git a/src/format.sh b/src/format.sh new file mode 100755 index 0000000..ee811bc --- /dev/null +++ b/src/format.sh @@ -0,0 +1,3 @@ +#!/bin/bash +isort . --profile black +black . \ No newline at end of file diff --git a/src/tests/__init__.py b/src/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tests/data/__init__.py b/src/tests/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tests/data/test_reasoning_graph.py b/src/tests/data/test_reasoning_graph.py new file mode 100644 index 0000000..0ebc51c --- /dev/null +++ b/src/tests/data/test_reasoning_graph.py @@ -0,0 +1,2882 @@ +import unittest +from collections import namedtuple +from typing import Any, Dict, List + +from ground_slash.program import ( + NPP, + AggrCount, + AggrElement, + AggrLiteral, + AggrMax, + AggrMin, + AggrSum, + Choice, + ChoiceElement, + FalseConstant, + Functional, + Guard, + LiteralCollection, + Naf, + Number, + PredLiteral, + Program, + RelOp, + SymbolicConstant, + TrueConstant, +) + +from asn.data.reasoning_graph import ReasoningGraph + +# can then also conveniently add defaults !!! (really nice to generate/manipulate rg???) +# TODO: id ??? + +AtomNode = namedtuple("atom", ["label", "x", "aux"]) +ConjNode = namedtuple("conj", ["label", "x"]) + +CountNode = namedtuple("count", ["label", "x", "guards"]) +SumNode = namedtuple("sum", ["label", "x", "guards"]) +MinNode = namedtuple("min", ["label", "x", "guards"]) +MaxNode = namedtuple("max", ["label", "x", "guards"]) + +# atom -> conj +Atom2ConjEdge = namedtuple("atom_in_conj", ["edge_index", "edge_weight"]) +# conj -> atom +Conj2AtomEdge = namedtuple("conj_defines_atom", ["edge_index", "edge_weight"]) +# atom -> aggr +Atom2CountEdge = namedtuple("atom_in_count", ["edge_index", "edge_weight"]) +Atom2SumEdge = namedtuple("atom_in_sum", ["edge_index", "edge_weight"]) +Atom2MinEdge = namedtuple("atom_in_min", ["edge_index", "edge_weight"]) +Atom2MaxEdge = namedtuple("atom_in_max", ["edge_index", "edge_weight"]) +# aggr -> atom +Count2AtomEdge = namedtuple("count_defines_atom", ["edge_index", "edge_weight"]) +Sum2AtomEdge = namedtuple("sum_defines_atom", ["edge_index", "edge_weight"]) +Min2AtomEdge = namedtuple("min_defines_atom", ["edge_index", "edge_weight"]) +Max2AtomEdge = namedtuple("max_defines_atom", ["edge_index", "edge_weight"]) + + +def zip_nodes(node_dict: Dict[str, Dict[str, List[Any]]]) -> Dict: + return { + node.__name__: list( + node(*tup) + for tup in zip(*[node_dict[node.__name__][attr] for attr in node._fields]) + ) + for node in (AtomNode, ConjNode, CountNode, SumNode, MinNode, MaxNode) + } + + +def zip_edges(edge_dict: Dict[str, Dict[str, List[Any]]]) -> Dict: + return { + tuple(edge.__name__.split("_")): list( + edge(*tup) + for tup in zip( + *[ + edge_dict[tuple(edge.__name__.split("_"))][attr] + for attr in edge._fields + ] + ) + ) + for edge in ( + Atom2ConjEdge, + Conj2AtomEdge, + Atom2CountEdge, + Atom2SumEdge, + Atom2MinEdge, + Atom2MaxEdge, + Count2AtomEdge, + Sum2AtomEdge, + Min2AtomEdge, + Max2AtomEdge, + ) + } + + +class TestRegionGraph(unittest.TestCase): + def test_init(self): + # empty program + prog = Program(()) + + # create reasoning graph + # (empty) except for basic initialization + rg = ReasoningGraph(prog) + + # basic attributes + self.assertEqual(rg.aux_counter, 0) + self.assertEqual(rg.atom_in_conj, ("atom", "in", "conj")) + self.assertEqual(rg.conj_defines_atom, ("conj", "defines", "atom")) + self.assertEqual(rg.atom_in_count, ("atom", "in", "count")) + self.assertEqual(rg.atom_in_sum, ("atom", "in", "sum")) + self.assertEqual(rg.atom_in_min, ("atom", "in", "min")) + self.assertEqual(rg.atom_in_max, ("atom", "in", "max")) + self.assertEqual(rg.count_defines_atom, ("count", "defines", "atom")) + self.assertEqual(rg.sum_defines_atom, ("sum", "defines", "atom")) + self.assertEqual(rg.min_defines_atom, ("min", "defines", "atom")) + self.assertEqual(rg.max_defines_atom, ("max", "defines", "atom")) + self.assertEqual( + rg.aggr_map, + { + AggrCount(): ("count", r"\#"), + AggrSum(): ("sum", r"\Sigma"), + AggrMin(): ("min", "MIN"), + AggrMax(): ("max", "MAX"), + }, + ) + self.assertEqual(rg.true_const, TrueConstant()) + self.assertEqual(rg.false_const, FalseConstant()) + + # zip node and edges attributes for convenience (since order may differ) + nodes = zip_nodes(rg.node_dict) + edges = zip_edges(rg.edge_dict) + + # atom nodes + self.assertEqual( + set(nodes["atom"]), + {AtomNode(r"$\top$", 1.0, True), AtomNode(r"$\bot$", 0.0, True)}, + ) + # conjunction nodes + self.assertFalse(nodes["conj"]) + # aggregate nodes + self.assertFalse(nodes["count"]) + self.assertFalse(nodes["sum"]) + self.assertFalse(nodes["min"]) + self.assertFalse(nodes["max"]) + + # atom -> conj + self.assertFalse(edges[("atom", "in", "conj")]) + # conj -> atom + self.assertFalse(edges[("conj", "defines", "atom")]) + # atom -> aggr + self.assertFalse(edges[("atom", "in", "count")]) + self.assertFalse(edges[("atom", "in", "sum")]) + self.assertFalse(edges[("atom", "in", "min")]) + self.assertFalse(edges[("atom", "in", "max")]) + # aggr -> atom + self.assertFalse(edges[("count", "defines", "atom")]) + self.assertFalse(edges[("sum", "defines", "atom")]) + self.assertFalse(edges[("min", "defines", "atom")]) + self.assertFalse(edges[("max", "defines", "atom")]) + + # node id dictionaries + self.assertEqual(rg.atom_ids[rg.true_const], 0) + self.assertEqual(rg.atom_ids[rg.false_const], 1) + self.assertEqual(len(rg.conj_ids), 0) + self.assertEqual(len(rg.aggr_ids), 0) + + # choice tracking + self.assertEqual(len(rg.choices), 0) + self.assertEqual(len(rg.choice_edges), 0) + + def test_normal_fact(self): + prog = Program.from_string( + r""" + + a. + + """ + ) + + # create reasoning graph + rg = ReasoningGraph(prog) + + # zip node and edges attributes for convenience (since order may differ) + nodes = zip_nodes(rg.node_dict) + edges = zip_edges(rg.edge_dict) + + # atom nodes + self.assertEqual( + set(nodes["atom"]), + { + AtomNode(r"$\top$", 1.0, True), + AtomNode(r"$\bot$", 0.0, True), + AtomNode("a", 1.0, False), + }, + ) + # conjunction nodes + self.assertEqual(set(nodes["conj"]), {ConjNode(r"$\wedge_0$", 0.0)}) + # aggregate nodes + self.assertFalse(nodes["count"]) + self.assertFalse(nodes["sum"]) + self.assertFalse(nodes["min"]) + self.assertFalse(nodes["max"]) + + # atom -> conj + self.assertEqual( + set(edges[("atom", "in", "conj")]), {Atom2ConjEdge((0, 0), 1.0)} + ) + # conj -> atom + self.assertEqual( + set(edges[("conj", "defines", "atom")]), {Conj2AtomEdge((0, 2), 1.0)} + ) + # atom -> aggr + self.assertFalse(edges[("atom", "in", "count")]) + self.assertFalse(edges[("atom", "in", "sum")]) + self.assertFalse(edges[("atom", "in", "min")]) + self.assertFalse(edges[("atom", "in", "max")]) + # aggr -> atom + self.assertFalse(edges[("count", "defines", "atom")]) + self.assertFalse(edges[("sum", "defines", "atom")]) + self.assertFalse(edges[("min", "defines", "atom")]) + self.assertFalse(edges[("max", "defines", "atom")]) + + # node id dictionaries + self.assertEqual(rg.atom_ids[rg.true_const], 0) + self.assertEqual(rg.atom_ids[rg.false_const], 1) + self.assertEqual(rg.conj_ids[LiteralCollection(TrueConstant())], 0) + self.assertEqual(len(rg.aggr_ids), 0) + + # choice tracking + self.assertEqual(len(rg.choices), 0) + self.assertEqual(len(rg.choice_edges), 0) + + def test_normal_facts(self): + prog = Program.from_string( + r""" + + a. + b. + + """ + ) + # conj. should be reused for both facts + + # create reasoning graph + rg = ReasoningGraph(prog) + + # zip node and edges attributes for convenience (since order may differ) + nodes = zip_nodes(rg.node_dict) + edges = zip_edges(rg.edge_dict) + + # atom nodes + self.assertEqual( + set(nodes["atom"]), + { + AtomNode(r"$\top$", 1.0, True), + AtomNode(r"$\bot$", 0.0, True), + AtomNode("a", 1.0, False), + AtomNode("b", 1.0, False), + }, + ) + # conjunction nodes + self.assertEqual(set(nodes["conj"]), {ConjNode(r"$\wedge_0$", 0.0)}) + # aggregate nodes + self.assertFalse(nodes["count"]) + self.assertFalse(nodes["sum"]) + self.assertFalse(nodes["min"]) + self.assertFalse(nodes["max"]) + + # atom -> conj + self.assertEqual( + set(edges[("atom", "in", "conj")]), {Atom2ConjEdge((0, 0), 1.0)} + ) + # conj -> atom + self.assertEqual( + set(edges[("conj", "defines", "atom")]), + {Conj2AtomEdge((0, 2), 1.0), Conj2AtomEdge((0, 3), 1.0)}, + ) + # atom -> aggr + self.assertFalse(edges[("atom", "in", "count")]) + self.assertFalse(edges[("atom", "in", "sum")]) + self.assertFalse(edges[("atom", "in", "min")]) + self.assertFalse(edges[("atom", "in", "max")]) + # aggr -> atom + self.assertFalse(edges[("count", "defines", "atom")]) + self.assertFalse(edges[("sum", "defines", "atom")]) + self.assertFalse(edges[("min", "defines", "atom")]) + self.assertFalse(edges[("max", "defines", "atom")]) + + # node id dictionaries + self.assertEqual(rg.atom_ids[rg.true_const], 0) + self.assertEqual(rg.atom_ids[rg.false_const], 1) + self.assertEqual(rg.conj_ids[LiteralCollection(TrueConstant())], 0) + self.assertEqual(len(rg.aggr_ids), 0) + + # choice tracking + self.assertEqual(len(rg.choices), 0) + self.assertEqual(len(rg.choice_edges), 0) + + def test_normal_rule(self): + prog = Program.from_string( + r""" + + a :- b, not c. + + """ + ) + + # create reasoning graph + rg = ReasoningGraph(prog) + + # zip node and edges attributes for convenience (since order may differ) + nodes = zip_nodes(rg.node_dict) + edges = zip_edges(rg.edge_dict) + + # atom nodes + self.assertEqual( + nodes["atom"], + [ + AtomNode(r"$\top$", 1.0, True), + AtomNode(r"$\bot$", 0.0, True), + AtomNode("a", 0.0, False), + AtomNode("b", 0.0, False), + AtomNode("c", 0.0, False), + ], + ) + # conjunction nodes + self.assertEqual(nodes["conj"], [ConjNode(r"$\wedge_0$", 0.0)]) + # aggregate nodes + self.assertFalse(nodes["count"]) + self.assertFalse(nodes["sum"]) + self.assertFalse(nodes["min"]) + self.assertFalse(nodes["max"]) + + # atom -> conj + self.assertEqual( + set(edges[("atom", "in", "conj")]), + {Atom2ConjEdge((3, 0), 1.0), Atom2ConjEdge((4, 0), -1.0)}, + ) + # conj -> atom + self.assertEqual( + set(edges[("conj", "defines", "atom")]), + {Conj2AtomEdge((0, 2), 1.0)}, + ) + # atom -> aggr + self.assertFalse(edges[("atom", "in", "count")]) + self.assertFalse(edges[("atom", "in", "sum")]) + self.assertFalse(edges[("atom", "in", "min")]) + self.assertFalse(edges[("atom", "in", "max")]) + # aggr -> atom + self.assertFalse(edges[("count", "defines", "atom")]) + self.assertFalse(edges[("sum", "defines", "atom")]) + self.assertFalse(edges[("min", "defines", "atom")]) + self.assertFalse(edges[("max", "defines", "atom")]) + + # node id dictionaries + self.assertEqual( + rg.atom_ids, + { + rg.true_const: 0, + rg.false_const: 1, + PredLiteral("a"): 2, + PredLiteral("b"): 3, + PredLiteral("c"): 4, + }, + ) + self.assertEqual( + rg.conj_ids, {LiteralCollection(PredLiteral("b"), Naf(PredLiteral("c"))): 0} + ) + self.assertEqual(len(rg.aggr_ids), 0) + + # choice tracking + self.assertEqual(len(rg.choices), 0) + self.assertEqual(len(rg.choice_edges), 0) + + def test_disjunctive_fact(self): + prog = Program.from_string( + r""" + + a | b. + + """ + ) + + # create reasoning graph + rg = ReasoningGraph(prog) + + # zip node and edges attributes for convenience (since order may differ) + nodes = zip_nodes(rg.node_dict) + edges = zip_edges(rg.edge_dict) + + # atom nodes + self.assertEqual( + set(nodes["atom"]), + { + AtomNode(r"$\top$", 1.0, True), + AtomNode(r"$\bot$", 0.0, True), + AtomNode("a", 0.0, False), + AtomNode("b", 0.0, False), + AtomNode(r"$\vee_0$", 0.0, True), + }, + ) + # conjunction nodes + self.assertEqual( + set(nodes["conj"]), + {ConjNode(r"$\wedge_0$", 0.0), ConjNode(r"$\wedge_1$", 0.0)}, + ) + # aggregate nodes + self.assertFalse(nodes["count"]) + self.assertFalse(nodes["sum"]) + self.assertFalse(nodes["min"]) + self.assertFalse(nodes["max"]) + + # atom -> conj + self.assertEqual( + set(edges[("atom", "in", "conj")]), + { + Atom2ConjEdge((0, 0), 1.0), + Atom2ConjEdge((2, 1), -1.0), + Atom2ConjEdge((3, 1), -1.0), + Atom2ConjEdge((4, 1), 1.0), + }, + ) + # conj -> atom + self.assertEqual( + set(edges[("conj", "defines", "atom")]), + { + Conj2AtomEdge((0, 2), 1.0), + Conj2AtomEdge((0, 3), 1.0), + Conj2AtomEdge((0, 4), 1.0), + Conj2AtomEdge((1, 1), 1.0), + }, + ) + # atom -> aggr + self.assertFalse(edges[("atom", "in", "count")]) + self.assertFalse(edges[("atom", "in", "sum")]) + self.assertFalse(edges[("atom", "in", "min")]) + self.assertFalse(edges[("atom", "in", "max")]) + # aggr -> atom + self.assertFalse(edges[("count", "defines", "atom")]) + self.assertFalse(edges[("sum", "defines", "atom")]) + self.assertFalse(edges[("min", "defines", "atom")]) + self.assertFalse(edges[("max", "defines", "atom")]) + + # node id dictionaries + self.assertEqual(rg.atom_ids[rg.true_const], 0) + self.assertEqual(rg.atom_ids[rg.false_const], 1) + self.assertEqual(rg.conj_ids[LiteralCollection(TrueConstant())], 0) + self.assertEqual(len(rg.aggr_ids), 0) + + # choice tracking + self.assertEqual( + rg.choices, {LiteralCollection(PredLiteral("a"), PredLiteral("b"))} + ) + self.assertEqual( + rg.choice_edges, + { + LiteralCollection(PredLiteral("a"), PredLiteral("b")): [ + 0, # (0, 2) + 1, # (0, 3) + ] + # '$\wedge_0$' to 'a', 'b' + }, + ) + + def test_disjunctive_rule(self): + prog = Program.from_string( + r""" + + a | b :- c, not d. + + """ + ) + + # create reasoning graph + rg = ReasoningGraph(prog) + + # zip node and edges attributes for convenience (since order may differ) + nodes = zip_nodes(rg.node_dict) + edges = zip_edges(rg.edge_dict) + + # atom nodes + self.assertEqual( + nodes["atom"], + [ + AtomNode(r"$\top$", 1.0, True), + AtomNode(r"$\bot$", 0.0, True), + AtomNode("a", 0.0, False), + AtomNode("b", 0.0, False), + AtomNode("c", 0.0, False), + AtomNode("d", 0.0, False), + AtomNode(r"$\vee_0$", 0.0, True), + ], + ) + # conjunction nodes + self.assertEqual( + nodes["conj"], [ConjNode(r"$\wedge_0$", 0.0), ConjNode(r"$\wedge_1$", 0.0)] + ) + # aggregate nodes + self.assertFalse(nodes["count"]) + self.assertFalse(nodes["sum"]) + self.assertFalse(nodes["min"]) + self.assertFalse(nodes["max"]) + + # atom -> conj + self.assertEqual( + set(edges[("atom", "in", "conj")]), + { + # body conjunction + Atom2ConjEdge((4, 0), 1.0), + Atom2ConjEdge((5, 0), -1.0), + # disjunction constraint + Atom2ConjEdge((2, 1), -1.0), + Atom2ConjEdge((3, 1), -1.0), + Atom2ConjEdge((6, 1), 1.0), + }, + ) + # conj -> atom + self.assertEqual( + set(edges[("conj", "defines", "atom")]), + { + # body to head + Conj2AtomEdge((0, 2), 1.0), + Conj2AtomEdge((0, 3), 1.0), + # body to aux. atom + Conj2AtomEdge((0, 6), 1.0), + # disjunction constraint + Conj2AtomEdge((1, 1), 1.0), + }, + ) + # atom -> aggr + self.assertFalse(edges[("atom", "in", "count")]) + self.assertFalse(edges[("atom", "in", "sum")]) + self.assertFalse(edges[("atom", "in", "min")]) + self.assertFalse(edges[("atom", "in", "max")]) + # aggr -> atom + self.assertFalse(edges[("count", "defines", "atom")]) + self.assertFalse(edges[("sum", "defines", "atom")]) + self.assertFalse(edges[("min", "defines", "atom")]) + self.assertFalse(edges[("max", "defines", "atom")]) + + # node id dictionaries + self.assertEqual( + rg.atom_ids, + { + rg.true_const: 0, + rg.false_const: 1, + PredLiteral("a"): 2, + PredLiteral("b"): 3, + PredLiteral("c"): 4, + PredLiteral("d"): 5, + LiteralCollection(PredLiteral("c"), Naf(PredLiteral("d"))): 6, + }, + ) + self.assertEqual( + rg.conj_ids, + { + LiteralCollection(PredLiteral("c"), Naf(PredLiteral("d"))): 0, + LiteralCollection( + Naf(PredLiteral("a")), + Naf(PredLiteral("b")), + PredLiteral("c"), + Naf(PredLiteral("d")), + ): 1, + }, + ) + self.assertEqual(len(rg.aggr_ids), 0) + + # choice tracking + self.assertEqual( + rg.choices, {LiteralCollection(PredLiteral("a"), PredLiteral("b"))} + ) + self.assertEqual( + rg.choice_edges, + {LiteralCollection(PredLiteral("a"), PredLiteral("b")): [0, 1]}, + ) + + def test_constraint(self): + prog = Program.from_string( + r""" + + :- a, not b. + + """ + ) + + # create reasoning graph + rg = ReasoningGraph(prog) + + # zip node and edges attributes for convenience (since order may differ) + nodes = zip_nodes(rg.node_dict) + edges = zip_edges(rg.edge_dict) + + # atom nodes + self.assertEqual( + nodes["atom"], + [ + AtomNode(r"$\top$", 1.0, True), + AtomNode(r"$\bot$", 0.0, True), + AtomNode("a", 0.0, False), + AtomNode("b", 0.0, False), + ], + ) + # conjunction nodes + self.assertEqual(nodes["conj"], [ConjNode(r"$\wedge_0$", 0.0)]) + # aggregate nodes + self.assertFalse(nodes["count"]) + self.assertFalse(nodes["sum"]) + self.assertFalse(nodes["min"]) + self.assertFalse(nodes["max"]) + + # atom -> conj + self.assertEqual( + set(edges[("atom", "in", "conj")]), + {Atom2ConjEdge((2, 0), 1.0), Atom2ConjEdge((3, 0), -1.0)}, + ) + # conj -> atom + self.assertEqual( + set(edges[("conj", "defines", "atom")]), + {Conj2AtomEdge((0, 1), 1.0)}, + ) + # atom -> aggr + self.assertFalse(edges[("atom", "in", "count")]) + self.assertFalse(edges[("atom", "in", "sum")]) + self.assertFalse(edges[("atom", "in", "min")]) + self.assertFalse(edges[("atom", "in", "max")]) + # aggr -> atom + self.assertFalse(edges[("count", "defines", "atom")]) + self.assertFalse(edges[("sum", "defines", "atom")]) + self.assertFalse(edges[("min", "defines", "atom")]) + self.assertFalse(edges[("max", "defines", "atom")]) + + # node id dictionaries + self.assertEqual( + rg.atom_ids, + { + rg.true_const: 0, + rg.false_const: 1, + PredLiteral("a"): 2, + PredLiteral("b"): 3, + }, + ) + self.assertEqual( + rg.conj_ids, {LiteralCollection(PredLiteral("a"), Naf(PredLiteral("b"))): 0} + ) + self.assertEqual(len(rg.aggr_ids), 0) + + # choice tracking + self.assertEqual(len(rg.choices), 0) + self.assertEqual(len(rg.choice_edges), 0) + + def test_empty_constraint(self): + # TODO: not supported by ground_slash yet + pass + + def test_choice_fact(self): + prog = Program.from_string( + r""" + + {a;b:d;b:e,not f;b:not f;c}. + + """ + ) + + # create reasoning graph + rg = ReasoningGraph(prog) + + # zip node and edges attributes for convenience (since order may differ) + nodes = zip_nodes(rg.node_dict) + edges = zip_edges(rg.edge_dict) + + # atom nodes + self.assertEqual( + nodes["atom"], + [ + AtomNode(r"$\top$", 1.0, True), + AtomNode(r"$\bot$", 0.0, True), + AtomNode("a", 0.0, False), + AtomNode("b", 0.0, False), + AtomNode("c", 0.0, False), + AtomNode(r"$\vee_0$", 0.0, True), + AtomNode(r"$\vee_1$", 0.0, True), + AtomNode("d", 0.0, False), + AtomNode("e", 0.0, False), + AtomNode("f", 0.0, False), + AtomNode(r"$\vee_2$", 0.0, True), + AtomNode(r"$\vee_3$", 0.0, True), + AtomNode(r"$\vee_4$", 0.0, True), + ], + ) + # conjunction nodes + self.assertEqual( + nodes["conj"], + [ + ConjNode(r"$\wedge_0$", 0.0), + ConjNode(r"$\wedge_1$", 0.0), + ConjNode(r"$\wedge_2$", 0.0), + ConjNode(r"$\wedge_3$", 0.0), + ConjNode(r"$\wedge_4$", 0.0), + ConjNode(r"$\wedge_5$", 0.0), + ConjNode(r"$\wedge_6$", 0.0), + ], + ) + # aggregate nodes + self.assertEqual(set(nodes["count"]), {CountNode(r"$\#_0$", 0.0, (-1,) * 4)}) + self.assertFalse(nodes["sum"]) + self.assertFalse(nodes["min"]) + self.assertFalse(nodes["max"]) + + # atom -> conj + self.assertEqual( + set(edges[("atom", "in", "conj")]), + { + # body to conj. + Atom2ConjEdge((0, 0), 1.0), # 'True' to 'conj0' ('True') + # global constraint for choice (given sat. body) + Atom2ConjEdge((5, 1), 1.0), # 'disj0' (body) to 'conj1' + Atom2ConjEdge((6, 1), -1.0), # 'disj1' (aggr) to 'conj1' + # cond. constraint for 'a' + Atom2ConjEdge((5, 2), 1.0), # 'disj0' (body) to 'conj4' + Atom2ConjEdge((2, 2), 1.0), # 'a' to 'conj4' + Atom2ConjEdge((10, 2), -1.0), # 'disj5' ('a' condition) to 'conj4' + # cond. constraint for 'b' + Atom2ConjEdge((5, 5), 1.0), # 'disj0' (body) o 'conj5' + Atom2ConjEdge((3, 5), 1.0), # 'b' to 'conj5' + Atom2ConjEdge((11, 5), -1.0), # 'disj6' ('b' condition) to 'conj5' + # cond. candidates for 'b' choice + Atom2ConjEdge((7, 3), 1.0), # 'd' to 'conj2' + Atom2ConjEdge((9, 4), -1.0), # 'not f' to 'conj3' + # cond. constraint for 'c' + Atom2ConjEdge((5, 6), 1.0), # 'disj0' (body) to 'conj6' + Atom2ConjEdge((4, 6), 1.0), # 'c' to 'conj6' + Atom2ConjEdge((12, 6), -1.0), # 'disj7' ('c' condition) to 'conj6' + }, + ) + # conj -> atom + self.assertEqual( + set(edges[("conj", "defines", "atom")]), + { + # conj. to head + Conj2AtomEdge((0, 2), 1.0), # 'conj0' ('True') to 'a' + Conj2AtomEdge((0, 3), 1.0), # 'conj0' ('True') to 'b' + Conj2AtomEdge((0, 4), 1.0), # 'conj0' ('True') to 'c' + # conj. to aux. body atom + Conj2AtomEdge((0, 5), 1.0), # 'conj0' ('True') to 'disj0' + # cond. for 'a' + Conj2AtomEdge((0, 10), 1.0), # 'conj0' ('True') to 'disj10' + # cond. for 'b' + Conj2AtomEdge((3, 11), 1.0), # 'conj3' ('True') to 'disj11' + Conj2AtomEdge((4, 11), 1.0), # 'conj4' ('True') to 'disj11' + # cond. for 'c' + Conj2AtomEdge((0, 12), 1.0), # 'conj0' ('True') to 'disj12' + # global choice constraint + Conj2AtomEdge((1, 1), 1.0), # 'conj1' to 'False' + # local constr. for 'a' + Conj2AtomEdge((2, 1), 1.0), # 'conj1' to 'False' + # local constr. for 'b' + Conj2AtomEdge((5, 1), 1.0), # 'conj5' to 'False' + # local constr. for 'c' + Conj2AtomEdge((6, 1), 1.0), # 'conj6' to 'False' + }, + ) + # atom -> aggr + self.assertEqual( + set(edges[("atom", "in", "count")]), + { + Atom2CountEdge((2, 0), 1.0), # 'a' to choice count + Atom2CountEdge((3, 0), 1.0), # 'b' to choice count + Atom2CountEdge((4, 0), 1.0), # 'c' to choice count + }, + ) + self.assertFalse(edges[("atom", "in", "sum")]) + self.assertFalse(edges[("atom", "in", "min")]) + self.assertFalse(edges[("atom", "in", "max")]) + # aggr -> atom + self.assertEqual( + edges[("count", "defines", "atom")], + [ + Count2AtomEdge((0, 6), 1.0), # choice count to 'disj1' (aggr) + ], + ) + self.assertFalse(edges[("sum", "defines", "atom")]) + self.assertFalse(edges[("min", "defines", "atom")]) + self.assertFalse(edges[("max", "defines", "atom")]) + + # node id dictionaries + self.assertEqual(rg.atom_ids[rg.true_const], 0) + self.assertEqual(rg.atom_ids[rg.false_const], 1) + self.assertEqual( + rg.atom_ids, + { + rg.true_const: 0, + rg.false_const: 1, + PredLiteral("a"): 2, + PredLiteral("b"): 3, + PredLiteral("c"): 4, + LiteralCollection(rg.true_const): 5, + AggrLiteral( + AggrCount(), + ( + AggrElement([Functional("a")], [PredLiteral("a")]), + AggrElement([Functional("b")], [PredLiteral("b")]), + AggrElement([Functional("c")], [PredLiteral("c")]), + ), + guards=(None, None), + ): 6, + PredLiteral("d"): 7, + PredLiteral("e"): 8, + PredLiteral("f"): 9, + }, + ) + self.assertEqual( + rg.conj_ids, + { + LiteralCollection(rg.true_const): 0, + LiteralCollection(PredLiteral("d")): 3, + LiteralCollection(Naf(PredLiteral("f"))): 4, + # NOTE: constraint conjunctions are not re-used + }, + ) + self.assertEqual( + rg.aggr_ids, + { + AggrLiteral( + AggrCount(), + ( + AggrElement([Functional("a")], [PredLiteral("a")]), + AggrElement([Functional("b")], [PredLiteral("b")]), + AggrElement([Functional("c")], [PredLiteral("c")]), + ), + guards=(None, None), + ): 0, + }, + ) + + # choice tracking + self.assertEqual( + rg.choices, + { + Choice( + ( + ChoiceElement(PredLiteral("a")), + ChoiceElement(PredLiteral("b"), [PredLiteral("d")]), + ChoiceElement( + PredLiteral("b"), [PredLiteral("e"), Naf(PredLiteral("f"))] + ), + ChoiceElement(PredLiteral("b"), [Naf(PredLiteral("f"))]), + ChoiceElement(PredLiteral("c")), + ), + guards=(None, None), + ), + }, + ) + self.assertEqual( + rg.choice_edges, + { + Choice( + ( + ChoiceElement(PredLiteral("a")), + ChoiceElement(PredLiteral("b"), [PredLiteral("d")]), + ChoiceElement( + PredLiteral("b"), [PredLiteral("e"), Naf(PredLiteral("f"))] + ), + ChoiceElement(PredLiteral("b"), [Naf(PredLiteral("f"))]), + ChoiceElement(PredLiteral("c")), + ), + guards=(None, None), + ): [ + 0, # (0, 2) + 1, # (0, 3) + 2, # (0, 4) + ] + # '$\wedge_0$' to 'a', 'b', 'c' + }, + ) + + def test_count_aggregate(self): + prog = Program.from_string( + r""" + + a :- #count{1;2:b;2:c,not d;2:not d;3}. + + """ + ) + + # create reasoning graph + rg = ReasoningGraph(prog) + + # zip node and edges attributes for convenience (since order may differ) + nodes = zip_nodes(rg.node_dict) + edges = zip_edges(rg.edge_dict) + + # atom nodes + self.assertEqual( + nodes["atom"], + [ + AtomNode(r"$\top$", 1.0, True), + AtomNode(r"$\bot$", 0.0, True), + AtomNode("a", 0.0, False), + AtomNode(r"$\vee_0$", 0.0, True), + AtomNode("b", 0.0, False), + AtomNode("c", 0.0, False), + AtomNode("d", 0.0, False), + AtomNode(r"$\vee_1$", 0.0, True), + ], + ) + # conjunction nodes + self.assertEqual( + set(nodes["conj"]), + { + ConjNode(r"$\wedge_0$", 0.0), + ConjNode(r"$\wedge_1$", 0.0), + ConjNode(r"$\wedge_2$", 0.0), + }, + ) + # aggregate nodes + self.assertEqual(set(nodes["count"]), {SumNode(r"$\#_0$", 0.0, (-1,) * 4)}) + self.assertFalse(nodes["sum"]) + self.assertFalse(nodes["min"]) + self.assertFalse(nodes["max"]) + + # atom -> conj + self.assertEqual( + set(edges[("atom", "in", "conj")]), + { + # body + Atom2ConjEdge((3, 2), 1.0), # 'disj0' (aggr) to 'conj2' (body) + # conditions for '2' + Atom2ConjEdge((4, 0), 1.0), # 'b' to 'conj0' + Atom2ConjEdge((6, 1), -1.0), # 'f' to 'conj1' + }, + ) + # conj -> atom + self.assertEqual( + set(edges[("conj", "defines", "atom")]), + { + # body to head + Conj2AtomEdge((2, 2), 1.0), # 'conj2' (body) to 'a' + # conditions for '2' + Conj2AtomEdge((0, 7), 1.0), # 'conj0' to 'disj1' + Conj2AtomEdge((1, 7), 1.0), # 'conj1' to 'disj1' + }, + ) + # atom -> aggr + self.assertEqual( + set(edges[("atom", "in", "count")]), + { + Atom2CountEdge((0, 0), 2.0), # 'True' to choice count ('1','3') + Atom2CountEdge((7, 0), 1.0), # 'disj1' to choice count ('2') + }, + ) + self.assertFalse(edges[("atom", "in", "sum")]) + self.assertFalse(edges[("atom", "in", "min")]) + self.assertFalse(edges[("atom", "in", "max")]) + # aggr -> atom + self.assertEqual( + edges[("count", "defines", "atom")], + [ + Count2AtomEdge((0, 3), 1.0), # choice count to 'disj0' (aggr) + ], + ) + self.assertFalse(edges[("sum", "defines", "atom")]) + self.assertFalse(edges[("min", "defines", "atom")]) + self.assertFalse(edges[("max", "defines", "atom")]) + + # node id dictionaries + self.assertEqual(rg.atom_ids[rg.true_const], 0) + self.assertEqual(rg.atom_ids[rg.false_const], 1) + self.assertEqual( + rg.atom_ids, + { + rg.true_const: 0, + rg.false_const: 1, + PredLiteral("a"): 2, + AggrLiteral( + AggrCount(), + ( + AggrElement([Number(1)]), + AggrElement([Number(2)], [PredLiteral("b")]), + AggrElement( + [Number(2)], [PredLiteral("c"), Naf(PredLiteral("d"))] + ), + AggrElement([Number(2)], [Naf(PredLiteral("d"))]), + AggrElement([Number(3)]), + ), + guards=(None, None), + ): 3, + PredLiteral("b"): 4, + PredLiteral("c"): 5, + PredLiteral("d"): 6, + # TODO: local conditions not reused + }, + ) + self.assertEqual( + rg.conj_ids, + { + LiteralCollection(PredLiteral("b")): 0, + LiteralCollection(Naf(PredLiteral("d"))): 1, + LiteralCollection( + AggrLiteral( + AggrCount(), + ( + AggrElement([Number(1)]), + AggrElement([Number(2)], [PredLiteral("b")]), + AggrElement( + [Number(2)], [PredLiteral("c"), Naf(PredLiteral("d"))] + ), + AggrElement([Number(2)], [Naf(PredLiteral("d"))]), + AggrElement([Number(3)]), + ), + guards=(None, None), + ) + ): 2, + }, + ) + self.assertEqual( + rg.aggr_ids, + { + AggrLiteral( + AggrCount(), + ( + AggrElement([Number(1)]), + AggrElement([Number(2)], [PredLiteral("b")]), + AggrElement( + [Number(2)], [PredLiteral("c"), Naf(PredLiteral("d"))] + ), + AggrElement([Number(2)], [Naf(PredLiteral("d"))]), + AggrElement([Number(3)]), + ), + guards=(None, None), + ): 0 + }, + ) + + # choice tracking + self.assertFalse(rg.choices) + self.assertFalse(rg.choice_edges) + + def test_sum_aggregate(self): + prog = Program.from_string( + r""" + + a :- #sum{1;2:b;2:c,not d;2:not d;3}. + + """ + ) + + # create reasoning graph + rg = ReasoningGraph(prog) + + # zip node and edges attributes for convenience (since order may differ) + nodes = zip_nodes(rg.node_dict) + edges = zip_edges(rg.edge_dict) + + # atom nodes + self.assertEqual( + nodes["atom"], + [ + AtomNode(r"$\top$", 1.0, True), + AtomNode(r"$\bot$", 0.0, True), + AtomNode("a", 0.0, False), + AtomNode(r"$\vee_0$", 0.0, True), + AtomNode("b", 0.0, False), + AtomNode("c", 0.0, False), + AtomNode("d", 0.0, False), + AtomNode(r"$\vee_1$", 0.0, True), + ], + ) + # conjunction nodes + self.assertEqual( + set(nodes["conj"]), + { + ConjNode(r"$\wedge_0$", 0.0), + ConjNode(r"$\wedge_1$", 0.0), + ConjNode(r"$\wedge_2$", 0.0), + }, + ) + # aggregate nodes + self.assertFalse(nodes["count"]) + self.assertEqual(set(nodes["sum"]), {SumNode(r"$\Sigma_0$", 0.0, (-1,) * 4)}) + self.assertFalse(nodes["min"]) + self.assertFalse(nodes["max"]) + + # atom -> conj + self.assertEqual( + set(edges[("atom", "in", "conj")]), + { + # body + Atom2ConjEdge((3, 2), 1.0), # 'disj0' (aggr) to 'conj2' (body) + # conditions for '2' + Atom2ConjEdge((4, 0), 1.0), # 'b' to 'conj0' + Atom2ConjEdge((6, 1), -1.0), # 'f' to 'conj1' + }, + ) + # conj -> atom + self.assertEqual( + set(edges[("conj", "defines", "atom")]), + { + # body to head + Conj2AtomEdge((2, 2), 1.0), # 'conj2' (body) to 'a' + # conditions for '2' + Conj2AtomEdge((0, 7), 1.0), # 'conj0' to 'disj1' + Conj2AtomEdge((1, 7), 1.0), # 'conj1' to 'disj1' + }, + ) + # atom -> aggr + self.assertEqual( + set(edges[("atom", "in", "sum")]), + { + Atom2SumEdge((0, 0), 4.0), # 'True' to aggregate ('1','3') + Atom2SumEdge((7, 0), 2.0), # 'disj1' to aggregate ('2') + }, + ) + # aggr -> atom + self.assertFalse(edges[("count", "defines", "atom")]) + self.assertEqual( + edges[("sum", "defines", "atom")], + [ + Sum2AtomEdge((0, 3), 1.0), # choice count to 'disj0' (aggr) + ], + ) + self.assertFalse(edges[("min", "defines", "atom")]) + self.assertFalse(edges[("max", "defines", "atom")]) + + # node id dictionaries + self.assertEqual(rg.atom_ids[rg.true_const], 0) + self.assertEqual(rg.atom_ids[rg.false_const], 1) + self.assertEqual( + rg.atom_ids, + { + rg.true_const: 0, + rg.false_const: 1, + PredLiteral("a"): 2, + AggrLiteral( + AggrSum(), + ( + AggrElement([Number(1)]), + AggrElement([Number(2)], [PredLiteral("b")]), + AggrElement( + [Number(2)], [PredLiteral("c"), Naf(PredLiteral("d"))] + ), + AggrElement([Number(2)], [Naf(PredLiteral("d"))]), + AggrElement([Number(3)]), + ), + guards=(None, None), + ): 3, + PredLiteral("b"): 4, + PredLiteral("c"): 5, + PredLiteral("d"): 6, + # TODO: local conditions not reused + }, + ) + self.assertEqual( + rg.conj_ids, + { + LiteralCollection(PredLiteral("b")): 0, + LiteralCollection(Naf(PredLiteral("d"))): 1, + LiteralCollection( + AggrLiteral( + AggrSum(), + ( + AggrElement([Number(1)]), + AggrElement([Number(2)], [PredLiteral("b")]), + AggrElement( + [Number(2)], [PredLiteral("c"), Naf(PredLiteral("d"))] + ), + AggrElement([Number(2)], [Naf(PredLiteral("d"))]), + AggrElement([Number(3)]), + ), + guards=(None, None), + ) + ): 2, + }, + ) + self.assertEqual( + rg.aggr_ids, + { + AggrLiteral( + AggrSum(), + ( + AggrElement([Number(1)]), + AggrElement([Number(2)], [PredLiteral("b")]), + AggrElement( + [Number(2)], [PredLiteral("c"), Naf(PredLiteral("d"))] + ), + AggrElement([Number(2)], [Naf(PredLiteral("d"))]), + AggrElement([Number(3)]), + ), + guards=(None, None), + ): 0 + }, + ) + + # choice tracking + self.assertFalse(rg.choices) + self.assertFalse(rg.choice_edges) + + def test_min_aggregate(self): + prog = Program.from_string( + r""" + + a :- #min{1;2:b;2:c,not d;2:not d;3}. + + """ + ) + + # create reasoning graph + rg = ReasoningGraph(prog) + + # zip node and edges attributes for convenience (since order may differ) + nodes = zip_nodes(rg.node_dict) + edges = zip_edges(rg.edge_dict) + + # atom nodes + self.assertEqual( + nodes["atom"], + [ + AtomNode(r"$\top$", 1.0, True), + AtomNode(r"$\bot$", 0.0, True), + AtomNode("a", 0.0, False), + AtomNode(r"$\vee_0$", 0.0, True), + AtomNode("b", 0.0, False), + AtomNode("c", 0.0, False), + AtomNode("d", 0.0, False), + AtomNode(r"$\vee_1$", 0.0, True), + ], + ) + # conjunction nodes + self.assertEqual( + set(nodes["conj"]), + { + ConjNode(r"$\wedge_0$", 0.0), + ConjNode(r"$\wedge_1$", 0.0), + ConjNode(r"$\wedge_2$", 0.0), + }, + ) + # aggregate nodes + self.assertFalse(nodes["count"]) + self.assertFalse(nodes["sum"]) + self.assertEqual(set(nodes["min"]), {MinNode(r"$MIN_0$", 0.0, (-1,) * 4)}) + self.assertFalse(nodes["max"]) + + # atom -> conj + self.assertEqual( + set(edges[("atom", "in", "conj")]), + { + # body + Atom2ConjEdge((3, 2), 1.0), # 'disj0' (aggr) to 'conj2' (body) + # conditions for '2' + Atom2ConjEdge((4, 0), 1.0), # 'b' to 'conj0' + Atom2ConjEdge((6, 1), -1.0), # 'f' to 'conj1' + }, + ) + # conj -> atom + self.assertEqual( + set(edges[("conj", "defines", "atom")]), + { + # body to head + Conj2AtomEdge((2, 2), 1.0), # 'conj2' (body) to 'a' + # conditions for '2' + Conj2AtomEdge((0, 7), 1.0), # 'conj0' to 'disj1' + Conj2AtomEdge((1, 7), 1.0), # 'conj1' to 'disj1' + }, + ) + # atom -> aggr + self.assertFalse(edges[("atom", "in", "count")]) + self.assertFalse(edges[("atom", "in", "sum")]) + self.assertEqual( + set(edges[("atom", "in", "min")]), + { + Atom2MinEdge((0, 0), 1.0), # 'True' to choice count ('1','3') + Atom2MinEdge((7, 0), 2.0), # 'disj1' to choice count ('2') + }, + ) + self.assertFalse(edges[("atom", "in", "max")]) + # aggr -> atom + self.assertEqual( + edges[("min", "defines", "atom")], + [ + Min2AtomEdge((0, 3), 1.0), # choice count to 'disj0' (aggr) + ], + ) + + # node id dictionaries + self.assertEqual(rg.atom_ids[rg.true_const], 0) + self.assertEqual(rg.atom_ids[rg.false_const], 1) + self.assertEqual( + rg.atom_ids, + { + rg.true_const: 0, + rg.false_const: 1, + PredLiteral("a"): 2, + AggrLiteral( + AggrMin(), + ( + AggrElement([Number(1)]), + AggrElement([Number(2)], [PredLiteral("b")]), + AggrElement( + [Number(2)], [PredLiteral("c"), Naf(PredLiteral("d"))] + ), + AggrElement([Number(2)], [Naf(PredLiteral("d"))]), + AggrElement([Number(3)]), + ), + guards=(None, None), + ): 3, + PredLiteral("b"): 4, + PredLiteral("c"): 5, + PredLiteral("d"): 6, + # TODO: local conditions not reused + }, + ) + self.assertEqual( + rg.conj_ids, + { + LiteralCollection(PredLiteral("b")): 0, + LiteralCollection(Naf(PredLiteral("d"))): 1, + LiteralCollection( + AggrLiteral( + AggrMin(), + ( + AggrElement([Number(1)]), + AggrElement([Number(2)], [PredLiteral("b")]), + AggrElement( + [Number(2)], [PredLiteral("c"), Naf(PredLiteral("d"))] + ), + AggrElement([Number(2)], [Naf(PredLiteral("d"))]), + AggrElement([Number(3)]), + ), + guards=(None, None), + ) + ): 2, + }, + ) + self.assertEqual( + rg.aggr_ids, + { + AggrLiteral( + AggrMin(), + ( + AggrElement([Number(1)]), + AggrElement([Number(2)], [PredLiteral("b")]), + AggrElement( + [Number(2)], [PredLiteral("c"), Naf(PredLiteral("d"))] + ), + AggrElement([Number(2)], [Naf(PredLiteral("d"))]), + AggrElement([Number(3)]), + ), + guards=(None, None), + ): 0 + }, + ) + + # choice tracking + self.assertFalse(rg.choices) + self.assertFalse(rg.choice_edges) + + def test_max_aggregate(self): + prog = Program.from_string( + r""" + + a :- #max{1;2:b;2:c,not d;2:not d;3}. + + """ + ) + + # create reasoning graph + rg = ReasoningGraph(prog) + + # zip node and edges attributes for convenience (since order may differ) + nodes = zip_nodes(rg.node_dict) + edges = zip_edges(rg.edge_dict) + + # atom nodes + self.assertEqual( + nodes["atom"], + [ + AtomNode(r"$\top$", 1.0, True), + AtomNode(r"$\bot$", 0.0, True), + AtomNode("a", 0.0, False), + AtomNode(r"$\vee_0$", 0.0, True), + AtomNode("b", 0.0, False), + AtomNode("c", 0.0, False), + AtomNode("d", 0.0, False), + AtomNode(r"$\vee_1$", 0.0, True), + ], + ) + # conjunction nodes + self.assertEqual( + set(nodes["conj"]), + { + ConjNode(r"$\wedge_0$", 0.0), + ConjNode(r"$\wedge_1$", 0.0), + ConjNode(r"$\wedge_2$", 0.0), + }, + ) + # aggregate nodes + self.assertFalse(nodes["count"]) + self.assertFalse(nodes["sum"]) + self.assertFalse(nodes["min"]) + self.assertEqual(set(nodes["max"]), {MaxNode(r"$MAX_0$", 0.0, (-1,) * 4)}) + + # atom -> conj + self.assertEqual( + set(edges[("atom", "in", "conj")]), + { + # body + Atom2ConjEdge((3, 2), 1.0), # 'disj0' (aggr) to 'conj2' (body) + # conditions for '2' + Atom2ConjEdge((4, 0), 1.0), # 'b' to 'conj0' + Atom2ConjEdge((6, 1), -1.0), # 'f' to 'conj1' + }, + ) + # conj -> atom + self.assertEqual( + set(edges[("conj", "defines", "atom")]), + { + # body to head + Conj2AtomEdge((2, 2), 1.0), # 'conj2' (body) to 'a' + # conditions for '2' + Conj2AtomEdge((0, 7), 1.0), # 'conj0' to 'disj1' + Conj2AtomEdge((1, 7), 1.0), # 'conj1' to 'disj1' + }, + ) + # atom -> aggr + self.assertFalse(edges[("atom", "in", "count")]) + self.assertFalse(edges[("atom", "in", "sum")]) + self.assertFalse(edges[("atom", "in", "min")]) + self.assertEqual( + set(edges[("atom", "in", "max")]), + { + Atom2MaxEdge((0, 0), 3.0), # 'True' to choice count ('1','3') + Atom2MaxEdge((7, 0), 2.0), # 'disj1' to choice count ('2') + }, + ) + # aggr -> atom + self.assertFalse(edges[("count", "defines", "atom")]) + self.assertFalse(edges[("sum", "defines", "atom")]) + self.assertFalse(edges[("min", "defines", "atom")]) + self.assertEqual( + edges[("max", "defines", "atom")], + [ + Max2AtomEdge((0, 3), 1.0), # choice count to 'disj0' (aggr) + ], + ) + + # node id dictionaries + self.assertEqual(rg.atom_ids[rg.true_const], 0) + self.assertEqual(rg.atom_ids[rg.false_const], 1) + self.assertEqual( + rg.atom_ids, + { + rg.true_const: 0, + rg.false_const: 1, + PredLiteral("a"): 2, + AggrLiteral( + AggrMax(), + ( + AggrElement([Number(1)]), + AggrElement([Number(2)], [PredLiteral("b")]), + AggrElement( + [Number(2)], [PredLiteral("c"), Naf(PredLiteral("d"))] + ), + AggrElement([Number(2)], [Naf(PredLiteral("d"))]), + AggrElement([Number(3)]), + ), + guards=(None, None), + ): 3, + PredLiteral("b"): 4, + PredLiteral("c"): 5, + PredLiteral("d"): 6, + # TODO: local conditions not reused + }, + ) + self.assertEqual( + rg.conj_ids, + { + LiteralCollection(PredLiteral("b")): 0, + LiteralCollection(Naf(PredLiteral("d"))): 1, + LiteralCollection( + AggrLiteral( + AggrMax(), + ( + AggrElement([Number(1)]), + AggrElement([Number(2)], [PredLiteral("b")]), + AggrElement( + [Number(2)], [PredLiteral("c"), Naf(PredLiteral("d"))] + ), + AggrElement([Number(2)], [Naf(PredLiteral("d"))]), + AggrElement([Number(3)]), + ), + guards=(None, None), + ) + ): 2, + }, + ) + self.assertEqual( + rg.aggr_ids, + { + AggrLiteral( + AggrMax(), + ( + AggrElement([Number(1)]), + AggrElement([Number(2)], [PredLiteral("b")]), + AggrElement( + [Number(2)], [PredLiteral("c"), Naf(PredLiteral("d"))] + ), + AggrElement([Number(2)], [Naf(PredLiteral("d"))]), + AggrElement([Number(3)]), + ), + guards=(None, None), + ): 0 + }, + ) + + # choice tracking + self.assertFalse(rg.choices) + self.assertFalse(rg.choice_edges) + + def test_NPP_fact(self): + prog = Program.from_string( + r""" + + #npp(h(a,b), [0,1,2]). + + """ + ) + + # create reasoning graph + rg = ReasoningGraph(prog) + + # zip node and edges attributes for convenience (since order may differ) + nodes = zip_nodes(rg.node_dict) + edges = zip_edges(rg.edge_dict) + + # atom nodes + self.assertEqual( + nodes["atom"], + [ + AtomNode(r"$\top$", 1.0, True), + AtomNode(r"$\bot$", 0.0, True), + AtomNode("h(a,b,0)", 0.0, False), + AtomNode("h(a,b,1)", 0.0, False), + AtomNode("h(a,b,2)", 0.0, False), + AtomNode(r"$\vee_0$", 0.0, True), + AtomNode(r"$\vee_1$", 0.0, True), + AtomNode(r"$\vee_2$", 0.0, True), + AtomNode(r"$\vee_3$", 0.0, True), + AtomNode(r"$\vee_4$", 0.0, True), + ], + ) + # conjunction nodes + self.assertEqual( + nodes["conj"], + [ + ConjNode(r"$\wedge_0$", 0.0), + ConjNode(r"$\wedge_1$", 0.0), + ConjNode(r"$\wedge_2$", 0.0), + ConjNode(r"$\wedge_3$", 0.0), + ConjNode(r"$\wedge_4$", 0.0), + ], + ) + # aggregate nodes + self.assertEqual( + set(nodes["count"]), {CountNode(r"$\#_0$", 0.0, (0, 1, -1, -1))} + ) + self.assertFalse(nodes["sum"]) + self.assertFalse(nodes["min"]) + self.assertFalse(nodes["max"]) + + # atom -> conj + self.assertEqual( + set(edges[("atom", "in", "conj")]), + { + # body to conj. + Atom2ConjEdge((0, 0), 1.0), # 'True' to 'conj0' ('True') + # global constraint for choice (given sat. body) + Atom2ConjEdge((5, 1), 1.0), # 'disj0' (body) to 'conj1' + Atom2ConjEdge((6, 1), -1.0), # 'disj1' (aggr) to 'conj1' + # local constraint for 'h(a,b,0)' + Atom2ConjEdge((2, 2), 1.0), # 'h(a,b,0)' to 'conj2' + Atom2ConjEdge((5, 2), 1.0), # 'disj0' (body) to 'conj2' + Atom2ConjEdge((7, 2), -1.0), # 'disj1' (aggr) to 'conj2' + # local constraint for 'h(a,b,1)' + Atom2ConjEdge((3, 3), 1.0), # 'h(a,b,0)' to 'conj3' + Atom2ConjEdge((5, 3), 1.0), # 'disj0' (body) to 'conj3' + Atom2ConjEdge((8, 3), -1.0), # 'disj1' (aggr) to 'conj3' + # local constraint for 'h(a,b,2)' + Atom2ConjEdge((4, 4), 1.0), # 'h(a,b,0)' to 'conj4' + Atom2ConjEdge((5, 4), 1.0), # 'disj0' (body) to 'conj4' + Atom2ConjEdge((9, 4), -1.0), # 'disj1' (aggr) to 'conj4' + }, + ) + # conj -> atom + self.assertEqual( + set(edges[("conj", "defines", "atom")]), + { + # conj. to head + Conj2AtomEdge((0, 2), 1.0), # 'conj0' ('True') to 'h(a,b,0)' + Conj2AtomEdge((0, 3), 1.0), # 'conj0' ('True') to 'h(a,b,1)' + Conj2AtomEdge((0, 4), 1.0), # 'conj0' ('True') to 'h(a,b,2)' + # conj. to aux. body atom + Conj2AtomEdge((0, 5), 1.0), # 'conj0' ('True') to 'disj0' + # global constraint + Conj2AtomEdge((1, 1), 1.0), # 'conj1' ('True') to 'False' + # cond. for 'h(a,b,0)' + Conj2AtomEdge((2, 1), 1.0), # 'conj2' ('True') to 'False' + Conj2AtomEdge((0, 7), 1.0), # 'conj0' ('True') to 'disj2' + # cond. for 'h(a,b,1)' + Conj2AtomEdge((3, 1), 1.0), # 'conj3' ('True') to 'False' + Conj2AtomEdge((0, 8), 1.0), # 'conj0' ('True') to 'disj3' + # cond. for 'h(a,b,2)' + Conj2AtomEdge((4, 1), 1.0), # 'conj4' ('True') to 'False' + Conj2AtomEdge((0, 9), 1.0), # 'conj0' ('True') to 'disj4' + }, + ) + # atom -> aggr + self.assertEqual( + set(edges[("atom", "in", "count")]), + { + Atom2CountEdge((2, 0), 1.0), # 'a' to choice count + Atom2CountEdge((3, 0), 1.0), # 'b' to choice count + Atom2CountEdge((4, 0), 1.0), # 'c' to choice count + }, + ) + self.assertFalse(edges[("atom", "in", "sum")]) + self.assertFalse(edges[("atom", "in", "min")]) + self.assertFalse(edges[("atom", "in", "max")]) + # aggr -> atom + self.assertEqual( + edges[("count", "defines", "atom")], + [ + Count2AtomEdge((0, 6), 1.0), # choice count to 'disj1' (aggr) + ], + ) + self.assertFalse(edges[("sum", "defines", "atom")]) + self.assertFalse(edges[("min", "defines", "atom")]) + self.assertFalse(edges[("max", "defines", "atom")]) + + # node id dictionaries + self.assertEqual(rg.atom_ids[rg.true_const], 0) + self.assertEqual(rg.atom_ids[rg.false_const], 1) + self.assertEqual( + rg.atom_ids, + { + rg.true_const: 0, + rg.false_const: 1, + PredLiteral( + "h", SymbolicConstant("a"), SymbolicConstant("b"), Number(0) + ): 2, + PredLiteral( + "h", SymbolicConstant("a"), SymbolicConstant("b"), Number(1) + ): 3, + PredLiteral( + "h", SymbolicConstant("a"), SymbolicConstant("b"), Number(2) + ): 4, + LiteralCollection(rg.true_const): 5, + AggrLiteral( + AggrCount(), + ( + AggrElement( + [ + Functional( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(0), + ) + ], + [ + PredLiteral( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(0), + ) + ], + ), + AggrElement( + [ + Functional( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(1), + ) + ], + [ + PredLiteral( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(1), + ) + ], + ), + AggrElement( + [ + Functional( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(2), + ) + ], + [ + PredLiteral( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(2), + ) + ], + ), + ), + guards=(Guard(RelOp.EQUAL, Number(1), False), None), + ): 6, + }, + ) + self.assertEqual( + rg.conj_ids, + { + LiteralCollection(rg.true_const): 0, + # NOTE: constraint conjunctions are not re-used + }, + ) + self.assertEqual( + rg.aggr_ids, + { + AggrLiteral( + AggrCount(), + ( + AggrElement( + [ + Functional( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(0), + ) + ], + [ + PredLiteral( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(0), + ) + ], + ), + AggrElement( + [ + Functional( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(1), + ) + ], + [ + PredLiteral( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(1), + ) + ], + ), + AggrElement( + [ + Functional( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(2), + ) + ], + [ + PredLiteral( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(2), + ) + ], + ), + ), + guards=(Guard(RelOp.EQUAL, Number(1), False), None), + ): 0 + }, + ) + + # choice tracking + self.assertFalse(rg.choices) + self.assertFalse(rg.choice_edges) + + # NPP tracking + self.assertEqual( + rg.npps, + { + NPP( + "h", + [SymbolicConstant("a"), SymbolicConstant("b")], + [Number(0), Number(1), Number(2)], + ), + }, + ) + self.assertEqual( + rg.npp_edges, + { + NPP( + "h", + [SymbolicConstant("a"), SymbolicConstant("b")], + [Number(0), Number(1), Number(2)], + ): [ + 0, # (0, 2) + 1, # (0, 3) + 2, # (0, 4) + ] + # '$\wedge_0$' to 'a', 'b', 'c' + }, + ) + + def test_NPP_rule(self): + prog = Program.from_string( + r""" + + #npp(h(a,b), [0,1,2]) :- a, not b. + + """ + ) + + # create reasoning graph + rg = ReasoningGraph(prog) + + # zip node and edges attributes for convenience (since order may differ) + nodes = zip_nodes(rg.node_dict) + edges = zip_edges(rg.edge_dict) + + # atom nodes + self.assertEqual( + nodes["atom"], + [ + AtomNode(r"$\top$", 1.0, True), + AtomNode(r"$\bot$", 0.0, True), + AtomNode("h(a,b,0)", 0.0, False), + AtomNode("h(a,b,1)", 0.0, False), + AtomNode("h(a,b,2)", 0.0, False), + AtomNode("a", 0.0, False), + AtomNode("b", 0.0, False), + AtomNode(r"$\vee_0$", 0.0, True), + AtomNode(r"$\vee_1$", 0.0, True), + AtomNode(r"$\vee_2$", 0.0, True), + AtomNode(r"$\vee_3$", 0.0, True), + AtomNode(r"$\vee_4$", 0.0, True), + ], + ) + # conjunction nodes + self.assertEqual( + nodes["conj"], + [ + ConjNode(r"$\wedge_0$", 0.0), + ConjNode(r"$\wedge_1$", 0.0), + ConjNode(r"$\wedge_2$", 0.0), + ConjNode(r"$\wedge_3$", 0.0), + ConjNode(r"$\wedge_4$", 0.0), + ConjNode(r"$\wedge_5$", 0.0), + ], + ) + # aggregate nodes + self.assertEqual( + set(nodes["count"]), {CountNode(r"$\#_0$", 0.0, (0, 1, -1, -1))} + ) + self.assertFalse(nodes["sum"]) + self.assertFalse(nodes["min"]) + self.assertFalse(nodes["max"]) + + # atom -> conj + self.assertEqual( + set(edges[("atom", "in", "conj")]), + { + # condition for 'choices' + Atom2ConjEdge((0, 2), 1.0), # 'True' to 'conj2' + # body to conj. + Atom2ConjEdge((5, 0), 1.0), # 'a' to 'conj2' ('a, not b') + Atom2ConjEdge((6, 0), -1.0), # 'b' to 'conj2' ('a, not b') + # global constraint for choice (given sat. body) + Atom2ConjEdge((7, 1), 1.0), # 'disj0' (body) to 'conj1' + Atom2ConjEdge((8, 1), -1.0), # 'disj1' (aggr) to 'conj1' + # local constraint for 'h(a,b,0)' + Atom2ConjEdge((2, 3), 1.0), # 'h(a,b,0)' to 'conj2' + Atom2ConjEdge((7, 3), 1.0), # 'disj0' (body) to 'conj2' + Atom2ConjEdge((9, 3), -1.0), # 'disj1' (aggr) to 'conj2' + # local constraint for 'h(a,b,1)' + Atom2ConjEdge((3, 4), 1.0), # 'h(a,b,0)' to 'conj3' + Atom2ConjEdge((7, 4), 1.0), # 'disj0' (body) to 'conj3' + Atom2ConjEdge((10, 4), -1.0), # 'disj1' (aggr) to 'conj3' + # local constraint for 'h(a,b,2)' + Atom2ConjEdge((4, 5), 1.0), # 'h(a,b,0)' to 'conj4' + Atom2ConjEdge((7, 5), 1.0), # 'disj0' (body) to 'conj4' + Atom2ConjEdge((11, 5), -1.0), # 'disj1' (aggr) to 'conj4' + }, + ) + # conj -> atom + self.assertEqual( + set(edges[("conj", "defines", "atom")]), + { + # conj. to head + Conj2AtomEdge((0, 2), 1.0), # 'conj0' ('True') to 'h(a,b,0)' + Conj2AtomEdge((0, 3), 1.0), # 'conj0' ('True') to 'h(a,b,1)' + Conj2AtomEdge((0, 4), 1.0), # 'conj0' ('True') to 'h(a,b,2)' + # conj. to aux. body atom + Conj2AtomEdge((0, 7), 1.0), # 'conj0' ('True') to 'disj0' + # global constraint + Conj2AtomEdge((1, 1), 1.0), # 'conj1' ('True') to 'False' + # cond. for 'h(a,b,0)' + Conj2AtomEdge((3, 1), 1.0), # 'conj2' ('True') to 'False' + Conj2AtomEdge((2, 9), 1.0), # 'conj0' ('True') to 'disj2' + # cond. for 'h(a,b,1)' + Conj2AtomEdge((4, 1), 1.0), # 'conj3' ('True') to 'False' + Conj2AtomEdge((2, 10), 1.0), # 'conj0' ('True') to 'disj3' + # cond. for 'h(a,b,2)' + Conj2AtomEdge((5, 1), 1.0), # 'conj4' ('True') to 'False' + Conj2AtomEdge((2, 11), 1.0), # 'conj0' ('True') to 'disj4' + }, + ) + # atom -> aggr + self.assertEqual( + set(edges[("atom", "in", "count")]), + { + Atom2CountEdge((2, 0), 1.0), # 'a' to choice count + Atom2CountEdge((3, 0), 1.0), # 'b' to choice count + Atom2CountEdge((4, 0), 1.0), # 'c' to choice count + }, + ) + self.assertFalse(edges[("atom", "in", "sum")]) + self.assertFalse(edges[("atom", "in", "min")]) + self.assertFalse(edges[("atom", "in", "max")]) + # aggr -> atom + self.assertEqual( + edges[("count", "defines", "atom")], + [ + Count2AtomEdge((0, 8), 1.0), # choice count to 'disj1' (aggr) + ], + ) + self.assertFalse(edges[("sum", "defines", "atom")]) + self.assertFalse(edges[("min", "defines", "atom")]) + self.assertFalse(edges[("max", "defines", "atom")]) + + # node id dictionaries + self.assertEqual(rg.atom_ids[rg.true_const], 0) + self.assertEqual(rg.atom_ids[rg.false_const], 1) + self.assertEqual( + rg.atom_ids, + { + rg.true_const: 0, + rg.false_const: 1, + PredLiteral( + "h", SymbolicConstant("a"), SymbolicConstant("b"), Number(0) + ): 2, + PredLiteral( + "h", SymbolicConstant("a"), SymbolicConstant("b"), Number(1) + ): 3, + PredLiteral( + "h", SymbolicConstant("a"), SymbolicConstant("b"), Number(2) + ): 4, + PredLiteral("a"): 5, + PredLiteral("b"): 6, + LiteralCollection(PredLiteral("a"), Naf(PredLiteral("b"))): 7, + AggrLiteral( + AggrCount(), + ( + AggrElement( + [ + Functional( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(0), + ) + ], + [ + PredLiteral( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(0), + ) + ], + ), + AggrElement( + [ + Functional( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(1), + ) + ], + [ + PredLiteral( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(1), + ) + ], + ), + AggrElement( + [ + Functional( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(2), + ) + ], + [ + PredLiteral( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(2), + ) + ], + ), + ), + guards=(Guard(RelOp.EQUAL, Number(1), False), None), + ): 8, + }, + ) + self.assertEqual( + rg.conj_ids, + { + LiteralCollection(PredLiteral("a"), Naf(PredLiteral("b"))): 0, + LiteralCollection(rg.true_const): 2, + # NOTE: constraint conjunctions are not re-used + }, + ) + self.assertEqual( + rg.aggr_ids, + { + AggrLiteral( + AggrCount(), + ( + AggrElement( + [ + Functional( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(0), + ) + ], + [ + PredLiteral( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(0), + ) + ], + ), + AggrElement( + [ + Functional( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(1), + ) + ], + [ + PredLiteral( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(1), + ) + ], + ), + AggrElement( + [ + Functional( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(2), + ) + ], + [ + PredLiteral( + "h", + SymbolicConstant("a"), + SymbolicConstant("b"), + Number(2), + ) + ], + ), + ), + guards=(Guard(RelOp.EQUAL, Number(1), False), None), + ): 0 + }, + ) + + # choice tracking + self.assertFalse(rg.choices) + self.assertFalse(rg.choice_edges) + + # NPP tracking + self.assertEqual( + rg.npps, + { + NPP( + "h", + [SymbolicConstant("a"), SymbolicConstant("b")], + [Number(0), Number(1), Number(2)], + ), + }, + ) + self.assertEqual( + rg.npp_edges, + { + NPP( + "h", + [SymbolicConstant("a"), SymbolicConstant("b")], + [Number(0), Number(1), Number(2)], + ): [ + 0, # (0, 2) + 1, # (0, 3) + 2, # (0, 4) + ] + # '$\wedge_0$' to 'a', 'b', 'c' + }, + ) + + def test_4_queens(self): + prog = Program.from_string( + r""" + n(0). n(1). n(2). n(3). + + % choose one row per queen + 1={q(0,0);q(0,1);q(0,2);q(0,3)} :- n(0). + 1={q(1,0);q(1,1);q(1,2);q(1,3)} :- n(1). + 1={q(2,0);q(2,1);q(2,2);q(2,3)} :- n(2). + 1={q(3,0);q(3,1);q(3,2);q(3,3)} :- n(3). + + % check columns + :- q(0,0), q(1,0), 0 < 1. + :- q(0,0), q(2,0), 0 < 2. + :- q(0,0), q(3,0), 0 < 3. + :- q(1,0), q(2,0), 1 < 2. + :- q(1,0), q(3,0), 1 < 3. + :- q(2,0), q(3,0), 2 < 3. + + :- q(0,1), q(1,1), 0 < 1. + :- q(0,1), q(2,1), 0 < 2. + :- q(0,1), q(3,1), 0 < 3. + :- q(1,1), q(2,1), 1 < 2. + :- q(1,1), q(3,1), 1 < 3. + :- q(2,1), q(3,1), 2 < 3. + + :- q(0,2), q(1,2), 0 < 1. + :- q(0,2), q(2,2), 0 < 2. + :- q(0,2), q(3,2), 0 < 3. + :- q(1,2), q(2,2), 1 < 2. + :- q(1,2), q(3,2), 1 < 3. + :- q(2,2), q(3,2), 2 < 3. + + :- q(0,3), q(1,3), 0 < 1. + :- q(0,3), q(2,3), 0 < 2. + :- q(0,3), q(3,3), 0 < 3. + :- q(1,3), q(2,3), 1 < 2. + :- q(1,3), q(3,3), 1 < 3. + :- q(2,3), q(3,3), 2 < 3. + + % check diagonals 1 + :- q(0,0), q(1,1), n(1), 1=0+1, 1=0+1, 1 > 0. + :- q(1,1), q(2,2), n(1), 2=1+1, 2=1+1, 1 > 0. + :- q(2,2), q(3,3), n(1), 3=2+1, 3=2+1, 1 > 0. + + :- q(0,1), q(1,2), n(1), 1=0+1, 2=1+1, 1 > 0. + :- q(1,2), q(2,3), n(1), 2=1+1, 3=2+1, 1 > 0. + + :- q(0,2), q(1,3), n(1), 1=0+1, 3=2+1, 1 > 0. + + :- q(1,0), q(2,1), n(1), 2=1+1, 1=0+1, 1 > 0. + :- q(2,1), q(3,2), n(1), 3=2+1, 2=1+1, 1 > 0. + + :- q(2,0), q(3,1), n(1), 3=2+1, 1=0+1, 1 > 0. + + :- q(0,0), q(2,2), n(2), 2=0+2, 2=0+2, 2 > 0. + :- q(1,1), q(3,3), n(2), 3=1+2, 3=1+2, 2 > 0. + + :- q(0,1), q(2,3), n(2), 2=0+2, 3=1+2, 2 > 0. + + :- q(1,0), q(3,2), n(2), 3=1+2, 2=0+2, 2 > 0. + + :- q(0,0), q(3,3), n(3), 3=0+3, 3=0+3, 3 > 0. + + % check diagonals 2 + :- q(0,3), q(1,2), n(1), 1=0+1, 3=2+1, 1 > 0. + :- q(1,2), q(2,1), n(1), 2=1+1, 2=1+1, 1 > 0. + :- q(2,1), q(3,0), n(1), 3=2+1, 1=0+1, 1 > 0. + + :- q(0,2), q(1,1), n(1), 1=0+1, 2=1+1, 1 > 0. + :- q(1,1), q(2,0), n(1), 2=1+1, 1=0+1, 1 > 0. + + :- q(0,1), q(1,0), n(1), 1=0+1, 1=0+1, 1 > 0. + + :- q(1,3), q(2,2), n(1), 2=1+1, 3=2+1, 1 > 0. + :- q(2,2), q(3,1), n(1), 3=2+1, 2=1+1, 1 > 0. + + :- q(2,3), q(3,2), n(1), 3=2+1, 3=2+1, 1 > 0. + + :- q(0,3), q(2,1), n(2), 2=0+2, 3=1+2, 2 > 0. + :- q(1,2), q(3,0), n(2), 3=1+2, 2=0+2, 2 > 0. + + :- q(0,2), q(2,0), n(2), 2=0+2, 2=0+2, 2 > 0. + + :- q(1,3), q(3,1), n(2), 3=1+2, 3=1+2, 2 > 0. + + :- q(0,3), q(3,0), n(3), 3=0+3, 3=0+3, 3 > 0. + """ + ) + + # create reasoning graph + rg = ReasoningGraph(prog) + + # zip node and edges attributes for convenience (since order may differ) + nodes = zip_nodes(rg.node_dict) + edges = zip_edges(rg.edge_dict) + + # atom nodes + self.assertEqual( + nodes["atom"], + [ + AtomNode(r"$\top$", 1.0, True), + AtomNode(r"$\bot$", 0.0, True), + AtomNode("n(0)", 1.0, False), + AtomNode("n(1)", 1.0, False), + AtomNode("n(2)", 1.0, False), + AtomNode("n(3)", 1.0, False), + AtomNode("q(0,0)", 0.0, False), + AtomNode("q(0,1)", 0.0, False), + AtomNode("q(0,2)", 0.0, False), + AtomNode("q(0,3)", 0.0, False), + AtomNode(r"$\vee_0$", 0.0, True), + AtomNode(r"$\vee_1$", 0.0, True), + AtomNode(r"$\vee_2$", 0.0, True), + AtomNode(r"$\vee_3$", 0.0, True), + AtomNode(r"$\vee_4$", 0.0, True), + AtomNode(r"$\vee_5$", 0.0, True), + AtomNode("q(1,0)", 0.0, False), + AtomNode("q(1,1)", 0.0, False), + AtomNode("q(1,2)", 0.0, False), + AtomNode("q(1,3)", 0.0, False), + AtomNode(r"$\vee_6$", 0.0, True), + AtomNode(r"$\vee_7$", 0.0, True), + AtomNode(r"$\vee_8$", 0.0, True), + AtomNode(r"$\vee_9$", 0.0, True), + AtomNode(r"$\vee_10$", 0.0, True), + AtomNode(r"$\vee_11$", 0.0, True), + AtomNode("q(2,0)", 0.0, False), + AtomNode("q(2,1)", 0.0, False), + AtomNode("q(2,2)", 0.0, False), + AtomNode("q(2,3)", 0.0, False), + AtomNode(r"$\vee_12$", 0.0, True), + AtomNode(r"$\vee_13$", 0.0, True), + AtomNode(r"$\vee_14$", 0.0, True), + AtomNode(r"$\vee_15$", 0.0, True), + AtomNode(r"$\vee_16$", 0.0, True), + AtomNode(r"$\vee_17$", 0.0, True), + AtomNode("q(3,0)", 0.0, False), + AtomNode("q(3,1)", 0.0, False), + AtomNode("q(3,2)", 0.0, False), + AtomNode("q(3,3)", 0.0, False), + AtomNode(r"$\vee_18$", 0.0, True), + AtomNode(r"$\vee_19$", 0.0, True), + AtomNode(r"$\vee_20$", 0.0, True), + AtomNode(r"$\vee_21$", 0.0, True), + AtomNode(r"$\vee_22$", 0.0, True), + AtomNode(r"$\vee_23$", 0.0, True), + ], + ) + # conjunction nodes + self.assertEqual( + nodes["conj"], + [ + ConjNode(r"$\wedge_0$", 0.0), + ConjNode(r"$\wedge_1$", 0.0), + ConjNode(r"$\wedge_2$", 0.0), + ConjNode(r"$\wedge_3$", 0.0), + ConjNode(r"$\wedge_4$", 0.0), + ConjNode(r"$\wedge_5$", 0.0), + ConjNode(r"$\wedge_6$", 0.0), + ConjNode(r"$\wedge_7$", 0.0), + ConjNode(r"$\wedge_8$", 0.0), + ConjNode(r"$\wedge_9$", 0.0), + ConjNode(r"$\wedge_10$", 0.0), + ConjNode(r"$\wedge_11$", 0.0), + ConjNode(r"$\wedge_12$", 0.0), + ConjNode(r"$\wedge_13$", 0.0), + ConjNode(r"$\wedge_14$", 0.0), + ConjNode(r"$\wedge_15$", 0.0), + ConjNode(r"$\wedge_16$", 0.0), + ConjNode(r"$\wedge_17$", 0.0), + ConjNode(r"$\wedge_18$", 0.0), + ConjNode(r"$\wedge_19$", 0.0), + ConjNode(r"$\wedge_20$", 0.0), + ConjNode(r"$\wedge_21$", 0.0), + ConjNode(r"$\wedge_22$", 0.0), + ConjNode(r"$\wedge_23$", 0.0), + ConjNode(r"$\wedge_24$", 0.0), + # columns constraints + ConjNode(r"$\wedge_25$", 0.0), + ConjNode(r"$\wedge_26$", 0.0), + ConjNode(r"$\wedge_27$", 0.0), + ConjNode(r"$\wedge_28$", 0.0), + ConjNode(r"$\wedge_29$", 0.0), + ConjNode(r"$\wedge_30$", 0.0), + ConjNode(r"$\wedge_31$", 0.0), + ConjNode(r"$\wedge_32$", 0.0), + ConjNode(r"$\wedge_33$", 0.0), + ConjNode(r"$\wedge_34$", 0.0), + ConjNode(r"$\wedge_35$", 0.0), + ConjNode(r"$\wedge_36$", 0.0), + ConjNode(r"$\wedge_37$", 0.0), + ConjNode(r"$\wedge_38$", 0.0), + ConjNode(r"$\wedge_39$", 0.0), + ConjNode(r"$\wedge_40$", 0.0), + ConjNode(r"$\wedge_41$", 0.0), + ConjNode(r"$\wedge_42$", 0.0), + ConjNode(r"$\wedge_43$", 0.0), + ConjNode(r"$\wedge_44$", 0.0), + ConjNode(r"$\wedge_45$", 0.0), + ConjNode(r"$\wedge_46$", 0.0), + ConjNode(r"$\wedge_47$", 0.0), + ConjNode(r"$\wedge_48$", 0.0), + # diagonal constraints 1 + ConjNode(r"$\wedge_49$", 0.0), + ConjNode(r"$\wedge_50$", 0.0), + ConjNode(r"$\wedge_51$", 0.0), + ConjNode(r"$\wedge_52$", 0.0), + ConjNode(r"$\wedge_53$", 0.0), + ConjNode(r"$\wedge_54$", 0.0), + ConjNode(r"$\wedge_55$", 0.0), + ConjNode(r"$\wedge_56$", 0.0), + ConjNode(r"$\wedge_57$", 0.0), + ConjNode(r"$\wedge_58$", 0.0), + ConjNode(r"$\wedge_59$", 0.0), + ConjNode(r"$\wedge_60$", 0.0), + ConjNode(r"$\wedge_61$", 0.0), + ConjNode(r"$\wedge_62$", 0.0), + # diagonal constraints 2 + ConjNode(r"$\wedge_63$", 0.0), + ConjNode(r"$\wedge_64$", 0.0), + ConjNode(r"$\wedge_65$", 0.0), + ConjNode(r"$\wedge_66$", 0.0), + ConjNode(r"$\wedge_67$", 0.0), + ConjNode(r"$\wedge_68$", 0.0), + ConjNode(r"$\wedge_69$", 0.0), + ConjNode(r"$\wedge_70$", 0.0), + ConjNode(r"$\wedge_71$", 0.0), + ConjNode(r"$\wedge_72$", 0.0), + ConjNode(r"$\wedge_73$", 0.0), + ConjNode(r"$\wedge_74$", 0.0), + ConjNode(r"$\wedge_75$", 0.0), + ConjNode(r"$\wedge_76$", 0.0), + ], + ) + # aggregate nodes + self.assertEqual( + set(nodes["count"]), {CountNode(r"$\#_0$", 0.0, (0, 1, -1, -1))} + ) + self.assertFalse(nodes["sum"]) + self.assertFalse(nodes["min"]) + self.assertFalse(nodes["max"]) + + # atom -> conj + self.assertEqual( + set(edges[("atom", "in", "conj")]), + { + Atom2ConjEdge((0, 0), 1.0), # 'True' to 'conj0' ('True') + Atom2ConjEdge((2, 1), 1.0), + Atom2ConjEdge((6, 3), 1.0), + Atom2ConjEdge((7, 4), 1.0), + Atom2ConjEdge((8, 5), 1.0), + Atom2ConjEdge((9, 6), 1.0), + Atom2ConjEdge((10, 2), 1.0), + Atom2ConjEdge((10, 3), 1.0), + Atom2ConjEdge((10, 4), 1.0), + Atom2ConjEdge((10, 5), 1.0), + Atom2ConjEdge((10, 6), 1.0), + Atom2ConjEdge((11, 2), -1.0), + Atom2ConjEdge((12, 3), -1.0), + Atom2ConjEdge((13, 4), -1.0), + Atom2ConjEdge((14, 5), -1.0), + Atom2ConjEdge((15, 6), -1.0), + Atom2ConjEdge((3, 7), 1.0), + Atom2ConjEdge((16, 9), 1.0), + Atom2ConjEdge((17, 10), 1.0), + Atom2ConjEdge((18, 11), 1.0), + Atom2ConjEdge((19, 12), 1.0), + Atom2ConjEdge((20, 8), 1.0), + Atom2ConjEdge((20, 9), 1.0), + Atom2ConjEdge((20, 10), 1.0), + Atom2ConjEdge((20, 11), 1.0), + Atom2ConjEdge((20, 12), 1.0), + Atom2ConjEdge((21, 8), -1.0), + Atom2ConjEdge((22, 9), -1.0), + Atom2ConjEdge((23, 10), -1.0), + Atom2ConjEdge((24, 11), -1.0), + Atom2ConjEdge((25, 12), -1.0), + Atom2ConjEdge((4, 13), 1.0), + Atom2ConjEdge((26, 15), 1.0), + Atom2ConjEdge((27, 16), 1.0), + Atom2ConjEdge((28, 17), 1.0), + Atom2ConjEdge((29, 18), 1.0), + Atom2ConjEdge((30, 14), 1.0), + Atom2ConjEdge((30, 15), 1.0), + Atom2ConjEdge((30, 16), 1.0), + Atom2ConjEdge((30, 17), 1.0), + Atom2ConjEdge((30, 18), 1.0), + Atom2ConjEdge((31, 14), -1.0), + Atom2ConjEdge((32, 15), -1.0), + Atom2ConjEdge((33, 16), -1.0), + Atom2ConjEdge((34, 17), -1.0), + Atom2ConjEdge((35, 18), -1.0), + Atom2ConjEdge((5, 19), 1.0), + Atom2ConjEdge((36, 21), 1.0), + Atom2ConjEdge((37, 22), 1.0), + Atom2ConjEdge((38, 23), 1.0), + Atom2ConjEdge((39, 24), 1.0), + Atom2ConjEdge((40, 20), 1.0), + Atom2ConjEdge((40, 21), 1.0), + Atom2ConjEdge((40, 22), 1.0), + Atom2ConjEdge((40, 23), 1.0), + Atom2ConjEdge((40, 24), 1.0), + Atom2ConjEdge((41, 20), -1.0), + Atom2ConjEdge((42, 21), -1.0), + Atom2ConjEdge((43, 22), -1.0), + Atom2ConjEdge((44, 23), -1.0), + Atom2ConjEdge((45, 24), -1.0), + # column constraints + Atom2ConjEdge((6, 25), 1.0), + Atom2ConjEdge((16, 25), 1.0), + Atom2ConjEdge((6, 26), 1.0), + Atom2ConjEdge((26, 26), 1.0), + Atom2ConjEdge((6, 27), 1.0), + Atom2ConjEdge((36, 27), 1.0), + Atom2ConjEdge((16, 28), 1.0), + Atom2ConjEdge((26, 28), 1.0), + Atom2ConjEdge((16, 29), 1.0), + Atom2ConjEdge((36, 29), 1.0), + Atom2ConjEdge((26, 30), 1.0), + Atom2ConjEdge((36, 30), 1.0), + Atom2ConjEdge((7, 31), 1.0), + Atom2ConjEdge((17, 31), 1.0), + Atom2ConjEdge((7, 32), 1.0), + Atom2ConjEdge((27, 32), 1.0), + Atom2ConjEdge((7, 33), 1.0), + Atom2ConjEdge((37, 33), 1.0), + Atom2ConjEdge((17, 34), 1.0), + Atom2ConjEdge((27, 34), 1.0), + Atom2ConjEdge((17, 35), 1.0), + Atom2ConjEdge((37, 35), 1.0), + Atom2ConjEdge((27, 36), 1.0), + Atom2ConjEdge((37, 36), 1.0), + Atom2ConjEdge((8, 37), 1.0), + Atom2ConjEdge((18, 37), 1.0), + Atom2ConjEdge((8, 38), 1.0), + Atom2ConjEdge((28, 38), 1.0), + Atom2ConjEdge((8, 39), 1.0), + Atom2ConjEdge((38, 39), 1.0), + Atom2ConjEdge((18, 40), 1.0), + Atom2ConjEdge((28, 40), 1.0), + Atom2ConjEdge((18, 41), 1.0), + Atom2ConjEdge((38, 41), 1.0), + Atom2ConjEdge((28, 42), 1.0), + Atom2ConjEdge((38, 42), 1.0), + Atom2ConjEdge((9, 43), 1.0), + Atom2ConjEdge((19, 43), 1.0), + Atom2ConjEdge((9, 44), 1.0), + Atom2ConjEdge((29, 44), 1.0), + Atom2ConjEdge((9, 45), 1.0), + Atom2ConjEdge((39, 45), 1.0), + Atom2ConjEdge((19, 46), 1.0), + Atom2ConjEdge((29, 46), 1.0), + Atom2ConjEdge((19, 47), 1.0), + Atom2ConjEdge((39, 47), 1.0), + Atom2ConjEdge((29, 48), 1.0), + Atom2ConjEdge((39, 48), 1.0), + # diagonal constraints 1 + # 1 + Atom2ConjEdge((3, 49), 1.0), + Atom2ConjEdge((6, 49), 1.0), + Atom2ConjEdge((17, 49), 1.0), + Atom2ConjEdge((3, 50), 1.0), + Atom2ConjEdge((17, 50), 1.0), + Atom2ConjEdge((28, 50), 1.0), + Atom2ConjEdge((3, 51), 1.0), + Atom2ConjEdge((28, 51), 1.0), + Atom2ConjEdge((39, 51), 1.0), + Atom2ConjEdge((3, 52), 1.0), + Atom2ConjEdge((7, 52), 1.0), + Atom2ConjEdge((18, 52), 1.0), + Atom2ConjEdge((3, 53), 1.0), + Atom2ConjEdge((18, 53), 1.0), + Atom2ConjEdge((29, 53), 1.0), + Atom2ConjEdge((3, 54), 1.0), + Atom2ConjEdge((8, 54), 1.0), + Atom2ConjEdge((19, 54), 1.0), + Atom2ConjEdge((3, 55), 1.0), + Atom2ConjEdge((16, 55), 1.0), + Atom2ConjEdge((27, 55), 1.0), + Atom2ConjEdge((3, 56), 1.0), + Atom2ConjEdge((27, 56), 1.0), + Atom2ConjEdge((38, 56), 1.0), + Atom2ConjEdge((3, 57), 1.0), + Atom2ConjEdge((26, 57), 1.0), + Atom2ConjEdge((37, 57), 1.0), + # 2 + Atom2ConjEdge((4, 58), 1.0), + Atom2ConjEdge((6, 58), 1.0), + Atom2ConjEdge((28, 58), 1.0), + Atom2ConjEdge((4, 59), 1.0), + Atom2ConjEdge((17, 59), 1.0), + Atom2ConjEdge((39, 59), 1.0), + Atom2ConjEdge((4, 60), 1.0), + Atom2ConjEdge((7, 60), 1.0), + Atom2ConjEdge((29, 60), 1.0), + Atom2ConjEdge((4, 61), 1.0), + Atom2ConjEdge((16, 61), 1.0), + Atom2ConjEdge((38, 61), 1.0), + # 3 + Atom2ConjEdge((5, 62), 1.0), + Atom2ConjEdge((6, 62), 1.0), + Atom2ConjEdge((39, 62), 1.0), + # diagonal constraints 2 + # 1 + Atom2ConjEdge((3, 63), 1.0), + Atom2ConjEdge((9, 63), 1.0), + Atom2ConjEdge((18, 63), 1.0), + Atom2ConjEdge((3, 64), 1.0), + Atom2ConjEdge((18, 64), 1.0), + Atom2ConjEdge((27, 64), 1.0), + Atom2ConjEdge((3, 65), 1.0), + Atom2ConjEdge((27, 65), 1.0), + Atom2ConjEdge((36, 65), 1.0), + Atom2ConjEdge((3, 66), 1.0), + Atom2ConjEdge((8, 66), 1.0), + Atom2ConjEdge((17, 66), 1.0), + Atom2ConjEdge((3, 67), 1.0), + Atom2ConjEdge((17, 67), 1.0), + Atom2ConjEdge((26, 67), 1.0), + Atom2ConjEdge((3, 68), 1.0), + Atom2ConjEdge((7, 68), 1.0), + Atom2ConjEdge((16, 68), 1.0), + Atom2ConjEdge((3, 69), 1.0), + Atom2ConjEdge((19, 69), 1.0), + Atom2ConjEdge((28, 69), 1.0), + Atom2ConjEdge((3, 70), 1.0), + Atom2ConjEdge((28, 70), 1.0), + Atom2ConjEdge((37, 70), 1.0), + Atom2ConjEdge((3, 71), 1.0), + Atom2ConjEdge((29, 71), 1.0), + Atom2ConjEdge((38, 71), 1.0), + # 2 + Atom2ConjEdge((4, 72), 1.0), + Atom2ConjEdge((9, 72), 1.0), + Atom2ConjEdge((27, 72), 1.0), + Atom2ConjEdge((4, 73), 1.0), + Atom2ConjEdge((18, 73), 1.0), + Atom2ConjEdge((36, 73), 1.0), + Atom2ConjEdge((4, 74), 1.0), + Atom2ConjEdge((8, 74), 1.0), + Atom2ConjEdge((26, 74), 1.0), + Atom2ConjEdge((4, 75), 1.0), + Atom2ConjEdge((19, 75), 1.0), + Atom2ConjEdge((37, 75), 1.0), + # 3 + Atom2ConjEdge((5, 76), 1.0), + Atom2ConjEdge((9, 76), 1.0), + Atom2ConjEdge((36, 76), 1.0), + }, + ) + # conj -> atom + self.assertEqual( + set(edges[("conj", "defines", "atom")]), + { + # facts + Conj2AtomEdge((0, 2), 1.0), # conj0 (true) -> n(0) + Conj2AtomEdge((0, 3), 1.0), # conj0 (true) -> n(1) + Conj2AtomEdge((0, 4), 1.0), # conj0 (true) -> n(2) + Conj2AtomEdge((0, 5), 1.0), # conj0 (true) -> n(3) + Conj2AtomEdge((0, 12), 1.0), + Conj2AtomEdge((0, 13), 1.0), + Conj2AtomEdge((0, 14), 1.0), + Conj2AtomEdge((0, 15), 1.0), + Conj2AtomEdge((2, 1), 1.0), + Conj2AtomEdge((3, 1), 1.0), + Conj2AtomEdge((4, 1), 1.0), + Conj2AtomEdge((5, 1), 1.0), + Conj2AtomEdge((6, 1), 1.0), + Conj2AtomEdge((1, 6), 1.0), # conj1 (body) -> q(0,0) + Conj2AtomEdge((1, 7), 1.0), # conj1 (body) -> q(0,1) + Conj2AtomEdge((1, 8), 1.0), # conj1 (body) -> q(0,2) + Conj2AtomEdge((1, 9), 1.0), # conj1 (body) -> q(0,3) + Conj2AtomEdge((1, 10), 1.0), # conj1 (body) -> aux. body atom + Conj2AtomEdge((0, 22), 1.0), + Conj2AtomEdge((0, 23), 1.0), + Conj2AtomEdge((0, 24), 1.0), + Conj2AtomEdge((0, 25), 1.0), + Conj2AtomEdge((8, 1), 1.0), + Conj2AtomEdge((9, 1), 1.0), + Conj2AtomEdge((10, 1), 1.0), + Conj2AtomEdge((11, 1), 1.0), + Conj2AtomEdge((12, 1), 1.0), + Conj2AtomEdge((7, 16), 1.0), # conj7 (body) -> q(1,0) + Conj2AtomEdge((7, 17), 1.0), # conj7 (body) -> q(1,1) + Conj2AtomEdge((7, 18), 1.0), # conj7 (body) -> q(1,2) + Conj2AtomEdge((7, 19), 1.0), # conj7 (body) -> q(1,3) + Conj2AtomEdge((7, 20), 1.0), # conj7 (body) -> aux. body atom + Conj2AtomEdge((0, 32), 1.0), + Conj2AtomEdge((0, 33), 1.0), + Conj2AtomEdge((0, 34), 1.0), + Conj2AtomEdge((0, 35), 1.0), + Conj2AtomEdge((14, 1), 1.0), + Conj2AtomEdge((15, 1), 1.0), + Conj2AtomEdge((16, 1), 1.0), + Conj2AtomEdge((17, 1), 1.0), + Conj2AtomEdge((18, 1), 1.0), + Conj2AtomEdge((13, 26), 1.0), # conj13 (body) -> q(2,0) + Conj2AtomEdge((13, 27), 1.0), # conj13 (body) -> q(2,1) + Conj2AtomEdge((13, 28), 1.0), # conj13 (body) -> q(2,2) + Conj2AtomEdge((13, 29), 1.0), # conj13 (body) -> q(2,3) + Conj2AtomEdge((13, 30), 1.0), # conj13 (body) -> aux. body atom + Conj2AtomEdge((0, 42), 1.0), + Conj2AtomEdge((0, 43), 1.0), + Conj2AtomEdge((0, 44), 1.0), + Conj2AtomEdge((0, 45), 1.0), + Conj2AtomEdge((20, 1), 1.0), + Conj2AtomEdge((21, 1), 1.0), + Conj2AtomEdge((22, 1), 1.0), + Conj2AtomEdge((23, 1), 1.0), + Conj2AtomEdge((24, 1), 1.0), + Conj2AtomEdge((19, 36), 1.0), # conj19 (body) -> q(3,0) + Conj2AtomEdge((19, 37), 1.0), # conj19 (body) -> q(3,1) + Conj2AtomEdge((19, 38), 1.0), # conj19 (body) -> q(3,2) + Conj2AtomEdge((19, 39), 1.0), # conj19 (body) -> q(3,3) + Conj2AtomEdge((19, 40), 1.0), # conj19 (body) -> aux. body atom + # column constraints + Conj2AtomEdge((25, 1), 1.0), + Conj2AtomEdge((26, 1), 1.0), + Conj2AtomEdge((27, 1), 1.0), + Conj2AtomEdge((28, 1), 1.0), + Conj2AtomEdge((29, 1), 1.0), + Conj2AtomEdge((30, 1), 1.0), + Conj2AtomEdge((31, 1), 1.0), + Conj2AtomEdge((32, 1), 1.0), + Conj2AtomEdge((33, 1), 1.0), + Conj2AtomEdge((34, 1), 1.0), + Conj2AtomEdge((35, 1), 1.0), + Conj2AtomEdge((36, 1), 1.0), + Conj2AtomEdge((37, 1), 1.0), + Conj2AtomEdge((38, 1), 1.0), + Conj2AtomEdge((39, 1), 1.0), + Conj2AtomEdge((40, 1), 1.0), + Conj2AtomEdge((41, 1), 1.0), + Conj2AtomEdge((42, 1), 1.0), + Conj2AtomEdge((43, 1), 1.0), + Conj2AtomEdge((44, 1), 1.0), + Conj2AtomEdge((45, 1), 1.0), + Conj2AtomEdge((46, 1), 1.0), + Conj2AtomEdge((47, 1), 1.0), + Conj2AtomEdge((48, 1), 1.0), + # diagonal constraints 1 + Conj2AtomEdge((49, 1), 1.0), + Conj2AtomEdge((50, 1), 1.0), + Conj2AtomEdge((51, 1), 1.0), + Conj2AtomEdge((52, 1), 1.0), + Conj2AtomEdge((53, 1), 1.0), + Conj2AtomEdge((54, 1), 1.0), + Conj2AtomEdge((55, 1), 1.0), + Conj2AtomEdge((56, 1), 1.0), + Conj2AtomEdge((57, 1), 1.0), + Conj2AtomEdge((58, 1), 1.0), + Conj2AtomEdge((59, 1), 1.0), + Conj2AtomEdge((60, 1), 1.0), + Conj2AtomEdge((61, 1), 1.0), + Conj2AtomEdge((62, 1), 1.0), + # diagonal constraints 2 + Conj2AtomEdge((63, 1), 1.0), + Conj2AtomEdge((64, 1), 1.0), + Conj2AtomEdge((65, 1), 1.0), + Conj2AtomEdge((66, 1), 1.0), + Conj2AtomEdge((67, 1), 1.0), + Conj2AtomEdge((68, 1), 1.0), + Conj2AtomEdge((69, 1), 1.0), + Conj2AtomEdge((70, 1), 1.0), + Conj2AtomEdge((71, 1), 1.0), + Conj2AtomEdge((72, 1), 1.0), + Conj2AtomEdge((73, 1), 1.0), + Conj2AtomEdge((74, 1), 1.0), + Conj2AtomEdge((75, 1), 1.0), + Conj2AtomEdge((76, 1), 1.0), + }, + ) + # atom -> aggr + self.assertEqual( + set(edges[("atom", "in", "count")]), + { + Atom2CountEdge((6, 0), 1.0), + Atom2CountEdge((7, 0), 1.0), + Atom2CountEdge((8, 0), 1.0), + Atom2CountEdge((9, 0), 1.0), + Atom2CountEdge((16, 1), 1.0), + Atom2CountEdge((17, 1), 1.0), + Atom2CountEdge((18, 1), 1.0), + Atom2CountEdge((19, 1), 1.0), + Atom2CountEdge((26, 2), 1.0), + Atom2CountEdge((27, 2), 1.0), + Atom2CountEdge((28, 2), 1.0), + Atom2CountEdge((29, 2), 1.0), + Atom2CountEdge((36, 3), 1.0), + Atom2CountEdge((37, 3), 1.0), + Atom2CountEdge((38, 3), 1.0), + Atom2CountEdge((39, 3), 1.0), + }, + ) + self.assertFalse(edges[("atom", "in", "sum")]) + self.assertFalse(edges[("atom", "in", "min")]) + self.assertFalse(edges[("atom", "in", "max")]) + # aggr -> atom + self.assertEqual( + edges[("count", "defines", "atom")], + [ + Count2AtomEdge((0, 11), 1.0), + Count2AtomEdge((1, 21), 1.0), + Count2AtomEdge((2, 31), 1.0), + Count2AtomEdge((3, 41), 1.0), + ], + ) + self.assertFalse(edges[("sum", "defines", "atom")]) + self.assertFalse(edges[("min", "defines", "atom")]) + self.assertFalse(edges[("max", "defines", "atom")]) + + # choice tracking + self.assertEqual( + rg.choices, + { + Choice( + ( + ChoiceElement(PredLiteral("q", Number(0), Number(0))), + ChoiceElement(PredLiteral("q", Number(0), Number(1))), + ChoiceElement(PredLiteral("q", Number(0), Number(2))), + ChoiceElement(PredLiteral("q", Number(0), Number(3))), + ), + guards=(Guard(RelOp.EQUAL, Number(1), False), None), + ), + Choice( + ( + ChoiceElement(PredLiteral("q", Number(1), Number(0))), + ChoiceElement(PredLiteral("q", Number(1), Number(1))), + ChoiceElement(PredLiteral("q", Number(1), Number(2))), + ChoiceElement(PredLiteral("q", Number(1), Number(3))), + ), + guards=(Guard(RelOp.EQUAL, Number(1), False), None), + ), + Choice( + ( + ChoiceElement(PredLiteral("q", Number(2), Number(0))), + ChoiceElement(PredLiteral("q", Number(2), Number(1))), + ChoiceElement(PredLiteral("q", Number(2), Number(2))), + ChoiceElement(PredLiteral("q", Number(2), Number(3))), + ), + guards=(Guard(RelOp.EQUAL, Number(1), False), None), + ), + Choice( + ( + ChoiceElement(PredLiteral("q", Number(3), Number(0))), + ChoiceElement(PredLiteral("q", Number(3), Number(1))), + ChoiceElement(PredLiteral("q", Number(3), Number(2))), + ChoiceElement(PredLiteral("q", Number(3), Number(3))), + ), + guards=(Guard(RelOp.EQUAL, Number(1), False), None), + ), + }, + ) + self.assertEqual(len(rg.choice_edges), 4) + self.assertEqual( + set( + edges[("conj", "defines", "atom")][i] + for i in rg.choice_edges[ + Choice( + ( + ChoiceElement(PredLiteral("q", Number(0), Number(0))), + ChoiceElement(PredLiteral("q", Number(0), Number(1))), + ChoiceElement(PredLiteral("q", Number(0), Number(2))), + ChoiceElement(PredLiteral("q", Number(0), Number(3))), + ), + guards=(Guard(RelOp.EQUAL, Number(1), False), None), + ) + ] + ), + { + Conj2AtomEdge((1, 6), 1.0), + Conj2AtomEdge((1, 7), 1.0), + Conj2AtomEdge((1, 8), 1.0), + Conj2AtomEdge((1, 9), 1.0), + }, + ) + self.assertEqual( + set( + edges[("conj", "defines", "atom")][i] + for i in rg.choice_edges[ + Choice( + ( + ChoiceElement(PredLiteral("q", Number(1), Number(0))), + ChoiceElement(PredLiteral("q", Number(1), Number(1))), + ChoiceElement(PredLiteral("q", Number(1), Number(2))), + ChoiceElement(PredLiteral("q", Number(1), Number(3))), + ), + guards=(Guard(RelOp.EQUAL, Number(1), False), None), + ) + ] + ), + { + Conj2AtomEdge((7, 16), 1.0), + Conj2AtomEdge((7, 17), 1.0), + Conj2AtomEdge((7, 18), 1.0), + Conj2AtomEdge((7, 19), 1.0), + }, + ) + self.assertEqual( + set( + edges[("conj", "defines", "atom")][i] + for i in rg.choice_edges[ + Choice( + ( + ChoiceElement(PredLiteral("q", Number(2), Number(0))), + ChoiceElement(PredLiteral("q", Number(2), Number(1))), + ChoiceElement(PredLiteral("q", Number(2), Number(2))), + ChoiceElement(PredLiteral("q", Number(2), Number(3))), + ), + guards=(Guard(RelOp.EQUAL, Number(1), False), None), + ) + ] + ), + { + Conj2AtomEdge((13, 26), 1.0), + Conj2AtomEdge((13, 27), 1.0), + Conj2AtomEdge((13, 28), 1.0), + Conj2AtomEdge((13, 29), 1.0), + }, + ) + self.assertEqual( + set( + edges[("conj", "defines", "atom")][i] + for i in rg.choice_edges[ + Choice( + ( + ChoiceElement(PredLiteral("q", Number(3), Number(0))), + ChoiceElement(PredLiteral("q", Number(3), Number(1))), + ChoiceElement(PredLiteral("q", Number(3), Number(2))), + ChoiceElement(PredLiteral("q", Number(3), Number(3))), + ), + guards=(Guard(RelOp.EQUAL, Number(1), False), None), + ) + ] + ), + { + Conj2AtomEdge((19, 36), 1.0), + Conj2AtomEdge((19, 37), 1.0), + Conj2AtomEdge((19, 38), 1.0), + Conj2AtomEdge((19, 39), 1.0), + }, + ) + + +# TODO: check that certain_atoms works +# TODO: check individual methods +# TODO: negative aggregates + +if __name__ == "__main__": # pragma: no cover + unittest.main()