diff --git a/sdk/core/azure_core_test/src/lib.rs b/sdk/core/azure_core_test/src/lib.rs index fba572f048..5ae2720878 100644 --- a/sdk/core/azure_core_test/src/lib.rs +++ b/sdk/core/azure_core_test/src/lib.rs @@ -17,15 +17,17 @@ pub use azure_core::test::TestMode; #[derive(Clone, Debug)] pub struct TestContext { test_mode: TestMode, + crate_dir: &'static str, test_name: &'static str, } impl TestContext { /// Not intended for use outside the `azure_core` crates. #[doc(hidden)] - pub fn new(test_mode: TestMode, test_name: &'static str) -> Self { + pub fn new(test_mode: TestMode, crate_dir: &'static str, test_name: &'static str) -> Self { Self { test_mode, + crate_dir, test_name, } } @@ -35,6 +37,11 @@ impl TestContext { self.test_mode } + /// Gets the root directory of the crate under test. + pub fn crate_dir(&self) -> &'static str { + self.crate_dir + } + /// Gets the current test function name. pub fn test_name(&self) -> &'static str { self.test_name @@ -47,8 +54,16 @@ mod tests { #[test] fn test_content_new() { - let ctx = TestContext::new(TestMode::default(), "test_content_new"); + let ctx = TestContext::new( + TestMode::default(), + env!("CARGO_MANIFEST_DIR"), + "test_content_new", + ); assert_eq!(ctx.test_mode(), TestMode::Playback); + assert!(ctx + .crate_dir() + .replace("\\", "/") + .ends_with("sdk/core/azure_core_test")); assert_eq!(ctx.test_name(), "test_content_new"); } } diff --git a/sdk/core/azure_core_test_macros/src/test.rs b/sdk/core/azure_core_test_macros/src/test.rs index ec14b09d34..1d0ef41f63 100644 --- a/sdk/core/azure_core_test_macros/src/test.rs +++ b/sdk/core/azure_core_test_macros/src/test.rs @@ -9,7 +9,8 @@ use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, Meta, PatType, Result, const INVALID_RECORDED_ATTRIBUTE_MESSAGE: &str = "expected `#[recorded::test]` or `#[recorded::test(live)]`"; -const INVALID_RECORDED_FUNCTION_MESSAGE: &str = "expected `fn(TestContext)` function signature"; +const INVALID_RECORDED_FUNCTION_MESSAGE: &str = + "expected `async fn(TestContext)` function signature with optional `Result` return"; // cspell:ignore asyncness pub fn parse_test(attr: TokenStream, item: TokenStream) -> Result { @@ -17,15 +18,18 @@ pub fn parse_test(attr: TokenStream, item: TokenStream) -> Result { let ItemFn { attrs, vis, - mut sig, + sig: original_sig, block, } = syn::parse2(item)?; - // Use #[tokio::test] for async functions; otherwise, #[test]. - let mut test_attr: TokenStream = if sig.asyncness.is_some() { - quote! { #[::tokio::test] } - } else { - quote! { #[::core::prelude::v1::test] } + let mut test_attr: TokenStream = match original_sig.asyncness { + Some(_) => quote! { #[::tokio::test] }, + None => { + return Err(syn::Error::new( + original_sig.span(), + INVALID_RECORDED_FUNCTION_MESSAGE, + )) + } }; // Ignore live-only tests if not running live tests. @@ -36,21 +40,23 @@ pub fn parse_test(attr: TokenStream, item: TokenStream) -> Result { }); } - let mut inputs = sig.inputs.iter(); - let preamble = match inputs.next() { - None if recorded_attrs.live => TokenStream::new(), - Some(FnArg::Typed(PatType { pat, ty, .. })) if is_test_context(ty.as_ref()) => { + let fn_name = &original_sig.ident; + let mut inputs = original_sig.inputs.iter(); + let setup = match inputs.next() { + None if recorded_attrs.live => quote! { + #fn_name().await + }, + Some(FnArg::Typed(PatType { ty, .. })) if is_test_context(ty.as_ref()) => { let test_mode = test_mode_to_tokens(test_mode); - let fn_name = &sig.ident; - quote! { #[allow(dead_code)] - let #pat = #ty::new(#test_mode, stringify!(#fn_name)); + let ctx = #ty::new(#test_mode, env!("CARGO_MANIFEST_DIR"), stringify!(#fn_name)); + #fn_name(ctx).await } } _ => { return Err(syn::Error::new( - sig.ident.span(), + original_sig.ident.span(), INVALID_RECORDED_FUNCTION_MESSAGE, )) } @@ -63,14 +69,18 @@ pub fn parse_test(attr: TokenStream, item: TokenStream) -> Result { )); } - // Empty the parameters and return our rewritten test function. - sig.inputs.clear(); + // Clear the actual test method parameters. + let mut outer_sig = original_sig.clone(); + outer_sig.inputs.clear(); + Ok(quote! { #test_attr #(#attrs)* - #vis #sig { - #preamble - #block + #vis #outer_sig { + #original_sig { + #block + } + #setup } }) } diff --git a/sdk/eventhubs/azure_messaging_eventhubs/tests/producer.rs b/sdk/eventhubs/azure_messaging_eventhubs/tests/producer.rs index 0ca6752ef8..46e82ee585 100644 --- a/sdk/eventhubs/azure_messaging_eventhubs/tests/producer.rs +++ b/sdk/eventhubs/azure_messaging_eventhubs/tests/producer.rs @@ -139,7 +139,7 @@ async fn test_get_partition_properties() { } #[recorded::test(live)] -fn test_create_eventdata() { +async fn test_create_eventdata() { common::setup(); let data = b"hello world"; let ed1 = azure_messaging_eventhubs::models::EventData::builder()