diff --git a/cgpuvm-attest-cli/src/main.rs b/cgpuvm-attest-cli/src/main.rs index 2e02d13..101a063 100644 --- a/cgpuvm-attest-cli/src/main.rs +++ b/cgpuvm-attest-cli/src/main.rs @@ -12,5 +12,6 @@ pub fn main() { let Ok(token) = attest(s.as_bytes(), 0xffff, maa_url) else { panic!("Failed to get MAA token") }; + println!("Got MAA token: {}", String::from_utf8(token).unwrap()); } diff --git a/cgpuvm-attest/src/err.rs b/cgpuvm-attest/src/err.rs index 29d6749..bcfdb6e 100644 --- a/cgpuvm-attest/src/err.rs +++ b/cgpuvm-attest/src/err.rs @@ -2,8 +2,10 @@ use thiserror::Error; #[derive(Error, Debug)] pub enum AttestError { + #[error("Failed to initialize guest attestation library")] + Initialization, #[error("Failed to convert endpoint URL to CString")] Convertion, #[error("CVM guest attestation library returned error: {0}")] - MAAToken(i32), + LibraryError(i32), } diff --git a/cgpuvm-attest/src/lib.rs b/cgpuvm-attest/src/lib.rs index 215a20b..9a5e487 100755 --- a/cgpuvm-attest/src/lib.rs +++ b/cgpuvm-attest/src/lib.rs @@ -1,7 +1,7 @@ pub mod err; use err::AttestError; -use libc::{c_char, c_int, size_t}; +use libc::{c_char, c_int, size_t, c_void}; use std::ffi::CString; type Res = Result>; @@ -15,6 +15,29 @@ extern "C" { jwt_len: *mut size_t, endpoint_url: *const c_char, ) -> c_int; + + fn ga_create( + st: *mut *mut c_void + ) -> c_int; + + fn ga_free( + st: *mut c_void + ); + + fn ga_get_token( + st: *mut c_void, + app_data: *const u8, + pcr: u32, + jwt: *mut u8, + jwt_len: *mut size_t, + endpoint_url: *const c_char + ) -> c_int; + + fn ga_decrypt( + st: *mut c_void, + cipher: *mut u8, + len: *mut size_t + ) -> c_int; } pub fn attest(data: &[u8], pcrs: u32, endpoint_url: &str) -> Res> { @@ -23,7 +46,6 @@ pub fn attest(data: &[u8], pcrs: u32, endpoint_url: &str) -> Res> { let mut dstlen = 32 * 1024; let mut dst = Vec::with_capacity(dstlen); let pdst = dst.as_mut_ptr(); - let url_ptr = endpoint_url_cstring.as_ptr(); let ret = get_attestation_token(data.as_ptr(), pcrs, pdst, &mut dstlen, url_ptr); @@ -31,9 +53,76 @@ pub fn attest(data: &[u8], pcrs: u32, endpoint_url: &str) -> Res> { dst.set_len(dstlen); Ok(dst) } else { - Err(Box::new(AttestError::MAAToken(ret))) + Err(Box::new(AttestError::LibraryError(ret))) } }, _e => Err(Box::new(AttestError::Convertion)), } } + +pub struct AttestationClient { + st: *mut c_void +} + +impl AttestationClient { + pub fn new() -> Res { + let mut c = AttestationClient { st: std::ptr::null_mut() }; + unsafe { + let rc = ga_create(&mut c.st); + + if rc == 0 { + return Ok(c); + } + + return Err(Box::new(AttestError::Initialization)); + } + } + + pub fn attest(&mut self, data: &[u8], pcrs: u32, endpoint_url: &str) -> Res> { + match CString::new(endpoint_url) { + Ok(endpoint_url_cstring) => + unsafe { + let url_ptr = endpoint_url_cstring.as_ptr(); + let mut dstlen = 32 * 1024; + let mut dst = Vec::with_capacity(dstlen); + let pdst = dst.as_mut_ptr(); + let rc = ga_get_token(self.st, data.as_ptr(), pcrs, pdst, &mut dstlen, url_ptr); + + if rc == 0 { + dst.set_len(dstlen); + Ok(dst) + } else { + Err(Box::new(AttestError::LibraryError(rc))) + } + }, + _ => Err(Box::new(AttestError::Convertion)), + } + } + + pub fn decrypt(&mut self, data: &[u8]) -> Res> { + unsafe { + let mut buf = Vec::from(data); + let mut len = data.len(); + let rc = ga_decrypt(self.st, buf.as_mut_ptr(), &mut len); + + if rc == 0 { + buf.set_len(len); + Ok(buf) + } else { + Err(Box::new(AttestError::LibraryError(rc))) + } + } + } +} + +impl Drop for AttestationClient { + fn drop(&mut self) { + unsafe { + ga_free(self.st); + } + } +} + +unsafe impl Send for AttestationClient { + +} \ No newline at end of file diff --git a/ohttp-client/src/main.rs b/ohttp-client/src/main.rs index b3d5005..713c1d5 100644 --- a/ohttp-client/src/main.rs +++ b/ohttp-client/src/main.rs @@ -240,7 +240,11 @@ async fn get_kms_config(kms_url: String, cert: &str) -> Res { loop { // Make the GET request - let response = client.get(url.clone()).send().await?.error_for_status()?; + let response = client + .get(url.clone()) + .send() + .await? + .error_for_status()?; // We may have to wait for receipt to be ready match response.status().as_u16() { @@ -256,12 +260,12 @@ async fn get_kms_config(kms_url: String, cert: &str) -> Res { } else { Err("Max retries reached, giving up. Cannot reach key management service")?; } - } + }, 200 => { let body = response.text().await?; assert!(!body.is_empty()); - return Ok(body); - } + return Ok(body) + }, e => { Err(format!("KMS returned unexpected {} status code.", e))?; } diff --git a/ohttp-server/Cargo.toml b/ohttp-server/Cargo.toml index e557b72..fae27f5 100644 --- a/ohttp-server/Cargo.toml +++ b/ohttp-server/Cargo.toml @@ -12,6 +12,7 @@ rust-hpke = ["ohttp/rust-hpke"] [dependencies] env_logger = {version = "0.10", default-features = false} hex = "0.4" +base64 = "0.22.1" lazy_static = "1.4" moka = { version = "0.12", features = ["future"] } tokio = { version = "1", features = ["full"] } diff --git a/ohttp-server/src/err.rs b/ohttp-server/src/err.rs index 6a4db12..c9436f5 100644 --- a/ohttp-server/src/err.rs +++ b/ohttp-server/src/err.rs @@ -18,4 +18,10 @@ pub enum ServerError { KMSUnreachable, #[error("Private key missing from SKR response")] PrivateKeyMissing, + #[error("CVM guest attestation library initialization failure")] + AttestationLibraryInit, + #[error("Guest attestation library failed to decrypt HPKE private key")] + TPMDecryptionFailure, + #[error("SKR for the requested KID has failed in the past 60 seconds, waiting to retry.")] + CachedSKRError, } diff --git a/ohttp-server/src/main.rs b/ohttp-server/src/main.rs old mode 100755 new mode 100644 index 65a7aa1..3c0c547 --- a/ohttp-server/src/main.rs +++ b/ohttp-server/src/main.rs @@ -15,6 +15,8 @@ use reqwest::{ use bhttp::{Message, Mode}; use clap::Parser; +use base64::{Engine as _, engine::general_purpose::STANDARD as b64}; + use ohttp::{ hpke::{Aead, Kdf, Kem}, Error, KeyConfig, Server as OhttpServer, ServerResponse, SymmetricSuite, @@ -23,7 +25,7 @@ use warp::{hyper::Body, Filter}; use tokio::time::{sleep, Duration}; -use cgpuvm_attest::attest; +use cgpuvm_attest::AttestationClient; use reqwest::Client; type Res = Result>; @@ -93,19 +95,29 @@ impl Args { } } +// We cache both successful key releases from the KMS as well as SKR errors, +// as guest attestation is very expensive (IMDS + TPM createPrimary + RSA decrypt x2) +// ValidKey expire based on the TTL of the cache (24 hours) +// SKRError are manually invalidated (see import_config), after 60 seconds +#[derive(Clone)] +enum CachedKey { + ValidKey(KeyConfig, String), + SKRError(std::time::SystemTime) +} + +// Lazily initialized shared globals lazy_static! { - static ref cache: Arc> = Arc::new( - Cache::builder() - .time_to_live(Duration::from_secs(24 * 60 * 60)) - .build() - ); + // Key cache for Oblivious HTTP, by key id + static ref cache : Arc> = Arc::new(Cache::builder() + .time_to_live(Duration::from_secs(24 * 60 * 60)) + .build()); } -fn parse_cbor_key(key: &str, kid: u8) -> Res<(Option>, u8)> { - let cwk = hex::decode(key)?; - let cwk_map: Value = serde_cbor::from_slice(&cwk)?; +fn parse_cbor_key(key: &[u8], kid: u8) -> Res<(Option>, u8)> { + let cwk_map: Value = serde_cbor::from_slice(&key)?; let mut d = None; let mut returned_kid: u8 = 0; + if let Value::Map(map) = cwk_map { for (key, value) in map { if let Value::Integer(key) = key { @@ -154,18 +166,6 @@ fn parse_cbor_key(key: &str, kid: u8) -> Res<(Option>, u8)> { Ok((d, returned_kid)) } -/// Fetches the MAA token from the CVM guest attestation library. -/// -fn fetch_maa_token(maa: &str) -> Res { - // Get MAA token from CVM guest attestation library - info!("Fetching MAA token from {maa}"); - let token = attest("{}".as_bytes(), 0xffff, maa)?; - - let token = String::from_utf8(token).unwrap(); - trace!("{token}"); - Ok(token) -} - /// Retrieves the HPKE private key from Azure KMS. /// async fn get_hpke_private_key_from_kms(kms: &str, kid: u8, token: &str) -> Res { @@ -178,7 +178,7 @@ async fn get_hpke_private_key_from_kms(kms: &str, kid: u8, token: &str) -> Res Res Res Res<(KeyConfig, String)> { // Check if the key configuration is in cache - if let Some((config, token)) = cache.get(&kid).await { - info!("Found OHTTP configuration for KID {kid} in cache."); - return Ok((config, token)); + if let Some(entry) = cache.get(&kid).await { + match entry { + CachedKey::ValidKey(config, token) => { + info!("Found OHTTP configuration for KID {kid} in cache."); + return Ok((config,token)); + }, + CachedKey::SKRError(ts) => { + if ts.elapsed()? > Duration::from_secs(60) { + cache.invalidate(&kid).await; + } else { + Err(Box::new(ServerError::CachedSKRError))? + } + } + } } - // Get MAA token from CVM guest attestation library - let token = fetch_maa_token(maa)?; + let mut attest_cli = match AttestationClient::new() { + Ok(cli) => cli, + _ => Err(Box::new(ServerError::AttestationLibraryInit))? + }; + + let t = attest_cli.attest("{}".as_bytes(), 0xff, maa)?; + let token = String::from_utf8(t).unwrap(); + info!("Fetched MAA token: {token}"); + + // The KMS returns the base64-encoded, RSA2048-OAEP-SHA256 encrypted CBOR key let key = get_hpke_private_key_from_kms(kms, kid, &token).await?; - let (d, returned_kid) = parse_cbor_key(&key, kid)?; + let enc_key = b64.decode(&key)?; + + let decrypted_key = match attest_cli.decrypt(enc_key.as_slice()) { + Ok(k) => k, + _ => Err(Box::new(ServerError::TPMDecryptionFailure))? + }; + + let (d, returned_kid) = parse_cbor_key(&decrypted_key, kid)?; let sk = match d { Some(key) => ::PrivateKey::from_bytes(&key), @@ -258,7 +281,7 @@ async fn load_config(maa: &str, kms: &str, kid: u8) -> Res<(KeyConfig, String)> ], )?; - cache.insert(kid, (config.clone(), token.clone())).await; + cache.insert(kid, CachedKey::ValidKey(config.clone(), token.clone())).await; Ok((config, token)) } @@ -352,6 +375,18 @@ fn compute_injected_headers(headers: &HeaderMap, keys: Vec) -> HeaderMap result } +// Serialize Box as it lacks `Send` trait +async fn load_config_safe(maa: &str, kms: &str, kid: u8) -> Result<(KeyConfig, String), String> { + match load_config(maa, kms, kid).await { + Ok(r) => Ok(r), + Err(e) => { + let err = format!("Error loading OHTTP key configuration {kid}: {e:?}"); + error!("{err}"); + Err(err) + } + } +} + #[instrument(skip(headers, body, args), fields(version = %VERSION))] async fn score( headers: warp::hyper::HeaderMap, @@ -376,26 +411,30 @@ async fn score( } Some(kid) => kid, }; + let maa_url = args.maa_url.clone().unwrap_or(DEFAULT_MAA_URL.to_string()); let kms_url = args.kms_url.clone().unwrap_or(DEFAULT_KMS_URL.to_string()); - let (ohttp, token) = match load_config(&maa_url, &kms_url, kid).await { + + let (config, token) = match load_config_safe(&maa_url, &kms_url, kid).await { + Ok((config, token)) => (config, token), + Err(_e) => { + cache.insert(kid, CachedKey::SKRError(std::time::SystemTime::now())).await; + let error_msg = "Failed to load the requested OHTTP key identifier."; + return Ok(warp::http::Response::builder() + .status(500) + .body(Body::from(error_msg.as_bytes()))); + } + }; + + let ohttp = match OhttpServer::new(config) { + Ok(server) => server, Err(e) => { - let error_msg = "Failed to get or load OHTTP configuration."; + let error_msg = "Failed to create OHTTP server from config."; error!("{error_msg} {e}"); return Ok(warp::http::Response::builder() .status(500) .body(Body::from(error_msg.as_bytes()))); } - Ok((config, token)) => match OhttpServer::new(config) { - Ok(server) => (server, token), - Err(e) => { - let error_msg = "Failed to create OHTTP server from config."; - error!("{error_msg} {e}"); - return Ok(warp::http::Response::builder() - .status(500) - .body(Body::from(error_msg.as_bytes()))); - } - }, }; let inject_request_headers = args.inject_request_headers.clone(); @@ -544,7 +583,8 @@ async fn main() -> Res<()> { error!("{e}"); e })?; - cache.insert(0, (config, String::new())).await; + + cache.insert(0, CachedKey::ValidKey(config, "".to_owned())).await; } let argsc = Arc::new(args);