mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
CONST(VIEW(DEVICE)) (#8365)
This commit is contained in:
@@ -1987,7 +1987,7 @@ class TestBigGraph(unittest.TestCase):
|
||||
check_schedule(x, 1)
|
||||
|
||||
tensor_const_pm = PatternMatcher([
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.CONST, src=()))), lambda: True),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)))), lambda: True),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST))))), lambda: True),
|
||||
])
|
||||
class TestConst(unittest.TestCase):
|
||||
|
||||
@@ -3,7 +3,7 @@ from tinygrad import Tensor
|
||||
from tinygrad.ops import UPat, Ops
|
||||
|
||||
realized_pattern = UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),))
|
||||
const_pattern = UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.CONST)))
|
||||
const_pattern = UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),))))
|
||||
def is_pattern(ten:Tensor, pat:UPat): assert pat.match(ten.lazydata, {})
|
||||
|
||||
class TestTensorUopRepresentation(unittest.TestCase):
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_
|
||||
from tinygrad.ops import identity_element, buffers, exec_alu, type_verify
|
||||
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, ContextVar
|
||||
from tinygrad.dtype import ConstType, DType, ImageDType, dtypes
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.device import Buffer
|
||||
@@ -39,6 +39,9 @@ tensor_uop_spec = PatternMatcher([
|
||||
# Tensor variable bindings
|
||||
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
|
||||
|
||||
# Tensor const has a ShapeTracker of shape=() and a device
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)),)), lambda: True),
|
||||
|
||||
# DETACH and CONTIGUOUS change how we interpret the source UOp
|
||||
# CONTIGUOUS ensures the source UOp realizes
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype),
|
||||
@@ -76,10 +79,6 @@ tensor_uop_spec = PatternMatcher([
|
||||
# DEVICE and VIEW specify device and shape for BIND
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND))), lambda: True),
|
||||
|
||||
# Tensor const has a ShapeTracker of shape=() and a device
|
||||
(UPat(Ops.VIEW, name="view", arg=ShapeTracker.from_shape(()), src=(UPat(Ops.DEVICE), UPat(Ops.CONST, name="const"))),
|
||||
lambda view,const: view.dtype == const.dtype),
|
||||
|
||||
# NOTE: EMPTY just ensures the source BUFFER is allocated before children run
|
||||
# TODO: this should be EMPTY(VIEW(BUFFER))
|
||||
(UPat(Ops.EMPTY, src=(), arg=None), lambda: True),
|
||||
@@ -128,9 +127,9 @@ class ScheduleContext:
|
||||
contiguous: dict[UOp, UOp] = field(default_factory=dict) # this maps roots to places they are made contiguous
|
||||
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
||||
|
||||
# TODO: delete this once CONST has a VIEW source
|
||||
# currently tensor uop is VIEW(DEVICE, CONST)
|
||||
def is_constant(u:UOp): return u.op is Ops.VIEW and len(u.src) == 2 and u.src[1].op in {Ops.CONST, Ops.BIND}
|
||||
# TODO: delete this once BIND has a VIEW source
|
||||
# currently tensor BIND is VIEW(DEVICE, BIND) - CONST(VIEW(DEVICE)) is a prereq for this
|
||||
def is_constant(u:UOp): return u.op is Ops.CONST or (u.op is Ops.VIEW and len(u.src) == 2 and u.src[1].op is Ops.BIND)
|
||||
|
||||
def to_uop(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
@@ -406,11 +405,6 @@ class UPatScheduled(UPat):
|
||||
|
||||
# ** this is schedule level const folding
|
||||
|
||||
def _as_const(u:UOp, val:ConstType) -> UOp:
|
||||
assert is_scheduled(u), f"must be scheduled to fold {u}"
|
||||
st = (base:=ShapeTracker.from_shape(())).reshape((1,)*len(u.shape)).expand(u.shape)
|
||||
return UOp(Ops.VIEW, u.dtype, (u.buf_uop, UOp.const(u.dtype, val)), base).view(st)
|
||||
|
||||
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
|
||||
# remove reduce on unmasked const
|
||||
if all_int(x.shape) and x.is_unrealized_unmasked_const():
|
||||
@@ -451,9 +445,9 @@ def replace_contiguous(ctx:ScheduleContext, alu:UOp):
|
||||
|
||||
ops_folding = PatternMatcher([
|
||||
# op with size 0 is zero
|
||||
(UPatScheduled(), lambda b,to_store,base: _as_const(base, 0) if base.size == 0 else None),
|
||||
(UPatScheduled(), lambda b,to_store,base: base.const_like(0) if base.size == 0 else None),
|
||||
# if the uop folded to a CONST we can delete the BUFFER
|
||||
(UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.replace(src=(UOp(Ops.DEVICE, arg=base.device), const))),
|
||||
(UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.const_like(const.const_arg)),
|
||||
# DETACH is a NOOP here
|
||||
(UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]),
|
||||
# elementwise const folding
|
||||
@@ -563,7 +557,7 @@ def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp:
|
||||
|
||||
break_sched = PatternMatcher([
|
||||
# consts are always fused and generated
|
||||
(UPat(Ops.VIEW, name="root", src=(UPat(), UPat.cvar())), lambda root: UOp.const_with_shape(root.dtype.base, root.const_arg, root.shape)),
|
||||
(UPat.cvar(name="root"), lambda root: None if root.st is None else UOp.const_with_shape(root.dtype.base, root.const_arg, root.shape)),
|
||||
# values from BIND append to this schedule's var_vals
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(), UPat(Ops.BIND, name="bind"))), unbind_variable),
|
||||
# view of realized buffer just loads
|
||||
|
||||
@@ -437,10 +437,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None, src:tuple[UOp, ...]=()) -> UOp:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
if op is Ops.CONST:
|
||||
# Tensor const is a VIEW(DEVICE, CONST) -> RESHAPE -> EXPAND
|
||||
# Tensor const is a CONST(VIEW(DEVICE)) -> RESHAPE -> EXPAND
|
||||
assert isinstance(arg, get_args(ConstType)), f"trying to create CONST with {arg=}"
|
||||
return UOp(Ops.VIEW, dtype, (UOp(Ops.DEVICE, arg=device), UOp.const(dtype, unwrap(arg))),
|
||||
ShapeTracker.from_shape(())).reshape((1,)*len(shape)).expand(shape)
|
||||
return UOp.const(dtype, unwrap(arg)).replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),),
|
||||
ShapeTracker.from_shape(())),)).reshape((1,)*len(shape)).expand(shape)
|
||||
# TOOD: Tensor variable bindings need device and shape from sources
|
||||
if op is Ops.BIND:
|
||||
assert isinstance(arg, UOp) and arg.op is Ops.BIND and shape == (), f"trying to create BIND with {arg=} {shape=}"
|
||||
@@ -458,7 +458,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if not unwrap((src:=self.base).st).contiguous: raise RuntimeError(f"can only copy contiguous {self}")
|
||||
return UOp.metaop(Ops.COPY, src.shape, src.dtype, device, (device, clone), (src,)).view(unwrap(self.st))
|
||||
def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
|
||||
def is_unrealized_const(self): return (s:=self.base).op is Ops.VIEW and len(s.src) == 2 and s.realized is None and s.src[1].op is Ops.CONST
|
||||
# TODO: CONST is just CONST, delete this
|
||||
def is_unrealized_const(self): return self.base.op is Ops.CONST
|
||||
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in unwrap(self.st).views)
|
||||
def can_view(self):
|
||||
return (self.st is not None and self._device is not None and self.st.consecutive and not self.is_unrealized_const() and
|
||||
|
||||
Reference in New Issue
Block a user