diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index fab1f5ae24..15d8ba7536 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -1,13 +1,35 @@ use crate::{ - proc::ensure_block_returns, Arena, BinaryOperator, Block, EntryPoint, Expression, Function, + proc::ensure_block_returns, Arena, BinaryOperator, Block, Constant, ConstantInner, EntryPoint, Expression, Function, FunctionArgument, FunctionResult, Handle, ImageQuery, LocalVariable, MathFunction, - RelationalFunction, SampleLevel, ScalarKind, Statement, StructMember, SwizzleComponent, Type, + RelationalFunction, SampleLevel, ScalarKind, ScalarValue, Statement, StructMember, SwizzleComponent, Type, TypeInner, VectorSize, }; use super::{ast::*, error::ErrorKind, SourceMetadata}; impl Program<'_> { + fn add_constant_value( + &mut self, + scalar_kind: ScalarKind, + value: u64, + ) -> Handle { + let value = match scalar_kind { + ScalarKind::Uint => ScalarValue::Uint(value), + ScalarKind::Sint => ScalarValue::Sint(value as i64), + ScalarKind::Float => ScalarValue::Float(value as f64), + _ => unreachable!(), + }; + + self.module.constants.fetch_or_append(Constant { + name: None, + specialization: None, + inner: ConstantInner::Scalar { + width: 4, + value, + }, + }) + } + pub fn function_call( &mut self, ctx: &mut Context, @@ -24,13 +46,40 @@ impl Program<'_> { match fc { FunctionCallKind::TypeConstructor(ty) => { let h = if args.len() == 1 { - let is_vec = match *self.resolve_type(ctx, args[0].0, args[0].1)? { - TypeInner::Vector { .. } => true, - _ => false, + let expr_type = self.resolve_type(ctx, args[0].0, args[0].1)?; + + let vector_size = match *expr_type { + TypeInner::Vector{ size, .. } => Some(size), + _ => None, }; + // Special case: if casting from a bool, we need to use Select and not As. + match self.module.types[ty].inner.scalar_kind() { + Some(result_scalar_kind) if expr_type.scalar_kind() == Some(ScalarKind::Bool) && result_scalar_kind != ScalarKind::Bool => { + let c0 = self.add_constant_value(result_scalar_kind, 0u64); + let c1 = self.add_constant_value(result_scalar_kind, 1u64); + let mut reject = ctx.add_expression(Expression::Constant(c0), body); + let mut accept = ctx.add_expression(Expression::Constant(c1), body); + + ctx.implicit_splat(self, &mut reject, meta, vector_size)?; + ctx.implicit_splat(self, &mut accept, meta, vector_size)?; + + let h = ctx.add_expression( + Expression::Select { + accept, + reject, + condition: args[0].0, + }, + body, + ); + + return Ok(Some(h)); + } + _ => {} + } + match self.module.types[ty].inner { - TypeInner::Vector { size, kind, .. } if !is_vec => { + TypeInner::Vector { size, kind, .. } if vector_size.is_none() => { let (mut value, meta) = args[0]; ctx.implicit_conversion(self, &mut value, meta, kind)?; diff --git a/src/valid/expression.rs b/src/valid/expression.rs index d1bb45f7bf..990026d9b2 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1159,13 +1159,15 @@ impl super::Validator { .resolve(expr)? .scalar_kind() .ok_or(ExpressionError::InvalidCastArgument)?; + + if prev_kind == Sk::Bool || kind == Sk::Bool { + return Err(ExpressionError::InvalidCastArgument); + } + match convert { Some(width) if !self.check_width(kind, width) => { return Err(ExpressionError::InvalidCastArgument) } - None if prev_kind == Sk::Bool || kind == Sk::Bool => { - return Err(ExpressionError::InvalidCastArgument) - } _ => {} } ShaderStages::all() diff --git a/tests/in/glsl/bool-select.frag b/tests/in/glsl/bool-select.frag new file mode 100644 index 0000000000..96cd7d4ae8 --- /dev/null +++ b/tests/in/glsl/bool-select.frag @@ -0,0 +1,17 @@ +#version 440 core +precision highp float; + +layout(location = 0) out vec4 o_color; + +float TevPerCompGT(float a, float b) { + return float(a > b); +} + +vec3 TevPerCompGT(vec3 a, vec3 b) { + return vec3(greaterThan(a, b)); +} + +void main() { + o_color.rgb = TevPerCompGT(vec3(3.0), vec3(5.0)); + o_color.a = TevPerCompGT(3.0, 5.0); +} diff --git a/tests/out/wgsl/bool-select-frag.wgsl b/tests/out/wgsl/bool-select-frag.wgsl new file mode 100644 index 0000000000..51524143bc --- /dev/null +++ b/tests/out/wgsl/bool-select-frag.wgsl @@ -0,0 +1,45 @@ +struct FragmentOutput { + [[location(0), interpolate(perspective)]] o_color: vec4; +}; + +var o_color: vec4; + +fn TevPerCompGT(a: f32, b: f32) -> f32 { + var a1: f32; + var b1: f32; + + a1 = a; + b1 = b; + let _e5: f32 = a1; + let _e6: f32 = b1; + return select(1.0, 0.0, (_e5 > _e6)); +} + +fn TevPerCompGT1(a2: vec3, b2: vec3) -> vec3 { + var a3: vec3; + var b3: vec3; + + a3 = a2; + b3 = b2; + let _e5: vec3 = a3; + let _e6: vec3 = b3; + return select(vec3(1.0), vec3(0.0), (_e5 > _e6)); +} + +fn main1() { + let _e1: vec4 = o_color; + let _e11: vec3 = TevPerCompGT1(vec3(3.0), vec3(5.0)); + o_color.x = _e11.x; + o_color.y = _e11.y; + o_color.z = _e11.z; + let _e23: f32 = TevPerCompGT(3.0, 5.0); + o_color.w = _e23; + return; +} + +[[stage(fragment)]] +fn main() -> FragmentOutput { + main1(); + let _e1: vec4 = o_color; + return FragmentOutput(_e1); +}