diff --git a/src/scheduler.rs b/src/scheduler.rs index c209f83..61a7f2a 100644 --- a/src/scheduler.rs +++ b/src/scheduler.rs @@ -1,7 +1,6 @@ -use std::sync::{ - Arc, - atomic::{AtomicU64, Ordering}, -}; +#[cfg(feature = "multithreaded")] +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; use crate::world::SendSync; use parking_lot::RwLock; @@ -16,34 +15,41 @@ use crate::world::World; #[derive(Copy, Clone)] pub struct SysId(u64); -pub trait SystemFn: Fn(&World) + SendSync {} +pub trait SystemFn: FnMut(&World) + SendSync {} + +impl SystemFn for T {} + +pub type System = (SysId, Option>); + +#[cfg(feature = "multithreaded")] +pub trait ParallelSystemFn: Fn(&World) + SendSync {} -impl SystemFn for T {} +#[cfg(feature = "multithreaded")] +impl ParallelSystemFn for T {} -pub type System = (SysId, Arc); +#[cfg(feature = "multithreaded")] +pub type ParallelSystem = (SysId, Arc); #[derive(Default)] pub struct Scheduler { next_id: AtomicU64, #[cfg(feature = "multithreaded")] - parallel_systems: RwLock>, + parallel_systems: RwLock>, systems: RwLock>, } impl Scheduler { - fn add_to(&self, systems: &RwLock>, system: impl SystemFn) -> SysId { + #[cfg(feature = "multithreaded")] + pub fn register_parallel(&self, system: impl ParallelSystemFn) -> SysId { let id = SysId(self.next_id.fetch_add(1, Ordering::Relaxed)); - systems.write().push((id, Arc::new(system))); + self.parallel_systems.write().push((id, Arc::new(system))); id } - #[cfg(feature = "multithreaded")] - pub fn register_parallel(&self, system: impl SystemFn) -> SysId { - self.add_to(&self.parallel_systems, system) - } - pub fn register(&self, system: impl SystemFn) -> SysId { - self.add_to(&self.systems, system) + let id = SysId(self.next_id.fetch_add(1, Ordering::Relaxed)); + self.systems.write().push((id, Some(Box::new(system)))); + id } pub fn deregister(&self, system: SysId) { @@ -81,8 +87,19 @@ impl Scheduler { .for_each(|sys| sys(world)); let len = self.systems.read().len(); - (0..len) - .filter_map(|i| Some(self.systems.read().get(i)?.1.clone())) - .for_each(|sys| sys(world)); + for i in 0..len { + let mut guard = self.systems.write(); + let Some((_, sys)) = guard.get_mut(i) else { + break; + }; + let mut sys = sys.take().unwrap(); + drop(guard); + sys(world); + let mut guard = self.systems.write(); + let Some((_, entry)) = guard.get_mut(i) else { + break; + }; + *entry = Some(sys); + } } } diff --git a/src/world.rs b/src/world.rs index f100828..bb5b6bf 100644 --- a/src/world.rs +++ b/src/world.rs @@ -11,6 +11,8 @@ use std::{ sync::atomic::{AtomicU64, Ordering}, }; +#[cfg(feature = "multithreaded")] +use crate::scheduler::ParallelSystemFn; use crate::{ components::AttachComponents, query::Query, @@ -238,7 +240,7 @@ impl World { /// Add a system that will run in parallel on threads with all /// other parallel systems. #[cfg(feature = "multithreaded")] - pub fn add_parallel_system(&self, system: impl SystemFn) { + pub fn add_parallel_system(&self, system: impl ParallelSystemFn) { self.scheduler.register_parallel(system); } diff --git a/tests/system.rs b/tests/system.rs index bc90c31..285a17c 100644 --- a/tests/system.rs +++ b/tests/system.rs @@ -84,6 +84,13 @@ fn query_system() { *i *= 2; }); + let mut state = 5_u32; + + world.add_system(move |_world| { + state += 1; + assert!(state <= 8); + }); + for _ in 0..3 { world.run_systems(); }