mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
Add binary op support for the ConstantSolver
This commit is contained in:
committed by
Dzmitry Malyshau
parent
b1fa7471d2
commit
88eee833d7
@@ -401,7 +401,7 @@ pub struct Constant {
|
||||
}
|
||||
|
||||
/// A literal scalar value, used in constants.
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
#[derive(Debug, PartialEq, Clone, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub enum ScalarValue {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::{
|
||||
arena::{Arena, Handle},
|
||||
Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, UnaryOperator,
|
||||
BinaryOperator, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner,
|
||||
UnaryOperator,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -46,6 +47,8 @@ pub enum ConstantSolvingError {
|
||||
InvalidCastArg,
|
||||
#[error("Cannot apply the unary op to the argument")]
|
||||
InvalidUnaryOpArg,
|
||||
#[error("Cannot apply the binary op to the arguments")]
|
||||
InvalidBinaryOpArgs,
|
||||
}
|
||||
|
||||
impl<'a> ConstantSolver<'a> {
|
||||
@@ -78,7 +81,12 @@ impl<'a> ConstantSolver<'a> {
|
||||
|
||||
self.unary_op(op, tgt)
|
||||
}
|
||||
Expression::Binary { .. } => todo!(),
|
||||
Expression::Binary { left, right, op } => {
|
||||
let left = self.solve(left)?;
|
||||
let right = self.solve(right)?;
|
||||
|
||||
self.binary_op(op, left, right)
|
||||
}
|
||||
Expression::Math { .. } => todo!(),
|
||||
Expression::As {
|
||||
convert,
|
||||
@@ -97,11 +105,11 @@ impl<'a> ConstantSolver<'a> {
|
||||
let array = self.solve(expr)?;
|
||||
|
||||
match self.constants[array].inner {
|
||||
crate::ConstantInner::Scalar { .. } => {
|
||||
ConstantInner::Scalar { .. } => {
|
||||
Err(ConstantSolvingError::InvalidArrayLengthArg)
|
||||
}
|
||||
crate::ConstantInner::Composite { ty, .. } => match self.types[ty].inner {
|
||||
crate::TypeInner::Array { size, .. } => match size {
|
||||
ConstantInner::Composite { ty, .. } => match self.types[ty].inner {
|
||||
TypeInner::Array { size, .. } => match size {
|
||||
crate::ArraySize::Constant(constant) => Ok(constant),
|
||||
crate::ArraySize::Dynamic => {
|
||||
Err(ConstantSolvingError::ArrayLengthDynamic)
|
||||
@@ -134,13 +142,13 @@ impl<'a> ConstantSolver<'a> {
|
||||
let base = self.solve(base)?;
|
||||
|
||||
match self.constants[base].inner {
|
||||
crate::ConstantInner::Scalar { .. } => Err(ConstantSolvingError::InvalidAccessBase),
|
||||
crate::ConstantInner::Composite { ty, ref components } => {
|
||||
ConstantInner::Scalar { .. } => Err(ConstantSolvingError::InvalidAccessBase),
|
||||
ConstantInner::Composite { ty, ref components } => {
|
||||
match self.types[ty].inner {
|
||||
crate::TypeInner::Vector { .. }
|
||||
| crate::TypeInner::Matrix { .. }
|
||||
| crate::TypeInner::Array { .. }
|
||||
| crate::TypeInner::Struct { .. } => (),
|
||||
TypeInner::Vector { .. }
|
||||
| TypeInner::Matrix { .. }
|
||||
| TypeInner::Array { .. }
|
||||
| TypeInner::Struct { .. } => (),
|
||||
_ => return Err(ConstantSolvingError::InvalidAccessBase),
|
||||
}
|
||||
|
||||
@@ -195,7 +203,7 @@ impl<'a> ConstantSolver<'a> {
|
||||
ref mut components,
|
||||
} => {
|
||||
match self.types[ty].inner {
|
||||
crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => (),
|
||||
TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
|
||||
_ => return Err(ConstantSolvingError::InvalidCastArg),
|
||||
}
|
||||
|
||||
@@ -238,7 +246,7 @@ impl<'a> ConstantSolver<'a> {
|
||||
ref mut components,
|
||||
} => {
|
||||
match self.types[ty].inner {
|
||||
crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => (),
|
||||
TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
|
||||
_ => return Err(ConstantSolvingError::InvalidCastArg),
|
||||
}
|
||||
|
||||
@@ -254,6 +262,100 @@ impl<'a> ConstantSolver<'a> {
|
||||
inner,
|
||||
}))
|
||||
}
|
||||
|
||||
fn binary_op(
|
||||
&mut self,
|
||||
op: BinaryOperator,
|
||||
left: Handle<Constant>,
|
||||
right: Handle<Constant>,
|
||||
) -> Result<Handle<Constant>, ConstantSolvingError> {
|
||||
let left = &self.constants[left].inner;
|
||||
let right = &self.constants[right].inner;
|
||||
|
||||
let inner = match (left, right) {
|
||||
(
|
||||
ConstantInner::Scalar {
|
||||
value: left_value,
|
||||
width,
|
||||
},
|
||||
ConstantInner::Scalar {
|
||||
value: right_value, ..
|
||||
},
|
||||
) => {
|
||||
let value = match op {
|
||||
BinaryOperator::Equal => ScalarValue::Bool(left_value == right_value),
|
||||
BinaryOperator::NotEqual => ScalarValue::Bool(left_value != right_value),
|
||||
BinaryOperator::Less => ScalarValue::Bool(left_value < right_value),
|
||||
BinaryOperator::LessEqual => ScalarValue::Bool(left_value <= right_value),
|
||||
BinaryOperator::Greater => ScalarValue::Bool(left_value > right_value),
|
||||
BinaryOperator::GreaterEqual => ScalarValue::Bool(left_value >= right_value),
|
||||
|
||||
_ => match (left_value, right_value) {
|
||||
(ScalarValue::Sint(a), ScalarValue::Sint(b)) => {
|
||||
ScalarValue::Sint(match op {
|
||||
BinaryOperator::Add => a + b,
|
||||
BinaryOperator::Subtract => a - b,
|
||||
BinaryOperator::Multiply => a * b,
|
||||
BinaryOperator::Divide => a / b,
|
||||
BinaryOperator::Modulo => a % b,
|
||||
BinaryOperator::And => a & b,
|
||||
BinaryOperator::ExclusiveOr => a ^ b,
|
||||
BinaryOperator::InclusiveOr => a | b,
|
||||
BinaryOperator::ShiftLeft => a << b,
|
||||
BinaryOperator::ShiftRight => a >> b,
|
||||
_ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
|
||||
})
|
||||
}
|
||||
(ScalarValue::Uint(a), ScalarValue::Uint(b)) => {
|
||||
ScalarValue::Uint(match op {
|
||||
BinaryOperator::Add => a + b,
|
||||
BinaryOperator::Subtract => a - b,
|
||||
BinaryOperator::Multiply => a * b,
|
||||
BinaryOperator::Divide => a / b,
|
||||
BinaryOperator::Modulo => a % b,
|
||||
BinaryOperator::And => a & b,
|
||||
BinaryOperator::ExclusiveOr => a ^ b,
|
||||
BinaryOperator::InclusiveOr => a | b,
|
||||
BinaryOperator::ShiftLeft => a << b,
|
||||
BinaryOperator::ShiftRight => a >> b,
|
||||
_ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
|
||||
})
|
||||
}
|
||||
(ScalarValue::Float(a), ScalarValue::Float(b)) => {
|
||||
ScalarValue::Float(match op {
|
||||
BinaryOperator::Add => a + b,
|
||||
BinaryOperator::Subtract => a - b,
|
||||
BinaryOperator::Multiply => a * b,
|
||||
BinaryOperator::Divide => a / b,
|
||||
BinaryOperator::Modulo => a % b,
|
||||
_ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
|
||||
})
|
||||
}
|
||||
(ScalarValue::Bool(a), ScalarValue::Bool(b)) => {
|
||||
ScalarValue::Bool(match op {
|
||||
BinaryOperator::LogicalAnd => *a && *b,
|
||||
BinaryOperator::LogicalOr => *a || *b,
|
||||
_ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
|
||||
})
|
||||
}
|
||||
_ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
|
||||
},
|
||||
};
|
||||
|
||||
ConstantInner::Scalar {
|
||||
value,
|
||||
width: *width,
|
||||
}
|
||||
}
|
||||
_ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
|
||||
};
|
||||
|
||||
Ok(self.constants.fetch_or_append(Constant {
|
||||
name: None,
|
||||
specialization: None,
|
||||
inner,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Reference in New Issue
Block a user