Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to kotlin style callbacks. Close #14 #36

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Add the dependency:

```kotlin
dependencies {
implementation("io.modelcontextprotocol:kotlin-sdk:0.2.0")
implementation("io.modelcontextprotocol:kotlin-sdk:0.3.0")
}
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ public class SSEClientTransport(
private var session: ClientSSESession by Delegates.notNull()
private val endpoint = CompletableDeferred<String>()

override var onClose: (() -> Unit)? = null
override var onError: ((Throwable) -> Unit)? = null
override var onMessage: (suspend ((JSONRPCMessage) -> Unit))? = null
private var _onClose: (() -> Unit) = {}
private var _onError: ((Throwable) -> Unit) = {}
private var _onMessage: (suspend ((JSONRPCMessage) -> Unit)) = {}

private var job: Job? = null

Expand Down Expand Up @@ -67,7 +67,7 @@ public class SSEClientTransport(
when (event.event) {
"error" -> {
val e = IllegalStateException("SSE error: ${event.data}")
onError?.invoke(e)
_onError(e)
throw e
}

Expand All @@ -84,7 +84,7 @@ public class SSEClientTransport(

endpoint.complete(maybeEndpoint.toString())
} catch (e: Exception) {
onError?.invoke(e)
_onError(e)
close()
error(e)
}
Expand All @@ -93,9 +93,9 @@ public class SSEClientTransport(
else -> {
try {
val message = McpJson.decodeFromString<JSONRPCMessage>(event.data ?: "")
onMessage?.invoke(message)
_onMessage(message)
} catch (e: Exception) {
onError?.invoke(e)
_onError(e)
}
}
}
Expand All @@ -122,7 +122,7 @@ public class SSEClientTransport(
error("Error POSTing to endpoint (HTTP ${response.status}): $text")
}
} catch (e: Exception) {
onError?.invoke(e)
_onError(e)
throw e
}
}
Expand All @@ -133,7 +133,31 @@ public class SSEClientTransport(
}

session.cancel()
onClose?.invoke()
_onClose()
job?.cancelAndJoin()
}

override fun onClose(block: () -> Unit) {
val old = _onClose
_onClose = {
old()
block()
}
}

override fun onError(block: (Throwable) -> Unit) {
val old = _onError
_onError = { e ->
old(e)
block(e)
}
}

override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) {
val old = _onMessage
_onMessage = { message ->
old(message)
block(message)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ public class ServerOptions(
public open class Server(
private val serverInfo: Implementation,
options: ServerOptions,
public var onCloseCallback: (() -> Unit)? = null
) : Protocol(options) {
private var _onInitialized: (() -> Unit) = {}
private var _onClose: () -> Unit = {}

/**
* The client's reported capabilities after initialization.
Expand All @@ -49,18 +50,13 @@ public open class Server(
*/
public var clientVersion: Implementation? = null
private set

private val capabilities: ServerCapabilities = options.capabilities

private val tools = mutableMapOf<String, RegisteredTool>()
private val prompts = mutableMapOf<String, RegisteredPrompt>()
private val resources = mutableMapOf<String, RegisteredResource>()

/**
* A callback invoked when the server has completed the initialization sequence.
* After initialization, the server is ready to handle requests.
*/
public var onInitialized: (() -> Unit)? = null

init {
logger.debug { "Initializing MCP server with capabilities: $capabilities" }

Expand All @@ -69,7 +65,7 @@ public open class Server(
handleInitialize(request)
}
setNotificationHandler<InitializedNotification>(Method.Defined.NotificationsInitialized) {
onInitialized?.invoke()
_onInitialized()
CompletableDeferred(Unit)
}

Expand Down Expand Up @@ -107,13 +103,35 @@ public open class Server(
}
}

/**
* Registers a callback to be invoked when the server has completed initialization.
*/
public fun onInitalized(block: () -> Unit) {
val old = _onInitialized
_onInitialized = {
old()
block()
}
}

/**
* Registers a callback to be invoked when the server connection is closing.
*/
public fun onClose(block: () -> Unit) {
val old = _onClose
_onClose = {
old()
block()
}
}

/**
* Called when the server connection is closing.
* Invokes [onCloseCallback] if set.
*/
override fun onClose() {
logger.info { "Server connection closing" }
onCloseCallback?.invoke()
_onClose()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,18 @@ public interface Transport {
*
* This should be invoked when close() is called as well.
*/
public var onClose: (() -> Unit)?
public fun onClose(block: () -> Unit)

/**
* Callback for when an error occurs.
*
* Note that errors are not necessarily fatal; they are used for reporting any kind of
* exceptional condition out of band.
*/
public var onError: ((Throwable) -> Unit)?
public fun onError(block: (Throwable) -> Unit)

/**
* Callback for when a message (request or response) is received over the connection.
*/
public var onMessage: (suspend ((JSONRPCMessage) -> Unit))?
public fun onMessage(block: suspend (JSONRPCMessage) -> Unit)
}
Loading