has_st shouldn't exist anymore [pr] (#8446)

* has_st shouldn't exist anymore [pr]

* const also shouldn't be there
This commit is contained in:
qazal
2024-12-30 12:52:18 +02:00
committed by GitHub
parent 7499139239
commit 96f2a83a62
3 changed files with 7 additions and 6 deletions

View File

@@ -685,8 +685,9 @@ class Kernel:
# the living definition of intermediate UOps
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:dict[UOp, ShapeTracker]) -> None:
if not uop.has_st or uop in sts: return
if uop in sts: return
# restore globals from the two stage reduce
# this is because this LOAD has an implicit movement op
if uop.op is Ops.LOAD and uop.src[0].op is Ops.DEFINE_LOCAL:
_assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
sts[uop] = sts[local_reduce]

View File

@@ -187,7 +187,7 @@ def elementwise_view_right(root:UOp) -> UOp|None:
# push the swizzle from src to root
output_swizzle = swizzles[0]
new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
ret = root.replace(src=tuple(x if not x.has_st else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
# NOTE: swizzle resolves once we hit STORE
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(output_swizzle.shape))

View File

@@ -273,8 +273,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# *** uop shape stuff ***
@property
def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR}
@functools.cached_property
def st(self) -> ShapeTracker|None:
# these ops define a ShapeTracker from the arg
@@ -295,7 +293,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return ShapeTracker.from_shape(shape)
@functools.cached_property
def full_shape(self) -> tuple[sint, ...]:
return self.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
if self.op is Ops.VIEW: return self.shape
# TODO: this should check if st is None, it cannot because local reduce has implicit movement ops
return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL,Ops.DEFINE_VAR,Ops.CONST}]))
@property
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
@property
@@ -1310,7 +1310,7 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, name="s0").view(name="s1"), lambda
view_left = merge_views+PatternMatcher([
# VIEW before elementwise ops
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"),
lambda e,v: e.replace(src=tuple(s if not s.has_st else s.view(v.st) if s is s.base else s.base.view(s.st+v.st) for s in e.src))),
lambda e,v: e.replace(src=tuple(s if s.st is None else s.view(v.st) if s is s.base else s.base.view(s.st+v.st) for s in e.src))),
# early merge VIEW buffer ops
(UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
])