Skip to content

Commit 05d8fb9

Browse files
committed
Update rust implementation, make more consistent with core implementation
Signed-off-by: currantw <[email protected]>
1 parent b44f959 commit 05d8fb9

File tree

1 file changed

+169
-55
lines changed

1 file changed

+169
-55
lines changed

rust/src/lib.rs

Lines changed: 169 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use glide_core::{
1111
};
1212
use std::{
1313
ffi::{CStr, CString, c_char, c_void},
14-
slice,
14+
slice::from_raw_parts,
1515
sync::Arc,
1616
};
1717
use tokio::runtime::{Builder, Runtime};
@@ -444,9 +444,9 @@ pub unsafe extern "C" fn init(level: Option<Level>, file_name: *const c_char) ->
444444
///
445445
/// # Safety
446446
/// * `client_ptr` must be a valid Client pointer from create_client
447-
/// * `cursor` must be a valid C string
447+
/// * `cursor` must be "0" for initial scan or a valid cursor ID from previous scan
448448
/// * `args` and `arg_lengths` must be valid arrays of length `arg_count`
449-
/// * `args` array format: alternating parameter names and values (e.g., [b"MATCH", pattern, b"COUNT", count_str])
449+
/// * `args` format: [b"MATCH", pattern, b"COUNT", count, b"TYPE", type] (all optional)
450450
#[unsafe(no_mangle)]
451451
pub unsafe extern "C-unwind" fn request_cluster_scan(
452452
client_ptr: *const c_void,
@@ -456,6 +456,7 @@ pub unsafe extern "C-unwind" fn request_cluster_scan(
456456
args: *const usize,
457457
arg_lengths: *const u64,
458458
) {
459+
// Build client and add panic guard.
459460
let client = unsafe {
460461
Arc::increment_strong_count(client_ptr);
461462
Arc::from_raw(client_ptr as *mut Client)
@@ -468,21 +469,34 @@ pub unsafe extern "C-unwind" fn request_cluster_scan(
468469
callback_index,
469470
};
470471

472+
// Build arguments and get the cluster scan state.
471473
let cursor_id = unsafe { CStr::from_ptr(cursor) }
472474
.to_str()
473475
.unwrap_or("0")
474476
.to_owned();
475477

476-
let cluster_scan_args = unsafe { parse_cluster_scan_args(args, arg_lengths, arg_count) };
478+
let cluster_scan_args = match unsafe {
479+
build_cluster_scan_args(
480+
arg_count,
481+
args,
482+
arg_lengths,
483+
core.failure_callback,
484+
callback_index,
485+
)
486+
} {
487+
Some(args) => args,
488+
None => return,
489+
};
477490

478491
let scan_state_cursor =
479492
match glide_core::cluster_scan_container::get_cluster_scan_cursor(cursor_id) {
480493
Ok(existing_cursor) => existing_cursor,
481494
Err(_error) => redis::ScanStateRC::new(),
482495
};
483496

497+
// Run cluster scan.
484498
client.runtime.spawn(async move {
485-
let mut panic_guard = PanicGuard {
499+
let mut async_panic_guard = PanicGuard {
486500
panicked: true,
487501
failure_callback: core.failure_callback,
488502
callback_index,
@@ -502,12 +516,13 @@ pub unsafe extern "C-unwind" fn request_cluster_scan(
502516
report_error(
503517
core.failure_callback,
504518
callback_index,
505-
error_message(&err),
506-
error_type(&err),
519+
glide_core::errors::error_message(&err),
520+
glide_core::errors::error_type(&err),
507521
);
508522
},
509523
};
510-
panic_guard.panicked = false;
524+
525+
async_panic_guard.panicked = false;
511526
});
512527

513528
panic_guard.panicked = false;
@@ -531,69 +546,168 @@ pub unsafe extern "C" fn remove_cluster_scan_cursor(cursor_id: *const c_char) {
531546
}
532547
}
533548

534-
/// Parse cluster scan arguments from C-style arrays.
549+
/// Build cluster scan arguments from C-style arrays.
535550
///
536551
/// # Safety
537552
/// * `args` and `arg_lengths` must be valid arrays of length `arg_count`
538553
/// * Each pointer in `args` must point to valid memory of the corresponding length
539-
unsafe fn parse_cluster_scan_args(
554+
unsafe fn build_cluster_scan_args(
555+
arg_count: u64,
540556
args: *const usize,
541557
arg_lengths: *const u64,
542-
arg_count: u64,
543-
) -> redis::ClusterScanArgs {
558+
failure_callback: FailureCallback,
559+
callback_index: usize,
560+
) -> Option<redis::ClusterScanArgs> {
544561
if arg_count == 0 {
545-
return redis::ClusterScanArgs::builder().build();
562+
return Some(redis::ClusterScanArgs::builder().build());
546563
}
547564

548-
let mut pattern: Option<&[u8]> = None;
549-
let mut object_type: Option<&[u8]> = None;
550-
let mut count: Option<&[u8]> = None;
551-
552-
let mut i = 0;
553-
while i < arg_count as usize {
554-
let arg_ptr = unsafe { *args.add(i) as *const u8 };
555-
let arg_len = unsafe { *arg_lengths.add(i) as usize };
556-
let arg = unsafe { slice::from_raw_parts(arg_ptr, arg_len) };
557-
558-
match arg {
559-
b"MATCH" if i + 1 < arg_count as usize => {
560-
i += 1;
561-
let pattern_ptr = unsafe { *args.add(i) as *const u8 };
562-
let pattern_len = unsafe { *arg_lengths.add(i) as usize };
563-
pattern = Some(unsafe { slice::from_raw_parts(pattern_ptr, pattern_len) });
565+
let arg_vec = unsafe { convert_double_pointer_to_vec(args, arg_count, arg_lengths) };
566+
567+
let mut pattern: &[u8] = &[];
568+
let mut object_type: &[u8] = &[];
569+
let mut count: &[u8] = &[];
570+
571+
let mut iter = arg_vec.iter().peekable();
572+
while let Some(arg) = iter.next() {
573+
match *arg {
574+
b"MATCH" => match iter.next() {
575+
Some(pat) => pattern = pat,
576+
None => {
577+
unsafe {
578+
report_error(
579+
failure_callback,
580+
callback_index,
581+
"No argument following MATCH.".into(),
582+
RequestErrorType::Unspecified,
583+
);
584+
}
585+
return None;
586+
}
587+
},
588+
b"TYPE" => match iter.next() {
589+
Some(obj_type) => object_type = obj_type,
590+
None => {
591+
unsafe {
592+
report_error(
593+
failure_callback,
594+
callback_index,
595+
"No argument following TYPE.".into(),
596+
RequestErrorType::Unspecified,
597+
);
598+
}
599+
return None;
600+
}
601+
},
602+
b"COUNT" => match iter.next() {
603+
Some(c) => count = c,
604+
None => {
605+
unsafe {
606+
report_error(
607+
failure_callback,
608+
callback_index,
609+
"No argument following COUNT.".into(),
610+
RequestErrorType::Unspecified,
611+
);
612+
}
613+
return None;
614+
}
615+
},
616+
_ => {
617+
unsafe {
618+
report_error(
619+
failure_callback,
620+
callback_index,
621+
"Unknown cluster scan argument".into(),
622+
RequestErrorType::Unspecified,
623+
);
624+
}
625+
return None;
564626
}
565-
b"TYPE" if i + 1 < arg_count as usize => {
566-
i += 1;
567-
let type_ptr = unsafe { *args.add(i) as *const u8 };
568-
let type_len = unsafe { *arg_lengths.add(i) as usize };
569-
object_type = Some(unsafe { slice::from_raw_parts(type_ptr, type_len) });
627+
}
628+
}
629+
630+
// Convert back to proper types
631+
let converted_count = match std::str::from_utf8(count) {
632+
Ok(v) => {
633+
if !count.is_empty() {
634+
match v.parse::<u32>() {
635+
Ok(v) => v,
636+
Err(_) => {
637+
unsafe {
638+
report_error(
639+
failure_callback,
640+
callback_index,
641+
"Invalid COUNT value".into(),
642+
RequestErrorType::Unspecified,
643+
);
644+
}
645+
return None;
646+
}
647+
}
648+
} else {
649+
10 // default count value
570650
}
571-
b"COUNT" if i + 1 < arg_count as usize => {
572-
i += 1;
573-
let count_ptr = unsafe { *args.add(i) as *const u8 };
574-
let count_len = unsafe { *arg_lengths.add(i) as usize };
575-
count = Some(unsafe { slice::from_raw_parts(count_ptr, count_len) });
651+
}
652+
Err(_) => {
653+
unsafe {
654+
report_error(
655+
failure_callback,
656+
callback_index,
657+
"Invalid UTF-8 in COUNT argument".into(),
658+
RequestErrorType::Unspecified,
659+
);
576660
}
577-
_ => {}
661+
return None;
578662
}
579-
i += 1;
580-
}
663+
};
581664

582-
let mut builder = redis::ClusterScanArgs::builder();
583-
if let Some(pattern) = pattern {
584-
builder = builder.with_match_pattern(pattern);
585-
}
586-
if let Some(count_bytes) = count {
587-
if let Ok(count_str) = std::str::from_utf8(count_bytes) {
588-
if let Ok(count_val) = count_str.parse::<u32>() {
589-
builder = builder.with_count(count_val);
665+
let converted_type = match std::str::from_utf8(object_type) {
666+
Ok(v) => redis::ObjectType::from(v.to_string()),
667+
Err(_) => {
668+
unsafe {
669+
report_error(
670+
failure_callback,
671+
callback_index,
672+
"Invalid UTF-8 in TYPE argument".into(),
673+
RequestErrorType::Unspecified,
674+
);
590675
}
676+
return None;
591677
}
678+
};
679+
680+
let mut cluster_scan_args_builder = redis::ClusterScanArgs::builder();
681+
if !count.is_empty() {
682+
cluster_scan_args_builder = cluster_scan_args_builder.with_count(converted_count);
592683
}
593-
if let Some(type_bytes) = object_type {
594-
if let Ok(type_str) = std::str::from_utf8(type_bytes) {
595-
builder = builder.with_object_type(redis::ObjectType::from(type_str.to_string()));
596-
}
684+
if !pattern.is_empty() {
685+
cluster_scan_args_builder = cluster_scan_args_builder.with_match_pattern(pattern);
686+
}
687+
if !object_type.is_empty() {
688+
cluster_scan_args_builder = cluster_scan_args_builder.with_object_type(converted_type);
689+
}
690+
Some(cluster_scan_args_builder.build())
691+
}
692+
693+
/// Converts a double pointer to a vec.
694+
///
695+
/// # Safety
696+
///
697+
/// `convert_double_pointer_to_vec` returns a `Vec` of u8 slice which holds pointers of C
698+
/// strings. The returned `Vec<&'a [u8]>` is meant to be copied into Rust code. Storing them
699+
/// for later use will cause the program to crash as the pointers will be freed by the caller.
700+
unsafe fn convert_double_pointer_to_vec<'a>(
701+
data: *const usize,
702+
len: u64,
703+
data_len: *const u64,
704+
) -> Vec<&'a [u8]> {
705+
let string_ptrs = unsafe { from_raw_parts(data, len as usize) };
706+
let string_lengths = unsafe { from_raw_parts(data_len, len as usize) };
707+
let mut result = Vec::<&[u8]>::with_capacity(string_ptrs.len());
708+
for (i, &str_ptr) in string_ptrs.iter().enumerate() {
709+
let slice = unsafe { from_raw_parts(str_ptr as *const u8, string_lengths[i] as usize) };
710+
result.push(slice);
597711
}
598-
builder.build()
712+
result
599713
}

0 commit comments

Comments
 (0)