mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
this is better
This commit is contained in:
@@ -169,7 +169,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
@property
|
||||
def ptrdtype(self) -> PtrDType:
|
||||
if not isinstance(self.dtype, PtrDType): raise RuntimeError("ptrdtype called on UOp without PtrDType")
|
||||
if not isinstance(self.dtype, PtrDType): raise RuntimeError(f"ptrdtype called on UOp with type {self.dtype}")
|
||||
return self.dtype
|
||||
|
||||
# *** uop shape stuff ***
|
||||
@@ -1178,8 +1178,6 @@ pm_lower_index_dtype = PatternMatcher([
|
||||
lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))),
|
||||
# TODO: this is only triggering if they are all casts, correct?
|
||||
(UPat((Ops.SINK, Ops.NOOP), src=UPat().cast(dtypes.index), name="n"), lambda n: n.replace(src=tuple(s.src[0] for s in n.src))),
|
||||
# TODO: this should be more general
|
||||
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=tuple(y.src[0] if y.op is Ops.CAST and y.dtype.scalar()==dtypes.index else y for y in x.src))),
|
||||
])
|
||||
def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
|
||||
|
||||
|
||||
@@ -377,6 +377,11 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
(UPat(GroupOp.Binary, src=(UPat.var("x", dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y:
|
||||
x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None),
|
||||
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
|
||||
# only RANGE/IF/STORE/KERNEL have side effects
|
||||
(UPat(Ops.AFTER, name="x"), lambda x:
|
||||
x.replace(src=(x.src[0],)+tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.IF, Ops.STORE, Ops.KERNEL} else y.src for y in x.src[1:]])))),
|
||||
# after with 1 src is just src[0]
|
||||
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
|
||||
])+gep_pushing
|
||||
|
||||
symbolic_flat = symbolic+PatternMatcher([
|
||||
@@ -554,7 +559,4 @@ sym = symbolic_flat+pm_simplify_valid+PatternMatcher([
|
||||
((UPat.var("x")*UPat.cvar("c", vec=False)).reduce(arg=Ops.ADD, name="r", allow_any_len=True), lambda x,c,r: r.replace(src=(x,)+r.src[1:])*c.arg),
|
||||
# reduce mul chain, move muls after the reduce
|
||||
(UPat(Ops.MUL).reduce(name="r", allow_any_len=True), reduce_mul_chain),
|
||||
# only RANGE/STORE/AFTER have side effects
|
||||
(UPat(Ops.AFTER, name="x"), lambda x:
|
||||
x.replace(src=(x.src[0],)+tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.AFTER} else y.src for y in x.src[1:]])))),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user