From d3b39d9e587218ff451fbdc754dd33d6d65d3c18 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Wed, 14 Apr 2021 01:03:25 -0400 Subject: [PATCH] Splat expression in the IR --- src/back/dot/mod.rs | 4 ++++ src/back/glsl/mod.rs | 4 ++++ src/back/msl/writer.rs | 3 +++ src/front/glsl/constants.rs | 24 ++++++++++++++++++++++++ src/lib.rs | 5 +++++ src/proc/typifier.rs | 11 +++++++++++ src/valid/analyzer.rs | 4 ++++ src/valid/expression.rs | 9 +++++++++ 8 files changed, 64 insertions(+) diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index 8296e0ba2a..d1df175176 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -183,6 +183,10 @@ fn write_fun( (Cow::Owned(format!("AccessIndex[{}]", index)), 1) } E::Constant(_) => (Cow::Borrowed("Constant"), 2), + E::Splat { size, value } => { + edges.insert("value", value); + (Cow::Owned(format!("Splat{:?}", size)), 3) + } E::Compose { ref components, .. } => { payload = Some(Payload::Arguments(components)); (Cow::Borrowed("Compose"), 3) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 3674a2076c..fb686d5b0f 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -1477,6 +1477,10 @@ impl<'a, W: Write> Writer<'a, W> { Expression::Constant(constant) => { self.write_constant(&self.module.constants[constant])? } + // `Splat` is just writing `value` + Expression::Splat { size: _, value } => { + self.write_expr(value, ctx)?; + } // `Compose` is pretty simple we just write `type(components)` where `components` is a // comma separated list of expressions Expression::Compose { ty, ref components } => { diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index da43c71eef..086f23768e 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -608,6 +608,9 @@ impl Writer { }; write!(self.out, "{}", coco)?; } + crate::Expression::Splat { size: _, value } => { + self.put_expression(value, context, is_scoped)?; + } crate::Expression::Compose { ty, ref components } => { let inner = &context.module.types[ty].inner; match *inner { diff --git a/src/front/glsl/constants.rs b/src/front/glsl/constants.rs index c68d2b5f72..372ff17041 100644 --- a/src/front/glsl/constants.rs +++ b/src/front/glsl/constants.rs @@ -49,6 +49,8 @@ pub enum ConstantSolvingError { InvalidUnaryOpArg, #[error("Cannot apply the binary op to the arguments")] InvalidBinaryOpArgs, + #[error("Splat type is not registered")] + SplatType, } impl<'a> ConstantSolver<'a> { @@ -64,6 +66,28 @@ impl<'a> ConstantSolver<'a> { self.access(base, self.constant_index(index)?) } + Expression::Splat { + size, + value: splat_value, + } => { + let tgt = self.solve(splat_value)?; + let ty = match self.constants[tgt].inner { + ConstantInner::Scalar { ref value, width } => { + let kind = value.scalar_kind(); + self.types + .fetch_if(|t| t.inner == crate::TypeInner::Vector { size, kind, width }) + } + ConstantInner::Composite { .. } => None, + }; + Ok(self.constants.fetch_or_append(Constant { + name: None, + specialization: None, + inner: ConstantInner::Composite { + ty: ty.ok_or(ConstantSolvingError::SplatType)?, + components: vec![tgt; size as usize], + }, + })) + } Expression::Compose { ty, ref components } => { let components = components .iter() diff --git a/src/lib.rs b/src/lib.rs index fa12613fbc..151591dc40 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -683,6 +683,11 @@ pub enum Expression { }, /// Constant value. Constant(Handle), + /// Splat scalar into a vector. + Splat { + size: VectorSize, + value: Handle, + }, /// Composite expression. Compose { ty: Handle, diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 50455defca..ce62a5b218 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -79,6 +79,8 @@ pub enum ResolveError { ty: Handle, indexed: bool, }, + #[error("Invalid scalar {0:?}")] + InvalidScalar(Handle), #[error("Invalid pointer {0:?}")] InvalidPointer(Handle), #[error("Invalid image {0:?}")] @@ -265,6 +267,15 @@ impl<'a> ResolveContext<'a> { } crate::ConstantInner::Composite { ty, components: _ } => TypeResolution::Handle(ty), }, + crate::Expression::Splat { size, value } => match *past(value).inner_with(types) { + Ti::Scalar { kind, width } => { + TypeResolution::Value(Ti::Vector { size, kind, width }) + } + ref other => { + log::error!("Scalar type {:?}", other); + return Err(ResolveError::InvalidScalar(value)); + } + }, crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty), crate::Expression::FunctionArgument(index) => { TypeResolution::Handle(self.arguments[index as usize].ty) diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 94d7fecc9b..52e00818fb 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -318,6 +318,10 @@ impl FunctionInfo { }, // always uniform E::Constant(_) => Uniformity::new(), + E::Splat { size: _, value } => Uniformity { + non_uniform_result: self.add_ref(value), + requirements: UniformityRequirements::empty(), + }, E::Compose { ref components, .. } => { let non_uniform_result = components .iter() diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 9df4daae11..0ab075af11 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -31,6 +31,8 @@ pub enum ExpressionError { InvalidPointerType(Handle), #[error("Array length of {0:?} can't be done")] InvalidArrayType(Handle), + #[error("Splatting {0:?} can't be done")] + InvalidSplatType(Handle), #[error("Compose type {0:?} doesn't exist")] ComposeTypeDoesntExist(Handle), #[error("Composing of type {0:?} can't be done")] @@ -197,6 +199,13 @@ impl super::Validator { .ok_or(ExpressionError::ConstantDoesntExist(handle))?; ShaderStages::all() } + E::Splat { size: _, value } => match *resolver.resolve(value)? { + Ti::Scalar { .. } => ShaderStages::all(), + ref other => { + log::error!("Splat scalar type {:?}", other); + return Err(ExpressionError::InvalidSplatType(value)); + } + }, E::Compose { ref components, ty } => { match module .types