@@ -20,7 +20,9 @@ def _assert_error(error: types.JSONRPCError, expected_code: int, expected_messag
2020 _ensure (error_payload .message == expected_message , f"unexpected error message: { error_payload .message } " )
2121
2222
23- async def _run_client_request (request : types .JSONRPCRequest ) -> types .JSONRPCError :
23+ async def _run_client_request (
24+ request : types .JSONRPCRequest , * , expected_error : tuple [int , str ] | None = None
25+ ) -> types .JSONRPCError :
2426 request_send , request_receive = anyio .create_memory_object_stream [SessionMessage | Exception ](1 )
2527 response_send , response_receive = anyio .create_memory_object_stream [SessionMessage ](1 )
2628
@@ -44,10 +46,18 @@ async def _run_client_request(request: types.JSONRPCRequest) -> types.JSONRPCErr
4446
4547 root = response_message .message .root
4648 _ensure (isinstance (root , types .JSONRPCError ), "expected a JSON-RPC error response" )
47- return cast (types .JSONRPCError , root )
49+ error = cast (types .JSONRPCError , root )
4850
51+ if expected_error is not None :
52+ expected_code , expected_message = expected_error
53+ _assert_error (error , expected_code , expected_message )
4954
50- async def _run_server_request (request : types .JSONRPCRequest ) -> types .JSONRPCError :
55+ return error
56+
57+
58+ async def _run_server_request (
59+ request : types .JSONRPCRequest , * , expected_error : tuple [int , str ] | None = None
60+ ) -> types .JSONRPCError :
5161 request_send , request_receive = anyio .create_memory_object_stream [SessionMessage | Exception ](1 )
5262 response_send , response_receive = anyio .create_memory_object_stream [SessionMessage ](1 )
5363
@@ -71,7 +81,13 @@ async def _run_server_request(request: types.JSONRPCRequest) -> types.JSONRPCErr
7181
7282 root = response_message .message .root
7383 _ensure (isinstance (root , types .JSONRPCError ), "expected a JSON-RPC error response" )
74- return cast (types .JSONRPCError , root )
84+ error = cast (types .JSONRPCError , root )
85+
86+ if expected_error is not None :
87+ expected_code , expected_message = expected_error
88+ _assert_error (error , expected_code , expected_message )
89+
90+ return error
7591
7692
7793@pytest .mark .anyio
@@ -85,27 +101,21 @@ async def _run_server_request(request: types.JSONRPCRequest) -> types.JSONRPCErr
85101async def test_client_to_server_unknown_method_returns_method_not_found (method : str , request_id : int ) -> None :
86102 request = types .JSONRPCRequest (jsonrpc = "2.0" , id = request_id , method = method , params = None )
87103
88- error = await _run_client_request (request )
89-
90- _assert_error (error , types .METHOD_NOT_FOUND , "Method not found" )
104+ await _run_client_request (request , expected_error = (types .METHOD_NOT_FOUND , "Method not found" ))
91105
92106
93107@pytest .mark .anyio
94108async def test_client_to_server_invalid_params_returns_invalid_params () -> None :
95109 request = types .JSONRPCRequest (jsonrpc = "2.0" , id = 2 , method = "resources/read" , params = {})
96110
97- error = await _run_client_request (request )
98-
99- _assert_error (error , types .INVALID_PARAMS , "Invalid request parameters" )
111+ await _run_client_request (request , expected_error = (types .INVALID_PARAMS , "Invalid request parameters" ))
100112
101113
102114@pytest .mark .anyio
103115async def test_server_to_client_unknown_method_returns_method_not_found () -> None :
104116 request = types .JSONRPCRequest (jsonrpc = "2.0" , id = 3 , method = "server/unknown" , params = None )
105117
106- error = await _run_server_request (request )
107-
108- _assert_error (error , types .METHOD_NOT_FOUND , "Method not found" )
118+ await _run_server_request (request , expected_error = (types .METHOD_NOT_FOUND , "Method not found" ))
109119
110120
111121@pytest .mark .anyio
@@ -121,6 +131,4 @@ async def test_server_to_client_invalid_params_returns_invalid_params(
121131) -> None :
122132 request = types .JSONRPCRequest (jsonrpc = "2.0" , id = request_id , method = method , params = params )
123133
124- error = await _run_server_request (request )
125-
126- _assert_error (error , types .INVALID_PARAMS , "Invalid request parameters" )
134+ await _run_server_request (request , expected_error = (types .INVALID_PARAMS , "Invalid request parameters" ))
0 commit comments