Make type a part of ExpressionInfo

This commit is contained in:
Dzmitry Malyshau
2021-03-22 11:33:33 -04:00
parent 1d3f2bbdb1
commit 970b77abaf
9 changed files with 1043 additions and 261 deletions

View File

@@ -8,7 +8,7 @@ mod typifier;
pub use layouter::{Alignment, Layouter};
pub use namer::{EntryPointIndex, NameKey, Namer};
pub use terminator::ensure_block_returns;
pub use typifier::{ResolveContext, ResolveError, Typifier, TypifyError};
pub use typifier::{ResolveContext, ResolveError, TypeResolution, Typifier, TypifyError};
impl From<super::StorageFormat> for super::ScalarKind {
fn from(format: super::StorageFormat) -> Self {

View File

@@ -3,18 +3,36 @@ use crate::arena::{Arena, Handle};
use thiserror::Error;
#[derive(Debug, PartialEq)]
enum Resolution {
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub enum TypeResolution {
Handle(Handle<crate::Type>),
Value(crate::TypeInner),
}
impl TypeResolution {
pub fn handle(&self) -> Option<Handle<crate::Type>> {
match *self {
Self::Handle(handle) => Some(handle),
Self::Value(_) => None,
}
}
pub fn inner_with<'a>(&'a self, arena: &'a Arena<crate::Type>) -> &'a crate::TypeInner {
match *self {
Self::Handle(handle) => &arena[handle].inner,
Self::Value(ref inner) => inner,
}
}
}
// Clone is only implemented for numeric variants of `TypeInner`.
impl Clone for Resolution {
impl Clone for TypeResolution {
fn clone(&self) -> Self {
use crate::TypeInner as Ti;
match *self {
Resolution::Handle(handle) => Resolution::Handle(handle),
Resolution::Value(ref v) => Resolution::Value(match *v {
Self::Handle(handle) => Self::Handle(handle),
Self::Value(ref v) => Self::Value(match *v {
Ti::Scalar { kind, width } => Ti::Scalar { kind, width },
Ti::Vector { size, kind, width } => Ti::Vector { size, kind, width },
Ti::Matrix {
@@ -47,7 +65,7 @@ impl Clone for Resolution {
/// Helper processor that derives the types of all expressions.
#[derive(Debug)]
pub struct Typifier {
resolutions: Vec<Resolution>,
resolutions: Vec<TypeResolution>,
}
#[derive(Clone, Debug, Error, PartialEq)]
@@ -81,6 +99,7 @@ pub enum ResolveError {
IncompatibleOperands(String),
}
//TODO: remove this
#[repr(C)] // pack this tighter: 48 -> 40 bytes
#[derive(Clone, Debug, Error, PartialEq)]
#[error("Type resolution of {0:?} failed")]
@@ -94,77 +113,34 @@ pub struct ResolveContext<'a> {
pub arguments: &'a [crate::FunctionArgument],
}
impl Typifier {
pub fn new() -> Self {
Typifier {
resolutions: Vec::new(),
}
}
pub fn clear(&mut self) {
self.resolutions.clear()
}
pub fn get<'a>(
&'a self,
expr_handle: Handle<crate::Expression>,
types: &'a Arena<crate::Type>,
) -> &'a crate::TypeInner {
match self.resolutions[expr_handle.index()] {
Resolution::Handle(ty_handle) => &types[ty_handle].inner,
Resolution::Value(ref inner) => inner,
}
}
pub fn try_get<'a>(
&'a self,
expr_handle: Handle<crate::Expression>,
types: &'a Arena<crate::Type>,
) -> Option<&'a crate::TypeInner> {
let resolution = self.resolutions.get(expr_handle.index())?;
Some(match *resolution {
Resolution::Handle(ty_handle) => &types[ty_handle].inner,
Resolution::Value(ref inner) => inner,
})
}
pub fn get_handle(
&self,
expr_handle: Handle<crate::Expression>,
) -> Result<Handle<crate::Type>, &crate::TypeInner> {
match self.resolutions[expr_handle.index()] {
Resolution::Handle(ty_handle) => Ok(ty_handle),
Resolution::Value(ref inner) => Err(inner),
}
}
fn resolve_impl(
&self,
impl TypeResolution {
pub fn new<'a>(
expr: &crate::Expression,
types: &Arena<crate::Type>,
types: &'a Arena<crate::Type>,
ctx: &ResolveContext,
) -> Result<Resolution, ResolveError> {
past: impl Fn(Handle<crate::Expression>) -> &'a Self,
) -> Result<Self, ResolveError> {
use crate::TypeInner as Ti;
Ok(match *expr {
crate::Expression::Access { base, .. } => match *self.get(base, types) {
Ti::Array { base, .. } => Resolution::Handle(base),
crate::Expression::Access { base, .. } => match *past(base).inner_with(types) {
Ti::Array { base, .. } => Self::Handle(base),
Ti::Vector {
size: _,
kind,
width,
} => Resolution::Value(Ti::Scalar { kind, width }),
} => Self::Value(Ti::Scalar { kind, width }),
Ti::ValuePointer {
size: Some(_),
kind,
width,
class,
} => Resolution::Value(Ti::ValuePointer {
} => Self::Value(Ti::ValuePointer {
size: None,
kind,
width,
class,
}),
Ti::Pointer { base, class } => Resolution::Value(match types[base].inner {
Ti::Pointer { base, class } => Self::Value(match types[base].inner {
Ti::Array { base, .. } => Ti::Pointer { base, class },
Ti::Vector {
size: _,
@@ -192,12 +168,12 @@ impl Typifier {
});
}
},
crate::Expression::AccessIndex { base, index } => match *self.get(base, types) {
crate::Expression::AccessIndex { base, index } => match *past(base).inner_with(types) {
Ti::Vector { size, kind, width } => {
if index >= size as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
Resolution::Value(Ti::Scalar { kind, width })
Self::Value(Ti::Scalar { kind, width })
}
Ti::Matrix {
columns,
@@ -207,13 +183,13 @@ impl Typifier {
if index >= columns as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
Resolution::Value(crate::TypeInner::Vector {
Self::Value(crate::TypeInner::Vector {
size: rows,
kind: crate::ScalarKind::Float,
width,
})
}
Ti::Array { base, .. } => Resolution::Handle(base),
Ti::Array { base, .. } => Self::Handle(base),
Ti::Struct {
block: _,
ref members,
@@ -221,7 +197,7 @@ impl Typifier {
let member = members
.get(index as usize)
.ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
Resolution::Handle(member.ty)
Self::Handle(member.ty)
}
Ti::ValuePointer {
size: Some(size),
@@ -232,7 +208,7 @@ impl Typifier {
if index >= size as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
Resolution::Value(Ti::ValuePointer {
Self::Value(Ti::ValuePointer {
size: None,
kind,
width,
@@ -242,7 +218,7 @@ impl Typifier {
Ti::Pointer {
base: ty_base,
class,
} => Resolution::Value(match types[ty_base].inner {
} => Self::Value(match types[ty_base].inner {
Ti::Array { base, .. } => Ti::Pointer { base, class },
Ti::Vector { size, kind, width } => {
if index >= size as u32 {
@@ -299,24 +275,22 @@ impl Typifier {
}
},
crate::Expression::Constant(h) => match ctx.constants[h].inner {
crate::ConstantInner::Scalar { width, ref value } => {
Resolution::Value(Ti::Scalar {
kind: value.scalar_kind(),
width,
})
}
crate::ConstantInner::Composite { ty, components: _ } => Resolution::Handle(ty),
crate::ConstantInner::Scalar { width, ref value } => Self::Value(Ti::Scalar {
kind: value.scalar_kind(),
width,
}),
crate::ConstantInner::Composite { ty, components: _ } => Self::Handle(ty),
},
crate::Expression::Compose { ty, .. } => Resolution::Handle(ty),
crate::Expression::Compose { ty, .. } => Self::Handle(ty),
crate::Expression::FunctionArgument(index) => {
Resolution::Handle(ctx.arguments[index as usize].ty)
Self::Handle(ctx.arguments[index as usize].ty)
}
crate::Expression::GlobalVariable(h) => {
let var = &ctx.global_vars[h];
if var.class == crate::StorageClass::Handle {
Resolution::Handle(var.ty)
Self::Handle(var.ty)
} else {
Resolution::Value(Ti::Pointer {
Self::Value(Ti::Pointer {
base: var.ty,
class: var.class,
})
@@ -324,19 +298,19 @@ impl Typifier {
}
crate::Expression::LocalVariable(h) => {
let var = &ctx.local_vars[h];
Resolution::Value(Ti::Pointer {
Self::Value(Ti::Pointer {
base: var.ty,
class: crate::StorageClass::Function,
})
}
crate::Expression::Load { pointer } => match *self.get(pointer, types) {
Ti::Pointer { base, class: _ } => Resolution::Handle(base),
crate::Expression::Load { pointer } => match *past(pointer).inner_with(types) {
Ti::Pointer { base, class: _ } => Self::Handle(base),
Ti::ValuePointer {
size,
kind,
width,
class: _,
} => Resolution::Value(match size {
} => Self::Value(match size {
Some(size) => Ti::Vector { size, kind, width },
None => Ti::Scalar { kind, width },
}),
@@ -346,8 +320,8 @@ impl Typifier {
}
},
crate::Expression::ImageSample { image, .. }
| crate::Expression::ImageLoad { image, .. } => match *self.get(image, types) {
Ti::Image { class, .. } => Resolution::Value(match class {
| crate::Expression::ImageLoad { image, .. } => match *past(image).inner_with(types) {
Ti::Image { class, .. } => Self::Value(match class {
crate::ImageClass::Depth => Ti::Scalar {
kind: crate::ScalarKind::Float,
width: 4,
@@ -368,8 +342,8 @@ impl Typifier {
return Err(ResolveError::InvalidImage(image));
}
},
crate::Expression::ImageQuery { image, query } => Resolution::Value(match query {
crate::ImageQuery::Size { level: _ } => match *self.get(image, types) {
crate::Expression::ImageQuery { image, query } => Self::Value(match query {
crate::ImageQuery::Size { level: _ } => match *past(image).inner_with(types) {
Ti::Image { dim, .. } => match dim {
crate::ImageDimension::D1 => Ti::Scalar {
kind: crate::ScalarKind::Sint,
@@ -398,28 +372,30 @@ impl Typifier {
width: 4,
},
}),
crate::Expression::Unary { expr, .. } => self.resolutions[expr.index()].clone(),
crate::Expression::Unary { expr, .. } => past(expr).clone(),
crate::Expression::Binary { op, left, right } => match op {
crate::BinaryOperator::Add
| crate::BinaryOperator::Subtract
| crate::BinaryOperator::Divide
| crate::BinaryOperator::Modulo => self.resolutions[left.index()].clone(),
| crate::BinaryOperator::Modulo => past(left).clone(),
crate::BinaryOperator::Multiply => {
let ty_left = self.get(left, types);
let ty_right = self.get(right, types);
let res_left = past(left);
let ty_left = res_left.inner_with(types);
let res_right = past(right);
let ty_right = res_right.inner_with(types);
if ty_left == ty_right {
self.resolutions[left.index()].clone()
res_left.clone()
} else if let Ti::Scalar { .. } = *ty_left {
self.resolutions[right.index()].clone()
res_right.clone()
} else if let Ti::Scalar { .. } = *ty_right {
self.resolutions[left.index()].clone()
res_left.clone()
} else if let Ti::Matrix {
columns: _,
rows,
width,
} = *ty_left
{
Resolution::Value(Ti::Vector {
Self::Value(Ti::Vector {
size: rows,
kind: crate::ScalarKind::Float,
width,
@@ -430,7 +406,7 @@ impl Typifier {
width,
} = *ty_right
{
Resolution::Value(Ti::Vector {
Self::Value(Ti::Vector {
size: columns,
kind: crate::ScalarKind::Float,
width,
@@ -452,7 +428,7 @@ impl Typifier {
| crate::BinaryOperator::LogicalOr => {
let kind = crate::ScalarKind::Bool;
let width = 1;
let inner = match *self.get(left, types) {
let inner = match *past(left).inner_with(types) {
Ti::Scalar { .. } => Ti::Scalar { kind, width },
Ti::Vector { size, .. } => Ti::Vector { size, kind, width },
ref other => {
@@ -462,19 +438,17 @@ impl Typifier {
)))
}
};
Resolution::Value(inner)
Self::Value(inner)
}
crate::BinaryOperator::And
| crate::BinaryOperator::ExclusiveOr
| crate::BinaryOperator::InclusiveOr
| crate::BinaryOperator::ShiftLeft
| crate::BinaryOperator::ShiftRight => self.resolutions[left.index()].clone(),
| crate::BinaryOperator::ShiftRight => past(left).clone(),
},
crate::Expression::Select { accept, .. } => self.resolutions[accept.index()].clone(),
crate::Expression::Derivative { axis: _, expr } => {
self.resolutions[expr.index()].clone()
}
crate::Expression::Relational { .. } => Resolution::Value(Ti::Scalar {
crate::Expression::Select { accept, .. } => past(accept).clone(),
crate::Expression::Derivative { axis: _, expr } => past(expr).clone(),
crate::Expression::Relational { .. } => Self::Value(Ti::Scalar {
kind: crate::ScalarKind::Bool,
width: 4,
}),
@@ -485,6 +459,7 @@ impl Typifier {
arg2: _,
} => {
use crate::MathFunction as Mf;
let res_arg = past(arg);
match fun {
// comparison
Mf::Abs |
@@ -516,14 +491,14 @@ impl Typifier {
Mf::Exp2 |
Mf::Log |
Mf::Log2 |
Mf::Pow => self.resolutions[arg.index()].clone(),
Mf::Pow => res_arg.clone(),
// geometry
Mf::Dot => match *self.get(arg, types) {
Mf::Dot => match *res_arg.inner_with(types) {
Ti::Vector {
kind,
size: _,
width,
} => Resolution::Value(Ti::Scalar { kind, width }),
} => Self::Value(Ti::Scalar { kind, width }),
ref other =>
return Err(ResolveError::IncompatibleOperands(
format!("{:?}({:?}, _)", fun, other)
@@ -533,26 +508,26 @@ impl Typifier {
let arg1 = arg1.ok_or_else(|| ResolveError::IncompatibleOperands(
format!("{:?}(_, None)", fun)
))?;
match (self.get(arg, types), self.get(arg1,types)) {
(&Ti::Vector {kind: _, size: columns,width}, &Ti::Vector{ size: rows, .. }) => Resolution::Value(Ti::Matrix { columns, rows, width }),
match (res_arg.inner_with(types), past(arg1).inner_with(types)) {
(&Ti::Vector {kind: _, size: columns,width}, &Ti::Vector{ size: rows, .. }) => Self::Value(Ti::Matrix { columns, rows, width }),
(left, right) =>
return Err(ResolveError::IncompatibleOperands(
format!("{:?}({:?}, {:?})", fun, left, right)
)),
}
},
Mf::Cross => self.resolutions[arg.index()].clone(),
Mf::Cross => res_arg.clone(),
Mf::Distance |
Mf::Length => match *self.get(arg, types) {
Mf::Length => match *res_arg.inner_with(types) {
Ti::Scalar {width,kind} |
Ti::Vector {width,kind,size:_} => Resolution::Value(Ti::Scalar { kind, width }),
Ti::Vector {width,kind,size:_} => Self::Value(Ti::Scalar { kind, width }),
ref other => return Err(ResolveError::IncompatibleOperands(
format!("{:?}({:?})", fun, other)
)),
},
Mf::Normalize |
Mf::FaceForward |
Mf::Reflect => self.resolutions[arg.index()].clone(),
Mf::Reflect => res_arg.clone(),
// computational
Mf::Sign |
Mf::Fma |
@@ -560,13 +535,13 @@ impl Typifier {
Mf::Step |
Mf::SmoothStep |
Mf::Sqrt |
Mf::InverseSqrt => self.resolutions[arg.index()].clone(),
Mf::Transpose => match *self.get(arg, types) {
Mf::InverseSqrt => res_arg.clone(),
Mf::Transpose => match *res_arg.inner_with(types) {
Ti::Matrix {
columns,
rows,
width,
} => Resolution::Value(Ti::Matrix {
} => Self::Value(Ti::Matrix {
columns: rows,
rows: columns,
width,
@@ -575,12 +550,12 @@ impl Typifier {
format!("{:?}({:?})", fun, other)
)),
},
Mf::Inverse => match *self.get(arg, types) {
Mf::Inverse => match *res_arg.inner_with(types) {
Ti::Matrix {
columns,
rows,
width,
} if columns == rows => Resolution::Value(Ti::Matrix {
} if columns == rows => Self::Value(Ti::Matrix {
columns,
rows,
width,
@@ -589,31 +564,31 @@ impl Typifier {
format!("{:?}({:?})", fun, other)
)),
},
Mf::Determinant => match *self.get(arg, types) {
Mf::Determinant => match *res_arg.inner_with(types) {
Ti::Matrix {
width,
..
} => Resolution::Value(Ti::Scalar { kind: crate::ScalarKind::Float, width }),
} => Self::Value(Ti::Scalar { kind: crate::ScalarKind::Float, width }),
ref other => return Err(ResolveError::IncompatibleOperands(
format!("{:?}({:?})", fun, other)
)),
},
// bits
Mf::CountOneBits |
Mf::ReverseBits => self.resolutions[arg.index()].clone(),
Mf::ReverseBits => res_arg.clone(),
}
}
crate::Expression::As {
expr,
kind,
convert: _,
} => match *self.get(expr, types) {
Ti::Scalar { kind: _, width } => Resolution::Value(Ti::Scalar { kind, width }),
} => match *past(expr).inner_with(types) {
Ti::Scalar { kind: _, width } => Self::Value(Ti::Scalar { kind, width }),
Ti::Vector {
kind: _,
size,
width,
} => Resolution::Value(Ti::Vector { kind, size, width }),
} => Self::Value(Ti::Vector { kind, size, width }),
ref other => {
return Err(ResolveError::IncompatibleOperands(format!(
"{:?} as {:?}",
@@ -626,14 +601,60 @@ impl Typifier {
.result
.as_ref()
.ok_or(ResolveError::FunctionReturnsVoid)?;
Resolution::Handle(result.ty)
Self::Handle(result.ty)
}
crate::Expression::ArrayLength(_) => Resolution::Value(Ti::Scalar {
crate::Expression::ArrayLength(_) => Self::Value(Ti::Scalar {
kind: crate::ScalarKind::Uint,
width: 4,
}),
})
}
}
impl Typifier {
pub fn new() -> Self {
Typifier {
resolutions: Vec::new(),
}
}
pub fn clear(&mut self) {
self.resolutions.clear()
}
//TODO: remove most of these
pub fn get<'a>(
&'a self,
expr_handle: Handle<crate::Expression>,
types: &'a Arena<crate::Type>,
) -> &'a crate::TypeInner {
match self.resolutions[expr_handle.index()] {
TypeResolution::Handle(ty_handle) => &types[ty_handle].inner,
TypeResolution::Value(ref inner) => inner,
}
}
pub fn try_get<'a>(
&'a self,
expr_handle: Handle<crate::Expression>,
types: &'a Arena<crate::Type>,
) -> Option<&'a crate::TypeInner> {
let resolution = self.resolutions.get(expr_handle.index())?;
Some(match *resolution {
TypeResolution::Handle(ty_handle) => &types[ty_handle].inner,
TypeResolution::Value(ref inner) => inner,
})
}
pub fn get_handle(
&self,
expr_handle: Handle<crate::Expression>,
) -> Result<Handle<crate::Type>, &crate::TypeInner> {
match self.resolutions[expr_handle.index()] {
TypeResolution::Handle(ty_handle) => Ok(ty_handle),
TypeResolution::Value(ref inner) => Err(inner),
}
}
pub fn grow(
&mut self,
@@ -644,7 +665,8 @@ impl Typifier {
) -> Result<(), ResolveError> {
if self.resolutions.len() <= expr_handle.index() {
for (eh, expr) in expressions.iter().skip(self.resolutions.len()) {
let resolution = self.resolve_impl(expr, types, ctx)?;
let resolution =
TypeResolution::new(expr, types, ctx, |h| &self.resolutions[h.index()])?;
log::debug!("Resolving {:?} = {:?} : {:?}", eh, expr, resolution);
self.resolutions.push(resolution);
}
@@ -660,9 +682,9 @@ impl Typifier {
) -> Result<(), TypifyError> {
self.clear();
for (handle, expr) in expressions.iter() {
let resolution = self
.resolve_impl(expr, types, ctx)
.map_err(|err| TypifyError(handle, err))?;
let resolution =
TypeResolution::new(expr, types, ctx, |h| &self.resolutions[h.index()])
.map_err(|err| TypifyError(handle, err))?;
self.resolutions.push(resolution);
}
Ok(())

View File

@@ -6,8 +6,11 @@ Figures out the following properties:
- expression reference counts
!*/
use super::{CallError, FunctionError, ModuleInfo, ValidationFlags};
use crate::arena::{Arena, Handle};
use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ValidationFlags};
use crate::{
arena::{Arena, Handle},
proc::{ResolveContext, TypeResolution},
};
use std::ops;
pub type NonUniformResult = Option<Handle<crate::Expression>>;
@@ -138,6 +141,7 @@ pub struct ExpressionInfo {
pub uniformity: Uniformity,
pub ref_count: usize,
assignable_global: Option<Handle<crate::GlobalVariable>>,
pub ty: TypeResolution,
}
impl ExpressionInfo {
@@ -146,6 +150,11 @@ impl ExpressionInfo {
uniformity: Uniformity::new(),
ref_count: 0,
assignable_global: None,
// this doesn't matter at this point, will be overwritten
ty: TypeResolution::Value(crate::TypeInner::Scalar {
kind: crate::ScalarKind::Bool,
width: 0,
}),
}
}
}
@@ -284,15 +293,16 @@ impl FunctionInfo {
fn process_expression(
&mut self,
handle: Handle<crate::Expression>,
expression: &crate::Expression,
expression_arena: &Arena<crate::Expression>,
arguments: &[crate::FunctionArgument],
global_var_arena: &Arena<crate::GlobalVariable>,
other_functions: &[FunctionInfo],
) -> Result<(), FunctionError> {
type_arena: &Arena<crate::Type>,
resolve_context: &ResolveContext,
) -> Result<(), ExpressionError> {
use crate::{Expression as E, SampleLevel as Sl};
let mut assignable_global = None;
let uniformity = match expression_arena[handle] {
let uniformity = match *expression {
E::Access { base, index } => Uniformity {
non_uniform_result: self
.add_assignable_ref(base, &mut assignable_global)
@@ -316,7 +326,7 @@ impl FunctionInfo {
}
// depends on the builtin or interpolation
E::FunctionArgument(index) => {
let arg = &arguments[index as usize];
let arg = &resolve_context.arguments[index as usize];
let uniform = match arg.binding {
Some(crate::Binding::BuiltIn(built_in)) => match built_in {
// per-polygon built-ins are uniform
@@ -339,7 +349,7 @@ impl FunctionInfo {
E::GlobalVariable(gh) => {
use crate::StorageClass as Sc;
assignable_global = Some(gh);
let var = &global_var_arena[gh];
let var = &resolve_context.global_vars[gh];
let uniform = match var.class {
// local data is non-uniform
Sc::Function | Sc::Private => false,
@@ -377,15 +387,11 @@ impl FunctionInfo {
self.sampling_set.insert(SamplingKey {
image: match expression_arena[image] {
crate::Expression::GlobalVariable(var) => var,
ref other => {
return Err(FunctionError::ExpectedGlobalVariable(other.clone()))
}
_ => return Err(ExpressionError::ExpectedGlobalVariable),
},
sampler: match expression_arena[sampler] {
crate::Expression::GlobalVariable(var) => var,
ref other => {
return Err(FunctionError::ExpectedGlobalVariable(other.clone()))
}
_ => return Err(ExpressionError::ExpectedGlobalVariable),
},
});
// "nur" == "Non-Uniform Result"
@@ -482,13 +488,9 @@ impl FunctionInfo {
requirements: UniformityRequirements::empty(),
},
E::Call(function) => {
let fun =
other_functions
.get(function.index())
.ok_or(FunctionError::InvalidCall {
function,
error: CallError::ForwardDeclaredFunction,
})?;
let fun = other_functions
.get(function.index())
.ok_or(ExpressionError::CallToUndeclaredFunction(function))?;
self.process_call(fun).result
}
E::ArrayLength(expr) => Uniformity {
@@ -497,10 +499,14 @@ impl FunctionInfo {
},
};
let ty = TypeResolution::new(expression, type_arena, resolve_context, |h| {
&self.expressions[h.index()].ty
})?;
self.expressions[handle.index()] = ExpressionInfo {
uniformity,
ref_count: 0,
assignable_global,
ty,
};
Ok(())
}
@@ -648,6 +654,7 @@ impl FunctionInfo {
error: CallError::ForwardDeclaredFunction,
},
)?;
//Note: the result is validated by the Validator, not here
self.process_call(info)
}
};
@@ -665,7 +672,7 @@ impl ModuleInfo {
pub(super) fn process_function(
&self,
fun: &crate::Function,
global_var_arena: &Arena<crate::GlobalVariable>,
module: &crate::Module,
flags: ValidationFlags,
) -> Result<FunctionInfo, FunctionError> {
let mut info = FunctionInfo {
@@ -673,18 +680,28 @@ impl ModuleInfo {
uniformity: Uniformity::new(),
may_kill: false,
sampling_set: crate::FastHashSet::default(),
global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
};
let resolve_context = ResolveContext {
constants: &module.constants,
global_vars: &module.global_variables,
local_vars: &fun.local_variables,
functions: &module.functions,
arguments: &fun.arguments,
};
for (handle, _) in fun.expressions.iter() {
info.process_expression(
for (handle, expr) in fun.expressions.iter() {
if let Err(error) = info.process_expression(
handle,
expr,
&fun.expressions,
&fun.arguments,
global_var_arena,
&self.functions,
)?;
&module.types,
&resolve_context,
) {
return Err(FunctionError::Expression { handle, error });
}
}
let uniformity = info.process_block(&fun.body, &self.functions, None)?;
@@ -722,7 +739,8 @@ fn uniform_control_flow() {
let mut type_arena = Arena::new();
let ty = type_arena.append(crate::Type {
name: None,
inner: crate::TypeInner::Scalar {
inner: crate::TypeInner::Vector {
size: crate::VectorSize::Bi,
kind: crate::ScalarKind::Float,
width: 4,
},
@@ -775,9 +793,23 @@ fn uniform_control_flow() {
global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
};
for (handle, _) in expressions.iter() {
info.process_expression(handle, &expressions, &[], &global_var_arena, &[])
.unwrap();
let resolve_context = ResolveContext {
constants: &constant_arena,
global_vars: &global_var_arena,
local_vars: &Arena::new(),
functions: &Arena::new(),
arguments: &[],
};
for (handle, expression) in expressions.iter() {
info.process_expression(
handle,
expression,
&expressions,
&[],
&type_arena,
&resolve_context,
)
.unwrap();
}
assert_eq!(info[non_uniform_global_expr].ref_count, 1);
assert_eq!(info[uniform_global_expr].ref_count, 1);

View File

@@ -1,6 +1,7 @@
use super::FunctionInfo;
use crate::{
arena::{Arena, Handle},
proc::Typifier,
proc::ResolveError,
};
#[derive(Clone, Debug, thiserror::Error)]
@@ -52,12 +53,18 @@ pub enum ExpressionError {
InvalidBooleanVector(Handle<crate::Expression>),
#[error("Relational argument {0:?} is not a float")]
InvalidFloatArgument(Handle<crate::Expression>),
#[error("Type resolution failed")]
Type(#[from] ResolveError),
#[error("Not a global variable")]
ExpectedGlobalVariable,
#[error("Calling an undeclared function {0:?}")]
CallToUndeclaredFunction(Handle<crate::Function>),
}
struct ExpressionTypeResolver<'a> {
root: Handle<crate::Expression>,
types: &'a Arena<crate::Type>,
typifier: &'a Typifier,
info: &'a FunctionInfo,
}
impl<'a> ExpressionTypeResolver<'a> {
@@ -66,7 +73,7 @@ impl<'a> ExpressionTypeResolver<'a> {
handle: Handle<crate::Expression>,
) -> Result<&'a crate::TypeInner, ExpressionError> {
if handle < self.root {
Ok(self.typifier.get(handle, self.types))
Ok(self.info[handle].ty.inner_with(self.types))
} else {
Err(ExpressionError::ForwardDependency(handle))
}
@@ -80,13 +87,14 @@ impl super::Validator {
expression: &crate::Expression,
function: &crate::Function,
module: &crate::Module,
info: &FunctionInfo,
) -> Result<(), ExpressionError> {
use crate::{Expression as E, ScalarKind as Sk, TypeInner as Ti};
let resolver = ExpressionTypeResolver {
root,
types: &module.types,
typifier: &self.typifier,
info,
};
match *expression {

View File

@@ -1,11 +1,8 @@
use super::{
analyzer::{FunctionInfo, UniformityDisruptor, UniformityRequirements},
ExpressionError, ModuleInfo, TypeFlags, ValidationFlags,
};
use crate::{
arena::{Arena, Handle},
proc::{ResolveContext, TypifyError},
analyzer::{UniformityDisruptor, UniformityRequirements},
ExpressionError, FunctionInfo, ModuleInfo, TypeFlags, ValidationFlags,
};
use crate::arena::{Arena, Handle};
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
@@ -32,11 +29,8 @@ pub enum CallError {
required: Handle<crate::Type>,
seen_expression: Handle<crate::Expression>,
},
#[error("Result value {seen_expression:?} does not match the type {required:?}")]
ResultType {
required: Option<Handle<crate::Type>>,
seen_expression: Option<Handle<crate::Expression>>,
},
#[error("The emitted expression doesn't match the call")]
ExpressionMismatch(Option<Handle<crate::Expression>>),
}
#[derive(Clone, Debug, thiserror::Error)]
@@ -49,8 +43,6 @@ pub enum LocalVariableError {
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum FunctionError {
#[error(transparent)]
Resolve(#[from] TypifyError),
#[error("Expression {handle:?} is invalid")]
Expression {
handle: Handle<crate::Expression>,
@@ -103,8 +95,6 @@ pub enum FunctionError {
#[source]
error: CallError,
},
#[error("Expression {0:?} is not a global variable!")]
ExpectedGlobalVariable(crate::Expression),
#[error(
"Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
)]
@@ -127,6 +117,7 @@ bitflags::bitflags! {
struct BlockContext<'a> {
flags: Flags,
info: &'a FunctionInfo,
expressions: &'a Arena<crate::Expression>,
types: &'a Arena<crate::Type>,
functions: &'a Arena<crate::Function>,
@@ -134,9 +125,10 @@ struct BlockContext<'a> {
}
impl<'a> BlockContext<'a> {
pub(super) fn new(fun: &'a crate::Function, module: &'a crate::Module) -> Self {
fn new(fun: &'a crate::Function, module: &'a crate::Module, info: &'a FunctionInfo) -> Self {
Self {
flags: Flags::CAN_JUMP,
info,
expressions: &fun.expressions,
types: &module.types,
functions: &module.functions,
@@ -147,6 +139,7 @@ impl<'a> BlockContext<'a> {
fn with_flags(&self, flags: Flags) -> Self {
BlockContext {
flags,
info: self.info,
expressions: self.expressions,
types: self.types,
functions: self.functions,
@@ -162,6 +155,25 @@ impl<'a> BlockContext<'a> {
.try_get(handle)
.ok_or(FunctionError::InvalidExpression(handle))
}
fn resolve_type_impl(
&self,
handle: Handle<crate::Expression>,
) -> Result<&crate::TypeInner, ExpressionError> {
if handle.index() < self.expressions.len() {
Ok(self.info[handle].ty.inner_with(self.types))
} else {
Err(ExpressionError::DoesntExist)
}
}
fn resolve_type(
&self,
handle: Handle<crate::Expression>,
) -> Result<&crate::TypeInner, FunctionError> {
self.resolve_type_impl(handle)
.map_err(|error| FunctionError::Expression { handle, error })
}
}
impl super::Validator {
@@ -183,8 +195,8 @@ impl super::Validator {
});
}
for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() {
let ty = self
.resolve_statement_type_impl(expr, context.types)
let ty = context
.resolve_type_impl(expr)
.map_err(|error| CallError::Argument { index, error })?;
if ty != &context.types[arg.ty].inner {
return Err(CallError::ArgumentType {
@@ -201,49 +213,17 @@ impl super::Validator {
} else {
return Err(CallError::ResultAlreadyInScope(expr));
}
match context.expressions[expr] {
crate::Expression::Call(callee) if fun.result.is_some() && callee == function => {}
_ => return Err(CallError::ExpressionMismatch(result)),
}
} else if fun.result.is_some() {
return Err(CallError::ExpressionMismatch(result));
}
let result_ty = result
.map(|expr| self.resolve_statement_type_impl(expr, context.types))
.transpose()
.map_err(CallError::ResultValue)?;
let expected_ty = fun.result.as_ref().map(|fr| &context.types[fr.ty].inner);
if result_ty != expected_ty {
log::error!(
"Called function returns {:?} where {:?} is expected",
result_ty,
expected_ty
);
return Err(CallError::ResultType {
required: fun.result.as_ref().map(|fr| fr.ty),
seen_expression: result,
});
}
Ok(())
}
fn resolve_statement_type_impl<'a>(
&'a self,
handle: Handle<crate::Expression>,
types: &'a Arena<crate::Type>,
) -> Result<&'a crate::TypeInner, ExpressionError> {
if !self.valid_expression_set.contains(handle.index()) {
return Err(ExpressionError::NotInScope);
}
self.typifier
.try_get(handle, types)
.ok_or(ExpressionError::DoesntExist)
}
fn resolve_statement_type<'a>(
&'a self,
handle: Handle<crate::Expression>,
types: &'a Arena<crate::Type>,
) -> Result<&'a crate::TypeInner, FunctionError> {
self.resolve_statement_type_impl(handle, types)
.map_err(|error| FunctionError::Expression { handle, error })
}
fn validate_block_impl(
&mut self,
statements: &[crate::Statement],
@@ -271,7 +251,7 @@ impl super::Validator {
ref accept,
ref reject,
} => {
match *self.resolve_statement_type(condition, context.types)? {
match *context.resolve_type(condition)? {
Ti::Scalar {
kind: crate::ScalarKind::Bool,
width: _,
@@ -286,7 +266,7 @@ impl super::Validator {
ref cases,
ref default,
} => {
match *self.resolve_statement_type(selector, context.types)? {
match *context.resolve_type(selector)? {
Ti::Scalar {
kind: crate::ScalarKind::Sint,
width: _,
@@ -330,9 +310,7 @@ impl super::Validator {
if !context.flags.contains(Flags::CAN_JUMP) {
return Err(FunctionError::InvalidReturnSpot);
}
let value_ty = value
.map(|expr| self.resolve_statement_type(expr, context.types))
.transpose()?;
let value_ty = value.map(|expr| context.resolve_type(expr)).transpose()?;
let expected_ty = context.return_type.map(|ty| &context.types[ty].inner);
if value_ty != expected_ty {
log::error!(
@@ -350,12 +328,7 @@ impl super::Validator {
S::Store { pointer, value } => {
let mut current = pointer;
loop {
self.typifier.try_get(current, context.types).ok_or(
FunctionError::Expression {
handle: current,
error: ExpressionError::DoesntExist,
},
)?;
let _ = context.resolve_type(current)?;
match context.expressions[current] {
crate::Expression::Access { base, .. }
| crate::Expression::AccessIndex { base, .. } => current = base,
@@ -366,29 +339,27 @@ impl super::Validator {
}
}
let value_ty = self.resolve_statement_type(value, context.types)?;
let value_ty = context.resolve_type(value)?;
match *value_ty {
Ti::Image { .. } | Ti::Sampler { .. } => {
return Err(FunctionError::InvalidStoreValue(value));
}
_ => {}
}
let good = match self.typifier.try_get(pointer, context.types) {
Some(&Ti::Pointer { base, class: _ }) => {
*value_ty == context.types[base].inner
}
Some(&Ti::ValuePointer {
let good = match *context.resolve_type(pointer)? {
Ti::Pointer { base, class: _ } => *value_ty == context.types[base].inner,
Ti::ValuePointer {
size: Some(size),
kind,
width,
class: _,
}) => *value_ty == Ti::Vector { size, kind, width },
Some(&Ti::ValuePointer {
} => *value_ty == Ti::Vector { size, kind, width },
Ti::ValuePointer {
size: None,
kind,
width,
class: _,
}) => *value_ty == Ti::Scalar { kind, width },
} => *value_ty == Ti::Scalar { kind, width },
_ => false,
};
if !good {
@@ -405,15 +376,14 @@ impl super::Validator {
crate::Expression::GlobalVariable(_var_handle) => (), //TODO
_ => return Err(FunctionError::InvalidImage(image)),
};
let value_ty = self.typifier.get(value, context.types);
match *value_ty {
match *context.resolve_type(value)? {
Ti::Scalar { .. } | Ti::Vector { .. } => {}
_ => {
return Err(FunctionError::InvalidStoreValue(value));
}
}
if let Some(expr) = array_index {
match *self.typifier.get(expr, context.types) {
match *context.resolve_type(expr)? {
Ti::Scalar {
kind: crate::ScalarKind::Sint,
width: _,
@@ -483,16 +453,7 @@ impl super::Validator {
module: &crate::Module,
mod_info: &ModuleInfo,
) -> Result<FunctionInfo, FunctionError> {
let resolve_ctx = ResolveContext {
constants: &module.constants,
global_vars: &module.global_variables,
local_vars: &fun.local_variables,
functions: &module.functions,
arguments: &fun.arguments,
};
self.typifier
.resolve_all(&fun.expressions, &module.types, &resolve_ctx)?;
let info = mod_info.process_function(fun, &module.global_variables, self.flags)?;
let info = mod_info.process_function(fun, module, self.flags)?;
for (var_handle, var) in fun.local_variables.iter() {
self.validate_local_var(var, &module.types, &module.constants)
@@ -521,14 +482,14 @@ impl super::Validator {
self.valid_expression_set.insert(handle.index());
}
if !self.flags.contains(ValidationFlags::EXPRESSIONS) {
if let Err(error) = self.validate_expression(handle, expr, fun, module) {
if let Err(error) = self.validate_expression(handle, expr, fun, module, &info) {
return Err(FunctionError::Expression { handle, error });
}
}
}
if self.flags.contains(ValidationFlags::BLOCKS) {
self.validate_block(&fun.body, &BlockContext::new(fun, module))?;
self.validate_block(&fun.body, &BlockContext::new(fun, module, &info))?;
}
Ok(info)
}

View File

@@ -6,7 +6,7 @@ mod r#type;
use crate::{
arena::{Arena, Handle},
proc::{Layouter, Typifier},
proc::Layouter,
FastHashSet,
};
use bit_set::BitSet;
@@ -40,9 +40,6 @@ pub struct ModuleInfo {
#[derive(Debug)]
pub struct Validator {
flags: ValidationFlags,
//Note: this is a bit tricky: some of the front-ends as well as backends
// already have to use the typifier, so the work here is redundant in a way.
typifier: Typifier,
types: Vec<r#type::TypeInfo>,
location_mask: BitSet,
bind_group_masks: Vec<BitSet>,
@@ -125,7 +122,6 @@ impl Validator {
pub fn new(flags: ValidationFlags) -> Self {
Validator {
flags,
typifier: Typifier::new(),
types: Vec::new(),
location_mask: BitSet::new(),
bind_group_masks: Vec::new(),

View File

@@ -103,7 +103,6 @@ impl super::Validator {
}
pub(super) fn reset_types(&mut self, size: usize) {
self.typifier.clear();
self.types.clear();
self.types.resize(size, TypeInfo::new());
}

View File

@@ -31,6 +31,10 @@ expression: output
),
ref_count: 0,
assignable_global: Some(1),
ty: Value(Pointer(
base: 3,
class: Storage,
)),
),
(
uniformity: (
@@ -41,6 +45,7 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
(
uniformity: (
@@ -51,6 +56,10 @@ expression: output
),
ref_count: 7,
assignable_global: None,
ty: Value(Pointer(
base: 1,
class: Function,
)),
),
(
uniformity: (
@@ -61,6 +70,10 @@ expression: output
),
ref_count: 0,
assignable_global: None,
ty: Value(Scalar(
kind: Uint,
width: 4,
)),
),
(
uniformity: (
@@ -71,6 +84,10 @@ expression: output
),
ref_count: 3,
assignable_global: None,
ty: Value(Pointer(
base: 1,
class: Function,
)),
),
(
uniformity: (
@@ -81,6 +98,7 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
(
uniformity: (
@@ -91,6 +109,10 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Uint,
width: 4,
)),
),
(
uniformity: (
@@ -101,6 +123,10 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Bool,
width: 1,
)),
),
(
uniformity: (
@@ -111,6 +137,7 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
(
uniformity: (
@@ -121,6 +148,10 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Uint,
width: 4,
)),
),
(
uniformity: (
@@ -131,6 +162,7 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
(
uniformity: (
@@ -141,6 +173,10 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Uint,
width: 4,
)),
),
(
uniformity: (
@@ -151,6 +187,10 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Bool,
width: 1,
)),
),
(
uniformity: (
@@ -161,6 +201,7 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
(
uniformity: (
@@ -171,6 +212,10 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Uint,
width: 4,
)),
),
(
uniformity: (
@@ -181,6 +226,7 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
(
uniformity: (
@@ -191,6 +237,10 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Uint,
width: 4,
)),
),
(
uniformity: (
@@ -201,6 +251,7 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
(
uniformity: (
@@ -211,6 +262,10 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Uint,
width: 4,
)),
),
(
uniformity: (
@@ -221,6 +276,10 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Uint,
width: 4,
)),
),
(
uniformity: (
@@ -231,6 +290,10 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Uint,
width: 4,
)),
),
(
uniformity: (
@@ -241,6 +304,7 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
(
uniformity: (
@@ -251,6 +315,10 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Uint,
width: 4,
)),
),
(
uniformity: (
@@ -261,6 +329,7 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
(
uniformity: (
@@ -271,6 +340,7 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
],
),
@@ -303,6 +373,10 @@ expression: output
),
ref_count: 2,
assignable_global: Some(1),
ty: Value(Pointer(
base: 3,
class: Storage,
)),
),
(
uniformity: (
@@ -313,6 +387,7 @@ expression: output
),
ref_count: 2,
assignable_global: None,
ty: Handle(4),
),
(
uniformity: (
@@ -323,6 +398,10 @@ expression: output
),
ref_count: 1,
assignable_global: Some(1),
ty: Value(Pointer(
base: 2,
class: Storage,
)),
),
(
uniformity: (
@@ -333,6 +412,10 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Uint,
width: 4,
)),
),
(
uniformity: (
@@ -343,6 +426,10 @@ expression: output
),
ref_count: 1,
assignable_global: Some(1),
ty: Value(Pointer(
base: 1,
class: Storage,
)),
),
(
uniformity: (
@@ -353,6 +440,10 @@ expression: output
),
ref_count: 1,
assignable_global: Some(1),
ty: Value(Pointer(
base: 2,
class: Storage,
)),
),
(
uniformity: (
@@ -363,6 +454,10 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Uint,
width: 4,
)),
),
(
uniformity: (
@@ -373,6 +468,10 @@ expression: output
),
ref_count: 1,
assignable_global: Some(1),
ty: Value(Pointer(
base: 1,
class: Storage,
)),
),
(
uniformity: (
@@ -383,6 +482,7 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
(
uniformity: (
@@ -393,6 +493,7 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
],
),

File diff suppressed because it is too large Load Diff