typifier: handle forward expression dependencies

This commit is contained in:
Dzmitry Malyshau
2022-01-13 18:27:43 -05:00
parent be7df0d212
commit 68b1ae1499
3 changed files with 109 additions and 98 deletions

View File

@@ -94,7 +94,8 @@ impl Typifier {
) -> Result<(), ResolveError> {
if self.resolutions.len() <= expr_handle.index() {
for (eh, expr) in expressions.iter().skip(self.resolutions.len()) {
let resolution = ctx.resolve(expr, |h| &self.resolutions[h.index()])?;
//Note: the closure can't `Err` by construction
let resolution = ctx.resolve(expr, |h| Ok(&self.resolutions[h.index()]))?;
log::debug!("Resolving {:?} = {:?} : {:?}", eh, expr, resolution);
self.resolutions.push(resolution);
}
@@ -116,7 +117,8 @@ impl Typifier {
self.grow(expr_handle, expressions, ctx)
} else {
let expr = &expressions[expr_handle];
let resolution = ctx.resolve(expr, |h| &self.resolutions[h.index()])?;
//Note: the closure can't `Err` by construction
let resolution = ctx.resolve(expr, |h| Ok(&self.resolutions[h.index()]))?;
self.resolutions[expr_handle.index()] = resolution;
Ok(())
}

View File

@@ -197,6 +197,8 @@ pub enum ResolveError {
GlobalVariableNotFound(Handle<crate::GlobalVariable>),
#[error("Function argument {0} doesn't exist")]
FunctionArgumentNotFound(u32),
#[error("Expression {0:?} depends on expressions that follow")]
ExpressionForwardDependency(Handle<crate::Expression>),
}
pub struct ResolveContext<'a> {
@@ -227,12 +229,12 @@ impl<'a> ResolveContext<'a> {
pub fn resolve(
&self,
expr: &crate::Expression,
past: impl Fn(Handle<crate::Expression>) -> &'a TypeResolution,
past: impl Fn(Handle<crate::Expression>) -> Result<&'a TypeResolution, ResolveError>,
) -> Result<TypeResolution, ResolveError> {
use crate::TypeInner as Ti;
let types = self.types;
Ok(match *expr {
crate::Expression::Access { base, .. } => match *past(base).inner_with(types) {
crate::Expression::Access { base, .. } => match *past(base)?.inner_with(types) {
// Arrays and matrices can only be indexed dynamically behind a
// pointer, but that's a validation error, not a type error, so
// go ahead provide a type here.
@@ -299,106 +301,108 @@ impl<'a> ResolveContext<'a> {
});
}
},
crate::Expression::AccessIndex { base, index } => match *past(base).inner_with(types) {
Ti::Vector { size, kind, width } => {
if index >= size as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
TypeResolution::Value(Ti::Scalar { kind, width })
}
Ti::Matrix {
columns,
rows,
width,
} => {
if index >= columns as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
TypeResolution::Value(crate::TypeInner::Vector {
size: rows,
kind: crate::ScalarKind::Float,
width,
})
}
Ti::Array { base, .. } => TypeResolution::Handle(base),
Ti::Struct { ref members, .. } => {
let member = members
.get(index as usize)
.ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
TypeResolution::Handle(member.ty)
}
Ti::ValuePointer {
size: Some(size),
kind,
width,
class,
} => {
if index >= size as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
TypeResolution::Value(Ti::ValuePointer {
size: None,
kind,
width,
class,
})
}
Ti::Pointer {
base: ty_base,
class,
} => TypeResolution::Value(match types[ty_base].inner {
Ti::Array { base, .. } => Ti::Pointer { base, class },
crate::Expression::AccessIndex { base, index } => {
match *past(base)?.inner_with(types) {
Ti::Vector { size, kind, width } => {
if index >= size as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
Ti::ValuePointer {
size: None,
kind,
width,
class,
}
TypeResolution::Value(Ti::Scalar { kind, width })
}
Ti::Matrix {
rows,
columns,
rows,
width,
} => {
if index >= columns as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
Ti::ValuePointer {
size: Some(rows),
TypeResolution::Value(crate::TypeInner::Vector {
size: rows,
kind: crate::ScalarKind::Float,
width,
class,
}
})
}
Ti::Array { base, .. } => TypeResolution::Handle(base),
Ti::Struct { ref members, .. } => {
let member = members
.get(index as usize)
.ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
Ti::Pointer {
base: member.ty,
class,
}
TypeResolution::Handle(member.ty)
}
Ti::ValuePointer {
size: Some(size),
kind,
width,
class,
} => {
if index >= size as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
TypeResolution::Value(Ti::ValuePointer {
size: None,
kind,
width,
class,
})
}
Ti::Pointer {
base: ty_base,
class,
} => TypeResolution::Value(match types[ty_base].inner {
Ti::Array { base, .. } => Ti::Pointer { base, class },
Ti::Vector { size, kind, width } => {
if index >= size as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
Ti::ValuePointer {
size: None,
kind,
width,
class,
}
}
Ti::Matrix {
rows,
columns,
width,
} => {
if index >= columns as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
Ti::ValuePointer {
size: Some(rows),
kind: crate::ScalarKind::Float,
width,
class,
}
}
Ti::Struct { ref members, .. } => {
let member = members
.get(index as usize)
.ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
Ti::Pointer {
base: member.ty,
class,
}
}
ref other => {
log::error!("Access index sub-type {:?}", other);
return Err(ResolveError::InvalidSubAccess {
ty: ty_base,
indexed: true,
});
}
}),
ref other => {
log::error!("Access index sub-type {:?}", other);
return Err(ResolveError::InvalidSubAccess {
ty: ty_base,
log::error!("Access index type {:?}", other);
return Err(ResolveError::InvalidAccess {
expr: base,
indexed: true,
});
}
}),
ref other => {
log::error!("Access index type {:?}", other);
return Err(ResolveError::InvalidAccess {
expr: base,
indexed: true,
});
}
},
}
crate::Expression::Constant(h) => match self.constants[h].inner {
crate::ConstantInner::Scalar { width, ref value } => {
TypeResolution::Value(Ti::Scalar {
@@ -408,7 +412,7 @@ impl<'a> ResolveContext<'a> {
}
crate::ConstantInner::Composite { ty, components: _ } => TypeResolution::Handle(ty),
},
crate::Expression::Splat { size, value } => match *past(value).inner_with(types) {
crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) {
Ti::Scalar { kind, width } => {
TypeResolution::Value(Ti::Vector { size, kind, width })
}
@@ -421,7 +425,7 @@ impl<'a> ResolveContext<'a> {
size,
vector,
pattern: _,
} => match *past(vector).inner_with(types) {
} => match *past(vector)?.inner_with(types) {
Ti::Vector {
size: _,
kind,
@@ -464,7 +468,7 @@ impl<'a> ResolveContext<'a> {
class: crate::StorageClass::Function,
})
}
crate::Expression::Load { pointer } => match *past(pointer).inner_with(types) {
crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) {
Ti::Pointer { base, class: _ } => {
if let Ti::Atomic { kind, width } = types[base].inner {
TypeResolution::Value(Ti::Scalar { kind, width })
@@ -490,7 +494,7 @@ impl<'a> ResolveContext<'a> {
image,
gather: Some(_),
..
} => match *past(image).inner_with(types) {
} => match *past(image)?.inner_with(types) {
Ti::Image { class, .. } => TypeResolution::Value(Ti::Vector {
kind: match class {
crate::ImageClass::Sampled { kind, multi: _ } => kind,
@@ -505,7 +509,7 @@ impl<'a> ResolveContext<'a> {
}
},
crate::Expression::ImageSample { image, .. }
| crate::Expression::ImageLoad { image, .. } => match *past(image).inner_with(types) {
| crate::Expression::ImageLoad { image, .. } => match *past(image)?.inner_with(types) {
Ti::Image { class, .. } => TypeResolution::Value(match class {
crate::ImageClass::Depth { multi: _ } => Ti::Scalar {
kind: crate::ScalarKind::Float,
@@ -528,7 +532,7 @@ impl<'a> ResolveContext<'a> {
}
},
crate::Expression::ImageQuery { image, query } => TypeResolution::Value(match query {
crate::ImageQuery::Size { level: _ } => match *past(image).inner_with(types) {
crate::ImageQuery::Size { level: _ } => match *past(image)?.inner_with(types) {
Ti::Image { dim, .. } => match dim {
crate::ImageDimension::D1 => Ti::Scalar {
kind: crate::ScalarKind::Sint,
@@ -557,14 +561,14 @@ impl<'a> ResolveContext<'a> {
width: 4,
},
}),
crate::Expression::Unary { expr, .. } => past(expr).clone(),
crate::Expression::Unary { expr, .. } => past(expr)?.clone(),
crate::Expression::Binary { op, left, right } => match op {
crate::BinaryOperator::Add
| crate::BinaryOperator::Subtract
| crate::BinaryOperator::Divide
| crate::BinaryOperator::Modulo => past(left).clone(),
| crate::BinaryOperator::Modulo => past(left)?.clone(),
crate::BinaryOperator::Multiply => {
let (res_left, res_right) = (past(left), past(right));
let (res_left, res_right) = (past(left)?, past(right)?);
match (res_left.inner_with(types), res_right.inner_with(types)) {
(
&Ti::Matrix {
@@ -623,7 +627,7 @@ impl<'a> ResolveContext<'a> {
| crate::BinaryOperator::LogicalOr => {
let kind = crate::ScalarKind::Bool;
let width = crate::BOOL_WIDTH;
let inner = match *past(left).inner_with(types) {
let inner = match *past(left)?.inner_with(types) {
Ti::Scalar { .. } => Ti::Scalar { kind, width },
Ti::Vector { size, .. } => Ti::Vector { size, kind, width },
ref other => {
@@ -639,7 +643,7 @@ impl<'a> ResolveContext<'a> {
| crate::BinaryOperator::ExclusiveOr
| crate::BinaryOperator::InclusiveOr
| crate::BinaryOperator::ShiftLeft
| crate::BinaryOperator::ShiftRight => past(left).clone(),
| crate::BinaryOperator::ShiftRight => past(left)?.clone(),
},
crate::Expression::AtomicResult {
kind,
@@ -656,8 +660,8 @@ impl<'a> ResolveContext<'a> {
TypeResolution::Value(Ti::Scalar { kind, width })
}
}
crate::Expression::Select { accept, .. } => past(accept).clone(),
crate::Expression::Derivative { axis: _, expr } => past(expr).clone(),
crate::Expression::Select { accept, .. } => past(accept)?.clone(),
crate::Expression::Derivative { axis: _, expr } => past(expr)?.clone(),
crate::Expression::Relational { fun, argument } => match fun {
crate::RelationalFunction::All | crate::RelationalFunction::Any => {
TypeResolution::Value(Ti::Scalar {
@@ -668,7 +672,7 @@ impl<'a> ResolveContext<'a> {
crate::RelationalFunction::IsNan
| crate::RelationalFunction::IsInf
| crate::RelationalFunction::IsFinite
| crate::RelationalFunction::IsNormal => match *past(argument).inner_with(types) {
| crate::RelationalFunction::IsNormal => match *past(argument)?.inner_with(types) {
Ti::Scalar { .. } => TypeResolution::Value(Ti::Scalar {
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
@@ -694,7 +698,7 @@ impl<'a> ResolveContext<'a> {
arg3: _,
} => {
use crate::MathFunction as Mf;
let res_arg = past(arg);
let res_arg = past(arg)?;
match fun {
// comparison
Mf::Abs |
@@ -748,7 +752,7 @@ impl<'a> ResolveContext<'a> {
let arg1 = arg1.ok_or_else(|| ResolveError::IncompatibleOperands(
format!("{:?}(_, None)", fun)
))?;
match (res_arg.inner_with(types), past(arg1).inner_with(types)) {
match (res_arg.inner_with(types), past(arg1)?.inner_with(types)) {
(&Ti::Vector {kind: _, size: columns,width}, &Ti::Vector{ size: rows, .. }) => TypeResolution::Value(Ti::Matrix { columns, rows, width }),
(left, right) =>
return Err(ResolveError::IncompatibleOperands(
@@ -847,7 +851,7 @@ impl<'a> ResolveContext<'a> {
expr,
kind,
convert,
} => match *past(expr).inner_with(types) {
} => match *past(expr)?.inner_with(types) {
Ti::Scalar { kind: _, width } => TypeResolution::Value(Ti::Scalar {
kind,
width: convert.unwrap_or(width),

View File

@@ -10,7 +10,7 @@ use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ShaderStages,
use crate::span::{AddSpan as _, WithSpan};
use crate::{
arena::{Arena, Handle},
proc::{ResolveContext, TypeResolution},
proc::{ResolveContext, ResolveError, TypeResolution},
};
use std::ops;
@@ -598,7 +598,12 @@ impl FunctionInfo {
},
};
let ty = resolve_context.resolve(expression, |h| &self.expressions[h.index()].ty)?;
let ty = resolve_context.resolve(expression, |h| {
self.expressions
.get(h.index())
.map(|ei| &ei.ty)
.ok_or(ResolveError::ExpressionForwardDependency(h))
})?;
self.expressions[handle.index()] = ExpressionInfo {
uniformity,
ref_count: 0,