mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
refactor ast arg and op [compare_schedule] (#6052)
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import sys, pickle, atexit, importlib, contextlib
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args
|
||||
from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st
|
||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, get_args
|
||||
from tinygrad.ops import MetaOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st
|
||||
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \
|
||||
GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
|
||||
@@ -55,7 +55,7 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer,
|
||||
"""recursively create a lazyop"""
|
||||
if buf is not buf.base: st, buf = buf.st+st, buf.base
|
||||
if (buf, st) in cache: return cache[(buf, st)]
|
||||
arg = buf.arg
|
||||
assert buf.op is not None, "base must be a base itself"
|
||||
|
||||
# buffer ops define ShapeTracker
|
||||
if buf.realized is not None or (buf in realizes and buf not in outputs):
|
||||
@@ -63,11 +63,11 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer,
|
||||
var_vals.update(st_var_vals)
|
||||
# if it's a const, we generate it
|
||||
if buf.op is MetaOps.CONST:
|
||||
if isinstance(arg, Variable):
|
||||
arg, var_val = arg.unbind()
|
||||
var_vals[arg] = var_val
|
||||
else: assert isinstance(arg, get_args(ConstType)), f"cannot create ConstBuffer with value {arg}"
|
||||
return LazyOp(BufferOps.CONST, (), ConstBuffer(arg, buf.dtype, unbound_st))
|
||||
if isinstance(val:=buf.arg, Variable):
|
||||
val, var_val = val.unbind()
|
||||
var_vals[val] = var_val
|
||||
else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}"
|
||||
return LazyOp(BufferOps.CONST, (), ConstBuffer(val, buf.dtype, unbound_st))
|
||||
# otherwise, it's a load and we add it to the inputs
|
||||
if buf in assign_targets:
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
@@ -80,18 +80,20 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer,
|
||||
|
||||
# reduce ops change ShapeTracker
|
||||
if buf.op in ReduceOps:
|
||||
rinfo = reduce_info.get((buf, st))
|
||||
rsrc = _recursive_lazyop(buf.srcs[0], st:=(rinfo[0] if rinfo else st), outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
|
||||
# if we are merging the reduce, skip it
|
||||
if (buf, st) not in reduce_info:
|
||||
assert buf.srcs[0].base.op is buf.op, f"can't merge reduceop {buf.op} with {buf.srcs[0].base.op}\n{st}"
|
||||
return _recursive_lazyop(buf.srcs[0], st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
|
||||
st, arg = reduce_info[(buf, st)]
|
||||
if rinfo is None:
|
||||
assert rsrc.op is buf.op, f"can't merge reduceop {buf.op} with {rsrc.op}\n{st}"
|
||||
return rsrc
|
||||
return cache.setdefault((buf, st), LazyOp(buf.op, (rsrc,), rinfo[1]))
|
||||
|
||||
# elementwise ops pass shapetracker
|
||||
in_ops = tuple(_recursive_lazyop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs)
|
||||
if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}:
|
||||
assert buf in outputs, f"{buf.op} must be writable"
|
||||
return in_ops[0]
|
||||
return cache.setdefault((buf, st), LazyOp(cast(Op, buf.op), in_ops, arg))
|
||||
return cache.setdefault((buf, st), LazyOp(buf.op, in_ops, buf.arg))
|
||||
|
||||
def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
|
||||
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis) + axis
|
||||
|
||||
Reference in New Issue
Block a user