refactor ast arg and op [compare_schedule] (#6052)

This commit is contained in:
qazal
2024-08-13 02:51:00 +08:00
committed by GitHub
parent dc2617bffd
commit 71c5901fc1

View File

@@ -1,8 +1,8 @@
import sys, pickle, atexit, importlib, contextlib
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args
from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, get_args
from tinygrad.ops import MetaOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \
GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
@@ -55,7 +55,7 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer,
"""recursively create a lazyop"""
if buf is not buf.base: st, buf = buf.st+st, buf.base
if (buf, st) in cache: return cache[(buf, st)]
arg = buf.arg
assert buf.op is not None, "base must be a base itself"
# buffer ops define ShapeTracker
if buf.realized is not None or (buf in realizes and buf not in outputs):
@@ -63,11 +63,11 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer,
var_vals.update(st_var_vals)
# if it's a const, we generate it
if buf.op is MetaOps.CONST:
if isinstance(arg, Variable):
arg, var_val = arg.unbind()
var_vals[arg] = var_val
else: assert isinstance(arg, get_args(ConstType)), f"cannot create ConstBuffer with value {arg}"
return LazyOp(BufferOps.CONST, (), ConstBuffer(arg, buf.dtype, unbound_st))
if isinstance(val:=buf.arg, Variable):
val, var_val = val.unbind()
var_vals[val] = var_val
else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}"
return LazyOp(BufferOps.CONST, (), ConstBuffer(val, buf.dtype, unbound_st))
# otherwise, it's a load and we add it to the inputs
if buf in assign_targets:
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
@@ -80,18 +80,20 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer,
# reduce ops change ShapeTracker
if buf.op in ReduceOps:
rinfo = reduce_info.get((buf, st))
rsrc = _recursive_lazyop(buf.srcs[0], st:=(rinfo[0] if rinfo else st), outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
# if we are merging the reduce, skip it
if (buf, st) not in reduce_info:
assert buf.srcs[0].base.op is buf.op, f"can't merge reduceop {buf.op} with {buf.srcs[0].base.op}\n{st}"
return _recursive_lazyop(buf.srcs[0], st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
st, arg = reduce_info[(buf, st)]
if rinfo is None:
assert rsrc.op is buf.op, f"can't merge reduceop {buf.op} with {rsrc.op}\n{st}"
return rsrc
return cache.setdefault((buf, st), LazyOp(buf.op, (rsrc,), rinfo[1]))
# elementwise ops pass shapetracker
in_ops = tuple(_recursive_lazyop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs)
if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}:
assert buf in outputs, f"{buf.op} must be writable"
return in_ops[0]
return cache.setdefault((buf, st), LazyOp(cast(Op, buf.op), in_ops, arg))
return cache.setdefault((buf, st), LazyOp(buf.op, in_ops, buf.arg))
def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis) + axis