From b08dfe51469447b6daf86b01ebd56465c8efe9ca Mon Sep 17 00:00:00 2001 From: Ashley Date: Tue, 20 Apr 2021 17:41:13 +0200 Subject: [PATCH] [Metal] Impl `Expression::Splat` (#738) * [Metal] Impl `Expression::Splat` * Add changes to the snapshots * Apply suggestions --- src/back/msl/writer.rs | 13 +++++++++++-- tests/out/access.msl | 2 +- tests/out/boids.msl | 4 ++-- tests/out/operators.msl | 4 ++-- tests/out/shadow.msl | 2 +- 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 79a8114824..8b2f6901e3 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -608,8 +608,17 @@ impl Writer { }; write!(self.out, "{}", coco)?; } - crate::Expression::Splat { size: _, value } => { - self.put_expression(value, context, is_scoped)?; + crate::Expression::Splat { size, value } => { + let scalar_kind = match *context.resolve_type(value) { + crate::TypeInner::Scalar { kind, .. } => kind, + _ => return Err(Error::Validation) + }; + let scalar = scalar_kind_string(scalar_kind); + let size = vector_size_string(size); + + write!(self.out, "{}{}(", scalar, size)?; + self.put_expression(value, context, true)?; + write!(self.out, ")")?; } crate::Expression::Compose { ty, ref components } => { let inner = &context.module.types[ty].inner; diff --git a/tests/out/access.msl b/tests/out/access.msl index 8c9d1c2260..2c3da3e904 100644 --- a/tests/out/access.msl +++ b/tests/out/access.msl @@ -11,5 +11,5 @@ struct fooOutput { vertex fooOutput foo( metal::uint vi [[vertex_id]] ) { - return fooOutput { static_cast(type3 {1, 2, 3, 4, 5}[vi]) }; + return fooOutput { static_cast(int4(type3 {1, 2, 3, 4, 5}[vi])) }; } diff --git a/tests/out/boids.msl b/tests/out/boids.msl index f8aa787805..3cabb8e664 100644 --- a/tests/out/boids.msl +++ b/tests/out/boids.msl @@ -73,10 +73,10 @@ kernel void main1( } } if (cMassCount > 0) { - cMass = (cMass / static_cast(cMassCount)) - vPos; + cMass = (cMass / float2(static_cast(cMassCount))) - vPos; } if (cVelCount > 0) { - cVel = cVel / static_cast(cVelCount); + cVel = cVel / float2(static_cast(cVelCount)); } vVel = ((vVel + (cMass * params.rule1Scale)) + (colVel * params.rule2Scale)) + (cVel * params.rule3Scale); vVel = metal::normalize(vVel) * metal::clamp(metal::length(vVel), 0.0, 0.1); diff --git a/tests/out/operators.msl b/tests/out/operators.msl index 2b21eab624..79d2ced847 100644 --- a/tests/out/operators.msl +++ b/tests/out/operators.msl @@ -7,6 +7,6 @@ struct splatOutput { }; vertex splatOutput splat( ) { - metal::float2 _e10 = ((1.0 + 2.0) - 3.0) / 4.0; - return splatOutput { metal::float4(_e10.x, _e10.y, _e10.x, _e10.y) + static_cast(5 % 2) }; + metal::float2 _e10 = ((float2(1.0) + float2(2.0)) - float2(3.0)) / float2(4.0); + return splatOutput { metal::float4(_e10.x, _e10.y, _e10.x, _e10.y) + static_cast(int4(5) % int4(2)) }; } diff --git a/tests/out/shadow.msl b/tests/out/shadow.msl index 772a3cf2e9..12a5c43bba 100644 --- a/tests/out/shadow.msl +++ b/tests/out/shadow.msl @@ -25,7 +25,7 @@ float fetch_shadow( if (homogeneous_coords.w <= 0.0) { return 1.0; } - float _e28 = t_shadow.sample_compare(sampler_shadow, ((metal::float2(homogeneous_coords.x, homogeneous_coords.y) * metal::float2(0.5, -0.5)) / homogeneous_coords.w) + metal::float2(0.5, 0.5), static_cast(light_id), homogeneous_coords.z / homogeneous_coords.w); + float _e28 = t_shadow.sample_compare(sampler_shadow, ((metal::float2(homogeneous_coords.x, homogeneous_coords.y) * metal::float2(0.5, -0.5)) / float2(homogeneous_coords.w)) + metal::float2(0.5, 0.5), static_cast(light_id), homogeneous_coords.z / homogeneous_coords.w); return _e28; }