mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-28 08:17:58 -05:00
has_st shouldn't exist anymore [pr] (#8446)
* has_st shouldn't exist anymore [pr] * const also shouldn't be there
This commit is contained in:
@@ -685,8 +685,9 @@ class Kernel:
|
||||
# the living definition of intermediate UOps
|
||||
|
||||
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:dict[UOp, ShapeTracker]) -> None:
|
||||
if not uop.has_st or uop in sts: return
|
||||
if uop in sts: return
|
||||
# restore globals from the two stage reduce
|
||||
# this is because this LOAD has an implicit movement op
|
||||
if uop.op is Ops.LOAD and uop.src[0].op is Ops.DEFINE_LOCAL:
|
||||
_assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
|
||||
sts[uop] = sts[local_reduce]
|
||||
|
||||
@@ -187,7 +187,7 @@ def elementwise_view_right(root:UOp) -> UOp|None:
|
||||
# push the swizzle from src to root
|
||||
output_swizzle = swizzles[0]
|
||||
new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
|
||||
ret = root.replace(src=tuple(x if not x.has_st else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
|
||||
ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
|
||||
# NOTE: swizzle resolves once we hit STORE
|
||||
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(output_swizzle.shape))
|
||||
|
||||
|
||||
@@ -273,8 +273,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
# *** uop shape stuff ***
|
||||
|
||||
@property
|
||||
def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR}
|
||||
@functools.cached_property
|
||||
def st(self) -> ShapeTracker|None:
|
||||
# these ops define a ShapeTracker from the arg
|
||||
@@ -295,7 +293,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return ShapeTracker.from_shape(shape)
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> tuple[sint, ...]:
|
||||
return self.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
if self.op is Ops.VIEW: return self.shape
|
||||
# TODO: this should check if st is None, it cannot because local reduce has implicit movement ops
|
||||
return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL,Ops.DEFINE_VAR,Ops.CONST}]))
|
||||
@property
|
||||
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
|
||||
@property
|
||||
@@ -1310,7 +1310,7 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, name="s0").view(name="s1"), lambda
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# VIEW before elementwise ops
|
||||
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"),
|
||||
lambda e,v: e.replace(src=tuple(s if not s.has_st else s.view(v.st) if s is s.base else s.base.view(s.st+v.st) for s in e.src))),
|
||||
lambda e,v: e.replace(src=tuple(s if s.st is None else s.view(v.st) if s is s.base else s.base.view(s.st+v.st) for s in e.src))),
|
||||
# early merge VIEW buffer ops
|
||||
(UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user