diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 1aa412f646..f906d3fbaa 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -1,23 +1,19 @@ // TODO: temp #![allow(dead_code)] use super::Error; -use crate::FastHashMap; use crate::{ back::{binary_operation_str, vector_size_str, wgsl::keywords::RESERVED}, - proc::{EntryPointIndex, TypeResolution}, + proc::{EntryPointIndex, NameKey, Namer, TypeResolution}, valid::{FunctionInfo, ModuleInfo}, - Arena, ArraySize, Binding, Constant, Expression, Function, GlobalVariable, Handle, ImageClass, - ImageDimension, Module, ScalarKind, ShaderStage, Statement, StorageFormat, StructLevel, Type, - TypeInner, -}; -use crate::{ - proc::{NameKey, Namer}, - StructMember, + Arena, ArraySize, Binding, Constant, Expression, FastHashMap, Function, GlobalVariable, Handle, + ImageClass, ImageDimension, Interpolation, Module, Sampling, ScalarKind, ScalarValue, + ShaderStage, Statement, StorageFormat, StructLevel, StructMember, Type, TypeInner, }; use bit_set::BitSet; use std::fmt::Write; const INDENT: &str = " "; +const COMPONENTS: &[char] = &['x', 'y', 'z', 'w']; const BAKE_PREFIX: &str = "_e"; /// Shorthand result used internally by the backend @@ -26,10 +22,12 @@ type BackendResult = Result<(), Error>; /// WGSL attribute /// https://gpuweb.github.io/gpuweb/wgsl/#attributes enum Attribute { + Access(crate::StorageAccess), Binding(u32), Block, BuiltIn(crate::BuiltIn), Group(u32), + Interpolate(Option, Option), Location(u32), Stage(ShaderStage), Stride(u32), @@ -75,23 +73,14 @@ impl Writer { } } - pub fn write(&mut self, module: &Module, info: &ModuleInfo) -> BackendResult { + fn reset(&mut self, module: &Module) { self.names.clear(); self.namer.reset(module, RESERVED, &[], &mut self.names); + self.named_expressions.clear(); + } - // Write all constants - for (_, constant) in module.constants.iter() { - if constant.name.is_some() { - self.write_constant(&constant, true)?; - } - } - - // Write all globals - for (_, global) in module.global_variables.iter() { - if global.name.is_some() { - self.write_global(&module, &global)?; - } - } + pub fn write(&mut self, module: &Module, info: &ModuleInfo) -> BackendResult { + self.reset(module); // Write all structs for (handle, ty) in module.types.iter() { @@ -99,13 +88,46 @@ impl Writer { level, ref members, .. } = ty.inner { - let name = &self.names[&NameKey::Type(handle)].clone(); let block = level == StructLevel::Root; - self.write_struct(module, name, block, members)?; + self.write_struct(module, handle, block, members)?; writeln!(self.out)?; } } + // Write all constants + for (handle, constant) in module.constants.iter() { + if constant.name.is_some() { + self.write_global_constant(&constant, handle)?; + } + } + + // Write all globals + for (ty, global) in module.global_variables.iter() { + self.write_global(&module, &global, ty)?; + } + + if !module.global_variables.is_empty() { + // Add extra newline for readability + writeln!(self.out)?; + } + + // Write all regular functions + for (handle, function) in module.functions.iter() { + let fun_info = &info[handle]; + + let func_ctx = FunctionCtx { + ty: FunctionType::Function(handle), + info: fun_info, + expressions: &function.expressions, + }; + + // Write the function + self.write_function(&module, &function, &func_ctx)?; + + writeln!(self.out)?; + } + + // Write all entry points for (index, ep) in module.entry_points.iter().enumerate() { let attributes = match ep.stage { ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)], @@ -115,7 +137,7 @@ impl Writer { ], }; - self.write_attributes(&attributes)?; + self.write_attributes(&attributes, false)?; // Add a newline after attribute writeln!(self.out)?; @@ -125,11 +147,28 @@ impl Writer { expressions: &ep.function.expressions, }; self.write_function(&module, &ep.function, &func_ctx)?; - writeln!(self.out)?; + + if index < module.entry_points.len() - 1 { + writeln!(self.out)?; + } } - // Add a newline at the end of file - writeln!(self.out)?; + Ok(()) + } + + /// Helper method used to write [`ScalarValue`](ScalarValue) + /// + /// # Notes + /// Adds no trailing or leading whitespace + fn write_scalar_value(&mut self, value: ScalarValue) -> BackendResult { + match value { + ScalarValue::Sint(value) => write!(self.out, "{}", value)?, + ScalarValue::Uint(value) => write!(self.out, "{}", value)?, + // Floats are written using `Debug` instead of `Display` because it always appends the + // decimal part even it's zero + ScalarValue::Float(value) => write!(self.out, "{:?}", value)?, + ScalarValue::Bool(value) => write!(self.out, "{}", value)?, + } Ok(()) } @@ -145,61 +184,100 @@ impl Writer { func: &Function, func_ctx: &FunctionCtx<'_>, ) -> BackendResult { - if func.name.is_some() { - write!(self.out, "fn {}(", func.name.as_ref().unwrap())?; + let func_name = match func_ctx.ty { + FunctionType::EntryPoint(index) => self.names[&NameKey::EntryPoint(index)].clone(), + FunctionType::Function(handle) => self.names[&NameKey::Function(handle)].clone(), + }; - // Write function arguments - // TODO: another function type - if let FunctionType::EntryPoint(ep_index) = func_ctx.ty { - for (index, arg) in func.arguments.iter().enumerate() { - // Write argument attribute if a binding is present - if let Some(ref binding) = arg.binding { - self.write_attributes(&[map_binding_to_attribute(binding)])?; - write!(self.out, " ")?; - } - // Write argument name - write!( - self.out, - "{}: ", - &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)] - )?; - // Write argument type - self.write_type(module, arg.ty)?; - if index < func.arguments.len() - 1 { - // Add a separator between args - write!(self.out, ", ")?; - } + // Write function name + write!(self.out, "fn {}(", func_name)?; + + // Write function arguments + for (index, arg) in func.arguments.iter().enumerate() { + // Write argument attribute if a binding is present + if let Some(ref binding) = arg.binding { + self.write_attributes(&map_binding_to_attribute(binding), false)?; + write!(self.out, " ")?; + } + // Write argument name + let argument_name = match func_ctx.ty { + FunctionType::Function(handle) => { + self.names[&NameKey::FunctionArgument(handle, index as u32)].clone() } - write!(self.out, ")")?; - } - - // Write function return type - if let Some(ref result) = func.result { - if let Some(ref binding) = result.binding { - write!(self.out, " -> ")?; - self.write_attributes(&[map_binding_to_attribute(binding)])?; - write!(self.out, " ")?; - self.write_type(module, result.ty)?; - // Extra space only for readability - write!(self.out, " ")?; - } else { - let struct_name = &self.names[&NameKey::Type(result.ty)].clone(); - write!(self.out, " -> {} ", struct_name)?; + FunctionType::EntryPoint(ep_index) => { + self.names[&NameKey::EntryPointArgument(ep_index, index as u32)].clone() } + }; + + write!(self.out, "{}: ", argument_name)?; + // Write argument type + self.write_type(module, arg.ty)?; + if index < func.arguments.len() - 1 { + // Add a separator between args + write!(self.out, ", ")?; } - - write!(self.out, "{{")?; - writeln!(self.out)?; - - // Write the function body (statement list) - for sta in func.body.iter() { - // The indentation should always be 1 when writing the function body - self.write_stmt(&module, sta, &func_ctx, 1)?; - } - - writeln!(self.out, "}}")?; } + write!(self.out, ")")?; + + // Write function return type + if let Some(ref result) = func.result { + if let Some(ref binding) = result.binding { + write!(self.out, " -> ")?; + self.write_attributes(&map_binding_to_attribute(binding), true)?; + self.write_type(module, result.ty)?; + } else { + let struct_name = &self.names[&NameKey::Type(result.ty)].clone(); + write!(self.out, " -> {}", struct_name)?; + } + } + + write!(self.out, " {{")?; + writeln!(self.out)?; + + // Write function local variables + for (handle, local) in func.local_variables.iter() { + // Write indentation (only for readability) + write!(self.out, "{}", INDENT)?; + + // Write the local name + // The leading space is important + let name_key = match func_ctx.ty { + FunctionType::Function(func_handle) => NameKey::FunctionLocal(func_handle, handle), + FunctionType::EntryPoint(idx) => NameKey::EntryPointLocal(idx, handle), + }; + write!(self.out, "var {}: ", self.names[&name_key])?; + + // Write the local type + self.write_type(&module, local.ty)?; + + // Write the local initializer if needed + if let Some(init) = local.init { + // Put the equal signal only if there's a initializer + // The leading and trailing spaces aren't needed but help with readability + write!(self.out, " = ")?; + + // Write the constant + // `write_constant` adds no trailing or leading space/newline + self.write_constant(module, init)?; + } + + // Finish the local with `;` and add a newline (only for readability) + writeln!(self.out, ";")? + } + + if !func.local_variables.is_empty() { + writeln!(self.out)?; + } + + // Write the function body (statement list) + for sta in func.body.iter() { + // The indentation should always be 1 when writing the function body + self.write_stmt(&module, sta, &func_ctx, 1)?; + } + + writeln!(self.out, "}}")?; + self.named_expressions.clear(); Ok(()) @@ -208,45 +286,84 @@ impl Writer { /// Helper method to write a attribute /// /// # Notes - /// Adds no leading or trailing whitespace - fn write_attributes(&mut self, attributes: &[Attribute]) -> BackendResult { - write!(self.out, "[[")?; + /// Adds an extra space if required + fn write_attributes(&mut self, attributes: &[Attribute], extra_space: bool) -> BackendResult { + let mut attributes_str = String::new(); for (index, attribute) in attributes.iter().enumerate() { - match *attribute { - Attribute::Block => { - write!(self.out, "block")?; + let attribute_str = match *attribute { + Attribute::Access(access) => { + let access_str = if access.is_all() { + "read_write" + } else if access.contains(crate::StorageAccess::LOAD) { + "read" + } else { + "write" + }; + format!("access({})", access_str) } - Attribute::Location(id) => write!(self.out, "location({})", id)?, + Attribute::Block => String::from("block"), + Attribute::Location(id) => format!("location({})", id), Attribute::BuiltIn(builtin_attrib) => { let builtin_str = builtin_str(builtin_attrib); if let Some(builtin) = builtin_str { - write!(self.out, "builtin({})", builtin)? + format!("builtin({})", builtin) } else { log::warn!("Unsupported builtin attribute: {:?}", builtin_attrib); + String::from("") } } Attribute::Stage(shader_stage) => match shader_stage { - ShaderStage::Vertex => write!(self.out, "stage(vertex)")?, - ShaderStage::Fragment => write!(self.out, "stage(fragment)")?, - ShaderStage::Compute => write!(self.out, "stage(compute)")?, + ShaderStage::Vertex => String::from("stage(vertex)"), + ShaderStage::Fragment => String::from("stage(fragment)"), + ShaderStage::Compute => String::from("stage(compute)"), }, - Attribute::Stride(stride) => write!(self.out, "stride({})", stride)?, + Attribute::Stride(stride) => format!("stride({})", stride), Attribute::WorkGroupSize(size) => { - write!( - self.out, - "workgroup_size({}, {}, {})", - size[0], size[1], size[2] - )?; + format!("workgroup_size({}, {}, {})", size[0], size[1], size[2]) + } + Attribute::Binding(id) => format!("binding({})", id), + Attribute::Group(id) => format!("group({})", id), + Attribute::Interpolate(interpolation, sampling) => { + if interpolation.is_some() || sampling.is_some() { + let interpolation_str = if let Some(interpolation) = interpolation { + interpolation_str(interpolation) + } else { + "" + }; + let sampling_str = if let Some(sampling) = sampling { + // Center sampling is the default + if sampling == Sampling::Center { + String::from("") + } else { + format!(",{}", sampling_str(sampling)) + } + } else { + String::from("") + }; + format!("interpolate({}{})", interpolation_str, sampling_str) + } else { + String::from("") + } } - Attribute::Binding(id) => write!(self.out, "binding({})", id)?, - Attribute::Group(id) => write!(self.out, "group({})", id)?, }; - if index < attributes.len() - 1 { + if !attribute_str.is_empty() { // Add a separator between args - write!(self.out, ", ")?; + let separator = if index < attributes.len() - 1 { + ", " + } else { + "" + }; + attributes_str = format!("{}{}{}", attributes_str, attribute_str, separator); } } - write!(self.out, "]]")?; + if !attributes_str.is_empty() { + //TODO: looks ugly + if attributes_str.ends_with(", ") { + attributes_str = attributes_str[0..attributes_str.len() - 2].to_string(); + } + let extra_space_str = if extra_space { " " } else { "" }; + write!(self.out, "[[{}]]{}", attributes_str, extra_space_str)?; + } Ok(()) } @@ -258,41 +375,48 @@ impl Writer { fn write_struct( &mut self, module: &Module, - name: &str, + handle: Handle, block: bool, members: &[StructMember], ) -> BackendResult { if block { - self.write_attributes(&[Attribute::Block])?; + self.write_attributes(&[Attribute::Block], false)?; writeln!(self.out)?; } + let name = &self.names[&NameKey::Type(handle)].clone(); write!(self.out, "struct {} {{", name)?; writeln!(self.out)?; - for (_, member) in members.iter().enumerate() { - if member.name.is_some() { - // The indentation is only for readability - write!(self.out, "{}", INDENT)?; - if let Some(ref binding) = member.binding { - self.write_attributes(&[map_binding_to_attribute(binding)])?; - write!(self.out, " ")?; + for (index, member) in members.iter().enumerate() { + // Skip struct member with unsupported built in + if let Some(Binding::BuiltIn(builtin)) = member.binding { + if builtin_str(builtin).is_none() { + log::warn!("Skip member with unsupported builtin {:?}", builtin); + continue; } - // Write struct member name and type - write!(self.out, "{}: ", member.name.as_ref().unwrap())?; - // Write stride attribute for array struct member - if let TypeInner::Array { - base: _, - size: _, - stride, - } = module.types[member.ty].inner - { - self.write_attributes(&[Attribute::Stride(stride)])?; - write!(self.out, " ")?; - } - self.write_type(module, member.ty)?; - write!(self.out, ";")?; - writeln!(self.out)?; } + + // The indentation is only for readability + write!(self.out, "{}", INDENT)?; + if let Some(ref binding) = member.binding { + self.write_attributes(&map_binding_to_attribute(binding), true)?; + } + // Write struct member name and type + let member_name = &self.names[&NameKey::StructMember(handle, index as u32)]; + write!(self.out, "{}: ", member_name)?; + // Write stride attribute for array struct member + if let TypeInner::Array { + base: _, + size: _, + stride, + } = module.types[member.ty].inner + { + self.write_attributes(&[Attribute::Stride(stride)], true)?; + } + self.write_type(module, member.ty)?; + write!(self.out, ";")?; + writeln!(self.out)?; } + write!(self.out, "}};")?; writeln!(self.out)?; @@ -352,7 +476,7 @@ impl Writer { ), ImageClass::Depth => ("depth", "", String::from("")), ImageClass::Storage(storage_format) => ( - "storage", + "storage_", "", format!("<{}>", storage_format_str(storage_format)), ), @@ -375,7 +499,7 @@ impl Writer { ArraySize::Constant(handle) => { self.write_type(module, base)?; write!(self.out, ",")?; - self.write_constant(&module.constants[handle], false)?; + self.write_constant(module, handle)?; } ArraySize::Dynamic => { self.write_type(module, base)?; @@ -423,7 +547,7 @@ impl Writer { let min_ref_count = func_ctx.expressions[handle].bake_ref_count(); if min_ref_count <= func_ctx.info[handle].ref_count { write!(self.out, "{}", INDENT.repeat(indent))?; - self.start_baking_expr(handle, &func_ctx)?; + self.start_baking_expr(module, handle, &func_ctx)?; self.write_expr(module, handle, &func_ctx)?; writeln!(self.out, ";")?; self.named_expressions.insert(handle.index()); @@ -461,19 +585,49 @@ impl Writer { } Statement::Return { value } => { write!(self.out, "{}", INDENT.repeat(indent))?; + write!(self.out, "return")?; if let Some(return_value) = value { - write!(self.out, "return ")?; + // The leading space is important + write!(self.out, " ")?; self.write_expr(module, return_value, &func_ctx)?; - writeln!(self.out, ";")?; - } else { - writeln!(self.out, "return;")?; } + writeln!(self.out, ";")?; } // TODO: copy-paste from glsl-out Statement::Kill => { write!(self.out, "{}", INDENT.repeat(indent))?; writeln!(self.out, "discard;")? } + // TODO: copy-paste from glsl-out + Statement::Store { pointer, value } => { + write!(self.out, "{}", INDENT.repeat(indent))?; + self.write_expr(module, pointer, func_ctx)?; + write!(self.out, " = ")?; + self.write_expr(module, value, func_ctx)?; + writeln!(self.out, ";")? + } + crate::Statement::Call { + function, + ref arguments, + result, + } => { + write!(self.out, "{}", INDENT.repeat(indent))?; + if let Some(expr) = result { + self.start_baking_expr(module, expr, &func_ctx)?; + self.named_expressions.insert(expr.index()); + } + let func_name = &self.names[&NameKey::Function(function)]; + write!(self.out, "{}(", func_name)?; + for (index, argument) in arguments.iter().enumerate() { + self.write_expr(module, *argument, func_ctx)?; + // Only write a comma if isn't the last element + if index != arguments.len().saturating_sub(1) { + // The leading space is for readability only + write!(self.out, ", ")?; + } + } + writeln!(self.out, ");")? + } _ => { return Err(Error::Unimplemented(format!("write_stmt {:?}", stmt))); } @@ -484,6 +638,7 @@ impl Writer { fn start_baking_expr( &mut self, + module: &Module, handle: Handle, context: &FunctionCtx, ) -> BackendResult { @@ -492,6 +647,9 @@ impl Writer { let ty = &context.info[handle].ty; // Write variable type match *ty { + TypeResolution::Handle(ty_handle) => { + self.write_type(module, ty_handle)?; + } TypeResolution::Value(crate::TypeInner::Scalar { kind, .. }) => { write!(self.out, "{}", scalar_kind_str(kind))?; } @@ -530,15 +688,74 @@ impl Writer { } match *expression { - Expression::Constant(constant) => { - self.write_constant(&module.constants[constant], false)? - } + Expression::Constant(constant) => self.write_constant(module, constant)?, Expression::Compose { ty, ref components } => { self.write_type(&module, ty)?; write!(self.out, "(")?; - self.write_slice(components, |this, _, arg| { - this.write_expr(&module, *arg, func_ctx) - })?; + // !spv-in specific notes! + // WGSL does not support all SPIR-V builtins and we should skip it in generated shaders. + // We already skip them when we generate struct type. + // Now we need to find components that used struct with ignored builtins. + + // So, why we can't just return the error to a user? + // We can, but otherwise, we can't generate WGSL shader from any glslang SPIR-V shaders. + // glslang generates gl_PerVertex struct with gl_CullDistance, gl_ClipDistance and gl_PointSize builtin inside by default. + // All of them are not supported by WGSL. + + // We need to copy components to another vec because we don't know which of them we should write. + let mut components_to_write = Vec::with_capacity(components.len()); + for component in components { + let mut skip_component = false; + if let Expression::Load { pointer } = func_ctx.expressions[*component] { + if let Expression::AccessIndex { + base, + index: access_index, + } = func_ctx.expressions[pointer] + { + let base_ty_res = &func_ctx.info[base].ty; + let resolved = base_ty_res.inner_with(&module.types); + if let TypeInner::Pointer { + base: pointer_base_handle, + .. + } = *resolved + { + // Let's check that we try to access a struct member with unsupported built-in and skip it. + if let TypeInner::Struct { ref members, .. } = + module.types[pointer_base_handle].inner + { + if let Some(Binding::BuiltIn(builtin)) = + members[access_index as usize].binding + { + if builtin_str(builtin).is_none() { + // glslang why you did this with us... + log::warn!( + "Skip component with unsupported builtin {:?}", + builtin + ); + skip_component = true; + } + } + } + } + } + } + if skip_component { + continue; + } else { + components_to_write.push(*component); + } + } + + // non spv-in specific notes! + // Real `Expression::Compose` logic generates here. + for (index, component) in components_to_write.iter().enumerate() { + self.write_expr(module, *component, &func_ctx)?; + // Only write a comma if isn't the last element + if index != components_to_write.len().saturating_sub(1) { + // The leading space is for readability only + write!(self.out, ", ")?; + } + } write!(self.out, ")")? } Expression::FunctionArgument(pos) => { @@ -642,6 +859,7 @@ impl Writer { format!("mat{}x{}", vector_size_str(columns), vector_size_str(rows)) } TypeInner::Vector { size, .. } => format!("vec{}", vector_size_str(size)), + TypeInner::Scalar { kind, .. } => String::from(scalar_kind_str(kind)), _ => { return Err(Error::Unimplemented(format!( "write_expr expression::as {:?}", @@ -672,6 +890,64 @@ impl Writer { self.write_expr(module, value, func_ctx)?; write!(self.out, ")")?; } + //TODO: add pointer logic + Expression::Load { pointer } => self.write_expr(module, pointer, func_ctx)?, + Expression::LocalVariable(handle) => { + let name_key = match func_ctx.ty { + FunctionType::Function(func_handle) => { + NameKey::FunctionLocal(func_handle, handle) + } + FunctionType::EntryPoint(idx) => NameKey::EntryPointLocal(idx, handle), + }; + write!(self.out, "{}", self.names[&name_key])? + } + Expression::ArrayLength(expr) => { + write!(self.out, "arrayLength(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")")?; + } + Expression::Math { + fun, + arg, + arg1, + arg2, + } => { + use crate::MathFunction as Mf; + + let fun_name = match fun { + Mf::Length => "length", + Mf::Mix => "mix", + _ => { + return Err(Error::Unimplemented(format!( + "write_expr Math func {:?}", + fun + ))); + } + }; + + write!(self.out, "{}(", fun_name)?; + self.write_expr(module, arg, func_ctx)?; + if let Some(arg) = arg1 { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } + if let Some(arg) = arg2 { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } + write!(self.out, ")")? + } + Expression::Swizzle { + size, + vector, + pattern, + } => { + self.write_expr(module, vector, func_ctx)?; + write!(self.out, ".")?; + for &sc in pattern[..size as usize].iter() { + self.out.write_char(COMPONENTS[sc as usize])?; + } + } _ => { return Err(Error::Unimplemented(format!("write_expr {:?}", expression))); } @@ -680,115 +956,120 @@ impl Writer { Ok(()) } - /// Helper method that writes a list of comma separated `T` with a writer function `F` - /// - /// The writer function `F` receives a mutable reference to `self` that if needed won't cause - /// borrow checker issues (using for example a closure with `self` will cause issues), the - /// second argument is the 0 based index of the element on the list, and the last element is - /// a reference to the element `T` being written - /// - /// # Notes - /// - Adds no newlines or leading/trailing whitespace - /// - The last element won't have a trailing `,` - // TODO: copy-paste from glsl-out - fn write_slice BackendResult>( - &mut self, - data: &[T], - mut f: F, - ) -> BackendResult { - // Loop trough `data` invoking `f` for each element - for (i, item) in data.iter().enumerate() { - f(self, i as u32, item)?; - - // Only write a comma if isn't the last element - if i != data.len().saturating_sub(1) { - // The leading space is for readability only - write!(self.out, ", ")?; - } - } - - Ok(()) - } - /// Helper method used to write global variables - fn write_global(&mut self, module: &Module, global: &GlobalVariable) -> BackendResult { + /// # Notes + /// Always adds a newline + fn write_global( + &mut self, + module: &Module, + global: &GlobalVariable, + handle: Handle, + ) -> BackendResult { + let name = self.names[&NameKey::GlobalVariable(handle)].clone(); + // Write group and dinding attributes if present if let Some(ref binding) = global.binding { - self.write_attributes(&[ - Attribute::Group(binding.group), - Attribute::Binding(binding.binding), - ])?; - write!(self.out, " ")?; - } - - if let Some(ref name) = global.name { - // First write only global name - write!(self.out, "var {}: ", name)?; - // Write global type - self.write_type(module, global.ty)?; - // End with semicolon and extra newline for readability - writeln!(self.out, ";")?; + self.write_attributes( + &[ + Attribute::Group(binding.group), + Attribute::Binding(binding.binding), + ], + false, + )?; writeln!(self.out)?; } + // First write only global name + write!(self.out, "var {}: ", name)?; + // Write access attribute if present + if !global.storage_access.is_empty() { + self.write_attributes(&[Attribute::Access(global.storage_access)], true)?; + } + // Write global type + self.write_type(module, global.ty)?; + // End with semicolon + writeln!(self.out, ";")?; + Ok(()) } /// Helper method used to write constants /// /// # Notes - /// Adds newlines for global constants - fn write_constant(&mut self, constant: &Constant, global: bool) -> BackendResult { + /// Doesn't add any newlines or leading/trailing spaces + fn write_constant(&mut self, module: &Module, handle: Handle) -> BackendResult { + let constant = &module.constants[handle]; match constant.inner { crate::ConstantInner::Scalar { width: _, ref value, } => { if let Some(ref name) = constant.name { - if global { - // First write only constant name - write!(self.out, "let {}: ", name)?; - // Next write constant type and value - match *value { - crate::ScalarValue::Sint(value) => { - write!(self.out, "i32 = {}", value)?; - } - crate::ScalarValue::Uint(value) => { - write!(self.out, "u32 = {}", value)?; - } - crate::ScalarValue::Float(value) => { - write!(self.out, "f32 = {}", value)?; - } - crate::ScalarValue::Bool(value) => { - write!(self.out, "bool = {}", value)?; - } - }; - // End with semicolon and extra newline for readability - writeln!(self.out, ";")?; - writeln!(self.out)?; - } else { - write!(self.out, "{}", name)?; - } + write!(self.out, "{}", name)?; } else { - match *value { - crate::ScalarValue::Sint(value) => { - write!(self.out, "{}", value)?; - } - crate::ScalarValue::Uint(value) => { - write!(self.out, "{}", value)?; - } - // TODO: fix float - crate::ScalarValue::Float(value) => { - write!(self.out, "{:.1}", value)?; - } - crate::ScalarValue::Bool(value) => { - write!(self.out, "{}", value)?; - } - }; + self.write_scalar_value(*value)?; } } + crate::ConstantInner::Composite { ty, ref components } => { + self.write_type(module, ty)?; + write!(self.out, "(")?; + + // Write the comma separated constants + for (index, constant) in components.iter().enumerate() { + self.write_constant(module, *constant)?; + // Only write a comma if isn't the last element + if index != components.len().saturating_sub(1) { + // The leading space is for readability only + write!(self.out, ", ")?; + } + } + write!(self.out, ")")? + } + } + + Ok(()) + } + + /// Helper method used to write global constants + /// + /// # Notes + /// Ends in a newline + fn write_global_constant( + &mut self, + constant: &Constant, + handle: Handle, + ) -> BackendResult { + match constant.inner { + crate::ConstantInner::Scalar { + width: _, + ref value, + } => { + let name = self.names[&NameKey::Constant(handle)].clone(); + // First write only constant name + write!(self.out, "let {}: ", name)?; + // Next write constant type and value + match *value { + crate::ScalarValue::Sint(value) => { + write!(self.out, "i32 = {}", value)?; + } + crate::ScalarValue::Uint(value) => { + write!(self.out, "u32 = {}", value)?; + } + crate::ScalarValue::Float(value) => { + // Floats are written using `Debug` instead of `Display` because it always appends the + // decimal part even it's zero + write!(self.out, "f32 = {:?}", value)?; + } + crate::ScalarValue::Bool(value) => { + write!(self.out, "bool = {}", value)?; + } + }; + // End with semicolon and extra newline for readability + writeln!(self.out, ";")?; + writeln!(self.out)?; + } _ => { return Err(Error::Unimplemented(format!( - "write_constant {:?}", + "write_global_constant {:?}", constant.inner ))); } @@ -876,10 +1157,34 @@ fn storage_format_str(format: StorageFormat) -> &'static str { } } -fn map_binding_to_attribute(binding: &Binding) -> Attribute { - match *binding { - Binding::BuiltIn(built_in) => Attribute::BuiltIn(built_in), - //TODO: Interpolation - Binding::Location { location, .. } => Attribute::Location(location), +/// Helper function that returns the string corresponding to the WGSL interpolation qualifier +fn interpolation_str(interpolation: Interpolation) -> &'static str { + match interpolation { + Interpolation::Perspective => "perspective", + Interpolation::Linear => "linear", + Interpolation::Flat => "flat", + } +} + +/// Return the WGSL auxiliary qualifier for the given sampling value. +fn sampling_str(sampling: Sampling) -> &'static str { + match sampling { + Sampling::Center => "", + Sampling::Centroid => "centroid", + Sampling::Sample => "sample", + } +} + +fn map_binding_to_attribute(binding: &Binding) -> Vec { + match *binding { + Binding::BuiltIn(built_in) => vec![Attribute::BuiltIn(built_in)], + Binding::Location { + location, + interpolation, + sampling, + } => vec![ + Attribute::Location(location), + Attribute::Interpolate(interpolation, sampling), + ], } } diff --git a/tests/out/access.wgsl b/tests/out/access.wgsl deleted file mode 100644 index 4b1cbc85ec..0000000000 --- a/tests/out/access.wgsl +++ /dev/null @@ -1,6 +0,0 @@ -[[stage(vertex)]] -fn foo([[builtin(vertex_index)]] vi: u32) -> [[builtin(position)]] vec4 { - return vec4(vec4(array(1, 2, 3, 4, 5)[vi])); -} - - diff --git a/tests/out/empty.wgsl b/tests/out/empty.wgsl index 31548a146c..1ee73b5809 100644 --- a/tests/out/empty.wgsl +++ b/tests/out/empty.wgsl @@ -1,6 +1,4 @@ [[stage(compute), workgroup_size(1, 1, 1)]] -fn main(){ +fn main() { return; } - - diff --git a/tests/out/quad-vert.wgsl b/tests/out/quad-vert.wgsl new file mode 100644 index 0000000000..66a173b4cf --- /dev/null +++ b/tests/out/quad-vert.wgsl @@ -0,0 +1,29 @@ +[[block]] +struct gl_PerVertex { + [[builtin(position)]] gl_Position: vec4; +}; + +struct type10 { + [[location(0), interpolate(perspective)]] member: vec2; + [[builtin(position)]] gl_Position1: vec4; +}; + +var v_uv: vec2; +var a_uv: vec2; +var perVertexStruct: gl_PerVertex; +var a_pos: vec2; + +fn main() { + v_uv = a_uv; + let _e13: vec2 = a_pos; + perVertexStruct.gl_Position = vec4(_e13[0], _e13[1], 0.0, 1.0); + return; +} + +[[stage(vertex)]] +fn main1([[location(1)]] a_uv1: vec2, [[location(0)]] a_pos1: vec2) -> type10 { + a_uv = a_uv1; + a_pos = a_pos1; + main(); + return type10(v_uv, perVertexStruct.gl_Position); +} diff --git a/tests/out/quad.wgsl b/tests/out/quad.wgsl index ba45b48c26..fe924dc765 100644 --- a/tests/out/quad.wgsl +++ b/tests/out/quad.wgsl @@ -1,26 +1,25 @@ -let c_scale: f32 = 1.2; - -[[group(0), binding(0)]] var u_texture: texture_2d; - -[[group(0), binding(1)]] var u_sampler: sampler; - struct VertexOutput { - [[location(0)]] uv: vec2; + [[location(0), interpolate(perspective)]] uv: vec2; [[builtin(position)]] position: vec4; }; +let c_scale: f32 = 1.2; + +[[group(0), binding(0)]] +var u_texture: texture_2d; +[[group(0), binding(1)]] +var u_sampler: sampler; + [[stage(vertex)]] fn main([[location(0)]] pos: vec2, [[location(1)]] uv1: vec2) -> VertexOutput { return VertexOutput(uv1, vec4(c_scale * pos, 0.0, 1.0)); } [[stage(fragment)]] -fn main([[location(0)]] uv2: vec2) -> [[location(0)]] vec4 { +fn main1([[location(0), interpolate(perspective)]] uv2: vec2) -> [[location(0)]] vec4 { let _e4: vec4 = textureSample(u_texture, u_sampler, uv2); if (_e4[3] == 0.0) { discard; } return _e4[3] * _e4; } - - diff --git a/tests/snapshots.rs b/tests/snapshots.rs index b628075ce5..5fe0c2b450 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -303,7 +303,7 @@ fn convert_spv(name: &str, adjust_coordinate_space: bool, targets: Targets) { #[cfg(feature = "spv-in")] #[test] fn convert_spv_quad_vert() { - convert_spv("quad-vert", false, Targets::METAL | Targets::GLSL); + convert_spv("quad-vert", false, Targets::METAL | Targets::GLSL | Targets::WGSL); } #[cfg(feature = "spv-in")]