Skip to content

Commit 470709e

Browse files
authored
feat: add oauth to shared (#22)
1 parent 70a863b commit 470709e

File tree

5 files changed

+226
-11
lines changed

5 files changed

+226
-11
lines changed

Cargo.toml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,24 +35,36 @@ metrics-exporter-prometheus = "0.17.0"
3535
# Slot Calc
3636
chrono = "0.4.40"
3737

38+
# OAuth
39+
oauth2 = { version = "5.0.0", optional = true }
40+
tokio = { version = "1.36.0", optional = true }
41+
3842
# Other
3943
thiserror = "2.0.11"
4044
alloy = { version = "0.12.6", optional = true, default-features = false, features = ["std", "signer-aws", "signer-local", "consensus", "network"] }
4145
serde = { version = "1", features = ["derive"] }
4246
async-trait = { version = "0.1.80", optional = true }
4347

48+
4449
# AWS
4550
aws-config = { version = "1.1.7", optional = true }
4651
aws-sdk-kms = { version = "1.15.0", optional = true }
52+
reqwest = { version = "0.12.15", optional = true }
4753

4854
[dev-dependencies]
4955
ajj = "0.3.1"
5056
axum = "0.8.1"
57+
eyre = "0.6.12"
5158
serial_test = "3.2.0"
5259
signal-hook = "0.3.17"
5360
tokio = { version = "1.43.0", features = ["macros"] }
5461

5562
[features]
5663
default = ["alloy"]
5764
alloy = ["dep:alloy", "dep:async-trait", "dep:aws-config", "dep:aws-sdk-kms"]
58-
perms = []
65+
perms = ["dep:oauth2", "dep:tokio", "dep:reqwest"]
66+
67+
[[example]]
68+
name = "oauth"
69+
path = "examples/oauth.rs"
70+
required-features = ["perms"]

examples/oauth.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
use init4_bin_base::{perms::OAuthConfig, utils::from_env::FromEnv};
2+
3+
#[tokio::main]
4+
async fn main() -> eyre::Result<()> {
5+
let cfg = OAuthConfig::from_env()?;
6+
let authenticator = cfg.authenticator();
7+
let token = authenticator.token();
8+
9+
let _jh = authenticator.spawn();
10+
11+
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
12+
dbg!(token.read());
13+
14+
Ok(())
15+
}

src/lib.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,28 @@ pub mod perms;
1818

1919
/// Signet utilities.
2020
pub mod utils {
21-
/// Prometheus metrics utilities.
22-
pub mod metrics;
23-
24-
/// OpenTelemetry utilities.
25-
pub mod otlp;
21+
/// Slot calculator for determining the current slot and timepoint within a
22+
/// slot.
23+
pub mod calc;
2624

2725
/// [`FromEnv`], [`FromEnvVar`] traits and related utilities.
2826
///
2927
/// [`FromEnv`]: from_env::FromEnv
3028
/// [`FromEnvVar`]: from_env::FromEnvVar
3129
pub mod from_env;
3230

33-
/// Tracing utilities.
34-
pub mod tracing;
31+
/// Prometheus metrics utilities.
32+
pub mod metrics;
3533

36-
/// Slot calculator for determining the current slot and timepoint within a
37-
/// slot.
38-
pub mod calc;
34+
/// OpenTelemetry utilities.
35+
pub mod otlp;
3936

4037
#[cfg(feature = "alloy")]
4138
/// Signer using a local private key or AWS KMS key.
4239
pub mod signer;
40+
41+
/// Tracing utilities.
42+
pub mod tracing;
4343
}
4444

4545
/// Re-exports of common dependencies.

src/perms/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@ pub use builders::{Builder, BuilderPermissionError, Builders, BuildersEnvError};
33

44
pub(crate) mod config;
55
pub use config::{SlotAuthzConfig, SlotAuthzConfigEnvError};
6+
7+
pub(crate) mod oauth;
8+
pub use oauth::{Authenticator, OAuthConfig, SharedToken};

src/perms/oauth.rs

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
//! Service responsible for authenticating with the cache with Oauth tokens.
2+
//! This authenticator periodically fetches a new token every set amount of seconds.
3+
use crate::{
4+
deps::tracing::{error, info},
5+
utils::from_env::FromEnv,
6+
};
7+
use oauth2::{
8+
basic::{BasicClient, BasicTokenType},
9+
AuthUrl, ClientId, ClientSecret, EmptyExtraTokenFields, EndpointNotSet, EndpointSet,
10+
HttpClientError, RequestTokenError, StandardErrorResponse, StandardTokenResponse, TokenUrl,
11+
};
12+
use std::sync::{Arc, Mutex};
13+
use tokio::task::JoinHandle;
14+
15+
type Token = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
16+
17+
type MyOAuthClient =
18+
BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>;
19+
20+
/// Configuration for the OAuth2 client.
21+
#[derive(Debug, Clone, FromEnv)]
22+
#[from_env(crate)]
23+
pub struct OAuthConfig {
24+
/// OAuth client ID for the builder.
25+
#[from_env(var = "OAUTH_CLIENT_ID", desc = "OAuth client ID for the builder")]
26+
pub oauth_client_id: String,
27+
/// OAuth client secret for the builder.
28+
#[from_env(
29+
var = "OAUTH_CLIENT_SECRET",
30+
desc = "OAuth client secret for the builder"
31+
)]
32+
pub oauth_client_secret: String,
33+
/// OAuth authenticate URL for the builder for performing OAuth logins.
34+
#[from_env(
35+
var = "OAUTH_AUTHENTICATE_URL",
36+
desc = "OAuth authenticate URL for the builder for performing OAuth logins"
37+
)]
38+
pub oauth_authenticate_url: url::Url,
39+
/// OAuth token URL for the builder to get an OAuth2 access token
40+
#[from_env(
41+
var = "OAUTH_TOKEN_URL",
42+
desc = "OAuth token URL for the builder to get an OAuth2 access token"
43+
)]
44+
pub oauth_token_url: url::Url,
45+
/// The oauth token refresh interval in seconds.
46+
#[from_env(
47+
var = "AUTH_TOKEN_REFRESH_INTERVAL",
48+
desc = "The oauth token refresh interval in seconds"
49+
)]
50+
pub oauth_token_refresh_interval: u64,
51+
}
52+
53+
impl OAuthConfig {
54+
/// Create a new [`Authenticator`] from the provided config.
55+
pub fn authenticator(&self) -> Authenticator {
56+
Authenticator::new(self)
57+
}
58+
}
59+
60+
/// A shared token that can be read and written to by multiple threads.
61+
#[derive(Debug, Clone, Default)]
62+
pub struct SharedToken(Arc<Mutex<Option<Token>>>);
63+
64+
impl SharedToken {
65+
/// Read the token from the shared token.
66+
pub fn read(&self) -> Option<Token> {
67+
self.0.lock().unwrap().clone()
68+
}
69+
70+
/// Write a new token to the shared token.
71+
pub fn write(&self, token: Token) {
72+
let mut lock = self.0.lock().unwrap();
73+
*lock = Some(token);
74+
}
75+
76+
/// Check if the token is authenticated.
77+
pub fn is_authenticated(&self) -> bool {
78+
self.0.lock().unwrap().is_some()
79+
}
80+
}
81+
82+
/// A self-refreshing, periodically fetching authenticator for the block
83+
/// builder. This task periodically fetches a new token, and stores it in a
84+
/// [`SharedToken`].
85+
#[derive(Debug)]
86+
pub struct Authenticator {
87+
/// Configuration
88+
pub config: OAuthConfig,
89+
client: MyOAuthClient,
90+
token: SharedToken,
91+
reqwest: reqwest::Client,
92+
}
93+
94+
impl Authenticator {
95+
/// Creates a new Authenticator from the provided builder config.
96+
pub fn new(config: &OAuthConfig) -> Self {
97+
let client = BasicClient::new(ClientId::new(config.oauth_client_id.clone()))
98+
.set_client_secret(ClientSecret::new(config.oauth_client_secret.clone()))
99+
.set_auth_uri(AuthUrl::from_url(config.oauth_authenticate_url.clone()))
100+
.set_token_uri(TokenUrl::from_url(config.oauth_token_url.clone()));
101+
102+
let rq_client = reqwest::Client::builder()
103+
.redirect(reqwest::redirect::Policy::none())
104+
.build()
105+
.unwrap();
106+
107+
Self {
108+
config: config.clone(),
109+
client,
110+
token: Default::default(),
111+
reqwest: rq_client,
112+
}
113+
}
114+
115+
/// Requests a new authentication token and, if successful, sets it to as the token
116+
pub async fn authenticate(
117+
&self,
118+
) -> Result<
119+
(),
120+
RequestTokenError<
121+
HttpClientError<reqwest::Error>,
122+
StandardErrorResponse<oauth2::basic::BasicErrorResponseType>,
123+
>,
124+
> {
125+
let token = self.fetch_oauth_token().await?;
126+
self.set_token(token);
127+
Ok(())
128+
}
129+
130+
/// Returns true if there is Some token set
131+
pub fn is_authenticated(&self) -> bool {
132+
self.token.is_authenticated()
133+
}
134+
135+
/// Sets the Authenticator's token to the provided value
136+
fn set_token(&self, token: StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>) {
137+
self.token.write(token);
138+
}
139+
140+
/// Returns the currently set token
141+
pub fn token(&self) -> SharedToken {
142+
self.token.clone()
143+
}
144+
145+
/// Fetches an oauth token
146+
pub async fn fetch_oauth_token(
147+
&self,
148+
) -> Result<
149+
Token,
150+
RequestTokenError<
151+
HttpClientError<reqwest::Error>,
152+
StandardErrorResponse<oauth2::basic::BasicErrorResponseType>,
153+
>,
154+
> {
155+
let token_result = self
156+
.client
157+
.exchange_client_credentials()
158+
.request_async(&self.reqwest)
159+
.await?;
160+
161+
Ok(token_result)
162+
}
163+
164+
/// Spawns a task that periodically fetches a new token every 300 seconds.
165+
pub fn spawn(self) -> JoinHandle<()> {
166+
let interval = self.config.oauth_token_refresh_interval;
167+
168+
let handle: JoinHandle<()> = tokio::spawn(async move {
169+
loop {
170+
info!("Refreshing oauth token");
171+
match self.authenticate().await {
172+
Ok(_) => {
173+
info!("Successfully refreshed oauth token");
174+
}
175+
Err(e) => {
176+
error!(%e, "Failed to refresh oauth token");
177+
}
178+
};
179+
let _sleep = tokio::time::sleep(tokio::time::Duration::from_secs(interval)).await;
180+
}
181+
});
182+
183+
handle
184+
}
185+
}

0 commit comments

Comments
 (0)