Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions src/migtd/src/migration/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,15 @@
// SPDX-License-Identifier: BSD-2-Clause-Patent

use crate::driver::vmcall_raw::panic_with_guest_crash_reg_report;
use alloc::collections::BTreeMap;
use core::sync::atomic::{AtomicBool, Ordering};
use lazy_static::lazy_static;
use spin::Mutex;
use td_payload::arch::apic::*;
use td_payload::arch::idt::{register_interrupt_callback, InterruptCallback, InterruptStack};

pub const VMCALL_SERVICE_VECTOR: u8 = 0x50;
pub static VMCALL_SERVICE_FLAG: AtomicBool = AtomicBool::new(false);

lazy_static! {
pub static ref VMCALL_MIG_REPORTSTATUS_FLAGS: Mutex<BTreeMap<u64, AtomicBool>> =
Mutex::new(BTreeMap::new());
}

fn vmcall_service_callback(_stack: &mut InterruptStack) {
VMCALL_SERVICE_FLAG.store(true, Ordering::SeqCst);

for (_key, flag) in VMCALL_MIG_REPORTSTATUS_FLAGS.lock().iter() {
flag.store(true, Ordering::SeqCst);
}
}

pub fn register_callback() {
Expand Down
89 changes: 22 additions & 67 deletions src/migtd/src/migration/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
//
// SPDX-License-Identifier: BSD-2-Clause-Patent

#[cfg(feature = "vmcall-raw")]
use crate::migration::event::VMCALL_MIG_REPORTSTATUS_FLAGS;
#[cfg(feature = "policy_v2")]
use crate::migration::pre_session_data::pre_session_data_exchange;
#[cfg(all(feature = "vmcall-raw", feature = "policy_v2"))]
Expand All @@ -14,8 +12,7 @@ use crate::migration::transport::TransportType;
#[cfg(feature = "policy_v2")]
use alloc::boxed::Box;
use alloc::collections::BTreeSet;
#[cfg(feature = "vmcall-raw")]
use core::sync::atomic::AtomicBool;

#[cfg(any(feature = "vmcall-interrupt", feature = "vmcall-raw"))]
use core::sync::atomic::Ordering;
use core::time::Duration;
Expand Down Expand Up @@ -243,6 +240,19 @@ fn calculate_shared_page_nums(reqbufferhdrlen: usize) -> Result<usize> {
Ok((total_size + PAGE_SIZE - 1) / PAGE_SIZE)
}

#[cfg(feature = "vmcall-raw")]
fn try_accept_request(
mig_request_id: u64,
response: WaitForRequestResponse,
) -> Poll<Result<WaitForRequestResponse>> {
let inserted = REQUESTS.lock().insert(mig_request_id);
if inserted {
Poll::Ready(Ok(response))
} else {
Poll::Pending
}
}

#[cfg(feature = "vmcall-raw")]
pub async fn wait_for_request() -> Result<WaitForRequestResponse> {
let mut reqbufferhdr = RequestDataBufferHeader {
Expand Down Expand Up @@ -314,10 +324,6 @@ pub async fn wait_for_request() -> Result<WaitForRequestResponse> {
let slice = &data_buffer[reqbufferhdrlen..reqbufferhdrlen + data_length as usize];
let mig_request_id = u64::from_le_bytes(slice[0..8].try_into().unwrap());

VMCALL_MIG_REPORTSTATUS_FLAGS
.lock()
.insert(mig_request_id, AtomicBool::new(false));

let wfr_info = MigtdMigrationInformation {
mig_request_id,
migration_source: slice[8],
Expand All @@ -328,26 +334,13 @@ pub async fn wait_for_request() -> Result<WaitForRequestResponse> {

let wfr_info = MigrationInformation { mig_info: wfr_info };

if REQUESTS.lock().contains(&mig_request_id) {
Poll::Pending
} else {
REQUESTS.lock().insert(mig_request_id);
Poll::Ready(Ok(WaitForRequestResponse::StartMigration(wfr_info)))
}
try_accept_request(mig_request_id, WaitForRequestResponse::StartMigration(wfr_info))
} else if operation == DataStatusOperation::StartRebinding as u8 {
#[cfg(all(feature = "vmcall-raw", feature = "policy_v2"))]
match RebindingInfo::read_from_bytes(&data_buffer[reqbufferhdrlen..]) {
Some(rebinding_info) => {
VMCALL_MIG_REPORTSTATUS_FLAGS
.lock()
.insert(rebinding_info.mig_request_id, AtomicBool::new(false));

if REQUESTS.lock().contains(&rebinding_info.mig_request_id) {
Poll::Pending
} else {
REQUESTS.lock().insert(rebinding_info.mig_request_id);
Poll::Ready(Ok(WaitForRequestResponse::StartRebinding(rebinding_info)))
}
let req_id = rebinding_info.mig_request_id;
try_accept_request(req_id, WaitForRequestResponse::StartRebinding(rebinding_info))
}
None => {
if data_length >= size_of::<u64>() as u32 {
Expand Down Expand Up @@ -390,21 +383,12 @@ pub async fn wait_for_request() -> Result<WaitForRequestResponse> {
reportdata = slice[8..72].try_into().unwrap();
}

VMCALL_MIG_REPORTSTATUS_FLAGS
.lock()
.insert(mig_request_id, AtomicBool::new(false));

let wfr_info = ReportInfo {
mig_request_id,
reportdata,
};

if REQUESTS.lock().contains(&mig_request_id) {
Poll::Pending
} else {
REQUESTS.lock().insert(mig_request_id);
Poll::Ready(Ok(WaitForRequestResponse::GetTdReport(wfr_info)))
}
try_accept_request(mig_request_id, WaitForRequestResponse::GetTdReport(wfr_info))
} else if operation == DataStatusOperation::EnableLogArea as u8 {
let expected_datalength = size_of::<EnableLogAreaInfo>();
if data_length != expected_datalength as u32 {
Expand All @@ -421,22 +405,13 @@ pub async fn wait_for_request() -> Result<WaitForRequestResponse> {
let slice = &data_buffer[reqbufferhdrlen..reqbufferhdrlen + data_length as usize];
let mig_request_id = u64::from_le_bytes(slice[0..8].try_into().unwrap());

VMCALL_MIG_REPORTSTATUS_FLAGS
.lock()
.insert(mig_request_id, AtomicBool::new(false));

let wfr_info = EnableLogAreaInfo {
mig_request_id,
log_max_level: slice[8],
reserved: slice[9..16].try_into().unwrap(),
};

if REQUESTS.lock().contains(&mig_request_id) {
Poll::Pending
} else {
REQUESTS.lock().insert(mig_request_id);
Poll::Ready(Ok(WaitForRequestResponse::EnableLogArea(wfr_info)))
}
try_accept_request(mig_request_id, WaitForRequestResponse::EnableLogArea(wfr_info))
} else if operation == DataStatusOperation::GetMigtdData as u8 {
#[cfg(all(feature = "vmcall-raw", feature = "policy_v2"))]
{
Expand All @@ -461,20 +436,12 @@ pub async fn wait_for_request() -> Result<WaitForRequestResponse> {
reportdata = slice[8..72].try_into().unwrap();
}

VMCALL_MIG_REPORTSTATUS_FLAGS
.lock()
.insert(mig_request_id, AtomicBool::new(false));

let wfr_info = MigtdDataInfo {
mig_request_id,
reportdata,
};
if REQUESTS.lock().contains(&mig_request_id) {
Poll::Pending
} else {
REQUESTS.lock().insert(mig_request_id);
Poll::Ready(Ok(WaitForRequestResponse::GetMigtdData(wfr_info)))
}

try_accept_request(mig_request_id, WaitForRequestResponse::GetMigtdData(wfr_info))
}
#[cfg(not(all(feature = "vmcall-raw", feature = "policy_v2")))]
{
Expand Down Expand Up @@ -720,25 +687,13 @@ pub async fn report_status(status: u8, request_id: u64, data: &Vec<u8>) -> Resul
})?;

poll_fn(|_cx| -> Poll<Result<()>> {
if let Some(flag) = VMCALL_MIG_REPORTSTATUS_FLAGS.lock().get(&request_id) {
if flag.load(Ordering::SeqCst) {
flag.store(false, Ordering::SeqCst);
} else {
return Poll::Pending;
}
} else {
return Poll::Pending;
}

reqbufferhdr = process_buffer(data_buffer);
let data_status_bytes = &reqbufferhdr.datastatus.to_le_bytes();
if data_status_bytes[0] != TDX_VMCALL_VMM_SUCCESS {
log::error!(migration_request_id = request_id; "report_status: data_status byte[0] failure\n");
log::info!(migration_request_id = request_id; "report_status: Pending confirmation\n");
return Poll::Pending;
}

VMCALL_MIG_REPORTSTATUS_FLAGS.lock().remove(&request_id);

Poll::Ready(Ok(()))
})
.await
Expand Down
Loading