this is better

This commit is contained in:
George Hotz
2025-10-20 18:07:34 +08:00
parent 154b6d5901
commit ec97cec952
2 changed files with 6 additions and 6 deletions

View File

@@ -169,7 +169,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def ptrdtype(self) -> PtrDType:
if not isinstance(self.dtype, PtrDType): raise RuntimeError("ptrdtype called on UOp without PtrDType")
if not isinstance(self.dtype, PtrDType): raise RuntimeError(f"ptrdtype called on UOp with type {self.dtype}")
return self.dtype
# *** uop shape stuff ***
@@ -1178,8 +1178,6 @@ pm_lower_index_dtype = PatternMatcher([
lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))),
# TODO: this is only triggering if they are all casts, correct?
(UPat((Ops.SINK, Ops.NOOP), src=UPat().cast(dtypes.index), name="n"), lambda n: n.replace(src=tuple(s.src[0] for s in n.src))),
# TODO: this should be more general
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=tuple(y.src[0] if y.op is Ops.CAST and y.dtype.scalar()==dtypes.index else y for y in x.src))),
])
def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]

View File

@@ -377,6 +377,11 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
(UPat(GroupOp.Binary, src=(UPat.var("x", dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y:
x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None),
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
# only RANGE/IF/STORE/KERNEL have side effects
(UPat(Ops.AFTER, name="x"), lambda x:
x.replace(src=(x.src[0],)+tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.IF, Ops.STORE, Ops.KERNEL} else y.src for y in x.src[1:]])))),
# after with 1 src is just src[0]
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
])+gep_pushing
symbolic_flat = symbolic+PatternMatcher([
@@ -554,7 +559,4 @@ sym = symbolic_flat+pm_simplify_valid+PatternMatcher([
((UPat.var("x")*UPat.cvar("c", vec=False)).reduce(arg=Ops.ADD, name="r", allow_any_len=True), lambda x,c,r: r.replace(src=(x,)+r.src[1:])*c.arg),
# reduce mul chain, move muls after the reduce
(UPat(Ops.MUL).reduce(name="r", allow_any_len=True), reduce_mul_chain),
# only RANGE/STORE/AFTER have side effects
(UPat(Ops.AFTER, name="x"), lambda x:
x.replace(src=(x.src[0],)+tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.AFTER} else y.src for y in x.src[1:]])))),
])