pyrender Ops.SPECIAL and use correct dtype for Ops.RANGE rendering (#12931)

This commit is contained in:
Sieds Lykles
2025-10-27 03:21:34 +01:00
committed by GitHub
parent 8c1368cab6
commit eaeaea2f9c

View File

@@ -376,10 +376,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
return ret
@staticmethod
def range(end:sint, *arg):
def range(end:sint, *arg, dtype=dtypes.index):
if len(arg) == 0: raise RuntimeError("range needs an arg")
if len(arg) == 1: arg = arg+(AxisType.LOOP,)
return UOp(Ops.RANGE, dtype=dtypes.index, src=(sint_to_uop(end),), arg=arg)
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=arg)
@staticmethod
def special(end:sint, name:str, dtype=dtypes.index): return UOp(Ops.SPECIAL, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=name)
def r(self, op:Ops, axis:tuple[int, ...]):
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis)) if len(axis) else self
@@ -1161,7 +1163,7 @@ def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, na
for k,v in input_map.items(): new_map[k] = new_map.get(v,v)
return new_map
def sint_to_uop(x:sint) -> UOp: return UOp.const(dtypes.index, x) if isinstance(x, int) else x.cast(dtypes.index)
def sint_to_uop(x:sint, dtype=dtypes.index) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype)
def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count)
pm_lower_index_dtype = PatternMatcher([
@@ -1238,8 +1240,11 @@ pm_pyrender = PatternMatcher([
(UPat(Ops.BITCAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.bitcast({x.dtype})")),
(UPat({Ops.MAX, Ops.THREEFRY, Ops.CMPLT, Ops.CMPNE, Ops.POW}, src=UPat(Ops.NOOP), name="x"),
lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.alu({x.op}, {x.src[1].arg})")),
(UPat(Ops.RANGE, src=(UPat(Ops.NOOP),), name="x"), lambda x:
UOp(Ops.NOOP, arg=f"UOp.range({x.src[0].arg}, {str(x.arg[0])}, {str(x.arg[1])})")),
(UPat(Ops.RANGE, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=
f"UOp.range({x.src[0].arg}, {str(x.arg[0])}, {str(x.arg[1])}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})")),
(UPat(Ops.SPECIAL, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg= f"UOp.special({x.src[0].arg}, \"{x.arg}\", dtype={x.dtype})")),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: UOp(Ops.NOOP, arg=
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})")),
(UPat(set(sugar.keys()), src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP,
arg=f"{x.src[0].arg}.{sugar[x.op]}({', '.join([y.arg for y in x.src[1:]] + ([f'arg={str(x.arg)}'] if x.arg is not None else []))})")),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.NOOP),), name="x"),