From 02ddf6532c4befd90444188979e1a595b459bfe4 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Tue, 9 May 2023 18:16:01 -0700 Subject: [PATCH] Introduce `Expression::Literal`. --- src/back/dot/mod.rs | 1 + src/back/glsl/mod.rs | 15 ++++++ src/back/hlsl/writer.rs | 9 ++++ src/back/msl/writer.rs | 25 +++++++++ src/back/spv/block.rs | 84 +++++++++++++------------------ src/back/spv/image.rs | 4 +- src/back/spv/instructions.rs | 8 +++ src/back/spv/mod.rs | 14 ++---- src/back/spv/ray.rs | 9 ++-- src/back/spv/writer.rs | 93 ++++++++++++---------------------- src/back/wgsl/writer.rs | 13 +++++ src/front/glsl/constants.rs | 13 +++++ src/lib.rs | 19 ++++++- src/proc/mod.rs | 98 +++++++++++++++++++++++++++++++++++- src/proc/typifier.rs | 1 + src/valid/analyzer.rs | 3 +- src/valid/expression.rs | 1 + src/valid/handles.rs | 1 + 18 files changed, 280 insertions(+), 131 deletions(-) diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index 8667cd7257..1494e9a820 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -397,6 +397,7 @@ fn write_function_expressions( for (handle, expression) in fun.expressions.iter() { use crate::Expression as E; let (label, color_id) = match *expression { + E::Literal(_) => ("Literal".into(), 2), E::Constant(_) => ("Constant".into(), 2), E::ZeroValue(_) => ("ZeroValue".into(), 2), E::Access { base, index } => { diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 030dba4e60..ad76705b84 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2270,6 +2270,21 @@ impl<'a, W: Write> Writer<'a, W> { Expression::ZeroValue(ty) => { self.write_zero_init_value(ty)?; } + Expression::Literal(literal) => { + match literal { + // Floats are written using `Debug` instead of `Display` because it always appends the + // decimal part even it's zero which is needed for a valid glsl float constant + crate::Literal::F64(value) => write!(self.out, "{:?}LF", value)?, + crate::Literal::F32(value) => write!(self.out, "{:?}", value)?, + // Unsigned integers need a `u` at the end + // + // While `core` doesn't necessarily need it, it's allowed and since `es` needs it we + // always write it as the extra branch wouldn't have any benefit in readability + crate::Literal::U32(value) => write!(self.out, "{}u", value)?, + crate::Literal::I32(value) => write!(self.out, "{}", value)?, + crate::Literal::Bool(value) => write!(self.out, "{}", value)?, + } + } // `Splat` needs to actually write down a vector, it's not always inferred in GLSL. Expression::Splat { size: _, value } => { let resolved = ctx.info[expr].ty.inner_with(&self.module.types); diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index c5bda66070..3d0ed9cd3e 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -2058,6 +2058,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { match *expression { Expression::Constant(constant) => self.write_constant(module, constant)?, Expression::ZeroValue(ty) => self.write_default_init(module, ty)?, + Expression::Literal(literal) => match literal { + // Floats are written using `Debug` instead of `Display` because it always appends the + // decimal part even it's zero + crate::Literal::F64(value) => write!(self.out, "{value:?}L")?, + crate::Literal::F32(value) => write!(self.out, "{value:?}")?, + crate::Literal::U32(value) => write!(self.out, "{}u", value)?, + crate::Literal::I32(value) => write!(self.out, "{}", value)?, + crate::Literal::Bool(value) => write!(self.out, "{}", value)?, + }, Expression::Compose { ty, ref components } => { match module.types[ty].inner { TypeInner::Struct { .. } | TypeInner::Array { .. } => { diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 1c474a5e40..7e41e8a7e1 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1344,6 +1344,31 @@ impl Writer { }; write!(self.out, "{ty_name} {{}}")?; } + crate::Expression::Literal(literal) => match literal { + crate::Literal::F64(_) => { + return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64)) + } + crate::Literal::F32(value) => { + if value.is_infinite() { + let sign = if value.is_sign_negative() { "-" } else { "" }; + write!(self.out, "{sign}INFINITY")?; + } else if value.is_nan() { + write!(self.out, "NAN")?; + } else { + let suffix = if value.fract() == 0.0 { ".0" } else { "" }; + write!(self.out, "{value}{suffix}")?; + } + } + crate::Literal::U32(value) => { + write!(self.out, "{value}u")?; + } + crate::Literal::I32(value) => { + write!(self.out, "{value}")?; + } + crate::Literal::Bool(value) => { + write!(self.out, "{value}")?; + } + }, crate::Expression::Splat { size, value } => { let scalar_kind = match *context.resolve_type(value) { crate::TypeInner::Scalar { kind, .. } => kind, diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index a8afa89977..be5fa5da01 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -118,8 +118,8 @@ impl Writer { width: 4, pointer_space: None, })); - let value0_id = self.get_constant_scalar(crate::ScalarValue::Float(0.0), 4); - let value1_id = self.get_constant_scalar(crate::ScalarValue::Float(1.0), 4); + let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0)); + let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0)); let original_id = self.id_gen.next(); body.push(Instruction::load( @@ -135,7 +135,7 @@ impl Writer { spirv::GLOp::FClamp, float_type_id, clamp_id, - &[original_id, value0_id, value1_id], + &[original_id, zero_scalar_id, one_scalar_id], )); body.push(Instruction::store(frag_depth_id, clamp_id, None)); @@ -359,6 +359,7 @@ impl<'w> BlockContext<'w> { } crate::Expression::Constant(handle) => self.writer.constant_ids[handle.index()], crate::Expression::ZeroValue(_) => self.writer.write_constant_null(result_type_id), + crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal), crate::Expression::Splat { size, value } => { let value_id = self.cached[value]; let components = [value_id; 4]; @@ -705,18 +706,14 @@ impl<'w> BlockContext<'w> { crate::TypeInner::Scalar { width, .. } => (None, width), ref other => unimplemented!("Unexpected saturate({:?})", other), }; - - let mut arg1_id = self - .writer - .get_constant_scalar(crate::ScalarValue::Float(0.0), width); - let mut arg2_id = self - .writer - .get_constant_scalar(crate::ScalarValue::Float(1.0), width); + let kind = crate::ScalarKind::Float; + let mut arg1_id = self.writer.get_constant_scalar_with(0, kind, width)?; + let mut arg2_id = self.writer.get_constant_scalar_with(1, kind, width)?; if let Some(size) = maybe_size { let ty = LocalType::Value { vector_size: Some(size), - kind: crate::ScalarKind::Float, + kind, width, pointer_space: None, } @@ -878,12 +875,13 @@ impl<'w> BlockContext<'w> { arg0_id, )), Mf::CountTrailingZeros => { - let uint = crate::ScalarValue::Uint(32); + let kind = crate::ScalarKind::Uint; + let uint_id = match *arg_ty { crate::TypeInner::Vector { size, width, .. } => { let ty = LocalType::Value { vector_size: Some(size), - kind: crate::ScalarKind::Uint, + kind, width, pointer_space: None, } @@ -892,13 +890,13 @@ impl<'w> BlockContext<'w> { self.temp_list.clear(); self.temp_list.resize( size as _, - self.writer.get_constant_scalar(uint, width), + self.writer.get_constant_scalar_with(32, kind, width)?, ); self.writer.get_constant_composite(ty, &self.temp_list) } crate::TypeInner::Scalar { width, .. } => { - self.writer.get_constant_scalar(uint, width) + self.writer.get_constant_scalar_with(32, kind, width)? } _ => unreachable!(), }; @@ -921,21 +919,23 @@ impl<'w> BlockContext<'w> { )) } Mf::CountLeadingZeros => { - let int = crate::ScalarValue::Sint(31); + let kind = crate::ScalarKind::Sint; let (int_type_id, int_id) = match *arg_ty { crate::TypeInner::Vector { size, width, .. } => { let ty = LocalType::Value { vector_size: Some(size), - kind: crate::ScalarKind::Sint, + kind, width, pointer_space: None, } .into(); self.temp_list.clear(); - self.temp_list - .resize(size as _, self.writer.get_constant_scalar(int, width)); + self.temp_list.resize( + size as _, + self.writer.get_constant_scalar_with(31, kind, width)?, + ); ( self.get_type_id(ty), @@ -945,11 +945,11 @@ impl<'w> BlockContext<'w> { crate::TypeInner::Scalar { width, .. } => ( self.get_type_id(LookupType::Local(LocalType::Value { vector_size: None, - kind: crate::ScalarKind::Sint, + kind, width, pointer_space: None, })), - self.writer.get_constant_scalar(int, width), + self.writer.get_constant_scalar_with(31, kind, width)?, ), _ => unreachable!(), }; @@ -1134,15 +1134,14 @@ impl<'w> BlockContext<'w> { (_, _, None) => Cast::Unary(spirv::Op::Bitcast), // casting to a bool - generate `OpXxxNotEqual` (_, Sk::Bool, Some(_)) => { - let (op, value) = match src_kind { - Sk::Sint => (spirv::Op::INotEqual, crate::ScalarValue::Sint(0)), - Sk::Uint => (spirv::Op::INotEqual, crate::ScalarValue::Uint(0)), - Sk::Float => { - (spirv::Op::FUnordNotEqual, crate::ScalarValue::Float(0.0)) - } + let op = match src_kind { + Sk::Sint | Sk::Uint => spirv::Op::INotEqual, + Sk::Float => spirv::Op::FUnordNotEqual, Sk::Bool => unreachable!(), }; - let zero_scalar_id = self.writer.get_constant_scalar(value, src_width); + let zero_scalar_id = self + .writer + .get_constant_scalar_with(0, src_kind, src_width)?; let zero_id = match src_size { Some(size) => { let ty = LocalType::Value { @@ -1165,21 +1164,10 @@ impl<'w> BlockContext<'w> { } // casting from a bool - generate `OpSelect` (Sk::Bool, _, Some(dst_width)) => { - let (val0, val1) = match kind { - Sk::Sint => { - (crate::ScalarValue::Sint(0), crate::ScalarValue::Sint(1)) - } - Sk::Uint => { - (crate::ScalarValue::Uint(0), crate::ScalarValue::Uint(1)) - } - Sk::Float => ( - crate::ScalarValue::Float(0.0), - crate::ScalarValue::Float(1.0), - ), - Sk::Bool => unreachable!(), - }; - let scalar0_id = self.writer.get_constant_scalar(val0, dst_width); - let scalar1_id = self.writer.get_constant_scalar(val1, dst_width); + let zero_scalar_id = + self.writer.get_constant_scalar_with(0, kind, dst_width)?; + let one_scalar_id = + self.writer.get_constant_scalar_with(1, kind, dst_width)?; let (accept_id, reject_id) = match src_size { Some(size) => { let ty = LocalType::Value { @@ -1191,19 +1179,19 @@ impl<'w> BlockContext<'w> { .into(); self.temp_list.clear(); - self.temp_list.resize(size as _, scalar0_id); + self.temp_list.resize(size as _, zero_scalar_id); let vec0_id = self.writer.get_constant_composite(ty, &self.temp_list); - self.temp_list.fill(scalar1_id); + self.temp_list.fill(one_scalar_id); let vec1_id = self.writer.get_constant_composite(ty, &self.temp_list); (vec1_id, vec0_id) } - None => (scalar1_id, scalar0_id), + None => (one_scalar_id, zero_scalar_id), }; Cast::Ternary(spirv::Op::Select, accept_id, reject_id) @@ -1460,8 +1448,8 @@ impl<'w> BlockContext<'w> { BoundsCheckResult::KnownInBounds(known_index) => { // Even if the index is known, `OpAccessIndex` // requires expression operands, not literals. - let scalar = crate::ScalarValue::Uint(known_index as u64); - self.writer.get_constant_scalar(scalar, 4) + let scalar = crate::Literal::U32(known_index); + self.writer.get_constant_scalar(scalar) } BoundsCheckResult::Computed(computed_index_id) => computed_index_id, BoundsCheckResult::Conditional(comparison_id) => { diff --git a/src/back/spv/image.rs b/src/back/spv/image.rs index 27f3520502..81c9de3755 100644 --- a/src/back/spv/image.rs +++ b/src/back/spv/image.rs @@ -901,9 +901,7 @@ impl<'w> BlockContext<'w> { depth_id, ); - let zero_id = self - .writer - .get_constant_scalar(crate::ScalarValue::Float(0.0), 4); + let zero_id = self.writer.get_constant_scalar(crate::Literal::F32(0.0)); mask |= spirv::ImageOperands::LOD; inst.add_operand(mask.bits()); diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index 96d0278285..31ed6e231d 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -343,6 +343,14 @@ impl super::Instruction { instruction } + pub(super) fn constant_32bit(result_type_id: Word, id: Word, value: Word) -> Self { + Self::constant(result_type_id, id, &[value]) + } + + pub(super) fn constant_64bit(result_type_id: Word, id: Word, low: Word, high: Word) -> Self { + Self::constant(result_type_id, id, &[low, high]) + } + pub(super) fn constant(result_type_id: Word, id: Word, values: &[Word]) -> Self { let mut instruction = Self::new(Op::Constant); instruction.set_type(result_type_id); diff --git a/src/back/spv/mod.rs b/src/back/spv/mod.rs index adc3dd7a7a..613912ef8d 100644 --- a/src/back/spv/mod.rs +++ b/src/back/spv/mod.rs @@ -295,12 +295,12 @@ enum LocalType { /// [`BindingArray`]: crate::TypeInner::BindingArray PointerToBindingArray { base: Handle, - size: u64, + size: u32, space: crate::AddressSpace, }, BindingArray { base: Handle, - size: u64, + size: u32, }, AccelerationStructure, RayQuery, @@ -454,10 +454,7 @@ impl recyclable::Recyclable for CachedExpressions { #[derive(Eq, Hash, PartialEq)] enum CachedConstant { - Scalar { - value: crate::ScalarValue, - width: crate::Bytes, - }, + Literal(crate::Literal), Composite { ty: LookupType, constituent_ids: Vec, @@ -568,13 +565,12 @@ impl BlockContext<'_> { } fn get_index_constant(&mut self, index: Word) -> Word { - self.writer - .get_constant_scalar(crate::ScalarValue::Uint(index as _), 4) + self.writer.get_constant_scalar(crate::Literal::U32(index)) } fn get_scope_constant(&mut self, scope: Word) -> Word { self.writer - .get_constant_scalar(crate::ScalarValue::Sint(scope as _), 4) + .get_constant_scalar(crate::Literal::I32(scope as _)) } } diff --git a/src/back/spv/ray.rs b/src/back/spv/ray.rs index 79eb2ff971..0b53b9cc52 100644 --- a/src/back/spv/ray.rs +++ b/src/back/spv/ray.rs @@ -117,12 +117,9 @@ impl<'w> BlockContext<'w> { ) -> spirv::Word { let width = 4; let query_id = self.cached[query]; - let intersection_id = self.writer.get_constant_scalar( - crate::ScalarValue::Uint( - spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _, - ), - width, - ); + let intersection_id = self.writer.get_constant_scalar(crate::Literal::U32( + spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _, + )); let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value { vector_size: None, diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 90c6b5089d..3e3dfc60ad 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -540,8 +540,7 @@ impl Writer { ); iface.varying_ids.push(varying_id); - let default_value_id = - self.get_constant_scalar(crate::ScalarValue::Float(1.0), 4); + let default_value_id = self.get_constant_scalar(crate::Literal::F32(1.0)); prelude .body .push(Instruction::store(varying_id, default_value_id, None)); @@ -937,7 +936,7 @@ impl Writer { } LocalType::BindingArray { base, size } => { let inner_ty = self.get_type_id(LookupType::Handle(base)); - let scalar_id = self.get_constant_scalar(crate::ScalarValue::Uint(size), 4); + let scalar_id = self.get_constant_scalar(crate::Literal::U32(size)); Instruction::type_array(id, inner_ty, scalar_id) } LocalType::PointerToBindingArray { base, size, space } => { @@ -1108,20 +1107,29 @@ impl Writer { } pub(super) fn get_index_constant(&mut self, index: Word) -> Word { - self.get_constant_scalar(crate::ScalarValue::Uint(index as _), 4) + self.get_constant_scalar(crate::Literal::U32(index)) } - pub(super) fn get_constant_scalar( + pub(super) fn get_constant_scalar_with( &mut self, - value: crate::ScalarValue, + value: u8, + kind: crate::ScalarKind, width: crate::Bytes, - ) -> Word { - let scalar = CachedConstant::Scalar { value, width }; + ) -> Result { + Ok( + self.get_constant_scalar(crate::Literal::new(value, kind, width).ok_or( + Error::Validation("Unexpected kind and/or width for Literal"), + )?), + ) + } + + pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word { + let scalar = CachedConstant::Literal(value); if let Some(&id) = self.cached_constants.get(&scalar) { return id; } let id = self.id_gen.next(); - self.write_constant_scalar(id, &value, width, None); + self.write_constant_scalar(id, &value, None); self.cached_constants.insert(scalar, id); id } @@ -1129,8 +1137,7 @@ impl Writer { fn write_constant_scalar( &mut self, id: Word, - value: &crate::ScalarValue, - width: crate::Bytes, + value: &crate::Literal, debug_name: Option<&String>, ) { if self.flags.contains(WriterFlags::DEBUG) { @@ -1141,56 +1148,19 @@ impl Writer { let type_id = self.get_type_id(LookupType::Local(LocalType::Value { vector_size: None, kind: value.scalar_kind(), - width, + width: value.width(), pointer_space: None, })); - let (solo, pair); let instruction = match *value { - crate::ScalarValue::Sint(val) => { - let words = match width { - 4 => { - solo = [val as u32]; - &solo[..] - } - 8 => { - pair = [val as u32, (val >> 32) as u32]; - &pair - } - _ => unreachable!(), - }; - Instruction::constant(type_id, id, words) + crate::Literal::F64(value) => { + let bits = value.to_bits(); + Instruction::constant_64bit(type_id, id, bits as u32, (bits >> 32) as u32) } - crate::ScalarValue::Uint(val) => { - let words = match width { - 4 => { - solo = [val as u32]; - &solo[..] - } - 8 => { - pair = [val as u32, (val >> 32) as u32]; - &pair - } - _ => unreachable!(), - }; - Instruction::constant(type_id, id, words) - } - crate::ScalarValue::Float(val) => { - let words = match width { - 4 => { - solo = [(val as f32).to_bits()]; - &solo[..] - } - 8 => { - let bits = f64::to_bits(val); - pair = [bits as u32, (bits >> 32) as u32]; - &pair - } - _ => unreachable!(), - }; - Instruction::constant(type_id, id, words) - } - crate::ScalarValue::Bool(true) => Instruction::constant_true(type_id, id), - crate::ScalarValue::Bool(false) => Instruction::constant_false(type_id, id), + crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()), + crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value), + crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32), + crate::Literal::Bool(true) => Instruction::constant_true(type_id, id), + crate::Literal::Bool(false) => Instruction::constant_false(type_id, id), }; instruction.to_words(&mut self.logical_layout.declarations); @@ -1598,7 +1568,7 @@ impl Writer { substitute_inner_type_lookup = Some(LookupType::Local(LocalType::PointerToBindingArray { base, - size: remapped_binding_array_size as u64, + size: remapped_binding_array_size, space: global_variable.space, })) } @@ -1812,13 +1782,16 @@ impl Writer { match constant.inner { crate::ConstantInner::Composite { .. } => continue, crate::ConstantInner::Scalar { width, ref value } => { + let literal = crate::Literal::from_scalar(*value, width).ok_or( + Error::Validation("Unexpected kind and/or width for Literal"), + )?; self.constant_ids[handle.index()] = match constant.name { Some(ref name) => { let id = self.id_gen.next(); - self.write_constant_scalar(id, value, width, Some(name)); + self.write_constant_scalar(id, &literal, Some(name)); id } - None => self.get_constant_scalar(*value, width), + None => self.get_constant_scalar(literal), }; } } diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 97a271dbbb..93909ab7a6 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -1103,6 +1103,19 @@ impl Writer { // `postfix_expression` forms for member/component access and // subscripting. match *expression { + Expression::Literal(literal) => { + match literal { + // Floats are written using `Debug` instead of `Display` because it always appends the + // decimal part even it's zero + crate::Literal::F64(_) => { + return Err(Error::Custom("unsupported f64 literal".to_string())); + } + crate::Literal::F32(value) => write!(self.out, "{:?}", value)?, + crate::Literal::U32(value) => write!(self.out, "{}u", value)?, + crate::Literal::I32(value) => write!(self.out, "{}", value)?, + crate::Literal::Bool(value) => write!(self.out, "{}", value)?, + } + } Expression::Constant(constant) => self.write_constant(module, constant)?, Expression::ZeroValue(ty) => { self.write_type(module, ty)?; diff --git a/src/front/glsl/constants.rs b/src/front/glsl/constants.rs index 5f7fc4f892..9f763f9f04 100644 --- a/src/front/glsl/constants.rs +++ b/src/front/glsl/constants.rs @@ -72,6 +72,7 @@ impl<'a> ConstantSolver<'a> { match self.expressions[expr] { Expression::Constant(constant) => Ok(constant), Expression::ZeroValue(ty) => self.register_zero_constant(ty, span), + Expression::Literal(literal) => Ok(self.register_literal(literal, span)), Expression::AccessIndex { base, index } => self.access(base, index as usize), Expression::Access { base, index } => { let index = self.solve(index)?; @@ -648,6 +649,18 @@ impl<'a> ConstantSolver<'a> { Ok(self.register_constant(inner, span)) } + fn register_literal(&mut self, literal: crate::Literal, span: crate::Span) -> Handle { + let (width, value) = match literal { + crate::Literal::F64(n) => (8, ScalarValue::Float(n)), + crate::Literal::F32(n) => (4, ScalarValue::Float(n as f64)), + crate::Literal::U32(n) => (4, ScalarValue::Uint(n as u64)), + crate::Literal::I32(n) => (4, ScalarValue::Sint(n as i64)), + crate::Literal::Bool(b) => (1, ScalarValue::Bool(b)), + }; + + self.register_constant(ConstantInner::Scalar { width, value }, span) + } + fn register_constant(&mut self, inner: ConstantInner, span: crate::Span) -> Handle { self.constants.fetch_or_append( Constant { diff --git a/src/lib.rs b/src/lib.rs index 6a5bc8e0a2..1cf4c8b1a5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,8 +75,8 @@ of `Statement`s and other `Expression`s. Naga's rules for when `Expression`s are evaluated are as follows: -- [`Constant`] and [`ZeroValue`] expressions are considered to be - implicitly evaluated before execution begins. +- [`Literal`], [`Constant`], and [`ZeroValue`] expressions are + considered to be implicitly evaluated before execution begins. - [`FunctionArgument`] and [`LocalVariable`] expressions are considered implicitly evaluated upon entry to the function to which they belong. @@ -175,6 +175,7 @@ tree. [`CallResult`]: Expression::CallResult [`Constant`]: Expression::Constant [`ZeroValue`]: Expression::ZeroValue +[`Literal`]: Expression::Literal [`Derivative`]: Expression::Derivative [`FunctionArgument`]: Expression::FunctionArgument [`GlobalVariable`]: Expression::GlobalVariable @@ -792,6 +793,18 @@ pub struct Constant { pub inner: ConstantInner, } +#[derive(Debug, Clone, Copy, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum Literal { + F64(f64), + F32(f32), + U32(u32), + I32(i32), + Bool(bool), +} + /// A literal scalar value, used in constants. #[derive(Debug, Clone, Copy, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] @@ -1190,6 +1203,8 @@ bitflags::bitflags! { #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum Expression { + /// Literal. + Literal(Literal), /// Constant value. Constant(Handle), /// Zero value of a type. diff --git a/src/proc/mod.rs b/src/proc/mod.rs index c4c6d4f4e7..ed501693d1 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -82,6 +82,101 @@ impl super::ScalarKind { } } +impl PartialEq for crate::Literal { + fn eq(&self, other: &Self) -> bool { + match (*self, *other) { + (Self::F64(a), Self::F64(b)) => a.to_bits() == b.to_bits(), + (Self::F32(a), Self::F32(b)) => a.to_bits() == b.to_bits(), + (Self::U32(a), Self::U32(b)) => a == b, + (Self::I32(a), Self::I32(b)) => a == b, + (Self::Bool(a), Self::Bool(b)) => a == b, + _ => false, + } + } +} +impl Eq for crate::Literal {} +impl std::hash::Hash for crate::Literal { + fn hash(&self, hasher: &mut H) { + match *self { + Self::F64(v) => { + hasher.write_u8(0); + v.to_bits().hash(hasher); + } + Self::F32(v) => { + hasher.write_u8(1); + v.to_bits().hash(hasher); + } + Self::U32(v) => { + hasher.write_u8(2); + v.hash(hasher); + } + Self::I32(v) => { + hasher.write_u8(3); + v.hash(hasher); + } + Self::Bool(v) => { + hasher.write_u8(4); + v.hash(hasher); + } + } + } +} + +impl crate::Literal { + pub const fn new(value: u8, kind: crate::ScalarKind, width: crate::Bytes) -> Option { + match (value, kind, width) { + (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)), + (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)), + (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)), + (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)), + (1, crate::ScalarKind::Bool, 4) => Some(Self::Bool(true)), + (0, crate::ScalarKind::Bool, 4) => Some(Self::Bool(false)), + _ => None, + } + } + + pub const fn from_scalar(scalar: crate::ScalarValue, width: crate::Bytes) -> Option { + match (scalar, width) { + (crate::ScalarValue::Sint(n), 4) => Some(Self::I32(n as _)), + (crate::ScalarValue::Uint(n), 4) => Some(Self::U32(n as _)), + (crate::ScalarValue::Float(n), 4) => Some(Self::F32(n as _)), + (crate::ScalarValue::Float(n), 8) => Some(Self::F64(n)), + (crate::ScalarValue::Bool(b), _) => Some(Self::Bool(b)), + _ => None, + } + } + + pub const fn zero(kind: crate::ScalarKind, width: crate::Bytes) -> Option { + Self::new(0, kind, width) + } + + pub const fn one(kind: crate::ScalarKind, width: crate::Bytes) -> Option { + Self::new(1, kind, width) + } + + pub const fn width(&self) -> crate::Bytes { + match *self { + Self::F64(_) => 8, + Self::F32(_) | Self::U32(_) | Self::I32(_) => 4, + Self::Bool(_) => 1, + } + } + pub const fn scalar_kind(&self) -> crate::ScalarKind { + match *self { + Self::F64(_) | Self::F32(_) => crate::ScalarKind::Float, + Self::U32(_) => crate::ScalarKind::Uint, + Self::I32(_) => crate::ScalarKind::Sint, + Self::Bool(_) => crate::ScalarKind::Bool, + } + } + pub const fn ty_inner(&self) -> crate::TypeInner { + crate::TypeInner::Scalar { + kind: self.scalar_kind(), + width: self.width(), + } + } +} + pub const POINTER_SPAN: u32 = 4; impl super::TypeInner { @@ -311,7 +406,8 @@ impl crate::Expression { /// Returns true if the expression is considered emitted at the start of a function. pub const fn needs_pre_emit(&self) -> bool { match *self { - Self::Constant(_) + Self::Literal(_) + | Self::Constant(_) | Self::ZeroValue(_) | Self::FunctionArgument(_) | Self::GlobalVariable(_) diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 414749ee00..2345cccf8e 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -421,6 +421,7 @@ impl<'a> ResolveContext<'a> { } } } + crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()), crate::Expression::Constant(h) => match self.constants[h].inner { crate::ConstantInner::Scalar { width, ref value } => { TypeResolution::Value(Ti::Scalar { diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 4860435cec..b93f444a9e 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -492,8 +492,7 @@ impl FunctionInfo { requirements: UniformityRequirements::empty(), }, // always uniform - E::Constant(_) => Uniformity::new(), - E::ZeroValue(_) => Uniformity::new(), + E::Literal(_) | E::Constant(_) | E::ZeroValue(_) => Uniformity::new(), E::Splat { size: _, value } => Uniformity { non_uniform_result: self.add_ref(value), requirements: UniformityRequirements::empty(), diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 7f9e277cfe..ece98d70fd 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -269,6 +269,7 @@ impl super::Validator { } ShaderStages::all() } + E::Literal(_value) => ShaderStages::all(), E::Constant(_handle) => ShaderStages::all(), E::ZeroValue(_type) => ShaderStages::all(), E::Splat { size: _, value } => match resolver[value] { diff --git a/src/valid/handles.rs b/src/valid/handles.rs index 0115cd85d2..24e485277e 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -253,6 +253,7 @@ impl super::Validator { crate::Expression::AccessIndex { base, .. } => { handle.check_dep(base)?; } + crate::Expression::Literal(_value) => {} crate::Expression::Constant(constant) => { validate_constant(constant)?; }