mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
give define global and friends a shape (#11502)
* give define global and friends a shape * ignore negative size * ptx fix
This commit is contained in:
@@ -156,7 +156,8 @@ merge_views = PatternMatcher([
|
||||
# replace MovementOps with VIEW
|
||||
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.base.view(mop.st)),
|
||||
# remove NOOP views
|
||||
(UPat.var("x").view(name="view"), lambda x,view: x if x.st is not None and view.st.contiguous and view.shape == x.shape else None),
|
||||
(UPat.var("x").view(name="view"),
|
||||
lambda x,view: x if x.st is not None and x.op not in GroupOp.Defines and view.st.contiguous and view.shape == x.shape else None),
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}).view(name="view"),
|
||||
lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
|
||||
# only unmaksed VIEW on CONST replaces the ShapeTracker
|
||||
|
||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
from tinygrad.uop import Ops, GroupOp
|
||||
from tinygrad.uop.mathtraits import MathTrait
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten
|
||||
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey
|
||||
if TYPE_CHECKING:
|
||||
@@ -150,7 +150,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
# BUFFER/BUFFER_VIEW and KERNEL only have a size
|
||||
if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return ShapeTracker.from_shape((self.size,))
|
||||
if self.op is Ops.KERNEL: return ShapeTracker.from_shape((self.arg.ast.size,))
|
||||
#if self.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}: return ShapeTracker.from_shape((self.dtype.size,))
|
||||
if self.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
|
||||
sz = cast(PtrDType, self.dtype).size
|
||||
return ShapeTracker.from_shape((sz,)) if sz > 0 else None
|
||||
|
||||
# hack for PTX, CASTing the ptr loses the shape
|
||||
if self.op is Ops.CAST and self.src[0].op is Ops.DEFINE_GLOBAL: return None
|
||||
|
||||
# otherwise we get the shape from sources
|
||||
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
|
||||
@@ -171,7 +176,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
parent_shapes = [x.full_shape for x in self.src]
|
||||
return tuple(smax(x) for x in itertools.zip_longest(*parent_shapes, fillvalue=1))
|
||||
@property
|
||||
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
|
||||
def shape(self) -> tuple[sint, ...]:
|
||||
assert self.st is not None, f"{self.op} doesn't have a shape"
|
||||
return unwrap(self.st).shape
|
||||
@property
|
||||
def size(self) -> int: return self.arg[0] if self.op is Ops.BUFFER_VIEW else self.arg if self.op is Ops.BUFFER else unwrap(self.st).size
|
||||
|
||||
@@ -635,6 +642,7 @@ class UPat(MathTrait):
|
||||
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
|
||||
|
||||
# copied from UOp
|
||||
def sink(self, *srcs:UPat|None, **kwargs): return UPat(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def index(self, idx:UPat, valid:UPat|None=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs)
|
||||
def cast(self, dtype=None, **kwargs): return UPat(Ops.CAST, dtype, (self,), **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user