diff --git a/src/front/glsl/ast.rs b/src/front/glsl/ast.rs index 571c5f5696..d1b6a7f416 100644 --- a/src/front/glsl/ast.rs +++ b/src/front/glsl/ast.rs @@ -8,7 +8,7 @@ use crate::{ proc::ResolveContext, Arena, BinaryOperator, Binding, Block, Constant, Expression, FastHashMap, Function, FunctionArgument, GlobalVariable, Handle, Interpolation, LocalVariable, Module, RelationalFunction, ResourceBinding, Sampling, ScalarKind, ShaderStage, Statement, - StorageClass, SwizzleComponent, Type, TypeInner, UnaryOperator, VectorSize, + StorageClass, Type, TypeInner, UnaryOperator, VectorSize, }; #[derive(Debug, Clone, Copy)] @@ -511,7 +511,7 @@ impl<'function> Context<'function> { HirExprKind::Select { base, field } => { let base = self.lower_expect(program, base, lhs, body)?.0; - program.field_selection(self, body, base, &field, meta)? + program.field_selection(self, lhs, body, base, &field, meta)? } HirExprKind::Constant(constant) if !lhs => { self.add_expression(Expression::Constant(constant), body) @@ -651,12 +651,7 @@ impl<'function> Context<'function> { let dst = self.add_expression( Expression::AccessIndex { base: vector, - index: match pattern[index] { - SwizzleComponent::X => 0, - SwizzleComponent::Y => 1, - SwizzleComponent::Z => 2, - SwizzleComponent::W => 3, - }, + index: pattern[index].index(), }, body, ); @@ -805,11 +800,8 @@ impl<'function> Context<'function> { .and_then(type_power)) } - pub fn expr_is_swizzle(&mut self, expr: Handle) -> bool { - match self.expressions[expr] { - Expression::Swizzle { .. } => true, - _ => false, - } + pub fn expr(&self, expr: Handle) -> &Expression { + &self.expressions[expr] } pub fn implicit_conversion( diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index ad888467af..d63245b0ba 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -48,12 +48,7 @@ impl Program<'_> { Expression::Swizzle { size, vector: args[0].0, - pattern: [ - SwizzleComponent::X, - SwizzleComponent::Y, - SwizzleComponent::Z, - SwizzleComponent::W, - ], + pattern: SwizzleComponent::XYZW, }, body, ); @@ -93,12 +88,7 @@ impl Program<'_> { Expression::Swizzle { size: rows, vector, - pattern: [ - SwizzleComponent::X, - SwizzleComponent::Y, - SwizzleComponent::Z, - SwizzleComponent::W, - ], + pattern: SwizzleComponent::XYZW, }, body, ); @@ -372,7 +362,9 @@ impl Program<'_> { let mut proxy_writes = Vec::new(); for (qualifier, expr) in fun.qualifiers.iter().zip(raw_args.iter()) { let handle = ctx.lower_expect(self, *expr, qualifier.is_lhs(), body)?.0; - if qualifier.is_lhs() && ctx.expr_is_swizzle(handle) { + if qualifier.is_lhs() + && matches! { ctx.expr(handle), &Expression::Swizzle { .. } } + { let meta = ctx.hir_exprs[*expr].meta; let ty = self.resolve_handle(ctx, handle, meta)?; let temp_var = ctx.locals.append(LocalVariable { diff --git a/src/front/glsl/parser_tests.rs b/src/front/glsl/parser_tests.rs index 88f252761a..9d203e9424 100644 --- a/src/front/glsl/parser_tests.rs +++ b/src/front/glsl/parser_tests.rs @@ -517,3 +517,47 @@ fn structs() { ) .unwrap_err(); } + +#[test] +fn swizzles() { + let mut entry_points = crate::FastHashMap::default(); + entry_points.insert("".to_string(), ShaderStage::Fragment); + + parse_program( + r#" + # version 450 + void main() { + vec4 v = vec4(1); + v.xyz = vec3(2); + v.x = 5.0; + v.xyz.zxy.yx.xy = vec2(5.0, 1.0); + } + "#, + &entry_points, + ) + .unwrap(); + + parse_program( + r#" + # version 450 + void main() { + vec4 v = vec4(1); + v.xx = vec2(5.0); + } + "#, + &entry_points, + ) + .unwrap_err(); + + parse_program( + r#" + # version 450 + void main() { + vec3 v = vec3(1); + v.w = 2.0; + } + "#, + &entry_points, + ) + .unwrap_err(); +} diff --git a/src/front/glsl/variables.rs b/src/front/glsl/variables.rs index 9eafcf88fb..72fff5e2a4 100644 --- a/src/front/glsl/variables.rs +++ b/src/front/glsl/variables.rs @@ -137,6 +137,7 @@ impl Program<'_> { pub fn field_selection( &mut self, ctx: &mut Context, + lhs: bool, body: &mut Block, expression: Handle, name: &str, @@ -164,49 +165,76 @@ impl Program<'_> { TypeInner::Vector { size, .. } => { let check_swizzle_components = |comps: &str| { name.chars() - .map(|c| comps.find(c).filter(|i| *i < size as usize)) - .collect::>>() + .map(|c| { + comps + .find(c) + .filter(|i| *i < size as usize) + .map(|i| SwizzleComponent::from_index(i as u32)) + }) + .collect::>>() }; - let indices = check_swizzle_components("xyzw") + let components = check_swizzle_components("xyzw") .or_else(|| check_swizzle_components("rgba")) .or_else(|| check_swizzle_components("stpq")); - if let Some(components) = indices { - if components.len() == 1 { - // only single element swizzle, like pos.y, just return that component - return Ok(ctx.add_expression( - Expression::AccessIndex { - base: expression, - index: components[0] as u32, - }, - body, - )); + if let Some(components) = components { + if lhs { + let not_unique = (1..components.len()) + .any(|i| components[i..].contains(&components[i - 1])); + if not_unique { + return Err(ErrorKind::SemanticError( + meta, + format!( + "swizzle cannot have duplicate components in left-hand-side expression for \"{:?}\"", + name + ) + .into(), + )); + } + } + + let mut pattern = [SwizzleComponent::X; 4]; + for (pat, component) in pattern.iter_mut().zip(&components) { + *pat = *component; + } + + // flatten nested swizzles (vec.zyx.xy.x => vec.z) + let mut expression = expression; + if let Expression::Swizzle { + size: _, + ref vector, + pattern: ref src_pattern, + } = *ctx.expr(expression) + { + expression = *vector; + for pat in &mut pattern { + *pat = src_pattern[pat.index() as usize]; + } } let size = match components.len() { + 1 => { + // only single element swizzle, like pos.y, just return that component + return Ok(ctx.add_expression( + Expression::AccessIndex { + base: expression, + index: pattern[0].index(), + }, + body, + )); + } 2 => VectorSize::Bi, 3 => VectorSize::Tri, 4 => VectorSize::Quad, _ => { return Err(ErrorKind::SemanticError( meta, - format!("Bad swizzle size for \"{:?}\": {:?}", name, components) - .into(), + format!("Bad swizzle size for \"{:?}\"", name).into(), )); } }; - let mut pattern = [SwizzleComponent::X; 4]; - for (pat, index) in pattern.iter_mut().zip(components) { - *pat = match index { - 0 => SwizzleComponent::X, - 1 => SwizzleComponent::Y, - 2 => SwizzleComponent::Z, - _ => SwizzleComponent::W, - }; - } - Ok(ctx.add_expression( Expression::Swizzle { size, @@ -393,25 +421,20 @@ impl Program<'_> { return Ok(GlobalOrConstant::Global(handle)); } else if let StorageQualifier::Const = storage { - if let Some(init) = init { - if let Some(name) = name { - self.global_variables.push(( - name, - GlobalLookup { - kind: GlobalLookupKind::Constant(init), - entry_arg: None, - mutable: false, - }, - )); - } - - return Ok(GlobalOrConstant::Constant(init)); - } else { - return Err(ErrorKind::SemanticError( - meta, - "const values must have an initializer".into(), + let init = init.ok_or_else(|| { + ErrorKind::SemanticError(meta, "const values must have an initializer".into()) + })?; + if let Some(name) = name { + self.global_variables.push(( + name, + GlobalLookup { + kind: GlobalLookupKind::Constant(init), + entry_arg: None, + mutable: false, + }, )); } + return Ok(GlobalOrConstant::Constant(init)); } let (class, storage_access) = match self.module.types[ty].inner { diff --git a/src/lib.rs b/src/lib.rs index b2ab989b9b..9942b9983b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -773,6 +773,32 @@ pub enum SwizzleComponent { W = 3, } +impl SwizzleComponent { + pub const XYZW: [SwizzleComponent; 4] = [ + SwizzleComponent::X, + SwizzleComponent::Y, + SwizzleComponent::Z, + SwizzleComponent::W, + ]; + + pub fn index(&self) -> u32 { + match *self { + SwizzleComponent::X => 0, + SwizzleComponent::Y => 1, + SwizzleComponent::Z => 2, + SwizzleComponent::W => 3, + } + } + pub fn from_index(idx: u32) -> Self { + match idx { + 0 => SwizzleComponent::X, + 1 => SwizzleComponent::Y, + 2 => SwizzleComponent::Z, + _ => SwizzleComponent::W, + } + } +} + bitflags::bitflags! { /// Memory barrier flags. #[cfg_attr(feature = "serialize", derive(Serialize))] diff --git a/tests/in/glsl/swizzle_write.frag b/tests/in/glsl/swizzle_write.frag index c4e91ed750..9f11d7e5f3 100644 --- a/tests/in/glsl/swizzle_write.frag +++ b/tests/in/glsl/swizzle_write.frag @@ -3,5 +3,6 @@ void main() { vec3 x = vec3(2.0); - x.rg *= 2.0; + x.zxy.xy = vec2(3.0, 4.0); + x.rg *= 5.0; }