From 64510a1f18c95a1ef841e0b5a2df9908ed7701d0 Mon Sep 17 00:00:00 2001 From: Haitao Huang Date: Sat, 21 Feb 2026 19:49:10 -0800 Subject: [PATCH] MigTD: Remove lock for REPORTSTATUS flags MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The interrupt handler (vmcall_service_callback) previously held a spin::Mutex on VMCALL_MIG_REPORTSTATUS_FLAGS (BTreeMap). On the single-threaded TDX guest, if an interrupt fires while non-ISR code holds the same lock, the ISR spins forever — guaranteed deadlock. Replace the per-request BTreeMap + Mutex with direct buffer polling: each report_status poll_fn checks its own private shared-memory buffer's datastatus field on every poll cycle. This is correct because: - The buffer is the ground truth for VMM completion - Any interrupt (WFR or otherwise) wakes the CPU from HLT, triggering poll_tasks() which re-polls all pending futures - Each consumer checks its own buffer — no shared state, no interference between concurrent report_status calls The ISR now only sets VMCALL_SERVICE_FLAG (for wait_for_request). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Haitao Huang --- src/migtd/src/migration/event.rs | 12 ---- src/migtd/src/migration/session.rs | 89 ++++++++---------------------- 2 files changed, 22 insertions(+), 79 deletions(-) diff --git a/src/migtd/src/migration/event.rs b/src/migtd/src/migration/event.rs index 796c287f..d8e5cda3 100644 --- a/src/migtd/src/migration/event.rs +++ b/src/migtd/src/migration/event.rs @@ -3,27 +3,15 @@ // SPDX-License-Identifier: BSD-2-Clause-Patent use crate::driver::vmcall_raw::panic_with_guest_crash_reg_report; -use alloc::collections::BTreeMap; use core::sync::atomic::{AtomicBool, Ordering}; -use lazy_static::lazy_static; -use spin::Mutex; use td_payload::arch::apic::*; use td_payload::arch::idt::{register_interrupt_callback, InterruptCallback, InterruptStack}; pub const VMCALL_SERVICE_VECTOR: u8 = 0x50; pub static VMCALL_SERVICE_FLAG: AtomicBool = AtomicBool::new(false); -lazy_static! { - pub static ref VMCALL_MIG_REPORTSTATUS_FLAGS: Mutex> = - Mutex::new(BTreeMap::new()); -} - fn vmcall_service_callback(_stack: &mut InterruptStack) { VMCALL_SERVICE_FLAG.store(true, Ordering::SeqCst); - - for (_key, flag) in VMCALL_MIG_REPORTSTATUS_FLAGS.lock().iter() { - flag.store(true, Ordering::SeqCst); - } } pub fn register_callback() { diff --git a/src/migtd/src/migration/session.rs b/src/migtd/src/migration/session.rs index 18103512..ccd157c8 100644 --- a/src/migtd/src/migration/session.rs +++ b/src/migtd/src/migration/session.rs @@ -2,8 +2,6 @@ // // SPDX-License-Identifier: BSD-2-Clause-Patent -#[cfg(feature = "vmcall-raw")] -use crate::migration::event::VMCALL_MIG_REPORTSTATUS_FLAGS; #[cfg(feature = "policy_v2")] use crate::migration::pre_session_data::pre_session_data_exchange; #[cfg(all(feature = "vmcall-raw", feature = "policy_v2"))] @@ -14,8 +12,7 @@ use crate::migration::transport::TransportType; #[cfg(feature = "policy_v2")] use alloc::boxed::Box; use alloc::collections::BTreeSet; -#[cfg(feature = "vmcall-raw")] -use core::sync::atomic::AtomicBool; + #[cfg(any(feature = "vmcall-interrupt", feature = "vmcall-raw"))] use core::sync::atomic::Ordering; use core::time::Duration; @@ -243,6 +240,19 @@ fn calculate_shared_page_nums(reqbufferhdrlen: usize) -> Result { Ok((total_size + PAGE_SIZE - 1) / PAGE_SIZE) } +#[cfg(feature = "vmcall-raw")] +fn try_accept_request( + mig_request_id: u64, + response: WaitForRequestResponse, +) -> Poll> { + let inserted = REQUESTS.lock().insert(mig_request_id); + if inserted { + Poll::Ready(Ok(response)) + } else { + Poll::Pending + } +} + #[cfg(feature = "vmcall-raw")] pub async fn wait_for_request() -> Result { let mut reqbufferhdr = RequestDataBufferHeader { @@ -314,10 +324,6 @@ pub async fn wait_for_request() -> Result { let slice = &data_buffer[reqbufferhdrlen..reqbufferhdrlen + data_length as usize]; let mig_request_id = u64::from_le_bytes(slice[0..8].try_into().unwrap()); - VMCALL_MIG_REPORTSTATUS_FLAGS - .lock() - .insert(mig_request_id, AtomicBool::new(false)); - let wfr_info = MigtdMigrationInformation { mig_request_id, migration_source: slice[8], @@ -328,26 +334,13 @@ pub async fn wait_for_request() -> Result { let wfr_info = MigrationInformation { mig_info: wfr_info }; - if REQUESTS.lock().contains(&mig_request_id) { - Poll::Pending - } else { - REQUESTS.lock().insert(mig_request_id); - Poll::Ready(Ok(WaitForRequestResponse::StartMigration(wfr_info))) - } + try_accept_request(mig_request_id, WaitForRequestResponse::StartMigration(wfr_info)) } else if operation == DataStatusOperation::StartRebinding as u8 { #[cfg(all(feature = "vmcall-raw", feature = "policy_v2"))] match RebindingInfo::read_from_bytes(&data_buffer[reqbufferhdrlen..]) { Some(rebinding_info) => { - VMCALL_MIG_REPORTSTATUS_FLAGS - .lock() - .insert(rebinding_info.mig_request_id, AtomicBool::new(false)); - - if REQUESTS.lock().contains(&rebinding_info.mig_request_id) { - Poll::Pending - } else { - REQUESTS.lock().insert(rebinding_info.mig_request_id); - Poll::Ready(Ok(WaitForRequestResponse::StartRebinding(rebinding_info))) - } + let req_id = rebinding_info.mig_request_id; + try_accept_request(req_id, WaitForRequestResponse::StartRebinding(rebinding_info)) } None => { if data_length >= size_of::() as u32 { @@ -390,21 +383,12 @@ pub async fn wait_for_request() -> Result { reportdata = slice[8..72].try_into().unwrap(); } - VMCALL_MIG_REPORTSTATUS_FLAGS - .lock() - .insert(mig_request_id, AtomicBool::new(false)); - let wfr_info = ReportInfo { mig_request_id, reportdata, }; - if REQUESTS.lock().contains(&mig_request_id) { - Poll::Pending - } else { - REQUESTS.lock().insert(mig_request_id); - Poll::Ready(Ok(WaitForRequestResponse::GetTdReport(wfr_info))) - } + try_accept_request(mig_request_id, WaitForRequestResponse::GetTdReport(wfr_info)) } else if operation == DataStatusOperation::EnableLogArea as u8 { let expected_datalength = size_of::(); if data_length != expected_datalength as u32 { @@ -421,22 +405,13 @@ pub async fn wait_for_request() -> Result { let slice = &data_buffer[reqbufferhdrlen..reqbufferhdrlen + data_length as usize]; let mig_request_id = u64::from_le_bytes(slice[0..8].try_into().unwrap()); - VMCALL_MIG_REPORTSTATUS_FLAGS - .lock() - .insert(mig_request_id, AtomicBool::new(false)); - let wfr_info = EnableLogAreaInfo { mig_request_id, log_max_level: slice[8], reserved: slice[9..16].try_into().unwrap(), }; - if REQUESTS.lock().contains(&mig_request_id) { - Poll::Pending - } else { - REQUESTS.lock().insert(mig_request_id); - Poll::Ready(Ok(WaitForRequestResponse::EnableLogArea(wfr_info))) - } + try_accept_request(mig_request_id, WaitForRequestResponse::EnableLogArea(wfr_info)) } else if operation == DataStatusOperation::GetMigtdData as u8 { #[cfg(all(feature = "vmcall-raw", feature = "policy_v2"))] { @@ -461,20 +436,12 @@ pub async fn wait_for_request() -> Result { reportdata = slice[8..72].try_into().unwrap(); } - VMCALL_MIG_REPORTSTATUS_FLAGS - .lock() - .insert(mig_request_id, AtomicBool::new(false)); - let wfr_info = MigtdDataInfo { mig_request_id, reportdata, }; - if REQUESTS.lock().contains(&mig_request_id) { - Poll::Pending - } else { - REQUESTS.lock().insert(mig_request_id); - Poll::Ready(Ok(WaitForRequestResponse::GetMigtdData(wfr_info))) - } + + try_accept_request(mig_request_id, WaitForRequestResponse::GetMigtdData(wfr_info)) } #[cfg(not(all(feature = "vmcall-raw", feature = "policy_v2")))] { @@ -720,25 +687,13 @@ pub async fn report_status(status: u8, request_id: u64, data: &Vec) -> Resul })?; poll_fn(|_cx| -> Poll> { - if let Some(flag) = VMCALL_MIG_REPORTSTATUS_FLAGS.lock().get(&request_id) { - if flag.load(Ordering::SeqCst) { - flag.store(false, Ordering::SeqCst); - } else { - return Poll::Pending; - } - } else { - return Poll::Pending; - } - reqbufferhdr = process_buffer(data_buffer); let data_status_bytes = &reqbufferhdr.datastatus.to_le_bytes(); if data_status_bytes[0] != TDX_VMCALL_VMM_SUCCESS { - log::error!(migration_request_id = request_id; "report_status: data_status byte[0] failure\n"); + log::info!(migration_request_id = request_id; "report_status: Pending confirmation\n"); return Poll::Pending; } - VMCALL_MIG_REPORTSTATUS_FLAGS.lock().remove(&request_id); - Poll::Ready(Ok(())) }) .await