diff --git a/src/front/glsl/ast.rs b/src/front/glsl/ast.rs index 16dc59708e..5367fc74ed 100644 --- a/src/front/glsl/ast.rs +++ b/src/front/glsl/ast.rs @@ -8,7 +8,7 @@ use crate::{ proc::ResolveContext, Arena, BinaryOperator, Binding, Block, Constant, Expression, FastHashMap, Function, FunctionArgument, GlobalVariable, Handle, Interpolation, LocalVariable, Module, RelationalFunction, ResourceBinding, Sampling, ScalarKind, ShaderStage, Statement, - StorageClass, Type, TypeInner, UnaryOperator, + StorageClass, SwizzleComponent, Type, TypeInner, UnaryOperator, VectorSize, }; #[derive(Debug, Clone, Copy)] @@ -625,10 +625,56 @@ impl<'function> Context<'function> { self.implicit_conversion(program, &mut value, value_meta, kind)?; } - self.emit_flush(body); - self.emit_start(); + if let Expression::Swizzle { + size, + vector, + pattern, + } = self.expressions[pointer] + { + // Stores to swizzled values are not directly supported, + // lower them as series of per-component stores. + let size = match size { + VectorSize::Bi => 2, + VectorSize::Tri => 3, + VectorSize::Quad => 4, + }; - body.push(Statement::Store { pointer, value }); + #[allow(clippy::needless_range_loop)] + for index in 0..size { + let dst = self.add_expression( + Expression::AccessIndex { + base: vector, + index: match pattern[index] { + SwizzleComponent::X => 0, + SwizzleComponent::Y => 1, + SwizzleComponent::Z => 2, + SwizzleComponent::W => 3, + }, + }, + body, + ); + let src = self.add_expression( + Expression::AccessIndex { + base: value, + index: index as u32, + }, + body, + ); + + self.emit_flush(body); + self.emit_start(); + + body.push(Statement::Store { + pointer: dst, + value: src, + }); + } + } else { + self.emit_flush(body); + self.emit_start(); + + body.push(Statement::Store { pointer, value }); + } value } diff --git a/src/front/glsl/variables.rs b/src/front/glsl/variables.rs index 03dd2ddfde..fc54c5daff 100644 --- a/src/front/glsl/variables.rs +++ b/src/front/glsl/variables.rs @@ -1,7 +1,7 @@ use crate::{ Binding, Block, BuiltIn, Constant, Expression, GlobalVariable, Handle, ImageClass, - Interpolation, LocalVariable, ScalarKind, StorageAccess, StorageClass, Type, TypeInner, - VectorSize, + Interpolation, LocalVariable, ScalarKind, StorageAccess, StorageClass, SwizzleComponent, Type, + TypeInner, VectorSize, }; use super::ast::*; @@ -146,67 +146,60 @@ impl Program<'_> { )) } // swizzles (xyzw, rgba, stpq) - TypeInner::Vector { size, kind, width } => { + TypeInner::Vector { size, .. } => { let check_swizzle_components = |comps: &str| { name.chars() - .map(|c| { - comps - .find(c) - .and_then(|i| if i < size as usize { Some(i) } else { None }) - }) - .fold(Some(Vec::::new()), |acc, cur| { - cur.and_then(|i| { - acc.map(|mut v| { - v.push(i); - v - }) - }) - }) + .map(|c| comps.find(c).filter(|i| *i < size as usize)) + .collect::>>() }; let indices = check_swizzle_components("xyzw") .or_else(|| check_swizzle_components("rgba")) .or_else(|| check_swizzle_components("stpq")); - if let Some(v) = indices { - let components: Vec> = v - .iter() - .map(|idx| { - ctx.add_expression( - Expression::AccessIndex { - base: expression, - index: *idx as u32, - }, - body, - ) - }) - .collect(); + if let Some(components) = indices { if components.len() == 1 { // only single element swizzle, like pos.y, just return that component - Ok(components[0]) - } else { - let size = match components.len() { - 2 => VectorSize::Bi, - 3 => VectorSize::Tri, - 4 => VectorSize::Quad, - _ => { - return Err(ErrorKind::SemanticError( - meta, - format!("Bad swizzle size for \"{:?}\": {:?}", name, v).into(), - )); - } - }; - Ok(ctx.add_expression( - Expression::Compose { - ty: self.module.types.fetch_or_append(Type { - name: None, - inner: TypeInner::Vector { kind, width, size }, - }), - components, + return Ok(ctx.add_expression( + Expression::AccessIndex { + base: expression, + index: components[0] as u32, }, body, - )) + )); } + + let size = match components.len() { + 2 => VectorSize::Bi, + 3 => VectorSize::Tri, + 4 => VectorSize::Quad, + _ => { + return Err(ErrorKind::SemanticError( + meta, + format!("Bad swizzle size for \"{:?}\": {:?}", name, components) + .into(), + )); + } + }; + + let mut pattern = [SwizzleComponent::X; 4]; + for (pat, index) in pattern.iter_mut().zip(components) { + *pat = match index { + 0 => SwizzleComponent::X, + 1 => SwizzleComponent::Y, + 2 => SwizzleComponent::Z, + _ => SwizzleComponent::W, + }; + } + + Ok(ctx.add_expression( + Expression::Swizzle { + size, + vector: expression, + pattern, + }, + body, + )) } else { Err(ErrorKind::SemanticError( meta, diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index c8d2b02df5..47adb4d5d6 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -322,6 +322,17 @@ impl<'a> ResolveContext<'a> { kind, width, } => TypeResolution::Value(Ti::Vector { size, kind, width }), + Ti::Pointer { base, .. } => match types[base].inner { + Ti::Vector { + size: _, + kind, + width, + } => TypeResolution::Value(Ti::Vector { size, kind, width }), + ref other => { + log::error!("Vector pointer type {:?}", other); + return Err(ResolveError::InvalidVector(vector)); + } + }, ref other => { log::error!("Vector type {:?}", other); return Err(ResolveError::InvalidVector(vector)); diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 561a088a7e..313dfa01bf 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -233,6 +233,13 @@ impl super::Validator { } => { let vec_size = match *resolver.resolve(vector)? { Ti::Vector { size: vec_size, .. } => vec_size, + Ti::Pointer { base, .. } => match module.types[base].inner { + Ti::Vector { size: vec_size, .. } => vec_size, + ref other => { + log::error!("Swizzle vector pointer type {:?}", other); + return Err(ExpressionError::InvalidVectorType(vector)); + } + }, ref other => { log::error!("Swizzle vector type {:?}", other); return Err(ExpressionError::InvalidVectorType(vector));