mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
hotfix: disable DoubleMatmul for PTX
This commit is contained in:
@@ -6,18 +6,19 @@ from tinygrad.codegen.opt import OptOps, Opt
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.nir import NIRRenderer
|
||||
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, NIRRenderer), "broken in LVP")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "broken in LVP and PTX")
|
||||
class TestDoubleMatmul(unittest.TestCase):
|
||||
def setUp(self):
|
||||
with Context(DEBUG=0):
|
||||
self.a, self.b, self.c = [Tensor.randn(16, 16).contiguous().realize() for _ in range(3)]
|
||||
self.ref = (self.a @ self.b @ self.c).realize()
|
||||
|
||||
def _test(self, opts):
|
||||
with Context(PCONTIG=2, DEBUG=max(2, DEBUG.value)):
|
||||
out = (self.a @ self.b @ self.c).contiguous(arg=opts).realize()
|
||||
|
||||
with Context(DEBUG=0):
|
||||
err = (out-(self.a @ self.b @ self.c)).square()
|
||||
err = (out-self.ref).square()
|
||||
self.assertLess(err.max().item(), 1e-4)
|
||||
self.assertLess(err.mean().item(), 1e-6)
|
||||
|
||||
|
||||
@@ -83,10 +83,12 @@ class Scheduler:
|
||||
|
||||
def colors(self) -> list[str]:
|
||||
output_rngs = self._output_rngs()
|
||||
globalizible_rngs = self._globalizable_rngs()
|
||||
ret = []
|
||||
for x,r in zip(self.axis_types, self.rngs):
|
||||
if self.dont_use_locals and x == AxisType.GLOBAL: ret.append("BLUE")
|
||||
elif r not in output_rngs and x == AxisType.LOOP: ret.append("BLACK")
|
||||
elif r not in globalizible_rngs and x == AxisType.LOOP: ret.append("white")
|
||||
else: ret.append(axis_colors[x])
|
||||
return ret
|
||||
def colored_shape(self) -> str: return ' '.join([colored(f'{x.src[0].render():>4s}', color) for x,color in zip(self.rngs, self.colors())])
|
||||
|
||||
Reference in New Issue
Block a user