mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
reorder
This commit is contained in:
@@ -11,6 +11,11 @@ class TestRangeify(unittest.TestCase):
|
||||
ba = A.expand(N, N)
|
||||
((ba+1).sum(axis=1) + (ba+2).sum(axis=0)).realize()
|
||||
|
||||
def test_partial_contig(self):
|
||||
A = Tensor.empty(64, 64, 64)
|
||||
ret = A.sum(axis=2).contiguous(arg=(1,)).sum(axis=1)
|
||||
ret.realize()
|
||||
|
||||
def test_double_gemm(self):
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
|
||||
@@ -72,14 +72,14 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
|
||||
# ** expander (expand_rewrite) **
|
||||
ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic"))
|
||||
|
||||
# expand
|
||||
ret.append(RewriteStep(sym+expander, name="expander"))
|
||||
# add locals
|
||||
ret.append(RewriteStep(pm_flatten_range+pm_add_buffers_local+rangeify_codegen, name="add local buffers"))
|
||||
|
||||
# add gpu dims (late). this also handles UNROLL range
|
||||
ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims"))
|
||||
|
||||
# add locals
|
||||
ret.append(RewriteStep(pm_flatten_range+pm_add_buffers_local+rangeify_codegen, name="add local buffers"))
|
||||
# expand
|
||||
ret.append(RewriteStep(sym+expander, name="expander"))
|
||||
|
||||
# ** devectorizer (full_graph_rewrite) **
|
||||
# remove reduce
|
||||
|
||||
@@ -86,7 +86,7 @@ expander = PatternMatcher([
|
||||
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
|
||||
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
||||
# do expansion
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, #Ops.BUFFERIZE,
|
||||
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
||||
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
||||
# BARRIERs aren't actually expanded
|
||||
|
||||
@@ -207,8 +207,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
ret: dict[UOp, None] = {}
|
||||
if self.op in range_start.keys():
|
||||
for s in self.src[:range_start[self.op]]: ret.update(s.ranges)
|
||||
for s in self.src[range_start[self.op]:]:
|
||||
if s in ret: del ret[s]
|
||||
delete_ranges = self.src[range_start[self.op]:]
|
||||
if len(delete_ranges):
|
||||
for s in UOp.sink(*delete_ranges).ranges:
|
||||
if s in ret: del ret[s]
|
||||
elif self.op in {Ops.BARRIER}:
|
||||
ret = {x:None for x in self.src[0].ranges if x.arg[1] != AxisType.LOCAL}
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user