@@ -231,16 +231,23 @@ class KafkaRequestHandlerTest {
231
231
})
232
232
}
233
233
234
- def makeRequest (time : Time , metrics : RequestChannelMetrics ): RequestChannel .Request = {
234
+ def makeRequest (
235
+ time : Time ,
236
+ metrics : RequestChannelMetrics ,
237
+ apiKeys : ApiKeys = ApiKeys .API_VERSIONS ,
238
+ version : Short = 0 ,
239
+ buffer : ByteBuffer = ByteBuffer .allocate(0 ),
240
+ memoryPool : MemoryPool = mock(classOf [MemoryPool ])
241
+ ): RequestChannel .Request = {
235
242
// Make unsupported API versions request to avoid having to parse a real request
236
243
val requestHeader = mock(classOf [RequestHeader ])
237
- when(requestHeader.apiKey()).thenReturn(ApiKeys . API_VERSIONS )
238
- when(requestHeader.apiVersion()).thenReturn(0 .toShort )
244
+ when(requestHeader.apiKey()).thenReturn(apiKeys )
245
+ when(requestHeader.apiVersion()).thenReturn(version )
239
246
240
247
val context = new RequestContext (requestHeader, " 0" , mock(classOf [InetAddress ]), new KafkaPrincipal (" " , " " ),
241
248
new ListenerName (" " ), SecurityProtocol .PLAINTEXT , mock(classOf [ClientInformation ]), false )
242
249
new RequestChannel .Request (0 , context, time.nanoseconds(),
243
- mock( classOf [ MemoryPool ]), ByteBuffer .allocate( 0 ) , metrics)
250
+ memoryPool, buffer , metrics)
244
251
}
245
252
246
253
def setupBrokerTopicMetrics (systemRemoteStorageEnabled : Boolean = true ): BrokerTopicMetrics = {
@@ -699,4 +706,32 @@ class KafkaRequestHandlerTest {
699
706
// cleanup
700
707
brokerTopicStats.close()
701
708
}
709
+
710
+ @ Test
711
+ def testRequestBufferRelease (): Unit = {
712
+ val time = new MockTime ()
713
+ val metrics = new RequestChannelMetrics (Collections .emptySet[ApiKeys ])
714
+ val requestChannel = new RequestChannel (10 , time, metrics)
715
+ val apiHandler = mock(classOf [ApiRequestHandler ])
716
+ val memoryPool = mock(classOf [MemoryPool ])
717
+ val buffer = ByteBuffer .allocate(1024 )
718
+
719
+ val handler = new KafkaRequestHandler (0 , 0 , mock(classOf [Meter ]), new AtomicInteger (1 ), requestChannel, apiHandler, time)
720
+
721
+ val request = makeRequest(time, metrics, ApiKeys .PRODUCE , 3 , buffer, memoryPool)
722
+ requestChannel.sendRequest(request)
723
+
724
+ val shutdownThread = new Thread (() => {
725
+ try {
726
+ Thread .sleep(1000 )
727
+ requestChannel.sendShutdownRequest()
728
+ } catch {
729
+ case _ : InterruptedException =>
730
+ }
731
+ })
732
+
733
+ shutdownThread.start()
734
+ handler.run()
735
+ verify(memoryPool, times(1 )).release(buffer)
736
+ }
702
737
}
0 commit comments