diff --git a/test/test_schedule.py b/test/test_schedule.py index 8cab22f4fa..dadca5b2c1 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1987,7 +1987,7 @@ class TestBigGraph(unittest.TestCase): check_schedule(x, 1) tensor_const_pm = PatternMatcher([ - (UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)))), lambda: True), + (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.CONST, src=()))), lambda: True), (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST))))), lambda: True), ]) class TestConst(unittest.TestCase): diff --git a/test/unit/test_tensor_uop_representation.py b/test/unit/test_tensor_uop_representation.py index 532e7a4582..999f22b39d 100644 --- a/test/unit/test_tensor_uop_representation.py +++ b/test/unit/test_tensor_uop_representation.py @@ -3,7 +3,7 @@ from tinygrad import Tensor from tinygrad.ops import UPat, Ops realized_pattern = UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)) -const_pattern = UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)))) +const_pattern = UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.CONST))) def is_pattern(ten:Tensor, pat:UPat): assert pat.match(ten.lazydata, {}) class TestTensorUopRepresentation(unittest.TestCase): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 245e6bab99..7d2b23a913 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -5,7 +5,7 @@ from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_ from tinygrad.ops import identity_element, buffers, exec_alu, type_verify from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, ContextVar -from tinygrad.dtype import DType, ImageDType, dtypes +from tinygrad.dtype import ConstType, DType, ImageDType, dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View, strides_for_shape from tinygrad.device import Buffer @@ -39,9 +39,6 @@ tensor_uop_spec = PatternMatcher([ # Tensor variable bindings (UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True), - # Tensor const has a ShapeTracker of shape=() and a device - (UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)),)), lambda: True), - # DETACH and CONTIGUOUS change how we interpret the source UOp # CONTIGUOUS ensures the source UOp realizes (UPat((Ops.DETACH, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype), @@ -79,6 +76,10 @@ tensor_uop_spec = PatternMatcher([ # DEVICE and VIEW specify device and shape for BIND (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND))), lambda: True), + # Tensor const has a ShapeTracker of shape=() and a device + (UPat(Ops.VIEW, name="view", arg=ShapeTracker.from_shape(()), src=(UPat(Ops.DEVICE), UPat(Ops.CONST, name="const"))), + lambda view,const: view.dtype == const.dtype), + # NOTE: EMPTY just ensures the source BUFFER is allocated before children run # TODO: this should be EMPTY(VIEW(BUFFER)) (UPat(Ops.EMPTY, src=(), arg=None), lambda: True), @@ -127,9 +128,9 @@ class ScheduleContext: contiguous: dict[UOp, UOp] = field(default_factory=dict) # this maps roots to places they are made contiguous children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) -# TODO: delete this once BIND has a VIEW source -# currently tensor BIND is VIEW(DEVICE, BIND) - CONST(VIEW(DEVICE)) is a prereq for this -def is_constant(u:UOp): return u.op is Ops.CONST or (u.op is Ops.VIEW and len(u.src) == 2 and u.src[1].op is Ops.BIND) +# TODO: delete this once CONST has a VIEW source +# currently tensor uop is VIEW(DEVICE, CONST) +def is_constant(u:UOp): return u.op is Ops.VIEW and len(u.src) == 2 and u.src[1].op in {Ops.CONST, Ops.BIND} def to_uop(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r @@ -405,6 +406,11 @@ class UPatScheduled(UPat): # ** this is schedule level const folding +def _as_const(u:UOp, val:ConstType) -> UOp: + assert is_scheduled(u), f"must be scheduled to fold {u}" + st = (base:=ShapeTracker.from_shape(())).reshape((1,)*len(u.shape)).expand(u.shape) + return UOp(Ops.VIEW, u.dtype, (u.buf_uop, UOp.const(u.dtype, val)), base).view(st) + def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: # remove reduce on unmasked const if all_int(x.shape) and x.is_unrealized_unmasked_const(): @@ -445,9 +451,9 @@ def replace_contiguous(ctx:ScheduleContext, alu:UOp): ops_folding = PatternMatcher([ # op with size 0 is zero - (UPatScheduled(), lambda b,to_store,base: base.const_like(0) if base.size == 0 else None), + (UPatScheduled(), lambda b,to_store,base: _as_const(base, 0) if base.size == 0 else None), # if the uop folded to a CONST we can delete the BUFFER - (UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.const_like(const.const_arg)), + (UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.replace(src=(UOp(Ops.DEVICE, arg=base.device), const))), # DETACH is a NOOP here (UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]), # elementwise const folding @@ -557,7 +563,7 @@ def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp: break_sched = PatternMatcher([ # consts are always fused and generated - (UPat.cvar(name="root"), lambda root: None if root.st is None else UOp.const_with_shape(root.dtype.base, root.const_arg, root.shape)), + (UPat(Ops.VIEW, name="root", src=(UPat(), UPat.cvar())), lambda root: UOp.const_with_shape(root.dtype.base, root.const_arg, root.shape)), # values from BIND append to this schedule's var_vals (UPat(Ops.VIEW, name="st", src=(UPat(), UPat(Ops.BIND, name="bind"))), unbind_variable), # view of realized buffer just loads diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 263363b1d9..9ee0a3ba5c 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -437,10 +437,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None, src:tuple[UOp, ...]=()) -> UOp: from tinygrad.shape.shapetracker import ShapeTracker if op is Ops.CONST: - # Tensor const is a CONST(VIEW(DEVICE)) -> RESHAPE -> EXPAND + # Tensor const is a VIEW(DEVICE, CONST) -> RESHAPE -> EXPAND assert isinstance(arg, get_args(ConstType)), f"trying to create CONST with {arg=}" - return UOp.const(dtype, unwrap(arg)).replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), - ShapeTracker.from_shape(())),)).reshape((1,)*len(shape)).expand(shape) + return UOp(Ops.VIEW, dtype, (UOp(Ops.DEVICE, arg=device), UOp.const(dtype, unwrap(arg))), + ShapeTracker.from_shape(())).reshape((1,)*len(shape)).expand(shape) # TOOD: Tensor variable bindings need device and shape from sources if op is Ops.BIND: assert isinstance(arg, UOp) and arg.op is Ops.BIND and shape == (), f"trying to create BIND with {arg=} {shape=}" @@ -458,8 +458,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if not unwrap((src:=self.base).st).contiguous: raise RuntimeError(f"can only copy contiguous {self}") return UOp.metaop(Ops.COPY, src.shape, src.dtype, device, (device, clone), (src,)).view(unwrap(self.st)) def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True) - # TODO: CONST is just CONST, delete this - def is_unrealized_const(self): return self.base.op is Ops.CONST + def is_unrealized_const(self): return (s:=self.base).op is Ops.VIEW and len(s.src) == 2 and s.realized is None and s.src[1].op is Ops.CONST def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in unwrap(self.st).views) def can_view(self): return (self.st is not None and self._device is not None and self.st.consecutive and not self.is_unrealized_const() and