mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
warp
This commit is contained in:
@@ -104,9 +104,16 @@ class TestRangeify(unittest.TestCase):
|
||||
# NOTE: these axes are poorly sorted
|
||||
args += (Opt(OptOps.TC, 0, (0,0,1,1)),)
|
||||
args += (Opt(OptOps.TC, 0, (0,0,1,0)),)
|
||||
|
||||
args += (Opt(OptOps.UPCAST, 0, 2),)
|
||||
args += (Opt(OptOps.UPCAST, 1, 2),)
|
||||
args += (Opt(OptOps.UNROLL, 0, 2),)
|
||||
args += (Opt(OptOps.UNROLL, 1, 2),)
|
||||
|
||||
tst = (A@B@C).contiguous(arg=args).realize()
|
||||
assert tst.uop.base.op is Ops.BUFFER, "buffer"
|
||||
with Context(RANGEIFY=0, DEBUG=0):
|
||||
with Context(RANGEIFY=0, DEBUG=2):
|
||||
GlobalCounters.reset()
|
||||
mse = ((A@B@C)-tst).square().mean().item()
|
||||
print(mse)
|
||||
|
||||
@@ -241,6 +248,8 @@ class TestRangeify(unittest.TestCase):
|
||||
args += (Opt(OptOps.DEMOTE, 5, 8),)
|
||||
args += (Opt(OptOps.TC, 0, (0,0,1,3)),)
|
||||
args += (Opt(OptOps.TC, 0, (0,0,1,0)),)
|
||||
args += (Opt(OptOps.WARP, 1, 32),)
|
||||
args += (Opt(OptOps.WARP, 2, 32),)
|
||||
ret = fa().contiguous(arg=args).realize()
|
||||
with Context(RANGEIFY=0):
|
||||
with Context(DEBUG=2):
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.uop.ops import AxisType
|
||||
class OptOps(Enum):
|
||||
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); THREAD = auto() # noqa: E702
|
||||
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
|
||||
DEMOTE = auto()
|
||||
DEMOTE = auto(); WARP = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
|
||||
@@ -172,6 +172,9 @@ class Scheduler:
|
||||
check(not self.dont_use_locals, "can't use locals")
|
||||
check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce")
|
||||
ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD})
|
||||
elif opt.op is OptOps.WARP:
|
||||
warp = UOp.range(cast(int, opt.arg), -1, AxisType.WARP)
|
||||
ret = self.shift_to(rng, cast(int, opt.arg), AxisType.WARP, input_new_rng=warp)
|
||||
elif opt.op is OptOps.DEMOTE:
|
||||
_, rr = self.shift_to(rng, cast(int, opt.arg), AxisType.LOOP)
|
||||
def do_demote(ctx, x:UOp):
|
||||
|
||||
Reference in New Issue
Block a user