[valid] Handle texture/sampler function argument

This commit is contained in:
João Capucho
2021-06-24 22:01:40 +01:00
committed by Dzmitry Malyshau
parent f98b4e2f48
commit 7df4a52af9
6 changed files with 244 additions and 58 deletions

View File

@@ -159,6 +159,24 @@ impl ExpressionInfo {
}
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
enum GlobalOrArgument {
Global(Handle<crate::GlobalVariable>),
Argument(u32),
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
struct ArgumentSampling {
/// Whether this argument is used as an image or a sampler
image: bool,
/// The other global or argument used in the sampling
uses: Vec<GlobalOrArgument>,
}
#[derive(Debug)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
@@ -181,6 +199,10 @@ pub struct FunctionInfo {
///
/// Each item corresponds to an expression in the function.
expressions: Box<[ExpressionInfo]>,
/// Vector with information of wether or not a functin argument is used for sampling
///
/// Each item corresponds to a function argument
argument_sampling: Box<[Option<ArgumentSampling>]>,
}
impl FunctionInfo {
@@ -273,21 +295,101 @@ impl FunctionInfo {
}
/// Inherit information from a called function.
fn process_call(&mut self, info: &Self) -> FunctionUniformity {
fn process_call(
&mut self,
info: &Self,
arguments: &[Handle<crate::Expression>],
expression_arena: &Arena<crate::Expression>,
) -> Result<FunctionUniformity, FunctionError> {
for key in info.sampling_set.iter() {
self.sampling_set.insert(key.clone());
}
for (i, arg_sampling) in info
.argument_sampling
.iter()
.enumerate()
.filter_map(|(i, s)| Some((i, s.as_ref()?)))
{
let handle = arguments[i];
let arg_storage = match expression_arena[handle] {
crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
_ => {
return Err(FunctionError::Expression {
handle,
error: ExpressionError::ExpectedGlobalVariable,
})
}
};
for other in arg_sampling.uses.iter() {
let other_storage = match *other {
GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
GlobalOrArgument::Argument(i) => {
let other_handle = arguments[i as usize];
match expression_arena[other_handle] {
crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
_ => {
return Err(FunctionError::Expression {
handle,
error: ExpressionError::ExpectedGlobalVariable,
})
}
}
}
};
match (arg_storage, other_storage) {
(GlobalOrArgument::Global(arg), GlobalOrArgument::Global(other)) => {
if arg_sampling.image {
self.sampling_set.insert(SamplingKey {
image: arg,
sampler: other,
});
} else {
self.sampling_set.insert(SamplingKey {
image: other,
sampler: arg,
});
}
}
(GlobalOrArgument::Argument(i), _) => {
let sampling =
self.argument_sampling[i as usize].get_or_insert_with(|| {
ArgumentSampling {
image: arg_sampling.image,
uses: Vec::with_capacity(1),
}
});
sampling.uses.push(other_storage)
}
(_, GlobalOrArgument::Argument(i)) => {
let sampling =
self.argument_sampling[i as usize].get_or_insert_with(|| {
ArgumentSampling {
image: !arg_sampling.image,
uses: Vec::with_capacity(1),
}
});
sampling.uses.push(arg_storage)
}
}
}
}
for (mine, other) in self.global_uses.iter_mut().zip(info.global_uses.iter()) {
*mine |= *other;
}
FunctionUniformity {
Ok(FunctionUniformity {
result: info.uniformity.clone(),
exit: if info.may_kill {
ExitFlags::MAY_KILL
} else {
ExitFlags::empty()
},
}
})
}
/// Computes the expression info and stores it in `self.expressions`.
@@ -397,16 +499,45 @@ impl FunctionInfo {
level,
depth_ref,
} => {
self.sampling_set.insert(SamplingKey {
image: match expression_arena[image] {
crate::Expression::GlobalVariable(var) => var,
_ => return Err(ExpressionError::ExpectedGlobalVariable),
},
sampler: match expression_arena[sampler] {
crate::Expression::GlobalVariable(var) => var,
_ => return Err(ExpressionError::ExpectedGlobalVariable),
},
});
let image_storage = match expression_arena[image] {
crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
_ => return Err(ExpressionError::ExpectedGlobalVariable),
};
let sampler_storage = match expression_arena[sampler] {
crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
_ => return Err(ExpressionError::ExpectedGlobalVariable),
};
match (image_storage, sampler_storage) {
(GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
self.sampling_set.insert(SamplingKey { image, sampler });
}
(GlobalOrArgument::Argument(i), _) => {
let sampling =
self.argument_sampling[i as usize].get_or_insert_with(|| {
ArgumentSampling {
image: true,
uses: Vec::with_capacity(1),
}
});
sampling.uses.push(sampler_storage)
}
(_, GlobalOrArgument::Argument(i)) => {
let sampling =
self.argument_sampling[i as usize].get_or_insert_with(|| {
ArgumentSampling {
image: false,
uses: Vec::with_capacity(1),
}
});
sampling.uses.push(image_storage)
}
}
// "nur" == "Non-Uniform Result"
let array_nur = array_index.and_then(|h| self.add_ref(h));
let level_nur = match level {
@@ -501,10 +632,11 @@ impl FunctionInfo {
requirements: UniformityRequirements::empty(),
},
E::Call(function) => {
let fun = other_functions
let info = other_functions
.get(function.index())
.ok_or(ExpressionError::CallToUndeclaredFunction(function))?;
self.process_call(fun).result
info.uniformity.clone()
}
E::ArrayLength(expr) => Uniformity {
non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
@@ -537,6 +669,7 @@ impl FunctionInfo {
statements: &[crate::Statement],
other_functions: &[FunctionInfo],
mut disruptor: Option<UniformityDisruptor>,
expression_arena: &Arena<crate::Expression>,
) -> Result<FunctionUniformity, FunctionError> {
use crate::Statement as S;
@@ -582,7 +715,9 @@ impl FunctionInfo {
},
exit: ExitFlags::empty(),
},
S::Block(ref b) => self.process_block(b, other_functions, disruptor)?,
S::Block(ref b) => {
self.process_block(b, other_functions, disruptor, expression_arena)?
}
S::If {
condition,
ref accept,
@@ -591,10 +726,18 @@ impl FunctionInfo {
let condition_nur = self.add_ref(condition);
let branch_disruptor =
disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
let accept_uniformity =
self.process_block(accept, other_functions, branch_disruptor)?;
let reject_uniformity =
self.process_block(reject, other_functions, branch_disruptor)?;
let accept_uniformity = self.process_block(
accept,
other_functions,
branch_disruptor,
expression_arena,
)?;
let reject_uniformity = self.process_block(
reject,
other_functions,
branch_disruptor,
expression_arena,
)?;
accept_uniformity | reject_uniformity
}
S::Switch {
@@ -608,8 +751,12 @@ impl FunctionInfo {
let mut uniformity = FunctionUniformity::new();
let mut case_disruptor = branch_disruptor;
for case in cases.iter() {
let case_uniformity =
self.process_block(&case.body, other_functions, case_disruptor)?;
let case_uniformity = self.process_block(
&case.body,
other_functions,
case_disruptor,
expression_arena,
)?;
case_disruptor = if case.fall_through {
case_disruptor.or(case_uniformity.exit_disruptor())
} else {
@@ -618,18 +765,27 @@ impl FunctionInfo {
uniformity = uniformity | case_uniformity;
}
// using the disruptor inherited from the last fall-through chain
let default_exit =
self.process_block(default, other_functions, case_disruptor)?;
let default_exit = self.process_block(
default,
other_functions,
case_disruptor,
expression_arena,
)?;
uniformity | default_exit
}
S::Loop {
ref body,
ref continuing,
} => {
let body_uniformity = self.process_block(body, other_functions, disruptor)?;
let body_uniformity =
self.process_block(body, other_functions, disruptor, expression_arena)?;
let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
let continuing_uniformity =
self.process_block(continuing, other_functions, continuing_disruptor)?;
let continuing_uniformity = self.process_block(
continuing,
other_functions,
continuing_disruptor,
expression_arena,
)?;
body_uniformity | continuing_uniformity
}
S::Return { value } => FunctionUniformity {
@@ -680,7 +836,7 @@ impl FunctionInfo {
},
)?;
//Note: the result is validated by the Validator, not here
self.process_call(info)
self.process_call(info, arguments, expression_arena)?
}
};
@@ -708,6 +864,7 @@ impl ModuleInfo {
sampling_set: crate::FastHashSet::default(),
global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
argument_sampling: vec![None; fun.arguments.len()].into_boxed_slice(),
};
let resolve_context = ResolveContext {
constants: &module.constants,
@@ -730,7 +887,7 @@ impl ModuleInfo {
}
}
let uniformity = info.process_block(&fun.body, &self.functions, None)?;
let uniformity = info.process_block(&fun.body, &self.functions, None, &fun.expressions)?;
info.uniformity = uniformity.result;
info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
@@ -812,6 +969,7 @@ fn uniform_control_flow() {
sampling_set: crate::FastHashSet::default(),
global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
argument_sampling: vec![None; 0].into_boxed_slice(),
};
let resolve_context = ResolveContext {
constants: &constant_arena,
@@ -845,7 +1003,7 @@ fn uniform_control_flow() {
],
};
assert_eq!(
info.process_block(&[stmt_emit1, stmt_if_uniform], &[], None),
info.process_block(&[stmt_emit1, stmt_if_uniform], &[], None, &expressions),
Ok(FunctionUniformity {
result: Uniformity {
non_uniform_result: None,
@@ -870,7 +1028,7 @@ fn uniform_control_flow() {
reject: Vec::new(),
};
assert_eq!(
info.process_block(&[stmt_emit2, stmt_if_non_uniform], &[], None),
info.process_block(&[stmt_emit2, stmt_if_non_uniform], &[], None, &expressions),
Err(FunctionError::NonUniformControlFlow(
UniformityRequirements::DERIVATIVE,
derivative_expr,
@@ -888,7 +1046,8 @@ fn uniform_control_flow() {
info.process_block(
&[stmt_emit3, stmt_return_non_uniform],
&[],
Some(UniformityDisruptor::Return)
Some(UniformityDisruptor::Return),
&expressions
),
Ok(FunctionUniformity {
result: Uniformity {
@@ -914,7 +1073,8 @@ fn uniform_control_flow() {
info.process_block(
&[stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer],
&[],
Some(UniformityDisruptor::Discard)
Some(UniformityDisruptor::Discard),
&expressions
),
Ok(FunctionUniformity {
result: Uniformity {

View File

@@ -339,24 +339,26 @@ impl super::Validator {
depth_ref,
} => {
// check the validity of expressions
let image_var = match function.expressions[image] {
let image_ty = match function.expressions[image] {
crate::Expression::GlobalVariable(var_handle) => {
&module.global_variables[var_handle]
module.global_variables[var_handle].ty
}
crate::Expression::FunctionArgument(i) => function.arguments[i as usize].ty,
_ => return Err(ExpressionError::ExpectedGlobalVariable),
};
let sampler_var = match function.expressions[sampler] {
let sampler_ty = match function.expressions[sampler] {
crate::Expression::GlobalVariable(var_handle) => {
&module.global_variables[var_handle]
module.global_variables[var_handle].ty
}
crate::Expression::FunctionArgument(i) => function.arguments[i as usize].ty,
_ => return Err(ExpressionError::ExpectedGlobalVariable),
};
let comparison = match module.types[sampler_var.ty].inner {
let comparison = match module.types[sampler_ty].inner {
Ti::Sampler { comparison } => comparison,
_ => return Err(ExpressionError::ExpectedSamplerType(sampler_var.ty)),
_ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)),
};
let (class, dim) = match module.types[image_var.ty].inner {
let (class, dim) = match module.types[image_ty].inner {
Ti::Image {
class,
arrayed,
@@ -377,7 +379,7 @@ impl super::Validator {
}
(class, dim)
}
_ => return Err(ExpressionError::ExpectedImageType(image_var.ty)),
_ => return Err(ExpressionError::ExpectedImageType(image_ty)),
};
// check sampling and comparison properties
@@ -515,13 +517,14 @@ impl super::Validator {
array_index,
index,
} => {
let var = match function.expressions[image] {
let ty = match function.expressions[image] {
crate::Expression::GlobalVariable(var_handle) => {
&module.global_variables[var_handle]
module.global_variables[var_handle].ty
}
crate::Expression::FunctionArgument(i) => function.arguments[i as usize].ty,
_ => return Err(ExpressionError::ExpectedGlobalVariable),
};
match module.types[var.ty].inner {
match module.types[ty].inner {
Ti::Image {
class,
arrayed,
@@ -564,18 +567,19 @@ impl super::Validator {
}
}
}
_ => return Err(ExpressionError::ExpectedImageType(var.ty)),
_ => return Err(ExpressionError::ExpectedImageType(ty)),
}
ShaderStages::all()
}
E::ImageQuery { image, query } => {
let var = match function.expressions[image] {
let ty = match function.expressions[image] {
crate::Expression::GlobalVariable(var_handle) => {
&module.global_variables[var_handle]
module.global_variables[var_handle].ty
}
crate::Expression::FunctionArgument(i) => function.arguments[i as usize].ty,
_ => return Err(ExpressionError::ExpectedGlobalVariable),
};
match module.types[var.ty].inner {
match module.types[ty].inner {
Ti::Image { class, arrayed, .. } => {
let can_level = match class {
crate::ImageClass::Sampled { multi, .. } => !multi,
@@ -593,7 +597,7 @@ impl super::Validator {
return Err(ExpressionError::InvalidImageClass(class));
}
}
_ => return Err(ExpressionError::ExpectedImageType(var.ty)),
_ => return Err(ExpressionError::ExpectedImageType(ty)),
}
ShaderStages::all()
}

View File

@@ -578,7 +578,7 @@ impl super::Validator {
for (index, argument) in fun.arguments.iter().enumerate() {
if !self.types[argument.ty.index()]
.flags
.contains(TypeFlags::DATA | TypeFlags::SIZED)
.contains(TypeFlags::ARGUMENT)
{
return Err(FunctionError::InvalidArgumentType {
index,

View File

@@ -46,6 +46,9 @@ bitflags::bitflags! {
/// This is a top-level host-shareable type.
const TOP_LEVEL = 0x10;
/// This type can be passed as a function argument.
const ARGUMENT = 0x20;
}
}
@@ -192,7 +195,8 @@ impl super::Validator {
TypeFlags::DATA
| TypeFlags::SIZED
| TypeFlags::INTERFACE
| TypeFlags::HOST_SHARED,
| TypeFlags::HOST_SHARED
| TypeFlags::ARGUMENT,
width as u32,
)
}
@@ -205,7 +209,8 @@ impl super::Validator {
TypeFlags::DATA
| TypeFlags::SIZED
| TypeFlags::INTERFACE
| TypeFlags::HOST_SHARED,
| TypeFlags::HOST_SHARED
| TypeFlags::ARGUMENT,
count * (width as u32),
)
}
@@ -222,7 +227,8 @@ impl super::Validator {
TypeFlags::DATA
| TypeFlags::SIZED
| TypeFlags::INTERFACE
| TypeFlags::HOST_SHARED,
| TypeFlags::HOST_SHARED
| TypeFlags::ARGUMENT,
count * (width as u32),
)
}
@@ -239,9 +245,9 @@ impl super::Validator {
// `DATA`.
let base_info = &self.types[base.index()];
let data_flag = if base_info.flags.contains(TypeFlags::SIZED) {
TypeFlags::DATA
TypeFlags::DATA | TypeFlags::ARGUMENT
} else if let crate::TypeInner::Struct { .. } = types[base].inner {
TypeFlags::DATA
TypeFlags::DATA | TypeFlags::ARGUMENT
} else {
TypeFlags::empty()
};
@@ -350,7 +356,7 @@ impl super::Validator {
return Err(TypeError::NonPositiveArrayLength(const_handle));
}
TypeFlags::SIZED
TypeFlags::SIZED | TypeFlags::ARGUMENT
}
crate::ArraySize::Dynamic => {
// Non-SIZED types may only appear as the last element of a structure.
@@ -376,7 +382,8 @@ impl super::Validator {
TypeFlags::DATA
| TypeFlags::SIZED
| TypeFlags::HOST_SHARED
| TypeFlags::INTERFACE,
| TypeFlags::INTERFACE
| TypeFlags::ARGUMENT,
1,
);
let mut min_offset = 0;
@@ -460,7 +467,7 @@ impl super::Validator {
ti
}
Ti::Image { .. } | Ti::Sampler { .. } => TypeInfo::new(TypeFlags::empty(), 0),
Ti::Image { .. } | Ti::Sampler { .. } => TypeInfo::new(TypeFlags::ARGUMENT, 0),
})
}
}

View File

@@ -336,6 +336,9 @@
ty: Handle(1),
),
],
argument_sampling: [
None,
],
),
],
entry_points: [
@@ -492,6 +495,9 @@
ty: Handle(1),
),
],
argument_sampling: [
None,
],
),
],
)

View File

@@ -1060,6 +1060,10 @@
)),
),
],
argument_sampling: [
None,
None,
],
),
(
flags: (
@@ -2748,6 +2752,7 @@
ty: Handle(4),
),
],
argument_sampling: [],
),
],
entry_points: [
@@ -2871,6 +2876,10 @@
ty: Handle(4),
),
],
argument_sampling: [
None,
None,
],
),
],
)