diff --git a/plugins/vikings/pyproject.toml b/plugins/vikings/pyproject.toml new file mode 100644 index 000000000..7fb331e33 --- /dev/null +++ b/plugins/vikings/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "vikings" +description = "A collection of bearded and aggressive assistants." +version = "0.1.0" + +[tool.setuptools.packages.find] +where = ["src"] + +[project.entry-points."ragna.assistants"] +assistants = "vikings.assistants" diff --git a/plugins/vikings/src/vikings/__init__.py b/plugins/vikings/src/vikings/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/plugins/vikings/src/vikings/assistants/__init__.py b/plugins/vikings/src/vikings/assistants/__init__.py new file mode 100644 index 000000000..295d0301f --- /dev/null +++ b/plugins/vikings/src/vikings/assistants/__init__.py @@ -0,0 +1,20 @@ +__all__ = ["Ivar"] + +from ragna.core import Assistant, Source + + +class IvarTheBoneless(Assistant): + """Ivar the Boneless""" + + @classmethod + def display_name(cls) -> str: + return "Vikings/IvarTheBoneless" + + @property + def max_input_size(self) -> int: + return 873 + + def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + ) -> str: + return "I am Ivar the Boneless! " diff --git a/ragna/_cli/config.py b/ragna/_cli/config.py index f1fddc1a2..68a5276c1 100644 --- a/ragna/_cli/config.py +++ b/ragna/_cli/config.py @@ -11,6 +11,7 @@ from rich.table import Table import ragna +from ragna._compat import importlib_metadata_entry_points from ragna.core import ( Assistant, Config, @@ -237,6 +238,24 @@ def _handle_unmet_requirements(components: Iterable[Type[Component]]) -> None: def _wizard_common() -> Config: config = _wizard_builtin() + if questionary.confirm( + "Do you want to install any ragna assistant plugins?", + default=False, + qmark=QMARK, + ).unsafe_ask(): + plugin_modules = [ + plugin.load() + for plugin in importlib_metadata_entry_points(group="ragna.assistants") + ] + for plugin_module in plugin_modules: + plugin_assistants = _select_components( + "assistants", + plugin_module, + Assistant, # type: ignore[type-abstract] + ) + for assistant in plugin_assistants: + config.core.assistants.append(assistant) + config.local_cache_root = Path( questionary.path( "Where should local files be stored?", diff --git a/ragna/_compat.py b/ragna/_compat.py index 58b1d4865..0a6fa5a53 100644 --- a/ragna/_compat.py +++ b/ragna/_compat.py @@ -1,7 +1,27 @@ import sys -from typing import Callable, Iterable, Iterator, Mapping, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Iterator, + Mapping, + Protocol, + TypeVar, +) -__all__ = ["itertools_pairwise", "importlib_metadata_package_distributions"] +if TYPE_CHECKING: + if sys.version_info[:2] >= (3, 10): + from importlib.metadata import EntryPoints + else: + from importlib_metadata import EntryPoints + + +__all__ = [ + "itertools_pairwise", + "importlib_metadata_package_distributions", + "importlib_metadata_entry_points", +] T = TypeVar("T") @@ -38,3 +58,20 @@ def _importlib_metadata_package_distributions() -> ( importlib_metadata_package_distributions = _importlib_metadata_package_distributions() + + +class EntryPointsCallable(Protocol): + def __call__(self, **kwargs: Any) -> "EntryPoints": + ... + + +def _importlib_metadata_entry_points() -> EntryPointsCallable: + if sys.version_info[:2] >= (3, 10): + from importlib.metadata import entry_points + else: + from importlib_metadata import entry_points + + return entry_points + + +importlib_metadata_entry_points = _importlib_metadata_entry_points()