Skip to content

Commit d8fcc13

Browse files
author
Casey Quinn
committed
test(shared): assert method errors in helpers
1 parent f5bc4e3 commit d8fcc13

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

tests/shared/test_method_errors.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ def _assert_error(error: types.JSONRPCError, expected_code: int, expected_messag
2020
_ensure(error_payload.message == expected_message, f"unexpected error message: {error_payload.message}")
2121

2222

23-
async def _run_client_request(request: types.JSONRPCRequest) -> types.JSONRPCError:
23+
async def _run_client_request(
24+
request: types.JSONRPCRequest, *, expected_error: tuple[int, str] | None = None
25+
) -> types.JSONRPCError:
2426
request_send, request_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
2527
response_send, response_receive = anyio.create_memory_object_stream[SessionMessage](1)
2628

@@ -44,10 +46,18 @@ async def _run_client_request(request: types.JSONRPCRequest) -> types.JSONRPCErr
4446

4547
root = response_message.message.root
4648
_ensure(isinstance(root, types.JSONRPCError), "expected a JSON-RPC error response")
47-
return cast(types.JSONRPCError, root)
49+
error = cast(types.JSONRPCError, root)
4850

51+
if expected_error is not None:
52+
expected_code, expected_message = expected_error
53+
_assert_error(error, expected_code, expected_message)
4954

50-
async def _run_server_request(request: types.JSONRPCRequest) -> types.JSONRPCError:
55+
return error
56+
57+
58+
async def _run_server_request(
59+
request: types.JSONRPCRequest, *, expected_error: tuple[int, str] | None = None
60+
) -> types.JSONRPCError:
5161
request_send, request_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
5262
response_send, response_receive = anyio.create_memory_object_stream[SessionMessage](1)
5363

@@ -71,7 +81,13 @@ async def _run_server_request(request: types.JSONRPCRequest) -> types.JSONRPCErr
7181

7282
root = response_message.message.root
7383
_ensure(isinstance(root, types.JSONRPCError), "expected a JSON-RPC error response")
74-
return cast(types.JSONRPCError, root)
84+
error = cast(types.JSONRPCError, root)
85+
86+
if expected_error is not None:
87+
expected_code, expected_message = expected_error
88+
_assert_error(error, expected_code, expected_message)
89+
90+
return error
7591

7692

7793
@pytest.mark.anyio
@@ -85,27 +101,21 @@ async def _run_server_request(request: types.JSONRPCRequest) -> types.JSONRPCErr
85101
async def test_client_to_server_unknown_method_returns_method_not_found(method: str, request_id: int) -> None:
86102
request = types.JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params=None)
87103

88-
error = await _run_client_request(request)
89-
90-
_assert_error(error, types.METHOD_NOT_FOUND, "Method not found")
104+
await _run_client_request(request, expected_error=(types.METHOD_NOT_FOUND, "Method not found"))
91105

92106

93107
@pytest.mark.anyio
94108
async def test_client_to_server_invalid_params_returns_invalid_params() -> None:
95109
request = types.JSONRPCRequest(jsonrpc="2.0", id=2, method="resources/read", params={})
96110

97-
error = await _run_client_request(request)
98-
99-
_assert_error(error, types.INVALID_PARAMS, "Invalid request parameters")
111+
await _run_client_request(request, expected_error=(types.INVALID_PARAMS, "Invalid request parameters"))
100112

101113

102114
@pytest.mark.anyio
103115
async def test_server_to_client_unknown_method_returns_method_not_found() -> None:
104116
request = types.JSONRPCRequest(jsonrpc="2.0", id=3, method="server/unknown", params=None)
105117

106-
error = await _run_server_request(request)
107-
108-
_assert_error(error, types.METHOD_NOT_FOUND, "Method not found")
118+
await _run_server_request(request, expected_error=(types.METHOD_NOT_FOUND, "Method not found"))
109119

110120

111121
@pytest.mark.anyio
@@ -121,6 +131,4 @@ async def test_server_to_client_invalid_params_returns_invalid_params(
121131
) -> None:
122132
request = types.JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params=params)
123133

124-
error = await _run_server_request(request)
125-
126-
_assert_error(error, types.INVALID_PARAMS, "Invalid request parameters")
134+
await _run_server_request(request, expected_error=(types.INVALID_PARAMS, "Invalid request parameters"))

0 commit comments

Comments
 (0)