This commit is contained in:
George Hotz
2025-08-26 17:10:06 -07:00
parent 8b067e5dca
commit 78e092d59d
4 changed files with 14 additions and 7 deletions

View File

@@ -11,6 +11,11 @@ class TestRangeify(unittest.TestCase):
ba = A.expand(N, N)
((ba+1).sum(axis=1) + (ba+2).sum(axis=0)).realize()
def test_partial_contig(self):
A = Tensor.empty(64, 64, 64)
ret = A.sum(axis=2).contiguous(arg=(1,)).sum(axis=1)
ret.realize()
def test_double_gemm(self):
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)

View File

@@ -72,14 +72,14 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
# ** expander (expand_rewrite) **
ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic"))
# expand
ret.append(RewriteStep(sym+expander, name="expander"))
# add locals
ret.append(RewriteStep(pm_flatten_range+pm_add_buffers_local+rangeify_codegen, name="add local buffers"))
# add gpu dims (late). this also handles UNROLL range
ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims"))
# add locals
ret.append(RewriteStep(pm_flatten_range+pm_add_buffers_local+rangeify_codegen, name="add local buffers"))
# expand
ret.append(RewriteStep(sym+expander, name="expander"))
# ** devectorizer (full_graph_rewrite) **
# remove reduce

View File

@@ -86,7 +86,7 @@ expander = PatternMatcher([
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
# do expansion
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, #Ops.BUFFERIZE,
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
(UPat(Ops.CONTRACT, name="con"), do_contract),
# BARRIERs aren't actually expanded

View File

@@ -207,8 +207,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
ret: dict[UOp, None] = {}
if self.op in range_start.keys():
for s in self.src[:range_start[self.op]]: ret.update(s.ranges)
for s in self.src[range_start[self.op]:]:
if s in ret: del ret[s]
delete_ranges = self.src[range_start[self.op]:]
if len(delete_ranges):
for s in UOp.sink(*delete_ranges).ranges:
if s in ret: del ret[s]
elif self.op in {Ops.BARRIER}:
ret = {x:None for x in self.src[0].ranges if x.arg[1] != AxisType.LOCAL}
else: