mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
Add support for unary ops to the constant solver
Fix wrong result when casting to bool Add tests
This commit is contained in:
committed by
Dzmitry Malyshau
parent
53bd721895
commit
fcbf2aa4c4
@@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
arena::{Arena, Handle},
|
||||
ArraySize, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type,
|
||||
ArraySize, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, UnaryOperator,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -44,6 +44,8 @@ pub enum ConstantSolvingError {
|
||||
Bitcast,
|
||||
#[error("Cannot cast type")]
|
||||
InvalidCastArg,
|
||||
#[error("Cannot apply the unary op to the argument")]
|
||||
InvalidUnaryOpArg,
|
||||
}
|
||||
|
||||
impl<'a> ConstantSolver<'a> {
|
||||
@@ -71,7 +73,11 @@ impl<'a> ConstantSolver<'a> {
|
||||
inner: ConstantInner::Composite { ty, components },
|
||||
}))
|
||||
}
|
||||
Expression::Unary { .. } => todo!(),
|
||||
Expression::Unary { expr, op } => {
|
||||
let tgt = self.solve(expr)?;
|
||||
|
||||
self.unary_op(op, tgt)
|
||||
}
|
||||
Expression::Binary { .. } => todo!(),
|
||||
Expression::Math { .. } => todo!(),
|
||||
Expression::As {
|
||||
@@ -196,7 +202,7 @@ impl<'a> ConstantSolver<'a> {
|
||||
ScalarKind::Sint => *value = ScalarValue::Sint(inner_cast(intial)),
|
||||
ScalarKind::Uint => *value = ScalarValue::Uint(inner_cast(intial)),
|
||||
ScalarKind::Float => *value = ScalarValue::Float(inner_cast(intial)),
|
||||
ScalarKind::Bool => *value = ScalarValue::Bool(inner_cast::<u64>(intial) == 0),
|
||||
ScalarKind::Bool => *value = ScalarValue::Bool(inner_cast::<u64>(intial) != 0),
|
||||
}
|
||||
}
|
||||
ConstantInner::Composite {
|
||||
@@ -220,4 +226,210 @@ impl<'a> ConstantSolver<'a> {
|
||||
inner,
|
||||
}))
|
||||
}
|
||||
|
||||
fn unary_op(
|
||||
&mut self,
|
||||
op: UnaryOperator,
|
||||
constant: Handle<Constant>,
|
||||
) -> Result<Handle<Constant>, ConstantSolvingError> {
|
||||
let mut inner = self.constants[constant].inner.clone();
|
||||
|
||||
match inner {
|
||||
ConstantInner::Scalar { ref mut value, .. } => match op {
|
||||
UnaryOperator::Negate => match value {
|
||||
ScalarValue::Sint(v) => *v = -*v,
|
||||
ScalarValue::Float(v) => *v = -*v,
|
||||
_ => return Err(ConstantSolvingError::InvalidUnaryOpArg),
|
||||
},
|
||||
UnaryOperator::Not => match value {
|
||||
ScalarValue::Sint(v) => *v = !*v,
|
||||
ScalarValue::Uint(v) => *v = !*v,
|
||||
ScalarValue::Bool(v) => *v = !*v,
|
||||
_ => return Err(ConstantSolvingError::InvalidUnaryOpArg),
|
||||
},
|
||||
},
|
||||
ConstantInner::Composite {
|
||||
ty,
|
||||
ref mut components,
|
||||
} => {
|
||||
match self.types[ty].inner {
|
||||
crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => (),
|
||||
_ => return Err(ConstantSolvingError::InvalidCastArg),
|
||||
}
|
||||
|
||||
for component in components {
|
||||
*component = self.unary_op(op, *component)?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(self.constants.fetch_or_append(Constant {
|
||||
name: None,
|
||||
specialization: None,
|
||||
inner,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::vec;
|
||||
|
||||
use crate::{
|
||||
Arena, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner,
|
||||
UnaryOperator, VectorSize,
|
||||
};
|
||||
|
||||
use super::ConstantSolver;
|
||||
|
||||
#[test]
|
||||
fn unary_op() {
|
||||
let mut types = Arena::new();
|
||||
let mut expressions = Arena::new();
|
||||
let mut constants = Arena::new();
|
||||
|
||||
let vec_ty = types.append(Type {
|
||||
name: None,
|
||||
inner: TypeInner::Vector {
|
||||
size: VectorSize::Bi,
|
||||
kind: ScalarKind::Sint,
|
||||
width: 4,
|
||||
},
|
||||
});
|
||||
|
||||
let h = constants.append(Constant {
|
||||
name: None,
|
||||
specialization: None,
|
||||
inner: ConstantInner::Scalar {
|
||||
width: 4,
|
||||
value: ScalarValue::Sint(4),
|
||||
},
|
||||
});
|
||||
|
||||
let h1 = constants.append(Constant {
|
||||
name: None,
|
||||
specialization: None,
|
||||
inner: ConstantInner::Scalar {
|
||||
width: 4,
|
||||
value: ScalarValue::Sint(8),
|
||||
},
|
||||
});
|
||||
|
||||
let vec_h = constants.append(Constant {
|
||||
name: None,
|
||||
specialization: None,
|
||||
inner: ConstantInner::Composite {
|
||||
ty: vec_ty,
|
||||
components: vec![h, h1],
|
||||
},
|
||||
});
|
||||
|
||||
let expr = expressions.append(Expression::Constant(h));
|
||||
let expr1 = expressions.append(Expression::Constant(vec_h));
|
||||
|
||||
let root1 = expressions.append(Expression::Unary {
|
||||
op: UnaryOperator::Negate,
|
||||
expr,
|
||||
});
|
||||
|
||||
let root2 = expressions.append(Expression::Unary {
|
||||
op: UnaryOperator::Not,
|
||||
expr,
|
||||
});
|
||||
|
||||
let root3 = expressions.append(Expression::Unary {
|
||||
op: UnaryOperator::Not,
|
||||
expr: expr1,
|
||||
});
|
||||
|
||||
let mut solver = ConstantSolver {
|
||||
types: &types,
|
||||
expressions: &expressions,
|
||||
constants: &mut constants,
|
||||
};
|
||||
|
||||
let res1 = solver.solve(root1).unwrap();
|
||||
let res2 = solver.solve(root2).unwrap();
|
||||
let res3 = solver.solve(root3).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
constants[res1].inner,
|
||||
ConstantInner::Scalar {
|
||||
width: 4,
|
||||
value: ScalarValue::Sint(-4),
|
||||
},
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
constants[res2].inner,
|
||||
ConstantInner::Scalar {
|
||||
width: 4,
|
||||
value: ScalarValue::Sint(!4),
|
||||
},
|
||||
);
|
||||
|
||||
let res3_inner = &constants[res3].inner;
|
||||
|
||||
match res3_inner {
|
||||
ConstantInner::Composite { ty, components } => {
|
||||
assert_eq!(*ty, vec_ty);
|
||||
let mut components_iter = components.iter().copied();
|
||||
assert_eq!(
|
||||
constants[components_iter.next().unwrap()].inner,
|
||||
ConstantInner::Scalar {
|
||||
width: 4,
|
||||
value: ScalarValue::Sint(!4),
|
||||
},
|
||||
);
|
||||
assert_eq!(
|
||||
constants[components_iter.next().unwrap()].inner,
|
||||
ConstantInner::Scalar {
|
||||
width: 4,
|
||||
value: ScalarValue::Sint(!8),
|
||||
},
|
||||
);
|
||||
assert!(components_iter.next().is_none());
|
||||
}
|
||||
_ => panic!("Expected vector"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast() {
|
||||
let mut expressions = Arena::new();
|
||||
let mut constants = Arena::new();
|
||||
|
||||
let h = constants.append(Constant {
|
||||
name: None,
|
||||
specialization: None,
|
||||
inner: ConstantInner::Scalar {
|
||||
width: 4,
|
||||
value: ScalarValue::Sint(4),
|
||||
},
|
||||
});
|
||||
|
||||
let expr = expressions.append(Expression::Constant(h));
|
||||
|
||||
let root = expressions.append(Expression::As {
|
||||
expr,
|
||||
kind: ScalarKind::Bool,
|
||||
convert: true,
|
||||
});
|
||||
|
||||
let mut solver = ConstantSolver {
|
||||
types: &Arena::new(),
|
||||
expressions: &expressions,
|
||||
constants: &mut constants,
|
||||
};
|
||||
|
||||
let res = solver.solve(root).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
constants[res].inner,
|
||||
ConstantInner::Scalar {
|
||||
width: 4,
|
||||
value: ScalarValue::Bool(true),
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user