From 819592ee6796aaa65fc79e60426c068dfe7fc59b Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 29 Oct 2025 16:37:17 +0800 Subject: [PATCH] hotfix: disable DoubleMatmul for PTX --- test/test_rangeify.py | 5 +++-- tinygrad/codegen/opt/postrange.py | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/test/test_rangeify.py b/test/test_rangeify.py index e25af7ff7b..b47da8fb98 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -6,18 +6,19 @@ from tinygrad.codegen.opt import OptOps, Opt from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.nir import NIRRenderer -@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, NIRRenderer), "broken in LVP") +@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "broken in LVP and PTX") class TestDoubleMatmul(unittest.TestCase): def setUp(self): with Context(DEBUG=0): self.a, self.b, self.c = [Tensor.randn(16, 16).contiguous().realize() for _ in range(3)] + self.ref = (self.a @ self.b @ self.c).realize() def _test(self, opts): with Context(PCONTIG=2, DEBUG=max(2, DEBUG.value)): out = (self.a @ self.b @ self.c).contiguous(arg=opts).realize() with Context(DEBUG=0): - err = (out-(self.a @ self.b @ self.c)).square() + err = (out-self.ref).square() self.assertLess(err.max().item(), 1e-4) self.assertLess(err.mean().item(), 1e-6) diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index c0ac187252..968210878a 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -83,10 +83,12 @@ class Scheduler: def colors(self) -> list[str]: output_rngs = self._output_rngs() + globalizible_rngs = self._globalizable_rngs() ret = [] for x,r in zip(self.axis_types, self.rngs): if self.dont_use_locals and x == AxisType.GLOBAL: ret.append("BLUE") elif r not in output_rngs and x == AxisType.LOOP: ret.append("BLACK") + elif r not in globalizible_rngs and x == AxisType.LOOP: ret.append("white") else: ret.append(axis_colors[x]) return ret def colored_shape(self) -> str: return ' '.join([colored(f'{x.src[0].render():>4s}', color) for x,color in zip(self.rngs, self.colors())])