mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
init folding changes from the tensor_map branch [pr] (#8666)
* init folding changes from the tensor_map branch [pr] * add ops_folding to the viz rewrite
This commit is contained in:
@@ -90,7 +90,7 @@ class ScheduleContext:
|
||||
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
|
||||
def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
if buf.op is Ops.SINK: return UOp.sink(*[add_buffers(x, ctx, cache) for x in buf.src])
|
||||
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, ctx, cache) for x in buf.src))
|
||||
# shapeless op is passthrough
|
||||
# realized is passthrough
|
||||
# constants are passthrough
|
||||
@@ -369,7 +369,8 @@ def replace_contiguous(ctx:ScheduleContext, alu:UOp):
|
||||
|
||||
ops_folding = symbolic_simple+PatternMatcher([
|
||||
# op with size 0 is zero
|
||||
(UPatScheduled(), lambda b,to_store,base: base.const_like(0) if base.size == 0 else None),
|
||||
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
|
||||
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
|
||||
# if the uop folded to a CONST we can delete the BUFFER
|
||||
(UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.const_like(const.const_arg)),
|
||||
# DETACH is a NOOP here
|
||||
@@ -382,8 +383,8 @@ ops_folding = symbolic_simple+PatternMatcher([
|
||||
# CONST doesn't need COPY
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat.cvar("x"),)), lambda x: x),
|
||||
# no COPY to same device, except clone (arg is True)
|
||||
(UPatScheduled(Ops.COPY, src=(UPat(), UPat(Ops.VIEW, name="copyin")), name="copy"),
|
||||
lambda base,b,copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
|
||||
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
|
||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
||||
(UPatScheduled(Ops.CONTIGUOUS, name="contig"), found_contiguous),
|
||||
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
|
||||
@@ -521,7 +522,7 @@ remove_movement_ops = PatternMatcher([
|
||||
def create_schedule_with_vars(outs:list[UOp], skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||
big_sink = UOp.sink(*outs)
|
||||
# if using VIZ, do a graph rewrite to vizualize the Tensor graph
|
||||
if getenv("VIZ"): graph_rewrite(big_sink, remove_movement_ops)
|
||||
if getenv("VIZ"): graph_rewrite(big_sink, remove_movement_ops+ops_folding)
|
||||
if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec)
|
||||
# to_uop is removing (many) of the movement ops
|
||||
sink = add_buffers(big_sink, ctx:=ScheduleContext(), cache={})
|
||||
|
||||
Reference in New Issue
Block a user