Ray query expressions and special types

This commit is contained in:
Dzmitry Malyshau
2023-02-15 23:35:03 -08:00
parent e46c53d212
commit b856625821
25 changed files with 388 additions and 46 deletions

View File

@@ -566,6 +566,12 @@ fn write_function_expressions(
edges.insert("", expr);
("ArrayLength".into(), 7)
}
E::RayQueryProceedResult => ("rayQueryProceedResult".into(), 4),
E::RayQueryGetIntersection { query, committed } => {
edges.insert("", query);
let ty = if committed { "Committed" } else { "Candidate" };
(format!("rayQueryGet{}Intersection", ty).into(), 4)
}
};
// give uniform expressions an outline

View File

@@ -3280,13 +3280,17 @@ impl<'a, W: Write> Writer<'a, W> {
}
}
// These expressions never show up in `Emit`.
Expression::CallResult(_) | Expression::AtomicResult { .. } => unreachable!(),
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult => unreachable!(),
// `ArrayLength` is written as `expr.length()` and we convert it to a uint
Expression::ArrayLength(expr) => {
write!(self.out, "uint(")?;
self.write_expr(expr, ctx)?;
write!(self.out, ".length())")?
}
// not supported yet
Expression::RayQueryGetIntersection { .. } => unreachable!(),
}
Ok(())

View File

@@ -2879,8 +2879,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_expr(module, reject, func_ctx)?;
write!(self.out, ")")?
}
// Not supported yet
Expression::RayQueryGetIntersection { .. } => unreachable!(),
// Nothing to do here, since call expression already cached
Expression::CallResult(_) | Expression::AtomicResult { .. } => {}
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult => {}
}
if !closing_bracket.is_empty() {

View File

@@ -1838,7 +1838,9 @@ impl<W: Write> Writer<W> {
_ => return Err(Error::Validation),
},
// has to be a named expression
crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } => {
crate::Expression::CallResult(_)
| crate::Expression::AtomicResult { .. }
| crate::Expression::RayQueryProceedResult => {
unreachable!()
}
crate::Expression::ArrayLength(expr) => {
@@ -1863,6 +1865,8 @@ impl<W: Write> Writer<W> {
write!(self.out, ")")?;
}
}
// hot supported yet
crate::Expression::RayQueryGetIntersection { .. } => unreachable!(),
}
Ok(())
}

View File

@@ -1386,6 +1386,10 @@ impl<'w> BlockContext<'w> {
id
}
crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
//TODO
crate::Expression::RayQueryProceedResult => unreachable!(),
//TODO
crate::Expression::RayQueryGetIntersection { .. } => unreachable!(),
};
self.cached[expr_handle] = id;

View File

@@ -1622,8 +1622,12 @@ impl<W: Write> Writer<W> {
write!(self.out, ")")?
}
// Not supported yet
Expression::RayQueryGetIntersection { .. } => unreachable!(),
// Nothing to do here, since call expression already cached
Expression::CallResult(_) | Expression::AtomicResult { .. } => {}
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult => {}
}
Ok(())

View File

@@ -37,6 +37,8 @@ pub enum ConstantSolvingError {
Load,
#[error("Constants don't support image expressions")]
ImageExpression,
#[error("Constants don't support ray query expressions")]
RayQueryExpression,
#[error("Cannot access the type")]
InvalidAccessBase,
#[error("Cannot access at the index")]
@@ -295,6 +297,9 @@ impl<'a> ConstantSolver<'a> {
Expression::ImageSample { .. }
| Expression::ImageLoad { .. }
| Expression::ImageQuery { .. } => Err(ConstantSolvingError::ImageExpression),
Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => {
Err(ConstantSolvingError::RayQueryExpression)
}
}
}

View File

