mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
delete
This commit is contained in:
@@ -116,10 +116,10 @@ class TestRangeify(unittest.TestCase):
|
||||
out.realize()
|
||||
|
||||
def test_flash_attention(self):
|
||||
#BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
|
||||
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
|
||||
|
||||
# bigger
|
||||
BS, HEADS, SEQLEN, EMB = 4, 32, 128, 64
|
||||
#BS, HEADS, SEQLEN, EMB = 4, 32, 128, 64
|
||||
|
||||
# llama 8B
|
||||
#BS, HEADS, SEQLEN, EMB = 4, 32, 2048, 128
|
||||
|
||||
@@ -238,7 +238,6 @@ def no_vectorized_buf(buf:UOp):
|
||||
|
||||
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
|
||||
cnt = cast.dtype.count
|
||||
if idx.dtype.count > 1: return None
|
||||
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
|
||||
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.int.vec(cnt), tuple(range(cnt))))
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ expander = PatternMatcher([
|
||||
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
|
||||
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
||||
# do expansion
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, #Ops.BUFFERIZE,
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
|
||||
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
||||
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
||||
# BARRIERs aren't actually expanded
|
||||
|
||||
@@ -211,8 +211,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if len(delete_ranges):
|
||||
for s in UOp.sink(*delete_ranges).ranges:
|
||||
if s in ret: del ret[s]
|
||||
elif self.op in {Ops.BARRIER}:
|
||||
ret = {x:None for x in self.src[0].ranges if x.arg[1] != AxisType.LOCAL}
|
||||
else:
|
||||
for s in self.src: ret.update(s.ranges)
|
||||
return ret
|
||||
|
||||
@@ -202,7 +202,7 @@ spec = PatternMatcher([
|
||||
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
|
||||
|
||||
# WMMA has a <a, b, acc>
|
||||
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), allow_any_len=True, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
|
||||
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
|
||||
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
|
||||
|
||||
Reference in New Issue
Block a user