diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index c0f0797..0d07d80 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -11,8 +11,7 @@ jobs: build: strategy: matrix: - python-version: - - "3.10" + python-version: ["3.10", "3.11", "3.12"] runs-on: ubuntu-latest steps: - name: Checkout @@ -38,3 +37,5 @@ jobs: run: | . .venv/bin/activate scripts/test.sh + env: + AWS_DEFAULT_REGION: us-east-1 diff --git a/.gitignore b/.gitignore index eca740b..3f4a801 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,6 @@ cython_debug/ #.idea/ temp/ + +# directory created by VS Code "local history" extension +.history/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f2e0cd9..a738037 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,13 @@ repos: - repo: local hooks: - - id: black_check - name: black check - entry: black + - id: ruff-check + name: Run ruff check + entry: ruff check + args: [--diff] language: python - 'types_or': [python, pyi] - args: [--diff, --check, src/] + types_or: [python, pyi] + pass_filenames: true require_serial: true - id: check-added-large-files name: Check for added large files @@ -27,17 +28,21 @@ repos: entry: end-of-file-fixer language: system types: [text] - stages: [commit, push, manual] + stages: [pre-commit, pre-push, manual] - id: trailing-whitespace name: Trim Trailing Whitespace entry: trailing-whitespace-fixer language: system types: [text] - stages: [commit, push, manual] - - id: pyright - name: pyright - entry: pyright - language: python - 'types_or': [python, pyi] - args: [--verbose, .] - require_serial: true + stages: [pre-commit, pre-push, manual] + # Pre-commit runs pyright-python in its own virtual environment by + # default which means it does not detect installed dependencies. The + # virtual env can be specified in pyrightconfig.json via the "venvPath" + # and "venv" variables. However, this doesn't seem to work with + # conda/mamba environments and is less robust anyway. + # - id: pyright + # name: pyright + # entry: pyright src/ --verbose + # language: python + # 'types_or': [python, pyi] + # require_serial: true diff --git a/README.md b/README.md index 50fb8f4..41a2cf4 100644 --- a/README.md +++ b/README.md @@ -3,11 +3,18 @@ Contains Geospatial AI/ML related code ## Developing -1. Install python and [uv](https://github.com/astral-sh/uv) -2. Checkout the code -3. Run `scripts/recreate_venv.sh` -4. Run `pre-commit install` to install the pre commit changes -5. Make changes -6. Verify linting passes `scripts/lint.sh` -7. Verify tests pass `scripts/test.sh` -8. Commit and push your changes +1. Checkout the code. +1. Create/activate your Python environment of choice. +1. Install uv: `pip install uv`. +1. Install dependencies: `uv pip install -r pyproject.toml`. +1. Install dev dependencies: `uv pip install -r pyproject.toml --extra dev`. +1. Run `pre-commit install` to install pre-commit hooks. +1. Configure your editor for realtime linting: + - For VS Code: + - Set the correct Python environment for the workspace via `ctrl+shift+P` > `Python: Select Interpreter`. + - Install the Pylance and Ruff extensions. +1. Make changes. +1. Verify linting passes `scripts/lint.sh`. +1. Verify tests pass `scripts/test.sh`. +1. Commit and push your changes. + - Note: if using Gitkraken, launch it from the terminal (with `gitkraken`) with the correct python environment activated to ensure that it can use the pre-commit hooks. diff --git a/pyproject.toml b/pyproject.toml index 99eade5..99904bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,8 @@ dependencies = [ "boto3-stubs[bedrock-runtime]>=1.35.20", "pydantic>=2.9.1", "shapely>=2.0.6", - "types-shapely>=2.0.0.20240820" + "types-shapely>=2.0.0.20240820", + "function_schema>=0.4.4", ] dynamic = ["version"] @@ -24,23 +25,81 @@ dynamic = ["version"] [tool.pytest.ini_options] pythonpath = "src" -testpaths=[ - "tests" -] +testpaths = ["tests"] [project.urls] Github = "https://github.com/Element84/e84-geoai-common" [project.optional-dependencies] -debugging = [ - "folium>=0.17.0" -] +debugging = ["folium>=0.17.0"] dev = [ "pytest>=8.3.3", "ipykernel>=6.29.5", - "black>=24.8.0", + "ruff>=0.6.8", "pyright>=1.1.381", "build>=1.2.2", "pre-commit>=3.8.0", - "pre-commit-hooks>=4.6.0" + "pre-commit-hooks>=4.6.0", + "moto>=5.0.20", +] + + +[tool.pyright] +include = ["src/"] +ignore = ["**/tests/**", "**/venv/**", "*.pyc"] +typeCheckingMode = "strict" +reportGeneralTypeIssues = true +reportImplicitStringConcatenation = "none" +reportPropertyTypeMismatch = "error" +reportShadowedImports = "error" +reportTypedDictNotRequiredAccess = "none" +reportUninitializedInstanceVariable = "error" +reportUnknownArgumentType = "error" +reportUnknownMemberType = "error" +reportUnknownVariableType = "error" +reportUnnecessaryComparison = "error" +reportIncompatibleVariableOverride = "none" + +[tool.ruff] +line-length = 79 + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.lint] +# http://docs.astral.sh/ruff/rules/ +select = ['ALL'] +ignore = [ + # Unnecessary assignment before return statement + 'RET504', + # Trailing comma missing + 'COM812', + # Missing docstring for module + 'D100', + # Missing docstring in magic method + 'D105', + # 1 blank line required before class docstring + 'D203', + # Multi-line docstring summary should start at the second line + 'D213', +] + +[tool.ruff.lint.per-file-ignores] +'__init__.py' = [ + # Module level import not at top of cell + 'E402', + # Imported but unused + 'F401', +] +'tests/**/*' = [ + # Use of assert detected + 'S101', + # Missing return type annotation for public function + 'ANN201', + # Missing docstrings + 'D1', + # Private member accessed + 'SLF001', + # magic values + 'PLR2004', ] diff --git a/pyrightconfig.json b/pyrightconfig.json deleted file mode 100644 index 874eb8b..0000000 --- a/pyrightconfig.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "pythonVersion": "3.10", - "venvPath": ".", - "venv": ".venv", - "typeCheckingMode": "strict", - - "reportGeneralTypeIssues": true, - "reportImplicitStringConcatenation": "none", - "reportPropertyTypeMismatch": "error", - "reportShadowedImports": "error", - "reportTypedDictNotRequiredAccess": "none", - "reportUninitializedInstanceVariable": "error", - "reportUnknownArgumentType": "error", - "reportUnknownMemberType": "error", - "reportUnknownVariableType": "error", - "reportUnnecessaryComparison": "error" -} diff --git a/requirements.txt b/requirements.txt index 7ad2596..cd488ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,21 +2,20 @@ # uv pip compile --refresh --all-extras pyproject.toml -o requirements.txt annotated-types==0.7.0 # via pydantic -appnope==0.1.4 - # via ipykernel -asttokens==2.4.1 +asttokens==3.0.0 # via stack-data -black==24.10.0 - # via e84-geoai-common (pyproject.toml) -boto3==1.35.71 - # via e84-geoai-common (pyproject.toml) -boto3-stubs==1.35.71 +boto3==1.35.75 + # via + # e84-geoai-common (pyproject.toml) + # moto +boto3-stubs==1.35.75 # via e84-geoai-common (pyproject.toml) -botocore==1.35.71 +botocore==1.35.75 # via # boto3 + # moto # s3transfer -botocore-stubs==1.35.71 +botocore-stubs==1.35.74 # via boto3-stubs branca==0.8.0 # via folium @@ -24,30 +23,30 @@ build==1.2.2.post1 # via e84-geoai-common (pyproject.toml) certifi==2024.8.30 # via requests +cffi==1.17.1 + # via cryptography cfgv==3.4.0 # via pre-commit charset-normalizer==3.4.0 # via requests -click==8.1.7 - # via black comm==0.2.2 # via ipykernel +cryptography==44.0.0 + # via moto debugpy==1.8.9 # via ipykernel decorator==5.1.1 # via ipython distlib==0.3.9 # via virtualenv -exceptiongroup==1.2.2 - # via - # ipython - # pytest executing==2.1.0 # via stack-data filelock==3.16.1 # via virtualenv folium==0.18.0 # via e84-geoai-common (pyproject.toml) +function-schema==0.4.5 + # via e84-geoai-common (pyproject.toml) identify==2.6.3 # via pre-commit idna==3.10 @@ -64,6 +63,7 @@ jinja2==3.1.4 # via # branca # folium + # moto jmespath==1.0.1 # via # boto3 @@ -75,15 +75,17 @@ jupyter-core==5.7.2 # ipykernel # jupyter-client markupsafe==3.0.2 - # via jinja2 + # via + # jinja2 + # werkzeug matplotlib-inline==0.1.7 # via # ipykernel # ipython -mypy-boto3-bedrock-runtime==1.35.56 +moto==5.0.22 + # via e84-geoai-common (pyproject.toml) +mypy-boto3-bedrock-runtime==1.35.75 # via boto3-stubs -mypy-extensions==1.0.0 - # via black nest-asyncio==1.6.0 # via ipykernel nodeenv==1.9.1 @@ -97,19 +99,15 @@ numpy==2.1.3 # types-shapely packaging==24.2 # via - # black # build # ipykernel # pytest parso==0.8.4 # via jedi -pathspec==0.12.1 - # via black pexpect==4.9.0 # via ipython platformdirs==4.3.6 # via - # black # jupyter-core # virtualenv pluggy==1.5.0 @@ -126,7 +124,9 @@ ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data -pydantic==2.10.2 +pycparser==2.22 + # via cffi +pydantic==2.10.3 # via e84-geoai-common (pyproject.toml) pydantic-core==2.27.1 # via pydantic @@ -134,42 +134,44 @@ pygments==2.18.0 # via ipython pyproject-hooks==1.2.0 # via build -pyright==1.1.389 +pyright==1.1.390 # via e84-geoai-common (pyproject.toml) -pytest==8.3.3 +pytest==8.3.4 # via e84-geoai-common (pyproject.toml) python-dateutil==2.9.0.post0 # via # botocore # jupyter-client + # moto pyyaml==6.0.2 - # via pre-commit + # via + # pre-commit + # responses pyzmq==26.2.0 # via # ipykernel # jupyter-client requests==2.32.3 - # via folium + # via + # folium + # moto + # responses +responses==0.25.3 + # via moto ruamel-yaml==0.18.6 # via pre-commit-hooks ruamel-yaml-clib==0.2.12 # via ruamel-yaml +ruff==0.8.1 + # via e84-geoai-common (pyproject.toml) s3transfer==0.10.4 # via boto3 shapely==2.0.6 # via e84-geoai-common (pyproject.toml) -six==1.16.0 - # via - # asttokens - # python-dateutil +six==1.17.0 + # via python-dateutil stack-data==0.6.3 # via ipython -tomli==2.2.1 - # via - # black - # build - # pre-commit-hooks - # pytest tornado==6.4.2 # via # ipykernel @@ -182,7 +184,7 @@ traitlets==5.14.3 # jupyter-client # jupyter-core # matplotlib-inline -types-awscrt==0.23.1 +types-awscrt==0.23.3 # via botocore-stubs types-s3transfer==0.10.4 # via boto3-stubs @@ -190,10 +192,6 @@ types-shapely==2.0.0.20241112 # via e84-geoai-common (pyproject.toml) typing-extensions==4.12.2 # via - # black - # boto3-stubs - # ipython - # mypy-boto3-bedrock-runtime # pydantic # pydantic-core # pyright @@ -201,9 +199,14 @@ urllib3==2.2.3 # via # botocore # requests + # responses virtualenv==20.28.0 # via pre-commit wcwidth==0.2.13 # via prompt-toolkit +werkzeug==3.1.3 + # via moto +xmltodict==0.14.2 + # via moto xyzservices==2024.9.0 # via folium diff --git a/scripts/lint.sh b/scripts/lint.sh index 5d8c230..b937200 100755 --- a/scripts/lint.sh +++ b/scripts/lint.sh @@ -6,8 +6,8 @@ set -e -o pipefail -echo "Running black" -black --diff --check src/ +echo "Running Ruff" +ruff check src/ --diff echo "Running pyright" pyright . diff --git a/scripts/refresh_requirements.sh b/scripts/refresh_requirements.sh index 33a9ad6..c769325 100755 --- a/scripts/refresh_requirements.sh +++ b/scripts/refresh_requirements.sh @@ -6,8 +6,6 @@ set -e -o pipefail -rm -rf .venv - uv pip compile \ --refresh \ --all-extras \ @@ -15,6 +13,3 @@ uv pip compile \ pyproject.toml \ -o requirements.txt -uv venv - -uv pip install -r requirements.txt diff --git a/scripts/refresh_requirements_venv.sh b/scripts/refresh_requirements_venv.sh new file mode 100755 index 0000000..33a9ad6 --- /dev/null +++ b/scripts/refresh_requirements_venv.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +#################################################################################################### +# Pulls downs the latest requirements as defined in the pyproject.toml and requirements.in files. +#################################################################################################### + +set -e -o pipefail + +rm -rf .venv + +uv pip compile \ + --refresh \ + --all-extras \ + --upgrade \ + pyproject.toml \ + -o requirements.txt + +uv venv + +uv pip install -r requirements.txt diff --git a/src/e84_geoai_common/__init__.py b/src/e84_geoai_common/__init__.py index e69de29..2989ffd 100644 --- a/src/e84_geoai_common/__init__.py +++ b/src/e84_geoai_common/__init__.py @@ -0,0 +1,9 @@ +"""Common Geospatial AI/ML code for Element 84 projects.""" + +import logging + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) diff --git a/src/e84_geoai_common/debugging.py b/src/e84_geoai_common/debugging.py index e077fb3..a854522 100644 --- a/src/e84_geoai_common/debugging.py +++ b/src/e84_geoai_common/debugging.py @@ -1,10 +1,14 @@ -import folium # type: ignore +import folium # type: ignore[reportMissingImports] from shapely import GeometryCollection from shapely.geometry.base import BaseGeometry from e84_geoai_common.util import timed_function -SEARCH_AREA_STYLE = {"fillColor": "transparent", "color": "#FF0000", "weight": 3} +SEARCH_AREA_STYLE = { + "fillColor": "transparent", + "color": "#FF0000", + "weight": 3, +} @timed_function @@ -14,24 +18,32 @@ def display_geometry( selected_geometry: BaseGeometry | None = None, search_area: BaseGeometry | None = None, ) -> folium.Map: - """ - Display the provided geometries on a folium Map with optional highlighting of a selected geometry or search area. + """Display provided geometries on a folium Map. - Parameters: - - geoms: List of shapely BaseGeometry objects to be displayed on the map. - - selected_geometry: Optional BaseGeometry object to highlight on the map. - - search_area: Optional BaseGeometry object to display as a search area on the map. + Displays the provided geometries on a folium Map with optional highlighting + of a selected geometry or search area. - Returns: - Folium Map object displaying the provided geometries with optional highlighting. + If a selected_geometry is provided, it will be highlighted on the map. If a + search_area is provided, it will also be displayed on the map. If neither + is provided, the center of the bounding box of all geometries will be + calculated and used as the center of the map. - If a selected_geometry is provided, it will be highlighted on the map. If a search_area is provided, - it will also be displayed on the map. If neither is provided, the center of the bounding box of all - geometries will be calculated and used as the center of the map. + The map will be fitted to display all provided geometries with the + calculated center point. - The map will be fitted to display all provided geometries with the calculated center point. + Note: The SEARCH_AREA_STYLE dictionary defines the style for the search + area displayed on the map. - Note: The SEARCH_AREA_STYLE dictionary defines the style for the search area displayed on the map. + Args: + geoms: List of shapely BaseGeometry objects to be displayed on the map. + selected_geometry: Optional BaseGeometry object to highlight on the + map. + search_area: Optional BaseGeometry object to display as a search area + on the map. + + Returns: + Folium Map object displaying the provided geometries with optional + highlighting. """ if selected_geometry: min_lon, min_lat, max_lon, max_lat = selected_geometry.bounds @@ -57,13 +69,13 @@ def display_geometry( for geom in geoms: g = folium.GeoJson(geom.__geo_interface__) - g.add_to(m) # type: ignore + g.add_to(m) if search_area: g = folium.GeoJson( search_area.__geo_interface__, - style_function=lambda _x: SEARCH_AREA_STYLE, # type: ignore + style_function=lambda _x: SEARCH_AREA_STYLE, # type: ignore[reportUnknownLambdaType] ) - g.add_to(m) # type: ignore + g.add_to(m) return m diff --git a/src/e84_geoai_common/geojson.py b/src/e84_geoai_common/geojson.py index 3770a4c..f30d590 100644 --- a/src/e84_geoai_common/geojson.py +++ b/src/e84_geoai_common/geojson.py @@ -1,32 +1,36 @@ -"""Pydantic models for GeoJSON features""" +"""Pydantic models for GeoJSON features.""" from typing import Annotated, Any, Generic, Literal, TypeVar, cast + from pydantic import ( BaseModel, ConfigDict, SkipValidation, - field_validator, field_serializer, + field_validator, ) from shapely.geometry.base import BaseGeometry from e84_geoai_common.geometry import geometry_from_geojson_dict - T = TypeVar("T") class Feature(BaseModel, Generic[T]): - """ - Represents a feature object as defined in the GeoJSON format. + """Represents a feature object as defined in the GeoJSON format. Attributes: - type: Literal["Feature"] - specifies the type of the GeoJSON object as 'Feature'. - geometry: Annotated[BaseGeometry, SkipValidation] - represents the geometry of the feature. - properties: T - generic type representing the properties associated with the feature. + type: Literal["Feature"] - specifies the type of the GeoJSON object as + 'Feature'. + geometry: Annotated[BaseGeometry, SkipValidation] - represents the + geometry of the feature. + properties: T - generic type representing the properties associated + with the feature. """ - model_config = ConfigDict(strict=True, frozen=True, arbitrary_types_allowed=True) + model_config = ConfigDict( + strict=True, frozen=True, arbitrary_types_allowed=True + ) type: Literal["Feature"] = "Feature" geometry: Annotated[BaseGeometry, SkipValidation] @@ -34,15 +38,16 @@ class Feature(BaseModel, Generic[T]): @field_validator("geometry", mode="before") @classmethod - def _parse_shapely_geometry(cls, d: Any) -> BaseGeometry: + def _parse_shapely_geometry(cls, d: Any) -> BaseGeometry: # noqa: ANN401 if isinstance(d, dict): return geometry_from_geojson_dict(cast(dict[str, Any], d)) - elif isinstance(d, BaseGeometry): + if isinstance(d, BaseGeometry): return d - else: - raise Exception( - "geometry must be a geojson feature dictionary or a shapely geometry" - ) + msg = ( + "geometry must be a geojson feature dictionary or " + "a shapely geometry." + ) + raise TypeError(msg) @field_serializer("geometry") def _shapely_geometry_to_json(self, g: BaseGeometry) -> dict[str, Any]: @@ -50,13 +55,14 @@ def _shapely_geometry_to_json(self, g: BaseGeometry) -> dict[str, Any]: class FeatureCollection(BaseModel, Generic[T]): - """ - Represents a collection of feature objects defined in the GeoJSON format. + """Represents a collection of feature objects in the GeoJSON format. Attributes: model_config: ConfigDict - configuration settings for the model. - type: Literal["FeatureCollection"] - specifies the type of the GeoJSON object as 'FeatureCollection'. - features: list[Feature[T]] - a list of features included in the collection. + type: Literal["FeatureCollection"] - specifies the type of the GeoJSON + object as 'FeatureCollection'. + features: list[Feature[T]] - a list of features included in the + collection. """ # Extra fields are allowed to be open like GeoJSON is. diff --git a/src/e84_geoai_common/geometry.py b/src/e84_geoai_common/geometry.py index 17c51bc..6a75fad 100644 --- a/src/e84_geoai_common/geometry.py +++ b/src/e84_geoai_common/geometry.py @@ -1,11 +1,11 @@ -"""This contains helpers for dealing with geometry.""" +"""Helpers for dealing with geometry.""" import json import math from typing import Any + import shapely import shapely.geometry - from shapely import GeometryCollection from shapely.geometry.base import BaseGeometry @@ -13,15 +13,12 @@ def geometry_from_wkt(wkt: str) -> BaseGeometry: - """ - Convert a Well-Known Text (WKT) string representation of a geometry into a shapely BaseGeometry object. - """ - return shapely.from_wkt(wkt) # type: ignore + """Create shapely geometry from Well-Known Text (WKT) string.""" + return shapely.from_wkt(wkt) # type: ignore[reportUnknownVariableType] def geometry_from_geojson_dict(geom: dict[str, Any]) -> BaseGeometry: - """ - Construct a shapely BaseGeometry object from a dictionary representation of a GeoJSON geometry. + """Create shapely geometry from GeoJSON dict. Example: Sample usage of the function: @@ -38,13 +35,12 @@ def geometry_from_geojson_dict(geom: dict[str, Any]) -> BaseGeometry: def geometry_from_geojson(geojson: str) -> BaseGeometry: - """Construct a shapely BaseGeometry object from a string representation of a GeoJSON geometry.""" + """Create shapely geometry from GeoJSON string.""" return geometry_from_geojson_dict(json.loads(geojson)) def geometry_to_geojson(geom: BaseGeometry) -> str: - """ - Convert a shapely BaseGeometry object into a GeoJSON string representation. + """Convert shapely geometry to GeoJSON string. Example: Sample usage of the function: @@ -58,64 +54,73 @@ def geometry_to_geojson(geom: BaseGeometry) -> str: return json.dumps(geom.__geo_interface__) -def geometry_point_count(geom: BaseGeometry) -> int: - """Returns the number of points in both exterior and interior (if any) in a Geometry.""" +def geometry_point_count(geom: BaseGeometry) -> int: # noqa: PLR0911 + """Count number of points in exterior and interior (if any) of geometry.""" if isinstance(geom, shapely.geometry.Point): return 1 - elif isinstance(geom, shapely.geometry.MultiPoint): + if isinstance(geom, shapely.geometry.MultiPoint): return sum(geometry_point_count(g) for g in geom.geoms) - elif isinstance(geom, shapely.geometry.Polygon): + if isinstance(geom, shapely.geometry.Polygon): exterior_count = geometry_point_count(geom.exterior) interior_count = sum( geometry_point_count(interior) for interior in geom.interiors ) return exterior_count + interior_count - elif isinstance(geom, shapely.geometry.MultiPolygon): + if isinstance(geom, shapely.geometry.MultiPolygon): return sum(geometry_point_count(g) for g in geom.geoms) - elif isinstance(geom, shapely.geometry.LinearRing): + if isinstance(geom, shapely.geometry.LinearRing): return len(geom.coords) - 1 - elif isinstance(geom, shapely.geometry.LineString): + if isinstance(geom, shapely.geometry.LineString): return len(geom.coords) - elif isinstance(geom, shapely.geometry.MultiLineString): + if isinstance(geom, shapely.geometry.MultiLineString): return sum(len(line.coords) for line in geom.geoms) - elif isinstance(geom, shapely.geometry.GeometryCollection): - return sum(geometry_point_count(g) for g in geom.geoms) # type: ignore - else: - raise TypeError(f"Unsupported geometry type: {type(geom).__name__}") + if isinstance(geom, shapely.geometry.GeometryCollection): + return sum(geometry_point_count(g) for g in geom.geoms) # type: ignore[reportUnknownVariableType] + msg = f"Unsupported geometry type: {type(geom).__name__}" + raise TypeError(msg) @timed_function -def simplify_geometry(geom: BaseGeometry, max_points: int = 3_000) -> BaseGeometry: - """ - Simplify a shapely geometry object by reducing the number of points in the geometry while - preserving its overall shape. +def simplify_geometry( + geom: BaseGeometry, max_points: int = 3_000 +) -> BaseGeometry: + """Simplify geometry. + + Simplifies a shapely geometry object by reducing the number of points in + the geometry while preserving its overall shape. - Parameters: + Args: geom (BaseGeometry): The shapely geometry object to be simplified. - max_points (int): The maximum number of points allowed in the simplified geometry. + max_points (int): The maximum number of points allowed in the + simplified geometry. + + Raises: + ValueError: If geometry cannot be simplified to under max_points + points. """ num_points = geometry_point_count(geom) if num_points < max_points: return geom - # Repeatedly try different tolerances to simplify it just until it reaches below the maximum - # number of points. + # Repeatedly try different tolerances to simplify it just until it reaches + # below the maximum number of points. for power in range(-7, 0): tolerance: float = pow(10, power) simplified = geom.simplify(tolerance) if geometry_point_count(simplified) < max_points: return simplified - raise Exception( - "Unable to simplify the geometry enough to get it under the maximum number of points" + msg = ( + "Unable to simplify the geometry enough to get it under the maximum " + "number of points" ) + raise ValueError(msg) -def BoundingBox( +def BoundingBox( # noqa: N802 *, west: float, south: float, east: float, north: float ) -> BaseGeometry: - """ - Construct a bounding box geometry using the provided coordinates. + """Construct a bounding box geometry using the provided coordinates. - Parameters: + Args: west (float): The western longitude of the bounding box. south (float): The southern latitude of the bounding box. east (float): The eastern longitude of the bounding box. @@ -135,35 +140,39 @@ def BoundingBox( def between(g1: BaseGeometry, g2: BaseGeometry) -> BaseGeometry: - """Returns the geometry between the two geometries""" + """Return the geometry between the two geometries.""" coll = GeometryCollection([g1, g2]) return coll.convex_hull - g1.convex_hull - g2.convex_hull def degrees_to_radians(deg: float) -> float: - """Converts degrees to radians""" + """Convert degrees to radians.""" return deg * (math.pi / 180.0) def add_buffer(g: BaseGeometry, distance_km: float) -> BaseGeometry: - """ - Adds a buffer around the input geometry. + """Add a buffer around the input geometry. + + This function calculates the latitude and longitude distances based on the + input distance in kilometers. The longitude distance is adjusted based on + the average latitude of the input geometry. A buffer of the calculated + average distance between longitude and latitude is added around the input + geometry. - Parameters: - g (BaseGeometry): The input geometry for which the buffer is to be added. + Args: + g (BaseGeometry): The input geometry for which the buffer is to be + added. distance_km (float): The distance in kilometers for the buffer. Returns: - BaseGeometry: A new geometry object with the buffer added around the input geometry. - - Description: - This function calculates the latitude and longitude distances based on the input distance in kilometers. - The longitude distance is adjusted based on the average latitude of the input geometry. - A buffer of the calculated average distance between longitude and latitude is added around the input geometry. + BaseGeometry: A new geometry object with the buffer added around the + input geometry. Note: - - This function assumes a basic circle buffer and may not work correctly near the poles. - - For accurate distance calculations, consider using a more sophisticated approach. + - This function assumes a basic circle buffer and may not work + correctly near the poles. + - For accurate distance calculations, consider using a more + sophisticated approach. Example: Sample usage of the function: @@ -180,8 +189,8 @@ def add_buffer(g: BaseGeometry, distance_km: float) -> BaseGeometry: avg_lat_rad = degrees_to_radians(avg_lat) lon_distance = distance_km / (math.cos(avg_lat_rad) * 111.32) - # Since we're creating a basic circle around the geometry we'll do something dumb here - # and just average the longitude and latitude distance. This will fall apart at the poles - # but works for our current use cases. + # Since we're creating a basic circle around the geometry we'll do + # something dumb here and just average the longitude and latitude distance. + # This will fall apart at the poles but works for our current use cases. return g.buffer((lon_distance + lat_distance) / 2.0) diff --git a/src/e84_geoai_common/llm/README.md b/src/e84_geoai_common/llm/README.md new file mode 100644 index 0000000..b4e191d --- /dev/null +++ b/src/e84_geoai_common/llm/README.md @@ -0,0 +1,100 @@ +# Usage + +## LLM + +```py +from e84_geoai_common.llm.core import LLMInferenceConfig, LLMMessage +from e84_geoai_common.llm.models.claude import BedrockClaudeLLM + + +def get_temperature_celsius() -> float: + """Returns the current temperature in celsius.""" + return 23.5 + + +llm = BedrockClaudeLLM() + +inference_cfg = LLMInferenceConfig(tools=[get_temperature_celsius]) +messages = [ + LLMMessage(role="user", content="What is the current temperature?") +] + +llm.prompt( + messages=messages, + inference_cfg=inference_cfg, + auto_use_tools=True, +) + +# Output: +# ClaudeAssistantMessage(role='assistant', content=[ClaudeTextContent(type='text', text='The current temperature is 23.5 degrees Celsius.')]) +``` + +## Agent + +```py +from pydantic import BaseModel, Field +from rich import print + +from e84_geoai_common.llm.agents.data_extraction_agent import DataExtractionAgent +from e84_geoai_common.llm.models.claude import BedrockClaudeLLM + + +class ExtractedLocation(BaseModel): + """A location extracted from a piece of text.""" + + name: str = Field( + ..., + description="Name or address of the location.", + examples=[ + "New York", + "Germany", + "1435 Walnut St Philadelphia, PA, USA", + ], + ) + role: str = Field( + ..., + description="The role of the location in the text.", + examples=[ + "The city the game was played in.", + "Country whose national team participated in the game.", + ], + ) + + +class ExtractedLocations(BaseModel): + """List of locations extracted from a piece of text. Can be empty.""" + + locations: list[ExtractedLocation] = [] + + +llm = BedrockClaudeLLM() +extraction_agent = DataExtractionAgent(llm, ExtractedLocations) + +# source: https://en.wikinews.org/wiki/New_Zealand_defeats_South_Africa_to_win_2024_women%27s_T20_cricket_world_cup +text = """\ +On October 20, New Zealand won the 2024 ICC Women's T20 World Cup, defeating South Africa by 32 runs in the tournament's final at the Dubai International Cricket Stadium. This marked the ninth edition of the tournament and New Zealand's first Women's T20 World Cup title. + +South Africa won the toss and chose to field first. New Zealand, batting first, posted a total of 158 runs for the loss of five wickets. Amelia Kerr was the top scorer with 43 runs, while Brooke Halliday contributed 38 runs. Nonkululeko Mlaba was the most successful bowler for South Africa, taking two wickets. + +In their chase, South Africa struggled to keep up with the required run rate, managing 126 runs for the loss of nine wickets. Laura Wolvaardt, the South African captain, scored 33 runs from 27 balls, but the team could not reach the target. For New Zealand, both Amelia Kerr and Rosemary Mair took three wickets each. + +Amelia Kerr was named both Player of the Match and Player of the Tournament for her all-round performance with both bat and ball. + +After the match, South African captain Laura Wolvaardt commented, "We had a really good semi-final, the focus was to reset, but we didn't nail our cricket today." New Zealand's Amelia Kerr expressed her thoughts, saying, "I'm a little bit speechless and I'm just so stoked to get the win after all the team has been through."\ +""" # noqa: E501 + +data = extraction_agent.run(text) +print(data) + +# Output: +# ExtractedLocations( +# locations=[ +# ExtractedLocation(name='New Zealand', role='The country whose national team won the tournament.'), +# ExtractedLocation( +# name='South Africa', +# role='The country whose national team participated in the final and lost.' +# ), +# ExtractedLocation(name='Dubai International Cricket Stadium', role='The venue where the final was played.') +# ] +# ) +``` diff --git a/src/e84_geoai_common/llm/__init__.py b/src/e84_geoai_common/llm/__init__.py index 36796ab..ff01e0e 100644 --- a/src/e84_geoai_common/llm/__init__.py +++ b/src/e84_geoai_common/llm/__init__.py @@ -1,12 +1 @@ -from e84_geoai_common.llm.core import LLM as _LLM -from e84_geoai_common.llm.bedrock import BedrockClaudeLLM as _BedrockClaudeLLM -from e84_geoai_common.llm.extraction import ( - extract_data_from_text as _extract_data_from_text, - ExtractDataExample as _ExtractDataExample, -) - -LLM = _LLM - -BedrockClaudeLLM = _BedrockClaudeLLM -ExtractDataExample = _ExtractDataExample -extract_data_from_text = _extract_data_from_text +"""LLM-related modules.""" diff --git a/src/e84_geoai_common/llm/agents/__init__.py b/src/e84_geoai_common/llm/agents/__init__.py new file mode 100644 index 0000000..c8caf71 --- /dev/null +++ b/src/e84_geoai_common/llm/agents/__init__.py @@ -0,0 +1,9 @@ +"""LLM Agents.""" + +from e84_geoai_common.llm.agents.data_extraction_agent import ( + DataExtractionAgent, +) + +__all__ = [ + "DataExtractionAgent", +] diff --git a/src/e84_geoai_common/llm/agents/data_extraction_agent.py b/src/e84_geoai_common/llm/agents/data_extraction_agent.py new file mode 100644 index 0000000..48baaa8 --- /dev/null +++ b/src/e84_geoai_common/llm/agents/data_extraction_agent.py @@ -0,0 +1,88 @@ +import json +from string import Template +from typing import TYPE_CHECKING, Generic, TypeVar + +from pydantic import BaseModel + +from e84_geoai_common.llm.core import ( + LLM, + Agent, + LLMInferenceConfig, + LLMMessage, +) + +if TYPE_CHECKING: + from e84_geoai_common.llm.core import LLM + +DATA_EXTRACTION_PROMPT_TEMPLATE = Template("""\ +Extract relevant information from the text provided, following the schema \ +given below. + +Schema: +${schema} + +Text: +${text} +""") + +ModelT = TypeVar("ModelT", bound=BaseModel) + + +class DataExtractionAgent(Agent, Generic[ModelT]): + """Extracts structured information from text following a schema.""" + + def __init__( + self, + llm: "LLM", + data_model: type[ModelT], + inference_cfg: LLMInferenceConfig | None = None, + ) -> None: + """Construct. + + Args: + llm (LLM): An LLM instance. + data_model (type[ModelT]): A Pydantic model. + inference_cfg (LLMInferenceConfig | None): Inference config. + Defaults to None. + """ + self.llm = llm + self.tools = [] + self.data_model = data_model + if inference_cfg is None: + inference_cfg = LLMInferenceConfig( + system_prompt="You are an LLM specializing in extracting " + "structured information from text.", + temperature=0, + json_mode=True, + ) + self.inference_cfg = inference_cfg + model_json_schema = json.dumps( + data_model.model_json_schema(), indent=2 + ) + # partially fill out template + self._prompt_template: Template = Template( + DATA_EXTRACTION_PROMPT_TEMPLATE.safe_substitute( + schema=model_json_schema + ) + ) + + def run(self, text: str) -> ModelT: + """Extract information from the given text.""" + prompt = self._prompt_template.safe_substitute(text=text) + response = self.llm.prompt( + messages=[LLMMessage(role="user", content=prompt)], + inference_cfg=self.inference_cfg, + auto_use_tools=True, + ) + response = response[-1] + if isinstance(response.content, str): + output_json = response.content + else: + output_json = str(response.content[-1]) + out = self.data_model.model_validate_json(output_json) + return out + + @property + def prompt_template(self) -> str: + """The prompt template used by the agent.""" + return self._prompt_template.template diff --git a/src/e84_geoai_common/llm/bedrock.py b/src/e84_geoai_common/llm/bedrock.py deleted file mode 100644 index 932dff4..0000000 --- a/src/e84_geoai_common/llm/bedrock.py +++ /dev/null @@ -1,61 +0,0 @@ -import json - -import boto3 -import botocore.exceptions -from mypy_boto3_bedrock_runtime import BedrockRuntimeClient - -from e84_geoai_common.llm.core import LLM, InvokeLLMRequest -from e84_geoai_common.util import timed_function - - -class BedrockClaudeLLM(LLM): - """Implements the LLM class for Bedrock Claude.""" - - client: BedrockRuntimeClient - - def __init__( - self, - model_id: str = "anthropic.claude-3-5-sonnet-20240620-v1:0", - client: BedrockRuntimeClient | None = None, - ) -> None: - self.model_id = model_id - self.client = client or boto3.client("bedrock-runtime") # type: ignore - - def _llm_request_to_body(self, request: InvokeLLMRequest) -> str: - messages = [ - {"role": msg.role, "content": msg.content} for msg in request.messages - ] - if request.json_mode: - # Force Claude into JSON mode - messages.append({"role": "assistant", "content": "{"}) - - return json.dumps( - { - "anthropic_version": "bedrock-2023-05-31", - "max_tokens": request.max_tokens, - "system": request.system, - "messages": messages, - "temperature": request.temperature, - "top_p": request.top_p, - "top_k": request.top_k, - } - ) - - @timed_function - def invoke_model_with_request(self, request: InvokeLLMRequest) -> str: - if len(request.messages) == 0: - raise Exception("Must specify at least one message") - req_body = self._llm_request_to_body(request) - try: - resp = self.client.invoke_model(modelId=self.model_id, body=req_body) - except botocore.exceptions.ClientError as vex: - print("Failed with", vex) - print("Request body:", req_body) - raise vex - body = str(resp["body"].read(), "UTF-8") - parsed = json.loads(body) - llm_response = parsed["content"][0]["text"] - if request.json_mode: - return "{" + llm_response - else: - return llm_response diff --git a/src/e84_geoai_common/llm/core.py b/src/e84_geoai_common/llm/core.py deleted file mode 100644 index 8966a26..0000000 --- a/src/e84_geoai_common/llm/core.py +++ /dev/null @@ -1,86 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Literal - -from pydantic import BaseModel, ConfigDict, Field - - -class LLMMessage(BaseModel): - """ - Represents a single message sent to or received from an LLM. - - "assistant" refers to the LLM. - """ - - model_config = ConfigDict(strict=True, extra="forbid", frozen=True) - - role: Literal["user", "assistant"] = "user" - - # FUTURE: This could be changed to allow for multiple items following the anthropic content style - content: str - - -class InvokeLLMRequest(BaseModel): - """Represents a request to invoke an LLM and get a response back.""" - - model_config = ConfigDict(strict=True, extra="forbid") - - system: str | None = Field(default=None, description="System Prompt") - max_tokens: int = Field(default=1000, description="Maximum number of output tokens") - temperature: float = Field( - default=0, description="Temperature control for randomness" - ) - top_p: float = Field(default=0, description="Top P for nucleus sampling") - top_k: int = Field(default=0, description="Top K for sampling") - json_mode: bool = Field(default=False, description="Turn on/off json mode") - messages: list[LLMMessage] = Field( - default_factory=list, description="List of LLM Messages" - ) - - -class LLM(ABC): - """ - An abstract base class for interacting with an LLM. - """ - - def invoke_model( - self, - *, - user_prompt: str, - system: str | None = None, - max_tokens: int = 1000, - temperature: float = 0, - top_p: float = 0, - top_k: int = 0, - json_mode: bool = False, - ) -> str: - """ - This function prepares a request to invoke an LLM and receives the response back. - - Parameters: - user_prompt (str): The user's prompt to the LLM. - system (str): An optional system prompt. - max_tokens (int): Defines the maximum number of output tokens. Default is 1000 tokens. - temperature (float): Controls randomness in the model's output. A value of 0 means deterministic. Default is 0. - top_p (float): Defines the cumulative probability below which the possible next tokens are discarded. Default is 0. - top_k (int): Chooses the next token from the top K probable tokens. Default is 0 which implies no limit. - json_mode (bool): A flag to specify JSON mode. If True, the model outputs in JSON mode. Default is False. - - Returns: - str: The LLM's response as a string. - """ - return self.invoke_model_with_request( - InvokeLLMRequest( - system=system, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - top_k=top_k, - json_mode=json_mode, - messages=[LLMMessage(content=user_prompt)], - ) - ) - - @abstractmethod - def invoke_model_with_request(self, request: InvokeLLMRequest) -> str: - """Invokes the model with the given request""" - ... diff --git a/src/e84_geoai_common/llm/core/__init__.py b/src/e84_geoai_common/llm/core/__init__.py new file mode 100644 index 0000000..a324d0a --- /dev/null +++ b/src/e84_geoai_common/llm/core/__init__.py @@ -0,0 +1,11 @@ +"""Core LLM classes.""" + +from e84_geoai_common.llm.core.agent import Agent +from e84_geoai_common.llm.core.llm import LLM, LLMInferenceConfig, LLMMessage + +__all__ = [ + "LLM", + "Agent", + "LLMInferenceConfig", + "LLMMessage", +] diff --git a/src/e84_geoai_common/llm/core/agent.py b/src/e84_geoai_common/llm/core/agent.py new file mode 100644 index 0000000..57c025f --- /dev/null +++ b/src/e84_geoai_common/llm/core/agent.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any + +from e84_geoai_common.llm.core.llm import LLM, LLMInferenceConfig + + +class Agent(ABC): + """A language model with instructions and tools.""" + + llm: LLM + inference_cfg: LLMInferenceConfig + tools: list[Callable[..., Any]] + + @property + @abstractmethod + def prompt_template(self) -> str: + """The prompt template used by the agent.""" diff --git a/src/e84_geoai_common/llm/core/llm.py b/src/e84_geoai_common/llm/core/llm.py new file mode 100644 index 0000000..835e87e --- /dev/null +++ b/src/e84_geoai_common/llm/core/llm.py @@ -0,0 +1,71 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from typing import Any + +from pydantic import BaseModel, Field + + +class LLMMessage(BaseModel): + """Standard representation of an LLM message. + + Specific LLM implementations should implement logic to translate to and + from this representation. + """ + + role: str + content: str | Sequence[Any] + + +class LLMInferenceConfig(BaseModel): + """Common inference options for LLMs. + + Specific LLM implementations should implement logic to translate these + parameters to their respective APIs. + """ + + system_prompt: str | None = Field( + default=None, description="System Prompt" + ) + tools: list[Callable[..., Any]] | None = Field( + default=None, description="List of tools that the model may call." + ) + tool_choice: str | None = Field( + default=None, + description="Whether the model should use a specific " + "tool, or any tool, or decide by itself.", + ) + max_tokens: int = Field( + default=1000, description="Maximum number of output tokens" + ) + temperature: float = Field( + default=0, + description="Temperature control for randomness. " + "Closer to zero = more deterministic.", + ) + top_p: float | None = Field( + default=None, description="Top P for nucleus sampling." + ) + top_k: int | None = Field(default=None, description="Top K for sampling") + json_mode: bool = Field( + default=False, + description="If True, forces model to only outputs valid JSON.", + ) + response_prefix: str | None = Field( + default=None, + description="Continue a pre-filled response instead of " + "starting from sratch.", + ) + + +class LLM(ABC): + """An abstract base class for interacting with an LLM.""" + + @abstractmethod + def prompt( + self, + messages: Sequence[LLMMessage], + inference_cfg: LLMInferenceConfig, + *, + auto_use_tools: bool = False, + ) -> Sequence[LLMMessage]: + """Prompt the LLM with a message and optional conversation history.""" diff --git a/src/e84_geoai_common/llm/extraction.py b/src/e84_geoai_common/llm/extraction.py index e29d0b1..f6457a9 100644 --- a/src/e84_geoai_common/llm/extraction.py +++ b/src/e84_geoai_common/llm/extraction.py @@ -1,22 +1,22 @@ +import logging from typing import Generic, TypeVar from pydantic import BaseModel, ConfigDict, ValidationError -from e84_geoai_common.llm.core import LLM, InvokeLLMRequest, LLMMessage - +from e84_geoai_common.llm.core import LLM, LLMInferenceConfig, LLMMessage Model = TypeVar("Model", bound=BaseModel) +log = logging.getLogger(__name__) + class ExtractDataExample(BaseModel, Generic[Model]): - """ - Represents an example data extraction scenario that can be used for building system prompts for - data extraction. + """Example data extraction scenario. Attributes: - - name (str): Name of the example. - - user_query (str): User's query for data extraction. - - structure (Model): Data structure to extract. + name (str): Name of the example. + user_query (str): User's query for data extraction. + structure (Model): Data structure to extract. """ model_config = ConfigDict(strict=True, extra="forbid", frozen=True) @@ -26,14 +26,18 @@ class ExtractDataExample(BaseModel, Generic[Model]): structure: Model def to_str(self) -> str: - """ - Returns a formatted string representation of the example data extraction scenario. + """Return formatted string representation of the example scenario. Returns: - str: Formatted string with example name, user query, and data structure in JSON format. + str: Formatted string with example name, user query, and data + structure in JSON format. """ - query_json = f"```json\n{self.structure.model_dump_json(indent=2, exclude_none=True)}\n```" - return f'Example: {self.name}\nUser Query: "{self.user_query}"\n\n{query_json}' + json_str = self.structure.model_dump_json(indent=2, exclude_none=True) + query_json = f"```json\n{json_str}\n```" + return ( + f'Example: {self.name}\nUser Query: "{self.user_query}"' + f"\n\n{query_json}" + ) def extract_data_from_text( @@ -43,24 +47,34 @@ def extract_data_from_text( system_prompt: str, user_prompt: str, ) -> Model: - """ - Extracts data from text using a Language Model (LLM) by providing system and user prompts. + """Extract data from text using an LLM given system and user prompts. Args: - - llm (LLM): The Language Model instance used for data extraction. - - model_type (Type[Model]): The type of data model to be used for validation. - - system_prompt (str): The prompt for the system to process the user input. - - user_prompt (str): The user input text for data extraction. + llm (LLM): The Language Model instance used for data extraction. + model_type (Type[Model]): The type of data model to be used for + validation. + system_prompt (str): The prompt for the system to process the user + input. + user_prompt (str): The user input text for data extraction. Returns: - Model: The extracted data model validated against the specified model type. + Model: The extracted data model validated against the specified model + type. """ - request = InvokeLLMRequest( - system=system_prompt, json_mode=True, messages=[LLMMessage(content=user_prompt)] + inference_cfg = LLMInferenceConfig( + system_prompt=system_prompt, + json_mode=True, ) - resp = llm.invoke_model_with_request(request) + messages = [LLMMessage(role="user", content=user_prompt)] + resp = llm.prompt(messages=messages, inference_cfg=inference_cfg) + resp = resp[-1] + if isinstance(resp.content, str): + output_json = resp.content + else: + output_json = str(resp.content[-1]) try: - return model_type.model_validate_json(resp) - except ValidationError as e: - print("Unable to parse response:", resp) - raise e + out = model_type.model_validate_json(output_json) + except ValidationError: + log.exception("Unable to parse LLM response: %s", resp) + raise + return out diff --git a/src/e84_geoai_common/llm/models/__init__.py b/src/e84_geoai_common/llm/models/__init__.py new file mode 100644 index 0000000..fcf2727 --- /dev/null +++ b/src/e84_geoai_common/llm/models/__init__.py @@ -0,0 +1,13 @@ +"""Wrappers for LLM APIs.""" + +from e84_geoai_common.llm.models.claude import ( + CLAUDE_BEDROCK_MODEL_IDS, + BedrockClaudeLLM, + ClaudeInvokeLLMRequest, +) + +__all__ = [ + "CLAUDE_BEDROCK_MODEL_IDS", + "BedrockClaudeLLM", + "ClaudeInvokeLLMRequest", +] diff --git a/src/e84_geoai_common/llm/models/claude.py b/src/e84_geoai_common/llm/models/claude.py new file mode 100644 index 0000000..f5aee3b --- /dev/null +++ b/src/e84_geoai_common/llm/models/claude.py @@ -0,0 +1,357 @@ +import logging +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any, Literal + +import boto3 +import botocore.exceptions +from function_schema.core import ( # type: ignore[reportMissingTypeStubs] + get_function_schema, # type: ignore[reportUnknownVariableType] +) +from mypy_boto3_bedrock_runtime import BedrockRuntimeClient +from pydantic import BaseModel, ConfigDict, Field + +from e84_geoai_common.llm.core.llm import LLM, LLMInferenceConfig, LLMMessage +from e84_geoai_common.util import timed_function + +if TYPE_CHECKING: + from typing import Self + +log = logging.getLogger(__name__) + +# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +ANTHROPIC_API_VERSION = "bedrock-2023-05-31" +# https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns +CLAUDE_BEDROCK_MODEL_IDS = { + "Claude 3 Haiku": "anthropic.claude-3-haiku-20240307-v1:0", + "Claude 3.5 Sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "Claude 3 Sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", + "Claude 3 Opus": "anthropic.claude-3-opus-20240229-v1:0", + "Claude Instant": "anthropic.claude-instant-v1", + "Claude 3.5 Haiku": "anthropic.claude-3-5-haiku-20241022-v1:0", + "Claude 3.5 Sonnet v2": "anthropic.claude-3-5-sonnet-20241022-v2:0", +} + + +class ClaudeTextContent(BaseModel): + """Claude text context model.""" + + type: Literal["text"] = "text" + text: str + + def __str__(self) -> str: + return self.text + + +class ClaudeToolUseContent(BaseModel): + """Claude tool-use request model.""" + + type: Literal["tool_use"] = "tool_use" + id: str + name: str + input: dict[str, Any] + + +class ClaudeToolResultContent(BaseModel): + """Claude tool result model.""" + + type: Literal["tool_result"] = "tool_result" + tool_use_id: str + content: str + + +class ClaudeMessage(LLMMessage): + """Claude message base model.""" + + role: Literal["assistant", "user"] + content: ( + str + | Sequence[ + ClaudeTextContent | ClaudeToolUseContent | ClaudeToolResultContent + ] + ) + + @classmethod + def from_llm_message(cls, message: LLMMessage) -> "Self": + """Construct from an LLMMessage.""" + return cls.model_validate(message.model_dump()) + + +class ClaudeUserMessage(ClaudeMessage): + """Claude user message model.""" + + role: Literal["user"] = "user" + content: str | Sequence[ClaudeTextContent | ClaudeToolResultContent] + + +class ClaudeAssistantMessage(ClaudeMessage): + """Claude assistant message model.""" + + role: Literal["assistant"] = "assistant" + content: str | Sequence[ClaudeTextContent | ClaudeToolUseContent] + + +class ClaudeUsageInfo(BaseModel): + """Claude usage-info model.""" + + input_tokens: int + output_tokens: int + + +class ClaudeResponse(BaseModel): + """Claude response model.""" + + id: str + type: Literal["message"] = "message" + role: Literal["assistant"] = "assistant" + content: Sequence[ClaudeTextContent | ClaudeToolUseContent] + model: str + stop_reason: Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"] + stop_sequence: str | None + usage: ClaudeUsageInfo + + def to_message(self) -> ClaudeMessage: + """Convert to a ClaudeAssistantMessage.""" + return ClaudeAssistantMessage(role=self.role, content=self.content) + + +class ClaudeTool(BaseModel): + """Representation of a tool that Claude can use.""" + + name: str + description: str + input_schema: dict[str, Any] + _func: Callable[..., Any] + + @classmethod + def from_function(cls, func: Callable[..., Any]) -> "Self": + """Construct from a Python funtion.""" + schema = get_function_schema(func, format="claude") # type: ignore[reportUnknownVariableType] + out = cls.model_validate(schema) + out._func = func # noqa: SLF001 + return out + + def use(self, tool_use: ClaudeToolUseContent) -> ClaudeUserMessage: + """Use tool and return the result as a ClaudeUserMessage.""" + func_out = self._func(**tool_use.input) + result = ClaudeToolResultContent( + tool_use_id=tool_use.id, content=str(func_out) + ) + msg = ClaudeUserMessage(content=[result]) + return msg + + +class ClaudeToolChoice(BaseModel): + """Claude tool choice model.""" + + type: Literal["auto", "any", "tool"] + name: str | None = None + # disable_parallel_tool_use is documented in Anthropic docs but seems to + # not be supported in Bedrock + # disable_parallel_tool_use: bool | None = None # noqa: ERA001 + + +class ClaudeInvokeLLMRequest(BaseModel): + """Represents a request to invoke Claude and get a response back.""" + + model_config = ConfigDict(strict=True, extra="forbid") + + anthropic_version: str = ANTHROPIC_API_VERSION + messages: list[ClaudeMessage] = Field( + default_factory=list, description="List of LLM Messages" + ) + system: str | None = Field(default=None, description="System Prompt") + tools: list[ClaudeTool] | None = Field( + default=None, description="List of tools that the model may call." + ) + tool_choice: ClaudeToolChoice | None = Field( + default=None, + description="Whether the model should use a specific " + "tool, or any tool, or decide by itself.", + ) + max_tokens: int = Field( + default=1000, description="Maximum number of output tokens" + ) + temperature: float = Field( + default=0, + description="Temperature control for randomness. " + "Closer to zero = more deterministic.", + ) + top_p: float | None = Field( + default=None, description="Top P for nucleus sampling" + ) + top_k: int | None = Field(default=None, description="Top K for sampling") + response_prefix: str | None = Field( + default=None, + description="Make Claude continue a pre-filled response instead of " + 'starting from sratch. Can be set to "{" to force "JSON mode".', + ) + + @classmethod + def from_inference_config( + cls, + cfg: LLMInferenceConfig, + messages: Sequence[ClaudeMessage] | None = None, + ) -> "Self": + """Construct from an LLMInferenceConfig.""" + messages = [] if messages is None else list(messages) + response_prefix = cfg.response_prefix + if cfg.json_mode: + if response_prefix is not None: + msg = "response_prefix not supported with json_mode=True." + raise ValueError(msg) + response_prefix = "{" + + tools = None + tool_choice = None + if cfg.tools is not None: + tools = [ClaudeTool.from_function(f) for f in cfg.tools] + if cfg.tool_choice is None: + tool_choice = ClaudeToolChoice(type="auto") + elif cfg.tool_choice in ("auto", "any"): + tool_choice = ClaudeToolChoice(type=cfg.tool_choice) + else: + tool_choice = ClaudeToolChoice( + type="tool", name=cfg.tool_choice + ) + log.info(tool_choice) + req = cls( + messages=messages, + system=cfg.system_prompt, + tools=tools, + tool_choice=tool_choice, + max_tokens=cfg.max_tokens, + temperature=cfg.temperature, + top_k=cfg.top_k, + top_p=cfg.top_p, + response_prefix=response_prefix, + ) + return req + + def to_request_body(self) -> str: + """Convert to JSON request body.""" + if len(self.messages) == 0: + msg = "Must specify at least one message." + raise ValueError(msg) + if self.response_prefix is not None: + prefilled_response = ClaudeAssistantMessage( + content=self.response_prefix + ) + self.messages.append(prefilled_response) + body = self.model_dump_json( + exclude_none=True, exclude={"response_prefix"} + ) + return body + + +class BedrockClaudeLLM(LLM): + """Implements the LLM class for Bedrock Claude.""" + + client: BedrockRuntimeClient + + def __init__( + self, + model_id: str = CLAUDE_BEDROCK_MODEL_IDS["Claude 3 Haiku"], + client: BedrockRuntimeClient | None = None, + ) -> None: + """Initialize. + + Args: + model_id: Model ID. Defaults to the model ID for Claude 3 Haiku. + client: Optional pre-initialized boto3 client. Defaults to None. + """ + self.model_id = model_id + self.client = client or boto3.client("bedrock-runtime") # type: ignore[reportUnknownMemberType] + + @timed_function + def prompt( + self, + messages: Sequence[LLMMessage], + inference_cfg: LLMInferenceConfig, + *, + auto_use_tools: bool = False, + ) -> list[ClaudeMessage]: + """Prompt the LLM with a message and optional conversation history.""" + if len(messages) == 0: + msg = "Must specify at least one message." + raise ValueError(msg) + messages = [ClaudeMessage.from_llm_message(m) for m in messages] + request = ClaudeInvokeLLMRequest.from_inference_config( + inference_cfg, messages + ) + response = self.invoke_model_with_request(request) + if response.stop_reason == "tool_use" and auto_use_tools: + assert request.tools is not None # noqa: S101 + log.info("Tool-use requested:") + log.info(response.content) + tool_result_msgs = self.use_tools(response.content, request.tools) + log.info("Tool-use results:") + log.info(tool_result_msgs) + new_messages = [ + *messages, + response.to_message(), + *tool_result_msgs, + ] + return self.prompt( + new_messages, + inference_cfg, + ) + return [*messages, response.to_message()] + + @timed_function + def invoke_model_with_request( + self, request: ClaudeInvokeLLMRequest + ) -> ClaudeResponse: + """Invoke model with request and get a response back.""" + response_body = self._make_client_request(request) + claude_response = self._parse_response(response_body, request) + return claude_response + + def use_tools( + self, + content: Sequence[ClaudeTextContent | ClaudeToolUseContent], + tools: list[ClaudeTool], + ) -> list[ClaudeUserMessage]: + """Fulfill all tool-use requests and return response messages.""" + tools_dict = {t.name: t for t in tools} + out_messages: list[ClaudeUserMessage] = [] + for block in content: + if not isinstance(block, ClaudeToolUseContent): + continue + tool = tools_dict[block.name] + out_messages.append(tool.use(block)) + return out_messages + + def _parse_response( + self, response_body: str, request: ClaudeInvokeLLMRequest + ) -> ClaudeResponse: + """Parse raw JSON response into a ClaudeResponse.""" + response = ClaudeResponse.model_validate_json(response_body) + if request.response_prefix is not None: + response = self._add_prefix_to_response( + response, request.response_prefix + ) + return response + + def _make_client_request(self, request: ClaudeInvokeLLMRequest) -> str: + """Make model invocation request and return raw JSON response.""" + request_body = request.to_request_body() + try: + response = self.client.invoke_model( + modelId=self.model_id, body=request_body + ) + except botocore.exceptions.ClientError as e: + log.error("Failed with %s", e) # noqa: TRY400 + log.error("Request body: %s", request_body) # noqa: TRY400 + raise + response_body = response["body"].read().decode("UTF-8") + return response_body + + def _add_prefix_to_response( + self, response: ClaudeResponse, prefix: str + ) -> ClaudeResponse: + """Prepend prefix to the text of the first text-content block.""" + for content_block in response.content: + if isinstance(content_block, ClaudeTextContent): + content_block.text = prefix + content_block.text + break + return response diff --git a/src/e84_geoai_common/util.py b/src/e84_geoai_common/util.py index 82cfb51..a4ad055 100644 --- a/src/e84_geoai_common/util.py +++ b/src/e84_geoai_common/util.py @@ -1,97 +1,94 @@ -import time +import logging import os import textwrap -from typing import Any, Callable, TypeVar +from collections.abc import Callable +from time import perf_counter +from typing import Any, TypeVar + +log = logging.getLogger(__name__) + + +T = TypeVar("T", bound=Callable[..., Any]) def get_env_var(name: str, default: str | None = None) -> str: - """ - Retrieves the value of an environment variable. - """ + """Retrieve the value of an environment variable.""" value = os.getenv(name) or default - if value is None: - raise Exception(f"Env var {name} must be set") + msg = f"Env var {name} must be set" + raise ValueError(msg) return value def dedent(text: str) -> str: - """ - Remove common leading whitespace from every line in a multi-line string. + """Remove common leading whitespace from every line in a multi-line string. - Parameters: - text (str): The multi-line string with potentially uneven indentation. + Args: + text (str): The multi-line string with potentially uneven indentation. Returns: - str: The modified string with common leading whitespace removed from every line. - - Raises: - None + str: The modified string with common leading whitespace removed from + every line. Example: - text = ''' - Lorem ipsum dolor sit amet, - consectetur adipiscing elit, - sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. - ''' - result = dedent(text) - print(result) - # Output: - # 'Lorem ipsum dolor sit amet, - # consectetur adipiscing elit, - # sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.' + text = ''' + Lorem ipsum dolor sit amet, + consectetur adipiscing elit, + sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. + ''' + result = dedent(text) + print(result) + # Output: + # 'Lorem ipsum dolor sit amet, + # consectetur adipiscing elit, + # sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.' """ return textwrap.dedent(text).strip() def singleline(text: str) -> str: - """ - Remove common leading whitespace from every line in a multi-line string and convert it into a single line. + """Remove common leading whitespace from every line in a multi-line string. - Parameters: - text (str): The multi-line string with potentially uneven indentation. + Args: + text (str): The multi-line string with potentially uneven indentation. Returns: - str: The modified string with common leading whitespace removed from every line and converted into a single line. - - Raises: - None + str: The modified string with common leading whitespace removed from + every line and converted into a single line. Example: - text = ''' - Lorem ipsum dolor sit amet, - consectetur adipiscing elit, - sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. - ''' - result = singleline(text) - print(result) - # Output: - # 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.' - """ + text = ''' + Lorem ipsum dolor sit amet, + consectetur adipiscing elit, + sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. + ''' + result = singleline(text) + print(result) + # Output: + # 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.' + """ # noqa: E501 return dedent(text).replace("\n", " ") -T = TypeVar("T", bound=Callable[..., Any]) - - def timed_function(func: T) -> T: - """ - A decorator for timing a function call. + """Decorate a function to log execution time. - This decorator will print the execution time of the decorated function after it runs. + This decorator will print the execution time of the decorated function + after it finishes executing. - Parameters: - func (Callable): The function to be timed. + Args: + func (Callable): The function to be timed. Returns: - Callable: The decorated function. + Callable: The decorated function. """ - def wrapper(*args: Any, **kwargs: Any) -> Any: - start_time = time.time() # capture the start time before executing - result = func(*args, **kwargs) # execute the function - end_time = time.time() - print(f"{func.__name__} took {end_time - start_time} seconds to run.") + def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> Any: # noqa: ANN401 + start_time = perf_counter() + result = func(*args, **kwargs) + end_time = perf_counter() + diff = end_time - start_time + log.info("%s took %f seconds to run.", func.__name__, diff) return result - return wrapper # type: ignore + return wrapper # type: ignore[reportReturnType] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/llm/__init__.py b/tests/llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/llm/models/__init__.py b/tests/llm/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/llm/models/test_claude.py b/tests/llm/models/test_claude.py new file mode 100644 index 0000000..bcb58ec --- /dev/null +++ b/tests/llm/models/test_claude.py @@ -0,0 +1,28 @@ +from moto import mock_aws + +from e84_geoai_common.llm.models.claude import ( + BedrockClaudeLLM, + ClaudeResponse, + ClaudeTextContent, + ClaudeUsageInfo, +) + + +@mock_aws +def test_response_prefix(): + prefix = "__prefix__" + llm = BedrockClaudeLLM() + content_in = [ClaudeTextContent(text="abc"), ClaudeTextContent(text="def")] + response_in = ClaudeResponse( + id="", + content=content_in, + model="", + stop_reason="end_turn", + stop_sequence=None, + usage=ClaudeUsageInfo(input_tokens=0, output_tokens=0), + ) + response_out = llm._add_prefix_to_response(response_in, prefix=prefix) + content_out = response_out.content + assert len(content_out) == len(content_in) + assert content_out[0].text.startswith(prefix) + assert content_out[1].text == content_in[1].text diff --git a/tests/test_geometry.py b/tests/test_geometry.py index 8629d97..4328080 100644 --- a/tests/test_geometry.py +++ b/tests/test_geometry.py @@ -1,6 +1,7 @@ +from math import cos, pi, sin + from shapely import GeometryCollection, LineString, Point from shapely.geometry.polygon import Polygon -from math import sin, cos, pi from e84_geoai_common.geometry import geometry_point_count, simplify_geometry @@ -62,7 +63,8 @@ def test_simplify_geometry(): assert simplify_geometry(point) == point polygon = generate_circle(num_points=200) - # Simplifies to the same thing if already less than the set number of points + # Simplifies to the same thing if already less than the set number of + # points assert simplify_geometry(polygon, 300) == polygon assert geometry_point_count(simplify_geometry(polygon, 199)) <= 199