From c0b7ac7f542cc42ccac6f2ec3fc1fb01309cf4d7 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Tue, 11 Jan 2022 10:56:41 -0500 Subject: [PATCH] WGSL: assignment binary operators --- src/front/wgsl/lexer.rs | 26 +++++++-- src/front/wgsl/mod.rs | 41 +++++++++++++- tests/in/operators.wgsl | 16 +++++- tests/out/glsl/operators.main.Compute.glsl | 34 +++++++++--- tests/out/hlsl/operators.hlsl | 36 ++++++++++--- tests/out/msl/operators.msl | 35 +++++++++--- tests/out/spv/operators.spvasm | 62 ++++++++++++++++------ tests/out/wgsl/operators.wgsl | 35 +++++++++--- 8 files changed, 237 insertions(+), 48 deletions(-) diff --git a/src/front/wgsl/lexer.rs b/src/front/wgsl/lexer.rs index e745205418..be419fd42b 100644 --- a/src/front/wgsl/lexer.rs +++ b/src/front/wgsl/lexer.rs @@ -322,7 +322,12 @@ fn consume_token(mut input: &str, generic: bool) -> (Token<'_>, &str) { if next == Some('=') && !generic { (Token::LogicalOperation(cur), chars.as_str()) } else if next == Some(cur) && !generic { - (Token::ShiftOperation(cur), chars.as_str()) + input = chars.as_str(); + if chars.next() == Some('=') { + (Token::AssignmentOperation(cur), chars.as_str()) + } else { + (Token::ShiftOperation(cur), input) + } } else { (Token::Paren(cur), input) } @@ -356,14 +361,22 @@ fn consume_token(mut input: &str, generic: bool) -> (Token<'_>, &str) { (Token::Trivia, chars.as_str()) } '-' => { - let og_chars = chars.as_str(); + let sub_input = chars.as_str(); match chars.next() { Some('>') => (Token::Arrow, chars.as_str()), Some('0'..='9') | Some('.') => consume_number(input), - _ => (Token::Operation(cur), og_chars), + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), sub_input), + } + } + '+' | '*' | '/' | '%' | '^' => { + input = chars.as_str(); + if chars.next() == Some('=') { + (Token::AssignmentOperation(cur), chars.as_str()) + } else { + (Token::Operation(cur), input) } } - '+' | '*' | '/' | '%' | '^' => (Token::Operation(cur), chars.as_str()), '!' | '~' => { input = chars.as_str(); if chars.next() == Some('=') { @@ -374,8 +387,11 @@ fn consume_token(mut input: &str, generic: bool) -> (Token<'_>, &str) { } '=' | '&' | '|' => { input = chars.as_str(); - if chars.next() == Some(cur) { + let next = chars.next(); + if next == Some(cur) { (Token::LogicalOperation(cur), chars.as_str()) + } else if next == Some('=') { + (Token::AssignmentOperation(cur), chars.as_str()) } else { (Token::Operation(cur), input) } diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 6103cf6192..5b7c343554 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -68,6 +68,7 @@ pub enum Token<'a> { Operation(char), LogicalOperation(char), ShiftOperation(char), + AssignmentOperation(char), Arrow, Unknown(char), UnterminatedString, @@ -198,6 +199,8 @@ impl<'a> Error<'a> { Token::Operation(c) => format!("operation ('{}')", c), Token::LogicalOperation(c) => format!("logical operation ('{}')", c), Token::ShiftOperation(c) => format!("bitshift ('{}{}')", c, c), + Token::AssignmentOperation(c) if c=='<' || c=='>' => format!("bitshift ('{}{}=')", c, c), + Token::AssignmentOperation(c) => format!("operation ('{}=')", c), Token::Arrow => "->".to_string(), Token::Unknown(c) => format!("unknown ('{}')", c), Token::UnterminatedString => "unterminated string".to_string(), @@ -3308,6 +3311,8 @@ impl Parser { lexer: &mut Lexer<'a>, mut context: ExpressionContext<'a, '_, 'out>, ) -> Result<(), Error<'a>> { + use crate::BinaryOperator as Bo; + let span_start = lexer.current_byte_offset(); context.emitter.start(context.expressions); let reference = self.parse_unary_expression(lexer, context.reborrow())?; @@ -3319,8 +3324,40 @@ impl Parser { span, )); } - lexer.expect(Token::Operation('='))?; - let value = self.parse_general_expression(lexer, context.reborrow())?; + + let value = match lexer.next() { + (Token::Operation('='), _) => { + self.parse_general_expression(lexer, context.reborrow())? + } + (Token::AssignmentOperation(c), span) => { + let op = match c { + '<' => Bo::ShiftLeft, + '>' => Bo::ShiftRight, + '+' => Bo::Add, + '-' => Bo::Subtract, + '*' => Bo::Multiply, + '/' => Bo::Divide, + '%' => Bo::Modulo, + '&' => Bo::And, + '|' => Bo::InclusiveOr, + '^' => Bo::ExclusiveOr, + //Note: `consume_token` shouldn't produce any other assignment ops + _ => unreachable!(), + }; + let left = context.expressions.append( + crate::Expression::Load { + pointer: reference.handle, + }, + NagaSpan::from(span_start..lexer.current_byte_offset()), + ); + let right = self.parse_general_expression(lexer, context.reborrow())?; + context + .expressions + .append(crate::Expression::Binary { op, left, right }, span.into()) + } + other => return Err(Error::Unexpected(other, ExpectedToken::SwitchItem)), + }; + let span_end = lexer.current_byte_offset(); context .block diff --git a/tests/in/operators.wgsl b/tests/in/operators.wgsl index da18e3a460..a7aedb7f25 100644 --- a/tests/in/operators.wgsl +++ b/tests/in/operators.wgsl @@ -80,11 +80,22 @@ fn scalar_times_matrix() { let assertion: mat4x4 = 2.0 * model; } -fn binary() { +fn logical() { let a = true | false; let b = true & false; } +fn binary_assignment() { + var a = 1; + a += 1; + a -= 1; + a *= a; + a /= a; + a %= 1; + a ^= 0; + a &= 0; +} + [[stage(compute), workgroup_size(1)]] fn main() { let a = builtins(); @@ -94,5 +105,6 @@ fn main() { let e = constructors(); modulo(); scalar_times_matrix(); - binary(); + logical(); + binary_assignment(); } diff --git a/tests/out/glsl/operators.main.Compute.glsl b/tests/out/glsl/operators.main.Compute.glsl index 58dac45478..1ca8db55f6 100644 --- a/tests/out/glsl/operators.main.Compute.glsl +++ b/tests/out/glsl/operators.main.Compute.glsl @@ -23,9 +23,9 @@ vec4 builtins() { } vec4 splat() { - vec2 a = (((vec2(1.0) + vec2(2.0)) - vec2(3.0)) / vec2(4.0)); + vec2 a_1 = (((vec2(1.0) + vec2(2.0)) - vec2(3.0)) / vec2(4.0)); ivec4 b = (ivec4(5) % ivec4(2)); - return (a.xyxy + vec4(b)); + return (a_1.xyxy + vec4(b)); } int unary() { @@ -51,7 +51,7 @@ float constructors() { } void modulo() { - int a_1 = (1 % 1); + int a_2 = (1 % 1); float b_1 = (1.0 - 1.0 * trunc(1.0 / 1.0)); ivec3 c = (ivec3(1) % ivec3(1)); vec3 d = (vec3(1.0) - vec3(1.0) * trunc(vec3(1.0) / vec3(1.0))); @@ -62,11 +62,32 @@ void scalar_times_matrix() { mat4x4 assertion = (2.0 * model); } -void binary() { - bool a_2 = (true || false); +void logical() { + bool a_3 = (true || false); bool b_2 = (true && false); } +void binary_assignment() { + int a = 1; + int _e6 = a; + a = (_e6 + 1); + int _e9 = a; + a = (_e9 - 1); + int _e12 = a; + int _e13 = a; + a = (_e12 * _e13); + int _e15 = a; + int _e16 = a; + a = (_e15 / _e16); + int _e18 = a; + a = (_e18 % 1); + int _e21 = a; + a = (_e21 ^ 0); + int _e24 = a; + a = (_e24 & 0); + return; +} + void main() { vec4 _e4 = builtins(); vec4 _e5 = splat(); @@ -75,7 +96,8 @@ void main() { float _e9 = constructors(); modulo(); scalar_times_matrix(); - binary(); + logical(); + binary_assignment(); return; } diff --git a/tests/out/hlsl/operators.hlsl b/tests/out/hlsl/operators.hlsl index 1a858a7e0c..8f060d85bd 100644 --- a/tests/out/hlsl/operators.hlsl +++ b/tests/out/hlsl/operators.hlsl @@ -23,9 +23,9 @@ float4 builtins() float4 splat() { - float2 a = (((float2(1.0.xx) + float2(2.0.xx)) - float2(3.0.xx)) / float2(4.0.xx)); + float2 a_1 = (((float2(1.0.xx) + float2(2.0.xx)) - float2(3.0.xx)) / float2(4.0.xx)); int4 b = (int4(5.xxxx) % int4(2.xxxx)); - return (a.xyxy + float4(b)); + return (a_1.xyxy + float4(b)); } int unary() @@ -63,7 +63,7 @@ float constructors() void modulo() { - int a_1 = (1 % 1); + int a_2 = (1 % 1); float b_1 = (1.0 % 1.0); int3 c = (int3(1.xxx) % int3(1.xxx)); float3 d = (float3(1.0.xxx) % float3(1.0.xxx)); @@ -75,12 +75,35 @@ void scalar_times_matrix() float4x4 assertion = mul(model, 2.0); } -void binary() +void logical() { - bool a_2 = (true | false); + bool a_3 = (true | false); bool b_2 = (true & false); } +void binary_assignment() +{ + int a = 1; + + int _expr6 = a; + a = (_expr6 + 1); + int _expr9 = a; + a = (_expr9 - 1); + int _expr12 = a; + int _expr13 = a; + a = (_expr12 * _expr13); + int _expr15 = a; + int _expr16 = a; + a = (_expr15 / _expr16); + int _expr18 = a; + a = (_expr18 % 1); + int _expr21 = a; + a = (_expr21 ^ 0); + int _expr24 = a; + a = (_expr24 & 0); + return; +} + [numthreads(1, 1, 1)] void main() { @@ -91,6 +114,7 @@ void main() const float _e9 = constructors(); modulo(); scalar_times_matrix(); - binary(); + logical(); + binary_assignment(); return; } diff --git a/tests/out/msl/operators.msl b/tests/out/msl/operators.msl index 5aa0fb44c9..7fa931624f 100644 --- a/tests/out/msl/operators.msl +++ b/tests/out/msl/operators.msl @@ -26,9 +26,9 @@ metal::float4 builtins( metal::float4 splat( ) { - metal::float2 a = ((metal::float2(1.0) + metal::float2(2.0)) - metal::float2(3.0)) / metal::float2(4.0); + metal::float2 a_1 = ((metal::float2(1.0) + metal::float2(2.0)) - metal::float2(3.0)) / metal::float2(4.0); metal::int4 b = metal::int4(5) % metal::int4(2); - return a.xyxy + static_cast(b); + return a_1.xyxy + static_cast(b); } int unary( @@ -59,7 +59,7 @@ float constructors( void modulo( ) { - int a_1 = 1 % 1; + int a_2 = 1 % 1; float b_1 = metal::fmod(1.0, 1.0); metal::int3 c = metal::int3(1) % metal::int3(1); metal::float3 d = metal::fmod(metal::float3(1.0), metal::float3(1.0)); @@ -71,12 +71,34 @@ void scalar_times_matrix( metal::float4x4 assertion = 2.0 * model; } -void binary( +void logical( ) { - bool a_2 = true | false; + bool a_3 = true | false; bool b_2 = true & false; } +void binary_assignment( +) { + int a = 1; + int _e6 = a; + a = _e6 + 1; + int _e9 = a; + a = _e9 - 1; + int _e12 = a; + int _e13 = a; + a = _e12 * _e13; + int _e15 = a; + int _e16 = a; + a = _e15 / _e16; + int _e18 = a; + a = _e18 % 1; + int _e21 = a; + a = _e21 ^ 0; + int _e24 = a; + a = _e24 & 0; + return; +} + kernel void main_( ) { metal::float4 _e4 = builtins(); @@ -86,6 +108,7 @@ kernel void main_( float _e9 = constructors(); modulo(); scalar_times_matrix(); - binary(); + logical(); + binary_assignment(); return; } diff --git a/tests/out/spv/operators.spvasm b/tests/out/spv/operators.spvasm index 514e9bfb8e..bd81edce36 100644 --- a/tests/out/spv/operators.spvasm +++ b/tests/out/spv/operators.spvasm @@ -1,12 +1,12 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 154 +; Bound: 176 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %143 "main" -OpExecutionMode %143 LocalSize 1 1 1 +OpEntryPoint GLCompute %164 "main" +OpExecutionMode %164 LocalSize 1 1 1 OpMemberDecorate %23 0 Offset 0 OpMemberDecorate %23 1 Offset 16 %2 = OpTypeVoid @@ -51,6 +51,7 @@ OpMemberDecorate %23 1 Offset 16 %111 = OpConstant %112 0 %117 = OpTypeFunction %2 %121 = OpTypeVector %8 3 +%143 = OpTypePointer Function %8 %32 = OpFunction %19 None %33 %31 = OpLabel OpBranch %34 @@ -178,18 +179,49 @@ OpBranch %139 %141 = OpLogicalAnd %10 %9 %12 OpReturn OpFunctionEnd -%143 = OpFunction %2 None %117 -%142 = OpLabel -OpBranch %144 +%145 = OpFunction %2 None %117 %144 = OpLabel -%145 = OpFunctionCall %19 %32 -%146 = OpFunctionCall %19 %57 -%147 = OpFunctionCall %8 %73 -%148 = OpVectorShuffle %22 %27 %27 0 1 2 -%149 = OpFunctionCall %22 %84 %148 -%150 = OpFunctionCall %4 %96 -%151 = OpFunctionCall %2 %116 -%152 = OpFunctionCall %2 %129 -%153 = OpFunctionCall %2 %138 +%142 = OpVariable %143 Function %7 +OpBranch %146 +%146 = OpLabel +%147 = OpLoad %8 %142 +%148 = OpIAdd %8 %147 %7 +OpStore %142 %148 +%149 = OpLoad %8 %142 +%150 = OpISub %8 %149 %7 +OpStore %142 %150 +%151 = OpLoad %8 %142 +%152 = OpLoad %8 %142 +%153 = OpIMul %8 %151 %152 +OpStore %142 %153 +%154 = OpLoad %8 %142 +%155 = OpLoad %8 %142 +%156 = OpSDiv %8 %154 %155 +OpStore %142 %156 +%157 = OpLoad %8 %142 +%158 = OpSMod %8 %157 %7 +OpStore %142 %158 +%159 = OpLoad %8 %142 +%160 = OpBitwiseXor %8 %159 %11 +OpStore %142 %160 +%161 = OpLoad %8 %142 +%162 = OpBitwiseAnd %8 %161 %11 +OpStore %142 %162 +OpReturn +OpFunctionEnd +%164 = OpFunction %2 None %117 +%163 = OpLabel +OpBranch %165 +%165 = OpLabel +%166 = OpFunctionCall %19 %32 +%167 = OpFunctionCall %19 %57 +%168 = OpFunctionCall %8 %73 +%169 = OpVectorShuffle %22 %27 %27 0 1 2 +%170 = OpFunctionCall %22 %84 %169 +%171 = OpFunctionCall %4 %96 +%172 = OpFunctionCall %2 %116 +%173 = OpFunctionCall %2 %129 +%174 = OpFunctionCall %2 %138 +%175 = OpFunctionCall %2 %145 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/operators.wgsl b/tests/out/wgsl/operators.wgsl index 1920ccb3a9..ee1ef36d6c 100644 --- a/tests/out/wgsl/operators.wgsl +++ b/tests/out/wgsl/operators.wgsl @@ -20,9 +20,9 @@ fn builtins() -> vec4 { } fn splat() -> vec4 { - let a = (((vec2(1.0) + vec2(2.0)) - vec2(3.0)) / vec2(4.0)); + let a_1 = (((vec2(1.0) + vec2(2.0)) - vec2(3.0)) / vec2(4.0)); let b = (vec4(5) % vec4(2)); - return (a.xyxy + vec4(b)); + return (a_1.xyxy + vec4(b)); } fn unary() -> i32 { @@ -49,7 +49,7 @@ fn constructors() -> f32 { } fn modulo() { - let a_1 = (1 % 1); + let a_2 = (1 % 1); let b_1 = (1.0 % 1.0); let c = (vec3(1) % vec3(1)); let d = (vec3(1.0) % vec3(1.0)); @@ -60,11 +60,33 @@ fn scalar_times_matrix() { let assertion = (2.0 * model); } -fn binary() { - let a_2 = (true | false); +fn logical() { + let a_3 = (true | false); let b_2 = (true & false); } +fn binary_assignment() { + var a: i32 = 1; + + let _e6 = a; + a = (_e6 + 1); + let _e9 = a; + a = (_e9 - 1); + let _e12 = a; + let _e13 = a; + a = (_e12 * _e13); + let _e15 = a; + let _e16 = a; + a = (_e15 / _e16); + let _e18 = a; + a = (_e18 % 1); + let _e21 = a; + a = (_e21 ^ 0); + let _e24 = a; + a = (_e24 & 0); + return; +} + [[stage(compute), workgroup_size(1, 1, 1)]] fn main() { let _e4 = builtins(); @@ -74,6 +96,7 @@ fn main() { let _e9 = constructors(); modulo(); scalar_times_matrix(); - binary(); + logical(); + binary_assignment(); return; }