mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
[wgsl] support unsigned literals
This commit is contained in:
committed by
Dzmitry Malyshau
parent
5b35b04546
commit
0ea8a0a3c2
@@ -13,7 +13,8 @@ fn consume_any(input: &str, what: impl Fn(char) -> bool) -> (&str, &str) {
|
||||
input.split_at(pos)
|
||||
}
|
||||
|
||||
fn consume_number(input: &str) -> (&str, &str) {
|
||||
fn consume_number(input: &str) -> (Token, &str) {
|
||||
//Note: I wish this function was simpler and faster...
|
||||
let mut is_first_char = true;
|
||||
let mut right_after_exponent = false;
|
||||
|
||||
@@ -32,7 +33,28 @@ fn consume_number(input: &str) -> (&str, &str) {
|
||||
}
|
||||
};
|
||||
let pos = input.find(|c| !what(c)).unwrap_or_else(|| input.len());
|
||||
input.split_at(pos)
|
||||
let (value, rest) = input.split_at(pos);
|
||||
|
||||
let mut rest_iter = rest.chars();
|
||||
let ty = rest_iter.next().unwrap_or(' ');
|
||||
match ty {
|
||||
'u' | 'i' | 'f' => {
|
||||
let width_end = rest_iter
|
||||
.position(|c| !('0'..='9').contains(&c))
|
||||
.unwrap_or_else(|| rest.len() - 1);
|
||||
let (width, rest) = rest[1..].split_at(width_end);
|
||||
(Token::Number { value, ty, width }, rest)
|
||||
}
|
||||
// default to `i32` or `f32`
|
||||
_ => (
|
||||
Token::Number {
|
||||
value,
|
||||
ty: if value.contains('.') { 'f' } else { 'i' },
|
||||
width: "",
|
||||
},
|
||||
rest,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn consume_token(mut input: &str) -> (Token<'_>, &str) {
|
||||
@@ -56,10 +78,7 @@ fn consume_token(mut input: &str) -> (Token<'_>, &str) {
|
||||
'.' => {
|
||||
let og_chars = chars.as_str();
|
||||
match chars.next() {
|
||||
Some('0'..='9') => {
|
||||
let (number, rest) = consume_number(input);
|
||||
(Token::Number(number), rest)
|
||||
}
|
||||
Some('0'..='9') => consume_number(input),
|
||||
_ => (Token::Separator(cur), og_chars),
|
||||
}
|
||||
}
|
||||
@@ -83,10 +102,7 @@ fn consume_token(mut input: &str) -> (Token<'_>, &str) {
|
||||
(Token::Paren(cur), input)
|
||||
}
|
||||
}
|
||||
'0'..='9' => {
|
||||
let (number, rest) = consume_number(input);
|
||||
(Token::Number(number), rest)
|
||||
}
|
||||
'0'..='9' => consume_number(input),
|
||||
'a'..='z' | 'A'..='Z' | '_' => {
|
||||
let (word, rest) = consume_any(input, |c| c.is_ascii_alphanumeric() || c == '_');
|
||||
(Token::Word(word), rest)
|
||||
@@ -115,10 +131,7 @@ fn consume_token(mut input: &str) -> (Token<'_>, &str) {
|
||||
let og_chars = chars.as_str();
|
||||
match chars.next() {
|
||||
Some('>') => (Token::Arrow, chars.as_str()),
|
||||
Some('0'..='9') | Some('.') => {
|
||||
let (number, rest) = consume_number(input);
|
||||
(Token::Number(number), rest)
|
||||
}
|
||||
Some('0'..='9') | Some('.') => consume_number(input),
|
||||
_ => (Token::Operation(cur), og_chars),
|
||||
}
|
||||
}
|
||||
@@ -193,21 +206,25 @@ impl<'a> Lexer<'a> {
|
||||
|
||||
fn _next_float_literal(&mut self) -> Result<f32, Error<'a>> {
|
||||
match self.next() {
|
||||
Token::Number(word) => word.parse().map_err(|err| Error::BadFloat(word, err)),
|
||||
Token::Number { value, .. } => value.parse().map_err(|err| Error::BadFloat(value, err)),
|
||||
other => other.unexpected("float literal"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn next_uint_literal(&mut self) -> Result<u32, Error<'a>> {
|
||||
match self.next() {
|
||||
Token::Number(word) => word.parse().map_err(|err| Error::BadInteger(word, err)),
|
||||
Token::Number { value, .. } => {
|
||||
value.parse().map_err(|err| Error::BadInteger(value, err))
|
||||
}
|
||||
other => other.unexpected("uint literal"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn next_sint_literal(&mut self) -> Result<i32, Error<'a>> {
|
||||
match self.next() {
|
||||
Token::Number(word) => word.parse().map_err(|err| Error::BadInteger(word, err)),
|
||||
Token::Number { value, .. } => {
|
||||
value.parse().map_err(|err| Error::BadInteger(value, err))
|
||||
}
|
||||
other => other.unexpected("sint literal"),
|
||||
}
|
||||
}
|
||||
@@ -246,7 +263,39 @@ fn sub_test(source: &str, expected_tokens: &[Token]) {
|
||||
#[test]
|
||||
fn test_tokens() {
|
||||
sub_test("id123_OK", &[Token::Word("id123_OK")]);
|
||||
sub_test("92No", &[Token::Number("92"), Token::Word("No")]);
|
||||
sub_test(
|
||||
"92No",
|
||||
&[
|
||||
Token::Number {
|
||||
value: "92",
|
||||
ty: 'i',
|
||||
width: "",
|
||||
},
|
||||
Token::Word("No"),
|
||||
],
|
||||
);
|
||||
sub_test(
|
||||
"2u3o",
|
||||
&[
|
||||
Token::Number {
|
||||
value: "2",
|
||||
ty: 'u',
|
||||
width: "3",
|
||||
},
|
||||
Token::Word("o"),
|
||||
],
|
||||
);
|
||||
sub_test(
|
||||
"2.4f44po",
|
||||
&[
|
||||
Token::Number {
|
||||
value: "2.4",
|
||||
ty: 'f',
|
||||
width: "44",
|
||||
},
|
||||
Token::Word("po"),
|
||||
],
|
||||
);
|
||||
sub_test(
|
||||
"æNoø",
|
||||
&[Token::Unknown('æ'), Token::Word("No"), Token::Unknown('ø')],
|
||||
@@ -264,7 +313,11 @@ fn test_variable_decl() {
|
||||
Token::DoubleParen('['),
|
||||
Token::Word("group"),
|
||||
Token::Paren('('),
|
||||
Token::Number("0"),
|
||||
Token::Number {
|
||||
value: "0",
|
||||
ty: 'i',
|
||||
width: "",
|
||||
},
|
||||
Token::Paren(')'),
|
||||
Token::DoubleParen(']'),
|
||||
Token::Word("var"),
|
||||
|
||||
@@ -23,7 +23,11 @@ pub enum Token<'a> {
|
||||
DoubleColon,
|
||||
Paren(char),
|
||||
DoubleParen(char),
|
||||
Number(&'a str),
|
||||
Number {
|
||||
value: &'a str,
|
||||
ty: char,
|
||||
width: &'a str,
|
||||
},
|
||||
String(&'a str),
|
||||
Word(&'a str),
|
||||
Operation(char),
|
||||
@@ -50,6 +54,8 @@ pub enum Error<'a> {
|
||||
BadInteger(&'a str, std::num::ParseIntError),
|
||||
#[error("unable to parse `{1}` as float: {1}")]
|
||||
BadFloat(&'a str, std::num::ParseFloatError),
|
||||
#[error("unable to parse `{0}{1}{2}` as scalar width: {3}")]
|
||||
BadScalarWidth(&'a str, char, &'a str, std::num::ParseIntError),
|
||||
#[error("bad field accessor `{0}`")]
|
||||
BadAccessor(&'a str),
|
||||
#[error("bad texture {0}`")]
|
||||
@@ -369,16 +375,37 @@ impl Parser {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_scalar_value(word: &str) -> Result<crate::ScalarValue, Error<'_>> {
|
||||
if word.contains('.') {
|
||||
word.parse()
|
||||
.map(crate::ScalarValue::Float)
|
||||
.map_err(|err| Error::BadFloat(word, err))
|
||||
} else {
|
||||
word.parse()
|
||||
fn get_constant_inner<'a>(
|
||||
word: &'a str,
|
||||
ty: char,
|
||||
width: &'a str,
|
||||
) -> Result<crate::ConstantInner, Error<'a>> {
|
||||
let value = match ty {
|
||||
'i' => word
|
||||
.parse()
|
||||
.map(crate::ScalarValue::Sint)
|
||||
.map_err(|err| Error::BadInteger(word, err))
|
||||
}
|
||||
.map_err(|err| Error::BadInteger(word, err))?,
|
||||
'u' => word
|
||||
.parse()
|
||||
.map(crate::ScalarValue::Uint)
|
||||
.map_err(|err| Error::BadInteger(word, err))?,
|
||||
'f' => word
|
||||
.parse()
|
||||
.map(crate::ScalarValue::Float)
|
||||
.map_err(|err| Error::BadFloat(word, err))?,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(crate::ConstantInner::Scalar {
|
||||
value,
|
||||
width: if width.is_empty() {
|
||||
4
|
||||
} else {
|
||||
match width.parse::<crate::Bytes>() {
|
||||
Ok(bits) => bits / 8,
|
||||
Err(e) => return Err(Error::BadScalarWidth(word, ty, width, e)),
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_function_call_inner<'a>(
|
||||
@@ -698,12 +725,9 @@ impl Parser {
|
||||
value: crate::ScalarValue::Bool(false),
|
||||
}
|
||||
}
|
||||
Token::Number(word) => {
|
||||
Token::Number { value, ty, width } => {
|
||||
let _ = lexer.next();
|
||||
crate::ConstantInner::Scalar {
|
||||
width: 4,
|
||||
value: Self::get_scalar_value(word)?,
|
||||
}
|
||||
Self::get_constant_inner(value, ty, width)?
|
||||
}
|
||||
_ => {
|
||||
let (composite_ty, _access) =
|
||||
@@ -773,12 +797,12 @@ impl Parser {
|
||||
});
|
||||
crate::Expression::Constant(handle)
|
||||
}
|
||||
Token::Number(word) => {
|
||||
let value = Self::get_scalar_value(word)?;
|
||||
Token::Number { value, ty, width } => {
|
||||
let inner = Self::get_constant_inner(value, ty, width)?;
|
||||
let handle = ctx.constants.fetch_or_append(crate::Constant {
|
||||
name: None,
|
||||
specialization: None,
|
||||
inner: crate::ConstantInner::Scalar { width: 4, value },
|
||||
inner,
|
||||
});
|
||||
crate::Expression::Constant(handle)
|
||||
}
|
||||
|
||||
@@ -155,30 +155,30 @@ fn convert_wgsl(name: &str, language: Language) {
|
||||
|
||||
#[cfg(feature = "wgsl-in")]
|
||||
#[test]
|
||||
fn converts_wgsl_quad() {
|
||||
fn convert_wgsl_quad() {
|
||||
convert_wgsl("quad", Language::all());
|
||||
}
|
||||
|
||||
#[cfg(feature = "wgsl-in")]
|
||||
#[test]
|
||||
fn converts_wgsl_simple() {
|
||||
convert_wgsl("simple", Language::all());
|
||||
fn convert_wgsl_empty() {
|
||||
convert_wgsl("empty", Language::all());
|
||||
}
|
||||
|
||||
#[cfg(feature = "wgsl-in")]
|
||||
#[test]
|
||||
fn converts_wgsl_boids() {
|
||||
fn convert_wgsl_boids() {
|
||||
convert_wgsl("boids", Language::METAL);
|
||||
}
|
||||
|
||||
#[cfg(feature = "wgsl-in")]
|
||||
#[test]
|
||||
fn converts_wgsl_skybox() {
|
||||
fn convert_wgsl_skybox() {
|
||||
convert_wgsl("skybox", Language::all());
|
||||
}
|
||||
|
||||
#[cfg(feature = "wgsl-in")]
|
||||
#[test]
|
||||
fn converts_wgsl_collatz() {
|
||||
fn convert_wgsl_collatz() {
|
||||
convert_wgsl("collatz", Language::METAL);
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ struct Particles {
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
const index : u32 = gl_GlobalInvocationID.x;
|
||||
if (index >= u32(5)) {
|
||||
if (index >= 5u) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -84,9 +84,9 @@ fn main() {
|
||||
|
||||
var pos : vec2<f32>;
|
||||
var vel : vec2<f32>;
|
||||
var i : u32 = u32(0);
|
||||
var i : u32 = 0u;
|
||||
loop {
|
||||
if (i >= u32(5)) {
|
||||
if (i >= 5u) {
|
||||
break;
|
||||
}
|
||||
if (i == index) {
|
||||
@@ -109,7 +109,7 @@ fn main() {
|
||||
}
|
||||
|
||||
continuing {
|
||||
i = i + u32(1);
|
||||
i = i + 1u;
|
||||
}
|
||||
}
|
||||
if (cMassCount > 0) {
|
||||
|
||||
@@ -15,18 +15,18 @@ var<storage> v_indices: [[access(read_write)]] PrimeIndices;
|
||||
// Though the conjecture has not been proven, no counterexample has ever been found.
|
||||
// This function returns how many times this recurrence needs to be applied to reach 1.
|
||||
fn collatz_iterations(n: u32) -> u32{
|
||||
var i: u32 = u32(0);
|
||||
var i: u32 = 0u;
|
||||
loop {
|
||||
if (n <= u32(1)) {
|
||||
if (n <= 1u) {
|
||||
break;
|
||||
}
|
||||
if (n % u32(2) == u32(0)) {
|
||||
n = n / u32(2);
|
||||
if (n % 2u == 0u) {
|
||||
n = n / 2u;
|
||||
}
|
||||
else {
|
||||
n = u32(3) * n + i32(1);
|
||||
n = 3u * n + 1u;
|
||||
}
|
||||
i = i + u32(1);
|
||||
i = i + 1u;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
2
tests/snapshots/in/empty.wgsl
Normal file
2
tests/snapshots/in/empty.wgsl
Normal file
@@ -0,0 +1,2 @@
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {}
|
||||
@@ -1,7 +0,0 @@
|
||||
// vertex
|
||||
[[builtin(position)]] var<out> o_position : vec4<f32>;
|
||||
|
||||
[[stage(vertex)]]
|
||||
fn main() {
|
||||
o_position = vec4<f32>(1);
|
||||
}
|
||||
@@ -71,17 +71,16 @@ kernel void main3(
|
||||
type6 cVelCount = 0;
|
||||
type pos1;
|
||||
type vel1;
|
||||
type5 i;
|
||||
if (gl_GlobalInvocationID.x >= static_cast<uint>(5)) {
|
||||
type5 i = 0;
|
||||
if (gl_GlobalInvocationID.x >= 5) {
|
||||
}
|
||||
vPos = particlesA.particles[gl_GlobalInvocationID.x].pos;
|
||||
vVel = particlesA.particles[gl_GlobalInvocationID.x].vel;
|
||||
cMass = metal::float2(0.0, 0.0);
|
||||
cVel = metal::float2(0.0, 0.0);
|
||||
colVel = metal::float2(0.0, 0.0);
|
||||
i = static_cast<uint>(0);
|
||||
while(true) {
|
||||
if (i >= static_cast<uint>(5)) {
|
||||
if (i >= 5) {
|
||||
break;
|
||||
}
|
||||
if (i == gl_GlobalInvocationID.x) {
|
||||
|
||||
@@ -15,18 +15,17 @@ struct PrimeIndices {
|
||||
type1 collatz_iterations(
|
||||
type1 n
|
||||
) {
|
||||
type1 i;
|
||||
i = static_cast<uint>(0);
|
||||
type1 i = 0;
|
||||
while(true) {
|
||||
if (n <= static_cast<uint>(1)) {
|
||||
if (n <= 1) {
|
||||
break;
|
||||
}
|
||||
if (n % static_cast<uint>(2) == static_cast<uint>(0)) {
|
||||
n = n / static_cast<uint>(2);
|
||||
if (n % 2 == 0) {
|
||||
n = n / 2;
|
||||
} else {
|
||||
n = static_cast<uint>(3) * n + static_cast<int>(1);
|
||||
n = 3 * n + 1;
|
||||
}
|
||||
i = i + static_cast<uint>(1);
|
||||
i = i + 1;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ precision highp float;
|
||||
|
||||
void main() {
|
||||
|
||||
gl_Position = vec4(1);
|
||||
return;
|
||||
}
|
||||
|
||||
12
tests/snapshots/snapshots__empty.msl.snap
Normal file
12
tests/snapshots/snapshots__empty.msl.snap
Normal file
@@ -0,0 +1,12 @@
|
||||
---
|
||||
source: tests/snapshots.rs
|
||||
expression: msl
|
||||
---
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
|
||||
kernel void main1(
|
||||
) {
|
||||
}
|
||||
|
||||
@@ -5,23 +5,15 @@ expression: dis
|
||||
; SPIR-V
|
||||
; Version: 1.0
|
||||
; Generator: rspirv
|
||||
; Bound: 13
|
||||
; Bound: 6
|
||||
OpCapability Shader
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint Vertex %3 "main" %6
|
||||
OpDecorate %6 BuiltIn Position
|
||||
OpEntryPoint GLCompute %3 "main"
|
||||
OpExecutionMode %3 LocalSize 1 1 1
|
||||
%2 = OpTypeVoid
|
||||
%4 = OpTypeFunction %2
|
||||
%8 = OpTypeFloat 32
|
||||
%7 = OpTypeVector %8 4
|
||||
%9 = OpTypePointer Output %7
|
||||
%6 = OpVariable %9 Output
|
||||
%11 = OpTypeInt 32 1
|
||||
%10 = OpConstant %11 1
|
||||
%3 = OpFunction %2 None %4
|
||||
%5 = OpLabel
|
||||
%12 = OpCompositeConstruct %7 %10
|
||||
OpStore %6 %12
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
@@ -1,22 +0,0 @@
|
||||
---
|
||||
source: tests/snapshots.rs
|
||||
expression: msl
|
||||
---
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
typedef metal::float4 type;
|
||||
|
||||
struct main1Input {
|
||||
};
|
||||
struct main1Output {
|
||||
type o_position [[position]];
|
||||
};
|
||||
vertex main1Output main1(
|
||||
main1Input input [[stage_in]]
|
||||
) {
|
||||
main1Output output;
|
||||
output.o_position = metal::float4(1);
|
||||
return output;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user