mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
break out metaops (#6948)
This commit is contained in:
@@ -4,7 +4,7 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, replace
|
||||
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA, dedup
|
||||
from tinygrad.helpers import NO_MEMORY_PLANNER
|
||||
from tinygrad.ops import MetaOps, UOps, UOp
|
||||
from tinygrad.ops import UOps, UOp
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
||||
@@ -188,19 +188,19 @@ class ExecItem:
|
||||
return et
|
||||
|
||||
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
|
||||
assert len(set(x.device for x in si.bufs)) == 1 or (si.ast.op is UOps.EXT and si.ast.arg[0] is MetaOps.COPY)
|
||||
assert len(set(x.device for x in si.bufs)) == 1 or si.ast.op is UOps.COPY
|
||||
if si.ast.op is UOps.SINK:
|
||||
runner = get_runner(si.outputs[0].device, si.ast)
|
||||
return ExecItem(runner, [si.bufs[x] for x in runner.p.globals], si.metadata)
|
||||
out, (op, arg) = si.outputs[0], si.ast.arg
|
||||
if op is MetaOps.COPY:
|
||||
out, arg = si.outputs[0], si.ast.arg
|
||||
if si.ast.op is UOps.COPY:
|
||||
kernel_type = BufferCopy
|
||||
if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
|
||||
kernel_type = BufferXfer
|
||||
return ExecItem(kernel_type(arg, out.device, si.inputs[0].device), list(si.bufs))
|
||||
if op is MetaOps.CUSTOM: return ExecItem(CustomOp(arg), list(si.bufs))
|
||||
if op is MetaOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
|
||||
if op is MetaOps.VIEW: return ExecItem(ViewOp(out), list(si.bufs))
|
||||
if si.ast.op is UOps.CUSTOM: return ExecItem(CustomOp(arg), list(si.bufs))
|
||||
if si.ast.op is UOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
|
||||
if si.ast.op is UOps.VIEW: return ExecItem(ViewOp(out), list(si.bufs))
|
||||
raise RuntimeError(f"don't know how to lower {si.ast}")
|
||||
|
||||
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
|
||||
|
||||
@@ -179,7 +179,8 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp]) -> Tupl
|
||||
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
|
||||
if (out:=outs[0]).op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
|
||||
metadata = (out.metadata,) if out.metadata is not None else None
|
||||
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), (out,)+tuple(x.base for x in out.srcs), metadata), {}
|
||||
uop = {MetaOps.CUSTOM: UOps.CUSTOM, MetaOps.COPY: UOps.COPY, MetaOps.EMPTY: UOps.EMPTY, MetaOps.VIEW: UOps.VIEW}[cast(MetaOps, out.op)]
|
||||
return LBScheduleItem(UOp(uop, out.dtype, (), out.arg), (out,)+tuple(x.base for x in out.srcs), metadata), {}
|
||||
# create the stores
|
||||
var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
|
||||
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
|
||||
|
||||
@@ -99,7 +99,13 @@ def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.
|
||||
class UOps(FastEnum):
|
||||
# uops that aren't rendered
|
||||
SINK = auto()
|
||||
EXT = auto()
|
||||
|
||||
# metaops
|
||||
CUSTOM = auto()
|
||||
COPY = auto()
|
||||
EMPTY = auto()
|
||||
VIEW = auto()
|
||||
|
||||
EXPAND = auto()
|
||||
CONTRACT = auto()
|
||||
SHAPETRACKER = auto()
|
||||
|
||||
Reference in New Issue
Block a user