give define global and friends a shape (#11502)

* give define global and friends a shape

* ignore negative size

* ptx fix
This commit is contained in:
George Hotz
2025-08-04 19:09:39 -07:00
committed by GitHub
parent 83385e7abc
commit 7f6acfb0d5
2 changed files with 13 additions and 4 deletions

View File

@@ -156,7 +156,8 @@ merge_views = PatternMatcher([
# replace MovementOps with VIEW
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.base.view(mop.st)),
# remove NOOP views
(UPat.var("x").view(name="view"), lambda x,view: x if x.st is not None and view.st.contiguous and view.shape == x.shape else None),
(UPat.var("x").view(name="view"),
lambda x,view: x if x.st is not None and x.op not in GroupOp.Defines and view.st.contiguous and view.shape == x.shape else None),
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}).view(name="view"),
lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
# only unmaksed VIEW on CONST replaces the ShapeTracker

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
from enum import Enum, auto
from tinygrad.uop import Ops, GroupOp
from tinygrad.uop.mathtraits import MathTrait
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey
if TYPE_CHECKING:
@@ -150,7 +150,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# BUFFER/BUFFER_VIEW and KERNEL only have a size
if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return ShapeTracker.from_shape((self.size,))
if self.op is Ops.KERNEL: return ShapeTracker.from_shape((self.arg.ast.size,))
#if self.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}: return ShapeTracker.from_shape((self.dtype.size,))
if self.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
sz = cast(PtrDType, self.dtype).size
return ShapeTracker.from_shape((sz,)) if sz > 0 else None
# hack for PTX, CASTing the ptr loses the shape
if self.op is Ops.CAST and self.src[0].op is Ops.DEFINE_GLOBAL: return None
# otherwise we get the shape from sources
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
@@ -171,7 +176,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
parent_shapes = [x.full_shape for x in self.src]
return tuple(smax(x) for x in itertools.zip_longest(*parent_shapes, fillvalue=1))
@property
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
def shape(self) -> tuple[sint, ...]:
assert self.st is not None, f"{self.op} doesn't have a shape"
return unwrap(self.st).shape
@property
def size(self) -> int: return self.arg[0] if self.op is Ops.BUFFER_VIEW else self.arg if self.op is Ops.BUFFER else unwrap(self.st).size
@@ -635,6 +642,7 @@ class UPat(MathTrait):
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
# copied from UOp
def sink(self, *srcs:UPat|None, **kwargs): return UPat(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def index(self, idx:UPat, valid:UPat|None=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs)
def cast(self, dtype=None, **kwargs): return UPat(Ops.CAST, dtype, (self,), **kwargs)