diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index e9fc133485..d9f7749114 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -322,6 +322,7 @@ fn write_fun( arg, arg1, arg2, + arg3, } => { edges.insert("arg", arg); if let Some(expr) = arg1 { @@ -330,6 +331,9 @@ fn write_fun( if let Some(expr) = arg2 { edges.insert("arg2", expr); } + if let Some(expr) = arg3 { + edges.insert("arg3", expr); + } (format!("{:?}", fun).into(), 7) } E::As { diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index f2406fa9cd..b79aabb870 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2321,6 +2321,7 @@ impl<'a, W: Write> Writer<'a, W> { arg, arg1, arg2, + arg3, } => { use crate::MathFunction as Mf; @@ -2385,17 +2386,56 @@ impl<'a, W: Write> Writer<'a, W> { // bits Mf::CountOneBits => "bitCount", Mf::ReverseBits => "bitfieldReverse", + Mf::ExtractBits => "bitfieldExtract", + Mf::InsertBits => "bitfieldInsert", + // data packing + Mf::Pack4x8snorm => "packSnorm4x8", + Mf::Pack4x8unorm => "packUnorm4x8", + Mf::Pack2x16snorm => "packSnorm2x16", + Mf::Pack2x16unorm => "packUnorm2x16", + Mf::Pack2x16float => "packHalf2x16", + // data unpacking + Mf::Unpack4x8snorm => "unpackSnorm4x8", + Mf::Unpack4x8unorm => "unpackUnorm4x8", + Mf::Unpack2x16snorm => "unpackSnorm2x16", + Mf::Unpack2x16unorm => "unpackUnorm2x16", + Mf::Unpack2x16float => "unpackHalf2x16", }; + let extract_bits = fun == Mf::ExtractBits; + let insert_bits = fun == Mf::InsertBits; + write!(self.out, "{}(", fun_name)?; self.write_expr(arg, ctx)?; if let Some(arg) = arg1 { write!(self.out, ", ")?; - self.write_expr(arg, ctx)?; + if extract_bits { + write!(self.out, "int(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")")?; + } else { + self.write_expr(arg, ctx)?; + } } if let Some(arg) = arg2 { write!(self.out, ", ")?; - self.write_expr(arg, ctx)?; + if extract_bits || insert_bits { + write!(self.out, "int(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")")?; + } else { + self.write_expr(arg, ctx)?; + } + } + if let Some(arg) = arg3 { + write!(self.out, ", ")?; + if insert_bits { + write!(self.out, "int(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")")?; + } else { + self.write_expr(arg, ctx)?; + } } write!(self.out, ")")? } diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 883b67d135..604fd94d81 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1816,6 +1816,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { arg, arg1, arg2, + arg3, } => { use crate::MathFunction as Mf; @@ -1918,6 +1919,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, ", ")?; self.write_expr(module, arg, func_ctx)?; } + if let Some(arg) = arg3 { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } write!(self.out, ")")? } } diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index efe58d3dd5..bcd4dd6c48 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1110,6 +1110,7 @@ impl Writer { arg, arg1, arg2, + arg3, } => { use crate::MathFunction as Mf; @@ -1178,6 +1179,20 @@ impl Writer { // bits Mf::CountOneBits => "popcount", Mf::ReverseBits => "reverse_bits", + Mf::ExtractBits => "extract_bits", + Mf::InsertBits => "insert_bits", + // data packing + Mf::Pack4x8snorm => "pack_float_to_unorm4x8", + Mf::Pack4x8unorm => "pack_float_to_snorm4x8", + Mf::Pack2x16snorm => "pack_float_to_unorm2x16", + Mf::Pack2x16unorm => "pack_float_to_snorm2x16", + Mf::Pack2x16float => "", + // data unpacking + Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float", + Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float", + Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float", + Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float", + Mf::Unpack2x16float => "", }; if fun == Mf::Distance && scalar_argument { @@ -1186,9 +1201,20 @@ impl Writer { write!(self.out, " - ")?; self.put_expression(arg1.unwrap(), context, false)?; write!(self.out, ")")?; + } else if fun == Mf::Unpack2x16float { + write!(self.out, "float2(as_type(")?; + self.put_expression(arg, context, false)?; + write!(self.out, "))")?; + } else if fun == Mf::Pack2x16float { + write!(self.out, "as_type(half2(")?; + self.put_expression(arg, context, false)?; + write!(self.out, "))")?; } else { write!(self.out, "{}::{}", NAMESPACE, fun_name)?; - self.put_call_parameters(iter::once(arg).chain(arg1).chain(arg2), context)?; + self.put_call_parameters( + iter::once(arg).chain(arg1).chain(arg2).chain(arg3), + context, + )?; } } crate::Expression::As { @@ -2661,8 +2687,8 @@ fn test_stack_size() { } let stack_size = addresses.end - addresses.start; // check the size (in debug only) - // last observed macOS value: 18304 - if !(15000..=20000).contains(&stack_size) { + // last observed macOS value: 20528 (CI) + if !(15000..=25000).contains(&stack_size) { panic!("`put_expression` stack size {} has changed!", stack_size); } } diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 491ab79ce1..88c16e7cfc 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -439,6 +439,7 @@ impl<'w> BlockContext<'w> { arg, arg1, arg2, + arg3, } => { use crate::MathFunction as Mf; enum MathOp { @@ -457,6 +458,10 @@ impl<'w> BlockContext<'w> { Some(handle) => self.cached[handle], None => 0, }; + let arg3_id = match arg3 { + Some(handle) => self.cached[handle], + None => 0, + }; let id = self.gen_id(); let math_op = match fun { @@ -606,6 +611,40 @@ impl<'w> BlockContext<'w> { log::error!("unimplemented math function {:?}", fun); return Err(Error::FeatureNotImplemented("math function")); } + Mf::ExtractBits => { + let op = match arg_scalar_kind { + Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract, + Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract, + other => unimplemented!("Unexpected sign({:?})", other), + }; + MathOp::Custom(Instruction::ternary( + op, + result_type_id, + id, + arg0_id, + arg1_id, + arg2_id, + )) + } + Mf::InsertBits => MathOp::Custom(Instruction::quaternary( + spirv::Op::BitFieldInsert, + result_type_id, + id, + arg0_id, + arg1_id, + arg2_id, + arg3_id, + )), + Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8), + Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8), + Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16), + Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16), + Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16), + Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8), + Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8), + Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16), + Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16), + Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16), }; block.body.push(match math_op { @@ -614,7 +653,7 @@ impl<'w> BlockContext<'w> { op, result_type_id, id, - &[arg0_id, arg1_id, arg2_id][..fun.argument_count()], + &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()], ), MathOp::Custom(inst) => inst, }); diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index ac738dfa6e..211c51a1ed 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -703,6 +703,42 @@ impl super::Instruction { instruction } + pub(super) fn ternary( + op: Op, + result_type_id: Word, + id: Word, + operand_1: Word, + operand_2: Word, + operand_3: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(operand_1); + instruction.add_operand(operand_2); + instruction.add_operand(operand_3); + instruction + } + + pub(super) fn quaternary( + op: Op, + result_type_id: Word, + id: Word, + operand_1: Word, + operand_2: Word, + operand_3: Word, + operand_4: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(operand_1); + instruction.add_operand(operand_2); + instruction.add_operand(operand_3); + instruction.add_operand(operand_4); + instruction + } + pub(super) fn relational(op: Op, result_type_id: Word, id: Word, expr_id: Word) -> Self { let mut instruction = Self::new(op); instruction.set_type(result_type_id); diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index a215722fb3..82b46838b8 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -1456,6 +1456,7 @@ impl Writer { arg, arg1, arg2, + arg3, } => { use crate::MathFunction as Mf; @@ -1523,6 +1524,20 @@ impl Writer { // bits Mf::CountOneBits => Function::Regular("countOneBits"), Mf::ReverseBits => Function::Regular("reverseBits"), + Mf::ExtractBits => Function::Regular("extractBits"), + Mf::InsertBits => Function::Regular("insertBits"), + // data packing + Mf::Pack4x8snorm => Function::Regular("pack4x8snorm"), + Mf::Pack4x8unorm => Function::Regular("pack4x8unorm"), + Mf::Pack2x16snorm => Function::Regular("pack2x16snorm"), + Mf::Pack2x16unorm => Function::Regular("pack2x16unorm"), + Mf::Pack2x16float => Function::Regular("pack2x16float"), + // data unpacking + Mf::Unpack4x8snorm => Function::Regular("unpack4x8snorm"), + Mf::Unpack4x8unorm => Function::Regular("unpack4x8unorm"), + Mf::Unpack2x16snorm => Function::Regular("unpack2x16snorm"), + Mf::Unpack2x16unorm => Function::Regular("unpack2x16unorm"), + Mf::Unpack2x16float => Function::Regular("unpack2x16float"), _ => { return Err(Error::UnsupportedMathFunction(fun)); } @@ -1559,6 +1574,10 @@ impl Writer { write!(self.out, ", ")?; self.write_expr(module, arg, func_ctx)?; } + if let Some(arg) = arg3 { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } write!(self.out, ")")? } } diff --git a/src/front/glsl/builtins.rs b/src/front/glsl/builtins.rs index 84d7ee8e29..ed220bc98b 100644 --- a/src/front/glsl/builtins.rs +++ b/src/front/glsl/builtins.rs @@ -630,7 +630,21 @@ pub fn inject_builtin(declaration: &mut FunctionDeclaration, module: &mut Module )) } } - "bitCount" | "bitfieldReverse" => { + "bitCount" | "bitfieldReverse" | "bitfieldExtract" | "bitfieldInsert" => { + let fun = match name { + "bitCount" => MathFunction::CountOneBits, + "bitfieldReverse" => MathFunction::ReverseBits, + "bitfieldExtract" => MathFunction::ExtractBits, + "bitfieldInsert" => MathFunction::InsertBits, + _ => unreachable!(), + }; + + let mc = match fun { + MathFunction::ExtractBits => MacroCall::BitfieldExtract, + MathFunction::InsertBits => MacroCall::BitfieldInsert, + _ => MacroCall::MathFunction(fun), + }; + // bits layout // bit 0 - int/uint // bit 1 trough 2 - dims @@ -646,21 +660,93 @@ pub fn inject_builtin(declaration: &mut FunctionDeclaration, module: &mut Module _ => Some(VectorSize::Quad), }; - let args = vec![match size { + let ty = || match size { Some(size) => TypeInner::Vector { size, kind, width }, None => TypeInner::Scalar { kind, width }, - }]; + }; - declaration.overloads.push(module.add_builtin( - args, - MacroCall::MathFunction(match name { - "bitCount" => MathFunction::CountOneBits, - "bitfieldReverse" => MathFunction::ReverseBits, - _ => unreachable!(), - }), - )) + let mut args = vec![ty()]; + + match fun { + MathFunction::ExtractBits => { + args.push(TypeInner::Scalar { + kind: Sk::Sint, + width: 4, + }); + args.push(TypeInner::Scalar { + kind: Sk::Sint, + width: 4, + }); + } + MathFunction::InsertBits => { + args.push(ty()); + args.push(TypeInner::Scalar { + kind: Sk::Sint, + width: 4, + }); + args.push(TypeInner::Scalar { + kind: Sk::Sint, + width: 4, + }); + } + _ => {} + } + + declaration.overloads.push(module.add_builtin(args, mc)) } } + "packSnorm4x8" | "packUnorm4x8" | "packSnorm2x16" | "packUnorm2x16" | "packHalf2x16" => { + let fun = match name { + "packSnorm4x8" => MathFunction::Pack4x8snorm, + "packUnorm4x8" => MathFunction::Pack4x8unorm, + "packSnorm2x16" => MathFunction::Pack2x16unorm, + "packUnorm2x16" => MathFunction::Pack2x16snorm, + "packHalf2x16" => MathFunction::Pack2x16float, + _ => unreachable!(), + }; + + let ty = match fun { + MathFunction::Pack4x8snorm | MathFunction::Pack4x8unorm => TypeInner::Vector { + size: crate::VectorSize::Quad, + kind: Sk::Float, + width: 4, + }, + MathFunction::Pack2x16unorm + | MathFunction::Pack2x16snorm + | MathFunction::Pack2x16float => TypeInner::Vector { + size: crate::VectorSize::Bi, + kind: Sk::Float, + width: 4, + }, + _ => unreachable!(), + }; + + let args = vec![ty]; + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::MathFunction(fun))); + } + "unpackSnorm4x8" | "unpackUnorm4x8" | "unpackSnorm2x16" | "unpackUnorm2x16" + | "unpackHalf2x16" => { + let fun = match name { + "unpackSnorm4x8" => MathFunction::Unpack4x8snorm, + "unpackUnorm4x8" => MathFunction::Unpack4x8unorm, + "unpackSnorm2x16" => MathFunction::Unpack2x16snorm, + "unpackUnorm2x16" => MathFunction::Unpack2x16unorm, + "unpackHalf2x16" => MathFunction::Unpack2x16float, + _ => unreachable!(), + }; + + let args = vec![TypeInner::Scalar { + kind: Sk::Uint, + width: 4, + }]; + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::MathFunction(fun))); + } "atan" => { // bits layout // bit 0 - atan/atan2 @@ -1454,6 +1540,8 @@ pub enum MacroCall { TextureSize, TexelFetch, MathFunction(MathFunction), + BitfieldExtract, + BitfieldInsert, Relational(RelationalFunction), Binary(BinaryOperator), Mod(Option), @@ -1640,10 +1728,73 @@ impl MacroCall { arg: args[0], arg1: args.get(1).copied(), arg2: args.get(2).copied(), + arg3: args.get(3).copied(), }, Span::default(), body, )), + MacroCall::BitfieldInsert => { + let conv_arg_2 = ctx.add_expression( + Expression::As { + expr: args[2], + kind: Sk::Uint, + convert: Some(4), + }, + Span::default(), + body, + ); + let conv_arg_3 = ctx.add_expression( + Expression::As { + expr: args[3], + kind: Sk::Uint, + convert: Some(4), + }, + Span::default(), + body, + ); + Ok(ctx.add_expression( + Expression::Math { + fun: MathFunction::InsertBits, + arg: args[0], + arg1: Some(args[1]), + arg2: Some(conv_arg_2), + arg3: Some(conv_arg_3), + }, + Span::default(), + body, + )) + } + MacroCall::BitfieldExtract => { + let conv_arg_1 = ctx.add_expression( + Expression::As { + expr: args[1], + kind: Sk::Uint, + convert: Some(4), + }, + Span::default(), + body, + ); + let conv_arg_2 = ctx.add_expression( + Expression::As { + expr: args[2], + kind: Sk::Uint, + convert: Some(4), + }, + Span::default(), + body, + ); + Ok(ctx.add_expression( + Expression::Math { + fun: MathFunction::ExtractBits, + arg: args[0], + arg1: Some(conv_arg_1), + arg2: Some(conv_arg_2), + arg3: None, + }, + Span::default(), + body, + )) + } MacroCall::Relational(fun) => Ok(ctx.add_expression( Expression::Relational { fun, @@ -1683,6 +1834,7 @@ impl MacroCall { arg: args[0], arg1: args.get(1).copied(), arg2: args.get(2).copied(), + arg3: args.get(3).copied(), }, Span::default(), body, @@ -1707,6 +1859,7 @@ impl MacroCall { arg: args[0], arg1: args.get(1).copied(), arg2: args.get(2).copied(), + arg3: args.get(3).copied(), }, Span::default(), body, diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index 5523824806..4cc0481509 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -1330,6 +1330,7 @@ impl> Parser { arg: loaded, arg1: None, arg2: None, + arg3: None, }, span, ); @@ -1407,6 +1408,7 @@ impl> Parser { arg: loaded, arg1: None, arg2: None, + arg3: None, }, span, ) @@ -1811,6 +1813,7 @@ impl> Parser { arg: matrix_handle, arg1: None, arg2: None, + arg3: None, }; self.lookup_expression.insert( result_id, @@ -1837,6 +1840,70 @@ impl> Parser { arg: left_handle, arg1: Some(right_handle), arg2: None, + arg3: None, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::BitFieldInsert => { + inst.expect(7)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let base_id = self.next()?; + let insert_id = self.next()?; + let offset_id = self.next()?; + let count_id = self.next()?; + let base_lexp = self.lookup_expression.lookup(base_id)?; + let base_handle = get_expr_handle!(base_id, base_lexp); + let insert_lexp = self.lookup_expression.lookup(insert_id)?; + let insert_handle = get_expr_handle!(insert_id, insert_lexp); + let offset_lexp = self.lookup_expression.lookup(offset_id)?; + let offset_handle = get_expr_handle!(offset_id, offset_lexp); + let count_lexp = self.lookup_expression.lookup(count_id)?; + let count_handle = get_expr_handle!(count_id, count_lexp); + let expr = crate::Expression::Math { + fun: crate::MathFunction::InsertBits, + arg: base_handle, + arg1: Some(insert_handle), + arg2: Some(offset_handle), + arg3: Some(count_handle), + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::BitFieldSExtract | Op::BitFieldUExtract => { + inst.expect(6)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let base_id = self.next()?; + let offset_id = self.next()?; + let count_id = self.next()?; + let base_lexp = self.lookup_expression.lookup(base_id)?; + let base_handle = get_expr_handle!(base_id, base_lexp); + let offset_lexp = self.lookup_expression.lookup(offset_id)?; + let offset_handle = get_expr_handle!(offset_id, offset_lexp); + let count_lexp = self.lookup_expression.lookup(count_id)?; + let count_handle = get_expr_handle!(count_id, count_lexp); + let expr = crate::Expression::Math { + fun: crate::MathFunction::ExtractBits, + arg: base_handle, + arg1: Some(offset_handle), + arg2: Some(count_handle), + arg3: None, }; self.lookup_expression.insert( result_id, @@ -1863,6 +1930,7 @@ impl> Parser { arg: left_handle, arg1: Some(right_handle), arg2: None, + arg3: None, }; self.lookup_expression.insert( result_id, @@ -2352,6 +2420,16 @@ impl> Parser { Glo::FaceForward => Mf::FaceForward, Glo::Reflect => Mf::Reflect, Glo::Refract => Mf::Refract, + Glo::PackUnorm4x8 => Mf::Pack4x8unorm, + Glo::PackSnorm4x8 => Mf::Pack4x8snorm, + Glo::PackHalf2x16 => Mf::Pack2x16float, + Glo::PackUnorm2x16 => Mf::Pack2x16unorm, + Glo::PackSnorm2x16 => Mf::Pack2x16snorm, + Glo::UnpackUnorm4x8 => Mf::Unpack4x8unorm, + Glo::UnpackSnorm4x8 => Mf::Unpack4x8snorm, + Glo::UnpackHalf2x16 => Mf::Unpack2x16float, + Glo::UnpackUnorm2x16 => Mf::Unpack2x16unorm, + Glo::UnpackSnorm2x16 => Mf::Unpack2x16snorm, _ => return Err(Error::UnsupportedExtInst(inst_id)), }; @@ -2376,12 +2454,20 @@ impl> Parser { } else { None }; + let arg3 = if arg_count > 3 { + let arg_id = self.next()?; + let lexp = self.lookup_expression.lookup(arg_id)?; + Some(get_expr_handle!(arg_id, lexp)) + } else { + None + }; let expr = crate::Expression::Math { fun, arg, arg1, arg2, + arg3, }; self.lookup_expression.insert( result_id, diff --git a/src/front/wgsl/conv.rs b/src/front/wgsl/conv.rs index 519073fc39..add465f189 100644 --- a/src/front/wgsl/conv.rs +++ b/src/front/wgsl/conv.rs @@ -198,6 +198,20 @@ pub fn map_standard_fun(word: &str) -> Option { // bits "countOneBits" => Mf::CountOneBits, "reverseBits" => Mf::ReverseBits, + "extractBits" => Mf::ExtractBits, + "insertBits" => Mf::InsertBits, + // data packing + "pack4x8snorm" => Mf::Pack4x8snorm, + "pack4x8unorm" => Mf::Pack4x8unorm, + "pack2x16snorm" => Mf::Pack2x16snorm, + "pack2x16unorm" => Mf::Pack2x16unorm, + "pack2x16float" => Mf::Pack2x16float, + // data unpacking + "unpack4x8snorm" => Mf::Unpack4x8snorm, + "unpack4x8unorm" => Mf::Unpack4x8unorm, + "unpack2x16snorm" => Mf::Unpack2x16snorm, + "unpack2x16unorm" => Mf::Unpack2x16unorm, + "unpack2x16float" => Mf::Unpack2x16float, _ => return None, }) } diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 800851a744..289282a43f 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -1400,12 +1400,19 @@ impl Parser { } else { None }; + let arg3 = if arg_count > 3 { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(lexer, ctx.reborrow())?) + } else { + None + }; lexer.close_arguments()?; crate::Expression::Math { fun, arg, arg1, arg2, + arg3, } } else { match name { diff --git a/src/lib.rs b/src/lib.rs index 97b5b9f52c..737e0bfca3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -934,6 +934,20 @@ pub enum MathFunction { // bits CountOneBits, ReverseBits, + ExtractBits, + InsertBits, + // data packing + Pack4x8snorm, + Pack4x8unorm, + Pack2x16snorm, + Pack2x16unorm, + Pack2x16float, + // data unpacking + Unpack4x8snorm, + Unpack4x8unorm, + Unpack2x16snorm, + Unpack2x16unorm, + Unpack2x16float, } /// Sampling modifier to control the level of detail. @@ -1255,6 +1269,7 @@ pub enum Expression { arg: Handle, arg1: Option>, arg2: Option>, + arg3: Option>, }, /// Cast a simple type to another kind. As { diff --git a/src/proc/mod.rs b/src/proc/mod.rs index fd6015e623..c9a66585c5 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -245,6 +245,20 @@ impl super::MathFunction { // bits Self::CountOneBits => 1, Self::ReverseBits => 1, + Self::ExtractBits => 3, + Self::InsertBits => 4, + // data packing + Self::Pack4x8snorm => 1, + Self::Pack4x8unorm => 1, + Self::Pack2x16snorm => 1, + Self::Pack2x16unorm => 1, + Self::Pack2x16float => 1, + // data unpacking + Self::Unpack4x8snorm => 1, + Self::Unpack4x8unorm => 1, + Self::Unpack2x16snorm => 1, + Self::Unpack2x16unorm => 1, + Self::Unpack2x16float => 1, } } } diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 217eeb1e2d..e1d2bd6b8d 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -659,6 +659,7 @@ impl<'a> ResolveContext<'a> { arg, arg1, arg2: _, + arg3: _, } => { use crate::MathFunction as Mf; let res_arg = past(arg); @@ -781,7 +782,21 @@ impl<'a> ResolveContext<'a> { }, // bits Mf::CountOneBits | - Mf::ReverseBits => res_arg.clone(), + Mf::ReverseBits | + Mf::ExtractBits | + Mf::InsertBits => res_arg.clone(), + // data packing + Mf::Pack4x8snorm | + Mf::Pack4x8unorm | + Mf::Pack2x16snorm | + Mf::Pack2x16unorm | + Mf::Pack2x16float => TypeResolution::Value(Ti::Scalar { kind: crate::ScalarKind::Uint, width: 4 }), + // data unpacking + Mf::Unpack4x8snorm | + Mf::Unpack4x8unorm => TypeResolution::Value(Ti::Vector { size: crate::VectorSize::Quad, kind: crate::ScalarKind::Float, width: 4 }), + Mf::Unpack2x16snorm | + Mf::Unpack2x16unorm | + Mf::Unpack2x16float => TypeResolution::Value(Ti::Vector { size: crate::VectorSize::Bi, kind: crate::ScalarKind::Float, width: 4 }), } } crate::Expression::As { diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 8c1f659f9c..a9dab24cd6 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -870,15 +870,17 @@ impl super::Validator { arg, arg1, arg2, + arg3, } => { use crate::MathFunction as Mf; let arg_ty = resolver.resolve(arg)?; let arg1_ty = arg1.map(|expr| resolver.resolve(expr)).transpose()?; let arg2_ty = arg2.map(|expr| resolver.resolve(expr)).transpose()?; + let arg3_ty = arg3.map(|expr| resolver.resolve(expr)).transpose()?; match fun { Mf::Abs => { - if arg1_ty.is_some() | arg2_ty.is_some() { + if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } let good = match *arg_ty { @@ -890,8 +892,8 @@ impl super::Validator { } } Mf::Min | Mf::Max => { - let arg1_ty = match (arg1_ty, arg2_ty) { - (Some(ty1), None) => ty1, + let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), None, None) => ty1, _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; let good = match *arg_ty { @@ -910,8 +912,8 @@ impl super::Validator { } } Mf::Clamp => { - let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty) { - (Some(ty1), Some(ty2)) => (ty1, ty2), + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), None) => (ty1, ty2), _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; let good = match *arg_ty { @@ -961,7 +963,7 @@ impl super::Validator { | Mf::Sign | Mf::Sqrt | Mf::InverseSqrt => { - if arg1_ty.is_some() | arg2_ty.is_some() { + if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } match *arg_ty { @@ -975,8 +977,8 @@ impl super::Validator { } } Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => { - let arg1_ty = match (arg1_ty, arg2_ty) { - (Some(ty1), None) => ty1, + let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), None, None) => ty1, _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; match *arg_ty { @@ -997,8 +999,8 @@ impl super::Validator { } } Mf::Modf | Mf::Frexp | Mf::Ldexp => { - let arg1_ty = match (arg1_ty, arg2_ty) { - (Some(ty1), None) => ty1, + let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), None, None) => ty1, _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; let (size0, width0) = match *arg_ty { @@ -1032,8 +1034,8 @@ impl super::Validator { } } Mf::Dot | Mf::Outer | Mf::Cross | Mf::Reflect => { - let arg1_ty = match (arg1_ty, arg2_ty) { - (Some(ty1), None) => ty1, + let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), None, None) => ty1, _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; match *arg_ty { @@ -1051,8 +1053,8 @@ impl super::Validator { } } Mf::Refract => { - let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty) { - (Some(ty1), Some(ty2)) => (ty1, ty2), + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), None) => (ty1, ty2), _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; @@ -1092,7 +1094,7 @@ impl super::Validator { } } Mf::Normalize => { - if arg1_ty.is_some() | arg2_ty.is_some() { + if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } match *arg_ty { @@ -1103,8 +1105,8 @@ impl super::Validator { } } Mf::FaceForward | Mf::Fma | Mf::SmoothStep => { - let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty) { - (Some(ty1), Some(ty2)) => (ty1, ty2), + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), None) => (ty1, ty2), _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; match *arg_ty { @@ -1132,8 +1134,8 @@ impl super::Validator { } } Mf::Mix => { - let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty) { - (Some(ty1), Some(ty2)) => (ty1, ty2), + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), None) => (ty1, ty2), _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; let arg_width = match *arg_ty { @@ -1172,7 +1174,7 @@ impl super::Validator { } } Mf::Inverse | Mf::Determinant => { - if arg1_ty.is_some() | arg2_ty.is_some() { + if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } let good = match *arg_ty { @@ -1184,7 +1186,7 @@ impl super::Validator { } } Mf::Transpose => { - if arg1_ty.is_some() | arg2_ty.is_some() { + if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } match *arg_ty { @@ -1193,7 +1195,7 @@ impl super::Validator { } } Mf::CountOneBits | Mf::ReverseBits => { - if arg1_ty.is_some() | arg2_ty.is_some() { + if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } match *arg_ty { @@ -1204,6 +1206,118 @@ impl super::Validator { _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } } + Mf::InsertBits => { + let (arg1_ty, arg2_ty, arg3_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), Some(ty3)) => (ty1, ty2, ty3), + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Scalar { kind: Sk::Sint, .. } + | Ti::Scalar { kind: Sk::Uint, .. } + | Ti::Vector { kind: Sk::Sint, .. } + | Ti::Vector { kind: Sk::Uint, .. } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + match *arg2_ty { + Ti::Scalar { kind: Sk::Uint, .. } => {} + _ => { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg2.unwrap(), + )) + } + } + match *arg3_ty { + Ti::Scalar { kind: Sk::Uint, .. } => {} + _ => { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg3.unwrap(), + )) + } + } + } + Mf::ExtractBits => { + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), None) => (ty1, ty2), + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Scalar { kind: Sk::Sint, .. } + | Ti::Scalar { kind: Sk::Uint, .. } + | Ti::Vector { kind: Sk::Sint, .. } + | Ti::Vector { kind: Sk::Uint, .. } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + match *arg1_ty { + Ti::Scalar { kind: Sk::Uint, .. } => {} + _ => { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg1.unwrap(), + )) + } + } + match *arg2_ty { + Ti::Scalar { kind: Sk::Uint, .. } => {} + _ => { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg2.unwrap(), + )) + } + } + } + Mf::Pack2x16unorm | Mf::Pack2x16snorm | Mf::Pack2x16float => { + if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Vector { + size: crate::VectorSize::Bi, + kind: Sk::Float, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::Pack4x8snorm | Mf::Pack4x8unorm => { + if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Vector { + size: crate::VectorSize::Quad, + kind: Sk::Float, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::Unpack2x16float + | Mf::Unpack2x16snorm + | Mf::Unpack2x16unorm + | Mf::Unpack4x8snorm + | Mf::Unpack4x8unorm => { + if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Scalar { kind: Sk::Uint, .. } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } } ShaderStages::all() } diff --git a/tests/in/bits.param.ron b/tests/in/bits.param.ron new file mode 100644 index 0000000000..fb45784e55 --- /dev/null +++ b/tests/in/bits.param.ron @@ -0,0 +1,15 @@ +( + msl: ( + lang_version: (1, 2), + per_stage_map: ( + cs: ( + resources: { + }, + sizes_buffer: Some(0), + ) + ), + inline_samplers: [], + spirv_cross_compatibility: false, + fake_missing_bindings: false, + ), +) diff --git a/tests/in/bits.wgsl b/tests/in/bits.wgsl new file mode 100644 index 0000000000..1834411f35 --- /dev/null +++ b/tests/in/bits.wgsl @@ -0,0 +1,39 @@ +[[stage(compute), workgroup_size(1)]] +fn main() { + var i = 0; + var i2 = vec2(0); + var i3 = vec3(0); + var i4 = vec4(0); + var u = 0u; + var u2 = vec2(0u); + var u3 = vec3(0u); + var u4 = vec4(0u); + var f2 = vec2(0.0); + var f4 = vec4(0.0); + u = pack4x8snorm(f4); + u = pack4x8unorm(f4); + u = pack2x16snorm(f2); + u = pack2x16unorm(f2); + u = pack2x16float(f2); + f4 = unpack4x8snorm(u); + f4 = unpack4x8unorm(u); + f2 = unpack2x16snorm(u); + f2 = unpack2x16unorm(u); + f2 = unpack2x16float(u); + i = insertBits(i, i, 5u, 10u); + i2 = insertBits(i2, i2, 5u, 10u); + i3 = insertBits(i3, i3, 5u, 10u); + i4 = insertBits(i4, i4, 5u, 10u); + u = insertBits(u, u, 5u, 10u); + u2 = insertBits(u2, u2, 5u, 10u); + u3 = insertBits(u3, u3, 5u, 10u); + u4 = insertBits(u4, u4, 5u, 10u); + i = extractBits(i, 5u, 10u); + i2 = extractBits(i2, 5u, 10u); + i3 = extractBits(i3, 5u, 10u); + i4 = extractBits(i4, 5u, 10u); + u = extractBits(u, 5u, 10u); + u2 = extractBits(u2, 5u, 10u); + u3 = extractBits(u3, 5u, 10u); + u4 = extractBits(u4, 5u, 10u); +} diff --git a/tests/in/glsl/bits_glsl.frag b/tests/in/glsl/bits_glsl.frag new file mode 100644 index 0000000000..cd960ed861 --- /dev/null +++ b/tests/in/glsl/bits_glsl.frag @@ -0,0 +1,40 @@ +#version 450 + +void main() { + int i = 0; + ivec2 i2 = ivec2(0); + ivec3 i3 = ivec3(0); + ivec4 i4 = ivec4(0); + uint u = 0; + uvec2 u2 = uvec2(0); + uvec3 u3 = uvec3(0); + uvec4 u4 = uvec4(0); + vec2 f2 = vec2(0.0); + vec4 f4 = vec4(0.0); + u = packSnorm4x8(f4); + u = packUnorm4x8(f4); + u = packSnorm2x16(f2); + u = packUnorm2x16(f2); + u = packHalf2x16(f2); + f4 = unpackSnorm4x8(u); + f4 = unpackUnorm4x8(u); + f2 = unpackSnorm2x16(u); + f2 = unpackUnorm2x16(u); + f2 = unpackHalf2x16(u); + i = bitfieldInsert(i, i, 5, 10); + i2 = bitfieldInsert(i2, i2, 5, 10); + i3 = bitfieldInsert(i3, i3, 5, 10); + i4 = bitfieldInsert(i4, i4, 5, 10); + u = bitfieldInsert(u, u, 5, 10); + u2 = bitfieldInsert(u2, u2, 5, 10); + u3 = bitfieldInsert(u3, u3, 5, 10); + u4 = bitfieldInsert(u4, u4, 5, 10); + i = bitfieldExtract(i, 5, 10); + i2 = bitfieldExtract(i2, 5, 10); + i3 = bitfieldExtract(i3, 5, 10); + i4 = bitfieldExtract(i4, 5, 10); + u = bitfieldExtract(u, 5, 10); + u2 = bitfieldExtract(u2, 5, 10); + u3 = bitfieldExtract(u3, 5, 10); + u4 = bitfieldExtract(u4, 5, 10); +} \ No newline at end of file diff --git a/tests/out/glsl/bits.main.Compute.glsl b/tests/out/glsl/bits.main.Compute.glsl new file mode 100644 index 0000000000..303cdbc440 --- /dev/null +++ b/tests/out/glsl/bits.main.Compute.glsl @@ -0,0 +1,90 @@ +#version 310 es + +precision highp float; +precision highp int; + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + + +void main() { + int i = 0; + ivec2 i2_ = ivec2(0, 0); + ivec3 i3_ = ivec3(0, 0, 0); + ivec4 i4_ = ivec4(0, 0, 0, 0); + uint u = 0u; + uvec2 u2_ = uvec2(0u, 0u); + uvec3 u3_ = uvec3(0u, 0u, 0u); + uvec4 u4_ = uvec4(0u, 0u, 0u, 0u); + vec2 f2_ = vec2(0.0, 0.0); + vec4 f4_ = vec4(0.0, 0.0, 0.0, 0.0); + i2_ = ivec2(0); + i3_ = ivec3(0); + i4_ = ivec4(0); + u2_ = uvec2(0u); + u3_ = uvec3(0u); + u4_ = uvec4(0u); + f2_ = vec2(0.0); + f4_ = vec4(0.0); + vec4 _e28 = f4_; + u = packSnorm4x8(_e28); + vec4 _e30 = f4_; + u = packUnorm4x8(_e30); + vec2 _e32 = f2_; + u = packSnorm2x16(_e32); + vec2 _e34 = f2_; + u = packUnorm2x16(_e34); + vec2 _e36 = f2_; + u = packHalf2x16(_e36); + uint _e38 = u; + f4_ = unpackSnorm4x8(_e38); + uint _e40 = u; + f4_ = unpackUnorm4x8(_e40); + uint _e42 = u; + f2_ = unpackSnorm2x16(_e42); + uint _e44 = u; + f2_ = unpackUnorm2x16(_e44); + uint _e46 = u; + f2_ = unpackHalf2x16(_e46); + int _e48 = i; + int _e49 = i; + i = bitfieldInsert(_e48, _e49, int(5u), int(10u)); + ivec2 _e53 = i2_; + ivec2 _e54 = i2_; + i2_ = bitfieldInsert(_e53, _e54, int(5u), int(10u)); + ivec3 _e58 = i3_; + ivec3 _e59 = i3_; + i3_ = bitfieldInsert(_e58, _e59, int(5u), int(10u)); + ivec4 _e63 = i4_; + ivec4 _e64 = i4_; + i4_ = bitfieldInsert(_e63, _e64, int(5u), int(10u)); + uint _e68 = u; + uint _e69 = u; + u = bitfieldInsert(_e68, _e69, int(5u), int(10u)); + uvec2 _e73 = u2_; + uvec2 _e74 = u2_; + u2_ = bitfieldInsert(_e73, _e74, int(5u), int(10u)); + uvec3 _e78 = u3_; + uvec3 _e79 = u3_; + u3_ = bitfieldInsert(_e78, _e79, int(5u), int(10u)); + uvec4 _e83 = u4_; + uvec4 _e84 = u4_; + u4_ = bitfieldInsert(_e83, _e84, int(5u), int(10u)); + int _e88 = i; + i = bitfieldExtract(_e88, int(5u), int(10u)); + ivec2 _e92 = i2_; + i2_ = bitfieldExtract(_e92, int(5u), int(10u)); + ivec3 _e96 = i3_; + i3_ = bitfieldExtract(_e96, int(5u), int(10u)); + ivec4 _e100 = i4_; + i4_ = bitfieldExtract(_e100, int(5u), int(10u)); + uint _e104 = u; + u = bitfieldExtract(_e104, int(5u), int(10u)); + uvec2 _e108 = u2_; + u2_ = bitfieldExtract(_e108, int(5u), int(10u)); + uvec3 _e112 = u3_; + u3_ = bitfieldExtract(_e112, int(5u), int(10u)); + uvec4 _e116 = u4_; + u4_ = bitfieldExtract(_e116, int(5u), int(10u)); + return; +} + diff --git a/tests/out/ir/shadow.ron b/tests/out/ir/shadow.ron index 4dae9ffae9..572ffe39f6 100644 --- a/tests/out/ir/shadow.ron +++ b/tests/out/ir/shadow.ron @@ -968,6 +968,7 @@ arg: 52, arg1: Some(32), arg2: None, + arg3: None, ), Binary( op: GreaterEqual, @@ -1015,6 +1016,7 @@ arg: 65, arg1: None, arg2: None, + arg3: None, ), AccessIndex( base: 6, @@ -1129,18 +1131,21 @@ arg: 93, arg1: None, arg2: None, + arg3: None, ), Math( fun: Dot, arg: 66, arg1: Some(94), arg2: None, + arg3: None, ), Math( fun: Max, arg: 46, arg1: Some(95), arg2: None, + arg3: None, ), Binary( op: Multiply, diff --git a/tests/out/msl/bits.msl b/tests/out/msl/bits.msl new file mode 100644 index 0000000000..70ca905e39 --- /dev/null +++ b/tests/out/msl/bits.msl @@ -0,0 +1,87 @@ +// language: metal1.2 +#include +#include + + +kernel void main1( +) { + int i = 0; + metal::int2 i2_; + metal::int3 i3_; + metal::int4 i4_; + metal::uint u = 0u; + metal::uint2 u2_; + metal::uint3 u3_; + metal::uint4 u4_; + metal::float2 f2_; + metal::float4 f4_; + i2_ = metal::int2(0); + i3_ = metal::int3(0); + i4_ = metal::int4(0); + u2_ = metal::uint2(0u); + u3_ = metal::uint3(0u); + u4_ = metal::uint4(0u); + f2_ = metal::float2(0.0); + f4_ = metal::float4(0.0); + metal::float4 _e28 = f4_; + u = metal::pack_float_to_unorm4x8(_e28); + metal::float4 _e30 = f4_; + u = metal::pack_float_to_snorm4x8(_e30); + metal::float2 _e32 = f2_; + u = metal::pack_float_to_unorm2x16(_e32); + metal::float2 _e34 = f2_; + u = metal::pack_float_to_snorm2x16(_e34); + metal::float2 _e36 = f2_; + u = as_type(half2(_e36)); + metal::uint _e38 = u; + f4_ = metal::unpack_snorm4x8_to_float(_e38); + metal::uint _e40 = u; + f4_ = metal::unpack_unorm4x8_to_float(_e40); + metal::uint _e42 = u; + f2_ = metal::unpack_snorm2x16_to_float(_e42); + metal::uint _e44 = u; + f2_ = metal::unpack_unorm2x16_to_float(_e44); + metal::uint _e46 = u; + f2_ = float2(as_type(_e46)); + int _e48 = i; + int _e49 = i; + i = metal::insert_bits(_e48, _e49, 5u, 10u); + metal::int2 _e53 = i2_; + metal::int2 _e54 = i2_; + i2_ = metal::insert_bits(_e53, _e54, 5u, 10u); + metal::int3 _e58 = i3_; + metal::int3 _e59 = i3_; + i3_ = metal::insert_bits(_e58, _e59, 5u, 10u); + metal::int4 _e63 = i4_; + metal::int4 _e64 = i4_; + i4_ = metal::insert_bits(_e63, _e64, 5u, 10u); + metal::uint _e68 = u; + metal::uint _e69 = u; + u = metal::insert_bits(_e68, _e69, 5u, 10u); + metal::uint2 _e73 = u2_; + metal::uint2 _e74 = u2_; + u2_ = metal::insert_bits(_e73, _e74, 5u, 10u); + metal::uint3 _e78 = u3_; + metal::uint3 _e79 = u3_; + u3_ = metal::insert_bits(_e78, _e79, 5u, 10u); + metal::uint4 _e83 = u4_; + metal::uint4 _e84 = u4_; + u4_ = metal::insert_bits(_e83, _e84, 5u, 10u); + int _e88 = i; + i = metal::extract_bits(_e88, 5u, 10u); + metal::int2 _e92 = i2_; + i2_ = metal::extract_bits(_e92, 5u, 10u); + metal::int3 _e96 = i3_; + i3_ = metal::extract_bits(_e96, 5u, 10u); + metal::int4 _e100 = i4_; + i4_ = metal::extract_bits(_e100, 5u, 10u); + metal::uint _e104 = u; + u = metal::extract_bits(_e104, 5u, 10u); + metal::uint2 _e108 = u2_; + u2_ = metal::extract_bits(_e108, 5u, 10u); + metal::uint3 _e112 = u3_; + u3_ = metal::extract_bits(_e112, 5u, 10u); + metal::uint4 _e116 = u4_; + u4_ = metal::extract_bits(_e116, 5u, 10u); + return; +} diff --git a/tests/out/spv/bits.spvasm b/tests/out/spv/bits.spvasm new file mode 100644 index 0000000000..f7b8384530 --- /dev/null +++ b/tests/out/spv/bits.spvasm @@ -0,0 +1,155 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 111 +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %40 "main" +OpExecutionMode %40 LocalSize 1 1 1 +%2 = OpTypeVoid +%4 = OpTypeInt 32 1 +%3 = OpConstant %4 0 +%6 = OpTypeInt 32 0 +%5 = OpConstant %6 0 +%8 = OpTypeFloat 32 +%7 = OpConstant %8 0.0 +%9 = OpConstant %6 5 +%10 = OpConstant %6 10 +%11 = OpTypeVector %4 2 +%12 = OpTypeVector %4 3 +%13 = OpTypeVector %4 4 +%14 = OpTypeVector %6 2 +%15 = OpTypeVector %6 3 +%16 = OpTypeVector %6 4 +%17 = OpTypeVector %8 2 +%18 = OpTypeVector %8 4 +%20 = OpTypePointer Function %4 +%22 = OpTypePointer Function %11 +%24 = OpTypePointer Function %12 +%26 = OpTypePointer Function %13 +%28 = OpTypePointer Function %6 +%30 = OpTypePointer Function %14 +%32 = OpTypePointer Function %15 +%34 = OpTypePointer Function %16 +%36 = OpTypePointer Function %17 +%38 = OpTypePointer Function %18 +%41 = OpTypeFunction %2 +%40 = OpFunction %2 None %41 +%39 = OpLabel +%37 = OpVariable %38 Function +%31 = OpVariable %32 Function +%25 = OpVariable %26 Function +%19 = OpVariable %20 Function %3 +%33 = OpVariable %34 Function +%27 = OpVariable %28 Function %5 +%21 = OpVariable %22 Function +%35 = OpVariable %36 Function +%29 = OpVariable %30 Function +%23 = OpVariable %24 Function +OpBranch %42 +%42 = OpLabel +%43 = OpCompositeConstruct %11 %3 %3 +OpStore %21 %43 +%44 = OpCompositeConstruct %12 %3 %3 %3 +OpStore %23 %44 +%45 = OpCompositeConstruct %13 %3 %3 %3 %3 +OpStore %25 %45 +%46 = OpCompositeConstruct %14 %5 %5 +OpStore %29 %46 +%47 = OpCompositeConstruct %15 %5 %5 %5 +OpStore %31 %47 +%48 = OpCompositeConstruct %16 %5 %5 %5 %5 +OpStore %33 %48 +%49 = OpCompositeConstruct %17 %7 %7 +OpStore %35 %49 +%50 = OpCompositeConstruct %18 %7 %7 %7 %7 +OpStore %37 %50 +%51 = OpLoad %18 %37 +%52 = OpExtInst %6 %1 PackSnorm4x8 %51 +OpStore %27 %52 +%53 = OpLoad %18 %37 +%54 = OpExtInst %6 %1 PackUnorm4x8 %53 +OpStore %27 %54 +%55 = OpLoad %17 %35 +%56 = OpExtInst %6 %1 PackSnorm2x16 %55 +OpStore %27 %56 +%57 = OpLoad %17 %35 +%58 = OpExtInst %6 %1 PackSnorm2x16 %57 +OpStore %27 %58 +%59 = OpLoad %17 %35 +%60 = OpExtInst %6 %1 PackHalf2x16 %59 +OpStore %27 %60 +%61 = OpLoad %6 %27 +%62 = OpExtInst %18 %1 UnpackSnorm4x8 %61 +OpStore %37 %62 +%63 = OpLoad %6 %27 +%64 = OpExtInst %18 %1 UnpackUnorm4x8 %63 +OpStore %37 %64 +%65 = OpLoad %6 %27 +%66 = OpExtInst %17 %1 UnpackSnorm2x16 %65 +OpStore %35 %66 +%67 = OpLoad %6 %27 +%68 = OpExtInst %17 %1 UnpackSnorm2x16 %67 +OpStore %35 %68 +%69 = OpLoad %6 %27 +%70 = OpExtInst %17 %1 UnpackHalf2x16 %69 +OpStore %35 %70 +%71 = OpLoad %4 %19 +%72 = OpLoad %4 %19 +%73 = OpBitFieldInsert %4 %71 %72 %9 %10 +OpStore %19 %73 +%74 = OpLoad %11 %21 +%75 = OpLoad %11 %21 +%76 = OpBitFieldInsert %11 %74 %75 %9 %10 +OpStore %21 %76 +%77 = OpLoad %12 %23 +%78 = OpLoad %12 %23 +%79 = OpBitFieldInsert %12 %77 %78 %9 %10 +OpStore %23 %79 +%80 = OpLoad %13 %25 +%81 = OpLoad %13 %25 +%82 = OpBitFieldInsert %13 %80 %81 %9 %10 +OpStore %25 %82 +%83 = OpLoad %6 %27 +%84 = OpLoad %6 %27 +%85 = OpBitFieldInsert %6 %83 %84 %9 %10 +OpStore %27 %85 +%86 = OpLoad %14 %29 +%87 = OpLoad %14 %29 +%88 = OpBitFieldInsert %14 %86 %87 %9 %10 +OpStore %29 %88 +%89 = OpLoad %15 %31 +%90 = OpLoad %15 %31 +%91 = OpBitFieldInsert %15 %89 %90 %9 %10 +OpStore %31 %91 +%92 = OpLoad %16 %33 +%93 = OpLoad %16 %33 +%94 = OpBitFieldInsert %16 %92 %93 %9 %10 +OpStore %33 %94 +%95 = OpLoad %4 %19 +%96 = OpBitFieldSExtract %4 %95 %9 %10 +OpStore %19 %96 +%97 = OpLoad %11 %21 +%98 = OpBitFieldSExtract %11 %97 %9 %10 +OpStore %21 %98 +%99 = OpLoad %12 %23 +%100 = OpBitFieldSExtract %12 %99 %9 %10 +OpStore %23 %100 +%101 = OpLoad %13 %25 +%102 = OpBitFieldSExtract %13 %101 %9 %10 +OpStore %25 %102 +%103 = OpLoad %6 %27 +%104 = OpBitFieldUExtract %6 %103 %9 %10 +OpStore %27 %104 +%105 = OpLoad %14 %29 +%106 = OpBitFieldUExtract %14 %105 %9 %10 +OpStore %29 %106 +%107 = OpLoad %15 %31 +%108 = OpBitFieldUExtract %15 %107 %9 %10 +OpStore %31 %108 +%109 = OpLoad %16 %33 +%110 = OpBitFieldUExtract %16 %109 %9 %10 +OpStore %33 %110 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/bits.wgsl b/tests/out/wgsl/bits.wgsl new file mode 100644 index 0000000000..c4054d8e34 --- /dev/null +++ b/tests/out/wgsl/bits.wgsl @@ -0,0 +1,83 @@ +[[stage(compute), workgroup_size(1, 1, 1)]] +fn main() { + var i: i32 = 0; + var i2_: vec2; + var i3_: vec3; + var i4_: vec4; + var u: u32 = 0u; + var u2_: vec2; + var u3_: vec3; + var u4_: vec4; + var f2_: vec2; + var f4_: vec4; + + i2_ = vec2(0); + i3_ = vec3(0); + i4_ = vec4(0); + u2_ = vec2(0u); + u3_ = vec3(0u); + u4_ = vec4(0u); + f2_ = vec2(0.0); + f4_ = vec4(0.0); + let e28: vec4 = f4_; + u = pack4x8snorm(e28); + let e30: vec4 = f4_; + u = pack4x8unorm(e30); + let e32: vec2 = f2_; + u = pack2x16snorm(e32); + let e34: vec2 = f2_; + u = pack2x16unorm(e34); + let e36: vec2 = f2_; + u = pack2x16float(e36); + let e38: u32 = u; + f4_ = unpack4x8snorm(e38); + let e40: u32 = u; + f4_ = unpack4x8unorm(e40); + let e42: u32 = u; + f2_ = unpack2x16snorm(e42); + let e44: u32 = u; + f2_ = unpack2x16unorm(e44); + let e46: u32 = u; + f2_ = unpack2x16float(e46); + let e48: i32 = i; + let e49: i32 = i; + i = insertBits(e48, e49, 5u, 10u); + let e53: vec2 = i2_; + let e54: vec2 = i2_; + i2_ = insertBits(e53, e54, 5u, 10u); + let e58: vec3 = i3_; + let e59: vec3 = i3_; + i3_ = insertBits(e58, e59, 5u, 10u); + let e63: vec4 = i4_; + let e64: vec4 = i4_; + i4_ = insertBits(e63, e64, 5u, 10u); + let e68: u32 = u; + let e69: u32 = u; + u = insertBits(e68, e69, 5u, 10u); + let e73: vec2 = u2_; + let e74: vec2 = u2_; + u2_ = insertBits(e73, e74, 5u, 10u); + let e78: vec3 = u3_; + let e79: vec3 = u3_; + u3_ = insertBits(e78, e79, 5u, 10u); + let e83: vec4 = u4_; + let e84: vec4 = u4_; + u4_ = insertBits(e83, e84, 5u, 10u); + let e88: i32 = i; + i = extractBits(e88, 5u, 10u); + let e92: vec2 = i2_; + i2_ = extractBits(e92, 5u, 10u); + let e96: vec3 = i3_; + i3_ = extractBits(e96, 5u, 10u); + let e100: vec4 = i4_; + i4_ = extractBits(e100, 5u, 10u); + let e104: u32 = u; + u = extractBits(e104, 5u, 10u); + let e108: vec2 = u2_; + u2_ = extractBits(e108, 5u, 10u); + let e112: vec3 = u3_; + u3_ = extractBits(e112, 5u, 10u); + let e116: vec4 = u4_; + u4_ = extractBits(e116, 5u, 10u); + return; +} diff --git a/tests/out/wgsl/bits_glsl-frag.wgsl b/tests/out/wgsl/bits_glsl-frag.wgsl new file mode 100644 index 0000000000..2eb77888ca --- /dev/null +++ b/tests/out/wgsl/bits_glsl-frag.wgsl @@ -0,0 +1,80 @@ +fn main1() { + var i: i32 = 0; + var i2_: vec2 = vec2(0, 0); + var i3_: vec3 = vec3(0, 0, 0); + var i4_: vec4 = vec4(0, 0, 0, 0); + var u: u32 = 0u; + var u2_: vec2 = vec2(0u, 0u); + var u3_: vec3 = vec3(0u, 0u, 0u); + var u4_: vec4 = vec4(0u, 0u, 0u, 0u); + var f2_: vec2 = vec2(0.0, 0.0); + var f4_: vec4 = vec4(0.0, 0.0, 0.0, 0.0); + + let e33: vec4 = f4_; + u = pack4x8snorm(e33); + let e36: vec4 = f4_; + u = pack4x8unorm(e36); + let e39: vec2 = f2_; + u = pack2x16unorm(e39); + let e42: vec2 = f2_; + u = pack2x16snorm(e42); + let e45: vec2 = f2_; + u = pack2x16float(e45); + let e48: u32 = u; + f4_ = unpack4x8snorm(e48); + let e51: u32 = u; + f4_ = unpack4x8unorm(e51); + let e54: u32 = u; + f2_ = unpack2x16snorm(e54); + let e57: u32 = u; + f2_ = unpack2x16unorm(e57); + let e60: u32 = u; + f2_ = unpack2x16float(e60); + let e66: i32 = i; + let e67: i32 = i; + i = insertBits(e66, e67, u32(5), u32(10)); + let e77: vec2 = i2_; + let e78: vec2 = i2_; + i2_ = insertBits(e77, e78, u32(5), u32(10)); + let e88: vec3 = i3_; + let e89: vec3 = i3_; + i3_ = insertBits(e88, e89, u32(5), u32(10)); + let e99: vec4 = i4_; + let e100: vec4 = i4_; + i4_ = insertBits(e99, e100, u32(5), u32(10)); + let e110: u32 = u; + let e111: u32 = u; + u = insertBits(e110, e111, u32(5), u32(10)); + let e121: vec2 = u2_; + let e122: vec2 = u2_; + u2_ = insertBits(e121, e122, u32(5), u32(10)); + let e132: vec3 = u3_; + let e133: vec3 = u3_; + u3_ = insertBits(e132, e133, u32(5), u32(10)); + let e143: vec4 = u4_; + let e144: vec4 = u4_; + u4_ = insertBits(e143, e144, u32(5), u32(10)); + let e153: i32 = i; + i = extractBits(e153, u32(5), u32(10)); + let e162: vec2 = i2_; + i2_ = extractBits(e162, u32(5), u32(10)); + let e171: vec3 = i3_; + i3_ = extractBits(e171, u32(5), u32(10)); + let e180: vec4 = i4_; + i4_ = extractBits(e180, u32(5), u32(10)); + let e189: u32 = u; + u = extractBits(e189, u32(5), u32(10)); + let e198: vec2 = u2_; + u2_ = extractBits(e198, u32(5), u32(10)); + let e207: vec3 = u3_; + u3_ = extractBits(e207, u32(5), u32(10)); + let e216: vec4 = u4_; + u4_ = extractBits(e216, u32(5), u32(10)); + return; +} + +[[stage(fragment)]] +fn main() { + main1(); + return; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 5725695e66..7d05d0fbfe 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -465,6 +465,10 @@ fn convert_wgsl() { | Targets::HLSL | Targets::WGSL, ), + ( + "bits", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL, + ), ( "boids", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,