diff --git a/test/test_rangeify.py b/test/test_rangeify.py index c0ed402935..7a74d844d7 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -219,7 +219,10 @@ class TestRangeify(unittest.TestCase): out.realize() def test_flash_attention(self): - BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8 + #BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8 + + # lil bigger + BS, HEADS, SEQLEN, EMB = 4, 32, 128, 64 # bigger #BS, HEADS, SEQLEN, EMB = 4, 32, 1024, 64 @@ -236,8 +239,8 @@ class TestRangeify(unittest.TestCase): GlobalCounters.reset() args = () args += (Opt(OptOps.DEMOTE, 5, 8),) + args += (Opt(OptOps.TC, 0, (0,0,1,3)),) args += (Opt(OptOps.TC, 0, (0,0,1,0)),) - #args += (Opt(OptOps.TC, 0, (0,0,1,3)),) ret = fa().contiguous(arg=args).realize() with Context(RANGEIFY=0): with Context(DEBUG=2): diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index 9d8d5d5129..5a30511f2c 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -288,8 +288,6 @@ class Scheduler: axes[2], new_range = self.shift_to(axes[2], amt, AxisType.UNROLL) ne.append(new_range) - #print("ne", [x.arg for x in ne]) - if use_tensor_cores != 2: # fix the srcs reduceop = get_single_element([x for x in self.ast.toposort() if x.op is Ops.REDUCE and x.tag == "TC"]) @@ -300,18 +298,17 @@ class Scheduler: # get reduce/upcast axes for the tensor cores tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))]) - #print(tc.base_upcast_axes()) base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis(tc.base_upcast_axes())]) tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)]) - #print(tc_upcast_axes) # axes to range number (was done in lowerer) tc_upcast_axes = tuple([tuple([(self.rngs[a].arg[0], sz) for a,sz in v]) for v in tc_upcast_axes]) tc_reduce_axes = tuple([self.rngs[a].arg[0] for a in tc_reduce_axes]) + #print("ne", [x.arg for x in ne]) #print(tc_upcast_axes) #print(tc_reduce_axes) - tc_upcast_axes = (((ne[0].arg[0], 2),), ((ne[0].arg[0], 2),), ((ne[0].arg[0], 2),)) + #print("hack", tc_upcast_axes) # construct the op # TODO: remove tc_upcast_axes from the arg