From 3a556a7e8bfeb8631ceb719bc54fe57fca2bd778 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 24 Dec 2024 10:15:56 +0200 Subject: [PATCH] fully local tensor const representation: CONST(VIEW(DEVICE)) [pr] (#8389) --- test/test_const_folding.py | 2 +- test/test_schedule.py | 2 +- test/unit/test_tensor_uop_representation.py | 5 ++-- tinygrad/engine/schedule.py | 33 +++++++-------------- tinygrad/ops.py | 12 +++++--- 5 files changed, 24 insertions(+), 30 deletions(-) diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 287bac476f..1366853486 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -96,7 +96,7 @@ class TestBinaryOpsConstFolding(unittest.TestCase): def test_literal_one_pow(self): _check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4])) # this fails because of DETACH, it shouldn't - @unittest.expectedFailure + # update: passes after CONST(VIEW(DEVICE)) in tensor def test_tensor_one_pow(self): _check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4])) diff --git a/test/test_schedule.py b/test/test_schedule.py index dadca5b2c1..620d65fb77 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1987,7 +1987,7 @@ class TestBigGraph(unittest.TestCase): check_schedule(x, 1) tensor_const_pm = PatternMatcher([ - (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.CONST, src=()))), lambda: True), + (UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)),)), lambda: True), (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST))))), lambda: True), ]) class TestConst(unittest.TestCase): diff --git a/test/unit/test_tensor_uop_representation.py b/test/unit/test_tensor_uop_representation.py index 999f22b39d..d9d2b48a22 100644 --- a/test/unit/test_tensor_uop_representation.py +++ b/test/unit/test_tensor_uop_representation.py @@ -3,7 +3,7 @@ from tinygrad import Tensor from tinygrad.ops import UPat, Ops realized_pattern = UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)) -const_pattern = UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.CONST))) +const_pattern = UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),),))) def is_pattern(ten:Tensor, pat:UPat): assert pat.match(ten.lazydata, {}) class TestTensorUopRepresentation(unittest.TestCase): @@ -22,7 +22,8 @@ class TestTensorUopRepresentation(unittest.TestCase): def test_const_pattern(self): a = Tensor(1) print(a.lazydata) - is_pattern(a, const_pattern) + is_pattern(a, const_pattern) # const in tensor has a DEVICE and VIEW src + is_pattern(a, UPat.cvar("x")) # even cvar works! def test_consts_do_not_realize(self): a = Tensor(1) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 9d2f5914c5..ea94f853ef 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -5,7 +5,7 @@ from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_ from tinygrad.ops import identity_element, buffers, exec_alu, type_verify from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, ContextVar -from tinygrad.dtype import ConstType, DType, ImageDType, dtypes +from tinygrad.dtype import DType, ImageDType, dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View, strides_for_shape from tinygrad.device import Buffer @@ -39,6 +39,9 @@ tensor_uop_spec = PatternMatcher([ # Tensor variable bindings (UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True), + # Tensor const has a ShapeTracker of shape=() and a device + (UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)),)), lambda: True), + # DETACH and CONTIGUOUS change how we interpret the source UOp # CONTIGUOUS ensures the source UOp realizes (UPat((Ops.DETACH, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype), @@ -73,10 +76,6 @@ tensor_uop_spec = PatternMatcher([ # DEVICE and VIEW specify device and shape for BIND (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND))), lambda: True), - # Tensor const has a ShapeTracker of shape=() and a device - (UPat(Ops.VIEW, name="view", arg=ShapeTracker.from_shape(()), src=(UPat(Ops.DEVICE), UPat(Ops.CONST, name="const"))), - lambda view,const: view.dtype == const.dtype), - # NOTE: EMPTY just ensures the source BUFFER is allocated before children run # TODO: this should be EMPTY(VIEW(BUFFER)) (UPat(Ops.EMPTY, src=(), arg=None), lambda: True), @@ -127,7 +126,7 @@ class ScheduleContext: # TODO: delete this once CONST has a VIEW source # currently tensor uop is VIEW(DEVICE, CONST) -def is_constant(u:UOp): return u.op is Ops.VIEW and len(u.src) == 2 and u.src[1].op in {Ops.CONST, Ops.BIND} +def is_constant(u:UOp): return u.op is Ops.CONST or (u.op is Ops.VIEW and len(u.src) == 2 and u.src[1].op is Ops.BIND) def to_uop(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r @@ -403,11 +402,6 @@ class UPatScheduled(UPat): # ** this is schedule level const folding -def _as_const(u:UOp, val:ConstType) -> UOp: - assert is_scheduled(u), f"must be scheduled to fold {u}" - st = (base:=ShapeTracker.from_shape(())).reshape((1,)*len(u.shape)).expand(u.shape) - return UOp(Ops.VIEW, u.dtype, (u.buf_uop, UOp.const(u.dtype, val)), base).view(st) - def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: # remove reduce on unmasked const if all_int(x.shape) and x.is_unrealized_unmasked_const(): @@ -448,9 +442,9 @@ def replace_contiguous(ctx:ScheduleContext, alu:UOp): ops_folding = PatternMatcher([ # op with size 0 is zero - (UPatScheduled(), lambda b,to_store,base: _as_const(base, 0) if base.size == 0 else None), + (UPatScheduled(), lambda b,to_store,base: base.const_like(0) if base.size == 0 else None), # if the uop folded to a CONST we can delete the BUFFER - (UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.replace(src=(UOp(Ops.DEVICE, arg=base.device), const))), + (UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.const_like(const.const_arg)), # DETACH is a NOOP here (UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]), # elementwise const folding @@ -543,14 +537,9 @@ do_realize = PatternMatcher([ # **** rewrite VIEW into LOAD/STORE/VALID or fuse the underlying UOp -def generate_const(x:UOp, st:UOp): - # NOTE: masked VIEW stacks on top of the CONST, this is required for const folding correctness - assert all(v.mask is None for v in unwrap(st.st).views), f"ShapeTracker of CONST must be unmasked, got {st}" - return UOp(Ops.VALID, dtypes.bool, (unwrap(st.st).to_uop(),)).where(x.replace(dtype=x.dtype.base), 0) - def unbind_variable(ctx:ScheduleContext, bind:UOp, st:UOp): ctx.var_vals.update([bind.unbind()]) - return generate_const(UOp.const(bind.dtype, bind), st) + return UOp.const(bind.dtype, bind).valid(unwrap(st.st)) def load_realized(ctx:ScheduleContext, b:UOp, st:UOp): assert st.size == b.size and unwrap(st.st).contiguous, f"ShapeTracker of realized {b} BUFFER must match the BUFFER size {st}" @@ -565,7 +554,7 @@ def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp): break_sched = PatternMatcher([ # CONST is always fused and generated - (UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE), UPat(Ops.CONST, name="x"))), generate_const), + (UPat(Ops.CONST, name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: UOp.const(x.dtype.base, x.const_arg).valid(st.st)), (UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE), UPat(Ops.BIND, name="bind"))), unbind_variable), # VIEW of BUFFER either becomes a LOAD/STORE or we fuse it (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized), @@ -592,8 +581,8 @@ remove_movement_ops = PatternMatcher([ (UPat(Ops.VIEW, src=(UPat(Ops.VIEW, src=(UPat.var("x"),), name="v1")), name="v2"), lambda x,v1,v2: v1.replace(arg=v1.arg+v2.arg) if x.op is not Ops.BUFFER else None), # merge unmasked const views - (UPat(Ops.VIEW, src=(UPat(Ops.VIEW, src=(UPat(), UPat(Ops.CONST)), name="v1")), name="v2"), - lambda v1,v2: v1.replace(arg=v1.arg+v2.arg) if all(v.mask is None for v in v2.st.views) else None), + (UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)), + lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None), ]) @track_rewrites(named=True) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 1fd18b321d..a08b751f5d 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -391,6 +391,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype)) + def valid(self, st:ShapeTracker): + assert self.op in {Ops.CONST, Ops.DEFINE_VAR}, f"can only create VALID from a constant, got {self.op}" + return UOp(Ops.VALID, dtypes.bool, (st.to_uop(),)).where(self, 0) @staticmethod def range(dtype:DType, start:sint, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(start), sint_to_uop(end)), arg=idx) def _reduce_op(self, op:Ops, axis:tuple[int, ...]): @@ -436,10 +439,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None, src:tuple[UOp, ...]=()) -> UOp: from tinygrad.shape.shapetracker import ShapeTracker if op is Ops.CONST: - # Tensor const is a VIEW(DEVICE, CONST) -> RESHAPE -> EXPAND + # Tensor const is CONST(VIEW(DEVICE)) -> RESHAPE -> EXPAND assert isinstance(arg, get_args(ConstType)), f"trying to create CONST with {arg=}" - return UOp(Ops.VIEW, dtype, (UOp(Ops.DEVICE, arg=device), UOp.const(dtype, unwrap(arg))), - ShapeTracker.from_shape(())).reshape((1,)*len(shape)).expand(shape) + return UOp.const(dtype, unwrap(arg)).replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), + ShapeTracker.from_shape(())),)).reshape((1,)*len(shape)).expand(shape) # TOOD: Tensor variable bindings need device and shape from sources if op is Ops.BIND: assert isinstance(arg, UOp) and arg.op is Ops.BIND and shape == (), f"trying to create BIND with {arg=} {shape=}" @@ -457,7 +460,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if not unwrap((src:=self.base).st).contiguous: raise RuntimeError(f"can only copy contiguous {self}") return UOp.metaop(Ops.COPY, src.shape, src.dtype, device, (device, clone), (src,)).view(unwrap(self.st)) def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True) - def is_unrealized_const(self): return (s:=self.base).op is Ops.VIEW and len(s.src) == 2 and s.realized is None and s.src[1].op is Ops.CONST + # TOOD: checking op is shorter, delete this. + def is_unrealized_const(self): return self.base.op is Ops.CONST def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in unwrap(self.st).views) def can_view(self): return (self.st is not None and self._device is not None and self.st.consecutive and not self.is_unrealized_const() and