don't late simplify on marg

This commit is contained in:
George Hotz
2025-10-27 16:51:49 +08:00
parent f4da94af28
commit 7563e20809

View File

@@ -296,6 +296,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
with Context(TRACK_MATCH_STATS=0 if not tracked else TRACK_MATCH_STATS.value):
return graph_rewrite(self, symbolic if full_symbolic else commutative, name="simplify")
def ssimplify(self) -> UOp|ConstType: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
def sintify(self) -> sint: return self.arg if self.op is Ops.CONST else self
def _eval(self, dtype, expected_type:Type[T]) -> T:
assert self.dtype in dtype, f"eval with wrong dtype {self}"
vmin, vmax = (simple_self:=self.simplify())._min_max
@@ -484,7 +485,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
match self.op:
case Ops.CONST: return self.arg
case Ops.VCONST: return self.arg[i]
case Ops.VECTORIZE: return cast(sint, self.src[i].ssimplify())
case Ops.VECTORIZE: return self.src[i].sintify()
case _: raise RuntimeError(f"no sgep on {self.op}")
@functools.cached_property
@@ -506,7 +507,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if len(arg) == 0: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(0)))
elif all(isinstance(x, int) for x in arg): usrcs.append(UOp.const(dtypes.index.vec(len(arg)), arg))
else: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg)))
ret = UOp(op, self.dtype, (self,)+tuple(usrcs), arg if len(usrcs) == 0 else None)
if len(usrcs) == 0: ret = UOp(op, self.dtype, (self,), arg)
else: ret = UOp(op, self.dtype, (self,)+UOp.sink(*usrcs).simplify().src)
# for all movement ops, we check shape property
if ret.shape == self.shape and same_shape_noop: return self
return ret