it works, it's just slow...

This commit is contained in:
George Hotz
2025-10-29 18:35:44 +08:00
parent e4ef94cf10
commit 154ddd98fd
2 changed files with 43 additions and 4 deletions

View File

@@ -7,6 +7,35 @@ from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
@unittest.skipUnless(Device.DEFAULT == "METAL" and not CI, "only for METAL TC")
class TestBigDoubleMatmul(unittest.TestCase):
def setUp(self):
N = 1024
with Context(DEBUG=0):
self.a, self.b, self.c = [Tensor.randn(N, N).contiguous().realize() for _ in range(3)]
with Context(DEBUG=2):
self.ref = (self.a @ self.b @ self.c).realize()
def _test(self, opts):
with Context(PCONTIG=2, DEBUG=max(2, DEBUG.value)):
out = (self.a @ self.b @ self.c).contiguous(arg=opts).realize()
with Context(DEBUG=0):
err = (out-self.ref).square()
self.assertLess(err.max().item(), 1e-4)
self.assertLess(err.mean().item(), 1e-6)
def test_demote_tc_both(self):
outs = ()
outs += (Opt(OptOps.DEMOTE, 2, 8),)
outs += (Opt(OptOps.TC, 0, (0, 0, 1, 1)),)
outs += (Opt(OptOps.TC, 0, (0, 0, 1, 0)),)
outs += (Opt(OptOps.UPCAST, 0, 4),)
outs += (Opt(OptOps.UPCAST, 1, 4),)
outs += (Opt(OptOps.UNROLL, 0, 4),)
outs += (Opt(OptOps.UNROLL, 1, 4),)
self._test(outs)
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer, CUDARenderer)), "broken in LVP and PTX")
class TestDoubleMatmul(unittest.TestCase):
def setUp(self):
@@ -62,10 +91,9 @@ class TestDoubleMatmul(unittest.TestCase):
def test_demote_tc_bottom(self):
self._test((Opt(OptOps.DEMOTE, 2, 8), Opt(OptOps.TC, 0, (0, 0, 1, 1))))
@unittest.skip("broken")
@unittest.skipUnless(Device.DEFAULT == "METAL", "only for METAL TC")
def test_demote_tc_both(self):
self._test((Opt(OptOps.DEMOTE, 2, 8), Opt(OptOps.TC, 0, (0, 0, 1, 0)), Opt(OptOps.TC, 0, (0, 0, 1, 1))))
self._test((Opt(OptOps.DEMOTE, 2, 8), Opt(OptOps.TC, 0, (0, 0, 1, 1)), Opt(OptOps.TC, 0, (0, 0, 1, 0))))
class TestRangeifyAssign(unittest.TestCase):
def test_assign_permuted(self):

View File

@@ -253,8 +253,8 @@ class Scheduler:
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: -x.arg[0])
red_ranges = sorted(reduceop.src[1:], key=lambda x: -x.arg[0])
if DEBUG >= 3:
print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
print(f"TC({axis}, {reduce_choice}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
if not len(in0_ranges) or not len(in1_ranges) or not len(red_ranges): continue
# pick ranges
@@ -277,6 +277,9 @@ class Scheduler:
axes[i] = self.rngs[idx]
except KernelOptError: continue
upcast_ranges = []
reduce_ranges = []
# we create the warp as a whole thing, in case some of these ranges are moved/removed later
warp = UOp.range(tc.threads, -1, AxisType.WARP)
ne: list[UOp] = []
@@ -286,12 +289,14 @@ class Scheduler:
warp //= 2
elif opt[0] == "u":
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.UPCAST)
upcast_ranges.append(new_range)
else: raise RuntimeError(f"unsupported opt {opt[0]} in tensor cores")
ne.append(new_range)
for _, amt in tc.get_reduce_axes():
axes[2], new_range = self.shift_to(axes[2], amt, AxisType.UNROLL)
ne.append(new_range)
reduce_ranges.append(new_range)
if use_tensor_cores != 2:
# fix the srcs
@@ -309,6 +314,12 @@ class Scheduler:
# axes to range number (was done in lowerer)
tc_upcast_axes = tuple([tuple([(self.rngs[a].arg[0], sz) for a,sz in v]) for v in tc_upcast_axes])
tc_reduce_axes = tuple([self.rngs[a].arg[0] for a in tc_reduce_axes])
print(tc_reduce_axes, tc_upcast_axes)
# DIRECT: get range number from ranges
tc_upcast_axes = (((upcast_ranges[0].arg[0], 2),), ((upcast_ranges[0].arg[0], 2),), ((upcast_ranges[0].arg[0], 2),))
tc_reduce_axes = tuple([x.arg[0] for x in reduce_ranges])
#print(tc_reduce_axes, tc_upcast_axes)
# construct the op
# TODO: remove tc_upcast_axes from the arg