[naga]: Let TypeInner::Matrix hold a Scalar, not just a width.

Let `naga::TypeInner::Matrix` hold a full `Scalar`, with a kind and
byte width, not merely a byte width, to make it possible to represent
matrices of AbstractFloats for WGSL.
This commit is contained in:
Jim Blandy
2023-11-16 13:16:08 -08:00
committed by Teodor Tanasoaia
parent 4b10ce7e5b
commit 72462267e8
37 changed files with 225 additions and 246 deletions

View File

@@ -95,6 +95,8 @@ Passing an owned value `window` to `Surface` will return a `Surface<'static>`. S
- When reading GLSL, fix the argument types of the double-precision floating-point overloads of the `dot`, `reflect`, `distance`, and `ldexp` builtin functions. Correct the WGSL generated for constructing 64-bit floating-point matrices. Add tests for all the above. By @jimblandy in [#4684](https://github.com/gfx-rs/wgpu/pull/4684).
- Allow Naga's IR types to represent matrices with elements elements of any scalar kind. This makes it possible for Naga IR types to represent WGSL abstract matrices. By @jimblandy in [#4735](https://github.com/gfx-rs/wgpu/pull/4735).
- When evaluating const-expressions and generating SPIR-V, properly handle `Compose` expressions whose operands are `Splat` expressions. Such expressions are created and marked as constant by the constant evaluator. By @jimblandy in [#4695](https://github.com/gfx-rs/wgpu/pull/4695).
- Preserve the source spans for constants and expressions correctly across module compaction. By @jimblandy in [#4696](https://github.com/gfx-rs/wgpu/pull/4696).
@@ -2353,4 +2355,4 @@ DeviceDescriptor {
- concept of the storage hub
- basic recording of passes and command buffers
- submission-based lifetime tracking and command buffer recycling
- automatic resource transitions
- automatic resource transitions

View File

@@ -275,11 +275,9 @@ impl<'a, W> Writer<'a, W> {
for (ty_handle, ty) in self.module.types.iter() {
match ty.inner {
TypeInner::Scalar(scalar) => self.scalar_required_features(scalar),
TypeInner::Vector { scalar, .. } => self.scalar_required_features(scalar),
TypeInner::Matrix { width, .. } => {
self.scalar_required_features(Scalar::float(width))
}
TypeInner::Scalar(scalar)
| TypeInner::Vector { scalar, .. }
| TypeInner::Matrix { scalar, .. } => self.scalar_required_features(scalar),
TypeInner::Array { base, size, .. } => {
if let TypeInner::Array { .. } = self.module.types[base].inner {
self.features.request(Features::ARRAY_OF_ARRAYS)

View File

@@ -985,11 +985,11 @@ impl<'a, W: Write> Writer<'a, W> {
TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => write!(
self.out,
"{}mat{}x{}",
glsl_scalar(crate::Scalar::float(width))?.prefix,
glsl_scalar(scalar)?.prefix,
columns as u8,
rows as u8
)?,

View File

@@ -47,10 +47,10 @@ impl crate::TypeInner {
Self::Matrix {
columns,
rows,
width,
scalar,
} => {
let stride = Alignment::from(rows) * width as u32;
let last_row_size = rows as u32 * width as u32;
let stride = Alignment::from(rows) * scalar.width as u32;
let last_row_size = rows as u32 * scalar.width as u32;
((columns as u32 - 1) * stride) + last_row_size
}
Self::Array { base, size, stride } => {
@@ -82,10 +82,10 @@ impl crate::TypeInner {
crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => Cow::Owned(format!(
"{}{}x{}",
crate::Scalar::float(width).to_hlsl_str()?,
scalar.to_hlsl_str()?,
crate::back::vector_size_str(columns),
crate::back::vector_size_str(rows),
)),

View File

@@ -656,10 +656,9 @@ impl<'a, W: Write> super::Writer<'a, W> {
_ => unreachable!(),
};
let vec_ty = match module.types[member.ty].inner {
crate::TypeInner::Matrix { rows, width, .. } => crate::TypeInner::Vector {
size: rows,
scalar: crate::Scalar::float(width),
},
crate::TypeInner::Matrix { rows, scalar, .. } => {
crate::TypeInner::Vector { size: rows, scalar }
}
_ => unreachable!(),
};
self.write_value_type(module, &vec_ty)?;
@@ -736,9 +735,7 @@ impl<'a, W: Write> super::Writer<'a, W> {
_ => unreachable!(),
};
let scalar_ty = match module.types[member.ty].inner {
crate::TypeInner::Matrix { width, .. } => {
crate::TypeInner::Scalar(crate::Scalar::float(width))
}
crate::TypeInner::Matrix { scalar, .. } => crate::TypeInner::Scalar(scalar),
_ => unreachable!(),
};
self.write_value_type(module, &scalar_ty)?;

View File

@@ -180,23 +180,20 @@ impl<W: fmt::Write> super::Writer<'_, W> {
crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
write!(
self.out,
"{}{}x{}(",
crate::Scalar::float(width).to_hlsl_str()?,
scalar.to_hlsl_str()?,
columns as u8,
rows as u8,
)?;
// Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
let row_stride = Alignment::from(rows) * width as u32;
let row_stride = Alignment::from(rows) * scalar.width as u32;
let iter = (0..columns as u32).map(|i| {
let ty_inner = crate::TypeInner::Vector {
size: rows,
scalar: crate::Scalar::float(width),
};
let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
(TypeResolution::Value(ty_inner), i * row_stride)
});
self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
@@ -316,7 +313,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
// first, assign the value to a temporary
writeln!(self.out, "{level}{{")?;
@@ -325,7 +322,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
self.out,
"{}{}{}x{} {}{} = ",
level.next(),
crate::Scalar::float(width).to_hlsl_str()?,
scalar.to_hlsl_str()?,
columns as u8,
rows as u8,
STORE_TEMP_NAME,
@@ -335,16 +332,13 @@ impl<W: fmt::Write> super::Writer<'_, W> {
writeln!(self.out, ";")?;
// Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
let row_stride = Alignment::from(rows) * width as u32;
let row_stride = Alignment::from(rows) * scalar.width as u32;
// then iterate the stores
for i in 0..columns as u32 {
self.temp_access_chain
.push(SubAccess::Offset(i * row_stride));
let ty_inner = crate::TypeInner::Vector {
size: rows,
scalar: crate::Scalar::float(width),
};
let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
let sv = StoreValue::TempIndex {
depth,
index: i,
@@ -467,10 +461,10 @@ impl<W: fmt::Write> super::Writer<'_, W> {
crate::TypeInner::Vector { scalar, .. } => Parent::Array {
stride: scalar.width as u32,
},
crate::TypeInner::Matrix { rows, width, .. } => Parent::Array {
crate::TypeInner::Matrix { rows, scalar, .. } => Parent::Array {
// The stride between matrices is the count of rows as this is how
// long each column is.
stride: Alignment::from(rows) * width as u32,
stride: Alignment::from(rows) * scalar.width as u32,
},
_ => unreachable!(),
},

View File

@@ -908,12 +908,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
TypeInner::Matrix {
rows,
columns,
width,
scalar,
} if member.binding.is_none() && rows == crate::VectorSize::Bi => {
let vec_ty = crate::TypeInner::Vector {
size: rows,
scalar: crate::Scalar::float(width),
};
let vec_ty = crate::TypeInner::Vector { size: rows, scalar };
let field_name_key = NameKey::StructMember(handle, index as u32);
for i in 0..columns as u8 {
@@ -1037,7 +1034,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
// The IR supports only float matrix
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
@@ -1046,7 +1043,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(
self.out,
"{}{}x{}",
crate::Scalar::float(width).to_hlsl_str()?,
scalar.to_hlsl_str()?,
back::vector_size_str(columns),
back::vector_size_str(rows),
)?;
@@ -3241,11 +3238,11 @@ pub(super) fn get_inner_matrix_data(
TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => Some(MatrixType {
columns,
rows,
width,
width: scalar.width,
}),
TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
_ => None,
@@ -3276,12 +3273,12 @@ pub(super) fn get_inner_matrix_of_struct_array_member(
TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
mat_data = Some(MatrixType {
columns,
rows,
width,
width: scalar.width,
})
}
TypeInner::Array { base, .. } => {
@@ -3333,12 +3330,12 @@ fn get_inner_matrix_of_global_uniform(
TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
mat_data = Some(MatrixType {
columns,
rows,
width,
width: scalar.width,
})
}
TypeInner::Array { base, .. } => {

View File

@@ -1942,11 +1942,11 @@ impl<W: Write> Writer<W> {
crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
let target_scalar = crate::Scalar {
kind,
width: convert.unwrap_or(width),
width: convert.unwrap_or(scalar.width),
};
put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
write!(self.out, "(")?;
@@ -2555,10 +2555,9 @@ impl<W: Write> Writer<W> {
TypeResolution::Value(crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
}) => {
let element = crate::Scalar::float(width);
put_numeric_type(&mut self.out, element, &[rows, columns])?;
put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
}
TypeResolution::Value(ref other) => {
log::warn!("Type {:?} isn't a known local", other); //TEMP!

View File

@@ -494,7 +494,7 @@ impl<'w> BlockContext<'w> {
crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
self.write_matrix_matrix_column_op(
block,
@@ -504,7 +504,7 @@ impl<'w> BlockContext<'w> {
right_id,
columns,
rows,
width,
scalar.width,
spirv::Op::FAdd,
);
@@ -522,7 +522,7 @@ impl<'w> BlockContext<'w> {
crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
self.write_matrix_matrix_column_op(
block,
@@ -532,7 +532,7 @@ impl<'w> BlockContext<'w> {
right_id,
columns,
rows,
width,
scalar.width,
spirv::Op::FSub,
);
@@ -1141,9 +1141,7 @@ impl<'w> BlockContext<'w> {
match *self.fun_info[expr].ty.inner_with(&self.ir_module.types) {
crate::TypeInner::Scalar(scalar) => (scalar, None, false),
crate::TypeInner::Vector { scalar, size } => (scalar, Some(size), false),
crate::TypeInner::Matrix { width, .. } => {
(crate::Scalar::float(width), None, true)
}
crate::TypeInner::Matrix { scalar, .. } => (scalar, None, true),
ref other => {
log::error!("As source {:?}", other);
return Err(Error::Validation("Unexpected Expression::As source"));

View File

@@ -367,11 +367,11 @@ fn make_local(inner: &crate::TypeInner) -> Option<LocalType> {
crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => LocalType::Matrix {
columns,
rows,
width,
width: scalar.width,
},
crate::TypeInner::Pointer { base, space } => LocalType::Pointer {
base,

View File

@@ -1766,10 +1766,10 @@ impl Writer {
if let crate::TypeInner::Matrix {
columns: _,
rows,
width,
scalar,
} = *member_array_subty_inner
{
let byte_stride = Alignment::from(rows) * width as u32;
let byte_stride = Alignment::from(rows) * scalar.width as u32;
self.annotations.push(Instruction::member_decorate(
struct_id,
index as u32,

View File

@@ -524,14 +524,14 @@ impl<W: Write> Writer<W> {
TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
write!(
self.out,
"mat{}x{}<{}>",
back::vector_size_str(columns),
back::vector_size_str(rows),
scalar_kind_str(crate::Scalar::float(width))
scalar_kind_str(scalar)
)?;
}
TypeInner::Pointer { base, space } => {
@@ -1412,12 +1412,11 @@ impl<W: Write> Writer<W> {
TypeInner::Matrix {
columns,
rows,
width,
..
scalar,
} => {
let scalar = crate::Scalar {
kind,
width: convert.unwrap_or(width),
width: convert.unwrap_or(scalar.width),
};
let scalar_kind_str = scalar_kind_str(scalar);
write!(

View File

@@ -1276,7 +1276,7 @@ fn inject_common_builtin(
vec![TypeInner::Matrix {
columns,
rows,
width: float_width,
scalar: float_scalar,
}],
MacroCall::MathFunction(MathFunction::Transpose),
))
@@ -1295,7 +1295,7 @@ fn inject_common_builtin(
let args = vec![TypeInner::Matrix {
columns,
rows,
width: float_width,
scalar: float_scalar,
}];
declaration.overloads.push(module.add_builtin(

View File

@@ -10,7 +10,7 @@ use super::{
use crate::{
front::Typifier, proc::Emitter, AddressSpace, Arena, BinaryOperator, Block, Expression,
FastHashMap, FunctionArgument, Handle, Literal, LocalVariable, RelationalFunction, Scalar,
ScalarKind, Span, Statement, Type, TypeInner, VectorSize,
Span, Statement, Type, TypeInner, VectorSize,
};
use std::ops::Index;
@@ -619,12 +619,12 @@ impl<'a> Context<'a> {
&TypeInner::Matrix {
columns: left_columns,
rows: left_rows,
width: left_width,
scalar: left_scalar,
},
&TypeInner::Matrix {
columns: right_columns,
rows: right_rows,
width: right_width,
scalar: right_scalar,
},
) => {
let dimensions_ok = if op == BinaryOperator::Multiply {
@@ -634,7 +634,7 @@ impl<'a> Context<'a> {
};
// Check that the two arguments have the same dimensions
if !dimensions_ok || left_width != right_width {
if !dimensions_ok || left_scalar != right_scalar {
frontend.errors.push(Error {
kind: ErrorKind::SemanticError(
format!(
@@ -682,7 +682,7 @@ impl<'a> Context<'a> {
inner: TypeInner::Matrix {
columns: left_columns,
rows: left_rows,
width: left_width,
scalar: left_scalar,
},
},
Span::default(),
@@ -824,17 +824,15 @@ impl<'a> Context<'a> {
_ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
},
(
&TypeInner::Scalar(Scalar {
width: left_width, ..
}),
&TypeInner::Scalar(left_scalar),
&TypeInner::Matrix {
rows,
columns,
width: right_width,
scalar: right_scalar,
},
) => {
// Check that the two arguments have the same width
if left_width != right_width {
// Check that the two arguments have the same scalar type
if left_scalar != right_scalar {
frontend.errors.push(Error {
kind: ErrorKind::SemanticError(
format!(
@@ -891,7 +889,7 @@ impl<'a> Context<'a> {
inner: TypeInner::Matrix {
columns,
rows,
width: left_width,
scalar: left_scalar,
},
},
Span::default(),
@@ -909,14 +907,12 @@ impl<'a> Context<'a> {
&TypeInner::Matrix {
rows,
columns,
width: left_width,
scalar: left_scalar,
},
&TypeInner::Scalar(Scalar {
width: right_width, ..
}),
&TypeInner::Scalar(right_scalar),
) => {
// Check that the two arguments have the same width
if left_width != right_width {
// Check that the two arguments have the same scalar type
if left_scalar != right_scalar {
frontend.errors.push(Error {
kind: ErrorKind::SemanticError(
format!(
@@ -974,7 +970,7 @@ impl<'a> Context<'a> {
inner: TypeInner::Matrix {
columns,
rows,
width: left_width,
scalar: left_scalar,
},
},
Span::default(),
@@ -1216,18 +1212,14 @@ impl<'a> Context<'a> {
TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
let ty = TypeInner::Matrix {
columns,
rows,
width,
scalar,
};
Literal::one(Scalar {
kind: ScalarKind::Float,
width,
})
.map(|i| (ty, i, Some(rows), Some(columns)))
Literal::one(scalar).map(|i| (ty, i, Some(rows), Some(columns)))
}
_ => None,
};

View File

@@ -156,8 +156,8 @@ impl Frontend {
TypeInner::Matrix {
columns,
rows,
width,
} => self.matrix_one_arg(ctx, ty, columns, rows, width, (value, expr_meta), meta)?,
scalar,
} => self.matrix_one_arg(ctx, ty, columns, rows, scalar, (value, expr_meta), meta)?,
TypeInner::Struct { ref members, .. } => {
let scalar_components = members
.get(0)
@@ -207,7 +207,7 @@ impl Frontend {
ty: Handle<Type>,
columns: crate::VectorSize,
rows: crate::VectorSize,
width: crate::Bytes,
element_scalar: Scalar,
(mut value, expr_meta): (Handle<Expression>, Span),
meta: Span,
) -> Result<Handle<Expression>> {
@@ -216,10 +216,6 @@ impl Frontend {
// `Expression::As` doesn't support matrix width
// casts so we need to do some extra work for casts
let element_scalar = Scalar {
kind: ScalarKind::Float,
width,
};
ctx.forced_conversion(&mut value, expr_meta, element_scalar)?;
match *ctx.resolve_type(value, expr_meta)? {
TypeInner::Scalar(_) => {
@@ -422,14 +418,10 @@ impl Frontend {
TypeInner::Matrix {
columns,
rows,
width,
scalar: element_scalar,
} => {
let mut flattened = Vec::with_capacity(columns as usize * rows as usize);
let element_scalar = Scalar {
kind: ScalarKind::Float,
width,
};
for (mut arg, meta) in args.iter().copied() {
ctx.forced_conversion(&mut arg, meta, element_scalar)?;
@@ -1532,16 +1524,14 @@ fn conversion(target: &TypeInner, source: &TypeInner) -> Option<Conversion> {
&TypeInner::Matrix {
rows: tgt_rows,
columns: tgt_cols,
width: tgt_width,
scalar: tgt_scalar,
},
&TypeInner::Matrix {
rows: src_rows,
columns: src_cols,
width: src_width,
scalar: src_scalar,
},
) if tgt_cols == src_cols && tgt_rows == src_rows => {
(Scalar::float(tgt_width), Scalar::float(src_width))
}
) if tgt_cols == src_cols && tgt_rows == src_rows => (tgt_scalar, src_scalar),
_ => return None,
};
@@ -1585,16 +1575,12 @@ fn builtin_required_variations<'a>(args: impl Iterator<Item = &'a TypeInner>) ->
match *ty {
TypeInner::ValuePointer { scalar, .. }
| TypeInner::Scalar(scalar)
| TypeInner::Vector { scalar, .. } => {
| TypeInner::Vector { scalar, .. }
| TypeInner::Matrix { scalar, .. } => {
if scalar == Scalar::F64 {
variations |= BuiltinVariations::DOUBLE
}
}
TypeInner::Matrix { width, .. } => {
if width == 8 {
variations |= BuiltinVariations::DOUBLE
}
}
TypeInner::Image {
dim,
arrayed,

View File

@@ -109,9 +109,9 @@ pub fn calculate_offset(
TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
let mut align = Alignment::from(rows) * Alignment::from_width(width);
let mut align = Alignment::from(rows) * Alignment::from_width(scalar.width);
// See comment at the beginning of the function
if StructLayout::Std430 != layout {

View File

@@ -43,13 +43,10 @@ fn element_or_member_type(
),
// The child type of a matrix is a vector of floats with the same
// width and the size of the matrix rows.
TypeInner::Matrix { rows, width, .. } => types.insert(
TypeInner::Matrix { rows, scalar, .. } => types.insert(
Type {
name: None,
inner: TypeInner::Vector {
size: rows,
scalar: Scalar::float(width),
},
inner: TypeInner::Vector { size: rows, scalar },
},
Default::default(),
),

View File

@@ -72,7 +72,7 @@ pub fn parse_type(type_name: &str) -> Option<Type> {
let kind = iter.next()?;
let size = iter.next()?;
let Scalar { width, .. } = kind_width_parse(kind)?;
let scalar = kind_width_parse(kind)?;
let (columns, rows) = if let Some(size) = size_parse(size) {
(size, size)
@@ -89,7 +89,7 @@ pub fn parse_type(type_name: &str) -> Option<Type> {
inner: TypeInner::Matrix {
columns,
rows,
width,
scalar,
},
})
};
@@ -193,8 +193,8 @@ pub const fn scalar_components(ty: &TypeInner) -> Option<Scalar> {
match *ty {
TypeInner::Scalar(scalar)
| TypeInner::Vector { scalar, .. }
| TypeInner::ValuePointer { scalar, .. } => Some(scalar),
TypeInner::Matrix { width, .. } => Some(Scalar::float(width)),
| TypeInner::ValuePointer { scalar, .. }
| TypeInner::Matrix { scalar, .. } => Some(scalar),
_ => None,
}
}

View File

@@ -2831,8 +2831,8 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let ty_lookup = self.lookup_type.lookup(result_type_id)?;
let scalar = match ctx.type_arena[ty_lookup.handle].inner {
crate::TypeInner::Scalar(scalar)
| crate::TypeInner::Vector { scalar, .. } => scalar,
crate::TypeInner::Matrix { width, .. } => crate::Scalar::float(width),
| crate::TypeInner::Vector { scalar, .. }
| crate::TypeInner::Matrix { scalar, .. } => scalar,
_ => return Err(Error::InvalidAsType(ty_lookup.handle)),
};
@@ -4377,7 +4377,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
crate::TypeInner::Vector { size, scalar } => crate::TypeInner::Matrix {
columns: map_vector_size(num_columns)?,
rows: size,
width: scalar.width,
scalar,
},
_ => return Err(Error::InvalidInnerType(vector_type_id)),
};
@@ -4674,17 +4674,17 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
if let crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
} = *inner
{
if let Some(stride) = decor.matrix_stride {
let expected_stride = Alignment::from(rows) * width as u32;
let expected_stride = Alignment::from(rows) * scalar.width as u32;
if stride.get() != expected_stride {
return Err(Error::UnsupportedMatrixStride {
stride: stride.get(),
columns: columns as u8,
rows: rows as u8,
width,
width: scalar.width,
});
}
}

View File

@@ -156,7 +156,7 @@ impl crate::Module {
inner: crate::TypeInner::Matrix {
columns: crate::VectorSize::Quad,
rows: crate::VectorSize::Tri,
width: 4,
scalar: crate::Scalar::F32,
},
},
Span::UNDEFINED,

View File

@@ -244,13 +244,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
&crate::TypeInner::Matrix {
columns: dst_columns,
rows: dst_rows,
width: dst_width,
scalar: dst_scalar,
},
)),
) if dst_columns == src_columns && dst_rows == src_rows => crate::Expression::As {
expr: component,
kind: crate::ScalarKind::Float,
convert: Some(dst_width),
convert: Some(dst_scalar.width),
},
// Matrix conversion (matrix -> matrix) - partial
@@ -336,7 +336,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
(
Components::Many {
components,
first_component_ty_inner: &crate::TypeInner::Scalar(crate::Scalar { width, .. }),
first_component_ty_inner: &crate::TypeInner::Scalar(scalar),
..
},
Constructor::PartialMatrix { columns, rows },
@@ -352,14 +352,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
&crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
},
)),
) => {
let vec_ty = ctx.ensure_type_exists(crate::TypeInner::Vector {
scalar: crate::Scalar::float(width),
size: rows,
});
let vec_ty =
ctx.ensure_type_exists(crate::TypeInner::Vector { scalar, size: rows });
let components = components
.chunks(rows as usize)
@@ -377,7 +375,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
});
crate::Expression::Compose { ty, components }
}
@@ -386,11 +384,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
(
Components::Many {
components,
first_component_ty_inner:
&crate::TypeInner::Vector {
scalar: crate::Scalar { width, .. },
..
},
first_component_ty_inner: &crate::TypeInner::Vector { scalar, .. },
..
},
Constructor::PartialMatrix { columns, rows },
@@ -406,14 +400,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
&crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
},
)),
) => {
let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
});
crate::Expression::Compose { ty, components }
}
@@ -531,7 +525,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
columns,
rows,
width,
scalar: crate::Scalar::float(width),
});
Constructor::Type(ty)
}

View File

@@ -2529,7 +2529,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
} => crate::TypeInner::Matrix {
columns,
rows,
width,
scalar: crate::Scalar::float(width),
},
ast::Type::Atomic(scalar) => scalar.to_inner_atomic(),
ast::Type::Pointer { base, space } => {

View File

@@ -44,13 +44,13 @@ impl crate::TypeInner {
Ti::Matrix {
columns,
rows,
width,
scalar,
} => {
format!(
"mat{}x{}<{}>",
columns as u32,
rows as u32,
crate::Scalar::float(width).to_wgsl(),
scalar.to_wgsl(),
)
}
Ti::Atomic(scalar) => {
@@ -236,7 +236,7 @@ mod tests {
let mat = crate::TypeInner::Matrix {
rows: crate::VectorSize::Quad,
columns: crate::VectorSize::Bi,
width: 8,
scalar: crate::Scalar::F64,
};
assert_eq!(mat.to_wgsl(&gctx), "mat2x4<f64>");

View File

@@ -693,11 +693,11 @@ pub enum TypeInner {
Scalar(Scalar),
/// Vector of numbers.
Vector { size: VectorSize, scalar: Scalar },
/// Matrix of floats.
/// Matrix of numbers.
Matrix {
columns: VectorSize,
rows: VectorSize,
width: Bytes,
scalar: Scalar,
},
/// Atomic scalar.
Atomic(Scalar),

View File

@@ -914,15 +914,12 @@ impl<'a> ConstantEvaluator<'a> {
TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
let vec_ty = self.types.insert(
Type {
name: None,
inner: TypeInner::Vector {
size: rows,
scalar: crate::Scalar::float(width),
},
inner: TypeInner::Vector { size: rows, scalar },
},
span,
);
@@ -1026,7 +1023,7 @@ impl<'a> ConstantEvaluator<'a> {
TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
columns,
rows,
width: target.width,
scalar: target,
},
_ => return Err(ConstantEvaluatorError::InvalidCastArg),
};
@@ -1522,7 +1519,7 @@ mod tests {
inner: TypeInner::Matrix {
columns: VectorSize::Bi,
rows: VectorSize::Tri,
width: 4,
scalar: crate::Scalar::F32,
},
},
Default::default(),

View File

@@ -190,9 +190,9 @@ impl Layouter {
Ti::Matrix {
columns: _,
rows,
width,
scalar,
} => {
let alignment = Alignment::new(width as u32)
let alignment = Alignment::new(scalar.width as u32)
.ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
TypeLayout {
size,

View File

@@ -226,7 +226,7 @@ impl super::TypeInner {
use crate::TypeInner as Ti;
match *self {
Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => Some(scalar),
Ti::Matrix { width, .. } => Some(super::Scalar::float(width)),
Ti::Matrix { scalar, .. } => Some(scalar),
_ => None,
}
}
@@ -266,8 +266,8 @@ impl super::TypeInner {
Self::Matrix {
columns,
rows,
width,
} => Alignment::from(rows) * width as u32 * columns as u32,
scalar,
} => Alignment::from(rows) * scalar.width as u32 * columns as u32,
Self::Pointer { .. } | Self::ValuePointer { .. } => POINTER_SPAN,
Self::Array {
base: _,
@@ -367,10 +367,9 @@ impl super::TypeInner {
pub fn component_type(&self, index: usize) -> Option<TypeResolution> {
Some(match *self {
Self::Vector { scalar, .. } => TypeResolution::Value(crate::TypeInner::Scalar(scalar)),
Self::Matrix { rows, width, .. } => TypeResolution::Value(crate::TypeInner::Vector {
size: rows,
scalar: crate::Scalar::float(width),
}),
Self::Matrix { rows, scalar, .. } => {
TypeResolution::Value(crate::TypeInner::Vector { size: rows, scalar })
}
Self::Array {
base,
size: crate::ArraySize::Constant(_),
@@ -773,7 +772,7 @@ fn test_matrix_size() {
crate::TypeInner::Matrix {
columns: crate::VectorSize::Tri,
rows: crate::VectorSize::Tri,
width: 4
scalar: crate::Scalar::F32,
}
.size(module.to_ctx()),
48,

View File

@@ -124,11 +124,11 @@ impl Clone for TypeResolution {
Ti::Matrix {
rows,
columns,
width,
scalar,
} => Ti::Matrix {
rows,
columns,
width,
scalar,
},
Ti::Pointer { base, space } => Ti::Pointer { base, space },
Ti::ValuePointer {
@@ -239,10 +239,9 @@ impl<'a> ResolveContext<'a> {
// pointer, but that's a validation error, not a type error, so
// go ahead provide a type here.
Ti::Array { base, .. } => TypeResolution::Handle(base),
Ti::Matrix { rows, width, .. } => TypeResolution::Value(Ti::Vector {
size: rows,
scalar: crate::Scalar::float(width),
}),
Ti::Matrix { rows, scalar, .. } => {
TypeResolution::Value(Ti::Vector { size: rows, scalar })
}
Ti::Vector { size: _, scalar } => TypeResolution::Value(Ti::Scalar(scalar)),
Ti::ValuePointer {
size: Some(_),
@@ -265,10 +264,10 @@ impl<'a> ResolveContext<'a> {
Ti::Matrix {
columns: _,
rows,
width,
scalar,
} => Ti::ValuePointer {
size: Some(rows),
scalar: crate::Scalar::float(width),
scalar,
space,
},
Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
@@ -301,15 +300,12 @@ impl<'a> ResolveContext<'a> {
Ti::Matrix {
columns,
rows,
width,
scalar,
} => {
if index >= columns as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
TypeResolution::Value(crate::TypeInner::Vector {
size: rows,
scalar: crate::Scalar::float(width),
})
TypeResolution::Value(crate::TypeInner::Vector { size: rows, scalar })
}
Ti::Array { base, .. } => TypeResolution::Handle(base),
Ti::Struct { ref members, .. } => {
@@ -350,14 +346,14 @@ impl<'a> ResolveContext<'a> {
Ti::Matrix {
rows,
columns,
width,
scalar,
} => {
if index >= columns as u32 {
return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
}
Ti::ValuePointer {
size: Some(rows),
scalar: crate::Scalar::float(width),
scalar,
space,
}
}
@@ -535,35 +531,32 @@ impl<'a> ResolveContext<'a> {
&Ti::Matrix {
columns: _,
rows,
width,
scalar,
},
&Ti::Matrix { columns, .. },
) => TypeResolution::Value(Ti::Matrix {
columns,
rows,
width,
scalar,
}),
(
&Ti::Matrix {
columns: _,
rows,
width,
scalar,
},
&Ti::Vector { .. },
) => TypeResolution::Value(Ti::Vector {
size: rows,
scalar: crate::Scalar::float(width),
}),
) => TypeResolution::Value(Ti::Vector { size: rows, scalar }),
(
&Ti::Vector { .. },
&Ti::Matrix {
columns,
rows: _,
width,
scalar,
},
) => TypeResolution::Value(Ti::Vector {
size: columns,
scalar: crate::Scalar::float(width),
scalar,
}),
(&Ti::Scalar { .. }, _) => res_right.clone(),
(_, &Ti::Scalar { .. }) => res_left.clone(),
@@ -718,7 +711,7 @@ impl<'a> ResolveContext<'a> {
) => TypeResolution::Value(Ti::Matrix {
columns,
rows,
width: scalar.width
scalar,
}),
(left, right) =>
return Err(ResolveError::IncompatibleOperands(
@@ -751,11 +744,11 @@ impl<'a> ResolveContext<'a> {
Ti::Matrix {
columns,
rows,
width,
scalar,
} => TypeResolution::Value(Ti::Matrix {
columns: rows,
rows: columns,
width,
scalar,
}),
ref other => return Err(ResolveError::IncompatibleOperands(
format!("{fun:?}({other:?})")
@@ -765,11 +758,11 @@ impl<'a> ResolveContext<'a> {
Ti::Matrix {
columns,
rows,
width,
scalar,
} if columns == rows => TypeResolution::Value(Ti::Matrix {
columns,
rows,
width,
scalar,
}),
ref other => return Err(ResolveError::IncompatibleOperands(
format!("{fun:?}({other:?})")
@@ -777,9 +770,9 @@ impl<'a> ResolveContext<'a> {
},
Mf::Determinant => match *res_arg.inner_with(types) {
Ti::Matrix {
width,
scalar,
..
} => TypeResolution::Value(Ti::Scalar(crate::Scalar::float(width))),
} => TypeResolution::Value(Ti::Scalar(scalar)),
ref other => return Err(ResolveError::IncompatibleOperands(
format!("{fun:?}({other:?})")
)),
@@ -852,12 +845,17 @@ impl<'a> ResolveContext<'a> {
Ti::Matrix {
columns,
rows,
width,
} => TypeResolution::Value(Ti::Matrix {
columns,
rows,
width: convert.unwrap_or(width),
}),
mut scalar,
} => {
if let Some(width) = convert {
scalar.width = width;
}
TypeResolution::Value(Ti::Matrix {
columns,
rows,
scalar,
})
}
ref other => {
return Err(ResolveError::IncompatibleOperands(format!(
"{other:?} as {kind:?}"

View File

@@ -50,12 +50,9 @@ pub fn validate_compose(
Ti::Matrix {
columns,
rows,
width,
scalar,
} => {
let inner = Ti::Vector {
size: rows,
scalar: crate::Scalar::float(width),
};
let inner = Ti::Vector { size: rows, scalar };
if columns as usize != component_resolutions.len() {
return Err(ComposeError::ComponentCount {
expected: columns as u32,

View File

@@ -1501,7 +1501,7 @@ impl super::Validator {
crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => {
scalar
}
crate::TypeInner::Matrix { width, .. } => crate::Scalar::float(width),
crate::TypeInner::Matrix { scalar, .. } => scalar,
_ => return Err(ExpressionError::InvalidCastArgument),
};
base_scalar.kind = kind;

View File

@@ -103,6 +103,8 @@ pub enum TypeError {
InvalidData(Handle<crate::Type>),
#[error("Base type {0:?} for the array is invalid")]
InvalidArrayBaseType(Handle<crate::Type>),
#[error("Matrix elements must always be floating-point types")]
MatrixElementNotFloat,
#[error("The constant {0:?} is specialized, and cannot be used as an array size")]
UnsupportedSpecializedArrayLength(Handle<crate::Constant>),
#[error("Array stride {stride} does not match the expected {expected}")]
@@ -305,9 +307,12 @@ impl super::Validator {
Ti::Matrix {
columns: _,
rows,
width,
scalar,
} => {
self.check_width(crate::Scalar::float(width))?;
if scalar.kind != crate::ScalarKind::Float {
return Err(TypeError::MatrixElementNotFloat);
}
self.check_width(scalar)?;
TypeInfo::new(
TypeFlags::DATA
| TypeFlags::SIZED
@@ -315,7 +320,7 @@ impl super::Validator {
| TypeFlags::HOST_SHAREABLE
| TypeFlags::ARGUMENT
| TypeFlags::CONSTRUCTIBLE,
Alignment::from(rows) * Alignment::from_width(width),
Alignment::from(rows) * Alignment::from_width(scalar.width),
)
}
Ti::Atomic(crate::Scalar { kind, width }) => {

View File

@@ -69,7 +69,10 @@
inner: Matrix(
columns: Quad,
rows: Tri,
width: 4,
scalar: (
kind: Float,
width: 4,
),
),
),
(
@@ -77,7 +80,10 @@
inner: Matrix(
columns: Bi,
rows: Bi,
width: 4,
scalar: (
kind: Float,
width: 4,
),
),
),
(
@@ -178,7 +184,10 @@
inner: Matrix(
columns: Tri,
rows: Bi,
width: 4,
scalar: (
kind: Float,
width: 4,
),
),
),
(
@@ -210,7 +219,10 @@
inner: Matrix(
columns: Quad,
rows: Bi,
width: 4,
scalar: (
kind: Float,
width: 4,
),
),
),
(

View File

@@ -69,7 +69,10 @@
inner: Matrix(
columns: Quad,
rows: Tri,
width: 4,
scalar: (
kind: Float,
width: 4,
),
),
),
(
@@ -77,7 +80,10 @@
inner: Matrix(
columns: Bi,
rows: Bi,
width: 4,
scalar: (
kind: Float,
width: 4,
),
),
),
(
@@ -178,7 +184,10 @@
inner: Matrix(
columns: Tri,
rows: Bi,
width: 4,
scalar: (
kind: Float,
width: 4,
),
),
),
(
@@ -220,7 +229,10 @@
inner: Matrix(
columns: Quad,
rows: Bi,
width: 4,
scalar: (
kind: Float,
width: 4,
),
),
),
(

View File

@@ -90,7 +90,10 @@
inner: Matrix(
columns: Quad,
rows: Quad,
width: 4,
scalar: (
kind: Float,
width: 4,
),
),
),
(

View File

@@ -138,7 +138,10 @@
inner: Matrix(
columns: Quad,
rows: Quad,
width: 4,
scalar: (
kind: Float,
width: 4,
),
),
),
(

View File

@@ -792,10 +792,10 @@ impl Interface {
naga::TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => NumericType {
dim: NumericDimension::Matrix(columns, rows),
scalar: naga::Scalar::float(width),
scalar,
},
naga::TypeInner::Struct { ref members, .. } => {
for member in members {

View File

@@ -1563,7 +1563,7 @@ impl super::Queue {
naga::TypeInner::Matrix {
columns: naga::VectorSize::Bi,
rows: naga::VectorSize::Bi,
width: 4,
scalar: naga::Scalar::F32,
} => {
let data = unsafe { get_data::<f32, 4>(data_bytes, offset) };
unsafe { gl.uniform_matrix_2_f32_slice(location, false, data) };
@@ -1571,7 +1571,7 @@ impl super::Queue {
naga::TypeInner::Matrix {
columns: naga::VectorSize::Bi,
rows: naga::VectorSize::Tri,
width: 4,
scalar: naga::Scalar::F32,
} => {
// repack 2 vec3s into 6 values.
let unpacked_data = unsafe { get_data::<f32, 8>(data_bytes, offset) };
@@ -1585,7 +1585,7 @@ impl super::Queue {
naga::TypeInner::Matrix {
columns: naga::VectorSize::Bi,
rows: naga::VectorSize::Quad,
width: 4,
scalar: naga::Scalar::F32,
} => {
let data = unsafe { get_data::<f32, 8>(data_bytes, offset) };
unsafe { gl.uniform_matrix_2x4_f32_slice(location, false, data) };
@@ -1597,7 +1597,7 @@ impl super::Queue {
naga::TypeInner::Matrix {
columns: naga::VectorSize::Tri,
rows: naga::VectorSize::Bi,
width: 4,
scalar: naga::Scalar::F32,
} => {
let data = unsafe { get_data::<f32, 6>(data_bytes, offset) };
unsafe { gl.uniform_matrix_3x2_f32_slice(location, false, data) };
@@ -1605,7 +1605,7 @@ impl super::Queue {
naga::TypeInner::Matrix {
columns: naga::VectorSize::Tri,
rows: naga::VectorSize::Tri,
width: 4,
scalar: naga::Scalar::F32,
} => {
// repack 3 vec3s into 9 values.
let unpacked_data = unsafe { get_data::<f32, 12>(data_bytes, offset) };
@@ -1620,7 +1620,7 @@ impl super::Queue {
naga::TypeInner::Matrix {
columns: naga::VectorSize::Tri,
rows: naga::VectorSize::Quad,
width: 4,
scalar: naga::Scalar::F32,
} => {
let data = unsafe { get_data::<f32, 12>(data_bytes, offset) };
unsafe { gl.uniform_matrix_3x4_f32_slice(location, false, data) };
@@ -1632,7 +1632,7 @@ impl super::Queue {
naga::TypeInner::Matrix {
columns: naga::VectorSize::Quad,
rows: naga::VectorSize::Bi,
width: 4,
scalar: naga::Scalar::F32,
} => {
let data = unsafe { get_data::<f32, 8>(data_bytes, offset) };
unsafe { gl.uniform_matrix_4x2_f32_slice(location, false, data) };
@@ -1640,7 +1640,7 @@ impl super::Queue {
naga::TypeInner::Matrix {
columns: naga::VectorSize::Quad,
rows: naga::VectorSize::Tri,
width: 4,
scalar: naga::Scalar::F32,
} => {
// repack 4 vec3s into 12 values.
let unpacked_data = unsafe { get_data::<f32, 16>(data_bytes, offset) };
@@ -1656,7 +1656,7 @@ impl super::Queue {
naga::TypeInner::Matrix {
columns: naga::VectorSize::Quad,
rows: naga::VectorSize::Quad,
width: 4,
scalar: naga::Scalar::F32,
} => {
let data = unsafe { get_data::<f32, 16>(data_bytes, offset) };
unsafe { gl.uniform_matrix_4_f32_slice(location, false, data) };