From b4a5d441ca4185f368f447673bda47c17e60ae10 Mon Sep 17 00:00:00 2001 From: Heath Stewart Date: Mon, 24 Nov 2025 15:55:01 -0800 Subject: [PATCH] Eliminating locking and improve safety with Pagers, Pollers Resolves #3294 --- sdk/core/azure_core/src/http/pager.rs | 727 +++++++++--------- sdk/core/azure_core/src/lib.rs | 20 + .../generated/clients/certificate_client.rs | 8 +- .../src/generated/clients/secret_client.rs | 6 +- .../tests/secret_client.rs | 1 + 5 files changed, 410 insertions(+), 352 deletions(-) diff --git a/sdk/core/azure_core/src/http/pager.rs b/sdk/core/azure_core/src/http/pager.rs index 5e858ecbd9..8401e264ba 100644 --- a/sdk/core/azure_core/src/http/pager.rs +++ b/sdk/core/azure_core/src/http/pager.rs @@ -3,7 +3,11 @@ //! Types and methods for pageable responses. +// TODO: Remove once tests re-enabled! +#![allow(missing_docs, unexpected_cfgs)] + use crate::{ + conditional_send::ConditionalSend, error::ErrorKind, http::{ headers::HeaderName, policies::create_public_api_span, response::Response, Context, @@ -12,16 +16,9 @@ use crate::{ tracing::{Span, SpanStatus}, }; use async_trait::async_trait; -use futures::{stream::unfold, FutureExt, Stream}; -use std::{ - fmt, - future::Future, - ops::Deref, - pin::Pin, - str::FromStr, - sync::{Arc, Mutex}, - task, -}; +use futures::{pin_mut, stream::FusedStream, FutureExt, Stream, StreamExt}; +use pin_project::pin_project; +use std::{fmt, future::Future, ops::Deref, pin::Pin, str::FromStr, sync::Arc, task}; /// Represents the state of a [`Pager`] or [`PageIterator`]. #[derive(Debug, Default, PartialEq, Eq)] @@ -97,7 +94,7 @@ pub enum PagerResult> { } impl PagerResult, String> { - /// Creates a [`PagerResult`] from the provided response, extracting the continuation value from the provided header. + /// Creates a [`PagerResult`] from the provided response, extracting the continuation value from the provided header. /// /// If the provided response has a header with the matching name, this returns [`PagerResult::More`], using the value from the header as the continuation. /// If the provided response does not have a header with the matching name, this returns [`PagerResult::Done`]. @@ -154,7 +151,7 @@ where /// Represents a paginated stream of items returned by a collection request to a service. /// -/// Specifically, this is a [`ItemIterator`] that yields [`Response`] items. +/// Specifically, this is a [`ItemIterator`] that yields [`Response`] items. /// /// # Examples /// @@ -204,7 +201,60 @@ where /// } /// # Ok(()) } /// ``` -pub type Pager = ItemIterator>; +#[derive(Debug)] +#[pin_project(project = PagerProjection, project_replace = PagerProjectedOwned)] +pub struct Pager +where + P: Page + DeserializeWith + ConditionalSend + 'static, + F: Format + ConditionalSend + 'static, +{ + #[pin] + iter: Pin> + 'static>>, +} + +impl Pager +where + P: Page + DeserializeWith + ConditionalSend + 'static, + F: Format + ConditionalSend + 'static, +{ + pub fn new(make_request: Func, options: Option>) -> Self + where + C: AsRef + Clone + FromStr + ConditionalSend + 'static, + Func: Fn(PagerState, PagerOptions<'static>) -> Fut + ConditionalSend + 'static, + Fut: Future, C>>> + + ConditionalSend + + 'static, + ::Err: std::error::Error, + { + Self { + iter: Box::pin(ItemIterator::new(make_request, options)), + } + } + + pub fn continuation_token(&self) -> Option<&str> { + self.iter.continuation_token() + } + + pub fn into_pages(self) -> Pin> + 'static>> { + todo!() + } +} + +impl Stream for Pager +where + P: Page + DeserializeWith + ConditionalSend + 'static, + F: Format + ConditionalSend + 'static, +{ + type Item = crate::Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + let mut this = self.project(); + this.iter.poll_next_unpin(cx) + } +} /// Options for configuring the behavior of a [`Pager`]. #[derive(Clone, Debug, Default)] @@ -236,7 +286,7 @@ pub struct PagerOptions<'a> { /// // which is the first page in this example. /// let options = SecretClientListSecretPropertiesOptions { /// method_options: PagerOptions { - /// continuation_token: pager.continuation_token(), + /// continuation_token: pager.continuation_token().map(Into::into), /// ..Default::default() /// }, /// ..Default::default() @@ -251,12 +301,6 @@ pub struct PagerOptions<'a> { pub continuation_token: Option, } -#[cfg(not(target_arch = "wasm32"))] -type BoxedStream

= Box> + Send>; - -#[cfg(target_arch = "wasm32")] -type BoxedStream

= Box>>; - /// Iterates over a collection of items or individual pages of items from a service. /// /// You can asynchronously iterate over items returned by a collection request to a service, @@ -310,21 +354,36 @@ type BoxedStream

