mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
readable COPY(VIEW) reordering [pr] (#10505)
* readable COPY(VIEW) reordering [pr]
* assert that
* spec
* resolve
* Revert "resolve"
This reverts commit f5629fbef8.
* arg
This commit is contained in:
@@ -67,6 +67,10 @@ def split_reduceop(reduce:UOp, x:UOp):
|
||||
# reduce original axes, then split
|
||||
return splitted.r(*reduce.arg).r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape)
|
||||
|
||||
def copy_reorder_view(copy:UOp, view:UOp, base:UOp):
|
||||
if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device)
|
||||
return base.copy_to_device(copy.device).view(view.arg)
|
||||
|
||||
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT}
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
@@ -90,9 +94,7 @@ sym = symbolic_simple+PatternMatcher([
|
||||
# non device changing COPY is a NOOP
|
||||
(UPat(Ops.COPY, name="c", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda c,x: x if c.device == x.device else None),
|
||||
# store a shrink before COPY, otherwise view after the COPY
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, name="v"), UPat(Ops.DEVICE)), name="copy"), lambda copy,v:
|
||||
v.contiguous().copy_to_device(copy.device, arg=copy.arg) if prod(v.shape) < prod(v.base.shape) else \
|
||||
v.base.copy_to_device(copy.device, arg=copy.arg).view(v.st)),
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"), UPat(Ops.DEVICE)), name="copy"), copy_reorder_view),
|
||||
# remove cast to image when it's already a contiguous image
|
||||
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)),
|
||||
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
||||
|
||||
@@ -81,7 +81,7 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
|
||||
lambda root,x: root.dtype == x.dtype),
|
||||
|
||||
# COPY/ALLREDUCE/MULTI
|
||||
(UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda copy,x: copy.dtype == x.dtype),
|
||||
(UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE)), arg=None), lambda copy,x: copy.dtype == x.dtype),
|
||||
(UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)),
|
||||
(UPat(Ops.MULTI, name="multi"), lambda multi: all(x.dtype == multi.dtype for x in multi.src) and isinstance(multi.arg, int)),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user