#include #include constexpr constant unsigned NUM_PARTICLES = 1500u; struct Particle { metal::float2 pos; metal::float2 vel; }; struct SimParams { float deltaT; float rule1Distance; float rule2Distance; float rule3Distance; float rule1Scale; float rule2Scale; float rule3Scale; }; typedef Particle type3[1]; struct Particles { type3 particles; }; struct main1Input { }; kernel void main1( metal::uint3 global_invocation_id [[thread_position_in_grid]] , constant SimParams& params [[buffer(0)]] , constant Particles& particlesSrc [[buffer(1)]] , device Particles& particlesDst [[buffer(2)]] ) { metal::float2 vPos; metal::float2 vVel; metal::float2 cMass; metal::float2 cVel; metal::float2 colVel; int cMassCount = 0; int cVelCount = 0; metal::float2 pos1; metal::float2 vel1; metal::uint i = 0u; if (global_invocation_id.x >= NUM_PARTICLES) { return; } vPos = particlesSrc.particles[global_invocation_id.x].pos; vVel = particlesSrc.particles[global_invocation_id.x].vel; cMass = metal::float2(0.0, 0.0); cVel = metal::float2(0.0, 0.0); colVel = metal::float2(0.0, 0.0); bool loop_init = true; while(true) { if (!loop_init) { i = i + 1u; } loop_init = false; if (i >= NUM_PARTICLES) { break; } if (i == global_invocation_id.x) { continue; } pos1 = particlesSrc.particles[i].pos; vel1 = particlesSrc.particles[i].vel; if (metal::distance(pos1, vPos) < params.rule1Distance) { cMass = cMass + pos1; cMassCount = cMassCount + 1; } if (metal::distance(pos1, vPos) < params.rule2Distance) { colVel = colVel - (pos1 - vPos); } if (metal::distance(pos1, vPos) < params.rule3Distance) { cVel = cVel + vel1; cVelCount = cVelCount + 1; } } if (cMassCount > 0) { cMass = (cMass / float2(static_cast(cMassCount))) - vPos; } if (cVelCount > 0) { cVel = cVel / float2(static_cast(cVelCount)); } vVel = ((vVel + (cMass * params.rule1Scale)) + (colVel * params.rule2Scale)) + (cVel * params.rule3Scale); vVel = metal::normalize(vVel) * metal::clamp(metal::length(vVel), 0.0, 0.1); vPos = vPos + (vVel * params.deltaT); if (vPos.x < -1.0) { vPos.x = 1.0; } if (vPos.x > 1.0) { vPos.x = -1.0; } if (vPos.y < -1.0) { vPos.y = 1.0; } if (vPos.y > 1.0) { vPos.y = -1.0; } particlesDst.particles[global_invocation_id.x].pos = vPos; particlesDst.particles[global_invocation_id.x].vel = vVel; return; }