From eaae9b035fd7d37cb4e045f4ed872208e4b5c2c0 Mon Sep 17 00:00:00 2001 From: likun Date: Mon, 31 Mar 2025 17:55:29 +0800 Subject: [PATCH] Reduce excessive off-heap memory consumption caused by occasional large keys. --- .../java/io/lettuce/core/ClientOptions.java | 62 ++++++++++- .../lettuce/core/protocol/CommandHandler.java | 37 ++++++- .../protocol/CommandHandlerUnitTests.java | 101 ++++++++++++++++++ 3 files changed, 198 insertions(+), 2 deletions(-) diff --git a/src/main/java/io/lettuce/core/ClientOptions.java b/src/main/java/io/lettuce/core/ClientOptions.java index a16845b00..170dae34e 100644 --- a/src/main/java/io/lettuce/core/ClientOptions.java +++ b/src/main/java/io/lettuce/core/ClientOptions.java @@ -94,6 +94,10 @@ public class ClientOptions implements Serializable { public static final boolean DEFAULT_USE_HASH_INDEX_QUEUE = true; + public static final int DEFAULT_READ_BUFFER_SIZE = 64 * 1024; + + private static final boolean DEFAULT_CREATE_BYTEBUF_WHEN_RECV_LARGE_KEY = false; + private final boolean autoReconnect; private final Predicate> replayFilter; @@ -130,6 +134,10 @@ public class ClientOptions implements Serializable { private final boolean useHashIndexedQueue; + private final boolean createByteBufWhenRecvLargeKey; + + private final int readBufferSize; + protected ClientOptions(Builder builder) { this.autoReconnect = builder.autoReconnect; this.replayFilter = builder.replayFilter; @@ -149,6 +157,8 @@ protected ClientOptions(Builder builder) { this.suspendReconnectOnProtocolFailure = builder.suspendReconnectOnProtocolFailure; this.timeoutOptions = builder.timeoutOptions; this.useHashIndexedQueue = builder.useHashIndexedQueue; + this.createByteBufWhenRecvLargeKey = builder.createByteBufWhenRecvLargeKey; + this.readBufferSize = builder.readBufferSize; } protected ClientOptions(ClientOptions original) { @@ -170,6 +180,8 @@ protected ClientOptions(ClientOptions original) { this.suspendReconnectOnProtocolFailure = original.isSuspendReconnectOnProtocolFailure(); this.timeoutOptions = original.getTimeoutOptions(); this.useHashIndexedQueue = original.isUseHashIndexedQueue(); + this.createByteBufWhenRecvLargeKey = original.isCreateByteBufWhenRecvLargeKey(); + this.readBufferSize = original.getReadBufferSize(); } /** @@ -241,6 +253,10 @@ public static class Builder { private boolean useHashIndexedQueue = DEFAULT_USE_HASH_INDEX_QUEUE; + private boolean createByteBufWhenRecvLargeKey = DEFAULT_CREATE_BYTEBUF_WHEN_RECV_LARGE_KEY; + + private int readBufferSize = DEFAULT_READ_BUFFER_SIZE; + protected Builder() { } @@ -529,6 +545,31 @@ public Builder useHashIndexQueue(boolean useHashIndexedQueue) { return this; } + /** + * In the case of large keys, the command handler will create a {@link io.netty.buffer.ByteBuf} to hold the large key. + * to avoid occassional large keys from occupying excessive memory + * + * @param createByteBufWhenRecvLargeKey true/false + * @return {@code this} + * @see io.lettuce.core.protocol.CommandHandler + */ + public Builder createByteBufWhenRecvLargeKey(boolean createByteBufWhenRecvLargeKey) { + this.createByteBufWhenRecvLargeKey = createByteBufWhenRecvLargeKey; + return this; + } + + /** + * Set the read buffer size for receiving data from Redis server in bytes. See {@link #DEFAULT_READ_BUFFER_SIZE}. + * + * @param readBufferSize Read ByteBuf Size + * @return {@code this} + * @see io.lettuce.core.protocol.CommandHandler + */ + public Builder readBufferSize(int readBufferSize) { + this.readBufferSize = readBufferSize; + return this; + } + /** * Create a new instance of {@link ClientOptions}. * @@ -558,7 +599,8 @@ public ClientOptions.Builder mutate() { .pingBeforeActivateConnection(isPingBeforeActivateConnection()).protocolVersion(getConfiguredProtocolVersion()) .requestQueueSize(getRequestQueueSize()).scriptCharset(getScriptCharset()).jsonParser(getJsonParser()) .socketOptions(getSocketOptions()).sslOptions(getSslOptions()) - .suspendReconnectOnProtocolFailure(isSuspendReconnectOnProtocolFailure()).timeoutOptions(getTimeoutOptions()); + .suspendReconnectOnProtocolFailure(isSuspendReconnectOnProtocolFailure()).timeoutOptions(getTimeoutOptions()) + .createByteBufWhenRecvLargeKey(isCreateByteBufWhenRecvLargeKey()).readBufferSize(getReadBufferSize()); return builder; } @@ -773,6 +815,24 @@ public TimeoutOptions getTimeoutOptions() { return timeoutOptions; } + /** + * get original readBuffer size in {@link io.lettuce.core.protocol.CommandHandler} + * + * @return the original readBuffer size in {@link io.lettuce.core.protocol.CommandHandler} + */ + public int getReadBufferSize() { + return readBufferSize; + } + + /** + * if true , the client will create a bytebuf when recv large key ( size > {@link #readBufferSize}) + * + * @return true/false + */ + public boolean isCreateByteBufWhenRecvLargeKey() { + return createByteBufWhenRecvLargeKey; + } + /** * Defines the re-authentication behavior of the Redis client. *

diff --git a/src/main/java/io/lettuce/core/protocol/CommandHandler.java b/src/main/java/io/lettuce/core/protocol/CommandHandler.java index 59aee61e0..4b09b12e6 100644 --- a/src/main/java/io/lettuce/core/protocol/CommandHandler.java +++ b/src/main/java/io/lettuce/core/protocol/CommandHandler.java @@ -55,6 +55,7 @@ import io.lettuce.core.tracing.Tracer; import io.lettuce.core.tracing.Tracing; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.netty.channel.Channel; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelHandler; @@ -119,6 +120,10 @@ public class CommandHandler extends ChannelDuplexHandler implements HasQueuedCom private final BackpressureSource backpressureSource = new BackpressureSource(); + private final int defaultReadBufferSize; + + private final boolean createByteBufWhenRecvLargeKey; + private RedisStateMachine rsm; private Channel channel; @@ -139,6 +144,10 @@ public class CommandHandler extends ChannelDuplexHandler implements HasQueuedCom private Tracing.Endpoint tracedEndpoint; + private ByteBuf tmpReadBuffer; + + private ByteBufAllocator byteBufAllocator; + /** * Initialize a new instance that handles commands from the supplied queue. * @@ -165,6 +174,10 @@ public CommandHandler(ClientOptions clientOptions, ClientResources clientResourc this.tracingEnabled = tracing.isEnabled(); this.decodeBufferPolicy = clientOptions.getDecodeBufferPolicy(); + + this.defaultReadBufferSize = clientOptions.getReadBufferSize(); + + this.createByteBufWhenRecvLargeKey = clientOptions.isCreateByteBufWhenRecvLargeKey(); } public Endpoint getEndpoint() { @@ -222,7 +235,8 @@ public void channelRegistered(ChannelHandlerContext ctx) throws Exception { setState(LifecycleState.REGISTERED); - readBuffer = ctx.alloc().buffer(8192 * 8); + byteBufAllocator = ctx.alloc(); + readBuffer = ctx.alloc().buffer(defaultReadBufferSize); rsm = new RedisStateMachine(); ctx.fireChannelRegistered(); } @@ -615,6 +629,15 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception logger.trace("{} Buffer: {}", logPrefix(), input.toString(Charset.defaultCharset()).trim()); } + // if buffer capacity larger than default capacity, then create a new buffer with double capacity + if (createByteBufWhenRecvLargeKey && readBuffer.capacity() == defaultReadBufferSize + && readBuffer.writableBytes() < input.readableBytes() && byteBufAllocator != null) { + ByteBuf byteBuf = byteBufAllocator.directBuffer(readBuffer.capacity() << 1); + byteBuf.writeBytes(readBuffer); + tmpReadBuffer = readBuffer; + readBuffer = byteBuf; + } + readBuffer.touch("CommandHandler.read(…)"); readBuffer.writeBytes(input); @@ -642,6 +665,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer) throws Interrup } } + boolean decodeComplete = false; while (canDecode(buffer)) { if (isPushDecode(buffer)) { @@ -702,6 +726,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer) throws Interrup logger.debug("{} Completing command {}", logPrefix(), command); } complete(command); + decodeComplete = true; } catch (Exception e) { logger.warn("{} Unexpected exception during request: {}", logPrefix, e.toString(), e); } @@ -712,6 +737,16 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer) throws Interrup } decodeBufferPolicy.afterDecoding(buffer); + releaseLargeBufferIfNecessary(buffer, decodeComplete); + } + + private void releaseLargeBufferIfNecessary(ByteBuf buffer, boolean decodeComplete) { + if (decodeComplete && this.tmpReadBuffer != null) { + buffer.release(); + this.readBuffer = tmpReadBuffer; + this.readBuffer.clear(); + this.tmpReadBuffer = null; + } } protected void notifyPushListeners(PushMessage notification) { diff --git a/src/test/java/io/lettuce/core/protocol/CommandHandlerUnitTests.java b/src/test/java/io/lettuce/core/protocol/CommandHandlerUnitTests.java index e7f7cfec6..f6d7f620a 100644 --- a/src/test/java/io/lettuce/core/protocol/CommandHandlerUnitTests.java +++ b/src/test/java/io/lettuce/core/protocol/CommandHandlerUnitTests.java @@ -26,11 +26,20 @@ import static org.mockito.Mockito.*; import static org.mockito.Mockito.eq; +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; import java.net.Inet4Address; import java.net.InetSocketAddress; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; import java.time.Duration; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -65,6 +74,7 @@ import io.lettuce.core.output.KeyValueListOutput; import io.lettuce.core.output.StatusOutput; import io.lettuce.core.output.ValueListOutput; +import io.lettuce.core.output.ValueOutput; import io.lettuce.core.resource.ClientResources; import io.lettuce.core.tracing.Tracing; import io.lettuce.test.Delay; @@ -650,4 +660,95 @@ void shouldHandleNullBuffers() throws Exception { sut.channelUnregistered(context); } + /** + * if large keys are received ,the large buffer will created and then released + */ + @Test + void shouldLargeBufferCreatedAndRelease() throws Exception { + ChannelPromise channelPromise = new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + channelPromise.setSuccess(); + ClientOptions clientOptions = ClientOptions.builder().createByteBufWhenRecvLargeKey(true).build(); + sut = new CommandHandler(clientOptions, clientResources, endpoint); + sut.channelRegistered(context); + sut.channelActive(context); + sut.getStack().add(new Command<>(CommandType.GET, new ValueOutput(StringCodec.UTF8))); + + ByteBuf internalBuffer = ReflectionTestUtils.getField(sut, "readBuffer"); + + // step1 Receive First TCP Packet + ByteBuf msg = context.alloc().buffer(13); + // 1+5+2+7 ($+length+\r\n+len(value)) + msg.writeBytes("$65536\r\nval_abc".getBytes(StandardCharsets.UTF_8)); + sut.channelRead(context, msg); + + int markedReaderIndex = ReflectionTestUtils.getField(internalBuffer, "markedReaderIndex"); + assertThat(markedReaderIndex).isEqualTo(8); + assertThat(internalBuffer.readerIndex()).isEqualTo(8); + assertThat(internalBuffer.writerIndex()).isEqualTo(15); + + // step2 Receive Second TCP Packet + ByteBuf msg2 = context.alloc().buffer(64 * 1024); + StringBuilder sb = new StringBuilder(); + // 65536-7 + for (int i = 0; i < 65529; i++) { + sb.append((char) ('a' + i % 26)); + } + sb.append("\r\n"); + msg2.writeBytes(sb.toString().getBytes(StandardCharsets.UTF_8)); + sut.channelRead(context, msg2); + + // step3 Got Result: readBuffer.capacity = 64k and tmpBuffer is null + ByteBuf readBuffer = ReflectionTestUtils.getField(sut, "readBuffer"); + assertThat(readBuffer.capacity()).isEqualTo(64 * 1024); + assertThat(readBuffer.readerIndex()).isZero(); + assertThat(readBuffer.writerIndex()).isZero(); + + ByteBuf tmpBuffer = ReflectionTestUtils.getField(sut, "tmpReadBuffer"); + assertThat(tmpBuffer).isNull(); + } + + /** + * readBufferSize = 16 createByteBufWhenRecvLargeKey = true large than readBufferSize will create a new buffer + * + */ + @Test + void shouldLargeThan16ThenCreateAndRelease() throws Exception { + ChannelPromise channelPromise = new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + channelPromise.setSuccess(); + int readBufferSize = 16; + ClientOptions clientOptions = ClientOptions.builder().createByteBufWhenRecvLargeKey(true).readBufferSize(readBufferSize) + .build(); + sut = new CommandHandler(clientOptions, clientResources, endpoint); + sut.channelRegistered(context); + sut.channelActive(context); + sut.getStack().add(new Command<>(CommandType.GET, new ValueOutput(StringCodec.UTF8))); + + ByteBuf internalBuffer = ReflectionTestUtils.getField(sut, "readBuffer"); + + // step1 Receive First TCP Packet + ByteBuf msg = context.alloc().buffer(13); + // 1+5+2+7 ($+length+\r\n+len(value)) + msg.writeBytes("$16\r\nabc_val".getBytes(StandardCharsets.UTF_8)); + sut.channelRead(context, msg); + + int markedReaderIndex = ReflectionTestUtils.getField(internalBuffer, "markedReaderIndex"); + assertThat(markedReaderIndex).isEqualTo(5); + assertThat(internalBuffer.readerIndex()).isEqualTo(5); + assertThat(internalBuffer.writerIndex()).isEqualTo(12); + + // step2 Receive Second TCP Packet + ByteBuf msg2 = context.alloc().buffer(32); + msg2.writeBytes("abcd_abcd\r\n".getBytes(StandardCharsets.UTF_8)); + sut.channelRead(context, msg2); + + // step3 Got Result: readBuffer.capacity = 64k and tmpBuffer is null + ByteBuf readBuffer = ReflectionTestUtils.getField(sut, "readBuffer"); + assertThat(readBuffer.capacity()).isEqualTo(readBufferSize); + assertThat(readBuffer.readerIndex()).isZero(); + assertThat(readBuffer.writerIndex()).isZero(); + + ByteBuf tmpBuffer = ReflectionTestUtils.getField(sut, "tmpReadBuffer"); + assertThat(tmpBuffer).isNull(); + } + }