mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
DEMOTE op for putting globals in locals
This commit is contained in:
@@ -143,6 +143,13 @@ class TestPcontig(unittest.TestCase):
|
||||
print(f"mse: {mse}")
|
||||
self.assertLessEqual(mse, 1e-6)
|
||||
|
||||
def test_flash_attention_tc(self):
|
||||
opts = ()
|
||||
# rows in all the matrix
|
||||
opts += (Opt(OptOps.DEMOTE, 4, 8),)
|
||||
#opts += (Opt(OptOps.TC, 0, (0, 0, 1)),)
|
||||
self.test_flash_attention(opts)
|
||||
|
||||
def test_flash_attention_opt(self):
|
||||
opts = ()
|
||||
# columns in top matrix
|
||||
|
||||
@@ -6,6 +6,7 @@ from dataclasses import dataclass
|
||||
class OptOps(Enum):
|
||||
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); THREAD = auto() # noqa: E702
|
||||
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
|
||||
DEMOTE = auto()
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
|
||||
@@ -174,6 +174,23 @@ class Scheduler:
|
||||
check(not self.dont_use_locals, "can't use locals")
|
||||
check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce")
|
||||
ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD})
|
||||
elif opt.op is OptOps.DEMOTE:
|
||||
_, rr = self.shift_to(rng, cast(int, opt.arg), AxisType.LOOP)
|
||||
def do_demote(ctx, x:UOp):
|
||||
if x.tag is not None: return None
|
||||
mr = ctx[0]
|
||||
nr = mr.replace(arg=ctx[0].arg[0:-2]+(mr.arg[-2]+1, mr.arg[-1]))
|
||||
ctx[0] = nr
|
||||
buf = x.replace(src=(x.src[0], mr)+x.src[1:], tag=1).substitute({mr:nr})
|
||||
return UOp(Ops.APPENDINDEX, dtypes.void, (buf,mr))
|
||||
# do the demotion
|
||||
pm_demote = PatternMatcher([
|
||||
(UPat(Ops.END, src=(UPat(Ops.END, name="e1"),), allow_any_len=True, name="e2"), lambda e1,e2: e1.replace(src=e1.src+e2.src[1:])),
|
||||
(UPat(Ops.BUFFERIZE, name="x"), do_demote),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.APPENDINDEX, name="x"),), name="y", allow_any_len=True),
|
||||
lambda x,y: y.replace(src=(x.src[0],)+x.src[1:]+y.src[1:])),
|
||||
])
|
||||
self.ast = graph_rewrite(self.ast.src[0].end(rr).sink(), pm_demote, ctx=[rr], bottom_up=True, name="demote")
|
||||
elif opt.op is OptOps.TC:
|
||||
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps
|
||||
check(opt.axis is not None, "tensor core opts must have an axis")
|
||||
|
||||
@@ -59,6 +59,7 @@ class Ops(FastEnum):
|
||||
|
||||
# INDEX is a BinaryOp similar to ADD, but it operates on pointers
|
||||
INDEX = auto()
|
||||
APPENDINDEX = auto()
|
||||
|
||||
# BinaryOps
|
||||
ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto() # noqa: E702
|
||||
|
||||
@@ -188,7 +188,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.INDEX | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST:
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST | Ops.CONTRACT | Ops.APPENDINDEX:
|
||||
return None
|
||||
|
||||
# some ops init the shape
|
||||
|
||||
Reference in New Issue
Block a user