rename to BUFFER_VIEW + MetaOps cleanup (#6953)

This commit is contained in:
qazal
2024-10-08 15:09:22 +03:00
committed by GitHub
parent 1ff2c98f8a
commit 851f39653a
3 changed files with 6 additions and 6 deletions

View File

@@ -201,7 +201,7 @@ def lower_schedule_item(si:ScheduleItem) -> ExecItem:
return ExecItem(kernel_type(arg, out.device, si.inputs[0].device), list(si.bufs))
if si.ast.op is UOps.CUSTOM: return ExecItem(CustomOp(arg), list(si.bufs))
if si.ast.op is UOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
if si.ast.op is UOps.VIEW: return ExecItem(ViewOp(out), list(si.bufs))
if si.ast.op is UOps.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs))
raise RuntimeError(f"don't know how to lower {si.ast}")
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:

View File

@@ -17,6 +17,7 @@ from tinygrad.shape.view import View, strides_for_shape
sys.setrecursionlimit(10000)
BUF_LIMIT = {"METAL": 32}
METAOPS = {MetaOps.CUSTOM:UOps.CUSTOM, MetaOps.COPY:UOps.COPY, MetaOps.EMPTY:UOps.EMPTY, MetaOps.VIEW:UOps.BUFFER_VIEW}
# *** ScheduleItem return type ***
@@ -177,10 +178,9 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp]) -> Tuple[LBScheduleItem, Dict[Variable, int]]:
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
if (out:=outs[0]).op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
metadata = (out.metadata,) if out.metadata is not None else None
uop = {MetaOps.CUSTOM: UOps.CUSTOM, MetaOps.COPY: UOps.COPY, MetaOps.EMPTY: UOps.EMPTY, MetaOps.VIEW: UOps.VIEW}[cast(MetaOps, out.op)]
return LBScheduleItem(UOp(uop, out.dtype, (), out.arg), (out,)+tuple(x.base for x in out.srcs), metadata), {}
if (out:=outs[0]).op in METAOPS:
return LBScheduleItem(UOp(METAOPS[cast(MetaOps, out.op)], out.dtype, (), out.arg), (out,)+tuple(x.base for x in out.srcs),
(out.metadata,) if out.metadata is not None else None), {}
# create the stores
var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}

View File

@@ -104,7 +104,7 @@ class UOps(FastEnum):
CUSTOM = auto()
COPY = auto()
EMPTY = auto()
VIEW = auto()
BUFFER_VIEW = auto()
EXPAND = auto()
CONTRACT = auto()