-
Notifications
You must be signed in to change notification settings - Fork 1.1k
feat: add MoonshotAI provider with Kimi-K2 model support #2211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
dcd7d87
ab62dba
390b2f8
b9ccddc
5ac96cf
3bb2886
7360d6e
9e6db4f
15255cd
0853043
5e67bac
befbef1
fa81e0f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from __future__ import annotations as _annotations | ||
|
||
from . import ModelProfile | ||
|
||
|
||
def moonshotai_model_profile(model_name: str) -> ModelProfile | None: | ||
"""Get the model profile for a MoonshotAI model.""" | ||
return None |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from __future__ import annotations as _annotations | ||
|
||
import os | ||
from typing import Literal, overload | ||
|
||
from httpx import AsyncClient as AsyncHTTPClient | ||
from openai import AsyncOpenAI | ||
|
||
from pydantic_ai.exceptions import UserError | ||
from pydantic_ai.models import cached_async_http_client | ||
from pydantic_ai.profiles import ModelProfile | ||
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile | ||
from pydantic_ai.profiles.openai import ( | ||
OpenAIJsonSchemaTransformer, | ||
OpenAIModelProfile, | ||
) | ||
from pydantic_ai.providers import Provider | ||
|
||
MoonshotModelName = Literal[ | ||
'moonshot-v1-8k', | ||
'moonshot-v1-32k', | ||
'moonshot-v1-128k', | ||
'moonshot-v1-8k-vision-preview', | ||
'moonshot-v1-32k-vision-preview', | ||
'moonshot-v1-128k-vision-preview', | ||
zachmayer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
'kimi-latest', | ||
'kimi-thinking-preview', | ||
'kimi-k2-0711-preview', | ||
] | ||
|
||
|
||
class MoonshotAIProvider(Provider[AsyncOpenAI]): | ||
"""Provider for MoonshotAI platform (Kimi models).""" | ||
|
||
@property | ||
def name(self) -> str: | ||
return 'moonshotai' | ||
|
||
@property | ||
def base_url(self) -> str: | ||
# OpenAI-compatible endpoint, see MoonshotAI docs | ||
return 'https://api.moonshot.ai/v1' | ||
|
||
@property | ||
def client(self) -> AsyncOpenAI: | ||
return self._client | ||
|
||
def model_profile(self, model_name: str) -> ModelProfile | None: | ||
profile = moonshotai_model_profile(model_name) | ||
|
||
# As the MoonshotAI API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer, | ||
# unless json_schema_transformer is set explicitly. | ||
# Also, MoonshotAI does not support strict tool definitions | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this actually the case or was it just a misinterpretation of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at their docs:
My reading of this is that they don't support strict: mode. You can ask it for a json object, but you need to verybally describe the fields you want in the object. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh hmmm, this parameter is different than There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah look like they don't support the I don't see anything about |
||
# https://platform.moonshot.ai/docs/guide/migrating-from-openai-to-kimi#about-tool_choice | ||
# "Please note that the current version of Kimi API does not support the tool_choice=required parameter." | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return OpenAIModelProfile( | ||
json_schema_transformer=OpenAIJsonSchemaTransformer, | ||
openai_supports_strict_tool_definition=False, | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
openai_supports_tool_choice_required=False, | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
).update(profile) | ||
|
||
# --------------------------------------------------------------------- | ||
# Construction helpers | ||
# --------------------------------------------------------------------- | ||
@overload | ||
def __init__(self) -> None: ... | ||
|
||
@overload | ||
def __init__(self, *, api_key: str) -> None: ... | ||
|
||
@overload | ||
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ... | ||
|
||
@overload | ||
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ... | ||
|
||
def __init__( | ||
self, | ||
*, | ||
api_key: str | None = None, | ||
openai_client: AsyncOpenAI | None = None, | ||
http_client: AsyncHTTPClient | None = None, | ||
) -> None: | ||
api_key = api_key or os.getenv('MOONSHOT_API_KEY') | ||
if not api_key and openai_client is None: | ||
raise UserError( | ||
'Set the `MOONSHOT_API_KEY` environment variable or pass it via ' | ||
'`MoonshotAIProvider(api_key=...)` to use the MoonshotAI provider.' | ||
) | ||
|
||
if openai_client is not None: | ||
self._client = openai_client | ||
elif http_client is not None: | ||
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) | ||
else: | ||
http_client = cached_async_http_client(provider='moonshotai') | ||
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""Ensure `tool_choice='required'` is downgraded to `'auto'` when the profile says so.""" | ||
|
||
from __future__ import annotations | ||
|
||
import types | ||
from typing import Any | ||
|
||
import pytest | ||
|
||
from pydantic_ai.models import ModelRequestParameters | ||
from pydantic_ai.tools import ToolDefinition | ||
|
||
from ..conftest import try_import | ||
|
||
with try_import() as imports_successful: | ||
from pydantic_ai.models.openai import OpenAIModel | ||
from pydantic_ai.profiles.openai import OpenAIModelProfile | ||
|
||
|
||
pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed') | ||
|
||
|
||
@pytest.mark.anyio() | ||
async def test_tool_choice_fallback(monkeypatch: pytest.MonkeyPatch) -> None: | ||
monkeypatch.setenv('OPENAI_API_KEY', 'dummy') | ||
|
||
model = OpenAIModel('stub', provider='openai') | ||
|
||
# Make profile report lack of `tool_choice='required'` support but keep sampling | ||
def fake_from_profile(_p: Any) -> types.SimpleNamespace: | ||
return types.SimpleNamespace( | ||
openai_supports_tool_choice_required=False, | ||
openai_supports_sampling_settings=True, | ||
) | ||
|
||
monkeypatch.setattr(OpenAIModelProfile, 'from_profile', fake_from_profile, raising=True) | ||
|
||
captured: dict[str, Any] = {} | ||
|
||
async def fake_create(*_a: Any, **kw: Any) -> dict[str, Any]: | ||
captured.update(kw) | ||
return {} | ||
|
||
# Patch chat completions create | ||
monkeypatch.setattr(model.client.chat.completions, 'create', fake_create, raising=True) | ||
|
||
params = ModelRequestParameters(function_tools=[ToolDefinition(name='x')], allow_text_output=False) | ||
|
||
await model._completions_create( # pyright: ignore[reportPrivateUsage] | ||
messages=[], | ||
stream=False, | ||
model_settings={}, | ||
model_request_parameters=params, | ||
) | ||
|
||
assert captured.get('tool_choice') == 'auto' |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import re | ||
|
||
import httpx | ||
import pytest | ||
|
||
from pydantic_ai.exceptions import UserError | ||
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile | ||
|
||
from ..conftest import TestEnv, try_import | ||
|
||
with try_import() as imports_successful: | ||
import openai | ||
|
||
from pydantic_ai.models.openai import OpenAIModel | ||
from pydantic_ai.providers.moonshotai import MoonshotAIProvider | ||
|
||
pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed') | ||
|
||
|
||
def test_moonshotai_provider(): | ||
"""Test basic MoonshotAI provider initialization.""" | ||
provider = MoonshotAIProvider(api_key='api-key') | ||
assert provider.name == 'moonshotai' | ||
assert provider.base_url == 'https://api.moonshot.ai/v1' | ||
assert isinstance(provider.client, openai.AsyncOpenAI) | ||
assert provider.client.api_key == 'api-key' | ||
|
||
|
||
def test_moonshotai_provider_need_api_key(env: TestEnv) -> None: | ||
"""Test that MoonshotAI provider requires an API key.""" | ||
env.remove('MOONSHOT_API_KEY') | ||
with pytest.raises( | ||
UserError, | ||
match=re.escape( | ||
'Set the `MOONSHOT_API_KEY` environment variable or pass it via `MoonshotAIProvider(api_key=...)`' | ||
' to use the MoonshotAI provider.' | ||
), | ||
): | ||
MoonshotAIProvider() | ||
|
||
|
||
def test_moonshotai_provider_pass_http_client() -> None: | ||
"""Test passing a custom HTTP client to MoonshotAI provider.""" | ||
http_client = httpx.AsyncClient() | ||
provider = MoonshotAIProvider(http_client=http_client, api_key='api-key') | ||
assert provider.client._client == http_client # type: ignore[reportPrivateUsage] | ||
|
||
|
||
def test_moonshotai_pass_openai_client() -> None: | ||
"""Test passing a custom OpenAI client to MoonshotAI provider.""" | ||
openai_client = openai.AsyncOpenAI(api_key='api-key') | ||
provider = MoonshotAIProvider(openai_client=openai_client) | ||
assert provider.client == openai_client | ||
|
||
|
||
def test_moonshotai_provider_with_cached_http_client() -> None: | ||
"""Test MoonshotAI provider using cached HTTP client (covers line 76).""" | ||
# This should use the else branch with cached_async_http_client | ||
provider = MoonshotAIProvider(api_key='api-key') | ||
assert isinstance(provider.client, openai.AsyncOpenAI) | ||
assert provider.client.api_key == 'api-key' | ||
|
||
|
||
def test_moonshotai_model_profile(): | ||
provider = MoonshotAIProvider(api_key='api-key') | ||
model = OpenAIModel('kimi-k2-0711-preview', provider=provider) | ||
assert isinstance(model.profile, OpenAIModelProfile) | ||
assert model.profile.json_schema_transformer == OpenAIJsonSchemaTransformer | ||
assert model.profile.openai_supports_strict_tool_definition is False | ||
assert model.profile.openai_supports_tool_choice_required is False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with the provider class: