Skip to content

feat: add MoonshotAI provider with Kimi-K2 model support #2211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
18 changes: 18 additions & 0 deletions docs/models/openai.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,24 @@ agent = Agent(model)
...
```

### MoonshotAI

Create an API key in the [Moonshot Console](https://platform.moonshot.ai/console).
With that key you can instantiate the [`MoonshotAIProvider`][pydantic_ai.providers.moonshotai.MoonshotAIProvider]:

```python
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.moonshotai import MoonshotAIProvider

model = OpenAIModel(
'kimi-k2-0711-preview',
provider=MoonshotAIProvider(api_key='your-moonshot-api-key'),
)
agent = Agent(model)
...
```

### GitHub Models

To use [GitHub Models](https://docs.github.com/en/github-models), you'll need a GitHub personal access token with the `models: read` permission.
Expand Down
10 changes: 10 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,15 @@
'mistral:mistral-large-latest',
'mistral:mistral-moderation-latest',
'mistral:mistral-small-latest',
'moonshotai:moonshot-v1-8k',
'moonshotai:moonshot-v1-32k',
'moonshotai:moonshot-v1-128k',
'moonshotai:moonshot-v1-8k-vision-preview',
'moonshotai:moonshot-v1-32k-vision-preview',
'moonshotai:moonshot-v1-128k-vision-preview',
'moonshotai:kimi-latest',
'moonshotai:kimi-thinking-preview',
'moonshotai:kimi-k2-0711-preview',
'o1',
'o1-2024-12-17',
'o1-mini',
Expand Down Expand Up @@ -607,6 +616,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
'azure',
'openrouter',
'grok',
'moonshotai',
'fireworks',
'together',
'heroku',
Expand Down
16 changes: 14 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,16 @@ def __init__(
model_name: OpenAIModelName,
*,
provider: Literal[
'openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku', 'github'
'openai',
'deepseek',
'azure',
'openrouter',
'grok',
'moonshotai',
'fireworks',
'together',
'heroku',
'github',
]
| Provider[AsyncOpenAI] = 'openai',
profile: ModelProfileSpec | None = None,
Expand Down Expand Up @@ -289,7 +298,10 @@ async def _completions_create(
tools = self._get_tools(model_request_parameters)
if not tools:
tool_choice: Literal['none', 'required', 'auto'] | None = None
elif not model_request_parameters.allow_text_output:
elif (
not model_request_parameters.allow_text_output
and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required
):
tool_choice = 'required'
else:
tool_choice = 'auto'
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/profiles/moonshotai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import annotations as _annotations

from . import ModelProfile


def moonshotai_model_profile(model_name: str) -> ModelProfile | None:
"""Get the model profile for a MoonshotAI model."""
return None
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/profiles/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ class OpenAIModelProfile(ModelProfile):
openai_supports_sampling_settings: bool = True
"""Turn off to don't send sampling settings like `temperature` and `top_p` to models that don't support them, like OpenAI's o-series reasoning models."""

# Some OpenAI-compatible providers (e.g. MoonshotAI) currently do **not** accept
# `tool_choice="required"`. This flag lets the calling model know whether it's
# safe to pass that value along. Default is `True` to preserve existing
# behaviour for OpenAI itself and most providers.
openai_supports_tool_choice_required: bool = True
"""Whether the provider accepts the value ``tool_choice='required'`` in the
request payload."""


def openai_model_profile(model_name: str) -> ModelProfile:
"""Get the model profile for an OpenAI model."""
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
from .grok import GrokProvider

return GrokProvider
elif provider == 'moonshotai':
from .moonshotai import MoonshotAIProvider

return MoonshotAIProvider
elif provider == 'fireworks':
from .fireworks import FireworksProvider

Expand Down
97 changes: 97 additions & 0 deletions pydantic_ai_slim/pydantic_ai/providers/moonshotai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from __future__ import annotations as _annotations

import os
from typing import Literal, overload

from httpx import AsyncClient as AsyncHTTPClient
from openai import AsyncOpenAI

from pydantic_ai.exceptions import UserError
from pydantic_ai.models import cached_async_http_client
from pydantic_ai.profiles import ModelProfile
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
from pydantic_ai.profiles.openai import (
OpenAIJsonSchemaTransformer,
OpenAIModelProfile,
)
from pydantic_ai.providers import Provider

MoonshotModelName = Literal[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency with the provider class:

Suggested change
MoonshotModelName = Literal[
MoonshotAIModelName = Literal[

'moonshot-v1-8k',
'moonshot-v1-32k',
'moonshot-v1-128k',
'moonshot-v1-8k-vision-preview',
'moonshot-v1-32k-vision-preview',
'moonshot-v1-128k-vision-preview',
'kimi-latest',
'kimi-thinking-preview',
'kimi-k2-0711-preview',
]


class MoonshotAIProvider(Provider[AsyncOpenAI]):
"""Provider for MoonshotAI platform (Kimi models)."""

@property
def name(self) -> str:
return 'moonshotai'

@property
def base_url(self) -> str:
# OpenAI-compatible endpoint, see MoonshotAI docs
return 'https://api.moonshot.ai/v1'

@property
def client(self) -> AsyncOpenAI:
return self._client

def model_profile(self, model_name: str) -> ModelProfile | None:
profile = moonshotai_model_profile(model_name)

# As the MoonshotAI API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer,
# unless json_schema_transformer is set explicitly.
# Also, MoonshotAI does not support strict tool definitions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually the case or was it just a misinterpretation of the tool_choice=required thing? If so, we can drop this from the comment and drop the field

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at their docs:

When you set response_format to {"type": "json_object"}, you must explicitly guide the model to output JSON-formatted content in the prompt and specify the exact format of the JSON, otherwise it may result in unexpected outcomes.

My reading of this is that they don't support strict: mode. You can ask it for a json object, but you need to verybally describe the fields you want in the object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh hmmm, this parameter is different than strict for output schemas though. Now I'm not sure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah look like they don't support the json_schema response_format, just json_object, so we can set supports_json_object_output to True (they're both False by default).

I don't see anything about strict on tool definitions though, so I think we can drop that.

# https://platform.moonshot.ai/docs/guide/migrating-from-openai-to-kimi#about-tool_choice
# "Please note that the current version of Kimi API does not support the tool_choice=required parameter."
return OpenAIModelProfile(
json_schema_transformer=OpenAIJsonSchemaTransformer,
openai_supports_strict_tool_definition=False,
openai_supports_tool_choice_required=False,
).update(profile)

# ---------------------------------------------------------------------
# Construction helpers
# ---------------------------------------------------------------------
@overload
def __init__(self) -> None: ...

@overload
def __init__(self, *, api_key: str) -> None: ...

@overload
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...

@overload
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...

def __init__(
self,
*,
api_key: str | None = None,
openai_client: AsyncOpenAI | None = None,
http_client: AsyncHTTPClient | None = None,
) -> None:
api_key = api_key or os.getenv('MOONSHOT_API_KEY')
if not api_key and openai_client is None:
raise UserError(
'Set the `MOONSHOT_API_KEY` environment variable or pass it via '
'`MoonshotAIProvider(api_key=...)` to use the MoonshotAI provider.'
)

if openai_client is not None:
self._client = openai_client
elif http_client is not None:
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
else:
http_client = cached_async_http_client(provider='moonshotai')
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ interactions:
response:
headers:
content-length:
- '545'
- '550'
content-security-policy:
- default-src 'none'; frame-ancestors 'none'
content-type:
Expand Down Expand Up @@ -46,6 +46,7 @@ interactions:
- text-to-text
- model_id: claude-4-sonnet
regions:
- eu
- us
type:
- text-to-text
Expand Down
8 changes: 8 additions & 0 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@
'github',
'OpenAIModel',
),
(
'MOONSHOT_API_KEY',
'moonshotai:kimi-k2-0711-preview',
'kimi-k2-0711-preview',
'moonshotai',
'moonshotai',
'OpenAIModel',
),
]


Expand Down
3 changes: 3 additions & 0 deletions tests/models/test_model_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pydantic_ai.models.huggingface import HuggingFaceModelName
from pydantic_ai.models.mistral import MistralModelName
from pydantic_ai.models.openai import OpenAIModelName
from pydantic_ai.providers.moonshotai import MoonshotModelName

pytestmark = [
pytest.mark.skipif(not imports_successful(), reason='some model package was not installed'),
Expand Down Expand Up @@ -49,6 +50,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]:
f'google-vertex:{n}' for n in get_model_names(GeminiModelName)
]
groq_names = [f'groq:{n}' for n in get_model_names(GroqModelName)]
moonshotai_names = [f'moonshotai:{n}' for n in get_model_names(MoonshotModelName)]
mistral_names = [f'mistral:{n}' for n in get_model_names(MistralModelName)]
openai_names = [f'openai:{n}' for n in get_model_names(OpenAIModelName)] + [
n for n in get_model_names(OpenAIModelName) if n.startswith('o1') or n.startswith('gpt') or n.startswith('o3')
Expand All @@ -65,6 +67,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]:
+ google_names
+ groq_names
+ mistral_names
+ moonshotai_names
+ openai_names
+ bedrock_names
+ deepseek_names
Expand Down
56 changes: 56 additions & 0 deletions tests/models/test_tool_choice_fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Ensure `tool_choice='required'` is downgraded to `'auto'` when the profile says so."""

from __future__ import annotations

import types
from typing import Any

import pytest

from pydantic_ai.models import ModelRequestParameters
from pydantic_ai.tools import ToolDefinition

from ..conftest import try_import

with try_import() as imports_successful:
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.profiles.openai import OpenAIModelProfile


pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed')


@pytest.mark.anyio()
async def test_tool_choice_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv('OPENAI_API_KEY', 'dummy')

model = OpenAIModel('stub', provider='openai')

# Make profile report lack of `tool_choice='required'` support but keep sampling
def fake_from_profile(_p: Any) -> types.SimpleNamespace:
return types.SimpleNamespace(
openai_supports_tool_choice_required=False,
openai_supports_sampling_settings=True,
)

monkeypatch.setattr(OpenAIModelProfile, 'from_profile', fake_from_profile, raising=True)

captured: dict[str, Any] = {}

async def fake_create(*_a: Any, **kw: Any) -> dict[str, Any]:
captured.update(kw)
return {}

# Patch chat completions create
monkeypatch.setattr(model.client.chat.completions, 'create', fake_create, raising=True)

params = ModelRequestParameters(function_tools=[ToolDefinition(name='x')], allow_text_output=False)

await model._completions_create( # pyright: ignore[reportPrivateUsage]
messages=[],
stream=False,
model_settings={},
model_request_parameters=params,
)

assert captured.get('tool_choice') == 'auto'
70 changes: 70 additions & 0 deletions tests/providers/test_moonshotai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import re

import httpx
import pytest

from pydantic_ai.exceptions import UserError
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile

from ..conftest import TestEnv, try_import

with try_import() as imports_successful:
import openai

from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.moonshotai import MoonshotAIProvider

pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed')


def test_moonshotai_provider():
"""Test basic MoonshotAI provider initialization."""
provider = MoonshotAIProvider(api_key='api-key')
assert provider.name == 'moonshotai'
assert provider.base_url == 'https://api.moonshot.ai/v1'
assert isinstance(provider.client, openai.AsyncOpenAI)
assert provider.client.api_key == 'api-key'


def test_moonshotai_provider_need_api_key(env: TestEnv) -> None:
"""Test that MoonshotAI provider requires an API key."""
env.remove('MOONSHOT_API_KEY')
with pytest.raises(
UserError,
match=re.escape(
'Set the `MOONSHOT_API_KEY` environment variable or pass it via `MoonshotAIProvider(api_key=...)`'
' to use the MoonshotAI provider.'
),
):
MoonshotAIProvider()


def test_moonshotai_provider_pass_http_client() -> None:
"""Test passing a custom HTTP client to MoonshotAI provider."""
http_client = httpx.AsyncClient()
provider = MoonshotAIProvider(http_client=http_client, api_key='api-key')
assert provider.client._client == http_client # type: ignore[reportPrivateUsage]


def test_moonshotai_pass_openai_client() -> None:
"""Test passing a custom OpenAI client to MoonshotAI provider."""
openai_client = openai.AsyncOpenAI(api_key='api-key')
provider = MoonshotAIProvider(openai_client=openai_client)
assert provider.client == openai_client


def test_moonshotai_provider_with_cached_http_client() -> None:
"""Test MoonshotAI provider using cached HTTP client (covers line 76)."""
# This should use the else branch with cached_async_http_client
provider = MoonshotAIProvider(api_key='api-key')
assert isinstance(provider.client, openai.AsyncOpenAI)
assert provider.client.api_key == 'api-key'


def test_moonshotai_model_profile():
provider = MoonshotAIProvider(api_key='api-key')
model = OpenAIModel('kimi-k2-0711-preview', provider=provider)
assert isinstance(model.profile, OpenAIModelProfile)
assert model.profile.json_schema_transformer == OpenAIJsonSchemaTransformer
assert model.profile.openai_supports_strict_tool_definition is False
assert model.profile.openai_supports_tool_choice_required is False
Loading