diff --git a/Cargo.toml b/Cargo.toml index d326209..c0cfd81 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,8 +49,10 @@ sha1 = { version = "0.10", optional = true } sha2 = { version = "0.10", optional = true } sprintf = { version = ">=0.3, <0.5", optional = true } parse-size = { version = "1", features = ["std"], optional = true } -serde_yaml = { version = "0.9.1", optional = true } +serde_yml = { version = "0.0.12", optional = true } form_urlencoded = { version = "1", optional = true } +http = { version = "1", optional = true } +reqwest = { version = "0.12", features = ["json"], optional = true } urlencoding = { version = "2", optional = true } chrono = { version = "0.4.31", optional = true, default-features = false, features = [ "std", @@ -69,6 +71,7 @@ wasmtime = { version = ">=22, <32", default-features = false, features = [ "cranelift", ] } insta = { version = "1", features = ["yaml"] } +httpmock = "0.7" [build-dependencies] # We would like at least this version of rayon, because older versions depend on old rand, @@ -103,6 +106,7 @@ fast = ["wasmtime/cranelift", "wasmtime/parallel-compilation"] rng = ["dep:rand"] time = ["dep:chrono"] +http = ["dep:http", "dep:reqwest"] base64url-builtins = ["dep:base64", "dep:hex"] crypto-digest-builtins = ["dep:digest", "dep:hex"] @@ -116,9 +120,10 @@ sprintf-builtins = ["dep:sprintf"] json-builtins = ["dep:json-patch"] units-builtins = ["dep:parse-size"] rand-builtins = ["rng"] -yaml-builtins = ["dep:serde_yaml"] +yaml-builtins = ["dep:serde_yml"] urlquery-builtins = ["dep:form_urlencoded", "dep:urlencoding"] time-builtins = ["time", "dep:chrono-tz", "dep:duration-str", "dep:chronoutil"] +http-builtins = ["http", "dep:serde_yml"] all-crypto-builtins = [ "crypto-digest-builtins", @@ -140,6 +145,7 @@ all-builtins = [ "yaml-builtins", "urlquery-builtins", "time-builtins", + "http-builtins", ] [[test]] diff --git a/features.txt b/features.txt index a609839..db75636 100644 --- a/features.txt +++ b/features.txt @@ -2,6 +2,7 @@ loader cli rng +http base64url-builtins crypto-digest-builtins crypto-md5-builtins crypto-digest-builtins crypto-sha1-builtins @@ -17,5 +18,6 @@ units-builtins rand-builtins yaml-builtins time-builtins +http-builtins all-crypto-builtins all-builtins diff --git a/src/builtins/impls/http.rs b/src/builtins/impls/http.rs index 097f1d3..7f83319 100644 --- a/src/builtins/impls/http.rs +++ b/src/builtins/impls/http.rs @@ -14,10 +14,218 @@ //! Builtins used to make HTTP request -use anyhow::{bail, Result}; +use std::{collections::HashMap, future::Future, pin::Pin, time::Duration}; + +use anyhow::{Context, Result}; +use http; +use serde_json::{self, Map}; +use serde_yml; +use tokio::time::sleep; + +use crate::{builtins::traits::Builtin, EvaluationContext}; + +/// This builtin is needed because the wrapper in traits.rs doesn't work when +/// dealing with async+context. +pub struct HttpSendBuiltin {} + +impl Builtin for HttpSendBuiltin +where + C: EvaluationContext, +{ + fn call<'a>( + &'a self, + context: &'a mut C, + args: &'a [&'a [u8]], + ) -> Pin, anyhow::Error>> + Send + 'a>> { + Box::pin(async move { + let [opa_req]: [&'a [u8]; 1] = args.try_into().ok().context("invalid arguments")?; + + let opa_req: serde_json::Value = serde_json::from_slice(opa_req) + .context(concat!("failed to convert opa_req argument"))?; + + let res = send(context, opa_req).await?; + let res = serde_json::to_vec(&res).context("could not serialize result")?; + Ok(res) + }) + } +} /// Returns a HTTP response to the given HTTP request. -#[tracing::instrument(name = "http.send", err)] -pub fn send(request: serde_json::Value) -> Result { - bail!("not implemented"); +/// +/// Wraps [`internal_send`] to add error handling regarding the `raise_error` +/// field in the OPA request. +#[tracing::instrument(name = "http.send", skip(ctx), err)] +pub async fn send( + ctx: &mut C, + opa_req: serde_json::Value, +) -> Result { + let raise_error = opa_req + .get("raise_error") + .and_then(serde_json::Value::as_bool) + .unwrap_or(true); + + match internal_send(ctx, opa_req).await { + Ok(resp) => Ok(resp), + Err(e) => { + if raise_error { + Err(e) + } else { + Ok(serde_json::json!({ + "status_code": 0, + "error": { + "message": e.to_string(), + }, + })) + } + } + } +} + +/// Sends a HTTP request and returns the response. +async fn internal_send( + ctx: &mut C, + opa_req: serde_json::Value, +) -> Result { + let opa_req = opa_req + .as_object() + .ok_or(anyhow::anyhow!("request must be a JSON object"))?; + + let http_req = convert_opa_req_to_http_req(opa_req)?; + + let timeout_value = opa_req.get("timeout"); + + let timeout = if let Some(timeout_value) = timeout_value { + if let Some(timeout_nanos) = timeout_value.as_u64() { + Some(Duration::from_nanos(timeout_nanos)) + } else if let Some(timeout_str) = timeout_value.as_str() { + duration_str::parse(timeout_str).ok() + } else { + None + } + } else { + None + }; + + let enable_redirect = opa_req + .get("enable_redirect") + .and_then(serde_json::Value::as_bool); + + let max_retry_attempts = opa_req + .get("max_retry_attempts") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + + let mut http_resp_res: Result> = + Err(anyhow::anyhow!("This shouldnt happen")); + + for attempt in 0..=max_retry_attempts { + http_resp_res = ctx + .send_http(http_req.clone(), timeout, enable_redirect) + .await; + if http_resp_res.is_ok() { + break; + } + if max_retry_attempts > 0 { + #[allow(clippy::cast_possible_truncation)] + sleep(Duration::from_millis(500 * 2_u64.pow(attempt as u32))).await; + } + } + + match http_resp_res { + Ok(http_resp) => { + let force_json_decode = opa_req + .get("force_json_decode") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false); + let force_yaml_decode = opa_req + .get("force_yaml_decode") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false); + + Ok(convert_http_resp_to_opa_resp( + http_resp, + force_json_decode, + force_yaml_decode, + )) + } + Err(e) => Err(e), + } +} + +/// Converts an OPA request to an HTTP request. +fn convert_opa_req_to_http_req( + opa_req: &Map, +) -> Result> { + let url = opa_req + .get("url") + .ok_or(anyhow::anyhow!("missing url"))? + .as_str() + .ok_or(anyhow::anyhow!("url must be a string"))?; + let method = opa_req + .get("method") + .ok_or(anyhow::anyhow!("missing method"))? + .as_str() + .map(str::to_uppercase) + .ok_or(anyhow::anyhow!("method must be a string"))?; + let headers = opa_req.get("headers").and_then(|v| v.as_object()); + + let mut req_builder = http::Request::builder().method(method.as_str()).uri(url); + if let Some(headers) = headers { + for (key, value) in headers { + req_builder = req_builder.header(key, value.to_string()); + } + } + + let json_req_body = opa_req.get("body"); + let http_req = if let Some(json_req_body) = json_req_body { + req_builder.body(json_req_body.to_string())? + } else { + let raw_req_body = opa_req + .get("raw_body") + .map(std::string::ToString::to_string); + req_builder.body(raw_req_body.unwrap_or_default())? + }; + + Ok(http_req) +} + +/// Converts an HTTP response to an OPA response. +fn convert_http_resp_to_opa_resp( + response: http::Response, + force_json_decode: bool, + force_yaml_decode: bool, +) -> serde_json::Value { + let response_headers = response + .headers() + .iter() + .map(|(k, v)| (k.as_str(), v.to_str().unwrap_or(""))) + .collect::>(); + + let mut opa_resp = serde_json::json!({ + "status_code": response.status().as_u16(), + "headers": response_headers, + }); + + let raw_resp_body: &String = response.body(); + opa_resp["raw_body"] = serde_json::Value::String(raw_resp_body.clone()); + + let content_type = response + .headers() + .get("content-type") + .map(|v| v.to_str().unwrap_or_default()); + + if force_json_decode || content_type == Some("application/json") { + if let Ok(parsed_body) = serde_json::from_str::(raw_resp_body) { + opa_resp["body"] = parsed_body; + } + } else if force_yaml_decode + || content_type == Some("application/yaml") + || content_type == Some("application/x-yaml") + { + if let Ok(parsed_body) = serde_yml::from_str::(raw_resp_body) { + opa_resp["body"] = parsed_body; + } + } + + opa_resp } diff --git a/src/builtins/impls/mod.rs b/src/builtins/impls/mod.rs index 8013b1a..4f228b0 100644 --- a/src/builtins/impls/mod.rs +++ b/src/builtins/impls/mod.rs @@ -27,6 +27,7 @@ pub mod graph; pub mod graphql; #[cfg(feature = "hex-builtins")] pub mod hex; +#[cfg(feature = "http-builtins")] pub mod http; pub mod io; #[cfg(feature = "json-builtins")] diff --git a/src/builtins/impls/rand.rs b/src/builtins/impls/rand.rs index dbedd2a..d949ace 100644 --- a/src/builtins/impls/rand.rs +++ b/src/builtins/impls/rand.rs @@ -35,7 +35,7 @@ pub fn intn(ctx: &mut C, str: String, n: i64) -> Result bool { - let parse: Result = serde_yaml::from_str(&x); + let parse: Result = serde_yml::from_str(&x); parse.is_ok() } /// Serializes the input term to YAML. #[tracing::instrument(name = "yaml.marshal", err)] -pub fn marshal(x: serde_yaml::Value) -> Result { - let parse: String = serde_yaml::to_string(&x)?; +pub fn marshal(x: serde_yml::Value) -> Result { + let parse: String = serde_yml::to_string(&x)?; Ok(parse) } /// Deserializes the input string. #[tracing::instrument(name = "yaml.unmarshal", err)] pub fn unmarshal(x: String) -> Result { - let parse: serde_json::Value = serde_yaml::from_str(&x)?; + let parse: serde_json::Value = serde_yml::from_str(&x)?; Ok(parse) } diff --git a/src/builtins/mod.rs b/src/builtins/mod.rs index 392f7f6..455b650 100644 --- a/src/builtins/mod.rs +++ b/src/builtins/mod.rs @@ -80,7 +80,9 @@ pub fn resolve(name: &str) -> Result>> #[cfg(feature = "hex-builtins")] "hex.encode" => Ok(self::impls::hex::encode.wrap()), - "http.send" => Ok(self::impls::http::send.wrap()), + #[cfg(feature = "http-builtins")] + "http.send" => Ok(Box::new(self::impls::http::HttpSendBuiltin {})), + "indexof_n" => Ok(self::impls::indexof_n.wrap()), "io.jwt.decode" => Ok(self::impls::io::jwt::decode.wrap()), "io.jwt.decode_verify" => Ok(self::impls::io::jwt::decode_verify.wrap()), diff --git a/src/context.rs b/src/context.rs index 110cb0f..89119be 100644 --- a/src/context.rs +++ b/src/context.rs @@ -16,7 +16,7 @@ #![allow(clippy::module_name_repetitions)] -use std::collections::HashMap; +use std::{collections::HashMap, time::Duration}; use anyhow::Result; #[cfg(feature = "time")] @@ -37,6 +37,14 @@ pub trait EvaluationContext: Send + 'static { #[cfg(feature = "time")] fn now(&self) -> chrono::DateTime; + /// Send an HTTP request + fn send_http( + &self, + req: http::Request, + timeout: Option, + enable_redirect: Option, + ) -> impl std::future::Future>> + Send + Sync; + /// Notify the context on evaluation start, so it can clean itself up fn evaluation_start(&mut self); @@ -63,8 +71,32 @@ pub struct DefaultContext { /// The time at which the evaluation started #[cfg(feature = "time")] evaluation_time: chrono::DateTime, + + /// The client used to send HTTP requests + #[cfg(feature = "http")] + http_client: reqwest::Client, } +/// Builds a [`reqwest::Client`] with the given timeout and redirect policy. +fn build_reqwest_client(timeout: Duration, enable_redirect: bool) -> reqwest::Client { + let mut client_builder = reqwest::Client::builder(); + client_builder = client_builder.timeout(timeout); + client_builder = client_builder.redirect(if enable_redirect { + reqwest::redirect::Policy::default() + } else { + reqwest::redirect::Policy::none() + }); + #[allow(clippy::unwrap_used)] + client_builder.build().unwrap() +} + +/// The default HTTP timeout (5 seconds as specified in the OPA specification) +static DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(5); + +/// The default HTTP redirect policy (disabled as specified in the OPA +/// specification) +static DEFAULT_HTTP_ENABLE_REDIRECT: bool = false; + #[allow(clippy::derivable_impls)] impl Default for DefaultContext { fn default() -> Self { @@ -73,6 +105,25 @@ impl Default for DefaultContext { #[cfg(feature = "time")] evaluation_time: chrono::Utc.timestamp_nanos(0), + + #[cfg(feature = "http")] + http_client: build_reqwest_client(DEFAULT_HTTP_TIMEOUT, DEFAULT_HTTP_ENABLE_REDIRECT), + } + } +} + +impl DefaultContext { + /// Returns a [`reqwest::Client`] with the given timeout and redirect + /// policy. + /// + /// If the timeout or redirect policy is different from the default ones, a + /// new [`reqwest::Client`] is built. Otherwise, the existing client is + /// returned. + fn get_reqwest_client(&self, timeout: Duration, enable_redirect: bool) -> reqwest::Client { + if timeout != DEFAULT_HTTP_TIMEOUT || enable_redirect != DEFAULT_HTTP_ENABLE_REDIRECT { + build_reqwest_client(timeout, enable_redirect) + } else { + self.http_client.clone() } } } @@ -91,6 +142,32 @@ impl EvaluationContext for DefaultContext { self.evaluation_time } + #[cfg(feature = "http")] + async fn send_http( + &self, + req: http::Request, + timeout: Option, + enable_redirect: Option, + ) -> Result> { + let client = self.get_reqwest_client( + timeout.unwrap_or(DEFAULT_HTTP_TIMEOUT), + enable_redirect.unwrap_or(DEFAULT_HTTP_ENABLE_REDIRECT), + ); + + let response: reqwest::Response = client.execute(reqwest::Request::try_from(req)?).await?; + + let mut builder = http::Response::builder().status(response.status()); + for (name, value) in response.headers() { + builder = builder.header(name, value); + } + + let bytes_body = response.bytes().await?; + let string_body = String::from_utf8(bytes_body.to_vec())?; + builder + .body(string_body) + .map_err(|e| anyhow::anyhow!("Failed to build response: {}", e)) + } + fn evaluation_start(&mut self) { // Clear the cache self.cache = HashMap::new(); @@ -122,6 +199,8 @@ impl EvaluationContext for DefaultContext { /// Test utilities pub mod tests { + use std::time::Duration; + use anyhow::Result; #[cfg(feature = "time")] use chrono::TimeZone; @@ -182,6 +261,16 @@ pub mod tests { rand::rngs::StdRng::seed_from_u64(self.seed) } + #[cfg(feature = "http")] + async fn send_http( + &self, + req: http::Request, + timeout: Option, + enable_redirect: Option, + ) -> Result> { + self.inner.send_http(req, timeout, enable_redirect).await + } + fn cache_get(&mut self, key: &K) -> Result> { self.inner.cache_get(key) } diff --git a/tests/fixtures/test-http.rego b/tests/fixtures/test-http.rego new file mode 100644 index 0000000..220b4c1 --- /dev/null +++ b/tests/fixtures/test-http.rego @@ -0,0 +1,15 @@ +package fixtures + +import rego.v1 + +# Test that automatic ser/der is working fine for request and response +get_json := http.send({"url": sprintf("%s/json", [input.base_url]), "method": "get"}) +get_yaml := http.send({"url": sprintf("%s/yaml", [input.base_url]), "method": "get"}) +post_json := http.send({"url": sprintf("%s/post", [input.base_url]), "method": "post", "body": {"key": "value"}}) + +# Test a connection error doesn't error out the whole policy when using raise_error=false +get_no_conn := http.send({"url": "https://cahbe8ang5umaiwavai1shuchiehae7u.com", "method": "get", "raise_error": false}) + +# Test automatic redirection +get_redirect := http.send({"url": sprintf("%s/redirect", [input.base_url]), "method": "get"}) +get_redirect_follow := http.send({"url": sprintf("%s/redirect", [input.base_url]), "method": "get", "enable_redirect": true}) diff --git a/tests/smoke_test.rs b/tests/smoke_test.rs index e728501..bda2ae3 100644 --- a/tests/smoke_test.rs +++ b/tests/smoke_test.rs @@ -17,6 +17,7 @@ use std::path::Path; use anyhow::Result as AnyResult; use insta::assert_yaml_snapshot; use opa_wasm::{read_bundle, Runtime, TestContext}; +use serde_json::json; use wasmtime::{Config, Engine, Module, Store}; macro_rules! integration_test { @@ -31,7 +32,7 @@ macro_rules! integration_test { ($name:ident, $suite:expr, input = $input:expr) => { #[tokio::test] async fn $name() { - assert_yaml_snapshot!(test_policy($suite, Some($input)) + assert_yaml_snapshot!(test_policy_with_datafile($suite, Some($input)) .await .expect("error in test suite")); } @@ -82,10 +83,26 @@ fn input(name: &str) -> String { .into() } -async fn test_policy(bundle_name: &str, data: Option<&str>) -> AnyResult { +async fn test_policy_with_datafile( + bundle_name: &str, + datafile_path: Option<&str>, +) -> AnyResult { + let input = match datafile_path { + Some(path) => { + let input_bytes = tokio::fs::read(input(path)).await?; + Some(serde_json::from_slice(&input_bytes[..])?) + } + None => None, + }; + test_policy(bundle_name, input).await +} + +async fn test_policy( + bundle_name: &str, + data: Option, +) -> AnyResult { let input = if let Some(data) = data { - let input_bytes = tokio::fs::read(input(&format!("{}.json", data))).await?; - serde_json::from_slice(&input_bytes[..])? + data } else { serde_json::Value::Object(serde_json::Map::default()) }; @@ -112,9 +129,13 @@ async fn infra_loader_works() { integration_test!( test_loader_false, "test-loader", - input = "test-loader.false" + input = "test-loader.false.json" +); +integration_test!( + test_loader_true, + "test-loader", + input = "test-loader.true.json" ); -integration_test!(test_loader_true, "test-loader", input = "test-loader.true"); integration_test!(test_loader_empty, "test-loader"); integration_test!(test_units, "test-units"); integration_test!(test_rand, "test-rand"); @@ -122,6 +143,93 @@ integration_test!(test_yaml, "test-yaml"); integration_test!(test_urlquery, "test-urlquery"); integration_test!(test_time, "test-time"); +#[tokio::test] +async fn test_http() { + use httpmock; + let server = httpmock::MockServer::start(); + + let content_value: serde_json::Value = json!({ "key": "value" }); + + let get_json_mock = server.mock(|when, then| { + when.method(httpmock::Method::GET).path("/json"); + then.status(200) + .header("Content-Type", "application/json") + .body(content_value.to_string()); + }); + + let get_yaml_mock = server.mock(|when, then| { + when.method(httpmock::Method::GET).path("/yaml"); + then.status(200) + .header("Content-Type", "application/yaml") + .body(serde_yml::to_string(&content_value).unwrap()); + }); + + let post_json_mock = server.mock(|when, then| { + when.method(httpmock::Method::POST) + .path("/post") + .json_body(content_value.clone()); + then.status(200); + }); + + let redirect_json_mock = server.mock(|when, then| { + when.method(httpmock::Method::GET).path("/redirect"); + then.status(302).header("Location", "/target"); + }); + + let target_json_mock = server.mock(|when, then| { + when.method(httpmock::Method::GET).path("/target"); + then.status(200); + }); + + let res = test_policy("test-http", Some(json!({"base_url": server.url("")}))) + .await + .expect("error in test suite"); + + let result = res.as_array().unwrap()[0] + .as_object() + .unwrap() + .get("result") + .unwrap() + .as_object() + .unwrap(); + + get_json_mock.assert(); + let get_json_res = result.get("get_json").unwrap().as_object().unwrap(); + assert_eq!( + get_json_res.get("raw_body").unwrap(), + &content_value.to_string() + ); + assert_eq!(get_json_res.get("body").unwrap(), &content_value); + + get_yaml_mock.assert(); + let get_yaml_res = result.get("get_yaml").unwrap().as_object().unwrap(); + assert_eq!( + get_yaml_res.get("raw_body").unwrap(), + &serde_yml::to_string(&content_value).unwrap() + ); + assert_eq!(get_yaml_res.get("body").unwrap(), &content_value); + + post_json_mock.assert(); + let post_json_res = result.get("post_json").unwrap().as_object().unwrap(); + assert_eq!(post_json_res.get("status_code").unwrap(), &200); + + let get_no_conn_res = result.get("get_no_conn").unwrap().as_object().unwrap(); + assert_eq!(get_no_conn_res.get("status_code").unwrap(), &0); + + redirect_json_mock.assert_hits(2); + target_json_mock.assert(); + + let get_redirect_res = result.get("get_redirect").unwrap().as_object().unwrap(); + assert_eq!(get_redirect_res.get("status_code").unwrap(), &302); + + let get_redirect_follow_res = result + .get("get_redirect_follow") + .unwrap() + .as_object() + .unwrap(); + assert_eq!(get_redirect_follow_res.get("status_code").unwrap(), &200); +} + /* #[tokio::test] async fn test_uuid() {