= Box>>; /// } /// # Ok(()) } /// ``` -#[pin_project::pin_project] -pub struct ItemIterator { +#[pin_project(project = ItemIteratorProjection, project_replace = ItemIteratorProjectionOwned)] +pub struct ItemIterator +where + // These type constraints are copied for all implementations to ease maintenance; + // though, not all are necessary in every case. + P: Page + 'static, + C: AsRef + Clone + FromStr + ConditionalSend + 'static, + F: Fn(PagerState, PagerOptions<'static>) -> Fut + ConditionalSend + 'static, + Fut: Future>> + ConditionalSend + 'static, + ::Err: std::error::Error, +{ #[pin] - stream: Pin>, - continuation_token: Option, - next_token: Arc>>, + iter: PageIterator, + next_token: Option, current: Option, } -impl ItemIterator

{ - /// Creates a [`ItemIterator

`] from a callback that will be called repeatedly to request each page. +impl ItemIterator +where + P: Page + 'static, + C: AsRef + Clone + FromStr + ConditionalSend + 'static, + F: Fn(PagerState, PagerOptions<'static>) -> Fut + ConditionalSend + 'static, + Fut: Future>> + ConditionalSend + 'static, + ::Err: std::error::Error, +{ + /// Creates a [`ItemIterator`] from a callback that will be called repeatedly to request each page. /// - /// This method expect a callback that accepts a single [`PagerState`] parameter, and returns a [`PagerResult`] value asynchronously. + /// This method expect a callback that accepts a single [`PagerState`] parameter, and returns a [`PagerResult`] value asynchronously. /// The `C` type parameter is the type of the next link/continuation token. It may be any [`Send`]able type. - /// The result will be an asynchronous stream of [`Result`](crate::Result) values. + /// The result will be an asynchronous stream of [`Result`](crate::Result) values. /// /// The first time your callback is called, it will be called with [`Option::None`], indicating no next link/continuation token is present. /// @@ -356,7 +415,7 @@ impl ItemIterator

{ /// } /// let url = "https://example.com/my_paginated_api".parse().unwrap(); /// let mut base_req = Request::new(url, Method::Get); - /// let pager = ItemIterator::from_callback(move |next_link: PagerState, options: PagerOptions| { + /// let pager = ItemIterator::new(move |next_link: PagerState, options: PagerOptions| { /// // The callback must be 'static, so you have to clone and move any values you want to use. /// let pipeline = pipeline.clone(); /// let api_version = api_version.clone(); @@ -411,7 +470,7 @@ impl ItemIterator

{ /// } /// let url = "https://example.com/my_paginated_api".parse().unwrap(); /// let mut base_req = Request::new(url, Method::Get); - /// let pager = ItemIterator::from_callback(move |continuation, options| { + /// let pager = ItemIterator::new(move |continuation, options| { /// // The callback must be 'static, so you have to clone and move any values you want to use. /// let pipeline = pipeline.clone(); /// let mut req = base_req.clone(); @@ -427,121 +486,95 @@ impl ItemIterator

{ /// } /// }, None); /// ``` - pub fn from_callback< - // This is a bit gnarly, but the only thing that differs between the WASM/non-WASM configs is the presence of Send bounds. - #[cfg(not(target_arch = "wasm32"))] C: AsRef + FromStr + Send + 'static, - #[cfg(not(target_arch = "wasm32"))] F: Fn(PagerState, PagerOptions<'static>) -> Fut + Send + 'static, - #[cfg(not(target_arch = "wasm32"))] Fut: Future>> + Send + 'static, - #[cfg(target_arch = "wasm32")] C: AsRef + FromStr + 'static, - #[cfg(target_arch = "wasm32")] F: Fn(PagerState, PagerOptions<'static>) -> Fut + 'static, - #[cfg(target_arch = "wasm32")] Fut: Future>> + 'static, - >( - make_request: F, - options: Option>, - ) -> Self - where - ::Err: std::error::Error, - { + pub fn new(make_request: F, options: Option>) -> Self { let options = options.unwrap_or_default(); // Start from the optional `PagerOptions::continuation_token`. let continuation_token = options.continuation_token.clone(); - let next_token = Arc::new(Mutex::new(continuation_token.clone())); - let stream = iter_from_callback(make_request, options, next_token.clone()); + let next_token = continuation_token.clone(); Self { - stream: Box::pin(stream), - continuation_token, + iter: PageIterator { + make_request: Box::pin(make_request), + continuation_token, + options, + state: State::Init, + added_span: false, + }, next_token, current: None, } } - /// Creates a [`ItemIterator

`] from a raw stream of [`Result

`](crate::Result

) values. - /// - /// This constructor is used when you are implementing a completely custom stream and want to use it as a pager. - pub fn from_stream< - // This is a bit gnarly, but the only thing that differs between the WASM/non-WASM configs is the presence of Send bounds. - #[cfg(not(target_arch = "wasm32"))] S: Stream> + Send + 'static, - #[cfg(target_arch = "wasm32")] S: Stream> + 'static, - >( - stream: S, - ) -> Self { - Self { - stream: Box::pin(stream), - continuation_token: None, - next_token: Default::default(), - current: None, - } + pub fn continuation_token(&self) -> Option<&str> { + self.iter.continuation_token() } - /// Gets a [`PageIterator

`] to iterate over a collection of pages from a service. - /// - /// You can use this to asynchronously iterate pages returned by a collection request to a service. - /// This allows you to get the individual pages' [`Response

`], from which you can iterate items in each page - /// or deserialize the raw response as appropriate. - /// - /// The returned `PageIterator` resumes from the current page until _after_ all items are processed. - /// It does not continue on the next page until you call `next()` after the last item in the current page - /// because of how iterators are implemented. This may yield duplicates but will reduce the likelihood of skipping items instead. - pub fn into_pages(self) -> PageIterator

{ - // Attempt to start paging from the current page so that we don't skip items, - // assuming the service collection hasn't changed (most services don't create ephemeral snapshots). - if let Ok(mut token) = self.next_token.lock() { - *token = self.continuation_token; - } + pub fn into_pages(self) -> Box + 'static> { + Box::new(self.iter) + } +} - PageIterator { - stream: self.stream, - continuation_token: self.next_token, - } +pub trait ItemPager: Stream> + fmt::Debug { + fn continuation_token(&self) -> Option<&str>; + fn into_pages(self) -> Box + 'static>; +} + +impl ItemPager

for ItemIterator +where + P: Page + 'static, + C: AsRef + Clone + FromStr + ConditionalSend + 'static, + F: Fn(PagerState, PagerOptions<'static>) -> Fut + ConditionalSend + 'static, + Fut: Future>> + ConditionalSend + 'static, + ::Err: std::error::Error, +{ + fn continuation_token(&self) -> Option<&str> { + Self::continuation_token(self) } - /// Gets the continuation token for the current page. - /// - /// Pass this in [`PagerOptions::continuation_token`] to create a `ItemIterator` that, when first iterated, - /// will return the current page until _after_ all items are iterated. - /// It does not continue on the next page until you call `next()` after the last item in the current page - /// because of how iterators are implemented. This may yield duplicates but will reduce the likelihood of skipping items instead. - pub fn continuation_token(&self) -> Option { - // Get the continuation_token because that will be used to start over with the current page. - self.continuation_token.clone() + fn into_pages(self) -> Box> { + Self::into_pages(self) } } -impl futures::Stream for ItemIterator

{ +impl Stream for ItemIterator +where + P: Page + 'static, + C: AsRef + Clone + FromStr + ConditionalSend + 'static, + F: Fn(PagerState, PagerOptions<'static>) -> Fut + ConditionalSend + 'static, + Fut: Future>> + ConditionalSend + 'static, + ::Err: std::error::Error, +{ type Item = crate::Result; fn poll_next( self: Pin<&mut Self>, - cx: &mut task::Context<'_>, - ) -> task::Poll> { - let mut projected_self = self.project(); + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let mut this = self.project(); loop { - if let Some(current) = projected_self.current.as_mut() { + if let Some(current) = this.current.as_mut() { if let Some(item) = current.next() { return task::Poll::Ready(Some(Ok(item))); } // Reset the iterator and poll for the next page. - *projected_self.current = None; + *this.current = None; } // Set the current_token to the next page only after iterating through all items. - if let Ok(token) = projected_self.next_token.lock() { - tracing::trace!( - "updating continuation_token from {:?} to {:?}", - projected_self.continuation_token, - token - ); - *projected_self.continuation_token = token.clone(); - } - - match projected_self.stream.as_mut().poll_next(cx) { + tracing::trace!( + "updating continuation_token from {:?} to {:?}", + &this.iter.continuation_token, + &this.next_token + ); + this.iter.continuation_token = this.next_token.clone(); + + match this.iter.as_mut().poll_next(cx) { task::Poll::Ready(page) => match page { Some(Ok(page)) => match page.into_items().poll_unpin(cx) { task::Poll::Ready(Ok(iter)) => { - *projected_self.current = Some(iter); + *this.current = Some(iter); continue; } task::Poll::Ready(Err(err)) => return task::Poll::Ready(Some(Err(err))), @@ -556,14 +589,35 @@ impl futures::Stream for ItemIterator

{ } } -impl fmt::Debug for ItemIterator

{ +impl fmt::Debug for ItemIterator +where + P: Page + 'static, + C: AsRef + Clone + FromStr + ConditionalSend + 'static, + F: Fn(PagerState, PagerOptions<'static>) -> Fut + ConditionalSend + 'static, + Fut: Future>> + ConditionalSend + 'static, + ::Err: std::error::Error, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ItemIterator") - .field("continuation_token", &self.continuation_token) + .field("iter", &self.iter) + .field("next_token", &self.next_token) .finish_non_exhaustive() } } +impl FusedStream for ItemIterator +where + P: Page + 'static, + C: AsRef + Clone + FromStr + ConditionalSend + 'static, + F: Fn(PagerState, PagerOptions<'static>) -> Fut + ConditionalSend + 'static, + Fut: Future>> + ConditionalSend + 'static, + ::Err: std::error::Error, +{ + fn is_terminated(&self) -> bool { + self.iter.is_terminated() + } +} + /// Iterates over a collection pages of items from a service. /// /// # Examples @@ -594,19 +648,37 @@ impl fmt::Debug for ItemIterator

{ /// } /// # Ok(()) } /// ``` -#[pin_project::pin_project] -pub struct PageIterator

{ +#[must_use = "streams do nothing unless polled"] +#[pin_project(project = PageIteratorProjection, project_replace = PageIteratorProjectionOwned)] +pub struct PageIterator +where + // These type constraints are copied for all implementations to ease maintenance; + // though, not all are necessary in every case. + C: AsRef + Clone + FromStr + ConditionalSend + 'static, + F: Fn(PagerState, PagerOptions<'static>) -> Fut + ConditionalSend + 'static, + Fut: Future>> + ConditionalSend + 'static, + ::Err: std::error::Error, +{ #[pin] - stream: Pin>, - continuation_token: Arc>>, + make_request: Pin>, + continuation_token: Option, + options: PagerOptions<'static>, + state: State, + added_span: bool, } -impl

PageIterator

{ - /// Creates a [`PageIterator

`] from a callback that will be called repeatedly to request each page. +impl PageIterator +where + C: AsRef + Clone + FromStr + ConditionalSend + 'static, + F: Fn(PagerState, PagerOptions<'static>) -> Fut + ConditionalSend + 'static, + Fut: Future>> + ConditionalSend + 'static, + ::Err: std::error::Error, +{ + /// Creates a [`PageIterator`] from a callback that will be called repeatedly to request each page. /// - /// This method expect a callback that accepts a single [`PagerState`] parameter, and returns a [`PagerResult`] value asynchronously. + /// This method expect a callback that accepts a single [`PagerState`] parameter, and returns a [`PagerResult`] value asynchronously. /// The `C` type parameter is the type of the next link/continuation token. It may be any [`Send`]able type. - /// The result will be an asynchronous stream of [`Result`](crate::Result) values. + /// The result will be an asynchronous stream of [`Result`](crate::Result) values. /// /// The first time your callback is called, it will be called with [`PagerState::Initial`], indicating no next link/continuation token is present. /// @@ -629,7 +701,7 @@ impl

PageIterator

{ /// } /// let url = "https://example.com/my_paginated_api".parse().unwrap(); /// let mut base_req = Request::new(url, Method::Get); - /// let pager = PageIterator::from_callback(move |next_link: PagerState, options: PagerOptions<'static>| { + /// let pager = PageIterator::new(move |next_link: PagerState, options: PagerOptions<'static>| { /// // The callback must be 'static, so you have to clone and move any values you want to use. /// let pipeline = pipeline.clone(); /// let api_version = api_version.clone(); @@ -675,7 +747,7 @@ impl

PageIterator

{ /// } /// let url = "https://example.com/my_paginated_api".parse().unwrap(); /// let mut base_req = Request::new(url, Method::Get); - /// let pager = PageIterator::from_callback(move |continuation, options| { + /// let pager = PageIterator::new(move |continuation, options| { /// // The callback must be 'static, so you have to clone and move any values you want to use. /// let pipeline = pipeline.clone(); /// let mut req = base_req.clone(); @@ -691,76 +763,198 @@ impl

PageIterator

{ /// } /// }, None); /// ``` - pub fn from_callback< - // This is a bit gnarly, but the only thing that differs between the WASM/non-WASM configs is the presence of Send bounds. - #[cfg(not(target_arch = "wasm32"))] C: AsRef + FromStr + Send + 'static, - #[cfg(not(target_arch = "wasm32"))] F: Fn(PagerState, PagerOptions<'static>) -> Fut + Send + 'static, - #[cfg(not(target_arch = "wasm32"))] Fut: Future>> + Send + 'static, - #[cfg(target_arch = "wasm32")] C: AsRef + FromStr + 'static, - #[cfg(target_arch = "wasm32")] F: Fn(PagerState, PagerOptions<'static>) -> Fut + 'static, - #[cfg(target_arch = "wasm32")] Fut: Future>> + 'static, - >( - make_request: F, - options: Option>, - ) -> Self - where - ::Err: std::error::Error, - { + #[cfg(test)] + fn new(make_request: F, options: Option>) -> Self { let options = options.unwrap_or_default(); // Start from the optional `PagerOptions::continuation_token`. - let continuation_token = Arc::new(Mutex::new(options.continuation_token.clone())); - let stream = iter_from_callback(make_request, options, continuation_token.clone()); + let continuation_token = options.continuation_token.clone(); Self { - stream: Box::pin(stream), + make_request: Box::pin(make_request), continuation_token, + options, + state: State::Init, + added_span: false, } } - /// Creates a [`PageIterator

`] from a raw stream of [`Result

`](crate::Result

) values. - /// - /// This constructor is used when you are implementing a completely custom stream and want to use it as a pager. - pub fn from_stream< - // This is a bit gnarly, but the only thing that differs between the WASM/non-WASM configs is the presence of Send bounds. - #[cfg(not(target_arch = "wasm32"))] S: Stream> + Send + 'static, - #[cfg(target_arch = "wasm32")] S: Stream> + 'static, - >( - stream: S, - ) -> Self { - Self { - stream: Box::pin(stream), - continuation_token: Default::default(), - } - } - - /// Gets the continuation token for the current page. - /// - /// Pass this to [`PagerOptions::continuation_token`] to create a `PageIterator` that, when first iterated, - /// will return the next page. You can use this to page results across separate processes. - pub fn continuation_token(&self) -> Option { - if let Ok(token) = self.continuation_token.lock() { - return token.clone(); - } - - None + pub fn continuation_token(&self) -> Option<&str> { + self.continuation_token.as_deref() } } -impl

futures::Stream for PageIterator

{ +impl Stream for PageIterator +where + C: AsRef + Clone + FromStr + ConditionalSend + 'static, + F: Fn(PagerState, PagerOptions<'static>) -> Fut + ConditionalSend + 'static, + Fut: Future>> + ConditionalSend + 'static, + ::Err: std::error::Error, +{ type Item = crate::Result

; fn poll_next( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - self.project().stream.poll_next(cx) + let this = self.project(); + + // When in the "Init" state, we are either starting fresh or resuming from a continuation token. + // In either case, attach a span to the context for the entire paging operation. + if *this.state == State::Init { + tracing::debug!("establish a public API span for new pager."); + + // At the very start of polling, create a span for the entire request, and attach it to the context + let span = create_public_api_span(&this.options.context, None, None); + if let Some(s) = span { + *this.added_span = true; + let old_context = std::mem::take(this.options); + *this.options = PagerOptions { + context: old_context.context.with_value(s), + continuation_token: old_context.continuation_token, + }; + } + } + + // Get the `continuation_token` to pick up where we left off, or None for the initial page, + // but don't override the terminal `State::Done`. + if *this.state != State::Done { + let next_state = match this.continuation_token.as_deref() { + Some(n) => match n.parse() { + Ok(s) => State::More(s), + Err(err) => { + let error = + crate::Error::with_message_fn(ErrorKind::DataConversion, || { + format!("invalid continuation token: {err}") + }); + *this.state = State::Done; + return task::Poll::Ready(Some(Err(error))); + } + }, + // Restart the pager if `continuation_token` is None indicating we resumed from before or within the first page. + None => State::Init, + }; + *this.state = next_state; + } + + let result = match *this.state { + State::Init => { + tracing::debug!("initial page request"); + let options = this.options.clone(); + let fut = (this.make_request)(PagerState::Initial, options); + pin_mut!(fut); + + match fut.poll(cx) { + task::Poll::Ready(result) => result, + task::Poll::Pending => return task::Poll::Pending, + } + } + State::More(ref n) => { + tracing::debug!("subsequent page request to {:?}", AsRef::::as_ref(n)); + let options = this.options.clone(); + let fut = (this.make_request)(PagerState::More(n.clone()), options); + pin_mut!(fut); + + match fut.poll(cx) { + task::Poll::Ready(result) => result, + task::Poll::Pending => return task::Poll::Pending, + } + } + State::Done => { + tracing::debug!("done"); + // Set the `continuation_token` to None now that we are done. + *this.continuation_token = None; + return task::Poll::Ready(None); + } + }; + + // Update continuation token and instrumentation. + match result { + Err(e) => { + if *this.added_span { + if let Some(span) = this.options.context.value::>() { + // Mark the span as an error with an appropriate description. + span.set_status(SpanStatus::Error { + description: e.to_string(), + }); + span.set_attribute("error.type", e.kind().to_string().into()); + span.end(); + } + } + + *this.state = State::Done; + task::Poll::Ready(Some(Err(e))) + } + + Ok(PagerResult::More { + response, + continuation: next_token, + }) => { + // Set the `continuation_token` to the next page. + *this.continuation_token = Some(next_token.as_ref().into()); + task::Poll::Ready(Some(Ok(response))) + } + + Ok(PagerResult::Done { response }) => { + // Set the `continuation_token` to None now that we are done. + *this.continuation_token = None; + + // When the result is done, finalize the span. Note that we only do that if we created the span in the first place; + // otherwise, it is the responsibility of the caller to end their span. + if *this.added_span { + if let Some(span) = this.options.context.value::>() { + span.end(); + } + } + + task::Poll::Ready(Some(Ok(response))) + } + } } } -impl

fmt::Debug for PageIterator

{ +impl fmt::Debug for PageIterator +where + C: AsRef + Clone + FromStr + ConditionalSend + 'static, + F: Fn(PagerState, PagerOptions<'static>) -> Fut + ConditionalSend + 'static, + Fut: Future>> + ConditionalSend + 'static, + ::Err: std::error::Error, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("PageIterator").finish_non_exhaustive() + f.debug_struct("PageIterator") + .field("continuation_token", &self.continuation_token) + .field("options", &self.options) + .field("state", &self.state) + .field("added_span", &self.added_span) + .finish_non_exhaustive() + } +} + +impl FusedStream for PageIterator +where + C: AsRef + Clone + FromStr + ConditionalSend + 'static, + F: Fn(PagerState, PagerOptions<'static>) -> Fut + ConditionalSend + 'static, + Fut: Future>> + ConditionalSend + 'static, + ::Err: std::error::Error, +{ + fn is_terminated(&self) -> bool { + self.state == State::Done + } +} + +pub trait PagePager

: Stream> { + fn continuation_token(&self) -> Option<&str>; +} + +impl PagePager

for PageIterator +where + C: AsRef + Clone + FromStr + ConditionalSend + 'static, + F: Fn(PagerState, PagerOptions<'static>) -> Fut + ConditionalSend + 'static, + Fut: Future>> + ConditionalSend + 'static, + ::Err: std::error::Error, +{ + fn continuation_token(&self) -> Option<&str> { + Self::continuation_token(self) } } @@ -800,163 +994,13 @@ where } } -#[derive(Debug)] -struct StreamState<'a, C, F> -where - C: AsRef, -{ - state: State, - make_request: F, - continuation_token: Arc>>, - options: PagerOptions<'a>, - added_span: bool, -} - -fn iter_from_callback< - P, - // This is a bit gnarly, but the only thing that differs between the WASM/non-WASM configs is the presence of Send bounds. - #[cfg(not(target_arch = "wasm32"))] C: AsRef + FromStr + Send + 'static, - #[cfg(not(target_arch = "wasm32"))] F: Fn(PagerState, PagerOptions<'static>) -> Fut + Send + 'static, - #[cfg(not(target_arch = "wasm32"))] Fut: Future>> + Send + 'static, - #[cfg(target_arch = "wasm32")] C: AsRef + FromStr + 'static, - #[cfg(target_arch = "wasm32")] F: Fn(PagerState, PagerOptions<'static>) -> Fut + 'static, - #[cfg(target_arch = "wasm32")] Fut: Future>> + 'static, ->( - make_request: F, - options: PagerOptions<'static>, - continuation_token: Arc>>, -) -> impl Stream> + 'static -where - ::Err: std::error::Error, -{ - unfold( - StreamState { - state: State::Init, - make_request, - continuation_token, - options, - added_span: false, - }, - |mut stream_state| async move { - // When in the "Init" state, we are either starting fresh or resuming from a continuation token. In either case, - // attach a span to the context for the entire paging operation. - if stream_state.state == State::Init { - tracing::debug!("establish a public API span for new pager."); - - // At the very start of polling, create a span for the entire request, and attach it to the context - let span = create_public_api_span(&stream_state.options.context, None, None); - if let Some(ref s) = span { - stream_state.added_span = true; - stream_state.options.context = - stream_state.options.context.with_value(s.clone()); - } - } - - // Get the `continuation_token` to pick up where we left off, or None for the initial page, - // but don't override the terminal `State::Done`. - if stream_state.state != State::Done { - let result = match stream_state.continuation_token.lock() { - Ok(next_token) => match next_token.as_deref() { - Some(n) => match n.parse() { - Ok(s) => Ok(State::More(s)), - Err(err) => Err(crate::Error::with_message_fn( - ErrorKind::DataConversion, - || format!("invalid continuation token: {err}"), - )), - }, - // Restart the pager if `next_token` is None indicating we resumed from before or within the first page. - None => Ok(State::Init), - }, - Err(err) => Err(crate::Error::with_message_fn(ErrorKind::Other, || { - format!("continuation token lock: {err}") - })), - }; - - match result { - Ok(state) => stream_state.state = state, - Err(err) => { - stream_state.state = State::Done; - return Some((Err(err), stream_state)); - } - } - } - let result = match stream_state.state { - State::Init => { - tracing::debug!("initial page request"); - (stream_state.make_request)(PagerState::Initial, stream_state.options.clone()) - .await - } - State::More(n) => { - tracing::debug!("subsequent page request to {:?}", AsRef::::as_ref(&n)); - (stream_state.make_request)(PagerState::More(n), stream_state.options.clone()) - .await - } - State::Done => { - tracing::debug!("done"); - // Set the `continuation_token` to None now that we are done. - if let Ok(mut token) = stream_state.continuation_token.lock() { - *token = None; - } - return None; - } - }; - let (item, next_state) = match result { - Err(e) => { - if stream_state.added_span { - if let Some(span) = stream_state.options.context.value::>() { - // Mark the span as an error with an appropriate description. - span.set_status(SpanStatus::Error { - description: e.to_string(), - }); - span.set_attribute("error.type", e.kind().to_string().into()); - span.end(); - } - } - - stream_state.state = State::Done; - return Some((Err(e), stream_state)); - } - Ok(PagerResult::More { - response, - continuation: next_token, - }) => { - // Set the `continuation_token` to the next page. - if let Ok(mut token) = stream_state.continuation_token.lock() { - *token = Some(next_token.as_ref().into()); - } - (Ok(response), State::More(next_token)) - } - Ok(PagerResult::Done { response }) => { - // Set the `continuation_token` to None now that we are done. - if let Ok(mut token) = stream_state.continuation_token.lock() { - *token = None; - } - // When the result is done, finalize the span. Note that we only do that if we created the span in the first place, - // otherwise it is the responsibility of the caller to end their span. - if stream_state.added_span { - if let Some(span) = stream_state.options.context.value::>() { - // P is unconstrained, so it's not possible to retrieve the status code for now. - - span.end(); - } - } - (Ok(response), State::Done) - } - }; - - stream_state.state = next_state; - Some((item, stream_state)) - }, - ) -} - #[cfg(test)] mod tests { + use super::{ItemIterator, PageIterator, Pager, PagerOptions, PagerResult, PagerState}; use crate::{ error::ErrorKind, http::{ headers::{HeaderName, HeaderValue}, - pager::{PageIterator, Pager, PagerOptions, PagerResult, PagerState}, RawResponse, Response, StatusCode, }, }; @@ -984,7 +1028,7 @@ mod tests { #[tokio::test] async fn callback_item_pagination() { - let pager: Pager = Pager::from_callback( + let pager: Pager = Pager::new( |continuation: PagerState, _ctx| async move { match continuation { PagerState::Initial => Ok(PagerResult::More { @@ -1038,7 +1082,7 @@ mod tests { #[tokio::test] async fn callback_item_pagination_error() { - let pager: Pager = Pager::from_callback( + let pager = ItemIterator::new( |continuation: PagerState, _options| async move { match continuation { PagerState::Initial => Ok(PagerResult::More { @@ -1098,8 +1142,7 @@ mod tests { #[tokio::test] async fn page_iterator_with_continuation_token() { // Create the first PageIterator. - let mut first_pager: PageIterator> = - PageIterator::from_callback(make_three_page_callback(), None); + let mut first_pager = PageIterator::new(make_three_page_callback(), None); // Should start with no continuation_token. assert_eq!(first_pager.continuation_token(), None); @@ -1122,19 +1165,16 @@ mod tests { assert_eq!(continuation_token, "next-token-1"); // Create the second PageIterator. - let mut second_pager: PageIterator> = PageIterator::from_callback( + let mut second_pager = PageIterator::new( make_three_page_callback(), Some(PagerOptions { - continuation_token: Some(continuation_token), + continuation_token: Some(continuation_token.into()), ..Default::default() }), ); // Should start with link to second page. - assert_eq!( - second_pager.continuation_token(), - Some("next-token-1".into()) - ); + assert_eq!(second_pager.continuation_token(), Some("next-token-1")); // Advance to second page. let second_page = second_pager @@ -1146,10 +1186,7 @@ mod tests { .expect("expected page"); assert_eq!(second_page.page, Some(2)); assert_eq!(second_page.items, vec![4, 5, 6]); - assert_eq!( - second_pager.continuation_token(), - Some("next-token-2".into()) - ); + assert_eq!(second_pager.continuation_token(), Some("next-token-2")); // Advance to last page. let last_page = second_pager @@ -1167,7 +1204,7 @@ mod tests { #[tokio::test] async fn page_iterator_from_item_iterator_after_first_page() { // Create an ItemIterator and consume all items from first page. - let mut item_pager: Pager = Pager::from_callback(make_three_page_callback(), None); + let mut item_pager = ItemIterator::new(make_three_page_callback(), None); // Should start with no continuation_token. assert_eq!(item_pager.continuation_token(), None); @@ -1221,7 +1258,7 @@ mod tests { #[tokio::test] async fn page_iterator_from_item_iterator_second_page_first_item() { // Create an ItemIterator and consume items up to first item of second page. - let mut item_pager: Pager = Pager::from_callback(make_three_page_callback(), None); + let mut item_pager = ItemIterator::new(make_three_page_callback(), None); // Should start with no continuation_token. assert_eq!(item_pager.continuation_token(), None); @@ -1283,7 +1320,7 @@ mod tests { #[tokio::test] async fn item_iterator_with_continuation_token() { // Create the first ItemIterator. - let mut first_pager: Pager = Pager::from_callback(make_three_page_callback(), None); + let mut first_pager = ItemIterator::new(make_three_page_callback(), None); // Should start with no continuation_token. assert_eq!(first_pager.continuation_token(), None); @@ -1309,7 +1346,7 @@ mod tests { assert_eq!(continuation_token, None); // Create the second ItemIterator with continuation token. - let mut second_pager: Pager = Pager::from_callback( + let mut second_pager = ItemIterator::new( make_three_page_callback(), Some(PagerOptions { continuation_token, @@ -1339,7 +1376,7 @@ mod tests { #[tokio::test] async fn item_iterator_continuation_second_page_second_item() { // Create the first ItemIterator. - let mut first_pager: Pager = Pager::from_callback(make_three_page_callback(), None); + let mut first_pager = ItemIterator::new(make_three_page_callback(), None); // Should start with no continuation_token. assert_eq!(first_pager.continuation_token(), None); @@ -1388,7 +1425,7 @@ mod tests { assert_eq!(continuation_token.as_deref(), Some("next-token-1")); // Create the second ItemIterator with continuation token. - let mut second_pager: Pager = Pager::from_callback( + let mut second_pager = ItemIterator::new( make_three_page_callback(), Some(PagerOptions { continuation_token, @@ -1414,7 +1451,7 @@ mod tests { #[tokio::test] async fn item_iterator_continuation_after_first_page() { // Create the first ItemIterator. - let mut first_pager: Pager = Pager::from_callback(make_three_page_callback(), None); + let mut first_pager = ItemIterator::new(make_three_page_callback(), None); // Should start with no continuation_token. assert_eq!(first_pager.continuation_token(), None); @@ -1446,7 +1483,7 @@ mod tests { assert_eq!(continuation_token, None); // Create the second ItemIterator with continuation token. - let mut second_pager: Pager = Pager::from_callback( + let mut second_pager = ItemIterator::new( make_three_page_callback(), Some(PagerOptions { continuation_token, @@ -1493,7 +1530,7 @@ mod tests { #[tokio::test] async fn callback_item_pagination_from_str_error() { - let mut pager: Pager = Pager::from_callback( + let mut pager = ItemIterator::new( |continuation: PagerState, _ctx| async move { match continuation { PagerState::Initial => Ok(PagerResult::More { diff --git a/sdk/core/azure_core/src/lib.rs b/sdk/core/azure_core/src/lib.rs index 213bbec228..e1a7deff55 100644 --- a/sdk/core/azure_core/src/lib.rs +++ b/sdk/core/azure_core/src/lib.rs @@ -36,6 +36,26 @@ pub mod tracing { #[cfg(feature = "xml")] pub use typespec_client_core::xml; +#[cfg(not(target_arch = "wasm32"))] +mod conditional_send { + /// Conditionally implements [`Send`] based on the `target_arch`. + /// + /// This implementation requires `Send`. + pub trait ConditionalSend: Send {} + + impl ConditionalSend for T where T: Send {} +} + +#[cfg(target_arch = "wasm32")] +mod conditional_send { + /// Conditionally implements [`Send`] based on the `target_arch`. + /// + /// This implementation does not require `Send`. + pub trait ConditionalSend {} + + impl ConditionalSend for T {} +} + mod private { pub trait Sealed {} } diff --git a/sdk/keyvault/azure_security_keyvault_certificates/src/generated/clients/certificate_client.rs b/sdk/keyvault/azure_security_keyvault_certificates/src/generated/clients/certificate_client.rs index 9ab368ed1c..c48f62f9e2 100644 --- a/sdk/keyvault/azure_security_keyvault_certificates/src/generated/clients/certificate_client.rs +++ b/sdk/keyvault/azure_security_keyvault_certificates/src/generated/clients/certificate_client.rs @@ -689,7 +689,7 @@ impl CertificateClient { .append_pair("maxresults", &maxresults.to_string()); } let api_version = self.api_version.clone(); - Ok(Pager::from_callback( + Ok(Pager::new( move |next_link: PagerState, pager_options| { let url = match next_link { PagerState::More(next_link) => { @@ -774,7 +774,7 @@ impl CertificateClient { .append_pair("maxresults", &maxresults.to_string()); } let api_version = self.api_version.clone(); - Ok(Pager::from_callback( + Ok(Pager::new( move |next_link: PagerState, pager_options| { let url = match next_link { PagerState::More(next_link) => { @@ -855,7 +855,7 @@ impl CertificateClient { .append_pair("maxresults", &maxresults.to_string()); } let api_version = self.api_version.clone(); - Ok(Pager::from_callback( + Ok(Pager::new( move |next_link: PagerState, pager_options| { let url = match next_link { PagerState::More(next_link) => { @@ -930,7 +930,7 @@ impl CertificateClient { .append_pair("maxresults", &maxresults.to_string()); } let api_version = self.api_version.clone(); - Ok(Pager::from_callback( + Ok(Pager::new( move |next_link: PagerState, pager_options| { let url = match next_link { PagerState::More(next_link) => { diff --git a/sdk/keyvault/azure_security_keyvault_secrets/src/generated/clients/secret_client.rs b/sdk/keyvault/azure_security_keyvault_secrets/src/generated/clients/secret_client.rs index 7cfd1d9571..70364e1cd2 100644 --- a/sdk/keyvault/azure_security_keyvault_secrets/src/generated/clients/secret_client.rs +++ b/sdk/keyvault/azure_security_keyvault_secrets/src/generated/clients/secret_client.rs @@ -312,7 +312,7 @@ impl SecretClient { .append_pair("maxresults", &maxresults.to_string()); } let api_version = self.api_version.clone(); - Ok(Pager::from_callback( + Ok(Pager::new( move |next_link: PagerState, pager_options| { let url = match next_link { PagerState::More(next_link) => { @@ -388,7 +388,7 @@ impl SecretClient { .append_pair("maxresults", &maxresults.to_string()); } let api_version = self.api_version.clone(); - Ok(Pager::from_callback( + Ok(Pager::new( move |next_link: PagerState, pager_options| { let url = match next_link { PagerState::More(next_link) => { @@ -473,7 +473,7 @@ impl SecretClient { .append_pair("maxresults", &maxresults.to_string()); } let api_version = self.api_version.clone(); - Ok(Pager::from_callback( + Ok(Pager::new( move |next_link: PagerState, pager_options| { let url = match next_link { PagerState::More(next_link) => { diff --git a/sdk/keyvault/azure_security_keyvault_secrets/tests/secret_client.rs b/sdk/keyvault/azure_security_keyvault_secrets/tests/secret_client.rs index dd0c0a36b0..b03f7fda06 100644 --- a/sdk/keyvault/azure_security_keyvault_secrets/tests/secret_client.rs +++ b/sdk/keyvault/azure_security_keyvault_secrets/tests/secret_client.rs @@ -529,6 +529,7 @@ async fn list_secrets_verify_telemetry_rehydrated(ctx: TestContext) -> Result<() first_pager .continuation_token() .expect("expected continuation token to be created after first page") + .into() }; let options = SecretClientListSecretPropertiesOptions { method_options: PagerOptions {