Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/src/kimi_agent_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,13 @@ async def main() -> None:
from kimi_agent_sdk._approval import ApprovalHandlerFn
from kimi_agent_sdk._exception import PromptValidationError, SessionStateError
from kimi_agent_sdk._prompt import prompt
from kimi_agent_sdk._session import Session
from kimi_agent_sdk._session import Session, TokenStats

__all__ = [
# Core API
"prompt",
"Session",
"TokenStats",
# Approval
"ApprovalHandlerFn",
"ApprovalRequest",
Expand Down
67 changes: 66 additions & 1 deletion python/src/kimi_agent_sdk/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import inspect
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any

Expand All @@ -11,7 +12,7 @@
from kimi_cli.config import Config
from kimi_cli.session import Session as CliSession
from kimi_cli.soul import StatusSnapshot
from kimi_cli.wire.types import ContentPart, WireMessage
from kimi_cli.wire.types import ContentPart, StatusUpdate, TokenUsage, WireMessage

from kimi_agent_sdk._exception import SessionStateError

Expand All @@ -24,6 +25,58 @@ def _ensure_type(name: str, value: object, expected: type) -> None:
raise TypeError(f"{name} must be {expected.__name__}, got {type(value).__name__}")


@dataclass
class TokenStats:
"""Cumulative token usage statistics for a session.

Accumulates token usage across multiple prompts in a session.
"""

_input_other: int = field(default=0, init=False, repr=False)
_output: int = field(default=0, init=False, repr=False)
_input_cache_read: int = field(default=0, init=False, repr=False)
_input_cache_creation: int = field(default=0, init=False, repr=False)

@property
def input_other(self) -> int:
"""Non-cached input tokens."""
return self._input_other

@property
def output(self) -> int:
"""Output tokens."""
return self._output

@property
def input_cache_read(self) -> int:
"""Cache read tokens."""
return self._input_cache_read

@property
def input_cache_creation(self) -> int:
"""Cache creation tokens."""
return self._input_cache_creation

def add(self, usage: TokenUsage | None) -> None:
"""Add a TokenUsage to the cumulative stats."""
if usage is None:
return
self._input_other += usage.input_other
self._output += usage.output
self._input_cache_read += usage.input_cache_read
self._input_cache_creation += usage.input_cache_creation

@property
def input(self) -> int:
"""Total input tokens (including cache)."""
return self._input_other + self._input_cache_read + self._input_cache_creation

@property
def total(self) -> int:
"""Total tokens (input + output)."""
return self.input + self._output


class Session:
"""
Kimi Agent session with low-level control.
Expand All @@ -36,6 +89,7 @@ def __init__(self, cli: KimiCLI) -> None:
self._cli = cli
self._cancel_event: asyncio.Event | None = None
self._closed = False
self._token_stats = TokenStats()

@staticmethod
async def create(
Expand Down Expand Up @@ -198,6 +252,14 @@ def status(self) -> StatusSnapshot:
"""Current status snapshot (context usage, yolo state, etc.)."""
return self._cli.soul.status

@property
def token_stats(self) -> TokenStats:
"""Cumulative token usage statistics for this session.

Accumulates token usage across all prompts sent in this session.
"""
return self._token_stats

async def prompt(
self,
user_input: str | list[ContentPart],
Expand Down Expand Up @@ -237,6 +299,9 @@ async def prompt(
cancel_event,
merge_wire_messages=merge_wire_messages,
):
# Accumulate token usage from StatusUpdate messages
if isinstance(msg, StatusUpdate) and msg.token_usage is not None:
self._token_stats.add(msg.token_usage)
yield msg
finally:
if self._cancel_event is cancel_event:
Expand Down
139 changes: 139 additions & 0 deletions python/tests/test_token_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Tests for TokenStats — session-level token usage accumulation."""

from __future__ import annotations

import pytest
from kimi_cli.wire.types import TokenUsage

from kimi_agent_sdk._session import TokenStats

# ─── TokenStats standalone tests ─────────────────────────────────────────────


def test_initial_zero() -> None:
"""TokenStats starts at zero."""
stats = TokenStats()
assert stats.input_other == 0
assert stats.output == 0
assert stats.input_cache_read == 0
assert stats.input_cache_creation == 0
assert stats.input == 0
assert stats.total == 0


def test_add_accumulates() -> None:
"""add() accumulates TokenUsage."""
stats = TokenStats()
usage = TokenUsage(
input_other=100,
output=50,
input_cache_read=10,
input_cache_creation=5,
)

stats.add(usage)

assert stats.input_other == 100
assert stats.output == 50
assert stats.input_cache_read == 10
assert stats.input_cache_creation == 5
assert stats.input == 115 # 100 + 10 + 5
assert stats.total == 165 # 115 + 50


def test_add_multiple_times() -> None:
"""add() can be called multiple times to accumulate."""
stats = TokenStats()

usage1 = TokenUsage(input_other=100, output=50)
usage2 = TokenUsage(
input_other=200,
output=100,
input_cache_read=50,
)

stats.add(usage1)
stats.add(usage2)

assert stats.input_other == 300
assert stats.output == 150
assert stats.input_cache_read == 50
assert stats.total == 500 # 300 + 150 + 50


def test_add_none_is_noop() -> None:
"""add(None) is a no-op."""
stats = TokenStats()
stats.add(TokenUsage(input_other=100, output=0))

stats.add(None)

assert stats.input_other == 100
assert stats.total == 100


def test_properties_readonly() -> None:
"""TokenStats properties are read-only."""
stats = TokenStats()

# Can read
_ = stats.input_other
_ = stats.output
_ = stats.input_cache_read
_ = stats.input_cache_creation
_ = stats.input
_ = stats.total

# Cannot set (AttributeError)
with pytest.raises(AttributeError):
stats.input_other = 100 # type: ignore[misc]


def test_input_property() -> None:
"""input property sums all input tokens."""
stats = TokenStats()
stats.add(
TokenUsage(
input_other=100,
output=0,
input_cache_read=20,
input_cache_creation=5,
)
)

assert stats.input == 125


def test_total_property() -> None:
"""total property sums input and output tokens."""
stats = TokenStats()
stats.add(
TokenUsage(
input_other=100,
output=50,
input_cache_read=20,
)
)

assert stats.input == 120
assert stats.total == 170


def test_add_token_usage_with_all_fields() -> None:
"""TokenUsage with all fields populated works correctly."""
stats = TokenStats()
usage = TokenUsage(
input_other=1000,
output=500,
input_cache_read=200,
input_cache_creation=50,
)

stats.add(usage)

assert stats.input_other == 1000
assert stats.output == 500
assert stats.input_cache_read == 200
assert stats.input_cache_creation == 50
assert stats.input == 1250 # 1000 + 200 + 50
assert stats.total == 1750 # 1250 + 500