Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(cdc): pass ownership of source message sender to java thread #20353

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import com.risingwave.connector.source.common.DbzConnectorConfig;
import com.risingwave.connector.source.common.DbzSourceUtils;
import com.risingwave.java.binding.Binding;
import com.risingwave.java.binding.CdcSourceChannel;
import com.risingwave.proto.ConnectorServiceProto.GetEventStreamResponse;
import io.debezium.config.CommonConnectorConfig;
import io.grpc.stub.StreamObserver;
Expand Down Expand Up @@ -69,7 +69,7 @@ public static DbzCdcEngineRunner newCdcEngineRunner(
return runner;
}

public static DbzCdcEngineRunner create(DbzConnectorConfig config, long channelPtr) {
public static DbzCdcEngineRunner create(DbzConnectorConfig config, CdcSourceChannel channel) {
DbzCdcEngineRunner runner = new DbzCdcEngineRunner(config);
try {
var sourceId = config.getSourceId();
Expand All @@ -90,8 +90,7 @@ public static DbzCdcEngineRunner create(DbzConnectorConfig config, long channelP
(error != null && error.getMessage() != null
? error.getMessage()
: message);
if (!Binding.sendCdcSourceErrorToChannel(
channelPtr, errorMsg)) {
if (!channel.sendError(errorMsg)) {
LOG.warn(
"engine#{} unable to send error message: {}",
sourceId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import com.risingwave.connector.source.common.CdcConnectorException;
import com.risingwave.connector.source.common.DbzConnectorConfig;
import com.risingwave.connector.source.common.DbzSourceUtils;
import com.risingwave.java.binding.Binding;
import com.risingwave.java.binding.CdcSourceChannel;
import com.risingwave.metrics.ConnectorNodeMetrics;
import com.risingwave.proto.ConnectorServiceProto;
import com.risingwave.proto.ConnectorServiceProto.GetEventStreamResponse;
Expand All @@ -41,9 +41,9 @@ public class JniDbzSourceHandler {
private final DbzConnectorConfig config;
private final DbzCdcEngineRunner runner;

public JniDbzSourceHandler(DbzConnectorConfig config, long channelPtr) {
public JniDbzSourceHandler(DbzConnectorConfig config, CdcSourceChannel channel) {
this.config = config;
this.runner = DbzCdcEngineRunner.create(config, channelPtr);
this.runner = DbzCdcEngineRunner.create(config, channel);

if (runner == null) {
throw new CdcConnectorException("Failed to create engine runner");
Expand All @@ -52,6 +52,9 @@ public JniDbzSourceHandler(DbzConnectorConfig config, long channelPtr) {

public static void runJniDbzSourceThread(byte[] getEventStreamRequestBytes, long channelPtr)
throws Exception {

var channel = CdcSourceChannel.fromOwnedPointer(channelPtr);

var request =
ConnectorServiceProto.GetEventStreamRequest.parseFrom(getEventStreamRequestBytes);
// userProps extracted from request, underlying implementation is UnmodifiableMap
Expand All @@ -72,10 +75,10 @@ public static void runJniDbzSourceThread(byte[] getEventStreamRequestBytes, long
mutableUserProps,
request.getSnapshotDone(),
isCdcSourceJob);
JniDbzSourceHandler handler = new JniDbzSourceHandler(config, channelPtr);
JniDbzSourceHandler handler = new JniDbzSourceHandler(config, channel);
// register handler to the registry
JniDbzSourceRegistry.register(config.getSourceId(), handler);
handler.start(channelPtr);
handler.start(channel);
}

public void commitOffset(String encodedOffset) throws InterruptedException {
Expand All @@ -96,12 +99,12 @@ public void commitOffset(String encodedOffset) throws InterruptedException {
}
}

public void start(long channelPtr) {
public void start(CdcSourceChannel channel) {

try {
// Start the engine
var startOk = runner.start();
if (!sendHandshakeMessage(runner, channelPtr, startOk)) {
if (!sendHandshakeMessage(runner, channel, startOk)) {
LOG.error(
"Failed to send handshake message to channel. sourceId={}",
config.getSourceId());
Expand All @@ -125,10 +128,10 @@ public void start(long channelPtr) {
"Engine#{}: emit one chunk {} events to network ",
config.getSourceId(),
resp.getEventsCount());
success = Binding.sendCdcSourceMsgToChannel(channelPtr, resp.toByteArray());
success = channel.send(resp.toByteArray());
} else {
// If resp is null means just check whether channel is closed.
success = Binding.sendCdcSourceMsgToChannel(channelPtr, null);
success = channel.send(null);
}
if (!success) {
LOG.info(
Expand All @@ -152,7 +155,7 @@ public void start(long channelPtr) {
}

private boolean sendHandshakeMessage(
DbzCdcEngineRunner runner, long channelPtr, boolean startOk) throws Exception {
DbzCdcEngineRunner runner, CdcSourceChannel channel, boolean startOk) throws Exception {
// send a handshake message to notify the Source executor
// if the handshake is not ok, the split reader will return error to source actor
var controlInfo =
Expand All @@ -163,7 +166,7 @@ private boolean sendHandshakeMessage(
.setSourceId(config.getSourceId())
.setControl(controlInfo)
.build();
var success = Binding.sendCdcSourceMsgToChannel(channelPtr, handshakeMsg.toByteArray());
var success = channel.send(handshakeMsg.toByteArray());
if (!success) {
LOG.info(
"Engine#{}: JNI sender broken detected, stop the engine", config.getSourceId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ public static native void tracingSlf4jEvent(

public static native boolean sendCdcSourceErrorToChannel(long channelPtr, String errorMsg);

public static native void cdcSourceSenderClose(long channelPtr);

public static native com.risingwave.java.binding.JniSinkWriterStreamRequest
recvSinkWriterRequestFromChannel(long channelPtr);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright 2025 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.risingwave.java.binding;

public class CdcSourceChannel implements AutoCloseable {
private final long pointer;

CdcSourceChannel(long pointer) {
this.pointer = pointer;
}

public static CdcSourceChannel fromOwnedPointer(long pointer) {
return new CdcSourceChannel(pointer);
}

public boolean send(byte[] msg) {
return Binding.sendCdcSourceMsgToChannel(pointer, msg);
}

public boolean sendError(String errorMsg) {
return Binding.sendCdcSourceErrorToChannel(pointer, errorMsg);
}

@Override
public void close() {
Binding.cdcSourceSenderClose(pointer);
}
}
10 changes: 7 additions & 3 deletions src/connector/src/source/cdc/source/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use risingwave_common::bail;
use risingwave_common::metrics::GLOBAL_ERROR_METRICS;
use risingwave_common::util::addr::HostAddr;
use risingwave_jni_core::jvm_runtime::{execute_with_jni_env, JVM};
use risingwave_jni_core::{call_static_method, JniReceiverType, JniSenderType};
use risingwave_jni_core::{call_static_method, JniReceiverType, OwnedPointer};
use risingwave_pb::connector_service::{GetEventStreamRequest, GetEventStreamResponse};
use thiserror_ext::AsReport;
use tokio::sync::mpsc;
Expand Down Expand Up @@ -98,7 +98,7 @@ impl<T: CdcSourceTypeTrait> SplitReader for CdcSplitReader<T> {

let source_id = split.split_id() as u64;
let source_type = conn_props.get_source_type_pb();
let (mut tx, mut rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
let (tx, mut rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);

let jvm = JVM.get_or_init()?;
let get_event_stream_request = GetEventStreamRequest {
Expand Down Expand Up @@ -129,12 +129,16 @@ impl<T: CdcSourceTypeTrait> SplitReader for CdcSplitReader<T> {
}
};

// `runJniDbzSourceThread` will take ownership of `tx`, and release it later in
// `Java_com_risingwave_java_binding_Binding_cdcSourceSenderClose` via `AutoClosable`.
let tx: OwnedPointer<_> = tx.into();

let result = call_static_method!(
env,
{com.risingwave.connector.source.core.JniDbzSourceHandler},
{void runJniDbzSourceThread(byte[] getEventStreamRequestBytes, long channelPtr)},
&get_event_stream_request_bytes,
&mut tx as *mut JniSenderType<GetEventStreamResponse>
tx.into_pointer()
);

match result {
Expand Down
30 changes: 24 additions & 6 deletions src/jni_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,21 +152,31 @@ impl<T> From<T> for Pointer<'static, T> {

impl<'a, T> Pointer<'a, T> {
fn as_ref(&self) -> &'a T {
debug_assert!(self.pointer != 0);
assert!(self.pointer != 0);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't seem to be a heavy operation.

unsafe { &*(self.pointer as *const T) }
}

fn as_mut(&mut self) -> &'a mut T {
debug_assert!(self.pointer != 0);
assert!(self.pointer != 0);
unsafe { &mut *(self.pointer as *mut T) }
}
}

/// A pointer that owns the object it points to.
///
/// Note that dropping an `OwnedPointer` does not release the object.
/// Instead, you should call [`OwnedPointer::release`] manually.
pub type OwnedPointer<T> = Pointer<'static, T>;

impl<T> OwnedPointer<T> {
fn drop(self) {
debug_assert!(self.pointer != 0);
/// Consume `self` and return the pointer value. Used for passing to JNI.
pub fn into_pointer(self) -> jlong {
self.pointer
}

/// Release the object behind the pointer.
fn release(self) {
assert!(self.pointer != 0);
unsafe { drop(Box::from_raw(self.pointer as *mut T)) }
}
}
Expand Down Expand Up @@ -389,7 +399,7 @@ extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorClose<'a>(
_env: EnvParam<'a>,
pointer: OwnedPointer<JavaBindingIterator<'a>>,
) {
pointer.drop()
pointer.release()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to avoid confusion with Drop trait or std::mem::drop function.

}

#[no_mangle]
Expand Down Expand Up @@ -419,7 +429,7 @@ extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkClose(
_env: EnvParam<'_>,
chunk: OwnedPointer<StreamChunk>,
) {
chunk.drop()
chunk.release()
}

#[no_mangle]
Expand Down Expand Up @@ -1052,6 +1062,14 @@ extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceErrorTo
})
}

#[no_mangle]
extern "system" fn Java_com_risingwave_java_binding_Binding_cdcSourceSenderClose<'a>(
_env: EnvParam<'a>,
channel: OwnedPointer<JniSenderType<GetEventStreamResponse>>,
) {
channel.release();
}

pub enum JniSinkWriterStreamRequest {
PbRequest(SinkWriterStreamRequest),
Chunk {
Expand Down
Loading