From a13dced8a8f081119bea4d0836ebc8a7128c095c Mon Sep 17 00:00:00 2001 From: Antoine Delignat-Lavaud Date: Wed, 16 Oct 2024 17:44:37 +0000 Subject: [PATCH 1/6] Support multiple key ID dynamically based on client request. Implement OHTTP configuration cache --- cgpuvm-attest/src/lib.rs | 26 ++-- ohttp-server/Cargo.toml | 2 + ohttp-server/src/main.rs | 315 ++++++++++++++++++++++----------------- 3 files changed, 199 insertions(+), 144 deletions(-) diff --git a/cgpuvm-attest/src/lib.rs b/cgpuvm-attest/src/lib.rs index 1f0cfa4..80cb298 100755 --- a/cgpuvm-attest/src/lib.rs +++ b/cgpuvm-attest/src/lib.rs @@ -12,15 +12,21 @@ extern "C" { ) -> c_int; } -pub fn attest(data: &[u8], pcrs: u32, endpoint_url: &str) -> Option> { - let endpoint_url_cstring = CString::new(endpoint_url).expect("CString::new failed"); - unsafe { - let url_ptr = endpoint_url_cstring.as_ptr(); - let mut dstlen = 32 * 1024; - let mut dst = Vec::with_capacity(dstlen); - let pdst = dst.as_mut_ptr(); - let res = get_attestation_token(data.as_ptr(), pcrs, pdst, &mut dstlen, url_ptr); - dst.set_len(dstlen); - (res == 0).then_some(dst) +pub fn attest(data: &[u8], pcrs: u32, endpoint_url: &str) -> Result, String> { + match CString::new(endpoint_url) { + Ok(endpoint_url_cstring) => + unsafe { + let url_ptr = endpoint_url_cstring.as_ptr(); + let mut dstlen = 32 * 1024; + let mut dst = Vec::with_capacity(dstlen); + let pdst = dst.as_mut_ptr(); + if get_attestation_token(data.as_ptr(), pcrs, pdst, &mut dstlen, url_ptr) == 0 { + dst.set_len(dstlen); + Ok(dst) + } else { + Err("CVM guest attestation library returned a non-0 code.".to_owned()) + } + }, + _e => Err("Failed to convert endpoint URL to CString.".to_owned()) } } diff --git a/ohttp-server/Cargo.toml b/ohttp-server/Cargo.toml index d7f4fc5..fceb14c 100644 --- a/ohttp-server/Cargo.toml +++ b/ohttp-server/Cargo.toml @@ -12,6 +12,8 @@ rust-hpke = ["ohttp/rust-hpke"] [dependencies] env_logger = {version = "0.10", default-features = false} hex = "0.4" +lazy_static = "1.4" +moka = { version = "0.12", features = ["future"] } tokio = { version = "1", features = ["full"] } serde = { version = "1.0", features = ["derive"] } elliptic-curve = { version = "0.13.8", features = ["jwk"] } diff --git a/ohttp-server/src/main.rs b/ohttp-server/src/main.rs index 93cb52b..c395abd 100755 --- a/ohttp-server/src/main.rs +++ b/ohttp-server/src/main.rs @@ -2,12 +2,14 @@ use std::{io::Cursor, net::SocketAddr, sync::Arc}; +use lazy_static::lazy_static; +use moka::future::Cache; + use futures_util::stream::unfold; use reqwest::{ header::{HeaderMap, HeaderName, HeaderValue}, Method, Response, Url, }; -use tokio::sync::Mutex; use bhttp::{Message, Mode}; use clap::Parser; @@ -44,7 +46,7 @@ const DEFAULT_KMS_URL: &str = "https://acceu-aml-504.confidential-ledger.azure.c const DEFAULT_MAA_URL: &str = "https://sharedeus2.eus2.attest.azure.net"; const FILTERED_RESPONSE_HEADERS: [&str; 2] = ["content-type", "content-length"]; -#[derive(Debug, Parser)] +#[derive(Debug, Parser, Clone)] #[command(name = "ohttp-server", about = "Serve oblivious HTTP requests.")] struct Args { /// The address to bind to. @@ -61,7 +63,7 @@ struct Args { target: Url, /// Obtain key configuration from a KMS after attestation - #[arg(long, short = 'a')] + #[arg(long, short = 'a', default_value_t = true)] attest: bool, /// MAA endpoint @@ -86,12 +88,23 @@ impl Args { } } -async fn import_config(maa: &str, kms: &str) -> Res { +lazy_static! { + static ref cache : Arc> = Arc::new(Cache::builder() + .time_to_live(Duration::from_secs(24 * 60 * 60)) + .build()); +} + +async fn import_config(maa: &str, kms: &str, kid: i32) -> Res { + // Check if the key configuration is in cache + if let Some(config) = cache.get(&kid).await { + info!("Found OHTTP configuration for KID {kid} in cache."); + return Ok(config); + } + // Get MAA token from CVM guest attestation library - let Some(tok) = attest("{}".as_bytes(), 0xffff, maa) else { - panic!("Failed to get MAA token. You must be root to access TPM.") - }; - let token = String::from_utf8(tok).unwrap(); + let token = attest("{}".as_bytes(), 0xffff, maa)?; + + let token = String::from_utf8(token).unwrap(); info!("Fetched MAA token"); trace!("{token}"); @@ -103,43 +116,61 @@ async fn import_config(maa: &str, kms: &str) -> Res { let max_retries = 3; let mut retries = 0; let key: String; - let mut kid: u8 = 0; loop { + // kid<0 will get the latest, this is used by the discover endpoint + let url = if kid>=0 { format!("{kms}?kid={kid}") }else{ kms.to_owned() }; + info!("Sending SKR request to {url}"); + // Get HPKE private key from Azure KMS + // FIXME(adl) kid should be an input of the SKR request let response = client - .post(kms) + .post(url) .header("Authorization", format!("Bearer {token}")) .send() .await?; // We may have to wait for receipt to be ready - if response.status() == 202 { - if retries < max_retries { - retries += 1; - trace!( - "Received 202 status code, retrying... (attempt {}/{})", - retries, - max_retries - ); - sleep(Duration::from_secs(1)).await; - } else { - panic!("Max retries reached, giving up. Cannot reach key management service"); + match response.status().as_u16() { + 202 => { + if retries < max_retries { + retries += 1; + trace!( + "Received 202 status code, retrying... (attempt {}/{})", + retries, + max_retries + ); + sleep(Duration::from_secs(1)).await; + } else { + Err("Max retries reached, giving up. Cannot reach key management service")?; + } + }, + 200 => { + let skr_body = response.text().await?; + info!("SKR successful {}", skr_body); + + let skr: ExportedKey = from_str(&skr_body)?; + trace!("requested KID={}, returned KID={}, Receipt={}", kid, skr.kid, skr.receipt); + + if kid >= 0 && skr.kid as i32 != kid { + Err("KMS returned a different key ID from the one requested")? + } + + key = skr.key; + break; + }, + e => { + info!("KMS returned an unexpected status code: {e}"); + key = "".to_string(); + break } - } else { - let skr_body = response.text().await?; - let skr: ExportedKey = - from_str(&skr_body).expect("Failed to deserialize SKR response. Check KMS version"); - - info!("SKR successful"); - trace!("KID={}, Receipt={}", skr.kid, skr.receipt); - key = skr.key; - break; } } - let cwk = hex::decode(&key).expect("Failed to decode hex key"); - let cwk_map: Value = serde_cbor::from_slice(&cwk).expect("Invalid CBOR in key from KMS"); + + let cwk = hex::decode(&key)?; + let cwk_map: Value = serde_cbor::from_slice(&cwk)?; let mut d = None; + let mut returned_kid : u8 = 0; // Parse the returned CBOR key (in CWK-like format) if let Value::Map(map) = cwk_map { @@ -149,9 +180,12 @@ async fn import_config(maa: &str, kms: &str) -> Res { // key identifier 4 => { if let Value::Integer(k) = value { - kid = u8::try_from(k).unwrap(); + returned_kid = u8::try_from(k).unwrap(); + if kid >=0 && returned_kid as i32 != kid { + Err("Server returned a different KID from the one requested")?; + } } else { - panic!("Bad KID"); + Err("Bad key identifier in SKR response")? } } @@ -160,7 +194,7 @@ async fn import_config(maa: &str, kms: &str) -> Res { if let Value::Bytes(vec) = value { d = Some(vec); } else { - panic!("Invalid private key"); + Err("Invalid secret exponent in SKR response")? } } @@ -168,32 +202,29 @@ async fn import_config(maa: &str, kms: &str) -> Res { -1 => { if value == Value::Integer(2) { } else { - panic!("Bad CBOR key type, expected P-384(2)"); + Err("Bad CBOR key type, expected P-384(2)")? } } // Ignore public key (x,y) as we recompute it from d anyway -2 | -3 => (), - _ => panic!("Unexpected field in exported private key from KMS"), + _ => Err("Unexpected field in exported private key from KMS")?, }; }; } } else { - panic!("Incorrect CBOR encoding in returned private key"); + Err("Incorrect CBOR encoding in returned private key")?; }; - let (sk, pk) = if let Some(key) = d { - let s = ::PrivateKey::from_bytes(&key) - .expect("Failed to create HPKE private key"); - let p = ::sk_to_pk(&s); - (s, p) - } else { - panic!("Missing private exponent in key returned from KMS"); - }; + let sk = match d { + Some(key) => ::PrivateKey::from_bytes(&key), + None => Err("Private key missing from SKR response")? + }?; + let pk = ::sk_to_pk(&sk); let config = KeyConfig::import_p384( - kid, + returned_kid, Kem::P384Sha384, sk, pk, @@ -204,17 +235,17 @@ async fn import_config(maa: &str, kms: &str) -> Res { ], )?; + cache.insert(kid, config.clone()).await; Ok(config) } async fn generate_reply( - ohttp_ref: &Arc>, + ohttp: &OhttpServer, inject_headers: HeaderMap, enc_request: &[u8], target: Url, _mode: Mode, ) -> Res<(Response, ServerResponse)> { - let ohttp = ohttp_ref.lock().await; let (request, server_response) = ohttp.decapsulate(enc_request)?; let bin_request = Message::read_bhttp(&mut Cursor::new(&request[..]))?; @@ -281,88 +312,130 @@ fn compute_injected_headers(headers: &HeaderMap, keys: Vec) -> HeaderMap result } -#[allow(clippy::unused_async)] async fn score( headers: warp::hyper::HeaderMap, - inject_request_headers: Vec, body: warp::hyper::body::Bytes, - ohttp: Arc>, - target: Url, - mode: Mode, + args: Arc, ) -> Result { + + let kms_url = args.kms_url.clone().unwrap_or(DEFAULT_KMS_URL.to_string()); + let maa_url = args.maa_url.clone().unwrap_or(DEFAULT_MAA_URL.to_string()); + let mode = args.mode(); + let target = args.target.clone(); + let inject_request_headers = args.inject_request_headers.clone(); + info!("Received encapsulated score request for target {}", target); info!("Request headers"); for (key, value) in &headers { info!("{}: {}", key, value.to_str().unwrap()); } - let inject_headers = compute_injected_headers(&headers, inject_request_headers); - let reply = generate_reply(&ohttp, inject_headers, &body[..], target, mode); - - match reply.await { - Ok((response, server_response)) => { - let mut builder = - warp::http::Response::builder().header("Content-Type", "message/ohttp-chunked-res"); - - // Move headers from the inner response into the outer response - info!("Response headers:"); - for (key, value) in response.headers() { - if !FILTERED_RESPONSE_HEADERS - .iter() - .any(|h| h.eq_ignore_ascii_case(key.as_str())) - { - info!( - "{}: {}", - key, - std::str::from_utf8(value.as_bytes()).unwrap() - ); - builder = builder.header(key.as_str(), value.as_bytes()); - } - } + let kid : i32 = match body.get(0).copied() {None => -1, Some(kid) => kid as i32}; - let stream = Box::pin(unfold(response, |mut response| async move { - match response.chunk().await { - Ok(Some(chunk)) => { - Some((Ok::, ohttp::Error>(chunk.to_vec()), response)) + let ohttp = match import_config(&maa_url, &kms_url, kid).await { + Ok(config) => match OhttpServer::new(config) { + Ok(ohttp) => Some(ohttp), + _ => None + }, + _ => None + }; + + match ohttp { + None => Ok(warp::http::Response::builder() + .status(500) + .body(Body::from(&b"Failed to get or load the OHTTP coniguration from local cache or key management service."[..]))), + + Some(ohttp) => { + let inject_headers = compute_injected_headers(&headers, inject_request_headers); + let reply = generate_reply(&ohttp, inject_headers, &body[..], target, mode).await; + + match reply { + Ok((response, server_response)) => { + let mut builder = + warp::http::Response::builder().header("Content-Type", "message/ohttp-chunked-res"); + + // Move headers from the inner response into the outer response + info!("Response headers:"); + for (key, value) in response.headers() { + if !FILTERED_RESPONSE_HEADERS + .iter() + .any(|h| h.eq_ignore_ascii_case(key.as_str())) + { + info!( + "{}: {}", + key, + std::str::from_utf8(value.as_bytes()).unwrap() + ); + builder = builder.header(key.as_str(), value.as_bytes()); } - _ => None, } - })); - let stream = server_response.encapsulate_stream(stream); - Ok(builder.body(Body::wrap_stream(stream))) - } - Err(e) => { - error!("400 {}", e.to_string()); - if let Ok(oe) = e.downcast::<::ohttp::Error>() { - Ok(warp::http::Response::builder() - .status(422) - .body(Body::from(format!("Error: {oe:?}")))) - } else { - Ok(warp::http::Response::builder() - .status(400) - .body(Body::from(&b"Request error"[..]))) + let stream = Box::pin(unfold(response, |mut response| async move { + match response.chunk().await { + Ok(Some(chunk)) => { + Some((Ok::, ohttp::Error>(chunk.to_vec()), response)) + } + _ => None, + } + })); + + let stream = server_response.encapsulate_stream(stream); + Ok(builder.body(Body::wrap_stream(stream))) + } + Err(e) => { + error!("400 {}", e.to_string()); + if let Ok(oe) = e.downcast::<::ohttp::Error>() { + Ok(warp::http::Response::builder() + .status(422) + .body(Body::from(format!("Error: {oe:?}")))) + } else { + Ok(warp::http::Response::builder() + .status(400) + .body(Body::from(&b"Request error"[..]))) + } } } + } } } #[allow(clippy::unused_async)] -async fn discover(config: String) -> Result { - Ok(warp::http::Response::builder() - .status(200) - .body(Vec::from(config))) -} +async fn discover(args: Arc) -> Result { + let kms_url = &args.kms_url.clone().unwrap_or(DEFAULT_KMS_URL.to_string()); + let maa_url = &args.maa_url.clone().unwrap_or(DEFAULT_MAA_URL.to_string()); + + match import_config(maa_url, kms_url, -1).await { + Err(_e) => + Ok(warp::http::Response::builder() + .status(500) + .body(Body::from(&b"Failed to get OHTTP coniguration from local cache or key management service. The key ID of the request may be invalid or expired."[..]))), + + Ok(config) => + match KeyConfig::encode_list(&[config]) { + Err(_e) => + Ok(warp::http::Response::builder() + .status(500) + .body(Body::from(&b"Failed to get OHTTP coniguration from local cache or key management service. The key ID of the request may be invalid or expired."[..]))), + + Ok(list) => { + let hex = hex::encode(list); + trace!("Discover config: {}", hex); -fn with_ohttp( - ohttp: Arc>, -) -> impl Filter>,), Error = std::convert::Infallible> + Clone { - warp::any().map(move || Arc::clone(&ohttp)) + Ok(warp::http::Response::builder() + .status(200) + .body(Vec::from(hex).into())) + } + } + } } #[tokio::main] async fn main() -> Res<()> { let args = Args::parse(); + let address = args.address.clone(); + let argsc = Arc::new(args); + let args1 = Arc::clone(&argsc); + let args2 = Arc::clone(&argsc); ::ohttp::init(); // Build a simple subscriber that outputs to stdout @@ -374,49 +447,23 @@ async fn main() -> Res<()> { // Set the subscriber as global default tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); - let config = if args.attest { - let kms_url = &args.kms_url.clone().unwrap_or(DEFAULT_KMS_URL.to_string()); - let maa_url = &args.maa_url.clone().unwrap_or(DEFAULT_MAA_URL.to_string()); - import_config(maa_url, kms_url).await? - } else { - KeyConfig::new( - 0, - Kem::X25519Sha256, - vec![ - SymmetricSuite::new(Kdf::HkdfSha256, Aead::Aes128Gcm), - SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305), - ], - )? - }; - - let ohttp = OhttpServer::new(config)?; - let config = hex::encode(KeyConfig::encode_list(&[ohttp.config()])?); - trace!("Config: {}", config); - - let mode = args.mode(); - let target = args.target; - let inject_request_headers = args.inject_request_headers; - let score = warp::post() .and(warp::path::path("score")) .and(warp::path::end()) .and(warp::header::headers_cloned()) - .and(warp::any().map(move || inject_request_headers.clone())) .and(warp::body::bytes()) - .and(with_ohttp(Arc::new(Mutex::new(ohttp)))) - .and(warp::any().map(move || target.clone())) - .and(warp::any().map(move || mode)) + .and(warp::any().map(move || Arc::clone(&args1))) .and_then(score); let discover = warp::get() .and(warp::path("discover")) .and(warp::path::end()) - .and(warp::any().map(move || config.clone())) + .and(warp::any().map(move || Arc::clone(&args2))) .and_then(discover); let routes = score.or(discover); - warp::serve(routes).run(args.address).await; + warp::serve(routes).run(address).await; Ok(()) } From f49bdc9c95f136458c637d44cf62319576e601f4 Mon Sep 17 00:00:00 2001 From: Antoine Delignat-Lavaud Date: Wed, 16 Oct 2024 17:45:02 +0000 Subject: [PATCH 2/6] Wait for receipt in OHTTP client --- ohttp-client/src/main.rs | 48 +++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/ohttp-client/src/main.rs b/ohttp-client/src/main.rs index 231a2e3..ecaa5d6 100644 --- a/ohttp-client/src/main.rs +++ b/ohttp-client/src/main.rs @@ -154,17 +154,43 @@ async fn get_kms_config(kms_url: String, cert: &str) -> Res { .build()?; info!("Contacting key management service at {kms_url}..."); - - // Make the GET request - let response = client - .get(kms_url + "/listpubkeys") - .send() - .await? - .error_for_status()?; - - let body = response.text().await?; - assert!(!body.is_empty()); - Ok(body) + let max_retries = 3; + let mut retries = 0; + let url = kms_url + "/listpubkeys"; + + loop { + // Make the GET request + let response = client + .get(url.clone()) + .send() + .await? + .error_for_status()?; + + // We may have to wait for receipt to be ready + match response.status().as_u16() { + 202 => { + if retries < max_retries { + retries += 1; + trace!( + "Received 202 status code, retrying... (attempt {}/{})", + retries, + max_retries + ); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + } else { + Err("Max retries reached, giving up. Cannot reach key management service")?; + } + }, + 200 => { + let body = response.text().await?; + assert!(!body.is_empty()); + return Ok(body) + }, + e => { + Err(format!("KMS returned unexpected {} status code.", e))?; + } + } + } } #[derive(Deserialize)] From 84bb26ffc19c0912d90cbd8af688524a8b9fd48f Mon Sep 17 00:00:00 2001 From: Antoine Delignat-Lavaud Date: Thu, 17 Oct 2024 12:54:15 +0000 Subject: [PATCH 3/6] Add local keying mode for testing without KMS. Also added MAA token caching and optional MAA header in inference response --- ohttp-server/src/main.rs | 84 ++++++++++++++++++++++++++++------------ 1 file changed, 60 insertions(+), 24 deletions(-) diff --git a/ohttp-server/src/main.rs b/ohttp-server/src/main.rs index c395abd..2d57c11 100755 --- a/ohttp-server/src/main.rs +++ b/ohttp-server/src/main.rs @@ -62,9 +62,9 @@ struct Args { #[arg(long, short = 't', default_value = "http://127.0.0.1:8000")] target: Url, - /// Obtain key configuration from a KMS after attestation - #[arg(long, short = 'a', default_value_t = true)] - attest: bool, + /// Use locally generated key, for testing without KMS + #[arg(long, short = 'l')] + local_key: bool, /// MAA endpoint #[arg(long, short = 'm')] @@ -89,16 +89,16 @@ impl Args { } lazy_static! { - static ref cache : Arc> = Arc::new(Cache::builder() + static ref cache : Arc> = Arc::new(Cache::builder() .time_to_live(Duration::from_secs(24 * 60 * 60)) .build()); } -async fn import_config(maa: &str, kms: &str, kid: i32) -> Res { +async fn import_config(maa: &str, kms: &str, kid: i32) -> Res<(KeyConfig, String)> { // Check if the key configuration is in cache - if let Some(config) = cache.get(&kid).await { + if let Some((config,token)) = cache.get(&kid).await { info!("Found OHTTP configuration for KID {kid} in cache."); - return Ok(config); + return Ok((config,token)); } // Get MAA token from CVM guest attestation library @@ -235,8 +235,8 @@ async fn import_config(maa: &str, kms: &str, kid: i32) -> Res { ], )?; - cache.insert(kid, config.clone()).await; - Ok(config) + cache.insert(kid, (config.clone(), token.clone())).await; + Ok((config, token)) } async fn generate_reply( @@ -323,29 +323,42 @@ async fn score( let mode = args.mode(); let target = args.target.clone(); let inject_request_headers = args.inject_request_headers.clone(); + let mut return_token = false; info!("Received encapsulated score request for target {}", target); info!("Request headers"); + for (key, value) in &headers { info!("{}: {}", key, value.to_str().unwrap()); + if key == "x-attestation-token" { return_token = true; } } - let kid : i32 = match body.get(0).copied() {None => -1, Some(kid) => kid as i32}; - - let ohttp = match import_config(&maa_url, &kms_url, kid).await { - Ok(config) => match OhttpServer::new(config) { - Ok(ohttp) => Some(ohttp), - _ => None - }, - _ => None + // The KID is normally the first byte of the request + let kid : i32 = match body.get(0).copied() { + None => -1, + Some(kid) => kid as i32 }; + let ohttp = + if args.local_key && kid != 0 { + info!("Ignoring non-0 KID {kid} with local keying configuration"); + None + } else { + match import_config(&maa_url, &kms_url, kid).await { + Ok((config,token)) => match OhttpServer::new(config) { + Ok(ohttp) => Some((ohttp,token)), + _ => None + }, + _ => { info!("Failed to load KID {kid} from KMS"); None } + } + }; + match ohttp { None => Ok(warp::http::Response::builder() .status(500) .body(Body::from(&b"Failed to get or load the OHTTP coniguration from local cache or key management service."[..]))), - Some(ohttp) => { + Some((ohttp,token)) => { let inject_headers = compute_injected_headers(&headers, inject_request_headers); let reply = generate_reply(&ohttp, inject_headers, &body[..], target, mode).await; @@ -354,6 +367,12 @@ async fn score( let mut builder = warp::http::Response::builder().header("Content-Type", "message/ohttp-chunked-res"); + + // Add HTTP header with MAA token, for client auditing. + if return_token { + builder = builder.header(HeaderName::from_static("x-attestation-token"), token.clone()); + } + // Move headers from the inner response into the outer response info!("Response headers:"); for (key, value) in response.headers() { @@ -399,23 +418,27 @@ async fn score( } } -#[allow(clippy::unused_async)] async fn discover(args: Arc) -> Result { let kms_url = &args.kms_url.clone().unwrap_or(DEFAULT_KMS_URL.to_string()); let maa_url = &args.maa_url.clone().unwrap_or(DEFAULT_MAA_URL.to_string()); - match import_config(maa_url, kms_url, -1).await { + // The discovery endpoint is only enabled for local testing + if !args.local_key { + return Ok(warp::http::Response::builder().status(404).body(Body::from(&b"Not found"[..]))); + } + + match import_config(maa_url, kms_url, 0).await { Err(_e) => Ok(warp::http::Response::builder() .status(500) - .body(Body::from(&b"Failed to get OHTTP coniguration from local cache or key management service. The key ID of the request may be invalid or expired."[..]))), + .body(Body::from(&b"KID 0 missing from cache (should be impossible with local keying)"[..]))), - Ok(config) => + Ok((config, _)) => match KeyConfig::encode_list(&[config]) { Err(_e) => Ok(warp::http::Response::builder() .status(500) - .body(Body::from(&b"Failed to get OHTTP coniguration from local cache or key management service. The key ID of the request may be invalid or expired."[..]))), + .body(Body::from(&b"Invalid key configuration (check KeyConfig written to initial cache)"[..]))), Ok(list) => { let hex = hex::encode(list); @@ -432,12 +455,26 @@ async fn discover(args: Arc) -> Result Res<()> { let args = Args::parse(); + let is_local = args.local_key; let address = args.address.clone(); + let argsc = Arc::new(args); let args1 = Arc::clone(&argsc); let args2 = Arc::clone(&argsc); ::ohttp::init(); + // Generate a fresh key for local testing. KID is set to 0. + if is_local { + let config = KeyConfig::new( + 0, + Kem::P384Sha384, + vec![ + SymmetricSuite::new(Kdf::HkdfSha256, Aead::Aes128Gcm), + SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305), + ])?; + cache.insert(0, (config, "".to_owned())).await; + } + // Build a simple subscriber that outputs to stdout let subscriber = FmtSubscriber::builder() .with_max_level(tracing::Level::INFO) @@ -462,7 +499,6 @@ async fn main() -> Res<()> { .and_then(discover); let routes = score.or(discover); - warp::serve(routes).run(address).await; Ok(()) From cfa30bb8716952bed7c353248863a726278c3d54 Mon Sep 17 00:00:00 2001 From: Antoine Delignat-Lavaud Date: Tue, 5 Nov 2024 20:09:23 +0000 Subject: [PATCH 4/6] Use stateful CVM guest attestation to decrypt key --- cgpuvm-attest-cli/src/main.rs | 2 +- cgpuvm-attest/src/lib.rs | 97 +++++++++++++++++++++++++++++++++-- ohttp-server/Cargo.toml | 1 + ohttp-server/src/main.rs | 70 +++++++++++++++++-------- 4 files changed, 146 insertions(+), 24 deletions(-) diff --git a/cgpuvm-attest-cli/src/main.rs b/cgpuvm-attest-cli/src/main.rs index cfbbe06..0636ae4 100644 --- a/cgpuvm-attest-cli/src/main.rs +++ b/cgpuvm-attest-cli/src/main.rs @@ -9,7 +9,7 @@ pub fn main() { let maa_url = &args[1]; let s = "{\"a\":1}"; - let Some(token) = attest(s.as_bytes(), 0xffff, maa_url) else { + let Some(token) = attest(s.as_bytes(), 0xffff, None, maa_url) else { panic!("Failed to get MAA token") }; println!("Got MAA token: {}", String::from_utf8(token).unwrap()); diff --git a/cgpuvm-attest/src/lib.rs b/cgpuvm-attest/src/lib.rs index 80cb298..7ffee2e 100755 --- a/cgpuvm-attest/src/lib.rs +++ b/cgpuvm-attest/src/lib.rs @@ -1,4 +1,4 @@ -use libc::{c_char, c_int, size_t}; +use libc::{c_char, c_int, size_t, c_void}; use std::ffi::CString; #[link(name = "azguestattestation")] @@ -10,6 +10,29 @@ extern "C" { jwt_len: *mut size_t, endpoint_url: *const c_char, ) -> c_int; + + fn ga_create( + st: *mut *mut c_void + ) -> c_int; + + fn ga_free( + st: *mut c_void + ); + + fn ga_get_token( + st: *mut c_void, + app_data: *const u8, + pcr: u32, + jwt: *mut u8, + jwt_len: *mut size_t, + endpoint_url: *const c_char + ) -> c_int; + + fn ga_decrypt( + st: *mut c_void, + cipher: *mut u8, + len: *mut size_t + ) -> c_int; } pub fn attest(data: &[u8], pcrs: u32, endpoint_url: &str) -> Result, String> { @@ -20,13 +43,81 @@ pub fn attest(data: &[u8], pcrs: u32, endpoint_url: &str) -> Result, Str let mut dstlen = 32 * 1024; let mut dst = Vec::with_capacity(dstlen); let pdst = dst.as_mut_ptr(); + if get_attestation_token(data.as_ptr(), pcrs, pdst, &mut dstlen, url_ptr) == 0 { dst.set_len(dstlen); Ok(dst) } else { - Err("CVM guest attestation library returned a non-0 code.".to_owned()) + Err("CVM guest attestation library returned a non-0 code.")? } }, - _e => Err("Failed to convert endpoint URL to CString.".to_owned()) + _ => Err("Failed to convert endpoint URL or ephemeral key to CString.")? } } + +pub struct AttestationClient { + st: *mut c_void +} + +impl AttestationClient { + pub fn new() -> Result { + let mut c = AttestationClient { st: std::ptr::null_mut() }; + unsafe { + let rc = ga_create(&mut c.st); + + if rc == 0 { + return Ok(c); + } + + return Err("Failed to initialize attestation library with error code {rc}".to_string()); + } + } + + pub fn attest(&mut self, data: &[u8], pcrs: u32, endpoint_url: &str) -> Result, String> { + match CString::new(endpoint_url) { + Ok(endpoint_url_cstring) => + unsafe { + let url_ptr = endpoint_url_cstring.as_ptr(); + let mut dstlen = 32 * 1024; + let mut dst = Vec::with_capacity(dstlen); + let pdst = dst.as_mut_ptr(); + let rc = ga_get_token(self.st, data.as_ptr(), pcrs, pdst, &mut dstlen, url_ptr); + + if rc == 0 { + dst.set_len(dstlen); + Ok(dst) + } else { + Err("CVM guest attestation library returned an error: {rc}.")? + } + }, + _ => Err("Failed to convert endpoint URL or ephemeral key to CString.")? + } + } + + pub fn decrypt(&mut self, data: &[u8]) -> Result, String> { + unsafe { + let mut buf = Vec::from(data); + let mut len = data.len(); + let rc = ga_decrypt(self.st, buf.as_mut_ptr(), &mut len); + + if rc == 0 { + buf.set_len(len); + Ok(buf) + } else { + Err("CVM guest attestation library returned an error: {rc}.")? + } + } + } +} + +impl Drop for AttestationClient { + fn drop(&mut self) { + unsafe { + ga_free(self.st); + } + } +} + +unsafe impl Send for AttestationClient { + +} \ No newline at end of file diff --git a/ohttp-server/Cargo.toml b/ohttp-server/Cargo.toml index fceb14c..e64bba9 100644 --- a/ohttp-server/Cargo.toml +++ b/ohttp-server/Cargo.toml @@ -13,6 +13,7 @@ rust-hpke = ["ohttp/rust-hpke"] env_logger = {version = "0.10", default-features = false} hex = "0.4" lazy_static = "1.4" +base64 = "0.22.1" moka = { version = "0.12", features = ["future"] } tokio = { version = "1", features = ["full"] } serde = { version = "1.0", features = ["derive"] } diff --git a/ohttp-server/src/main.rs b/ohttp-server/src/main.rs index 2d57c11..6c8c8d6 100755 --- a/ohttp-server/src/main.rs +++ b/ohttp-server/src/main.rs @@ -21,7 +21,7 @@ use warp::{hyper::Body, Filter}; use tokio::time::{sleep, Duration}; -use cgpuvm_attest::attest; +use cgpuvm_attest::{AttestationClient}; use reqwest::Client; type Res = Result>; @@ -34,6 +34,7 @@ use serde::Deserialize; use tracing::{error, info, trace}; use tracing_subscriber::FmtSubscriber; +use base64::{Engine as _, engine::general_purpose::STANDARD as b64}; #[derive(Deserialize)] struct ExportedKey { @@ -88,25 +89,47 @@ impl Args { } } +// We cache both successful key releases from the KMS as well as SKR errors, +// as guest attestation is very expensive (IMDS + TPM createPrimary + RSA decrypt x2) +// ValidKey expire based on the TTL of the cache (24 hours) +// SKRError are manually invalidated (see import_config), after 60 seconds +#[derive(Clone)] +enum CachedKey { + ValidKey(KeyConfig, String), + SKRError(std::time::SystemTime) +} + +// Lazily initialized shared globals lazy_static! { - static ref cache : Arc> = Arc::new(Cache::builder() + // Key cache for Oblivious HTTP, by key id + static ref cache : Arc> = Arc::new(Cache::builder() .time_to_live(Duration::from_secs(24 * 60 * 60)) .build()); } async fn import_config(maa: &str, kms: &str, kid: i32) -> Res<(KeyConfig, String)> { // Check if the key configuration is in cache - if let Some((config,token)) = cache.get(&kid).await { - info!("Found OHTTP configuration for KID {kid} in cache."); - return Ok((config,token)); + if let Some(entry) = cache.get(&kid).await { + match entry { + CachedKey::ValidKey(config, token) => { + info!("Found OHTTP configuration for KID {kid} in cache."); + return Ok((config,token)); + }, + CachedKey::SKRError(ts) => { + if ts.elapsed()? > Duration::from_secs(60) { + cache.invalidate(&kid).await; + } else { + Err(format!("SKR for KID {kid} has failed in the past 60 seconds, waiting to retry."))? + } + } + } } - // Get MAA token from CVM guest attestation library - let token = attest("{}".as_bytes(), 0xffff, maa)?; + let mut attest_cli = AttestationClient::new().expect("Failed to create attestation client object"); + let t = attest_cli.attest("{}".as_bytes(), 0xff, maa)?; - let token = String::from_utf8(token).unwrap(); - info!("Fetched MAA token"); - trace!("{token}"); + let token = String::from_utf8(t).unwrap(); + info!("Fetched MAA token: {token}"); let client = Client::builder() .danger_accept_invalid_certs(true) @@ -119,11 +142,10 @@ async fn import_config(maa: &str, kms: &str, kid: i32) -> Res<(KeyConfig, String loop { // kid<0 will get the latest, this is used by the discover endpoint - let url = if kid>=0 { format!("{kms}?kid={kid}") }else{ kms.to_owned() }; + let url = if kid>=0 { format!("{kms}?encrypted=true&kid={kid}") }else{ format!("{kms}?encrypted=true") }; info!("Sending SKR request to {url}"); // Get HPKE private key from Azure KMS - // FIXME(adl) kid should be an input of the SKR request let response = client .post(url) .header("Authorization", format!("Bearer {token}")) @@ -167,8 +189,11 @@ async fn import_config(maa: &str, kms: &str, kid: i32) -> Res<(KeyConfig, String } } - let cwk = hex::decode(&key)?; - let cwk_map: Value = serde_cbor::from_slice(&cwk)?; + // The KMS returns the base64-encoded, RSA2048-OAEP-SHA256 encrypted CBOR key + let enc_key = b64.decode(&key)?; + let decrypted_key = attest_cli.decrypt(enc_key.as_slice())?; + let cwk_map: Value = serde_cbor::from_slice(&decrypted_key)?; + let mut d = None; let mut returned_kid : u8 = 0; @@ -235,7 +260,7 @@ async fn import_config(maa: &str, kms: &str, kid: i32) -> Res<(KeyConfig, String ], )?; - cache.insert(kid, (config.clone(), token.clone())).await; + cache.insert(kid, CachedKey::ValidKey(config.clone(), token.clone())).await; Ok((config, token)) } @@ -349,14 +374,19 @@ async fn score( Ok(ohttp) => Some((ohttp,token)), _ => None }, - _ => { info!("Failed to load KID {kid} from KMS"); None } + Err(e) => { + info!("Failed to load KID {kid} from KMS: {e}"); + None + } } }; match ohttp { - None => Ok(warp::http::Response::builder() - .status(500) - .body(Body::from(&b"Failed to get or load the OHTTP coniguration from local cache or key management service."[..]))), + None => { + cache.insert(kid, CachedKey::SKRError(std::time::SystemTime::now())).await; + Ok(warp::http::Response::builder().status(500) + .body(Body::from(&b"Failed to get or load the OHTTP coniguration from local cache or key management service."[..]))) + }, Some((ohttp,token)) => { let inject_headers = compute_injected_headers(&headers, inject_request_headers); @@ -472,7 +502,7 @@ async fn main() -> Res<()> { SymmetricSuite::new(Kdf::HkdfSha256, Aead::Aes128Gcm), SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305), ])?; - cache.insert(0, (config, "".to_owned())).await; + cache.insert(0, CachedKey::ValidKey(config, "".to_owned())).await; } // Build a simple subscriber that outputs to stdout From bb7a3478e8b59754b61cc03ecd18d18dec971996 Mon Sep 17 00:00:00 2001 From: Antoine Delignat-Lavaud Date: Tue, 5 Nov 2024 20:47:48 +0000 Subject: [PATCH 5/6] Merge changes --- ohttp-server/Cargo.toml | 7 +- ohttp-server/src/err.rs | 27 ++ ohttp-server/src/main.rs | 527 ++++++++++++++++++++++----------------- 3 files changed, 326 insertions(+), 235 deletions(-) create mode 100644 ohttp-server/src/err.rs mode change 100755 => 100644 ohttp-server/src/main.rs diff --git a/ohttp-server/Cargo.toml b/ohttp-server/Cargo.toml index e64bba9..fae27f5 100644 --- a/ohttp-server/Cargo.toml +++ b/ohttp-server/Cargo.toml @@ -12,8 +12,8 @@ rust-hpke = ["ohttp/rust-hpke"] [dependencies] env_logger = {version = "0.10", default-features = false} hex = "0.4" -lazy_static = "1.4" base64 = "0.22.1" +lazy_static = "1.4" moka = { version = "0.12", features = ["future"] } tokio = { version = "1", features = ["full"] } serde = { version = "1.0", features = ["derive"] } @@ -28,9 +28,10 @@ futures-util = "0.3.30" futures = "0.3.30" log = "0.4.22" clap = { version = "4.5.18", features = ["derive"] } -json_log = { version = "0.1" } tracing = "0.1" -tracing-subscriber = { version = "0.3.18", features = ["default", "json"] } +tracing-subscriber = { version = "0.3.18", features = ["default", "json", "env-filter"] } +thiserror = "1" +uuid = { version = "1.0", features = ["v4"] } [dependencies.bhttp] path= "../bhttp" diff --git a/ohttp-server/src/err.rs b/ohttp-server/src/err.rs new file mode 100644 index 0000000..c9436f5 --- /dev/null +++ b/ohttp-server/src/err.rs @@ -0,0 +1,27 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum ServerError { + #[error("Incorrect CBOR encoding in returned private key")] + KMSCBOREncoding, + #[error("Bad CBOR key type, expected P-384(2)")] + KMSCBORKeyType, + #[error("Unexpected field in exported private key from KMS")] + KMSField, + #[error("Bad key identifier in SKR response")] + KMSKeyId, + #[error("Invalid secret exponent in SKR response")] + KMSExponent, + #[error("KMS returned an unexpected status code: {0}")] + KMSUnexpected(u16), + #[error("Max retries reached, giving up. Cannot reach key management service")] + KMSUnreachable, + #[error("Private key missing from SKR response")] + PrivateKeyMissing, + #[error("CVM guest attestation library initialization failure")] + AttestationLibraryInit, + #[error("Guest attestation library failed to decrypt HPKE private key")] + TPMDecryptionFailure, + #[error("SKR for the requested KID has failed in the past 60 seconds, waiting to retry.")] + CachedSKRError, +} diff --git a/ohttp-server/src/main.rs b/ohttp-server/src/main.rs old mode 100755 new mode 100644 index 6c8c8d6..793745b --- a/ohttp-server/src/main.rs +++ b/ohttp-server/src/main.rs @@ -1,5 +1,7 @@ #![deny(clippy::pedantic)] +pub mod err; + use std::{io::Cursor, net::SocketAddr, sync::Arc}; use lazy_static::lazy_static; @@ -13,15 +15,17 @@ use reqwest::{ use bhttp::{Message, Mode}; use clap::Parser; +use base64::{Engine as _, engine::general_purpose::STANDARD as b64}; + use ohttp::{ hpke::{Aead, Kdf, Kem}, - KeyConfig, Server as OhttpServer, ServerResponse, SymmetricSuite, + Error, KeyConfig, Server as OhttpServer, ServerResponse, SymmetricSuite, }; use warp::{hyper::Body, Filter}; use tokio::time::{sleep, Duration}; -use cgpuvm_attest::{AttestationClient}; +use cgpuvm_attest::AttestationClient; use reqwest::Client; type Res = Result>; @@ -32,9 +36,12 @@ use serde_json::from_str; use hpke::Deserializable; use serde::Deserialize; -use tracing::{error, info, trace}; -use tracing_subscriber::FmtSubscriber; -use base64::{Engine as _, engine::general_purpose::STANDARD as b64}; +use err::ServerError; +use tracing::{error, info, instrument, trace}; +use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter, FmtSubscriber}; +use uuid::Uuid; + +const VERSION: &str = "1.0.0"; #[derive(Deserialize)] struct ExportedKey { @@ -43,16 +50,15 @@ struct ExportedKey { receipt: String, } -const DEFAULT_KMS_URL: &str = "https://acceu-aml-504.confidential-ledger.azure.com/key"; -const DEFAULT_MAA_URL: &str = "https://sharedeus2.eus2.attest.azure.net"; +const DEFAULT_KMS_URL: &str = "https://accconfinferencedebug.confidential-ledger.azure.com/app/key"; +const DEFAULT_MAA_URL: &str = "https://maanosecureboottestyfu.eus.attest.azure.net"; const FILTERED_RESPONSE_HEADERS: [&str; 2] = ["content-type", "content-length"]; #[derive(Debug, Parser, Clone)] #[command(name = "ohttp-server", about = "Serve oblivious HTTP requests.")] struct Args { /// The address to bind to. - // #[arg(default_value = "127.0.0.1:9443")] - #[arg(default_value = "0.0.0.0:9443")] + #[arg(default_value = "127.0.0.1:9443")] address: SocketAddr, /// When creating message/bhttp, use the indeterminate-length form. @@ -102,35 +108,67 @@ enum CachedKey { // Lazily initialized shared globals lazy_static! { // Key cache for Oblivious HTTP, by key id - static ref cache : Arc> = Arc::new(Cache::builder() + static ref cache : Arc> = Arc::new(Cache::builder() .time_to_live(Duration::from_secs(24 * 60 * 60)) .build()); } -async fn import_config(maa: &str, kms: &str, kid: i32) -> Res<(KeyConfig, String)> { - // Check if the key configuration is in cache - if let Some(entry) = cache.get(&kid).await { - match entry { - CachedKey::ValidKey(config, token) => { - info!("Found OHTTP configuration for KID {kid} in cache."); - return Ok((config,token)); - }, - CachedKey::SKRError(ts) => { - if ts.elapsed()? > Duration::from_secs(60) { - cache.invalidate(&kid).await; - } else { - Err(format!("SKR for KID {kid} has failed in the past 60 seconds, waiting to retry."))? - } - } - } - } +fn parse_cbor_key(key: &str, kid: u8) -> Res<(Option>, u8)> { + let cwk = hex::decode(key)?; + let cwk_map: Value = serde_cbor::from_slice(&cwk)?; + let mut d = None; + let mut returned_kid: u8 = 0; + if let Value::Map(map) = cwk_map { + for (key, value) in map { + if let Value::Integer(key) = key { + match key { + // key identifier + 4 => { + if let Value::Integer(k) = value { + returned_kid = u8::try_from(k).unwrap(); + if returned_kid != kid { + return Err(Box::new(Error::KeyIdMismatch(returned_kid, kid))); + } + } else { + return Err(Box::new(ServerError::KMSKeyId)); + } + } - let mut attest_cli = AttestationClient::new().expect("Failed to create attestation client object"); - let t = attest_cli.attest("{}".as_bytes(), 0xff, maa)?; + // private exponent + -4 => { + if let Value::Bytes(vec) = value { + d = Some(vec); + } else { + return Err(Box::new(ServerError::KMSExponent)); + } + } - let token = String::from_utf8(t).unwrap(); - info!("Fetched MAA token: {token}"); + // key type, must be P-384(2) + -1 => { + if value == Value::Integer(2) { + } else { + return Err(Box::new(ServerError::KMSCBORKeyType)); + } + } + + // Ignore public key (x,y) as we recompute it from d anyway + -2 | -3 => (), + + _ => { + return Err(Box::new(ServerError::KMSField)); + } + }; + }; + } + } else { + return Err(Box::new(ServerError::KMSCBOREncoding)); + }; + Ok((d, returned_kid)) +} +/// Retrieves the HPKE private key from Azure KMS. +/// +async fn get_hpke_private_key_from_kms(kms: &str, kid: u8, token: &str) -> Res { let client = Client::builder() .danger_accept_invalid_certs(true) .build()?; @@ -138,11 +176,9 @@ async fn import_config(maa: &str, kms: &str, kid: i32) -> Res<(KeyConfig, String // Retrying logic for receipt let max_retries = 3; let mut retries = 0; - let key: String; loop { - // kid<0 will get the latest, this is used by the discover endpoint - let url = if kid>=0 { format!("{kms}?encrypted=true&kid={kid}") }else{ format!("{kms}?encrypted=true") }; + let url = format!("{kms}?kid={kid}&encrypted=true"); info!("Sending SKR request to {url}"); // Get HPKE private key from Azure KMS @@ -164,87 +200,72 @@ async fn import_config(maa: &str, kms: &str, kid: i32) -> Res<(KeyConfig, String ); sleep(Duration::from_secs(1)).await; } else { - Err("Max retries reached, giving up. Cannot reach key management service")?; + return Err(Box::new(ServerError::KMSUnreachable)); } - }, + } 200 => { let skr_body = response.text().await?; - info!("SKR successful {}", skr_body); + info!("SKR successful"); let skr: ExportedKey = from_str(&skr_body)?; - trace!("requested KID={}, returned KID={}, Receipt={}", kid, skr.kid, skr.receipt); + trace!("requested KID={kid}, returned KID={}, receipt={}", skr.kid, skr.receipt); - if kid >= 0 && skr.kid as i32 != kid { - Err("KMS returned a different key ID from the one requested")? + if skr.kid != kid { + return Err(Box::new(Error::KeyIdMismatch(skr.kid, kid))); } - key = skr.key; - break; - }, + return Ok(skr.key); + } e => { - info!("KMS returned an unexpected status code: {e}"); - key = "".to_string(); - break + return Err(Box::new(ServerError::KMSUnexpected(e))); } } } +} - // The KMS returns the base64-encoded, RSA2048-OAEP-SHA256 encrypted CBOR key - let enc_key = b64.decode(&key)?; - let decrypted_key = attest_cli.decrypt(enc_key.as_slice())?; - let cwk_map: Value = serde_cbor::from_slice(&decrypted_key)?; - - let mut d = None; - let mut returned_kid : u8 = 0; - - // Parse the returned CBOR key (in CWK-like format) - if let Value::Map(map) = cwk_map { - for (key, value) in map { - if let Value::Integer(key) = key { - match key { - // key identifier - 4 => { - if let Value::Integer(k) = value { - returned_kid = u8::try_from(k).unwrap(); - if kid >=0 && returned_kid as i32 != kid { - Err("Server returned a different KID from the one requested")?; - } - } else { - Err("Bad key identifier in SKR response")? - } - } +/// Try to load an OHTTP key configuration from the cache, or from the KMS if not found +/// +async fn load_config(maa: &str, kms: &str, kid: u8) -> Res<(KeyConfig, String)> { + // Check if the key configuration is in cache + if let Some(entry) = cache.get(&kid).await { + match entry { + CachedKey::ValidKey(config, token) => { + info!("Found OHTTP configuration for KID {kid} in cache."); + return Ok((config,token)); + }, + CachedKey::SKRError(ts) => { + if ts.elapsed()? > Duration::from_secs(60) { + cache.invalidate(&kid).await; + } else { + Err(Box::new(ServerError::CachedSKRError))? + } + } + } + } - // private exponent - -4 => { - if let Value::Bytes(vec) = value { - d = Some(vec); - } else { - Err("Invalid secret exponent in SKR response")? - } - } + let mut attest_cli = match AttestationClient::new() { + Ok(cli) => cli, + _ => Err(Box::new(ServerError::AttestationLibraryInit))? + }; - // key type, must be P-384(2) - -1 => { - if value == Value::Integer(2) { - } else { - Err("Bad CBOR key type, expected P-384(2)")? - } - } + let t = attest_cli.attest("{}".as_bytes(), 0xff, maa)?; + let token = String::from_utf8(t).unwrap(); + info!("Fetched MAA token: {token}"); - // Ignore public key (x,y) as we recompute it from d anyway - -2 | -3 => (), + // The KMS returns the base64-encoded, RSA2048-OAEP-SHA256 encrypted CBOR key + let key = get_hpke_private_key_from_kms(kms, kid, &token).await?; + let enc_key = b64.decode(&key)?; - _ => Err("Unexpected field in exported private key from KMS")?, - }; - }; - } - } else { - Err("Incorrect CBOR encoding in returned private key")?; + let decrypted_key = match attest_cli.decrypt(enc_key.as_slice()) { + Ok(k) => k, + _ => Err(Box::new(ServerError::TPMDecryptionFailure))? }; + let (d, returned_kid) = parse_cbor_key(&decrypted_key, kid)?; + let sk = match d { Some(key) => ::PrivateKey::from_bytes(&key), - None => Err("Private key missing from SKR response")? + None => Err(Box::new(ServerError::PrivateKeyMissing))?, }?; let pk = ::sk_to_pk(&sk); @@ -264,11 +285,32 @@ async fn import_config(maa: &str, kms: &str, kid: i32) -> Res<(KeyConfig, String Ok((config, token)) } +/// Copies headers from the encapsulated request and logs them. +/// +fn get_headers_from_request(bin_request: &Message) -> HeaderMap { + info!("Inner request headers"); + let mut headers = HeaderMap::new(); + for field in bin_request.header().fields() { + info!( + " {}: {}", + std::str::from_utf8(field.name()).unwrap(), + std::str::from_utf8(field.value()).unwrap() + ); + + headers.append( + HeaderName::from_bytes(field.name()).unwrap(), + HeaderValue::from_bytes(field.value()).unwrap(), + ); + } + headers +} + async fn generate_reply( ohttp: &OhttpServer, inject_headers: HeaderMap, enc_request: &[u8], target: Url, + target_path: Option<&HeaderValue>, _mode: Mode, ) -> Res<(Response, ServerResponse)> { let (request, server_response) = ohttp.decapsulate(enc_request)?; @@ -281,32 +323,28 @@ async fn generate_reply( }; // Copy headers from the encapsulated request - info!("Inner request headers"); - let mut headers = HeaderMap::new(); - for field in bin_request.header().fields() { - info!( - "{}: {}", - std::str::from_utf8(field.name()).unwrap(), - std::str::from_utf8(field.value()).unwrap() - ); - - headers.append( - HeaderName::from_bytes(field.name()).unwrap(), - HeaderValue::from_bytes(field.value()).unwrap(), - ); - } + let mut headers = get_headers_from_request(&bin_request); // Inject additional headers from the outer request - info!("Inner request injected headers"); - for (key, value) in inject_headers { - if let Some(key) = key { - info!("{}: {}", key.as_str(), value.to_str().unwrap()); - headers.append(key, value); + if !inject_headers.is_empty() { + info!("Appending injected headers"); + for (key, value) in inject_headers { + if let Some(key) = key { + info!(" {}: {}", key.as_str(), value.to_str().unwrap()); + headers.append(key, value); + } } - } + }; let mut t = target; - if let Some(path_bytes) = bin_request.control().path() { + + // Set resource path to either the one provided in the outer request header + // If none provided, use the path set by the client + if let Some(path_bytes) = target_path { + if let Ok(path_str) = std::str::from_utf8(path_bytes.as_bytes()) { + t.set_path(path_str); + } + } else if let Some(path_bytes) = bin_request.control().path() { if let Ok(path_str) = std::str::from_utf8(path_bytes) { t.set_path(path_str); } @@ -337,115 +375,128 @@ fn compute_injected_headers(headers: &HeaderMap, keys: Vec) -> HeaderMap result } +#[instrument(skip(headers, body, args), fields(version = %VERSION))] async fn score( headers: warp::hyper::HeaderMap, body: warp::hyper::body::Bytes, args: Arc, + x_ms_request_id: Uuid, ) -> Result { - - let kms_url = args.kms_url.clone().unwrap_or(DEFAULT_KMS_URL.to_string()); - let maa_url = args.maa_url.clone().unwrap_or(DEFAULT_MAA_URL.to_string()); - let mode = args.mode(); let target = args.target.clone(); - let inject_request_headers = args.inject_request_headers.clone(); - let mut return_token = false; - info!("Received encapsulated score request for target {}", target); - info!("Request headers"); - for (key, value) in &headers { - info!("{}: {}", key, value.to_str().unwrap()); - if key == "x-attestation-token" { return_token = true; } - } + info!("Request headers length = {}", headers.len()); + let return_token = headers.contains_key("x-attestation-token"); // The KID is normally the first byte of the request - let kid : i32 = match body.get(0).copied() { - None => -1, - Some(kid) => kid as i32 + let kid = match body.first().copied() { + None => { + let error_msg = "No key found in request."; + error!("{error_msg}"); + return Ok(warp::http::Response::builder() + .status(500) + .body(Body::from(error_msg.as_bytes()))); + } + Some(kid) => kid, }; + let maa_url = args.maa_url.clone().unwrap_or(DEFAULT_MAA_URL.to_string()); + let kms_url = args.kms_url.clone().unwrap_or(DEFAULT_KMS_URL.to_string()); - let ohttp = - if args.local_key && kid != 0 { - info!("Ignoring non-0 KID {kid} with local keying configuration"); - None - } else { - match import_config(&maa_url, &kms_url, kid).await { - Ok((config,token)) => match OhttpServer::new(config) { - Ok(ohttp) => Some((ohttp,token)), - _ => None - }, - Err(e) => { - info!("Failed to load KID {kid} from KMS: {e}"); - None - } - } - }; - - match ohttp { - None => { - cache.insert(kid, CachedKey::SKRError(std::time::SystemTime::now())).await; - Ok(warp::http::Response::builder().status(500) - .body(Body::from(&b"Failed to get or load the OHTTP coniguration from local cache or key management service."[..]))) - }, - - Some((ohttp,token)) => { - let inject_headers = compute_injected_headers(&headers, inject_request_headers); - let reply = generate_reply(&ohttp, inject_headers, &body[..], target, mode).await; - - match reply { - Ok((response, server_response)) => { - let mut builder = - warp::http::Response::builder().header("Content-Type", "message/ohttp-chunked-res"); - + let (ohttp, token) = match load_config(&maa_url, &kms_url, kid).await { + Err(e) => { + let error_msg = format!("Failed to get or load OHTTP configuration: {e}"); + error!(error_msg); + cache.insert(kid, CachedKey::SKRError(std::time::SystemTime::now())).await; - // Add HTTP header with MAA token, for client auditing. - if return_token { - builder = builder.header(HeaderName::from_static("x-attestation-token"), token.clone()); - } + return Ok(warp::http::Response::builder() + .status(500) + .body(Body::from(error_msg.as_bytes()))); + } + Ok((config, token)) => match OhttpServer::new(config) { + Ok(server) => (server, token), + Err(e) => { + let error_msg = "Failed to create OHTTP server from config."; + error!("{error_msg} {e}"); + return Ok(warp::http::Response::builder() + .status(500) + .body(Body::from(error_msg.as_bytes()))); + } + }, + }; - // Move headers from the inner response into the outer response - info!("Response headers:"); - for (key, value) in response.headers() { - if !FILTERED_RESPONSE_HEADERS - .iter() - .any(|h| h.eq_ignore_ascii_case(key.as_str())) - { - info!( - "{}: {}", - key, - std::str::from_utf8(value.as_bytes()).unwrap() - ); - builder = builder.header(key.as_str(), value.as_bytes()); - } - } + let inject_request_headers = args.inject_request_headers.clone(); + info!( + "Request inject headers length = {}", + inject_request_headers.len() + ); + for key in &inject_request_headers { + info!(" {}", key); + } - let stream = Box::pin(unfold(response, |mut response| async move { - match response.chunk().await { - Ok(Some(chunk)) => { - Some((Ok::, ohttp::Error>(chunk.to_vec()), response)) - } - _ => None, - } - })); + let inject_headers = compute_injected_headers(&headers, inject_request_headers); + info!("Injected headers length = {}", inject_headers.len()); + for (key, value) in &inject_headers { + info!(" {}: {}", key, value.to_str().unwrap()); + } - let stream = server_response.encapsulate_stream(stream); - Ok(builder.body(Body::wrap_stream(stream))) - } + let target_path = headers.get("enginetarget"); + let mode = args.mode(); + let (response, server_response) = + match generate_reply(&ohttp, inject_headers, &body[..], target, target_path, mode).await { + Ok(s) => s, Err(e) => { - error!("400 {}", e.to_string()); + error!(e); + if let Ok(oe) = e.downcast::<::ohttp::Error>() { - Ok(warp::http::Response::builder() + return Ok(warp::http::Response::builder() .status(422) - .body(Body::from(format!("Error: {oe:?}")))) - } else { - Ok(warp::http::Response::builder() - .status(400) - .body(Body::from(&b"Request error"[..]))) + .body(Body::from(format!("Error: {oe:?}")))); } + + let error_msg = "Request error."; + error!("{error_msg}"); + return Ok(warp::http::Response::builder() + .status(400) + .body(Body::from(error_msg.as_bytes()))); } + }; + + let mut builder = + warp::http::Response::builder().header("Content-Type", "message/ohttp-chunked-res"); + + // Add HTTP header with MAA token, for client auditing. + if return_token { + builder = builder.header( + HeaderName::from_static("x-attestation-token"), + token.clone(), + ); + } + + // Move headers from the inner response into the outer response + info!("Response headers:"); + for (key, value) in response.headers() { + if !FILTERED_RESPONSE_HEADERS + .iter() + .any(|h| h.eq_ignore_ascii_case(key.as_str())) + { + info!( + " {}: {}", + key, + std::str::from_utf8(value.as_bytes()).unwrap() + ); + builder = builder.header(key.as_str(), value.as_bytes()); } - } } + + let stream = Box::pin(unfold(response, |mut response| async move { + match response.chunk().await { + Ok(Some(chunk)) => Some((Ok::, ohttp::Error>(chunk.to_vec()), response)), + _ => None, + } + })); + + let stream = server_response.encapsulate_stream(stream); + Ok(builder.body(Body::wrap_stream(stream))) } async fn discover(args: Arc) -> Result { @@ -454,22 +505,13 @@ async fn discover(args: Arc) -> Result - Ok(warp::http::Response::builder() - .status(500) - .body(Body::from(&b"KID 0 missing from cache (should be impossible with local keying)"[..]))), - - Ok((config, _)) => - match KeyConfig::encode_list(&[config]) { - Err(_e) => - Ok(warp::http::Response::builder() - .status(500) - .body(Body::from(&b"Invalid key configuration (check KeyConfig written to initial cache)"[..]))), - + match load_config(maa_url, kms_url, 0).await { + Ok((config, _)) => match KeyConfig::encode_list(&[config]) { Ok(list) => { let hex = hex::encode(list); trace!("Discover config: {}", hex); @@ -478,50 +520,71 @@ async fn discover(args: Arc) -> Result { + error!("{e}"); + Ok(warp::http::Response::builder().status(500).body(Body::from( + &b"Invalid key configuration (check KeyConfig written to initial cache)"[..], + ))) + } + }, + Err(e) => { + error!(e); + Ok(warp::http::Response::builder().status(500).body(Body::from( + &b"KID 0 missing from cache (should be impossible with local keying)"[..], + ))) } } } #[tokio::main] async fn main() -> Res<()> { - let args = Args::parse(); - let is_local = args.local_key; - let address = args.address.clone(); + // Build a simple subscriber that outputs to stdout + let subscriber = FmtSubscriber::builder() + .with_env_filter(EnvFilter::from_default_env()) + .with_file(true) + .with_line_number(true) + .with_thread_ids(true) + .with_span_events(FmtSpan::NEW) + .json() + .finish(); - let argsc = Arc::new(args); - let args1 = Arc::clone(&argsc); - let args2 = Arc::clone(&argsc); + // Set the subscriber as global default + tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); ::ohttp::init(); + let args = Args::parse(); + let address = args.address; + // Generate a fresh key for local testing. KID is set to 0. - if is_local { + if args.local_key { let config = KeyConfig::new( 0, Kem::P384Sha384, vec![ + SymmetricSuite::new(Kdf::HkdfSha384, Aead::Aes256Gcm), SymmetricSuite::new(Kdf::HkdfSha256, Aead::Aes128Gcm), SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305), - ])?; + ], + ) + .map_err(|e| { + error!("{e}"); + e + })?; cache.insert(0, CachedKey::ValidKey(config, "".to_owned())).await; } - // Build a simple subscriber that outputs to stdout - let subscriber = FmtSubscriber::builder() - .with_max_level(tracing::Level::INFO) - .json() - .finish(); - - // Set the subscriber as global default - tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); - + let argsc = Arc::new(args); + let args1 = Arc::clone(&argsc); let score = warp::post() .and(warp::path::path("score")) .and(warp::path::end()) .and(warp::header::headers_cloned()) .and(warp::body::bytes()) .and(warp::any().map(move || Arc::clone(&args1))) + .and(warp::any().map(Uuid::new_v4)) .and_then(score); + let args2 = Arc::clone(&argsc); let discover = warp::get() .and(warp::path("discover")) .and(warp::path::end()) From b1be7c011ebc2dd9411bfb61d764bbf6307bc182 Mon Sep 17 00:00:00 2001 From: Antoine Delignat-Lavaud Date: Tue, 5 Nov 2024 21:43:49 +0000 Subject: [PATCH 6/6] Merge conflict + type errors --- ohttp-server/src/main.rs | 50 +++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 19 deletions(-) diff --git a/ohttp-server/src/main.rs b/ohttp-server/src/main.rs index 6f0304d..3c0c547 100644 --- a/ohttp-server/src/main.rs +++ b/ohttp-server/src/main.rs @@ -113,11 +113,11 @@ lazy_static! { .build()); } -fn parse_cbor_key(key: &str, kid: u8) -> Res<(Option>, u8)> { - let cwk = hex::decode(key)?; - let cwk_map: Value = serde_cbor::from_slice(&cwk)?; +fn parse_cbor_key(key: &[u8], kid: u8) -> Res<(Option>, u8)> { + let cwk_map: Value = serde_cbor::from_slice(&key)?; let mut d = None; let mut returned_kid: u8 = 0; + if let Value::Map(map) = cwk_map { for (key, value) in map { if let Value::Integer(key) = key { @@ -247,7 +247,6 @@ async fn load_config(maa: &str, kms: &str, kid: u8) -> Res<(KeyConfig, String)> Ok(cli) => cli, _ => Err(Box::new(ServerError::AttestationLibraryInit))? }; - Ok((d, returned_kid)) let t = attest_cli.attest("{}".as_bytes(), 0xff, maa)?; let token = String::from_utf8(t).unwrap(); @@ -376,6 +375,18 @@ fn compute_injected_headers(headers: &HeaderMap, keys: Vec) -> HeaderMap result } +// Serialize Box as it lacks `Send` trait +async fn load_config_safe(maa: &str, kms: &str, kid: u8) -> Result<(KeyConfig, String), String> { + match load_config(maa, kms, kid).await { + Ok(r) => Ok(r), + Err(e) => { + let err = format!("Error loading OHTTP key configuration {kid}: {e:?}"); + error!("{err}"); + Err(err) + } + } +} + #[instrument(skip(headers, body, args), fields(version = %VERSION))] async fn score( headers: warp::hyper::HeaderMap, @@ -400,29 +411,30 @@ async fn score( } Some(kid) => kid, }; + let maa_url = args.maa_url.clone().unwrap_or(DEFAULT_MAA_URL.to_string()); let kms_url = args.kms_url.clone().unwrap_or(DEFAULT_KMS_URL.to_string()); - let (ohttp, token) = match load_config(&maa_url, &kms_url, kid).await { - Err(e) => { - let error_msg = format!("Failed to get or load OHTTP configuration: {e}"); - error!(error_msg); + let (config, token) = match load_config_safe(&maa_url, &kms_url, kid).await { + Ok((config, token)) => (config, token), + Err(_e) => { cache.insert(kid, CachedKey::SKRError(std::time::SystemTime::now())).await; + let error_msg = "Failed to load the requested OHTTP key identifier."; + return Ok(warp::http::Response::builder() + .status(500) + .body(Body::from(error_msg.as_bytes()))); + } + }; + let ohttp = match OhttpServer::new(config) { + Ok(server) => server, + Err(e) => { + let error_msg = "Failed to create OHTTP server from config."; + error!("{error_msg} {e}"); return Ok(warp::http::Response::builder() .status(500) .body(Body::from(error_msg.as_bytes()))); } - Ok((config, token)) => match OhttpServer::new(config) { - Ok(server) => (server, token), - Err(e) => { - let error_msg = "Failed to create OHTTP server from config."; - error!("{error_msg} {e}"); - return Ok(warp::http::Response::builder() - .status(500) - .body(Body::from(error_msg.as_bytes()))); - } - }, }; let inject_request_headers = args.inject_request_headers.clone(); @@ -571,7 +583,7 @@ async fn main() -> Res<()> { error!("{e}"); e })?; - + cache.insert(0, CachedKey::ValidKey(config, "".to_owned())).await; }