From 70671d9625cf358ae170d1c4cf5e4f8fd1f9ccda Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Thu, 28 Sep 2023 14:02:22 -0400 Subject: [PATCH] fix test_collectives (#1934) * fix: fix test_collectives.py * feat: reenable test_collectives --- .github/workflows/test.yml | 2 +- test/external/dist/test_collectives.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8c21ed86cd..f902b5d578 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -162,7 +162,7 @@ jobs: name: Test multigpu run: | PYTHONPATH="." python test/external/dist/test_world.py - #PYTHONPATH="." python test/external/dist/test_collectives.py + PYTHONPATH="." python test/external/dist/test_collectives.py - if: ${{ matrix.task == 'realworld' }} name: Test KOPT run: PYTHONPATH="." KOPT=1 BUDGET=20 GPU=1 DEBUG=1 python -m pytest -rA -n=auto test/models/test_real_world.py diff --git a/test/external/dist/test_collectives.py b/test/external/dist/test_collectives.py index 2378a988f1..f7c19d19cb 100644 --- a/test/external/dist/test_collectives.py +++ b/test/external/dist/test_collectives.py @@ -30,6 +30,7 @@ def run(): # reset jit allreduce_jit.cnt = 0 + allreduce_jit.input_replace = {} # test uneven chunk sizes for _ in range(3):