diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 8dd28e0476..8c4f24578c 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -34,7 +34,6 @@ use std::{ borrow::Cow, convert::TryFrom, io::{self, Write}, - num::NonZeroU32, ops, }; use thiserror::Error; @@ -140,7 +139,9 @@ pub enum Error<'a> { UnknownType(Span), UnknownStorageFormat(Span), UnknownConservativeDepth(Span), - ZeroSizeOrAlign(Span), + SizeAttributeTooLow(Span, u32), + AlignAttributeTooLow(Span, Alignment), + NonPowerOfTwoAlignAttribute(Span), InconsistentBinding(Span), UnknownLocalFunction(Span), TypeNotConstructible(Span), @@ -366,9 +367,19 @@ impl<'a> Error<'a> { labels: vec![(bad_span.clone(), "unknown type".into())], notes: vec![], }, - Error::ZeroSizeOrAlign(ref bad_span) => ParseError { - message: "struct member size or alignment must not be 0".to_string(), - labels: vec![(bad_span.clone(), "struct member size or alignment must not be 0".into())], + Error::SizeAttributeTooLow(ref bad_span, min_size) => ParseError { + message: format!("struct member size must be at least {}", min_size), + labels: vec![(bad_span.clone(), format!("must be at least {}", min_size).into())], + notes: vec![], + }, + Error::AlignAttributeTooLow(ref bad_span, min_align) => ParseError { + message: format!("struct member alignment must be at least {}", min_align), + labels: vec![(bad_span.clone(), format!("must be at least {}", min_align).into())], + notes: vec![], + }, + Error::NonPowerOfTwoAlignAttribute(ref bad_span) => ParseError { + message: "struct member alignment must be a power of 2".to_string(), + labels: vec![(bad_span.clone(), "must be a power of 2".into())], notes: vec![], }, Error::InconsistentBinding(ref span) => ParseError { @@ -2787,7 +2798,7 @@ impl Parser { const_arena: &mut Arena, ) -> Result<(Vec, u32), Error<'a>> { let mut offset = 0; - let mut alignment = Alignment::new(1).unwrap(); + let mut struct_alignment = Alignment::ONE; let mut members = Vec::new(); lexer.expect(Token::Paren('{'))?; @@ -2799,7 +2810,7 @@ impl Parser { ExpectedToken::Token(Token::Separator(',')), )); } - let (mut size, mut align) = (None, None); + let (mut size_attr, mut align_attr) = (None, None); self.push_scope(Scope::Attribute, lexer); let mut bind_parser = BindingParser::default(); while lexer.skip(Token::Attribute) { @@ -2809,20 +2820,22 @@ impl Parser { let (value, span) = lexer.capture_span(Self::parse_non_negative_i32_literal)?; lexer.expect(Token::Paren(')'))?; - size = Some(NonZeroU32::new(value).ok_or(Error::ZeroSizeOrAlign(span))?); + size_attr = Some((value, span)); } ("align", _) => { lexer.expect(Token::Paren('('))?; let (value, span) = lexer.capture_span(Self::parse_non_negative_i32_literal)?; lexer.expect(Token::Paren(')'))?; - align = Some(Alignment::new(value).ok_or(Error::ZeroSizeOrAlign(span))?); + align_attr = Some((value, span)); } (word, word_span) => bind_parser.parse(lexer, word, word_span)?, } } let bind_span = self.pop_scope(lexer); + let mut binding = bind_parser.finish(bind_span)?; + let (name, span) = match lexer.next() { (Token::Word(word), span) => (word, span), other => return Err(Error::Unexpected(other, ExpectedToken::FieldName)), @@ -2831,29 +2844,57 @@ impl Parser { return Err(Error::ReservedKeyword(span)); } lexer.expect(Token::Separator(':'))?; - let (ty, _access) = self.parse_type_decl(lexer, None, type_arena, const_arena)?; + let (ty, _) = self.parse_type_decl(lexer, None, type_arena, const_arena)?; ready = lexer.skip(Token::Separator(',')); self.layouter.update(type_arena, const_arena).unwrap(); - let (range, align) = self.layouter.member_placement(offset, ty, align, size); - alignment = alignment.max(align); - offset = range.end; + let member_min_size = self.layouter[ty].size; + let member_min_alignment = self.layouter[ty].alignment; + + let member_size = if let Some((size, span)) = size_attr { + if size < member_min_size { + return Err(Error::SizeAttributeTooLow(span, member_min_size)); + } else { + size + } + } else { + member_min_size + }; + + let member_alignment = if let Some((align, span)) = align_attr { + if let Some(alignment) = Alignment::new(align) { + if alignment < member_min_alignment { + return Err(Error::AlignAttributeTooLow(span, member_min_alignment)); + } else { + alignment + } + } else { + return Err(Error::NonPowerOfTwoAlignAttribute(span)); + } + } else { + member_min_alignment + }; + + offset = member_alignment.round_up(offset); + struct_alignment = struct_alignment.max(member_alignment); - let mut binding = bind_parser.finish(bind_span)?; if let Some(ref mut binding) = binding { binding.apply_default_interpolation(&type_arena[ty].inner); } + members.push(crate::StructMember { name: Some(name.to_owned()), ty, binding, - offset: range.start, + offset, }); + + offset += member_size; } - let span = alignment.round_up(offset); - Ok((members, span)) + let struct_size = struct_alignment.round_up(offset); + Ok((members, struct_size)) } fn parse_matrix_scalar_type<'a>( diff --git a/src/front/wgsl/tests.rs b/src/front/wgsl/tests.rs index 8c5604ecc4..33fc541acb 100644 --- a/src/front/wgsl/tests.rs +++ b/src/front/wgsl/tests.rs @@ -91,7 +91,7 @@ fn parse_struct() { struct Bar { @size(16) x: vec2, @align(16) y: f32, - @size(32) @align(8) z: vec3, + @size(32) @align(128) z: vec3, }; struct Empty {} var s: Foo; diff --git a/src/proc/layouter.rs b/src/proc/layouter.rs index 4bfb7ec052..562c681f5d 100644 --- a/src/proc/layouter.rs +++ b/src/proc/layouter.rs @@ -151,32 +151,6 @@ impl Layouter { self.layouts.clear(); } - /// Return the offset and span of a struct member. - /// - /// The member must fall at or after `offset`. The member's alignment and - /// size are `align` and `size` if given, defaulting to the values this - /// `Layouter` has previously determined for `ty`. - /// - /// The return value is the range of offsets within the containing struct to - /// reserve for this member, along with the alignment used. The containing - /// struct must have sufficient space and alignment to accommodate these. - pub fn member_placement( - &self, - offset: u32, - ty: Handle, - align: Option, - size: Option, - ) -> (ops::Range, Alignment) { - let layout = self.layouts[ty.index()]; - let alignment = align.unwrap_or(layout.alignment); - let start = alignment.round_up(offset); - let span = match size { - Some(size) => size.get(), - None => layout.size, - }; - (start..start + span, alignment) - } - /// Extend this `Layouter` with layouts for any new entries in `types`. /// /// Ensure that every type in `types` has a corresponding [TypeLayout] in diff --git a/tests/wgsl-errors.rs b/tests/wgsl-errors.rs index 74f9b621f0..564958f814 100644 --- a/tests/wgsl-errors.rs +++ b/tests/wgsl-errors.rs @@ -428,36 +428,54 @@ fn unknown_conservative_depth() { } #[test] -fn struct_member_zero_size() { +fn struct_member_size_too_low() { check( r#" struct Bar { @size(0) data: array } "#, - r#"error: struct member size or alignment must not be 0 + r#"error: struct member size must be at least 4 ┌─ wgsl:3:23 │ 3 │ @size(0) data: array - │ ^ struct member size or alignment must not be 0 + │ ^ must be at least 4 "#, ); } #[test] -fn struct_member_zero_align() { +fn struct_member_align_too_low() { check( r#" struct Bar { - @align(0) data: array + @align(8) data: vec3 } "#, - r#"error: struct member size or alignment must not be 0 + r#"error: struct member alignment must be at least 16 ┌─ wgsl:3:24 │ -3 │ @align(0) data: array - │ ^ struct member size or alignment must not be 0 +3 │ @align(8) data: vec3 + │ ^ must be at least 16 + +"#, + ); +} + +#[test] +fn struct_member_non_po2_align() { + check( + r#" + struct Bar { + @align(7) data: array + } + "#, + r#"error: struct member alignment must be a power of 2 + ┌─ wgsl:3:24 + │ +3 │ @align(7) data: array + │ ^ must be a power of 2 "#, );