DEMOTE op for putting globals in locals

This commit is contained in:
George Hotz
2025-10-29 17:22:59 +08:00
parent 819592ee67
commit a9d91ffcfc
5 changed files with 27 additions and 1 deletions

View File

@@ -143,6 +143,13 @@ class TestPcontig(unittest.TestCase):
print(f"mse: {mse}")
self.assertLessEqual(mse, 1e-6)
def test_flash_attention_tc(self):
opts = ()
# rows in all the matrix
opts += (Opt(OptOps.DEMOTE, 4, 8),)
#opts += (Opt(OptOps.TC, 0, (0, 0, 1)),)
self.test_flash_attention(opts)
def test_flash_attention_opt(self):
opts = ()
# columns in top matrix

View File

@@ -6,6 +6,7 @@ from dataclasses import dataclass
class OptOps(Enum):
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); THREAD = auto() # noqa: E702
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
DEMOTE = auto()
def __lt__(self, x:OptOps): return self.value < x.value
@dataclass(frozen=True, order=True)

View File

@@ -174,6 +174,23 @@ class Scheduler:
check(not self.dont_use_locals, "can't use locals")
check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce")
ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD})
elif opt.op is OptOps.DEMOTE:
_, rr = self.shift_to(rng, cast(int, opt.arg), AxisType.LOOP)
def do_demote(ctx, x:UOp):
if x.tag is not None: return None
mr = ctx[0]
nr = mr.replace(arg=ctx[0].arg[0:-2]+(mr.arg[-2]+1, mr.arg[-1]))
ctx[0] = nr
buf = x.replace(src=(x.src[0], mr)+x.src[1:], tag=1).substitute({mr:nr})
return UOp(Ops.APPENDINDEX, dtypes.void, (buf,mr))
# do the demotion
pm_demote = PatternMatcher([
(UPat(Ops.END, src=(UPat(Ops.END, name="e1"),), allow_any_len=True, name="e2"), lambda e1,e2: e1.replace(src=e1.src+e2.src[1:])),
(UPat(Ops.BUFFERIZE, name="x"), do_demote),
(UPat(Ops.INDEX, src=(UPat(Ops.APPENDINDEX, name="x"),), name="y", allow_any_len=True),
lambda x,y: y.replace(src=(x.src[0],)+x.src[1:]+y.src[1:])),
])
self.ast = graph_rewrite(self.ast.src[0].end(rr).sink(), pm_demote, ctx=[rr], bottom_up=True, name="demote")
elif opt.op is OptOps.TC:
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps
check(opt.axis is not None, "tensor core opts must have an axis")

View File

@@ -59,6 +59,7 @@ class Ops(FastEnum):
# INDEX is a BinaryOp similar to ADD, but it operates on pointers
INDEX = auto()
APPENDINDEX = auto()
# BinaryOps
ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto() # noqa: E702

View File

@@ -188,7 +188,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.INDEX | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST:
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST | Ops.CONTRACT | Ops.APPENDINDEX:
return None
# some ops init the shape