flash attention with two gemms

This commit is contained in:
George Hotz
2025-10-02 17:48:48 +08:00
parent e5028d58e9
commit 3fb3dd4c06
2 changed files with 7 additions and 7 deletions

View File

@@ -219,7 +219,10 @@ class TestRangeify(unittest.TestCase):
out.realize()
def test_flash_attention(self):
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
#BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
# lil bigger
BS, HEADS, SEQLEN, EMB = 4, 32, 128, 64
# bigger
#BS, HEADS, SEQLEN, EMB = 4, 32, 1024, 64
@@ -236,8 +239,8 @@ class TestRangeify(unittest.TestCase):
GlobalCounters.reset()
args = ()
args += (Opt(OptOps.DEMOTE, 5, 8),)
args += (Opt(OptOps.TC, 0, (0,0,1,3)),)
args += (Opt(OptOps.TC, 0, (0,0,1,0)),)
#args += (Opt(OptOps.TC, 0, (0,0,1,3)),)
ret = fa().contiguous(arg=args).realize()
with Context(RANGEIFY=0):
with Context(DEBUG=2):

View File

@@ -288,8 +288,6 @@ class Scheduler:
axes[2], new_range = self.shift_to(axes[2], amt, AxisType.UNROLL)
ne.append(new_range)
#print("ne", [x.arg for x in ne])
if use_tensor_cores != 2:
# fix the srcs
reduceop = get_single_element([x for x in self.ast.toposort() if x.op is Ops.REDUCE and x.tag == "TC"])
@@ -300,18 +298,17 @@ class Scheduler:
# get reduce/upcast axes for the tensor cores
tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))])
#print(tc.base_upcast_axes())
base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis(tc.base_upcast_axes())])
tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)])
#print(tc_upcast_axes)
# axes to range number (was done in lowerer)
tc_upcast_axes = tuple([tuple([(self.rngs[a].arg[0], sz) for a,sz in v]) for v in tc_upcast_axes])
tc_reduce_axes = tuple([self.rngs[a].arg[0] for a in tc_reduce_axes])
#print("ne", [x.arg for x in ne])
#print(tc_upcast_axes)
#print(tc_reduce_axes)
tc_upcast_axes = (((ne[0].arg[0], 2),), ((ne[0].arg[0], 2),), ((ne[0].arg[0], 2),))
#print("hack", tc_upcast_axes)
# construct the op
# TODO: remove tc_upcast_axes from the arg