diff --git a/src/migtd/src/migration/event.rs b/src/migtd/src/migration/event.rs index 796c287f..d8e5cda3 100644 --- a/src/migtd/src/migration/event.rs +++ b/src/migtd/src/migration/event.rs @@ -3,27 +3,15 @@ // SPDX-License-Identifier: BSD-2-Clause-Patent use crate::driver::vmcall_raw::panic_with_guest_crash_reg_report; -use alloc::collections::BTreeMap; use core::sync::atomic::{AtomicBool, Ordering}; -use lazy_static::lazy_static; -use spin::Mutex; use td_payload::arch::apic::*; use td_payload::arch::idt::{register_interrupt_callback, InterruptCallback, InterruptStack}; pub const VMCALL_SERVICE_VECTOR: u8 = 0x50; pub static VMCALL_SERVICE_FLAG: AtomicBool = AtomicBool::new(false); -lazy_static! { - pub static ref VMCALL_MIG_REPORTSTATUS_FLAGS: Mutex> = - Mutex::new(BTreeMap::new()); -} - fn vmcall_service_callback(_stack: &mut InterruptStack) { VMCALL_SERVICE_FLAG.store(true, Ordering::SeqCst); - - for (_key, flag) in VMCALL_MIG_REPORTSTATUS_FLAGS.lock().iter() { - flag.store(true, Ordering::SeqCst); - } } pub fn register_callback() { diff --git a/src/migtd/src/migration/session.rs b/src/migtd/src/migration/session.rs index 18103512..ccd157c8 100644 --- a/src/migtd/src/migration/session.rs +++ b/src/migtd/src/migration/session.rs @@ -2,8 +2,6 @@ // // SPDX-License-Identifier: BSD-2-Clause-Patent -#[cfg(feature = "vmcall-raw")] -use crate::migration::event::VMCALL_MIG_REPORTSTATUS_FLAGS; #[cfg(feature = "policy_v2")] use crate::migration::pre_session_data::pre_session_data_exchange; #[cfg(all(feature = "vmcall-raw", feature = "policy_v2"))] @@ -14,8 +12,7 @@ use crate::migration::transport::TransportType; #[cfg(feature = "policy_v2")] use alloc::boxed::Box; use alloc::collections::BTreeSet; -#[cfg(feature = "vmcall-raw")] -use core::sync::atomic::AtomicBool; + #[cfg(any(feature = "vmcall-interrupt", feature = "vmcall-raw"))] use core::sync::atomic::Ordering; use core::time::Duration; @@ -243,6 +240,19 @@ fn calculate_shared_page_nums(reqbufferhdrlen: usize) -> Result { Ok((total_size + PAGE_SIZE - 1) / PAGE_SIZE) } +#[cfg(feature = "vmcall-raw")] +fn try_accept_request( + mig_request_id: u64, + response: WaitForRequestResponse, +) -> Poll> { + let inserted = REQUESTS.lock().insert(mig_request_id); + if inserted { + Poll::Ready(Ok(response)) + } else { + Poll::Pending + } +} + #[cfg(feature = "vmcall-raw")] pub async fn wait_for_request() -> Result { let mut reqbufferhdr = RequestDataBufferHeader { @@ -314,10 +324,6 @@ pub async fn wait_for_request() -> Result { let slice = &data_buffer[reqbufferhdrlen..reqbufferhdrlen + data_length as usize]; let mig_request_id = u64::from_le_bytes(slice[0..8].try_into().unwrap()); - VMCALL_MIG_REPORTSTATUS_FLAGS - .lock() - .insert(mig_request_id, AtomicBool::new(false)); - let wfr_info = MigtdMigrationInformation { mig_request_id, migration_source: slice[8], @@ -328,26 +334,13 @@ pub async fn wait_for_request() -> Result { let wfr_info = MigrationInformation { mig_info: wfr_info }; - if REQUESTS.lock().contains(&mig_request_id) { - Poll::Pending - } else { - REQUESTS.lock().insert(mig_request_id); - Poll::Ready(Ok(WaitForRequestResponse::StartMigration(wfr_info))) - } + try_accept_request(mig_request_id, WaitForRequestResponse::StartMigration(wfr_info)) } else if operation == DataStatusOperation::StartRebinding as u8 { #[cfg(all(feature = "vmcall-raw", feature = "policy_v2"))] match RebindingInfo::read_from_bytes(&data_buffer[reqbufferhdrlen..]) { Some(rebinding_info) => { - VMCALL_MIG_REPORTSTATUS_FLAGS - .lock() - .insert(rebinding_info.mig_request_id, AtomicBool::new(false)); - - if REQUESTS.lock().contains(&rebinding_info.mig_request_id) { - Poll::Pending - } else { - REQUESTS.lock().insert(rebinding_info.mig_request_id); - Poll::Ready(Ok(WaitForRequestResponse::StartRebinding(rebinding_info))) - } + let req_id = rebinding_info.mig_request_id; + try_accept_request(req_id, WaitForRequestResponse::StartRebinding(rebinding_info)) } None => { if data_length >= size_of::() as u32 { @@ -390,21 +383,12 @@ pub async fn wait_for_request() -> Result { reportdata = slice[8..72].try_into().unwrap(); } - VMCALL_MIG_REPORTSTATUS_FLAGS - .lock() - .insert(mig_request_id, AtomicBool::new(false)); - let wfr_info = ReportInfo { mig_request_id, reportdata, }; - if REQUESTS.lock().contains(&mig_request_id) { - Poll::Pending - } else { - REQUESTS.lock().insert(mig_request_id); - Poll::Ready(Ok(WaitForRequestResponse::GetTdReport(wfr_info))) - } + try_accept_request(mig_request_id, WaitForRequestResponse::GetTdReport(wfr_info)) } else if operation == DataStatusOperation::EnableLogArea as u8 { let expected_datalength = size_of::(); if data_length != expected_datalength as u32 { @@ -421,22 +405,13 @@ pub async fn wait_for_request() -> Result { let slice = &data_buffer[reqbufferhdrlen..reqbufferhdrlen + data_length as usize]; let mig_request_id = u64::from_le_bytes(slice[0..8].try_into().unwrap()); - VMCALL_MIG_REPORTSTATUS_FLAGS - .lock() - .insert(mig_request_id, AtomicBool::new(false)); - let wfr_info = EnableLogAreaInfo { mig_request_id, log_max_level: slice[8], reserved: slice[9..16].try_into().unwrap(), }; - if REQUESTS.lock().contains(&mig_request_id) { - Poll::Pending - } else { - REQUESTS.lock().insert(mig_request_id); - Poll::Ready(Ok(WaitForRequestResponse::EnableLogArea(wfr_info))) - } + try_accept_request(mig_request_id, WaitForRequestResponse::EnableLogArea(wfr_info)) } else if operation == DataStatusOperation::GetMigtdData as u8 { #[cfg(all(feature = "vmcall-raw", feature = "policy_v2"))] { @@ -461,20 +436,12 @@ pub async fn wait_for_request() -> Result { reportdata = slice[8..72].try_into().unwrap(); } - VMCALL_MIG_REPORTSTATUS_FLAGS - .lock() - .insert(mig_request_id, AtomicBool::new(false)); - let wfr_info = MigtdDataInfo { mig_request_id, reportdata, }; - if REQUESTS.lock().contains(&mig_request_id) { - Poll::Pending - } else { - REQUESTS.lock().insert(mig_request_id); - Poll::Ready(Ok(WaitForRequestResponse::GetMigtdData(wfr_info))) - } + + try_accept_request(mig_request_id, WaitForRequestResponse::GetMigtdData(wfr_info)) } #[cfg(not(all(feature = "vmcall-raw", feature = "policy_v2")))] { @@ -720,25 +687,13 @@ pub async fn report_status(status: u8, request_id: u64, data: &Vec) -> Resul })?; poll_fn(|_cx| -> Poll> { - if let Some(flag) = VMCALL_MIG_REPORTSTATUS_FLAGS.lock().get(&request_id) { - if flag.load(Ordering::SeqCst) { - flag.store(false, Ordering::SeqCst); - } else { - return Poll::Pending; - } - } else { - return Poll::Pending; - } - reqbufferhdr = process_buffer(data_buffer); let data_status_bytes = &reqbufferhdr.datastatus.to_le_bytes(); if data_status_bytes[0] != TDX_VMCALL_VMM_SUCCESS { - log::error!(migration_request_id = request_id; "report_status: data_status byte[0] failure\n"); + log::info!(migration_request_id = request_id; "report_status: Pending confirmation\n"); return Poll::Pending; } - VMCALL_MIG_REPORTSTATUS_FLAGS.lock().remove(&request_id); - Poll::Ready(Ok(())) }) .await