conceptual uop st cleanup [pr] (#7956)

* conceptual uop st cleanup [pr]

* unwrap is fine here, better than arg
This commit is contained in:
qazal
2024-11-29 06:35:46 -05:00
committed by GitHub
parent 2d11765295
commit e54ff0d3af

View File

@@ -263,11 +263,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR}
@functools.cached_property
def st(self) -> Optional[ShapeTracker]:
if not self.has_st: return None
if self.op in GroupOp.Buffer: return self.st_arg
if self.op is Ops.VIEW: return self.arg
src_sts = [x.st for x in self.src if x.st is not None]
# buffer ops can have a non contiguous shapetracker
if self.op in GroupOp.Buffer and len(src_sts:=[unwrap(x.st) for x in self.src if x.op is Ops.VIEW]) != 0: return src_sts[0]
if len(src_sts:=[x.st for x in self.src if x.st is not None]) == 0: return None
assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
# all other ops have a contiguous shapetracker
from tinygrad.shape.shapetracker import ShapeTracker
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op is Ops.REDUCE_AXIS else src_sts[0].shape)
@functools.cached_property