@@ -71,3 +71,66 @@ async def test_async_sqs_logger_flush():
71
71
assert len (payload_data ["messages" ]) == 1
72
72
assert payload_data ["messages" ][0 ]["role" ] == "user"
73
73
assert payload_data ["messages" ][0 ]["content" ] == "hello"
74
+
75
+
76
+ @pytest .mark .asyncio
77
+ async def test_async_sqs_logger_error_flush ():
78
+ expected_queue_url = "https://sqs.us-east-1.amazonaws.com/123456789012/test-queue"
79
+ expected_region = "us-east-1"
80
+
81
+ sqs_logger = SQSLogger (
82
+ sqs_queue_url = expected_queue_url ,
83
+ sqs_region_name = expected_region ,
84
+ sqs_flush_interval = 1 ,
85
+ )
86
+
87
+ # Mock the httpx client
88
+ mock_response = MagicMock ()
89
+ mock_response .raise_for_status = Exception ("Something went wrong" )
90
+ sqs_logger .async_httpx_client .post = AsyncMock (return_value = mock_response )
91
+
92
+ litellm .callbacks = [sqs_logger ]
93
+
94
+ await litellm .acompletion (
95
+ model = "gpt-4o" ,
96
+ messages = [{"role" : "user" , "content" : "hello" }],
97
+ mock_response = "Error occurred"
98
+ )
99
+
100
+ await asyncio .sleep (2 )
101
+
102
+ # Verify that httpx post was called
103
+ sqs_logger .async_httpx_client .post .assert_called ()
104
+
105
+ # Get the call arguments
106
+ call_args = sqs_logger .async_httpx_client .post .call_args
107
+
108
+ # Verify the URL is correct
109
+ called_url = call_args [0 ][0 ] # First positional argument
110
+ assert called_url == expected_queue_url , f"Expected URL { expected_queue_url } , got { called_url } "
111
+
112
+ # Verify the payload contains StandardLoggingPayload data
113
+ called_data = call_args .kwargs ['data' ]
114
+
115
+ # Extract the MessageBody from the URL-encoded data
116
+ # Format: "Action=SendMessage&Version=2012-11-05&MessageBody=<url_encoded_json>"
117
+ assert "Action=SendMessage" in called_data
118
+ assert "Version=2012-11-05" in called_data
119
+ assert "MessageBody=" in called_data
120
+
121
+ # Extract and decode the message body
122
+ message_body_start = called_data .find ("MessageBody=" ) + len ("MessageBody=" )
123
+ message_body_encoded = called_data [message_body_start :]
124
+ message_body_json = unquote (message_body_encoded )
125
+
126
+ # Parse the JSON to verify it's a StandardLoggingPayload
127
+ payload_data = json .loads (message_body_json )
128
+
129
+ # Verify it has the expected StandardLoggingPayload structure
130
+ assert "model" in payload_data
131
+ assert "messages" in payload_data
132
+ assert "response" in payload_data
133
+ assert payload_data ["model" ] == "gpt-4o"
134
+ assert len (payload_data ["messages" ]) == 1
135
+ assert payload_data ["messages" ][0 ]["role" ] == "user"
136
+ assert payload_data ["messages" ][0 ]["content" ] == "hello"
0 commit comments