diff --git a/examples/api_showcase/src/main.rs b/examples/api_showcase/src/main.rs index a1b3eb8..757b5f6 100644 --- a/examples/api_showcase/src/main.rs +++ b/examples/api_showcase/src/main.rs @@ -1,15 +1,15 @@ -#![allow(unused, clippy::no_effect)] +#![allow(unused, clippy::no_effect, clippy::unnecessary_operation)] use shame as sm; use shame::prelude::*; use shame::aliases::*; #[rustfmt::skip] fn make_pipeline(some_param: u32) -> Result { - + // start a pipeline encoding with the default settings. // (in the `shame_wgpu` examples, this is wrapped by the `Gpu` object) - // - // compared to earlier versions of `shame`, pipeline + // + // compared to earlier versions of `shame`, pipeline // encoding is no longer based on a closure, but based on a // RAII guard `sm::EncodingGuard<...>` instead. // That way you can use the `?` operator for your own @@ -60,7 +60,7 @@ fn make_pipeline(some_param: u32) -> Result Result Result = group0.next(); let xforms_uni: sm::Buffer = group0.next(); - + // conditional code generation based on pipeline parameter if some_param > 0 { // if not further specified, defaults to `sm::mem::Storage` @@ -99,7 +99,7 @@ fn make_pipeline(some_param: u32) -> Result Result Result Result Result Result Result().set_with_alpha_to_coverage(rg); + // targets.next::().set_with_alpha_to_coverage(rg); // finish the encoding and obtain the pipeline setup info + shader code. encoder.finish() @@ -486,27 +486,31 @@ struct Mat2([[f32; 2]; 2]); // tell `shame` about the layout semantics of your cpu types // Mat2::layout() == sm::f32x2x2::layout() impl sm::CpuLayout for Mat2 { - fn cpu_layout() -> sm::TypeLayout { sm::f32x2x2::gpu_layout() } + fn cpu_layout() -> sm::TypeLayout { sm::gpu_layout::() } } #[repr(C, align(16))] struct Mat4([[f32; 4]; 4]); impl sm::CpuLayout for Mat4 { - fn cpu_layout() -> sm::TypeLayout { sm::f32x4x4::gpu_layout() } + fn cpu_layout() -> sm::TypeLayout { sm::gpu_layout::() } } // using "duck-traiting" allows you to define layouts for foreign cpu-types, // sidestepping the orphan-rule: +// // if you want to try this, add `glam = "0.30.4"` to the Cargo.toml of this +// // example, comment the `Mat4` definition above and uncomment below // use glam::Mat4; // // declare your own trait with a `layout()` function like this -// // This function will be used by the `derive(GpuLayout)` proc macro +// // This function will be used by the `derive(sm::CpuLayout)` proc macro // pub trait MyCpuLayoutTrait { -// fn layout() -> shame::TypeLayout; +// fn cpu_layout() -> shame::TypeLayout; // } // // tell `shame` about the layout semantics of `glam` types +// // here make a promise to `shame` that `glam::Mat4` has identical memory layout +// // to `sm::f32x4x4` // impl MyCpuLayoutTrait for glam::Mat4 { -// fn layout() -> shame::TypeLayout { sm::f32x4x4::layout() } +// fn cpu_layout() -> shame::TypeLayout { sm::gpu_layout::() } // } diff --git a/examples/hello_triangles/src/util/shame_glam.rs b/examples/hello_triangles/src/util/shame_glam.rs index 7f85092..cff249b 100644 --- a/examples/hello_triangles/src/util/shame_glam.rs +++ b/examples/hello_triangles/src/util/shame_glam.rs @@ -14,14 +14,14 @@ pub trait CpuLayoutExt { // glam::Vec4 matches sm::f32x4 in size and alignment impl CpuLayoutExt for glam::Vec4 { - fn cpu_layout() -> sm::TypeLayout { sm::f32x4::gpu_layout() } + fn cpu_layout() -> sm::TypeLayout { sm::gpu_layout::() } } // glam::Vec2 only matches sm::f32x2 if it has 8 byte alignment impl CpuLayoutExt for glam::Vec2 { fn cpu_layout() -> sm::TypeLayout { if align_of::() == 8 { - sm::f32x2::gpu_layout() + sm::gpu_layout::() } else { panic!("glam needs to use the `cuda` crate feature for Vec2 to be 8 byte aligned"); } @@ -30,5 +30,5 @@ impl CpuLayoutExt for glam::Vec2 { // glam::Mat4 matches sm::f32x4x4 in size and alignment impl CpuLayoutExt for glam::Mat4 { - fn cpu_layout() -> sm::TypeLayout { sm::f32x4x4::gpu_layout() } + fn cpu_layout() -> sm::TypeLayout { sm::gpu_layout::() } } diff --git a/examples/shame_wgpu/src/conversion.rs b/examples/shame_wgpu/src/conversion.rs index 16c7ad6..8ff55e0 100644 --- a/examples/shame_wgpu/src/conversion.rs +++ b/examples/shame_wgpu/src/conversion.rs @@ -101,7 +101,7 @@ pub fn sample_type(st: sm::TextureSampleUsageType) -> wgpu::TextureSampleType { } } -/// converts `tf` into a `wgpu::TextureFormat` if supported. +/// converts `tf` into a `wgpu::TextureFormat` if supported. /// If `tf` is `ExtraTextureFormats::SurfaceFormat`, then the provided `surface_format` argument /// is returned if it is `Some`. Otherwise an error is returned. #[rustfmt::skip] @@ -194,7 +194,7 @@ pub fn texture_format(tf: &dyn sm::TextureFormatId, surface_format: Option wgpu::TextureFormat::EacR11Snorm, SmTf::EacRg11Unorm => wgpu::TextureFormat::EacRg11Unorm, SmTf::EacRg11Snorm => wgpu::TextureFormat::EacRg11Snorm, - SmTf::Astc { block, channel } => wgpu::TextureFormat::Astc { + SmTf::Astc { block, channel } => wgpu::TextureFormat::Astc { block: match block { SmASTCb::B4x4 => wgpu::AstcBlock::B4x4, SmASTCb::B5x4 => wgpu::AstcBlock::B5x4, @@ -210,12 +210,12 @@ pub fn texture_format(tf: &dyn sm::TextureFormatId, surface_format: Option wgpu::AstcBlock::B10x10, SmASTCb::B12x10 => wgpu::AstcBlock::B12x10, SmASTCb::B12x12 => wgpu::AstcBlock::B12x12, - }, + }, channel: match channel { SmASTCc::Unorm => wgpu::AstcChannel::Unorm, SmASTCc::UnormSrgb => wgpu::AstcChannel::UnormSrgb, SmASTCc::Hdr => wgpu::AstcChannel::Hdr, - } + } }, }; Ok(wtf) @@ -449,7 +449,7 @@ fn color_writes(write_mask: smr::ChannelWrites) -> wgpu::ColorWrites { #[rustfmt::skip] fn vertex_format(format: smr::VertexAttribFormat) -> Result { - use smr::ScalarType as S; + use shame::layout::ScalarType as S; use smr::Len as L; use wgpu::VertexFormat as W; let unsupported = Err(ShameToWgpuError::UnsupportedVertexAttribFormat(format)); @@ -479,10 +479,8 @@ fn vertex_format(format: smr::VertexAttribFormat) -> Result W::Sint32x2, (S::I32, L::X3) => W::Sint32x3, (S::I32, L::X4) => W::Sint32x4, - - (S::Bool, _) => return unsupported, }, - + smr::VertexAttribFormat::Coarse(p) => { use smr::PackedScalarType as PS; use smr::PackedFloat as Norm; diff --git a/examples/shame_wgpu/src/lib.rs b/examples/shame_wgpu/src/lib.rs index 40e58d9..7177494 100644 --- a/examples/shame_wgpu/src/lib.rs +++ b/examples/shame_wgpu/src/lib.rs @@ -1,6 +1,9 @@ //! shame wgpu integration //! //! bind-group and pipeline glue code +#![allow(mismatched_lifetime_syntaxes)] +#![deny(unsafe_code)] + pub use shame::*; pub mod bind_group; pub mod binding; diff --git a/shame/Cargo.toml b/shame/Cargo.toml index f838cf9..6ceb070 100644 --- a/shame/Cargo.toml +++ b/shame/Cargo.toml @@ -26,4 +26,4 @@ shame_derive = { path = "../shame_derive/" } [dev-dependencies] static_assertions = "1.1.0" pretty_assertions = "1.4.1" -glam = "0.29.0" \ No newline at end of file +glam = "0.29.0" diff --git a/shame/src/backend/language.rs b/shame/src/backend/language.rs index 3b4e4bd..83fc396 100644 --- a/shame/src/backend/language.rs +++ b/shame/src/backend/language.rs @@ -14,3 +14,11 @@ pub enum Language { Wgsl, // SpirV } + +impl std::fmt::Display for Language { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Language::Wgsl => write!(f, "wgsl"), + } + } +} diff --git a/shame/src/common/format.rs b/shame/src/common/format.rs index 76606fe..47c558b 100644 --- a/shame/src/common/format.rs +++ b/shame/src/common/format.rs @@ -1,7 +1,8 @@ -use std::fmt::Write; +use std::fmt::{Display, Write}; use crate::ir::recording::CallInfo; +/// Ordinal formatting suffix for numbers. /// for "1st" "2nd" "3rd", and the likes. for `1` returns `"st"` pub fn numeral_suffix(i: usize) -> &'static str { match i { @@ -134,3 +135,15 @@ pub fn write_error_excerpt(f: &mut impl Write, call_info: CallInfo, use_colors: Ok(()) } + +/// Turn a closure into a struct implementing [`Display`]. Code borrowed from `Typst` +pub fn display std::fmt::Result>(f: F) -> impl Display { + struct Wrapper(F); + + impl std::fmt::Result> Display for Wrapper { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0(f) + } + } + Wrapper(f) +} \ No newline at end of file diff --git a/shame/src/common/po2.rs b/shame/src/common/po2.rs index 846e6c0..ccbfae8 100644 --- a/shame/src/common/po2.rs +++ b/shame/src/common/po2.rs @@ -38,10 +38,11 @@ pub enum U32PowerOf2 { _2147483648, } -impl From for u32 { +impl U32PowerOf2 { + /// Returns the corresponding u32. #[rustfmt::skip] - fn from(value: U32PowerOf2) -> Self { - match value { + pub const fn as_u32(self) -> u32 { + match self { U32PowerOf2::_1 => 1_u32 , U32PowerOf2::_2 => 2_u32 , U32PowerOf2::_4 => 4_u32 , @@ -76,8 +77,21 @@ impl From for u32 { U32PowerOf2::_2147483648 => 2147483648_u32, } } + + /// Returns the corresponding u64. + pub const fn as_u64(self) -> u64 { self.as_u32() as u64 } +} + +impl From for u32 { + fn from(value: U32PowerOf2) -> Self { value.as_u32() } } +impl U32PowerOf2 { + /// Returns the maximum between `self` and `other`. + pub const fn max(self, other: Self) -> Self { if self as u32 > other as u32 { self } else { other } } +} + +#[allow(missing_docs)] #[derive(Debug)] pub struct NotAU32PowerOf2(u32); @@ -89,11 +103,10 @@ impl Display for NotAU32PowerOf2 { impl std::error::Error for NotAU32PowerOf2 {} -impl TryFrom for U32PowerOf2 { - type Error = NotAU32PowerOf2; - - fn try_from(value: u32) -> Result { - Ok(match value { +impl U32PowerOf2 { + /// Tries to convert a u32 to U32PowerOf2. + pub const fn try_from_u32(value: u32) -> Option { + Some(match value { 1 => U32PowerOf2::_1, 2 => U32PowerOf2::_2, 4 => U32PowerOf2::_4, @@ -126,9 +139,20 @@ impl TryFrom for U32PowerOf2 { 536870912 => U32PowerOf2::_536870912, 1073741824 => U32PowerOf2::_1073741824, 2147483648 => U32PowerOf2::_2147483648, - n => return Err(NotAU32PowerOf2(n)), + n => return None, }) } + + /// Tries to convert a usize to U32PowerOf2. + pub const fn try_from_usize(value: usize) -> Option { Self::try_from_u32(value as u32) } +} + +impl TryFrom for U32PowerOf2 { + type Error = NotAU32PowerOf2; + + fn try_from(value: u32) -> Result { + U32PowerOf2::try_from_u32(value).ok_or(NotAU32PowerOf2(value)) + } } impl From for u64 { diff --git a/shame/src/common/prettify.rs b/shame/src/common/prettify.rs index 719f383..09b95c1 100644 --- a/shame/src/common/prettify.rs +++ b/shame/src/common/prettify.rs @@ -29,3 +29,14 @@ pub fn set_color(w: &mut W, hexcode: Option<&str>, use_256_c } } } + +/// Implements `Display` to print `Some(T)` as `T` and `None` as the provided &'static str. +pub(crate) struct UnwrapDisplayOr(pub Option, pub &'static str); +impl std::fmt::Display for UnwrapDisplayOr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + UnwrapDisplayOr(Some(s), _) => s.fmt(f), + UnwrapDisplayOr(None, s) => s.fmt(f), + } + } +} diff --git a/shame/src/common/proc_macro_reexports.rs b/shame/src/common/proc_macro_reexports.rs index 30170bf..56e2789 100644 --- a/shame/src/common/proc_macro_reexports.rs +++ b/shame/src/common/proc_macro_reexports.rs @@ -24,12 +24,10 @@ pub use crate::frontend::rust_types::reference::Ref; pub use crate::frontend::rust_types::struct_::BufferFields; pub use crate::frontend::rust_types::struct_::SizedFields; pub use crate::frontend::rust_types::type_layout::FieldLayout; -pub use crate::frontend::rust_types::type_layout::FieldLayoutWithOffset; +pub use crate::frontend::rust_types::type_layout::recipe::FieldOptions; pub use crate::frontend::rust_types::type_layout::StructLayout; -pub use crate::frontend::rust_types::type_layout::StructLayoutError; +pub use crate::frontend::rust_types::type_layout::Repr; pub use crate::frontend::rust_types::type_layout::TypeLayout; -pub use crate::frontend::rust_types::type_layout::TypeLayoutRules; -pub use crate::frontend::rust_types::type_layout::TypeLayoutSemantics; pub use crate::frontend::rust_types::type_traits::BindingArgs; pub use crate::frontend::rust_types::type_traits::GpuAligned; pub use crate::frontend::rust_types::type_traits::GpuSized; @@ -40,6 +38,11 @@ pub use crate::frontend::rust_types::type_traits::NoBools; pub use crate::frontend::rust_types::type_traits::NoHandles; pub use crate::frontend::rust_types::type_traits::VertexAttribute; pub use crate::frontend::rust_types::type_traits::GpuLayoutField; +pub use crate::frontend::rust_types::type_layout::recipe::SizedStruct; +pub use crate::frontend::rust_types::type_layout::recipe::TypeLayoutRecipe; +pub use crate::frontend::rust_types::type_layout::recipe::SizedType; +pub use crate::frontend::rust_types::type_layout::recipe::SizedOrArray; +pub use crate::frontend::rust_types::type_layout::recipe::builder::StructFromPartsError; pub use crate::frontend::rust_types::AsAny; pub use crate::frontend::rust_types::GpuType; #[allow(missing_docs)] @@ -53,3 +56,4 @@ pub use crate::ir::pipeline::StageMask; pub use crate::ir::recording::CallInfo; pub use crate::ir::recording::CallInfoScope; pub use crate::mem::AddressSpace; +pub use crate::any::U32PowerOf2; diff --git a/shame/src/common/proc_macro_utils.rs b/shame/src/common/proc_macro_utils.rs index f392681..843759a 100644 --- a/shame/src/common/proc_macro_utils.rs +++ b/shame/src/common/proc_macro_utils.rs @@ -1,13 +1,12 @@ use std::rc::Rc; use crate::{ + call_info, frontend::{ any::{Any, InvalidReason}, rust_types::{ error::FrontendError, - type_layout::{ - FieldLayout, FieldLayoutWithOffset, StructLayout, TypeLayout, TypeLayoutRules, TypeLayoutSemantics, - }, + type_layout::{FieldLayout, StructLayout, TypeLayout}, }, }, ir::{ @@ -56,9 +55,10 @@ pub fn collect_into_array_exact(mut it: impl Iterator, + /// size that was provided by the (maybe user defined) impl of `CpuLayout` + cpu_layout_provided_field_size: Option, + }, +} + +impl std::fmt::Display for CpuLayoutImplMismatch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CpuLayoutImplMismatch::UnexpectedSize { + field_index, + field_type_name: t, + struct_name, + expected_field_size: std_mem_size, + cpu_layout_provided_field_size: cpu_layout_impl_size, + } => { + let field_index_count_from_1 = field_index + 1; + write!(f, "Field {field_index_count_from_1} of struct `{struct_name}` has a type with a `CpuLayout` implementation that \ + claims this field is a `{t}` "); + + match cpu_layout_impl_size { + Some(s) => write!(f, "with a byte-size of {s},"), + None => write!(f, "with a size unknown at compile time,"), + }; + + match std_mem_size { + Some(s) => write!(f, "\nbut this field has an actual byte-size of {s}"), + None => write!(f, "\nbut the size of this field is actually unknown at compile-time (unsized)"), + }; + writeln!(f, ".")?; + writeln!( + f, + "This is most likely caused by a mistake in the `shame::CpuLayout` implementation of this field's type \ + or in the implementation of one of the types it is composed of. \ + The size of the layout returned in the `CpuLayout` implementation must be equal to what `std::mem::size_of` returns." + )?; + } + } + Ok(()) + } +} + +#[track_caller] +fn try_report_cpu_layout_impl_mismatch(err: CpuLayoutImplMismatch) { + let caller = call_info!(); + let success = Context::try_with(caller, |ctx| { + ctx.push_error(crate::frontend::encoding::EncodingErrorKind::LayoutError( + err.clone().into(), + )); + }) + .unwrap_or_else(|| { + if crate::__private::DEBUG_PRINT_ENABLED { + println!("`shame` warning @ {caller}:\n{err}"); + } else { + // unable to report assumed implementation mistake of `CpuLayout` for a given type + panic!("shame error at {caller} \n{err}"); + } + }); +} + +#[track_caller] pub fn repr_c_struct_layout( - repr_c_align_attribute: Option, + repr_c_align_attribute: Option, struct_name: &'static str, first_fields_with_offsets_and_sizes: &[(ReprCField, usize, usize)], - last_field: ReprCField, - last_field_size: Option, + mut last_field: ReprCField, + // the size of the last field according to the `CpuAligned` trait's associated constant + last_field_trait_size: Option, ) -> Result { let last_field_offset = match first_fields_with_offsets_and_sizes.last() { None => 0, @@ -80,59 +150,84 @@ pub fn repr_c_struct_layout( return Err(ReprCError::SecondLastElementIsUnsized); }; round_up( - last_field.alignment as u64, + last_field.alignment.as_u64(), *_2nd_last_offset as u64 + *_2nd_last_size as u64, ) } }; - let max_alignment = first_fields_with_offsets_and_sizes - .iter() - .map(|(f, _, _)| f.alignment) - .fold(last_field.alignment, ::std::cmp::max) as u64; - - let struct_alignment = match repr_c_align_attribute { - Some(repr_c_align) => max_alignment.max(repr_c_align), - None => max_alignment, + let struct_alignment = { + let max_alignment = first_fields_with_offsets_and_sizes + .iter() + .map(|(f, _, _)| f.alignment) + .fold(last_field.alignment, ::std::cmp::max); + match repr_c_align_attribute { + Some(repr_c_align_attribute) => max_alignment.max(repr_c_align_attribute), + None => max_alignment, + } }; - let last_field_size = last_field_size.map(|s| s as u64); - let total_struct_size = last_field_size.map(|last_size| round_up(struct_alignment, last_field_offset + last_size)); + /// the size of the last field according to the `CpuAligned` trait's associated constant + let last_field_trait_size = last_field_trait_size.map(|s| s as u64); + + let total_struct_size = + last_field_trait_size.map(|last_size| round_up(struct_alignment.as_u64(), last_field_offset + last_size)); let mut fields = first_fields_with_offsets_and_sizes .iter() - .map(|(field, offset, size)| (field, *offset as u64, *size as u64)) - .map(|(field, offset, size)| FieldLayoutWithOffset { - field: FieldLayout { - custom_min_align: None.into(), - custom_min_size: (field.layout.byte_size() != Some(size)).then_some(size).into(), + .map(|(field, std_mem_offset_of, std_mem_size_of)| (field, *std_mem_offset_of as u64, *std_mem_size_of as u64)) + .map(|(mut field, std_mem_offset_of, std_mem_size_of)| { + let mut layout = field.layout.clone(); + // here `std::mem::size_of` is prioritized over `<#field_type>::cpu_layout().byte_size()`. + // They can disagree if the user-driven `cpu_layout()` implementation is broken. TODO(release) reconsider this, especially in the case of f32x3 + layout.set_byte_size(std_mem_size_of); + FieldLayout { + rel_byte_offset: std_mem_offset_of, name: field.name.into(), - ty: field.layout.clone(), - }, - rel_byte_offset: offset, + ty: layout, + } }) - .chain(std::iter::once(FieldLayoutWithOffset { - field: FieldLayout { - custom_min_align: None.into(), - custom_min_size: (last_field.layout.byte_size() != last_field_size) - .then_some(last_field_size) - .flatten() - .into(), + .chain(std::iter::once({ + if last_field.layout.byte_size() != last_field_trait_size { + try_report_cpu_layout_impl_mismatch(CpuLayoutImplMismatch::UnexpectedSize { + field_index: first_fields_with_offsets_and_sizes.len(), + struct_name, + field_type_name: last_field.layout.short_name(), + expected_field_size: last_field_trait_size, + cpu_layout_provided_field_size: last_field.layout.byte_size(), + }); + } + + // here `<#last_field_type as CpuAligned>::CPU_SIZE` is prioritized over `<#field_type>::cpu_layout().byte_size()`. + // if the reporting above failed. The two can disagree if the user-driven `cpu_layout()` implementation is + // broken. + // + // If the error could not be reported above, the user cannot be informed and they + // have to accept the consequences of their broken `CpuLayout` impl. + match (last_field.layout.removable_byte_size_mut(), last_field_trait_size) { + (Ok(layout_maybe_size), trait_maybe_size) => *layout_maybe_size = trait_maybe_size, + (Err(layout_size), Some(trait_size)) => *layout_size = trait_size, + (Err(layout_size), None) => { + // in this case the rust type is an always-sized type like vector/matrix/packedvec, + // but the `CpuLayout` impl claims it is an unsized struct or array. + } + }; + + FieldLayout { + rel_byte_offset: last_field_offset, name: last_field.name.into(), ty: last_field.layout, - }, - rel_byte_offset: last_field_offset, + } })) .collect::>(); - Ok(TypeLayout::new( - total_struct_size, - struct_alignment, - TypeLayoutSemantics::Structure(Rc::new(StructLayout { - name: struct_name.into(), - fields, - })), - )) + Ok(StructLayout { + byte_size: total_struct_size, + align: struct_alignment.into(), + name: struct_name.into(), + fields, + } + .into()) } #[track_caller] diff --git a/shame/src/common/small_vec_actual.rs b/shame/src/common/small_vec_actual.rs index 939cbb9..ccdca81 100644 --- a/shame/src/common/small_vec_actual.rs +++ b/shame/src/common/small_vec_actual.rs @@ -222,6 +222,23 @@ impl std::borrow::Borrow<[T]> for SmallVec { fn borrow(&self) -> &[T] { self } } +impl std::fmt::Display for SmallVec +where + T: std::fmt::Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "[")?; + let mut iter = self.iter(); + if let Some(first) = iter.next() { + write!(f, "{}", first)?; + for item in iter { + write!(f, ", {}", item)?; + } + } + write!(f, "]") + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/shame/src/frontend/any/render_io.rs b/shame/src/frontend/any/render_io.rs index 5640f27..014fb10 100644 --- a/shame/src/frontend/any/render_io.rs +++ b/shame/src/frontend/any/render_io.rs @@ -2,8 +2,9 @@ use std::{fmt::Display, rc::Rc}; use thiserror::Error; +use crate::any::layout::{Repr}; use crate::frontend::any::Any; -use crate::frontend::rust_types::type_layout::TypeLayout; +use crate::frontend::rust_types::type_layout::{recipe, TypeLayout}; use crate::{ call_info, common::iterator_ext::try_collect, @@ -13,7 +14,6 @@ use crate::{ io_iter::LocationCounter, EncodingErrorKind, }, - rust_types::type_layout::TypeLayoutSemantics, }, ir::{ expr::{BuiltinShaderIn, BuiltinShaderIo, Expr, Interpolator, ShaderIo}, @@ -108,7 +108,7 @@ impl Any { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum VertexAttribFormat { /// regular [`crate::vec`] types - Fine(Len, ScalarType), + Fine(Len, recipe::ScalarType), /// packed [`crate::packed::PackedVec`] types Coarse(PackedVector), } @@ -238,7 +238,7 @@ impl VertexAttribFormat { #[allow(missing_docs)] pub fn type_in_shader(self) -> SizedType { match self { - VertexAttribFormat::Fine(l, t) => SizedType::Vector(l, t), + VertexAttribFormat::Fine(l, t) => SizedType::Vector(l, t.into()), VertexAttribFormat::Coarse(coarse) => coarse.decompressed_ty(), } } @@ -248,36 +248,36 @@ impl Attrib { pub(crate) fn get_attribs_and_stride( layout: &TypeLayout, mut location_counter: &LocationCounter, + stride_repr: Repr, ) -> Option<(Box<[Attrib]>, u64)> { let stride = { let size = layout.byte_size()?; - stride_of_array_from_element_align_size(layout.align(), size) + recipe::array_stride(layout.align(), size, stride_repr) }; - use TypeLayoutSemantics as TLS; + use TypeLayout::*; - let attribs: Box<[Attrib]> = match &layout.kind { - TLS::Matrix(..) | TLS::Array(..) | TLS::Vector(_, ScalarType::Bool) => return None, - TLS::Vector(len, non_bool) => [Attrib { + let attribs: Box<[Attrib]> = match &layout { + Matrix(..) | Array(..) => return None, + Vector(v) => [Attrib { offset: 0, location: location_counter.next(), - format: VertexAttribFormat::Fine(*len, *non_bool), + format: VertexAttribFormat::Fine(v.ty.len, v.ty.scalar), }] .into(), - TLS::PackedVector(packed_vector) => [Attrib { + PackedVector(v) => [Attrib { offset: 0, location: location_counter.next(), - format: VertexAttribFormat::Coarse(*packed_vector), + format: VertexAttribFormat::Coarse(v.ty), }] .into(), - TLS::Structure(rc) => try_collect(rc.fields.iter().map(|f| { + Struct(rc) => try_collect(rc.fields.iter().map(|f| { Some(Attrib { offset: f.rel_byte_offset, location: location_counter.next(), - format: match f.field.ty.kind { - TLS::Vector(_, ScalarType::Bool) => return None, - TLS::Vector(len, non_bool) => Some(VertexAttribFormat::Fine(len, non_bool)), - TLS::PackedVector(packed_vector) => Some(VertexAttribFormat::Coarse(packed_vector)), - TLS::Matrix(..) | TLS::Array(..) | TLS::Structure(..) => None, + format: match &f.ty { + Vector(v) => Some(VertexAttribFormat::Fine(v.ty.len, v.ty.scalar)), + PackedVector(v) => Some(VertexAttribFormat::Coarse(v.ty)), + Matrix(..) | Array(..) | Struct(..) => None, }?, }) }))?, diff --git a/shame/src/frontend/encoding/buffer.rs b/shame/src/frontend/encoding/buffer.rs index 50ad927..c94b640 100644 --- a/shame/src/frontend/encoding/buffer.rs +++ b/shame/src/frontend/encoding/buffer.rs @@ -17,7 +17,7 @@ use crate::frontend::rust_types::{reference::AccessModeReadable, scalar_type::Sc use crate::ir::pipeline::StageMask; use crate::ir::recording::{Context, MemoryRegion}; use crate::ir::Type; -use crate::{self as shame, call_info, ir}; +use crate::{self as shame, call_info, ir, GpuLayout}; use std::borrow::Borrow; use std::marker::PhantomData; @@ -30,9 +30,32 @@ use super::binding::Binding; /// Implemented by the marker types /// - [`mem::Uniform`] /// - [`mem::Storage`] -pub trait BufferAddressSpace: AddressSpace + SupportsAccess {} -impl BufferAddressSpace for mem::Uniform {} -impl BufferAddressSpace for mem::Storage {} +pub trait BufferAddressSpace: AddressSpace + SupportsAccess { + /// Either Storage or Uniform address space. + const BUFFER_ADDRESS_SPACE: BufferAddressSpaceEnum; +} +/// Either Storage or Uniform address space. +#[derive(Debug, Clone, Copy)] +pub enum BufferAddressSpaceEnum { + /// Storage address space + Storage, + /// Uniform address space + Uniform, +} +impl BufferAddressSpace for mem::Uniform { + const BUFFER_ADDRESS_SPACE: BufferAddressSpaceEnum = BufferAddressSpaceEnum::Uniform; +} +impl BufferAddressSpace for mem::Storage { + const BUFFER_ADDRESS_SPACE: BufferAddressSpaceEnum = BufferAddressSpaceEnum::Storage; +} +impl std::fmt::Display for BufferAddressSpaceEnum { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BufferAddressSpaceEnum::Storage => write!(f, "storage"), + BufferAddressSpaceEnum::Uniform => write!(f, "uniform"), + } + } +} /// A read-only buffer binding, for writeable buffers and atomics use [`BufferRef`] instead. /// @@ -97,14 +120,13 @@ where impl Buffer where - T: GpuStore + NoHandles + NoAtomics + NoBools, + T: GpuStore + NoHandles + NoAtomics + NoBools + GpuLayout, AS: BufferAddressSpace, { #[track_caller] fn new(args: Result) -> Self { - let skip_stride_check = true; // not a vertex buffer Context::try_with(call_info!(), |ctx| { - get_layout_compare_with_cpu_push_error::(ctx, skip_stride_check) + get_layout_compare_with_cpu_push_error::(ctx, None) }); Self { inner: T::instantiate_buffer_inner(args, BufferInner::::binding_type(DYN_OFFSET)), @@ -114,15 +136,14 @@ where impl BufferRef where - T: GpuStore + NoHandles + NoBools, + T: GpuStore + NoHandles + NoBools + GpuLayout, AS: BufferAddressSpace, AM: AccessModeReadable, { #[track_caller] fn new(args: Result) -> Self { - let skip_stride_check = true; // not a vertex buffer Context::try_with(call_info!(), |ctx| { - get_layout_compare_with_cpu_push_error::(ctx, skip_stride_check) + get_layout_compare_with_cpu_push_error::(ctx, None) }); Self { inner: T::instantiate_buffer_ref_inner(args, BufferRefInner::::binding_type(DYN_OFFSET)), @@ -294,7 +315,7 @@ impl BufferRefInner #[rustfmt::skip] impl Binding for Buffer where - T: GpuSized + T: GpuSized+ GpuLayout { fn binding_type() -> BindingType { BufferInner::::binding_type(DYN_OFFSET) } #[track_caller] @@ -316,9 +337,9 @@ fn store_type_from_impl_category(category: GpuStoreImplCategory) -> ir::StoreTyp } #[rustfmt::skip] impl -Binding for Buffer, AS, DYN_OFFSET> -where - T: GpuType + GpuSized +Binding for Buffer, AS, DYN_OFFSET> +where + T: GpuType + GpuSized + GpuLayout { fn binding_type() -> BindingType { BufferInner::::binding_type(DYN_OFFSET) } #[track_caller] @@ -464,7 +485,7 @@ where /// /// // field access returns references /// let world: sm::Ref = buffer.world; -/// +/// /// // get fields via `.get()` /// let matrix: f32x4x4 = buffer.world.get(); /// @@ -488,7 +509,7 @@ where pub(crate) inner: BufferRefInner, } -#[rustfmt::skip] impl +#[rustfmt::skip] impl Binding for BufferRef where AS: BufferAddressSpace + SupportsAccess, @@ -497,7 +518,7 @@ where fn binding_type() -> BindingType { BufferRefInner::::binding_type(DYN_OFFSET) } #[track_caller] fn new_binding(args: Result) -> Self { BufferRef::new(args) } - + fn store_ty() -> ir::StoreType { store_type_from_impl_category(T::impl_category()) } diff --git a/shame/src/frontend/encoding/io_iter.rs b/shame/src/frontend/encoding/io_iter.rs index 67a8511..7d541e9 100644 --- a/shame/src/frontend/encoding/io_iter.rs +++ b/shame/src/frontend/encoding/io_iter.rs @@ -2,42 +2,25 @@ use std::{cell::Cell, iter, marker::PhantomData, rc::Rc}; use crate::{ - call_info, - common::integer::post_inc_u32, - frontend::{ + any::layout::Repr, call_info, common::integer::post_inc_u32, frontend::{ any::{ - render_io::{Attrib, Location, VertexAttribFormat, VertexBufferLayout}, - shared_io::{BindPath, BindingType}, - Any, InvalidReason, + Any, InvalidReason, render_io::{Attrib, Location, VertexAttribFormat, VertexBufferLayout}, shared_io::{BindPath, BindingType} }, error::InternalError, rust_types::{ - error::FrontendError, - layout_traits::{ - cpu_type_name_and_layout, get_layout_compare_with_cpu_push_error, ArrayElementsUnsizedError, FromAnys, - GpuLayout, VertexLayout, - }, - reference::AccessMode, - struct_::SizedFields, - type_traits::{BindingArgs, GpuSized, GpuStore, GpuStoreImplCategory, NoAtomics, NoBools}, - GpuType, + GpuType, error::FrontendError, layout_traits::{ + ArrayElementsUnsizedError, FromAnys, GpuLayout, VertexLayout, cpu_type_name_and_layout, get_layout_compare_with_cpu_push_error + }, reference::AccessMode, struct_::SizedFields, type_traits::{BindingArgs, GpuSized, GpuStore, GpuStoreImplCategory, NoAtomics, NoBools} }, texture::{ - texture_array::{StorageTextureArray, TextureArray}, - texture_traits::{ + Sampler, Texture, TextureKind, texture_array::{StorageTextureArray, TextureArray}, texture_traits::{ LayerCoords, SamplingFormat, SamplingMethod, Spp, StorageTextureCoords, StorageTextureFormat, SupportsCoords, SupportsSpp, TextureCoords, - }, - Sampler, Texture, TextureKind, + } }, - }, - ir::{ - self, - ir_type::{Field, LayoutError}, - pipeline::{PipelineError, StageMask}, - recording::Context, - TextureFormatWrapper, - }, + }, ir::{ + self, TextureFormatWrapper, ir_type::{Field, LayoutError}, pipeline::{PipelineError, StageMask}, recording::Context + } }; use super::{binding::Binding, rasterizer::VertexIndex}; @@ -106,10 +89,18 @@ impl VertexBuffer<'_, T> { fn new(slot: u32, location_counter: Rc) -> Self { let call_info = call_info!(); let attribs_and_stride = Context::try_with(call_info, |ctx| { - let skip_stride_check = false; // it is implied that T is in an array, the strides must match - let gpu_layout = get_layout_compare_with_cpu_push_error::(ctx, skip_stride_check); - - let attribs_and_stride = Attrib::get_attribs_and_stride(&gpu_layout, &location_counter).ok_or_else(|| { + // it is implied that T is in an array, the strides must match + // + // the stride check repr only affects vertex buffers where `T = f32x3`. + // In those cases we assume a stride of 16 bytes, so that the stride of `T` is + // identical to what it would be in an `array`. If the `T` itself is a struct that + // uses #[gpu_repr(packed)], that makes `T`s alignment equal to 1 and therefore the + // chosen repr here doesn't matter. + let stride_repr = Repr::default(); + + let gpu_layout = get_layout_compare_with_cpu_push_error::(ctx, Some(stride_repr)); + + let attribs_and_stride = Attrib::get_attribs_and_stride(&gpu_layout, &location_counter, stride_repr).ok_or_else(|| { ctx.push_error(FrontendError::MalformedVertexBufferLayout(gpu_layout).into()); InvalidReason::ErrorThatWasPushed }); @@ -454,13 +445,13 @@ impl BindingIter<'_> { /// let texarr: sm::TextureArray = bind_group.next(); /// let texarr: sm::TextureArray, 4> = bind_group.next(); /// ``` - /// --- + /// --- /// ## storage textures /// ``` /// let texsto: sm::StorageTexture = bind_group.next(); /// let texsto: sm::StorageTexture = bind_group.next(); /// ``` - /// --- + /// --- /// ## Arrays of storage textures /// ``` /// let texstoarr: sm::StorageTextureArray = bind_group.next(); @@ -558,14 +549,13 @@ impl PushConstants<'_> { #[track_caller] pub fn get(self) -> T where - T: GpuStore + GpuSized + NoAtomics + NoBools, + T: GpuStore + GpuSized + NoAtomics + NoBools + GpuLayout, { let _caller_scope = Context::call_info_scope(); // the push constants structure as a whole doesn't need to have the same stride - let skip_stride_check = true; Context::try_with(call_info!(), |ctx| { - let _ = get_layout_compare_with_cpu_push_error::(ctx, skip_stride_check); + let _ = get_layout_compare_with_cpu_push_error::(ctx, None); match T::impl_category() { GpuStoreImplCategory::Fields(buffer_block) => match buffer_block.last_unsized_field() { diff --git a/shame/src/frontend/rust_types/array.rs b/shame/src/frontend/rust_types/array.rs index 853e6e1..7c66084 100644 --- a/shame/src/frontend/rust_types/array.rs +++ b/shame/src/frontend/rust_types/array.rs @@ -5,7 +5,7 @@ use super::len::x1; use super::mem::AddressSpace; use super::reference::{AccessMode, AccessModeReadable, AccessModeWritable, Read}; use super::scalar_type::ScalarTypeInteger; -use super::type_layout::{ElementLayout, TypeLayout, TypeLayoutRules, TypeLayoutSemantics}; +use super::type_layout::{self, recipe, TypeLayout, ArrayLayout}; use super::type_traits::{ BindingArgs, EmptyRefFields, GpuAligned, GpuSized, GpuStore, GpuStoreImplCategory, NoAtomics, NoBools, NoHandles, }; @@ -159,15 +159,18 @@ impl GpuSized for Array> { impl ToGpuType for Array { type Gpu = Self; - - fn to_gpu(&self) -> Self::Gpu { self.clone() } fn as_gpu_type_ref(&self) -> Option<&Self::Gpu> { Some(self) } } impl GpuLayout for Array { - fn gpu_layout() -> TypeLayout { TypeLayout::from_array(TypeLayoutRules::Wgsl, &T::sized_ty(), N::LEN) } + fn layout_recipe() -> recipe::TypeLayoutRecipe { + match N::LEN { + Some(n) => recipe::SizedArray::new(Rc::new(T::layout_recipe_sized()), n).into(), + None => recipe::RuntimeSizedArray::new(T::layout_recipe_sized()).into(), + } + } fn cpu_type_name_and_layout() -> Option, TypeLayout), ArrayElementsUnsizedError>> { let (t_cpu_name, t_cpu_layout) = match T::cpu_type_name_and_layout()? { @@ -185,17 +188,16 @@ impl GpuLayout for Array { let result = ( name.into(), - TypeLayout::new( - N::LEN.map(|n| n.get() as u64 * t_cpu_size), - t_cpu_layout.align(), - TypeLayoutSemantics::Array( - Rc::new(ElementLayout { - byte_stride: stride_of_array_from_element_align_size(t_cpu_layout.align(), t_cpu_size), - ty: t_cpu_layout, - }), - N::LEN.map(NonZeroU32::get), - ), - ), + ArrayLayout { + byte_size: N::LEN.map(|n| n.get() as u64 * t_cpu_size), + align: t_cpu_layout.align().into(), + // array stride is element size according to + // https://doc.rust-lang.org/reference/type-layout.html#r-layout.properties.size + byte_stride: t_cpu_size, + element_ty: t_cpu_layout, + len: N::LEN.map(NonZeroU32::get), + } + .into(), ); Some(Ok(result)) diff --git a/shame/src/frontend/rust_types/atomic.rs b/shame/src/frontend/rust_types/atomic.rs index e391b6b..f6a465c 100644 --- a/shame/src/frontend/rust_types/atomic.rs +++ b/shame/src/frontend/rust_types/atomic.rs @@ -6,7 +6,7 @@ use super::{ mem::{AddressSpace, AddressSpaceAtomic}, reference::{AccessMode, AccessModeReadable, ReadWrite}, scalar_type::{ScalarType, ScalarTypeInteger}, - type_layout::{TypeLayout, TypeLayoutRules}, + type_layout::{self, TypeLayout}, type_traits::{ BindingArgs, EmptyRefFields, GpuAligned, GpuSized, GpuStore, GpuStoreImplCategory, NoAtomics, NoBools, NoHandles, @@ -14,7 +14,7 @@ use super::{ vec::vec, AsAny, GpuType, To, ToGpuType, }; -use crate::frontend::rust_types::reference::Ref; +use crate::frontend::rust_types::{reference::Ref, type_layout::recipe}; use crate::{ boolx1, frontend::{ @@ -130,7 +130,12 @@ impl GetAllFields for Atomic { } impl GpuLayout for Atomic { - fn gpu_layout() -> TypeLayout { TypeLayout::from_sized_ty(TypeLayoutRules::Wgsl, &::sized_ty()) } + fn layout_recipe() -> recipe::TypeLayoutRecipe { + recipe::Atomic { + scalar: T::SCALAR_TYPE_INTEGER, + } + .into() + } fn cpu_type_name_and_layout() -> Option, TypeLayout), ArrayElementsUnsizedError>> { diff --git a/shame/src/frontend/rust_types/layout_traits.rs b/shame/src/frontend/rust_types/layout_traits.rs index 96e9373..ac18208 100644 --- a/shame/src/frontend/rust_types/layout_traits.rs +++ b/shame/src/frontend/rust_types/layout_traits.rs @@ -1,3 +1,4 @@ +use crate::any::layout::{TypeLayoutRecipe, Repr, SizedType}; use crate::call_info; use crate::common::po2::U32PowerOf2; use crate::common::proc_macro_utils::{self, repr_c_struct_layout, ReprCError, ReprCField}; @@ -9,6 +10,7 @@ use crate::frontend::encoding::buffer::{BufferAddressSpace, BufferInner, BufferR use crate::frontend::encoding::{EncodingError, EncodingErrorKind}; use crate::frontend::error::InternalError; use crate::frontend::rust_types::len::*; +use crate::frontend::rust_types::type_layout::{ArrayLayout, VectorLayout}; use crate::ir::ir_type::{ align_of_array, align_of_array_from_element_alignment, byte_size_of_array_from_stride_len, round_up, stride_of_array_from_element_align_size, CanonName, LayoutError, ScalarTypeFp, ScalarTypeInteger, @@ -21,18 +23,16 @@ use super::error::FrontendError; use super::mem::AddressSpace; use super::reference::{AccessMode, AccessModeReadable}; use super::struct_::{BufferFields, SizedFields, Struct}; -use super::type_layout::{ - ElementLayout, FieldLayout, FieldLayoutWithOffset, StructLayout, TypeLayout, TypeLayoutError, TypeLayoutRules, - TypeLayoutSemantics, -}; +use super::type_layout::recipe::{self, array_stride, Vector, ScalarType}; +use super::type_layout::{self, FieldLayout, StructLayout, TypeLayout}; use super::type_traits::{ BindingArgs, GpuAligned, GpuSized, GpuStore, GpuStoreImplCategory, NoAtomics, NoBools, NoHandles, VertexAttribute, }; -use super::{len::Len, scalar_type::ScalarType, vec::vec}; +use super::{len::Len, vec::vec}; use super::{AsAny, GpuType, ToGpuType}; use crate::frontend::any::{shared_io::BindPath, shared_io::BindingType}; use crate::frontend::rust_types::reference::Ref; -use crate::ir::{self, AlignedType, ScalarType as ST, SizedStruct, SizedType, StoreType}; +use crate::ir::{self, AlignedType, ScalarType as ST, SizedStruct, StoreType}; use std::borrow::{Borrow, Cow}; use std::iter::Empty; use std::mem::size_of; @@ -114,7 +114,7 @@ use std::rc::Rc; /// # Layout comparison of different types /// /// The layouts of different [`GpuLayout`]/[`CpuLayout`] types can be compared -/// by comparing [`TypeLayout`] objects returned by `.gpu_layout()`/`.cpu_layout()` +/// by comparing [`TypeLayout`] objects returned by `.gpu_layout()`/`.cpu_layout()` /// ``` /// use shame as sm; /// use sm::{ GpuLayout, CpuLayout }; @@ -136,33 +136,33 @@ use std::rc::Rc; /// [`StorageTexture`]: crate::StorageTexture /// pub trait GpuLayout { - /// returns a [`TypeLayout`] object that can be used to inspect the layout - /// of a type on the gpu. - /// - /// # Layout comparison of different types - /// - /// The layouts of different [`GpuLayout`]/[`CpuLayout`] types can be compared - /// by comparing [`TypeLayout`] objects returned by `.gpu_layout()`/`.cpu_layout()` - /// ``` - /// use shame as sm; - /// use sm::{ GpuLayout, CpuLayout }; - /// - /// type OnGpu = sm::Array>; - /// type OnCpu = [f32; 16]; - /// - /// if OnGpu::gpu_layout() == OnCpu::cpu_layout() { - /// println!("same layout") - /// } - /// println!("OnGpu:\n{}\n", OnGpu::gpu_layout()); - /// println!("OnCpu:\n{}\n", OnCpu::cpu_layout()); - /// ``` - fn gpu_layout() -> TypeLayout; + /// Returns a [`TypeLayoutRecipe`] that describes how a layout algorithm (repr) should layout this type in memory. + fn layout_recipe() -> TypeLayoutRecipe; + + /// For `GpuSized` types, this returns the [`SizedType`] that describes the type's layout. + fn layout_recipe_sized() -> SizedType + where + Self: GpuSized, + { + match Self::layout_recipe() { + TypeLayoutRecipe::Sized(s) => s, + TypeLayoutRecipe::RuntimeSizedArray(_) | TypeLayoutRecipe::UnsizedStruct(_) => { + unreachable!("Self is GpuSized, which these TypeLayoutRecipe variants aren't.") + } + } + } /// the `#[cpu(...)]` in `#[derive(GpuLayout)]` allows the definition of a /// corresponding Cpu type to the Gpu type that the derive macro is used on. /// /// If this association exists, this function returns the name and layout of /// that Cpu type, otherwise `None` is returned. + /// examples: + /// - vec: has no association like that + /// - PackedVec: has no association like that + /// - mat: has no association like that + /// - Array: has such an association if the inner type does + /// - Struct: has such an association if `T` does /// /// implementor note: if a nested type's `cpu_type_name_and_layout` returns `Some` /// this function _MUST NOT_ return `None`, as it would throw away assumptions @@ -171,6 +171,36 @@ pub trait GpuLayout { fn cpu_type_name_and_layout() -> Option, TypeLayout), ArrayElementsUnsizedError>>; } +/// returns a [`TypeLayout`] object that can be used to inspect the layout +/// of a type on the gpu. +/// +/// # Layout comparison of different types +/// +/// The layouts of different [`GpuLayout`]/[`CpuLayout`] types can be compared +/// by comparing [`TypeLayout`] objects returned by `.gpu_layout()`/`.cpu_layout()` +/// ``` +/// use shame as sm; +/// use sm::{ GpuLayout, CpuLayout }; +/// +/// type OnGpu = sm::Array>; +/// type OnCpu = [f32; 16]; +/// +/// if OnGpu::gpu_layout() == OnCpu::cpu_layout() { +/// println!("same layout") +/// } +/// println!("OnGpu:\n{}\n", OnGpu::gpu_layout()); +/// println!("OnCpu:\n{}\n", OnCpu::cpu_layout()); +/// ``` +#[track_caller] +pub fn gpu_layout() -> TypeLayout { T::layout_recipe().layout() } + +/// (no documentation yet) +// `CpuLayout::cpu_layout` exists, but this function exists for consistency with +// the `gpu_layout` function. `GpuLayout::gpu_layout` does not exist, so that implementors +// of `GpuLayout` can't overwrite it. +#[track_caller] +pub fn cpu_layout() -> TypeLayout { T::cpu_layout() } + pub(crate) fn cpu_type_name_and_layout(ctx: &Context) -> Option<(Cow<'static, str>, TypeLayout)> { match T::cpu_type_name_and_layout().transpose() { Ok(t) => t, @@ -190,50 +220,60 @@ pub(crate) fn cpu_type_name_and_layout(ctx: &Context) -> Option<(C /// returns the `TypeLayout` of `T` and pushes an error to the provided context if it is incompatible with its associated cpu layout pub(crate) fn get_layout_compare_with_cpu_push_error( ctx: &Context, - skip_stride_check: bool, + treat_as_array_element_with_stride_repr: Option, ) -> TypeLayout { const ERR_COMMENT: &str = "`GpuLayout` uses WGSL layout rules unless #[gpu_repr(packed)] is used.\nsee https://www.w3.org/TR/WGSL/#structure-member-layout\n`CpuLayout` uses #[repr(C)].\nsee https://doc.rust-lang.org/reference/type-layout.html#r-layout.repr.c.struct"; - let gpu_layout = T::gpu_layout(); + let gpu_layout = gpu_layout::(); if let Some((cpu_name, cpu_layout)) = cpu_type_name_and_layout::(ctx) { - check_layout_push_error(ctx, &cpu_name, &cpu_layout, &gpu_layout, skip_stride_check, ERR_COMMENT).ok(); + check_layout_push_error(ctx, &cpu_name, &cpu_layout, &gpu_layout, treat_as_array_element_with_stride_repr, ERR_COMMENT).ok(); } gpu_layout } +fn repr_c_array_stride_from_array_element_size(element_size: u64) -> u64 { + // in repr(C) the stride is equal to the element size + element_size +} + pub(crate) fn check_layout_push_error( ctx: &Context, cpu_name: &str, cpu_layout: &TypeLayout, gpu_layout: &TypeLayout, - skip_stride_check: bool, + treat_as_array_element_with_stride_repr: Option, comment_on_mismatch_error: &str, ) -> Result<(), InvalidReason> { - TypeLayout::check_eq(("cpu", cpu_layout), ("gpu", gpu_layout)) + type_layout::eq::check_eq(("cpu", cpu_layout), ("gpu", gpu_layout)) .map_err(|e| LayoutError::LayoutMismatch(e, Some(comment_on_mismatch_error.to_string()))) .and_then(|_| { - if skip_stride_check { - Ok(()) - } else { - // the layout is an element in an array, so the strides need to match too - match (cpu_layout.byte_size(), gpu_layout.byte_size()) { - (None, None) | (None, Some(_)) => Err(LayoutError::UnsizedStride { name: cpu_name.into() }), - (Some(_), None) => Err(LayoutError::UnsizedStride { - name: gpu_layout.short_name(), - }), - (Some(cpu_size), Some(gpu_size)) => { - let cpu_stride = stride_of_array_from_element_align_size(cpu_layout.align(), cpu_size); - let gpu_stride = stride_of_array_from_element_align_size(gpu_layout.align(), gpu_size); - - if cpu_stride != gpu_stride { - Err(LayoutError::StrideMismatch { - cpu_name: cpu_name.into(), - cpu_stride, - gpu_name: gpu_layout.short_name(), - gpu_stride, - }) - } else { - Ok(()) + match treat_as_array_element_with_stride_repr { + None => { + Ok(()) + } + Some(stride_repr) => { + // the layout is an element in an array, so the strides need to match too + match (cpu_layout.byte_size(), gpu_layout.byte_size()) { + (None, None) | (None, Some(_)) => Err(LayoutError::UnsizedStride { name: cpu_name.into() }), + (Some(_), None) => Err(LayoutError::UnsizedStride { + name: gpu_layout.short_name(), + }), + + (Some(cpu_size), Some(gpu_size)) => { + let cpu_stride = repr_c_array_stride_from_array_element_size(cpu_size); + + let gpu_stride = array_stride(gpu_layout.align(), gpu_size, stride_repr); + + if cpu_stride != gpu_stride { + Err(LayoutError::StrideMismatch { + cpu_name: cpu_name.into(), + cpu_stride, + gpu_name: gpu_layout.short_name(), + gpu_stride, + }) + } else { + Ok(()) + } } } } @@ -418,7 +458,7 @@ where ); match struct_ { Ok(s) => s, - Err(ir::StructureFieldNamesMustBeUnique) => unreachable!("field names are assumed unique"), + Err(ir::StructureFieldNamesMustBeUnique { .. }) => unreachable!("field names are assumed unique"), } } } @@ -512,7 +552,9 @@ impl BufferFields for GpuT { Ok(t) => t, Err(e) => match e { E::MustHaveAtLeastOneField => unreachable!(">= 1 field is ensured by derive macro"), - E::FieldNamesMustBeUnique => unreachable!("unique field idents are ensured by rust struct definition"), + E::FieldNamesMustBeUnique(_) => { + unreachable!("unique field idents are ensured by rust struct definition") + } }, } } @@ -534,27 +576,7 @@ where } impl GpuLayout for GpuT { - fn gpu_layout() -> TypeLayout { - use crate::__private::proc_macro_reexports as rx; - // compiler_error! if the struct has zero fields! - let is_packed = false; - let result = rx::TypeLayout::struct_from_parts( - rx::TypeLayoutRules::Wgsl, - is_packed, - std::stringify!(GpuT).into(), - [FieldLayout { - name: unreachable!(), - custom_min_size: unreachable!(), - custom_min_align: unreachable!(), - ty: as rx::GpuLayout>::gpu_layout(), - }] - .into_iter(), - ); - match result { - Ok(layout) => layout, - Err(e @ rx::StructLayoutError::UnsizedFieldMustBeLast { .. }) => unreachable!(), - } - } + fn layout_recipe() -> recipe::TypeLayoutRecipe { todo!() } fn cpu_type_name_and_layout() -> Option, TypeLayout), ArrayElementsUnsizedError>> { Some(Ok(( @@ -699,80 +721,78 @@ where //fn gpu_type_layout() -> Option> { Some(Ok(GpuT::gpu_layout())) } } -impl CpuLayout for f32 { - fn cpu_layout() -> TypeLayout { - TypeLayout::from_rust_sized::(TypeLayoutSemantics::Vector(ir::Len::X1, Self::SCALAR_TYPE)) +fn cpu_layout_of_scalar(scalar: ScalarType) -> TypeLayout { + let (size, align) = match scalar { + ScalarType::F32 => (size_of::(), align_of::()), + ScalarType::F64 => (size_of::(), align_of::()), + ScalarType::U32 => (size_of::(), align_of::()), + ScalarType::I32 => (size_of::(), align_of::()), + // Waiting for f16 to become stable + // ScalarType::F16 => (size_of::(), align_of::()), + ScalarType::F16 => (2, 2), + }; + VectorLayout { + byte_size: size as u64, + // https://doc.rust-lang.org/reference/type-layout.html#r-layout.properties.align + align: U32PowerOf2::try_from(align as u32) + .expect("aligns are power of 2s in rust") + .into(), + ty: Vector::new(scalar, recipe::Len::X1), + debug_is_atomic: false, } - // fn gpu_type_layout() -> Option> { - // Some(Ok(vec::::gpu_layout())) - // } + .into() } +impl CpuLayout for f32 { + fn cpu_layout() -> TypeLayout { cpu_layout_of_scalar(ScalarType::F32) } +} impl CpuLayout for f64 { - fn cpu_layout() -> TypeLayout { - TypeLayout::from_rust_sized::(TypeLayoutSemantics::Vector(ir::Len::X1, Self::SCALAR_TYPE)) - } - // fn gpu_type_layout() -> Option> { - // Some(Ok(vec::::gpu_layout())) - // } + fn cpu_layout() -> TypeLayout { cpu_layout_of_scalar(ScalarType::F64) } } - impl CpuLayout for u32 { - fn cpu_layout() -> TypeLayout { - TypeLayout::from_rust_sized::(TypeLayoutSemantics::Vector(ir::Len::X1, Self::SCALAR_TYPE)) - } - // fn gpu_type_layout() -> Option> { - // Some(Ok(vec::::gpu_layout())) - // } + fn cpu_layout() -> TypeLayout { cpu_layout_of_scalar(ScalarType::U32) } } - impl CpuLayout for i32 { - fn cpu_layout() -> TypeLayout { - TypeLayout::from_rust_sized::(TypeLayoutSemantics::Vector(ir::Len::X1, Self::SCALAR_TYPE)) - } - // fn gpu_type_layout() -> Option> { - // Some(Ok(vec::::gpu_layout())) - // } + fn cpu_layout() -> TypeLayout { cpu_layout_of_scalar(ScalarType::I32) } } /// (no documentation yet) pub trait CpuAligned { /// (no documentation yet) - const CPU_ALIGNMENT: usize; + const CPU_ALIGNMENT: U32PowerOf2; /// (no documentation yet) const CPU_SIZE: Option; /// (no documentation yet) - fn alignment() -> usize; + fn alignment() -> U32PowerOf2; } impl CpuAligned for T { - const CPU_ALIGNMENT: usize = std::mem::align_of::(); + const CPU_ALIGNMENT: U32PowerOf2 = + U32PowerOf2::try_from_usize(std::mem::align_of::()).expect("alignment of types is always a power of 2"); const CPU_SIZE: Option = Some(std::mem::size_of::()); - fn alignment() -> usize { std::mem::align_of::() } + fn alignment() -> U32PowerOf2 { Self::CPU_ALIGNMENT } } impl CpuAligned for [T] { // must be same as align of `T` since `std::slice::from_ref` and `&slice[0]` exist - const CPU_ALIGNMENT: usize = std::mem::align_of::(); + const CPU_ALIGNMENT: U32PowerOf2 = T::CPU_ALIGNMENT; const CPU_SIZE: Option = None; - fn alignment() -> usize { std::mem::align_of_val::<[T]>(&[]) } + fn alignment() -> U32PowerOf2 { + U32PowerOf2::try_from_usize(std::mem::align_of_val::<[T]>(&[])) + .expect("alignment of types is always a power of 2") + } } impl CpuLayout for [T; N] { fn cpu_layout() -> TypeLayout { - let align = ::alignment() as u64; - - TypeLayout::new( - Some(std::mem::size_of::() as u64), - align, - TypeLayoutSemantics::Array( - Rc::new(ElementLayout { - byte_stride: std::mem::size_of::() as u64, - ty: T::cpu_layout(), - }), - Some(u32::try_from(N).expect("arrays larger than u32::MAX elements are not supported by WGSL")), - ), - ) + ArrayLayout { + byte_size: Some(std::mem::size_of::() as u64), + align: ::alignment().into(), + byte_stride: std::mem::size_of::() as u64, + element_ty: T::cpu_layout(), + len: Some(u32::try_from(N).expect("arrays larger than u32::MAX elements are not supported by WGSL")), + } + .into() } // fn gpu_type_layout() -> Option> { @@ -803,19 +823,14 @@ impl CpuLayout for [T; N] { impl CpuLayout for [T] { fn cpu_layout() -> TypeLayout { - let align = ::alignment() as u64; - - TypeLayout::new( - None, - align, - TypeLayoutSemantics::Array( - Rc::new(ElementLayout { - byte_stride: std::mem::size_of::() as u64, - ty: T::cpu_layout(), - }), - None, - ), - ) + ArrayLayout { + byte_size: None, + align: ::alignment().into(), + byte_stride: std::mem::size_of::() as u64, + element_ty: T::cpu_layout(), + len: None, + } + .into() } // TODO(release) remove if we decide to not support this function on the `CpuLayout` trait diff --git a/shame/src/frontend/rust_types/mat.rs b/shame/src/frontend/rust_types/mat.rs index 08c72ab..8ccc8d0 100644 --- a/shame/src/frontend/rust_types/mat.rs +++ b/shame/src/frontend/rust_types/mat.rs @@ -7,7 +7,7 @@ use super::{ mem::AddressSpace, reference::{AccessMode, AccessModeReadable}, scalar_type::{ScalarType, ScalarTypeFp}, - type_layout::{TypeLayout, TypeLayoutRules}, + type_layout::{self, recipe, TypeLayout}, type_traits::{ BindingArgs, EmptyRefFields, GpuAligned, GpuSized, GpuStore, GpuStoreImplCategory, NoAtomics, NoBools, NoHandles, @@ -51,7 +51,14 @@ impl Default for mat { } impl GpuLayout for mat { - fn gpu_layout() -> TypeLayout { TypeLayout::from_sized_ty(TypeLayoutRules::Wgsl, &::sized_ty()) } + fn layout_recipe() -> recipe::TypeLayoutRecipe { + recipe::Matrix { + columns: C::LEN2, + rows: R::LEN2, + scalar: T::SCALAR_TYPE_FP, + } + .into() + } fn cpu_type_name_and_layout() -> Option, TypeLayout), ArrayElementsUnsizedError>> { None } } @@ -181,7 +188,7 @@ impl mat { /// sm::vec!(2.0, 5.0) // row 3 /// ]) /// - /// let m: f32x3x2 = sm::mat::new([ + /// let m: f32x3x2 = sm::mat::new([ /// 0.0, 3.0 // column 0, becomes row 0 /// 1.0, 4.0 // column 1, becomes row 1 /// 2.0, 5.0 // column 2, becomes row 2 @@ -235,7 +242,7 @@ impl mat { /// sm::vec!(2.0, 5.0) // row 3 /// ]) /// - /// let m: f32x3x2 = sm::mat::new([ + /// let m: f32x3x2 = sm::mat::new([ /// 0.0, 3.0 // column 0, becomes row 0 /// 1.0, 4.0 // column 1, becomes row 1 /// 2.0, 5.0 // column 2, becomes row 2 @@ -496,7 +503,7 @@ impl mat { /// sm::vec!(2.0, 5.0) // row 3 /// ]) /// - /// let m: f32x3x2 = sm::mat::new([ + /// let m: f32x3x2 = sm::mat::new([ /// 0.0, 3.0 // column 0, becomes row 0 /// 1.0, 4.0 // column 1, becomes row 1 /// 2.0, 5.0 // column 2, becomes row 2 diff --git a/shame/src/frontend/rust_types/mem.rs b/shame/src/frontend/rust_types/mem.rs index 2050947..f9fb2f6 100644 --- a/shame/src/frontend/rust_types/mem.rs +++ b/shame/src/frontend/rust_types/mem.rs @@ -57,12 +57,12 @@ pub struct PushConstant(()); /// the only source of uniform values, and reading in it does not necessarily /// produce uniform values (i.e. during array lookup, if the array index is /// not uniform). -#[derive(Clone, Copy)] +#[derive(Debug, Clone, Copy)] pub struct Uniform(()); /// ### the address space of storage buffer bindings /// /// readable and writeable, visible across all threads of a dispatch/drawcall -#[derive(Clone, Copy)] +#[derive(Debug, Clone, Copy)] pub struct Storage(()); /// ### the address space of texture-/sampler bindings /// diff --git a/shame/src/frontend/rust_types/mod.rs b/shame/src/frontend/rust_types/mod.rs index 8a79162..598a562 100644 --- a/shame/src/frontend/rust_types/mod.rs +++ b/shame/src/frontend/rust_types/mod.rs @@ -51,7 +51,7 @@ pub mod vec_range_traits; //[old-doc] - `StorageTexture<…>` /// (no documentation yet) /// -pub trait GpuType: ToGpuType + From + AsAny + Clone + GpuLayout { +pub trait GpuType: ToGpuType + From + AsAny + Clone { /// (no documentation yet) #[doc(hidden)] // returns a type from the `any` api fn ty() -> ir::Type; @@ -62,7 +62,9 @@ pub trait GpuType: ToGpuType + From + AsAny + Clone + GpuLayout /// (no documentation yet) #[track_caller] - fn from_any(any: Any) -> Self { typecheck_downcast(any, Self::ty(), Self::from_any_unchecked) } + fn from_any(any: Any) -> Self { + typecheck_downcast(any, Self::ty(), Self::from_any_unchecked) + } } /// (no documentation yet) @@ -113,11 +115,15 @@ pub trait ToGpuType { /// (no documentation yet) #[track_caller] - fn to_any(&self) -> Any { self.to_gpu().as_any() } + fn to_any(&self) -> Any { + self.to_gpu().as_any() + } /// (no documentation yet) #[track_caller] - fn as_gpu_type_ref(&self) -> Option<&Self::Gpu> { None } + fn as_gpu_type_ref(&self) -> Option<&Self::Gpu> { + None + } /// convenience function for [`shame::Cell::new(...)`] /// diff --git a/shame/src/frontend/rust_types/packed_vec.rs b/shame/src/frontend/rust_types/packed_vec.rs index 7331f5f..6bd79c5 100644 --- a/shame/src/frontend/rust_types/packed_vec.rs +++ b/shame/src/frontend/rust_types/packed_vec.rs @@ -7,7 +7,7 @@ use std::{ use crate::{ any::{AsAny, DataPackingFn}, common::floating_point::f16, - f32x2, f32x4, i32x4, u32x1, u32x4, + f32x2, f32x4, gpu_layout, i32x4, u32x1, u32x4, }; use crate::frontend::rust_types::len::{x1, x2, x3, x4}; use crate::frontend::rust_types::vec::vec; @@ -22,7 +22,7 @@ use super::{ layout_traits::{from_single_any, ArrayElementsUnsizedError, FromAnys, GpuLayout}, len::LenEven, scalar_type::ScalarType, - type_layout::{TypeLayout, TypeLayoutRules, TypeLayoutSemantics}, + type_layout::{self, recipe, Repr, TypeLayout}, type_traits::{GpuAligned, GpuSized, NoAtomics, NoBools, NoHandles, VertexAttribute}, vec::IsVec, GpuType, @@ -104,7 +104,7 @@ impl PackedVec { } } -fn get_type_description() -> PackedVector { +pub(crate) fn get_type_description() -> PackedVector { PackedVector { len: L::LEN_EVEN, bits_per_component: T::BITS_PER_COMPONENT, @@ -132,20 +132,17 @@ impl NoHandles for PackedVec {} impl NoAtomics for PackedVec {} impl GpuLayout for PackedVec { - fn gpu_layout() -> TypeLayout { - let packed_vec = get_type_description::(); - TypeLayout::new( - Some(u8::from(packed_vec.byte_size()) as u64), - packed_vec.align(), - TypeLayoutSemantics::PackedVector(get_type_description::()), - ) + fn layout_recipe() -> recipe::TypeLayoutRecipe { + recipe::PackedVector { + scalar_type: T::SCALAR_TYPE, + bits_per_component: T::BITS_PER_COMPONENT, + len: L::LEN_EVEN, + } + .into() } fn cpu_type_name_and_layout() -> Option, TypeLayout), ArrayElementsUnsizedError>> { - let sized_ty = Self::sized_ty_equivalent(); - let name = sized_ty.to_string().into(); - let layout = TypeLayout::from_sized_ty(TypeLayoutRules::Wgsl, &sized_ty); - Some(Ok((name, layout))) + None } } @@ -162,7 +159,7 @@ impl From for PackedVec { let inner = Context::try_with(call_info!(), |ctx| { let err = |ty| { ctx.push_error_get_invalid_any( - FrontendError::InvalidDowncastToNonShaderType(ty, Self::gpu_layout()).into(), + FrontendError::InvalidDowncastToNonShaderType(ty, gpu_layout::()).into(), ) }; match any.ty() { diff --git a/shame/src/frontend/rust_types/struct_.rs b/shame/src/frontend/rust_types/struct_.rs index f5615dc..ed9f9ff 100644 --- a/shame/src/frontend/rust_types/struct_.rs +++ b/shame/src/frontend/rust_types/struct_.rs @@ -22,14 +22,13 @@ use std::{ }; use super::layout_traits::{GetAllFields, GpuLayout}; -use super::type_layout::TypeLayout; +use super::type_layout::{self, recipe, TypeLayout}; use super::type_traits::{GpuAligned, GpuSized, GpuStore, GpuStoreImplCategory, NoBools}; use super::{ error::FrontendError, layout_traits::{ArrayElementsUnsizedError, FromAnys}, mem::AddressSpace, reference::{AccessMode, AccessModeReadable}, - type_layout::TypeLayoutSemantics, type_traits::{BindingArgs, NoAtomics, NoHandles}, typecheck_downcast, AsAny, }; @@ -134,8 +133,8 @@ impl Deref for Struct { fn deref(&self) -> &Self::Target { &self.fields } } -impl GpuLayout for Struct { - fn gpu_layout() -> TypeLayout { T::gpu_layout() } +impl GpuLayout for Struct { + fn layout_recipe() -> recipe::TypeLayoutRecipe { T::layout_recipe() } fn cpu_type_name_and_layout() -> Option, TypeLayout), ArrayElementsUnsizedError>> { T::cpu_type_name_and_layout().map(|x| x.map(|(name, l)| (format!("Struct<{name}>").into(), l))) diff --git a/shame/src/frontend/rust_types/type_layout.rs b/shame/src/frontend/rust_types/type_layout.rs deleted file mode 100644 index 14811ce..0000000 --- a/shame/src/frontend/rust_types/type_layout.rs +++ /dev/null @@ -1,906 +0,0 @@ -use std::{ - fmt::{Debug, Display, Write}, - num::NonZeroU32, - ops::Deref, - rc::Rc, -}; - -use crate::{ - call_info, - common::{ - ignore_eq::{IgnoreInEqOrdHash, InEqOrd}, - prettify::set_color, - }, - ir::{ - self, - ir_type::{ - align_of_array, byte_size_of_array, round_up, stride_of_array, AlignedType, CanonName, LenEven, - PackedVectorByteSize, ScalarTypeFp, ScalarTypeInteger, - }, - recording::Context, - Len, SizedType, Type, - }, -}; -use thiserror::Error; - -/// The type contained in the bytes of a `TypeLayout` -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum TypeLayoutSemantics { - /// `vec` - Vector(ir::Len, ir::ScalarType), - /// special compressed vectors for vertex attribute types - /// - /// see the [`crate::packed`] module - PackedVector(ir::PackedVector), - /// `mat`, first `Len2` is cols, 2nd `Len2` is rows - Matrix(ir::Len2, ir::Len2, ScalarTypeFp), - /// `Array` and `Array>` - Array(Rc, Option), // not NonZeroU32, since for rust `CpuLayout`s the array size may be 0. - /// structures which may be empty and may have an unsized last field - Structure(Rc), -} - -/// The memory layout of a type. -/// -/// This models only the layout, not other characteristics of the types. -/// For example an `Atomic>` is treated like a regular `vec` layout wise. -/// -/// The `PartialEq + Eq` implementation of `TypeLayout` is designed to answer the question -/// "do these two types have the same layout" so that uploading a type to the gpu -/// will result in no memory errors. -/// -/// a layout comparison looks like this: -/// ``` -/// assert!(f32::cpu_layout() == vec::gpu_layout()); -/// // or, more explicitly -/// assert_eq!( -/// ::cpu_layout(), -/// as GpuLayout>::gpu_layout(), -/// ); -/// ``` -/// -#[derive(Clone, PartialEq, Eq, Hash)] -pub struct TypeLayout { - /// size in bytes (Some), or unsized (None) - pub byte_size: Option, - /// the byte alignment - /// - /// top level alignment is not considered relevant in some checks, but relevant in others (vertex array elements) - pub byte_align: IgnoreInEqOrdHash, - /// the type contained in the bytes of this type layout - pub kind: TypeLayoutSemantics, -} - -impl TypeLayout { - pub(crate) fn new(byte_size: Option, byte_align: u64, kind: TypeLayoutSemantics) -> Self { - Self { - byte_size, - byte_align: byte_align.into(), - kind, - } - } - - pub(crate) fn from_rust_sized(kind: TypeLayoutSemantics) -> Self { - Self::new(Some(size_of::() as u64), align_of::() as u64, kind) - } - - fn first_line_of_display_with_ellipsis(&self) -> String { - let string = format!("{}", self); - string.split_once('\n').map(|(s, _)| format!("{s}…")).unwrap_or(string) - } - - /// a short name for this `TypeLayout`, useful for printing inline - pub fn short_name(&self) -> String { - match &self.kind { - TypeLayoutSemantics::Vector { .. } | - TypeLayoutSemantics::PackedVector { .. } | - TypeLayoutSemantics::Matrix { .. } => format!("{}", self), - TypeLayoutSemantics::Array(element_layout, n) => match n { - Some(n) => format!("array<{}, {n}>", element_layout.ty.short_name()), - None => format!("array<{}, runtime-sized>", element_layout.ty.short_name()), - }, - TypeLayoutSemantics::Structure(s) => s.name.to_string(), - } - } -} - -/// a sized or unsized struct type with 0 or more fields -#[allow(missing_docs)] -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct StructLayout { - pub name: IgnoreInEqOrdHash, - pub fields: Vec, -} - -impl StructLayout { - /// this exists, because if in the future a refactor happens that separates - /// fields into sized and unsized fields, the intention of this function is - /// clear - fn all_fields(&self) -> &[FieldLayoutWithOffset] { &self.fields } -} - -#[allow(missing_docs)] -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct FieldLayoutWithOffset { - pub field: FieldLayout, - pub rel_byte_offset: u64, // this being relative is used in TypeLayout::byte_size -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ElementLayout { - pub byte_stride: u64, - pub ty: TypeLayout, -} - -/// the layout rules used when calculating the byte offsets and alignment of a type -#[derive(Debug, Clone, Copy)] -pub enum TypeLayoutRules { - /// wgsl type layout rules, see https://www.w3.org/TR/WGSL/#memory-layouts - Wgsl, - // reprC, - // Std140, - // Std430, - // Scalar, -} - -#[allow(missing_docs)] -#[derive(Error, Debug, Clone)] -pub enum TypeLayoutError { - #[error("An array cannot contain elements of an unsized type {elements}")] - ArrayOfUnsizedElements { elements: TypeLayout }, - #[error("the type `{0}` has no defined {1:?} layout in shaders")] - LayoutUndefined(Type, TypeLayoutRules), - #[error("in `{parent_name}` at field `{field_name}`: {error}")] - AtField { - parent_name: CanonName, - field_name: CanonName, - error: Rc, - }, - #[error("in element of array: {0}")] - InArrayElement(Rc), -} - -#[allow(missing_docs)] // TODO(docs) low priority docs, add after release -impl TypeLayout { - pub fn from_ty(rules: TypeLayoutRules, ty: &ir::Type) -> Result { - match ty { - Type::Unit | Type::Ptr(_, _, _) | Type::Ref(_, _, _) => { - Err(TypeLayoutError::LayoutUndefined(ty.clone(), rules)) - } - Type::Store(ty) => Self::from_store_ty(rules, ty), - } - } - - pub fn from_store_ty(rules: TypeLayoutRules, ty: &ir::StoreType) -> Result { - match ty { - ir::StoreType::Sized(sized) => Ok(TypeLayout::from_sized_ty(rules, sized)), - ir::StoreType::Handle(handle) => Err(TypeLayoutError::LayoutUndefined(ir::Type::Store(ty.clone()), rules)), - ir::StoreType::RuntimeSizedArray(element) => Ok(Self::from_array(rules, element, None)), - ir::StoreType::BufferBlock(s) => Ok(Self::from_struct(rules, s)), - } - } - - pub fn from_sized_ty(rules: TypeLayoutRules, ty: &ir::SizedType) -> TypeLayout { - pub use TypeLayoutSemantics as Sem; - let size = ty.byte_size(); - let align = ty.align(); - match ty { - ir::SizedType::Vector(l, t) => - // we treat bool as a type that has a layout to allow for an - // `Eq` operator on `TypeLayout` that behaves intuitively. - // The layout of `bool`s is not actually observable in any part of the api. - { - TypeLayout::new(Some(size), align, Sem::Vector(*l, *t)) - } - ir::SizedType::Matrix(c, r, t) => TypeLayout::new(Some(size), align, Sem::Matrix(*c, *r, *t)), - ir::SizedType::Array(sized, l) => Self::from_array(rules, sized, Some(*l)), - ir::SizedType::Atomic(t) => Self::from_sized_ty(rules, &ir::SizedType::Vector(ir::Len::X1, (*t).into())), - ir::SizedType::Structure(s) => Self::from_struct(rules, s), - } - } - - pub fn from_array(rules: TypeLayoutRules, element: &ir::SizedType, len: Option) -> TypeLayout { - TypeLayout::new( - len.map(|n| byte_size_of_array(element, n)), - align_of_array(element), - TypeLayoutSemantics::Array( - Rc::new(ElementLayout { - byte_stride: match rules { - TypeLayoutRules::Wgsl => stride_of_array(element), - }, - ty: Self::from_sized_ty(rules, element), - }), - len.map(NonZeroU32::get), - ), - ) - } - - pub fn from_struct(rules: TypeLayoutRules, s: &ir::Struct) -> TypeLayout { - let (size, align, struct_layout) = StructLayout::from_ir_struct(rules, s); - TypeLayout::new(size, align, TypeLayoutSemantics::Structure(Rc::new(struct_layout))) - } - - pub fn struct_from_parts( - rules: TypeLayoutRules, - packed: bool, - name: CanonName, - fields: impl ExactSizeIterator, - ) -> Result { - let (byte_size, byte_align, struct_) = StructLayout::new(rules, packed, name, fields)?; - let layout = TypeLayout::new(byte_size, byte_align, TypeLayoutSemantics::Structure(Rc::new(struct_))); - Ok(layout) - } - - pub fn from_aligned_type(rules: TypeLayoutRules, ty: &AlignedType) -> TypeLayout { - match ty { - AlignedType::Sized(sized) => Self::from_sized_ty(rules, sized), - AlignedType::RuntimeSizedArray(element) => Self::from_array(rules, element, None), - } - } - - pub(crate) fn writeln(&self, indent: &str, colored: bool, f: &mut W) -> std::fmt::Result { - self.write(indent, colored, f)?; - writeln!(f) - } - - //TODO(low prio) try to figure out a cleaner way of writing these. - pub(crate) fn write(&self, indent: &str, colored: bool, f: &mut W) -> std::fmt::Result { - let tab = " "; - let use_256_color_mode = false; - let color = |f_: &mut W, hex| match colored { - true => set_color(f_, Some(hex), use_256_color_mode), - false => Ok(()), - }; - let reset = |f_: &mut W| match colored { - true => set_color(f_, None, use_256_color_mode), - false => Ok(()), - }; - - use TypeLayoutSemantics as Sem; - - match &self.kind { - Sem::Vector(l, t) => match l { - Len::X1 => write!(f, "{t}")?, - l => write!(f, "{t}x{}", u64::from(*l))?, - }, - Sem::PackedVector(c) => write!(f, "{}", c)?, - Sem::Matrix(c, r, t) => write!(f, "{}", ir::SizedType::Matrix(*c, *r, *t))?, - Sem::Array(t, n) => { - let stride = t.byte_stride; - write!(f, "array<")?; - t.ty.write(&(indent.to_string() + tab), colored, f)?; - if let Some(n) = n { - write!(f, ", {n}")?; - } - write!(f, "> stride={stride}")?; - } - Sem::Structure(s) => { - writeln!(f, "struct {} {{", s.name)?; - { - let indent = indent.to_string() + tab; - for field in &s.fields { - let offset = field.rel_byte_offset; - let field = &field.field; - write!(f, "{indent}{offset:3} {}: ", field.name)?; - field.ty.write(&(indent.to_string() + tab), colored, f)?; - if let Some(size) = field.ty.byte_size { - let size = size.max(field.custom_min_size.unwrap_or(0)); - write!(f, " size={size}")?; - } else { - write!(f, " size=?")?; - } - writeln!(f, ",")?; - } - } - write!(f, "{indent}}}")?; - write!(f, " align={}", self.byte_align)?; - if let Some(size) = self.byte_size { - write!(f, " size={size}")?; - } else { - write!(f, " size=?")?; - } - } - }; - Ok(()) - } - - pub fn align(&self) -> u64 { *self.byte_align } - - pub fn byte_size(&self) -> Option { self.byte_size } -} - -#[allow(missing_docs)] -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct FieldLayout { - pub name: CanonName, - pub custom_min_size: IgnoreInEqOrdHash>, // whether size/align is custom doesn't matter for the layout equality. - pub custom_min_align: IgnoreInEqOrdHash>, - pub ty: TypeLayout, -} - -impl FieldLayout { - fn byte_size(&self) -> Option { - self.ty - .byte_size() - .map(|byte_size| byte_size.max(self.custom_min_size.unwrap_or(0))) - } - fn align(&self) -> u64 { self.ty.align().max(self.custom_min_align.map(u64::from).unwrap_or(1)) } -} - -#[allow(missing_docs)] -#[derive(Error, Debug)] -pub enum StructLayoutError { - #[error( - "field #{unsized_field_index} in struct `{struct_name}` with {num_fields} is unsized. Only the last field may be unsized." - )] - UnsizedFieldMustBeLast { - struct_name: CanonName, - unsized_field_index: usize, - num_fields: usize, - }, -} - -impl StructLayout { - /// returns a `(byte_size, byte_alignment, struct_layout)` tuple or an error - /// - /// this was created for the `#[derive(GpuLayout)]` macro to support the - /// non-GpuType `PackedVec` for gpu_repr(packed) and non-packed. - /// - // TODO(low prio) find a way to merge all struct layout calculation functions in this codebase. This is very redundand. - pub(crate) fn new( - rules: TypeLayoutRules, - packed: bool, - name: CanonName, - fields: impl ExactSizeIterator, - ) -> Result<(Option, u64, StructLayout), StructLayoutError> { - let mut total_byte_size = None; - let mut total_align = 1; - let num_fields = fields.len(); - let struct_layout = StructLayout { - name: name.clone().into(), - fields: { - let mut offset_so_far = 0; - let mut fields_with_offset = Vec::new(); - for (i, field) in fields.enumerate() { - let is_last = i + 1 == num_fields; - fields_with_offset.push(FieldLayoutWithOffset { - field: field.clone(), - rel_byte_offset: match rules { - TypeLayoutRules::Wgsl => { - let field_offset = match (packed, *field.custom_min_align) { - (true, None) => offset_so_far, - (true, Some(custom_align)) => round_up(custom_align, offset_so_far), - (false, _) => round_up(field.align(), offset_so_far), - }; - match (field.byte_size(), is_last) { - (Some(field_size), _) => { - offset_so_far = field_offset + field_size; - Ok(()) - } - (None, true) => Ok(()), - (None, false) => Err(StructLayoutError::UnsizedFieldMustBeLast { - struct_name: name.clone(), - unsized_field_index: i, - num_fields, - }), - }?; - field_offset - } - }, - }); - total_align = total_align.max(field.align()); - if is_last { - // wgsl spec: - // roundUp(AlignOf(S), justPastLastMember) - // where justPastLastMember = OffsetOfMember(S,N) + SizeOfMember(S,N) - - // if the last field size is None (= unsized), just_past_last is None (= unsized) - let just_past_last = field.byte_size().map(|_| offset_so_far); - total_byte_size = just_past_last.map(|just_past_last| round_up(total_align, just_past_last)); - } - } - fields_with_offset - }, - }; - Ok((total_byte_size, total_align, struct_layout)) - } - - /// returns a `(byte_size, byte_alignment, struct_layout)` tuple - #[doc(hidden)] - pub fn from_ir_struct(rules: TypeLayoutRules, s: &ir::Struct) -> (Option, u64, StructLayout) { - let mut total_byte_size = None; - let struct_layout = StructLayout { - name: s.name().clone().into(), - fields: { - let mut offset = 0; - let mut fields = Vec::new(); - for field in s.sized_fields() { - fields.push(FieldLayoutWithOffset { - field: FieldLayout { - name: field.name.clone(), - ty: TypeLayout::from_sized_ty(rules, &field.ty), - custom_min_size: field.custom_min_size.into(), - custom_min_align: field.custom_min_align.map(u64::from).into(), - }, - rel_byte_offset: match rules { - TypeLayoutRules::Wgsl => { - let rel_byte_offset = round_up(field.align(), offset); - offset = rel_byte_offset + field.byte_size(); - rel_byte_offset - } - }, - }) - } - if let Some(unsized_array) = s.last_unsized_field() { - fields.push(FieldLayoutWithOffset { - field: FieldLayout { - name: unsized_array.name.clone(), - custom_min_align: unsized_array.custom_min_align.map(u64::from).into(), - custom_min_size: None.into(), - ty: TypeLayout::from_array(rules, &unsized_array.element_ty, None), - }, - rel_byte_offset: round_up(unsized_array.align(), offset), - }) - } else { - total_byte_size = Some(s.min_byte_size()); - } - fields - }, - }; - (total_byte_size, s.align(), struct_layout) - } -} - -impl Display for TypeLayout { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let colored = Context::try_with(call_info!(), |ctx| ctx.settings().colored_error_messages).unwrap_or(false); - self.write("", colored, f) - } -} - -impl Debug for TypeLayout { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.write("", false, f) } -} - -#[derive(Clone)] -pub struct LayoutMismatch { - /// 2 (name, layout) pairs - layouts: [(String, TypeLayout); 2], - colored_error: bool, -} - - -impl std::fmt::Debug for LayoutMismatch { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self) // use Display - } -} - -impl LayoutMismatch { - fn pad_width(name_a: &str, name_b: &str) -> usize { name_a.chars().count().max(name_b.chars().count()) + SEP.len() } -} - -impl Display for LayoutMismatch { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let colored = self.colored_error; - let [(a_name, a), (b_name, b)] = &self.layouts; - write!(f, "{:width$}", ' ', width = Self::pad_width(a_name, b_name))?; - let layouts = [(a_name.as_str(), a), (b_name.as_str(), b)]; - match LayoutMismatch::write("", layouts, colored, f) { - Err(MismatchWasFound) => Ok(()), - Ok(KeepWriting) => { - writeln!( - f, - "" - )?; - writeln!(f, "the full type layouts in question are:")?; - for (name, layout) in layouts { - writeln!(f, "`{}`:", name)?; - writeln!(f, "{}", layout)?; - } - Ok(()) - } - } - } -} - -/// layout mismatch diff name separator -/// used in the display impl of LayoutMismatch to show the actual place where the layouts mismatch -/// ``` -/// layout mismatch: -/// cpu{SEP} f32 -/// gpu{SEP} i32 -/// ``` -const SEP: &str = ": "; - -/// whether the mismatching part of the TypeLayouts in a LayoutMismatch was already expressed via writes. -/// indicates that the `write` function should stop writing. -pub(crate) struct MismatchWasFound; -pub(crate) struct KeepWriting; - -impl LayoutMismatch { - //TODO(low prio) try to figure out a cleaner way of writing these. - - /// this function uses the `Err(MismatchWasFound)` to halt traversing the typelayout. - /// It does not constitute an error of this function, it is just so the ? operator can be used to propagate the abort. - #[allow(clippy::needless_return)] - pub(crate) fn write( - indent: &str, - layouts: [(&str, &TypeLayout); 2], - colored: bool, - f: &mut W, - ) -> Result { - let tab = " "; - let [(a_name, a), (b_name, b)] = layouts; - - if a == b { - a.write(indent, colored, f); - return Ok(KeepWriting); - } - - let use_256_color_mode = false; - let hex_color = |f_: &mut W, hex| match colored { - true => set_color(f_, Some(hex), use_256_color_mode), - false => Ok(()), - }; - let color_reset = |f_: &mut W| match colored { - true => set_color(f_, None, use_256_color_mode), - false => Ok(()), - }; - - let color_a_hex = "#DF5853"; - let color_b_hex = "#9A639C"; - let color_a = |f| hex_color(f, color_a_hex); - let color_b = |f| hex_color(f, color_b_hex); - - let pad_width = Self::pad_width(a_name, b_name); - - use TypeLayoutSemantics as S; - match (&a.kind, &b.kind) { - (S::Structure(sa), S::Structure(sb)) => { - let max_fields = sa.all_fields().len().max(sb.all_fields().len()); - { - write!(f, "struct "); - hex_color(f, color_a_hex); - write!(f, "{}", sa.name); - color_reset(f); - write!(f, " / "); - hex_color(f, color_b_hex); - write!(f, "{}", sb.name); - color_reset(f); - writeln!(f, " {{"); - } - - let mut sa_fields = sa.all_fields().iter(); - let mut sb_fields = sb.all_fields().iter(); - - loop { - //TODO(low prio) get a hold of the code duplication here - match (sa_fields.next(), sb_fields.next()) { - (Some(a_field), Some(b_field)) => { - let offsets_match = a_field.rel_byte_offset == b_field.rel_byte_offset; - let types_match = a_field.field.ty == b_field.field.ty; - if !offsets_match && types_match { - // only write this mismatch if the types are also the same, otherwise display the detailed type mismatch further below - let a_ty_string = a_field.field.ty.first_line_of_display_with_ellipsis(); - let b_ty_string = b_field.field.ty.first_line_of_display_with_ellipsis(); - color_a(f); - writeln!( - f, - "{a_name}{SEP}{indent}{:3} {}: {a_ty_string} align={}", - a_field.rel_byte_offset, - a_field.field.name, - a_field.field.align() - ); - color_b(f); - writeln!( - f, - "{b_name}{SEP}{indent}{:3} {}: {b_ty_string} align={}", - b_field.rel_byte_offset, - b_field.field.name, - b_field.field.align() - ); - color_reset(f); - writeln!( - f, - "field offset is different on {a_name} ({}) and {b_name} ({}).", - a_field.rel_byte_offset, b_field.rel_byte_offset - ); - return Err(MismatchWasFound); - } - let offset = a_field.rel_byte_offset; - let a_field = &a_field.field; - let b_field = &b_field.field; - let field = &a_field; - - if offsets_match { - write!(f, "{:width$}{indent}{offset:3} ", ' ', width = pad_width); - } else { - write!(f, "{:width$}{indent} ? ", ' ', width = pad_width); - } - if a_field.name != b_field.name { - writeln!(f); - color_a(f); - writeln!(f, "{a_name}{SEP}{indent} {}: …", a_field.name); - color_b(f); - writeln!(f, "{b_name}{SEP}{indent} {}: …", b_field.name); - color_reset(f); - writeln!( - f, - "identifier mismatch, either\nfield '{}' is missing on {a_name}, or\nfield '{}' is missing on {b_name}.", - b_field.name, a_field.name - ); - return Err(MismatchWasFound); - } - write!(f, "{}: ", field.name); - if a_field.ty != b_field.ty { - Self::write( - &format!("{indent}{tab}"), - [(a_name, &a_field.ty), (b_name, &b_field.ty)], - colored, - f, - )?; - return Err(MismatchWasFound); - } - write!(f, "{}", field.ty.first_line_of_display_with_ellipsis()); - if a_field.byte_size() != b_field.byte_size() { - writeln!(f); - color_a(f); - writeln!( - f, - "{a_name}{SEP}{indent} size={}", - a_field - .byte_size() - .as_ref() - .map(|x| x as &dyn Display) - .unwrap_or(&"?" as _) - ); - color_b(f); - writeln!( - f, - "{b_name}{SEP}{indent} size={}", - b_field - .byte_size() - .as_ref() - .map(|x| x as &dyn Display) - .unwrap_or(&"?" as _) - ); - color_reset(f); - return Err(MismatchWasFound); - } - - write!( - f, - " size={}", - field - .byte_size() - .as_ref() - .map(|x| x as &dyn Display) - .unwrap_or(&"?" as _) - ); - writeln!(f, ","); - } - (Some(a_field), None) => { - let offset = a_field.rel_byte_offset; - let a_field = &a_field.field; - let field = &a_field; - color_a(f); - write!(f, "{a_name}{SEP}{indent}{offset:3} "); - write!(f, "{}: ", field.name); - write!(f, "{}", field.ty.first_line_of_display_with_ellipsis()); - write!( - f, - " size={}", - field - .byte_size() - .as_ref() - .map(|x| x as &dyn Display) - .unwrap_or(&"?" as _) - ); - writeln!(f, ","); - color_b(f); - writeln!(f, "{b_name}{SEP}{indent}", a_field.name); - color_reset(f); - return Err(MismatchWasFound); - } - (None, Some(b_field)) => { - let offset = b_field.rel_byte_offset; - let b_field = &b_field.field; - color_a(f); - writeln!(f, "{a_name}{SEP}{indent}", b_field.name); - let field = &b_field; - color_b(f); - write!(f, "{b_name}{SEP}{indent}{offset:3} "); - write!(f, "{}: ", field.name); - write!(f, "{}", field.ty.first_line_of_display_with_ellipsis()); - write!( - f, - " size={}", - field - .byte_size() - .as_ref() - .map(|x| x as &dyn Display) - .unwrap_or(&"?" as _) - ); - writeln!(f, ","); - color_reset(f); - return Err(MismatchWasFound); - } - (None, None) => break, - } - } - - write!(f, "{:width$}{indent}}}", ' ', width = pad_width); - let align_matches = a.align() == b.align(); - let size_matches = a.byte_size() == b.byte_size(); - if !align_matches && size_matches { - writeln!(f); - color_a(f); - writeln!(f, "{a_name}{SEP}{indent}align={}", a.align()); - color_b(f); - writeln!(f, "{b_name}{SEP}{indent}align={}", b.align()); - color_reset(f); - return Err(MismatchWasFound); - } else { - match align_matches { - true => write!(f, " align={}", a.align()), - false => write!(f, " align=?"), - }; - } - if !size_matches { - writeln!(f); - color_a(f); - writeln!( - f, - "{a_name}{SEP}{indent}size={}", - a.byte_size().as_ref().map(|x| x as &dyn Display).unwrap_or(&"?" as _) - ); - color_b(f); - writeln!( - f, - "{b_name}{SEP}{indent}size={}", - b.byte_size().as_ref().map(|x| x as &dyn Display).unwrap_or(&"?" as _) - ); - color_reset(f); - return Err(MismatchWasFound); - } else { - match a.byte_size() { - Some(size) => write!(f, " size={size}"), - None => write!(f, " size=?"), - }; - } - // this should never happen, returning Ok(KeepWriting) will trigger the internal error in the Display impl - return Ok(KeepWriting); - } - (S::Array(ta, na), S::Array(tb, nb)) => { - if na != nb { - writeln!(f); - color_a(f); - write!(f, "{a_name}{SEP}"); - - write!( - f, - "array<…, {}>", - match na { - Some(n) => n as &dyn Display, - None => (&"runtime-sized") as &dyn Display, - } - ); - - //a.writeln(indent, colored, f); - writeln!(f); - color_b(f); - write!(f, "{b_name}{SEP}"); - - write!( - f, - "array<…, {}>", - match nb { - Some(n) => n as &dyn Display, - None => (&"runtime-sized") as &dyn Display, - } - ); - - //b.writeln(indent, colored, f); - color_reset(f); - Err(MismatchWasFound) - } else { - write!(f, "array<"); - //ta.ty.write(&(indent.to_string() + tab), colored, f); - - Self::write( - &format!("{indent}{tab}"), - [(a_name, &ta.ty), (b_name, &tb.ty)], - colored, - f, - )?; - - if let Some(na) = na { - write!(f, ", {na}"); - } - write!(f, ">"); - - if ta.byte_stride != tb.byte_stride { - writeln!(f); - color_a(f); - writeln!(f, "{a_name}{SEP}{indent}> stride={}", ta.byte_stride); - color_b(f); - writeln!(f, "{b_name}{SEP}{indent}> stride={}", tb.byte_stride); - color_reset(f); - Err(MismatchWasFound) - } else { - // this should never happen, returning Ok(KeepWriting) will trigger the internal error in the Display impl - write!(f, "> stride={}", ta.byte_stride); - return Ok(KeepWriting); - } - } - } - (S::Vector(na, ta), S::Vector(nb, tb)) => { - writeln!(f); - color_a(f); - write!(f, "{a_name}{SEP}"); - a.writeln(indent, colored, f); - color_b(f); - write!(f, "{b_name}{SEP}"); - b.writeln(indent, colored, f); - color_reset(f); - Err(MismatchWasFound) - } - (S::Matrix(c, r, t), S::Matrix(c1, r1, t1)) => { - writeln!(f); - color_a(f); - write!(f, "{a_name}{SEP}"); - a.writeln(indent, colored, f); - color_b(f); - write!(f, "{b_name}{SEP}"); - b.writeln(indent, colored, f); - color_reset(f); - Err(MismatchWasFound) - } - (S::PackedVector(p), S::PackedVector(p1)) => { - writeln!(f); - color_a(f); - write!(f, "{a_name}{SEP}"); - a.writeln(indent, colored, f); - color_b(f); - write!(f, "{b_name}{SEP}"); - b.writeln(indent, colored, f); - color_reset(f); - Err(MismatchWasFound) - } - ( - // its written like this so that exhaustiveness checks lead us to this match statement if a type is added - S::Structure { .. } | S::Array { .. } | S::Vector { .. } | S::Matrix { .. } | S::PackedVector { .. }, - _, - ) => { - // TypeLayoutSemantics mismatch - writeln!(f); - color_a(f); - write!(f, "{a_name}{SEP}"); - a.writeln(indent, colored, f); - color_b(f); - write!(f, "{b_name}{SEP}"); - b.writeln(indent, colored, f); - color_reset(f); - Err(MismatchWasFound) - } - } - } -} - -impl TypeLayout { - /// takes two pairs of `(debug_name, layout)` and compares them for equality. - /// - /// if the two layouts are not equal it uses the debug names in the returned - /// error to tell the two layouts apart. - pub(crate) fn check_eq(a: (&str, &TypeLayout), b: (&str, &TypeLayout)) -> Result<(), LayoutMismatch> { - match a.1 == b.1 { - true => Ok(()), - false => Err(LayoutMismatch { - layouts: [(a.0.into(), a.1.clone()), (b.0.into(), b.1.clone())], - colored_error: Context::try_with(call_info!(), |ctx| ctx.settings().colored_error_messages) - .unwrap_or(false), - }), - } - } -} diff --git a/shame/src/frontend/rust_types/type_layout/compatible_with.rs b/shame/src/frontend/rust_types/type_layout/compatible_with.rs new file mode 100644 index 0000000..560dfa3 --- /dev/null +++ b/shame/src/frontend/rust_types/type_layout/compatible_with.rs @@ -0,0 +1,721 @@ +use std::fmt::Write; + +use crate::{ + BufferAddressSpace, Language, TypeLayout, any::layout::StructLayout, call_info, common::{format::display, prettify::{UnwrapDisplayOr, set_color}}, frontend::{ + encoding::buffer::BufferAddressSpaceEnum, + rust_types::type_layout::{ + ArrayLayout, display::{self, LayoutInfoFlags}, eq::{LayoutMismatch, StructMismatch, TopLevelMismatch, try_find_mismatch}, recipe::to_layout::RecipeContains + }, + }, ir::{ir_type::max_u64_po2_dividing, recording::Context}, mem +}; + +use super::{recipe::TypeLayoutRecipe, Repr}; + +pub use mem::{Storage, Uniform}; + +/// `TypeLayoutCompatibleWith` is a [`TypeLayoutRecipe`] with the additional +/// guarantee that the [`TypeLayout`] it produces is compatible with the specified `AddressSpace`. +/// +/// Address space requirements are language specific, which is why `TypeLayoutCompatibleWith` constructors +/// additionally take a [`Language`] parameter. +/// +/// To be "compatible with" an address space means that +/// - the recipe is **valid** ([`TypeLayoutRecipe::layout`] succeeds) +/// - the type layout **satisfies the layout requirements** of the address space +/// - the type layout recipe is **representable** in the target language +/// +/// To be representable in a language means that the type layout recipe can be expressed in the +/// language's type system: +/// 1. all types in the recipe can be expressed in the target language (for example `bool` (with guaranteed align=1, size=1) +/// or `PackedVector` can't be expressed in wgsl) +/// 2. the available layout algorithms in the target language can produce the same layout as the one produced by the recipe +/// 3. support for the custom attributes the recipe uses, such as `#[align(N)]` and `#[size(N)]`. +/// Custom attributes may be rejected by the target language itself (NotRepresentable error) +/// or by the layout algorithms specified in the recipe (InvalidRecipe error during `TypeLayoutRecipe -> TypeLayout` conversion). +/// +/// For example for wgsl we have +/// 1. PackedVector can be part of a recipe, but can not be expressed in wgsl, +/// so a recipe containing a PackedVector is not representable in wgsl. +/// 2. Wgsl has only one layout algorithm (`Repr::Wgsl`) - there is no choice between std140 and std430 +/// like in glsl - so to be representable in wgsl the type layout produced by the recipe +/// has to be the same as the one produced by the same recipe but using exclusively the Repr::Wgsl +/// layout algorithm instead of the layout algorithms specified in the recipe. +/// 3. Wgsl only supports custom struct field attributes `#[align(N)]` and `#[size(N)]` currently. +#[derive(Debug, Clone)] +pub struct TypeLayoutCompatibleWith { + recipe: TypeLayoutRecipe, + _phantom: std::marker::PhantomData, +} + +impl TypeLayoutCompatibleWith { + pub fn try_from(language: Language, recipe: TypeLayoutRecipe) -> Result { + let address_space = AS::BUFFER_ADDRESS_SPACE; + let layout = recipe.layout(); + + match (language, address_space, layout.byte_size()) { + // Must be sized in wgsl's uniform address space + (Language::Wgsl, BufferAddressSpaceEnum::Uniform, None) => { + return Err(RequirementsNotSatisfied::MustBeSized(recipe, language, address_space).into()); + } + (Language::Wgsl, BufferAddressSpaceEnum::Uniform, Some(_)) | + (Language::Wgsl, BufferAddressSpaceEnum::Storage, _) => {} + } + + // Check that the recipe is representable in the target language. + // See `TypeLayoutCompatibleWith` docs for more details on what it means to be representable. + match (language, address_space) { + (Language::Wgsl, BufferAddressSpaceEnum::Storage | BufferAddressSpaceEnum::Uniform) => { + // We match like this, so that future additions to `RecipeContains` lead us here. + match RecipeContains::CustomFieldAlign { + // supported in wgsl + RecipeContains::CustomFieldAlign | RecipeContains::CustomFieldSize | + // not supported in wgsl + RecipeContains::PackedVector => { + if recipe.contains(RecipeContains::PackedVector) { + return Err(NotRepresentable::MayNotContain( + recipe, + language, + address_space, + RecipeContains::PackedVector, + ) + .into()); + } + } + } + + // Wgsl has only one layout algorithm + let recipe_wgsl = recipe.to_unified_repr(Repr::Wgsl); + let layout_wgsl = recipe_wgsl.layout_with_default_repr(Repr::Wgsl); + if layout != layout_wgsl { + match try_find_mismatch(&layout, &layout_wgsl) { + Some(mismatch) => { + return Err(NotRepresentable::LayoutError(LayoutError { + recipe, + kind: LayoutErrorKind::NotRepresentable, + language, + address_space, + mismatch, + colored: Context::try_with(call_info!(), |ctx| ctx.settings().colored_error_messages) + .unwrap_or(false), + }) + .into()); + } + None => return Err(NotRepresentable::UnknownLayoutError(recipe, address_space).into()), + } + } + } + } + + // Check that the type layout satisfies the requirements of the address space + match (language, address_space) { + (Language::Wgsl, BufferAddressSpaceEnum::Storage) => { + // As long as the recipe is representable in wgsl, it satifies the storage address space requirements. + // We already checked that the recipe is representable in wgsl above. + } + (Language::Wgsl, BufferAddressSpaceEnum::Uniform) => { + // Repr::WgslUniform is made for exactly this purpose: to check that the type layout + // satisfies the requirements of wgsl's uniform address space. + let recipe_unified = recipe.to_unified_repr(Repr::WgslUniform); + let layout_unified = recipe_unified.layout(); + if layout != layout_unified { + match try_find_mismatch(&layout, &layout_unified) { + Some(mismatch) => { + return Err(RequirementsNotSatisfied::LayoutError(LayoutError { + recipe, + kind: LayoutErrorKind::RequirementsNotSatisfied, + language, + address_space, + mismatch, + colored: Context::try_with(call_info!(), |ctx| ctx.settings().colored_error_messages) + .unwrap_or(false), + }) + .into()); + } + None => return Err(RequirementsNotSatisfied::UnknownLayoutError(recipe, address_space).into()), + } + } + } + } + + Ok(Self { + recipe, + _phantom: std::marker::PhantomData, + }) + } +} + +#[derive(thiserror::Error, Debug, Clone)] +pub enum AddressSpaceError { + #[error("{0}")] + NotRepresentable(#[from] NotRepresentable), + #[error("{0}")] + RequirementsNotSatisfied(#[from] RequirementsNotSatisfied), +} + +#[derive(thiserror::Error, Debug, Clone)] +pub enum NotRepresentable { + #[error("{0}")] + LayoutError(LayoutError), + #[error("{0} contains a {3}, which is not allowed in {1}'s {2} address space.")] + MayNotContain(TypeLayoutRecipe, Language, BufferAddressSpaceEnum, RecipeContains), + #[error("Unknown layout error occured for {0} in {1} address space.")] + UnknownLayoutError(TypeLayoutRecipe, BufferAddressSpaceEnum), +} + +#[derive(thiserror::Error, Debug, Clone)] +pub enum RequirementsNotSatisfied { + #[error("{0}")] + LayoutError(LayoutError), + #[error( + "The size of `{0}` on the gpu is not known at compile time. {1}'s {2} address space \ + requires that the size of {0} on the gpu is known at compile time." + )] + MustBeSized(TypeLayoutRecipe, Language, BufferAddressSpaceEnum), + #[error("Unknown layout error occured for {0} in {1} address space.")] + UnknownLayoutError(TypeLayoutRecipe, BufferAddressSpaceEnum), +} + +#[derive(Debug, Clone)] +pub struct LayoutError { + recipe: TypeLayoutRecipe, + mismatch: LayoutMismatch, + + /// Used to adjust the error message + /// to fit `NotRepresentable` or `RequirementsNotSatisfied`. + kind: LayoutErrorKind, + language: Language, + address_space: BufferAddressSpaceEnum, + + colored: bool, +} + +#[derive(Debug, Clone, Copy)] +pub enum LayoutErrorKind { + NotRepresentable, + RequirementsNotSatisfied, +} + +impl LayoutError { + /// Returns the context the error occurred in. The "{language}" in case of a `NotRepresentable` error, + /// or the "{language}'s {address_space}" in case of a `RequirementsNotSatisfied` error. + fn context(&self) -> &'static str { + match self.kind { + LayoutErrorKind::NotRepresentable => match self.language { + Language::Wgsl => "wgsl", + }, + LayoutErrorKind::RequirementsNotSatisfied => match (self.language, self.address_space) { + (Language::Wgsl, BufferAddressSpaceEnum::Storage) => "wgsl's storage address space", + (Language::Wgsl, BufferAddressSpaceEnum::Uniform) => "wgsl's uniform address space", + }, + } + } +} + +impl std::error::Error for LayoutError {} +impl std::fmt::Display for LayoutError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.kind { + LayoutErrorKind::NotRepresentable => { + writeln!(f, "`{}` is not representable in {}:", self.recipe, self.language)?; + } + LayoutErrorKind::RequirementsNotSatisfied => { + writeln!( + f, + "`{}` does not satisfy the layout requirements of {}:", + self.recipe, + self.context() + )?; + } + } + + match &self.mismatch { + LayoutMismatch::TopLevel { + layout_left, + layout_right, + mismatch, + } => write_top_level_mismatch(f, self, layout_left, layout_right, mismatch), + LayoutMismatch::Struct { + struct_left, + struct_right, + mismatch, + } => write_struct_mismatch(f, self, struct_left, struct_right, mismatch), + } + } +} + +fn write_top_level_mismatch( + f: &mut std::fmt::Formatter<'_>, + error: &LayoutError, + layout_left: &TypeLayout, + layout_right: &TypeLayout, + mismatch: &TopLevelMismatch, +) -> Result<(), std::fmt::Error> { + match mismatch { + TopLevelMismatch::Type => { + // unreachable: The LayoutError is produced by comparing two semantically equivalent TypeLayouts, so all (nested) types are the same + writeln!(f, "[internal error trying to find top-level type layout mismatch: the types seem identical, please report this error]"); + }, + TopLevelMismatch::ArrayStride { + array_left, + array_right, + } => { + let outer_most_array_has_mismatch = array_left.short_name() == layout_left.short_name(); + if outer_most_array_has_mismatch { + writeln!( + f, + "`{}` requires a stride of {} in {}, but has a stride of {}.", + array_left.short_name(), + array_right.byte_stride, + error.context(), + array_left.byte_stride + )?; + } else { + writeln!( + f, + "`{}` within `{}` requires a stride of {} in {}, but has a stride of {}.", + array_left.short_name(), + layout_left.short_name(), + array_right.byte_stride, + error.context(), + array_left.byte_stride + )?; + } + } + TopLevelMismatch::ByteSize { left, right } => { + let outer_most_array_has_mismatch = left.short_name() == layout_left.short_name(); + + let left_name = left.short_name(); + let within_layout_left = display(|f| match outer_most_array_has_mismatch { + true => Ok(()), // don't mention nesting + false => write!(f, " within `{}`", layout_left.short_name()) + }); + let requires_a_byte_size_of_right_size = display(|f| match right.byte_size() { + Some(size) => write!(f, "requires a byte size of {size}"), + None => write!(f, "must be runtime-sized"), + }); + let constraint = error.context(); + let has_a_byte_size_of_left_size = display(|f| match left.byte_size() { + Some(size) => write!(f, "has a byte size of {size}"), + None => write!(f, "is runtime-sized"), + }); + + writeln!( + f, + "`{left_name}`{within_layout_left} {requires_a_byte_size_of_right_size} in {constraint}, but {has_a_byte_size_of_left_size}.", + )?; + } + } + Ok(()) +} + +fn write_struct_mismatch( + f: &mut std::fmt::Formatter<'_>, + error: &LayoutError, + struct_left: &StructLayout, + struct_right: &StructLayout, + mismatch: &StructMismatch, +) -> Result<(), std::fmt::Error> { + match mismatch { + StructMismatch::FieldLayout { + field_index, + field_left, + mismatch: + TopLevelMismatch::ArrayStride { + array_left, + array_right, + }, + .. + } => { + let outer_most_array_has_mismatch = field_left.ty.short_name() == array_left.short_name(); + let layout_info = if outer_most_array_has_mismatch { + LayoutInfoFlags::STRIDE + } else { + // if an inner array has the stride mismatch, showing the outer array's stride could be confusing + LayoutInfoFlags::NONE + }; + + writeln!( + f, + "`{}` in `{}` requires a stride of {} in {}, but has a stride of {}.", + array_left.short_name(), + struct_left.short_name(), + array_right.byte_stride, + error.context(), + array_left.byte_stride + )?; + writeln!(f, "The full layout of `{}` is:\n", struct_left.short_name())?; + write_struct(f, struct_left, layout_info, Some(*field_index), error.colored)?; + } + StructMismatch::FieldLayout { + field_index, + field_left, + field_right, + mismatch: TopLevelMismatch::ByteSize { left, right }, + } => { + let outer_most_array_has_mismatch = field_left.ty.short_name() == left.short_name(); + let layout_info = if outer_most_array_has_mismatch { + LayoutInfoFlags::SIZE + } else { + // if an inner array has the byte size mismatch, showing the outer array's byte size could be confusing + LayoutInfoFlags::NONE + }; + + if !outer_most_array_has_mismatch { + write!(f, "`{}` in field", left.short_name())?; + } else { + write!(f, "Field")?; + } + + let field_left_name = &field_left.name; + let struct_left_name = &struct_left.name; + let requires_a_byte_size_of_right_size = display(|f| match field_right.ty.byte_size() { + Some(size) => write!(f, "requires a byte size of {size}"), + None => write!(f, "must be runtime-sized"), + }); + let constraint = error.context(); + let has_a_byte_size_of_left_size = display(|f| match field_left.ty.byte_size() { + Some(size) => write!(f, "has a byte size of {size}"), + None => write!(f, "is actually runtime-sized"), + }); + + writeln!( + f, + " `{field_left_name}` of `{struct_left_name}` {requires_a_byte_size_of_right_size} in {constraint}, but {has_a_byte_size_of_left_size}", + )?; + writeln!(f, "The full layout of `{}` is:", struct_left.short_name())?; + write_struct(f, struct_left, layout_info, Some(*field_index), error.colored)?; + } + StructMismatch::FieldOffset { + field_index, + field_left, + field_right, + } => { + let field_name = &field_left.name; + let offset = field_left.rel_byte_offset; + let expected_align = field_right.ty.align().as_u64(); + let actual_align = max_u64_po2_dividing(field_left.rel_byte_offset); + + writeln!( + f, + "Field `{}` of `{}` needs to be {} byte aligned in {}, but has a byte-offset of {}, which is only {} byte aligned", + field_name, + struct_left.name, + expected_align, + error.context(), + offset, + actual_align + )?; + + writeln!(f, "The full layout of `{}` is:\n", struct_left.short_name())?; + write_struct( + f, + struct_left, + LayoutInfoFlags::OFFSET | LayoutInfoFlags::ALIGN | LayoutInfoFlags::SIZE, + Some(*field_index), + error.colored, + )?; + + writeln!(f, "\nPotential solutions include:")?; + writeln!( + f, + "- add an #[align({})] attribute to the definition of `{}`", + field_right.ty.align().as_u32(), + field_name + )?; + writeln!( + f, + "- increase the offset of `{field_name}` until it is divisible by {expected_align} by making previous fields larger or adding fields before it" + )?; + match (error.kind, error.language, error.address_space) { + (LayoutErrorKind::RequirementsNotSatisfied, Language::Wgsl, BufferAddressSpaceEnum::Uniform) => { + writeln!(f, "- use a storage binding instead of a uniform binding")?; + } + ( + LayoutErrorKind::NotRepresentable | LayoutErrorKind::RequirementsNotSatisfied, + Language::Wgsl, + BufferAddressSpaceEnum::Storage | BufferAddressSpaceEnum::Uniform, + ) => {} + } + writeln!(f)?; + + match error.language { + Language::Wgsl => { + match error.address_space { + BufferAddressSpaceEnum::Uniform => writeln!( + f, + "In {}, structs, arrays and array elements must be at least 16 byte aligned.", + error.context() + )?, + BufferAddressSpaceEnum::Storage => {} + } + + match (error.kind, error.address_space) { + ( + LayoutErrorKind::RequirementsNotSatisfied, + BufferAddressSpaceEnum::Uniform | BufferAddressSpaceEnum::Storage, + ) => writeln!( + f, + "More info about the wgsl's {} address space can be found at https://www.w3.org/TR/WGSL/#address-space-layout-constraints", + error.address_space + )?, + (LayoutErrorKind::NotRepresentable, _) => writeln!( + f, + "More info about the wgsl's layout algorithm can be found at https://www.w3.org/TR/WGSL/#alignment-and-size" + )?, + } + } + } + } + StructMismatch::FieldCount | + StructMismatch::FieldName { .. } | + StructMismatch::FieldLayout { + mismatch: TopLevelMismatch::Type, + .. + } => { + // unreachable: The LayoutError is produced by comparing two semantically equivalent TypeLayouts, so all (nested) types are the same + writeln!(f, "[internal error trying to find struct layout mismatch: the structs seem identical, please report this error]"); + } + } + Ok(()) +} + +fn write_struct( + f: &mut W, + s: &StructLayout, + layout_info: LayoutInfoFlags, + highlight_field: Option, + colored: bool, +) -> std::fmt::Result +where + W: Write, +{ + let use_256_color_mode = false; + + let mut writer = s.writer(layout_info); + writer.writeln_header(f); + writer.writeln_struct_declaration(f); + for field_index in 0..s.fields.len() { + if Some(field_index) == highlight_field { + if colored { + set_color(f, Some("#508EE3"), use_256_color_mode)?; + } + writer.write_field(f, field_index)?; + writeln!(f, " <--")?; + if colored { + set_color(f, None, use_256_color_mode)?; + } + } else { + writer.writeln_field(f, field_index); + } + } + writer.writeln_struct_end(f)?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::pipeline_kind::Render; + use crate::{self as shame, EncodingGuard, ThreadIsAlreadyEncoding}; + use shame as sm; + use shame::{aliases::*, GpuLayout}; + + const PRINT: bool = false; + + macro_rules! is_struct_mismatch { + ($result:expr, $as_error:ident, $mismatch:pat) => {{ + if let Err(e) = &$result + && PRINT + { + println!("{e}"); + } + matches!( + $result, + Err(AddressSpaceError::$as_error($as_error::LayoutError(LayoutError { + mismatch: LayoutMismatch::Struct { + mismatch: $mismatch, + .. + }, + .. + }))) + ) + }}; + } + + fn enable_color() -> Option, ThreadIsAlreadyEncoding>> { + PRINT.then(|| sm::start_encoding(sm::Settings::default())) + } + + #[test] + fn field_offset_error_not_representable() { + let _guard = enable_color(); + + #[derive(sm::GpuLayout)] + #[gpu_repr(packed)] + struct A { + a: f32x1, + // has offset 4, but in wgsl's storage/uniform address space, it needs to be 16 byte aligned + b: f32x3, + } + + // The error variant is NotRepresentable, because there is no way to represent it in wgsl, + // because an offset of 4 is not possible for f32x3, because needs to be 16 byte aligned. + assert!(is_struct_mismatch!( + TypeLayoutCompatibleWith::::try_from(Language::Wgsl, A::layout_recipe()), + NotRepresentable, + StructMismatch::FieldOffset { .. } + )); + assert!(is_struct_mismatch!( + TypeLayoutCompatibleWith::::try_from(Language::Wgsl, A::layout_recipe()), + NotRepresentable, + StructMismatch::FieldOffset { .. } + )); + } + + #[test] + fn wgsl_uniform_array_stride_requirements_not_satisfied() { + let _guard = enable_color(); + + #[derive(sm::GpuLayout)] + struct A { + a: f32x1, + // has stride 4, but wgsl's uniform address space requires a stride of 16. + // also, has wrong offset, because array align is multiple of 16 in wgsl's uniform address space. + // array stride error has higher priority than field offset error, + b: sm::Array>, + } + + // The error variant is RequirementsNotSatisfied, because the array has a stride of 4 in Repr::Packed, + // but wgsl's uniform address space requires a stride of 16. + assert!(is_struct_mismatch!( + TypeLayoutCompatibleWith::::try_from(Language::Wgsl, A::layout_recipe()), + RequirementsNotSatisfied, + StructMismatch::FieldLayout { + mismatch: TopLevelMismatch::ArrayStride { .. }, + .. + } + )); + + // Testing that the error remains the same when nested in another struct + #[derive(sm::GpuLayout)] + struct B { + a: sm::Struct, + } + assert!(is_struct_mismatch!( + TypeLayoutCompatibleWith::::try_from(Language::Wgsl, B::layout_recipe()), + RequirementsNotSatisfied, + StructMismatch::FieldLayout { + mismatch: TopLevelMismatch::ArrayStride { .. }, + .. + } + )); + + // Testing that the error remains the same when nested in an array + assert!(is_struct_mismatch!( + TypeLayoutCompatibleWith::::try_from( + Language::Wgsl, + , sm::Size<1>>>::layout_recipe() + ), + RequirementsNotSatisfied, + StructMismatch::FieldLayout { + mismatch: TopLevelMismatch::ArrayStride { .. }, + .. + } + )); + } + + #[test] + fn wgsl_uniform_field_offset_requirements_not_satisfied() { + let _guard = enable_color(); + + #[derive(sm::GpuLayout)] + struct A { + a: f32x1, + b: sm::Struct, + } + #[derive(sm::GpuLayout)] + struct B { + a: f32x1, + } + + // The error variant is RequirementsNotSatisfied, because the array has a stride of 4 in Repr::Packed, + // but wgsl's uniform address space requires a stride of 16. + assert!(is_struct_mismatch!( + TypeLayoutCompatibleWith::::try_from(Language::Wgsl, A::layout_recipe()), + RequirementsNotSatisfied, + StructMismatch::FieldOffset { .. } + )); + + // Testing that the error remains the same when nested in another struct + #[derive(sm::GpuLayout)] + struct C { + a: sm::Struct, + } + assert!(is_struct_mismatch!( + TypeLayoutCompatibleWith::::try_from(Language::Wgsl, C::layout_recipe()), + RequirementsNotSatisfied, + StructMismatch::FieldOffset { .. } + )); + + // Testing that the error remains the same when nested in an array + assert!(is_struct_mismatch!( + TypeLayoutCompatibleWith::::try_from( + Language::Wgsl, + , sm::Size<1>>>::layout_recipe() + ), + RequirementsNotSatisfied, + StructMismatch::FieldOffset { .. } + )); + } + + #[test] + fn wgsl_uniform_must_be_sized() { + let _guard = enable_color(); + + #[derive(sm::GpuLayout)] + struct A { + a: sm::Array, + } + + let e = TypeLayoutCompatibleWith::::try_from(Language::Wgsl, A::layout_recipe()).unwrap_err(); + if PRINT { + println!("{e}"); + } + assert!(matches!( + e, + AddressSpaceError::RequirementsNotSatisfied(RequirementsNotSatisfied::MustBeSized( + _, + Language::Wgsl, + BufferAddressSpaceEnum::Uniform + )) + )); + + // Storage address space should allow unsized types + assert!(TypeLayoutCompatibleWith::::try_from(Language::Wgsl, A::layout_recipe()).is_ok()); + } + + #[test] + fn wgsl_storage_may_not_contain_packed_vec() { + let _guard = enable_color(); + + #[derive(sm::GpuLayout)] + #[gpu_repr(packed)] + struct A { + a: sm::packed::snorm16x2, + } + let e = TypeLayoutCompatibleWith::::try_from(Language::Wgsl, A::layout_recipe()).unwrap_err(); + if PRINT { + println!("{e}"); + } + assert!(matches!( + e, + AddressSpaceError::NotRepresentable(NotRepresentable::MayNotContain( + _, + Language::Wgsl, + BufferAddressSpaceEnum::Storage, + RecipeContains::PackedVector + )) + )); + } +} diff --git a/shame/src/frontend/rust_types/type_layout/display.rs b/shame/src/frontend/rust_types/type_layout/display.rs new file mode 100644 index 0000000..eb64d08 --- /dev/null +++ b/shame/src/frontend/rust_types/type_layout/display.rs @@ -0,0 +1,264 @@ +//! This module provides the `Display` implementation for `TypeLayout` and also contains +//! a `StructWriter` which can be used to write the layout of a struct piece by piece +//! with configurable layout information. + +use std::fmt::{Display, Write}; + +use crate::{ + any::{ + layout::{ArrayLayout, StructLayout}, + U32PowerOf2, + }, + common::prettify::UnwrapDisplayOr, + TypeLayout, +}; + +impl Display for TypeLayout { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.write(f, LayoutInfoFlags::ALL) } +} + +impl TypeLayout { + /// a short name for this `TypeLayout`, useful for printing inline + pub fn short_name(&self) -> String { + use TypeLayout::*; + + match &self { + Vector(v) => { + if v.debug_is_atomic { + format!("atomic<{}>", v.ty.scalar) + } else { + v.ty.to_string() + } + } + PackedVector(v) => v.ty.to_string(), + Matrix(m) => m.ty.to_string(), + Array(a) => a.short_name(), + Struct(s) => s.short_name(), + } + } + + pub(crate) fn write(&self, f: &mut W, layout_info: LayoutInfoFlags) -> std::fmt::Result { + use TypeLayout::*; + + match self { + Vector(_) | PackedVector(_) | Matrix(_) | Array(_) => { + let plain = self.short_name(); + + let stride = match self { + Array(a) => Some(a.byte_stride), + Vector(_) | PackedVector(_) | Matrix(_) | Struct(_) => None, + }; + let info_offset = plain.len() + 1; + + // Write header if some layout information is requested + if layout_info != LayoutInfoFlags::NONE { + writeln!(f, "{:info_offset$}{}", "", layout_info.header())?; + } + + // Write the type name and layout information + let info = layout_info.format(None, self.align(), self.byte_size(), stride); + writeln!(f, "{plain:info_offset$}{info}")?; + } + Struct(s) => s.write(f, layout_info)?, + }; + + Ok(()) + } +} + +impl StructLayout { + /// a short name for this `StructLayout`, useful for printing inline + pub fn short_name(&self) -> String { self.name.to_string() } + + pub(crate) fn writer(&self, layout_info: LayoutInfoFlags) -> StructWriter<'_> { StructWriter::new(self, layout_info) } + + pub(crate) fn write(&self, f: &mut W, layout_info: LayoutInfoFlags) -> std::fmt::Result { + use TypeLayout::*; + + let mut writer = self.writer(layout_info); + writer.writeln_header(f)?; + writer.writeln_struct_declaration(f)?; + for i in 0..self.fields.len() { + writer.writeln_field(f, i)?; + } + writer.writeln_struct_end(f) + } +} + +impl ArrayLayout { + /// a short name for this `ArrayLayout`, useful for printing inline + pub fn short_name(&self) -> String { + match self.len { + Some(n) => format!("array<{}, {n}>", self.element_ty.short_name()), + None => format!("array<{}, runtime-sized>", self.element_ty.short_name()), + } + } +} + +/// A bitmask that indicates which layout information should be displayed. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct LayoutInfoFlags(u8); +#[rustfmt::skip] +impl LayoutInfoFlags { + pub const NONE: Self = Self(0); + pub const OFFSET: Self = Self(1 << 0); + pub const ALIGN: Self = Self(1 << 1); + pub const SIZE: Self = Self(1 << 2); + pub const STRIDE: Self = Self(1 << 3); + pub const ALL: Self = Self(Self::OFFSET.0 | Self::ALIGN.0 | Self::SIZE.0 | Self::STRIDE.0); +} +impl std::ops::BitOr for LayoutInfoFlags { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { LayoutInfoFlags(self.0 | rhs.0) } +} +impl LayoutInfoFlags { + pub fn contains(&self, other: Self) -> bool { (self.0 & other.0) == other.0 } + + pub fn header(&self) -> String { + let mut parts = Vec::with_capacity(4); + for (info, info_str) in [ + (Self::OFFSET, "offset"), + (Self::ALIGN, "align"), + (Self::SIZE, "size"), + (Self::STRIDE, "stride"), + ] { + if self.contains(info) { + parts.push(info_str); + } + } + parts.join(" ") + } + + pub fn format(&self, offset: Option, align: U32PowerOf2, size: Option, stride: Option) -> String { + let infos: [(Self, &'static str, &dyn Display); 4] = [ + (Self::OFFSET, "offset", &UnwrapDisplayOr(offset, "")), + (Self::ALIGN, "align", &align.as_u32()), + (Self::SIZE, "size", &UnwrapDisplayOr(size, "")), + (Self::STRIDE, "stride", &UnwrapDisplayOr(stride, "")), + ]; + let mut parts = Vec::with_capacity(4); + for (info, info_str, value) in infos { + if self.contains(info) { + parts.push(format!("{:>info_width$}", value, info_width = info_str.len())); + } + } + parts.join(" ") + } +} + +pub struct StructWriter<'a> { + s: &'a StructLayout, + tab: &'static str, + layout_info: LayoutInfoFlags, + layout_info_offset: usize, +} + +impl<'a> StructWriter<'a> { + pub fn new(s: &'a StructLayout, layout_info: LayoutInfoFlags) -> Self { + let mut this = Self { + s, + // Could make this configurable + tab: " ", + layout_info, + layout_info_offset: 0, + }; + this.set_layout_info_offset_auto(None); + this + } + + /// By setting `max_fields` to `Some(n)`, the writer will adjust Self::layout_info_offset + /// to only take into account the first `n` fields of the struct. + pub(crate) fn set_layout_info_offset_auto(&mut self, max_fields: Option) { + let fields = match max_fields { + Some(n) => n.min(self.s.fields.len()), + None => self.s.fields.len(), + }; + let layout_info_offset = 1 + + (0..fields) + .map(|i| self.field_declaration(i).len()) + .max() + .unwrap_or(0) + .max(self.struct_declaration().len()); + self.layout_info_offset = layout_info_offset; + } + + pub(crate) fn ensure_layout_info_offset(&mut self, min_layout_info_offset: usize) { + self.layout_info_offset = self.layout_info_offset.max(min_layout_info_offset) + } + + pub(crate) fn layout_info_offset(&self) -> usize { self.layout_info_offset } + + pub(crate) fn tab(&self) -> &'static str { self.tab } + + fn struct_declaration(&self) -> String { format!("struct {} {{", self.s.name) } + + fn field_declaration(&self, field_index: usize) -> String { + match self.s.fields.get(field_index) { + Some(field) => format!("{}{}: {},", self.tab, field.name, field.ty.short_name()), + None => format!("{}field {field_index} not found,", self.tab), + } + } + + pub(crate) fn write_header(&self, f: &mut W) -> std::fmt::Result { + if self.layout_info != LayoutInfoFlags::NONE { + let info_offset = self.layout_info_offset(); + write!(f, "{:info_offset$}{}", "", self.layout_info.header())?; + } + Ok(()) + } + + pub(crate) fn write_struct_declaration(&self, f: &mut W) -> std::fmt::Result { + let info = self.layout_info.format(None, *self.s.align, self.s.byte_size, None); + let info_offset = self.layout_info_offset(); + write!(f, "{:info_offset$}{info}", self.struct_declaration()) + } + + pub(crate) fn write_field(&self, f: &mut W, field_index: usize) -> std::fmt::Result { + use TypeLayout::*; + + match self.s.fields.get(field_index) { + Some(field) => { + let info = self.layout_info.format( + Some(field.rel_byte_offset), + field.ty.align(), + field.ty.byte_size(), + match &field.ty { + Array(array) => Some(array.byte_stride), + Vector(_) | PackedVector(_) | Matrix(_) | Struct(_) => None, + }, + ); + let info_offset = self.layout_info_offset(); + write!(f, "{:info_offset$}{info}", self.field_declaration(field_index)) + } + None => { + write!(f, "{}field {field_index} not found", self.tab) + } + } + } + + pub(crate) fn write_struct_end(&self, f: &mut W) -> std::fmt::Result { write!(f, "}}") } + + pub(crate) fn writeln_header(&self, f: &mut W) -> std::fmt::Result { + if self.layout_info != LayoutInfoFlags::NONE { + self.write_header(f)?; + writeln!(f)?; + } + Ok(()) + } + + pub(crate) fn writeln_struct_declaration(&self, f: &mut W) -> std::fmt::Result { + self.write_struct_declaration(f)?; + writeln!(f) + } + + pub(crate) fn writeln_field(&self, f: &mut W, field_index: usize) -> std::fmt::Result { + self.write_field(f, field_index)?; + writeln!(f) + } + + pub(crate) fn writeln_struct_end(&self, f: &mut W) -> std::fmt::Result { + self.write_struct_end(f)?; + writeln!(f) + } +} diff --git a/shame/src/frontend/rust_types/type_layout/eq.rs b/shame/src/frontend/rust_types/type_layout/eq.rs new file mode 100644 index 0000000..afea5a9 --- /dev/null +++ b/shame/src/frontend/rust_types/type_layout/eq.rs @@ -0,0 +1,790 @@ +use crate::common::format::display; +use crate::common::prettify::UnwrapDisplayOr; +use crate::frontend::rust_types::type_layout::display::LayoutInfoFlags; + +use super::*; + +/// Contains information about the layout mismatch between two `TypeLayout`s. +/// +/// The type layouts are traversed depth-first and the first mismatch encountered +/// is reported at its deepest level. +/// +/// In case of nested structs this means that if the field `a` in +/// ``` +/// struct A { +/// a: struct B { ... } +/// } +/// struct AOther { +/// a: struct BOther { ... } +/// } +/// ``` +/// mismatches, because `B` and `BOther` don't match, then the exact mismatch (some field mismatch) +/// between `B` and `BOther` is reported and not the field mismatch of `a` in `A` and `AOther`. +/// +/// Nested arrays are reported in two levels: `LayoutMismatch::TopLevel` contains the +/// top level / outer most array layout and a `TopLevelMismatch`, which contains the +/// inner type layout where the mismatch is happening. +/// For example if there is an array stride mismatch of the inner array of `Array>`, +/// then `LayoutMismatch::TopLevel` contains the layout of `Array>` and +/// a `TopLevelMismatch::ArrayStride` with the layout of `Array`. +/// +/// A field of nested arrays in a struct is handled in the same way by +/// `LayoutMismatch::Struct` containing the mismatching field index and `FieldLayout`, +/// which let's us access the outer most array, and a `TopLevelMismatch`, +/// which let's us access the inner type layout where the mismatch is happening. +#[derive(Debug, Clone)] +pub enum LayoutMismatch { + TopLevel { + layout_left: TypeLayout, + layout_right: TypeLayout, + mismatch: TopLevelMismatch, + }, + Struct { + struct_left: StructLayout, + struct_right: StructLayout, + mismatch: StructMismatch, + }, +} + +#[derive(Debug, Clone)] +pub enum TopLevelMismatch { + Type, + ByteSize { + left: TypeLayout, + right: TypeLayout, + }, + ArrayStride { + array_left: ArrayLayout, + array_right: ArrayLayout, + }, +} + +/// Field count is checked last. +#[derive(Debug, Clone)] +pub enum StructMismatch { + FieldName { + field_index: usize, + field_left: FieldLayout, + field_right: FieldLayout, + }, + FieldLayout { + field_index: usize, + field_left: FieldLayout, + field_right: FieldLayout, + mismatch: TopLevelMismatch, + }, + FieldOffset { + field_index: usize, + field_left: FieldLayout, + field_right: FieldLayout, + }, + FieldCount, +} + +/// Find the first depth first layout mismatch +pub(crate) fn try_find_mismatch(layout1: &TypeLayout, layout2: &TypeLayout) -> Option { + use TypeLayout::*; + + let make_mismatch = |mismatch: TopLevelMismatch| LayoutMismatch::TopLevel { + layout_left: layout1.clone(), + layout_right: layout2.clone(), + mismatch, + }; + + // First check if the kinds are the same type + match (&layout1, &layout2) { + (Vector(v1), Vector(v2)) => { + if v1.ty != v2.ty { + return Some(make_mismatch(TopLevelMismatch::Type)); + } + } + (PackedVector(p1), PackedVector(p2)) => { + if p1.ty != p2.ty { + return Some(make_mismatch(TopLevelMismatch::Type)); + } + } + (Matrix(m1), Matrix(m2)) => { + if m1.ty != m2.ty { + return Some(make_mismatch(TopLevelMismatch::Type)); + } + } + (Array(a1), Array(a2)) => { + // Recursively check element types + match try_find_mismatch(&a1.element_ty, &a2.element_ty) { + // Update the top level layouts and propagate the LayoutMismatch + Some(LayoutMismatch::TopLevel { mismatch, .. }) => { + return Some(make_mismatch(mismatch)); + } + // Struct mismatch, so it's not a top-level mismatch + m @ Some(LayoutMismatch::Struct { .. }) => return m, + None => {} + } + + // Check array lengths, which are a type mismatch if they differ + if a1.len != a2.len { + return Some(make_mismatch(TopLevelMismatch::Type)); + } + + // Check array stride + if a1.byte_stride != a2.byte_stride { + return Some(make_mismatch(TopLevelMismatch::ArrayStride { + array_left: (**a1).clone(), + array_right: (**a2).clone(), + })); + } + } + (Struct(s1), Struct(s2)) => { + return try_find_struct_mismatch(s1, s2); + } + // Different kinds entirely. Matching exhaustively, so that changes to TypeLayout lead us here. + (Vector(_) | PackedVector(_) | Matrix(_) | Array(_) | Struct(_), _) => { + return Some(make_mismatch(TopLevelMismatch::Type)); + } + } + + // Check byte size. + // We do this at the end, because type mismatches should have priority over byte size mismatches. + if layout1.byte_size() != layout2.byte_size() { + return Some(make_mismatch(TopLevelMismatch::ByteSize { + left: layout1.clone(), + right: layout2.clone(), + })); + } + + None +} + +fn try_find_struct_mismatch(struct1: &StructLayout, struct2: &StructLayout) -> Option { + let make_mismatch = |mismatch: StructMismatch| LayoutMismatch::Struct { + struct_left: struct1.clone(), + struct_right: struct2.clone(), + mismatch, + }; + + for (field_index, (field1, field2)) in struct1.fields.iter().zip(struct2.fields.iter()).enumerate() { + // Order of checks is important here. We check in order + // - field name + // - field inner mismatch + // - field offset + if field1.name != field2.name { + return Some(make_mismatch(StructMismatch::FieldName { + field_index, + field_left: field1.clone(), + field_right: field2.clone(), + })); + } + + // Recursively check field types + if let Some(inner_mismatch) = try_find_mismatch(&field1.ty, &field2.ty) { + match inner_mismatch { + // If it's a top-level mismatch, convert it to a field mismatch + LayoutMismatch::TopLevel { mismatch, .. } => { + return Some(make_mismatch(StructMismatch::FieldLayout { + field_index, + field_left: field1.clone(), + field_right: field2.clone(), + mismatch, + })); + } + // Pass through struct mismatches + struct_mismatch @ LayoutMismatch::Struct { .. } => return Some(struct_mismatch), + } + } + + // Check field offset + if field1.rel_byte_offset != field2.rel_byte_offset { + return Some(make_mismatch(StructMismatch::FieldOffset { + field_index, + field_left: field1.clone(), + field_right: field2.clone(), + })); + } + } + + // Check field count. + // We do this at the end, because fields are checked in order and a field count mismatch + // can be viewed as a field mismatch of one field beyond the last field of the smaller struct. + if struct1.fields.len() != struct2.fields.len() { + return Some(make_mismatch(StructMismatch::FieldCount)); + } + + None +} + +/// Error of two layouts mismatching. Implements Display for a visualization of the mismatch. +#[derive(Clone)] +pub struct CheckEqLayoutMismatch { + /// 2 (name, layout) pairs + layouts: [(String, TypeLayout); 2], + colored_error: bool, +} + +impl std::fmt::Debug for CheckEqLayoutMismatch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{self}") } +} + +impl Display for CheckEqLayoutMismatch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let [(a_name, a), (b_name, b)] = &self.layouts; + let layouts = [(a_name.as_str(), a), (b_name.as_str(), b)]; + match CheckEqLayoutMismatch::write(f, layouts, self.colored_error) { + Ok(()) => {} + Err(DisplayMismatchError::FmtError(e)) => return Err(e), + Err(DisplayMismatchError::NotFound) => { + writeln!( + f, + "" + )?; + writeln!(f, "the full type layouts in question are:")?; + for (name, layout) in layouts { + writeln!(f, "`{name}`:")?; + writeln!(f, "{layout}")?; + } + } + } + + Ok(()) + } +} + +pub(crate) enum DisplayMismatchError { + NotFound, + FmtError(std::fmt::Error), +} +impl From for DisplayMismatchError { + fn from(err: std::fmt::Error) -> Self { DisplayMismatchError::FmtError(err) } +} + +impl CheckEqLayoutMismatch { + #[allow(clippy::needless_return)] + pub(crate) fn write( + f: &mut W, + layouts: [(&str, &TypeLayout); 2], + colored: bool, + ) -> Result<(), DisplayMismatchError> { + let [(a_name, a), (b_name, b)] = layouts; + + let use_256_color_mode = false; + let enable_color = |f_: &mut W, hex| match colored { + true => set_color(f_, Some(hex), use_256_color_mode), + false => Ok(()), + }; + let reset_color = |f_: &mut W| match colored { + true => set_color(f_, None, use_256_color_mode), + false => Ok(()), + }; + let hex_left = "#DF5853"; // red + let hex_right = "#9A639C"; // purple + + // Using try_find_mismatch to find the first mismatching type / struct field. + let Some(mismatch) = try_find_mismatch(a, b) else { + return Err(DisplayMismatchError::NotFound); + }; + + match mismatch { + LayoutMismatch::TopLevel { + layout_left, + layout_right, + mismatch, + } => match mismatch { + TopLevelMismatch::Type => writeln!( + f, + "The layouts of `{}` ({a_name}) and `{}` ({b_name}) do not match, because their types are semantically different.", + layout_left.short_name(), + layout_right.short_name() + )?, + TopLevelMismatch::ArrayStride { + array_left, + array_right, + } => { + writeln!( + f, + "The layouts of `{}` ({a_name}) and `{}` ({b_name}) do not match.", + layout_left.short_name(), + layout_right.short_name() + )?; + writeln!( + f, + "`{}` ({a_name}) has a stride of {}, while `{}` ({b_name}) has a stride of {}.", + array_left.short_name(), + array_left.byte_stride, + array_right.short_name(), + array_right.byte_stride + )?; + } + TopLevelMismatch::ByteSize { left, right } => { + writeln!( + f, + "The layouts of `{}` ({a_name}) and `{}` ({b_name}) do not match.", + layout_left.short_name(), + layout_right.short_name() + )?; + writeln!( + f, + "`{}` ({a_name}) {}, while `{}` ({b_name}) {}.", + left.short_name(), + display(|f| match left.byte_size() { + Some(size) => write!(f, "has a byte-size of {size}"), + None => write!(f, "is runtime-sized"), + }), + right.short_name(), + display(|f| match right.byte_size() { + Some(size) => write!(f, "has a byte-size of {size}"), + None => write!(f, "is runtime-sized"), + }), + )?; + } + }, + LayoutMismatch::Struct { + struct_left, + struct_right, + mismatch, + } => { + let is_top_level = &TypeLayout::Struct(Rc::new(struct_left.clone())) == a; + if is_top_level { + writeln!( + f, + "The layouts of `{}` and `{}` do not match, because the", + struct_left.name, struct_right.name + )?; + } else { + writeln!( + f, + "The layouts of `{}` and `{}`, contained in `{}` and `{}` respectively, do not match, because the", + struct_left.name, + struct_right.name, + a.short_name(), + b.short_name() + )?; + } + let (mismatch_field_index, layout_info) = match &mismatch { + StructMismatch::FieldName { field_index, .. } => { + writeln!(f, "names of field {field_index} are different.")?; + (Some(field_index), LayoutInfoFlags::NONE) + } + StructMismatch::FieldLayout { + field_index, + field_left, + mismatch: TopLevelMismatch::Type, + .. + } => { + writeln!(f, "type of `{}` is different.", field_left.name)?; + (Some(field_index), LayoutInfoFlags::NONE) + } + StructMismatch::FieldLayout { + field_index, + field_left, + mismatch: TopLevelMismatch::ByteSize { left, right }, + .. + } => { + // Inner type in (nested) array has mismatching byte size + if &field_left.ty != left { + writeln!( + f, + "byte size of `{}` is {} in `{}` and the byte size of `{}` is {} in `{}`.", + left.short_name(), + UnwrapDisplayOr(left.byte_size(), "runtime-sized"), + struct_left.name, + right.short_name(), + UnwrapDisplayOr(right.byte_size(), "runtime-sized"), + struct_right.name, + )?; + // Not showing byte size info, because it can be misleading since + // the inner type is the one that has mismatching byte size. + (Some(field_index), LayoutInfoFlags::NONE) + } else { + writeln!(f, "byte size of `{}` is different.", field_left.name)?; + (Some(field_index), LayoutInfoFlags::SIZE) + } + } + StructMismatch::FieldLayout { + field_index, + field_left, + mismatch: + TopLevelMismatch::ArrayStride { + array_left, + array_right, + }, + .. + } => { + // Inner type in (nested) array has mismatching stride + if field_left.ty.short_name() != array_left.short_name() { + writeln!( + f, + "stride of `{}` is {} in `{}` and {} in `{}`.", + array_left.short_name(), + array_left.byte_stride, + struct_left.name, + array_right.byte_stride, + struct_right.name, + )?; + // Not showing stride info, because it can be misleading since + // the inner type is the one that has mismatching stride. + (Some(field_index), LayoutInfoFlags::NONE) + } else { + writeln!(f, "array stride of {} is different.", field_left.name)?; + (Some(field_index), LayoutInfoFlags::STRIDE) + } + } + StructMismatch::FieldOffset { + field_index, + field_left, + .. + } => { + writeln!(f, "offset of {} is different.", field_left.name)?; + ( + Some(field_index), + LayoutInfoFlags::OFFSET | LayoutInfoFlags::ALIGN | LayoutInfoFlags::SIZE, + ) + } + StructMismatch::FieldCount => { + writeln!(f, "number of fields is different.")?; + (None, LayoutInfoFlags::NONE) + } + }; + writeln!(f)?; + + let fields_without_mismatch = match mismatch_field_index { + Some(index) => struct_left.fields.len().min(struct_right.fields.len()).min(*index), + None => struct_left.fields.len().min(struct_right.fields.len()), + }; + + // Start writing the structs with the mismatch highlighted + let mut writer_left = struct_left.writer(layout_info); + let mut writer_right = struct_right.writer(layout_info); + + // Make it so layout info offset only takes into account the fields before and + // including the mismatch, because those are the only fields that will be written below. + if let Some(mismatch_field_index) = mismatch_field_index { + let max_fields = Some(mismatch_field_index + 1); + writer_left.set_layout_info_offset_auto(max_fields); + writer_right.set_layout_info_offset_auto(max_fields); + } + // Make sure layout info offset is large enough to fit the custom struct declaration + let struct_declaration = format!("struct {a_name} / {b_name} {{"); + let layout_info_offset = writer_left + .layout_info_offset() + .max(writer_right.layout_info_offset()) + .max(struct_declaration.len()); + writer_left.ensure_layout_info_offset(layout_info_offset); + writer_right.ensure_layout_info_offset(layout_info_offset); + + // Write header + writer_left.writeln_header(f)?; + + // Write custom struct declaration + write!(f, "struct ")?; + enable_color(f, hex_left)?; + write!(f, "{}", struct_left.name)?; + reset_color(f)?; + write!(f, " / ")?; + enable_color(f, hex_right)?; + write!(f, "{}", struct_right.name)?; + reset_color(f)?; + writeln!(f, " {{")?; + + // Write matching fields + for field_index in 0..fields_without_mismatch { + writer_left.writeln_field(f, field_index)?; + } + + match mismatch { + StructMismatch::FieldName { field_index, .. } | + StructMismatch::FieldLayout { field_index, .. } | + StructMismatch::FieldOffset { field_index, .. } => { + // Write mismatching field + enable_color(f, hex_left)?; + writer_left.write_field(f, field_index)?; + writeln!(f, " <-- {a_name}")?; + reset_color(f)?; + enable_color(f, hex_right)?; + writer_right.write_field(f, field_index)?; + writeln!(f, " <-- {b_name}")?; + reset_color(f)?; + if struct_left.fields.len() > field_index + 1 || struct_right.fields.len() > field_index + 1 { + // Write ellipsis if there are more fields after the mismatch + writeln!(f, "{}...", writer_left.tab())?; + } + } + StructMismatch::FieldCount => { + // Write the remaining fields of the larger struct + let (writer, len, hex) = match struct_left.fields.len() > struct_right.fields.len() { + true => (&mut writer_left, struct_left.fields.len(), hex_left), + false => (&mut writer_right, struct_right.fields.len(), hex_right), + }; + + enable_color(f, hex)?; + for field_index in fields_without_mismatch..len { + writer.writeln_field(f, field_index)?; + } + reset_color(f)?; + } + } + + // Write closing bracket + writer_left.writeln_struct_end(f)?; + } + } + + Ok(()) + } +} + +/// takes two pairs of `(debug_name, layout)` and compares them for equality. +/// +/// if the two layouts are not equal it uses the debug names in the returned +/// error to tell the two layouts apart. +pub(crate) fn check_eq(a: (&str, &TypeLayout), b: (&str, &TypeLayout)) -> Result<(), CheckEqLayoutMismatch> +where + TypeLayout: PartialEq, +{ + match a.1 == b.1 { + true => Ok(()), + false => Err(CheckEqLayoutMismatch { + layouts: [(a.0.into(), a.1.to_owned()), (b.0.into(), b.1.to_owned())], + colored_error: Context::try_with(call_info!(), |ctx| ctx.settings().colored_error_messages) + .unwrap_or(false), + }), + } +} + +#[cfg(test)] +mod tests { + use crate::pipeline_kind::Render; + use crate::{self as shame, EncodingGuard, ThreadIsAlreadyEncoding}; + use shame as sm; + use shame::{CpuLayout, GpuLayout, gpu_layout, cpu_layout}; + use crate::aliases::*; + + const PRINT: bool = true; + + #[derive(Clone, Copy)] + #[repr(C)] + #[allow(non_camel_case_types)] + struct f32x3_align4(pub [f32; 3]); + + impl CpuLayout for f32x3_align4 { + fn cpu_layout() -> shame::TypeLayout { + let mut layout = gpu_layout::(); + *layout.align_mut() = shame::any::U32PowerOf2::_4; + layout + } + } + + #[derive(Clone, Copy)] + #[repr(C)] + #[allow(non_camel_case_types)] + struct f32x3_size16(pub [f32; 4]); + + impl CpuLayout for f32x3_size16 { + fn cpu_layout() -> shame::TypeLayout { + let mut layout = gpu_layout::(); + layout.set_byte_size(size_of::() as u64); + layout + } + } + + fn check_mismatch() { + let mismatch = super::check_eq(("gpu", &gpu_layout::()), ("cpu", &cpu_layout::())).unwrap_err(); + if PRINT { + println!("{mismatch}"); + } + } + + fn enable_color() -> Option, ThreadIsAlreadyEncoding>> { + PRINT.then(|| sm::start_encoding(sm::Settings::default())) + } + + #[test] + fn test_field_name_mismatch() { + let _guard = enable_color(); + + #[derive(GpuLayout)] + pub struct A { + a: u32x1, + } + #[derive(CpuLayout)] + #[repr(C)] + pub struct ACpu { + b: u32, + } + check_mismatch::(); + } + + #[test] + fn test_field_type_mismatch() { + let _guard = enable_color(); + + if PRINT { + println!("The error also shows how \"...\" is used if there are more fields after the mismatching field\n"); + } + #[derive(GpuLayout)] + pub struct B { + a: f32x1, + b: f32x1, + } + #[derive(CpuLayout)] + #[repr(C)] + pub struct BCpu { + a: u32, + b: f32, + } + check_mismatch::(); + } + + #[test] + fn test_field_offset_mismatch() { + let _guard = enable_color(); + + #[derive(GpuLayout)] + pub struct C { + a: f32x1, + b: f32x3, + } + #[derive(CpuLayout)] + #[repr(C)] + pub struct CCpu { + a: f32, + b: f32x3_align4, + } + check_mismatch::(); + } + + #[test] + fn test_field_byte_size_mismatch() { + let _guard = enable_color(); + + #[derive(GpuLayout)] + pub struct D { + a: f32x3, + } + #[derive(CpuLayout)] + #[repr(C)] + pub struct DCpu { + a: f32x3_size16, + } + check_mismatch::(); + } + + #[test] + fn test_field_nested_byte_size_mismatch() { + let _guard = enable_color(); + + if PRINT { + println!( + "The error does not show the `size` column, because it could be confusing, since the type in the array is where the mismatch happens:\n" + ); + } + #[derive(GpuLayout)] + pub struct E { + a: sm::Array>, + } + #[derive(CpuLayout)] + #[repr(C)] + pub struct ECpu { + a: [f32x3_size16; 4], + } + check_mismatch::(); + } + + #[test] + fn test_field_stride_mismatch() { + let _guard = enable_color(); + + #[derive(GpuLayout)] + pub struct F { + a: sm::Array>, + } + #[derive(CpuLayout)] + #[repr(C)] + pub struct FCpu { + a: [f32x3_align4; 4], + } + check_mismatch::(); + } + + #[test] + fn test_stride_mismatch() { + let _guard = enable_color(); + check_mismatch::>, [f32x3_align4; 4]>(); + } + + #[test] + fn test_nested_stride_mismatch() { + let _guard = enable_color(); + check_mismatch::>, sm::Size<2>>, [[f32x3_align4; 4]; 2]>(); + } + + #[test] + fn test_nested_stride_in_struct_mismatch() { + let _guard = enable_color(); + + if PRINT { + println!( + "The error does not show the `stride` column, because it could be confusing, since the type in the array is where the mismatch happens:\n" + ); + } + #[derive(GpuLayout)] + pub struct G { + a: sm::Array>, sm::Size<2>>, + } + #[derive(CpuLayout)] + #[repr(C)] + pub struct GCpu { + a: [[f32x3_align4; 4]; 2], + } + check_mismatch::(); + } + + #[test] + fn test_struct_in_struct_mismatch() { + let _guard = enable_color(); + + #[derive(GpuLayout)] + pub struct Inner { + x: f32x1, + y: f32x1, + } + + #[derive(GpuLayout)] + pub struct Outer { + inner: sm::Struct, + z: u32x1, + } + + #[derive(CpuLayout)] + #[repr(C)] + pub struct InnerCpu { + x: f32, + y: u32, // Type mismatch here + } + + #[derive(CpuLayout)] + #[repr(C)] + pub struct OuterCpu { + inner: InnerCpu, + z: u32, + } + + check_mismatch::(); + } + + #[test] + fn test_field_count_mismatch() { + let _guard = enable_color(); + + #[derive(GpuLayout)] + pub struct H { + a: f32x1, + b: u32x1, + } + #[derive(CpuLayout)] + #[repr(C)] + pub struct HCpu { + a: f32, + b: u32, + c: f32, // Extra field + } + check_mismatch::(); + } +} diff --git a/shame/src/frontend/rust_types/type_layout/mod.rs b/shame/src/frontend/rust_types/type_layout/mod.rs new file mode 100644 index 0000000..b67492f --- /dev/null +++ b/shame/src/frontend/rust_types/type_layout/mod.rs @@ -0,0 +1,469 @@ +//! Everything related to type layouts. + +use std::{ + fmt::{Debug, Display, Write}, + hash::Hash, + rc::Rc, +}; + +use crate::{ + any::U32PowerOf2, + call_info, + common::{ignore_eq::IgnoreInEqOrdHash, prettify::set_color}, + ir::{self, ir_type::CanonName, recording::Context}, +}; +use recipe::{Matrix, Vector, PackedVector}; + +pub(crate) mod compatible_with; +pub(crate) mod display; +pub(crate) mod eq; +pub(crate) mod recipe; + +/// The memory layout of a type. +/// +/// This models only the layout, not other characteristics of the types. +/// For example an `Atomic>` is treated like a regular `vec` layout wise. +/// +/// ### Layout comparison +/// +/// The `PartialEq + Eq` implementation of `TypeLayout` is designed to answer the question +/// "do these two types have the same layout" so that uploading a type to the gpu +/// will result in no memory errors. +/// +/// a layout comparison looks like this: +/// ``` +/// use shame as sm; +/// assert_eq!(sm::cpu_layout::(), sm::gpu_layout>()); +/// ``` +#[derive(Clone, PartialEq, Eq, Hash)] +pub enum TypeLayout { + /// `vec` + Vector(VectorLayout), + /// special compressed vectors for vertex attribute types + /// + /// see the [`crate::packed`] module + PackedVector(PackedVectorLayout), + /// `mat`, first `Len2` is cols, 2nd `Len2` is rows + Matrix(MatrixLayout), + /// `Array` and `Array>` + Array(Rc), + /// structures which may be empty and may have an unsized last field + Struct(Rc), +} + +impl Debug for TypeLayout { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // debug assertions should display the proper presentation of diffs, + // so we us the Display trait here, too + write!(f, "{}", self) + } +} + +#[allow(missing_docs)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct VectorLayout { + pub byte_size: u64, + pub align: IgnoreInEqOrdHash, + pub ty: Vector, + + // debug information + pub debug_is_atomic: bool, +} + +#[allow(missing_docs)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct PackedVectorLayout { + pub byte_size: u64, + pub align: IgnoreInEqOrdHash, + pub ty: PackedVector, +} + +#[allow(missing_docs)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MatrixLayout { + pub byte_size: u64, + pub align: IgnoreInEqOrdHash, + pub ty: Matrix, +} + +#[allow(missing_docs)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ArrayLayout { + pub byte_size: Option, + pub align: IgnoreInEqOrdHash, + pub byte_stride: u64, + pub element_ty: TypeLayout, + // not NonZeroU32, since for rust `CpuLayout`s the array size may be 0. + pub len: Option, +} + +/// a sized or unsized struct type with 0 or more fields +#[allow(missing_docs)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct StructLayout { + pub byte_size: Option, + pub align: IgnoreInEqOrdHash, + /// The canonical name of the structure type, ignored in equality/hash comparisons + pub name: IgnoreInEqOrdHash, + /// The fields of the structure with their memory offsets + pub fields: Vec, +} + +#[allow(missing_docs)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct FieldLayout { + /// The relative byte offset of this field from the start of its containing structure + pub rel_byte_offset: u64, + pub name: CanonName, + pub ty: TypeLayout, +} + +/// Enum of layout algorithms. +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Repr { + /// WGSL's layout algorithm + /// https://www.w3.org/TR/WGSL/#alignment-and-size + #[default] + Wgsl, + /// Modified layout algorithm based on [`Repr::Wgsl`], but with different type + /// alignments and array strides that make the resulting Layout match wgsl's + /// uniform address space requirements. + /// + /// https://www.w3.org/TR/WGSL/#address-space-layout-constraints + /// + /// (matrix strides remain unchanged however, which makes this different from the std140 layout for mat2x2) + /// + /// Internally used for checking whether a type can be used in the wgsl's + /// uniform address space + WgslUniform, + /// byte-alignment of everything is 1. Custom alignment attributes + /// in [`TypeLayoutRecipe`] are unsupported. + Packed, +} + +impl std::fmt::Display for Repr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Repr::Wgsl => write!(f, "wgsl"), + Repr::WgslUniform => write!(f, "wgsl uniform"), + Repr::Packed => write!(f, "packed"), + } + } +} + +impl TypeLayout { + /// Returns the byte size of the represented type. + /// + /// For sized types, this returns Some(size), while for unsized types + /// (like runtime-sized arrays), this returns None. + pub fn byte_size(&self) -> Option { + match self { + TypeLayout::Vector(v) => Some(v.byte_size), + TypeLayout::PackedVector(p) => Some(p.byte_size), + TypeLayout::Matrix(m) => Some(m.byte_size), + TypeLayout::Array(a) => a.byte_size, + TypeLayout::Struct(s) => s.byte_size, + } + } + + /// Returns the alignment requirement of the represented type. + pub fn align(&self) -> U32PowerOf2 { + match self { + TypeLayout::Vector(v) => *v.align, + TypeLayout::PackedVector(p) => *p.align, + TypeLayout::Matrix(m) => *m.align, + TypeLayout::Array(a) => *a.align, + TypeLayout::Struct(s) => *s.align, + } + } + + /// mutable reference to the alignment requirement of the represented type. + /// + /// may allocate a new `Rc` (via `Rc::make_mut`) for `Array`/`Struct` layouts + pub fn align_mut(&mut self) -> &mut U32PowerOf2 { + match self { + TypeLayout::Vector(v) => &mut v.align, + TypeLayout::Matrix(m) => &mut m.align, + TypeLayout::PackedVector(v) => &mut v.align, + TypeLayout::Array(a) => &mut Rc::make_mut(a).align, + TypeLayout::Struct(s) => &mut Rc::make_mut(s).align, + } + } + + /// mutable access to the `byte_size` if it exists + pub fn try_byte_size_mut(&mut self) -> Option<&mut u64> { + match self.removable_byte_size_mut() { + Ok(removable) => removable.as_mut(), + Err(fixed) => Some(fixed), + } + } + + /// set the byte size of the `TypeLayout` to `new_size` (or `Some(new_size)` if the type can be unsized) + /// + /// use [`removable_byte_size_mut`] if you need more control + pub fn set_byte_size(&mut self, new_size: u64) { + match self.removable_byte_size_mut() { + Ok(removable) => *removable = Some(new_size), + Err(fixed) => *fixed = new_size, + } + } + + /// mutable access to the size of a `self` + /// + /// returns + /// - `Ok(&mut option)` if `self`'s represented type can exist in an unsized configuration (e.g. struct, array) + /// - `Err(&mut size)` if `self`'s represented type is always sized (e.g. vector, matrix) + pub fn removable_byte_size_mut(&mut self) -> Result<&mut Option, &mut u64> { + match self { + TypeLayout::Vector(v) => Err(&mut v.byte_size), + TypeLayout::Matrix(m) => Err(&mut m.byte_size), + TypeLayout::PackedVector(v) => Err(&mut v.byte_size), + TypeLayout::Array(a) => Ok(&mut Rc::make_mut(a).byte_size), + TypeLayout::Struct(s) => Ok(&mut Rc::make_mut(s).byte_size), + } + } + + // TODO(chronicl) this should be removed with improved any api for storage/uniform bindings + pub(crate) fn from_store_ty(store_type: ir::StoreType) -> Result { + let t: recipe::TypeLayoutRecipe = store_type.try_into()?; + Ok(t.layout()) + } +} + +impl From for TypeLayout { + fn from(layout: VectorLayout) -> Self { TypeLayout::Vector(layout) } +} + +impl From for TypeLayout { + fn from(layout: PackedVectorLayout) -> Self { TypeLayout::PackedVector(layout) } +} + +impl From for TypeLayout { + fn from(layout: MatrixLayout) -> Self { TypeLayout::Matrix(layout) } +} + +impl From for TypeLayout { + fn from(layout: ArrayLayout) -> Self { TypeLayout::Array(Rc::new(layout)) } +} + +impl From for TypeLayout { + fn from(layout: StructLayout) -> Self { TypeLayout::Struct(Rc::new(layout)) } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + any::U32PowerOf2, + frontend::rust_types::type_layout::{ + recipe::{*}, + Repr, *, + }, + }; + use std::{rc::Rc, num::NonZeroU32}; + + #[test] + fn test_array_alignment() { + let array: TypeLayoutRecipe = SizedArray::new( + Rc::new(Vector::new(ScalarType::F32, Len::X1).into()), + NonZeroU32::new(1).unwrap(), + ) + .into(); + + // To change the top level arrays repr, we need to set the default repr, + // because non-structs inherit repr. + let storage = array.layout_with_default_repr(Repr::Wgsl); + let uniform = array.layout_with_default_repr(Repr::WgslUniform); + let packed = array.layout_with_default_repr(Repr::Packed); + + assert_eq!(storage.align(), U32PowerOf2::_4); + assert_eq!(uniform.align(), U32PowerOf2::_16); + assert_eq!(packed.align(), U32PowerOf2::_1); + + assert_eq!(storage.byte_size(), Some(4)); + assert_eq!(uniform.byte_size(), Some(16)); + assert_eq!(packed.byte_size(), Some(4)); + + match (storage, uniform, packed) { + (TypeLayout::Array(storage), TypeLayout::Array(uniform), TypeLayout::Array(packed)) => { + assert_eq!(storage.len, Some(1)); + assert_eq!(uniform.len, Some(1)); + assert_eq!(packed.len, Some(1)); + assert_eq!(storage.byte_stride, 4); + assert_eq!(uniform.byte_stride, 16); + assert_eq!(packed.byte_stride, 4); + } + _ => panic!("Unexpected layout kind"), + } + } + + #[test] + fn test_struct_alignment() { + let s = |repr| -> TypeLayoutRecipe { + SizedStruct::new("A", "a", Vector::new(ScalarType::F32, Len::X1), repr).into() + }; + + let storage = s(Repr::Wgsl).layout(); + let uniform = s(Repr::WgslUniform).layout(); + let packed = s(Repr::Packed).layout(); + + assert_eq!(storage.align(), U32PowerOf2::_4); + assert_eq!(uniform.align(), U32PowerOf2::_16); + assert_eq!(packed.align(), U32PowerOf2::_1); + + assert_eq!(storage.byte_size(), Some(4)); + assert_eq!(uniform.byte_size(), Some(16)); + assert_eq!(packed.byte_size(), Some(4)); + } + + #[test] + fn test_nested_struct_field_offset() { + let s = |repr| -> TypeLayoutRecipe { + let a = SizedStruct::new("A", "a", Vector::new(ScalarType::F32, Len::X1), repr); + SizedStruct::new("B", "a", Vector::new(ScalarType::F32, Len::X1), repr) + .extend("b", a) // offset 4 for storage and packed, offset 16 for uniform + .into() + }; + + let storage = s(Repr::Wgsl).layout(); + let uniform = s(Repr::WgslUniform).layout(); + let packed = s(Repr::Packed).layout(); + + assert_eq!(storage.align(), U32PowerOf2::_4); + assert_eq!(uniform.align(), U32PowerOf2::_16); + assert_eq!(packed.align(), U32PowerOf2::_1); + + assert_eq!(storage.byte_size(), Some(8)); + // field b is bytes 16..=19 and struct size must be a multiple of the struct align (16) + assert_eq!(uniform.byte_size(), Some(32)); + assert_eq!(packed.byte_size(), Some(8)); + + match (storage, uniform, packed) { + (TypeLayout::Struct(storage), TypeLayout::Struct(uniform), TypeLayout::Struct(packed)) => { + assert_eq!(storage.fields[1].rel_byte_offset, 4); + assert_eq!(uniform.fields[1].rel_byte_offset, 16); + assert_eq!(packed.fields[1].rel_byte_offset, 4); + } + _ => panic!("Unexpected layout kind"), + } + } + + #[test] + fn test_array_in_struct_field_offset() { + let s = |repr| -> TypeLayoutRecipe { + SizedStruct::new("B", "a", Vector::new(ScalarType::F32, Len::X1), repr) + .extend( + "b", + SizedArray::new( + Rc::new(Vector::new(ScalarType::F32, Len::X1).into()), + NonZeroU32::new(1).unwrap(), + ), + ) // offset 4 for storage and packed, offset 16 for uniform + .into() + }; + + let storage = s(Repr::Wgsl).layout(); + let uniform = s(Repr::WgslUniform).layout(); + let packed = s(Repr::Packed).layout(); + + assert_eq!(storage.align(), U32PowerOf2::_4); + assert_eq!(uniform.align(), U32PowerOf2::_16); + assert_eq!(packed.align(), U32PowerOf2::_1); + + assert_eq!(storage.byte_size(), Some(8)); + // field b is bytes 16..=19 and struct size must be a multiple of the struct align (16) + assert_eq!(uniform.byte_size(), Some(32)); + assert_eq!(packed.byte_size(), Some(8)); + + match (storage, uniform, packed) { + (TypeLayout::Struct(storage), TypeLayout::Struct(uniform), TypeLayout::Struct(packed)) => { + assert_eq!(storage.fields[1].rel_byte_offset, 4); + assert_eq!(uniform.fields[1].rel_byte_offset, 16); + assert_eq!(packed.fields[1].rel_byte_offset, 4); + } + _ => panic!("Unexpected layout kind"), + } + } + + #[test] + fn test_unsized_struct_layout() { + let mut unsized_struct = UnsizedStruct { + name: CanonName::from("TestStruct"), + repr: Repr::Wgsl, + sized_fields: vec![ + SizedField { + name: CanonName::from("field1"), + custom_min_size: None, + custom_min_align: None, + ty: Vector::new(ScalarType::F32, Len::X2).into(), + }, + SizedField { + name: CanonName::from("field2"), + custom_min_size: None, + custom_min_align: None, + ty: Vector::new(ScalarType::F32, Len::X1).into(), + }, + ], + last_unsized: RuntimeSizedArrayField { + name: CanonName::from("dynamic_array"), + custom_min_align: None, + array: RuntimeSizedArray { + element: Vector::new(ScalarType::F32, Len::X1).into(), + }, + }, + }; + let recipe: TypeLayoutRecipe = unsized_struct.clone().into(); + + let layout = recipe.layout(); + assert_eq!(layout.byte_size(), None); + assert!(layout.align().as_u64() == 8); // align of vec2 + match &layout { + TypeLayout::Struct(struct_layout) => { + assert_eq!(struct_layout.fields.len(), 3); + assert_eq!(struct_layout.fields[0].name, CanonName::from("field1")); + assert_eq!(struct_layout.fields[1].name, CanonName::from("field2")); + assert_eq!(struct_layout.fields[2].name, CanonName::from("dynamic_array")); + + assert_eq!(struct_layout.fields[0].rel_byte_offset, 0); // vec2 + assert_eq!(struct_layout.fields[1].rel_byte_offset, 8); // f32 + assert_eq!(struct_layout.fields[2].rel_byte_offset, 12); // Array + // The last field should be an unsized array + match &struct_layout.fields[2].ty { + TypeLayout::Array(array) => { + assert_eq!(array.byte_size, None); + assert_eq!(array.byte_stride, 4) + } + _ => panic!("Expected runtime-sized array for last field"), + } + } + _ => panic!("Expected structure layout"), + } + + // Testing uniform representation + unsized_struct.repr = Repr::WgslUniform; + let recipe: TypeLayoutRecipe = unsized_struct.into(); + println!("{recipe:#?}"); + let layout = recipe.layout(); + assert_eq!(layout.byte_size(), None); + // Struct alignmment has to be a multiple of 16, but the runtime sized array + // also has an alignment of 16, which transfers to the struct alignment. + assert!(layout.align().as_u64() == 16); + match &layout { + TypeLayout::Struct(struct_layout) => { + assert_eq!(struct_layout.fields[0].rel_byte_offset, 0); // vec2 + assert_eq!(struct_layout.fields[1].rel_byte_offset, 8); // f32 + // array has alignment of 16, so offset should be 16 + assert_eq!(struct_layout.fields[2].rel_byte_offset, 16); // Array + match &struct_layout.fields[2].ty { + // Stride has to be a multiple of 16 in uniform address space + TypeLayout::Array(array) => { + assert_eq!(array.byte_size, None); + assert_eq!(array.byte_stride, 16); + } + _ => panic!("Expected runtime-sized array for last field"), + } + } + _ => panic!("Expected structure layout"), + } + } +} diff --git a/shame/src/frontend/rust_types/type_layout/recipe/align_size.rs b/shame/src/frontend/rust_types/type_layout/recipe/align_size.rs new file mode 100644 index 0000000..1c66cef --- /dev/null +++ b/shame/src/frontend/rust_types/type_layout/recipe/align_size.rs @@ -0,0 +1,1163 @@ +use super::super::{Repr}; +use super::*; + +// Size and align of layout recipe types // +// https://www.w3.org/TR/WGSL/#address-space-layout-constraints // + +pub(crate) const PACKED_ALIGN: U32PowerOf2 = U32PowerOf2::_1; + +impl TypeLayoutRecipe { + /// This is expensive for structs. Prefer `byte_size_and_align` if you also need the align. + pub fn byte_size(&self, default_repr: Repr) -> Option { + match self { + TypeLayoutRecipe::Sized(s) => Some(s.byte_size(default_repr)), + TypeLayoutRecipe::UnsizedStruct(_) | TypeLayoutRecipe::RuntimeSizedArray(_) => None, + } + } + + /// This is expensive for structs. Prefer `byte_size_and_align` if you also need the size. + pub fn align(&self, default_repr: Repr) -> U32PowerOf2 { + match self { + TypeLayoutRecipe::Sized(s) => s.align(default_repr), + TypeLayoutRecipe::UnsizedStruct(s) => s.align(), + TypeLayoutRecipe::RuntimeSizedArray(a) => a.align(default_repr), + } + } + + /// This is expensive for structs as it calculates the byte size and align by traversing all fields recursively. + pub fn byte_size_and_align(&self, default_repr: Repr) -> (Option, U32PowerOf2) { + match self { + TypeLayoutRecipe::Sized(s) => { + let (size, align) = s.byte_size_and_align(default_repr); + (Some(size), align) + } + TypeLayoutRecipe::UnsizedStruct(s) => (None, s.align()), + TypeLayoutRecipe::RuntimeSizedArray(a) => (None, a.align(default_repr)), + } + } + + /// Returns a copy of self, but with all struct reprs changed to `repr`. + pub fn to_unified_repr(&self, repr: Repr) -> Self { + let mut this = self.clone(); + this.change_all_repr(repr); + this + } + + /// Recursively changes all struct reprs to the given `repr`. + pub fn change_all_repr(&mut self, repr: Repr) { + match self { + TypeLayoutRecipe::Sized(s) => s.change_all_repr(repr), + TypeLayoutRecipe::UnsizedStruct(s) => s.change_all_repr(repr), + TypeLayoutRecipe::RuntimeSizedArray(a) => a.change_all_repr(repr), + } + } +} + +impl SizedType { + /// This is expensive for structs. Prefer `byte_size_and_align` if you also need the align. + pub fn byte_size(&self, parent_repr: Repr) -> u64 { self.byte_size_and_align(parent_repr).0 } + + /// This is expensive for structs. Prefer `byte_size_and_align` if you also need the size. + pub fn align(&self, parent_repr: Repr) -> U32PowerOf2 { self.byte_size_and_align(parent_repr).1 } + + /// This is expensive for structs as it calculates the byte size and align by traversing all fields recursively. + pub fn byte_size_and_align(&self, parent_repr: Repr) -> (u64, U32PowerOf2) { + let repr = parent_repr; + match self { + SizedType::Array(a) => (a.byte_size(parent_repr), a.align(parent_repr)), + SizedType::Vector(v) => (v.byte_size(parent_repr), v.align(parent_repr)), + SizedType::Matrix(m) => (m.byte_size(parent_repr), m.align(parent_repr)), + SizedType::Atomic(a) => (a.byte_size(), a.align(parent_repr)), + SizedType::PackedVec(v) => (u8::from(v.byte_size()) as u64, v.align(parent_repr)), + SizedType::Struct(s) => s.byte_size_and_align(), + } + } + + /// Recursively changes all struct reprs to the given `repr`. + pub fn change_all_repr(&mut self, repr: Repr) { + match self { + SizedType::Struct(s) => s.change_all_repr(repr), + SizedType::Array(s) => s.change_all_repr(repr), + SizedType::Atomic(_) | SizedType::PackedVec(_) | SizedType::Vector(_) | SizedType::Matrix(_) => { + // No repr to change for these types. + } + } + } +} + +impl SizedStruct { + /// Returns [`FieldOffsetsSized`], which serves as an iterator over the offsets of the + /// fields of this struct. `FieldOffsetsSized::struct_byte_size_and_align` can be + /// used to efficiently obtain the byte_size and align. + pub fn field_offsets(&self) -> FieldOffsetsSized<'_> { + FieldOffsetsSized(FieldOffsets::new(self.fields(), self.repr)) + } + + /// Returns (byte_size, align) + /// + /// This is expensive for structs as it calculates the byte size and align by traversing all fields recursively. + pub fn byte_size_and_align(&self) -> (u64, U32PowerOf2) { self.field_offsets().struct_byte_size_and_align() } + + /// Recursively changes all struct reprs to the given `repr`. + pub fn change_all_repr(&mut self, repr: Repr) { + self.repr = repr; + for field in &mut self.fields { + field.ty.change_all_repr(repr); + } + } +} + +impl UnsizedStruct { + /// Returns [`FieldOffsetsUnsized`]. + /// + /// - Use [`FieldOffsetsUnsized::sized_field_offsets`] for an iterator over the sized field offsets. + /// - Use [`FieldOffsetsUnsized::last_field_offset_and_struct_align`] for the last field's offset + /// and the struct's align + pub fn field_offsets(&self) -> FieldOffsetsUnsized<'_> { + FieldOffsetsUnsized::new(&self.sized_fields, &self.last_unsized, self.repr) + } + + /// This is expensive as it calculates the byte align by traversing all fields recursively. + pub fn align(&self) -> U32PowerOf2 { self.field_offsets().last_field_offset_and_struct_align().1 } + + /// Recursively changes all struct reprs to the given `repr`. + pub fn change_all_repr(&mut self, repr: Repr) { + self.repr = repr; + for field in &mut self.sized_fields { + field.ty.change_all_repr(repr); + } + self.last_unsized.array.change_all_repr(repr); + } +} + +/// An iterator over the offsets of sized fields. +#[derive(Debug)] +pub struct FieldOffsets<'a> { + fields: &'a [SizedField], + field_index: usize, + calc: StructLayoutCalculator, + repr: Repr, +} +impl Iterator for FieldOffsets<'_> { + type Item = u64; + + fn next(&mut self) -> Option { + self.field_index += 1; + self.fields.get(self.field_index - 1).map(|field| { + let (size, align) = field.ty.byte_size_and_align(self.repr); + let is_struct = matches!(field.ty, SizedType::Struct(_)); + + self.calc + .extend(size, align, field.custom_min_size, field.custom_min_align, is_struct) + }) + } +} +impl<'a> FieldOffsets<'a> { + fn new(fields: &'a [SizedField], repr: Repr) -> Self { + Self { + fields, + field_index: 0, + calc: StructLayoutCalculator::new(repr), + repr, + } + } +} + +/// Iterator over the field offsets of a `SizedStruct`. +// The difference to `FieldOffsets` is that it also offers a `struct_byte_size_and_align` method. +#[derive(Debug)] +pub struct FieldOffsetsSized<'a>(FieldOffsets<'a>); +impl Iterator for FieldOffsetsSized<'_> { + type Item = u64; + fn next(&mut self) -> Option { self.0.next() } +} +impl<'a> FieldOffsetsSized<'a> { + /// Consumes self and calculates the byte size and align of a struct + /// with exactly the sized fields that this FieldOffsets was created with. + pub fn struct_byte_size_and_align(mut self) -> (u64, U32PowerOf2) { + // Finishing layout calculations + // using count only to advance iterator to the end + (&mut self.0).count(); + (self.0.calc.byte_size(), self.0.calc.align()) + } + + /// Returns the inner iterator over sized fields. + pub fn into_inner(self) -> FieldOffsets<'a> { self.0 } +} + +/// The field offsets of an `UnsizedStruct`. +/// +/// - Use [`FieldOffsetsUnsized::sized_field_offsets`] for an iterator over the sized field offsets. +/// - Use [`FieldOffsetsUnsized::last_field_offset_and_struct_align`] for the last field's offset +/// and the struct's align +pub struct FieldOffsetsUnsized<'a> { + sized: FieldOffsets<'a>, + last_unsized: &'a RuntimeSizedArrayField, +} + +impl<'a> FieldOffsetsUnsized<'a> { + fn new(sized_fields: &'a [SizedField], last_unsized: &'a RuntimeSizedArrayField, repr: Repr) -> Self { + Self { + sized: FieldOffsets::new(sized_fields, repr), + last_unsized, + } + } + + /// Returns an iterator over the sized field offsets. + pub fn sized_field_offsets(&mut self) -> &mut FieldOffsets<'a> { &mut self.sized } + + /// Returns the last field's offset and the struct's align. + pub fn last_field_offset_and_struct_align(mut self) -> (u64, U32PowerOf2) { + // Finishing layout calculations + // using count only to advance iterator to the end + (&mut self.sized).count(); + let array_align = self.last_unsized.array.align(self.sized.repr); + let custom_min_align = self.last_unsized.custom_min_align; + self.sized.calc.extend_unsized(array_align, custom_min_align) + } + + /// Returns the inner iterator over sized fields. + pub fn into_inner(self) -> FieldOffsets<'a> { self.sized } +} + + +#[allow(missing_docs)] +impl Vector { + pub const fn new(scalar: ScalarType, len: Len) -> Self { Self { scalar, len } } + + pub const fn byte_size(&self, repr: Repr) -> u64 { + match repr { + Repr::Wgsl | Repr::WgslUniform | Repr::Packed => self.len.as_u64() * self.scalar.byte_size(), + } + } + + pub const fn align(&self, repr: Repr) -> U32PowerOf2 { + match repr { + Repr::Packed => PACKED_ALIGN, + Repr::Wgsl | Repr::WgslUniform => { + let po2_len = match self.len { + Len::X1 | Len::X2 | Len::X4 => self.len.as_u32(), + Len::X3 => 4, + }; + let po2_align = self.scalar.align(repr); + U32PowerOf2::try_from_u32(po2_len * po2_align.as_u32()).expect( + "power of 2 * power of 2 = power of 2. Highest operands are around 4 * 16 so overflow is unlikely", + ) + } + } + } +} + +#[allow(missing_docs)] +impl ScalarType { + pub const fn byte_size(&self) -> u64 { + match self { + ScalarType::F16 => 2, + ScalarType::F32 | ScalarType::U32 | ScalarType::I32 => 4, + ScalarType::F64 => 8, + } + } + + pub const fn align(&self, repr: Repr) -> U32PowerOf2 { + match repr { + Repr::Packed => PACKED_ALIGN, + Repr::Wgsl | Repr::WgslUniform => match self { + ScalarType::F16 => U32PowerOf2::_2, + ScalarType::F32 | ScalarType::U32 | ScalarType::I32 => U32PowerOf2::_4, + ScalarType::F64 => U32PowerOf2::_8, + }, + } + } +} + +#[allow(missing_docs)] +#[derive(Debug, Clone, Copy)] +pub enum MatrixMajor { + Row, + Column, +} + +#[allow(missing_docs)] +impl Matrix { + pub const fn byte_size(&self, repr: Repr) -> u64 { + let (vec, array_len) = self.as_vector_array(); + // According to https://www.w3.org/TR/WGSL/#alignment-and-size + // SizeOf(matCxR) = SizeOf(array) = C × roundUp(AlignOf(vecR), SizeOf(vecR)) + array_len.get() as u64 * round_up(vec.align(repr).as_u64(), vec.byte_size(repr)) + } + + pub const fn align(&self, repr: Repr) -> U32PowerOf2 { + let (vec, _) = self.as_vector_array(); + // AlignOf(vecR) + vec.align(repr) + } + + const fn as_vector_array(&self) -> (Vector, NonZeroU32) { + let major = MatrixMajor::Column; // This can be made a parameter in the future. + let (vec_len, array_len): (Len, NonZeroU32) = match major { + MatrixMajor::Column => (self.rows.as_len(), self.columns.as_non_zero_u32()), + MatrixMajor::Row => (self.columns.as_len(), self.rows.as_non_zero_u32()), + }; + ( + Vector { + len: vec_len, + scalar: self.scalar.as_scalar_type(), + }, + array_len, + ) + } +} + +#[allow(missing_docs)] +impl Atomic { + pub const fn byte_size(&self) -> u64 { self.scalar.as_scalar_type().byte_size() } + pub const fn align(&self, repr: Repr) -> U32PowerOf2 { + match repr { + Repr::Packed => return PACKED_ALIGN, + Repr::Wgsl | Repr::WgslUniform => {} + } + + self.scalar.as_scalar_type().align(repr) + } +} + +#[allow(missing_docs)] +impl SizedArray { + pub fn byte_size(&self, repr: Repr) -> u64 { array_size(self.byte_stride(repr), self.len) } + + pub fn align(&self, repr: Repr) -> U32PowerOf2 { array_align(self.element.align(repr), repr) } + + pub fn byte_stride(&self, repr: Repr) -> u64 { + let (element_size, element_align) = self.element.byte_size_and_align(repr); + array_stride(element_align, element_size, repr) + } + + // Recursively changes all struct reprs to the given `repr`. + pub fn change_all_repr(&mut self, repr: Repr) { + let mut element = (*self.element).clone(); + element.change_all_repr(repr); + self.element = Rc::new(element); + } +} + +/// Returns an array's size given it's stride and length. +/// +/// Note, this is independent of layout rules (`Repr`). +pub const fn array_size(array_stride: u64, len: NonZeroU32) -> u64 { array_stride * len.get() as u64 } + +/// Returns an array's size given the alignment of it's elements. +pub const fn array_align(element_align: U32PowerOf2, repr: Repr) -> U32PowerOf2 { + match repr { + // Packedness is ensured by the `LayoutCalculator`. + Repr::Wgsl => element_align, + Repr::WgslUniform => round_up_align(U32PowerOf2::_16, element_align), + Repr::Packed => PACKED_ALIGN, + } +} + +/// Returns an array's stride (the distance between consecutive elements) given the alignment and size of its elements. +pub const fn array_stride(element_align: U32PowerOf2, element_size: u64, repr: Repr) -> u64 { + let element_align = match repr { + Repr::Wgsl => element_align, + // This should already be the case, but doesn't hurt to ensure. + Repr::Packed => PACKED_ALIGN, + // The uniform address space also requires that: + // Array elements are aligned to 16 byte boundaries. + // That is, StrideOf(array) = 16 × k’ for some positive integer k'. + // - https://www.w3.org/TR/WGSL/#address-space-layout-constraints + Repr::WgslUniform => round_up_align(U32PowerOf2::_16, element_align), + }; + + round_up(element_align.as_u64(), element_size) +} + +#[allow(missing_docs)] +impl RuntimeSizedArray { + pub fn align(&self, parent_repr: Repr) -> U32PowerOf2 { array_align(self.element.align(parent_repr), parent_repr) } + + pub fn byte_stride(&self, parent_repr: Repr) -> u64 { + array_stride( + self.align(parent_repr), + self.element.byte_size(parent_repr), + parent_repr, + ) + } + + // Recursively changes all struct reprs to the given `repr`. + pub fn change_all_repr(&mut self, repr: Repr) { self.element.change_all_repr(repr); } +} + +#[allow(missing_docs)] +impl SizedField { + pub fn byte_size(&self, repr: Repr) -> u64 { + StructLayoutCalculator::calculate_byte_size(self.ty.byte_size(repr), self.custom_min_size) + } + pub fn align(&self, repr: Repr) -> U32PowerOf2 { + StructLayoutCalculator::calculate_align(self.ty.align(repr), self.custom_min_align, repr) + } +} + +#[allow(missing_docs)] +impl RuntimeSizedArrayField { + pub fn align(&self, repr: Repr) -> U32PowerOf2 { + StructLayoutCalculator::calculate_align(self.array.align(repr), self.custom_min_align, repr) + } +} + +pub const fn round_up(multiple_of: u64, n: u64) -> u64 { + match multiple_of { + 0 => match n { + 0 => 0, + _ => panic!("cannot round up n to a multiple of 0"), + }, + k @ 1.. => n.div_ceil(k) * k, + } +} + +pub const fn round_up_align(multiple_of: U32PowerOf2, n: U32PowerOf2) -> U32PowerOf2 { + let rounded_up = round_up(multiple_of.as_u64(), n.as_u64()); + // n <= multiple_of -> rounded_up = multiple_of + // n > multiple_of -> rounded_up = n, since both are powers of 2, n must already + // be a multiple of multiple_of + // In both cases rounded_up is a power of 2 + U32PowerOf2::try_from_u32(rounded_up as u32).unwrap() +} + +/// `LayoutCalculator` helps calculate the size, align and the field offsets of a struct. +/// +/// If `LayoutCalculator` is created with `repr == Repr::Packed`, provided `field_align`s +/// are ignored and the field is inserted directly after the previous field. However, +/// a `custom_min_align` that is `Some` overwrites the "packedness" of the field. +#[derive(Debug, Clone)] +pub struct StructLayoutCalculator { + next_offset_min: u64, + align: U32PowerOf2, + repr: Repr, +} + +impl StructLayoutCalculator { + /// Creates a new `LayoutCalculator`, which calculates the size, align and + /// the field offsets of a gpu struct. + pub const fn new(repr: Repr) -> Self { + Self { + next_offset_min: 0, + align: U32PowerOf2::_1, + repr, + } + } + + /// Extends the layout by a field. + /// + /// `is_struct` must be true if the field is a struct. + /// + /// Returns the field's offset. + pub const fn extend( + &mut self, + field_size: u64, + mut field_align: U32PowerOf2, + custom_min_size: Option, + custom_min_align: Option, + is_struct: bool, + ) -> u64 { + // Just in case the user didn't already do this. + match self.repr { + Repr::Packed => field_align = PACKED_ALIGN, + Repr::Wgsl | Repr::WgslUniform => {} + } + + let size = Self::calculate_byte_size(field_size, custom_min_size); + let align = Self::calculate_align(field_align, custom_min_align, self.repr); + + let offset = self.next_field_offset(align, custom_min_align); + self.next_offset_min = match (self.repr, is_struct) { + // The uniform address space requires that: + // - If a structure member itself has a structure type S, then the number of + // bytes between the start of that member and the start of any following + // member must be at least roundUp(16, SizeOf(S)). + (Repr::WgslUniform, true) => round_up(16, offset + size), + (Repr::Wgsl | Repr::Packed, _) | (Repr::WgslUniform, false) => offset + size, + }; + self.align = self.align.max(align); + + offset + } + + /// Extends the layout by a runtime sized array field given it's align. + /// + /// Returns (last field offset, align) + /// + /// `self` is consumed, so that no further fields may be extended, because + /// only the last field may be unsized. + pub const fn extend_unsized( + mut self, + mut field_align: U32PowerOf2, + custom_min_align: Option, + ) -> (u64, U32PowerOf2) { + // Just in case the user didn't already do this. + match self.repr { + Repr::Packed => field_align = PACKED_ALIGN, + Repr::Wgsl | Repr::WgslUniform => {} + } + + let align = Self::calculate_align(field_align, custom_min_align, self.repr); + + let offset = self.next_field_offset(align, custom_min_align); + self.align = self.align.max(align); + + (offset, self.align()) + } + + /// Returns the byte size of the struct. + // wgsl spec: + // roundUp(AlignOf(S), justPastLastMember) + // where justPastLastMember = OffsetOfMember(S,N) + SizeOfMember(S,N) + // + // self.next_offset_min is justPastLastMember already. + pub const fn byte_size(&self) -> u64 { round_up(self.align().as_u64(), self.next_offset_min) } + + /// Returns the align of the struct. + pub const fn align(&self) -> U32PowerOf2 { Self::adjust_struct_alignment_for_repr(self.align, self.repr) } + + const fn next_field_offset(&self, field_align: U32PowerOf2, field_custom_min_align: Option) -> u64 { + let field_align = Self::calculate_align(field_align, field_custom_min_align, self.repr); + match (self.repr, field_custom_min_align) { + // Packed always returns self.next_offset_min regardless of custom_min_align + (Repr::Packed, _) => self.next_offset_min, + (Repr::Wgsl | Repr::WgslUniform, _) => round_up(field_align.as_u64(), self.next_offset_min), + } + } + + pub(crate) const fn calculate_byte_size(byte_size: u64, custom_min_size: Option) -> u64 { + // const byte_size.max(custom_min_size.unwrap_or(0)) + if let Some(min_size) = custom_min_size { + if min_size > byte_size { + return min_size; + } + } + byte_size + } + + pub(crate) const fn calculate_align( + align: U32PowerOf2, + custom_min_align: Option, + repr: Repr, + ) -> U32PowerOf2 { + match repr { + Repr::Wgsl | Repr::WgslUniform => { + // const align.max(custom_min_align.unwrap_or(U32PowerOf2::_1)) + if let Some(min_align) = custom_min_align { + align.max(min_align) + } else { + align + } + } + // custom_min_align is ignored in packed structs and the align is always 1. + Repr::Packed => PACKED_ALIGN, + } + } + + const fn adjust_struct_alignment_for_repr(align: U32PowerOf2, repr: Repr) -> U32PowerOf2 { + match repr { + // Packedness is ensured by the `LayoutCalculator`. + Repr::Wgsl => align, + Repr::WgslUniform => round_up_align(U32PowerOf2::_16, align), + Repr::Packed => PACKED_ALIGN, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::any::U32PowerOf2; + use crate::frontend::rust_types::type_layout::Repr; + use crate::ir::{Len, Len2, ScalarTypeFp, ScalarTypeInteger}; + use std::num::NonZeroU32; + use std::rc::Rc; + use super::super::builder::FieldOptions; + + #[test] + fn test_primitives_layout() { + // Testing all aligns and sizes found here (and some more that aren't in the spec like f64) + // https://www.w3.org/TR/WGSL/#alignment-and-size + + // i32, u32, or f32: AlilgnOf(T) = 4, SizeOf(T) = 4 + assert_eq!(ScalarType::I32.byte_size(), 4); + assert_eq!(ScalarType::I32.align(Repr::Wgsl), U32PowerOf2::_4); + assert_eq!(ScalarType::U32.byte_size(), 4); + assert_eq!(ScalarType::U32.align(Repr::Wgsl), U32PowerOf2::_4); + assert_eq!(ScalarType::F32.byte_size(), 4); + assert_eq!(ScalarType::F32.align(Repr::Wgsl), U32PowerOf2::_4); + // f16: AlilgnOf(T) = 2, SizeOf(T) = 2 + assert_eq!(ScalarType::F16.byte_size(), 2); + assert_eq!(ScalarType::F16.align(Repr::Wgsl), U32PowerOf2::_2); + // not found in spec + assert_eq!(ScalarType::F64.byte_size(), 8); + assert_eq!(ScalarType::F64.align(Repr::Wgsl), U32PowerOf2::_8); + + // Test atomics + let atomic_u32 = Atomic { + scalar: ScalarTypeInteger::U32, + }; + let atomic_i32 = Atomic { + scalar: ScalarTypeInteger::I32, + }; + // atomic: AlignOf(T) = 4, SizeOf(T) = 4 + assert_eq!(atomic_u32.align(Repr::Wgsl), U32PowerOf2::_4); + assert_eq!(atomic_u32.byte_size(), 4); + assert_eq!(atomic_i32.align(Repr::Wgsl), U32PowerOf2::_4); + assert_eq!(atomic_i32.byte_size(), 4); + + // Test vectors + let vec2_f32 = Vector::new(ScalarType::F32, Len::X2); + let vec2_f16 = Vector::new(ScalarType::F16, Len::X2); + let vec3_f32 = Vector::new(ScalarType::F32, Len::X3); + let vec3_f16 = Vector::new(ScalarType::F16, Len::X3); + let vec4_f32 = Vector::new(ScalarType::F32, Len::X4); + let vec4_f16 = Vector::new(ScalarType::F16, Len::X4); + // vec2, T is i32, u32, or f32: AlignOf(T) = 8, SizeOf(T) = 8 + assert_eq!(vec2_f32.align(Repr::Wgsl), U32PowerOf2::_8); + assert_eq!(vec2_f32.byte_size(Repr::Wgsl), 8); + // vec2: AlignOf(T) = 4, SizeOf(T) = 4 + assert_eq!(vec2_f16.align(Repr::Wgsl), U32PowerOf2::_4); + assert_eq!(vec2_f16.byte_size(Repr::Wgsl), 4); + // vec3, T is i32, u32, or f32: AlignOf(T) = 16, SizeOf(T) = 12 + assert_eq!(vec3_f32.align(Repr::Wgsl), U32PowerOf2::_16); + assert_eq!(vec3_f32.byte_size(Repr::Wgsl), 12); + // vec3: AlignOf(T) = 8, SizeOf(T) = 6 + assert_eq!(vec3_f16.align(Repr::Wgsl), U32PowerOf2::_8); + assert_eq!(vec3_f16.byte_size(Repr::Wgsl), 6); + // vec4, T is i32, u32, or f32: AlignOf(T) = 16, SizeOf(T) = 16 + assert_eq!(vec4_f32.align(Repr::Wgsl), U32PowerOf2::_16); + assert_eq!(vec4_f32.byte_size(Repr::Wgsl), 16); + // vec4: AlignOf(T) = 8, SizeOf(T) = 8 + assert_eq!(vec4_f16.align(Repr::Wgsl), U32PowerOf2::_8); + assert_eq!(vec4_f16.byte_size(Repr::Wgsl), 8); + + // Test matrices + let mat2x2_f32 = Matrix { + scalar: ScalarTypeFp::F32, + columns: Len2::X2, + rows: Len2::X2, + }; + let mat2x2_f16 = Matrix { + scalar: ScalarTypeFp::F16, + columns: Len2::X2, + rows: Len2::X2, + }; + let mat3x2_f32 = Matrix { + scalar: ScalarTypeFp::F32, + columns: Len2::X3, + rows: Len2::X2, + }; + let mat3x2_f16 = Matrix { + scalar: ScalarTypeFp::F16, + columns: Len2::X3, + rows: Len2::X2, + }; + let mat4x2_f32 = Matrix { + scalar: ScalarTypeFp::F32, + columns: Len2::X4, + rows: Len2::X2, + }; + let mat4x2_f16 = Matrix { + scalar: ScalarTypeFp::F16, + columns: Len2::X4, + rows: Len2::X2, + }; + let mat2x3_f32 = Matrix { + scalar: ScalarTypeFp::F32, + columns: Len2::X2, + rows: Len2::X3, + }; + let mat2x3_f16 = Matrix { + scalar: ScalarTypeFp::F16, + columns: Len2::X2, + rows: Len2::X3, + }; + let mat3x3_f32 = Matrix { + scalar: ScalarTypeFp::F32, + columns: Len2::X3, + rows: Len2::X3, + }; + let mat3x3_f16 = Matrix { + scalar: ScalarTypeFp::F16, + columns: Len2::X3, + rows: Len2::X3, + }; + let mat4x3_f32 = Matrix { + scalar: ScalarTypeFp::F32, + columns: Len2::X4, + rows: Len2::X3, + }; + let mat4x3_f16 = Matrix { + scalar: ScalarTypeFp::F16, + columns: Len2::X4, + rows: Len2::X3, + }; + let mat2x4_f32 = Matrix { + scalar: ScalarTypeFp::F32, + columns: Len2::X2, + rows: Len2::X4, + }; + let mat2x4_f16 = Matrix { + scalar: ScalarTypeFp::F16, + columns: Len2::X2, + rows: Len2::X4, + }; + let mat3x4_f32 = Matrix { + scalar: ScalarTypeFp::F32, + columns: Len2::X3, + rows: Len2::X4, + }; + let mat3x4_f16 = Matrix { + scalar: ScalarTypeFp::F16, + columns: Len2::X3, + rows: Len2::X4, + }; + let mat4x4_f32 = Matrix { + scalar: ScalarTypeFp::F32, + columns: Len2::X4, + rows: Len2::X4, + }; + let mat4x4_f16 = Matrix { + scalar: ScalarTypeFp::F16, + columns: Len2::X4, + rows: Len2::X4, + }; + // mat2x2: AlignOf(T) = 8, SizeOf(T) = 16 + assert_eq!(mat2x2_f32.align(Repr::Wgsl), U32PowerOf2::_8); + assert_eq!(mat2x2_f32.byte_size(Repr::Wgsl), 16); + // mat2x2: AlignOf(T) = 4, SizeOf(T) = 8 + assert_eq!(mat2x2_f16.align(Repr::Wgsl), U32PowerOf2::_4); + assert_eq!(mat2x2_f16.byte_size(Repr::Wgsl), 8); + // mat3x2: AlignOf(T) = 8, SizeOf(T) = 24 + assert_eq!(mat3x2_f32.align(Repr::Wgsl), U32PowerOf2::_8); + assert_eq!(mat3x2_f32.byte_size(Repr::Wgsl), 24); + // mat3x2: AlignOf(T) = 4, SizeOf(T) = 12 + assert_eq!(mat3x2_f16.align(Repr::Wgsl), U32PowerOf2::_4); + assert_eq!(mat3x2_f16.byte_size(Repr::Wgsl), 12); + // mat4x2: AlignOf(T) = 8, SizeOf(T) = 32 + assert_eq!(mat4x2_f32.align(Repr::Wgsl), U32PowerOf2::_8); + assert_eq!(mat4x2_f32.byte_size(Repr::Wgsl), 32); + // mat4x2: AlignOf(T) = 4, SizeOf(T) = 16 + assert_eq!(mat4x2_f16.align(Repr::Wgsl), U32PowerOf2::_4); + assert_eq!(mat4x2_f16.byte_size(Repr::Wgsl), 16); + // mat2x3: AlignOf(T) = 16, SizeOf(T) = 32 + assert_eq!(mat2x3_f32.align(Repr::Wgsl), U32PowerOf2::_16); + assert_eq!(mat2x3_f32.byte_size(Repr::Wgsl), 32); + // mat2x3: AlignOf(T) = 8, SizeOf(T) = 16 + assert_eq!(mat2x3_f16.align(Repr::Wgsl), U32PowerOf2::_8); + assert_eq!(mat2x3_f16.byte_size(Repr::Wgsl), 16); + // mat3x3: AlignOf(T) = 16, SizeOf(T) = 48 + assert_eq!(mat3x3_f32.align(Repr::Wgsl), U32PowerOf2::_16); + assert_eq!(mat3x3_f32.byte_size(Repr::Wgsl), 48); + // mat3x3: AlignOf(T) = 8, SizeOf(T) = 24 + assert_eq!(mat3x3_f16.align(Repr::Wgsl), U32PowerOf2::_8); + assert_eq!(mat3x3_f16.byte_size(Repr::Wgsl), 24); + // mat4x3: AlignOf(T) = 16, SizeOf(T) = 64 + assert_eq!(mat4x3_f32.align(Repr::Wgsl), U32PowerOf2::_16); + assert_eq!(mat4x3_f32.byte_size(Repr::Wgsl), 64); + // mat4x3: AlignOf(T) = 8, SizeOf(T) = 32 + assert_eq!(mat4x3_f16.align(Repr::Wgsl), U32PowerOf2::_8); + assert_eq!(mat4x3_f16.byte_size(Repr::Wgsl), 32); + // mat2x4: AlignOf(T) = 16, SizeOf(T) = 32 + assert_eq!(mat2x4_f32.align(Repr::Wgsl), U32PowerOf2::_16); + assert_eq!(mat2x4_f32.byte_size(Repr::Wgsl), 32); + // mat2x4: AlignOf(T) = 8, SizeOf(T) = 16 + assert_eq!(mat2x4_f16.align(Repr::Wgsl), U32PowerOf2::_8); + assert_eq!(mat2x4_f16.byte_size(Repr::Wgsl), 16); + // mat3x4: AlignOf(T) = 16, SizeOf(T) = 48 + assert_eq!(mat3x4_f32.align(Repr::Wgsl), U32PowerOf2::_16); + assert_eq!(mat3x4_f32.byte_size(Repr::Wgsl), 48); + // mat3x4: AlignOf(T) = 8, SizeOf(T) = 24 + assert_eq!(mat3x4_f16.align(Repr::Wgsl), U32PowerOf2::_8); + assert_eq!(mat3x4_f16.byte_size(Repr::Wgsl), 24); + // mat4x4: AlignOf(T) = 16, SizeOf(T) = 64 + assert_eq!(mat4x4_f32.align(Repr::Wgsl), U32PowerOf2::_16); + assert_eq!(mat4x4_f32.byte_size(Repr::Wgsl), 64); + // mat4x4: AlignOf(T) = 8, SizeOf(T) = 32 + assert_eq!(mat4x4_f16.align(Repr::Wgsl), U32PowerOf2::_8); + assert_eq!(mat4x4_f16.byte_size(Repr::Wgsl), 32); + + // Testing Repr::WgslUniform and Repr::Packed // + + let scalars = [ + ScalarType::F16, + ScalarType::F32, + ScalarType::F64, + ScalarType::I32, + ScalarType::U32, + ]; + let atomics = [atomic_u32, atomic_i32]; + let vectors = [vec2_f32, vec2_f16, vec3_f32, vec3_f16, vec4_f32, vec4_f16]; + let matrices = [ + mat2x2_f32, mat2x2_f16, mat3x2_f32, mat3x2_f16, mat4x2_f32, mat4x2_f16, mat2x3_f32, mat2x3_f16, mat3x3_f32, + mat3x3_f16, mat4x3_f32, mat4x3_f16, mat2x4_f32, mat2x4_f16, mat3x4_f32, mat3x4_f16, mat4x4_f32, mat4x4_f16, + ]; + + // Testing + // - byte size for Storage, Uniform and Packed is the same + // - align of Storage and Uniform is the same + // - align of Packed is 1 + for scalar in scalars { + // Looks silly, because byte_size doesn't have a repr argument. + assert_eq!(scalar.byte_size(), scalar.byte_size()); + assert_eq!(scalar.align(Repr::Wgsl), scalar.align(Repr::WgslUniform)); + assert_eq!(scalar.align(Repr::Packed), U32PowerOf2::_1); + } + for atomic in atomics { + // Looks silly, because byte_size doesn't have a repr argument. + assert_eq!(atomic.byte_size(), atomic.byte_size()); + assert_eq!(atomic.align(Repr::Wgsl), atomic.align(Repr::WgslUniform)); + assert_eq!(atomic.align(Repr::Packed), U32PowerOf2::_1); + } + for vector in vectors { + assert_eq!(vector.byte_size(Repr::Wgsl), vector.byte_size(Repr::WgslUniform)); + assert_eq!(vector.align(Repr::Wgsl), vector.align(Repr::WgslUniform)); + assert_eq!(vector.align(Repr::Packed), U32PowerOf2::_1); + } + for matrix in matrices { + assert_eq!(matrix.byte_size(Repr::Wgsl), matrix.byte_size(Repr::WgslUniform)); + assert_eq!(matrix.align(Repr::Wgsl), matrix.align(Repr::WgslUniform)); + assert_eq!(matrix.align(Repr::Packed), U32PowerOf2::_1); + } + } + + #[test] + fn test_sized_array_layout() { + let element = SizedType::Vector(Vector::new(ScalarType::F32, Len::X2)); + let array = SizedArray { + element: Rc::new(element), + len: NonZeroU32::new(5).unwrap(), + }; + + // vec2 is 8 bytes, aligned to 8 bytes + assert_eq!(array.byte_stride(Repr::Wgsl), 8); + assert_eq!(array.byte_size(Repr::Wgsl), 40); // 5 * 8 + assert_eq!(array.align(Repr::Wgsl), U32PowerOf2::_8); + + // Uniform requires 16-byte alignment for array elements + assert_eq!(array.byte_stride(Repr::WgslUniform), 16); + assert_eq!(array.byte_size(Repr::WgslUniform), 80); // 5 * 16 + assert_eq!(array.align(Repr::WgslUniform), U32PowerOf2::_16); + + // Packed has 1-byte alignment + assert_eq!(array.byte_stride(Repr::Packed), 8); + assert_eq!(array.byte_size(Repr::Packed), 40); // 5 * 8 + assert_eq!(array.align(Repr::Packed), U32PowerOf2::_1); + } + + #[test] + fn test_runtime_sized_array_layout() { + let element = SizedType::Vector(Vector::new(ScalarType::F32, Len::X2)); + let array = RuntimeSizedArray { element }; + + assert_eq!(array.byte_stride(Repr::Wgsl), 8); + assert_eq!(array.align(Repr::Wgsl), U32PowerOf2::_8); + + assert_eq!(array.byte_stride(Repr::WgslUniform), 16); + assert_eq!(array.align(Repr::WgslUniform), U32PowerOf2::_16); + + assert_eq!(array.byte_stride(Repr::Packed), 8); + assert_eq!(array.align(Repr::Packed), U32PowerOf2::_1); + } + + #[test] + fn test_array_size() { + let len = NonZeroU32::new(5).unwrap(); + assert_eq!(array_size(8, len), 40); + assert_eq!(array_size(16, len), 80); + assert_eq!(array_size(1, len), 5); + } + + #[test] + fn test_array_align() { + let element_align = U32PowerOf2::_8; + assert_eq!(array_align(element_align, Repr::Wgsl), U32PowerOf2::_8); + assert_eq!(array_align(element_align, Repr::WgslUniform), U32PowerOf2::_16); + assert_eq!(array_align(element_align, Repr::Packed), U32PowerOf2::_1); + + let small_align = U32PowerOf2::_4; + assert_eq!(array_align(small_align, Repr::Wgsl), U32PowerOf2::_4); + assert_eq!(array_align(small_align, Repr::WgslUniform), U32PowerOf2::_16); + assert_eq!(array_align(small_align, Repr::Packed), U32PowerOf2::_1); + } + + #[test] + fn test_array_stride() { + let element_align = U32PowerOf2::_8; + let element_size = 12; + + // Storage: round up to element alignment + assert_eq!(array_stride(element_align, element_size, Repr::Wgsl), 16); + // Uniform: round up to 16-byte alignment + assert_eq!(array_stride(element_align, element_size, Repr::WgslUniform), 16); + // Packed: round up to 1-byte alignment (no padding) + assert_eq!(array_stride(element_align, element_size, Repr::Packed), 12); + } + + #[test] + fn test_layout_calculator_basic() { + let mut calc = StructLayoutCalculator::new(Repr::Wgsl); + + // Add a u32 field + let offset1 = calc.extend(4, U32PowerOf2::_4, None, None, false); + assert_eq!(offset1, 0); + // Add another u32 field + let offset2 = calc.extend(4, U32PowerOf2::_4, None, None, false); + assert_eq!(offset2, 4); + // Add a vec2 field (8 bytes, 8-byte aligned) + let offset3 = calc.extend(8, U32PowerOf2::_8, None, None, false); + assert_eq!(offset3, 8); + + assert_eq!(calc.byte_size(), 16); + assert_eq!(calc.align(), U32PowerOf2::_8); + } + + #[test] + fn test_layout_calculator_packed() { + let mut calc = StructLayoutCalculator::new(Repr::Packed); + + // Add a u32 field - should be packed without padding + let offset1 = calc.extend(4, U32PowerOf2::_4, None, None, false); + assert_eq!(offset1, 0); + // Add a vec2 field - should be packed directly after + let offset2 = calc.extend(8, U32PowerOf2::_8, None, None, false); + assert_eq!(offset2, 4); + + assert_eq!(calc.byte_size(), 12); + assert_eq!(calc.align(), U32PowerOf2::_1); + + // Add a vec2 field - but with custom min align, which is ignored because of Repr::Packed + let offset3 = calc.extend(8, U32PowerOf2::_8, None, Some(U32PowerOf2::_16), false); + assert_eq!(offset3, 12); + assert_eq!(calc.align(), U32PowerOf2::_1); + } + + #[test] + fn test_layout_calculator_uniform_struct_padding() { + let mut calc = StructLayoutCalculator::new(Repr::WgslUniform); + + // Add a nested struct with size 12 + let offset1 = calc.extend(12, U32PowerOf2::_4, None, None, true); + assert_eq!(offset1, 0); + // Add another field - should be padded to 16-byte boundary from struct + let offset2 = calc.extend(4, U32PowerOf2::_4, None, None, false); + assert_eq!(offset2, 16); + + assert_eq!(calc.align(), U32PowerOf2::_16); // Uniform struct alignment is multiple of 16 + assert_eq!(calc.byte_size(), 32); // Byte size of struct is a multiple of it's align + } + + #[test] + fn test_layout_calculator_custom_sizes_and_aligns() { + let mut calc = StructLayoutCalculator::new(Repr::Wgsl); + + // Add field with custom minimum size + let offset1 = calc.extend(4, U32PowerOf2::_4, Some(33), None, false); + assert_eq!(offset1, 0); + assert_eq!(calc.byte_size(), 36); // 33 rounded up to multiple of align + // Add field with custom minimum alignment + let offset2 = calc.extend(4, U32PowerOf2::_4, None, Some(U32PowerOf2::_16), false); + assert_eq!(offset2, 48); + + // 33 -> placed at 48 due to 16 align -> 64 size because rounded up to multiple of align + assert_eq!(calc.byte_size(), 64); + assert_eq!(calc.align(), U32PowerOf2::_16); + } + + #[test] + fn test_layout_calculator_extend_unsized() { + let mut calc = StructLayoutCalculator::new(Repr::Wgsl); + + // Add some sized fields first + calc.extend(4, U32PowerOf2::_4, None, None, false); + calc.extend(8, U32PowerOf2::_8, None, None, false); + // Add unsized field + let (offset, align) = calc.extend_unsized(U32PowerOf2::_4, None); + assert_eq!(offset, 16); + assert_eq!(align, U32PowerOf2::_8); + } + + #[test] + fn test_sized_field_calculations() { + // Test custom size + let field = SizedField::new( + FieldOptions::new("test_field", None, Some(16)), + SizedType::Vector(Vector::new(ScalarType::F32, Len::X2)), + ); + // Vector is 8 bytes, but field has custom min size of 16 + assert_eq!(field.byte_size(Repr::Wgsl), 16); + assert_eq!(field.align(Repr::Wgsl), U32PowerOf2::_8); + + // Test custom alignment + let field2 = SizedField::new( + FieldOptions::new("test_field2", Some(U32PowerOf2::_16), None), + SizedType::Vector(Vector::new(ScalarType::F32, Len::X2)), + ); + assert_eq!(field2.align(Repr::Wgsl), U32PowerOf2::_16); + } + + #[test] + fn test_runtime_sized_array_field_align() { + let field = RuntimeSizedArrayField::new( + "test_array", + Some(U32PowerOf2::_16), + SizedType::Vector(Vector::new(ScalarType::F32, Len::X2)), + ); + + // Array has 8-byte alignment, but field has custom min align of 16 + assert_eq!(field.align(Repr::Wgsl), U32PowerOf2::_16); + // Custom min align is ignored by packed + assert_eq!(field.align(Repr::Packed), U32PowerOf2::_1); + } + + #[test] + fn test_sized_struct_layout() { + // Create a struct with mixed field types + let sized_struct = SizedStruct::new( + "TestStruct", + "field1", + SizedType::Vector(Vector::new(ScalarType::F32, Len::X1)), // 4 bytes, 4-byte aligned + Repr::Wgsl, + ) + .extend( + "field2", + SizedType::Vector(Vector::new(ScalarType::F32, Len::X2)), // 8 bytes, 8-byte aligned + ) + .extend( + "field3", + SizedType::Vector(Vector::new(ScalarType::F32, Len::X1)), // 4 bytes, 4-byte aligned + ); + + // Test field offsets + let mut field_offsets = sized_struct.field_offsets(); + + // Field 1: offset 0 + assert_eq!(field_offsets.next(), Some(0)); + // Field 2: offset 8 (aligned to 8-byte boundary) + assert_eq!(field_offsets.next(), Some(8)); + // Field 3: offset 16 (directly after field 2) + assert_eq!(field_offsets.next(), Some(16)); + // No more fields + assert_eq!(field_offsets.next(), None); + + // Test struct size and alignment + let (size, align) = sized_struct.byte_size_and_align(); + assert_eq!(size, 24); // Round up to 8-byte alignment: round_up(8, 20) = 24 + assert_eq!(align, U32PowerOf2::_8); + } + + #[test] + fn test_uniform_struct_alignment() { + let sized_struct = SizedStruct::new( + "TestStruct", + "field1", + SizedType::Vector(Vector::new(ScalarType::F32, Len::X2)), // 8 bytes, 8-byte aligned + Repr::WgslUniform, + ); + + let (size, align) = sized_struct.byte_size_and_align(); + + assert_eq!(align, U32PowerOf2::_16); // Alignment adjusted for uniform to multiple of 16 + assert_eq!(size, 16); // Byte size of struct is a multiple of it's alignment + } + + #[test] + fn test_unsized_struct_layout() { + // Test UnsizedStruct with sized fields and a runtime sized array + let mut unsized_struct = SizedStruct::new( + "UnsizedStruct", + "field1", + SizedType::Vector(Vector::new(ScalarType::F32, Len::X2)), // 4 bytes, 4-byte aligned + Repr::Wgsl, + ) + .extend( + "field2", + SizedType::Vector(Vector::new(ScalarType::F32, Len::X1)), // 8 bytes, 8-byte aligned + ) + .extend_unsized( + "runtime_array", + None, + SizedType::Vector(Vector::new(ScalarType::F32, Len::X1)), // 4 bytes per element + ); + + // Test field offsets + let mut field_offsets = unsized_struct.field_offsets(); + let sized_offsets: Vec = field_offsets.sized_field_offsets().collect(); + assert_eq!(sized_offsets, vec![0, 8]); // First field at 0, second at 8 + + // Test last field offset and struct alignment + let (last_offset, struct_align) = field_offsets.last_field_offset_and_struct_align(); + assert_eq!(last_offset, 12); // Runtime array starts at offset 12 + assert_eq!(struct_align, U32PowerOf2::_8); // Struct alignment is 8 + + // Test struct alignment method + assert_eq!(unsized_struct.align(), U32PowerOf2::_8); + + // Test with different repr + unsized_struct.change_all_repr(Repr::WgslUniform); + let mut field_offsets_uniform = unsized_struct.field_offsets(); + let (last_offset_uniform, struct_align_uniform) = field_offsets_uniform.last_field_offset_and_struct_align(); + assert_eq!(last_offset_uniform, 16); // Different offset in uniform, because array's alignment is 16 + assert_eq!(struct_align_uniform, U32PowerOf2::_16); // Uniform struct alignment + } + + #[test] + fn test_packed_struct_layout() { + let sized_struct = SizedStruct::new( + "TestStruct", + "field1", + SizedType::Vector(Vector::new(ScalarType::F32, Len::X3)), // 8 bytes, 16 align (when not packed) + Repr::Packed, + ) + .extend( + "field2", + SizedType::Vector(Vector::new(ScalarType::F32, Len::X1)), // 4 bytes + ); + + let mut field_offsets = sized_struct.field_offsets(); + + // Field 1: offset 0 + assert_eq!(field_offsets.next(), Some(0)); + // Field 2: offset 12, packed directly after field 1, despite 16 alignment of field1, because packed + assert_eq!(field_offsets.next(), Some(12)); + + let (size, align) = sized_struct.byte_size_and_align(); + assert_eq!(size, 16); + assert_eq!(align, U32PowerOf2::_1); + } + + #[test] + fn test_packed_ignores_custom_min_align() { + let mut calc = StructLayoutCalculator::new(Repr::Packed); + + // Add a u32 field with custom min align of 16 + let offset1 = calc.extend(4, U32PowerOf2::_4, None, Some(U32PowerOf2::_16), false); + assert_eq!(offset1, 0); + // Add a vec2 field - should be packed directly after + let offset2 = calc.extend(8, U32PowerOf2::_8, None, None, false); + assert_eq!(offset2, 4); + + assert_eq!(calc.byte_size(), 12); + assert_eq!(calc.align(), U32PowerOf2::_1); // Packed structs always have align of 1 + + let s = SizedStruct::new( + "TestStruct", + "field1", + SizedType::Vector(Vector::new(ScalarType::F32, Len::X2)), // 8 bytes, 8-byte aligned + Repr::Packed, + ) + .extend( + FieldOptions::new("field2", Some(U32PowerOf2::_16), None), + SizedType::Vector(Vector::new(ScalarType::F32, Len::X1)), // 4 bytes, 4-byte aligned + ); + + // The custom min align is ignored in packed structs + assert_eq!(s.byte_size_and_align().1, U32PowerOf2::_1); + let mut offsets = s.field_offsets(); + assert_eq!(offsets.next(), Some(0)); // field1 at offset 0 + assert_eq!(offsets.next(), Some(8)); // field2 at offset 8, because min align is ignored + } +} diff --git a/shame/src/frontend/rust_types/type_layout/recipe/builder.rs b/shame/src/frontend/rust_types/type_layout/recipe/builder.rs new file mode 100644 index 0000000..be1300a --- /dev/null +++ b/shame/src/frontend/rust_types/type_layout/recipe/builder.rs @@ -0,0 +1,259 @@ +use super::*; + +impl TypeLayoutRecipe { + /// Fallibly creates a new `TypeLayoutRecipe` of a struct. + /// + /// An error is returned if the following rules aren't followed: + /// - There must be at least one field. + /// - None of the fields must be an `UnsizedStruct`. + /// - Only the last field may be unsized (a runtime sized array). + pub fn struct_from_parts( + struct_name: impl Into, + fields: impl IntoIterator, + repr: Repr, + ) -> Result { + use StructFromPartsError::*; + + enum Field { + Sized(SizedField), + Unsized(RuntimeSizedArrayField), + } + + let mut fields = fields + .into_iter() + .map(|(options, ty)| { + Ok(match ty { + TypeLayoutRecipe::Sized(s) => Field::Sized(SizedField::new(options, s)), + TypeLayoutRecipe::RuntimeSizedArray(a) => Field::Unsized(RuntimeSizedArrayField::new( + options.name, + options.custom_min_align, + a.element, + )), + TypeLayoutRecipe::UnsizedStruct(_) => return Err(MustNotHaveUnsizedStructField), + }) + }) + .peekable(); + + let mut sized_fields = Vec::new(); + let mut last_unsized = None; + while let Some(field) = fields.next() { + let field = field?; + match field { + Field::Sized(sized) => sized_fields.push(sized), + Field::Unsized(a) => { + last_unsized = Some(a); + if fields.peek().is_some() { + return Err(OnlyLastFieldMayBeUnsized); + } + } + } + } + + let field_count = sized_fields.len() + last_unsized.is_some() as usize; + if field_count == 0 { + return Err(MustHaveAtLeastOneField); + } + + if let Some(last_unsized) = last_unsized { + Ok(UnsizedStruct { + name: struct_name.into(), + sized_fields, + last_unsized, + repr, + } + .into()) + } else { + Ok(SizedStruct::from_parts(struct_name, sized_fields, repr).into()) + } + } +} + +#[allow(missing_docs)] +#[derive(thiserror::Error, Debug)] +pub enum StructFromPartsError { + #[error("Struct must have at least one field.")] + MustHaveAtLeastOneField, + #[error("Only the last field of a struct may be unsized.")] + OnlyLastFieldMayBeUnsized, + #[error("A field of the struct is an unsized struct, which isn't allowed.")] + MustNotHaveUnsizedStructField, +} + +impl SizedStruct { + /// Creates a new `SizedStruct` with one field. + /// + /// To add additional fields to it, use [`SizedStruct::extend`] or [`SizedStruct::extend_unsized`]. + pub fn new( + name: impl Into, + field_options: impl Into, + ty: impl Into, + repr: Repr, + ) -> Self { + Self { + name: name.into(), + fields: vec![SizedField::new(field_options, ty)], + repr, + } + } + + /// Adds a sized field to the struct. + pub fn extend(mut self, field_options: impl Into, ty: impl Into) -> Self { + self.fields.push(SizedField::new(field_options, ty)); + self + } + + /// Adds a runtime sized array field to the struct. This can only be the last + /// field of a struct, which is ensured by transitioning to an UnsizedStruct. + pub fn extend_unsized( + self, + name: impl Into, + custom_min_align: Option, + element_ty: impl Into, + ) -> UnsizedStruct { + UnsizedStruct { + name: self.name, + sized_fields: self.fields, + last_unsized: RuntimeSizedArrayField::new(name, custom_min_align, element_ty), + repr: self.repr, + } + } + + /// Adds either a `SizedType` or a `RuntimeSizedArray` field to the struct. + /// + /// Returns a `TypeLayoutRecipe`, because the `Self` may either stay + /// a `SizedStruct` or become an `UnsizedStruct` depending on the field's type. + pub fn extend_sized_or_array( + self, + field_options: impl Into, + field: SizedOrArray, + ) -> TypeLayoutRecipe { + let options = field_options.into(); + match field { + SizedOrArray::Sized(ty) => self.extend(options, ty).into(), + SizedOrArray::RuntimeSizedArray(a) => self + .extend_unsized(options.name, options.custom_min_align, a.element) + .into(), + } + } + + /// The fields of this struct. + pub fn fields(&self) -> &[SizedField] { &self.fields } + + pub(crate) fn from_parts(name: impl Into, fields: Vec, repr: Repr) -> Self { + Self { + name: name.into(), + fields, + repr, + } + } +} + +#[allow(missing_docs)] +pub enum SizedOrArray { + Sized(SizedType), + RuntimeSizedArray(RuntimeSizedArray), +} + +#[allow(missing_docs)] +#[derive(thiserror::Error, Debug)] +#[error("`LayoutType` is `UnsizedStruct`, which is not a variant of `SizedOrArray`")] +pub struct IsUnsizedStructError; +impl TryFrom for SizedOrArray { + type Error = IsUnsizedStructError; + + fn try_from(value: TypeLayoutRecipe) -> Result { + match value { + TypeLayoutRecipe::Sized(sized) => Ok(SizedOrArray::Sized(sized)), + TypeLayoutRecipe::RuntimeSizedArray(array) => Ok(SizedOrArray::RuntimeSizedArray(array)), + TypeLayoutRecipe::UnsizedStruct(_) => Err(IsUnsizedStructError), + } + } +} + +impl SizedField { + /// Creates a new `SizedField`. + pub fn new(options: impl Into, ty: impl Into) -> Self { + let options = options.into(); + Self { + name: options.name, + custom_min_size: options.custom_min_size, + custom_min_align: options.custom_min_align, + ty: ty.into(), + } + } +} + +impl RuntimeSizedArrayField { + /// Creates a new `RuntimeSizedArrayField` given it's field name, + /// an optional custom minimum align and it's element type. + pub fn new( + name: impl Into, + custom_min_align: Option, + element_ty: impl Into, + ) -> Self { + Self { + name: name.into(), + custom_min_align, + array: RuntimeSizedArray { + element: element_ty.into(), + }, + } + } +} + +impl SizedArray { + /// Creates a new `SizedArray` from it's element type and length. + pub fn new(element_ty: Rc, len: NonZeroU32) -> Self { + Self { + element: element_ty, + len, + } + } +} + +impl RuntimeSizedArray { + /// Creates a new `RuntimeSizedArray` from it's element type. + pub fn new(element_ty: impl Into) -> Self { + RuntimeSizedArray { + element: element_ty.into(), + } + } +} + +/// Options for the field of a struct. +/// +/// If you only want to customize the field's name, you can convert most string types +/// to `FieldOptions` using `Into::into`. For methods that take `impl Into` +/// parameters you can just pass the string type directly. +#[derive(Debug, Clone)] +pub struct FieldOptions { + /// Name of the field + pub name: CanonName, + /// Custom minimum align of the field. + pub custom_min_align: Option, + /// Custom mininum size of the field. + pub custom_min_size: Option, +} + +impl FieldOptions { + /// Creates new `FieldOptions`. + /// + /// If you only want to customize the field's name, you can convert most string types + /// to `FieldOptions` using `Into::into`. For methods that take `impl Into` + /// parameters you can just pass the string type directly. + pub fn new( + name: impl Into, + custom_min_align: Option, + custom_min_size: Option, + ) -> Self { + Self { + name: name.into(), + custom_min_align, + custom_min_size, + } + } +} + +impl> From for FieldOptions { + fn from(name: T) -> Self { Self::new(name, None, None) } +} diff --git a/shame/src/frontend/rust_types/type_layout/recipe/ir_compat.rs b/shame/src/frontend/rust_types/type_layout/recipe/ir_compat.rs new file mode 100644 index 0000000..329a68b --- /dev/null +++ b/shame/src/frontend/rust_types/type_layout/recipe/ir_compat.rs @@ -0,0 +1,397 @@ +use crate::GpuLayout; + +use super::*; + +// Conversions to ir types // + +/// Errors that can occur when converting IR types to recipe types. +#[derive(thiserror::Error, Debug)] +pub enum IRConversionError { + /// Packed vectors do not exist in the shader type system. + #[error("Type is or contains a packed vector, which does not exist in the shader type system.")] + ContainsPackedVector, + /// Struct field names must be unique in the shader type system. + #[error("{0}")] + DuplicateFieldName(#[from] DuplicateFieldNameError), +} + +#[derive(Debug)] +pub struct DuplicateFieldNameError { + pub struct_type: StructKind, + pub first_occurence: usize, + pub second_occurence: usize, + pub use_color: bool, +} + +impl std::error::Error for DuplicateFieldNameError {} + +impl std::fmt::Display for DuplicateFieldNameError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let (struct_name, sized_fields, last_unsized) = match &self.struct_type { + StructKind::Sized(s) => (&s.name, s.fields(), None), + StructKind::Unsized(s) => (&s.name, s.sized_fields.as_slice(), Some(&s.last_unsized)), + }; + + let indent = " "; + let is_duplicate = |i| self.first_occurence == i || self.second_occurence == i; + let arrow = |i| match is_duplicate(i) { + true => " <--", + false => "", + }; + let color = |f: &mut Formatter<'_>, i| { + if (self.use_color && is_duplicate(i)) { + set_color(f, Some("#508EE3"), false) + } else { + Ok(()) + } + }; + let color_reset = |f: &mut Formatter<'_>, i| { + if self.use_color && is_duplicate(i) { + set_color(f, None, false) + } else { + Ok(()) + } + }; + + writeln!( + f, + "Type contains or is a struct with duplicate field names.\ + Field names must be unique in the shader type system.\n\ + The following struct contains duplicate field names:" + )?; + let header = writeln!(f, "struct {struct_name} {{"); + for (i, field) in sized_fields.iter().enumerate() { + color(f, i)?; + writeln!(f, "{indent}{}: {},{}", field.name, field.ty, arrow(i))?; + color_reset(f, i)?; + } + if let Some(field) = last_unsized { + let i = sized_fields.len(); + color(f, i)?; + writeln!(f, "{indent}{}: {},{}", field.name, field.array, arrow(i))?; + color_reset(f, i)?; + } + writeln!(f, "}}")?; + + Ok(()) + } +} + +#[track_caller] +fn should_use_color() -> bool { + Context::try_with(call_info!(), |ctx| ctx.settings().colored_error_messages).unwrap_or(false) +} + +impl TryFrom for ir::StoreType { + type Error = IRConversionError; + + fn try_from(ty: TypeLayoutRecipe) -> Result { + match ty { + TypeLayoutRecipe::Sized(s) => Ok(ir::StoreType::Sized(s.try_into()?)), + TypeLayoutRecipe::RuntimeSizedArray(s) => Ok(ir::StoreType::RuntimeSizedArray(s.element.try_into()?)), + TypeLayoutRecipe::UnsizedStruct(s) => Ok(ir::StoreType::BufferBlock(s.try_into()?)), + } + } +} + +impl TryFrom for ir::SizedType { + type Error = IRConversionError; + + fn try_from(host: SizedType) -> Result { + Ok(match host { + SizedType::Vector(v) => ir::SizedType::Vector(v.len, v.scalar.into()), + SizedType::Matrix(m) => ir::SizedType::Matrix(m.columns, m.rows, m.scalar), + SizedType::Array(a) => { + let element = Rc::unwrap_or_clone(a.element); + let converted_element = element.try_into()?; + ir::SizedType::Array(Rc::new(converted_element), a.len) + } + SizedType::Atomic(i) => ir::SizedType::Atomic(i.scalar), + SizedType::PackedVec(_) => return Err(IRConversionError::ContainsPackedVector), + SizedType::Struct(s) => ir::SizedType::Structure(s.try_into()?), + }) + } +} + +impl From for ir::ScalarType { + fn from(scalar_type: ScalarType) -> Self { + match scalar_type { + ScalarType::F16 => ir::ScalarType::F16, + ScalarType::F32 => ir::ScalarType::F32, + ScalarType::F64 => ir::ScalarType::F64, + ScalarType::U32 => ir::ScalarType::U32, + ScalarType::I32 => ir::ScalarType::I32, + } + } +} + +impl SizedStruct { + fn fields_split_last(&self) -> (&SizedField, &[SizedField]) { + self.fields.split_last().expect("guaranteed to have at least one field") + } +} + +impl TryFrom for ir::ir_type::SizedStruct { + type Error = IRConversionError; + + fn try_from(ty: SizedStruct) -> Result { + let (last_field, first_fields) = ty.fields_split_last(); + let first_fields: Result, _> = first_fields.iter().map(|f| f.clone().try_into()).collect(); + let last_field_ir = last_field.clone().try_into()?; + + match ir::ir_type::SizedStruct::new_nonempty(ty.name.clone(), first_fields?, last_field_ir) { + Ok(s) => Ok(s), + Err(StructureFieldNamesMustBeUnique { + first_occurence, + second_occurence, + }) => Err(IRConversionError::DuplicateFieldName(DuplicateFieldNameError { + struct_type: StructKind::Sized(ty), + first_occurence, + second_occurence, + use_color: should_use_color(), + })), + } + } +} + +impl TryFrom for ir::ir_type::BufferBlock { + type Error = IRConversionError; + + fn try_from(ty: UnsizedStruct) -> Result { + let mut sized_fields: Vec = Vec::new(); + + for field in ty.sized_fields.iter() { + sized_fields.push(field.clone().try_into()?); + } + let last_unsized = ty.last_unsized.clone().try_into()?; + + match ir::ir_type::BufferBlock::new(ty.name.clone(), sized_fields, Some(last_unsized)) { + Ok(b) => Ok(b), + Err(BufferBlockDefinitionError::FieldNamesMustBeUnique(StructureFieldNamesMustBeUnique { + first_occurence, + second_occurence, + })) => Err(IRConversionError::DuplicateFieldName(DuplicateFieldNameError { + struct_type: StructKind::Unsized(ty), + first_occurence, + second_occurence, + use_color: should_use_color(), + })), + Err(BufferBlockDefinitionError::MustHaveAtLeastOneField) => { + unreachable!("last_unsized is at least one field") + } + } + } +} + +impl TryFrom for ir::StoreType { + type Error = IRConversionError; + + fn try_from(array: RuntimeSizedArray) -> Result { + Ok(ir::StoreType::RuntimeSizedArray(array.element.try_into()?)) + } +} + +impl TryFrom for ir::ir_type::SizedField { + type Error = IRConversionError; + + fn try_from(f: SizedField) -> Result { + Ok(ir::SizedField::new( + f.name, + f.custom_min_size, + f.custom_min_align, + f.ty.try_into()?, + )) + } +} + +impl TryFrom for ir::ir_type::RuntimeSizedArrayField { + type Error = IRConversionError; + + fn try_from(f: RuntimeSizedArrayField) -> Result { + Ok(ir::RuntimeSizedArrayField::new( + f.name, + f.custom_min_align, + f.array.element.try_into()?, + )) + } +} + + +// Conversions from ir types // + +/// Type contains bools, which doesn't have a known layout. +#[derive(thiserror::Error, Debug)] +#[error("Type contains bools, which doesn't have a known layout.")] +pub struct ContainsBoolsError; + +/// Errors that can occur when converting IR types to recipe types. +#[allow(missing_docs)] +#[derive(thiserror::Error, Debug)] +pub enum RecipeConversionError { + #[error("Type contains bools, which don't have a standardized memory layout on the gpu.")] + ContainsBool, + #[error("Type is a handle, which don't have a standardized memory layout.")] + IsHandle, +} + +impl TryFrom for ScalarType { + type Error = ContainsBoolsError; + + fn try_from(value: ir::ScalarType) -> Result { + Ok(match value { + ir::ScalarType::F16 => ScalarType::F16, + ir::ScalarType::F32 => ScalarType::F32, + ir::ScalarType::F64 => ScalarType::F64, + ir::ScalarType::U32 => ScalarType::U32, + ir::ScalarType::I32 => ScalarType::I32, + ir::ScalarType::Bool => return Err(ContainsBoolsError), + }) + } +} + +impl TryFrom for SizedType { + type Error = ContainsBoolsError; + + fn try_from(value: ir::SizedType) -> Result { + Ok(match value { + ir::SizedType::Vector(len, scalar) => SizedType::Vector(Vector { + scalar: scalar.try_into()?, + len, + }), + ir::SizedType::Matrix(columns, rows, scalar) => SizedType::Matrix(Matrix { scalar, columns, rows }), + ir::SizedType::Array(element, len) => SizedType::Array(SizedArray { + element: Rc::new((*element).clone().try_into()?), + len, + }), + ir::SizedType::Atomic(scalar_type) => SizedType::Atomic(Atomic { scalar: scalar_type }), + ir::SizedType::Structure(structure) => SizedType::Struct(structure.try_into()?), + }) + } +} + +impl TryFrom for SizedStruct { + type Error = ContainsBoolsError; + + fn try_from(structure: ir::ir_type::SizedStruct) -> Result { + let mut fields = Vec::new(); + + for field in structure.fields() { + fields.push(SizedField { + name: field.name.clone(), + custom_min_size: field.custom_min_size, + custom_min_align: field.custom_min_align, + ty: field.ty.clone().try_into()?, + }); + } + + Ok(SizedStruct { + name: structure.name().clone(), + fields, + // TODO(chronicl) hardcoding this is a temporary solution. This whole + // TryFrom should be removed in future PRs. + repr: Repr::Wgsl, + }) + } +} + +impl From for RecipeConversionError { + fn from(_: ContainsBoolsError) -> Self { Self::ContainsBool } +} + +impl TryFrom for TypeLayoutRecipe { + type Error = RecipeConversionError; + + fn try_from(value: ir::StoreType) -> Result { + Ok(match value { + ir::StoreType::Sized(sized_type) => TypeLayoutRecipe::Sized(sized_type.try_into()?), + ir::StoreType::RuntimeSizedArray(element) => TypeLayoutRecipe::RuntimeSizedArray(RuntimeSizedArray { + element: element.try_into()?, + }), + ir::StoreType::BufferBlock(buffer_block) => buffer_block.try_into()?, + ir::StoreType::Handle(_) => return Err(RecipeConversionError::IsHandle), + }) + } +} + +impl TryFrom for TypeLayoutRecipe { + type Error = ContainsBoolsError; + + fn try_from(buffer_block: ir::ir_type::BufferBlock) -> Result { + let mut sized_fields = Vec::new(); + + for field in buffer_block.sized_fields() { + sized_fields.push(SizedField { + name: field.name.clone(), + custom_min_size: field.custom_min_size, + custom_min_align: field.custom_min_align, + ty: field.ty.clone().try_into()?, + }); + } + + let last_unsized = if let Some(last_field) = buffer_block.last_unsized_field() { + RuntimeSizedArrayField { + name: last_field.name.clone(), + custom_min_align: last_field.custom_min_align, + array: RuntimeSizedArray { + element: last_field.element_ty.clone().try_into()?, + }, + } + } else { + return Ok(SizedStruct { + name: buffer_block.name().clone(), + fields: sized_fields, + // TODO(chronicl) hardcoding this is a temporary solution. This whole + // TryFrom should be removed in future PRs. + repr: Repr::Wgsl, + } + .into()); + }; + + Ok(UnsizedStruct { + name: buffer_block.name().clone(), + sized_fields, + last_unsized, + // TODO(chronicl) hardcoding this is a temporary solution. This whole + // TryFrom should be removed in future PRs. + repr: Repr::Wgsl, + } + .into()) + } +} + +#[derive(Debug, Clone)] +pub enum StructKind { + Sized(SizedStruct), + Unsized(UnsizedStruct), +} + +impl From for StructKind { + fn from(value: SizedStruct) -> Self { StructKind::Sized(value) } +} +impl From for StructKind { + fn from(value: UnsizedStruct) -> Self { StructKind::Unsized(value) } +} + +#[test] +fn test_ir_conversion_error() { + use crate::{f32x1, packed::unorm8x2}; + + let ty: TypeLayoutRecipe = SizedStruct::new("A", "a", f32x1::layout_recipe_sized(), Repr::Wgsl) + .extend("b", f32x1::layout_recipe_sized()) + .extend("a", f32x1::layout_recipe_sized()) + .into(); + let result: Result = ty.try_into(); + assert!(matches!( + result, + Err(IRConversionError::DuplicateFieldName(DuplicateFieldNameError { + struct_type: StructKind::Sized(_), + first_occurence: 0, + second_occurence: 2, + .. + })) + )); + + let ty: TypeLayoutRecipe = SizedStruct::new("A", "a", unorm8x2::layout_recipe_sized(), Repr::Wgsl).into(); + let result: Result = ty.try_into(); + assert!(matches!(result, Err(IRConversionError::ContainsPackedVector))); +} diff --git a/shame/src/frontend/rust_types/type_layout/recipe/mod.rs b/shame/src/frontend/rust_types/type_layout/recipe/mod.rs new file mode 100644 index 0000000..04c349e --- /dev/null +++ b/shame/src/frontend/rust_types/type_layout/recipe/mod.rs @@ -0,0 +1,270 @@ +//! This module defines types that can be laid out in memory. + +use std::{fmt::Formatter, num::NonZeroU32, rc::Rc}; + +use crate::{ + any::U32PowerOf2, + call_info, + common::prettify::set_color, + ir::{self, ir_type::BufferBlockDefinitionError, recording::Context, StructureFieldNamesMustBeUnique}, + GpuSized, +}; + +pub use crate::ir::{Len, Len2, PackedVector, ScalarTypeFp, ScalarTypeInteger, ir_type::CanonName}; +use super::{Repr}; + +pub(crate) mod align_size; +pub(crate) mod builder; +pub(crate) mod ir_compat; +pub(crate) mod to_layout; + +pub use align_size::{FieldOffsets, MatrixMajor, StructLayoutCalculator, array_size, array_stride, array_align}; +pub use builder::{SizedOrArray, FieldOptions}; + +/// `TypeLayoutRecipe` describes how a type should be laid out in memory. +/// +/// It does not contain any layout information itself, but can be converted to a `TypeLayout` +/// using the `TypeLayoutRecipe::layout` method. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum TypeLayoutRecipe { + /// A type with a known size. + Sized(SizedType), + /// A struct with a runtime sized array as it's last field. + UnsizedStruct(UnsizedStruct), + /// An array whose size is determined at runtime. + RuntimeSizedArray(RuntimeSizedArray), +} + +/// Types that have a size which is known at shader creation time. +#[allow(missing_docs)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum SizedType { + Vector(Vector), + Matrix(Matrix), + Array(SizedArray), + Atomic(Atomic), + PackedVec(PackedVector), + Struct(SizedStruct), +} + +#[allow(missing_docs)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Vector { + pub scalar: ScalarType, + pub len: Len, +} + +#[allow(missing_docs)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Matrix { + pub scalar: ScalarTypeFp, + pub columns: Len2, + pub rows: Len2, +} + +#[allow(missing_docs)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SizedArray { + pub element: Rc, + pub len: NonZeroU32, +} + +#[allow(missing_docs)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Atomic { + pub scalar: ScalarTypeInteger, +} + +#[allow(missing_docs)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct RuntimeSizedArray { + pub element: SizedType, +} + +/// Scalar types with known memory layout. +/// +/// Same as `ir::ScalarType`, but without `ScalarType::Bool` since booleans +/// don't have a standardized memory representation. +#[allow(missing_docs)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ScalarType { + F16, + F32, + U32, + I32, + F64, +} + +/// A struct with a known fixed size. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SizedStruct { + /// The canonical name of the struct. + pub name: CanonName, + // This is private to ensure a `SizedStruct` always has at least one field. + fields: Vec, + /// The representation/layout rules for this struct. See [`Repr`] for more details. + pub repr: Repr, +} + +/// A struct whose size is not known before shader runtime. +/// +/// This struct has a runtime sized array as it's last field. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct UnsizedStruct { + /// The canonical name of the struct. + pub name: CanonName, + /// Fixed-size fields that come before the unsized field + pub sized_fields: Vec, + /// Last runtime sized array field of the struct. + pub last_unsized: RuntimeSizedArrayField, + /// The representation/layout rules for this struct. See [`Repr`] for more details. + pub repr: Repr, +} + +#[allow(missing_docs)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SizedField { + pub name: CanonName, + pub custom_min_size: Option, + pub custom_min_align: Option, + pub ty: SizedType, +} + +#[allow(missing_docs)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct RuntimeSizedArrayField { + pub name: CanonName, + pub custom_min_align: Option, + pub array: RuntimeSizedArray, +} + +// Conversions to ScalarType, SizedType and TypeLayoutRecipe // + +macro_rules! impl_into_sized_type { + ($($ty:ident -> $variant:path),*) => { + $( + impl From<$ty> for SizedType { + fn from(v: $ty) -> Self { $variant(v) } + } + )* + }; +} +impl_into_sized_type!( + Vector -> SizedType::Vector, + Matrix -> SizedType::Matrix, + SizedArray -> SizedType::Array, + Atomic -> SizedType::Atomic, + SizedStruct -> SizedType::Struct, + PackedVector -> SizedType::PackedVec +); + +impl From for TypeLayoutRecipe +where + SizedType: From, +{ + fn from(value: T) -> Self { TypeLayoutRecipe::Sized(SizedType::from(value)) } +} + +impl From for TypeLayoutRecipe { + fn from(s: UnsizedStruct) -> Self { TypeLayoutRecipe::UnsizedStruct(s) } +} +impl From for TypeLayoutRecipe { + fn from(a: RuntimeSizedArray) -> Self { TypeLayoutRecipe::RuntimeSizedArray(a) } +} + +impl ScalarTypeInteger { + pub const fn as_scalar_type(self) -> ScalarType { + match self { + ScalarTypeInteger::I32 => ScalarType::I32, + ScalarTypeInteger::U32 => ScalarType::U32, + } + } +} +impl From for ScalarType { + fn from(int: ScalarTypeInteger) -> Self { int.as_scalar_type() } +} +impl ScalarTypeFp { + pub const fn as_scalar_type(self) -> ScalarType { + match self { + ScalarTypeFp::F16 => ScalarType::F16, + ScalarTypeFp::F32 => ScalarType::F32, + ScalarTypeFp::F64 => ScalarType::F64, + } + } +} +impl From for ScalarType { + fn from(int: ScalarTypeFp) -> Self { int.as_scalar_type() } +} + +// Display impls + +impl std::fmt::Display for TypeLayoutRecipe { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + TypeLayoutRecipe::Sized(s) => s.fmt(f), + TypeLayoutRecipe::RuntimeSizedArray(a) => a.fmt(f), + TypeLayoutRecipe::UnsizedStruct(s) => s.fmt(f), + } + } +} + +impl std::fmt::Display for SizedType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + SizedType::Vector(v) => v.fmt(f), + SizedType::Matrix(m) => m.fmt(f), + SizedType::Array(a) => a.fmt(f), + SizedType::Atomic(a) => a.fmt(f), + SizedType::PackedVec(p) => p.fmt(f), + SizedType::Struct(s) => s.fmt(f), + } + } +} + +impl std::fmt::Display for Vector { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}{}", self.scalar, self.len) } +} + +impl std::fmt::Display for Matrix { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "mat<{}, {}, {}>", + self.scalar, + Len::from(self.columns), + Len::from(self.rows) + ) + } +} + +impl std::fmt::Display for SizedArray { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "Array<{}, {}>", &*self.element, self.len) } +} + +impl std::fmt::Display for Atomic { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "Atomic<{}>", ScalarType::from(self.scalar)) } +} + +impl std::fmt::Display for SizedStruct { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.name) } +} + +impl std::fmt::Display for UnsizedStruct { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.name) } +} + +impl std::fmt::Display for RuntimeSizedArray { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "Array<{}>", self.element) } +} + +impl std::fmt::Display for ScalarType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + ScalarType::F16 => "f16", + ScalarType::F32 => "f32", + ScalarType::F64 => "f64", + ScalarType::U32 => "u32", + ScalarType::I32 => "i32", + }) + } +} diff --git a/shame/src/frontend/rust_types/type_layout/recipe/to_layout.rs b/shame/src/frontend/rust_types/type_layout/recipe/to_layout.rs new file mode 100644 index 0000000..2cac7ee --- /dev/null +++ b/shame/src/frontend/rust_types/type_layout/recipe/to_layout.rs @@ -0,0 +1,268 @@ +use std::rc::Rc; +use crate::{ + frontend::rust_types::type_layout::{ + ArrayLayout, FieldLayout, MatrixLayout, PackedVectorLayout, Repr, StructLayout, VectorLayout, + }, + ir, TypeLayout, +}; +use super::{ + Atomic, TypeLayoutRecipe, Matrix, PackedVector, RuntimeSizedArray, SizedArray, SizedField, SizedStruct, SizedType, + UnsizedStruct, Vector, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RecipeContains { + CustomFieldAlign, + CustomFieldSize, + PackedVector, +} + +impl std::fmt::Display for RecipeContains { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RecipeContains::CustomFieldAlign => write!(f, "custom field alignment attribute"), + RecipeContains::CustomFieldSize => write!(f, "custom field size attribute"), + RecipeContains::PackedVector => write!(f, "packed vector"), + } + } +} + +impl TypeLayoutRecipe { + /// Returns the layout of this type recipe using `Repr::default` for top level types + /// that aren't structs. + pub fn layout(&self) -> TypeLayout { self.layout_with_default_repr(Repr::default()) } + + /// Returns the layout of this type recipe using the given `default_repr` for top level types + /// that aren't structs. + pub fn layout_with_default_repr(&self, default_repr: Repr) -> TypeLayout { + match self { + TypeLayoutRecipe::Sized(ty) => ty.layout(default_repr), + TypeLayoutRecipe::UnsizedStruct(ty) => ty.layout().into(), + TypeLayoutRecipe::RuntimeSizedArray(ty) => ty.layout(default_repr).into(), + } + } + + /// Checks whether `RecipeContains` is contained in this type recipe. + pub fn contains(&self, c: RecipeContains) -> bool { + match self { + TypeLayoutRecipe::Sized(ty) => ty.contains(c), + TypeLayoutRecipe::UnsizedStruct(ty) => ty.contains(c), + TypeLayoutRecipe::RuntimeSizedArray(ty) => ty.contains(c), + } + } +} + +#[allow(missing_docs)] +impl SizedType { + pub fn layout(&self, parent_repr: Repr) -> TypeLayout { + match &self { + SizedType::Vector(v) => v.layout(parent_repr).into(), + SizedType::Atomic(a) => a.layout(parent_repr).into(), + SizedType::Matrix(m) => m.layout(parent_repr).into(), + SizedType::Array(a) => a.layout(parent_repr).into(), + SizedType::PackedVec(v) => v.layout(parent_repr).into(), + SizedType::Struct(s) => s.layout().into(), + } + } + + pub fn contains(&self, c: RecipeContains) -> bool { + match c { + RecipeContains::CustomFieldAlign | RecipeContains::CustomFieldSize | RecipeContains::PackedVector => { + match self { + SizedType::Vector(_) => false, + SizedType::Atomic(_) => false, + SizedType::Matrix(_) => false, + SizedType::Array(a) => a.contains(c), + SizedType::PackedVec(_) => c == RecipeContains::PackedVector, + SizedType::Struct(s) => s.contains(c), + } + } + } + } +} + +#[allow(missing_docs)] +impl Vector { + pub fn layout(&self, parent_repr: Repr) -> VectorLayout { + VectorLayout { + byte_size: self.byte_size(parent_repr), + align: self.align(parent_repr).into(), + ty: *self, + debug_is_atomic: false, + } + } +} + +#[allow(missing_docs)] +impl Matrix { + pub fn layout(&self, parent_repr: Repr) -> MatrixLayout { + MatrixLayout { + byte_size: self.byte_size(parent_repr), + align: self.align(parent_repr).into(), + ty: *self, + } + } +} + +#[allow(missing_docs)] +impl Atomic { + pub fn layout(&self, parent_repr: Repr) -> VectorLayout { + // Atomic types are represented as vectors of length 1. + let vector = Vector::new(self.scalar.into(), ir::Len::X1); + let mut layout = vector.layout(parent_repr); + layout.debug_is_atomic = true; + layout + } +} + +impl PackedVector { + pub fn layout(&self, parent_repr: Repr) -> PackedVectorLayout { + PackedVectorLayout { + byte_size: self.byte_size().as_u64(), + align: self.align(parent_repr).into(), + ty: *self, + } + } +} + +#[allow(missing_docs)] +impl SizedArray { + pub fn layout(&self, parent_repr: Repr) -> ArrayLayout { + ArrayLayout { + byte_size: self.byte_size(parent_repr).into(), + align: self.align(parent_repr).into(), + byte_stride: self.byte_stride(parent_repr), + element_ty: self.element.layout(parent_repr), + len: Some(self.len.get()), + } + } + + pub fn contains(&self, c: RecipeContains) -> bool { + match c { + RecipeContains::CustomFieldAlign | RecipeContains::CustomFieldSize | RecipeContains::PackedVector => { + self.element.contains(c) + } + } + } +} + +#[allow(missing_docs)] +impl SizedStruct { + pub fn layout(&self) -> StructLayout { + let mut field_offsets = self.field_offsets(); + let fields = (&mut field_offsets) + .zip(self.fields()) + .map(|(offset, field)| sized_field_to_field_layout(field, offset, self.repr)) + .collect::>(); + + let (byte_size, align) = field_offsets.struct_byte_size_and_align(); + + StructLayout { + byte_size: Some(byte_size), + align: align.into(), + name: self.name.clone().into(), + fields, + } + } + + pub fn contains(&self, c: RecipeContains) -> bool { + let mut contains = false; + match c { + RecipeContains::CustomFieldAlign => { + contains |= self.fields.iter().any(|f| f.custom_min_align.is_some()); + } + RecipeContains::CustomFieldSize => { + contains |= self.fields.iter().any(|f| f.custom_min_size.is_some()); + } + RecipeContains::PackedVector => {} + } + for field in self.fields.iter() { + contains |= field.ty.contains(c); + } + contains + } +} + +fn sized_field_to_field_layout(field: &SizedField, offset: u64, repr: Repr) -> FieldLayout { + let mut ty = field.ty.layout(repr); + // VERY IMPORTANT: TypeLayout::from_sized_type does not take into account + // custom_min_align and custom_min_size, but field.byte_size and field.align do. + ty.set_byte_size(field.byte_size(repr)); + *ty.align_mut() = field.align(repr); + + FieldLayout { + rel_byte_offset: offset, + name: field.name.clone(), + ty, + } +} + +#[allow(missing_docs)] +impl UnsizedStruct { + pub fn layout(&self) -> StructLayout { + let mut field_offsets = self.field_offsets(); + let mut fields = (&mut field_offsets.sized_field_offsets()) + .zip(self.sized_fields.iter()) + .map(|(offset, field)| sized_field_to_field_layout(field, offset, self.repr)) + .collect::>(); + + let (field_offset, align) = field_offsets.last_field_offset_and_struct_align(); + + let mut ty = self.last_unsized.array.layout(self.repr); + // VERY IMPORTANT: TypeLayout::from_runtime_sized_array does not take into account + // custom_min_align, but s.last_unsized.align does. + ty.align = self.last_unsized.align(self.repr).into(); + + fields.push(FieldLayout { + rel_byte_offset: field_offset, + name: self.last_unsized.name.clone(), + ty: ty.into(), + }); + + StructLayout { + byte_size: None, + align: align.into(), + name: self.name.clone().into(), + fields, + } + } + + pub fn contains(&self, c: RecipeContains) -> bool { + let mut contains = false; + match c { + RecipeContains::CustomFieldAlign => { + contains |= self.sized_fields.iter().any(|f| f.custom_min_align.is_some()); + contains |= self.last_unsized.custom_min_align.is_some(); + } + RecipeContains::CustomFieldSize => { + contains |= self.sized_fields.iter().any(|f| f.custom_min_size.is_some()); + } + RecipeContains::PackedVector => {} + } + for field in self.sized_fields.iter() { + contains |= field.ty.contains(c); + } + contains + } +} + +#[allow(missing_docs)] +impl RuntimeSizedArray { + pub fn layout(&self, parent_repr: Repr) -> ArrayLayout { + ArrayLayout { + byte_size: None, + align: self.align(parent_repr).into(), + byte_stride: self.byte_stride(parent_repr), + element_ty: self.element.layout(parent_repr), + len: None, + } + } + + pub fn contains(&self, c: RecipeContains) -> bool { + match c { + RecipeContains::CustomFieldAlign | RecipeContains::CustomFieldSize | RecipeContains::PackedVector => { + self.element.contains(c) + } + } + } +} diff --git a/shame/src/frontend/rust_types/type_traits.rs b/shame/src/frontend/rust_types/type_traits.rs index 0157f57..927effc 100644 --- a/shame/src/frontend/rust_types/type_traits.rs +++ b/shame/src/frontend/rust_types/type_traits.rs @@ -3,9 +3,13 @@ use super::{ layout_traits::{FromAnys, GetAllFields, GpuLayout}, mem::{self, AddressSpace}, reference::{AccessMode, AccessModeReadable}, + type_layout::{self}, AsAny, GpuType, ToGpuType, }; -use crate::frontend::any::shared_io::{BindPath, BindingType}; +use crate::{ + frontend::any::shared_io::{BindPath, BindingType}, + TypeLayout, +}; use crate::{ call_info, common::proc_macro_utils::push_wrong_amount_of_args_error, @@ -168,7 +172,7 @@ pub trait GpuSized: GpuAligned { )] /// ## known byte-alignment on the gpu /// types that have a byte-alignment on the graphics device that is known at rust compile-time -pub trait GpuAligned: GpuLayout { +pub trait GpuAligned { #[doc(hidden)] // runtime api fn aligned_ty() -> AlignedType where @@ -179,41 +183,44 @@ pub trait GpuAligned: GpuLayout { message = "`{Self}` may contain `bool`s, which have an unspecified memory footprint on the graphics device." )] // implementor note: -// NoXYZ traits must require GpuLayout or some other base trait, so that the -// error message isn't misleading for user provided types `T`. Those types will show +// NoXYZ traits should require some other base trait, so that the +// error message isn't misleading for user provided types `T`. Those types will then show // the base trait diagnostic, instead of "`T` contains `XYZ`" which it doesn't. /// types that don't contain booleans at any nesting level /// /// boolean types do not have a defined size on gpus. /// You may want to use unsigned integers for transferring boolean data instead. -pub trait NoBools: GpuLayout {} +pub trait NoBools {} /// (no documentation yet) #[diagnostic::on_unimplemented( message = "`{Self}` may be or contain a `shame::Atomic` type. Atomics are usable via `shame::BufferRef<_, Storage, ReadWrite>` or via allocations in workgroup memory" )] // implementor note: -// NoXYZ traits must require GpuLayout or some other base trait, so that the -// error message isn't misleading for user provided types `T`. Those types will show +// NoXYZ traits should require some other base trait, so that the +// error message isn't misleading for user provided types `T`. Those types will then show // the base trait diagnostic, instead of "`T` contains `XYZ`" which it doesn't. /// types that don't contain atomics at any nesting level -pub trait NoAtomics: GpuLayout {} +pub trait NoAtomics {} -// implementor note: -// NoXYZ traits must require GpuLayout or some other base trait, so that the -// error message isn't misleading for user provided types `T`. Those types will show -// the base trait diagnostic, instead of "`T` contains `XYZ`" which it doesn't. #[diagnostic::on_unimplemented( message = "`{Self}` may be or contain a handle type such as `Texture`, `Sampler`, `StorageTexture`." )] +// implementor note: +// NoXYZ traits should require some other base trait, so that the +// error message isn't misleading for user provided types `T`. Those types will then show +// the base trait diagnostic, instead of "`T` contains `XYZ`" which it doesn't. + /// Implemented by types that aren't/contain no textures, storage textures, their array variants or samplers -pub trait NoHandles: GpuLayout {} +pub trait NoHandles {} /// this trait is only implemented by: /// /// * `sm::vec`s of non-boolean type (e.g. `sm::f32x4`) /// * `sm::packed::PackedVec`s (e.g. `sm::packed::unorm8x4`) -pub trait VertexAttribute: GpuLayout + FromAnys { +// Is at most 16 bytes according to https://www.w3.org/TR/WGSL/#input-output-locations +// and thus GpuSized. +pub trait VertexAttribute: GpuLayout + FromAnys + GpuSized { #[doc(hidden)] // runtime api fn vertex_attrib_format() -> VertexAttribFormat; } diff --git a/shame/src/frontend/rust_types/vec.rs b/shame/src/frontend/rust_types/vec.rs index 977c2e4..d06b34b 100644 --- a/shame/src/frontend/rust_types/vec.rs +++ b/shame/src/frontend/rust_types/vec.rs @@ -7,11 +7,11 @@ use super::{ mem::AddressSpace, reference::{AccessMode, AccessModeReadable}, scalar_type::{dtype_as_scalar_from_f64, ScalarType, ScalarTypeInteger, ScalarTypeNumber}, - type_layout::TypeLayoutRules, type_traits::{BindingArgs, GpuAligned, GpuStoreImplCategory, NoAtomics, NoHandles, VertexAttribute}, AsAny, GpuType, To, ToGpuType, }; use crate::{ + any::layout::{self}, call_info, common::{ proc_macro_utils::{collect_into_array_exact, push_wrong_amount_of_args_error}, @@ -68,7 +68,7 @@ pub type scalar = vec; /// let my_vec3 = sm::vec!(1.0, 2.0, 3.0); /// let my_vec4 = sm::vec!(my_vec3, 0.0); // component concatenation, like usual in shaders /// let my_vec4 = my_vec3.extend(0.0); // or like this -/// +/// /// let my_normal = sm::vec!(1.0, 1.0, 0.0).normalize(); /// let rgb = my_normal.remap(-1.0..=1.0, 0.0..=1.0); // remap linear ranges (instead of " * 0.5 + 0.5") /// @@ -572,8 +572,19 @@ impl GpuStore for vec { fn impl_category() -> GpuStoreImplCategory { GpuStoreImplCategory::GpuType(Self::store_ty()) } } -impl GpuLayout for vec { - fn gpu_layout() -> TypeLayout { TypeLayout::from_sized_ty(TypeLayoutRules::Wgsl, &::sized_ty()) } +impl GpuLayout for vec +where + vec: NoBools, +{ + fn layout_recipe() -> layout::TypeLayoutRecipe { + layout::Vector::new( + T::SCALAR_TYPE + .try_into() + .expect("guaranteed via `NoBools` trait bound above"), + L::LEN, + ) + .into() + } fn cpu_type_name_and_layout() -> Option, TypeLayout), super::layout_traits::ArrayElementsUnsizedError>> @@ -1095,7 +1106,14 @@ impl VertexAttribute for vec where Self: NoBools, { - fn vertex_attrib_format() -> VertexAttribFormat { VertexAttribFormat::Fine(L::LEN, T::SCALAR_TYPE) } + fn vertex_attrib_format() -> VertexAttribFormat { + VertexAttribFormat::Fine( + L::LEN, + T::SCALAR_TYPE + .try_into() + .expect("Self: NoBools bound on impl ensures no bools"), + ) + } } impl FromAnys for vec { diff --git a/shame/src/ir/ir_type/align_size.rs b/shame/src/ir/ir_type/align_size.rs index 3916e09..1ed703a 100644 --- a/shame/src/ir/ir_type/align_size.rs +++ b/shame/src/ir/ir_type/align_size.rs @@ -8,11 +8,11 @@ use thiserror::Error; use super::{CanonName, Len2, ScalarType, ScalarTypeFp, SizedType, StoreType}; use super::{Len, Std}; -pub fn round_up(multiple_of: u64, n: u64) -> u64 { +pub const fn round_up(multiple_of: u64, n: u64) -> u64 { match multiple_of { 0 => match n { 0 => 0, - n => panic!("cannot round up {n} to a multiple of 0"), + n => panic!("cannot round up n to a multiple of 0"), }, k @ 1.. => n.div_ceil(k) * k, } diff --git a/shame/src/ir/ir_type/layout_constraints.rs b/shame/src/ir/ir_type/layout_constraints.rs index 8d7b80f..7dcedcc 100644 --- a/shame/src/ir/ir_type/layout_constraints.rs +++ b/shame/src/ir/ir_type/layout_constraints.rs @@ -6,14 +6,15 @@ use std::{ use thiserror::Error; -use crate::{common::proc_macro_reexports::TypeLayoutRules, frontend::rust_types::type_layout::TypeLayout}; +use crate::{ + common::proc_macro_utils::CpuLayoutImplMismatch, + frontend::rust_types::type_layout::{display::LayoutInfoFlags, eq::CheckEqLayoutMismatch, TypeLayout}, +}; use crate::{ backend::language::Language, call_info, common::prettify::set_color, - frontend::{ - any::shared_io::BufferBindingType, encoding::EncodingErrorKind, rust_types::type_layout::LayoutMismatch, - }, + frontend::{any::shared_io::BufferBindingType, encoding::EncodingErrorKind}, ir::{ ir_type::{max_u64_po2_dividing, AccessModeReadable}, recording::{Context, MemoryRegion}, @@ -470,7 +471,7 @@ Type `{}` contains type `{struct_or_block_name}` which has a custom byte-alignme #[error("custom size of {custom} is too small. `{ty}` must have a size of at least {required}")] CustomSizeTooSmall { custom: u64, required: u64, ty: Type }, #[error("memory layout mismatch:\n{0}\n{}", if let Some(comment) = .1 {comment.as_str()} else {""})] - LayoutMismatch(LayoutMismatch, Option), + LayoutMismatch(CheckEqLayoutMismatch, Option), #[error("runtime-sized type {name} cannot be element in an array buffer")] UnsizedStride { name: String }, #[error( @@ -482,6 +483,8 @@ Type `{}` contains type `{struct_or_block_name}` which has a custom byte-alignme gpu_name: String, gpu_stride: u64, }, + #[error(transparent)] + CpuLayoutImplMismatch(#[from] CpuLayoutImplMismatch), } #[allow(missing_docs)] @@ -507,9 +510,9 @@ impl Display for ArrayStrideAlignmentError { "The array with `{}` elements requires that every element is {expected_align}-byte aligned, but the array has a stride of {actual_stride} bytes, which means subsequent elements are not {expected_align}-byte aligned.", self.element_ty ); - if let Ok(layout) = TypeLayout::from_store_ty(TypeLayoutRules::Wgsl, &self.ctx.top_level_type) { + if let Ok(layout) = TypeLayout::from_store_ty(self.ctx.top_level_type.clone()) { writeln!(f, "The full layout of `{}` is:", self.ctx.top_level_type); - layout.write("", self.ctx.use_color, f)?; + layout.write(f, LayoutInfoFlags::ALL)?; writeln!(f); }; writeln!( @@ -542,9 +545,9 @@ impl Display for ArrayStrideError { "The array with `{}` elements requires stride {}, but has stride {}.", self.element_ty, self.expected, self.actual ); - if let Ok(layout) = TypeLayout::from_store_ty(TypeLayoutRules::Wgsl, &self.ctx.top_level_type) { + if let Ok(layout) = TypeLayout::from_store_ty(self.ctx.top_level_type.clone()) { writeln!(f, "The full layout of `{}` is:", self.ctx.top_level_type); - layout.write("", self.ctx.use_color, f)?; + layout.write(f, LayoutInfoFlags::ALL)?; writeln!(f); }; writeln!( @@ -577,9 +580,9 @@ impl Display for ArrayAlignmentError { "The array with `{}` elements requires alignment {}, but has alignment {}.", self.element_ty, self.expected, self.actual ); - if let Ok(layout) = TypeLayout::from_store_ty(TypeLayoutRules::Wgsl, &self.ctx.top_level_type) { + if let Ok(layout) = TypeLayout::from_store_ty(self.ctx.top_level_type.clone()) { writeln!(f, "The full layout of `{}` is:", self.ctx.top_level_type); - layout.write("", self.ctx.use_color, f)?; + layout.write(f, LayoutInfoFlags::ALL)?; writeln!(f); }; writeln!( diff --git a/shame/src/ir/ir_type/struct_.rs b/shame/src/ir/ir_type/struct_.rs index fbfc9f1..25b0863 100644 --- a/shame/src/ir/ir_type/struct_.rs +++ b/shame/src/ir/ir_type/struct_.rs @@ -10,7 +10,7 @@ use thiserror::Error; use super::{align_of_array, canon_name::CanonName, round_up, LayoutError, SizedType, StoreType, Type}; use crate::{ call_info, - common::{iterator_ext::IteratorExt, po2::U32PowerOf2, pool::Key}, + common::{format::numeral_suffix, iterator_ext::IteratorExt, po2::U32PowerOf2, pool::Key}, ir::recording::{Context, Ident}, }; use crate::{ @@ -41,8 +41,8 @@ pub enum StructureDefinitionError { "runtime sized arrays are only allowed as the last field of a buffer-block struct. They are not allowed in sized structs." )] RuntimeSizedArrayNotAllowedInSizedStruct, - #[error("field names must be unique within a structure definition")] - FieldNamesMustBeUnique, + #[error(transparent)] + FieldNamesMustBeUnique(#[from] StructureFieldNamesMustBeUnique), } pub trait Field { @@ -164,10 +164,7 @@ impl Struct { first_sized_fields.push(last_sized_field); let sized_fields = first_sized_fields; assert!(!sized_fields.is_empty()); - use crate::common::iterator_ext::IteratorExt; - if !sized_fields.iter().all_unique_by(|a, b| a.name == b.name) { - return Err(StructureFieldNamesMustBeUnique); - } + check_for_duplicate_field_names(&sized_fields, None)?; let struct_ = Rc::new(Self { kind: StructKind::Sized, name, @@ -210,9 +207,8 @@ impl Struct { ctx.latest_user_caller(), ); }); - if !struct_.fields().all_unique_by(|a, b| a.name() == b.name()) { - return Err(StructureDefinitionError::FieldNamesMustBeUnique); - } + check_for_duplicate_field_names(&struct_.sized_fields, struct_.last_unsized.as_ref())?; + Ok(struct_) } @@ -425,8 +421,49 @@ impl TryFrom> for BufferBlock { } } + + /// an error created if a struct contains two or more fields of the same name -pub struct StructureFieldNamesMustBeUnique; +#[allow(missing_docs)] +#[derive(Debug, Clone, PartialEq, Eq, Error)] +#[error("{} and {} struct field have the same name. Field names must be unique within a structure definition", + numeral_suffix(self.first_occurence + 1), + numeral_suffix(self.second_occurence + 1) +)] +pub struct StructureFieldNamesMustBeUnique { + pub first_occurence: usize, + pub second_occurence: usize, +} + +fn check_for_duplicate_field_names( + sized_fields: &[SizedField], + last_unsized: Option<&RuntimeSizedArrayField>, +) -> Result<(), StructureFieldNamesMustBeUnique> { + // Brute force search > HashMap for the amount of fields + // we'd usually deal with. + let mut duplicate_fields = None; + for (i, field1) in sized_fields.iter().enumerate() { + for (j, field2) in sized_fields.iter().enumerate().skip(i + 1) { + if field1.name == field2.name { + duplicate_fields = Some((i, j)); + break; + } + } + if let Some(last_unsized) = last_unsized { + if field1.name == last_unsized.name { + duplicate_fields = Some((i, sized_fields.len())); + break; + } + } + } + match duplicate_fields { + Some((first_occurence, second_occurence)) => Err(StructureFieldNamesMustBeUnique { + first_occurence, + second_occurence, + }), + None => Ok(()), + } +} impl SizedStruct { #[track_caller] @@ -467,12 +504,12 @@ impl SizedStruct { } #[allow(missing_docs)] -#[derive(Error, Debug, Clone, Copy)] +#[derive(Error, Debug, Clone)] pub enum BufferBlockDefinitionError { #[error("buffer block must have at least one field")] MustHaveAtLeastOneField, #[error("field names of a buffer block must be unique")] - FieldNamesMustBeUnique, + FieldNamesMustBeUnique(StructureFieldNamesMustBeUnique), } impl BufferBlock { @@ -489,7 +526,7 @@ impl BufferBlock { E::WrongStructKindForStructType(_, _) => unreachable!("error never created by Struct::new"), E::RuntimeSizedArrayNotAllowedInSizedStruct => unreachable!("not a sized struct"), E::MustHaveAtLeastOneField(_) => BufferBlockDefinitionError::MustHaveAtLeastOneField, - E::FieldNamesMustBeUnique => BufferBlockDefinitionError::FieldNamesMustBeUnique, + E::FieldNamesMustBeUnique(e) => BufferBlockDefinitionError::FieldNamesMustBeUnique(e), }), } } diff --git a/shame/src/ir/ir_type/tensor.rs b/shame/src/ir/ir_type/tensor.rs index 9933b87..4b87fa6 100644 --- a/shame/src/ir/ir_type/tensor.rs +++ b/shame/src/ir/ir_type/tensor.rs @@ -1,7 +1,9 @@ use std::{fmt::Display, num::NonZeroU32}; use crate::{ + any::{U32PowerOf2}, common::floating_point::{f16, f32_eq_where_nans_are_equal, f64_eq_where_nans_are_equal}, + frontend::rust_types::type_layout::{self, recipe::align_size::PACKED_ALIGN, Repr}, ir::Comp4, }; @@ -73,23 +75,11 @@ impl PartialEq for Len { } impl From for u64 { - fn from(value: Len2) -> Self { - match value { - Len2::X2 => 2, - Len2::X3 => 3, - Len2::X4 => 4, - } - } + fn from(value: Len2) -> Self { value.as_u64() } } impl From for NonZeroU32 { - fn from(value: Len2) -> Self { - match value { - Len2::X2 => NonZeroU32::new(2).unwrap(), - Len2::X3 => NonZeroU32::new(3).unwrap(), - Len2::X4 => NonZeroU32::new(4).unwrap(), - } - } + fn from(value: Len2) -> Self { value.as_non_zero_u32() } } impl PartialEq for Len2 { @@ -102,6 +92,54 @@ impl Len { use Comp4::*; [X, Y, Z, W].into_iter().take(self.into()) } + + /// as u8 + pub const fn as_u8(self) -> u8 { + match self { + Len::X1 => 1, + Len::X2 => 2, + Len::X3 => 3, + Len::X4 => 4, + } + } + /// as u32 + pub const fn as_u32(self) -> u32 { self.as_u8() as u32 } + /// as u64 + pub const fn as_u64(self) -> u64 { self.as_u8() as u64 } + /// as `NonZeroU32` + pub const fn as_non_zero_u32(self) -> NonZeroU32 { + match self { + Len::X1 => NonZeroU32::new(1).unwrap(), + Len::X2 => NonZeroU32::new(2).unwrap(), + Len::X3 => NonZeroU32::new(3).unwrap(), + Len::X4 => NonZeroU32::new(4).unwrap(), + } + } +} + +impl Len2 { + /// as u8 + pub const fn as_u8(self) -> u8 { + match self { + Len2::X2 => 2, + Len2::X3 => 3, + Len2::X4 => 4, + } + } + /// as u32 + pub const fn as_u32(self) -> u32 { self.as_u8() as u32 } + /// as u64 + pub const fn as_u64(self) -> u64 { self.as_u8() as u64 } + /// as `NonZeroU32` + pub const fn as_non_zero_u32(self) -> NonZeroU32 { self.as_len().as_non_zero_u32() } + /// as [`Len`] + pub const fn as_len(self) -> Len { + match self { + Len2::X2 => Len::X2, + Len2::X3 => Len::X3, + Len2::X4 => Len::X4, + } + } } /// (no documentation yet) @@ -442,6 +480,7 @@ pub struct PackedVector { } /// exhaustive list of all byte sizes a `packed_vec` can have +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum PackedVectorByteSize { _2, _4, @@ -458,6 +497,12 @@ impl From for u8 { } } +impl PackedVectorByteSize { + pub fn as_u32(self) -> u32 { u8::from(self) as u32 } + + pub fn as_u64(self) -> u64 { u8::from(self) as u64 } +} + impl Display for PackedVector { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let stype = match self.scalar_type { @@ -485,11 +530,17 @@ impl PackedVector { } } - pub fn align(&self) -> u64 { - match self.byte_size() { - PackedVectorByteSize::_2 => SizedType::Vector(Len::X1, ScalarType::F16).align(), - PackedVectorByteSize::_4 => SizedType::Vector(Len::X1, ScalarType::U32).align(), - PackedVectorByteSize::_8 => SizedType::Vector(Len::X2, ScalarType::U32).align(), + pub fn align(&self, repr: Repr) -> U32PowerOf2 { + match repr { + Repr::Packed => PACKED_ALIGN, + Repr::Wgsl | Repr::WgslUniform => { + let align = match self.byte_size() { + PackedVectorByteSize::_2 => SizedType::Vector(Len::X1, ScalarType::F16).align(), + PackedVectorByteSize::_4 => SizedType::Vector(Len::X1, ScalarType::U32).align(), + PackedVectorByteSize::_8 => SizedType::Vector(Len::X2, ScalarType::U32).align(), + }; + U32PowerOf2::try_from(align as u32).expect("the above all have power of 2 align") + } } } } diff --git a/shame/src/ir/ir_type/ty.rs b/shame/src/ir/ir_type/ty.rs index ead22d6..5647974 100644 --- a/shame/src/ir/ir_type/ty.rs +++ b/shame/src/ir/ir_type/ty.rs @@ -1,6 +1,6 @@ use self::struct_::SizedStruct; use crate::{ - frontend::{any::shared_io, rust_types::type_layout::TypeLayoutRules}, + frontend::{any::shared_io}, ir::recording::MemoryRegion, }; diff --git a/shame/src/ir/pipeline/wip_pipeline.rs b/shame/src/ir/pipeline/wip_pipeline.rs index 4e9e1f0..c694b37 100644 --- a/shame/src/ir/pipeline/wip_pipeline.rs +++ b/shame/src/ir/pipeline/wip_pipeline.rs @@ -10,6 +10,7 @@ use thiserror::Error; use super::{PossibleStages, ShaderStage, StageMask}; use crate::{ + any::layout::Repr, call_info, common::{ integer::post_inc_usize, @@ -32,7 +33,7 @@ use crate::{ error::InternalError, rust_types::{ len::x3, - type_layout::{StructLayout, TypeLayoutRules}, + type_layout::{self, recipe, StructLayout}, }, }, ir::{ @@ -43,7 +44,7 @@ use crate::{ StructureFieldNamesMustBeUnique, TextureFormatWrapper, Type, }, results::DepthStencilState, - BindingIter, DepthLhs, StencilMasking, Test, + BindingIter, DepthLhs, StencilMasking, Test, TypeLayout, }; @@ -71,7 +72,7 @@ macro_rules! stringify_checked { #[derive(Error, Debug, Clone)] pub enum PipelineError { - #[error("Missing pipeline specialization. Use either the `{}` or `{}` method to start a compute- or render-pipeline encoding.", + #[error("Missing pipeline specialization. Use either the `{}` or `{}` method to start a compute- or render-pipeline encoding.", stringify_checked!(expr: EncodingGuard::new_compute_pipeline::<3>).replace("3", "_"), stringify_checked!(expr: EncodingGuard::new_render_pipeline), )] @@ -352,17 +353,16 @@ impl WipPushConstantsField { let byte_size = sized_struct.byte_size(); // TODO(release) the `.expect()` calls here can be removed by building a `std::alloc::Layout`-like builder for struct layouts. - let (_, _, layout) = StructLayout::from_ir_struct(TypeLayoutRules::Wgsl, &sized_struct); + let sized_struct: recipe::SizedStruct = sized_struct + .try_into() + .map_err(|e| InternalError::new(true, format!("{e}")))?; + let layout = sized_struct.layout(); let mut ranges = ByteRangesPerStage::default(); for (field, node) in layout.fields.iter().zip(fields.iter().map(|f| f.node)) { let stages = nodes[node].stages.must_appear_in(); - let field_size = field - .field - .ty - .byte_size() - .expect("SizedStruct type enforces Some(size)"); + let field_size = field.ty.byte_size().expect("SizedStruct type enforces Some(size)"); let start = field.rel_byte_offset; let end = start + field_size; @@ -408,11 +408,12 @@ impl WipPushConstantsField { // here we have to allocate unique name strings for each field, // so we don't fail the name uniqueness check, even though we don't need those names. - SizedStruct::new_nonempty("PushConstants".into(), + #[allow(clippy::match_single_binding)] + SizedStruct::new_nonempty("PushConstants".into(), fields.iter().map(&mut to_sized_field).collect(), to_sized_field(last) ).map_err(|err| match err { - StructureFieldNamesMustBeUnique => { + StructureFieldNamesMustBeUnique { .. } => { InternalError::new(true, format!("intermediate push constants structure field names are not unique. fields: {fields:?}, last: {last:?}")) } }) diff --git a/shame/src/ir/recording/builtin_templates.rs b/shame/src/ir/recording/builtin_templates.rs index 9eca60b..682d923 100644 --- a/shame/src/ir/recording/builtin_templates.rs +++ b/shame/src/ir/recording/builtin_templates.rs @@ -156,7 +156,7 @@ impl BuiltinTemplateStructs { new_field("exp", SizedType::Vector(len, ir::ScalarType::I32)), ); match struc { - Err(StructureFieldNamesMustBeUnique) => unreachable!("field names above are unique"), + Err(StructureFieldNamesMustBeUnique { .. }) => unreachable!("field names above are unique"), Ok(s) => s, } } @@ -167,7 +167,7 @@ impl BuiltinTemplateStructs { new_field("whole", SizedType::Vector(len, fp.into())), ); match struc { - Err(StructureFieldNamesMustBeUnique) => unreachable!("field names above are unique"), + Err(StructureFieldNamesMustBeUnique { .. }) => unreachable!("field names above are unique"), Ok(s) => s, } } @@ -179,7 +179,7 @@ impl BuiltinTemplateStructs { new_field("exchanged", SizedType::Vector(ir::Len::X1, ir::ScalarType::Bool)), ); match struc { - Err(StructureFieldNamesMustBeUnique) => unreachable!("field names above are unique"), + Err(StructureFieldNamesMustBeUnique { .. }) => unreachable!("field names above are unique"), Ok(s) => s, } } diff --git a/shame/src/lib.rs b/shame/src/lib.rs index 9d316df..e6d9325 100644 --- a/shame/src/lib.rs +++ b/shame/src/lib.rs @@ -2,7 +2,11 @@ #![forbid(unsafe_code)] //#![warn(clippy::cast_lossless)] #![deny(missing_docs)] -#![allow(clippy::match_like_matches_macro, clippy::diverging_sub_expression)] +#![allow( + mismatched_lifetime_syntaxes, + clippy::match_like_matches_macro, + clippy::diverging_sub_expression +)] #![allow(unused)] mod backend; @@ -184,7 +188,7 @@ pub use frontend::rust_types::vec_range_traits::VecRangeBoundsInclusive; pub mod aliases { use crate::frontend::rust_types::aliases; - #[rustfmt::skip] + #[rustfmt::skip] pub use aliases::rust_simd::{ f16x1, f32x1, f64x1, u32x1, i32x1, boolx1, f16x2, f32x2, f64x2, u32x2, i32x2, boolx2, @@ -192,16 +196,16 @@ pub mod aliases { f16x4, f32x4, f64x4, u32x4, i32x4, boolx4, }; - #[rustfmt::skip] + #[rustfmt::skip] pub use aliases::rust_simd::{ f16x2x2, f32x2x2, f64x2x2, f16x2x3, f32x2x3, f64x2x3, f16x2x4, f32x2x4, f64x2x4, - + f16x3x2, f32x3x2, f64x3x2, f16x3x3, f32x3x3, f64x3x3, f16x3x4, f32x3x4, f64x3x4, - + f16x4x2, f32x4x2, f64x4x2, f16x4x3, f32x4x3, f64x4x3, f16x4x4, f32x4x4, f64x4x4, @@ -315,10 +319,12 @@ pub use frontend::texture::texture_formats as tf; pub use shame_derive::CpuLayout; pub use shame_derive::GpuLayout; pub use frontend::rust_types::layout_traits::GpuLayout; +pub use frontend::rust_types::layout_traits::gpu_layout; pub use frontend::rust_types::layout_traits::CpuLayout; +pub use frontend::rust_types::layout_traits::cpu_layout; pub use frontend::rust_types::type_layout::TypeLayout; -pub use frontend::rust_types::type_layout::TypeLayoutError; -pub use frontend::rust_types::layout_traits::ArrayElementsUnsizedError; +pub use common::po2::U32PowerOf2; +pub use common::po2::NotAU32PowerOf2; // derived traits pub use frontend::rust_types::type_traits::GpuStore; @@ -331,6 +337,7 @@ pub use frontend::rust_types::struct_::SizedFields; pub use frontend::rust_types::type_traits::NoBools; pub use frontend::rust_types::type_traits::NoAtomics; pub use frontend::rust_types::type_traits::NoHandles; +pub use frontend::rust_types::type_traits::VertexAttribute; pub use frontend::rust_types::layout_traits::VertexLayout; @@ -416,6 +423,11 @@ pub mod results { pub type Dict = std::collections::BTreeMap; } +/// everything related to type layouts +pub mod layout { + pub use crate::frontend::rust_types::type_layout::recipe::ScalarType; +} + // #[doc(hidden)] interface starts here // (not part of the public api) @@ -462,6 +474,56 @@ pub mod any { pub use crate::ir::ir_type::StructureDefinitionError; pub use crate::ir::ir_type::StructureFieldNamesMustBeUnique; + pub mod layout { + use crate::frontend::rust_types::type_layout; + + // type layout + pub use type_layout::TypeLayout; + pub use type_layout::Repr; + pub mod repr { + use crate::frontend::rust_types::type_layout; + } + pub use type_layout::VectorLayout; + pub use type_layout::PackedVectorLayout; + pub use type_layout::MatrixLayout; + pub use type_layout::ArrayLayout; + pub use type_layout::StructLayout; + pub use type_layout::FieldLayout; + + // recipe types + pub use type_layout::recipe::TypeLayoutRecipe; + pub use type_layout::recipe::UnsizedStruct; + pub use type_layout::recipe::RuntimeSizedArray; + pub use type_layout::recipe::SizedType; + pub use type_layout::recipe::Vector; + pub use type_layout::recipe::Matrix; + pub use type_layout::recipe::MatrixMajor; + pub use type_layout::recipe::SizedArray; + pub use type_layout::recipe::Atomic; + pub use type_layout::recipe::PackedVector; + pub use type_layout::recipe::SizedStruct; + + // recipe type parts + pub use type_layout::recipe::ScalarType; + pub use type_layout::recipe::ScalarTypeFp; + pub use type_layout::recipe::ScalarTypeInteger; + pub use type_layout::recipe::Len; + pub use type_layout::recipe::Len2; + pub use type_layout::recipe::SizedField; + pub use type_layout::recipe::RuntimeSizedArrayField; + pub use type_layout::recipe::CanonName; + pub use type_layout::recipe::SizedOrArray; + pub use type_layout::recipe::FieldOptions; + + // layout calculation utility + pub use type_layout::recipe::StructLayoutCalculator; + pub use type_layout::recipe::FieldOffsets; + + // conversion and builder errors + pub use type_layout::recipe::builder::IsUnsizedStructError; + pub use type_layout::recipe::builder::StructFromPartsError; + } + // runtime binding api pub use any::shared_io::BindPath; pub use any::shared_io::BindingType; diff --git a/shame/tests/test_layout.rs b/shame/tests/test_layout.rs index b137a23..d0b6414 100644 --- a/shame/tests/test_layout.rs +++ b/shame/tests/test_layout.rs @@ -1,7 +1,8 @@ #![allow(non_camel_case_types, unused)] use pretty_assertions::{assert_eq, assert_ne}; +use sm::__private::proc_macro_reexports::CpuAligned; -use shame as sm; +use shame::{self as sm, cpu_layout, gpu_layout}; use sm::{aliases::*, CpuLayout, GpuLayout}; #[test] @@ -21,7 +22,7 @@ fn basic_layout_eq() { c: i32, } - assert_eq!(OnGpu::gpu_layout(), OnCpu::cpu_layout()); + assert_eq!(gpu_layout::(), cpu_layout::()); } #[test] @@ -50,9 +51,9 @@ fn attributes_dont_contribute_to_eq() { c: i32, } - assert_eq!(OnGpuA::gpu_layout(), OnCpu::cpu_layout()); - assert_eq!(OnGpuB::gpu_layout(), OnCpu::cpu_layout()); - assert_eq!(OnGpuA::gpu_layout(), OnGpuB::gpu_layout()); + assert_eq!(gpu_layout::(), cpu_layout::()); + assert_eq!(gpu_layout::(), cpu_layout::()); + assert_eq!(gpu_layout::(), gpu_layout::()); } #[test] @@ -74,7 +75,7 @@ fn fixed_by_align_size_attribute() { c: i32, } - assert_eq!(OnGpu::gpu_layout(), OnCpu::cpu_layout()); + assert_eq!(gpu_layout::(), cpu_layout::()); } { @@ -94,7 +95,88 @@ fn fixed_by_align_size_attribute() { c: f32x3_size32, } - assert_eq!(OnGpu::gpu_layout(), OnCpu::cpu_layout()); + assert_eq!(gpu_layout::(), cpu_layout::()); + } + + { + #[derive(sm::GpuLayout)] + struct OnGpu { + a: f32x1, + #[size(16)] + b: f32x3, + c: i32x1, + } + + #[derive(sm::CpuLayout)] + #[repr(C)] + struct OnCpu { + a: f32, + b: f32x3_size16, + c: i32, + } + + assert_eq!(gpu_layout::(), cpu_layout::()); + } + + { + #[derive(sm::GpuLayout)] + struct OnGpu { + a: f32x1, + b: i32x1, + #[size(16)] + c: f32x3, // TODO(release) this should work even without #[size(16)], no? + } + + #[derive(sm::CpuLayout)] + #[repr(C)] + struct OnCpu { + a: f32, + b: i32, + c: f32x3_size16, + } + + assert_eq!(gpu_layout::(), cpu_layout::()); + } + + { + #[derive(sm::GpuLayout)] + struct OnGpu { + a: f32x4, + b: f32x3, // align 16 + c: i32x1, + } + + #[derive(sm::CpuLayout)] + #[repr(C)] + struct OnCpu { + a: f32x4_cpu, + b: f32x3_align4, // de-facto 16 aligned + c: i32, + } + + assert_eq!(gpu_layout::(), cpu_layout::()); + } + + { + // this is the case where rust's idea that `size` must be multiple of `align` + // clashes with wgsl's `vec3f` + + #[derive(sm::GpuLayout)] + struct OnGpu { + a: f32x1, + b: f32x3, + c: i32x1, + } + + #[derive(sm::CpuLayout)] + #[repr(C)] + struct OnCpu { + a: f32, + b: f32x3_cpu, + c: i32, + } + + assert_ne!(gpu_layout::(), cpu_layout::()); } } @@ -115,7 +197,7 @@ fn different_align_struct_eq() { c: i32, } - assert_eq!(OnGpu::gpu_layout(), OnCpu::cpu_layout()); + assert_eq!(gpu_layout::(), cpu_layout::()); } #[test] @@ -135,33 +217,54 @@ fn unsized_struct_layout_eq() { c: [i32], } - assert_eq!(OnGpu::gpu_layout(), OnCpu::cpu_layout()); + assert_eq!(gpu_layout::(), cpu_layout::()); } #[derive(Clone, Copy)] #[repr(C, align(16))] struct f32x4_cpu(pub [f32; 4]); +impl CpuLayout for f32x4_cpu { + fn cpu_layout() -> shame::TypeLayout { + rust_layout_with_shame_semantics::() + } +} + #[derive(Clone, Copy)] #[repr(C, align(16))] struct f32x3_cpu(pub [f32; 3]); impl CpuLayout for f32x3_cpu { - fn cpu_layout() -> shame::TypeLayout { f32x3::gpu_layout() } + fn cpu_layout() -> shame::TypeLayout { + println!("this impl of `CpuLayout` is wrong, do not copy-paste it into your application"); + // This impl of `CpuLayout` is wrong. It claims that `Self` has size 12 + // and not size 16, as `std::mem::size_of::()` would return. + // It still represents a way a user might try to replicate wgsl's vec3f. + // At the time of writing it is undecided how we want to deal with this. + // The user can have an align4 and an align16 implementation of Vec3, similar + // to `glam`. The user could then choose depending on the desired packing. + // Atm this does not cause an actual memory bug, because actual offsets and + // sizes are used in cpu-layout checks. + + // TODO(release): decide on the above issue and, depending on decision, remove `f32x3_cpu` entirely + let mut layout = gpu_layout::(); // size 12 + *layout.align_mut() = Self::CPU_ALIGNMENT; // align 16 + layout + } } #[derive(Clone, Copy)] #[repr(C, align(8))] struct f32x2_cpu(pub [f32; 2]); impl CpuLayout for f32x2_cpu { - fn cpu_layout() -> shame::TypeLayout { f32x2::gpu_layout() } + fn cpu_layout() -> shame::TypeLayout { rust_layout_with_shame_semantics::() } } #[derive(Clone, Copy)] #[repr(C)] struct f32x2_align4(pub [f32; 2]); impl CpuLayout for f32x2_align4 { - fn cpu_layout() -> shame::TypeLayout { f32x2::gpu_layout() } + fn cpu_layout() -> shame::TypeLayout { rust_layout_with_shame_semantics::() } } #[derive(Clone, Copy)] @@ -169,7 +272,7 @@ impl CpuLayout for f32x2_align4 { struct f32x4_align4(pub [f32; 4]); impl CpuLayout for f32x4_align4 { - fn cpu_layout() -> shame::TypeLayout { f32x4::gpu_layout() } + fn cpu_layout() -> shame::TypeLayout { rust_layout_with_shame_semantics::() } } #[derive(Clone, Copy)] @@ -182,7 +285,23 @@ static_assertions::assert_eq_align!(glam::Vec3, f32x3_align4); static_assertions::assert_eq_align!(glam::Vec4, f32x4_cpu); impl CpuLayout for f32x3_align4 { - fn cpu_layout() -> shame::TypeLayout { f32x3::gpu_layout() } + fn cpu_layout() -> shame::TypeLayout { + // TODO(release): replace this with `rust_layout_with_shame_semantics::()` + // and find a proper solution to the consequences. Its size is 16, and not 12. + let mut layout = gpu_layout::(); + *layout.align_mut() = Self::CPU_ALIGNMENT; + layout + } +} + +#[derive(Clone, Copy)] +#[repr(C, align(16))] +struct f32x3_size16(pub [f32; 3], [u8; 4]); + +impl CpuLayout for f32x3_size16 { + fn cpu_layout() -> shame::TypeLayout { + rust_layout_with_shame_semantics::() + } } #[derive(Clone, Copy)] @@ -190,7 +309,9 @@ impl CpuLayout for f32x3_align4 { struct f32x3_size32(pub [f32; 3], [u8; 20]); impl CpuLayout for f32x3_size32 { - fn cpu_layout() -> shame::TypeLayout { f32x3::gpu_layout() } + fn cpu_layout() -> shame::TypeLayout { + rust_layout_with_shame_semantics::() + } } @@ -212,7 +333,7 @@ fn unsized_struct_vec3_align_layout_eq() { c: [f32x3_cpu], } - assert_eq!(OnGpu::gpu_layout(), OnCpu::cpu_layout()); + assert_eq!(gpu_layout::(), cpu_layout::()); } #[test] @@ -231,7 +352,7 @@ fn unsized_struct_vec3_align_layout_eq() { // the alignment on the top level of the layout doesn't matter. // two layouts are only considered different if an alignment mismatch // leads to different offsets of fields or array elements - assert_eq!(OnGpu::gpu_layout(), OnCpu::cpu_layout()); + assert_eq!(gpu_layout::(), cpu_layout::()); } #[test] @@ -246,11 +367,11 @@ fn unsized_struct_vec3_align_layout_eq() { struct OnCpu { // size=12, align=4 a: f32x3_align4, } - assert_ne!(OnGpu::gpu_layout(), OnCpu::cpu_layout()); - assert!(OnGpu::gpu_layout().byte_size() == Some(16)); - assert!(OnGpu::gpu_layout().align() == 16); - assert!(OnCpu::cpu_layout().byte_size() == Some(12)); - assert!(OnCpu::cpu_layout().align() == 4); + assert_ne!(gpu_layout::(), cpu_layout::()); + assert!(gpu_layout::().byte_size() == Some(16)); + assert!(gpu_layout::().align().as_u32() == 16); + assert!(cpu_layout::().byte_size() == Some(12)); + assert!(cpu_layout::().align().as_u32() == 4); } #[test] @@ -283,15 +404,16 @@ fn unsized_struct_nested_vec3_align_layout_eq() { c: [InnerCpu], } - assert_eq!(OnGpu::gpu_layout(), OnCpu::cpu_layout()); + assert_eq!(gpu_layout::(), cpu_layout::()); } #[test] fn unsized_array_layout_eq() { - assert_eq!(>::gpu_layout(), <[f32]>::cpu_layout()); - assert_eq!(>::gpu_layout(), <[f32x3_cpu]>::cpu_layout()); - assert_ne!(>::gpu_layout(), <[f32x3_align4]>::cpu_layout()); - assert_ne!(>::gpu_layout(), <[f32x3_size32]>::cpu_layout()); + assert_eq!(gpu_layout::>(), cpu_layout::<[f32]>()); + assert_eq!(gpu_layout::>(), cpu_layout::<[f32x3_cpu]>()); + assert_ne!(gpu_layout::>(), cpu_layout::<[f32x3_align4]>()); + assert_ne!(gpu_layout::>(), cpu_layout::<[f32x3_size16]>()); + assert_ne!(gpu_layout::>(), cpu_layout::<[f32x3_size32]>()); } #[test] @@ -318,14 +440,16 @@ fn layouts_mismatch() { c: i32, } - assert_ne!(OnGpuLess::gpu_layout(), OnCpu::cpu_layout()); - assert_ne!(OnGpuMore::gpu_layout(), OnCpu::cpu_layout()); + assert_ne!(gpu_layout::(), cpu_layout::()); + assert_ne!(gpu_layout::(), cpu_layout::()); } #[test] fn external_vec_type() { // using duck-traiting just so that the proc-macro uses `CpuLayoutExt::layout()` pub mod my_mod { + use super::rust_layout_with_shame_semantics; + use shame::gpu_layout; use shame as sm; use sm::aliases::*; use sm::GpuLayout as _; @@ -335,11 +459,11 @@ fn external_vec_type() { } impl CpuLayoutExt for glam::Vec4 { - fn cpu_layout() -> shame::TypeLayout { f32x4::gpu_layout() } + fn cpu_layout() -> shame::TypeLayout { gpu_layout::() } } impl CpuLayoutExt for glam::Vec3 { - fn cpu_layout() -> shame::TypeLayout { f32x3::gpu_layout() } + fn cpu_layout() -> shame::TypeLayout { rust_layout_with_shame_semantics::() } } } @@ -357,7 +481,7 @@ fn external_vec_type() { b: glam::Vec4, } - assert_eq!(OnGpu::gpu_layout(), OnCpu::cpu_layout()); + assert_eq!(gpu_layout::(), cpu_layout::()); #[derive(sm::GpuLayout)] struct OnGpu2 { @@ -367,15 +491,6 @@ fn external_vec_type() { c: f32x4, } - #[derive(sm::GpuLayout)] - #[gpu_repr(packed)] - struct OnGpu2Packed { - a: f32x3, - b: f32x3, - #[align(16)] - c: f32x4, - } - #[derive(sm::CpuLayout)] #[repr(C)] struct OnCpu2 { @@ -384,8 +499,20 @@ fn external_vec_type() { c: glam::Vec4, } - assert_ne!(OnGpu2::gpu_layout(), OnCpu2::cpu_layout()); - assert_eq!(OnGpu2Packed::gpu_layout(), OnCpu2::cpu_layout()); + assert_ne!(gpu_layout::(), cpu_layout::()); + + // TODO: delete or use compile fail test crate like trybuild to make + // sure that align and size attributes aren't allowed on packed structs. + // #[derive(sm::GpuLayout)] + // #[gpu_repr(packed)] + // struct OnGpu2Packed { + // a: f32x3, + // b: f32x3, + // #[align(16)] + // c: f32x4, + // } + + // assert_eq!(gpu_layout::(), cpu_layout::()); } #[test] @@ -407,26 +534,83 @@ fn external_vec_type() { uv : f32x2_align4, } - assert_eq!(OnGpu::gpu_layout(), OnCpu::cpu_layout()); + assert_eq!(gpu_layout::(), cpu_layout::()); } { - #[derive(sm::GpuLayout)] - #[gpu_repr(packed)] - struct OnGpu { - pos: f32x3, - nor: f32x3, - #[align(8)] uv : f32x2, - } - - #[derive(sm::CpuLayout)] - #[repr(C)] - struct OnCpu { - pos: f32x3_align4, - nor: f32x3_align4, - uv : f32x2_cpu, - } + // TODO: delete or use compile fail test crate like trybuild to make + // sure that align and size attributes aren't allowed on packed structs. + // #[derive(sm::GpuLayout)] + // #[gpu_repr(packed)] + // struct OnGpu { + // pos: f32x3, + // nor: f32x3, + // #[align(8)] uv : f32x2, + // } + + // #[derive(sm::CpuLayout)] + // #[repr(C)] + // struct OnCpu { + // pos: f32x3_align4, + // nor: f32x3_align4, + // uv : f32x2_cpu, + // } + + // assert_eq!(gpu_layout::(), cpu_layout::()); + // enum __ where OnGpu: sm::VertexLayout {} + } +} - assert_eq!(OnGpu::gpu_layout(), OnCpu::cpu_layout()); - enum __ where OnGpu: sm::VertexLayout {} +#[rustfmt::skip] +#[test] +fn test_set_align_size() { + #[derive(sm::GpuLayout)] + struct OnGpu { + a: f32x1, + b: u32x1, + c: sm::Array>, + } + + let mut layouts = [ + gpu_layout::(), + gpu_layout::(), + gpu_layout::(), + gpu_layout::(), + gpu_layout::>(), + gpu_layout::>(), + ]; + + for (i, lay) in layouts.iter_mut().enumerate() { + let new_align = sm::U32PowerOf2::_128; + assert_ne!(lay.align(), new_align, "#{i}: align of {lay} is already {new_align:?}, change new_align to make the test work"); + *lay.align_mut() = new_align; + assert_eq!(lay.align(), new_align, "#{i}: align of {lay} is not {new_align:?}"); + + let new_size = 128; + assert_ne!(lay.byte_size(), Some(new_size), "#{i}: size of {lay} is already {new_size}, change new_size to make the test work"); + match lay.removable_byte_size_mut() { + Ok(removable) => *removable = Some(new_size), + Err(fixed) => *fixed = new_size, + }; + assert_eq!(lay.byte_size(), Some(new_size), "#{i}: size of {lay} is not {new_size:?}"); + + if let Ok(removable) = lay.removable_byte_size_mut() { + *removable = None; + assert_eq!(lay.byte_size(), None, "#{i}: size of {lay} is not None"); + }; } } + +/// helper for defining a `TypeLayout` for a cpu type that +/// represents a `GpuSemantics` on the Gpu, but has alignment and size of `Layout` on the Cpu +pub fn rust_layout_with_shame_semantics() -> sm::TypeLayout { + let mut layout = sm::gpu_layout::(); + + *layout.align_mut() = CpuType::CPU_ALIGNMENT; + layout.set_byte_size(size_of::() as u64); + + // these are just here because we are testing + assert_eq!(layout.align().as_u32(), align_of::() as u32); + assert_eq!(layout.byte_size().map(|x| x as _), CpuType::CPU_SIZE); + + layout +} diff --git a/shame_derive/src/derive_layout.rs b/shame_derive/src/derive_layout.rs index c9fe351..4e37afd 100644 --- a/shame_derive/src/derive_layout.rs +++ b/shame_derive/src/derive_layout.rs @@ -8,6 +8,7 @@ use syn::LitInt; use syn::{DataStruct, DeriveInput, FieldsNamed}; use crate::util; +use crate::util::Repr; macro_rules! bail { ($span: expr, $display: expr) => {return Err(syn::Error::new($span, $display,))}; @@ -81,12 +82,17 @@ pub fn impl_for_struct( .into_iter(); let none_if_no_cpu_equivalent_type = cpu_attr.is_none().then_some(quote! { None }).into_iter(); - // #[gpu_repr(packed)] - let gpu_repr_packed = util::find_gpu_repr_packed(&input.attrs)?; - if let (Some(span), WhichDerive::CpuLayout) = (&gpu_repr_packed, &which_derive) { + // #[gpu_repr(packed | storage)] + let gpu_repr = util::try_find_gpu_repr(&input.attrs)?; + if let (Some((span, _)), WhichDerive::CpuLayout) = (&gpu_repr, &which_derive) { bail!(*span, "`gpu_repr` attribute is only supported by `derive(GpuLayout)`") } - let is_gpu_repr_packed = gpu_repr_packed.is_some(); + // if no `#[gpu_repr(_)]` attribute was explicitly specified, we default to `Repr::Wgsl` + let gpu_repr = gpu_repr.map(|(_, repr)| repr).unwrap_or(util::Repr::Wgsl); + let gpu_repr_shame = match gpu_repr { + Repr::Packed => quote!( #re::Repr::Packed ), + Repr::Wgsl => quote!( #re::Repr::Wgsl ), + }; // #[repr(...)] let repr_c_attr = util::try_parse_repr(&input.attrs)?; @@ -150,8 +156,20 @@ pub fn impl_for_struct( align: util::find_literal_list_attr::("align", &field.attrs)?, }; + match gpu_repr { + Repr::Packed => { + if fwa.align.is_some() { + bail!( + field.span(), + "`#[gpu_repr(packed)]` structs do not support `#[align(N)]` attributes" + ); + } + } + Repr::Wgsl => {} + } + if let Some((span, align_lit)) = &fwa.align { - match align_lit.base10_parse::().map(u32::is_power_of_two) { + match align_lit.base10_parse().map(u32::is_power_of_two) { Ok(true) => (), Ok(false) => bail!(*span, "alignment attribute must be a power of two"), Err(_) => bail!( @@ -197,30 +215,34 @@ pub fn impl_for_struct( let impl_gpu_layout = quote! { impl<#generics_decl> #re::GpuLayout for #derive_struct_ident<#(#idents_of_generics),*> where - #(#first_fields_type: #re::GpuSized,)* - #last_field_type: #re::GpuAligned, + #(#first_fields_type: #re::NoBools + #re::NoHandles + #re::GpuLayout + #re::GpuSized,)* + #last_field_type: #re::NoBools + #re::NoHandles + #re::GpuLayout, #where_clause_predicates { - - fn gpu_layout() -> #re::TypeLayout { - let result = #re::TypeLayout::struct_from_parts( - #re::TypeLayoutRules::Wgsl, - #is_gpu_repr_packed, - std::stringify!(#derive_struct_ident).into(), + fn layout_recipe() -> #re::TypeLayoutRecipe { + let result = #re::TypeLayoutRecipe::struct_from_parts( + std::stringify!(#derive_struct_ident), [ - #( - #re::FieldLayout { - name: std::stringify!(#field_ident).into(), - ty: <#field_type as #re::GpuLayout>::gpu_layout(), - custom_min_align: #field_align.map(|align: u32| TryFrom::try_from(align).expect("power of two validated during codegen")).into(), - custom_min_size: #field_size.into(), - }, - )* - ].into_iter() + #(( + #re::FieldOptions::new( + std::stringify!(#field_ident), + #field_align.map(|align: u32| TryFrom::try_from(align).expect("power of two validated during codegen")).into(), + #field_size.into(), + ), + <#field_type as #re::GpuLayout>::layout_recipe() + ),)* + ], + #gpu_repr_shame, ); + match result { - Ok(layout) => layout, - Err(e @ #re::StructLayoutError::UnsizedFieldMustBeLast { .. }) => unreachable!("all fields except last require bound `GpuSized`. {e}"), + Ok(recipe_type) => recipe_type, + Err(#re::StructFromPartsError::MustHaveAtLeastOneField) => unreachable!("checked above"), + Err(#re::StructFromPartsError::OnlyLastFieldMayBeUnsized) => unreachable!("ensured by field trait bounds"), + // GpuType is not implemented for derived structs directly, so they can't be used + // as the field of another struct, instead shame::Struct has to be used, which + // only accepts sized structs. + Err(#re::StructFromPartsError::MustNotHaveUnsizedStructField) => unreachable!("GpuType bound for fields makes this impossible"), } } @@ -243,8 +265,7 @@ pub fn impl_for_struct( where #(#triv #field_type: #re::VertexAttribute,)* #where_clause_predicates - { - } + { } }; @@ -322,232 +343,238 @@ pub fn impl_for_struct( } }; - if gpu_repr_packed.is_some() { + match gpu_repr { + Repr::Packed => // this is basically only for vertex buffers, so // we only implement `GpuLayout` and `VertexLayout`, as well as their implied traits - Ok(quote! { - #impl_gpu_layout - #impl_vertex_buffer_layout - #impl_fake_auto_traits - #impl_from_anys - }) - } else { - // non gpu_repr(packed) - let struct_ref_doc = format!( - r#"This struct was generated by `#[derive(shame::GpuLayout)]` - as a version of `{derive_struct_ident}` which holds references to its fields. It is used as + { + Ok(quote! { + #impl_gpu_layout + #impl_vertex_buffer_layout + #impl_fake_auto_traits + #impl_from_anys + }) + } + Repr::Wgsl => { + // non gpu_repr(packed) + let struct_ref_doc = format!( + r#"This struct was generated by `#[derive(shame::GpuLayout)]` + as a version of `{derive_struct_ident}` which holds references to its fields. It is used as the `std::ops::Deref` target of `shame::Ref<{derive_struct_ident}>`"# - ); - - Ok(quote! { - #impl_gpu_layout - #impl_vertex_buffer_layout - #impl_fake_auto_traits - #impl_from_anys - - #[doc = #struct_ref_doc] - #[allow(non_camel_case_types)] - #[derive(Clone, Copy)] - #vis struct #derive_struct_ref_ident<_AS: #re::AddressSpace, _AM: #re::AccessMode> - where #( - #triv #field_type: #re::GpuStore + #re::GpuType, - )* - { - #( - #field_vis #field_ident: #re::Ref<#field_type, _AS, _AM>, - )* - } + ); - impl<#generics_decl> #re::BufferFields for #derive_struct_ident<#(#idents_of_generics),*> - where - #(#triv #field_type: #re::GpuStore + #re::GpuType,)* - #(#triv #first_fields_type: #re::GpuSized,)* - #triv #last_field_type: #re::GpuAligned, - #where_clause_predicates - { - fn as_anys(&self) -> impl std::borrow::Borrow<[#re::Any]> { - use #re::AsAny; - [ - #(self.#field_ident.as_any()),* - ] + Ok(quote! { + #impl_gpu_layout + #impl_vertex_buffer_layout + #impl_fake_auto_traits + #impl_from_anys + + #[doc = #struct_ref_doc] + #[allow(non_camel_case_types)] + #[derive(Clone, Copy)] + #vis struct #derive_struct_ref_ident<_AS: #re::AddressSpace, _AM: #re::AccessMode> + where #( + #triv #field_type: #re::GpuStore + #re::GpuType, + )* + { + #( + #field_vis #field_ident: #re::Ref<#field_type, _AS, _AM>, + )* } - fn clone_fields(&self) -> Self { - Self { - #(#field_ident: std::clone::Clone::clone(&self.#field_ident)),* + impl<#generics_decl> #re::BufferFields for #derive_struct_ident<#(#idents_of_generics),*> + where + #(#triv #field_type: #re::GpuStore + #re::GpuType,)* + #(#triv #first_fields_type: #re::GpuSized,)* + #triv #last_field_type: #re::GpuAligned, + #where_clause_predicates + { + fn as_anys(&self) -> impl std::borrow::Borrow<[#re::Any]> { + use #re::AsAny; + [ + #(self.#field_ident.as_any()),* + ] } - } - - fn get_bufferblock_type() -> #re::ir::BufferBlock { - let mut fields = std::vec::Vec::from([ - #( - #re::ir::SizedField { - name: std::stringify!(#first_fields_ident).into(), - custom_min_size: #first_fields_size, - custom_min_align: #first_fields_align.map(|align: u32| TryFrom::try_from(align).expect("power of two validated during codegen")), - ty: <#first_fields_type as #re::GpuSized>::sized_ty(), - } - ),* - ]); - let mut last_unsized = None::<#re::ir::RuntimeSizedArrayField>; - #[allow(clippy::no_effect)] - { - // this part is only here to force a compiler error if the last field - // uses the #[size(n)] attribute but the type is not shame::GpuSized. - #(#enable_if_last_field_has_size_attribute; // only generate the line below if the last field has a #[size(n)] attribute - // compiler error if not shame::GpuSized - fn __() where #last_field_type: #re::GpuSized {} - )* - - match <#last_field_type as #re::GpuAligned>::aligned_ty() { - #re::ir::AlignedType::Sized(ty) => - fields.push(#re::ir::SizedField { - name: std::stringify!(#last_field_ident).into(), - custom_min_size: #last_field_size, - custom_min_align: #last_field_align.map(|align: u32| TryFrom::try_from(align).expect("power of two validated during codegen")), - ty - }), - #re::ir::AlignedType::RuntimeSizedArray(element_ty) => - last_unsized = Some(#re::ir::RuntimeSizedArrayField { - name: std::stringify!(#last_field_ident).into(), - custom_min_align: #last_field_align.map(|align: u32| TryFrom::try_from(align).expect("power of two validated during codegen")), - element_ty - }), + fn clone_fields(&self) -> Self { + Self { + #(#field_ident: std::clone::Clone::clone(&self.#field_ident)),* } } - use #re::BufferBlockDefinitionError as E; - match #re::ir::BufferBlock::new( - std::stringify!(#derive_struct_ident).into(), - fields, - last_unsized - ) { - Ok(t) => t, - Err(e) => match e { - E::MustHaveAtLeastOneField => unreachable!(">= 1 field is ensured by derive macro"), - E::FieldNamesMustBeUnique => unreachable!("unique field idents are ensured by rust struct definition"), + fn get_bufferblock_type() -> #re::ir::BufferBlock { + let mut fields = std::vec::Vec::from([ + #( + #re::ir::SizedField { + name: std::stringify!(#first_fields_ident).into(), + custom_min_size: #first_fields_size, + custom_min_align: #first_fields_align.map(|align: u32| TryFrom::try_from(align).expect("power of two validated during codegen")), + ty: <#first_fields_type as #re::GpuSized>::sized_ty(), + } + ),* + ]); + + let mut last_unsized = None::<#re::ir::RuntimeSizedArrayField>; + #[allow(clippy::no_effect)] + { + // this part is only here to force a compiler error if the last field + // uses the #[size(n)] attribute but the type is not shame::GpuSized. + #(#enable_if_last_field_has_size_attribute; // only generate the line below if the last field has a #[size(n)] attribute + // compiler error if not shame::GpuSized + fn __() where #last_field_type: #re::GpuSized {} + )* + + match <#last_field_type as #re::GpuAligned>::aligned_ty() { + #re::ir::AlignedType::Sized(ty) => + fields.push(#re::ir::SizedField { + name: std::stringify!(#last_field_ident).into(), + custom_min_size: #last_field_size, + custom_min_align: #last_field_align.map(|align: u32| TryFrom::try_from(align).expect("power of two validated during codegen")), + ty + }), + #re::ir::AlignedType::RuntimeSizedArray(element_ty) => + last_unsized = Some(#re::ir::RuntimeSizedArrayField { + name: std::stringify!(#last_field_ident).into(), + custom_min_align: #last_field_align.map(|align: u32| TryFrom::try_from(align).expect("power of two validated during codegen")), + element_ty + }), + } } - } - } - } - impl<#generics_decl> #re::GpuStore for #derive_struct_ident<#(#idents_of_generics),*> - where - #(#triv #field_type: #re::GpuStore + #re::GpuType,)* - #where_clause_predicates - { - type RefFields = #derive_struct_ref_ident; - - fn store_ty() -> #re::ir::StoreType where - #triv Self: #re::GpuType { - unreachable!("Self: !GpuType") + use #re::BufferBlockDefinitionError as E; + match #re::ir::BufferBlock::new( + std::stringify!(#derive_struct_ident).into(), + fields, + last_unsized + ) { + Ok(t) => t, + Err(e) => match e { + E::MustHaveAtLeastOneField => unreachable!(">= 1 field is ensured by derive macro"), + E::FieldNamesMustBeUnique(_) => unreachable!("unique field idents are ensured by rust struct definition"), + } + } + } } - fn instantiate_buffer_inner( - args: Result<#re::BindingArgs, #re::InvalidReason>, - bind_ty: #re::BindingType - ) -> #re::BufferInner + impl<#generics_decl> #re::GpuStore for #derive_struct_ident<#(#idents_of_generics),*> where - #triv Self: - #re::NoAtomics + - #re::NoBools + #(#triv #field_type: #re::GpuStore + #re::GpuType,)* + #where_clause_predicates { - #re::BufferInner::new_fields(args, bind_ty) - } + type RefFields = #derive_struct_ref_ident; - fn instantiate_buffer_ref_inner( - args: Result<#re::BindingArgs, #re::InvalidReason>, - bind_ty: #re::BindingType - ) -> #re::BufferRefInner - where - #triv Self: #re::NoBools, - { - #re::BufferRefInner::new_fields(args, bind_ty) - } + fn store_ty() -> #re::ir::StoreType where + #triv Self: #re::GpuType { + unreachable!("Self: !GpuType") + } + + fn instantiate_buffer_inner( + args: Result<#re::BindingArgs, #re::InvalidReason>, + bind_ty: #re::BindingType + ) -> #re::BufferInner + where + #triv Self: + #re::NoAtomics + + #re::NoBools + { + #re::BufferInner::new_fields(args, bind_ty) + } + + fn instantiate_buffer_ref_inner( + args: Result<#re::BindingArgs, #re::InvalidReason>, + bind_ty: #re::BindingType + ) -> #re::BufferRefInner + where + #triv Self: #re::NoBools, + { + #re::BufferRefInner::new_fields(args, bind_ty) + } - fn impl_category() -> #re::GpuStoreImplCategory { - #re::GpuStoreImplCategory::Fields(::get_bufferblock_type()) + fn impl_category() -> #re::GpuStoreImplCategory { + #re::GpuStoreImplCategory::Fields(::get_bufferblock_type()) + } } - } - impl<#generics_decl> #re::SizedFields for #derive_struct_ident<#(#idents_of_generics),*> - where - #(#triv #field_type: #re::GpuSized + #re::GpuStore + #re::GpuType,)* - #where_clause_predicates - { - fn get_sizedstruct_type() -> #re::ir::SizedStruct { - let struct_ = #re::ir::SizedStruct::new_nonempty( - std::stringify!(#derive_struct_ident).into(), - std::vec::Vec::from([ - #( - #re::ir::SizedField { - name: std::stringify!(#first_fields_ident).into(), - custom_min_size: #first_fields_size, - custom_min_align: #first_fields_align.map(|align: u32| TryFrom::try_from(align).expect("power of two validated during codegen")), - ty: <#first_fields_type as #re::GpuSized>::sized_ty(), - } - ),* - ]), - #re::ir::SizedField { - name: std::stringify!(#last_field_ident).into(), - custom_min_size: #last_field_size, - custom_min_align: #last_field_align.map(|align: u32| TryFrom::try_from(align).expect("power of two validated during codegen")), - ty: <#last_field_type as #re::GpuSized>::sized_ty(), + impl<#generics_decl> #re::SizedFields for #derive_struct_ident<#(#idents_of_generics),*> + where + #(#triv #field_type: #re::GpuSized + #re::GpuStore + #re::GpuType,)* + #where_clause_predicates + { + fn get_sizedstruct_type() -> #re::ir::SizedStruct { + let struct_ = #re::ir::SizedStruct::new_nonempty( + std::stringify!(#derive_struct_ident).into(), + std::vec::Vec::from([ + #( + #re::ir::SizedField { + name: std::stringify!(#first_fields_ident).into(), + custom_min_size: #first_fields_size, + custom_min_align: #first_fields_align.map(|align: u32| TryFrom::try_from(align).expect("power of two validated during codegen")), + ty: <#first_fields_type as #re::GpuSized>::sized_ty(), + } + ),* + ]), + #re::ir::SizedField { + name: std::stringify!(#last_field_ident).into(), + custom_min_size: #last_field_size, + custom_min_align: #last_field_align.map(|align: u32| TryFrom::try_from(align).expect("power of two validated during codegen")), + ty: <#last_field_type as #re::GpuSized>::sized_ty(), + } + ); + match struct_ { + Ok(s) => s, + Err(#re::ir::StructureFieldNamesMustBeUnique { .. }) => unreachable!("field name uniqueness is checked by rust"), } - ); - match struct_ { - Ok(s) => s, - Err(#re::ir::StructureFieldNamesMustBeUnique) => unreachable!("field name uniqueness is checked by rust"), } } - } - impl<#generics_decl> #re::GetAllFields for #derive_struct_ident<#(#idents_of_generics),*> - where - #(#triv #first_fields_type: #re::GpuSized,)* - #triv #last_field_type: #re::GpuAligned, - #where_clause_predicates - { - fn fields_as_anys_unchecked(self_: #re::Any) -> impl std::borrow::Borrow<[#re::Any]> { - [ - #(self_.get_field(std::stringify!(#field_ident).into())),* - ] + impl<#generics_decl> #re::GetAllFields for #derive_struct_ident<#(#idents_of_generics),*> + where + #(#triv #first_fields_type: #re::GpuSized,)* + #triv #last_field_type: #re::GpuAligned, + #where_clause_predicates + { + fn fields_as_anys_unchecked(self_: #re::Any) -> impl std::borrow::Borrow<[#re::Any]> { + [ + #(self_.get_field(std::stringify!(#field_ident).into())),* + ] + } } - } - impl #re::FromAnys for #derive_struct_ref_ident - where #( - #triv #field_type: #re::GpuStore + #re::GpuType, - )* { - fn expected_num_anys() -> usize {#num_fields} - - #[track_caller] - fn from_anys(mut anys: impl Iterator) -> Self { - use #re::{ - collect_into_array_exact, - push_wrong_amount_of_args_error - }; - const EXPECTED_LEN: usize = #num_fields; - let [#(#field_ident),*] = match collect_into_array_exact::<#re::Any, EXPECTED_LEN>(anys) { - Ok(t) => t, - Err(actual_len) => { - let any = push_wrong_amount_of_args_error(actual_len, EXPECTED_LEN, #re::call_info!()); - [any; EXPECTED_LEN] + impl #re::FromAnys for #derive_struct_ref_ident + where #( + #triv #field_type: #re::GpuStore + #re::GpuType, + )* { + fn expected_num_anys() -> usize {#num_fields} + + #[track_caller] + fn from_anys(mut anys: impl Iterator) -> Self { + use #re::{ + collect_into_array_exact, + push_wrong_amount_of_args_error + }; + const EXPECTED_LEN: usize = #num_fields; + let [#(#field_ident),*] = match collect_into_array_exact::<#re::Any, EXPECTED_LEN>(anys) { + Ok(t) => t, + Err(actual_len) => { + let any = push_wrong_amount_of_args_error(actual_len, EXPECTED_LEN, #re::call_info!()); + std::array::from_fn(|_| any) + } + }; + Self { + #(#field_ident: From::from(#field_ident)),* } - }; - Self { - #(#field_ident: From::from(#field_ident)),* } } - } - }) + }) + } } } WhichDerive::CpuLayout => { let align_attr_or_none = match repr_align_attr { None => quote!(None), - Some(n) => quote!(Some(#n as u64)), + Some(n) => quote!( + Some(#re::U32PowerOf2::try_from_usize(#n)).expect("rust checks that N in repr(C, align(N)) is a power of 2.") + ), }; Ok(quote! { @@ -558,6 +585,7 @@ pub fn impl_for_struct( #(#field_type: #re::CpuAligned,)* #where_clause_predicates { + #[track_caller] fn cpu_layout() -> #re::TypeLayout { //use #re::CpuLayout // using `use` instead of `as #re::CpuAligned` allows for duck-traits to circumvent the orphan rule use #re::CpuAligned; @@ -573,7 +601,7 @@ pub fn impl_for_struct( layout: <#first_fields_type>::cpu_layout(), // DO NOT refactor to `as #re::CpuLayout`, that would prevent the duck-trait trick for circumventing the orphan rule }, std::mem::offset_of!(#derive_struct_ident, #first_fields_ident), - std::mem::size_of::<#first_fields_type>(), + std::mem::size_of::<#first_fields_type>(), // TODO(release): it is correct that this uses `std::mem::size_of`. At the time of writing, the example implementations of `CpuLayout` for `f32x3_cpu` in the type layout tests are technically wrong, causing the size of <#first_fields_type>::cpu_layout() to disagree with this one. This needs to be addressed! )),* ], #re::ReprCField { diff --git a/shame_derive/src/util.rs b/shame_derive/src/util.rs index afa6930..eae1c82 100644 --- a/shame_derive/src/util.rs +++ b/shame_derive/src/util.rs @@ -27,16 +27,28 @@ pub fn find_literal_list_attr( Ok(None) } -pub fn find_gpu_repr_packed(attribs: &[syn::Attribute]) -> Result> { +pub enum Repr { + Packed, + Wgsl, +} + +pub fn try_find_gpu_repr(attribs: &[syn::Attribute]) -> Result> { + let mut repr = Repr::Wgsl; for a in attribs { if a.path().is_ident("gpu_repr") { a.parse_nested_meta(|meta| { if meta.path.is_ident("packed") { + repr = Repr::Packed; + return Ok(()); + } else if meta.path.is_ident("wgsl") { + repr = Repr::Wgsl; return Ok(()); } - Err(meta.error("unrecognized `gpu_repr`. Did you mean `gpu_repr(packed)`?")) + + Err(meta.error("unrecognized `gpu_repr`. Did you mean `gpu_repr(packed)` or `gpu_repr(wgsl)`?")) })?; - return Ok(Some(a.span())); + + return Ok(Some((a.span(), repr))); } } Ok(None)