readable COPY(VIEW) reordering [pr] (#10505)

* readable COPY(VIEW) reordering [pr]

* assert that

* spec

* resolve

* Revert "resolve"

This reverts commit f5629fbef8.

* arg
This commit is contained in:
qazal
2025-05-24 17:08:58 +03:00
committed by GitHub
parent bbb6deff53
commit dd5601af68
2 changed files with 6 additions and 4 deletions

View File

@@ -67,6 +67,10 @@ def split_reduceop(reduce:UOp, x:UOp):
# reduce original axes, then split
return splitted.r(*reduce.arg).r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape)
def copy_reorder_view(copy:UOp, view:UOp, base:UOp):
if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device)
return base.copy_to_device(copy.device).view(view.arg)
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT}
sym = symbolic_simple+PatternMatcher([
@@ -90,9 +94,7 @@ sym = symbolic_simple+PatternMatcher([
# non device changing COPY is a NOOP
(UPat(Ops.COPY, name="c", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda c,x: x if c.device == x.device else None),
# store a shrink before COPY, otherwise view after the COPY
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, name="v"), UPat(Ops.DEVICE)), name="copy"), lambda copy,v:
v.contiguous().copy_to_device(copy.device, arg=copy.arg) if prod(v.shape) < prod(v.base.shape) else \
v.base.copy_to_device(copy.device, arg=copy.arg).view(v.st)),
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"), UPat(Ops.DEVICE)), name="copy"), copy_reorder_view),
# remove cast to image when it's already a contiguous image
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)),
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),

View File

@@ -81,7 +81,7 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
lambda root,x: root.dtype == x.dtype),
# COPY/ALLREDUCE/MULTI
(UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda copy,x: copy.dtype == x.dtype),
(UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE)), arg=None), lambda copy,x: copy.dtype == x.dtype),
(UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)),
(UPat(Ops.MULTI, name="multi"), lambda multi: all(x.dtype == multi.dtype for x in multi.src) and isinstance(multi.arg, int)),
])