refactor ast creation [compare_schedule] (#6050)

* refactor scheduler lazyop creation [compare_schedule]

* helpful prints

* this will become the default
This commit is contained in:
qazal
2024-08-13 02:15:11 +08:00
committed by GitHub
parent 8f787785d9
commit 529832d223

View File

@@ -57,20 +57,18 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer,
if (buf, st) in cache: return cache[(buf, st)]
arg = buf.arg
# consts are always fused and generated
if buf.op is MetaOps.CONST:
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
if isinstance(arg, Variable):
arg, var_val = arg.unbind()
var_vals[arg] = var_val
else: assert isinstance(arg, get_args(ConstType)), f"cannot create ConstBuffer with value {arg}"
return LazyOp(BufferOps.CONST, (), ConstBuffer(arg, buf.dtype, unbound_st))
# if we aren't fusing it, it's a load and we add it to the inputs
# buffer ops define ShapeTracker
if buf.realized is not None or (buf in realizes and buf not in outputs):
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
# if it's a const, we generate it
if buf.op is MetaOps.CONST:
if isinstance(arg, Variable):
arg, var_val = arg.unbind()
var_vals[arg] = var_val
else: assert isinstance(arg, get_args(ConstType)), f"cannot create ConstBuffer with value {arg}"
return LazyOp(BufferOps.CONST, (), ConstBuffer(arg, buf.dtype, unbound_st))
# otherwise, it's a load and we add it to the inputs
if buf in assign_targets:
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and\
@@ -80,12 +78,7 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer,
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.setdefault(buf, len(inputs)), buf.dtype, unbound_st))
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}:
assert buf in outputs
return _recursive_lazyop(buf.srcs[0], st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
# if it's a reduce, we have to change the shapetracker
# reduce ops change ShapeTracker
if buf.op in ReduceOps:
# if we are merging the reduce, skip it
if (buf, st) not in reduce_info:
@@ -93,9 +86,12 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer,
return _recursive_lazyop(buf.srcs[0], st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
st, arg = reduce_info[(buf, st)]
# otherwise we fuse it like normal
return cache.setdefault((buf, st), LazyOp(cast(Op,buf.op), tuple(_recursive_lazyop(x, st, outputs, var_vals, inputs, realizes, assign_targets, \
reduce_info, cache) for x in buf.srcs), arg))
# elementwise ops pass shapetracker
in_ops = tuple(_recursive_lazyop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs)
if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}:
assert buf in outputs, f"{buf.op} must be writable"
return in_ops[0]
return cache.setdefault((buf, st), LazyOp(cast(Op, buf.op), in_ops, arg))
def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis) + axis