Skip to content

Commit da6dee5

Browse files
fix(data-pipeline): Apply suggestions
1 parent ad49ce7 commit da6dee5

File tree

9 files changed

+126
-124
lines changed

9 files changed

+126
-124
lines changed

data-pipeline/src/agent_info/fetcher.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
use super::{schema::AgentInfo, AgentInfoArc};
77
use anyhow::{anyhow, Result};
88
use arc_swap::ArcSwapOption;
9-
use ddcommon::hyper_migration;
10-
use ddcommon::Endpoint;
9+
use ddcommon::{hyper_migration, worker::Worker, Endpoint};
1110
use http_body_util::BodyExt;
1211
use hyper::{self, body::Buf, header::HeaderName};
1312
use log::{error, info};
@@ -140,11 +139,20 @@ impl AgentInfoFetcher {
140139
}
141140
}
142141

142+
/// Return an AgentInfoArc storing the info received by the agent.
143+
///
144+
/// When the fetcher is running it updates the AgentInfoArc when the agent's info changes.
145+
pub fn get_info(&self) -> AgentInfoArc {
146+
self.info.clone()
147+
}
148+
}
149+
150+
impl Worker for AgentInfoFetcher {
143151
/// Start fetching the info endpoint with the given interval.
144152
///
145153
/// # Warning
146154
/// This method does not return and should be called within a dedicated task.
147-
pub async fn run(&self) {
155+
async fn run(&mut self) {
148156
loop {
149157
let current_info = self.info.load();
150158
let current_hash = current_info.as_ref().map(|info| info.state_hash.as_str());
@@ -164,13 +172,6 @@ impl AgentInfoFetcher {
164172
sleep(self.refresh_interval).await;
165173
}
166174
}
167-
168-
/// Return an AgentInfoArc storing the info received by the agent.
169-
///
170-
/// When the fetcher is running it updates the AgentInfoArc when the agent's info changes.
171-
pub fn get_info(&self) -> AgentInfoArc {
172-
self.info.clone()
173-
}
174175
}
175176

176177
#[cfg(test)]
@@ -329,7 +330,7 @@ mod tests {
329330
})
330331
.await;
331332
let endpoint = Endpoint::from_url(server.url("/info").parse().unwrap());
332-
let fetcher = AgentInfoFetcher::new(endpoint.clone(), Duration::from_millis(100));
333+
let mut fetcher = AgentInfoFetcher::new(endpoint.clone(), Duration::from_millis(100));
333334
let info = fetcher.get_info();
334335
assert!(info.load().is_none());
335336
tokio::spawn(async move {

data-pipeline/src/pausable_worker.rs

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,29 @@
22
// SPDX-License-Identifier: Apache-2.0
33

44
//! Defines a pausable worker to be able to stop background processes before forks
5-
use anyhow::{anyhow, Result};
6-
use ddtelemetry::worker::TelemetryWorker;
7-
use tokio::{runtime::Runtime, select, task::JoinHandle};
8-
use tokio_util::sync::CancellationToken;
9-
10-
use crate::{agent_info::AgentInfoFetcher, stats_exporter::StatsExporter};
11-
12-
/// Trait representing a worker which can be wrapped by `PausableWorker`
13-
pub trait Worker {
14-
/// Main worker loop
15-
fn run(&mut self) -> impl std::future::Future<Output = ()> + Send;
16-
}
17-
18-
impl Worker for StatsExporter {
19-
async fn run(&mut self) {
20-
Self::run(self).await
21-
}
22-
}
235
24-
impl Worker for AgentInfoFetcher {
25-
async fn run(&mut self) {
26-
Self::run(self).await
27-
}
28-
}
29-
30-
impl Worker for TelemetryWorker {
31-
async fn run(&mut self) {
32-
Self::run(self).await
33-
}
34-
}
6+
use ddcommon::worker::Worker;
7+
use std::fmt::Display;
8+
use tokio::{
9+
runtime::Runtime,
10+
select,
11+
task::{JoinError, JoinHandle},
12+
};
13+
use tokio_util::sync::CancellationToken;
3514

36-
/// A pausable worker which can be paused and restarded on forks.
15+
/// A pausable worker which can be paused and restarted on forks.
16+
///
17+
/// Used to allow a [`ddcommon::worker::Worker`] to be paused while saving its state when dropping
18+
/// a tokio runtime to be able to restart with the same state on a new runtime. This is used to
19+
/// stop all threads before a fork to avoid deadlocks in child.
3720
///
38-
/// # Requirements
39-
/// When paused the worker will exit on the next awaited call. To be able to safely restart the
40-
/// worker must be in a valid state on every call to `.await`.
21+
/// # Time-to-pause
22+
/// This loop should yield regularly to reduce time-to-pause. See [`tokio::task::yield_now`].
23+
///
24+
/// # Cancellation safety
25+
/// The main loop can be interrupted at any yield point (`.await`ed call). The state of the worker
26+
/// at this point will be saved and used to restart the worker. To be able to safely restart, the
27+
/// worker must be in a valid state on every call to `.await`. See [`tokio::select#cancellation-safety`] for more details.
4128
#[derive(Debug)]
4229
pub enum PausableWorker<T: Worker + Send + Sync + 'static> {
4330
Running {
@@ -50,6 +37,25 @@ pub enum PausableWorker<T: Worker + Send + Sync + 'static> {
5037
InvalidState,
5138
}
5239

40+
#[derive(Debug)]
41+
pub enum PausableWorkerError {
42+
InvalidState,
43+
TaskAborted,
44+
}
45+
46+
impl Display for PausableWorkerError {
47+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48+
match self {
49+
PausableWorkerError::InvalidState => {
50+
write!(f, "Worker is in an invalid state and must be recreated.")
51+
}
52+
PausableWorkerError::TaskAborted => {
53+
write!(f, "Worker task has been aborted and state has been lost.")
54+
}
55+
}
56+
}
57+
}
58+
5359
impl<T: Worker + Send + Sync + 'static> PausableWorker<T> {
5460
/// Create a new pausable worker from the given worker.
5561
pub fn new(worker: T) -> Self {
@@ -58,13 +64,15 @@ impl<T: Worker + Send + Sync + 'static> PausableWorker<T> {
5864

5965
/// Start the worker on the given runtime.
6066
///
67+
/// The worker's main loop will be run on the runtime.
68+
///
6169
/// # Errors
6270
/// Fails if the worker is in an invalid state.
63-
pub fn start(&mut self, rt: &Runtime) -> Result<()> {
71+
pub fn start(&mut self, rt: &Runtime) -> Result<(), PausableWorkerError> {
6472
if let Self::Running { .. } = self {
6573
Ok(())
6674
} else if let Self::Paused { mut worker } = std::mem::replace(self, Self::InvalidState) {
67-
// Worker is temporarly in an invalid state, but since this block is failsafe it will
75+
// Worker is temporarily in an invalid state, but since this block is failsafe it will
6876
// be replaced by a valid state.
6977
let stop_token = CancellationToken::new();
7078
let cloned_token = stop_token.clone();
@@ -78,34 +86,34 @@ impl<T: Worker + Send + Sync + 'static> PausableWorker<T> {
7886
*self = PausableWorker::Running { handle, stop_token };
7987
Ok(())
8088
} else {
81-
Err(anyhow!("Failed to start service"))
89+
Err(PausableWorkerError::InvalidState)
8290
}
8391
}
8492

8593
/// Pause the worker saving it's state to be restarted.
8694
///
8795
/// # Errors
8896
/// Fails if the worker handle has been aborted preventing the worker from being retrieved.
89-
pub async fn pause(&mut self) -> Result<()> {
97+
pub async fn pause(&mut self) -> Result<(), PausableWorkerError> {
9098
match self {
9199
PausableWorker::Running { handle, stop_token } => {
92100
stop_token.cancel();
93101
if let Ok(worker) = handle.await {
94102
*self = PausableWorker::Paused { worker };
95103
Ok(())
96104
} else {
97-
// Worker isn't retrieved and can't be restarted.
105+
// The task has been aborted and the worker can't be retrieved.
98106
*self = PausableWorker::InvalidState;
99-
Err(anyhow!("Failed to stop worker. Worker must be recreated"))
107+
Err(PausableWorkerError::TaskAborted)
100108
}
101109
}
102110
PausableWorker::Paused { .. } => Ok(()),
103-
PausableWorker::InvalidState => Err(anyhow!("Worker is in invalid state")),
111+
PausableWorker::InvalidState => Err(PausableWorkerError::InvalidState),
104112
}
105113
}
106114

107115
/// Wait for the run method of the worker to exit.
108-
pub async fn join(self) -> Result<()> {
116+
pub async fn join(self) -> Result<(), JoinError> {
109117
if let PausableWorker::Running { handle, .. } = self {
110118
handle.await?;
111119
}

data-pipeline/src/stats_exporter.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use std::{
1313

1414
use datadog_trace_protobuf::pb;
1515
use datadog_trace_utils::send_with_retry::{send_with_retry, RetryStrategy};
16-
use ddcommon::Endpoint;
16+
use ddcommon::{worker::Worker, Endpoint};
1717
use hyper;
1818
use log::error;
1919
use tokio::select;
@@ -128,13 +128,15 @@ impl StatsExporter {
128128
.flush(time::SystemTime::now(), force_flush),
129129
)
130130
}
131+
}
131132

133+
impl Worker for StatsExporter {
132134
/// Run loop of the stats exporter
133135
///
134136
/// Once started, the stats exporter will flush and send stats on every `self.flush_interval`.
135137
/// If the `self.cancellation_token` is cancelled, the exporter will force flush all stats and
136138
/// return.
137-
pub async fn run(&self) {
139+
async fn run(&mut self) {
138140
loop {
139141
select! {
140142
_ = self.cancellation_token.cancelled() => {
@@ -316,7 +318,7 @@ mod tests {
316318
})
317319
.await;
318320

319-
let stats_exporter = StatsExporter::new(
321+
let mut stats_exporter = StatsExporter::new(
320322
BUCKETS_DURATION,
321323
Arc::new(Mutex::new(get_test_concentrator())),
322324
get_test_metadata(),
@@ -356,7 +358,7 @@ mod tests {
356358
let buckets_duration = Duration::from_secs(10);
357359
let cancellation_token = CancellationToken::new();
358360

359-
let stats_exporter = StatsExporter::new(
361+
let mut stats_exporter = StatsExporter::new(
360362
buckets_duration,
361363
Arc::new(Mutex::new(get_test_concentrator())),
362364
get_test_metadata(),

data-pipeline/src/telemetry/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ impl TelemetryClient {
266266

267267
#[cfg(test)]
268268
mod tests {
269-
use ddcommon::hyper_migration;
269+
use ddcommon::{hyper_migration, worker::Worker};
270270
use httpmock::Method::POST;
271271
use httpmock::MockServer;
272272
use hyper::{Response, StatusCode};

data-pipeline/src/trace_exporter/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,9 @@ impl TraceExporter {
395395
);
396396
let mut stats_worker = PausableWorker::new(stats_exporter);
397397
let runtime = self.runtime()?;
398-
stats_worker.start(&runtime)?;
398+
stats_worker.start(&runtime).map_err(|e| {
399+
TraceExporterError::Internal(InternalErrorKind::InvalidWorkerState(e.to_string()))
400+
})?;
399401

400402
self.workers.lock_or_panic().stats = Some(stats_worker);
401403

data-pipeline/tests/test_fetch_info.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ mod tracing_integration_tests {
66
use arc_swap::access::Access;
77
use data_pipeline::agent_info::{fetch_info, AgentInfoFetcher};
88
use datadog_trace_utils::test_utils::datadog_test_agent::DatadogTestAgent;
9-
use ddcommon::Endpoint;
9+
use ddcommon::{worker::Worker, Endpoint};
1010
use std::time::Duration;
1111

1212
#[cfg_attr(miri, ignore)]
@@ -30,7 +30,7 @@ mod tracing_integration_tests {
3030
async fn test_agent_info_fetcher_with_test_agent() {
3131
let test_agent = DatadogTestAgent::new(None, None, &[]).await;
3232
let endpoint = Endpoint::from_url(test_agent.get_uri_for_endpoint("info", None).await);
33-
let fetcher = AgentInfoFetcher::new(endpoint, Duration::from_secs(1));
33+
let mut fetcher = AgentInfoFetcher::new(endpoint, Duration::from_secs(1));
3434
let info_arc = fetcher.get_info();
3535
tokio::spawn(async move { fetcher.run().await });
3636
let info_received = async {

ddcommon/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pub mod rate_limiter;
2626
pub mod tag;
2727
pub mod tracer_metadata;
2828
pub mod unix_utils;
29+
pub mod worker;
2930

3031
/// Extension trait for `Mutex` to provide a method that acquires a lock, panicking if the lock is
3132
/// poisoned.

ddcommon/src/worker.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
/// Trait representing a generic worker.
5+
///
6+
/// The worker runs an async looping function running periodic tasks.
7+
///
8+
/// This trait can be used to provide wrapper around a worker.
9+
pub trait Worker {
10+
/// Main worker loop
11+
fn run(&mut self) -> impl std::future::Future<Output = ()> + Send;
12+
}

0 commit comments

Comments
 (0)