mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
it works, it's just slow...
This commit is contained in:
@@ -7,6 +7,35 @@ from tinygrad.renderer.cstyle import CUDARenderer
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.nir import NIRRenderer
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "METAL" and not CI, "only for METAL TC")
|
||||
class TestBigDoubleMatmul(unittest.TestCase):
|
||||
def setUp(self):
|
||||
N = 1024
|
||||
with Context(DEBUG=0):
|
||||
self.a, self.b, self.c = [Tensor.randn(N, N).contiguous().realize() for _ in range(3)]
|
||||
with Context(DEBUG=2):
|
||||
self.ref = (self.a @ self.b @ self.c).realize()
|
||||
|
||||
def _test(self, opts):
|
||||
with Context(PCONTIG=2, DEBUG=max(2, DEBUG.value)):
|
||||
out = (self.a @ self.b @ self.c).contiguous(arg=opts).realize()
|
||||
|
||||
with Context(DEBUG=0):
|
||||
err = (out-self.ref).square()
|
||||
self.assertLess(err.max().item(), 1e-4)
|
||||
self.assertLess(err.mean().item(), 1e-6)
|
||||
|
||||
def test_demote_tc_both(self):
|
||||
outs = ()
|
||||
outs += (Opt(OptOps.DEMOTE, 2, 8),)
|
||||
outs += (Opt(OptOps.TC, 0, (0, 0, 1, 1)),)
|
||||
outs += (Opt(OptOps.TC, 0, (0, 0, 1, 0)),)
|
||||
outs += (Opt(OptOps.UPCAST, 0, 4),)
|
||||
outs += (Opt(OptOps.UPCAST, 1, 4),)
|
||||
outs += (Opt(OptOps.UNROLL, 0, 4),)
|
||||
outs += (Opt(OptOps.UNROLL, 1, 4),)
|
||||
self._test(outs)
|
||||
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer, CUDARenderer)), "broken in LVP and PTX")
|
||||
class TestDoubleMatmul(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -62,10 +91,9 @@ class TestDoubleMatmul(unittest.TestCase):
|
||||
def test_demote_tc_bottom(self):
|
||||
self._test((Opt(OptOps.DEMOTE, 2, 8), Opt(OptOps.TC, 0, (0, 0, 1, 1))))
|
||||
|
||||
@unittest.skip("broken")
|
||||
@unittest.skipUnless(Device.DEFAULT == "METAL", "only for METAL TC")
|
||||
def test_demote_tc_both(self):
|
||||
self._test((Opt(OptOps.DEMOTE, 2, 8), Opt(OptOps.TC, 0, (0, 0, 1, 0)), Opt(OptOps.TC, 0, (0, 0, 1, 1))))
|
||||
self._test((Opt(OptOps.DEMOTE, 2, 8), Opt(OptOps.TC, 0, (0, 0, 1, 1)), Opt(OptOps.TC, 0, (0, 0, 1, 0))))
|
||||
|
||||
class TestRangeifyAssign(unittest.TestCase):
|
||||
def test_assign_permuted(self):
|
||||
|
||||
@@ -253,8 +253,8 @@ class Scheduler:
|
||||
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: -x.arg[0])
|
||||
red_ranges = sorted(reduceop.src[1:], key=lambda x: -x.arg[0])
|
||||
if DEBUG >= 3:
|
||||
print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
|
||||
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
|
||||
print(f"TC({axis}, {reduce_choice}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
|
||||
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
|
||||
if not len(in0_ranges) or not len(in1_ranges) or not len(red_ranges): continue
|
||||
|
||||
# pick ranges
|
||||
@@ -277,6 +277,9 @@ class Scheduler:
|
||||
axes[i] = self.rngs[idx]
|
||||
except KernelOptError: continue
|
||||
|
||||
upcast_ranges = []
|
||||
reduce_ranges = []
|
||||
|
||||
# we create the warp as a whole thing, in case some of these ranges are moved/removed later
|
||||
warp = UOp.range(tc.threads, -1, AxisType.WARP)
|
||||
ne: list[UOp] = []
|
||||
@@ -286,12 +289,14 @@ class Scheduler:
|
||||
warp //= 2
|
||||
elif opt[0] == "u":
|
||||
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.UPCAST)
|
||||
upcast_ranges.append(new_range)
|
||||
else: raise RuntimeError(f"unsupported opt {opt[0]} in tensor cores")
|
||||
ne.append(new_range)
|
||||
|
||||
for _, amt in tc.get_reduce_axes():
|
||||
axes[2], new_range = self.shift_to(axes[2], amt, AxisType.UNROLL)
|
||||
ne.append(new_range)
|
||||
reduce_ranges.append(new_range)
|
||||
|
||||
if use_tensor_cores != 2:
|
||||
# fix the srcs
|
||||
@@ -309,6 +314,12 @@ class Scheduler:
|
||||
# axes to range number (was done in lowerer)
|
||||
tc_upcast_axes = tuple([tuple([(self.rngs[a].arg[0], sz) for a,sz in v]) for v in tc_upcast_axes])
|
||||
tc_reduce_axes = tuple([self.rngs[a].arg[0] for a in tc_reduce_axes])
|
||||
print(tc_reduce_axes, tc_upcast_axes)
|
||||
|
||||
# DIRECT: get range number from ranges
|
||||
tc_upcast_axes = (((upcast_ranges[0].arg[0], 2),), ((upcast_ranges[0].arg[0], 2),), ((upcast_ranges[0].arg[0], 2),))
|
||||
tc_reduce_axes = tuple([x.arg[0] for x in reduce_ranges])
|
||||
#print(tc_reduce_axes, tc_upcast_axes)
|
||||
|
||||
# construct the op
|
||||
# TODO: remove tc_upcast_axes from the arg
|
||||
|
||||
Reference in New Issue
Block a user