CONST(VIEW(DEVICE)) (#8365)

This commit is contained in:
qazal
2024-12-22 04:18:35 +02:00
committed by GitHub
parent 88bc51385c
commit 83284985f0
4 changed files with 17 additions and 22 deletions

View File

@@ -1987,7 +1987,7 @@ class TestBigGraph(unittest.TestCase):
check_schedule(x, 1)
tensor_const_pm = PatternMatcher([
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.CONST, src=()))), lambda: True),
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)))), lambda: True),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST))))), lambda: True),
])
class TestConst(unittest.TestCase):

View File

@@ -3,7 +3,7 @@ from tinygrad import Tensor
from tinygrad.ops import UPat, Ops
realized_pattern = UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),))
const_pattern = UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.CONST)))
const_pattern = UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),))))
def is_pattern(ten:Tensor, pat:UPat): assert pat.match(ten.lazydata, {})
class TestTensorUopRepresentation(unittest.TestCase):

View File

@@ -5,7 +5,7 @@ from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_
from tinygrad.ops import identity_element, buffers, exec_alu, type_verify
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, ContextVar
from tinygrad.dtype import ConstType, DType, ImageDType, dtypes
from tinygrad.dtype import DType, ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.device import Buffer
@@ -39,6 +39,9 @@ tensor_uop_spec = PatternMatcher([
# Tensor variable bindings
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
# Tensor const has a ShapeTracker of shape=() and a device
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)),)), lambda: True),
# DETACH and CONTIGUOUS change how we interpret the source UOp
# CONTIGUOUS ensures the source UOp realizes
(UPat((Ops.DETACH, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype),
@@ -76,10 +79,6 @@ tensor_uop_spec = PatternMatcher([
# DEVICE and VIEW specify device and shape for BIND
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND))), lambda: True),
# Tensor const has a ShapeTracker of shape=() and a device
(UPat(Ops.VIEW, name="view", arg=ShapeTracker.from_shape(()), src=(UPat(Ops.DEVICE), UPat(Ops.CONST, name="const"))),
lambda view,const: view.dtype == const.dtype),
# NOTE: EMPTY just ensures the source BUFFER is allocated before children run
# TODO: this should be EMPTY(VIEW(BUFFER))
(UPat(Ops.EMPTY, src=(), arg=None), lambda: True),
@@ -128,9 +127,9 @@ class ScheduleContext:
contiguous: dict[UOp, UOp] = field(default_factory=dict) # this maps roots to places they are made contiguous
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
# TODO: delete this once CONST has a VIEW source
# currently tensor uop is VIEW(DEVICE, CONST)
def is_constant(u:UOp): return u.op is Ops.VIEW and len(u.src) == 2 and u.src[1].op in {Ops.CONST, Ops.BIND}
# TODO: delete this once BIND has a VIEW source
# currently tensor BIND is VIEW(DEVICE, BIND) - CONST(VIEW(DEVICE)) is a prereq for this
def is_constant(u:UOp): return u.op is Ops.CONST or (u.op is Ops.VIEW and len(u.src) == 2 and u.src[1].op is Ops.BIND)
def to_uop(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
@@ -406,11 +405,6 @@ class UPatScheduled(UPat):
# ** this is schedule level const folding
def _as_const(u:UOp, val:ConstType) -> UOp:
assert is_scheduled(u), f"must be scheduled to fold {u}"
st = (base:=ShapeTracker.from_shape(())).reshape((1,)*len(u.shape)).expand(u.shape)
return UOp(Ops.VIEW, u.dtype, (u.buf_uop, UOp.const(u.dtype, val)), base).view(st)
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
# remove reduce on unmasked const
if all_int(x.shape) and x.is_unrealized_unmasked_const():
@@ -451,9 +445,9 @@ def replace_contiguous(ctx:ScheduleContext, alu:UOp):
ops_folding = PatternMatcher([
# op with size 0 is zero
(UPatScheduled(), lambda b,to_store,base: _as_const(base, 0) if base.size == 0 else None),
(UPatScheduled(), lambda b,to_store,base: base.const_like(0) if base.size == 0 else None),
# if the uop folded to a CONST we can delete the BUFFER
(UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.replace(src=(UOp(Ops.DEVICE, arg=base.device), const))),
(UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.const_like(const.const_arg)),
# DETACH is a NOOP here
(UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]),
# elementwise const folding
@@ -563,7 +557,7 @@ def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp:
break_sched = PatternMatcher([
# consts are always fused and generated
(UPat(Ops.VIEW, name="root", src=(UPat(), UPat.cvar())), lambda root: UOp.const_with_shape(root.dtype.base, root.const_arg, root.shape)),
(UPat.cvar(name="root"), lambda root: None if root.st is None else UOp.const_with_shape(root.dtype.base, root.const_arg, root.shape)),
# values from BIND append to this schedule's var_vals
(UPat(Ops.VIEW, name="st", src=(UPat(), UPat(Ops.BIND, name="bind"))), unbind_variable),
# view of realized buffer just loads

View File

@@ -437,10 +437,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None, src:tuple[UOp, ...]=()) -> UOp:
from tinygrad.shape.shapetracker import ShapeTracker
if op is Ops.CONST:
# Tensor const is a VIEW(DEVICE, CONST) -> RESHAPE -> EXPAND
# Tensor const is a CONST(VIEW(DEVICE)) -> RESHAPE -> EXPAND
assert isinstance(arg, get_args(ConstType)), f"trying to create CONST with {arg=}"
return UOp(Ops.VIEW, dtype, (UOp(Ops.DEVICE, arg=device), UOp.const(dtype, unwrap(arg))),
ShapeTracker.from_shape(())).reshape((1,)*len(shape)).expand(shape)
return UOp.const(dtype, unwrap(arg)).replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),),
ShapeTracker.from_shape(())),)).reshape((1,)*len(shape)).expand(shape)
# TOOD: Tensor variable bindings need device and shape from sources
if op is Ops.BIND:
assert isinstance(arg, UOp) and arg.op is Ops.BIND and shape == (), f"trying to create BIND with {arg=} {shape=}"
@@ -458,7 +458,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if not unwrap((src:=self.base).st).contiguous: raise RuntimeError(f"can only copy contiguous {self}")
return UOp.metaop(Ops.COPY, src.shape, src.dtype, device, (device, clone), (src,)).view(unwrap(self.st))
def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
def is_unrealized_const(self): return (s:=self.base).op is Ops.VIEW and len(s.src) == 2 and s.realized is None and s.src[1].op is Ops.CONST
# TODO: CONST is just CONST, delete this
def is_unrealized_const(self): return self.base.op is Ops.CONST
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in unwrap(self.st).views)
def can_view(self):
return (self.st is not None and self._device is not None and self.st.consecutive and not self.is_unrealized_const() and