Fix float-bool casts in MSL, SPV, and HLSL backends (#1459)

This commit is contained in:
Dzmitry Malyshau
2021-10-12 11:42:20 -04:00
committed by GitHub
parent 3a2f7e611e
commit ee450c1ee4
9 changed files with 285 additions and 165 deletions

View File

@@ -1786,20 +1786,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, "{}", op_str)?;
self.write_expr(module, expr, func_ctx)?;
}
Expression::As { expr, kind, .. } => {
Expression::As {
expr,
kind,
convert,
} => {
let inner = func_ctx.info[expr].ty.inner_with(&module.types);
match *inner {
TypeInner::Vector { size, width, .. } => {
write!(
self.out,
"{}{}",
kind.to_hlsl_str(width)?,
back::vector_size_str(size),
)?;
}
TypeInner::Scalar { width, .. } => {
write!(self.out, "{}", kind.to_hlsl_str(width)?)?
}
let (size_str, src_width) = match *inner {
TypeInner::Vector { size, width, .. } => (back::vector_size_str(size), width),
TypeInner::Scalar { width, .. } => ("", width),
_ => {
return Err(Error::Unimplemented(format!(
"write_expr expression::as {:?}",
@@ -1807,7 +1802,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
)));
}
};
write!(self.out, "(")?;
let kind_str = kind.to_hlsl_str(convert.unwrap_or(src_width))?;
write!(self.out, "{}{}(", kind_str, size_str,)?;
self.write_expr(module, expr, func_ctx)?;
write!(self.out, ")")?;
}

View File

@@ -165,7 +165,13 @@ impl<'a> Display for TypeContext<'a> {
} else if self.access.contains(crate::StorageAccess::LOAD) {
"read"
} else {
unreachable!("module is not valid")
log::warn!(
"Storage access for {:?} (name '{}'): {:?}",
self.handle,
ty.name.as_deref().unwrap_or_default(),
self.access
);
unreachable!("module is not valid");
};
("texture", "", format.into(), access)
}
@@ -1223,13 +1229,15 @@ impl<W: Write> Writer<W> {
convert,
} => {
let scalar = scalar_kind_string(kind);
let width = match *context.resolve_type(expr) {
crate::TypeInner::Scalar { width, .. }
| crate::TypeInner::Vector { width, .. } => width,
let (src_kind, src_width) = match *context.resolve_type(expr) {
crate::TypeInner::Scalar { kind, width }
| crate::TypeInner::Vector { kind, width, .. } => (kind, width),
_ => return Err(Error::Validation),
};
let is_bool_cast =
kind == crate::ScalarKind::Bool || src_kind == crate::ScalarKind::Bool;
let op = match convert {
Some(w) if w == width => "static_cast",
Some(w) if w == src_width || is_bool_cast => "static_cast",
Some(8) if kind == crate::ScalarKind::Float => {
return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
}

View File

@@ -232,14 +232,12 @@ impl<'w> BlockContext<'w> {
crate::Expression::Constant(handle) => self.writer.constant_ids[handle.index()],
crate::Expression::Splat { size, value } => {
let value_id = self.cached[value];
self.temp_list.clear();
self.temp_list.resize(size as usize, value_id);
let components = [value_id; 4];
let id = self.gen_id();
block.body.push(Instruction::composite_construct(
result_type_id,
id,
&self.temp_list,
&components[..size as usize],
));
id
}
@@ -726,25 +724,26 @@ impl<'w> BlockContext<'w> {
use crate::ScalarKind as Sk;
let expr_id = self.cached[expr];
let (src_kind, src_width) =
let (src_kind, src_size, src_width) =
match *self.fun_info[expr].ty.inner_with(&self.ir_module.types) {
crate::TypeInner::Scalar { kind, width }
| crate::TypeInner::Vector {
kind,
width,
size: _,
} => (kind, width),
crate::TypeInner::Matrix { width, .. } => (crate::ScalarKind::Float, width),
crate::TypeInner::Scalar { kind, width } => (kind, None, width),
crate::TypeInner::Vector { kind, width, size } => (kind, Some(size), width),
ref other => {
log::error!("As source {:?}", other);
return Err(Error::Validation("Unexpected Expression::As source"));
}
};
let id = self.gen_id();
enum Cast {
Unary(spirv::Op),
Binary(spirv::Op, Word),
Ternary(spirv::Op, Word, Word),
}
let instruction = match (src_kind, kind, convert) {
(_, Sk::Bool, Some(_)) if src_kind != Sk::Bool => {
let cast = match (src_kind, kind, convert) {
(_, _, None) | (Sk::Bool, Sk::Bool, Some(_)) => Cast::Unary(spirv::Op::Bitcast),
// casting to a bool - generate `OpXxxNotEqual`
(_, Sk::Bool, Some(_)) => {
let (op, value) = match src_kind {
Sk::Sint => (spirv::Op::INotEqual, crate::ScalarValue::Sint(0)),
Sk::Uint => (spirv::Op::INotEqual, crate::ScalarValue::Uint(0)),
@@ -753,34 +752,102 @@ impl<'w> BlockContext<'w> {
}
Sk::Bool => unreachable!(),
};
let zero_id = self.writer.get_constant_scalar(value, 4);
let zero_scalar_id = self.writer.get_constant_scalar(value, src_width);
let zero_id = match src_size {
Some(size) => {
let vector_type_id =
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(size),
kind: src_kind,
width: src_width,
pointer_class: None,
}));
let components = [zero_scalar_id; 4];
Instruction::binary(op, result_type_id, id, expr_id, zero_id)
}
_ => {
let op = match (src_kind, kind, convert) {
(_, _, None) => spirv::Op::Bitcast,
(Sk::Float, Sk::Uint, Some(_)) => spirv::Op::ConvertFToU,
(Sk::Float, Sk::Sint, Some(_)) => spirv::Op::ConvertFToS,
(Sk::Float, Sk::Float, Some(dst_width)) if src_width != dst_width => {
spirv::Op::FConvert
let zero_id = self.gen_id();
block.body.push(Instruction::composite_construct(
vector_type_id,
zero_id,
&components[..size as usize],
));
zero_id
}
(Sk::Sint, Sk::Float, Some(_)) => spirv::Op::ConvertSToF,
(Sk::Sint, Sk::Sint, Some(dst_width)) if src_width != dst_width => {
spirv::Op::SConvert
}
(Sk::Uint, Sk::Float, Some(_)) => spirv::Op::ConvertUToF,
(Sk::Uint, Sk::Uint, Some(dst_width)) if src_width != dst_width => {
spirv::Op::UConvert
}
// We assume it's either an identity cast, or int-uint.
_ => spirv::Op::Bitcast,
None => zero_scalar_id,
};
Instruction::unary(op, result_type_id, id, expr_id)
Cast::Binary(op, zero_id)
}
// casting from a bool - generate `OpSelect`
(Sk::Bool, _, Some(dst_width)) => {
let (val0, val1) = match kind {
Sk::Sint => (crate::ScalarValue::Sint(0), crate::ScalarValue::Sint(1)),
Sk::Uint => (crate::ScalarValue::Uint(0), crate::ScalarValue::Uint(1)),
Sk::Float => (
crate::ScalarValue::Float(0.0),
crate::ScalarValue::Float(1.0),
),
Sk::Bool => unreachable!(),
};
let scalar0_id = self.writer.get_constant_scalar(val0, dst_width);
let scalar1_id = self.writer.get_constant_scalar(val1, dst_width);
let (accept_id, reject_id) = match src_size {
Some(size) => {
let vector_type_id =
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(size),
kind,
width: dst_width,
pointer_class: None,
}));
let components0 = [scalar0_id; 4];
let components1 = [scalar1_id; 4];
let vec0_id = self.gen_id();
block.body.push(Instruction::composite_construct(
vector_type_id,
vec0_id,
&components0[..size as usize],
));
let vec1_id = self.gen_id();
block.body.push(Instruction::composite_construct(
vector_type_id,
vec1_id,
&components1[..size as usize],
));
(vec1_id, vec0_id)
}
None => (scalar1_id, scalar0_id),
};
Cast::Ternary(spirv::Op::Select, accept_id, reject_id)
}
(Sk::Float, Sk::Uint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToU),
(Sk::Float, Sk::Sint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToS),
(Sk::Float, Sk::Float, Some(dst_width)) if src_width != dst_width => {
Cast::Unary(spirv::Op::FConvert)
}
(Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF),
(Sk::Sint, Sk::Sint, Some(dst_width)) if src_width != dst_width => {
Cast::Unary(spirv::Op::SConvert)
}
(Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF),
(Sk::Uint, Sk::Uint, Some(dst_width)) if src_width != dst_width => {
Cast::Unary(spirv::Op::UConvert)
}
// We assume it's either an identity cast, or int-uint.
_ => Cast::Unary(spirv::Op::Bitcast),
};
let id = self.gen_id();
let instruction = match cast {
Cast::Unary(op) => Instruction::unary(op, result_type_id, id, expr_id),
Cast::Binary(op, operand) => {
Instruction::binary(op, result_type_id, id, expr_id, operand)
}
Cast::Ternary(op, op1, op2) => {
Instruction::ternary(op, result_type_id, id, expr_id, op1, op2)
}
};
block.body.push(instruction);
id
}