diff --git a/src/front/wgsl/lexer.rs b/src/front/wgsl/lexer.rs index fed5214c78..fb68a47826 100644 --- a/src/front/wgsl/lexer.rs +++ b/src/front/wgsl/lexer.rs @@ -13,7 +13,8 @@ fn consume_any(input: &str, what: impl Fn(char) -> bool) -> (&str, &str) { input.split_at(pos) } -fn consume_number(input: &str) -> (&str, &str) { +fn consume_number(input: &str) -> (Token, &str) { + //Note: I wish this function was simpler and faster... let mut is_first_char = true; let mut right_after_exponent = false; @@ -32,7 +33,28 @@ fn consume_number(input: &str) -> (&str, &str) { } }; let pos = input.find(|c| !what(c)).unwrap_or_else(|| input.len()); - input.split_at(pos) + let (value, rest) = input.split_at(pos); + + let mut rest_iter = rest.chars(); + let ty = rest_iter.next().unwrap_or(' '); + match ty { + 'u' | 'i' | 'f' => { + let width_end = rest_iter + .position(|c| !('0'..='9').contains(&c)) + .unwrap_or_else(|| rest.len() - 1); + let (width, rest) = rest[1..].split_at(width_end); + (Token::Number { value, ty, width }, rest) + } + // default to `i32` or `f32` + _ => ( + Token::Number { + value, + ty: if value.contains('.') { 'f' } else { 'i' }, + width: "", + }, + rest, + ), + } } fn consume_token(mut input: &str) -> (Token<'_>, &str) { @@ -56,10 +78,7 @@ fn consume_token(mut input: &str) -> (Token<'_>, &str) { '.' => { let og_chars = chars.as_str(); match chars.next() { - Some('0'..='9') => { - let (number, rest) = consume_number(input); - (Token::Number(number), rest) - } + Some('0'..='9') => consume_number(input), _ => (Token::Separator(cur), og_chars), } } @@ -83,10 +102,7 @@ fn consume_token(mut input: &str) -> (Token<'_>, &str) { (Token::Paren(cur), input) } } - '0'..='9' => { - let (number, rest) = consume_number(input); - (Token::Number(number), rest) - } + '0'..='9' => consume_number(input), 'a'..='z' | 'A'..='Z' | '_' => { let (word, rest) = consume_any(input, |c| c.is_ascii_alphanumeric() || c == '_'); (Token::Word(word), rest) @@ -115,10 +131,7 @@ fn consume_token(mut input: &str) -> (Token<'_>, &str) { let og_chars = chars.as_str(); match chars.next() { Some('>') => (Token::Arrow, chars.as_str()), - Some('0'..='9') | Some('.') => { - let (number, rest) = consume_number(input); - (Token::Number(number), rest) - } + Some('0'..='9') | Some('.') => consume_number(input), _ => (Token::Operation(cur), og_chars), } } @@ -193,21 +206,25 @@ impl<'a> Lexer<'a> { fn _next_float_literal(&mut self) -> Result> { match self.next() { - Token::Number(word) => word.parse().map_err(|err| Error::BadFloat(word, err)), + Token::Number { value, .. } => value.parse().map_err(|err| Error::BadFloat(value, err)), other => other.unexpected("float literal"), } } pub(super) fn next_uint_literal(&mut self) -> Result> { match self.next() { - Token::Number(word) => word.parse().map_err(|err| Error::BadInteger(word, err)), + Token::Number { value, .. } => { + value.parse().map_err(|err| Error::BadInteger(value, err)) + } other => other.unexpected("uint literal"), } } pub(super) fn next_sint_literal(&mut self) -> Result> { match self.next() { - Token::Number(word) => word.parse().map_err(|err| Error::BadInteger(word, err)), + Token::Number { value, .. } => { + value.parse().map_err(|err| Error::BadInteger(value, err)) + } other => other.unexpected("sint literal"), } } @@ -246,7 +263,39 @@ fn sub_test(source: &str, expected_tokens: &[Token]) { #[test] fn test_tokens() { sub_test("id123_OK", &[Token::Word("id123_OK")]); - sub_test("92No", &[Token::Number("92"), Token::Word("No")]); + sub_test( + "92No", + &[ + Token::Number { + value: "92", + ty: 'i', + width: "", + }, + Token::Word("No"), + ], + ); + sub_test( + "2u3o", + &[ + Token::Number { + value: "2", + ty: 'u', + width: "3", + }, + Token::Word("o"), + ], + ); + sub_test( + "2.4f44po", + &[ + Token::Number { + value: "2.4", + ty: 'f', + width: "44", + }, + Token::Word("po"), + ], + ); sub_test( "æNoø", &[Token::Unknown('æ'), Token::Word("No"), Token::Unknown('ø')], @@ -264,7 +313,11 @@ fn test_variable_decl() { Token::DoubleParen('['), Token::Word("group"), Token::Paren('('), - Token::Number("0"), + Token::Number { + value: "0", + ty: 'i', + width: "", + }, Token::Paren(')'), Token::DoubleParen(']'), Token::Word("var"), diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 578e946969..fbff173b95 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -23,7 +23,11 @@ pub enum Token<'a> { DoubleColon, Paren(char), DoubleParen(char), - Number(&'a str), + Number { + value: &'a str, + ty: char, + width: &'a str, + }, String(&'a str), Word(&'a str), Operation(char), @@ -50,6 +54,8 @@ pub enum Error<'a> { BadInteger(&'a str, std::num::ParseIntError), #[error("unable to parse `{1}` as float: {1}")] BadFloat(&'a str, std::num::ParseFloatError), + #[error("unable to parse `{0}{1}{2}` as scalar width: {3}")] + BadScalarWidth(&'a str, char, &'a str, std::num::ParseIntError), #[error("bad field accessor `{0}`")] BadAccessor(&'a str), #[error("bad texture {0}`")] @@ -369,16 +375,37 @@ impl Parser { } } - fn get_scalar_value(word: &str) -> Result> { - if word.contains('.') { - word.parse() - .map(crate::ScalarValue::Float) - .map_err(|err| Error::BadFloat(word, err)) - } else { - word.parse() + fn get_constant_inner<'a>( + word: &'a str, + ty: char, + width: &'a str, + ) -> Result> { + let value = match ty { + 'i' => word + .parse() .map(crate::ScalarValue::Sint) - .map_err(|err| Error::BadInteger(word, err)) - } + .map_err(|err| Error::BadInteger(word, err))?, + 'u' => word + .parse() + .map(crate::ScalarValue::Uint) + .map_err(|err| Error::BadInteger(word, err))?, + 'f' => word + .parse() + .map(crate::ScalarValue::Float) + .map_err(|err| Error::BadFloat(word, err))?, + _ => unreachable!(), + }; + Ok(crate::ConstantInner::Scalar { + value, + width: if width.is_empty() { + 4 + } else { + match width.parse::() { + Ok(bits) => bits / 8, + Err(e) => return Err(Error::BadScalarWidth(word, ty, width, e)), + } + }, + }) } fn parse_function_call_inner<'a>( @@ -698,12 +725,9 @@ impl Parser { value: crate::ScalarValue::Bool(false), } } - Token::Number(word) => { + Token::Number { value, ty, width } => { let _ = lexer.next(); - crate::ConstantInner::Scalar { - width: 4, - value: Self::get_scalar_value(word)?, - } + Self::get_constant_inner(value, ty, width)? } _ => { let (composite_ty, _access) = @@ -773,12 +797,12 @@ impl Parser { }); crate::Expression::Constant(handle) } - Token::Number(word) => { - let value = Self::get_scalar_value(word)?; + Token::Number { value, ty, width } => { + let inner = Self::get_constant_inner(value, ty, width)?; let handle = ctx.constants.fetch_or_append(crate::Constant { name: None, specialization: None, - inner: crate::ConstantInner::Scalar { width: 4, value }, + inner, }); crate::Expression::Constant(handle) } diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 4732b19cee..9016c3be00 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -155,30 +155,30 @@ fn convert_wgsl(name: &str, language: Language) { #[cfg(feature = "wgsl-in")] #[test] -fn converts_wgsl_quad() { +fn convert_wgsl_quad() { convert_wgsl("quad", Language::all()); } #[cfg(feature = "wgsl-in")] #[test] -fn converts_wgsl_simple() { - convert_wgsl("simple", Language::all()); +fn convert_wgsl_empty() { + convert_wgsl("empty", Language::all()); } #[cfg(feature = "wgsl-in")] #[test] -fn converts_wgsl_boids() { +fn convert_wgsl_boids() { convert_wgsl("boids", Language::METAL); } #[cfg(feature = "wgsl-in")] #[test] -fn converts_wgsl_skybox() { +fn convert_wgsl_skybox() { convert_wgsl("skybox", Language::all()); } #[cfg(feature = "wgsl-in")] #[test] -fn converts_wgsl_collatz() { +fn convert_wgsl_collatz() { convert_wgsl("collatz", Language::METAL); } diff --git a/tests/snapshots/in/boids.wgsl b/tests/snapshots/in/boids.wgsl index 2e8b074259..b75b671a01 100644 --- a/tests/snapshots/in/boids.wgsl +++ b/tests/snapshots/in/boids.wgsl @@ -69,7 +69,7 @@ struct Particles { [[stage(compute), workgroup_size(1)]] fn main() { const index : u32 = gl_GlobalInvocationID.x; - if (index >= u32(5)) { + if (index >= 5u) { return; } @@ -84,9 +84,9 @@ fn main() { var pos : vec2; var vel : vec2; - var i : u32 = u32(0); + var i : u32 = 0u; loop { - if (i >= u32(5)) { + if (i >= 5u) { break; } if (i == index) { @@ -109,7 +109,7 @@ fn main() { } continuing { - i = i + u32(1); + i = i + 1u; } } if (cMassCount > 0) { diff --git a/tests/snapshots/in/collatz.wgsl b/tests/snapshots/in/collatz.wgsl index a0853e0e43..5c97310267 100644 --- a/tests/snapshots/in/collatz.wgsl +++ b/tests/snapshots/in/collatz.wgsl @@ -15,18 +15,18 @@ var v_indices: [[access(read_write)]] PrimeIndices; // Though the conjecture has not been proven, no counterexample has ever been found. // This function returns how many times this recurrence needs to be applied to reach 1. fn collatz_iterations(n: u32) -> u32{ - var i: u32 = u32(0); + var i: u32 = 0u; loop { - if (n <= u32(1)) { + if (n <= 1u) { break; } - if (n % u32(2) == u32(0)) { - n = n / u32(2); + if (n % 2u == 0u) { + n = n / 2u; } else { - n = u32(3) * n + i32(1); + n = 3u * n + 1u; } - i = i + u32(1); + i = i + 1u; } return i; } diff --git a/tests/snapshots/in/simple.param.ron b/tests/snapshots/in/empty.param.ron similarity index 100% rename from tests/snapshots/in/simple.param.ron rename to tests/snapshots/in/empty.param.ron diff --git a/tests/snapshots/in/empty.wgsl b/tests/snapshots/in/empty.wgsl new file mode 100644 index 0000000000..9bd04d80cf --- /dev/null +++ b/tests/snapshots/in/empty.wgsl @@ -0,0 +1,2 @@ +[[stage(compute), workgroup_size(1)]] +fn main() {} diff --git a/tests/snapshots/in/simple.wgsl b/tests/snapshots/in/simple.wgsl deleted file mode 100644 index 54d4923569..0000000000 --- a/tests/snapshots/in/simple.wgsl +++ /dev/null @@ -1,7 +0,0 @@ -// vertex -[[builtin(position)]] var o_position : vec4; - -[[stage(vertex)]] -fn main() { - o_position = vec4(1); -} diff --git a/tests/snapshots/snapshots__boids.msl.snap b/tests/snapshots/snapshots__boids.msl.snap index ca9665c13c..13a3f32fd8 100644 --- a/tests/snapshots/snapshots__boids.msl.snap +++ b/tests/snapshots/snapshots__boids.msl.snap @@ -71,17 +71,16 @@ kernel void main3( type6 cVelCount = 0; type pos1; type vel1; - type5 i; - if (gl_GlobalInvocationID.x >= static_cast(5)) { + type5 i = 0; + if (gl_GlobalInvocationID.x >= 5) { } vPos = particlesA.particles[gl_GlobalInvocationID.x].pos; vVel = particlesA.particles[gl_GlobalInvocationID.x].vel; cMass = metal::float2(0.0, 0.0); cVel = metal::float2(0.0, 0.0); colVel = metal::float2(0.0, 0.0); - i = static_cast(0); while(true) { - if (i >= static_cast(5)) { + if (i >= 5) { break; } if (i == gl_GlobalInvocationID.x) { diff --git a/tests/snapshots/snapshots__collatz.msl.snap b/tests/snapshots/snapshots__collatz.msl.snap index be5eb887a4..f076a55da3 100644 --- a/tests/snapshots/snapshots__collatz.msl.snap +++ b/tests/snapshots/snapshots__collatz.msl.snap @@ -15,18 +15,17 @@ struct PrimeIndices { type1 collatz_iterations( type1 n ) { - type1 i; - i = static_cast(0); + type1 i = 0; while(true) { - if (n <= static_cast(1)) { + if (n <= 1) { break; } - if (n % static_cast(2) == static_cast(0)) { - n = n / static_cast(2); + if (n % 2 == 0) { + n = n / 2; } else { - n = static_cast(3) * n + static_cast(1); + n = 3 * n + 1; } - i = i + static_cast(1); + i = i + 1; } return i; } diff --git a/tests/snapshots/snapshots__simple-Vertex.glsl.snap b/tests/snapshots/snapshots__empty-Compute.glsl.snap similarity index 83% rename from tests/snapshots/snapshots__simple-Vertex.glsl.snap rename to tests/snapshots/snapshots__empty-Compute.glsl.snap index d32d800749..9fb9bfbaa1 100644 --- a/tests/snapshots/snapshots__simple-Vertex.glsl.snap +++ b/tests/snapshots/snapshots__empty-Compute.glsl.snap @@ -11,7 +11,6 @@ precision highp float; void main() { - gl_Position = vec4(1); return; } diff --git a/tests/snapshots/snapshots__empty.msl.snap b/tests/snapshots/snapshots__empty.msl.snap new file mode 100644 index 0000000000..6f2876dcc2 --- /dev/null +++ b/tests/snapshots/snapshots__empty.msl.snap @@ -0,0 +1,12 @@ +--- +source: tests/snapshots.rs +expression: msl +--- +#include +#include + + +kernel void main1( +) { +} + diff --git a/tests/snapshots/snapshots__simple.spvasm.snap b/tests/snapshots/snapshots__empty.spvasm.snap similarity index 51% rename from tests/snapshots/snapshots__simple.spvasm.snap rename to tests/snapshots/snapshots__empty.spvasm.snap index 1d6cad3c4d..ae2b9628bf 100644 --- a/tests/snapshots/snapshots__simple.spvasm.snap +++ b/tests/snapshots/snapshots__empty.spvasm.snap @@ -5,23 +5,15 @@ expression: dis ; SPIR-V ; Version: 1.0 ; Generator: rspirv -; Bound: 13 +; Bound: 6 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint Vertex %3 "main" %6 -OpDecorate %6 BuiltIn Position +OpEntryPoint GLCompute %3 "main" +OpExecutionMode %3 LocalSize 1 1 1 %2 = OpTypeVoid %4 = OpTypeFunction %2 -%8 = OpTypeFloat 32 -%7 = OpTypeVector %8 4 -%9 = OpTypePointer Output %7 -%6 = OpVariable %9 Output -%11 = OpTypeInt 32 1 -%10 = OpConstant %11 1 %3 = OpFunction %2 None %4 %5 = OpLabel -%12 = OpCompositeConstruct %7 %10 -OpStore %6 %12 OpReturn OpFunctionEnd diff --git a/tests/snapshots/snapshots__simple.msl.snap b/tests/snapshots/snapshots__simple.msl.snap deleted file mode 100644 index 1102c4ea8e..0000000000 --- a/tests/snapshots/snapshots__simple.msl.snap +++ /dev/null @@ -1,22 +0,0 @@ ---- -source: tests/snapshots.rs -expression: msl ---- -#include -#include - -typedef metal::float4 type; - -struct main1Input { -}; -struct main1Output { - type o_position [[position]]; -}; -vertex main1Output main1( - main1Input input [[stage_in]] -) { - main1Output output; - output.o_position = metal::float4(1); - return output; -} -