mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
add locals support to rangeify (#11826)
This commit is contained in:
@@ -5,7 +5,7 @@ from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.helpers import getenv, colored, prod, unwrap
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.codegen.opt.kernel import axis_colors
|
||||
from tinygrad.codegen.opt.kernel import axis_colors, Opt, OptOps
|
||||
from tinygrad.codegen.opt.swizzler import merge_views, view_left
|
||||
|
||||
def to_colored(full_shape, axis_types): return '_'.join([colored(str(s), axis_colors[at]) for s,at in zip(full_shape, axis_types)])
|
||||
@@ -44,6 +44,21 @@ pm = PatternMatcher([
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
|
||||
])
|
||||
|
||||
def rangeify_kernel3():
|
||||
a = Tensor.empty(N,N)
|
||||
b = Tensor.empty(N,N)
|
||||
c = a@b
|
||||
#c = c.reshape((32,2,16,4,32,2,16,4)).contiguous()
|
||||
with Context(RANGEIFY=1):
|
||||
sink = c.schedule()[-1].ast
|
||||
#print(sink)
|
||||
|
||||
opts = [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.UPCAST, 0, 2)]
|
||||
opts += [Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.LOCAL, 1, 16), Opt(OptOps.UPCAST, 1, 2)]
|
||||
opts += [Opt(OptOps.UNROLL, 0, 8)]
|
||||
|
||||
return sink.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
|
||||
|
||||
def top_spec_kernel3():
|
||||
a = Tensor.empty(N,N)
|
||||
b = Tensor.empty(N,N)
|
||||
@@ -309,10 +324,15 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
||||
|
||||
if __name__ == "__main__":
|
||||
HL = getenv("HL")
|
||||
if HL == 2: hprg = top_spec_kernel3()
|
||||
if HL == 3: hprg = rangeify_kernel3()
|
||||
elif HL == 2: hprg = top_spec_kernel3()
|
||||
elif HL == 1: hprg = hl_spec_kernel3()
|
||||
else: hprg = hand_spec_kernel3()
|
||||
prg = get_program(hprg, Device.default.renderer)
|
||||
if HL == 3:
|
||||
with Context(RANGEIFY=1, BLOCK_REORDER=0):
|
||||
prg = get_program(hprg, Device.default.renderer)
|
||||
else:
|
||||
prg = get_program(hprg, Device.default.renderer)
|
||||
print(prg.src)
|
||||
if getenv("SRC"): exit(0)
|
||||
hrunner = CompiledRunner(prg)
|
||||
|
||||
Reference in New Issue
Block a user