diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 6538ca91..6d78745c 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.8.0" + ".": "0.9.0" } \ No newline at end of file diff --git a/.stats.yml b/.stats.yml index b868f520..ed702969 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 34 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/contextual-ai%2Fsunrise-db7245c74772a8cd47c02886619fed0568fbb58b1fa8aba0dc77524b924a4fb6.yml -openapi_spec_hash: ca3de8d7b14b78683e39464fe7d4b1e1 -config_hash: 410f8a2f86f605885911277be47c3c78 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/contextual-ai%2Fsunrise-c8152db455001be3f09a3bc60d63711699d2c2a4ea5f7bbc1d71726efda0fd9b.yml +openapi_spec_hash: 97719df292ca220de5d35d36f9756b95 +config_hash: ae81af9b7eb88a788a80bcf3480e0b6b diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a96ddb7..10ba2b3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,35 @@ # Changelog +## 0.9.0 (2025-10-28) + +Full Changelog: [v0.8.0...v0.9.0](https://github.com/ContextualAI/contextual-client-python/compare/v0.8.0...v0.9.0) + +### Features + +* **api:** update via SDK Studio ([3ebbcab](https://github.com/ContextualAI/contextual-client-python/commit/3ebbcab780e0391c420126b8cbf11589aba78470)) +* improve future compat with pydantic v3 ([2837532](https://github.com/ContextualAI/contextual-client-python/commit/2837532cb8930994be7d02c356421a1a3e990c78)) +* **types:** replace List[str] with SequenceNotStr in params ([ee66bc5](https://github.com/ContextualAI/contextual-client-python/commit/ee66bc5ce67fefb92b384ea979475ed54a53af9d)) + + +### Bug Fixes + +* avoid newer type syntax ([551b56e](https://github.com/ContextualAI/contextual-client-python/commit/551b56e22af03e8305e599814f484a5ed64b9cb3)) +* **compat:** compat with `pydantic<2.8.0` when using additional fields ([ceb597f](https://github.com/ContextualAI/contextual-client-python/commit/ceb597f1da87f0e2ae718a6143ec57c60b8f4d3d)) + + +### Chores + +* bump `httpx-aiohttp` version to 0.1.9 ([bdcc7c6](https://github.com/ContextualAI/contextual-client-python/commit/bdcc7c6b9c7784c25b49debbcadca2307c43e5b6)) +* do not install brew dependencies in ./scripts/bootstrap by default ([41397b2](https://github.com/ContextualAI/contextual-client-python/commit/41397b25ce468bf58ed53e8e78e7cea2fcf41a47)) +* **internal:** add Sequence related utils ([f00b892](https://github.com/ContextualAI/contextual-client-python/commit/f00b892536ef8c2d6f67965d7db28a1956adebb5)) +* **internal:** detect missing future annotations with ruff ([6958d77](https://github.com/ContextualAI/contextual-client-python/commit/6958d772079b5f5571e7db3c39255d146b11dd5b)) +* **internal:** improve examples ([c6f06b9](https://github.com/ContextualAI/contextual-client-python/commit/c6f06b9b0859a68bb32fa96294443abd139070e4)) +* **internal:** move mypy configurations to `pyproject.toml` file ([57b4284](https://github.com/ContextualAI/contextual-client-python/commit/57b42849dd5340b2ce21aa8b6b8fb0c7e15529ba)) +* **internal:** update pydantic dependency ([35223af](https://github.com/ContextualAI/contextual-client-python/commit/35223af9a91cc39d4800294b7821480ac2d2b0ee)) +* **internal:** update pyright exclude list ([e89669e](https://github.com/ContextualAI/contextual-client-python/commit/e89669e93ed4a4e74993adfee5756b3502719e8c)) +* **tests:** simplify `get_platform` test ([1f089bd](https://github.com/ContextualAI/contextual-client-python/commit/1f089bdf7319dee3c726d844d11f35a924cfdcc4)) +* **types:** change optional parameter type from NotGiven to Omit ([07ee8a4](https://github.com/ContextualAI/contextual-client-python/commit/07ee8a4cecd02070a7fd44d1daec9687af2fce45)) + ## 0.8.0 (2025-08-26) Full Changelog: [v0.7.0...v0.8.0](https://github.com/ContextualAI/contextual-client-python/compare/v0.7.0...v0.8.0) diff --git a/api.md b/api.md index 5ce1f467..225b992d 100644 --- a/api.md +++ b/api.md @@ -43,7 +43,7 @@ Methods: - client.datastores.documents.get_parse_result(document_id, \*, datastore_id, \*\*params) -> DocumentGetParseResultResponse - client.datastores.documents.ingest(datastore_id, \*\*params) -> IngestionResponse - client.datastores.documents.metadata(document_id, \*, datastore_id) -> DocumentMetadata -- client.datastores.documents.set_metadata(document_id, \*, datastore_id, \*\*params) -> DocumentMetadata +- client.datastores.documents.set_metadata(document_id, \*, datastore_id, \*\*params) -> DocumentMetadata # Agents @@ -79,13 +79,18 @@ Methods: Types: ```python -from contextual.types.agents import QueryResponse, RetrievalInfoResponse, QueryMetricsResponse +from contextual.types.agents import ( + QueryResponse, + RetrievalInfoResponse, + QueryFeedbackResponse, + QueryMetricsResponse, +) ``` Methods: - client.agents.query.create(agent_id, \*\*params) -> QueryResponse -- client.agents.query.feedback(agent_id, \*\*params) -> object +- client.agents.query.feedback(agent_id, \*\*params) -> QueryFeedbackResponse - client.agents.query.metrics(agent_id, \*\*params) -> QueryMetricsResponse - client.agents.query.retrieval_info(message_id, \*, agent_id, \*\*params) -> RetrievalInfoResponse diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index c6b994be..00000000 --- a/mypy.ini +++ /dev/null @@ -1,50 +0,0 @@ -[mypy] -pretty = True -show_error_codes = True - -# Exclude _files.py because mypy isn't smart enough to apply -# the correct type narrowing and as this is an internal module -# it's fine to just use Pyright. -# -# We also exclude our `tests` as mypy doesn't always infer -# types correctly and Pyright will still catch any type errors. -exclude = ^(src/contextual/_files\.py|_dev/.*\.py|tests/.*)$ - -strict_equality = True -implicit_reexport = True -check_untyped_defs = True -no_implicit_optional = True - -warn_return_any = True -warn_unreachable = True -warn_unused_configs = True - -# Turn these options off as it could cause conflicts -# with the Pyright options. -warn_unused_ignores = False -warn_redundant_casts = False - -disallow_any_generics = True -disallow_untyped_defs = True -disallow_untyped_calls = True -disallow_subclassing_any = True -disallow_incomplete_defs = True -disallow_untyped_decorators = True -cache_fine_grained = True - -# By default, mypy reports an error if you assign a value to the result -# of a function call that doesn't return anything. We do this in our test -# cases: -# ``` -# result = ... -# assert result is None -# ``` -# Changing this codegen to make mypy happy would increase complexity -# and would not be worth it. -disable_error_code = func-returns-value,overload-cannot-match - -# https://github.com/python/mypy/issues/12162 -[mypy.overrides] -module = "black.files.*" -ignore_errors = true -ignore_missing_imports = true diff --git a/pyproject.toml b/pyproject.toml index 1169a3e4..5d15d3e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "contextual-client" -version = "0.8.0" +version = "0.9.0" description = "The official Python library for the Contextual AI API" dynamic = ["readme"] license = "Apache-2.0" @@ -39,7 +39,7 @@ Homepage = "https://github.com/ContextualAI/contextual-client-python" Repository = "https://github.com/ContextualAI/contextual-client-python" [project.optional-dependencies] -aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.8"] +aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.9"] [tool.rye] managed = true @@ -56,7 +56,6 @@ dev-dependencies = [ "dirty-equals>=0.6.0", "importlib-metadata>=6.7.0", "rich>=13.7.1", - "nest_asyncio==1.6.0", "pytest-xdist>=3.6.1", ] @@ -148,6 +147,7 @@ exclude = [ "_dev", ".venv", ".nox", + ".git", ] reportImplicitOverride = true @@ -156,6 +156,58 @@ reportOverlappingOverload = false reportImportCycles = false reportPrivateUsage = false +[tool.mypy] +pretty = true +show_error_codes = true + +# Exclude _files.py because mypy isn't smart enough to apply +# the correct type narrowing and as this is an internal module +# it's fine to just use Pyright. +# +# We also exclude our `tests` as mypy doesn't always infer +# types correctly and Pyright will still catch any type errors. +exclude = ['src/contextual/_files.py', '_dev/.*.py', 'tests/.*'] + +strict_equality = true +implicit_reexport = true +check_untyped_defs = true +no_implicit_optional = true + +warn_return_any = true +warn_unreachable = true +warn_unused_configs = true + +# Turn these options off as it could cause conflicts +# with the Pyright options. +warn_unused_ignores = false +warn_redundant_casts = false + +disallow_any_generics = true +disallow_untyped_defs = true +disallow_untyped_calls = true +disallow_subclassing_any = true +disallow_incomplete_defs = true +disallow_untyped_decorators = true +cache_fine_grained = true + +# By default, mypy reports an error if you assign a value to the result +# of a function call that doesn't return anything. We do this in our test +# cases: +# ``` +# result = ... +# assert result is None +# ``` +# Changing this codegen to make mypy happy would increase complexity +# and would not be worth it. +disable_error_code = "func-returns-value,overload-cannot-match" + +# https://github.com/python/mypy/issues/12162 +[[tool.mypy.overrides]] +module = "black.files.*" +ignore_errors = true +ignore_missing_imports = true + + [tool.ruff] line-length = 120 output-format = "grouped" @@ -172,6 +224,8 @@ select = [ "B", # remove unused imports "F401", + # check for missing future annotations + "FA102", # bare except statements "E722", # unused arguments @@ -194,6 +248,8 @@ unfixable = [ "T203", ] +extend-safe-fixes = ["FA102"] + [tool.ruff.lint.flake8-tidy-imports.banned-api] "functools.lru_cache".msg = "This function does not retain type information for the wrapped function's arguments; The `lru_cache` function from `_utils` should be used instead" diff --git a/requirements-dev.lock b/requirements-dev.lock index 1a6388b1..d95129c3 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -56,7 +56,7 @@ httpx==0.28.1 # via contextual-client # via httpx-aiohttp # via respx -httpx-aiohttp==0.1.8 +httpx-aiohttp==0.1.9 # via contextual-client idna==3.4 # via anyio @@ -75,7 +75,6 @@ multidict==6.4.4 mypy==1.14.1 mypy-extensions==1.0.0 # via mypy -nest-asyncio==1.6.0 nodeenv==1.8.0 # via pyright nox==2023.4.22 @@ -89,9 +88,9 @@ pluggy==1.5.0 propcache==0.3.1 # via aiohttp # via yarl -pydantic==2.10.3 +pydantic==2.11.9 # via contextual-client -pydantic-core==2.27.1 +pydantic-core==2.33.2 # via pydantic pygments==2.18.0 # via rich @@ -127,6 +126,9 @@ typing-extensions==4.12.2 # via pydantic # via pydantic-core # via pyright + # via typing-inspection +typing-inspection==0.4.1 + # via pydantic virtualenv==20.24.5 # via nox yarl==1.20.0 diff --git a/requirements.lock b/requirements.lock index 321707d9..b5c6c839 100644 --- a/requirements.lock +++ b/requirements.lock @@ -43,7 +43,7 @@ httpcore==1.0.9 httpx==0.28.1 # via contextual-client # via httpx-aiohttp -httpx-aiohttp==0.1.8 +httpx-aiohttp==0.1.9 # via contextual-client idna==3.4 # via anyio @@ -55,9 +55,9 @@ multidict==6.4.4 propcache==0.3.1 # via aiohttp # via yarl -pydantic==2.10.3 +pydantic==2.11.9 # via contextual-client -pydantic-core==2.27.1 +pydantic-core==2.33.2 # via pydantic sniffio==1.3.0 # via anyio @@ -68,5 +68,8 @@ typing-extensions==4.12.2 # via multidict # via pydantic # via pydantic-core + # via typing-inspection +typing-inspection==0.4.1 + # via pydantic yarl==1.20.0 # via aiohttp diff --git a/scripts/bootstrap b/scripts/bootstrap index e84fe62c..b430fee3 100755 --- a/scripts/bootstrap +++ b/scripts/bootstrap @@ -4,10 +4,18 @@ set -e cd "$(dirname "$0")/.." -if ! command -v rye >/dev/null 2>&1 && [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ]; then +if [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ] && [ "$SKIP_BREW" != "1" ] && [ -t 0 ]; then brew bundle check >/dev/null 2>&1 || { - echo "==> Installing Homebrew dependencies…" - brew bundle + echo -n "==> Install Homebrew dependencies? (y/N): " + read -r response + case "$response" in + [yY][eE][sS]|[yY]) + brew bundle + ;; + *) + ;; + esac + echo } fi diff --git a/src/contextual/__init__.py b/src/contextual/__init__.py index 831570ca..658e0753 100644 --- a/src/contextual/__init__.py +++ b/src/contextual/__init__.py @@ -3,7 +3,7 @@ import typing as _t from . import types -from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes +from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes, omit, not_given from ._utils import file_from_path from ._client import ( Client, @@ -48,7 +48,9 @@ "ProxiesTypes", "NotGiven", "NOT_GIVEN", + "not_given", "Omit", + "omit", "ContextualAIError", "APIError", "APIStatusError", diff --git a/src/contextual/_base_client.py b/src/contextual/_base_client.py index 2d5b5fa2..c7fd3cb7 100644 --- a/src/contextual/_base_client.py +++ b/src/contextual/_base_client.py @@ -42,7 +42,6 @@ from ._qs import Querystring from ._files import to_httpx_files, async_to_httpx_files from ._types import ( - NOT_GIVEN, Body, Omit, Query, @@ -57,9 +56,10 @@ RequestOptions, HttpxRequestFiles, ModelBuilderProtocol, + not_given, ) from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping -from ._compat import PYDANTIC_V2, model_copy, model_dump +from ._compat import PYDANTIC_V1, model_copy, model_dump from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type from ._response import ( APIResponse, @@ -145,9 +145,9 @@ def __init__( def __init__( self, *, - url: URL | NotGiven = NOT_GIVEN, - json: Body | NotGiven = NOT_GIVEN, - params: Query | NotGiven = NOT_GIVEN, + url: URL | NotGiven = not_given, + json: Body | NotGiven = not_given, + params: Query | NotGiven = not_given, ) -> None: self.url = url self.json = json @@ -232,7 +232,7 @@ def _set_private_attributes( model: Type[_T], options: FinalRequestOptions, ) -> None: - if PYDANTIC_V2 and getattr(self, "__pydantic_private__", None) is None: + if (not PYDANTIC_V1) and getattr(self, "__pydantic_private__", None) is None: self.__pydantic_private__ = {} self._model = model @@ -320,7 +320,7 @@ def _set_private_attributes( client: AsyncAPIClient, options: FinalRequestOptions, ) -> None: - if PYDANTIC_V2 and getattr(self, "__pydantic_private__", None) is None: + if (not PYDANTIC_V1) and getattr(self, "__pydantic_private__", None) is None: self.__pydantic_private__ = {} self._model = model @@ -595,7 +595,7 @@ def _maybe_override_cast_to(self, cast_to: type[ResponseT], options: FinalReques # we internally support defining a temporary header to override the # default `cast_to` type for use with `.with_raw_response` and `.with_streaming_response` # see _response.py for implementation details - override_cast_to = headers.pop(OVERRIDE_CAST_TO_HEADER, NOT_GIVEN) + override_cast_to = headers.pop(OVERRIDE_CAST_TO_HEADER, not_given) if is_given(override_cast_to): options.headers = headers return cast(Type[ResponseT], override_cast_to) @@ -825,7 +825,7 @@ def __init__( version: str, base_url: str | URL, max_retries: int = DEFAULT_MAX_RETRIES, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.Client | None = None, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, @@ -1356,7 +1356,7 @@ def __init__( base_url: str | URL, _strict_response_validation: bool, max_retries: int = DEFAULT_MAX_RETRIES, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.AsyncClient | None = None, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, @@ -1818,8 +1818,8 @@ def make_request_options( extra_query: Query | None = None, extra_body: Body | None = None, idempotency_key: str | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - post_parser: PostParser | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + post_parser: PostParser | NotGiven = not_given, ) -> RequestOptions: """Create a dict of type RequestOptions without keys of NotGiven values.""" options: RequestOptions = {} diff --git a/src/contextual/_client.py b/src/contextual/_client.py index 9665f1b9..433b434d 100644 --- a/src/contextual/_client.py +++ b/src/contextual/_client.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Any, Union, Mapping +from typing import TYPE_CHECKING, Any, Mapping from typing_extensions import Self, override import httpx @@ -11,13 +11,13 @@ from . import _exceptions from ._qs import Querystring from ._types import ( - NOT_GIVEN, Omit, Timeout, NotGiven, Transport, ProxiesTypes, RequestOptions, + not_given, ) from ._utils import is_given, get_async_library from ._compat import cached_property @@ -63,7 +63,7 @@ def __init__( *, api_key: str | None = None, base_url: str | httpx.URL | None = None, - timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -195,9 +195,9 @@ def copy( *, api_key: str | None = None, base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.Client | None = None, - max_retries: int | NotGiven = NOT_GIVEN, + max_retries: int | NotGiven = not_given, default_headers: Mapping[str, str] | None = None, set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -286,7 +286,7 @@ def __init__( *, api_key: str | None = None, base_url: str | httpx.URL | None = None, - timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -418,9 +418,9 @@ def copy( *, api_key: str | None = None, base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.AsyncClient | None = None, - max_retries: int | NotGiven = NOT_GIVEN, + max_retries: int | NotGiven = not_given, default_headers: Mapping[str, str] | None = None, set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, diff --git a/src/contextual/_compat.py b/src/contextual/_compat.py index 92d9ee61..bdef67f0 100644 --- a/src/contextual/_compat.py +++ b/src/contextual/_compat.py @@ -12,14 +12,13 @@ _T = TypeVar("_T") _ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel) -# --------------- Pydantic v2 compatibility --------------- +# --------------- Pydantic v2, v3 compatibility --------------- # Pyright incorrectly reports some of our functions as overriding a method when they don't # pyright: reportIncompatibleMethodOverride=false -PYDANTIC_V2 = pydantic.VERSION.startswith("2.") +PYDANTIC_V1 = pydantic.VERSION.startswith("1.") -# v1 re-exports if TYPE_CHECKING: def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001 @@ -44,90 +43,92 @@ def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001 ... else: - if PYDANTIC_V2: - from pydantic.v1.typing import ( + # v1 re-exports + if PYDANTIC_V1: + from pydantic.typing import ( get_args as get_args, is_union as is_union, get_origin as get_origin, is_typeddict as is_typeddict, is_literal_type as is_literal_type, ) - from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime + from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime else: - from pydantic.typing import ( + from ._utils import ( get_args as get_args, is_union as is_union, get_origin as get_origin, + parse_date as parse_date, is_typeddict as is_typeddict, + parse_datetime as parse_datetime, is_literal_type as is_literal_type, ) - from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # refactored config if TYPE_CHECKING: from pydantic import ConfigDict as ConfigDict else: - if PYDANTIC_V2: - from pydantic import ConfigDict - else: + if PYDANTIC_V1: # TODO: provide an error message here? ConfigDict = None + else: + from pydantic import ConfigDict as ConfigDict # renamed methods / properties def parse_obj(model: type[_ModelT], value: object) -> _ModelT: - if PYDANTIC_V2: - return model.model_validate(value) - else: + if PYDANTIC_V1: return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + else: + return model.model_validate(value) def field_is_required(field: FieldInfo) -> bool: - if PYDANTIC_V2: - return field.is_required() - return field.required # type: ignore + if PYDANTIC_V1: + return field.required # type: ignore + return field.is_required() def field_get_default(field: FieldInfo) -> Any: value = field.get_default() - if PYDANTIC_V2: - from pydantic_core import PydanticUndefined - - if value == PydanticUndefined: - return None + if PYDANTIC_V1: return value + from pydantic_core import PydanticUndefined + + if value == PydanticUndefined: + return None return value def field_outer_type(field: FieldInfo) -> Any: - if PYDANTIC_V2: - return field.annotation - return field.outer_type_ # type: ignore + if PYDANTIC_V1: + return field.outer_type_ # type: ignore + return field.annotation def get_model_config(model: type[pydantic.BaseModel]) -> Any: - if PYDANTIC_V2: - return model.model_config - return model.__config__ # type: ignore + if PYDANTIC_V1: + return model.__config__ # type: ignore + return model.model_config def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]: - if PYDANTIC_V2: - return model.model_fields - return model.__fields__ # type: ignore + if PYDANTIC_V1: + return model.__fields__ # type: ignore + return model.model_fields def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT: - if PYDANTIC_V2: - return model.model_copy(deep=deep) - return model.copy(deep=deep) # type: ignore + if PYDANTIC_V1: + return model.copy(deep=deep) # type: ignore + return model.model_copy(deep=deep) def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: - if PYDANTIC_V2: - return model.model_dump_json(indent=indent) - return model.json(indent=indent) # type: ignore + if PYDANTIC_V1: + return model.json(indent=indent) # type: ignore + return model.model_dump_json(indent=indent) def model_dump( @@ -139,14 +140,14 @@ def model_dump( warnings: bool = True, mode: Literal["json", "python"] = "python", ) -> dict[str, Any]: - if PYDANTIC_V2 or hasattr(model, "model_dump"): + if (not PYDANTIC_V1) or hasattr(model, "model_dump"): return model.model_dump( mode=mode, exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, # warnings are not supported in Pydantic v1 - warnings=warnings if PYDANTIC_V2 else True, + warnings=True if PYDANTIC_V1 else warnings, ) return cast( "dict[str, Any]", @@ -159,9 +160,9 @@ def model_dump( def model_parse(model: type[_ModelT], data: Any) -> _ModelT: - if PYDANTIC_V2: - return model.model_validate(data) - return model.parse_obj(data) # pyright: ignore[reportDeprecated] + if PYDANTIC_V1: + return model.parse_obj(data) # pyright: ignore[reportDeprecated] + return model.model_validate(data) # generic models @@ -170,17 +171,16 @@ def model_parse(model: type[_ModelT], data: Any) -> _ModelT: class GenericModel(pydantic.BaseModel): ... else: - if PYDANTIC_V2: + if PYDANTIC_V1: + import pydantic.generics + + class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ... + else: # there no longer needs to be a distinction in v2 but # we still have to create our own subclass to avoid # inconsistent MRO ordering errors class GenericModel(pydantic.BaseModel): ... - else: - import pydantic.generics - - class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ... - # cached properties if TYPE_CHECKING: diff --git a/src/contextual/_models.py b/src/contextual/_models.py index b8387ce9..6a3cd1d2 100644 --- a/src/contextual/_models.py +++ b/src/contextual/_models.py @@ -50,7 +50,7 @@ strip_annotated_type, ) from ._compat import ( - PYDANTIC_V2, + PYDANTIC_V1, ConfigDict, GenericModel as BaseGenericModel, get_args, @@ -81,11 +81,7 @@ class _ConfigProtocol(Protocol): class BaseModel(pydantic.BaseModel): - if PYDANTIC_V2: - model_config: ClassVar[ConfigDict] = ConfigDict( - extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")) - ) - else: + if PYDANTIC_V1: @property @override @@ -95,6 +91,10 @@ def model_fields_set(self) -> set[str]: class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] extra: Any = pydantic.Extra.allow # type: ignore + else: + model_config: ClassVar[ConfigDict] = ConfigDict( + extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")) + ) def to_dict( self, @@ -215,25 +215,25 @@ def construct( # pyright: ignore[reportIncompatibleMethodOverride] if key not in model_fields: parsed = construct_type(value=value, type_=extra_field_type) if extra_field_type is not None else value - if PYDANTIC_V2: - _extra[key] = parsed - else: + if PYDANTIC_V1: _fields_set.add(key) fields_values[key] = parsed + else: + _extra[key] = parsed object.__setattr__(m, "__dict__", fields_values) - if PYDANTIC_V2: - # these properties are copied from Pydantic's `model_construct()` method - object.__setattr__(m, "__pydantic_private__", None) - object.__setattr__(m, "__pydantic_extra__", _extra) - object.__setattr__(m, "__pydantic_fields_set__", _fields_set) - else: + if PYDANTIC_V1: # init_private_attributes() does not exist in v2 m._init_private_attributes() # type: ignore # copied from Pydantic v1's `construct()` method object.__setattr__(m, "__fields_set__", _fields_set) + else: + # these properties are copied from Pydantic's `model_construct()` method + object.__setattr__(m, "__pydantic_private__", None) + object.__setattr__(m, "__pydantic_extra__", _extra) + object.__setattr__(m, "__pydantic_fields_set__", _fields_set) return m @@ -243,7 +243,7 @@ def construct( # pyright: ignore[reportIncompatibleMethodOverride] # although not in practice model_construct = construct - if not PYDANTIC_V2: + if PYDANTIC_V1: # we define aliases for some of the new pydantic v2 methods so # that we can just document these methods without having to specify # a specific pydantic version as some users may not know which @@ -256,7 +256,7 @@ def model_dump( mode: Literal["json", "python"] | str = "python", include: IncEx | None = None, exclude: IncEx | None = None, - by_alias: bool = False, + by_alias: bool | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, @@ -264,6 +264,7 @@ def model_dump( warnings: bool | Literal["none", "warn", "error"] = True, context: dict[str, Any] | None = None, serialize_as_any: bool = False, + fallback: Callable[[Any], Any] | None = None, ) -> dict[str, Any]: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump @@ -295,16 +296,18 @@ def model_dump( raise ValueError("context is only supported in Pydantic v2") if serialize_as_any != False: raise ValueError("serialize_as_any is only supported in Pydantic v2") + if fallback is not None: + raise ValueError("fallback is only supported in Pydantic v2") dumped = super().dict( # pyright: ignore[reportDeprecated] include=include, exclude=exclude, - by_alias=by_alias, + by_alias=by_alias if by_alias is not None else False, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) - return cast(dict[str, Any], json_safe(dumped)) if mode == "json" else dumped + return cast("dict[str, Any]", json_safe(dumped)) if mode == "json" else dumped @override def model_dump_json( @@ -313,13 +316,14 @@ def model_dump_json( indent: int | None = None, include: IncEx | None = None, exclude: IncEx | None = None, - by_alias: bool = False, + by_alias: bool | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, warnings: bool | Literal["none", "warn", "error"] = True, context: dict[str, Any] | None = None, + fallback: Callable[[Any], Any] | None = None, serialize_as_any: bool = False, ) -> str: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json @@ -348,11 +352,13 @@ def model_dump_json( raise ValueError("context is only supported in Pydantic v2") if serialize_as_any != False: raise ValueError("serialize_as_any is only supported in Pydantic v2") + if fallback is not None: + raise ValueError("fallback is only supported in Pydantic v2") return super().json( # type: ignore[reportDeprecated] indent=indent, include=include, exclude=exclude, - by_alias=by_alias, + by_alias=by_alias if by_alias is not None else False, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, @@ -363,10 +369,10 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object: if value is None: return field_get_default(field) - if PYDANTIC_V2: - type_ = field.annotation - else: + if PYDANTIC_V1: type_ = cast(type, field.outer_type_) # type: ignore + else: + type_ = field.annotation # type: ignore if type_ is None: raise RuntimeError(f"Unexpected field type is None for {key}") @@ -375,7 +381,7 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object: def _get_extra_fields_type(cls: type[pydantic.BaseModel]) -> type | None: - if not PYDANTIC_V2: + if PYDANTIC_V1: # TODO return None @@ -628,30 +634,30 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, for variant in get_args(union): variant = strip_annotated_type(variant) if is_basemodel_type(variant): - if PYDANTIC_V2: - field = _extract_field_schema_pv2(variant, discriminator_field_name) - if not field: + if PYDANTIC_V1: + field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + if not field_info: continue # Note: if one variant defines an alias then they all should - discriminator_alias = field.get("serialization_alias") - - field_schema = field["schema"] + discriminator_alias = field_info.alias - if field_schema["type"] == "literal": - for entry in cast("LiteralSchema", field_schema)["expected"]: + if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(annotation): + for entry in get_args(annotation): if isinstance(entry, str): mapping[entry] = variant else: - field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] - if not field_info: + field = _extract_field_schema_pv2(variant, discriminator_field_name) + if not field: continue # Note: if one variant defines an alias then they all should - discriminator_alias = field_info.alias + discriminator_alias = field.get("serialization_alias") - if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(annotation): - for entry in get_args(annotation): + field_schema = field["schema"] + + if field_schema["type"] == "literal": + for entry in cast("LiteralSchema", field_schema)["expected"]: if isinstance(entry, str): mapping[entry] = variant @@ -714,7 +720,7 @@ class GenericModel(BaseGenericModel, BaseModel): pass -if PYDANTIC_V2: +if not PYDANTIC_V1: from pydantic import TypeAdapter as _TypeAdapter _CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter)) @@ -782,12 +788,12 @@ class FinalRequestOptions(pydantic.BaseModel): json_data: Union[Body, None] = None extra_json: Union[AnyMapping, None] = None - if PYDANTIC_V2: - model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) - else: + if PYDANTIC_V1: class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] arbitrary_types_allowed: bool = True + else: + model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) def get_max_retries(self, max_retries: int) -> int: if isinstance(self.max_retries, NotGiven): @@ -820,9 +826,9 @@ def construct( # type: ignore key: strip_not_given(value) for key, value in values.items() } - if PYDANTIC_V2: - return super().model_construct(_fields_set, **kwargs) - return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated] + if PYDANTIC_V1: + return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated] + return super().model_construct(_fields_set, **kwargs) if not TYPE_CHECKING: # type checkers incorrectly complain about this assignment diff --git a/src/contextual/_qs.py b/src/contextual/_qs.py index 274320ca..ada6fd3f 100644 --- a/src/contextual/_qs.py +++ b/src/contextual/_qs.py @@ -4,7 +4,7 @@ from urllib.parse import parse_qs, urlencode from typing_extensions import Literal, get_args -from ._types import NOT_GIVEN, NotGiven, NotGivenOr +from ._types import NotGiven, not_given from ._utils import flatten _T = TypeVar("_T") @@ -41,8 +41,8 @@ def stringify( self, params: Params, *, - array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN, - nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN, + array_format: ArrayFormat | NotGiven = not_given, + nested_format: NestedFormat | NotGiven = not_given, ) -> str: return urlencode( self.stringify_items( @@ -56,8 +56,8 @@ def stringify_items( self, params: Params, *, - array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN, - nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN, + array_format: ArrayFormat | NotGiven = not_given, + nested_format: NestedFormat | NotGiven = not_given, ) -> list[tuple[str, str]]: opts = Options( qs=self, @@ -143,8 +143,8 @@ def __init__( self, qs: Querystring = _qs, *, - array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN, - nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN, + array_format: ArrayFormat | NotGiven = not_given, + nested_format: NestedFormat | NotGiven = not_given, ) -> None: self.array_format = qs.array_format if isinstance(array_format, NotGiven) else array_format self.nested_format = qs.nested_format if isinstance(nested_format, NotGiven) else nested_format diff --git a/src/contextual/_types.py b/src/contextual/_types.py index 46a038e6..a6a1882a 100644 --- a/src/contextual/_types.py +++ b/src/contextual/_types.py @@ -13,10 +13,21 @@ Mapping, TypeVar, Callable, + Iterator, Optional, Sequence, ) -from typing_extensions import Set, Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable +from typing_extensions import ( + Set, + Literal, + Protocol, + TypeAlias, + TypedDict, + SupportsIndex, + overload, + override, + runtime_checkable, +) import httpx import pydantic @@ -106,18 +117,21 @@ class RequestOptions(TypedDict, total=False): # Sentinel class used until PEP 0661 is accepted class NotGiven: """ - A sentinel singleton class used to distinguish omitted keyword arguments - from those passed in with the value None (which may have different behavior). + For parameters with a meaningful None value, we need to distinguish between + the user explicitly passing None, and the user not passing the parameter at + all. + + User code shouldn't need to use not_given directly. For example: ```py - def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ... + def create(timeout: Timeout | None | NotGiven = not_given): ... - get(timeout=1) # 1s timeout - get(timeout=None) # No timeout - get() # Default timeout behavior, which may not be statically known at the method definition. + create(timeout=1) # 1s timeout + create(timeout=None) # No timeout + create() # Default timeout behavior ``` """ @@ -129,13 +143,14 @@ def __repr__(self) -> str: return "NOT_GIVEN" -NotGivenOr = Union[_T, NotGiven] +not_given = NotGiven() +# for backwards compatibility: NOT_GIVEN = NotGiven() class Omit: - """In certain situations you need to be able to represent a case where a default value has - to be explicitly removed and `None` is not an appropriate substitute, for example: + """ + To explicitly omit something from being sent in a request, use `omit`. ```py # as the default `Content-Type` header is `application/json` that will be sent @@ -145,8 +160,8 @@ class Omit: # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983' client.post(..., headers={"Content-Type": "multipart/form-data"}) - # instead you can remove the default `application/json` header by passing Omit - client.post(..., headers={"Content-Type": Omit()}) + # instead you can remove the default `application/json` header by passing omit + client.post(..., headers={"Content-Type": omit}) ``` """ @@ -154,6 +169,9 @@ def __bool__(self) -> Literal[False]: return False +omit = Omit() + + @runtime_checkable class ModelBuilderProtocol(Protocol): @classmethod @@ -217,3 +235,26 @@ class _GenericAlias(Protocol): class HttpxSendArgs(TypedDict, total=False): auth: httpx.Auth follow_redirects: bool + + +_T_co = TypeVar("_T_co", covariant=True) + + +if TYPE_CHECKING: + # This works because str.__contains__ does not accept object (either in typeshed or at runtime) + # https://github.com/hauntsaninja/useful_types/blob/5e9710f3875107d068e7679fd7fec9cfab0eff3b/useful_types/__init__.py#L285 + class SequenceNotStr(Protocol[_T_co]): + @overload + def __getitem__(self, index: SupportsIndex, /) -> _T_co: ... + @overload + def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ... + def __contains__(self, value: object, /) -> bool: ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[_T_co]: ... + def index(self, value: Any, start: int = 0, stop: int = ..., /) -> int: ... + def count(self, value: Any, /) -> int: ... + def __reversed__(self) -> Iterator[_T_co]: ... +else: + # just point this to a normal `Sequence` at runtime to avoid having to special case + # deserializing our custom sequence type + SequenceNotStr = Sequence diff --git a/src/contextual/_utils/__init__.py b/src/contextual/_utils/__init__.py index d4fda26f..dc64e29a 100644 --- a/src/contextual/_utils/__init__.py +++ b/src/contextual/_utils/__init__.py @@ -10,7 +10,6 @@ lru_cache as lru_cache, is_mapping as is_mapping, is_tuple_t as is_tuple_t, - parse_date as parse_date, is_iterable as is_iterable, is_sequence as is_sequence, coerce_float as coerce_float, @@ -23,7 +22,6 @@ coerce_boolean as coerce_boolean, coerce_integer as coerce_integer, file_from_path as file_from_path, - parse_datetime as parse_datetime, strip_not_given as strip_not_given, deepcopy_minimal as deepcopy_minimal, get_async_library as get_async_library, @@ -32,12 +30,20 @@ maybe_coerce_boolean as maybe_coerce_boolean, maybe_coerce_integer as maybe_coerce_integer, ) +from ._compat import ( + get_args as get_args, + is_union as is_union, + get_origin as get_origin, + is_typeddict as is_typeddict, + is_literal_type as is_literal_type, +) from ._typing import ( is_list_type as is_list_type, is_union_type as is_union_type, extract_type_arg as extract_type_arg, is_iterable_type as is_iterable_type, is_required_type as is_required_type, + is_sequence_type as is_sequence_type, is_annotated_type as is_annotated_type, is_type_alias_type as is_type_alias_type, strip_annotated_type as strip_annotated_type, @@ -55,3 +61,4 @@ function_has_argument as function_has_argument, assert_signatures_in_sync as assert_signatures_in_sync, ) +from ._datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime diff --git a/src/contextual/_utils/_compat.py b/src/contextual/_utils/_compat.py new file mode 100644 index 00000000..dd703233 --- /dev/null +++ b/src/contextual/_utils/_compat.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import sys +import typing_extensions +from typing import Any, Type, Union, Literal, Optional +from datetime import date, datetime +from typing_extensions import get_args as _get_args, get_origin as _get_origin + +from .._types import StrBytesIntFloat +from ._datetime_parse import parse_date as _parse_date, parse_datetime as _parse_datetime + +_LITERAL_TYPES = {Literal, typing_extensions.Literal} + + +def get_args(tp: type[Any]) -> tuple[Any, ...]: + return _get_args(tp) + + +def get_origin(tp: type[Any]) -> type[Any] | None: + return _get_origin(tp) + + +def is_union(tp: Optional[Type[Any]]) -> bool: + if sys.version_info < (3, 10): + return tp is Union # type: ignore[comparison-overlap] + else: + import types + + return tp is Union or tp is types.UnionType + + +def is_typeddict(tp: Type[Any]) -> bool: + return typing_extensions.is_typeddict(tp) + + +def is_literal_type(tp: Type[Any]) -> bool: + return get_origin(tp) in _LITERAL_TYPES + + +def parse_date(value: Union[date, StrBytesIntFloat]) -> date: + return _parse_date(value) + + +def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: + return _parse_datetime(value) diff --git a/src/contextual/_utils/_datetime_parse.py b/src/contextual/_utils/_datetime_parse.py new file mode 100644 index 00000000..7cb9d9e6 --- /dev/null +++ b/src/contextual/_utils/_datetime_parse.py @@ -0,0 +1,136 @@ +""" +This file contains code from https://github.com/pydantic/pydantic/blob/main/pydantic/v1/datetime_parse.py +without the Pydantic v1 specific errors. +""" + +from __future__ import annotations + +import re +from typing import Dict, Union, Optional +from datetime import date, datetime, timezone, timedelta + +from .._types import StrBytesIntFloat + +date_expr = r"(?P\d{4})-(?P\d{1,2})-(?P\d{1,2})" +time_expr = ( + r"(?P\d{1,2}):(?P\d{1,2})" + r"(?::(?P\d{1,2})(?:\.(?P\d{1,6})\d{0,6})?)?" + r"(?PZ|[+-]\d{2}(?::?\d{2})?)?$" +) + +date_re = re.compile(f"{date_expr}$") +datetime_re = re.compile(f"{date_expr}[T ]{time_expr}") + + +EPOCH = datetime(1970, 1, 1) +# if greater than this, the number is in ms, if less than or equal it's in seconds +# (in seconds this is 11th October 2603, in ms it's 20th August 1970) +MS_WATERSHED = int(2e10) +# slightly more than datetime.max in ns - (datetime.max - EPOCH).total_seconds() * 1e9 +MAX_NUMBER = int(3e20) + + +def _get_numeric(value: StrBytesIntFloat, native_expected_type: str) -> Union[None, int, float]: + if isinstance(value, (int, float)): + return value + try: + return float(value) + except ValueError: + return None + except TypeError: + raise TypeError(f"invalid type; expected {native_expected_type}, string, bytes, int or float") from None + + +def _from_unix_seconds(seconds: Union[int, float]) -> datetime: + if seconds > MAX_NUMBER: + return datetime.max + elif seconds < -MAX_NUMBER: + return datetime.min + + while abs(seconds) > MS_WATERSHED: + seconds /= 1000 + dt = EPOCH + timedelta(seconds=seconds) + return dt.replace(tzinfo=timezone.utc) + + +def _parse_timezone(value: Optional[str]) -> Union[None, int, timezone]: + if value == "Z": + return timezone.utc + elif value is not None: + offset_mins = int(value[-2:]) if len(value) > 3 else 0 + offset = 60 * int(value[1:3]) + offset_mins + if value[0] == "-": + offset = -offset + return timezone(timedelta(minutes=offset)) + else: + return None + + +def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: + """ + Parse a datetime/int/float/string and return a datetime.datetime. + + This function supports time zone offsets. When the input contains one, + the output uses a timezone with a fixed offset from UTC. + + Raise ValueError if the input is well formatted but not a valid datetime. + Raise ValueError if the input isn't well formatted. + """ + if isinstance(value, datetime): + return value + + number = _get_numeric(value, "datetime") + if number is not None: + return _from_unix_seconds(number) + + if isinstance(value, bytes): + value = value.decode() + + assert not isinstance(value, (float, int)) + + match = datetime_re.match(value) + if match is None: + raise ValueError("invalid datetime format") + + kw = match.groupdict() + if kw["microsecond"]: + kw["microsecond"] = kw["microsecond"].ljust(6, "0") + + tzinfo = _parse_timezone(kw.pop("tzinfo")) + kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None} + kw_["tzinfo"] = tzinfo + + return datetime(**kw_) # type: ignore + + +def parse_date(value: Union[date, StrBytesIntFloat]) -> date: + """ + Parse a date/int/float/string and return a datetime.date. + + Raise ValueError if the input is well formatted but not a valid date. + Raise ValueError if the input isn't well formatted. + """ + if isinstance(value, date): + if isinstance(value, datetime): + return value.date() + else: + return value + + number = _get_numeric(value, "date") + if number is not None: + return _from_unix_seconds(number).date() + + if isinstance(value, bytes): + value = value.decode() + + assert not isinstance(value, (float, int)) + match = date_re.match(value) + if match is None: + raise ValueError("invalid date format") + + kw = {k: int(v) for k, v in match.groupdict().items()} + + try: + return date(**kw) + except ValueError: + raise ValueError("invalid date format") from None diff --git a/src/contextual/_utils/_transform.py b/src/contextual/_utils/_transform.py index b0cc20a7..52075492 100644 --- a/src/contextual/_utils/_transform.py +++ b/src/contextual/_utils/_transform.py @@ -16,18 +16,20 @@ lru_cache, is_mapping, is_iterable, + is_sequence, ) from .._files import is_base64_file_input +from ._compat import get_origin, is_typeddict from ._typing import ( is_list_type, is_union_type, extract_type_arg, is_iterable_type, is_required_type, + is_sequence_type, is_annotated_type, strip_annotated_type, ) -from .._compat import get_origin, model_dump, is_typeddict _T = TypeVar("_T") @@ -167,6 +169,8 @@ def _transform_recursive( Defaults to the same value as the `annotation` argument. """ + from .._compat import model_dump + if inner_type is None: inner_type = annotation @@ -184,6 +188,8 @@ def _transform_recursive( (is_list_type(stripped_type) and is_list(data)) # Iterable[T] or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + # Sequence[T] + or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str)) ): # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually # intended as an iterable, so we don't transform it. @@ -262,7 +268,7 @@ def _transform_typeddict( annotations = get_type_hints(expected_type, include_extras=True) for key, value in data.items(): if not is_given(value): - # we don't need to include `NotGiven` values here as they'll + # we don't need to include omitted values here as they'll # be stripped out before the request is sent anyway continue @@ -329,6 +335,8 @@ async def _async_transform_recursive( Defaults to the same value as the `annotation` argument. """ + from .._compat import model_dump + if inner_type is None: inner_type = annotation @@ -346,6 +354,8 @@ async def _async_transform_recursive( (is_list_type(stripped_type) and is_list(data)) # Iterable[T] or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + # Sequence[T] + or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str)) ): # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually # intended as an iterable, so we don't transform it. @@ -424,7 +434,7 @@ async def _async_transform_typeddict( annotations = get_type_hints(expected_type, include_extras=True) for key, value in data.items(): if not is_given(value): - # we don't need to include `NotGiven` values here as they'll + # we don't need to include omitted values here as they'll # be stripped out before the request is sent anyway continue diff --git a/src/contextual/_utils/_typing.py b/src/contextual/_utils/_typing.py index 1bac9542..193109f3 100644 --- a/src/contextual/_utils/_typing.py +++ b/src/contextual/_utils/_typing.py @@ -15,7 +15,7 @@ from ._utils import lru_cache from .._types import InheritsGeneric -from .._compat import is_union as _is_union +from ._compat import is_union as _is_union def is_annotated_type(typ: type) -> bool: @@ -26,6 +26,11 @@ def is_list_type(typ: type) -> bool: return (get_origin(typ) or typ) == list +def is_sequence_type(typ: type) -> bool: + origin = get_origin(typ) or typ + return origin == typing_extensions.Sequence or origin == typing.Sequence or origin == _c_abc.Sequence + + def is_iterable_type(typ: type) -> bool: """If the given type is `typing.Iterable[T]`""" origin = get_origin(typ) or typ diff --git a/src/contextual/_utils/_utils.py b/src/contextual/_utils/_utils.py index ea3cf3f2..50d59269 100644 --- a/src/contextual/_utils/_utils.py +++ b/src/contextual/_utils/_utils.py @@ -21,8 +21,7 @@ import sniffio -from .._types import NotGiven, FileTypes, NotGivenOr, HeadersLike -from .._compat import parse_date as parse_date, parse_datetime as parse_datetime +from .._types import Omit, NotGiven, FileTypes, HeadersLike _T = TypeVar("_T") _TupleT = TypeVar("_TupleT", bound=Tuple[object, ...]) @@ -64,7 +63,7 @@ def _extract_items( try: key = path[index] except IndexError: - if isinstance(obj, NotGiven): + if not is_given(obj): # no value was provided - we can safely ignore return [] @@ -127,8 +126,8 @@ def _extract_items( return [] -def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]: - return not isinstance(obj, NotGiven) +def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]: + return not isinstance(obj, NotGiven) and not isinstance(obj, Omit) # Type safe methods for narrowing types with TypeVars. diff --git a/src/contextual/_version.py b/src/contextual/_version.py index a855209f..1cf85a09 100644 --- a/src/contextual/_version.py +++ b/src/contextual/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "contextual" -__version__ = "0.8.0" # x-release-please-version +__version__ = "0.9.0" # x-release-please-version diff --git a/src/contextual/resources/agents/agents.py b/src/contextual/resources/agents/agents.py index 0a2c3a93..0bce0ebc 100644 --- a/src/contextual/resources/agents/agents.py +++ b/src/contextual/resources/agents/agents.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, List, cast +from typing import Any, cast import httpx @@ -15,7 +15,7 @@ AsyncQueryResourceWithStreamingResponse, ) from ...types import agent_list_params, agent_create_params, agent_update_params -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from ..._types import Body, Omit, Query, Headers, NotGiven, SequenceNotStr, omit, not_given from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource @@ -63,20 +63,21 @@ def create( self, *, name: str, - agent_configs: AgentConfigsParam | NotGiven = NOT_GIVEN, - datastore_ids: List[str] | NotGiven = NOT_GIVEN, - description: str | NotGiven = NOT_GIVEN, - filter_prompt: str | NotGiven = NOT_GIVEN, - multiturn_system_prompt: str | NotGiven = NOT_GIVEN, - no_retrieval_system_prompt: str | NotGiven = NOT_GIVEN, - suggested_queries: List[str] | NotGiven = NOT_GIVEN, - system_prompt: str | NotGiven = NOT_GIVEN, + agent_configs: AgentConfigsParam | Omit = omit, + datastore_ids: SequenceNotStr[str] | Omit = omit, + description: str | Omit = omit, + filter_prompt: str | Omit = omit, + multiturn_system_prompt: str | Omit = omit, + no_retrieval_system_prompt: str | Omit = omit, + suggested_queries: SequenceNotStr[str] | Omit = omit, + system_prompt: str | Omit = omit, + template_name: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CreateAgentOutput: """ Create a new `Agent` with a specific configuration. @@ -121,6 +122,8 @@ def create( system_prompt: Instructions that your agent references when generating responses. Note that we do not guarantee that the system will follow these instructions exactly. + template_name: The template defining the base configuration for the agent. + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -142,6 +145,7 @@ def create( "no_retrieval_system_prompt": no_retrieval_system_prompt, "suggested_queries": suggested_queries, "system_prompt": system_prompt, + "template_name": template_name, }, agent_create_params.AgentCreateParams, ), @@ -155,20 +159,21 @@ def update( self, agent_id: str, *, - agent_configs: AgentConfigsParam | NotGiven = NOT_GIVEN, - datastore_ids: List[str] | NotGiven = NOT_GIVEN, - filter_prompt: str | NotGiven = NOT_GIVEN, - llm_model_id: str | NotGiven = NOT_GIVEN, - multiturn_system_prompt: str | NotGiven = NOT_GIVEN, - no_retrieval_system_prompt: str | NotGiven = NOT_GIVEN, - suggested_queries: List[str] | NotGiven = NOT_GIVEN, - system_prompt: str | NotGiven = NOT_GIVEN, + agent_configs: AgentConfigsParam | Omit = omit, + datastore_ids: SequenceNotStr[str] | Omit = omit, + description: str | Omit = omit, + filter_prompt: str | Omit = omit, + multiturn_system_prompt: str | Omit = omit, + name: str | Omit = omit, + no_retrieval_system_prompt: str | Omit = omit, + suggested_queries: SequenceNotStr[str] | Omit = omit, + system_prompt: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> object: """ Modify a given `Agent` to utilize the provided configuration. @@ -182,15 +187,15 @@ def update( datastore_ids: IDs of the datastore to associate with the agent. + description: Description of the agent + filter_prompt: The prompt to an LLM which determines whether retrieved chunks are relevant to a given query and filters out irrelevant chunks. - llm_model_id: The model ID to use for generation. Tuned models can only be used for the agents - on which they were tuned. If no model is specified, the default model is used. - Set to `default` to switch from a tuned model to the default model. - multiturn_system_prompt: Instructions on how the agent should handle multi-turn conversations. + name: Name of the agent + no_retrieval_system_prompt: Instructions on how the agent should respond when there are no relevant retrievals that can be used to answer a query. @@ -218,9 +223,10 @@ def update( { "agent_configs": agent_configs, "datastore_ids": datastore_ids, + "description": description, "filter_prompt": filter_prompt, - "llm_model_id": llm_model_id, "multiturn_system_prompt": multiturn_system_prompt, + "name": name, "no_retrieval_system_prompt": no_retrieval_system_prompt, "suggested_queries": suggested_queries, "system_prompt": system_prompt, @@ -236,14 +242,14 @@ def update( def list( self, *, - cursor: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, + cursor: str | Omit = omit, + limit: int | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> SyncPage[Agent]: """ Retrieve a list of all `Agents`. @@ -290,7 +296,7 @@ def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> object: """Delete a given `Agent`. @@ -330,7 +336,7 @@ def copy( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CreateAgentOutput: """ Copy an existing agent with all its configurations and datastore associations. @@ -366,7 +372,7 @@ def metadata( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> AgentMetadataResponse: """ Get metadata and configuration of a given `Agent`. @@ -406,7 +412,7 @@ def reset( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> object: """ Reset a given `Agent` to default configuration. @@ -461,20 +467,21 @@ async def create( self, *, name: str, - agent_configs: AgentConfigsParam | NotGiven = NOT_GIVEN, - datastore_ids: List[str] | NotGiven = NOT_GIVEN, - description: str | NotGiven = NOT_GIVEN, - filter_prompt: str | NotGiven = NOT_GIVEN, - multiturn_system_prompt: str | NotGiven = NOT_GIVEN, - no_retrieval_system_prompt: str | NotGiven = NOT_GIVEN, - suggested_queries: List[str] | NotGiven = NOT_GIVEN, - system_prompt: str | NotGiven = NOT_GIVEN, + agent_configs: AgentConfigsParam | Omit = omit, + datastore_ids: SequenceNotStr[str] | Omit = omit, + description: str | Omit = omit, + filter_prompt: str | Omit = omit, + multiturn_system_prompt: str | Omit = omit, + no_retrieval_system_prompt: str | Omit = omit, + suggested_queries: SequenceNotStr[str] | Omit = omit, + system_prompt: str | Omit = omit, + template_name: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CreateAgentOutput: """ Create a new `Agent` with a specific configuration. @@ -519,6 +526,8 @@ async def create( system_prompt: Instructions that your agent references when generating responses. Note that we do not guarantee that the system will follow these instructions exactly. + template_name: The template defining the base configuration for the agent. + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -540,6 +549,7 @@ async def create( "no_retrieval_system_prompt": no_retrieval_system_prompt, "suggested_queries": suggested_queries, "system_prompt": system_prompt, + "template_name": template_name, }, agent_create_params.AgentCreateParams, ), @@ -553,20 +563,21 @@ async def update( self, agent_id: str, *, - agent_configs: AgentConfigsParam | NotGiven = NOT_GIVEN, - datastore_ids: List[str] | NotGiven = NOT_GIVEN, - filter_prompt: str | NotGiven = NOT_GIVEN, - llm_model_id: str | NotGiven = NOT_GIVEN, - multiturn_system_prompt: str | NotGiven = NOT_GIVEN, - no_retrieval_system_prompt: str | NotGiven = NOT_GIVEN, - suggested_queries: List[str] | NotGiven = NOT_GIVEN, - system_prompt: str | NotGiven = NOT_GIVEN, + agent_configs: AgentConfigsParam | Omit = omit, + datastore_ids: SequenceNotStr[str] | Omit = omit, + description: str | Omit = omit, + filter_prompt: str | Omit = omit, + multiturn_system_prompt: str | Omit = omit, + name: str | Omit = omit, + no_retrieval_system_prompt: str | Omit = omit, + suggested_queries: SequenceNotStr[str] | Omit = omit, + system_prompt: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> object: """ Modify a given `Agent` to utilize the provided configuration. @@ -580,15 +591,15 @@ async def update( datastore_ids: IDs of the datastore to associate with the agent. + description: Description of the agent + filter_prompt: The prompt to an LLM which determines whether retrieved chunks are relevant to a given query and filters out irrelevant chunks. - llm_model_id: The model ID to use for generation. Tuned models can only be used for the agents - on which they were tuned. If no model is specified, the default model is used. - Set to `default` to switch from a tuned model to the default model. - multiturn_system_prompt: Instructions on how the agent should handle multi-turn conversations. + name: Name of the agent + no_retrieval_system_prompt: Instructions on how the agent should respond when there are no relevant retrievals that can be used to answer a query. @@ -616,9 +627,10 @@ async def update( { "agent_configs": agent_configs, "datastore_ids": datastore_ids, + "description": description, "filter_prompt": filter_prompt, - "llm_model_id": llm_model_id, "multiturn_system_prompt": multiturn_system_prompt, + "name": name, "no_retrieval_system_prompt": no_retrieval_system_prompt, "suggested_queries": suggested_queries, "system_prompt": system_prompt, @@ -634,14 +646,14 @@ async def update( def list( self, *, - cursor: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, + cursor: str | Omit = omit, + limit: int | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> AsyncPaginator[Agent, AsyncPage[Agent]]: """ Retrieve a list of all `Agents`. @@ -688,7 +700,7 @@ async def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> object: """Delete a given `Agent`. @@ -728,7 +740,7 @@ async def copy( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CreateAgentOutput: """ Copy an existing agent with all its configurations and datastore associations. @@ -764,7 +776,7 @@ async def metadata( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> AgentMetadataResponse: """ Get metadata and configuration of a given `Agent`. @@ -804,7 +816,7 @@ async def reset( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> object: """ Reset a given `Agent` to default configuration. diff --git a/src/contextual/resources/agents/query.py b/src/contextual/resources/agents/query.py index d321b346..40a9e7b9 100644 --- a/src/contextual/resources/agents/query.py +++ b/src/contextual/resources/agents/query.py @@ -2,13 +2,13 @@ from __future__ import annotations -from typing import List, Union, Iterable +from typing import Union, Iterable from datetime import datetime from typing_extensions import Literal import httpx -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from ..._types import Body, Omit, Query, Headers, NotGiven, SequenceNotStr, omit, not_given from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource @@ -27,6 +27,7 @@ ) from ...types.agents.query_response import QueryResponse from ...types.agents.query_metrics_response import QueryMetricsResponse +from ...types.agents.query_feedback_response import QueryFeedbackResponse from ...types.agents.retrieval_info_response import RetrievalInfoResponse __all__ = ["QueryResource", "AsyncQueryResource"] @@ -57,20 +58,20 @@ def create( agent_id: str, *, messages: Iterable[query_create_params.Message], - include_retrieval_content_text: bool | NotGiven = NOT_GIVEN, - retrievals_only: bool | NotGiven = NOT_GIVEN, - conversation_id: str | NotGiven = NOT_GIVEN, - documents_filters: query_create_params.DocumentsFilters | NotGiven = NOT_GIVEN, - llm_model_id: str | NotGiven = NOT_GIVEN, - override_configuration: query_create_params.OverrideConfiguration | NotGiven = NOT_GIVEN, - stream: bool | NotGiven = NOT_GIVEN, - structured_output: query_create_params.StructuredOutput | NotGiven = NOT_GIVEN, + include_retrieval_content_text: bool | Omit = omit, + retrievals_only: bool | Omit = omit, + conversation_id: str | Omit = omit, + documents_filters: query_create_params.DocumentsFilters | Omit = omit, + llm_model_id: str | Omit = omit, + override_configuration: query_create_params.OverrideConfiguration | Omit = omit, + stream: bool | Omit = omit, + structured_output: query_create_params.StructuredOutput | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> QueryResponse: """ Start a conversation with an `Agent` and receive its generated response, along @@ -98,8 +99,10 @@ def create( query will be ignored. documents_filters: Defines an Optional custom metadata filter, which can be a list of filters or - nested filters. The expected input is a nested JSON object that can represent a - single filter or a composite (logical) combination of filters. + nested filters. Use **lowercase** for `value` and/or **field.keyword** for + `field` when not using `equals` operator.The expected input is a nested JSON + object that can represent a single filter or a composite (logical) combination + of filters. Unnested Example: @@ -188,27 +191,21 @@ def feedback( *, feedback: Literal["thumbs_up", "thumbs_down", "flagged", "removed"], message_id: str, - content_id: str | NotGiven = NOT_GIVEN, - explanation: str | NotGiven = NOT_GIVEN, + content_id: str | Omit = omit, + explanation: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> object: + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> QueryFeedbackResponse: """Provide feedback for a generation or a retrieval. Feedback can be used to track overall `Agent` performance through the `Feedback` page in the Contextual UI, and as a basis for model fine-tuning. - If providing feedback on a retrieval, include the `message_id` from the `/query` - response, and a `content_id` returned in the query's `retrieval_contents` list. - - For feedback on generations, include `message_id` and do not include a - `content_id`. - Args: agent_id: ID of the agent for which to provide feedback @@ -246,25 +243,25 @@ def feedback( options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=object, + cast_to=QueryFeedbackResponse, ) def metrics( self, agent_id: str, *, - conversation_ids: List[str] | NotGiven = NOT_GIVEN, - created_after: Union[str, datetime] | NotGiven = NOT_GIVEN, - created_before: Union[str, datetime] | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - offset: int | NotGiven = NOT_GIVEN, - user_emails: List[str] | NotGiven = NOT_GIVEN, + conversation_ids: SequenceNotStr[str] | Omit = omit, + created_after: Union[str, datetime] | Omit = omit, + created_before: Union[str, datetime] | Omit = omit, + limit: int | Omit = omit, + offset: int | Omit = omit, + user_emails: SequenceNotStr[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> QueryMetricsResponse: """Returns usage and user-provided feedback data. @@ -278,7 +275,15 @@ def metrics( created_after: Filters messages that are created after the specified timestamp. - created_before: Filters messages that are created before specified timestamp. + created_before: Filters messages that are created before specified timestamp. If both + `created_after` and `created_before` are not provided, then `created_before` + will be set to the current time and `created_after` will be set to the + `created_before` - 2 days. If only `created_after` is provided, then + `created_before` will be set to the `created_after` + 2 days. If only + `created_before` is provided, then `created_after` will be set to the + `created_before` - 2 days. If both `created_after` and `created_before` are + provided, and the difference between them is more than 2 days, then + `created_after` will be set to the `created_before` - 2 days. limit: Limits the number of messages to return. @@ -323,13 +328,13 @@ def retrieval_info( message_id: str, *, agent_id: str, - content_ids: List[str], + content_ids: SequenceNotStr[str], # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> RetrievalInfoResponse: """ Return metadata of the contents used to generate the response for a given @@ -394,20 +399,20 @@ async def create( agent_id: str, *, messages: Iterable[query_create_params.Message], - include_retrieval_content_text: bool | NotGiven = NOT_GIVEN, - retrievals_only: bool | NotGiven = NOT_GIVEN, - conversation_id: str | NotGiven = NOT_GIVEN, - documents_filters: query_create_params.DocumentsFilters | NotGiven = NOT_GIVEN, - llm_model_id: str | NotGiven = NOT_GIVEN, - override_configuration: query_create_params.OverrideConfiguration | NotGiven = NOT_GIVEN, - stream: bool | NotGiven = NOT_GIVEN, - structured_output: query_create_params.StructuredOutput | NotGiven = NOT_GIVEN, + include_retrieval_content_text: bool | Omit = omit, + retrievals_only: bool | Omit = omit, + conversation_id: str | Omit = omit, + documents_filters: query_create_params.DocumentsFilters | Omit = omit, + llm_model_id: str | Omit = omit, + override_configuration: query_create_params.OverrideConfiguration | Omit = omit, + stream: bool | Omit = omit, + structured_output: query_create_params.StructuredOutput | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> QueryResponse: """ Start a conversation with an `Agent` and receive its generated response, along @@ -435,8 +440,10 @@ async def create( query will be ignored. documents_filters: Defines an Optional custom metadata filter, which can be a list of filters or - nested filters. The expected input is a nested JSON object that can represent a - single filter or a composite (logical) combination of filters. + nested filters. Use **lowercase** for `value` and/or **field.keyword** for + `field` when not using `equals` operator.The expected input is a nested JSON + object that can represent a single filter or a composite (logical) combination + of filters. Unnested Example: @@ -525,27 +532,21 @@ async def feedback( *, feedback: Literal["thumbs_up", "thumbs_down", "flagged", "removed"], message_id: str, - content_id: str | NotGiven = NOT_GIVEN, - explanation: str | NotGiven = NOT_GIVEN, + content_id: str | Omit = omit, + explanation: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> object: + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> QueryFeedbackResponse: """Provide feedback for a generation or a retrieval. Feedback can be used to track overall `Agent` performance through the `Feedback` page in the Contextual UI, and as a basis for model fine-tuning. - If providing feedback on a retrieval, include the `message_id` from the `/query` - response, and a `content_id` returned in the query's `retrieval_contents` list. - - For feedback on generations, include `message_id` and do not include a - `content_id`. - Args: agent_id: ID of the agent for which to provide feedback @@ -583,25 +584,25 @@ async def feedback( options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=object, + cast_to=QueryFeedbackResponse, ) async def metrics( self, agent_id: str, *, - conversation_ids: List[str] | NotGiven = NOT_GIVEN, - created_after: Union[str, datetime] | NotGiven = NOT_GIVEN, - created_before: Union[str, datetime] | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - offset: int | NotGiven = NOT_GIVEN, - user_emails: List[str] | NotGiven = NOT_GIVEN, + conversation_ids: SequenceNotStr[str] | Omit = omit, + created_after: Union[str, datetime] | Omit = omit, + created_before: Union[str, datetime] | Omit = omit, + limit: int | Omit = omit, + offset: int | Omit = omit, + user_emails: SequenceNotStr[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> QueryMetricsResponse: """Returns usage and user-provided feedback data. @@ -615,7 +616,15 @@ async def metrics( created_after: Filters messages that are created after the specified timestamp. - created_before: Filters messages that are created before specified timestamp. + created_before: Filters messages that are created before specified timestamp. If both + `created_after` and `created_before` are not provided, then `created_before` + will be set to the current time and `created_after` will be set to the + `created_before` - 2 days. If only `created_after` is provided, then + `created_before` will be set to the `created_after` + 2 days. If only + `created_before` is provided, then `created_after` will be set to the + `created_before` - 2 days. If both `created_after` and `created_before` are + provided, and the difference between them is more than 2 days, then + `created_after` will be set to the `created_before` - 2 days. limit: Limits the number of messages to return. @@ -660,13 +669,13 @@ async def retrieval_info( message_id: str, *, agent_id: str, - content_ids: List[str], + content_ids: SequenceNotStr[str], # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> RetrievalInfoResponse: """ Return metadata of the contents used to generate the response for a given diff --git a/src/contextual/resources/datastores/datastores.py b/src/contextual/resources/datastores/datastores.py index 12d34683..3a0b964c 100644 --- a/src/contextual/resources/datastores/datastores.py +++ b/src/contextual/resources/datastores/datastores.py @@ -5,7 +5,7 @@ import httpx from ...types import datastore_list_params, datastore_create_params, datastore_update_params -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from ..._types import Body, Omit, Query, Headers, NotGiven, omit, not_given from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from .documents import ( @@ -61,13 +61,13 @@ def create( self, *, name: str, - configuration: datastore_create_params.Configuration | NotGiven = NOT_GIVEN, + configuration: datastore_create_params.Configuration | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CreateDatastoreResponse: """Create a new `Datastore`. @@ -119,14 +119,14 @@ def update( self, datastore_id: str, *, - configuration: datastore_update_params.Configuration | NotGiven = NOT_GIVEN, - name: str | NotGiven = NOT_GIVEN, + configuration: datastore_update_params.Configuration | Omit = omit, + name: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DatastoreUpdateResponse: """ Edit Datastore Configuration @@ -167,15 +167,15 @@ def update( def list( self, *, - agent_id: str | NotGiven = NOT_GIVEN, - cursor: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, + agent_id: str | Omit = omit, + cursor: str | Omit = omit, + limit: int | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> SyncDatastoresPage[Datastore]: """ Retrieve a list of `Datastores`. @@ -230,7 +230,7 @@ def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> object: """Delete a given `Datastore`, including all the documents ingested into it. @@ -270,7 +270,7 @@ def metadata( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DatastoreMetadata: """ Get the details of a given `Datastore`, including its name, create time, and the @@ -306,7 +306,7 @@ def reset( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> object: """Reset the give `Datastore`. @@ -363,13 +363,13 @@ async def create( self, *, name: str, - configuration: datastore_create_params.Configuration | NotGiven = NOT_GIVEN, + configuration: datastore_create_params.Configuration | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> CreateDatastoreResponse: """Create a new `Datastore`. @@ -421,14 +421,14 @@ async def update( self, datastore_id: str, *, - configuration: datastore_update_params.Configuration | NotGiven = NOT_GIVEN, - name: str | NotGiven = NOT_GIVEN, + configuration: datastore_update_params.Configuration | Omit = omit, + name: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DatastoreUpdateResponse: """ Edit Datastore Configuration @@ -469,15 +469,15 @@ async def update( def list( self, *, - agent_id: str | NotGiven = NOT_GIVEN, - cursor: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, + agent_id: str | Omit = omit, + cursor: str | Omit = omit, + limit: int | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> AsyncPaginator[Datastore, AsyncDatastoresPage[Datastore]]: """ Retrieve a list of `Datastores`. @@ -532,7 +532,7 @@ async def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> object: """Delete a given `Datastore`, including all the documents ingested into it. @@ -572,7 +572,7 @@ async def metadata( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DatastoreMetadata: """ Get the details of a given `Datastore`, including its name, create time, and the @@ -608,7 +608,7 @@ async def reset( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> object: """Reset the give `Datastore`. diff --git a/src/contextual/resources/datastores/documents.py b/src/contextual/resources/datastores/documents.py index 5fb3604e..98c2b702 100644 --- a/src/contextual/resources/datastores/documents.py +++ b/src/contextual/resources/datastores/documents.py @@ -2,13 +2,13 @@ from __future__ import annotations -from typing import Dict, List, Union, Mapping, cast +from typing import Dict, List, Union, Mapping, Iterable, cast from datetime import datetime from typing_extensions import Literal import httpx -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes +from ..._types import Body, Omit, Query, Headers, NotGiven, FileTypes, omit, not_given from ..._utils import extract_files, maybe_transform, deepcopy_minimal, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource @@ -57,19 +57,19 @@ def list( self, datastore_id: str, *, - cursor: str | NotGiven = NOT_GIVEN, - document_name_prefix: str | NotGiven = NOT_GIVEN, + cursor: str | Omit = omit, + document_name_prefix: str | Omit = omit, ingestion_job_status: List[Literal["pending", "processing", "retrying", "completed", "failed", "cancelled"]] - | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - uploaded_after: Union[str, datetime] | NotGiven = NOT_GIVEN, - uploaded_before: Union[str, datetime] | NotGiven = NOT_GIVEN, + | Omit = omit, + limit: int | Omit = omit, + uploaded_after: Union[str, datetime] | Omit = omit, + uploaded_before: Union[str, datetime] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> SyncDocumentsPage[DocumentMetadata]: """ Get list of documents in a given `Datastore`, including document `id`, `name`, @@ -140,7 +140,7 @@ def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> object: """Delete a given document from its `Datastore`. @@ -176,13 +176,13 @@ def get_parse_result( document_id: str, *, datastore_id: str, - output_types: List[Literal["markdown-document", "markdown-per-page", "blocks-per-page"]] | NotGiven = NOT_GIVEN, + output_types: List[Literal["markdown-document", "markdown-per-page", "blocks-per-page"]] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DocumentGetParseResultResponse: """ Get the parse results that are generated during ingestion for a given document. @@ -232,13 +232,14 @@ def ingest( datastore_id: str, *, file: FileTypes, - metadata: str | NotGiven = NOT_GIVEN, + configuration: str | Omit = omit, + metadata: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> IngestionResponse: """Ingest a document into a given `Datastore`. @@ -260,25 +261,29 @@ def ingest( file: File to ingest. - metadata: Metadata request in JSON format. `custom_metadata` is a flat dictionary - containing one or more key-value pairs, where each value must be a primitive - type (`str`, `bool`, `float`, or `int`). The default maximum metadata fields - that can be used is 15, contact support if more is needed.The combined size of - the metadata must not exceed **2 KB** when encoded as JSON.The strings with date - format must stay in date format or be avoided if not in date format.The - `custom_metadata.url` field is automatically included in returned attributions - during query time, if provided. - - **Example Request Body:** - - ```json - { - "custom_metadata": { - "topic": "science", - "difficulty": 3 - } - } - ``` + configuration: Overrides the datastore's default configuration for this specific document. This + allows applying optimized settings tailored to the document's characteristics + without changing the global datastore configuration. + + metadata: Metadata request in stringified JSON format. `custom_metadata` is a flat + dictionary containing one or more key-value pairs, where each value must be a + primitive type (`str`, `bool`, `float`, or `int`). The default maximum metadata + fields that can be used is 15, contact support@contextual.ai if more is needed. + The combined size of the metadata must not exceed **2 KB** when encoded as JSON. + The strings with date format must stay in date format or be avoided if not in + date format. The `custom_metadata.url` or `link` field is automatically included + in returned attributions during query time, if provided. + + **Example Request Body (as returned by `json.dumps`):** + + ```json + "{{ + \"custom_metadata\": {{ + \"topic\": \"science\", + \"difficulty\": 3 + }} + }}" + ``` extra_headers: Send extra headers @@ -293,6 +298,7 @@ def ingest( body = deepcopy_minimal( { "file": file, + "configuration": configuration, "metadata": metadata, } ) @@ -321,7 +327,7 @@ def metadata( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DocumentMetadata: """ Get details of a given document, including its `name` and ingestion job @@ -357,14 +363,14 @@ def set_metadata( document_id: str, *, datastore_id: str, - custom_metadata: Dict[str, Union[bool, float, str]] | NotGiven = NOT_GIVEN, - custom_metadata_config: Dict[str, document_set_metadata_params.CustomMetadataConfig] | NotGiven = NOT_GIVEN, + custom_metadata: Dict[str, Union[bool, float, str, Iterable[float]]] | Omit = omit, + custom_metadata_config: Dict[str, document_set_metadata_params.CustomMetadataConfig] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DocumentMetadata: """ Post details of a given document that will enrich the chunk and be added to the @@ -385,17 +391,12 @@ def set_metadata( used is 15, contact support if more is needed. custom_metadata_config: A dictionary mapping metadata field names to the configuration to use for each - field. - - - If a metadata field is not present in the dictionary, the default configuration will be used. - - - If the dictionary is not provided, metadata will be added in chunks but will not be retrievable. - - Limits: - Maximum characters per metadata field (for prompt or rerank): 400 - - - Maximum number of metadata fields (for prompt or retrieval): 10 - - Contact support@contextual.ai to request quota increases. + field. If a metadata field is not present in the dictionary, the default + configuration will be used. If the dictionary is not provided, metadata will be + added in context for rerank and generation but will not be returned back to the + user in retrievals in query API. Limits: - Maximum characters per metadata field + (for prompt or rerank): **400** - Maximum number of metadata fields (for prompt + or retrieval): **10** Contact support@contextual.ai to request quota increases. extra_headers: Send extra headers @@ -409,7 +410,7 @@ def set_metadata( raise ValueError(f"Expected a non-empty value for `datastore_id` but received {datastore_id!r}") if not document_id: raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}") - return self._post( + return self._put( f"/datastores/{datastore_id}/documents/{document_id}/metadata", body=maybe_transform( { @@ -449,19 +450,19 @@ def list( self, datastore_id: str, *, - cursor: str | NotGiven = NOT_GIVEN, - document_name_prefix: str | NotGiven = NOT_GIVEN, + cursor: str | Omit = omit, + document_name_prefix: str | Omit = omit, ingestion_job_status: List[Literal["pending", "processing", "retrying", "completed", "failed", "cancelled"]] - | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - uploaded_after: Union[str, datetime] | NotGiven = NOT_GIVEN, - uploaded_before: Union[str, datetime] | NotGiven = NOT_GIVEN, + | Omit = omit, + limit: int | Omit = omit, + uploaded_after: Union[str, datetime] | Omit = omit, + uploaded_before: Union[str, datetime] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> AsyncPaginator[DocumentMetadata, AsyncDocumentsPage[DocumentMetadata]]: """ Get list of documents in a given `Datastore`, including document `id`, `name`, @@ -532,7 +533,7 @@ async def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> object: """Delete a given document from its `Datastore`. @@ -568,13 +569,13 @@ async def get_parse_result( document_id: str, *, datastore_id: str, - output_types: List[Literal["markdown-document", "markdown-per-page", "blocks-per-page"]] | NotGiven = NOT_GIVEN, + output_types: List[Literal["markdown-document", "markdown-per-page", "blocks-per-page"]] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DocumentGetParseResultResponse: """ Get the parse results that are generated during ingestion for a given document. @@ -624,13 +625,14 @@ async def ingest( datastore_id: str, *, file: FileTypes, - metadata: str | NotGiven = NOT_GIVEN, + configuration: str | Omit = omit, + metadata: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> IngestionResponse: """Ingest a document into a given `Datastore`. @@ -652,25 +654,29 @@ async def ingest( file: File to ingest. - metadata: Metadata request in JSON format. `custom_metadata` is a flat dictionary - containing one or more key-value pairs, where each value must be a primitive - type (`str`, `bool`, `float`, or `int`). The default maximum metadata fields - that can be used is 15, contact support if more is needed.The combined size of - the metadata must not exceed **2 KB** when encoded as JSON.The strings with date - format must stay in date format or be avoided if not in date format.The - `custom_metadata.url` field is automatically included in returned attributions - during query time, if provided. - - **Example Request Body:** - - ```json - { - "custom_metadata": { - "topic": "science", - "difficulty": 3 - } - } - ``` + configuration: Overrides the datastore's default configuration for this specific document. This + allows applying optimized settings tailored to the document's characteristics + without changing the global datastore configuration. + + metadata: Metadata request in stringified JSON format. `custom_metadata` is a flat + dictionary containing one or more key-value pairs, where each value must be a + primitive type (`str`, `bool`, `float`, or `int`). The default maximum metadata + fields that can be used is 15, contact support@contextual.ai if more is needed. + The combined size of the metadata must not exceed **2 KB** when encoded as JSON. + The strings with date format must stay in date format or be avoided if not in + date format. The `custom_metadata.url` or `link` field is automatically included + in returned attributions during query time, if provided. + + **Example Request Body (as returned by `json.dumps`):** + + ```json + "{{ + \"custom_metadata\": {{ + \"topic\": \"science\", + \"difficulty\": 3 + }} + }}" + ``` extra_headers: Send extra headers @@ -685,6 +691,7 @@ async def ingest( body = deepcopy_minimal( { "file": file, + "configuration": configuration, "metadata": metadata, } ) @@ -713,7 +720,7 @@ async def metadata( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DocumentMetadata: """ Get details of a given document, including its `name` and ingestion job @@ -749,14 +756,14 @@ async def set_metadata( document_id: str, *, datastore_id: str, - custom_metadata: Dict[str, Union[bool, float, str]] | NotGiven = NOT_GIVEN, - custom_metadata_config: Dict[str, document_set_metadata_params.CustomMetadataConfig] | NotGiven = NOT_GIVEN, + custom_metadata: Dict[str, Union[bool, float, str, Iterable[float]]] | Omit = omit, + custom_metadata_config: Dict[str, document_set_metadata_params.CustomMetadataConfig] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DocumentMetadata: """ Post details of a given document that will enrich the chunk and be added to the @@ -777,17 +784,12 @@ async def set_metadata( used is 15, contact support if more is needed. custom_metadata_config: A dictionary mapping metadata field names to the configuration to use for each - field. - - - If a metadata field is not present in the dictionary, the default configuration will be used. - - - If the dictionary is not provided, metadata will be added in chunks but will not be retrievable. - - Limits: - Maximum characters per metadata field (for prompt or rerank): 400 - - - Maximum number of metadata fields (for prompt or retrieval): 10 - - Contact support@contextual.ai to request quota increases. + field. If a metadata field is not present in the dictionary, the default + configuration will be used. If the dictionary is not provided, metadata will be + added in context for rerank and generation but will not be returned back to the + user in retrievals in query API. Limits: - Maximum characters per metadata field + (for prompt or rerank): **400** - Maximum number of metadata fields (for prompt + or retrieval): **10** Contact support@contextual.ai to request quota increases. extra_headers: Send extra headers @@ -801,7 +803,7 @@ async def set_metadata( raise ValueError(f"Expected a non-empty value for `datastore_id` but received {datastore_id!r}") if not document_id: raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}") - return await self._post( + return await self._put( f"/datastores/{datastore_id}/documents/{document_id}/metadata", body=await async_maybe_transform( { diff --git a/src/contextual/resources/generate.py b/src/contextual/resources/generate.py index 3dae6172..1da6b689 100644 --- a/src/contextual/resources/generate.py +++ b/src/contextual/resources/generate.py @@ -2,12 +2,12 @@ from __future__ import annotations -from typing import List, Iterable +from typing import Iterable import httpx from ..types import generate_create_params -from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from .._types import Body, Omit, Query, Headers, NotGiven, SequenceNotStr, omit, not_given from .._utils import maybe_transform, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource @@ -46,20 +46,20 @@ def with_streaming_response(self) -> GenerateResourceWithStreamingResponse: def create( self, *, - knowledge: List[str], + knowledge: SequenceNotStr[str], messages: Iterable[generate_create_params.Message], model: str, - avoid_commentary: bool | NotGiven = NOT_GIVEN, - max_new_tokens: int | NotGiven = NOT_GIVEN, - system_prompt: str | NotGiven = NOT_GIVEN, - temperature: float | NotGiven = NOT_GIVEN, - top_p: float | NotGiven = NOT_GIVEN, + avoid_commentary: bool | Omit = omit, + max_new_tokens: int | Omit = omit, + system_prompt: str | Omit = omit, + temperature: float | Omit = omit, + top_p: float | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> GenerateCreateResponse: """ Generate a response using Contextual's Grounded Language Model (GLM), an LLM @@ -154,20 +154,20 @@ def with_streaming_response(self) -> AsyncGenerateResourceWithStreamingResponse: async def create( self, *, - knowledge: List[str], + knowledge: SequenceNotStr[str], messages: Iterable[generate_create_params.Message], model: str, - avoid_commentary: bool | NotGiven = NOT_GIVEN, - max_new_tokens: int | NotGiven = NOT_GIVEN, - system_prompt: str | NotGiven = NOT_GIVEN, - temperature: float | NotGiven = NOT_GIVEN, - top_p: float | NotGiven = NOT_GIVEN, + avoid_commentary: bool | Omit = omit, + max_new_tokens: int | Omit = omit, + system_prompt: str | Omit = omit, + temperature: float | Omit = omit, + top_p: float | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> GenerateCreateResponse: """ Generate a response using Contextual's Grounded Language Model (GLM), an LLM diff --git a/src/contextual/resources/lmunit.py b/src/contextual/resources/lmunit.py index ceb000cc..99e677d8 100644 --- a/src/contextual/resources/lmunit.py +++ b/src/contextual/resources/lmunit.py @@ -5,7 +5,7 @@ import httpx from ..types import lmunit_create_params -from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from .._types import Body, Query, Headers, NotGiven, not_given from .._utils import maybe_transform, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource @@ -52,7 +52,7 @@ def create( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> LMUnitCreateResponse: """ Given a `query`, `response`, and a `unit_test`, return the response's `score` on @@ -129,7 +129,7 @@ async def create( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> LMUnitCreateResponse: """ Given a `query`, `response`, and a `unit_test`, return the response's `score` on diff --git a/src/contextual/resources/parse.py b/src/contextual/resources/parse.py index 2485b22f..eeb87af2 100644 --- a/src/contextual/resources/parse.py +++ b/src/contextual/resources/parse.py @@ -9,7 +9,7 @@ import httpx from ..types import parse_jobs_params, parse_create_params, parse_job_results_params -from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes +from .._types import Body, Omit, Query, Headers, NotGiven, FileTypes, omit, not_given from .._utils import extract_files, maybe_transform, deepcopy_minimal, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource @@ -52,18 +52,18 @@ def create( self, *, raw_file: FileTypes, - enable_document_hierarchy: bool | NotGiven = NOT_GIVEN, - enable_split_tables: bool | NotGiven = NOT_GIVEN, - figure_caption_mode: Literal["concise", "detailed"] | NotGiven = NOT_GIVEN, - max_split_table_cells: int | NotGiven = NOT_GIVEN, - page_range: str | NotGiven = NOT_GIVEN, - parse_mode: Literal["basic", "standard"] | NotGiven = NOT_GIVEN, + enable_document_hierarchy: bool | Omit = omit, + enable_split_tables: bool | Omit = omit, + figure_caption_mode: Literal["concise", "detailed"] | Omit = omit, + max_split_table_cells: int | Omit = omit, + page_range: str | Omit = omit, + parse_mode: Literal["basic", "standard"] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ParseCreateResponse: """Parse a file into a structured Markdown and/or JSON. @@ -141,13 +141,13 @@ def job_results( self, job_id: str, *, - output_types: List[Literal["markdown-document", "markdown-per-page", "blocks-per-page"]] | NotGiven = NOT_GIVEN, + output_types: List[Literal["markdown-document", "markdown-per-page", "blocks-per-page"]] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ParseJobResultsResponse: """ Get the results of a parse job. @@ -196,7 +196,7 @@ def job_status( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ParseJobStatusResponse: """ Get the status of a parse job. @@ -228,15 +228,15 @@ def job_status( def jobs( self, *, - cursor: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - uploaded_after: Union[str, datetime] | NotGiven = NOT_GIVEN, + cursor: str | Omit = omit, + limit: int | Omit = omit, + uploaded_after: Union[str, datetime] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ParseJobsResponse: """ Get list of parse jobs, sorted from most recent to oldest. @@ -307,18 +307,18 @@ async def create( self, *, raw_file: FileTypes, - enable_document_hierarchy: bool | NotGiven = NOT_GIVEN, - enable_split_tables: bool | NotGiven = NOT_GIVEN, - figure_caption_mode: Literal["concise", "detailed"] | NotGiven = NOT_GIVEN, - max_split_table_cells: int | NotGiven = NOT_GIVEN, - page_range: str | NotGiven = NOT_GIVEN, - parse_mode: Literal["basic", "standard"] | NotGiven = NOT_GIVEN, + enable_document_hierarchy: bool | Omit = omit, + enable_split_tables: bool | Omit = omit, + figure_caption_mode: Literal["concise", "detailed"] | Omit = omit, + max_split_table_cells: int | Omit = omit, + page_range: str | Omit = omit, + parse_mode: Literal["basic", "standard"] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ParseCreateResponse: """Parse a file into a structured Markdown and/or JSON. @@ -396,13 +396,13 @@ async def job_results( self, job_id: str, *, - output_types: List[Literal["markdown-document", "markdown-per-page", "blocks-per-page"]] | NotGiven = NOT_GIVEN, + output_types: List[Literal["markdown-document", "markdown-per-page", "blocks-per-page"]] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ParseJobResultsResponse: """ Get the results of a parse job. @@ -453,7 +453,7 @@ async def job_status( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ParseJobStatusResponse: """ Get the status of a parse job. @@ -485,15 +485,15 @@ async def job_status( async def jobs( self, *, - cursor: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - uploaded_after: Union[str, datetime] | NotGiven = NOT_GIVEN, + cursor: str | Omit = omit, + limit: int | Omit = omit, + uploaded_after: Union[str, datetime] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ParseJobsResponse: """ Get list of parse jobs, sorted from most recent to oldest. diff --git a/src/contextual/resources/rerank.py b/src/contextual/resources/rerank.py index 81d48cc0..1d643c5b 100644 --- a/src/contextual/resources/rerank.py +++ b/src/contextual/resources/rerank.py @@ -2,12 +2,10 @@ from __future__ import annotations -from typing import List - import httpx from ..types import rerank_create_params -from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from .._types import Body, Omit, Query, Headers, NotGiven, SequenceNotStr, omit, not_given from .._utils import maybe_transform, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource @@ -46,18 +44,18 @@ def with_streaming_response(self) -> RerankResourceWithStreamingResponse: def create( self, *, - documents: List[str], + documents: SequenceNotStr[str], model: str, query: str, - instruction: str | NotGiven = NOT_GIVEN, - metadata: List[str] | NotGiven = NOT_GIVEN, - top_n: int | NotGiven = NOT_GIVEN, + instruction: str | Omit = omit, + metadata: SequenceNotStr[str] | Omit = omit, + top_n: int | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> RerankCreateResponse: """ Rank a list of documents according to their relevance to a query primarily and @@ -152,18 +150,18 @@ def with_streaming_response(self) -> AsyncRerankResourceWithStreamingResponse: async def create( self, *, - documents: List[str], + documents: SequenceNotStr[str], model: str, query: str, - instruction: str | NotGiven = NOT_GIVEN, - metadata: List[str] | NotGiven = NOT_GIVEN, - top_n: int | NotGiven = NOT_GIVEN, + instruction: str | Omit = omit, + metadata: SequenceNotStr[str] | Omit = omit, + top_n: int | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> RerankCreateResponse: """ Rank a list of documents according to their relevance to a query primarily and diff --git a/src/contextual/resources/users.py b/src/contextual/resources/users.py index e699f9cc..c025598c 100644 --- a/src/contextual/resources/users.py +++ b/src/contextual/resources/users.py @@ -8,7 +8,7 @@ import httpx from ..types import user_list_params, user_invite_params, user_update_params, user_deactivate_params -from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given from .._utils import maybe_transform, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource @@ -51,29 +51,32 @@ def update( self, *, email: str, - agent_level_roles: List[Literal["AGENT_LEVEL_USER"]] | NotGiven = NOT_GIVEN, - is_tenant_admin: bool | NotGiven = NOT_GIVEN, - per_agent_roles: Iterable[user_update_params.PerAgentRole] | NotGiven = NOT_GIVEN, + agent_level_roles: List[Literal["AGENT_LEVEL_USER"]] | Omit = omit, + is_tenant_admin: bool | Omit = omit, + per_agent_roles: Iterable[user_update_params.PerAgentRole] | Omit = omit, roles: List[ Literal[ "VISITOR", "AGENT_USER", + "CUSTOMER_USER", "CUSTOMER_INTERNAL_USER", "CONTEXTUAL_STAFF_USER", "CONTEXTUAL_EXTERNAL_STAFF_USER", "CONTEXTUAL_INTERNAL_STAFF_USER", "TENANT_ADMIN", + "CUSTOMER_ADMIN", + "CONTEXTUAL_ADMIN", "SUPER_ADMIN", "SERVICE_ACCOUNT", ] ] - | NotGiven = NOT_GIVEN, + | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> object: """ Modify a given `User`. @@ -122,16 +125,16 @@ def update( def list( self, *, - cursor: str | NotGiven = NOT_GIVEN, - deactivated: bool | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - search: str | NotGiven = NOT_GIVEN, + cursor: str | Omit = omit, + deactivated: bool | Omit = omit, + limit: int | Omit = omit, + search: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> SyncUsersPage[User]: """ Retrieve a list of `users`. @@ -183,7 +186,7 @@ def deactivate( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> object: """ Delete a given `user`. @@ -218,7 +221,7 @@ def invite( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> InviteUsersResponse: """Invite users to the tenant. @@ -279,29 +282,32 @@ async def update( self, *, email: str, - agent_level_roles: List[Literal["AGENT_LEVEL_USER"]] | NotGiven = NOT_GIVEN, - is_tenant_admin: bool | NotGiven = NOT_GIVEN, - per_agent_roles: Iterable[user_update_params.PerAgentRole] | NotGiven = NOT_GIVEN, + agent_level_roles: List[Literal["AGENT_LEVEL_USER"]] | Omit = omit, + is_tenant_admin: bool | Omit = omit, + per_agent_roles: Iterable[user_update_params.PerAgentRole] | Omit = omit, roles: List[ Literal[ "VISITOR", "AGENT_USER", + "CUSTOMER_USER", "CUSTOMER_INTERNAL_USER", "CONTEXTUAL_STAFF_USER", "CONTEXTUAL_EXTERNAL_STAFF_USER", "CONTEXTUAL_INTERNAL_STAFF_USER", "TENANT_ADMIN", + "CUSTOMER_ADMIN", + "CONTEXTUAL_ADMIN", "SUPER_ADMIN", "SERVICE_ACCOUNT", ] ] - | NotGiven = NOT_GIVEN, + | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> object: """ Modify a given `User`. @@ -350,16 +356,16 @@ async def update( def list( self, *, - cursor: str | NotGiven = NOT_GIVEN, - deactivated: bool | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - search: str | NotGiven = NOT_GIVEN, + cursor: str | Omit = omit, + deactivated: bool | Omit = omit, + limit: int | Omit = omit, + search: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> AsyncPaginator[User, AsyncUsersPage[User]]: """ Retrieve a list of `users`. @@ -411,7 +417,7 @@ async def deactivate( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> object: """ Delete a given `user`. @@ -446,7 +452,7 @@ async def invite( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> InviteUsersResponse: """Invite users to the tenant. diff --git a/src/contextual/types/__init__.py b/src/contextual/types/__init__.py index 6ac75ea7..1dec3682 100644 --- a/src/contextual/types/__init__.py +++ b/src/contextual/types/__init__.py @@ -55,13 +55,13 @@ # This ensures that, when building the deferred (due to cyclical references) model schema, # Pydantic can resolve the necessary references. # See: https://github.com/pydantic/pydantic/issues/11250 for more context. -if _compat.PYDANTIC_V2: - datastores.composite_metadata_filter.CompositeMetadataFilter.model_rebuild(_parent_namespace_depth=0) - agent_configs.AgentConfigs.model_rebuild(_parent_namespace_depth=0) - agent_metadata.AgentMetadata.model_rebuild(_parent_namespace_depth=0) - filter_and_rerank_config.FilterAndRerankConfig.model_rebuild(_parent_namespace_depth=0) -else: +if _compat.PYDANTIC_V1: datastores.composite_metadata_filter.CompositeMetadataFilter.update_forward_refs() # type: ignore agent_configs.AgentConfigs.update_forward_refs() # type: ignore agent_metadata.AgentMetadata.update_forward_refs() # type: ignore filter_and_rerank_config.FilterAndRerankConfig.update_forward_refs() # type: ignore +else: + datastores.composite_metadata_filter.CompositeMetadataFilter.model_rebuild(_parent_namespace_depth=0) + agent_configs.AgentConfigs.model_rebuild(_parent_namespace_depth=0) + agent_metadata.AgentMetadata.model_rebuild(_parent_namespace_depth=0) + filter_and_rerank_config.FilterAndRerankConfig.model_rebuild(_parent_namespace_depth=0) diff --git a/src/contextual/types/agent_configs.py b/src/contextual/types/agent_configs.py index 6acecb7e..1b7f9d43 100644 --- a/src/contextual/types/agent_configs.py +++ b/src/contextual/types/agent_configs.py @@ -9,7 +9,15 @@ from .retrieval_config import RetrievalConfig from .generate_response_config import GenerateResponseConfig -__all__ = ["AgentConfigs", "ReformulationConfig"] +__all__ = ["AgentConfigs", "ACLConfig", "ReformulationConfig", "TranslationConfig"] + + +class ACLConfig(BaseModel): + acl_active: Optional[bool] = None + """Whether to enable ACL.""" + + acl_yaml: Optional[str] = None + """The YAML file to use for ACL.""" class ReformulationConfig(BaseModel): @@ -26,7 +34,18 @@ class ReformulationConfig(BaseModel): """The prompt to use for query expansion.""" +class TranslationConfig(BaseModel): + translate_confidence: Optional[float] = None + """The confidence threshold for translation.""" + + translate_needed: Optional[bool] = None + """Whether to enable translation for the agent's responses.""" + + class AgentConfigs(BaseModel): + acl_config: Optional[ACLConfig] = None + """Parameters that affect the agent's ACL workflow""" + filter_and_rerank_config: Optional["FilterAndRerankConfig"] = None """Parameters that affect filtering and reranking of retrieved knowledge""" @@ -42,5 +61,8 @@ class AgentConfigs(BaseModel): retrieval_config: Optional[RetrievalConfig] = None """Parameters that affect how the agent retrieves from datastore(s)""" + translation_config: Optional[TranslationConfig] = None + """Parameters that affect the agent's translation workflow""" + from .filter_and_rerank_config import FilterAndRerankConfig diff --git a/src/contextual/types/agent_configs_param.py b/src/contextual/types/agent_configs_param.py index 909c860e..822e02a0 100644 --- a/src/contextual/types/agent_configs_param.py +++ b/src/contextual/types/agent_configs_param.py @@ -8,7 +8,15 @@ from .retrieval_config_param import RetrievalConfigParam from .generate_response_config_param import GenerateResponseConfigParam -__all__ = ["AgentConfigsParam", "ReformulationConfig"] +__all__ = ["AgentConfigsParam", "ACLConfig", "ReformulationConfig", "TranslationConfig"] + + +class ACLConfig(TypedDict, total=False): + acl_active: bool + """Whether to enable ACL.""" + + acl_yaml: str + """The YAML file to use for ACL.""" class ReformulationConfig(TypedDict, total=False): @@ -25,7 +33,18 @@ class ReformulationConfig(TypedDict, total=False): """The prompt to use for query expansion.""" +class TranslationConfig(TypedDict, total=False): + translate_confidence: float + """The confidence threshold for translation.""" + + translate_needed: bool + """Whether to enable translation for the agent's responses.""" + + class AgentConfigsParam(TypedDict, total=False): + acl_config: ACLConfig + """Parameters that affect the agent's ACL workflow""" + filter_and_rerank_config: "FilterAndRerankConfigParam" """Parameters that affect filtering and reranking of retrieved knowledge""" @@ -41,5 +60,8 @@ class AgentConfigsParam(TypedDict, total=False): retrieval_config: RetrievalConfigParam """Parameters that affect how the agent retrieves from datastore(s)""" + translation_config: TranslationConfig + """Parameters that affect the agent's translation workflow""" + from .filter_and_rerank_config_param import FilterAndRerankConfigParam diff --git a/src/contextual/types/agent_create_params.py b/src/contextual/types/agent_create_params.py index f6613de1..90503f38 100644 --- a/src/contextual/types/agent_create_params.py +++ b/src/contextual/types/agent_create_params.py @@ -2,9 +2,10 @@ from __future__ import annotations -from typing import List from typing_extensions import Required, TypedDict +from .._types import SequenceNotStr + __all__ = ["AgentCreateParams"] @@ -15,7 +16,7 @@ class AgentCreateParams(TypedDict, total=False): agent_configs: "AgentConfigsParam" """The following advanced parameters are experimental and subject to change.""" - datastore_ids: List[str] + datastore_ids: SequenceNotStr[str] """The IDs of the datastore to associate with this agent.""" description: str @@ -36,7 +37,7 @@ class AgentCreateParams(TypedDict, total=False): retrievals that can be used to answer a query. """ - suggested_queries: List[str] + suggested_queries: SequenceNotStr[str] """ These queries will show up as suggestions in the Contextual UI when users load the agent. We recommend including common queries that users will ask, as well as @@ -51,5 +52,8 @@ class AgentCreateParams(TypedDict, total=False): exactly. """ + template_name: str + """The template defining the base configuration for the agent.""" + from .agent_configs_param import AgentConfigsParam diff --git a/src/contextual/types/agent_metadata.py b/src/contextual/types/agent_metadata.py index a6b12e88..6598f8e6 100644 --- a/src/contextual/types/agent_metadata.py +++ b/src/contextual/types/agent_metadata.py @@ -45,14 +45,6 @@ class AgentMetadata(BaseModel): given query and filters out irrelevant chunks. This prompt is applied per chunk. """ - llm_model_id: Optional[str] = None - """The model ID to use for generation. - - Tuned models can only be used for the agents on which they were tuned. If no - model is specified, the default model is used. Set to `default` to switch from a - tuned model to the default model. - """ - multiturn_system_prompt: Optional[str] = None """Instructions on how the agent should handle multi-turn conversations.""" diff --git a/src/contextual/types/agent_metadata_response.py b/src/contextual/types/agent_metadata_response.py index e514568e..0bb47340 100644 --- a/src/contextual/types/agent_metadata_response.py +++ b/src/contextual/types/agent_metadata_response.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import List, Union, Optional +from typing import Dict, List, Union, Optional from typing_extensions import TypeAlias from .._models import BaseModel @@ -30,7 +30,7 @@ class GetTwilightAgentResponse(BaseModel): template_name: str - agent_configs: Optional[object] = None + agent_configs: Optional[Dict[str, object]] = None """The following advanced parameters are experimental and subject to change.""" agent_usages: Optional[GetTwilightAgentResponseAgentUsages] = None diff --git a/src/contextual/types/agent_update_params.py b/src/contextual/types/agent_update_params.py index 9f58a1c3..31dc0fac 100644 --- a/src/contextual/types/agent_update_params.py +++ b/src/contextual/types/agent_update_params.py @@ -2,9 +2,10 @@ from __future__ import annotations -from typing import List from typing_extensions import TypedDict +from .._types import SequenceNotStr + __all__ = ["AgentUpdateParams"] @@ -12,33 +13,31 @@ class AgentUpdateParams(TypedDict, total=False): agent_configs: "AgentConfigsParam" """The following advanced parameters are experimental and subject to change.""" - datastore_ids: List[str] + datastore_ids: SequenceNotStr[str] """IDs of the datastore to associate with the agent.""" + description: str + """Description of the agent""" + filter_prompt: str """ The prompt to an LLM which determines whether retrieved chunks are relevant to a given query and filters out irrelevant chunks. """ - llm_model_id: str - """The model ID to use for generation. - - Tuned models can only be used for the agents on which they were tuned. If no - model is specified, the default model is used. Set to `default` to switch from a - tuned model to the default model. - """ - multiturn_system_prompt: str """Instructions on how the agent should handle multi-turn conversations.""" + name: str + """Name of the agent""" + no_retrieval_system_prompt: str """ Instructions on how the agent should respond when there are no relevant retrievals that can be used to answer a query. """ - suggested_queries: List[str] + suggested_queries: SequenceNotStr[str] """ These queries will show up as suggestions in the Contextual UI when users load the agent. We recommend including common queries that users will ask, as well as diff --git a/src/contextual/types/agents/__init__.py b/src/contextual/types/agents/__init__.py index 561c07db..6f824035 100644 --- a/src/contextual/types/agents/__init__.py +++ b/src/contextual/types/agents/__init__.py @@ -7,5 +7,6 @@ from .query_metrics_params import QueryMetricsParams as QueryMetricsParams from .query_feedback_params import QueryFeedbackParams as QueryFeedbackParams from .query_metrics_response import QueryMetricsResponse as QueryMetricsResponse +from .query_feedback_response import QueryFeedbackResponse as QueryFeedbackResponse from .retrieval_info_response import RetrievalInfoResponse as RetrievalInfoResponse from .query_retrieval_info_params import QueryRetrievalInfoParams as QueryRetrievalInfoParams diff --git a/src/contextual/types/agents/query_create_params.py b/src/contextual/types/agents/query_create_params.py index 550b0f1f..ed08da17 100644 --- a/src/contextual/types/agents/query_create_params.py +++ b/src/contextual/types/agents/query_create_params.py @@ -2,9 +2,10 @@ from __future__ import annotations -from typing import Union, Iterable +from typing import Dict, Union, Iterable from typing_extensions import Literal, Required, TypeAlias, TypedDict +from ..._types import SequenceNotStr from ..datastores.base_metadata_filter_param import BaseMetadataFilterParam __all__ = ["QueryCreateParams", "Message", "DocumentsFilters", "OverrideConfiguration", "StructuredOutput"] @@ -44,8 +45,10 @@ class QueryCreateParams(TypedDict, total=False): documents_filters: DocumentsFilters """ Defines an Optional custom metadata filter, which can be a list of filters or - nested filters. The expected input is a nested JSON object that can represent a - single filter or a composite (logical) combination of filters. + nested filters. Use **lowercase** for `value` and/or **field.keyword** for + `field` when not using `equals` operator.The expected input is a nested JSON + object that can represent a single filter or a composite (logical) combination + of filters. Unnested Example: @@ -105,6 +108,9 @@ class Message(TypedDict, total=False): role: Required[Literal["user", "system", "assistant", "knowledge"]] """Role of the sender""" + custom_tags: SequenceNotStr[str] + """Custom tags for the message""" + DocumentsFilters: TypeAlias = Union[BaseMetadataFilterParam, "CompositeMetadataFilterParam"] @@ -178,7 +184,7 @@ class OverrideConfiguration(TypedDict, total=False): class StructuredOutput(TypedDict, total=False): - json_schema: Required[object] + json_schema: Required[Dict[str, object]] """The output json structure.""" type: Literal["JSON"] diff --git a/src/contextual/types/agents/query_feedback_response.py b/src/contextual/types/agents/query_feedback_response.py new file mode 100644 index 00000000..303902b3 --- /dev/null +++ b/src/contextual/types/agents/query_feedback_response.py @@ -0,0 +1,10 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from ..._models import BaseModel + +__all__ = ["QueryFeedbackResponse"] + + +class QueryFeedbackResponse(BaseModel): + feedback_id: str + """ID of the submitted or updated feedback.""" diff --git a/src/contextual/types/agents/query_metrics_params.py b/src/contextual/types/agents/query_metrics_params.py index a24c2c90..38e8f22e 100644 --- a/src/contextual/types/agents/query_metrics_params.py +++ b/src/contextual/types/agents/query_metrics_params.py @@ -2,24 +2,35 @@ from __future__ import annotations -from typing import List, Union +from typing import Union from datetime import datetime from typing_extensions import Annotated, TypedDict +from ..._types import SequenceNotStr from ..._utils import PropertyInfo __all__ = ["QueryMetricsParams"] class QueryMetricsParams(TypedDict, total=False): - conversation_ids: List[str] + conversation_ids: SequenceNotStr[str] """Filter messages by conversation ids.""" created_after: Annotated[Union[str, datetime], PropertyInfo(format="iso8601")] """Filters messages that are created after the specified timestamp.""" created_before: Annotated[Union[str, datetime], PropertyInfo(format="iso8601")] - """Filters messages that are created before specified timestamp.""" + """Filters messages that are created before specified timestamp. + + If both `created_after` and `created_before` are not provided, then + `created_before` will be set to the current time and `created_after` will be set + to the `created_before` - 2 days. If only `created_after` is provided, then + `created_before` will be set to the `created_after` + 2 days. If only + `created_before` is provided, then `created_after` will be set to the + `created_before` - 2 days. If both `created_after` and `created_before` are + provided, and the difference between them is more than 2 days, then + `created_after` will be set to the `created_before` - 2 days. + """ limit: int """Limits the number of messages to return.""" @@ -27,5 +38,5 @@ class QueryMetricsParams(TypedDict, total=False): offset: int """Offset for pagination.""" - user_emails: List[str] + user_emails: SequenceNotStr[str] """Filter messages by users.""" diff --git a/src/contextual/types/agents/query_response.py b/src/contextual/types/agents/query_response.py index aca2d227..82e1b411 100644 --- a/src/contextual/types/agents/query_response.py +++ b/src/contextual/types/agents/query_response.py @@ -49,12 +49,17 @@ class RetrievalContentCtxlMetadata(BaseModel): section_title: Optional[str] = None """Title of the section.""" - __pydantic_extra__: Dict[str, object] = FieldInfo(init=False) # pyright: ignore[reportIncompatibleVariableOverride] if TYPE_CHECKING: + # Some versions of Pydantic <2.8.0 have a bug and don’t allow assigning a + # value to this field, so for compatibility we avoid doing it at runtime. + __pydantic_extra__: Dict[str, object] = FieldInfo(init=False) # pyright: ignore[reportIncompatibleVariableOverride] + # Stub to indicate that arbitrary properties are accepted. # To access properties that are not valid identifiers you can use `getattr`, e.g. # `getattr(obj, '$type')` def __getattr__(self, attr: str) -> object: ... + else: + __pydantic_extra__: Dict[str, object] class RetrievalContentCustomMetadataConfig(BaseModel): @@ -102,7 +107,7 @@ class RetrievalContent(BaseModel): ctxl_metadata: Optional[RetrievalContentCtxlMetadata] = None """Default metadata from the retrieval""" - custom_metadata: Optional[Dict[str, Union[bool, float, str]]] = None + custom_metadata: Optional[Dict[str, Union[bool, float, str, List[float]]]] = None """ Custom metadata for the document, provided by the user at ingestion time.Must be a JSON-serializable dictionary with string keys and simple primitive values @@ -116,21 +121,17 @@ class RetrievalContent(BaseModel): custom_metadata_config: Optional[Dict[str, RetrievalContentCustomMetadataConfig]] = None """ A dictionary mapping metadata field names to the configuration to use for each - field. - - - If a metadata field is not present in the dictionary, the default configuration will be used. - - - If the dictionary is not provided, metadata will be added in chunks but will not be retrievable. - - - Limits: - Maximum characters per metadata field (for prompt or rerank): 400 - - - Maximum number of metadata fields (for prompt or retrieval): 10 - - - Contact support@contextual.ai to request quota increases. + field. If a metadata field is not present in the dictionary, the default + configuration will be used. If the dictionary is not provided, metadata will be + added in context for rerank and generation but will not be returned back to the + user in retrievals in query API. Limits: - Maximum characters per metadata field + (for prompt or rerank): **400** - Maximum number of metadata fields (for prompt + or retrieval): **10** Contact support@contextual.ai to request quota increases. """ + datastore_id: Optional[str] = None + """Unique identifier of the datastore""" + number: Optional[int] = None """Index of the retrieved item in the retrieval_contents list (starting from 1)""" @@ -173,6 +174,9 @@ class Message(BaseModel): role: Literal["user", "system", "assistant", "knowledge"] """Role of the sender""" + custom_tags: Optional[List[str]] = None + """Custom tags for the message""" + class QueryResponse(BaseModel): conversation_id: str diff --git a/src/contextual/types/agents/query_retrieval_info_params.py b/src/contextual/types/agents/query_retrieval_info_params.py index 14d5fef7..ffa7b004 100644 --- a/src/contextual/types/agents/query_retrieval_info_params.py +++ b/src/contextual/types/agents/query_retrieval_info_params.py @@ -2,9 +2,10 @@ from __future__ import annotations -from typing import List from typing_extensions import Required, TypedDict +from ..._types import SequenceNotStr + __all__ = ["QueryRetrievalInfoParams"] @@ -12,5 +13,5 @@ class QueryRetrievalInfoParams(TypedDict, total=False): agent_id: Required[str] """ID of the agent which sent the provided message.""" - content_ids: Required[List[str]] + content_ids: Required[SequenceNotStr[str]] """List of content ids for which to get the metadata.""" diff --git a/src/contextual/types/agents/retrieval_info_response.py b/src/contextual/types/agents/retrieval_info_response.py index 07358922..6dfbf044 100644 --- a/src/contextual/types/agents/retrieval_info_response.py +++ b/src/contextual/types/agents/retrieval_info_response.py @@ -11,6 +11,7 @@ "ContentMetadata", "ContentMetadataUnstructuredContentMetadata", "ContentMetadataStructuredContentMetadata", + "ContentMetadataFileAnalysisContentMetadata", ] @@ -61,8 +62,25 @@ class ContentMetadataStructuredContentMetadata(BaseModel): content_type: Optional[Literal["structured"]] = None +class ContentMetadataFileAnalysisContentMetadata(BaseModel): + content_id: str + """Id of the content.""" + + file_format: str + """Format of the file.""" + + gcp_location: str + """GCP location of the file.""" + + content_type: Optional[Literal["file_analysis"]] = None + + ContentMetadata: TypeAlias = Annotated[ - Union[ContentMetadataUnstructuredContentMetadata, ContentMetadataStructuredContentMetadata], + Union[ + ContentMetadataUnstructuredContentMetadata, + ContentMetadataStructuredContentMetadata, + ContentMetadataFileAnalysisContentMetadata, + ], PropertyInfo(discriminator="content_type"), ] diff --git a/src/contextual/types/datastore_metadata.py b/src/contextual/types/datastore_metadata.py index bacf5d9d..f1a00357 100644 --- a/src/contextual/types/datastore_metadata.py +++ b/src/contextual/types/datastore_metadata.py @@ -112,7 +112,7 @@ class DatastoreMetadata(BaseModel): """Name of the datastore""" configuration: Optional[Configuration] = None - """Configuration of the datastore. Not set if default configuration is in use.""" + """Configuration for unstructured datastores.""" datastore_type: Optional[Literal["UNSTRUCTURED"]] = None """Type of the datastore""" diff --git a/src/contextual/types/datastores/base_metadata_filter_param.py b/src/contextual/types/datastores/base_metadata_filter_param.py index 66ab145f..0006f7f1 100644 --- a/src/contextual/types/datastores/base_metadata_filter_param.py +++ b/src/contextual/types/datastores/base_metadata_filter_param.py @@ -2,9 +2,11 @@ from __future__ import annotations -from typing import List, Union +from typing import Union from typing_extensions import Literal, Required, TypedDict +from ..._types import SequenceNotStr + __all__ = ["BaseMetadataFilterParam"] @@ -29,7 +31,7 @@ class BaseMetadataFilterParam(TypedDict, total=False): ] """Operator to be used for the filter.""" - value: Union[str, float, bool, List[Union[str, float, bool]], None] + value: Union[str, float, bool, SequenceNotStr[Union[str, float, bool]], None] """The value to be searched for in the field. In case of exists operator, it is not needed. diff --git a/src/contextual/types/datastores/composite_metadata_filter.py b/src/contextual/types/datastores/composite_metadata_filter.py index 513d2fc4..049f45a0 100644 --- a/src/contextual/types/datastores/composite_metadata_filter.py +++ b/src/contextual/types/datastores/composite_metadata_filter.py @@ -5,13 +5,13 @@ from typing import TYPE_CHECKING, List, Union, Optional from typing_extensions import Literal, TypeAlias, TypeAliasType -from ..._compat import PYDANTIC_V2 +from ..._compat import PYDANTIC_V1 from ..._models import BaseModel from .base_metadata_filter import BaseMetadataFilter __all__ = ["CompositeMetadataFilter", "Filter"] -if TYPE_CHECKING or PYDANTIC_V2: +if TYPE_CHECKING or not PYDANTIC_V1: Filter = TypeAliasType("Filter", Union[BaseMetadataFilter, "CompositeMetadataFilter"]) else: Filter: TypeAlias = Union[BaseMetadataFilter, "CompositeMetadataFilter"] diff --git a/src/contextual/types/datastores/composite_metadata_filter_param.py b/src/contextual/types/datastores/composite_metadata_filter_param.py index 20a30cd4..b5a11f89 100644 --- a/src/contextual/types/datastores/composite_metadata_filter_param.py +++ b/src/contextual/types/datastores/composite_metadata_filter_param.py @@ -5,12 +5,12 @@ from typing import TYPE_CHECKING, Union, Iterable, Optional from typing_extensions import Literal, Required, TypeAlias, TypedDict, TypeAliasType -from ..._compat import PYDANTIC_V2 +from ..._compat import PYDANTIC_V1 from .base_metadata_filter_param import BaseMetadataFilterParam __all__ = ["CompositeMetadataFilterParam", "Filter"] -if TYPE_CHECKING or PYDANTIC_V2: +if TYPE_CHECKING or not PYDANTIC_V1: Filter = TypeAliasType("Filter", Union[BaseMetadataFilterParam, "CompositeMetadataFilterParam"]) else: Filter: TypeAlias = Union[BaseMetadataFilterParam, "CompositeMetadataFilterParam"] diff --git a/src/contextual/types/datastores/document_ingest_params.py b/src/contextual/types/datastores/document_ingest_params.py index 5179f7d7..f5e02108 100644 --- a/src/contextual/types/datastores/document_ingest_params.py +++ b/src/contextual/types/datastores/document_ingest_params.py @@ -13,25 +13,33 @@ class DocumentIngestParams(TypedDict, total=False): file: Required[FileTypes] """File to ingest.""" + configuration: str + """Overrides the datastore's default configuration for this specific document. + + This allows applying optimized settings tailored to the document's + characteristics without changing the global datastore configuration. + """ + metadata: str - """Metadata request in JSON format. + """Metadata request in stringified JSON format. `custom_metadata` is a flat dictionary containing one or more key-value pairs, where each value must be a primitive type (`str`, `bool`, `float`, or `int`). - The default maximum metadata fields that can be used is 15, contact support if - more is needed.The combined size of the metadata must not exceed **2 KB** when - encoded as JSON.The strings with date format must stay in date format or be - avoided if not in date format.The `custom_metadata.url` field is automatically - included in returned attributions during query time, if provided. - - **Example Request Body:** - - ```json - { - "custom_metadata": { - "topic": "science", - "difficulty": 3 - } - } - ``` + The default maximum metadata fields that can be used is 15, contact + support@contextual.ai if more is needed. The combined size of the metadata must + not exceed **2 KB** when encoded as JSON. The strings with date format must stay + in date format or be avoided if not in date format. The `custom_metadata.url` or + `link` field is automatically included in returned attributions during query + time, if provided. + + **Example Request Body (as returned by `json.dumps`):** + + ```json + "{{ + \"custom_metadata\": {{ + \"topic\": \"science\", + \"difficulty\": 3 + }} + }}" + ``` """ diff --git a/src/contextual/types/datastores/document_metadata.py b/src/contextual/types/datastores/document_metadata.py index 790f0ad0..c1bdd5ac 100644 --- a/src/contextual/types/datastores/document_metadata.py +++ b/src/contextual/types/datastores/document_metadata.py @@ -1,6 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Dict, Union, Optional +from typing import Dict, List, Union, Optional from typing_extensions import Literal from ..._models import BaseModel @@ -38,7 +38,7 @@ class DocumentMetadata(BaseModel): status: Literal["pending", "processing", "retrying", "completed", "failed", "cancelled"] """Status of this document's ingestion job""" - custom_metadata: Optional[Dict[str, Union[bool, float, str]]] = None + custom_metadata: Optional[Dict[str, Union[bool, float, str, List[float]]]] = None """ Custom metadata for the document, provided by the user at ingestion time.Must be a JSON-serializable dictionary with string keys and simple primitive values @@ -52,25 +52,18 @@ class DocumentMetadata(BaseModel): custom_metadata_config: Optional[Dict[str, CustomMetadataConfig]] = None """ A dictionary mapping metadata field names to the configuration to use for each - field. - - - If a metadata field is not present in the dictionary, the default configuration will be used. - - - If the dictionary is not provided, metadata will be added in chunks but will not be retrievable. - - - Limits: - Maximum characters per metadata field (for prompt or rerank): 400 - - - Maximum number of metadata fields (for prompt or retrieval): 10 - - - Contact support@contextual.ai to request quota increases. + field. If a metadata field is not present in the dictionary, the default + configuration will be used. If the dictionary is not provided, metadata will be + added in context for rerank and generation but will not be returned back to the + user in retrievals in query API. Limits: - Maximum characters per metadata field + (for prompt or rerank): **400** - Maximum number of metadata fields (for prompt + or retrieval): **10** Contact support@contextual.ai to request quota increases. """ has_access: Optional[bool] = None """Whether the user has access to this document.""" - ingestion_config: Optional[object] = None + ingestion_config: Optional[Dict[str, object]] = None """Ingestion configuration for the document when the document was ingested. It may be different from the current datastore configuration. diff --git a/src/contextual/types/datastores/document_set_metadata_params.py b/src/contextual/types/datastores/document_set_metadata_params.py index 4c218225..9582a46f 100644 --- a/src/contextual/types/datastores/document_set_metadata_params.py +++ b/src/contextual/types/datastores/document_set_metadata_params.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Dict, Union +from typing import Dict, Union, Iterable from typing_extensions import Required, TypedDict __all__ = ["DocumentSetMetadataParams", "CustomMetadataConfig"] @@ -12,7 +12,7 @@ class DocumentSetMetadataParams(TypedDict, total=False): datastore_id: Required[str] """Datastore ID of the datastore from which to retrieve the document""" - custom_metadata: Dict[str, Union[bool, float, str]] + custom_metadata: Dict[str, Union[bool, float, str, Iterable[float]]] """ Custom metadata for the document, provided by the user at ingestion time.Must be a JSON-serializable dictionary with string keys and simple primitive values @@ -26,19 +26,12 @@ class DocumentSetMetadataParams(TypedDict, total=False): custom_metadata_config: Dict[str, CustomMetadataConfig] """ A dictionary mapping metadata field names to the configuration to use for each - field. - - - If a metadata field is not present in the dictionary, the default configuration will be used. - - - If the dictionary is not provided, metadata will be added in chunks but will not be retrievable. - - - Limits: - Maximum characters per metadata field (for prompt or rerank): 400 - - - Maximum number of metadata fields (for prompt or retrieval): 10 - - - Contact support@contextual.ai to request quota increases. + field. If a metadata field is not present in the dictionary, the default + configuration will be used. If the dictionary is not provided, metadata will be + added in context for rerank and generation but will not be returned back to the + user in retrievals in query API. Limits: - Maximum characters per metadata field + (for prompt or rerank): **400** - Maximum number of metadata fields (for prompt + or retrieval): **10** Contact support@contextual.ai to request quota increases. """ diff --git a/src/contextual/types/generate_create_params.py b/src/contextual/types/generate_create_params.py index 61dafc8c..662a561b 100644 --- a/src/contextual/types/generate_create_params.py +++ b/src/contextual/types/generate_create_params.py @@ -2,14 +2,16 @@ from __future__ import annotations -from typing import List, Iterable +from typing import Iterable from typing_extensions import Literal, Required, TypedDict +from .._types import SequenceNotStr + __all__ = ["GenerateCreateParams", "Message"] class GenerateCreateParams(TypedDict, total=False): - knowledge: Required[List[str]] + knowledge: Required[SequenceNotStr[str]] """The knowledge sources the model can use when generating a response.""" messages: Required[Iterable[Message]] diff --git a/src/contextual/types/list_users_response.py b/src/contextual/types/list_users_response.py index aac19ee0..2cebb635 100644 --- a/src/contextual/types/list_users_response.py +++ b/src/contextual/types/list_users_response.py @@ -33,11 +33,14 @@ class User(BaseModel): Literal[ "VISITOR", "AGENT_USER", + "CUSTOMER_USER", "CUSTOMER_INTERNAL_USER", "CONTEXTUAL_STAFF_USER", "CONTEXTUAL_EXTERNAL_STAFF_USER", "CONTEXTUAL_INTERNAL_STAFF_USER", "TENANT_ADMIN", + "CUSTOMER_ADMIN", + "CONTEXTUAL_ADMIN", "SUPER_ADMIN", "SERVICE_ACCOUNT", ] @@ -61,11 +64,14 @@ class User(BaseModel): Literal[ "VISITOR", "AGENT_USER", + "CUSTOMER_USER", "CUSTOMER_INTERNAL_USER", "CONTEXTUAL_STAFF_USER", "CONTEXTUAL_EXTERNAL_STAFF_USER", "CONTEXTUAL_INTERNAL_STAFF_USER", "TENANT_ADMIN", + "CUSTOMER_ADMIN", + "CONTEXTUAL_ADMIN", "SUPER_ADMIN", "SERVICE_ACCOUNT", ] diff --git a/src/contextual/types/new_user_param.py b/src/contextual/types/new_user_param.py index 09bc557e..027d7594 100644 --- a/src/contextual/types/new_user_param.py +++ b/src/contextual/types/new_user_param.py @@ -41,11 +41,14 @@ class NewUserParam(TypedDict, total=False): Literal[ "VISITOR", "AGENT_USER", + "CUSTOMER_USER", "CUSTOMER_INTERNAL_USER", "CONTEXTUAL_STAFF_USER", "CONTEXTUAL_EXTERNAL_STAFF_USER", "CONTEXTUAL_INTERNAL_STAFF_USER", "TENANT_ADMIN", + "CUSTOMER_ADMIN", + "CONTEXTUAL_ADMIN", "SUPER_ADMIN", "SERVICE_ACCOUNT", ] diff --git a/src/contextual/types/rerank_create_params.py b/src/contextual/types/rerank_create_params.py index 5e59ef6e..582ff86b 100644 --- a/src/contextual/types/rerank_create_params.py +++ b/src/contextual/types/rerank_create_params.py @@ -2,14 +2,15 @@ from __future__ import annotations -from typing import List from typing_extensions import Required, TypedDict +from .._types import SequenceNotStr + __all__ = ["RerankCreateParams"] class RerankCreateParams(TypedDict, total=False): - documents: Required[List[str]] + documents: Required[SequenceNotStr[str]] """ The texts to be reranked according to their relevance to the query and the optional instruction @@ -37,7 +38,7 @@ class RerankCreateParams(TypedDict, total=False): portal content supersedes distributor communications." """ - metadata: List[str] + metadata: SequenceNotStr[str] """Metadata for documents being passed to the reranker. Must be the same length as the documents list. If a document does not have diff --git a/src/contextual/types/user_update_params.py b/src/contextual/types/user_update_params.py index 86561832..229d81d1 100644 --- a/src/contextual/types/user_update_params.py +++ b/src/contextual/types/user_update_params.py @@ -30,11 +30,14 @@ class UserUpdateParams(TypedDict, total=False): Literal[ "VISITOR", "AGENT_USER", + "CUSTOMER_USER", "CUSTOMER_INTERNAL_USER", "CONTEXTUAL_STAFF_USER", "CONTEXTUAL_EXTERNAL_STAFF_USER", "CONTEXTUAL_INTERNAL_STAFF_USER", "TENANT_ADMIN", + "CUSTOMER_ADMIN", + "CONTEXTUAL_ADMIN", "SUPER_ADMIN", "SERVICE_ACCOUNT", ] diff --git a/tests/api_resources/agents/test_query.py b/tests/api_resources/agents/test_query.py index d706d780..c126460c 100644 --- a/tests/api_resources/agents/test_query.py +++ b/tests/api_resources/agents/test_query.py @@ -13,6 +13,7 @@ from contextual.types.agents import ( QueryResponse, QueryMetricsResponse, + QueryFeedbackResponse, RetrievalInfoResponse, ) @@ -43,13 +44,20 @@ def test_method_create_with_all_params(self, client: ContextualAI) -> None: { "content": "content", "role": "user", + "custom_tags": ["string"], } ], include_retrieval_content_text=True, retrievals_only=True, conversation_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", documents_filters={ - "filters": [], + "filters": [ + { + "field": "field1", + "operator": "equals", + "value": "value1", + } + ], "operator": "AND", }, llm_model_id="llm_model_id", @@ -73,7 +81,7 @@ def test_method_create_with_all_params(self, client: ContextualAI) -> None: }, stream=True, structured_output={ - "json_schema": {}, + "json_schema": {"foo": "bar"}, "type": "JSON", }, ) @@ -135,7 +143,7 @@ def test_method_feedback(self, client: ContextualAI) -> None: feedback="thumbs_up", message_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", ) - assert_matches_type(object, query, path=["response"]) + assert_matches_type(QueryFeedbackResponse, query, path=["response"]) @parametrize def test_method_feedback_with_all_params(self, client: ContextualAI) -> None: @@ -146,7 +154,7 @@ def test_method_feedback_with_all_params(self, client: ContextualAI) -> None: content_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", explanation="explanation", ) - assert_matches_type(object, query, path=["response"]) + assert_matches_type(QueryFeedbackResponse, query, path=["response"]) @parametrize def test_raw_response_feedback(self, client: ContextualAI) -> None: @@ -159,7 +167,7 @@ def test_raw_response_feedback(self, client: ContextualAI) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" query = response.parse() - assert_matches_type(object, query, path=["response"]) + assert_matches_type(QueryFeedbackResponse, query, path=["response"]) @parametrize def test_streaming_response_feedback(self, client: ContextualAI) -> None: @@ -172,7 +180,7 @@ def test_streaming_response_feedback(self, client: ContextualAI) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" query = response.parse() - assert_matches_type(object, query, path=["response"]) + assert_matches_type(QueryFeedbackResponse, query, path=["response"]) assert cast(Any, response.is_closed) is True @@ -316,13 +324,20 @@ async def test_method_create_with_all_params(self, async_client: AsyncContextual { "content": "content", "role": "user", + "custom_tags": ["string"], } ], include_retrieval_content_text=True, retrievals_only=True, conversation_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", documents_filters={ - "filters": [], + "filters": [ + { + "field": "field1", + "operator": "equals", + "value": "value1", + } + ], "operator": "AND", }, llm_model_id="llm_model_id", @@ -346,7 +361,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncContextual }, stream=True, structured_output={ - "json_schema": {}, + "json_schema": {"foo": "bar"}, "type": "JSON", }, ) @@ -408,7 +423,7 @@ async def test_method_feedback(self, async_client: AsyncContextualAI) -> None: feedback="thumbs_up", message_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", ) - assert_matches_type(object, query, path=["response"]) + assert_matches_type(QueryFeedbackResponse, query, path=["response"]) @parametrize async def test_method_feedback_with_all_params(self, async_client: AsyncContextualAI) -> None: @@ -419,7 +434,7 @@ async def test_method_feedback_with_all_params(self, async_client: AsyncContextu content_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", explanation="explanation", ) - assert_matches_type(object, query, path=["response"]) + assert_matches_type(QueryFeedbackResponse, query, path=["response"]) @parametrize async def test_raw_response_feedback(self, async_client: AsyncContextualAI) -> None: @@ -432,7 +447,7 @@ async def test_raw_response_feedback(self, async_client: AsyncContextualAI) -> N assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" query = await response.parse() - assert_matches_type(object, query, path=["response"]) + assert_matches_type(QueryFeedbackResponse, query, path=["response"]) @parametrize async def test_streaming_response_feedback(self, async_client: AsyncContextualAI) -> None: @@ -445,7 +460,7 @@ async def test_streaming_response_feedback(self, async_client: AsyncContextualAI assert response.http_request.headers.get("X-Stainless-Lang") == "python" query = await response.parse() - assert_matches_type(object, query, path=["response"]) + assert_matches_type(QueryFeedbackResponse, query, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/datastores/test_documents.py b/tests/api_resources/datastores/test_documents.py index d18c95cd..86caad78 100644 --- a/tests/api_resources/datastores/test_documents.py +++ b/tests/api_resources/datastores/test_documents.py @@ -192,6 +192,7 @@ def test_method_ingest_with_all_params(self, client: ContextualAI) -> None: document = client.datastores.documents.ingest( datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", file=b"raw file contents", + configuration="configuration", metadata="metadata", ) assert_matches_type(IngestionResponse, document, path=["response"]) @@ -517,6 +518,7 @@ async def test_method_ingest_with_all_params(self, async_client: AsyncContextual document = await async_client.datastores.documents.ingest( datastore_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", file=b"raw file contents", + configuration="configuration", metadata="metadata", ) assert_matches_type(IngestionResponse, document, path=["response"]) diff --git a/tests/api_resources/test_agents.py b/tests/api_resources/test_agents.py index f9467ac5..f892cce2 100644 --- a/tests/api_resources/test_agents.py +++ b/tests/api_resources/test_agents.py @@ -34,14 +34,30 @@ def test_method_create_with_all_params(self, client: ContextualAI) -> None: agent = client.agents.create( name="xxx", agent_configs={ + "acl_config": { + "acl_active": True, + "acl_yaml": "acl_yaml", + }, "filter_and_rerank_config": { "default_metadata_filters": { - "filters": [], + "filters": [ + { + "field": "field1", + "operator": "equals", + "value": "value1", + } + ], "operator": "AND", }, "per_datastore_metadata_filters": { "d49609d9-61c3-4a67-b3bd-5196b10da560": { - "filters": [], + "filters": [ + { + "field": "field1", + "operator": "equals", + "value": "value1", + } + ], "operator": "AND", } }, @@ -75,6 +91,10 @@ def test_method_create_with_all_params(self, client: ContextualAI) -> None: "semantic_alpha": 0, "top_k_retrieved_chunks": 0, }, + "translation_config": { + "translate_confidence": 0, + "translate_needed": True, + }, }, datastore_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], description="description", @@ -83,6 +103,7 @@ def test_method_create_with_all_params(self, client: ContextualAI) -> None: no_retrieval_system_prompt="no_retrieval_system_prompt", suggested_queries=["string"], system_prompt="system_prompt", + template_name="template_name", ) assert_matches_type(CreateAgentOutput, agent, path=["response"]) @@ -122,14 +143,30 @@ def test_method_update_with_all_params(self, client: ContextualAI) -> None: agent = client.agents.update( agent_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", agent_configs={ + "acl_config": { + "acl_active": True, + "acl_yaml": "acl_yaml", + }, "filter_and_rerank_config": { "default_metadata_filters": { - "filters": [], + "filters": [ + { + "field": "field1", + "operator": "equals", + "value": "value1", + } + ], "operator": "AND", }, "per_datastore_metadata_filters": { "d49609d9-61c3-4a67-b3bd-5196b10da560": { - "filters": [], + "filters": [ + { + "field": "field1", + "operator": "equals", + "value": "value1", + } + ], "operator": "AND", } }, @@ -163,11 +200,16 @@ def test_method_update_with_all_params(self, client: ContextualAI) -> None: "semantic_alpha": 0, "top_k_retrieved_chunks": 0, }, + "translation_config": { + "translate_confidence": 0, + "translate_needed": True, + }, }, datastore_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + description="description", filter_prompt="filter_prompt", - llm_model_id="llm_model_id", multiturn_system_prompt="multiturn_system_prompt", + name="xxx", no_retrieval_system_prompt="no_retrieval_system_prompt", suggested_queries=["string"], system_prompt="system_prompt", @@ -408,14 +450,30 @@ async def test_method_create_with_all_params(self, async_client: AsyncContextual agent = await async_client.agents.create( name="xxx", agent_configs={ + "acl_config": { + "acl_active": True, + "acl_yaml": "acl_yaml", + }, "filter_and_rerank_config": { "default_metadata_filters": { - "filters": [], + "filters": [ + { + "field": "field1", + "operator": "equals", + "value": "value1", + } + ], "operator": "AND", }, "per_datastore_metadata_filters": { "d49609d9-61c3-4a67-b3bd-5196b10da560": { - "filters": [], + "filters": [ + { + "field": "field1", + "operator": "equals", + "value": "value1", + } + ], "operator": "AND", } }, @@ -449,6 +507,10 @@ async def test_method_create_with_all_params(self, async_client: AsyncContextual "semantic_alpha": 0, "top_k_retrieved_chunks": 0, }, + "translation_config": { + "translate_confidence": 0, + "translate_needed": True, + }, }, datastore_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], description="description", @@ -457,6 +519,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncContextual no_retrieval_system_prompt="no_retrieval_system_prompt", suggested_queries=["string"], system_prompt="system_prompt", + template_name="template_name", ) assert_matches_type(CreateAgentOutput, agent, path=["response"]) @@ -496,14 +559,30 @@ async def test_method_update_with_all_params(self, async_client: AsyncContextual agent = await async_client.agents.update( agent_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", agent_configs={ + "acl_config": { + "acl_active": True, + "acl_yaml": "acl_yaml", + }, "filter_and_rerank_config": { "default_metadata_filters": { - "filters": [], + "filters": [ + { + "field": "field1", + "operator": "equals", + "value": "value1", + } + ], "operator": "AND", }, "per_datastore_metadata_filters": { "d49609d9-61c3-4a67-b3bd-5196b10da560": { - "filters": [], + "filters": [ + { + "field": "field1", + "operator": "equals", + "value": "value1", + } + ], "operator": "AND", } }, @@ -537,11 +616,16 @@ async def test_method_update_with_all_params(self, async_client: AsyncContextual "semantic_alpha": 0, "top_k_retrieved_chunks": 0, }, + "translation_config": { + "translate_confidence": 0, + "translate_needed": True, + }, }, datastore_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + description="description", filter_prompt="filter_prompt", - llm_model_id="llm_model_id", multiturn_system_prompt="multiturn_system_prompt", + name="xxx", no_retrieval_system_prompt="no_retrieval_system_prompt", suggested_queries=["string"], system_prompt="system_prompt", diff --git a/tests/test_client.py b/tests/test_client.py index df3fc8da..60b3a988 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -6,13 +6,10 @@ import os import sys import json -import time import asyncio import inspect -import subprocess import tracemalloc from typing import Any, Union, cast -from textwrap import dedent from unittest import mock from typing_extensions import Literal @@ -23,14 +20,17 @@ from contextual import ContextualAI, AsyncContextualAI, APIResponseValidationError from contextual._types import Omit +from contextual._utils import asyncify from contextual._models import BaseModel, FinalRequestOptions from contextual._exceptions import APIStatusError, APITimeoutError, ContextualAIError, APIResponseValidationError from contextual._base_client import ( DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, + OtherPlatform, DefaultHttpxClient, DefaultAsyncHttpxClient, + get_platform, make_request_options, ) @@ -1643,52 +1643,9 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: assert response.http_request.headers.get("x-stainless-retry-count") == "42" - def test_get_platform(self) -> None: - # A previous implementation of asyncify could leave threads unterminated when - # used with nest_asyncio. - # - # Since nest_asyncio.apply() is global and cannot be un-applied, this - # test is run in a separate process to avoid affecting other tests. - test_code = dedent( - """ - import asyncio - import nest_asyncio - import threading - - from contextual._utils import asyncify - from contextual._base_client import get_platform - - async def test_main() -> None: - result = await asyncify(get_platform)() - print(result) - for thread in threading.enumerate(): - print(thread.name) - - nest_asyncio.apply() - asyncio.run(test_main()) - """ - ) - with subprocess.Popen( - [sys.executable, "-c", test_code], - text=True, - ) as process: - timeout = 10 # seconds - - start_time = time.monotonic() - while True: - return_code = process.poll() - if return_code is not None: - if return_code != 0: - raise AssertionError("calling get_platform using asyncify resulted in a non-zero exit code") - - # success - break - - if time.monotonic() - start_time > timeout: - process.kill() - raise AssertionError("calling get_platform using asyncify resulted in a hung process") - - time.sleep(0.1) + async def test_get_platform(self) -> None: + platform = await asyncify(get_platform)() + assert isinstance(platform, (str, OtherPlatform)) async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None: # Test that the proxy environment variables are set correctly diff --git a/tests/test_models.py b/tests/test_models.py index ae4b3f01..45f5759f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -8,7 +8,7 @@ from pydantic import Field from contextual._utils import PropertyInfo -from contextual._compat import PYDANTIC_V2, parse_obj, model_dump, model_json +from contextual._compat import PYDANTIC_V1, parse_obj, model_dump, model_json from contextual._models import BaseModel, construct_type @@ -294,12 +294,12 @@ class Model(BaseModel): assert cast(bool, m.foo) is True m = Model.construct(foo={"name": 3}) - if PYDANTIC_V2: - assert isinstance(m.foo, Submodel1) - assert m.foo.name == 3 # type: ignore - else: + if PYDANTIC_V1: assert isinstance(m.foo, Submodel2) assert m.foo.name == "3" + else: + assert isinstance(m.foo, Submodel1) + assert m.foo.name == 3 # type: ignore def test_list_of_unions() -> None: @@ -426,10 +426,10 @@ class Model(BaseModel): expected = datetime(2019, 12, 27, 18, 11, 19, 117000, tzinfo=timezone.utc) - if PYDANTIC_V2: - expected_json = '{"created_at":"2019-12-27T18:11:19.117000Z"}' - else: + if PYDANTIC_V1: expected_json = '{"created_at": "2019-12-27T18:11:19.117000+00:00"}' + else: + expected_json = '{"created_at":"2019-12-27T18:11:19.117000Z"}' model = Model.construct(created_at="2019-12-27T18:11:19.117Z") assert model.created_at == expected @@ -531,7 +531,7 @@ class Model2(BaseModel): assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)} assert m4.to_dict(mode="json") == {"created_at": time_str} - if not PYDANTIC_V2: + if PYDANTIC_V1: with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): m.to_dict(warnings=False) @@ -556,7 +556,7 @@ class Model(BaseModel): assert m3.model_dump() == {"foo": None} assert m3.model_dump(exclude_none=True) == {} - if not PYDANTIC_V2: + if PYDANTIC_V1: with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"): m.model_dump(round_trip=True) @@ -580,10 +580,10 @@ class Model(BaseModel): assert json.loads(m.to_json()) == {"FOO": "hello"} assert json.loads(m.to_json(use_api_names=False)) == {"foo": "hello"} - if PYDANTIC_V2: - assert m.to_json(indent=None) == '{"FOO":"hello"}' - else: + if PYDANTIC_V1: assert m.to_json(indent=None) == '{"FOO": "hello"}' + else: + assert m.to_json(indent=None) == '{"FOO":"hello"}' m2 = Model() assert json.loads(m2.to_json()) == {} @@ -595,7 +595,7 @@ class Model(BaseModel): assert json.loads(m3.to_json()) == {"FOO": None} assert json.loads(m3.to_json(exclude_none=True)) == {} - if not PYDANTIC_V2: + if PYDANTIC_V1: with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): m.to_json(warnings=False) @@ -622,7 +622,7 @@ class Model(BaseModel): assert json.loads(m3.model_dump_json()) == {"foo": None} assert json.loads(m3.model_dump_json(exclude_none=True)) == {} - if not PYDANTIC_V2: + if PYDANTIC_V1: with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"): m.model_dump_json(round_trip=True) @@ -679,12 +679,12 @@ class B(BaseModel): ) assert isinstance(m, A) assert m.type == "a" - if PYDANTIC_V2: - assert m.data == 100 # type: ignore[comparison-overlap] - else: + if PYDANTIC_V1: # pydantic v1 automatically converts inputs to strings # if the expected type is a str assert m.data == "100" + else: + assert m.data == 100 # type: ignore[comparison-overlap] def test_discriminated_unions_unknown_variant() -> None: @@ -768,12 +768,12 @@ class B(BaseModel): ) assert isinstance(m, A) assert m.foo_type == "a" - if PYDANTIC_V2: - assert m.data == 100 # type: ignore[comparison-overlap] - else: + if PYDANTIC_V1: # pydantic v1 automatically converts inputs to strings # if the expected type is a str assert m.data == "100" + else: + assert m.data == 100 # type: ignore[comparison-overlap] def test_discriminated_unions_overlapping_discriminators_invalid_data() -> None: @@ -833,7 +833,7 @@ class B(BaseModel): assert UnionType.__discriminator__ is discriminator -@pytest.mark.skipif(not PYDANTIC_V2, reason="TypeAliasType is not supported in Pydantic v1") +@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1") def test_type_alias_type() -> None: Alias = TypeAliasType("Alias", str) # pyright: ignore @@ -849,7 +849,7 @@ class Model(BaseModel): assert m.union == "bar" -@pytest.mark.skipif(not PYDANTIC_V2, reason="TypeAliasType is not supported in Pydantic v1") +@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1") def test_field_named_cls() -> None: class Model(BaseModel): cls: str @@ -936,7 +936,7 @@ class Type2(BaseModel): assert isinstance(model.value, InnerType2) -@pytest.mark.skipif(not PYDANTIC_V2, reason="this is only supported in pydantic v2 for now") +@pytest.mark.skipif(PYDANTIC_V1, reason="this is only supported in pydantic v2 for now") def test_extra_properties() -> None: class Item(BaseModel): prop: int diff --git a/tests/test_transform.py b/tests/test_transform.py index 6fe1ce3f..035d26e7 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -8,14 +8,14 @@ import pytest -from contextual._types import NOT_GIVEN, Base64FileInput +from contextual._types import Base64FileInput, omit, not_given from contextual._utils import ( PropertyInfo, transform as _transform, parse_datetime, async_transform as _async_transform, ) -from contextual._compat import PYDANTIC_V2 +from contextual._compat import PYDANTIC_V1 from contextual._models import BaseModel _T = TypeVar("_T") @@ -189,7 +189,7 @@ class DateModel(BaseModel): @pytest.mark.asyncio async def test_iso8601_format(use_async: bool) -> None: dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") - tz = "Z" if PYDANTIC_V2 else "+00:00" + tz = "+00:00" if PYDANTIC_V1 else "Z" assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692" + tz} # type: ignore[comparison-overlap] @@ -297,11 +297,11 @@ async def test_pydantic_unknown_field(use_async: bool) -> None: @pytest.mark.asyncio async def test_pydantic_mismatched_types(use_async: bool) -> None: model = MyModel.construct(foo=True) - if PYDANTIC_V2: + if PYDANTIC_V1: + params = await transform(model, Any, use_async) + else: with pytest.warns(UserWarning): params = await transform(model, Any, use_async) - else: - params = await transform(model, Any, use_async) assert cast(Any, params) == {"foo": True} @@ -309,11 +309,11 @@ async def test_pydantic_mismatched_types(use_async: bool) -> None: @pytest.mark.asyncio async def test_pydantic_mismatched_object_type(use_async: bool) -> None: model = MyModel.construct(foo=MyModel.construct(hello="world")) - if PYDANTIC_V2: + if PYDANTIC_V1: + params = await transform(model, Any, use_async) + else: with pytest.warns(UserWarning): params = await transform(model, Any, use_async) - else: - params = await transform(model, Any, use_async) assert cast(Any, params) == {"foo": {"hello": "world"}} @@ -450,4 +450,11 @@ async def test_transform_skipping(use_async: bool) -> None: @pytest.mark.asyncio async def test_strips_notgiven(use_async: bool) -> None: assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"} - assert await transform({"foo_bar": NOT_GIVEN}, Foo1, use_async) == {} + assert await transform({"foo_bar": not_given}, Foo1, use_async) == {} + + +@parametrize +@pytest.mark.asyncio +async def test_strips_omit(use_async: bool) -> None: + assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"} + assert await transform({"foo_bar": omit}, Foo1, use_async) == {} diff --git a/tests/test_utils/test_datetime_parse.py b/tests/test_utils/test_datetime_parse.py new file mode 100644 index 00000000..c6f158fc --- /dev/null +++ b/tests/test_utils/test_datetime_parse.py @@ -0,0 +1,110 @@ +""" +Copied from https://github.com/pydantic/pydantic/blob/v1.10.22/tests/test_datetime_parse.py +with modifications so it works without pydantic v1 imports. +""" + +from typing import Type, Union +from datetime import date, datetime, timezone, timedelta + +import pytest + +from contextual._utils import parse_date, parse_datetime + + +def create_tz(minutes: int) -> timezone: + return timezone(timedelta(minutes=minutes)) + + +@pytest.mark.parametrize( + "value,result", + [ + # Valid inputs + ("1494012444.883309", date(2017, 5, 5)), + (b"1494012444.883309", date(2017, 5, 5)), + (1_494_012_444.883_309, date(2017, 5, 5)), + ("1494012444", date(2017, 5, 5)), + (1_494_012_444, date(2017, 5, 5)), + (0, date(1970, 1, 1)), + ("2012-04-23", date(2012, 4, 23)), + (b"2012-04-23", date(2012, 4, 23)), + ("2012-4-9", date(2012, 4, 9)), + (date(2012, 4, 9), date(2012, 4, 9)), + (datetime(2012, 4, 9, 12, 15), date(2012, 4, 9)), + # Invalid inputs + ("x20120423", ValueError), + ("2012-04-56", ValueError), + (19_999_999_999, date(2603, 10, 11)), # just before watershed + (20_000_000_001, date(1970, 8, 20)), # just after watershed + (1_549_316_052, date(2019, 2, 4)), # nowish in s + (1_549_316_052_104, date(2019, 2, 4)), # nowish in ms + (1_549_316_052_104_324, date(2019, 2, 4)), # nowish in μs + (1_549_316_052_104_324_096, date(2019, 2, 4)), # nowish in ns + ("infinity", date(9999, 12, 31)), + ("inf", date(9999, 12, 31)), + (float("inf"), date(9999, 12, 31)), + ("infinity ", date(9999, 12, 31)), + (int("1" + "0" * 100), date(9999, 12, 31)), + (1e1000, date(9999, 12, 31)), + ("-infinity", date(1, 1, 1)), + ("-inf", date(1, 1, 1)), + ("nan", ValueError), + ], +) +def test_date_parsing(value: Union[str, bytes, int, float], result: Union[date, Type[Exception]]) -> None: + if type(result) == type and issubclass(result, Exception): # pyright: ignore[reportUnnecessaryIsInstance] + with pytest.raises(result): + parse_date(value) + else: + assert parse_date(value) == result + + +@pytest.mark.parametrize( + "value,result", + [ + # Valid inputs + # values in seconds + ("1494012444.883309", datetime(2017, 5, 5, 19, 27, 24, 883_309, tzinfo=timezone.utc)), + (1_494_012_444.883_309, datetime(2017, 5, 5, 19, 27, 24, 883_309, tzinfo=timezone.utc)), + ("1494012444", datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + (b"1494012444", datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + (1_494_012_444, datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + # values in ms + ("1494012444000.883309", datetime(2017, 5, 5, 19, 27, 24, 883, tzinfo=timezone.utc)), + ("-1494012444000.883309", datetime(1922, 8, 29, 4, 32, 35, 999117, tzinfo=timezone.utc)), + (1_494_012_444_000, datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + ("2012-04-23T09:15:00", datetime(2012, 4, 23, 9, 15)), + ("2012-4-9 4:8:16", datetime(2012, 4, 9, 4, 8, 16)), + ("2012-04-23T09:15:00Z", datetime(2012, 4, 23, 9, 15, 0, 0, timezone.utc)), + ("2012-4-9 4:8:16-0320", datetime(2012, 4, 9, 4, 8, 16, 0, create_tz(-200))), + ("2012-04-23T10:20:30.400+02:30", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(150))), + ("2012-04-23T10:20:30.400+02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(120))), + ("2012-04-23T10:20:30.400-02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(-120))), + (b"2012-04-23T10:20:30.400-02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(-120))), + (datetime(2017, 5, 5), datetime(2017, 5, 5)), + (0, datetime(1970, 1, 1, 0, 0, 0, tzinfo=timezone.utc)), + # Invalid inputs + ("x20120423091500", ValueError), + ("2012-04-56T09:15:90", ValueError), + ("2012-04-23T11:05:00-25:00", ValueError), + (19_999_999_999, datetime(2603, 10, 11, 11, 33, 19, tzinfo=timezone.utc)), # just before watershed + (20_000_000_001, datetime(1970, 8, 20, 11, 33, 20, 1000, tzinfo=timezone.utc)), # just after watershed + (1_549_316_052, datetime(2019, 2, 4, 21, 34, 12, 0, tzinfo=timezone.utc)), # nowish in s + (1_549_316_052_104, datetime(2019, 2, 4, 21, 34, 12, 104_000, tzinfo=timezone.utc)), # nowish in ms + (1_549_316_052_104_324, datetime(2019, 2, 4, 21, 34, 12, 104_324, tzinfo=timezone.utc)), # nowish in μs + (1_549_316_052_104_324_096, datetime(2019, 2, 4, 21, 34, 12, 104_324, tzinfo=timezone.utc)), # nowish in ns + ("infinity", datetime(9999, 12, 31, 23, 59, 59, 999999)), + ("inf", datetime(9999, 12, 31, 23, 59, 59, 999999)), + ("inf ", datetime(9999, 12, 31, 23, 59, 59, 999999)), + (1e50, datetime(9999, 12, 31, 23, 59, 59, 999999)), + (float("inf"), datetime(9999, 12, 31, 23, 59, 59, 999999)), + ("-infinity", datetime(1, 1, 1, 0, 0)), + ("-inf", datetime(1, 1, 1, 0, 0)), + ("nan", ValueError), + ], +) +def test_datetime_parsing(value: Union[str, bytes, int, float], result: Union[datetime, Type[Exception]]) -> None: + if type(result) == type and issubclass(result, Exception): # pyright: ignore[reportUnnecessaryIsInstance] + with pytest.raises(result): + parse_datetime(value) + else: + assert parse_datetime(value) == result diff --git a/tests/utils.py b/tests/utils.py index 1171455c..a047fcd6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,7 @@ import inspect import traceback import contextlib -from typing import Any, TypeVar, Iterator, cast +from typing import Any, TypeVar, Iterator, Sequence, cast from datetime import date, datetime from typing_extensions import Literal, get_args, get_origin, assert_type @@ -15,10 +15,11 @@ is_list_type, is_union_type, extract_type_arg, + is_sequence_type, is_annotated_type, is_type_alias_type, ) -from contextual._compat import PYDANTIC_V2, field_outer_type, get_model_fields +from contextual._compat import PYDANTIC_V1, field_outer_type, get_model_fields from contextual._models import BaseModel BaseModelT = TypeVar("BaseModelT", bound=BaseModel) @@ -27,12 +28,12 @@ def assert_matches_model(model: type[BaseModelT], value: BaseModelT, *, path: list[str]) -> bool: for name, field in get_model_fields(model).items(): field_value = getattr(value, name) - if PYDANTIC_V2: - allow_none = False - else: + if PYDANTIC_V1: # in v1 nullability was structured differently # https://docs.pydantic.dev/2.0/migration/#required-optional-and-nullable-fields allow_none = getattr(field, "allow_none", False) + else: + allow_none = False assert_matches_type( field_outer_type(field), @@ -71,6 +72,13 @@ def assert_matches_type( if is_list_type(type_): return _assert_list_type(type_, value) + if is_sequence_type(type_): + assert isinstance(value, Sequence) + inner_type = get_args(type_)[0] + for entry in value: # type: ignore + assert_type(inner_type, entry) # type: ignore + return + if origin == str: assert isinstance(value, str) elif origin == int: