mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
conceptual uop st cleanup [pr] (#7956)
* conceptual uop st cleanup [pr] * unwrap is fine here, better than arg
This commit is contained in:
@@ -263,11 +263,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR}
|
||||
@functools.cached_property
|
||||
def st(self) -> Optional[ShapeTracker]:
|
||||
if not self.has_st: return None
|
||||
if self.op in GroupOp.Buffer: return self.st_arg
|
||||
if self.op is Ops.VIEW: return self.arg
|
||||
src_sts = [x.st for x in self.src if x.st is not None]
|
||||
# buffer ops can have a non contiguous shapetracker
|
||||
if self.op in GroupOp.Buffer and len(src_sts:=[unwrap(x.st) for x in self.src if x.op is Ops.VIEW]) != 0: return src_sts[0]
|
||||
if len(src_sts:=[x.st for x in self.src if x.st is not None]) == 0: return None
|
||||
assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
|
||||
# all other ops have a contiguous shapetracker
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op is Ops.REDUCE_AXIS else src_sts[0].shape)
|
||||
@functools.cached_property
|
||||
|
||||
Reference in New Issue
Block a user