From 0061dc7447d953d27133ad5b304cf1734cb3265e Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 7 Jan 2025 00:37:59 -0500 Subject: [PATCH] fix benchmark allreduce and add to ci [pr] (#8521) --- .github/workflows/benchmark.yml | 2 ++ test/external/external_benchmark_multitensor_allreduce.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 4567095452..941e97f0ee 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -168,6 +168,8 @@ jobs: run: NV=1 RUN_PROCESS_REPLAY=0 HALF=1 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt - name: Test speed vs theoretical run: NV=1 IGNORE_BEAM_CACHE=1 BEAM_DEBUG=1 DEBUG=1 python -m pytest -rA test/external/speed_v_theoretical.py --durations=20 + - name: Test benchmark allreduce + run: NV=1 python test/external/external_benchmark_multitensor_allreduce.py - name: Test tensor cores run: | NV=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded diff --git a/test/external/external_benchmark_multitensor_allreduce.py b/test/external/external_benchmark_multitensor_allreduce.py index e3ed9c5aab..c22bb655f8 100644 --- a/test/external/external_benchmark_multitensor_allreduce.py +++ b/test/external/external_benchmark_multitensor_allreduce.py @@ -2,14 +2,14 @@ import time from tinygrad import Tensor, Device, GlobalCounters, TinyJit from tinygrad.ops import Ops, UOp from tinygrad.multi import MultiLazyBuffer, all_reduce -from tinygrad.engine.schedule import create_schedule +from tinygrad.engine.schedule import create_schedule_with_vars from tinygrad.engine.realize import run_schedule from tinygrad.helpers import getenv, Context, RING, DEBUG from typing import List, Union def realize(x: Union[UOp, List[UOp]]): x = x if isinstance(x, list) else [x] - run_schedule(create_schedule(x)) + run_schedule(*create_schedule_with_vars(x)) for lb in x: Device[lb.device].synchronize() def test(devs: List[str], N: int, iters:int = 10):