diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index e9cb3dd582..b6c7189bee 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -5,6 +5,7 @@ use crate::{ ImageQuery, LocalVariable, MathFunction, Module, RelationalFunction, SampleLevel, ScalarKind, ScalarValue, Statement, StructMember, Type, TypeInner, VectorSize, }; +use std::iter; /// Helper struct for texture calls with the separate components from the vector argument /// @@ -130,9 +131,52 @@ impl Program { ScalarKind::Float, width, )?; - let column = match *self.resolve_type(ctx, args[0].0, args[0].1)? { - TypeInner::Scalar { .. } => ctx - .add_expression(Expression::Splat { size: rows, value }, body), + match *self.resolve_type(ctx, value, meta)? { + TypeInner::Scalar { .. } => { + // If a matrix is constructed with a single scalar value, then that + // value is used to initialize all the values along the diagonal of + // the matrix; the rest are given zeros. + let mut components = Vec::with_capacity(columns as usize); + let vector_ty = self.module.types.fetch_or_append(Type { + name: None, + inner: TypeInner::Vector { + size: rows, + kind: ScalarKind::Float, + width, + }, + }); + let zero_constant = + self.module.constants.fetch_or_append(Constant { + name: None, + specialization: None, + inner: ConstantInner::Scalar { + width, + value: ScalarValue::Float(0.0), + }, + }); + let zero = ctx + .add_expression(Expression::Constant(zero_constant), body); + + for i in 0..columns as u32 { + components.push( + ctx.add_expression( + Expression::Compose { + ty: vector_ty, + components: (0..rows as u32) + .into_iter() + .map(|r| match r == i { + true => value, + false => zero, + }) + .collect(), + }, + body, + ), + ) + } + + ctx.add_expression(Expression::Compose { ty, components }, body) + } TypeInner::Matrix { rows: ori_rows, .. } => { let mut components = Vec::new(); @@ -152,26 +196,21 @@ impl Program { components.push(vector) } - let h = ctx.add_expression( - Expression::Compose { ty, components }, - body, - ); - - return Ok(Some(h)); + ctx.add_expression(Expression::Compose { ty, components }, body) } - _ => value, - }; + _ => { + let columns = + iter::repeat(value).take(columns as usize).collect(); - let columns = - std::iter::repeat(column).take(columns as usize).collect(); - - ctx.add_expression( - Expression::Compose { - ty, - components: columns, - }, - body, - ) + ctx.add_expression( + Expression::Compose { + ty, + components: columns, + }, + body, + ) + } + } } TypeInner::Struct { .. } => ctx.add_expression( Expression::Compose { diff --git a/tests/out/wgsl/280-matrix-cast-vert.wgsl b/tests/out/wgsl/280-matrix-cast-vert.wgsl index 3416dfc757..c5676cd288 100644 --- a/tests/out/wgsl/280-matrix-cast-vert.wgsl +++ b/tests/out/wgsl/280-matrix-cast-vert.wgsl @@ -1,9 +1,7 @@ fn main1() { - var a: mat4x4; + var a: mat4x4 = mat4x4(vec4(1.0, 0.0, 0.0, 0.0), vec4(0.0, 1.0, 0.0, 0.0), vec4(0.0, 0.0, 1.0, 0.0), vec4(0.0, 0.0, 0.0, 1.0)); - let _e2: vec4 = vec4(f32(1)); - a = mat4x4(_e2, _e2, _e2, _e2); - return; + let _e1: f32 = f32(1); } [[stage(vertex)]] diff --git a/tests/out/wgsl/long-form-matrix-vert.wgsl b/tests/out/wgsl/long-form-matrix-vert.wgsl index d2c3aa18da..dee8738ffe 100644 --- a/tests/out/wgsl/long-form-matrix-vert.wgsl +++ b/tests/out/wgsl/long-form-matrix-vert.wgsl @@ -1,5 +1,5 @@ fn main1() { - var splat: mat2x2; + var splat: mat2x2 = mat2x2(vec2(1.0, 0.0), vec2(0.0, 1.0)); var normal: mat2x2 = mat2x2(vec2(1.0, 1.0), vec2(2.0, 2.0)); var a: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); var b: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); @@ -7,21 +7,20 @@ fn main1() { var d: mat3x3 = mat3x3(vec3(2.0, 2.0, 1.0), vec3(1.0, 1.0, 1.0), vec3(1.0, 1.0, 1.0)); var e: mat4x4 = mat4x4(vec4(2.0, 2.0, 1.0, 1.0), vec4(1.0, 1.0, 2.0, 2.0), vec4(1.0, 1.0, 1.0, 1.0), vec4(1.0, 1.0, 1.0, 1.0)); - let _e2: vec2 = vec2(f32(1)); - splat = mat2x2(_e2, _e2); - let _e7: vec2 = vec2(f32(1)); - let _e10: vec2 = vec2(f32(2)); - let _e36: vec2 = vec2(f32(2), f32(3)); - let _e51: vec3 = vec3(f32(1)); - let _e54: vec3 = vec3(f32(1)); - let _e71: vec2 = vec2(f32(2)); - let _e75: vec3 = vec3(f32(1)); - let _e78: vec3 = vec3(f32(1)); - let _e95: vec2 = vec2(f32(2)); - let _e98: vec4 = vec4(f32(1)); - let _e101: vec2 = vec2(f32(2)); - let _e104: vec4 = vec4(f32(1)); - let _e107: vec4 = vec4(f32(1)); + let _e1: f32 = f32(1); + let _e9: vec2 = vec2(f32(1)); + let _e12: vec2 = vec2(f32(2)); + let _e38: vec2 = vec2(f32(2), f32(3)); + let _e53: vec3 = vec3(f32(1)); + let _e56: vec3 = vec3(f32(1)); + let _e73: vec2 = vec2(f32(2)); + let _e77: vec3 = vec3(f32(1)); + let _e80: vec3 = vec3(f32(1)); + let _e97: vec2 = vec2(f32(2)); + let _e100: vec4 = vec4(f32(1)); + let _e103: vec2 = vec2(f32(2)); + let _e106: vec4 = vec4(f32(1)); + let _e109: vec4 = vec4(f32(1)); } [[stage(vertex)]]