Validate that used expressions are emitted

This commit is contained in:
Dzmitry Malyshau
2021-04-25 18:51:20 -04:00
parent 452db33947
commit 4a5ff9a053

View File

@@ -3,6 +3,7 @@ use super::{
ExpressionError, FunctionInfo, ModuleInfo, TypeFlags, ValidationFlags,
};
use crate::arena::{Arena, Handle};
use bit_set::BitSet;
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
@@ -160,21 +161,39 @@ impl<'a> BlockContext<'a> {
fn resolve_type_impl(
&self,
handle: Handle<crate::Expression>,
valid_expressions: &BitSet,
) -> Result<&crate::TypeInner, ExpressionError> {
if handle.index() < self.expressions.len() {
Ok(self.info[handle].ty.inner_with(self.types))
} else {
if handle.index() >= self.expressions.len() {
Err(ExpressionError::DoesntExist)
} else if !valid_expressions.contains(handle.index()) {
Err(ExpressionError::NotInScope)
} else {
Ok(self.info[handle].ty.inner_with(self.types))
}
}
fn resolve_type(
&self,
handle: Handle<crate::Expression>,
valid_expressions: &BitSet,
) -> Result<&crate::TypeInner, FunctionError> {
self.resolve_type_impl(handle)
self.resolve_type_impl(handle, valid_expressions)
.map_err(|error| FunctionError::Expression { handle, error })
}
fn resolve_pointer_type(
&self,
handle: Handle<crate::Expression>,
) -> Result<&crate::TypeInner, FunctionError> {
if handle.index() >= self.expressions.len() {
Err(FunctionError::Expression {
handle,
error: ExpressionError::DoesntExist,
})
} else {
Ok(self.info[handle].ty.inner_with(self.types))
}
}
}
impl super::Validator {
@@ -197,7 +216,7 @@ impl super::Validator {
}
for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() {
let ty = context
.resolve_type_impl(expr)
.resolve_type_impl(expr, &self.valid_expression_set)
.map_err(|error| CallError::Argument { index, error })?;
if ty != &context.types[arg.ty].inner {
return Err(CallError::ArgumentType {
@@ -252,7 +271,7 @@ impl super::Validator {
ref accept,
ref reject,
} => {
match *context.resolve_type(condition)? {
match *context.resolve_type(condition, &self.valid_expression_set)? {
Ti::Scalar {
kind: crate::ScalarKind::Bool,
width: _,
@@ -267,7 +286,7 @@ impl super::Validator {
ref cases,
ref default,
} => {
match *context.resolve_type(selector)? {
match *context.resolve_type(selector, &self.valid_expression_set)? {
Ti::Scalar {
kind: crate::ScalarKind::Sint,
width: _,
@@ -311,7 +330,9 @@ impl super::Validator {
if !context.flags.contains(Flags::CAN_JUMP) {
return Err(FunctionError::InvalidReturnSpot);
}
let value_ty = value.map(|expr| context.resolve_type(expr)).transpose()?;
let value_ty = value
.map(|expr| context.resolve_type(expr, &self.valid_expression_set))
.transpose()?;
let expected_ty = context.return_type.map(|ty| &context.types[ty].inner);
if value_ty != expected_ty {
log::error!(
@@ -329,7 +350,7 @@ impl super::Validator {
S::Store { pointer, value } => {
let mut current = pointer;
loop {
let _ = context.resolve_type(current)?;
let _ = context.resolve_pointer_type(current)?;
match *context.get_expression(current)? {
crate::Expression::Access { base, .. }
| crate::Expression::AccessIndex { base, .. } => current = base,
@@ -340,14 +361,14 @@ impl super::Validator {
}
}
let value_ty = context.resolve_type(value)?;
let value_ty = context.resolve_type(value, &self.valid_expression_set)?;
match *value_ty {
Ti::Image { .. } | Ti::Sampler { .. } => {
return Err(FunctionError::InvalidStoreValue(value));
}
_ => {}
}
let good = match *context.resolve_type(pointer)? {
let good = match *context.resolve_pointer_type(pointer)? {
Ti::Pointer { base, class: _ } => *value_ty == context.types[base].inner,
Ti::ValuePointer {
size: Some(size),
@@ -393,7 +414,7 @@ impl super::Validator {
dim,
} => {
match context
.resolve_type(coordinate)?
.resolve_type(coordinate, &self.valid_expression_set)?
.image_storage_coordinates()
{
Some(coord_dim) if coord_dim == dim => {}
@@ -411,7 +432,7 @@ impl super::Validator {
));
}
if let Some(expr) = array_index {
match *context.resolve_type(expr)? {
match *context.resolve_type(expr, &self.valid_expression_set)? {
Ti::Scalar {
kind: crate::ScalarKind::Sint,
width: _,
@@ -443,7 +464,7 @@ impl super::Validator {
}
};
if *context.resolve_type(value)? != value_ty {
if *context.resolve_type(value, &self.valid_expression_set)? != value_ty {
return Err(FunctionError::InvalidStoreValue(value));
}
}