This commit is contained in:
George Hotz
2025-08-27 11:49:02 -07:00
parent 73f83e6fe6
commit ea1b853a60
5 changed files with 4 additions and 7 deletions

View File

@@ -116,10 +116,10 @@ class TestRangeify(unittest.TestCase):
out.realize()
def test_flash_attention(self):
#BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
# bigger
BS, HEADS, SEQLEN, EMB = 4, 32, 128, 64
#BS, HEADS, SEQLEN, EMB = 4, 32, 128, 64
# llama 8B
#BS, HEADS, SEQLEN, EMB = 4, 32, 2048, 128

View File

@@ -238,7 +238,6 @@ def no_vectorized_buf(buf:UOp):
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
cnt = cast.dtype.count
if idx.dtype.count > 1: return None
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.int.vec(cnt), tuple(range(cnt))))

View File

@@ -86,7 +86,7 @@ expander = PatternMatcher([
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
# do expansion
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, #Ops.BUFFERIZE,
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
(UPat(Ops.CONTRACT, name="con"), do_contract),
# BARRIERs aren't actually expanded

View File

@@ -211,8 +211,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if len(delete_ranges):
for s in UOp.sink(*delete_ranges).ranges:
if s in ret: del ret[s]
elif self.op in {Ops.BARRIER}:
ret = {x:None for x in self.src[0].ranges if x.arg[1] != AxisType.LOCAL}
else:
for s in self.src: ret.update(s.ranges)
return ret

View File

@@ -202,7 +202,7 @@ spec = PatternMatcher([
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
# WMMA has a <a, b, acc>
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), allow_any_len=True, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),