diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6b5814f --- /dev/null +++ b/.gitignore @@ -0,0 +1,442 @@ +# Python-generated files +__pycache__/ +*.py[oc] +build/ +dist/ +wheels/ +scratchpad/ +*.egg-info +.vscode/ + +# IDE +.idea + +# Virtual environments +.venv + +# emacs backup +*~ +\#* + +vllm_backup/ + +# Created by https://www.toptal.com/developers/gitignore/api/python,direnv,visualstudiocode,pycharm,macos,jetbrains +# Edit at https://www.toptal.com/developers/gitignore?templates=python,direnv,visualstudiocode,pycharm,macos,jetbrains + +### direnv ### +.direnv +.envrc + +### JetBrains ### +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# AWS User-specific +.idea/**/aws.xml + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# SonarLint plugin +.idea/sonarlint/ + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +### JetBrains Patch ### +# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 + +# *.iml +# modules.xml +# .idea/misc.xml +# *.ipr + +# Sonarlint plugin +# https://plugins.jetbrains.com/plugin/7973-sonarlint +.idea/**/sonarlint/ + +# SonarQube Plugin +# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin +.idea/**/sonarIssues.xml + +# Markdown Navigator plugin +# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced +.idea/**/markdown-navigator.xml +.idea/**/markdown-navigator-enh.xml +.idea/**/markdown-navigator/ + +# Cache file creation bug +# See https://youtrack.jetbrains.com/issue/JBR-2257 +.idea/$CACHE_FILE$ + +# CodeStream plugin +# https://plugins.jetbrains.com/plugin/12206-codestream +.idea/codestream.xml + +# Azure Toolkit for IntelliJ plugin +# https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij +.idea/**/azureSettings.xml + +### macOS ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +### macOS Patch ### +# iCloud generated files +*.icloud + +### PyCharm ### +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff + +# AWS User-specific + +# Generated files + +# Sensitive or high-churn files + +# Gradle + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake + +# Mongo Explorer plugin + +# File-based project format + +# IntelliJ + +# mpeltonen/sbt-idea plugin + +# JIRA plugin + +# Cursive Clojure plugin + +# SonarLint plugin + +# Crashlytics plugin (for Android Studio and IntelliJ) + +# Editor-based Rest Client + +# Android studio 3.1+ serialized cache file + +### PyCharm Patch ### +# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 + +# *.iml +# modules.xml +# .idea/misc.xml +# *.ipr + +# Sonarlint plugin +# https://plugins.jetbrains.com/plugin/7973-sonarlint + +# SonarQube Plugin +# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin + +# Markdown Navigator plugin +# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced + +# Cache file creation bug +# See https://youtrack.jetbrains.com/issue/JBR-2257 + +# CodeStream plugin +# https://plugins.jetbrains.com/plugin/12206-codestream + +# Azure Toolkit for IntelliJ plugin +# https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +# End of https://www.toptal.com/developers/gitignore/api/python,direnv,visualstudiocode,pycharm,macos,jetbrains diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py new file mode 100644 index 0000000..b0204df --- /dev/null +++ b/mellea_contribs/va/__init__.py @@ -0,0 +1,9 @@ + + +from .core import Core +from .relation import Relation +from .sequence import Sequence +from .cluster import Cluster +from .subset import Subset +from .pareto import Pareto + diff --git a/mellea_contribs/va/cluster.py b/mellea_contribs/va/cluster.py new file mode 100644 index 0000000..11795ef --- /dev/null +++ b/mellea_contribs/va/cluster.py @@ -0,0 +1,149 @@ +import random +import functools +import itertools +import asyncio +import networkx as nx +import numpy as np +import matplotlib.pyplot as plt +from mellea import MelleaSession +from mellea.backends import Backend +from mellea.stdlib.base import Context +from mellea.stdlib.functional import ainstruct +from mellea.helpers.fancy_logger import FancyLogger +from mellea.helpers.event_loop_helper import _run_async_in_thread + +from pydantic import BaseModel + +from typing import ( + Literal, + Callable, + TypeVar, + List, +) + +import numpy as np + +from .util import session_wrapper, sync_wrapper +from .core import Core, abool + +T = TypeVar("T") + +async def delaunay(elems:list[T], criterion:Callable[[T,T,T],bool], k:int=3) -> nx.Graph: + + assert len(elems) >= 2 + + def select(elems:list[T]): + assert len(elems) >= 2 + _elems = elems.copy() + i1 = random.randint(0, len(_elems)-1) + r1 = _elems.pop(i1) + i2 = random.randint(0, len(_elems)-1) + r2 = _elems.pop(i2) + return r1, r2, _elems + + async def split(x:T, y:T, S:list[T]): + """Split a set S into Sx and Sy, which are closer to x or y, respectively""" + + Sx = [] + Sy = [] + tasks = [ + criterion(x, y, z) + for z in S + ] + results = await asyncio.gather(*tasks) + for z, r in zip(S, results): + if r: + Sx.append(z) + else: + Sy.append(z) + + return Sx, Sy + + async def construct(parent, elems): + if len(elems) < 2: + g = nx.Graph() + for elem in elems: + g.add_edge(parent, elem) + return g + + c1, c2, _elems = select(elems) + elems1, elems2 = await split(c1, c2, _elems) + g1, g2 = await asyncio.gather( + construct(c1, elems1), + construct(c2, elems2)) + + g = nx.compose(g1, g2) + g.add_edge(parent, c1) + g.add_edge(parent, c2) + return g + + async def tree(): + r1, r2, _elems = select(elems) + elems1, elems2 = await split(r1, r2, _elems) + g1, g2 = await asyncio.gather( + construct(r1, elems1), + construct(r2, elems2)) + g = nx.compose(g1, g2) + g.add_edge(r1, r2) + return g + + g = nx.Graph() + trees = await asyncio.gather(*[tree() for _ in range(k)]) + for t in trees: + g = nx.compose(g, t) + return g + + +async def atriplet(backend:Backend, ctx:Context, prompt: str, x:str, y:str, z:str, **kwargs) -> bool: + """Given a triplet comparison query, perform the query using the LLM. + + It returns True if Z is closer to X than is to Y. + """ + return await abool(backend, ctx, prompt + f"\nZ: {z}\nX: {x}\nY: {y}\n", **kwargs) + + +async def acluster(backend:Backend, ctx:Context, criterion:str, elems:list[str], + *, + k:int = 3, + n_clusters:int = 3, + **kwargs) -> list[set[str]]: + """Generate an approximate Delaunay graph and perform graph clustering on it. + + The graph construction method follows the n log n algorithm by Haghiri et. al. [1] + + Args: + criterion: triplet comparison criterion + elems: list of strings to cluster + k: k for k-ANNS for Delaunay Graph + n_clusters: the number of clusters + + **kwargs: accepts vote, positional, shuffle. + + Returns: + A cluster representation as list[set[str]] + + + [1] Haghiri, Siavash, Debarghya Ghoshdastidar, and Ulrike von Luxburg. + "Comparison-based nearest neighbor search." Artificial Intelligence and Statistics. PMLR, 2017. + + """ + + async def fn(x:str, y:str, z:str) -> bool: + return await atriplet(backend, ctx, criterion, x, y, z, **kwargs) + + g = await delaunay(elems, fn, k=k) + + communities = list(nx.algorithms.community.greedy_modularity_communities(g, cutoff=n_clusters, best_n=n_clusters)) + + return communities + + +class Cluster(Core): + pass + + +Cluster.atriplet = session_wrapper(atriplet) +Cluster.acluster = session_wrapper(acluster) +Cluster.triplet = sync_wrapper(Cluster.atriplet) +Cluster.cluster = sync_wrapper(Cluster.acluster) + diff --git a/mellea_contribs/va/core.py b/mellea_contribs/va/core.py new file mode 100644 index 0000000..e367fb0 --- /dev/null +++ b/mellea_contribs/va/core.py @@ -0,0 +1,104 @@ +import collections +import random +import functools +import itertools +import asyncio +from mellea import MelleaSession +from mellea.backends import Backend +from mellea.stdlib.base import Context +from mellea.stdlib.functional import ainstruct +from mellea.helpers.fancy_logger import FancyLogger +from mellea.helpers.event_loop_helper import _run_async_in_thread + +from pydantic import BaseModel + +from typing import Literal + +from .util import session_wrapper, sync_wrapper + +logger = FancyLogger.get_logger() + +class YesNo(BaseModel): + answer : Literal["yes","no"] + +async def abool(backend:Backend, ctx:Context, prompt:str, vote:int=3, **kwargs) -> bool: + """ + Answers a yes/no question. + """ + + if vote % 2 == 0: + logger.warning( + "the specified number of votes in a majority vote is even, making ties possible. Increasing the value by one to avoid this." + ) + vote += 1 + + async def fn(): + + output, _ = await ainstruct(f"{prompt} Answer yes or no.", + ctx, backend, + format=YesNo, **kwargs) + + yesno = YesNo.model_validate_json(output.value) + + return yesno.answer == "yes" + + tasks = [fn() for _ in range(vote)] + results = await asyncio.gather(*tasks) + return results.count(True) >= (vote // 2 + 1) + + +async def achoice(backend:Backend, ctx:Context, prompt:str, choices:list[str], *, vote:int=3, positional:bool=True, **kwargs) -> str: + """ + Answers a multiple-choice question. Returns an element of choices. + + Args: + vote: When >=1, it samples multiple selections in each turn, and perform a majority voting. + positional: Shuffle the order to present the elements to the LLM in order to mitigate the positional bias. + + """ + + # note: constraint decoding does not respect pydantic.conint + L = len(choices) + class Choice(BaseModel): + answer : Literal[*[ str(i) for i in range(L)]] + + async def choose(choices:list[str]) -> str: + output, _ = await ainstruct(f"{prompt}\n" + + f"Answer the index (0-{L-1}) of one of the following choices: \n" + + "\n".join([f"index {i}: {c}" for i, c in enumerate(choices)]), + ctx, backend, + format=Choice, **kwargs) + index = int(Choice.model_validate_json(output.value).answer) + return choices[index] + + if positional: + # enumerate random permutations while avoiding duplicaes + shuffled = set() + while len(shuffled) < vote: + _choices = choices.copy() + random.shuffle(_choices) + shuffled.add(tuple(_choices)) + inputs = list(shuffled) + else: + inputs = [ choices for _ in range(vote) ] + + tasks = [choose(_choices) for _choices in inputs] + + chosen = await asyncio.gather(*tasks) + + counter = collections.Counter(chosen) + + return counter.most_common(1)[0][0] + + +class Core: + """ + The Core powerup provides a core functionality for extracting the embedded reward model in the model. + """ + pass + +Core.abool = session_wrapper(abool) +Core.achoice = session_wrapper(achoice) +Core.bool = sync_wrapper(Core.abool) +Core.choice = sync_wrapper(Core.achoice) + diff --git a/mellea_contribs/va/pareto.py b/mellea_contribs/va/pareto.py new file mode 100644 index 0000000..5b416e9 --- /dev/null +++ b/mellea_contribs/va/pareto.py @@ -0,0 +1,170 @@ +import random +import functools +import itertools +import asyncio +import networkx as nx +import numpy as np +import matplotlib.pyplot as plt +from mellea import MelleaSession +from mellea.backends import Backend +from mellea.stdlib.base import Context +from mellea.stdlib.functional import ainstruct +from mellea.helpers.fancy_logger import FancyLogger +from mellea.helpers.event_loop_helper import _run_async_in_thread + +from pydantic import BaseModel + +from typing import ( + Literal, + Callable, + TypeVar, + List, + Iterable, + Awaitable, +) + +import numpy as np + +from .util import session_wrapper, sync_wrapper +from .sequence import Sequence, asort +from .relation import agt + +T = TypeVar("T") + +async def async_all(awaitables: Iterable[Awaitable[bool]]) -> bool: + """ + Asynchronously evaluate awaitables like built-in all(). + + - Returns False immediately when any awaitable resolves to False + - Cancels all remaining tasks when returning early + - Returns True only if all awaitables resolve to True + """ + tasks = {asyncio.create_task(aw) for aw in awaitables} + + try: + for completed in asyncio.as_completed(tasks): + result = await completed + if not result: + # Cancel remaining tasks + for task in tasks: + if not task.done(): + task.cancel() + # Ensure cancellation is propagated + await asyncio.gather(*tasks, return_exceptions=True) + return False + + return True + + finally: + # Safety net: cancel anything still pending + for task in tasks: + if not task.done(): + task.cancel() + +async def async_any(awaitables: Iterable[Awaitable[bool]]) -> bool: + """ + Asynchronously evaluate awaitables like built-in any(). + + - Returns True immediately when any awaitable resolves to True + - Cancels all remaining tasks when returning early + - Returns False only if all awaitables resolve to False + """ + tasks = {asyncio.create_task(aw) for aw in awaitables} + + try: + for completed in asyncio.as_completed(tasks): + result = await completed + if result: + # Cancel remaining tasks + for task in tasks: + if not task.done(): + task.cancel() + # Ensure cancellation is propagated + await asyncio.gather(*tasks, return_exceptions=True) + return True + + return False + + finally: + # Safety net: cancel anything still pending + for task in tasks: + if not task.done(): + task.cancel() + +async def async_kung(criteria:list[str], elems:list[str], acmp, asort) -> list[str]: + """ + Kung's divide-and-conquer pareto-front computation algorithm, runtime n log^{d-2} n. + """ + + assert len(criteria) > 0 + + if len(criteria) == 1: + return await asort(criteria[0], elems) + + elems = await asort(criteria[0], elems) + + half = len(elems)//2 + elems1, elems2 = elems[:half], elems[half:] + + front1 = await async_kung(criteria[1:], elems1, acmp, asort) + + async def dominated(elem): + return await \ + async_any([ + async_all([ + acmp(criterion, elem, front_elem) + for criterion in criteria[1:] + ]) + for front_elem in front1 + ]) + + elems2_pruned = [] + for elem2 in elems2: + if not await dominated(elem2): + elems2_pruned.append(elem2) + + front2 = await async_kung(criteria[1:], elems2_pruned, acmp, asort) + + return front1 + front2 + + +async def apareto(backend:Backend, ctx:Context, + criteria:list[str], + elems:list[str], + vote:int=3, + positional:bool=True, + shuffle:bool=True, + **kwargs) -> list[str]: + """Compute the pareto front of elements based on multiple comparison criteria. + + The algorithm follows the divide-and-conquer algorithm in [1]. + + Args: + criteria: a list of natural language comparison criteria between X and Y. + elems: list of strings to compute the pareto front for. + + **kwargs: accepts vote, positional, shuffle. + + Returns: + The pareto front as list[str] + + [1] Kung, Luccio, and Preparata "On Finding the Maxima of a Set of Vectors", + Journal of the ACM (JACM) 22.4 (1975): 469-476. + + """ + + async def acmp(criterion, x, y): + return await agt(backend, ctx, criterion, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) + + async def asort(criterion, elems:list[str]): + return await asort(backend, ctx, criterion, elems, vote=vote, positional=positional, shuffle=shuffle, **kwargs) + + return await async_kung(criteria, elems, acmp, asort) + + +class Pareto(Sequence): + pass + +Pareto.apareto = session_wrapper(apareto) +Pareto.pareto = sync_wrapper(Pareto.apareto) + diff --git a/mellea_contribs/va/relation.py b/mellea_contribs/va/relation.py new file mode 100644 index 0000000..4d7d6ce --- /dev/null +++ b/mellea_contribs/va/relation.py @@ -0,0 +1,261 @@ +import random +import functools +import itertools +import asyncio +from mellea import MelleaSession +from mellea.backends import Backend +from mellea.stdlib.base import Context +from mellea.stdlib.functional import ainstruct +from mellea.helpers.fancy_logger import FancyLogger +from mellea.helpers.event_loop_helper import _run_async_in_thread + +from pydantic import BaseModel + +from typing import Literal + +from .util import session_wrapper, sync_wrapper +from .core import Core, abool + + +async def abinary(backend:Backend, ctx:Context, criterion:str, x:str, y:str, *, + vote:int=3, + symmetric:bool=False, + asymmetric:bool=False, + reflexive:bool=False, + irreflexive:bool=False, + positional:bool=True, + shuffle:bool=True, **kwargs) -> bool: + """Evaluates a query that evaluates a binary relation. + + Args: + criterion: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. + x: the first element + y: the second element + vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. + symmetric: Declares the relation to be symmetric. Half of the queries swap x and y. + asymmetric: Declares the relation to be asymmetric. Half of the queries swap x and y, and asks if they violate the criterion. This mitigates LLM's psycophancy bias toward answering "yes". + reflexive: Declares the relation to be reflexive, i.e., if x == y, returns True immediately. + irreflexive: Declares the relation to be irreflexive, i.e., if x == y, returns False immediately. + positional: Permute the order of presenting x and y. This mitigates the positional bias. + shuffle: It shuffles the variation of queries (symmetric/positional variations). + This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). + For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. + Returns: + bool. + """ + + assert not (symmetric and asymmetric), "symmetric and asymmetric flags are mutually exclusive" + + if x == y: + if reflexive: + return True + if irreflexive: + return False + + if vote % 2 == 0: + logger.warning( + "the specified number of votes in a majority vote is even, making ties possible. Increasing the value by one to avoid this." + ) + vote += 1 + + if symmetric: + args = [(x,y),(y,x)] + target = [True,True] + elif asymmetric: + args = [(x,y),(y,x)] + target = [True,False] + else: + args = [(x,y)] + target = [True] + + prompts = [] + for (x, y), t in zip(args, target): + prompts.append((f"Do X and Y satisfy the following criterion? \nCriterion: {criterion}\nX:{x}\nY:{y}", t)) + if positional: + prompts.append((f"Do X and Y satisfy the following criterion? \nCriterion: {criterion}\nY:{y}\nX:{x}", t)) + + if shuffle: + random.shuffle(prompts) + + tasks = [ + abool(backend,ctx,p) + for i, (p, t) in zip(range(vote),itertools.cycle(prompts)) + ] + + answers = await asyncio.gather(*tasks) + + answers = [ t == a for (p, t), a in zip(itertools.cycle(prompts), answers) ] + + return answers.count(True) >= (vote // 2) + 1 + + +async def aternary(backend:Backend, ctx:Context, criterion:str, x:str, y:str, z:str, *, + vote:int=3, + symmetric:bool=False, + asymmetric:bool=False, + positional:bool=True, + shuffle:bool=True, + **kwargs) -> bool: + """Evaluates a query that evaluates a ternary relation. + + Args: + criterion: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. + x: the first element + y: the second element + z: the third element + vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. + symmetric: Declares the relation to be symmetric wrto x and y. Half of the queries swap x and y. + asymmetric: Declares the relation to be asymmetric wrto x and y. Half of the queries swap x and y, and asks if they violate the criterion. This mitigates LLM's psycophancy bias toward answering "yes". + positional: The queries permutes the order of presenting x, y, z. This mitigates the positional bias. + shuffle: It shuffles the variation of queries (symmetric/positional variations). + This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). + For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. + Returns: + bool. + """ + + assert not (symmetric and asymmetric), "symmetric and asymmetric flags are mutually exclusive" + + if vote % 2 == 0: + logger.warning( + "the specified number of votes in a majority vote is even, making ties possible. Increasing the value by one to avoid this." + ) + vote += 1 + + if symmetric: + args = [(x,y,z),(y,x,z)] + target = [True,True] + elif asymmetric: + args = [(x,y,z),(y,x,z)] + target = [True,False] + else: + args = [(x,y,z)] + target = [True] + + prompts = [] + for (x, y, z), t in zip(args, target): + parts = [f"X:{x}", f"Y:{y}", f"Z:{z}"] + if positional: + for _parts in itertools.permutations(parts): + prompts.append(("\n".join([f"Do X, Y and Z satisfy the following criterion?", f"Criterion: {criterion}", *_parts]), t)) + else: + prompts.append(("\n".join([f"Do X, Y and Z satisfy the following criterion?", f"Criterion: {criterion}", *parts]), t)) + + if shuffle: + random.shuffle(prompts) + + tasks = [ + abool(backend,ctx,p) + for i, (p, t) in zip(range(vote),itertools.cycle(prompts)) + ] + + answers = await asyncio.gather(*tasks) + + answers = [ t == a for (p, t), a in zip(itertools.cycle(prompts), answers) ] + + return answers.count(True) >= (vote // 2) + 1 + + +async def agt(backend:Backend, ctx:Context, criterion:str, x:str, y:str, *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, **kwargs) -> bool: + """Evaluates a query that evaluates a "greater-than" relation. + + Args: + criterion: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. + x: the first element + y: the second element + vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. + positional: Permute the order of presenting x and y. This mitigates the positional bias. + shuffle: It shuffles the variation of queries (symmetric/positional variations). + This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). + For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. + Returns: + bool. + """ + return await abinary(backend, ctx, criterion, x, y, + vote=vote, + symmetric=False, + asymmetric=True, + reflexive=False, + irreflexive=True, + shuffle=shuffle, **kwargs) + + +async def age(backend:Backend, ctx:Context, criterion:str, x:str, y:str, *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, **kwargs) -> bool: + """Evaluates a query that evaluates a "greater-than-equal" relation. + + Args: + criterion: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. + x: the first element + y: the second element + vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. + positional: Permute the order of presenting x and y. This mitigates the positional bias. + shuffle: It shuffles the variation of queries (symmetric/positional variations). + This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). + For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. + Returns: + bool. + """ + return await abinary(backend, ctx, criterion, x, y, + vote=vote, + symmetric=False, + asymmetric=True, + reflexive=True, + irreflexive=False, + shuffle=shuffle, **kwargs) + + +async def aeq(backend:Backend, ctx:Context, criterion:str, x:str, y:str, *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, **kwargs) -> bool: + """Evaluates a query that evaluates an equivalence relation. + + Args: + criterion: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. + x: the first element + y: the second element + vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. + positional: Permute the order of presenting x and y. This mitigates the positional bias. + shuffle: It shuffles the variation of queries (symmetric/positional variations). + This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). + For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. + Returns: + bool. + """ + return await abinary(backend, ctx, criterion, x, y, + vote=vote, + symmetric=True, + asymmetric=False, + reflexive=True, + irreflexive=False, + shuffle=shuffle, **kwargs) + + + +class Relation(Core): + """ + The Relation powerup defines methods for binary and ternary predicates. + Options can be used to declare the property of the predicate, + such as being symmetric or reflexive with regard to certain arguments. + """ + pass + + +Relation.abinary = session_wrapper(abinary) +Relation.aternary = session_wrapper(aternary) +Relation.agt = session_wrapper(agt) +Relation.age = session_wrapper(age) +Relation.aeq = session_wrapper(aeq) + +Relation.binary = sync_wrapper(Relation.abinary) +Relation.ternary = sync_wrapper(Relation.aternary) +Relation.gt = sync_wrapper(Relation.agt) +Relation.ge = sync_wrapper(Relation.age) +Relation.eq = sync_wrapper(Relation.aeq) + diff --git a/mellea_contribs/va/sequence.py b/mellea_contribs/va/sequence.py new file mode 100644 index 0000000..e5cfac6 --- /dev/null +++ b/mellea_contribs/va/sequence.py @@ -0,0 +1,223 @@ +import random +import functools +import itertools +import asyncio +from mellea import MelleaSession +from mellea.backends import Backend +from mellea.stdlib.base import Context +from mellea.stdlib.functional import ainstruct +from mellea.helpers.fancy_logger import FancyLogger +from mellea.helpers.event_loop_helper import _run_async_in_thread + +from pydantic import BaseModel + +from typing import Literal + +from .util import session_wrapper, sync_wrapper +from .core import abool +from .relation import Relation, agt + + +async def async_merge_sort(lst:list[str], acmp): + if len(lst) <= 1: + return lst + mid = len(lst) // 2 + left = await async_merge_sort(lst[:mid], acmp) + right = await async_merge_sort(lst[mid:], acmp) + return await async_merge(left, right, acmp) + +async def async_merge(left:list[str], right:list[str], acmp): + result = [] + while left and right: + if await acmp(left[0], right[0]): + result.append(left.pop(0)) + else: + result.append(right.pop(0)) + return result + left + right + +async def async_max(lst:list[str], acmp): + if len(lst) <= 1: + return lst[0] + mid = len(lst) // 2 + left = await async_max(lst[:mid], acmp) + right = await async_max(lst[mid:], acmp) + if await acmp(left, right): + return left + else: + return right + +async def async_mom(seq:list[str], asort, block_size=5): + """ + Median of medians algorithm for finding an approximate median. Worst-case runtime O(n) + """ + + async def median_fixed(seq): + return (await asort(seq))[len(seq)//2] + + if len(seq) <= block_size: + return await median_fixed(seq) + + # Step 1: Divide into groups of block_size + groups = itertools.batched(seq, block_size) + + # Step 2: Find median of each group + medians = await asyncio.gather(*[median_fixed(g) for g in groups]) + + # Step 3: Recursively find the pivot + return await async_mom(medians, asort, block_size=block_size) + +async def async_quickselect(seq:list[str], k, acmp, asort, block_size=5): + """ + Quickselect algorithm that uses median-of-medians for pivot selection. Worst-case runtime O(n^2) + """ + + pivot = await async_mom(medians, asort, block_size=block_size) + + # Step 4: Partition + lows, highs = [], [] + for x in seq: + if await acmp(x, pivot): + lows.append(x) + else: + highs.append(x) + + # Step block_size: Recurse + if k < len(lows): + return await async_quickselect(lows, k, acmp, asort, block_size=block_size) + elif k == len(lows): + return pivot + else: + return await async_quickselect(highs, k - len(lows), acmp, asort, block_size=block_size) + + +class MapOutput(BaseModel): + answer : str + + +async def amap(backend:Backend, ctx:Context, variable:str, output:str, elems:list[str], **kwargs) -> list[str]: + + tasks = [ + ainstruct(f"Given a value of {variable}, answer the answer of the output. \n"+ + f"Z: 3\n" + + f"Output: twice the value of Z\n" + + '{"answer":"6"}\n' + + f"{variable}: {elem}\n" + + f"Output: {output}", + ctx, backend, + format=MapOutput) + for elem in elems + ] + + return [MapOutput.model_validate_json(o.value).answer for o,_ in await asyncio.gather(*tasks)] + +async def afind(backend:Backend, ctx:Context, variable:str, criterion:str, elems:list[str], **kwargs) -> str | None: + + """ + Returns any element which satisfies the criterion about the variable. + It checks the criterion over the elements concurrently and returns the earliest element that satisfied the criterion, + cancelling all running or pending LLM calls. + + Args: + vote: When >=1, it samples multiple selections in each turn, and perform a majority voting. + """ + + async def fn(elem): + answer = await abool(backend, ctx, f"Does {variable} satisfy the criterion?\n"+ + f"{variable}: {elem}\n"+ + f"Criterion: {criterion}", + **kwargs) + return answer, elem + + tasks = [fn(elem) for elem in elems] + + for task in asyncio.as_completed(tasks): + answer, elem = await task + if answer: + return elem + + pass + +async def amerge(backend:Backend, ctx:Context, criterion:str, elems1:list[str], elems2:list[str], *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, **kwargs) -> list[str]: + """ + Given two lists already sorted according to the criterion, + merge them into a list so that the resulting list is also sorted according to the criterion. + """ + + async def acmp(x, y): + return await agt(backend, ctx, criterion, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) + + return await async_merge(elems1, elems2, acmp) + +async def asort(backend:Backend, ctx:Context, criterion:str, elems:list[str], *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, **kwargs) -> list[str]: + + async def acmp(x, y): + return await agt(backend, ctx, criterion, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) + + return await async_merge_sort(elems, acmp) + +async def amax(backend:Backend, ctx:Context, criterion:str, elems:list[str], *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, **kwargs) -> str: + + async def acmp(x, y): + return await agt(backend, ctx, criterion, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) + + return await async_max(elems, acmp) + +async def amedian(backend:Backend, ctx:Context, criterion:str, elems:list[str], *, + exact = False, + vote:int=3, + positional:bool=True, + shuffle:bool=True, + block_size:int=5, + **kwargs) -> str: + """ + If exact = True, use quickselect. + Otherwise, return the approximate median returned by median of medians. + """ + + async def acmp(x, y): + return await agt(backend, ctx, criterion, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) + + async def _asort(elems:list[str]): + return await asort(backend, ctx, criterion, elems, vote=vote, positional=positional, shuffle=shuffle, **kwargs) + + if exact: + return await async_quickselect(elems, len(elems)//2, acmp, _asort, block_size=block_size) + else: + return await async_mom(elems, _asort, block_size=block_size) + + +class Sequence(Relation): + """ + Sequence powerup provides a set of sequence operations, such as + mapping a list of strings, + sorting a list of strings, + selecting an element, or extracting the median according to some criterion. + """ + + pass + + + +Sequence.amap = session_wrapper(amap) +Sequence.afind = session_wrapper(afind) +Sequence.amerge = session_wrapper(amerge) +Sequence.asort = session_wrapper(asort) +Sequence.amax = session_wrapper(amax) +Sequence.amedian = session_wrapper(amedian) + +Sequence.map = sync_wrapper(Sequence.amap) +Sequence.find = sync_wrapper(Sequence.afind) +Sequence.merge = sync_wrapper(Sequence.amerge) +Sequence.sort = sync_wrapper(Sequence.asort) +Sequence.max = sync_wrapper(Sequence.amax) +Sequence.median = sync_wrapper(Sequence.amedian) + diff --git a/mellea_contribs/va/subset.py b/mellea_contribs/va/subset.py new file mode 100644 index 0000000..4aba733 --- /dev/null +++ b/mellea_contribs/va/subset.py @@ -0,0 +1,132 @@ +import random +import functools +import itertools +import asyncio +from mellea import MelleaSession +from mellea.backends import Backend +from mellea.stdlib.base import Context +from mellea.stdlib.functional import ainstruct +from mellea.helpers.fancy_logger import FancyLogger +from mellea.helpers.event_loop_helper import _run_async_in_thread + +from pydantic import BaseModel + +from typing import Literal + +from .util import session_wrapper, sync_wrapper +from .core import Core, abool, achoice + + +async def afilter(backend:Backend, ctx:Context, + variable: str, + criterion: str, + elems:list[str], + *, + vote:int=3, + **kwargs) -> list[str]: + """ + Returns a subset whose elements all satisfy the criterion. + + Args: + vote: When >=1, it samples multiple selections in each turn, and perform a majority voting. + """ + + if vote % 2 == 0: + logger.warning( + "the specified number of votes in a majority vote is even, making ties possible. Increasing the value by one to avoid this." + ) + vote += 1 + + async def per_elem(elem): + tasks = [ + abool(backend, ctx, f"Does {variable} satisfy the criterion?\n"+ + f"{variable}: {elem}\n"+ + f"Criterion: {criterion}") + for _ in range(vote) + ] + return asyncio.gather(*tasks).count(True) >= (vote // 2 + 1) + + tasks = [ + per_elem(elem) + for elem in elems + ] + + results = [] + for answer, elem in zip(asyncio.gather(*tasks), elems): + if answer: + results.append(elem) + + return results + + +async def asubset(backend:Backend, ctx:Context, + description:str, + criterion: str, + elems:list[str], + k:int, + *, + vote:int=3, + positional:bool=True, + **kwargs) -> list[str]: + """ + Greedily select a k-elements subset from elems, maximizing the given criterion. + + Args: + description: A decription of what the current and the output subset is meant to represent. + criterion: A decription of the desired property of the returned subset. + elems: The universe to select the subset from. + k: The number of elements to select from elems. + vote: When >=1, it samples multiple selections in each turn, and perform a majority voting. + positional: Shuffle the order to present the elements to the LLM in order to mitigate the positional bias. + + The criterion is assumed to be contain a modular or submodular aspect. + + + Example 1: + + description = "We are building a team of culturally diverse members." + + criterion = "Maximize the cultural diversity among the members." + + + Example 2: + + description = ("We need set of past legal cases that helps defending our case. " + "In our case, the defandant has ..." + "We want to see a variety of cases that are relevant to ours but" + "are also different from each other.") + + criterion = "Minimize the ovelap with the documents in the current set while staying relevant to our case." + """ + + current = [] + remaining = elems.copy() + + for _ in range(k): + chosen = await achoice(backend, ctx, f"{description}\n" + "Choose the best element to add to the current set following the criterion:\n" + f"Criterion: {criterion}\n" + + "Current set:\n" + + "\n".join(current) + "\n", + remaining, + vote=vote, + positional=positional, + **kwargs) + current.append(chosen) + remaining.remove(chosen) + + return current + + +class Subset(Core): + """ + Subset powerup provides methods for selecting a subset of the input set. + """ + pass + + +Subset.afilter = session_wrapper(afilter) +Subset.asubset = session_wrapper(asubset) + +Subset.filter = sync_wrapper(Subset.afilter) +Subset.subset = sync_wrapper(Subset.asubset) diff --git a/mellea_contribs/va/util.py b/mellea_contribs/va/util.py new file mode 100644 index 0000000..485395f --- /dev/null +++ b/mellea_contribs/va/util.py @@ -0,0 +1,22 @@ + +import functools +from mellea.helpers.event_loop_helper import _run_async_in_thread + + +def session_wrapper(fn): + """Wrap an async function so it can be called synchronously.""" + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + return fn(self.backend, self.ctx, *args, **kwargs) + return wrapper + + + +def sync_wrapper(async_fn): + """Wrap an async function so it can be called synchronously.""" + @functools.wraps(async_fn) + def wrapper(*args, **kwargs): + return _run_async_in_thread(async_fn(*args, **kwargs)) + return wrapper + + diff --git a/test/va/test_cluster.py b/test/va/test_cluster.py new file mode 100644 index 0000000..31fc8e6 --- /dev/null +++ b/test/va/test_cluster.py @@ -0,0 +1,97 @@ +import asyncio +import pytest + +from mellea_va import Cluster + +from mellea import MelleaSession, start_session +from mellea.backends.ollama import OllamaModelBackend +from mellea.backends.types import ModelOption +from mellea.stdlib.base import CBlock, SimpleContext +from mellea.stdlib.requirement import Requirement, simple_validate + +# @pytest.fixture(scope="module") +# def m() -> MelleaSession: +# return MelleaSession(backend=OllamaModelBackend(), ctx=ChatContext()) + +@pytest.fixture(scope="function") +def m(): + """Fresh Ollama session for each test.""" + session = start_session() + yield session + session.reset() + + +def rand_index(clusters_pred:dict[tuple[float, float], int], + clusters_true:dict[tuple[float, float], int]): + """ + Measures the agreement between two clusters. + """ + + count = 0 + total = 0 + for i1, (e1, cid_pred1) in enumerate(sorted(clusters_pred.items())): + cid_true1 = clusters_true[e1] + for i2, (e2, cid_pred2) in enumerate(sorted(clusters_pred.items())): + cid_true2 = clusters_true[e2] + + if i1 >= i2: + continue + + same_pred = (cid_pred1 == cid_pred2) + same_true = (cid_true1 == cid_true2) + + total += 1 + if same_pred == same_true: # both true or both false + count += 1 + + N = len(clusters_pred) + assert total == ((N * (N-1)) // 2) + return count / ((N * (N-1)) // 2) + + +async def test_cluster(m: MelleaSession): + """""" + + MelleaSession.powerup(Cluster) + + assert await m.atriplet( + "Is country Z culturally closer to country X than is to country Y?", + "United States", "Japan", "Canada"), "cultural similarity test US/Canada/Japan" + + communities = await m.acluster( + # "Is country Z geographically closer to country X than is to country Y?", + # ["Canada", "United States", + # "Spain", "France", + # "China", "Japan"], + "Is color Z more similar to color X than is to color Y?", + ["red", "pink", + "blue", "cyan", + "green", "lime"], + k=5, + n_clusters=3) + + clusters_pred = dict() + for cid, comm in enumerate(communities): + print(cid, comm) + for node in comm: + clusters_pred[node] = cid + + clusters_true = { + # "Canada":0, + # "United States":0, + # "Spain":1, + # "France":1, + # "China":2, + # "Japan":2, + "red":0, + "pink":0, + "blue":1, + "cyan":1, + "green":2, + "lime":2, + } + assert rand_index(clusters_pred, clusters_true) > 0.9 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/va/test_cluster_2d.py b/test/va/test_cluster_2d.py new file mode 100644 index 0000000..5e8e371 --- /dev/null +++ b/test/va/test_cluster_2d.py @@ -0,0 +1,213 @@ +""" +Testing Delaunay Graph Clustering approach on 2D points. +(for VA, we replace the triplet comparison with LLM-based one) +""" + +import pytest +import random +import networkx as nx +import numpy as np +import matplotlib.pyplot as plt + +Point = tuple[float,float] + +def points( + n_clusters=5, + points_per_cluster=20, + radius=10.0, + cluster_std=0.5, + seed=None, +): + """ + Generate 2D points in clusters centered at the vertices of a regular polyhedra. + + Returns + ------- + points : np.ndarray of shape (n_clusters*points_per_cluster, 2) + The generated 2D points. + """ + if seed is not None: + np.random.seed(seed) + + angles = np.linspace(0, 2 * np.pi, n_clusters, endpoint=False) + centers = np.column_stack([radius * np.cos(angles), + radius * np.sin(angles)]) + + # Generate clusters + points = [] + for cx, cy in centers: + cluster = np.random.normal( + loc=[cx, cy], + scale=cluster_std, + size=(points_per_cluster, 2) + ) + points.append(cluster) + + return np.vstack(points), points + +def criteria(x, y, z): + x = np.array(x) + y = np.array(y) + z = np.array(z) + return np.square(x-z).sum() < np.square(y-z).sum() + +def delaunay(elems, k=7): + + assert len(elems) >= 2 + + g = nx.Graph() + for elem in elems: + g.add_node(elem) + + for _ in range(k): + + def select(elems:list[Point]): + assert len(elems) >= 2 + _elems = elems.copy() + i1 = random.randint(0, len(_elems)-1) + r1 = _elems.pop(i1) + i2 = random.randint(0, len(_elems)-1) + r2 = _elems.pop(i2) + return r1, r2, _elems + + def split(x:Point, y:Point, S:list[Point]): + """Split a set S into Sx and Sy, which are closer to x or y, respectively""" + + Sx = [] + Sy = [] + for z in S: + if criteria(x, y, z): + Sx.append(z) + else: + Sy.append(z) + + return Sx, Sy + + def construct(parent, elems): + if len(elems) < 2: + for elem in elems: + g.add_edge(parent, elem) + return + + c1, c2, _elems = select(elems) + g.add_edge(parent, c1) + g.add_edge(parent, c2) + + elems1, elems2 = split(c1, c2, _elems) + construct(c1, elems1) + construct(c2, elems2) + + r1, r2, _elems = select(elems) + g.add_edge(r1, r2) + elems1, elems2 = split(r1, r2, _elems) + construct(r1, elems1) + construct(r2, elems2) + + return g + +def plot(g, communities): + """ + Plot the graph g whose nodes are 2D points [x, y]. + Also compute greedy modularity communities and color + nodes by community assignment. + """ + + # Assign a color index to each node + node_color = {} + for cid, comm in enumerate(communities): + for node in comm: + node_color[node] = cid + + # Color palette + # If many clusters, matplotlib cycles automatically + colors = [node_color[n] for n in g.nodes] + + # --- Extract node positions --- + xs = [node[0] for node in g.nodes] + ys = [node[1] for node in g.nodes] + + plt.figure(figsize=(7, 7)) + + # --- Draw edges --- + for u, v in g.edges: + plt.plot([u[0], v[0]], [u[1], v[1]], linewidth=0.8, color="gray", alpha=0.2) + + # --- Draw nodes --- + sc = plt.scatter(xs, ys, c=colors, cmap="tab10", s=35) + + plt.gca().set_aspect('equal', 'box') + plt.title("Graph with Greedy Modularity Communities") + plt.xlabel("x") + plt.ylabel("y") + plt.grid(True) + + cbar = plt.colorbar(sc) + cbar.set_label("Community ID") + + plt.savefig("test_cluster_2d.png") + # plt.show() + + pass + +def rand_index(clusters_pred:dict[tuple[float, float], int], + clusters_true:dict[tuple[float, float], int]): + """ + Measures the agreement between two clusters. + """ + + count = 0 + total = 0 + for i1, (e1, cid_pred1) in enumerate(sorted(clusters_pred.items())): + cid_true1 = clusters_true[e1] + for i2, (e2, cid_pred2) in enumerate(sorted(clusters_pred.items())): + cid_true2 = clusters_true[e2] + + if i1 >= i2: + continue + + same_pred = (cid_pred1 == cid_pred2) + same_true = (cid_true1 == cid_true2) + + total += 1 + if same_pred == same_true: # both true or both false + count += 1 + + N = len(clusters_pred) + assert total == ((N * (N-1)) // 2) + return count / ((N * (N-1)) // 2) + +def test_cluster_2d(): + elems, clusters = points(points_per_cluster=30) + elems = [ tuple(p) for p in elems ] + g = delaunay(elems, k=7) + + # To obtain exactly n communities, set both cutoff and best_n to n. + # https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.community.modularity_max.greedy_modularity_communities.html + communities = list(nx.algorithms.community.greedy_modularity_communities(g, cutoff=5, best_n=5)) + + plot(g, communities) + + clusters_pred = dict() + for cid, comm in enumerate(communities): + for node in comm: + clusters_pred[(node[0], node[1])] = cid + + clusters_true = dict() + for cid, comm in enumerate(clusters): + for node in comm: + clusters_true[(node[0], node[1])] = cid + + + # for (e1, cid_pred1), (e2, cid_pred2) in zip(sorted(clusters_pred.items()), + # sorted(clusters_true.items()),): + # assert e1 == e2 + # print(e1, cid_pred1, cid_pred2) + + assert len(clusters) == len(communities) + assert len(clusters_true) == len(clusters_pred) + assert set(clusters_true.keys()) == set(clusters_pred.keys()) + + assert rand_index(clusters_pred, clusters_true) > 0.9 + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/va/test_core.py b/test/va/test_core.py new file mode 100644 index 0000000..c9c58f2 --- /dev/null +++ b/test/va/test_core.py @@ -0,0 +1,58 @@ +import asyncio +import pytest + +import time + +from mellea_va import Core + +from mellea import MelleaSession, start_session +from mellea.backends.ollama import OllamaModelBackend +from mellea.backends.types import ModelOption +from mellea.stdlib.base import CBlock, SimpleContext +from mellea.stdlib.requirement import Requirement, simple_validate + +# @pytest.fixture(scope="module") +# def m() -> MelleaSession: +# return MelleaSession(backend=OllamaModelBackend(), ctx=ChatContext()) + +@pytest.fixture(scope="function") +def m(): + """Fresh Ollama session for each test.""" + session = start_session() + yield session + session.reset() + + + +async def test_core(m: MelleaSession): + """""" + + MelleaSession.powerup(Core) + + assert await m.abool("Is 1+1=2?") + + assert await m.achoice("Which city is in the United States?", + ["Tokyo","Boston","Paris","Melbourne"]) == "Boston" + + t1 = time.time() + await m.abool("Is 1+1=2?", vote=1) + t2 = time.time() + dt1 = t2-t1 + + t1 = time.time() + await m.abool("Is 1+1=2?", vote=5) + t2 = time.time() + dt2 = t2-t1 + + r = dt2 / dt1 + print(f"asynchronouos efficiency for 5 calls: {r}") + assert 3 < r < 5, "asynchronous call efficiency test" + + assert m.bool("Is 1+1=2?") + + assert m.choice("Which city is in the United States?", + ["Tokyo","Boston","Paris","Melbourne"]) == "Boston" + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/va/test_pareto.py b/test/va/test_pareto.py new file mode 100644 index 0000000..b786672 --- /dev/null +++ b/test/va/test_pareto.py @@ -0,0 +1,41 @@ +import asyncio +import pytest + +from mellea_va import Pareto + +from mellea import MelleaSession, start_session +from mellea.backends.ollama import OllamaModelBackend +from mellea.backends.types import ModelOption +from mellea.stdlib.base import CBlock, SimpleContext +from mellea.stdlib.requirement import Requirement, simple_validate + +# @pytest.fixture(scope="module") +# def m() -> MelleaSession: +# return MelleaSession(backend=OllamaModelBackend(), ctx=ChatContext()) + +@pytest.fixture(scope="function") +def m(): + """Fresh Ollama session for each test.""" + session = start_session() + yield session + session.reset() + + + +async def test_pareto(m: MelleaSession): + """""" + + MelleaSession.powerup(Pareto) + + assert await m.amax("Is country X larger than country Y by area?", + ["France", "United States", "Australia", "Singapore"]) == "United States" + assert await m.amax("Is the country X more Asian than the country Y?", + ["France", "United States", "Australia", "Singapore"]) == "Singapore" + + assert set(await m.apareto(["Is country X larger than country Y by area?", + "Is the country X more Asian than the country Y?"], + ["France", "United States", "Australia", "Singapore"])) == set(["United States", "Singapore"]) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/va/test_relation.py b/test/va/test_relation.py new file mode 100644 index 0000000..47d6054 --- /dev/null +++ b/test/va/test_relation.py @@ -0,0 +1,39 @@ +import asyncio +import pytest + +from mellea_va import Relation + +from mellea import MelleaSession, start_session +from mellea.backends.ollama import OllamaModelBackend +from mellea.backends.types import ModelOption +from mellea.stdlib.base import CBlock, SimpleContext +from mellea.stdlib.requirement import Requirement, simple_validate + +# @pytest.fixture(scope="module") +# def m() -> MelleaSession: +# return MelleaSession(backend=OllamaModelBackend(), ctx=ChatContext()) + +@pytest.fixture(scope="function") +def m(): + """Fresh Ollama session for each test.""" + session = start_session() + yield session + session.reset() + + + +async def test_relation(m: MelleaSession): + """""" + + MelleaSession.powerup(Relation) + + assert await m.agt("Is the number X larger than the number Y?", "2", "1"), "number test 2>1" + assert await m.agt("Is the number X smaller than the number Y?", "1", "2"), "number test 1<2" + assert await m.agt("Is the country X larger than the country Y by area?", "United States", "Japan"), "area test US > Japan" + assert await m.agt("Is the country X more densely populated than the country Y?", "Japan", "United States"), "population density test US > Japan" + + + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/va/test_sequence.py b/test/va/test_sequence.py new file mode 100644 index 0000000..2193899 --- /dev/null +++ b/test/va/test_sequence.py @@ -0,0 +1,43 @@ +import asyncio +import pytest + +from mellea_va import Sequence + +from mellea import MelleaSession, start_session +from mellea.backends.ollama import OllamaModelBackend +from mellea.backends.types import ModelOption +from mellea.stdlib.base import CBlock, SimpleContext +from mellea.stdlib.requirement import Requirement, simple_validate + +# @pytest.fixture(scope="module") +# def m() -> MelleaSession: +# return MelleaSession(backend=OllamaModelBackend(), ctx=ChatContext()) + +@pytest.fixture(scope="function") +def m(): + """Fresh Ollama session for each test.""" + session = start_session() + yield session + session.reset() + + + +async def test_sequence(m: MelleaSession): + """""" + + MelleaSession.powerup(Sequence) + + assert await m.amap("X", "a noun corresponding to X with the opposite sex", ["ox", "rooster"]) == ["cow", "hen"] + + assert await m.afind("X", "X is a plant", ["ox", "hen", "carrot", "car"]) == "carrot" + + assert await m.asort("Is country X larger than country Y by area?", + ["France", "United States", "Nigeria", "Singapore"]) == ["United States", "Nigeria", "France", "Singapore"] + assert await m.amax("Is country X larger than country Y by area?", + ["France", "United States", "Australia", "Singapore"]) == "United States" + assert await m.amedian("Is the latitude of city X larger than the latitude of city Y?", + ["Yakutsk","Taipei","Edinburgh","Singapore","Melbourne"]) == "Taipei" + + +if __name__ == "__main__": + pytest.main([__file__])