@@ -246,14 +246,7 @@ impl Frontend {
expr: Handle<Expression>,
meta: Span,
) -> Result<()> {
let resolve_ctx = ResolveContext {
constants: &self.module.constants,
types: &self.module.types,
global_vars: &self.module.global_variables,
local_vars: &ctx.locals,
functions: &self.module.functions,
arguments: &ctx.arguments,
};
let resolve_ctx = ResolveContext::with_locals(&self.module, &ctx.locals, &ctx.arguments);
ctx.typifier
.grow(expr, &ctx.expressions, &resolve_ctx)
@@ -312,14 +305,7 @@ impl Frontend {
expr: Handle<Expression>,
meta: Span,
) -> Result<()> {
let resolve_ctx = ResolveContext {
constants: &self.module.constants,
types: &self.module.types,
global_vars: &self.module.global_variables,
local_vars: &ctx.locals,
functions: &self.module.functions,
arguments: &ctx.arguments,
};
let resolve_ctx = ResolveContext::with_locals(&self.module, &ctx.locals, &ctx.arguments);
ctx.typifier
.invalidate(expr, &ctx.expressions, &resolve_ctx)

View File

@@ -3,6 +3,7 @@ Frontend parsers that consume binary and text shaders and load them into [`Modul
*/
mod interpolator;
mod type_gen;
#[cfg(feature = "glsl-in")]
pub mod glsl;

153
src/front/type_gen.rs Normal file
View File

@@ -0,0 +1,153 @@
/*!
Type generators.
*/
use crate::{arena::Handle, span::Span};
impl crate::Module {
pub(super) fn generate_ray_desc_type(&mut self) -> Handle<crate::Type> {
if let Some(handle) = self.special_types.ray_desc {
return handle;
}
let width = 4;
let ty_flag = self.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar {
width,
kind: crate::ScalarKind::Uint,
},
},
Span::UNDEFINED,
);
let ty_scalar = self.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar {
width,
kind: crate::ScalarKind::Float,
},
},
Span::UNDEFINED,
);
let ty_vector = self.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Vector {
size: crate::VectorSize::Tri,
kind: crate::ScalarKind::Float,
width,
},
},
Span::UNDEFINED,
);
let handle = self.types.insert(
crate::Type {
name: Some("RayDesc".to_string()),
inner: crate::TypeInner::Struct {
members: vec![
crate::StructMember {
name: Some("flags".to_string()),
ty: ty_flag,
binding: None,
offset: 0,
},
crate::StructMember {
name: Some("cull_mask".to_string()),
ty: ty_flag,
binding: None,
offset: 4,
},
crate::StructMember {
name: Some("tmin".to_string()),
ty: ty_scalar,
binding: None,
offset: 8,
},
crate::StructMember {
name: Some("tmax".to_string()),
ty: ty_scalar,
binding: None,
offset: 12,
},
crate::StructMember {
name: Some("origin".to_string()),
ty: ty_vector,
binding: None,
offset: 16,
},
crate::StructMember {
name: Some("dir".to_string()),
ty: ty_vector,
binding: None,
offset: 32,
},
],
span: 48,
},
},
Span::UNDEFINED,
);
self.special_types.ray_desc = Some(handle);
handle
}
pub(super) fn generate_ray_intersection_type(&mut self) -> Handle<crate::Type> {
if let Some(handle) = self.special_types.ray_intersection {
return handle;
}
let width = 4;
let ty_flag = self.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar {
width,
kind: crate::ScalarKind::Uint,
},
},
Span::UNDEFINED,
);
let ty_scalar = self.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar {
width,
kind: crate::ScalarKind::Float,
},
},
Span::UNDEFINED,
);
let handle = self.types.insert(
crate::Type {
name: Some("RayIntersection".to_string()),
inner: crate::TypeInner::Struct {
members: vec![
crate::StructMember {
name: Some("kind".to_string()),
ty: ty_flag,
binding: None,
offset: 0,
},
crate::StructMember {
name: Some("t".to_string()),
ty: ty_scalar,
binding: None,
offset: 4,
},
//TODO: the rest
],
span: 8,
},
},
Span::UNDEFINED,
);
self.special_types.ray_intersection = Some(handle);
handle
}
}

View File

@@ -660,6 +660,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
});
ConcreteConstructorHandle::Type(ty)
}
ast::ConstructorType::RayDesc => {
let ty = ctx.module.generate_ray_desc_type();
ConcreteConstructorHandle::Type(ty)
}
ast::ConstructorType::Type(ty) => ConcreteConstructorHandle::Type(ty),
};

View File

@@ -234,14 +234,8 @@ impl<'a> ExpressionContext<'a, '_, '_> {
/// [`self.resolved_inner(handle)`]: ExpressionContext::resolved_inner
/// [`Typifier`]: Typifier
fn grow_types(&mut self, handle: Handle<crate::Expression>) -> Result<&mut Self, Error<'a>> {
let resolve_ctx = ResolveContext {
constants: &self.module.constants,
types: &self.module.types,
global_vars: &self.module.global_variables,
local_vars: self.local_vars,
functions: &self.module.functions,
arguments: self.arguments,
};
let resolve_ctx =
ResolveContext::with_locals(&self.module, self.local_vars, self.arguments);
self.typifier
.grow(handle, self.naga_expressions, &resolve_ctx)
.map_err(Error::InvalidResolve)?;
@@ -1919,6 +1913,54 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
query: crate::ImageQuery::NumSamples,
}
}
"rayQueryInitialize" => {
let mut args = ctx.prepare_args(arguments, 3, span);
let query = self.expression(args.next()?, ctx.reborrow())?;
let acceleration_structure =
self.expression(args.next()?, ctx.reborrow())?;
let descriptor = self.expression(args.next()?, ctx.reborrow())?;
args.finish()?;
let _ = ctx.module.generate_ray_desc_type();
let fun = crate::RayQueryFunction::Initialize {
acceleration_structure,
descriptor,
};
ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions));
ctx.emitter.start(ctx.naga_expressions);
ctx.block
.push(crate::Statement::RayQuery { query, fun }, span);
return Ok(None);
}
"rayQueryProceed" => {
let mut args = ctx.prepare_args(arguments, 1, span);
let query = self.expression(args.next()?, ctx.reborrow())?;
args.finish()?;
let fun = crate::RayQueryFunction::Proceed;
ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions));
let result = ctx
.naga_expressions
.append(crate::Expression::RayQueryProceedResult, span);
ctx.emitter.start(ctx.naga_expressions);
ctx.block
.push(crate::Statement::RayQuery { query, fun }, span);
return Ok(Some(result));
}
"rayQueryGetCommittedIntersection" => {
let mut args = ctx.prepare_args(arguments, 1, span);
let query = self.expression(args.next()?, ctx.reborrow())?;
args.finish()?;
let _ = ctx.module.generate_ray_intersection_type();
crate::Expression::RayQueryGetIntersection {
query,
committed: true,
}
}
_ => return Err(Error::UnknownIdent(function.span, function.name)),
}
};

View File

@@ -370,6 +370,9 @@ pub enum ConstructorType<'a> {
size: ArraySize<'a>,
},
/// Ray description.
RayDesc,
/// Constructing a value of a known Naga IR type.
///
/// This variant is produced only during lowering, when we have Naga types

View File

@@ -441,6 +441,7 @@ impl Parser {
}))
}
"array" => ast::ConstructorType::PartialArray,
"RayDesc" => ast::ConstructorType::RayDesc,
"atomic"
| "binding_array"
| "sampler"
@@ -622,6 +623,18 @@ impl Parser {
let num = res.map_err(|err| Error::BadNumber(span, err))?;
ast::Expression::Literal(ast::Literal::Number(num))
}
(Token::Word("RAY_FLAG_NONE"), _) => {
let _ = lexer.next();
ast::Expression::Literal(ast::Literal::Number(Number::U32(0)))
}
(Token::Word("RAY_FLAG_TERMINATE_ON_FIRST_HIT"), _) => {
let _ = lexer.next();
ast::Expression::Literal(ast::Literal::Number(Number::U32(4)))
}
(Token::Word("RAY_QUERY_INTERSECTION_NONE"), _) => {
let _ = lexer.next();
ast::Expression::Literal(ast::Literal::Number(Number::U32(0)))
}
(Token::Word(word), span) => {
let start = lexer.start_byte_offset();
let _ = lexer.next();

View File

@@ -107,6 +107,9 @@ Naga's rules for when `Expression`s are evaluated are as follows:
[`Atomic`] statement, representing the result of the atomic operation, is
evaluated when the `Atomic` statement is executed.
- Similarly, an [`RayQueryProceedResult`] expression, which is a boolean
indicating if the ray query is finished.
- All other expressions are evaluated when the (unique) [`Statement::Emit`]
statement that covers them is executed.
@@ -1441,6 +1444,13 @@ pub enum Expression {
/// This doesn't match the semantics of spirv's `OpArrayLength`, which must be passed
/// a pointer to a structure containing a runtime array in its' last field.
ArrayLength(Handle<Expression>),
/// Result of `rayQueryProceed`.
RayQueryProceedResult,
/// Result of `rayQueryGet*Intersection`.
RayQueryGetIntersection {
query: Handle<Expression>,
committed: bool,
},
}
pub use block::Block;
@@ -1779,6 +1789,19 @@ pub struct EntryPoint {
pub function: Function,
}
/// Set of special types that can be optionally generated by the frontends.
#[derive(Debug, Default)]
#[cfg_attr(feature = "clone", derive(Clone))]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
pub struct SpecialTypes {
/// Type for `RayDesc`.
ray_desc: Option<Handle<Type>>,
/// Type for `RayIntersection`.
ray_intersection: Option<Handle<Type>>,
}
/// Shader module.
///
/// A module is a set of constants, global variables and functions, as well as
@@ -1798,6 +1821,8 @@ pub struct EntryPoint {
pub struct Module {
/// Arena for the types defined in this module.
pub types: UniqueArena<Type>,
/// Dictionary of special type handles.
pub special_types: SpecialTypes,
/// Arena for the constants defined in this module.
pub constants: Arena<Constant>,
/// Arena for the global variables defined in this module.

View File

@@ -193,11 +193,14 @@ pub enum ResolveError {
IncompatibleOperands(String),
#[error("Function argument {0} doesn't exist")]
FunctionArgumentNotFound(u32),
#[error("Special type is not registered within the module")]
MissingSpecialType,
}
pub struct ResolveContext<'a> {
pub constants: &'a Arena<crate::Constant>,
pub types: &'a UniqueArena<crate::Type>,
pub special_types: &'a crate::SpecialTypes,
pub global_vars: &'a Arena<crate::GlobalVariable>,
pub local_vars: &'a Arena<crate::LocalVariable>,
pub functions: &'a Arena<crate::Function>,
@@ -205,6 +208,23 @@ pub struct ResolveContext<'a> {
}
impl<'a> ResolveContext<'a> {
/// Initialize a resolve context from the module.
pub fn with_locals(
module: &'a crate::Module,
local_vars: &'a Arena<crate::LocalVariable>,
arguments: &'a [crate::FunctionArgument],
) -> Self {
Self {
constants: &module.constants,
types: &module.types,
special_types: &module.special_types,
global_vars: &module.global_variables,
local_vars,
functions: &module.functions,
arguments,
}
}
/// Determine the type of `expr`.
///
/// The `past` argument must be a closure that can resolve the types of any
@@ -867,6 +887,17 @@ impl<'a> ResolveContext<'a> {
kind: crate::ScalarKind::Uint,
width: 4,
}),
crate::Expression::RayQueryProceedResult => TypeResolution::Value(Ti::Scalar {
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
}),
crate::Expression::RayQueryGetIntersection { .. } => {
let result = self
.special_types
.ray_intersection
.ok_or(ResolveError::MissingSpecialType)?;
TypeResolution::Handle(result)
}
})
}
}

View File

@@ -686,7 +686,7 @@ impl FunctionInfo {
requirements: UniformityRequirements::empty(),
},
E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
E::AtomicResult { .. } => Uniformity {
E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
non_uniform_result: Some(handle),
requirements: UniformityRequirements::empty(),
},
@@ -694,6 +694,13 @@ impl FunctionInfo {
non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
requirements: UniformityRequirements::empty(),
},
E::RayQueryGetIntersection {
query,
committed: _,
} => Uniformity {
non_uniform_result: self.add_ref(query),
requirements: UniformityRequirements::empty(),
},
};
let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
@@ -934,14 +941,8 @@ impl ModuleInfo {
expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
sampling: crate::FastHashSet::default(),
};
let resolve_context = ResolveContext {
constants: &module.constants,
types: &module.types,
global_vars: &module.global_variables,
local_vars: &fun.local_variables,
functions: &module.functions,
arguments: &fun.arguments,
};
let resolve_context =
ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
for (handle, expr) in fun.expressions.iter() {
if let Err(source) = info.process_expression(
@@ -1064,6 +1065,7 @@ fn uniform_control_flow() {
let resolve_context = ResolveContext {
constants: &constant_arena,
types: &type_arena,
special_types: &crate::SpecialTypes::default(),
global_vars: &global_var_arena,
local_vars: &Arena::new(),
functions: &Arena::new(),

View File

@@ -1427,6 +1427,7 @@ impl super::Validator {
return Err(ExpressionError::InvalidArrayType(expr));
}
},
E::RayQueryProceedResult | E::RayQueryGetIntersection { .. } => ShaderStages::all(),
};
Ok(stages)
}

View File

@@ -39,6 +39,7 @@ impl super::Validator {
ref functions,
ref global_variables,
ref types,
ref special_types,
} = module;
// NOTE: Types being first is important. All other forms of validation depend on this.
@@ -194,6 +195,13 @@ impl super::Validator {
validate_function(Some(function_handle), function)?;
}
if let Some(ty) = special_types.ray_desc {
validate_type(ty)?;
}
if let Some(ty) = special_types.ray_intersection {
validate_type(ty)?;
}
Ok(())
}
@@ -379,10 +387,16 @@ impl super::Validator {
handle.check_dep(function)?;
}
}
crate::Expression::AtomicResult { .. } => (),
crate::Expression::AtomicResult { .. } | crate::Expression::RayQueryProceedResult => (),
crate::Expression::ArrayLength(array) => {
handle.check_dep(array)?;
}
crate::Expression::RayQueryGetIntersection {
query,
committed: _,
} => {
handle.check_dep(query)?;
}
}
Ok(())
}

View File

@@ -440,7 +440,9 @@ impl super::Validator {
match types[var.ty].inner {
crate::TypeInner::Image { .. }
| crate::TypeInner::Sampler { .. }
| crate::TypeInner::BindingArray { .. } => {}
| crate::TypeInner::BindingArray { .. }
| crate::TypeInner::AccelerationStructure
| crate::TypeInner::RayQuery => {}
_ => {
return Err(GlobalVariableError::InvalidType(var.space));
}

View File

@@ -622,10 +622,14 @@ impl super::Validator {
Ti::Image { .. } | Ti::Sampler { .. } => {
TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE)
}
Ti::AccelerationStructure | Ti::RayQuery => {
Ti::AccelerationStructure => {
self.require_type_capability(Capabilities::RAY_QUERY)?;
TypeInfo::new(TypeFlags::empty(), Alignment::ONE)
}
Ti::RayQuery => {
self.require_type_capability(Capabilities::RAY_QUERY)?;
TypeInfo::new(TypeFlags::DATA | TypeFlags::SIZED, Alignment::ONE)
}
Ti::BindingArray { .. } => TypeInfo::new(TypeFlags::empty(), Alignment::ONE),
})
}

View File

@@ -1,3 +1,4 @@
@group(0) @binding(0)
var acc_struct: acceleration_structure;
/*
@@ -12,24 +13,41 @@ let RAY_QUERY_INTERSECTION_AABB = 4u;
struct RayDesc {
flags: u32,
cull_mask: u32,
origin: vec3<f32>,
t_min: f32,
dir: vec3<f32>,
t_max: f32,
}*/
origin: vec3<f32>,
dir: vec3<f32>,
}
struct RayIntersection {
kind: u32,
t: f32,
instance_custom_index: u32,
instance_id: u32,
sbt_record_offset: u32,
geometry_index: u32,
primitive_index: u32,
barycentrics: vec2<f32>,
front_face: bool,
//TODO: object ray direction, origin, matrices
}
*/
struct Output {
visible: u32,
}
@group(0) @binding(1)
var<storage, read_write> output: Output;
@compute
@compute @workgroup_size(1)
fn main() {
var rq: ray_query;
rayQueryInitialize(rq, acceleration_structure, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFF, vec3<f32>(0.0), 0.1, vec3<f32>(0.0, 1.0, 0.0), 100.0));
rayQueryInitialize(rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3<f32>(0.0), vec3<f32>(0.0, 1.0, 0.0)));
rayQueryProceed(rq);
output.visible = rayQueryGetCommittedIntersectionType(rq) == RAY_QUERY_COMMITTED_INTERSECTION_NONE;
let intersection = rayQueryGetCommittedIntersection(rq);
output.visible = u32(intersection.kind == RAY_QUERY_INTERSECTION_NONE);
}

View File

@@ -333,6 +333,10 @@
),
),
],
special_types: (
ray_desc: None,
ray_intersection: None,
),
constants: [
(
name: None,

View File

@@ -38,6 +38,10 @@
),
),
],
special_types: (
ray_desc: None,
ray_intersection: None,
),
constants: [
(
name: None,

View File

@@ -286,6 +286,10 @@
),
),
],
special_types: (
ray_desc: None,
ray_intersection: None,
),
constants: [
(
name: None,