This commit is contained in:
George Hotz
2025-10-02 18:04:45 +08:00
parent 3fb3dd4c06
commit dae164ffb1
3 changed files with 14 additions and 2 deletions

View File

@@ -104,9 +104,16 @@ class TestRangeify(unittest.TestCase):
# NOTE: these axes are poorly sorted
args += (Opt(OptOps.TC, 0, (0,0,1,1)),)
args += (Opt(OptOps.TC, 0, (0,0,1,0)),)
args += (Opt(OptOps.UPCAST, 0, 2),)
args += (Opt(OptOps.UPCAST, 1, 2),)
args += (Opt(OptOps.UNROLL, 0, 2),)
args += (Opt(OptOps.UNROLL, 1, 2),)
tst = (A@B@C).contiguous(arg=args).realize()
assert tst.uop.base.op is Ops.BUFFER, "buffer"
with Context(RANGEIFY=0, DEBUG=0):
with Context(RANGEIFY=0, DEBUG=2):
GlobalCounters.reset()
mse = ((A@B@C)-tst).square().mean().item()
print(mse)
@@ -241,6 +248,8 @@ class TestRangeify(unittest.TestCase):
args += (Opt(OptOps.DEMOTE, 5, 8),)
args += (Opt(OptOps.TC, 0, (0,0,1,3)),)
args += (Opt(OptOps.TC, 0, (0,0,1,0)),)
args += (Opt(OptOps.WARP, 1, 32),)
args += (Opt(OptOps.WARP, 2, 32),)
ret = fa().contiguous(arg=args).realize()
with Context(RANGEIFY=0):
with Context(DEBUG=2):

View File

@@ -7,7 +7,7 @@ from tinygrad.uop.ops import AxisType
class OptOps(Enum):
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); THREAD = auto() # noqa: E702
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
DEMOTE = auto()
DEMOTE = auto(); WARP = auto() # noqa: E702
def __lt__(self, x:OptOps): return self.value < x.value
@dataclass(frozen=True, order=True)

View File

@@ -172,6 +172,9 @@ class Scheduler:
check(not self.dont_use_locals, "can't use locals")
check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce")
ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD})
elif opt.op is OptOps.WARP:
warp = UOp.range(cast(int, opt.arg), -1, AxisType.WARP)
ret = self.shift_to(rng, cast(int, opt.arg), AxisType.WARP, input_new_rng=warp)
elif opt.op is OptOps.DEMOTE:
_, rr = self.shift_to(rng, cast(int, opt.arg), AxisType.LOOP)
def do_demote(ctx, x:UOp):