mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 07:18:40 -05:00
refactor ast creation [compare_schedule] (#6050)
* refactor scheduler lazyop creation [compare_schedule] * helpful prints * this will become the default
This commit is contained in:
@@ -57,20 +57,18 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer,
|
||||
if (buf, st) in cache: return cache[(buf, st)]
|
||||
arg = buf.arg
|
||||
|
||||
# consts are always fused and generated
|
||||
if buf.op is MetaOps.CONST:
|
||||
unbound_st, st_var_vals = st.simplify().unbind()
|
||||
var_vals.update(st_var_vals)
|
||||
if isinstance(arg, Variable):
|
||||
arg, var_val = arg.unbind()
|
||||
var_vals[arg] = var_val
|
||||
else: assert isinstance(arg, get_args(ConstType)), f"cannot create ConstBuffer with value {arg}"
|
||||
return LazyOp(BufferOps.CONST, (), ConstBuffer(arg, buf.dtype, unbound_st))
|
||||
|
||||
# if we aren't fusing it, it's a load and we add it to the inputs
|
||||
# buffer ops define ShapeTracker
|
||||
if buf.realized is not None or (buf in realizes and buf not in outputs):
|
||||
unbound_st, st_var_vals = st.simplify().unbind()
|
||||
var_vals.update(st_var_vals)
|
||||
# if it's a const, we generate it
|
||||
if buf.op is MetaOps.CONST:
|
||||
if isinstance(arg, Variable):
|
||||
arg, var_val = arg.unbind()
|
||||
var_vals[arg] = var_val
|
||||
else: assert isinstance(arg, get_args(ConstType)), f"cannot create ConstBuffer with value {arg}"
|
||||
return LazyOp(BufferOps.CONST, (), ConstBuffer(arg, buf.dtype, unbound_st))
|
||||
# otherwise, it's a load and we add it to the inputs
|
||||
if buf in assign_targets:
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and\
|
||||
@@ -80,12 +78,7 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer,
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.setdefault(buf, len(inputs)), buf.dtype, unbound_st))
|
||||
|
||||
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
|
||||
if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}:
|
||||
assert buf in outputs
|
||||
return _recursive_lazyop(buf.srcs[0], st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
|
||||
|
||||
# if it's a reduce, we have to change the shapetracker
|
||||
# reduce ops change ShapeTracker
|
||||
if buf.op in ReduceOps:
|
||||
# if we are merging the reduce, skip it
|
||||
if (buf, st) not in reduce_info:
|
||||
@@ -93,9 +86,12 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer,
|
||||
return _recursive_lazyop(buf.srcs[0], st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
|
||||
st, arg = reduce_info[(buf, st)]
|
||||
|
||||
# otherwise we fuse it like normal
|
||||
return cache.setdefault((buf, st), LazyOp(cast(Op,buf.op), tuple(_recursive_lazyop(x, st, outputs, var_vals, inputs, realizes, assign_targets, \
|
||||
reduce_info, cache) for x in buf.srcs), arg))
|
||||
# elementwise ops pass shapetracker
|
||||
in_ops = tuple(_recursive_lazyop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs)
|
||||
if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}:
|
||||
assert buf in outputs, f"{buf.op} must be writable"
|
||||
return in_ops[0]
|
||||
return cache.setdefault((buf, st), LazyOp(cast(Op, buf.op), in_ops, arg))
|
||||
|
||||
def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
|
||||
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis) + axis
|
||||
|
||||
Reference in New Issue
Block a user