From b35e901249ffd7a7857ddd396388fb7a017c41dd Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Wed, 9 Dec 2020 00:12:53 -0500 Subject: [PATCH] Add all standard library functions to the IR --- src/back/glsl/mod.rs | 175 +++++++++++++++++++-------------- src/back/msl/writer.rs | 127 +++++++++++++++++------- src/back/spv/writer.rs | 73 +++++++------- src/front/spv/function.rs | 16 ++- src/front/spv/mod.rs | 201 +++++++++++++++++--------------------- src/front/wgsl/conv.rs | 88 +++++++++++++---- src/front/wgsl/mod.rs | 130 ++++++++++++------------ src/lib.rs | 121 ++++++++++++++++------- src/proc/interface.rs | 33 ++++--- src/proc/mod.rs | 59 +++++++++++ src/proc/typifier.rs | 169 ++++++++++++++++++++++---------- test-data/boids.wgsl | 14 +-- 12 files changed, 744 insertions(+), 462 deletions(-) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 98bfe02c9d..3f87d1c760 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -49,10 +49,9 @@ use crate::{ Typifier, Visitor, }, Arena, ArraySize, BinaryOperator, BuiltIn, Bytes, ConservativeDepth, Constant, ConstantInner, - DerivativeAxis, Expression, FastHashMap, Function, FunctionOrigin, GlobalVariable, Handle, - ImageClass, Interpolation, IntrinsicFunction, LocalVariable, Module, ScalarKind, ShaderStage, - Statement, StorageAccess, StorageClass, StorageFormat, StructMember, Type, TypeInner, - UnaryOperator, + DerivativeAxis, Expression, FastHashMap, Function, GlobalVariable, Handle, ImageClass, + Interpolation, LocalVariable, Module, RelationalFunction, ScalarKind, ShaderStage, Statement, + StorageAccess, StorageClass, StorageFormat, StructMember, Type, TypeInner, UnaryOperator, }; use features::FeaturesManager; use std::{ @@ -1371,49 +1370,119 @@ impl<'a, W: Write> Writer<'a, W> { self.write_expr(reject, ctx)?; write!(self.out, ")")? } - // `Intrinsic` is a normal function call to some glsl provided functions - Expression::Intrinsic { fun, argument } => { + // `Derivative` is a function call to a glsl provided function + Expression::Derivative { axis, expr } => { write!( self.out, "{}(", - match fun { - // There's no specific function for this but we can invert the result of `isinf` - IntrinsicFunction::IsFinite => "!isinf", - IntrinsicFunction::IsInf => "isinf", - IntrinsicFunction::IsNan => "isnan", - // There's also no function for this but we can invert `isnan` - IntrinsicFunction::IsNormal => "!isnan", - IntrinsicFunction::All => "all", - IntrinsicFunction::Any => "any", + match axis { + DerivativeAxis::X => "dFdx", + DerivativeAxis::Y => "dFdy", + DerivativeAxis::Width => "fwidth", } )?; + self.write_expr(expr, ctx)?; + write!(self.out, ")")? + } + // `Relational` is a normal function call to some glsl provided functions + Expression::Relational { fun, argument } => { + let fun_name = match fun { + // There's no specific function for this but we can invert the result of `isinf` + RelationalFunction::IsFinite => "!isinf", + RelationalFunction::IsInf => "isinf", + RelationalFunction::IsNan => "isnan", + // There's also no function for this but we can invert `isnan` + RelationalFunction::IsNormal => "!isnan", + RelationalFunction::All => "all", + RelationalFunction::Any => "any", + }; + write!(self.out, "{}(", fun_name)?; self.write_expr(argument, ctx)?; write!(self.out, ")")? } + Expression::Math { + fun, + arg, + arg1, + arg2, + } => { + use crate::MathFunction as Mf; + + let fun_name = match fun { + // comparison + Mf::Abs => "abs", + Mf::Min => "min", + Mf::Max => "max", + Mf::Clamp => "clamp", + // trigonometry + Mf::Cos => "cos", + Mf::Cosh => "cosh", + Mf::Sin => "sin", + Mf::Sinh => "sinh", + Mf::Tan => "tan", + Mf::Tanh => "tanh", + Mf::Acos => "acos", + Mf::Asin => "asin", + Mf::Atan => "atan", + Mf::Atan2 => "atan2", + // decomposition + Mf::Ceil => "ceil", + Mf::Floor => "floor", + Mf::Round => "round", + Mf::Fract => "fract", + Mf::Trunc => "trunc", + Mf::Modf => "modf", + Mf::Frexp => "frexp", + Mf::Ldexp => "ldexp", + // exponent + Mf::Exp => "exp", + Mf::Exp2 => "exp2", + Mf::Log => "log", + Mf::Log2 => "log2", + Mf::Pow => "pow", + // geometry + Mf::Dot => "dot", + Mf::Outer => "outerProduct", + Mf::Cross => "cross", + Mf::Distance => "distance", + Mf::Length => "length", + Mf::Normalize => "normalize", + Mf::FaceForward => "faceforward", + Mf::Reflect => "reflect", + // computational + Mf::Sign => "sign", + Mf::Fma => "fma", + Mf::Mix => "mix", + Mf::Step => "step", + Mf::SmoothStep => "smoothstep", + Mf::Sqrt => "sqrt", + Mf::InverseSqrt => "inversesqrt", + Mf::Determinant => "determinant", + // bits + Mf::CountOneBits => "bitCount", + Mf::ReverseBits => "bitfieldReverse", + }; + + write!(self.out, "{}(", fun_name)?; + self.write_expr(arg, ctx)?; + if let Some(arg) = arg1 { + write!(self.out, ", ")?; + self.write_expr(arg, ctx)?; + } + if let Some(arg) = arg2 { + write!(self.out, ", ")?; + self.write_expr(arg, ctx)?; + } + write!(self.out, ")")? + } // `Transpose` is a call to the glsl function `transpose` Expression::Transpose(matrix) => { write!(self.out, "transpose(")?; self.write_expr(matrix, ctx)?; write!(self.out, ")")? } - // Both `Dot` and `Cross` products are a call to a glsl provide functions with `left` - // and `right` as arguments - Expression::DotProduct(left, right) => { - write!(self.out, "dot(")?; - self.write_expr(left, ctx)?; - write!(self.out, ", ")?; - self.write_expr(right, ctx)?; - write!(self.out, ")")? - } - Expression::CrossProduct(left, right) => { - write!(self.out, "cross(")?; - self.write_expr(left, ctx)?; - write!(self.out, ", ")?; - self.write_expr(right, ctx)?; - write!(self.out, ")")? - } // `As` is always a call. // If `convert` is true the function name is the type // Else the function name is one of the glsl provided bitcast functions @@ -1457,55 +1526,15 @@ impl<'a, W: Write> Writer<'a, W> { self.write_expr(expr, ctx)?; write!(self.out, ")")? } - // `Derivative` is a function call to a glsl provided function - Expression::Derivative { axis, expr } => { - write!( - self.out, - "{}(", - match axis { - DerivativeAxis::X => "dFdx", - DerivativeAxis::Y => "dFdy", - DerivativeAxis::Width => "fwidth", - } - )?; - self.write_expr(expr, ctx)?; - write!(self.out, ")")? - } // A `Call` is written `name(arguments)` where `arguments` is a comma separated expressions list Expression::Call { - origin: FunctionOrigin::Local(ref function), + function, ref arguments, } => { - write!(self.out, "{}(", &self.names[&NameKey::Function(*function)])?; + write!(self.out, "{}(", &self.names[&NameKey::Function(function)])?; self.write_slice(arguments, |this, _, arg| this.write_expr(*arg, ctx))?; write!(self.out, ")")? } - Expression::Call { - origin: crate::FunctionOrigin::External(ref name), - ref arguments, - } => match name.as_str() { - "cos" | "normalize" | "sin" | "length" | "abs" | "floor" | "inverse" - | "distance" | "dot" | "min" | "max" | "reflect" | "pow" | "step" | "cross" - | "fclamp" | "clamp" | "mix" | "smoothstep" => { - let name = match name.as_str() { - "fclamp" => "clamp", - name => name, - }; - - write!(self.out, "{}(", name)?; - self.write_slice(arguments, |this, _, arg| this.write_expr(*arg, ctx))?; - write!(self.out, ")")? - } - // `atan2` is implemented as `atan(y,x)` so we must handle it separately - "atan2" => { - write!(self.out, "atan(")?; - self.write_expr(arguments[1], ctx)?; - write!(self.out, ", ")?; - self.write_expr(arguments[0], ctx)?; - write!(self.out, ")")? - } - _ => return Err(Error::UnsupportedExternal(name.clone())), - }, // `ArrayLength` is written as `expr.length()` and we convert it to a uint Expression::ArrayLength(expr) => { write!(self.out, "uint(")?; diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 936f78ac50..710948d625 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -346,26 +346,103 @@ impl Writer { self.put_expression(reject, context)?; write!(self.out, ")")?; } - crate::Expression::Intrinsic { fun, argument } => { + crate::Expression::Derivative { axis, expr } => { + let op = match axis { + crate::DerivativeAxis::X => "dfdx", + crate::DerivativeAxis::Y => "dfdy", + crate::DerivativeAxis::Width => "fwidth", + }; + self.put_call(op, &[expr], context)?; + } + crate::Expression::Relational { fun, argument } => { let op = match fun { - crate::IntrinsicFunction::Any => "any", - crate::IntrinsicFunction::All => "all", - crate::IntrinsicFunction::IsNan => "", - crate::IntrinsicFunction::IsInf => "", - crate::IntrinsicFunction::IsFinite => "", - crate::IntrinsicFunction::IsNormal => "", + crate::RelationalFunction::Any => "any", + crate::RelationalFunction::All => "all", + crate::RelationalFunction::IsNan => "isnan", + crate::RelationalFunction::IsInf => "isinf", + crate::RelationalFunction::IsFinite => "isfinite", + crate::RelationalFunction::IsNormal => "isnormal", }; self.put_call(op, &[argument], context)?; } + crate::Expression::Math { + fun, + arg, + arg1, + arg2, + } => { + use crate::MathFunction as Mf; + + let fun_name = match fun { + // comparison + Mf::Abs => "abs", + Mf::Min => "min", + Mf::Max => "max", + Mf::Clamp => "clamp", + // trigonometry + Mf::Cos => "cos", + Mf::Cosh => "cosh", + Mf::Sin => "sin", + Mf::Sinh => "sinh", + Mf::Tan => "tan", + Mf::Tanh => "tanh", + Mf::Acos => "acos", + Mf::Asin => "asin", + Mf::Atan => "atan", + Mf::Atan2 => "atan2", + // decomposition + Mf::Ceil => "ceil", + Mf::Floor => "floor", + Mf::Round => "round", + Mf::Fract => "fract", + Mf::Trunc => "trunc", + Mf::Modf => "modf", + Mf::Frexp => "frexp", + Mf::Ldexp => "ldexp", + // exponent + Mf::Exp => "exp", + Mf::Exp2 => "exp2", + Mf::Log => "log", + Mf::Log2 => "log2", + Mf::Pow => "pow", + // geometry + Mf::Dot => "dot", + Mf::Outer => return Err(Error::UnsupportedCall(format!("{:?}", fun))), + Mf::Cross => "cross", + Mf::Distance => "distance", + Mf::Length => "length", + Mf::Normalize => "normalize", + Mf::FaceForward => "faceforward", + Mf::Reflect => "reflect", + // computational + Mf::Sign => "sign", + Mf::Fma => "fma", + Mf::Mix => "mix", + Mf::Step => "step", + Mf::SmoothStep => "smoothstep", + Mf::Sqrt => "sqrt", + Mf::InverseSqrt => "rsqrt", + Mf::Determinant => "determinant", + // bits + Mf::CountOneBits => "popcount", + Mf::ReverseBits => "reverse_bits", + }; + + write!(self.out, "metal::{}(", fun_name)?; + self.put_expression(arg, context)?; + if let Some(arg) = arg1 { + write!(self.out, ", ")?; + self.put_expression(arg, context)?; + } + if let Some(arg) = arg2 { + write!(self.out, ", ")?; + self.put_expression(arg, context)?; + } + write!(self.out, ")")?; + } crate::Expression::Transpose(expr) => { self.put_call("transpose", &[expr], context)?; } - crate::Expression::DotProduct(a, b) => { - self.put_call("dot", &[a, b], context)?; - } - crate::Expression::CrossProduct(a, b) => { - self.put_call("cross", &[a, b], context)?; - } crate::Expression::As { expr, kind, @@ -382,34 +459,14 @@ impl Writer { self.put_expression(expr, context)?; write!(self.out, ")")?; } - crate::Expression::Derivative { axis, expr } => { - let op = match axis { - crate::DerivativeAxis::X => "dfdx", - crate::DerivativeAxis::Y => "dfdy", - crate::DerivativeAxis::Width => "fwidth", - }; - self.put_call(op, &[expr], context)?; - } crate::Expression::Call { - origin: crate::FunctionOrigin::Local(handle), + function, ref arguments, } => { - let name = &self.names[&NameKey::Function(handle)]; + let name = &self.names[&NameKey::Function(function)]; write!(self.out, "{}", name)?; self.put_call("", arguments, context)?; } - crate::Expression::Call { - origin: crate::FunctionOrigin::External(ref name), - ref arguments, - } => match name.as_str() { - "atan2" | "cos" | "distance" | "length" | "mix" | "normalize" | "sin" => { - self.put_call(name, arguments, context)?; - } - "fclamp" => { - self.put_call("clamp", arguments, context)?; - } - other => return Err(Error::UnsupportedCall(other.to_owned())), - }, crate::Expression::ArrayLength(expr) => match *self .typifier .get(expr, &context.module.types) diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 50d354dc9e..9279cce037 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -1232,6 +1232,10 @@ impl Writer { )); Ok((id, result_lookup_ty)) } + crate::Expression::Math { fun, .. } => { + log::error!("unimplemented math function {:?}", fun); + Err(Error::FeatureNotImplemented("math function")) + } crate::Expression::LocalVariable(variable) => { let var = &ir_function.local_variables[variable]; let local_var = &function.variables[&variable]; @@ -1251,49 +1255,38 @@ impl Writer { Ok((load_id, LookupType::Handle(handle))) } crate::Expression::Call { - ref origin, + function: local_function, ref arguments, - } => match *origin { - crate::FunctionOrigin::Local(local_function) => { - let origin_function = &ir_module.functions[local_function]; - let id = self.generate_id(); - let mut argument_ids = vec![]; + } => { + let target_function = &ir_module.functions[local_function]; + let id = self.generate_id(); + let mut argument_ids = vec![]; - for argument in arguments { - let expression = &ir_function.expressions[*argument]; - let (arg_id, _) = self.write_expression( - ir_module, - ir_function, - expression, - block, - function, - )?; - argument_ids.push(arg_id); - } - - let return_type_id = self - .get_function_return_type(origin_function.return_type, &ir_module.types)?; - - block - .body - .push(super::instructions::instruction_function_call( - return_type_id, - id, - *self.lookup_function.get(&local_function).unwrap(), - argument_ids.as_slice(), - )); - - let result_type = match origin_function.return_type { - Some(ty_handle) => LookupType::Handle(ty_handle), - None => LookupType::Local(LocalType::Void), - }; - Ok((id, result_type)) + for argument in arguments { + let expression = &ir_function.expressions[*argument]; + let (arg_id, _) = + self.write_expression(ir_module, ir_function, expression, block, function)?; + argument_ids.push(arg_id); } - crate::FunctionOrigin::External(ref string) => { - log::error!("unimplemented stdlib function {}", string); - Err(Error::FeatureNotImplemented("stdlib function")) - } - }, + + let return_type_id = + self.get_function_return_type(target_function.return_type, &ir_module.types)?; + + block + .body + .push(super::instructions::instruction_function_call( + return_type_id, + id, + *self.lookup_function.get(&local_function).unwrap(), + argument_ids.as_slice(), + )); + + let result_type = match target_function.return_type { + Some(ty_handle) => LookupType::Handle(ty_handle), + None => LookupType::Local(LocalType::Void), + }; + Ok((id, result_type)) + } crate::Expression::As { expr, kind, diff --git a/src/front/spv/function.rs b/src/front/spv/function.rs index d2cb0551a1..654160b303 100644 --- a/src/front/spv/function.rs +++ b/src/front/spv/function.rs @@ -115,8 +115,8 @@ impl> super::Parser { } // Read body - let mut local_function_calls = FastHashMap::default(); let mut flow_graph = FlowGraph::new(); + let base_deferred_call_index = self.deferred_function_calls.len(); // Scan the blocks and add them as nodes loop { @@ -135,7 +135,6 @@ impl> super::Parser { &module.types, &module.constants, &module.global_variables, - &mut local_function_calls, )?; flow_graph.add_node(node); @@ -176,9 +175,14 @@ impl> super::Parser { } }; + for dfc in self.deferred_function_calls[base_deferred_call_index..].iter_mut() { + dfc.source = source.clone(); + } + if let Some(ref prefix) = self.options.flow_graph_dump_prefix { let dump = flow_graph.to_graphviz().unwrap_or_default(); let suffix = match source { + DeferredSource::Undefined => unreachable!(), DeferredSource::EntryPoint(stage, ref name) => { format!("flow.{:?}-{}.dot", stage, name) } @@ -187,14 +191,6 @@ impl> super::Parser { let _ = std::fs::write(prefix.join(suffix), dump); } - for (expr_handle, dst_id) in local_function_calls { - self.deferred_function_calls.push(DeferredFunctionCall { - source: source.clone(), - expr_handle, - dst_id, - }); - } - self.lookup_expression.clear(); self.lookup_sampled_image.clear(); Ok(()) diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index ad7c208519..782c62a6fa 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -289,6 +289,7 @@ struct LookupSampledImage { } #[derive(Clone, Debug)] enum DeferredSource { + Undefined, EntryPoint(crate::ShaderStage, String), Function(Handle), } @@ -296,6 +297,7 @@ struct DeferredFunctionCall { source: DeferredSource, expr_handle: Handle, dst_id: spirv::Word, + arguments: Vec>, } #[derive(Clone, Debug)] @@ -530,7 +532,6 @@ impl> Parser { type_arena: &Arena, const_arena: &Arena, global_arena: &Arena, - local_function_calls: &mut FastHashMap, spirv::Word>, ) -> Result { let mut assignments = Vec::new(); let mut phis = Vec::new(); @@ -850,7 +851,12 @@ impl> Parser { let right_id = self.next()?; let left_lexp = self.lookup_expression.lookup(left_id)?; let right_lexp = self.lookup_expression.lookup(right_id)?; - let expr = crate::Expression::DotProduct(left_lexp.handle, right_lexp.handle); + let expr = crate::Expression::Math { + fun: crate::MathFunction::Dot, + arg: left_lexp.handle, + arg1: Some(right_lexp.handle), + arg2: None, + }; self.lookup_expression.insert( result_id, LookupExpression { @@ -1185,13 +1191,15 @@ impl> Parser { let arg_id = self.next()?; arguments.push(self.lookup_expression.lookup(arg_id)?.handle); } - let expr = crate::Expression::Call { - // will be replaced by `Local()` after all the functions are parsed - origin: crate::FunctionOrigin::External(String::new()), - arguments, - }; + // will be replaced by the actual expression + let expr = crate::Expression::FunctionArgument(!0); let expr_handle = expressions.append(expr); - local_function_calls.insert(expr_handle, func_id); + self.deferred_function_calls.push(DeferredFunctionCall { + source: DeferredSource::Undefined, + expr_handle, + dst_id: func_id, + arguments, + }); self.lookup_expression.insert( result_id, LookupExpression { @@ -1201,6 +1209,9 @@ impl> Parser { ); } Op::ExtInst => { + use crate::MathFunction as Mf; + use spirv::GLOp as Glo; + let base_wc = 5; inst.expect_at_least(base_wc)?; let result_type_id = self.next()?; @@ -1210,106 +1221,76 @@ impl> Parser { return Err(Error::UnsupportedExtInstSet(set_id)); } let inst_id = self.next()?; - let name = match spirv::GLOp::from_u32(inst_id) { - Some(spirv::GLOp::FAbs) | Some(spirv::GLOp::SAbs) => { - inst.expect(base_wc + 1)?; - "abs" - } - Some(spirv::GLOp::FSign) | Some(spirv::GLOp::SSign) => { - inst.expect(base_wc + 1)?; - "sign" - } - Some(spirv::GLOp::Floor) => { - inst.expect(base_wc + 1)?; - "floor" - } - Some(spirv::GLOp::Ceil) => { - inst.expect(base_wc + 1)?; - "ceil" - } - Some(spirv::GLOp::Fract) => { - inst.expect(base_wc + 1)?; - "fract" - } - Some(spirv::GLOp::Sin) => { - inst.expect(base_wc + 1)?; - "sin" - } - Some(spirv::GLOp::Cos) => { - inst.expect(base_wc + 1)?; - "cos" - } - Some(spirv::GLOp::Tan) => { - inst.expect(base_wc + 1)?; - "tan" - } - Some(spirv::GLOp::Atan2) => { - inst.expect(base_wc + 2)?; - "atan2" - } - Some(spirv::GLOp::Pow) => { - inst.expect(base_wc + 2)?; - "pow" - } - Some(spirv::GLOp::MatrixInverse) => { - inst.expect(base_wc + 1)?; - "inverse" - } - Some(spirv::GLOp::FMix) => { - inst.expect(base_wc + 3)?; - "mix" - } - Some(spirv::GLOp::Step) => { - inst.expect(base_wc + 2)?; - "step" - } - Some(spirv::GLOp::SmoothStep) => { - inst.expect(base_wc + 3)?; - "smoothstep" - } - Some(spirv::GLOp::FMin) => { - inst.expect(base_wc + 2)?; - "min" - } - Some(spirv::GLOp::FMax) => { - inst.expect(base_wc + 2)?; - "max" - } - Some(spirv::GLOp::FClamp) => { - inst.expect(base_wc + 3)?; - "clamp" - } - Some(spirv::GLOp::Length) => { - inst.expect(base_wc + 1)?; - "length" - } - Some(spirv::GLOp::Distance) => { - inst.expect(base_wc + 2)?; - "distance" - } - Some(spirv::GLOp::Cross) => { - inst.expect(base_wc + 2)?; - "cross" - } - Some(spirv::GLOp::Normalize) => { - inst.expect(base_wc + 1)?; - "normalize" - } - Some(spirv::GLOp::Reflect) => { - inst.expect(base_wc + 2)?; - "reflect" - } + let gl_op = Glo::from_u32(inst_id).ok_or(Error::UnsupportedExtInst(inst_id))?; + let fun = match gl_op { + Glo::Round => Mf::Round, + Glo::Trunc => Mf::Trunc, + Glo::FAbs | Glo::SAbs => Mf::Abs, + Glo::FSign | Glo::SSign => Mf::Sign, + Glo::Floor => Mf::Floor, + Glo::Ceil => Mf::Ceil, + Glo::Fract => Mf::Fract, + Glo::Sin => Mf::Sin, + Glo::Cos => Mf::Cos, + Glo::Tan => Mf::Tan, + Glo::Asin => Mf::Asin, + Glo::Acos => Mf::Acos, + Glo::Atan => Mf::Atan, + Glo::Sinh => Mf::Sinh, + Glo::Cosh => Mf::Cosh, + Glo::Tanh => Mf::Tanh, + Glo::Atan2 => Mf::Atan2, + Glo::Pow => Mf::Pow, + Glo::Exp => Mf::Exp, + Glo::Log => Mf::Log, + Glo::Exp2 => Mf::Exp2, + Glo::Log2 => Mf::Log2, + Glo::Sqrt => Mf::Sqrt, + Glo::InverseSqrt => Mf::InverseSqrt, + Glo::Determinant => Mf::Determinant, + Glo::Modf => Mf::Modf, + Glo::FMin | Glo::UMin | Glo::SMin | Glo::NMin => Mf::Min, + Glo::FMax | Glo::UMax | Glo::SMax | Glo::NMax => Mf::Max, + Glo::FClamp | Glo::UClamp | Glo::SClamp | Glo::NClamp => Mf::Clamp, + Glo::FMix => Mf::Mix, + Glo::Step => Mf::Step, + Glo::SmoothStep => Mf::SmoothStep, + Glo::Fma => Mf::Fma, + Glo::Frexp => Mf::Frexp, //TODO: FrexpStruct? + Glo::Ldexp => Mf::Ldexp, + Glo::Length => Mf::Length, + Glo::Distance => Mf::Distance, + Glo::Cross => Mf::Cross, + Glo::Normalize => Mf::Normalize, + Glo::FaceForward => Mf::FaceForward, + Glo::Reflect => Mf::Reflect, _ => return Err(Error::UnsupportedExtInst(inst_id)), }; - let mut arguments = Vec::with_capacity((inst.wc - base_wc) as usize); - for _ in 0..arguments.capacity() { + let arg_count = fun.argument_count(); + inst.expect(base_wc + arg_count as u16)?; + let arg = { let arg_id = self.next()?; - arguments.push(self.lookup_expression.lookup(arg_id)?.handle); - } - let expr = crate::Expression::Call { - origin: crate::FunctionOrigin::External(name.to_string()), - arguments, + self.lookup_expression.lookup(arg_id)?.handle + }; + let arg1 = if arg_count > 1 { + let arg_id = self.next()?; + Some(self.lookup_expression.lookup(arg_id)?.handle) + } else { + None + }; + let arg2 = if arg_count > 2 { + let arg_id = self.next()?; + Some(self.lookup_expression.lookup(arg_id)?.handle) + } else { + None + }; + + let expr = crate::Expression::Math { + fun, + arg, + arg1, + arg2, }; self.lookup_expression.insert( result_id, @@ -1553,6 +1534,7 @@ impl> Parser { for dfc in self.deferred_function_calls.drain(..) { let dst_handle = *self.lookup_function.lookup(dfc.dst_id)?; let fun = match dfc.source { + DeferredSource::Undefined => unreachable!(), DeferredSource::Function(fun_handle) => module.functions.get_mut(fun_handle), DeferredSource::EntryPoint(stage, name) => { &mut module @@ -1562,13 +1544,10 @@ impl> Parser { .function } }; - match *fun.expressions.get_mut(dfc.expr_handle) { - crate::Expression::Call { - ref mut origin, - arguments: _, - } => *origin = crate::FunctionOrigin::Local(dst_handle), - _ => unreachable!(), - } + *fun.expressions.get_mut(dfc.expr_handle) = crate::Expression::Call { + function: dst_handle, + arguments: dfc.arguments, + }; } if !self.future_decor.is_empty() { diff --git a/src/front/wgsl/conv.rs b/src/front/wgsl/conv.rs index 1c986a3cd3..18de009aad 100644 --- a/src/front/wgsl/conv.rs +++ b/src/front/wgsl/conv.rs @@ -97,18 +97,7 @@ pub fn get_scalar_type(word: &str) -> Option<(crate::ScalarKind, crate::Bytes)> } } -pub fn get_intrinsic(word: &str) -> Option { - match word { - "any" => Some(crate::IntrinsicFunction::Any), - "all" => Some(crate::IntrinsicFunction::All), - "is_nan" => Some(crate::IntrinsicFunction::IsNan), - "is_inf" => Some(crate::IntrinsicFunction::IsInf), - "is_normal" => Some(crate::IntrinsicFunction::IsNormal), - _ => None, - } -} - -pub fn get_derivative(word: &str) -> Option { +pub fn map_derivative_axis(word: &str) -> Option { match word { "dpdx" => Some(crate::DerivativeAxis::X), "dpdy" => Some(crate::DerivativeAxis::Y), @@ -117,16 +106,73 @@ pub fn get_derivative(word: &str) -> Option { } } -// Returns argument count on success -pub fn get_standard_fun(word: &str) -> Option { +pub fn map_relational_fun(word: &str) -> Option { match word { - "abs" | "acos" | "asin" | "atan" | "ceil" | "cos" | "cosh" | "exp" | "exp2" | "floor" - | "fract" | "inverseSqrt" | "length" | "log" | "log2" | "normalize" | "round" | "sign" - | "sin" | "sinh" | "sqrt" | "tan" | "tanh" | "trunc" => Some(1), - "countOneBits" | "reverseBits" | "determinant" => Some(1), - "atan2" | "distance" | "frexp" | "ldexp" | "max" | "min" | "outerProduct" | "pow" - | "reflect" | "step" => Some(2), - "clamp" | "faceForward" | "fma" | "smoothStep" => Some(3), + "any" => Some(crate::RelationalFunction::Any), + "all" => Some(crate::RelationalFunction::All), + "isFinite" => Some(crate::RelationalFunction::IsFinite), + "isInf" => Some(crate::RelationalFunction::IsInf), + "isNan" => Some(crate::RelationalFunction::IsNan), + "isNormal" => Some(crate::RelationalFunction::IsNormal), _ => None, } } + +pub fn map_standard_fun(word: &str) -> Option { + use crate::MathFunction as Mf; + Some(match word { + // comparison + "abs" => Mf::Abs, + "min" => Mf::Min, + "max" => Mf::Max, + "clamp" => Mf::Clamp, + // trigonometry + "cos" => Mf::Cos, + "cosh" => Mf::Cosh, + "sin" => Mf::Sin, + "sinh" => Mf::Sinh, + "tan" => Mf::Tan, + "tanh" => Mf::Tanh, + "acos" => Mf::Acos, + "asin" => Mf::Asin, + "atan" => Mf::Atan, + "atan2" => Mf::Atan2, + // decomposition + "ceil" => Mf::Ceil, + "floor" => Mf::Floor, + "round" => Mf::Round, + "fract" => Mf::Fract, + "trunc" => Mf::Trunc, + "modf" => Mf::Modf, + "frexp" => Mf::Frexp, + "ldexp" => Mf::Ldexp, + // exponent + "exp" => Mf::Exp, + "exp2" => Mf::Exp2, + "log" => Mf::Log, + "log2" => Mf::Log2, + "pow" => Mf::Pow, + // geometry + "dot" => Mf::Dot, + "outerProduct" => Mf::Outer, + "cross" => Mf::Cross, + "distance" => Mf::Distance, + "length" => Mf::Length, + "normalize" => Mf::Normalize, + "faceForward" => Mf::FaceForward, + "reflect" => Mf::Reflect, + // computational + "sign" => Mf::Sign, + "fma" => Mf::Fma, + "mix" => Mf::Mix, + "step" => Mf::Step, + "smoothStep" => Mf::SmoothStep, + "sqrt" => Mf::Sqrt, + "inverseSqrt" => Mf::InverseSqrt, + "determinant" => Mf::Determinant, + // bits + "countOneBits" => Mf::CountOneBits, + "reverseBits" => Mf::ReverseBits, + _ => return None, + }) +} diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 98d4cc513e..4beeada220 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -336,40 +336,54 @@ impl Parser { ) -> Result)>, Error<'a>> { let mut lexer = lexer.clone(); - let external_function = if let Some(std_namespaces) = self.std_namespace.as_deref() { - std_namespaces.iter().all(|namespace| { - lexer.skip(Token::Word(namespace)) && lexer.skip(Token::DoubleColon) - }) - } else { - false - }; - - let origin = if external_function { - let function = lexer.next_ident()?; - crate::FunctionOrigin::External(function.to_string()) - } else if let Ok(function) = lexer.next_ident() { - if let Some(&function) = self.function_lookup.get(function) { - crate::FunctionOrigin::Local(function) + let name = lexer.next_ident()?; + //TODO: avoid code duplication with `parse_singular_expression` + Ok(if let Some(fun) = conv::map_standard_fun(name) { + lexer.expect(Token::Paren('('))?; + let arg_count = fun.argument_count(); + let arg = self.parse_general_expression(&mut lexer, ctx.reborrow())?; + let arg1 = if arg_count > 1 { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(&mut lexer, ctx.reborrow())?) } else { + None + }; + let arg2 = if arg_count > 1 { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(&mut lexer, ctx.reborrow())?) + } else { + None + }; + lexer.expect(Token::Paren(')'))?; + let expr = crate::Expression::Math { + fun, + arg, + arg1, + arg2, + }; + Some((expr, lexer)) + } else if let Some(&function) = self.function_lookup.get(name) { + if !lexer.skip(Token::Paren('(')) { return Ok(None); } - } else { - return Ok(None); - }; - if !lexer.skip(Token::Paren('(')) { - return Ok(None); - } - - let mut arguments = Vec::new(); - while !lexer.skip(Token::Paren(')')) { - if !arguments.is_empty() { - lexer.expect(Token::Separator(','))?; + let mut arguments = Vec::new(); + while !lexer.skip(Token::Paren(')')) { + if !arguments.is_empty() { + lexer.expect(Token::Separator(','))?; + } + let arg = self.parse_general_expression(&mut lexer, ctx.reborrow())?; + arguments.push(arg); } - let arg = self.parse_general_expression(&mut lexer, ctx.reborrow())?; - arguments.push(arg); - } - Ok(Some((crate::Expression::Call { origin, arguments }, lexer))) + + let expr = crate::Expression::Call { + function, + arguments, + }; + Some((expr, lexer)) + } else { + None + }) } fn parse_const_expression<'a>( @@ -605,12 +619,12 @@ impl Parser { expr: self.parse_singular_expression(lexer, ctx.reborrow())?, }), Token::Word(word) => { - if let Some(fun) = conv::get_intrinsic(word) { + if let Some(fun) = conv::map_relational_fun(word) { lexer.expect(Token::Paren('('))?; let argument = self.parse_general_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Paren(')'))?; - Some(crate::Expression::Intrinsic { fun, argument }) - } else if let Some(axis) = conv::get_derivative(word) { + Some(crate::Expression::Relational { fun, argument }) + } else if let Some(axis) = conv::map_derivative_axis(word) { lexer.expect(Token::Paren('('))?; let expr = self.parse_general_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Paren(')'))?; @@ -624,40 +638,32 @@ impl Parser { kind, convert: true, }) - } else if let Some(arg_count) = conv::get_standard_fun(word) { + } else if let Some(fun) = conv::map_standard_fun(word) { lexer.expect(Token::Paren('('))?; - let mut arguments = Vec::with_capacity(arg_count); - for i in 0..arg_count { - let arg = self.parse_general_expression(lexer, ctx.reborrow())?; - arguments.push(arg); - lexer.expect(if i + 1 == arg_count { - Token::Paren(')') - } else { - Token::Separator(',') - })?; - } - Some(crate::Expression::Call { - origin: crate::FunctionOrigin::External(word.to_string()), - arguments, + let arg_count = fun.argument_count(); + let arg = self.parse_general_expression(lexer, ctx.reborrow())?; + let arg1 = if arg_count > 1 { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(lexer, ctx.reborrow())?) + } else { + None + }; + let arg2 = if arg_count > 2 { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(lexer, ctx.reborrow())?) + } else { + None + }; + lexer.expect(Token::Paren(')'))?; + Some(crate::Expression::Math { + fun, + arg, + arg1, + arg2, }) } else { + // texture sampling match word { - "dot" => { - lexer.expect(Token::Paren('('))?; - let a = self.parse_general_expression(lexer, ctx.reborrow())?; - lexer.expect(Token::Separator(','))?; - let b = self.parse_general_expression(lexer, ctx.reborrow())?; - lexer.expect(Token::Paren(')'))?; - Some(crate::Expression::DotProduct(a, b)) - } - "cross" => { - lexer.expect(Token::Paren('('))?; - let a = self.parse_general_expression(lexer, ctx.reborrow())?; - lexer.expect(Token::Separator(','))?; - let b = self.parse_general_expression(lexer, ctx.reborrow())?; - lexer.expect(Token::Paren(')'))?; - Some(crate::Expression::CrossProduct(a, b)) - } "textureSample" => { lexer.expect(Token::Paren('('))?; let image_name = lexer.next_ident()?; diff --git a/src/lib.rs b/src/lib.rs index 7b141b5388..61b835c4b1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -516,19 +516,6 @@ pub enum BinaryOperator { ShiftRight, } -/// Built-in shader function. -#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] -#[cfg_attr(feature = "serialize", derive(Serialize))] -#[cfg_attr(feature = "deserialize", derive(Deserialize))] -pub enum IntrinsicFunction { - Any, - All, - IsNan, - IsInf, - IsFinite, - IsNormal, -} - /// Axis on which to compute a derivative. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] @@ -539,17 +526,76 @@ pub enum DerivativeAxis { Width, } -/// Origin of a function to call. -#[derive(Clone, Debug, PartialEq)] +/// Built-in shader function for testing relation between values. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] -pub enum FunctionOrigin { - Local(Handle), - // External { - // namespace: String, // Maybe this should be a handle to a namespace Arena? - // function: String, - // }, - External(String), +pub enum RelationalFunction { + All, + Any, + IsNan, + IsInf, + IsFinite, + IsNormal, +} + +/// Built-in shader function for math. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum MathFunction { + // comparison + Abs, + Min, + Max, + Clamp, + // trigonometry + Cos, + Cosh, + Sin, + Sinh, + Tan, + Tanh, + Acos, + Asin, + Atan, + Atan2, + // decomposition + Ceil, + Floor, + Round, + Fract, + Trunc, + Modf, + Frexp, + Ldexp, + // exponent + Exp, + Exp2, + Log, + Log2, + Pow, + // geometry + Dot, + Outer, + Cross, + Distance, + Length, + Normalize, + FaceForward, + Reflect, + // computational + Sign, + Fma, + Mix, + Step, + SmoothStep, + Sqrt, + InverseSqrt, + Determinant, + // bits + CountOneBits, + ReverseBits, } /// Sampling modifier to control the level of detail. @@ -628,17 +674,26 @@ pub enum Expression { accept: Handle, reject: Handle, }, - /// Call an intrinsic function. - Intrinsic { - fun: IntrinsicFunction, + /// Compute the derivative on an axis. + Derivative { + axis: DerivativeAxis, + //modifier, + expr: Handle, + }, + /// Call a relational function. + Relational { + fun: RelationalFunction, argument: Handle, }, + /// Call a math function + Math { + fun: MathFunction, + arg: Handle, + arg1: Option>, + arg2: Option>, + }, /// Transpose of a matrix. Transpose(Handle), - /// Dot product between two vectors. - DotProduct(Handle, Handle), - /// Cross product between two vectors. - CrossProduct(Handle, Handle), /// Cast a simply type to another kind. As { /// Source expression, which can only be a scalar or a vector. @@ -648,15 +703,9 @@ pub enum Expression { /// True = conversion needs to take place; False = bitcast. convert: bool, }, - /// Compute the derivative on an axis. - Derivative { - axis: DerivativeAxis, - //modifier, - expr: Handle, - }, /// Call another function. Call { - origin: FunctionOrigin, + function: Handle, arguments: Vec>, }, /// Get the length of an array. diff --git a/src/proc/interface.rs b/src/proc/interface.rs index e3835f263a..29f1176847 100644 --- a/src/proc/interface.rs +++ b/src/proc/interface.rs @@ -88,36 +88,37 @@ where self.traverse_expr(accept); self.traverse_expr(reject); } - E::Intrinsic { argument, .. } => { + E::Derivative { expr, .. } => { + self.traverse_expr(expr); + } + E::Relational { argument, .. } => { self.traverse_expr(argument); } + E::Math { + arg, arg1, arg2, .. + } => { + self.traverse_expr(arg); + if let Some(arg) = arg1 { + self.traverse_expr(arg); + } + if let Some(arg) = arg2 { + self.traverse_expr(arg); + } + } E::Transpose(matrix) => { self.traverse_expr(matrix); } - E::DotProduct(left, right) => { - self.traverse_expr(left); - self.traverse_expr(right); - } - E::CrossProduct(left, right) => { - self.traverse_expr(left); - self.traverse_expr(right); - } E::As { expr, .. } => { self.traverse_expr(expr); } - E::Derivative { expr, .. } => { - self.traverse_expr(expr); - } E::Call { - ref origin, + function, ref arguments, } => { for &argument in arguments { self.traverse_expr(argument); } - if let crate::FunctionOrigin::Local(fun) = *origin { - self.visitor.visit_fun(fun); - } + self.visitor.visit_fun(function); } E::ArrayLength(expr) => { self.traverse_expr(expr); diff --git a/src/proc/mod.rs b/src/proc/mod.rs index d8388b4c5c..23ff413734 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -67,3 +67,62 @@ impl crate::TypeInner { } } } + +impl crate::MathFunction { + pub fn argument_count(&self) -> usize { + match *self { + // comparison + Self::Abs => 1, + Self::Min => 2, + Self::Max => 2, + Self::Clamp => 3, + // trigonometry + Self::Cos => 1, + Self::Cosh => 1, + Self::Sin => 1, + Self::Sinh => 1, + Self::Tan => 1, + Self::Tanh => 1, + Self::Acos => 1, + Self::Asin => 1, + Self::Atan => 1, + Self::Atan2 => 2, + // decomposition + Self::Ceil => 1, + Self::Floor => 1, + Self::Round => 1, + Self::Fract => 1, + Self::Trunc => 1, + Self::Modf => 2, + Self::Frexp => 2, + Self::Ldexp => 2, + // exponent + Self::Exp => 1, + Self::Exp2 => 1, + Self::Log => 1, + Self::Log2 => 1, + Self::Pow => 2, + // geometry + Self::Dot => 2, + Self::Outer => 2, + Self::Cross => 2, + Self::Distance => 2, + Self::Length => 1, + Self::Normalize => 1, + Self::FaceForward => 3, + Self::Reflect => 2, + // computational + Self::Sign => 1, + Self::Fma => 3, + Self::Mix => 3, + Self::Step => 2, + Self::SmoothStep => 3, + Self::Sqrt => 1, + Self::InverseSqrt => 1, + Self::Determinant => 1, + // bits + Self::CountOneBits => 1, + Self::ReverseBits => 1, + } + } +} diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index b09893c179..8fe370163b 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -246,7 +246,122 @@ impl Typifier { | crate::BinaryOperator::ShiftRight => self.resolutions[left.index()].clone(), }, crate::Expression::Select { accept, .. } => self.resolutions[accept.index()].clone(), - crate::Expression::Intrinsic { .. } => unimplemented!(), + crate::Expression::Derivative { axis: _, expr } => { + self.resolutions[expr.index()].clone() + } + crate::Expression::Relational { .. } => Resolution::Value(crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + width: 4, + }), + crate::Expression::Math { + fun, + arg, + arg1, + arg2: _, + } => { + use crate::MathFunction as Mf; + match fun { + // comparison + Mf::Abs | + Mf::Min | + Mf::Max | + Mf::Clamp | + // trigonometry + Mf::Cos | + Mf::Cosh | + Mf::Sin | + Mf::Sinh | + Mf::Tan | + Mf::Tanh | + Mf::Acos | + Mf::Asin | + Mf::Atan | + Mf::Atan2 | + // decomposition + Mf::Ceil | + Mf::Floor | + Mf::Round | + Mf::Fract | + Mf::Trunc | + Mf::Modf | + Mf::Frexp | + Mf::Ldexp | + // exponent + Mf::Exp | + Mf::Exp2 | + Mf::Log | + Mf::Log2 | + Mf::Pow => self.resolutions[arg.index()].clone(), + // geometry + Mf::Dot => match *self.get(arg, types) { + crate::TypeInner::Vector { + kind, + size: _, + width, + } => Resolution::Value(crate::TypeInner::Scalar { kind, width }), + ref other => { + return Err(ResolveError::IncompatibleOperand { + op: "dot product".to_string(), + operand: format!("{:?}", other), + }) + } + }, + Mf::Outer => { + let arg1 = arg1.ok_or_else(|| ResolveError::IncompatibleOperand { + op: "outer product".to_string(), + operand: "".to_string(), + })?; + match (self.get(arg, types), self.get(arg1,types)) { + (&crate::TypeInner::Vector {kind: _, size: columns,width}, &crate::TypeInner::Vector{ size: rows, .. }) => Resolution::Value(crate::TypeInner::Matrix { columns, rows, width }), + (left, right) => { + return Err(ResolveError::IncompatibleOperands { + op: "outer product".to_string(), + left: format!("{:?}", left), + right: format!("{:?}", right), + }) + } + } + }, + Mf::Cross => self.resolutions[arg.index()].clone(), + Mf::Distance | + Mf::Length => match *self.get(arg, types) { + crate::TypeInner::Scalar {width,kind} | + crate::TypeInner::Vector {width,kind,size:_} => Resolution::Value(crate::TypeInner::Scalar { kind, width }), + ref other => { + return Err(ResolveError::IncompatibleOperand { + op: format!("{:?}", fun), + operand: format!("{:?}", other), + }) + } + }, + Mf::Normalize | + Mf::FaceForward | + Mf::Reflect => self.resolutions[arg.index()].clone(), + // computational + Mf::Sign | + Mf::Fma | + Mf::Mix | + Mf::Step | + Mf::SmoothStep | + Mf::Sqrt | + Mf::InverseSqrt => self.resolutions[arg.index()].clone(), + Mf::Determinant => match *self.get(arg, types) { + crate::TypeInner::Matrix { + width, + .. + } => Resolution::Value(crate::TypeInner::Scalar { kind: crate::ScalarKind::Float, width }), + ref other => { + return Err(ResolveError::IncompatibleOperand { + op: "determinant".to_string(), + operand: format!("{:?}", other), + }) + } + }, + // bits + Mf::CountOneBits | + Mf::ReverseBits => self.resolutions[arg.index()].clone(), + } + } crate::Expression::Transpose(expr) => match *self.get(expr, types) { crate::TypeInner::Matrix { columns, @@ -264,20 +379,6 @@ impl Typifier { }) } }, - crate::Expression::DotProduct(left_expr, _) => match *self.get(left_expr, types) { - crate::TypeInner::Vector { - kind, - size: _, - width, - } => Resolution::Value(crate::TypeInner::Scalar { kind, width }), - ref other => { - return Err(ResolveError::IncompatibleOperand { - op: "dot product".to_string(), - operand: format!("{:?}", other), - }) - } - }, - crate::Expression::CrossProduct(_, _) => unimplemented!(), crate::Expression::As { expr, kind, @@ -298,45 +399,11 @@ impl Typifier { }) } }, - crate::Expression::Derivative { .. } => unimplemented!(), crate::Expression::Call { - origin: crate::FunctionOrigin::External(ref name), - ref arguments, - } => match name.as_str() { - "distance" | "length" => match *self.get(arguments[0], types) { - crate::TypeInner::Vector { kind, width, .. } - | crate::TypeInner::Scalar { kind, width } => { - Resolution::Value(crate::TypeInner::Scalar { kind, width }) - } - ref other => { - return Err(ResolveError::IncompatibleOperand { - op: name.clone(), - operand: format!("{:?}", other), - }) - } - }, - "dot" => match *self.get(arguments[0], types) { - crate::TypeInner::Vector { kind, width, .. } => { - Resolution::Value(crate::TypeInner::Scalar { kind, width }) - } - ref other => { - return Err(ResolveError::IncompatibleOperand { - op: name.clone(), - operand: format!("{:?}", other), - }) - } - }, - //Note: `cross` is here too, we still need to figure out what to do with it - "abs" | "atan2" | "cos" | "sin" | "floor" | "inverse" | "normalize" | "min" - | "max" | "reflect" | "pow" | "clamp" | "fclamp" | "mix" | "step" - | "smoothstep" | "cross" => self.resolutions[arguments[0].index()].clone(), - _ => return Err(ResolveError::FunctionNotDefined { name: name.clone() }), - }, - crate::Expression::Call { - origin: crate::FunctionOrigin::Local(handle), + function, arguments: _, } => { - let ty = ctx.functions[handle] + let ty = ctx.functions[function] .return_type .ok_or(ResolveError::FunctionReturnsVoid)?; Resolution::Handle(ty) diff --git a/test-data/boids.wgsl b/test-data/boids.wgsl index 9bc98052e9..ec23768870 100644 --- a/test-data/boids.wgsl +++ b/test-data/boids.wgsl @@ -23,10 +23,10 @@ import "GLSL.std.450" as std; [[stage(vertex)]] fn main() -> void { - var angle : f32 = -std::atan2(a_particleVel.x, a_particleVel.y); + var angle : f32 = -atan2(a_particleVel.x, a_particleVel.y); var pos : vec2 = vec2( - (a_pos.x * std::cos(angle)) - (a_pos.y * std::sin(angle)), - (a_pos.x * std::sin(angle)) + (a_pos.y * std::cos(angle))); + (a_pos.x * cos(angle)) - (a_pos.y * sin(angle)), + (a_pos.x * sin(angle)) + (a_pos.y * cos(angle))); gl_Position = vec4(pos + a_particlePos, 0.0, 1.0); return; } @@ -97,14 +97,14 @@ fn main() -> void { pos = particlesA.particles[i].pos.xy; vel = particlesA.particles[i].vel.xy; - if (std::distance(pos, vPos) < params.rule1Distance) { + if (distance(pos, vPos) < params.rule1Distance) { cMass = cMass + pos; cMassCount = cMassCount + 1; } - if (std::distance(pos, vPos) < params.rule2Distance) { + if (distance(pos, vPos) < params.rule2Distance) { colVel = colVel - (pos - vPos); } - if (std::distance(pos, vPos) < params.rule3Distance) { + if (distance(pos, vPos) < params.rule3Distance) { cVel = cVel + vel; cVelCount = cVelCount + 1; } @@ -124,7 +124,7 @@ fn main() -> void { (cVel * params.rule3Scale); # clamp velocity for a more pleasing simulation - vVel = std::normalize(vVel) * std::fclamp(std::length(vVel), 0.0, 0.1); + vVel = normalize(vVel) * clamp(length(vVel), 0.0, 0.1); # kinematic update vPos = vPos + (vVel * params.deltaT);