mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
flash attention with two gemms
This commit is contained in:
@@ -219,7 +219,10 @@ class TestRangeify(unittest.TestCase):
|
||||
out.realize()
|
||||
|
||||
def test_flash_attention(self):
|
||||
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
|
||||
#BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
|
||||
|
||||
# lil bigger
|
||||
BS, HEADS, SEQLEN, EMB = 4, 32, 128, 64
|
||||
|
||||
# bigger
|
||||
#BS, HEADS, SEQLEN, EMB = 4, 32, 1024, 64
|
||||
@@ -236,8 +239,8 @@ class TestRangeify(unittest.TestCase):
|
||||
GlobalCounters.reset()
|
||||
args = ()
|
||||
args += (Opt(OptOps.DEMOTE, 5, 8),)
|
||||
args += (Opt(OptOps.TC, 0, (0,0,1,3)),)
|
||||
args += (Opt(OptOps.TC, 0, (0,0,1,0)),)
|
||||
#args += (Opt(OptOps.TC, 0, (0,0,1,3)),)
|
||||
ret = fa().contiguous(arg=args).realize()
|
||||
with Context(RANGEIFY=0):
|
||||
with Context(DEBUG=2):
|
||||
|
||||
@@ -288,8 +288,6 @@ class Scheduler:
|
||||
axes[2], new_range = self.shift_to(axes[2], amt, AxisType.UNROLL)
|
||||
ne.append(new_range)
|
||||
|
||||
#print("ne", [x.arg for x in ne])
|
||||
|
||||
if use_tensor_cores != 2:
|
||||
# fix the srcs
|
||||
reduceop = get_single_element([x for x in self.ast.toposort() if x.op is Ops.REDUCE and x.tag == "TC"])
|
||||
@@ -300,18 +298,17 @@ class Scheduler:
|
||||
|
||||
# get reduce/upcast axes for the tensor cores
|
||||
tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))])
|
||||
#print(tc.base_upcast_axes())
|
||||
base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis(tc.base_upcast_axes())])
|
||||
tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)])
|
||||
#print(tc_upcast_axes)
|
||||
|
||||
# axes to range number (was done in lowerer)
|
||||
tc_upcast_axes = tuple([tuple([(self.rngs[a].arg[0], sz) for a,sz in v]) for v in tc_upcast_axes])
|
||||
tc_reduce_axes = tuple([self.rngs[a].arg[0] for a in tc_reduce_axes])
|
||||
#print("ne", [x.arg for x in ne])
|
||||
#print(tc_upcast_axes)
|
||||
#print(tc_reduce_axes)
|
||||
|
||||
tc_upcast_axes = (((ne[0].arg[0], 2),), ((ne[0].arg[0], 2),), ((ne[0].arg[0], 2),))
|
||||
#print("hack", tc_upcast_axes)
|
||||
|
||||
# construct the op
|
||||
# TODO: remove tc_upcast_axes from the arg
|
||||
|
||||
Reference in New Issue
Block a user