diff --git a/src/back/glsl.rs b/src/back/glsl.rs index 75780860e5..23187ffaeb 100644 --- a/src/back/glsl.rs +++ b/src/back/glsl.rs @@ -756,7 +756,6 @@ fn write_statement<'a, 'b>( indent: usize, ) -> Result { Ok(match sta { - Statement::Empty => String::new(), Statement::Block(block) => block .iter() .map(|sta| write_statement(sta, module, builder, indent)) @@ -1173,12 +1172,19 @@ fn write_expression<'a, 'b>( let value_expr = write_expression(&builder.expressions[expr], module, builder)?; let (source_kind, ty_expr) = match *builder.typifier.get(expr, &module.types) { - TypeInner::Scalar { width, kind } => ( - kind, + TypeInner::Scalar { + width, + kind: source_kind, + } => ( + source_kind, Cow::Borrowed(map_scalar(kind, width, builder.manager)?.full), ), - TypeInner::Vector { width, kind, size } => ( - kind, + TypeInner::Vector { + width, + kind: source_kind, + size, + } => ( + source_kind, Cow::Owned(format!( "{}vec{}", map_scalar(kind, width, builder.manager)?.prefix, diff --git a/src/back/msl.rs b/src/back/msl.rs index 618d258c6b..78e091f87c 100644 --- a/src/back/msl.rs +++ b/src/back/msl.rs @@ -71,15 +71,15 @@ pub enum Error { BadName(String), UnexpectedGlobalType(Handle), UnimplementedBindTarget(BindTarget), - UnexpectedIndexing(crate::Expression), UnsupportedCompose(Handle), UnsupportedBinaryOp(crate::BinaryOperator), UnexpectedSampleLevel(crate::SampleLevel), UnsupportedCall(String), - UnsupportedExpression(crate::Expression), + UnsupportedDynamicArrayLength, UnableToReturnValue(Handle), - UnsupportedStatement(crate::Statement), AccessIndexExceedsStaticLength(u32, u32), + /// The source IR is not valid. + Validation, } impl From for Error { @@ -370,6 +370,24 @@ fn separate(is_last: bool) -> &'static str { } impl Writer { + fn put_call( + &mut self, + name: &str, + parameters: &[Handle], + function: &crate::Function, + module: &crate::Module, + ) -> Result<(), Error> { + write!(self.out, "{}(", name)?; + for (i, &handle) in parameters.iter().enumerate() { + if i != 0 { + write!(self.out, ", ")?; + } + self.put_expression(handle, function, module)?; + } + write!(self.out, ")")?; + Ok(()) + } + fn put_expression( &mut self, expr_handle: Handle, @@ -381,19 +399,9 @@ impl Writer { match *expression { crate::Expression::Access { base, index } => { self.put_expression(base, function, module)?; - match *self.typifier.get(base, &module.types) { - crate::TypeInner::Array { .. } => { - //TODO: add size check - self.out.write_str("[")?; - self.put_expression(index, function, module)?; - self.out.write_str("]")?; - } - _ => { - return Err(Error::UnexpectedIndexing( - function.expressions[base].clone(), - )) - } - } + self.out.write_str("[")?; + self.put_expression(index, function, module)?; + self.out.write_str("]")?; } crate::Expression::AccessIndex { base, index } => { self.put_expression(base, function, module)?; @@ -419,9 +427,7 @@ impl Writer { write!(self.out, "[{}]", index)?; } _ => { - return Err(Error::UnexpectedIndexing( - function.expressions[base].clone(), - )) + // unexpected indexing, should fail validation } } } @@ -432,26 +438,22 @@ impl Writer { crate::TypeInner::Vector { size, kind, .. } => { write!( self.out, - "{}{}(", + "{}{}", scalar_kind_string(kind), vector_size_string(size) )?; - for (i, &handle) in components.iter().enumerate() { - if i != 0 { - write!(self.out, ", ")?; - } - self.put_expression(handle, function, module)?; - } - write!(self.out, ")")?; + self.put_call("", components, function, module)?; } crate::TypeInner::Scalar { width: 4, kind } if components.len() == 1 => { - write!(self.out, "{}(", scalar_kind_string(kind))?; - self.put_expression(components[0], function, module)?; - write!(self.out, ")")?; + self.put_call(scalar_kind_string(kind), components, function, module)?; } _ => return Err(Error::UnsupportedCompose(ty)), } } + crate::Expression::FunctionParameter(index) => { + let name = Name::from(ParameterIndex(index as usize)); + write!(self.out, "{}", name)?; + } crate::Expression::GlobalVariable(handle) => { let var = &module.global_variables[handle]; match var.class { @@ -480,36 +482,6 @@ impl Writer { //write!(self.out, "*")?; self.put_expression(pointer, function, module)?; } - crate::Expression::Unary { op, expr } => { - let op_str = match op { - crate::UnaryOperator::Negate => "-", - crate::UnaryOperator::Not => "!", - }; - write!(self.out, "{}", op_str)?; - self.put_expression(expr, function, module)?; - } - crate::Expression::Binary { op, left, right } => { - let op_str = match op { - crate::BinaryOperator::Add => "+", - crate::BinaryOperator::Subtract => "-", - crate::BinaryOperator::Multiply => "*", - crate::BinaryOperator::Divide => "/", - crate::BinaryOperator::Modulo => "%", - crate::BinaryOperator::Equal => "==", - crate::BinaryOperator::NotEqual => "!=", - crate::BinaryOperator::Less => "<", - crate::BinaryOperator::LessEqual => "<=", - crate::BinaryOperator::Greater => "==", - crate::BinaryOperator::GreaterEqual => ">=", - crate::BinaryOperator::And => "&", - other => return Err(Error::UnsupportedBinaryOp(other)), - }; - //write!(self.out, "(")?; - self.put_expression(left, function, module)?; - write!(self.out, " {} ", op_str)?; - self.put_expression(right, function, module)?; - //write!(self.out, ")")?; - } crate::Expression::ImageSample { image, sampler, @@ -564,55 +536,113 @@ impl Writer { } write!(self.out, ")")?; } + crate::Expression::Unary { op, expr } => { + let op_str = match op { + crate::UnaryOperator::Negate => "-", + crate::UnaryOperator::Not => "!", + }; + write!(self.out, "{}", op_str)?; + self.put_expression(expr, function, module)?; + } + crate::Expression::Binary { op, left, right } => { + let op_str = match op { + crate::BinaryOperator::Add => "+", + crate::BinaryOperator::Subtract => "-", + crate::BinaryOperator::Multiply => "*", + crate::BinaryOperator::Divide => "/", + crate::BinaryOperator::Modulo => "%", + crate::BinaryOperator::Equal => "==", + crate::BinaryOperator::NotEqual => "!=", + crate::BinaryOperator::Less => "<", + crate::BinaryOperator::LessEqual => "<=", + crate::BinaryOperator::Greater => "==", + crate::BinaryOperator::GreaterEqual => ">=", + crate::BinaryOperator::And => "&", + other => return Err(Error::UnsupportedBinaryOp(other)), + }; + //write!(self.out, "(")?; + self.put_expression(left, function, module)?; + write!(self.out, " {} ", op_str)?; + self.put_expression(right, function, module)?; + //write!(self.out, ")")?; + } + crate::Expression::Intrinsic { fun, argument } => { + let op = match fun { + crate::IntrinsicFunction::Any => "any", + crate::IntrinsicFunction::All => "all", + crate::IntrinsicFunction::IsNan => "", + crate::IntrinsicFunction::IsInf => "", + crate::IntrinsicFunction::IsFinite => "", + crate::IntrinsicFunction::IsNormal => "", + }; + self.put_call(op, &[argument], function, module)?; + } + crate::Expression::Transpose(expr) => { + self.put_call("transpose", &[expr], function, module)?; + } + crate::Expression::DotProduct(a, b) => { + self.put_call("dot", &[a, b], function, module)?; + } + crate::Expression::CrossProduct(a, b) => { + self.put_call("cross", &[a, b], function, module)?; + } + crate::Expression::As { + expr, + kind, + convert, + } => { + let scalar = scalar_kind_string(kind); + let size = match *self.typifier.get(expr, &module.types) { + crate::TypeInner::Scalar { .. } => "", + crate::TypeInner::Vector { size, .. } => vector_size_string(size), + _ => return Err(Error::Validation), + }; + let op = if convert { "static_cast" } else { "as_type" }; + write!(self.out, "{}<{}{}>(", op, scalar, size)?; + self.put_expression(expr, function, module)?; + write!(self.out, ")")?; + } + crate::Expression::Derivative { axis, expr } => { + let op = match axis { + crate::DerivativeAxis::X => "dfdx", + crate::DerivativeAxis::Y => "dfdy", + crate::DerivativeAxis::Width => "fwidth", + }; + self.put_call(op, &[expr], function, module)?; + } + crate::Expression::Call { + origin: crate::FunctionOrigin::Local(handle), + ref arguments, + } => { + let name = module.functions[handle].name.or_index(handle); + write!(self.out, "{}", name)?; + self.put_call("", arguments, function, module)?; + } crate::Expression::Call { origin: crate::FunctionOrigin::External(ref name), ref arguments, } => match name.as_str() { - "cos" | "normalize" | "sin" => { - write!(self.out, "{}(", name)?; - self.put_expression(arguments[0], function, module)?; - write!(self.out, ")")?; + "atan2" | "cos" | "distance" | "length" | "mix" | "normalize" | "sin" => { + self.put_call(name, arguments, function, module)?; } "fclamp" => { - write!(self.out, "clamp(")?; - self.put_expression(arguments[0], function, module)?; - write!(self.out, ", ")?; - self.put_expression(arguments[1], function, module)?; - write!(self.out, ", ")?; - self.put_expression(arguments[2], function, module)?; - write!(self.out, ")")?; - } - "atan2" => { - write!(self.out, "{}(", name)?; - self.put_expression(arguments[0], function, module)?; - write!(self.out, ", ")?; - self.put_expression(arguments[1], function, module)?; - write!(self.out, ")")?; - } - "distance" => { - write!(self.out, "distance(")?; - self.put_expression(arguments[0], function, module)?; - write!(self.out, ", ")?; - self.put_expression(arguments[1], function, module)?; - write!(self.out, ")")?; - } - "length" => { - write!(self.out, "length(")?; - self.put_expression(arguments[0], function, module)?; - write!(self.out, ")")?; - } - "mix" => { - write!(self.out, "mix(")?; - self.put_expression(arguments[0], function, module)?; - write!(self.out, ", ")?; - self.put_expression(arguments[1], function, module)?; - write!(self.out, ", ")?; - self.put_expression(arguments[2], function, module)?; - write!(self.out, ")")?; + self.put_call("clamp", arguments, function, module)?; } other => return Err(Error::UnsupportedCall(other.to_owned())), }, - ref other => return Err(Error::UnsupportedExpression(other.clone())), + crate::Expression::ArrayLength(expr) => { + let size = match *self.typifier.get(expr, &module.types) { + crate::TypeInner::Array { + size: crate::ArraySize::Static(size), + .. + } => size, + crate::TypeInner::Array { .. } => { + return Err(Error::UnsupportedDynamicArrayLength) + } + _ => return Err(Error::Validation), + }; + write!(self.out, "{}", size)?; + } } Ok(()) } @@ -657,79 +687,100 @@ impl Writer { Ok(()) } - fn put_statement( + fn put_block( &mut self, level: Level, - statement: &crate::Statement, + statements: &[crate::Statement], function: &crate::Function, - has_output: bool, module: &crate::Module, ) -> Result<(), Error> { - log::trace!("statement[{}] {:?}", level.0, statement); - match *statement { - crate::Statement::Empty => {} - crate::Statement::If { - condition, - ref accept, - ref reject, - } => { - write!(self.out, "{}if (", level)?; - self.put_expression(condition, function, module)?; - writeln!(self.out, ") {{")?; - for s in accept { - self.put_statement(level.next(), s, function, has_output, module)?; - } - if !reject.is_empty() { - writeln!(self.out, "{}}} else {{", level)?; - for s in reject { - self.put_statement(level.next(), s, function, has_output, module)?; + for statement in statements { + log::trace!("statement[{}] {:?}", level.0, statement); + match *statement { + crate::Statement::Block(ref block) => { + if !block.is_empty() { + writeln!(self.out, "{}{{", level)?; + self.put_block(level.next(), block, function, module)?; + writeln!(self.out, "{}}}", level)?; } } - writeln!(self.out, "{}}}", level)?; - } - crate::Statement::Loop { - ref body, - ref continuing, - } => { - writeln!(self.out, "{}while(true) {{", level)?; - for s in body { - self.put_statement(level.next(), s, function, has_output, module)?; - } - if !continuing.is_empty() { - //TODO - } - writeln!(self.out, "{}}}", level)?; - } - crate::Statement::Store { pointer, value } => { - //write!(self.out, "\t*")?; - write!(self.out, "{}", level)?; - self.put_expression(pointer, function, module)?; - write!(self.out, " = ")?; - self.put_expression(value, function, module)?; - writeln!(self.out, ";")?; - } - crate::Statement::Break => { - writeln!(self.out, "{}break;", level)?; - } - crate::Statement::Continue => { - writeln!(self.out, "{}continue;", level)?; - } - crate::Statement::Return { value } => { - write!(self.out, "{}return ", level)?; - match value { - None if has_output => self.out.write_str(OUTPUT_STRUCT_NAME)?, - None => {} - Some(expr_handle) if has_output => { - return Err(Error::UnableToReturnValue(expr_handle)); - } - Some(expr_handle) => { - self.put_expression(expr_handle, function, module)?; + crate::Statement::If { + condition, + ref accept, + ref reject, + } => { + write!(self.out, "{}if (", level)?; + self.put_expression(condition, function, module)?; + writeln!(self.out, ") {{")?; + self.put_block(level.next(), accept, function, module)?; + if !reject.is_empty() { + writeln!(self.out, "{}}} else {{", level)?; + self.put_block(level.next(), reject, function, module)?; } + writeln!(self.out, "{}}}", level)?; + } + crate::Statement::Switch { + selector, + ref cases, + ref default, + } => { + write!(self.out, "{}switch(", level)?; + self.put_expression(selector, function, module)?; + writeln!(self.out, ") {{")?; + let lcase = level.next(); + for (&value, &(ref block, ref fall_through)) in cases.iter() { + writeln!(self.out, "{}case {}: {{", lcase, value)?; + self.put_block(lcase.next(), block, function, module)?; + if fall_through.is_none() { + writeln!(self.out, "{}break;", lcase.next())?; + } + writeln!(self.out, "{}}}", lcase)?; + } + writeln!(self.out, "{}default: {{", lcase)?; + self.put_block(lcase.next(), default, function, module)?; + writeln!(self.out, "{}}}", lcase)?; + writeln!(self.out, "{}}}", level)?; + } + crate::Statement::Loop { + ref body, + ref continuing, + } => { + writeln!(self.out, "{}while(true) {{", level)?; + self.put_block(level.next(), body, function, module)?; + if !continuing.is_empty() { + //TODO + } + writeln!(self.out, "{}}}", level)?; + } + crate::Statement::Break => { + writeln!(self.out, "{}break;", level)?; + } + crate::Statement::Continue => { + writeln!(self.out, "{}continue;", level)?; + } + crate::Statement::Return { value } => { + write!(self.out, "{}return ", level)?; + match value { + None => self.out.write_str(OUTPUT_STRUCT_NAME)?, + Some(expr_handle) => { + self.put_expression(expr_handle, function, module)?; + } + } + writeln!(self.out, ";")?; + } + crate::Statement::Kill => { + writeln!(self.out, "{}discard_fragment();", level)?; + } + crate::Statement::Store { pointer, value } => { + //write!(self.out, "\t*")?; + write!(self.out, "{}", level)?; + self.put_expression(pointer, function, module)?; + write!(self.out, " = ")?; + self.put_expression(value, function, module)?; + writeln!(self.out, ";")?; } - writeln!(self.out, ";")?; } - ref other => return Err(Error::UnsupportedStatement(other.clone())), - }; + } Ok(()) } @@ -924,9 +975,7 @@ impl Writer { } writeln!(self.out, ";")?; } - for statement in fun.body.iter() { - self.put_statement(Level(1), statement, fun, false, module)?; - } + self.put_block(Level(1), &fun.body, fun, module)?; writeln!(self.out, "}}")?; } @@ -1132,13 +1181,12 @@ impl Writer { } writeln!(self.out, ") {{")?; - let has_output = match stage { + match stage { crate::ShaderStage::Vertex | crate::ShaderStage::Fragment => { writeln!(self.out, "\t{} {};", output_name, OUTPUT_STRUCT_NAME)?; - true } - crate::ShaderStage::Compute => false, - }; + crate::ShaderStage::Compute => {} + } for (local_handle, local) in fun.local_variables.iter() { let ty_name = module.types[local.ty].name.or_index(local.ty); write!( @@ -1153,9 +1201,7 @@ impl Writer { } writeln!(self.out, ";")?; } - for statement in fun.body.iter() { - self.put_statement(Level(1), statement, fun, has_output, module)?; - } + self.put_block(Level(1), &fun.body, fun, module)?; writeln!(self.out, "}}")?; } diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index a0939d73b5..faca512f64 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -321,19 +321,7 @@ impl Writer { function_type, )); - let mut block = Block::new(); - let id = self.generate_id(); - block.label = Some(super::instructions::instruction_label(id)); - for statement in ir_function.body.iter() { - self.write_function_statement( - ir_module, - ir_function, - &statement, - &mut block, - &mut function, - ); - } - function.blocks.push(block); + let id = self.write_block(&ir_function.body, ir_module, ir_function, &mut function); function.to_words(&mut self.logical_layout.function_definitions); super::instructions::instruction_function_end() @@ -996,74 +984,111 @@ impl Writer { } } - fn write_function_statement( + fn write_block( &mut self, + statements: &[crate::Statement], ir_module: &crate::Module, ir_function: &crate::Function, - statement: &crate::Statement, - block: &mut Block, function: &mut Function, - ) { - match statement { - crate::Statement::Return { value } => match ir_function.return_type { - Some(_) => { - let expression = &ir_function.expressions[value.unwrap()]; - let (id, ty) = self - .write_expression(ir_module, ir_function, expression, block, function) + ) -> spirv::Word { + let mut block = Block::new(); + let id = self.generate_id(); + block.label = Some(super::instructions::instruction_label(id)); + + for statement in statements { + match *statement { + crate::Statement::Block(ref ir_block) => { + if !ir_block.is_empty() { + //TODO: link the block with `OpBranch` + self.write_block(ir_block, ir_module, ir_function, function); + } + } + crate::Statement::Return { value } => { + block.termination = Some(match ir_function.return_type { + Some(_) => { + let expression = &ir_function.expressions[value.unwrap()]; + let (id, ty) = self + .write_expression( + ir_module, + ir_function, + expression, + &mut block, + function, + ) + .unwrap(); + + let id = match *expression { + crate::Expression::LocalVariable(_) => { + let load_id = self.generate_id(); + let value_ty_id = self.get_type_id( + &ir_module.types, + LookupType::Handle(ty.unwrap()), + ); + block.body.push(super::instructions::instruction_load( + value_ty_id, + load_id, + id, + None, + )); + load_id + } + _ => id, + }; + super::instructions::instruction_return_value(id) + } + None => super::instructions::instruction_return(), + }); + } + crate::Statement::Store { pointer, value } => { + let pointer_expression = &ir_function.expressions[pointer]; + let value_expression = &ir_function.expressions[value]; + let (pointer_id, _) = self + .write_expression( + ir_module, + ir_function, + pointer_expression, + &mut block, + function, + ) + .unwrap(); + let (value_id, value_ty) = self + .write_expression( + ir_module, + ir_function, + value_expression, + &mut block, + function, + ) .unwrap(); - let id = match expression { + let value_id = match value_expression { crate::Expression::LocalVariable(_) => { let load_id = self.generate_id(); - let value_ty_id = - self.get_type_id(&ir_module.types, LookupType::Handle(ty.unwrap())); + let value_ty_id = self.get_type_id( + &ir_module.types, + LookupType::Handle(value_ty.unwrap()), + ); block.body.push(super::instructions::instruction_load( value_ty_id, load_id, - id, + value_id, None, )); load_id } - _ => id, + _ => value_id, }; - block.termination = Some(super::instructions::instruction_return_value(id)); + + block.body.push(super::instructions::instruction_store( + pointer_id, value_id, None, + )); } - None => block.termination = Some(super::instructions::instruction_return()), - }, - crate::Statement::Store { pointer, value } => { - let pointer_expression = &ir_function.expressions[*pointer]; - let value_expression = &ir_function.expressions[*value]; - let (pointer_id, _) = self - .write_expression(ir_module, ir_function, pointer_expression, block, function) - .unwrap(); - let (value_id, value_ty) = self - .write_expression(ir_module, ir_function, value_expression, block, function) - .unwrap(); - - let value_id = match value_expression { - crate::Expression::LocalVariable(_) => { - let load_id = self.generate_id(); - let value_ty_id = self - .get_type_id(&ir_module.types, LookupType::Handle(value_ty.unwrap())); - block.body.push(super::instructions::instruction_load( - value_ty_id, - load_id, - value_id, - None, - )); - load_id - } - _ => value_id, - }; - - block.body.push(super::instructions::instruction_store( - pointer_id, value_id, None, - )); + _ => unimplemented!("{:?}", statement), } - crate::Statement::Empty => {} - _ => unimplemented!("{:?}", statement), } + + function.blocks.push(block); + id } fn write_physical_layout(&mut self) { diff --git a/src/front/glsl/parser.rs b/src/front/glsl/parser.rs index e6ee55635d..e0dff7325d 100644 --- a/src/front/glsl/parser.rs +++ b/src/front/glsl/parser.rs @@ -726,7 +726,6 @@ pomelo! { extra.context.add_local_var(id, exp); } match statements.len() { - 0 => Statement::Empty, 1 => statements.remove(0), _ => Statement::Block(statements), } @@ -774,7 +773,7 @@ pomelo! { statement_list ::= statement_list(mut ss) statement(s) { ss.push(s); ss } expression_statement ::= Semicolon { - Statement::Empty + Statement::Block(Vec::new()) } expression_statement ::= expression(mut e) Semicolon { match e.statements.len() { diff --git a/src/front/wgsl/conv.rs b/src/front/wgsl/conv.rs new file mode 100644 index 0000000000..0b25ad9435 --- /dev/null +++ b/src/front/wgsl/conv.rs @@ -0,0 +1,116 @@ +use super::Error; + +pub fn map_storage_class(word: &str) -> Result> { + match word { + "in" => Ok(crate::StorageClass::Input), + "out" => Ok(crate::StorageClass::Output), + "uniform" => Ok(crate::StorageClass::Uniform), + "storage_buffer" => Ok(crate::StorageClass::StorageBuffer), + _ => Err(Error::UnknownStorageClass(word)), + } +} + +pub fn map_built_in(word: &str) -> Result> { + Ok(match word { + // vertex + "position" => crate::BuiltIn::Position, + "vertex_idx" => crate::BuiltIn::VertexIndex, + "instance_idx" => crate::BuiltIn::InstanceIndex, + // fragment + "front_facing" => crate::BuiltIn::FrontFacing, + "frag_coord" => crate::BuiltIn::FragCoord, + "frag_depth" => crate::BuiltIn::FragDepth, + // compute + "global_invocation_id" => crate::BuiltIn::GlobalInvocationId, + "local_invocation_id" => crate::BuiltIn::LocalInvocationId, + "local_invocation_idx" => crate::BuiltIn::LocalInvocationIndex, + _ => return Err(Error::UnknownBuiltin(word)), + }) +} + +pub fn map_shader_stage(word: &str) -> Result> { + match word { + "vertex" => Ok(crate::ShaderStage::Vertex), + "fragment" => Ok(crate::ShaderStage::Fragment), + "compute" => Ok(crate::ShaderStage::Compute), + _ => Err(Error::UnknownShaderStage(word)), + } +} + +pub fn map_interpolation(word: &str) -> Result> { + match word { + "linear" => Ok(crate::Interpolation::Linear), + "flat" => Ok(crate::Interpolation::Flat), + "centroid" => Ok(crate::Interpolation::Centroid), + "sample" => Ok(crate::Interpolation::Sample), + "perspective" => Ok(crate::Interpolation::Perspective), + _ => Err(Error::UnknownDecoration(word)), + } +} + +pub fn map_storage_format(word: &str) -> Result> { + use crate::StorageFormat as Sf; + Ok(match word { + "r8unorm" => Sf::R8Unorm, + "r8snorm" => Sf::R8Snorm, + "r8uint" => Sf::R8Uint, + "r8sint" => Sf::R8Sint, + "r16uint" => Sf::R16Uint, + "r16sint" => Sf::R16Sint, + "r16float" => Sf::R16Float, + "rg8unorm" => Sf::Rg8Unorm, + "rg8snorm" => Sf::Rg8Snorm, + "rg8uint" => Sf::Rg8Uint, + "rg8sint" => Sf::Rg8Sint, + "r32uint" => Sf::R32Uint, + "r32sint" => Sf::R32Sint, + "r32float" => Sf::R32Float, + "rg16uint" => Sf::Rg16Uint, + "rg16sint" => Sf::Rg16Sint, + "rg16float" => Sf::Rg16Float, + "rgba8unorm" => Sf::Rgba8Unorm, + "rgba8snorm" => Sf::Rgba8Snorm, + "rgba8uint" => Sf::Rgba8Uint, + "rgba8sint" => Sf::Rgba8Sint, + "rgb10a2unorm" => Sf::Rgb10a2Unorm, + "rg11b10float" => Sf::Rg11b10Float, + "rg32uint" => Sf::Rg32Uint, + "rg32sint" => Sf::Rg32Sint, + "rg32float" => Sf::Rg32Float, + "rgba16uint" => Sf::Rgba16Uint, + "rgba16sint" => Sf::Rgba16Sint, + "rgba16float" => Sf::Rgba16Float, + "rgba32uint" => Sf::Rgba32Uint, + "rgba32sint" => Sf::Rgba32Sint, + "rgba32float" => Sf::Rgba32Float, + _ => return Err(Error::UnknownStorageFormat(word)), + }) +} + +pub fn get_scalar_type(word: &str) -> Option<(crate::ScalarKind, crate::Bytes)> { + match word { + "f32" => Some((crate::ScalarKind::Float, 4)), + "i32" => Some((crate::ScalarKind::Sint, 4)), + "u32" => Some((crate::ScalarKind::Uint, 4)), + _ => None, + } +} + +pub fn get_intrinsic(word: &str) -> Option { + match word { + "any" => Some(crate::IntrinsicFunction::Any), + "all" => Some(crate::IntrinsicFunction::All), + "is_nan" => Some(crate::IntrinsicFunction::IsNan), + "is_inf" => Some(crate::IntrinsicFunction::IsInf), + "is_normal" => Some(crate::IntrinsicFunction::IsNormal), + _ => None, + } +} +pub fn get_derivative(word: &str) -> Option { + match word { + "dpdx" => Some(crate::DerivativeAxis::X), + "dpdy" => Some(crate::DerivativeAxis::Y), + "dwidth" => Some(crate::DerivativeAxis::Width), + _ => None, + } +} diff --git a/src/front/wgsl/lexer.rs b/src/front/wgsl/lexer.rs index af41041752..65bea63842 100644 --- a/src/front/wgsl/lexer.rs +++ b/src/front/wgsl/lexer.rs @@ -1,4 +1,4 @@ -use super::{Error, Token}; +use super::{conv, Error, Token}; fn _consume_str<'a>(input: &'a str, what: &str) -> Option<&'a str> { if input.starts_with(what) { @@ -214,56 +214,17 @@ impl<'a> Lexer<'a> { &mut self, ) -> Result<(crate::ScalarKind, crate::Bytes), Error<'a>> { self.expect(Token::Paren('<'))?; - let pair = match self.next() { - Token::Word("f32") => (crate::ScalarKind::Float, 4), - Token::Word("i32") => (crate::ScalarKind::Sint, 4), - Token::Word("u32") => (crate::ScalarKind::Uint, 4), - other => return Err(Error::Unexpected(other)), - }; + let word = self.next_ident()?; + let pair = conv::get_scalar_type(word).ok_or(Error::UnknownScalarType(word))?; self.expect(Token::Paren('>'))?; Ok(pair) } pub(super) fn next_format_generic(&mut self) -> Result> { - use crate::StorageFormat as Sf; self.expect(Token::Paren('<'))?; - let pair = match self.next() { - Token::Word("r8unorm") => Sf::R8Unorm, - Token::Word("r8snorm") => Sf::R8Snorm, - Token::Word("r8uint") => Sf::R8Uint, - Token::Word("r8sint") => Sf::R8Sint, - Token::Word("r16uint") => Sf::R16Uint, - Token::Word("r16sint") => Sf::R16Sint, - Token::Word("r16float") => Sf::R16Float, - Token::Word("rg8unorm") => Sf::Rg8Unorm, - Token::Word("rg8snorm") => Sf::Rg8Snorm, - Token::Word("rg8uint") => Sf::Rg8Uint, - Token::Word("rg8sint") => Sf::Rg8Sint, - Token::Word("r32uint") => Sf::R32Uint, - Token::Word("r32sint") => Sf::R32Sint, - Token::Word("r32float") => Sf::R32Float, - Token::Word("rg16uint") => Sf::Rg16Uint, - Token::Word("rg16sint") => Sf::Rg16Sint, - Token::Word("rg16float") => Sf::Rg16Float, - Token::Word("rgba8unorm") => Sf::Rgba8Unorm, - Token::Word("rgba8snorm") => Sf::Rgba8Snorm, - Token::Word("rgba8uint") => Sf::Rgba8Uint, - Token::Word("rgba8sint") => Sf::Rgba8Sint, - Token::Word("rgb10a2unorm") => Sf::Rgb10a2Unorm, - Token::Word("rg11b10float") => Sf::Rg11b10Float, - Token::Word("rg32uint") => Sf::Rg32Uint, - Token::Word("rg32sint") => Sf::Rg32Sint, - Token::Word("rg32float") => Sf::Rg32Float, - Token::Word("rgba16uint") => Sf::Rgba16Uint, - Token::Word("rgba16sint") => Sf::Rgba16Sint, - Token::Word("rgba16float") => Sf::Rgba16Float, - Token::Word("rgba32uint") => Sf::Rgba32Uint, - Token::Word("rgba32sint") => Sf::Rgba32Sint, - Token::Word("rgba32float") => Sf::Rgba32Float, - other => return Err(Error::Unexpected(other)), - }; + let format = conv::map_storage_format(self.next_ident()?)?; self.expect(Token::Paren('>'))?; - Ok(pair) + Ok(format) } pub(super) fn take_until(&mut self, what: Token<'_>) -> Result, Error<'a>> { diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 61f2330e37..7a4fc2b36c 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -2,6 +2,7 @@ //! //! [wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html +mod conv; mod lexer; use crate::{ @@ -60,10 +61,14 @@ pub enum Error<'a> { UnknownShaderStage(&'a str), #[error("unknown identifier: `{0}`")] UnknownIdent(&'a str), + #[error("unknown scalar type: `{0}`")] + UnknownScalarType(&'a str), #[error("unknown type: `{0}`")] UnknownType(&'a str), #[error("unknown function: `{0}`")] UnknownFunction(&'a str), + #[error("unknown storage format: `{0}`")] + UnknownStorageFormat(&'a str), #[error("missing offset for structure member `{0}`")] MissingMemberOffset(&'a str), #[error("array stride must not be 0")] @@ -72,8 +77,6 @@ pub enum Error<'a> { NotCompositeType(Handle), #[error("function redefinition: `{0}`")] FunctionRedefinition(&'a str), - //MutabilityViolation(&'a str), - // TODO: these could be replaced with more detailed errors #[error("other error")] Other, } @@ -289,54 +292,6 @@ impl Parser { } } - fn get_storage_class(word: &str) -> Result> { - match word { - "in" => Ok(crate::StorageClass::Input), - "out" => Ok(crate::StorageClass::Output), - "uniform" => Ok(crate::StorageClass::Uniform), - "storage_buffer" => Ok(crate::StorageClass::StorageBuffer), - _ => Err(Error::UnknownStorageClass(word)), - } - } - - fn get_built_in(word: &str) -> Result> { - Ok(match word { - // vertex - "position" => crate::BuiltIn::Position, - "vertex_idx" => crate::BuiltIn::VertexIndex, - "instance_idx" => crate::BuiltIn::InstanceIndex, - // fragment - "front_facing" => crate::BuiltIn::FrontFacing, - "frag_coord" => crate::BuiltIn::FragCoord, - "frag_depth" => crate::BuiltIn::FragDepth, - // compute - "global_invocation_id" => crate::BuiltIn::GlobalInvocationId, - "local_invocation_id" => crate::BuiltIn::LocalInvocationId, - "local_invocation_idx" => crate::BuiltIn::LocalInvocationIndex, - _ => return Err(Error::UnknownBuiltin(word)), - }) - } - - fn get_shader_stage(word: &str) -> Result> { - match word { - "vertex" => Ok(crate::ShaderStage::Vertex), - "fragment" => Ok(crate::ShaderStage::Fragment), - "compute" => Ok(crate::ShaderStage::Compute), - _ => Err(Error::UnknownShaderStage(word)), - } - } - - fn get_interpolation(word: &str) -> Result> { - match word { - "linear" => Ok(crate::Interpolation::Linear), - "flat" => Ok(crate::Interpolation::Flat), - "centroid" => Ok(crate::Interpolation::Centroid), - "sample" => Ok(crate::Interpolation::Sample), - "perspective" => Ok(crate::Interpolation::Perspective), - _ => Err(Error::UnknownDecoration(word)), - } - } - fn deconstruct_composite_type( type_arena: &mut Arena, ty: Handle, @@ -631,25 +586,6 @@ impl Parser { lexer: &mut Lexer<'a>, mut ctx: ExpressionContext<'a, '_, '_>, ) -> Result, Error<'a>> { - fn get_intrinsic(word: &str) -> Option { - match word { - "any" => Some(crate::IntrinsicFunction::Any), - "all" => Some(crate::IntrinsicFunction::All), - "is_nan" => Some(crate::IntrinsicFunction::IsNan), - "is_inf" => Some(crate::IntrinsicFunction::IsInf), - "is_normal" => Some(crate::IntrinsicFunction::IsNormal), - _ => None, - } - } - fn get_derivative(word: &str) -> Option { - match word { - "dpdx" => Some(crate::DerivativeAxis::X), - "dpdy" => Some(crate::DerivativeAxis::Y), - "dwidth" => Some(crate::DerivativeAxis::Width), - _ => None, - } - } - self.scopes.push(Scope::SingularExpr); let backup = lexer.clone(); let expression = match lexer.next() { @@ -662,16 +598,25 @@ impl Parser { expr: self.parse_singular_expression(lexer, ctx.reborrow())?, }), Token::Word(word) => { - if let Some(fun) = get_intrinsic(word) { + if let Some(fun) = conv::get_intrinsic(word) { lexer.expect(Token::Paren('('))?; let argument = self.parse_primary_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Paren(')'))?; Some(crate::Expression::Intrinsic { fun, argument }) - } else if let Some(axis) = get_derivative(word) { + } else if let Some(axis) = conv::get_derivative(word) { lexer.expect(Token::Paren('('))?; let expr = self.parse_primary_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Paren(')'))?; Some(crate::Expression::Derivative { axis, expr }) + } else if let Some((kind, _width)) = conv::get_scalar_type(word) { + lexer.expect(Token::Paren('('))?; + let expr = self.parse_primary_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Paren(')'))?; + Some(crate::Expression::As { + expr, + kind, + convert: true, + }) } else { match word { "dot" => { @@ -978,7 +923,7 @@ impl Parser { let mut class = None; if lexer.skip(Token::Paren('<')) { let class_str = lexer.next_ident()?; - class = Some(Self::get_storage_class(class_str)?); + class = Some(conv::map_storage_class(class_str)?); lexer.expect(Token::Paren('>'))?; } let name = lexer.next_ident()?; @@ -1186,7 +1131,7 @@ impl Parser { } Token::Word("ptr") => { lexer.expect(Token::Paren('<'))?; - let class = Self::get_storage_class(lexer.next_ident()?)?; + let class = conv::map_storage_class(lexer.next_ident()?)?; lexer.expect(Token::Separator(','))?; let base = self.parse_type_decl(lexer, None, type_arena)?; lexer.expect(Token::Paren('>'))?; @@ -1418,11 +1363,10 @@ impl Parser { &mut self, lexer: &mut Lexer<'a>, mut context: StatementContext<'a, '_, '_>, - ) -> Result, Error<'a>> { + ) -> Result> { let backup = lexer.clone(); match lexer.next() { - Token::Separator(';') => Ok(Some(crate::Statement::Empty)), - Token::Paren('}') => Ok(None), + Token::Separator(';') => Ok(crate::Statement::Block(Vec::new())), Token::Word(word) => { self.scopes.push(Scope::Statement); let statement = match word { @@ -1462,7 +1406,7 @@ impl Parser { pointer: expr_id, value, }, - _ => crate::Statement::Empty, + _ => crate::Statement::Block(Vec::new()), } } "return" => { @@ -1501,10 +1445,11 @@ impl Parser { lexer.expect(Token::Paren('}'))?; break; } - match self.parse_statement(lexer, context.reborrow())? { - Some(s) => body.push(s), - None => break, + if lexer.skip(Token::Paren('}')) { + break; } + let s = self.parse_statement(lexer, context.reborrow())?; + body.push(s); } crate::Statement::Loop { body, continuing } } @@ -1529,14 +1474,14 @@ impl Parser { *lexer = new_lexer; context.expressions.append(expr); lexer.expect(Token::Separator(';'))?; - crate::Statement::Empty + crate::Statement::Block(Vec::new()) } else { return Err(Error::UnknownIdent(ident)); } } }; self.scopes.pop(); - Ok(Some(statement)) + Ok(statement) } other => Err(Error::Unexpected(other)), } @@ -1550,7 +1495,8 @@ impl Parser { self.scopes.push(Scope::Block); lexer.expect(Token::Paren('{'))?; let mut statements = Vec::new(); - while let Some(s) = self.parse_statement(lexer, context.reborrow())? { + while !lexer.skip(Token::Paren('}')) { + let s = self.parse_statement(lexer, context.reborrow())?; statements.push(s); } self.scopes.pop(); @@ -1654,7 +1600,7 @@ impl Parser { } "builtin" => { lexer.expect(Token::Paren('('))?; - let builtin = Self::get_built_in(lexer.next_ident()?)?; + let builtin = conv::map_built_in(lexer.next_ident()?)?; lexer.expect(Token::Paren(')'))?; binding = Some(crate::Binding::BuiltIn(builtin)); } @@ -1670,12 +1616,12 @@ impl Parser { } "interpolate" => { lexer.expect(Token::Paren('('))?; - interpolation = Some(Self::get_interpolation(lexer.next_ident()?)?); + interpolation = Some(conv::map_interpolation(lexer.next_ident()?)?); lexer.expect(Token::Paren(')'))?; } "stage" => { lexer.expect(Token::Paren('('))?; - stage = Some(Self::get_shader_stage(lexer.next_ident()?)?); + stage = Some(conv::map_shader_stage(lexer.next_ident()?)?); lexer.expect(Token::Paren(')'))?; } "workgroup_size" => { diff --git a/src/lib.rs b/src/lib.rs index c54d057e88..a05391cb3a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -645,7 +645,7 @@ pub type Block = Vec; /// Marker type, used for falling through in a switch statement. // Clone is used only for error reporting and is not intended for end users -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub struct FallThrough; @@ -656,8 +656,6 @@ pub struct FallThrough; #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum Statement { - /// Empty statement, does nothing. - Empty, /// A block containing more statements, to be executed sequentially. Block(Block), /// Conditionally executes one of two blocks, based on the value of the condition. diff --git a/src/proc/interface.rs b/src/proc/interface.rs index ce117c56fe..33767f597f 100644 --- a/src/proc/interface.rs +++ b/src/proc/interface.rs @@ -119,7 +119,7 @@ where for statement in block { use crate::Statement as S; match *statement { - S::Empty | S::Break | S::Continue | S::Kill => (), + S::Break | S::Continue | S::Kill => (), S::Block(ref b) => { self.traverse(b); } diff --git a/test-data/boids.wgsl b/test-data/boids.wgsl index 6e4cf7c673..5df5e60611 100644 --- a/test-data/boids.wgsl +++ b/test-data/boids.wgsl @@ -70,7 +70,7 @@ type Particles = struct { [[stage(compute), workgroup_size(1)]] fn main() -> void { var index : u32 = gl_GlobalInvocationID.x; - if (index >= 5) { + if (index >= u32(5)) { return; } @@ -87,7 +87,7 @@ fn main() -> void { var vel : vec2; var i : u32 = 0; loop { - if (i >= 5) { + if (i >= u32(5)) { break; } if (i == index) { @@ -110,7 +110,7 @@ fn main() -> void { } continuing { - i = i + 1; + i = i + u32(1); } } if (cMassCount > 0) { diff --git a/test-data/simple/module.ron b/test-data/simple/module.ron index b45c6b417c..c3c34fc6b8 100644 --- a/test-data/simple/module.ron +++ b/test-data/simple/module.ron @@ -104,7 +104,7 @@ ), ], body: [ - Empty, + Block([]), Store( pointer: 2, value: 6,