mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
pyrender Ops.SPECIAL and use correct dtype for Ops.RANGE rendering (#12931)
This commit is contained in:
@@ -376,10 +376,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
|
||||
return ret
|
||||
@staticmethod
|
||||
def range(end:sint, *arg):
|
||||
def range(end:sint, *arg, dtype=dtypes.index):
|
||||
if len(arg) == 0: raise RuntimeError("range needs an arg")
|
||||
if len(arg) == 1: arg = arg+(AxisType.LOOP,)
|
||||
return UOp(Ops.RANGE, dtype=dtypes.index, src=(sint_to_uop(end),), arg=arg)
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=arg)
|
||||
@staticmethod
|
||||
def special(end:sint, name:str, dtype=dtypes.index): return UOp(Ops.SPECIAL, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=name)
|
||||
def r(self, op:Ops, axis:tuple[int, ...]):
|
||||
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
||||
return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis)) if len(axis) else self
|
||||
@@ -1161,7 +1163,7 @@ def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, na
|
||||
for k,v in input_map.items(): new_map[k] = new_map.get(v,v)
|
||||
return new_map
|
||||
|
||||
def sint_to_uop(x:sint) -> UOp: return UOp.const(dtypes.index, x) if isinstance(x, int) else x.cast(dtypes.index)
|
||||
def sint_to_uop(x:sint, dtype=dtypes.index) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype)
|
||||
|
||||
def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count)
|
||||
pm_lower_index_dtype = PatternMatcher([
|
||||
@@ -1238,8 +1240,11 @@ pm_pyrender = PatternMatcher([
|
||||
(UPat(Ops.BITCAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.bitcast({x.dtype})")),
|
||||
(UPat({Ops.MAX, Ops.THREEFRY, Ops.CMPLT, Ops.CMPNE, Ops.POW}, src=UPat(Ops.NOOP), name="x"),
|
||||
lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.alu({x.op}, {x.src[1].arg})")),
|
||||
(UPat(Ops.RANGE, src=(UPat(Ops.NOOP),), name="x"), lambda x:
|
||||
UOp(Ops.NOOP, arg=f"UOp.range({x.src[0].arg}, {str(x.arg[0])}, {str(x.arg[1])})")),
|
||||
(UPat(Ops.RANGE, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=
|
||||
f"UOp.range({x.src[0].arg}, {str(x.arg[0])}, {str(x.arg[1])}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})")),
|
||||
(UPat(Ops.SPECIAL, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg= f"UOp.special({x.src[0].arg}, \"{x.arg}\", dtype={x.dtype})")),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: UOp(Ops.NOOP, arg=
|
||||
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})")),
|
||||
(UPat(set(sugar.keys()), src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP,
|
||||
arg=f"{x.src[0].arg}.{sugar[x.op]}({', '.join([y.arg for y in x.src[1:]] + ([f'arg={str(x.arg)}'] if x.arg is not None else []))})")),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.NOOP),), name="x"),
|
||||
|
||||
Reference in New Issue
Block a user