Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

feat: chunks certification #24

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
lint + fmt + now decode_body returns body sha
3cL1p5e7 committed Mar 2, 2022
commit 346fbdd2691731f89e6eaf7cbaea0e399db61136
126 changes: 69 additions & 57 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::config::dns_canister_config::DnsCanisterConfig;
use clap::{crate_authors, crate_version, AppSettings, Parser};
use flate2::read::{DeflateDecoder, GzDecoder};
use hyper::{
body,
body::Bytes,
@@ -24,6 +25,7 @@ use ic_utils::{
use lazy_regex::regex_captures;
use sha2::{Digest, Sha256};
use slog::Drain;
use std::io::prelude::Read;
use std::{
convert::Infallible,
error::Error,
@@ -35,8 +37,6 @@ use std::{
Arc, Mutex,
},
};
use std::io::prelude::{Read};
use flate2::read::{GzDecoder, DeflateDecoder};

mod config;
mod logging;
@@ -177,18 +177,12 @@ fn resolve_canister_id(
fn decode_hash_tree(
name: &str,
value: Option<String>,
logger: &slog::Logger
logger: &slog::Logger,
) -> Result<Vec<u8>, ()> {
match value {
Some(tree) => base64::decode(tree)
.map_err(|e| {
slog::warn!(
logger,
"Unable to decode {} from base64: {}",
name,
e
);
}),
Some(tree) => base64::decode(tree).map_err(|e| {
slog::warn!(logger, "Unable to decode {} from base64: {}", name, e);
}),
_ => Err(()),
}
}
@@ -201,10 +195,7 @@ struct HeadersData {
encoding: Option<String>,
}

fn extract_headers_data(
headers: &Vec<HeaderField>,
logger: &slog::Logger
) -> HeadersData {
fn extract_headers_data(headers: &[HeaderField], logger: &slog::Logger) -> HeadersData {
let mut headers_data = HeadersData {
certificate: None,
tree: None,
@@ -267,11 +258,14 @@ fn extract_headers_data(
(Some(chunk_tree), Ok(bytes)) => {
slog::warn!(logger, "duplicate chunk_tree field: {:?}", bytes);
Some(chunk_tree)
},
}
(Some(chunk_tree), Err(_)) => {
slog::warn!(logger, "duplicate chunk_tree field (failed to decode)");
slog::warn!(
logger,
"duplicate chunk_tree field (failed to decode)"
);
Some(chunk_tree)
},
}
};
}
}
@@ -400,15 +394,17 @@ async fn forward_request(
for HeaderField(name, value) in &http_response.headers {
builder = builder.header(name, value);
}

let headers_data = extract_headers_data(&http_response.headers, &logger);
let body = if logger.is_trace_enabled() {
Some(http_response.body.clone())
} else {
None
};

let is_streaming = http_response.streaming_strategy.is_some();
// No need to stream when get 206 HTTP partial response
let is_streaming =
http_response.streaming_strategy.is_some() && http_response.status_code != 206;
let body_valid = validate(
&headers_data,
&canister_id,
@@ -425,7 +421,8 @@ async fn forward_request(
.unwrap());
}

let response = if let Some(streaming_strategy) = http_response.streaming_strategy {
let response = if is_streaming {
let streaming_strategy = http_response.streaming_strategy.unwrap();
let (mut sender, body) = body::Body::channel();
let agent = agent.as_ref().clone();
sender.send_data(Bytes::from(http_response.body)).await?;
@@ -453,8 +450,13 @@ async fn forward_request(
.call()
.await
{
Ok((StreamingCallbackHttpResponse { body, token, chunk_tree },)) => {
let decoded_chunk_tree = decode_hash_tree("chunk_tree", chunk_tree, &logger);
Ok((StreamingCallbackHttpResponse {
body,
token,
chunk_tree,
},)) => {
let decoded_chunk_tree =
decode_hash_tree("chunk_tree", chunk_tree, &logger);
let chunk_headers_data = HeadersData {
certificate: headers_data.certificate.clone(),
tree: headers_data.tree.clone(),
@@ -472,7 +474,9 @@ async fn forward_request(
logger.clone(),
);

if body_valid.is_err() || sender.send_data(Bytes::from(body)).await.is_err() {
if body_valid.is_err()
|| sender.send_data(Bytes::from(body)).await.is_err()
{
sender.abort();
break;
}
@@ -542,7 +546,7 @@ fn validate(
is_streaming: bool,
logger: slog::Logger,
) -> Result<(), String> {
let decoded_body = decode_body(response_body, headers_data.encoding.clone());
let body_sha = decode_body(response_body, headers_data.encoding.clone());
let body_valid = match (headers_data.certificate.clone(), headers_data.tree.clone()) {
(Some(Ok(certificate)), Some(Ok(tree))) => match validate_body(
Certificates {
@@ -551,17 +555,19 @@ fn validate(
chunk_tree: headers_data.chunk_tree.clone(),
chunk_index: headers_data.chunk_index.clone(),
},
&canister_id,
&agent,
&uri,
&decoded_body,
canister_id,
agent,
uri,
&body_sha,
logger.clone(),
) {
Ok(valid) => if valid {
Ok(())
} else {
Err("Body does not pass verification".to_string())
},
Ok(valid) => {
if valid {
Ok(())
} else {
Err("Body does not pass verification".to_string())
}
}
Err(e) => Err(format!("Certificate validation failed: {}", e)),
},
(Some(_), _) | (_, Some(_)) => Err("Body does not pass verification".to_string()),
@@ -571,35 +577,43 @@ fn validate(

if body_valid.is_err() && !cfg!(feature = "skip_body_verification") {
match (is_streaming, headers_data.chunk_tree.is_some()) {
(true, false) => {}, // backward compatibility. Headers could not contain chunk_tree witness for streaming
(true, false) => {} // backward compatibility. Headers could not contain chunk_tree witness for streaming
_ => {
return Err(body_valid.unwrap_err());
},
return body_valid;
}
}
}

Ok(())
}

fn decode_body(body: &[u8], encoding: Option<String>) -> Vec<u8> {
fn decode_body(body: &[u8], encoding: Option<String>) -> [u8; 32] {
let mut sha256 = Sha256::new();
match encoding {
Some(enc) => match enc.as_str() {
"gzip" => {
let decoded: &mut Vec<u8> = &mut vec![];
let decoded: &mut Vec<u8> = &mut vec![];
let decoder = GzDecoder::new(body);
decoder.take(MAX_BYTES_SIZE_TO_DECOMPRESS).read_to_end(decoded).unwrap();
decoded.to_vec()
},
decoder
.take(MAX_BYTES_SIZE_TO_DECOMPRESS)
.read_to_end(decoded)
.unwrap();
sha256.update(decoded);
}
"deflate" => {
let decoded: &mut Vec<u8> = &mut vec![];
let decoder = DeflateDecoder::new(body);
decoder.take(MAX_BYTES_SIZE_TO_DECOMPRESS).read_to_end(decoded).unwrap();
decoded.to_vec()
},
_ => body.to_vec(),
decoder
.take(MAX_BYTES_SIZE_TO_DECOMPRESS)
.read_to_end(decoded)
.unwrap();
sha256.update(decoded);
}
_ => sha256.update(body),
},
_ => body.to_vec(),
}
_ => sha256.update(body),
};
sha256.finalize().into()
}

struct Certificates {
@@ -614,12 +628,13 @@ fn validate_body(
canister_id: &Principal,
agent: &Agent,
uri: &Uri,
response_body: &[u8],
body_sha: &[u8; 32],
logger: slog::Logger,
) -> anyhow::Result<bool> {
let cert: Certificate =
serde_cbor::from_slice(&certificates.certificate).map_err(AgentError::InvalidCborData)?;
let tree: HashTree = serde_cbor::from_slice(&certificates.tree).map_err(AgentError::InvalidCborData)?;
let tree: HashTree =
serde_cbor::from_slice(&certificates.tree).map_err(AgentError::InvalidCborData)?;

if let Err(e) = agent.verify(&cert) {
slog::trace!(logger, ">> certificate failed verification: {}", e);
@@ -671,12 +686,9 @@ fn validate_body(
},
};

let mut sha256 = Sha256::new();
sha256.update(response_body);
let body_sha: [u8; 32] = sha256.finalize().into();

if let Some(tree) = certificates.chunk_tree {
let chunk_tree: HashTree = serde_cbor::from_slice(&tree).map_err(AgentError::InvalidCborData)?;
let chunk_tree: HashTree =
serde_cbor::from_slice(&tree).map_err(AgentError::InvalidCborData)?;

let chunk_tree_digest = chunk_tree.digest();

@@ -700,7 +712,7 @@ fn validate_body(
return Ok(false);
}
};

Ok(body_sha == chunk_sha)
} else {
Ok(body_sha == tree_sha)