diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 067246897f..5142dcda96 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -594,7 +594,9 @@ jobs: - name: some unit tests run: METAL=1 RANGEIFY=1 python -m pytest -n=auto test/unit/test_winograd.py test/unit/test_linalg.py --durations=20 - name: Test METAL=1 RANGEIFY=1 - run: METAL=1 RANGEIFY=1 python -m pytest -n=auto test/test_ops.py test/test_multitensor.py --durations=20 + run: | + METAL=1 RANGEIFY=1 python -m pytest -n=auto test/test_ops.py test/test_multitensor.py --durations=20 + METAL=1 MAX_KERNEL_BUFFERS=6 RANGEIFY=1 PYTHONPATH=. python test/test_multitensor.py TestBatchNorm.test_batchnorm - name: Run process replay tests uses: ./.github/actions/process-replay diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index e9dfd47323..2d401f4125 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -506,7 +506,7 @@ def limit_bufs(ctx:RangeifyContext, root:UOp): if s.op in GroupOp.Elementwise: # Insert bufferize: all AxisType.REDUCE before bufferize are AxisType.LOOP orig_ranges, end_ranges = s.ranges, [x.replace(arg=(next(ctx.range_idx), AxisType.LOOP)) if x.op is Ops.RANGE else x for x in s.ranges] - s = s.substitute(dict(zip(orig_ranges, end_ranges))).bufferize(*end_ranges, arg=BufferizeOpts(device=device)).index(*orig_ranges) + s = s.substitute(dict(zip(orig_ranges, end_ranges))).bufferize(*end_ranges, arg=BufferizeOpts(device=s.device)).index(*orig_ranges) srcs.append(s) return root.replace(src=tuple(srcs)) pm_limit_bufs = PatternMatcher([(UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs)])