From 25c4a42c8baeff3e3304a07c9a0c6ae3679f347f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=86=E9=80=8A?= Date: Tue, 17 Mar 2026 23:53:22 +0800 Subject: [PATCH] feat(session): add TokenStats for session-level token usage accumulation - Add TokenStats dataclass to track cumulative token usage per session - Expose session.token_stats property for callers to access stats - Accumulate token usage from StatusUpdate wire messages during prompt() - Add 8 tests covering initial zero, accumulation, properties, and edge cases The SDK now provides data-level token statistics without handling UI display, allowing callers to decide how to present the information. --- python/src/kimi_agent_sdk/__init__.py | 3 +- python/src/kimi_agent_sdk/_session.py | 67 ++++++++++++- python/tests/test_token_stats.py | 139 ++++++++++++++++++++++++++ 3 files changed, 207 insertions(+), 2 deletions(-) create mode 100644 python/tests/test_token_stats.py diff --git a/python/src/kimi_agent_sdk/__init__.py b/python/src/kimi_agent_sdk/__init__.py index a021188..0e96e4a 100644 --- a/python/src/kimi_agent_sdk/__init__.py +++ b/python/src/kimi_agent_sdk/__init__.py @@ -122,12 +122,13 @@ async def main() -> None: from kimi_agent_sdk._approval import ApprovalHandlerFn from kimi_agent_sdk._exception import PromptValidationError, SessionStateError from kimi_agent_sdk._prompt import prompt -from kimi_agent_sdk._session import Session +from kimi_agent_sdk._session import Session, TokenStats __all__ = [ # Core API "prompt", "Session", + "TokenStats", # Approval "ApprovalHandlerFn", "ApprovalRequest", diff --git a/python/src/kimi_agent_sdk/_session.py b/python/src/kimi_agent_sdk/_session.py index 10ad26d..7550876 100644 --- a/python/src/kimi_agent_sdk/_session.py +++ b/python/src/kimi_agent_sdk/_session.py @@ -3,6 +3,7 @@ import asyncio import inspect from collections.abc import AsyncGenerator +from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any @@ -11,7 +12,7 @@ from kimi_cli.config import Config from kimi_cli.session import Session as CliSession from kimi_cli.soul import StatusSnapshot -from kimi_cli.wire.types import ContentPart, WireMessage +from kimi_cli.wire.types import ContentPart, StatusUpdate, TokenUsage, WireMessage from kimi_agent_sdk._exception import SessionStateError @@ -24,6 +25,58 @@ def _ensure_type(name: str, value: object, expected: type) -> None: raise TypeError(f"{name} must be {expected.__name__}, got {type(value).__name__}") +@dataclass +class TokenStats: + """Cumulative token usage statistics for a session. + + Accumulates token usage across multiple prompts in a session. + """ + + _input_other: int = field(default=0, init=False, repr=False) + _output: int = field(default=0, init=False, repr=False) + _input_cache_read: int = field(default=0, init=False, repr=False) + _input_cache_creation: int = field(default=0, init=False, repr=False) + + @property + def input_other(self) -> int: + """Non-cached input tokens.""" + return self._input_other + + @property + def output(self) -> int: + """Output tokens.""" + return self._output + + @property + def input_cache_read(self) -> int: + """Cache read tokens.""" + return self._input_cache_read + + @property + def input_cache_creation(self) -> int: + """Cache creation tokens.""" + return self._input_cache_creation + + def add(self, usage: TokenUsage | None) -> None: + """Add a TokenUsage to the cumulative stats.""" + if usage is None: + return + self._input_other += usage.input_other + self._output += usage.output + self._input_cache_read += usage.input_cache_read + self._input_cache_creation += usage.input_cache_creation + + @property + def input(self) -> int: + """Total input tokens (including cache).""" + return self._input_other + self._input_cache_read + self._input_cache_creation + + @property + def total(self) -> int: + """Total tokens (input + output).""" + return self.input + self._output + + class Session: """ Kimi Agent session with low-level control. @@ -36,6 +89,7 @@ def __init__(self, cli: KimiCLI) -> None: self._cli = cli self._cancel_event: asyncio.Event | None = None self._closed = False + self._token_stats = TokenStats() @staticmethod async def create( @@ -198,6 +252,14 @@ def status(self) -> StatusSnapshot: """Current status snapshot (context usage, yolo state, etc.).""" return self._cli.soul.status + @property + def token_stats(self) -> TokenStats: + """Cumulative token usage statistics for this session. + + Accumulates token usage across all prompts sent in this session. + """ + return self._token_stats + async def prompt( self, user_input: str | list[ContentPart], @@ -237,6 +299,9 @@ async def prompt( cancel_event, merge_wire_messages=merge_wire_messages, ): + # Accumulate token usage from StatusUpdate messages + if isinstance(msg, StatusUpdate) and msg.token_usage is not None: + self._token_stats.add(msg.token_usage) yield msg finally: if self._cancel_event is cancel_event: diff --git a/python/tests/test_token_stats.py b/python/tests/test_token_stats.py new file mode 100644 index 0000000..83d0dc2 --- /dev/null +++ b/python/tests/test_token_stats.py @@ -0,0 +1,139 @@ +"""Tests for TokenStats — session-level token usage accumulation.""" + +from __future__ import annotations + +import pytest +from kimi_cli.wire.types import TokenUsage + +from kimi_agent_sdk._session import TokenStats + +# ─── TokenStats standalone tests ───────────────────────────────────────────── + + +def test_initial_zero() -> None: + """TokenStats starts at zero.""" + stats = TokenStats() + assert stats.input_other == 0 + assert stats.output == 0 + assert stats.input_cache_read == 0 + assert stats.input_cache_creation == 0 + assert stats.input == 0 + assert stats.total == 0 + + +def test_add_accumulates() -> None: + """add() accumulates TokenUsage.""" + stats = TokenStats() + usage = TokenUsage( + input_other=100, + output=50, + input_cache_read=10, + input_cache_creation=5, + ) + + stats.add(usage) + + assert stats.input_other == 100 + assert stats.output == 50 + assert stats.input_cache_read == 10 + assert stats.input_cache_creation == 5 + assert stats.input == 115 # 100 + 10 + 5 + assert stats.total == 165 # 115 + 50 + + +def test_add_multiple_times() -> None: + """add() can be called multiple times to accumulate.""" + stats = TokenStats() + + usage1 = TokenUsage(input_other=100, output=50) + usage2 = TokenUsage( + input_other=200, + output=100, + input_cache_read=50, + ) + + stats.add(usage1) + stats.add(usage2) + + assert stats.input_other == 300 + assert stats.output == 150 + assert stats.input_cache_read == 50 + assert stats.total == 500 # 300 + 150 + 50 + + +def test_add_none_is_noop() -> None: + """add(None) is a no-op.""" + stats = TokenStats() + stats.add(TokenUsage(input_other=100, output=0)) + + stats.add(None) + + assert stats.input_other == 100 + assert stats.total == 100 + + +def test_properties_readonly() -> None: + """TokenStats properties are read-only.""" + stats = TokenStats() + + # Can read + _ = stats.input_other + _ = stats.output + _ = stats.input_cache_read + _ = stats.input_cache_creation + _ = stats.input + _ = stats.total + + # Cannot set (AttributeError) + with pytest.raises(AttributeError): + stats.input_other = 100 # type: ignore[misc] + + +def test_input_property() -> None: + """input property sums all input tokens.""" + stats = TokenStats() + stats.add( + TokenUsage( + input_other=100, + output=0, + input_cache_read=20, + input_cache_creation=5, + ) + ) + + assert stats.input == 125 + + +def test_total_property() -> None: + """total property sums input and output tokens.""" + stats = TokenStats() + stats.add( + TokenUsage( + input_other=100, + output=50, + input_cache_read=20, + ) + ) + + assert stats.input == 120 + assert stats.total == 170 + + +def test_add_token_usage_with_all_fields() -> None: + """TokenUsage with all fields populated works correctly.""" + stats = TokenStats() + usage = TokenUsage( + input_other=1000, + output=500, + input_cache_read=200, + input_cache_creation=50, + ) + + stats.add(usage) + + assert stats.input_other == 1000 + assert stats.output == 500 + assert stats.input_cache_read == 200 + assert stats.input_cache_creation == 50 + assert stats.input == 1250 # 1000 + 200 + 50 + assert stats.total == 1750 # 1250 + 500