diff --git a/Cargo.toml b/Cargo.toml index 125014f064..8db524e08f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,8 @@ license = "MPL-2.0" bitflags = "1" fxhash = "0.2" log = "0.4" +pest = "2" +pest_derive = "2" spirv_headers = "1" [dev-dependencies] diff --git a/README.md b/README.md index cdb0b3164b..a18d882e14 100644 --- a/README.md +++ b/README.md @@ -12,14 +12,14 @@ This is an experimental shader translation library for the needs of gfx-rs proje Front-end | Status | Notes | --------------- | ------------------ | ----- | SPIR-V (binary) | :construction: | | -Tint | | | +WGSL (Tint) | :construction: | | GLSL (Vulkan) | | | Rust | | | Back-end | Status | Notes | --------------- | ------------------ | ----- | SPIR-V (binary) | | | -Tint | | | +WGSL | | | MSL | :construction: | | HLSL | | | GLSL | | | diff --git a/examples/convert.rs b/examples/convert.rs index e5ce528aaa..f10909462e 100644 --- a/examples/convert.rs +++ b/examples/convert.rs @@ -4,23 +4,40 @@ fn main() { env_logger::init(); let args = env::args().collect::>(); - let input = fs::read(&args[1]).unwrap(); - let module = naga::front::spirv::parse_u8_slice(&input).unwrap(); - //println!("{:?}", module); - - let mut binding_map = naga::back::msl::BindingMap::default(); - binding_map.insert( - naga::back::msl::BindSource { set: 0, binding: 0 }, - naga::back::msl::BindTarget { buffer: None, texture: Some(1), sampler: None }, - ); - binding_map.insert( - naga::back::msl::BindSource { set: 0, binding: 1 }, - naga::back::msl::BindTarget { buffer: None, texture: None, sampler: Some(1) }, - ); - let options = naga::back::msl::Options { - binding_map: &binding_map, + let module = if args.len() <= 1 { + println!("Call with "); + return + } else if args[1].ends_with(".spv") { + let input = fs::read(&args[1]).unwrap(); + naga::front::spirv::parse_u8_slice(&input).unwrap() + } else if args[1].ends_with(".wgsl") { + let input = fs::read_to_string(&args[1]).unwrap(); + naga::front::wgsl::parse_str(&input).unwrap() + } else { + panic!("Unknown input: {:?}", args[1]); }; - let msl = naga::back::msl::write_string(&module, options).unwrap(); - println!("{}", msl); + + if args.len() <= 2 { + println!("{:?}", module); + return + } else if args[2].ends_with(".msl") { + use naga::back::msl; + let mut binding_map = msl::BindingMap::default(); + binding_map.insert( + msl::BindSource { set: 0, binding: 0 }, + msl::BindTarget { buffer: None, texture: Some(1), sampler: None }, + ); + binding_map.insert( + msl::BindSource { set: 0, binding: 1 }, + msl::BindTarget { buffer: None, texture: None, sampler: Some(1) }, + ); + let options = msl::Options { + binding_map: &binding_map, + }; + let msl = msl::write_string(&module, options).unwrap(); + fs::write(&args[2], msl).unwrap(); + } else { + panic!("Unknown output: {:?}", args[2]); + } } diff --git a/grammars/wgsl.pest b/grammars/wgsl.pest new file mode 100644 index 0000000000..831dd99104 --- /dev/null +++ b/grammars/wgsl.pest @@ -0,0 +1,90 @@ +translation_unit = _{ SOI ~ global_decl* ~ EOI } + +global_decl = { + ";" + | import_decl ~ ";" + | global_variable_decl ~ ";" +// | global_constant_decl SEMICOLON +// | entry_point_decl SEMICOLON +// | type_alias SEMICOLON +// | function_decl +} + +import_decl = { "import" ~ string ~ "as" ~ (ident ~ "::")* ~ ident} + +global_variable_decl = { + variable_decoration_list ~ variable_decl + | variable_decoration_list ~ variable_decl ~ "=" ~ const_expr +} + +variable_decoration_list = { "[[" ~ (variable_decoration ~ ",")* ~ variable_decoration ~ "]]" } + +variable_decoration = { + "location" ~ int_literal +// | BUILTIN builtin_decoration +// | BINDING INT_LITERAL +// | SET INT_LITERAL +} + +variable_decl = { "var" ~ variable_storage_decoration? ~ variable_ident_decl } +variable_storage_decoration = _{ "<" ~ storage_class ~ ">" } +variable_ident_decl = { ident ~ ":" ~ type_decl } + +storage_class = { + "in" + | "out" + | "uniform" +// | WORKGROUP +// | UNIFORM_CONSTANT +// | STORAGE_BUFFER +// | IMAGE +// | PUSH_CONSTANT + | "private" + | "function" +} + +const_literal = { + int_literal +// | UINT_LITERAL +// | FLOAT_LITERAL + | "true" + | "false" +} + +const_expr = { + type_decl ~ "(" ~ (const_expr ~ ",")? ~ const_expr ~ ")" + | const_literal +} + +type_decl = { + scalar_type +// | VEC2 LESS_THAN type_decl GREATER_THAN +// | VEC3 LESS_THAN type_decl GREATER_THAN +// | VEC3 LESS_THAN type_decl GREATER_THAN +// | PTR LESS_THAN storage_class, type_decl GREATER_THAN +// | ARRAY LESS_THAN type_decl COMMA INT_LITERAL GREATER_THAN +// | ARRAY LESS_THAN type_decl GREATER_THAN +// | MAT2x2 LESS_THAN type_decl GREATER_THAN +// | MAT2x3 LESS_THAN type_decl GREATER_THAN +// | MAT2x4 LESS_THAN type_decl GREATER_THAN +// | MAT3x2 LESS_THAN type_decl GREATER_THAN +// | MAT3x3 LESS_THAN type_decl GREATER_THAN +// | MAT3x4 LESS_THAN type_decl GREATER_THAN +// | MAT4x2 LESS_THAN type_decl GREATER_THAN +// | MAT4x3 LESS_THAN type_decl GREATER_THAN +// | MAT4x4 LESS_THAN type_decl GREATER_THAN + | ident +} + +scalar_type = { + "bool" + | "f32" + | "i32" + | "u32" +} + +ident = @{ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_")* } +int_literal = @{ ("-"? ~ "0x"? ~ (ASCII_DIGIT | 'a'..'f' | 'A'..'F')+) | "0" | ("-"? ~ '1'..'9' ~ ASCII_DIGIT*) } +string = @{ "\"" ~ ( "\"\"" | (!"\"" ~ ANY) )* ~ "\"" } + +WHITESPACE = _{ " " | "\t" | "\n" } diff --git a/src/front/mod.rs b/src/front/mod.rs index e4919e7735..9c549d16cb 100644 --- a/src/front/mod.rs +++ b/src/front/mod.rs @@ -1 +1,31 @@ pub mod spirv; +pub mod wgsl; + +use crate::storage::Storage; + +pub const GENERATOR: u32 = 0; + +impl crate::Module { + fn from_header(header: crate::Header) -> Self { + crate::Module { + header, + complex_types: crate::ComplexTypes { + pointers: Storage::new(), + arrays: Storage::new(), + structs: Storage::new(), + images: Storage::new(), + samplers: Storage::new(), + }, + global_variables: Storage::new(), + functions: Storage::new(), + entry_points: Vec::new(), + } + } + + fn generate_empty() -> Self { + Self::from_header(crate::Header { + version: (1, 0, 0), + generator: GENERATOR, + }) + } +} diff --git a/src/front/spirv.rs b/src/front/spirv.rs index 67d606cde1..77daeec199 100644 --- a/src/front/spirv.rs +++ b/src/front/spirv.rs @@ -668,31 +668,19 @@ impl> Parser { } pub fn parse(&mut self) -> Result { - let mut module = crate::Module { - header: { - if self.next()? != spirv::MAGIC_NUMBER { - return Err(Error::InvalidHeader); - } - let version_raw = self.next()?.to_le_bytes(); - let generator = self.next()?; - let _bound = self.next()?; - let _schema = self.next()?; - crate::Header { - version: (version_raw[2], version_raw[1], version_raw[0]), - generator, - } - }, - complex_types: crate::ComplexTypes { - pointers: Storage::new(), - arrays: Storage::new(), - structs: Storage::new(), - images: Storage::new(), - samplers: Storage::new(), - }, - global_variables: Storage::new(), - functions: Storage::new(), - entry_points: Vec::new(), - }; + let mut module = crate::Module::from_header({ + if self.next()? != spirv::MAGIC_NUMBER { + return Err(Error::InvalidHeader); + } + let version_raw = self.next()?.to_le_bytes(); + let generator = self.next()?; + let _bound = self.next()?; + let _schema = self.next()?; + crate::Header { + version: (version_raw[2], version_raw[1], version_raw[0]), + generator, + } + }); let mut entry_points = Vec::new(); while let Ok(inst) = self.next_inst() { diff --git a/src/front/wgsl.rs b/src/front/wgsl.rs new file mode 100644 index 0000000000..12f3c15596 --- /dev/null +++ b/src/front/wgsl.rs @@ -0,0 +1,132 @@ +#[derive(Parser)] +#[grammar = "../grammars/wgsl.pest"] +struct Tokenizer; + +#[derive(Debug)] +pub enum Error { + Pest(pest::error::Error), + BadInt(std::num::ParseIntError), + BadStorageClass(String), +} +impl From> for Error { + fn from(error: pest::error::Error) -> Self { + Error::Pest(error) + } +} +impl From for Error { + fn from(error: std::num::ParseIntError) -> Self { + Error::BadInt(error) + } +} + +pub struct Parser { +} + +impl Parser { + fn parse_uint_literal(pair: pest::iterators::Pair) -> Result { + Ok(pair.as_str().parse()?) + } + + fn _parse_int_literal(pair: pest::iterators::Pair) -> Result { + let istr = pair.as_str(); + let (sign, istr) = match &istr[..1] { + "_" => (-1, &istr[1..]), + _ => (1, &istr[..]), + }; + let integer: i32 = istr.parse()?; + Ok(sign * integer) + } + + fn parse_decoration_list(variable_decoration_list: pest::iterators::Pair) -> Result, Error> { + assert_eq!(variable_decoration_list.as_rule(), Rule::variable_decoration_list); + for variable_decoration in variable_decoration_list.into_inner() { + assert_eq!(variable_decoration.as_rule(), Rule::variable_decoration); + let mut inner = variable_decoration.into_inner(); + let first = inner.next().unwrap(); + match first.as_rule() { + Rule::int_literal => { + let location = Self::parse_uint_literal(first)?; + return Ok(Some(crate::Binding::Location(location))); + } + unknown => panic!("Unexpected decoration: {:?}", unknown), + } + } + unimplemented!() + } + + fn parse_storage_class(storage_class: pest::iterators::Pair) -> Result { + match storage_class.as_str() { + "in" => Ok(spirv::StorageClass::Input), + "out" => Ok(spirv::StorageClass::Output), + other => Err(Error::BadStorageClass(other.to_owned())), + } + } + + fn parse_type_decl(type_decl: pest::iterators::Pair) -> Result { + assert_eq!(type_decl.as_rule(), Rule::type_decl); + let inner = type_decl.into_inner().next().unwrap(); + match inner.as_rule() { + Rule::scalar_type => match inner.as_str() { + "f32" => Ok(crate::Type::Scalar { kind: crate::ScalarKind::Float, width: 32 }), + "i32" => Ok(crate::Type::Scalar { kind: crate::ScalarKind::Sint, width: 32 }), + "u32" => Ok(crate::Type::Scalar { kind: crate::ScalarKind::Uint, width: 32 }), + other => panic!("Unexpected scalar {:?}", other), + }, + Rule::ident => unimplemented!(), + other => panic!("Unexpected type {:?}", other), + } + } + + pub fn parse(&mut self, source: &str) -> Result { + use pest::Parser as _; + let pairs = Tokenizer::parse(Rule::translation_unit, source)?; + let mut module = crate::Module::generate_empty(); + for global_decl_maybe in pairs { + match global_decl_maybe.as_rule() { + Rule::global_decl => { + let global_decl = global_decl_maybe.into_inner().next().unwrap(); + match global_decl.as_rule() { + Rule::import_decl => { + let mut import_decl = global_decl.into_inner(); + let path = import_decl.next().unwrap().as_str(); + log::warn!("Ignoring import {:?}", path); + } + Rule::global_variable_decl => { + let mut global_decl_pairs = global_decl.into_inner(); + let binding = Self::parse_decoration_list(global_decl_pairs.next().unwrap())?; + let var_decl = global_decl_pairs.next().unwrap(); + assert_eq!(var_decl.as_rule(), Rule::variable_decl); + let mut var_decl_pairs = var_decl.into_inner(); + let mut body = var_decl_pairs.next().unwrap(); + let class = if body.as_rule() == Rule::storage_class { + let class = Self::parse_storage_class(body)?; + body = var_decl_pairs.next().unwrap(); + class + } else { + spirv::StorageClass::Private + }; + assert_eq!(body.as_rule(), Rule::variable_ident_decl); + let mut var_ident_decl_pairs = body.into_inner(); + let name = var_ident_decl_pairs.next().unwrap().as_str().to_owned(); + let ty = Self::parse_type_decl(var_ident_decl_pairs.next().unwrap())?; + module.global_variables.append(crate::GlobalVariable { + name: Some(name), + class, + binding, + ty, + }); + } + unknown => panic!("Unexpected global decl: {:?}", unknown), + } + } + Rule::EOI => break, + unknown => panic!("Unexpected: {:?}", unknown), + } + } + Ok(module) + } +} + +pub fn parse_str(source: &str) -> Result { + Parser{}.parse(source) +} diff --git a/src/lib.rs b/src/lib.rs index 4bda73b8de..5ab2f86762 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +#[macro_use] +extern crate pest_derive; extern crate spirv_headers as spirv; pub mod back; @@ -15,7 +17,7 @@ use std::{ type FastHashMap = HashMap>; type FastHashSet = HashSet>; -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Header { pub version: (u8, u8, u8), pub generator: u32,