diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index db143edabb..ff2be31b99 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -858,18 +858,14 @@ impl Writer { let mut current_offset = 0; let mut member_ids = Vec::with_capacity(members.len()); for (index, member) in members.iter().enumerate() { - let layout = self.layouter.resolve(member.ty); - current_offset += layout.pad(current_offset); + let (placement, _) = self.layouter.member_placement(current_offset, member); self.annotations.push(Instruction::member_decorate( id, index as u32, spirv::Decoration::Offset, - &[current_offset], + &[placement.start], )); - current_offset += match member.span { - Some(span) => span.get(), - None => layout.size, - }; + current_offset = placement.end; if self.flags.contains(WriterFlags::DEBUG) { if let Some(ref name) = member.name { diff --git a/src/front/glsl/parser.rs b/src/front/glsl/parser.rs index 5efd1076e1..4255ec4955 100644 --- a/src/front/glsl/parser.rs +++ b/src/front/glsl/parser.rs @@ -725,9 +725,12 @@ pomelo! { if let Some(ty) = t { sdl.iter().map(|name| StructMember { name: Some(name.clone()), - span: None, ty, binding: None, //TODO + //TODO: if the struct is a uniform struct, these values have to reflect + // std140 layout. Otherwise, std430. + size: None, + align: None, }).collect() } else { return Err(ErrorKind::SemanticError("Struct member can't be void".into())) diff --git a/src/front/spv/function.rs b/src/front/spv/function.rs index ed804e59be..8bd683a1b6 100644 --- a/src/front/spv/function.rs +++ b/src/front/spv/function.rs @@ -237,9 +237,10 @@ impl> super::Parser { if let super::Variable::Output(ref result) = lvar.inner { members.push(crate::StructMember { name: None, - span: None, ty: result.ty, binding: result.binding.clone(), + size: None, + align: None, }); // populate just the globals first, then do `Load` in a // separate step, so that we can get a range. diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index 318b7f5131..9e35bd5f87 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -2369,9 +2369,10 @@ impl> Parser { host_shared |= decor.offset.is_some(); members.push(crate::StructMember { name: decor.name, - span: None, //TODO ty, binding: None, + size: None, //TODO + align: None, }); } diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 182b16e7bb..7586dc1aff 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -107,6 +107,8 @@ pub enum Error<'a> { UnknownConservativeDepth(&'a str), #[error("array stride must not be 0")] ZeroStride, + #[error("struct member size or array must not be 0")] + ZeroSizeOrAlign, #[error("not a composite type: {0:?}")] NotCompositeType(Handle), #[error("Input/output binding is not consistent: location {0:?}, built-in {1:?} and interpolation {2:?}")] @@ -1502,7 +1504,7 @@ impl Parser { let mut members = Vec::new(); lexer.expect(Token::Paren('{'))?; loop { - let mut span = 0; + let (mut size, mut align) = (None, None); let mut bind_parser = BindingParser::default(); if lexer.skip(Token::DoubleParen('[')) { self.scopes.push(Scope::Decoration); @@ -1517,17 +1519,19 @@ impl Parser { } (Token::Word(word), _) if ready => { match word { - "span" => { + "size" => { lexer.expect(Token::Paren('('))?; - //Note: 0 is not handled - span = lexer.next_uint_literal()?; + let value = lexer.next_uint_literal()?; lexer.expect(Token::Paren(')'))?; + size = + Some(NonZeroU32::new(value).ok_or(Error::ZeroSizeOrAlign)?); } - "offset" => { - // skip - only here for parsing compatibility + "align" => { lexer.expect(Token::Paren('('))?; - let _offset = lexer.next_uint_literal()?; + let value = lexer.next_uint_literal()?; lexer.expect(Token::Paren(')'))?; + align = + Some(NonZeroU32::new(value).ok_or(Error::ZeroSizeOrAlign)?); } _ => bind_parser.parse(lexer, word)?, } @@ -1550,9 +1554,10 @@ impl Parser { members.push(crate::StructMember { name: Some(name.to_owned()), - span: NonZeroU32::new(span), ty, binding: bind_parser.finish()?, + size, + align, }); } } diff --git a/src/front/wgsl/tests.rs b/src/front/wgsl/tests.rs index 847e3dea7e..fbad55201c 100644 --- a/src/front/wgsl/tests.rs +++ b/src/front/wgsl/tests.rs @@ -70,7 +70,11 @@ fn parse_struct() { parse_str( " [[block]] struct Foo { x: i32; }; - struct Bar { [[span(16)]] x: vec2; }; + struct Bar { + [[size(16)]] x: vec2; + [[align(16)]] y: f32; + [[size(32), align(8)]] z: vec3; + }; struct Empty {}; var s: [[access(read_write)]] Foo; ", diff --git a/src/lib.rs b/src/lib.rs index ebca7842a4..52de57311f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -238,10 +238,14 @@ pub enum Interpolation { #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub struct StructMember { pub name: Option, - pub span: Option, + /// Type of the field. pub ty: Handle, /// For I/O structs, defines the binding. pub binding: Option, + /// Overrides the size computed off the type. + pub size: Option, + /// Overrides the alignment computed off the type. + pub align: Option, } /// The number of dimensions an image has. diff --git a/src/proc/layouter.rs b/src/proc/layouter.rs index 8512e96d52..48056d3103 100644 --- a/src/proc/layouter.rs +++ b/src/proc/layouter.rs @@ -1,5 +1,5 @@ use crate::arena::Arena; -use std::num::NonZeroU32; +use std::{num::NonZeroU32, ops}; pub type Alignment = NonZeroU32; @@ -10,17 +10,9 @@ pub struct TypeLayout { pub alignment: Alignment, } -impl TypeLayout { - /// Return padding to this type given an offset. - pub fn pad(&self, offset: u32) -> u32 { - match offset & self.alignment.get() { - 0 => 0, - other => self.alignment.get() - other, - } - } -} - /// Helper processor that derives the sizes of all types. +/// It uses the default layout algorithm/table, described in +/// https://github.com/gpuweb/gpuweb/issues/1393 #[derive(Debug, Default)] pub struct Layouter { layouts: Vec, @@ -33,6 +25,29 @@ impl Layouter { this } + pub fn round_up(alignment: NonZeroU32, offset: u32) -> u32 { + match offset & alignment.get() { + 0 => offset, + other => offset + alignment.get() - other, + } + } + + pub fn member_placement( + &self, + offset: u32, + member: &crate::StructMember, + ) -> (ops::Range, NonZeroU32) { + let layout = self.layouts[member.ty.index()]; + let alignment = member.align.unwrap_or(layout.alignment); + let start = Self::round_up(alignment, offset); + let end = start + + match member.size { + Some(size) => size.get(), + None => layout.size, + }; + (start..end, alignment) + } + pub fn initialize(&mut self, types: &Arena, constants: &Arena) { use crate::TypeInner as Ti; @@ -51,8 +66,10 @@ impl Layouter { width, } => TypeLayout { size: (size as u8 * width) as u32, - //TODO: reconsider if this needs to match the size - alignment: Alignment::new(width as u32).unwrap(), + alignment: { + let count = if size >= crate::VectorSize::Tri { 4 } else { 2 }; + Alignment::new((count * width) as u32).unwrap() + }, }, Ti::Matrix { columns, @@ -60,7 +77,10 @@ impl Layouter { width, } => TypeLayout { size: (columns as u8 * rows as u8 * width) as u32, - alignment: Alignment::new((columns as u8 * width) as u32).unwrap(), + alignment: { + let count = if rows >= crate::VectorSize::Tri { 4 } else { 2 }; + Alignment::new((count * width) as u32).unwrap() + }, }, Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout { size: 4, @@ -83,11 +103,15 @@ impl Layouter { } => value as u32, ref other => unreachable!("Unexpected array size {:?}", other), }, - crate::ArraySize::Dynamic => 1, + crate::ArraySize::Dynamic => 0, }; let stride = match stride { Some(value) => value, - None => Alignment::new(self.layouts[base.index()].size.max(1)).unwrap(), + None => { + let layout = &self.layouts[base.index()]; + let stride = Self::round_up(layout.alignment, layout.size); + Alignment::new(stride).unwrap() + } }; TypeLayout { size: count * stride.get(), @@ -101,18 +125,12 @@ impl Layouter { let mut total = 0; let mut biggest_alignment = Alignment::new(1).unwrap(); for member in members { - let member_layout = self.layouts[member.ty.index()]; - biggest_alignment = biggest_alignment.max(member_layout.alignment); - // align up first - total += member_layout.pad(total); - // then add the size - total += match member.span { - Some(span) => span.get(), - None => member_layout.size, - }; + let (placement, alignment) = self.member_placement(total, member); + biggest_alignment = biggest_alignment.max(alignment); + total = placement.end; } TypeLayout { - size: total, + size: Self::round_up(biggest_alignment, total), alignment: biggest_alignment, } } diff --git a/src/proc/validator.rs b/src/proc/validator.rs index 90fad50a28..2d198237ab 100644 --- a/src/proc/validator.rs +++ b/src/proc/validator.rs @@ -555,7 +555,11 @@ impl Validator { } TypeFlags::SIZED //TODO: `DATA`? } - Ti::Array { base, size, stride } => { + Ti::Array { + base, + size, + stride: _, + } => { if base >= handle { return Err(TypeError::UnresolvedBase(base)); } @@ -596,11 +600,7 @@ impl Validator { } crate::ArraySize::Dynamic => TypeFlags::empty(), }; - let base_mask = if stride.is_none() { - TypeFlags::INTERFACE - } else { - TypeFlags::HOST_SHARED | TypeFlags::INTERFACE - }; + let base_mask = TypeFlags::HOST_SHARED | TypeFlags::INTERFACE; TypeFlags::DATA | (base_flags & base_mask) | sized_flag } Ti::Struct { block, ref members } => { diff --git a/tests/out/collatz.ron.snap b/tests/out/collatz.ron.snap index 5b3a5eca4a..e4e9224c72 100644 --- a/tests/out/collatz.ron.snap +++ b/tests/out/collatz.ron.snap @@ -26,9 +26,10 @@ expression: output members: [ ( name: Some("data"), - span: None, ty: 2, binding: None, + size: None, + align: None, ), ], ), diff --git a/tests/out/shadow.ron.snap b/tests/out/shadow.ron.snap index e1276ca7bf..cda0423ffb 100644 --- a/tests/out/shadow.ron.snap +++ b/tests/out/shadow.ron.snap @@ -110,9 +110,10 @@ expression: output members: [ ( name: Some("num_lights"), - span: None, ty: 13, binding: None, + size: None, + align: None, ), ], ), @@ -153,21 +154,24 @@ expression: output members: [ ( name: Some("proj"), - span: None, ty: 18, binding: None, + size: None, + align: None, ), ( name: Some("pos"), - span: None, ty: 4, binding: None, + size: None, + align: None, ), ( name: Some("color"), - span: None, ty: 4, binding: None, + size: None, + align: None, ), ], ), @@ -187,9 +191,10 @@ expression: output members: [ ( name: Some("data"), - span: None, ty: 20, binding: None, + size: None, + align: None, ), ], ),