mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
Fix float-bool casts in MSL, SPV, and HLSL backends (#1459)
This commit is contained in:
@@ -1786,20 +1786,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
write!(self.out, "{}", op_str)?;
|
||||
self.write_expr(module, expr, func_ctx)?;
|
||||
}
|
||||
Expression::As { expr, kind, .. } => {
|
||||
Expression::As {
|
||||
expr,
|
||||
kind,
|
||||
convert,
|
||||
} => {
|
||||
let inner = func_ctx.info[expr].ty.inner_with(&module.types);
|
||||
match *inner {
|
||||
TypeInner::Vector { size, width, .. } => {
|
||||
write!(
|
||||
self.out,
|
||||
"{}{}",
|
||||
kind.to_hlsl_str(width)?,
|
||||
back::vector_size_str(size),
|
||||
)?;
|
||||
}
|
||||
TypeInner::Scalar { width, .. } => {
|
||||
write!(self.out, "{}", kind.to_hlsl_str(width)?)?
|
||||
}
|
||||
let (size_str, src_width) = match *inner {
|
||||
TypeInner::Vector { size, width, .. } => (back::vector_size_str(size), width),
|
||||
TypeInner::Scalar { width, .. } => ("", width),
|
||||
_ => {
|
||||
return Err(Error::Unimplemented(format!(
|
||||
"write_expr expression::as {:?}",
|
||||
@@ -1807,7 +1802,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
)));
|
||||
}
|
||||
};
|
||||
write!(self.out, "(")?;
|
||||
let kind_str = kind.to_hlsl_str(convert.unwrap_or(src_width))?;
|
||||
write!(self.out, "{}{}(", kind_str, size_str,)?;
|
||||
self.write_expr(module, expr, func_ctx)?;
|
||||
write!(self.out, ")")?;
|
||||
}
|
||||
|
||||
@@ -165,7 +165,13 @@ impl<'a> Display for TypeContext<'a> {
|
||||
} else if self.access.contains(crate::StorageAccess::LOAD) {
|
||||
"read"
|
||||
} else {
|
||||
unreachable!("module is not valid")
|
||||
log::warn!(
|
||||
"Storage access for {:?} (name '{}'): {:?}",
|
||||
self.handle,
|
||||
ty.name.as_deref().unwrap_or_default(),
|
||||
self.access
|
||||
);
|
||||
unreachable!("module is not valid");
|
||||
};
|
||||
("texture", "", format.into(), access)
|
||||
}
|
||||
@@ -1223,13 +1229,15 @@ impl<W: Write> Writer<W> {
|
||||
convert,
|
||||
} => {
|
||||
let scalar = scalar_kind_string(kind);
|
||||
let width = match *context.resolve_type(expr) {
|
||||
crate::TypeInner::Scalar { width, .. }
|
||||
| crate::TypeInner::Vector { width, .. } => width,
|
||||
let (src_kind, src_width) = match *context.resolve_type(expr) {
|
||||
crate::TypeInner::Scalar { kind, width }
|
||||
| crate::TypeInner::Vector { kind, width, .. } => (kind, width),
|
||||
_ => return Err(Error::Validation),
|
||||
};
|
||||
let is_bool_cast =
|
||||
kind == crate::ScalarKind::Bool || src_kind == crate::ScalarKind::Bool;
|
||||
let op = match convert {
|
||||
Some(w) if w == width => "static_cast",
|
||||
Some(w) if w == src_width || is_bool_cast => "static_cast",
|
||||
Some(8) if kind == crate::ScalarKind::Float => {
|
||||
return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
|
||||
}
|
||||
|
||||
@@ -232,14 +232,12 @@ impl<'w> BlockContext<'w> {
|
||||
crate::Expression::Constant(handle) => self.writer.constant_ids[handle.index()],
|
||||
crate::Expression::Splat { size, value } => {
|
||||
let value_id = self.cached[value];
|
||||
self.temp_list.clear();
|
||||
self.temp_list.resize(size as usize, value_id);
|
||||
|
||||
let components = [value_id; 4];
|
||||
let id = self.gen_id();
|
||||
block.body.push(Instruction::composite_construct(
|
||||
result_type_id,
|
||||
id,
|
||||
&self.temp_list,
|
||||
&components[..size as usize],
|
||||
));
|
||||
id
|
||||
}
|
||||
@@ -726,25 +724,26 @@ impl<'w> BlockContext<'w> {
|
||||
use crate::ScalarKind as Sk;
|
||||
|
||||
let expr_id = self.cached[expr];
|
||||
let (src_kind, src_width) =
|
||||
let (src_kind, src_size, src_width) =
|
||||
match *self.fun_info[expr].ty.inner_with(&self.ir_module.types) {
|
||||
crate::TypeInner::Scalar { kind, width }
|
||||
| crate::TypeInner::Vector {
|
||||
kind,
|
||||
width,
|
||||
size: _,
|
||||
} => (kind, width),
|
||||
crate::TypeInner::Matrix { width, .. } => (crate::ScalarKind::Float, width),
|
||||
crate::TypeInner::Scalar { kind, width } => (kind, None, width),
|
||||
crate::TypeInner::Vector { kind, width, size } => (kind, Some(size), width),
|
||||
ref other => {
|
||||
log::error!("As source {:?}", other);
|
||||
return Err(Error::Validation("Unexpected Expression::As source"));
|
||||
}
|
||||
};
|
||||
|
||||
let id = self.gen_id();
|
||||
enum Cast {
|
||||
Unary(spirv::Op),
|
||||
Binary(spirv::Op, Word),
|
||||
Ternary(spirv::Op, Word, Word),
|
||||
}
|
||||
|
||||
let instruction = match (src_kind, kind, convert) {
|
||||
(_, Sk::Bool, Some(_)) if src_kind != Sk::Bool => {
|
||||
let cast = match (src_kind, kind, convert) {
|
||||
(_, _, None) | (Sk::Bool, Sk::Bool, Some(_)) => Cast::Unary(spirv::Op::Bitcast),
|
||||
// casting to a bool - generate `OpXxxNotEqual`
|
||||
(_, Sk::Bool, Some(_)) => {
|
||||
let (op, value) = match src_kind {
|
||||
Sk::Sint => (spirv::Op::INotEqual, crate::ScalarValue::Sint(0)),
|
||||
Sk::Uint => (spirv::Op::INotEqual, crate::ScalarValue::Uint(0)),
|
||||
@@ -753,34 +752,102 @@ impl<'w> BlockContext<'w> {
|
||||
}
|
||||
Sk::Bool => unreachable!(),
|
||||
};
|
||||
let zero_id = self.writer.get_constant_scalar(value, 4);
|
||||
let zero_scalar_id = self.writer.get_constant_scalar(value, src_width);
|
||||
let zero_id = match src_size {
|
||||
Some(size) => {
|
||||
let vector_type_id =
|
||||
self.get_type_id(LookupType::Local(LocalType::Value {
|
||||
vector_size: Some(size),
|
||||
kind: src_kind,
|
||||
width: src_width,
|
||||
pointer_class: None,
|
||||
}));
|
||||
let components = [zero_scalar_id; 4];
|
||||
|
||||
Instruction::binary(op, result_type_id, id, expr_id, zero_id)
|
||||
}
|
||||
_ => {
|
||||
let op = match (src_kind, kind, convert) {
|
||||
(_, _, None) => spirv::Op::Bitcast,
|
||||
(Sk::Float, Sk::Uint, Some(_)) => spirv::Op::ConvertFToU,
|
||||
(Sk::Float, Sk::Sint, Some(_)) => spirv::Op::ConvertFToS,
|
||||
(Sk::Float, Sk::Float, Some(dst_width)) if src_width != dst_width => {
|
||||
spirv::Op::FConvert
|
||||
let zero_id = self.gen_id();
|
||||
block.body.push(Instruction::composite_construct(
|
||||
vector_type_id,
|
||||
zero_id,
|
||||
&components[..size as usize],
|
||||
));
|
||||
zero_id
|
||||
}
|
||||
(Sk::Sint, Sk::Float, Some(_)) => spirv::Op::ConvertSToF,
|
||||
(Sk::Sint, Sk::Sint, Some(dst_width)) if src_width != dst_width => {
|
||||
spirv::Op::SConvert
|
||||
}
|
||||
(Sk::Uint, Sk::Float, Some(_)) => spirv::Op::ConvertUToF,
|
||||
(Sk::Uint, Sk::Uint, Some(dst_width)) if src_width != dst_width => {
|
||||
spirv::Op::UConvert
|
||||
}
|
||||
// We assume it's either an identity cast, or int-uint.
|
||||
_ => spirv::Op::Bitcast,
|
||||
None => zero_scalar_id,
|
||||
};
|
||||
|
||||
Instruction::unary(op, result_type_id, id, expr_id)
|
||||
Cast::Binary(op, zero_id)
|
||||
}
|
||||
// casting from a bool - generate `OpSelect`
|
||||
(Sk::Bool, _, Some(dst_width)) => {
|
||||
let (val0, val1) = match kind {
|
||||
Sk::Sint => (crate::ScalarValue::Sint(0), crate::ScalarValue::Sint(1)),
|
||||
Sk::Uint => (crate::ScalarValue::Uint(0), crate::ScalarValue::Uint(1)),
|
||||
Sk::Float => (
|
||||
crate::ScalarValue::Float(0.0),
|
||||
crate::ScalarValue::Float(1.0),
|
||||
),
|
||||
Sk::Bool => unreachable!(),
|
||||
};
|
||||
let scalar0_id = self.writer.get_constant_scalar(val0, dst_width);
|
||||
let scalar1_id = self.writer.get_constant_scalar(val1, dst_width);
|
||||
let (accept_id, reject_id) = match src_size {
|
||||
Some(size) => {
|
||||
let vector_type_id =
|
||||
self.get_type_id(LookupType::Local(LocalType::Value {
|
||||
vector_size: Some(size),
|
||||
kind,
|
||||
width: dst_width,
|
||||
pointer_class: None,
|
||||
}));
|
||||
let components0 = [scalar0_id; 4];
|
||||
let components1 = [scalar1_id; 4];
|
||||
|
||||
let vec0_id = self.gen_id();
|
||||
block.body.push(Instruction::composite_construct(
|
||||
vector_type_id,
|
||||
vec0_id,
|
||||
&components0[..size as usize],
|
||||
));
|
||||
let vec1_id = self.gen_id();
|
||||
block.body.push(Instruction::composite_construct(
|
||||
vector_type_id,
|
||||
vec1_id,
|
||||
&components1[..size as usize],
|
||||
));
|
||||
(vec1_id, vec0_id)
|
||||
}
|
||||
None => (scalar1_id, scalar0_id),
|
||||
};
|
||||
|
||||
Cast::Ternary(spirv::Op::Select, accept_id, reject_id)
|
||||
}
|
||||
(Sk::Float, Sk::Uint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToU),
|
||||
(Sk::Float, Sk::Sint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToS),
|
||||
(Sk::Float, Sk::Float, Some(dst_width)) if src_width != dst_width => {
|
||||
Cast::Unary(spirv::Op::FConvert)
|
||||
}
|
||||
(Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF),
|
||||
(Sk::Sint, Sk::Sint, Some(dst_width)) if src_width != dst_width => {
|
||||
Cast::Unary(spirv::Op::SConvert)
|
||||
}
|
||||
(Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF),
|
||||
(Sk::Uint, Sk::Uint, Some(dst_width)) if src_width != dst_width => {
|
||||
Cast::Unary(spirv::Op::UConvert)
|
||||
}
|
||||
// We assume it's either an identity cast, or int-uint.
|
||||
_ => Cast::Unary(spirv::Op::Bitcast),
|
||||
};
|
||||
|
||||
let id = self.gen_id();
|
||||
let instruction = match cast {
|
||||
Cast::Unary(op) => Instruction::unary(op, result_type_id, id, expr_id),
|
||||
Cast::Binary(op, operand) => {
|
||||
Instruction::binary(op, result_type_id, id, expr_id, operand)
|
||||
}
|
||||
Cast::Ternary(op, op1, op2) => {
|
||||
Instruction::ternary(op, result_type_id, id, expr_id, op1, op2)
|
||||
}
|
||||
};
|
||||
block.body.push(instruction);
|
||||
id
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user