[wgsl] support unsigned literals

This commit is contained in:
Dzmitry Malyshau
2021-01-23 01:57:02 -05:00
committed by Dzmitry Malyshau
parent 5b35b04546
commit 0ea8a0a3c2
14 changed files with 156 additions and 105 deletions

View File

@@ -13,7 +13,8 @@ fn consume_any(input: &str, what: impl Fn(char) -> bool) -> (&str, &str) {
input.split_at(pos)
}
fn consume_number(input: &str) -> (&str, &str) {
fn consume_number(input: &str) -> (Token, &str) {
//Note: I wish this function was simpler and faster...
let mut is_first_char = true;
let mut right_after_exponent = false;
@@ -32,7 +33,28 @@ fn consume_number(input: &str) -> (&str, &str) {
}
};
let pos = input.find(|c| !what(c)).unwrap_or_else(|| input.len());
input.split_at(pos)
let (value, rest) = input.split_at(pos);
let mut rest_iter = rest.chars();
let ty = rest_iter.next().unwrap_or(' ');
match ty {
'u' | 'i' | 'f' => {
let width_end = rest_iter
.position(|c| !('0'..='9').contains(&c))
.unwrap_or_else(|| rest.len() - 1);
let (width, rest) = rest[1..].split_at(width_end);
(Token::Number { value, ty, width }, rest)
}
// default to `i32` or `f32`
_ => (
Token::Number {
value,
ty: if value.contains('.') { 'f' } else { 'i' },
width: "",
},
rest,
),
}
}
fn consume_token(mut input: &str) -> (Token<'_>, &str) {
@@ -56,10 +78,7 @@ fn consume_token(mut input: &str) -> (Token<'_>, &str) {
'.' => {
let og_chars = chars.as_str();
match chars.next() {
Some('0'..='9') => {
let (number, rest) = consume_number(input);
(Token::Number(number), rest)
}
Some('0'..='9') => consume_number(input),
_ => (Token::Separator(cur), og_chars),
}
}
@@ -83,10 +102,7 @@ fn consume_token(mut input: &str) -> (Token<'_>, &str) {
(Token::Paren(cur), input)
}
}
'0'..='9' => {
let (number, rest) = consume_number(input);
(Token::Number(number), rest)
}
'0'..='9' => consume_number(input),
'a'..='z' | 'A'..='Z' | '_' => {
let (word, rest) = consume_any(input, |c| c.is_ascii_alphanumeric() || c == '_');
(Token::Word(word), rest)
@@ -115,10 +131,7 @@ fn consume_token(mut input: &str) -> (Token<'_>, &str) {
let og_chars = chars.as_str();
match chars.next() {
Some('>') => (Token::Arrow, chars.as_str()),
Some('0'..='9') | Some('.') => {
let (number, rest) = consume_number(input);
(Token::Number(number), rest)
}
Some('0'..='9') | Some('.') => consume_number(input),
_ => (Token::Operation(cur), og_chars),
}
}
@@ -193,21 +206,25 @@ impl<'a> Lexer<'a> {
fn _next_float_literal(&mut self) -> Result<f32, Error<'a>> {
match self.next() {
Token::Number(word) => word.parse().map_err(|err| Error::BadFloat(word, err)),
Token::Number { value, .. } => value.parse().map_err(|err| Error::BadFloat(value, err)),
other => other.unexpected("float literal"),
}
}
pub(super) fn next_uint_literal(&mut self) -> Result<u32, Error<'a>> {
match self.next() {
Token::Number(word) => word.parse().map_err(|err| Error::BadInteger(word, err)),
Token::Number { value, .. } => {
value.parse().map_err(|err| Error::BadInteger(value, err))
}
other => other.unexpected("uint literal"),
}
}
pub(super) fn next_sint_literal(&mut self) -> Result<i32, Error<'a>> {
match self.next() {
Token::Number(word) => word.parse().map_err(|err| Error::BadInteger(word, err)),
Token::Number { value, .. } => {
value.parse().map_err(|err| Error::BadInteger(value, err))
}
other => other.unexpected("sint literal"),
}
}
@@ -246,7 +263,39 @@ fn sub_test(source: &str, expected_tokens: &[Token]) {
#[test]
fn test_tokens() {
sub_test("id123_OK", &[Token::Word("id123_OK")]);
sub_test("92No", &[Token::Number("92"), Token::Word("No")]);
sub_test(
"92No",
&[
Token::Number {
value: "92",
ty: 'i',
width: "",
},
Token::Word("No"),
],
);
sub_test(
"2u3o",
&[
Token::Number {
value: "2",
ty: 'u',
width: "3",
},
Token::Word("o"),
],
);
sub_test(
"2.4f44po",
&[
Token::Number {
value: "2.4",
ty: 'f',
width: "44",
},
Token::Word("po"),
],
);
sub_test(
"æNoø",
&[Token::Unknown('æ'), Token::Word("No"), Token::Unknown('ø')],
@@ -264,7 +313,11 @@ fn test_variable_decl() {
Token::DoubleParen('['),
Token::Word("group"),
Token::Paren('('),
Token::Number("0"),
Token::Number {
value: "0",
ty: 'i',
width: "",
},
Token::Paren(')'),
Token::DoubleParen(']'),
Token::Word("var"),

View File

@@ -23,7 +23,11 @@ pub enum Token<'a> {
DoubleColon,
Paren(char),
DoubleParen(char),
Number(&'a str),
Number {
value: &'a str,
ty: char,
width: &'a str,
},
String(&'a str),
Word(&'a str),
Operation(char),
@@ -50,6 +54,8 @@ pub enum Error<'a> {
BadInteger(&'a str, std::num::ParseIntError),
#[error("unable to parse `{1}` as float: {1}")]
BadFloat(&'a str, std::num::ParseFloatError),
#[error("unable to parse `{0}{1}{2}` as scalar width: {3}")]
BadScalarWidth(&'a str, char, &'a str, std::num::ParseIntError),
#[error("bad field accessor `{0}`")]
BadAccessor(&'a str),
#[error("bad texture {0}`")]
@@ -369,16 +375,37 @@ impl Parser {
}
}
fn get_scalar_value(word: &str) -> Result<crate::ScalarValue, Error<'_>> {
if word.contains('.') {
word.parse()
.map(crate::ScalarValue::Float)
.map_err(|err| Error::BadFloat(word, err))
} else {
word.parse()
fn get_constant_inner<'a>(
word: &'a str,
ty: char,
width: &'a str,
) -> Result<crate::ConstantInner, Error<'a>> {
let value = match ty {
'i' => word
.parse()
.map(crate::ScalarValue::Sint)
.map_err(|err| Error::BadInteger(word, err))
}
.map_err(|err| Error::BadInteger(word, err))?,
'u' => word
.parse()
.map(crate::ScalarValue::Uint)
.map_err(|err| Error::BadInteger(word, err))?,
'f' => word
.parse()
.map(crate::ScalarValue::Float)
.map_err(|err| Error::BadFloat(word, err))?,
_ => unreachable!(),
};
Ok(crate::ConstantInner::Scalar {
value,
width: if width.is_empty() {
4
} else {
match width.parse::<crate::Bytes>() {
Ok(bits) => bits / 8,
Err(e) => return Err(Error::BadScalarWidth(word, ty, width, e)),
}
},
})
}
fn parse_function_call_inner<'a>(
@@ -698,12 +725,9 @@ impl Parser {
value: crate::ScalarValue::Bool(false),
}
}
Token::Number(word) => {
Token::Number { value, ty, width } => {
let _ = lexer.next();
crate::ConstantInner::Scalar {
width: 4,
value: Self::get_scalar_value(word)?,
}
Self::get_constant_inner(value, ty, width)?
}
_ => {
let (composite_ty, _access) =
@@ -773,12 +797,12 @@ impl Parser {
});
crate::Expression::Constant(handle)
}
Token::Number(word) => {
let value = Self::get_scalar_value(word)?;
Token::Number { value, ty, width } => {
let inner = Self::get_constant_inner(value, ty, width)?;
let handle = ctx.constants.fetch_or_append(crate::Constant {
name: None,
specialization: None,
inner: crate::ConstantInner::Scalar { width: 4, value },
inner,
});
crate::Expression::Constant(handle)
}

View File

@@ -155,30 +155,30 @@ fn convert_wgsl(name: &str, language: Language) {
#[cfg(feature = "wgsl-in")]
#[test]
fn converts_wgsl_quad() {
fn convert_wgsl_quad() {
convert_wgsl("quad", Language::all());
}
#[cfg(feature = "wgsl-in")]
#[test]
fn converts_wgsl_simple() {
convert_wgsl("simple", Language::all());
fn convert_wgsl_empty() {
convert_wgsl("empty", Language::all());
}
#[cfg(feature = "wgsl-in")]
#[test]
fn converts_wgsl_boids() {
fn convert_wgsl_boids() {
convert_wgsl("boids", Language::METAL);
}
#[cfg(feature = "wgsl-in")]
#[test]
fn converts_wgsl_skybox() {
fn convert_wgsl_skybox() {
convert_wgsl("skybox", Language::all());
}
#[cfg(feature = "wgsl-in")]
#[test]
fn converts_wgsl_collatz() {
fn convert_wgsl_collatz() {
convert_wgsl("collatz", Language::METAL);
}

View File

@@ -69,7 +69,7 @@ struct Particles {
[[stage(compute), workgroup_size(1)]]
fn main() {
const index : u32 = gl_GlobalInvocationID.x;
if (index >= u32(5)) {
if (index >= 5u) {
return;
}
@@ -84,9 +84,9 @@ fn main() {
var pos : vec2<f32>;
var vel : vec2<f32>;
var i : u32 = u32(0);
var i : u32 = 0u;
loop {
if (i >= u32(5)) {
if (i >= 5u) {
break;
}
if (i == index) {
@@ -109,7 +109,7 @@ fn main() {
}
continuing {
i = i + u32(1);
i = i + 1u;
}
}
if (cMassCount > 0) {

View File

@@ -15,18 +15,18 @@ var<storage> v_indices: [[access(read_write)]] PrimeIndices;
// Though the conjecture has not been proven, no counterexample has ever been found.
// This function returns how many times this recurrence needs to be applied to reach 1.
fn collatz_iterations(n: u32) -> u32{
var i: u32 = u32(0);
var i: u32 = 0u;
loop {
if (n <= u32(1)) {
if (n <= 1u) {
break;
}
if (n % u32(2) == u32(0)) {
n = n / u32(2);
if (n % 2u == 0u) {
n = n / 2u;
}
else {
n = u32(3) * n + i32(1);
n = 3u * n + 1u;
}
i = i + u32(1);
i = i + 1u;
}
return i;
}

View File

@@ -0,0 +1,2 @@
[[stage(compute), workgroup_size(1)]]
fn main() {}

View File

@@ -1,7 +0,0 @@
// vertex
[[builtin(position)]] var<out> o_position : vec4<f32>;
[[stage(vertex)]]
fn main() {
o_position = vec4<f32>(1);
}

View File

@@ -71,17 +71,16 @@ kernel void main3(
type6 cVelCount = 0;
type pos1;
type vel1;
type5 i;
if (gl_GlobalInvocationID.x >= static_cast<uint>(5)) {
type5 i = 0;
if (gl_GlobalInvocationID.x >= 5) {
}
vPos = particlesA.particles[gl_GlobalInvocationID.x].pos;
vVel = particlesA.particles[gl_GlobalInvocationID.x].vel;
cMass = metal::float2(0.0, 0.0);
cVel = metal::float2(0.0, 0.0);
colVel = metal::float2(0.0, 0.0);
i = static_cast<uint>(0);
while(true) {
if (i >= static_cast<uint>(5)) {
if (i >= 5) {
break;
}
if (i == gl_GlobalInvocationID.x) {

View File

@@ -15,18 +15,17 @@ struct PrimeIndices {
type1 collatz_iterations(
type1 n
) {
type1 i;
i = static_cast<uint>(0);
type1 i = 0;
while(true) {
if (n <= static_cast<uint>(1)) {
if (n <= 1) {
break;
}
if (n % static_cast<uint>(2) == static_cast<uint>(0)) {
n = n / static_cast<uint>(2);
if (n % 2 == 0) {
n = n / 2;
} else {
n = static_cast<uint>(3) * n + static_cast<int>(1);
n = 3 * n + 1;
}
i = i + static_cast<uint>(1);
i = i + 1;
}
return i;
}

View File

@@ -11,7 +11,6 @@ precision highp float;
void main() {
gl_Position = vec4(1);
return;
}

View File

@@ -0,0 +1,12 @@
---
source: tests/snapshots.rs
expression: msl
---
#include <metal_stdlib>
#include <simd/simd.h>
kernel void main1(
) {
}

View File

@@ -5,23 +5,15 @@ expression: dis
; SPIR-V
; Version: 1.0
; Generator: rspirv
; Bound: 13
; Bound: 6
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Vertex %3 "main" %6
OpDecorate %6 BuiltIn Position
OpEntryPoint GLCompute %3 "main"
OpExecutionMode %3 LocalSize 1 1 1
%2 = OpTypeVoid
%4 = OpTypeFunction %2
%8 = OpTypeFloat 32
%7 = OpTypeVector %8 4
%9 = OpTypePointer Output %7
%6 = OpVariable %9 Output
%11 = OpTypeInt 32 1
%10 = OpConstant %11 1
%3 = OpFunction %2 None %4
%5 = OpLabel
%12 = OpCompositeConstruct %7 %10
OpStore %6 %12
OpReturn
OpFunctionEnd

View File

@@ -1,22 +0,0 @@
---
source: tests/snapshots.rs
expression: msl
---
#include <metal_stdlib>
#include <simd/simd.h>
typedef metal::float4 type;
struct main1Input {
};
struct main1Output {
type o_position [[position]];
};
vertex main1Output main1(
main1Input input [[stage_in]]
) {
main1Output output;
output.o_position = metal::float4(1);
return output;
}